Introduce persistent, read-only TFLite tensor type

Several operators (rank, shape) are critical for preserving the ability
to resize graphs correctly at runtime. However, introduction of such ops
in the graph currently makes it impossible to fully propagate shapes when
tensors are allocated. This also prevents delegation of the graph for most
delegates, as it introduces dynamic shapes.

Introduce a new, persistent tensor type that can be treated as "constant"
at the time of TfLiteRegistration::Prepare. This tensor type is allocated
immediately when requested, similar to a dynamic tensor, but promises that
its contents will be populated after the "producing" node is prepared, and
that it won't change across subsequent evals.

Update Rank/Shape operators to use this tensor allocation type.

A follow-up CL will introduce a new pseudo-constant tensor check
that can be used by various kernels to avoid making them dynamic.

PiperOrigin-RevId: 311199934
Change-Id: I050704be7d1ff264fc1a852efade53d4021cb034
This commit is contained in:
Jared Duke 2020-05-12 14:16:20 -07:00 committed by TensorFlower Gardener
parent 1638fe218d
commit f581c55e4d
12 changed files with 129 additions and 30 deletions

View File

@ -79,7 +79,8 @@ TfLiteFloatArray* TfLiteFloatArrayCreate(int size) {
void TfLiteFloatArrayFree(TfLiteFloatArray* a) { free(a); } void TfLiteFloatArrayFree(TfLiteFloatArray* a) { free(a); }
void TfLiteTensorDataFree(TfLiteTensor* t) { void TfLiteTensorDataFree(TfLiteTensor* t) {
if (t->allocation_type == kTfLiteDynamic) { if (t->allocation_type == kTfLiteDynamic ||
t->allocation_type == kTfLitePersistentRo) {
free(t->data.raw); free(t->data.raw);
} }
t->data.raw = NULL; t->data.raw = NULL;
@ -172,7 +173,8 @@ void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
} }
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) { void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) {
if (tensor->allocation_type != kTfLiteDynamic) { if (tensor->allocation_type != kTfLiteDynamic &&
tensor->allocation_type != kTfLitePersistentRo) {
return; return;
} }
// TODO(b/145340303): Tensor data should be aligned. // TODO(b/145340303): Tensor data should be aligned.

View File

@ -321,15 +321,23 @@ typedef union TfLitePtrUnion {
void* data; void* data;
} TfLitePtrUnion; } TfLitePtrUnion;
// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped // Memory allocation strategies.
// data (or data externally allocated). kTfLiteArenaRw is arena allocated // * kTfLiteMmapRo: Read-only memory-mapped data, or data externally allocated.
// data. kTfLiteDynamic is for tensors that are allocated during evaluation. // * kTfLiteArenaRw: Arena allocated with no guarantees about persistence,
// and available during eval.
// * kTfLiteArenaRwPersistent: Arena allocated but persistent across eval, and
// only available during eval.
// * kTfLiteDynamic: Allocated during eval, or for string tensors.
// * kTfLitePersistentRo: Allocated and populated during prepare. This is
// useful for tensors that can be computed during prepare and treated
// as constant inputs for downstream ops (also in prepare).
typedef enum TfLiteAllocationType { typedef enum TfLiteAllocationType {
kTfLiteMemNone = 0, kTfLiteMemNone = 0,
kTfLiteMmapRo, kTfLiteMmapRo,
kTfLiteArenaRw, kTfLiteArenaRw,
kTfLiteArenaRwPersistent, kTfLiteArenaRwPersistent,
kTfLiteDynamic, kTfLiteDynamic,
kTfLitePersistentRo,
} TfLiteAllocationType; } TfLiteAllocationType;
// The delegates should use zero or positive integers to represent handles. // The delegates should use zero or positive integers to represent handles.

View File

@ -1183,7 +1183,8 @@ TfLiteStatus Subgraph::ResizeTensorImpl(TfLiteTensor* tensor,
// Note that in theory we could resize kTfLiteArenaRwPersistent tensors too. // Note that in theory we could resize kTfLiteArenaRwPersistent tensors too.
if (tensor->allocation_type == kTfLiteArenaRw || if (tensor->allocation_type == kTfLiteArenaRw ||
tensor->allocation_type == kTfLiteDynamic || tensor->allocation_type == kTfLiteDynamic ||
tensor->allocation_type == kTfLiteArenaRwPersistent) { tensor->allocation_type == kTfLiteArenaRwPersistent ||
tensor->allocation_type == kTfLitePersistentRo) {
tensor_resized_since_op_invoke_ |= tensor_resized_since_op_invoke_ |=
TfLiteIntArrayEqual(tensor->dims, new_size) == 0; TfLiteIntArrayEqual(tensor->dims, new_size) == 0;
if (tensor->type != kTfLiteString) { if (tensor->type != kTfLiteString) {
@ -1195,14 +1196,16 @@ TfLiteStatus Subgraph::ResizeTensorImpl(TfLiteTensor* tensor,
return kTfLiteError; return kTfLiteError;
} }
// Realloc space for kTfLiteDynamic tensors. // Realloc space for heap-allocated tensors.
TfLiteTensorRealloc(bytesRequired, tensor); TfLiteTensorRealloc(bytesRequired, tensor);
tensor->bytes = bytesRequired; tensor->bytes = bytesRequired;
} }
if (tensor->dims) TfLiteIntArrayFree(tensor->dims); if (tensor->dims) TfLiteIntArrayFree(tensor->dims);
tensor->dims = new_size; tensor->dims = new_size;
if (tensor->allocation_type != kTfLiteDynamic) { // Reset arena-allocated tensors; they will be allocated later.
if (tensor->allocation_type == kTfLiteArenaRw ||
tensor->allocation_type == kTfLiteArenaRwPersistent) {
tensor->data.raw = nullptr; tensor->data.raw = nullptr;
} }
} else { } else {

View File

@ -87,6 +87,10 @@ inline const TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context,
} }
// Determines whether tensor is constant. // Determines whether tensor is constant.
// TODO(b/138199592): Introduce new query which checks for constant OR
// persistent-read-only, which would be useful for most tensor kernels that
// are potentially dynamic based on the input tensor value availability at the
// time of prepare.
inline bool IsConstantTensor(const TfLiteTensor* tensor) { inline bool IsConstantTensor(const TfLiteTensor* tensor) {
return tensor->allocation_type == kTfLiteMmapRo; return tensor->allocation_type == kTfLiteMmapRo;
} }
@ -105,6 +109,14 @@ inline void SetTensorToDynamic(TfLiteTensor* tensor) {
} }
} }
// Sets tensor to persistent and read-only.
inline void SetTensorToPersistentRo(TfLiteTensor* tensor) {
if (tensor->allocation_type != kTfLitePersistentRo) {
tensor->allocation_type = kTfLitePersistentRo;
tensor->data.raw = nullptr;
}
}
// Determines whether it is a hybrid op - one that has float inputs and // Determines whether it is a hybrid op - one that has float inputs and
// quantized weights. // quantized weights.
inline bool IsHybridOp(const TfLiteTensor* input, const TfLiteTensor* weight) { inline bool IsHybridOp(const TfLiteTensor* input, const TfLiteTensor* weight) {

View File

@ -30,19 +30,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
output->type = kTfLiteInt32; output->type = kTfLiteInt32;
// By design, the input shape is always known at the time of Prepare, even
// if the preceding op that generates |input| is dynamic. Thus, we can
// always compute the rank immediately, without waiting for Eval.
SetTensorToPersistentRo(output);
// Rank produces a 0-D int32 Tensor representing the rank of input. // Rank produces a 0-D int32 Tensor representing the rank of input.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(0); TfLiteIntArray* output_size = TfLiteIntArrayCreate(0);
return context->ResizeTensor(context, output, output_size); TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_size));
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 0); TF_LITE_ENSURE_EQ(context, NumDimensions(output), 0);
// Immediately propagate the known rank to the output tensor. This allows
// downstream ops that rely on the value to use it during prepare.
if (output->type == kTfLiteInt32) { if (output->type == kTfLiteInt32) {
int32_t* output_data = GetTensorData<int32_t>(output); int32_t* output_data = GetTensorData<int32_t>(output);
*output_data = NumDimensions(input); *output_data = NumDimensions(input);
@ -53,6 +57,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk; return kTfLiteOk;
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
} // namespace rank } // namespace rank
TfLiteRegistration* Register_RANK() { TfLiteRegistration* Register_RANK() {

View File

@ -43,6 +43,9 @@ class RankOpModel : public SingleOpModel {
std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); } std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); } std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
TfLiteAllocationType GetOutputAllocationType() const {
return interpreter_->tensor(interpreter_->outputs()[0])->allocation_type;
}
private: private:
int input_; int input_;
@ -51,6 +54,13 @@ class RankOpModel : public SingleOpModel {
TEST(RankOpTest, InputTypeFloat) { TEST(RankOpTest, InputTypeFloat) {
RankOpModel model({1, 3, 1, 3, 5}, TensorType_FLOAT32); RankOpModel model({1, 3, 1, 3, 5}, TensorType_FLOAT32);
ASSERT_EQ(model.GetOutputAllocationType(), kTfLitePersistentRo);
// Unlike most ops, Rank populates outputs in Prepare().
EXPECT_THAT(model.GetOutput(), ElementsAreArray({5}));
EXPECT_TRUE(model.GetOutputShape().empty());
// Invoke is superfluous and shouldn't change the output.
model.Invoke(); model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({5})); EXPECT_THAT(model.GetOutput(), ElementsAreArray({5}));
@ -59,7 +69,6 @@ TEST(RankOpTest, InputTypeFloat) {
TEST(RankOpTest, InputTypeInt) { TEST(RankOpTest, InputTypeInt) {
RankOpModel model({1, 3, 1, 3, 5}, TensorType_INT32); RankOpModel model({1, 3, 1, 3, 5}, TensorType_INT32);
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({5})); EXPECT_THAT(model.GetOutput(), ElementsAreArray({5}));
EXPECT_TRUE(model.GetOutputShape().empty()); EXPECT_TRUE(model.GetOutputShape().empty());
@ -67,7 +76,6 @@ TEST(RankOpTest, InputTypeInt) {
TEST(RankOpTest, ScalarTensor) { TEST(RankOpTest, ScalarTensor) {
RankOpModel model({}, TensorType_FLOAT32); RankOpModel model({}, TensorType_FLOAT32);
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0})); EXPECT_THAT(model.GetOutput(), ElementsAreArray({0}));
EXPECT_TRUE(model.GetOutputShape().empty()); EXPECT_TRUE(model.GetOutputShape().empty());
@ -75,7 +83,6 @@ TEST(RankOpTest, ScalarTensor) {
TEST(RankOpTest, EmptyTensor) { TEST(RankOpTest, EmptyTensor) {
RankOpModel model({1, 0}, TensorType_FLOAT32); RankOpModel model({1, 0}, TensorType_FLOAT32);
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({2})); EXPECT_THAT(model.GetOutput(), ElementsAreArray({2}));
EXPECT_TRUE(model.GetOutputShape().empty()); EXPECT_TRUE(model.GetOutputShape().empty());

View File

@ -54,19 +54,22 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteError; return kTfLiteError;
} }
// By design, the input shape is always known at the time of Prepare, even
// if the preceding op that generates |input| is dynamic. Thus, we can
// always compute the shape immediately, without waiting for Eval.
SetTensorToPersistentRo(output);
// Shape always produces a 1-dimensional output tensor, where each output // Shape always produces a 1-dimensional output tensor, where each output
// element is the length of the corresponding input tensor's dimension. // element is the length of the corresponding input tensor's dimension.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(1); TfLiteIntArray* output_size = TfLiteIntArrayCreate(1);
output_size->data[0] = NumDimensions(input); output_size->data[0] = NumDimensions(input);
return context->ResizeTensor(context, output, output_size); TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_size));
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TFLITE_DCHECK_EQ(NumDimensions(output), 1); TFLITE_DCHECK_EQ(NumDimensions(output), 1);
TFLITE_DCHECK_EQ(SizeOfDimension(output, 0), NumDimensions(input)); TFLITE_DCHECK_EQ(SizeOfDimension(output, 0), NumDimensions(input));
// Immediately propagate the known shape to the output tensor. This allows
// downstream ops that rely on the value to use it during prepare.
switch (output->type) { switch (output->type) {
case kTfLiteInt32: case kTfLiteInt32:
ExtractShape(input, GetTensorData<int32_t>(output)); ExtractShape(input, GetTensorData<int32_t>(output));
@ -81,6 +84,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk; return kTfLiteOk;
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
} // namespace shape } // namespace shape
TfLiteRegistration* Register_SHAPE() { TfLiteRegistration* Register_SHAPE() {

View File

@ -45,6 +45,9 @@ class ShapeOpModel : public SingleOpModel {
int32_t GetOutputSize() { return GetTensorSize(output_); } int32_t GetOutputSize() { return GetTensorSize(output_); }
std::vector<T> GetOutput() { return ExtractVector<T>(output_); } std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); } std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
TfLiteAllocationType GetOutputAllocationType() const {
return interpreter_->tensor(interpreter_->outputs()[0])->allocation_type;
}
private: private:
int input_; int input_;
@ -54,6 +57,13 @@ class ShapeOpModel : public SingleOpModel {
TEST(ShapeOpTest, OutTypeInt) { TEST(ShapeOpTest, OutTypeInt) {
ShapeOpModel<int32_t> model({1, 3, 1, 3, 5}, TensorType_FLOAT32, ShapeOpModel<int32_t> model({1, 3, 1, 3, 5}, TensorType_FLOAT32,
TensorType_INT32); TensorType_INT32);
ASSERT_EQ(model.GetOutputAllocationType(), kTfLitePersistentRo);
// Unlike most ops, Rank populates outputs in Prepare().
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5}));
// Invoke is superfluous and shouldn't change the output.
model.Invoke(); model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5})); EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5}));
@ -63,7 +73,6 @@ TEST(ShapeOpTest, OutTypeInt) {
TEST(ShapeOpTest, OutTypeInt64) { TEST(ShapeOpTest, OutTypeInt64) {
ShapeOpModel<int64_t> model({1, 3, 1, 3, 5}, TensorType_FLOAT32, ShapeOpModel<int64_t> model({1, 3, 1, 3, 5}, TensorType_FLOAT32,
TensorType_INT64); TensorType_INT64);
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5})); EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5})); EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5}));
@ -71,7 +80,6 @@ TEST(ShapeOpTest, OutTypeInt64) {
TEST(ShapeOpTest, ScalarTensor) { TEST(ShapeOpTest, ScalarTensor) {
ShapeOpModel<int32_t> model({}, TensorType_FLOAT32, TensorType_INT32); ShapeOpModel<int32_t> model({}, TensorType_FLOAT32, TensorType_INT32);
model.Invoke();
EXPECT_EQ(model.GetOutputSize(), 0); EXPECT_EQ(model.GetOutputSize(), 0);
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({0})); EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({0}));
@ -79,7 +87,6 @@ TEST(ShapeOpTest, ScalarTensor) {
TEST(ShapeOpTest, EmptyTensor) { TEST(ShapeOpTest, EmptyTensor) {
ShapeOpModel<int32_t> model({1, 0}, TensorType_FLOAT32, TensorType_INT32); ShapeOpModel<int32_t> model({1, 0}, TensorType_FLOAT32, TensorType_INT32);
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0})); EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2})); EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));

View File

@ -95,6 +95,8 @@ const char* AllocTypeName(TfLiteAllocationType type) {
return "kTfLiteArenaRw"; return "kTfLiteArenaRw";
case kTfLiteArenaRwPersistent: case kTfLiteArenaRwPersistent:
return "kTfLiteArenaRwPersistent"; return "kTfLiteArenaRwPersistent";
case kTfLitePersistentRo:
return "kTfLitePersistentRo";
} }
return "(invalid)"; return "(invalid)";
} }

View File

@ -77,6 +77,8 @@ const char* AllocTypeName(TfLiteAllocationType type) {
return "kTfLiteArenaRw"; return "kTfLiteArenaRw";
case kTfLiteArenaRwPersistent: case kTfLiteArenaRwPersistent:
return "kTfLiteArenaRwPersistent"; return "kTfLiteArenaRwPersistent";
case kTfLitePersistentRo:
return "kTfLitePersistentRo";
} }
return "(invalid)"; return "(invalid)";
} }

View File

@ -269,9 +269,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
[out_tensor]) [out_tensor])
converter.inference_input_type = lite_constants.QUANTIZED_UINT8 converter.inference_input_type = lite_constants.QUANTIZED_UINT8
converter.inference_type = lite_constants.FLOAT converter.inference_type = lite_constants.FLOAT
converter.quantized_input_stats = { converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
'Placeholder': (0., 1.)
} # mean, std_dev
tflite_model = converter.convert() tflite_model = converter.convert()
self.assertTrue(tflite_model) self.assertTrue(tflite_model)
@ -1327,6 +1325,41 @@ class FromSessionTest(TestModels, parameterized.TestCase):
tflite_model = converter.convert() tflite_model = converter.convert()
self.assertTrue(tflite_model) self.assertTrue(tflite_model)
def testResizeWithShape(self):
with ops.Graph().as_default():
# Construct a graph with a dynamically shapped input and an internal node
# that relies on the output of that input's shape.
in_tensor = array_ops.placeholder(
shape=[None, None], dtype=dtypes.float32)
in_tensor2 = [[1, 2], [3, 4]]
out_tensor = array_ops.reshape(in_tensor2, array_ops.shape(in_tensor))
sess = session.Session()
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
converter.experimental_new_converter = True
tflite_model = converter.convert()
# Check values from converted model.
interpreter = Interpreter(model_content=tflite_model)
input_details = interpreter.get_input_details()
self.assertLen(input_details, 1)
self.assertTrue(([1, 1] == input_details[0]['shape']).all())
self.assertTrue(([-1, -1] == input_details[0]['shape_signature']).all())
# Resize tensor and invoke.
interpreter.resize_tensor_input(0, [4])
interpreter.allocate_tensors()
interpreter.invoke()
# The output should be reshaped properly according to the resized input.
output_details = interpreter.get_output_details()
self.assertLen(output_details, 1)
self.assertEqual(np.int32, output_details[0]['dtype'])
self.assertTrue(([4] == output_details[0]['shape']).all())
output_data = interpreter.get_tensor(output_details[0]['index'])
self.assertTrue(([1, 2, 3, 4] == output_data).all())
def testResizingIntermediateDynamicTensor(self): def testResizingIntermediateDynamicTensor(self):
# This is a regression test for the case where shape of dynamic output # This is a regression test for the case where shape of dynamic output
# tensors changes between invocations. # tensors changes between invocations.

View File

@ -321,15 +321,23 @@ typedef union TfLitePtrUnion {
void* data; void* data;
} TfLitePtrUnion; } TfLitePtrUnion;
// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped // Memory allocation strategies.
// data (or data externally allocated). kTfLiteArenaRw is arena allocated // * kTfLiteMmapRo: Read-only memory-mapped data, or data externally allocated.
// data. kTfLiteDynamic is for tensors that are allocated during evaluation. // * kTfLiteArenaRw: Arena allocated with no guarantees about persistence,
// and available during eval.
// * kTfLiteArenaRwPersistent: Arena allocated but persistent across eval, and
// only available during eval.
// * kTfLiteDynamic: Allocated during eval, or for string tensors.
// * kTfLitePersistentRo: Allocated and populated during prepare. This is
// useful for tensors that can be computed during prepare and treated
// as constant inputs for downstream ops (also in prepare).
typedef enum TfLiteAllocationType { typedef enum TfLiteAllocationType {
kTfLiteMemNone = 0, kTfLiteMemNone = 0,
kTfLiteMmapRo, kTfLiteMmapRo,
kTfLiteArenaRw, kTfLiteArenaRw,
kTfLiteArenaRwPersistent, kTfLiteArenaRwPersistent,
kTfLiteDynamic, kTfLiteDynamic,
kTfLitePersistentRo,
} TfLiteAllocationType; } TfLiteAllocationType;
// The delegates should use zero or positive integers to represent handles. // The delegates should use zero or positive integers to represent handles.