STT-tensorflow/tensorflow/python/ops/control_flow_ops_test.py
Yanhua Sun b2f0928940 Add DeviceIndex xla op.
DeviceIndex op: given a list of device names, this operation returns the index of the device this op runs.  In the case of XLA, we are not executing on any device, we return the length of the list.

PiperOrigin-RevId: 317740778
Change-Id: I0679aa0adc5508b83502eee0d2044584577ed5b4
2020-06-22 15:06:38 -07:00

1549 lines
58 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for control_flow_ops.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from absl.testing import parameterized
import numpy as np
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import tf2
from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util_v2
from tensorflow.python.ops import control_flow_v2_toggles
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
from tensorflow.python.training import momentum
from tensorflow.python.util import nest
TestTuple = collections.namedtuple("TestTuple", "a b")
SingletonTestTuple = collections.namedtuple("SingletonTestTuple", "a")
class GroupTestCase(test_util.TensorFlowTestCase):
def _StripNode(self, nd):
snode = node_def_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input)
if nd.device:
snode.device = nd.device
return snode
def _StripGraph(self, gd):
"""Copy gd keeping only, node.name, node.op, node.input, and node.device."""
return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node])
def testGroup_NoDevices(self):
with ops.Graph().as_default() as g:
a = constant_op.constant(0, name="a")
b = constant_op.constant(0, name="b")
c = constant_op.constant(0, name="c")
control_flow_ops.group(a.op, b.op, c.op, name="root")
gd = g.as_graph_def()
self.assertProtoEquals("""
node { name: "a" op: "Const"}
node { name: "b" op: "Const"}
node { name: "c" op: "Const"}
node { name: "root" op: "NoOp" input: "^a" input: "^b" input: "^c" }
""", self._StripGraph(gd))
def testGroup_OneDevice(self):
with ops.Graph().as_default() as g:
with g.device("/task:0"):
a = constant_op.constant(0, name="a")
b = constant_op.constant(0, name="b")
control_flow_ops.group(a.op, b.op, name="root")
gd = g.as_graph_def()
self.assertProtoEquals("""
node { name: "a" op: "Const" device: "/task:0" }
node { name: "b" op: "Const" device: "/task:0" }
node { name: "root" op: "NoOp" input: "^a" input: "^b" device: "/task:0" }
""", self._StripGraph(gd))
def testGroup_MultiDevice(self):
with ops.Graph().as_default() as g:
with g.device("/task:0"):
a = constant_op.constant(0, name="a")
b = constant_op.constant(0, name="b")
with g.device("/task:1"):
c = constant_op.constant(0, name="c")
d = constant_op.constant(0, name="d")
with g.device("/task:2"):
control_flow_ops.group(a.op, b.op, c.op, d.op, name="root")
gd = g.as_graph_def()
self.assertProtoEquals("""
node { name: "a" op: "Const" device: "/task:0"}
node { name: "b" op: "Const" device: "/task:0"}
node { name: "c" op: "Const" device: "/task:1"}
node { name: "d" op: "Const" device: "/task:1"}
node { name: "root/NoOp" op: "NoOp" input: "^a" input: "^b"
device: "/task:0" }
node { name: "root/NoOp_1" op: "NoOp" input: "^c" input: "^d"
device: "/task:1" }
node { name: "root" op: "NoOp" input: "^root/NoOp" input: "^root/NoOp_1"
device: "/task:2" }
""", self._StripGraph(gd))
def testPassingList(self):
with ops.Graph().as_default() as g:
a = constant_op.constant(0, name="a")
b = constant_op.constant(0, name="b")
control_flow_ops.group([a.op, b.op], name="root")
gd = g.as_graph_def()
self.assertProtoEquals("""
node { name: "a" op: "Const"}
node { name: "b" op: "Const"}
node { name: "root" op: "NoOp" input: "^a" input: "^b" }
""", self._StripGraph(gd))
@test_util.run_deprecated_v1
def testPassingNonTensors(self):
with self.assertRaises(TypeError):
control_flow_ops.group(1, 2)
class ShapeTestCase(test_util.TensorFlowTestCase):
def testShape(self):
tensor = constant_op.constant([1.0, 2.0])
self.assertEquals([2], tensor.get_shape())
self.assertEquals([2],
control_flow_ops.with_dependencies(
[constant_op.constant(1.0)], tensor).get_shape())
class WithDependenciesTestCase(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1
def testTupleDependencies(self):
counter = variable_scope.get_variable(
"my_counter", shape=[], initializer=init_ops.zeros_initializer())
increment_counter = state_ops.assign_add(counter, 1)
const_with_dep = control_flow_ops.with_dependencies(
(increment_counter, constant_op.constant(42)),
constant_op.constant(7))
self.evaluate(variables.global_variables_initializer())
self.assertEquals(0, self.evaluate(counter))
self.assertEquals(7, self.evaluate(const_with_dep))
self.assertEquals(1, self.evaluate(counter))
@test_util.run_deprecated_v1
def testListDependencies(self):
counter = variable_scope.get_variable(
"my_counter", shape=[], initializer=init_ops.zeros_initializer())
increment_counter = state_ops.assign_add(counter, 1)
const_with_dep = control_flow_ops.with_dependencies(
[increment_counter, constant_op.constant(42)],
constant_op.constant(7))
self.evaluate(variables.global_variables_initializer())
self.assertEquals(0, self.evaluate(counter))
self.assertEquals(7, self.evaluate(const_with_dep))
self.assertEquals(1, self.evaluate(counter))
class SwitchTestCase(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1
def testIndexedSlicesWithDenseShape(self):
with self.cached_session():
data = ops.IndexedSlices(
constant_op.constant([1, 2, 3]),
constant_op.constant([0, 1, 2]),
dense_shape=constant_op.constant([3]))
zero = constant_op.constant(0)
one = constant_op.constant(1)
less_op = math_ops.less(zero, one)
_, switch_true = control_flow_ops.switch(data, less_op)
self.assertAllEqual([1, 2, 3], switch_true.values.eval())
self.assertAllEqual([0, 1, 2], switch_true.indices.eval())
@test_util.run_deprecated_v1
def testIndexedSlicesGradient(self):
embedding_matrix = variable_scope.get_variable(
"embedding_matrix", [5, 5],
initializer=init_ops.random_normal_initializer())
def cond(it, _):
return it < 5
def body(it, cost):
embedding = embedding_ops.embedding_lookup(embedding_matrix + 0.0, [0])
cost += math_ops.reduce_sum(embedding)
return it + 1, cost
_, cost = control_flow_ops.while_loop(
cond, body, [constant_op.constant(0),
constant_op.constant(0.0)])
optimizer = momentum.MomentumOptimizer(0.1, 0.9)
train_op = optimizer.minimize(cost)
with self.cached_session():
self.evaluate(variables.global_variables_initializer())
for _ in range(10):
self.evaluate([train_op])
def testResourceReadInLoop(self):
embedding_matrix = variable_scope.get_variable(
"embedding_matrix", initializer=[[2.0], [3.0]], use_resource=True)
def cond(it, _):
return it < 5
def body(it, cost):
embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
cost += math_ops.reduce_sum(embedding)
return it + 1, cost
_, cost = control_flow_ops.while_loop(
cond, body, [constant_op.constant(0),
constant_op.constant(0.0)])
with self.cached_session():
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual(10.0, self.evaluate(cost))
def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False):
embedding_matrix = variable_scope.get_variable(
"embedding_matrix", [5, 5],
initializer=init_ops.random_normal_initializer(),
use_resource=use_resource)
def cond(it, _):
return it < 5
def body(it, cost):
embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
cost = control_flow_ops.cond(
math_ops.equal(it, 3), lambda: math_ops.square(cost),
(lambda: cost + math_ops.reduce_sum(embedding)))
return it + 1, cost
_, cost = control_flow_ops.while_loop(
cond, body, [constant_op.constant(0),
constant_op.constant(0.0)])
dynamic_grads = gradients_impl.gradients(cost, [embedding_matrix])[0]
dynamic_grads = math_ops.segment_sum(dynamic_grads.values,
dynamic_grads.indices)
embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
static = math_ops.square(
math_ops.reduce_sum(embedding) + math_ops.reduce_sum(embedding) +
math_ops.reduce_sum(embedding)) + math_ops.reduce_sum(embedding)
static_grads = gradients_impl.gradients(static, [embedding_matrix])[0]
static_grads = math_ops.segment_sum(static_grads.values,
static_grads.indices)
with self.cached_session():
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual(*self.evaluate([static_grads, dynamic_grads]))
def testIndexedSlicesGradientInCondInWhileLoop(self):
self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=False)
def testIndexedSlicesGradientInCondInWhileLoopResource(self):
self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=True)
@test_util.run_v1_only("b/120545219")
def testIndexedSlicesWithShapeGradientInWhileLoop(self):
for dtype in [dtypes.float32, dtypes.float64]:
with self.cached_session() as sess:
num_steps = 9
inputs = array_ops.placeholder(dtype=dtype, shape=[num_steps])
initial_outputs = tensor_array_ops.TensorArray(
dtype=dtype, size=num_steps)
initial_i = constant_op.constant(0, dtype=dtypes.int32)
def cond(i, _):
return i < num_steps # pylint: disable=cell-var-from-loop
def body(i, outputs):
x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop
outputs = outputs.write(i, x)
return i + 1, outputs
_, outputs = control_flow_ops.while_loop(cond, body,
[initial_i, initial_outputs])
outputs = math_ops.reduce_sum(outputs.stack())
r = gradients_impl.gradients([outputs], [inputs])[0]
grad_wr_inputs = ops.convert_to_tensor(r)
o, grad = sess.run([outputs, grad_wr_inputs],
feed_dict={inputs: [4, 6, 0, 7, 0, 0, 1, 2, 0]})
self.assertEquals(o, 20)
self.assertAllEqual(grad, [1] * num_steps)
@test_util.run_v1_only("b/120545219")
def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self):
for dtype in [dtypes.float32, dtypes.float64]:
with self.cached_session() as sess:
inputs = array_ops.placeholder(dtype=dtype)
initial_outputs = tensor_array_ops.TensorArray(
dtype=dtype, dynamic_size=True, size=1)
initial_i = constant_op.constant(0, dtype=dtypes.int32)
def cond(i, _):
return i < array_ops.size(inputs) # pylint: disable=cell-var-from-loop
def body(i, outputs):
x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop
outputs = outputs.write(i, x)
return i + 1, outputs
_, outputs = control_flow_ops.while_loop(cond, body,
[initial_i, initial_outputs])
outputs = math_ops.reduce_sum(outputs.stack())
r = gradients_impl.gradients([outputs], [inputs])[0]
grad_wr_inputs = ops.convert_to_tensor(r)
o, grad = sess.run([outputs, grad_wr_inputs],
feed_dict={inputs: [1, 3, 2]})
self.assertEquals(o, 6)
self.assertAllEqual(grad, [1] * 3)
@test_util.run_deprecated_v1
def testGradientThroughSingleBranchOutsideOfContext(self):
x = constant_op.constant(2.)
s = constant_op.constant(True)
x_false, x_true = control_flow_ops.switch(x, s)
grad_x_true = gradients_impl.gradients(x_true, x)[0]
grad_x_false = gradients_impl.gradients(x_false, x)[0]
self.assertEquals(self.evaluate(grad_x_true), 1.)
self.assertEquals(self.evaluate(grad_x_false), 0.)
class CondTest(test_util.TensorFlowTestCase):
def testCondTrue(self):
x = constant_op.constant(2)
y = constant_op.constant(5)
z = control_flow_ops.cond(
math_ops.less(
x,
y), lambda: math_ops.multiply(x, 17), lambda: math_ops.add(y, 23))
self.assertEquals(self.evaluate(z), 34)
def testCondFalse(self):
x = constant_op.constant(2)
y = constant_op.constant(1)
z = control_flow_ops.cond(
math_ops.less(
x,
y), lambda: math_ops.multiply(x, 17), lambda: math_ops.add(y, 23))
self.assertEquals(self.evaluate(z), 24)
def testCondTrueLegacy(self):
x = constant_op.constant(2)
y = constant_op.constant(5)
z = control_flow_ops.cond(
math_ops.less(x, y),
fn1=lambda: math_ops.multiply(x, 17),
fn2=lambda: math_ops.add(y, 23))
self.assertEquals(self.evaluate(z), 34)
def testCondFalseLegacy(self):
x = constant_op.constant(2)
y = constant_op.constant(1)
z = control_flow_ops.cond(
math_ops.less(x, y),
fn1=lambda: math_ops.multiply(x, 17),
fn2=lambda: math_ops.add(y, 23))
self.assertEquals(self.evaluate(z), 24)
@test_util.run_v1_only("Exercises Ref variables")
def testCondModifyBoolPred(self):
# We want to use the GPU here because we want to ensure that we can update
# a boolean ref variable on the GPU.
with test_util.use_gpu():
bool_var = variable_scope.get_variable(
"bool_var", dtype=dtypes.bool, initializer=True)
cond_on_bool_var = control_flow_ops.cond(
pred=bool_var,
true_fn=lambda: state_ops.assign(bool_var, False),
false_fn=lambda: True)
self.evaluate(bool_var.initializer)
self.assertEquals(self.evaluate(cond_on_bool_var), False)
self.assertEquals(self.evaluate(cond_on_bool_var), True)
def testCondMissingArg1(self):
x = constant_op.constant(1)
with self.assertRaises(TypeError):
control_flow_ops.cond(True, false_fn=lambda: x)
def testCondMissingArg2(self):
x = constant_op.constant(1)
with self.assertRaises(TypeError):
control_flow_ops.cond(True, lambda: x)
def testCondDuplicateArg1(self):
x = constant_op.constant(1)
with self.assertRaises(TypeError):
control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x)
def testCondDuplicateArg2(self):
x = constant_op.constant(1)
with self.assertRaises(TypeError):
control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
@test_util.enable_control_flow_v2
@test_util.run_in_graph_and_eager_modes
def testCond_gradient(self):
true_in, false_in = array_ops.constant(1.), array_ops.constant(5.)
with backprop.GradientTape(persistent=True) as tape:
tape.watch(true_in)
tape.watch(false_in)
cond_true = control_flow_ops.cond(
array_ops.constant(True), lambda: true_in**2., lambda: false_in**2.)
cond_false = control_flow_ops.cond(
array_ops.constant(False), lambda: true_in**2., lambda: false_in**2.)
grads_true = tape.gradient(
cond_true, [true_in, false_in], output_gradients=3.)
grads_false = tape.gradient(
cond_false, [true_in, false_in], output_gradients=3.)
self.assertEqual(3. * 2. * 1., self.evaluate(grads_true[0]))
self.assertEqual(None if context.executing_eagerly() else 0.,
self.evaluate(grads_true[1]))
self.assertEqual(3. * 2. * 5., self.evaluate(grads_false[1]))
self.assertEqual(None if context.executing_eagerly() else 0.,
self.evaluate(grads_false[0]))
def testCondWithGroupAndSummaries(self):
with ops.Graph().as_default():
writer = summary_ops_v2.create_file_writer(self.get_temp_dir())
with writer.as_default(), summary_ops_v2.always_record_summaries():
op = control_flow_ops.cond(
constant_op.constant(1) >= 0,
lambda: control_flow_ops.group(summary_ops_v2.scalar("loss", 0.2)),
control_flow_ops.no_op)
self.evaluate(variables.global_variables_initializer())
self.evaluate(summary_ops_v2.summary_writer_initializer_op())
self.assertEqual(self.evaluate(op), True)
class ContextTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1
def testCondContext(self):
with self.cached_session() as sess:
x = constant_op.constant(2)
y = constant_op.constant(5)
control_flow_ops.cond(
math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
lambda: math_ops.add(y, 23))
for op in sess.graph.get_operations():
c = op._get_control_flow_context()
if c:
self.assertProtoEquals(
c.to_proto(),
control_flow_ops.CondContext.from_proto(c.to_proto()).to_proto())
def _testWhileContextHelper(self, maximum_iterations=None):
with self.cached_session() as sess:
i = constant_op.constant(0)
c = lambda i: math_ops.less(i, 10)
b = lambda i: math_ops.add(i, 1)
control_flow_ops.while_loop(
c, b, [i], maximum_iterations=maximum_iterations)
for op in sess.graph.get_operations():
control_flow_context = op._get_control_flow_context()
if control_flow_context:
self.assertProtoEquals(
control_flow_context.to_proto(),
control_flow_ops.WhileContext.from_proto(
control_flow_context.to_proto()).to_proto())
@test_util.run_deprecated_v1
def testWhileContext(self):
self._testWhileContextHelper()
@test_util.run_deprecated_v1
def testWhileContextWithMaximumIterations(self):
self._testWhileContextHelper(maximum_iterations=10)
@test_util.run_deprecated_v1
def testControlContextImportScope(self):
class NoABCControlFlowContext(control_flow_ops.ControlFlowContext):
"""A noop wrapper around `ControlFlowContext`.
`ControlFlowContext` is an ABC and therefore cannot be instantiated.
"""
# pylint: disable=useless-super-delegation
def to_control_flow_context_def(self, context_def, export_scope=None):
super(NoABCControlFlowContext, self).to_control_flow_context_def(
context_def, export_scope)
with self.cached_session():
constant_op.constant(0, name="a")
constant_op.constant(2, name="test_scope/a")
b1 = constant_op.constant(1, name="b")
b2 = constant_op.constant(3, name="test_scope/b")
c = NoABCControlFlowContext()
c._values = ["a", "b"]
c._external_values = {"a": b1}
c_with_scope = NoABCControlFlowContext(
values_def=c._to_values_def(), import_scope="test_scope")
# _values and _external_values should be have scope prepended.
self.assertEquals(
c_with_scope._values, set(["test_scope/a", "test_scope/b"]))
self.assertEquals(
c_with_scope._external_values, {"test_scope/a": b2})
# Calling _to_proto() with export_scope should remove "test_scope".
self.assertProtoEquals(
c._to_values_def(),
c_with_scope._to_values_def(export_scope="test_scope"))
def _get_nested_shape(nested):
def _get_shape(tensor):
if isinstance(tensor, tensor_array_ops.TensorArray):
return tensor_array_ops.TensorArray
elif isinstance(tensor, ops.IndexedSlices):
return tensor.dense_shape
else:
return tensor.get_shape()
return nest.map_structure(_get_shape, nested)
def _create_tensor_array(size, shape):
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=size,
clear_after_read=False)
for i in range(size):
ta = ta.write(i, array_ops.zeros(shape))
return ta
def _raw_nested_shape(nested_shape):
def _raw_shape(shape):
if isinstance(shape, tensor_shape.TensorShape) and shape.ndims is not None:
return [x.value for x in shape.dims]
else:
return None
return nest.map_structure(_raw_shape, nested_shape)
# TODO(yori): Add tests for indexed slices.
class DataTypesTest(test_util.TensorFlowTestCase):
def assertAllEqualNested(self, a, b):
if isinstance(a, (list, tuple)):
for entry_a, entry_b in zip(a, b):
self.assertAllEqualNested(entry_a, entry_b)
else:
self.assertAllEqual(a, b)
def _testShape(self, fn_true, fn_false, expected_shape,
strict=False):
condition = array_ops.placeholder(dtypes.bool)
output_cond = control_flow_ops.cond(condition, fn_true, fn_false,
strict=strict)
self.assertEqual(
_raw_nested_shape(_get_nested_shape(output_cond)),
_raw_nested_shape(expected_shape))
output_case = control_flow_ops.case([(condition, fn_true)], fn_false,
strict=strict)
self.assertEqual(
_raw_nested_shape(_get_nested_shape(output_case)),
_raw_nested_shape(expected_shape))
def _testReturnValues(self, fn_true, fn_false, expected_value_true,
expected_value_false, strict=False,
check_cond=True, feed_dict=None):
if feed_dict is None: feed_dict = {}
condition = array_ops.placeholder(dtypes.bool)
output_cond = control_flow_ops.cond(condition, fn_true, fn_false,
strict=strict)
output_case = control_flow_ops.case([(condition, fn_true)], fn_false,
strict=strict)
with self.cached_session() as sess:
self.evaluate(variables.global_variables_initializer())
true_feed_dict = {condition: True}
true_feed_dict.update(feed_dict)
result_cond, result_case = sess.run([output_cond, output_case],
feed_dict=true_feed_dict)
self.assertAllEqualNested(result_cond, expected_value_true)
if check_cond:
self.assertAllEqualNested(result_case, expected_value_true)
false_feed_dict = {condition: False}
false_feed_dict.update(feed_dict)
result_cond, result_case = sess.run([output_cond, output_case],
feed_dict=false_feed_dict)
self.assertAllEqualNested(result_cond, expected_value_false)
if check_cond:
self.assertAllEqualNested(result_case, expected_value_false)
@test_util.run_deprecated_v1
def test_int(self):
shape = tensor_shape.TensorShape([])
fn_true = lambda: 1
fn_false = lambda: 2
self._testShape(fn_true, fn_false, shape)
self._testReturnValues(fn_true, fn_false, 1, 2)
self._testShape(fn_true, fn_false, shape, strict=True)
self._testReturnValues(fn_true, fn_false, 1, 2, strict=True)
@test_util.run_deprecated_v1
def test_float(self):
shape = tensor_shape.TensorShape([])
fn_true = lambda: 1.0
fn_false = lambda: 2.0
self._testShape(fn_true, fn_false, shape)
self._testReturnValues(fn_true, fn_false, 1.0, 2.0)
@test_util.run_deprecated_v1
def test_noop(self):
shape = tensor_shape.TensorShape(None)
self._testShape(control_flow_ops.no_op, control_flow_ops.no_op, shape)
self._testReturnValues(control_flow_ops.no_op, control_flow_ops.no_op,
True, False, check_cond=False)
@test_util.run_deprecated_v1
def test_string(self):
shape = tensor_shape.TensorShape([])
fn_true = lambda: "abc"
fn_false = lambda: "xyz"
self._testShape(fn_true, fn_false, shape)
self._testReturnValues(fn_true, fn_false, b"abc", b"xyz")
@test_util.run_v1_only("b/138741991")
def test_variable(self):
shape = tensor_shape.TensorShape([])
fn_true = lambda: variables.Variable(3.0)
fn_false = lambda: variables.Variable(4.0)
self._testShape(fn_true, fn_false, shape)
self._testReturnValues(fn_true, fn_false, 3.0, 4.0)
@test_util.run_v1_only("b/120553181")
def test_none(self):
fn_none = lambda: None
fn_tensor = lambda: constant_op.constant(1)
with self.assertRaises(ValueError):
control_flow_ops.cond(constant_op.constant(True), fn_none, fn_tensor)
with self.assertRaises(ValueError):
control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_none)
@test_util.run_deprecated_v1
def test_tensors(self):
def _build_true_branch(dtype):
def _build():
return (array_ops.zeros([2, 2], dtype=dtype),
array_ops.ones([3, 3], dtype=dtype))
return _build
def _build_false_branch(dtype):
def _build():
return (array_ops.ones([2, 2], dtype=dtype),
array_ops.zeros([3, 3], dtype=dtype))
return _build
for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8):
shape = (tensor_shape.TensorShape([2, 2]),
tensor_shape.TensorShape([3, 3]))
fn_true = _build_true_branch(dtype)
fn_false = _build_false_branch(dtype)
self._testShape(fn_true, fn_false, shape)
self._testReturnValues(fn_true, fn_false,
(np.zeros([2, 2]), np.ones([3, 3])),
(np.ones([2, 2]), np.zeros([3, 3])))
@test_util.run_deprecated_v1
def test_tensors_unknown_shape(self):
def _build_true_branch(dtype):
tensor = array_ops.placeholder(dtype=dtype, shape=None)
def _build():
return tensor
return _build, tensor
def _build_false_branch(dtype):
tensor = array_ops.placeholder(dtype=dtype, shape=None)
def _build():
return tensor
return _build, tensor
for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8):
shape = tensor_shape.TensorShape(None)
fn_true, true_tensor = _build_true_branch(dtype)
fn_false, false_tensor = _build_false_branch(dtype)
self._testShape(fn_true, fn_false, shape)
self._testReturnValues(fn_true, fn_false,
np.zeros([2, 2]), np.ones([2, 2]),
feed_dict={true_tensor: np.zeros([2, 2]),
false_tensor: np.ones([2, 2])})
@test_util.run_deprecated_v1
def test_sparse_tensors(self):
shape = tensor_shape.TensorShape([None, None])
def true_fn():
return [sparse_tensor.SparseTensor(indices=[[0, 0], [1, 2]],
values=[1, 2], dense_shape=[3, 4])]
def false_fn():
return [sparse_tensor.SparseTensor(indices=[[0, 0], [2, 1]],
values=[3, 4], dense_shape=[3, 4])]
value1 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 2]],
values=[1, 2], dense_shape=[3, 4])
value2 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [2, 1]],
values=[3, 4], dense_shape=[3, 4])
# Non-strict cond is only available in v1
if not tf2.enabled():
self._testShape(true_fn, false_fn, shape)
self._testReturnValues(true_fn, false_fn, value1, value2)
self._testShape(true_fn, false_fn, [shape], strict=True)
self._testReturnValues(true_fn, false_fn, [value1], [value2], strict=True)
@test_util.run_deprecated_v1
def test_tensors_with_partially_specified_shapes(self):
def _build_branch(dtype, shape):
a = array_ops.placeholder(dtype=dtype, shape=shape[0])
b = array_ops.placeholder(dtype=dtype, shape=shape[1])
c = array_ops.placeholder(dtype=dtype, shape=shape[2])
def _build():
return a, b, c
return _build, (a, b, c)
for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8):
shape = (tensor_shape.TensorShape([None, 2]),
tensor_shape.TensorShape([None]),
tensor_shape.TensorShape([3, None]))
fn_true, true_tensors = _build_branch(dtype, shape)
fn_false, false_tensors = _build_branch(dtype, shape)
self._testShape(fn_true, fn_false, shape)
self._testReturnValues(fn_true, fn_false,
(np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])),
(np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])),
feed_dict={true_tensors[0]: np.zeros([2, 2]),
false_tensors[0]: np.zeros([2, 2]),
true_tensors[1]: np.zeros([5]),
false_tensors[1]: np.zeros([5]),
true_tensors[2]: np.ones([3, 3]),
false_tensors[2]: np.ones([3, 3])})
@test_util.run_deprecated_v1
def test_tensor_arrays(self):
element_shape = tensor_shape.TensorShape([2])
ta1 = _create_tensor_array(4, element_shape)
ta2 = _create_tensor_array(4, element_shape)
shape = tensor_array_ops.TensorArray
fn_true = lambda: ta1
fn_false = lambda: ta2
self._testShape(fn_true, fn_false, shape)
@test_util.run_deprecated_v1
def test_tensor_array_reads(self):
shape = tensor_shape.TensorShape([2])
ta = _create_tensor_array(4, shape)
fn_true = lambda: ta.read(0)
fn_false = lambda: ta.read(1)
self._testShape(fn_true, fn_false, shape)
@test_util.run_v1_only("b/138741991")
def test_list(self):
shape = [tensor_shape.TensorShape([]), tensor_shape.TensorShape([]),
tensor_shape.TensorShape([])]
fn_true = lambda: [constant_op.constant(1), 2, variables.Variable(3.0)]
fn_false = lambda: [constant_op.constant(3), 4, variables.Variable(5.0)]
self._testShape(fn_true, fn_false, shape)
self._testReturnValues(fn_true, fn_false, [1, 2, 3.0], [3, 4, 5.0])
@test_util.run_v1_only("Non-strict cond is only available in v1")
def test_non_strict(self):
shape = tensor_shape.TensorShape([])
fn_tensor = lambda: constant_op.constant(1)
fn_list = lambda: [constant_op.constant(2)]
fn_tuple = lambda: (constant_op.constant(3),)
self._testShape(fn_tensor, fn_list, shape)
self._testShape(fn_tensor, fn_tuple, shape)
self._testShape(fn_list, fn_tuple, shape)
self._testReturnValues(fn_tensor, fn_list, 1, 2)
self._testReturnValues(fn_tensor, fn_tuple, 1, 3)
self._testReturnValues(fn_list, fn_tuple, 2, 3)
@test_util.run_v1_only("b/120553181")
def test_singleton_strict(self):
fn_tensor = lambda: constant_op.constant(1)
fn_list = lambda: [constant_op.constant(2)]
fn_tuple = lambda: (constant_op.constant(3),)
with self.assertRaises(ValueError):
control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_list,
strict=True)
with self.assertRaises(TypeError):
control_flow_ops.cond(constant_op.constant(True), fn_list, fn_tuple,
strict=True)
with self.assertRaises(ValueError):
control_flow_ops.case([(constant_op.constant(True), fn_tensor)], fn_list,
strict=True)
with self.assertRaises(TypeError):
control_flow_ops.case([(constant_op.constant(True), fn_list)], fn_tuple,
strict=True)
@test_util.run_deprecated_v1
def test_singleton_list(self):
shape = tensor_shape.TensorShape([])
fn_true = lambda: [constant_op.constant(1)]
fn_false = lambda: [constant_op.constant(3)]
# Non-strict cond is only available in v1
if not tf2.enabled():
self._testShape(fn_true, fn_false, shape)
self._testReturnValues(fn_true, fn_false, 1, 3)
self._testShape(fn_true, fn_false, [shape], strict=True)
self._testReturnValues(fn_true, fn_false, [1], [3], strict=True)
@test_util.run_deprecated_v1
def test_singleton_tuple(self):
shape = tensor_shape.TensorShape([])
fn_true = lambda: (constant_op.constant(1),)
fn_false = lambda: (constant_op.constant(3),)
# Non-strict cond is only available in v1
if not tf2.enabled():
self._testShape(fn_true, fn_false, shape)
self._testReturnValues(fn_true, fn_false, 1, 3)
self._testShape(fn_true, fn_false, (shape,), strict=True)
self._testReturnValues(fn_true, fn_false, (1,), (3,),
strict=True)
@test_util.run_deprecated_v1
def test_singleton_namedtuple(self):
shape = tensor_shape.TensorShape([])
fn_true = lambda: SingletonTestTuple(constant_op.constant(1))
fn_false = lambda: SingletonTestTuple(constant_op.constant(3))
# Non-strict cond is only available in v1
if not tf2.enabled():
self._testShape(fn_true, fn_false, shape)
self._testReturnValues(fn_true, fn_false, 1, 3)
self._testShape(fn_true, fn_false, SingletonTestTuple(shape),
strict=True)
self._testReturnValues(fn_true, fn_false, SingletonTestTuple(1),
SingletonTestTuple(3), strict=True)
@test_util.run_deprecated_v1
def test_tuple(self):
shape = (tensor_shape.TensorShape([]), tensor_shape.TensorShape([]))
fn_true = lambda: (constant_op.constant(1), 2)
fn_false = lambda: (constant_op.constant(3), 4)
self._testShape(fn_true, fn_false, shape)
self._testReturnValues(fn_true, fn_false, (1, 2), (3, 4))
@test_util.run_deprecated_v1
def test_namedtuple(self):
shape = TestTuple(tensor_shape.TensorShape([]),
tensor_shape.TensorShape([]))
fn_true = lambda: TestTuple(constant_op.constant(1), 2)
fn_false = lambda: TestTuple(constant_op.constant(3), 4)
self._testShape(fn_true, fn_false, shape)
self._testReturnValues(fn_true, fn_false, TestTuple(1, 2), TestTuple(3, 4))
@test_util.run_deprecated_v1
def test_nested(self):
shape = [tensor_shape.TensorShape([]),
TestTuple(tensor_shape.TensorShape([]),
[tensor_shape.TensorShape([]),
tensor_shape.TensorShape([])]),
tensor_shape.TensorShape([5, 5]),
tensor_shape.TensorShape([])]
def true_fn():
return [constant_op.constant(1),
TestTuple(constant_op.constant(2), [3, 4]),
array_ops.zeros([5, 5]), 6]
def false_fn():
return [constant_op.constant(11),
TestTuple(constant_op.constant(12), [13, 14]),
array_ops.ones([5, 5]), 16]
self._testShape(true_fn, false_fn, shape)
self._testReturnValues(
true_fn, false_fn,
[1, TestTuple(2, [3, 4]), np.zeros([5, 5]), 6],
[11, TestTuple(12, [13, 14]),
np.ones([5, 5]), 16])
@test_util.run_deprecated_v1
def test_cond_inside_while_loop(self):
def body(i, matrix):
result_tuple, unused_matrix = control_flow_ops.cond(
constant_op.constant(True),
lambda: (TestTuple(matrix * 2, matrix * 4), matrix),
lambda: (TestTuple(matrix * 4, matrix * 2), matrix))
return [i+1, result_tuple.a]
iteration, matrix = control_flow_ops.while_loop(
lambda i, matrix: i < 10,
body,
loop_vars=[constant_op.constant(0),
array_ops.ones([2, 2])])
self.assertEqual(iteration.get_shape(), tensor_shape.TensorShape([]))
self.assertEqual(matrix.get_shape(), tensor_shape.TensorShape([2, 2]))
@test_util.run_all_in_graph_and_eager_modes
class IndexedCaseTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def make_name(self):
name = self.id().split(".")[-1].replace("(", "_").replace(")", "")
return name.replace(" ", "_")
def disabled_testCase_ticklesGpuVsHostMemoryIssueWithInt32(self):
nbranches = 5
def make_func(bi):
return lambda: array_ops.constant(bi * 10, name="br{}_out".format(bi))
branches = [(i, make_func(i)) for i in range(nbranches)]
for bi in range(nbranches):
branch_index = array_ops.placeholder_with_default(bi, [])
case_out = control_flow_ops.switch_case(branch_index, branches)
self.assertEqual(bi * 10, self.evaluate(case_out))
@parameterized.parameters((0,), (2,), (3,))
def testCase(self, bi):
nbranches = 5
def make_func(bi):
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
branches = [(i, make_func(i)) for i in range(nbranches)]
branch_index = array_ops.placeholder_with_default(bi, [])
case_out = control_flow_ops.switch_case(
branch_index, branches, name=self.make_name())
self.assertEqual(bi * 10., self.evaluate(case_out))
@parameterized.parameters((-1,), (2,), (4,), (5,), (6,))
def testCase_withDefault(self, bi):
nbranches = 5
def make_func(bi):
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
branches = [(i, make_func(i)) for i in range(nbranches)]
branch_index = array_ops.placeholder_with_default(bi, [])
case_out = control_flow_ops.switch_case(
branch_index, branches, default=make_func(6), name=self.make_name())
if bi < 0 or bi >= nbranches:
expected = 60.
else:
expected = bi * 10.
self.assertEqual(expected, self.evaluate(case_out))
@parameterized.parameters((-1,), (0,), (3,), (5,))
def testCase_dictWithDefault(self, bi):
nbranches = 5
def make_func(bi):
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
branches = {i: make_func(i) for i in range(nbranches)}
branch_index = array_ops.placeholder_with_default(bi, [])
case_out = control_flow_ops.switch_case(
branch_index, branches, default=make_func(6), name=self.make_name())
if bi < 0 or bi >= nbranches:
expected = 60.
else:
expected = bi * 10.
self.assertEqual(expected, self.evaluate(case_out))
@parameterized.parameters((-1,), (1,), (4,), (5,))
def testCase_gradient_disable_lowering(self, bi):
self._testCase_gradient(True, bi)
@parameterized.parameters((-1,), (1,), (4,), (5,))
def testCase_gradient_enable_lowering(self, bi):
self._testCase_gradient(False, bi)
def _testCase_gradient(self, disable_lowering, bi):
default_lowering = control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE
control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = disable_lowering
nbranches = 5
inputs = [
array_ops.constant(float(bi), name="br{}_in".format(bi))
for bi in range(nbranches)
]
def make_func(bi):
return lambda: inputs[bi]**2.
branches = {bi: make_func(bi) for bi in range(nbranches)}
branch_index = array_ops.placeholder_with_default(bi, [])
with backprop.GradientTape() as tape:
for x in inputs:
tape.watch(x)
case_out = control_flow_ops.switch_case(branch_index, branches)
out_grad = 3.
actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad)
expected_grads = [None if context.executing_eagerly() else 0.] * nbranches
used_branch_idx = nbranches - 1 if bi < 0 or bi >= nbranches - 1 else bi
expected_grads[used_branch_idx] = out_grad * 2. * used_branch_idx
self.assertEqual(len(expected_grads), len(actual_grads))
for expected, actual in zip(expected_grads, actual_grads):
self.assertEqual(expected, self.evaluate(actual))
# reset to default value
control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = default_lowering
@parameterized.parameters((-2,), (2,), (5,))
def testCase_gradient_diffShapedIntermediates(self, bi):
nbranches = 5
inputs = [
array_ops.constant(
float(bi), shape=[bi + 1], name="br{}_in".format(bi))
for bi in range(nbranches)
]
def make_func(bi):
def f():
x = inputs[bi]**2 * inputs[bi][:bi + 1, None]
return math_ops.reduce_sum(x)
return f
branches = {bi: make_func(bi) for bi in range(nbranches)}
branch_index = array_ops.placeholder_with_default(bi, [])
with backprop.GradientTape() as tape:
for x in inputs:
tape.watch(x)
case_out = control_flow_ops.switch_case(
branch_index, branches, name=self.make_name())
out_grad = 3.
actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad)
used_bi = (nbranches - 1) if (bi < 0 or bi >= nbranches - 1) else bi
expected_grads = []
for input_idx in range(nbranches):
if used_bi == input_idx:
with backprop.GradientTape() as tape:
tape.watch(inputs[used_bi])
y = make_func(used_bi)()
expected_grads.append(
self.evaluate(
tape.gradient(y, inputs[used_bi], output_gradients=out_grad)))
else:
expected_grads.append(None if context.executing_eagerly() else [0.] *
(input_idx + 1))
self.assertEqual(len(expected_grads), len(actual_grads))
for expected, actual in zip(expected_grads, actual_grads):
if expected is None:
self.assertIsNone(actual)
else:
self.assertAllEqual(expected, self.evaluate(actual))
@test_util.run_gpu_only
@test_util.disable_xla("Wants RunMetadata")
def testParallelExecution(self):
"""Verify disjoint branches across while iterations are run in parallel."""
if control_flow_v2_toggles.control_flow_v2_enabled():
self.skipTest("b/138870290")
if test.is_built_with_rocm():
self.skipTest(
"Disable subtest on ROCm due to missing Cholesky op support")
with ops.Graph().as_default() as g:
nbranches = 7
matrices = array_ops.unstack( # Ensure all are ready before while.
array_ops.matrix_diag(
random_ops.random_uniform([nbranches, 8, 512]) + 1e-3))
def make_branch(i, mat, name):
def branch_fn():
next_i = i + 1
with ops.device("gpu:0"):
return next_i, math_ops.reduce_sum(
linalg_ops.cholesky(mat, name=name + "_Cholesky"))
return branch_fn
def make_branches(i):
return [make_branch(i, matrices[bi], "br{}".format(bi))
for bi in range(nbranches)]
def cond(i, _):
return i < nbranches
def body(i, result):
with ops.device("cpu:0"):
next_i, branch_out = control_flow_ops.switch_case(i, make_branches(i))
return next_i, result + branch_out
_, result = control_flow_ops.while_loop(cond, body, [0, 0.])
run_metadata = config_pb2.RunMetadata()
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
config = config_pb2.ConfigProto(
allow_soft_placement=False, log_device_placement=True)
with session.Session(config=config, graph=g) as sess:
_ = sess.run(result, options=run_options, run_metadata=run_metadata)
chol_node_stats = []
for dev_stats in run_metadata.step_stats.dev_stats:
for node_stats in dev_stats.node_stats:
if (node_stats.node_name.endswith("Cholesky") and
node_stats.all_start_nanos > 0):
chol_node_stats.append(node_stats)
self.assertLen(chol_node_stats, nbranches)
chol_node_stats = sorted(chol_node_stats, key=lambda stats: stats.node_name)
op_start_nanos = [
stats.all_start_nanos for stats in chol_node_stats
]
op_end_nanos = [
stats.all_start_nanos + stats.op_end_rel_nanos
for stats in chol_node_stats
]
def overlap(range1, range2):
s1, e1 = range1
s2, e2 = range2
if s1 < s2:
return 0 if s2 > e1 else e1 - s2
return 0 if s1 > e2 else e2 - s1
timespans = list(zip(op_start_nanos, op_end_nanos))
overlaps_chol0 = [overlap(timespans[0], r2) for r2 in timespans[1:]]
# There are nbranches-1 overlaps, sometimes all nonzero, but we
# conservatively check for at least one here, to avoid test flakiness.
self.assertGreater(np.count_nonzero(overlaps_chol0), 0)
def testCase_validateIndicesContiguous(self):
def make_func(bi):
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
branches = {i: make_func(i) for i in range(0, 6, 2)}
with self.assertRaisesRegexp(ValueError, "must form contiguous"):
control_flow_ops.switch_case(array_ops.constant(0), branches)
def testCase_validateIndicesDup(self):
def make_func(bi):
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
branches = [(i, make_func(i)) for i in range(0, 6, 2)]
branches.append((0, make_func(7)))
with self.assertRaisesRegexp(ValueError, "must form contiguous"):
control_flow_ops.switch_case(array_ops.constant(0), branches)
def testCase_validateBranchIndex(self):
def make_func(bi):
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
branches = {i: make_func(i) for i in range(5)}
with self.assertRaisesRegexp(TypeError, "branch_index.*Tensor"):
control_flow_ops.switch_case(1, branches)
def testCase_validateNonIntKeys(self):
def make_func(bi):
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
branches = [(array_ops.constant(i), make_func(i)) for i in range(5)]
with self.assertRaisesRegexp(TypeError, "must be a Python `int`"):
control_flow_ops.switch_case(array_ops.constant(1), branches)
class ExecuteFnForDeviceTest(test_util.TensorFlowTestCase):
def testCommonCases(self):
def cpu_fn(x):
return x + x
def gpu_fn(x):
return x * x
def flexible_fn(a):
branches = {"CPU": lambda: cpu_fn(a), "GPU": lambda: gpu_fn(a)}
return control_flow_ops.execute_fn_for_device(branches, lambda: cpu_fn(a))
@def_function.function
def flexible_defun(a):
return flexible_fn(a)
def run_defun_and_tape(a):
with backprop.GradientTape() as tape:
tape.watch(a)
result = flexible_defun(a)
grad = tape.gradient(result, a)
r = flexible_fn(a)
return r, result, grad
a = array_ops.constant(3.)
with ops.device("cpu:0"):
r, result, grad = run_defun_and_tape(a)
self.assertEqual(6., self.evaluate(r))
self.assertEqual(6., self.evaluate(result))
self.assertEqual([2.], self.evaluate(grad))
if test_util.is_gpu_available():
with ops.device("gpu:0"):
r, result, grad = run_defun_and_tape(a)
self.assertEqual(9., self.evaluate(r))
self.assertEqual(9., self.evaluate(result))
self.assertEqual([6.], self.evaluate(grad))
# no device annotation
r, result, grad = run_defun_and_tape(a)
if test_util.is_gpu_available():
self.assertEqual(9., self.evaluate(r))
self.assertEqual(9., self.evaluate(result))
self.assertEqual([6.], self.evaluate(grad))
else:
self.assertEqual(6., self.evaluate(r))
self.assertEqual(6., self.evaluate(result))
self.assertEqual([2.], self.evaluate(grad))
def testCompile(self):
if not test_util.is_gpu_available():
return
def cpu_fn(x):
return x + x
def gpu_fn(x):
return x * x
@def_function.function(experimental_compile=True)
def flexible_defun(a):
branches = {"CPU": lambda: cpu_fn(a), "GPU": lambda: gpu_fn(a)}
return control_flow_ops.execute_fn_for_device(branches, lambda: cpu_fn(a))
# Always execute the default branch in xla compilation case.
a = array_ops.constant(3.)
r = flexible_defun(a)
self.assertEqual(6., self.evaluate(r))
def testFallBack(self):
def default_fn(x):
return x
def tpu_fn(x):
return x * x * x
def flexible_fn(a):
branches = {"TPU": lambda: tpu_fn(a)}
return control_flow_ops.execute_fn_for_device(
branches, default_fn=lambda: default_fn(a))
@def_function.function
def flexible_defun(a):
return flexible_fn(a)
a = array_ops.constant(3.)
with ops.device("cpu:0"):
result_defun = flexible_defun(a)
result_defun = flexible_fn(a)
self.assertEqual(3., self.evaluate(result_defun))
# execute_fn_for_device is not inside defun_function.
result = flexible_fn(a)
self.assertEqual(3., self.evaluate(result))
if test_util.is_gpu_available():
with ops.device("gpu:0"):
result_defun = flexible_defun(a)
self.assertEqual(3., self.evaluate(result_defun))
# execute_fn_for_device is not inside defun_function.
result = flexible_fn(a)
self.assertEqual(3., self.evaluate(result))
class CaseTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1
def testCase_withDefault(self):
x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
(math_ops.equal(x, 2), lambda: constant_op.constant(4))]
default = lambda: constant_op.constant(6)
output = control_flow_ops.case(conditions, default, exclusive=True)
with self.cached_session() as sess:
self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
self.assertEqual(sess.run(output, feed_dict={x: 3}), 6)
@test_util.run_deprecated_v1
def testCase_multiple_matches_exclusive(self):
x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
(math_ops.equal(x, 2), lambda: constant_op.constant(4)),
(math_ops.equal(x, 2), lambda: constant_op.constant(6))]
default = lambda: constant_op.constant(8)
output = control_flow_ops.case(conditions, default, exclusive=True)
with self.cached_session() as sess:
self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
self.assertEqual(sess.run(output, feed_dict={x: 3}), 8)
with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"):
sess.run(output, feed_dict={x: 2})
@test_util.run_deprecated_v1
def testCase_multiple_matches_non_exclusive(self):
x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
(math_ops.equal(x, 2), lambda: constant_op.constant(4)),
(math_ops.equal(x, 2), lambda: constant_op.constant(6))]
default = lambda: constant_op.constant(8)
output = control_flow_ops.case(conditions, default, exclusive=False)
with self.cached_session() as sess:
self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
self.assertEqual(sess.run(output, feed_dict={x: 3}), 8)
@test_util.run_deprecated_v1
def testCase_withoutDefault(self):
x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
(math_ops.equal(x, 2), lambda: constant_op.constant(4)),
(math_ops.equal(x, 3), lambda: constant_op.constant(6))]
output = control_flow_ops.case(conditions, exclusive=True)
with self.cached_session() as sess:
self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
self.assertEqual(sess.run(output, feed_dict={x: 3}), 6)
with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"):
sess.run(output, feed_dict={x: 4})
@test_util.run_deprecated_v1
def testCase_withoutDefault_oneCondition(self):
x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2))]
output = control_flow_ops.case(conditions, exclusive=True)
with self.cached_session() as sess:
self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"):
sess.run(output, feed_dict={x: 4})
@test_util.run_in_graph_and_eager_modes
def testCase_dict(self):
x = constant_op.constant(2)
conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
(math_ops.equal(x, 2), lambda: constant_op.constant(4))]
output = control_flow_ops.case(conditions, exclusive=True)
self.assertEqual(4, self.evaluate(output))
class WhileLoopTestCase(test_util.TensorFlowTestCase):
@test_util.run_in_graph_and_eager_modes
def testWhileLoopWithSingleVariable(self):
i = constant_op.constant(0)
c = lambda i: math_ops.less(i, 10)
b = lambda i: math_ops.add(i, 1)
r = control_flow_ops.while_loop(c, b, [i])
self.assertEqual(self.evaluate(r), 10)
@test_util.run_in_graph_and_eager_modes
def testEagerWhileLoopWithSingleVariable_bodyReturnsTuple(self):
i = constant_op.constant(0)
c = lambda i: math_ops.less(i, 10)
b = lambda i: (math_ops.add(i, 1),)
r = control_flow_ops.while_loop(c, b, [i])
# Expect a tuple since that is what the body returns.
self.assertEqual(self.evaluate(r), (10,))
@test_util.run_v1_only("Unsupported in cfv2")
def testWhileLoopSameReturnShape_False(self):
i = constant_op.constant(0)
c = lambda i, _: math_ops.less(i, 10)
# Body returns a [tensor, []]
b = lambda i, _: [math_ops.add(i, 1), []]
# Should only return the tensor.
r = control_flow_ops.while_loop(c, b, [i, []])
self.assertEqual(self.evaluate(r), 10)
# Adding maximum_iterations should yield the same result.
r = control_flow_ops.while_loop(c, b, [i, []], maximum_iterations=50)
# Note: this result is still incorrect - it should be just 10.
self.assertEqual(self.evaluate(r), [10, []])
def testWhileLoopSameReturnShape_FalseSingleLoopVar(self):
i = constant_op.constant(0)
c = lambda i: math_ops.less(i, 10)
# Body return must be unpacked in this case.
b = lambda i: math_ops.add(i, 1)
# Should only return the tensor.
r = control_flow_ops.while_loop(c, b, [i])
self.assertEqual(self.evaluate(r), 10)
# Adding maximum_iterations should yield the same result.
r = control_flow_ops.while_loop(c, b, [i], maximum_iterations=50)
self.assertEqual(self.evaluate(r), 10)
def testWhileLoopSameReturnShape_True(self):
i = constant_op.constant(0)
c = lambda i, _: math_ops.less(i, 10)
# Body returns a [tensor, []]
b = lambda i, _: [math_ops.add(i, 1), []]
# Should only return the original structure.
r = control_flow_ops.while_loop(c, b, [i, []], return_same_structure=True)
self.assertEqual(self.evaluate(r), [10, []])
# Adding maximum_iterations should yield the same result.
r = control_flow_ops.while_loop(
c, b, [i, []], return_same_structure=True, maximum_iterations=50)
self.assertEqual(self.evaluate(r), [10, []])
def testWhileLoopSameReturnShape_TrueSingleLoopVar(self):
i = constant_op.constant(0)
c = lambda i: math_ops.less(i, 10)
b = lambda i: [math_ops.add(i, 1)]
# Should not unpack the single variable
r = control_flow_ops.while_loop(c, b, [i], return_same_structure=True)
self.assertEqual(self.evaluate(r), [10])
# Adding maximum_iterations should yield the same result.
r = control_flow_ops.while_loop(
c, b, [i], return_same_structure=True, maximum_iterations=50)
self.assertEqual(self.evaluate(r), [10])
@test_util.enable_control_flow_v2
@test_util.run_in_graph_and_eager_modes
def testSkipsUnnecessaryCaptureGradients(self):
@custom_gradient.custom_gradient
def gradient_trap(t):
def grad(w):
# Computing this gradient should fail the test
check_ops.assert_equal(0, 1)
return w
return t, grad
x = array_ops.constant(0.0, name="x")
y = array_ops.constant(1.0, name="y")
def cond(s):
return s < 10.0
def body(s):
return s + 2*x + gradient_trap(y)
with backprop.GradientTape() as tape:
tape.watch(x)
out = control_flow_ops.while_loop(cond, body, (array_ops.constant(0.0),))
grad = tape.gradient(out, x)
self.assertAllEqual(grad, 20.0)
class AssertTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1
def testAssert(self):
i = constant_op.constant(0)
c = control_flow_ops.Assert(i < 10, [i, [10], [i + 1]])
self.evaluate(c)
i = constant_op.constant(10)
c = control_flow_ops.Assert(i < 10, [i, [10], [i + 1]])
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(c)
@test_util.run_in_graph_and_eager_modes
def testAssertInFunction(self):
# TODO(fishx): Re-enable this test for GPU.
# NOTE(fishx): Disable this test for now because, in GPU, multiple errors
# will be thrown. But since the root cause error is marked as "derived"
# error. So it might be ignored.
if test_util.is_gpu_available():
self.skipTest("Skip GPU Test")
@def_function.function
def whiny(value):
control_flow_ops.Assert(value, ["Raised false"])
return constant_op.constant(5)
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(whiny(False))
self.assertAllEqual(whiny(True), 5)
if __name__ == "__main__":
googletest.main()