Fixes https://github.com/tensorflow/tensorflow/issues/39872 PiperOrigin-RevId: 317137904 Change-Id: Id287e10a0ab2494b11427435d8f89a383eeaf392
3722 lines
139 KiB
Python
3722 lines
139 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Control Flow Operations.
|
|
|
|
See the [autograph](https://www.tensorflow.org/guide/autograph) guide.
|
|
"""
|
|
# pylint: disable=g-bad-name
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import abc
|
|
import collections
|
|
import functools
|
|
|
|
import six
|
|
|
|
from tensorflow.core.framework import attr_value_pb2
|
|
from tensorflow.core.protobuf import control_flow_pb2
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import composite_tensor
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import errors
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_shape
|
|
from tensorflow.python.framework import tensor_spec
|
|
from tensorflow.python.framework import tensor_util
|
|
from tensorflow.python.framework import type_spec
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import control_flow_util as util
|
|
from tensorflow.python.ops import gen_array_ops
|
|
from tensorflow.python.ops import gen_control_flow_ops
|
|
from tensorflow.python.ops import gen_functional_ops
|
|
from tensorflow.python.ops import gen_logging_ops
|
|
from tensorflow.python.ops import gen_math_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import tensor_array_ops
|
|
# go/tf-wildcard-import
|
|
# pylint: disable=wildcard-import,undefined-variable
|
|
from tensorflow.python.ops.gen_control_flow_ops import *
|
|
# pylint: enable=wildcard-import
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
from tensorflow.python.util import compat
|
|
from tensorflow.python.util import deprecation
|
|
from tensorflow.python.util import dispatch
|
|
from tensorflow.python.util import nest
|
|
from tensorflow.python.util import tf_should_use
|
|
from tensorflow.python.util.lazy_loader import LazyLoader
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
# This is to avoid a circular dependency:
|
|
# cond_v2 -> gradients_util -> control_flow_ops
|
|
cond_v2 = LazyLoader("cond_v2", globals(),
|
|
"tensorflow.python.ops.cond_v2")
|
|
|
|
# This is to avoid circular dependencies:
|
|
# while_v2 -> control_flow_ops
|
|
# while_v2 -> gradients_util -> control_flow_ops
|
|
while_v2 = LazyLoader("while_v2", globals(),
|
|
"tensorflow.python.ops.while_v2")
|
|
|
|
# We override the 'tuple' for a control flow op, so we keep python's
|
|
# existing 'tuple' for later use in this module.
|
|
_basetuple = tuple
|
|
|
|
|
|
def _summarize_eager(tensor, summarize=None):
|
|
"""Returns a summarized string representation of eager `tensor`.
|
|
|
|
Args:
|
|
tensor: EagerTensor to summarize
|
|
summarize: Include these many first elements of `array`
|
|
"""
|
|
# Emulate the behavior of Tensor::SummarizeValue()
|
|
if summarize is None:
|
|
summarize = 3
|
|
elif summarize < 0:
|
|
summarize = array_ops.size(tensor)
|
|
|
|
# reshape((-1,)) is the fastest way to get a flat array view
|
|
if tensor._rank(): # pylint: disable=protected-access
|
|
flat = tensor.numpy().reshape((-1,))
|
|
lst = [str(x) for x in flat[:summarize]]
|
|
if len(lst) < flat.size:
|
|
lst.append("...")
|
|
else:
|
|
# tensor.numpy() returns a scalar for zero dimensional arrays
|
|
if gen_math_ops.not_equal(summarize, 0):
|
|
lst = [str(tensor.numpy())]
|
|
else:
|
|
lst = []
|
|
|
|
return ", ".join(lst)
|
|
|
|
|
|
# pylint: disable=protected-access
|
|
|
|
|
|
# Assert and Print are special symbols in python, so we must
|
|
# use an upper-case version of them.
|
|
@tf_export("debugging.Assert", "Assert")
|
|
@dispatch.add_dispatch_support
|
|
@tf_should_use.should_use_result
|
|
def Assert(condition, data, summarize=None, name=None):
|
|
"""Asserts that the given condition is true.
|
|
|
|
If `condition` evaluates to false, print the list of tensors in `data`.
|
|
`summarize` determines how many entries of the tensors to print.
|
|
|
|
Args:
|
|
condition: The condition to evaluate.
|
|
data: The tensors to print out when condition is false.
|
|
summarize: Print this many entries of each tensor.
|
|
name: A name for this operation (optional).
|
|
|
|
Returns:
|
|
assert_op: An `Operation` that, when executed, raises a
|
|
`tf.errors.InvalidArgumentError` if `condition` is not true.
|
|
@compatibility(eager)
|
|
returns None
|
|
@end_compatibility
|
|
|
|
Raises:
|
|
@compatibility(TF1)
|
|
When in TF V1 mode (that is, outside `tf.function`) Assert needs a control
|
|
dependency on the output to ensure the assertion executes:
|
|
|
|
```python
|
|
# Ensure maximum element of x is smaller or equal to 1
|
|
assert_op = tf.Assert(tf.less_equal(tf.reduce_max(x), 1.), [x])
|
|
with tf.control_dependencies([assert_op]):
|
|
... code using x ...
|
|
```
|
|
|
|
@end_compatibility
|
|
"""
|
|
if context.executing_eagerly():
|
|
if not condition:
|
|
xs = ops.convert_n_to_tensor(data)
|
|
data_str = [_summarize_eager(x, summarize) for x in xs]
|
|
raise errors.InvalidArgumentError(
|
|
node_def=None,
|
|
op=None,
|
|
message="Expected '%s' to be true. Summarized data: %s" %
|
|
(condition, "\n".join(data_str)))
|
|
return
|
|
|
|
with ops.name_scope(name, "Assert", [condition, data]) as name:
|
|
xs = ops.convert_n_to_tensor(data)
|
|
if all(x.dtype in {dtypes.string, dtypes.int32} for x in xs):
|
|
# As a simple heuristic, we assume that string and int32 are
|
|
# on host to avoid the need to use cond. If it is not case,
|
|
# we will pay the price copying the tensor to host memory.
|
|
return gen_logging_ops._assert(condition, data, summarize, name="Assert")
|
|
else:
|
|
condition = ops.convert_to_tensor(condition, name="Condition")
|
|
|
|
def true_assert():
|
|
return gen_logging_ops._assert(
|
|
condition, data, summarize, name="Assert")
|
|
|
|
guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard")
|
|
if context.executing_eagerly():
|
|
return
|
|
return guarded_assert.op
|
|
|
|
|
|
def _Identity(data, name=None):
|
|
"""Return a tensor with the same shape and contents as the input tensor.
|
|
|
|
Args:
|
|
data: A Tensor.
|
|
name: A name for this operation (optional).
|
|
|
|
Returns:
|
|
A Tensor with the same type and value as the input Tensor.
|
|
"""
|
|
data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True)
|
|
if isinstance(data, ops.Tensor):
|
|
if data.dtype._is_ref_dtype: # pylint: disable=protected-access
|
|
return gen_array_ops.ref_identity(data, name=name)
|
|
else:
|
|
return array_ops.identity(data, name=name)
|
|
elif isinstance(data, composite_tensor.CompositeTensor):
|
|
return nest.map_structure(_Identity, data, expand_composites=True)
|
|
else:
|
|
raise TypeError("Type %s not supported" % type(data))
|
|
|
|
|
|
def _NextIteration(data, name=None):
|
|
data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True)
|
|
if isinstance(data, ops.Tensor):
|
|
if data.dtype._is_ref_dtype: # pylint: disable=protected-access
|
|
return ref_next_iteration(data, name=name)
|
|
else:
|
|
return next_iteration(data, name=name)
|
|
elif isinstance(data, composite_tensor.CompositeTensor):
|
|
return nest.map_structure(_NextIteration, data, expand_composites=True)
|
|
else:
|
|
raise TypeError("Type %s not supported" % type(data))
|
|
|
|
|
|
def _Enter(data,
|
|
frame_name,
|
|
is_constant=False,
|
|
parallel_iterations=10,
|
|
use_ref=True,
|
|
use_input_shape=True,
|
|
name=None):
|
|
"""Creates or finds a child frame, and makes `data` available to it.
|
|
|
|
The unique `frame_name` is used by the `Executor` to identify frames. If
|
|
`is_constant` is true, `data` is a constant in the child frame; otherwise
|
|
it may be changed in the child frame. At most `parallel_iterations`
|
|
iterations are run in parallel in the child frame.
|
|
|
|
Args:
|
|
data: The tensor to be made available to the child frame.
|
|
frame_name: The name of the child frame.
|
|
is_constant: If true, the output is constant within the child frame.
|
|
parallel_iterations: The number of iterations allowed to run in parallel.
|
|
use_ref: If true, use ref_enter if data is of ref type.
|
|
use_input_shape: If true, set the result's shape based on data's shape.
|
|
name: A name for this operation (optional).
|
|
|
|
Returns:
|
|
The same tensor as `data`.
|
|
"""
|
|
data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True)
|
|
if isinstance(data, ops.Tensor):
|
|
if data.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access
|
|
result = gen_control_flow_ops.ref_enter(
|
|
data, frame_name, is_constant, parallel_iterations, name=name)
|
|
else:
|
|
result = gen_control_flow_ops.enter(
|
|
data, frame_name, is_constant, parallel_iterations, name=name)
|
|
if use_input_shape:
|
|
result.set_shape(data.get_shape())
|
|
return result
|
|
elif isinstance(data, composite_tensor.CompositeTensor):
|
|
|
|
def enter_component(t):
|
|
return _Enter(t, frame_name, is_constant, parallel_iterations, use_ref,
|
|
use_input_shape)
|
|
|
|
return nest.map_structure(enter_component, data, expand_composites=True)
|
|
else:
|
|
raise TypeError("Type %s not supported" % type(data))
|
|
|
|
|
|
def exit(data, name=None): # pylint: disable=redefined-builtin
|
|
"""Exits the current frame to its parent frame.
|
|
|
|
Exit makes its input `data` available to the parent frame.
|
|
|
|
Args:
|
|
data: The tensor to be made available to the parent frame.
|
|
name: A name for this operation (optional).
|
|
|
|
Returns:
|
|
The same tensor as `data`.
|
|
"""
|
|
data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True)
|
|
if isinstance(data, ops.Tensor):
|
|
if data.dtype._is_ref_dtype: # pylint: disable=protected-access
|
|
return gen_control_flow_ops.ref_exit(data, name)
|
|
else:
|
|
return gen_control_flow_ops._exit(data, name)
|
|
elif isinstance(data, composite_tensor.CompositeTensor):
|
|
return nest.map_structure(exit, data, expand_composites=True)
|
|
else:
|
|
raise TypeError("Type %s not supported" % type(data))
|
|
|
|
|
|
def switch(data, pred, dtype=None, name=None):
|
|
"""Forwards `data` to an output determined by `pred`.
|
|
|
|
If `pred` is false, the `data` input is forwarded to the first output.
|
|
Otherwise, the data goes to the second output.
|
|
|
|
This op handles `Tensor`s and `IndexedSlices`.
|
|
|
|
Args:
|
|
data: The tensor to be forwarded to the appropriate output.
|
|
pred: A scalar that specifies which output port will receive data.
|
|
dtype: Optional element type for the returned tensor. If missing, the type
|
|
is inferred from the type of `value`.
|
|
name: A name for this operation (optional).
|
|
|
|
Returns:
|
|
`(output_false, output_true)`: If `pred` is true, data will be forwarded
|
|
to `output_true`, otherwise it goes to `output_false`.
|
|
"""
|
|
with ops.name_scope(name, "Switch", [data, pred]) as name:
|
|
data = ops.internal_convert_to_tensor_or_composite(
|
|
data, dtype=dtype, name="data", as_ref=True)
|
|
pred = ops.convert_to_tensor(pred, name="pred")
|
|
if isinstance(data, ops.Tensor):
|
|
return gen_control_flow_ops.switch(data, pred, name=name)
|
|
else:
|
|
if not isinstance(data, composite_tensor.CompositeTensor):
|
|
raise TypeError("Type %s not supported" % type(data))
|
|
tensors = nest.flatten(data, expand_composites=True)
|
|
mapped = [gen_control_flow_ops.switch(tensor, pred) for tensor in tensors]
|
|
mapped_f, mapped_t = zip(*mapped)
|
|
return (nest.pack_sequence_as(data, mapped_f, expand_composites=True),
|
|
nest.pack_sequence_as(data, mapped_t, expand_composites=True))
|
|
|
|
|
|
def _SwitchRefOrTensor(data, pred, name="Switch"):
|
|
"""Forwards `data` to an output determined by `pred`.
|
|
|
|
If `pred` is false, the `data` input is forwarded to the first output.
|
|
Otherwise, the data goes to the second output.
|
|
|
|
This op handles `Tensor`s and `IndexedSlices`.
|
|
|
|
Args:
|
|
data: The tensor to be forwarded to the appropriate output.
|
|
pred: A scalar that specifies which output port will receive data.
|
|
name: A name for this operation (optional).
|
|
|
|
Returns:
|
|
`(output_false, output_true)`: If `pred` is true, data will be forwarded to
|
|
`output_true`, otherwise it goes to `output_false`.
|
|
|
|
Raises:
|
|
TypeError: if data is not a Tensor or IndexedSlices
|
|
"""
|
|
data = ops.convert_to_tensor_or_composite(data, name="data")
|
|
# NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below
|
|
# addresses the following scenario.
|
|
#
|
|
# Assume you execute Optimizer.apply_gradients() in a branch of a cond().
|
|
#
|
|
# 1. The update op is created inside a `with ops.colocate(var):` block
|
|
#
|
|
# 2. Some tensor `data` is captured and a switch is created in a
|
|
# `with ops.colocate_with(data):` block.
|
|
#
|
|
# with ops.colocate_with(var):
|
|
# with ops.colocate_with(data):
|
|
# op = ...
|
|
#
|
|
# var and data may be pinned to different devices, so we want to ops
|
|
# created within ops.colocate_with(data) to ignore the existing stack.
|
|
with ops.colocate_with(data, ignore_existing=True):
|
|
if isinstance(data, ops.Tensor):
|
|
if data.dtype._is_ref_dtype: # pylint: disable=protected-access
|
|
return ref_switch(data, pred, name=name)
|
|
return switch(data, pred, name=name)
|
|
|
|
|
|
def merge(inputs, name=None):
|
|
"""Returns the value of an available element of `inputs`.
|
|
|
|
This op tests each of the tensors in `inputs` in turn to determine if any of
|
|
them is available. If it finds an available tensor, it returns it and its
|
|
index in `inputs`.
|
|
|
|
It is an error if more than one tensor in `inputs` is available. If no tensor
|
|
in `inputs` is available, the returned tensor and index are not set.
|
|
|
|
This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of
|
|
`Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices
|
|
before merging.
|
|
|
|
Args:
|
|
inputs: The input tensors, at most one of which is available.
|
|
name: A name for this operation (optional).
|
|
|
|
Returns:
|
|
A tuple containing the chosen input tensor and its index in `inputs`.
|
|
|
|
Raises:
|
|
ValueError: If any of the inputs is None, or inputs are IndexedSlices and
|
|
some but not all have a dense_shape property.
|
|
"""
|
|
if any(inp is None for inp in inputs):
|
|
raise ValueError("At least one of the merge inputs is None: %s" % inputs)
|
|
with ops.name_scope(name, "Merge", inputs) as name:
|
|
inputs = [
|
|
ops.internal_convert_to_tensor_or_composite(inp, as_ref=True)
|
|
for inp in inputs
|
|
]
|
|
if all(isinstance(v, ops.Tensor) for v in inputs):
|
|
if all(v.dtype._is_ref_dtype for v in inputs): # pylint: disable=protected-access
|
|
return gen_control_flow_ops.ref_merge(inputs, name)
|
|
else:
|
|
return gen_control_flow_ops.merge(inputs, name)
|
|
else:
|
|
# If there is a mix of tensors and indexed slices, then convert the
|
|
# tensors to indexed slices.
|
|
if all(isinstance(v, (ops.IndexedSlices, ops.Tensor)) for v in inputs):
|
|
inputs = math_ops._as_indexed_slices_list(inputs, optimize=False)
|
|
|
|
for v in inputs:
|
|
if not isinstance(v, composite_tensor.CompositeTensor):
|
|
raise TypeError("Type %s not supported" % type(v))
|
|
|
|
for v in inputs[1:]:
|
|
nest.assert_same_structure(inputs[0], v, expand_composites=True)
|
|
|
|
flat_inputs = [nest.flatten(v, expand_composites=True) for v in inputs]
|
|
merged_results = [
|
|
gen_control_flow_ops.merge(component)
|
|
for component in zip(*flat_inputs)
|
|
]
|
|
flat_merged = [tensor for (tensor, _) in merged_results]
|
|
chosen_index = merged_results[0][1]
|
|
merged_inputs = nest.pack_sequence_as(
|
|
inputs[0], flat_merged, expand_composites=True)
|
|
return (merged_inputs, chosen_index)
|
|
|
|
|
|
# pylint: enable=protected-access
|
|
|
|
|
|
def _convert_tensorarray_to_flow(tensor_or_tensor_array):
|
|
if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray):
|
|
return tensor_or_tensor_array.flow
|
|
else:
|
|
return tensor_or_tensor_array
|
|
|
|
|
|
def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows):
|
|
if len(tensors_or_tensorarrays) != len(tensors_or_flows):
|
|
raise ValueError(
|
|
"Lengths of original Tensor list and new list do not match: %d vs. %d" %
|
|
(len(tensors_or_tensorarrays), len(tensors_or_flows)))
|
|
return [
|
|
tensor_array_ops.build_ta_with_new_flow(ta, t_or_flow) if isinstance(
|
|
ta, tensor_array_ops.TensorArray) else t_or_flow
|
|
for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows)
|
|
]
|
|
|
|
|
|
def _ShapeLessThanOrEqual(shape1, shape2):
|
|
if shape2.dims is None:
|
|
return True
|
|
if shape1.ndims != shape2.ndims:
|
|
return False
|
|
for dim1, dim2 in zip(shape1.dims, shape2.dims):
|
|
if dim2.value is not None and dim1.value != dim2.value:
|
|
return False
|
|
return True
|
|
|
|
|
|
def _get_shape_invariant(var, shape=None):
|
|
"""Returns shape invariant(s) for the given variable.
|
|
|
|
Args:
|
|
var: The tensor whose shape is described.
|
|
shape: The shape invariant for the tensor. If not specified, then a default
|
|
shape invariant for `var` is returned.
|
|
|
|
Returns:
|
|
`TensorShape` or `list` of `TensorShape`: The shape invariant for `var` (if
|
|
it is a `Tensor`), or the shape invariants for the components that comprise
|
|
`var` (if it is a `CompositeTensor`).
|
|
"""
|
|
if isinstance(var, composite_tensor.CompositeTensor):
|
|
# Get a TypeSpec for `var`.
|
|
if shape is None:
|
|
spec = var._type_spec # pylint: disable=protected-access
|
|
else:
|
|
spec = _shape_invariant_to_type_spec(var, shape)
|
|
|
|
tensor_specs = nest.flatten(spec, expand_composites=True)
|
|
return [tspec.shape for tspec in tensor_specs]
|
|
|
|
elif shape is None:
|
|
return var.shape
|
|
elif isinstance(shape, tensor_spec.TensorSpec):
|
|
if var.dtype != shape.dtype:
|
|
raise TypeError("TensorSpec %r is not compatible with %r" % (shape, var))
|
|
return shape.shape
|
|
elif isinstance(shape, type_spec.TypeSpec):
|
|
raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var))
|
|
else:
|
|
return shape
|
|
|
|
|
|
def _shape_invariant_to_type_spec(var, shape):
|
|
"""Converts a shape invariant to a TypeSpec.
|
|
|
|
Args:
|
|
var: The tensor whose shape is described by the shape invariant.
|
|
shape: A `TypeSpec` or `TensorShape`. If `shape` is already a `TypeSpec`,
|
|
then it is simply returned as-is.
|
|
|
|
Returns:
|
|
A `TypeSpec` for `var`, consistent with the given shape.
|
|
"""
|
|
if shape is None:
|
|
return type_spec.type_spec_from_value(var)
|
|
elif isinstance(shape, type_spec.TypeSpec):
|
|
if not shape.is_compatible_with(var):
|
|
raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var))
|
|
return shape
|
|
elif not isinstance(shape, tensor_shape.TensorShape):
|
|
raise TypeError(
|
|
"Expected shape to be a TypeSpec, TensorShape or None, got %r for"
|
|
" value %r" % (shape, var))
|
|
|
|
if isinstance(var, ops.Tensor):
|
|
return tensor_spec.TensorSpec(shape, var.dtype)
|
|
|
|
elif isinstance(var, composite_tensor.CompositeTensor):
|
|
try:
|
|
return var._shape_invariant_to_type_spec(shape) # pylint: disable=protected-access
|
|
except NotImplementedError:
|
|
raise TypeError(
|
|
"To describe or constrain a %s, use a %s instead of a TensorShape." %
|
|
(type(var).__name__, type(var._type_spec).__name__)) # pylint: disable=protected-access
|
|
|
|
else:
|
|
raise TypeError("Expected var to be a Tensor or CompositeTensor, got %s"
|
|
% var)
|
|
|
|
|
|
def _SetShapeInvariants(input_vars, enter_vars, shapes):
|
|
"""Set the shapes of the tensors in `enter_vars` to `shapes`.
|
|
|
|
Args:
|
|
input_vars: A list of tensors that are inputs to `enter_vars`.
|
|
enter_vars: A list of tensors whose shapes will be set.
|
|
shapes: A (possibly nested) list of shapes.
|
|
|
|
Raises:
|
|
ValueError: If any tensor in `enter_vars` has a less specific shape
|
|
than its corresponding shape in `shapes`.
|
|
"""
|
|
if shapes is None:
|
|
return
|
|
flat_shapes = nest.flatten(shapes)
|
|
if not all(isinstance(s, tensor_shape.TensorShape) for s in flat_shapes):
|
|
raise ValueError("`shapes` must be a (possibly nested) list of shapes.")
|
|
# Check that the shapes of the inputs are less than the shape invariants,
|
|
# and set the shapes of `enter_vars` to the shape invariants.
|
|
for inp, var, shape in zip(input_vars, enter_vars, flat_shapes):
|
|
if isinstance(var, ops.Tensor):
|
|
if not _ShapeLessThanOrEqual(inp.get_shape(), shape):
|
|
raise ValueError(
|
|
"The shape invariant specified for %s is not compatible with "
|
|
"the initial shape of the loop variable. It enters the loop "
|
|
"with shape %s, but the specified shape invariant is %s." %
|
|
(inp.name, inp.get_shape(), shape))
|
|
var.set_shape(shape)
|
|
else:
|
|
raise TypeError("Type %s not supported" % type(var))
|
|
|
|
|
|
def _EnforceShapeInvariant(merge_var, next_var):
|
|
"""Check if the shapes of the loops variables are invariants.
|
|
|
|
Args:
|
|
merge_var: The list of tensors representing the initial values of the loop
|
|
variables.
|
|
next_var: The list of tensors representing the values of the loop variables
|
|
after one loop iteration.
|
|
|
|
Raises:
|
|
ValueError: If any tensor in `merge_var` has a more specific shape than
|
|
its corresponding tensor in `next_var`.
|
|
"""
|
|
if isinstance(merge_var, ops.Tensor):
|
|
m_shape = merge_var.get_shape()
|
|
n_shape = next_var.get_shape()
|
|
if not _ShapeLessThanOrEqual(n_shape, m_shape):
|
|
enter = merge_var.op.inputs[0].op
|
|
assert util.IsLoopEnter(enter)
|
|
input_t = enter.inputs[0]
|
|
raise ValueError(
|
|
"Input tensor '%s' enters the loop with shape %s, but has shape %s "
|
|
"after one iteration. To allow the shape to vary across iterations, "
|
|
"use the `shape_invariants` argument of tf.while_loop to specify a "
|
|
"less-specific shape." % (input_t.name, input_t.shape, n_shape))
|
|
else:
|
|
raise TypeError("Type %s not supported" % type(merge_var))
|
|
|
|
|
|
def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True):
|
|
"""Add NextIteration and back edge from v to m."""
|
|
if isinstance(m, ops.Tensor):
|
|
v = ops.convert_to_tensor(v)
|
|
v = _NextIteration(v)
|
|
if enforce_shape_invariant:
|
|
# Make sure the shapes of loop outputs are correct. We do this before
|
|
# calling _update_input, which will raise a less-helpful error message if
|
|
# the types don't match.
|
|
# TODO(skyewm): call this for other cases below (needs testing)
|
|
_EnforceShapeInvariant(m, v)
|
|
m.op._update_input(1, v) # pylint: disable=protected-access
|
|
elif isinstance(m, composite_tensor.CompositeTensor):
|
|
# pylint: disable=protected-access
|
|
def update_component(m_component, v_component):
|
|
m_component.op._update_input(1, v_component)
|
|
|
|
if isinstance(m, ops.IndexedSlices):
|
|
v = math_ops._as_indexed_slices(v, optimize=False)
|
|
# pylint: enable=protected-access
|
|
v = _NextIteration(v)
|
|
return nest.map_structure(update_component, m, v, expand_composites=True)
|
|
else:
|
|
raise TypeError("Type %s not supported" % type(m))
|
|
return v
|
|
|
|
|
|
@six.add_metaclass(abc.ABCMeta)
|
|
class ControlFlowContext(object):
|
|
"""The base class for control flow context.
|
|
|
|
The usage pattern is a sequence of (Enter, Exit) followed by a final
|
|
ExitResult.
|
|
|
|
We maintain the following state for control flow contexts during graph
|
|
construction:
|
|
1. graph has _control_flow_context: the current context used to
|
|
construct new nodes. Changed by ctxt.Enter() and ctxt.Exit()
|
|
2. op has _control_flow_context: the context to which the op belongs.
|
|
Set at the time the op is created. Immutable.
|
|
3. A ControlFlowContext has _outer_context: the context in which this
|
|
context is created. Set at the time a context is created. Immutable.
|
|
4. A ControlFlowContext has _context_stack.
|
|
Pushed and popped by ctxt.Enter() and ctxt.Exit()
|
|
"""
|
|
|
|
def __init__(self, values_def=None, import_scope=None):
|
|
self._nested_contexts = []
|
|
self._outer_context = ops.get_default_graph()._get_control_flow_context()
|
|
if self._outer_context:
|
|
self._outer_context._nested_contexts.append(self) # pylint: disable=protected-access
|
|
self._context_stack = []
|
|
if values_def:
|
|
self._init_values_from_proto(values_def, import_scope=import_scope)
|
|
else:
|
|
# The names of tensors that have been already seen in this context.
|
|
self._values = set()
|
|
# The keys are the names of tensors referenced by but external to this
|
|
# context. Each value is the Tensor that should be used by this context to
|
|
# access the key value (e.g. a switch output guarding a cond input value).
|
|
self._external_values = {}
|
|
|
|
def _init_values_from_proto(self, values_def, import_scope=None):
|
|
"""Initializes values and external_values from `ValuesDef` protocol buffer.
|
|
|
|
Args:
|
|
values_def: `ValuesDef` protocol buffer.
|
|
import_scope: Optional `string`. Name scope to add.
|
|
"""
|
|
assert isinstance(values_def, control_flow_pb2.ValuesDef)
|
|
self._values = set(
|
|
ops.prepend_name_scope(value, import_scope)
|
|
for value in values_def.values)
|
|
g = ops.get_default_graph()
|
|
self._external_values = {}
|
|
for k, v in values_def.external_values.items():
|
|
k = ops.prepend_name_scope(k, import_scope)
|
|
self._external_values[k] = g.as_graph_element(
|
|
ops.prepend_name_scope(v, import_scope))
|
|
op_names = set([
|
|
op.split(":")[0]
|
|
for op in self._values - set(self._external_values.keys())
|
|
])
|
|
for op in op_names:
|
|
# pylint: disable=protected-access
|
|
g.as_graph_element(op)._set_control_flow_context(self)
|
|
# pylint: enable=protected-access
|
|
|
|
@property
|
|
def name(self):
|
|
return self._name
|
|
|
|
@property
|
|
def outer_context(self):
|
|
"""Return the context containing this context."""
|
|
return self._outer_context
|
|
|
|
@property
|
|
def grad_state(self):
|
|
raise NotImplementedError("Abstract method")
|
|
|
|
@property
|
|
def back_prop(self):
|
|
raise NotImplementedError("Abstract method")
|
|
|
|
@abc.abstractmethod
|
|
def to_control_flow_context_def(self, context_def, export_scope=None):
|
|
"""Serializes this into `context_def`.
|
|
|
|
Args:
|
|
context_def: a `ControlFlowContextDef` protocol buffer.
|
|
export_scope: Optional `string`. Name scope to remove.
|
|
"""
|
|
raise NotImplementedError("Abstract method")
|
|
|
|
def _to_values_def(self, export_scope=None):
|
|
"""Converts the values to a `ValuesDef` protocol buffer.
|
|
|
|
Args:
|
|
export_scope: Optional `string`. Name scope to remove.
|
|
|
|
Returns:
|
|
A `ValuesDef` protocol buffer.
|
|
"""
|
|
values_def = control_flow_pb2.ValuesDef()
|
|
values_def.values.extend(
|
|
[ops.strip_name_scope(v, export_scope) for v in sorted(self._values)])
|
|
for k, v in self._external_values.items():
|
|
k = ops.strip_name_scope(k, export_scope)
|
|
values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope)
|
|
return values_def
|
|
|
|
def AddName(self, name):
|
|
self._values.add(name)
|
|
|
|
# pylint: disable=protected-access
|
|
def Enter(self):
|
|
"""Enter this control flow context."""
|
|
graph = ops.get_default_graph()
|
|
self._context_stack.append(graph._get_control_flow_context())
|
|
graph._set_control_flow_context(self)
|
|
|
|
def Exit(self):
|
|
"""Exit this control flow context."""
|
|
graph = ops.get_default_graph()
|
|
last_context = self._context_stack.pop()
|
|
graph._set_control_flow_context(last_context)
|
|
|
|
def EnterGradientColocation(self, op, gradient_uid):
|
|
"""Start building a gradient colocated with an op."""
|
|
if self._outer_context:
|
|
self._outer_context.EnterGradientColocation(op, gradient_uid)
|
|
|
|
def ExitGradientColocation(self, op, gradient_uid):
|
|
"""Start building a gradient colocated with an op."""
|
|
if self._outer_context:
|
|
self._outer_context.ExitGradientColocation(op, gradient_uid)
|
|
|
|
def ExitResult(self, result):
|
|
"""Make a list of tensors available in the outer context."""
|
|
if self._outer_context:
|
|
def fn(x):
|
|
self._outer_context.AddName(x.name)
|
|
return x
|
|
nest.map_structure(fn, result, expand_composites=True)
|
|
|
|
def GetWhileContext(self):
|
|
"""Return the while context containing this context."""
|
|
if self._outer_context:
|
|
return self._outer_context.GetWhileContext()
|
|
return None
|
|
|
|
def _RemoveExternalControlEdges(self, op):
|
|
"""Remove any external control dependency on this op."""
|
|
while_ctxt = self.GetWhileContext()
|
|
# A control input of `op` is internal if it is in the same while
|
|
# loop context as the enclosing while loop context of self.
|
|
if while_ctxt is None:
|
|
internal_control_inputs = op.control_inputs
|
|
else:
|
|
internal_control_inputs = []
|
|
for x in op.control_inputs:
|
|
ctxt = util.GetOutputContext(x)
|
|
if ctxt is not None and ctxt.GetWhileContext() == while_ctxt:
|
|
internal_control_inputs.append(x)
|
|
external_control_inputs = []
|
|
if len(internal_control_inputs) != len(op.control_inputs):
|
|
external_control_inputs = list(
|
|
set(op.control_inputs) - set(internal_control_inputs))
|
|
op._remove_all_control_inputs()
|
|
op._add_control_inputs(internal_control_inputs)
|
|
return internal_control_inputs, external_control_inputs
|
|
|
|
# pylint: enable=protected-access
|
|
|
|
def AddInnerOp(self, op):
|
|
"""Notifies a scope about an operator added to an inner scope."""
|
|
if self._outer_context:
|
|
self._outer_context.AddInnerOp(op)
|
|
|
|
def GetControlPivot(self):
|
|
"""Returns the pivot node for this context, or None."""
|
|
return None
|
|
|
|
def IsWhileContext(self):
|
|
return False
|
|
|
|
def IsCondContext(self):
|
|
return False
|
|
|
|
def IsXLAContext(self):
|
|
return False
|
|
|
|
def __str__(self):
|
|
return self.name
|
|
|
|
|
|
class CondContext(ControlFlowContext):
|
|
"""The context for the conditional construct."""
|
|
|
|
def __init__(self,
|
|
pred=None,
|
|
pivot=None,
|
|
branch=None,
|
|
name="cond_text",
|
|
context_def=None,
|
|
import_scope=None):
|
|
"""Creates a `CondContext`.
|
|
|
|
Args:
|
|
pred: The `boolean` tensor for the conditional predicate.
|
|
pivot: The predicate tensor in this branch.
|
|
branch: 0 or 1 representing this branch.
|
|
name: Name of the `CondContext` python object.
|
|
context_def: Optional `ContextDef` protocol buffer to initialize the
|
|
`CondContext` object from.
|
|
import_scope: Optional `string`. Name scope to add. Only used when
|
|
initialing from protocol buffer.
|
|
"""
|
|
self._name = ops.get_default_graph().unique_name(name)
|
|
|
|
if context_def:
|
|
self._init_from_proto(context_def, import_scope=import_scope)
|
|
else:
|
|
# Initializes the default fields.
|
|
ControlFlowContext.__init__(self)
|
|
self._pred = pred # The boolean tensor for the cond predicate
|
|
self._pivot = pivot # The predicate tensor in this branch
|
|
self._branch = branch # 0 or 1 representing this branch
|
|
|
|
# Values considered to have been already seen in this context. pred is not
|
|
# included in this context.
|
|
self._values.add(pred.name)
|
|
self._external_values[pred.name] = pred
|
|
self._values.add(pivot.name)
|
|
pivot.op._set_control_flow_context(self) # pylint: disable=protected-access
|
|
|
|
def _init_from_proto(self, context_def, import_scope=None):
|
|
"""Creates a new `CondContext` from protocol buffer.
|
|
|
|
Args:
|
|
context_def: `CondContextDef` protocol buffer.
|
|
import_scope: Optional `string`. Name scope to add.
|
|
"""
|
|
assert isinstance(context_def, control_flow_pb2.CondContextDef)
|
|
# Create from context_def.
|
|
g = ops.get_default_graph()
|
|
self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
|
|
self._pred = g.as_graph_element(
|
|
ops.prepend_name_scope(context_def.pred_name, import_scope))
|
|
self._pivot = g.as_graph_element(
|
|
ops.prepend_name_scope(context_def.pivot_name, import_scope))
|
|
self._branch = context_def.branch
|
|
super(CondContext, self).__init__(
|
|
values_def=context_def.values_def, import_scope=import_scope)
|
|
|
|
@property
|
|
def pred(self):
|
|
return self._pred
|
|
|
|
@property
|
|
def pivot(self):
|
|
return self._pivot
|
|
|
|
@property
|
|
def branch(self):
|
|
return self._branch
|
|
|
|
@property
|
|
def grad_state(self):
|
|
if self.GetWhileContext():
|
|
return self.GetWhileContext().grad_state
|
|
return None
|
|
|
|
@property
|
|
def back_prop(self):
|
|
if self.GetWhileContext():
|
|
self.GetWhileContext().back_prop
|
|
return False
|
|
|
|
def GetControlPivot(self):
|
|
return self._pivot
|
|
|
|
def to_proto(self, export_scope=None):
|
|
"""Converts a `CondContext` to a `CondContextDef` protocol buffer.
|
|
|
|
Args:
|
|
export_scope: Optional `string`. Name scope to remove.
|
|
|
|
Returns:
|
|
A `CondContextDef` protocol buffer.
|
|
"""
|
|
if (export_scope is None or self.name.startswith(export_scope)):
|
|
context_def = control_flow_pb2.CondContextDef()
|
|
context_def.context_name = ops.strip_name_scope(self.name, export_scope)
|
|
context_def.pred_name = ops.strip_name_scope(self._pred.name,
|
|
export_scope)
|
|
context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
|
|
export_scope)
|
|
context_def.branch = self._branch
|
|
context_def.values_def.MergeFrom(
|
|
super(CondContext, self)._to_values_def(export_scope))
|
|
for nested in self._nested_contexts:
|
|
nested_def = context_def.nested_contexts.add()
|
|
nested.to_control_flow_context_def(nested_def)
|
|
|
|
return context_def
|
|
else:
|
|
return None
|
|
|
|
@staticmethod
|
|
def from_proto(context_def, import_scope=None):
|
|
"""Returns a `CondContext` object created from `context_def`."""
|
|
ret = CondContext(context_def=context_def, import_scope=import_scope)
|
|
|
|
ret.Enter()
|
|
for nested_def in context_def.nested_contexts:
|
|
from_control_flow_context_def(nested_def, import_scope=import_scope)
|
|
ret.Exit()
|
|
return ret
|
|
|
|
def to_control_flow_context_def(self, context_def, export_scope=None):
|
|
context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
|
|
|
|
def AddValue(self, val):
|
|
"""Add `val` to the current context and its outer context recursively."""
|
|
if val.name in self._values:
|
|
# Use the real value if it comes from outer context. This is needed in
|
|
# particular for nested conds.
|
|
result = self._external_values.get(val.name)
|
|
result = val if result is None else result
|
|
else:
|
|
result = val
|
|
self._values.add(val.name)
|
|
if self._outer_context:
|
|
result = self._outer_context.AddValue(val)
|
|
self._values.add(result.name)
|
|
self._external_values[result.name] = result
|
|
with ops.control_dependencies(None):
|
|
result = _SwitchRefOrTensor(result, self._pred)[self._branch]
|
|
if self._outer_context:
|
|
self._outer_context.AddInnerOp(result.op)
|
|
|
|
result.op.graph.prevent_fetching(result.op)
|
|
# pylint: disable=protected-access
|
|
result.op._set_control_flow_context(self)
|
|
# pylint: enable=protected-access
|
|
|
|
# Mark Switch output as seen by this context and any outer contexts,
|
|
# just like what we do for normal op outputs in _AddOpInternal() below.
|
|
ctxt = self
|
|
while ctxt is not None:
|
|
# pylint: disable=protected-access
|
|
ctxt._values.add(result.name)
|
|
ctxt = ctxt._outer_context
|
|
# pylint: enable=protected-access
|
|
|
|
self._external_values[val.name] = result
|
|
return result
|
|
|
|
def AddOp(self, op):
|
|
self._AddOpInternal(op)
|
|
|
|
def _AddOpInternal(self, op):
|
|
"""Add `op` to the current context."""
|
|
if not op.inputs:
|
|
# If we're in a while loop, remove any control inputs from outside the
|
|
# loop.
|
|
self._RemoveExternalControlEdges(op)
|
|
|
|
if not any(
|
|
util.OpInContext(input_op, self) for input_op in op.control_inputs):
|
|
# pylint: disable=protected-access
|
|
op._add_control_input(self._pivot.op)
|
|
# pylint: enable=protected-access
|
|
else:
|
|
# Make each input to 'op' available in this CondContext. If an input is
|
|
# already part of this context there's nothing to do, but if it's
|
|
# external, AddValue() will handle adding the appropriate Switch node and
|
|
# other bookkeeping.
|
|
for index in range(len(op.inputs)):
|
|
x = op.inputs[index]
|
|
if op.type == "Merge" and x.op.type == "NextIteration":
|
|
# Edge case: if we're importing a while loop inside this CondContext,
|
|
# AddValue() will not correctly handle the NextIteration inputs to
|
|
# Merge node. The problem is that the NextIteration should also be
|
|
# part of this context, but if we're importing it won't have been
|
|
# processed and added to the context yet, so AddValue() will try to
|
|
# add a Switch which results in an invalid graph. Instead, we use the
|
|
# NextIteration input as-is here, and it will eventually be added to
|
|
# the context via AddOp().
|
|
real_x = x
|
|
else:
|
|
real_x = self.AddValue(x)
|
|
if real_x != x:
|
|
# pylint: disable=protected-access
|
|
op._update_input(index, real_x)
|
|
# pylint: enable=protected-access
|
|
# Remove any external control dependency on this op.
|
|
self._RemoveExternalControlEdges(op)
|
|
# pylint: disable=protected-access
|
|
if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
|
|
op._add_control_input(self._pivot.op)
|
|
# pylint: enable=protected-access
|
|
|
|
# Mark op's outputs as seen by this context and any outer contexts.
|
|
output_names = [x.name for x in op.outputs]
|
|
ctxt = self
|
|
while ctxt is not None:
|
|
# pylint: disable=protected-access
|
|
ctxt._values.update(output_names)
|
|
ctxt = ctxt._outer_context
|
|
# pylint: enable=protected-access
|
|
|
|
if self._outer_context or not util.IsLoopExit(op):
|
|
op.graph.prevent_fetching(op)
|
|
|
|
if self._outer_context:
|
|
self._outer_context.AddInnerOp(op)
|
|
|
|
def _ProcessOutputTensor(self, val):
|
|
"""Process an output tensor of a conditional branch."""
|
|
real_val = val
|
|
if val.name not in self._values:
|
|
# Handle the special case of lambda: x
|
|
self._values.add(val.name)
|
|
if self._outer_context:
|
|
real_val = self._outer_context.AddValue(val)
|
|
self._values.add(real_val.name)
|
|
self._external_values[real_val.name] = real_val
|
|
real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch]
|
|
self._external_values[val.name] = real_val
|
|
else:
|
|
external_val = self._external_values.get(val.name)
|
|
if external_val is not None:
|
|
real_val = external_val
|
|
return real_val
|
|
|
|
def _BuildCondTensor(self, v):
|
|
if isinstance(v, ops.Operation):
|
|
# Use pivot as the proxy for this op.
|
|
return with_dependencies([v], self._pivot)
|
|
else:
|
|
v = nest.map_structure(
|
|
_convert_tensorarray_to_flow, v, expand_composites=True)
|
|
return self._ProcessOutputTensor(ops.convert_to_tensor(v))
|
|
|
|
def BuildCondBranch(self, fn):
|
|
"""Add the subgraph defined by fn() to the graph."""
|
|
pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
|
|
original_result = fn()
|
|
post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
|
|
if len(post_summaries) > len(pre_summaries):
|
|
new_summaries = post_summaries[len(pre_summaries):]
|
|
summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
|
|
summary_ref[:] = pre_summaries
|
|
with ops.control_dependencies(new_summaries):
|
|
if original_result is None:
|
|
return no_op(), None
|
|
elif not isinstance(original_result, ops.Operation):
|
|
original_result = nest.map_structure(
|
|
array_ops.identity, original_result, expand_composites=True)
|
|
if original_result is None:
|
|
return None, None
|
|
|
|
result = nest.map_structure(
|
|
self._BuildCondTensor, original_result, expand_composites=True)
|
|
if not isinstance(result, (list, _basetuple)):
|
|
result = [result]
|
|
return original_result, result
|
|
|
|
def IsCondContext(self):
|
|
return True
|
|
|
|
|
|
def _UnpackIfSingleton(res):
|
|
if isinstance(res, (list, _basetuple)) and len(res) == 1:
|
|
return res[0]
|
|
else:
|
|
return res
|
|
|
|
|
|
# pylint: disable=redefined-outer-name
|
|
# pylint: disable=g-doc-args
|
|
@tf_export(v1=["cond"])
|
|
@dispatch.add_dispatch_support
|
|
@deprecation.deprecated_args(
|
|
None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
|
|
"fn1", "fn2")
|
|
def cond(pred,
|
|
true_fn=None,
|
|
false_fn=None,
|
|
strict=False,
|
|
name=None,
|
|
fn1=None,
|
|
fn2=None):
|
|
"""Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
|
|
|
|
`true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
|
|
`false_fn` must have the same non-zero number and type of outputs.
|
|
|
|
**WARNING**: Any Tensors or Operations created outside of `true_fn` and
|
|
`false_fn` will be executed regardless of which branch is selected at runtime.
|
|
|
|
Although this behavior is consistent with the dataflow model of TensorFlow,
|
|
it has frequently surprised users who expected a lazier semantics.
|
|
Consider the following simple program:
|
|
|
|
```python
|
|
z = tf.multiply(a, b)
|
|
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
|
|
```
|
|
|
|
If `x < y`, the `tf.add` operation will be executed and `tf.square`
|
|
operation will not be executed. Since `z` is needed for at least one
|
|
branch of the `cond`, the `tf.multiply` operation is always executed,
|
|
unconditionally.
|
|
|
|
Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
|
|
call to `cond`, and not at all during `Session.run()`). `cond`
|
|
stitches together the graph fragments created during the `true_fn` and
|
|
`false_fn` calls with some additional graph nodes to ensure that the right
|
|
branch gets executed depending on the value of `pred`.
|
|
|
|
`tf.cond` supports nested structures as implemented in
|
|
`tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
|
|
same (possibly nested) value structure of lists, tuples, and/or named tuples.
|
|
Singleton lists and tuples form the only exceptions to this: when returned by
|
|
`true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
|
|
This behavior is disabled by passing `strict=True`.
|
|
|
|
Args:
|
|
pred: A scalar determining whether to return the result of `true_fn` or
|
|
`false_fn`.
|
|
true_fn: The callable to be performed if pred is true.
|
|
false_fn: The callable to be performed if pred is false.
|
|
strict: A boolean that enables/disables 'strict' mode; see above.
|
|
name: Optional name prefix for the returned tensors.
|
|
|
|
Returns:
|
|
Tensors returned by the call to either `true_fn` or `false_fn`. If the
|
|
callables return a singleton list, the element is extracted from the list.
|
|
|
|
Raises:
|
|
TypeError: if `true_fn` or `false_fn` is not callable.
|
|
ValueError: if `true_fn` and `false_fn` do not return the same number of
|
|
tensors, or return tensors of different types.
|
|
|
|
Example:
|
|
|
|
```python
|
|
x = tf.constant(2)
|
|
y = tf.constant(5)
|
|
def f1(): return tf.multiply(x, 17)
|
|
def f2(): return tf.add(y, 23)
|
|
r = tf.cond(tf.less(x, y), f1, f2)
|
|
# r is set to f1().
|
|
# Operations in f2 (e.g., tf.add) are not executed.
|
|
```
|
|
|
|
"""
|
|
# Always enable control flow v2 if building a function, regardless of toggle.
|
|
if (util.EnableControlFlowV2(ops.get_default_graph()) and
|
|
not context.executing_eagerly()):
|
|
return cond_v2.cond_v2(pred, true_fn, false_fn, name)
|
|
|
|
# We needed to make true_fn/false_fn keyword arguments for
|
|
# backwards-compatibility. This check exists so that we can convert back to
|
|
# having them be positional arguments.
|
|
# TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after
|
|
# `fn1` and `fn2` are deleted.
|
|
if fn1 is not None:
|
|
if true_fn is not None:
|
|
raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
|
|
true_fn = fn1
|
|
elif true_fn is None:
|
|
raise TypeError("cond(): true_fn argument required")
|
|
if fn2 is not None:
|
|
if false_fn is not None:
|
|
raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
|
|
false_fn = fn2
|
|
elif false_fn is None:
|
|
raise TypeError("cond(): false_fn argument required")
|
|
|
|
if not callable(true_fn):
|
|
raise TypeError("true_fn must be callable.")
|
|
if not callable(false_fn):
|
|
raise TypeError("false_fn must be callable.")
|
|
|
|
with ops.name_scope(name, "cond", [pred]):
|
|
if context.executing_eagerly():
|
|
if pred:
|
|
result = true_fn()
|
|
else:
|
|
result = false_fn()
|
|
if not strict:
|
|
result = _UnpackIfSingleton(result)
|
|
return result
|
|
|
|
# Add the Switch to the graph.
|
|
if isinstance(pred, bool):
|
|
raise TypeError("pred must not be a Python bool")
|
|
p_2, p_1 = switch(pred, pred)
|
|
pivot_1 = array_ops.identity(p_1, name="switch_t")
|
|
pivot_2 = array_ops.identity(p_2, name="switch_f")
|
|
pred = array_ops.identity(pred, name="pred_id")
|
|
# Disable the fetching of tensors that are only on one branch of cond.
|
|
for tensor in [p_1, p_2, pivot_1, pivot_2, pred]:
|
|
tensor.op.graph.prevent_fetching(tensor.op)
|
|
|
|
# Build the graph for the true branch in a new context.
|
|
context_t = CondContext(pred, pivot_1, branch=1)
|
|
try:
|
|
context_t.Enter()
|
|
orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
|
|
if orig_res_t is None:
|
|
raise ValueError("true_fn must have a return value.")
|
|
context_t.ExitResult(res_t)
|
|
finally:
|
|
context_t.Exit()
|
|
|
|
# Build the graph for the false branch in a new context.
|
|
context_f = CondContext(pred, pivot_2, branch=0)
|
|
try:
|
|
context_f.Enter()
|
|
orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
|
|
if orig_res_f is None:
|
|
raise ValueError("false_fn must have a return value.")
|
|
context_f.ExitResult(res_f)
|
|
finally:
|
|
context_f.Exit()
|
|
|
|
if not strict:
|
|
orig_res_t = _UnpackIfSingleton(orig_res_t)
|
|
orig_res_f = _UnpackIfSingleton(orig_res_f)
|
|
|
|
# Check that the return values of the two branches have the same structure.
|
|
try:
|
|
nest.assert_same_structure(orig_res_t, orig_res_f, expand_composites=True)
|
|
except (TypeError, ValueError):
|
|
nest.map_structure(_cast_indexed_slice_indices, orig_res_t, orig_res_f)
|
|
nest.map_structure(_cast_indexed_slice_indices, res_t, res_f)
|
|
try:
|
|
nest.assert_same_structure(orig_res_t, orig_res_f,
|
|
expand_composites=True)
|
|
except TypeError as e:
|
|
raise TypeError(
|
|
"Incompatible return types of true_fn and false_fn: {}".format(e))
|
|
except ValueError as e:
|
|
raise ValueError(
|
|
"Incompatible return values of true_fn and false_fn: {}".format(e))
|
|
|
|
# Add the final merge to the graph.
|
|
if not res_t:
|
|
raise ValueError("true_fn and false_fn must return at least one result.")
|
|
|
|
res_t_flat = nest.flatten(res_t, expand_composites=True)
|
|
res_f_flat = nest.flatten(res_f, expand_composites=True)
|
|
|
|
for (x, y) in zip(res_t_flat, res_f_flat):
|
|
assert isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor)
|
|
if x.dtype.base_dtype != y.dtype.base_dtype:
|
|
raise ValueError(
|
|
"Outputs of true_fn and false_fn must have the same type: "
|
|
"%s, %s" % (x.dtype.name, y.dtype.name))
|
|
|
|
merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]
|
|
merges = _convert_flows_to_tensorarrays(
|
|
nest.flatten(orig_res_t, expand_composites=True), merges)
|
|
|
|
# Only add non-nested conds to the collection. Any nested control flow will
|
|
# be encapsulated in the root context.
|
|
assert context_t.outer_context == context_f.outer_context
|
|
if context_t.outer_context is None:
|
|
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t)
|
|
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f)
|
|
|
|
merges = nest.pack_sequence_as(
|
|
structure=orig_res_t, flat_sequence=merges, expand_composites=True)
|
|
|
|
# Singleton lists and tuples are automatically unpacked if strict == False.
|
|
if not strict:
|
|
merges = _UnpackIfSingleton(merges)
|
|
return merges
|
|
|
|
|
|
def _cast_indexed_slice_indices(a, b):
|
|
"""Cast IndexedSlice.indices from int32 to int64 where necessary.
|
|
|
|
If `a` and `b` are both IndexedSlices, and their indices have different
|
|
dtypes, then cast both their dtypes to `int64` (modifies `a` and `b`
|
|
in-place). Otherwise, does nothing.
|
|
|
|
Args:
|
|
a: A value, which may be an IndexedSlices.
|
|
b: A value, which may be an IndexedSlices.
|
|
"""
|
|
if (isinstance(a, ops.IndexedSlices) and isinstance(b, ops.IndexedSlices)
|
|
and a.indices.dtype != b.indices.dtype):
|
|
# pylint: disable=protected-access
|
|
a._indices = math_ops.cast(a.indices, dtypes.int64)
|
|
b._indices = math_ops.cast(b.indices, dtypes.int64)
|
|
|
|
|
|
# pylint: enable=g-doc-args
|
|
# pylint: enable=redefined-outer-name
|
|
|
|
|
|
@tf_export("cond", v1=[])
|
|
@dispatch.add_dispatch_support
|
|
def cond_for_tf_v2(pred, true_fn=None, false_fn=None, name=None):
|
|
"""Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
|
|
|
|
`true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
|
|
`false_fn` must have the same non-zero number and type of outputs.
|
|
|
|
**WARNING**: Any Tensors or Operations created outside of `true_fn` and
|
|
`false_fn` will be executed regardless of which branch is selected at runtime.
|
|
|
|
Although this behavior is consistent with the dataflow model of TensorFlow,
|
|
it has frequently surprised users who expected a lazier semantics.
|
|
Consider the following simple program:
|
|
|
|
```python
|
|
z = tf.multiply(a, b)
|
|
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
|
|
```
|
|
|
|
If `x < y`, the `tf.add` operation will be executed and `tf.square`
|
|
operation will not be executed. Since `z` is needed for at least one
|
|
branch of the `cond`, the `tf.multiply` operation is always executed,
|
|
unconditionally.
|
|
|
|
Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
|
|
call to `cond`, and not at all during `Session.run()`). `cond`
|
|
stitches together the graph fragments created during the `true_fn` and
|
|
`false_fn` calls with some additional graph nodes to ensure that the right
|
|
branch gets executed depending on the value of `pred`.
|
|
|
|
`tf.cond` supports nested structures as implemented in
|
|
`tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
|
|
same (possibly nested) value structure of lists, tuples, and/or named tuples.
|
|
Singleton lists and tuples form the only exceptions to this: when returned by
|
|
`true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
|
|
|
|
Note: It is illegal to "directly" use tensors created inside a cond branch
|
|
outside it, e.g. by storing a reference to a branch tensor in the python
|
|
state. If you need to use a tensor created in a branch function you should
|
|
return it as an output of the branch function and use the output from
|
|
`tf.cond` instead.
|
|
|
|
Args:
|
|
pred: A scalar determining whether to return the result of `true_fn` or
|
|
`false_fn`.
|
|
true_fn: The callable to be performed if pred is true.
|
|
false_fn: The callable to be performed if pred is false.
|
|
name: Optional name prefix for the returned tensors.
|
|
|
|
Returns:
|
|
Tensors returned by the call to either `true_fn` or `false_fn`. If the
|
|
callables return a singleton list, the element is extracted from the list.
|
|
|
|
Raises:
|
|
TypeError: if `true_fn` or `false_fn` is not callable.
|
|
ValueError: if `true_fn` and `false_fn` do not return the same number of
|
|
tensors, or return tensors of different types.
|
|
|
|
Example:
|
|
|
|
```python
|
|
x = tf.constant(2)
|
|
y = tf.constant(5)
|
|
def f1(): return tf.multiply(x, 17)
|
|
def f2(): return tf.add(y, 23)
|
|
r = tf.cond(tf.less(x, y), f1, f2)
|
|
# r is set to f1().
|
|
# Operations in f2 (e.g., tf.add) are not executed.
|
|
```
|
|
|
|
"""
|
|
return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name)
|
|
|
|
|
|
def _resource_safe_shape(t):
|
|
"""Returns the shape of t or the variable it points to."""
|
|
if t.dtype == dtypes.resource:
|
|
while t.op.inputs:
|
|
t = t.op.inputs[0]
|
|
return tensor_shape.TensorShape(t.op.get_attr("shape"))
|
|
return array_ops.shape_internal(t, optimize=False)
|
|
|
|
|
|
# TODO(yuanbyu): Consider having a unified notion of context for
|
|
# not only conditionals and loops but also control dependency and
|
|
# subgraphs.
|
|
class WhileContext(ControlFlowContext):
|
|
"""The context for the loop construct."""
|
|
|
|
def __init__(self,
|
|
maximum_iterations=None,
|
|
parallel_iterations=10,
|
|
back_prop=True,
|
|
swap_memory=False,
|
|
name="while_context",
|
|
grad_state=None,
|
|
context_def=None,
|
|
import_scope=None):
|
|
""""Creates a `WhileContext`.
|
|
|
|
Args:
|
|
maximum_iterations: Optional upper bound on number of loop iterations.
|
|
parallel_iterations: The number of iterations allowed to run in parallel.
|
|
back_prop: Whether backprop is enabled for this while loop.
|
|
swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
|
|
name: Optional name prefix for the returned tensors.
|
|
grad_state: The gradient loop state.
|
|
context_def: Optional `WhileContextDef` protocol buffer to initialize the
|
|
`Whilecontext` python object from.
|
|
import_scope: Optional `string`. Name scope to add. Only used when
|
|
initialing from protocol buffer.
|
|
"""
|
|
if context_def:
|
|
self._init_from_proto(context_def, import_scope=import_scope)
|
|
else:
|
|
ControlFlowContext.__init__(self)
|
|
self._init_from_args(maximum_iterations, parallel_iterations, back_prop,
|
|
swap_memory, name)
|
|
# The gradient loop state.
|
|
self._grad_state = grad_state
|
|
|
|
def _init_from_args(self, maximum_iterations, parallel_iterations, back_prop,
|
|
swap_memory, name):
|
|
"""Creates a new `WhileContext` from arguments.
|
|
|
|
Args:
|
|
maximum_iterations: Optional upper bound on number of loop iterations.
|
|
parallel_iterations: The number of iterations allowed to run in parallel.
|
|
back_prop: Whether backprop is enabled for this while loop.
|
|
swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
|
|
name: Optional name prefix for the returned tensors.
|
|
|
|
Raises:
|
|
ValueError: If `parallel_iterations` has invalid value.
|
|
"""
|
|
if not isinstance(parallel_iterations, int) or (parallel_iterations <= 0):
|
|
raise ValueError("`parallel_iterations` must be a positive integer: "
|
|
"%s" % parallel_iterations)
|
|
self._name = ops.get_default_graph().unique_name(name)
|
|
self._maximum_iterations = maximum_iterations
|
|
self._parallel_iterations = parallel_iterations
|
|
self._back_prop = back_prop
|
|
self._swap_memory = swap_memory
|
|
# We use this node to control constants created by the pred lambda.
|
|
self._pivot_for_pred = None
|
|
# We use this node to control constants created by the body lambda.
|
|
self._pivot_for_body = None
|
|
# The boolean tensor for loop termination condition. Used in code
|
|
# generation for gradient computation
|
|
self._pivot = None
|
|
# The list of exit tensors for loop variables.
|
|
self._loop_exits = []
|
|
# The list of enter tensors for loop variables.
|
|
self._loop_enters = []
|
|
self._graph = ops.get_default_graph()
|
|
|
|
def _init_from_proto(self, context_def, import_scope=None):
|
|
"""Creates a new `WhileContext` from protocol buffer.
|
|
|
|
Args:
|
|
context_def: `WhileContextDef` protocol buffer.
|
|
import_scope: Optional `string`. Name scope to add.
|
|
"""
|
|
assert isinstance(context_def, control_flow_pb2.WhileContextDef)
|
|
# Create from context_def.
|
|
g = ops.get_default_graph()
|
|
self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
|
|
if context_def.maximum_iterations_name:
|
|
self._maximum_iterations = g.as_graph_element(
|
|
ops.prepend_name_scope(context_def.maximum_iterations_name,
|
|
import_scope))
|
|
else:
|
|
self._maximum_iterations = None
|
|
self._parallel_iterations = context_def.parallel_iterations
|
|
self._back_prop = context_def.back_prop
|
|
self._swap_memory = context_def.swap_memory
|
|
self._pivot_for_pred = g.as_graph_element(
|
|
ops.prepend_name_scope(context_def.pivot_for_pred_name, import_scope))
|
|
# We use this node to control constants created by the body lambda.
|
|
self._pivot_for_body = g.as_graph_element(
|
|
ops.prepend_name_scope(context_def.pivot_for_body_name, import_scope))
|
|
# The boolean tensor for loop termination condition. Used in code
|
|
# generation for gradient computation.
|
|
self._pivot = g.as_graph_element(
|
|
ops.prepend_name_scope(context_def.pivot_name, import_scope))
|
|
# The list of exit tensors for loop variables.
|
|
self._loop_exits = [
|
|
g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope))
|
|
for exit_name in context_def.loop_exit_names
|
|
]
|
|
# The list of enter tensors for loop variables.
|
|
self._loop_enters = [
|
|
g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope))
|
|
for enter_name in context_def.loop_enter_names
|
|
]
|
|
super(WhileContext, self).__init__(
|
|
values_def=context_def.values_def, import_scope=import_scope)
|
|
|
|
# import_scope causes self.name to be different from the original serialized
|
|
# context's name. Rewrite "frame_name" attrs with the new name.
|
|
if import_scope:
|
|
for tensor_name in self._values:
|
|
op = g.as_graph_element(tensor_name).op
|
|
if util.IsLoopEnter(op):
|
|
# pylint: disable=protected-access
|
|
op._set_attr("frame_name",
|
|
attr_value_pb2.AttrValue(s=compat.as_bytes(self.name)))
|
|
# pylint: enable=protected-access
|
|
self._graph = ops.get_default_graph()
|
|
|
|
@property
|
|
def maximum_iterations(self):
|
|
"""The maximum number of iterations that will be executed."""
|
|
return self._maximum_iterations
|
|
|
|
@property
|
|
def parallel_iterations(self):
|
|
"""The number of iterations allowed to run in parallel."""
|
|
return self._parallel_iterations
|
|
|
|
@property
|
|
def back_prop(self):
|
|
"""True iff backprop is enabled for this while loop."""
|
|
return self._back_prop
|
|
|
|
@property
|
|
def swap_memory(self):
|
|
"""True iff GPU-CPU memory swap is enabled for this while loop."""
|
|
return self._swap_memory
|
|
|
|
@property
|
|
def pivot(self):
|
|
"""The boolean tensor representing the loop termination condition."""
|
|
return self._pivot
|
|
|
|
@property
|
|
def loop_enters(self):
|
|
"""The list of enter tensors for loop variables."""
|
|
return self._loop_enters
|
|
|
|
@property
|
|
def loop_exits(self):
|
|
"""The list of exit tensors for loop variables."""
|
|
return self._loop_exits
|
|
|
|
@property
|
|
def grad_state(self):
|
|
"""The gradient loop state."""
|
|
return self._grad_state
|
|
|
|
def to_proto(self, export_scope=None):
|
|
"""Converts a `WhileContext` to a `WhileContextDef` protocol buffer.
|
|
|
|
Args:
|
|
export_scope: Optional `string`. Name scope to remove.
|
|
|
|
Returns:
|
|
A `WhileContextDef` protocol buffer.
|
|
"""
|
|
if (export_scope is None or self.name.startswith(export_scope)):
|
|
context_def = control_flow_pb2.WhileContextDef()
|
|
context_def.context_name = ops.strip_name_scope(self.name, export_scope)
|
|
context_def.parallel_iterations = self._parallel_iterations
|
|
if self._maximum_iterations is not None:
|
|
context_def.maximum_iterations_name = ops.strip_name_scope(
|
|
self._maximum_iterations.name, export_scope)
|
|
context_def.back_prop = self._back_prop
|
|
context_def.swap_memory = self._swap_memory
|
|
context_def.pivot_for_pred_name = ops.strip_name_scope(
|
|
self._pivot_for_pred.name, export_scope)
|
|
context_def.pivot_for_body_name = ops.strip_name_scope(
|
|
self._pivot_for_body.name, export_scope)
|
|
context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
|
|
export_scope)
|
|
context_def.loop_exit_names.extend([
|
|
ops.strip_name_scope(l.name, export_scope) for l in self._loop_exits
|
|
])
|
|
context_def.loop_enter_names.extend([
|
|
ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters
|
|
])
|
|
context_def.values_def.MergeFrom(
|
|
super(WhileContext, self)._to_values_def(export_scope=export_scope))
|
|
for nested in self._nested_contexts:
|
|
nested_def = context_def.nested_contexts.add()
|
|
nested.to_control_flow_context_def(nested_def)
|
|
|
|
return context_def
|
|
else:
|
|
return None
|
|
|
|
def to_control_flow_context_def(self, context_def, export_scope=None):
|
|
context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
|
|
|
|
@staticmethod
|
|
def from_proto(context_def, import_scope=None):
|
|
"""Returns a `WhileContext` object created from `context_def`.
|
|
|
|
Args:
|
|
context_def: A `WhileContextDef` protocol buffer.
|
|
import_scope: Optional `string`. Name scope to add.
|
|
|
|
Returns:
|
|
A `WhileContext` Python object.
|
|
"""
|
|
ret = WhileContext(context_def=context_def, import_scope=import_scope)
|
|
ret.Enter()
|
|
for nested_def in context_def.nested_contexts:
|
|
from_control_flow_context_def(nested_def, import_scope=import_scope)
|
|
ret.Exit()
|
|
return ret
|
|
|
|
def GetWhileContext(self):
|
|
return self
|
|
|
|
def GetControlPivot(self):
|
|
if self._pivot_for_body is not None:
|
|
return self._pivot_for_body
|
|
return self._pivot_for_pred
|
|
|
|
def AddValue(self, val):
|
|
"""Add `val` to the current context and its outer context recursively."""
|
|
result = val
|
|
new_value = val.name not in self._values
|
|
# Don't treat ops in this context as new values. Usually all known values
|
|
# are in self._values, except when we're importing a while loop inside this
|
|
# WhileContext. Since there's a cycle in this case, `val` may be part of the
|
|
# imported while loop but not yet processed by this context and added to
|
|
# self._values in _AddOpInternal. We only want to process external input
|
|
# tensors to the while loop here.
|
|
new_value &= val.op._control_flow_context is not self # pylint: disable=protected-access
|
|
if new_value:
|
|
self._values.add(val.name)
|
|
|
|
# If we are in a grad context and val is from its forward context,
|
|
# use GetRealValue(), which adds the logic to save the history of
|
|
# val in forward.
|
|
grad_ctxt = ops.get_default_graph()._get_control_flow_context()
|
|
if grad_ctxt:
|
|
grad_ctxt = grad_ctxt.GetWhileContext()
|
|
if grad_ctxt.grad_state:
|
|
forward_ctxt = util.GetWhileContext(val.op)
|
|
if util.IsLoopExit(val.op):
|
|
forward_ctxt = forward_ctxt.outer_context
|
|
if forward_ctxt:
|
|
forward_ctxt = forward_ctxt.GetWhileContext()
|
|
if forward_ctxt == grad_ctxt.grad_state.forward_context:
|
|
real_val = grad_ctxt.grad_state.GetRealValue(val)
|
|
self._external_values[val.name] = real_val
|
|
return real_val
|
|
|
|
if self._outer_context is not None:
|
|
result = self._outer_context.AddValue(val)
|
|
# Create an Enter to make `result` known to this loop context.
|
|
with ops.control_dependencies(None):
|
|
enter = _Enter(
|
|
result,
|
|
self._name,
|
|
is_constant=True,
|
|
parallel_iterations=self._parallel_iterations)
|
|
enter.graph.prevent_feeding(enter)
|
|
if self._outer_context:
|
|
self._outer_context.AddInnerOp(enter.op)
|
|
# Fix the control inputs and control flow context of these enter ops.
|
|
self._FixControlInputsAndContext([enter])
|
|
|
|
# Add `enter` in this context.
|
|
self._values.add(enter.name)
|
|
self._external_values[val.name] = enter
|
|
result = enter
|
|
else:
|
|
actual_val = self._external_values.get(val.name)
|
|
if actual_val is not None:
|
|
result = actual_val
|
|
return result
|
|
|
|
def AddOp(self, op):
|
|
"""Add `op` to the current context."""
|
|
# For a reduction op, if op is in a grad context and its input is from
|
|
# its forward context, moving op to the forward context means we would
|
|
# store the tensor after the reduction as opposed to the tensor before
|
|
# reduction, and therefore could significantly reduce memory consumption.
|
|
# For now, we do this only for a few ops.
|
|
#
|
|
# If in XLA context, do not move constant ops to forward pass as pushing to
|
|
# and popping from a stack removes the constant property of an op and breaks
|
|
# XLA compilation, which requires certain inputs to be constant for certain
|
|
# ops.
|
|
if not util.IsInXLAContext(op) and op.type in {"Shape", "Size", "Rank"}:
|
|
grad_ctxt = ops.get_default_graph()._get_control_flow_context()
|
|
if grad_ctxt:
|
|
grad_ctxt = grad_ctxt.GetWhileContext()
|
|
if grad_ctxt.grad_state:
|
|
op_input_forward_ctxt = util.GetWhileContext(op.inputs[0].op)
|
|
if op_input_forward_ctxt == grad_ctxt.grad_state.forward_context:
|
|
op_input_ctxt = op.inputs[0].op._get_control_flow_context()
|
|
op._set_control_flow_context(op_input_ctxt)
|
|
op_input_ctxt._AddOpInternal(op)
|
|
return
|
|
self._AddOpInternal(op)
|
|
|
|
def _AddOpInternal(self, op):
|
|
"""Add `op` to the current context.
|
|
|
|
We move any external control dependencies of the op to the loop pivot, to
|
|
ensure they get executed.
|
|
"""
|
|
# This is needed to prevent frame mismatch errors where there are Const
|
|
# nodes inside tf.function in v1 while_loop and inlining is turned on.
|
|
if op.type in ["PartitionedCall", "StatefulPartitionedCall"]:
|
|
op._add_control_input(self.GetControlPivot().op) # pylint: disable=protected-access
|
|
if not op.inputs:
|
|
# Remove any external control dependency on this op
|
|
control_inputs, external_inputs = self._RemoveExternalControlEdges(op)
|
|
# Add a control edge from the control pivot to this op.
|
|
if not control_inputs:
|
|
# pylint: disable=protected-access
|
|
op._add_control_input(self.GetControlPivot().op)
|
|
# pylint: enable=protected-access
|
|
for x in op.outputs:
|
|
self._values.add(x.name)
|
|
else:
|
|
for index in range(len(op.inputs)):
|
|
x = op.inputs[index]
|
|
real_x = self.AddValue(x)
|
|
if real_x != x:
|
|
op._update_input(index, real_x) # pylint: disable=protected-access
|
|
# Remove any external control dependency on this op.
|
|
_, external_inputs = self._RemoveExternalControlEdges(op)
|
|
# Add a control dependency to prevent loop invariants from
|
|
# enabling ops that should not be executed.
|
|
self._MaybeAddControlDependency(op)
|
|
for x in op.outputs:
|
|
self._values.add(x.name)
|
|
if external_inputs:
|
|
# Use an identity to pull control inputs as data inputs. Note that we
|
|
# ignore ops which don't have outputs. TODO(apassos): fix that
|
|
with ops.control_dependencies(None):
|
|
self.Enter()
|
|
external_inputs = [
|
|
array_ops.identity(x.outputs[0]).op
|
|
for x in external_inputs
|
|
if x.outputs
|
|
]
|
|
self.Exit()
|
|
op._add_control_inputs(external_inputs) # pylint: disable=protected-access
|
|
if self._outer_context or not util.IsLoopExit(op):
|
|
op.graph.prevent_fetching(op)
|
|
for x in op.outputs:
|
|
op.graph.prevent_feeding(x)
|
|
|
|
if self._outer_context:
|
|
self._outer_context.AddInnerOp(op)
|
|
|
|
def _MaybeAddControlDependency(self, op):
|
|
"""Add a control input to the op if it only depends on loop invariants."""
|
|
|
|
def _IsOpFree(op):
|
|
"""Determines if `op` needs a control dependency."""
|
|
if op.control_inputs:
|
|
return False
|
|
# pylint: disable=protected-access
|
|
if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
|
|
return True
|
|
# pylint: enable=protected-access
|
|
for x in op.inputs:
|
|
if not util.IsLoopConstantEnter(x.op):
|
|
return False
|
|
return True
|
|
|
|
if _IsOpFree(op):
|
|
# pylint: disable=protected-access
|
|
op._add_control_input(self.GetControlPivot().op)
|
|
# pylint: enable=protected-access
|
|
|
|
def AddForwardLoopCounter(self, outer_grad_state):
|
|
"""Adds a loop that counts the number of iterations.
|
|
|
|
This is added to the forward loop at the time when we start to
|
|
create the loop for backprop gradient computation. Called in
|
|
the outer context of this forward context.
|
|
|
|
The pseudocode is:
|
|
`n = 0; while (_pivot) { n++; }`
|
|
|
|
Note that a control dependency is added to `n` to ensure the correct
|
|
execution order of stack push ops.
|
|
|
|
Args:
|
|
outer_grad_state: The outer grad state. None if not nested.
|
|
|
|
Returns:
|
|
The number of iterations taken by the forward loop and the loop index.
|
|
"""
|
|
n = constant_op.constant(0, name="f_count")
|
|
if outer_grad_state is not None:
|
|
# Force the stack pushes of i-th execution of an inner loop to be ordered
|
|
# before the pushes of (i+1)-th execution of the same inner loop.
|
|
outer_add_op = outer_grad_state.forward_index.op.inputs[0].op
|
|
n.op._add_control_input(outer_add_op) # pylint: disable=protected-access
|
|
|
|
self.Enter()
|
|
self.AddName(n.name)
|
|
enter_n = _Enter(
|
|
n,
|
|
self._name,
|
|
is_constant=False,
|
|
parallel_iterations=self._parallel_iterations,
|
|
name="f_count")
|
|
self.loop_enters.append(enter_n)
|
|
|
|
merge_n = merge([enter_n, enter_n])[0]
|
|
switch_n = switch(merge_n, self._pivot)
|
|
|
|
index = math_ops.add(switch_n[1], 1)
|
|
next_n = _NextIteration(index)
|
|
merge_n.op._update_input(1, next_n)
|
|
|
|
total_iterations = exit(switch_n[0], name="f_count")
|
|
self.loop_exits.append(total_iterations)
|
|
self.ExitResult([total_iterations])
|
|
self.Exit()
|
|
return total_iterations, next_n
|
|
|
|
def AddBackpropLoopCounter(self, count, outer_grad_state):
|
|
"""Add the backprop loop that controls the iterations.
|
|
|
|
This is added to the backprop loop. It is used to control the loop
|
|
termination of the backprop loop. Called in the outer context of
|
|
this grad context.
|
|
|
|
The pseudocode is:
|
|
`n = count; while (n >= 1) { n--; }`
|
|
|
|
Note that a control dependency is added to `final_zero` to ensure the
|
|
correct execution order of stack pop ops.
|
|
|
|
Args:
|
|
count: The number of iterations for backprop.
|
|
outer_grad_state: The outer grad state. None if not nested.
|
|
|
|
Returns:
|
|
The loop index.
|
|
"""
|
|
in_separate_functions = count.graph is not ops.get_default_graph()
|
|
if in_separate_functions:
|
|
# Brings the count into this graph
|
|
count = array_ops.identity(count)
|
|
else:
|
|
# TODO(apassos) XLA expects this constant to be created outside the loop,
|
|
# so doing that for now.
|
|
one = constant_op.constant(1, name="b_count")
|
|
|
|
self.Enter()
|
|
self.AddName(count.name)
|
|
enter_count = _Enter(
|
|
count,
|
|
self._name,
|
|
is_constant=False,
|
|
parallel_iterations=self._parallel_iterations,
|
|
name="b_count")
|
|
self.loop_enters.append(enter_count)
|
|
|
|
merge_count = merge([enter_count, enter_count])[0]
|
|
self._pivot_for_pred = merge_count
|
|
|
|
if in_separate_functions:
|
|
one = constant_op.constant(1, name="b_count")
|
|
pred = math_ops.greater_equal(merge_count, one)
|
|
self._pivot = loop_cond(pred, name="b_count")
|
|
switch_count = switch(merge_count, self._pivot)
|
|
|
|
index = math_ops.subtract(switch_count[1], one)
|
|
self._pivot_for_body = index
|
|
next_count = _NextIteration(index)
|
|
merge_count.op._update_input(1, next_count)
|
|
|
|
final_zero = exit(switch_count[0], name="b_count")
|
|
self.loop_exits.append(final_zero)
|
|
if outer_grad_state is not None:
|
|
# Force the stack pops of i-th execution of an inner loop to be ordered
|
|
# before the pops of (i+1)-th execution of the same inner loop.
|
|
# pylint: disable=protected-access
|
|
outer_grad_state.grad_sync._add_control_input(final_zero.op)
|
|
# pylint: enable=protected-access
|
|
|
|
self.ExitResult([final_zero])
|
|
self.Exit()
|
|
return next_count
|
|
|
|
def AddBackpropAccumulator(self, op, grad):
|
|
"""Add an accumulation loop for every loop invariant.
|
|
|
|
This is added to the backprop loop. It is used to accumulate partial
|
|
gradients within each loop iteration. Called when in the gradient while
|
|
context.
|
|
|
|
The pseudocode is:
|
|
```
|
|
acc = 0.0;
|
|
while (_pivot) {
|
|
acc += grad;
|
|
}
|
|
```
|
|
|
|
Args:
|
|
op: The Enter op for a loop invariant.
|
|
grad: The partial gradient of an iteration for a loop invariant.
|
|
|
|
Returns:
|
|
The gradient for a loop invariant.
|
|
"""
|
|
self.Exit()
|
|
# Create a zeros tensor with the right shape for acc. If we don't
|
|
# know the full shape statically, we will have to get the shape
|
|
# dynamically from the forward inference. Getting the shape right
|
|
# for the zeros is only needed for the base case when the loop exits
|
|
# without running any iterations.
|
|
shape = grad.get_shape()
|
|
if shape.is_fully_defined():
|
|
if self.outer_context:
|
|
self.outer_context.Enter()
|
|
acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc")
|
|
if self.outer_context:
|
|
self.outer_context.Exit()
|
|
else:
|
|
value = op.inputs[0]
|
|
if (isinstance(self.outer_context, WhileContext) and
|
|
self.outer_context.grad_state is not None):
|
|
# We are in a nested while loop.
|
|
forward_ctxt = self.grad_state.forward_context
|
|
forward_ctxt.outer_context.Enter()
|
|
zeros_shape = array_ops.shape_internal(value, optimize=False)
|
|
forward_ctxt.outer_context.Exit()
|
|
outer_grad_state = self.grad_state.outer_grad_state
|
|
history_zeros_shape = outer_grad_state.AddForwardAccumulator(
|
|
zeros_shape)
|
|
self.outer_context.Enter()
|
|
real_shape = outer_grad_state.AddBackpropAccumulatedValue(
|
|
history_zeros_shape, zeros_shape)
|
|
acc = array_ops.zeros(real_shape, grad.dtype)
|
|
self.outer_context.Exit()
|
|
else:
|
|
if self.outer_context:
|
|
self.outer_context.Enter()
|
|
zeros_shape = array_ops.shape_internal(value, optimize=False)
|
|
acc = array_ops.zeros(zeros_shape, grad.dtype)
|
|
if self.outer_context:
|
|
self.outer_context.Exit()
|
|
|
|
self.Enter()
|
|
self.AddName(acc.name)
|
|
enter_acc = _Enter(
|
|
acc,
|
|
self._name,
|
|
is_constant=False,
|
|
parallel_iterations=self._parallel_iterations,
|
|
name="b_acc")
|
|
self.loop_enters.append(enter_acc)
|
|
|
|
merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0]
|
|
switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot)
|
|
|
|
add_acc = math_ops.add(switch_acc_true, grad)
|
|
next_acc = _NextIteration(add_acc)
|
|
merge_acc.op._update_input(1, next_acc) # pylint: disable=protected-access
|
|
|
|
result_acc = exit(switch_acc_false, name="b_acc")
|
|
self.loop_exits.append(result_acc)
|
|
self.ExitResult([result_acc])
|
|
return result_acc
|
|
|
|
def AddBackpropIndexedSlicesAccumulator(self, op, grad):
|
|
"""This is used for accumulating gradients that are IndexedSlices.
|
|
|
|
This is essentially the equivalent of AddBackpropAccumulator but optimized
|
|
for things like updating embeddings from within a while loop.
|
|
|
|
Args:
|
|
op: The Enter op for a loop invariant.
|
|
grad: The partial gradients represented as an IndexedSlices.
|
|
|
|
Returns:
|
|
The accumulated IndexedSlices gradient of the loop invariant.
|
|
"""
|
|
values = grad.values
|
|
indices = grad.indices
|
|
dense_shape = grad.dense_shape
|
|
|
|
self.Exit()
|
|
if self.outer_context:
|
|
self.outer_context.Enter()
|
|
if values.get_shape().is_fully_defined():
|
|
values_shape = tensor_shape.TensorShape([tensor_shape.Dimension(1)] +
|
|
values.get_shape().dims[1:])
|
|
if self.outer_context:
|
|
self.outer_context.Enter()
|
|
values_acc = constant_op.constant(
|
|
0, values.dtype, shape=values_shape, name="b_acc")
|
|
if self.outer_context:
|
|
self.outer_context.Exit()
|
|
else:
|
|
values_shape = _resource_safe_shape(op.inputs[0])[1:]
|
|
values_shape = array_ops.concat([[1], values_shape], 0)
|
|
values_acc = array_ops.zeros(values_shape, dtype=values.dtype)
|
|
indices_acc = constant_op.constant([0], indices.dtype)
|
|
shape_acc = None
|
|
if dense_shape is not None:
|
|
if dense_shape.get_shape().is_fully_defined():
|
|
if self.outer_context:
|
|
self.outer_context.Enter()
|
|
shape_acc = constant_op.constant(
|
|
0, dense_shape.dtype, shape=dense_shape.get_shape())
|
|
if self.outer_context:
|
|
self.outer_context.Exit()
|
|
else:
|
|
shape_acc = array_ops.zeros_like(
|
|
array_ops.shape_internal(
|
|
op.inputs[0], optimize=False, out_type=dense_shape.dtype),
|
|
optimize=False)
|
|
|
|
if self.outer_context:
|
|
self.outer_context.Exit()
|
|
|
|
self.Enter()
|
|
self.AddName(values_acc.name)
|
|
self.AddName(indices_acc.name)
|
|
init_acc = [indices_acc, values_acc]
|
|
if shape_acc is not None:
|
|
self.AddName(shape_acc.name)
|
|
init_acc.append(shape_acc)
|
|
|
|
# Set use_input_shape=False since the accumulator tensors will grow in
|
|
# size. If use_input_shape=True, the _update_input call below will result in
|
|
# incompatible shapes.
|
|
enter_acc = [
|
|
_Enter(
|
|
x,
|
|
self._name,
|
|
is_constant=False,
|
|
parallel_iterations=self._parallel_iterations,
|
|
use_input_shape=False,
|
|
name="b_acc") for x in init_acc
|
|
]
|
|
# Manually set appropriate partial shapes.
|
|
enter_acc[0].set_shape([None])
|
|
if values_acc.shape.dims is not None:
|
|
enter_acc[1].set_shape([None] + values_acc.shape.as_list()[1:])
|
|
self.loop_enters.extend(enter_acc)
|
|
|
|
merge_acc = [merge([x, x], name="b_acc")[0] for x in enter_acc]
|
|
switch_acc = [switch(x, self._pivot) for x in merge_acc]
|
|
|
|
# The actual accumulation.
|
|
acc_indexed_slices = [
|
|
array_ops.concat([xa[1], xv], 0)
|
|
for xa, xv in zip(switch_acc[:2], [indices, values])
|
|
]
|
|
if shape_acc is not None:
|
|
# For the shape we just keep the maximum
|
|
acc_indexed_slices.append(math_ops.maximum(dense_shape, switch_acc[2][1]))
|
|
|
|
next_acc = [_NextIteration(x) for x in acc_indexed_slices]
|
|
for xm, xn in zip(merge_acc, next_acc):
|
|
xm.op._update_input(1, xn) # pylint: disable=protected-access
|
|
|
|
exit_acc = [exit(x[0], name="b_acc") for x in switch_acc]
|
|
self.loop_exits.extend(exit_acc)
|
|
|
|
self.ExitResult(exit_acc)
|
|
return ops.IndexedSlices(
|
|
indices=exit_acc[0],
|
|
values=exit_acc[1],
|
|
dense_shape=exit_acc[2] if shape_acc is not None else None)
|
|
|
|
def _InitializeValues(self, values):
|
|
"""Makes the values known to this context."""
|
|
self._values = set()
|
|
for x in values:
|
|
if isinstance(x, ops.Tensor):
|
|
self._values.add(x.name)
|
|
else:
|
|
raise TypeError("Type %s not supported" % type(x))
|
|
|
|
def _BuildLoop(self, pred, body, original_loop_vars, loop_vars,
|
|
shape_invariants):
|
|
"""Core: Add the loop termination condition and body to the graph."""
|
|
flat_loop_vars = nest.flatten(original_loop_vars, expand_composites=True)
|
|
|
|
# Let the context know the loop variables so the loop variables
|
|
# would be added in the outer contexts properly.
|
|
self._InitializeValues(loop_vars)
|
|
real_vars = loop_vars
|
|
if self._outer_context:
|
|
real_vars = [self._outer_context.AddValue(x) for x in loop_vars]
|
|
with ops.control_dependencies(None):
|
|
enter_vars = [
|
|
_Enter(
|
|
x,
|
|
self._name,
|
|
is_constant=False,
|
|
parallel_iterations=self._parallel_iterations,
|
|
use_input_shape=(shape_invariants is None)) for x in real_vars
|
|
]
|
|
for x in enter_vars:
|
|
x.graph.prevent_feeding(x)
|
|
if self._outer_context:
|
|
self._outer_context.AddInnerOp(x.op)
|
|
|
|
# Finds the closest enclosing non-None control pivot.
|
|
outer_context = self._outer_context
|
|
control_pivot = None
|
|
while outer_context is not None and control_pivot is None:
|
|
control_pivot = outer_context.GetControlPivot()
|
|
# pylint: disable=protected-access
|
|
outer_context = outer_context._outer_context
|
|
# pylint: enable=protected-access
|
|
|
|
if control_pivot is not None:
|
|
for var in enter_vars:
|
|
if util.IsLoopConstantEnter(var.op.inputs[0].op):
|
|
# pylint: disable=protected-access
|
|
var.op._add_control_input(control_pivot.op)
|
|
# pylint: enable=protected-access
|
|
_SetShapeInvariants(real_vars, enter_vars, shape_invariants)
|
|
|
|
# Fix the control inputs and control flow context of these enter ops.
|
|
self._FixControlInputsAndContext(enter_vars)
|
|
self._InitializeValues(enter_vars)
|
|
self._loop_enters = enter_vars
|
|
|
|
merge_vars = [merge([x, x])[0] for x in enter_vars]
|
|
self._pivot_for_pred = merge_vars[0]
|
|
|
|
# Build the graph for pred.
|
|
merge_vars_with_tensor_arrays = (
|
|
_convert_flows_to_tensorarrays(flat_loop_vars, merge_vars))
|
|
packed_vars = nest.pack_sequence_as(
|
|
structure=original_loop_vars,
|
|
flat_sequence=merge_vars_with_tensor_arrays,
|
|
expand_composites=True)
|
|
c = ops.convert_to_tensor(pred(*packed_vars))
|
|
self._pivot = loop_cond(c, name="LoopCond")
|
|
switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars]
|
|
|
|
# Build the graph for body.
|
|
vars_for_body = [_Identity(x[1]) for x in switch_vars]
|
|
self._pivot_for_body = vars_for_body[0]
|
|
# Convert TensorArray flow variables inside the context back into
|
|
# their associated TensorArrays for calling the body.
|
|
vars_for_body_with_tensor_arrays = (
|
|
_convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body))
|
|
packed_vars_for_body = nest.pack_sequence_as(
|
|
structure=original_loop_vars,
|
|
flat_sequence=vars_for_body_with_tensor_arrays,
|
|
expand_composites=True)
|
|
pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
|
|
body_result = body(*packed_vars_for_body)
|
|
post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
|
|
if not nest.is_sequence_or_composite(body_result):
|
|
body_result = [body_result]
|
|
if len(post_summaries) > len(pre_summaries):
|
|
new_summaries = post_summaries[len(pre_summaries):]
|
|
summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
|
|
summary_ref[:] = pre_summaries
|
|
with ops.control_dependencies(new_summaries):
|
|
|
|
def map_fn(x):
|
|
# TODO(apassos) figure out how to trigger with tensor arrays as well
|
|
if isinstance(x, tensor_array_ops.TensorArray):
|
|
return x
|
|
return array_ops.identity(x)
|
|
|
|
body_result = nest.map_structure(
|
|
map_fn, body_result, expand_composites=True)
|
|
|
|
# Compare the structure types of input and output of body.
|
|
# For backwards compatibility, the first layer is forced to a list
|
|
# during this comparison, because inputs are typically lists and
|
|
# outputs of the body are typically tuples.
|
|
nest.assert_same_structure(
|
|
list(packed_vars_for_body), list(body_result), expand_composites=True)
|
|
|
|
# Store body_result to keep track of TensorArrays returned by body
|
|
original_body_result = body_result
|
|
# Convert TensorArrays returned by body into their flow variables
|
|
result = nest.map_structure(
|
|
_convert_tensorarray_to_flow,
|
|
nest.flatten(body_result, expand_composites=True),
|
|
expand_composites=True)
|
|
result = ops.convert_n_to_tensor_or_composite(result)
|
|
|
|
# Add NextIteration and the back edges to complete the loop.
|
|
if len(merge_vars) != len(result):
|
|
raise ValueError("Number of inputs and outputs of body must match "
|
|
"loop_vars: %d, %d" % (len(merge_vars), len(result)))
|
|
next_vars = []
|
|
for m, v in zip(merge_vars, result):
|
|
next_vars.append(_AddNextAndBackEdge(m, v))
|
|
|
|
# Add the exit ops.
|
|
exit_vars = [exit(x[0]) for x in switch_vars]
|
|
self._loop_exits = exit_vars
|
|
|
|
# Exit the loop.
|
|
self.ExitResult(exit_vars)
|
|
|
|
return original_body_result, exit_vars
|
|
|
|
def BuildLoop(self, pred, body, loop_vars, shape_invariants,
|
|
return_same_structure):
|
|
"""Add the loop termination condition and body to the graph."""
|
|
|
|
# Keep original_loop_vars to identify which are TensorArrays
|
|
original_loop_vars = loop_vars
|
|
# Convert TensorArrays to their flow variables
|
|
loop_vars = nest.map_structure(
|
|
_convert_tensorarray_to_flow,
|
|
nest.flatten(loop_vars, expand_composites=False),
|
|
expand_composites=True)
|
|
loop_vars = ops.convert_n_to_tensor_or_composite(loop_vars)
|
|
if shape_invariants is None:
|
|
shape_invariants = nest.map_structure(
|
|
_get_shape_invariant, loop_vars, expand_composites=False)
|
|
loop_vars = nest.flatten(loop_vars, expand_composites=True)
|
|
try:
|
|
self.Enter()
|
|
# _BuildLoop calls _update_input in several places. _mutation_lock()
|
|
# ensures a Session.run call cannot occur between creating and mutating
|
|
# new ops.
|
|
with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access
|
|
original_body_result, exit_vars = self._BuildLoop(
|
|
pred, body, original_loop_vars, loop_vars, shape_invariants)
|
|
finally:
|
|
self.Exit()
|
|
|
|
flat_result = nest.flatten(original_body_result, expand_composites=True)
|
|
# Convert TensorArray flow variables outside the context back into
|
|
# their associated TensorArrays for returning to caller.
|
|
exit_vars_with_tensor_arrays = (
|
|
_convert_flows_to_tensorarrays(flat_result, exit_vars))
|
|
packed_exit_vars = nest.pack_sequence_as(
|
|
structure=original_body_result,
|
|
flat_sequence=exit_vars_with_tensor_arrays,
|
|
expand_composites=True)
|
|
|
|
if return_same_structure:
|
|
return packed_exit_vars
|
|
else:
|
|
return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars
|
|
|
|
def _FixControlInputsAndContext(self, enters):
|
|
graph = ops.get_default_graph()
|
|
# pylint: disable=protected-access
|
|
for e in enters:
|
|
if isinstance(e, ops.Tensor):
|
|
xs = [e]
|
|
else:
|
|
raise TypeError("Type %s not supported" % type(e))
|
|
for x in xs:
|
|
inp_op = x.op.inputs[0].op
|
|
control_inputs = graph._control_dependencies_for_inputs([inp_op])
|
|
outer_control_inputs = []
|
|
for op in control_inputs:
|
|
# We need to keep control inputs that are in any ancestor
|
|
# ControlFlowContext, and within outer WhileContext.
|
|
keep_as_control_input = True
|
|
op_ctxt = util.GetOutputContext(op)
|
|
outer_ctxt = self.outer_context
|
|
outer_while_context = (None if outer_ctxt is None else
|
|
outer_ctxt.GetWhileContext())
|
|
while outer_ctxt != op_ctxt:
|
|
if outer_ctxt is None or outer_ctxt == outer_while_context:
|
|
keep_as_control_input = False
|
|
break
|
|
outer_ctxt = outer_ctxt.outer_context
|
|
if keep_as_control_input:
|
|
outer_control_inputs.append(op)
|
|
x.op._set_control_flow_context(self)
|
|
x.op._add_control_inputs(outer_control_inputs)
|
|
graph._record_op_seen_by_control_dependencies(x.op)
|
|
# pylint: enable=protected-access
|
|
|
|
def IsWhileContext(self):
|
|
return True
|
|
|
|
|
|
# @TODO(b/133606651) Replace "shape_invariants" with "loop_vars_signature".
|
|
# pylint: disable=redefined-outer-name
|
|
@tf_export("while_loop", v1=[])
|
|
@deprecation.deprecated_arg_values(
|
|
None,
|
|
"""back_prop=False is deprecated. Consider using tf.stop_gradient instead.
|
|
Instead of:
|
|
results = tf.while_loop(c, b, vars, back_prop=False)
|
|
Use:
|
|
results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))""",
|
|
warn_once=True,
|
|
back_prop=False)
|
|
def while_loop_v2(cond,
|
|
body,
|
|
loop_vars,
|
|
shape_invariants=None,
|
|
parallel_iterations=10,
|
|
back_prop=True,
|
|
swap_memory=False,
|
|
maximum_iterations=None,
|
|
name=None):
|
|
"""Repeat `body` while the condition `cond` is true.
|
|
|
|
`cond` is a callable returning a boolean scalar tensor. `body` is a callable
|
|
returning a (possibly nested) tuple, namedtuple or list of tensors of the same
|
|
arity (length and structure) and types as `loop_vars`. `loop_vars` is a
|
|
(possibly nested) tuple, namedtuple or list of tensors that is passed to both
|
|
`cond` and `body`. `cond` and `body` both take as many arguments as there are
|
|
`loop_vars`.
|
|
|
|
In addition to regular Tensors or IndexedSlices, the body may accept and
|
|
return TensorArray objects. The flows of the TensorArray objects will
|
|
be appropriately forwarded between loops and during gradient calculations.
|
|
|
|
Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
|
|
call to `while_loop`, and not at all during `Session.run()`). `while_loop`
|
|
stitches together the graph fragments created during the `cond` and `body`
|
|
calls with some additional graph nodes to create the graph flow that
|
|
repeats `body` until `cond` returns false.
|
|
|
|
For correctness, `tf.while_loop()` strictly enforces shape invariants for
|
|
the loop variables. A shape invariant is a (possibly partial) shape that
|
|
is unchanged across the iterations of the loop. An error will be raised
|
|
if the shape of a loop variable after an iteration is determined to be more
|
|
general than or incompatible with its shape invariant. For example, a shape
|
|
of [11, None] is more general than a shape of [11, 17], and [11, 21] is not
|
|
compatible with [11, 17]. By default (if the argument `shape_invariants` is
|
|
not specified), it is assumed that the initial shape of each tensor in
|
|
`loop_vars` is the same in every iteration. The `shape_invariants` argument
|
|
allows the caller to specify a less specific shape invariant for each loop
|
|
variable, which is needed if the shape varies between iterations. The
|
|
`tf.Tensor.set_shape`
|
|
function may also be used in the `body` function to indicate that
|
|
the output loop variable has a particular shape. The shape invariant for
|
|
SparseTensor and IndexedSlices are treated specially as follows:
|
|
|
|
a) If a loop variable is a SparseTensor, the shape invariant must be
|
|
TensorShape([r]) where r is the rank of the dense tensor represented
|
|
by the sparse tensor. It means the shapes of the three tensors of the
|
|
SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here
|
|
is the shape of the SparseTensor.dense_shape property. It must be the shape of
|
|
a vector.
|
|
|
|
b) If a loop variable is an IndexedSlices, the shape invariant must be
|
|
a shape invariant of the values tensor of the IndexedSlices. It means
|
|
the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]],
|
|
[shape.ndims]).
|
|
|
|
`while_loop` implements non-strict semantics, enabling multiple iterations
|
|
to run in parallel. The maximum number of parallel iterations can be
|
|
controlled by `parallel_iterations`, which gives users some control over
|
|
memory consumption and execution order. For correct programs, `while_loop`
|
|
should return the same result for any parallel_iterations > 0.
|
|
|
|
For training, TensorFlow stores the tensors that are produced in the
|
|
forward inference and are needed in back propagation. These tensors are a
|
|
main source of memory consumption and often cause OOM errors when training
|
|
on GPUs. When the flag swap_memory is true, we swap out these tensors from
|
|
GPU to CPU. This for example allows us to train RNN models with very long
|
|
sequences and large batches.
|
|
|
|
Args:
|
|
cond: A callable that represents the termination condition of the loop.
|
|
body: A callable that represents the loop body.
|
|
loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array,
|
|
`Tensor`, and `TensorArray` objects.
|
|
shape_invariants: The shape invariants for the loop variables.
|
|
parallel_iterations: The number of iterations allowed to run in parallel. It
|
|
must be a positive integer.
|
|
back_prop: (optional) Deprecated. False disables support for back
|
|
propagation. Prefer using `tf.stop_gradient` instead.
|
|
swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
|
|
maximum_iterations: Optional maximum number of iterations of the while loop
|
|
to run. If provided, the `cond` output is AND-ed with an additional
|
|
condition ensuring the number of iterations executed is no greater than
|
|
`maximum_iterations`.
|
|
name: Optional name prefix for the returned tensors.
|
|
|
|
Returns:
|
|
The output tensors for the loop variables after the loop. The return value
|
|
has the same structure as `loop_vars`.
|
|
|
|
Raises:
|
|
TypeError: if `cond` or `body` is not callable.
|
|
ValueError: if `loop_vars` is empty.
|
|
|
|
Example:
|
|
|
|
```python
|
|
i = tf.constant(0)
|
|
c = lambda i: tf.less(i, 10)
|
|
b = lambda i: (tf.add(i, 1), )
|
|
r = tf.while_loop(c, b, [i])
|
|
```
|
|
|
|
Example with nesting and a namedtuple:
|
|
|
|
```python
|
|
import collections
|
|
Pair = collections.namedtuple('Pair', 'j, k')
|
|
ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
|
|
c = lambda i, p: i < 10
|
|
b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
|
|
ijk_final = tf.while_loop(c, b, ijk_0)
|
|
```
|
|
|
|
Example using shape_invariants:
|
|
|
|
```python
|
|
i0 = tf.constant(0)
|
|
m0 = tf.ones([2, 2])
|
|
c = lambda i, m: i < 10
|
|
b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
|
|
tf.while_loop(
|
|
c, b, loop_vars=[i0, m0],
|
|
shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
|
|
```
|
|
|
|
Example which demonstrates non-strict semantics: In the following
|
|
example, the final value of the counter `i` does not depend on `x`. So
|
|
the `while_loop` can increment the counter parallel to updates of `x`.
|
|
However, because the loop counter at one loop iteration depends
|
|
on the value at the previous iteration, the loop counter itself cannot
|
|
be incremented in parallel. Hence if we just want the final value of the
|
|
counter (which we print on the line `print(sess.run(i))`), then
|
|
`x` will never be incremented, but the counter will be updated on a
|
|
single thread. Conversely, if we want the value of the output (which we
|
|
print on the line `print(sess.run(out).shape)`), then the counter may be
|
|
incremented on its own thread, while `x` can be incremented in
|
|
parallel on a separate thread. In the extreme case, it is conceivable
|
|
that the thread incrementing the counter runs until completion before
|
|
`x` is incremented even a single time. The only thing that can never
|
|
happen is that the thread updating `x` can never get ahead of the
|
|
counter thread because the thread incrementing `x` depends on the value
|
|
of the counter.
|
|
|
|
```python
|
|
import tensorflow as tf
|
|
|
|
n = 10000
|
|
x = tf.constant(list(range(n)))
|
|
c = lambda i, x: i < n
|
|
b = lambda i, x: (tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1,
|
|
[i], "x:"))
|
|
i, out = tf.while_loop(c, b, (0, x))
|
|
with tf.compat.v1.Session() as sess:
|
|
print(sess.run(i)) # prints [0] ... [9999]
|
|
|
|
# The following line may increment the counter and x in parallel.
|
|
# The counter thread may get ahead of the other thread, but not the
|
|
# other way around. So you may see things like
|
|
# [9996] x:[9987]
|
|
# meaning that the counter thread is on iteration 9996,
|
|
# while the other thread is on iteration 9987
|
|
print(sess.run(out).shape)
|
|
```
|
|
|
|
"""
|
|
return while_loop(
|
|
cond=cond,
|
|
body=body,
|
|
loop_vars=loop_vars,
|
|
shape_invariants=shape_invariants,
|
|
parallel_iterations=parallel_iterations,
|
|
back_prop=back_prop,
|
|
swap_memory=swap_memory,
|
|
name=name,
|
|
maximum_iterations=maximum_iterations,
|
|
return_same_structure=True)
|
|
|
|
|
|
# pylint: disable=redefined-outer-name
|
|
@tf_export(v1=["while_loop"])
|
|
def while_loop(cond,
|
|
body,
|
|
loop_vars,
|
|
shape_invariants=None,
|
|
parallel_iterations=10,
|
|
back_prop=True,
|
|
swap_memory=False,
|
|
name=None,
|
|
maximum_iterations=None,
|
|
return_same_structure=False):
|
|
"""Repeat `body` while the condition `cond` is true.
|
|
|
|
`cond` is a callable returning a boolean scalar tensor. `body` is a callable
|
|
returning a (possibly nested) tuple, namedtuple or list of tensors of the same
|
|
arity (length and structure) and types as `loop_vars`. `loop_vars` is a
|
|
(possibly nested) tuple, namedtuple or list of tensors that is passed to both
|
|
`cond` and `body`. `cond` and `body` both take as many arguments as there are
|
|
`loop_vars`.
|
|
|
|
In addition to regular Tensors or IndexedSlices, the body may accept and
|
|
return TensorArray objects. The flows of the TensorArray objects will
|
|
be appropriately forwarded between loops and during gradient calculations.
|
|
|
|
Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
|
|
call to `while_loop`, and not at all during `Session.run()`). `while_loop`
|
|
stitches together the graph fragments created during the `cond` and `body`
|
|
calls with some additional graph nodes to create the graph flow that
|
|
repeats `body` until `cond` returns false.
|
|
|
|
For correctness, `tf.while_loop()` strictly enforces shape invariants for
|
|
the loop variables. A shape invariant is a (possibly partial) shape that
|
|
is unchanged across the iterations of the loop. An error will be raised
|
|
if the shape of a loop variable after an iteration is determined to be more
|
|
general than or incompatible with its shape invariant. For example, a shape
|
|
of [11, None] is more general than a shape of [11, 17], and [11, 21] is not
|
|
compatible with [11, 17]. By default (if the argument `shape_invariants` is
|
|
not specified), it is assumed that the initial shape of each tensor in
|
|
`loop_vars` is the same in every iteration. The `shape_invariants` argument
|
|
allows the caller to specify a less specific shape invariant for each loop
|
|
variable, which is needed if the shape varies between iterations. The
|
|
`tf.Tensor.set_shape`
|
|
function may also be used in the `body` function to indicate that
|
|
the output loop variable has a particular shape. The shape invariant for
|
|
SparseTensor and IndexedSlices are treated specially as follows:
|
|
|
|
a) If a loop variable is a SparseTensor, the shape invariant must be
|
|
TensorShape([r]) where r is the rank of the dense tensor represented
|
|
by the sparse tensor. It means the shapes of the three tensors of the
|
|
SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here
|
|
is the shape of the SparseTensor.dense_shape property. It must be the shape of
|
|
a vector.
|
|
|
|
b) If a loop variable is an IndexedSlices, the shape invariant must be
|
|
a shape invariant of the values tensor of the IndexedSlices. It means
|
|
the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]],
|
|
[shape.ndims]).
|
|
|
|
`while_loop` implements non-strict semantics, enabling multiple iterations
|
|
to run in parallel. The maximum number of parallel iterations can be
|
|
controlled by `parallel_iterations`, which gives users some control over
|
|
memory consumption and execution order. For correct programs, `while_loop`
|
|
should return the same result for any parallel_iterations > 0.
|
|
|
|
For training, TensorFlow stores the tensors that are produced in the
|
|
forward inference and are needed in back propagation. These tensors are a
|
|
main source of memory consumption and often cause OOM errors when training
|
|
on GPUs. When the flag swap_memory is true, we swap out these tensors from
|
|
GPU to CPU. This for example allows us to train RNN models with very long
|
|
sequences and large batches.
|
|
|
|
Args:
|
|
cond: A callable that represents the termination condition of the loop.
|
|
body: A callable that represents the loop body.
|
|
loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array,
|
|
`Tensor`, and `TensorArray` objects.
|
|
shape_invariants: The shape invariants for the loop variables.
|
|
parallel_iterations: The number of iterations allowed to run in parallel. It
|
|
must be a positive integer.
|
|
back_prop: Whether backprop is enabled for this while loop.
|
|
swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
|
|
name: Optional name prefix for the returned tensors.
|
|
maximum_iterations: Optional maximum number of iterations of the while loop
|
|
to run. If provided, the `cond` output is AND-ed with an additional
|
|
condition ensuring the number of iterations executed is no greater than
|
|
`maximum_iterations`.
|
|
return_same_structure: If True, output has same structure as `loop_vars`. If
|
|
eager execution is enabled, this is ignored (and always treated as True).
|
|
|
|
Returns:
|
|
The output tensors for the loop variables after the loop.
|
|
If `return_same_structure` is True, the return value has the same
|
|
structure as `loop_vars`.
|
|
If `return_same_structure` is False, the return value is a Tensor,
|
|
TensorArray or IndexedSlice if the length of `loop_vars` is 1, or a list
|
|
otherwise.
|
|
|
|
Raises:
|
|
TypeError: if `cond` or `body` is not callable.
|
|
ValueError: if `loop_vars` is empty.
|
|
|
|
Example:
|
|
|
|
```python
|
|
i = tf.constant(0)
|
|
c = lambda i: tf.less(i, 10)
|
|
b = lambda i: tf.add(i, 1)
|
|
r = tf.while_loop(c, b, [i])
|
|
```
|
|
|
|
Example with nesting and a namedtuple:
|
|
|
|
```python
|
|
import collections
|
|
Pair = collections.namedtuple('Pair', 'j, k')
|
|
ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
|
|
c = lambda i, p: i < 10
|
|
b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
|
|
ijk_final = tf.while_loop(c, b, ijk_0)
|
|
```
|
|
|
|
Example using shape_invariants:
|
|
|
|
```python
|
|
i0 = tf.constant(0)
|
|
m0 = tf.ones([2, 2])
|
|
c = lambda i, m: i < 10
|
|
b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
|
|
tf.while_loop(
|
|
c, b, loop_vars=[i0, m0],
|
|
shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
|
|
```
|
|
|
|
Example which demonstrates non-strict semantics: In the following
|
|
example, the final value of the counter `i` does not depend on `x`. So
|
|
the `while_loop` can increment the counter parallel to updates of `x`.
|
|
However, because the loop counter at one loop iteration depends
|
|
on the value at the previous iteration, the loop counter itself cannot
|
|
be incremented in parallel. Hence if we just want the final value of the
|
|
counter (which we print on the line `print(sess.run(i))`), then
|
|
`x` will never be incremented, but the counter will be updated on a
|
|
single thread. Conversely, if we want the value of the output (which we
|
|
print on the line `print(sess.run(out).shape)`), then the counter may be
|
|
incremented on its own thread, while `x` can be incremented in
|
|
parallel on a separate thread. In the extreme case, it is conceivable
|
|
that the thread incrementing the counter runs until completion before
|
|
`x` is incremented even a single time. The only thing that can never
|
|
happen is that the thread updating `x` can never get ahead of the
|
|
counter thread because the thread incrementing `x` depends on the value
|
|
of the counter.
|
|
|
|
```python
|
|
import tensorflow as tf
|
|
|
|
n = 10000
|
|
x = tf.constant(list(range(n)))
|
|
c = lambda i, x: i < n
|
|
b = lambda i, x: (tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1,
|
|
[i], "x:"))
|
|
i, out = tf.while_loop(c, b, (0, x))
|
|
with tf.compat.v1.Session() as sess:
|
|
print(sess.run(i)) # prints [0] ... [9999]
|
|
|
|
# The following line may increment the counter and x in parallel.
|
|
# The counter thread may get ahead of the other thread, but not the
|
|
# other way around. So you may see things like
|
|
# [9996] x:[9987]
|
|
# meaning that the counter thread is on iteration 9996,
|
|
# while the other thread is on iteration 9987
|
|
print(sess.run(out).shape)
|
|
```
|
|
|
|
"""
|
|
if not callable(cond):
|
|
raise TypeError("cond must be callable.")
|
|
if not callable(body):
|
|
raise TypeError("body must be callable.")
|
|
if parallel_iterations < 1:
|
|
raise TypeError("parallel_iterations must be a positive integer.")
|
|
|
|
# Always enable control flow v2 if building a function, regardless of toggle.
|
|
executing_eagerly = context.executing_eagerly()
|
|
if (util.EnableControlFlowV2(ops.get_default_graph()) and
|
|
not executing_eagerly):
|
|
return while_v2.while_loop(
|
|
cond,
|
|
body,
|
|
loop_vars,
|
|
shape_invariants=shape_invariants,
|
|
parallel_iterations=parallel_iterations,
|
|
maximum_iterations=maximum_iterations,
|
|
name=name,
|
|
return_same_structure=return_same_structure,
|
|
back_prop=back_prop)
|
|
|
|
with ops.name_scope(name, "while", loop_vars):
|
|
if not loop_vars:
|
|
raise ValueError("No loop variables provided")
|
|
try_to_pack = (len(loop_vars) == 1 and not return_same_structure)
|
|
if maximum_iterations is not None:
|
|
maximum_iterations = ops.convert_to_tensor(
|
|
maximum_iterations, name="maximum_iterations")
|
|
if maximum_iterations.shape.ndims != 0:
|
|
raise ValueError("maximum_iterations must be a scalar, saw shape: %s" %
|
|
maximum_iterations.shape)
|
|
|
|
if executing_eagerly:
|
|
counter = 0
|
|
maximum_iterations = int(maximum_iterations.numpy())
|
|
else:
|
|
counter = constant_op.constant(
|
|
0, dtype=maximum_iterations.dtype, name="iteration_counter")
|
|
orig_cond = cond
|
|
orig_body = body
|
|
if try_to_pack:
|
|
loop_vars = (counter, loop_vars[0])
|
|
cond = lambda i, lv: ( # pylint: disable=g-long-lambda
|
|
math_ops.logical_and(i < maximum_iterations, orig_cond(lv)))
|
|
body = lambda i, lv: (i + 1, orig_body(lv))
|
|
else:
|
|
loop_vars = (counter, loop_vars)
|
|
cond = lambda i, lv: ( # pylint: disable=g-long-lambda
|
|
math_ops.logical_and(i < maximum_iterations, orig_cond(*lv)))
|
|
body = lambda i, lv: (i + 1, orig_body(*lv))
|
|
try_to_pack = False
|
|
|
|
if executing_eagerly:
|
|
packed = False # whether the body result was packed into a 1-item tuple
|
|
|
|
loop_var_structure = nest.map_structure(type_spec.type_spec_from_value,
|
|
list(loop_vars))
|
|
while cond(*loop_vars):
|
|
loop_vars = body(*loop_vars)
|
|
if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
|
|
packed = True
|
|
loop_vars = (loop_vars,)
|
|
nest.assert_same_structure(loop_var_structure, list(loop_vars))
|
|
|
|
def convert(x):
|
|
if isinstance(x, tensor_array_ops.TensorArray):
|
|
return x
|
|
return ops.convert_to_tensor(x)
|
|
|
|
loop_vars = nest.map_structure(convert, loop_vars, expand_composites=True)
|
|
if maximum_iterations is not None:
|
|
return loop_vars[1]
|
|
else:
|
|
return loop_vars[0] if packed else loop_vars
|
|
|
|
if shape_invariants is not None:
|
|
if maximum_iterations is not None:
|
|
shape_invariants = (tensor_shape.TensorShape([]), shape_invariants)
|
|
|
|
nest.assert_same_structure(
|
|
loop_vars, shape_invariants, expand_composites=False)
|
|
shape_invariants = nest.map_structure(
|
|
_get_shape_invariant,
|
|
loop_vars,
|
|
shape_invariants,
|
|
expand_composites=False)
|
|
|
|
loop_context = WhileContext(
|
|
maximum_iterations=maximum_iterations,
|
|
parallel_iterations=parallel_iterations,
|
|
back_prop=back_prop,
|
|
swap_memory=swap_memory)
|
|
# Only add non-nested loops to the collection. Any nested control flow will
|
|
# be encapsulated in the root context.
|
|
if loop_context.outer_context is None:
|
|
ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
|
|
result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants,
|
|
return_same_structure)
|
|
if maximum_iterations is not None:
|
|
return result[1]
|
|
else:
|
|
return result
|
|
|
|
|
|
# pylint: enable=redefined-outer-name
|
|
|
|
|
|
def _AsTensorList(x, p):
|
|
"""Return x as a list of Tensors or IndexedSlices.
|
|
|
|
For entries of `x` that are Operations, this returns an Identity of `p`
|
|
with a dependency on the operation.
|
|
|
|
Args:
|
|
x: A Tensor/IndexedSlices/Operation or a list or tuple of them.
|
|
p: A Tensor to return for entries in `x` that are Operations.
|
|
|
|
Returns:
|
|
A list of Tensors or IndexedSlices.
|
|
"""
|
|
if not isinstance(x, (list, _basetuple)):
|
|
x = [x]
|
|
|
|
l = []
|
|
for v in x:
|
|
if isinstance(v, ops.Operation):
|
|
v = with_dependencies([v], p)
|
|
v = ops.convert_to_tensor_or_composite(v)
|
|
if isinstance(v, ops.Tensor):
|
|
l.append(array_ops.identity(v))
|
|
else:
|
|
l.append(
|
|
ops.IndexedSlices(
|
|
array_ops.identity(v.values), array_ops.identity(v.indices)))
|
|
return l
|
|
|
|
|
|
def _CheckResults(a, b):
|
|
assert len(a) == len(b), (
|
|
"Values returned by a() and b() must have the same length.")
|
|
for x, y in zip(a, b):
|
|
assert x.dtype == y.dtype, (
|
|
"Values returned by a() [%s] and b() [%s] must have "
|
|
"the same type: %s, %s." % (x.name, y.name, x.dtype.name, y.dtype.name))
|
|
|
|
|
|
def with_dependencies(dependencies, output_tensor, name=None):
|
|
"""Produces the content of `output_tensor` only after `dependencies`.
|
|
|
|
In some cases, a user may want the output of an operation to be
|
|
consumed externally only after some other dependencies have run
|
|
first. This function ensures returns `output_tensor`, but only after all
|
|
operations in `dependencies` have run. Note that this means that there is
|
|
no guarantee that `output_tensor` will be evaluated after any `dependencies`
|
|
have run.
|
|
|
|
See also `tf.tuple` and `tf.group`.
|
|
|
|
Args:
|
|
dependencies: Iterable of operations to run before this op finishes.
|
|
output_tensor: A `Tensor` or `IndexedSlices` that will be returned.
|
|
name: (Optional) A name for this operation.
|
|
|
|
Returns:
|
|
Same as `output_tensor`.
|
|
|
|
Raises:
|
|
TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
|
|
"""
|
|
if context.executing_eagerly():
|
|
return output_tensor
|
|
with ops.name_scope(name, "control_dependency",
|
|
list(dependencies) + [output_tensor]) as name:
|
|
with ops.colocate_with(output_tensor):
|
|
with ops.control_dependencies(dependencies):
|
|
output_tensor = ops.convert_to_tensor_or_composite(output_tensor)
|
|
if isinstance(output_tensor, ops.Tensor):
|
|
return _Identity(output_tensor, name=name)
|
|
else:
|
|
return ops.IndexedSlices(
|
|
_Identity(output_tensor.values, name=name), output_tensor.indices,
|
|
output_tensor.dense_shape)
|
|
|
|
|
|
def _GroupControlDeps(dev, deps, name=None):
|
|
with ops.control_dependencies(deps):
|
|
if dev is None:
|
|
return no_op(name=name)
|
|
else:
|
|
with ops.device(dev):
|
|
return no_op(name=name)
|
|
|
|
|
|
# TODO(touts): Accept "inputs" as a list.
|
|
@tf_export("group")
|
|
def group(*inputs, **kwargs):
|
|
"""Create an op that groups multiple operations.
|
|
|
|
When this op finishes, all ops in `inputs` have finished. This op has no
|
|
output.
|
|
|
|
Note: *In TensorFlow 2 with eager and/or Autograph, you should not require
|
|
this method, as code executes in your expected order.* Only use tf.group when
|
|
working with v1-style code or in a graph context such as inside `Dataset.map`.
|
|
|
|
When operating in a v1-style graph context, ops are not executed in the same
|
|
order as specified in the code; TensorFlow will attempt to execute ops in
|
|
parallel or in an order convienient to the result it is computing. `tf.group`
|
|
allows you to request that one or more results finish before execution
|
|
continues.
|
|
|
|
`tf.group` creates a single op (of type `NoOp`), and then adds appropriate
|
|
control dependencies. Thus, `c = tf.group(a, b)` will compute the same graph
|
|
as this:
|
|
|
|
with tf.control_dependencies([a, b]):
|
|
c = tf.no_op()
|
|
|
|
See also `tf.tuple` and
|
|
`tf.control_dependencies`.
|
|
|
|
Args:
|
|
*inputs: Zero or more tensors to group.
|
|
name: A name for this operation (optional).
|
|
|
|
Returns:
|
|
An Operation that executes all its inputs.
|
|
|
|
Raises:
|
|
ValueError: If an unknown keyword argument is provided.
|
|
"""
|
|
if context.executing_eagerly():
|
|
return None
|
|
name = kwargs.pop("name", None)
|
|
if kwargs:
|
|
raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys()))
|
|
with ops.name_scope(name, "group_deps", inputs) as name:
|
|
# Grouping no inputs means do nothing
|
|
if not inputs:
|
|
return no_op(name=name)
|
|
|
|
# Sorts *inputs according to their devices.
|
|
ops_on_device = {} # device -> operations specified on the device.
|
|
for inp in nest.flatten(inputs, expand_composites=True):
|
|
if not hasattr(inp, "device"):
|
|
raise TypeError("Expected tf.group() expected Tensor arguments not "
|
|
"'%s' with type '%s'" % (inp, type(inp)))
|
|
dev = inp.device
|
|
if dev in ops_on_device:
|
|
ops_on_device[dev].append(inp)
|
|
else:
|
|
ops_on_device[dev] = [inp]
|
|
if len(ops_on_device) == 1:
|
|
# 1-level tree. The root node is the returned NoOp node.
|
|
(dev, deps), = ops_on_device.items()
|
|
return _GroupControlDeps(dev, deps, name=name)
|
|
|
|
# 2-level tree. The root node is the returned NoOp node.
|
|
# deps contains 1 NoOp node for each device.
|
|
deps = []
|
|
|
|
def device_key(dev):
|
|
"""A sort key that allows None to be compared to strings."""
|
|
return "" if dev is None else dev
|
|
|
|
for dev in sorted(ops_on_device, key=device_key):
|
|
deps.append(_GroupControlDeps(dev, ops_on_device[dev]))
|
|
|
|
with ops.control_dependencies(deps):
|
|
return no_op(name=name)
|
|
|
|
|
|
@tf_export("tuple", v1=[])
|
|
@dispatch.add_dispatch_support
|
|
def tuple_v2(tensors, control_inputs=None, name=None):
|
|
"""Group tensors together.
|
|
|
|
This creates a tuple of tensors with the same values as the `tensors`
|
|
argument, except that the value of each tensor is only returned after the
|
|
values of all tensors have been computed.
|
|
|
|
`control_inputs` contains additional ops that have to finish before this op
|
|
finishes, but whose outputs are not returned.
|
|
|
|
This can be used as a "join" mechanism for parallel computations: all the
|
|
argument tensors can be computed in parallel, but the values of any tensor
|
|
returned by `tuple` are only available after all the parallel computations
|
|
are done.
|
|
|
|
See also `tf.group` and
|
|
`tf.control_dependencies`.
|
|
|
|
Args:
|
|
tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
|
|
control_inputs: List of additional ops to finish before returning.
|
|
name: (optional) A name to use as a `name_scope` for the operation.
|
|
|
|
Returns:
|
|
Same as `tensors`.
|
|
|
|
Raises:
|
|
ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
|
|
TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
|
|
objects.
|
|
|
|
"""
|
|
return tuple(tensors=tensors, name=name, control_inputs=control_inputs) # pylint: disable=redefined-builtin
|
|
|
|
|
|
@tf_export(v1=["tuple"])
|
|
@dispatch.add_dispatch_support
|
|
def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined-builtin
|
|
"""Group tensors together.
|
|
|
|
This creates a tuple of tensors with the same values as the `tensors`
|
|
argument, except that the value of each tensor is only returned after the
|
|
values of all tensors have been computed.
|
|
|
|
`control_inputs` contains additional ops that have to finish before this op
|
|
finishes, but whose outputs are not returned.
|
|
|
|
This can be used as a "join" mechanism for parallel computations: all the
|
|
argument tensors can be computed in parallel, but the values of any tensor
|
|
returned by `tuple` are only available after all the parallel computations
|
|
are done.
|
|
|
|
See also `tf.group` and
|
|
`tf.control_dependencies`.
|
|
|
|
Args:
|
|
tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
|
|
name: (optional) A name to use as a `name_scope` for the operation.
|
|
control_inputs: List of additional ops to finish before returning.
|
|
|
|
Returns:
|
|
Same as `tensors`.
|
|
|
|
Raises:
|
|
ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
|
|
TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
|
|
objects.
|
|
|
|
"""
|
|
if context.executing_eagerly():
|
|
return tensors
|
|
with ops.name_scope(name, "tuple", tensors) as name:
|
|
tensors = [
|
|
t if (isinstance(t, ops.Operation) or tensor_util.is_tensor(t) or
|
|
t is None) else ops.convert_to_tensor(t) for t in tensors
|
|
]
|
|
gating_ops = [
|
|
t if isinstance(t, ops.Operation) else t.op
|
|
for t in tensors
|
|
if t is not None
|
|
]
|
|
if control_inputs:
|
|
for c in control_inputs:
|
|
if isinstance(c, ops.Tensor):
|
|
c = c.op
|
|
elif not isinstance(c, ops.Operation):
|
|
raise TypeError("Control input must be Operation or Tensor: %s" % c)
|
|
gating_ops.append(c)
|
|
# Note that in order to ensure ordering in the pbtxt, we must take care to
|
|
# ensure the order here.
|
|
gating_ops = sorted(set(gating_ops), key=lambda op: op._id) # Uniquify ops.
|
|
if not gating_ops:
|
|
raise ValueError("Must have at least one Tensor: %s" % tensors)
|
|
gate = group(*gating_ops)
|
|
tpl = []
|
|
for t in tensors:
|
|
if tensor_util.is_tensor(t):
|
|
tpl.append(with_dependencies([gate], t))
|
|
elif isinstance(t, ops.Operation):
|
|
with ops.control_dependencies([gate]):
|
|
tpl.append(group(t))
|
|
else:
|
|
tpl.append(None)
|
|
return tpl
|
|
|
|
|
|
def _assert_at_most_n_true(predicates, n, msg):
|
|
"""Returns an Assert op that checks that at most n predicates are True.
|
|
|
|
Args:
|
|
predicates: list of bool scalar tensors.
|
|
n: maximum number of true predicates allowed.
|
|
msg: Error message.
|
|
"""
|
|
preds_c = array_ops.stack(predicates, name="preds_c")
|
|
num_true_conditions = math_ops.reduce_sum(
|
|
math_ops.cast(preds_c, dtypes.int32), name="num_true_conds")
|
|
condition = math_ops.less_equal(num_true_conditions,
|
|
constant_op.constant(n, name="n_true_conds"))
|
|
preds_names = ", ".join(getattr(p, "name", "?") for p in predicates)
|
|
error_msg = [
|
|
"%s: more than %d conditions (%s) evaluated as True:" %
|
|
(msg, n, preds_names), preds_c
|
|
]
|
|
return Assert(condition, data=error_msg, summarize=len(predicates))
|
|
|
|
|
|
def _case_create_default_action(predicates, actions):
|
|
"""Creates default action for a list of actions and their predicates.
|
|
|
|
It uses the input actions to select an arbitrary as default and makes sure
|
|
that corresponding predicates have valid values.
|
|
|
|
Args:
|
|
predicates: a list of bool scalar tensors
|
|
actions: a list of callable objects which return tensors.
|
|
|
|
Returns:
|
|
a callable
|
|
"""
|
|
k = len(predicates) - 1 # could pick any
|
|
predicate, action = predicates[k], actions[k]
|
|
other_predicates, other_actions = predicates[:k], actions[:k]
|
|
|
|
def default_action():
|
|
others_msg = ("Implementation error: "
|
|
"selected default action #%d was called, but some of other "
|
|
"predicates are True: " % k)
|
|
default_msg = ("Input error: "
|
|
"None of conditions evaluated as True:",
|
|
array_ops.stack(predicates, name="preds_c"))
|
|
with ops.control_dependencies([
|
|
_assert_at_most_n_true(other_predicates, n=0, msg=others_msg),
|
|
Assert(predicate, data=default_msg)
|
|
]):
|
|
return action()
|
|
|
|
return default_action, other_predicates, other_actions
|
|
|
|
|
|
def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name,
|
|
allow_python_preds):
|
|
"""Verifies input arguments for the case function.
|
|
|
|
Args:
|
|
pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a
|
|
callable which returns a list of tensors.
|
|
exclusive: True iff at most one predicate is allowed to evaluate to `True`.
|
|
name: A name for the case operation.
|
|
allow_python_preds: if true, pred_fn_pairs may contain Python bools in
|
|
addition to boolean Tensors
|
|
|
|
Raises:
|
|
TypeError: If `pred_fn_pairs` is not a list/dictionary.
|
|
TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
|
|
TypeError: If `fns[i]` is not callable for any i, or `default` is not
|
|
callable.
|
|
|
|
Returns:
|
|
a tuple <list of scalar bool tensors, list of callables>.
|
|
"""
|
|
if not isinstance(pred_fn_pairs, (list, _basetuple, dict)):
|
|
raise TypeError("fns must be a list, tuple, or dict")
|
|
|
|
if isinstance(pred_fn_pairs, collections.OrderedDict):
|
|
pred_fn_pairs = pred_fn_pairs.items()
|
|
elif isinstance(pred_fn_pairs, dict):
|
|
if context.executing_eagerly():
|
|
# No name to sort on in eager mode. Use dictionary traversal order,
|
|
# which is nondeterministic in versions of Python < 3.6
|
|
if not exclusive:
|
|
raise ValueError("Unordered dictionaries are not supported for the "
|
|
"`pred_fn_pairs` argument when `exclusive=False` and "
|
|
"eager mode is enabled.")
|
|
pred_fn_pairs = list(pred_fn_pairs.items())
|
|
else:
|
|
pred_fn_pairs = sorted(
|
|
pred_fn_pairs.items(), key=lambda item: item[0].name)
|
|
if not exclusive:
|
|
logging.warn(
|
|
"%s: An unordered dictionary of predicate/fn pairs was "
|
|
"provided, but exclusive=False. The order of conditional "
|
|
"tests is deterministic but not guaranteed.", name)
|
|
for pred_fn_pair in pred_fn_pairs:
|
|
if not isinstance(pred_fn_pair, _basetuple) or len(pred_fn_pair) != 2:
|
|
raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple")
|
|
pred, fn = pred_fn_pair
|
|
|
|
if isinstance(pred, ops.Tensor):
|
|
if pred.dtype != dtypes.bool:
|
|
raise TypeError("pred must be Tensor of type bool: %s" % pred.name)
|
|
elif not allow_python_preds:
|
|
raise TypeError("pred must be a Tensor, got: %s" % pred)
|
|
elif not isinstance(pred, bool):
|
|
raise TypeError("pred must be a Tensor or bool, got: %s" % pred)
|
|
|
|
if not callable(fn):
|
|
raise TypeError("fn for pred %s must be callable." % pred.name)
|
|
|
|
predicates, actions = zip(*pred_fn_pairs)
|
|
return predicates, actions
|
|
|
|
|
|
def _case_helper(cond_fn,
|
|
pred_fn_pairs,
|
|
default,
|
|
exclusive,
|
|
name,
|
|
allow_python_preds=False,
|
|
**cond_kwargs):
|
|
"""Implementation of case that allows for different cond functions.
|
|
|
|
Args:
|
|
cond_fn: method that has signature and semantics of `cond` above.
|
|
pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a
|
|
callable which returns a list of tensors.
|
|
default: Optional callable that returns a list of tensors.
|
|
exclusive: True iff at most one predicate is allowed to evaluate to `True`.
|
|
name: A name for this operation (optional).
|
|
allow_python_preds: if true, pred_fn_pairs may contain Python bools in
|
|
addition to boolean Tensors
|
|
**cond_kwargs: keyword arguments that will be passed to `cond_fn`.
|
|
|
|
Returns:
|
|
The tensors returned by the first pair whose predicate evaluated to True, or
|
|
those returned by `default` if none does.
|
|
|
|
Raises:
|
|
TypeError: If `pred_fn_pairs` is not a list/dictionary.
|
|
TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
|
|
TypeError: If `fns[i]` is not callable for any i, or `default` is not
|
|
callable.
|
|
"""
|
|
predicates, actions = _case_verify_and_canonicalize_args(
|
|
pred_fn_pairs, exclusive, name, allow_python_preds)
|
|
with ops.name_scope(name, "case", [predicates]):
|
|
if default is None:
|
|
default, predicates, actions = _case_create_default_action(
|
|
predicates, actions)
|
|
fn = default
|
|
# To eval conditions in direct order we create nested conditions in reverse:
|
|
# cond_fn(c[0], true_fn=.., false_fn=cond_fn(c[1], ...))
|
|
for predicate, action in reversed(list(zip(predicates, actions))):
|
|
fn = functools.partial(
|
|
cond_fn, predicate, true_fn=action, false_fn=fn, **cond_kwargs)
|
|
if exclusive:
|
|
with ops.control_dependencies([
|
|
_assert_at_most_n_true(
|
|
predicates, n=1, msg="Input error: exclusive=True")
|
|
]):
|
|
return fn()
|
|
else:
|
|
return fn()
|
|
|
|
|
|
def _indexed_case_verify_and_canonicalize_args(branch_fns, default,
|
|
branch_index):
|
|
"""Verifies input arguments for the case function.
|
|
|
|
Args:
|
|
branch_fns: Dict or list of pairs of an `int` and a callable which
|
|
returns a list of tensors.
|
|
default: Optional callable that returns a list of tensors.
|
|
branch_index: Optional int `Tensor`, which selects for the corresponding
|
|
pred_fn_pair.
|
|
|
|
Raises:
|
|
TypeError: If `branch_fns` is not a list/dictionary.
|
|
TypeError: If `branch_fns` is a list but does not contain 2-tuples or
|
|
callables.
|
|
TypeError: If `fns[i]` is not callable for any i, or `default` is not
|
|
callable.
|
|
|
|
Returns:
|
|
branch_fns: validated list of callables for each branch (default last).
|
|
"""
|
|
if not isinstance(branch_index, ops.Tensor):
|
|
raise TypeError("branch_index must a Tensor, got {}".format(
|
|
type(branch_index)))
|
|
if not branch_index.dtype.is_integer:
|
|
raise TypeError("branch_index must an integer Tensor, got {}".format(
|
|
branch_index.dtype))
|
|
|
|
if not branch_fns:
|
|
raise ValueError("Must provide at least one item in branch_fns")
|
|
if not isinstance(branch_fns, (list, _basetuple, dict)):
|
|
raise TypeError("branch_fns must be a list, tuple, or dict")
|
|
|
|
if isinstance(branch_fns, dict):
|
|
branch_fns = branch_fns.items()
|
|
|
|
if all(callable(fn) for fn in branch_fns):
|
|
branch_fns = list(enumerate(branch_fns))
|
|
|
|
for key_fn_pair in branch_fns:
|
|
if not isinstance(key_fn_pair, _basetuple) or len(key_fn_pair) != 2:
|
|
raise TypeError("Each entry in branch_fns must be a 2-tuple")
|
|
key, branch_fn = key_fn_pair
|
|
|
|
if not isinstance(key, int):
|
|
raise TypeError("key must be a Python `int`, got {}".format(type(key)))
|
|
|
|
if not callable(branch_fn):
|
|
raise TypeError("fn for key {} must be callable.".format(key))
|
|
|
|
keys = [p[0] for p in branch_fns]
|
|
if min(keys) < 0 or max(keys) >= len(keys) or len(set(keys)) != len(keys):
|
|
raise ValueError(
|
|
"branch indices (keys) must form contiguous range of [0 to {}) but "
|
|
"found {{{}}}".format(len(keys), ",".join(map(str, sorted(keys)))))
|
|
actions = [p[1] for p in sorted(branch_fns)]
|
|
if default is not None:
|
|
actions.append(default)
|
|
return actions
|
|
|
|
|
|
def _indexed_case_helper(branch_fns,
|
|
default,
|
|
branch_index,
|
|
name,
|
|
lower_using_switch_merge=None):
|
|
"""Implementation of case that emits the n-way indexed Case op.
|
|
|
|
Args:
|
|
branch_fns: Dict or list of pairs of a boolean scalar tensor, and a
|
|
callable which returns a list of tensors.
|
|
default: Optional callable that returns a list of tensors.
|
|
branch_index: Optional int `Tensor`, which selects for the corresponding
|
|
pred_fn_pair.
|
|
name: A name for this operation (optional).
|
|
lower_using_switch_merge: Lower this op using switch merge ops (optional).
|
|
|
|
Returns:
|
|
The tensors returned by the pair whose key matched branch_index, or
|
|
those returned by `default` if none does.
|
|
|
|
Raises:
|
|
TypeError: If `branch_fns` is not a list/dictionary.
|
|
TypeError: If `branch_fns` is a list but does not contain 2-tuples or
|
|
callables.
|
|
TypeError: If `fns[i]` is not callable for any i, or `default` is not
|
|
callable.
|
|
"""
|
|
branch_fns = _indexed_case_verify_and_canonicalize_args(
|
|
branch_fns, default, branch_index)
|
|
with ops.name_scope(name, "case", [branch_index]):
|
|
if context.executing_eagerly() and not hasattr(branch_index, "graph"):
|
|
branch_index = array_ops.where(
|
|
math_ops.less(branch_index, 0)
|
|
| math_ops.greater_equal(branch_index, len(branch_fns)),
|
|
len(branch_fns) - 1, branch_index)
|
|
return branch_fns[int(branch_index)]()
|
|
return cond_v2.indexed_case(
|
|
branch_index,
|
|
branch_fns,
|
|
lower_using_switch_merge=lower_using_switch_merge)
|
|
|
|
|
|
@tf_export("case", v1=[])
|
|
@dispatch.add_dispatch_support
|
|
def case_v2(pred_fn_pairs,
|
|
default=None,
|
|
exclusive=False,
|
|
strict=False,
|
|
name="case"):
|
|
"""Create a case operation.
|
|
|
|
See also `tf.switch_case`.
|
|
|
|
The `pred_fn_pairs` parameter is a list of pairs of size N.
|
|
Each pair contains a boolean scalar tensor and a python callable that
|
|
creates the tensors to be returned if the boolean evaluates to True.
|
|
`default` is a callable generating a list of tensors. All the callables
|
|
in `pred_fn_pairs` as well as `default` (if provided) should return the same
|
|
number and types of tensors.
|
|
|
|
If `exclusive==True`, all predicates are evaluated, and an exception is
|
|
thrown if more than one of the predicates evaluates to `True`.
|
|
If `exclusive==False`, execution stops at the first predicate which
|
|
evaluates to True, and the tensors generated by the corresponding function
|
|
are returned immediately. If none of the predicates evaluate to True, this
|
|
operation returns the tensors generated by `default`.
|
|
|
|
`tf.case` supports nested structures as implemented in
|
|
`tf.contrib.framework.nest`. All of the callables must return the same
|
|
(possibly nested) value structure of lists, tuples, and/or named tuples.
|
|
Singleton lists and tuples form the only exceptions to this: when returned by
|
|
a callable, they are implicitly unpacked to single values. This
|
|
behavior is disabled by passing `strict=True`.
|
|
|
|
@compatibility(v2)
|
|
`pred_fn_pairs` could be a dictionary in v1. However, tf.Tensor and
|
|
tf.Variable are no longer hashable in v2, so cannot be used as a key for a
|
|
dictionary. Please use a list or a tuple instead.
|
|
@end_compatibility
|
|
|
|
|
|
**Example 1:**
|
|
|
|
Pseudocode:
|
|
|
|
```
|
|
if (x < y) return 17;
|
|
else return 23;
|
|
```
|
|
|
|
Expressions:
|
|
|
|
```python
|
|
f1 = lambda: tf.constant(17)
|
|
f2 = lambda: tf.constant(23)
|
|
r = tf.case([(tf.less(x, y), f1)], default=f2)
|
|
```
|
|
|
|
**Example 2:**
|
|
|
|
Pseudocode:
|
|
|
|
```
|
|
if (x < y && x > z) raise OpError("Only one predicate may evaluate to True");
|
|
if (x < y) return 17;
|
|
else if (x > z) return 23;
|
|
else return -1;
|
|
```
|
|
|
|
Expressions:
|
|
|
|
```python
|
|
def f1(): return tf.constant(17)
|
|
def f2(): return tf.constant(23)
|
|
def f3(): return tf.constant(-1)
|
|
r = tf.case([(tf.less(x, y), f1), (tf.greater(x, z), f2)],
|
|
default=f3, exclusive=True)
|
|
```
|
|
|
|
Args:
|
|
pred_fn_pairs: List of pairs of a boolean scalar tensor and a callable which
|
|
returns a list of tensors.
|
|
default: Optional callable that returns a list of tensors.
|
|
exclusive: True iff at most one predicate is allowed to evaluate to `True`.
|
|
strict: A boolean that enables/disables 'strict' mode; see above.
|
|
name: A name for this operation (optional).
|
|
|
|
Returns:
|
|
The tensors returned by the first pair whose predicate evaluated to True, or
|
|
those returned by `default` if none does.
|
|
|
|
Raises:
|
|
TypeError: If `pred_fn_pairs` is not a list/tuple.
|
|
TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
|
|
TypeError: If `fns[i]` is not callable for any i, or `default` is not
|
|
callable.
|
|
"""
|
|
return _case_helper(
|
|
cond,
|
|
pred_fn_pairs,
|
|
default,
|
|
exclusive,
|
|
name,
|
|
allow_python_preds=False,
|
|
strict=strict)
|
|
|
|
|
|
@tf_export(v1=["case"])
|
|
@dispatch.add_dispatch_support
|
|
def case(pred_fn_pairs,
|
|
default=None,
|
|
exclusive=False,
|
|
strict=False,
|
|
name="case"):
|
|
"""Create a case operation.
|
|
|
|
See also `tf.switch_case`.
|
|
|
|
The `pred_fn_pairs` parameter is a dict or list of pairs of size N.
|
|
Each pair contains a boolean scalar tensor and a python callable that
|
|
creates the tensors to be returned if the boolean evaluates to True.
|
|
`default` is a callable generating a list of tensors. All the callables
|
|
in `pred_fn_pairs` as well as `default` (if provided) should return the same
|
|
number and types of tensors.
|
|
|
|
If `exclusive==True`, all predicates are evaluated, and an exception is
|
|
thrown if more than one of the predicates evaluates to `True`.
|
|
If `exclusive==False`, execution stops at the first predicate which
|
|
evaluates to True, and the tensors generated by the corresponding function
|
|
are returned immediately. If none of the predicates evaluate to True, this
|
|
operation returns the tensors generated by `default`.
|
|
|
|
`tf.case` supports nested structures as implemented in
|
|
`tf.contrib.framework.nest`. All of the callables must return the same
|
|
(possibly nested) value structure of lists, tuples, and/or named tuples.
|
|
Singleton lists and tuples form the only exceptions to this: when returned by
|
|
a callable, they are implicitly unpacked to single values. This
|
|
behavior is disabled by passing `strict=True`.
|
|
|
|
If an unordered dictionary is used for `pred_fn_pairs`, the order of the
|
|
conditional tests is not guaranteed. However, the order is guaranteed to be
|
|
deterministic, so that variables created in conditional branches are created
|
|
in fixed order across runs.
|
|
|
|
@compatibility(eager)
|
|
Unordered dictionaries are not supported in eager mode when `exclusive=False`.
|
|
Use a list of tuples instead.
|
|
@end_compatibility
|
|
|
|
|
|
**Example 1:**
|
|
|
|
Pseudocode:
|
|
|
|
```
|
|
if (x < y) return 17;
|
|
else return 23;
|
|
```
|
|
|
|
Expressions:
|
|
|
|
```python
|
|
f1 = lambda: tf.constant(17)
|
|
f2 = lambda: tf.constant(23)
|
|
r = tf.case([(tf.less(x, y), f1)], default=f2)
|
|
```
|
|
|
|
**Example 2:**
|
|
|
|
Pseudocode:
|
|
|
|
```
|
|
if (x < y && x > z) raise OpError("Only one predicate may evaluate to True");
|
|
if (x < y) return 17;
|
|
else if (x > z) return 23;
|
|
else return -1;
|
|
```
|
|
|
|
Expressions:
|
|
|
|
```python
|
|
def f1(): return tf.constant(17)
|
|
def f2(): return tf.constant(23)
|
|
def f3(): return tf.constant(-1)
|
|
r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2},
|
|
default=f3, exclusive=True)
|
|
```
|
|
|
|
Args:
|
|
pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a
|
|
callable which returns a list of tensors.
|
|
default: Optional callable that returns a list of tensors.
|
|
exclusive: True iff at most one predicate is allowed to evaluate to `True`.
|
|
strict: A boolean that enables/disables 'strict' mode; see above.
|
|
name: A name for this operation (optional).
|
|
|
|
Returns:
|
|
The tensors returned by the first pair whose predicate evaluated to True, or
|
|
those returned by `default` if none does.
|
|
|
|
Raises:
|
|
TypeError: If `pred_fn_pairs` is not a list/dictionary.
|
|
TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
|
|
TypeError: If `fns[i]` is not callable for any i, or `default` is not
|
|
callable.
|
|
"""
|
|
return _case_helper(
|
|
cond,
|
|
pred_fn_pairs,
|
|
default,
|
|
exclusive,
|
|
name,
|
|
allow_python_preds=False,
|
|
strict=strict)
|
|
|
|
|
|
@tf_export("switch_case")
|
|
def switch_case(branch_index,
|
|
branch_fns,
|
|
default=None,
|
|
name="switch_case"):
|
|
"""Create a switch/case operation, i.e. an integer-indexed conditional.
|
|
|
|
See also `tf.case`.
|
|
|
|
This op can be substantially more efficient than `tf.case` when exactly one
|
|
branch will be selected. `tf.switch_case` is more like a C++ switch/case
|
|
statement than `tf.case`, which is more like an if/elif/elif/else chain.
|
|
|
|
The `branch_fns` parameter is either a dict from `int` to callables, or list
|
|
of (`int`, callable) pairs, or simply a list of callables (in which case the
|
|
index is implicitly the key). The `branch_index` `Tensor` is used to select an
|
|
element in `branch_fns` with matching `int` key, falling back to `default`
|
|
if none match, or `max(keys)` if no `default` is provided. The keys must form
|
|
a contiguous set from `0` to `len(branch_fns) - 1`.
|
|
|
|
`tf.switch_case` supports nested structures as implemented in `tf.nest`. All
|
|
callables must return the same (possibly nested) value structure of lists,
|
|
tuples, and/or named tuples.
|
|
|
|
**Example:**
|
|
|
|
Pseudocode:
|
|
|
|
```c++
|
|
switch (branch_index) { // c-style switch
|
|
case 0: return 17;
|
|
case 1: return 31;
|
|
default: return -1;
|
|
}
|
|
```
|
|
or
|
|
```python
|
|
branches = {0: lambda: 17, 1: lambda: 31}
|
|
branches.get(branch_index, lambda: -1)()
|
|
```
|
|
|
|
Expressions:
|
|
|
|
```python
|
|
def f1(): return tf.constant(17)
|
|
def f2(): return tf.constant(31)
|
|
def f3(): return tf.constant(-1)
|
|
r = tf.switch_case(branch_index, branch_fns={0: f1, 1: f2}, default=f3)
|
|
# Equivalent: tf.switch_case(branch_index, branch_fns={0: f1, 1: f2, 2: f3})
|
|
```
|
|
|
|
Args:
|
|
branch_index: An int Tensor specifying which of `branch_fns` should be
|
|
executed.
|
|
branch_fns: A `dict` mapping `int`s to callables, or a `list` of
|
|
(`int`, callable) pairs, or simply a list of callables (in which case the
|
|
index serves as the key). Each callable must return a matching structure
|
|
of tensors.
|
|
default: Optional callable that returns a structure of tensors.
|
|
name: A name for this operation (optional).
|
|
|
|
Returns:
|
|
The tensors returned by the callable identified by `branch_index`, or those
|
|
returned by `default` if no key matches and `default` was provided, or those
|
|
returned by the max-keyed `branch_fn` if no `default` is provided.
|
|
|
|
Raises:
|
|
TypeError: If `branch_fns` is not a list/dictionary.
|
|
TypeError: If `branch_fns` is a list but does not contain 2-tuples or
|
|
callables.
|
|
TypeError: If `fns[i]` is not callable for any i, or `default` is not
|
|
callable.
|
|
"""
|
|
return _indexed_case_helper(branch_fns, default, branch_index, name)
|
|
|
|
|
|
def execute_fn_for_device(device_branch_fns, default_fn, name="execute_fn"):
|
|
"""Executes one of the provided callables based on the device placement.
|
|
|
|
This API is used when the implementations for high level function depend on
|
|
the underlying device placement. It takes a dictionary of device type to
|
|
callables. The device type includes "CPU", "GPU", "TPU", etc. When the type of
|
|
the device where to run this op matches the key in 'device_branch_fns',
|
|
the corresponding callable is executed, falling back to 'default_fn' if none
|
|
matches.
|
|
|
|
**Example:**
|
|
```python
|
|
def f1(): return tf.constant(1)
|
|
def f2(): return tf.constant(2)
|
|
r = tf.execute_fn_for_device({"CPU": f1, "GPU": f2}, default_fn=f1)
|
|
```
|
|
'r' is evaluated as 1 when it runs on CPU, 2 running on GPU, 1 running on
|
|
any other device types.
|
|
|
|
|
|
Args:
|
|
device_branch_fns: a dictionary of device types to the callables. Each
|
|
callable must return a matching structure of tensors.
|
|
default_fn: fallback callable when the underlying device does not match any
|
|
key in the 'device_branch_fns'.
|
|
name: A name for this operation (optional).
|
|
|
|
Returns:
|
|
The tensors returned by the callable identified by device type during
|
|
execution, or those returned by 'default_fn' if no key matches.
|
|
"""
|
|
|
|
device_branch_fns_upper = {k.upper(): v for k, v in device_branch_fns.items()}
|
|
branch_fns = list(device_branch_fns_upper.values())
|
|
devices = list(device_branch_fns_upper.keys())
|
|
device_index = gen_functional_ops.device_index(device_names=devices)
|
|
return _indexed_case_helper(
|
|
branch_fns,
|
|
default_fn,
|
|
device_index,
|
|
name,
|
|
lower_using_switch_merge=False)
|
|
|
|
|
|
class XLAControlFlowContext(ControlFlowContext):
|
|
"""Base class for XLA and TPU control flow contexts."""
|
|
|
|
def __init__(self):
|
|
super(XLAControlFlowContext, self).__init__()
|
|
self._name = "XLAControlFlowContext"
|
|
|
|
def to_control_flow_context_def(self, context_def, export_scope=None):
|
|
# pylint: disable=useless-super-delegation
|
|
# NOTE(slebedev): the method is required by `ControlFlowContext`.
|
|
super(XLAControlFlowContext,
|
|
self).to_control_flow_context_def(context_def, export_scope)
|
|
|
|
def IsXLAContext(self):
|
|
return True
|
|
|
|
def AddOp(self, _):
|
|
pass
|
|
|
|
def AddValue(self, x):
|
|
return x
|
|
|
|
def RequiresUniqueFunctionRetracing(self):
|
|
"""Returns whether the tf.function should be retraced if the context changes.
|
|
"""
|
|
return False
|
|
|
|
|
|
def from_control_flow_context_def(context_def, import_scope=None):
|
|
"""Deserializes `context_def` into the appropriate ControlFlowContext.
|
|
|
|
Args:
|
|
context_def: ControlFlowContextDef proto
|
|
import_scope: Optional `string`. Name scope to add.
|
|
|
|
Returns:
|
|
A ControlFlowContext subclass
|
|
"""
|
|
if context_def.HasField("cond_ctxt"):
|
|
return CondContext.from_proto(
|
|
context_def.cond_ctxt, import_scope=import_scope)
|
|
if context_def.HasField("while_ctxt"):
|
|
return WhileContext.from_proto(
|
|
context_def.while_ctxt, import_scope=import_scope)
|
|
raise NotImplementedError("Unknown ControlFlowContextDef field: %s" %
|
|
context_def.WhichOneof("ctxt"))
|
|
|
|
|
|
ops.register_proto_function(
|
|
ops.GraphKeys.COND_CONTEXT,
|
|
proto_type=control_flow_pb2.CondContextDef,
|
|
to_proto=CondContext.to_proto,
|
|
from_proto=CondContext.from_proto)
|
|
|
|
ops.register_proto_function(
|
|
ops.GraphKeys.WHILE_CONTEXT,
|
|
proto_type=control_flow_pb2.WhileContextDef,
|
|
to_proto=WhileContext.to_proto,
|
|
from_proto=WhileContext.from_proto)
|