From f581c55e4d01e4ebdeaebf6c095aff547745d893 Mon Sep 17 00:00:00 2001 From: Jared Duke Date: Tue, 12 May 2020 14:16:20 -0700 Subject: [PATCH] Introduce persistent, read-only TFLite tensor type Several operators (rank, shape) are critical for preserving the ability to resize graphs correctly at runtime. However, introduction of such ops in the graph currently makes it impossible to fully propagate shapes when tensors are allocated. This also prevents delegation of the graph for most delegates, as it introduces dynamic shapes. Introduce a new, persistent tensor type that can be treated as "constant" at the time of TfLiteRegistration::Prepare. This tensor type is allocated immediately when requested, similar to a dynamic tensor, but promises that its contents will be populated after the "producing" node is prepared, and that it won't change across subsequent evals. Update Rank/Shape operators to use this tensor allocation type. A follow-up CL will introduce a new pseudo-constant tensor check that can be used by various kernels to avoid making them dynamic. PiperOrigin-RevId: 311199934 Change-Id: I050704be7d1ff264fc1a852efade53d4021cb034 --- tensorflow/lite/c/common.c | 6 ++- tensorflow/lite/c/common.h | 14 +++++-- tensorflow/lite/core/subgraph.cc | 9 +++-- tensorflow/lite/kernels/kernel_util.h | 12 ++++++ tensorflow/lite/kernels/rank.cc | 18 ++++++--- tensorflow/lite/kernels/rank_test.cc | 13 +++++-- tensorflow/lite/kernels/shape.cc | 17 +++++--- tensorflow/lite/kernels/shape_test.cc | 13 +++++-- .../lite/micro/micro_optional_debug_tools.cc | 2 + tensorflow/lite/optional_debug_tools.cc | 2 + tensorflow/lite/python/lite_test.py | 39 +++++++++++++++++-- .../benchmark/experimental/c/c_api_types.h | 14 +++++-- 12 files changed, 129 insertions(+), 30 deletions(-) diff --git a/tensorflow/lite/c/common.c b/tensorflow/lite/c/common.c index f70a60002dd..e6b47896528 100644 --- a/tensorflow/lite/c/common.c +++ b/tensorflow/lite/c/common.c @@ -79,7 +79,8 @@ TfLiteFloatArray* TfLiteFloatArrayCreate(int size) { void TfLiteFloatArrayFree(TfLiteFloatArray* a) { free(a); } void TfLiteTensorDataFree(TfLiteTensor* t) { - if (t->allocation_type == kTfLiteDynamic) { + if (t->allocation_type == kTfLiteDynamic || + t->allocation_type == kTfLitePersistentRo) { free(t->data.raw); } t->data.raw = NULL; @@ -172,7 +173,8 @@ void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, } void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) { - if (tensor->allocation_type != kTfLiteDynamic) { + if (tensor->allocation_type != kTfLiteDynamic && + tensor->allocation_type != kTfLitePersistentRo) { return; } // TODO(b/145340303): Tensor data should be aligned. diff --git a/tensorflow/lite/c/common.h b/tensorflow/lite/c/common.h index 9657c7e564c..ab150e87d93 100644 --- a/tensorflow/lite/c/common.h +++ b/tensorflow/lite/c/common.h @@ -321,15 +321,23 @@ typedef union TfLitePtrUnion { void* data; } TfLitePtrUnion; -// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped -// data (or data externally allocated). kTfLiteArenaRw is arena allocated -// data. kTfLiteDynamic is for tensors that are allocated during evaluation. +// Memory allocation strategies. +// * kTfLiteMmapRo: Read-only memory-mapped data, or data externally allocated. +// * kTfLiteArenaRw: Arena allocated with no guarantees about persistence, +// and available during eval. +// * kTfLiteArenaRwPersistent: Arena allocated but persistent across eval, and +// only available during eval. +// * kTfLiteDynamic: Allocated during eval, or for string tensors. +// * kTfLitePersistentRo: Allocated and populated during prepare. This is +// useful for tensors that can be computed during prepare and treated +// as constant inputs for downstream ops (also in prepare). typedef enum TfLiteAllocationType { kTfLiteMemNone = 0, kTfLiteMmapRo, kTfLiteArenaRw, kTfLiteArenaRwPersistent, kTfLiteDynamic, + kTfLitePersistentRo, } TfLiteAllocationType; // The delegates should use zero or positive integers to represent handles. diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 4cebd059a80..7f4e0e286ea 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -1183,7 +1183,8 @@ TfLiteStatus Subgraph::ResizeTensorImpl(TfLiteTensor* tensor, // Note that in theory we could resize kTfLiteArenaRwPersistent tensors too. if (tensor->allocation_type == kTfLiteArenaRw || tensor->allocation_type == kTfLiteDynamic || - tensor->allocation_type == kTfLiteArenaRwPersistent) { + tensor->allocation_type == kTfLiteArenaRwPersistent || + tensor->allocation_type == kTfLitePersistentRo) { tensor_resized_since_op_invoke_ |= TfLiteIntArrayEqual(tensor->dims, new_size) == 0; if (tensor->type != kTfLiteString) { @@ -1195,14 +1196,16 @@ TfLiteStatus Subgraph::ResizeTensorImpl(TfLiteTensor* tensor, return kTfLiteError; } - // Realloc space for kTfLiteDynamic tensors. + // Realloc space for heap-allocated tensors. TfLiteTensorRealloc(bytesRequired, tensor); tensor->bytes = bytesRequired; } if (tensor->dims) TfLiteIntArrayFree(tensor->dims); tensor->dims = new_size; - if (tensor->allocation_type != kTfLiteDynamic) { + // Reset arena-allocated tensors; they will be allocated later. + if (tensor->allocation_type == kTfLiteArenaRw || + tensor->allocation_type == kTfLiteArenaRwPersistent) { tensor->data.raw = nullptr; } } else { diff --git a/tensorflow/lite/kernels/kernel_util.h b/tensorflow/lite/kernels/kernel_util.h index ad068ddd3fd..5793b08616d 100644 --- a/tensorflow/lite/kernels/kernel_util.h +++ b/tensorflow/lite/kernels/kernel_util.h @@ -87,6 +87,10 @@ inline const TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context, } // Determines whether tensor is constant. +// TODO(b/138199592): Introduce new query which checks for constant OR +// persistent-read-only, which would be useful for most tensor kernels that +// are potentially dynamic based on the input tensor value availability at the +// time of prepare. inline bool IsConstantTensor(const TfLiteTensor* tensor) { return tensor->allocation_type == kTfLiteMmapRo; } @@ -105,6 +109,14 @@ inline void SetTensorToDynamic(TfLiteTensor* tensor) { } } +// Sets tensor to persistent and read-only. +inline void SetTensorToPersistentRo(TfLiteTensor* tensor) { + if (tensor->allocation_type != kTfLitePersistentRo) { + tensor->allocation_type = kTfLitePersistentRo; + tensor->data.raw = nullptr; + } +} + // Determines whether it is a hybrid op - one that has float inputs and // quantized weights. inline bool IsHybridOp(const TfLiteTensor* input, const TfLiteTensor* weight) { diff --git a/tensorflow/lite/kernels/rank.cc b/tensorflow/lite/kernels/rank.cc index 8e27ebcc325..53fd92f1682 100644 --- a/tensorflow/lite/kernels/rank.cc +++ b/tensorflow/lite/kernels/rank.cc @@ -30,19 +30,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); output->type = kTfLiteInt32; + // By design, the input shape is always known at the time of Prepare, even + // if the preceding op that generates |input| is dynamic. Thus, we can + // always compute the rank immediately, without waiting for Eval. + SetTensorToPersistentRo(output); + // Rank produces a 0-D int32 Tensor representing the rank of input. TfLiteIntArray* output_size = TfLiteIntArrayCreate(0); - return context->ResizeTensor(context, output, output_size); -} + TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_size)); -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_EQ(context, NumDimensions(output), 0); + // Immediately propagate the known rank to the output tensor. This allows + // downstream ops that rely on the value to use it during prepare. if (output->type == kTfLiteInt32) { int32_t* output_data = GetTensorData(output); *output_data = NumDimensions(input); @@ -53,6 +57,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + } // namespace rank TfLiteRegistration* Register_RANK() { diff --git a/tensorflow/lite/kernels/rank_test.cc b/tensorflow/lite/kernels/rank_test.cc index f3dc97126ba..5373a0a66fe 100644 --- a/tensorflow/lite/kernels/rank_test.cc +++ b/tensorflow/lite/kernels/rank_test.cc @@ -43,6 +43,9 @@ class RankOpModel : public SingleOpModel { std::vector GetOutput() { return ExtractVector(output_); } std::vector GetOutputShape() { return GetTensorShape(output_); } + TfLiteAllocationType GetOutputAllocationType() const { + return interpreter_->tensor(interpreter_->outputs()[0])->allocation_type; + } private: int input_; @@ -51,6 +54,13 @@ class RankOpModel : public SingleOpModel { TEST(RankOpTest, InputTypeFloat) { RankOpModel model({1, 3, 1, 3, 5}, TensorType_FLOAT32); + ASSERT_EQ(model.GetOutputAllocationType(), kTfLitePersistentRo); + + // Unlike most ops, Rank populates outputs in Prepare(). + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5})); + EXPECT_TRUE(model.GetOutputShape().empty()); + + // Invoke is superfluous and shouldn't change the output. model.Invoke(); EXPECT_THAT(model.GetOutput(), ElementsAreArray({5})); @@ -59,7 +69,6 @@ TEST(RankOpTest, InputTypeFloat) { TEST(RankOpTest, InputTypeInt) { RankOpModel model({1, 3, 1, 3, 5}, TensorType_INT32); - model.Invoke(); EXPECT_THAT(model.GetOutput(), ElementsAreArray({5})); EXPECT_TRUE(model.GetOutputShape().empty()); @@ -67,7 +76,6 @@ TEST(RankOpTest, InputTypeInt) { TEST(RankOpTest, ScalarTensor) { RankOpModel model({}, TensorType_FLOAT32); - model.Invoke(); EXPECT_THAT(model.GetOutput(), ElementsAreArray({0})); EXPECT_TRUE(model.GetOutputShape().empty()); @@ -75,7 +83,6 @@ TEST(RankOpTest, ScalarTensor) { TEST(RankOpTest, EmptyTensor) { RankOpModel model({1, 0}, TensorType_FLOAT32); - model.Invoke(); EXPECT_THAT(model.GetOutput(), ElementsAreArray({2})); EXPECT_TRUE(model.GetOutputShape().empty()); diff --git a/tensorflow/lite/kernels/shape.cc b/tensorflow/lite/kernels/shape.cc index 88794fefac4..d979f083f70 100644 --- a/tensorflow/lite/kernels/shape.cc +++ b/tensorflow/lite/kernels/shape.cc @@ -54,19 +54,22 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteError; } + // By design, the input shape is always known at the time of Prepare, even + // if the preceding op that generates |input| is dynamic. Thus, we can + // always compute the shape immediately, without waiting for Eval. + SetTensorToPersistentRo(output); + // Shape always produces a 1-dimensional output tensor, where each output // element is the length of the corresponding input tensor's dimension. TfLiteIntArray* output_size = TfLiteIntArrayCreate(1); output_size->data[0] = NumDimensions(input); - return context->ResizeTensor(context, output, output_size); -} + TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_size)); -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TFLITE_DCHECK_EQ(NumDimensions(output), 1); TFLITE_DCHECK_EQ(SizeOfDimension(output, 0), NumDimensions(input)); + // Immediately propagate the known shape to the output tensor. This allows + // downstream ops that rely on the value to use it during prepare. switch (output->type) { case kTfLiteInt32: ExtractShape(input, GetTensorData(output)); @@ -81,6 +84,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + } // namespace shape TfLiteRegistration* Register_SHAPE() { diff --git a/tensorflow/lite/kernels/shape_test.cc b/tensorflow/lite/kernels/shape_test.cc index 6a7dad4d3e0..3eeb83f5000 100644 --- a/tensorflow/lite/kernels/shape_test.cc +++ b/tensorflow/lite/kernels/shape_test.cc @@ -45,6 +45,9 @@ class ShapeOpModel : public SingleOpModel { int32_t GetOutputSize() { return GetTensorSize(output_); } std::vector GetOutput() { return ExtractVector(output_); } std::vector GetOutputShape() { return GetTensorShape(output_); } + TfLiteAllocationType GetOutputAllocationType() const { + return interpreter_->tensor(interpreter_->outputs()[0])->allocation_type; + } private: int input_; @@ -54,6 +57,13 @@ class ShapeOpModel : public SingleOpModel { TEST(ShapeOpTest, OutTypeInt) { ShapeOpModel model({1, 3, 1, 3, 5}, TensorType_FLOAT32, TensorType_INT32); + ASSERT_EQ(model.GetOutputAllocationType(), kTfLitePersistentRo); + + // Unlike most ops, Rank populates outputs in Prepare(). + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5})); + + // Invoke is superfluous and shouldn't change the output. model.Invoke(); EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5})); @@ -63,7 +73,6 @@ TEST(ShapeOpTest, OutTypeInt) { TEST(ShapeOpTest, OutTypeInt64) { ShapeOpModel model({1, 3, 1, 3, 5}, TensorType_FLOAT32, TensorType_INT64); - model.Invoke(); EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5})); EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5})); @@ -71,7 +80,6 @@ TEST(ShapeOpTest, OutTypeInt64) { TEST(ShapeOpTest, ScalarTensor) { ShapeOpModel model({}, TensorType_FLOAT32, TensorType_INT32); - model.Invoke(); EXPECT_EQ(model.GetOutputSize(), 0); EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({0})); @@ -79,7 +87,6 @@ TEST(ShapeOpTest, ScalarTensor) { TEST(ShapeOpTest, EmptyTensor) { ShapeOpModel model({1, 0}, TensorType_FLOAT32, TensorType_INT32); - model.Invoke(); EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0})); EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2})); diff --git a/tensorflow/lite/micro/micro_optional_debug_tools.cc b/tensorflow/lite/micro/micro_optional_debug_tools.cc index 70f16c78d79..42c42aea9f8 100644 --- a/tensorflow/lite/micro/micro_optional_debug_tools.cc +++ b/tensorflow/lite/micro/micro_optional_debug_tools.cc @@ -95,6 +95,8 @@ const char* AllocTypeName(TfLiteAllocationType type) { return "kTfLiteArenaRw"; case kTfLiteArenaRwPersistent: return "kTfLiteArenaRwPersistent"; + case kTfLitePersistentRo: + return "kTfLitePersistentRo"; } return "(invalid)"; } diff --git a/tensorflow/lite/optional_debug_tools.cc b/tensorflow/lite/optional_debug_tools.cc index c5ccdb98390..2e25b0a17f7 100644 --- a/tensorflow/lite/optional_debug_tools.cc +++ b/tensorflow/lite/optional_debug_tools.cc @@ -77,6 +77,8 @@ const char* AllocTypeName(TfLiteAllocationType type) { return "kTfLiteArenaRw"; case kTfLiteArenaRwPersistent: return "kTfLiteArenaRwPersistent"; + case kTfLitePersistentRo: + return "kTfLitePersistentRo"; } return "(invalid)"; } diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index 9ddd09edca6..1bcb2ce0ee4 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -269,9 +269,7 @@ class FromSessionTest(TestModels, parameterized.TestCase): [out_tensor]) converter.inference_input_type = lite_constants.QUANTIZED_UINT8 converter.inference_type = lite_constants.FLOAT - converter.quantized_input_stats = { - 'Placeholder': (0., 1.) - } # mean, std_dev + converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev tflite_model = converter.convert() self.assertTrue(tflite_model) @@ -1327,6 +1325,41 @@ class FromSessionTest(TestModels, parameterized.TestCase): tflite_model = converter.convert() self.assertTrue(tflite_model) + def testResizeWithShape(self): + with ops.Graph().as_default(): + # Construct a graph with a dynamically shapped input and an internal node + # that relies on the output of that input's shape. + in_tensor = array_ops.placeholder( + shape=[None, None], dtype=dtypes.float32) + in_tensor2 = [[1, 2], [3, 4]] + out_tensor = array_ops.reshape(in_tensor2, array_ops.shape(in_tensor)) + sess = session.Session() + + converter = lite.TFLiteConverter.from_session(sess, [in_tensor], + [out_tensor]) + converter.experimental_new_converter = True + tflite_model = converter.convert() + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + input_details = interpreter.get_input_details() + self.assertLen(input_details, 1) + self.assertTrue(([1, 1] == input_details[0]['shape']).all()) + self.assertTrue(([-1, -1] == input_details[0]['shape_signature']).all()) + + # Resize tensor and invoke. + interpreter.resize_tensor_input(0, [4]) + interpreter.allocate_tensors() + interpreter.invoke() + + # The output should be reshaped properly according to the resized input. + output_details = interpreter.get_output_details() + self.assertLen(output_details, 1) + self.assertEqual(np.int32, output_details[0]['dtype']) + self.assertTrue(([4] == output_details[0]['shape']).all()) + output_data = interpreter.get_tensor(output_details[0]['index']) + self.assertTrue(([1, 2, 3, 4] == output_data).all()) + def testResizingIntermediateDynamicTensor(self): # This is a regression test for the case where shape of dynamic output # tensors changes between invocations. diff --git a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h index 9657c7e564c..ab150e87d93 100644 --- a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h +++ b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h @@ -321,15 +321,23 @@ typedef union TfLitePtrUnion { void* data; } TfLitePtrUnion; -// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped -// data (or data externally allocated). kTfLiteArenaRw is arena allocated -// data. kTfLiteDynamic is for tensors that are allocated during evaluation. +// Memory allocation strategies. +// * kTfLiteMmapRo: Read-only memory-mapped data, or data externally allocated. +// * kTfLiteArenaRw: Arena allocated with no guarantees about persistence, +// and available during eval. +// * kTfLiteArenaRwPersistent: Arena allocated but persistent across eval, and +// only available during eval. +// * kTfLiteDynamic: Allocated during eval, or for string tensors. +// * kTfLitePersistentRo: Allocated and populated during prepare. This is +// useful for tensors that can be computed during prepare and treated +// as constant inputs for downstream ops (also in prepare). typedef enum TfLiteAllocationType { kTfLiteMemNone = 0, kTfLiteMmapRo, kTfLiteArenaRw, kTfLiteArenaRwPersistent, kTfLiteDynamic, + kTfLitePersistentRo, } TfLiteAllocationType; // The delegates should use zero or positive integers to represent handles.