From 53d3a75b7b7237b0582220913b17b811741f9b70 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Fri, 25 Jan 2019 03:26:19 -0800 Subject: [PATCH] Remove infeed from TPU Strategy when using host training loop. PiperOrigin-RevId: 230877075 --- .../contrib/distribute/python/tpu_strategy.py | 294 +++++++++++------- 1 file changed, 186 insertions(+), 108 deletions(-) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 5a494a6ae7b..4863bb29dae 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -21,6 +21,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import copy from tensorflow.contrib.tpu.python.ops import tpu_ops @@ -42,6 +43,7 @@ from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver a from tensorflow.python.eager import context from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op +from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util @@ -98,9 +100,9 @@ def _create_tpu_mirrored_variable( # pylint: disable=missing-docstring *args, **kwargs): # Figure out what collections this variable should be added to. # We'll add the TPUMirroredVariable to those collections instead. - collections = kwargs.pop("collections", None) - if collections is None: - collections = [ops.GraphKeys.GLOBAL_VARIABLES] + var_collections = kwargs.pop("collections", None) + if var_collections is None: + var_collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] # TODO(jhseu): Should we have different behavior for different @@ -139,11 +141,11 @@ def _create_tpu_mirrored_variable( # pylint: disable=missing-docstring # "trainable" to False for next_creator() since that causes functions # like implicit_gradients to skip those variables. if kwargs.get("trainable", True): - collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) + var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for v in value_list: l.remove(v) - g.add_to_collections(collections, result) + g.add_to_collections(var_collections, result) return result @@ -170,19 +172,16 @@ class TPUStrategy(distribute_lib.DistributionStrategy): the usecase of using a single core within a TPU cluster. **kwargs: Additional experimental flags. Will be removed in future. """ - super(TPUStrategy, self).__init__(TPUExtended( - self, tpu_cluster_resolver, steps_per_run, device_assignment)) - - self._disable_training_loop_on_host = False if len(kwargs) > 1: raise ValueError("TPUStrategy constructor only takes one experimental " "flag now") - if len(kwargs) == 1: - if "_disable_training_loop_on_host" not in kwargs: - raise ValueError("TPUStrategy constructor does not support arguments: " - "{}".format(kwargs)) - self._disable_training_loop_on_host = ( - kwargs["_disable_training_loop_on_host"]) + elif len(kwargs) == 1 and "_disable_training_loop_on_host" not in kwargs: + raise ValueError("TPUStrategy constructor does not support arguments: " + "{}".format(kwargs)) + + super(TPUStrategy, self).__init__(TPUExtended( + self, tpu_cluster_resolver, steps_per_run, device_assignment, + kwargs.get("_disable_training_loop_on_host", False))) @property def steps_per_run(self): @@ -197,7 +196,8 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): container_strategy, tpu_cluster_resolver=None, steps_per_run=None, - device_assignment=None): + device_assignment=None, + disable_training_loop_on_host=False): super(TPUExtended, self).__init__(container_strategy) if tpu_cluster_resolver is None: @@ -211,6 +211,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): self._tpu_cluster_resolver = tpu_cluster_resolver self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) self._device_assignment = device_assignment + self._disable_training_loop_on_host = disable_training_loop_on_host # Device assignment is currently only supported for 1 core case. if self._device_assignment: @@ -238,15 +239,25 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync] self._device_map = values.ReplicaDeviceMap(self._tpu_devices) - # For input: - input_device_map = values.ReplicaDeviceMap(tuple( - self.get_host_cpu_device(hid) for hid in range(self.num_hosts))) - worker_devices = [ - (self.get_host(hid), [self.get_host_cpu_device(hid)]) - for hid in range(self.num_hosts) - ] - self._input_workers = input_lib.InputWorkers( - input_device_map, worker_devices) + # If the training loop is on the device, we must use the infeed, with input + # on the host. Otherwise, we preload the data onto the TPUs. + if disable_training_loop_on_host: + input_device_map = values.ReplicaDeviceMap(tuple( + self.get_host_cpu_device(hid) for hid in range(self.num_hosts))) + worker_devices = [ + (self.get_host(hid), [self.get_host_cpu_device(hid)]) + for hid in range(self.num_hosts) + ] + self._input_workers = input_lib.InputWorkers( + input_device_map, worker_devices) + else: + input_worker_devices = collections.OrderedDict() + for tpu_device in self._tpu_devices: + host_device = _get_host_for_device(tpu_device) + input_worker_devices.setdefault(host_device, []) + input_worker_devices[host_device].append(tpu_device) + self._input_workers = input_lib.InputWorkers( + self._device_map, tuple(input_worker_devices.items())) # TODO(sourabhbajaj): Remove this once performance of running one step # at a time is comparable to multiple steps. @@ -346,6 +357,113 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): # a mechanism to infer the outputs of `fn`. Pending b/110550782. def _experimental_run_steps_on_iterator( self, fn, multi_worker_iterator, iterations, initial_loop_values=None): + if self._disable_training_loop_on_host: + impl = self._run_steps_on_iterator_with_device_loop + else: + impl = self._run_steps_on_iterator_with_host_loop + + return impl( + fn=fn, multi_worker_iterator=multi_worker_iterator, + iterations=iterations, initial_loop_values=initial_loop_values) + + def _run_steps_on_iterator_with_host_loop( + self, fn, multi_worker_iterator, iterations, initial_loop_values=None): + output_shapes = multi_worker_iterator.output_shapes + shapes = nest.flatten(output_shapes) + if any(not s.is_fully_defined() for s in shapes): + raise ValueError( + "TPU currently requires fully defined shapes. Either use " + "set_shape() on the input tensors or use " + "dataset.batch(..., drop_remainder=True).") + + # Wrap `fn` for repeat. + if initial_loop_values is None: + initial_loop_values = {} + initial_loop_values = nest.flatten(initial_loop_values) + ctx = input_lib.MultiStepContext() + + def run_fn(inputs): + """Single step on the TPU device.""" + fn_result = fn(ctx, inputs) + flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) + if flat_last_step_outputs: + with ops.control_dependencies([fn_result]): + return [array_ops.identity(f) for f in flat_last_step_outputs] + else: + return fn_result + + # We capture the control_flow_context at this point, before we run `fn` + # inside a while_loop and TPU replicate context. This is useful in cases + # where we might need to exit these contexts and get back to the outer + # context to do some things, for e.g. create an op which should be + # evaluated only once at the end of the loop on the host. One such usage + # is in creating metrics' value op. + self._outer_control_flow_context = ( + ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access + + def rewrite_fn(*args): + """The rewritten step fn running on TPU.""" + del args + + per_replica_inputs = multi_worker_iterator.get_next() + replicate_inputs = [] + for replica_id in range(self._num_replicas_in_sync): + select_replica = lambda x: values.select_replica(replica_id, x) # pylint: disable=cell-var-from-loop + replicate_inputs.append((nest.map_structure( + select_replica, per_replica_inputs),)) + + replicate_outputs = tpu.replicate(run_fn, replicate_inputs) + + # If run_fn has tensor outputs, tpu.replicate returns a list of list. We + # will flatten it in this case. If run_fn has no tensor outputs, + # tpu.replicate returns a list of no_ops, we will keep the output as it + # is. + if isinstance(replicate_outputs[0], list): + replicate_outputs = nest.flatten(replicate_outputs) + + return replicate_outputs + + # TODO(sourabhbajaj): The input to while loop should be based on the + # output type of the step_fn + assert isinstance(initial_loop_values, list) + initial_loop_values = initial_loop_values * self._num_replicas_in_sync + + # Put the while loop op on host 0. + with ops.device(self.get_host_cpu_device(0)): + replicate_outputs = training_loop.repeat(iterations, rewrite_fn, + initial_loop_values) + + del self._outer_control_flow_context + ctx.run_op = control_flow_ops.group(replicate_outputs) + + if isinstance(replicate_outputs, list): + # Filter out any ops from the outputs, typically this would be the case + # when there were no tensor outputs. + last_step_tensor_outputs = [ + x for x in replicate_outputs if not isinstance(x, ops.Operation) + ] + + # Outputs are currently of the structure (flattened) + # [output0_device0, output1_device0, output2_device0, + # output0_device1, output1_device1, output2_device1, + # ...] + # Convert this to the following structure instead: (grouped by output) + # [[output0_device0, output0_device1], + # [output1_device0, output1_device1], + # [output2_device0, output2_device1]] + output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync + last_step_tensor_outputs = [ + last_step_tensor_outputs[i::output_num] for i in range(output_num) + ] + else: + # no tensors returned. + last_step_tensor_outputs = [] + + _set_last_step_outputs(ctx, last_step_tensor_outputs) + return ctx + + def _run_steps_on_iterator_with_device_loop( + self, fn, multi_worker_iterator, iterations, initial_loop_values=None): output_shapes = multi_worker_iterator.output_shapes shapes = nest.flatten(output_shapes) if any(not s.is_fully_defined() for s in shapes): @@ -393,95 +511,28 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access - # pylint: disable=protected-access - if self._container_strategy()._disable_training_loop_on_host: - replicate_inputs = [[]] * self._num_replicas_in_sync - replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) - else: - def rewrite_fn(*args): - """The rewritten step fn running on TPU.""" - del args - replicate_inputs = [[]] * self._num_replicas_in_sync - replicate_outputs = tpu.replicate(run_fn, replicate_inputs) - - # If run_fn has tensor outputs, tpu.replicate returns a list of list. We - # will flatten it in this case. If run_fn has no tensor outputs, - # tpu.replicate returns a list of no_ops, we will keep the output as it - # is. - if isinstance(replicate_outputs[0], list): - replicate_outputs = nest.flatten(replicate_outputs) - - return replicate_outputs - - # TODO(sourabhbajaj): The input to while loop should be based on the - # output type of the step_fn - assert isinstance(initial_loop_values, list) - initial_loop_values = initial_loop_values * self._num_replicas_in_sync - - # Put the while loop op on host 0. - with ops.device(self.get_host_cpu_device(0)): - replicate_outputs = training_loop.repeat(iterations, rewrite_fn, - initial_loop_values) + replicate_inputs = [[]] * self._num_replicas_in_sync + replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) - if self._container_strategy()._disable_training_loop_on_host: - # Filter out any ops from the outputs, typically this would be the case - # when there were no tensor outputs. - last_step_tensor_outputs = [x for x in replicate_outputs - if not isinstance(x, ops.Operation)] + # Filter out any ops from the outputs, typically this would be the case + # when there were no tensor outputs. + last_step_tensor_outputs = [x for x in replicate_outputs + if not isinstance(x, ops.Operation)] - # Outputs are currently of the structure (grouped by device) - # [[output0_device0, output1_device0, output2_device0], - # [output0_device1, output1_device1, output2_device1]] - # Convert this to the following structure instead: (grouped by output) - # [[output0_device0, output0_device1], - # [output1_device0, output1_device1], - # [output2_device0, output2_device1]] - last_step_tensor_outputs = [list(x) for x in - zip(*last_step_tensor_outputs)] - else: - if isinstance(replicate_outputs, list): - # Filter out any ops from the outputs, typically this would be the case - # when there were no tensor outputs. - last_step_tensor_outputs = [ - x for x in replicate_outputs if not isinstance(x, ops.Operation) - ] - - # Outputs are currently of the structure (flattened) - # [output0_device0, output1_device0, output2_device0, - # output0_device1, output1_device1, output2_device1, - # ...] - # Convert this to the following structure instead: (grouped by output) - # [[output0_device0, output0_device1], - # [output1_device0, output1_device1], - # [output2_device0, output2_device1]] - output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync - last_step_tensor_outputs = [ - last_step_tensor_outputs[i::output_num] for i in range(output_num) - ] - else: - # no tensors returned. - last_step_tensor_outputs = [] - - # Convert replicate_outputs to the original dict structure of - # last_step_outputs. - last_step_tensor_outputs_dict = nest.pack_sequence_as( - ctx.last_step_outputs, last_step_tensor_outputs) - - for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access - output = last_step_tensor_outputs_dict[name] - # For outputs that have already been reduced, take the first value - # from the list as each value should be the same. Else return the full - # list of values. - # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica - # value. - if reduce_op is not None: - # TODO(priyag): Should this return the element or a list with 1 element - last_step_tensor_outputs_dict[name] = output[0] - ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access + # Outputs are currently of the structure (grouped by device) + # [[output0_device0, output1_device0, output2_device0], + # [output0_device1, output1_device1, output2_device1]] + # Convert this to the following structure instead: (grouped by output) + # [[output0_device0, output0_device1], + # [output1_device0, output1_device1], + # [output2_device0, output2_device1]] + last_step_tensor_outputs = [list(x) for x in + zip(*last_step_tensor_outputs)] + _set_last_step_outputs(ctx, last_step_tensor_outputs) return ctx def _call_for_each_replica(self, fn, args, kwargs): @@ -744,3 +795,30 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext): ds = self._strategy replica_id = tensor_util.constant_value(self._replica_id_in_sync_group) return (ds.extended.worker_devices[replica_id],) + + +def _get_host_for_device(device): + spec = tf_device.DeviceSpec.from_string(device) + return tf_device.DeviceSpec( + job=spec.job, replica=spec.replica, task=spec.task, + device_type="CPU", device_index=0).to_string() + + +def _set_last_step_outputs(ctx, last_step_tensor_outputs): + """Sets the last step outputs on the given context.""" + # Convert replicate_outputs to the original dict structure of + # last_step_outputs. + last_step_tensor_outputs_dict = nest.pack_sequence_as( + ctx.last_step_outputs, last_step_tensor_outputs) + + for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access + output = last_step_tensor_outputs_dict[name] + # For outputs that have already been reduced, take the first value + # from the list as each value should be the same. Else return the full + # list of values. + # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica + # value. + if reduce_op is not None: + # TODO(priyag): Should this return the element or a list with 1 element + last_step_tensor_outputs_dict[name] = output[0] + ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access