From de87e628e6d89382783ea948c21fa66b182d69d3 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Fri, 1 Feb 2019 09:25:37 -0800 Subject: [PATCH] Adding a RebatchDatasetOp that uses Grappler to modify the dataset graph to mutate the batch size of the Dataset. Right now this basically does all the wiring through Grappler etc. and the logic is very minimal and basic. Subsequent CL's will make the logic expand to different scenarios. PiperOrigin-RevId: 231983575 --- .../api_def_ExperimentalRebatchDataset.pbtxt | 23 ++ .../core/grappler/optimizers/data/BUILD | 15 + .../grappler/optimizers/data/graph_utils.cc | 7 + .../grappler/optimizers/data/graph_utils.h | 4 + .../optimizers/data/graph_utils_test.cc | 15 + .../core/grappler/optimizers/data/rebatch.cc | 115 ++++++++ .../core/grappler/optimizers/data/rebatch.h | 52 ++++ tensorflow/core/kernels/data/BUILD | 24 +- .../core/kernels/data/experimental/BUILD | 16 ++ .../data/experimental/rebatch_dataset_op.cc | 92 ++++++ .../kernels/data/graph_rewrite_dataset.cc | 239 ++++++++++++++++ .../core/kernels/data/graph_rewrite_dataset.h | 92 ++++++ .../core/kernels/data/optimize_dataset_op.cc | 264 +----------------- .../core/ops/experimental_dataset_ops.cc | 8 + .../data/experimental/kernel_tests/BUILD | 14 + .../kernel_tests/rebatch_dataset_test.py | 60 ++++ .../kernel_tests/serialization/BUILD | 18 ++ .../rebatch_dataset_serialization_test.py | 41 +++ .../python/data/experimental/ops/batching.py | 31 ++ .../api/golden/v1/tensorflow.raw_ops.pbtxt | 4 + .../api/golden/v2/tensorflow.raw_ops.pbtxt | 4 + 21 files changed, 874 insertions(+), 264 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_ExperimentalRebatchDataset.pbtxt create mode 100644 tensorflow/core/grappler/optimizers/data/rebatch.cc create mode 100644 tensorflow/core/grappler/optimizers/data/rebatch.h create mode 100644 tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc create mode 100644 tensorflow/core/kernels/data/graph_rewrite_dataset.cc create mode 100644 tensorflow/core/kernels/data/graph_rewrite_dataset.h create mode 100644 tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py create mode 100644 tensorflow/python/data/experimental/kernel_tests/serialization/rebatch_dataset_serialization_test.py diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalRebatchDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalRebatchDataset.pbtxt new file mode 100644 index 00000000000..b8455308e5c --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalRebatchDataset.pbtxt @@ -0,0 +1,23 @@ +op { + graph_op_name: "ExperimentalRebatchDataset" + visibility: HIDDEN + in_arg { + name: "input_dataset" + description: < FindAllGraphNodesWithOp(const string& op, diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc index 3b6d223fd36..879cecd13d3 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc @@ -228,6 +228,21 @@ TEST(GraphUtilsTest, GetInputNode) { EXPECT_EQ(GetInputNode(*node1, graph), nullptr); } +TEST(GraphUtilsTest, GetIthInputNode) { + GraphDef graph_def; + MutableGraphView graph(&graph_def); + + NodeDef* node1 = AddNode("", "A", {}, {}, &graph); + NodeDef* node2 = AddNode("", "A", {}, {}, &graph); + NodeDef* node3 = AddNode("", "A", {node1->name(), node2->name()}, {}, &graph); + + EXPECT_EQ(GetInputNode(*node3, graph), node1); + EXPECT_EQ(GetInputNode(*node3, graph, 1), node2); + EXPECT_EQ(GetInputNode(*node3, graph, 0), node1); + EXPECT_EQ(GetInputNode(*node3, graph, 2), nullptr); + EXPECT_EQ(GetInputNode(*node1, graph), nullptr); +} + TEST(GraphUtilsTest, EnsureNodeNamesUnique) { Graph g(OpRegistry::Global()); diff --git a/tensorflow/core/grappler/optimizers/data/rebatch.cc b/tensorflow/core/grappler/optimizers/data/rebatch.cc new file mode 100644 index 00000000000..187e1a62afa --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/rebatch.cc @@ -0,0 +1,115 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/rebatch.h" + +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" + +namespace tensorflow { +namespace grappler { + +Status RebatchOptimizer::Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) { + if (!config) return Status::OK(); + + num_workers_ = config->parameter_map().at("num_workers").i(); + return Status::OK(); +} + +namespace { + +constexpr char kCastOp[] = "Cast"; +constexpr char kRealDivOp[] = "RealDiv"; +constexpr char kBatchDatasetOp[] = "BatchDatasetV2"; + +NodeDef* AddCastNode(const string& input, DataType src_t, DataType dst_t, + MutableGraphView* graph) { + NodeDef cast_node; + cast_node.set_op(kCastOp); + cast_node.add_input(input); + graph_utils::SetUniqueGraphNodeName(cast_node.op(), graph->graph(), + &cast_node); + AddNodeAttr("SrcT", src_t, &cast_node); + AddNodeAttr("DstT", dst_t, &cast_node); + + return graph->AddNode(std::move(cast_node)); +} + +NodeDef* AddBinaryNode(const string& input_x, const string& input_y, + const string& op, DataType type, + MutableGraphView* graph) { + NodeDef node; + node.set_op(op); + node.add_input(input_x); + node.add_input(input_y); + graph_utils::SetUniqueGraphNodeName(op, graph->graph(), &node); + AddNodeAttr("T", type, &node); + + return graph->AddNode(std::move(node)); +} + +NodeDef* AddFloatDivNode(const string& input_x, const string& input_y, + MutableGraphView* graph) { + return AddBinaryNode(input_x, input_y, kRealDivOp, DT_FLOAT, graph); +} + +} // anonymous namespace + +Status RebatchOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) { + *output = item.graph; + MutableGraphView graph(output); + + absl::flat_hash_set nodes_to_delete; + for (const NodeDef& node : item.graph.node()) { + if (node.op() == kBatchDatasetOp) { + NodeDef* batch_size_node = graph_utils::GetInputNode(node, graph, 1); + NodeDef tmp_node; + tmp_node = *batch_size_node; + graph_utils::SetUniqueGraphNodeName(tmp_node.op(), graph.graph(), + &tmp_node); + NodeDef* copy_batch_size_node = graph.AddNode(std::move(tmp_node)); + NodeDef* float_copy_batch_size_node = + AddCastNode(copy_batch_size_node->name(), DT_INT64, DT_FLOAT, &graph); + NodeDef* num_worker_node = + graph_utils::AddScalarConstNode(num_workers_, &graph); + NodeDef* float_num_worker_node = + AddCastNode(num_worker_node->name(), DT_INT64, DT_FLOAT, &graph); + NodeDef* divided_batch_size_node = + AddFloatDivNode(float_copy_batch_size_node->name(), + float_num_worker_node->name(), &graph); + NodeDef* cast_new_batch_size_node = AddCastNode( + divided_batch_size_node->name(), DT_FLOAT, DT_INT64, &graph); + TF_RETURN_IF_ERROR(graph.UpdateFanouts(batch_size_node->name(), + cast_new_batch_size_node->name())); + nodes_to_delete.insert(batch_size_node->name()); + break; + } + } + TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete)); + return Status::OK(); +} + +void RebatchOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, + double result) {} + +REGISTER_GRAPH_OPTIMIZER_AS(RebatchOptimizer, "tf_data_rebatcher"); + +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/rebatch.h b/tensorflow/core/grappler/optimizers/data/rebatch.h new file mode 100644 index 00000000000..f7aa69f3ff9 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/rebatch.h @@ -0,0 +1,52 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_REBATCH_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_REBATCH_H_ + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" + +namespace tensorflow { +namespace grappler { + +// This optimizer changes the batch size of the output dataset by dividing the +// current batch size by parameter `num_workers`. Currently, this works only +// for very simple pipelines with a single BatchDatasetV2 transformation. +// +// TODO(rohanj): Extend this logic to correctly handle any input pipeline that +// uses core tf.data APIs + MapAndBatch. +class RebatchOptimizer : public CustomGraphOptimizer { + public: + RebatchOptimizer() = default; + ~RebatchOptimizer() override = default; + + string name() const override { return "tf_data_rebatcher"; } + + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override; + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) override; + + private: + int64 num_workers_; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_REBATCH_H_ diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index c1df9b57071..cd6803fb3fa 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -548,17 +548,14 @@ tf_kernel_library( ], ) -tf_kernel_library( - name = "optimize_dataset_op", - srcs = ["optimize_dataset_op.cc"], +cc_library( + name = "graph_rewrite_dataset", + srcs = ["graph_rewrite_dataset.cc"], + hdrs = ["graph_rewrite_dataset.h"], deps = [ "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:grappler_item_builder", "//tensorflow/core/grappler/clusters:virtual_cluster", @@ -569,6 +566,19 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "optimize_dataset_op", + srcs = ["optimize_dataset_op.cc"], + deps = [ + ":graph_rewrite_dataset", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + ], +) + tf_kernel_library( name = "model_dataset_op", srcs = ["model_dataset_op.cc"], diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index 4f7c8f156c6..9171b91a62f 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -232,6 +232,21 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "rebatch_dataset_op", + srcs = ["rebatch_dataset_op.cc"], + deps = [ + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler/optimizers/data:rebatch", + "//tensorflow/core/kernels/data:graph_rewrite_dataset", + ], +) + tf_kernel_library( name = "scan_dataset_op", srcs = ["scan_dataset_op.cc"], @@ -391,6 +406,7 @@ tf_kernel_library( ":parse_example_dataset_op", ":prefetching_kernels", ":random_dataset_op", + ":rebatch_dataset_op", ":scan_dataset_op", ":set_stats_aggregator_dataset_op", ":sleep_dataset_op", diff --git a/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc new file mode 100644 index 00000000000..a95773afd12 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc @@ -0,0 +1,92 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/kernels/data/graph_rewrite_dataset.h" + +namespace tensorflow { +namespace data { +namespace { + +constexpr char kOptimizerName[] = "tf_data_rebatcher"; + +class RebatchDatasetOp : public UnaryDatasetOpKernel { + public: + explicit RebatchDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx), + graph_def_version_(ctx->graph_def_version()) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + int64 num_workers; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_workers", &num_workers)); + OP_REQUIRES(ctx, num_workers > 0, + errors::InvalidArgument( + "num_parallel_calls must be greater than zero.")); + + Dataset* dataset = + new Dataset(ctx, input, num_workers, output_types_, output_shapes_); + Status s = dataset->Optimize(ctx); + if (s.ok()) { + *output = dataset; + } else { + dataset->Unref(); + OP_REQUIRES_OK(ctx, s); + } + } + + private: + class Dataset : public GraphRewriteDataset { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const int64 num_workers, const DataTypeVector& output_types, + const std::vector& output_shapes) + : GraphRewriteDataset(ctx, input, output_types, output_shapes), + num_workers_(num_workers) {} + + string DebugString() const override { return "RebatchDatasetOp::Dataset"; } + + private: + RewriterConfig CreateGrapplerRewriteConfig() override { + RewriterConfig rewriter_config; + rewriter_config.add_optimizers(kOptimizerName); + rewriter_config.set_meta_optimizer_iterations( + RewriterConfig_NumIterationsType_ONE); + auto custom_optimizer = rewriter_config.add_custom_optimizers(); + custom_optimizer->set_name(kOptimizerName); + AttrValue num_workers_attr; + num_workers_attr.set_i(num_workers_); + (*custom_optimizer->mutable_parameter_map())["num_workers"] = + num_workers_attr; + return rewriter_config; + } + + const int64 num_workers_; + }; + + const int graph_def_version_; + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +REGISTER_KERNEL_BUILDER(Name("ExperimentalRebatchDataset").Device(DEVICE_CPU), + RebatchDatasetOp); + +} // anonymous namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/graph_rewrite_dataset.cc b/tensorflow/core/kernels/data/graph_rewrite_dataset.cc new file mode 100644 index 00000000000..bc4bb460101 --- /dev/null +++ b/tensorflow/core/kernels/data/graph_rewrite_dataset.cc @@ -0,0 +1,239 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/data/graph_rewrite_dataset.h" + +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace data { + +GraphRewriteDataset::~GraphRewriteDataset() { + input_->Unref(); + if (optimized_input_) { + optimized_input_->Unref(); + } +} + +Status GraphRewriteDataset::Optimize(OpKernelContext* ctx) { + GraphDefBuilder b; + DatasetGraphDefBuilder db(&b); + Node* input_node = nullptr; + SerializationContext::Params params; + std::vector> input_list; + params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); + params.input_list = &input_list; + params.optimization_only = true; + SerializationContext serialization_ctx(params); + TF_RETURN_IF_ERROR( + db.AddInputDataset(&serialization_ctx, input_, &input_node)); + string output_node = input_node->name(); + + GraphDef graph_def; + TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def)); + VLOG(3) << "Before optimization: " << graph_def.DebugString(); + + TF_RETURN_IF_ERROR(ApplyOptimizations(ctx, &graph_def, &output_node)); + VLOG(3) << "After optimization: " << graph_def.DebugString(); + + // Instantiate the optimized input pipeline by running the optimized graph + // using the optimized function library. + TF_RETURN_IF_ERROR(ctx->function_library()->Clone(&flib_def_, &pflr_, &lib_)); + + // Create a FunctionHandleCache. + function_handle_cache_ = absl::make_unique(lib_); + + // Some functions may have been modified without having their names + // changed (for example, nested dataset graphs from FlatMap or + // Interleave). To avoid name conflicts, we remove these functions from + // flib_def_ before adding the optimized function library. + for (const FunctionDef& fd : graph_def.library().function()) { + if (flib_def_->Find(fd.signature().name()) != nullptr) { + TF_RETURN_IF_ERROR(flib_def_->RemoveFunction(fd.signature().name())); + } + } + TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph_def.library())); + + Graph graph(OpRegistry::Global()); + TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr)); + std::vector outputs; + GraphRunner graph_runner(ctx->function_library()->device()); + + TF_RETURN_IF_ERROR( + graph_runner.Run(&graph, lib_, input_list, {output_node}, &outputs)); + TF_RETURN_IF_ERROR( + GetDatasetFromVariantTensor(outputs[0], &optimized_input_)); + optimized_input_->Ref(); + return Status::OK(); +} + +Status GraphRewriteDataset::AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const { + // We only serialize the optimized dataset to avoid re-running + // optimizations when the input pipeline is restored from a checkpoint. + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, optimized_input_, output)); + return Status::OK(); +} + +namespace { +void AddFakeSinks(FunctionDef* function_def) { + int counter = 0; + for (const auto& output : function_def->signature().output_arg()) { + NodeDef* node = function_def->add_node_def(); + tensorflow::grappler::function_utils::SetUniqueFunctionNodeName( + strings::StrCat("FakeSink", counter++), function_def, node); + node->set_op("Identity"); + node->add_input(function_def->ret().at(output.name())); + (*node->mutable_attr())["T"].set_type(output.type()); + + (*function_def->mutable_ret())[output.name()] = + strings::StrCat(node->name(), ":output:0"); + } +} + +void RemoveFakeSinks(FunctionDef* function_def) { + // Map from identity node names to their input tensor strings + std::map identity_map; + for (const auto& node : function_def->node_def()) { + if (node.op() == "Identity" && node.input_size() == 1) { + identity_map[node.name()] = node.input(0); + } + } + for (const auto& output_arg : function_def->signature().output_arg()) { + const string& tensor = function_def->ret().at(output_arg.name()); + const string& output_node = tensor.substr(0, tensor.find(':')); + if (identity_map.find(output_node) != identity_map.end()) { + (*function_def->mutable_ret())[output_arg.name()] = + identity_map.at(output_node); + } + } +} +} // anonymous namespace + +Status GraphRewriteDataset::ApplyOptimizations(OpKernelContext* ctx, + GraphDef* graph_def, + string* output_node) { + // Add an identity node as the fetch node, otherwise we might get + // 'placeholder is both fed and fetched' errors in some cases when using + // input list with placeholder dataset nodes. + NodeDef* node = graph_def->mutable_node()->Add(); + tensorflow::grappler::graph_utils::SetUniqueGraphNodeName("Sink", graph_def, + node); + node->set_op("Identity"); + node->add_input(*output_node); + (*node->mutable_attr())["T"].set_type(DT_VARIANT); + *output_node = node->name(); + + // Add fake sink node to graph and functions to allow rewriting the actual + // sink nodes. + // TODO(b/118820916): When MetaOptimizer adds provisions for function + // retvals to be optimizable, we will no longer need this. + for (auto& function_def : *graph_def->mutable_library()->mutable_function()) { + AddFakeSinks(&function_def); + } + + // Create metagraph. + MetaGraphDef meta_graph_def; + (*meta_graph_def.mutable_graph_def()) = *graph_def; + + // Grappler determines fetch ops from collection 'train_op'. + CollectionDef collection_def; + auto node_list = collection_def.mutable_node_list(); + node_list->add_value(*output_node); + (*meta_graph_def.mutable_collection_def())["train_op"] = collection_def; + + // Create Grappler item. + tensorflow::grappler::ItemConfig item_config; + item_config.apply_optimizations = true; + std::unique_ptr grappler_item = + tensorflow::grappler::GrapplerItemFromMetaGraphDef( + "graph", meta_graph_def, item_config); + std::unordered_map device_map; + tensorflow::grappler::VirtualCluster cluster(device_map); + + // Run data optimizer using grappler's meta optimizer. + tensorflow::ConfigProto config; + *config.mutable_graph_options()->mutable_rewrite_options() = + CreateGrapplerRewriteConfig(); + TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer( + *grappler_item, config, ctx->device(), &cluster, graph_def)); + + // Remove fake sinks after optimizations are done. + // TODO(b/118820916): When MetaOptimizer adds provisions for function + // retvals to be optimizable, we will no longer need this. + for (auto& function_def : *graph_def->mutable_library()->mutable_function()) { + RemoveFakeSinks(&function_def); + } + + return Status::OK(); +} + +class GraphRewriteDataset::Iterator + : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + IteratorContext::Params params(ctx); + params.lib = dataset()->lib_; + params.function_handle_cache = dataset()->function_handle_cache_.get(); + return dataset()->optimized_input_->MakeIterator( + IteratorContext(std::move(params)), prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) override { + IteratorContext::Params params(ctx); + params.lib = dataset()->lib_; + params.function_handle_cache = dataset()->function_handle_cache_.get(); + return input_impl_->GetNext(IteratorContext(std::move(params)), out_tensors, + end_of_sequence); + } + + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + + Status SaveInternal(IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + return Status::OK(); + } + + private: + std::unique_ptr input_impl_; +}; + +std::unique_ptr GraphRewriteDataset::MakeIteratorInternal( + const string& prefix) const { + // We do not add a token for this dataset to the prefix. The + // prefix is used to identify checkpoint elements and since this + // dataset is excluded from the checkpoint, adding a token + // here would result in invalid checkpoint identifiers. + return absl::make_unique(Iterator::Params{this, prefix}); +} + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/graph_rewrite_dataset.h b/tensorflow/core/kernels/data/graph_rewrite_dataset.h new file mode 100644 index 00000000000..dedbdce71ff --- /dev/null +++ b/tensorflow/core/kernels/data/graph_rewrite_dataset.h @@ -0,0 +1,92 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_GRAPH_REWRITE_DATASET_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_GRAPH_REWRITE_DATASET_H_ + +#include "tensorflow/core/common_runtime/graph_runner.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/function_handle_cache.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/grappler/clusters/virtual_cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/grappler_item_builder.h" +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/grappler/optimizers/meta_optimizer.h" + +namespace tensorflow { +namespace data { + +class GraphRewriteDataset : public DatasetBase { + public: + GraphRewriteDataset(OpKernelContext* ctx, const DatasetBase* input, + const DataTypeVector& output_types, + const std::vector& output_shapes) + : DatasetBase(DatasetContext(ctx)), + optimized_input_(nullptr), + input_(input), + output_types_(output_types), + output_shapes_(output_shapes) { + input_->Ref(); + } + + ~GraphRewriteDataset() override; + + // Runs Grappler to transform the input dataset into optimized_input_ + // dataset. + Status Optimize(OpKernelContext* ctx); + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override; + + const DataTypeVector& output_dtypes() const override { return output_types_; } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + int64 Cardinality() const override { return input_->Cardinality(); } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override; + + private: + class Iterator; + + // Create a Grappler RewriteConfig proto that defines the list of + // optimizations to be run by the Grappler Meta Optimizer. + virtual RewriterConfig CreateGrapplerRewriteConfig() = 0; + + Status ApplyOptimizations(OpKernelContext* ctx, GraphDef* graph_def, + string* output_node); + + DatasetBase* optimized_input_; + FunctionLibraryRuntime* lib_ = nullptr; + std::unique_ptr pflr_ = nullptr; + std::unique_ptr flib_def_ = nullptr; + std::unique_ptr function_handle_cache_ = nullptr; + const DatasetBase* input_; + const DataTypeVector output_types_; + const std::vector output_shapes_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_GRAPH_REWRITE_DATASET_H_ diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc index 6047dc5f3f4..17094e30017 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -14,26 +14,11 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/core/common_runtime/device_mgr.h" -#include "tensorflow/core/common_runtime/graph_runner.h" -#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/framework/device_base.h" -#include "tensorflow/core/framework/function_handle_cache.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/graph/graph_constructor.h" -#include "tensorflow/core/graph/graph_def_builder.h" -#include "tensorflow/core/grappler/clusters/virtual_cluster.h" -#include "tensorflow/core/grappler/graph_view.h" -#include "tensorflow/core/grappler/grappler_item.h" -#include "tensorflow/core/grappler/grappler_item_builder.h" -#include "tensorflow/core/grappler/optimizers/data/function_utils.h" -#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" -#include "tensorflow/core/grappler/optimizers/meta_optimizer.h" -#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/kernels/data/graph_rewrite_dataset.h" #include "tensorflow/core/lib/random/random.h" -#include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { @@ -71,235 +56,20 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { } private: - class Dataset : public DatasetBase { + class Dataset : public GraphRewriteDataset { public: Dataset(OpKernelContext* ctx, const DatasetBase* input, const std::vector& optimizations, const DataTypeVector& output_types, const std::vector& output_shapes) - : DatasetBase(DatasetContext(ctx)), - optimized_input_(nullptr), - input_(input), - optimizations_(optimizations), - output_types_(output_types), - output_shapes_(output_shapes) { - input_->Ref(); - } - - ~Dataset() override { - input_->Unref(); - if (optimized_input_) { - optimized_input_->Unref(); - } - } - - std::unique_ptr MakeIteratorInternal( - const string& prefix) const override { - // We do not add a token for the optimization dataset to the prefix. The - // prefix is used to identify checkpoint elements and since the - // optimization dataset is excluded from the checkpoint, adding a token - // here would result in invalid checkpoint identifiers. - return absl::make_unique(Iterator::Params{this, prefix}); - } - - Status Optimize(OpKernelContext* ctx) { - GraphDefBuilder b; - DatasetGraphDefBuilder db(&b); - Node* input_node = nullptr; - SerializationContext::Params params; - std::vector> input_list; - params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); - params.input_list = &input_list; - params.optimization_only = true; - SerializationContext serialization_ctx(params); - TF_RETURN_IF_ERROR( - db.AddInputDataset(&serialization_ctx, input_, &input_node)); - string output_node = input_node->name(); - - GraphDef graph_def; - TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def)); - VLOG(3) << "Before optimization: " << graph_def.DebugString(); - - TF_RETURN_IF_ERROR(ApplyOptimizations(ctx, &graph_def, &output_node)); - VLOG(3) << "After optimization: " << graph_def.DebugString(); - - // Instantiate the optimized input pipeline by running the optimized graph - // using the optimized function library. - TF_RETURN_IF_ERROR( - ctx->function_library()->Clone(&flib_def_, &pflr_, &lib_)); - - // Create a FunctionHandleCache. - function_handle_cache_ = absl::make_unique(lib_); - - // Some functions may have been modified without having their names - // changed (for example, nested dataset graphs from FlatMap or - // Interleave). To avoid name conflicts, we remove these functions from - // flib_def_ before adding the optimized function library. - for (const FunctionDef& fd : graph_def.library().function()) { - if (flib_def_->Find(fd.signature().name()) != nullptr) { - TF_RETURN_IF_ERROR(flib_def_->RemoveFunction(fd.signature().name())); - } - } - TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph_def.library())); - - Graph graph(OpRegistry::Global()); - TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr)); - std::vector outputs; - GraphRunner graph_runner(ctx->function_library()->device()); - - TF_RETURN_IF_ERROR( - graph_runner.Run(&graph, lib_, input_list, {output_node}, &outputs)); - TF_RETURN_IF_ERROR( - GetDatasetFromVariantTensor(outputs[0], &optimized_input_)); - optimized_input_->Ref(); - return Status::OK(); - } - - const DataTypeVector& output_dtypes() const override { - return output_types_; - } - const std::vector& output_shapes() const override { - return output_shapes_; - } + : GraphRewriteDataset(ctx, input, output_types, output_shapes), + optimizations_(optimizations) {} string DebugString() const override { return "OptimizeDatasetOp::Dataset"; } - int64 Cardinality() const override { return input_->Cardinality(); } - - protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - // We only serialize the optimized dataset to avoid re-running - // optimizations when the input pipeline is restored from a checkpoint. - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, optimized_input_, output)); - return Status::OK(); - } - private: - class Iterator : public DatasetIterator { - public: - explicit Iterator(const Params& params) - : DatasetIterator(params) {} - - Status Initialize(IteratorContext* ctx) override { - IteratorContext::Params params(ctx); - params.lib = dataset()->lib_; - params.function_handle_cache = dataset()->function_handle_cache_.get(); - return dataset()->optimized_input_->MakeIterator( - IteratorContext(std::move(params)), prefix(), &input_impl_); - } - - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { - IteratorContext::Params params(ctx); - params.lib = dataset()->lib_; - params.function_handle_cache = dataset()->function_handle_cache_.get(); - return input_impl_->GetNext(IteratorContext(std::move(params)), - out_tensors, end_of_sequence); - } - - protected: - std::shared_ptr CreateNode( - IteratorContext* ctx, model::Node::Args args) const override { - return model::MakeKnownRatioNode(std::move(args), - /*ratio=*/1); - } - - Status SaveInternal(IteratorStateWriter* writer) override { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); - return Status::OK(); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - return Status::OK(); - } - - private: - std::unique_ptr input_impl_; - }; - - void AddFakeSinks(FunctionDef* function_def) { - int counter = 0; - for (const auto& output : function_def->signature().output_arg()) { - NodeDef* node = function_def->add_node_def(); - tensorflow::grappler::function_utils::SetUniqueFunctionNodeName( - strings::StrCat("FakeSink", counter++), function_def, node); - node->set_op("Identity"); - node->add_input(function_def->ret().at(output.name())); - (*node->mutable_attr())["T"].set_type(output.type()); - - (*function_def->mutable_ret())[output.name()] = - strings::StrCat(node->name(), ":output:0"); - } - } - - void RemoveFakeSinks(FunctionDef* function_def) { - // Map from identity node names to their input tensor strings - std::map identity_map; - for (const auto& node : function_def->node_def()) { - if (node.op() == "Identity" && node.input_size() == 1) { - identity_map[node.name()] = node.input(0); - } - } - for (const auto& output_arg : function_def->signature().output_arg()) { - const string& tensor = function_def->ret().at(output_arg.name()); - const string& output_node = tensor.substr(0, tensor.find(':')); - if (identity_map.find(output_node) != identity_map.end()) { - (*function_def->mutable_ret())[output_arg.name()] = - identity_map.at(output_node); - } - } - } - - Status ApplyOptimizations(OpKernelContext* ctx, GraphDef* graph_def, - string* output_node) { - // Add an identity node as the fetch node, otherwise we might get - // 'placeholder is both fed and fetched' errors in some cases when using - // input list with placeholder dataset nodes. - NodeDef* node = graph_def->mutable_node()->Add(); - tensorflow::grappler::graph_utils::SetUniqueGraphNodeName( - "Sink", graph_def, node); - node->set_op("Identity"); - node->add_input(*output_node); - (*node->mutable_attr())["T"].set_type(DT_VARIANT); - *output_node = node->name(); - - // Add fake sink node to graph and functions to allow rewriting the actual - // sink nodes. - // TODO(b/118820916): When MetaOptimizer adds provisions for function - // retvals to be optimizable, we will no longer need this. - for (auto& function_def : - *graph_def->mutable_library()->mutable_function()) { - AddFakeSinks(&function_def); - } - - // Create metagraph. - MetaGraphDef meta_graph_def; - (*meta_graph_def.mutable_graph_def()) = *graph_def; - - // Grappler determines fetch ops from collection 'train_op'. - CollectionDef collection_def; - auto node_list = collection_def.mutable_node_list(); - node_list->add_value(*output_node); - (*meta_graph_def.mutable_collection_def())["train_op"] = collection_def; - - // Create Grappler item. - tensorflow::grappler::ItemConfig item_config; - item_config.apply_optimizations = true; - std::unique_ptr grappler_item = - tensorflow::grappler::GrapplerItemFromMetaGraphDef( - "graph", meta_graph_def, item_config); - std::unordered_map device_map; - tensorflow::grappler::VirtualCluster cluster(device_map); - - // Run data optimizer using grappler's meta optimizer. - tensorflow::ConfigProto config; - RewriterConfig& rewriter_config = - *config.mutable_graph_options()->mutable_rewrite_options(); + RewriterConfig CreateGrapplerRewriteConfig() override { + RewriterConfig rewriter_config; rewriter_config.add_optimizers(kOptimizerName); rewriter_config.set_meta_optimizer_iterations( RewriterConfig_NumIterationsType_ONE); @@ -311,30 +81,10 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { for (const auto& opt : optimizations_) { custom_optimizations_list->add_s(opt); } - - TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer( - *grappler_item, config, ctx->device(), &cluster, graph_def)); - - // Remove fake sinks after optimizations are done. - // TODO(b/118820916): When MetaOptimizer adds provisions for function - // retvals to be optimizable, we will no longer need this. - for (auto& function_def : - *graph_def->mutable_library()->mutable_function()) { - RemoveFakeSinks(&function_def); - } - - return Status::OK(); + return rewriter_config; } - DatasetBase* optimized_input_; - FunctionLibraryRuntime* lib_ = nullptr; - std::unique_ptr pflr_ = nullptr; - std::unique_ptr flib_def_ = nullptr; - std::unique_ptr function_handle_cache_ = nullptr; - const DatasetBase* input_; const std::vector optimizations_; - const DataTypeVector output_types_; - const std::vector output_shapes_; }; const int graph_def_version_; diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc index 316e405188c..41348972d38 100644 --- a/tensorflow/core/ops/experimental_dataset_ops.cc +++ b/tensorflow/core/ops/experimental_dataset_ops.cc @@ -190,6 +190,14 @@ REGISTER_OP("ExperimentalMapAndBatchDataset") return shape_inference::ScalarShape(c); }); +REGISTER_OP("ExperimentalRebatchDataset") + .Input("input_dataset: variant") + .Input("num_workers: int64") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + REGISTER_OP("ExperimentalMapDataset") .Input("input_dataset: variant") .Input("other_arguments: Targuments") diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD index 04819130642..d0e5abcd5b1 100644 --- a/tensorflow/python/data/experimental/kernel_tests/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/BUILD @@ -470,6 +470,20 @@ py_library( ], ) +py_test( + name = "rebatch_dataset_test", + size = "small", + srcs = ["rebatch_dataset_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/experimental/ops:batching", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + ], +) + py_test( name = "rejection_resample_test", size = "medium", diff --git a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py new file mode 100644 index 00000000000..0dcbd563fac --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py @@ -0,0 +1,60 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the private `_RebatchDataset` transformation.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +@test_util.run_all_in_graph_and_eager_modes +class RebatchDatasetTest(test_base.DatasetTestBase): + + def testBasic(self): + dataset = dataset_ops.Dataset.range(1024).batch(32, drop_remainder=True) + rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) + self.assertEqual( + [[32]], [ts.as_list() for ts in nest.flatten(dataset.output_shapes)]) + self.assertEqual( + [[8]], + [ts.as_list() for ts in nest.flatten(rebatched_dataset.output_shapes)]) + + expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension + self.assertDatasetProduces(rebatched_dataset, expected_output) + + def testScalarInputError(self): + dataset = dataset_ops.Dataset.range(1024) + with self.assertRaisesRegexp(ValueError, "at least one dimension"): + batching._RebatchDataset(dataset, num_workers=4) + + def testUnknownBatchSizeError(self): + dataset = dataset_ops.Dataset.range(1024).batch(32) + with self.assertRaisesRegexp(ValueError, "unknown batch size datasets"): + batching._RebatchDataset(dataset, num_workers=4) + + def testNotDivisibleError(self): + dataset = dataset_ops.Dataset.range(1024).batch(32, drop_remainder=True) + with self.assertRaisesRegexp(ValueError, "not divisible by"): + batching._RebatchDataset(dataset, num_workers=5) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD index 4fd2a2ec4bf..dc168d867ae 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD @@ -408,6 +408,24 @@ py_test( ], ) +py_test( + name = "rebatch_dataset_serialization_test", + size = "small", + srcs = ["rebatch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_oss", + "no_pip", + "no_windows", + ], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/experimental/ops:batching", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_test( name = "padded_batch_dataset_serialization_test", size = "medium", diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/rebatch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/rebatch_dataset_serialization_test.py new file mode 100644 index 00000000000..b30db589069 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/rebatch_dataset_serialization_test.py @@ -0,0 +1,41 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the _RebatchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class RebatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testCore(self): + + def build_dataset(num_elements, batch_size): + return batching._RebatchDataset( + dataset_ops.Dataset.range(num_elements).batch( + 4 * batch_size, drop_remainder=True), + num_workers=4) + + self.run_core_tests(lambda: build_dataset(200, 10), None, 20) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/ops/batching.py b/tensorflow/python/data/experimental/ops/batching.py index f0cf7f0a995..39cb0a68f8d 100644 --- a/tensorflow/python/data/experimental/ops/batching.py +++ b/tensorflow/python/data/experimental/ops/batching.py @@ -645,3 +645,34 @@ def map_and_batch(map_func, num_parallel_calls, drop_remainder) return _apply_fn + + +class _RebatchDataset(dataset_ops.UnaryDataset): + """A `Dataset` that divides the batch size by `num_workers`.""" + + def __init__(self, input_dataset, num_workers): + self._input_dataset = input_dataset + output_shapes = input_dataset.output_shapes + if len(output_shapes) < 1: + raise ValueError("Input shape should have at least one dimension.") + if not output_shapes.dims[0].value: + raise ValueError("Cannot rebatch unknown batch size datasets.") + if output_shapes.dims[0].value % num_workers != 0: + raise ValueError( + "First dim of input shape: %d is not divisible by num_workers: %d" % + (output_shapes[0], num_workers)) + output_dims = [d for d in output_shapes.dims] + output_dims[0] = output_dims[0] // num_workers + output_shapes = tensor_shape.TensorShapeV1(output_dims) + self._structure = structure.convert_legacy_structure( + self._input_dataset.output_types, output_shapes, + self._input_dataset.output_classes) + variant_tensor = ged_ops.experimental_rebatch_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access + num_workers=num_workers, + **dataset_ops.flat_structure(self)) + super(_RebatchDataset, self).__init__(input_dataset, variant_tensor) + + @property + def _element_structure(self): + return self._structure diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index e7375af1822..0b2a15f5216 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -1096,6 +1096,10 @@ tf_module { name: "ExperimentalRandomDataset" argspec: "args=[\'seed\', \'seed2\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "ExperimentalRebatchDataset" + argspec: "args=[\'input_dataset\', \'num_workers\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "ExperimentalScanDataset" argspec: "args=[\'input_dataset\', \'initial_state\', \'other_arguments\', \'f\', \'output_types\', \'output_shapes\', \'preserve_cardinality\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index e7375af1822..0b2a15f5216 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -1096,6 +1096,10 @@ tf_module { name: "ExperimentalRandomDataset" argspec: "args=[\'seed\', \'seed2\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "ExperimentalRebatchDataset" + argspec: "args=[\'input_dataset\', \'num_workers\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "ExperimentalScanDataset" argspec: "args=[\'input_dataset\', \'initial_state\', \'other_arguments\', \'f\', \'output_types\', \'output_shapes\', \'preserve_cardinality\'], varargs=None, keywords=None, defaults=None"