Place padding ops on TPUs if the original input nodes has already been on TPUs, so XLA op demand mode can be triggered. This will avoid unnecessary data copy between host and device.

PiperOrigin-RevId: 270940275
This commit is contained in:
Ruoxin Sang 2019-09-24 10:53:49 -07:00 committed by TensorFlower Gardener
parent 5e15426e3a
commit 2405e8a159

View File

@ -77,6 +77,7 @@ _UNCONNECTED_OPS_TO_PRUNE = set(["Placeholder", "VarHandleOp"])
_MAX_WARNING_LINES = 5
_TPU_REPLICATE_ATTR = "_tpu_replicate"
_POST_DEVICE_REWRITE_ATTR = "_post_device_rewrite"
_TPU_COMPILATION_STATUS_ATTR = "_tpu_compilation_status"
_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation"
@ -707,7 +708,13 @@ def _pad_all_input(inputs, padded_shapes):
need_padding[idx][i] = True
maximum_static_shapes[idx] = max(input_shape,
maximum_static_shapes[idx])
input_shape_tensors[idx].append(array_ops.shape(input_tensor))
# Append _POST_DEVICE_REWRITE_ATTR attributes to the real shape ops.
real_input_shape = array_ops.shape(input_tensor)
real_input_shape.op._set_attr( # pylint: disable=protected-access
_POST_DEVICE_REWRITE_ATTR,
attr_value_pb2.AttrValue(b=True))
input_shape_tensors[idx].append(real_input_shape)
maximum_shapes = []
for shapes_per_input in input_shape_tensors:
@ -763,6 +770,12 @@ def _pad_all_input(inputs, padded_shapes):
lambda: input_tensor)
else:
padded_input = array_ops.pad(input_tensor, paddings)
# Append _POST_DEVICE_REWRITE_ATTR attributes to all padded inputs.
padded_input.op._set_attr( # pylint: disable=protected-access
_POST_DEVICE_REWRITE_ATTR,
attr_value_pb2.AttrValue(b=True))
padded_inputs[core_idx].append(padded_input)
else:
padded_inputs[core_idx].append(input_tensor)