Make minimum padding size to 2 in tpu.repliate.
XLA doesn't support the upperbound of dynamic dimension of 1. Make minumum padding size to 2. PiperOrigin-RevId: 276603306 Change-Id: I4b64f31adc4b62e160e831f05ccc2a2326ff6d6f
This commit is contained in:
parent
15ee2eba23
commit
7785075046
@ -823,13 +823,19 @@ def _pad_all_input(inputs, padded_shapes):
|
||||
paddings = []
|
||||
for i, s in enumerate(padded_shape.dims):
|
||||
if need_padding[idx][i]:
|
||||
# The minimum padded dimension size is 2 as XLA doesn't support size
|
||||
# 1 dynamic size.
|
||||
minimum_dynamic_dim_size = 2
|
||||
if s.value:
|
||||
# Pad to the given maximum value.
|
||||
padding = [0, s.value - input_shape_tensor[i]]
|
||||
max_dim_size = max(s.value, minimum_dynamic_dim_size)
|
||||
else:
|
||||
# If maximum value is not given, then pad to the maximum dimension
|
||||
# among all the cores.
|
||||
padding = [0, maximum_shapes[idx][i] - input_shape_tensor[i]]
|
||||
max_dim_size = math_ops.maximum(maximum_shapes[idx][i],
|
||||
minimum_dynamic_dim_size)
|
||||
# Pad to the given maximum value.
|
||||
padding = [0, max_dim_size - input_shape_tensor[i]]
|
||||
else:
|
||||
padding = [0, 0]
|
||||
paddings.append(padding)
|
||||
|
Loading…
x
Reference in New Issue
Block a user