[xla_compiler] Do not promote TF shape constant folding when input is a XLA dynamic shape.
PiperOrigin-RevId: 356904310 Change-Id: Iff329b8f81777f895333726e8ca98e2d3ad4ddb5
This commit is contained in:
parent
24720e5940
commit
443f13e41a
@ -760,7 +760,10 @@ Status XlaCompiler::CompileFunction(
|
||||
if (absl::holds_alternative<xla::Shape>(args[i].shape)) {
|
||||
xla::Shape xla_shape = absl::get<xla::Shape>(args[i].shape);
|
||||
TensorShape tensor_shape;
|
||||
if (XLAShapeToTensorShape(xla_shape, &tensor_shape).ok()) {
|
||||
// If xla_shape is dynamic, prevent constant folding by not setting
|
||||
// output_shapes.
|
||||
if (XLAShapeToTensorShape(xla_shape, &tensor_shape).ok() &&
|
||||
xla_shape.is_static()) {
|
||||
fbody->arg_nodes[i]->ClearAttr("_output_shapes");
|
||||
fbody->arg_nodes[i]->AddAttr("_output_shapes",
|
||||
std::vector<TensorShape>{tensor_shape});
|
||||
|
@ -393,6 +393,43 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
||||
train_step()
|
||||
self.assertEqual(2.0, v.numpy())
|
||||
|
||||
def test_cluster_conditional_with_dynamic_shape(self, enable_packed_var):
|
||||
strategy = get_tpu_strategy(enable_packed_var)
|
||||
|
||||
@def_function.function
|
||||
def train_step():
|
||||
|
||||
def shape_list(tensor):
|
||||
shape = tensor.shape.as_list()
|
||||
|
||||
non_static_indexes = []
|
||||
for (index, dim) in enumerate(shape):
|
||||
if dim is None:
|
||||
non_static_indexes.append(index)
|
||||
|
||||
if not non_static_indexes:
|
||||
return shape
|
||||
|
||||
dynamic_shape = array_ops.shape(input=tensor)
|
||||
for index in non_static_indexes:
|
||||
shape[index] = dynamic_shape[index]
|
||||
|
||||
return shape
|
||||
|
||||
def step_fn(condition):
|
||||
where = array_ops.where(condition)
|
||||
if array_ops.shape(where)[0] > 0:
|
||||
tensor_shape = shape_list(where)
|
||||
d1 = tensor_shape[0]
|
||||
d2 = tensor_shape[1]
|
||||
where = array_ops.reshape(where, [d1, d2])
|
||||
return where
|
||||
|
||||
return strategy.run(step_fn, args=([True, False, True],))
|
||||
|
||||
outputs = strategy.experimental_local_results(train_step())
|
||||
self.assertAllEqual(outputs[0].numpy(), [[0], [2]])
|
||||
|
||||
def test_cluster_in_graph_and_while_body_fn(self, enable_packed_var):
|
||||
strategy = get_tpu_strategy(enable_packed_var)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user