STT-tensorflow/tensorflow/python/kernel_tests/aggregate_ops_test.py
Gaurav Jain 24f578cd66 Add @run_deprecated_v1 annotation to tests failing in v2
PiperOrigin-RevId: 223422907
2018-11-29 15:43:25 -08:00

125 lines
5.0 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for aggregate_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.core.framework import tensor_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class AddNTest(test.TestCase):
# AddN special-cases adding the first M inputs to make (N - M) divisible by 8,
# after which it adds the remaining (N - M) tensors 8 at a time in a loop.
# Test N in [1, 10] so we check each special-case from 1 to 9 and one
# iteration of the loop.
_MAX_N = 10
def _supported_types(self):
if test.is_gpu_available():
return [
dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64,
dtypes.complex128, dtypes.int64
]
return [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64,
dtypes.complex128]
def _buildData(self, shape, dtype):
data = np.random.randn(*shape).astype(dtype.as_numpy_dtype)
# For complex types, add an index-dependent imaginary component so we can
# tell we got the right value.
if dtype.is_complex:
return data + 10j * data
return data
def testAddN(self):
np.random.seed(12345)
with self.session(use_gpu=True) as sess:
for dtype in self._supported_types():
for count in range(1, self._MAX_N + 1):
data = [self._buildData((2, 2), dtype) for _ in range(count)]
actual = self.evaluate(math_ops.add_n(data))
expected = np.sum(np.vstack(
[np.expand_dims(d, 0) for d in data]), axis=0)
tol = 5e-3 if dtype == dtypes.float16 else 5e-7
self.assertAllClose(expected, actual, rtol=tol, atol=tol)
@test_util.run_deprecated_v1
def testUnknownShapes(self):
np.random.seed(12345)
with self.session(use_gpu=True) as sess:
for dtype in self._supported_types():
data = self._buildData((2, 2), dtype)
for count in range(1, self._MAX_N + 1):
data_ph = array_ops.placeholder(dtype=dtype)
actual = sess.run(math_ops.add_n([data_ph] * count), {data_ph: data})
expected = np.sum(np.vstack([np.expand_dims(data, 0)] * count),
axis=0)
tol = 5e-3 if dtype == dtypes.float16 else 5e-7
self.assertAllClose(expected, actual, rtol=tol, atol=tol)
@test_util.run_deprecated_v1
def testVariant(self):
def create_constant_variant(value):
return constant_op.constant(
tensor_pb2.TensorProto(
dtype=dtypes.variant.as_datatype_enum,
tensor_shape=tensor_shape.TensorShape([]).as_proto(),
variant_val=[
tensor_pb2.VariantTensorDataProto(
# Match registration in variant_op_registry.cc
type_name=b"int",
metadata=np.array(value, dtype=np.int32).tobytes())
]))
# TODO(ebrevdo): Re-enable use_gpu=True once non-DMA Variant
# copying between CPU and GPU is supported.
with self.session(use_gpu=False):
variant_const_3 = create_constant_variant(3)
variant_const_4 = create_constant_variant(4)
variant_const_5 = create_constant_variant(5)
# 3 + 3 + 5 + 4 = 15.
result = math_ops.add_n((variant_const_3, variant_const_3,
variant_const_5, variant_const_4))
# Smoke test -- ensure this executes without trouble.
# Right now, non-numpy-compatible objects cannot be returned from a
# session.run call; similarly, objects that can't be converted to
# native numpy types cannot be passed to ops.convert_to_tensor.
# For now, run the test and examine the output to see that the result is
# equal to 15.
result_op = logging_ops.Print(
result, [variant_const_3, variant_const_4, variant_const_5, result],
message=("Variants stored an int: c(3), c(4), c(5), "
"add_n(c(3), c(3), c(5), c(4)): ")).op
result_op.run()
if __name__ == "__main__":
test.main()