STT-tensorflow/tensorflow/python/tpu/training_loop.py
Yanan Cao ebf22aecde De-dup tpu/xla.py compiler/xla/xla.py
PiperOrigin-RevId: 241598282
2019-04-02 14:21:06 -07:00

223 lines
8.7 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Library for constructing a training loop, suitable for TPUs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compiler.xla import xla
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.tpu import tensor_tracer
from tensorflow.python.tpu import tpu_function
def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
"""Builds a training loop for TPUs.
The set of loop-carried tensors corresponds to `inputs`. Both
`condition` and `body` take the current value of the loop-carried
tensors. 'body' additionally takes a tuple of infeed from
infeed_queue if infeed_queue is not None. `condition` must return a
single boolean value that determines whether iteration
continues. `body` must return an updated list of values for the
loop-carried tensors.
Args:
condition: a Python function that builds the loop condition.
body: a Python function that builds the loop body.
inputs: a list of initial values passed into the training loop, or
None (equivalent to an empty list).
infeed_queue: if not None, the infeed queue from which to append a tuple
of arguments as inputs to condition.
name: (Deprecated) Does nothing.
Returns:
The final values of the loop-carried tensors.
Raises:
TypeError: if body or condition has the wrong signature.
"""
del name
# Converts inputs to Tensors.
inputs = [] if inputs is None else [ops.convert_to_tensor(x) for
x in inputs]
input_types = [x.dtype for x in inputs]
input_arity = len(inputs)
body_arg_error = xla.check_function_argument_count(
body, input_arity, infeed_queue)
if body_arg_error is not None:
if infeed_queue is None:
raise TypeError(
"Supplied loop body function cannot be called with the specified "
"inputs. You specified %d inputs: %s, but the loop body needs %s" % (
input_arity, str([i.name for i in inputs]), body_arg_error))
else:
raise TypeError(
"Supplied loop body function cannot be called with the specified "
"inputs. You specified %d inputs: %s and %d additional inputs from "
"infeed, but the computation needs %s" % (input_arity, str(
[i.name for i in inputs]), infeed_queue.number_of_tuple_elements,
body_arg_error))
condition_arg_error = xla.check_function_argument_count(
condition, input_arity, None)
if condition_arg_error is not None:
if infeed_queue is None:
raise TypeError(
"Supplied loop condition function cannot be called with the "
"specified inputs. You specified %d inputs: %s, but the loop "
"condition needs %s" % (input_arity, str([i.name for i in inputs]),
condition_arg_error))
else:
raise TypeError(
"Supplied loop condition function cannot be called with the "
"specified inputs. You specified %d inputs: %s, but the loop "
"condition needs %s. Note that infeed is not passed to the loop "
"condition." % (input_arity, str([i.name for i in inputs]),
condition_arg_error))
def condition_wrapper(*inputs):
# Discards the dummy output added for arity-0 loops.
if input_arity == 0:
inputs = []
return condition(*inputs)
def body_wrapper(*inputs):
"""Wrapper around `body` that handles infeed queues and control deps."""
inputs = list(inputs)
# Discards the dummy output added for arity-0 loops.
if input_arity == 0:
inputs = []
# Runs `body` with the dequeue_ops appended.
if infeed_queue:
number_of_shards = tpu_function.get_tpu_context().number_of_shards
if number_of_shards is None:
raise ValueError("Can't build training loop with infeed when there is "
"no tpu_shard_context. Are you building a loop or "
"graph directly rather than from inside tpu.rewrite, "
"tpu.batch_parallel, tpu.shard, or tpu.replicate?")
infeed_queue.set_number_of_shards(number_of_shards)
dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()]
else:
dequeue_ops = []
outputs = body(*(inputs + dequeue_ops))
# If the computation only returned one value, make it a tuple.
if not isinstance(outputs, (list, tuple)):
outputs = (outputs,)
outputs = [
o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
for o in outputs
]
# Separates the returned Operations and Tensors.
output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
output_tensors = [o for o in outputs
if not isinstance(o, ops.Operation)]
if outputs != output_tensors + output_operations:
raise ValueError(
"TPU training loop body must return zero or more Tensor values "
"followed by zero or more Operations.")
output_types = [op.dtype for op in output_tensors]
if input_types != output_types:
raise TypeError(
"Mismatch between input types and output types for training loop "
"body: {} vs {}".format(input_types, output_types))
# Add the dequeue operations to output_operations to ensure they are run
# by the loop, even if the programmer's loop body does not use them.
output_operations += dequeue_ops
# Add a dummy output, if needed.
if not output_tensors:
output_tensors = array_ops.constant(0)
if output_operations:
# TODO(phawkins): in principle this is too restrictive since it serializes
# the training loop steps. In practice it does not matter since this loop
# will be compiled by XLA.
output_tensors = control_flow_ops.tuple(output_tensors,
control_inputs=output_operations)
if tensor_tracer.TensorTracer.is_enabled():
num_replicas = tpu_function.get_tpu_context().number_of_shards
if num_replicas is None:
num_replicas = 1
tt = tensor_tracer.TensorTracer()
output_tensors = tt.trace_tpu(ops.get_default_graph(),
output_tensors, None,
num_replicas)
return output_tensors
# If the body has arity 0, add a dummy loop-carried value to which we can add
# control dependencies from any side-effecting operations.
if input_arity == 0:
inputs = [array_ops.constant(0)]
return control_flow_ops.while_loop(
condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1)
def repeat(n, body, inputs=None, infeed_queue=None, name=None):
"""Builds a training loop that executes a fixed number of iterations.
The set of loop-carried tensors correspond to `inputs`.
`body` must be a function that takes and returns the values of the
loop-carried tensors.
Args:
n: the number of loop iterations
body: a Python function that builds the loop body.
inputs: a list of initial values passed into the training loop or
None (equivalent to an empty list).
infeed_queue: if not None, the infeed queue from which to append a tuple
of arguments as inputs to condition.
name: (Deprecated) Does nothing.
Returns:
The final values of the loop-carried tensors.
Raises:
ValueError: if there is a type error.
"""
def _convert_to_list(xs):
if not isinstance(xs, (list, tuple)):
return [xs]
else:
return list(xs)
def cond(i, *args):
del args
return i < n
def body_wrapper(i, *args):
return [i + 1] + _convert_to_list(body(*args))
inputs = [0] if inputs is None else [0] + _convert_to_list(inputs)
outputs = while_loop(
cond, body_wrapper, inputs=inputs, infeed_queue=infeed_queue, name=name)
outputs = _convert_to_list(outputs)
if len(outputs) == 1:
# Returns the Op rather than an empty list.
return outputs[0].op
else:
return outputs[1:]