Add ResizeInputTensorStrict to C++ API.

PiperOrigin-RevId: 305802949
Change-Id: I131bf9a43f086fef4301a9858ab34cea379d7040
This commit is contained in:
Nupur Garg 2020-04-09 18:32:02 -07:00 committed by TensorFlower Gardener
parent 1734e0c64e
commit 993a6e4f0e
5 changed files with 129 additions and 6 deletions

View File

@ -762,6 +762,36 @@ TfLiteStatus Subgraph::ResizeInputTensor(int tensor_index,
return ResizeTensorImpl(tensor, ConvertVectorToTfLiteIntArray(dims));
}
TfLiteStatus Subgraph::ResizeInputTensorStrict(int tensor_index,
const std::vector<int>& dims) {
TF_LITE_ENSURE(&context_,
tensor_index < context_.tensors_size && tensor_index >= 0);
TfLiteTensor* tensor = &context_.tensors[tensor_index];
// Ensure that only unknown dimensions can be resized.
TF_LITE_ENSURE_EQ(&context_, tensor->dims->size, dims.size());
for (size_t idx = 0; idx < dims.size(); idx++) {
// `dims_signature` is not defined when no unknown dimensions are present.
int dim_signature;
if (tensor->dims_signature && tensor->dims_signature->size) {
dim_signature = tensor->dims_signature->data[idx];
} else {
dim_signature = tensor->dims->data[idx];
}
if (dim_signature != -1 && dim_signature != dims[idx]) {
ReportError(
"Attempting to resize dimension %d of tensor %d with value %d to %d. "
"ResizeInputTensorStrict only allows mutating unknown dimensions "
"identified by -1.",
idx, tensor_index, dim_signature, dims[idx]);
return kTfLiteError;
}
}
return ResizeInputTensor(tensor_index, dims);
}
TfLiteStatus Subgraph::ReleaseNonPersistentMemory() {
if (memory_planner_) {
TF_LITE_ENSURE_STATUS(memory_planner_->ReleaseNonPersistentMemory());

View File

@ -221,6 +221,15 @@ class Subgraph {
TfLiteStatus ResizeInputTensor(int tensor_index,
const std::vector<int>& dims);
// WARNING: Experimental interface, subject to change
// Change the dimensionality of a given tensor. This is only acceptable for
// tensor indices that are inputs or variables. Only unknown dimensions can be
// resized with this function. Unknown dimensions are indicated as `-1` in the
// `dims_signature` attribute of a `TfLiteTensor`. Returns status of failure
// or success.
TfLiteStatus ResizeInputTensorStrict(int tensor_index,
const std::vector<int>& dims);
// This releases memory held by non-persistent tensors. It does NOT re-perform
// memory planning.
// AllocateTensors needs to be called before next invocation.

View File

@ -197,6 +197,11 @@ TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index,
return primary_subgraph().ResizeInputTensor(tensor_index, dims);
}
TfLiteStatus Interpreter::ResizeInputTensorStrict(
int tensor_index, const std::vector<int>& dims) {
return primary_subgraph().ResizeInputTensorStrict(tensor_index, dims);
}
TfLiteStatus Interpreter::ReleaseNonPersistentMemory() {
// TODO(b/138790287): We could do this for all subgraphs whose tensors have
// been allocated. However, AllocateTensors() relies on Control Flow ops to
@ -256,10 +261,12 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
TfLiteStatus Interpreter::SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name, const size_t rank,
const int* dims, TfLiteQuantizationParams quantization, bool is_variable) {
const int* dims, TfLiteQuantizationParams quantization, bool is_variable,
const size_t rank_dims_signature, const int* dims_signature) {
TfLiteQuantization new_quantization = GetQuantizationFromLegacy(quantization);
return primary_subgraph().SetTensorParametersReadWrite(
tensor_index, type, name, rank, dims, new_quantization, is_variable);
tensor_index, type, name, rank, dims, new_quantization, is_variable,
rank_dims_signature, dims_signature);
}
TfLiteStatus Interpreter::SetExecutionPlan(const std::vector<int>& new_plan) {

View File

@ -166,14 +166,23 @@ class Interpreter {
inline TfLiteStatus SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name,
const std::vector<int>& dims, TfLiteQuantizationParams quantization,
bool is_variable = false) {
return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
dims.data(), quantization, is_variable);
bool is_variable = false,
const std::vector<int>* dims_signature = nullptr) {
size_t rank_dims_signature = 0;
const int* dims_signature_pointer = nullptr;
if (dims_signature) {
rank_dims_signature = dims_signature->size();
dims_signature_pointer = dims_signature->data();
}
return SetTensorParametersReadWrite(
tensor_index, type, name, dims.size(), dims.data(), quantization,
is_variable, rank_dims_signature, dims_signature_pointer);
}
TfLiteStatus SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name, const size_t rank,
const int* dims, TfLiteQuantizationParams quantization,
bool is_variable = false);
bool is_variable = false, const size_t rank_dims_signature = 0,
const int* dims_signature = nullptr);
#endif // DOXYGEN_SKIP
// Functions to access tensor data
@ -319,6 +328,15 @@ class Interpreter {
TfLiteStatus ResizeInputTensor(int tensor_index,
const std::vector<int>& dims);
// WARNING: Experimental interface, subject to change
// Change the dimensionality of a given tensor. This is only acceptable for
// tensor indices that are inputs or variables. Only unknown dimensions can be
// resized with this function. Unknown dimensions are indicated as `-1` in the
// `dims_signature` attribute of a `TfLiteTensor`. Returns status of failure
// or success.
TfLiteStatus ResizeInputTensorStrict(int tensor_index,
const std::vector<int>& dims);
// This releases memory held by non-persistent tensors. It does NOT re-perform
// memory planning.
// AllocateTensors needs to be called before next invocation.

View File

@ -551,6 +551,65 @@ TEST(BasicInterpreter, NoopResizingTensors) {
ASSERT_EQ(tensor->data.f[5], 0.123f);
}
TEST(BasicInterpreter, ResizingTensorsStrictInvalid) {
// Tests ResizeInputTensorStrict where `dims_signature` is not specified.
Interpreter interpreter;
ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk);
ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
0, kTfLiteFloat32, "", {1, 1, 3}, TfLiteQuantizationParams()),
kTfLiteOk);
int t = interpreter.inputs()[0];
TfLiteTensor* tensor = interpreter.tensor(t);
ASSERT_EQ(interpreter.ResizeInputTensorStrict(t, {1, 1, 3}), kTfLiteOk);
EXPECT_EQ(tensor->bytes, 3 * sizeof(float));
ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
// Invalid becuase `dims_signature` is not specified.
ASSERT_EQ(interpreter.ResizeInputTensorStrict(t, {1, 2, 3}), kTfLiteError);
EXPECT_EQ(tensor->bytes, 3 * sizeof(float));
ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
// Assert that ResizeInputTensor works for this value.
ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 3}), kTfLiteOk);
EXPECT_EQ(tensor->bytes, 6 * sizeof(float));
ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
}
TEST(BasicInterpreter, ResizingTensorsStrict) {
// Tests ResizeInputTensorStrict where `dims_signature` is specified.
Interpreter interpreter;
ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk);
std::vector<int> dims_signature = {-1, -1, 3};
ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
0, kTfLiteFloat32, "", {1, 1, 3}, TfLiteQuantizationParams(),
false, &dims_signature),
kTfLiteOk);
int t = interpreter.inputs()[0];
TfLiteTensor* tensor = interpreter.tensor(t);
ASSERT_EQ(interpreter.ResizeInputTensorStrict(t, {1, 2, 3}), kTfLiteOk);
EXPECT_EQ(tensor->bytes, 6 * sizeof(float));
ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
ASSERT_EQ(interpreter.ResizeInputTensorStrict(t, {1, 2, 4}), kTfLiteError);
EXPECT_EQ(tensor->bytes, 6 * sizeof(float));
ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
// Assert that ResizeInputTensor works for this value.
ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 4}), kTfLiteOk);
EXPECT_EQ(tensor->bytes, 8 * sizeof(float));
ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
}
// Simple op that does input = output.
TfLiteRegistration GetPassthroughOpRegistration() {
TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};