245 lines
9.2 KiB
Python
245 lines
9.2 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for tensorflow.kernels.functional_ops."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import numpy as np
|
|
|
|
from tensorflow.python.eager import def_function
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import sparse_tensor
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import gradients_impl
|
|
from tensorflow.python.ops import init_ops
|
|
from tensorflow.python.ops import map_fn
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import variable_scope
|
|
from tensorflow.python.ops import variables
|
|
import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import
|
|
from tensorflow.python.platform import test
|
|
|
|
|
|
# pylint: disable=invalid-name
|
|
def simple_scoped_fn(a, x):
|
|
"""Simple function: (a, x) -> 2(x+a), but with "2" as a variable in scope."""
|
|
with variable_scope.variable_scope("body"):
|
|
# Dummy variable, just to check that scoping works as intended.
|
|
two = variable_scope.get_variable(
|
|
"two", [],
|
|
dtype=dtypes.int32,
|
|
initializer=init_ops.constant_initializer(2))
|
|
return math_ops.multiply(math_ops.add(a, x), two)
|
|
|
|
|
|
@test_util.with_control_flow_v2
|
|
class MapFnTest(test.TestCase):
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testMap_Simple(self):
|
|
nums = [1, 2, 3, 4, 5, 6]
|
|
elems = constant_op.constant(nums, name="data")
|
|
r = map_fn.map_fn(
|
|
lambda x: math_ops.multiply(math_ops.add(x, 3), 2), elems)
|
|
self.assertAllEqual(
|
|
np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
|
|
|
|
def testMapDtypeEager(self):
|
|
with context.eager_mode():
|
|
dtype = map_fn.map_fn(lambda x: constant_op.constant(""),
|
|
constant_op.constant([]),
|
|
dtype=dtypes.string).dtype
|
|
self.assertEqual(dtype, dtypes.string)
|
|
|
|
def testMapSparseTensor(self):
|
|
with self.cached_session():
|
|
st = sparse_tensor.SparseTensor(
|
|
indices=[[0, 0], [0, 1], [1, 0]],
|
|
values=constant_op.constant([0, 1, 2]),
|
|
dense_shape=[2, 2])
|
|
result = map_fn.map_fn(lambda x: x, st)
|
|
self.assertAllEqual(result.indices, st.indices)
|
|
self.assertAllEqual(result.values, st.values)
|
|
self.assertAllEqual(result.dense_shape, st.dense_shape)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testMapOverScalarErrors(self):
|
|
with self.assertRaisesRegexp(ValueError, "not scalars"):
|
|
map_fn.map_fn(lambda x: x, [1, 2])
|
|
with self.assertRaisesRegexp(ValueError, "not a scalar"):
|
|
map_fn.map_fn(lambda x: x, 1)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testMap_Scoped(self):
|
|
with self.cached_session() as sess:
|
|
|
|
def double_scoped(x):
|
|
"""2x with a dummy 2 that is scoped."""
|
|
with variable_scope.variable_scope("body"):
|
|
# Dummy variable, just to check that scoping works as intended.
|
|
two = variable_scope.get_variable(
|
|
"two", [],
|
|
dtype=dtypes.int32,
|
|
initializer=init_ops.constant_initializer(2))
|
|
return math_ops.multiply(x, two)
|
|
|
|
with variable_scope.variable_scope("root") as varscope:
|
|
elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
|
|
doubles = np.array([2 * x for x in [1, 2, 3, 4, 5, 6]])
|
|
|
|
r = map_fn.map_fn(double_scoped, elems)
|
|
# Check that we have the one variable we asked for here.
|
|
self.assertEqual(len(variables.trainable_variables()), 1)
|
|
self.assertEqual(variables.trainable_variables()[0].name,
|
|
"root/body/two:0")
|
|
sess.run([variables.global_variables_initializer()])
|
|
self.assertAllEqual(doubles, self.evaluate(r))
|
|
|
|
# Now let's reuse our single variable.
|
|
varscope.reuse_variables()
|
|
r = map_fn.map_fn(double_scoped, elems)
|
|
self.assertEqual(len(variables.trainable_variables()), 1)
|
|
self.assertAllEqual(doubles, self.evaluate(r))
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testMap_Grad(self):
|
|
with self.cached_session():
|
|
param = constant_op.constant(2.0)
|
|
elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
|
|
y = map_fn.map_fn(
|
|
lambda x: math_ops.multiply(math_ops.square(x), param), elems)
|
|
r = gradients_impl.gradients(y, param)[0]
|
|
self.assertAllEqual(91.0, self.evaluate(r))
|
|
r = gradients_impl.gradients(y, elems)[0]
|
|
self.assertAllEqual([4.0, 8.0, 12.0, 16.0, 20.0, 24.0], self.evaluate(r))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testMap_SimpleNotTensor(self):
|
|
nums = np.array([1, 2, 3, 4, 5, 6])
|
|
r = map_fn.map_fn(
|
|
lambda x: math_ops.multiply(math_ops.add(x, 3), 2), nums)
|
|
self.assertAllEqual(
|
|
np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testMap_SingleInputMultiOutput(self):
|
|
nums = np.array([1, 2, 3, 4, 5, 6])
|
|
r = map_fn.map_fn(
|
|
lambda x: ((x + 3) * 2, -(x + 3) * 2),
|
|
nums,
|
|
dtype=(dtypes.int64, dtypes.int64))
|
|
self.assertEqual(2, len(r))
|
|
self.assertEqual((6,), r[0].get_shape())
|
|
self.assertEqual((6,), r[1].get_shape())
|
|
received = self.evaluate(r)
|
|
self.assertAllEqual((nums + 3) * 2, received[0])
|
|
self.assertAllEqual(-(nums + 3) * 2, received[1])
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testMap_MultiOutputMismatchedDtype(self):
|
|
nums = np.array([1, 2, 3, 4, 5, 6])
|
|
with self.assertRaisesRegexp(
|
|
TypeError, r"two structures don't have the same nested structure"):
|
|
# lambda emits tuple, but dtype is a list
|
|
map_fn.map_fn(
|
|
lambda x: ((x + 3) * 2, -(x + 3) * 2),
|
|
nums,
|
|
dtype=[dtypes.int64, dtypes.int64])
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testMap_MultiInputSingleOutput(self):
|
|
nums = np.array([1, 2, 3, 4, 5, 6])
|
|
r = map_fn.map_fn(
|
|
lambda x: x[0] * x[1][0] + x[1][1], (nums, (nums, -nums)),
|
|
dtype=dtypes.int64)
|
|
self.assertEqual((6,), r.get_shape())
|
|
received = self.evaluate(r)
|
|
self.assertAllEqual(nums * nums + (-nums), received)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testMap_MultiInputSameStructureOutput(self):
|
|
nums = np.array([1, 2, 3, 4, 5, 6])
|
|
r = map_fn.map_fn(lambda x: (x[1][0], (x[1][1], x[0])),
|
|
(nums, (2 * nums, -nums)))
|
|
r = [r[0], r[1][0], r[1][1]]
|
|
self.assertEqual((6,), r[0].get_shape())
|
|
self.assertEqual((6,), r[1].get_shape())
|
|
self.assertEqual((6,), r[2].get_shape())
|
|
received = self.evaluate(r)
|
|
self.assertAllEqual(2 * nums, received[0])
|
|
self.assertAllEqual(-nums, received[1])
|
|
self.assertAllEqual(nums, received[2])
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testMap_autograph_indirect(self):
|
|
|
|
def test_function(x):
|
|
cond = constant_op.constant(-1)
|
|
if cond == 0:
|
|
result = x
|
|
else:
|
|
result = x
|
|
return result
|
|
|
|
@def_function.function
|
|
def map_call(x):
|
|
return map_fn.map_fn(test_function, x)
|
|
|
|
x = constant_op.constant([1])
|
|
y = map_call(x)
|
|
self.assertAllEqual([1], self.evaluate(y))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testMapShape(self):
|
|
x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
|
|
y = map_fn.map_fn(lambda e: e, x)
|
|
self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testMapUnknownShape(self):
|
|
x = array_ops.placeholder(dtypes.float32)
|
|
y = map_fn.map_fn(lambda e: e, x)
|
|
self.assertIs(None, y.get_shape().dims)
|
|
|
|
# TODO(b/124383826): this test fails in eager: the iterable is of length 0 so
|
|
# so the body of the while loop never executes
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testMapEmptyScalar(self):
|
|
map_return = map_fn.map_fn(lambda x: 1,
|
|
constant_op.constant([], dtype=dtypes.int32))
|
|
self.assertAllEqual([0], map_return.get_shape().dims)
|
|
self.assertAllEqual([0], self.evaluate(map_return).shape)
|
|
|
|
# TODO(b/124383826): this test fails in eager: the iterable is of length 0 so
|
|
# so the body of the while loop never executes
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testMapEmptyTensor(self):
|
|
with self.cached_session():
|
|
map_return = map_fn.map_fn(lambda x: array_ops.zeros([3, 2]),
|
|
constant_op.constant([]))
|
|
self.assertAllEqual([0, 3, 2], map_return.get_shape().dims)
|
|
self.assertAllEqual([0, 3, 2], self.evaluate(map_return).shape)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test.main()
|
|
|
|
# pylint: enable=invalid-name
|