Support nested structure as input to xla.compile() and tpu.rewrite() computation.

PiperOrigin-RevId: 226947103
This commit is contained in:
Yanan Cao 2018-12-26 13:08:28 -08:00 committed by TensorFlower Gardener
parent 402da5c870
commit 98bbee7afe
2 changed files with 43 additions and 27 deletions

View File

@ -34,6 +34,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util import function_utils from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect from tensorflow.python.util import tf_inspect
@ -76,7 +77,9 @@ def compile(computation, inputs=None): # pylint: disable=redefined-builtin
All `Operation`s returned from `computation` will be executed when All `Operation`s returned from `computation` will be executed when
evaluating any of the returned output tensors. evaluating any of the returned output tensors.
inputs: A list of input tensors or `None` (equivalent to an empty list). inputs: A list of inputs or `None` (equivalent to an empty list). Each input
can be a nested structure containing values that are convertible to
tensors.
Returns: Returns:
A list of output tensors. A list of output tensors.
@ -260,17 +263,10 @@ def _compile_internal(computation, inputs=None):
if not isinstance(inputs, collections.Sequence): if not isinstance(inputs, collections.Sequence):
raise TypeError('inputs must be a list') raise TypeError('inputs must be a list')
# Flatten inputs.
flat_inputs = nest.flatten(inputs)
# Converts inputs to Tensors. # Converts inputs to Tensors.
inputs = [ops.convert_to_tensor(x) for x in inputs] flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs]
input_arity = len(inputs)
arg_error = check_function_argument_count(
computation, input_arity, infeed_queue=None)
if arg_error is not None:
raise TypeError(
'Supplied computation cannot be called with the specified inputs. You '
'specified %d inputs: %s, but the computation needs %s' %
(input_arity, str([i.name for i in inputs]), arg_error))
cluster_name = ops.get_default_graph().unique_name('cluster') cluster_name = ops.get_default_graph().unique_name('cluster')
pivot = control_flow_ops.no_op(name=cluster_name + '/pivot') pivot = control_flow_ops.no_op(name=cluster_name + '/pivot')
@ -280,11 +276,15 @@ def _compile_internal(computation, inputs=None):
# Add identity ops so even unused inputs are 'consumed' by the # Add identity ops so even unused inputs are 'consumed' by the
# computation. # computation.
computation_inputs = [ flat_inputs = [
array_ops.identity(x, name='input_{}'.format(i)) array_ops.identity(x, name='input_{}'.format(i))
for i, x in enumerate(inputs) for i, x in enumerate(flat_inputs)
] ]
# Re-pack flat_inputs in same structure as 'inputs'.
computation_inputs = nest.pack_sequence_as(
structure=inputs, flat_sequence=flat_inputs)
# Only resource variables work inside an XLA computation, so turn on # Only resource variables work inside an XLA computation, so turn on
# resource variables for the computation. # resource variables for the computation.
vscope = variable_scope.get_variable_scope() vscope = variable_scope.get_variable_scope()

View File

@ -36,6 +36,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util import nest
# Operations that indicate some error in the users graph, e.g. a placeholder # Operations that indicate some error in the users graph, e.g. a placeholder
@ -487,7 +488,8 @@ def replicate(computation,
computation: A Python function that builds the computation to replicate. computation: A Python function that builds the computation to replicate.
inputs: A list of lists of input tensors or `None` (equivalent to inputs: A list of lists of input tensors or `None` (equivalent to
`[[]]`), indexed by `[replica_num][input_num]`. All replicas must `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
have the same number of inputs. have the same number of inputs. Each input can be a nested structure
containing values that are convertible to tensors.
infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
of arguments as inputs to computation. of arguments as inputs to computation.
device_assignment: If not `None`, a `DeviceAssignment` describing the device_assignment: If not `None`, a `DeviceAssignment` describing the
@ -526,7 +528,8 @@ def split_compile_and_replicate(computation,
computation: A Python function that builds the computation to replicate. computation: A Python function that builds the computation to replicate.
inputs: A list of lists of input tensors or `None` (equivalent to inputs: A list of lists of input tensors or `None` (equivalent to
`[[]]`), indexed by `[replica_num][input_num]`. All replicas must `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
have the same number of inputs. have the same number of inputs. Each input can be a nested structure
containing values that are convertible to tensors.
infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
of arguments as inputs to computation. of arguments as inputs to computation.
device_assignment: If not `None`, a `DeviceAssignment` describing the device_assignment: If not `None`, a `DeviceAssignment` describing the
@ -580,24 +583,32 @@ def split_compile_and_replicate(computation,
if num_replicas == 0: if num_replicas == 0:
return [] return []
# Checks all replicas have the same structure.
for i in xrange(1, num_replicas):
nest.assert_same_structure(inputs[0], inputs[i])
# Flatten inputs.
flat_inputs = [
nest.flatten(per_replica_input) for per_replica_input in inputs
]
# Converts inputs to Tensors. # Converts inputs to Tensors.
inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs] flat_inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in flat_inputs]
# Verifies that all replicas have matching numbers and types of inputs # Verifies that all replicas have matching numbers and types of inputs
input_types = [x.dtype for x in inputs[0]] flat_input_types = [x.dtype for x in flat_inputs[0]]
input_arity = len(input_types) input_arity = len(inputs[0])
flat_input_arity = len(flat_input_types)
for i in range(num_replicas): for i in range(num_replicas):
if len(inputs[i]) != input_arity: if len(inputs[i]) != input_arity:
raise ValueError("Replicas must have the same number of inputs. " raise ValueError("Replicas must have the same number of inputs. "
"Replica 0 had {} inputs, replica {} had {} " "Replica 0 had {} inputs, replica {} had {} "
"inputs.".format(input_arity, i, len(inputs[i]))) "inputs.".format(input_arity, i, len(inputs[i])))
types = [x.dtype for x in inputs[i]] types = [x.dtype for x in flat_inputs[i]]
if types != input_types: if types != flat_input_types:
raise ValueError( raise ValueError("Replicas must have matching input types. Replica 0 had "
"Replicas must have matching input types. Replica 0 had "
"input types {}, replica {} had input types {}".format( "input types {}, replica {} had input types {}".format(
input_types, i, types)) flat_input_types, i, types))
arg_error = xla.check_function_argument_count( arg_error = xla.check_function_argument_count(
computation, input_arity, infeed_queue) computation, input_arity, infeed_queue)
@ -620,8 +631,8 @@ def split_compile_and_replicate(computation,
# Fan-in: Builds a TPUReplicatedInput node for each input. # Fan-in: Builds a TPUReplicatedInput node for each input.
computation_inputs = [] computation_inputs = []
for i in range(0, input_arity): for i in range(0, flat_input_arity):
replicas = [inputs[replica][i] for replica in xrange(num_replicas)] replicas = [flat_inputs[replica][i] for replica in xrange(num_replicas)]
computation_inputs.append( computation_inputs.append(
tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i)))
@ -651,6 +662,10 @@ def split_compile_and_replicate(computation,
i.op._set_attr("_tpu_input_identity", attr_value_pb2.AttrValue(b=True)) i.op._set_attr("_tpu_input_identity", attr_value_pb2.AttrValue(b=True))
# pylint: enable=protected-access # pylint: enable=protected-access
# Unflatten the computation inputs to match original input structure.
computation_inputs = nest.pack_sequence_as(
structure=inputs[0], flat_sequence=computation_inputs)
# If there is an infeed queue, adds the dequeued values to the # If there is an infeed queue, adds the dequeued values to the
# computation's inputs. # computation's inputs.
if infeed_queue is not None: if infeed_queue is not None:
@ -1093,7 +1108,8 @@ def rewrite(computation,
evaluating any of the returned output tensors, not just the ones returned. evaluating any of the returned output tensors, not just the ones returned.
inputs: A list of input tensors or `None` (equivalent to an empty list). inputs: A list of input tensors or `None` (equivalent to an empty list).
infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
of arguments as inputs to `computation`. of arguments as inputs to `computation`. Each input can be a nested
structure containing values that are convertible to tensors.
device_assignment: if not `None`, a `DeviceAssignment` describing the device_assignment: if not `None`, a `DeviceAssignment` describing the
mapping between logical cores in the computation with physical cores in mapping between logical cores in the computation with physical cores in
the TPU topology. May be omitted for a single-core computation, in which the TPU topology. May be omitted for a single-core computation, in which