Support nested structure as input to xla.compile() and tpu.rewrite() computation.
PiperOrigin-RevId: 226947103
This commit is contained in:
parent
402da5c870
commit
98bbee7afe
@ -34,6 +34,7 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import function_utils
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
@ -76,7 +77,9 @@ def compile(computation, inputs=None): # pylint: disable=redefined-builtin
|
||||
|
||||
All `Operation`s returned from `computation` will be executed when
|
||||
evaluating any of the returned output tensors.
|
||||
inputs: A list of input tensors or `None` (equivalent to an empty list).
|
||||
inputs: A list of inputs or `None` (equivalent to an empty list). Each input
|
||||
can be a nested structure containing values that are convertible to
|
||||
tensors.
|
||||
|
||||
Returns:
|
||||
A list of output tensors.
|
||||
@ -260,17 +263,10 @@ def _compile_internal(computation, inputs=None):
|
||||
if not isinstance(inputs, collections.Sequence):
|
||||
raise TypeError('inputs must be a list')
|
||||
|
||||
# Flatten inputs.
|
||||
flat_inputs = nest.flatten(inputs)
|
||||
# Converts inputs to Tensors.
|
||||
inputs = [ops.convert_to_tensor(x) for x in inputs]
|
||||
input_arity = len(inputs)
|
||||
|
||||
arg_error = check_function_argument_count(
|
||||
computation, input_arity, infeed_queue=None)
|
||||
if arg_error is not None:
|
||||
raise TypeError(
|
||||
'Supplied computation cannot be called with the specified inputs. You '
|
||||
'specified %d inputs: %s, but the computation needs %s' %
|
||||
(input_arity, str([i.name for i in inputs]), arg_error))
|
||||
flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs]
|
||||
|
||||
cluster_name = ops.get_default_graph().unique_name('cluster')
|
||||
pivot = control_flow_ops.no_op(name=cluster_name + '/pivot')
|
||||
@ -280,11 +276,15 @@ def _compile_internal(computation, inputs=None):
|
||||
|
||||
# Add identity ops so even unused inputs are 'consumed' by the
|
||||
# computation.
|
||||
computation_inputs = [
|
||||
flat_inputs = [
|
||||
array_ops.identity(x, name='input_{}'.format(i))
|
||||
for i, x in enumerate(inputs)
|
||||
for i, x in enumerate(flat_inputs)
|
||||
]
|
||||
|
||||
# Re-pack flat_inputs in same structure as 'inputs'.
|
||||
computation_inputs = nest.pack_sequence_as(
|
||||
structure=inputs, flat_sequence=flat_inputs)
|
||||
|
||||
# Only resource variables work inside an XLA computation, so turn on
|
||||
# resource variables for the computation.
|
||||
vscope = variable_scope.get_variable_scope()
|
||||
|
@ -36,6 +36,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
# Operations that indicate some error in the users graph, e.g. a placeholder
|
||||
@ -487,7 +488,8 @@ def replicate(computation,
|
||||
computation: A Python function that builds the computation to replicate.
|
||||
inputs: A list of lists of input tensors or `None` (equivalent to
|
||||
`[[]]`), indexed by `[replica_num][input_num]`. All replicas must
|
||||
have the same number of inputs.
|
||||
have the same number of inputs. Each input can be a nested structure
|
||||
containing values that are convertible to tensors.
|
||||
infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
|
||||
of arguments as inputs to computation.
|
||||
device_assignment: If not `None`, a `DeviceAssignment` describing the
|
||||
@ -526,7 +528,8 @@ def split_compile_and_replicate(computation,
|
||||
computation: A Python function that builds the computation to replicate.
|
||||
inputs: A list of lists of input tensors or `None` (equivalent to
|
||||
`[[]]`), indexed by `[replica_num][input_num]`. All replicas must
|
||||
have the same number of inputs.
|
||||
have the same number of inputs. Each input can be a nested structure
|
||||
containing values that are convertible to tensors.
|
||||
infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
|
||||
of arguments as inputs to computation.
|
||||
device_assignment: If not `None`, a `DeviceAssignment` describing the
|
||||
@ -580,24 +583,32 @@ def split_compile_and_replicate(computation,
|
||||
if num_replicas == 0:
|
||||
return []
|
||||
|
||||
# Checks all replicas have the same structure.
|
||||
for i in xrange(1, num_replicas):
|
||||
nest.assert_same_structure(inputs[0], inputs[i])
|
||||
|
||||
# Flatten inputs.
|
||||
flat_inputs = [
|
||||
nest.flatten(per_replica_input) for per_replica_input in inputs
|
||||
]
|
||||
# Converts inputs to Tensors.
|
||||
inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs]
|
||||
flat_inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in flat_inputs]
|
||||
|
||||
# Verifies that all replicas have matching numbers and types of inputs
|
||||
input_types = [x.dtype for x in inputs[0]]
|
||||
input_arity = len(input_types)
|
||||
flat_input_types = [x.dtype for x in flat_inputs[0]]
|
||||
input_arity = len(inputs[0])
|
||||
flat_input_arity = len(flat_input_types)
|
||||
for i in range(num_replicas):
|
||||
if len(inputs[i]) != input_arity:
|
||||
raise ValueError("Replicas must have the same number of inputs. "
|
||||
"Replica 0 had {} inputs, replica {} had {} "
|
||||
"inputs.".format(input_arity, i, len(inputs[i])))
|
||||
|
||||
types = [x.dtype for x in inputs[i]]
|
||||
if types != input_types:
|
||||
raise ValueError(
|
||||
"Replicas must have matching input types. Replica 0 had "
|
||||
"input types {}, replica {} had input types {}".format(
|
||||
input_types, i, types))
|
||||
types = [x.dtype for x in flat_inputs[i]]
|
||||
if types != flat_input_types:
|
||||
raise ValueError("Replicas must have matching input types. Replica 0 had "
|
||||
"input types {}, replica {} had input types {}".format(
|
||||
flat_input_types, i, types))
|
||||
|
||||
arg_error = xla.check_function_argument_count(
|
||||
computation, input_arity, infeed_queue)
|
||||
@ -620,8 +631,8 @@ def split_compile_and_replicate(computation,
|
||||
|
||||
# Fan-in: Builds a TPUReplicatedInput node for each input.
|
||||
computation_inputs = []
|
||||
for i in range(0, input_arity):
|
||||
replicas = [inputs[replica][i] for replica in xrange(num_replicas)]
|
||||
for i in range(0, flat_input_arity):
|
||||
replicas = [flat_inputs[replica][i] for replica in xrange(num_replicas)]
|
||||
computation_inputs.append(
|
||||
tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i)))
|
||||
|
||||
@ -651,6 +662,10 @@ def split_compile_and_replicate(computation,
|
||||
i.op._set_attr("_tpu_input_identity", attr_value_pb2.AttrValue(b=True))
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# Unflatten the computation inputs to match original input structure.
|
||||
computation_inputs = nest.pack_sequence_as(
|
||||
structure=inputs[0], flat_sequence=computation_inputs)
|
||||
|
||||
# If there is an infeed queue, adds the dequeued values to the
|
||||
# computation's inputs.
|
||||
if infeed_queue is not None:
|
||||
@ -1093,7 +1108,8 @@ def rewrite(computation,
|
||||
evaluating any of the returned output tensors, not just the ones returned.
|
||||
inputs: A list of input tensors or `None` (equivalent to an empty list).
|
||||
infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
|
||||
of arguments as inputs to `computation`.
|
||||
of arguments as inputs to `computation`. Each input can be a nested
|
||||
structure containing values that are convertible to tensors.
|
||||
device_assignment: if not `None`, a `DeviceAssignment` describing the
|
||||
mapping between logical cores in the computation with physical cores in
|
||||
the TPU topology. May be omitted for a single-core computation, in which
|
||||
|
Loading…
Reference in New Issue
Block a user