Support nested structure as input to xla.compile() and tpu.rewrite() computation.

PiperOrigin-RevId: 226947103
This commit is contained in:
Yanan Cao 2018-12-26 13:08:28 -08:00 committed by TensorFlower Gardener
parent 402da5c870
commit 98bbee7afe
2 changed files with 43 additions and 27 deletions

View File

@ -34,6 +34,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
@ -76,7 +77,9 @@ def compile(computation, inputs=None): # pylint: disable=redefined-builtin
All `Operation`s returned from `computation` will be executed when
evaluating any of the returned output tensors.
inputs: A list of input tensors or `None` (equivalent to an empty list).
inputs: A list of inputs or `None` (equivalent to an empty list). Each input
can be a nested structure containing values that are convertible to
tensors.
Returns:
A list of output tensors.
@ -260,17 +263,10 @@ def _compile_internal(computation, inputs=None):
if not isinstance(inputs, collections.Sequence):
raise TypeError('inputs must be a list')
# Flatten inputs.
flat_inputs = nest.flatten(inputs)
# Converts inputs to Tensors.
inputs = [ops.convert_to_tensor(x) for x in inputs]
input_arity = len(inputs)
arg_error = check_function_argument_count(
computation, input_arity, infeed_queue=None)
if arg_error is not None:
raise TypeError(
'Supplied computation cannot be called with the specified inputs. You '
'specified %d inputs: %s, but the computation needs %s' %
(input_arity, str([i.name for i in inputs]), arg_error))
flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs]
cluster_name = ops.get_default_graph().unique_name('cluster')
pivot = control_flow_ops.no_op(name=cluster_name + '/pivot')
@ -280,11 +276,15 @@ def _compile_internal(computation, inputs=None):
# Add identity ops so even unused inputs are 'consumed' by the
# computation.
computation_inputs = [
flat_inputs = [
array_ops.identity(x, name='input_{}'.format(i))
for i, x in enumerate(inputs)
for i, x in enumerate(flat_inputs)
]
# Re-pack flat_inputs in same structure as 'inputs'.
computation_inputs = nest.pack_sequence_as(
structure=inputs, flat_sequence=flat_inputs)
# Only resource variables work inside an XLA computation, so turn on
# resource variables for the computation.
vscope = variable_scope.get_variable_scope()

View File

@ -36,6 +36,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import nest
# Operations that indicate some error in the users graph, e.g. a placeholder
@ -487,7 +488,8 @@ def replicate(computation,
computation: A Python function that builds the computation to replicate.
inputs: A list of lists of input tensors or `None` (equivalent to
`[[]]`), indexed by `[replica_num][input_num]`. All replicas must
have the same number of inputs.
have the same number of inputs. Each input can be a nested structure
containing values that are convertible to tensors.
infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
of arguments as inputs to computation.
device_assignment: If not `None`, a `DeviceAssignment` describing the
@ -526,7 +528,8 @@ def split_compile_and_replicate(computation,
computation: A Python function that builds the computation to replicate.
inputs: A list of lists of input tensors or `None` (equivalent to
`[[]]`), indexed by `[replica_num][input_num]`. All replicas must
have the same number of inputs.
have the same number of inputs. Each input can be a nested structure
containing values that are convertible to tensors.
infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
of arguments as inputs to computation.
device_assignment: If not `None`, a `DeviceAssignment` describing the
@ -580,24 +583,32 @@ def split_compile_and_replicate(computation,
if num_replicas == 0:
return []
# Checks all replicas have the same structure.
for i in xrange(1, num_replicas):
nest.assert_same_structure(inputs[0], inputs[i])
# Flatten inputs.
flat_inputs = [
nest.flatten(per_replica_input) for per_replica_input in inputs
]
# Converts inputs to Tensors.
inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs]
flat_inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in flat_inputs]
# Verifies that all replicas have matching numbers and types of inputs
input_types = [x.dtype for x in inputs[0]]
input_arity = len(input_types)
flat_input_types = [x.dtype for x in flat_inputs[0]]
input_arity = len(inputs[0])
flat_input_arity = len(flat_input_types)
for i in range(num_replicas):
if len(inputs[i]) != input_arity:
raise ValueError("Replicas must have the same number of inputs. "
"Replica 0 had {} inputs, replica {} had {} "
"inputs.".format(input_arity, i, len(inputs[i])))
types = [x.dtype for x in inputs[i]]
if types != input_types:
raise ValueError(
"Replicas must have matching input types. Replica 0 had "
"input types {}, replica {} had input types {}".format(
input_types, i, types))
types = [x.dtype for x in flat_inputs[i]]
if types != flat_input_types:
raise ValueError("Replicas must have matching input types. Replica 0 had "
"input types {}, replica {} had input types {}".format(
flat_input_types, i, types))
arg_error = xla.check_function_argument_count(
computation, input_arity, infeed_queue)
@ -620,8 +631,8 @@ def split_compile_and_replicate(computation,
# Fan-in: Builds a TPUReplicatedInput node for each input.
computation_inputs = []
for i in range(0, input_arity):
replicas = [inputs[replica][i] for replica in xrange(num_replicas)]
for i in range(0, flat_input_arity):
replicas = [flat_inputs[replica][i] for replica in xrange(num_replicas)]
computation_inputs.append(
tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i)))
@ -651,6 +662,10 @@ def split_compile_and_replicate(computation,
i.op._set_attr("_tpu_input_identity", attr_value_pb2.AttrValue(b=True))
# pylint: enable=protected-access
# Unflatten the computation inputs to match original input structure.
computation_inputs = nest.pack_sequence_as(
structure=inputs[0], flat_sequence=computation_inputs)
# If there is an infeed queue, adds the dequeued values to the
# computation's inputs.
if infeed_queue is not None:
@ -1093,7 +1108,8 @@ def rewrite(computation,
evaluating any of the returned output tensors, not just the ones returned.
inputs: A list of input tensors or `None` (equivalent to an empty list).
infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
of arguments as inputs to `computation`.
of arguments as inputs to `computation`. Each input can be a nested
structure containing values that are convertible to tensors.
device_assignment: if not `None`, a `DeviceAssignment` describing the
mapping between logical cores in the computation with physical cores in
the TPU topology. May be omitted for a single-core computation, in which