[tf.data] Make rebatching fallback work for datasets with unknown shapes.
PiperOrigin-RevId: 264501770
This commit is contained in:
parent
6208021e3d
commit
7107f907aa
@ -200,11 +200,13 @@ Status AddConstBoolNode(bool value, FunctionDef* fdef, NodeDef** result) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status AddShapeNode(const NodeDefBuilder::NodeOut& input, FunctionDef* fdef,
|
Status AddShapeNode(const NodeDefBuilder::NodeOut& input, DataType out_type,
|
||||||
NodeDef** result) {
|
FunctionDef* fdef, NodeDef** result) {
|
||||||
*result = fdef->add_node_def();
|
*result = fdef->add_node_def();
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(NodeDefBuilder("", "Shape")
|
||||||
NodeDefBuilder("", "Shape").Input(input).Finalize(*result));
|
.Input(input)
|
||||||
|
.Attr("out_type", out_type)
|
||||||
|
.Finalize(*result));
|
||||||
function_utils::SetUniqueFunctionNodeName("rebatch/shape", fdef, *result);
|
function_utils::SetUniqueFunctionNodeName("rebatch/shape", fdef, *result);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -276,34 +278,49 @@ void SetUnknownShapes(int num_components, AttrValue* output_shapes) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetBatchDim(AttrValue output_shapes, int* batch_dim) {
|
// If the batch dimension is known and divisible by num_replicas, we set
|
||||||
const auto& shape_0 = output_shapes.list().shape(0);
|
// result = batch_dim / num_replicas. If the batch dimension is unknown,
|
||||||
if (shape_0.unknown_rank() || shape_0.dim(0).size() == -1) {
|
// result = -1. If the dataset node is missing an output shapes attr,
|
||||||
|
// or the batch dimensions of its components don't match, we return an error
|
||||||
|
// status.
|
||||||
|
Status GetMinibatchDimForReshape(const NodeDef& dataset_node,
|
||||||
|
int64 num_replicas, int64* result) {
|
||||||
|
AttrValue output_shapes;
|
||||||
|
if (!dataset_node.attr().contains(kOutputShapesAttr)) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Cannot use rebatching fallback when 0th dimensions of dataset "
|
"Cannot use rebatching fallback when the final dataset node does not "
|
||||||
"components are not fully known. Component 0 has shape: ",
|
"have an `output_shapes` attr. Node: ",
|
||||||
shape_0.ShortDebugString());
|
dataset_node.name(), " Op: ", dataset_node.op());
|
||||||
}
|
}
|
||||||
|
output_shapes = dataset_node.attr().at(kOutputShapesAttr);
|
||||||
|
|
||||||
*batch_dim = output_shapes.list().shape(0).dim(0).size();
|
// Get the batch dimension by checking the 0th dimension of all the inputs.
|
||||||
|
int batch_dim = -1;
|
||||||
for (int i = 1; i < output_shapes.list().shape_size(); ++i) {
|
for (int i = 0; i < output_shapes.list().shape_size(); ++i) {
|
||||||
const auto& shape_i = output_shapes.list().shape(i);
|
const auto& shape_i = output_shapes.list().shape(i);
|
||||||
|
|
||||||
if (shape_i.unknown_rank() || shape_i.dim(0).size() == -1) {
|
// If unknown, ignore.
|
||||||
|
if (shape_i.unknown_rank()) continue;
|
||||||
|
int batch_dim_i = shape_i.dim(0).size();
|
||||||
|
if (batch_dim_i == -1) continue;
|
||||||
|
|
||||||
|
// Update batch_dim with known dimension.
|
||||||
|
if (batch_dim_i != batch_dim && batch_dim != -1) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Cannot use rebatching fallback when 0th dimensions of dataset "
|
"Cannot use rebatching fallback: 0th dimensions of dataset "
|
||||||
"components are not fully known. Component ",
|
|
||||||
i, " has shape: ", shape_i.ShortDebugString());
|
|
||||||
}
|
|
||||||
if (shape_i.dim(0).size() != *batch_dim) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Cannot use rebatching fallback when 0th dimensions of dataset "
|
|
||||||
"components don't match. Component ",
|
"components don't match. Component ",
|
||||||
i, " has batch dimension: ", shape_i.dim(0).size(),
|
i, " has batch dimension: ", batch_dim_i,
|
||||||
" while previous components have batch dimension: ", *batch_dim);
|
" while previous components have batch dimension: ", batch_dim);
|
||||||
}
|
}
|
||||||
|
batch_dim = batch_dim_i;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (batch_dim == -1 || batch_dim % num_replicas != 0) {
|
||||||
|
*result = -1;
|
||||||
|
} else {
|
||||||
|
*result = batch_dim / num_replicas;
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -411,6 +428,8 @@ Status AddFlatMapNode(const string& input_dataset,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// def flat_map_fn(*batched_components):
|
// def flat_map_fn(*batched_components):
|
||||||
|
// batch_size = tf.shape(batched_components[0])[0]
|
||||||
|
// minibatch_size = (batch_size + num_replicas - 1) // num_replicas
|
||||||
// ds = tf.data.Dataset.from_tensor_slices(batched_components)
|
// ds = tf.data.Dataset.from_tensor_slices(batched_components)
|
||||||
// return ds.batch(minibatch_size, drop_remainder=False)
|
// return ds.batch(minibatch_size, drop_remainder=False)
|
||||||
Status CreateFlatMapFnWithBatch(const DataTypeVector& dtypes,
|
Status CreateFlatMapFnWithBatch(const DataTypeVector& dtypes,
|
||||||
@ -439,13 +458,32 @@ Status CreateFlatMapFnWithBatch(const DataTypeVector& dtypes,
|
|||||||
batch_node->add_input(
|
batch_node->add_input(
|
||||||
strings::StrCat(tensor_slice_node->name(), ":handle:0"));
|
strings::StrCat(tensor_slice_node->name(), ":handle:0"));
|
||||||
|
|
||||||
// `batch_size` input
|
// `batch_size` is tf.shape(arg)[0]
|
||||||
// Here, we capture the original batch size from outside the flat map fn.
|
NodeDef* shape;
|
||||||
auto* original_batch_size =
|
TF_RETURN_IF_ERROR(AddShapeNode({tensor_slice_node->input(0), 0, dtypes[0]},
|
||||||
function_utils::AddFunctionInput("captured_batch_size", result, DT_INT64);
|
DT_INT64, result, &shape));
|
||||||
|
|
||||||
|
// Const with value [0]
|
||||||
|
NodeDef* const_vec_0;
|
||||||
|
TF_RETURN_IF_ERROR(AddConstIntNode({0}, {1}, result, &const_vec_0));
|
||||||
|
|
||||||
|
// Const with value [1]
|
||||||
|
NodeDef* const_vec_1;
|
||||||
|
TF_RETURN_IF_ERROR(AddConstIntNode({1}, {1}, result, &const_vec_1));
|
||||||
|
|
||||||
|
// Extracts the 0th dimension from the shape node.
|
||||||
|
NodeDef* original_batch_size;
|
||||||
|
TF_RETURN_IF_ERROR(AddStridedSliceNode(
|
||||||
|
{strings::StrCat(shape->name(), ":output"), 0, DT_INT64},
|
||||||
|
{strings::StrCat(const_vec_0->name(), ":output"), 0, DT_INT32},
|
||||||
|
{strings::StrCat(const_vec_1->name(), ":output"), 0, DT_INT32},
|
||||||
|
{strings::StrCat(const_vec_1->name(), ":output"), 0, DT_INT32}, DT_INT32,
|
||||||
|
0, 0, 0, 0, 1, result, &original_batch_size));
|
||||||
|
|
||||||
NodeDef* new_batch_size;
|
NodeDef* new_batch_size;
|
||||||
TF_RETURN_IF_ERROR(MakeNewBatchSizeNode(
|
TF_RETURN_IF_ERROR(MakeNewBatchSizeNode(
|
||||||
original_batch_size->name(), num_replicas, result, &new_batch_size));
|
strings::StrCat(original_batch_size->name(), ":output:0"), num_replicas,
|
||||||
|
result, &new_batch_size));
|
||||||
batch_node->add_input(strings::StrCat(new_batch_size->name(), ":z:0"));
|
batch_node->add_input(strings::StrCat(new_batch_size->name(), ":z:0"));
|
||||||
|
|
||||||
// `drop_remainder` input
|
// `drop_remainder` input
|
||||||
@ -486,8 +524,6 @@ Status AppendFlatMap(const NodeDef& batch_node, int64 num_replicas,
|
|||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
CreateFlatMapFnWithBatch(dtypes, num_replicas, &flat_map_fn));
|
CreateFlatMapFnWithBatch(dtypes, num_replicas, &flat_map_fn));
|
||||||
|
|
||||||
int64 batch_size_index = GetBatchSizeArgIndex(batch_node);
|
|
||||||
|
|
||||||
NodeDef* flat_map_node;
|
NodeDef* flat_map_node;
|
||||||
|
|
||||||
AttrValue output_shapes = batch_node.attr().at(kOutputShapesAttr);
|
AttrValue output_shapes = batch_node.attr().at(kOutputShapesAttr);
|
||||||
@ -502,9 +538,8 @@ Status AppendFlatMap(const NodeDef& batch_node, int64 num_replicas,
|
|||||||
}
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(AddFlatMapNode(strings::StrCat(batch_node.name(), ":0"),
|
TF_RETURN_IF_ERROR(AddFlatMapNode(strings::StrCat(batch_node.name(), ":0"),
|
||||||
{batch_node.input(batch_size_index)},
|
{}, {}, flat_map_fn, output_shapes, dtypes,
|
||||||
{DT_INT64}, flat_map_fn, output_shapes,
|
flib, graph, &flat_map_node));
|
||||||
dtypes, flib, graph, &flat_map_node));
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
graph->UpdateFanouts(batch_node.name(), flat_map_node->name()));
|
graph->UpdateFanouts(batch_node.name(), flat_map_node->name()));
|
||||||
@ -650,7 +685,7 @@ Status ReshapeComponent(int new_batch_dim, const string& arg, DataType dtype,
|
|||||||
|
|
||||||
// shape = tf.shape(arg)
|
// shape = tf.shape(arg)
|
||||||
NodeDef* shape;
|
NodeDef* shape;
|
||||||
TF_RETURN_IF_ERROR(AddShapeNode({arg, 0, dtype}, fdef, &shape));
|
TF_RETURN_IF_ERROR(AddShapeNode({arg, 0, dtype}, DT_INT32, fdef, &shape));
|
||||||
|
|
||||||
// later_dimensions = tf.shape(arg)[1:]
|
// later_dimensions = tf.shape(arg)[1:]
|
||||||
NodeDef* later_dimensions;
|
NodeDef* later_dimensions;
|
||||||
@ -748,26 +783,6 @@ Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_replicas,
|
|||||||
fetch_node = graph_utils::GetInputNode(*fetch_node, *graph, 0);
|
fetch_node = graph_utils::GetInputNode(*fetch_node, *graph, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: Here, we are conservative with only using the fallback when
|
|
||||||
// the output_shapes attr has the 0th dimension defined for every component.
|
|
||||||
// This because the flat_map_fn will fail if the batch does not divide evenly
|
|
||||||
// because of the use of the "Reshape" op. This ensures that the error is
|
|
||||||
// surfaced correctly.
|
|
||||||
AttrValue output_shapes;
|
|
||||||
if (!fetch_node->attr().contains(kOutputShapesAttr)) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Cannot use rebatching fallback without output_shapes attr. Node: ",
|
|
||||||
fetch_node->name(), " Op: ", fetch_node->op());
|
|
||||||
} else {
|
|
||||||
output_shapes = fetch_node->attr().at(kOutputShapesAttr);
|
|
||||||
}
|
|
||||||
int batch_dim;
|
|
||||||
TF_RETURN_IF_ERROR(GetBatchDim(output_shapes, &batch_dim));
|
|
||||||
if (batch_dim % num_replicas != 0) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Cannot use rebatching fallback when batch dimension doesn't divide "
|
|
||||||
"num_replicas evenly.");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the flat map fn
|
// Create the flat map fn
|
||||||
FunctionDef flat_map_fn;
|
FunctionDef flat_map_fn;
|
||||||
@ -779,9 +794,26 @@ Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_replicas,
|
|||||||
DataTypeVector output_types;
|
DataTypeVector output_types;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
graph_utils::GetDatasetOutputTypesAttr(*fetch_node, &output_types));
|
graph_utils::GetDatasetOutputTypesAttr(*fetch_node, &output_types));
|
||||||
TF_RETURN_IF_ERROR(CreateFlatMapFnWithReshape(batch_dim / num_replicas,
|
|
||||||
output_types, &flat_map_fn));
|
|
||||||
|
|
||||||
|
int64 minibatch_dim;
|
||||||
|
// If the batch dimension is known and perfectly divisible by num_replicas,
|
||||||
|
// we use a fallback with `tf.reshape` for better performance.
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
GetMinibatchDimForReshape(*fetch_node, num_replicas, &minibatch_dim));
|
||||||
|
if (minibatch_dim != -1) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
CreateFlatMapFnWithReshape(minibatch_dim, output_types, &flat_map_fn));
|
||||||
|
} else {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
CreateFlatMapFnWithBatch(output_types, num_replicas, &flat_map_fn));
|
||||||
|
}
|
||||||
|
|
||||||
|
AttrValue output_shapes;
|
||||||
|
if (fetch_node->attr().contains(kOutputShapesAttr)) {
|
||||||
|
output_shapes = fetch_node->attr().at(kOutputShapesAttr);
|
||||||
|
} else {
|
||||||
|
SetUnknownShapes(output_types.size(), &output_shapes);
|
||||||
|
}
|
||||||
NodeDef* flat_map_node;
|
NodeDef* flat_map_node;
|
||||||
TF_RETURN_IF_ERROR(AddFlatMapNode(strings::StrCat(fetch_node->name(), ":0"),
|
TF_RETURN_IF_ERROR(AddFlatMapNode(strings::StrCat(fetch_node->name(), ":0"),
|
||||||
{}, {}, flat_map_fn, output_shapes,
|
{}, {}, flat_map_fn, output_shapes,
|
||||||
|
@ -377,11 +377,8 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
|
self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
|
||||||
|
|
||||||
# The batches of 5 (value == 0) will be split into minibatches of (3, 2) and
|
|
||||||
# the batches of 10 (value == 1) split into minibatches of (5, 5)
|
|
||||||
# [(batch_size, value), ...]
|
|
||||||
pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (1, 0), (5, 1), (5, 1), (1, 1),
|
pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (1, 0), (5, 1), (5, 1), (1, 1),
|
||||||
(3, 0), (2, 0), (3, 0), (1, 0), (5, 1), (4, 1)]
|
(3, 0), (2, 0), (2, 0), (2, 0), (5, 1), (4, 1)]
|
||||||
expected_output = [[value] * batch_size for batch_size, value in pairs]
|
expected_output = [[value] * batch_size for batch_size, value in pairs]
|
||||||
self.assertDatasetProduces(dataset, expected_output)
|
self.assertDatasetProduces(dataset, expected_output)
|
||||||
|
|
||||||
@ -497,36 +494,39 @@ class RebatchDatasetFallbackTest(test_base.DatasetTestBase):
|
|||||||
def testWithUnknownBatchDim(self):
|
def testWithUnknownBatchDim(self):
|
||||||
dataset = dataset_ops.Dataset.range(1024).batch(
|
dataset = dataset_ops.Dataset.range(1024).batch(
|
||||||
32, drop_remainder=False).apply(sleep.sleep(10))
|
32, drop_remainder=False).apply(sleep.sleep(10))
|
||||||
|
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
|
||||||
"Cannot use rebatching fallback"):
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
|
||||||
next_element = self.getNext(rebatched_dataset)
|
|
||||||
self.evaluate(next_element())
|
|
||||||
|
|
||||||
def testWithUnknownBatchDimInSecondComponent(self):
|
def testWithUnknownBatchDimInSecondComponent(self):
|
||||||
dataset0 = dataset_ops.Dataset.range(1024).batch(32, drop_remainder=True)
|
dataset0 = dataset_ops.Dataset.range(1024).batch(32, drop_remainder=True)
|
||||||
dataset1 = dataset_ops.Dataset.range(1024).batch(
|
dataset1 = dataset_ops.Dataset.range(1024).batch(
|
||||||
32, drop_remainder=False).apply(sleep.sleep(10))
|
32, drop_remainder=False).apply(sleep.sleep(10))
|
||||||
dataset = dataset_ops.Dataset.zip((dataset0, dataset1))
|
dataset = dataset_ops.Dataset.zip((dataset0, dataset1))
|
||||||
|
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
|
||||||
"Cannot use rebatching fallback"):
|
expected_output = [(x, x) for x in expected_output]
|
||||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
next_element = self.getNext(rebatched_dataset)
|
|
||||||
self.evaluate(next_element())
|
|
||||||
|
|
||||||
def testBatchSizeNotDivisibleByNumReplicas(self):
|
def testBatchSizeNotDivisibleByNumReplicas(self):
|
||||||
# This doesn't work; reshape requires tensor shape to be exactly divisible
|
|
||||||
# by the second dim.
|
|
||||||
dataset = dataset_ops.Dataset.range(64).batch(
|
dataset = dataset_ops.Dataset.range(64).batch(
|
||||||
32, drop_remainder=True).apply(sleep.sleep(10))
|
32, drop_remainder=True).apply(sleep.sleep(10))
|
||||||
|
|
||||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
|
||||||
"Cannot use rebatching fallback"):
|
|
||||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
|
expected_output = []
|
||||||
next_element = self.getNext(rebatched_dataset)
|
i = 0
|
||||||
self.evaluate(next_element())
|
for _ in range(2): # number of steps
|
||||||
|
# first four minibatches have seven elements
|
||||||
|
for _ in range(4):
|
||||||
|
expected_output.append([k for k in range(i, i + 7)])
|
||||||
|
i += 7
|
||||||
|
# last minibatch has four elements
|
||||||
|
expected_output.append([k for k in range(i, i + 4)])
|
||||||
|
i += 4
|
||||||
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
|
|
||||||
def testBatchSizesDontMatch(self):
|
def testBatchSizesDontMatch(self):
|
||||||
dataset = dataset_ops.Dataset.from_tensors((np.arange(10), np.arange(5)))
|
dataset = dataset_ops.Dataset.from_tensors((np.arange(10), np.arange(5)))
|
||||||
|
Loading…
Reference in New Issue
Block a user