diff --git a/RELEASE.md b/RELEASE.md index e49397d1dcf..e8ec9067fb4 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -32,6 +32,9 @@ examples in the same step. * tf.data service supports custom data transfer protocols (other than gRPC). + * `tf.data.Dataset.batch()` now supports `num_parallel_calls` argument, + which can be used to indicate that multiple input batches should be + computed in parallel. ## Bug Fixes and Other Changes diff --git a/tensorflow/core/api_def/base_api/api_def_ParallelBatchDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ParallelBatchDataset.pbtxt new file mode 100644 index 00000000000..cb949830deb --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ParallelBatchDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ParallelBatchDataset" + visibility: HIDDEN +} diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index 24c5e81fb2e..2bdfec6abd7 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -516,9 +516,9 @@ class KnownRatio : public Node { class AsyncKnownRatio : public Node { public: - AsyncKnownRatio(Node::Args args, double ratio, + AsyncKnownRatio(Node::Args args, double ratio, double memory_ratio, std::vector> parameters) - : Node(args), ratio_(ratio) { + : Node(args), ratio_(ratio), memory_ratio_(memory_ratio) { for (auto& parameter : parameters) { parameters_[parameter->name] = std::move(parameter); } @@ -534,7 +534,7 @@ class AsyncKnownRatio : public Node { parameters.push_back(pair.second); } return std::make_shared( - Args{id_, name_, std::move(output)}, ratio_, parameters); + Args{id_, name_, std::move(output)}, ratio_, memory_ratio_, parameters); } // The input time is the sum of inherited input time and parallelism adjusted @@ -711,13 +711,14 @@ class AsyncKnownRatio : public Node { } if (parameter) { - if (ratio_ == 0) { + if (memory_ratio_ == 0) { result += (*parameter)->value * AverageBufferedElementSize(); } else { // The estimation is currently not accurate for MapAndBatchDataset for // the maximum buffer size does not match `num_parallel_calls` // parameter. - result += (*parameter)->value * AverageBufferedElementSize() / ratio_; + result += + (*parameter)->value * AverageBufferedElementSize() / memory_ratio_; } } return result; @@ -727,11 +728,22 @@ class AsyncKnownRatio : public Node { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::ASYNC_KNOWN_RATIO); node_proto->set_ratio(ratio_); + node_proto->set_memory_ratio(memory_ratio_); return Status::OK(); } private: + // Identifies how many input elements need to be created to construct an + // element for the dataset. + // + // Currently the value is 1 for PrefetchDataset and ParallelMapDataset, + // batch_size for MapAndBatchDataset and ParallelBatchDataset. const double ratio_; + // For parallelism nodes, identifies how many parallelism calls are introduced + // by one buffered element. The value is defined to correctly estimate RAM + // budget bound with given num_parallel_calls (or buffer_size) combined with + // the estimated average size of buffered elements. + const double memory_ratio_; }; class UnknownRatio : public Node { @@ -922,11 +934,18 @@ std::shared_ptr MakeKnownRatioNode(Node::Args args, double ratio) { return std::make_shared(std::move(args), ratio); } +std::shared_ptr MakeAsyncKnownRatioNode( + Node::Args args, double ratio, double memory_ratio, + std::vector> parameters) { + return std::make_shared(std::move(args), ratio, memory_ratio, + std::move(parameters)); +} + std::shared_ptr MakeAsyncKnownRatioNode( Node::Args args, double ratio, std::vector> parameters) { - return std::make_shared(std::move(args), ratio, - std::move(parameters)); + return MakeAsyncKnownRatioNode(std::move(args), /*ratio=*/ratio, + /*memory_ratio=*/ratio, std::move(parameters)); } std::shared_ptr MakeSourceNode(Node::Args args) { @@ -1546,7 +1565,7 @@ Status Node::FromProto(ModelProto::Node node_proto, break; case NodeClass::ASYNC_KNOWN_RATIO: restored_node = std::make_shared( - args, node_proto.ratio(), + args, node_proto.ratio(), node_proto.memory_ratio(), /*parameters=*/std::vector>()); break; case NodeClass::UNKNOWN_RATIO: diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h index 42ce637e6cb..e04e2c66beb 100644 --- a/tensorflow/core/framework/model.h +++ b/tensorflow/core/framework/model.h @@ -605,6 +605,10 @@ std::shared_ptr MakeAsyncInterleaveManyNode( std::shared_ptr MakeKnownRatioNode(Node::Args args, double ratio); // AsyncKnownRatio nodes are the asynchronous version of KnownRate nodes. +std::shared_ptr MakeAsyncKnownRatioNode( + Node::Args args, double ratio, double memory_ratio, + std::vector> parameters); + std::shared_ptr MakeAsyncKnownRatioNode( Node::Args args, double ratio, std::vector> parameters); diff --git a/tensorflow/core/framework/model.proto b/tensorflow/core/framework/model.proto index b80687e5cef..a0e25657a38 100644 --- a/tensorflow/core/framework/model.proto +++ b/tensorflow/core/framework/model.proto @@ -88,6 +88,10 @@ message ModelProto { // Ratio of input to output elements. This is only used by KNOWN_RATIO and // ASYNC_KNOWN_RATIO nodes. double ratio = 16; + + // Ratio identifies how many parallelism calls are introduced by one + // buffered element. This is only used by ASYNC_KNOWN_RATIO nodes. + double memory_ratio = 17; } // Output node of this model. diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 7d2d2ef1099..d0f1c64b7b0 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -30,6 +30,7 @@ tf_kernel_library( srcs = ["batch_dataset_op.cc"], hdrs = ["batch_dataset_op.h"], deps = [ + ":dataset_utils", ":name_utils", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -683,6 +684,50 @@ tf_cc_test( ], ) +tf_kernel_library( + name = "parallel_batch_dataset_op", + srcs = ["parallel_batch_dataset_op.cc"], + hdrs = ["parallel_batch_dataset_op.h"], + deps = [ + ":dataset_utils", + ":name_utils", + ":stats_utils", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/profiler/lib:traceme_encode", + ], +) + +tf_cc_test( + name = "parallel_batch_dataset_op_test", + size = "small", + srcs = ["parallel_batch_dataset_op_test.cc"], + deps = [ + ":dataset_test_base", + ":dataset_utils", + ":iterator_ops", + ":name_utils", + ":parallel_batch_dataset_op", + ":range_dataset_op", + ":stats_utils", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:function_ops", + ], +) + tf_kernel_library( name = "parallel_interleave_dataset_op", srcs = ["parallel_interleave_dataset_op.cc"], @@ -1476,6 +1521,7 @@ tf_kernel_library( ":optimize_dataset_op", ":optional_ops", ":padded_batch_dataset_op", + ":parallel_batch_dataset_op", ":parallel_interleave_dataset_op", ":parallel_map_dataset_op", ":prefetch_dataset_op", diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc index 7ea39dfe709..7959c88e9ed 100644 --- a/tensorflow/core/kernels/data/batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/batch_dataset_op.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/name_utils.h" -#include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stringprintf.h" @@ -199,67 +199,9 @@ class BatchDatasetOp::Dataset : public DatasetBase { // respective slice locations. This would require a different GetNext() // overload that supports zero-copy, and might make sense in an // optimization pass. - const size_t num_tuple_components = batch_elements[0].size(); - out_tensors->reserve(num_tuple_components); - const int64 num_batch_elements = batch_elements.size(); - for (size_t component_index = 0; component_index < num_tuple_components; - ++component_index) { - const Tensor& first_element = batch_elements[0][component_index]; - TensorShape batch_component_shape({num_batch_elements}); - // NOTE(mrry): Copy the shape of the first element here, because - // `first_element.shape()` will become undefined after the 0th batch - // element is moved into the output batch. - TensorShape first_element_shape(first_element.shape()); - batch_component_shape.AppendShape(first_element_shape); - out_tensors->emplace_back(ctx->allocator({}), first_element.dtype(), - batch_component_shape); - if (!out_tensors->back().IsInitialized()) { - return errors::ResourceExhausted( - "Failed to allocate memory for the batch of component ", - component_index); - } - Tensor& batch_component = out_tensors->back(); - // Build the output tuple component by copying one slice - // from each input element in the batch. - auto copy_element_fn = [component_index, &batch_elements, - &batch_component](int index) { - TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice( - std::move(batch_elements[index][component_index]), - &batch_component, index)); - return Status::OK(); - }; - BlockingCounter counter(num_batch_elements); - Status status; - mutex status_mu; - for (size_t i = 0; i < num_batch_elements; ++i) { - if (batch_elements[i][component_index].shape() != - first_element_shape) { - return errors::InvalidArgument( - "Cannot batch tensors with different shapes in " - "component ", - component_index, ". First element had shape ", - first_element_shape.DebugString(), " and element ", i, - " had shape ", - batch_elements[i][component_index].shape().DebugString(), "."); - } - if (TF_PREDICT_FALSE(dataset()->parallel_copy_)) { - (*ctx->runner())( - [i, &status, &status_mu, &counter, ©_element_fn]() { - Status s = copy_element_fn(i); - { - mutex_lock l(status_mu); - status.Update(s); - } - counter.DecrementCount(); - }); - } else { - status.Update(copy_element_fn(i)); - counter.DecrementCount(); - } - } - counter.Wait(); - TF_RETURN_IF_ERROR(status); - } + TF_RETURN_IF_ERROR(CopyBatch(/*parallel_copy=*/dataset()->parallel_copy_, + ctx, out_tensors, &batch_elements)); + *end_of_sequence = false; return Status::OK(); } diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc index 7e6d5b1ccfd..ddba4077d60 100644 --- a/tensorflow/core/kernels/data/dataset_utils.cc +++ b/tensorflow/core/kernels/data/dataset_utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset_utils.h" +#include #include #include "absl/container/flat_hash_map.h" @@ -30,6 +31,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/proto_serialization.h" @@ -39,10 +41,16 @@ limitations under the License. namespace tensorflow { namespace data { namespace { + constexpr char kDelimiter[] = "@@"; constexpr char kComponent[] = "component"; constexpr char kNumElements[] = "num_elements"; constexpr char kNumComponents[] = "num_components"; +constexpr char kOutputSize[] = "output_size"; +constexpr char kCode[] = "code"; +constexpr char kMessage[] = "msg"; +constexpr char kOutput[] = "output"; + } // namespace Status WriteElementsToCheckpoint( @@ -699,5 +707,233 @@ void StripDevicePlacement(FunctionDefLibrary* library) { } } +Status CopyPartialBatch(int64 num_elements, const Tensor& value, + Tensor* output) { + switch (value.dtype()) { +#define HANDLE_TYPE(type) \ + case DataTypeToEnum::value: { \ + auto output_t = output->flat_outer_dims(); \ + auto value_t = value.flat_outer_dims(); \ + for (size_t i = 0; i < num_elements; i++) { \ + output_t.template chip<0>(i) = value_t.template chip<0>(i); \ + } \ + return Status::OK(); \ + } + TF_CALL_DATASET_TYPES(HANDLE_TYPE); +#undef HANDLE_TYPE + default: + return errors::InvalidArgument("Unsupported data type: ", + DataTypeString(value.dtype())); + } + return Status::OK(); +} + +Status ReadBatch(int64 batch_size, const string& iterator_prefix, + const string& batch_prefix, IteratorContext* ctx, + IteratorStateReader* reader, std::vector* batch) { + int64 output_size; + TF_RETURN_IF_ERROR(reader->ReadScalar( + FullName(iterator_prefix, + strings::StrCat(batch_prefix, "_", kOutputSize)), + &output_size)); + batch->reserve(output_size); + for (int i = 0; i < output_size; i++) { + Tensor t; + TF_RETURN_IF_ERROR(reader->ReadTensor( + FullName(iterator_prefix, + strings::StrCat(batch_prefix, "_", kOutput, "_", i)), + &t)); + // If the batch was not full, we may have stored only the relevant slice. + // Since tensors in `BatchResult.output` are expected to have the leading + // dimension of size batch_size, we build a larger tensor and copy the slice + // read from the checkpoint into it. + if (t.dim_size(0) < batch_size) { + TensorShape component_shape(t.shape()); + component_shape.set_dim(0, batch_size); + AllocatorAttributes attr; + attr.set_gpu_compatible(true); + Tensor new_t(ctx->allocator(attr), t.dtype(), component_shape); + TF_RETURN_IF_ERROR(CopyPartialBatch(t.dim_size(0), t, &new_t)); + batch->emplace_back(std::move(new_t)); + } else { + batch->emplace_back(std::move(t)); + } + } + return Status::OK(); +} + +Status WriteBatch(int64 batch_size, int64 num_elements, + const string& iterator_prefix, const string& batch_prefix, + IteratorStateWriter* writer, std::vector* batch) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + FullName(iterator_prefix, + strings::StrCat(batch_prefix, "_", kOutputSize)), + batch->size())); + for (int i = 0; i < batch->size(); i++) { + // If the batch is not full, we only store the first `num_elements` values. + // The rest of the batch tensor is *uninitialized* and accessing that will + // raise msan errors. + if (num_elements < batch_size) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + FullName(iterator_prefix, + strings::StrCat(batch_prefix, "_", kOutput, "_", i)), + (*batch)[i].Slice(0, num_elements))); + } else { + TF_RETURN_IF_ERROR(writer->WriteTensor( + FullName(iterator_prefix, + strings::StrCat(batch_prefix, "_", kOutput, "_", i)), + (*batch)[i])); + } + } + return Status::OK(); +} + +Status ReadStatus(const string& iterator_prefix, const string& prefix, + IteratorStateReader* reader, Status* status) { + int64 code_int; + TF_RETURN_IF_ERROR(reader->ReadScalar( + FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)), + &code_int)); + error::Code code = static_cast(code_int); + + if (code != error::Code::OK) { + tstring error_message; + TF_RETURN_IF_ERROR(reader->ReadScalar( + FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)), + &error_message)); + *status = Status(code, error_message); + } else { + *status = Status::OK(); + } + return Status::OK(); +} + +Status WriteStatus(const string& iterator_prefix, const string& prefix, + const Status& status, IteratorStateWriter* writer) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)), + static_cast(status.code()))); + if (!status.ok()) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)), + status.error_message())); + } + return Status::OK(); +} + +Status ProcessBatch(int64 batch_size, int64 num_elements, bool drop_remainder, + const Status& status, IteratorContext* ctx, + std::vector* output, bool* end_of_sequence, + std::vector* batch) { + if (num_elements == 0) { + if (status.ok() || errors::IsOutOfRange(status)) { + *end_of_sequence = true; + return Status::OK(); + } else { + *end_of_sequence = false; + return status; + } + } + if (!status.ok() && !errors::IsOutOfRange(status)) { + *end_of_sequence = false; + return status; + } + if (num_elements < batch_size) { + if (drop_remainder) { + *end_of_sequence = true; + return Status::OK(); + } + for (size_t i = 0; i < batch->size(); ++i) { + TensorShape component_shape((*batch)[i].shape()); + component_shape.set_dim(0, num_elements); + AllocatorAttributes attr; + attr.set_gpu_compatible(true); + output->emplace_back(ctx->allocator(attr), (*batch)[i].dtype(), + component_shape); + if (!output->back().IsInitialized()) { + return errors::ResourceExhausted( + "Failed to allocate memory for the batch of component ", i); + } + TF_RETURN_IF_ERROR( + CopyPartialBatch(num_elements, (*batch)[i], &output->back())); + } + } else { + *output = std::move(*batch); + } + *end_of_sequence = false; + return Status::OK(); +} + +Status CopyBatch(bool parallel_copy, IteratorContext* ctx, + std::vector* out_tensors, + std::vector>* batch_elements) { + const size_t num_tuple_components = (*batch_elements)[0].size(); + out_tensors->reserve(num_tuple_components); + const int64 num_batch_elements = batch_elements->size(); + for (size_t component_index = 0; component_index < num_tuple_components; + ++component_index) { + const Tensor& first_element = (*batch_elements)[0][component_index]; + TensorShape batch_component_shape({num_batch_elements}); + // NOTE(mrry): Copy the shape of the first element here, because + // `first_element.shape()` will become undefined after the 0th batch element + // is moved into the output batch. + TensorShape first_element_shape(first_element.shape()); + batch_component_shape.AppendShape(first_element_shape); + out_tensors->emplace_back(ctx->allocator({}), first_element.dtype(), + batch_component_shape); + if (!out_tensors->back().IsInitialized()) { + return errors::ResourceExhausted( + "Failed to allocate memory for the batch of component ", + component_index); + } + Tensor& batch_component = out_tensors->back(); + // Build the output tuple component by copying one slice from each input + // element in the batch. + auto copy_element_fn = [component_index, &batch_elements, + &batch_component](int index) { + TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice( + std::move((*batch_elements)[index][component_index]), + &batch_component, index)); + return Status::OK(); + }; + Status status; + std::unique_ptr counter; + std::unique_ptr status_mu; + if (TF_PREDICT_FALSE(parallel_copy)) { + counter = std::make_unique(num_batch_elements); + status_mu = std::make_unique(); + } + for (size_t i = 0; i < num_batch_elements; ++i) { + if ((*batch_elements)[i][component_index].shape() != + first_element_shape) { + return errors::InvalidArgument( + "Cannot batch tensors with different shapes in component ", + component_index, ". First element had shape ", + first_element_shape.DebugString(), " and element ", i, + " had shape ", + (*batch_elements)[i][component_index].shape().DebugString(), "."); + } + if (TF_PREDICT_FALSE(parallel_copy)) { + (*ctx->runner())( + [i, &status, &status_mu, &counter, ©_element_fn]() { + Status s = copy_element_fn(i); + { + mutex_lock l(*status_mu); + status.Update(s); + } + counter->DecrementCount(); + }); + } else { + status.Update(copy_element_fn(i)); + } + } + if (TF_PREDICT_FALSE(parallel_copy)) { + counter->Wait(); + } + TF_RETURN_IF_ERROR(status); + } + return Status::OK(); +} + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h index 4bfee51a087..aebf76a8eee 100644 --- a/tensorflow/core/kernels/data/dataset_utils.h +++ b/tensorflow/core/kernels/data/dataset_utils.h @@ -302,6 +302,40 @@ std::vector SelectOptimizations( // Removes device placements from the ops of all functions in `library`. void StripDevicePlacement(FunctionDefLibrary* library); +// Copies partial of the batch output. +Status CopyPartialBatch(int64 num_elements, const Tensor& value, + Tensor* output); + +// Reads a batch when restoring the iterator. +Status ReadBatch(int64 batch_size, const string& iterator_prefix, + const string& batch_prefix, IteratorContext* ctx, + IteratorStateReader* reader, std::vector* batch); + +// Writes a batch when saving the iterator. +Status WriteBatch(int64 batch_size, int64 num_elements, + const string& iterator_prefix, const string& batch_prefix, + IteratorStateWriter* writer, std::vector* batch); + +// Reads a status when restoring the iterator. +Status ReadStatus(const string& iterator_prefix, const string& prefix, + IteratorStateReader* reader, Status* status); + +// Writes a status when saving the iterator. +Status WriteStatus(const string& iterator_prefix, const string& prefix, + const Status& status, IteratorStateWriter* writer); + +// Processes a batch to output. In the case a partial batch is encountered, copy +// only partial of the batch. +Status ProcessBatch(int64 batch_size, int64 num_elements, bool drop_remainder, + const Status& status, IteratorContext* ctx, + std::vector* output, bool* end_of_sequence, + std::vector* batch); + +// Copies the input elements to a batch. +Status CopyBatch(bool parallel_copy, IteratorContext* ctx, + std::vector* out_tensors, + std::vector>* batch_elements); + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc index 5b7cfeced95..51a02ef45a2 100644 --- a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc @@ -72,11 +72,7 @@ constexpr char kEndOfInput[] = "end_of_input"; constexpr char kNumCalls[] = "num_calls"; constexpr char kNumElements[] = "num_elements"; constexpr char kOutputAllocated[] = "output_allocated"; -constexpr char kOutputSize[] = "output_size"; -constexpr char kOutput[] = "output"; constexpr char kStatus[] = "status"; -constexpr char kCode[] = "code"; -constexpr char kMessage[] = "msg"; // Computes ceil(x / y). inline int64 CeilDiv(int64 x, int64 y) { return (x + y - 1) / y; } @@ -257,7 +253,17 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { return profiler::TraceMeEncode("MapAndBatchConsume", {{"element_id", result->uid}}); }); - return ProcessResult(ctx, result, out_tensors, end_of_sequence); + // Deallocate tensors allocated for the output. + auto cleanup = gtl::MakeCleanup([result] { result->output.clear(); }); + mutex_lock l(result->mu); + if (result->output_allocated) { + RecordBufferDequeue(ctx, result->output); + } + TF_RETURN_IF_ERROR( + ProcessBatch(dataset()->batch_size_, result->num_elements, + dataset()->drop_remainder_, result->status, ctx, + out_tensors, end_of_sequence, &result->output)); + return Status::OK(); } protected: @@ -477,27 +483,6 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { } } - Status CopyPartialBatch(Tensor* output, const Tensor& value, - int64 num_elements) { - switch (value.dtype()) { -#define HANDLE_TYPE(type) \ - case DataTypeToEnum::value: { \ - auto output_t = output->flat_outer_dims(); \ - auto value_t = value.flat_outer_dims(); \ - for (size_t i = 0; i < num_elements; i++) { \ - output_t.template chip<0>(i) = value_t.template chip<0>(i); \ - } \ - return Status::OK(); \ - } - TF_CALL_DATASET_TYPES(HANDLE_TYPE); -#undef HANDLE_TYPE - default: - return errors::InvalidArgument("Unsupported data type: ", - DataTypeString(value.dtype())); - } - return Status::OK(); - } - void EnsureRunnerThreadStarted(IteratorContext* ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { @@ -536,60 +521,6 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { return Status::OK(); } - Status ProcessResult(IteratorContext* ctx, - const std::shared_ptr& result, - std::vector* out_tensors, - bool* end_of_sequence) { - mutex_lock l(result->mu); - if (result->output_allocated) { - RecordBufferDequeue(ctx, result->output); - } - if (result->num_elements == 0) { - if (result->status.ok() || errors::IsOutOfRange(result->status)) { - *end_of_sequence = true; - return Status::OK(); - } else { - *end_of_sequence = false; - return result->status; - } - } - if (!result->status.ok() && !errors::IsOutOfRange(result->status)) { - // Deallocate tensors allocated for the output. - result->output.clear(); - *end_of_sequence = false; - return result->status; - } - if (result->num_elements < dataset()->batch_size_) { - if (dataset()->drop_remainder_) { - // Deallocate tensors allocated for the output. - result->output.clear(); - *end_of_sequence = true; - return Status::OK(); - } - const std::vector& output = result->output; - for (size_t i = 0; i < output.size(); ++i) { - TensorShape component_shape(result->output[i].shape()); - component_shape.set_dim(0, result->num_elements); - AllocatorAttributes attr; - attr.set_gpu_compatible(true); - out_tensors->emplace_back(ctx->allocator(attr), output[i].dtype(), - component_shape); - if (!out_tensors->back().IsInitialized()) { - return errors::ResourceExhausted( - "Failed to allocate memory for the batch of component ", i); - } - TF_RETURN_IF_ERROR(CopyPartialBatch(&out_tensors->back(), output[i], - result->num_elements)); - } - // Deallocate tensors allocated for the output. - result->output.clear(); - } else { - *out_tensors = std::move(result->output); - } - *end_of_sequence = false; - return Status::OK(); - } - void RunnerThread(const std::shared_ptr& ctx) TF_LOCKS_EXCLUDED(*mu_) { std::vector, int64>> new_calls; @@ -660,116 +591,54 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { batch_results_.push_back( std::make_shared(dataset()->batch_size_)); std::shared_ptr result = batch_results_.back(); - string prefix = strings::StrCat(kBatchResults, "_", index); + string batch_prefix = strings::StrCat(kBatchResults, "_", index); mutex_lock l(result->mu); result->end_of_input = reader->Contains( - full_name(strings::StrCat(prefix, "_", kEndOfInput))); - TF_RETURN_IF_ERROR( - reader->ReadScalar(full_name(strings::StrCat(prefix, "_", kNumCalls)), - &result->num_calls)); + full_name(strings::StrCat(batch_prefix, "_", kEndOfInput))); TF_RETURN_IF_ERROR(reader->ReadScalar( - full_name(strings::StrCat(prefix, "_", kNumElements)), + full_name(strings::StrCat(batch_prefix, "_", kNumCalls)), + &result->num_calls)); + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat(batch_prefix, "_", kNumElements)), &result->num_elements)); result->output_allocated = reader->Contains( - full_name(strings::StrCat(prefix, "_", kOutputAllocated))); - int64 output_size; - TF_RETURN_IF_ERROR(reader->ReadScalar( - full_name(strings::StrCat(prefix, "_", kOutputSize)), &output_size)); - result->output.reserve(output_size); - for (int i = 0; i < output_size; i++) { - Tensor t; - TF_RETURN_IF_ERROR(reader->ReadTensor( - full_name(strings::StrCat(prefix, "_", kOutput, "_", i)), &t)); - // If the batch was not full, we may have stored only the relevant - // slice. Since tensors in `BatchResult.output` are expected to - // have the leading dimension of size batch_size, we build a larger - // tensor and copy the slice read from the checkpoint into it. - if (t.dim_size(0) < dataset()->batch_size_) { - TensorShape component_shape(t.shape()); - component_shape.set_dim(0, dataset()->batch_size_); - AllocatorAttributes attr; - attr.set_gpu_compatible(true); - Tensor new_t(ctx->allocator(attr), t.dtype(), component_shape); - TF_RETURN_IF_ERROR(CopyPartialBatch(&new_t, t, t.dim_size(0))); - result->output.emplace_back(std::move(new_t)); - } else { - result->output.emplace_back(std::move(t)); - } - } - TF_RETURN_IF_ERROR(ReadStatus( - reader, strings::StrCat(prefix, "_", kStatus), &result->status)); - return Status::OK(); - } + full_name(strings::StrCat(batch_prefix, "_", kOutputAllocated))); - Status ReadStatus(IteratorStateReader* reader, const string& prefix, - Status* status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { - int64 code_int; - TF_RETURN_IF_ERROR(reader->ReadScalar( - full_name(strings::StrCat(prefix, "_", kCode)), &code_int)); - error::Code code = static_cast(code_int); - - if (code != error::Code::OK) { - tstring error_message; - TF_RETURN_IF_ERROR(reader->ReadScalar( - full_name(strings::StrCat(prefix, "_", kMessage)), &error_message)); - *status = Status(code, error_message); - } else { - *status = Status::OK(); - } + TF_RETURN_IF_ERROR(ReadBatch(dataset()->batch_size_, prefix(), + batch_prefix, ctx, reader, &result->output)); + TF_RETURN_IF_ERROR(ReadStatus(prefix(), + strings::StrCat(batch_prefix, "_", kStatus), + reader, &result->status)); return Status::OK(); } Status WriteBatchResult(IteratorStateWriter* writer, size_t index) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { std::shared_ptr result = batch_results_[index]; - string prefix = strings::StrCat(kBatchResults, "_", index); + string batch_prefix = strings::StrCat(kBatchResults, "_", index); mutex_lock l(result->mu); if (result->end_of_input) { TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat(prefix, "_", kEndOfInput)), "")); + full_name(strings::StrCat(batch_prefix, "_", kEndOfInput)), "")); } TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat(prefix, "_", kNumCalls)), + full_name(strings::StrCat(batch_prefix, "_", kNumCalls)), result->num_calls)); TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat(prefix, "_", kNumElements)), + full_name(strings::StrCat(batch_prefix, "_", kNumElements)), result->num_elements)); if (result->output_allocated) { TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat(prefix, "_", kOutputAllocated)), "")); + full_name(strings::StrCat(batch_prefix, "_", kOutputAllocated)), + "")); } - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat(prefix, "_", kOutputSize)), - result->output.size())); - for (int i = 0; i < result->output.size(); i++) { - // If the batch is not full, we only store the first `num_elements` - // values. The rest of the batch tensor is *uninitialized* and - // accessing that will raise msan errors. - if (result->num_elements < dataset()->batch_size_) { - TF_RETURN_IF_ERROR(writer->WriteTensor( - full_name(strings::StrCat(prefix, "_", kOutput, "_", i)), - result->output[i].Slice(0, result->num_elements))); - } else { - TF_RETURN_IF_ERROR(writer->WriteTensor( - full_name(strings::StrCat(prefix, "_", kOutput, "_", i)), - result->output[i])); - } - } - TF_RETURN_IF_ERROR(WriteStatus( - writer, strings::StrCat(prefix, "_", kStatus), result->status)); - return Status::OK(); - } - Status WriteStatus(IteratorStateWriter* writer, const string& prefix, - const Status& status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + TF_RETURN_IF_ERROR(WriteBatch(dataset()->batch_size_, + result->num_elements, prefix(), + batch_prefix, writer, &result->output)); TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name(strings::StrCat(prefix, "_", kCode)), - static_cast(status.code()))); - if (!status.ok()) { - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat(prefix, "_", kMessage)), - status.error_message())); - } + WriteStatus(prefix(), strings::StrCat(batch_prefix, "_", kStatus), + result->status, writer)); return Status::OK(); } diff --git a/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc b/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc new file mode 100644 index 00000000000..ce2c02f2604 --- /dev/null +++ b/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc @@ -0,0 +1,548 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/data/parallel_batch_dataset_op.h" + +#include +#include +#include + +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h" +#include "tensorflow/core/common_runtime/metrics.h" +#include "tensorflow/core/framework/model.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" +#include "tensorflow/core/kernels/data/name_utils.h" +#include "tensorflow/core/kernels/data/stats_utils.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/platform/env_time.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringprintf.h" +#include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/profiler/lib/traceme_encode.h" +#include "tensorflow/core/util/batch_util.h" + +namespace tensorflow { +namespace data { + +/* static */ constexpr const char* const ParallelBatchDatasetOp::kDatasetType; +/* static */ constexpr const char* const ParallelBatchDatasetOp::kInputDataset; +/* static */ constexpr const char* const ParallelBatchDatasetOp::kBatchSize; +/* static */ constexpr const char* const + ParallelBatchDatasetOp::kNumParallelCalls; +/* static */ constexpr const char* const ParallelBatchDatasetOp::kDropRemainder; +/* static */ constexpr const char* const ParallelBatchDatasetOp::kOutputTypes; +/* static */ constexpr const char* const ParallelBatchDatasetOp::kOutputShapes; + +namespace { + +constexpr char kBatchResultsSize[] = "batch_results_size"; +constexpr char kTFDataParallelBatch[] = "tf_data_parallel_batch"; +constexpr char kBatchResults[] = "batch_results"; +constexpr char kEndOfInput[] = "end_of_input"; +constexpr char kNumElements[] = "num_elements"; +constexpr char kCallFinished[] = "call_finished"; +constexpr char kStatus[] = "status"; + +} // namespace + +class ParallelBatchDatasetOp::Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, int64 batch_size, int64 num_parallel_calls, + bool drop_remainder, const DatasetBase* input) + : DatasetBase(DatasetContext(ctx)), + batch_size_(batch_size), + // Dataset batch is sometimes used to stack all elements in the + // dataset. In such cases, a very large batch size (e.g., INT32_MAX) + // is passed with drop_remainder set to false. Avoid OOM in such case + // by limiting `reserve()` size by 2**16. + reserve_size_(drop_remainder ? batch_size + : std::min(batch_size, 1 << 16)), + num_parallel_calls_(num_parallel_calls), + drop_remainder_(drop_remainder), + input_(input), + traceme_metadata_( + {{"autotune", + num_parallel_calls == model::kAutotune ? "true" : "false"}, + {"batch_size", + strings::Printf("%lld", static_cast(batch_size))}, + {"drop_remainder", drop_remainder ? "true" : "false"}}) { + input_->Ref(); + + const auto& input_shapes = input_->output_shapes(); + output_shapes_.reserve(input_shapes.size()); + for (const auto& input_shape : input_shapes) { + if (drop_remainder_ || input_->Cardinality() == kInfiniteCardinality) { + output_shapes_.emplace_back( + PartialTensorShape({batch_size_}).Concatenate(input_shape)); + } else { + output_shapes_.emplace_back( + PartialTensorShape({-1}).Concatenate(input_shape)); + } + } + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return absl::make_unique(Iterator::Params{ + this, name_utils::IteratorPrefix(kDatasetType, prefix)}); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + name_utils::DatasetDebugStringParams params; + params.set_args(batch_size_); + return name_utils::DatasetDebugString(kDatasetType, params); + } + + int64 Cardinality() const override { + int64 n = input_->Cardinality(); + if (n == kInfiniteCardinality || n == kUnknownCardinality) { + return n; + } + return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1); + } + + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + + Status CheckExternalState() const override { + return input_->CheckExternalState(); + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + // Input: input_dataset + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + + // Input: batch_size + Node* batch_size = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size)); + + // Input: num_parallel_calls + Node* num_parallel_calls = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(num_parallel_calls_, &num_parallel_calls)); + + // Input: drop_remainder + Node* drop_remainder = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder)); + + TF_RETURN_IF_ERROR(b->AddDataset( + this, + {input_graph_node, batch_size, num_parallel_calls, drop_remainder}, {}, + output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params), + mu_(std::make_shared()), + cond_var_(std::make_shared()), + num_parallel_calls_(std::make_shared( + params.dataset->num_parallel_calls_, mu_, cond_var_)) {} + + ~Iterator() override { + CancelThreads(/*wait=*/true); + if (deregister_fn_) deregister_fn_(); + } + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(*mu_); + if (num_parallel_calls_->value == model::kAutotune) { + num_parallel_calls_->value = 1; + } + TF_RETURN_IF_ERROR(RegisterCancellationCallback( + ctx->cancellation_manager(), + [this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_)); + return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + std::shared_ptr result; + { + mutex_lock l(*mu_); + EnsureRunnerThreadStarted(ctx); + while (!cancelled_ && (batch_results_.empty() || + !batch_results_.front()->call_finished)) { + ++waiting_; + RecordStop(ctx); + cond_var_->wait(l); + RecordStart(ctx); + --waiting_; + } + if (cancelled_) { + return errors::Cancelled("Iterator was cancelled"); + } + std::swap(result, batch_results_.front()); + batch_results_.pop_front(); + cond_var_->notify_all(); + } + profiler::TraceMe traceme([&] { + return profiler::TraceMeEncode("ParallelBatchConsume", + {{"element_id", result->uid}}); + }); + mutex_lock l(result->mu); + // Deallocate tensors allocated for the output. + auto cleanup = + gtl::MakeCleanup([result]() TF_EXCLUSIVE_LOCKS_REQUIRED( + &BatchResult::mu) { result->output.clear(); }); + RecordBufferDequeue(ctx, result->output); + TF_RETURN_IF_ERROR( + ProcessBatch(dataset()->batch_size_, result->num_elements, + dataset()->drop_remainder_, result->status, ctx, + out_tensors, end_of_sequence, &result->output)); + return Status::OK(); + } + + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeAsyncKnownRatioNode( + std::move(args), + /*ratio=*/dataset()->batch_size_, /*memory_ratio=*/1.0, + {model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1, + /*max=*/ctx->runner_threadpool_size())}); + } + + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + mutex_lock l(*mu_); + // Wait for all in-flight calls to complete. + while (num_calls_ > 0) { + cond_var_->wait(l); + } + DCHECK_EQ(num_calls_, 0); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kBatchResultsSize), + batch_results_.size())); + for (size_t i = 0; i < batch_results_.size(); ++i) { + TF_RETURN_IF_ERROR(WriteBatchResult(writer, i)); + } + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(*mu_); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + int64 batch_results_size; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kBatchResultsSize), + &batch_results_size)); + for (int i = 0; i < batch_results_size; ++i) { + TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i)); + } + return Status::OK(); + } + + TraceMeMetadata GetTraceMeMetadata() const override { + int64 parallelism = -1; + // NOTE: We only set the parallelism value if the lock can be acquired + // right away to avoid introducing tracing overhead. + if (mu_->try_lock()) { + parallelism = num_parallel_calls_->value; + mu_->unlock(); + } + auto result = dataset()->traceme_metadata_; + result.push_back(std::make_pair( + "parallelism", + strings::Printf("%lld", static_cast(parallelism)))); + return result; + } + + // BatchResult encapsulates the output batch. + struct BatchResult { + explicit BatchResult() + : end_of_input(false), + num_elements(0), + status(Status::OK()), + call_finished(false), + uid(tensorflow::EnvTime::NowNanos()) {} + + mutex mu; + bool end_of_input TF_GUARDED_BY(mu); + int64 num_elements TF_GUARDED_BY(mu); + std::vector output TF_GUARDED_BY(mu); + Status status TF_GUARDED_BY(mu); + bool call_finished TF_GUARDED_BY(&Iterator::mu_); + const int64 uid = -1; + }; + + void CallCompleted(const std::shared_ptr& ctx, + const std::shared_ptr& result) + TF_LOCKS_EXCLUDED(*mu_) { + mutex_lock l(*mu_); + num_calls_--; + result->call_finished = true; + cond_var_->notify_all(); + } + + // The function fetches elements from input dataset sequentially and then + // executes the batching for different batches in parallel using the context + // runner. + void CallBatching(std::shared_ptr ctx, + const std::shared_ptr& result) + TF_LOCKS_EXCLUDED(*mu_) { + profiler::TraceMe traceme([&] { + return profiler::TraceMeEncode("ParallelBatchProduce", + {{"element_id", result->uid}}); + }); + + if (!input_impl_) { + CallCompleted(ctx, result); + return; + } + + // Each row of `batch_elements` is a tuple of tensors from the input + // iterator. + auto batch_elements = + std::make_shared>>(); + batch_elements->reserve(dataset()->reserve_size_); + + bool end_of_input = false; + for (int i = 0; i < dataset()->batch_size_ && !end_of_input; ++i) { + std::vector batch_element_tuple; + Status status = input_impl_->GetNext(ctx.get(), &batch_element_tuple, + &end_of_input); + { + mutex_lock l(result->mu); + result->end_of_input = result->end_of_input || end_of_input; + result->status.Update(status); + if (result->end_of_input || !result->status.ok()) break; + } + if (!end_of_input) { + batch_elements->emplace_back(std::move(batch_element_tuple)); + mutex_lock l(result->mu); + result->num_elements++; + } else { + input_impl_.reset(); + } + } + + if (batch_elements->empty()) { + CallCompleted(ctx, result); + DCHECK(end_of_input); + return; + } + + auto copy_elements_fn = [this, ctx, result, batch_elements]() { + Status status; + { + mutex_lock l(result->mu); + status = CopyBatch(/*parallel_copy=*/false, ctx.get(), + &result->output, batch_elements.get()); + result->status.Update(status); + RecordBufferEnqueue(ctx.get(), result->output); + } + CallCompleted(ctx, result); + return status; + }; + + (*ctx->runner())(copy_elements_fn); + } + + void CancelThreads(bool wait) TF_LOCKS_EXCLUDED(mu_) { + mutex_lock l(*mu_); + cancelled_ = true; + cond_var_->notify_all(); + // Wait for all in-flight calls to complete. + while (wait && num_calls_ > 0) { + cond_var_->wait(l); + } + } + + void EnsureRunnerThreadStarted(IteratorContext* ctx) + TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + if (!runner_thread_) { + auto ctx_copy = std::make_shared(*ctx); + runner_thread_ = ctx->StartThread( + kTFDataParallelBatch, + std::bind(&Iterator::RunnerThread, this, ctx_copy)); + } + } + + void RunnerThread(const std::shared_ptr& ctx) + TF_LOCKS_EXCLUDED(*mu_) { + std::vector> new_calls; + RecordStart(ctx.get()); + auto stop_cleanup = + gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); }); + { + tf_shared_lock l(*mu_); // mu_ == num_parallel_calls_->mu + new_calls.reserve(num_parallel_calls_->value); + } + auto busy = [this]() TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool { + return num_calls_ >= num_parallel_calls_->value; + }; + while (true) { + { + mutex_lock l(*mu_); + while (!cancelled_ && busy()) { + RecordStop(ctx.get()); + cond_var_->wait(l); + RecordStart(ctx.get()); + } + + if (cancelled_) { + return; + } + + while (!busy()) { + batch_results_.push_back(std::make_shared()); + new_calls.emplace_back(batch_results_.back()); + num_calls_++; + } + } + for (const auto& call : new_calls) { + CallBatching(ctx, call); + } + new_calls.clear(); + } + } + + Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader, + size_t index) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + batch_results_.push_back(std::make_shared()); + std::shared_ptr result = batch_results_.back(); + string batch_prefix = strings::StrCat(kBatchResults, "_", index); + mutex_lock l(result->mu); + result->end_of_input = reader->Contains( + full_name(strings::StrCat(batch_prefix, "_", kEndOfInput))); + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat(batch_prefix, "_", kNumElements)), + &result->num_elements)); + result->call_finished = reader->Contains( + full_name(strings::StrCat(batch_prefix, "_", kCallFinished))); + + TF_RETURN_IF_ERROR(ReadBatch(dataset()->batch_size_, prefix(), + batch_prefix, ctx, reader, &result->output)); + TF_RETURN_IF_ERROR(ReadStatus(prefix(), + strings::StrCat(batch_prefix, "_", kStatus), + reader, &result->status)); + return Status::OK(); + } + + Status WriteBatchResult(IteratorStateWriter* writer, size_t index) + TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + std::shared_ptr result = batch_results_[index]; + string batch_prefix = strings::StrCat(kBatchResults, "_", index); + mutex_lock l(result->mu); + if (result->end_of_input) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat(batch_prefix, "_", kEndOfInput)), "")); + } + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat(batch_prefix, "_", kNumElements)), + result->num_elements)); + if (result->call_finished) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat(batch_prefix, "_", kCallFinished)), "")); + } + + TF_RETURN_IF_ERROR(WriteBatch(dataset()->batch_size_, + result->num_elements, prefix(), + batch_prefix, writer, &result->output)); + TF_RETURN_IF_ERROR( + WriteStatus(prefix(), strings::StrCat(batch_prefix, "_", kStatus), + result->status, writer)); + return Status::OK(); + } + + // Used for coordination between the main thread and the runner thread. + const std::shared_ptr mu_; + // Used for coordination between the main thread and the runner thread. In + // particular, the runner thread should only schedule new calls when the + // number of in-flight calls is less than the user specified level of + // parallelism and there are slots available in the `invocation_results_` + // buffer. + const std::shared_ptr cond_var_; + // Identifies the maximum number of parallel calls. + const std::shared_ptr num_parallel_calls_; + + // Counts the number of outstanding calls for this batch. + int64 num_calls_ TF_GUARDED_BY(*mu_) = 0; + std::unique_ptr input_impl_; + // Buffer for storing the (intermediate) batch results. + std::deque> batch_results_ TF_GUARDED_BY(*mu_); + // Background thread used for coordinating input processing. + std::unique_ptr runner_thread_ TF_GUARDED_BY(*mu_); + // Determines whether the transformation has been cancelled. + bool cancelled_ TF_GUARDED_BY(*mu_) = false; + // Identifies the number of callers currently waiting for a batch result. + int64 waiting_ TF_GUARDED_BY(*mu_) = 0; + + // Method for deregistering the cancellation callback. + std::function deregister_fn_; + }; + + const int64 batch_size_; + const int64 reserve_size_; + const int64 num_parallel_calls_; + const bool drop_remainder_; + const DatasetBase* const input_; + std::vector output_shapes_; + const TraceMeMetadata traceme_metadata_; +}; + +ParallelBatchDatasetOp::ParallelBatchDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + +void ParallelBatchDatasetOp::MakeDataset(OpKernelContext* ctx, + DatasetBase* input, + DatasetBase** output) { + int64 batch_size = 0; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBatchSize, &batch_size)); + OP_REQUIRES(ctx, batch_size > 0, + errors::InvalidArgument("Batch size must be greater than zero.")); + + int64 num_parallel_calls = 0; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kNumParallelCalls, + &num_parallel_calls)); + + bool drop_remainder = false; + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, kDropRemainder, &drop_remainder)); + + *output = + new Dataset(ctx, batch_size, num_parallel_calls, drop_remainder, input); +} + +namespace { +REGISTER_KERNEL_BUILDER(Name("ParallelBatchDataset").Device(DEVICE_CPU), + ParallelBatchDatasetOp); +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parallel_batch_dataset_op.h b/tensorflow/core/kernels/data/parallel_batch_dataset_op.h new file mode 100644 index 00000000000..2e2c7e690df --- /dev/null +++ b/tensorflow/core/kernels/data/parallel_batch_dataset_op.h @@ -0,0 +1,47 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_BATCH_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_BATCH_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" + +namespace tensorflow { +namespace data { + +class ParallelBatchDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "ParallelBatch"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kBatchSize = "batch_size"; + static constexpr const char* const kNumParallelCalls = "num_parallel_calls"; + static constexpr const char* const kDropRemainder = "drop_remainder"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit ParallelBatchDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_BATCH_DATASET_OP_H_ diff --git a/tensorflow/core/kernels/data/parallel_batch_dataset_op_test.cc b/tensorflow/core/kernels/data/parallel_batch_dataset_op_test.cc new file mode 100644 index 00000000000..247337f8749 --- /dev/null +++ b/tensorflow/core/kernels/data/parallel_batch_dataset_op_test.cc @@ -0,0 +1,416 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/data/parallel_batch_dataset_op.h" + +#include "tensorflow/core/kernels/data/dataset_test_base.h" + +namespace tensorflow { +namespace data { +namespace { + +constexpr char kNodeName[] = "parallel_batch_dataset"; +constexpr int kOpVersion = 1; + +class ParallelBatchDatasetParams : public DatasetParams { + public: + template + ParallelBatchDatasetParams(T input_dataset_params, int64 batch_size, + int64 num_parallel_calls, bool drop_remainder, + DataTypeVector output_dtypes, + std::vector output_shapes, + string node_name) + : DatasetParams(std::move(output_dtypes), std::move(output_shapes), + std::move(node_name)), + batch_size_(batch_size), + num_parallel_calls_(num_parallel_calls), + drop_remainder_(drop_remainder) { + input_dataset_params_.push_back(std::make_unique(input_dataset_params)); + op_version_ = kOpVersion; + iterator_prefix_ = + name_utils::IteratorPrefix(input_dataset_params.dataset_type(), + input_dataset_params.iterator_prefix()); + } + + std::vector GetInputTensors() const override { + Tensor batch_size = CreateTensor(TensorShape({}), {batch_size_}); + Tensor num_parallel_calls = + CreateTensor(TensorShape({}), {num_parallel_calls_}); + Tensor drop_remainder = + CreateTensor(TensorShape({}), {drop_remainder_}); + return {batch_size, num_parallel_calls, drop_remainder}; + } + + Status GetInputNames(std::vector* input_names) const override { + *input_names = {ParallelBatchDatasetOp::kInputDataset, + ParallelBatchDatasetOp::kBatchSize, + ParallelBatchDatasetOp::kNumParallelCalls, + ParallelBatchDatasetOp::kDropRemainder}; + return Status::OK(); + } + + Status GetAttributes(AttributeVector* attr_vector) const override { + *attr_vector = {{ParallelBatchDatasetOp::kOutputTypes, output_dtypes_}, + {ParallelBatchDatasetOp::kOutputShapes, output_shapes_}}; + return Status::OK(); + }; + + string dataset_type() const override { + return ParallelBatchDatasetOp::kDatasetType; + } + + private: + int64 batch_size_; + int64 num_parallel_calls_; + bool drop_remainder_; +}; + +class ParallelBatchDatasetOpTest : public DatasetOpsTestBase {}; + +// Test Case 1: test ParallelBatchDataset with `drop_remainder` = false and a +// batch size that can evenly split the input dataset. +ParallelBatchDatasetParams ParallelBatchDatasetParams1() { + return ParallelBatchDatasetParams(RangeDatasetParams(0, 12, 1), + /*batch_size=*/4, + /*num_parallel_calls=*/1, + /*drop_remainder=*/false, + /*output_dtypes=*/{DT_INT64}, + /*output_shapes=*/{PartialTensorShape({4})}, + /*node_name=*/kNodeName); +} + +// Test Case 2: test ParallelBatchDataset with `drop_remainder` = true and a +// batch size that can evenly split the input dataset. +ParallelBatchDatasetParams ParallelBatchDatasetParams2() { + return ParallelBatchDatasetParams(RangeDatasetParams(0, 12, 1), + /*batch_size=*/4, + /*num_parallel_calls=*/1, + /*drop_remainder=*/true, + /*output_dtypes=*/{DT_INT64}, + /*output_shapes=*/{PartialTensorShape({4})}, + /*node_name=*/kNodeName); +} + +// Test Case 3: test ParallelBatchDataset with `drop_remainder` = false and a +// batch size that can not evenly split the input dataset. +ParallelBatchDatasetParams ParallelBatchDatasetParams3() { + return ParallelBatchDatasetParams( + RangeDatasetParams(0, 10, 1), + /*batch_size=*/3, + /*num_parallel_calls=*/1, + /*drop_remainder=*/false, + /*output_dtypes=*/{DT_INT64}, + /*output_shapes=*/{PartialTensorShape({-1})}, + /*node_name=*/kNodeName); +} + +// Test Case 4: test ParallelBatchDataset with `drop_remainder` = true and a +// batch size that can not evenly split the input dataset. +ParallelBatchDatasetParams ParallelBatchDatasetParams4() { + return ParallelBatchDatasetParams(RangeDatasetParams(0, 10, 1), + /*batch_size=*/3, + /*num_parallel_calls=*/1, + /*drop_remainder=*/true, + /*output_dtypes=*/{DT_INT64}, + /*output_shapes=*/{PartialTensorShape({3})}, + /*node_name=*/kNodeName); +} + +// Test Case 5: test ParallelBatchDataset with `drop_remainder` = true and +// `batch_size` > the cardinality of the input dataset. +ParallelBatchDatasetParams ParallelBatchDatasetParams5() { + return ParallelBatchDatasetParams( + RangeDatasetParams(0, 10, 1), + /*batch_size=*/12, + /*num_parallel_calls=*/1, + /*drop_remainder=*/true, + /*output_dtypes=*/{DT_INT64}, + /*output_shapes=*/{PartialTensorShape({12})}, + /*node_name=*/kNodeName); +} + +// Test Case 6: test ParallelBatchDataset with `drop_remainder` = false and +// `batch_size` > the cardinality of the input dataset. +ParallelBatchDatasetParams ParallelBatchDatasetParams6() { + return ParallelBatchDatasetParams( + RangeDatasetParams(0, 10, 1), + /*batch_size=*/12, + /*num_parallel_calls=*/1, + /*drop_remainder=*/false, + /*output_dtypes=*/{DT_INT64}, + /*output_shapes=*/{PartialTensorShape({-1})}, + /*node_name=*/kNodeName); +} + +// Test Case 7: test ParallelBatchDataset with `drop_remainder` = false and +// the output of the input dataset is empty. +ParallelBatchDatasetParams ParallelBatchDatasetParams7() { + return ParallelBatchDatasetParams(RangeDatasetParams(0, 0, 1), + /*batch_size=*/4, + /*num_parallel_calls=*/1, + /*drop_remainder=*/false, + /*output_dtypes=*/{DT_INT64}, + /*output_shapes=*/{PartialTensorShape({4})}, + /*node_name=*/kNodeName); +} + +// Test Case 8: test ParallelBatchDataset with `num_parallel_calls` = 2. +ParallelBatchDatasetParams ParallelBatchDatasetParams8() { + return ParallelBatchDatasetParams(RangeDatasetParams(0, 12, 1), + /*batch_size=*/4, + /*num_parallel_calls=*/2, + /*drop_remainder=*/false, + /*output_dtypes=*/{DT_INT64}, + /*output_shapes=*/{PartialTensorShape({4})}, + /*node_name=*/kNodeName); +} + +// Test Case 9: test ParallelBatchDataset with `num_parallel_calls` = 4. +ParallelBatchDatasetParams ParallelBatchDatasetParams9() { + return ParallelBatchDatasetParams(RangeDatasetParams(0, 12, 1), + /*batch_size=*/4, + /*num_parallel_calls=*/4, + /*drop_remainder=*/false, + /*output_dtypes=*/{DT_INT64}, + /*output_shapes=*/{PartialTensorShape({4})}, + /*node_name=*/kNodeName); +} + +// Test Case 10: test ParallelBatchDataset with an invalid batch size. +ParallelBatchDatasetParams InvalidBatchSizeParallelBatchDatasetParams() { + return ParallelBatchDatasetParams(RangeDatasetParams(0, 10, 1), + /*batch_size=*/-1, + /*num_parallel_calls=*/1, + /*drop_remainder=*/false, + /*output_dtypes=*/{DT_INT64}, + /*output_shapes=*/{PartialTensorShape({3})}, + /*node_name=*/kNodeName); +} + +std::vector> GetNextTestCases() { + return {{/*dataset_params=*/ParallelBatchDatasetParams1(), + /*expected_outputs=*/ + CreateTensors(TensorShape({4}), + {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})}, + {/*dataset_params=*/ParallelBatchDatasetParams2(), + /*expected_outputs=*/ + CreateTensors(TensorShape({4}), + {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})}, + {/*dataset_params=*/ParallelBatchDatasetParams3(), + /*expected_outputs=*/ + {CreateTensor(TensorShape({3}), {0, 1, 2}), + CreateTensor(TensorShape({3}), {3, 4, 5}), + CreateTensor(TensorShape({3}), {6, 7, 8}), + CreateTensor(TensorShape({1}), {9})}}, + {/*dataset_params=*/ParallelBatchDatasetParams4(), + /*expected_outputs=*/ + CreateTensors(TensorShape({3}), + {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}})}, + {/*dataset_params=*/ParallelBatchDatasetParams5(), + /*expected_outputs=*/{}}, + {/*dataset_params=*/ParallelBatchDatasetParams6(), + /*expected_outputs=*/ + CreateTensors(TensorShape({10}), + {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}})}, + {/*dataset_params=*/ParallelBatchDatasetParams7(), + /*expected_outputs=*/{}}, + {/*dataset_params=*/ParallelBatchDatasetParams8(), + /*expected_outputs=*/ + CreateTensors(TensorShape({4}), + {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})}, + {/*dataset_params=*/ParallelBatchDatasetParams9(), + /*expected_outputs=*/ + CreateTensors(TensorShape({4}), + {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})}}; +} + +ITERATOR_GET_NEXT_TEST_P(ParallelBatchDatasetOpTest, ParallelBatchDatasetParams, + GetNextTestCases()) + +TEST_F(ParallelBatchDatasetOpTest, DatasetNodeName) { + auto parallel_batch_dataset_params = ParallelBatchDatasetParams1(); + TF_ASSERT_OK(Initialize(parallel_batch_dataset_params)); + TF_ASSERT_OK(CheckDatasetNodeName(parallel_batch_dataset_params.node_name())); +} + +TEST_F(ParallelBatchDatasetOpTest, DatasetTypeString) { + auto parallel_batch_dataset_params = ParallelBatchDatasetParams1(); + TF_ASSERT_OK(Initialize(parallel_batch_dataset_params)); + name_utils::OpNameParams params; + params.op_version = parallel_batch_dataset_params.op_version(); + TF_ASSERT_OK(CheckDatasetTypeString( + name_utils::OpName(ParallelBatchDatasetOp::kDatasetType, params))); +} + +TEST_F(ParallelBatchDatasetOpTest, DatasetOutputDtypes) { + auto parallel_batch_dataset_params = ParallelBatchDatasetParams1(); + TF_ASSERT_OK(Initialize(parallel_batch_dataset_params)); + TF_ASSERT_OK(CheckDatasetOutputDtypes({DT_INT64})); +} + +std::vector> +DatasetOutputShapesTestCases() { + return {{/*dataset_params=*/ParallelBatchDatasetParams1(), + /*expected_output_shapes=*/{PartialTensorShape({4})}}, + {/*dataset_params=*/ParallelBatchDatasetParams2(), + /*expected_output_shapes=*/{PartialTensorShape({4})}}, + {/*dataset_params=*/ParallelBatchDatasetParams3(), + /*expected_output_shapes=*/{PartialTensorShape({-1})}}, + {/*dataset_params=*/ParallelBatchDatasetParams4(), + /*expected_output_shapes=*/{PartialTensorShape({3})}}, + {/*dataset_params=*/ParallelBatchDatasetParams5(), + /*expected_output_shapes=*/{PartialTensorShape({12})}}, + {/*dataset_params=*/ParallelBatchDatasetParams6(), + /*expected_output_shapes=*/{PartialTensorShape({-1})}}, + {/*dataset_params=*/ParallelBatchDatasetParams7(), + /*expected_output_shapes=*/{PartialTensorShape({4})}}, + {/*dataset_params=*/ParallelBatchDatasetParams8(), + /*expected_output_shapes=*/{PartialTensorShape({4})}}, + {/*dataset_params=*/ParallelBatchDatasetParams9(), + /*expected_output_shapes=*/{PartialTensorShape({4})}}}; +} + +DATASET_OUTPUT_SHAPES_TEST_P(ParallelBatchDatasetOpTest, + ParallelBatchDatasetParams, + DatasetOutputShapesTestCases()) + +std::vector> +CardinalityTestCases() { + return {{/*dataset_params=*/ParallelBatchDatasetParams1(), + /*expected_cardinality=*/3}, + {/*dataset_params=*/ParallelBatchDatasetParams2(), + /*expected_cardinality=*/3}, + {/*dataset_params=*/ParallelBatchDatasetParams3(), + /*expected_cardinality=*/4}, + {/*dataset_params=*/ParallelBatchDatasetParams4(), + /*expected_cardinality=*/3}, + {/*dataset_params=*/ParallelBatchDatasetParams5(), + /*expected_cardinality=*/0}, + {/*dataset_params=*/ParallelBatchDatasetParams6(), + /*expected_cardinality=*/1}, + {/*dataset_params=*/ParallelBatchDatasetParams7(), + /*expected_cardinality=*/0}, + {/*dataset_params=*/ParallelBatchDatasetParams8(), + /*expected_cardinality=*/3}, + {/*dataset_params=*/ParallelBatchDatasetParams9(), + /*expected_cardinality=*/3}}; +} + +DATASET_CARDINALITY_TEST_P(ParallelBatchDatasetOpTest, + ParallelBatchDatasetParams, CardinalityTestCases()) + +TEST_F(ParallelBatchDatasetOpTest, IteratorOutputDtypes) { + auto parallel_batch_dataset_params = ParallelBatchDatasetParams1(); + TF_ASSERT_OK(Initialize(parallel_batch_dataset_params)); + TF_ASSERT_OK(CheckIteratorOutputDtypes({DT_INT64})); +} + +std::vector> +IteratorOutputShapesTestCases() { + return {{/*dataset_params=*/ParallelBatchDatasetParams1(), + /*expected_output_shapes=*/{PartialTensorShape({4})}}, + {/*dataset_params=*/ParallelBatchDatasetParams2(), + /*expected_output_shapes=*/{PartialTensorShape({4})}}, + {/*dataset_params=*/ParallelBatchDatasetParams3(), + /*expected_output_shapes=*/{PartialTensorShape({-1})}}, + {/*dataset_params=*/ParallelBatchDatasetParams4(), + /*expected_output_shapes=*/{PartialTensorShape({3})}}, + {/*dataset_params=*/ParallelBatchDatasetParams5(), + /*expected_output_shapes=*/{PartialTensorShape({12})}}, + {/*dataset_params=*/ParallelBatchDatasetParams6(), + /*expected_output_shapes=*/{PartialTensorShape({-1})}}, + {/*dataset_params=*/ParallelBatchDatasetParams7(), + /*expected_output_shapes=*/{PartialTensorShape({4})}}, + {/*dataset_params=*/ParallelBatchDatasetParams8(), + /*expected_output_shapes=*/{PartialTensorShape({4})}}, + {/*dataset_params=*/ParallelBatchDatasetParams9(), + /*expected_output_shapes=*/{PartialTensorShape({4})}}}; +} + +ITERATOR_OUTPUT_SHAPES_TEST_P(ParallelBatchDatasetOpTest, + ParallelBatchDatasetParams, + IteratorOutputShapesTestCases()) + +TEST_F(ParallelBatchDatasetOpTest, IteratorOutputPrefix) { + auto parallel_batch_dataset_params = ParallelBatchDatasetParams1(); + TF_ASSERT_OK(Initialize(parallel_batch_dataset_params)); + name_utils::IteratorPrefixParams params; + params.op_version = parallel_batch_dataset_params.op_version(); + TF_ASSERT_OK(CheckIteratorPrefix(name_utils::IteratorPrefix( + ParallelBatchDatasetOp::kDatasetType, + parallel_batch_dataset_params.iterator_prefix(), params))); +} + +std::vector> +IteratorSaveAndRestoreTestCases() { + return {{/*dataset_params=*/ParallelBatchDatasetParams1(), + /*breakpoints=*/{0, 1, 5}, + /*expected_outputs=*/ + CreateTensors(TensorShape({4}), + {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})}, + {/*dataset_params=*/ParallelBatchDatasetParams2(), + /*breakpoints=*/{0, 1, 5}, + /*expected_outputs=*/ + CreateTensors(TensorShape({4}), + {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})}, + {/*dataset_params=*/ParallelBatchDatasetParams3(), + /*breakpoints=*/{0, 1, 5}, + /*expected_outputs=*/ + {CreateTensor(TensorShape({3}), {0, 1, 2}), + CreateTensor(TensorShape({3}), {3, 4, 5}), + CreateTensor(TensorShape({3}), {6, 7, 8}), + CreateTensor(TensorShape({1}), {9})}}, + {/*dataset_params=*/ParallelBatchDatasetParams4(), + /*breakpoints=*/{0, 1, 5}, + /*expected_outputs=*/ + {CreateTensor(TensorShape({3}), {0, 1, 2}), + CreateTensor(TensorShape({3}), {3, 4, 5}), + CreateTensor(TensorShape({3}), {6, 7, 8})}}, + {/*dataset_params=*/ParallelBatchDatasetParams5(), + /*breakpoints=*/{0, 1, 5}, + /*expected_outputs=*/{}}, + {/*dataset_params=*/ParallelBatchDatasetParams6(), + /*breakpoints=*/{0, 1, 5}, + /*expected_outputs=*/ + {CreateTensor(TensorShape({10}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})}}, + {/*dataset_params=*/ParallelBatchDatasetParams7(), + /*breakpoints=*/{0, 1, 5}, + /*expected_outputs=*/{}}, + {/*dataset_params=*/ParallelBatchDatasetParams8(), + /*breakpoints=*/{0, 1, 5}, + /*expected_outputs=*/ + CreateTensors(TensorShape({4}), + {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})}, + {/*dataset_params=*/ParallelBatchDatasetParams9(), + /*breakpoints=*/{0, 1, 5}, + /*expected_outputs=*/ + CreateTensors(TensorShape({4}), + {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})}}; +} + +ITERATOR_SAVE_AND_RESTORE_TEST_P(ParallelBatchDatasetOpTest, + ParallelBatchDatasetParams, + IteratorSaveAndRestoreTestCases()) + +TEST_F(ParallelBatchDatasetOpTest, InvalidParallelBatchSize) { + auto parallel_batch_dataset_params = + InvalidBatchSizeParallelBatchDatasetParams(); + EXPECT_EQ(Initialize(parallel_batch_dataset_params).code(), + tensorflow::error::INVALID_ARGUMENT); +} + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index e00ebf7325a..6ceb3719630 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -328,6 +328,25 @@ REGISTER_OP("BatchDatasetV2") return shape_inference::ScalarShape(c); }); +REGISTER_OP("ParallelBatchDataset") + .Input("input_dataset: variant") + .Input("batch_size: int64") + .Input("num_parallel_calls: int64") + .Input("drop_remainder: bool") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // batch_size should be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + // num_parallel_calls should be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + // drop_remainder should be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + return shape_inference::ScalarShape(c); + }); + REGISTER_OP("ShardDataset") .Input("input_dataset: variant") .Input("num_shards: int64") diff --git a/tensorflow/python/data/benchmarks/batch_benchmark.py b/tensorflow/python/data/benchmarks/batch_benchmark.py index 5b97873fc6e..e8607edae1b 100644 --- a/tensorflow/python/data/benchmarks/batch_benchmark.py +++ b/tensorflow/python/data/benchmarks/batch_benchmark.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.python.data.benchmarks import benchmark_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import random_ops class BatchBenchmark(benchmark_base.DatasetBenchmarkBase): @@ -67,6 +68,30 @@ class BatchBenchmark(benchmark_base.DatasetBenchmarkBase): name="batch_element_size_%d_batch_size_%d%s" % (element_size, batch_size, tag)) + def benchmark_parallel_batch(self): + batch_size = 128 + nums_parallel_calls = [None, 1, 4, 16, dataset_ops.AUTOTUNE] + num_range = 100000 + + def f(_): + return random_ops.random_uniform([224, 224, 3]) + + for num_parallel_calls in nums_parallel_calls: + num_parallel_calls_str = ("autotune" + if num_parallel_calls == dataset_ops.AUTOTUNE + else str(num_parallel_calls)) + op_str = ("batch" if num_parallel_calls is None else + ("parallel_batch_num_parallel_calls_%s" % + num_parallel_calls_str)) + + dataset = dataset_ops.Dataset.range(num_range).map(f).batch( + batch_size, num_parallel_calls=num_parallel_calls) + self.run_and_report_benchmark( + dataset, + num_elements=num_range // batch_size, + iters=1, + name="batch_size_%d_%s" % (batch_size, op_str)) + if __name__ == "__main__": benchmark_base.test.main() diff --git a/tensorflow/python/data/kernel_tests/batch_test.py b/tensorflow/python/data/kernel_tests/batch_test.py index 013f9831bd1..4fdb0c041c3 100644 --- a/tensorflow/python/data/kernel_tests/batch_test.py +++ b/tensorflow/python/data/kernel_tests/batch_test.py @@ -43,9 +43,11 @@ class BatchTest(test_base.DatasetTestBase, parameterized.TestCase): combinations.times( test_base.default_test_combinations(), combinations.combine( - count=[0, 28], batch_size=[14, 15], drop_remainder=[True, - False]))) - def testBasic(self, count, batch_size, drop_remainder): + count=[0, 28], + batch_size=[14, 15], + drop_remainder=[True, False], + num_parallel_calls=[None, 1, 2, 4]))) + def testBasic(self, count, batch_size, drop_remainder, num_parallel_calls): """Tests the batch dataset logic for various input configurations. Args: @@ -53,6 +55,8 @@ class BatchTest(test_base.DatasetTestBase, parameterized.TestCase): batch_size: the batch size drop_remainder: whether a smaller batch size should be produced if batch size does not divide number of inputs evenly + num_parallel_calls: the number batches to process asynchronously in + parallel """ # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> @@ -65,7 +69,8 @@ class BatchTest(test_base.DatasetTestBase, parameterized.TestCase): return math_ops.square(x), math_ops.square(y), math_ops.square(z) dataset = dataset_ops.Dataset.from_tensor_slices(components).map( - _map_fn).repeat(count).batch(batch_size, drop_remainder) + _map_fn).repeat(count).batch(batch_size, drop_remainder, + num_parallel_calls) get_next = self.getNext(dataset) if drop_remainder: diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 1c949254f48..59eb761c5d3 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -1505,7 +1505,7 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable, """ return ShardDataset(self, num_shards, index) - def batch(self, batch_size, drop_remainder=False): + def batch(self, batch_size, drop_remainder=False, num_parallel_calls=None): """Combines consecutive elements of this dataset into batches. >>> dataset = tf.data.Dataset.range(8) @@ -1532,11 +1532,21 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable, whether the last batch should be dropped in the case it has fewer than `batch_size` elements; the default behavior is not to drop the smaller batch. + num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`, + representing the number of batches to compute asynchronously in + parallel. + If not specified, batches will be computed sequentially. If the value + `tf.data.AUTOTUNE` is used, then the number of parallel + calls is set dynamically based on available resources. Returns: Dataset: A `Dataset`. """ - return BatchDataset(self, batch_size, drop_remainder) + if num_parallel_calls is None: + return BatchDataset(self, batch_size, drop_remainder) + else: + return ParallelBatchDataset(self, batch_size, drop_remainder, + num_parallel_calls) def padded_batch(self, batch_size, @@ -2613,9 +2623,10 @@ class DatasetV1(DatasetV2): return DatasetV1Adapter(super(DatasetV1, self).shard(num_shards, index)) @functools.wraps(DatasetV2.batch) - def batch(self, batch_size, drop_remainder=False): - return DatasetV1Adapter(super(DatasetV1, self).batch( - batch_size, drop_remainder)) + def batch(self, batch_size, drop_remainder=False, num_parallel_calls=None): + return DatasetV1Adapter( + super(DatasetV1, self).batch(batch_size, drop_remainder, + num_parallel_calls)) @functools.wraps(DatasetV2.padded_batch) def padded_batch(self, @@ -3936,6 +3947,47 @@ class BatchDataset(UnaryDataset): return self._structure +class ParallelBatchDataset(UnaryDataset): + """A `Dataset` that batches contiguous elements from its input in parallel.""" + + def __init__(self, input_dataset, batch_size, drop_remainder, + num_parallel_calls): + """See `Dataset.batch()` for details.""" + self._input_dataset = input_dataset + self._batch_size = ops.convert_to_tensor( + batch_size, dtype=dtypes.int64, name="batch_size") + self._drop_remainder = ops.convert_to_tensor( + drop_remainder, dtype=dtypes.bool, name="drop_remainder") + self._num_parallel_calls = ops.convert_to_tensor( + num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") + + constant_drop_remainder = tensor_util.constant_value(self._drop_remainder) + # pylint: disable=protected-access + if constant_drop_remainder: + # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) + # or `False` (explicitly retaining the remainder). + # pylint: disable=g-long-lambda + constant_batch_size = tensor_util.constant_value(self._batch_size) + self._structure = nest.map_structure( + lambda component_spec: component_spec._batch(constant_batch_size), + input_dataset.element_spec) + else: + self._structure = nest.map_structure( + lambda component_spec: component_spec._batch(None), + input_dataset.element_spec) + variant_tensor = gen_dataset_ops.parallel_batch_dataset( + input_dataset._variant_tensor, + batch_size=self._batch_size, + num_parallel_calls=self._num_parallel_calls, + drop_remainder=self._drop_remainder, + **self._flat_structure) + super(ParallelBatchDataset, self).__init__(input_dataset, variant_tensor) + + @property + def element_spec(self): + return self._structure + + class _NumpyIterator(object): """Iterator over a dataset with elements converted to numpy.""" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt index 92d48198ca3..51d39adc4cb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt @@ -33,7 +33,7 @@ tf_class { } member_method { name: "batch" - argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], " + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } member_method { name: "cache" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt index 6ef93ca9890..e6ae33a97b4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -35,7 +35,7 @@ tf_class { } member_method { name: "batch" - argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], " + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } member_method { name: "cache" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt index 7249cdff1be..cbfb077acfb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt @@ -35,7 +35,7 @@ tf_class { } member_method { name: "batch" - argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], " + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } member_method { name: "cache" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt index 79aeac1d2a1..dcae2a070ad 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt @@ -35,7 +35,7 @@ tf_class { } member_method { name: "batch" - argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], " + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } member_method { name: "cache" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt index b81d19a161f..b7e9ed060ea 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt @@ -35,7 +35,7 @@ tf_class { } member_method { name: "batch" - argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], " + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } member_method { name: "cache" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt index cacd6fa0d0a..bf089044286 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt @@ -35,7 +35,7 @@ tf_class { } member_method { name: "batch" - argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], " + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } member_method { name: "cache" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt index 3cbcda297e4..a4ccc2ac7dc 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt @@ -35,7 +35,7 @@ tf_class { } member_method { name: "batch" - argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], " + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } member_method { name: "cache" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 2f191df658d..1cdb121acc0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -2796,6 +2796,10 @@ tf_module { name: "PaddingFIFOQueueV2" argspec: "args=[\'component_types\', \'shapes\', \'capacity\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'-1\', \'\', \'\', \'None\'], " } + member_method { + name: "ParallelBatchDataset" + argspec: "args=[\'input_dataset\', \'batch_size\', \'num_parallel_calls\', \'drop_remainder\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "ParallelConcat" argspec: "args=[\'values\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt index 1aee69e3a5f..0aaeb69f6c6 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt @@ -20,7 +20,7 @@ tf_class { } member_method { name: "batch" - argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], " + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } member_method { name: "cache" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt index eec1f30e679..1e034690a9b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -22,7 +22,7 @@ tf_class { } member_method { name: "batch" - argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], " + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } member_method { name: "cache" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt index a9d0eafa605..458da1f007f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "batch" - argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], " + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } member_method { name: "cache" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt index c875e08a4c3..eafac5753dd 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt @@ -22,7 +22,7 @@ tf_class { } member_method { name: "batch" - argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], " + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } member_method { name: "cache" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt index 90c6edd4da4..7f848e9aa22 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt @@ -22,7 +22,7 @@ tf_class { } member_method { name: "batch" - argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], " + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } member_method { name: "cache" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt index ae2c22c237e..5151bdb81a7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt @@ -22,7 +22,7 @@ tf_class { } member_method { name: "batch" - argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], " + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } member_method { name: "cache" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt index 71726e9aebe..9c21b95f21e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt @@ -22,7 +22,7 @@ tf_class { } member_method { name: "batch" - argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], " + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } member_method { name: "cache" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 2f191df658d..1cdb121acc0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -2796,6 +2796,10 @@ tf_module { name: "PaddingFIFOQueueV2" argspec: "args=[\'component_types\', \'shapes\', \'capacity\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'-1\', \'\', \'\', \'None\'], " } + member_method { + name: "ParallelBatchDataset" + argspec: "args=[\'input_dataset\', \'batch_size\', \'num_parallel_calls\', \'drop_remainder\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "ParallelConcat" argspec: "args=[\'values\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "