Add hash table resource implementation into TFLite

PiperOrigin-RevId: 282460990
Change-Id: If489e2ab2d3e9ad2b4899f4ec055e611def6696d
This commit is contained in:
Jaesung Chung 2019-11-25 16:48:44 -08:00 committed by TensorFlower Gardener
parent 1b17b134a9
commit 3c725ae67b
17 changed files with 535 additions and 70 deletions

View File

@ -215,7 +215,7 @@ cc_library(
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/core/api",
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
"//tensorflow/lite/experimental/resource_variable",
"//tensorflow/lite/experimental/resource",
"//tensorflow/lite/nnapi:nnapi_implementation",
"//tensorflow/lite/schema:schema_fbs",
],

View File

@ -158,13 +158,13 @@ class InterpreterInfo : public GraphInfo {
Subgraph::Subgraph(ErrorReporter* error_reporter,
TfLiteExternalContext** external_contexts,
std::vector<std::unique_ptr<Subgraph>>* subgraphs,
ResourceVariableMap* resource_variables)
resource::ResourceMap* resources)
: external_contexts_(external_contexts),
error_reporter_(error_reporter),
next_execution_plan_index_to_prepare_(0),
next_execution_plan_index_to_plan_allocation_(0),
subgraphs_(subgraphs),
resource_variables_(resource_variables) {
resources_(resources) {
context_.impl_ = static_cast<void*>(this);
context_.ResizeTensor = ResizeTensor;
context_.ReportError = ReportErrorC;

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/core/api/profiler.h"
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#include "tensorflow/lite/experimental/resource_variable/resource_variable.h"
#include "tensorflow/lite/experimental/resource/resource_base.h"
#include "tensorflow/lite/memory_planner.h"
#include "tensorflow/lite/util.h"
@ -40,7 +40,7 @@ class Subgraph {
Subgraph(ErrorReporter* error_reporter,
TfLiteExternalContext** external_contexts,
std::vector<std::unique_ptr<Subgraph>>* subgraphs,
ResourceVariableMap* resource_variables);
resource::ResourceMap* resources);
Subgraph(const Subgraph&) = delete;
@ -166,7 +166,7 @@ class Subgraph {
// WARNING: Experimental interface, subject to change.
// TODO(ycling): Move this function to an external context interface.
ResourceVariableMap& resource_variables() { return *resource_variables_; }
resource::ResourceMap& resources() { return *resources_; }
size_t tensors_size() const { return tensors_.size(); }
@ -635,9 +635,8 @@ class Subgraph {
// `check_cancelled_func_`.
void* cancellation_data_ = nullptr;
// A map of resource variables. Owned by interpreter and shared by multiple
// subgraphs.
ResourceVariableMap* resource_variables_ = nullptr;
// A map of resources. Owned by interpreter and shared by multiple subgraphs.
resource::ResourceMap* resources_ = nullptr;
};
} // namespace tflite

View File

@ -0,0 +1,25 @@
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "resource",
srcs = [
"resource_variable.cc",
"static_hashtable.cc",
],
hdrs = [
"lookup_interfaces.h",
"lookup_util.h",
"resource_base.h",
"resource_variable.h",
"static_hashtable.h",
],
deps = [
"//tensorflow/lite:string_util",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels/internal:compatibility",
"//tensorflow/lite/kernels/internal:tensor",
],
)

View File

@ -0,0 +1,64 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_INTERFACES_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_INTERFACES_H_
#include <unordered_map>
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/experimental/resource/lookup_util.h"
#include "tensorflow/lite/experimental/resource/resource_base.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/string_util.h"
namespace tflite {
namespace resource {
/// WARNING: Experimental interface, subject to change.
// A resource hash table interface. It's similar to TensorFlow core's
// LookupInterface class. But it's identified with int32 ID in TFLite (instead
// of using Resource handle like TensorFlow).
class LookupInterface : public ResourceBase {
public:
virtual TfLiteStatus Lookup(TfLiteContext* context, const TfLiteTensor* keys,
TfLiteTensor* values,
const TfLiteTensor* default_value) = 0;
virtual TfLiteStatus Import(TfLiteContext* context, const TfLiteTensor* keys,
const TfLiteTensor* values) = 0;
virtual size_t Size() = 0;
virtual TfLiteType GetKeyType() const = 0;
virtual TfLiteType GetValueType() const = 0;
virtual TfLiteStatus CheckKeyAndValueTypes(TfLiteContext* context,
const TfLiteTensor* keys,
const TfLiteTensor* values) = 0;
};
// Creates an resource hash table, shared among all the subgraphs with the
// given resource id if there is an existing one.
// WARNING: Experimental interface, subject to change.
void CreateHashtableResourceIfNotAvailable(ResourceMap* resources,
int resource_id,
TfLiteType key_dtype,
TfLiteType value_dtype);
// Returns the corresponding resource hash table, or nullptr if none.
// WARNING: Experimental interface, subject to change.
LookupInterface* GetHashtableResource(ResourceMap* resources, int resource_id);
} // namespace resource
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_INTERFACES_H_

View File

@ -0,0 +1,114 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_UTIL_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_UTIL_H_
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/string_util.h"
namespace tflite {
namespace resource {
namespace internal {
/// Helper class for accessing TFLite tensor data.
template <typename T>
class TensorReader {
public:
explicit TensorReader(const TfLiteTensor* input) {
input_data_ = GetTensorData<T>(input);
}
// Returns the corresponding scalar data at the given index position.
// In here, it does not check the validity of the index should be guaranteed
// in order not to harm the performance. Caller should take care of it.
T GetData(int index) { return input_data_[index]; }
private:
const T* input_data_;
};
/// Helper class for accesing TFLite tensor data. This specialized class is for
/// std::string type.
template <>
class TensorReader<std::string> {
public:
explicit TensorReader(const TfLiteTensor* input) : input_(input) {}
// Returns the corresponding string data at the given index position.
// In here, it does not check the validity of the index should be guaranteed
// in order not to harm the performance. Caller should take care of it.
std::string GetData(int index) {
auto string_ref = GetString(input_, index);
return std::string(string_ref.str, string_ref.len);
}
private:
const TfLiteTensor* input_;
};
/// WARNING: Experimental interface, subject to change.
/// Helper class for writing TFLite tensor data.
template <typename ValueType>
class TensorWriter {
public:
explicit TensorWriter(TfLiteTensor* values) {
output_data_ = GetTensorData<ValueType>(values);
}
// Sets the given value to the given index position of the tensor storage.
// In here, it does not check the validity of the index should be guaranteed
// in order not to harm the performance. Caller should take care of it.
void SetData(int index, ValueType& value) { output_data_[index] = value; }
// Commit updates. In this case, it does nothing since the SetData method
// writes data directly.
void Commit() {
// Noop.
}
private:
ValueType* output_data_;
};
/// WARNING: Experimental interface, subject to change.
/// Helper class for writing TFLite tensor data. This specialized class is for
/// std::string type.
template <>
class TensorWriter<std::string> {
public:
explicit TensorWriter(TfLiteTensor* values) : values_(values) {}
// Queues the given string value to the buffer regardless of the provided
// index.
// In here, it does not check the validity of the index should be guaranteed
// in order not to harm the performance. Caller should take care of it.
void SetData(int index, const std::string& value) {
buf_.AddString(value.data(), value.length());
}
// Commit updates. The stored data in DynamicBuffer will be written into the
// tensor storage.
void Commit() { buf_.WriteToTensor(values_, nullptr); }
private:
TfLiteTensor* values_;
DynamicBuffer buf_;
};
} // namespace internal
} // namespace resource
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_UTIL_H_

View File

@ -0,0 +1,43 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_BASE_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_BASE_H_
#include <memory>
#include <unordered_map>
#include "tensorflow/lite/kernels/internal/compatibility.h"
namespace tflite {
namespace resource {
// ResourceBase is an abstract base class for resources.
/// WARNING: Experimental interface, subject to change.
class ResourceBase {
public:
explicit ResourceBase() {}
virtual ~ResourceBase() {}
// Returns true if it is initialized.
virtual bool IsInitialized() = 0;
};
/// WARNING: Experimental interface, subject to change.
using ResourceMap = std::unordered_map<int32, std::unique_ptr<ResourceBase>>;
} // namespace resource
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_BASE_H_

View File

@ -13,13 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/resource_variable/resource_variable.h"
#include "tensorflow/lite/experimental/resource/resource_variable.h"
#include <cstdlib>
#include <cstring>
#include <map>
#include <memory>
namespace tflite {
namespace resource {
ResourceVariable::ResourceVariable() {
memset(&tensor_, 0, sizeof(TfLiteTensor));
@ -75,4 +77,22 @@ TfLiteStatus ResourceVariable::AssignFrom(const TfLiteTensor* tensor) {
return kTfLiteOk;
}
void CreateResourceVariableIfNotAvailable(ResourceMap* resources,
int resource_id) {
if (resources->count(resource_id) != 0) {
return;
}
resources->emplace(
resource_id, std::unique_ptr<ResourceVariable>(new ResourceVariable()));
}
ResourceVariable* GetResourceVariable(ResourceMap* resources, int resource_id) {
auto it = resources->find(resource_id);
if (it != resources->end()) {
return static_cast<ResourceVariable*>(it->second.get());
}
return nullptr;
}
} // namespace resource
} // namespace tflite

View File

@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_VARIABLE_RESOURCE_VARIABLE_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_VARIABLE_RESOURCE_VARIABLE_H_
#include <unordered_map>
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_VARIABLE_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_VARIABLE_H_
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/experimental/resource/resource_base.h"
namespace tflite {
namespace resource {
/// WARNING: Experimental interface, subject to change.
// A resource variable class. It's similar to TensorFlow Resource
@ -28,7 +28,7 @@ namespace tflite {
//
// TODO(b/137042749): TFLite converter cannot convert variables yet.
// Variable functionalities are only tested with unit tests now.
class ResourceVariable {
class ResourceVariable : public ResourceBase {
public:
ResourceVariable();
ResourceVariable(ResourceVariable&& other);
@ -36,7 +36,7 @@ class ResourceVariable {
ResourceVariable(const ResourceVariable&) = delete;
ResourceVariable& operator=(const ResourceVariable&) = delete;
~ResourceVariable();
~ResourceVariable() override;
// Assigns data from a tensor. Copies its type, shape and data over.
TfLiteStatus AssignFrom(const TfLiteTensor* tensor);
@ -46,6 +46,9 @@ class ResourceVariable {
// `AssignFrom`.
TfLiteTensor* GetTensor() { return is_initialized_ ? &tensor_ : nullptr; }
// Returns true if this resource variable is initialized.
bool IsInitialized() override { return is_initialized_; }
private:
// The tensor (and its buffer stored in `tensor_.data` is fully owned by
// the `ResourceVariable` object.
@ -55,8 +58,17 @@ class ResourceVariable {
bool is_initialized_ = false;
};
using ResourceVariableMap = std::unordered_map<int, ResourceVariable>;
// Creates a resource variable, shared among all the subgraphs with the given
// resource id if there is an existing one.
// WARNING: Experimental interface, subject to change.
void CreateResourceVariableIfNotAvailable(ResourceMap* resources,
int resource_id);
// Returns the corresponding resource variable, or nullptr if none.
// WARNING: Experimental interface, subject to change.
ResourceVariable* GetResourceVariable(ResourceMap* resources, int resource_id);
} // namespace resource
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_VARIABLE_RESOURCE_VARIABLE_H_
#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_VARIABLE_H_

View File

@ -0,0 +1,129 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/resource/static_hashtable.h"
#include <memory>
#include "tensorflow/lite/experimental/resource/lookup_interfaces.h"
namespace tflite {
namespace resource {
namespace internal {
template <typename KeyType, typename ValueType>
TfLiteStatus StaticHashtable<KeyType, ValueType>::Lookup(
TfLiteContext* context, const TfLiteTensor* keys, TfLiteTensor* values,
const TfLiteTensor* default_value) {
TF_LITE_ENSURE(context, is_initialized_);
const int size =
MatchingFlatSize(GetTensorShape(keys), GetTensorShape(values));
auto key_tensor_reader = TensorReader<KeyType>(keys);
auto value_tensor_writer = TensorWriter<ValueType>(values);
auto default_value_tensor_reader = TensorReader<ValueType>(default_value);
ValueType first_default_value = default_value_tensor_reader.GetData(0);
for (int i = 0; i < size; ++i) {
auto result = map_.find(key_tensor_reader.GetData(i));
if (result != map_.end()) {
value_tensor_writer.SetData(i, result->second);
} else {
value_tensor_writer.SetData(i, first_default_value);
}
}
// This is for a string tensor case in order to write buffer back to the
// actual tensor destination. Otherwise, it does nothing since the scalar data
// will be written into the tensor storage directly.
value_tensor_writer.Commit();
return kTfLiteOk;
}
template <typename KeyType, typename ValueType>
TfLiteStatus StaticHashtable<KeyType, ValueType>::Import(
TfLiteContext* context, const TfLiteTensor* keys,
const TfLiteTensor* values) {
// Import nodes can be invoked twice because the converter will not extract
// the initializer graph separately from the original graph. The invocations
// after the first call will be ignored.
if (is_initialized_) {
return kTfLiteOk;
}
const int size =
MatchingFlatSize(GetTensorShape(keys), GetTensorShape(values));
auto key_tensor_reader = TensorReader<KeyType>(keys);
auto value_tensor_writer = TensorReader<ValueType>(values);
for (int i = 0; i < size; ++i) {
map_.insert({key_tensor_reader.GetData(i), value_tensor_writer.GetData(i)});
}
is_initialized_ = true;
return kTfLiteOk;
}
template <typename KeyType>
LookupInterface* CreateStaticHashtableWithGivenKey(TfLiteType key_type,
TfLiteType value_type) {
switch (value_type) {
case kTfLiteInt32:
return new StaticHashtable<KeyType, int32>(key_type, value_type);
case kTfLiteString:
return new StaticHashtable<KeyType, std::string>(key_type, value_type);
case kTfLiteFloat32:
return new StaticHashtable<KeyType, float>(key_type, value_type);
default:
return nullptr;
}
}
LookupInterface* CreateStaticHashtable(TfLiteType key_type,
TfLiteType value_type) {
switch (key_type) {
case kTfLiteInt32:
return CreateStaticHashtableWithGivenKey<int32>(key_type, value_type);
case kTfLiteString:
return CreateStaticHashtableWithGivenKey<std::string>(key_type,
value_type);
default:
return nullptr;
}
}
} // namespace internal
void CreateHashtableResourceIfNotAvailable(ResourceMap* resources,
int resource_id,
TfLiteType key_dtype,
TfLiteType value_dtype) {
if (resources->count(resource_id) != 0) {
return;
}
auto* hashtable = internal::CreateStaticHashtable(key_dtype, value_dtype);
resources->emplace(resource_id, std::unique_ptr<LookupInterface>(hashtable));
}
LookupInterface* GetHashtableResource(ResourceMap* resources, int resource_id) {
auto it = resources->find(resource_id);
if (it != resources->end()) {
return static_cast<LookupInterface*>(it->second.get());
}
return nullptr;
}
} // namespace resource
} // namespace tflite

View File

@ -0,0 +1,84 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_STATIC_HASHTABLE_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_STATIC_HASHTABLE_H_
#include <unordered_map>
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/experimental/resource/lookup_interfaces.h"
#include "tensorflow/lite/experimental/resource/lookup_util.h"
#include "tensorflow/lite/experimental/resource/resource_base.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/string_util.h"
namespace tflite {
namespace resource {
namespace internal {
// A static hash table class. This hash table allows initialization one time in
// its life cycle. This hash table implements Tensorflow core's HashTableV2 op.
template <typename KeyType, typename ValueType>
class StaticHashtable : public tflite::resource::LookupInterface {
public:
explicit StaticHashtable(TfLiteType key_type, TfLiteType value_type)
: key_type_(key_type), value_type_(value_type) {}
~StaticHashtable() override {}
// Finds the corresponding value of the given keys tensor in the map and
// copies the result data to the given values tensor. If there is no matching
// value, it will write the default value into the matched position instead.
TfLiteStatus Lookup(TfLiteContext* context, const TfLiteTensor* keys,
TfLiteTensor* values,
const TfLiteTensor* default_value) override;
// Inserts the given key and value tensor data into the hash table.
TfLiteStatus Import(TfLiteContext* context, const TfLiteTensor* keys,
const TfLiteTensor* values) override;
// Returns the item size of the hash table.
size_t Size() override { return map_.size(); }
TfLiteType GetKeyType() const override { return key_type_; }
TfLiteType GetValueType() const override { return value_type_; }
TfLiteStatus CheckKeyAndValueTypes(TfLiteContext* context,
const TfLiteTensor* keys,
const TfLiteTensor* values) override {
TF_LITE_ENSURE_EQ(context, keys->type, key_type_);
TF_LITE_ENSURE_EQ(context, values->type, value_type_);
return kTfLiteOk;
}
// Returns true if the hash table is initialized.
bool IsInitialized() override { return is_initialized_; }
private:
TfLiteType key_type_;
TfLiteType value_type_;
std::unordered_map<KeyType, ValueType> map_;
bool is_initialized_ = false;
};
::tflite::resource::LookupInterface* CreateStaticHashtable(
TfLiteType key_type, TfLiteType value_type);
} // namespace internal
} // namespace resource
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_STATIC_HASHTABLE_H_

View File

@ -1,17 +0,0 @@
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "resource_variable",
srcs = [
"resource_variable.cc",
],
hdrs = [
"resource_variable.h",
],
deps = [
"//tensorflow/lite/c:c_api_internal",
],
)

View File

@ -155,7 +155,7 @@ void Interpreter::AddSubgraphs(int subgraphs_to_add,
subgraphs_.reserve(base_index + subgraphs_to_add);
for (int i = 0; i < subgraphs_to_add; ++i) {
Subgraph* subgraph = new Subgraph(error_reporter_, external_contexts_,
&subgraphs_, &resource_variables_);
&subgraphs_, &resources_);
subgraphs_.emplace_back(subgraph);
}
}

View File

@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/profiler.h"
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/experimental/resource_variable/resource_variable.h"
#include "tensorflow/lite/experimental/resource/resource_base.h"
#include "tensorflow/lite/external_cpu_backend_context.h"
#include "tensorflow/lite/memory_planner.h"
#include "tensorflow/lite/stderr_reporter.h"
@ -522,9 +522,8 @@ class Interpreter {
// Subgraphs
std::vector<std::unique_ptr<Subgraph>> subgraphs_;
// A map of resource variables. Owned by interpreter and shared by multiple
// subgraphs.
ResourceVariableMap resource_variables_;
// A map of resources. Owned by interpreter and shared by multiple subgraphs.
resource::ResourceMap resources_;
};
} // namespace tflite

View File

@ -528,6 +528,7 @@ cc_library(
"//tensorflow/lite:framework",
"//tensorflow/lite:string_util",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/experimental/resource",
"//tensorflow/lite/kernels/internal:audio_utils",
"//tensorflow/lite/kernels/internal:common",
"//tensorflow/lite/kernels/internal:compatibility",
@ -561,6 +562,7 @@ cc_library(
":op_macros",
"//tensorflow/lite:framework",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/experimental/resource",
"//tensorflow/lite/kernels/internal:tensor",
],
)

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/experimental/resource/resource_variable.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
@ -43,10 +44,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// everything still works fine when variable ops aren't used.
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 0);
const TfLiteTensor* input_variable_id_tensor =
const TfLiteTensor* input_resource_id_tensor =
GetInput(context, node, kInputVariableId);
TF_LITE_ENSURE_EQ(context, input_variable_id_tensor->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumElements(input_variable_id_tensor), 1);
TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumElements(input_resource_id_tensor), 1);
return kTfLiteOk;
}
@ -54,21 +55,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
const TfLiteTensor* input_variable_id_tensor =
const TfLiteTensor* input_resource_id_tensor =
GetInput(context, node, kInputVariableId);
const TfLiteTensor* input_value_tensor = GetInput(context, node, kInputValue);
int variable_id = input_variable_id_tensor->data.i32[0];
auto& resource_variables = subgraph->resource_variables();
auto variable_iterator = resource_variables.find(variable_id);
if (variable_iterator == resource_variables.end()) {
auto ret = resource_variables.emplace(variable_id, ResourceVariable());
variable_iterator = ret.first;
}
auto& variable = variable_iterator->second;
variable.AssignFrom(input_value_tensor);
int resource_id = input_resource_id_tensor->data.i32[0];
auto& resources = subgraph->resources();
resource::CreateResourceVariableIfNotAvailable(&resources, resource_id);
auto* variable = resource::GetResourceVariable(&resources, resource_id);
TF_LITE_ENSURE(context, variable != nullptr);
variable->AssignFrom(input_value_tensor);
return kTfLiteOk;
}

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/experimental/resource/resource_variable.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
@ -36,10 +37,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 1);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
const TfLiteTensor* input_variable_id_tensor =
const TfLiteTensor* input_resource_id_tensor =
GetInput(context, node, kInputVariableId);
TF_LITE_ENSURE_EQ(context, input_variable_id_tensor->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumElements(input_variable_id_tensor), 1);
TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumElements(input_resource_id_tensor), 1);
TfLiteTensor* output = GetOutput(context, node, kOutputValue);
SetTensorToDynamic(output);
@ -50,20 +51,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
const TfLiteTensor* input_variable_id_tensor =
const TfLiteTensor* input_resource_id_tensor =
GetInput(context, node, kInputVariableId);
int variable_id = input_variable_id_tensor->data.i32[0];
auto& resource_variables = subgraph->resource_variables();
int resource_id = input_resource_id_tensor->data.i32[0];
auto& resources = subgraph->resources();
auto* variable = resource::GetResourceVariable(&resources, resource_id);
TF_LITE_ENSURE(context, variable != nullptr);
const auto& variable_iterator = resource_variables.find(variable_id);
if (variable_iterator == resource_variables.end()) {
context->ReportError(context, "Variable ID %d is read before initialized.",
variable_id);
return kTfLiteError;
}
auto& variable = variable_iterator->second;
TfLiteTensor* variable_tensor = variable.GetTensor();
TfLiteTensor* variable_tensor = variable->GetTensor();
TfLiteTensor* output = GetOutput(context, node, kOutputValue);
TF_LITE_ENSURE_EQ(context, variable_tensor->type, output->type);