Read from a sharded checkpoint in parallel with multiple threads.

PiperOrigin-RevId: 313506462
Change-Id: I1cef16cdaa9e03fd3161727614007429221089e0
This commit is contained in:
A. Unique TensorFlower 2020-05-27 18:57:52 -07:00 committed by TensorFlower Gardener
parent 102bf84e26
commit ea97139d4d
2 changed files with 59 additions and 92 deletions
tensorflow/core/util/tensor_bundle

View File

@ -34,7 +34,6 @@ limitations under the License.
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/crc32c.h"
#include "tensorflow/core/lib/io/path.h"
@ -42,8 +41,6 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/core/util/saved_tensor_slice_util.h"
#include "tensorflow/core/util/tensor_bundle/byte_swap.h"
@ -1024,23 +1021,10 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
" to restore in slice_spec: ", slice_spec.DebugString());
}
BlockingCounter counter(static_cast<int>(details.size()));
auto runner = [this, &details](std::function<void()> fn) {
if (details.size() > 1) {
// If there are multiple slices to read, perform the read in parallel
// using multiple threads.
env_->SchedClosure(fn);
} else {
fn();
}
};
for (const auto& slice_tag_pair : details) {
runner([this, &slice_spec, &full_shape, &slice_tag_pair, &full_tensor_entry,
&full_tensor_key_string, &counter, val]() {
// The union of the slices in "details" covers "slice_spec". Performs the
// copies from each.
BundleEntryProto stored_slice_entry = full_tensor_entry;
for (const auto& slice_tag_pair : details) {
// Seeks for the stored slice.
const TensorSlice& stored_slice = slice_tag_pair.first;
@ -1050,25 +1034,20 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
const string encoded_stored_slice_name =
checkpoint::EncodeTensorNameSlice(full_tensor_key_string,
stored_slice);
mutex_lock l(mu_);
// `GetBundleEntryProto` will access `iter_`, so protecting it with a
// mutex lock.
status_ =
GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry);
if (!status_.ok()) return;
if (!status_.ok()) return status_;
}
auto cleanup = gtl::MakeCleanup([&counter] { counter.DecrementCount(); });
// TODO(zongheng): should we take an OpKernelContext, so that we can
// call allocate_temp()? Note that without major refactorings to
// Saver, it's hard for the caller of the tensor bundle module to
// allocate these precisely-shaped scratch storage.
// TODO(zongheng): should we take an OpKernelContext, so that we can call
// allocate_temp()? Note that without major refactorings to Saver, it's
// hard for the caller of the tensor bundle module to allocate these
// precisely-shaped scratch storage.
// Optimization for the common case: the stored slice can be directly
// copied to the destination without additional slicing. This is true
// when either the slices are equal or when they are both full slices
// having the same shape.
// copied to the destination without additional slicing. This is true when
// either the slices are equal or when they are both full slices having the
// same shape.
TensorShape stored_slice_shape(stored_slice_entry.shape());
if (stored_slice == slice_spec ||
(stored_slice_shape == val->shape() &&
@ -1078,18 +1057,14 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
"pre-allocated buffer; spec: "
<< slice_spec.DebugString();
status_ = GetValue(stored_slice_entry, val);
return;
return status_;
}
Tensor stored_slice_tensor(stored_slice_entry.dtype(),
stored_slice_shape);
Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape);
status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
if (!status_.ok()) return;
if (!status_.ok()) return status_;
// Copies the intersection over.
mutex_lock l(mu_);
// `CopyDataFromTensorSliceToTensorSlice` will write to `val`, so
// protecting it with a mutex lock.
const DataType common_dtype = full_tensor_entry.dtype();
switch (common_dtype) {
#define HANDLE_COPY(T) \
@ -1098,6 +1073,7 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
full_shape, stored_slice, slice_spec, \
stored_slice_tensor.flat<T>().data(), val->flat<T>().data())); \
break;
HANDLE_COPY(float)
HANDLE_COPY(double)
HANDLE_COPY(int32)
@ -1113,17 +1089,11 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
HANDLE_COPY(qint8)
HANDLE_COPY(bfloat16)
default:
status_ = errors::InvalidArgument(
"Dtype ", DataTypeString(common_dtype), " not supported.");
if (!status_.ok()) return;
return errors::InvalidArgument("Dtype ", DataTypeString(common_dtype),
" not supported.");
}
#undef HANDLE_COPY
});
}
counter.Wait();
TF_RETURN_IF_ERROR(status_);
return Status::OK();
}

View File

@ -306,9 +306,6 @@ class BundleReader {
// differs from that of the current system's processor architecture.
bool need_to_swap_bytes_;
// Protect internal states when accessing from multiple threads.
mutable mutex mu_;
friend class TensorBundleAlignmentTest; // For testing data alignment.
TF_DISALLOW_COPY_AND_ASSIGN(BundleReader);