diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index 4a112afacd0..2783a387513 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -215,7 +215,7 @@ cc_library( "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/core/api", "//tensorflow/lite/delegates/nnapi:nnapi_delegate", - "//tensorflow/lite/experimental/resource_variable", + "//tensorflow/lite/experimental/resource", "//tensorflow/lite/nnapi:nnapi_implementation", "//tensorflow/lite/schema:schema_fbs", ], diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 2aea1b0ca48..d03f5721587 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -158,13 +158,13 @@ class InterpreterInfo : public GraphInfo { Subgraph::Subgraph(ErrorReporter* error_reporter, TfLiteExternalContext** external_contexts, std::vector>* subgraphs, - ResourceVariableMap* resource_variables) + resource::ResourceMap* resources) : external_contexts_(external_contexts), error_reporter_(error_reporter), next_execution_plan_index_to_prepare_(0), next_execution_plan_index_to_plan_allocation_(0), subgraphs_(subgraphs), - resource_variables_(resource_variables) { + resources_(resources) { context_.impl_ = static_cast(this); context_.ResizeTensor = ResizeTensor; context_.ReportError = ReportErrorC; diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h index 688cffee213..e6796264703 100644 --- a/tensorflow/lite/core/subgraph.h +++ b/tensorflow/lite/core/subgraph.h @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/core/api/profiler.h" #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" -#include "tensorflow/lite/experimental/resource_variable/resource_variable.h" +#include "tensorflow/lite/experimental/resource/resource_base.h" #include "tensorflow/lite/memory_planner.h" #include "tensorflow/lite/util.h" @@ -40,7 +40,7 @@ class Subgraph { Subgraph(ErrorReporter* error_reporter, TfLiteExternalContext** external_contexts, std::vector>* subgraphs, - ResourceVariableMap* resource_variables); + resource::ResourceMap* resources); Subgraph(const Subgraph&) = delete; @@ -166,7 +166,7 @@ class Subgraph { // WARNING: Experimental interface, subject to change. // TODO(ycling): Move this function to an external context interface. - ResourceVariableMap& resource_variables() { return *resource_variables_; } + resource::ResourceMap& resources() { return *resources_; } size_t tensors_size() const { return tensors_.size(); } @@ -635,9 +635,8 @@ class Subgraph { // `check_cancelled_func_`. void* cancellation_data_ = nullptr; - // A map of resource variables. Owned by interpreter and shared by multiple - // subgraphs. - ResourceVariableMap* resource_variables_ = nullptr; + // A map of resources. Owned by interpreter and shared by multiple subgraphs. + resource::ResourceMap* resources_ = nullptr; }; } // namespace tflite diff --git a/tensorflow/lite/experimental/resource/BUILD b/tensorflow/lite/experimental/resource/BUILD new file mode 100644 index 00000000000..b7dd0beaa70 --- /dev/null +++ b/tensorflow/lite/experimental/resource/BUILD @@ -0,0 +1,25 @@ +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "resource", + srcs = [ + "resource_variable.cc", + "static_hashtable.cc", + ], + hdrs = [ + "lookup_interfaces.h", + "lookup_util.h", + "resource_base.h", + "resource_variable.h", + "static_hashtable.h", + ], + deps = [ + "//tensorflow/lite:string_util", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels/internal:compatibility", + "//tensorflow/lite/kernels/internal:tensor", + ], +) diff --git a/tensorflow/lite/experimental/resource/lookup_interfaces.h b/tensorflow/lite/experimental/resource/lookup_interfaces.h new file mode 100644 index 00000000000..42bb8e419e6 --- /dev/null +++ b/tensorflow/lite/experimental/resource/lookup_interfaces.h @@ -0,0 +1,64 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_INTERFACES_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_INTERFACES_H_ + +#include + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/resource/lookup_util.h" +#include "tensorflow/lite/experimental/resource/resource_base.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/string_util.h" + +namespace tflite { +namespace resource { + +/// WARNING: Experimental interface, subject to change. +// A resource hash table interface. It's similar to TensorFlow core's +// LookupInterface class. But it's identified with int32 ID in TFLite (instead +// of using Resource handle like TensorFlow). +class LookupInterface : public ResourceBase { + public: + virtual TfLiteStatus Lookup(TfLiteContext* context, const TfLiteTensor* keys, + TfLiteTensor* values, + const TfLiteTensor* default_value) = 0; + virtual TfLiteStatus Import(TfLiteContext* context, const TfLiteTensor* keys, + const TfLiteTensor* values) = 0; + virtual size_t Size() = 0; + + virtual TfLiteType GetKeyType() const = 0; + virtual TfLiteType GetValueType() const = 0; + virtual TfLiteStatus CheckKeyAndValueTypes(TfLiteContext* context, + const TfLiteTensor* keys, + const TfLiteTensor* values) = 0; +}; + +// Creates an resource hash table, shared among all the subgraphs with the +// given resource id if there is an existing one. +// WARNING: Experimental interface, subject to change. +void CreateHashtableResourceIfNotAvailable(ResourceMap* resources, + int resource_id, + TfLiteType key_dtype, + TfLiteType value_dtype); + +// Returns the corresponding resource hash table, or nullptr if none. +// WARNING: Experimental interface, subject to change. +LookupInterface* GetHashtableResource(ResourceMap* resources, int resource_id); + +} // namespace resource +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_INTERFACES_H_ diff --git a/tensorflow/lite/experimental/resource/lookup_util.h b/tensorflow/lite/experimental/resource/lookup_util.h new file mode 100644 index 00000000000..bb2c1c53ce5 --- /dev/null +++ b/tensorflow/lite/experimental/resource/lookup_util.h @@ -0,0 +1,114 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_UTIL_H_ + +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/string_util.h" + +namespace tflite { +namespace resource { +namespace internal { + +/// Helper class for accessing TFLite tensor data. +template +class TensorReader { + public: + explicit TensorReader(const TfLiteTensor* input) { + input_data_ = GetTensorData(input); + } + + // Returns the corresponding scalar data at the given index position. + // In here, it does not check the validity of the index should be guaranteed + // in order not to harm the performance. Caller should take care of it. + T GetData(int index) { return input_data_[index]; } + + private: + const T* input_data_; +}; + +/// Helper class for accesing TFLite tensor data. This specialized class is for +/// std::string type. +template <> +class TensorReader { + public: + explicit TensorReader(const TfLiteTensor* input) : input_(input) {} + + // Returns the corresponding string data at the given index position. + // In here, it does not check the validity of the index should be guaranteed + // in order not to harm the performance. Caller should take care of it. + std::string GetData(int index) { + auto string_ref = GetString(input_, index); + return std::string(string_ref.str, string_ref.len); + } + + private: + const TfLiteTensor* input_; +}; + +/// WARNING: Experimental interface, subject to change. +/// Helper class for writing TFLite tensor data. +template +class TensorWriter { + public: + explicit TensorWriter(TfLiteTensor* values) { + output_data_ = GetTensorData(values); + } + + // Sets the given value to the given index position of the tensor storage. + // In here, it does not check the validity of the index should be guaranteed + // in order not to harm the performance. Caller should take care of it. + void SetData(int index, ValueType& value) { output_data_[index] = value; } + + // Commit updates. In this case, it does nothing since the SetData method + // writes data directly. + void Commit() { + // Noop. + } + + private: + ValueType* output_data_; +}; + +/// WARNING: Experimental interface, subject to change. +/// Helper class for writing TFLite tensor data. This specialized class is for +/// std::string type. +template <> +class TensorWriter { + public: + explicit TensorWriter(TfLiteTensor* values) : values_(values) {} + + // Queues the given string value to the buffer regardless of the provided + // index. + // In here, it does not check the validity of the index should be guaranteed + // in order not to harm the performance. Caller should take care of it. + void SetData(int index, const std::string& value) { + buf_.AddString(value.data(), value.length()); + } + + // Commit updates. The stored data in DynamicBuffer will be written into the + // tensor storage. + void Commit() { buf_.WriteToTensor(values_, nullptr); } + + private: + TfLiteTensor* values_; + DynamicBuffer buf_; +}; + +} // namespace internal +} // namespace resource +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_UTIL_H_ diff --git a/tensorflow/lite/experimental/resource/resource_base.h b/tensorflow/lite/experimental/resource/resource_base.h new file mode 100644 index 00000000000..48a00b93957 --- /dev/null +++ b/tensorflow/lite/experimental/resource/resource_base.h @@ -0,0 +1,43 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_BASE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_BASE_H_ + +#include +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" + +namespace tflite { +namespace resource { + +// ResourceBase is an abstract base class for resources. +/// WARNING: Experimental interface, subject to change. +class ResourceBase { + public: + explicit ResourceBase() {} + virtual ~ResourceBase() {} + + // Returns true if it is initialized. + virtual bool IsInitialized() = 0; +}; + +/// WARNING: Experimental interface, subject to change. +using ResourceMap = std::unordered_map>; + +} // namespace resource +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_BASE_H_ diff --git a/tensorflow/lite/experimental/resource_variable/resource_variable.cc b/tensorflow/lite/experimental/resource/resource_variable.cc similarity index 77% rename from tensorflow/lite/experimental/resource_variable/resource_variable.cc rename to tensorflow/lite/experimental/resource/resource_variable.cc index 502ca273464..c16db39047c 100644 --- a/tensorflow/lite/experimental/resource_variable/resource_variable.cc +++ b/tensorflow/lite/experimental/resource/resource_variable.cc @@ -13,13 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/resource_variable/resource_variable.h" +#include "tensorflow/lite/experimental/resource/resource_variable.h" #include #include #include +#include namespace tflite { +namespace resource { ResourceVariable::ResourceVariable() { memset(&tensor_, 0, sizeof(TfLiteTensor)); @@ -75,4 +77,22 @@ TfLiteStatus ResourceVariable::AssignFrom(const TfLiteTensor* tensor) { return kTfLiteOk; } +void CreateResourceVariableIfNotAvailable(ResourceMap* resources, + int resource_id) { + if (resources->count(resource_id) != 0) { + return; + } + resources->emplace( + resource_id, std::unique_ptr(new ResourceVariable())); +} + +ResourceVariable* GetResourceVariable(ResourceMap* resources, int resource_id) { + auto it = resources->find(resource_id); + if (it != resources->end()) { + return static_cast(it->second.get()); + } + return nullptr; +} + +} // namespace resource } // namespace tflite diff --git a/tensorflow/lite/experimental/resource_variable/resource_variable.h b/tensorflow/lite/experimental/resource/resource_variable.h similarity index 65% rename from tensorflow/lite/experimental/resource_variable/resource_variable.h rename to tensorflow/lite/experimental/resource/resource_variable.h index 6a938489eea..1e832c79625 100644 --- a/tensorflow/lite/experimental/resource_variable/resource_variable.h +++ b/tensorflow/lite/experimental/resource/resource_variable.h @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_VARIABLE_RESOURCE_VARIABLE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_VARIABLE_RESOURCE_VARIABLE_H_ - -#include +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_VARIABLE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_VARIABLE_H_ #include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/resource/resource_base.h" namespace tflite { +namespace resource { /// WARNING: Experimental interface, subject to change. // A resource variable class. It's similar to TensorFlow Resource @@ -28,7 +28,7 @@ namespace tflite { // // TODO(b/137042749): TFLite converter cannot convert variables yet. // Variable functionalities are only tested with unit tests now. -class ResourceVariable { +class ResourceVariable : public ResourceBase { public: ResourceVariable(); ResourceVariable(ResourceVariable&& other); @@ -36,7 +36,7 @@ class ResourceVariable { ResourceVariable(const ResourceVariable&) = delete; ResourceVariable& operator=(const ResourceVariable&) = delete; - ~ResourceVariable(); + ~ResourceVariable() override; // Assigns data from a tensor. Copies its type, shape and data over. TfLiteStatus AssignFrom(const TfLiteTensor* tensor); @@ -46,6 +46,9 @@ class ResourceVariable { // `AssignFrom`. TfLiteTensor* GetTensor() { return is_initialized_ ? &tensor_ : nullptr; } + // Returns true if this resource variable is initialized. + bool IsInitialized() override { return is_initialized_; } + private: // The tensor (and its buffer stored in `tensor_.data` is fully owned by // the `ResourceVariable` object. @@ -55,8 +58,17 @@ class ResourceVariable { bool is_initialized_ = false; }; -using ResourceVariableMap = std::unordered_map; +// Creates a resource variable, shared among all the subgraphs with the given +// resource id if there is an existing one. +// WARNING: Experimental interface, subject to change. +void CreateResourceVariableIfNotAvailable(ResourceMap* resources, + int resource_id); +// Returns the corresponding resource variable, or nullptr if none. +// WARNING: Experimental interface, subject to change. +ResourceVariable* GetResourceVariable(ResourceMap* resources, int resource_id); + +} // namespace resource } // namespace tflite -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_VARIABLE_RESOURCE_VARIABLE_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_VARIABLE_H_ diff --git a/tensorflow/lite/experimental/resource/static_hashtable.cc b/tensorflow/lite/experimental/resource/static_hashtable.cc new file mode 100644 index 00000000000..7da85b38e87 --- /dev/null +++ b/tensorflow/lite/experimental/resource/static_hashtable.cc @@ -0,0 +1,129 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/resource/static_hashtable.h" + +#include +#include "tensorflow/lite/experimental/resource/lookup_interfaces.h" + +namespace tflite { +namespace resource { +namespace internal { + +template +TfLiteStatus StaticHashtable::Lookup( + TfLiteContext* context, const TfLiteTensor* keys, TfLiteTensor* values, + const TfLiteTensor* default_value) { + TF_LITE_ENSURE(context, is_initialized_); + const int size = + MatchingFlatSize(GetTensorShape(keys), GetTensorShape(values)); + + auto key_tensor_reader = TensorReader(keys); + auto value_tensor_writer = TensorWriter(values); + auto default_value_tensor_reader = TensorReader(default_value); + ValueType first_default_value = default_value_tensor_reader.GetData(0); + + for (int i = 0; i < size; ++i) { + auto result = map_.find(key_tensor_reader.GetData(i)); + if (result != map_.end()) { + value_tensor_writer.SetData(i, result->second); + } else { + value_tensor_writer.SetData(i, first_default_value); + } + } + + // This is for a string tensor case in order to write buffer back to the + // actual tensor destination. Otherwise, it does nothing since the scalar data + // will be written into the tensor storage directly. + value_tensor_writer.Commit(); + + return kTfLiteOk; +} + +template +TfLiteStatus StaticHashtable::Import( + TfLiteContext* context, const TfLiteTensor* keys, + const TfLiteTensor* values) { + // Import nodes can be invoked twice because the converter will not extract + // the initializer graph separately from the original graph. The invocations + // after the first call will be ignored. + if (is_initialized_) { + return kTfLiteOk; + } + + const int size = + MatchingFlatSize(GetTensorShape(keys), GetTensorShape(values)); + + auto key_tensor_reader = TensorReader(keys); + auto value_tensor_writer = TensorReader(values); + for (int i = 0; i < size; ++i) { + map_.insert({key_tensor_reader.GetData(i), value_tensor_writer.GetData(i)}); + } + + is_initialized_ = true; + return kTfLiteOk; +} + +template +LookupInterface* CreateStaticHashtableWithGivenKey(TfLiteType key_type, + TfLiteType value_type) { + switch (value_type) { + case kTfLiteInt32: + return new StaticHashtable(key_type, value_type); + case kTfLiteString: + return new StaticHashtable(key_type, value_type); + case kTfLiteFloat32: + return new StaticHashtable(key_type, value_type); + default: + return nullptr; + } +} + +LookupInterface* CreateStaticHashtable(TfLiteType key_type, + TfLiteType value_type) { + switch (key_type) { + case kTfLiteInt32: + return CreateStaticHashtableWithGivenKey(key_type, value_type); + case kTfLiteString: + return CreateStaticHashtableWithGivenKey(key_type, + value_type); + default: + return nullptr; + } +} + +} // namespace internal + +void CreateHashtableResourceIfNotAvailable(ResourceMap* resources, + int resource_id, + TfLiteType key_dtype, + TfLiteType value_dtype) { + if (resources->count(resource_id) != 0) { + return; + } + auto* hashtable = internal::CreateStaticHashtable(key_dtype, value_dtype); + resources->emplace(resource_id, std::unique_ptr(hashtable)); +} + +LookupInterface* GetHashtableResource(ResourceMap* resources, int resource_id) { + auto it = resources->find(resource_id); + if (it != resources->end()) { + return static_cast(it->second.get()); + } + return nullptr; +} + +} // namespace resource +} // namespace tflite diff --git a/tensorflow/lite/experimental/resource/static_hashtable.h b/tensorflow/lite/experimental/resource/static_hashtable.h new file mode 100644 index 00000000000..84e68b72930 --- /dev/null +++ b/tensorflow/lite/experimental/resource/static_hashtable.h @@ -0,0 +1,84 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_STATIC_HASHTABLE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_STATIC_HASHTABLE_H_ + +#include + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/resource/lookup_interfaces.h" +#include "tensorflow/lite/experimental/resource/lookup_util.h" +#include "tensorflow/lite/experimental/resource/resource_base.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/string_util.h" + +namespace tflite { +namespace resource { +namespace internal { + +// A static hash table class. This hash table allows initialization one time in +// its life cycle. This hash table implements Tensorflow core's HashTableV2 op. +template +class StaticHashtable : public tflite::resource::LookupInterface { + public: + explicit StaticHashtable(TfLiteType key_type, TfLiteType value_type) + : key_type_(key_type), value_type_(value_type) {} + ~StaticHashtable() override {} + + // Finds the corresponding value of the given keys tensor in the map and + // copies the result data to the given values tensor. If there is no matching + // value, it will write the default value into the matched position instead. + TfLiteStatus Lookup(TfLiteContext* context, const TfLiteTensor* keys, + TfLiteTensor* values, + const TfLiteTensor* default_value) override; + + // Inserts the given key and value tensor data into the hash table. + TfLiteStatus Import(TfLiteContext* context, const TfLiteTensor* keys, + const TfLiteTensor* values) override; + + // Returns the item size of the hash table. + size_t Size() override { return map_.size(); } + + TfLiteType GetKeyType() const override { return key_type_; } + TfLiteType GetValueType() const override { return value_type_; } + + TfLiteStatus CheckKeyAndValueTypes(TfLiteContext* context, + const TfLiteTensor* keys, + const TfLiteTensor* values) override { + TF_LITE_ENSURE_EQ(context, keys->type, key_type_); + TF_LITE_ENSURE_EQ(context, values->type, value_type_); + return kTfLiteOk; + } + + // Returns true if the hash table is initialized. + bool IsInitialized() override { return is_initialized_; } + + private: + TfLiteType key_type_; + TfLiteType value_type_; + + std::unordered_map map_; + bool is_initialized_ = false; +}; + +::tflite::resource::LookupInterface* CreateStaticHashtable( + TfLiteType key_type, TfLiteType value_type); + +} // namespace internal + +} // namespace resource +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_STATIC_HASHTABLE_H_ diff --git a/tensorflow/lite/experimental/resource_variable/BUILD b/tensorflow/lite/experimental/resource_variable/BUILD deleted file mode 100644 index af2ed19d214..00000000000 --- a/tensorflow/lite/experimental/resource_variable/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "resource_variable", - srcs = [ - "resource_variable.cc", - ], - hdrs = [ - "resource_variable.h", - ], - deps = [ - "//tensorflow/lite/c:c_api_internal", - ], -) diff --git a/tensorflow/lite/interpreter.cc b/tensorflow/lite/interpreter.cc index 10d857c05e6..07a8af1aa0e 100644 --- a/tensorflow/lite/interpreter.cc +++ b/tensorflow/lite/interpreter.cc @@ -155,7 +155,7 @@ void Interpreter::AddSubgraphs(int subgraphs_to_add, subgraphs_.reserve(base_index + subgraphs_to_add); for (int i = 0; i < subgraphs_to_add; ++i) { Subgraph* subgraph = new Subgraph(error_reporter_, external_contexts_, - &subgraphs_, &resource_variables_); + &subgraphs_, &resources_); subgraphs_.emplace_back(subgraph); } } diff --git a/tensorflow/lite/interpreter.h b/tensorflow/lite/interpreter.h index 1d2664a028c..3976c278d11 100644 --- a/tensorflow/lite/interpreter.h +++ b/tensorflow/lite/interpreter.h @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/profiler.h" #include "tensorflow/lite/core/subgraph.h" -#include "tensorflow/lite/experimental/resource_variable/resource_variable.h" +#include "tensorflow/lite/experimental/resource/resource_base.h" #include "tensorflow/lite/external_cpu_backend_context.h" #include "tensorflow/lite/memory_planner.h" #include "tensorflow/lite/stderr_reporter.h" @@ -522,9 +522,8 @@ class Interpreter { // Subgraphs std::vector> subgraphs_; - // A map of resource variables. Owned by interpreter and shared by multiple - // subgraphs. - ResourceVariableMap resource_variables_; + // A map of resources. Owned by interpreter and shared by multiple subgraphs. + resource::ResourceMap resources_; }; } // namespace tflite diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 65bb0ea062e..d5fae3c9937 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -528,6 +528,7 @@ cc_library( "//tensorflow/lite:framework", "//tensorflow/lite:string_util", "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/resource", "//tensorflow/lite/kernels/internal:audio_utils", "//tensorflow/lite/kernels/internal:common", "//tensorflow/lite/kernels/internal:compatibility", @@ -561,6 +562,7 @@ cc_library( ":op_macros", "//tensorflow/lite:framework", "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/resource", "//tensorflow/lite/kernels/internal:tensor", ], ) diff --git a/tensorflow/lite/kernels/assign_variable.cc b/tensorflow/lite/kernels/assign_variable.cc index 099b8e16cfb..ac4ce79ffa1 100644 --- a/tensorflow/lite/kernels/assign_variable.cc +++ b/tensorflow/lite/kernels/assign_variable.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/core/subgraph.h" +#include "tensorflow/lite/experimental/resource/resource_variable.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/op_macros.h" @@ -43,10 +44,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // everything still works fine when variable ops aren't used. TF_LITE_ENSURE_EQ(context, NumOutputs(node), 0); - const TfLiteTensor* input_variable_id_tensor = + const TfLiteTensor* input_resource_id_tensor = GetInput(context, node, kInputVariableId); - TF_LITE_ENSURE_EQ(context, input_variable_id_tensor->type, kTfLiteInt32); - TF_LITE_ENSURE_EQ(context, NumElements(input_variable_id_tensor), 1); + TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, NumElements(input_resource_id_tensor), 1); return kTfLiteOk; } @@ -54,21 +55,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { Subgraph* subgraph = reinterpret_cast(context->impl_); - const TfLiteTensor* input_variable_id_tensor = + const TfLiteTensor* input_resource_id_tensor = GetInput(context, node, kInputVariableId); const TfLiteTensor* input_value_tensor = GetInput(context, node, kInputValue); - int variable_id = input_variable_id_tensor->data.i32[0]; - auto& resource_variables = subgraph->resource_variables(); - - auto variable_iterator = resource_variables.find(variable_id); - if (variable_iterator == resource_variables.end()) { - auto ret = resource_variables.emplace(variable_id, ResourceVariable()); - variable_iterator = ret.first; - } - - auto& variable = variable_iterator->second; - variable.AssignFrom(input_value_tensor); + int resource_id = input_resource_id_tensor->data.i32[0]; + auto& resources = subgraph->resources(); + resource::CreateResourceVariableIfNotAvailable(&resources, resource_id); + auto* variable = resource::GetResourceVariable(&resources, resource_id); + TF_LITE_ENSURE(context, variable != nullptr); + variable->AssignFrom(input_value_tensor); return kTfLiteOk; } diff --git a/tensorflow/lite/kernels/read_variable.cc b/tensorflow/lite/kernels/read_variable.cc index 4996bcc0b4a..891cad90e69 100644 --- a/tensorflow/lite/kernels/read_variable.cc +++ b/tensorflow/lite/kernels/read_variable.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/core/subgraph.h" +#include "tensorflow/lite/experimental/resource/resource_variable.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/op_macros.h" @@ -36,10 +37,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, node->inputs->size, 1); TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); - const TfLiteTensor* input_variable_id_tensor = + const TfLiteTensor* input_resource_id_tensor = GetInput(context, node, kInputVariableId); - TF_LITE_ENSURE_EQ(context, input_variable_id_tensor->type, kTfLiteInt32); - TF_LITE_ENSURE_EQ(context, NumElements(input_variable_id_tensor), 1); + TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, NumElements(input_resource_id_tensor), 1); TfLiteTensor* output = GetOutput(context, node, kOutputValue); SetTensorToDynamic(output); @@ -50,20 +51,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { Subgraph* subgraph = reinterpret_cast(context->impl_); - const TfLiteTensor* input_variable_id_tensor = + const TfLiteTensor* input_resource_id_tensor = GetInput(context, node, kInputVariableId); - int variable_id = input_variable_id_tensor->data.i32[0]; - auto& resource_variables = subgraph->resource_variables(); + int resource_id = input_resource_id_tensor->data.i32[0]; + auto& resources = subgraph->resources(); + auto* variable = resource::GetResourceVariable(&resources, resource_id); + TF_LITE_ENSURE(context, variable != nullptr); - const auto& variable_iterator = resource_variables.find(variable_id); - if (variable_iterator == resource_variables.end()) { - context->ReportError(context, "Variable ID %d is read before initialized.", - variable_id); - return kTfLiteError; - } - auto& variable = variable_iterator->second; - - TfLiteTensor* variable_tensor = variable.GetTensor(); + TfLiteTensor* variable_tensor = variable->GetTensor(); TfLiteTensor* output = GetOutput(context, node, kOutputValue); TF_LITE_ENSURE_EQ(context, variable_tensor->type, output->type);