STT-tensorflow/tensorflow/python/kernel_tests/map_fn_test.py
TensorFlower Gardener dfd659da71 Merge pull request from bhack:partial_decorate
PiperOrigin-RevId: 312132452
Change-Id: I6264d95cb73caf6ddca26567ddad035e94d3ac74
2020-05-18 12:38:27 -07:00

245 lines
9.2 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.kernels.functional_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.eager import def_function
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
# pylint: disable=invalid-name
def simple_scoped_fn(a, x):
"""Simple function: (a, x) -> 2(x+a), but with "2" as a variable in scope."""
with variable_scope.variable_scope("body"):
# Dummy variable, just to check that scoping works as intended.
two = variable_scope.get_variable(
"two", [],
dtype=dtypes.int32,
initializer=init_ops.constant_initializer(2))
return math_ops.multiply(math_ops.add(a, x), two)
@test_util.with_control_flow_v2
class MapFnTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testMap_Simple(self):
nums = [1, 2, 3, 4, 5, 6]
elems = constant_op.constant(nums, name="data")
r = map_fn.map_fn(
lambda x: math_ops.multiply(math_ops.add(x, 3), 2), elems)
self.assertAllEqual(
np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
def testMapDtypeEager(self):
with context.eager_mode():
dtype = map_fn.map_fn(lambda x: constant_op.constant(""),
constant_op.constant([]),
dtype=dtypes.string).dtype
self.assertEqual(dtype, dtypes.string)
def testMapSparseTensor(self):
with self.cached_session():
st = sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 1], [1, 0]],
values=constant_op.constant([0, 1, 2]),
dense_shape=[2, 2])
result = map_fn.map_fn(lambda x: x, st)
self.assertAllEqual(result.indices, st.indices)
self.assertAllEqual(result.values, st.values)
self.assertAllEqual(result.dense_shape, st.dense_shape)
@test_util.run_in_graph_and_eager_modes
def testMapOverScalarErrors(self):
with self.assertRaisesRegexp(ValueError, "not scalars"):
map_fn.map_fn(lambda x: x, [1, 2])
with self.assertRaisesRegexp(ValueError, "not a scalar"):
map_fn.map_fn(lambda x: x, 1)
@test_util.run_deprecated_v1
def testMap_Scoped(self):
with self.cached_session() as sess:
def double_scoped(x):
"""2x with a dummy 2 that is scoped."""
with variable_scope.variable_scope("body"):
# Dummy variable, just to check that scoping works as intended.
two = variable_scope.get_variable(
"two", [],
dtype=dtypes.int32,
initializer=init_ops.constant_initializer(2))
return math_ops.multiply(x, two)
with variable_scope.variable_scope("root") as varscope:
elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
doubles = np.array([2 * x for x in [1, 2, 3, 4, 5, 6]])
r = map_fn.map_fn(double_scoped, elems)
# Check that we have the one variable we asked for here.
self.assertEqual(len(variables.trainable_variables()), 1)
self.assertEqual(variables.trainable_variables()[0].name,
"root/body/two:0")
sess.run([variables.global_variables_initializer()])
self.assertAllEqual(doubles, self.evaluate(r))
# Now let's reuse our single variable.
varscope.reuse_variables()
r = map_fn.map_fn(double_scoped, elems)
self.assertEqual(len(variables.trainable_variables()), 1)
self.assertAllEqual(doubles, self.evaluate(r))
@test_util.run_deprecated_v1
def testMap_Grad(self):
with self.cached_session():
param = constant_op.constant(2.0)
elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
y = map_fn.map_fn(
lambda x: math_ops.multiply(math_ops.square(x), param), elems)
r = gradients_impl.gradients(y, param)[0]
self.assertAllEqual(91.0, self.evaluate(r))
r = gradients_impl.gradients(y, elems)[0]
self.assertAllEqual([4.0, 8.0, 12.0, 16.0, 20.0, 24.0], self.evaluate(r))
@test_util.run_in_graph_and_eager_modes
def testMap_SimpleNotTensor(self):
nums = np.array([1, 2, 3, 4, 5, 6])
r = map_fn.map_fn(
lambda x: math_ops.multiply(math_ops.add(x, 3), 2), nums)
self.assertAllEqual(
np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
@test_util.run_in_graph_and_eager_modes
def testMap_SingleInputMultiOutput(self):
nums = np.array([1, 2, 3, 4, 5, 6])
r = map_fn.map_fn(
lambda x: ((x + 3) * 2, -(x + 3) * 2),
nums,
dtype=(dtypes.int64, dtypes.int64))
self.assertEqual(2, len(r))
self.assertEqual((6,), r[0].get_shape())
self.assertEqual((6,), r[1].get_shape())
received = self.evaluate(r)
self.assertAllEqual((nums + 3) * 2, received[0])
self.assertAllEqual(-(nums + 3) * 2, received[1])
@test_util.run_in_graph_and_eager_modes
def testMap_MultiOutputMismatchedDtype(self):
nums = np.array([1, 2, 3, 4, 5, 6])
with self.assertRaisesRegexp(
TypeError, r"two structures don't have the same nested structure"):
# lambda emits tuple, but dtype is a list
map_fn.map_fn(
lambda x: ((x + 3) * 2, -(x + 3) * 2),
nums,
dtype=[dtypes.int64, dtypes.int64])
@test_util.run_in_graph_and_eager_modes
def testMap_MultiInputSingleOutput(self):
nums = np.array([1, 2, 3, 4, 5, 6])
r = map_fn.map_fn(
lambda x: x[0] * x[1][0] + x[1][1], (nums, (nums, -nums)),
dtype=dtypes.int64)
self.assertEqual((6,), r.get_shape())
received = self.evaluate(r)
self.assertAllEqual(nums * nums + (-nums), received)
@test_util.run_in_graph_and_eager_modes
def testMap_MultiInputSameStructureOutput(self):
nums = np.array([1, 2, 3, 4, 5, 6])
r = map_fn.map_fn(lambda x: (x[1][0], (x[1][1], x[0])),
(nums, (2 * nums, -nums)))
r = [r[0], r[1][0], r[1][1]]
self.assertEqual((6,), r[0].get_shape())
self.assertEqual((6,), r[1].get_shape())
self.assertEqual((6,), r[2].get_shape())
received = self.evaluate(r)
self.assertAllEqual(2 * nums, received[0])
self.assertAllEqual(-nums, received[1])
self.assertAllEqual(nums, received[2])
@test_util.run_in_graph_and_eager_modes
def testMap_autograph_indirect(self):
def test_function(x):
cond = constant_op.constant(-1)
if cond == 0:
result = x
else:
result = x
return result
@def_function.function
def map_call(x):
return map_fn.map_fn(test_function, x)
x = constant_op.constant([1])
y = map_call(x)
self.assertAllEqual([1], self.evaluate(y))
@test_util.run_in_graph_and_eager_modes
def testMapShape(self):
x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
y = map_fn.map_fn(lambda e: e, x)
self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
@test_util.run_deprecated_v1
def testMapUnknownShape(self):
x = array_ops.placeholder(dtypes.float32)
y = map_fn.map_fn(lambda e: e, x)
self.assertIs(None, y.get_shape().dims)
# TODO(b/124383826): this test fails in eager: the iterable is of length 0 so
# so the body of the while loop never executes
@test_util.run_v1_only("b/120545219")
def testMapEmptyScalar(self):
map_return = map_fn.map_fn(lambda x: 1,
constant_op.constant([], dtype=dtypes.int32))
self.assertAllEqual([0], map_return.get_shape().dims)
self.assertAllEqual([0], self.evaluate(map_return).shape)
# TODO(b/124383826): this test fails in eager: the iterable is of length 0 so
# so the body of the while loop never executes
@test_util.run_v1_only("b/120545219")
def testMapEmptyTensor(self):
with self.cached_session():
map_return = map_fn.map_fn(lambda x: array_ops.zeros([3, 2]),
constant_op.constant([]))
self.assertAllEqual([0, 3, 2], map_return.get_shape().dims)
self.assertAllEqual([0, 3, 2], self.evaluate(map_return).shape)
if __name__ == "__main__":
test.main()
# pylint: enable=invalid-name