Remove infeed from TPU Strategy when using host training loop.
PiperOrigin-RevId: 230877075
This commit is contained in:
parent
bdca83b71c
commit
53d3a75b7b
@ -21,6 +21,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import copy
|
||||
|
||||
from tensorflow.contrib.tpu.python.ops import tpu_ops
|
||||
@ -42,6 +43,7 @@ from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver a
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device as tf_device
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
@ -98,9 +100,9 @@ def _create_tpu_mirrored_variable( # pylint: disable=missing-docstring
|
||||
*args, **kwargs):
|
||||
# Figure out what collections this variable should be added to.
|
||||
# We'll add the TPUMirroredVariable to those collections instead.
|
||||
collections = kwargs.pop("collections", None)
|
||||
if collections is None:
|
||||
collections = [ops.GraphKeys.GLOBAL_VARIABLES]
|
||||
var_collections = kwargs.pop("collections", None)
|
||||
if var_collections is None:
|
||||
var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
|
||||
kwargs["collections"] = []
|
||||
|
||||
# TODO(jhseu): Should we have different behavior for different
|
||||
@ -139,11 +141,11 @@ def _create_tpu_mirrored_variable( # pylint: disable=missing-docstring
|
||||
# "trainable" to False for next_creator() since that causes functions
|
||||
# like implicit_gradients to skip those variables.
|
||||
if kwargs.get("trainable", True):
|
||||
collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||
var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||
l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||
for v in value_list:
|
||||
l.remove(v)
|
||||
g.add_to_collections(collections, result)
|
||||
g.add_to_collections(var_collections, result)
|
||||
return result
|
||||
|
||||
|
||||
@ -170,19 +172,16 @@ class TPUStrategy(distribute_lib.DistributionStrategy):
|
||||
the usecase of using a single core within a TPU cluster.
|
||||
**kwargs: Additional experimental flags. Will be removed in future.
|
||||
"""
|
||||
super(TPUStrategy, self).__init__(TPUExtended(
|
||||
self, tpu_cluster_resolver, steps_per_run, device_assignment))
|
||||
|
||||
self._disable_training_loop_on_host = False
|
||||
if len(kwargs) > 1:
|
||||
raise ValueError("TPUStrategy constructor only takes one experimental "
|
||||
"flag now")
|
||||
if len(kwargs) == 1:
|
||||
if "_disable_training_loop_on_host" not in kwargs:
|
||||
elif len(kwargs) == 1 and "_disable_training_loop_on_host" not in kwargs:
|
||||
raise ValueError("TPUStrategy constructor does not support arguments: "
|
||||
"{}".format(kwargs))
|
||||
self._disable_training_loop_on_host = (
|
||||
kwargs["_disable_training_loop_on_host"])
|
||||
|
||||
super(TPUStrategy, self).__init__(TPUExtended(
|
||||
self, tpu_cluster_resolver, steps_per_run, device_assignment,
|
||||
kwargs.get("_disable_training_loop_on_host", False)))
|
||||
|
||||
@property
|
||||
def steps_per_run(self):
|
||||
@ -197,7 +196,8 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
||||
container_strategy,
|
||||
tpu_cluster_resolver=None,
|
||||
steps_per_run=None,
|
||||
device_assignment=None):
|
||||
device_assignment=None,
|
||||
disable_training_loop_on_host=False):
|
||||
super(TPUExtended, self).__init__(container_strategy)
|
||||
|
||||
if tpu_cluster_resolver is None:
|
||||
@ -211,6 +211,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
||||
self._tpu_cluster_resolver = tpu_cluster_resolver
|
||||
self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver)
|
||||
self._device_assignment = device_assignment
|
||||
self._disable_training_loop_on_host = disable_training_loop_on_host
|
||||
|
||||
# Device assignment is currently only supported for 1 core case.
|
||||
if self._device_assignment:
|
||||
@ -238,7 +239,9 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
||||
self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync]
|
||||
self._device_map = values.ReplicaDeviceMap(self._tpu_devices)
|
||||
|
||||
# For input:
|
||||
# If the training loop is on the device, we must use the infeed, with input
|
||||
# on the host. Otherwise, we preload the data onto the TPUs.
|
||||
if disable_training_loop_on_host:
|
||||
input_device_map = values.ReplicaDeviceMap(tuple(
|
||||
self.get_host_cpu_device(hid) for hid in range(self.num_hosts)))
|
||||
worker_devices = [
|
||||
@ -247,6 +250,14 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
||||
]
|
||||
self._input_workers = input_lib.InputWorkers(
|
||||
input_device_map, worker_devices)
|
||||
else:
|
||||
input_worker_devices = collections.OrderedDict()
|
||||
for tpu_device in self._tpu_devices:
|
||||
host_device = _get_host_for_device(tpu_device)
|
||||
input_worker_devices.setdefault(host_device, [])
|
||||
input_worker_devices[host_device].append(tpu_device)
|
||||
self._input_workers = input_lib.InputWorkers(
|
||||
self._device_map, tuple(input_worker_devices.items()))
|
||||
|
||||
# TODO(sourabhbajaj): Remove this once performance of running one step
|
||||
# at a time is comparable to multiple steps.
|
||||
@ -346,6 +357,113 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
||||
# a mechanism to infer the outputs of `fn`. Pending b/110550782.
|
||||
def _experimental_run_steps_on_iterator(
|
||||
self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
|
||||
if self._disable_training_loop_on_host:
|
||||
impl = self._run_steps_on_iterator_with_device_loop
|
||||
else:
|
||||
impl = self._run_steps_on_iterator_with_host_loop
|
||||
|
||||
return impl(
|
||||
fn=fn, multi_worker_iterator=multi_worker_iterator,
|
||||
iterations=iterations, initial_loop_values=initial_loop_values)
|
||||
|
||||
def _run_steps_on_iterator_with_host_loop(
|
||||
self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
|
||||
output_shapes = multi_worker_iterator.output_shapes
|
||||
shapes = nest.flatten(output_shapes)
|
||||
if any(not s.is_fully_defined() for s in shapes):
|
||||
raise ValueError(
|
||||
"TPU currently requires fully defined shapes. Either use "
|
||||
"set_shape() on the input tensors or use "
|
||||
"dataset.batch(..., drop_remainder=True).")
|
||||
|
||||
# Wrap `fn` for repeat.
|
||||
if initial_loop_values is None:
|
||||
initial_loop_values = {}
|
||||
initial_loop_values = nest.flatten(initial_loop_values)
|
||||
ctx = input_lib.MultiStepContext()
|
||||
|
||||
def run_fn(inputs):
|
||||
"""Single step on the TPU device."""
|
||||
fn_result = fn(ctx, inputs)
|
||||
flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
|
||||
if flat_last_step_outputs:
|
||||
with ops.control_dependencies([fn_result]):
|
||||
return [array_ops.identity(f) for f in flat_last_step_outputs]
|
||||
else:
|
||||
return fn_result
|
||||
|
||||
# We capture the control_flow_context at this point, before we run `fn`
|
||||
# inside a while_loop and TPU replicate context. This is useful in cases
|
||||
# where we might need to exit these contexts and get back to the outer
|
||||
# context to do some things, for e.g. create an op which should be
|
||||
# evaluated only once at the end of the loop on the host. One such usage
|
||||
# is in creating metrics' value op.
|
||||
self._outer_control_flow_context = (
|
||||
ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
|
||||
|
||||
def rewrite_fn(*args):
|
||||
"""The rewritten step fn running on TPU."""
|
||||
del args
|
||||
|
||||
per_replica_inputs = multi_worker_iterator.get_next()
|
||||
replicate_inputs = []
|
||||
for replica_id in range(self._num_replicas_in_sync):
|
||||
select_replica = lambda x: values.select_replica(replica_id, x) # pylint: disable=cell-var-from-loop
|
||||
replicate_inputs.append((nest.map_structure(
|
||||
select_replica, per_replica_inputs),))
|
||||
|
||||
replicate_outputs = tpu.replicate(run_fn, replicate_inputs)
|
||||
|
||||
# If run_fn has tensor outputs, tpu.replicate returns a list of list. We
|
||||
# will flatten it in this case. If run_fn has no tensor outputs,
|
||||
# tpu.replicate returns a list of no_ops, we will keep the output as it
|
||||
# is.
|
||||
if isinstance(replicate_outputs[0], list):
|
||||
replicate_outputs = nest.flatten(replicate_outputs)
|
||||
|
||||
return replicate_outputs
|
||||
|
||||
# TODO(sourabhbajaj): The input to while loop should be based on the
|
||||
# output type of the step_fn
|
||||
assert isinstance(initial_loop_values, list)
|
||||
initial_loop_values = initial_loop_values * self._num_replicas_in_sync
|
||||
|
||||
# Put the while loop op on host 0.
|
||||
with ops.device(self.get_host_cpu_device(0)):
|
||||
replicate_outputs = training_loop.repeat(iterations, rewrite_fn,
|
||||
initial_loop_values)
|
||||
|
||||
del self._outer_control_flow_context
|
||||
ctx.run_op = control_flow_ops.group(replicate_outputs)
|
||||
|
||||
if isinstance(replicate_outputs, list):
|
||||
# Filter out any ops from the outputs, typically this would be the case
|
||||
# when there were no tensor outputs.
|
||||
last_step_tensor_outputs = [
|
||||
x for x in replicate_outputs if not isinstance(x, ops.Operation)
|
||||
]
|
||||
|
||||
# Outputs are currently of the structure (flattened)
|
||||
# [output0_device0, output1_device0, output2_device0,
|
||||
# output0_device1, output1_device1, output2_device1,
|
||||
# ...]
|
||||
# Convert this to the following structure instead: (grouped by output)
|
||||
# [[output0_device0, output0_device1],
|
||||
# [output1_device0, output1_device1],
|
||||
# [output2_device0, output2_device1]]
|
||||
output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync
|
||||
last_step_tensor_outputs = [
|
||||
last_step_tensor_outputs[i::output_num] for i in range(output_num)
|
||||
]
|
||||
else:
|
||||
# no tensors returned.
|
||||
last_step_tensor_outputs = []
|
||||
|
||||
_set_last_step_outputs(ctx, last_step_tensor_outputs)
|
||||
return ctx
|
||||
|
||||
def _run_steps_on_iterator_with_device_loop(
|
||||
self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
|
||||
output_shapes = multi_worker_iterator.output_shapes
|
||||
shapes = nest.flatten(output_shapes)
|
||||
if any(not s.is_fully_defined() for s in shapes):
|
||||
@ -393,40 +511,12 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
||||
self._outer_control_flow_context = (
|
||||
ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
|
||||
|
||||
# pylint: disable=protected-access
|
||||
if self._container_strategy()._disable_training_loop_on_host:
|
||||
replicate_inputs = [[]] * self._num_replicas_in_sync
|
||||
replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
|
||||
else:
|
||||
def rewrite_fn(*args):
|
||||
"""The rewritten step fn running on TPU."""
|
||||
del args
|
||||
replicate_inputs = [[]] * self._num_replicas_in_sync
|
||||
replicate_outputs = tpu.replicate(run_fn, replicate_inputs)
|
||||
|
||||
# If run_fn has tensor outputs, tpu.replicate returns a list of list. We
|
||||
# will flatten it in this case. If run_fn has no tensor outputs,
|
||||
# tpu.replicate returns a list of no_ops, we will keep the output as it
|
||||
# is.
|
||||
if isinstance(replicate_outputs[0], list):
|
||||
replicate_outputs = nest.flatten(replicate_outputs)
|
||||
|
||||
return replicate_outputs
|
||||
|
||||
# TODO(sourabhbajaj): The input to while loop should be based on the
|
||||
# output type of the step_fn
|
||||
assert isinstance(initial_loop_values, list)
|
||||
initial_loop_values = initial_loop_values * self._num_replicas_in_sync
|
||||
|
||||
# Put the while loop op on host 0.
|
||||
with ops.device(self.get_host_cpu_device(0)):
|
||||
replicate_outputs = training_loop.repeat(iterations, rewrite_fn,
|
||||
initial_loop_values)
|
||||
|
||||
del self._outer_control_flow_context
|
||||
ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)
|
||||
|
||||
if self._container_strategy()._disable_training_loop_on_host:
|
||||
# Filter out any ops from the outputs, typically this would be the case
|
||||
# when there were no tensor outputs.
|
||||
last_step_tensor_outputs = [x for x in replicate_outputs
|
||||
@ -441,47 +531,8 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
||||
# [output2_device0, output2_device1]]
|
||||
last_step_tensor_outputs = [list(x) for x in
|
||||
zip(*last_step_tensor_outputs)]
|
||||
else:
|
||||
if isinstance(replicate_outputs, list):
|
||||
# Filter out any ops from the outputs, typically this would be the case
|
||||
# when there were no tensor outputs.
|
||||
last_step_tensor_outputs = [
|
||||
x for x in replicate_outputs if not isinstance(x, ops.Operation)
|
||||
]
|
||||
|
||||
# Outputs are currently of the structure (flattened)
|
||||
# [output0_device0, output1_device0, output2_device0,
|
||||
# output0_device1, output1_device1, output2_device1,
|
||||
# ...]
|
||||
# Convert this to the following structure instead: (grouped by output)
|
||||
# [[output0_device0, output0_device1],
|
||||
# [output1_device0, output1_device1],
|
||||
# [output2_device0, output2_device1]]
|
||||
output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync
|
||||
last_step_tensor_outputs = [
|
||||
last_step_tensor_outputs[i::output_num] for i in range(output_num)
|
||||
]
|
||||
else:
|
||||
# no tensors returned.
|
||||
last_step_tensor_outputs = []
|
||||
|
||||
# Convert replicate_outputs to the original dict structure of
|
||||
# last_step_outputs.
|
||||
last_step_tensor_outputs_dict = nest.pack_sequence_as(
|
||||
ctx.last_step_outputs, last_step_tensor_outputs)
|
||||
|
||||
for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access
|
||||
output = last_step_tensor_outputs_dict[name]
|
||||
# For outputs that have already been reduced, take the first value
|
||||
# from the list as each value should be the same. Else return the full
|
||||
# list of values.
|
||||
# TODO(josh11b): If reduce_op is NONE, we should return a PerReplica
|
||||
# value.
|
||||
if reduce_op is not None:
|
||||
# TODO(priyag): Should this return the element or a list with 1 element
|
||||
last_step_tensor_outputs_dict[name] = output[0]
|
||||
ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access
|
||||
|
||||
_set_last_step_outputs(ctx, last_step_tensor_outputs)
|
||||
return ctx
|
||||
|
||||
def _call_for_each_replica(self, fn, args, kwargs):
|
||||
@ -744,3 +795,30 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext):
|
||||
ds = self._strategy
|
||||
replica_id = tensor_util.constant_value(self._replica_id_in_sync_group)
|
||||
return (ds.extended.worker_devices[replica_id],)
|
||||
|
||||
|
||||
def _get_host_for_device(device):
|
||||
spec = tf_device.DeviceSpec.from_string(device)
|
||||
return tf_device.DeviceSpec(
|
||||
job=spec.job, replica=spec.replica, task=spec.task,
|
||||
device_type="CPU", device_index=0).to_string()
|
||||
|
||||
|
||||
def _set_last_step_outputs(ctx, last_step_tensor_outputs):
|
||||
"""Sets the last step outputs on the given context."""
|
||||
# Convert replicate_outputs to the original dict structure of
|
||||
# last_step_outputs.
|
||||
last_step_tensor_outputs_dict = nest.pack_sequence_as(
|
||||
ctx.last_step_outputs, last_step_tensor_outputs)
|
||||
|
||||
for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access
|
||||
output = last_step_tensor_outputs_dict[name]
|
||||
# For outputs that have already been reduced, take the first value
|
||||
# from the list as each value should be the same. Else return the full
|
||||
# list of values.
|
||||
# TODO(josh11b): If reduce_op is NONE, we should return a PerReplica
|
||||
# value.
|
||||
if reduce_op is not None:
|
||||
# TODO(priyag): Should this return the element or a list with 1 element
|
||||
last_step_tensor_outputs_dict[name] = output[0]
|
||||
ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access
|
||||
|
Loading…
Reference in New Issue
Block a user