[tf.data] Make rebatching fallback work for datasets with unknown shapes.
PiperOrigin-RevId: 264501770
This commit is contained in:
parent
6208021e3d
commit
7107f907aa
@ -200,11 +200,13 @@ Status AddConstBoolNode(bool value, FunctionDef* fdef, NodeDef** result) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AddShapeNode(const NodeDefBuilder::NodeOut& input, FunctionDef* fdef,
|
||||
NodeDef** result) {
|
||||
Status AddShapeNode(const NodeDefBuilder::NodeOut& input, DataType out_type,
|
||||
FunctionDef* fdef, NodeDef** result) {
|
||||
*result = fdef->add_node_def();
|
||||
TF_RETURN_IF_ERROR(
|
||||
NodeDefBuilder("", "Shape").Input(input).Finalize(*result));
|
||||
TF_RETURN_IF_ERROR(NodeDefBuilder("", "Shape")
|
||||
.Input(input)
|
||||
.Attr("out_type", out_type)
|
||||
.Finalize(*result));
|
||||
function_utils::SetUniqueFunctionNodeName("rebatch/shape", fdef, *result);
|
||||
return Status::OK();
|
||||
}
|
||||
@ -276,34 +278,49 @@ void SetUnknownShapes(int num_components, AttrValue* output_shapes) {
|
||||
}
|
||||
}
|
||||
|
||||
Status GetBatchDim(AttrValue output_shapes, int* batch_dim) {
|
||||
const auto& shape_0 = output_shapes.list().shape(0);
|
||||
if (shape_0.unknown_rank() || shape_0.dim(0).size() == -1) {
|
||||
// If the batch dimension is known and divisible by num_replicas, we set
|
||||
// result = batch_dim / num_replicas. If the batch dimension is unknown,
|
||||
// result = -1. If the dataset node is missing an output shapes attr,
|
||||
// or the batch dimensions of its components don't match, we return an error
|
||||
// status.
|
||||
Status GetMinibatchDimForReshape(const NodeDef& dataset_node,
|
||||
int64 num_replicas, int64* result) {
|
||||
AttrValue output_shapes;
|
||||
if (!dataset_node.attr().contains(kOutputShapesAttr)) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot use rebatching fallback when 0th dimensions of dataset "
|
||||
"components are not fully known. Component 0 has shape: ",
|
||||
shape_0.ShortDebugString());
|
||||
"Cannot use rebatching fallback when the final dataset node does not "
|
||||
"have an `output_shapes` attr. Node: ",
|
||||
dataset_node.name(), " Op: ", dataset_node.op());
|
||||
}
|
||||
output_shapes = dataset_node.attr().at(kOutputShapesAttr);
|
||||
|
||||
*batch_dim = output_shapes.list().shape(0).dim(0).size();
|
||||
|
||||
for (int i = 1; i < output_shapes.list().shape_size(); ++i) {
|
||||
// Get the batch dimension by checking the 0th dimension of all the inputs.
|
||||
int batch_dim = -1;
|
||||
for (int i = 0; i < output_shapes.list().shape_size(); ++i) {
|
||||
const auto& shape_i = output_shapes.list().shape(i);
|
||||
|
||||
if (shape_i.unknown_rank() || shape_i.dim(0).size() == -1) {
|
||||
// If unknown, ignore.
|
||||
if (shape_i.unknown_rank()) continue;
|
||||
int batch_dim_i = shape_i.dim(0).size();
|
||||
if (batch_dim_i == -1) continue;
|
||||
|
||||
// Update batch_dim with known dimension.
|
||||
if (batch_dim_i != batch_dim && batch_dim != -1) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot use rebatching fallback when 0th dimensions of dataset "
|
||||
"components are not fully known. Component ",
|
||||
i, " has shape: ", shape_i.ShortDebugString());
|
||||
}
|
||||
if (shape_i.dim(0).size() != *batch_dim) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot use rebatching fallback when 0th dimensions of dataset "
|
||||
"Cannot use rebatching fallback: 0th dimensions of dataset "
|
||||
"components don't match. Component ",
|
||||
i, " has batch dimension: ", shape_i.dim(0).size(),
|
||||
" while previous components have batch dimension: ", *batch_dim);
|
||||
i, " has batch dimension: ", batch_dim_i,
|
||||
" while previous components have batch dimension: ", batch_dim);
|
||||
}
|
||||
batch_dim = batch_dim_i;
|
||||
}
|
||||
|
||||
if (batch_dim == -1 || batch_dim % num_replicas != 0) {
|
||||
*result = -1;
|
||||
} else {
|
||||
*result = batch_dim / num_replicas;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -411,6 +428,8 @@ Status AddFlatMapNode(const string& input_dataset,
|
||||
}
|
||||
|
||||
// def flat_map_fn(*batched_components):
|
||||
// batch_size = tf.shape(batched_components[0])[0]
|
||||
// minibatch_size = (batch_size + num_replicas - 1) // num_replicas
|
||||
// ds = tf.data.Dataset.from_tensor_slices(batched_components)
|
||||
// return ds.batch(minibatch_size, drop_remainder=False)
|
||||
Status CreateFlatMapFnWithBatch(const DataTypeVector& dtypes,
|
||||
@ -439,13 +458,32 @@ Status CreateFlatMapFnWithBatch(const DataTypeVector& dtypes,
|
||||
batch_node->add_input(
|
||||
strings::StrCat(tensor_slice_node->name(), ":handle:0"));
|
||||
|
||||
// `batch_size` input
|
||||
// Here, we capture the original batch size from outside the flat map fn.
|
||||
auto* original_batch_size =
|
||||
function_utils::AddFunctionInput("captured_batch_size", result, DT_INT64);
|
||||
// `batch_size` is tf.shape(arg)[0]
|
||||
NodeDef* shape;
|
||||
TF_RETURN_IF_ERROR(AddShapeNode({tensor_slice_node->input(0), 0, dtypes[0]},
|
||||
DT_INT64, result, &shape));
|
||||
|
||||
// Const with value [0]
|
||||
NodeDef* const_vec_0;
|
||||
TF_RETURN_IF_ERROR(AddConstIntNode({0}, {1}, result, &const_vec_0));
|
||||
|
||||
// Const with value [1]
|
||||
NodeDef* const_vec_1;
|
||||
TF_RETURN_IF_ERROR(AddConstIntNode({1}, {1}, result, &const_vec_1));
|
||||
|
||||
// Extracts the 0th dimension from the shape node.
|
||||
NodeDef* original_batch_size;
|
||||
TF_RETURN_IF_ERROR(AddStridedSliceNode(
|
||||
{strings::StrCat(shape->name(), ":output"), 0, DT_INT64},
|
||||
{strings::StrCat(const_vec_0->name(), ":output"), 0, DT_INT32},
|
||||
{strings::StrCat(const_vec_1->name(), ":output"), 0, DT_INT32},
|
||||
{strings::StrCat(const_vec_1->name(), ":output"), 0, DT_INT32}, DT_INT32,
|
||||
0, 0, 0, 0, 1, result, &original_batch_size));
|
||||
|
||||
NodeDef* new_batch_size;
|
||||
TF_RETURN_IF_ERROR(MakeNewBatchSizeNode(
|
||||
original_batch_size->name(), num_replicas, result, &new_batch_size));
|
||||
strings::StrCat(original_batch_size->name(), ":output:0"), num_replicas,
|
||||
result, &new_batch_size));
|
||||
batch_node->add_input(strings::StrCat(new_batch_size->name(), ":z:0"));
|
||||
|
||||
// `drop_remainder` input
|
||||
@ -486,8 +524,6 @@ Status AppendFlatMap(const NodeDef& batch_node, int64 num_replicas,
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateFlatMapFnWithBatch(dtypes, num_replicas, &flat_map_fn));
|
||||
|
||||
int64 batch_size_index = GetBatchSizeArgIndex(batch_node);
|
||||
|
||||
NodeDef* flat_map_node;
|
||||
|
||||
AttrValue output_shapes = batch_node.attr().at(kOutputShapesAttr);
|
||||
@ -502,9 +538,8 @@ Status AppendFlatMap(const NodeDef& batch_node, int64 num_replicas,
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(AddFlatMapNode(strings::StrCat(batch_node.name(), ":0"),
|
||||
{batch_node.input(batch_size_index)},
|
||||
{DT_INT64}, flat_map_fn, output_shapes,
|
||||
dtypes, flib, graph, &flat_map_node));
|
||||
{}, {}, flat_map_fn, output_shapes, dtypes,
|
||||
flib, graph, &flat_map_node));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
graph->UpdateFanouts(batch_node.name(), flat_map_node->name()));
|
||||
@ -650,7 +685,7 @@ Status ReshapeComponent(int new_batch_dim, const string& arg, DataType dtype,
|
||||
|
||||
// shape = tf.shape(arg)
|
||||
NodeDef* shape;
|
||||
TF_RETURN_IF_ERROR(AddShapeNode({arg, 0, dtype}, fdef, &shape));
|
||||
TF_RETURN_IF_ERROR(AddShapeNode({arg, 0, dtype}, DT_INT32, fdef, &shape));
|
||||
|
||||
// later_dimensions = tf.shape(arg)[1:]
|
||||
NodeDef* later_dimensions;
|
||||
@ -748,26 +783,6 @@ Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_replicas,
|
||||
fetch_node = graph_utils::GetInputNode(*fetch_node, *graph, 0);
|
||||
}
|
||||
|
||||
// Note: Here, we are conservative with only using the fallback when
|
||||
// the output_shapes attr has the 0th dimension defined for every component.
|
||||
// This because the flat_map_fn will fail if the batch does not divide evenly
|
||||
// because of the use of the "Reshape" op. This ensures that the error is
|
||||
// surfaced correctly.
|
||||
AttrValue output_shapes;
|
||||
if (!fetch_node->attr().contains(kOutputShapesAttr)) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot use rebatching fallback without output_shapes attr. Node: ",
|
||||
fetch_node->name(), " Op: ", fetch_node->op());
|
||||
} else {
|
||||
output_shapes = fetch_node->attr().at(kOutputShapesAttr);
|
||||
}
|
||||
int batch_dim;
|
||||
TF_RETURN_IF_ERROR(GetBatchDim(output_shapes, &batch_dim));
|
||||
if (batch_dim % num_replicas != 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot use rebatching fallback when batch dimension doesn't divide "
|
||||
"num_replicas evenly.");
|
||||
}
|
||||
|
||||
// Create the flat map fn
|
||||
FunctionDef flat_map_fn;
|
||||
@ -779,9 +794,26 @@ Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_replicas,
|
||||
DataTypeVector output_types;
|
||||
TF_RETURN_IF_ERROR(
|
||||
graph_utils::GetDatasetOutputTypesAttr(*fetch_node, &output_types));
|
||||
TF_RETURN_IF_ERROR(CreateFlatMapFnWithReshape(batch_dim / num_replicas,
|
||||
output_types, &flat_map_fn));
|
||||
|
||||
int64 minibatch_dim;
|
||||
// If the batch dimension is known and perfectly divisible by num_replicas,
|
||||
// we use a fallback with `tf.reshape` for better performance.
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetMinibatchDimForReshape(*fetch_node, num_replicas, &minibatch_dim));
|
||||
if (minibatch_dim != -1) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateFlatMapFnWithReshape(minibatch_dim, output_types, &flat_map_fn));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateFlatMapFnWithBatch(output_types, num_replicas, &flat_map_fn));
|
||||
}
|
||||
|
||||
AttrValue output_shapes;
|
||||
if (fetch_node->attr().contains(kOutputShapesAttr)) {
|
||||
output_shapes = fetch_node->attr().at(kOutputShapesAttr);
|
||||
} else {
|
||||
SetUnknownShapes(output_types.size(), &output_shapes);
|
||||
}
|
||||
NodeDef* flat_map_node;
|
||||
TF_RETURN_IF_ERROR(AddFlatMapNode(strings::StrCat(fetch_node->name(), ":0"),
|
||||
{}, {}, flat_map_fn, output_shapes,
|
||||
|
@ -377,11 +377,8 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
|
||||
|
||||
# The batches of 5 (value == 0) will be split into minibatches of (3, 2) and
|
||||
# the batches of 10 (value == 1) split into minibatches of (5, 5)
|
||||
# [(batch_size, value), ...]
|
||||
pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (1, 0), (5, 1), (5, 1), (1, 1),
|
||||
(3, 0), (2, 0), (3, 0), (1, 0), (5, 1), (4, 1)]
|
||||
(3, 0), (2, 0), (2, 0), (2, 0), (5, 1), (4, 1)]
|
||||
expected_output = [[value] * batch_size for batch_size, value in pairs]
|
||||
self.assertDatasetProduces(dataset, expected_output)
|
||||
|
||||
@ -497,36 +494,39 @@ class RebatchDatasetFallbackTest(test_base.DatasetTestBase):
|
||||
def testWithUnknownBatchDim(self):
|
||||
dataset = dataset_ops.Dataset.range(1024).batch(
|
||||
32, drop_remainder=False).apply(sleep.sleep(10))
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"Cannot use rebatching fallback"):
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
next_element = self.getNext(rebatched_dataset)
|
||||
self.evaluate(next_element())
|
||||
expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
def testWithUnknownBatchDimInSecondComponent(self):
|
||||
dataset0 = dataset_ops.Dataset.range(1024).batch(32, drop_remainder=True)
|
||||
dataset1 = dataset_ops.Dataset.range(1024).batch(
|
||||
32, drop_remainder=False).apply(sleep.sleep(10))
|
||||
dataset = dataset_ops.Dataset.zip((dataset0, dataset1))
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"Cannot use rebatching fallback"):
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||
next_element = self.getNext(rebatched_dataset)
|
||||
self.evaluate(next_element())
|
||||
expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
|
||||
expected_output = [(x, x) for x in expected_output]
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
def testBatchSizeNotDivisibleByNumReplicas(self):
|
||||
# This doesn't work; reshape requires tensor shape to be exactly divisible
|
||||
# by the second dim.
|
||||
dataset = dataset_ops.Dataset.range(64).batch(
|
||||
32, drop_remainder=True).apply(sleep.sleep(10))
|
||||
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"Cannot use rebatching fallback"):
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
|
||||
next_element = self.getNext(rebatched_dataset)
|
||||
self.evaluate(next_element())
|
||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
|
||||
|
||||
expected_output = []
|
||||
i = 0
|
||||
for _ in range(2): # number of steps
|
||||
# first four minibatches have seven elements
|
||||
for _ in range(4):
|
||||
expected_output.append([k for k in range(i, i + 7)])
|
||||
i += 7
|
||||
# last minibatch has four elements
|
||||
expected_output.append([k for k in range(i, i + 4)])
|
||||
i += 4
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
def testBatchSizesDontMatch(self):
|
||||
dataset = dataset_ops.Dataset.from_tensors((np.arange(10), np.arange(5)))
|
||||
|
Loading…
Reference in New Issue
Block a user