Update Fingerprint64Map to use aliases
PiperOrigin-RevId: 234022159
This commit is contained in:
parent
5b90573d41
commit
e7d9786e66
@ -132,14 +132,14 @@ class ResourceMgr {
|
|||||||
//
|
//
|
||||||
// REQUIRES: std::is_base_of<ResourceBase, T>
|
// REQUIRES: std::is_base_of<ResourceBase, T>
|
||||||
// REQUIRES: resource != nullptr
|
// REQUIRES: resource != nullptr
|
||||||
template <typename T>
|
template <typename T, bool use_dynamic_cast = false>
|
||||||
Status Lookup(const string& container, const string& name,
|
Status Lookup(const string& container, const string& name,
|
||||||
T** resource) const TF_MUST_USE_RESULT;
|
T** resource) const TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
// Similar to Lookup, but looks up multiple resources at once, with only a
|
// Similar to Lookup, but looks up multiple resources at once, with only a
|
||||||
// single lock acquisition. If containers_and_names[i] is uninitialized
|
// single lock acquisition. If containers_and_names[i] is uninitialized
|
||||||
// then this function does not modify resources[i].
|
// then this function does not modify resources[i].
|
||||||
template <typename T>
|
template <typename T, bool use_dynamic_cast = false>
|
||||||
Status LookupMany(absl::Span<std::pair<const string*, const string*> const>
|
Status LookupMany(absl::Span<std::pair<const string*, const string*> const>
|
||||||
containers_and_names,
|
containers_and_names,
|
||||||
std::vector<std::unique_ptr<T, core::RefCountDeleter>>*
|
std::vector<std::unique_ptr<T, core::RefCountDeleter>>*
|
||||||
@ -155,7 +155,7 @@ class ResourceMgr {
|
|||||||
//
|
//
|
||||||
// REQUIRES: std::is_base_of<ResourceBase, T>
|
// REQUIRES: std::is_base_of<ResourceBase, T>
|
||||||
// REQUIRES: resource != nullptr
|
// REQUIRES: resource != nullptr
|
||||||
template <typename T>
|
template <typename T, bool use_dynamic_cast = false>
|
||||||
Status LookupOrCreate(const string& container, const string& name,
|
Status LookupOrCreate(const string& container, const string& name,
|
||||||
T** resource,
|
T** resource,
|
||||||
std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
|
std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
|
||||||
@ -196,7 +196,7 @@ class ResourceMgr {
|
|||||||
mutable mutex mu_;
|
mutable mutex mu_;
|
||||||
std::unordered_map<string, Container*> containers_ GUARDED_BY(mu_);
|
std::unordered_map<string, Container*> containers_ GUARDED_BY(mu_);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, bool use_dynamic_cast = false>
|
||||||
Status LookupInternal(const string& container, const string& name,
|
Status LookupInternal(const string& container, const string& name,
|
||||||
T** resource) const
|
T** resource) const
|
||||||
SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
|
SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
|
||||||
@ -267,7 +267,7 @@ Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
|
|||||||
//
|
//
|
||||||
// If the lookup is successful, the caller takes the ownership of one ref on
|
// If the lookup is successful, the caller takes the ownership of one ref on
|
||||||
// `*value`, and must call its `Unref()` method when it has finished using it.
|
// `*value`, and must call its `Unref()` method when it has finished using it.
|
||||||
template <typename T>
|
template <typename T, bool use_dynamic_cast = false>
|
||||||
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
|
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
|
||||||
|
|
||||||
// Looks up multiple resources pointed by a sequence of resource handles. If
|
// Looks up multiple resources pointed by a sequence of resource handles. If
|
||||||
@ -437,15 +437,15 @@ Status ResourceMgr::Create(const string& container, const string& name,
|
|||||||
return DoCreate(container, MakeTypeIndex<T>(), name, resource);
|
return DoCreate(container, MakeTypeIndex<T>(), name, resource);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, bool use_dynamic_cast>
|
||||||
Status ResourceMgr::Lookup(const string& container, const string& name,
|
Status ResourceMgr::Lookup(const string& container, const string& name,
|
||||||
T** resource) const {
|
T** resource) const {
|
||||||
CheckDeriveFromResourceBase<T>();
|
CheckDeriveFromResourceBase<T>();
|
||||||
tf_shared_lock l(mu_);
|
tf_shared_lock l(mu_);
|
||||||
return LookupInternal(container, name, resource);
|
return LookupInternal<T, use_dynamic_cast>(container, name, resource);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, bool use_dynamic_cast>
|
||||||
Status ResourceMgr::LookupMany(
|
Status ResourceMgr::LookupMany(
|
||||||
absl::Span<std::pair<const string*, const string*> const>
|
absl::Span<std::pair<const string*, const string*> const>
|
||||||
containers_and_names,
|
containers_and_names,
|
||||||
@ -455,8 +455,9 @@ Status ResourceMgr::LookupMany(
|
|||||||
resources->resize(containers_and_names.size());
|
resources->resize(containers_and_names.size());
|
||||||
for (size_t i = 0; i < containers_and_names.size(); ++i) {
|
for (size_t i = 0; i < containers_and_names.size(); ++i) {
|
||||||
T* resource;
|
T* resource;
|
||||||
Status s = LookupInternal(*containers_and_names[i].first,
|
Status s = LookupInternal<T, use_dynamic_cast>(
|
||||||
*containers_and_names[i].second, &resource);
|
*containers_and_names[i].first, *containers_and_names[i].second,
|
||||||
|
&resource);
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
(*resources)[i].reset(resource);
|
(*resources)[i].reset(resource);
|
||||||
}
|
}
|
||||||
@ -464,7 +465,18 @@ Status ResourceMgr::LookupMany(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Simple wrapper to allow conditional dynamic / static casts.
|
||||||
|
template <typename T, bool use_dynamic_cast>
|
||||||
|
struct TypeCastFunctor {
|
||||||
|
static T* Cast(ResourceBase* r) { return static_cast<T*>(r); }
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
struct TypeCastFunctor<T, true> {
|
||||||
|
static T* Cast(ResourceBase* r) { return dynamic_cast<T*>(r); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, bool use_dynamic_cast>
|
||||||
Status ResourceMgr::LookupInternal(const string& container, const string& name,
|
Status ResourceMgr::LookupInternal(const string& container, const string& name,
|
||||||
T** resource) const {
|
T** resource) const {
|
||||||
ResourceBase* found = nullptr;
|
ResourceBase* found = nullptr;
|
||||||
@ -472,12 +484,12 @@ Status ResourceMgr::LookupInternal(const string& container, const string& name,
|
|||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
// It's safe to down cast 'found' to T* since
|
// It's safe to down cast 'found' to T* since
|
||||||
// typeid(T).hash_code() is part of the map key.
|
// typeid(T).hash_code() is part of the map key.
|
||||||
*resource = static_cast<T*>(found);
|
*resource = TypeCastFunctor<T, use_dynamic_cast>::Cast(found);
|
||||||
}
|
}
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, bool use_dynamic_cast>
|
||||||
Status ResourceMgr::LookupOrCreate(const string& container, const string& name,
|
Status ResourceMgr::LookupOrCreate(const string& container, const string& name,
|
||||||
T** resource,
|
T** resource,
|
||||||
std::function<Status(T**)> creator) {
|
std::function<Status(T**)> creator) {
|
||||||
@ -486,11 +498,11 @@ Status ResourceMgr::LookupOrCreate(const string& container, const string& name,
|
|||||||
Status s;
|
Status s;
|
||||||
{
|
{
|
||||||
tf_shared_lock l(mu_);
|
tf_shared_lock l(mu_);
|
||||||
s = LookupInternal(container, name, resource);
|
s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
|
||||||
if (s.ok()) return s;
|
if (s.ok()) return s;
|
||||||
}
|
}
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
s = LookupInternal(container, name, resource);
|
s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
|
||||||
if (s.ok()) return s;
|
if (s.ok()) return s;
|
||||||
TF_RETURN_IF_ERROR(creator(resource));
|
TF_RETURN_IF_ERROR(creator(resource));
|
||||||
s = DoCreate(container, MakeTypeIndex<T>(), name, *resource);
|
s = DoCreate(container, MakeTypeIndex<T>(), name, *resource);
|
||||||
@ -566,11 +578,12 @@ Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) {
|
|||||||
return ctx->resource_manager()->Create(p.container(), p.name(), value);
|
return ctx->resource_manager()->Create(p.container(), p.name(), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, bool use_dynamic_cast>
|
||||||
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
|
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||||
T** value) {
|
T** value) {
|
||||||
TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
|
TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
|
||||||
return ctx->resource_manager()->Lookup(p.container(), p.name(), value);
|
return ctx->resource_manager()->Lookup<T, use_dynamic_cast>(p.container(),
|
||||||
|
p.name(), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -19,18 +19,6 @@ cc_library(
|
|||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"@com_google_absl//absl/types:optional",
|
|
||||||
"@com_google_absl//absl/types:span",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "table_resource_utils",
|
|
||||||
hdrs = ["table_resource_utils.h"],
|
|
||||||
deps = [
|
|
||||||
":lookup_table_interface",
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -57,8 +45,8 @@ tf_kernel_library(
|
|||||||
"fingerprint64_map_ops.cc",
|
"fingerprint64_map_ops.cc",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":lookup_table_interface",
|
||||||
":table_op_utils",
|
":table_op_utils",
|
||||||
":table_resource_utils",
|
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
@ -15,8 +15,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/kernels/lookup_tables/lookup_table_interface.h"
|
||||||
#include "tensorflow/core/kernels/lookup_tables/table_op_utils.h"
|
#include "tensorflow/core/kernels/lookup_tables/table_op_utils.h"
|
||||||
#include "tensorflow/core/kernels/lookup_tables/table_resource_utils.h"
|
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/fingerprint.h"
|
#include "tensorflow/core/platform/fingerprint.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
@ -27,73 +27,48 @@ namespace tables {
|
|||||||
// Map x -> (Fingerprint64(x) % num_oov_buckets) + offset.
|
// Map x -> (Fingerprint64(x) % num_oov_buckets) + offset.
|
||||||
// num_oov_buckets and offset are node attributes provided at construction
|
// num_oov_buckets and offset are node attributes provided at construction
|
||||||
// time.
|
// time.
|
||||||
template <class HeterogeneousKeyType, class ValueType>
|
template <typename KeyType, typename ValueType>
|
||||||
class Fingerprint64Map final
|
class Fingerprint64Map final
|
||||||
: public LookupTableInterface<HeterogeneousKeyType, ValueType> {
|
: public virtual LookupInterface<ValueType*, const KeyType&>,
|
||||||
|
public virtual LookupWithPrefetchInterface<absl::Span<ValueType>,
|
||||||
|
absl::Span<const KeyType>> {
|
||||||
public:
|
public:
|
||||||
|
using key_type = KeyType;
|
||||||
|
|
||||||
Fingerprint64Map(int64 num_oov_buckets, int64 offset)
|
Fingerprint64Map(int64 num_oov_buckets, int64 offset)
|
||||||
: num_oov_buckets_(num_oov_buckets), offset_(offset) {}
|
: num_oov_buckets_(num_oov_buckets), offset_(offset) {}
|
||||||
|
|
||||||
mutex* GetMutex() const override { return nullptr; }
|
Status Lookup(const KeyType& key_to_find, ValueType* value) const override {
|
||||||
|
*value = LookupHelper(key_to_find);
|
||||||
bool UnsafeInsertOrAssign(const HeterogeneousKeyType& key,
|
return Status::OK();
|
||||||
const ValueType& value) override {
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TableUnbatchedInsertStatus() const override {
|
Status Lookup(absl::Span<const KeyType> keys, absl::Span<ValueType> values,
|
||||||
return errors::Unimplemented("Fingerprint64Map does not support inserts.");
|
int64 prefetch_lookahead) const override {
|
||||||
}
|
|
||||||
|
|
||||||
Status BatchInsertOrAssign(absl::Span<const HeterogeneousKeyType> keys,
|
|
||||||
absl::Span<const ValueType> values) override {
|
|
||||||
return errors::Unimplemented("Fingerprint64Map does not support inserts.");
|
|
||||||
}
|
|
||||||
|
|
||||||
ValueType UnsafeLookupKey(
|
|
||||||
const HeterogeneousKeyType& key_to_find) const override {
|
|
||||||
// This can cause a downcast.
|
|
||||||
return static_cast<ValueType>(Fingerprint64(key_to_find) %
|
|
||||||
num_oov_buckets_) +
|
|
||||||
offset_;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status TableUnbatchedLookupStatus() const override { return Status::OK(); }
|
|
||||||
|
|
||||||
Status BatchLookup(absl::Span<const HeterogeneousKeyType> keys,
|
|
||||||
absl::Span<ValueType> values,
|
|
||||||
int64 prefetch_lookahead) const override {
|
|
||||||
if (ABSL_PREDICT_FALSE(keys.size() != values.size())) {
|
if (ABSL_PREDICT_FALSE(keys.size() != values.size())) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"keys and values do not have the same number of elements (found ",
|
"keys and values do not have the same number of elements (found ",
|
||||||
keys.size(), " vs ", values.size(), ").");
|
keys.size(), " vs ", values.size(), ").");
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < keys.size(); ++i) {
|
for (size_t i = 0; i < keys.size(); ++i) {
|
||||||
values[i] = Fingerprint64Map::UnsafeLookupKey(keys[i]);
|
values[i] = LookupHelper(keys[i]);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
const absl::optional<const ValueType> DefaultValue() const override {
|
mutex* GetMutex() const override { return nullptr; }
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
void UnsafePrefetchKey(
|
string DebugString() const override { return __PRETTY_FUNCTION__; }
|
||||||
const HeterogeneousKeyType& key_to_find) const override {}
|
|
||||||
|
|
||||||
size_t UnsafeSize() const override { return 0; }
|
|
||||||
|
|
||||||
Status SizeStatus() const override {
|
|
||||||
return errors::Unimplemented(
|
|
||||||
"Fingerprint64Map does not have a concept of size.");
|
|
||||||
}
|
|
||||||
|
|
||||||
bool UnsafeContainsKey(
|
|
||||||
const HeterogeneousKeyType& key_to_find) const override {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
ABSL_ATTRIBUTE_ALWAYS_INLINE ValueType
|
||||||
|
LookupHelper(const KeyType& key_to_find) const {
|
||||||
|
// This can cause a downcast.
|
||||||
|
return static_cast<ValueType>(Fingerprint64(key_to_find) %
|
||||||
|
num_oov_buckets_) +
|
||||||
|
offset_;
|
||||||
|
}
|
||||||
|
|
||||||
const int64 num_oov_buckets_;
|
const int64 num_oov_buckets_;
|
||||||
const int64 offset_;
|
const int64 offset_;
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(Fingerprint64Map);
|
TF_DISALLOW_COPY_AND_ASSIGN(Fingerprint64Map);
|
||||||
@ -102,9 +77,10 @@ class Fingerprint64Map final
|
|||||||
template <typename Fingerprint64Map>
|
template <typename Fingerprint64Map>
|
||||||
struct Fingerprint64MapFactory {
|
struct Fingerprint64MapFactory {
|
||||||
struct Functor {
|
struct Functor {
|
||||||
template <typename ContainerBase>
|
using resource_type = Fingerprint64Map;
|
||||||
|
|
||||||
static Status AllocateContainer(OpKernelContext* ctx, OpKernel* kernel,
|
static Status AllocateContainer(OpKernelContext* ctx, OpKernel* kernel,
|
||||||
ContainerBase** container) {
|
Fingerprint64Map** container) {
|
||||||
int64 num_oov_buckets;
|
int64 num_oov_buckets;
|
||||||
int64 offset;
|
int64 offset;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
@ -116,24 +92,28 @@ struct Fingerprint64MapFactory {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
#define REGISTER_STRING_KERNEL(table_value_dtype) \
|
template <typename KeyType, typename ValueType>
|
||||||
REGISTER_KERNEL_BUILDER( \
|
using ResourceOp = ResourceConstructionOp<
|
||||||
Name("Fingerprint64Map") \
|
typename Fingerprint64MapFactory<
|
||||||
.Device(DEVICE_CPU) \
|
Fingerprint64Map<KeyType, ValueType>>::Functor,
|
||||||
.TypeConstraint<Variant>("heterogeneous_key_dtype") \
|
// These are the aliases.
|
||||||
.TypeConstraint<table_value_dtype>("table_value_dtype"), \
|
LookupInterface<ValueType*, const KeyType&>,
|
||||||
ResourceConstructionOp< \
|
LookupWithPrefetchInterface<absl::Span<ValueType>,
|
||||||
LookupTableInterface<absl::string_view, table_value_dtype>, \
|
absl::Span<const KeyType>>>;
|
||||||
Fingerprint64MapFactory<Fingerprint64Map< \
|
|
||||||
absl::string_view, table_value_dtype>>::Functor>); \
|
#define REGISTER_STRING_KERNEL(ValueType) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("Fingerprint64Map") \
|
Name("Fingerprint64Map") \
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<string>("heterogeneous_key_dtype") \
|
.TypeConstraint<Variant>("heterogeneous_key_dtype") \
|
||||||
.TypeConstraint<table_value_dtype>("table_value_dtype"), \
|
.TypeConstraint<ValueType>("table_value_dtype"), \
|
||||||
ResourceConstructionOp<LookupTableInterface<string, table_value_dtype>, \
|
ResourceOp<absl::string_view, ValueType>); \
|
||||||
Fingerprint64MapFactory<Fingerprint64Map< \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
string, table_value_dtype>>::Functor>);
|
Name("Fingerprint64Map") \
|
||||||
|
.Device(DEVICE_CPU) \
|
||||||
|
.TypeConstraint<string>("heterogeneous_key_dtype") \
|
||||||
|
.TypeConstraint<ValueType>("table_value_dtype"), \
|
||||||
|
ResourceOp<string, ValueType>);
|
||||||
|
|
||||||
REGISTER_STRING_KERNEL(int32);
|
REGISTER_STRING_KERNEL(int32);
|
||||||
REGISTER_STRING_KERNEL(int64);
|
REGISTER_STRING_KERNEL(int64);
|
||||||
|
@ -16,11 +16,6 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_LOOKUP_TABLE_INTERFACE_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_LOOKUP_TABLE_INTERFACE_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_LOOKUP_TABLE_INTERFACE_H_
|
#define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_LOOKUP_TABLE_INTERFACE_H_
|
||||||
|
|
||||||
#include <cstddef>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "absl/types/optional.h"
|
|
||||||
#include "absl/types/span.h"
|
|
||||||
#include "tensorflow/core/framework/resource_mgr.h"
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
@ -28,90 +23,74 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tables {
|
namespace tables {
|
||||||
|
|
||||||
// Interface for key-value pair lookups with support for heterogeneous keys.
|
// Interface for resources with mutable state.
|
||||||
// This class contains two main kinds of methods: methods which operate on
|
class SynchronizedInterface : public virtual ResourceBase {
|
||||||
// a batch of inputs and methods which do not. The latter have the prefix
|
|
||||||
// 'Unsafe'. Clients must call the corresponding status methods to determine
|
|
||||||
// whether they are safe to call within a code block.
|
|
||||||
// Implementations must guarantee thread-safety when GetMutex is used to
|
|
||||||
// synchronize method access.
|
|
||||||
template <typename HeterogeneousKeyType, typename ValueType>
|
|
||||||
class LookupTableInterface : public ResourceBase {
|
|
||||||
public:
|
public:
|
||||||
using heterogeneous_key_type = HeterogeneousKeyType;
|
|
||||||
using value_type = ValueType;
|
|
||||||
using key_type = heterogeneous_key_type;
|
|
||||||
|
|
||||||
// Return value should be used to synchronize read/write access to
|
// Return value should be used to synchronize read/write access to
|
||||||
// all public methods. If null, no synchronization is needed.
|
// all public methods. If null, no synchronization is needed.
|
||||||
virtual mutex* GetMutex() const = 0;
|
virtual mutex* GetMutex() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
// Insert the KV pair into the underlying table. If a key equivalent to key
|
// Interface for containers which support batch lookups.
|
||||||
// already exists in the underlying table, its corresponding value is
|
template <typename ValueType, typename... KeyContext>
|
||||||
// overridden. Returns true only if the key was inserted for the first time.
|
class InsertOrAssignInterface : public virtual SynchronizedInterface {
|
||||||
// Undefined if TableUnbatchedInsertStatus() != OK.
|
public:
|
||||||
virtual bool UnsafeInsertOrAssign(const HeterogeneousKeyType& key,
|
using value_type = ValueType;
|
||||||
const ValueType& value) = 0;
|
|
||||||
|
|
||||||
// Returns OK if it is safe to call InsertOrAssign.
|
|
||||||
// Once OK is returned, it is safe to call InsertOrAssign for the rest of the
|
|
||||||
// program.
|
|
||||||
virtual Status TableUnbatchedInsertStatus() const TF_MUST_USE_RESULT = 0;
|
|
||||||
|
|
||||||
// Stores each KV pair {keys[i], values[i]} in the underlying map, overriding
|
// Stores each KV pair {keys[i], values[i]} in the underlying map, overriding
|
||||||
// pre-existing pairs which have equivalent keys.
|
// pre-existing pairs which have equivalent keys.
|
||||||
// keys and values should have the same size.
|
// keys and values should have the same size.
|
||||||
virtual Status BatchInsertOrAssign(
|
virtual Status InsertOrAssign(KeyContext... key_context,
|
||||||
absl::Span<const HeterogeneousKeyType> keys,
|
ValueType values) = 0;
|
||||||
absl::Span<const ValueType> values) = 0;
|
};
|
||||||
|
|
||||||
// Prefetch key_to_find into implementation defined data caches.
|
// Interface for containers which support lookups.
|
||||||
// Implementations are free to leave this a no-op.
|
template <typename ValueType, typename... KeyContext>
|
||||||
// Undefined if TableUnbatchedLookupStatus() != OK.
|
class LookupInterface : public virtual SynchronizedInterface {
|
||||||
virtual void UnsafePrefetchKey(
|
public:
|
||||||
const HeterogeneousKeyType& key_to_find) const {}
|
using value_type = ValueType;
|
||||||
|
|
||||||
// Returns true if and only if the table contains key_to_find.
|
|
||||||
// Undefined if TableUnbatchedLookupStatus() != OK.
|
|
||||||
virtual bool UnsafeContainsKey(
|
|
||||||
const HeterogeneousKeyType& key_to_find) const = 0;
|
|
||||||
|
|
||||||
// Lookup the value for key_to_find. This value must always be well-defined,
|
|
||||||
// even when ContainsKey(key_to_find) == false. When
|
|
||||||
// dv = DefaultValue() != absl::nullopt and ContainsKey(key_to_find) == false,
|
|
||||||
// dv is returned.
|
|
||||||
// Undefined if TableUnbatchedLookupStatus() != OK.
|
|
||||||
virtual ValueType UnsafeLookupKey(
|
|
||||||
const HeterogeneousKeyType& key_to_find) const = 0;
|
|
||||||
|
|
||||||
// Returns OK if it is safe to call PrefetchKey, ContainsKey, and
|
|
||||||
// UnsafeLookupKey.
|
|
||||||
// If OK is returned, it is safe to call these methods until the next
|
|
||||||
// non-const method of this class is called.
|
|
||||||
virtual Status TableUnbatchedLookupStatus() const TF_MUST_USE_RESULT = 0;
|
|
||||||
|
|
||||||
// Lookup the values for keys and store them in values.
|
// Lookup the values for keys and store them in values.
|
||||||
// prefetch_lookahead is used to prefetch the key at index
|
// prefetch_lookahead is used to prefetch the key at index
|
||||||
// i + prefetch_lookahead at the ith iteration of the implemented loop.
|
// i + prefetch_lookahead at the ith iteration of the implemented loop.
|
||||||
// keys and values must have the same size.
|
// keys and values must have the same size.
|
||||||
virtual Status BatchLookup(absl::Span<const HeterogeneousKeyType> keys,
|
virtual Status Lookup(KeyContext... key_context, ValueType values) const = 0;
|
||||||
absl::Span<ValueType> values,
|
};
|
||||||
int64 prefetch_lookahead) const = 0;
|
|
||||||
|
|
||||||
// Returns the number of elements in the table.
|
// Interface for containers which support lookups with prefetching.
|
||||||
// Undefined if SizeStatus() != OK.
|
template <typename ValueType, typename... KeyContext>
|
||||||
virtual size_t UnsafeSize() const = 0;
|
class LookupWithPrefetchInterface : public virtual SynchronizedInterface {
|
||||||
|
public:
|
||||||
|
using value_type = ValueType;
|
||||||
|
|
||||||
// Returns OK if the return value of UnsafeSize() is always well-defined.
|
// Lookup the values for keys and store them in values.
|
||||||
virtual Status SizeStatus() const TF_MUST_USE_RESULT = 0;
|
// prefetch_lookahead is used to prefetch the key at index
|
||||||
|
// i + prefetch_lookahead at the ith iteration of the implemented loop.
|
||||||
|
// keys and values must have the same size.
|
||||||
|
virtual Status Lookup(KeyContext... key_context, ValueType values,
|
||||||
|
int64 prefetch_lookahead) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
// If non-null value is returned, LookupKey returns that value only for keys
|
// Interface for containers with size concepts.
|
||||||
// which satisfy ContainsKey(key_to_find) == false.
|
// Implementations must guarantee thread-safety when GetMutex is used to
|
||||||
virtual const absl::optional<const ValueType> DefaultValue() const = 0;
|
// synchronize method access.
|
||||||
|
class SizeInterface : public virtual SynchronizedInterface {
|
||||||
|
public:
|
||||||
|
// Returns the number of elements in the container.
|
||||||
|
virtual uint64 Size() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
string DebugString() const override { return "A lookup table"; }
|
// Interface for tables which can be initialized from key and value arguments.
|
||||||
|
template <typename ValueType, typename... KeyContext>
|
||||||
|
class KeyValueTableInitializerInterface : public virtual SynchronizedInterface {
|
||||||
|
public:
|
||||||
|
using value_type = ValueType;
|
||||||
|
|
||||||
~LookupTableInterface() override = default;
|
// Lookup the values for keys and store them in values.
|
||||||
|
// prefetch_lookahead is used to prefetch the key at index
|
||||||
|
// i + prefetch_lookahead at the ith iteration of the implemented loop.
|
||||||
|
// keys and values must have the same size.
|
||||||
|
virtual Status Initialize(KeyContext... key_context, ValueType values) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tables
|
} // namespace tables
|
||||||
|
@ -44,11 +44,11 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tables {
|
namespace tables {
|
||||||
|
|
||||||
// Create resources of type ContainerBase using the static method
|
// Create resources of type ResourceType and AliasesToRegister using
|
||||||
// Functor::AllocateContainer(OpKernelConstruction*, OpKernel*,
|
// Functor::AllocateContainer(OpKernelConstruction*, OpKernel*,
|
||||||
// ContainerBase**)
|
// ResourceType**). ResourceType = Functor::resource_type.
|
||||||
// If the resource has already been created it will be looked up.
|
// No-op for resources which have already been created.
|
||||||
template <class ContainerBase, typename Functor>
|
template <typename Functor, typename... AliasesToRegister>
|
||||||
class ResourceConstructionOp : public OpKernel {
|
class ResourceConstructionOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit ResourceConstructionOp(OpKernelConstruction* ctx)
|
explicit ResourceConstructionOp(OpKernelConstruction* ctx)
|
||||||
@ -66,46 +66,86 @@ class ResourceConstructionOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto creator = [ctx,
|
auto creator = [ctx,
|
||||||
this](ContainerBase** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
this](ResourceType** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
ContainerBase* container;
|
ResourceType* resource = nullptr;
|
||||||
auto status = Functor::AllocateContainer(ctx, this, &container);
|
auto status = Functor::AllocateContainer(ctx, this, &resource);
|
||||||
if (ABSL_PREDICT_FALSE(!status.ok())) {
|
if (ABSL_PREDICT_FALSE(!status.ok())) {
|
||||||
container->Unref();
|
// Ideally resource is non-null only if status is OK but we try
|
||||||
|
// to compensate here.
|
||||||
|
if (resource != nullptr) {
|
||||||
|
resource->Unref();
|
||||||
|
}
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
if (ctx->track_allocations()) {
|
if (ctx->track_allocations()) {
|
||||||
ctx->record_persistent_memory_allocation(container->MemoryUsed());
|
ctx->record_persistent_memory_allocation(resource->MemoryUsed());
|
||||||
}
|
}
|
||||||
*ret = container;
|
*ret = resource;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
};
|
||||||
|
|
||||||
ContainerBase* container_base = nullptr;
|
// Register the ResourceType alias.
|
||||||
|
ResourceType* resource = nullptr;
|
||||||
|
core::ScopedUnref unref_me(resource);
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, cinfo_.resource_manager()->template LookupOrCreate<ContainerBase>(
|
ctx,
|
||||||
cinfo_.container(), cinfo_.name(), &container_base, creator));
|
cinfo_.resource_manager()->template LookupOrCreate<ResourceType, true>(
|
||||||
core::ScopedUnref unref_me(container_base);
|
cinfo_.container(), cinfo_.name(), &resource, creator));
|
||||||
|
|
||||||
|
// Put a handle to resource in the output tensor (the other aliases will
|
||||||
|
// have the same handle).
|
||||||
Tensor* handle;
|
Tensor* handle;
|
||||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
|
||||||
handle->scalar<ResourceHandle>()() = MakeResourceHandle<ContainerBase>(
|
handle->scalar<ResourceHandle>()() = MakeResourceHandle<ResourceType>(
|
||||||
ctx, cinfo_.container(), cinfo_.name());
|
ctx, cinfo_.container(), cinfo_.name());
|
||||||
table_handle_set_ = true;
|
table_handle_set_ = true;
|
||||||
|
|
||||||
|
// Create other alias resources.
|
||||||
|
Status status;
|
||||||
|
char dummy[sizeof...(AliasesToRegister)] = {
|
||||||
|
(status.Update(RegisterAlias<AliasesToRegister>(resource)), 0)...};
|
||||||
|
(void)dummy;
|
||||||
|
OP_REQUIRES_OK(ctx, status);
|
||||||
}
|
}
|
||||||
|
|
||||||
~ResourceConstructionOp() override {
|
~ResourceConstructionOp() override {
|
||||||
// If the table object was not shared, delete it.
|
// If the table object was not shared, delete it.
|
||||||
if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
|
if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
|
||||||
if (!cinfo_.resource_manager()
|
if (!cinfo_.resource_manager()
|
||||||
->template Delete<ContainerBase>(cinfo_.container(),
|
->template Delete<ResourceType>(cinfo_.container(),
|
||||||
cinfo_.name())
|
cinfo_.name())
|
||||||
.ok()) {
|
.ok()) {
|
||||||
// Do nothing; the resource may have been deleted by session resets.
|
// Do nothing; the resource may have been deleted by session resets.
|
||||||
}
|
}
|
||||||
|
// Attempt to delete other resource aliases.
|
||||||
|
Status dummy_status;
|
||||||
|
char dummy[sizeof...(AliasesToRegister)] = {
|
||||||
|
(dummy_status.Update(DeleteAlias<AliasesToRegister>()), 0)...};
|
||||||
|
(void)dummy;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
using ResourceType = typename Functor::resource_type;
|
||||||
|
template <typename T>
|
||||||
|
Status RegisterAlias(ResourceType* resource) {
|
||||||
|
auto creator = [resource](T** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
|
*ret = resource;
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
|
||||||
|
T* alias_resource = nullptr;
|
||||||
|
core::ScopedUnref unref_me(alias_resource);
|
||||||
|
return cinfo_.resource_manager()->template LookupOrCreate<T, true>(
|
||||||
|
cinfo_.container(), cinfo_.name(), &alias_resource, creator);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status DeleteAlias() {
|
||||||
|
return cinfo_.resource_manager()->template Delete<T>(cinfo_.container(),
|
||||||
|
cinfo_.name());
|
||||||
|
}
|
||||||
|
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
bool table_handle_set_ GUARDED_BY(mu_);
|
bool table_handle_set_ GUARDED_BY(mu_);
|
||||||
ContainerInfo cinfo_;
|
ContainerInfo cinfo_;
|
||||||
@ -120,8 +160,7 @@ class ResourceConstructionOp : public OpKernel {
|
|||||||
// If the resource has already been created it will be looked up.
|
// If the resource has already been created it will be looked up.
|
||||||
// Container must decrease the reference count of the FallbackTableBaseType*
|
// Container must decrease the reference count of the FallbackTableBaseType*
|
||||||
// constructor argument before its destructor completes.
|
// constructor argument before its destructor completes.
|
||||||
template <class ContainerBase, class Functor,
|
template <typename Functor, typename... AliasesToRegister>
|
||||||
class FallbackTableBaseType = ContainerBase>
|
|
||||||
class TableWithFallbackConstructionOp : public OpKernel {
|
class TableWithFallbackConstructionOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit TableWithFallbackConstructionOp(OpKernelConstruction* ctx)
|
explicit TableWithFallbackConstructionOp(OpKernelConstruction* ctx)
|
||||||
@ -140,13 +179,14 @@ class TableWithFallbackConstructionOp : public OpKernel {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Look up the fallback table.
|
||||||
FallbackTableBaseType* fallback_table = nullptr;
|
FallbackTableBaseType* fallback_table = nullptr;
|
||||||
{
|
{
|
||||||
const Tensor& table_handle = ctx->input(table_int64_args.size());
|
const Tensor& table_handle = ctx->input(table_int64_args.size());
|
||||||
ResourceHandle handle(table_handle.scalar<ResourceHandle>()());
|
ResourceHandle handle(table_handle.scalar<ResourceHandle>()());
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, ctx->resource_manager()->Lookup(handle.container(),
|
ctx, ctx->resource_manager()->Lookup<FallbackTableBaseType, true>(
|
||||||
handle.name(), &fallback_table));
|
handle.container(), handle.name(), &fallback_table));
|
||||||
}
|
}
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
|
|
||||||
@ -156,51 +196,93 @@ class TableWithFallbackConstructionOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto creator = [ctx, this, fallback_table](
|
auto creator = [ctx, this, fallback_table](
|
||||||
ContainerBase** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
ResourceType** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
// container construction logic can't be merged with
|
// container construction logic can't be merged with
|
||||||
// ResourceConstructionOp because Container constructor requires an
|
// ResourceConstructionOp because Container constructor requires an
|
||||||
// input which can only be constructed if the resource manager
|
// input which can only be constructed if the resource manager
|
||||||
// internal lock is not already held.
|
// internal lock is not already held.
|
||||||
ContainerBase* container;
|
ResourceType* resource = nullptr;
|
||||||
auto status =
|
auto status =
|
||||||
Functor::AllocateContainer(ctx, this, fallback_table, &container);
|
Functor::AllocateContainer(ctx, this, fallback_table, &resource);
|
||||||
if (ABSL_PREDICT_FALSE(!status.ok())) {
|
if (ABSL_PREDICT_FALSE(!status.ok())) {
|
||||||
container->Unref();
|
// Ideally resource is non-null only if status is OK but we try
|
||||||
|
// to compensate here.
|
||||||
|
if (resource != nullptr) {
|
||||||
|
resource->Unref();
|
||||||
|
}
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
if (ctx->track_allocations()) {
|
if (ctx->track_allocations()) {
|
||||||
ctx->record_persistent_memory_allocation(container->MemoryUsed());
|
ctx->record_persistent_memory_allocation(resource->MemoryUsed());
|
||||||
}
|
}
|
||||||
*ret = container;
|
*ret = resource;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
};
|
||||||
|
|
||||||
ContainerBase* table = nullptr;
|
// Register the ResourceType alias.
|
||||||
OP_REQUIRES_OK(
|
ResourceType* table = nullptr;
|
||||||
ctx, cinfo_.resource_manager()->template LookupOrCreate<ContainerBase>(
|
|
||||||
cinfo_.container(), cinfo_.name(), &table, creator));
|
|
||||||
core::ScopedUnref unref_me(table);
|
core::ScopedUnref unref_me(table);
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx,
|
||||||
|
cinfo_.resource_manager()->template LookupOrCreate<ResourceType, true>(
|
||||||
|
cinfo_.container(), cinfo_.name(), &table, creator));
|
||||||
|
|
||||||
|
// Put a handle to resource in the output tensor (the other aliases will
|
||||||
|
// have the same handle).
|
||||||
Tensor* handle;
|
Tensor* handle;
|
||||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
|
||||||
handle->scalar<ResourceHandle>()() = MakeResourceHandle<ContainerBase>(
|
handle->scalar<ResourceHandle>()() = MakeResourceHandle<ResourceType>(
|
||||||
ctx, cinfo_.container(), cinfo_.name());
|
ctx, cinfo_.container(), cinfo_.name());
|
||||||
table_handle_set_ = true;
|
table_handle_set_ = true;
|
||||||
|
|
||||||
|
// Create other alias resources.
|
||||||
|
Status status;
|
||||||
|
char dummy[sizeof...(AliasesToRegister)] = {
|
||||||
|
(status.Update(RegisterAlias<AliasesToRegister>(table)), 0)...};
|
||||||
|
(void)dummy;
|
||||||
|
OP_REQUIRES_OK(ctx, status);
|
||||||
}
|
}
|
||||||
|
|
||||||
~TableWithFallbackConstructionOp() override {
|
~TableWithFallbackConstructionOp() override {
|
||||||
// If the table object was not shared, delete it.
|
// If the table object was not shared, delete it.
|
||||||
if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
|
if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
|
||||||
if (!cinfo_.resource_manager()
|
if (!cinfo_.resource_manager()
|
||||||
->template Delete<ContainerBase>(cinfo_.container(),
|
->template Delete<ResourceType>(cinfo_.container(),
|
||||||
cinfo_.name())
|
cinfo_.name())
|
||||||
.ok()) {
|
.ok()) {
|
||||||
// Do nothing; the resource may have been deleted by session resets.
|
// Do nothing; the resource may have been deleted by session resets.
|
||||||
}
|
}
|
||||||
|
// Attempt to delete other resource aliases.
|
||||||
|
Status dummy_status;
|
||||||
|
char dummy[sizeof...(AliasesToRegister)] = {
|
||||||
|
(dummy_status.Update(DeleteAlias<AliasesToRegister>()), 0)...};
|
||||||
|
(void)dummy;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
using ResourceType = typename Functor::resource_type;
|
||||||
|
using FallbackTableBaseType = typename Functor::fallback_table_type;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status RegisterAlias(ResourceType* resource) {
|
||||||
|
auto creator = [resource](T** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
|
*ret = resource;
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
|
||||||
|
T* alias_resource = nullptr;
|
||||||
|
core::ScopedUnref unref_me(alias_resource);
|
||||||
|
return cinfo_.resource_manager()->template LookupOrCreate<T, true>(
|
||||||
|
cinfo_.container(), cinfo_.name(), &alias_resource, creator);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status DeleteAlias() {
|
||||||
|
return cinfo_.resource_manager()->template Delete<T>(cinfo_.container(),
|
||||||
|
cinfo_.name());
|
||||||
|
}
|
||||||
|
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
bool table_handle_set_ GUARDED_BY(mu_);
|
bool table_handle_set_ GUARDED_BY(mu_);
|
||||||
ContainerInfo cinfo_;
|
ContainerInfo cinfo_;
|
||||||
@ -209,33 +291,29 @@ class TableWithFallbackConstructionOp : public OpKernel {
|
|||||||
TF_DISALLOW_COPY_AND_ASSIGN(TableWithFallbackConstructionOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(TableWithFallbackConstructionOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Used to insert tensors into a container.
|
// Lookup a table of type ResourceAlias and insert the passed in keys and
|
||||||
template <class Container, class InsertKeyTensorType,
|
// values tensors using Functor::TensorInsert(keys, values, table).
|
||||||
class InsertValueTensorType>
|
template <typename Functor,
|
||||||
class HeterogeneousLookupTableInsertOrAssignOp : public OpKernel {
|
typename ResourceAlias = typename Functor::resource_type>
|
||||||
|
class LookupTableInsertOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit HeterogeneousLookupTableInsertOrAssignOp(OpKernelConstruction* ctx)
|
explicit LookupTableInsertOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
: OpKernel(ctx) {}
|
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
OpInputList table_int64_args;
|
OpInputList table_int64_args;
|
||||||
OP_REQUIRES_OK(ctx, ctx->input_list("table_int64_args", &table_int64_args));
|
OP_REQUIRES_OK(ctx, ctx->input_list("table_int64_args", &table_int64_args));
|
||||||
const size_t tensor_index_offset = table_int64_args.size();
|
const size_t tensor_index_offset = table_int64_args.size();
|
||||||
|
// Business logic for checking tensor shapes, etc, is delegated to the
|
||||||
|
// Functor.
|
||||||
const Tensor& keys = ctx->input(tensor_index_offset + 1);
|
const Tensor& keys = ctx->input(tensor_index_offset + 1);
|
||||||
const Tensor& values = ctx->input(tensor_index_offset + 2);
|
const Tensor& values = ctx->input(tensor_index_offset + 2);
|
||||||
if (ABSL_PREDICT_FALSE(keys.NumElements() != values.NumElements())) {
|
|
||||||
ctx->SetStatus(errors::InvalidArgument(
|
|
||||||
"keys and values do not have the same number of elements: ",
|
|
||||||
keys.NumElements(), " vs ", values.NumElements()));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const Tensor& table_handle = ctx->input(tensor_index_offset);
|
const Tensor& table_handle = ctx->input(tensor_index_offset);
|
||||||
ResourceHandle handle(table_handle.scalar<ResourceHandle>()());
|
ResourceHandle handle(table_handle.scalar<ResourceHandle>()());
|
||||||
Container* table;
|
ResourceAlias* table;
|
||||||
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup(handle.container(),
|
|
||||||
handle.name(), &table));
|
|
||||||
core::ScopedUnref unref_me(table);
|
core::ScopedUnref unref_me(table);
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup<ResourceAlias, true>(
|
||||||
|
handle.container(), handle.name(), &table));
|
||||||
|
|
||||||
int memory_used_before = 0;
|
int memory_used_before = 0;
|
||||||
if (ctx->track_allocations()) {
|
if (ctx->track_allocations()) {
|
||||||
@ -244,9 +322,9 @@ class HeterogeneousLookupTableInsertOrAssignOp : public OpKernel {
|
|||||||
auto* mutex = table->GetMutex();
|
auto* mutex = table->GetMutex();
|
||||||
if (mutex != nullptr) {
|
if (mutex != nullptr) {
|
||||||
mutex_lock lock(*mutex);
|
mutex_lock lock(*mutex);
|
||||||
OP_REQUIRES_OK(ctx, TensorInsert(keys, values, table));
|
OP_REQUIRES_OK(ctx, Functor::TensorInsert(keys, values, table));
|
||||||
} else {
|
} else {
|
||||||
OP_REQUIRES_OK(ctx, TensorInsert(keys, values, table));
|
OP_REQUIRES_OK(ctx, Functor::TensorInsert(keys, values, table));
|
||||||
}
|
}
|
||||||
if (ctx->track_allocations()) {
|
if (ctx->track_allocations()) {
|
||||||
ctx->record_persistent_memory_allocation(table->MemoryUsed() -
|
ctx->record_persistent_memory_allocation(table->MemoryUsed() -
|
||||||
@ -255,74 +333,17 @@ class HeterogeneousLookupTableInsertOrAssignOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Non-variant InsertKeyTensorType which is the same as Container::key_type.
|
TF_DISALLOW_COPY_AND_ASSIGN(LookupTableInsertOp);
|
||||||
// No need to static_cast.
|
|
||||||
template <typename SfinaeArg = InsertKeyTensorType>
|
|
||||||
absl::enable_if_t<
|
|
||||||
IsValidDataType<SfinaeArg>::value &&
|
|
||||||
std::is_same<SfinaeArg, typename Container::key_type>::value,
|
|
||||||
Status>
|
|
||||||
TensorInsert(const Tensor& keys, const Tensor& values,
|
|
||||||
Container* table) const {
|
|
||||||
const auto keys_flat = keys.flat<SfinaeArg>();
|
|
||||||
const auto values_flat = values.flat<InsertValueTensorType>();
|
|
||||||
return table->BatchInsertOrAssign(
|
|
||||||
absl::MakeSpan(keys_flat.data(), keys_flat.size()),
|
|
||||||
absl::MakeSpan(values_flat.data(), values_flat.size()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Non-variant InsertKeyTensorType which is otherwise convertible to
|
|
||||||
// Container::key_type.
|
|
||||||
template <typename SfinaeArg = InsertKeyTensorType>
|
|
||||||
absl::enable_if_t<
|
|
||||||
IsValidDataType<SfinaeArg>::value &&
|
|
||||||
!std::is_same<SfinaeArg, typename Container::key_type>::value &&
|
|
||||||
std::is_convertible<SfinaeArg, typename Container::key_type>::value,
|
|
||||||
Status>
|
|
||||||
TensorInsert(const Tensor& keys, const Tensor& values,
|
|
||||||
Container* table) const {
|
|
||||||
const auto keys_flat = keys.flat<InsertKeyTensorType>();
|
|
||||||
std::vector<typename Container::key_type> keys_vec;
|
|
||||||
const auto keys_size = keys_flat.size();
|
|
||||||
keys_vec.reserve(keys_size);
|
|
||||||
for (size_t i = 0; i < keys_size; ++i) {
|
|
||||||
keys_vec.push_back(
|
|
||||||
static_cast<typename Container::key_type>(keys_flat(i)));
|
|
||||||
}
|
|
||||||
const auto values_flat = values.flat<InsertValueTensorType>();
|
|
||||||
return table->BatchInsertOrAssign(
|
|
||||||
keys_vec, absl::MakeSpan(values_flat.data(), values_flat.size()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Variant InsertKeyTensorType; the wrapped type is convertible to
|
|
||||||
// Container::key_type.
|
|
||||||
template <typename SfinaeArg = InsertKeyTensorType>
|
|
||||||
absl::enable_if_t<
|
|
||||||
!IsValidDataType<SfinaeArg>::value &&
|
|
||||||
std::is_convertible<typename SfinaeArg::value_type,
|
|
||||||
typename Container::key_type>::value,
|
|
||||||
Status>
|
|
||||||
TensorInsert(const Tensor& keys, const Tensor& values,
|
|
||||||
Container* table) const {
|
|
||||||
const auto keys_flat = keys.flat<Variant>();
|
|
||||||
std::vector<typename Container::key_type> keys_vec;
|
|
||||||
keys_vec.reserve(keys_flat.size());
|
|
||||||
for (size_t i = 0; i < keys_flat.size(); ++i) {
|
|
||||||
keys_vec.emplace_back(
|
|
||||||
*keys_flat(i).get<typename SfinaeArg::value_type>());
|
|
||||||
}
|
|
||||||
const auto values_flat = values.flat<InsertValueTensorType>();
|
|
||||||
return table->BatchInsertOrAssign(
|
|
||||||
keys_vec, absl::MakeSpan(values_flat.data(), values_flat.size()));
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Used for tensor lookups.
|
// Lookup a table of type ResourceAlias and look up the passed in keys using
|
||||||
template <class Container, class LookupKeyTensorType, class ValueTensorType>
|
// Functor::TensorLookup(
|
||||||
class HeterogeneousLookupTableFindOp : public OpKernel {
|
// table, keys, prefetch_lookahead, num_keys_per_thread, threadpool, out).
|
||||||
|
template <typename Functor,
|
||||||
|
typename ResourceAlias = typename Functor::resource_type>
|
||||||
|
class LookupTableFindOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit HeterogeneousLookupTableFindOp(OpKernelConstruction* ctx)
|
explicit LookupTableFindOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
: OpKernel(ctx) {}
|
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
OpInputList table_int64_args;
|
OpInputList table_int64_args;
|
||||||
@ -370,10 +391,10 @@ class HeterogeneousLookupTableFindOp : public OpKernel {
|
|||||||
|
|
||||||
const Tensor& table_handle = ctx->input(tensor_index_offset);
|
const Tensor& table_handle = ctx->input(tensor_index_offset);
|
||||||
ResourceHandle handle(table_handle.scalar<ResourceHandle>()());
|
ResourceHandle handle(table_handle.scalar<ResourceHandle>()());
|
||||||
Container* table;
|
ResourceAlias* table;
|
||||||
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup(handle.container(),
|
|
||||||
handle.name(), &table));
|
|
||||||
core::ScopedUnref unref_me(table);
|
core::ScopedUnref unref_me(table);
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup<ResourceAlias, true>(
|
||||||
|
handle.container(), handle.name(), &table));
|
||||||
|
|
||||||
auto* mutex = table->GetMutex();
|
auto* mutex = table->GetMutex();
|
||||||
auto* threadpool = ctx->device()->tensorflow_cpu_worker_threads()->workers;
|
auto* threadpool = ctx->device()->tensorflow_cpu_worker_threads()->workers;
|
||||||
@ -382,112 +403,20 @@ class HeterogeneousLookupTableFindOp : public OpKernel {
|
|||||||
// writer lock here.
|
// writer lock here.
|
||||||
mutex_lock lock(*mutex);
|
mutex_lock lock(*mutex);
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, TensorLookup(*table, prefetch_lookahead, num_keys_per_thread,
|
ctx, Functor::TensorLookup(*table, keys, prefetch_lookahead,
|
||||||
keys, out, threadpool));
|
num_keys_per_thread, threadpool, out));
|
||||||
} else {
|
} else {
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, TensorLookup(*table, prefetch_lookahead, num_keys_per_thread,
|
ctx, Functor::TensorLookup(*table, keys, prefetch_lookahead,
|
||||||
keys, out, threadpool));
|
num_keys_per_thread, threadpool, out));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
// keys and *values arguments to TensorLookup must have the same number of
|
|
||||||
// elements. This is guaranteed above.
|
|
||||||
|
|
||||||
// 'Simple' types below are types which are not natively supported in TF.
|
|
||||||
// Simple LookupKeyTensorType which is the same as Container::key_type.
|
|
||||||
template <typename SfinaeArg = LookupKeyTensorType>
|
|
||||||
absl::enable_if_t<
|
|
||||||
IsValidDataType<SfinaeArg>::value &&
|
|
||||||
std::is_same<SfinaeArg, typename Container::key_type>::value,
|
|
||||||
Status>
|
|
||||||
TensorLookup(Container& table, int64 prefetch_lookahead,
|
|
||||||
int64 num_keys_per_thread, const Tensor& keys, Tensor* values,
|
|
||||||
thread::ThreadPool* threadpool) const {
|
|
||||||
const auto keys_flat = keys.flat<LookupKeyTensorType>();
|
|
||||||
const auto keys_size = keys_flat.size();
|
|
||||||
auto key_span = absl::MakeSpan(keys_flat.data(), keys_size);
|
|
||||||
auto value_span = absl::MakeSpan(values->flat<ValueTensorType>().data(),
|
|
||||||
values->NumElements());
|
|
||||||
return MultithreadedTensorLookup(table, prefetch_lookahead,
|
|
||||||
num_keys_per_thread, key_span, value_span,
|
|
||||||
threadpool);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to implicitly convert all other simple LookupKeyTensorTypes to
|
|
||||||
// Container::key_type.
|
|
||||||
template <typename SfinaeArg = LookupKeyTensorType>
|
|
||||||
absl::enable_if_t<
|
|
||||||
IsValidDataType<SfinaeArg>::value &&
|
|
||||||
!std::is_same<SfinaeArg, typename Container::key_type>::value,
|
|
||||||
Status>
|
|
||||||
TensorLookup(Container& table, int64 prefetch_lookahead,
|
|
||||||
int64 num_keys_per_thread, const Tensor& keys, Tensor* values,
|
|
||||||
thread::ThreadPool* threadpool) const {
|
|
||||||
const auto keys_flat = keys.flat<LookupKeyTensorType>();
|
|
||||||
std::vector<typename Container::key_type> keys_vec;
|
|
||||||
const auto keys_size = keys_flat.size();
|
|
||||||
keys_vec.reserve(keys_size);
|
|
||||||
for (size_t i = 0; i < keys_size; ++i) {
|
|
||||||
keys_vec.emplace_back(keys_flat(i));
|
|
||||||
}
|
|
||||||
absl::Span<typename Container::key_type> key_span(keys_vec);
|
|
||||||
auto value_span = absl::MakeSpan(values->flat<ValueTensorType>().data(),
|
|
||||||
values->NumElements());
|
|
||||||
return MultithreadedTensorLookup(table, prefetch_lookahead,
|
|
||||||
num_keys_per_thread, key_span, value_span,
|
|
||||||
threadpool);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Non-simple LookupKeyTensorType. We'll try an implicit conversion to
|
|
||||||
// Container::key_type.
|
|
||||||
template <typename VariantSubType = LookupKeyTensorType>
|
|
||||||
absl::enable_if_t<!IsValidDataType<VariantSubType>::value, Status>
|
|
||||||
TensorLookup(Container& table, int64 prefetch_lookahead,
|
|
||||||
int64 num_keys_per_thread, const Tensor& keys, Tensor* values,
|
|
||||||
thread::ThreadPool* threadpool) const {
|
|
||||||
const auto keys_flat = keys.flat<Variant>();
|
|
||||||
std::vector<typename Container::key_type> keys_vec;
|
|
||||||
const auto keys_size = keys_flat.size();
|
|
||||||
keys_vec.reserve(keys_size);
|
|
||||||
for (size_t i = 0; i < keys_size; ++i) {
|
|
||||||
keys_vec.emplace_back(
|
|
||||||
*keys_flat(i).get<typename VariantSubType::value_type>());
|
|
||||||
}
|
|
||||||
absl::Span<typename Container::key_type> key_span(keys_vec);
|
|
||||||
auto value_span = absl::MakeSpan(values->flat<ValueTensorType>().data(),
|
|
||||||
values->NumElements());
|
|
||||||
return MultithreadedTensorLookup(table, prefetch_lookahead,
|
|
||||||
num_keys_per_thread, key_span, value_span,
|
|
||||||
threadpool);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wrapper around table.BatchLookup which permits sharding across cores.
|
|
||||||
template <typename K, typename V>
|
|
||||||
Status MultithreadedTensorLookup(Container& table, int64 prefetch_lookahead,
|
|
||||||
int64 num_keys_per_thread,
|
|
||||||
absl::Span<K> keys, absl::Span<V> values,
|
|
||||||
thread::ThreadPool* threadpool) const {
|
|
||||||
mutex temp_mutex; // Protect status.
|
|
||||||
Status status;
|
|
||||||
auto lookup_keys = [&, this](int64 begin, int64 end) {
|
|
||||||
auto temp_status = table.BatchLookup(keys.subspan(begin, end - begin),
|
|
||||||
values.subspan(begin, end - begin),
|
|
||||||
prefetch_lookahead);
|
|
||||||
if (ABSL_PREDICT_FALSE(!temp_status.ok())) {
|
|
||||||
mutex_lock lock(temp_mutex);
|
|
||||||
status.Update(temp_status);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
threadpool->TransformRangeConcurrently(num_keys_per_thread /* block_size */,
|
|
||||||
keys.size(), lookup_keys);
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Op that returns the size of a container.
|
// Lookup a container of type ResourceAlias and return its size using
|
||||||
template <class Container>
|
// Functor::Size(container, &size).
|
||||||
|
template <typename Functor,
|
||||||
|
typename ResourceAlias = typename Functor::resource_type>
|
||||||
class ContainerSizeOp : public OpKernel {
|
class ContainerSizeOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit ContainerSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
explicit ContainerSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
@ -495,11 +424,10 @@ class ContainerSizeOp : public OpKernel {
|
|||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
const Tensor& container_handle = ctx->input(0);
|
const Tensor& container_handle = ctx->input(0);
|
||||||
ResourceHandle handle(container_handle.scalar<ResourceHandle>()());
|
ResourceHandle handle(container_handle.scalar<ResourceHandle>()());
|
||||||
Container* container;
|
ResourceAlias* container;
|
||||||
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup(
|
|
||||||
handle.container(), handle.name(), &container));
|
|
||||||
core::ScopedUnref unref_me(container);
|
core::ScopedUnref unref_me(container);
|
||||||
OP_REQUIRES_OK(ctx, container->SizeStatus());
|
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup<ResourceAlias, true>(
|
||||||
|
handle.container(), handle.name(), &container));
|
||||||
|
|
||||||
Tensor* out;
|
Tensor* out;
|
||||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
|
||||||
@ -507,9 +435,9 @@ class ContainerSizeOp : public OpKernel {
|
|||||||
auto* mutex = container->GetMutex();
|
auto* mutex = container->GetMutex();
|
||||||
if (mutex != nullptr) {
|
if (mutex != nullptr) {
|
||||||
tf_shared_lock lock(*mutex);
|
tf_shared_lock lock(*mutex);
|
||||||
out->scalar<int64>()() = container->UnsafeSize();
|
OP_REQUIRES_OK(ctx, Functor::Size(*container, &out->scalar<uint64>()()));
|
||||||
} else {
|
} else {
|
||||||
out->scalar<int64>()() = container->UnsafeSize();
|
OP_REQUIRES_OK(ctx, Functor::Size(*container, &out->scalar<uint64>()()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1,87 +0,0 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_TABLE_RESOURCE_UTILS_H_
|
|
||||||
#define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_TABLE_RESOURCE_UTILS_H_
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
|
||||||
#include "tensorflow/core/kernels/lookup_tables/lookup_table_interface.h"
|
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace tables {
|
|
||||||
|
|
||||||
// Parent class for tables with support for multithreaded synchronization.
|
|
||||||
template <typename HeterogeneousKeyType, typename ValueType>
|
|
||||||
class LookupTableWithSynchronization
|
|
||||||
: public LookupTableInterface<HeterogeneousKeyType, ValueType> {
|
|
||||||
public:
|
|
||||||
LookupTableWithSynchronization(bool enable_synchronization) {
|
|
||||||
if (enable_synchronization) {
|
|
||||||
mutex_ = absl::make_unique<mutex>();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mutex for synchronizing access to unsynchronized methods.
|
|
||||||
mutex* GetMutex() const override { return mutex_.get(); }
|
|
||||||
|
|
||||||
private:
|
|
||||||
// Use this for locking.
|
|
||||||
mutable std::unique_ptr<mutex> mutex_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Parent class for tables which can be constructed with arbitrary
|
|
||||||
// lookup fallbacks.
|
|
||||||
// Since LookupTableInterface::LookupKey assumes that all keys can be mapped
|
|
||||||
// to values, LookupTableWithFallbackInterface allows clients to implement
|
|
||||||
// two-stage lookups. If the first key lookup fails, clients can choose
|
|
||||||
// to perform a fallback lookup using an externally supplied table.
|
|
||||||
template <typename HeterogeneousKeyType, typename ValueType,
|
|
||||||
typename FallbackTableBaseType =
|
|
||||||
LookupTableInterface<HeterogeneousKeyType, ValueType>>
|
|
||||||
class LookupTableWithFallbackInterface
|
|
||||||
: public LookupTableWithSynchronization<HeterogeneousKeyType, ValueType> {
|
|
||||||
public:
|
|
||||||
LookupTableWithFallbackInterface(bool enable_synchronization,
|
|
||||||
const FallbackTableBaseType* fallback_table)
|
|
||||||
: LookupTableWithSynchronization<HeterogeneousKeyType, ValueType>(
|
|
||||||
enable_synchronization),
|
|
||||||
fallback_table_(fallback_table) {}
|
|
||||||
|
|
||||||
// Clients are required to fail when ctx is set to a not-OK status in
|
|
||||||
// the constructor so this dereference is safe.
|
|
||||||
const FallbackTableBaseType& fallback_table() const {
|
|
||||||
return *fallback_table_;
|
|
||||||
}
|
|
||||||
|
|
||||||
~LookupTableWithFallbackInterface() override {
|
|
||||||
if (fallback_table_ != nullptr) {
|
|
||||||
fallback_table_->Unref();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
const FallbackTableBaseType* fallback_table_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace tables
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_TABLE_RESOURCE_UTILS_H_
|
|
Loading…
Reference in New Issue
Block a user