FlexDelegate: Provide constant tensors to ShapeInferenceFn

Some op (such as RFFT) requires a constant tensor to calculate the exact
output tensor shape.

PiperOrigin-RevId: 333643286
Change-Id: I826ab115551ad805aeaccb132dd6a2b8f133e578
This commit is contained in:
Terry Heo 2020-09-24 18:37:16 -07:00 committed by TensorFlower Gardener
parent 8dbab4f830
commit 7e6061037f
6 changed files with 66 additions and 0 deletions

View File

@ -149,6 +149,11 @@ tensorflow::Tensor BufferMap::GetTensor(int tensor_index) const {
return id_to_tensor_.at(tensor_index); return id_to_tensor_.at(tensor_index);
} }
const tensorflow::Tensor* BufferMap::GetTensorPtr(int tensor_index) const {
auto& tensor = id_to_tensor_.at(tensor_index);
return &tensor;
}
void BufferMap::SetFromTfLite(int tensor_index, const TfLiteTensor* tensor) { void BufferMap::SetFromTfLite(int tensor_index, const TfLiteTensor* tensor) {
tensorflow::TensorShape shape; tensorflow::TensorShape shape;
int num_dims = tensor->dims->size; int num_dims = tensor->dims->size;

View File

@ -47,6 +47,11 @@ class BufferMap {
// Precondition: HasTensor() is true. // Precondition: HasTensor() is true.
tensorflow::Tensor GetTensor(int tensor_index) const; tensorflow::Tensor GetTensor(int tensor_index) const;
// Returns the const pointer to tensorflow::Tensor associated with the given
// 'tensor_index'.
// Precondition: HasTensor() is true.
const tensorflow::Tensor* GetTensorPtr(int tensor_index) const;
// Associates the given tensorflow::Tensor with the given 'tensor_index'. // Associates the given tensorflow::Tensor with the given 'tensor_index'.
// Note that TensorFlow Tensors share data buffers, so this method is only a // Note that TensorFlow Tensors share data buffers, so this method is only a
// shallow copy. // shallow copy.

View File

@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/lite/delegates/flex/delegate.h" #include "tensorflow/lite/delegates/flex/delegate.h"
#include <cstdint>
#include <gmock/gmock.h> #include <gmock/gmock.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "tensorflow/lite/delegates/flex/test_util.h" #include "tensorflow/lite/delegates/flex/test_util.h"
@ -343,6 +345,30 @@ TEST_F(DelegateTest, StaticOutput) {
ASSERT_FALSE(IsDynamicTensor(6)); ASSERT_FALSE(IsDynamicTensor(6));
} }
TEST_F(DelegateTest, StaticOutputRFFT) {
// Define the graph with input, output shapes of [3, 257].
AddTensors(4, {0, 1}, {3}, kTfLiteFloat32, {3, 257});
int32_t rfft_length[] = {512};
SetConstTensor(1, {1}, kTfLiteInt32,
reinterpret_cast<const char*>(&rfft_length),
sizeof(rfft_length));
AddTfOp(testing::kRfft, {0, 1}, {2});
AddTfOp(testing::kImag, {2}, {3});
// Apply the delegate.
ConfigureDelegate();
// Define inputs.
SetShape(0, {3, 512});
ASSERT_TRUE(Invoke());
ASSERT_EQ(GetType(3), kTfLiteFloat32);
// Since shapes are consistent, static output tensor is used.
ASSERT_FALSE(IsDynamicTensor(3));
}
TEST_F(DelegateTest, DynamicOutputAfterReshape) { TEST_F(DelegateTest, DynamicOutputAfterReshape) {
// Define the graph. // Define the graph.
AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});

View File

@ -533,6 +533,12 @@ TfLiteStatus DelegateKernel::ValidateOutputTensorShapeConsistency(
for (int i = 0; i < num_inputs; ++i) { for (int i = 0; i < num_inputs; ++i) {
const auto input_tensor_index = node_data->inputs().TfLiteIndex(i); const auto input_tensor_index = node_data->inputs().TfLiteIndex(i);
TfLiteTensor* tfl_tensor = &context->tensors[input_tensor_index]; TfLiteTensor* tfl_tensor = &context->tensors[input_tensor_index];
// Provide constant input tensors since some op ("RFFT") needs it to
// calculate the output shape.
if (IsConstantTensor(tfl_tensor)) {
input_tensors_vector[i] =
op_data_->buffer_map->GetTensorPtr(input_tensor_index);
}
const auto dims_array = tfl_tensor->dims; const auto dims_array = tfl_tensor->dims;
std::vector<DimensionHandle> dims(dims_array->size); std::vector<DimensionHandle> dims(dims_array->size);
for (int j = 0; j < dims_array->size; ++j) { for (int j = 0; j < dims_array->size; ++j) {
@ -540,6 +546,7 @@ TfLiteStatus DelegateKernel::ValidateOutputTensorShapeConsistency(
} }
c.SetInput(i, c.MakeShape(dims)); c.SetInput(i, c.MakeShape(dims));
} }
c.set_input_tensors(input_tensors_vector);
tensorflow::Status status = c.construction_status(); tensorflow::Status status = c.construction_status();
if (!status.ok()) { if (!status.ok()) {

View File

@ -92,6 +92,18 @@ void FlexModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk); CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk);
} }
void FlexModelTest::SetConstTensor(int tensor_index,
const std::vector<int>& values,
TfLiteType type, const char* buffer,
size_t bytes) {
TfLiteQuantizationParams quant;
CHECK_EQ(interpreter_->SetTensorParametersReadOnly(tensor_index, type,
/*name=*/"",
/*dims=*/values, quant,
buffer, bytes),
kTfLiteOk);
}
void FlexModelTest::AddTfLiteMulOp(const std::vector<int>& inputs, void FlexModelTest::AddTfLiteMulOp(const std::vector<int>& inputs,
const std::vector<int>& outputs) { const std::vector<int>& outputs) {
++next_op_index_; ++next_op_index_;
@ -158,6 +170,10 @@ void FlexModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
} else if (op == kMul) { } else if (op == kMul) {
string attributes = type_attribute; string attributes = type_attribute;
AddTfOp("FlexMul", "Mul", attributes, inputs, outputs); AddTfOp("FlexMul", "Mul", attributes, inputs, outputs);
} else if (op == kRfft) {
AddTfOp("FlexRFFT", "RFFT", "", inputs, outputs);
} else if (op == kImag) {
AddTfOp("FlexImag", "Imag", "", inputs, outputs);
} else if (op == kNonExistent) { } else if (op == kNonExistent) {
AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs); AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs);
} else if (op == kIncompatibleNodeDef) { } else if (op == kIncompatibleNodeDef) {

View File

@ -28,6 +28,8 @@ enum TfOpType {
kIdentity, kIdentity,
kAdd, kAdd,
kMul, kMul,
kRfft,
kImag,
// Represents an op that does not exist in TensorFlow. // Represents an op that does not exist in TensorFlow.
kNonExistent, kNonExistent,
// Represents an valid TensorFlow op where the NodeDef is incompatible. // Represents an valid TensorFlow op where the NodeDef is incompatible.
@ -92,6 +94,11 @@ class FlexModelTest : public ::testing::Test {
const std::vector<int>& outputs, TfLiteType type, const std::vector<int>& outputs, TfLiteType type,
const std::vector<int>& dims); const std::vector<int>& dims);
// Set a constant tensor of the given shape, type and buffer at the given
// index.
void SetConstTensor(int tensor_index, const std::vector<int>& values,
TfLiteType type, const char* buffer, size_t bytes);
// Adds a TFLite Mul op. `inputs` contains the indices of the input tensors // Adds a TFLite Mul op. `inputs` contains the indices of the input tensors
// and `outputs` contains the indices of the output tensors. // and `outputs` contains the indices of the output tensors.
void AddTfLiteMulOp(const std::vector<int>& inputs, void AddTfLiteMulOp(const std::vector<int>& inputs,