From 7e6061037f87dacc545d834952883a08842d92eb Mon Sep 17 00:00:00 2001 From: Terry Heo Date: Thu, 24 Sep 2020 18:37:16 -0700 Subject: [PATCH] FlexDelegate: Provide constant tensors to ShapeInferenceFn Some op (such as RFFT) requires a constant tensor to calculate the exact output tensor shape. PiperOrigin-RevId: 333643286 Change-Id: I826ab115551ad805aeaccb132dd6a2b8f133e578 --- tensorflow/lite/delegates/flex/buffer_map.cc | 5 ++++ tensorflow/lite/delegates/flex/buffer_map.h | 5 ++++ .../lite/delegates/flex/delegate_test.cc | 26 +++++++++++++++++++ tensorflow/lite/delegates/flex/kernel.cc | 7 +++++ tensorflow/lite/delegates/flex/test_util.cc | 16 ++++++++++++ tensorflow/lite/delegates/flex/test_util.h | 7 +++++ 6 files changed, 66 insertions(+) diff --git a/tensorflow/lite/delegates/flex/buffer_map.cc b/tensorflow/lite/delegates/flex/buffer_map.cc index c2611290c1b..86ea4b849ea 100644 --- a/tensorflow/lite/delegates/flex/buffer_map.cc +++ b/tensorflow/lite/delegates/flex/buffer_map.cc @@ -149,6 +149,11 @@ tensorflow::Tensor BufferMap::GetTensor(int tensor_index) const { return id_to_tensor_.at(tensor_index); } +const tensorflow::Tensor* BufferMap::GetTensorPtr(int tensor_index) const { + auto& tensor = id_to_tensor_.at(tensor_index); + return &tensor; +} + void BufferMap::SetFromTfLite(int tensor_index, const TfLiteTensor* tensor) { tensorflow::TensorShape shape; int num_dims = tensor->dims->size; diff --git a/tensorflow/lite/delegates/flex/buffer_map.h b/tensorflow/lite/delegates/flex/buffer_map.h index 6c35895c249..6a29c7f80dc 100644 --- a/tensorflow/lite/delegates/flex/buffer_map.h +++ b/tensorflow/lite/delegates/flex/buffer_map.h @@ -47,6 +47,11 @@ class BufferMap { // Precondition: HasTensor() is true. tensorflow::Tensor GetTensor(int tensor_index) const; + // Returns the const pointer to tensorflow::Tensor associated with the given + // 'tensor_index'. + // Precondition: HasTensor() is true. + const tensorflow::Tensor* GetTensorPtr(int tensor_index) const; + // Associates the given tensorflow::Tensor with the given 'tensor_index'. // Note that TensorFlow Tensors share data buffers, so this method is only a // shallow copy. diff --git a/tensorflow/lite/delegates/flex/delegate_test.cc b/tensorflow/lite/delegates/flex/delegate_test.cc index 02ad4201307..e50ea1d9b0e 100644 --- a/tensorflow/lite/delegates/flex/delegate_test.cc +++ b/tensorflow/lite/delegates/flex/delegate_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/flex/delegate.h" +#include + #include #include #include "tensorflow/lite/delegates/flex/test_util.h" @@ -343,6 +345,30 @@ TEST_F(DelegateTest, StaticOutput) { ASSERT_FALSE(IsDynamicTensor(6)); } +TEST_F(DelegateTest, StaticOutputRFFT) { + // Define the graph with input, output shapes of [3, 257]. + AddTensors(4, {0, 1}, {3}, kTfLiteFloat32, {3, 257}); + int32_t rfft_length[] = {512}; + SetConstTensor(1, {1}, kTfLiteInt32, + reinterpret_cast(&rfft_length), + sizeof(rfft_length)); + + AddTfOp(testing::kRfft, {0, 1}, {2}); + AddTfOp(testing::kImag, {2}, {3}); + + // Apply the delegate. + ConfigureDelegate(); + + // Define inputs. + SetShape(0, {3, 512}); + + ASSERT_TRUE(Invoke()); + + ASSERT_EQ(GetType(3), kTfLiteFloat32); + // Since shapes are consistent, static output tensor is used. + ASSERT_FALSE(IsDynamicTensor(3)); +} + TEST_F(DelegateTest, DynamicOutputAfterReshape) { // Define the graph. AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); diff --git a/tensorflow/lite/delegates/flex/kernel.cc b/tensorflow/lite/delegates/flex/kernel.cc index 9674ee7b7f1..f21c984fe3e 100644 --- a/tensorflow/lite/delegates/flex/kernel.cc +++ b/tensorflow/lite/delegates/flex/kernel.cc @@ -533,6 +533,12 @@ TfLiteStatus DelegateKernel::ValidateOutputTensorShapeConsistency( for (int i = 0; i < num_inputs; ++i) { const auto input_tensor_index = node_data->inputs().TfLiteIndex(i); TfLiteTensor* tfl_tensor = &context->tensors[input_tensor_index]; + // Provide constant input tensors since some op ("RFFT") needs it to + // calculate the output shape. + if (IsConstantTensor(tfl_tensor)) { + input_tensors_vector[i] = + op_data_->buffer_map->GetTensorPtr(input_tensor_index); + } const auto dims_array = tfl_tensor->dims; std::vector dims(dims_array->size); for (int j = 0; j < dims_array->size; ++j) { @@ -540,6 +546,7 @@ TfLiteStatus DelegateKernel::ValidateOutputTensorShapeConsistency( } c.SetInput(i, c.MakeShape(dims)); } + c.set_input_tensors(input_tensors_vector); tensorflow::Status status = c.construction_status(); if (!status.ok()) { diff --git a/tensorflow/lite/delegates/flex/test_util.cc b/tensorflow/lite/delegates/flex/test_util.cc index fd566034b3d..02685aa0502 100644 --- a/tensorflow/lite/delegates/flex/test_util.cc +++ b/tensorflow/lite/delegates/flex/test_util.cc @@ -92,6 +92,18 @@ void FlexModelTest::AddTensors(int num_tensors, const std::vector& inputs, CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk); } +void FlexModelTest::SetConstTensor(int tensor_index, + const std::vector& values, + TfLiteType type, const char* buffer, + size_t bytes) { + TfLiteQuantizationParams quant; + CHECK_EQ(interpreter_->SetTensorParametersReadOnly(tensor_index, type, + /*name=*/"", + /*dims=*/values, quant, + buffer, bytes), + kTfLiteOk); +} + void FlexModelTest::AddTfLiteMulOp(const std::vector& inputs, const std::vector& outputs) { ++next_op_index_; @@ -158,6 +170,10 @@ void FlexModelTest::AddTfOp(TfOpType op, const std::vector& inputs, } else if (op == kMul) { string attributes = type_attribute; AddTfOp("FlexMul", "Mul", attributes, inputs, outputs); + } else if (op == kRfft) { + AddTfOp("FlexRFFT", "RFFT", "", inputs, outputs); + } else if (op == kImag) { + AddTfOp("FlexImag", "Imag", "", inputs, outputs); } else if (op == kNonExistent) { AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs); } else if (op == kIncompatibleNodeDef) { diff --git a/tensorflow/lite/delegates/flex/test_util.h b/tensorflow/lite/delegates/flex/test_util.h index bc74d8578a2..c00adbfe9b3 100644 --- a/tensorflow/lite/delegates/flex/test_util.h +++ b/tensorflow/lite/delegates/flex/test_util.h @@ -28,6 +28,8 @@ enum TfOpType { kIdentity, kAdd, kMul, + kRfft, + kImag, // Represents an op that does not exist in TensorFlow. kNonExistent, // Represents an valid TensorFlow op where the NodeDef is incompatible. @@ -92,6 +94,11 @@ class FlexModelTest : public ::testing::Test { const std::vector& outputs, TfLiteType type, const std::vector& dims); + // Set a constant tensor of the given shape, type and buffer at the given + // index. + void SetConstTensor(int tensor_index, const std::vector& values, + TfLiteType type, const char* buffer, size_t bytes); + // Adds a TFLite Mul op. `inputs` contains the indices of the input tensors // and `outputs` contains the indices of the output tensors. void AddTfLiteMulOp(const std::vector& inputs,