Update Fingerprint64Map to use aliases

PiperOrigin-RevId: 234022159
This commit is contained in:
A. Unique TensorFlower 2019-02-14 13:49:09 -08:00 committed by TensorFlower Gardener
parent 5b90573d41
commit e7d9786e66
6 changed files with 279 additions and 478 deletions

View File

@ -132,14 +132,14 @@ class ResourceMgr {
// //
// REQUIRES: std::is_base_of<ResourceBase, T> // REQUIRES: std::is_base_of<ResourceBase, T>
// REQUIRES: resource != nullptr // REQUIRES: resource != nullptr
template <typename T> template <typename T, bool use_dynamic_cast = false>
Status Lookup(const string& container, const string& name, Status Lookup(const string& container, const string& name,
T** resource) const TF_MUST_USE_RESULT; T** resource) const TF_MUST_USE_RESULT;
// Similar to Lookup, but looks up multiple resources at once, with only a // Similar to Lookup, but looks up multiple resources at once, with only a
// single lock acquisition. If containers_and_names[i] is uninitialized // single lock acquisition. If containers_and_names[i] is uninitialized
// then this function does not modify resources[i]. // then this function does not modify resources[i].
template <typename T> template <typename T, bool use_dynamic_cast = false>
Status LookupMany(absl::Span<std::pair<const string*, const string*> const> Status LookupMany(absl::Span<std::pair<const string*, const string*> const>
containers_and_names, containers_and_names,
std::vector<std::unique_ptr<T, core::RefCountDeleter>>* std::vector<std::unique_ptr<T, core::RefCountDeleter>>*
@ -155,7 +155,7 @@ class ResourceMgr {
// //
// REQUIRES: std::is_base_of<ResourceBase, T> // REQUIRES: std::is_base_of<ResourceBase, T>
// REQUIRES: resource != nullptr // REQUIRES: resource != nullptr
template <typename T> template <typename T, bool use_dynamic_cast = false>
Status LookupOrCreate(const string& container, const string& name, Status LookupOrCreate(const string& container, const string& name,
T** resource, T** resource,
std::function<Status(T**)> creator) TF_MUST_USE_RESULT; std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
@ -196,7 +196,7 @@ class ResourceMgr {
mutable mutex mu_; mutable mutex mu_;
std::unordered_map<string, Container*> containers_ GUARDED_BY(mu_); std::unordered_map<string, Container*> containers_ GUARDED_BY(mu_);
template <typename T> template <typename T, bool use_dynamic_cast = false>
Status LookupInternal(const string& container, const string& name, Status LookupInternal(const string& container, const string& name,
T** resource) const T** resource) const
SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT; SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
@ -267,7 +267,7 @@ Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
// //
// If the lookup is successful, the caller takes the ownership of one ref on // If the lookup is successful, the caller takes the ownership of one ref on
// `*value`, and must call its `Unref()` method when it has finished using it. // `*value`, and must call its `Unref()` method when it has finished using it.
template <typename T> template <typename T, bool use_dynamic_cast = false>
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value); Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
// Looks up multiple resources pointed by a sequence of resource handles. If // Looks up multiple resources pointed by a sequence of resource handles. If
@ -437,15 +437,15 @@ Status ResourceMgr::Create(const string& container, const string& name,
return DoCreate(container, MakeTypeIndex<T>(), name, resource); return DoCreate(container, MakeTypeIndex<T>(), name, resource);
} }
template <typename T> template <typename T, bool use_dynamic_cast>
Status ResourceMgr::Lookup(const string& container, const string& name, Status ResourceMgr::Lookup(const string& container, const string& name,
T** resource) const { T** resource) const {
CheckDeriveFromResourceBase<T>(); CheckDeriveFromResourceBase<T>();
tf_shared_lock l(mu_); tf_shared_lock l(mu_);
return LookupInternal(container, name, resource); return LookupInternal<T, use_dynamic_cast>(container, name, resource);
} }
template <typename T> template <typename T, bool use_dynamic_cast>
Status ResourceMgr::LookupMany( Status ResourceMgr::LookupMany(
absl::Span<std::pair<const string*, const string*> const> absl::Span<std::pair<const string*, const string*> const>
containers_and_names, containers_and_names,
@ -455,8 +455,9 @@ Status ResourceMgr::LookupMany(
resources->resize(containers_and_names.size()); resources->resize(containers_and_names.size());
for (size_t i = 0; i < containers_and_names.size(); ++i) { for (size_t i = 0; i < containers_and_names.size(); ++i) {
T* resource; T* resource;
Status s = LookupInternal(*containers_and_names[i].first, Status s = LookupInternal<T, use_dynamic_cast>(
*containers_and_names[i].second, &resource); *containers_and_names[i].first, *containers_and_names[i].second,
&resource);
if (s.ok()) { if (s.ok()) {
(*resources)[i].reset(resource); (*resources)[i].reset(resource);
} }
@ -464,7 +465,18 @@ Status ResourceMgr::LookupMany(
return Status::OK(); return Status::OK();
} }
// Simple wrapper to allow conditional dynamic / static casts.
template <typename T, bool use_dynamic_cast>
struct TypeCastFunctor {
static T* Cast(ResourceBase* r) { return static_cast<T*>(r); }
};
template <typename T> template <typename T>
struct TypeCastFunctor<T, true> {
static T* Cast(ResourceBase* r) { return dynamic_cast<T*>(r); }
};
template <typename T, bool use_dynamic_cast>
Status ResourceMgr::LookupInternal(const string& container, const string& name, Status ResourceMgr::LookupInternal(const string& container, const string& name,
T** resource) const { T** resource) const {
ResourceBase* found = nullptr; ResourceBase* found = nullptr;
@ -472,12 +484,12 @@ Status ResourceMgr::LookupInternal(const string& container, const string& name,
if (s.ok()) { if (s.ok()) {
// It's safe to down cast 'found' to T* since // It's safe to down cast 'found' to T* since
// typeid(T).hash_code() is part of the map key. // typeid(T).hash_code() is part of the map key.
*resource = static_cast<T*>(found); *resource = TypeCastFunctor<T, use_dynamic_cast>::Cast(found);
} }
return s; return s;
} }
template <typename T> template <typename T, bool use_dynamic_cast>
Status ResourceMgr::LookupOrCreate(const string& container, const string& name, Status ResourceMgr::LookupOrCreate(const string& container, const string& name,
T** resource, T** resource,
std::function<Status(T**)> creator) { std::function<Status(T**)> creator) {
@ -486,11 +498,11 @@ Status ResourceMgr::LookupOrCreate(const string& container, const string& name,
Status s; Status s;
{ {
tf_shared_lock l(mu_); tf_shared_lock l(mu_);
s = LookupInternal(container, name, resource); s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
if (s.ok()) return s; if (s.ok()) return s;
} }
mutex_lock l(mu_); mutex_lock l(mu_);
s = LookupInternal(container, name, resource); s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
if (s.ok()) return s; if (s.ok()) return s;
TF_RETURN_IF_ERROR(creator(resource)); TF_RETURN_IF_ERROR(creator(resource));
s = DoCreate(container, MakeTypeIndex<T>(), name, *resource); s = DoCreate(container, MakeTypeIndex<T>(), name, *resource);
@ -566,11 +578,12 @@ Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) {
return ctx->resource_manager()->Create(p.container(), p.name(), value); return ctx->resource_manager()->Create(p.container(), p.name(), value);
} }
template <typename T> template <typename T, bool use_dynamic_cast>
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
T** value) { T** value) {
TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p)); TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
return ctx->resource_manager()->Lookup(p.container(), p.name(), value); return ctx->resource_manager()->Lookup<T, use_dynamic_cast>(p.container(),
p.name(), value);
} }
template <typename T> template <typename T>

View File

@ -19,18 +19,6 @@ cc_library(
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "table_resource_utils",
hdrs = ["table_resource_utils.h"],
deps = [
":lookup_table_interface",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
], ],
) )
@ -57,8 +45,8 @@ tf_kernel_library(
"fingerprint64_map_ops.cc", "fingerprint64_map_ops.cc",
], ],
deps = [ deps = [
":lookup_table_interface",
":table_op_utils", ":table_op_utils",
":table_resource_utils",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",

View File

@ -15,8 +15,8 @@ limitations under the License.
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/lookup_tables/lookup_table_interface.h"
#include "tensorflow/core/kernels/lookup_tables/table_op_utils.h" #include "tensorflow/core/kernels/lookup_tables/table_op_utils.h"
#include "tensorflow/core/kernels/lookup_tables/table_resource_utils.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
@ -27,41 +27,23 @@ namespace tables {
// Map x -> (Fingerprint64(x) % num_oov_buckets) + offset. // Map x -> (Fingerprint64(x) % num_oov_buckets) + offset.
// num_oov_buckets and offset are node attributes provided at construction // num_oov_buckets and offset are node attributes provided at construction
// time. // time.
template <class HeterogeneousKeyType, class ValueType> template <typename KeyType, typename ValueType>
class Fingerprint64Map final class Fingerprint64Map final
: public LookupTableInterface<HeterogeneousKeyType, ValueType> { : public virtual LookupInterface<ValueType*, const KeyType&>,
public virtual LookupWithPrefetchInterface<absl::Span<ValueType>,
absl::Span<const KeyType>> {
public: public:
using key_type = KeyType;
Fingerprint64Map(int64 num_oov_buckets, int64 offset) Fingerprint64Map(int64 num_oov_buckets, int64 offset)
: num_oov_buckets_(num_oov_buckets), offset_(offset) {} : num_oov_buckets_(num_oov_buckets), offset_(offset) {}
mutex* GetMutex() const override { return nullptr; } Status Lookup(const KeyType& key_to_find, ValueType* value) const override {
*value = LookupHelper(key_to_find);
bool UnsafeInsertOrAssign(const HeterogeneousKeyType& key, return Status::OK();
const ValueType& value) override {
return true;
} }
Status TableUnbatchedInsertStatus() const override { Status Lookup(absl::Span<const KeyType> keys, absl::Span<ValueType> values,
return errors::Unimplemented("Fingerprint64Map does not support inserts.");
}
Status BatchInsertOrAssign(absl::Span<const HeterogeneousKeyType> keys,
absl::Span<const ValueType> values) override {
return errors::Unimplemented("Fingerprint64Map does not support inserts.");
}
ValueType UnsafeLookupKey(
const HeterogeneousKeyType& key_to_find) const override {
// This can cause a downcast.
return static_cast<ValueType>(Fingerprint64(key_to_find) %
num_oov_buckets_) +
offset_;
}
Status TableUnbatchedLookupStatus() const override { return Status::OK(); }
Status BatchLookup(absl::Span<const HeterogeneousKeyType> keys,
absl::Span<ValueType> values,
int64 prefetch_lookahead) const override { int64 prefetch_lookahead) const override {
if (ABSL_PREDICT_FALSE(keys.size() != values.size())) { if (ABSL_PREDICT_FALSE(keys.size() != values.size())) {
return errors::InvalidArgument( return errors::InvalidArgument(
@ -69,31 +51,24 @@ class Fingerprint64Map final
keys.size(), " vs ", values.size(), ")."); keys.size(), " vs ", values.size(), ").");
} }
for (size_t i = 0; i < keys.size(); ++i) { for (size_t i = 0; i < keys.size(); ++i) {
values[i] = Fingerprint64Map::UnsafeLookupKey(keys[i]); values[i] = LookupHelper(keys[i]);
} }
return Status::OK(); return Status::OK();
} }
const absl::optional<const ValueType> DefaultValue() const override { mutex* GetMutex() const override { return nullptr; }
return {};
}
void UnsafePrefetchKey( string DebugString() const override { return __PRETTY_FUNCTION__; }
const HeterogeneousKeyType& key_to_find) const override {}
size_t UnsafeSize() const override { return 0; }
Status SizeStatus() const override {
return errors::Unimplemented(
"Fingerprint64Map does not have a concept of size.");
}
bool UnsafeContainsKey(
const HeterogeneousKeyType& key_to_find) const override {
return true;
}
private: private:
ABSL_ATTRIBUTE_ALWAYS_INLINE ValueType
LookupHelper(const KeyType& key_to_find) const {
// This can cause a downcast.
return static_cast<ValueType>(Fingerprint64(key_to_find) %
num_oov_buckets_) +
offset_;
}
const int64 num_oov_buckets_; const int64 num_oov_buckets_;
const int64 offset_; const int64 offset_;
TF_DISALLOW_COPY_AND_ASSIGN(Fingerprint64Map); TF_DISALLOW_COPY_AND_ASSIGN(Fingerprint64Map);
@ -102,9 +77,10 @@ class Fingerprint64Map final
template <typename Fingerprint64Map> template <typename Fingerprint64Map>
struct Fingerprint64MapFactory { struct Fingerprint64MapFactory {
struct Functor { struct Functor {
template <typename ContainerBase> using resource_type = Fingerprint64Map;
static Status AllocateContainer(OpKernelContext* ctx, OpKernel* kernel, static Status AllocateContainer(OpKernelContext* ctx, OpKernel* kernel,
ContainerBase** container) { Fingerprint64Map** container) {
int64 num_oov_buckets; int64 num_oov_buckets;
int64 offset; int64 offset;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
@ -116,24 +92,28 @@ struct Fingerprint64MapFactory {
}; };
}; };
#define REGISTER_STRING_KERNEL(table_value_dtype) \ template <typename KeyType, typename ValueType>
using ResourceOp = ResourceConstructionOp<
typename Fingerprint64MapFactory<
Fingerprint64Map<KeyType, ValueType>>::Functor,
// These are the aliases.
LookupInterface<ValueType*, const KeyType&>,
LookupWithPrefetchInterface<absl::Span<ValueType>,
absl::Span<const KeyType>>>;
#define REGISTER_STRING_KERNEL(ValueType) \
REGISTER_KERNEL_BUILDER( \ REGISTER_KERNEL_BUILDER( \
Name("Fingerprint64Map") \ Name("Fingerprint64Map") \
.Device(DEVICE_CPU) \ .Device(DEVICE_CPU) \
.TypeConstraint<Variant>("heterogeneous_key_dtype") \ .TypeConstraint<Variant>("heterogeneous_key_dtype") \
.TypeConstraint<table_value_dtype>("table_value_dtype"), \ .TypeConstraint<ValueType>("table_value_dtype"), \
ResourceConstructionOp< \ ResourceOp<absl::string_view, ValueType>); \
LookupTableInterface<absl::string_view, table_value_dtype>, \
Fingerprint64MapFactory<Fingerprint64Map< \
absl::string_view, table_value_dtype>>::Functor>); \
REGISTER_KERNEL_BUILDER( \ REGISTER_KERNEL_BUILDER( \
Name("Fingerprint64Map") \ Name("Fingerprint64Map") \
.Device(DEVICE_CPU) \ .Device(DEVICE_CPU) \
.TypeConstraint<string>("heterogeneous_key_dtype") \ .TypeConstraint<string>("heterogeneous_key_dtype") \
.TypeConstraint<table_value_dtype>("table_value_dtype"), \ .TypeConstraint<ValueType>("table_value_dtype"), \
ResourceConstructionOp<LookupTableInterface<string, table_value_dtype>, \ ResourceOp<string, ValueType>);
Fingerprint64MapFactory<Fingerprint64Map< \
string, table_value_dtype>>::Functor>);
REGISTER_STRING_KERNEL(int32); REGISTER_STRING_KERNEL(int32);
REGISTER_STRING_KERNEL(int64); REGISTER_STRING_KERNEL(int64);

View File

@ -16,11 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_LOOKUP_TABLE_INTERFACE_H_ #ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_LOOKUP_TABLE_INTERFACE_H_
#define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_LOOKUP_TABLE_INTERFACE_H_ #define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_LOOKUP_TABLE_INTERFACE_H_
#include <cstddef>
#include <string>
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
@ -28,90 +23,74 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace tables { namespace tables {
// Interface for key-value pair lookups with support for heterogeneous keys. // Interface for resources with mutable state.
// This class contains two main kinds of methods: methods which operate on class SynchronizedInterface : public virtual ResourceBase {
// a batch of inputs and methods which do not. The latter have the prefix
// 'Unsafe'. Clients must call the corresponding status methods to determine
// whether they are safe to call within a code block.
// Implementations must guarantee thread-safety when GetMutex is used to
// synchronize method access.
template <typename HeterogeneousKeyType, typename ValueType>
class LookupTableInterface : public ResourceBase {
public: public:
using heterogeneous_key_type = HeterogeneousKeyType;
using value_type = ValueType;
using key_type = heterogeneous_key_type;
// Return value should be used to synchronize read/write access to // Return value should be used to synchronize read/write access to
// all public methods. If null, no synchronization is needed. // all public methods. If null, no synchronization is needed.
virtual mutex* GetMutex() const = 0; virtual mutex* GetMutex() const = 0;
};
// Insert the KV pair into the underlying table. If a key equivalent to key // Interface for containers which support batch lookups.
// already exists in the underlying table, its corresponding value is template <typename ValueType, typename... KeyContext>
// overridden. Returns true only if the key was inserted for the first time. class InsertOrAssignInterface : public virtual SynchronizedInterface {
// Undefined if TableUnbatchedInsertStatus() != OK. public:
virtual bool UnsafeInsertOrAssign(const HeterogeneousKeyType& key, using value_type = ValueType;
const ValueType& value) = 0;
// Returns OK if it is safe to call InsertOrAssign.
// Once OK is returned, it is safe to call InsertOrAssign for the rest of the
// program.
virtual Status TableUnbatchedInsertStatus() const TF_MUST_USE_RESULT = 0;
// Stores each KV pair {keys[i], values[i]} in the underlying map, overriding // Stores each KV pair {keys[i], values[i]} in the underlying map, overriding
// pre-existing pairs which have equivalent keys. // pre-existing pairs which have equivalent keys.
// keys and values should have the same size. // keys and values should have the same size.
virtual Status BatchInsertOrAssign( virtual Status InsertOrAssign(KeyContext... key_context,
absl::Span<const HeterogeneousKeyType> keys, ValueType values) = 0;
absl::Span<const ValueType> values) = 0; };
// Prefetch key_to_find into implementation defined data caches. // Interface for containers which support lookups.
// Implementations are free to leave this a no-op. template <typename ValueType, typename... KeyContext>
// Undefined if TableUnbatchedLookupStatus() != OK. class LookupInterface : public virtual SynchronizedInterface {
virtual void UnsafePrefetchKey( public:
const HeterogeneousKeyType& key_to_find) const {} using value_type = ValueType;
// Returns true if and only if the table contains key_to_find.
// Undefined if TableUnbatchedLookupStatus() != OK.
virtual bool UnsafeContainsKey(
const HeterogeneousKeyType& key_to_find) const = 0;
// Lookup the value for key_to_find. This value must always be well-defined,
// even when ContainsKey(key_to_find) == false. When
// dv = DefaultValue() != absl::nullopt and ContainsKey(key_to_find) == false,
// dv is returned.
// Undefined if TableUnbatchedLookupStatus() != OK.
virtual ValueType UnsafeLookupKey(
const HeterogeneousKeyType& key_to_find) const = 0;
// Returns OK if it is safe to call PrefetchKey, ContainsKey, and
// UnsafeLookupKey.
// If OK is returned, it is safe to call these methods until the next
// non-const method of this class is called.
virtual Status TableUnbatchedLookupStatus() const TF_MUST_USE_RESULT = 0;
// Lookup the values for keys and store them in values. // Lookup the values for keys and store them in values.
// prefetch_lookahead is used to prefetch the key at index // prefetch_lookahead is used to prefetch the key at index
// i + prefetch_lookahead at the ith iteration of the implemented loop. // i + prefetch_lookahead at the ith iteration of the implemented loop.
// keys and values must have the same size. // keys and values must have the same size.
virtual Status BatchLookup(absl::Span<const HeterogeneousKeyType> keys, virtual Status Lookup(KeyContext... key_context, ValueType values) const = 0;
absl::Span<ValueType> values, };
// Interface for containers which support lookups with prefetching.
template <typename ValueType, typename... KeyContext>
class LookupWithPrefetchInterface : public virtual SynchronizedInterface {
public:
using value_type = ValueType;
// Lookup the values for keys and store them in values.
// prefetch_lookahead is used to prefetch the key at index
// i + prefetch_lookahead at the ith iteration of the implemented loop.
// keys and values must have the same size.
virtual Status Lookup(KeyContext... key_context, ValueType values,
int64 prefetch_lookahead) const = 0; int64 prefetch_lookahead) const = 0;
};
// Returns the number of elements in the table. // Interface for containers with size concepts.
// Undefined if SizeStatus() != OK. // Implementations must guarantee thread-safety when GetMutex is used to
virtual size_t UnsafeSize() const = 0; // synchronize method access.
class SizeInterface : public virtual SynchronizedInterface {
public:
// Returns the number of elements in the container.
virtual uint64 Size() const = 0;
};
// Returns OK if the return value of UnsafeSize() is always well-defined. // Interface for tables which can be initialized from key and value arguments.
virtual Status SizeStatus() const TF_MUST_USE_RESULT = 0; template <typename ValueType, typename... KeyContext>
class KeyValueTableInitializerInterface : public virtual SynchronizedInterface {
public:
using value_type = ValueType;
// If non-null value is returned, LookupKey returns that value only for keys // Lookup the values for keys and store them in values.
// which satisfy ContainsKey(key_to_find) == false. // prefetch_lookahead is used to prefetch the key at index
virtual const absl::optional<const ValueType> DefaultValue() const = 0; // i + prefetch_lookahead at the ith iteration of the implemented loop.
// keys and values must have the same size.
string DebugString() const override { return "A lookup table"; } virtual Status Initialize(KeyContext... key_context, ValueType values) = 0;
~LookupTableInterface() override = default;
}; };
} // namespace tables } // namespace tables

View File

@ -44,11 +44,11 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace tables { namespace tables {
// Create resources of type ContainerBase using the static method // Create resources of type ResourceType and AliasesToRegister using
// Functor::AllocateContainer(OpKernelConstruction*, OpKernel*, // Functor::AllocateContainer(OpKernelConstruction*, OpKernel*,
// ContainerBase**) // ResourceType**). ResourceType = Functor::resource_type.
// If the resource has already been created it will be looked up. // No-op for resources which have already been created.
template <class ContainerBase, typename Functor> template <typename Functor, typename... AliasesToRegister>
class ResourceConstructionOp : public OpKernel { class ResourceConstructionOp : public OpKernel {
public: public:
explicit ResourceConstructionOp(OpKernelConstruction* ctx) explicit ResourceConstructionOp(OpKernelConstruction* ctx)
@ -66,46 +66,86 @@ class ResourceConstructionOp : public OpKernel {
} }
auto creator = [ctx, auto creator = [ctx,
this](ContainerBase** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { this](ResourceType** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
ContainerBase* container; ResourceType* resource = nullptr;
auto status = Functor::AllocateContainer(ctx, this, &container); auto status = Functor::AllocateContainer(ctx, this, &resource);
if (ABSL_PREDICT_FALSE(!status.ok())) { if (ABSL_PREDICT_FALSE(!status.ok())) {
container->Unref(); // Ideally resource is non-null only if status is OK but we try
// to compensate here.
if (resource != nullptr) {
resource->Unref();
}
return status; return status;
} }
if (ctx->track_allocations()) { if (ctx->track_allocations()) {
ctx->record_persistent_memory_allocation(container->MemoryUsed()); ctx->record_persistent_memory_allocation(resource->MemoryUsed());
} }
*ret = container; *ret = resource;
return Status::OK(); return Status::OK();
}; };
ContainerBase* container_base = nullptr; // Register the ResourceType alias.
ResourceType* resource = nullptr;
core::ScopedUnref unref_me(resource);
OP_REQUIRES_OK( OP_REQUIRES_OK(
ctx, cinfo_.resource_manager()->template LookupOrCreate<ContainerBase>( ctx,
cinfo_.container(), cinfo_.name(), &container_base, creator)); cinfo_.resource_manager()->template LookupOrCreate<ResourceType, true>(
core::ScopedUnref unref_me(container_base); cinfo_.container(), cinfo_.name(), &resource, creator));
// Put a handle to resource in the output tensor (the other aliases will
// have the same handle).
Tensor* handle; Tensor* handle;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle)); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
handle->scalar<ResourceHandle>()() = MakeResourceHandle<ContainerBase>( handle->scalar<ResourceHandle>()() = MakeResourceHandle<ResourceType>(
ctx, cinfo_.container(), cinfo_.name()); ctx, cinfo_.container(), cinfo_.name());
table_handle_set_ = true; table_handle_set_ = true;
// Create other alias resources.
Status status;
char dummy[sizeof...(AliasesToRegister)] = {
(status.Update(RegisterAlias<AliasesToRegister>(resource)), 0)...};
(void)dummy;
OP_REQUIRES_OK(ctx, status);
} }
~ResourceConstructionOp() override { ~ResourceConstructionOp() override {
// If the table object was not shared, delete it. // If the table object was not shared, delete it.
if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) { if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
if (!cinfo_.resource_manager() if (!cinfo_.resource_manager()
->template Delete<ContainerBase>(cinfo_.container(), ->template Delete<ResourceType>(cinfo_.container(),
cinfo_.name()) cinfo_.name())
.ok()) { .ok()) {
// Do nothing; the resource may have been deleted by session resets. // Do nothing; the resource may have been deleted by session resets.
} }
// Attempt to delete other resource aliases.
Status dummy_status;
char dummy[sizeof...(AliasesToRegister)] = {
(dummy_status.Update(DeleteAlias<AliasesToRegister>()), 0)...};
(void)dummy;
} }
} }
private: private:
using ResourceType = typename Functor::resource_type;
template <typename T>
Status RegisterAlias(ResourceType* resource) {
auto creator = [resource](T** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = resource;
return Status::OK();
};
T* alias_resource = nullptr;
core::ScopedUnref unref_me(alias_resource);
return cinfo_.resource_manager()->template LookupOrCreate<T, true>(
cinfo_.container(), cinfo_.name(), &alias_resource, creator);
}
template <typename T>
Status DeleteAlias() {
return cinfo_.resource_manager()->template Delete<T>(cinfo_.container(),
cinfo_.name());
}
mutex mu_; mutex mu_;
bool table_handle_set_ GUARDED_BY(mu_); bool table_handle_set_ GUARDED_BY(mu_);
ContainerInfo cinfo_; ContainerInfo cinfo_;
@ -120,8 +160,7 @@ class ResourceConstructionOp : public OpKernel {
// If the resource has already been created it will be looked up. // If the resource has already been created it will be looked up.
// Container must decrease the reference count of the FallbackTableBaseType* // Container must decrease the reference count of the FallbackTableBaseType*
// constructor argument before its destructor completes. // constructor argument before its destructor completes.
template <class ContainerBase, class Functor, template <typename Functor, typename... AliasesToRegister>
class FallbackTableBaseType = ContainerBase>
class TableWithFallbackConstructionOp : public OpKernel { class TableWithFallbackConstructionOp : public OpKernel {
public: public:
explicit TableWithFallbackConstructionOp(OpKernelConstruction* ctx) explicit TableWithFallbackConstructionOp(OpKernelConstruction* ctx)
@ -140,13 +179,14 @@ class TableWithFallbackConstructionOp : public OpKernel {
return; return;
} }
// Look up the fallback table.
FallbackTableBaseType* fallback_table = nullptr; FallbackTableBaseType* fallback_table = nullptr;
{ {
const Tensor& table_handle = ctx->input(table_int64_args.size()); const Tensor& table_handle = ctx->input(table_int64_args.size());
ResourceHandle handle(table_handle.scalar<ResourceHandle>()()); ResourceHandle handle(table_handle.scalar<ResourceHandle>()());
OP_REQUIRES_OK( OP_REQUIRES_OK(
ctx, ctx->resource_manager()->Lookup(handle.container(), ctx, ctx->resource_manager()->Lookup<FallbackTableBaseType, true>(
handle.name(), &fallback_table)); handle.container(), handle.name(), &fallback_table));
} }
mutex_lock l(mu_); mutex_lock l(mu_);
@ -156,51 +196,93 @@ class TableWithFallbackConstructionOp : public OpKernel {
} }
auto creator = [ctx, this, fallback_table]( auto creator = [ctx, this, fallback_table](
ContainerBase** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { ResourceType** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
// container construction logic can't be merged with // container construction logic can't be merged with
// ResourceConstructionOp because Container constructor requires an // ResourceConstructionOp because Container constructor requires an
// input which can only be constructed if the resource manager // input which can only be constructed if the resource manager
// internal lock is not already held. // internal lock is not already held.
ContainerBase* container; ResourceType* resource = nullptr;
auto status = auto status =
Functor::AllocateContainer(ctx, this, fallback_table, &container); Functor::AllocateContainer(ctx, this, fallback_table, &resource);
if (ABSL_PREDICT_FALSE(!status.ok())) { if (ABSL_PREDICT_FALSE(!status.ok())) {
container->Unref(); // Ideally resource is non-null only if status is OK but we try
// to compensate here.
if (resource != nullptr) {
resource->Unref();
}
return status; return status;
} }
if (ctx->track_allocations()) { if (ctx->track_allocations()) {
ctx->record_persistent_memory_allocation(container->MemoryUsed()); ctx->record_persistent_memory_allocation(resource->MemoryUsed());
} }
*ret = container; *ret = resource;
return Status::OK(); return Status::OK();
}; };
ContainerBase* table = nullptr; // Register the ResourceType alias.
OP_REQUIRES_OK( ResourceType* table = nullptr;
ctx, cinfo_.resource_manager()->template LookupOrCreate<ContainerBase>(
cinfo_.container(), cinfo_.name(), &table, creator));
core::ScopedUnref unref_me(table); core::ScopedUnref unref_me(table);
OP_REQUIRES_OK(
ctx,
cinfo_.resource_manager()->template LookupOrCreate<ResourceType, true>(
cinfo_.container(), cinfo_.name(), &table, creator));
// Put a handle to resource in the output tensor (the other aliases will
// have the same handle).
Tensor* handle; Tensor* handle;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle)); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
handle->scalar<ResourceHandle>()() = MakeResourceHandle<ContainerBase>( handle->scalar<ResourceHandle>()() = MakeResourceHandle<ResourceType>(
ctx, cinfo_.container(), cinfo_.name()); ctx, cinfo_.container(), cinfo_.name());
table_handle_set_ = true; table_handle_set_ = true;
// Create other alias resources.
Status status;
char dummy[sizeof...(AliasesToRegister)] = {
(status.Update(RegisterAlias<AliasesToRegister>(table)), 0)...};
(void)dummy;
OP_REQUIRES_OK(ctx, status);
} }
~TableWithFallbackConstructionOp() override { ~TableWithFallbackConstructionOp() override {
// If the table object was not shared, delete it. // If the table object was not shared, delete it.
if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) { if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
if (!cinfo_.resource_manager() if (!cinfo_.resource_manager()
->template Delete<ContainerBase>(cinfo_.container(), ->template Delete<ResourceType>(cinfo_.container(),
cinfo_.name()) cinfo_.name())
.ok()) { .ok()) {
// Do nothing; the resource may have been deleted by session resets. // Do nothing; the resource may have been deleted by session resets.
} }
// Attempt to delete other resource aliases.
Status dummy_status;
char dummy[sizeof...(AliasesToRegister)] = {
(dummy_status.Update(DeleteAlias<AliasesToRegister>()), 0)...};
(void)dummy;
} }
} }
private: private:
using ResourceType = typename Functor::resource_type;
using FallbackTableBaseType = typename Functor::fallback_table_type;
template <typename T>
Status RegisterAlias(ResourceType* resource) {
auto creator = [resource](T** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = resource;
return Status::OK();
};
T* alias_resource = nullptr;
core::ScopedUnref unref_me(alias_resource);
return cinfo_.resource_manager()->template LookupOrCreate<T, true>(
cinfo_.container(), cinfo_.name(), &alias_resource, creator);
}
template <typename T>
Status DeleteAlias() {
return cinfo_.resource_manager()->template Delete<T>(cinfo_.container(),
cinfo_.name());
}
mutex mu_; mutex mu_;
bool table_handle_set_ GUARDED_BY(mu_); bool table_handle_set_ GUARDED_BY(mu_);
ContainerInfo cinfo_; ContainerInfo cinfo_;
@ -209,33 +291,29 @@ class TableWithFallbackConstructionOp : public OpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(TableWithFallbackConstructionOp); TF_DISALLOW_COPY_AND_ASSIGN(TableWithFallbackConstructionOp);
}; };
// Used to insert tensors into a container. // Lookup a table of type ResourceAlias and insert the passed in keys and
template <class Container, class InsertKeyTensorType, // values tensors using Functor::TensorInsert(keys, values, table).
class InsertValueTensorType> template <typename Functor,
class HeterogeneousLookupTableInsertOrAssignOp : public OpKernel { typename ResourceAlias = typename Functor::resource_type>
class LookupTableInsertOp : public OpKernel {
public: public:
explicit HeterogeneousLookupTableInsertOrAssignOp(OpKernelConstruction* ctx) explicit LookupTableInsertOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
: OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
OpInputList table_int64_args; OpInputList table_int64_args;
OP_REQUIRES_OK(ctx, ctx->input_list("table_int64_args", &table_int64_args)); OP_REQUIRES_OK(ctx, ctx->input_list("table_int64_args", &table_int64_args));
const size_t tensor_index_offset = table_int64_args.size(); const size_t tensor_index_offset = table_int64_args.size();
// Business logic for checking tensor shapes, etc, is delegated to the
// Functor.
const Tensor& keys = ctx->input(tensor_index_offset + 1); const Tensor& keys = ctx->input(tensor_index_offset + 1);
const Tensor& values = ctx->input(tensor_index_offset + 2); const Tensor& values = ctx->input(tensor_index_offset + 2);
if (ABSL_PREDICT_FALSE(keys.NumElements() != values.NumElements())) {
ctx->SetStatus(errors::InvalidArgument(
"keys and values do not have the same number of elements: ",
keys.NumElements(), " vs ", values.NumElements()));
return;
}
const Tensor& table_handle = ctx->input(tensor_index_offset); const Tensor& table_handle = ctx->input(tensor_index_offset);
ResourceHandle handle(table_handle.scalar<ResourceHandle>()()); ResourceHandle handle(table_handle.scalar<ResourceHandle>()());
Container* table; ResourceAlias* table;
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup(handle.container(),
handle.name(), &table));
core::ScopedUnref unref_me(table); core::ScopedUnref unref_me(table);
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup<ResourceAlias, true>(
handle.container(), handle.name(), &table));
int memory_used_before = 0; int memory_used_before = 0;
if (ctx->track_allocations()) { if (ctx->track_allocations()) {
@ -244,9 +322,9 @@ class HeterogeneousLookupTableInsertOrAssignOp : public OpKernel {
auto* mutex = table->GetMutex(); auto* mutex = table->GetMutex();
if (mutex != nullptr) { if (mutex != nullptr) {
mutex_lock lock(*mutex); mutex_lock lock(*mutex);
OP_REQUIRES_OK(ctx, TensorInsert(keys, values, table)); OP_REQUIRES_OK(ctx, Functor::TensorInsert(keys, values, table));
} else { } else {
OP_REQUIRES_OK(ctx, TensorInsert(keys, values, table)); OP_REQUIRES_OK(ctx, Functor::TensorInsert(keys, values, table));
} }
if (ctx->track_allocations()) { if (ctx->track_allocations()) {
ctx->record_persistent_memory_allocation(table->MemoryUsed() - ctx->record_persistent_memory_allocation(table->MemoryUsed() -
@ -255,74 +333,17 @@ class HeterogeneousLookupTableInsertOrAssignOp : public OpKernel {
} }
private: private:
// Non-variant InsertKeyTensorType which is the same as Container::key_type. TF_DISALLOW_COPY_AND_ASSIGN(LookupTableInsertOp);
// No need to static_cast.
template <typename SfinaeArg = InsertKeyTensorType>
absl::enable_if_t<
IsValidDataType<SfinaeArg>::value &&
std::is_same<SfinaeArg, typename Container::key_type>::value,
Status>
TensorInsert(const Tensor& keys, const Tensor& values,
Container* table) const {
const auto keys_flat = keys.flat<SfinaeArg>();
const auto values_flat = values.flat<InsertValueTensorType>();
return table->BatchInsertOrAssign(
absl::MakeSpan(keys_flat.data(), keys_flat.size()),
absl::MakeSpan(values_flat.data(), values_flat.size()));
}
// Non-variant InsertKeyTensorType which is otherwise convertible to
// Container::key_type.
template <typename SfinaeArg = InsertKeyTensorType>
absl::enable_if_t<
IsValidDataType<SfinaeArg>::value &&
!std::is_same<SfinaeArg, typename Container::key_type>::value &&
std::is_convertible<SfinaeArg, typename Container::key_type>::value,
Status>
TensorInsert(const Tensor& keys, const Tensor& values,
Container* table) const {
const auto keys_flat = keys.flat<InsertKeyTensorType>();
std::vector<typename Container::key_type> keys_vec;
const auto keys_size = keys_flat.size();
keys_vec.reserve(keys_size);
for (size_t i = 0; i < keys_size; ++i) {
keys_vec.push_back(
static_cast<typename Container::key_type>(keys_flat(i)));
}
const auto values_flat = values.flat<InsertValueTensorType>();
return table->BatchInsertOrAssign(
keys_vec, absl::MakeSpan(values_flat.data(), values_flat.size()));
}
// Variant InsertKeyTensorType; the wrapped type is convertible to
// Container::key_type.
template <typename SfinaeArg = InsertKeyTensorType>
absl::enable_if_t<
!IsValidDataType<SfinaeArg>::value &&
std::is_convertible<typename SfinaeArg::value_type,
typename Container::key_type>::value,
Status>
TensorInsert(const Tensor& keys, const Tensor& values,
Container* table) const {
const auto keys_flat = keys.flat<Variant>();
std::vector<typename Container::key_type> keys_vec;
keys_vec.reserve(keys_flat.size());
for (size_t i = 0; i < keys_flat.size(); ++i) {
keys_vec.emplace_back(
*keys_flat(i).get<typename SfinaeArg::value_type>());
}
const auto values_flat = values.flat<InsertValueTensorType>();
return table->BatchInsertOrAssign(
keys_vec, absl::MakeSpan(values_flat.data(), values_flat.size()));
}
}; };
// Used for tensor lookups. // Lookup a table of type ResourceAlias and look up the passed in keys using
template <class Container, class LookupKeyTensorType, class ValueTensorType> // Functor::TensorLookup(
class HeterogeneousLookupTableFindOp : public OpKernel { // table, keys, prefetch_lookahead, num_keys_per_thread, threadpool, out).
template <typename Functor,
typename ResourceAlias = typename Functor::resource_type>
class LookupTableFindOp : public OpKernel {
public: public:
explicit HeterogeneousLookupTableFindOp(OpKernelConstruction* ctx) explicit LookupTableFindOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
: OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
OpInputList table_int64_args; OpInputList table_int64_args;
@ -370,10 +391,10 @@ class HeterogeneousLookupTableFindOp : public OpKernel {
const Tensor& table_handle = ctx->input(tensor_index_offset); const Tensor& table_handle = ctx->input(tensor_index_offset);
ResourceHandle handle(table_handle.scalar<ResourceHandle>()()); ResourceHandle handle(table_handle.scalar<ResourceHandle>()());
Container* table; ResourceAlias* table;
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup(handle.container(),
handle.name(), &table));
core::ScopedUnref unref_me(table); core::ScopedUnref unref_me(table);
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup<ResourceAlias, true>(
handle.container(), handle.name(), &table));
auto* mutex = table->GetMutex(); auto* mutex = table->GetMutex();
auto* threadpool = ctx->device()->tensorflow_cpu_worker_threads()->workers; auto* threadpool = ctx->device()->tensorflow_cpu_worker_threads()->workers;
@ -382,112 +403,20 @@ class HeterogeneousLookupTableFindOp : public OpKernel {
// writer lock here. // writer lock here.
mutex_lock lock(*mutex); mutex_lock lock(*mutex);
OP_REQUIRES_OK( OP_REQUIRES_OK(
ctx, TensorLookup(*table, prefetch_lookahead, num_keys_per_thread, ctx, Functor::TensorLookup(*table, keys, prefetch_lookahead,
keys, out, threadpool)); num_keys_per_thread, threadpool, out));
} else { } else {
OP_REQUIRES_OK( OP_REQUIRES_OK(
ctx, TensorLookup(*table, prefetch_lookahead, num_keys_per_thread, ctx, Functor::TensorLookup(*table, keys, prefetch_lookahead,
keys, out, threadpool)); num_keys_per_thread, threadpool, out));
} }
} }
private:
// keys and *values arguments to TensorLookup must have the same number of
// elements. This is guaranteed above.
// 'Simple' types below are types which are not natively supported in TF.
// Simple LookupKeyTensorType which is the same as Container::key_type.
template <typename SfinaeArg = LookupKeyTensorType>
absl::enable_if_t<
IsValidDataType<SfinaeArg>::value &&
std::is_same<SfinaeArg, typename Container::key_type>::value,
Status>
TensorLookup(Container& table, int64 prefetch_lookahead,
int64 num_keys_per_thread, const Tensor& keys, Tensor* values,
thread::ThreadPool* threadpool) const {
const auto keys_flat = keys.flat<LookupKeyTensorType>();
const auto keys_size = keys_flat.size();
auto key_span = absl::MakeSpan(keys_flat.data(), keys_size);
auto value_span = absl::MakeSpan(values->flat<ValueTensorType>().data(),
values->NumElements());
return MultithreadedTensorLookup(table, prefetch_lookahead,
num_keys_per_thread, key_span, value_span,
threadpool);
}
// Try to implicitly convert all other simple LookupKeyTensorTypes to
// Container::key_type.
template <typename SfinaeArg = LookupKeyTensorType>
absl::enable_if_t<
IsValidDataType<SfinaeArg>::value &&
!std::is_same<SfinaeArg, typename Container::key_type>::value,
Status>
TensorLookup(Container& table, int64 prefetch_lookahead,
int64 num_keys_per_thread, const Tensor& keys, Tensor* values,
thread::ThreadPool* threadpool) const {
const auto keys_flat = keys.flat<LookupKeyTensorType>();
std::vector<typename Container::key_type> keys_vec;
const auto keys_size = keys_flat.size();
keys_vec.reserve(keys_size);
for (size_t i = 0; i < keys_size; ++i) {
keys_vec.emplace_back(keys_flat(i));
}
absl::Span<typename Container::key_type> key_span(keys_vec);
auto value_span = absl::MakeSpan(values->flat<ValueTensorType>().data(),
values->NumElements());
return MultithreadedTensorLookup(table, prefetch_lookahead,
num_keys_per_thread, key_span, value_span,
threadpool);
}
// Non-simple LookupKeyTensorType. We'll try an implicit conversion to
// Container::key_type.
template <typename VariantSubType = LookupKeyTensorType>
absl::enable_if_t<!IsValidDataType<VariantSubType>::value, Status>
TensorLookup(Container& table, int64 prefetch_lookahead,
int64 num_keys_per_thread, const Tensor& keys, Tensor* values,
thread::ThreadPool* threadpool) const {
const auto keys_flat = keys.flat<Variant>();
std::vector<typename Container::key_type> keys_vec;
const auto keys_size = keys_flat.size();
keys_vec.reserve(keys_size);
for (size_t i = 0; i < keys_size; ++i) {
keys_vec.emplace_back(
*keys_flat(i).get<typename VariantSubType::value_type>());
}
absl::Span<typename Container::key_type> key_span(keys_vec);
auto value_span = absl::MakeSpan(values->flat<ValueTensorType>().data(),
values->NumElements());
return MultithreadedTensorLookup(table, prefetch_lookahead,
num_keys_per_thread, key_span, value_span,
threadpool);
}
// Wrapper around table.BatchLookup which permits sharding across cores.
template <typename K, typename V>
Status MultithreadedTensorLookup(Container& table, int64 prefetch_lookahead,
int64 num_keys_per_thread,
absl::Span<K> keys, absl::Span<V> values,
thread::ThreadPool* threadpool) const {
mutex temp_mutex; // Protect status.
Status status;
auto lookup_keys = [&, this](int64 begin, int64 end) {
auto temp_status = table.BatchLookup(keys.subspan(begin, end - begin),
values.subspan(begin, end - begin),
prefetch_lookahead);
if (ABSL_PREDICT_FALSE(!temp_status.ok())) {
mutex_lock lock(temp_mutex);
status.Update(temp_status);
}
};
threadpool->TransformRangeConcurrently(num_keys_per_thread /* block_size */,
keys.size(), lookup_keys);
return status;
}
}; };
// Op that returns the size of a container. // Lookup a container of type ResourceAlias and return its size using
template <class Container> // Functor::Size(container, &size).
template <typename Functor,
typename ResourceAlias = typename Functor::resource_type>
class ContainerSizeOp : public OpKernel { class ContainerSizeOp : public OpKernel {
public: public:
explicit ContainerSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} explicit ContainerSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
@ -495,11 +424,10 @@ class ContainerSizeOp : public OpKernel {
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
const Tensor& container_handle = ctx->input(0); const Tensor& container_handle = ctx->input(0);
ResourceHandle handle(container_handle.scalar<ResourceHandle>()()); ResourceHandle handle(container_handle.scalar<ResourceHandle>()());
Container* container; ResourceAlias* container;
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup(
handle.container(), handle.name(), &container));
core::ScopedUnref unref_me(container); core::ScopedUnref unref_me(container);
OP_REQUIRES_OK(ctx, container->SizeStatus()); OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup<ResourceAlias, true>(
handle.container(), handle.name(), &container));
Tensor* out; Tensor* out;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out)); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
@ -507,9 +435,9 @@ class ContainerSizeOp : public OpKernel {
auto* mutex = container->GetMutex(); auto* mutex = container->GetMutex();
if (mutex != nullptr) { if (mutex != nullptr) {
tf_shared_lock lock(*mutex); tf_shared_lock lock(*mutex);
out->scalar<int64>()() = container->UnsafeSize(); OP_REQUIRES_OK(ctx, Functor::Size(*container, &out->scalar<uint64>()()));
} else { } else {
out->scalar<int64>()() = container->UnsafeSize(); OP_REQUIRES_OK(ctx, Functor::Size(*container, &out->scalar<uint64>()()));
} }
} }
}; };

View File

@ -1,87 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_TABLE_RESOURCE_UTILS_H_
#define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_TABLE_RESOURCE_UTILS_H_
#include <memory>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/lookup_tables/lookup_table_interface.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace tables {
// Parent class for tables with support for multithreaded synchronization.
template <typename HeterogeneousKeyType, typename ValueType>
class LookupTableWithSynchronization
: public LookupTableInterface<HeterogeneousKeyType, ValueType> {
public:
LookupTableWithSynchronization(bool enable_synchronization) {
if (enable_synchronization) {
mutex_ = absl::make_unique<mutex>();
}
}
// Mutex for synchronizing access to unsynchronized methods.
mutex* GetMutex() const override { return mutex_.get(); }
private:
// Use this for locking.
mutable std::unique_ptr<mutex> mutex_;
};
// Parent class for tables which can be constructed with arbitrary
// lookup fallbacks.
// Since LookupTableInterface::LookupKey assumes that all keys can be mapped
// to values, LookupTableWithFallbackInterface allows clients to implement
// two-stage lookups. If the first key lookup fails, clients can choose
// to perform a fallback lookup using an externally supplied table.
template <typename HeterogeneousKeyType, typename ValueType,
typename FallbackTableBaseType =
LookupTableInterface<HeterogeneousKeyType, ValueType>>
class LookupTableWithFallbackInterface
: public LookupTableWithSynchronization<HeterogeneousKeyType, ValueType> {
public:
LookupTableWithFallbackInterface(bool enable_synchronization,
const FallbackTableBaseType* fallback_table)
: LookupTableWithSynchronization<HeterogeneousKeyType, ValueType>(
enable_synchronization),
fallback_table_(fallback_table) {}
// Clients are required to fail when ctx is set to a not-OK status in
// the constructor so this dereference is safe.
const FallbackTableBaseType& fallback_table() const {
return *fallback_table_;
}
~LookupTableWithFallbackInterface() override {
if (fallback_table_ != nullptr) {
fallback_table_->Unref();
}
}
private:
const FallbackTableBaseType* fallback_table_;
};
} // namespace tables
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_TABLE_RESOURCE_UTILS_H_