From ef9f0e8f2f204c45462544cb0f14b9d33061de29 Mon Sep 17 00:00:00 2001 From: Rachel Lim Date: Thu, 1 Aug 2019 17:42:39 -0700 Subject: [PATCH] [tf.data] Change the behavior of RebatchDataset when 1) drop_remainder = True or 2) batch size is not divisible by global batch size. In these cases, instead of mutating the batch size directly, we add a `.flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x).batch(new_batch_size))` after the batch. This has three effects: 1) Changes the behavior of _RebatchDataset, such that at each step (num_workers minibatches), the total number of examples is the same as the global batch size (v.s. before, when it was rounded up when global_batch_size is not divisible by num_workers) 2) Preserve behavior of `drop_remainder` (wrt to the global batch) 3) Probably less performant, since from_tensor_slices and batch both require data copies. PiperOrigin-RevId: 261233882 --- .../optimizers/data/function_utils.cc | 21 +- .../grappler/optimizers/data/function_utils.h | 6 +- .../optimizers/data/function_utils_test.cc | 12 + .../grappler/optimizers/data/graph_utils.cc | 46 +- .../grappler/optimizers/data/graph_utils.h | 18 +- .../optimizers/data/graph_utils_test.cc | 58 +++ .../core/grappler/optimizers/data/rebatch.cc | 402 ++++++++++++++---- .../kernel_tests/rebatch_dataset_test.py | 331 +++++++------- .../data/experimental/ops/distribute.py | 16 +- 9 files changed, 642 insertions(+), 268 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.cc b/tensorflow/core/grappler/optimizers/data/function_utils.cc index 20536910db1..40f4f24b03f 100644 --- a/tensorflow/core/grappler/optimizers/data/function_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/function_utils.cc @@ -88,18 +88,27 @@ void ReplaceReferences(const string& from, const string& to, void AddFunctionOutputWithUniqueName(StringPiece prefix, StringPiece output_tensor_name, - FunctionDef* function, DataType dt) { + FunctionDef* fdef, DataType dtype) { string name = string(prefix); - int id = function->signature().output_arg_size(); - while (ContainsFunctionOutputWithName(name, *function)) { + int id = fdef->signature().output_arg_size(); + while (ContainsFunctionOutputWithName(name, *fdef)) { name = strings::StrCat(prefix, "/_", id); ++id; } - auto* output = function->mutable_signature()->mutable_output_arg()->Add(); + auto* output = fdef->mutable_signature()->mutable_output_arg()->Add(); output->set_name(name); - output->set_type(dt); + output->set_type(dtype); - (*function->mutable_ret())[name] = string(output_tensor_name); + (*fdef->mutable_ret())[name] = string(output_tensor_name); +} + +OpDef_ArgDef* AddFunctionInput(const string& name, FunctionDef* fdef, + DataType dtype) { + auto* input_arg = fdef->mutable_signature()->mutable_input_arg()->Add(); + input_arg->set_type(dtype); + input_arg->set_name(name); + + return input_arg; } NodeDef* AddNode(StringPiece name, StringPiece op, diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.h b/tensorflow/core/grappler/optimizers/data/function_utils.h index 79271e8ad0c..8941e58c558 100644 --- a/tensorflow/core/grappler/optimizers/data/function_utils.h +++ b/tensorflow/core/grappler/optimizers/data/function_utils.h @@ -61,7 +61,11 @@ void ReplaceReferences(const string& from, const string& to, FunctionDef* func); // is unique, and maps to output_tensor_name in the ret dict. void AddFunctionOutputWithUniqueName(StringPiece prefix, StringPiece output_tensor_name, - FunctionDef* function, DataType dt); + FunctionDef* fdef, DataType dtype); + +// Adds an input to a FunctionDef. +OpDef_ArgDef* AddFunctionInput(const string& name, FunctionDef* fdef, + DataType dtype); // Adds a node to a FunctionDef. NodeDef* AddNode(StringPiece name, StringPiece op, diff --git a/tensorflow/core/grappler/optimizers/data/function_utils_test.cc b/tensorflow/core/grappler/optimizers/data/function_utils_test.cc index 8ae0cde4cd1..9a53b00275e 100644 --- a/tensorflow/core/grappler/optimizers/data/function_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/function_utils_test.cc @@ -60,6 +60,18 @@ TEST(FunctionUtilsTest, AddFunctionOutputWithUniqueName) { EXPECT_EQ(function.ret().at("y/_1"), "two"); } +TEST(FunctionUtilsTest, AddFunctionInput) { + FunctionDef fdef; + auto arg0 = AddFunctionInput("arg0", &fdef, DT_INT32); + auto arg1 = AddFunctionInput("arg1", &fdef, DT_BOOL); + EXPECT_EQ(fdef.signature().input_arg().data()[0], arg0); + EXPECT_EQ(arg0->name(), "arg0"); + EXPECT_EQ(arg0->type(), DT_INT32); + EXPECT_EQ(fdef.signature().input_arg().data()[1], arg1); + EXPECT_EQ(arg1->name(), "arg1"); + EXPECT_EQ(arg1->type(), DT_BOOL); +} + TEST(FunctionUtilsTest, ContainsFunctionNodeWithName) { FunctionDef function = test::function::XTimesTwo(); EXPECT_FALSE(ContainsFunctionNodeWithName( diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc index a11717e270a..ce56b7c3b0e 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc @@ -158,6 +158,46 @@ NodeDef* AddScalarConstNode(StringPiece v, MutableGraphView* graph) { graph); } +Status GetScalarConstNodeValueHelper( + const NodeDef& node, DataType dtype, + const std::function& get_value) { + if (node.op() != kConstOpName) + return errors::InvalidArgument("Node ", node.name(), + " is not a Const node. Op: ", node.op()); + + Tensor tensor; + TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &tensor)); + if (!TensorShapeUtils::IsScalar(tensor.shape())) { + return errors::InvalidArgument( + "Node ", node.name(), + " should be a scalar but has shape: ", tensor.shape()); + } + + if (tensor.dtype() != dtype) { + return errors::InvalidArgument( + "Node ", node.name(), " should have type ", DataTypeString(dtype), + " but has type: ", DataTypeString(tensor.dtype())); + } + + get_value(tensor); + + return Status::OK(); +} + +template <> +Status GetScalarConstNodeValue(const NodeDef& node, int64* value) { + return GetScalarConstNodeValueHelper( + node, DT_INT64, + [value](const Tensor& tensor) { *value = tensor.scalar()(); }); +} + +template <> +Status GetScalarConstNodeValue(const NodeDef& node, bool* value) { + return GetScalarConstNodeValueHelper( + node, DT_BOOL, + [value](const Tensor& tensor) { *value = tensor.scalar()(); }); +} + bool Compare(const GraphDef& g1, const GraphDef& g2) { if (g1.node_size() != g2.node_size()) { return false; @@ -240,12 +280,12 @@ NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph, return graph.GetRegularFanin(input_port).node; } -Status GetDatasetOutputTypesAttr(const NodeDef& node, AttrValue* output_types) { +Status GetDatasetOutputTypesAttr(const NodeDef& node, + DataTypeVector* output_types) { // We don't name the output_types attr consistently, so should check for both. for (const string& attr_name : {"output_types", "Toutput_types"}) { if (node.attr().contains(attr_name)) { - *output_types = node.attr().at(attr_name); - return Status::OK(); + return GetNodeAttr(node, attr_name, output_types); } } return errors::InvalidArgument("Could not find output_types attr for node: ", diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h index 341eec46158..87c9831126f 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.h +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h @@ -80,6 +80,21 @@ NodeDef* AddScalarConstNode(int64 v, MutableGraphView* graph); template <> NodeDef* AddScalarConstNode(StringPiece v, MutableGraphView* graph); +// Retrieves the value of a const node. Returns an error +// if the node is not const, or its value is of a different type. +template +Status GetScalarConstNodeValue(const NodeDef& node, T* value) { + // is_same is an idiomatic hack for making it compile if not instantiated. + // Replacing with false will result in a compile-time error. + static_assert(!std::is_same::value, + "Invalid specialization of this method fo rtype T."); +} + +template <> +Status GetScalarConstNodeValue(const NodeDef& node, int64* value); +template <> +Status GetScalarConstNodeValue(const NodeDef& node, bool* value); + // Checks whether the two graphs are the same. bool Compare(const GraphDef& g1, const GraphDef& g2); @@ -114,7 +129,8 @@ NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph, int64 i); // Gets the attr corresponding to a dataset node's output types, if it exists. -Status GetDatasetOutputTypesAttr(const NodeDef& node, AttrValue* output_types); +Status GetDatasetOutputTypesAttr(const NodeDef& node, + DataTypeVector* output_types); // Returns the list of indices of all nodes with the given op or empty list if // no such node exists. diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc index 93df72ab623..125f2e3ea32 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc @@ -85,6 +85,64 @@ TEST(GraphUtilsTest, AddScalarConstNodeString) { EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello"); } +TEST(GraphUtilsTest, GetScalarConstNodeInt64) { + GraphDef graph_def; + MutableGraphView graph(&graph_def); + NodeDef* int64_node = AddScalarConstNode(128, &graph); + int64 result; + EXPECT_TRUE(GetScalarConstNodeValue(*int64_node, &result).ok()); + EXPECT_EQ(result, 128); +} + +TEST(GraphUtilsTest, GetScalarConstNodeBool) { + GraphDef graph_def; + MutableGraphView graph(&graph_def); + NodeDef* bool_node = AddScalarConstNode(true, &graph); + bool result; + EXPECT_TRUE(GetScalarConstNodeValue(*bool_node, &result).ok()); + EXPECT_EQ(result, true); +} + +TEST(GraphUtilsTest, GetScalarConstNodeErrorWithNonConst) { + GraphDef graph_def; + MutableGraphView graph(&graph_def); + NodeDef* non_const = AddScalarPlaceholder(DT_INT64, &graph); + int64 result; + Status s = GetScalarConstNodeValue(*non_const, &result); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_message(), + "Node Placeholder is not a Const node. Op: Placeholder"); +} + +TEST(GraphUtilsTest, GetScalarConstNodeErrorWithType) { + GraphDef graph_def; + MutableGraphView graph(&graph_def); + NodeDef* int64_node = AddScalarConstNode(128, &graph); + bool result; + Status s = GetScalarConstNodeValue(*int64_node, &result); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_message(), + "Node Const should have type bool but has type: int64"); +} + +TEST(GraphUtilsTest, GetScalarConstNodeErrorWithVector) { + NodeDef node; + node.set_name("Const"); + node.set_op("Const"); + + (*node.mutable_attr())["dtype"].set_type(DT_INT64); + auto tensor = (*node.mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_INT64); + tensor->mutable_tensor_shape()->mutable_dim()->Add()->set_size(1); + tensor->add_int64_val(128); + + int64 result; + Status s = GetScalarConstNodeValue(node, &result); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_message(), + "Node Const should be a scalar but has shape: [1]"); +} + TEST(GraphUtilsTest, Compare) { GraphDef graph_def_a; MutableGraphView graph_a(&graph_def_a); diff --git a/tensorflow/core/grappler/optimizers/data/rebatch.cc b/tensorflow/core/grappler/optimizers/data/rebatch.cc index d6e86f7a0d9..879576bf9f7 100644 --- a/tensorflow/core/grappler/optimizers/data/rebatch.cc +++ b/tensorflow/core/grappler/optimizers/data/rebatch.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/utils/functions.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/padding.h" namespace tensorflow { namespace grappler { @@ -50,14 +51,19 @@ constexpr char kConstOp[] = "Const"; constexpr char kIdentityOp[] = "Identity"; constexpr char kSubOp[] = "Sub"; constexpr char kTruncateDivOp[] = "TruncateDiv"; +constexpr char kOutputShapesAttr[] = "output_shapes"; +constexpr char kOutputTypesAttr[] = "output_types"; +constexpr char kTOutputTypesAttr[] = "Toutput_types"; +constexpr char kBatchOp[] = "BatchDataset"; +constexpr char kBatchV2Op[] = "BatchDatasetV2"; +constexpr char kPaddedBatchOp[] = "PaddedBatchDataset"; +constexpr char kPaddedBatchV2Op[] = "PaddedBatchDatasetV2"; +constexpr char kMapAndBatchOp[] = "MapAndBatchDataset"; +constexpr char kExperimentalMapAndBatchOp[] = "ExperimentalMapAndBatchDataset"; constexpr std::array kBatchDatasetOps = { - "BatchDataset", - "BatchDatasetV2", - "ExperimentalMapAndBatchDataset", - "MapAndBatchDataset", - "PaddedBatchDataset", - "PaddedBatchDatasetV2"}; + kBatchOp, kBatchV2Op, kMapAndBatchOp, kExperimentalMapAndBatchOp, + kPaddedBatchOp, kPaddedBatchV2Op}; constexpr std::array kMultipleInputsDatasetOps = { "ConcatenateDataset", @@ -117,17 +123,24 @@ constexpr std::array kSourceDatasetOps = { "TFRecordDataset", }; -NodeDef* AddBinaryNode(const string& input_x, const string& input_y, - const string& op, DataType type, - MutableGraphView* graph) { +NodeDef MakeBinaryNode(const string& input_x, const string& input_y, + const string& op, DataType dtype) { NodeDef node; node.set_op(op); node.add_input(input_x); node.add_input(input_y); - graph_utils::SetUniqueGraphNodeName(op, graph->graph(), &node); - AddNodeAttr("T", type, &node); + AddNodeAttr("T", dtype, &node); - return graph->AddNode(std::move(node)); + return node; +} + +NodeDef* AddBinaryNode(const string& input_x, const string& input_y, + const string& op, DataType type, FunctionDef* fdef) { + NodeDef* node = fdef->add_node_def(); + *node = MakeBinaryNode(input_x, input_y, op, type); + function_utils::SetUniqueFunctionNodeName(op, fdef, node); + + return node; } // Adds a Const node to the FunctionDef. @@ -161,6 +174,30 @@ Status AddConstIntNode(gtl::ArraySlice values, const TensorShape& shape, return Status::OK(); } +Status AddConstInt64Node(int64 value, FunctionDef* fdef, NodeDef** result) { + *result = fdef->add_node_def(); + Tensor t(value); + TF_RETURN_IF_ERROR(NodeDefBuilder("", "Const") + .Attr("dtype", DT_INT64) + .Attr("value", t) + .Finalize(*result)); + function_utils::SetUniqueFunctionNodeName("rebatch/const", fdef, *result); + + return Status::OK(); +} + +Status AddConstBoolNode(bool value, FunctionDef* fdef, NodeDef** result) { + *result = fdef->add_node_def(); + Tensor t(value); + TF_RETURN_IF_ERROR(NodeDefBuilder("", "Const") + .Attr("dtype", DT_BOOL) + .Attr("value", t) + .Finalize(*result)); + function_utils::SetUniqueFunctionNodeName("rebatch/const", fdef, *result); + + return Status::OK(); +} + Status AddShapeNode(const NodeDefBuilder::NodeOut& input, FunctionDef* fdef, NodeDef** result) { *result = fdef->add_node_def(); @@ -271,58 +308,69 @@ Status GetBatchDim(AttrValue output_shapes, int* batch_dim) { Status UpdateOutputShapes(const string& node_name, int64 num_workers, MutableGraphView* graph) { NodeDef* node = graph->GetNode(node_name); - if (node->attr().contains("output_shapes")) { - AttrValue output_shapes = node->attr().at("output_shapes"); + if (node->attr().contains(kOutputShapesAttr)) { + AttrValue output_shapes = node->attr().at(kOutputShapesAttr); for (auto& shape : *output_shapes.mutable_list()->mutable_shape()) { if (!shape.unknown_rank() && shape.dim(0).size() != -1) { shape.mutable_dim(0)->set_size(shape.dim(0).size() / num_workers); } } - (*node->mutable_attr())["output_shapes"] = output_shapes; + (*node->mutable_attr())[kOutputShapesAttr] = output_shapes; } return Status::OK(); } +// Helper function to get the batch_size input node for a give batch node. +int64 GetBatchSizeArgIndex(const NodeDef& batch_node) { + if (batch_node.op() == kExperimentalMapAndBatchOp || + batch_node.op() == kMapAndBatchOp) { + // For MapAndBatch we take the 3rd last input. + return batch_node.input_size() - 3; + } + // For all the batching datasets the batch_size is input number 1 except for + // MapAndBatchDataset. + return 1; +} + +Status MakeNewBatchSizeNode(const string& global_batch_size_name, + int64 num_workers, FunctionDef* fdef, + NodeDef** result) { + NodeDef* one_node; + TF_RETURN_IF_ERROR(AddConstInt64Node(1, fdef, &one_node)); + NodeDef* num_workers_node; + TF_RETURN_IF_ERROR(AddConstInt64Node(num_workers, fdef, &num_workers_node)); + + NodeDef* numerator_node = + AddBinaryNode(global_batch_size_name, + strings::StrCat(num_workers_node->name(), ":output:0"), + kAddOp, DT_INT64, fdef); + numerator_node = AddBinaryNode( + strings::StrCat(numerator_node->name(), ":z:0"), + strings::StrCat(one_node->name(), ":output:0"), kSubOp, DT_INT64, fdef); + + *result = + AddBinaryNode(strings::StrCat(numerator_node->name(), ":z:0"), + strings::StrCat(num_workers_node->name(), ":output:0"), + kTruncateDivOp, DT_INT64, fdef); + return Status::OK(); +} + // Given a "batch" dataset node, we replace the `batch_size` input with a new -// input that corresponds to the original input divided by `num_workers`. If -// `num_workers` does not divide `batch_size` evenly, the value is rounded up. +// input that corresponds to the original input divided by `num_workers`. Status MutateBatchSize(const NodeDef& node, int64 num_workers, MutableGraphView* graph) { // For all the batching datasets the batch_size is input number 1 except for // MapAndBatchDataset. - int64 batch_size_arg_index = 1; - if (node.op() == "ExperimentalMapAndBatchDataset" || - node.op() == "MapAndBatchDataset") { - // For MapAndBatch we take the 3rd last input. - batch_size_arg_index = node.input_size() - 3; - } + int64 batch_size_arg_index = GetBatchSizeArgIndex(node); NodeDef* batch_size_node = graph_utils::GetInputNode(node, *graph, batch_size_arg_index); - NodeDef* new_batch_size_node; - if (batch_size_node->op() == kConstOp) { - Tensor batch_size_tensor; - TF_RETURN_IF_ERROR( - GetNodeAttr(*batch_size_node, "value", &batch_size_tensor)); - if (!TensorShapeUtils::IsScalar(batch_size_tensor.shape())) { - return errors::Internal("Batch size node shape should be scalar"); - } - int64 batch_size = batch_size_tensor.scalar()(); - batch_size = (batch_size + num_workers - 1) / num_workers; - new_batch_size_node = - graph_utils::AddScalarConstNode(batch_size, graph); - } else { - NodeDef* one_node = graph_utils::AddScalarConstNode(1, graph); - NodeDef* num_workers_node = - graph_utils::AddScalarConstNode(num_workers, graph); - NodeDef* numerator_node = - AddBinaryNode(batch_size_node->name(), num_workers_node->name(), kAddOp, - DT_INT64, graph); - numerator_node = AddBinaryNode(numerator_node->name(), one_node->name(), - kSubOp, DT_INT64, graph); - new_batch_size_node = - AddBinaryNode(numerator_node->name(), num_workers_node->name(), - kTruncateDivOp, DT_INT64, graph); - } + int64 batch_size; + TF_RETURN_IF_ERROR( + graph_utils::GetScalarConstNodeValue(*batch_size_node, &batch_size)); + DCHECK_EQ(batch_size % num_workers, 0); + batch_size = batch_size / num_workers; + NodeDef* new_batch_size_node = + graph_utils::AddScalarConstNode(batch_size, graph); // We don't call UpdateFanouts here because CSE elimination might lead to // multiple nodes sharing the same batch size constant node. This is also // why we don't delete batch_size_node as well. @@ -331,6 +379,181 @@ Status MutateBatchSize(const NodeDef& node, int64 num_workers, return Status::OK(); } +Status AddFlatMapNode(const string& input_dataset, + gtl::ArraySlice other_arguments, + gtl::ArraySlice t_arguments, + const FunctionDef& flat_map_fn, + const AttrValue& output_shapes, + const DataTypeVector& output_types, + FunctionLibraryDefinition* flib, MutableGraphView* graph, + NodeDef** result) { + TF_RETURN_IF_ERROR(flib->AddFunctionDef(flat_map_fn)); + AttrValue f; + f.mutable_func()->set_name(flat_map_fn.signature().name()); + + NodeDef flat_map_node; + flat_map_node.set_op("FlatMapDataset"); + flat_map_node.add_input(input_dataset); + for (const auto& arg : other_arguments) { + flat_map_node.add_input(arg); + } + AddNodeAttr("f", f, &flat_map_node); + AddNodeAttr("Targuments", t_arguments, &flat_map_node); + AddNodeAttr(kOutputShapesAttr, output_shapes, &flat_map_node); + AddNodeAttr(kOutputTypesAttr, output_types, &flat_map_node); + + graph_utils::SetUniqueGraphNodeName("rebatch/flat_map", graph->graph(), + &flat_map_node); + *result = graph->AddNode(std::move(flat_map_node)); + return Status::OK(); +} + +// def flat_map_fn(*batched_components): +// ds = tf.data.Dataset.from_tensor_slices(batched_components) +// return ds.batch(minibatch_size, drop_remainder=False) +Status CreateFlatMapFnWithBatch(const DataTypeVector& dtypes, int64 num_workers, + FunctionDef* result) { + NodeDef* tensor_slice_node = result->add_node_def(); + tensor_slice_node->set_op("TensorSliceDataset"); + for (int i = 0; i < dtypes.size(); ++i) { + auto* input_arg = function_utils::AddFunctionInput( + strings::StrCat("args_", i), result, dtypes.at(i)); + tensor_slice_node->add_input(input_arg->name()); + } + AddNodeAttr(kTOutputTypesAttr, dtypes, tensor_slice_node); + + // The output_shapes attr here doesn't make a difference, since we + // set the output_shapes of the external FlatMap node. + AttrValue shapes; + SetUnknownShapes(dtypes.size(), &shapes); + AddNodeAttr(kOutputShapesAttr, shapes, tensor_slice_node); + function_utils::SetUniqueFunctionNodeName("rebatch/from_tensor_slices", + result, tensor_slice_node); + + NodeDef* false_node; + TF_RETURN_IF_ERROR(AddConstBoolNode(false, result, &false_node)); + NodeDef* batch_node = result->add_node_def(); + batch_node->set_op(kBatchV2Op); + batch_node->add_input( + strings::StrCat(tensor_slice_node->name(), ":handle:0")); + + // `batch_size` input + // Here, we capture the original batch size from outside the flat map fn. + auto* original_batch_size = + function_utils::AddFunctionInput("captured_batch_size", result, DT_INT64); + NodeDef* new_batch_size; + TF_RETURN_IF_ERROR(MakeNewBatchSizeNode( + original_batch_size->name(), num_workers, result, &new_batch_size)); + batch_node->add_input(strings::StrCat(new_batch_size->name(), ":z:0")); + + // `drop_remainder` input + batch_node->add_input(strings::StrCat(false_node->name(), ":output:0")); + AddNodeAttr(kOutputTypesAttr, dtypes, batch_node); + AddNodeAttr(kOutputShapesAttr, shapes, batch_node); + function_utils::SetUniqueFunctionNodeName("rebatch/batch", result, + batch_node); + function_utils::AddFunctionOutputWithUniqueName( + "output", strings::StrCat(batch_node->name(), ":handle:0"), result, + DT_VARIANT); + // Because TensorSliceDataset is stateful, we set the function to stateful. + result->mutable_signature()->set_is_stateful(true); + + return Status::OK(); +} + +// Rewrite graph to add +// `.flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x). +// batch(minibatch_size, drop_remainder=False))` +// after the batch node. This ensures that the sum of the minibatch sizes +// in a step adds up to the global batch size. However, since this adds +// additional data copies (both from_tensor_slices and batch), we only use +// this approach when necessary, i.e. when we need to drop remainder on the +// global batch, or when the global batch size does not divide num_workers +// evenly. +Status AppendFlatMap(const NodeDef& batch_node, int64 num_workers, + FunctionLibraryDefinition* flib, MutableGraphView* graph) { + // `.flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x). + // batch(minibatch_size, drop_remainder=False))` + FunctionDef flat_map_fn; + FunctionDefLibrary lib = flib->ToProto(); + graph_utils::SetUniqueGraphFunctionName("rebatch/flat_map_fn", &lib, + &flat_map_fn); + DataTypeVector dtypes; + TF_RETURN_IF_ERROR( + graph_utils::GetDatasetOutputTypesAttr(batch_node, &dtypes)); + TF_RETURN_IF_ERROR( + CreateFlatMapFnWithBatch(dtypes, num_workers, &flat_map_fn)); + + int64 batch_size_index = GetBatchSizeArgIndex(batch_node); + + NodeDef* flat_map_node; + + AttrValue output_shapes = batch_node.attr().at(kOutputShapesAttr); + for (auto& shape : *output_shapes.mutable_list()->mutable_shape()) { + if (!shape.unknown_rank() && shape.dim(0).size() != -1) { + // Because the flat map function uses drop_remainder = False, + // the shape might be unknown + auto old_dim = shape.dim(0).size(); + auto new_dim = old_dim % num_workers == 0 ? old_dim / num_workers : -1; + shape.mutable_dim(0)->set_size(new_dim); + } + } + + TF_RETURN_IF_ERROR(AddFlatMapNode(strings::StrCat(batch_node.name(), ":0"), + {batch_node.input(batch_size_index)}, + {DT_INT64}, flat_map_fn, output_shapes, + dtypes, flib, graph, &flat_map_node)); + + TF_RETURN_IF_ERROR( + graph->UpdateFanouts(batch_node.name(), flat_map_node->name())); + + return Status::OK(); +} + +// There are several things we do here, depending on the values of +// batch_size and drop_remainder. +// (1) If batch size is known and divisible by num_workers, and drop_remainder +// is known to be False, we mutate the batch size directly. +// .batch(global_batch_size) -> .batch(global_batch_size // num_workers) +// (2) Otherwise, we add a flat_map transformation to preserve the global batch +// size across the workers and to preserve the drop remainder behavior. +bool ShouldMutateBatchSizeDirectly(const NodeDef& batch_node, int64 num_workers, + MutableGraphView* graph) { + int64 batch_size_arg_index = GetBatchSizeArgIndex(batch_node); + NodeDef* batch_size_node = + graph_utils::GetInputNode(batch_node, *graph, batch_size_arg_index); + + int64 batch_size; + Status s = + graph_utils::GetScalarConstNodeValue(*batch_size_node, &batch_size); + // If batch size is unknown or indivisible by num workers, we don't + // mutate it directly + if (!s.ok() || batch_size % num_workers != 0) return false; + + if (batch_node.op() == kBatchOp || batch_node.op() == kPaddedBatchOp) { + // These ops don't have a `drop_remainder` input, and behave like + // drop_remainder is False. + return true; + } + + // drop_remainder is the final input on the other batch nodes. + NodeDef* drop_remainder_node = graph_utils::GetInputNode( + batch_node, *graph, batch_node.input_size() - 1); + bool drop_remainder; + s = graph_utils::GetScalarConstNodeValue(*drop_remainder_node, + &drop_remainder); + return s.ok() && !drop_remainder; +} + +Status RewriteBatchNode(const NodeDef& batch_node, int64 num_workers, + FunctionLibraryDefinition* flib, + MutableGraphView* graph) { + if (ShouldMutateBatchSizeDirectly(batch_node, num_workers, graph)) { + return MutateBatchSize(batch_node, num_workers, graph); + } + return AppendFlatMap(batch_node, num_workers, flib, graph); +} + Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, bool use_fallback, GraphDef* output); @@ -346,7 +569,7 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, bool use_fallback, FunctionLibraryDefinition* flib, MutableGraphView* graph) { if (IsDatasetNodeOfType(node, kBatchDatasetOps)) { - TF_RETURN_IF_ERROR(MutateBatchSize(node, num_workers, graph)); + TF_RETURN_IF_ERROR(RewriteBatchNode(node, num_workers, flib, graph)); } else if (IsDatasetNodeOfType(node, kMultipleInputsDatasetOps)) { // For all multiple input datasets, all inputs are datasets themselves. for (int i = 0; i < node.input_size(); ++i) { @@ -403,7 +626,7 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, } // Add nodes to the function to reshape arg to shape (-1, new_batch_dim, ...) -Status ReshapeComponent(int new_batch_dim, StringPiece arg, DataType dtype, +Status ReshapeComponent(int new_batch_dim, const string& arg, DataType dtype, FunctionDef* fdef, string* result) { // Const with value [0] NodeDef* const_vec_0; @@ -453,47 +676,50 @@ Status ReshapeComponent(int new_batch_dim, StringPiece arg, DataType dtype, return Status::OK(); } -Status CreateFlatMapFn(int new_batch_dim, const AttrValue& types, - FunctionDef* result) { +// def flat_map_fn(*batched_components): +// return tf.data.Dataset.from_tensor_slices( +// [tf.reshape(c, (-1, new_batch_size, ...)) +// for c in batched_components]) +Status CreateFlatMapFnWithReshape(int new_batch_dim, + const DataTypeVector& types, + FunctionDef* result) { std::vector tensor_slice_dataset_inputs; // For each component of the dataset, we reshape it from shape // (old_batch_size, ...) to (-1, new_batch_size, ...) // where new_batch_size = (old_batch_size + num_workers - 1) // num_workers - for (int i = 0; i < types.list().type_size(); ++i) { - string arg = strings::StrCat("args_", i); - auto* input_arg = result->mutable_signature()->mutable_input_arg()->Add(); - input_arg->set_type(types.list().type(i)); - input_arg->set_name(arg); + for (int i = 0; i < types.size(); ++i) { + auto* input_arg = function_utils::AddFunctionInput( + strings::StrCat("args_", i), result, types.at(i)); string reshape_node_name; - TF_RETURN_IF_ERROR(ReshapeComponent( - new_batch_dim, arg, types.list().type(i), result, &reshape_node_name)); + TF_RETURN_IF_ERROR(ReshapeComponent(new_batch_dim, input_arg->name(), + types.at(i), result, + &reshape_node_name)); tensor_slice_dataset_inputs.emplace_back( - strings::StrCat(reshape_node_name, ":output"), 0, types.list().type(i)); + strings::StrCat(reshape_node_name, ":output"), 0, types.at(i)); } // The output_shapes attr here doesn't make a difference, since we // set the output_shapes of the external FlatMap node. AttrValue shapes; - SetUnknownShapes(types.list().type_size(), &shapes); + SetUnknownShapes(types.size(), &shapes); NodeDef* tensor_slice_dataset = result->add_node_def(); TF_RETURN_IF_ERROR(NodeDefBuilder("", "TensorSliceDataset") .Input(tensor_slice_dataset_inputs) .Attr("Toutput_types", types) - .Attr("output_shapes", shapes) + .Attr(kOutputShapesAttr, shapes) .Finalize(tensor_slice_dataset)); function_utils::SetUniqueFunctionNodeName("rebatch/tensor_slice_dataset", result, tensor_slice_dataset); - auto* output_arg = result->mutable_signature()->mutable_output_arg()->Add(); - output_arg->set_name("output"); - output_arg->set_type(DT_VARIANT); + function_utils::AddFunctionOutputWithUniqueName( + "output", strings::StrCat(tensor_slice_dataset->name(), ":handle:0"), + result, DT_VARIANT); + // Because TensorSliceDataset is stateful, we set the function to stateful. result->mutable_signature()->set_is_stateful(true); - (*result->mutable_ret())["output"] = - strings::StrCat(tensor_slice_dataset->name(), ":handle:0"); return Status::OK(); } @@ -525,12 +751,12 @@ Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_workers, // because of the use of the "Reshape" op. This ensures that the error is // surfaced correctly. AttrValue output_shapes; - if (!fetch_node->attr().contains("output_shapes")) { + if (!fetch_node->attr().contains(kOutputShapesAttr)) { return errors::InvalidArgument( "Cannot use rebatching fallback without output_shapes attr. Node: ", fetch_node->name(), " Op: ", fetch_node->op()); } else { - output_shapes = fetch_node->attr().at("output_shapes"); + output_shapes = fetch_node->attr().at(kOutputShapesAttr); } int batch_dim; TF_RETURN_IF_ERROR(GetBatchDim(output_shapes, &batch_dim)); @@ -543,35 +769,25 @@ Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_workers, // Create the flat map fn FunctionDef flat_map_fn; FunctionDefLibrary lib = flib->ToProto(); - graph_utils::SetUniqueGraphFunctionName("flat_map_fn", &lib, &flat_map_fn); + graph_utils::SetUniqueGraphFunctionName("rebatch/flat_map_fn", &lib, + &flat_map_fn); // Get types of input arguments from the output types of the final dataset. - AttrValue output_types; + DataTypeVector output_types; TF_RETURN_IF_ERROR( graph_utils::GetDatasetOutputTypesAttr(*fetch_node, &output_types)); + TF_RETURN_IF_ERROR(CreateFlatMapFnWithReshape(batch_dim / num_workers, + output_types, &flat_map_fn)); + + NodeDef* flat_map_node; + TF_RETURN_IF_ERROR(AddFlatMapNode(strings::StrCat(fetch_node->name(), ":0"), + {}, {}, flat_map_fn, output_shapes, + output_types, flib, graph, &flat_map_node)); TF_RETURN_IF_ERROR( - CreateFlatMapFn(batch_dim / num_workers, output_types, &flat_map_fn)); + UpdateOutputShapes(flat_map_node->name(), num_workers, graph)); - TF_RETURN_IF_ERROR(flib->AddFunctionDef(flat_map_fn)); - AttrValue fn; - fn.mutable_func()->set_name(flat_map_fn.signature().name()); - - NodeDef flat_map_node; TF_RETURN_IF_ERROR( - NodeDefBuilder("", "FlatMapDataset") - .Input(fetch_node->name(), 0, DT_VARIANT) - .Input(std::vector()) // other_arguments - .Attr("f", fn) - .Attr("Targuments", std::vector()) - .Attr("output_types", output_types) - .Attr("output_shapes", output_shapes) - .Finalize(&flat_map_node)); - graph_utils::SetUniqueGraphNodeName("rebatch/flat_map", graph->graph(), - &flat_map_node); - NodeDef* added = graph->AddNode(std::move(flat_map_node)); - TF_RETURN_IF_ERROR(UpdateOutputShapes(added->name(), num_workers, graph)); - - TF_RETURN_IF_ERROR(graph->UpdateFanouts(fetch_node->name(), added->name())); + graph->UpdateFanouts(fetch_node->name(), flat_map_node->name())); return Status::OK(); } @@ -593,8 +809,8 @@ Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, RecursivelyHandleOp(*sink_node, num_workers, use_fallback, &flib, &graph); if (!s.ok()) { if (use_fallback) { - VLOG(1) << "Couldn't find a batch transformation. Using a fallback method" - " to rebatch dataset."; + VLOG(1) << "Failed to rebatch by rewriting the batch transformation (" + << s << "). Using a fallback method instead."; // If RecursivelyHandleOp fails, we reset `graph` to use the original, // graph, since that function may have mutated `graph`. *output = item.graph; diff --git a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py index c36ea688880..09eac5dda50 100644 --- a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py @@ -48,96 +48,98 @@ def _flat_shapes(dataset): return nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)) -@parameterized.named_parameters(("WithDropRemainder", True), - ("WithoutDropRemainder", False)) @test_util.run_all_in_graph_and_eager_modes -class RebatchDatasetTest(test_base.DatasetTestBase): +class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): + drop_remainder_cases = [("WithDropRemainder", True), + ("WithoutDropRemainder", False)] + + @parameterized.named_parameters(drop_remainder_cases) def testBasic(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) - self.assertEqual( - [[32 if drop_remainder else None]], - [ts.as_list() for ts in _flat_shapes(dataset)]) - self.assertEqual( - [[8 if drop_remainder else None]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) + self.assertEqual([[8] if drop_remainder else [None]], + [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(rebatched_dataset, expected_output) - def testScalarInputError(self, _): + def testScalarInputError(self): dataset = dataset_ops.Dataset.range(1024) + distribute._RebatchDataset(dataset.batch(4), num_workers=4) with self.assertRaisesRegexp(ValueError, "at least one dimension"): distribute._RebatchDataset(dataset, num_workers=4) - def testNotDivisible(self, drop_remainder): + @parameterized.named_parameters(drop_remainder_cases) + def testBatchNotDivisibleByNumWorkers(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5) - expected_output = [[k for k in range(i, i + 7)] for i in range(0, 1022, 7)] # pylint: disable=g-complex-comprehension - if not drop_remainder: - expected_output.append([1022, 1023]) + self.assertEqual([[None]], + [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) + expected_output = [] + i = 0 + for _ in range(32): # number of steps + # first four minibatches have seven elements + for _ in range(4): + expected_output.append([k for k in range(i, i + 7)]) + i += 7 + # last minibatch has four elements + expected_output.append([k for k in range(i, i + 4)]) + i += 4 self.assertDatasetProduces(rebatched_dataset, expected_output) - def testTupleOutput(self, drop_remainder): - dataset = ( - dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch( - 32, drop_remainder=drop_remainder)) + def testTupleOutput(self): + dataset = dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(32) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) expected_output = [([k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension [k for k in range(i, i + 8)]) for i in range(0, 1024, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output) - def testNestedDictionaryOutput(self, drop_remainder): + def testNestedDictionaryOutput(self): dataset = dataset_ops.Dataset.range(1024).map( - lambda x: {"a": x, "b": {"c": x}}).batch( - 32, drop_remainder=drop_remainder) + lambda x: {"a": x, "b": {"c": x}}).batch(32) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) expected_output = [{"a": [k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension "b": {"c": [k for k in range(i, i + 8)]}} for i in range(0, 1024, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output) - def testFinalPartialBatchOriginal(self, drop_remainder): + @parameterized.named_parameters(drop_remainder_cases) + def testFinalPartialBatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(1032).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) - self.assertEqual( - [[32 if drop_remainder else None]], - [ts.as_list() for ts in _flat_shapes(dataset)]) - self.assertEqual( - [[8 if drop_remainder else None]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) + self.assertEqual([[8] if drop_remainder else [None]], + [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) - expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1032, 8)] # pylint: disable=g-complex-comprehension + # if drop_remainder, the final partial batch is dropped, even though it + # makes up a complete minibatch. + expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension + if not drop_remainder: + expected_output.append([k for k in range(1024, 1032)]) self.assertDatasetProduces(rebatched_dataset, expected_output) + @parameterized.named_parameters(drop_remainder_cases) def testFinalPartialBatchAfterRebatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(34).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) - self.assertEqual( - [[32 if drop_remainder else None]], - [ts.as_list() for ts in _flat_shapes(dataset)]) - self.assertEqual( - [[8 if drop_remainder else None]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) + self.assertEqual([[8] if drop_remainder else [None]], + [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)] # pylint: disable=g-complex-comprehension if not drop_remainder: expected_output += [[32, 33]] self.assertDatasetProduces(rebatched_dataset, expected_output) - def testMultipleBatches(self, drop_remainder): - dataset = dataset_ops.Dataset.range(128).batch( - 4, drop_remainder=drop_remainder) - dataset = dataset.batch(8, drop_remainder=drop_remainder) - self.assertEqual( - [[8, 4]] if drop_remainder else [[None, None]], - [ts.as_list() for ts in _flat_shapes(dataset)]) + def testMultipleBatches(self): + dataset = dataset_ops.Dataset.range(128).batch(4).batch(8) + self.assertEqual([[None, None]], + [ts.as_list() for ts in _flat_shapes(dataset)]) + # Each element is a list of 8 elements where each element is a list of 4. expected_output = [[[j, j + 1, j + 2, j + 3] # pylint: disable=g-complex-comprehension for j in range(i, i + 32, 4)] # generates 8 elements @@ -145,39 +147,30 @@ class RebatchDatasetTest(test_base.DatasetTestBase): self.assertDatasetProduces(dataset, expected_output) rebatched_dataset = distribute._RebatchDataset(dataset, 4) - self.assertEqual( - [[2, 4]] if drop_remainder else [[None, None]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) + self.assertEqual([[None, None]], + [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # Each element is a list of 2 elements where each element is a list of 4. expected_output = [[[j, j + 1, j + 2, j + 3] # pylint: disable=g-complex-comprehension for j in range(i, i + 8, 4)] # generates 2 elements for i in range(0, 128, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output) - def testMapAndBatch(self, drop_remainder): + def testMapAndBatch(self): dataset = dataset_ops.Dataset.range(1024).apply( - batching.map_and_batch( - math_ops.square, 32, drop_remainder=drop_remainder)) + batching.map_and_batch(math_ops.square, 32)) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) - self.assertEqual( - [[32 if drop_remainder else None]], - [ts.as_list() for ts in _flat_shapes(dataset)]) - self.assertEqual( - [[8 if drop_remainder else None]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) + self.assertEqual([[None]], + [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k**2 for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension for i in range(0, 1024, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output) - def testMapAndBatchWithCapturedInput(self, drop_remainder): + def testMapAndBatchWithCapturedInput(self): captured_t = variables.Variable(42) dataset = dataset_ops.Dataset.range(1024).apply( - batching.map_and_batch( - lambda x: captured_t, 32, drop_remainder=drop_remainder)) + batching.map_and_batch(lambda x: captured_t, 32)) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) - self.assertEqual([[32 if drop_remainder else None]], - [ts.as_list() for ts in _flat_shapes(dataset)]) - self.assertEqual([[8 if drop_remainder else None]], + self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[42 for _ in range(i, i + 8)] # pylint: disable=g-complex-comprehension for i in range(0, 1024, 8)] @@ -185,22 +178,19 @@ class RebatchDatasetTest(test_base.DatasetTestBase): self.assertDatasetProduces( rebatched_dataset, expected_output, requires_initialization=True) - def testPaddedBatch(self, drop_remainder): - dataset = dataset_ops.Dataset.range(128).batch(4).padded_batch( - 8, padded_shapes=[5], drop_remainder=drop_remainder) + def testPaddedBatch(self): + dataset = dataset_ops.Dataset.range(128).batch( + 4, drop_remainder=True).padded_batch( + 8, padded_shapes=[5]) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) - self.assertEqual( - [[8, 5]] if drop_remainder else [[None, 5]], - [ts.as_list() for ts in _flat_shapes(dataset)]) # Each element is a list of 8 elements in which each element is a list of 5 # elements, first four are numbers and the last one is a padded zero. expected_output = [[[j, j + 1, j + 2, j + 3, 0] # pylint: disable=g-complex-comprehension for j in range(i, i + 32, 4)] # generates 8 elements for i in range(0, 128, 32)] self.assertDatasetProduces(dataset, expected_output) - self.assertEqual( - [[2, 5]] if drop_remainder else [[None, 5]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) + self.assertEqual([[None, 5]], + [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # Each element is a list of 2 elements in which each element is a list of 5 # elements, first four are numbers and the last one is a padded zero. expected_output = [[[j, j + 1, j + 2, j + 3, 0] # pylint: disable=g-complex-comprehension @@ -208,32 +198,22 @@ class RebatchDatasetTest(test_base.DatasetTestBase): for i in range(0, 128, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output) - def testConcatenate(self, drop_remainder): - dataset1 = dataset_ops.Dataset.range(64).batch( - 8, drop_remainder=drop_remainder) - dataset2 = dataset_ops.Dataset.range(32).batch( - 8, drop_remainder=drop_remainder) + def testConcatenate(self): + dataset1 = dataset_ops.Dataset.range(64).batch(8) + dataset2 = dataset_ops.Dataset.range(32).batch(8) dataset = dataset1.concatenate(dataset2) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) - self.assertEqual( - [[8 if drop_remainder else None]], - [ts.as_list() for ts in _flat_shapes(dataset)]) - self.assertEqual( - [[2 if drop_remainder else None]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) + self.assertEqual([[None]], + [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = ([[i, i + 1] for i in range(0, 64, 2)] + [[i, i + 1] for i in range(0, 32, 2)]) self.assertDatasetProduces(rebatched_dataset, expected_output) - def testConcatenateDifferentShapes(self, drop_remainder): - dataset1 = dataset_ops.Dataset.range(64).batch( - 16, drop_remainder=drop_remainder) - dataset2 = dataset_ops.Dataset.range(32).batch( - 8, drop_remainder=drop_remainder) + def testConcatenateDifferentShapes(self): + dataset1 = dataset_ops.Dataset.range(64).batch(16) + dataset2 = dataset_ops.Dataset.range(32).batch(8) dataset = dataset1.concatenate(dataset2) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) - self.assertEqual( - [[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) self.assertEqual( [[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) @@ -241,73 +221,56 @@ class RebatchDatasetTest(test_base.DatasetTestBase): [[i, i + 1] for i in range(0, 32, 2)]) self.assertDatasetProduces(rebatched_dataset, expected_output) - def testZip(self, drop_remainder): - dataset1 = dataset_ops.Dataset.range(64).batch( - 8, drop_remainder=drop_remainder) - dataset2 = dataset_ops.Dataset.range(32).batch( - 8, drop_remainder=drop_remainder) + def testZip(self): + dataset1 = dataset_ops.Dataset.range(64).batch(8) + dataset2 = dataset_ops.Dataset.range(32).batch(8) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) - self.assertEqual( - [[8], [8]] if drop_remainder else [[None], [None]], - [ts.as_list() for ts in _flat_shapes(dataset)]) - self.assertEqual( - [[2], [2]] if drop_remainder else [[None], [None]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) + self.assertEqual([[None], [None]], + [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)] self.assertDatasetProduces(rebatched_dataset, expected_output) - def testZipDifferentShapes(self, drop_remainder): - dataset1 = dataset_ops.Dataset.range(64).batch( - 16, drop_remainder=drop_remainder) - dataset2 = dataset_ops.Dataset.range(32).batch( - 8, drop_remainder=drop_remainder) + def testZipDifferentShapes(self): + dataset1 = dataset_ops.Dataset.range(64).batch(16) + dataset2 = dataset_ops.Dataset.range(32).batch(8) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) - self.assertEqual( - [[16], [8]] if drop_remainder else [[None], [None]], - [ts.as_list() for ts in _flat_shapes(dataset)]) - self.assertEqual( - [[4], [2]] if drop_remainder else [[None], [None]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) + self.assertEqual([[None], [None]], + [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [([2 * i, 2 * i + 1, 2 * i + 2, 2 * i + 3], [i, i + 1]) for i in range(0, 32, 2)] self.assertDatasetProduces(rebatched_dataset, expected_output) - def testUnsupportedTransformError(self, drop_remainder): - dataset = dataset_ops.Dataset.range(1024).batch( - 32, drop_remainder=drop_remainder).apply(sleep.sleep(10)) + def testUnsupportedTransformError(self): + dataset = dataset_ops.Dataset.range(1024).batch(32).apply(sleep.sleep(10)) with self.assertRaises(errors.InvalidArgumentError): rebatched_dataset = distribute._RebatchDataset( dataset, num_workers=4, use_fallback=False) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element()) - def testUnsupportedTransformInFlatMapError(self, drop_remainder): + def testUnsupportedTransformInFlatMapError(self): dataset = dataset_ops.Dataset.range(2).flat_map( lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda - 32, drop_remainder=drop_remainder).apply(sleep.sleep(10))) + 32).apply(sleep.sleep(10))) with self.assertRaises(errors.InvalidArgumentError): rebatched_dataset = distribute._RebatchDataset( dataset, num_workers=4, use_fallback=False) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element()) - def testFlatMapBatching(self, drop_remainder): - dataset = dataset_ops.Dataset.range( - 2).flat_map(lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda - 32, drop_remainder=drop_remainder)) - self.assertEqual( - [[32 if drop_remainder else None]], - [ts.as_list() for ts in _flat_shapes(dataset)]) + def testFlatMapBatching(self): + dataset = dataset_ops.Dataset.range(2).flat_map( + lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda + 32)) # Two elements where each element is range(32) expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(dataset, expected_output) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) - self.assertEqual( - [[8 if drop_remainder else None]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) + self.assertEqual([[None]], + [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # Two elements where each element is a list of 4 elements where each element # is a list of 8. expected_output = [[k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension @@ -315,21 +278,18 @@ class RebatchDatasetTest(test_base.DatasetTestBase): for i in range(0, 32, 8)] # generates 4 elements self.assertDatasetProduces(rebatched_dataset, expected_output) - def testInterleaveBatching(self, drop_remainder): - dataset = dataset_ops.Dataset.range( - 2).interleave(lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda - 32, drop_remainder=drop_remainder), cycle_length=2) - self.assertEqual( - [[32 if drop_remainder else None]], - [ts.as_list() for ts in _flat_shapes(dataset)]) + def testInterleaveBatching(self): + dataset = dataset_ops.Dataset.range(2).interleave( + lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda + 32), + cycle_length=2) # Two elements where each element is range(32) expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(dataset, expected_output) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) - self.assertEqual( - [[8 if drop_remainder else None]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) + self.assertEqual([[None]], + [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # List of 4 elements where each element is a list of 8 numbering from 0 to # 31 repeated twice. expected_output = [[k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension @@ -337,22 +297,19 @@ class RebatchDatasetTest(test_base.DatasetTestBase): for _ in range(2)] self.assertDatasetProduces(rebatched_dataset, expected_output) - def testParallelInterleaveBatching(self, drop_remainder): - dataset = dataset_ops.Dataset.range( - 2).interleave(lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda - 32, drop_remainder=drop_remainder), cycle_length=2, - num_parallel_calls=2) - self.assertEqual( - [[32 if drop_remainder else None]], - [ts.as_list() for ts in _flat_shapes(dataset)]) + def testParallelInterleaveBatching(self): + dataset = dataset_ops.Dataset.range(2).interleave( + lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda + 32), + cycle_length=2, + num_parallel_calls=2) # Two elements where each element is range(32) expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(dataset, expected_output) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) - self.assertEqual( - [[8 if drop_remainder else None]], - [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) + self.assertEqual([[None]], + [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # List of 4 elements where each element is a list of 8 numbering from 0 to # 31 repeated twice in collated fashion i.e [0...8], [0...8] etc. expected_output = [[k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension @@ -360,17 +317,17 @@ class RebatchDatasetTest(test_base.DatasetTestBase): for _ in range(2)] self.assertDatasetProduces(rebatched_dataset, expected_output) - def testGroupByWindowStaticBatch(self, drop_remainder): + def testGroupByWindowStaticBatch(self): dataset = dataset_ops.Dataset.from_tensor_slices( [[array_ops.constant(i, dtype=dtypes.int64)] * 3 for i in range(40)]) reduce_fn = lambda bucket_id, ds: ds.batch( # pylint: disable=g-long-lambda - batch_size=10, drop_remainder=drop_remainder) + batch_size=10) dataset = dataset.apply( grouping.group_by_window( key_func=lambda x: x[0] % 4, reduce_func=reduce_fn, window_size=10)) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=2) - self.assertEqual([[5, 3] if drop_remainder else [None, 3]], + self.assertEqual([[None, 3]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # pylint: disable=g-complex-comprehension expected_output = [[[j + i * 4 + k * 20] * 3 @@ -379,10 +336,15 @@ class RebatchDatasetTest(test_base.DatasetTestBase): for k in range(2)] self.assertDatasetProduces(rebatched_dataset, expected_output) - def testGroupByWindowDynamicBatch(self, drop_remainder): + def testGroupByWindowDynamicBatch(self): + # {0, 1, 0, 1, ...} dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2) - reduce_fn = lambda bucket_id, ds: ds.batch( # pylint: disable=g-long-lambda - batch_size=(bucket_id + 1) * 5, drop_remainder=drop_remainder) + + def reduce_fn(key, ds): + # key == 0 -> .batch(5) + # key == 1 -> .batch(10) + return ds.batch(batch_size=(key + 1) * 5) + dataset = dataset.apply( grouping.group_by_window( key_func=lambda x: x, reduce_func=reduce_fn, window_size=10)) @@ -390,15 +352,64 @@ class RebatchDatasetTest(test_base.DatasetTestBase): self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) - pairs = [(3, 0), (3, 0), (3, 0)] - if not drop_remainder: - pairs.extend([(1, 0)]) - pairs.extend([(5, 1), (5, 1)]) + + # The batches of 5 (value == 0) will be split into minibatches of (3, 2) and + # the batches of 10 (value == 1) split into minibatches of (5, 5) + # [(batch_size, value), ...] + pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (5, 1), (5, 1)] pairs = pairs * 2 expected_output = [[value] * batch_size for batch_size, value in pairs] self.assertDatasetProduces(dataset, expected_output) - def testScanAfterBatch(self, drop_remainder): + def testGroupByWindowDynamicBatchWithPartialBatch(self): + # {0, 1, 0, 1, ...} + dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2) + + def reduce_fn(key, ds): + # key == 0 -> .batch(5) + # key == 1 -> .batch(10) + return ds.batch(batch_size=(key + 1) * 5) + + dataset = dataset.apply( + grouping.group_by_window( + key_func=lambda x: x, reduce_func=reduce_fn, window_size=11)) + dataset = distribute._RebatchDataset(dataset, num_workers=2) + + self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) + + # The batches of 5 (value == 0) will be split into minibatches of (3, 2) and + # the batches of 10 (value == 1) split into minibatches of (5, 5) + # [(batch_size, value), ...] + pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (1, 0), (5, 1), (5, 1), (1, 1), + (3, 0), (2, 0), (3, 0), (1, 0), (5, 1), (4, 1)] + expected_output = [[value] * batch_size for batch_size, value in pairs] + self.assertDatasetProduces(dataset, expected_output) + + def testGroupByWindowDynamicBatchWithPartialBatchWithDropRemainder(self): + # This test exercises nested batch functionality, dynamic batch size + # and drop_remainder=True together. + dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2) + + def reduce_fn(key, ds): + # key == 0 -> .batch(5) + # key == 1 -> .batch(10) + return ds.batch(batch_size=(key + 1) * 5, drop_remainder=True) + + dataset = dataset.apply( + grouping.group_by_window( + key_func=lambda x: x, reduce_func=reduce_fn, window_size=11)) + dataset = distribute._RebatchDataset(dataset, num_workers=2) + + self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) + + # The batches of 5 (value == 0) will be split into minibatches of (3, 2) and + # the batches of 10 (value == 1) split into minibatches of (5, 5) + # [(batch_size, value), ...] + pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (5, 1), (5, 1), (3, 0), (2, 0)] + expected_output = [[value] * batch_size for batch_size, value in pairs] + self.assertDatasetProduces(dataset, expected_output) + + def testScanAfterBatch(self): dataset = dataset_ops.Dataset.range(40).batch(10).apply( scan_ops.scan(np.int64(2), lambda state, value: (state, value * state))) dataset = distribute._RebatchDataset(dataset, num_workers=2) @@ -408,7 +419,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase): expected_output = [[i * 2 for i in range(j*5, (j+1)*5)] for j in range(8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(dataset, expected_output) - def testMakeBatchedFeaturesDataset(self, drop_remainder): + def testMakeBatchedFeaturesDataset(self): # Set up fn = os.path.join(self.get_temp_dir(), "tf_record.txt") writer = python_io.TFRecordWriter(fn) @@ -429,13 +440,11 @@ class RebatchDatasetTest(test_base.DatasetTestBase): features={"value": parsing_ops.FixedLenFeature([], dtypes.int64)}, shuffle=False, num_epochs=1, - drop_final_batch=drop_remainder) + drop_final_batch=False) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) - self.assertEqual([[32 if drop_remainder else None]], - [ts.as_list() for ts in _flat_shapes(dataset)]) - self.assertEqual([[8 if drop_remainder else None]], + self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [{ diff --git a/tensorflow/python/data/experimental/ops/distribute.py b/tensorflow/python/data/experimental/ops/distribute.py index b834fe8839a..f7db8491c57 100644 --- a/tensorflow/python/data/experimental/ops/distribute.py +++ b/tensorflow/python/data/experimental/ops/distribute.py @@ -74,7 +74,11 @@ def _AutoShardDatasetV1(input_dataset, num_workers, index): # pylint: disable=i class _RebatchDataset(dataset_ops.UnaryDataset): - """A `Dataset` that divides the batch size by `num_workers`.""" + """A `Dataset` that divides the batch size by `num_workers`. + + For each batch in the input dataset, the resulting dataset will produce + `num_replicas` minibatches whose sizes add up to the original batch size. + """ def __init__(self, input_dataset, num_workers, use_fallback=True): self._input_dataset = input_dataset @@ -85,8 +89,14 @@ class _RebatchDataset(dataset_ops.UnaryDataset): raise ValueError( "Input shape should have at least one dimension. " "Perhaps your input dataset is not batched?") - output_dims = [d for d in output_shapes.dims] - output_dims[0] = (output_dims[0] + num_workers - 1) // num_workers + output_dims = [d.value for d in output_shapes.dims] + + if output_dims[0] is not None and output_dims[0] % num_workers == 0: + output_dims[0] = output_dims[0] // num_workers + else: + # Set the batch dimension to unknown. If the global batch size does not + # divide num_workers evenly, the minibatches may have different sizes. + output_dims[0] = None return tensor_shape.TensorShape(output_dims) input_types = dataset_ops.get_legacy_output_types(self._input_dataset)