Read from a sharded checkpoint in parallel with multiple threads.
PiperOrigin-RevId: 313506462 Change-Id: I1cef16cdaa9e03fd3161727614007429221089e0
This commit is contained in:
parent
102bf84e26
commit
ea97139d4d
@ -34,7 +34,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||||
#include "tensorflow/core/lib/core/coding.h"
|
#include "tensorflow/core/lib/core/coding.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
#include "tensorflow/core/lib/hash/crc32c.h"
|
#include "tensorflow/core/lib/hash/crc32c.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
@ -42,8 +41,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/random/random.h"
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||||
#include "tensorflow/core/platform/blocking_counter.h"
|
|
||||||
#include "tensorflow/core/platform/status.h"
|
|
||||||
#include "tensorflow/core/util/env_var.h"
|
#include "tensorflow/core/util/env_var.h"
|
||||||
#include "tensorflow/core/util/saved_tensor_slice_util.h"
|
#include "tensorflow/core/util/saved_tensor_slice_util.h"
|
||||||
#include "tensorflow/core/util/tensor_bundle/byte_swap.h"
|
#include "tensorflow/core/util/tensor_bundle/byte_swap.h"
|
||||||
@ -1024,23 +1021,10 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
|
|||||||
" to restore in slice_spec: ", slice_spec.DebugString());
|
" to restore in slice_spec: ", slice_spec.DebugString());
|
||||||
}
|
}
|
||||||
|
|
||||||
BlockingCounter counter(static_cast<int>(details.size()));
|
|
||||||
auto runner = [this, &details](std::function<void()> fn) {
|
|
||||||
if (details.size() > 1) {
|
|
||||||
// If there are multiple slices to read, perform the read in parallel
|
|
||||||
// using multiple threads.
|
|
||||||
env_->SchedClosure(fn);
|
|
||||||
} else {
|
|
||||||
fn();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
for (const auto& slice_tag_pair : details) {
|
|
||||||
runner([this, &slice_spec, &full_shape, &slice_tag_pair, &full_tensor_entry,
|
|
||||||
&full_tensor_key_string, &counter, val]() {
|
|
||||||
// The union of the slices in "details" covers "slice_spec". Performs the
|
// The union of the slices in "details" covers "slice_spec". Performs the
|
||||||
// copies from each.
|
// copies from each.
|
||||||
BundleEntryProto stored_slice_entry = full_tensor_entry;
|
BundleEntryProto stored_slice_entry = full_tensor_entry;
|
||||||
|
for (const auto& slice_tag_pair : details) {
|
||||||
// Seeks for the stored slice.
|
// Seeks for the stored slice.
|
||||||
const TensorSlice& stored_slice = slice_tag_pair.first;
|
const TensorSlice& stored_slice = slice_tag_pair.first;
|
||||||
|
|
||||||
@ -1050,25 +1034,20 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
|
|||||||
const string encoded_stored_slice_name =
|
const string encoded_stored_slice_name =
|
||||||
checkpoint::EncodeTensorNameSlice(full_tensor_key_string,
|
checkpoint::EncodeTensorNameSlice(full_tensor_key_string,
|
||||||
stored_slice);
|
stored_slice);
|
||||||
mutex_lock l(mu_);
|
|
||||||
// `GetBundleEntryProto` will access `iter_`, so protecting it with a
|
|
||||||
// mutex lock.
|
|
||||||
status_ =
|
status_ =
|
||||||
GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry);
|
GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry);
|
||||||
if (!status_.ok()) return;
|
if (!status_.ok()) return status_;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto cleanup = gtl::MakeCleanup([&counter] { counter.DecrementCount(); });
|
// TODO(zongheng): should we take an OpKernelContext, so that we can call
|
||||||
|
// allocate_temp()? Note that without major refactorings to Saver, it's
|
||||||
// TODO(zongheng): should we take an OpKernelContext, so that we can
|
// hard for the caller of the tensor bundle module to allocate these
|
||||||
// call allocate_temp()? Note that without major refactorings to
|
// precisely-shaped scratch storage.
|
||||||
// Saver, it's hard for the caller of the tensor bundle module to
|
|
||||||
// allocate these precisely-shaped scratch storage.
|
|
||||||
|
|
||||||
// Optimization for the common case: the stored slice can be directly
|
// Optimization for the common case: the stored slice can be directly
|
||||||
// copied to the destination without additional slicing. This is true
|
// copied to the destination without additional slicing. This is true when
|
||||||
// when either the slices are equal or when they are both full slices
|
// either the slices are equal or when they are both full slices having the
|
||||||
// having the same shape.
|
// same shape.
|
||||||
TensorShape stored_slice_shape(stored_slice_entry.shape());
|
TensorShape stored_slice_shape(stored_slice_entry.shape());
|
||||||
if (stored_slice == slice_spec ||
|
if (stored_slice == slice_spec ||
|
||||||
(stored_slice_shape == val->shape() &&
|
(stored_slice_shape == val->shape() &&
|
||||||
@ -1078,18 +1057,14 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
|
|||||||
"pre-allocated buffer; spec: "
|
"pre-allocated buffer; spec: "
|
||||||
<< slice_spec.DebugString();
|
<< slice_spec.DebugString();
|
||||||
status_ = GetValue(stored_slice_entry, val);
|
status_ = GetValue(stored_slice_entry, val);
|
||||||
return;
|
return status_;
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor stored_slice_tensor(stored_slice_entry.dtype(),
|
Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape);
|
||||||
stored_slice_shape);
|
|
||||||
status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
|
status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
|
||||||
if (!status_.ok()) return;
|
if (!status_.ok()) return status_;
|
||||||
|
|
||||||
// Copies the intersection over.
|
// Copies the intersection over.
|
||||||
mutex_lock l(mu_);
|
|
||||||
// `CopyDataFromTensorSliceToTensorSlice` will write to `val`, so
|
|
||||||
// protecting it with a mutex lock.
|
|
||||||
const DataType common_dtype = full_tensor_entry.dtype();
|
const DataType common_dtype = full_tensor_entry.dtype();
|
||||||
switch (common_dtype) {
|
switch (common_dtype) {
|
||||||
#define HANDLE_COPY(T) \
|
#define HANDLE_COPY(T) \
|
||||||
@ -1098,6 +1073,7 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
|
|||||||
full_shape, stored_slice, slice_spec, \
|
full_shape, stored_slice, slice_spec, \
|
||||||
stored_slice_tensor.flat<T>().data(), val->flat<T>().data())); \
|
stored_slice_tensor.flat<T>().data(), val->flat<T>().data())); \
|
||||||
break;
|
break;
|
||||||
|
|
||||||
HANDLE_COPY(float)
|
HANDLE_COPY(float)
|
||||||
HANDLE_COPY(double)
|
HANDLE_COPY(double)
|
||||||
HANDLE_COPY(int32)
|
HANDLE_COPY(int32)
|
||||||
@ -1113,17 +1089,11 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
|
|||||||
HANDLE_COPY(qint8)
|
HANDLE_COPY(qint8)
|
||||||
HANDLE_COPY(bfloat16)
|
HANDLE_COPY(bfloat16)
|
||||||
default:
|
default:
|
||||||
status_ = errors::InvalidArgument(
|
return errors::InvalidArgument("Dtype ", DataTypeString(common_dtype),
|
||||||
"Dtype ", DataTypeString(common_dtype), " not supported.");
|
" not supported.");
|
||||||
if (!status_.ok()) return;
|
|
||||||
}
|
}
|
||||||
#undef HANDLE_COPY
|
#undef HANDLE_COPY
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
counter.Wait();
|
|
||||||
TF_RETURN_IF_ERROR(status_);
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -306,9 +306,6 @@ class BundleReader {
|
|||||||
// differs from that of the current system's processor architecture.
|
// differs from that of the current system's processor architecture.
|
||||||
bool need_to_swap_bytes_;
|
bool need_to_swap_bytes_;
|
||||||
|
|
||||||
// Protect internal states when accessing from multiple threads.
|
|
||||||
mutable mutex mu_;
|
|
||||||
|
|
||||||
friend class TensorBundleAlignmentTest; // For testing data alignment.
|
friend class TensorBundleAlignmentTest; // For testing data alignment.
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(BundleReader);
|
TF_DISALLOW_COPY_AND_ASSIGN(BundleReader);
|
||||||
|
Loading…
Reference in New Issue
Block a user