diff --git a/tensorflow/core/kernels/initializable_lookup_table.cc b/tensorflow/core/kernels/initializable_lookup_table.cc index b72874f723b..196c2fe95a3 100644 --- a/tensorflow/core/kernels/initializable_lookup_table.cc +++ b/tensorflow/core/kernels/initializable_lookup_table.cc @@ -68,10 +68,7 @@ Status InitializableLookupTable::Initialize(InitTableIterator& iter) { return iter.status(); } - // Prevent compiler/memory reordering of is_initialized and - // the initialization itself. - std::atomic_thread_fence(std::memory_order_release); - is_initialized_ = true; + is_initialized_.store(true, std::memory_order_release); return Status::OK(); } diff --git a/tensorflow/core/kernels/initializable_lookup_table.h b/tensorflow/core/kernels/initializable_lookup_table.h index 5ebfef01f5a..2ff537df81c 100644 --- a/tensorflow/core/kernels/initializable_lookup_table.h +++ b/tensorflow/core/kernels/initializable_lookup_table.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_ #define TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_ +#include + #include "tensorflow/core/framework/lookup_interface.h" #include "tensorflow/core/platform/macros.h" @@ -71,7 +73,9 @@ class InitializableLookupTable : public LookupInterface { TensorShape value_shape() const final { return TensorShape(); } // Returns whether the table was initialized and is ready to serve lookups. - bool is_initialized() const { return is_initialized_; } + bool is_initialized() const { + return is_initialized_.load(std::memory_order_acquire); + } // Initializes the table from the given init table iterator. // @@ -156,7 +160,9 @@ class InitializableLookupTable : public LookupInterface { virtual Status AreEntriesSame(const InitTableIterator& iter, bool* result); mutex mu_; - bool is_initialized_ = false; + + private: + std::atomic is_initialized_{false}; }; // Iterator to initialize tables given 'keys' and 'values' tensors. diff --git a/tensorflow/core/kernels/lookup_table_op.h b/tensorflow/core/kernels/lookup_table_op.h index c633b668646..f2304753638 100644 --- a/tensorflow/core/kernels/lookup_table_op.h +++ b/tensorflow/core/kernels/lookup_table_op.h @@ -180,15 +180,14 @@ class HashTable : public InitializableLookupTable { size_t size() const override { // return the size of the table only if it's initialized, otherwise 0. - if (!is_initialized_) { + if (!is_initialized()) { return 0; } - std::atomic_thread_fence(std::memory_order_acquire); return table_ ? table_->size() : 0; } Status ExportValues(OpKernelContext* context) override { - if (!is_initialized_) { + if (!is_initialized()) { return errors::Aborted("HashTable is not initialized."); } @@ -217,7 +216,7 @@ class HashTable : public InitializableLookupTable { protected: Status DoPrepare(size_t unused) override { - if (is_initialized_) { + if (is_initialized()) { return errors::Aborted("HashTable already initialized."); } if (!table_) { @@ -266,6 +265,9 @@ class HashTable : public InitializableLookupTable { } int64 MemoryUsed() const override { + if (!is_initialized()) { + return 0; + } if (table_) { const int64 num_elements = table_->size(); return num_elements * (sizeof(K) + sizeof(V));