Prorotype TFLite resource variables
PiperOrigin-RevId: 259986961
This commit is contained in:
parent
bdcd68771c
commit
a8d2d3bd42
@ -224,6 +224,7 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
|
||||
"//tensorflow/lite/nnapi:nnapi_implementation",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/experimental/resource_variable:resource_variable",
|
||||
] + select({
|
||||
":with_select_tf_ops": [
|
||||
"//tensorflow/lite/delegates/flex:delegate",
|
||||
|
@ -156,12 +156,14 @@ class InterpreterInfo : public GraphInfo {
|
||||
|
||||
Subgraph::Subgraph(ErrorReporter* error_reporter,
|
||||
TfLiteExternalContext** external_contexts,
|
||||
std::vector<std::unique_ptr<Subgraph>>* subgraphs)
|
||||
std::vector<std::unique_ptr<Subgraph>>* subgraphs,
|
||||
ResourceVariableMap* resource_variables)
|
||||
: external_contexts_(external_contexts),
|
||||
error_reporter_(error_reporter),
|
||||
next_execution_plan_index_to_prepare_(0),
|
||||
next_execution_plan_index_to_plan_allocation_(0),
|
||||
subgraphs_(subgraphs) {
|
||||
subgraphs_(subgraphs),
|
||||
resource_variables_(resource_variables) {
|
||||
context_.impl_ = static_cast<void*>(this);
|
||||
context_.ResizeTensor = ResizeTensor;
|
||||
context_.ReportError = ReportErrorC;
|
||||
|
@ -16,12 +16,14 @@ limitations under the License.
|
||||
#define TENSORFLOW_LITE_CORE_SUBGRAPH_H_
|
||||
|
||||
#include <cstdlib>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/allocation.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/core/api/profiler.h"
|
||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||
#include "tensorflow/lite/experimental/resource_variable/resource_variable.h"
|
||||
#include "tensorflow/lite/memory_planner.h"
|
||||
#include "tensorflow/lite/util.h"
|
||||
|
||||
@ -36,7 +38,8 @@ class Subgraph {
|
||||
|
||||
Subgraph(ErrorReporter* error_reporter,
|
||||
TfLiteExternalContext** external_contexts,
|
||||
std::vector<std::unique_ptr<Subgraph>>* subgraphs);
|
||||
std::vector<std::unique_ptr<Subgraph>>* subgraphs,
|
||||
ResourceVariableMap* resource_variables);
|
||||
|
||||
Subgraph(const Subgraph&) = delete;
|
||||
|
||||
@ -160,6 +163,10 @@ class Subgraph {
|
||||
// Read only access to list of variable tensors.
|
||||
const std::vector<int>& variables() const { return variables_; }
|
||||
|
||||
// WARNING: Experimental interface, subject to change.
|
||||
// TODO(ycling): Move this function to an external context interface.
|
||||
ResourceVariableMap& resource_variables() { return *resource_variables_; }
|
||||
|
||||
size_t tensors_size() const { return tensors_.size(); }
|
||||
|
||||
// Return the number of ops in the model.
|
||||
@ -581,6 +588,10 @@ class Subgraph {
|
||||
// Reference to data used by the cancellation function in
|
||||
// `check_cancelled_func_`.
|
||||
void* cancellation_data_ = nullptr;
|
||||
|
||||
// A map of resource variables. Owned by interpreter and shared by multiple
|
||||
// subgraphs.
|
||||
ResourceVariableMap* resource_variables_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
17
tensorflow/lite/experimental/resource_variable/BUILD
Normal file
17
tensorflow/lite/experimental/resource_variable/BUILD
Normal file
@ -0,0 +1,17 @@
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "resource_variable",
|
||||
srcs = [
|
||||
"resource_variable.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"resource_variable.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
],
|
||||
)
|
@ -0,0 +1,78 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/experimental/resource_variable/resource_variable.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <map>
|
||||
|
||||
namespace tflite {
|
||||
|
||||
ResourceVariable::ResourceVariable() {
|
||||
memset(&tensor_, 0, sizeof(TfLiteTensor));
|
||||
}
|
||||
|
||||
ResourceVariable::ResourceVariable(ResourceVariable&& other) {
|
||||
tensor_ = other.tensor_;
|
||||
is_initialized_ = other.is_initialized_;
|
||||
|
||||
memset(&other.tensor_, 0, sizeof(TfLiteTensor));
|
||||
other.is_initialized_ = false;
|
||||
}
|
||||
|
||||
ResourceVariable::~ResourceVariable() {
|
||||
if (is_initialized_) {
|
||||
free(tensor_.data.raw);
|
||||
if (tensor_.dims) {
|
||||
TfLiteIntArrayFree(tensor_.dims);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus ResourceVariable::AssignFrom(const TfLiteTensor* tensor) {
|
||||
// Save the old allocated resources and attributes that we might use.
|
||||
char* old_raw = tensor_.data.raw;
|
||||
size_t old_bytes = tensor_.bytes;
|
||||
TfLiteIntArray* old_dims = tensor_.dims;
|
||||
|
||||
// Copy primitive parameters.
|
||||
memset(&tensor_, 0, sizeof(tensor_));
|
||||
tensor_.allocation_type = kTfLiteDynamic;
|
||||
tensor_.type = tensor->type;
|
||||
tensor_.params = tensor->params;
|
||||
tensor_.quantization = tensor->quantization;
|
||||
|
||||
// Copy old shape if possible otherwise create a new one.
|
||||
if (TfLiteIntArrayEqual(old_dims, tensor->dims)) {
|
||||
tensor_.dims = old_dims;
|
||||
} else {
|
||||
TfLiteIntArrayFree(old_dims);
|
||||
tensor_.dims = TfLiteIntArrayCopy(tensor->dims);
|
||||
}
|
||||
|
||||
// Reuse the same buffer if possible otherwise allocate a new one.
|
||||
tensor_.data.raw = old_raw;
|
||||
if (old_bytes != tensor->bytes) {
|
||||
TfLiteTensorRealloc(tensor->bytes, &tensor_);
|
||||
}
|
||||
|
||||
memcpy(tensor_.data.raw, tensor->data.raw, tensor_.bytes);
|
||||
is_initialized_ = true;
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace tflite
|
@ -0,0 +1,62 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_VARIABLE_RESOURCE_VARIABLE_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_VARIABLE_RESOURCE_VARIABLE_H_
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
/// WARNING: Experimental interface, subject to change.
|
||||
// A resource variable class. It's similar to TensorFlow Resource
|
||||
// Variable, but it's identified with int32 ID in TFLite (instead of
|
||||
// using Resource handle like TensorFlow).
|
||||
//
|
||||
// TODO(b/137042749): TFLite converter cannot convert variables yet.
|
||||
// Variable functionalities are only tested with unit tests now.
|
||||
class ResourceVariable {
|
||||
public:
|
||||
ResourceVariable();
|
||||
ResourceVariable(ResourceVariable&& other);
|
||||
|
||||
ResourceVariable(const ResourceVariable&) = delete;
|
||||
ResourceVariable& operator=(const ResourceVariable&) = delete;
|
||||
|
||||
~ResourceVariable();
|
||||
|
||||
// Assigns data from a tensor. Copies its type, shape and data over.
|
||||
TfLiteStatus AssignFrom(const TfLiteTensor* tensor);
|
||||
|
||||
// Get the data tensor stored in the resource variable.
|
||||
// Returns `nullptr` if the variable is never initialized by calling
|
||||
// `AssignFrom`.
|
||||
TfLiteTensor* GetTensor() { return is_initialized_ ? &tensor_ : nullptr; }
|
||||
|
||||
private:
|
||||
// The tensor (and its buffer stored in `tensor_.data` is fully owned by
|
||||
// the `ResourceVariable` object.
|
||||
TfLiteTensor tensor_;
|
||||
// True if `AssignFrom` function is every called.
|
||||
// False if and only if `tensor_` is filled with zeros.
|
||||
bool is_initialized_ = false;
|
||||
};
|
||||
|
||||
using ResourceVariableMap = std::unordered_map<int, ResourceVariable>;
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_VARIABLE_RESOURCE_VARIABLE_H_
|
@ -134,8 +134,8 @@ void Interpreter::AddSubgraphs(int subgraphs_to_add,
|
||||
|
||||
subgraphs_.reserve(base_index + subgraphs_to_add);
|
||||
for (int i = 0; i < subgraphs_to_add; ++i) {
|
||||
Subgraph* subgraph =
|
||||
new Subgraph(error_reporter_, external_contexts_, &subgraphs_);
|
||||
Subgraph* subgraph = new Subgraph(error_reporter_, external_contexts_,
|
||||
&subgraphs_, &resource_variables_);
|
||||
subgraphs_.emplace_back(subgraph);
|
||||
}
|
||||
}
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/core/api/profiler.h"
|
||||
#include "tensorflow/lite/core/subgraph.h"
|
||||
#include "tensorflow/lite/experimental/resource_variable/resource_variable.h"
|
||||
#include "tensorflow/lite/external_cpu_backend_context.h"
|
||||
#include "tensorflow/lite/memory_planner.h"
|
||||
#include "tensorflow/lite/stderr_reporter.h"
|
||||
@ -539,6 +540,10 @@ class Interpreter {
|
||||
|
||||
// Subgraphs
|
||||
std::vector<std::unique_ptr<Subgraph>> subgraphs_;
|
||||
|
||||
// A map of resource variables. Owned by interpreter and shared by multiple
|
||||
// subgraphs.
|
||||
ResourceVariableMap resource_variables_;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -493,6 +493,36 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "variable_op_kernels",
|
||||
srcs = [
|
||||
"assign_variable.cc",
|
||||
"read_variable.cc",
|
||||
],
|
||||
deps = [
|
||||
":kernel_util",
|
||||
":op_macros",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/kernels/internal:tensor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "variable_ops_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"variable_ops_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":test_main",
|
||||
":test_util",
|
||||
":variable_op_kernels",
|
||||
"//tensorflow/lite:framework",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "custom_ops",
|
||||
srcs = ["rfft2d.cc"],
|
||||
|
86
tensorflow/lite/kernels/assign_variable.cc
Normal file
86
tensorflow/lite/kernels/assign_variable.cc
Normal file
@ -0,0 +1,86 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/core/subgraph.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace custom {
|
||||
namespace assign_variable {
|
||||
|
||||
constexpr int kInputVariableId = 0;
|
||||
constexpr int kInputValue = 1;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
// TODO(b/137042749): TFLite infrastructure (converter, delegate) doesn't
|
||||
// fully support 0-output ops yet. Currently it works if we manually crfat
|
||||
// a TFLite graph that contains variable ops. Note:
|
||||
// * The TFLite Converter need to be changed to be able to produce an op
|
||||
// with 0 output.
|
||||
// * The delegation code need to be changed to handle 0 output ops. However
|
||||
// everything still works fine when variable ops aren't used.
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 0);
|
||||
|
||||
const TfLiteTensor* input_variable_id_tensor =
|
||||
GetInput(context, node, kInputVariableId);
|
||||
TF_LITE_ENSURE_EQ(context, input_variable_id_tensor->type, kTfLiteInt32);
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(input_variable_id_tensor), 1);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
|
||||
|
||||
const TfLiteTensor* input_variable_id_tensor =
|
||||
GetInput(context, node, kInputVariableId);
|
||||
const TfLiteTensor* input_value_tensor = GetInput(context, node, kInputValue);
|
||||
|
||||
int variable_id = input_variable_id_tensor->data.i32[0];
|
||||
auto& resource_variables = subgraph->resource_variables();
|
||||
|
||||
auto variable_iterator = resource_variables.find(variable_id);
|
||||
if (variable_iterator == resource_variables.end()) {
|
||||
auto ret = resource_variables.emplace(variable_id, ResourceVariable());
|
||||
variable_iterator = ret.first;
|
||||
}
|
||||
|
||||
auto& variable = variable_iterator->second;
|
||||
variable.AssignFrom(input_value_tensor);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace assign_variable
|
||||
|
||||
TfLiteRegistration* Register_ASSIGN_VARIABLE() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, assign_variable::Prepare,
|
||||
assign_variable::Eval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace custom
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
88
tensorflow/lite/kernels/read_variable.cc
Normal file
88
tensorflow/lite/kernels/read_variable.cc
Normal file
@ -0,0 +1,88 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/core/subgraph.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace custom {
|
||||
namespace read_variable {
|
||||
|
||||
constexpr int kInputVariableId = 0;
|
||||
constexpr int kOutputValue = 0;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, node->inputs->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
||||
|
||||
const TfLiteTensor* input_variable_id_tensor =
|
||||
GetInput(context, node, kInputVariableId);
|
||||
TF_LITE_ENSURE_EQ(context, input_variable_id_tensor->type, kTfLiteInt32);
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(input_variable_id_tensor), 1);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputValue);
|
||||
SetTensorToDynamic(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
|
||||
|
||||
const TfLiteTensor* input_variable_id_tensor =
|
||||
GetInput(context, node, kInputVariableId);
|
||||
int variable_id = input_variable_id_tensor->data.i32[0];
|
||||
auto& resource_variables = subgraph->resource_variables();
|
||||
|
||||
const auto& variable_iterator = resource_variables.find(variable_id);
|
||||
if (variable_iterator == resource_variables.end()) {
|
||||
context->ReportError(context, "Variable ID %d is read before initialized.",
|
||||
variable_id);
|
||||
return kTfLiteError;
|
||||
}
|
||||
auto& variable = variable_iterator->second;
|
||||
|
||||
TfLiteTensor* variable_tensor = variable.GetTensor();
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputValue);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, variable_tensor->type, output->type);
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(
|
||||
context, output, TfLiteIntArrayCopy(variable_tensor->dims)));
|
||||
memcpy(output->data.raw, variable_tensor->data.raw, output->bytes);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace read_variable
|
||||
|
||||
TfLiteRegistration* Register_READ_VARIABLE() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, read_variable::Prepare,
|
||||
read_variable::Eval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace custom
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
149
tensorflow/lite/kernels/variable_ops_test.cc
Normal file
149
tensorflow/lite/kernels/variable_ops_test.cc
Normal file
@ -0,0 +1,149 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// Forward declaraction for op kernels.
|
||||
namespace ops {
|
||||
namespace custom {
|
||||
|
||||
TfLiteRegistration* Register_ASSIGN_VARIABLE();
|
||||
TfLiteRegistration* Register_READ_VARIABLE();
|
||||
|
||||
} // namespace custom
|
||||
} // namespace ops
|
||||
|
||||
namespace {
|
||||
|
||||
class VariableOpsTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
assign_registration_ = ::tflite::ops::custom::Register_ASSIGN_VARIABLE();
|
||||
ASSERT_NE(assign_registration_, nullptr);
|
||||
read_registration_ = ::tflite::ops::custom::Register_READ_VARIABLE();
|
||||
ASSERT_NE(read_registration_, nullptr);
|
||||
|
||||
ConstructGraph();
|
||||
}
|
||||
|
||||
void ConstructGraph() {
|
||||
// Construct a graph like ths:
|
||||
// Input: %0, %1, %2
|
||||
// Output: %3
|
||||
// variable_assign(%0, %2)
|
||||
// %3 = read(%1)
|
||||
|
||||
int first_new_tensor_index;
|
||||
ASSERT_EQ(interpreter_.AddTensors(4, &first_new_tensor_index), kTfLiteOk);
|
||||
ASSERT_EQ(interpreter_.SetInputs({0, 1, 2}), kTfLiteOk);
|
||||
ASSERT_EQ(interpreter_.SetOutputs({3}), kTfLiteOk);
|
||||
interpreter_.SetTensorParametersReadWrite(0, kTfLiteInt32, "", 0, nullptr,
|
||||
{}, false);
|
||||
interpreter_.SetTensorParametersReadWrite(1, kTfLiteInt32, "", 0, nullptr,
|
||||
{}, false);
|
||||
interpreter_.SetTensorParametersReadWrite(2, kTfLiteFloat32, "", 0, nullptr,
|
||||
{}, false);
|
||||
interpreter_.SetTensorParametersReadWrite(3, kTfLiteFloat32, "", 0, nullptr,
|
||||
{}, false);
|
||||
int node_index;
|
||||
interpreter_.AddNodeWithParameters({0, 2}, {}, nullptr, 0, nullptr,
|
||||
assign_registration_, &node_index);
|
||||
interpreter_.AddNodeWithParameters({1}, {3}, nullptr, 0, nullptr,
|
||||
read_registration_, &node_index);
|
||||
}
|
||||
TfLiteRegistration* assign_registration_;
|
||||
TfLiteRegistration* read_registration_;
|
||||
Interpreter interpreter_;
|
||||
};
|
||||
|
||||
TEST_F(VariableOpsTest, TestAssignThenReadVariable) {
|
||||
ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk);
|
||||
TfLiteTensor* input_assign_index = interpreter_.tensor(0);
|
||||
input_assign_index->data.i32[0] = 1;
|
||||
TfLiteTensor* input_read_index = interpreter_.tensor(1);
|
||||
input_read_index->data.i32[0] = 1;
|
||||
TfLiteTensor* input_data_index = interpreter_.tensor(2);
|
||||
input_data_index->data.f[0] = 1717;
|
||||
ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk);
|
||||
|
||||
// Verify output.
|
||||
TfLiteTensor* output = interpreter_.tensor(3);
|
||||
ASSERT_EQ(output->dims->size, 0);
|
||||
EXPECT_EQ(output->data.f[0], 1717);
|
||||
}
|
||||
|
||||
TEST_F(VariableOpsTest, TestReadVariableBeforeAssign) {
|
||||
ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk);
|
||||
TfLiteTensor* input_assign_index = interpreter_.tensor(0);
|
||||
input_assign_index->data.i32[0] = 1;
|
||||
TfLiteTensor* input_read_index = interpreter_.tensor(1);
|
||||
input_read_index->data.i32[0] = 2;
|
||||
TfLiteTensor* input_data_index = interpreter_.tensor(2);
|
||||
input_data_index->data.f[0] = 1717;
|
||||
|
||||
// Error because variable 2 is never initialized.
|
||||
ASSERT_EQ(interpreter_.Invoke(), kTfLiteError);
|
||||
}
|
||||
|
||||
TEST_F(VariableOpsTest, TestReeasignToDifferentSize) {
|
||||
// 1st invocation. The variable is assigned as a scalar.
|
||||
{
|
||||
ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk);
|
||||
|
||||
TfLiteTensor* input_assign_index = interpreter_.tensor(0);
|
||||
input_assign_index->data.i32[0] = 1;
|
||||
TfLiteTensor* input_read_index = interpreter_.tensor(1);
|
||||
input_read_index->data.i32[0] = 1;
|
||||
TfLiteTensor* input_data_index = interpreter_.tensor(2);
|
||||
input_data_index->data.f[0] = 1717;
|
||||
ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk);
|
||||
|
||||
// Verify output.
|
||||
TfLiteTensor* output = interpreter_.tensor(3);
|
||||
ASSERT_EQ(output->dims->size, 0);
|
||||
EXPECT_EQ(output->data.f[0], 1717);
|
||||
}
|
||||
|
||||
// 2nd invocation. The variable is assigned as a 1D vector with 2 elements.
|
||||
{
|
||||
interpreter_.ResizeInputTensor(2, {2});
|
||||
ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk);
|
||||
|
||||
TfLiteTensor* input_assign_index = interpreter_.tensor(0);
|
||||
input_assign_index->data.i32[0] = 1;
|
||||
TfLiteTensor* input_read_index = interpreter_.tensor(1);
|
||||
input_read_index->data.i32[0] = 1;
|
||||
TfLiteTensor* input_data_index = interpreter_.tensor(2);
|
||||
input_data_index->data.f[0] = 1717;
|
||||
input_data_index->data.f[1] = 2121;
|
||||
ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk);
|
||||
|
||||
// Verify output.
|
||||
TfLiteTensor* output = interpreter_.tensor(3);
|
||||
ASSERT_EQ(output->dims->size, 1);
|
||||
ASSERT_EQ(output->dims->data[0], 2);
|
||||
EXPECT_EQ(output->data.f[0], 1717);
|
||||
EXPECT_EQ(output->data.f[1], 2121);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
Loading…
Reference in New Issue
Block a user