[xla_compiler] Do not promote TF shape constant folding when input is a XLA dynamic shape.

PiperOrigin-RevId: 356904310
Change-Id: Iff329b8f81777f895333726e8ca98e2d3ad4ddb5
This commit is contained in:
Yunxing Dai 2021-02-10 22:33:00 -08:00 committed by TensorFlower Gardener
parent 24720e5940
commit 443f13e41a
2 changed files with 41 additions and 1 deletions

View File

@ -760,7 +760,10 @@ Status XlaCompiler::CompileFunction(
if (absl::holds_alternative<xla::Shape>(args[i].shape)) {
xla::Shape xla_shape = absl::get<xla::Shape>(args[i].shape);
TensorShape tensor_shape;
if (XLAShapeToTensorShape(xla_shape, &tensor_shape).ok()) {
// If xla_shape is dynamic, prevent constant folding by not setting
// output_shapes.
if (XLAShapeToTensorShape(xla_shape, &tensor_shape).ok() &&
xla_shape.is_static()) {
fbody->arg_nodes[i]->ClearAttr("_output_shapes");
fbody->arg_nodes[i]->AddAttr("_output_shapes",
std::vector<TensorShape>{tensor_shape});

View File

@ -393,6 +393,43 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
train_step()
self.assertEqual(2.0, v.numpy())
def test_cluster_conditional_with_dynamic_shape(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
@def_function.function
def train_step():
def shape_list(tensor):
shape = tensor.shape.as_list()
non_static_indexes = []
for (index, dim) in enumerate(shape):
if dim is None:
non_static_indexes.append(index)
if not non_static_indexes:
return shape
dynamic_shape = array_ops.shape(input=tensor)
for index in non_static_indexes:
shape[index] = dynamic_shape[index]
return shape
def step_fn(condition):
where = array_ops.where(condition)
if array_ops.shape(where)[0] > 0:
tensor_shape = shape_list(where)
d1 = tensor_shape[0]
d2 = tensor_shape[1]
where = array_ops.reshape(where, [d1, d2])
return where
return strategy.run(step_fn, args=([True, False, True],))
outputs = strategy.experimental_local_results(train_step())
self.assertAllEqual(outputs[0].numpy(), [[0], [2]])
def test_cluster_in_graph_and_while_body_fn(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)