[tf.data] Implements the parallel version of BatchDataset. The dataset will do the copying elements work of BatchDataset parallelly. The CL also updates the modeling of the AsyncKnownRatio node to correctly model the dataset.

PiperOrigin-RevId: 355056363
Change-Id: Ie3326d0827783d9621b7df4324a3d80df5b29eba
This commit is contained in:
Jay Shi 2021-02-01 16:52:14 -08:00 committed by TensorFlower Gardener
parent 30749f263e
commit f3e7ae3965
33 changed files with 1539 additions and 258 deletions

View File

@ -32,6 +32,9 @@
examples in the same step.
* tf.data service supports custom data transfer protocols (other than
gRPC).
* `tf.data.Dataset.batch()` now supports `num_parallel_calls` argument,
which can be used to indicate that multiple input batches should be
computed in parallel.
## Bug Fixes and Other Changes

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "ParallelBatchDataset"
visibility: HIDDEN
}

View File

@ -516,9 +516,9 @@ class KnownRatio : public Node {
class AsyncKnownRatio : public Node {
public:
AsyncKnownRatio(Node::Args args, double ratio,
AsyncKnownRatio(Node::Args args, double ratio, double memory_ratio,
std::vector<std::shared_ptr<Parameter>> parameters)
: Node(args), ratio_(ratio) {
: Node(args), ratio_(ratio), memory_ratio_(memory_ratio) {
for (auto& parameter : parameters) {
parameters_[parameter->name] = std::move(parameter);
}
@ -534,7 +534,7 @@ class AsyncKnownRatio : public Node {
parameters.push_back(pair.second);
}
return std::make_shared<AsyncKnownRatio>(
Args{id_, name_, std::move(output)}, ratio_, parameters);
Args{id_, name_, std::move(output)}, ratio_, memory_ratio_, parameters);
}
// The input time is the sum of inherited input time and parallelism adjusted
@ -711,13 +711,14 @@ class AsyncKnownRatio : public Node {
}
if (parameter) {
if (ratio_ == 0) {
if (memory_ratio_ == 0) {
result += (*parameter)->value * AverageBufferedElementSize();
} else {
// The estimation is currently not accurate for MapAndBatchDataset for
// the maximum buffer size does not match `num_parallel_calls`
// parameter.
result += (*parameter)->value * AverageBufferedElementSize() / ratio_;
result +=
(*parameter)->value * AverageBufferedElementSize() / memory_ratio_;
}
}
return result;
@ -727,11 +728,22 @@ class AsyncKnownRatio : public Node {
TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
node_proto->set_node_class(NodeClass::ASYNC_KNOWN_RATIO);
node_proto->set_ratio(ratio_);
node_proto->set_memory_ratio(memory_ratio_);
return Status::OK();
}
private:
// Identifies how many input elements need to be created to construct an
// element for the dataset.
//
// Currently the value is 1 for PrefetchDataset and ParallelMapDataset,
// batch_size for MapAndBatchDataset and ParallelBatchDataset.
const double ratio_;
// For parallelism nodes, identifies how many parallelism calls are introduced
// by one buffered element. The value is defined to correctly estimate RAM
// budget bound with given num_parallel_calls (or buffer_size) combined with
// the estimated average size of buffered elements.
const double memory_ratio_;
};
class UnknownRatio : public Node {
@ -922,11 +934,18 @@ std::shared_ptr<Node> MakeKnownRatioNode(Node::Args args, double ratio) {
return std::make_shared<KnownRatio>(std::move(args), ratio);
}
std::shared_ptr<Node> MakeAsyncKnownRatioNode(
Node::Args args, double ratio, double memory_ratio,
std::vector<std::shared_ptr<Parameter>> parameters) {
return std::make_shared<AsyncKnownRatio>(std::move(args), ratio, memory_ratio,
std::move(parameters));
}
std::shared_ptr<Node> MakeAsyncKnownRatioNode(
Node::Args args, double ratio,
std::vector<std::shared_ptr<Parameter>> parameters) {
return std::make_shared<AsyncKnownRatio>(std::move(args), ratio,
std::move(parameters));
return MakeAsyncKnownRatioNode(std::move(args), /*ratio=*/ratio,
/*memory_ratio=*/ratio, std::move(parameters));
}
std::shared_ptr<Node> MakeSourceNode(Node::Args args) {
@ -1546,7 +1565,7 @@ Status Node::FromProto(ModelProto::Node node_proto,
break;
case NodeClass::ASYNC_KNOWN_RATIO:
restored_node = std::make_shared<AsyncKnownRatio>(
args, node_proto.ratio(),
args, node_proto.ratio(), node_proto.memory_ratio(),
/*parameters=*/std::vector<std::shared_ptr<Parameter>>());
break;
case NodeClass::UNKNOWN_RATIO:

View File

@ -605,6 +605,10 @@ std::shared_ptr<Node> MakeAsyncInterleaveManyNode(
std::shared_ptr<Node> MakeKnownRatioNode(Node::Args args, double ratio);
// AsyncKnownRatio nodes are the asynchronous version of KnownRate nodes.
std::shared_ptr<Node> MakeAsyncKnownRatioNode(
Node::Args args, double ratio, double memory_ratio,
std::vector<std::shared_ptr<Parameter>> parameters);
std::shared_ptr<Node> MakeAsyncKnownRatioNode(
Node::Args args, double ratio,
std::vector<std::shared_ptr<Parameter>> parameters);

View File

@ -88,6 +88,10 @@ message ModelProto {
// Ratio of input to output elements. This is only used by KNOWN_RATIO and
// ASYNC_KNOWN_RATIO nodes.
double ratio = 16;
// Ratio identifies how many parallelism calls are introduced by one
// buffered element. This is only used by ASYNC_KNOWN_RATIO nodes.
double memory_ratio = 17;
}
// Output node of this model.

View File

@ -30,6 +30,7 @@ tf_kernel_library(
srcs = ["batch_dataset_op.cc"],
hdrs = ["batch_dataset_op.h"],
deps = [
":dataset_utils",
":name_utils",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@ -683,6 +684,50 @@ tf_cc_test(
],
)
tf_kernel_library(
name = "parallel_batch_dataset_op",
srcs = ["parallel_batch_dataset_op.cc"],
hdrs = ["parallel_batch_dataset_op.h"],
deps = [
":dataset_utils",
":name_utils",
":stats_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/profiler/lib:traceme_encode",
],
)
tf_cc_test(
name = "parallel_batch_dataset_op_test",
size = "small",
srcs = ["parallel_batch_dataset_op_test.cc"],
deps = [
":dataset_test_base",
":dataset_utils",
":iterator_ops",
":name_utils",
":parallel_batch_dataset_op",
":range_dataset_op",
":stats_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:function_ops",
],
)
tf_kernel_library(
name = "parallel_interleave_dataset_op",
srcs = ["parallel_interleave_dataset_op.cc"],
@ -1476,6 +1521,7 @@ tf_kernel_library(
":optimize_dataset_op",
":optional_ops",
":padded_batch_dataset_op",
":parallel_batch_dataset_op",
":parallel_interleave_dataset_op",
":parallel_map_dataset_op",
":prefetch_dataset_op",

View File

@ -21,8 +21,8 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stringprintf.h"
@ -199,67 +199,9 @@ class BatchDatasetOp::Dataset : public DatasetBase {
// respective slice locations. This would require a different GetNext()
// overload that supports zero-copy, and might make sense in an
// optimization pass.
const size_t num_tuple_components = batch_elements[0].size();
out_tensors->reserve(num_tuple_components);
const int64 num_batch_elements = batch_elements.size();
for (size_t component_index = 0; component_index < num_tuple_components;
++component_index) {
const Tensor& first_element = batch_elements[0][component_index];
TensorShape batch_component_shape({num_batch_elements});
// NOTE(mrry): Copy the shape of the first element here, because
// `first_element.shape()` will become undefined after the 0th batch
// element is moved into the output batch.
TensorShape first_element_shape(first_element.shape());
batch_component_shape.AppendShape(first_element_shape);
out_tensors->emplace_back(ctx->allocator({}), first_element.dtype(),
batch_component_shape);
if (!out_tensors->back().IsInitialized()) {
return errors::ResourceExhausted(
"Failed to allocate memory for the batch of component ",
component_index);
}
Tensor& batch_component = out_tensors->back();
// Build the output tuple component by copying one slice
// from each input element in the batch.
auto copy_element_fn = [component_index, &batch_elements,
&batch_component](int index) {
TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice(
std::move(batch_elements[index][component_index]),
&batch_component, index));
return Status::OK();
};
BlockingCounter counter(num_batch_elements);
Status status;
mutex status_mu;
for (size_t i = 0; i < num_batch_elements; ++i) {
if (batch_elements[i][component_index].shape() !=
first_element_shape) {
return errors::InvalidArgument(
"Cannot batch tensors with different shapes in "
"component ",
component_index, ". First element had shape ",
first_element_shape.DebugString(), " and element ", i,
" had shape ",
batch_elements[i][component_index].shape().DebugString(), ".");
}
if (TF_PREDICT_FALSE(dataset()->parallel_copy_)) {
(*ctx->runner())(
[i, &status, &status_mu, &counter, &copy_element_fn]() {
Status s = copy_element_fn(i);
{
mutex_lock l(status_mu);
status.Update(s);
}
counter.DecrementCount();
});
} else {
status.Update(copy_element_fn(i));
counter.DecrementCount();
}
}
counter.Wait();
TF_RETURN_IF_ERROR(status);
}
TF_RETURN_IF_ERROR(CopyBatch(/*parallel_copy=*/dataset()->parallel_copy_,
ctx, out_tensors, &batch_elements));
*end_of_sequence = false;
return Status::OK();
}

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include <memory>
#include <queue>
#include "absl/container/flat_hash_map.h"
@ -30,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"
@ -39,10 +41,16 @@ limitations under the License.
namespace tensorflow {
namespace data {
namespace {
constexpr char kDelimiter[] = "@@";
constexpr char kComponent[] = "component";
constexpr char kNumElements[] = "num_elements";
constexpr char kNumComponents[] = "num_components";
constexpr char kOutputSize[] = "output_size";
constexpr char kCode[] = "code";
constexpr char kMessage[] = "msg";
constexpr char kOutput[] = "output";
} // namespace
Status WriteElementsToCheckpoint(
@ -699,5 +707,233 @@ void StripDevicePlacement(FunctionDefLibrary* library) {
}
}
Status CopyPartialBatch(int64 num_elements, const Tensor& value,
Tensor* output) {
switch (value.dtype()) {
#define HANDLE_TYPE(type) \
case DataTypeToEnum<type>::value: { \
auto output_t = output->flat_outer_dims<type>(); \
auto value_t = value.flat_outer_dims<type>(); \
for (size_t i = 0; i < num_elements; i++) { \
output_t.template chip<0>(i) = value_t.template chip<0>(i); \
} \
return Status::OK(); \
}
TF_CALL_DATASET_TYPES(HANDLE_TYPE);
#undef HANDLE_TYPE
default:
return errors::InvalidArgument("Unsupported data type: ",
DataTypeString(value.dtype()));
}
return Status::OK();
}
Status ReadBatch(int64 batch_size, const string& iterator_prefix,
const string& batch_prefix, IteratorContext* ctx,
IteratorStateReader* reader, std::vector<Tensor>* batch) {
int64 output_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
FullName(iterator_prefix,
strings::StrCat(batch_prefix, "_", kOutputSize)),
&output_size));
batch->reserve(output_size);
for (int i = 0; i < output_size; i++) {
Tensor t;
TF_RETURN_IF_ERROR(reader->ReadTensor(
FullName(iterator_prefix,
strings::StrCat(batch_prefix, "_", kOutput, "_", i)),
&t));
// If the batch was not full, we may have stored only the relevant slice.
// Since tensors in `BatchResult.output` are expected to have the leading
// dimension of size batch_size, we build a larger tensor and copy the slice
// read from the checkpoint into it.
if (t.dim_size(0) < batch_size) {
TensorShape component_shape(t.shape());
component_shape.set_dim(0, batch_size);
AllocatorAttributes attr;
attr.set_gpu_compatible(true);
Tensor new_t(ctx->allocator(attr), t.dtype(), component_shape);
TF_RETURN_IF_ERROR(CopyPartialBatch(t.dim_size(0), t, &new_t));
batch->emplace_back(std::move(new_t));
} else {
batch->emplace_back(std::move(t));
}
}
return Status::OK();
}
Status WriteBatch(int64 batch_size, int64 num_elements,
const string& iterator_prefix, const string& batch_prefix,
IteratorStateWriter* writer, std::vector<Tensor>* batch) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
FullName(iterator_prefix,
strings::StrCat(batch_prefix, "_", kOutputSize)),
batch->size()));
for (int i = 0; i < batch->size(); i++) {
// If the batch is not full, we only store the first `num_elements` values.
// The rest of the batch tensor is *uninitialized* and accessing that will
// raise msan errors.
if (num_elements < batch_size) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
FullName(iterator_prefix,
strings::StrCat(batch_prefix, "_", kOutput, "_", i)),
(*batch)[i].Slice(0, num_elements)));
} else {
TF_RETURN_IF_ERROR(writer->WriteTensor(
FullName(iterator_prefix,
strings::StrCat(batch_prefix, "_", kOutput, "_", i)),
(*batch)[i]));
}
}
return Status::OK();
}
Status ReadStatus(const string& iterator_prefix, const string& prefix,
IteratorStateReader* reader, Status* status) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(
FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)),
&code_int));
error::Code code = static_cast<error::Code>(code_int);
if (code != error::Code::OK) {
tstring error_message;
TF_RETURN_IF_ERROR(reader->ReadScalar(
FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)),
&error_message));
*status = Status(code, error_message);
} else {
*status = Status::OK();
}
return Status::OK();
}
Status WriteStatus(const string& iterator_prefix, const string& prefix,
const Status& status, IteratorStateWriter* writer) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)),
static_cast<int64>(status.code())));
if (!status.ok()) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)),
status.error_message()));
}
return Status::OK();
}
Status ProcessBatch(int64 batch_size, int64 num_elements, bool drop_remainder,
const Status& status, IteratorContext* ctx,
std::vector<Tensor>* output, bool* end_of_sequence,
std::vector<Tensor>* batch) {
if (num_elements == 0) {
if (status.ok() || errors::IsOutOfRange(status)) {
*end_of_sequence = true;
return Status::OK();
} else {
*end_of_sequence = false;
return status;
}
}
if (!status.ok() && !errors::IsOutOfRange(status)) {
*end_of_sequence = false;
return status;
}
if (num_elements < batch_size) {
if (drop_remainder) {
*end_of_sequence = true;
return Status::OK();
}
for (size_t i = 0; i < batch->size(); ++i) {
TensorShape component_shape((*batch)[i].shape());
component_shape.set_dim(0, num_elements);
AllocatorAttributes attr;
attr.set_gpu_compatible(true);
output->emplace_back(ctx->allocator(attr), (*batch)[i].dtype(),
component_shape);
if (!output->back().IsInitialized()) {
return errors::ResourceExhausted(
"Failed to allocate memory for the batch of component ", i);
}
TF_RETURN_IF_ERROR(
CopyPartialBatch(num_elements, (*batch)[i], &output->back()));
}
} else {
*output = std::move(*batch);
}
*end_of_sequence = false;
return Status::OK();
}
Status CopyBatch(bool parallel_copy, IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
std::vector<std::vector<Tensor>>* batch_elements) {
const size_t num_tuple_components = (*batch_elements)[0].size();
out_tensors->reserve(num_tuple_components);
const int64 num_batch_elements = batch_elements->size();
for (size_t component_index = 0; component_index < num_tuple_components;
++component_index) {
const Tensor& first_element = (*batch_elements)[0][component_index];
TensorShape batch_component_shape({num_batch_elements});
// NOTE(mrry): Copy the shape of the first element here, because
// `first_element.shape()` will become undefined after the 0th batch element
// is moved into the output batch.
TensorShape first_element_shape(first_element.shape());
batch_component_shape.AppendShape(first_element_shape);
out_tensors->emplace_back(ctx->allocator({}), first_element.dtype(),
batch_component_shape);
if (!out_tensors->back().IsInitialized()) {
return errors::ResourceExhausted(
"Failed to allocate memory for the batch of component ",
component_index);
}
Tensor& batch_component = out_tensors->back();
// Build the output tuple component by copying one slice from each input
// element in the batch.
auto copy_element_fn = [component_index, &batch_elements,
&batch_component](int index) {
TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice(
std::move((*batch_elements)[index][component_index]),
&batch_component, index));
return Status::OK();
};
Status status;
std::unique_ptr<BlockingCounter> counter;
std::unique_ptr<mutex> status_mu;
if (TF_PREDICT_FALSE(parallel_copy)) {
counter = std::make_unique<BlockingCounter>(num_batch_elements);
status_mu = std::make_unique<mutex>();
}
for (size_t i = 0; i < num_batch_elements; ++i) {
if ((*batch_elements)[i][component_index].shape() !=
first_element_shape) {
return errors::InvalidArgument(
"Cannot batch tensors with different shapes in component ",
component_index, ". First element had shape ",
first_element_shape.DebugString(), " and element ", i,
" had shape ",
(*batch_elements)[i][component_index].shape().DebugString(), ".");
}
if (TF_PREDICT_FALSE(parallel_copy)) {
(*ctx->runner())(
[i, &status, &status_mu, &counter, &copy_element_fn]() {
Status s = copy_element_fn(i);
{
mutex_lock l(*status_mu);
status.Update(s);
}
counter->DecrementCount();
});
} else {
status.Update(copy_element_fn(i));
}
}
if (TF_PREDICT_FALSE(parallel_copy)) {
counter->Wait();
}
TF_RETURN_IF_ERROR(status);
}
return Status::OK();
}
} // namespace data
} // namespace tensorflow

View File

@ -302,6 +302,40 @@ std::vector<tstring> SelectOptimizations(
// Removes device placements from the ops of all functions in `library`.
void StripDevicePlacement(FunctionDefLibrary* library);
// Copies partial of the batch output.
Status CopyPartialBatch(int64 num_elements, const Tensor& value,
Tensor* output);
// Reads a batch when restoring the iterator.
Status ReadBatch(int64 batch_size, const string& iterator_prefix,
const string& batch_prefix, IteratorContext* ctx,
IteratorStateReader* reader, std::vector<Tensor>* batch);
// Writes a batch when saving the iterator.
Status WriteBatch(int64 batch_size, int64 num_elements,
const string& iterator_prefix, const string& batch_prefix,
IteratorStateWriter* writer, std::vector<Tensor>* batch);
// Reads a status when restoring the iterator.
Status ReadStatus(const string& iterator_prefix, const string& prefix,
IteratorStateReader* reader, Status* status);
// Writes a status when saving the iterator.
Status WriteStatus(const string& iterator_prefix, const string& prefix,
const Status& status, IteratorStateWriter* writer);
// Processes a batch to output. In the case a partial batch is encountered, copy
// only partial of the batch.
Status ProcessBatch(int64 batch_size, int64 num_elements, bool drop_remainder,
const Status& status, IteratorContext* ctx,
std::vector<Tensor>* output, bool* end_of_sequence,
std::vector<Tensor>* batch);
// Copies the input elements to a batch.
Status CopyBatch(bool parallel_copy, IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
std::vector<std::vector<Tensor>>* batch_elements);
} // namespace data
} // namespace tensorflow

View File

@ -72,11 +72,7 @@ constexpr char kEndOfInput[] = "end_of_input";
constexpr char kNumCalls[] = "num_calls";
constexpr char kNumElements[] = "num_elements";
constexpr char kOutputAllocated[] = "output_allocated";
constexpr char kOutputSize[] = "output_size";
constexpr char kOutput[] = "output";
constexpr char kStatus[] = "status";
constexpr char kCode[] = "code";
constexpr char kMessage[] = "msg";
// Computes ceil(x / y).
inline int64 CeilDiv(int64 x, int64 y) { return (x + y - 1) / y; }
@ -257,7 +253,17 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
return profiler::TraceMeEncode("MapAndBatchConsume",
{{"element_id", result->uid}});
});
return ProcessResult(ctx, result, out_tensors, end_of_sequence);
// Deallocate tensors allocated for the output.
auto cleanup = gtl::MakeCleanup([result] { result->output.clear(); });
mutex_lock l(result->mu);
if (result->output_allocated) {
RecordBufferDequeue(ctx, result->output);
}
TF_RETURN_IF_ERROR(
ProcessBatch(dataset()->batch_size_, result->num_elements,
dataset()->drop_remainder_, result->status, ctx,
out_tensors, end_of_sequence, &result->output));
return Status::OK();
}
protected:
@ -477,27 +483,6 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
}
}
Status CopyPartialBatch(Tensor* output, const Tensor& value,
int64 num_elements) {
switch (value.dtype()) {
#define HANDLE_TYPE(type) \
case DataTypeToEnum<type>::value: { \
auto output_t = output->flat_outer_dims<type>(); \
auto value_t = value.flat_outer_dims<type>(); \
for (size_t i = 0; i < num_elements; i++) { \
output_t.template chip<0>(i) = value_t.template chip<0>(i); \
} \
return Status::OK(); \
}
TF_CALL_DATASET_TYPES(HANDLE_TYPE);
#undef HANDLE_TYPE
default:
return errors::InvalidArgument("Unsupported data type: ",
DataTypeString(value.dtype()));
}
return Status::OK();
}
void EnsureRunnerThreadStarted(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
@ -536,60 +521,6 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
return Status::OK();
}
Status ProcessResult(IteratorContext* ctx,
const std::shared_ptr<BatchResult>& result,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) {
mutex_lock l(result->mu);
if (result->output_allocated) {
RecordBufferDequeue(ctx, result->output);
}
if (result->num_elements == 0) {
if (result->status.ok() || errors::IsOutOfRange(result->status)) {
*end_of_sequence = true;
return Status::OK();
} else {
*end_of_sequence = false;
return result->status;
}
}
if (!result->status.ok() && !errors::IsOutOfRange(result->status)) {
// Deallocate tensors allocated for the output.
result->output.clear();
*end_of_sequence = false;
return result->status;
}
if (result->num_elements < dataset()->batch_size_) {
if (dataset()->drop_remainder_) {
// Deallocate tensors allocated for the output.
result->output.clear();
*end_of_sequence = true;
return Status::OK();
}
const std::vector<Tensor>& output = result->output;
for (size_t i = 0; i < output.size(); ++i) {
TensorShape component_shape(result->output[i].shape());
component_shape.set_dim(0, result->num_elements);
AllocatorAttributes attr;
attr.set_gpu_compatible(true);
out_tensors->emplace_back(ctx->allocator(attr), output[i].dtype(),
component_shape);
if (!out_tensors->back().IsInitialized()) {
return errors::ResourceExhausted(
"Failed to allocate memory for the batch of component ", i);
}
TF_RETURN_IF_ERROR(CopyPartialBatch(&out_tensors->back(), output[i],
result->num_elements));
}
// Deallocate tensors allocated for the output.
result->output.clear();
} else {
*out_tensors = std::move(result->output);
}
*end_of_sequence = false;
return Status::OK();
}
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
TF_LOCKS_EXCLUDED(*mu_) {
std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls;
@ -660,116 +591,54 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
batch_results_.push_back(
std::make_shared<BatchResult>(dataset()->batch_size_));
std::shared_ptr<BatchResult> result = batch_results_.back();
string prefix = strings::StrCat(kBatchResults, "_", index);
string batch_prefix = strings::StrCat(kBatchResults, "_", index);
mutex_lock l(result->mu);
result->end_of_input = reader->Contains(
full_name(strings::StrCat(prefix, "_", kEndOfInput)));
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name(strings::StrCat(prefix, "_", kNumCalls)),
&result->num_calls));
full_name(strings::StrCat(batch_prefix, "_", kEndOfInput)));
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(prefix, "_", kNumElements)),
full_name(strings::StrCat(batch_prefix, "_", kNumCalls)),
&result->num_calls));
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(batch_prefix, "_", kNumElements)),
&result->num_elements));
result->output_allocated = reader->Contains(
full_name(strings::StrCat(prefix, "_", kOutputAllocated)));
int64 output_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(prefix, "_", kOutputSize)), &output_size));
result->output.reserve(output_size);
for (int i = 0; i < output_size; i++) {
Tensor t;
TF_RETURN_IF_ERROR(reader->ReadTensor(
full_name(strings::StrCat(prefix, "_", kOutput, "_", i)), &t));
// If the batch was not full, we may have stored only the relevant
// slice. Since tensors in `BatchResult.output` are expected to
// have the leading dimension of size batch_size, we build a larger
// tensor and copy the slice read from the checkpoint into it.
if (t.dim_size(0) < dataset()->batch_size_) {
TensorShape component_shape(t.shape());
component_shape.set_dim(0, dataset()->batch_size_);
AllocatorAttributes attr;
attr.set_gpu_compatible(true);
Tensor new_t(ctx->allocator(attr), t.dtype(), component_shape);
TF_RETURN_IF_ERROR(CopyPartialBatch(&new_t, t, t.dim_size(0)));
result->output.emplace_back(std::move(new_t));
} else {
result->output.emplace_back(std::move(t));
}
}
TF_RETURN_IF_ERROR(ReadStatus(
reader, strings::StrCat(prefix, "_", kStatus), &result->status));
return Status::OK();
}
full_name(strings::StrCat(batch_prefix, "_", kOutputAllocated)));
Status ReadStatus(IteratorStateReader* reader, const string& prefix,
Status* status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(prefix, "_", kCode)), &code_int));
error::Code code = static_cast<error::Code>(code_int);
if (code != error::Code::OK) {
tstring error_message;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(prefix, "_", kMessage)), &error_message));
*status = Status(code, error_message);
} else {
*status = Status::OK();
}
TF_RETURN_IF_ERROR(ReadBatch(dataset()->batch_size_, prefix(),
batch_prefix, ctx, reader, &result->output));
TF_RETURN_IF_ERROR(ReadStatus(prefix(),
strings::StrCat(batch_prefix, "_", kStatus),
reader, &result->status));
return Status::OK();
}
Status WriteBatchResult(IteratorStateWriter* writer, size_t index)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
std::shared_ptr<BatchResult> result = batch_results_[index];
string prefix = strings::StrCat(kBatchResults, "_", index);
string batch_prefix = strings::StrCat(kBatchResults, "_", index);
mutex_lock l(result->mu);
if (result->end_of_input) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(prefix, "_", kEndOfInput)), ""));
full_name(strings::StrCat(batch_prefix, "_", kEndOfInput)), ""));
}
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(prefix, "_", kNumCalls)),
full_name(strings::StrCat(batch_prefix, "_", kNumCalls)),
result->num_calls));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(prefix, "_", kNumElements)),
full_name(strings::StrCat(batch_prefix, "_", kNumElements)),
result->num_elements));
if (result->output_allocated) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(prefix, "_", kOutputAllocated)), ""));
full_name(strings::StrCat(batch_prefix, "_", kOutputAllocated)),
""));
}
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(prefix, "_", kOutputSize)),
result->output.size()));
for (int i = 0; i < result->output.size(); i++) {
// If the batch is not full, we only store the first `num_elements`
// values. The rest of the batch tensor is *uninitialized* and
// accessing that will raise msan errors.
if (result->num_elements < dataset()->batch_size_) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(strings::StrCat(prefix, "_", kOutput, "_", i)),
result->output[i].Slice(0, result->num_elements)));
} else {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(strings::StrCat(prefix, "_", kOutput, "_", i)),
result->output[i]));
}
}
TF_RETURN_IF_ERROR(WriteStatus(
writer, strings::StrCat(prefix, "_", kStatus), result->status));
return Status::OK();
}
Status WriteStatus(IteratorStateWriter* writer, const string& prefix,
const Status& status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(WriteBatch(dataset()->batch_size_,
result->num_elements, prefix(),
batch_prefix, writer, &result->output));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(strings::StrCat(prefix, "_", kCode)),
static_cast<int64>(status.code())));
if (!status.ok()) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(prefix, "_", kMessage)),
status.error_message()));
}
WriteStatus(prefix(), strings::StrCat(batch_prefix, "_", kStatus),
result->status, writer));
return Status::OK();
}

View File

@ -0,0 +1,548 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/parallel_batch_dataset_op.h"
#include <algorithm>
#include <memory>
#include <utility>
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
#include "tensorflow/core/common_runtime/metrics.h"
#include "tensorflow/core/framework/model.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/kernels/data/stats_utils.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/env_time.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/stringprintf.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/profiler/lib/traceme_encode.h"
#include "tensorflow/core/util/batch_util.h"
namespace tensorflow {
namespace data {
/* static */ constexpr const char* const ParallelBatchDatasetOp::kDatasetType;
/* static */ constexpr const char* const ParallelBatchDatasetOp::kInputDataset;
/* static */ constexpr const char* const ParallelBatchDatasetOp::kBatchSize;
/* static */ constexpr const char* const
ParallelBatchDatasetOp::kNumParallelCalls;
/* static */ constexpr const char* const ParallelBatchDatasetOp::kDropRemainder;
/* static */ constexpr const char* const ParallelBatchDatasetOp::kOutputTypes;
/* static */ constexpr const char* const ParallelBatchDatasetOp::kOutputShapes;
namespace {
constexpr char kBatchResultsSize[] = "batch_results_size";
constexpr char kTFDataParallelBatch[] = "tf_data_parallel_batch";
constexpr char kBatchResults[] = "batch_results";
constexpr char kEndOfInput[] = "end_of_input";
constexpr char kNumElements[] = "num_elements";
constexpr char kCallFinished[] = "call_finished";
constexpr char kStatus[] = "status";
} // namespace
class ParallelBatchDatasetOp::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 batch_size, int64 num_parallel_calls,
bool drop_remainder, const DatasetBase* input)
: DatasetBase(DatasetContext(ctx)),
batch_size_(batch_size),
// Dataset batch is sometimes used to stack all elements in the
// dataset. In such cases, a very large batch size (e.g., INT32_MAX)
// is passed with drop_remainder set to false. Avoid OOM in such case
// by limiting `reserve()` size by 2**16.
reserve_size_(drop_remainder ? batch_size
: std::min<int64>(batch_size, 1 << 16)),
num_parallel_calls_(num_parallel_calls),
drop_remainder_(drop_remainder),
input_(input),
traceme_metadata_(
{{"autotune",
num_parallel_calls == model::kAutotune ? "true" : "false"},
{"batch_size",
strings::Printf("%lld", static_cast<long long>(batch_size))},
{"drop_remainder", drop_remainder ? "true" : "false"}}) {
input_->Ref();
const auto& input_shapes = input_->output_shapes();
output_shapes_.reserve(input_shapes.size());
for (const auto& input_shape : input_shapes) {
if (drop_remainder_ || input_->Cardinality() == kInfiniteCardinality) {
output_shapes_.emplace_back(
PartialTensorShape({batch_size_}).Concatenate(input_shape));
} else {
output_shapes_.emplace_back(
PartialTensorShape({-1}).Concatenate(input_shape));
}
}
}
~Dataset() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(Iterator::Params{
this, name_utils::IteratorPrefix(kDatasetType, prefix)});
}
const DataTypeVector& output_dtypes() const override {
return input_->output_dtypes();
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
string DebugString() const override {
name_utils::DatasetDebugStringParams params;
params.set_args(batch_size_);
return name_utils::DatasetDebugString(kDatasetType, params);
}
int64 Cardinality() const override {
int64 n = input_->Cardinality();
if (n == kInfiniteCardinality || n == kUnknownCardinality) {
return n;
}
return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
}
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
inputs->push_back(input_);
return Status::OK();
}
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
// Input: input_dataset
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
// Input: batch_size
Node* batch_size = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
// Input: num_parallel_calls
Node* num_parallel_calls = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(num_parallel_calls_, &num_parallel_calls));
// Input: drop_remainder
Node* drop_remainder = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder));
TF_RETURN_IF_ERROR(b->AddDataset(
this,
{input_graph_node, batch_size, num_parallel_calls, drop_remainder}, {},
output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
mu_(std::make_shared<mutex>()),
cond_var_(std::make_shared<condition_variable>()),
num_parallel_calls_(std::make_shared<model::SharedState>(
params.dataset->num_parallel_calls_, mu_, cond_var_)) {}
~Iterator() override {
CancelThreads(/*wait=*/true);
if (deregister_fn_) deregister_fn_();
}
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_);
if (num_parallel_calls_->value == model::kAutotune) {
num_parallel_calls_->value = 1;
}
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(),
[this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_));
return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
std::shared_ptr<BatchResult> result;
{
mutex_lock l(*mu_);
EnsureRunnerThreadStarted(ctx);
while (!cancelled_ && (batch_results_.empty() ||
!batch_results_.front()->call_finished)) {
++waiting_;
RecordStop(ctx);
cond_var_->wait(l);
RecordStart(ctx);
--waiting_;
}
if (cancelled_) {
return errors::Cancelled("Iterator was cancelled");
}
std::swap(result, batch_results_.front());
batch_results_.pop_front();
cond_var_->notify_all();
}
profiler::TraceMe traceme([&] {
return profiler::TraceMeEncode("ParallelBatchConsume",
{{"element_id", result->uid}});
});
mutex_lock l(result->mu);
// Deallocate tensors allocated for the output.
auto cleanup =
gtl::MakeCleanup([result]() TF_EXCLUSIVE_LOCKS_REQUIRED(
&BatchResult::mu) { result->output.clear(); });
RecordBufferDequeue(ctx, result->output);
TF_RETURN_IF_ERROR(
ProcessBatch(dataset()->batch_size_, result->num_elements,
dataset()->drop_remainder_, result->status, ctx,
out_tensors, end_of_sequence, &result->output));
return Status::OK();
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeAsyncKnownRatioNode(
std::move(args),
/*ratio=*/dataset()->batch_size_, /*memory_ratio=*/1.0,
{model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
/*max=*/ctx->runner_threadpool_size())});
}
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
mutex_lock l(*mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
cond_var_->wait(l);
}
DCHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kBatchResultsSize),
batch_results_.size()));
for (size_t i = 0; i < batch_results_.size(); ++i) {
TF_RETURN_IF_ERROR(WriteBatchResult(writer, i));
}
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
int64 batch_results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kBatchResultsSize),
&batch_results_size));
for (int i = 0; i < batch_results_size; ++i) {
TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
}
return Status::OK();
}
TraceMeMetadata GetTraceMeMetadata() const override {
int64 parallelism = -1;
// NOTE: We only set the parallelism value if the lock can be acquired
// right away to avoid introducing tracing overhead.
if (mu_->try_lock()) {
parallelism = num_parallel_calls_->value;
mu_->unlock();
}
auto result = dataset()->traceme_metadata_;
result.push_back(std::make_pair(
"parallelism",
strings::Printf("%lld", static_cast<long long>(parallelism))));
return result;
}
// BatchResult encapsulates the output batch.
struct BatchResult {
explicit BatchResult()
: end_of_input(false),
num_elements(0),
status(Status::OK()),
call_finished(false),
uid(tensorflow::EnvTime::NowNanos()) {}
mutex mu;
bool end_of_input TF_GUARDED_BY(mu);
int64 num_elements TF_GUARDED_BY(mu);
std::vector<Tensor> output TF_GUARDED_BY(mu);
Status status TF_GUARDED_BY(mu);
bool call_finished TF_GUARDED_BY(&Iterator::mu_);
const int64 uid = -1;
};
void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
const std::shared_ptr<BatchResult>& result)
TF_LOCKS_EXCLUDED(*mu_) {
mutex_lock l(*mu_);
num_calls_--;
result->call_finished = true;
cond_var_->notify_all();
}
// The function fetches elements from input dataset sequentially and then
// executes the batching for different batches in parallel using the context
// runner.
void CallBatching(std::shared_ptr<IteratorContext> ctx,
const std::shared_ptr<BatchResult>& result)
TF_LOCKS_EXCLUDED(*mu_) {
profiler::TraceMe traceme([&] {
return profiler::TraceMeEncode("ParallelBatchProduce",
{{"element_id", result->uid}});
});
if (!input_impl_) {
CallCompleted(ctx, result);
return;
}
// Each row of `batch_elements` is a tuple of tensors from the input
// iterator.
auto batch_elements =
std::make_shared<std::vector<std::vector<Tensor>>>();
batch_elements->reserve(dataset()->reserve_size_);
bool end_of_input = false;
for (int i = 0; i < dataset()->batch_size_ && !end_of_input; ++i) {
std::vector<Tensor> batch_element_tuple;
Status status = input_impl_->GetNext(ctx.get(), &batch_element_tuple,
&end_of_input);
{
mutex_lock l(result->mu);
result->end_of_input = result->end_of_input || end_of_input;
result->status.Update(status);
if (result->end_of_input || !result->status.ok()) break;
}
if (!end_of_input) {
batch_elements->emplace_back(std::move(batch_element_tuple));
mutex_lock l(result->mu);
result->num_elements++;
} else {
input_impl_.reset();
}
}
if (batch_elements->empty()) {
CallCompleted(ctx, result);
DCHECK(end_of_input);
return;
}
auto copy_elements_fn = [this, ctx, result, batch_elements]() {
Status status;
{
mutex_lock l(result->mu);
status = CopyBatch(/*parallel_copy=*/false, ctx.get(),
&result->output, batch_elements.get());
result->status.Update(status);
RecordBufferEnqueue(ctx.get(), result->output);
}
CallCompleted(ctx, result);
return status;
};
(*ctx->runner())(copy_elements_fn);
}
void CancelThreads(bool wait) TF_LOCKS_EXCLUDED(mu_) {
mutex_lock l(*mu_);
cancelled_ = true;
cond_var_->notify_all();
// Wait for all in-flight calls to complete.
while (wait && num_calls_ > 0) {
cond_var_->wait(l);
}
}
void EnsureRunnerThreadStarted(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_ = ctx->StartThread(
kTFDataParallelBatch,
std::bind(&Iterator::RunnerThread, this, ctx_copy));
}
}
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
TF_LOCKS_EXCLUDED(*mu_) {
std::vector<std::shared_ptr<BatchResult>> new_calls;
RecordStart(ctx.get());
auto stop_cleanup =
gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); });
{
tf_shared_lock l(*mu_); // mu_ == num_parallel_calls_->mu
new_calls.reserve(num_parallel_calls_->value);
}
auto busy = [this]() TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
return num_calls_ >= num_parallel_calls_->value;
};
while (true) {
{
mutex_lock l(*mu_);
while (!cancelled_ && busy()) {
RecordStop(ctx.get());
cond_var_->wait(l);
RecordStart(ctx.get());
}
if (cancelled_) {
return;
}
while (!busy()) {
batch_results_.push_back(std::make_shared<BatchResult>());
new_calls.emplace_back(batch_results_.back());
num_calls_++;
}
}
for (const auto& call : new_calls) {
CallBatching(ctx, call);
}
new_calls.clear();
}
}
Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader,
size_t index) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
batch_results_.push_back(std::make_shared<BatchResult>());
std::shared_ptr<BatchResult> result = batch_results_.back();
string batch_prefix = strings::StrCat(kBatchResults, "_", index);
mutex_lock l(result->mu);
result->end_of_input = reader->Contains(
full_name(strings::StrCat(batch_prefix, "_", kEndOfInput)));
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(batch_prefix, "_", kNumElements)),
&result->num_elements));
result->call_finished = reader->Contains(
full_name(strings::StrCat(batch_prefix, "_", kCallFinished)));
TF_RETURN_IF_ERROR(ReadBatch(dataset()->batch_size_, prefix(),
batch_prefix, ctx, reader, &result->output));
TF_RETURN_IF_ERROR(ReadStatus(prefix(),
strings::StrCat(batch_prefix, "_", kStatus),
reader, &result->status));
return Status::OK();
}
Status WriteBatchResult(IteratorStateWriter* writer, size_t index)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
std::shared_ptr<BatchResult> result = batch_results_[index];
string batch_prefix = strings::StrCat(kBatchResults, "_", index);
mutex_lock l(result->mu);
if (result->end_of_input) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(batch_prefix, "_", kEndOfInput)), ""));
}
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(batch_prefix, "_", kNumElements)),
result->num_elements));
if (result->call_finished) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(batch_prefix, "_", kCallFinished)), ""));
}
TF_RETURN_IF_ERROR(WriteBatch(dataset()->batch_size_,
result->num_elements, prefix(),
batch_prefix, writer, &result->output));
TF_RETURN_IF_ERROR(
WriteStatus(prefix(), strings::StrCat(batch_prefix, "_", kStatus),
result->status, writer));
return Status::OK();
}
// Used for coordination between the main thread and the runner thread.
const std::shared_ptr<mutex> mu_;
// Used for coordination between the main thread and the runner thread. In
// particular, the runner thread should only schedule new calls when the
// number of in-flight calls is less than the user specified level of
// parallelism and there are slots available in the `invocation_results_`
// buffer.
const std::shared_ptr<condition_variable> cond_var_;
// Identifies the maximum number of parallel calls.
const std::shared_ptr<model::SharedState> num_parallel_calls_;
// Counts the number of outstanding calls for this batch.
int64 num_calls_ TF_GUARDED_BY(*mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
// Buffer for storing the (intermediate) batch results.
std::deque<std::shared_ptr<BatchResult>> batch_results_ TF_GUARDED_BY(*mu_);
// Background thread used for coordinating input processing.
std::unique_ptr<Thread> runner_thread_ TF_GUARDED_BY(*mu_);
// Determines whether the transformation has been cancelled.
bool cancelled_ TF_GUARDED_BY(*mu_) = false;
// Identifies the number of callers currently waiting for a batch result.
int64 waiting_ TF_GUARDED_BY(*mu_) = 0;
// Method for deregistering the cancellation callback.
std::function<void()> deregister_fn_;
};
const int64 batch_size_;
const int64 reserve_size_;
const int64 num_parallel_calls_;
const bool drop_remainder_;
const DatasetBase* const input_;
std::vector<PartialTensorShape> output_shapes_;
const TraceMeMetadata traceme_metadata_;
};
ParallelBatchDatasetOp::ParallelBatchDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {}
void ParallelBatchDatasetOp::MakeDataset(OpKernelContext* ctx,
DatasetBase* input,
DatasetBase** output) {
int64 batch_size = 0;
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kBatchSize, &batch_size));
OP_REQUIRES(ctx, batch_size > 0,
errors::InvalidArgument("Batch size must be greater than zero."));
int64 num_parallel_calls = 0;
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kNumParallelCalls,
&num_parallel_calls));
bool drop_remainder = false;
OP_REQUIRES_OK(
ctx, ParseScalarArgument<bool>(ctx, kDropRemainder, &drop_remainder));
*output =
new Dataset(ctx, batch_size, num_parallel_calls, drop_remainder, input);
}
namespace {
REGISTER_KERNEL_BUILDER(Name("ParallelBatchDataset").Device(DEVICE_CPU),
ParallelBatchDatasetOp);
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -0,0 +1,47 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_BATCH_DATASET_OP_H_
#define TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_BATCH_DATASET_OP_H_
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
namespace tensorflow {
namespace data {
class ParallelBatchDatasetOp : public UnaryDatasetOpKernel {
public:
static constexpr const char* const kDatasetType = "ParallelBatch";
static constexpr const char* const kInputDataset = "input_dataset";
static constexpr const char* const kBatchSize = "batch_size";
static constexpr const char* const kNumParallelCalls = "num_parallel_calls";
static constexpr const char* const kDropRemainder = "drop_remainder";
static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes";
explicit ParallelBatchDatasetOp(OpKernelConstruction* ctx);
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override;
private:
class Dataset;
};
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_BATCH_DATASET_OP_H_

View File

@ -0,0 +1,416 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/parallel_batch_dataset_op.h"
#include "tensorflow/core/kernels/data/dataset_test_base.h"
namespace tensorflow {
namespace data {
namespace {
constexpr char kNodeName[] = "parallel_batch_dataset";
constexpr int kOpVersion = 1;
class ParallelBatchDatasetParams : public DatasetParams {
public:
template <typename T>
ParallelBatchDatasetParams(T input_dataset_params, int64 batch_size,
int64 num_parallel_calls, bool drop_remainder,
DataTypeVector output_dtypes,
std::vector<PartialTensorShape> output_shapes,
string node_name)
: DatasetParams(std::move(output_dtypes), std::move(output_shapes),
std::move(node_name)),
batch_size_(batch_size),
num_parallel_calls_(num_parallel_calls),
drop_remainder_(drop_remainder) {
input_dataset_params_.push_back(std::make_unique<T>(input_dataset_params));
op_version_ = kOpVersion;
iterator_prefix_ =
name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
input_dataset_params.iterator_prefix());
}
std::vector<Tensor> GetInputTensors() const override {
Tensor batch_size = CreateTensor<int64>(TensorShape({}), {batch_size_});
Tensor num_parallel_calls =
CreateTensor<int64>(TensorShape({}), {num_parallel_calls_});
Tensor drop_remainder =
CreateTensor<bool>(TensorShape({}), {drop_remainder_});
return {batch_size, num_parallel_calls, drop_remainder};
}
Status GetInputNames(std::vector<string>* input_names) const override {
*input_names = {ParallelBatchDatasetOp::kInputDataset,
ParallelBatchDatasetOp::kBatchSize,
ParallelBatchDatasetOp::kNumParallelCalls,
ParallelBatchDatasetOp::kDropRemainder};
return Status::OK();
}
Status GetAttributes(AttributeVector* attr_vector) const override {
*attr_vector = {{ParallelBatchDatasetOp::kOutputTypes, output_dtypes_},
{ParallelBatchDatasetOp::kOutputShapes, output_shapes_}};
return Status::OK();
};
string dataset_type() const override {
return ParallelBatchDatasetOp::kDatasetType;
}
private:
int64 batch_size_;
int64 num_parallel_calls_;
bool drop_remainder_;
};
class ParallelBatchDatasetOpTest : public DatasetOpsTestBase {};
// Test Case 1: test ParallelBatchDataset with `drop_remainder` = false and a
// batch size that can evenly split the input dataset.
ParallelBatchDatasetParams ParallelBatchDatasetParams1() {
return ParallelBatchDatasetParams(RangeDatasetParams(0, 12, 1),
/*batch_size=*/4,
/*num_parallel_calls=*/1,
/*drop_remainder=*/false,
/*output_dtypes=*/{DT_INT64},
/*output_shapes=*/{PartialTensorShape({4})},
/*node_name=*/kNodeName);
}
// Test Case 2: test ParallelBatchDataset with `drop_remainder` = true and a
// batch size that can evenly split the input dataset.
ParallelBatchDatasetParams ParallelBatchDatasetParams2() {
return ParallelBatchDatasetParams(RangeDatasetParams(0, 12, 1),
/*batch_size=*/4,
/*num_parallel_calls=*/1,
/*drop_remainder=*/true,
/*output_dtypes=*/{DT_INT64},
/*output_shapes=*/{PartialTensorShape({4})},
/*node_name=*/kNodeName);
}
// Test Case 3: test ParallelBatchDataset with `drop_remainder` = false and a
// batch size that can not evenly split the input dataset.
ParallelBatchDatasetParams ParallelBatchDatasetParams3() {
return ParallelBatchDatasetParams(
RangeDatasetParams(0, 10, 1),
/*batch_size=*/3,
/*num_parallel_calls=*/1,
/*drop_remainder=*/false,
/*output_dtypes=*/{DT_INT64},
/*output_shapes=*/{PartialTensorShape({-1})},
/*node_name=*/kNodeName);
}
// Test Case 4: test ParallelBatchDataset with `drop_remainder` = true and a
// batch size that can not evenly split the input dataset.
ParallelBatchDatasetParams ParallelBatchDatasetParams4() {
return ParallelBatchDatasetParams(RangeDatasetParams(0, 10, 1),
/*batch_size=*/3,
/*num_parallel_calls=*/1,
/*drop_remainder=*/true,
/*output_dtypes=*/{DT_INT64},
/*output_shapes=*/{PartialTensorShape({3})},
/*node_name=*/kNodeName);
}
// Test Case 5: test ParallelBatchDataset with `drop_remainder` = true and
// `batch_size` > the cardinality of the input dataset.
ParallelBatchDatasetParams ParallelBatchDatasetParams5() {
return ParallelBatchDatasetParams(
RangeDatasetParams(0, 10, 1),
/*batch_size=*/12,
/*num_parallel_calls=*/1,
/*drop_remainder=*/true,
/*output_dtypes=*/{DT_INT64},
/*output_shapes=*/{PartialTensorShape({12})},
/*node_name=*/kNodeName);
}
// Test Case 6: test ParallelBatchDataset with `drop_remainder` = false and
// `batch_size` > the cardinality of the input dataset.
ParallelBatchDatasetParams ParallelBatchDatasetParams6() {
return ParallelBatchDatasetParams(
RangeDatasetParams(0, 10, 1),
/*batch_size=*/12,
/*num_parallel_calls=*/1,
/*drop_remainder=*/false,
/*output_dtypes=*/{DT_INT64},
/*output_shapes=*/{PartialTensorShape({-1})},
/*node_name=*/kNodeName);
}
// Test Case 7: test ParallelBatchDataset with `drop_remainder` = false and
// the output of the input dataset is empty.
ParallelBatchDatasetParams ParallelBatchDatasetParams7() {
return ParallelBatchDatasetParams(RangeDatasetParams(0, 0, 1),
/*batch_size=*/4,
/*num_parallel_calls=*/1,
/*drop_remainder=*/false,
/*output_dtypes=*/{DT_INT64},
/*output_shapes=*/{PartialTensorShape({4})},
/*node_name=*/kNodeName);
}
// Test Case 8: test ParallelBatchDataset with `num_parallel_calls` = 2.
ParallelBatchDatasetParams ParallelBatchDatasetParams8() {
return ParallelBatchDatasetParams(RangeDatasetParams(0, 12, 1),
/*batch_size=*/4,
/*num_parallel_calls=*/2,
/*drop_remainder=*/false,
/*output_dtypes=*/{DT_INT64},
/*output_shapes=*/{PartialTensorShape({4})},
/*node_name=*/kNodeName);
}
// Test Case 9: test ParallelBatchDataset with `num_parallel_calls` = 4.
ParallelBatchDatasetParams ParallelBatchDatasetParams9() {
return ParallelBatchDatasetParams(RangeDatasetParams(0, 12, 1),
/*batch_size=*/4,
/*num_parallel_calls=*/4,
/*drop_remainder=*/false,
/*output_dtypes=*/{DT_INT64},
/*output_shapes=*/{PartialTensorShape({4})},
/*node_name=*/kNodeName);
}
// Test Case 10: test ParallelBatchDataset with an invalid batch size.
ParallelBatchDatasetParams InvalidBatchSizeParallelBatchDatasetParams() {
return ParallelBatchDatasetParams(RangeDatasetParams(0, 10, 1),
/*batch_size=*/-1,
/*num_parallel_calls=*/1,
/*drop_remainder=*/false,
/*output_dtypes=*/{DT_INT64},
/*output_shapes=*/{PartialTensorShape({3})},
/*node_name=*/kNodeName);
}
std::vector<GetNextTestCase<ParallelBatchDatasetParams>> GetNextTestCases() {
return {{/*dataset_params=*/ParallelBatchDatasetParams1(),
/*expected_outputs=*/
CreateTensors<int64>(TensorShape({4}),
{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})},
{/*dataset_params=*/ParallelBatchDatasetParams2(),
/*expected_outputs=*/
CreateTensors<int64>(TensorShape({4}),
{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})},
{/*dataset_params=*/ParallelBatchDatasetParams3(),
/*expected_outputs=*/
{CreateTensor<int64>(TensorShape({3}), {0, 1, 2}),
CreateTensor<int64>(TensorShape({3}), {3, 4, 5}),
CreateTensor<int64>(TensorShape({3}), {6, 7, 8}),
CreateTensor<int64>(TensorShape({1}), {9})}},
{/*dataset_params=*/ParallelBatchDatasetParams4(),
/*expected_outputs=*/
CreateTensors<int64>(TensorShape({3}),
{{0, 1, 2}, {3, 4, 5}, {6, 7, 8}})},
{/*dataset_params=*/ParallelBatchDatasetParams5(),
/*expected_outputs=*/{}},
{/*dataset_params=*/ParallelBatchDatasetParams6(),
/*expected_outputs=*/
CreateTensors<int64>(TensorShape({10}),
{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}})},
{/*dataset_params=*/ParallelBatchDatasetParams7(),
/*expected_outputs=*/{}},
{/*dataset_params=*/ParallelBatchDatasetParams8(),
/*expected_outputs=*/
CreateTensors<int64>(TensorShape({4}),
{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})},
{/*dataset_params=*/ParallelBatchDatasetParams9(),
/*expected_outputs=*/
CreateTensors<int64>(TensorShape({4}),
{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})}};
}
ITERATOR_GET_NEXT_TEST_P(ParallelBatchDatasetOpTest, ParallelBatchDatasetParams,
GetNextTestCases())
TEST_F(ParallelBatchDatasetOpTest, DatasetNodeName) {
auto parallel_batch_dataset_params = ParallelBatchDatasetParams1();
TF_ASSERT_OK(Initialize(parallel_batch_dataset_params));
TF_ASSERT_OK(CheckDatasetNodeName(parallel_batch_dataset_params.node_name()));
}
TEST_F(ParallelBatchDatasetOpTest, DatasetTypeString) {
auto parallel_batch_dataset_params = ParallelBatchDatasetParams1();
TF_ASSERT_OK(Initialize(parallel_batch_dataset_params));
name_utils::OpNameParams params;
params.op_version = parallel_batch_dataset_params.op_version();
TF_ASSERT_OK(CheckDatasetTypeString(
name_utils::OpName(ParallelBatchDatasetOp::kDatasetType, params)));
}
TEST_F(ParallelBatchDatasetOpTest, DatasetOutputDtypes) {
auto parallel_batch_dataset_params = ParallelBatchDatasetParams1();
TF_ASSERT_OK(Initialize(parallel_batch_dataset_params));
TF_ASSERT_OK(CheckDatasetOutputDtypes({DT_INT64}));
}
std::vector<DatasetOutputShapesTestCase<ParallelBatchDatasetParams>>
DatasetOutputShapesTestCases() {
return {{/*dataset_params=*/ParallelBatchDatasetParams1(),
/*expected_output_shapes=*/{PartialTensorShape({4})}},
{/*dataset_params=*/ParallelBatchDatasetParams2(),
/*expected_output_shapes=*/{PartialTensorShape({4})}},
{/*dataset_params=*/ParallelBatchDatasetParams3(),
/*expected_output_shapes=*/{PartialTensorShape({-1})}},
{/*dataset_params=*/ParallelBatchDatasetParams4(),
/*expected_output_shapes=*/{PartialTensorShape({3})}},
{/*dataset_params=*/ParallelBatchDatasetParams5(),
/*expected_output_shapes=*/{PartialTensorShape({12})}},
{/*dataset_params=*/ParallelBatchDatasetParams6(),
/*expected_output_shapes=*/{PartialTensorShape({-1})}},
{/*dataset_params=*/ParallelBatchDatasetParams7(),
/*expected_output_shapes=*/{PartialTensorShape({4})}},
{/*dataset_params=*/ParallelBatchDatasetParams8(),
/*expected_output_shapes=*/{PartialTensorShape({4})}},
{/*dataset_params=*/ParallelBatchDatasetParams9(),
/*expected_output_shapes=*/{PartialTensorShape({4})}}};
}
DATASET_OUTPUT_SHAPES_TEST_P(ParallelBatchDatasetOpTest,
ParallelBatchDatasetParams,
DatasetOutputShapesTestCases())
std::vector<CardinalityTestCase<ParallelBatchDatasetParams>>
CardinalityTestCases() {
return {{/*dataset_params=*/ParallelBatchDatasetParams1(),
/*expected_cardinality=*/3},
{/*dataset_params=*/ParallelBatchDatasetParams2(),
/*expected_cardinality=*/3},
{/*dataset_params=*/ParallelBatchDatasetParams3(),
/*expected_cardinality=*/4},
{/*dataset_params=*/ParallelBatchDatasetParams4(),
/*expected_cardinality=*/3},
{/*dataset_params=*/ParallelBatchDatasetParams5(),
/*expected_cardinality=*/0},
{/*dataset_params=*/ParallelBatchDatasetParams6(),
/*expected_cardinality=*/1},
{/*dataset_params=*/ParallelBatchDatasetParams7(),
/*expected_cardinality=*/0},
{/*dataset_params=*/ParallelBatchDatasetParams8(),
/*expected_cardinality=*/3},
{/*dataset_params=*/ParallelBatchDatasetParams9(),
/*expected_cardinality=*/3}};
}
DATASET_CARDINALITY_TEST_P(ParallelBatchDatasetOpTest,
ParallelBatchDatasetParams, CardinalityTestCases())
TEST_F(ParallelBatchDatasetOpTest, IteratorOutputDtypes) {
auto parallel_batch_dataset_params = ParallelBatchDatasetParams1();
TF_ASSERT_OK(Initialize(parallel_batch_dataset_params));
TF_ASSERT_OK(CheckIteratorOutputDtypes({DT_INT64}));
}
std::vector<IteratorOutputShapesTestCase<ParallelBatchDatasetParams>>
IteratorOutputShapesTestCases() {
return {{/*dataset_params=*/ParallelBatchDatasetParams1(),
/*expected_output_shapes=*/{PartialTensorShape({4})}},
{/*dataset_params=*/ParallelBatchDatasetParams2(),
/*expected_output_shapes=*/{PartialTensorShape({4})}},
{/*dataset_params=*/ParallelBatchDatasetParams3(),
/*expected_output_shapes=*/{PartialTensorShape({-1})}},
{/*dataset_params=*/ParallelBatchDatasetParams4(),
/*expected_output_shapes=*/{PartialTensorShape({3})}},
{/*dataset_params=*/ParallelBatchDatasetParams5(),
/*expected_output_shapes=*/{PartialTensorShape({12})}},
{/*dataset_params=*/ParallelBatchDatasetParams6(),
/*expected_output_shapes=*/{PartialTensorShape({-1})}},
{/*dataset_params=*/ParallelBatchDatasetParams7(),
/*expected_output_shapes=*/{PartialTensorShape({4})}},
{/*dataset_params=*/ParallelBatchDatasetParams8(),
/*expected_output_shapes=*/{PartialTensorShape({4})}},
{/*dataset_params=*/ParallelBatchDatasetParams9(),
/*expected_output_shapes=*/{PartialTensorShape({4})}}};
}
ITERATOR_OUTPUT_SHAPES_TEST_P(ParallelBatchDatasetOpTest,
ParallelBatchDatasetParams,
IteratorOutputShapesTestCases())
TEST_F(ParallelBatchDatasetOpTest, IteratorOutputPrefix) {
auto parallel_batch_dataset_params = ParallelBatchDatasetParams1();
TF_ASSERT_OK(Initialize(parallel_batch_dataset_params));
name_utils::IteratorPrefixParams params;
params.op_version = parallel_batch_dataset_params.op_version();
TF_ASSERT_OK(CheckIteratorPrefix(name_utils::IteratorPrefix(
ParallelBatchDatasetOp::kDatasetType,
parallel_batch_dataset_params.iterator_prefix(), params)));
}
std::vector<IteratorSaveAndRestoreTestCase<ParallelBatchDatasetParams>>
IteratorSaveAndRestoreTestCases() {
return {{/*dataset_params=*/ParallelBatchDatasetParams1(),
/*breakpoints=*/{0, 1, 5},
/*expected_outputs=*/
CreateTensors<int64>(TensorShape({4}),
{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})},
{/*dataset_params=*/ParallelBatchDatasetParams2(),
/*breakpoints=*/{0, 1, 5},
/*expected_outputs=*/
CreateTensors<int64>(TensorShape({4}),
{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})},
{/*dataset_params=*/ParallelBatchDatasetParams3(),
/*breakpoints=*/{0, 1, 5},
/*expected_outputs=*/
{CreateTensor<int64>(TensorShape({3}), {0, 1, 2}),
CreateTensor<int64>(TensorShape({3}), {3, 4, 5}),
CreateTensor<int64>(TensorShape({3}), {6, 7, 8}),
CreateTensor<int64>(TensorShape({1}), {9})}},
{/*dataset_params=*/ParallelBatchDatasetParams4(),
/*breakpoints=*/{0, 1, 5},
/*expected_outputs=*/
{CreateTensor<int64>(TensorShape({3}), {0, 1, 2}),
CreateTensor<int64>(TensorShape({3}), {3, 4, 5}),
CreateTensor<int64>(TensorShape({3}), {6, 7, 8})}},
{/*dataset_params=*/ParallelBatchDatasetParams5(),
/*breakpoints=*/{0, 1, 5},
/*expected_outputs=*/{}},
{/*dataset_params=*/ParallelBatchDatasetParams6(),
/*breakpoints=*/{0, 1, 5},
/*expected_outputs=*/
{CreateTensor<int64>(TensorShape({10}),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9})}},
{/*dataset_params=*/ParallelBatchDatasetParams7(),
/*breakpoints=*/{0, 1, 5},
/*expected_outputs=*/{}},
{/*dataset_params=*/ParallelBatchDatasetParams8(),
/*breakpoints=*/{0, 1, 5},
/*expected_outputs=*/
CreateTensors<int64>(TensorShape({4}),
{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})},
{/*dataset_params=*/ParallelBatchDatasetParams9(),
/*breakpoints=*/{0, 1, 5},
/*expected_outputs=*/
CreateTensors<int64>(TensorShape({4}),
{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})}};
}
ITERATOR_SAVE_AND_RESTORE_TEST_P(ParallelBatchDatasetOpTest,
ParallelBatchDatasetParams,
IteratorSaveAndRestoreTestCases())
TEST_F(ParallelBatchDatasetOpTest, InvalidParallelBatchSize) {
auto parallel_batch_dataset_params =
InvalidBatchSizeParallelBatchDatasetParams();
EXPECT_EQ(Initialize(parallel_batch_dataset_params).code(),
tensorflow::error::INVALID_ARGUMENT);
}
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -328,6 +328,25 @@ REGISTER_OP("BatchDatasetV2")
return shape_inference::ScalarShape(c);
});
REGISTER_OP("ParallelBatchDataset")
.Input("input_dataset: variant")
.Input("batch_size: int64")
.Input("num_parallel_calls: int64")
.Input("drop_remainder: bool")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// batch_size should be a scalar.
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
// num_parallel_calls should be a scalar.
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
// drop_remainder should be a scalar.
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
return shape_inference::ScalarShape(c);
});
REGISTER_OP("ShardDataset")
.Input("input_dataset: variant")
.Input("num_shards: int64")

View File

@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.data.benchmarks import benchmark_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import random_ops
class BatchBenchmark(benchmark_base.DatasetBenchmarkBase):
@ -67,6 +68,30 @@ class BatchBenchmark(benchmark_base.DatasetBenchmarkBase):
name="batch_element_size_%d_batch_size_%d%s" %
(element_size, batch_size, tag))
def benchmark_parallel_batch(self):
batch_size = 128
nums_parallel_calls = [None, 1, 4, 16, dataset_ops.AUTOTUNE]
num_range = 100000
def f(_):
return random_ops.random_uniform([224, 224, 3])
for num_parallel_calls in nums_parallel_calls:
num_parallel_calls_str = ("autotune"
if num_parallel_calls == dataset_ops.AUTOTUNE
else str(num_parallel_calls))
op_str = ("batch" if num_parallel_calls is None else
("parallel_batch_num_parallel_calls_%s" %
num_parallel_calls_str))
dataset = dataset_ops.Dataset.range(num_range).map(f).batch(
batch_size, num_parallel_calls=num_parallel_calls)
self.run_and_report_benchmark(
dataset,
num_elements=num_range // batch_size,
iters=1,
name="batch_size_%d_%s" % (batch_size, op_str))
if __name__ == "__main__":
benchmark_base.test.main()

View File

@ -43,9 +43,11 @@ class BatchTest(test_base.DatasetTestBase, parameterized.TestCase):
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
count=[0, 28], batch_size=[14, 15], drop_remainder=[True,
False])))
def testBasic(self, count, batch_size, drop_remainder):
count=[0, 28],
batch_size=[14, 15],
drop_remainder=[True, False],
num_parallel_calls=[None, 1, 2, 4])))
def testBasic(self, count, batch_size, drop_remainder, num_parallel_calls):
"""Tests the batch dataset logic for various input configurations.
Args:
@ -53,6 +55,8 @@ class BatchTest(test_base.DatasetTestBase, parameterized.TestCase):
batch_size: the batch size
drop_remainder: whether a smaller batch size should be produced if batch
size does not divide number of inputs evenly
num_parallel_calls: the number batches to process asynchronously in
parallel
"""
# The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
@ -65,7 +69,8 @@ class BatchTest(test_base.DatasetTestBase, parameterized.TestCase):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
dataset = dataset_ops.Dataset.from_tensor_slices(components).map(
_map_fn).repeat(count).batch(batch_size, drop_remainder)
_map_fn).repeat(count).batch(batch_size, drop_remainder,
num_parallel_calls)
get_next = self.getNext(dataset)
if drop_remainder:

View File

@ -1505,7 +1505,7 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
"""
return ShardDataset(self, num_shards, index)
def batch(self, batch_size, drop_remainder=False):
def batch(self, batch_size, drop_remainder=False, num_parallel_calls=None):
"""Combines consecutive elements of this dataset into batches.
>>> dataset = tf.data.Dataset.range(8)
@ -1532,11 +1532,21 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
whether the last batch should be dropped in the case it has fewer than
`batch_size` elements; the default behavior is not to drop the smaller
batch.
num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`,
representing the number of batches to compute asynchronously in
parallel.
If not specified, batches will be computed sequentially. If the value
`tf.data.AUTOTUNE` is used, then the number of parallel
calls is set dynamically based on available resources.
Returns:
Dataset: A `Dataset`.
"""
return BatchDataset(self, batch_size, drop_remainder)
if num_parallel_calls is None:
return BatchDataset(self, batch_size, drop_remainder)
else:
return ParallelBatchDataset(self, batch_size, drop_remainder,
num_parallel_calls)
def padded_batch(self,
batch_size,
@ -2613,9 +2623,10 @@ class DatasetV1(DatasetV2):
return DatasetV1Adapter(super(DatasetV1, self).shard(num_shards, index))
@functools.wraps(DatasetV2.batch)
def batch(self, batch_size, drop_remainder=False):
return DatasetV1Adapter(super(DatasetV1, self).batch(
batch_size, drop_remainder))
def batch(self, batch_size, drop_remainder=False, num_parallel_calls=None):
return DatasetV1Adapter(
super(DatasetV1, self).batch(batch_size, drop_remainder,
num_parallel_calls))
@functools.wraps(DatasetV2.padded_batch)
def padded_batch(self,
@ -3936,6 +3947,47 @@ class BatchDataset(UnaryDataset):
return self._structure
class ParallelBatchDataset(UnaryDataset):
"""A `Dataset` that batches contiguous elements from its input in parallel."""
def __init__(self, input_dataset, batch_size, drop_remainder,
num_parallel_calls):
"""See `Dataset.batch()` for details."""
self._input_dataset = input_dataset
self._batch_size = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name="batch_size")
self._drop_remainder = ops.convert_to_tensor(
drop_remainder, dtype=dtypes.bool, name="drop_remainder")
self._num_parallel_calls = ops.convert_to_tensor(
num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
constant_drop_remainder = tensor_util.constant_value(self._drop_remainder)
# pylint: disable=protected-access
if constant_drop_remainder:
# NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
# or `False` (explicitly retaining the remainder).
# pylint: disable=g-long-lambda
constant_batch_size = tensor_util.constant_value(self._batch_size)
self._structure = nest.map_structure(
lambda component_spec: component_spec._batch(constant_batch_size),
input_dataset.element_spec)
else:
self._structure = nest.map_structure(
lambda component_spec: component_spec._batch(None),
input_dataset.element_spec)
variant_tensor = gen_dataset_ops.parallel_batch_dataset(
input_dataset._variant_tensor,
batch_size=self._batch_size,
num_parallel_calls=self._num_parallel_calls,
drop_remainder=self._drop_remainder,
**self._flat_structure)
super(ParallelBatchDataset, self).__init__(input_dataset, variant_tensor)
@property
def element_spec(self):
return self._structure
class _NumpyIterator(object):
"""Iterator over a dataset with elements converted to numpy."""

View File

@ -33,7 +33,7 @@ tf_class {
}
member_method {
name: "batch"
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "cache"

View File

@ -35,7 +35,7 @@ tf_class {
}
member_method {
name: "batch"
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "cache"

View File

@ -35,7 +35,7 @@ tf_class {
}
member_method {
name: "batch"
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "cache"

View File

@ -35,7 +35,7 @@ tf_class {
}
member_method {
name: "batch"
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "cache"

View File

@ -35,7 +35,7 @@ tf_class {
}
member_method {
name: "batch"
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "cache"

View File

@ -35,7 +35,7 @@ tf_class {
}
member_method {
name: "batch"
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "cache"

View File

@ -35,7 +35,7 @@ tf_class {
}
member_method {
name: "batch"
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "cache"

View File

@ -2796,6 +2796,10 @@ tf_module {
name: "PaddingFIFOQueueV2"
argspec: "args=[\'component_types\', \'shapes\', \'capacity\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "ParallelBatchDataset"
argspec: "args=[\'input_dataset\', \'batch_size\', \'num_parallel_calls\', \'drop_remainder\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "ParallelConcat"
argspec: "args=[\'values\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -20,7 +20,7 @@ tf_class {
}
member_method {
name: "batch"
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "cache"

View File

@ -22,7 +22,7 @@ tf_class {
}
member_method {
name: "batch"
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "cache"

View File

@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "batch"
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "cache"

View File

@ -22,7 +22,7 @@ tf_class {
}
member_method {
name: "batch"
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "cache"

View File

@ -22,7 +22,7 @@ tf_class {
}
member_method {
name: "batch"
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "cache"

View File

@ -22,7 +22,7 @@ tf_class {
}
member_method {
name: "batch"
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "cache"

View File

@ -22,7 +22,7 @@ tf_class {
}
member_method {
name: "batch"
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "cache"

View File

@ -2796,6 +2796,10 @@ tf_module {
name: "PaddingFIFOQueueV2"
argspec: "args=[\'component_types\', \'shapes\', \'capacity\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'-1\', \'\', \'\', \'None\'], "
}
member_method {
name: "ParallelBatchDataset"
argspec: "args=[\'input_dataset\', \'batch_size\', \'num_parallel_calls\', \'drop_remainder\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "ParallelConcat"
argspec: "args=[\'values\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "