diff --git a/tensorflow/core/grappler/optimizers/data/rebatch.cc b/tensorflow/core/grappler/optimizers/data/rebatch.cc index 821b486b884..c62088e9f3d 100644 --- a/tensorflow/core/grappler/optimizers/data/rebatch.cc +++ b/tensorflow/core/grappler/optimizers/data/rebatch.cc @@ -200,11 +200,13 @@ Status AddConstBoolNode(bool value, FunctionDef* fdef, NodeDef** result) { return Status::OK(); } -Status AddShapeNode(const NodeDefBuilder::NodeOut& input, FunctionDef* fdef, - NodeDef** result) { +Status AddShapeNode(const NodeDefBuilder::NodeOut& input, DataType out_type, + FunctionDef* fdef, NodeDef** result) { *result = fdef->add_node_def(); - TF_RETURN_IF_ERROR( - NodeDefBuilder("", "Shape").Input(input).Finalize(*result)); + TF_RETURN_IF_ERROR(NodeDefBuilder("", "Shape") + .Input(input) + .Attr("out_type", out_type) + .Finalize(*result)); function_utils::SetUniqueFunctionNodeName("rebatch/shape", fdef, *result); return Status::OK(); } @@ -276,34 +278,49 @@ void SetUnknownShapes(int num_components, AttrValue* output_shapes) { } } -Status GetBatchDim(AttrValue output_shapes, int* batch_dim) { - const auto& shape_0 = output_shapes.list().shape(0); - if (shape_0.unknown_rank() || shape_0.dim(0).size() == -1) { +// If the batch dimension is known and divisible by num_replicas, we set +// result = batch_dim / num_replicas. If the batch dimension is unknown, +// result = -1. If the dataset node is missing an output shapes attr, +// or the batch dimensions of its components don't match, we return an error +// status. +Status GetMinibatchDimForReshape(const NodeDef& dataset_node, + int64 num_replicas, int64* result) { + AttrValue output_shapes; + if (!dataset_node.attr().contains(kOutputShapesAttr)) { return errors::InvalidArgument( - "Cannot use rebatching fallback when 0th dimensions of dataset " - "components are not fully known. Component 0 has shape: ", - shape_0.ShortDebugString()); + "Cannot use rebatching fallback when the final dataset node does not " + "have an `output_shapes` attr. Node: ", + dataset_node.name(), " Op: ", dataset_node.op()); } + output_shapes = dataset_node.attr().at(kOutputShapesAttr); - *batch_dim = output_shapes.list().shape(0).dim(0).size(); - - for (int i = 1; i < output_shapes.list().shape_size(); ++i) { + // Get the batch dimension by checking the 0th dimension of all the inputs. + int batch_dim = -1; + for (int i = 0; i < output_shapes.list().shape_size(); ++i) { const auto& shape_i = output_shapes.list().shape(i); - if (shape_i.unknown_rank() || shape_i.dim(0).size() == -1) { + // If unknown, ignore. + if (shape_i.unknown_rank()) continue; + int batch_dim_i = shape_i.dim(0).size(); + if (batch_dim_i == -1) continue; + + // Update batch_dim with known dimension. + if (batch_dim_i != batch_dim && batch_dim != -1) { return errors::InvalidArgument( - "Cannot use rebatching fallback when 0th dimensions of dataset " - "components are not fully known. Component ", - i, " has shape: ", shape_i.ShortDebugString()); - } - if (shape_i.dim(0).size() != *batch_dim) { - return errors::InvalidArgument( - "Cannot use rebatching fallback when 0th dimensions of dataset " + "Cannot use rebatching fallback: 0th dimensions of dataset " "components don't match. Component ", - i, " has batch dimension: ", shape_i.dim(0).size(), - " while previous components have batch dimension: ", *batch_dim); + i, " has batch dimension: ", batch_dim_i, + " while previous components have batch dimension: ", batch_dim); } + batch_dim = batch_dim_i; } + + if (batch_dim == -1 || batch_dim % num_replicas != 0) { + *result = -1; + } else { + *result = batch_dim / num_replicas; + } + return Status::OK(); } @@ -411,6 +428,8 @@ Status AddFlatMapNode(const string& input_dataset, } // def flat_map_fn(*batched_components): +// batch_size = tf.shape(batched_components[0])[0] +// minibatch_size = (batch_size + num_replicas - 1) // num_replicas // ds = tf.data.Dataset.from_tensor_slices(batched_components) // return ds.batch(minibatch_size, drop_remainder=False) Status CreateFlatMapFnWithBatch(const DataTypeVector& dtypes, @@ -439,13 +458,32 @@ Status CreateFlatMapFnWithBatch(const DataTypeVector& dtypes, batch_node->add_input( strings::StrCat(tensor_slice_node->name(), ":handle:0")); - // `batch_size` input - // Here, we capture the original batch size from outside the flat map fn. - auto* original_batch_size = - function_utils::AddFunctionInput("captured_batch_size", result, DT_INT64); + // `batch_size` is tf.shape(arg)[0] + NodeDef* shape; + TF_RETURN_IF_ERROR(AddShapeNode({tensor_slice_node->input(0), 0, dtypes[0]}, + DT_INT64, result, &shape)); + + // Const with value [0] + NodeDef* const_vec_0; + TF_RETURN_IF_ERROR(AddConstIntNode({0}, {1}, result, &const_vec_0)); + + // Const with value [1] + NodeDef* const_vec_1; + TF_RETURN_IF_ERROR(AddConstIntNode({1}, {1}, result, &const_vec_1)); + + // Extracts the 0th dimension from the shape node. + NodeDef* original_batch_size; + TF_RETURN_IF_ERROR(AddStridedSliceNode( + {strings::StrCat(shape->name(), ":output"), 0, DT_INT64}, + {strings::StrCat(const_vec_0->name(), ":output"), 0, DT_INT32}, + {strings::StrCat(const_vec_1->name(), ":output"), 0, DT_INT32}, + {strings::StrCat(const_vec_1->name(), ":output"), 0, DT_INT32}, DT_INT32, + 0, 0, 0, 0, 1, result, &original_batch_size)); + NodeDef* new_batch_size; TF_RETURN_IF_ERROR(MakeNewBatchSizeNode( - original_batch_size->name(), num_replicas, result, &new_batch_size)); + strings::StrCat(original_batch_size->name(), ":output:0"), num_replicas, + result, &new_batch_size)); batch_node->add_input(strings::StrCat(new_batch_size->name(), ":z:0")); // `drop_remainder` input @@ -486,8 +524,6 @@ Status AppendFlatMap(const NodeDef& batch_node, int64 num_replicas, TF_RETURN_IF_ERROR( CreateFlatMapFnWithBatch(dtypes, num_replicas, &flat_map_fn)); - int64 batch_size_index = GetBatchSizeArgIndex(batch_node); - NodeDef* flat_map_node; AttrValue output_shapes = batch_node.attr().at(kOutputShapesAttr); @@ -502,9 +538,8 @@ Status AppendFlatMap(const NodeDef& batch_node, int64 num_replicas, } TF_RETURN_IF_ERROR(AddFlatMapNode(strings::StrCat(batch_node.name(), ":0"), - {batch_node.input(batch_size_index)}, - {DT_INT64}, flat_map_fn, output_shapes, - dtypes, flib, graph, &flat_map_node)); + {}, {}, flat_map_fn, output_shapes, dtypes, + flib, graph, &flat_map_node)); TF_RETURN_IF_ERROR( graph->UpdateFanouts(batch_node.name(), flat_map_node->name())); @@ -650,7 +685,7 @@ Status ReshapeComponent(int new_batch_dim, const string& arg, DataType dtype, // shape = tf.shape(arg) NodeDef* shape; - TF_RETURN_IF_ERROR(AddShapeNode({arg, 0, dtype}, fdef, &shape)); + TF_RETURN_IF_ERROR(AddShapeNode({arg, 0, dtype}, DT_INT32, fdef, &shape)); // later_dimensions = tf.shape(arg)[1:] NodeDef* later_dimensions; @@ -748,26 +783,6 @@ Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_replicas, fetch_node = graph_utils::GetInputNode(*fetch_node, *graph, 0); } - // Note: Here, we are conservative with only using the fallback when - // the output_shapes attr has the 0th dimension defined for every component. - // This because the flat_map_fn will fail if the batch does not divide evenly - // because of the use of the "Reshape" op. This ensures that the error is - // surfaced correctly. - AttrValue output_shapes; - if (!fetch_node->attr().contains(kOutputShapesAttr)) { - return errors::InvalidArgument( - "Cannot use rebatching fallback without output_shapes attr. Node: ", - fetch_node->name(), " Op: ", fetch_node->op()); - } else { - output_shapes = fetch_node->attr().at(kOutputShapesAttr); - } - int batch_dim; - TF_RETURN_IF_ERROR(GetBatchDim(output_shapes, &batch_dim)); - if (batch_dim % num_replicas != 0) { - return errors::InvalidArgument( - "Cannot use rebatching fallback when batch dimension doesn't divide " - "num_replicas evenly."); - } // Create the flat map fn FunctionDef flat_map_fn; @@ -779,9 +794,26 @@ Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_replicas, DataTypeVector output_types; TF_RETURN_IF_ERROR( graph_utils::GetDatasetOutputTypesAttr(*fetch_node, &output_types)); - TF_RETURN_IF_ERROR(CreateFlatMapFnWithReshape(batch_dim / num_replicas, - output_types, &flat_map_fn)); + int64 minibatch_dim; + // If the batch dimension is known and perfectly divisible by num_replicas, + // we use a fallback with `tf.reshape` for better performance. + TF_RETURN_IF_ERROR( + GetMinibatchDimForReshape(*fetch_node, num_replicas, &minibatch_dim)); + if (minibatch_dim != -1) { + TF_RETURN_IF_ERROR( + CreateFlatMapFnWithReshape(minibatch_dim, output_types, &flat_map_fn)); + } else { + TF_RETURN_IF_ERROR( + CreateFlatMapFnWithBatch(output_types, num_replicas, &flat_map_fn)); + } + + AttrValue output_shapes; + if (fetch_node->attr().contains(kOutputShapesAttr)) { + output_shapes = fetch_node->attr().at(kOutputShapesAttr); + } else { + SetUnknownShapes(output_types.size(), &output_shapes); + } NodeDef* flat_map_node; TF_RETURN_IF_ERROR(AddFlatMapNode(strings::StrCat(fetch_node->name(), ":0"), {}, {}, flat_map_fn, output_shapes, diff --git a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py index 02523c10479..c12d9916041 100644 --- a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py @@ -377,11 +377,8 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) - # The batches of 5 (value == 0) will be split into minibatches of (3, 2) and - # the batches of 10 (value == 1) split into minibatches of (5, 5) - # [(batch_size, value), ...] pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (1, 0), (5, 1), (5, 1), (1, 1), - (3, 0), (2, 0), (3, 0), (1, 0), (5, 1), (4, 1)] + (3, 0), (2, 0), (2, 0), (2, 0), (5, 1), (4, 1)] expected_output = [[value] * batch_size for batch_size, value in pairs] self.assertDatasetProduces(dataset, expected_output) @@ -497,36 +494,39 @@ class RebatchDatasetFallbackTest(test_base.DatasetTestBase): def testWithUnknownBatchDim(self): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=False).apply(sleep.sleep(10)) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Cannot use rebatching fallback"): - rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) - next_element = self.getNext(rebatched_dataset) - self.evaluate(next_element()) + expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension + self.assertDatasetProduces(rebatched_dataset, expected_output) def testWithUnknownBatchDimInSecondComponent(self): dataset0 = dataset_ops.Dataset.range(1024).batch(32, drop_remainder=True) dataset1 = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=False).apply(sleep.sleep(10)) dataset = dataset_ops.Dataset.zip((dataset0, dataset1)) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Cannot use rebatching fallback"): - rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) - next_element = self.getNext(rebatched_dataset) - self.evaluate(next_element()) + expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension + expected_output = [(x, x) for x in expected_output] + self.assertDatasetProduces(rebatched_dataset, expected_output) def testBatchSizeNotDivisibleByNumReplicas(self): - # This doesn't work; reshape requires tensor shape to be exactly divisible - # by the second dim. dataset = dataset_ops.Dataset.range(64).batch( 32, drop_remainder=True).apply(sleep.sleep(10)) - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Cannot use rebatching fallback"): - rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5) - next_element = self.getNext(rebatched_dataset) - self.evaluate(next_element()) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5) + + expected_output = [] + i = 0 + for _ in range(2): # number of steps + # first four minibatches have seven elements + for _ in range(4): + expected_output.append([k for k in range(i, i + 7)]) + i += 7 + # last minibatch has four elements + expected_output.append([k for k in range(i, i + 4)]) + i += 4 + self.assertDatasetProduces(rebatched_dataset, expected_output) def testBatchSizesDontMatch(self): dataset = dataset_ops.Dataset.from_tensors((np.arange(10), np.arange(5)))