From 98bbee7afec8e197bbe56f5ba2804a6ea75d3317 Mon Sep 17 00:00:00 2001 From: Yanan Cao Date: Wed, 26 Dec 2018 13:08:28 -0800 Subject: [PATCH] Support nested structure as input to xla.compile() and tpu.rewrite() computation. PiperOrigin-RevId: 226947103 --- tensorflow/contrib/compiler/xla.py | 26 +++++++------- tensorflow/contrib/tpu/python/tpu/tpu.py | 44 ++++++++++++++++-------- 2 files changed, 43 insertions(+), 27 deletions(-) diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py index f867cd15b67..4e07aa511be 100644 --- a/tensorflow/contrib/compiler/xla.py +++ b/tensorflow/contrib/compiler/xla.py @@ -34,6 +34,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import function_utils +from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect @@ -76,7 +77,9 @@ def compile(computation, inputs=None): # pylint: disable=redefined-builtin All `Operation`s returned from `computation` will be executed when evaluating any of the returned output tensors. - inputs: A list of input tensors or `None` (equivalent to an empty list). + inputs: A list of inputs or `None` (equivalent to an empty list). Each input + can be a nested structure containing values that are convertible to + tensors. Returns: A list of output tensors. @@ -260,17 +263,10 @@ def _compile_internal(computation, inputs=None): if not isinstance(inputs, collections.Sequence): raise TypeError('inputs must be a list') + # Flatten inputs. + flat_inputs = nest.flatten(inputs) # Converts inputs to Tensors. - inputs = [ops.convert_to_tensor(x) for x in inputs] - input_arity = len(inputs) - - arg_error = check_function_argument_count( - computation, input_arity, infeed_queue=None) - if arg_error is not None: - raise TypeError( - 'Supplied computation cannot be called with the specified inputs. You ' - 'specified %d inputs: %s, but the computation needs %s' % - (input_arity, str([i.name for i in inputs]), arg_error)) + flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs] cluster_name = ops.get_default_graph().unique_name('cluster') pivot = control_flow_ops.no_op(name=cluster_name + '/pivot') @@ -280,11 +276,15 @@ def _compile_internal(computation, inputs=None): # Add identity ops so even unused inputs are 'consumed' by the # computation. - computation_inputs = [ + flat_inputs = [ array_ops.identity(x, name='input_{}'.format(i)) - for i, x in enumerate(inputs) + for i, x in enumerate(flat_inputs) ] + # Re-pack flat_inputs in same structure as 'inputs'. + computation_inputs = nest.pack_sequence_as( + structure=inputs, flat_sequence=flat_inputs) + # Only resource variables work inside an XLA computation, so turn on # resource variables for the computation. vscope = variable_scope.get_variable_scope() diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 9266d81cf5f..53e5539ee6d 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat +from tensorflow.python.util import nest # Operations that indicate some error in the users graph, e.g. a placeholder @@ -487,7 +488,8 @@ def replicate(computation, computation: A Python function that builds the computation to replicate. inputs: A list of lists of input tensors or `None` (equivalent to `[[]]`), indexed by `[replica_num][input_num]`. All replicas must - have the same number of inputs. + have the same number of inputs. Each input can be a nested structure + containing values that are convertible to tensors. infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to computation. device_assignment: If not `None`, a `DeviceAssignment` describing the @@ -526,7 +528,8 @@ def split_compile_and_replicate(computation, computation: A Python function that builds the computation to replicate. inputs: A list of lists of input tensors or `None` (equivalent to `[[]]`), indexed by `[replica_num][input_num]`. All replicas must - have the same number of inputs. + have the same number of inputs. Each input can be a nested structure + containing values that are convertible to tensors. infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to computation. device_assignment: If not `None`, a `DeviceAssignment` describing the @@ -580,24 +583,32 @@ def split_compile_and_replicate(computation, if num_replicas == 0: return [] + # Checks all replicas have the same structure. + for i in xrange(1, num_replicas): + nest.assert_same_structure(inputs[0], inputs[i]) + + # Flatten inputs. + flat_inputs = [ + nest.flatten(per_replica_input) for per_replica_input in inputs + ] # Converts inputs to Tensors. - inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs] + flat_inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in flat_inputs] # Verifies that all replicas have matching numbers and types of inputs - input_types = [x.dtype for x in inputs[0]] - input_arity = len(input_types) + flat_input_types = [x.dtype for x in flat_inputs[0]] + input_arity = len(inputs[0]) + flat_input_arity = len(flat_input_types) for i in range(num_replicas): if len(inputs[i]) != input_arity: raise ValueError("Replicas must have the same number of inputs. " "Replica 0 had {} inputs, replica {} had {} " "inputs.".format(input_arity, i, len(inputs[i]))) - types = [x.dtype for x in inputs[i]] - if types != input_types: - raise ValueError( - "Replicas must have matching input types. Replica 0 had " - "input types {}, replica {} had input types {}".format( - input_types, i, types)) + types = [x.dtype for x in flat_inputs[i]] + if types != flat_input_types: + raise ValueError("Replicas must have matching input types. Replica 0 had " + "input types {}, replica {} had input types {}".format( + flat_input_types, i, types)) arg_error = xla.check_function_argument_count( computation, input_arity, infeed_queue) @@ -620,8 +631,8 @@ def split_compile_and_replicate(computation, # Fan-in: Builds a TPUReplicatedInput node for each input. computation_inputs = [] - for i in range(0, input_arity): - replicas = [inputs[replica][i] for replica in xrange(num_replicas)] + for i in range(0, flat_input_arity): + replicas = [flat_inputs[replica][i] for replica in xrange(num_replicas)] computation_inputs.append( tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) @@ -651,6 +662,10 @@ def split_compile_and_replicate(computation, i.op._set_attr("_tpu_input_identity", attr_value_pb2.AttrValue(b=True)) # pylint: enable=protected-access + # Unflatten the computation inputs to match original input structure. + computation_inputs = nest.pack_sequence_as( + structure=inputs[0], flat_sequence=computation_inputs) + # If there is an infeed queue, adds the dequeued values to the # computation's inputs. if infeed_queue is not None: @@ -1093,7 +1108,8 @@ def rewrite(computation, evaluating any of the returned output tensors, not just the ones returned. inputs: A list of input tensors or `None` (equivalent to an empty list). infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple - of arguments as inputs to `computation`. + of arguments as inputs to `computation`. Each input can be a nested + structure containing values that are convertible to tensors. device_assignment: if not `None`, a `DeviceAssignment` describing the mapping between logical cores in the computation with physical cores in the TPU topology. May be omitted for a single-core computation, in which