FlexDelegate: Provide constant tensors to ShapeInferenceFn
Some op (such as RFFT) requires a constant tensor to calculate the exact output tensor shape. PiperOrigin-RevId: 333643286 Change-Id: I826ab115551ad805aeaccb132dd6a2b8f133e578
This commit is contained in:
parent
8dbab4f830
commit
7e6061037f
@ -149,6 +149,11 @@ tensorflow::Tensor BufferMap::GetTensor(int tensor_index) const {
|
|||||||
return id_to_tensor_.at(tensor_index);
|
return id_to_tensor_.at(tensor_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const tensorflow::Tensor* BufferMap::GetTensorPtr(int tensor_index) const {
|
||||||
|
auto& tensor = id_to_tensor_.at(tensor_index);
|
||||||
|
return &tensor;
|
||||||
|
}
|
||||||
|
|
||||||
void BufferMap::SetFromTfLite(int tensor_index, const TfLiteTensor* tensor) {
|
void BufferMap::SetFromTfLite(int tensor_index, const TfLiteTensor* tensor) {
|
||||||
tensorflow::TensorShape shape;
|
tensorflow::TensorShape shape;
|
||||||
int num_dims = tensor->dims->size;
|
int num_dims = tensor->dims->size;
|
||||||
|
@ -47,6 +47,11 @@ class BufferMap {
|
|||||||
// Precondition: HasTensor() is true.
|
// Precondition: HasTensor() is true.
|
||||||
tensorflow::Tensor GetTensor(int tensor_index) const;
|
tensorflow::Tensor GetTensor(int tensor_index) const;
|
||||||
|
|
||||||
|
// Returns the const pointer to tensorflow::Tensor associated with the given
|
||||||
|
// 'tensor_index'.
|
||||||
|
// Precondition: HasTensor() is true.
|
||||||
|
const tensorflow::Tensor* GetTensorPtr(int tensor_index) const;
|
||||||
|
|
||||||
// Associates the given tensorflow::Tensor with the given 'tensor_index'.
|
// Associates the given tensorflow::Tensor with the given 'tensor_index'.
|
||||||
// Note that TensorFlow Tensors share data buffers, so this method is only a
|
// Note that TensorFlow Tensors share data buffers, so this method is only a
|
||||||
// shallow copy.
|
// shallow copy.
|
||||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/lite/delegates/flex/delegate.h"
|
#include "tensorflow/lite/delegates/flex/delegate.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
#include <gmock/gmock.h>
|
#include <gmock/gmock.h>
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include "tensorflow/lite/delegates/flex/test_util.h"
|
#include "tensorflow/lite/delegates/flex/test_util.h"
|
||||||
@ -343,6 +345,30 @@ TEST_F(DelegateTest, StaticOutput) {
|
|||||||
ASSERT_FALSE(IsDynamicTensor(6));
|
ASSERT_FALSE(IsDynamicTensor(6));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DelegateTest, StaticOutputRFFT) {
|
||||||
|
// Define the graph with input, output shapes of [3, 257].
|
||||||
|
AddTensors(4, {0, 1}, {3}, kTfLiteFloat32, {3, 257});
|
||||||
|
int32_t rfft_length[] = {512};
|
||||||
|
SetConstTensor(1, {1}, kTfLiteInt32,
|
||||||
|
reinterpret_cast<const char*>(&rfft_length),
|
||||||
|
sizeof(rfft_length));
|
||||||
|
|
||||||
|
AddTfOp(testing::kRfft, {0, 1}, {2});
|
||||||
|
AddTfOp(testing::kImag, {2}, {3});
|
||||||
|
|
||||||
|
// Apply the delegate.
|
||||||
|
ConfigureDelegate();
|
||||||
|
|
||||||
|
// Define inputs.
|
||||||
|
SetShape(0, {3, 512});
|
||||||
|
|
||||||
|
ASSERT_TRUE(Invoke());
|
||||||
|
|
||||||
|
ASSERT_EQ(GetType(3), kTfLiteFloat32);
|
||||||
|
// Since shapes are consistent, static output tensor is used.
|
||||||
|
ASSERT_FALSE(IsDynamicTensor(3));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DelegateTest, DynamicOutputAfterReshape) {
|
TEST_F(DelegateTest, DynamicOutputAfterReshape) {
|
||||||
// Define the graph.
|
// Define the graph.
|
||||||
AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
|
AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
|
||||||
|
@ -533,6 +533,12 @@ TfLiteStatus DelegateKernel::ValidateOutputTensorShapeConsistency(
|
|||||||
for (int i = 0; i < num_inputs; ++i) {
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
const auto input_tensor_index = node_data->inputs().TfLiteIndex(i);
|
const auto input_tensor_index = node_data->inputs().TfLiteIndex(i);
|
||||||
TfLiteTensor* tfl_tensor = &context->tensors[input_tensor_index];
|
TfLiteTensor* tfl_tensor = &context->tensors[input_tensor_index];
|
||||||
|
// Provide constant input tensors since some op ("RFFT") needs it to
|
||||||
|
// calculate the output shape.
|
||||||
|
if (IsConstantTensor(tfl_tensor)) {
|
||||||
|
input_tensors_vector[i] =
|
||||||
|
op_data_->buffer_map->GetTensorPtr(input_tensor_index);
|
||||||
|
}
|
||||||
const auto dims_array = tfl_tensor->dims;
|
const auto dims_array = tfl_tensor->dims;
|
||||||
std::vector<DimensionHandle> dims(dims_array->size);
|
std::vector<DimensionHandle> dims(dims_array->size);
|
||||||
for (int j = 0; j < dims_array->size; ++j) {
|
for (int j = 0; j < dims_array->size; ++j) {
|
||||||
@ -540,6 +546,7 @@ TfLiteStatus DelegateKernel::ValidateOutputTensorShapeConsistency(
|
|||||||
}
|
}
|
||||||
c.SetInput(i, c.MakeShape(dims));
|
c.SetInput(i, c.MakeShape(dims));
|
||||||
}
|
}
|
||||||
|
c.set_input_tensors(input_tensors_vector);
|
||||||
|
|
||||||
tensorflow::Status status = c.construction_status();
|
tensorflow::Status status = c.construction_status();
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
|
@ -92,6 +92,18 @@ void FlexModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
|
|||||||
CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk);
|
CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void FlexModelTest::SetConstTensor(int tensor_index,
|
||||||
|
const std::vector<int>& values,
|
||||||
|
TfLiteType type, const char* buffer,
|
||||||
|
size_t bytes) {
|
||||||
|
TfLiteQuantizationParams quant;
|
||||||
|
CHECK_EQ(interpreter_->SetTensorParametersReadOnly(tensor_index, type,
|
||||||
|
/*name=*/"",
|
||||||
|
/*dims=*/values, quant,
|
||||||
|
buffer, bytes),
|
||||||
|
kTfLiteOk);
|
||||||
|
}
|
||||||
|
|
||||||
void FlexModelTest::AddTfLiteMulOp(const std::vector<int>& inputs,
|
void FlexModelTest::AddTfLiteMulOp(const std::vector<int>& inputs,
|
||||||
const std::vector<int>& outputs) {
|
const std::vector<int>& outputs) {
|
||||||
++next_op_index_;
|
++next_op_index_;
|
||||||
@ -158,6 +170,10 @@ void FlexModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
|
|||||||
} else if (op == kMul) {
|
} else if (op == kMul) {
|
||||||
string attributes = type_attribute;
|
string attributes = type_attribute;
|
||||||
AddTfOp("FlexMul", "Mul", attributes, inputs, outputs);
|
AddTfOp("FlexMul", "Mul", attributes, inputs, outputs);
|
||||||
|
} else if (op == kRfft) {
|
||||||
|
AddTfOp("FlexRFFT", "RFFT", "", inputs, outputs);
|
||||||
|
} else if (op == kImag) {
|
||||||
|
AddTfOp("FlexImag", "Imag", "", inputs, outputs);
|
||||||
} else if (op == kNonExistent) {
|
} else if (op == kNonExistent) {
|
||||||
AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs);
|
AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs);
|
||||||
} else if (op == kIncompatibleNodeDef) {
|
} else if (op == kIncompatibleNodeDef) {
|
||||||
|
@ -28,6 +28,8 @@ enum TfOpType {
|
|||||||
kIdentity,
|
kIdentity,
|
||||||
kAdd,
|
kAdd,
|
||||||
kMul,
|
kMul,
|
||||||
|
kRfft,
|
||||||
|
kImag,
|
||||||
// Represents an op that does not exist in TensorFlow.
|
// Represents an op that does not exist in TensorFlow.
|
||||||
kNonExistent,
|
kNonExistent,
|
||||||
// Represents an valid TensorFlow op where the NodeDef is incompatible.
|
// Represents an valid TensorFlow op where the NodeDef is incompatible.
|
||||||
@ -92,6 +94,11 @@ class FlexModelTest : public ::testing::Test {
|
|||||||
const std::vector<int>& outputs, TfLiteType type,
|
const std::vector<int>& outputs, TfLiteType type,
|
||||||
const std::vector<int>& dims);
|
const std::vector<int>& dims);
|
||||||
|
|
||||||
|
// Set a constant tensor of the given shape, type and buffer at the given
|
||||||
|
// index.
|
||||||
|
void SetConstTensor(int tensor_index, const std::vector<int>& values,
|
||||||
|
TfLiteType type, const char* buffer, size_t bytes);
|
||||||
|
|
||||||
// Adds a TFLite Mul op. `inputs` contains the indices of the input tensors
|
// Adds a TFLite Mul op. `inputs` contains the indices of the input tensors
|
||||||
// and `outputs` contains the indices of the output tensors.
|
// and `outputs` contains the indices of the output tensors.
|
||||||
void AddTfLiteMulOp(const std::vector<int>& inputs,
|
void AddTfLiteMulOp(const std::vector<int>& inputs,
|
||||||
|
Loading…
Reference in New Issue
Block a user