Remove infeed from TPU Strategy when using host training loop.

PiperOrigin-RevId: 230877075
This commit is contained in:
Chris Jones 2019-01-25 03:26:19 -08:00 committed by TensorFlower Gardener
parent bdca83b71c
commit 53d3a75b7b

View File

@ -21,6 +21,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections
import copy import copy
from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.ops import tpu_ops
@ -42,6 +43,7 @@ from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver a
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import tape from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
@ -98,9 +100,9 @@ def _create_tpu_mirrored_variable( # pylint: disable=missing-docstring
*args, **kwargs): *args, **kwargs):
# Figure out what collections this variable should be added to. # Figure out what collections this variable should be added to.
# We'll add the TPUMirroredVariable to those collections instead. # We'll add the TPUMirroredVariable to those collections instead.
collections = kwargs.pop("collections", None) var_collections = kwargs.pop("collections", None)
if collections is None: if var_collections is None:
collections = [ops.GraphKeys.GLOBAL_VARIABLES] var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
kwargs["collections"] = [] kwargs["collections"] = []
# TODO(jhseu): Should we have different behavior for different # TODO(jhseu): Should we have different behavior for different
@ -139,11 +141,11 @@ def _create_tpu_mirrored_variable( # pylint: disable=missing-docstring
# "trainable" to False for next_creator() since that causes functions # "trainable" to False for next_creator() since that causes functions
# like implicit_gradients to skip those variables. # like implicit_gradients to skip those variables.
if kwargs.get("trainable", True): if kwargs.get("trainable", True):
collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
for v in value_list: for v in value_list:
l.remove(v) l.remove(v)
g.add_to_collections(collections, result) g.add_to_collections(var_collections, result)
return result return result
@ -170,19 +172,16 @@ class TPUStrategy(distribute_lib.DistributionStrategy):
the usecase of using a single core within a TPU cluster. the usecase of using a single core within a TPU cluster.
**kwargs: Additional experimental flags. Will be removed in future. **kwargs: Additional experimental flags. Will be removed in future.
""" """
super(TPUStrategy, self).__init__(TPUExtended(
self, tpu_cluster_resolver, steps_per_run, device_assignment))
self._disable_training_loop_on_host = False
if len(kwargs) > 1: if len(kwargs) > 1:
raise ValueError("TPUStrategy constructor only takes one experimental " raise ValueError("TPUStrategy constructor only takes one experimental "
"flag now") "flag now")
if len(kwargs) == 1: elif len(kwargs) == 1 and "_disable_training_loop_on_host" not in kwargs:
if "_disable_training_loop_on_host" not in kwargs: raise ValueError("TPUStrategy constructor does not support arguments: "
raise ValueError("TPUStrategy constructor does not support arguments: " "{}".format(kwargs))
"{}".format(kwargs))
self._disable_training_loop_on_host = ( super(TPUStrategy, self).__init__(TPUExtended(
kwargs["_disable_training_loop_on_host"]) self, tpu_cluster_resolver, steps_per_run, device_assignment,
kwargs.get("_disable_training_loop_on_host", False)))
@property @property
def steps_per_run(self): def steps_per_run(self):
@ -197,7 +196,8 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
container_strategy, container_strategy,
tpu_cluster_resolver=None, tpu_cluster_resolver=None,
steps_per_run=None, steps_per_run=None,
device_assignment=None): device_assignment=None,
disable_training_loop_on_host=False):
super(TPUExtended, self).__init__(container_strategy) super(TPUExtended, self).__init__(container_strategy)
if tpu_cluster_resolver is None: if tpu_cluster_resolver is None:
@ -211,6 +211,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
self._tpu_cluster_resolver = tpu_cluster_resolver self._tpu_cluster_resolver = tpu_cluster_resolver
self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver)
self._device_assignment = device_assignment self._device_assignment = device_assignment
self._disable_training_loop_on_host = disable_training_loop_on_host
# Device assignment is currently only supported for 1 core case. # Device assignment is currently only supported for 1 core case.
if self._device_assignment: if self._device_assignment:
@ -238,15 +239,25 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync] self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync]
self._device_map = values.ReplicaDeviceMap(self._tpu_devices) self._device_map = values.ReplicaDeviceMap(self._tpu_devices)
# For input: # If the training loop is on the device, we must use the infeed, with input
input_device_map = values.ReplicaDeviceMap(tuple( # on the host. Otherwise, we preload the data onto the TPUs.
self.get_host_cpu_device(hid) for hid in range(self.num_hosts))) if disable_training_loop_on_host:
worker_devices = [ input_device_map = values.ReplicaDeviceMap(tuple(
(self.get_host(hid), [self.get_host_cpu_device(hid)]) self.get_host_cpu_device(hid) for hid in range(self.num_hosts)))
for hid in range(self.num_hosts) worker_devices = [
] (self.get_host(hid), [self.get_host_cpu_device(hid)])
self._input_workers = input_lib.InputWorkers( for hid in range(self.num_hosts)
input_device_map, worker_devices) ]
self._input_workers = input_lib.InputWorkers(
input_device_map, worker_devices)
else:
input_worker_devices = collections.OrderedDict()
for tpu_device in self._tpu_devices:
host_device = _get_host_for_device(tpu_device)
input_worker_devices.setdefault(host_device, [])
input_worker_devices[host_device].append(tpu_device)
self._input_workers = input_lib.InputWorkers(
self._device_map, tuple(input_worker_devices.items()))
# TODO(sourabhbajaj): Remove this once performance of running one step # TODO(sourabhbajaj): Remove this once performance of running one step
# at a time is comparable to multiple steps. # at a time is comparable to multiple steps.
@ -346,6 +357,113 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
# a mechanism to infer the outputs of `fn`. Pending b/110550782. # a mechanism to infer the outputs of `fn`. Pending b/110550782.
def _experimental_run_steps_on_iterator( def _experimental_run_steps_on_iterator(
self, fn, multi_worker_iterator, iterations, initial_loop_values=None): self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
if self._disable_training_loop_on_host:
impl = self._run_steps_on_iterator_with_device_loop
else:
impl = self._run_steps_on_iterator_with_host_loop
return impl(
fn=fn, multi_worker_iterator=multi_worker_iterator,
iterations=iterations, initial_loop_values=initial_loop_values)
def _run_steps_on_iterator_with_host_loop(
self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
output_shapes = multi_worker_iterator.output_shapes
shapes = nest.flatten(output_shapes)
if any(not s.is_fully_defined() for s in shapes):
raise ValueError(
"TPU currently requires fully defined shapes. Either use "
"set_shape() on the input tensors or use "
"dataset.batch(..., drop_remainder=True).")
# Wrap `fn` for repeat.
if initial_loop_values is None:
initial_loop_values = {}
initial_loop_values = nest.flatten(initial_loop_values)
ctx = input_lib.MultiStepContext()
def run_fn(inputs):
"""Single step on the TPU device."""
fn_result = fn(ctx, inputs)
flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
if flat_last_step_outputs:
with ops.control_dependencies([fn_result]):
return [array_ops.identity(f) for f in flat_last_step_outputs]
else:
return fn_result
# We capture the control_flow_context at this point, before we run `fn`
# inside a while_loop and TPU replicate context. This is useful in cases
# where we might need to exit these contexts and get back to the outer
# context to do some things, for e.g. create an op which should be
# evaluated only once at the end of the loop on the host. One such usage
# is in creating metrics' value op.
self._outer_control_flow_context = (
ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
def rewrite_fn(*args):
"""The rewritten step fn running on TPU."""
del args
per_replica_inputs = multi_worker_iterator.get_next()
replicate_inputs = []
for replica_id in range(self._num_replicas_in_sync):
select_replica = lambda x: values.select_replica(replica_id, x) # pylint: disable=cell-var-from-loop
replicate_inputs.append((nest.map_structure(
select_replica, per_replica_inputs),))
replicate_outputs = tpu.replicate(run_fn, replicate_inputs)
# If run_fn has tensor outputs, tpu.replicate returns a list of list. We
# will flatten it in this case. If run_fn has no tensor outputs,
# tpu.replicate returns a list of no_ops, we will keep the output as it
# is.
if isinstance(replicate_outputs[0], list):
replicate_outputs = nest.flatten(replicate_outputs)
return replicate_outputs
# TODO(sourabhbajaj): The input to while loop should be based on the
# output type of the step_fn
assert isinstance(initial_loop_values, list)
initial_loop_values = initial_loop_values * self._num_replicas_in_sync
# Put the while loop op on host 0.
with ops.device(self.get_host_cpu_device(0)):
replicate_outputs = training_loop.repeat(iterations, rewrite_fn,
initial_loop_values)
del self._outer_control_flow_context
ctx.run_op = control_flow_ops.group(replicate_outputs)
if isinstance(replicate_outputs, list):
# Filter out any ops from the outputs, typically this would be the case
# when there were no tensor outputs.
last_step_tensor_outputs = [
x for x in replicate_outputs if not isinstance(x, ops.Operation)
]
# Outputs are currently of the structure (flattened)
# [output0_device0, output1_device0, output2_device0,
# output0_device1, output1_device1, output2_device1,
# ...]
# Convert this to the following structure instead: (grouped by output)
# [[output0_device0, output0_device1],
# [output1_device0, output1_device1],
# [output2_device0, output2_device1]]
output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync
last_step_tensor_outputs = [
last_step_tensor_outputs[i::output_num] for i in range(output_num)
]
else:
# no tensors returned.
last_step_tensor_outputs = []
_set_last_step_outputs(ctx, last_step_tensor_outputs)
return ctx
def _run_steps_on_iterator_with_device_loop(
self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
output_shapes = multi_worker_iterator.output_shapes output_shapes = multi_worker_iterator.output_shapes
shapes = nest.flatten(output_shapes) shapes = nest.flatten(output_shapes)
if any(not s.is_fully_defined() for s in shapes): if any(not s.is_fully_defined() for s in shapes):
@ -393,95 +511,28 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
self._outer_control_flow_context = ( self._outer_control_flow_context = (
ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
# pylint: disable=protected-access replicate_inputs = [[]] * self._num_replicas_in_sync
if self._container_strategy()._disable_training_loop_on_host: replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
replicate_inputs = [[]] * self._num_replicas_in_sync
replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
else:
def rewrite_fn(*args):
"""The rewritten step fn running on TPU."""
del args
replicate_inputs = [[]] * self._num_replicas_in_sync
replicate_outputs = tpu.replicate(run_fn, replicate_inputs)
# If run_fn has tensor outputs, tpu.replicate returns a list of list. We
# will flatten it in this case. If run_fn has no tensor outputs,
# tpu.replicate returns a list of no_ops, we will keep the output as it
# is.
if isinstance(replicate_outputs[0], list):
replicate_outputs = nest.flatten(replicate_outputs)
return replicate_outputs
# TODO(sourabhbajaj): The input to while loop should be based on the
# output type of the step_fn
assert isinstance(initial_loop_values, list)
initial_loop_values = initial_loop_values * self._num_replicas_in_sync
# Put the while loop op on host 0.
with ops.device(self.get_host_cpu_device(0)):
replicate_outputs = training_loop.repeat(iterations, rewrite_fn,
initial_loop_values)
del self._outer_control_flow_context del self._outer_control_flow_context
ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)
if self._container_strategy()._disable_training_loop_on_host: # Filter out any ops from the outputs, typically this would be the case
# Filter out any ops from the outputs, typically this would be the case # when there were no tensor outputs.
# when there were no tensor outputs. last_step_tensor_outputs = [x for x in replicate_outputs
last_step_tensor_outputs = [x for x in replicate_outputs if not isinstance(x, ops.Operation)]
if not isinstance(x, ops.Operation)]
# Outputs are currently of the structure (grouped by device) # Outputs are currently of the structure (grouped by device)
# [[output0_device0, output1_device0, output2_device0], # [[output0_device0, output1_device0, output2_device0],
# [output0_device1, output1_device1, output2_device1]] # [output0_device1, output1_device1, output2_device1]]
# Convert this to the following structure instead: (grouped by output) # Convert this to the following structure instead: (grouped by output)
# [[output0_device0, output0_device1], # [[output0_device0, output0_device1],
# [output1_device0, output1_device1], # [output1_device0, output1_device1],
# [output2_device0, output2_device1]] # [output2_device0, output2_device1]]
last_step_tensor_outputs = [list(x) for x in last_step_tensor_outputs = [list(x) for x in
zip(*last_step_tensor_outputs)] zip(*last_step_tensor_outputs)]
else:
if isinstance(replicate_outputs, list):
# Filter out any ops from the outputs, typically this would be the case
# when there were no tensor outputs.
last_step_tensor_outputs = [
x for x in replicate_outputs if not isinstance(x, ops.Operation)
]
# Outputs are currently of the structure (flattened)
# [output0_device0, output1_device0, output2_device0,
# output0_device1, output1_device1, output2_device1,
# ...]
# Convert this to the following structure instead: (grouped by output)
# [[output0_device0, output0_device1],
# [output1_device0, output1_device1],
# [output2_device0, output2_device1]]
output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync
last_step_tensor_outputs = [
last_step_tensor_outputs[i::output_num] for i in range(output_num)
]
else:
# no tensors returned.
last_step_tensor_outputs = []
# Convert replicate_outputs to the original dict structure of
# last_step_outputs.
last_step_tensor_outputs_dict = nest.pack_sequence_as(
ctx.last_step_outputs, last_step_tensor_outputs)
for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access
output = last_step_tensor_outputs_dict[name]
# For outputs that have already been reduced, take the first value
# from the list as each value should be the same. Else return the full
# list of values.
# TODO(josh11b): If reduce_op is NONE, we should return a PerReplica
# value.
if reduce_op is not None:
# TODO(priyag): Should this return the element or a list with 1 element
last_step_tensor_outputs_dict[name] = output[0]
ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access
_set_last_step_outputs(ctx, last_step_tensor_outputs)
return ctx return ctx
def _call_for_each_replica(self, fn, args, kwargs): def _call_for_each_replica(self, fn, args, kwargs):
@ -744,3 +795,30 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext):
ds = self._strategy ds = self._strategy
replica_id = tensor_util.constant_value(self._replica_id_in_sync_group) replica_id = tensor_util.constant_value(self._replica_id_in_sync_group)
return (ds.extended.worker_devices[replica_id],) return (ds.extended.worker_devices[replica_id],)
def _get_host_for_device(device):
spec = tf_device.DeviceSpec.from_string(device)
return tf_device.DeviceSpec(
job=spec.job, replica=spec.replica, task=spec.task,
device_type="CPU", device_index=0).to_string()
def _set_last_step_outputs(ctx, last_step_tensor_outputs):
"""Sets the last step outputs on the given context."""
# Convert replicate_outputs to the original dict structure of
# last_step_outputs.
last_step_tensor_outputs_dict = nest.pack_sequence_as(
ctx.last_step_outputs, last_step_tensor_outputs)
for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access
output = last_step_tensor_outputs_dict[name]
# For outputs that have already been reduced, take the first value
# from the list as each value should be the same. Else return the full
# list of values.
# TODO(josh11b): If reduce_op is NONE, we should return a PerReplica
# value.
if reduce_op is not None:
# TODO(priyag): Should this return the element or a list with 1 element
last_step_tensor_outputs_dict[name] = output[0]
ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access