From e7d9786e66c13eb978c418527746ca3d1e11fe95 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 14 Feb 2019 13:49:09 -0800 Subject: [PATCH] Update Fingerprint64Map to use aliases PiperOrigin-RevId: 234022159 --- tensorflow/core/framework/resource_mgr.h | 45 ++- tensorflow/core/kernels/lookup_tables/BUILD | 14 +- .../lookup_tables/fingerprint64_map_ops.cc | 116 +++--- .../lookup_tables/lookup_table_interface.h | 117 +++--- .../kernels/lookup_tables/table_op_utils.h | 378 +++++++----------- .../lookup_tables/table_resource_utils.h | 87 ---- 6 files changed, 279 insertions(+), 478 deletions(-) delete mode 100644 tensorflow/core/kernels/lookup_tables/table_resource_utils.h diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index 18a21d744b0..da547d5829f 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -132,14 +132,14 @@ class ResourceMgr { // // REQUIRES: std::is_base_of // REQUIRES: resource != nullptr - template + template Status Lookup(const string& container, const string& name, T** resource) const TF_MUST_USE_RESULT; // Similar to Lookup, but looks up multiple resources at once, with only a // single lock acquisition. If containers_and_names[i] is uninitialized // then this function does not modify resources[i]. - template + template Status LookupMany(absl::Span const> containers_and_names, std::vector>* @@ -155,7 +155,7 @@ class ResourceMgr { // // REQUIRES: std::is_base_of // REQUIRES: resource != nullptr - template + template Status LookupOrCreate(const string& container, const string& name, T** resource, std::function creator) TF_MUST_USE_RESULT; @@ -196,7 +196,7 @@ class ResourceMgr { mutable mutex mu_; std::unordered_map containers_ GUARDED_BY(mu_); - template + template Status LookupInternal(const string& container, const string& name, T** resource) const SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT; @@ -267,7 +267,7 @@ Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value); // // If the lookup is successful, the caller takes the ownership of one ref on // `*value`, and must call its `Unref()` method when it has finished using it. -template +template Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value); // Looks up multiple resources pointed by a sequence of resource handles. If @@ -437,15 +437,15 @@ Status ResourceMgr::Create(const string& container, const string& name, return DoCreate(container, MakeTypeIndex(), name, resource); } -template +template Status ResourceMgr::Lookup(const string& container, const string& name, T** resource) const { CheckDeriveFromResourceBase(); tf_shared_lock l(mu_); - return LookupInternal(container, name, resource); + return LookupInternal(container, name, resource); } -template +template Status ResourceMgr::LookupMany( absl::Span const> containers_and_names, @@ -455,8 +455,9 @@ Status ResourceMgr::LookupMany( resources->resize(containers_and_names.size()); for (size_t i = 0; i < containers_and_names.size(); ++i) { T* resource; - Status s = LookupInternal(*containers_and_names[i].first, - *containers_and_names[i].second, &resource); + Status s = LookupInternal( + *containers_and_names[i].first, *containers_and_names[i].second, + &resource); if (s.ok()) { (*resources)[i].reset(resource); } @@ -464,7 +465,18 @@ Status ResourceMgr::LookupMany( return Status::OK(); } +// Simple wrapper to allow conditional dynamic / static casts. +template +struct TypeCastFunctor { + static T* Cast(ResourceBase* r) { return static_cast(r); } +}; + template +struct TypeCastFunctor { + static T* Cast(ResourceBase* r) { return dynamic_cast(r); } +}; + +template Status ResourceMgr::LookupInternal(const string& container, const string& name, T** resource) const { ResourceBase* found = nullptr; @@ -472,12 +484,12 @@ Status ResourceMgr::LookupInternal(const string& container, const string& name, if (s.ok()) { // It's safe to down cast 'found' to T* since // typeid(T).hash_code() is part of the map key. - *resource = static_cast(found); + *resource = TypeCastFunctor::Cast(found); } return s; } -template +template Status ResourceMgr::LookupOrCreate(const string& container, const string& name, T** resource, std::function creator) { @@ -486,11 +498,11 @@ Status ResourceMgr::LookupOrCreate(const string& container, const string& name, Status s; { tf_shared_lock l(mu_); - s = LookupInternal(container, name, resource); + s = LookupInternal(container, name, resource); if (s.ok()) return s; } mutex_lock l(mu_); - s = LookupInternal(container, name, resource); + s = LookupInternal(container, name, resource); if (s.ok()) return s; TF_RETURN_IF_ERROR(creator(resource)); s = DoCreate(container, MakeTypeIndex(), name, *resource); @@ -566,11 +578,12 @@ Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) { return ctx->resource_manager()->Create(p.container(), p.name(), value); } -template +template Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value) { TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType(ctx, p)); - return ctx->resource_manager()->Lookup(p.container(), p.name(), value); + return ctx->resource_manager()->Lookup(p.container(), + p.name(), value); } template diff --git a/tensorflow/core/kernels/lookup_tables/BUILD b/tensorflow/core/kernels/lookup_tables/BUILD index 359caf64295..5cf628ef282 100644 --- a/tensorflow/core/kernels/lookup_tables/BUILD +++ b/tensorflow/core/kernels/lookup_tables/BUILD @@ -19,18 +19,6 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "table_resource_utils", - hdrs = ["table_resource_utils.h"], - deps = [ - ":lookup_table_interface", - "//tensorflow/core:framework", - "//tensorflow/core:lib", ], ) @@ -57,8 +45,8 @@ tf_kernel_library( "fingerprint64_map_ops.cc", ], deps = [ + ":lookup_table_interface", ":table_op_utils", - ":table_resource_utils", "//tensorflow/core:framework", "//tensorflow/core:lib", "@com_google_absl//absl/strings", diff --git a/tensorflow/core/kernels/lookup_tables/fingerprint64_map_ops.cc b/tensorflow/core/kernels/lookup_tables/fingerprint64_map_ops.cc index a000828c4b0..65487d307e6 100644 --- a/tensorflow/core/kernels/lookup_tables/fingerprint64_map_ops.cc +++ b/tensorflow/core/kernels/lookup_tables/fingerprint64_map_ops.cc @@ -15,8 +15,8 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/lookup_tables/lookup_table_interface.h" #include "tensorflow/core/kernels/lookup_tables/table_op_utils.h" -#include "tensorflow/core/kernels/lookup_tables/table_resource_utils.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/macros.h" @@ -27,73 +27,48 @@ namespace tables { // Map x -> (Fingerprint64(x) % num_oov_buckets) + offset. // num_oov_buckets and offset are node attributes provided at construction // time. -template +template class Fingerprint64Map final - : public LookupTableInterface { + : public virtual LookupInterface, + public virtual LookupWithPrefetchInterface, + absl::Span> { public: + using key_type = KeyType; + Fingerprint64Map(int64 num_oov_buckets, int64 offset) : num_oov_buckets_(num_oov_buckets), offset_(offset) {} - mutex* GetMutex() const override { return nullptr; } - - bool UnsafeInsertOrAssign(const HeterogeneousKeyType& key, - const ValueType& value) override { - return true; + Status Lookup(const KeyType& key_to_find, ValueType* value) const override { + *value = LookupHelper(key_to_find); + return Status::OK(); } - Status TableUnbatchedInsertStatus() const override { - return errors::Unimplemented("Fingerprint64Map does not support inserts."); - } - - Status BatchInsertOrAssign(absl::Span keys, - absl::Span values) override { - return errors::Unimplemented("Fingerprint64Map does not support inserts."); - } - - ValueType UnsafeLookupKey( - const HeterogeneousKeyType& key_to_find) const override { - // This can cause a downcast. - return static_cast(Fingerprint64(key_to_find) % - num_oov_buckets_) + - offset_; - } - - Status TableUnbatchedLookupStatus() const override { return Status::OK(); } - - Status BatchLookup(absl::Span keys, - absl::Span values, - int64 prefetch_lookahead) const override { + Status Lookup(absl::Span keys, absl::Span values, + int64 prefetch_lookahead) const override { if (ABSL_PREDICT_FALSE(keys.size() != values.size())) { return errors::InvalidArgument( "keys and values do not have the same number of elements (found ", keys.size(), " vs ", values.size(), ")."); } for (size_t i = 0; i < keys.size(); ++i) { - values[i] = Fingerprint64Map::UnsafeLookupKey(keys[i]); + values[i] = LookupHelper(keys[i]); } return Status::OK(); } - const absl::optional DefaultValue() const override { - return {}; - } + mutex* GetMutex() const override { return nullptr; } - void UnsafePrefetchKey( - const HeterogeneousKeyType& key_to_find) const override {} - - size_t UnsafeSize() const override { return 0; } - - Status SizeStatus() const override { - return errors::Unimplemented( - "Fingerprint64Map does not have a concept of size."); - } - - bool UnsafeContainsKey( - const HeterogeneousKeyType& key_to_find) const override { - return true; - } + string DebugString() const override { return __PRETTY_FUNCTION__; } private: + ABSL_ATTRIBUTE_ALWAYS_INLINE ValueType + LookupHelper(const KeyType& key_to_find) const { + // This can cause a downcast. + return static_cast(Fingerprint64(key_to_find) % + num_oov_buckets_) + + offset_; + } + const int64 num_oov_buckets_; const int64 offset_; TF_DISALLOW_COPY_AND_ASSIGN(Fingerprint64Map); @@ -102,9 +77,10 @@ class Fingerprint64Map final template struct Fingerprint64MapFactory { struct Functor { - template + using resource_type = Fingerprint64Map; + static Status AllocateContainer(OpKernelContext* ctx, OpKernel* kernel, - ContainerBase** container) { + Fingerprint64Map** container) { int64 num_oov_buckets; int64 offset; TF_RETURN_IF_ERROR( @@ -116,24 +92,28 @@ struct Fingerprint64MapFactory { }; }; -#define REGISTER_STRING_KERNEL(table_value_dtype) \ - REGISTER_KERNEL_BUILDER( \ - Name("Fingerprint64Map") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("heterogeneous_key_dtype") \ - .TypeConstraint("table_value_dtype"), \ - ResourceConstructionOp< \ - LookupTableInterface, \ - Fingerprint64MapFactory>::Functor>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Fingerprint64Map") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("heterogeneous_key_dtype") \ - .TypeConstraint("table_value_dtype"), \ - ResourceConstructionOp, \ - Fingerprint64MapFactory>::Functor>); +template +using ResourceOp = ResourceConstructionOp< + typename Fingerprint64MapFactory< + Fingerprint64Map>::Functor, + // These are the aliases. + LookupInterface, + LookupWithPrefetchInterface, + absl::Span>>; + +#define REGISTER_STRING_KERNEL(ValueType) \ + REGISTER_KERNEL_BUILDER( \ + Name("Fingerprint64Map") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("heterogeneous_key_dtype") \ + .TypeConstraint("table_value_dtype"), \ + ResourceOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Fingerprint64Map") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("heterogeneous_key_dtype") \ + .TypeConstraint("table_value_dtype"), \ + ResourceOp); REGISTER_STRING_KERNEL(int32); REGISTER_STRING_KERNEL(int64); diff --git a/tensorflow/core/kernels/lookup_tables/lookup_table_interface.h b/tensorflow/core/kernels/lookup_tables/lookup_table_interface.h index 0cfe44eda79..de6705d6942 100644 --- a/tensorflow/core/kernels/lookup_tables/lookup_table_interface.h +++ b/tensorflow/core/kernels/lookup_tables/lookup_table_interface.h @@ -16,11 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_LOOKUP_TABLE_INTERFACE_H_ #define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_LOOKUP_TABLE_INTERFACE_H_ -#include -#include - -#include "absl/types/optional.h" -#include "absl/types/span.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/mutex.h" @@ -28,90 +23,74 @@ limitations under the License. namespace tensorflow { namespace tables { -// Interface for key-value pair lookups with support for heterogeneous keys. -// This class contains two main kinds of methods: methods which operate on -// a batch of inputs and methods which do not. The latter have the prefix -// 'Unsafe'. Clients must call the corresponding status methods to determine -// whether they are safe to call within a code block. -// Implementations must guarantee thread-safety when GetMutex is used to -// synchronize method access. -template -class LookupTableInterface : public ResourceBase { +// Interface for resources with mutable state. +class SynchronizedInterface : public virtual ResourceBase { public: - using heterogeneous_key_type = HeterogeneousKeyType; - using value_type = ValueType; - using key_type = heterogeneous_key_type; - // Return value should be used to synchronize read/write access to // all public methods. If null, no synchronization is needed. virtual mutex* GetMutex() const = 0; +}; - // Insert the KV pair into the underlying table. If a key equivalent to key - // already exists in the underlying table, its corresponding value is - // overridden. Returns true only if the key was inserted for the first time. - // Undefined if TableUnbatchedInsertStatus() != OK. - virtual bool UnsafeInsertOrAssign(const HeterogeneousKeyType& key, - const ValueType& value) = 0; - - // Returns OK if it is safe to call InsertOrAssign. - // Once OK is returned, it is safe to call InsertOrAssign for the rest of the - // program. - virtual Status TableUnbatchedInsertStatus() const TF_MUST_USE_RESULT = 0; +// Interface for containers which support batch lookups. +template +class InsertOrAssignInterface : public virtual SynchronizedInterface { + public: + using value_type = ValueType; // Stores each KV pair {keys[i], values[i]} in the underlying map, overriding // pre-existing pairs which have equivalent keys. // keys and values should have the same size. - virtual Status BatchInsertOrAssign( - absl::Span keys, - absl::Span values) = 0; + virtual Status InsertOrAssign(KeyContext... key_context, + ValueType values) = 0; +}; - // Prefetch key_to_find into implementation defined data caches. - // Implementations are free to leave this a no-op. - // Undefined if TableUnbatchedLookupStatus() != OK. - virtual void UnsafePrefetchKey( - const HeterogeneousKeyType& key_to_find) const {} - - // Returns true if and only if the table contains key_to_find. - // Undefined if TableUnbatchedLookupStatus() != OK. - virtual bool UnsafeContainsKey( - const HeterogeneousKeyType& key_to_find) const = 0; - - // Lookup the value for key_to_find. This value must always be well-defined, - // even when ContainsKey(key_to_find) == false. When - // dv = DefaultValue() != absl::nullopt and ContainsKey(key_to_find) == false, - // dv is returned. - // Undefined if TableUnbatchedLookupStatus() != OK. - virtual ValueType UnsafeLookupKey( - const HeterogeneousKeyType& key_to_find) const = 0; - - // Returns OK if it is safe to call PrefetchKey, ContainsKey, and - // UnsafeLookupKey. - // If OK is returned, it is safe to call these methods until the next - // non-const method of this class is called. - virtual Status TableUnbatchedLookupStatus() const TF_MUST_USE_RESULT = 0; +// Interface for containers which support lookups. +template +class LookupInterface : public virtual SynchronizedInterface { + public: + using value_type = ValueType; // Lookup the values for keys and store them in values. // prefetch_lookahead is used to prefetch the key at index // i + prefetch_lookahead at the ith iteration of the implemented loop. // keys and values must have the same size. - virtual Status BatchLookup(absl::Span keys, - absl::Span values, - int64 prefetch_lookahead) const = 0; + virtual Status Lookup(KeyContext... key_context, ValueType values) const = 0; +}; - // Returns the number of elements in the table. - // Undefined if SizeStatus() != OK. - virtual size_t UnsafeSize() const = 0; +// Interface for containers which support lookups with prefetching. +template +class LookupWithPrefetchInterface : public virtual SynchronizedInterface { + public: + using value_type = ValueType; - // Returns OK if the return value of UnsafeSize() is always well-defined. - virtual Status SizeStatus() const TF_MUST_USE_RESULT = 0; + // Lookup the values for keys and store them in values. + // prefetch_lookahead is used to prefetch the key at index + // i + prefetch_lookahead at the ith iteration of the implemented loop. + // keys and values must have the same size. + virtual Status Lookup(KeyContext... key_context, ValueType values, + int64 prefetch_lookahead) const = 0; +}; - // If non-null value is returned, LookupKey returns that value only for keys - // which satisfy ContainsKey(key_to_find) == false. - virtual const absl::optional DefaultValue() const = 0; +// Interface for containers with size concepts. +// Implementations must guarantee thread-safety when GetMutex is used to +// synchronize method access. +class SizeInterface : public virtual SynchronizedInterface { + public: + // Returns the number of elements in the container. + virtual uint64 Size() const = 0; +}; - string DebugString() const override { return "A lookup table"; } +// Interface for tables which can be initialized from key and value arguments. +template +class KeyValueTableInitializerInterface : public virtual SynchronizedInterface { + public: + using value_type = ValueType; - ~LookupTableInterface() override = default; + // Lookup the values for keys and store them in values. + // prefetch_lookahead is used to prefetch the key at index + // i + prefetch_lookahead at the ith iteration of the implemented loop. + // keys and values must have the same size. + virtual Status Initialize(KeyContext... key_context, ValueType values) = 0; }; } // namespace tables diff --git a/tensorflow/core/kernels/lookup_tables/table_op_utils.h b/tensorflow/core/kernels/lookup_tables/table_op_utils.h index ad7b0db78e6..b4b27422669 100644 --- a/tensorflow/core/kernels/lookup_tables/table_op_utils.h +++ b/tensorflow/core/kernels/lookup_tables/table_op_utils.h @@ -44,11 +44,11 @@ limitations under the License. namespace tensorflow { namespace tables { -// Create resources of type ContainerBase using the static method +// Create resources of type ResourceType and AliasesToRegister using // Functor::AllocateContainer(OpKernelConstruction*, OpKernel*, -// ContainerBase**) -// If the resource has already been created it will be looked up. -template +// ResourceType**). ResourceType = Functor::resource_type. +// No-op for resources which have already been created. +template class ResourceConstructionOp : public OpKernel { public: explicit ResourceConstructionOp(OpKernelConstruction* ctx) @@ -66,46 +66,86 @@ class ResourceConstructionOp : public OpKernel { } auto creator = [ctx, - this](ContainerBase** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - ContainerBase* container; - auto status = Functor::AllocateContainer(ctx, this, &container); + this](ResourceType** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + ResourceType* resource = nullptr; + auto status = Functor::AllocateContainer(ctx, this, &resource); if (ABSL_PREDICT_FALSE(!status.ok())) { - container->Unref(); + // Ideally resource is non-null only if status is OK but we try + // to compensate here. + if (resource != nullptr) { + resource->Unref(); + } return status; } if (ctx->track_allocations()) { - ctx->record_persistent_memory_allocation(container->MemoryUsed()); + ctx->record_persistent_memory_allocation(resource->MemoryUsed()); } - *ret = container; + *ret = resource; return Status::OK(); }; - ContainerBase* container_base = nullptr; + // Register the ResourceType alias. + ResourceType* resource = nullptr; + core::ScopedUnref unref_me(resource); OP_REQUIRES_OK( - ctx, cinfo_.resource_manager()->template LookupOrCreate( - cinfo_.container(), cinfo_.name(), &container_base, creator)); - core::ScopedUnref unref_me(container_base); + ctx, + cinfo_.resource_manager()->template LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, creator)); + // Put a handle to resource in the output tensor (the other aliases will + // have the same handle). Tensor* handle; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle)); - handle->scalar()() = MakeResourceHandle( + handle->scalar()() = MakeResourceHandle( ctx, cinfo_.container(), cinfo_.name()); table_handle_set_ = true; + + // Create other alias resources. + Status status; + char dummy[sizeof...(AliasesToRegister)] = { + (status.Update(RegisterAlias(resource)), 0)...}; + (void)dummy; + OP_REQUIRES_OK(ctx, status); } ~ResourceConstructionOp() override { // If the table object was not shared, delete it. if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) { if (!cinfo_.resource_manager() - ->template Delete(cinfo_.container(), - cinfo_.name()) + ->template Delete(cinfo_.container(), + cinfo_.name()) .ok()) { // Do nothing; the resource may have been deleted by session resets. } + // Attempt to delete other resource aliases. + Status dummy_status; + char dummy[sizeof...(AliasesToRegister)] = { + (dummy_status.Update(DeleteAlias()), 0)...}; + (void)dummy; } } private: + using ResourceType = typename Functor::resource_type; + template + Status RegisterAlias(ResourceType* resource) { + auto creator = [resource](T** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + *ret = resource; + return Status::OK(); + }; + + T* alias_resource = nullptr; + core::ScopedUnref unref_me(alias_resource); + return cinfo_.resource_manager()->template LookupOrCreate( + cinfo_.container(), cinfo_.name(), &alias_resource, creator); + } + + template + Status DeleteAlias() { + return cinfo_.resource_manager()->template Delete(cinfo_.container(), + cinfo_.name()); + } + mutex mu_; bool table_handle_set_ GUARDED_BY(mu_); ContainerInfo cinfo_; @@ -120,8 +160,7 @@ class ResourceConstructionOp : public OpKernel { // If the resource has already been created it will be looked up. // Container must decrease the reference count of the FallbackTableBaseType* // constructor argument before its destructor completes. -template +template class TableWithFallbackConstructionOp : public OpKernel { public: explicit TableWithFallbackConstructionOp(OpKernelConstruction* ctx) @@ -140,13 +179,14 @@ class TableWithFallbackConstructionOp : public OpKernel { return; } + // Look up the fallback table. FallbackTableBaseType* fallback_table = nullptr; { const Tensor& table_handle = ctx->input(table_int64_args.size()); ResourceHandle handle(table_handle.scalar()()); OP_REQUIRES_OK( - ctx, ctx->resource_manager()->Lookup(handle.container(), - handle.name(), &fallback_table)); + ctx, ctx->resource_manager()->Lookup( + handle.container(), handle.name(), &fallback_table)); } mutex_lock l(mu_); @@ -156,51 +196,93 @@ class TableWithFallbackConstructionOp : public OpKernel { } auto creator = [ctx, this, fallback_table]( - ContainerBase** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + ResourceType** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { // container construction logic can't be merged with // ResourceConstructionOp because Container constructor requires an // input which can only be constructed if the resource manager // internal lock is not already held. - ContainerBase* container; + ResourceType* resource = nullptr; auto status = - Functor::AllocateContainer(ctx, this, fallback_table, &container); + Functor::AllocateContainer(ctx, this, fallback_table, &resource); if (ABSL_PREDICT_FALSE(!status.ok())) { - container->Unref(); + // Ideally resource is non-null only if status is OK but we try + // to compensate here. + if (resource != nullptr) { + resource->Unref(); + } return status; } if (ctx->track_allocations()) { - ctx->record_persistent_memory_allocation(container->MemoryUsed()); + ctx->record_persistent_memory_allocation(resource->MemoryUsed()); } - *ret = container; + *ret = resource; return Status::OK(); }; - ContainerBase* table = nullptr; - OP_REQUIRES_OK( - ctx, cinfo_.resource_manager()->template LookupOrCreate( - cinfo_.container(), cinfo_.name(), &table, creator)); + // Register the ResourceType alias. + ResourceType* table = nullptr; core::ScopedUnref unref_me(table); + OP_REQUIRES_OK( + ctx, + cinfo_.resource_manager()->template LookupOrCreate( + cinfo_.container(), cinfo_.name(), &table, creator)); + // Put a handle to resource in the output tensor (the other aliases will + // have the same handle). Tensor* handle; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle)); - handle->scalar()() = MakeResourceHandle( + handle->scalar()() = MakeResourceHandle( ctx, cinfo_.container(), cinfo_.name()); table_handle_set_ = true; + + // Create other alias resources. + Status status; + char dummy[sizeof...(AliasesToRegister)] = { + (status.Update(RegisterAlias(table)), 0)...}; + (void)dummy; + OP_REQUIRES_OK(ctx, status); } ~TableWithFallbackConstructionOp() override { // If the table object was not shared, delete it. if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) { if (!cinfo_.resource_manager() - ->template Delete(cinfo_.container(), - cinfo_.name()) + ->template Delete(cinfo_.container(), + cinfo_.name()) .ok()) { // Do nothing; the resource may have been deleted by session resets. } + // Attempt to delete other resource aliases. + Status dummy_status; + char dummy[sizeof...(AliasesToRegister)] = { + (dummy_status.Update(DeleteAlias()), 0)...}; + (void)dummy; } } private: + using ResourceType = typename Functor::resource_type; + using FallbackTableBaseType = typename Functor::fallback_table_type; + + template + Status RegisterAlias(ResourceType* resource) { + auto creator = [resource](T** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + *ret = resource; + return Status::OK(); + }; + + T* alias_resource = nullptr; + core::ScopedUnref unref_me(alias_resource); + return cinfo_.resource_manager()->template LookupOrCreate( + cinfo_.container(), cinfo_.name(), &alias_resource, creator); + } + + template + Status DeleteAlias() { + return cinfo_.resource_manager()->template Delete(cinfo_.container(), + cinfo_.name()); + } + mutex mu_; bool table_handle_set_ GUARDED_BY(mu_); ContainerInfo cinfo_; @@ -209,33 +291,29 @@ class TableWithFallbackConstructionOp : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(TableWithFallbackConstructionOp); }; -// Used to insert tensors into a container. -template -class HeterogeneousLookupTableInsertOrAssignOp : public OpKernel { +// Lookup a table of type ResourceAlias and insert the passed in keys and +// values tensors using Functor::TensorInsert(keys, values, table). +template +class LookupTableInsertOp : public OpKernel { public: - explicit HeterogeneousLookupTableInsertOrAssignOp(OpKernelConstruction* ctx) - : OpKernel(ctx) {} + explicit LookupTableInsertOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { OpInputList table_int64_args; OP_REQUIRES_OK(ctx, ctx->input_list("table_int64_args", &table_int64_args)); const size_t tensor_index_offset = table_int64_args.size(); + // Business logic for checking tensor shapes, etc, is delegated to the + // Functor. const Tensor& keys = ctx->input(tensor_index_offset + 1); const Tensor& values = ctx->input(tensor_index_offset + 2); - if (ABSL_PREDICT_FALSE(keys.NumElements() != values.NumElements())) { - ctx->SetStatus(errors::InvalidArgument( - "keys and values do not have the same number of elements: ", - keys.NumElements(), " vs ", values.NumElements())); - return; - } const Tensor& table_handle = ctx->input(tensor_index_offset); ResourceHandle handle(table_handle.scalar()()); - Container* table; - OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup(handle.container(), - handle.name(), &table)); + ResourceAlias* table; core::ScopedUnref unref_me(table); + OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup( + handle.container(), handle.name(), &table)); int memory_used_before = 0; if (ctx->track_allocations()) { @@ -244,9 +322,9 @@ class HeterogeneousLookupTableInsertOrAssignOp : public OpKernel { auto* mutex = table->GetMutex(); if (mutex != nullptr) { mutex_lock lock(*mutex); - OP_REQUIRES_OK(ctx, TensorInsert(keys, values, table)); + OP_REQUIRES_OK(ctx, Functor::TensorInsert(keys, values, table)); } else { - OP_REQUIRES_OK(ctx, TensorInsert(keys, values, table)); + OP_REQUIRES_OK(ctx, Functor::TensorInsert(keys, values, table)); } if (ctx->track_allocations()) { ctx->record_persistent_memory_allocation(table->MemoryUsed() - @@ -255,74 +333,17 @@ class HeterogeneousLookupTableInsertOrAssignOp : public OpKernel { } private: - // Non-variant InsertKeyTensorType which is the same as Container::key_type. - // No need to static_cast. - template - absl::enable_if_t< - IsValidDataType::value && - std::is_same::value, - Status> - TensorInsert(const Tensor& keys, const Tensor& values, - Container* table) const { - const auto keys_flat = keys.flat(); - const auto values_flat = values.flat(); - return table->BatchInsertOrAssign( - absl::MakeSpan(keys_flat.data(), keys_flat.size()), - absl::MakeSpan(values_flat.data(), values_flat.size())); - } - - // Non-variant InsertKeyTensorType which is otherwise convertible to - // Container::key_type. - template - absl::enable_if_t< - IsValidDataType::value && - !std::is_same::value && - std::is_convertible::value, - Status> - TensorInsert(const Tensor& keys, const Tensor& values, - Container* table) const { - const auto keys_flat = keys.flat(); - std::vector keys_vec; - const auto keys_size = keys_flat.size(); - keys_vec.reserve(keys_size); - for (size_t i = 0; i < keys_size; ++i) { - keys_vec.push_back( - static_cast(keys_flat(i))); - } - const auto values_flat = values.flat(); - return table->BatchInsertOrAssign( - keys_vec, absl::MakeSpan(values_flat.data(), values_flat.size())); - } - - // Variant InsertKeyTensorType; the wrapped type is convertible to - // Container::key_type. - template - absl::enable_if_t< - !IsValidDataType::value && - std::is_convertible::value, - Status> - TensorInsert(const Tensor& keys, const Tensor& values, - Container* table) const { - const auto keys_flat = keys.flat(); - std::vector keys_vec; - keys_vec.reserve(keys_flat.size()); - for (size_t i = 0; i < keys_flat.size(); ++i) { - keys_vec.emplace_back( - *keys_flat(i).get()); - } - const auto values_flat = values.flat(); - return table->BatchInsertOrAssign( - keys_vec, absl::MakeSpan(values_flat.data(), values_flat.size())); - } + TF_DISALLOW_COPY_AND_ASSIGN(LookupTableInsertOp); }; -// Used for tensor lookups. -template -class HeterogeneousLookupTableFindOp : public OpKernel { +// Lookup a table of type ResourceAlias and look up the passed in keys using +// Functor::TensorLookup( +// table, keys, prefetch_lookahead, num_keys_per_thread, threadpool, out). +template +class LookupTableFindOp : public OpKernel { public: - explicit HeterogeneousLookupTableFindOp(OpKernelConstruction* ctx) - : OpKernel(ctx) {} + explicit LookupTableFindOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { OpInputList table_int64_args; @@ -370,10 +391,10 @@ class HeterogeneousLookupTableFindOp : public OpKernel { const Tensor& table_handle = ctx->input(tensor_index_offset); ResourceHandle handle(table_handle.scalar()()); - Container* table; - OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup(handle.container(), - handle.name(), &table)); + ResourceAlias* table; core::ScopedUnref unref_me(table); + OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup( + handle.container(), handle.name(), &table)); auto* mutex = table->GetMutex(); auto* threadpool = ctx->device()->tensorflow_cpu_worker_threads()->workers; @@ -382,112 +403,20 @@ class HeterogeneousLookupTableFindOp : public OpKernel { // writer lock here. mutex_lock lock(*mutex); OP_REQUIRES_OK( - ctx, TensorLookup(*table, prefetch_lookahead, num_keys_per_thread, - keys, out, threadpool)); + ctx, Functor::TensorLookup(*table, keys, prefetch_lookahead, + num_keys_per_thread, threadpool, out)); } else { OP_REQUIRES_OK( - ctx, TensorLookup(*table, prefetch_lookahead, num_keys_per_thread, - keys, out, threadpool)); + ctx, Functor::TensorLookup(*table, keys, prefetch_lookahead, + num_keys_per_thread, threadpool, out)); } } - - private: - // keys and *values arguments to TensorLookup must have the same number of - // elements. This is guaranteed above. - - // 'Simple' types below are types which are not natively supported in TF. - // Simple LookupKeyTensorType which is the same as Container::key_type. - template - absl::enable_if_t< - IsValidDataType::value && - std::is_same::value, - Status> - TensorLookup(Container& table, int64 prefetch_lookahead, - int64 num_keys_per_thread, const Tensor& keys, Tensor* values, - thread::ThreadPool* threadpool) const { - const auto keys_flat = keys.flat(); - const auto keys_size = keys_flat.size(); - auto key_span = absl::MakeSpan(keys_flat.data(), keys_size); - auto value_span = absl::MakeSpan(values->flat().data(), - values->NumElements()); - return MultithreadedTensorLookup(table, prefetch_lookahead, - num_keys_per_thread, key_span, value_span, - threadpool); - } - - // Try to implicitly convert all other simple LookupKeyTensorTypes to - // Container::key_type. - template - absl::enable_if_t< - IsValidDataType::value && - !std::is_same::value, - Status> - TensorLookup(Container& table, int64 prefetch_lookahead, - int64 num_keys_per_thread, const Tensor& keys, Tensor* values, - thread::ThreadPool* threadpool) const { - const auto keys_flat = keys.flat(); - std::vector keys_vec; - const auto keys_size = keys_flat.size(); - keys_vec.reserve(keys_size); - for (size_t i = 0; i < keys_size; ++i) { - keys_vec.emplace_back(keys_flat(i)); - } - absl::Span key_span(keys_vec); - auto value_span = absl::MakeSpan(values->flat().data(), - values->NumElements()); - return MultithreadedTensorLookup(table, prefetch_lookahead, - num_keys_per_thread, key_span, value_span, - threadpool); - } - - // Non-simple LookupKeyTensorType. We'll try an implicit conversion to - // Container::key_type. - template - absl::enable_if_t::value, Status> - TensorLookup(Container& table, int64 prefetch_lookahead, - int64 num_keys_per_thread, const Tensor& keys, Tensor* values, - thread::ThreadPool* threadpool) const { - const auto keys_flat = keys.flat(); - std::vector keys_vec; - const auto keys_size = keys_flat.size(); - keys_vec.reserve(keys_size); - for (size_t i = 0; i < keys_size; ++i) { - keys_vec.emplace_back( - *keys_flat(i).get()); - } - absl::Span key_span(keys_vec); - auto value_span = absl::MakeSpan(values->flat().data(), - values->NumElements()); - return MultithreadedTensorLookup(table, prefetch_lookahead, - num_keys_per_thread, key_span, value_span, - threadpool); - } - - // Wrapper around table.BatchLookup which permits sharding across cores. - template - Status MultithreadedTensorLookup(Container& table, int64 prefetch_lookahead, - int64 num_keys_per_thread, - absl::Span keys, absl::Span values, - thread::ThreadPool* threadpool) const { - mutex temp_mutex; // Protect status. - Status status; - auto lookup_keys = [&, this](int64 begin, int64 end) { - auto temp_status = table.BatchLookup(keys.subspan(begin, end - begin), - values.subspan(begin, end - begin), - prefetch_lookahead); - if (ABSL_PREDICT_FALSE(!temp_status.ok())) { - mutex_lock lock(temp_mutex); - status.Update(temp_status); - } - }; - threadpool->TransformRangeConcurrently(num_keys_per_thread /* block_size */, - keys.size(), lookup_keys); - return status; - } }; -// Op that returns the size of a container. -template +// Lookup a container of type ResourceAlias and return its size using +// Functor::Size(container, &size). +template class ContainerSizeOp : public OpKernel { public: explicit ContainerSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} @@ -495,11 +424,10 @@ class ContainerSizeOp : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor& container_handle = ctx->input(0); ResourceHandle handle(container_handle.scalar()()); - Container* container; - OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup( - handle.container(), handle.name(), &container)); + ResourceAlias* container; core::ScopedUnref unref_me(container); - OP_REQUIRES_OK(ctx, container->SizeStatus()); + OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup( + handle.container(), handle.name(), &container)); Tensor* out; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out)); @@ -507,9 +435,9 @@ class ContainerSizeOp : public OpKernel { auto* mutex = container->GetMutex(); if (mutex != nullptr) { tf_shared_lock lock(*mutex); - out->scalar()() = container->UnsafeSize(); + OP_REQUIRES_OK(ctx, Functor::Size(*container, &out->scalar()())); } else { - out->scalar()() = container->UnsafeSize(); + OP_REQUIRES_OK(ctx, Functor::Size(*container, &out->scalar()())); } } }; diff --git a/tensorflow/core/kernels/lookup_tables/table_resource_utils.h b/tensorflow/core/kernels/lookup_tables/table_resource_utils.h deleted file mode 100644 index 742086cb214..00000000000 --- a/tensorflow/core/kernels/lookup_tables/table_resource_utils.h +++ /dev/null @@ -1,87 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_TABLE_RESOURCE_UTILS_H_ -#define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_TABLE_RESOURCE_UTILS_H_ - -#include - -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/kernels/lookup_tables/lookup_table_interface.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/mutex.h" - -namespace tensorflow { -namespace tables { - -// Parent class for tables with support for multithreaded synchronization. -template -class LookupTableWithSynchronization - : public LookupTableInterface { - public: - LookupTableWithSynchronization(bool enable_synchronization) { - if (enable_synchronization) { - mutex_ = absl::make_unique(); - } - } - - // Mutex for synchronizing access to unsynchronized methods. - mutex* GetMutex() const override { return mutex_.get(); } - - private: - // Use this for locking. - mutable std::unique_ptr mutex_; -}; - -// Parent class for tables which can be constructed with arbitrary -// lookup fallbacks. -// Since LookupTableInterface::LookupKey assumes that all keys can be mapped -// to values, LookupTableWithFallbackInterface allows clients to implement -// two-stage lookups. If the first key lookup fails, clients can choose -// to perform a fallback lookup using an externally supplied table. -template > -class LookupTableWithFallbackInterface - : public LookupTableWithSynchronization { - public: - LookupTableWithFallbackInterface(bool enable_synchronization, - const FallbackTableBaseType* fallback_table) - : LookupTableWithSynchronization( - enable_synchronization), - fallback_table_(fallback_table) {} - - // Clients are required to fail when ctx is set to a not-OK status in - // the constructor so this dereference is safe. - const FallbackTableBaseType& fallback_table() const { - return *fallback_table_; - } - - ~LookupTableWithFallbackInterface() override { - if (fallback_table_ != nullptr) { - fallback_table_->Unref(); - } - } - - private: - const FallbackTableBaseType* fallback_table_; -}; - -} // namespace tables -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_TABLE_RESOURCE_UTILS_H_