[xla_compiler] Do not promote TF shape constant folding when input is a XLA dynamic shape.
PiperOrigin-RevId: 356904310 Change-Id: Iff329b8f81777f895333726e8ca98e2d3ad4ddb5
This commit is contained in:
parent
24720e5940
commit
443f13e41a
@ -760,7 +760,10 @@ Status XlaCompiler::CompileFunction(
|
|||||||
if (absl::holds_alternative<xla::Shape>(args[i].shape)) {
|
if (absl::holds_alternative<xla::Shape>(args[i].shape)) {
|
||||||
xla::Shape xla_shape = absl::get<xla::Shape>(args[i].shape);
|
xla::Shape xla_shape = absl::get<xla::Shape>(args[i].shape);
|
||||||
TensorShape tensor_shape;
|
TensorShape tensor_shape;
|
||||||
if (XLAShapeToTensorShape(xla_shape, &tensor_shape).ok()) {
|
// If xla_shape is dynamic, prevent constant folding by not setting
|
||||||
|
// output_shapes.
|
||||||
|
if (XLAShapeToTensorShape(xla_shape, &tensor_shape).ok() &&
|
||||||
|
xla_shape.is_static()) {
|
||||||
fbody->arg_nodes[i]->ClearAttr("_output_shapes");
|
fbody->arg_nodes[i]->ClearAttr("_output_shapes");
|
||||||
fbody->arg_nodes[i]->AddAttr("_output_shapes",
|
fbody->arg_nodes[i]->AddAttr("_output_shapes",
|
||||||
std::vector<TensorShape>{tensor_shape});
|
std::vector<TensorShape>{tensor_shape});
|
||||||
|
@ -393,6 +393,43 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
|||||||
train_step()
|
train_step()
|
||||||
self.assertEqual(2.0, v.numpy())
|
self.assertEqual(2.0, v.numpy())
|
||||||
|
|
||||||
|
def test_cluster_conditional_with_dynamic_shape(self, enable_packed_var):
|
||||||
|
strategy = get_tpu_strategy(enable_packed_var)
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def train_step():
|
||||||
|
|
||||||
|
def shape_list(tensor):
|
||||||
|
shape = tensor.shape.as_list()
|
||||||
|
|
||||||
|
non_static_indexes = []
|
||||||
|
for (index, dim) in enumerate(shape):
|
||||||
|
if dim is None:
|
||||||
|
non_static_indexes.append(index)
|
||||||
|
|
||||||
|
if not non_static_indexes:
|
||||||
|
return shape
|
||||||
|
|
||||||
|
dynamic_shape = array_ops.shape(input=tensor)
|
||||||
|
for index in non_static_indexes:
|
||||||
|
shape[index] = dynamic_shape[index]
|
||||||
|
|
||||||
|
return shape
|
||||||
|
|
||||||
|
def step_fn(condition):
|
||||||
|
where = array_ops.where(condition)
|
||||||
|
if array_ops.shape(where)[0] > 0:
|
||||||
|
tensor_shape = shape_list(where)
|
||||||
|
d1 = tensor_shape[0]
|
||||||
|
d2 = tensor_shape[1]
|
||||||
|
where = array_ops.reshape(where, [d1, d2])
|
||||||
|
return where
|
||||||
|
|
||||||
|
return strategy.run(step_fn, args=([True, False, True],))
|
||||||
|
|
||||||
|
outputs = strategy.experimental_local_results(train_step())
|
||||||
|
self.assertAllEqual(outputs[0].numpy(), [[0], [2]])
|
||||||
|
|
||||||
def test_cluster_in_graph_and_while_body_fn(self, enable_packed_var):
|
def test_cluster_in_graph_and_while_body_fn(self, enable_packed_var):
|
||||||
strategy = get_tpu_strategy(enable_packed_var)
|
strategy = get_tpu_strategy(enable_packed_var)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user