223 lines
8.7 KiB
Python
223 lines
8.7 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# =============================================================================
|
|
|
|
"""Library for constructing a training loop, suitable for TPUs."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from tensorflow.python.compiler.xla import xla
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.tpu import tensor_tracer
|
|
from tensorflow.python.tpu import tpu_function
|
|
|
|
|
|
def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
|
|
"""Builds a training loop for TPUs.
|
|
|
|
The set of loop-carried tensors corresponds to `inputs`. Both
|
|
`condition` and `body` take the current value of the loop-carried
|
|
tensors. 'body' additionally takes a tuple of infeed from
|
|
infeed_queue if infeed_queue is not None. `condition` must return a
|
|
single boolean value that determines whether iteration
|
|
continues. `body` must return an updated list of values for the
|
|
loop-carried tensors.
|
|
|
|
Args:
|
|
condition: a Python function that builds the loop condition.
|
|
body: a Python function that builds the loop body.
|
|
inputs: a list of initial values passed into the training loop, or
|
|
None (equivalent to an empty list).
|
|
infeed_queue: if not None, the infeed queue from which to append a tuple
|
|
of arguments as inputs to condition.
|
|
name: (Deprecated) Does nothing.
|
|
|
|
Returns:
|
|
The final values of the loop-carried tensors.
|
|
|
|
Raises:
|
|
TypeError: if body or condition has the wrong signature.
|
|
"""
|
|
del name
|
|
# Converts inputs to Tensors.
|
|
inputs = [] if inputs is None else [ops.convert_to_tensor(x) for
|
|
x in inputs]
|
|
input_types = [x.dtype for x in inputs]
|
|
input_arity = len(inputs)
|
|
|
|
body_arg_error = xla.check_function_argument_count(
|
|
body, input_arity, infeed_queue)
|
|
if body_arg_error is not None:
|
|
if infeed_queue is None:
|
|
raise TypeError(
|
|
"Supplied loop body function cannot be called with the specified "
|
|
"inputs. You specified %d inputs: %s, but the loop body needs %s" % (
|
|
input_arity, str([i.name for i in inputs]), body_arg_error))
|
|
else:
|
|
raise TypeError(
|
|
"Supplied loop body function cannot be called with the specified "
|
|
"inputs. You specified %d inputs: %s and %d additional inputs from "
|
|
"infeed, but the computation needs %s" % (input_arity, str(
|
|
[i.name for i in inputs]), infeed_queue.number_of_tuple_elements,
|
|
body_arg_error))
|
|
condition_arg_error = xla.check_function_argument_count(
|
|
condition, input_arity, None)
|
|
if condition_arg_error is not None:
|
|
if infeed_queue is None:
|
|
raise TypeError(
|
|
"Supplied loop condition function cannot be called with the "
|
|
"specified inputs. You specified %d inputs: %s, but the loop "
|
|
"condition needs %s" % (input_arity, str([i.name for i in inputs]),
|
|
condition_arg_error))
|
|
else:
|
|
raise TypeError(
|
|
"Supplied loop condition function cannot be called with the "
|
|
"specified inputs. You specified %d inputs: %s, but the loop "
|
|
"condition needs %s. Note that infeed is not passed to the loop "
|
|
"condition." % (input_arity, str([i.name for i in inputs]),
|
|
condition_arg_error))
|
|
|
|
def condition_wrapper(*inputs):
|
|
# Discards the dummy output added for arity-0 loops.
|
|
if input_arity == 0:
|
|
inputs = []
|
|
return condition(*inputs)
|
|
|
|
def body_wrapper(*inputs):
|
|
"""Wrapper around `body` that handles infeed queues and control deps."""
|
|
inputs = list(inputs)
|
|
|
|
# Discards the dummy output added for arity-0 loops.
|
|
if input_arity == 0:
|
|
inputs = []
|
|
|
|
# Runs `body` with the dequeue_ops appended.
|
|
if infeed_queue:
|
|
number_of_shards = tpu_function.get_tpu_context().number_of_shards
|
|
if number_of_shards is None:
|
|
raise ValueError("Can't build training loop with infeed when there is "
|
|
"no tpu_shard_context. Are you building a loop or "
|
|
"graph directly rather than from inside tpu.rewrite, "
|
|
"tpu.batch_parallel, tpu.shard, or tpu.replicate?")
|
|
infeed_queue.set_number_of_shards(number_of_shards)
|
|
dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()]
|
|
else:
|
|
dequeue_ops = []
|
|
outputs = body(*(inputs + dequeue_ops))
|
|
|
|
# If the computation only returned one value, make it a tuple.
|
|
if not isinstance(outputs, (list, tuple)):
|
|
outputs = (outputs,)
|
|
|
|
outputs = [
|
|
o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
|
|
for o in outputs
|
|
]
|
|
|
|
# Separates the returned Operations and Tensors.
|
|
output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
|
|
output_tensors = [o for o in outputs
|
|
if not isinstance(o, ops.Operation)]
|
|
|
|
if outputs != output_tensors + output_operations:
|
|
raise ValueError(
|
|
"TPU training loop body must return zero or more Tensor values "
|
|
"followed by zero or more Operations.")
|
|
|
|
output_types = [op.dtype for op in output_tensors]
|
|
if input_types != output_types:
|
|
raise TypeError(
|
|
"Mismatch between input types and output types for training loop "
|
|
"body: {} vs {}".format(input_types, output_types))
|
|
|
|
# Add the dequeue operations to output_operations to ensure they are run
|
|
# by the loop, even if the programmer's loop body does not use them.
|
|
output_operations += dequeue_ops
|
|
|
|
# Add a dummy output, if needed.
|
|
if not output_tensors:
|
|
output_tensors = array_ops.constant(0)
|
|
|
|
if output_operations:
|
|
# TODO(phawkins): in principle this is too restrictive since it serializes
|
|
# the training loop steps. In practice it does not matter since this loop
|
|
# will be compiled by XLA.
|
|
output_tensors = control_flow_ops.tuple(output_tensors,
|
|
control_inputs=output_operations)
|
|
|
|
if tensor_tracer.TensorTracer.is_enabled():
|
|
num_replicas = tpu_function.get_tpu_context().number_of_shards
|
|
if num_replicas is None:
|
|
num_replicas = 1
|
|
tt = tensor_tracer.TensorTracer()
|
|
output_tensors = tt.trace_tpu(ops.get_default_graph(),
|
|
output_tensors, None,
|
|
num_replicas)
|
|
return output_tensors
|
|
|
|
# If the body has arity 0, add a dummy loop-carried value to which we can add
|
|
# control dependencies from any side-effecting operations.
|
|
if input_arity == 0:
|
|
inputs = [array_ops.constant(0)]
|
|
return control_flow_ops.while_loop(
|
|
condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1)
|
|
|
|
|
|
def repeat(n, body, inputs=None, infeed_queue=None, name=None):
|
|
"""Builds a training loop that executes a fixed number of iterations.
|
|
|
|
The set of loop-carried tensors correspond to `inputs`.
|
|
`body` must be a function that takes and returns the values of the
|
|
loop-carried tensors.
|
|
|
|
Args:
|
|
n: the number of loop iterations
|
|
body: a Python function that builds the loop body.
|
|
inputs: a list of initial values passed into the training loop or
|
|
None (equivalent to an empty list).
|
|
infeed_queue: if not None, the infeed queue from which to append a tuple
|
|
of arguments as inputs to condition.
|
|
name: (Deprecated) Does nothing.
|
|
Returns:
|
|
The final values of the loop-carried tensors.
|
|
Raises:
|
|
ValueError: if there is a type error.
|
|
"""
|
|
def _convert_to_list(xs):
|
|
if not isinstance(xs, (list, tuple)):
|
|
return [xs]
|
|
else:
|
|
return list(xs)
|
|
|
|
def cond(i, *args):
|
|
del args
|
|
return i < n
|
|
|
|
def body_wrapper(i, *args):
|
|
return [i + 1] + _convert_to_list(body(*args))
|
|
|
|
inputs = [0] if inputs is None else [0] + _convert_to_list(inputs)
|
|
outputs = while_loop(
|
|
cond, body_wrapper, inputs=inputs, infeed_queue=infeed_queue, name=name)
|
|
outputs = _convert_to_list(outputs)
|
|
if len(outputs) == 1:
|
|
# Returns the Op rather than an empty list.
|
|
return outputs[0].op
|
|
else:
|
|
return outputs[1:]
|