Flex delegate: Do not use dynamic tensors if possible
Added ValidateOutputTensorShapeConsistency() function to check if the existing tensor shapes are consistent. When they're consistent, the Flex delegate will not use dynamic tensors since it prevents from using another delegate. Updated multi_add_flex.bin to use 1-d tensors instead of scalar tensors. PiperOrigin-RevId: 332992406 Change-Id: Ifaf518cc81cd0adad5d327584a284a50a9f2657b
This commit is contained in:
parent
127d50358a
commit
1a2b1d0ae1
@ -313,6 +313,64 @@ TEST_F(DelegateTest, TF_AcquireFlexDelegate) {
|
||||
}
|
||||
#endif // !defined(__ANDROID__)
|
||||
|
||||
TEST_F(DelegateTest, StaticOutput) {
|
||||
// Define the graph with input, output shapes of [2].
|
||||
AddTensors(7, {0, 1, 2, 3}, {6}, kTfLiteFloat32, {2});
|
||||
|
||||
AddTfOp(testing::kAdd, {0, 2}, {4});
|
||||
AddTfOp(testing::kAdd, {1, 3}, {5});
|
||||
AddTfOp(testing::kMul, {4, 5}, {6});
|
||||
|
||||
// Apply the delegate.
|
||||
ConfigureDelegate();
|
||||
|
||||
// Define inputs which matech with the original shapes.
|
||||
SetShape(0, {2});
|
||||
SetShape(1, {2});
|
||||
SetShape(2, {2});
|
||||
SetShape(3, {2});
|
||||
SetValues(0, {1.1f, 2.2f});
|
||||
SetValues(1, {3.3f, 4.4f});
|
||||
SetValues(2, {1.1f, 2.2f});
|
||||
SetValues(3, {3.3f, 4.4f});
|
||||
|
||||
ASSERT_TRUE(Invoke());
|
||||
|
||||
ASSERT_THAT(GetShape(6), ElementsAre(2));
|
||||
ASSERT_THAT(GetValues(6), ElementsAre(14.52f, 38.72f));
|
||||
ASSERT_EQ(GetType(6), kTfLiteFloat32);
|
||||
// Since shapes are consistent, static output tensor is used.
|
||||
ASSERT_FALSE(IsDynamicTensor(6));
|
||||
}
|
||||
|
||||
TEST_F(DelegateTest, DynamicOutputAfterReshape) {
|
||||
// Define the graph.
|
||||
AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
|
||||
|
||||
AddTfOp(testing::kUnpack, {0}, {1, 2});
|
||||
AddTfOp(testing::kUnpack, {3}, {4, 5});
|
||||
AddTfOp(testing::kAdd, {1, 4}, {6});
|
||||
AddTfOp(testing::kAdd, {2, 5}, {7});
|
||||
AddTfOp(testing::kMul, {6, 7}, {8});
|
||||
|
||||
// Apply the delegate.
|
||||
ConfigureDelegate();
|
||||
|
||||
// Define inputs with reshape.
|
||||
SetShape(0, {2, 2, 1});
|
||||
SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
|
||||
SetShape(3, {2, 2, 1});
|
||||
SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
|
||||
|
||||
ASSERT_TRUE(Invoke());
|
||||
|
||||
ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
|
||||
ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
|
||||
ASSERT_EQ(GetType(8), kTfLiteFloat32);
|
||||
// Since shapes are inconsistent, dynamic output tensor is used.
|
||||
ASSERT_TRUE(IsDynamicTensor(8));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace flex
|
||||
} // namespace tflite
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/flex/delegate_data.h"
|
||||
#include "tensorflow/lite/delegates/flex/util.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/minimal_logging.h"
|
||||
#include "tensorflow/lite/string_type.h"
|
||||
|
||||
// Note: this is part of TF Lite's Flex delegation code which is to be
|
||||
@ -48,6 +49,16 @@ limitations under the License.
|
||||
// retrieve the associated NodeDef, which is then used to configure the
|
||||
// corresponding TensorFlow/Eager Op.
|
||||
|
||||
using tensorflow::shape_inference::DimensionHandle;
|
||||
using tensorflow::shape_inference::InferenceContext;
|
||||
using tensorflow::shape_inference::ShapeAndType;
|
||||
using tensorflow::shape_inference::ShapeHandle;
|
||||
|
||||
const std::string GetDimsDebugString(const TfLiteIntArray* dims) {
|
||||
return absl::StrCat("[", absl::StrJoin(tflite::TfLiteIntArrayView(dims), ","),
|
||||
"]");
|
||||
}
|
||||
|
||||
namespace tflite {
|
||||
namespace flex {
|
||||
|
||||
@ -188,6 +199,9 @@ class OpNode {
|
||||
void set_index(int index) { index_ = index; }
|
||||
|
||||
const tensorflow::NodeDef& nodedef() const { return nodedef_; }
|
||||
const tensorflow::OpRegistrationData* op_reg_data() const {
|
||||
return op_reg_data_;
|
||||
}
|
||||
|
||||
const OpInputs& inputs() const { return inputs_; }
|
||||
OpInputs* mutable_inputs() { return &inputs_; }
|
||||
@ -222,10 +236,9 @@ class OpNode {
|
||||
}
|
||||
|
||||
// Fill NodeDef with defaults if it's a valid op.
|
||||
const tensorflow::OpRegistrationData* op_reg_data;
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorflow::OpRegistry::Global()->LookUp(nodedef_.op(), &op_reg_data));
|
||||
AddDefaultsToNodeDef(op_reg_data->op_def, &nodedef_);
|
||||
tensorflow::OpRegistry::Global()->LookUp(nodedef_.op(), &op_reg_data_));
|
||||
AddDefaultsToNodeDef(op_reg_data_->op_def, &nodedef_);
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
@ -312,6 +325,8 @@ class OpNode {
|
||||
int index_;
|
||||
// The corresponding NodeDef, containing the attributes for the op.
|
||||
tensorflow::NodeDef nodedef_;
|
||||
// The corresponding OpRegistrationData pointer.
|
||||
const tensorflow::OpRegistrationData* op_reg_data_;
|
||||
// List of inputs, as TF Lite tensor indices.
|
||||
OpInputs inputs_;
|
||||
// List of outputs, as TF Lite tensor indices.
|
||||
@ -455,10 +470,22 @@ TfLiteStatus DelegateKernel::Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
tensor_ref_count[tensor_index] += 2;
|
||||
}
|
||||
|
||||
const bool shapes_are_valid =
|
||||
(ValidateOutputTensorShapeConsistency(context) == kTfLiteOk);
|
||||
if (shapes_are_valid) {
|
||||
TFLITE_LOG(tflite::TFLITE_LOG_INFO,
|
||||
"FlexDelegate: All tensor shapes are consistent.");
|
||||
} else {
|
||||
TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
|
||||
"FlexDelegate: Some tensor shapes are inconsistent.");
|
||||
}
|
||||
|
||||
// All output tensors are allocated by TensorFlow/Eager, so we
|
||||
// mark them as kTfLiteDynamic.
|
||||
for (auto tensor_index : op_data_->subgraph_outputs) {
|
||||
SetTensorToDynamic(&context->tensors[tensor_index]);
|
||||
if (!shapes_are_valid) {
|
||||
SetTensorToDynamic(&context->tensors[tensor_index]);
|
||||
}
|
||||
++tensor_ref_count[tensor_index];
|
||||
}
|
||||
|
||||
@ -488,6 +515,78 @@ TfLiteStatus DelegateKernel::Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus DelegateKernel::ValidateOutputTensorShapeConsistency(
|
||||
TfLiteContext* context) const {
|
||||
for (const auto& node_data : op_data_->nodes) {
|
||||
auto op_name = node_data->name().c_str();
|
||||
// Create an InferenceContext object.
|
||||
auto num_inputs = node_data->inputs().Size();
|
||||
std::vector<const tensorflow::Tensor*> input_tensors_vector(num_inputs,
|
||||
nullptr);
|
||||
InferenceContext c(
|
||||
TF_GRAPH_DEF_VERSION, node_data->nodedef(),
|
||||
node_data->op_reg_data()->op_def, std::vector<ShapeHandle>(num_inputs),
|
||||
input_tensors_vector, {},
|
||||
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
|
||||
|
||||
// Set input_shapes for ShapeInferenceFn.
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
const auto input_tensor_index = node_data->inputs().TfLiteIndex(i);
|
||||
TfLiteTensor* tfl_tensor = &context->tensors[input_tensor_index];
|
||||
const auto dims_array = tfl_tensor->dims;
|
||||
std::vector<DimensionHandle> dims(dims_array->size);
|
||||
for (int j = 0; j < dims_array->size; ++j) {
|
||||
dims[j] = c.MakeDim(dims_array->data[j]);
|
||||
}
|
||||
c.SetInput(i, c.MakeShape(dims));
|
||||
}
|
||||
|
||||
tensorflow::Status status = c.construction_status();
|
||||
if (!status.ok()) {
|
||||
TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
|
||||
"Shape construction failed for op '%s'", op_name);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Run ShapeInferenceFn to calculate output shapes.
|
||||
if (node_data->op_reg_data()->shape_inference_fn == nullptr) {
|
||||
TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
|
||||
"No shape inference function exists for op '%s'", op_name);
|
||||
return kTfLiteError;
|
||||
}
|
||||
status = c.Run(node_data->op_reg_data()->shape_inference_fn);
|
||||
|
||||
// Compare calculated output shapes with node_data->outputs
|
||||
auto num_outputs = node_data->outputs().Size();
|
||||
if (num_outputs != c.num_outputs()) {
|
||||
TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
|
||||
"Number of output tensors are mismatched for op '%s' %d != %d",
|
||||
op_name, num_outputs, c.num_outputs());
|
||||
return kTfLiteError;
|
||||
}
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
const auto output_tensor_index = node_data->outputs().TfLiteIndex(i);
|
||||
TfLiteTensor* tfl_tensor = &context->tensors[output_tensor_index];
|
||||
// tfl_tensor->dims only has valid information if the given model is
|
||||
// converted by the MLIR converter. Also when ResizeInputTensor() is
|
||||
// called the dims information becomes invalid.
|
||||
const std::string tfl_shape_string = GetDimsDebugString(tfl_tensor->dims);
|
||||
const std::string calculated_shape_string = c.DebugString(c.output(i));
|
||||
// Getting a shape string via c.DebugString() is the easiest way to get
|
||||
// the shape information of the given ShapeHandle for now.
|
||||
// TODO(b/169017408): Find a better approach without using debug string.
|
||||
if (tfl_shape_string != calculated_shape_string) {
|
||||
TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
|
||||
"op '%s' output%d tensor#%d shape mismatch for %s != %s",
|
||||
op_name, i, output_tensor_index, tfl_shape_string.c_str(),
|
||||
calculated_shape_string.c_str());
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus DelegateKernel::Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
BufferMap* buffer_map = op_data_->buffer_map;
|
||||
|
||||
@ -522,12 +621,30 @@ TfLiteStatus DelegateKernel::Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Copy TF tensor data to TFL allocated buffer for non dynamic tensors.
|
||||
// For dynamic tensors, copy shape and put buffer_handle for the later
|
||||
// CopyFromBufferHandle() call.
|
||||
TfLiteTensor* tensor = &context->tensors[tensor_index];
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
CopyShapeAndType(context, buffer_map->GetTensor(tensor_index), tensor));
|
||||
tensor->buffer_handle = tensor_index;
|
||||
tensor->data_is_stale = true;
|
||||
const tensorflow::Tensor& tf_tensor = buffer_map->GetTensor(tensor_index);
|
||||
if (tensor->allocation_type == kTfLiteDynamic) {
|
||||
TF_LITE_ENSURE_OK(context, CopyShapeAndType(context, tf_tensor, tensor));
|
||||
tensor->buffer_handle = tensor_index;
|
||||
tensor->data_is_stale = true;
|
||||
continue;
|
||||
}
|
||||
// If the tensor isn't dynamic, we can copy data directly to the buffer of
|
||||
// the tensor. Before copying the data, check if the target buffer has
|
||||
// expected size.
|
||||
if (tf_tensor.NumElements() != NumElements(tensor) ||
|
||||
tf_tensor.TotalBytes() != tensor->bytes) {
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "Tensor: %s(%d) buffer size mismatch %zu(%lld) != %ld(%ld)",
|
||||
tensor->name, tensor_index, tf_tensor.TotalBytes(),
|
||||
tf_tensor.NumElements(), tensor->bytes, NumElements(tensor));
|
||||
return kTfLiteError;
|
||||
}
|
||||
tensorflow::StringPiece t_data = tf_tensor.tensor_data();
|
||||
memcpy(tensor->data.raw, t_data.data(), t_data.size());
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
|
@ -35,6 +35,11 @@ class DelegateKernel : public SimpleDelegateKernelInterface {
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) override;
|
||||
|
||||
private:
|
||||
// Validate that the computed output tensor shape for the Flex node matches
|
||||
// the existing output shape assigned to the output tensor.
|
||||
TfLiteStatus ValidateOutputTensorShapeConsistency(
|
||||
TfLiteContext* context) const;
|
||||
|
||||
std::unique_ptr<OpData> op_data_;
|
||||
};
|
||||
|
||||
|
@ -18,6 +18,8 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/flex/delegate_data.h"
|
||||
#include "tensorflow/lite/delegates/flex/test_util.h"
|
||||
|
||||
extern const std::string GetDimsDebugString(const TfLiteIntArray* dims);
|
||||
|
||||
namespace tflite {
|
||||
namespace flex {
|
||||
namespace testing {
|
||||
@ -351,6 +353,61 @@ TEST_F(MultipleSubgraphsTest, DoNotForwardInputTensors) {
|
||||
})));
|
||||
}
|
||||
|
||||
tensorflow::OpDef MakeOpDef(int num_inputs, int num_outputs) {
|
||||
tensorflow::OpRegistrationData op_reg_data;
|
||||
tensorflow::OpDefBuilder b("dummy");
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
b.Input(tensorflow::strings::StrCat("i", i, ": float"));
|
||||
}
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
b.Output(tensorflow::strings::StrCat("o", i, ": float"));
|
||||
}
|
||||
CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok());
|
||||
return op_reg_data.op_def;
|
||||
}
|
||||
|
||||
tensorflow::PartialTensorShape S(std::initializer_list<int64> dims) {
|
||||
return tensorflow::PartialTensorShape(dims);
|
||||
}
|
||||
|
||||
TEST(ValidateOutputTensorShapeConsistencyTest, ShapeHandleDebugString) {
|
||||
// Setup test to contain an input tensor list of size 3.
|
||||
tensorflow::OpDef op_def = MakeOpDef(4, 1);
|
||||
tensorflow::NodeDef def;
|
||||
tensorflow::shape_inference::InferenceContext c(
|
||||
0, def, op_def, {S({1}), S({2, 3}), S({4, 5, 6}), {}}, {}, {}, {});
|
||||
c.SetInput(3, c.UnknownShape());
|
||||
|
||||
std::vector<tensorflow::shape_inference::ShapeHandle> shapes;
|
||||
EXPECT_EQ("[1]", c.DebugString(c.input(0)));
|
||||
EXPECT_EQ("[2,3]", c.DebugString(c.input(1)));
|
||||
EXPECT_EQ("[4,5,6]", c.DebugString(c.input(2)));
|
||||
// c.DebugString() returns "?" for the unknown shape which is different with
|
||||
// "-1" of TFLite. But this is intended behavior since we should use dynamic
|
||||
// tensor for unknown shape so the shape comparison must fail.
|
||||
EXPECT_EQ("?", c.DebugString(c.input(3)));
|
||||
}
|
||||
|
||||
TEST(ValidateOutputTensorShapeConsistencyTest, GetDimsDebugString) {
|
||||
TfLiteIntArray* dims1 = TfLiteIntArrayCreate(1);
|
||||
dims1->data[0] = 1;
|
||||
EXPECT_EQ("[1]", GetDimsDebugString(dims1));
|
||||
free(dims1);
|
||||
|
||||
TfLiteIntArray* dims2 = TfLiteIntArrayCreate(2);
|
||||
dims2->data[0] = 2;
|
||||
dims2->data[1] = 3;
|
||||
EXPECT_EQ("[2,3]", GetDimsDebugString(dims2));
|
||||
free(dims2);
|
||||
|
||||
TfLiteIntArray* dims3 = TfLiteIntArrayCreate(3);
|
||||
dims3->data[0] = 4;
|
||||
dims3->data[1] = 5;
|
||||
dims3->data[2] = 6;
|
||||
EXPECT_EQ("[4,5,6]", GetDimsDebugString(dims3));
|
||||
free(dims3);
|
||||
}
|
||||
|
||||
} // namespace testing
|
||||
} // namespace flex
|
||||
} // namespace tflite
|
||||
|
@ -67,6 +67,10 @@ TfLiteType FlexModelTest::GetType(int tensor_index) {
|
||||
return interpreter_->tensor(tensor_index)->type;
|
||||
}
|
||||
|
||||
bool FlexModelTest::IsDynamicTensor(int tensor_index) {
|
||||
return interpreter_->tensor(tensor_index)->allocation_type == kTfLiteDynamic;
|
||||
}
|
||||
|
||||
void FlexModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
|
||||
const std::vector<int>& outputs, TfLiteType type,
|
||||
const std::vector<int>& dims) {
|
||||
|
@ -80,6 +80,9 @@ class FlexModelTest : public ::testing::Test {
|
||||
// Returns the tensor's type at the given index.
|
||||
TfLiteType GetType(int tensor_index);
|
||||
|
||||
// Returns if the tensor at the given index is dynamic.
|
||||
bool IsDynamicTensor(int tensor_index);
|
||||
|
||||
const TestErrorReporter& error_reporter() const { return error_reporter_; }
|
||||
|
||||
// Adds `num_tensor` tensors to the model. `inputs` contains the indices of
|
||||
|
BIN
tensorflow/lite/testdata/multi_add_flex.bin
vendored
BIN
tensorflow/lite/testdata/multi_add_flex.bin
vendored
Binary file not shown.
@ -103,14 +103,14 @@ TEST_F(FlexOpsListTest, TestZeroSubgraphs) {
|
||||
TEST_F(FlexOpsListTest, TestFlexAdd) {
|
||||
ReadOps("tensorflow/lite/testdata/multi_add_flex.bin");
|
||||
EXPECT_EQ(output_text_,
|
||||
"[[\"Add\",\"BinaryOp<CPUDevice, functor::add<float>>\"]]\n");
|
||||
"[[\"AddV2\",\"BinaryOp<CPUDevice, functor::add<float>>\"]]\n");
|
||||
}
|
||||
|
||||
TEST_F(FlexOpsListTest, TestTwoModel) {
|
||||
ReadOps("tensorflow/lite/testdata/multi_add_flex.bin");
|
||||
ReadOps("tensorflow/lite/testdata/softplus_flex.bin");
|
||||
EXPECT_EQ(output_text_,
|
||||
"[[\"Add\",\"BinaryOp<CPUDevice, "
|
||||
"[[\"AddV2\",\"BinaryOp<CPUDevice, "
|
||||
"functor::add<float>>\"],[\"Softplus\",\"SoftplusOp<CPUDevice, "
|
||||
"float>\"]]\n");
|
||||
}
|
||||
@ -119,7 +119,7 @@ TEST_F(FlexOpsListTest, TestDuplicatedOp) {
|
||||
ReadOps("tensorflow/lite/testdata/multi_add_flex.bin");
|
||||
ReadOps("tensorflow/lite/testdata/multi_add_flex.bin");
|
||||
EXPECT_EQ(output_text_,
|
||||
"[[\"Add\",\"BinaryOp<CPUDevice, functor::add<float>>\"]]\n");
|
||||
"[[\"AddV2\",\"BinaryOp<CPUDevice, functor::add<float>>\"]]\n");
|
||||
}
|
||||
|
||||
TEST_F(FlexOpsListTest, TestInvalidCustomOptions) {
|
||||
|
Loading…
Reference in New Issue
Block a user