Update Fingerprint64Map to use aliases
PiperOrigin-RevId: 234022159
This commit is contained in:
parent
5b90573d41
commit
e7d9786e66
@ -132,14 +132,14 @@ class ResourceMgr {
|
||||
//
|
||||
// REQUIRES: std::is_base_of<ResourceBase, T>
|
||||
// REQUIRES: resource != nullptr
|
||||
template <typename T>
|
||||
template <typename T, bool use_dynamic_cast = false>
|
||||
Status Lookup(const string& container, const string& name,
|
||||
T** resource) const TF_MUST_USE_RESULT;
|
||||
|
||||
// Similar to Lookup, but looks up multiple resources at once, with only a
|
||||
// single lock acquisition. If containers_and_names[i] is uninitialized
|
||||
// then this function does not modify resources[i].
|
||||
template <typename T>
|
||||
template <typename T, bool use_dynamic_cast = false>
|
||||
Status LookupMany(absl::Span<std::pair<const string*, const string*> const>
|
||||
containers_and_names,
|
||||
std::vector<std::unique_ptr<T, core::RefCountDeleter>>*
|
||||
@ -155,7 +155,7 @@ class ResourceMgr {
|
||||
//
|
||||
// REQUIRES: std::is_base_of<ResourceBase, T>
|
||||
// REQUIRES: resource != nullptr
|
||||
template <typename T>
|
||||
template <typename T, bool use_dynamic_cast = false>
|
||||
Status LookupOrCreate(const string& container, const string& name,
|
||||
T** resource,
|
||||
std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
|
||||
@ -196,7 +196,7 @@ class ResourceMgr {
|
||||
mutable mutex mu_;
|
||||
std::unordered_map<string, Container*> containers_ GUARDED_BY(mu_);
|
||||
|
||||
template <typename T>
|
||||
template <typename T, bool use_dynamic_cast = false>
|
||||
Status LookupInternal(const string& container, const string& name,
|
||||
T** resource) const
|
||||
SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
|
||||
@ -267,7 +267,7 @@ Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
|
||||
//
|
||||
// If the lookup is successful, the caller takes the ownership of one ref on
|
||||
// `*value`, and must call its `Unref()` method when it has finished using it.
|
||||
template <typename T>
|
||||
template <typename T, bool use_dynamic_cast = false>
|
||||
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
|
||||
|
||||
// Looks up multiple resources pointed by a sequence of resource handles. If
|
||||
@ -437,15 +437,15 @@ Status ResourceMgr::Create(const string& container, const string& name,
|
||||
return DoCreate(container, MakeTypeIndex<T>(), name, resource);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, bool use_dynamic_cast>
|
||||
Status ResourceMgr::Lookup(const string& container, const string& name,
|
||||
T** resource) const {
|
||||
CheckDeriveFromResourceBase<T>();
|
||||
tf_shared_lock l(mu_);
|
||||
return LookupInternal(container, name, resource);
|
||||
return LookupInternal<T, use_dynamic_cast>(container, name, resource);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, bool use_dynamic_cast>
|
||||
Status ResourceMgr::LookupMany(
|
||||
absl::Span<std::pair<const string*, const string*> const>
|
||||
containers_and_names,
|
||||
@ -455,8 +455,9 @@ Status ResourceMgr::LookupMany(
|
||||
resources->resize(containers_and_names.size());
|
||||
for (size_t i = 0; i < containers_and_names.size(); ++i) {
|
||||
T* resource;
|
||||
Status s = LookupInternal(*containers_and_names[i].first,
|
||||
*containers_and_names[i].second, &resource);
|
||||
Status s = LookupInternal<T, use_dynamic_cast>(
|
||||
*containers_and_names[i].first, *containers_and_names[i].second,
|
||||
&resource);
|
||||
if (s.ok()) {
|
||||
(*resources)[i].reset(resource);
|
||||
}
|
||||
@ -464,7 +465,18 @@ Status ResourceMgr::LookupMany(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Simple wrapper to allow conditional dynamic / static casts.
|
||||
template <typename T, bool use_dynamic_cast>
|
||||
struct TypeCastFunctor {
|
||||
static T* Cast(ResourceBase* r) { return static_cast<T*>(r); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct TypeCastFunctor<T, true> {
|
||||
static T* Cast(ResourceBase* r) { return dynamic_cast<T*>(r); }
|
||||
};
|
||||
|
||||
template <typename T, bool use_dynamic_cast>
|
||||
Status ResourceMgr::LookupInternal(const string& container, const string& name,
|
||||
T** resource) const {
|
||||
ResourceBase* found = nullptr;
|
||||
@ -472,12 +484,12 @@ Status ResourceMgr::LookupInternal(const string& container, const string& name,
|
||||
if (s.ok()) {
|
||||
// It's safe to down cast 'found' to T* since
|
||||
// typeid(T).hash_code() is part of the map key.
|
||||
*resource = static_cast<T*>(found);
|
||||
*resource = TypeCastFunctor<T, use_dynamic_cast>::Cast(found);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, bool use_dynamic_cast>
|
||||
Status ResourceMgr::LookupOrCreate(const string& container, const string& name,
|
||||
T** resource,
|
||||
std::function<Status(T**)> creator) {
|
||||
@ -486,11 +498,11 @@ Status ResourceMgr::LookupOrCreate(const string& container, const string& name,
|
||||
Status s;
|
||||
{
|
||||
tf_shared_lock l(mu_);
|
||||
s = LookupInternal(container, name, resource);
|
||||
s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
|
||||
if (s.ok()) return s;
|
||||
}
|
||||
mutex_lock l(mu_);
|
||||
s = LookupInternal(container, name, resource);
|
||||
s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
|
||||
if (s.ok()) return s;
|
||||
TF_RETURN_IF_ERROR(creator(resource));
|
||||
s = DoCreate(container, MakeTypeIndex<T>(), name, *resource);
|
||||
@ -566,11 +578,12 @@ Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) {
|
||||
return ctx->resource_manager()->Create(p.container(), p.name(), value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, bool use_dynamic_cast>
|
||||
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||
T** value) {
|
||||
TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
|
||||
return ctx->resource_manager()->Lookup(p.container(), p.name(), value);
|
||||
return ctx->resource_manager()->Lookup<T, use_dynamic_cast>(p.container(),
|
||||
p.name(), value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -19,18 +19,6 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "table_resource_utils",
|
||||
hdrs = ["table_resource_utils.h"],
|
||||
deps = [
|
||||
":lookup_table_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
@ -57,8 +45,8 @@ tf_kernel_library(
|
||||
"fingerprint64_map_ops.cc",
|
||||
],
|
||||
deps = [
|
||||
":lookup_table_interface",
|
||||
":table_op_utils",
|
||||
":table_resource_utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -15,8 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/kernels/lookup_tables/lookup_table_interface.h"
|
||||
#include "tensorflow/core/kernels/lookup_tables/table_op_utils.h"
|
||||
#include "tensorflow/core/kernels/lookup_tables/table_resource_utils.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
@ -27,41 +27,23 @@ namespace tables {
|
||||
// Map x -> (Fingerprint64(x) % num_oov_buckets) + offset.
|
||||
// num_oov_buckets and offset are node attributes provided at construction
|
||||
// time.
|
||||
template <class HeterogeneousKeyType, class ValueType>
|
||||
template <typename KeyType, typename ValueType>
|
||||
class Fingerprint64Map final
|
||||
: public LookupTableInterface<HeterogeneousKeyType, ValueType> {
|
||||
: public virtual LookupInterface<ValueType*, const KeyType&>,
|
||||
public virtual LookupWithPrefetchInterface<absl::Span<ValueType>,
|
||||
absl::Span<const KeyType>> {
|
||||
public:
|
||||
using key_type = KeyType;
|
||||
|
||||
Fingerprint64Map(int64 num_oov_buckets, int64 offset)
|
||||
: num_oov_buckets_(num_oov_buckets), offset_(offset) {}
|
||||
|
||||
mutex* GetMutex() const override { return nullptr; }
|
||||
|
||||
bool UnsafeInsertOrAssign(const HeterogeneousKeyType& key,
|
||||
const ValueType& value) override {
|
||||
return true;
|
||||
Status Lookup(const KeyType& key_to_find, ValueType* value) const override {
|
||||
*value = LookupHelper(key_to_find);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TableUnbatchedInsertStatus() const override {
|
||||
return errors::Unimplemented("Fingerprint64Map does not support inserts.");
|
||||
}
|
||||
|
||||
Status BatchInsertOrAssign(absl::Span<const HeterogeneousKeyType> keys,
|
||||
absl::Span<const ValueType> values) override {
|
||||
return errors::Unimplemented("Fingerprint64Map does not support inserts.");
|
||||
}
|
||||
|
||||
ValueType UnsafeLookupKey(
|
||||
const HeterogeneousKeyType& key_to_find) const override {
|
||||
// This can cause a downcast.
|
||||
return static_cast<ValueType>(Fingerprint64(key_to_find) %
|
||||
num_oov_buckets_) +
|
||||
offset_;
|
||||
}
|
||||
|
||||
Status TableUnbatchedLookupStatus() const override { return Status::OK(); }
|
||||
|
||||
Status BatchLookup(absl::Span<const HeterogeneousKeyType> keys,
|
||||
absl::Span<ValueType> values,
|
||||
Status Lookup(absl::Span<const KeyType> keys, absl::Span<ValueType> values,
|
||||
int64 prefetch_lookahead) const override {
|
||||
if (ABSL_PREDICT_FALSE(keys.size() != values.size())) {
|
||||
return errors::InvalidArgument(
|
||||
@ -69,31 +51,24 @@ class Fingerprint64Map final
|
||||
keys.size(), " vs ", values.size(), ").");
|
||||
}
|
||||
for (size_t i = 0; i < keys.size(); ++i) {
|
||||
values[i] = Fingerprint64Map::UnsafeLookupKey(keys[i]);
|
||||
values[i] = LookupHelper(keys[i]);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const absl::optional<const ValueType> DefaultValue() const override {
|
||||
return {};
|
||||
}
|
||||
mutex* GetMutex() const override { return nullptr; }
|
||||
|
||||
void UnsafePrefetchKey(
|
||||
const HeterogeneousKeyType& key_to_find) const override {}
|
||||
|
||||
size_t UnsafeSize() const override { return 0; }
|
||||
|
||||
Status SizeStatus() const override {
|
||||
return errors::Unimplemented(
|
||||
"Fingerprint64Map does not have a concept of size.");
|
||||
}
|
||||
|
||||
bool UnsafeContainsKey(
|
||||
const HeterogeneousKeyType& key_to_find) const override {
|
||||
return true;
|
||||
}
|
||||
string DebugString() const override { return __PRETTY_FUNCTION__; }
|
||||
|
||||
private:
|
||||
ABSL_ATTRIBUTE_ALWAYS_INLINE ValueType
|
||||
LookupHelper(const KeyType& key_to_find) const {
|
||||
// This can cause a downcast.
|
||||
return static_cast<ValueType>(Fingerprint64(key_to_find) %
|
||||
num_oov_buckets_) +
|
||||
offset_;
|
||||
}
|
||||
|
||||
const int64 num_oov_buckets_;
|
||||
const int64 offset_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Fingerprint64Map);
|
||||
@ -102,9 +77,10 @@ class Fingerprint64Map final
|
||||
template <typename Fingerprint64Map>
|
||||
struct Fingerprint64MapFactory {
|
||||
struct Functor {
|
||||
template <typename ContainerBase>
|
||||
using resource_type = Fingerprint64Map;
|
||||
|
||||
static Status AllocateContainer(OpKernelContext* ctx, OpKernel* kernel,
|
||||
ContainerBase** container) {
|
||||
Fingerprint64Map** container) {
|
||||
int64 num_oov_buckets;
|
||||
int64 offset;
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -116,24 +92,28 @@ struct Fingerprint64MapFactory {
|
||||
};
|
||||
};
|
||||
|
||||
#define REGISTER_STRING_KERNEL(table_value_dtype) \
|
||||
template <typename KeyType, typename ValueType>
|
||||
using ResourceOp = ResourceConstructionOp<
|
||||
typename Fingerprint64MapFactory<
|
||||
Fingerprint64Map<KeyType, ValueType>>::Functor,
|
||||
// These are the aliases.
|
||||
LookupInterface<ValueType*, const KeyType&>,
|
||||
LookupWithPrefetchInterface<absl::Span<ValueType>,
|
||||
absl::Span<const KeyType>>>;
|
||||
|
||||
#define REGISTER_STRING_KERNEL(ValueType) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Fingerprint64Map") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<Variant>("heterogeneous_key_dtype") \
|
||||
.TypeConstraint<table_value_dtype>("table_value_dtype"), \
|
||||
ResourceConstructionOp< \
|
||||
LookupTableInterface<absl::string_view, table_value_dtype>, \
|
||||
Fingerprint64MapFactory<Fingerprint64Map< \
|
||||
absl::string_view, table_value_dtype>>::Functor>); \
|
||||
.TypeConstraint<ValueType>("table_value_dtype"), \
|
||||
ResourceOp<absl::string_view, ValueType>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Fingerprint64Map") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<string>("heterogeneous_key_dtype") \
|
||||
.TypeConstraint<table_value_dtype>("table_value_dtype"), \
|
||||
ResourceConstructionOp<LookupTableInterface<string, table_value_dtype>, \
|
||||
Fingerprint64MapFactory<Fingerprint64Map< \
|
||||
string, table_value_dtype>>::Functor>);
|
||||
.TypeConstraint<ValueType>("table_value_dtype"), \
|
||||
ResourceOp<string, ValueType>);
|
||||
|
||||
REGISTER_STRING_KERNEL(int32);
|
||||
REGISTER_STRING_KERNEL(int64);
|
||||
|
@ -16,11 +16,6 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_LOOKUP_TABLE_INTERFACE_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_LOOKUP_TABLE_INTERFACE_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
@ -28,90 +23,74 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace tables {
|
||||
|
||||
// Interface for key-value pair lookups with support for heterogeneous keys.
|
||||
// This class contains two main kinds of methods: methods which operate on
|
||||
// a batch of inputs and methods which do not. The latter have the prefix
|
||||
// 'Unsafe'. Clients must call the corresponding status methods to determine
|
||||
// whether they are safe to call within a code block.
|
||||
// Implementations must guarantee thread-safety when GetMutex is used to
|
||||
// synchronize method access.
|
||||
template <typename HeterogeneousKeyType, typename ValueType>
|
||||
class LookupTableInterface : public ResourceBase {
|
||||
// Interface for resources with mutable state.
|
||||
class SynchronizedInterface : public virtual ResourceBase {
|
||||
public:
|
||||
using heterogeneous_key_type = HeterogeneousKeyType;
|
||||
using value_type = ValueType;
|
||||
using key_type = heterogeneous_key_type;
|
||||
|
||||
// Return value should be used to synchronize read/write access to
|
||||
// all public methods. If null, no synchronization is needed.
|
||||
virtual mutex* GetMutex() const = 0;
|
||||
};
|
||||
|
||||
// Insert the KV pair into the underlying table. If a key equivalent to key
|
||||
// already exists in the underlying table, its corresponding value is
|
||||
// overridden. Returns true only if the key was inserted for the first time.
|
||||
// Undefined if TableUnbatchedInsertStatus() != OK.
|
||||
virtual bool UnsafeInsertOrAssign(const HeterogeneousKeyType& key,
|
||||
const ValueType& value) = 0;
|
||||
|
||||
// Returns OK if it is safe to call InsertOrAssign.
|
||||
// Once OK is returned, it is safe to call InsertOrAssign for the rest of the
|
||||
// program.
|
||||
virtual Status TableUnbatchedInsertStatus() const TF_MUST_USE_RESULT = 0;
|
||||
// Interface for containers which support batch lookups.
|
||||
template <typename ValueType, typename... KeyContext>
|
||||
class InsertOrAssignInterface : public virtual SynchronizedInterface {
|
||||
public:
|
||||
using value_type = ValueType;
|
||||
|
||||
// Stores each KV pair {keys[i], values[i]} in the underlying map, overriding
|
||||
// pre-existing pairs which have equivalent keys.
|
||||
// keys and values should have the same size.
|
||||
virtual Status BatchInsertOrAssign(
|
||||
absl::Span<const HeterogeneousKeyType> keys,
|
||||
absl::Span<const ValueType> values) = 0;
|
||||
virtual Status InsertOrAssign(KeyContext... key_context,
|
||||
ValueType values) = 0;
|
||||
};
|
||||
|
||||
// Prefetch key_to_find into implementation defined data caches.
|
||||
// Implementations are free to leave this a no-op.
|
||||
// Undefined if TableUnbatchedLookupStatus() != OK.
|
||||
virtual void UnsafePrefetchKey(
|
||||
const HeterogeneousKeyType& key_to_find) const {}
|
||||
|
||||
// Returns true if and only if the table contains key_to_find.
|
||||
// Undefined if TableUnbatchedLookupStatus() != OK.
|
||||
virtual bool UnsafeContainsKey(
|
||||
const HeterogeneousKeyType& key_to_find) const = 0;
|
||||
|
||||
// Lookup the value for key_to_find. This value must always be well-defined,
|
||||
// even when ContainsKey(key_to_find) == false. When
|
||||
// dv = DefaultValue() != absl::nullopt and ContainsKey(key_to_find) == false,
|
||||
// dv is returned.
|
||||
// Undefined if TableUnbatchedLookupStatus() != OK.
|
||||
virtual ValueType UnsafeLookupKey(
|
||||
const HeterogeneousKeyType& key_to_find) const = 0;
|
||||
|
||||
// Returns OK if it is safe to call PrefetchKey, ContainsKey, and
|
||||
// UnsafeLookupKey.
|
||||
// If OK is returned, it is safe to call these methods until the next
|
||||
// non-const method of this class is called.
|
||||
virtual Status TableUnbatchedLookupStatus() const TF_MUST_USE_RESULT = 0;
|
||||
// Interface for containers which support lookups.
|
||||
template <typename ValueType, typename... KeyContext>
|
||||
class LookupInterface : public virtual SynchronizedInterface {
|
||||
public:
|
||||
using value_type = ValueType;
|
||||
|
||||
// Lookup the values for keys and store them in values.
|
||||
// prefetch_lookahead is used to prefetch the key at index
|
||||
// i + prefetch_lookahead at the ith iteration of the implemented loop.
|
||||
// keys and values must have the same size.
|
||||
virtual Status BatchLookup(absl::Span<const HeterogeneousKeyType> keys,
|
||||
absl::Span<ValueType> values,
|
||||
virtual Status Lookup(KeyContext... key_context, ValueType values) const = 0;
|
||||
};
|
||||
|
||||
// Interface for containers which support lookups with prefetching.
|
||||
template <typename ValueType, typename... KeyContext>
|
||||
class LookupWithPrefetchInterface : public virtual SynchronizedInterface {
|
||||
public:
|
||||
using value_type = ValueType;
|
||||
|
||||
// Lookup the values for keys and store them in values.
|
||||
// prefetch_lookahead is used to prefetch the key at index
|
||||
// i + prefetch_lookahead at the ith iteration of the implemented loop.
|
||||
// keys and values must have the same size.
|
||||
virtual Status Lookup(KeyContext... key_context, ValueType values,
|
||||
int64 prefetch_lookahead) const = 0;
|
||||
};
|
||||
|
||||
// Returns the number of elements in the table.
|
||||
// Undefined if SizeStatus() != OK.
|
||||
virtual size_t UnsafeSize() const = 0;
|
||||
// Interface for containers with size concepts.
|
||||
// Implementations must guarantee thread-safety when GetMutex is used to
|
||||
// synchronize method access.
|
||||
class SizeInterface : public virtual SynchronizedInterface {
|
||||
public:
|
||||
// Returns the number of elements in the container.
|
||||
virtual uint64 Size() const = 0;
|
||||
};
|
||||
|
||||
// Returns OK if the return value of UnsafeSize() is always well-defined.
|
||||
virtual Status SizeStatus() const TF_MUST_USE_RESULT = 0;
|
||||
// Interface for tables which can be initialized from key and value arguments.
|
||||
template <typename ValueType, typename... KeyContext>
|
||||
class KeyValueTableInitializerInterface : public virtual SynchronizedInterface {
|
||||
public:
|
||||
using value_type = ValueType;
|
||||
|
||||
// If non-null value is returned, LookupKey returns that value only for keys
|
||||
// which satisfy ContainsKey(key_to_find) == false.
|
||||
virtual const absl::optional<const ValueType> DefaultValue() const = 0;
|
||||
|
||||
string DebugString() const override { return "A lookup table"; }
|
||||
|
||||
~LookupTableInterface() override = default;
|
||||
// Lookup the values for keys and store them in values.
|
||||
// prefetch_lookahead is used to prefetch the key at index
|
||||
// i + prefetch_lookahead at the ith iteration of the implemented loop.
|
||||
// keys and values must have the same size.
|
||||
virtual Status Initialize(KeyContext... key_context, ValueType values) = 0;
|
||||
};
|
||||
|
||||
} // namespace tables
|
||||
|
@ -44,11 +44,11 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace tables {
|
||||
|
||||
// Create resources of type ContainerBase using the static method
|
||||
// Create resources of type ResourceType and AliasesToRegister using
|
||||
// Functor::AllocateContainer(OpKernelConstruction*, OpKernel*,
|
||||
// ContainerBase**)
|
||||
// If the resource has already been created it will be looked up.
|
||||
template <class ContainerBase, typename Functor>
|
||||
// ResourceType**). ResourceType = Functor::resource_type.
|
||||
// No-op for resources which have already been created.
|
||||
template <typename Functor, typename... AliasesToRegister>
|
||||
class ResourceConstructionOp : public OpKernel {
|
||||
public:
|
||||
explicit ResourceConstructionOp(OpKernelConstruction* ctx)
|
||||
@ -66,46 +66,86 @@ class ResourceConstructionOp : public OpKernel {
|
||||
}
|
||||
|
||||
auto creator = [ctx,
|
||||
this](ContainerBase** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
ContainerBase* container;
|
||||
auto status = Functor::AllocateContainer(ctx, this, &container);
|
||||
this](ResourceType** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
ResourceType* resource = nullptr;
|
||||
auto status = Functor::AllocateContainer(ctx, this, &resource);
|
||||
if (ABSL_PREDICT_FALSE(!status.ok())) {
|
||||
container->Unref();
|
||||
// Ideally resource is non-null only if status is OK but we try
|
||||
// to compensate here.
|
||||
if (resource != nullptr) {
|
||||
resource->Unref();
|
||||
}
|
||||
return status;
|
||||
}
|
||||
if (ctx->track_allocations()) {
|
||||
ctx->record_persistent_memory_allocation(container->MemoryUsed());
|
||||
ctx->record_persistent_memory_allocation(resource->MemoryUsed());
|
||||
}
|
||||
*ret = container;
|
||||
*ret = resource;
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
ContainerBase* container_base = nullptr;
|
||||
// Register the ResourceType alias.
|
||||
ResourceType* resource = nullptr;
|
||||
core::ScopedUnref unref_me(resource);
|
||||
OP_REQUIRES_OK(
|
||||
ctx, cinfo_.resource_manager()->template LookupOrCreate<ContainerBase>(
|
||||
cinfo_.container(), cinfo_.name(), &container_base, creator));
|
||||
core::ScopedUnref unref_me(container_base);
|
||||
ctx,
|
||||
cinfo_.resource_manager()->template LookupOrCreate<ResourceType, true>(
|
||||
cinfo_.container(), cinfo_.name(), &resource, creator));
|
||||
|
||||
// Put a handle to resource in the output tensor (the other aliases will
|
||||
// have the same handle).
|
||||
Tensor* handle;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
|
||||
handle->scalar<ResourceHandle>()() = MakeResourceHandle<ContainerBase>(
|
||||
handle->scalar<ResourceHandle>()() = MakeResourceHandle<ResourceType>(
|
||||
ctx, cinfo_.container(), cinfo_.name());
|
||||
table_handle_set_ = true;
|
||||
|
||||
// Create other alias resources.
|
||||
Status status;
|
||||
char dummy[sizeof...(AliasesToRegister)] = {
|
||||
(status.Update(RegisterAlias<AliasesToRegister>(resource)), 0)...};
|
||||
(void)dummy;
|
||||
OP_REQUIRES_OK(ctx, status);
|
||||
}
|
||||
|
||||
~ResourceConstructionOp() override {
|
||||
// If the table object was not shared, delete it.
|
||||
if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
|
||||
if (!cinfo_.resource_manager()
|
||||
->template Delete<ContainerBase>(cinfo_.container(),
|
||||
->template Delete<ResourceType>(cinfo_.container(),
|
||||
cinfo_.name())
|
||||
.ok()) {
|
||||
// Do nothing; the resource may have been deleted by session resets.
|
||||
}
|
||||
// Attempt to delete other resource aliases.
|
||||
Status dummy_status;
|
||||
char dummy[sizeof...(AliasesToRegister)] = {
|
||||
(dummy_status.Update(DeleteAlias<AliasesToRegister>()), 0)...};
|
||||
(void)dummy;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
using ResourceType = typename Functor::resource_type;
|
||||
template <typename T>
|
||||
Status RegisterAlias(ResourceType* resource) {
|
||||
auto creator = [resource](T** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
*ret = resource;
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
T* alias_resource = nullptr;
|
||||
core::ScopedUnref unref_me(alias_resource);
|
||||
return cinfo_.resource_manager()->template LookupOrCreate<T, true>(
|
||||
cinfo_.container(), cinfo_.name(), &alias_resource, creator);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status DeleteAlias() {
|
||||
return cinfo_.resource_manager()->template Delete<T>(cinfo_.container(),
|
||||
cinfo_.name());
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
bool table_handle_set_ GUARDED_BY(mu_);
|
||||
ContainerInfo cinfo_;
|
||||
@ -120,8 +160,7 @@ class ResourceConstructionOp : public OpKernel {
|
||||
// If the resource has already been created it will be looked up.
|
||||
// Container must decrease the reference count of the FallbackTableBaseType*
|
||||
// constructor argument before its destructor completes.
|
||||
template <class ContainerBase, class Functor,
|
||||
class FallbackTableBaseType = ContainerBase>
|
||||
template <typename Functor, typename... AliasesToRegister>
|
||||
class TableWithFallbackConstructionOp : public OpKernel {
|
||||
public:
|
||||
explicit TableWithFallbackConstructionOp(OpKernelConstruction* ctx)
|
||||
@ -140,13 +179,14 @@ class TableWithFallbackConstructionOp : public OpKernel {
|
||||
return;
|
||||
}
|
||||
|
||||
// Look up the fallback table.
|
||||
FallbackTableBaseType* fallback_table = nullptr;
|
||||
{
|
||||
const Tensor& table_handle = ctx->input(table_int64_args.size());
|
||||
ResourceHandle handle(table_handle.scalar<ResourceHandle>()());
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->resource_manager()->Lookup(handle.container(),
|
||||
handle.name(), &fallback_table));
|
||||
ctx, ctx->resource_manager()->Lookup<FallbackTableBaseType, true>(
|
||||
handle.container(), handle.name(), &fallback_table));
|
||||
}
|
||||
mutex_lock l(mu_);
|
||||
|
||||
@ -156,51 +196,93 @@ class TableWithFallbackConstructionOp : public OpKernel {
|
||||
}
|
||||
|
||||
auto creator = [ctx, this, fallback_table](
|
||||
ContainerBase** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
ResourceType** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
// container construction logic can't be merged with
|
||||
// ResourceConstructionOp because Container constructor requires an
|
||||
// input which can only be constructed if the resource manager
|
||||
// internal lock is not already held.
|
||||
ContainerBase* container;
|
||||
ResourceType* resource = nullptr;
|
||||
auto status =
|
||||
Functor::AllocateContainer(ctx, this, fallback_table, &container);
|
||||
Functor::AllocateContainer(ctx, this, fallback_table, &resource);
|
||||
if (ABSL_PREDICT_FALSE(!status.ok())) {
|
||||
container->Unref();
|
||||
// Ideally resource is non-null only if status is OK but we try
|
||||
// to compensate here.
|
||||
if (resource != nullptr) {
|
||||
resource->Unref();
|
||||
}
|
||||
return status;
|
||||
}
|
||||
if (ctx->track_allocations()) {
|
||||
ctx->record_persistent_memory_allocation(container->MemoryUsed());
|
||||
ctx->record_persistent_memory_allocation(resource->MemoryUsed());
|
||||
}
|
||||
*ret = container;
|
||||
*ret = resource;
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
ContainerBase* table = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, cinfo_.resource_manager()->template LookupOrCreate<ContainerBase>(
|
||||
cinfo_.container(), cinfo_.name(), &table, creator));
|
||||
// Register the ResourceType alias.
|
||||
ResourceType* table = nullptr;
|
||||
core::ScopedUnref unref_me(table);
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
cinfo_.resource_manager()->template LookupOrCreate<ResourceType, true>(
|
||||
cinfo_.container(), cinfo_.name(), &table, creator));
|
||||
|
||||
// Put a handle to resource in the output tensor (the other aliases will
|
||||
// have the same handle).
|
||||
Tensor* handle;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
|
||||
handle->scalar<ResourceHandle>()() = MakeResourceHandle<ContainerBase>(
|
||||
handle->scalar<ResourceHandle>()() = MakeResourceHandle<ResourceType>(
|
||||
ctx, cinfo_.container(), cinfo_.name());
|
||||
table_handle_set_ = true;
|
||||
|
||||
// Create other alias resources.
|
||||
Status status;
|
||||
char dummy[sizeof...(AliasesToRegister)] = {
|
||||
(status.Update(RegisterAlias<AliasesToRegister>(table)), 0)...};
|
||||
(void)dummy;
|
||||
OP_REQUIRES_OK(ctx, status);
|
||||
}
|
||||
|
||||
~TableWithFallbackConstructionOp() override {
|
||||
// If the table object was not shared, delete it.
|
||||
if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
|
||||
if (!cinfo_.resource_manager()
|
||||
->template Delete<ContainerBase>(cinfo_.container(),
|
||||
->template Delete<ResourceType>(cinfo_.container(),
|
||||
cinfo_.name())
|
||||
.ok()) {
|
||||
// Do nothing; the resource may have been deleted by session resets.
|
||||
}
|
||||
// Attempt to delete other resource aliases.
|
||||
Status dummy_status;
|
||||
char dummy[sizeof...(AliasesToRegister)] = {
|
||||
(dummy_status.Update(DeleteAlias<AliasesToRegister>()), 0)...};
|
||||
(void)dummy;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
using ResourceType = typename Functor::resource_type;
|
||||
using FallbackTableBaseType = typename Functor::fallback_table_type;
|
||||
|
||||
template <typename T>
|
||||
Status RegisterAlias(ResourceType* resource) {
|
||||
auto creator = [resource](T** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
*ret = resource;
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
T* alias_resource = nullptr;
|
||||
core::ScopedUnref unref_me(alias_resource);
|
||||
return cinfo_.resource_manager()->template LookupOrCreate<T, true>(
|
||||
cinfo_.container(), cinfo_.name(), &alias_resource, creator);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status DeleteAlias() {
|
||||
return cinfo_.resource_manager()->template Delete<T>(cinfo_.container(),
|
||||
cinfo_.name());
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
bool table_handle_set_ GUARDED_BY(mu_);
|
||||
ContainerInfo cinfo_;
|
||||
@ -209,33 +291,29 @@ class TableWithFallbackConstructionOp : public OpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(TableWithFallbackConstructionOp);
|
||||
};
|
||||
|
||||
// Used to insert tensors into a container.
|
||||
template <class Container, class InsertKeyTensorType,
|
||||
class InsertValueTensorType>
|
||||
class HeterogeneousLookupTableInsertOrAssignOp : public OpKernel {
|
||||
// Lookup a table of type ResourceAlias and insert the passed in keys and
|
||||
// values tensors using Functor::TensorInsert(keys, values, table).
|
||||
template <typename Functor,
|
||||
typename ResourceAlias = typename Functor::resource_type>
|
||||
class LookupTableInsertOp : public OpKernel {
|
||||
public:
|
||||
explicit HeterogeneousLookupTableInsertOrAssignOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {}
|
||||
explicit LookupTableInsertOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
OpInputList table_int64_args;
|
||||
OP_REQUIRES_OK(ctx, ctx->input_list("table_int64_args", &table_int64_args));
|
||||
const size_t tensor_index_offset = table_int64_args.size();
|
||||
// Business logic for checking tensor shapes, etc, is delegated to the
|
||||
// Functor.
|
||||
const Tensor& keys = ctx->input(tensor_index_offset + 1);
|
||||
const Tensor& values = ctx->input(tensor_index_offset + 2);
|
||||
if (ABSL_PREDICT_FALSE(keys.NumElements() != values.NumElements())) {
|
||||
ctx->SetStatus(errors::InvalidArgument(
|
||||
"keys and values do not have the same number of elements: ",
|
||||
keys.NumElements(), " vs ", values.NumElements()));
|
||||
return;
|
||||
}
|
||||
|
||||
const Tensor& table_handle = ctx->input(tensor_index_offset);
|
||||
ResourceHandle handle(table_handle.scalar<ResourceHandle>()());
|
||||
Container* table;
|
||||
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup(handle.container(),
|
||||
handle.name(), &table));
|
||||
ResourceAlias* table;
|
||||
core::ScopedUnref unref_me(table);
|
||||
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup<ResourceAlias, true>(
|
||||
handle.container(), handle.name(), &table));
|
||||
|
||||
int memory_used_before = 0;
|
||||
if (ctx->track_allocations()) {
|
||||
@ -244,9 +322,9 @@ class HeterogeneousLookupTableInsertOrAssignOp : public OpKernel {
|
||||
auto* mutex = table->GetMutex();
|
||||
if (mutex != nullptr) {
|
||||
mutex_lock lock(*mutex);
|
||||
OP_REQUIRES_OK(ctx, TensorInsert(keys, values, table));
|
||||
OP_REQUIRES_OK(ctx, Functor::TensorInsert(keys, values, table));
|
||||
} else {
|
||||
OP_REQUIRES_OK(ctx, TensorInsert(keys, values, table));
|
||||
OP_REQUIRES_OK(ctx, Functor::TensorInsert(keys, values, table));
|
||||
}
|
||||
if (ctx->track_allocations()) {
|
||||
ctx->record_persistent_memory_allocation(table->MemoryUsed() -
|
||||
@ -255,74 +333,17 @@ class HeterogeneousLookupTableInsertOrAssignOp : public OpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
// Non-variant InsertKeyTensorType which is the same as Container::key_type.
|
||||
// No need to static_cast.
|
||||
template <typename SfinaeArg = InsertKeyTensorType>
|
||||
absl::enable_if_t<
|
||||
IsValidDataType<SfinaeArg>::value &&
|
||||
std::is_same<SfinaeArg, typename Container::key_type>::value,
|
||||
Status>
|
||||
TensorInsert(const Tensor& keys, const Tensor& values,
|
||||
Container* table) const {
|
||||
const auto keys_flat = keys.flat<SfinaeArg>();
|
||||
const auto values_flat = values.flat<InsertValueTensorType>();
|
||||
return table->BatchInsertOrAssign(
|
||||
absl::MakeSpan(keys_flat.data(), keys_flat.size()),
|
||||
absl::MakeSpan(values_flat.data(), values_flat.size()));
|
||||
}
|
||||
|
||||
// Non-variant InsertKeyTensorType which is otherwise convertible to
|
||||
// Container::key_type.
|
||||
template <typename SfinaeArg = InsertKeyTensorType>
|
||||
absl::enable_if_t<
|
||||
IsValidDataType<SfinaeArg>::value &&
|
||||
!std::is_same<SfinaeArg, typename Container::key_type>::value &&
|
||||
std::is_convertible<SfinaeArg, typename Container::key_type>::value,
|
||||
Status>
|
||||
TensorInsert(const Tensor& keys, const Tensor& values,
|
||||
Container* table) const {
|
||||
const auto keys_flat = keys.flat<InsertKeyTensorType>();
|
||||
std::vector<typename Container::key_type> keys_vec;
|
||||
const auto keys_size = keys_flat.size();
|
||||
keys_vec.reserve(keys_size);
|
||||
for (size_t i = 0; i < keys_size; ++i) {
|
||||
keys_vec.push_back(
|
||||
static_cast<typename Container::key_type>(keys_flat(i)));
|
||||
}
|
||||
const auto values_flat = values.flat<InsertValueTensorType>();
|
||||
return table->BatchInsertOrAssign(
|
||||
keys_vec, absl::MakeSpan(values_flat.data(), values_flat.size()));
|
||||
}
|
||||
|
||||
// Variant InsertKeyTensorType; the wrapped type is convertible to
|
||||
// Container::key_type.
|
||||
template <typename SfinaeArg = InsertKeyTensorType>
|
||||
absl::enable_if_t<
|
||||
!IsValidDataType<SfinaeArg>::value &&
|
||||
std::is_convertible<typename SfinaeArg::value_type,
|
||||
typename Container::key_type>::value,
|
||||
Status>
|
||||
TensorInsert(const Tensor& keys, const Tensor& values,
|
||||
Container* table) const {
|
||||
const auto keys_flat = keys.flat<Variant>();
|
||||
std::vector<typename Container::key_type> keys_vec;
|
||||
keys_vec.reserve(keys_flat.size());
|
||||
for (size_t i = 0; i < keys_flat.size(); ++i) {
|
||||
keys_vec.emplace_back(
|
||||
*keys_flat(i).get<typename SfinaeArg::value_type>());
|
||||
}
|
||||
const auto values_flat = values.flat<InsertValueTensorType>();
|
||||
return table->BatchInsertOrAssign(
|
||||
keys_vec, absl::MakeSpan(values_flat.data(), values_flat.size()));
|
||||
}
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(LookupTableInsertOp);
|
||||
};
|
||||
|
||||
// Used for tensor lookups.
|
||||
template <class Container, class LookupKeyTensorType, class ValueTensorType>
|
||||
class HeterogeneousLookupTableFindOp : public OpKernel {
|
||||
// Lookup a table of type ResourceAlias and look up the passed in keys using
|
||||
// Functor::TensorLookup(
|
||||
// table, keys, prefetch_lookahead, num_keys_per_thread, threadpool, out).
|
||||
template <typename Functor,
|
||||
typename ResourceAlias = typename Functor::resource_type>
|
||||
class LookupTableFindOp : public OpKernel {
|
||||
public:
|
||||
explicit HeterogeneousLookupTableFindOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {}
|
||||
explicit LookupTableFindOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
OpInputList table_int64_args;
|
||||
@ -370,10 +391,10 @@ class HeterogeneousLookupTableFindOp : public OpKernel {
|
||||
|
||||
const Tensor& table_handle = ctx->input(tensor_index_offset);
|
||||
ResourceHandle handle(table_handle.scalar<ResourceHandle>()());
|
||||
Container* table;
|
||||
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup(handle.container(),
|
||||
handle.name(), &table));
|
||||
ResourceAlias* table;
|
||||
core::ScopedUnref unref_me(table);
|
||||
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup<ResourceAlias, true>(
|
||||
handle.container(), handle.name(), &table));
|
||||
|
||||
auto* mutex = table->GetMutex();
|
||||
auto* threadpool = ctx->device()->tensorflow_cpu_worker_threads()->workers;
|
||||
@ -382,112 +403,20 @@ class HeterogeneousLookupTableFindOp : public OpKernel {
|
||||
// writer lock here.
|
||||
mutex_lock lock(*mutex);
|
||||
OP_REQUIRES_OK(
|
||||
ctx, TensorLookup(*table, prefetch_lookahead, num_keys_per_thread,
|
||||
keys, out, threadpool));
|
||||
ctx, Functor::TensorLookup(*table, keys, prefetch_lookahead,
|
||||
num_keys_per_thread, threadpool, out));
|
||||
} else {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, TensorLookup(*table, prefetch_lookahead, num_keys_per_thread,
|
||||
keys, out, threadpool));
|
||||
ctx, Functor::TensorLookup(*table, keys, prefetch_lookahead,
|
||||
num_keys_per_thread, threadpool, out));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// keys and *values arguments to TensorLookup must have the same number of
|
||||
// elements. This is guaranteed above.
|
||||
|
||||
// 'Simple' types below are types which are not natively supported in TF.
|
||||
// Simple LookupKeyTensorType which is the same as Container::key_type.
|
||||
template <typename SfinaeArg = LookupKeyTensorType>
|
||||
absl::enable_if_t<
|
||||
IsValidDataType<SfinaeArg>::value &&
|
||||
std::is_same<SfinaeArg, typename Container::key_type>::value,
|
||||
Status>
|
||||
TensorLookup(Container& table, int64 prefetch_lookahead,
|
||||
int64 num_keys_per_thread, const Tensor& keys, Tensor* values,
|
||||
thread::ThreadPool* threadpool) const {
|
||||
const auto keys_flat = keys.flat<LookupKeyTensorType>();
|
||||
const auto keys_size = keys_flat.size();
|
||||
auto key_span = absl::MakeSpan(keys_flat.data(), keys_size);
|
||||
auto value_span = absl::MakeSpan(values->flat<ValueTensorType>().data(),
|
||||
values->NumElements());
|
||||
return MultithreadedTensorLookup(table, prefetch_lookahead,
|
||||
num_keys_per_thread, key_span, value_span,
|
||||
threadpool);
|
||||
}
|
||||
|
||||
// Try to implicitly convert all other simple LookupKeyTensorTypes to
|
||||
// Container::key_type.
|
||||
template <typename SfinaeArg = LookupKeyTensorType>
|
||||
absl::enable_if_t<
|
||||
IsValidDataType<SfinaeArg>::value &&
|
||||
!std::is_same<SfinaeArg, typename Container::key_type>::value,
|
||||
Status>
|
||||
TensorLookup(Container& table, int64 prefetch_lookahead,
|
||||
int64 num_keys_per_thread, const Tensor& keys, Tensor* values,
|
||||
thread::ThreadPool* threadpool) const {
|
||||
const auto keys_flat = keys.flat<LookupKeyTensorType>();
|
||||
std::vector<typename Container::key_type> keys_vec;
|
||||
const auto keys_size = keys_flat.size();
|
||||
keys_vec.reserve(keys_size);
|
||||
for (size_t i = 0; i < keys_size; ++i) {
|
||||
keys_vec.emplace_back(keys_flat(i));
|
||||
}
|
||||
absl::Span<typename Container::key_type> key_span(keys_vec);
|
||||
auto value_span = absl::MakeSpan(values->flat<ValueTensorType>().data(),
|
||||
values->NumElements());
|
||||
return MultithreadedTensorLookup(table, prefetch_lookahead,
|
||||
num_keys_per_thread, key_span, value_span,
|
||||
threadpool);
|
||||
}
|
||||
|
||||
// Non-simple LookupKeyTensorType. We'll try an implicit conversion to
|
||||
// Container::key_type.
|
||||
template <typename VariantSubType = LookupKeyTensorType>
|
||||
absl::enable_if_t<!IsValidDataType<VariantSubType>::value, Status>
|
||||
TensorLookup(Container& table, int64 prefetch_lookahead,
|
||||
int64 num_keys_per_thread, const Tensor& keys, Tensor* values,
|
||||
thread::ThreadPool* threadpool) const {
|
||||
const auto keys_flat = keys.flat<Variant>();
|
||||
std::vector<typename Container::key_type> keys_vec;
|
||||
const auto keys_size = keys_flat.size();
|
||||
keys_vec.reserve(keys_size);
|
||||
for (size_t i = 0; i < keys_size; ++i) {
|
||||
keys_vec.emplace_back(
|
||||
*keys_flat(i).get<typename VariantSubType::value_type>());
|
||||
}
|
||||
absl::Span<typename Container::key_type> key_span(keys_vec);
|
||||
auto value_span = absl::MakeSpan(values->flat<ValueTensorType>().data(),
|
||||
values->NumElements());
|
||||
return MultithreadedTensorLookup(table, prefetch_lookahead,
|
||||
num_keys_per_thread, key_span, value_span,
|
||||
threadpool);
|
||||
}
|
||||
|
||||
// Wrapper around table.BatchLookup which permits sharding across cores.
|
||||
template <typename K, typename V>
|
||||
Status MultithreadedTensorLookup(Container& table, int64 prefetch_lookahead,
|
||||
int64 num_keys_per_thread,
|
||||
absl::Span<K> keys, absl::Span<V> values,
|
||||
thread::ThreadPool* threadpool) const {
|
||||
mutex temp_mutex; // Protect status.
|
||||
Status status;
|
||||
auto lookup_keys = [&, this](int64 begin, int64 end) {
|
||||
auto temp_status = table.BatchLookup(keys.subspan(begin, end - begin),
|
||||
values.subspan(begin, end - begin),
|
||||
prefetch_lookahead);
|
||||
if (ABSL_PREDICT_FALSE(!temp_status.ok())) {
|
||||
mutex_lock lock(temp_mutex);
|
||||
status.Update(temp_status);
|
||||
}
|
||||
};
|
||||
threadpool->TransformRangeConcurrently(num_keys_per_thread /* block_size */,
|
||||
keys.size(), lookup_keys);
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
// Op that returns the size of a container.
|
||||
template <class Container>
|
||||
// Lookup a container of type ResourceAlias and return its size using
|
||||
// Functor::Size(container, &size).
|
||||
template <typename Functor,
|
||||
typename ResourceAlias = typename Functor::resource_type>
|
||||
class ContainerSizeOp : public OpKernel {
|
||||
public:
|
||||
explicit ContainerSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
@ -495,11 +424,10 @@ class ContainerSizeOp : public OpKernel {
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& container_handle = ctx->input(0);
|
||||
ResourceHandle handle(container_handle.scalar<ResourceHandle>()());
|
||||
Container* container;
|
||||
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup(
|
||||
handle.container(), handle.name(), &container));
|
||||
ResourceAlias* container;
|
||||
core::ScopedUnref unref_me(container);
|
||||
OP_REQUIRES_OK(ctx, container->SizeStatus());
|
||||
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup<ResourceAlias, true>(
|
||||
handle.container(), handle.name(), &container));
|
||||
|
||||
Tensor* out;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
|
||||
@ -507,9 +435,9 @@ class ContainerSizeOp : public OpKernel {
|
||||
auto* mutex = container->GetMutex();
|
||||
if (mutex != nullptr) {
|
||||
tf_shared_lock lock(*mutex);
|
||||
out->scalar<int64>()() = container->UnsafeSize();
|
||||
OP_REQUIRES_OK(ctx, Functor::Size(*container, &out->scalar<uint64>()()));
|
||||
} else {
|
||||
out->scalar<int64>()() = container->UnsafeSize();
|
||||
OP_REQUIRES_OK(ctx, Functor::Size(*container, &out->scalar<uint64>()()));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -1,87 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_TABLE_RESOURCE_UTILS_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_TABLE_RESOURCE_UTILS_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/lookup_tables/lookup_table_interface.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tables {
|
||||
|
||||
// Parent class for tables with support for multithreaded synchronization.
|
||||
template <typename HeterogeneousKeyType, typename ValueType>
|
||||
class LookupTableWithSynchronization
|
||||
: public LookupTableInterface<HeterogeneousKeyType, ValueType> {
|
||||
public:
|
||||
LookupTableWithSynchronization(bool enable_synchronization) {
|
||||
if (enable_synchronization) {
|
||||
mutex_ = absl::make_unique<mutex>();
|
||||
}
|
||||
}
|
||||
|
||||
// Mutex for synchronizing access to unsynchronized methods.
|
||||
mutex* GetMutex() const override { return mutex_.get(); }
|
||||
|
||||
private:
|
||||
// Use this for locking.
|
||||
mutable std::unique_ptr<mutex> mutex_;
|
||||
};
|
||||
|
||||
// Parent class for tables which can be constructed with arbitrary
|
||||
// lookup fallbacks.
|
||||
// Since LookupTableInterface::LookupKey assumes that all keys can be mapped
|
||||
// to values, LookupTableWithFallbackInterface allows clients to implement
|
||||
// two-stage lookups. If the first key lookup fails, clients can choose
|
||||
// to perform a fallback lookup using an externally supplied table.
|
||||
template <typename HeterogeneousKeyType, typename ValueType,
|
||||
typename FallbackTableBaseType =
|
||||
LookupTableInterface<HeterogeneousKeyType, ValueType>>
|
||||
class LookupTableWithFallbackInterface
|
||||
: public LookupTableWithSynchronization<HeterogeneousKeyType, ValueType> {
|
||||
public:
|
||||
LookupTableWithFallbackInterface(bool enable_synchronization,
|
||||
const FallbackTableBaseType* fallback_table)
|
||||
: LookupTableWithSynchronization<HeterogeneousKeyType, ValueType>(
|
||||
enable_synchronization),
|
||||
fallback_table_(fallback_table) {}
|
||||
|
||||
// Clients are required to fail when ctx is set to a not-OK status in
|
||||
// the constructor so this dereference is safe.
|
||||
const FallbackTableBaseType& fallback_table() const {
|
||||
return *fallback_table_;
|
||||
}
|
||||
|
||||
~LookupTableWithFallbackInterface() override {
|
||||
if (fallback_table_ != nullptr) {
|
||||
fallback_table_->Unref();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const FallbackTableBaseType* fallback_table_;
|
||||
};
|
||||
|
||||
} // namespace tables
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_TABLE_RESOURCE_UTILS_H_
|
Loading…
Reference in New Issue
Block a user