Add finer grained check of dynamic shape inputs even maximum_shapes is set in tpu.replicate.

PiperOrigin-RevId: 297676611
Change-Id: Ie644ca680406d88597365f1bdb0b27c19f8b892c
This commit is contained in:
Ruoxin Sang 2020-02-27 13:39:06 -08:00 committed by TensorFlower Gardener
parent 7e11304e55
commit aad9a544a4

View File

@ -1162,6 +1162,7 @@ def split_compile_and_replicate(computation,
for i in inputs[0]]), infeed_queue.number_of_tuple_elements, for i in inputs[0]]), infeed_queue.number_of_tuple_elements,
arg_error)) arg_error))
dynamic_shape_inputs = False
if maximum_shapes: if maximum_shapes:
if infeed_queue: if infeed_queue:
raise ValueError( raise ValueError(
@ -1178,6 +1179,8 @@ def split_compile_and_replicate(computation,
flat_inputs, padding_maps = _pad_all_input(flat_inputs, flat_maximum_shapes, flat_inputs, padding_maps = _pad_all_input(flat_inputs, flat_maximum_shapes,
padding_spec) padding_spec)
if padding_maps:
dynamic_shape_inputs = True
serialized_padding_maps = [] serialized_padding_maps = []
for padding_map in padding_maps: for padding_map in padding_maps:
@ -1232,7 +1235,7 @@ def split_compile_and_replicate(computation,
# inputs when dynamic padding is enabled. # inputs when dynamic padding is enabled.
# TODO(rxsang): Use other ways except argument index in padding_map so # TODO(rxsang): Use other ways except argument index in padding_map so
# outside compilation can work with dynamic padding correctly. # outside compilation can work with dynamic padding correctly.
if maximum_shapes is None: if not dynamic_shape_inputs:
i.op._set_attr("_tpu_input_identity", i.op._set_attr("_tpu_input_identity",
attr_value_pb2.AttrValue(b=True)) attr_value_pb2.AttrValue(b=True))
# pylint: enable=protected-access # pylint: enable=protected-access
@ -1266,9 +1269,8 @@ def split_compile_and_replicate(computation,
kwargs["partitioner"] = None kwargs["partitioner"] = None
logging.warning( logging.warning(
"Partitioned variables are not supported on TPU. Got " "Partitioned variables are not supported on TPU. Got "
"`partitioner` that is {} for variable {}. " "`partitioner` that is %s for variable %s. "
"Setting `partitioner` to `None`." "Setting `partitioner` to `None`.", partitioner, name)
.format(partitioner, name))
if saved_custom_getter is None: if saved_custom_getter is None:
return getter(name, *args, **kwargs) return getter(name, *args, **kwargs)
else: else: