Add hash table resource implementation into TFLite
PiperOrigin-RevId: 282460990 Change-Id: If489e2ab2d3e9ad2b4899f4ec055e611def6696d
This commit is contained in:
parent
1b17b134a9
commit
3c725ae67b
@ -215,7 +215,7 @@ cc_library(
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
|
||||
"//tensorflow/lite/experimental/resource_variable",
|
||||
"//tensorflow/lite/experimental/resource",
|
||||
"//tensorflow/lite/nnapi:nnapi_implementation",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
|
@ -158,13 +158,13 @@ class InterpreterInfo : public GraphInfo {
|
||||
Subgraph::Subgraph(ErrorReporter* error_reporter,
|
||||
TfLiteExternalContext** external_contexts,
|
||||
std::vector<std::unique_ptr<Subgraph>>* subgraphs,
|
||||
ResourceVariableMap* resource_variables)
|
||||
resource::ResourceMap* resources)
|
||||
: external_contexts_(external_contexts),
|
||||
error_reporter_(error_reporter),
|
||||
next_execution_plan_index_to_prepare_(0),
|
||||
next_execution_plan_index_to_plan_allocation_(0),
|
||||
subgraphs_(subgraphs),
|
||||
resource_variables_(resource_variables) {
|
||||
resources_(resources) {
|
||||
context_.impl_ = static_cast<void*>(this);
|
||||
context_.ResizeTensor = ResizeTensor;
|
||||
context_.ReportError = ReportErrorC;
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/core/api/profiler.h"
|
||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||
#include "tensorflow/lite/experimental/resource_variable/resource_variable.h"
|
||||
#include "tensorflow/lite/experimental/resource/resource_base.h"
|
||||
#include "tensorflow/lite/memory_planner.h"
|
||||
#include "tensorflow/lite/util.h"
|
||||
|
||||
@ -40,7 +40,7 @@ class Subgraph {
|
||||
Subgraph(ErrorReporter* error_reporter,
|
||||
TfLiteExternalContext** external_contexts,
|
||||
std::vector<std::unique_ptr<Subgraph>>* subgraphs,
|
||||
ResourceVariableMap* resource_variables);
|
||||
resource::ResourceMap* resources);
|
||||
|
||||
Subgraph(const Subgraph&) = delete;
|
||||
|
||||
@ -166,7 +166,7 @@ class Subgraph {
|
||||
|
||||
// WARNING: Experimental interface, subject to change.
|
||||
// TODO(ycling): Move this function to an external context interface.
|
||||
ResourceVariableMap& resource_variables() { return *resource_variables_; }
|
||||
resource::ResourceMap& resources() { return *resources_; }
|
||||
|
||||
size_t tensors_size() const { return tensors_.size(); }
|
||||
|
||||
@ -635,9 +635,8 @@ class Subgraph {
|
||||
// `check_cancelled_func_`.
|
||||
void* cancellation_data_ = nullptr;
|
||||
|
||||
// A map of resource variables. Owned by interpreter and shared by multiple
|
||||
// subgraphs.
|
||||
ResourceVariableMap* resource_variables_ = nullptr;
|
||||
// A map of resources. Owned by interpreter and shared by multiple subgraphs.
|
||||
resource::ResourceMap* resources_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
25
tensorflow/lite/experimental/resource/BUILD
Normal file
25
tensorflow/lite/experimental/resource/BUILD
Normal file
@ -0,0 +1,25 @@
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "resource",
|
||||
srcs = [
|
||||
"resource_variable.cc",
|
||||
"static_hashtable.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"lookup_interfaces.h",
|
||||
"lookup_util.h",
|
||||
"resource_base.h",
|
||||
"resource_variable.h",
|
||||
"static_hashtable.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/kernels/internal:compatibility",
|
||||
"//tensorflow/lite/kernels/internal:tensor",
|
||||
],
|
||||
)
|
64
tensorflow/lite/experimental/resource/lookup_interfaces.h
Normal file
64
tensorflow/lite/experimental/resource/lookup_interfaces.h
Normal file
@ -0,0 +1,64 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_INTERFACES_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_INTERFACES_H_
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/experimental/resource/lookup_util.h"
|
||||
#include "tensorflow/lite/experimental/resource/resource_base.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace resource {
|
||||
|
||||
/// WARNING: Experimental interface, subject to change.
|
||||
// A resource hash table interface. It's similar to TensorFlow core's
|
||||
// LookupInterface class. But it's identified with int32 ID in TFLite (instead
|
||||
// of using Resource handle like TensorFlow).
|
||||
class LookupInterface : public ResourceBase {
|
||||
public:
|
||||
virtual TfLiteStatus Lookup(TfLiteContext* context, const TfLiteTensor* keys,
|
||||
TfLiteTensor* values,
|
||||
const TfLiteTensor* default_value) = 0;
|
||||
virtual TfLiteStatus Import(TfLiteContext* context, const TfLiteTensor* keys,
|
||||
const TfLiteTensor* values) = 0;
|
||||
virtual size_t Size() = 0;
|
||||
|
||||
virtual TfLiteType GetKeyType() const = 0;
|
||||
virtual TfLiteType GetValueType() const = 0;
|
||||
virtual TfLiteStatus CheckKeyAndValueTypes(TfLiteContext* context,
|
||||
const TfLiteTensor* keys,
|
||||
const TfLiteTensor* values) = 0;
|
||||
};
|
||||
|
||||
// Creates an resource hash table, shared among all the subgraphs with the
|
||||
// given resource id if there is an existing one.
|
||||
// WARNING: Experimental interface, subject to change.
|
||||
void CreateHashtableResourceIfNotAvailable(ResourceMap* resources,
|
||||
int resource_id,
|
||||
TfLiteType key_dtype,
|
||||
TfLiteType value_dtype);
|
||||
|
||||
// Returns the corresponding resource hash table, or nullptr if none.
|
||||
// WARNING: Experimental interface, subject to change.
|
||||
LookupInterface* GetHashtableResource(ResourceMap* resources, int resource_id);
|
||||
|
||||
} // namespace resource
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_INTERFACES_H_
|
114
tensorflow/lite/experimental/resource/lookup_util.h
Normal file
114
tensorflow/lite/experimental/resource/lookup_util.h
Normal file
@ -0,0 +1,114 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_UTIL_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_UTIL_H_
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace resource {
|
||||
namespace internal {
|
||||
|
||||
/// Helper class for accessing TFLite tensor data.
|
||||
template <typename T>
|
||||
class TensorReader {
|
||||
public:
|
||||
explicit TensorReader(const TfLiteTensor* input) {
|
||||
input_data_ = GetTensorData<T>(input);
|
||||
}
|
||||
|
||||
// Returns the corresponding scalar data at the given index position.
|
||||
// In here, it does not check the validity of the index should be guaranteed
|
||||
// in order not to harm the performance. Caller should take care of it.
|
||||
T GetData(int index) { return input_data_[index]; }
|
||||
|
||||
private:
|
||||
const T* input_data_;
|
||||
};
|
||||
|
||||
/// Helper class for accesing TFLite tensor data. This specialized class is for
|
||||
/// std::string type.
|
||||
template <>
|
||||
class TensorReader<std::string> {
|
||||
public:
|
||||
explicit TensorReader(const TfLiteTensor* input) : input_(input) {}
|
||||
|
||||
// Returns the corresponding string data at the given index position.
|
||||
// In here, it does not check the validity of the index should be guaranteed
|
||||
// in order not to harm the performance. Caller should take care of it.
|
||||
std::string GetData(int index) {
|
||||
auto string_ref = GetString(input_, index);
|
||||
return std::string(string_ref.str, string_ref.len);
|
||||
}
|
||||
|
||||
private:
|
||||
const TfLiteTensor* input_;
|
||||
};
|
||||
|
||||
/// WARNING: Experimental interface, subject to change.
|
||||
/// Helper class for writing TFLite tensor data.
|
||||
template <typename ValueType>
|
||||
class TensorWriter {
|
||||
public:
|
||||
explicit TensorWriter(TfLiteTensor* values) {
|
||||
output_data_ = GetTensorData<ValueType>(values);
|
||||
}
|
||||
|
||||
// Sets the given value to the given index position of the tensor storage.
|
||||
// In here, it does not check the validity of the index should be guaranteed
|
||||
// in order not to harm the performance. Caller should take care of it.
|
||||
void SetData(int index, ValueType& value) { output_data_[index] = value; }
|
||||
|
||||
// Commit updates. In this case, it does nothing since the SetData method
|
||||
// writes data directly.
|
||||
void Commit() {
|
||||
// Noop.
|
||||
}
|
||||
|
||||
private:
|
||||
ValueType* output_data_;
|
||||
};
|
||||
|
||||
/// WARNING: Experimental interface, subject to change.
|
||||
/// Helper class for writing TFLite tensor data. This specialized class is for
|
||||
/// std::string type.
|
||||
template <>
|
||||
class TensorWriter<std::string> {
|
||||
public:
|
||||
explicit TensorWriter(TfLiteTensor* values) : values_(values) {}
|
||||
|
||||
// Queues the given string value to the buffer regardless of the provided
|
||||
// index.
|
||||
// In here, it does not check the validity of the index should be guaranteed
|
||||
// in order not to harm the performance. Caller should take care of it.
|
||||
void SetData(int index, const std::string& value) {
|
||||
buf_.AddString(value.data(), value.length());
|
||||
}
|
||||
|
||||
// Commit updates. The stored data in DynamicBuffer will be written into the
|
||||
// tensor storage.
|
||||
void Commit() { buf_.WriteToTensor(values_, nullptr); }
|
||||
|
||||
private:
|
||||
TfLiteTensor* values_;
|
||||
DynamicBuffer buf_;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
} // namespace resource
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_UTIL_H_
|
43
tensorflow/lite/experimental/resource/resource_base.h
Normal file
43
tensorflow/lite/experimental/resource/resource_base.h
Normal file
@ -0,0 +1,43 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_BASE_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_BASE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace resource {
|
||||
|
||||
// ResourceBase is an abstract base class for resources.
|
||||
/// WARNING: Experimental interface, subject to change.
|
||||
class ResourceBase {
|
||||
public:
|
||||
explicit ResourceBase() {}
|
||||
virtual ~ResourceBase() {}
|
||||
|
||||
// Returns true if it is initialized.
|
||||
virtual bool IsInitialized() = 0;
|
||||
};
|
||||
|
||||
/// WARNING: Experimental interface, subject to change.
|
||||
using ResourceMap = std::unordered_map<int32, std::unique_ptr<ResourceBase>>;
|
||||
|
||||
} // namespace resource
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_BASE_H_
|
@ -13,13 +13,15 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/experimental/resource_variable/resource_variable.h"
|
||||
#include "tensorflow/lite/experimental/resource/resource_variable.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
namespace tflite {
|
||||
namespace resource {
|
||||
|
||||
ResourceVariable::ResourceVariable() {
|
||||
memset(&tensor_, 0, sizeof(TfLiteTensor));
|
||||
@ -75,4 +77,22 @@ TfLiteStatus ResourceVariable::AssignFrom(const TfLiteTensor* tensor) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void CreateResourceVariableIfNotAvailable(ResourceMap* resources,
|
||||
int resource_id) {
|
||||
if (resources->count(resource_id) != 0) {
|
||||
return;
|
||||
}
|
||||
resources->emplace(
|
||||
resource_id, std::unique_ptr<ResourceVariable>(new ResourceVariable()));
|
||||
}
|
||||
|
||||
ResourceVariable* GetResourceVariable(ResourceMap* resources, int resource_id) {
|
||||
auto it = resources->find(resource_id);
|
||||
if (it != resources->end()) {
|
||||
return static_cast<ResourceVariable*>(it->second.get());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace resource
|
||||
} // namespace tflite
|
@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_VARIABLE_RESOURCE_VARIABLE_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_VARIABLE_RESOURCE_VARIABLE_H_
|
||||
|
||||
#include <unordered_map>
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_VARIABLE_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_VARIABLE_H_
|
||||
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/experimental/resource/resource_base.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace resource {
|
||||
|
||||
/// WARNING: Experimental interface, subject to change.
|
||||
// A resource variable class. It's similar to TensorFlow Resource
|
||||
@ -28,7 +28,7 @@ namespace tflite {
|
||||
//
|
||||
// TODO(b/137042749): TFLite converter cannot convert variables yet.
|
||||
// Variable functionalities are only tested with unit tests now.
|
||||
class ResourceVariable {
|
||||
class ResourceVariable : public ResourceBase {
|
||||
public:
|
||||
ResourceVariable();
|
||||
ResourceVariable(ResourceVariable&& other);
|
||||
@ -36,7 +36,7 @@ class ResourceVariable {
|
||||
ResourceVariable(const ResourceVariable&) = delete;
|
||||
ResourceVariable& operator=(const ResourceVariable&) = delete;
|
||||
|
||||
~ResourceVariable();
|
||||
~ResourceVariable() override;
|
||||
|
||||
// Assigns data from a tensor. Copies its type, shape and data over.
|
||||
TfLiteStatus AssignFrom(const TfLiteTensor* tensor);
|
||||
@ -46,6 +46,9 @@ class ResourceVariable {
|
||||
// `AssignFrom`.
|
||||
TfLiteTensor* GetTensor() { return is_initialized_ ? &tensor_ : nullptr; }
|
||||
|
||||
// Returns true if this resource variable is initialized.
|
||||
bool IsInitialized() override { return is_initialized_; }
|
||||
|
||||
private:
|
||||
// The tensor (and its buffer stored in `tensor_.data` is fully owned by
|
||||
// the `ResourceVariable` object.
|
||||
@ -55,8 +58,17 @@ class ResourceVariable {
|
||||
bool is_initialized_ = false;
|
||||
};
|
||||
|
||||
using ResourceVariableMap = std::unordered_map<int, ResourceVariable>;
|
||||
// Creates a resource variable, shared among all the subgraphs with the given
|
||||
// resource id if there is an existing one.
|
||||
// WARNING: Experimental interface, subject to change.
|
||||
void CreateResourceVariableIfNotAvailable(ResourceMap* resources,
|
||||
int resource_id);
|
||||
|
||||
// Returns the corresponding resource variable, or nullptr if none.
|
||||
// WARNING: Experimental interface, subject to change.
|
||||
ResourceVariable* GetResourceVariable(ResourceMap* resources, int resource_id);
|
||||
|
||||
} // namespace resource
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_VARIABLE_RESOURCE_VARIABLE_H_
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_VARIABLE_H_
|
129
tensorflow/lite/experimental/resource/static_hashtable.cc
Normal file
129
tensorflow/lite/experimental/resource/static_hashtable.cc
Normal file
@ -0,0 +1,129 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/experimental/resource/static_hashtable.h"
|
||||
|
||||
#include <memory>
|
||||
#include "tensorflow/lite/experimental/resource/lookup_interfaces.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace resource {
|
||||
namespace internal {
|
||||
|
||||
template <typename KeyType, typename ValueType>
|
||||
TfLiteStatus StaticHashtable<KeyType, ValueType>::Lookup(
|
||||
TfLiteContext* context, const TfLiteTensor* keys, TfLiteTensor* values,
|
||||
const TfLiteTensor* default_value) {
|
||||
TF_LITE_ENSURE(context, is_initialized_);
|
||||
const int size =
|
||||
MatchingFlatSize(GetTensorShape(keys), GetTensorShape(values));
|
||||
|
||||
auto key_tensor_reader = TensorReader<KeyType>(keys);
|
||||
auto value_tensor_writer = TensorWriter<ValueType>(values);
|
||||
auto default_value_tensor_reader = TensorReader<ValueType>(default_value);
|
||||
ValueType first_default_value = default_value_tensor_reader.GetData(0);
|
||||
|
||||
for (int i = 0; i < size; ++i) {
|
||||
auto result = map_.find(key_tensor_reader.GetData(i));
|
||||
if (result != map_.end()) {
|
||||
value_tensor_writer.SetData(i, result->second);
|
||||
} else {
|
||||
value_tensor_writer.SetData(i, first_default_value);
|
||||
}
|
||||
}
|
||||
|
||||
// This is for a string tensor case in order to write buffer back to the
|
||||
// actual tensor destination. Otherwise, it does nothing since the scalar data
|
||||
// will be written into the tensor storage directly.
|
||||
value_tensor_writer.Commit();
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <typename KeyType, typename ValueType>
|
||||
TfLiteStatus StaticHashtable<KeyType, ValueType>::Import(
|
||||
TfLiteContext* context, const TfLiteTensor* keys,
|
||||
const TfLiteTensor* values) {
|
||||
// Import nodes can be invoked twice because the converter will not extract
|
||||
// the initializer graph separately from the original graph. The invocations
|
||||
// after the first call will be ignored.
|
||||
if (is_initialized_) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
const int size =
|
||||
MatchingFlatSize(GetTensorShape(keys), GetTensorShape(values));
|
||||
|
||||
auto key_tensor_reader = TensorReader<KeyType>(keys);
|
||||
auto value_tensor_writer = TensorReader<ValueType>(values);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
map_.insert({key_tensor_reader.GetData(i), value_tensor_writer.GetData(i)});
|
||||
}
|
||||
|
||||
is_initialized_ = true;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <typename KeyType>
|
||||
LookupInterface* CreateStaticHashtableWithGivenKey(TfLiteType key_type,
|
||||
TfLiteType value_type) {
|
||||
switch (value_type) {
|
||||
case kTfLiteInt32:
|
||||
return new StaticHashtable<KeyType, int32>(key_type, value_type);
|
||||
case kTfLiteString:
|
||||
return new StaticHashtable<KeyType, std::string>(key_type, value_type);
|
||||
case kTfLiteFloat32:
|
||||
return new StaticHashtable<KeyType, float>(key_type, value_type);
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
LookupInterface* CreateStaticHashtable(TfLiteType key_type,
|
||||
TfLiteType value_type) {
|
||||
switch (key_type) {
|
||||
case kTfLiteInt32:
|
||||
return CreateStaticHashtableWithGivenKey<int32>(key_type, value_type);
|
||||
case kTfLiteString:
|
||||
return CreateStaticHashtableWithGivenKey<std::string>(key_type,
|
||||
value_type);
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
|
||||
void CreateHashtableResourceIfNotAvailable(ResourceMap* resources,
|
||||
int resource_id,
|
||||
TfLiteType key_dtype,
|
||||
TfLiteType value_dtype) {
|
||||
if (resources->count(resource_id) != 0) {
|
||||
return;
|
||||
}
|
||||
auto* hashtable = internal::CreateStaticHashtable(key_dtype, value_dtype);
|
||||
resources->emplace(resource_id, std::unique_ptr<LookupInterface>(hashtable));
|
||||
}
|
||||
|
||||
LookupInterface* GetHashtableResource(ResourceMap* resources, int resource_id) {
|
||||
auto it = resources->find(resource_id);
|
||||
if (it != resources->end()) {
|
||||
return static_cast<LookupInterface*>(it->second.get());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace resource
|
||||
} // namespace tflite
|
84
tensorflow/lite/experimental/resource/static_hashtable.h
Normal file
84
tensorflow/lite/experimental/resource/static_hashtable.h
Normal file
@ -0,0 +1,84 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_STATIC_HASHTABLE_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_STATIC_HASHTABLE_H_
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/experimental/resource/lookup_interfaces.h"
|
||||
#include "tensorflow/lite/experimental/resource/lookup_util.h"
|
||||
#include "tensorflow/lite/experimental/resource/resource_base.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace resource {
|
||||
namespace internal {
|
||||
|
||||
// A static hash table class. This hash table allows initialization one time in
|
||||
// its life cycle. This hash table implements Tensorflow core's HashTableV2 op.
|
||||
template <typename KeyType, typename ValueType>
|
||||
class StaticHashtable : public tflite::resource::LookupInterface {
|
||||
public:
|
||||
explicit StaticHashtable(TfLiteType key_type, TfLiteType value_type)
|
||||
: key_type_(key_type), value_type_(value_type) {}
|
||||
~StaticHashtable() override {}
|
||||
|
||||
// Finds the corresponding value of the given keys tensor in the map and
|
||||
// copies the result data to the given values tensor. If there is no matching
|
||||
// value, it will write the default value into the matched position instead.
|
||||
TfLiteStatus Lookup(TfLiteContext* context, const TfLiteTensor* keys,
|
||||
TfLiteTensor* values,
|
||||
const TfLiteTensor* default_value) override;
|
||||
|
||||
// Inserts the given key and value tensor data into the hash table.
|
||||
TfLiteStatus Import(TfLiteContext* context, const TfLiteTensor* keys,
|
||||
const TfLiteTensor* values) override;
|
||||
|
||||
// Returns the item size of the hash table.
|
||||
size_t Size() override { return map_.size(); }
|
||||
|
||||
TfLiteType GetKeyType() const override { return key_type_; }
|
||||
TfLiteType GetValueType() const override { return value_type_; }
|
||||
|
||||
TfLiteStatus CheckKeyAndValueTypes(TfLiteContext* context,
|
||||
const TfLiteTensor* keys,
|
||||
const TfLiteTensor* values) override {
|
||||
TF_LITE_ENSURE_EQ(context, keys->type, key_type_);
|
||||
TF_LITE_ENSURE_EQ(context, values->type, value_type_);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// Returns true if the hash table is initialized.
|
||||
bool IsInitialized() override { return is_initialized_; }
|
||||
|
||||
private:
|
||||
TfLiteType key_type_;
|
||||
TfLiteType value_type_;
|
||||
|
||||
std::unordered_map<KeyType, ValueType> map_;
|
||||
bool is_initialized_ = false;
|
||||
};
|
||||
|
||||
::tflite::resource::LookupInterface* CreateStaticHashtable(
|
||||
TfLiteType key_type, TfLiteType value_type);
|
||||
|
||||
} // namespace internal
|
||||
|
||||
} // namespace resource
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_STATIC_HASHTABLE_H_
|
@ -1,17 +0,0 @@
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "resource_variable",
|
||||
srcs = [
|
||||
"resource_variable.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"resource_variable.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
],
|
||||
)
|
@ -155,7 +155,7 @@ void Interpreter::AddSubgraphs(int subgraphs_to_add,
|
||||
subgraphs_.reserve(base_index + subgraphs_to_add);
|
||||
for (int i = 0; i < subgraphs_to_add; ++i) {
|
||||
Subgraph* subgraph = new Subgraph(error_reporter_, external_contexts_,
|
||||
&subgraphs_, &resource_variables_);
|
||||
&subgraphs_, &resources_);
|
||||
subgraphs_.emplace_back(subgraph);
|
||||
}
|
||||
}
|
||||
|
@ -28,7 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/core/api/profiler.h"
|
||||
#include "tensorflow/lite/core/subgraph.h"
|
||||
#include "tensorflow/lite/experimental/resource_variable/resource_variable.h"
|
||||
#include "tensorflow/lite/experimental/resource/resource_base.h"
|
||||
#include "tensorflow/lite/external_cpu_backend_context.h"
|
||||
#include "tensorflow/lite/memory_planner.h"
|
||||
#include "tensorflow/lite/stderr_reporter.h"
|
||||
@ -522,9 +522,8 @@ class Interpreter {
|
||||
// Subgraphs
|
||||
std::vector<std::unique_ptr<Subgraph>> subgraphs_;
|
||||
|
||||
// A map of resource variables. Owned by interpreter and shared by multiple
|
||||
// subgraphs.
|
||||
ResourceVariableMap resource_variables_;
|
||||
// A map of resources. Owned by interpreter and shared by multiple subgraphs.
|
||||
resource::ResourceMap resources_;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -528,6 +528,7 @@ cc_library(
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/experimental/resource",
|
||||
"//tensorflow/lite/kernels/internal:audio_utils",
|
||||
"//tensorflow/lite/kernels/internal:common",
|
||||
"//tensorflow/lite/kernels/internal:compatibility",
|
||||
@ -561,6 +562,7 @@ cc_library(
|
||||
":op_macros",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/experimental/resource",
|
||||
"//tensorflow/lite/kernels/internal:tensor",
|
||||
],
|
||||
)
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/core/subgraph.h"
|
||||
#include "tensorflow/lite/experimental/resource/resource_variable.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
@ -43,10 +44,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// everything still works fine when variable ops aren't used.
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 0);
|
||||
|
||||
const TfLiteTensor* input_variable_id_tensor =
|
||||
const TfLiteTensor* input_resource_id_tensor =
|
||||
GetInput(context, node, kInputVariableId);
|
||||
TF_LITE_ENSURE_EQ(context, input_variable_id_tensor->type, kTfLiteInt32);
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(input_variable_id_tensor), 1);
|
||||
TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32);
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(input_resource_id_tensor), 1);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
@ -54,21 +55,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
|
||||
|
||||
const TfLiteTensor* input_variable_id_tensor =
|
||||
const TfLiteTensor* input_resource_id_tensor =
|
||||
GetInput(context, node, kInputVariableId);
|
||||
const TfLiteTensor* input_value_tensor = GetInput(context, node, kInputValue);
|
||||
|
||||
int variable_id = input_variable_id_tensor->data.i32[0];
|
||||
auto& resource_variables = subgraph->resource_variables();
|
||||
|
||||
auto variable_iterator = resource_variables.find(variable_id);
|
||||
if (variable_iterator == resource_variables.end()) {
|
||||
auto ret = resource_variables.emplace(variable_id, ResourceVariable());
|
||||
variable_iterator = ret.first;
|
||||
}
|
||||
|
||||
auto& variable = variable_iterator->second;
|
||||
variable.AssignFrom(input_value_tensor);
|
||||
int resource_id = input_resource_id_tensor->data.i32[0];
|
||||
auto& resources = subgraph->resources();
|
||||
resource::CreateResourceVariableIfNotAvailable(&resources, resource_id);
|
||||
auto* variable = resource::GetResourceVariable(&resources, resource_id);
|
||||
TF_LITE_ENSURE(context, variable != nullptr);
|
||||
variable->AssignFrom(input_value_tensor);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/core/subgraph.h"
|
||||
#include "tensorflow/lite/experimental/resource/resource_variable.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
@ -36,10 +37,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, node->inputs->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
||||
|
||||
const TfLiteTensor* input_variable_id_tensor =
|
||||
const TfLiteTensor* input_resource_id_tensor =
|
||||
GetInput(context, node, kInputVariableId);
|
||||
TF_LITE_ENSURE_EQ(context, input_variable_id_tensor->type, kTfLiteInt32);
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(input_variable_id_tensor), 1);
|
||||
TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32);
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(input_resource_id_tensor), 1);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputValue);
|
||||
SetTensorToDynamic(output);
|
||||
@ -50,20 +51,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
|
||||
|
||||
const TfLiteTensor* input_variable_id_tensor =
|
||||
const TfLiteTensor* input_resource_id_tensor =
|
||||
GetInput(context, node, kInputVariableId);
|
||||
int variable_id = input_variable_id_tensor->data.i32[0];
|
||||
auto& resource_variables = subgraph->resource_variables();
|
||||
int resource_id = input_resource_id_tensor->data.i32[0];
|
||||
auto& resources = subgraph->resources();
|
||||
auto* variable = resource::GetResourceVariable(&resources, resource_id);
|
||||
TF_LITE_ENSURE(context, variable != nullptr);
|
||||
|
||||
const auto& variable_iterator = resource_variables.find(variable_id);
|
||||
if (variable_iterator == resource_variables.end()) {
|
||||
context->ReportError(context, "Variable ID %d is read before initialized.",
|
||||
variable_id);
|
||||
return kTfLiteError;
|
||||
}
|
||||
auto& variable = variable_iterator->second;
|
||||
|
||||
TfLiteTensor* variable_tensor = variable.GetTensor();
|
||||
TfLiteTensor* variable_tensor = variable->GetTensor();
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputValue);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, variable_tensor->type, output->type);
|
||||
|
Loading…
Reference in New Issue
Block a user