Read from a sharded checkpoint in parallel with multiple threads.

PiperOrigin-RevId: 313506462
Change-Id: I1cef16cdaa9e03fd3161727614007429221089e0
This commit is contained in:
A. Unique TensorFlower 2020-05-27 18:57:52 -07:00 committed by TensorFlower Gardener
parent 102bf84e26
commit ea97139d4d
2 changed files with 59 additions and 92 deletions

View File

@ -34,7 +34,6 @@ limitations under the License.
#include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/crc32c.h" #include "tensorflow/core/lib/hash/crc32c.h"
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
@ -42,8 +41,6 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/env_var.h"
#include "tensorflow/core/util/saved_tensor_slice_util.h" #include "tensorflow/core/util/saved_tensor_slice_util.h"
#include "tensorflow/core/util/tensor_bundle/byte_swap.h" #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
@ -1024,23 +1021,10 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
" to restore in slice_spec: ", slice_spec.DebugString()); " to restore in slice_spec: ", slice_spec.DebugString());
} }
BlockingCounter counter(static_cast<int>(details.size()));
auto runner = [this, &details](std::function<void()> fn) {
if (details.size() > 1) {
// If there are multiple slices to read, perform the read in parallel
// using multiple threads.
env_->SchedClosure(fn);
} else {
fn();
}
};
for (const auto& slice_tag_pair : details) {
runner([this, &slice_spec, &full_shape, &slice_tag_pair, &full_tensor_entry,
&full_tensor_key_string, &counter, val]() {
// The union of the slices in "details" covers "slice_spec". Performs the // The union of the slices in "details" covers "slice_spec". Performs the
// copies from each. // copies from each.
BundleEntryProto stored_slice_entry = full_tensor_entry; BundleEntryProto stored_slice_entry = full_tensor_entry;
for (const auto& slice_tag_pair : details) {
// Seeks for the stored slice. // Seeks for the stored slice.
const TensorSlice& stored_slice = slice_tag_pair.first; const TensorSlice& stored_slice = slice_tag_pair.first;
@ -1050,25 +1034,20 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
const string encoded_stored_slice_name = const string encoded_stored_slice_name =
checkpoint::EncodeTensorNameSlice(full_tensor_key_string, checkpoint::EncodeTensorNameSlice(full_tensor_key_string,
stored_slice); stored_slice);
mutex_lock l(mu_);
// `GetBundleEntryProto` will access `iter_`, so protecting it with a
// mutex lock.
status_ = status_ =
GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry); GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry);
if (!status_.ok()) return; if (!status_.ok()) return status_;
} }
auto cleanup = gtl::MakeCleanup([&counter] { counter.DecrementCount(); }); // TODO(zongheng): should we take an OpKernelContext, so that we can call
// allocate_temp()? Note that without major refactorings to Saver, it's
// TODO(zongheng): should we take an OpKernelContext, so that we can // hard for the caller of the tensor bundle module to allocate these
// call allocate_temp()? Note that without major refactorings to // precisely-shaped scratch storage.
// Saver, it's hard for the caller of the tensor bundle module to
// allocate these precisely-shaped scratch storage.
// Optimization for the common case: the stored slice can be directly // Optimization for the common case: the stored slice can be directly
// copied to the destination without additional slicing. This is true // copied to the destination without additional slicing. This is true when
// when either the slices are equal or when they are both full slices // either the slices are equal or when they are both full slices having the
// having the same shape. // same shape.
TensorShape stored_slice_shape(stored_slice_entry.shape()); TensorShape stored_slice_shape(stored_slice_entry.shape());
if (stored_slice == slice_spec || if (stored_slice == slice_spec ||
(stored_slice_shape == val->shape() && (stored_slice_shape == val->shape() &&
@ -1078,18 +1057,14 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
"pre-allocated buffer; spec: " "pre-allocated buffer; spec: "
<< slice_spec.DebugString(); << slice_spec.DebugString();
status_ = GetValue(stored_slice_entry, val); status_ = GetValue(stored_slice_entry, val);
return; return status_;
} }
Tensor stored_slice_tensor(stored_slice_entry.dtype(), Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape);
stored_slice_shape);
status_ = GetValue(stored_slice_entry, &stored_slice_tensor); status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
if (!status_.ok()) return; if (!status_.ok()) return status_;
// Copies the intersection over. // Copies the intersection over.
mutex_lock l(mu_);
// `CopyDataFromTensorSliceToTensorSlice` will write to `val`, so
// protecting it with a mutex lock.
const DataType common_dtype = full_tensor_entry.dtype(); const DataType common_dtype = full_tensor_entry.dtype();
switch (common_dtype) { switch (common_dtype) {
#define HANDLE_COPY(T) \ #define HANDLE_COPY(T) \
@ -1098,6 +1073,7 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
full_shape, stored_slice, slice_spec, \ full_shape, stored_slice, slice_spec, \
stored_slice_tensor.flat<T>().data(), val->flat<T>().data())); \ stored_slice_tensor.flat<T>().data(), val->flat<T>().data())); \
break; break;
HANDLE_COPY(float) HANDLE_COPY(float)
HANDLE_COPY(double) HANDLE_COPY(double)
HANDLE_COPY(int32) HANDLE_COPY(int32)
@ -1113,17 +1089,11 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
HANDLE_COPY(qint8) HANDLE_COPY(qint8)
HANDLE_COPY(bfloat16) HANDLE_COPY(bfloat16)
default: default:
status_ = errors::InvalidArgument( return errors::InvalidArgument("Dtype ", DataTypeString(common_dtype),
"Dtype ", DataTypeString(common_dtype), " not supported."); " not supported.");
if (!status_.ok()) return;
} }
#undef HANDLE_COPY #undef HANDLE_COPY
});
} }
counter.Wait();
TF_RETURN_IF_ERROR(status_);
return Status::OK(); return Status::OK();
} }

View File

@ -306,9 +306,6 @@ class BundleReader {
// differs from that of the current system's processor architecture. // differs from that of the current system's processor architecture.
bool need_to_swap_bytes_; bool need_to_swap_bytes_;
// Protect internal states when accessing from multiple threads.
mutable mutex mu_;
friend class TensorBundleAlignmentTest; // For testing data alignment. friend class TensorBundleAlignmentTest; // For testing data alignment.
TF_DISALLOW_COPY_AND_ASSIGN(BundleReader); TF_DISALLOW_COPY_AND_ASSIGN(BundleReader);