Place padding ops on TPUs if the original input nodes has already been on TPUs, so XLA op demand mode can be triggered. This will avoid unnecessary data copy between host and device.
PiperOrigin-RevId: 270940275
This commit is contained in:
parent
5e15426e3a
commit
2405e8a159
@ -77,6 +77,7 @@ _UNCONNECTED_OPS_TO_PRUNE = set(["Placeholder", "VarHandleOp"])
|
||||
_MAX_WARNING_LINES = 5
|
||||
|
||||
_TPU_REPLICATE_ATTR = "_tpu_replicate"
|
||||
_POST_DEVICE_REWRITE_ATTR = "_post_device_rewrite"
|
||||
_TPU_COMPILATION_STATUS_ATTR = "_tpu_compilation_status"
|
||||
_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation"
|
||||
|
||||
@ -707,7 +708,13 @@ def _pad_all_input(inputs, padded_shapes):
|
||||
need_padding[idx][i] = True
|
||||
maximum_static_shapes[idx] = max(input_shape,
|
||||
maximum_static_shapes[idx])
|
||||
input_shape_tensors[idx].append(array_ops.shape(input_tensor))
|
||||
|
||||
# Append _POST_DEVICE_REWRITE_ATTR attributes to the real shape ops.
|
||||
real_input_shape = array_ops.shape(input_tensor)
|
||||
real_input_shape.op._set_attr( # pylint: disable=protected-access
|
||||
_POST_DEVICE_REWRITE_ATTR,
|
||||
attr_value_pb2.AttrValue(b=True))
|
||||
input_shape_tensors[idx].append(real_input_shape)
|
||||
|
||||
maximum_shapes = []
|
||||
for shapes_per_input in input_shape_tensors:
|
||||
@ -763,6 +770,12 @@ def _pad_all_input(inputs, padded_shapes):
|
||||
lambda: input_tensor)
|
||||
else:
|
||||
padded_input = array_ops.pad(input_tensor, paddings)
|
||||
|
||||
# Append _POST_DEVICE_REWRITE_ATTR attributes to all padded inputs.
|
||||
padded_input.op._set_attr( # pylint: disable=protected-access
|
||||
_POST_DEVICE_REWRITE_ATTR,
|
||||
attr_value_pb2.AttrValue(b=True))
|
||||
|
||||
padded_inputs[core_idx].append(padded_input)
|
||||
else:
|
||||
padded_inputs[core_idx].append(input_tensor)
|
||||
|
Loading…
x
Reference in New Issue
Block a user