Read from a sharded checkpoint in parallel with multiple threads.
PiperOrigin-RevId: 313506462 Change-Id: I1cef16cdaa9e03fd3161727614007429221089e0
This commit is contained in:
parent
102bf84e26
commit
ea97139d4d
tensorflow/core/util/tensor_bundle
@ -34,7 +34,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/lib/core/coding.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/hash/crc32c.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
@ -42,8 +41,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/blocking_counter.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
#include "tensorflow/core/util/saved_tensor_slice_util.h"
|
||||
#include "tensorflow/core/util/tensor_bundle/byte_swap.h"
|
||||
@ -1024,106 +1021,79 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
|
||||
" to restore in slice_spec: ", slice_spec.DebugString());
|
||||
}
|
||||
|
||||
BlockingCounter counter(static_cast<int>(details.size()));
|
||||
auto runner = [this, &details](std::function<void()> fn) {
|
||||
if (details.size() > 1) {
|
||||
// If there are multiple slices to read, perform the read in parallel
|
||||
// using multiple threads.
|
||||
env_->SchedClosure(fn);
|
||||
} else {
|
||||
fn();
|
||||
}
|
||||
};
|
||||
|
||||
// The union of the slices in "details" covers "slice_spec". Performs the
|
||||
// copies from each.
|
||||
BundleEntryProto stored_slice_entry = full_tensor_entry;
|
||||
for (const auto& slice_tag_pair : details) {
|
||||
runner([this, &slice_spec, &full_shape, &slice_tag_pair, &full_tensor_entry,
|
||||
&full_tensor_key_string, &counter, val]() {
|
||||
// The union of the slices in "details" covers "slice_spec". Performs the
|
||||
// copies from each.
|
||||
BundleEntryProto stored_slice_entry = full_tensor_entry;
|
||||
// Seeks for the stored slice.
|
||||
const TensorSlice& stored_slice = slice_tag_pair.first;
|
||||
// Seeks for the stored slice.
|
||||
const TensorSlice& stored_slice = slice_tag_pair.first;
|
||||
|
||||
// We already have the entry for the full tensor, so don't query again if
|
||||
// the slice is full.
|
||||
if (!stored_slice.IsFull()) {
|
||||
const string encoded_stored_slice_name =
|
||||
checkpoint::EncodeTensorNameSlice(full_tensor_key_string,
|
||||
stored_slice);
|
||||
mutex_lock l(mu_);
|
||||
// `GetBundleEntryProto` will access `iter_`, so protecting it with a
|
||||
// mutex lock.
|
||||
status_ =
|
||||
GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry);
|
||||
if (!status_.ok()) return;
|
||||
}
|
||||
// We already have the entry for the full tensor, so don't query again if
|
||||
// the slice is full.
|
||||
if (!stored_slice.IsFull()) {
|
||||
const string encoded_stored_slice_name =
|
||||
checkpoint::EncodeTensorNameSlice(full_tensor_key_string,
|
||||
stored_slice);
|
||||
status_ =
|
||||
GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry);
|
||||
if (!status_.ok()) return status_;
|
||||
}
|
||||
|
||||
auto cleanup = gtl::MakeCleanup([&counter] { counter.DecrementCount(); });
|
||||
// TODO(zongheng): should we take an OpKernelContext, so that we can call
|
||||
// allocate_temp()? Note that without major refactorings to Saver, it's
|
||||
// hard for the caller of the tensor bundle module to allocate these
|
||||
// precisely-shaped scratch storage.
|
||||
|
||||
// TODO(zongheng): should we take an OpKernelContext, so that we can
|
||||
// call allocate_temp()? Note that without major refactorings to
|
||||
// Saver, it's hard for the caller of the tensor bundle module to
|
||||
// allocate these precisely-shaped scratch storage.
|
||||
// Optimization for the common case: the stored slice can be directly
|
||||
// copied to the destination without additional slicing. This is true when
|
||||
// either the slices are equal or when they are both full slices having the
|
||||
// same shape.
|
||||
TensorShape stored_slice_shape(stored_slice_entry.shape());
|
||||
if (stored_slice == slice_spec ||
|
||||
(stored_slice_shape == val->shape() &&
|
||||
IsFullSlice(stored_slice, stored_slice_shape) &&
|
||||
IsFullSlice(slice_spec, stored_slice_shape))) {
|
||||
VLOG(1) << "Optimized for common case: directly copying into "
|
||||
"pre-allocated buffer; spec: "
|
||||
<< slice_spec.DebugString();
|
||||
status_ = GetValue(stored_slice_entry, val);
|
||||
return status_;
|
||||
}
|
||||
|
||||
// Optimization for the common case: the stored slice can be directly
|
||||
// copied to the destination without additional slicing. This is true
|
||||
// when either the slices are equal or when they are both full slices
|
||||
// having the same shape.
|
||||
TensorShape stored_slice_shape(stored_slice_entry.shape());
|
||||
if (stored_slice == slice_spec ||
|
||||
(stored_slice_shape == val->shape() &&
|
||||
IsFullSlice(stored_slice, stored_slice_shape) &&
|
||||
IsFullSlice(slice_spec, stored_slice_shape))) {
|
||||
VLOG(1) << "Optimized for common case: directly copying into "
|
||||
"pre-allocated buffer; spec: "
|
||||
<< slice_spec.DebugString();
|
||||
status_ = GetValue(stored_slice_entry, val);
|
||||
return;
|
||||
}
|
||||
Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape);
|
||||
status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
|
||||
if (!status_.ok()) return status_;
|
||||
|
||||
Tensor stored_slice_tensor(stored_slice_entry.dtype(),
|
||||
stored_slice_shape);
|
||||
status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
|
||||
if (!status_.ok()) return;
|
||||
|
||||
// Copies the intersection over.
|
||||
mutex_lock l(mu_);
|
||||
// `CopyDataFromTensorSliceToTensorSlice` will write to `val`, so
|
||||
// protecting it with a mutex lock.
|
||||
const DataType common_dtype = full_tensor_entry.dtype();
|
||||
switch (common_dtype) {
|
||||
// Copies the intersection over.
|
||||
const DataType common_dtype = full_tensor_entry.dtype();
|
||||
switch (common_dtype) {
|
||||
#define HANDLE_COPY(T) \
|
||||
case DataTypeToEnum<T>::value: \
|
||||
CHECK(CopyDataFromTensorSliceToTensorSlice( \
|
||||
full_shape, stored_slice, slice_spec, \
|
||||
stored_slice_tensor.flat<T>().data(), val->flat<T>().data())); \
|
||||
break;
|
||||
HANDLE_COPY(float)
|
||||
HANDLE_COPY(double)
|
||||
HANDLE_COPY(int32)
|
||||
HANDLE_COPY(uint8)
|
||||
HANDLE_COPY(int16)
|
||||
HANDLE_COPY(int8)
|
||||
HANDLE_COPY(complex64)
|
||||
HANDLE_COPY(complex128)
|
||||
HANDLE_COPY(int64)
|
||||
HANDLE_COPY(bool)
|
||||
HANDLE_COPY(qint32)
|
||||
HANDLE_COPY(quint8)
|
||||
HANDLE_COPY(qint8)
|
||||
HANDLE_COPY(bfloat16)
|
||||
default:
|
||||
status_ = errors::InvalidArgument(
|
||||
"Dtype ", DataTypeString(common_dtype), " not supported.");
|
||||
if (!status_.ok()) return;
|
||||
}
|
||||
|
||||
HANDLE_COPY(float)
|
||||
HANDLE_COPY(double)
|
||||
HANDLE_COPY(int32)
|
||||
HANDLE_COPY(uint8)
|
||||
HANDLE_COPY(int16)
|
||||
HANDLE_COPY(int8)
|
||||
HANDLE_COPY(complex64)
|
||||
HANDLE_COPY(complex128)
|
||||
HANDLE_COPY(int64)
|
||||
HANDLE_COPY(bool)
|
||||
HANDLE_COPY(qint32)
|
||||
HANDLE_COPY(quint8)
|
||||
HANDLE_COPY(qint8)
|
||||
HANDLE_COPY(bfloat16)
|
||||
default:
|
||||
return errors::InvalidArgument("Dtype ", DataTypeString(common_dtype),
|
||||
" not supported.");
|
||||
}
|
||||
#undef HANDLE_COPY
|
||||
});
|
||||
}
|
||||
|
||||
counter.Wait();
|
||||
TF_RETURN_IF_ERROR(status_);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -306,9 +306,6 @@ class BundleReader {
|
||||
// differs from that of the current system's processor architecture.
|
||||
bool need_to_swap_bytes_;
|
||||
|
||||
// Protect internal states when accessing from multiple threads.
|
||||
mutable mutex mu_;
|
||||
|
||||
friend class TensorBundleAlignmentTest; // For testing data alignment.
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(BundleReader);
|
||||
|
Loading…
Reference in New Issue
Block a user