STT-tensorflow/tensorflow/python/kernel_tests/functional_ops_test.py
Akshay Modi a94839d791 Make "share_cluster_devices_in_session" non experimental.
PiperOrigin-RevId: 294773227
Change-Id: Ib71d36f09e7e0fde3b5afe6710b9011d39445ede
2020-02-12 15:40:55 -08:00

1290 lines
47 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.kernels.functional_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.eager import def_function as eager_def_function
from tensorflow.python.eager import function as eager_function
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
from tensorflow.python.util import compat
# pylint: disable=invalid-name
def simple_scoped_fn(a, x):
"""Simple function: (a, x) -> 2(x+a), but with "2" as a variable in scope."""
with variable_scope.variable_scope("body"):
# Dummy variable, just to check that scoping works as intended.
two = variable_scope.get_variable(
"two", [],
dtype=dtypes.int32,
initializer=init_ops.constant_initializer(2))
return math_ops.multiply(math_ops.add(a, x), two)
@test_util.with_control_flow_v2
class FunctionalOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testFoldl_Simple(self):
elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
r = functional_ops.foldl(
lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
elems)
self.assertAllEqual(208, self.evaluate(r))
r = functional_ops.foldl(
lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
elems,
initializer=10)
self.assertAllEqual(880, self.evaluate(r))
@test_util.run_in_graph_and_eager_modes
def testFoldl_SingleInputMultiOutput(self):
elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
initializer = np.array([1, -1.0])
r = functional_ops.foldl(lambda a, x: a + x, elems, initializer)
r_value = self.evaluate(r)
self.assertAllEqual(22, r_value[0])
self.assertAllEqual(20, r_value[1])
@test_util.run_in_graph_and_eager_modes
def testFoldl_MultiInputSingleOutput(self):
elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
initializer = np.array(1.0)
r = functional_ops.foldl(lambda a, x: a + x[0] + x[1], (elems, -elems),
initializer)
self.assertAllEqual(1, self.evaluate(r))
@test_util.run_in_graph_and_eager_modes
def testFoldl_MultiInputDifferentDimsSingleOutput(self):
elems = np.array([[1.0, 1.0, 1.0], [2.0, 3.0, 4.0]])
other_elems = np.array([-1.0, 1.0])
initializer = np.array([0.0, 0.0, 0.0])
r = functional_ops.foldl(lambda a, x: a + x[0] * x[1],
(elems, other_elems), initializer)
self.assertAllEqual([1.0, 2.0, 3.0], self.evaluate(r))
@test_util.run_deprecated_v1
def testFoldl_Scoped(self):
with self.cached_session() as sess:
with variable_scope.variable_scope("root") as varscope:
elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
r = functional_ops.foldl(simple_scoped_fn, elems)
# Check that we have the one variable we asked for here.
self.assertEqual(len(variables.trainable_variables()), 1)
self.assertEqual(variables.trainable_variables()[0].name,
"root/body/two:0")
sess.run([variables.global_variables_initializer()])
self.assertAllEqual(208, self.evaluate(r))
# Now let's reuse our single variable.
varscope.reuse_variables()
r = functional_ops.foldl(simple_scoped_fn, elems, initializer=10)
self.assertEqual(len(variables.trainable_variables()), 1)
self.assertAllEqual(880, self.evaluate(r))
@test_util.run_in_graph_and_eager_modes
def testFoldr_Simple(self):
elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
r = functional_ops.foldr(
lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
elems)
self.assertAllEqual(450, self.evaluate(r))
r = functional_ops.foldr(
lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
elems,
initializer=10)
self.assertAllEqual(1282, self.evaluate(r))
@test_util.run_in_graph_and_eager_modes
def testFoldr_SingleInputMultiOutput(self):
elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
initializer = np.array([1, -1.0])
r = functional_ops.foldr(lambda a, x: a + x, elems, initializer)
r_value = self.evaluate(r)
self.assertAllEqual(22, r_value[0])
self.assertAllEqual(20, r_value[1])
@test_util.run_in_graph_and_eager_modes
def testFoldr_MultiInputSingleOutput(self):
elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
initializer = np.array(1.0)
r = functional_ops.foldr(lambda a, x: a + x[0] + x[1], (elems, -elems),
initializer)
self.assertAllEqual(1, self.evaluate(r))
@test_util.run_deprecated_v1
def testFoldr_Scoped(self):
with self.cached_session() as sess:
with variable_scope.variable_scope("root") as varscope:
elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
r = functional_ops.foldr(simple_scoped_fn, elems)
# Check that we have the one variable we asked for here.
self.assertEqual(len(variables.trainable_variables()), 1)
self.assertEqual(variables.trainable_variables()[0].name,
"root/body/two:0")
sess.run([variables.global_variables_initializer()])
self.assertAllEqual(450, self.evaluate(r))
# Now let's reuse our single variable.
varscope.reuse_variables()
r = functional_ops.foldr(simple_scoped_fn, elems, initializer=10)
self.assertEqual(len(variables.trainable_variables()), 1)
self.assertAllEqual(1282, self.evaluate(r))
# pylint: disable=unnecessary-lambda
@test_util.run_deprecated_v1
def testFold_Grad(self):
with self.cached_session():
elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
v = constant_op.constant(2.0, name="v")
r = functional_ops.foldl(
lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
r = gradients_impl.gradients(r, v)[0]
self.assertAllEqual(720.0, self.evaluate(r))
r = functional_ops.foldr(
lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
r = gradients_impl.gradients(r, v)[0]
self.assertAllEqual(720.0, self.evaluate(r))
# pylint: enable=unnecessary-lambda
@test_util.run_in_graph_and_eager_modes
def testScan_Simple(self):
elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
v = constant_op.constant(2.0, name="v")
# pylint: disable=unnecessary-lambda
r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems)
self.assertAllEqual([1., 2., 6., 24., 120., 720.], self.evaluate(r))
r = functional_ops.scan(
lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r))
# pylint: enable=unnecessary-lambda
@test_util.run_in_graph_and_eager_modes
def testScan_Reverse(self):
elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
v = constant_op.constant(2.0, name="v")
# pylint: disable=unnecessary-lambda
r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems,
reverse=True)
self.assertAllEqual([720., 720., 360., 120., 30., 6.], self.evaluate(r))
r = functional_ops.scan(
lambda a, x: math_ops.multiply(a, x), elems, initializer=v,
reverse=True)
self.assertAllEqual([1440., 1440., 720., 240., 60., 12.],
self.evaluate(r))
# pylint: enable=unnecessary-lambda
@test_util.run_in_graph_and_eager_modes
def testScan_SingleInputMultiOutput(self):
elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
initializer = (np.array(1.0), np.array(-1.0))
r = functional_ops.scan(lambda a, x: (a[0] * x, -a[1] * x), elems,
initializer)
r_value = self.evaluate(r)
self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0])
self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1])
@test_util.run_in_graph_and_eager_modes
def testScan_MultiInputSingleOutput(self):
elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
initializer = np.array(1.0)
# Multiply a * 1 each time
r = functional_ops.scan(lambda a, x: a * (x[0] + x[1]),
(elems + 1, -elems), initializer)
self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], self.evaluate(r))
@test_util.run_in_graph_and_eager_modes
def testScan_MultiInputSameTypeOutput(self):
elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
r = functional_ops.scan(lambda a, x: (a[0] + x[0], a[1] + x[1]),
(elems, -elems))
r_value = self.evaluate(r)
self.assertAllEqual(np.cumsum(elems), r_value[0])
self.assertAllEqual(np.cumsum(-elems), r_value[1])
@test_util.run_in_graph_and_eager_modes
def testScan_MultiOutputMismatchedInitializer(self):
elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
initializer = np.array(1.0)
# Multiply a * 1 each time
with self.assertRaisesRegexp(
ValueError, "two structures don't have the same nested structure"):
functional_ops.scan(lambda a, x: (a, -a), elems, initializer)
@test_util.run_deprecated_v1
def testScan_Scoped(self):
with self.cached_session() as sess:
with variable_scope.variable_scope("root") as varscope:
elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
r = functional_ops.scan(simple_scoped_fn, elems)
# Check that we have the one variable we asked for here.
self.assertEqual(len(variables.trainable_variables()), 1)
self.assertEqual(variables.trainable_variables()[0].name,
"root/body/two:0")
sess.run([variables.global_variables_initializer()])
results = np.array([1, 6, 18, 44, 98, 208])
self.assertAllEqual(results, self.evaluate(r))
# Now let's reuse our single variable.
varscope.reuse_variables()
r = functional_ops.scan(simple_scoped_fn, elems, initializer=2)
self.assertEqual(len(variables.trainable_variables()), 1)
results = np.array([6, 16, 38, 84, 178, 368])
self.assertAllEqual(results, self.evaluate(r))
@test_util.run_in_graph_and_eager_modes
def testScanFoldl_Nested(self):
elems = constant_op.constant([1.0, 2.0, 3.0, 4.0], name="data")
inner_elems = constant_op.constant([0.5, 0.5], name="data")
def r_inner(a, x):
return functional_ops.foldl(
lambda b, y: b * y * x, inner_elems, initializer=a)
r = functional_ops.scan(r_inner, elems)
# t == 0 (returns 1)
# t == 1, a == 1, x == 2 (returns 1)
# t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1
# t_1 == 1, b == 1, y == 0.5, returns b * y * x = 1
# t == 2, a == 1, x == 3 (returns 1.5*1.5 == 2.25)
# t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1.5
# t_1 == 1, b == 1.5, y == 0.5, returns b * y * x = 1.5*1.5
# t == 3, a == 2.25, x == 4 (returns 9)
# t_0 == 0, b == a == 2.25, y == 0.5, returns b * y * x = 4.5
# t_1 == 1, b == 4.5, y == 0.5, returns b * y * x = 9
self.assertAllClose([1., 1., 2.25, 9.], self.evaluate(r))
@test_util.run_deprecated_v1
def testScan_Control(self):
with self.cached_session() as sess:
s = array_ops.placeholder(dtypes.float32, shape=[None])
b = array_ops.placeholder(dtypes.bool)
with ops.control_dependencies([b]):
c = functional_ops.scan(lambda a, x: x * a, s)
self.assertAllClose(
np.array([1.0, 3.0, 9.0]), sess.run(c, {s: [1, 3, 3],
b: True}))
@test_util.run_deprecated_v1
def testScan_Grad(self):
with self.cached_session():
elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
v = constant_op.constant(2.0, name="v")
# pylint: disable=unnecessary-lambda
r = functional_ops.scan(
lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
# pylint: enable=unnecessary-lambda
r = gradients_impl.gradients(r, v)[0]
self.assertAllEqual(873.0, self.evaluate(r))
@test_util.run_deprecated_v1
def testScanGradientWithPartStopGradient(self):
a = variables.Variable(0.0, name="a")
b = variables.Variable(0.0, name="b")
elems = array_ops.zeros(5)
l0, l1 = functional_ops.scan(
lambda elem_, input_: (a, b), elems, initializer=(0., 0.))
loss = l0 + array_ops.stop_gradient(l1)
grad = gradients_impl.gradients(ys=[loss], xs=[a, b])
with self.test_session(use_gpu=True) as sess:
self.evaluate(variables.global_variables_initializer())
self.evaluate(grad)
@test_util.run_in_graph_and_eager_modes
def testFoldShape(self):
x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
def fn(_, current_input):
return current_input
initializer = constant_op.constant([0, 0, 0])
y = functional_ops.foldl(fn, x, initializer=initializer)
self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
@test_util.run_in_graph_and_eager_modes
def testScanShape(self):
x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
def fn(_, current_input):
return current_input
initializer = constant_op.constant([0, 0, 0])
y = functional_ops.scan(fn, x, initializer=initializer)
self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
# TODO(akshayka): this test fails in eager: the iterable is of length 0 so
# so the body of the while loop never executes
@test_util.run_deprecated_v1
def testScanEmptyTensor(self):
with self.cached_session():
x = functional_ops.scan(
lambda x, _: x, math_ops.range(0), initializer=array_ops.ones([2, 4]))
self.assertAllEqual([0, 2, 4], x.get_shape())
self.assertAllEqual(x.get_shape(), self.evaluate(x).shape)
@test_util.run_deprecated_v1
def testScanUnknownShape(self):
x = array_ops.placeholder(dtypes.float32)
initializer = array_ops.placeholder(dtypes.float32)
def fn(_, current_input):
return current_input
y = functional_ops.scan(fn, x, initializer=initializer)
self.assertIs(None, y.get_shape().dims)
@test_util.run_deprecated_v1
def testScanVaryingShape(self):
with self.cached_session() as sess:
x = array_ops.placeholder(dtype=dtypes.float32, shape=[None, 2])
x_t = array_ops.transpose(x)
# scan over dimension 0 (with shape None)
result = functional_ops.scan(lambda a, x: a + x, x)
# scanned over transposed dimension 0 (with shape 2)
result_t = functional_ops.scan(lambda a, x: a + x, x_t, infer_shape=False)
# ensure gradients can be calculated
result_grad = gradients_impl.gradients(result, [x])[0]
result_t_grad = gradients_impl.gradients(result_t, [x_t])[0]
# smoke test to ensure they all evaluate
sess.run([result, result_t, result_grad, result_t_grad],
feed_dict={x: [[1.0, 2.0]]})
@test_util.run_deprecated_v1
def testRemoteFunction(self):
worker_config = config_pb2.ConfigProto()
worker_config.device_count["CPU"] = 2
worker, _ = test_util.create_local_cluster(
1, 1, worker_config=worker_config)
@function.Defun(dtypes.int32, dtypes.int32)
def _remote_fn(a, b):
return math_ops.multiply(a, b)
with ops.device("/job:ps/task:0"):
a = variables.Variable(2, dtype=dtypes.int32)
b = variables.Variable(3, dtype=dtypes.int32)
with ops.device("/job:worker/replica:0/task:0/cpu:0"):
remote_op = functional_ops.remote_call(
args=[a, b],
Tout=[dtypes.int32],
f=_remote_fn,
target="/job:worker/replica:0/task:0/cpu:1")
with session.Session(worker[0].target) as sess:
self.evaluate(variables.global_variables_initializer())
mul = self.evaluate(remote_op)
self.assertEqual(mul, [6])
@test_util.run_deprecated_v1
def testRemoteFunctionDirectSession(self):
worker_config = config_pb2.ConfigProto()
worker_config.device_count["CPU"] = 2
@function.Defun(dtypes.int32, dtypes.int32)
def _remote_fn(a, b):
return math_ops.multiply(a, b)
with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
a = variables.Variable(2, dtype=dtypes.int32)
b = variables.Variable(3, dtype=dtypes.int32)
with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
remote_op = functional_ops.remote_call(
args=[a, b],
Tout=[dtypes.int32],
f=_remote_fn,
target="/job:localhost/replica:0/task:0/cpu:1")
with self.test_session(config=worker_config) as sess:
self.evaluate(variables.global_variables_initializer())
mul = self.evaluate(remote_op)
self.assertEqual(mul, [6])
@test_util.run_deprecated_v1
def testRemoteFunctionSameDeviceDirectSession(self):
@function.Defun(dtypes.int32, dtypes.int32)
def _remote_fn(a, b):
return math_ops.multiply(a, b)
with ops.device("/cpu:0"):
a = variables.Variable(2, dtype=dtypes.int32)
b = variables.Variable(3, dtype=dtypes.int32)
with ops.device("/cpu:0"):
remote_op = functional_ops.remote_call(
args=[a, b], Tout=[dtypes.int32], f=_remote_fn, target="/cpu:0")
with self.cached_session() as sess:
self.evaluate(variables.global_variables_initializer())
mul = self.evaluate(remote_op)
self.assertEqual(mul, [6])
@test_util.run_deprecated_v1
def testRemoteFunctionCPUGPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@function.Defun(dtypes.float32, dtypes.float32)
def _remote_fn(a, b):
return math_ops.multiply(a, b)
with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
a = variables.Variable(2, dtype=dtypes.float32)
b = variables.Variable(3, dtype=dtypes.float32)
with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
remote_op = functional_ops.remote_call(
args=[a, b],
Tout=[dtypes.float32],
f=_remote_fn,
target="/job:localhost/replica:0/task:0/device:GPU:0")[0] + 3.0
with self.cached_session() as sess:
self.evaluate(variables.global_variables_initializer())
mul = self.evaluate(remote_op)
self.assertEqual(mul, 9.0)
@test_util.run_deprecated_v1
def testRemoteFunctionGPUCPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@function.Defun(dtypes.float32, dtypes.float32)
def _remote_fn(a, b):
return math_ops.multiply(a, b)
with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
a = variables.Variable(2, dtype=dtypes.float32)
b = variables.Variable(3, dtype=dtypes.float32)
with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
remote_op = functional_ops.remote_call(
args=[a, b],
Tout=[dtypes.float32],
f=_remote_fn,
target="/job:localhost/replica:0/task:0/cpu:0")[0] + 3.0
with self.cached_session() as sess:
self.evaluate(variables.global_variables_initializer())
mul = self.evaluate(remote_op)
self.assertEqual(mul, 9.0)
@test_util.run_deprecated_v1
def testRemoteFunctionGPUCPUStrings(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@function.Defun(dtypes.string)
def _remote_fn(inp):
return array_ops.identity(inp)
a = array_ops.constant("a")
with ops.device("/gpu:0"):
remote_op = functional_ops.remote_call(
args=[a], Tout=[dtypes.string], f=_remote_fn, target="/cpu:0")
with self.cached_session() as sess:
ret = self.evaluate(remote_op)
self.assertAllEqual(ret, [b"a"])
@test_util.run_deprecated_v1
def testRemoteFunctionCrossProcess(self):
workers, _ = test_util.create_local_cluster(2, 1)
@function.Defun(dtypes.float32, dtypes.float32)
def _remote_fn(a, b):
return math_ops.multiply(a, b)
with ops.device("/job:ps/task:0"):
a = variables.Variable(2, dtype=dtypes.float32)
b = variables.Variable(3, dtype=dtypes.float32)
with ops.device("/job:worker/replica:0/task:0/cpu:0"):
remote_op = functional_ops.remote_call(
args=[a, b],
Tout=[dtypes.float32],
f=_remote_fn,
target="/job:worker/replica:0/task:1/cpu:0")[0] + 3.0
with session.Session(workers[0].target) as sess:
self.evaluate(variables.global_variables_initializer())
mul = self.evaluate(remote_op)
self.assertEqual(mul, 9)
@test_util.run_deprecated_v1
def testIf(self):
@function.Defun(dtypes.float32)
def Twice(x):
return x * 2
@function.Defun(dtypes.float32)
def Thrice(x):
return x * 3 + 1
with self.test_session(use_gpu=False) as sess:
x = array_ops.placeholder(dtypes.float32)
ret = functional_ops.If(math_ops.greater(x, 0), [x], Twice, Thrice)[0]
self.assertAllEqual(sess.run(ret, feed_dict={x: 9.}), 18.)
self.assertAllEqual(sess.run(ret, feed_dict={x: -8.}), -23.)
self.assertAllEqual(sess.run(ret, feed_dict={x: 0.}), 1.)
def testWhile(self):
for use_gpu in (True, False):
with ops.Graph().as_default() as g:
@function.Defun(*[dtypes.float32] * 2)
def Cond(n, unused_x):
return n > 0
@function.Defun(*[dtypes.float32] * 2)
def Body(n, x):
return n - 1, x + n
def Run(sess, n):
return sess.run(functional_ops.While([n, 0.], Cond, Body))[1]
with self.session(graph=g, use_gpu=use_gpu) as sess:
self.assertAllEqual(Run(sess, 20.), 210.)
self.assertAllEqual(Run(sess, 100.), 5050.)
def testToBool(self):
# For 0D tensors, the truthiness depends on whether the value is "zero".
self.assertAllEqual(gen_functional_ops.to_bool(0), False)
self.assertAllEqual(gen_functional_ops.to_bool(1), True)
self.assertAllEqual(gen_functional_ops.to_bool(42), True)
self.assertAllEqual(gen_functional_ops.to_bool(0.), False)
self.assertAllEqual(gen_functional_ops.to_bool(1.), True)
self.assertAllEqual(gen_functional_ops.to_bool(42.), True)
self.assertAllEqual(gen_functional_ops.to_bool(False), False)
self.assertAllEqual(gen_functional_ops.to_bool(True), True)
# For strings, "zero" is the empty string.
self.assertAllEqual(gen_functional_ops.to_bool(""), False)
self.assertAllEqual(gen_functional_ops.to_bool("a"), True)
# For >0D tensors, the truthiness only depends on whether there are
# elements or not.
self.assertAllEqual(gen_functional_ops.to_bool([]), False)
self.assertAllEqual(gen_functional_ops.to_bool([[]]), False)
self.assertAllEqual(gen_functional_ops.to_bool([[[]]]), False)
self.assertAllEqual(gen_functional_ops.to_bool([0]), True)
self.assertAllEqual(gen_functional_ops.to_bool([1]), True)
self.assertAllEqual(gen_functional_ops.to_bool([[0]]), True)
self.assertAllEqual(gen_functional_ops.to_bool([False]), True)
self.assertAllEqual(gen_functional_ops.to_bool([True]), True)
# Like above, but using int32 in order to ensure that int32 tensors don't get
# copied to the GPU during the application of the while.
def testWhileInt32(self):
with ops.Graph().as_default() as g:
@function.Defun(*[dtypes.int32] * 2)
def Cond(n, unused_x):
return n > 0
@function.Defun(*[dtypes.int32] * 2)
def Body(n, x):
return n - 1, x + n
def Run(sess, n):
return sess.run(functional_ops.While([n, 0], Cond, Body))[1]
with self.session(graph=g, use_gpu=True) as sess:
self.assertAllEqual(Run(sess, 20), 210)
self.assertAllEqual(Run(sess, 100), 5050)
@test_util.run_deprecated_v1
def testWhileLowering(self):
def Run(n, fetch_by_name):
for use_gpu in (True, False):
with ops.Graph().as_default() as g:
@function.Defun(*[dtypes.float32] * 2)
def Cond(n, unused_x):
return n > 0
@function.Defun(*[dtypes.float32] * 2)
def Body(n, x):
return n - 1, x + n
# outputs: [0, n*(n+1)/2]
outputs = functional_ops.While([n, 0.], Cond, Body, name="my_while")
# `outputs` is the list of output tensors of the While op. We
# arbitrarily choose the 0th tensor to get the While op and set the
# lowering attribute on it.
outputs[0].op._set_attr("_lower_using_switch_merge",
attr_value_pb2.AttrValue(b=True))
if not fetch_by_name:
fetch = outputs[1]
else:
fetch = "my_while:1"
with self.session(graph=g, use_gpu=use_gpu) as sess:
return self.evaluate(fetch)
self.assertAllEqual(Run(20., False), 210.)
self.assertAllEqual(Run(20., True), 210.)
self.assertAllEqual(Run(100., False), 5050.)
self.assertAllEqual(Run(100., True), 5050.)
@test_util.run_v1_only("b/120545219")
@test_util.disable_xla("b/123337890") # Different error message
def testWhileError(self):
for use_gpu in (True, False):
with ops.Graph().as_default() as g:
@function.Defun(*[dtypes.float32] * 2)
def Cond(n, unused_x):
return n > 0
@function.Defun(*[dtypes.float32] * 2)
def CondReturnsTooManyArgs(n, x):
return n > 0, x
@function.Defun(*[dtypes.float32] * 2)
def Body(n, x):
return n - 1, x + n
@function.Defun(*[dtypes.float32] * 2)
def BodyReturnsTooManyArgs(n, x):
return n - 1, x + n, x
with self.session(graph=g, use_gpu=use_gpu):
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Expected a single scalar.*got 2 tensors."):
functional_ops.While([5., 0.], CondReturnsTooManyArgs,
Body)[0].eval()
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"While loop body returned 3 arguments. Expected: 2"):
functional_ops.While([5., 0.], Cond,
BodyReturnsTooManyArgs)[0].eval()
def testWhileInMultipleSubgraphs(self):
for use_gpu in (True, False):
with ops.Graph().as_default() as g:
@function.Defun(*[dtypes.float32] * 2)
def Cond(n, x): # pylint: disable=unused-argument
return n > 0
@function.Defun(*[dtypes.float32] * 2)
def Body(n, x):
return n - 1, x + n
with self.session(graph=g, use_gpu=use_gpu) as sess:
n = array_ops.placeholder(dtypes.float32)
_, result = functional_ops.While([n, 0.], Cond, Body)
c = constant_op.constant(37.)
self.assertAllEqual(210., sess.run(result, feed_dict={n: 20.}))
self.assertAllEqual(5050., sess.run(result, feed_dict={n: 100.}))
# Test that the result is the same when we run a different subgraph.
self.assertAllEqual(5050.,
sess.run([result, c], feed_dict={n: 100.})[0])
# pylint: disable=cell-var-from-loop
def testWhileCapturedInputs(self):
for use_gpu in (True, False):
with ops.Graph().as_default() as g:
v = variables.Variable(1.0)
def TestCond(n, *args):
del args
return n < 10
@function.Defun(*[dtypes.float32] * 2)
def TestUnary(n, x):
return math_ops.add(n, 1), x + n + v
@function.Defun(*[dtypes.float32] * 3)
def TestBinary(n, x, x2):
return math_ops.add(n, 1), x + n + v, x2 + v
with self.session(graph=g, use_gpu=use_gpu) as sess:
result_unary = functional_ops.While(
[1.0, 0.],
function.Defun(*[dtypes.float32] * 2)(TestCond), TestUnary)
result_binary = functional_ops.While(
[1.0, 0., 0.],
function.Defun(*[dtypes.float32] * 3)(TestCond), TestBinary)
self.evaluate(variables.global_variables_initializer())
assert len(result_unary) == 2
self.assertEqual([10.0, 54.0], self.evaluate(result_unary))
assert len(result_binary) == 3
self.assertEqual([10.0, 54.0, 9.0], self.evaluate(result_binary))
def TestCondCapture(n, *args):
del args
return math_ops.cast(n, dtypes.float32) + v < 10
with self.assertRaises(ValueError):
_ = functional_ops.While(
[1],
function.Defun(dtypes.int32)(TestCondCapture),
function.Defun(dtypes.int32, dtypes.float32)(TestUnary))
# pylint: enable=cell-var-from-loop
def _tfSum(self, use_gpu, rewrite_with_while):
with ops.Graph().as_default() as g:
with self.session(graph=g, use_gpu=use_gpu) as sess:
@function.Defun(dtypes.int32, dtypes.float32)
def Body(n, x):
return x + math_ops.cast(n, dtypes.float32)
xs = [
# 1 + 2 + ... + 20
functional_ops.For(
1, 21, 1, [0.], Body, rewrite_with_while=rewrite_with_while)[0],
# 100 + 99 + ... + 1
functional_ops.For(
100, 0, -1, [0.], Body, rewrite_with_while=rewrite_with_while)
[0],
]
xvals = self.evaluate(xs)
self.assertAllEqual(210, xvals[0])
self.assertAllEqual(5050, xvals[1])
def testFor(self):
for use_gpu in (True, False):
self._tfSum(use_gpu, False)
def testForWithWhile(self):
for use_gpu in (True, False):
self._tfSum(use_gpu, True)
def testForWithWhileNaming(self):
g = ops.Graph()
with g.as_default():
@function.Defun(dtypes.int32, dtypes.float32, func_name="TestBody")
def TestBody(n, x):
return x + math_ops.cast(n, dtypes.float32)
_ = functional_ops.For(
1, 21, 1, [0.], TestBody, rewrite_with_while=True)[0]
names = []
for func in g.as_graph_def().library.function:
names.append(func.signature.name)
self.assertTrue("TestBody" in names)
self.assertTrue("TestBody_Cond" in names)
self.assertTrue("TestBody_Body" in names)
@test_util.run_deprecated_v1
def testForCapturedInputs(self):
v = variables.Variable(1.0)
@function.Defun(dtypes.int32)
def TestNullary(n):
v + math_ops.cast(n, dtypes.float32) # pylint: disable=expression-not-assigned
@function.Defun(dtypes.int32, dtypes.float32)
def TestUnary(n, x):
return x + math_ops.cast(n, dtypes.float32) + v
@function.Defun(dtypes.int32, dtypes.float32, dtypes.float32)
def TestBinary(n, x, x2):
return x + math_ops.cast(n, dtypes.float32) + v, x2 + v
for rewrite_with_while in (True, False):
use_gpu = not rewrite_with_while
with self.test_session(use_gpu=use_gpu) as sess:
result_nullary = functional_ops.For(
1, 10, 1, [], TestNullary,
rewrite_with_while=rewrite_with_while)
result_unary = functional_ops.For(
1, 10, 1, [0.], TestUnary,
rewrite_with_while=rewrite_with_while)
result_binary = functional_ops.For(
1, 10, 1, [0., 0.], TestBinary,
rewrite_with_while=rewrite_with_while)
self.evaluate(variables.global_variables_initializer())
assert not result_nullary
# The nullary variant doesn't return anything so we can't easily run it.
# As a total hack, fetch the operation by name and run it.
sess.run(ops.get_default_graph().get_operation_by_name(
"While" if rewrite_with_while else "For"))
assert len(result_unary) == 1
self.assertEqual([54.0], self.evaluate(result_unary))
assert len(result_binary) == 2
self.assertEqual([54.0, 9.0], self.evaluate(result_binary))
def _tfMLP(self, xval, wsval, bsval, rewrite_with_while):
# On GPU, don't rewrite using a while loop.
use_gpu = not rewrite_with_while
with self.test_session(use_gpu=use_gpu):
@function.Defun(dtypes.int32, *[dtypes.float64] * 3)
def MLP(i, a, ws, bs):
a = math_ops.tanh(math_ops.matmul(a, ws[i, :]) + bs[i, :])
return a, ws, bs
ret = functional_ops.For(
0,
wsval.shape[0],
1, [xval, wsval, bsval],
MLP,
rewrite_with_while=rewrite_with_while)[0]
return self.evaluate(ret)
def _npMLP(self, xval, wsval, bsval):
for i in range(wsval.shape[0]):
xval = np.tanh(np.dot(xval, wsval[i, :]) + bsval[i, :])
return xval
def _testForMLP(self, rewrite_with_while):
# We construct a 5-layer Multi-Layer Perceptron network here.
# Each layer have the same number of hidden unites (3), and the
# activation function is tanh(). We feed the input (xval) with
# batch size 2.
xval = np.random.normal(size=(2, 3))
wsval = np.random.normal(size=(5, 3, 3))
bsval = np.random.normal(size=(5, 3))
np_ans = self._npMLP(xval, wsval, bsval)
tf_for_ans = self._tfMLP(xval, wsval, bsval, rewrite_with_while)
self.assertAllClose(np_ans, tf_for_ans)
@test_util.run_deprecated_v1
def testForMLP(self):
self._testForMLP(False)
@test_util.run_deprecated_v1
@test_util.disable_xla(
"Test uses strided slice without compile time constant values")
def testForMLPWhile(self):
self._testForMLP(True)
@test_util.run_v1_only("b/120545219")
def testForError(self):
@function.Defun(dtypes.int32, dtypes.float32)
def Foo(i, v):
return math_ops.cast(i, dtypes.float32) + v
@function.Defun(dtypes.int32, dtypes.float32)
def ReturnsTooManyArgs(unused_i, v):
return v, v
with self.test_session(use_gpu=True):
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"must be a scalar"):
functional_ops.For([0], 10, 1, [0.0], Foo)[0].eval()
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"Invalid start/limit/delta"):
functional_ops.For(0, 10, -1, [0.0], Foo)[0].eval()
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"For loop body returned 2 arguments. Expected: 1"):
functional_ops.For(0, 10, 1, [0.0], ReturnsTooManyArgs)[0].eval()
@test_util.run_deprecated_v1
def testGradient(self):
@function.Defun(dtypes.float32)
def Poly(x):
# y = 2x^3+3x^2+4x+8
return 2 * x * x * x + 3 * x * x + 4 * x + 8
@function.Defun(dtypes.float32)
def Grad(x):
# dy/dx = dy/dy * dy/dx = 1.0 * (6x^2+6x+4)
return functional_ops.Gradient([x, 1.0], Poly)[0]
with self.test_session(use_gpu=False) as sess:
a = constant_op.constant(0.)
avals = [Poly(a), Grad(a)]
b = constant_op.constant(1.)
bvals = [Poly(b), Grad(b)]
self.assertAllEqual(self.evaluate(avals), [8., 4.])
self.assertAllEqual(self.evaluate(bvals), [17., 16.])
# TODO(akshayka): Replace `function.Defun` with tf.contrib.eager.defun` in the
# below test cases.
class PartitionedCallTest(test.TestCase):
@test_util.run_deprecated_v1
def testRemoteDeviceInPartitionedCallOp(self):
workers, _ = test_util.create_local_cluster(2, 0)
worker0_device = "/job:worker/replica:0/task:0/cpu:0"
worker1_device = "/job:worker/replica:0/task:1/cpu:0"
@eager_def_function.function
def f(a, b):
return a + b
with session.Session(workers[0].target) as sess:
with ops.device(worker0_device):
a = variable_scope.get_variable(
"a", initializer=constant_op.constant(1.), use_resource=True)
with ops.device(worker1_device):
b = variable_scope.get_variable(
"b", initializer=constant_op.constant(1.), use_resource=True)
sess.run(variables.global_variables_initializer())
config = config_pb2.ConfigProto()
config.share_cluster_devices_in_session = True
with session.Session(workers[0].target, config=config) as sess:
res = sess.run(f(a, b))
self.assertEqual(res, 2)
@test_util.run_deprecated_v1
def testBasicSingleDevice(self):
@function.Defun(*[dtypes.float32] * 2)
def Body(x, y):
with ops.device("/cpu:0"):
a = x + x
b = y + y
return a + b
output, = self.evaluate(
functional_ops.partitioned_call(
args=[constant_op.constant(1.),
constant_op.constant(2.)], f=Body))
self.assertEqual(output, 6.)
@test_util.run_deprecated_v1
def testBasicMultiDevice(self):
config = config_pb2.ConfigProto(device_count={"CPU": 3})
@function.Defun(*[dtypes.float32] * 2)
def Body(x, y):
# if x = 1, y = 2, ...
with ops.device("/cpu:0"):
# a:= 1 + 1 = 2
a = x + x
with ops.device("/cpu:1"):
# b:= 2 + 2 = 4
b = a + y
with ops.device("/cpu:2"):
# c:= 2 + 4 = 6
c = a + b
# a + b + c = 2 + 4 + 6 = 12
return a + b + c
with self.test_session(config=config):
output, = functional_ops.partitioned_call(
args=[constant_op.constant(1.),
constant_op.constant(2.)], f=Body)
self.assertEqual(output.eval(), 12.)
@test_util.run_deprecated_v1
def testBasicMultiDeviceGPU(self):
if not test_util.is_gpu_available():
return
@function.Defun(*[dtypes.float32] * 2)
def Body(x, y):
with ops.device("/gpu:0"):
a = x + x
b = y + y
with ops.device("/cpu:0"):
c = a + b
return c
output, = self.evaluate(
functional_ops.partitioned_call(
args=[constant_op.constant(1.),
constant_op.constant(2.)], f=Body))
self.assertEqual(output, 6.)
@test_util.run_deprecated_v1
def testBasicNoDeviceAnnotations(self):
@function.Defun(*[dtypes.float32] * 2)
def Body(x, y):
a = x + x
b = y + y
return a + b
output, = self.evaluate(
functional_ops.partitioned_call(
args=[constant_op.constant(1.),
constant_op.constant(2.)], f=Body))
self.assertEqual(output, 6.)
@test_util.run_deprecated_v1
def testShardsRunOnRequestedDevices(self):
config = config_pb2.ConfigProto(device_count={"CPU": 4})
@function.Defun()
def Body():
# Serialize DT_RESOURCE handles as DT_STRINGs, which encode the device on
# which the resource was created, so that we can verify that ops were
# actually run on the requested devices.
#
# TODO(akshayka): Provide a cleaner, more idiomatic API for obtaining the
# name of the device on which a resource lives / for determining the
# device on which an op ran.
with ops.device("/cpu:0"):
s1 = iterator_ops.Iterator.from_structure(
(dtypes.float32,)).string_handle()
with ops.device("/cpu:1"):
s2 = iterator_ops.Iterator.from_structure(
(dtypes.float32,)).string_handle()
with ops.device("/cpu:2"):
s3 = iterator_ops.Iterator.from_structure(
(dtypes.float32,)).string_handle()
return s1, s2, s3
with self.test_session(config=config, use_gpu=True) as sess:
outputs = sess.run(functional_ops.partitioned_call(args=[], f=Body))
self.assertIn(compat.as_bytes("CPU:0"), outputs[0])
self.assertIn(compat.as_bytes("CPU:1"), outputs[1])
self.assertIn(compat.as_bytes("CPU:2"), outputs[2])
@test_util.run_deprecated_v1
def testAssignAddResourceVariable(self):
v = resource_variable_ops.ResourceVariable(1.0)
@function.Defun()
def AssignAdd():
v.assign_add(1.0)
op = functional_ops.partitioned_call(
args=AssignAdd.captured_inputs, f=AssignAdd)
_ = self.evaluate(variables.global_variables_initializer())
_ = self.evaluate(op)
value = self.evaluate(v.read_value())
self.assertEqual(value, 2.0)
@test_util.run_deprecated_v1
def testFunctionWithResourcesOnDifferentDevices(self):
if not test_util.is_gpu_available():
self.skipTest("No GPUs available.")
with ops.device("/cpu:0"):
v_cpu_zero = resource_variable_ops.ResourceVariable(
[0.0, 1.0, 2.0], name="v_cpu_zero")
with ops.device("/cpu:1"):
v_cpu_one = resource_variable_ops.ResourceVariable(
[0.0, 1.0, 2.0], name="v_cpu_one")
with ops.device("/gpu:0"):
v_gpu = resource_variable_ops.ResourceVariable(
[0.0, 1.0, 2.0], name="v_gpu")
def sum_gather():
cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu_zero, [1, 2]))
also_cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu_one, [1, 2]))
gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2]))
return cpu_result, also_cpu_result, gpu_result
defined = function.Defun()(sum_gather)
with self.test_session(
config=config_pb2.ConfigProto(
allow_soft_placement=False,
log_device_placement=True,
device_count={"CPU": 2})) as sess:
self.evaluate(variables.global_variables_initializer())
expected = self.evaluate(sum_gather())
result = sess.run(
functional_ops.partitioned_call(
args=defined.captured_inputs, f=defined))
self.assertAllEqual(expected, result)
# Use an invalid executor name to test the plumbing of the executor_type attr.
@test_util.run_v1_only("b/120545219")
def testExecutorTypeAttrExecutorNotFound(self):
@function.Defun(dtypes.int32)
def AddFive(x):
return x + 5
op = functional_ops.partitioned_call(
args=[constant_op.constant([1, 2, 3], dtype=dtypes.int32)],
f=AddFive,
executor_type="NON_EXISTENT_EXECUTOR")
with self.assertRaisesRegexp(errors.NotFoundError,
"NON_EXISTENT_EXECUTOR"):
self.evaluate(op)
@test_util.run_all_in_graph_and_eager_modes
@test_util.with_control_flow_v2
class FunctionalOpsCaseTest(test.TestCase):
def testCase(self):
@eager_function.defun
def two(x):
return x * 2
@eager_function.defun
def three(x):
return x * 3
@eager_function.defun
def four(x):
return x * 4
def f(branch, x):
tmpl = array_ops.zeros_like(x)
return array_ops.identity(gen_functional_ops.case(
branch, input=[x], Tout=[dtypes.float32],
branches=[f.get_concrete_function(tmpl)
for f in (two, three, four)])[0])
one = array_ops.ones([])
self.assertAllEqual(np.float32(2), self.evaluate(f(0, one)))
self.assertAllEqual(np.float32(3), self.evaluate(f(1, one)))
self.assertAllEqual(np.float32(4), self.evaluate(f(2, one)))
self.assertAllEqual(np.float32(4), self.evaluate(f(-1, one))) # <0 default
self.assertAllEqual(np.float32(4), self.evaluate(f(6, one))) # >=N default
@test_util.run_deprecated_v1
@test_util.disable_xla("Don't lower for XLA")
def testSkipEagerCaseLoweringPreservesNameForFetch(self):
for use_gpu in (True, False):
def Run(branch, x, fetch_by_name, use_gpu=use_gpu):
with ops.Graph().as_default() as g:
@function.Defun(dtypes.float32)
def two(x):
return -1, x * 2
@function.Defun(dtypes.float32)
def three(x):
return 0, x * 3
@function.Defun(dtypes.float32)
def four(x):
return 1, x * 4
outputs = gen_functional_ops.case(branch, input=[x],
Tout=[dtypes.int32, dtypes.float32],
branches=[two, three, four],
name="my_case")
# `outputs` is the list of output tensors of the Case op. We
# arbitrarily choose the 0th tensor to get the Case op and set the
# lowering attribute on it.
outputs[0].op._set_attr("_lower_using_switch_merge",
attr_value_pb2.AttrValue(b=True))
outputs = array_ops.identity_n(outputs)
with self.session(graph=g, use_gpu=use_gpu) as sess:
return sess.run("my_case:1" if fetch_by_name else outputs[1])
self.assertAllEqual(2 * 1., Run(0, 1., False))
self.assertAllEqual(2 * 1., Run(0, 1., True))
self.assertAllEqual(3 * 7., Run(1, 7., False))
self.assertAllEqual(3 * 7., Run(1, 7., True))
self.assertAllEqual(4 * -3., Run(2, -3., False))
self.assertAllEqual(4 * -3., Run(2, -3., True))
self.assertAllEqual(4 * -4., Run(7, -4., False)) # >= N default
self.assertAllEqual(4 * -4., Run(7, -4., True)) # >= N default
self.assertAllEqual(4 * -5., Run(-1, -5., False)) # <0 default
self.assertAllEqual(4 * -5., Run(-1, -5., True)) # <0 default
@test_util.disable_xla("Don't lower for XLA")
def testCaseLowering(self):
for use_gpu in (True, False):
@eager_function.defun
def Run(branch, x):
@function.Defun(dtypes.float32)
def two(x):
return -1, x * 2
@function.Defun(dtypes.float32)
def three(x):
return 0, x * 3
@function.Defun(dtypes.float32)
def four(x):
return 1, x * 4
outputs = gen_functional_ops.case(branch, input=[x],
Tout=[dtypes.int32, dtypes.float32],
branches=[two, three, four])
# `outputs` is the list of output tensors of the Case op. We
# arbitrarily choose the 0th tensor to get the Case op and set the
# lowering attribute on it.
outputs[0].op._set_attr("_lower_using_switch_merge",
attr_value_pb2.AttrValue(b=True))
outputs = array_ops.identity_n(outputs)
return outputs[1]
with ops.device(test.gpu_device_name() if use_gpu else "CPU:0"):
self.assertAllEqual(2 * 1., self.evaluate(Run(0, 1.)))
self.assertAllEqual(3 * 7., self.evaluate(Run(1, 7.)))
self.assertAllEqual(4 * -3., self.evaluate(Run(2, -3.)))
self.assertAllEqual(4 * -4., self.evaluate(Run(7, -4.))) # >=N default
self.assertAllEqual(4 * -5., self.evaluate(Run(-1, -5.))) # <0 default
if __name__ == "__main__":
test.main()
# pylint: enable=invalid-name