From 27894d659fa07ce0badbf38129484e3f6a8db533 Mon Sep 17 00:00:00 2001
From: Alkis Evlogimenos <alkis@google.com>
Date: Fri, 13 Dec 2019 01:05:21 -0800
Subject: [PATCH] Simplify and optimize fastpath for initializable lookuptable.
 - is_initialized_ is made private to avoid wrong uses - is_initialized_ is
 set once with release semantics - is_initialized_ is read with acquire
 semantics This makes readers that check is_initialized() guarantee to see an
 initialized table. On x86 this generates the same assembly. On arm this is
 better because it does not require a full memory barrier. Also check
 is_initialized in MemoryUsed().

PiperOrigin-RevId: 285354674
Change-Id: Ic91054b93c8da345a03196c5726ca1a484b6d389
---
 tensorflow/core/kernels/initializable_lookup_table.cc |  5 +----
 tensorflow/core/kernels/initializable_lookup_table.h  | 10 ++++++++--
 tensorflow/core/kernels/lookup_table_op.h             | 10 ++++++----
 3 files changed, 15 insertions(+), 10 deletions(-)

diff --git a/tensorflow/core/kernels/initializable_lookup_table.cc b/tensorflow/core/kernels/initializable_lookup_table.cc
index b72874f723b..196c2fe95a3 100644
--- a/tensorflow/core/kernels/initializable_lookup_table.cc
+++ b/tensorflow/core/kernels/initializable_lookup_table.cc
@@ -68,10 +68,7 @@ Status InitializableLookupTable::Initialize(InitTableIterator& iter) {
     return iter.status();
   }
 
-  // Prevent compiler/memory reordering of is_initialized and
-  // the initialization itself.
-  std::atomic_thread_fence(std::memory_order_release);
-  is_initialized_ = true;
+  is_initialized_.store(true, std::memory_order_release);
   return Status::OK();
 }
 
diff --git a/tensorflow/core/kernels/initializable_lookup_table.h b/tensorflow/core/kernels/initializable_lookup_table.h
index 5ebfef01f5a..2ff537df81c 100644
--- a/tensorflow/core/kernels/initializable_lookup_table.h
+++ b/tensorflow/core/kernels/initializable_lookup_table.h
@@ -16,6 +16,8 @@ limitations under the License.
 #ifndef TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
 #define TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
 
+#include <atomic>
+
 #include "tensorflow/core/framework/lookup_interface.h"
 #include "tensorflow/core/platform/macros.h"
 
@@ -71,7 +73,9 @@ class InitializableLookupTable : public LookupInterface {
   TensorShape value_shape() const final { return TensorShape(); }
 
   // Returns whether the table was initialized and is ready to serve lookups.
-  bool is_initialized() const { return is_initialized_; }
+  bool is_initialized() const {
+    return is_initialized_.load(std::memory_order_acquire);
+  }
 
   // Initializes the table from the given init table iterator.
   //
@@ -156,7 +160,9 @@ class InitializableLookupTable : public LookupInterface {
   virtual Status AreEntriesSame(const InitTableIterator& iter, bool* result);
 
   mutex mu_;
-  bool is_initialized_ = false;
+
+ private:
+  std::atomic<bool> is_initialized_{false};
 };
 
 // Iterator to initialize tables given 'keys' and 'values' tensors.
diff --git a/tensorflow/core/kernels/lookup_table_op.h b/tensorflow/core/kernels/lookup_table_op.h
index c633b668646..f2304753638 100644
--- a/tensorflow/core/kernels/lookup_table_op.h
+++ b/tensorflow/core/kernels/lookup_table_op.h
@@ -180,15 +180,14 @@ class HashTable : public InitializableLookupTable {
 
   size_t size() const override {
     // return the size of the table only if it's initialized, otherwise 0.
-    if (!is_initialized_) {
+    if (!is_initialized()) {
       return 0;
     }
-    std::atomic_thread_fence(std::memory_order_acquire);
     return table_ ? table_->size() : 0;
   }
 
   Status ExportValues(OpKernelContext* context) override {
-    if (!is_initialized_) {
+    if (!is_initialized()) {
       return errors::Aborted("HashTable is not initialized.");
     }
 
@@ -217,7 +216,7 @@ class HashTable : public InitializableLookupTable {
 
  protected:
   Status DoPrepare(size_t unused) override {
-    if (is_initialized_) {
+    if (is_initialized()) {
       return errors::Aborted("HashTable already initialized.");
     }
     if (!table_) {
@@ -266,6 +265,9 @@ class HashTable : public InitializableLookupTable {
   }
 
   int64 MemoryUsed() const override {
+    if (!is_initialized()) {
+      return 0;
+    }
     if (table_) {
       const int64 num_elements = table_->size();
       return num_elements * (sizeof(K) + sizeof(V));