diff --git a/tensorflow/lite/delegates/flex/delegate_test.cc b/tensorflow/lite/delegates/flex/delegate_test.cc index 6861729e8c8..02ad4201307 100644 --- a/tensorflow/lite/delegates/flex/delegate_test.cc +++ b/tensorflow/lite/delegates/flex/delegate_test.cc @@ -313,6 +313,64 @@ TEST_F(DelegateTest, TF_AcquireFlexDelegate) { } #endif // !defined(__ANDROID__) +TEST_F(DelegateTest, StaticOutput) { + // Define the graph with input, output shapes of [2]. + AddTensors(7, {0, 1, 2, 3}, {6}, kTfLiteFloat32, {2}); + + AddTfOp(testing::kAdd, {0, 2}, {4}); + AddTfOp(testing::kAdd, {1, 3}, {5}); + AddTfOp(testing::kMul, {4, 5}, {6}); + + // Apply the delegate. + ConfigureDelegate(); + + // Define inputs which matech with the original shapes. + SetShape(0, {2}); + SetShape(1, {2}); + SetShape(2, {2}); + SetShape(3, {2}); + SetValues(0, {1.1f, 2.2f}); + SetValues(1, {3.3f, 4.4f}); + SetValues(2, {1.1f, 2.2f}); + SetValues(3, {3.3f, 4.4f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(6), ElementsAre(2)); + ASSERT_THAT(GetValues(6), ElementsAre(14.52f, 38.72f)); + ASSERT_EQ(GetType(6), kTfLiteFloat32); + // Since shapes are consistent, static output tensor is used. + ASSERT_FALSE(IsDynamicTensor(6)); +} + +TEST_F(DelegateTest, DynamicOutputAfterReshape) { + // Define the graph. + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfOp(testing::kMul, {6, 7}, {8}); + + // Apply the delegate. + ConfigureDelegate(); + + // Define inputs with reshape. + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); + ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); + ASSERT_EQ(GetType(8), kTfLiteFloat32); + // Since shapes are inconsistent, dynamic output tensor is used. + ASSERT_TRUE(IsDynamicTensor(8)); +} + } // namespace } // namespace flex } // namespace tflite diff --git a/tensorflow/lite/delegates/flex/kernel.cc b/tensorflow/lite/delegates/flex/kernel.cc index b3e978908bd..9674ee7b7f1 100644 --- a/tensorflow/lite/delegates/flex/kernel.cc +++ b/tensorflow/lite/delegates/flex/kernel.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/lite/delegates/flex/delegate_data.h" #include "tensorflow/lite/delegates/flex/util.h" #include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/minimal_logging.h" #include "tensorflow/lite/string_type.h" // Note: this is part of TF Lite's Flex delegation code which is to be @@ -48,6 +49,16 @@ limitations under the License. // retrieve the associated NodeDef, which is then used to configure the // corresponding TensorFlow/Eager Op. +using tensorflow::shape_inference::DimensionHandle; +using tensorflow::shape_inference::InferenceContext; +using tensorflow::shape_inference::ShapeAndType; +using tensorflow::shape_inference::ShapeHandle; + +const std::string GetDimsDebugString(const TfLiteIntArray* dims) { + return absl::StrCat("[", absl::StrJoin(tflite::TfLiteIntArrayView(dims), ","), + "]"); +} + namespace tflite { namespace flex { @@ -188,6 +199,9 @@ class OpNode { void set_index(int index) { index_ = index; } const tensorflow::NodeDef& nodedef() const { return nodedef_; } + const tensorflow::OpRegistrationData* op_reg_data() const { + return op_reg_data_; + } const OpInputs& inputs() const { return inputs_; } OpInputs* mutable_inputs() { return &inputs_; } @@ -222,10 +236,9 @@ class OpNode { } // Fill NodeDef with defaults if it's a valid op. - const tensorflow::OpRegistrationData* op_reg_data; TF_RETURN_IF_ERROR( - tensorflow::OpRegistry::Global()->LookUp(nodedef_.op(), &op_reg_data)); - AddDefaultsToNodeDef(op_reg_data->op_def, &nodedef_); + tensorflow::OpRegistry::Global()->LookUp(nodedef_.op(), &op_reg_data_)); + AddDefaultsToNodeDef(op_reg_data_->op_def, &nodedef_); return tensorflow::Status::OK(); } @@ -312,6 +325,8 @@ class OpNode { int index_; // The corresponding NodeDef, containing the attributes for the op. tensorflow::NodeDef nodedef_; + // The corresponding OpRegistrationData pointer. + const tensorflow::OpRegistrationData* op_reg_data_; // List of inputs, as TF Lite tensor indices. OpInputs inputs_; // List of outputs, as TF Lite tensor indices. @@ -455,10 +470,22 @@ TfLiteStatus DelegateKernel::Prepare(TfLiteContext* context, TfLiteNode* node) { tensor_ref_count[tensor_index] += 2; } + const bool shapes_are_valid = + (ValidateOutputTensorShapeConsistency(context) == kTfLiteOk); + if (shapes_are_valid) { + TFLITE_LOG(tflite::TFLITE_LOG_INFO, + "FlexDelegate: All tensor shapes are consistent."); + } else { + TFLITE_LOG(tflite::TFLITE_LOG_WARNING, + "FlexDelegate: Some tensor shapes are inconsistent."); + } + // All output tensors are allocated by TensorFlow/Eager, so we // mark them as kTfLiteDynamic. for (auto tensor_index : op_data_->subgraph_outputs) { - SetTensorToDynamic(&context->tensors[tensor_index]); + if (!shapes_are_valid) { + SetTensorToDynamic(&context->tensors[tensor_index]); + } ++tensor_ref_count[tensor_index]; } @@ -488,6 +515,78 @@ TfLiteStatus DelegateKernel::Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus DelegateKernel::ValidateOutputTensorShapeConsistency( + TfLiteContext* context) const { + for (const auto& node_data : op_data_->nodes) { + auto op_name = node_data->name().c_str(); + // Create an InferenceContext object. + auto num_inputs = node_data->inputs().Size(); + std::vector input_tensors_vector(num_inputs, + nullptr); + InferenceContext c( + TF_GRAPH_DEF_VERSION, node_data->nodedef(), + node_data->op_reg_data()->op_def, std::vector(num_inputs), + input_tensors_vector, {}, + std::vector>>()); + + // Set input_shapes for ShapeInferenceFn. + for (int i = 0; i < num_inputs; ++i) { + const auto input_tensor_index = node_data->inputs().TfLiteIndex(i); + TfLiteTensor* tfl_tensor = &context->tensors[input_tensor_index]; + const auto dims_array = tfl_tensor->dims; + std::vector dims(dims_array->size); + for (int j = 0; j < dims_array->size; ++j) { + dims[j] = c.MakeDim(dims_array->data[j]); + } + c.SetInput(i, c.MakeShape(dims)); + } + + tensorflow::Status status = c.construction_status(); + if (!status.ok()) { + TFLITE_LOG(tflite::TFLITE_LOG_WARNING, + "Shape construction failed for op '%s'", op_name); + return kTfLiteError; + } + + // Run ShapeInferenceFn to calculate output shapes. + if (node_data->op_reg_data()->shape_inference_fn == nullptr) { + TFLITE_LOG(tflite::TFLITE_LOG_WARNING, + "No shape inference function exists for op '%s'", op_name); + return kTfLiteError; + } + status = c.Run(node_data->op_reg_data()->shape_inference_fn); + + // Compare calculated output shapes with node_data->outputs + auto num_outputs = node_data->outputs().Size(); + if (num_outputs != c.num_outputs()) { + TFLITE_LOG(tflite::TFLITE_LOG_WARNING, + "Number of output tensors are mismatched for op '%s' %d != %d", + op_name, num_outputs, c.num_outputs()); + return kTfLiteError; + } + for (int i = 0; i < num_outputs; ++i) { + const auto output_tensor_index = node_data->outputs().TfLiteIndex(i); + TfLiteTensor* tfl_tensor = &context->tensors[output_tensor_index]; + // tfl_tensor->dims only has valid information if the given model is + // converted by the MLIR converter. Also when ResizeInputTensor() is + // called the dims information becomes invalid. + const std::string tfl_shape_string = GetDimsDebugString(tfl_tensor->dims); + const std::string calculated_shape_string = c.DebugString(c.output(i)); + // Getting a shape string via c.DebugString() is the easiest way to get + // the shape information of the given ShapeHandle for now. + // TODO(b/169017408): Find a better approach without using debug string. + if (tfl_shape_string != calculated_shape_string) { + TFLITE_LOG(tflite::TFLITE_LOG_WARNING, + "op '%s' output%d tensor#%d shape mismatch for %s != %s", + op_name, i, output_tensor_index, tfl_shape_string.c_str(), + calculated_shape_string.c_str()); + return kTfLiteError; + } + } + } + return kTfLiteOk; +} + TfLiteStatus DelegateKernel::Eval(TfLiteContext* context, TfLiteNode* node) { BufferMap* buffer_map = op_data_->buffer_map; @@ -522,12 +621,30 @@ TfLiteStatus DelegateKernel::Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteError; } + // Copy TF tensor data to TFL allocated buffer for non dynamic tensors. + // For dynamic tensors, copy shape and put buffer_handle for the later + // CopyFromBufferHandle() call. TfLiteTensor* tensor = &context->tensors[tensor_index]; - TF_LITE_ENSURE_OK( - context, - CopyShapeAndType(context, buffer_map->GetTensor(tensor_index), tensor)); - tensor->buffer_handle = tensor_index; - tensor->data_is_stale = true; + const tensorflow::Tensor& tf_tensor = buffer_map->GetTensor(tensor_index); + if (tensor->allocation_type == kTfLiteDynamic) { + TF_LITE_ENSURE_OK(context, CopyShapeAndType(context, tf_tensor, tensor)); + tensor->buffer_handle = tensor_index; + tensor->data_is_stale = true; + continue; + } + // If the tensor isn't dynamic, we can copy data directly to the buffer of + // the tensor. Before copying the data, check if the target buffer has + // expected size. + if (tf_tensor.NumElements() != NumElements(tensor) || + tf_tensor.TotalBytes() != tensor->bytes) { + TF_LITE_KERNEL_LOG( + context, "Tensor: %s(%d) buffer size mismatch %zu(%lld) != %ld(%ld)", + tensor->name, tensor_index, tf_tensor.TotalBytes(), + tf_tensor.NumElements(), tensor->bytes, NumElements(tensor)); + return kTfLiteError; + } + tensorflow::StringPiece t_data = tf_tensor.tensor_data(); + memcpy(tensor->data.raw, t_data.data(), t_data.size()); } return kTfLiteOk; diff --git a/tensorflow/lite/delegates/flex/kernel.h b/tensorflow/lite/delegates/flex/kernel.h index 9a7b93e31f2..b2ab485bdaa 100644 --- a/tensorflow/lite/delegates/flex/kernel.h +++ b/tensorflow/lite/delegates/flex/kernel.h @@ -35,6 +35,11 @@ class DelegateKernel : public SimpleDelegateKernelInterface { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) override; private: + // Validate that the computed output tensor shape for the Flex node matches + // the existing output shape assigned to the output tensor. + TfLiteStatus ValidateOutputTensorShapeConsistency( + TfLiteContext* context) const; + std::unique_ptr op_data_; }; diff --git a/tensorflow/lite/delegates/flex/kernel_test.cc b/tensorflow/lite/delegates/flex/kernel_test.cc index f7234075c95..9cd37a1d26e 100644 --- a/tensorflow/lite/delegates/flex/kernel_test.cc +++ b/tensorflow/lite/delegates/flex/kernel_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/lite/delegates/flex/delegate_data.h" #include "tensorflow/lite/delegates/flex/test_util.h" +extern const std::string GetDimsDebugString(const TfLiteIntArray* dims); + namespace tflite { namespace flex { namespace testing { @@ -351,6 +353,61 @@ TEST_F(MultipleSubgraphsTest, DoNotForwardInputTensors) { }))); } +tensorflow::OpDef MakeOpDef(int num_inputs, int num_outputs) { + tensorflow::OpRegistrationData op_reg_data; + tensorflow::OpDefBuilder b("dummy"); + for (int i = 0; i < num_inputs; ++i) { + b.Input(tensorflow::strings::StrCat("i", i, ": float")); + } + for (int i = 0; i < num_outputs; ++i) { + b.Output(tensorflow::strings::StrCat("o", i, ": float")); + } + CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok()); + return op_reg_data.op_def; +} + +tensorflow::PartialTensorShape S(std::initializer_list dims) { + return tensorflow::PartialTensorShape(dims); +} + +TEST(ValidateOutputTensorShapeConsistencyTest, ShapeHandleDebugString) { + // Setup test to contain an input tensor list of size 3. + tensorflow::OpDef op_def = MakeOpDef(4, 1); + tensorflow::NodeDef def; + tensorflow::shape_inference::InferenceContext c( + 0, def, op_def, {S({1}), S({2, 3}), S({4, 5, 6}), {}}, {}, {}, {}); + c.SetInput(3, c.UnknownShape()); + + std::vector shapes; + EXPECT_EQ("[1]", c.DebugString(c.input(0))); + EXPECT_EQ("[2,3]", c.DebugString(c.input(1))); + EXPECT_EQ("[4,5,6]", c.DebugString(c.input(2))); + // c.DebugString() returns "?" for the unknown shape which is different with + // "-1" of TFLite. But this is intended behavior since we should use dynamic + // tensor for unknown shape so the shape comparison must fail. + EXPECT_EQ("?", c.DebugString(c.input(3))); +} + +TEST(ValidateOutputTensorShapeConsistencyTest, GetDimsDebugString) { + TfLiteIntArray* dims1 = TfLiteIntArrayCreate(1); + dims1->data[0] = 1; + EXPECT_EQ("[1]", GetDimsDebugString(dims1)); + free(dims1); + + TfLiteIntArray* dims2 = TfLiteIntArrayCreate(2); + dims2->data[0] = 2; + dims2->data[1] = 3; + EXPECT_EQ("[2,3]", GetDimsDebugString(dims2)); + free(dims2); + + TfLiteIntArray* dims3 = TfLiteIntArrayCreate(3); + dims3->data[0] = 4; + dims3->data[1] = 5; + dims3->data[2] = 6; + EXPECT_EQ("[4,5,6]", GetDimsDebugString(dims3)); + free(dims3); +} + } // namespace testing } // namespace flex } // namespace tflite diff --git a/tensorflow/lite/delegates/flex/test_util.cc b/tensorflow/lite/delegates/flex/test_util.cc index 8c0e40b58dd..fd566034b3d 100644 --- a/tensorflow/lite/delegates/flex/test_util.cc +++ b/tensorflow/lite/delegates/flex/test_util.cc @@ -67,6 +67,10 @@ TfLiteType FlexModelTest::GetType(int tensor_index) { return interpreter_->tensor(tensor_index)->type; } +bool FlexModelTest::IsDynamicTensor(int tensor_index) { + return interpreter_->tensor(tensor_index)->allocation_type == kTfLiteDynamic; +} + void FlexModelTest::AddTensors(int num_tensors, const std::vector& inputs, const std::vector& outputs, TfLiteType type, const std::vector& dims) { diff --git a/tensorflow/lite/delegates/flex/test_util.h b/tensorflow/lite/delegates/flex/test_util.h index 1913a406e83..bc74d8578a2 100644 --- a/tensorflow/lite/delegates/flex/test_util.h +++ b/tensorflow/lite/delegates/flex/test_util.h @@ -80,6 +80,9 @@ class FlexModelTest : public ::testing::Test { // Returns the tensor's type at the given index. TfLiteType GetType(int tensor_index); + // Returns if the tensor at the given index is dynamic. + bool IsDynamicTensor(int tensor_index); + const TestErrorReporter& error_reporter() const { return error_reporter_; } // Adds `num_tensor` tensors to the model. `inputs` contains the indices of diff --git a/tensorflow/lite/testdata/multi_add_flex.bin b/tensorflow/lite/testdata/multi_add_flex.bin index 9aac2155fed..9ab31ed63d9 100644 Binary files a/tensorflow/lite/testdata/multi_add_flex.bin and b/tensorflow/lite/testdata/multi_add_flex.bin differ diff --git a/tensorflow/lite/tools/list_flex_ops_test.cc b/tensorflow/lite/tools/list_flex_ops_test.cc index 7d81dda71e6..8b3757a4039 100644 --- a/tensorflow/lite/tools/list_flex_ops_test.cc +++ b/tensorflow/lite/tools/list_flex_ops_test.cc @@ -103,14 +103,14 @@ TEST_F(FlexOpsListTest, TestZeroSubgraphs) { TEST_F(FlexOpsListTest, TestFlexAdd) { ReadOps("tensorflow/lite/testdata/multi_add_flex.bin"); EXPECT_EQ(output_text_, - "[[\"Add\",\"BinaryOp>\"]]\n"); + "[[\"AddV2\",\"BinaryOp>\"]]\n"); } TEST_F(FlexOpsListTest, TestTwoModel) { ReadOps("tensorflow/lite/testdata/multi_add_flex.bin"); ReadOps("tensorflow/lite/testdata/softplus_flex.bin"); EXPECT_EQ(output_text_, - "[[\"Add\",\"BinaryOp>\"],[\"Softplus\",\"SoftplusOp\"]]\n"); } @@ -119,7 +119,7 @@ TEST_F(FlexOpsListTest, TestDuplicatedOp) { ReadOps("tensorflow/lite/testdata/multi_add_flex.bin"); ReadOps("tensorflow/lite/testdata/multi_add_flex.bin"); EXPECT_EQ(output_text_, - "[[\"Add\",\"BinaryOp>\"]]\n"); + "[[\"AddV2\",\"BinaryOp>\"]]\n"); } TEST_F(FlexOpsListTest, TestInvalidCustomOptions) {