From ea97139d4d00fd71e0fcb52504ed7cca3c445555 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Wed, 27 May 2020 18:57:52 -0700 Subject: [PATCH] Read from a sharded checkpoint in parallel with multiple threads. PiperOrigin-RevId: 313506462 Change-Id: I1cef16cdaa9e03fd3161727614007429221089e0 --- .../core/util/tensor_bundle/tensor_bundle.cc | 148 +++++++----------- .../core/util/tensor_bundle/tensor_bundle.h | 3 - 2 files changed, 59 insertions(+), 92 deletions(-) diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index ad9ee2a7c0f..e1234d330fc 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/crc32c.h" #include "tensorflow/core/lib/io/path.h" @@ -42,8 +41,6 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/blocking_counter.h" -#include "tensorflow/core/platform/status.h" #include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/saved_tensor_slice_util.h" #include "tensorflow/core/util/tensor_bundle/byte_swap.h" @@ -1024,106 +1021,79 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key, " to restore in slice_spec: ", slice_spec.DebugString()); } - BlockingCounter counter(static_cast<int>(details.size())); - auto runner = [this, &details](std::function<void()> fn) { - if (details.size() > 1) { - // If there are multiple slices to read, perform the read in parallel - // using multiple threads. - env_->SchedClosure(fn); - } else { - fn(); - } - }; - + // The union of the slices in "details" covers "slice_spec". Performs the + // copies from each. + BundleEntryProto stored_slice_entry = full_tensor_entry; for (const auto& slice_tag_pair : details) { - runner([this, &slice_spec, &full_shape, &slice_tag_pair, &full_tensor_entry, - &full_tensor_key_string, &counter, val]() { - // The union of the slices in "details" covers "slice_spec". Performs the - // copies from each. - BundleEntryProto stored_slice_entry = full_tensor_entry; - // Seeks for the stored slice. - const TensorSlice& stored_slice = slice_tag_pair.first; + // Seeks for the stored slice. + const TensorSlice& stored_slice = slice_tag_pair.first; - // We already have the entry for the full tensor, so don't query again if - // the slice is full. - if (!stored_slice.IsFull()) { - const string encoded_stored_slice_name = - checkpoint::EncodeTensorNameSlice(full_tensor_key_string, - stored_slice); - mutex_lock l(mu_); - // `GetBundleEntryProto` will access `iter_`, so protecting it with a - // mutex lock. - status_ = - GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry); - if (!status_.ok()) return; - } + // We already have the entry for the full tensor, so don't query again if + // the slice is full. + if (!stored_slice.IsFull()) { + const string encoded_stored_slice_name = + checkpoint::EncodeTensorNameSlice(full_tensor_key_string, + stored_slice); + status_ = + GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry); + if (!status_.ok()) return status_; + } - auto cleanup = gtl::MakeCleanup([&counter] { counter.DecrementCount(); }); + // TODO(zongheng): should we take an OpKernelContext, so that we can call + // allocate_temp()? Note that without major refactorings to Saver, it's + // hard for the caller of the tensor bundle module to allocate these + // precisely-shaped scratch storage. - // TODO(zongheng): should we take an OpKernelContext, so that we can - // call allocate_temp()? Note that without major refactorings to - // Saver, it's hard for the caller of the tensor bundle module to - // allocate these precisely-shaped scratch storage. + // Optimization for the common case: the stored slice can be directly + // copied to the destination without additional slicing. This is true when + // either the slices are equal or when they are both full slices having the + // same shape. + TensorShape stored_slice_shape(stored_slice_entry.shape()); + if (stored_slice == slice_spec || + (stored_slice_shape == val->shape() && + IsFullSlice(stored_slice, stored_slice_shape) && + IsFullSlice(slice_spec, stored_slice_shape))) { + VLOG(1) << "Optimized for common case: directly copying into " + "pre-allocated buffer; spec: " + << slice_spec.DebugString(); + status_ = GetValue(stored_slice_entry, val); + return status_; + } - // Optimization for the common case: the stored slice can be directly - // copied to the destination without additional slicing. This is true - // when either the slices are equal or when they are both full slices - // having the same shape. - TensorShape stored_slice_shape(stored_slice_entry.shape()); - if (stored_slice == slice_spec || - (stored_slice_shape == val->shape() && - IsFullSlice(stored_slice, stored_slice_shape) && - IsFullSlice(slice_spec, stored_slice_shape))) { - VLOG(1) << "Optimized for common case: directly copying into " - "pre-allocated buffer; spec: " - << slice_spec.DebugString(); - status_ = GetValue(stored_slice_entry, val); - return; - } + Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape); + status_ = GetValue(stored_slice_entry, &stored_slice_tensor); + if (!status_.ok()) return status_; - Tensor stored_slice_tensor(stored_slice_entry.dtype(), - stored_slice_shape); - status_ = GetValue(stored_slice_entry, &stored_slice_tensor); - if (!status_.ok()) return; - - // Copies the intersection over. - mutex_lock l(mu_); - // `CopyDataFromTensorSliceToTensorSlice` will write to `val`, so - // protecting it with a mutex lock. - const DataType common_dtype = full_tensor_entry.dtype(); - switch (common_dtype) { + // Copies the intersection over. + const DataType common_dtype = full_tensor_entry.dtype(); + switch (common_dtype) { #define HANDLE_COPY(T) \ case DataTypeToEnum<T>::value: \ CHECK(CopyDataFromTensorSliceToTensorSlice( \ full_shape, stored_slice, slice_spec, \ stored_slice_tensor.flat<T>().data(), val->flat<T>().data())); \ break; - HANDLE_COPY(float) - HANDLE_COPY(double) - HANDLE_COPY(int32) - HANDLE_COPY(uint8) - HANDLE_COPY(int16) - HANDLE_COPY(int8) - HANDLE_COPY(complex64) - HANDLE_COPY(complex128) - HANDLE_COPY(int64) - HANDLE_COPY(bool) - HANDLE_COPY(qint32) - HANDLE_COPY(quint8) - HANDLE_COPY(qint8) - HANDLE_COPY(bfloat16) - default: - status_ = errors::InvalidArgument( - "Dtype ", DataTypeString(common_dtype), " not supported."); - if (!status_.ok()) return; - } + + HANDLE_COPY(float) + HANDLE_COPY(double) + HANDLE_COPY(int32) + HANDLE_COPY(uint8) + HANDLE_COPY(int16) + HANDLE_COPY(int8) + HANDLE_COPY(complex64) + HANDLE_COPY(complex128) + HANDLE_COPY(int64) + HANDLE_COPY(bool) + HANDLE_COPY(qint32) + HANDLE_COPY(quint8) + HANDLE_COPY(qint8) + HANDLE_COPY(bfloat16) + default: + return errors::InvalidArgument("Dtype ", DataTypeString(common_dtype), + " not supported."); + } #undef HANDLE_COPY - }); } - - counter.Wait(); - TF_RETURN_IF_ERROR(status_); - return Status::OK(); } diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.h b/tensorflow/core/util/tensor_bundle/tensor_bundle.h index 24a9c488cbb..c441000e47d 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.h +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.h @@ -306,9 +306,6 @@ class BundleReader { // differs from that of the current system's processor architecture. bool need_to_swap_bytes_; - // Protect internal states when accessing from multiple threads. - mutable mutex mu_; - friend class TensorBundleAlignmentTest; // For testing data alignment. TF_DISALLOW_COPY_AND_ASSIGN(BundleReader);