Update Fingerprint64Map to use aliases

PiperOrigin-RevId: 234022159
This commit is contained in:
A. Unique TensorFlower 2019-02-14 13:49:09 -08:00 committed by TensorFlower Gardener
parent 5b90573d41
commit e7d9786e66
6 changed files with 279 additions and 478 deletions

View File

@ -132,14 +132,14 @@ class ResourceMgr {
//
// REQUIRES: std::is_base_of<ResourceBase, T>
// REQUIRES: resource != nullptr
template <typename T>
template <typename T, bool use_dynamic_cast = false>
Status Lookup(const string& container, const string& name,
T** resource) const TF_MUST_USE_RESULT;
// Similar to Lookup, but looks up multiple resources at once, with only a
// single lock acquisition. If containers_and_names[i] is uninitialized
// then this function does not modify resources[i].
template <typename T>
template <typename T, bool use_dynamic_cast = false>
Status LookupMany(absl::Span<std::pair<const string*, const string*> const>
containers_and_names,
std::vector<std::unique_ptr<T, core::RefCountDeleter>>*
@ -155,7 +155,7 @@ class ResourceMgr {
//
// REQUIRES: std::is_base_of<ResourceBase, T>
// REQUIRES: resource != nullptr
template <typename T>
template <typename T, bool use_dynamic_cast = false>
Status LookupOrCreate(const string& container, const string& name,
T** resource,
std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
@ -196,7 +196,7 @@ class ResourceMgr {
mutable mutex mu_;
std::unordered_map<string, Container*> containers_ GUARDED_BY(mu_);
template <typename T>
template <typename T, bool use_dynamic_cast = false>
Status LookupInternal(const string& container, const string& name,
T** resource) const
SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
@ -267,7 +267,7 @@ Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
//
// If the lookup is successful, the caller takes the ownership of one ref on
// `*value`, and must call its `Unref()` method when it has finished using it.
template <typename T>
template <typename T, bool use_dynamic_cast = false>
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
// Looks up multiple resources pointed by a sequence of resource handles. If
@ -437,15 +437,15 @@ Status ResourceMgr::Create(const string& container, const string& name,
return DoCreate(container, MakeTypeIndex<T>(), name, resource);
}
template <typename T>
template <typename T, bool use_dynamic_cast>
Status ResourceMgr::Lookup(const string& container, const string& name,
T** resource) const {
CheckDeriveFromResourceBase<T>();
tf_shared_lock l(mu_);
return LookupInternal(container, name, resource);
return LookupInternal<T, use_dynamic_cast>(container, name, resource);
}
template <typename T>
template <typename T, bool use_dynamic_cast>
Status ResourceMgr::LookupMany(
absl::Span<std::pair<const string*, const string*> const>
containers_and_names,
@ -455,8 +455,9 @@ Status ResourceMgr::LookupMany(
resources->resize(containers_and_names.size());
for (size_t i = 0; i < containers_and_names.size(); ++i) {
T* resource;
Status s = LookupInternal(*containers_and_names[i].first,
*containers_and_names[i].second, &resource);
Status s = LookupInternal<T, use_dynamic_cast>(
*containers_and_names[i].first, *containers_and_names[i].second,
&resource);
if (s.ok()) {
(*resources)[i].reset(resource);
}
@ -464,7 +465,18 @@ Status ResourceMgr::LookupMany(
return Status::OK();
}
// Simple wrapper to allow conditional dynamic / static casts.
template <typename T, bool use_dynamic_cast>
struct TypeCastFunctor {
static T* Cast(ResourceBase* r) { return static_cast<T*>(r); }
};
template <typename T>
struct TypeCastFunctor<T, true> {
static T* Cast(ResourceBase* r) { return dynamic_cast<T*>(r); }
};
template <typename T, bool use_dynamic_cast>
Status ResourceMgr::LookupInternal(const string& container, const string& name,
T** resource) const {
ResourceBase* found = nullptr;
@ -472,12 +484,12 @@ Status ResourceMgr::LookupInternal(const string& container, const string& name,
if (s.ok()) {
// It's safe to down cast 'found' to T* since
// typeid(T).hash_code() is part of the map key.
*resource = static_cast<T*>(found);
*resource = TypeCastFunctor<T, use_dynamic_cast>::Cast(found);
}
return s;
}
template <typename T>
template <typename T, bool use_dynamic_cast>
Status ResourceMgr::LookupOrCreate(const string& container, const string& name,
T** resource,
std::function<Status(T**)> creator) {
@ -486,11 +498,11 @@ Status ResourceMgr::LookupOrCreate(const string& container, const string& name,
Status s;
{
tf_shared_lock l(mu_);
s = LookupInternal(container, name, resource);
s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
if (s.ok()) return s;
}
mutex_lock l(mu_);
s = LookupInternal(container, name, resource);
s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
if (s.ok()) return s;
TF_RETURN_IF_ERROR(creator(resource));
s = DoCreate(container, MakeTypeIndex<T>(), name, *resource);
@ -566,11 +578,12 @@ Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) {
return ctx->resource_manager()->Create(p.container(), p.name(), value);
}
template <typename T>
template <typename T, bool use_dynamic_cast>
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
T** value) {
TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
return ctx->resource_manager()->Lookup(p.container(), p.name(), value);
return ctx->resource_manager()->Lookup<T, use_dynamic_cast>(p.container(),
p.name(), value);
}
template <typename T>

View File

@ -19,18 +19,6 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "table_resource_utils",
hdrs = ["table_resource_utils.h"],
deps = [
":lookup_table_interface",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
@ -57,8 +45,8 @@ tf_kernel_library(
"fingerprint64_map_ops.cc",
],
deps = [
":lookup_table_interface",
":table_op_utils",
":table_resource_utils",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",

View File

@ -15,8 +15,8 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/lookup_tables/lookup_table_interface.h"
#include "tensorflow/core/kernels/lookup_tables/table_op_utils.h"
#include "tensorflow/core/kernels/lookup_tables/table_resource_utils.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/macros.h"
@ -27,73 +27,48 @@ namespace tables {
// Map x -> (Fingerprint64(x) % num_oov_buckets) + offset.
// num_oov_buckets and offset are node attributes provided at construction
// time.
template <class HeterogeneousKeyType, class ValueType>
template <typename KeyType, typename ValueType>
class Fingerprint64Map final
: public LookupTableInterface<HeterogeneousKeyType, ValueType> {
: public virtual LookupInterface<ValueType*, const KeyType&>,
public virtual LookupWithPrefetchInterface<absl::Span<ValueType>,
absl::Span<const KeyType>> {
public:
using key_type = KeyType;
Fingerprint64Map(int64 num_oov_buckets, int64 offset)
: num_oov_buckets_(num_oov_buckets), offset_(offset) {}
mutex* GetMutex() const override { return nullptr; }
bool UnsafeInsertOrAssign(const HeterogeneousKeyType& key,
const ValueType& value) override {
return true;
Status Lookup(const KeyType& key_to_find, ValueType* value) const override {
*value = LookupHelper(key_to_find);
return Status::OK();
}
Status TableUnbatchedInsertStatus() const override {
return errors::Unimplemented("Fingerprint64Map does not support inserts.");
}
Status BatchInsertOrAssign(absl::Span<const HeterogeneousKeyType> keys,
absl::Span<const ValueType> values) override {
return errors::Unimplemented("Fingerprint64Map does not support inserts.");
}
ValueType UnsafeLookupKey(
const HeterogeneousKeyType& key_to_find) const override {
// This can cause a downcast.
return static_cast<ValueType>(Fingerprint64(key_to_find) %
num_oov_buckets_) +
offset_;
}
Status TableUnbatchedLookupStatus() const override { return Status::OK(); }
Status BatchLookup(absl::Span<const HeterogeneousKeyType> keys,
absl::Span<ValueType> values,
int64 prefetch_lookahead) const override {
Status Lookup(absl::Span<const KeyType> keys, absl::Span<ValueType> values,
int64 prefetch_lookahead) const override {
if (ABSL_PREDICT_FALSE(keys.size() != values.size())) {
return errors::InvalidArgument(
"keys and values do not have the same number of elements (found ",
keys.size(), " vs ", values.size(), ").");
}
for (size_t i = 0; i < keys.size(); ++i) {
values[i] = Fingerprint64Map::UnsafeLookupKey(keys[i]);
values[i] = LookupHelper(keys[i]);
}
return Status::OK();
}
const absl::optional<const ValueType> DefaultValue() const override {
return {};
}
mutex* GetMutex() const override { return nullptr; }
void UnsafePrefetchKey(
const HeterogeneousKeyType& key_to_find) const override {}
size_t UnsafeSize() const override { return 0; }
Status SizeStatus() const override {
return errors::Unimplemented(
"Fingerprint64Map does not have a concept of size.");
}
bool UnsafeContainsKey(
const HeterogeneousKeyType& key_to_find) const override {
return true;
}
string DebugString() const override { return __PRETTY_FUNCTION__; }
private:
ABSL_ATTRIBUTE_ALWAYS_INLINE ValueType
LookupHelper(const KeyType& key_to_find) const {
// This can cause a downcast.
return static_cast<ValueType>(Fingerprint64(key_to_find) %
num_oov_buckets_) +
offset_;
}
const int64 num_oov_buckets_;
const int64 offset_;
TF_DISALLOW_COPY_AND_ASSIGN(Fingerprint64Map);
@ -102,9 +77,10 @@ class Fingerprint64Map final
template <typename Fingerprint64Map>
struct Fingerprint64MapFactory {
struct Functor {
template <typename ContainerBase>
using resource_type = Fingerprint64Map;
static Status AllocateContainer(OpKernelContext* ctx, OpKernel* kernel,
ContainerBase** container) {
Fingerprint64Map** container) {
int64 num_oov_buckets;
int64 offset;
TF_RETURN_IF_ERROR(
@ -116,24 +92,28 @@ struct Fingerprint64MapFactory {
};
};
#define REGISTER_STRING_KERNEL(table_value_dtype) \
REGISTER_KERNEL_BUILDER( \
Name("Fingerprint64Map") \
.Device(DEVICE_CPU) \
.TypeConstraint<Variant>("heterogeneous_key_dtype") \
.TypeConstraint<table_value_dtype>("table_value_dtype"), \
ResourceConstructionOp< \
LookupTableInterface<absl::string_view, table_value_dtype>, \
Fingerprint64MapFactory<Fingerprint64Map< \
absl::string_view, table_value_dtype>>::Functor>); \
REGISTER_KERNEL_BUILDER( \
Name("Fingerprint64Map") \
.Device(DEVICE_CPU) \
.TypeConstraint<string>("heterogeneous_key_dtype") \
.TypeConstraint<table_value_dtype>("table_value_dtype"), \
ResourceConstructionOp<LookupTableInterface<string, table_value_dtype>, \
Fingerprint64MapFactory<Fingerprint64Map< \
string, table_value_dtype>>::Functor>);
template <typename KeyType, typename ValueType>
using ResourceOp = ResourceConstructionOp<
typename Fingerprint64MapFactory<
Fingerprint64Map<KeyType, ValueType>>::Functor,
// These are the aliases.
LookupInterface<ValueType*, const KeyType&>,
LookupWithPrefetchInterface<absl::Span<ValueType>,
absl::Span<const KeyType>>>;
#define REGISTER_STRING_KERNEL(ValueType) \
REGISTER_KERNEL_BUILDER( \
Name("Fingerprint64Map") \
.Device(DEVICE_CPU) \
.TypeConstraint<Variant>("heterogeneous_key_dtype") \
.TypeConstraint<ValueType>("table_value_dtype"), \
ResourceOp<absl::string_view, ValueType>); \
REGISTER_KERNEL_BUILDER( \
Name("Fingerprint64Map") \
.Device(DEVICE_CPU) \
.TypeConstraint<string>("heterogeneous_key_dtype") \
.TypeConstraint<ValueType>("table_value_dtype"), \
ResourceOp<string, ValueType>);
REGISTER_STRING_KERNEL(int32);
REGISTER_STRING_KERNEL(int64);

View File

@ -16,11 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_LOOKUP_TABLE_INTERFACE_H_
#define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_LOOKUP_TABLE_INTERFACE_H_
#include <cstddef>
#include <string>
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/mutex.h"
@ -28,90 +23,74 @@ limitations under the License.
namespace tensorflow {
namespace tables {
// Interface for key-value pair lookups with support for heterogeneous keys.
// This class contains two main kinds of methods: methods which operate on
// a batch of inputs and methods which do not. The latter have the prefix
// 'Unsafe'. Clients must call the corresponding status methods to determine
// whether they are safe to call within a code block.
// Implementations must guarantee thread-safety when GetMutex is used to
// synchronize method access.
template <typename HeterogeneousKeyType, typename ValueType>
class LookupTableInterface : public ResourceBase {
// Interface for resources with mutable state.
class SynchronizedInterface : public virtual ResourceBase {
public:
using heterogeneous_key_type = HeterogeneousKeyType;
using value_type = ValueType;
using key_type = heterogeneous_key_type;
// Return value should be used to synchronize read/write access to
// all public methods. If null, no synchronization is needed.
virtual mutex* GetMutex() const = 0;
};
// Insert the KV pair into the underlying table. If a key equivalent to key
// already exists in the underlying table, its corresponding value is
// overridden. Returns true only if the key was inserted for the first time.
// Undefined if TableUnbatchedInsertStatus() != OK.
virtual bool UnsafeInsertOrAssign(const HeterogeneousKeyType& key,
const ValueType& value) = 0;
// Returns OK if it is safe to call InsertOrAssign.
// Once OK is returned, it is safe to call InsertOrAssign for the rest of the
// program.
virtual Status TableUnbatchedInsertStatus() const TF_MUST_USE_RESULT = 0;
// Interface for containers which support batch lookups.
template <typename ValueType, typename... KeyContext>
class InsertOrAssignInterface : public virtual SynchronizedInterface {
public:
using value_type = ValueType;
// Stores each KV pair {keys[i], values[i]} in the underlying map, overriding
// pre-existing pairs which have equivalent keys.
// keys and values should have the same size.
virtual Status BatchInsertOrAssign(
absl::Span<const HeterogeneousKeyType> keys,
absl::Span<const ValueType> values) = 0;
virtual Status InsertOrAssign(KeyContext... key_context,
ValueType values) = 0;
};
// Prefetch key_to_find into implementation defined data caches.
// Implementations are free to leave this a no-op.
// Undefined if TableUnbatchedLookupStatus() != OK.
virtual void UnsafePrefetchKey(
const HeterogeneousKeyType& key_to_find) const {}
// Returns true if and only if the table contains key_to_find.
// Undefined if TableUnbatchedLookupStatus() != OK.
virtual bool UnsafeContainsKey(
const HeterogeneousKeyType& key_to_find) const = 0;
// Lookup the value for key_to_find. This value must always be well-defined,
// even when ContainsKey(key_to_find) == false. When
// dv = DefaultValue() != absl::nullopt and ContainsKey(key_to_find) == false,
// dv is returned.
// Undefined if TableUnbatchedLookupStatus() != OK.
virtual ValueType UnsafeLookupKey(
const HeterogeneousKeyType& key_to_find) const = 0;
// Returns OK if it is safe to call PrefetchKey, ContainsKey, and
// UnsafeLookupKey.
// If OK is returned, it is safe to call these methods until the next
// non-const method of this class is called.
virtual Status TableUnbatchedLookupStatus() const TF_MUST_USE_RESULT = 0;
// Interface for containers which support lookups.
template <typename ValueType, typename... KeyContext>
class LookupInterface : public virtual SynchronizedInterface {
public:
using value_type = ValueType;
// Lookup the values for keys and store them in values.
// prefetch_lookahead is used to prefetch the key at index
// i + prefetch_lookahead at the ith iteration of the implemented loop.
// keys and values must have the same size.
virtual Status BatchLookup(absl::Span<const HeterogeneousKeyType> keys,
absl::Span<ValueType> values,
int64 prefetch_lookahead) const = 0;
virtual Status Lookup(KeyContext... key_context, ValueType values) const = 0;
};
// Returns the number of elements in the table.
// Undefined if SizeStatus() != OK.
virtual size_t UnsafeSize() const = 0;
// Interface for containers which support lookups with prefetching.
template <typename ValueType, typename... KeyContext>
class LookupWithPrefetchInterface : public virtual SynchronizedInterface {
public:
using value_type = ValueType;
// Returns OK if the return value of UnsafeSize() is always well-defined.
virtual Status SizeStatus() const TF_MUST_USE_RESULT = 0;
// Lookup the values for keys and store them in values.
// prefetch_lookahead is used to prefetch the key at index
// i + prefetch_lookahead at the ith iteration of the implemented loop.
// keys and values must have the same size.
virtual Status Lookup(KeyContext... key_context, ValueType values,
int64 prefetch_lookahead) const = 0;
};
// If non-null value is returned, LookupKey returns that value only for keys
// which satisfy ContainsKey(key_to_find) == false.
virtual const absl::optional<const ValueType> DefaultValue() const = 0;
// Interface for containers with size concepts.
// Implementations must guarantee thread-safety when GetMutex is used to
// synchronize method access.
class SizeInterface : public virtual SynchronizedInterface {
public:
// Returns the number of elements in the container.
virtual uint64 Size() const = 0;
};
string DebugString() const override { return "A lookup table"; }
// Interface for tables which can be initialized from key and value arguments.
template <typename ValueType, typename... KeyContext>
class KeyValueTableInitializerInterface : public virtual SynchronizedInterface {
public:
using value_type = ValueType;
~LookupTableInterface() override = default;
// Lookup the values for keys and store them in values.
// prefetch_lookahead is used to prefetch the key at index
// i + prefetch_lookahead at the ith iteration of the implemented loop.
// keys and values must have the same size.
virtual Status Initialize(KeyContext... key_context, ValueType values) = 0;
};
} // namespace tables

View File

@ -44,11 +44,11 @@ limitations under the License.
namespace tensorflow {
namespace tables {
// Create resources of type ContainerBase using the static method
// Create resources of type ResourceType and AliasesToRegister using
// Functor::AllocateContainer(OpKernelConstruction*, OpKernel*,
// ContainerBase**)
// If the resource has already been created it will be looked up.
template <class ContainerBase, typename Functor>
// ResourceType**). ResourceType = Functor::resource_type.
// No-op for resources which have already been created.
template <typename Functor, typename... AliasesToRegister>
class ResourceConstructionOp : public OpKernel {
public:
explicit ResourceConstructionOp(OpKernelConstruction* ctx)
@ -66,46 +66,86 @@ class ResourceConstructionOp : public OpKernel {
}
auto creator = [ctx,
this](ContainerBase** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
ContainerBase* container;
auto status = Functor::AllocateContainer(ctx, this, &container);
this](ResourceType** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
ResourceType* resource = nullptr;
auto status = Functor::AllocateContainer(ctx, this, &resource);
if (ABSL_PREDICT_FALSE(!status.ok())) {
container->Unref();
// Ideally resource is non-null only if status is OK but we try
// to compensate here.
if (resource != nullptr) {
resource->Unref();
}
return status;
}
if (ctx->track_allocations()) {
ctx->record_persistent_memory_allocation(container->MemoryUsed());
ctx->record_persistent_memory_allocation(resource->MemoryUsed());
}
*ret = container;
*ret = resource;
return Status::OK();
};
ContainerBase* container_base = nullptr;
// Register the ResourceType alias.
ResourceType* resource = nullptr;
core::ScopedUnref unref_me(resource);
OP_REQUIRES_OK(
ctx, cinfo_.resource_manager()->template LookupOrCreate<ContainerBase>(
cinfo_.container(), cinfo_.name(), &container_base, creator));
core::ScopedUnref unref_me(container_base);
ctx,
cinfo_.resource_manager()->template LookupOrCreate<ResourceType, true>(
cinfo_.container(), cinfo_.name(), &resource, creator));
// Put a handle to resource in the output tensor (the other aliases will
// have the same handle).
Tensor* handle;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
handle->scalar<ResourceHandle>()() = MakeResourceHandle<ContainerBase>(
handle->scalar<ResourceHandle>()() = MakeResourceHandle<ResourceType>(
ctx, cinfo_.container(), cinfo_.name());
table_handle_set_ = true;
// Create other alias resources.
Status status;
char dummy[sizeof...(AliasesToRegister)] = {
(status.Update(RegisterAlias<AliasesToRegister>(resource)), 0)...};
(void)dummy;
OP_REQUIRES_OK(ctx, status);
}
~ResourceConstructionOp() override {
// If the table object was not shared, delete it.
if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
if (!cinfo_.resource_manager()
->template Delete<ContainerBase>(cinfo_.container(),
cinfo_.name())
->template Delete<ResourceType>(cinfo_.container(),
cinfo_.name())
.ok()) {
// Do nothing; the resource may have been deleted by session resets.
}
// Attempt to delete other resource aliases.
Status dummy_status;
char dummy[sizeof...(AliasesToRegister)] = {
(dummy_status.Update(DeleteAlias<AliasesToRegister>()), 0)...};
(void)dummy;
}
}
private:
using ResourceType = typename Functor::resource_type;
template <typename T>
Status RegisterAlias(ResourceType* resource) {
auto creator = [resource](T** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = resource;
return Status::OK();
};
T* alias_resource = nullptr;
core::ScopedUnref unref_me(alias_resource);
return cinfo_.resource_manager()->template LookupOrCreate<T, true>(
cinfo_.container(), cinfo_.name(), &alias_resource, creator);
}
template <typename T>
Status DeleteAlias() {
return cinfo_.resource_manager()->template Delete<T>(cinfo_.container(),
cinfo_.name());
}
mutex mu_;
bool table_handle_set_ GUARDED_BY(mu_);
ContainerInfo cinfo_;
@ -120,8 +160,7 @@ class ResourceConstructionOp : public OpKernel {
// If the resource has already been created it will be looked up.
// Container must decrease the reference count of the FallbackTableBaseType*
// constructor argument before its destructor completes.
template <class ContainerBase, class Functor,
class FallbackTableBaseType = ContainerBase>
template <typename Functor, typename... AliasesToRegister>
class TableWithFallbackConstructionOp : public OpKernel {
public:
explicit TableWithFallbackConstructionOp(OpKernelConstruction* ctx)
@ -140,13 +179,14 @@ class TableWithFallbackConstructionOp : public OpKernel {
return;
}
// Look up the fallback table.
FallbackTableBaseType* fallback_table = nullptr;
{
const Tensor& table_handle = ctx->input(table_int64_args.size());
ResourceHandle handle(table_handle.scalar<ResourceHandle>()());
OP_REQUIRES_OK(
ctx, ctx->resource_manager()->Lookup(handle.container(),
handle.name(), &fallback_table));
ctx, ctx->resource_manager()->Lookup<FallbackTableBaseType, true>(
handle.container(), handle.name(), &fallback_table));
}
mutex_lock l(mu_);
@ -156,51 +196,93 @@ class TableWithFallbackConstructionOp : public OpKernel {
}
auto creator = [ctx, this, fallback_table](
ContainerBase** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
ResourceType** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
// container construction logic can't be merged with
// ResourceConstructionOp because Container constructor requires an
// input which can only be constructed if the resource manager
// internal lock is not already held.
ContainerBase* container;
ResourceType* resource = nullptr;
auto status =
Functor::AllocateContainer(ctx, this, fallback_table, &container);
Functor::AllocateContainer(ctx, this, fallback_table, &resource);
if (ABSL_PREDICT_FALSE(!status.ok())) {
container->Unref();
// Ideally resource is non-null only if status is OK but we try
// to compensate here.
if (resource != nullptr) {
resource->Unref();
}
return status;
}
if (ctx->track_allocations()) {
ctx->record_persistent_memory_allocation(container->MemoryUsed());
ctx->record_persistent_memory_allocation(resource->MemoryUsed());
}
*ret = container;
*ret = resource;
return Status::OK();
};
ContainerBase* table = nullptr;
OP_REQUIRES_OK(
ctx, cinfo_.resource_manager()->template LookupOrCreate<ContainerBase>(
cinfo_.container(), cinfo_.name(), &table, creator));
// Register the ResourceType alias.
ResourceType* table = nullptr;
core::ScopedUnref unref_me(table);
OP_REQUIRES_OK(
ctx,
cinfo_.resource_manager()->template LookupOrCreate<ResourceType, true>(
cinfo_.container(), cinfo_.name(), &table, creator));
// Put a handle to resource in the output tensor (the other aliases will
// have the same handle).
Tensor* handle;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
handle->scalar<ResourceHandle>()() = MakeResourceHandle<ContainerBase>(
handle->scalar<ResourceHandle>()() = MakeResourceHandle<ResourceType>(
ctx, cinfo_.container(), cinfo_.name());
table_handle_set_ = true;
// Create other alias resources.
Status status;
char dummy[sizeof...(AliasesToRegister)] = {
(status.Update(RegisterAlias<AliasesToRegister>(table)), 0)...};
(void)dummy;
OP_REQUIRES_OK(ctx, status);
}
~TableWithFallbackConstructionOp() override {
// If the table object was not shared, delete it.
if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
if (!cinfo_.resource_manager()
->template Delete<ContainerBase>(cinfo_.container(),
cinfo_.name())
->template Delete<ResourceType>(cinfo_.container(),
cinfo_.name())
.ok()) {
// Do nothing; the resource may have been deleted by session resets.
}
// Attempt to delete other resource aliases.
Status dummy_status;
char dummy[sizeof...(AliasesToRegister)] = {
(dummy_status.Update(DeleteAlias<AliasesToRegister>()), 0)...};
(void)dummy;
}
}
private:
using ResourceType = typename Functor::resource_type;
using FallbackTableBaseType = typename Functor::fallback_table_type;
template <typename T>
Status RegisterAlias(ResourceType* resource) {
auto creator = [resource](T** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = resource;
return Status::OK();
};
T* alias_resource = nullptr;
core::ScopedUnref unref_me(alias_resource);
return cinfo_.resource_manager()->template LookupOrCreate<T, true>(
cinfo_.container(), cinfo_.name(), &alias_resource, creator);
}
template <typename T>
Status DeleteAlias() {
return cinfo_.resource_manager()->template Delete<T>(cinfo_.container(),
cinfo_.name());
}
mutex mu_;
bool table_handle_set_ GUARDED_BY(mu_);
ContainerInfo cinfo_;
@ -209,33 +291,29 @@ class TableWithFallbackConstructionOp : public OpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(TableWithFallbackConstructionOp);
};
// Used to insert tensors into a container.
template <class Container, class InsertKeyTensorType,
class InsertValueTensorType>
class HeterogeneousLookupTableInsertOrAssignOp : public OpKernel {
// Lookup a table of type ResourceAlias and insert the passed in keys and
// values tensors using Functor::TensorInsert(keys, values, table).
template <typename Functor,
typename ResourceAlias = typename Functor::resource_type>
class LookupTableInsertOp : public OpKernel {
public:
explicit HeterogeneousLookupTableInsertOrAssignOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {}
explicit LookupTableInsertOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
OpInputList table_int64_args;
OP_REQUIRES_OK(ctx, ctx->input_list("table_int64_args", &table_int64_args));
const size_t tensor_index_offset = table_int64_args.size();
// Business logic for checking tensor shapes, etc, is delegated to the
// Functor.
const Tensor& keys = ctx->input(tensor_index_offset + 1);
const Tensor& values = ctx->input(tensor_index_offset + 2);
if (ABSL_PREDICT_FALSE(keys.NumElements() != values.NumElements())) {
ctx->SetStatus(errors::InvalidArgument(
"keys and values do not have the same number of elements: ",
keys.NumElements(), " vs ", values.NumElements()));
return;
}
const Tensor& table_handle = ctx->input(tensor_index_offset);
ResourceHandle handle(table_handle.scalar<ResourceHandle>()());
Container* table;
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup(handle.container(),
handle.name(), &table));
ResourceAlias* table;
core::ScopedUnref unref_me(table);
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup<ResourceAlias, true>(
handle.container(), handle.name(), &table));
int memory_used_before = 0;
if (ctx->track_allocations()) {
@ -244,9 +322,9 @@ class HeterogeneousLookupTableInsertOrAssignOp : public OpKernel {
auto* mutex = table->GetMutex();
if (mutex != nullptr) {
mutex_lock lock(*mutex);
OP_REQUIRES_OK(ctx, TensorInsert(keys, values, table));
OP_REQUIRES_OK(ctx, Functor::TensorInsert(keys, values, table));
} else {
OP_REQUIRES_OK(ctx, TensorInsert(keys, values, table));
OP_REQUIRES_OK(ctx, Functor::TensorInsert(keys, values, table));
}
if (ctx->track_allocations()) {
ctx->record_persistent_memory_allocation(table->MemoryUsed() -
@ -255,74 +333,17 @@ class HeterogeneousLookupTableInsertOrAssignOp : public OpKernel {
}
private:
// Non-variant InsertKeyTensorType which is the same as Container::key_type.
// No need to static_cast.
template <typename SfinaeArg = InsertKeyTensorType>
absl::enable_if_t<
IsValidDataType<SfinaeArg>::value &&
std::is_same<SfinaeArg, typename Container::key_type>::value,
Status>
TensorInsert(const Tensor& keys, const Tensor& values,
Container* table) const {
const auto keys_flat = keys.flat<SfinaeArg>();
const auto values_flat = values.flat<InsertValueTensorType>();
return table->BatchInsertOrAssign(
absl::MakeSpan(keys_flat.data(), keys_flat.size()),
absl::MakeSpan(values_flat.data(), values_flat.size()));
}
// Non-variant InsertKeyTensorType which is otherwise convertible to
// Container::key_type.
template <typename SfinaeArg = InsertKeyTensorType>
absl::enable_if_t<
IsValidDataType<SfinaeArg>::value &&
!std::is_same<SfinaeArg, typename Container::key_type>::value &&
std::is_convertible<SfinaeArg, typename Container::key_type>::value,
Status>
TensorInsert(const Tensor& keys, const Tensor& values,
Container* table) const {
const auto keys_flat = keys.flat<InsertKeyTensorType>();
std::vector<typename Container::key_type> keys_vec;
const auto keys_size = keys_flat.size();
keys_vec.reserve(keys_size);
for (size_t i = 0; i < keys_size; ++i) {
keys_vec.push_back(
static_cast<typename Container::key_type>(keys_flat(i)));
}
const auto values_flat = values.flat<InsertValueTensorType>();
return table->BatchInsertOrAssign(
keys_vec, absl::MakeSpan(values_flat.data(), values_flat.size()));
}
// Variant InsertKeyTensorType; the wrapped type is convertible to
// Container::key_type.
template <typename SfinaeArg = InsertKeyTensorType>
absl::enable_if_t<
!IsValidDataType<SfinaeArg>::value &&
std::is_convertible<typename SfinaeArg::value_type,
typename Container::key_type>::value,
Status>
TensorInsert(const Tensor& keys, const Tensor& values,
Container* table) const {
const auto keys_flat = keys.flat<Variant>();
std::vector<typename Container::key_type> keys_vec;
keys_vec.reserve(keys_flat.size());
for (size_t i = 0; i < keys_flat.size(); ++i) {
keys_vec.emplace_back(
*keys_flat(i).get<typename SfinaeArg::value_type>());
}
const auto values_flat = values.flat<InsertValueTensorType>();
return table->BatchInsertOrAssign(
keys_vec, absl::MakeSpan(values_flat.data(), values_flat.size()));
}
TF_DISALLOW_COPY_AND_ASSIGN(LookupTableInsertOp);
};
// Used for tensor lookups.
template <class Container, class LookupKeyTensorType, class ValueTensorType>
class HeterogeneousLookupTableFindOp : public OpKernel {
// Lookup a table of type ResourceAlias and look up the passed in keys using
// Functor::TensorLookup(
// table, keys, prefetch_lookahead, num_keys_per_thread, threadpool, out).
template <typename Functor,
typename ResourceAlias = typename Functor::resource_type>
class LookupTableFindOp : public OpKernel {
public:
explicit HeterogeneousLookupTableFindOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {}
explicit LookupTableFindOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
OpInputList table_int64_args;
@ -370,10 +391,10 @@ class HeterogeneousLookupTableFindOp : public OpKernel {
const Tensor& table_handle = ctx->input(tensor_index_offset);
ResourceHandle handle(table_handle.scalar<ResourceHandle>()());
Container* table;
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup(handle.container(),
handle.name(), &table));
ResourceAlias* table;
core::ScopedUnref unref_me(table);
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup<ResourceAlias, true>(
handle.container(), handle.name(), &table));
auto* mutex = table->GetMutex();
auto* threadpool = ctx->device()->tensorflow_cpu_worker_threads()->workers;
@ -382,112 +403,20 @@ class HeterogeneousLookupTableFindOp : public OpKernel {
// writer lock here.
mutex_lock lock(*mutex);
OP_REQUIRES_OK(
ctx, TensorLookup(*table, prefetch_lookahead, num_keys_per_thread,
keys, out, threadpool));
ctx, Functor::TensorLookup(*table, keys, prefetch_lookahead,
num_keys_per_thread, threadpool, out));
} else {
OP_REQUIRES_OK(
ctx, TensorLookup(*table, prefetch_lookahead, num_keys_per_thread,
keys, out, threadpool));
ctx, Functor::TensorLookup(*table, keys, prefetch_lookahead,
num_keys_per_thread, threadpool, out));
}
}
private:
// keys and *values arguments to TensorLookup must have the same number of
// elements. This is guaranteed above.
// 'Simple' types below are types which are not natively supported in TF.
// Simple LookupKeyTensorType which is the same as Container::key_type.
template <typename SfinaeArg = LookupKeyTensorType>
absl::enable_if_t<
IsValidDataType<SfinaeArg>::value &&
std::is_same<SfinaeArg, typename Container::key_type>::value,
Status>
TensorLookup(Container& table, int64 prefetch_lookahead,
int64 num_keys_per_thread, const Tensor& keys, Tensor* values,
thread::ThreadPool* threadpool) const {
const auto keys_flat = keys.flat<LookupKeyTensorType>();
const auto keys_size = keys_flat.size();
auto key_span = absl::MakeSpan(keys_flat.data(), keys_size);
auto value_span = absl::MakeSpan(values->flat<ValueTensorType>().data(),
values->NumElements());
return MultithreadedTensorLookup(table, prefetch_lookahead,
num_keys_per_thread, key_span, value_span,
threadpool);
}
// Try to implicitly convert all other simple LookupKeyTensorTypes to
// Container::key_type.
template <typename SfinaeArg = LookupKeyTensorType>
absl::enable_if_t<
IsValidDataType<SfinaeArg>::value &&
!std::is_same<SfinaeArg, typename Container::key_type>::value,
Status>
TensorLookup(Container& table, int64 prefetch_lookahead,
int64 num_keys_per_thread, const Tensor& keys, Tensor* values,
thread::ThreadPool* threadpool) const {
const auto keys_flat = keys.flat<LookupKeyTensorType>();
std::vector<typename Container::key_type> keys_vec;
const auto keys_size = keys_flat.size();
keys_vec.reserve(keys_size);
for (size_t i = 0; i < keys_size; ++i) {
keys_vec.emplace_back(keys_flat(i));
}
absl::Span<typename Container::key_type> key_span(keys_vec);
auto value_span = absl::MakeSpan(values->flat<ValueTensorType>().data(),
values->NumElements());
return MultithreadedTensorLookup(table, prefetch_lookahead,
num_keys_per_thread, key_span, value_span,
threadpool);
}
// Non-simple LookupKeyTensorType. We'll try an implicit conversion to
// Container::key_type.
template <typename VariantSubType = LookupKeyTensorType>
absl::enable_if_t<!IsValidDataType<VariantSubType>::value, Status>
TensorLookup(Container& table, int64 prefetch_lookahead,
int64 num_keys_per_thread, const Tensor& keys, Tensor* values,
thread::ThreadPool* threadpool) const {
const auto keys_flat = keys.flat<Variant>();
std::vector<typename Container::key_type> keys_vec;
const auto keys_size = keys_flat.size();
keys_vec.reserve(keys_size);
for (size_t i = 0; i < keys_size; ++i) {
keys_vec.emplace_back(
*keys_flat(i).get<typename VariantSubType::value_type>());
}
absl::Span<typename Container::key_type> key_span(keys_vec);
auto value_span = absl::MakeSpan(values->flat<ValueTensorType>().data(),
values->NumElements());
return MultithreadedTensorLookup(table, prefetch_lookahead,
num_keys_per_thread, key_span, value_span,
threadpool);
}
// Wrapper around table.BatchLookup which permits sharding across cores.
template <typename K, typename V>
Status MultithreadedTensorLookup(Container& table, int64 prefetch_lookahead,
int64 num_keys_per_thread,
absl::Span<K> keys, absl::Span<V> values,
thread::ThreadPool* threadpool) const {
mutex temp_mutex; // Protect status.
Status status;
auto lookup_keys = [&, this](int64 begin, int64 end) {
auto temp_status = table.BatchLookup(keys.subspan(begin, end - begin),
values.subspan(begin, end - begin),
prefetch_lookahead);
if (ABSL_PREDICT_FALSE(!temp_status.ok())) {
mutex_lock lock(temp_mutex);
status.Update(temp_status);
}
};
threadpool->TransformRangeConcurrently(num_keys_per_thread /* block_size */,
keys.size(), lookup_keys);
return status;
}
};
// Op that returns the size of a container.
template <class Container>
// Lookup a container of type ResourceAlias and return its size using
// Functor::Size(container, &size).
template <typename Functor,
typename ResourceAlias = typename Functor::resource_type>
class ContainerSizeOp : public OpKernel {
public:
explicit ContainerSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
@ -495,11 +424,10 @@ class ContainerSizeOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
const Tensor& container_handle = ctx->input(0);
ResourceHandle handle(container_handle.scalar<ResourceHandle>()());
Container* container;
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup(
handle.container(), handle.name(), &container));
ResourceAlias* container;
core::ScopedUnref unref_me(container);
OP_REQUIRES_OK(ctx, container->SizeStatus());
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup<ResourceAlias, true>(
handle.container(), handle.name(), &container));
Tensor* out;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
@ -507,9 +435,9 @@ class ContainerSizeOp : public OpKernel {
auto* mutex = container->GetMutex();
if (mutex != nullptr) {
tf_shared_lock lock(*mutex);
out->scalar<int64>()() = container->UnsafeSize();
OP_REQUIRES_OK(ctx, Functor::Size(*container, &out->scalar<uint64>()()));
} else {
out->scalar<int64>()() = container->UnsafeSize();
OP_REQUIRES_OK(ctx, Functor::Size(*container, &out->scalar<uint64>()()));
}
}
};

View File

@ -1,87 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_TABLE_RESOURCE_UTILS_H_
#define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_TABLE_RESOURCE_UTILS_H_
#include <memory>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/lookup_tables/lookup_table_interface.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace tables {
// Parent class for tables with support for multithreaded synchronization.
template <typename HeterogeneousKeyType, typename ValueType>
class LookupTableWithSynchronization
: public LookupTableInterface<HeterogeneousKeyType, ValueType> {
public:
LookupTableWithSynchronization(bool enable_synchronization) {
if (enable_synchronization) {
mutex_ = absl::make_unique<mutex>();
}
}
// Mutex for synchronizing access to unsynchronized methods.
mutex* GetMutex() const override { return mutex_.get(); }
private:
// Use this for locking.
mutable std::unique_ptr<mutex> mutex_;
};
// Parent class for tables which can be constructed with arbitrary
// lookup fallbacks.
// Since LookupTableInterface::LookupKey assumes that all keys can be mapped
// to values, LookupTableWithFallbackInterface allows clients to implement
// two-stage lookups. If the first key lookup fails, clients can choose
// to perform a fallback lookup using an externally supplied table.
template <typename HeterogeneousKeyType, typename ValueType,
typename FallbackTableBaseType =
LookupTableInterface<HeterogeneousKeyType, ValueType>>
class LookupTableWithFallbackInterface
: public LookupTableWithSynchronization<HeterogeneousKeyType, ValueType> {
public:
LookupTableWithFallbackInterface(bool enable_synchronization,
const FallbackTableBaseType* fallback_table)
: LookupTableWithSynchronization<HeterogeneousKeyType, ValueType>(
enable_synchronization),
fallback_table_(fallback_table) {}
// Clients are required to fail when ctx is set to a not-OK status in
// the constructor so this dereference is safe.
const FallbackTableBaseType& fallback_table() const {
return *fallback_table_;
}
~LookupTableWithFallbackInterface() override {
if (fallback_table_ != nullptr) {
fallback_table_->Unref();
}
}
private:
const FallbackTableBaseType* fallback_table_;
};
} // namespace tables
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_TABLE_RESOURCE_UTILS_H_