Add finer grained check of dynamic shape inputs even maximum_shapes
is set in tpu.replicate
.
PiperOrigin-RevId: 297676611 Change-Id: Ie644ca680406d88597365f1bdb0b27c19f8b892c
This commit is contained in:
parent
7e11304e55
commit
aad9a544a4
@ -1162,6 +1162,7 @@ def split_compile_and_replicate(computation,
|
||||
for i in inputs[0]]), infeed_queue.number_of_tuple_elements,
|
||||
arg_error))
|
||||
|
||||
dynamic_shape_inputs = False
|
||||
if maximum_shapes:
|
||||
if infeed_queue:
|
||||
raise ValueError(
|
||||
@ -1178,6 +1179,8 @@ def split_compile_and_replicate(computation,
|
||||
|
||||
flat_inputs, padding_maps = _pad_all_input(flat_inputs, flat_maximum_shapes,
|
||||
padding_spec)
|
||||
if padding_maps:
|
||||
dynamic_shape_inputs = True
|
||||
|
||||
serialized_padding_maps = []
|
||||
for padding_map in padding_maps:
|
||||
@ -1232,7 +1235,7 @@ def split_compile_and_replicate(computation,
|
||||
# inputs when dynamic padding is enabled.
|
||||
# TODO(rxsang): Use other ways except argument index in padding_map so
|
||||
# outside compilation can work with dynamic padding correctly.
|
||||
if maximum_shapes is None:
|
||||
if not dynamic_shape_inputs:
|
||||
i.op._set_attr("_tpu_input_identity",
|
||||
attr_value_pb2.AttrValue(b=True))
|
||||
# pylint: enable=protected-access
|
||||
@ -1266,9 +1269,8 @@ def split_compile_and_replicate(computation,
|
||||
kwargs["partitioner"] = None
|
||||
logging.warning(
|
||||
"Partitioned variables are not supported on TPU. Got "
|
||||
"`partitioner` that is {} for variable {}. "
|
||||
"Setting `partitioner` to `None`."
|
||||
.format(partitioner, name))
|
||||
"`partitioner` that is %s for variable %s. "
|
||||
"Setting `partitioner` to `None`.", partitioner, name)
|
||||
if saved_custom_getter is None:
|
||||
return getter(name, *args, **kwargs)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user