Read from a sharded checkpoint in parallel with multiple threads.

PiperOrigin-RevId: 313506462
Change-Id: I1cef16cdaa9e03fd3161727614007429221089e0
This commit is contained in:
A. Unique TensorFlower 2020-05-27 18:57:52 -07:00 committed by TensorFlower Gardener
parent 102bf84e26
commit ea97139d4d
2 changed files with 59 additions and 92 deletions
tensorflow/core/util/tensor_bundle

View File

@ -34,7 +34,6 @@ limitations under the License.
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/crc32c.h"
#include "tensorflow/core/lib/io/path.h"
@ -42,8 +41,6 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/core/util/saved_tensor_slice_util.h"
#include "tensorflow/core/util/tensor_bundle/byte_swap.h"
@ -1024,106 +1021,79 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
" to restore in slice_spec: ", slice_spec.DebugString());
}
BlockingCounter counter(static_cast<int>(details.size()));
auto runner = [this, &details](std::function<void()> fn) {
if (details.size() > 1) {
// If there are multiple slices to read, perform the read in parallel
// using multiple threads.
env_->SchedClosure(fn);
} else {
fn();
}
};
// The union of the slices in "details" covers "slice_spec". Performs the
// copies from each.
BundleEntryProto stored_slice_entry = full_tensor_entry;
for (const auto& slice_tag_pair : details) {
runner([this, &slice_spec, &full_shape, &slice_tag_pair, &full_tensor_entry,
&full_tensor_key_string, &counter, val]() {
// The union of the slices in "details" covers "slice_spec". Performs the
// copies from each.
BundleEntryProto stored_slice_entry = full_tensor_entry;
// Seeks for the stored slice.
const TensorSlice& stored_slice = slice_tag_pair.first;
// Seeks for the stored slice.
const TensorSlice& stored_slice = slice_tag_pair.first;
// We already have the entry for the full tensor, so don't query again if
// the slice is full.
if (!stored_slice.IsFull()) {
const string encoded_stored_slice_name =
checkpoint::EncodeTensorNameSlice(full_tensor_key_string,
stored_slice);
mutex_lock l(mu_);
// `GetBundleEntryProto` will access `iter_`, so protecting it with a
// mutex lock.
status_ =
GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry);
if (!status_.ok()) return;
}
// We already have the entry for the full tensor, so don't query again if
// the slice is full.
if (!stored_slice.IsFull()) {
const string encoded_stored_slice_name =
checkpoint::EncodeTensorNameSlice(full_tensor_key_string,
stored_slice);
status_ =
GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry);
if (!status_.ok()) return status_;
}
auto cleanup = gtl::MakeCleanup([&counter] { counter.DecrementCount(); });
// TODO(zongheng): should we take an OpKernelContext, so that we can call
// allocate_temp()? Note that without major refactorings to Saver, it's
// hard for the caller of the tensor bundle module to allocate these
// precisely-shaped scratch storage.
// TODO(zongheng): should we take an OpKernelContext, so that we can
// call allocate_temp()? Note that without major refactorings to
// Saver, it's hard for the caller of the tensor bundle module to
// allocate these precisely-shaped scratch storage.
// Optimization for the common case: the stored slice can be directly
// copied to the destination without additional slicing. This is true when
// either the slices are equal or when they are both full slices having the
// same shape.
TensorShape stored_slice_shape(stored_slice_entry.shape());
if (stored_slice == slice_spec ||
(stored_slice_shape == val->shape() &&
IsFullSlice(stored_slice, stored_slice_shape) &&
IsFullSlice(slice_spec, stored_slice_shape))) {
VLOG(1) << "Optimized for common case: directly copying into "
"pre-allocated buffer; spec: "
<< slice_spec.DebugString();
status_ = GetValue(stored_slice_entry, val);
return status_;
}
// Optimization for the common case: the stored slice can be directly
// copied to the destination without additional slicing. This is true
// when either the slices are equal or when they are both full slices
// having the same shape.
TensorShape stored_slice_shape(stored_slice_entry.shape());
if (stored_slice == slice_spec ||
(stored_slice_shape == val->shape() &&
IsFullSlice(stored_slice, stored_slice_shape) &&
IsFullSlice(slice_spec, stored_slice_shape))) {
VLOG(1) << "Optimized for common case: directly copying into "
"pre-allocated buffer; spec: "
<< slice_spec.DebugString();
status_ = GetValue(stored_slice_entry, val);
return;
}
Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape);
status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
if (!status_.ok()) return status_;
Tensor stored_slice_tensor(stored_slice_entry.dtype(),
stored_slice_shape);
status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
if (!status_.ok()) return;
// Copies the intersection over.
mutex_lock l(mu_);
// `CopyDataFromTensorSliceToTensorSlice` will write to `val`, so
// protecting it with a mutex lock.
const DataType common_dtype = full_tensor_entry.dtype();
switch (common_dtype) {
// Copies the intersection over.
const DataType common_dtype = full_tensor_entry.dtype();
switch (common_dtype) {
#define HANDLE_COPY(T) \
case DataTypeToEnum<T>::value: \
CHECK(CopyDataFromTensorSliceToTensorSlice( \
full_shape, stored_slice, slice_spec, \
stored_slice_tensor.flat<T>().data(), val->flat<T>().data())); \
break;
HANDLE_COPY(float)
HANDLE_COPY(double)
HANDLE_COPY(int32)
HANDLE_COPY(uint8)
HANDLE_COPY(int16)
HANDLE_COPY(int8)
HANDLE_COPY(complex64)
HANDLE_COPY(complex128)
HANDLE_COPY(int64)
HANDLE_COPY(bool)
HANDLE_COPY(qint32)
HANDLE_COPY(quint8)
HANDLE_COPY(qint8)
HANDLE_COPY(bfloat16)
default:
status_ = errors::InvalidArgument(
"Dtype ", DataTypeString(common_dtype), " not supported.");
if (!status_.ok()) return;
}
HANDLE_COPY(float)
HANDLE_COPY(double)
HANDLE_COPY(int32)
HANDLE_COPY(uint8)
HANDLE_COPY(int16)
HANDLE_COPY(int8)
HANDLE_COPY(complex64)
HANDLE_COPY(complex128)
HANDLE_COPY(int64)
HANDLE_COPY(bool)
HANDLE_COPY(qint32)
HANDLE_COPY(quint8)
HANDLE_COPY(qint8)
HANDLE_COPY(bfloat16)
default:
return errors::InvalidArgument("Dtype ", DataTypeString(common_dtype),
" not supported.");
}
#undef HANDLE_COPY
});
}
counter.Wait();
TF_RETURN_IF_ERROR(status_);
return Status::OK();
}

View File

@ -306,9 +306,6 @@ class BundleReader {
// differs from that of the current system's processor architecture.
bool need_to_swap_bytes_;
// Protect internal states when accessing from multiple threads.
mutable mutex mu_;
friend class TensorBundleAlignmentTest; // For testing data alignment.
TF_DISALLOW_COPY_AND_ASSIGN(BundleReader);