diff --git a/tensorflow/lite/interpreter_test.cc b/tensorflow/lite/interpreter_test.cc index 3c8d72866cd..e25c038967c 100644 --- a/tensorflow/lite/interpreter_test.cc +++ b/tensorflow/lite/interpreter_test.cc @@ -1162,9 +1162,16 @@ class TestDelegate : public ::testing::Test { [](TfLiteContext* context, TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle, TfLiteTensor* output) -> TfLiteStatus { - // TODO(ycling): Implement tests to test buffer copying logic. + TFLITE_CHECK_GE(buffer_handle, -1); + TFLITE_CHECK_EQ(output->buffer_handle, buffer_handle); + const float floats[] = {6., 6., 6.}; + int num = output->dims->data[0]; + for (int i = 0; i < num; i++) { + output->data.f[i] = floats[i]; + } return kTfLiteOk; }; + delegate_.FreeBufferHandle = [](TfLiteContext* context, TfLiteDelegate* delegate, TfLiteBufferHandle* handle) { *handle = kTfLiteNullBufferHandle; }; @@ -1176,6 +1183,21 @@ class TestDelegate : public ::testing::Test { static TfLiteRegistration FakeFusedRegistration() { TfLiteRegistration reg = {nullptr}; reg.custom_name = "fake_fused_op"; + + reg.invoke = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + // Copy input data to output data. + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* a1 = &context->tensors[node->inputs->data[1]]; + TfLiteTensor* out = &context->tensors[node->outputs->data[0]]; + int num = a0->dims->data[0]; + for (int i = 0; i < num; i++) { + out->data.f[i] = a0->data.f[i] + a1->data.f[i]; + } + // Make the data stale so that CopyFromBufferHandle can be invoked + out->data_is_stale = true; + return kTfLiteOk; + }; return reg; } @@ -1332,6 +1354,64 @@ TEST_F(TestDelegate, ResizeInputWithNonDynamicDelegateShouldFail) { ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 2}), kTfLiteError); } +TEST_F(TestDelegate, TestCopyFromBufferInvoke) { + delegate_ = std::unique_ptr(new SimpleDelegate({0, 1, 2})); + TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); + interpreter_->ModifyGraphWithDelegate(delegate); + + constexpr int kOutputTensorIndex = 3; + TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + std::vector floats = {1.0f, 2.0f, 3.0f}; + memcpy(interpreter_->typed_tensor(0), floats.data(), + floats.size() * sizeof(float)); + + memcpy(interpreter_->typed_tensor(1), floats.data(), + floats.size() * sizeof(float)); + + // Before setting the buffer handle, the tensor's `delegate` is already set + // because it will be written by the delegate. + ASSERT_EQ(tensor->delegate, delegate); + ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); + + // Called Invoke without setting the buffer will not call the CopyFromBuffer + interpreter_->Invoke(); + std::vector res = {2.0f, 4.0f, 6.0f}; + for (int i = 0; i < tensor->dims->data[0]; ++i) { + ASSERT_EQ(tensor->data.f[i], res[i]); + } +} + +TEST_F(TestDelegate, TestCopyFromBuffer) { + delegate_ = std::unique_ptr(new SimpleDelegate({0, 1, 2})); + TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); + interpreter_->ModifyGraphWithDelegate(delegate); + + constexpr int kOutputTensorIndex = 3; + TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + std::vector floats = {1.0f, 2.0f, 3.0f}; + memcpy(interpreter_->typed_tensor(0), floats.data(), + floats.size() * sizeof(float)); + + memcpy(interpreter_->typed_tensor(1), floats.data(), + floats.size() * sizeof(float)); + + // Before setting the buffer handle, the tensor's `delegate` is already set + // because it will be written by the delegate. + ASSERT_EQ(tensor->delegate, delegate); + ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); + + TfLiteBufferHandle handle = AllocateBufferHandle(); + TfLiteStatus status = + interpreter_->SetBufferHandle(kOutputTensorIndex, handle, delegate); + interpreter_->Invoke(); + ASSERT_EQ(status, kTfLiteOk); + EXPECT_EQ(tensor->delegate, delegate); + EXPECT_EQ(tensor->buffer_handle, handle); + for (int i = 0; i < tensor->dims->data[0]; ++i) { + ASSERT_EQ(tensor->data.f[i], 6.0f); + } +} + class TestDelegateWithDynamicTensors : public ::testing::Test { protected: void SetUp() override {