FlexDelegate: Provide constant tensors to ShapeInferenceFn

Some op (such as RFFT) requires a constant tensor to calculate the exact
output tensor shape.

PiperOrigin-RevId: 333643286
Change-Id: I826ab115551ad805aeaccb132dd6a2b8f133e578
This commit is contained in:
Terry Heo 2020-09-24 18:37:16 -07:00 committed by TensorFlower Gardener
parent 8dbab4f830
commit 7e6061037f
6 changed files with 66 additions and 0 deletions

View File

@ -149,6 +149,11 @@ tensorflow::Tensor BufferMap::GetTensor(int tensor_index) const {
return id_to_tensor_.at(tensor_index);
}
const tensorflow::Tensor* BufferMap::GetTensorPtr(int tensor_index) const {
auto& tensor = id_to_tensor_.at(tensor_index);
return &tensor;
}
void BufferMap::SetFromTfLite(int tensor_index, const TfLiteTensor* tensor) {
tensorflow::TensorShape shape;
int num_dims = tensor->dims->size;

View File

@ -47,6 +47,11 @@ class BufferMap {
// Precondition: HasTensor() is true.
tensorflow::Tensor GetTensor(int tensor_index) const;
// Returns the const pointer to tensorflow::Tensor associated with the given
// 'tensor_index'.
// Precondition: HasTensor() is true.
const tensorflow::Tensor* GetTensorPtr(int tensor_index) const;
// Associates the given tensorflow::Tensor with the given 'tensor_index'.
// Note that TensorFlow Tensors share data buffers, so this method is only a
// shallow copy.

View File

@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/flex/delegate.h"
#include <cstdint>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/delegates/flex/test_util.h"
@ -343,6 +345,30 @@ TEST_F(DelegateTest, StaticOutput) {
ASSERT_FALSE(IsDynamicTensor(6));
}
TEST_F(DelegateTest, StaticOutputRFFT) {
// Define the graph with input, output shapes of [3, 257].
AddTensors(4, {0, 1}, {3}, kTfLiteFloat32, {3, 257});
int32_t rfft_length[] = {512};
SetConstTensor(1, {1}, kTfLiteInt32,
reinterpret_cast<const char*>(&rfft_length),
sizeof(rfft_length));
AddTfOp(testing::kRfft, {0, 1}, {2});
AddTfOp(testing::kImag, {2}, {3});
// Apply the delegate.
ConfigureDelegate();
// Define inputs.
SetShape(0, {3, 512});
ASSERT_TRUE(Invoke());
ASSERT_EQ(GetType(3), kTfLiteFloat32);
// Since shapes are consistent, static output tensor is used.
ASSERT_FALSE(IsDynamicTensor(3));
}
TEST_F(DelegateTest, DynamicOutputAfterReshape) {
// Define the graph.
AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});

View File

@ -533,6 +533,12 @@ TfLiteStatus DelegateKernel::ValidateOutputTensorShapeConsistency(
for (int i = 0; i < num_inputs; ++i) {
const auto input_tensor_index = node_data->inputs().TfLiteIndex(i);
TfLiteTensor* tfl_tensor = &context->tensors[input_tensor_index];
// Provide constant input tensors since some op ("RFFT") needs it to
// calculate the output shape.
if (IsConstantTensor(tfl_tensor)) {
input_tensors_vector[i] =
op_data_->buffer_map->GetTensorPtr(input_tensor_index);
}
const auto dims_array = tfl_tensor->dims;
std::vector<DimensionHandle> dims(dims_array->size);
for (int j = 0; j < dims_array->size; ++j) {
@ -540,6 +546,7 @@ TfLiteStatus DelegateKernel::ValidateOutputTensorShapeConsistency(
}
c.SetInput(i, c.MakeShape(dims));
}
c.set_input_tensors(input_tensors_vector);
tensorflow::Status status = c.construction_status();
if (!status.ok()) {

View File

@ -92,6 +92,18 @@ void FlexModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk);
}
void FlexModelTest::SetConstTensor(int tensor_index,
const std::vector<int>& values,
TfLiteType type, const char* buffer,
size_t bytes) {
TfLiteQuantizationParams quant;
CHECK_EQ(interpreter_->SetTensorParametersReadOnly(tensor_index, type,
/*name=*/"",
/*dims=*/values, quant,
buffer, bytes),
kTfLiteOk);
}
void FlexModelTest::AddTfLiteMulOp(const std::vector<int>& inputs,
const std::vector<int>& outputs) {
++next_op_index_;
@ -158,6 +170,10 @@ void FlexModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
} else if (op == kMul) {
string attributes = type_attribute;
AddTfOp("FlexMul", "Mul", attributes, inputs, outputs);
} else if (op == kRfft) {
AddTfOp("FlexRFFT", "RFFT", "", inputs, outputs);
} else if (op == kImag) {
AddTfOp("FlexImag", "Imag", "", inputs, outputs);
} else if (op == kNonExistent) {
AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs);
} else if (op == kIncompatibleNodeDef) {

View File

@ -28,6 +28,8 @@ enum TfOpType {
kIdentity,
kAdd,
kMul,
kRfft,
kImag,
// Represents an op that does not exist in TensorFlow.
kNonExistent,
// Represents an valid TensorFlow op where the NodeDef is incompatible.
@ -92,6 +94,11 @@ class FlexModelTest : public ::testing::Test {
const std::vector<int>& outputs, TfLiteType type,
const std::vector<int>& dims);
// Set a constant tensor of the given shape, type and buffer at the given
// index.
void SetConstTensor(int tensor_index, const std::vector<int>& values,
TfLiteType type, const char* buffer, size_t bytes);
// Adds a TFLite Mul op. `inputs` contains the indices of the input tensors
// and `outputs` contains the indices of the output tensors.
void AddTfLiteMulOp(const std::vector<int>& inputs,