[tf.data] Make rebatching fallback work for datasets with unknown shapes.

PiperOrigin-RevId: 264501770
This commit is contained in:
Rachel Lim 2019-08-20 17:12:22 -07:00
parent 6208021e3d
commit 7107f907aa
2 changed files with 109 additions and 77 deletions

View File

@ -200,11 +200,13 @@ Status AddConstBoolNode(bool value, FunctionDef* fdef, NodeDef** result) {
return Status::OK();
}
Status AddShapeNode(const NodeDefBuilder::NodeOut& input, FunctionDef* fdef,
NodeDef** result) {
Status AddShapeNode(const NodeDefBuilder::NodeOut& input, DataType out_type,
FunctionDef* fdef, NodeDef** result) {
*result = fdef->add_node_def();
TF_RETURN_IF_ERROR(
NodeDefBuilder("", "Shape").Input(input).Finalize(*result));
TF_RETURN_IF_ERROR(NodeDefBuilder("", "Shape")
.Input(input)
.Attr("out_type", out_type)
.Finalize(*result));
function_utils::SetUniqueFunctionNodeName("rebatch/shape", fdef, *result);
return Status::OK();
}
@ -276,34 +278,49 @@ void SetUnknownShapes(int num_components, AttrValue* output_shapes) {
}
}
Status GetBatchDim(AttrValue output_shapes, int* batch_dim) {
const auto& shape_0 = output_shapes.list().shape(0);
if (shape_0.unknown_rank() || shape_0.dim(0).size() == -1) {
// If the batch dimension is known and divisible by num_replicas, we set
// result = batch_dim / num_replicas. If the batch dimension is unknown,
// result = -1. If the dataset node is missing an output shapes attr,
// or the batch dimensions of its components don't match, we return an error
// status.
Status GetMinibatchDimForReshape(const NodeDef& dataset_node,
int64 num_replicas, int64* result) {
AttrValue output_shapes;
if (!dataset_node.attr().contains(kOutputShapesAttr)) {
return errors::InvalidArgument(
"Cannot use rebatching fallback when 0th dimensions of dataset "
"components are not fully known. Component 0 has shape: ",
shape_0.ShortDebugString());
"Cannot use rebatching fallback when the final dataset node does not "
"have an `output_shapes` attr. Node: ",
dataset_node.name(), " Op: ", dataset_node.op());
}
output_shapes = dataset_node.attr().at(kOutputShapesAttr);
*batch_dim = output_shapes.list().shape(0).dim(0).size();
for (int i = 1; i < output_shapes.list().shape_size(); ++i) {
// Get the batch dimension by checking the 0th dimension of all the inputs.
int batch_dim = -1;
for (int i = 0; i < output_shapes.list().shape_size(); ++i) {
const auto& shape_i = output_shapes.list().shape(i);
if (shape_i.unknown_rank() || shape_i.dim(0).size() == -1) {
// If unknown, ignore.
if (shape_i.unknown_rank()) continue;
int batch_dim_i = shape_i.dim(0).size();
if (batch_dim_i == -1) continue;
// Update batch_dim with known dimension.
if (batch_dim_i != batch_dim && batch_dim != -1) {
return errors::InvalidArgument(
"Cannot use rebatching fallback when 0th dimensions of dataset "
"components are not fully known. Component ",
i, " has shape: ", shape_i.ShortDebugString());
}
if (shape_i.dim(0).size() != *batch_dim) {
return errors::InvalidArgument(
"Cannot use rebatching fallback when 0th dimensions of dataset "
"Cannot use rebatching fallback: 0th dimensions of dataset "
"components don't match. Component ",
i, " has batch dimension: ", shape_i.dim(0).size(),
" while previous components have batch dimension: ", *batch_dim);
i, " has batch dimension: ", batch_dim_i,
" while previous components have batch dimension: ", batch_dim);
}
batch_dim = batch_dim_i;
}
if (batch_dim == -1 || batch_dim % num_replicas != 0) {
*result = -1;
} else {
*result = batch_dim / num_replicas;
}
return Status::OK();
}
@ -411,6 +428,8 @@ Status AddFlatMapNode(const string& input_dataset,
}
// def flat_map_fn(*batched_components):
// batch_size = tf.shape(batched_components[0])[0]
// minibatch_size = (batch_size + num_replicas - 1) // num_replicas
// ds = tf.data.Dataset.from_tensor_slices(batched_components)
// return ds.batch(minibatch_size, drop_remainder=False)
Status CreateFlatMapFnWithBatch(const DataTypeVector& dtypes,
@ -439,13 +458,32 @@ Status CreateFlatMapFnWithBatch(const DataTypeVector& dtypes,
batch_node->add_input(
strings::StrCat(tensor_slice_node->name(), ":handle:0"));
// `batch_size` input
// Here, we capture the original batch size from outside the flat map fn.
auto* original_batch_size =
function_utils::AddFunctionInput("captured_batch_size", result, DT_INT64);
// `batch_size` is tf.shape(arg)[0]
NodeDef* shape;
TF_RETURN_IF_ERROR(AddShapeNode({tensor_slice_node->input(0), 0, dtypes[0]},
DT_INT64, result, &shape));
// Const with value [0]
NodeDef* const_vec_0;
TF_RETURN_IF_ERROR(AddConstIntNode({0}, {1}, result, &const_vec_0));
// Const with value [1]
NodeDef* const_vec_1;
TF_RETURN_IF_ERROR(AddConstIntNode({1}, {1}, result, &const_vec_1));
// Extracts the 0th dimension from the shape node.
NodeDef* original_batch_size;
TF_RETURN_IF_ERROR(AddStridedSliceNode(
{strings::StrCat(shape->name(), ":output"), 0, DT_INT64},
{strings::StrCat(const_vec_0->name(), ":output"), 0, DT_INT32},
{strings::StrCat(const_vec_1->name(), ":output"), 0, DT_INT32},
{strings::StrCat(const_vec_1->name(), ":output"), 0, DT_INT32}, DT_INT32,
0, 0, 0, 0, 1, result, &original_batch_size));
NodeDef* new_batch_size;
TF_RETURN_IF_ERROR(MakeNewBatchSizeNode(
original_batch_size->name(), num_replicas, result, &new_batch_size));
strings::StrCat(original_batch_size->name(), ":output:0"), num_replicas,
result, &new_batch_size));
batch_node->add_input(strings::StrCat(new_batch_size->name(), ":z:0"));
// `drop_remainder` input
@ -486,8 +524,6 @@ Status AppendFlatMap(const NodeDef& batch_node, int64 num_replicas,
TF_RETURN_IF_ERROR(
CreateFlatMapFnWithBatch(dtypes, num_replicas, &flat_map_fn));
int64 batch_size_index = GetBatchSizeArgIndex(batch_node);
NodeDef* flat_map_node;
AttrValue output_shapes = batch_node.attr().at(kOutputShapesAttr);
@ -502,9 +538,8 @@ Status AppendFlatMap(const NodeDef& batch_node, int64 num_replicas,
}
TF_RETURN_IF_ERROR(AddFlatMapNode(strings::StrCat(batch_node.name(), ":0"),
{batch_node.input(batch_size_index)},
{DT_INT64}, flat_map_fn, output_shapes,
dtypes, flib, graph, &flat_map_node));
{}, {}, flat_map_fn, output_shapes, dtypes,
flib, graph, &flat_map_node));
TF_RETURN_IF_ERROR(
graph->UpdateFanouts(batch_node.name(), flat_map_node->name()));
@ -650,7 +685,7 @@ Status ReshapeComponent(int new_batch_dim, const string& arg, DataType dtype,
// shape = tf.shape(arg)
NodeDef* shape;
TF_RETURN_IF_ERROR(AddShapeNode({arg, 0, dtype}, fdef, &shape));
TF_RETURN_IF_ERROR(AddShapeNode({arg, 0, dtype}, DT_INT32, fdef, &shape));
// later_dimensions = tf.shape(arg)[1:]
NodeDef* later_dimensions;
@ -748,26 +783,6 @@ Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_replicas,
fetch_node = graph_utils::GetInputNode(*fetch_node, *graph, 0);
}
// Note: Here, we are conservative with only using the fallback when
// the output_shapes attr has the 0th dimension defined for every component.
// This because the flat_map_fn will fail if the batch does not divide evenly
// because of the use of the "Reshape" op. This ensures that the error is
// surfaced correctly.
AttrValue output_shapes;
if (!fetch_node->attr().contains(kOutputShapesAttr)) {
return errors::InvalidArgument(
"Cannot use rebatching fallback without output_shapes attr. Node: ",
fetch_node->name(), " Op: ", fetch_node->op());
} else {
output_shapes = fetch_node->attr().at(kOutputShapesAttr);
}
int batch_dim;
TF_RETURN_IF_ERROR(GetBatchDim(output_shapes, &batch_dim));
if (batch_dim % num_replicas != 0) {
return errors::InvalidArgument(
"Cannot use rebatching fallback when batch dimension doesn't divide "
"num_replicas evenly.");
}
// Create the flat map fn
FunctionDef flat_map_fn;
@ -779,9 +794,26 @@ Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_replicas,
DataTypeVector output_types;
TF_RETURN_IF_ERROR(
graph_utils::GetDatasetOutputTypesAttr(*fetch_node, &output_types));
TF_RETURN_IF_ERROR(CreateFlatMapFnWithReshape(batch_dim / num_replicas,
output_types, &flat_map_fn));
int64 minibatch_dim;
// If the batch dimension is known and perfectly divisible by num_replicas,
// we use a fallback with `tf.reshape` for better performance.
TF_RETURN_IF_ERROR(
GetMinibatchDimForReshape(*fetch_node, num_replicas, &minibatch_dim));
if (minibatch_dim != -1) {
TF_RETURN_IF_ERROR(
CreateFlatMapFnWithReshape(minibatch_dim, output_types, &flat_map_fn));
} else {
TF_RETURN_IF_ERROR(
CreateFlatMapFnWithBatch(output_types, num_replicas, &flat_map_fn));
}
AttrValue output_shapes;
if (fetch_node->attr().contains(kOutputShapesAttr)) {
output_shapes = fetch_node->attr().at(kOutputShapesAttr);
} else {
SetUnknownShapes(output_types.size(), &output_shapes);
}
NodeDef* flat_map_node;
TF_RETURN_IF_ERROR(AddFlatMapNode(strings::StrCat(fetch_node->name(), ":0"),
{}, {}, flat_map_fn, output_shapes,

View File

@ -377,11 +377,8 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
# The batches of 5 (value == 0) will be split into minibatches of (3, 2) and
# the batches of 10 (value == 1) split into minibatches of (5, 5)
# [(batch_size, value), ...]
pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (1, 0), (5, 1), (5, 1), (1, 1),
(3, 0), (2, 0), (3, 0), (1, 0), (5, 1), (4, 1)]
(3, 0), (2, 0), (2, 0), (2, 0), (5, 1), (4, 1)]
expected_output = [[value] * batch_size for batch_size, value in pairs]
self.assertDatasetProduces(dataset, expected_output)
@ -497,36 +494,39 @@ class RebatchDatasetFallbackTest(test_base.DatasetTestBase):
def testWithUnknownBatchDim(self):
dataset = dataset_ops.Dataset.range(1024).batch(
32, drop_remainder=False).apply(sleep.sleep(10))
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"Cannot use rebatching fallback"):
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
next_element = self.getNext(rebatched_dataset)
self.evaluate(next_element())
expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
self.assertDatasetProduces(rebatched_dataset, expected_output)
def testWithUnknownBatchDimInSecondComponent(self):
dataset0 = dataset_ops.Dataset.range(1024).batch(32, drop_remainder=True)
dataset1 = dataset_ops.Dataset.range(1024).batch(
32, drop_remainder=False).apply(sleep.sleep(10))
dataset = dataset_ops.Dataset.zip((dataset0, dataset1))
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"Cannot use rebatching fallback"):
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
next_element = self.getNext(rebatched_dataset)
self.evaluate(next_element())
expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
expected_output = [(x, x) for x in expected_output]
self.assertDatasetProduces(rebatched_dataset, expected_output)
def testBatchSizeNotDivisibleByNumReplicas(self):
# This doesn't work; reshape requires tensor shape to be exactly divisible
# by the second dim.
dataset = dataset_ops.Dataset.range(64).batch(
32, drop_remainder=True).apply(sleep.sleep(10))
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"Cannot use rebatching fallback"):
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
next_element = self.getNext(rebatched_dataset)
self.evaluate(next_element())
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
expected_output = []
i = 0
for _ in range(2): # number of steps
# first four minibatches have seven elements
for _ in range(4):
expected_output.append([k for k in range(i, i + 7)])
i += 7
# last minibatch has four elements
expected_output.append([k for k in range(i, i + 4)])
i += 4
self.assertDatasetProduces(rebatched_dataset, expected_output)
def testBatchSizesDontMatch(self):
dataset = dataset_ops.Dataset.from_tensors((np.arange(10), np.arange(5)))