From ea97139d4d00fd71e0fcb52504ed7cca3c445555 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 27 May 2020 18:57:52 -0700
Subject: [PATCH] Read from a sharded checkpoint in parallel with multiple
 threads.

PiperOrigin-RevId: 313506462
Change-Id: I1cef16cdaa9e03fd3161727614007429221089e0
---
 .../core/util/tensor_bundle/tensor_bundle.cc  | 148 +++++++-----------
 .../core/util/tensor_bundle/tensor_bundle.h   |   3 -
 2 files changed, 59 insertions(+), 92 deletions(-)

diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
index ad9ee2a7c0f..e1234d330fc 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
@@ -34,7 +34,6 @@ limitations under the License.
 #include "tensorflow/core/lib/bfloat16/bfloat16.h"
 #include "tensorflow/core/lib/core/coding.h"
 #include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/cleanup.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/lib/hash/crc32c.h"
 #include "tensorflow/core/lib/io/path.h"
@@ -42,8 +41,6 @@ limitations under the License.
 #include "tensorflow/core/lib/random/random.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/lib/strings/stringprintf.h"
-#include "tensorflow/core/platform/blocking_counter.h"
-#include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/util/env_var.h"
 #include "tensorflow/core/util/saved_tensor_slice_util.h"
 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
@@ -1024,106 +1021,79 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
         " to restore in slice_spec: ", slice_spec.DebugString());
   }
 
-  BlockingCounter counter(static_cast<int>(details.size()));
-  auto runner = [this, &details](std::function<void()> fn) {
-    if (details.size() > 1) {
-      // If there are multiple slices to read, perform the read in parallel
-      // using multiple threads.
-      env_->SchedClosure(fn);
-    } else {
-      fn();
-    }
-  };
-
+  // The union of the slices in "details" covers "slice_spec".  Performs the
+  // copies from each.
+  BundleEntryProto stored_slice_entry = full_tensor_entry;
   for (const auto& slice_tag_pair : details) {
-    runner([this, &slice_spec, &full_shape, &slice_tag_pair, &full_tensor_entry,
-            &full_tensor_key_string, &counter, val]() {
-      // The union of the slices in "details" covers "slice_spec".  Performs the
-      // copies from each.
-      BundleEntryProto stored_slice_entry = full_tensor_entry;
-      // Seeks for the stored slice.
-      const TensorSlice& stored_slice = slice_tag_pair.first;
+    // Seeks for the stored slice.
+    const TensorSlice& stored_slice = slice_tag_pair.first;
 
-      // We already have the entry for the full tensor, so don't query again if
-      // the slice is full.
-      if (!stored_slice.IsFull()) {
-        const string encoded_stored_slice_name =
-            checkpoint::EncodeTensorNameSlice(full_tensor_key_string,
-                                              stored_slice);
-        mutex_lock l(mu_);
-        // `GetBundleEntryProto` will access `iter_`, so protecting it with a
-        // mutex lock.
-        status_ =
-            GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry);
-        if (!status_.ok()) return;
-      }
+    // We already have the entry for the full tensor, so don't query again if
+    // the slice is full.
+    if (!stored_slice.IsFull()) {
+      const string encoded_stored_slice_name =
+          checkpoint::EncodeTensorNameSlice(full_tensor_key_string,
+                                            stored_slice);
+      status_ =
+          GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry);
+      if (!status_.ok()) return status_;
+    }
 
-      auto cleanup = gtl::MakeCleanup([&counter] { counter.DecrementCount(); });
+    // TODO(zongheng): should we take an OpKernelContext, so that we can call
+    // allocate_temp()?  Note that without major refactorings to Saver, it's
+    // hard for the caller of the tensor bundle module to allocate these
+    // precisely-shaped scratch storage.
 
-      // TODO(zongheng): should we take an OpKernelContext, so that we can
-      // call allocate_temp()?  Note that without major refactorings to
-      // Saver, it's hard for the caller of the tensor bundle module to
-      // allocate these precisely-shaped scratch storage.
+    // Optimization for the common case: the stored slice can be directly
+    // copied to the destination without additional slicing. This is true when
+    // either the slices are equal or when they are both full slices having the
+    // same shape.
+    TensorShape stored_slice_shape(stored_slice_entry.shape());
+    if (stored_slice == slice_spec ||
+        (stored_slice_shape == val->shape() &&
+         IsFullSlice(stored_slice, stored_slice_shape) &&
+         IsFullSlice(slice_spec, stored_slice_shape))) {
+      VLOG(1) << "Optimized for common case: directly copying into "
+                 "pre-allocated buffer; spec: "
+              << slice_spec.DebugString();
+      status_ = GetValue(stored_slice_entry, val);
+      return status_;
+    }
 
-      // Optimization for the common case: the stored slice can be directly
-      // copied to the destination without additional slicing. This is true
-      // when either the slices are equal or when they are both full slices
-      // having the same shape.
-      TensorShape stored_slice_shape(stored_slice_entry.shape());
-      if (stored_slice == slice_spec ||
-          (stored_slice_shape == val->shape() &&
-           IsFullSlice(stored_slice, stored_slice_shape) &&
-           IsFullSlice(slice_spec, stored_slice_shape))) {
-        VLOG(1) << "Optimized for common case: directly copying into "
-                   "pre-allocated buffer; spec: "
-                << slice_spec.DebugString();
-        status_ = GetValue(stored_slice_entry, val);
-        return;
-      }
+    Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape);
+    status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
+    if (!status_.ok()) return status_;
 
-      Tensor stored_slice_tensor(stored_slice_entry.dtype(),
-                                 stored_slice_shape);
-      status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
-      if (!status_.ok()) return;
-
-      // Copies the intersection over.
-      mutex_lock l(mu_);
-      // `CopyDataFromTensorSliceToTensorSlice` will write to `val`, so
-      // protecting it with a mutex lock.
-      const DataType common_dtype = full_tensor_entry.dtype();
-      switch (common_dtype) {
+    // Copies the intersection over.
+    const DataType common_dtype = full_tensor_entry.dtype();
+    switch (common_dtype) {
 #define HANDLE_COPY(T)                                                 \
   case DataTypeToEnum<T>::value:                                       \
     CHECK(CopyDataFromTensorSliceToTensorSlice(                        \
         full_shape, stored_slice, slice_spec,                          \
         stored_slice_tensor.flat<T>().data(), val->flat<T>().data())); \
     break;
-        HANDLE_COPY(float)
-        HANDLE_COPY(double)
-        HANDLE_COPY(int32)
-        HANDLE_COPY(uint8)
-        HANDLE_COPY(int16)
-        HANDLE_COPY(int8)
-        HANDLE_COPY(complex64)
-        HANDLE_COPY(complex128)
-        HANDLE_COPY(int64)
-        HANDLE_COPY(bool)
-        HANDLE_COPY(qint32)
-        HANDLE_COPY(quint8)
-        HANDLE_COPY(qint8)
-        HANDLE_COPY(bfloat16)
-        default:
-          status_ = errors::InvalidArgument(
-              "Dtype ", DataTypeString(common_dtype), " not supported.");
-          if (!status_.ok()) return;
-      }
+
+      HANDLE_COPY(float)
+      HANDLE_COPY(double)
+      HANDLE_COPY(int32)
+      HANDLE_COPY(uint8)
+      HANDLE_COPY(int16)
+      HANDLE_COPY(int8)
+      HANDLE_COPY(complex64)
+      HANDLE_COPY(complex128)
+      HANDLE_COPY(int64)
+      HANDLE_COPY(bool)
+      HANDLE_COPY(qint32)
+      HANDLE_COPY(quint8)
+      HANDLE_COPY(qint8)
+      HANDLE_COPY(bfloat16)
+      default:
+        return errors::InvalidArgument("Dtype ", DataTypeString(common_dtype),
+                                       " not supported.");
+    }
 #undef HANDLE_COPY
-    });
   }
-
-  counter.Wait();
-  TF_RETURN_IF_ERROR(status_);
-
   return Status::OK();
 }
 
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.h b/tensorflow/core/util/tensor_bundle/tensor_bundle.h
index 24a9c488cbb..c441000e47d 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle.h
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.h
@@ -306,9 +306,6 @@ class BundleReader {
   // differs from that of the current system's processor architecture.
   bool need_to_swap_bytes_;
 
-  // Protect internal states when accessing from multiple threads.
-  mutable mutex mu_;
-
   friend class TensorBundleAlignmentTest;  // For testing data alignment.
 
   TF_DISALLOW_COPY_AND_ASSIGN(BundleReader);