[tf.data] s/workers/replicas in all rebatching related files for consistency with distribution strategy naming conventions (https://github.com/tensorflow/community/blob/master/rfcs/20181016-replicator.md).
PiperOrigin-RevId: 261958155
This commit is contained in:
parent
71d73e56a2
commit
6208021e3d
@ -8,9 +8,9 @@ A variant tensor representing the input dataset.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "num_workers"
|
||||
name: "num_replicas"
|
||||
description: <<END
|
||||
A scalar representing the number of workers to distribute this batch across. As
|
||||
A scalar representing the number of replicas to distribute this batch across. As
|
||||
a result of this transformation the current batch size would end up being
|
||||
divided by this parameter.
|
||||
END
|
||||
@ -18,6 +18,6 @@ END
|
||||
summary: "Creates a dataset that changes the batch size."
|
||||
description: <<END
|
||||
Creates a dataset that changes the batch size of the dataset to current batch
|
||||
size // num_workers.
|
||||
size // num_replicas.
|
||||
END
|
||||
}
|
||||
|
@ -8,9 +8,9 @@ A variant tensor representing the input dataset.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "num_workers"
|
||||
name: "num_replicas"
|
||||
description: <<END
|
||||
A scalar representing the number of workers to distribute this batch across. As
|
||||
A scalar representing the number of replicas to distribute this batch across. As
|
||||
a result of this transformation the current batch size would end up being
|
||||
divided by this parameter.
|
||||
END
|
||||
|
@ -39,7 +39,7 @@ Status RebatchOptimizer::Init(
|
||||
return errors::InvalidArgument(
|
||||
"Cannot initialize RebatchOptimizer without config.");
|
||||
|
||||
num_workers_ = config->parameter_map().at("num_workers").i();
|
||||
num_replicas_ = config->parameter_map().at("num_replicas").i();
|
||||
use_fallback_ = config->parameter_map().at("use_fallback").b();
|
||||
return Status::OK();
|
||||
}
|
||||
@ -307,14 +307,14 @@ Status GetBatchDim(AttrValue output_shapes, int* batch_dim) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status UpdateOutputShapes(const string& node_name, int64 num_workers,
|
||||
Status UpdateOutputShapes(const string& node_name, int64 num_replicas,
|
||||
MutableGraphView* graph) {
|
||||
NodeDef* node = graph->GetNode(node_name);
|
||||
if (node->attr().contains(kOutputShapesAttr)) {
|
||||
AttrValue output_shapes = node->attr().at(kOutputShapesAttr);
|
||||
for (auto& shape : *output_shapes.mutable_list()->mutable_shape()) {
|
||||
if (!shape.unknown_rank() && shape.dim(0).size() != -1) {
|
||||
shape.mutable_dim(0)->set_size(shape.dim(0).size() / num_workers);
|
||||
shape.mutable_dim(0)->set_size(shape.dim(0).size() / num_replicas);
|
||||
}
|
||||
}
|
||||
(*node->mutable_attr())[kOutputShapesAttr] = output_shapes;
|
||||
@ -335,16 +335,16 @@ int64 GetBatchSizeArgIndex(const NodeDef& batch_node) {
|
||||
}
|
||||
|
||||
Status MakeNewBatchSizeNode(const string& global_batch_size_name,
|
||||
int64 num_workers, FunctionDef* fdef,
|
||||
int64 num_replicas, FunctionDef* fdef,
|
||||
NodeDef** result) {
|
||||
NodeDef* one_node;
|
||||
TF_RETURN_IF_ERROR(AddConstInt64Node(1, fdef, &one_node));
|
||||
NodeDef* num_workers_node;
|
||||
TF_RETURN_IF_ERROR(AddConstInt64Node(num_workers, fdef, &num_workers_node));
|
||||
NodeDef* num_replicas_node;
|
||||
TF_RETURN_IF_ERROR(AddConstInt64Node(num_replicas, fdef, &num_replicas_node));
|
||||
|
||||
NodeDef* numerator_node =
|
||||
AddBinaryNode(global_batch_size_name,
|
||||
strings::StrCat(num_workers_node->name(), ":output:0"),
|
||||
strings::StrCat(num_replicas_node->name(), ":output:0"),
|
||||
kAddOp, DT_INT64, fdef);
|
||||
numerator_node = AddBinaryNode(
|
||||
strings::StrCat(numerator_node->name(), ":z:0"),
|
||||
@ -352,14 +352,14 @@ Status MakeNewBatchSizeNode(const string& global_batch_size_name,
|
||||
|
||||
*result =
|
||||
AddBinaryNode(strings::StrCat(numerator_node->name(), ":z:0"),
|
||||
strings::StrCat(num_workers_node->name(), ":output:0"),
|
||||
strings::StrCat(num_replicas_node->name(), ":output:0"),
|
||||
kTruncateDivOp, DT_INT64, fdef);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Given a "batch" dataset node, we replace the `batch_size` input with a new
|
||||
// input that corresponds to the original input divided by `num_workers`.
|
||||
Status MutateBatchSize(const NodeDef& node, int64 num_workers,
|
||||
// input that corresponds to the original input divided by `num_replicas`.
|
||||
Status MutateBatchSize(const NodeDef& node, int64 num_replicas,
|
||||
MutableGraphView* graph) {
|
||||
// For all the batching datasets the batch_size is input number 1 except for
|
||||
// MapAndBatchDataset.
|
||||
@ -369,8 +369,8 @@ Status MutateBatchSize(const NodeDef& node, int64 num_workers,
|
||||
int64 batch_size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
graph_utils::GetScalarConstNodeValue(*batch_size_node, &batch_size));
|
||||
DCHECK_EQ(batch_size % num_workers, 0);
|
||||
batch_size = batch_size / num_workers;
|
||||
DCHECK_EQ(batch_size % num_replicas, 0);
|
||||
batch_size = batch_size / num_replicas;
|
||||
NodeDef* new_batch_size_node =
|
||||
graph_utils::AddScalarConstNode<int64>(batch_size, graph);
|
||||
// We don't call UpdateFanouts here because CSE elimination might lead to
|
||||
@ -413,8 +413,8 @@ Status AddFlatMapNode(const string& input_dataset,
|
||||
// def flat_map_fn(*batched_components):
|
||||
// ds = tf.data.Dataset.from_tensor_slices(batched_components)
|
||||
// return ds.batch(minibatch_size, drop_remainder=False)
|
||||
Status CreateFlatMapFnWithBatch(const DataTypeVector& dtypes, int64 num_workers,
|
||||
FunctionDef* result) {
|
||||
Status CreateFlatMapFnWithBatch(const DataTypeVector& dtypes,
|
||||
int64 num_replicas, FunctionDef* result) {
|
||||
NodeDef* tensor_slice_node = result->add_node_def();
|
||||
tensor_slice_node->set_op("TensorSliceDataset");
|
||||
for (int i = 0; i < dtypes.size(); ++i) {
|
||||
@ -445,7 +445,7 @@ Status CreateFlatMapFnWithBatch(const DataTypeVector& dtypes, int64 num_workers,
|
||||
function_utils::AddFunctionInput("captured_batch_size", result, DT_INT64);
|
||||
NodeDef* new_batch_size;
|
||||
TF_RETURN_IF_ERROR(MakeNewBatchSizeNode(
|
||||
original_batch_size->name(), num_workers, result, &new_batch_size));
|
||||
original_batch_size->name(), num_replicas, result, &new_batch_size));
|
||||
batch_node->add_input(strings::StrCat(new_batch_size->name(), ":z:0"));
|
||||
|
||||
// `drop_remainder` input
|
||||
@ -470,9 +470,9 @@ Status CreateFlatMapFnWithBatch(const DataTypeVector& dtypes, int64 num_workers,
|
||||
// in a step adds up to the global batch size. However, since this adds
|
||||
// additional data copies (both from_tensor_slices and batch), we only use
|
||||
// this approach when necessary, i.e. when we need to drop remainder on the
|
||||
// global batch, or when the global batch size does not divide num_workers
|
||||
// global batch, or when the global batch size does not divide num_replicas
|
||||
// evenly.
|
||||
Status AppendFlatMap(const NodeDef& batch_node, int64 num_workers,
|
||||
Status AppendFlatMap(const NodeDef& batch_node, int64 num_replicas,
|
||||
FunctionLibraryDefinition* flib, MutableGraphView* graph) {
|
||||
// `.flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x).
|
||||
// batch(minibatch_size, drop_remainder=False))`
|
||||
@ -484,7 +484,7 @@ Status AppendFlatMap(const NodeDef& batch_node, int64 num_workers,
|
||||
TF_RETURN_IF_ERROR(
|
||||
graph_utils::GetDatasetOutputTypesAttr(batch_node, &dtypes));
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateFlatMapFnWithBatch(dtypes, num_workers, &flat_map_fn));
|
||||
CreateFlatMapFnWithBatch(dtypes, num_replicas, &flat_map_fn));
|
||||
|
||||
int64 batch_size_index = GetBatchSizeArgIndex(batch_node);
|
||||
|
||||
@ -496,7 +496,7 @@ Status AppendFlatMap(const NodeDef& batch_node, int64 num_workers,
|
||||
// Because the flat map function uses drop_remainder = False,
|
||||
// the shape might be unknown
|
||||
auto old_dim = shape.dim(0).size();
|
||||
auto new_dim = old_dim % num_workers == 0 ? old_dim / num_workers : -1;
|
||||
auto new_dim = old_dim % num_replicas == 0 ? old_dim / num_replicas : -1;
|
||||
shape.mutable_dim(0)->set_size(new_dim);
|
||||
}
|
||||
}
|
||||
@ -514,12 +514,13 @@ Status AppendFlatMap(const NodeDef& batch_node, int64 num_workers,
|
||||
|
||||
// There are several things we do here, depending on the values of
|
||||
// batch_size and drop_remainder.
|
||||
// (1) If batch size is known and divisible by num_workers, and drop_remainder
|
||||
// (1) If batch size is known and divisible by num_replicas, and drop_remainder
|
||||
// is known to be False, we mutate the batch size directly.
|
||||
// .batch(global_batch_size) -> .batch(global_batch_size // num_workers)
|
||||
// .batch(global_batch_size) -> .batch(global_batch_size // num_replicas)
|
||||
// (2) Otherwise, we add a flat_map transformation to preserve the global batch
|
||||
// size across the workers and to preserve the drop remainder behavior.
|
||||
bool ShouldMutateBatchSizeDirectly(const NodeDef& batch_node, int64 num_workers,
|
||||
// size across the replicas and to preserve the drop remainder behavior.
|
||||
bool ShouldMutateBatchSizeDirectly(const NodeDef& batch_node,
|
||||
int64 num_replicas,
|
||||
MutableGraphView* graph) {
|
||||
int64 batch_size_arg_index = GetBatchSizeArgIndex(batch_node);
|
||||
NodeDef* batch_size_node =
|
||||
@ -528,9 +529,9 @@ bool ShouldMutateBatchSizeDirectly(const NodeDef& batch_node, int64 num_workers,
|
||||
int64 batch_size;
|
||||
Status s =
|
||||
graph_utils::GetScalarConstNodeValue(*batch_size_node, &batch_size);
|
||||
// If batch size is unknown or indivisible by num workers, we don't
|
||||
// If batch size is unknown or indivisible by num replicas, we don't
|
||||
// mutate it directly
|
||||
if (!s.ok() || batch_size % num_workers != 0) return false;
|
||||
if (!s.ok() || batch_size % num_replicas != 0) return false;
|
||||
|
||||
if (batch_node.op() == kBatchOp || batch_node.op() == kPaddedBatchOp) {
|
||||
// These ops don't have a `drop_remainder` input, and behave like
|
||||
@ -547,16 +548,16 @@ bool ShouldMutateBatchSizeDirectly(const NodeDef& batch_node, int64 num_workers,
|
||||
return s.ok() && !drop_remainder;
|
||||
}
|
||||
|
||||
Status RewriteBatchNode(const NodeDef& batch_node, int64 num_workers,
|
||||
Status RewriteBatchNode(const NodeDef& batch_node, int64 num_replicas,
|
||||
FunctionLibraryDefinition* flib,
|
||||
MutableGraphView* graph) {
|
||||
if (ShouldMutateBatchSizeDirectly(batch_node, num_workers, graph)) {
|
||||
return MutateBatchSize(batch_node, num_workers, graph);
|
||||
if (ShouldMutateBatchSizeDirectly(batch_node, num_replicas, graph)) {
|
||||
return MutateBatchSize(batch_node, num_replicas, graph);
|
||||
}
|
||||
return AppendFlatMap(batch_node, num_workers, flib, graph);
|
||||
return AppendFlatMap(batch_node, num_replicas, flib, graph);
|
||||
}
|
||||
|
||||
Status OptimizeGraph(const GrapplerItem& item, int64 num_workers,
|
||||
Status OptimizeGraph(const GrapplerItem& item, int64 num_replicas,
|
||||
bool use_fallback, GraphDef* output);
|
||||
|
||||
// Helper function that starts from a node in the graph and recurses into its
|
||||
@ -567,16 +568,16 @@ Status OptimizeGraph(const GrapplerItem& item, int64 num_workers,
|
||||
// as they are datasets themselves.
|
||||
// 3. Core dataset ops + Identity op: Recurses into first input parameter.
|
||||
// 4. FlatMap type mapping dataset ops: Recurses into the function definition.
|
||||
Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers,
|
||||
Status RecursivelyHandleOp(const NodeDef& node, int64 num_replicas,
|
||||
bool use_fallback, FunctionLibraryDefinition* flib,
|
||||
MutableGraphView* graph) {
|
||||
if (IsDatasetNodeOfType(node, kBatchDatasetOps)) {
|
||||
TF_RETURN_IF_ERROR(RewriteBatchNode(node, num_workers, flib, graph));
|
||||
TF_RETURN_IF_ERROR(RewriteBatchNode(node, num_replicas, flib, graph));
|
||||
} else if (IsDatasetNodeOfType(node, kMultipleInputsDatasetOps)) {
|
||||
// For all multiple input datasets, all inputs are datasets themselves.
|
||||
for (int i = 0; i < node.input_size(); ++i) {
|
||||
NodeDef* input_node = graph_utils::GetInputNode(node, *graph, i);
|
||||
TF_RETURN_IF_ERROR(RecursivelyHandleOp(*input_node, num_workers,
|
||||
TF_RETURN_IF_ERROR(RecursivelyHandleOp(*input_node, num_replicas,
|
||||
use_fallback, flib, graph));
|
||||
}
|
||||
} else if (IsDatasetNodeOfType(node, kPassThroughOps) || IsRetval(node)) {
|
||||
@ -584,7 +585,7 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers,
|
||||
// function body graph in place of function outputs, the input dataset is
|
||||
// input 0.
|
||||
NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
|
||||
TF_RETURN_IF_ERROR(RecursivelyHandleOp(*input_node, num_workers,
|
||||
TF_RETURN_IF_ERROR(RecursivelyHandleOp(*input_node, num_replicas,
|
||||
use_fallback, flib, graph));
|
||||
} else if (IsDatasetNodeOfType(node, kFuncDatasetOps)) {
|
||||
const string func_name =
|
||||
@ -594,7 +595,7 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers,
|
||||
TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
|
||||
*fdef, *flib, graph->graph()->versions().producer(), &f_item));
|
||||
GraphDef optimized_func_graph;
|
||||
TF_RETURN_IF_ERROR(OptimizeGraph(f_item, num_workers, use_fallback,
|
||||
TF_RETURN_IF_ERROR(OptimizeGraph(f_item, num_replicas, use_fallback,
|
||||
&optimized_func_graph));
|
||||
|
||||
// Function body optimization might have created new specialized
|
||||
@ -623,7 +624,7 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers,
|
||||
}
|
||||
// If we've successfully updated the batch size of this node or any nodes
|
||||
// in the dataset tree rooted in this node, we update the output_shapes attr.
|
||||
TF_RETURN_IF_ERROR(UpdateOutputShapes(node.name(), num_workers, graph));
|
||||
TF_RETURN_IF_ERROR(UpdateOutputShapes(node.name(), num_replicas, graph));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -689,7 +690,7 @@ Status CreateFlatMapFnWithReshape(int new_batch_dim,
|
||||
|
||||
// For each component of the dataset, we reshape it from shape
|
||||
// (old_batch_size, ...) to (-1, new_batch_size, ...)
|
||||
// where new_batch_size = (old_batch_size + num_workers - 1) // num_workers
|
||||
// where new_batch_size = (old_batch_size + num_replicas - 1) // num_replicas
|
||||
for (int i = 0; i < types.size(); ++i) {
|
||||
auto* input_arg = function_utils::AddFunctionInput(
|
||||
strings::StrCat("args_", i), result, types.at(i));
|
||||
@ -733,13 +734,13 @@ Status CreateFlatMapFnWithReshape(int new_batch_dim,
|
||||
// return tf.data.Dataset.from_tensor_slices(
|
||||
// tf.reshape(
|
||||
// x,
|
||||
// tf.concat([[-1, old_batch_dim / num_workers], tf.shape(x)[1:]], 0)
|
||||
// tf.concat([[-1, old_batch_dim / num_replicas], tf.shape(x)[1:]], 0)
|
||||
// )
|
||||
// )
|
||||
//
|
||||
// dataset = dataset.flat_map(fn)
|
||||
// ```
|
||||
Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_workers,
|
||||
Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_replicas,
|
||||
FunctionLibraryDefinition* flib,
|
||||
MutableGraphView* graph) {
|
||||
if (IsRetval(*fetch_node) || fetch_node->op() == kIdentityOp) {
|
||||
@ -762,10 +763,10 @@ Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_workers,
|
||||
}
|
||||
int batch_dim;
|
||||
TF_RETURN_IF_ERROR(GetBatchDim(output_shapes, &batch_dim));
|
||||
if (batch_dim % num_workers != 0) {
|
||||
if (batch_dim % num_replicas != 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot use rebatching fallback when batch dimension doesn't divide "
|
||||
"num_workers evenly.");
|
||||
"num_replicas evenly.");
|
||||
}
|
||||
|
||||
// Create the flat map fn
|
||||
@ -778,7 +779,7 @@ Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_workers,
|
||||
DataTypeVector output_types;
|
||||
TF_RETURN_IF_ERROR(
|
||||
graph_utils::GetDatasetOutputTypesAttr(*fetch_node, &output_types));
|
||||
TF_RETURN_IF_ERROR(CreateFlatMapFnWithReshape(batch_dim / num_workers,
|
||||
TF_RETURN_IF_ERROR(CreateFlatMapFnWithReshape(batch_dim / num_replicas,
|
||||
output_types, &flat_map_fn));
|
||||
|
||||
NodeDef* flat_map_node;
|
||||
@ -786,7 +787,7 @@ Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_workers,
|
||||
{}, {}, flat_map_fn, output_shapes,
|
||||
output_types, flib, graph, &flat_map_node));
|
||||
TF_RETURN_IF_ERROR(
|
||||
UpdateOutputShapes(flat_map_node->name(), num_workers, graph));
|
||||
UpdateOutputShapes(flat_map_node->name(), num_replicas, graph));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
graph->UpdateFanouts(fetch_node->name(), flat_map_node->name()));
|
||||
@ -797,7 +798,7 @@ Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_workers,
|
||||
// Helper function that given a GrapplerItem generates a mutated graph def
|
||||
// with the batch size changed. The GrapplerItem could be generated from the
|
||||
// main graph or could be a function graph.
|
||||
Status OptimizeGraph(const GrapplerItem& item, int64 num_workers,
|
||||
Status OptimizeGraph(const GrapplerItem& item, int64 num_replicas,
|
||||
bool use_fallback, GraphDef* output) {
|
||||
*output = item.graph;
|
||||
MutableGraphView graph(output);
|
||||
@ -807,8 +808,8 @@ Status OptimizeGraph(const GrapplerItem& item, int64 num_workers,
|
||||
NodeDef* sink_node;
|
||||
TF_RETURN_IF_ERROR(graph_utils::GetFetchNode(graph, item, &sink_node));
|
||||
|
||||
Status s =
|
||||
RecursivelyHandleOp(*sink_node, num_workers, use_fallback, &flib, &graph);
|
||||
Status s = RecursivelyHandleOp(*sink_node, num_replicas, use_fallback, &flib,
|
||||
&graph);
|
||||
if (!s.ok()) {
|
||||
if (use_fallback) {
|
||||
VLOG(1) << "Failed to rebatch by rewriting the batch transformation ("
|
||||
@ -818,7 +819,7 @@ Status OptimizeGraph(const GrapplerItem& item, int64 num_workers,
|
||||
*output = item.graph;
|
||||
graph = MutableGraphView(output);
|
||||
TF_RETURN_IF_ERROR(
|
||||
RebatchWithFallback(sink_node, num_workers, &flib, &graph));
|
||||
RebatchWithFallback(sink_node, num_replicas, &flib, &graph));
|
||||
} else {
|
||||
// Return the error
|
||||
return s;
|
||||
@ -837,7 +838,7 @@ Status RebatchOptimizer::OptimizeAndCollectStats(Cluster* cluster,
|
||||
*output = item.graph;
|
||||
MutableGraphView graph(output);
|
||||
|
||||
TF_RETURN_IF_ERROR(OptimizeGraph(item, num_workers_, use_fallback_, output));
|
||||
TF_RETURN_IF_ERROR(OptimizeGraph(item, num_replicas_, use_fallback_, output));
|
||||
stats->num_changes++;
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// This optimizer changes the batch size of the output dataset by dividing the
|
||||
// current batch size by parameter `num_workers`. Currently, this works only
|
||||
// current batch size by parameter `num_replicas`. Currently, this works only
|
||||
// for very simple pipelines with a single BatchDatasetV2 transformation.
|
||||
class RebatchOptimizer : public TFDataOptimizerBase {
|
||||
public:
|
||||
@ -43,7 +43,7 @@ class RebatchOptimizer : public TFDataOptimizerBase {
|
||||
const GraphDef& optimize_output, double result) override;
|
||||
|
||||
private:
|
||||
int64 num_workers_;
|
||||
int64 num_replicas_;
|
||||
bool use_fallback_;
|
||||
};
|
||||
|
||||
|
@ -36,14 +36,15 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
protected:
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) override {
|
||||
int64 num_workers;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_workers", &num_workers));
|
||||
int64 num_replicas;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ParseScalarArgument(ctx, "num_replicas", &num_replicas));
|
||||
OP_REQUIRES(
|
||||
ctx, num_workers > 0,
|
||||
errors::InvalidArgument("num_workers must be greater than zero."));
|
||||
ctx, num_replicas > 0,
|
||||
errors::InvalidArgument("num_replicas must be greater than zero."));
|
||||
|
||||
auto config_factory = [num_workers, this]() {
|
||||
return CreateConfig(num_workers, this->use_fallback_);
|
||||
auto config_factory = [num_replicas, this]() {
|
||||
return CreateConfig(num_replicas, this->use_fallback_);
|
||||
};
|
||||
|
||||
// We only want to optimize functions for some particular datasets like
|
||||
@ -56,17 +57,17 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
static RewriterConfig CreateConfig(int64 num_workers, bool use_fallback) {
|
||||
static RewriterConfig CreateConfig(int64 num_replicas, bool use_fallback) {
|
||||
RewriterConfig rewriter_config;
|
||||
rewriter_config.set_fail_on_optimizer_errors(true);
|
||||
rewriter_config.add_optimizers(kOptimizerName);
|
||||
rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
|
||||
auto custom_optimizer = rewriter_config.add_custom_optimizers();
|
||||
custom_optimizer->set_name(kOptimizerName);
|
||||
AttrValue num_workers_attr;
|
||||
num_workers_attr.set_i(num_workers);
|
||||
(*custom_optimizer->mutable_parameter_map())["num_workers"] =
|
||||
num_workers_attr;
|
||||
AttrValue num_replicas_attr;
|
||||
num_replicas_attr.set_i(num_replicas);
|
||||
(*custom_optimizer->mutable_parameter_map())["num_replicas"] =
|
||||
num_replicas_attr;
|
||||
AttrValue use_fallback_attr;
|
||||
use_fallback_attr.set_b(use_fallback);
|
||||
(*custom_optimizer->mutable_parameter_map())["use_fallback"] =
|
||||
|
@ -658,7 +658,7 @@ REGISTER_OP("RandomDataset")
|
||||
|
||||
REGISTER_OP("ExperimentalRebatchDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("num_workers: int64")
|
||||
.Input("num_replicas: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
@ -667,7 +667,7 @@ REGISTER_OP("ExperimentalRebatchDataset")
|
||||
|
||||
REGISTER_OP("RebatchDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("num_workers: int64")
|
||||
.Input("num_replicas: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
|
@ -58,7 +58,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
def testBasic(self, drop_remainder):
|
||||
dataset = dataset_ops.Dataset.range(1024).batch(
|
||||
32, drop_remainder=drop_remainder)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual([[8] if drop_remainder else [None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
|
||||
@ -67,15 +67,15 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def testScalarInputError(self):
|
||||
dataset = dataset_ops.Dataset.range(1024)
|
||||
distribute._RebatchDataset(dataset.batch(4), num_workers=4)
|
||||
distribute._RebatchDataset(dataset.batch(4), num_replicas=4)
|
||||
with self.assertRaisesRegexp(ValueError, "at least one dimension"):
|
||||
distribute._RebatchDataset(dataset, num_workers=4)
|
||||
distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
|
||||
@parameterized.named_parameters(drop_remainder_cases)
|
||||
def testBatchNotDivisibleByNumWorkers(self, drop_remainder):
|
||||
def testBatchNotDivisibleByNumReplicas(self, drop_remainder):
|
||||
dataset = dataset_ops.Dataset.range(1024).batch(
|
||||
32, drop_remainder=drop_remainder)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
|
||||
self.assertEqual([[None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
expected_output = []
|
||||
@ -92,7 +92,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def testTupleOutput(self):
|
||||
dataset = dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(32)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
expected_output = [([k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension
|
||||
[k for k in range(i, i + 8)])
|
||||
for i in range(0, 1024, 8)]
|
||||
@ -101,7 +101,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
def testNestedDictionaryOutput(self):
|
||||
dataset = dataset_ops.Dataset.range(1024).map(
|
||||
lambda x: {"a": x, "b": {"c": x}}).batch(32)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
expected_output = [{"a": [k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension
|
||||
"b": {"c": [k for k in range(i, i + 8)]}}
|
||||
for i in range(0, 1024, 8)]
|
||||
@ -111,7 +111,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
def testFinalPartialBatch(self, drop_remainder):
|
||||
dataset = dataset_ops.Dataset.range(1032).batch(
|
||||
32, drop_remainder=drop_remainder)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual([[8] if drop_remainder else [None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
|
||||
@ -126,7 +126,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
def testFinalPartialBatchAfterRebatch(self, drop_remainder):
|
||||
dataset = dataset_ops.Dataset.range(34).batch(
|
||||
32, drop_remainder=drop_remainder)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual([[8] if drop_remainder else [None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
|
||||
@ -158,7 +158,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
def testMapAndBatch(self):
|
||||
dataset = dataset_ops.Dataset.range(1024).apply(
|
||||
batching.map_and_batch(math_ops.square, 32))
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual([[None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
expected_output = [[k**2 for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension
|
||||
@ -169,7 +169,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
captured_t = variables.Variable(42)
|
||||
dataset = dataset_ops.Dataset.range(1024).apply(
|
||||
batching.map_and_batch(lambda x: captured_t, 32))
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual([[None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
expected_output = [[42 for _ in range(i, i + 8)] # pylint: disable=g-complex-comprehension
|
||||
@ -182,7 +182,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = dataset_ops.Dataset.range(128).batch(
|
||||
4, drop_remainder=True).padded_batch(
|
||||
8, padded_shapes=[5])
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
# Each element is a list of 8 elements in which each element is a list of 5
|
||||
# elements, first four are numbers and the last one is a padded zero.
|
||||
expected_output = [[[j, j + 1, j + 2, j + 3, 0] # pylint: disable=g-complex-comprehension
|
||||
@ -202,7 +202,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset1 = dataset_ops.Dataset.range(64).batch(8)
|
||||
dataset2 = dataset_ops.Dataset.range(32).batch(8)
|
||||
dataset = dataset1.concatenate(dataset2)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual([[None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
expected_output = ([[i, i + 1] for i in range(0, 64, 2)] +
|
||||
@ -213,7 +213,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset1 = dataset_ops.Dataset.range(64).batch(16)
|
||||
dataset2 = dataset_ops.Dataset.range(32).batch(8)
|
||||
dataset = dataset1.concatenate(dataset2)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual(
|
||||
[[None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
@ -225,7 +225,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset1 = dataset_ops.Dataset.range(64).batch(8)
|
||||
dataset2 = dataset_ops.Dataset.range(32).batch(8)
|
||||
dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual([[None], [None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)]
|
||||
@ -235,7 +235,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset1 = dataset_ops.Dataset.range(64).batch(16)
|
||||
dataset2 = dataset_ops.Dataset.range(32).batch(8)
|
||||
dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual([[None], [None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
expected_output = [([2 * i, 2 * i + 1, 2 * i + 2, 2 * i + 3], [i, i + 1])
|
||||
@ -246,7 +246,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = dataset_ops.Dataset.range(1024).batch(32).apply(sleep.sleep(10))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
rebatched_dataset = distribute._RebatchDataset(
|
||||
dataset, num_workers=4, use_fallback=False)
|
||||
dataset, num_replicas=4, use_fallback=False)
|
||||
next_element = self.getNext(rebatched_dataset)
|
||||
self.evaluate(next_element())
|
||||
|
||||
@ -256,7 +256,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
32).apply(sleep.sleep(10)))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
rebatched_dataset = distribute._RebatchDataset(
|
||||
dataset, num_workers=4, use_fallback=False)
|
||||
dataset, num_replicas=4, use_fallback=False)
|
||||
next_element = self.getNext(rebatched_dataset)
|
||||
self.evaluate(next_element())
|
||||
|
||||
@ -268,7 +268,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension
|
||||
self.assertDatasetProduces(dataset, expected_output)
|
||||
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual([[None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
# Two elements where each element is a list of 4 elements where each element
|
||||
@ -287,7 +287,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension
|
||||
self.assertDatasetProduces(dataset, expected_output)
|
||||
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual([[None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
# List of 4 elements where each element is a list of 8 numbering from 0 to
|
||||
@ -307,7 +307,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension
|
||||
self.assertDatasetProduces(dataset, expected_output)
|
||||
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual([[None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
# List of 4 elements where each element is a list of 8 numbering from 0 to
|
||||
@ -325,7 +325,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = dataset.apply(
|
||||
grouping.group_by_window(
|
||||
key_func=lambda x: x[0] % 4, reduce_func=reduce_fn, window_size=10))
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=2)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=2)
|
||||
|
||||
self.assertEqual([[None, 3]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
@ -348,7 +348,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = dataset.apply(
|
||||
grouping.group_by_window(
|
||||
key_func=lambda x: x, reduce_func=reduce_fn, window_size=10))
|
||||
dataset = distribute._RebatchDataset(dataset, num_workers=2)
|
||||
dataset = distribute._RebatchDataset(dataset, num_replicas=2)
|
||||
|
||||
self.assertEqual([[None]],
|
||||
[ts.as_list() for ts in _flat_shapes(dataset)])
|
||||
@ -373,7 +373,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = dataset.apply(
|
||||
grouping.group_by_window(
|
||||
key_func=lambda x: x, reduce_func=reduce_fn, window_size=11))
|
||||
dataset = distribute._RebatchDataset(dataset, num_workers=2)
|
||||
dataset = distribute._RebatchDataset(dataset, num_replicas=2)
|
||||
|
||||
self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
|
||||
|
||||
@ -398,7 +398,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = dataset.apply(
|
||||
grouping.group_by_window(
|
||||
key_func=lambda x: x, reduce_func=reduce_fn, window_size=11))
|
||||
dataset = distribute._RebatchDataset(dataset, num_workers=2)
|
||||
dataset = distribute._RebatchDataset(dataset, num_replicas=2)
|
||||
|
||||
self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
|
||||
|
||||
@ -412,7 +412,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
def testScanAfterBatch(self):
|
||||
dataset = dataset_ops.Dataset.range(40).batch(10).apply(
|
||||
scan_ops.scan(np.int64(2), lambda state, value: (state, value * state)))
|
||||
dataset = distribute._RebatchDataset(dataset, num_workers=2)
|
||||
dataset = distribute._RebatchDataset(dataset, num_replicas=2)
|
||||
|
||||
self.assertEqual([[None]],
|
||||
[ts.as_list() for ts in _flat_shapes(dataset)])
|
||||
@ -442,7 +442,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
num_epochs=1,
|
||||
drop_final_batch=False)
|
||||
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
|
||||
self.assertEqual([[None]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
@ -459,7 +459,7 @@ class RebatchDatasetFallbackTest(test_base.DatasetTestBase):
|
||||
def testWithNoBatchDataset(self):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||
[[k for k in range(i, i + 32)] for i in range(0, 1024, 32)]) # pylint: disable=g-complex-comprehension
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual([[32]], [ts.as_list() for ts in _flat_shapes(dataset)])
|
||||
self.assertEqual([[8]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
@ -470,7 +470,7 @@ class RebatchDatasetFallbackTest(test_base.DatasetTestBase):
|
||||
def testWithUnhandledTransformation(self):
|
||||
dataset = dataset_ops.Dataset.range(1024).batch(
|
||||
32, drop_remainder=True).apply(sleep.sleep(10))
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
self.assertEqual([[32]], [ts.as_list() for ts in _flat_shapes(dataset)])
|
||||
self.assertEqual([[8]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
@ -482,7 +482,7 @@ class RebatchDatasetFallbackTest(test_base.DatasetTestBase):
|
||||
dataset = dataset_ops.Dataset.range(2).flat_map(
|
||||
lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
|
||||
32, drop_remainder=True).apply(sleep.sleep(10)))
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
|
||||
self.assertEqual([[8]],
|
||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||
@ -500,7 +500,7 @@ class RebatchDatasetFallbackTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"Cannot use rebatching fallback"):
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
next_element = self.getNext(rebatched_dataset)
|
||||
self.evaluate(next_element())
|
||||
|
||||
@ -512,11 +512,11 @@ class RebatchDatasetFallbackTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"Cannot use rebatching fallback"):
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
next_element = self.getNext(rebatched_dataset)
|
||||
self.evaluate(next_element())
|
||||
|
||||
def testBatchSizeIndivisibleByNumWorkers(self):
|
||||
def testBatchSizeNotDivisibleByNumReplicas(self):
|
||||
# This doesn't work; reshape requires tensor shape to be exactly divisible
|
||||
# by the second dim.
|
||||
dataset = dataset_ops.Dataset.range(64).batch(
|
||||
@ -524,7 +524,7 @@ class RebatchDatasetFallbackTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"Cannot use rebatching fallback"):
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
|
||||
next_element = self.getNext(rebatched_dataset)
|
||||
self.evaluate(next_element())
|
||||
|
||||
@ -532,7 +532,7 @@ class RebatchDatasetFallbackTest(test_base.DatasetTestBase):
|
||||
dataset = dataset_ops.Dataset.from_tensors((np.arange(10), np.arange(5)))
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"Cannot use rebatching fallback"):
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5)
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
|
||||
next_element = self.getNext(rebatched_dataset)
|
||||
self.evaluate(next_element())
|
||||
|
||||
|
@ -32,7 +32,7 @@ class RebatchDatasetSerializationTest(
|
||||
return distribute._RebatchDataset(
|
||||
dataset_ops.Dataset.range(num_elements).batch(
|
||||
4 * batch_size, drop_remainder=True),
|
||||
num_workers=4)
|
||||
num_replicas=4)
|
||||
|
||||
self.run_core_tests(lambda: build_dataset(200, 10), 20)
|
||||
|
||||
|
@ -67,28 +67,28 @@ def _AutoShardDatasetV1(input_dataset, num_workers, index): # pylint: disable=i
|
||||
|
||||
|
||||
class _RebatchDataset(dataset_ops.UnaryDataset):
|
||||
"""A `Dataset` that divides the batch size by `num_workers`.
|
||||
"""A `Dataset` that divides the batch size by `num_replicas`.
|
||||
|
||||
For each batch in the input dataset, the resulting dataset will produce
|
||||
`num_replicas` minibatches whose sizes add up to the original batch size.
|
||||
"""
|
||||
|
||||
def __init__(self, input_dataset, num_workers, use_fallback=True):
|
||||
def __init__(self, input_dataset, num_replicas, use_fallback=True):
|
||||
self._input_dataset = input_dataset
|
||||
|
||||
def recalculate_output_shapes(output_shapes):
|
||||
"""Recalculates the output_shapes after dividing it by num_workers."""
|
||||
"""Recalculates the output_shapes after dividing it by num_replicas."""
|
||||
if len(output_shapes) < 1:
|
||||
raise ValueError(
|
||||
"Input shape should have at least one dimension. "
|
||||
"Perhaps your input dataset is not batched?")
|
||||
output_dims = [d.value for d in output_shapes.dims]
|
||||
|
||||
if output_dims[0] is not None and output_dims[0] % num_workers == 0:
|
||||
output_dims[0] = output_dims[0] // num_workers
|
||||
if output_dims[0] is not None and output_dims[0] % num_replicas == 0:
|
||||
output_dims[0] = output_dims[0] // num_replicas
|
||||
else:
|
||||
# Set the batch dimension to unknown. If the global batch size does not
|
||||
# divide num_workers evenly, the minibatches may have different sizes.
|
||||
# divide num_replicas evenly, the minibatches may have different sizes.
|
||||
output_dims[0] = None
|
||||
return tensor_shape.TensorShape(output_dims)
|
||||
|
||||
@ -102,13 +102,13 @@ class _RebatchDataset(dataset_ops.UnaryDataset):
|
||||
if compat.forward_compatible(2019, 8, 13) or not use_fallback:
|
||||
variant_tensor = ged_ops.rebatch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
num_workers=num_workers,
|
||||
num_replicas=num_replicas,
|
||||
use_fallback=use_fallback,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.rebatch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
num_workers=num_workers,
|
||||
num_replicas=num_replicas,
|
||||
**self._flat_structure)
|
||||
super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
@ -1286,7 +1286,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "ExperimentalRebatchDataset"
|
||||
argspec: "args=[\'input_dataset\', \'num_workers\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'input_dataset\', \'num_replicas\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ExperimentalScanDataset"
|
||||
@ -3010,7 +3010,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "RebatchDataset"
|
||||
argspec: "args=[\'input_dataset\', \'num_workers\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'input_dataset\', \'num_replicas\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Reciprocal"
|
||||
|
@ -1286,7 +1286,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "ExperimentalRebatchDataset"
|
||||
argspec: "args=[\'input_dataset\', \'num_workers\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'input_dataset\', \'num_replicas\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ExperimentalScanDataset"
|
||||
@ -3010,7 +3010,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "RebatchDataset"
|
||||
argspec: "args=[\'input_dataset\', \'num_workers\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'input_dataset\', \'num_replicas\', \'output_types\', \'output_shapes\', \'use_fallback\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Reciprocal"
|
||||
|
Loading…
Reference in New Issue
Block a user