diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalRebatchDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalRebatchDataset.pbtxt index b8455308e5c..d45abf5630e 100644 --- a/tensorflow/core/api_def/base_api/api_def_ExperimentalRebatchDataset.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalRebatchDataset.pbtxt @@ -8,9 +8,9 @@ A variant tensor representing the input dataset. END } in_arg { - name: "num_workers" + name: "num_replicas" description: <parameter_map().at("num_workers").i(); + num_replicas_ = config->parameter_map().at("num_replicas").i(); use_fallback_ = config->parameter_map().at("use_fallback").b(); return Status::OK(); } @@ -307,14 +307,14 @@ Status GetBatchDim(AttrValue output_shapes, int* batch_dim) { return Status::OK(); } -Status UpdateOutputShapes(const string& node_name, int64 num_workers, +Status UpdateOutputShapes(const string& node_name, int64 num_replicas, MutableGraphView* graph) { NodeDef* node = graph->GetNode(node_name); if (node->attr().contains(kOutputShapesAttr)) { AttrValue output_shapes = node->attr().at(kOutputShapesAttr); for (auto& shape : *output_shapes.mutable_list()->mutable_shape()) { if (!shape.unknown_rank() && shape.dim(0).size() != -1) { - shape.mutable_dim(0)->set_size(shape.dim(0).size() / num_workers); + shape.mutable_dim(0)->set_size(shape.dim(0).size() / num_replicas); } } (*node->mutable_attr())[kOutputShapesAttr] = output_shapes; @@ -335,16 +335,16 @@ int64 GetBatchSizeArgIndex(const NodeDef& batch_node) { } Status MakeNewBatchSizeNode(const string& global_batch_size_name, - int64 num_workers, FunctionDef* fdef, + int64 num_replicas, FunctionDef* fdef, NodeDef** result) { NodeDef* one_node; TF_RETURN_IF_ERROR(AddConstInt64Node(1, fdef, &one_node)); - NodeDef* num_workers_node; - TF_RETURN_IF_ERROR(AddConstInt64Node(num_workers, fdef, &num_workers_node)); + NodeDef* num_replicas_node; + TF_RETURN_IF_ERROR(AddConstInt64Node(num_replicas, fdef, &num_replicas_node)); NodeDef* numerator_node = AddBinaryNode(global_batch_size_name, - strings::StrCat(num_workers_node->name(), ":output:0"), + strings::StrCat(num_replicas_node->name(), ":output:0"), kAddOp, DT_INT64, fdef); numerator_node = AddBinaryNode( strings::StrCat(numerator_node->name(), ":z:0"), @@ -352,14 +352,14 @@ Status MakeNewBatchSizeNode(const string& global_batch_size_name, *result = AddBinaryNode(strings::StrCat(numerator_node->name(), ":z:0"), - strings::StrCat(num_workers_node->name(), ":output:0"), + strings::StrCat(num_replicas_node->name(), ":output:0"), kTruncateDivOp, DT_INT64, fdef); return Status::OK(); } // Given a "batch" dataset node, we replace the `batch_size` input with a new -// input that corresponds to the original input divided by `num_workers`. -Status MutateBatchSize(const NodeDef& node, int64 num_workers, +// input that corresponds to the original input divided by `num_replicas`. +Status MutateBatchSize(const NodeDef& node, int64 num_replicas, MutableGraphView* graph) { // For all the batching datasets the batch_size is input number 1 except for // MapAndBatchDataset. @@ -369,8 +369,8 @@ Status MutateBatchSize(const NodeDef& node, int64 num_workers, int64 batch_size; TF_RETURN_IF_ERROR( graph_utils::GetScalarConstNodeValue(*batch_size_node, &batch_size)); - DCHECK_EQ(batch_size % num_workers, 0); - batch_size = batch_size / num_workers; + DCHECK_EQ(batch_size % num_replicas, 0); + batch_size = batch_size / num_replicas; NodeDef* new_batch_size_node = graph_utils::AddScalarConstNode(batch_size, graph); // We don't call UpdateFanouts here because CSE elimination might lead to @@ -413,8 +413,8 @@ Status AddFlatMapNode(const string& input_dataset, // def flat_map_fn(*batched_components): // ds = tf.data.Dataset.from_tensor_slices(batched_components) // return ds.batch(minibatch_size, drop_remainder=False) -Status CreateFlatMapFnWithBatch(const DataTypeVector& dtypes, int64 num_workers, - FunctionDef* result) { +Status CreateFlatMapFnWithBatch(const DataTypeVector& dtypes, + int64 num_replicas, FunctionDef* result) { NodeDef* tensor_slice_node = result->add_node_def(); tensor_slice_node->set_op("TensorSliceDataset"); for (int i = 0; i < dtypes.size(); ++i) { @@ -445,7 +445,7 @@ Status CreateFlatMapFnWithBatch(const DataTypeVector& dtypes, int64 num_workers, function_utils::AddFunctionInput("captured_batch_size", result, DT_INT64); NodeDef* new_batch_size; TF_RETURN_IF_ERROR(MakeNewBatchSizeNode( - original_batch_size->name(), num_workers, result, &new_batch_size)); + original_batch_size->name(), num_replicas, result, &new_batch_size)); batch_node->add_input(strings::StrCat(new_batch_size->name(), ":z:0")); // `drop_remainder` input @@ -470,9 +470,9 @@ Status CreateFlatMapFnWithBatch(const DataTypeVector& dtypes, int64 num_workers, // in a step adds up to the global batch size. However, since this adds // additional data copies (both from_tensor_slices and batch), we only use // this approach when necessary, i.e. when we need to drop remainder on the -// global batch, or when the global batch size does not divide num_workers +// global batch, or when the global batch size does not divide num_replicas // evenly. -Status AppendFlatMap(const NodeDef& batch_node, int64 num_workers, +Status AppendFlatMap(const NodeDef& batch_node, int64 num_replicas, FunctionLibraryDefinition* flib, MutableGraphView* graph) { // `.flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x). // batch(minibatch_size, drop_remainder=False))` @@ -484,7 +484,7 @@ Status AppendFlatMap(const NodeDef& batch_node, int64 num_workers, TF_RETURN_IF_ERROR( graph_utils::GetDatasetOutputTypesAttr(batch_node, &dtypes)); TF_RETURN_IF_ERROR( - CreateFlatMapFnWithBatch(dtypes, num_workers, &flat_map_fn)); + CreateFlatMapFnWithBatch(dtypes, num_replicas, &flat_map_fn)); int64 batch_size_index = GetBatchSizeArgIndex(batch_node); @@ -496,7 +496,7 @@ Status AppendFlatMap(const NodeDef& batch_node, int64 num_workers, // Because the flat map function uses drop_remainder = False, // the shape might be unknown auto old_dim = shape.dim(0).size(); - auto new_dim = old_dim % num_workers == 0 ? old_dim / num_workers : -1; + auto new_dim = old_dim % num_replicas == 0 ? old_dim / num_replicas : -1; shape.mutable_dim(0)->set_size(new_dim); } } @@ -514,12 +514,13 @@ Status AppendFlatMap(const NodeDef& batch_node, int64 num_workers, // There are several things we do here, depending on the values of // batch_size and drop_remainder. -// (1) If batch size is known and divisible by num_workers, and drop_remainder +// (1) If batch size is known and divisible by num_replicas, and drop_remainder // is known to be False, we mutate the batch size directly. -// .batch(global_batch_size) -> .batch(global_batch_size // num_workers) +// .batch(global_batch_size) -> .batch(global_batch_size // num_replicas) // (2) Otherwise, we add a flat_map transformation to preserve the global batch -// size across the workers and to preserve the drop remainder behavior. -bool ShouldMutateBatchSizeDirectly(const NodeDef& batch_node, int64 num_workers, +// size across the replicas and to preserve the drop remainder behavior. +bool ShouldMutateBatchSizeDirectly(const NodeDef& batch_node, + int64 num_replicas, MutableGraphView* graph) { int64 batch_size_arg_index = GetBatchSizeArgIndex(batch_node); NodeDef* batch_size_node = @@ -528,9 +529,9 @@ bool ShouldMutateBatchSizeDirectly(const NodeDef& batch_node, int64 num_workers, int64 batch_size; Status s = graph_utils::GetScalarConstNodeValue(*batch_size_node, &batch_size); - // If batch size is unknown or indivisible by num workers, we don't + // If batch size is unknown or indivisible by num replicas, we don't // mutate it directly - if (!s.ok() || batch_size % num_workers != 0) return false; + if (!s.ok() || batch_size % num_replicas != 0) return false; if (batch_node.op() == kBatchOp || batch_node.op() == kPaddedBatchOp) { // These ops don't have a `drop_remainder` input, and behave like @@ -547,16 +548,16 @@ bool ShouldMutateBatchSizeDirectly(const NodeDef& batch_node, int64 num_workers, return s.ok() && !drop_remainder; } -Status RewriteBatchNode(const NodeDef& batch_node, int64 num_workers, +Status RewriteBatchNode(const NodeDef& batch_node, int64 num_replicas, FunctionLibraryDefinition* flib, MutableGraphView* graph) { - if (ShouldMutateBatchSizeDirectly(batch_node, num_workers, graph)) { - return MutateBatchSize(batch_node, num_workers, graph); + if (ShouldMutateBatchSizeDirectly(batch_node, num_replicas, graph)) { + return MutateBatchSize(batch_node, num_replicas, graph); } - return AppendFlatMap(batch_node, num_workers, flib, graph); + return AppendFlatMap(batch_node, num_replicas, flib, graph); } -Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, +Status OptimizeGraph(const GrapplerItem& item, int64 num_replicas, bool use_fallback, GraphDef* output); // Helper function that starts from a node in the graph and recurses into its @@ -567,16 +568,16 @@ Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, // as they are datasets themselves. // 3. Core dataset ops + Identity op: Recurses into first input parameter. // 4. FlatMap type mapping dataset ops: Recurses into the function definition. -Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, +Status RecursivelyHandleOp(const NodeDef& node, int64 num_replicas, bool use_fallback, FunctionLibraryDefinition* flib, MutableGraphView* graph) { if (IsDatasetNodeOfType(node, kBatchDatasetOps)) { - TF_RETURN_IF_ERROR(RewriteBatchNode(node, num_workers, flib, graph)); + TF_RETURN_IF_ERROR(RewriteBatchNode(node, num_replicas, flib, graph)); } else if (IsDatasetNodeOfType(node, kMultipleInputsDatasetOps)) { // For all multiple input datasets, all inputs are datasets themselves. for (int i = 0; i < node.input_size(); ++i) { NodeDef* input_node = graph_utils::GetInputNode(node, *graph, i); - TF_RETURN_IF_ERROR(RecursivelyHandleOp(*input_node, num_workers, + TF_RETURN_IF_ERROR(RecursivelyHandleOp(*input_node, num_replicas, use_fallback, flib, graph)); } } else if (IsDatasetNodeOfType(node, kPassThroughOps) || IsRetval(node)) { @@ -584,7 +585,7 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, // function body graph in place of function outputs, the input dataset is // input 0. NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0); - TF_RETURN_IF_ERROR(RecursivelyHandleOp(*input_node, num_workers, + TF_RETURN_IF_ERROR(RecursivelyHandleOp(*input_node, num_replicas, use_fallback, flib, graph)); } else if (IsDatasetNodeOfType(node, kFuncDatasetOps)) { const string func_name = @@ -594,7 +595,7 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem( *fdef, *flib, graph->graph()->versions().producer(), &f_item)); GraphDef optimized_func_graph; - TF_RETURN_IF_ERROR(OptimizeGraph(f_item, num_workers, use_fallback, + TF_RETURN_IF_ERROR(OptimizeGraph(f_item, num_replicas, use_fallback, &optimized_func_graph)); // Function body optimization might have created new specialized @@ -623,7 +624,7 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, } // If we've successfully updated the batch size of this node or any nodes // in the dataset tree rooted in this node, we update the output_shapes attr. - TF_RETURN_IF_ERROR(UpdateOutputShapes(node.name(), num_workers, graph)); + TF_RETURN_IF_ERROR(UpdateOutputShapes(node.name(), num_replicas, graph)); return Status::OK(); } @@ -689,7 +690,7 @@ Status CreateFlatMapFnWithReshape(int new_batch_dim, // For each component of the dataset, we reshape it from shape // (old_batch_size, ...) to (-1, new_batch_size, ...) - // where new_batch_size = (old_batch_size + num_workers - 1) // num_workers + // where new_batch_size = (old_batch_size + num_replicas - 1) // num_replicas for (int i = 0; i < types.size(); ++i) { auto* input_arg = function_utils::AddFunctionInput( strings::StrCat("args_", i), result, types.at(i)); @@ -733,13 +734,13 @@ Status CreateFlatMapFnWithReshape(int new_batch_dim, // return tf.data.Dataset.from_tensor_slices( // tf.reshape( // x, -// tf.concat([[-1, old_batch_dim / num_workers], tf.shape(x)[1:]], 0) +// tf.concat([[-1, old_batch_dim / num_replicas], tf.shape(x)[1:]], 0) // ) // ) // // dataset = dataset.flat_map(fn) // ``` -Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_workers, +Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_replicas, FunctionLibraryDefinition* flib, MutableGraphView* graph) { if (IsRetval(*fetch_node) || fetch_node->op() == kIdentityOp) { @@ -762,10 +763,10 @@ Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_workers, } int batch_dim; TF_RETURN_IF_ERROR(GetBatchDim(output_shapes, &batch_dim)); - if (batch_dim % num_workers != 0) { + if (batch_dim % num_replicas != 0) { return errors::InvalidArgument( "Cannot use rebatching fallback when batch dimension doesn't divide " - "num_workers evenly."); + "num_replicas evenly."); } // Create the flat map fn @@ -778,7 +779,7 @@ Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_workers, DataTypeVector output_types; TF_RETURN_IF_ERROR( graph_utils::GetDatasetOutputTypesAttr(*fetch_node, &output_types)); - TF_RETURN_IF_ERROR(CreateFlatMapFnWithReshape(batch_dim / num_workers, + TF_RETURN_IF_ERROR(CreateFlatMapFnWithReshape(batch_dim / num_replicas, output_types, &flat_map_fn)); NodeDef* flat_map_node; @@ -786,7 +787,7 @@ Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_workers, {}, {}, flat_map_fn, output_shapes, output_types, flib, graph, &flat_map_node)); TF_RETURN_IF_ERROR( - UpdateOutputShapes(flat_map_node->name(), num_workers, graph)); + UpdateOutputShapes(flat_map_node->name(), num_replicas, graph)); TF_RETURN_IF_ERROR( graph->UpdateFanouts(fetch_node->name(), flat_map_node->name())); @@ -797,7 +798,7 @@ Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_workers, // Helper function that given a GrapplerItem generates a mutated graph def // with the batch size changed. The GrapplerItem could be generated from the // main graph or could be a function graph. -Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, +Status OptimizeGraph(const GrapplerItem& item, int64 num_replicas, bool use_fallback, GraphDef* output) { *output = item.graph; MutableGraphView graph(output); @@ -807,8 +808,8 @@ Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, NodeDef* sink_node; TF_RETURN_IF_ERROR(graph_utils::GetFetchNode(graph, item, &sink_node)); - Status s = - RecursivelyHandleOp(*sink_node, num_workers, use_fallback, &flib, &graph); + Status s = RecursivelyHandleOp(*sink_node, num_replicas, use_fallback, &flib, + &graph); if (!s.ok()) { if (use_fallback) { VLOG(1) << "Failed to rebatch by rewriting the batch transformation (" @@ -818,7 +819,7 @@ Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, *output = item.graph; graph = MutableGraphView(output); TF_RETURN_IF_ERROR( - RebatchWithFallback(sink_node, num_workers, &flib, &graph)); + RebatchWithFallback(sink_node, num_replicas, &flib, &graph)); } else { // Return the error return s; @@ -837,7 +838,7 @@ Status RebatchOptimizer::OptimizeAndCollectStats(Cluster* cluster, *output = item.graph; MutableGraphView graph(output); - TF_RETURN_IF_ERROR(OptimizeGraph(item, num_workers_, use_fallback_, output)); + TF_RETURN_IF_ERROR(OptimizeGraph(item, num_replicas_, use_fallback_, output)); stats->num_changes++; return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/data/rebatch.h b/tensorflow/core/grappler/optimizers/data/rebatch.h index 75c965824cc..028e69006e6 100644 --- a/tensorflow/core/grappler/optimizers/data/rebatch.h +++ b/tensorflow/core/grappler/optimizers/data/rebatch.h @@ -23,7 +23,7 @@ namespace tensorflow { namespace grappler { // This optimizer changes the batch size of the output dataset by dividing the -// current batch size by parameter `num_workers`. Currently, this works only +// current batch size by parameter `num_replicas`. Currently, this works only // for very simple pipelines with a single BatchDatasetV2 transformation. class RebatchOptimizer : public TFDataOptimizerBase { public: @@ -43,7 +43,7 @@ class RebatchOptimizer : public TFDataOptimizerBase { const GraphDef& optimize_output, double result) override; private: - int64 num_workers_; + int64 num_replicas_; bool use_fallback_; }; diff --git a/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc index 2cc1bec447a..13d01254155 100644 --- a/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc @@ -36,14 +36,15 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel { protected: void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - int64 num_workers; - OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_workers", &num_workers)); + int64 num_replicas; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "num_replicas", &num_replicas)); OP_REQUIRES( - ctx, num_workers > 0, - errors::InvalidArgument("num_workers must be greater than zero.")); + ctx, num_replicas > 0, + errors::InvalidArgument("num_replicas must be greater than zero.")); - auto config_factory = [num_workers, this]() { - return CreateConfig(num_workers, this->use_fallback_); + auto config_factory = [num_replicas, this]() { + return CreateConfig(num_replicas, this->use_fallback_); }; // We only want to optimize functions for some particular datasets like @@ -56,17 +57,17 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel { } private: - static RewriterConfig CreateConfig(int64 num_workers, bool use_fallback) { + static RewriterConfig CreateConfig(int64 num_replicas, bool use_fallback) { RewriterConfig rewriter_config; rewriter_config.set_fail_on_optimizer_errors(true); rewriter_config.add_optimizers(kOptimizerName); rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE); auto custom_optimizer = rewriter_config.add_custom_optimizers(); custom_optimizer->set_name(kOptimizerName); - AttrValue num_workers_attr; - num_workers_attr.set_i(num_workers); - (*custom_optimizer->mutable_parameter_map())["num_workers"] = - num_workers_attr; + AttrValue num_replicas_attr; + num_replicas_attr.set_i(num_replicas); + (*custom_optimizer->mutable_parameter_map())["num_replicas"] = + num_replicas_attr; AttrValue use_fallback_attr; use_fallback_attr.set_b(use_fallback); (*custom_optimizer->mutable_parameter_map())["use_fallback"] = diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc index 5504f5e577b..68823c8b8c0 100644 --- a/tensorflow/core/ops/experimental_dataset_ops.cc +++ b/tensorflow/core/ops/experimental_dataset_ops.cc @@ -658,7 +658,7 @@ REGISTER_OP("RandomDataset") REGISTER_OP("ExperimentalRebatchDataset") .Input("input_dataset: variant") - .Input("num_workers: int64") + .Input("num_replicas: int64") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") @@ -667,7 +667,7 @@ REGISTER_OP("ExperimentalRebatchDataset") REGISTER_OP("RebatchDataset") .Input("input_dataset: variant") - .Input("num_workers: int64") + .Input("num_replicas: int64") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") diff --git a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py index 09eac5dda50..02523c10479 100644 --- a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py @@ -58,7 +58,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def testBasic(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder) - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[8] if drop_remainder else [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) @@ -67,15 +67,15 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def testScalarInputError(self): dataset = dataset_ops.Dataset.range(1024) - distribute._RebatchDataset(dataset.batch(4), num_workers=4) + distribute._RebatchDataset(dataset.batch(4), num_replicas=4) with self.assertRaisesRegexp(ValueError, "at least one dimension"): - distribute._RebatchDataset(dataset, num_workers=4) + distribute._RebatchDataset(dataset, num_replicas=4) @parameterized.named_parameters(drop_remainder_cases) - def testBatchNotDivisibleByNumWorkers(self, drop_remainder): + def testBatchNotDivisibleByNumReplicas(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder) - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [] @@ -92,7 +92,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def testTupleOutput(self): dataset = dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(32) - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) expected_output = [([k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension [k for k in range(i, i + 8)]) for i in range(0, 1024, 8)] @@ -101,7 +101,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def testNestedDictionaryOutput(self): dataset = dataset_ops.Dataset.range(1024).map( lambda x: {"a": x, "b": {"c": x}}).batch(32) - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) expected_output = [{"a": [k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension "b": {"c": [k for k in range(i, i + 8)]}} for i in range(0, 1024, 8)] @@ -111,7 +111,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def testFinalPartialBatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(1032).batch( 32, drop_remainder=drop_remainder) - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[8] if drop_remainder else [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) @@ -126,7 +126,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def testFinalPartialBatchAfterRebatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(34).batch( 32, drop_remainder=drop_remainder) - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[8] if drop_remainder else [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) @@ -158,7 +158,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def testMapAndBatch(self): dataset = dataset_ops.Dataset.range(1024).apply( batching.map_and_batch(math_ops.square, 32)) - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k**2 for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension @@ -169,7 +169,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): captured_t = variables.Variable(42) dataset = dataset_ops.Dataset.range(1024).apply( batching.map_and_batch(lambda x: captured_t, 32)) - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[42 for _ in range(i, i + 8)] # pylint: disable=g-complex-comprehension @@ -182,7 +182,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): dataset = dataset_ops.Dataset.range(128).batch( 4, drop_remainder=True).padded_batch( 8, padded_shapes=[5]) - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) # Each element is a list of 8 elements in which each element is a list of 5 # elements, first four are numbers and the last one is a padded zero. expected_output = [[[j, j + 1, j + 2, j + 3, 0] # pylint: disable=g-complex-comprehension @@ -202,7 +202,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): dataset1 = dataset_ops.Dataset.range(64).batch(8) dataset2 = dataset_ops.Dataset.range(32).batch(8) dataset = dataset1.concatenate(dataset2) - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = ([[i, i + 1] for i in range(0, 64, 2)] + @@ -213,7 +213,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): dataset1 = dataset_ops.Dataset.range(64).batch(16) dataset2 = dataset_ops.Dataset.range(32).batch(8) dataset = dataset1.concatenate(dataset2) - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual( [[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) @@ -225,7 +225,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): dataset1 = dataset_ops.Dataset.range(64).batch(8) dataset2 = dataset_ops.Dataset.range(32).batch(8) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None], [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)] @@ -235,7 +235,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): dataset1 = dataset_ops.Dataset.range(64).batch(16) dataset2 = dataset_ops.Dataset.range(32).batch(8) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None], [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [([2 * i, 2 * i + 1, 2 * i + 2, 2 * i + 3], [i, i + 1]) @@ -246,7 +246,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): dataset = dataset_ops.Dataset.range(1024).batch(32).apply(sleep.sleep(10)) with self.assertRaises(errors.InvalidArgumentError): rebatched_dataset = distribute._RebatchDataset( - dataset, num_workers=4, use_fallback=False) + dataset, num_replicas=4, use_fallback=False) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element()) @@ -256,7 +256,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): 32).apply(sleep.sleep(10))) with self.assertRaises(errors.InvalidArgumentError): rebatched_dataset = distribute._RebatchDataset( - dataset, num_workers=4, use_fallback=False) + dataset, num_replicas=4, use_fallback=False) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element()) @@ -268,7 +268,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(dataset, expected_output) - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # Two elements where each element is a list of 4 elements where each element @@ -287,7 +287,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(dataset, expected_output) - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # List of 4 elements where each element is a list of 8 numbering from 0 to @@ -307,7 +307,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(dataset, expected_output) - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # List of 4 elements where each element is a list of 8 numbering from 0 to @@ -325,7 +325,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): dataset = dataset.apply( grouping.group_by_window( key_func=lambda x: x[0] % 4, reduce_func=reduce_fn, window_size=10)) - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=2) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=2) self.assertEqual([[None, 3]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) @@ -348,7 +348,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): dataset = dataset.apply( grouping.group_by_window( key_func=lambda x: x, reduce_func=reduce_fn, window_size=10)) - dataset = distribute._RebatchDataset(dataset, num_workers=2) + dataset = distribute._RebatchDataset(dataset, num_replicas=2) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) @@ -373,7 +373,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): dataset = dataset.apply( grouping.group_by_window( key_func=lambda x: x, reduce_func=reduce_fn, window_size=11)) - dataset = distribute._RebatchDataset(dataset, num_workers=2) + dataset = distribute._RebatchDataset(dataset, num_replicas=2) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) @@ -398,7 +398,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): dataset = dataset.apply( grouping.group_by_window( key_func=lambda x: x, reduce_func=reduce_fn, window_size=11)) - dataset = distribute._RebatchDataset(dataset, num_workers=2) + dataset = distribute._RebatchDataset(dataset, num_replicas=2) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) @@ -412,7 +412,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def testScanAfterBatch(self): dataset = dataset_ops.Dataset.range(40).batch(10).apply( scan_ops.scan(np.int64(2), lambda state, value: (state, value * state))) - dataset = distribute._RebatchDataset(dataset, num_workers=2) + dataset = distribute._RebatchDataset(dataset, num_replicas=2) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) @@ -442,7 +442,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): num_epochs=1, drop_final_batch=False) - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) @@ -459,7 +459,7 @@ class RebatchDatasetFallbackTest(test_base.DatasetTestBase): def testWithNoBatchDataset(self): dataset = dataset_ops.Dataset.from_tensor_slices( [[k for k in range(i, i + 32)] for i in range(0, 1024, 32)]) # pylint: disable=g-complex-comprehension - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[32]], [ts.as_list() for ts in _flat_shapes(dataset)]) self.assertEqual([[8]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) @@ -470,7 +470,7 @@ class RebatchDatasetFallbackTest(test_base.DatasetTestBase): def testWithUnhandledTransformation(self): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=True).apply(sleep.sleep(10)) - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[32]], [ts.as_list() for ts in _flat_shapes(dataset)]) self.assertEqual([[8]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) @@ -482,7 +482,7 @@ class RebatchDatasetFallbackTest(test_base.DatasetTestBase): dataset = dataset_ops.Dataset.range(2).flat_map( lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda 32, drop_remainder=True).apply(sleep.sleep(10))) - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[8]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) @@ -500,7 +500,7 @@ class RebatchDatasetFallbackTest(test_base.DatasetTestBase): with self.assertRaisesRegexp(errors.InvalidArgumentError, "Cannot use rebatching fallback"): - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element()) @@ -512,11 +512,11 @@ class RebatchDatasetFallbackTest(test_base.DatasetTestBase): with self.assertRaisesRegexp(errors.InvalidArgumentError, "Cannot use rebatching fallback"): - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element()) - def testBatchSizeIndivisibleByNumWorkers(self): + def testBatchSizeNotDivisibleByNumReplicas(self): # This doesn't work; reshape requires tensor shape to be exactly divisible # by the second dim. dataset = dataset_ops.Dataset.range(64).batch( @@ -524,7 +524,7 @@ class RebatchDatasetFallbackTest(test_base.DatasetTestBase): with self.assertRaisesRegexp(errors.InvalidArgumentError, "Cannot use rebatching fallback"): - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element()) @@ -532,7 +532,7 @@ class RebatchDatasetFallbackTest(test_base.DatasetTestBase): dataset = dataset_ops.Dataset.from_tensors((np.arange(10), np.arange(5))) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Cannot use rebatching fallback"): - rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5) + rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element()) diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/rebatch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/rebatch_dataset_serialization_test.py index 1f868a8eee2..0ae26927ca5 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/rebatch_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/rebatch_dataset_serialization_test.py @@ -32,7 +32,7 @@ class RebatchDatasetSerializationTest( return distribute._RebatchDataset( dataset_ops.Dataset.range(num_elements).batch( 4 * batch_size, drop_remainder=True), - num_workers=4) + num_replicas=4) self.run_core_tests(lambda: build_dataset(200, 10), 20) diff --git a/tensorflow/python/data/experimental/ops/distribute.py b/tensorflow/python/data/experimental/ops/distribute.py index 368ba95fd34..9bbd3ef4441 100644 --- a/tensorflow/python/data/experimental/ops/distribute.py +++ b/tensorflow/python/data/experimental/ops/distribute.py @@ -67,28 +67,28 @@ def _AutoShardDatasetV1(input_dataset, num_workers, index): # pylint: disable=i class _RebatchDataset(dataset_ops.UnaryDataset): - """A `Dataset` that divides the batch size by `num_workers`. + """A `Dataset` that divides the batch size by `num_replicas`. For each batch in the input dataset, the resulting dataset will produce `num_replicas` minibatches whose sizes add up to the original batch size. """ - def __init__(self, input_dataset, num_workers, use_fallback=True): + def __init__(self, input_dataset, num_replicas, use_fallback=True): self._input_dataset = input_dataset def recalculate_output_shapes(output_shapes): - """Recalculates the output_shapes after dividing it by num_workers.""" + """Recalculates the output_shapes after dividing it by num_replicas.""" if len(output_shapes) < 1: raise ValueError( "Input shape should have at least one dimension. " "Perhaps your input dataset is not batched?") output_dims = [d.value for d in output_shapes.dims] - if output_dims[0] is not None and output_dims[0] % num_workers == 0: - output_dims[0] = output_dims[0] // num_workers + if output_dims[0] is not None and output_dims[0] % num_replicas == 0: + output_dims[0] = output_dims[0] // num_replicas else: # Set the batch dimension to unknown. If the global batch size does not - # divide num_workers evenly, the minibatches may have different sizes. + # divide num_replicas evenly, the minibatches may have different sizes. output_dims[0] = None return tensor_shape.TensorShape(output_dims) @@ -102,13 +102,13 @@ class _RebatchDataset(dataset_ops.UnaryDataset): if compat.forward_compatible(2019, 8, 13) or not use_fallback: variant_tensor = ged_ops.rebatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access - num_workers=num_workers, + num_replicas=num_replicas, use_fallback=use_fallback, **self._flat_structure) else: variant_tensor = ged_ops.rebatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access - num_workers=num_workers, + num_replicas=num_replicas, **self._flat_structure) super(_RebatchDataset, self).__init__(input_dataset, variant_tensor) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 8c382aa0f03..76b3ccdf193 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -1286,7 +1286,7 @@ tf_module { } member_method { name: "ExperimentalRebatchDataset" - argspec: "args=[\'input_dataset\', \'num_workers\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + argspec: "args=[\'input_dataset\', \'num_replicas\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " } member_method { name: "ExperimentalScanDataset" @@ -3010,7 +3010,7 @@ tf_module { } member_method { name: "RebatchDataset" - argspec: "args=[\'input_dataset\', \'num_workers\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + argspec: "args=[\'input_dataset\', \'num_replicas\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " } member_method { name: "Reciprocal" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 8c382aa0f03..76b3ccdf193 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -1286,7 +1286,7 @@ tf_module { } member_method { name: "ExperimentalRebatchDataset" - argspec: "args=[\'input_dataset\', \'num_workers\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + argspec: "args=[\'input_dataset\', \'num_replicas\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " } member_method { name: "ExperimentalScanDataset" @@ -3010,7 +3010,7 @@ tf_module { } member_method { name: "RebatchDataset" - argspec: "args=[\'input_dataset\', \'num_workers\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + argspec: "args=[\'input_dataset\', \'num_replicas\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " } member_method { name: "Reciprocal"