Prorotype TFLite resource variables

PiperOrigin-RevId: 259986961
This commit is contained in:
Yu-Cheng Ling 2019-07-25 11:29:58 -07:00 committed by TensorFlower Gardener
parent bdcd68771c
commit a8d2d3bd42
12 changed files with 534 additions and 5 deletions

View File

@ -224,6 +224,7 @@ cc_library(
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
"//tensorflow/lite/nnapi:nnapi_implementation",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/experimental/resource_variable:resource_variable",
] + select({
":with_select_tf_ops": [
"//tensorflow/lite/delegates/flex:delegate",

View File

@ -156,12 +156,14 @@ class InterpreterInfo : public GraphInfo {
Subgraph::Subgraph(ErrorReporter* error_reporter,
TfLiteExternalContext** external_contexts,
std::vector<std::unique_ptr<Subgraph>>* subgraphs)
std::vector<std::unique_ptr<Subgraph>>* subgraphs,
ResourceVariableMap* resource_variables)
: external_contexts_(external_contexts),
error_reporter_(error_reporter),
next_execution_plan_index_to_prepare_(0),
next_execution_plan_index_to_plan_allocation_(0),
subgraphs_(subgraphs) {
subgraphs_(subgraphs),
resource_variables_(resource_variables) {
context_.impl_ = static_cast<void*>(this);
context_.ResizeTensor = ResizeTensor;
context_.ReportError = ReportErrorC;

View File

@ -16,12 +16,14 @@ limitations under the License.
#define TENSORFLOW_LITE_CORE_SUBGRAPH_H_
#include <cstdlib>
#include <map>
#include <vector>
#include "tensorflow/lite/allocation.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/core/api/profiler.h"
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#include "tensorflow/lite/experimental/resource_variable/resource_variable.h"
#include "tensorflow/lite/memory_planner.h"
#include "tensorflow/lite/util.h"
@ -36,7 +38,8 @@ class Subgraph {
Subgraph(ErrorReporter* error_reporter,
TfLiteExternalContext** external_contexts,
std::vector<std::unique_ptr<Subgraph>>* subgraphs);
std::vector<std::unique_ptr<Subgraph>>* subgraphs,
ResourceVariableMap* resource_variables);
Subgraph(const Subgraph&) = delete;
@ -160,6 +163,10 @@ class Subgraph {
// Read only access to list of variable tensors.
const std::vector<int>& variables() const { return variables_; }
// WARNING: Experimental interface, subject to change.
// TODO(ycling): Move this function to an external context interface.
ResourceVariableMap& resource_variables() { return *resource_variables_; }
size_t tensors_size() const { return tensors_.size(); }
// Return the number of ops in the model.
@ -581,6 +588,10 @@ class Subgraph {
// Reference to data used by the cancellation function in
// `check_cancelled_func_`.
void* cancellation_data_ = nullptr;
// A map of resource variables. Owned by interpreter and shared by multiple
// subgraphs.
ResourceVariableMap* resource_variables_ = nullptr;
};
} // namespace tflite

View File

@ -0,0 +1,17 @@
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "resource_variable",
srcs = [
"resource_variable.cc",
],
hdrs = [
"resource_variable.h",
],
deps = [
"//tensorflow/lite/c:c_api_internal",
],
)

View File

@ -0,0 +1,78 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/resource_variable/resource_variable.h"
#include <cstdlib>
#include <cstring>
#include <map>
namespace tflite {
ResourceVariable::ResourceVariable() {
memset(&tensor_, 0, sizeof(TfLiteTensor));
}
ResourceVariable::ResourceVariable(ResourceVariable&& other) {
tensor_ = other.tensor_;
is_initialized_ = other.is_initialized_;
memset(&other.tensor_, 0, sizeof(TfLiteTensor));
other.is_initialized_ = false;
}
ResourceVariable::~ResourceVariable() {
if (is_initialized_) {
free(tensor_.data.raw);
if (tensor_.dims) {
TfLiteIntArrayFree(tensor_.dims);
}
}
}
TfLiteStatus ResourceVariable::AssignFrom(const TfLiteTensor* tensor) {
// Save the old allocated resources and attributes that we might use.
char* old_raw = tensor_.data.raw;
size_t old_bytes = tensor_.bytes;
TfLiteIntArray* old_dims = tensor_.dims;
// Copy primitive parameters.
memset(&tensor_, 0, sizeof(tensor_));
tensor_.allocation_type = kTfLiteDynamic;
tensor_.type = tensor->type;
tensor_.params = tensor->params;
tensor_.quantization = tensor->quantization;
// Copy old shape if possible otherwise create a new one.
if (TfLiteIntArrayEqual(old_dims, tensor->dims)) {
tensor_.dims = old_dims;
} else {
TfLiteIntArrayFree(old_dims);
tensor_.dims = TfLiteIntArrayCopy(tensor->dims);
}
// Reuse the same buffer if possible otherwise allocate a new one.
tensor_.data.raw = old_raw;
if (old_bytes != tensor->bytes) {
TfLiteTensorRealloc(tensor->bytes, &tensor_);
}
memcpy(tensor_.data.raw, tensor->data.raw, tensor_.bytes);
is_initialized_ = true;
return kTfLiteOk;
}
} // namespace tflite

View File

@ -0,0 +1,62 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_VARIABLE_RESOURCE_VARIABLE_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_VARIABLE_RESOURCE_VARIABLE_H_
#include <unordered_map>
#include "tensorflow/lite/c/c_api_internal.h"
namespace tflite {
/// WARNING: Experimental interface, subject to change.
// A resource variable class. It's similar to TensorFlow Resource
// Variable, but it's identified with int32 ID in TFLite (instead of
// using Resource handle like TensorFlow).
//
// TODO(b/137042749): TFLite converter cannot convert variables yet.
// Variable functionalities are only tested with unit tests now.
class ResourceVariable {
public:
ResourceVariable();
ResourceVariable(ResourceVariable&& other);
ResourceVariable(const ResourceVariable&) = delete;
ResourceVariable& operator=(const ResourceVariable&) = delete;
~ResourceVariable();
// Assigns data from a tensor. Copies its type, shape and data over.
TfLiteStatus AssignFrom(const TfLiteTensor* tensor);
// Get the data tensor stored in the resource variable.
// Returns `nullptr` if the variable is never initialized by calling
// `AssignFrom`.
TfLiteTensor* GetTensor() { return is_initialized_ ? &tensor_ : nullptr; }
private:
// The tensor (and its buffer stored in `tensor_.data` is fully owned by
// the `ResourceVariable` object.
TfLiteTensor tensor_;
// True if `AssignFrom` function is every called.
// False if and only if `tensor_` is filled with zeros.
bool is_initialized_ = false;
};
using ResourceVariableMap = std::unordered_map<int, ResourceVariable>;
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_VARIABLE_RESOURCE_VARIABLE_H_

View File

@ -134,8 +134,8 @@ void Interpreter::AddSubgraphs(int subgraphs_to_add,
subgraphs_.reserve(base_index + subgraphs_to_add);
for (int i = 0; i < subgraphs_to_add; ++i) {
Subgraph* subgraph =
new Subgraph(error_reporter_, external_contexts_, &subgraphs_);
Subgraph* subgraph = new Subgraph(error_reporter_, external_contexts_,
&subgraphs_, &resource_variables_);
subgraphs_.emplace_back(subgraph);
}
}

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/profiler.h"
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/experimental/resource_variable/resource_variable.h"
#include "tensorflow/lite/external_cpu_backend_context.h"
#include "tensorflow/lite/memory_planner.h"
#include "tensorflow/lite/stderr_reporter.h"
@ -539,6 +540,10 @@ class Interpreter {
// Subgraphs
std::vector<std::unique_ptr<Subgraph>> subgraphs_;
// A map of resource variables. Owned by interpreter and shared by multiple
// subgraphs.
ResourceVariableMap resource_variables_;
};
} // namespace tflite

View File

@ -493,6 +493,36 @@ cc_library(
],
)
cc_library(
name = "variable_op_kernels",
srcs = [
"assign_variable.cc",
"read_variable.cc",
],
deps = [
":kernel_util",
":op_macros",
"//tensorflow/lite:framework",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels/internal:tensor",
],
)
cc_test(
name = "variable_ops_test",
size = "small",
srcs = [
"variable_ops_test.cc",
],
deps = [
":test_main",
":test_util",
":variable_op_kernels",
"//tensorflow/lite:framework",
"@com_google_googletest//:gtest",
],
)
cc_library(
name = "custom_ops",
srcs = ["rfft2d.cc"],

View File

@ -0,0 +1,86 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string.h>
#include <memory>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
namespace tflite {
namespace ops {
namespace custom {
namespace assign_variable {
constexpr int kInputVariableId = 0;
constexpr int kInputValue = 1;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
// TODO(b/137042749): TFLite infrastructure (converter, delegate) doesn't
// fully support 0-output ops yet. Currently it works if we manually crfat
// a TFLite graph that contains variable ops. Note:
// * The TFLite Converter need to be changed to be able to produce an op
// with 0 output.
// * The delegation code need to be changed to handle 0 output ops. However
// everything still works fine when variable ops aren't used.
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 0);
const TfLiteTensor* input_variable_id_tensor =
GetInput(context, node, kInputVariableId);
TF_LITE_ENSURE_EQ(context, input_variable_id_tensor->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumElements(input_variable_id_tensor), 1);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
const TfLiteTensor* input_variable_id_tensor =
GetInput(context, node, kInputVariableId);
const TfLiteTensor* input_value_tensor = GetInput(context, node, kInputValue);
int variable_id = input_variable_id_tensor->data.i32[0];
auto& resource_variables = subgraph->resource_variables();
auto variable_iterator = resource_variables.find(variable_id);
if (variable_iterator == resource_variables.end()) {
auto ret = resource_variables.emplace(variable_id, ResourceVariable());
variable_iterator = ret.first;
}
auto& variable = variable_iterator->second;
variable.AssignFrom(input_value_tensor);
return kTfLiteOk;
}
} // namespace assign_variable
TfLiteRegistration* Register_ASSIGN_VARIABLE() {
static TfLiteRegistration r = {nullptr, nullptr, assign_variable::Prepare,
assign_variable::Eval};
return &r;
}
} // namespace custom
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,88 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string.h>
#include <memory>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
namespace tflite {
namespace ops {
namespace custom {
namespace read_variable {
constexpr int kInputVariableId = 0;
constexpr int kOutputValue = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 1);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
const TfLiteTensor* input_variable_id_tensor =
GetInput(context, node, kInputVariableId);
TF_LITE_ENSURE_EQ(context, input_variable_id_tensor->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumElements(input_variable_id_tensor), 1);
TfLiteTensor* output = GetOutput(context, node, kOutputValue);
SetTensorToDynamic(output);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
const TfLiteTensor* input_variable_id_tensor =
GetInput(context, node, kInputVariableId);
int variable_id = input_variable_id_tensor->data.i32[0];
auto& resource_variables = subgraph->resource_variables();
const auto& variable_iterator = resource_variables.find(variable_id);
if (variable_iterator == resource_variables.end()) {
context->ReportError(context, "Variable ID %d is read before initialized.",
variable_id);
return kTfLiteError;
}
auto& variable = variable_iterator->second;
TfLiteTensor* variable_tensor = variable.GetTensor();
TfLiteTensor* output = GetOutput(context, node, kOutputValue);
TF_LITE_ENSURE_EQ(context, variable_tensor->type, output->type);
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(
context, output, TfLiteIntArrayCopy(variable_tensor->dims)));
memcpy(output->data.raw, variable_tensor->data.raw, output->bytes);
return kTfLiteOk;
}
} // namespace read_variable
TfLiteRegistration* Register_READ_VARIABLE() {
static TfLiteRegistration r = {nullptr, nullptr, read_variable::Prepare,
read_variable::Eval};
return &r;
}
} // namespace custom
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,149 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#include <gtest/gtest.h>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
namespace tflite {
// Forward declaraction for op kernels.
namespace ops {
namespace custom {
TfLiteRegistration* Register_ASSIGN_VARIABLE();
TfLiteRegistration* Register_READ_VARIABLE();
} // namespace custom
} // namespace ops
namespace {
class VariableOpsTest : public ::testing::Test {
protected:
void SetUp() override {
assign_registration_ = ::tflite::ops::custom::Register_ASSIGN_VARIABLE();
ASSERT_NE(assign_registration_, nullptr);
read_registration_ = ::tflite::ops::custom::Register_READ_VARIABLE();
ASSERT_NE(read_registration_, nullptr);
ConstructGraph();
}
void ConstructGraph() {
// Construct a graph like ths:
// Input: %0, %1, %2
// Output: %3
// variable_assign(%0, %2)
// %3 = read(%1)
int first_new_tensor_index;
ASSERT_EQ(interpreter_.AddTensors(4, &first_new_tensor_index), kTfLiteOk);
ASSERT_EQ(interpreter_.SetInputs({0, 1, 2}), kTfLiteOk);
ASSERT_EQ(interpreter_.SetOutputs({3}), kTfLiteOk);
interpreter_.SetTensorParametersReadWrite(0, kTfLiteInt32, "", 0, nullptr,
{}, false);
interpreter_.SetTensorParametersReadWrite(1, kTfLiteInt32, "", 0, nullptr,
{}, false);
interpreter_.SetTensorParametersReadWrite(2, kTfLiteFloat32, "", 0, nullptr,
{}, false);
interpreter_.SetTensorParametersReadWrite(3, kTfLiteFloat32, "", 0, nullptr,
{}, false);
int node_index;
interpreter_.AddNodeWithParameters({0, 2}, {}, nullptr, 0, nullptr,
assign_registration_, &node_index);
interpreter_.AddNodeWithParameters({1}, {3}, nullptr, 0, nullptr,
read_registration_, &node_index);
}
TfLiteRegistration* assign_registration_;
TfLiteRegistration* read_registration_;
Interpreter interpreter_;
};
TEST_F(VariableOpsTest, TestAssignThenReadVariable) {
ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk);
TfLiteTensor* input_assign_index = interpreter_.tensor(0);
input_assign_index->data.i32[0] = 1;
TfLiteTensor* input_read_index = interpreter_.tensor(1);
input_read_index->data.i32[0] = 1;
TfLiteTensor* input_data_index = interpreter_.tensor(2);
input_data_index->data.f[0] = 1717;
ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk);
// Verify output.
TfLiteTensor* output = interpreter_.tensor(3);
ASSERT_EQ(output->dims->size, 0);
EXPECT_EQ(output->data.f[0], 1717);
}
TEST_F(VariableOpsTest, TestReadVariableBeforeAssign) {
ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk);
TfLiteTensor* input_assign_index = interpreter_.tensor(0);
input_assign_index->data.i32[0] = 1;
TfLiteTensor* input_read_index = interpreter_.tensor(1);
input_read_index->data.i32[0] = 2;
TfLiteTensor* input_data_index = interpreter_.tensor(2);
input_data_index->data.f[0] = 1717;
// Error because variable 2 is never initialized.
ASSERT_EQ(interpreter_.Invoke(), kTfLiteError);
}
TEST_F(VariableOpsTest, TestReeasignToDifferentSize) {
// 1st invocation. The variable is assigned as a scalar.
{
ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk);
TfLiteTensor* input_assign_index = interpreter_.tensor(0);
input_assign_index->data.i32[0] = 1;
TfLiteTensor* input_read_index = interpreter_.tensor(1);
input_read_index->data.i32[0] = 1;
TfLiteTensor* input_data_index = interpreter_.tensor(2);
input_data_index->data.f[0] = 1717;
ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk);
// Verify output.
TfLiteTensor* output = interpreter_.tensor(3);
ASSERT_EQ(output->dims->size, 0);
EXPECT_EQ(output->data.f[0], 1717);
}
// 2nd invocation. The variable is assigned as a 1D vector with 2 elements.
{
interpreter_.ResizeInputTensor(2, {2});
ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk);
TfLiteTensor* input_assign_index = interpreter_.tensor(0);
input_assign_index->data.i32[0] = 1;
TfLiteTensor* input_read_index = interpreter_.tensor(1);
input_read_index->data.i32[0] = 1;
TfLiteTensor* input_data_index = interpreter_.tensor(2);
input_data_index->data.f[0] = 1717;
input_data_index->data.f[1] = 2121;
ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk);
// Verify output.
TfLiteTensor* output = interpreter_.tensor(3);
ASSERT_EQ(output->dims->size, 1);
ASSERT_EQ(output->dims->data[0], 2);
EXPECT_EQ(output->data.f[0], 1717);
EXPECT_EQ(output->data.f[1], 2121);
}
}
} // namespace
} // namespace tflite