From a8d2d3bd42d3d2d2b359164005173a5b5faa0f79 Mon Sep 17 00:00:00 2001 From: Yu-Cheng Ling Date: Thu, 25 Jul 2019 11:29:58 -0700 Subject: [PATCH] Prorotype TFLite resource variables PiperOrigin-RevId: 259986961 --- tensorflow/lite/BUILD | 1 + tensorflow/lite/core/subgraph.cc | 6 +- tensorflow/lite/core/subgraph.h | 13 +- .../lite/experimental/resource_variable/BUILD | 17 ++ .../resource_variable/resource_variable.cc | 78 +++++++++ .../resource_variable/resource_variable.h | 62 ++++++++ tensorflow/lite/interpreter.cc | 4 +- tensorflow/lite/interpreter.h | 5 + tensorflow/lite/kernels/BUILD | 30 ++++ tensorflow/lite/kernels/assign_variable.cc | 86 ++++++++++ tensorflow/lite/kernels/read_variable.cc | 88 +++++++++++ tensorflow/lite/kernels/variable_ops_test.cc | 149 ++++++++++++++++++ 12 files changed, 534 insertions(+), 5 deletions(-) create mode 100644 tensorflow/lite/experimental/resource_variable/BUILD create mode 100644 tensorflow/lite/experimental/resource_variable/resource_variable.cc create mode 100644 tensorflow/lite/experimental/resource_variable/resource_variable.h create mode 100644 tensorflow/lite/kernels/assign_variable.cc create mode 100644 tensorflow/lite/kernels/read_variable.cc create mode 100644 tensorflow/lite/kernels/variable_ops_test.cc diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index e97de3d0f2e..853ba3d473c 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -224,6 +224,7 @@ cc_library( "//tensorflow/lite/delegates/nnapi:nnapi_delegate", "//tensorflow/lite/nnapi:nnapi_implementation", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/experimental/resource_variable:resource_variable", ] + select({ ":with_select_tf_ops": [ "//tensorflow/lite/delegates/flex:delegate", diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index acbd41d19b8..b77f6fa09ef 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -156,12 +156,14 @@ class InterpreterInfo : public GraphInfo { Subgraph::Subgraph(ErrorReporter* error_reporter, TfLiteExternalContext** external_contexts, - std::vector>* subgraphs) + std::vector>* subgraphs, + ResourceVariableMap* resource_variables) : external_contexts_(external_contexts), error_reporter_(error_reporter), next_execution_plan_index_to_prepare_(0), next_execution_plan_index_to_plan_allocation_(0), - subgraphs_(subgraphs) { + subgraphs_(subgraphs), + resource_variables_(resource_variables) { context_.impl_ = static_cast(this); context_.ResizeTensor = ResizeTensor; context_.ReportError = ReportErrorC; diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h index 0a6bb634cfd..b9736d89f9a 100644 --- a/tensorflow/lite/core/subgraph.h +++ b/tensorflow/lite/core/subgraph.h @@ -16,12 +16,14 @@ limitations under the License. #define TENSORFLOW_LITE_CORE_SUBGRAPH_H_ #include +#include #include #include "tensorflow/lite/allocation.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/core/api/profiler.h" #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" +#include "tensorflow/lite/experimental/resource_variable/resource_variable.h" #include "tensorflow/lite/memory_planner.h" #include "tensorflow/lite/util.h" @@ -36,7 +38,8 @@ class Subgraph { Subgraph(ErrorReporter* error_reporter, TfLiteExternalContext** external_contexts, - std::vector>* subgraphs); + std::vector>* subgraphs, + ResourceVariableMap* resource_variables); Subgraph(const Subgraph&) = delete; @@ -160,6 +163,10 @@ class Subgraph { // Read only access to list of variable tensors. const std::vector& variables() const { return variables_; } + // WARNING: Experimental interface, subject to change. + // TODO(ycling): Move this function to an external context interface. + ResourceVariableMap& resource_variables() { return *resource_variables_; } + size_t tensors_size() const { return tensors_.size(); } // Return the number of ops in the model. @@ -581,6 +588,10 @@ class Subgraph { // Reference to data used by the cancellation function in // `check_cancelled_func_`. void* cancellation_data_ = nullptr; + + // A map of resource variables. Owned by interpreter and shared by multiple + // subgraphs. + ResourceVariableMap* resource_variables_ = nullptr; }; } // namespace tflite diff --git a/tensorflow/lite/experimental/resource_variable/BUILD b/tensorflow/lite/experimental/resource_variable/BUILD new file mode 100644 index 00000000000..af2ed19d214 --- /dev/null +++ b/tensorflow/lite/experimental/resource_variable/BUILD @@ -0,0 +1,17 @@ +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "resource_variable", + srcs = [ + "resource_variable.cc", + ], + hdrs = [ + "resource_variable.h", + ], + deps = [ + "//tensorflow/lite/c:c_api_internal", + ], +) diff --git a/tensorflow/lite/experimental/resource_variable/resource_variable.cc b/tensorflow/lite/experimental/resource_variable/resource_variable.cc new file mode 100644 index 00000000000..502ca273464 --- /dev/null +++ b/tensorflow/lite/experimental/resource_variable/resource_variable.cc @@ -0,0 +1,78 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/resource_variable/resource_variable.h" + +#include +#include +#include + +namespace tflite { + +ResourceVariable::ResourceVariable() { + memset(&tensor_, 0, sizeof(TfLiteTensor)); +} + +ResourceVariable::ResourceVariable(ResourceVariable&& other) { + tensor_ = other.tensor_; + is_initialized_ = other.is_initialized_; + + memset(&other.tensor_, 0, sizeof(TfLiteTensor)); + other.is_initialized_ = false; +} + +ResourceVariable::~ResourceVariable() { + if (is_initialized_) { + free(tensor_.data.raw); + if (tensor_.dims) { + TfLiteIntArrayFree(tensor_.dims); + } + } +} + +TfLiteStatus ResourceVariable::AssignFrom(const TfLiteTensor* tensor) { + // Save the old allocated resources and attributes that we might use. + char* old_raw = tensor_.data.raw; + size_t old_bytes = tensor_.bytes; + TfLiteIntArray* old_dims = tensor_.dims; + + // Copy primitive parameters. + memset(&tensor_, 0, sizeof(tensor_)); + tensor_.allocation_type = kTfLiteDynamic; + tensor_.type = tensor->type; + tensor_.params = tensor->params; + tensor_.quantization = tensor->quantization; + + // Copy old shape if possible otherwise create a new one. + if (TfLiteIntArrayEqual(old_dims, tensor->dims)) { + tensor_.dims = old_dims; + } else { + TfLiteIntArrayFree(old_dims); + tensor_.dims = TfLiteIntArrayCopy(tensor->dims); + } + + // Reuse the same buffer if possible otherwise allocate a new one. + tensor_.data.raw = old_raw; + if (old_bytes != tensor->bytes) { + TfLiteTensorRealloc(tensor->bytes, &tensor_); + } + + memcpy(tensor_.data.raw, tensor->data.raw, tensor_.bytes); + is_initialized_ = true; + + return kTfLiteOk; +} + +} // namespace tflite diff --git a/tensorflow/lite/experimental/resource_variable/resource_variable.h b/tensorflow/lite/experimental/resource_variable/resource_variable.h new file mode 100644 index 00000000000..6a938489eea --- /dev/null +++ b/tensorflow/lite/experimental/resource_variable/resource_variable.h @@ -0,0 +1,62 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_VARIABLE_RESOURCE_VARIABLE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_VARIABLE_RESOURCE_VARIABLE_H_ + +#include + +#include "tensorflow/lite/c/c_api_internal.h" + +namespace tflite { + +/// WARNING: Experimental interface, subject to change. +// A resource variable class. It's similar to TensorFlow Resource +// Variable, but it's identified with int32 ID in TFLite (instead of +// using Resource handle like TensorFlow). +// +// TODO(b/137042749): TFLite converter cannot convert variables yet. +// Variable functionalities are only tested with unit tests now. +class ResourceVariable { + public: + ResourceVariable(); + ResourceVariable(ResourceVariable&& other); + + ResourceVariable(const ResourceVariable&) = delete; + ResourceVariable& operator=(const ResourceVariable&) = delete; + + ~ResourceVariable(); + + // Assigns data from a tensor. Copies its type, shape and data over. + TfLiteStatus AssignFrom(const TfLiteTensor* tensor); + + // Get the data tensor stored in the resource variable. + // Returns `nullptr` if the variable is never initialized by calling + // `AssignFrom`. + TfLiteTensor* GetTensor() { return is_initialized_ ? &tensor_ : nullptr; } + + private: + // The tensor (and its buffer stored in `tensor_.data` is fully owned by + // the `ResourceVariable` object. + TfLiteTensor tensor_; + // True if `AssignFrom` function is every called. + // False if and only if `tensor_` is filled with zeros. + bool is_initialized_ = false; +}; + +using ResourceVariableMap = std::unordered_map; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_VARIABLE_RESOURCE_VARIABLE_H_ diff --git a/tensorflow/lite/interpreter.cc b/tensorflow/lite/interpreter.cc index bf72f7822ad..6ef6c2ce194 100644 --- a/tensorflow/lite/interpreter.cc +++ b/tensorflow/lite/interpreter.cc @@ -134,8 +134,8 @@ void Interpreter::AddSubgraphs(int subgraphs_to_add, subgraphs_.reserve(base_index + subgraphs_to_add); for (int i = 0; i < subgraphs_to_add; ++i) { - Subgraph* subgraph = - new Subgraph(error_reporter_, external_contexts_, &subgraphs_); + Subgraph* subgraph = new Subgraph(error_reporter_, external_contexts_, + &subgraphs_, &resource_variables_); subgraphs_.emplace_back(subgraph); } } diff --git a/tensorflow/lite/interpreter.h b/tensorflow/lite/interpreter.h index 8eef58530e2..397d47a6a8d 100644 --- a/tensorflow/lite/interpreter.h +++ b/tensorflow/lite/interpreter.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/profiler.h" #include "tensorflow/lite/core/subgraph.h" +#include "tensorflow/lite/experimental/resource_variable/resource_variable.h" #include "tensorflow/lite/external_cpu_backend_context.h" #include "tensorflow/lite/memory_planner.h" #include "tensorflow/lite/stderr_reporter.h" @@ -539,6 +540,10 @@ class Interpreter { // Subgraphs std::vector> subgraphs_; + + // A map of resource variables. Owned by interpreter and shared by multiple + // subgraphs. + ResourceVariableMap resource_variables_; }; } // namespace tflite diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 4d3876ec0e5..d1088d335ba 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -493,6 +493,36 @@ cc_library( ], ) +cc_library( + name = "variable_op_kernels", + srcs = [ + "assign_variable.cc", + "read_variable.cc", + ], + deps = [ + ":kernel_util", + ":op_macros", + "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels/internal:tensor", + ], +) + +cc_test( + name = "variable_ops_test", + size = "small", + srcs = [ + "variable_ops_test.cc", + ], + deps = [ + ":test_main", + ":test_util", + ":variable_op_kernels", + "//tensorflow/lite:framework", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "custom_ops", srcs = ["rfft2d.cc"], diff --git a/tensorflow/lite/kernels/assign_variable.cc b/tensorflow/lite/kernels/assign_variable.cc new file mode 100644 index 00000000000..099b8e16cfb --- /dev/null +++ b/tensorflow/lite/kernels/assign_variable.cc @@ -0,0 +1,86 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/core/subgraph.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace assign_variable { + +constexpr int kInputVariableId = 0; +constexpr int kInputValue = 1; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + // TODO(b/137042749): TFLite infrastructure (converter, delegate) doesn't + // fully support 0-output ops yet. Currently it works if we manually crfat + // a TFLite graph that contains variable ops. Note: + // * The TFLite Converter need to be changed to be able to produce an op + // with 0 output. + // * The delegation code need to be changed to handle 0 output ops. However + // everything still works fine when variable ops aren't used. + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 0); + + const TfLiteTensor* input_variable_id_tensor = + GetInput(context, node, kInputVariableId); + TF_LITE_ENSURE_EQ(context, input_variable_id_tensor->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, NumElements(input_variable_id_tensor), 1); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + Subgraph* subgraph = reinterpret_cast(context->impl_); + + const TfLiteTensor* input_variable_id_tensor = + GetInput(context, node, kInputVariableId); + const TfLiteTensor* input_value_tensor = GetInput(context, node, kInputValue); + + int variable_id = input_variable_id_tensor->data.i32[0]; + auto& resource_variables = subgraph->resource_variables(); + + auto variable_iterator = resource_variables.find(variable_id); + if (variable_iterator == resource_variables.end()) { + auto ret = resource_variables.emplace(variable_id, ResourceVariable()); + variable_iterator = ret.first; + } + + auto& variable = variable_iterator->second; + variable.AssignFrom(input_value_tensor); + + return kTfLiteOk; +} + +} // namespace assign_variable + +TfLiteRegistration* Register_ASSIGN_VARIABLE() { + static TfLiteRegistration r = {nullptr, nullptr, assign_variable::Prepare, + assign_variable::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/read_variable.cc b/tensorflow/lite/kernels/read_variable.cc new file mode 100644 index 00000000000..4996bcc0b4a --- /dev/null +++ b/tensorflow/lite/kernels/read_variable.cc @@ -0,0 +1,88 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/core/subgraph.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace read_variable { + +constexpr int kInputVariableId = 0; +constexpr int kOutputValue = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, node->inputs->size, 1); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + + const TfLiteTensor* input_variable_id_tensor = + GetInput(context, node, kInputVariableId); + TF_LITE_ENSURE_EQ(context, input_variable_id_tensor->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, NumElements(input_variable_id_tensor), 1); + + TfLiteTensor* output = GetOutput(context, node, kOutputValue); + SetTensorToDynamic(output); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + Subgraph* subgraph = reinterpret_cast(context->impl_); + + const TfLiteTensor* input_variable_id_tensor = + GetInput(context, node, kInputVariableId); + int variable_id = input_variable_id_tensor->data.i32[0]; + auto& resource_variables = subgraph->resource_variables(); + + const auto& variable_iterator = resource_variables.find(variable_id); + if (variable_iterator == resource_variables.end()) { + context->ReportError(context, "Variable ID %d is read before initialized.", + variable_id); + return kTfLiteError; + } + auto& variable = variable_iterator->second; + + TfLiteTensor* variable_tensor = variable.GetTensor(); + TfLiteTensor* output = GetOutput(context, node, kOutputValue); + + TF_LITE_ENSURE_EQ(context, variable_tensor->type, output->type); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor( + context, output, TfLiteIntArrayCopy(variable_tensor->dims))); + memcpy(output->data.raw, variable_tensor->data.raw, output->bytes); + + return kTfLiteOk; +} + +} // namespace read_variable + +TfLiteRegistration* Register_READ_VARIABLE() { + static TfLiteRegistration r = {nullptr, nullptr, read_variable::Prepare, + read_variable::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/variable_ops_test.cc b/tensorflow/lite/kernels/variable_ops_test.cc new file mode 100644 index 00000000000..e6e1a403f99 --- /dev/null +++ b/tensorflow/lite/kernels/variable_ops_test.cc @@ -0,0 +1,149 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { + +// Forward declaraction for op kernels. +namespace ops { +namespace custom { + +TfLiteRegistration* Register_ASSIGN_VARIABLE(); +TfLiteRegistration* Register_READ_VARIABLE(); + +} // namespace custom +} // namespace ops + +namespace { + +class VariableOpsTest : public ::testing::Test { + protected: + void SetUp() override { + assign_registration_ = ::tflite::ops::custom::Register_ASSIGN_VARIABLE(); + ASSERT_NE(assign_registration_, nullptr); + read_registration_ = ::tflite::ops::custom::Register_READ_VARIABLE(); + ASSERT_NE(read_registration_, nullptr); + + ConstructGraph(); + } + + void ConstructGraph() { + // Construct a graph like ths: + // Input: %0, %1, %2 + // Output: %3 + // variable_assign(%0, %2) + // %3 = read(%1) + + int first_new_tensor_index; + ASSERT_EQ(interpreter_.AddTensors(4, &first_new_tensor_index), kTfLiteOk); + ASSERT_EQ(interpreter_.SetInputs({0, 1, 2}), kTfLiteOk); + ASSERT_EQ(interpreter_.SetOutputs({3}), kTfLiteOk); + interpreter_.SetTensorParametersReadWrite(0, kTfLiteInt32, "", 0, nullptr, + {}, false); + interpreter_.SetTensorParametersReadWrite(1, kTfLiteInt32, "", 0, nullptr, + {}, false); + interpreter_.SetTensorParametersReadWrite(2, kTfLiteFloat32, "", 0, nullptr, + {}, false); + interpreter_.SetTensorParametersReadWrite(3, kTfLiteFloat32, "", 0, nullptr, + {}, false); + int node_index; + interpreter_.AddNodeWithParameters({0, 2}, {}, nullptr, 0, nullptr, + assign_registration_, &node_index); + interpreter_.AddNodeWithParameters({1}, {3}, nullptr, 0, nullptr, + read_registration_, &node_index); + } + TfLiteRegistration* assign_registration_; + TfLiteRegistration* read_registration_; + Interpreter interpreter_; +}; + +TEST_F(VariableOpsTest, TestAssignThenReadVariable) { + ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk); + TfLiteTensor* input_assign_index = interpreter_.tensor(0); + input_assign_index->data.i32[0] = 1; + TfLiteTensor* input_read_index = interpreter_.tensor(1); + input_read_index->data.i32[0] = 1; + TfLiteTensor* input_data_index = interpreter_.tensor(2); + input_data_index->data.f[0] = 1717; + ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk); + + // Verify output. + TfLiteTensor* output = interpreter_.tensor(3); + ASSERT_EQ(output->dims->size, 0); + EXPECT_EQ(output->data.f[0], 1717); +} + +TEST_F(VariableOpsTest, TestReadVariableBeforeAssign) { + ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk); + TfLiteTensor* input_assign_index = interpreter_.tensor(0); + input_assign_index->data.i32[0] = 1; + TfLiteTensor* input_read_index = interpreter_.tensor(1); + input_read_index->data.i32[0] = 2; + TfLiteTensor* input_data_index = interpreter_.tensor(2); + input_data_index->data.f[0] = 1717; + + // Error because variable 2 is never initialized. + ASSERT_EQ(interpreter_.Invoke(), kTfLiteError); +} + +TEST_F(VariableOpsTest, TestReeasignToDifferentSize) { + // 1st invocation. The variable is assigned as a scalar. + { + ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk); + + TfLiteTensor* input_assign_index = interpreter_.tensor(0); + input_assign_index->data.i32[0] = 1; + TfLiteTensor* input_read_index = interpreter_.tensor(1); + input_read_index->data.i32[0] = 1; + TfLiteTensor* input_data_index = interpreter_.tensor(2); + input_data_index->data.f[0] = 1717; + ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk); + + // Verify output. + TfLiteTensor* output = interpreter_.tensor(3); + ASSERT_EQ(output->dims->size, 0); + EXPECT_EQ(output->data.f[0], 1717); + } + + // 2nd invocation. The variable is assigned as a 1D vector with 2 elements. + { + interpreter_.ResizeInputTensor(2, {2}); + ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk); + + TfLiteTensor* input_assign_index = interpreter_.tensor(0); + input_assign_index->data.i32[0] = 1; + TfLiteTensor* input_read_index = interpreter_.tensor(1); + input_read_index->data.i32[0] = 1; + TfLiteTensor* input_data_index = interpreter_.tensor(2); + input_data_index->data.f[0] = 1717; + input_data_index->data.f[1] = 2121; + ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk); + + // Verify output. + TfLiteTensor* output = interpreter_.tensor(3); + ASSERT_EQ(output->dims->size, 1); + ASSERT_EQ(output->dims->data[0], 2); + EXPECT_EQ(output->data.f[0], 1717); + EXPECT_EQ(output->data.f[1], 2121); + } +} + +} // namespace +} // namespace tflite