STT-tensorflow/tensorflow/python/framework/common_shapes_test.py
Sergei Lebedev bca5e7385f Inlined tensor_shape.{scalar,vector,matrix}
Explicit constructor call is no less clear and match what we export via
the public API.

The functions will be removed once all the internal users are migrated.

PiperOrigin-RevId: 259620054
2019-07-23 15:33:44 -07:00

210 lines
9.3 KiB
Python

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for common shapes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
class CommonShapesTest(test_util.TensorFlowTestCase):
# Asserts that we get the same result with numpy (for known shapes), and that
# the order of arguments does not matter (i.e., broadcasting is reflexive).
def _assert_incompatible_broadcast(self, shape1, shape2):
if shape1.dims is not None and shape2.dims is not None:
zeros1 = np.zeros(shape1.as_list())
zeros2 = np.zeros(shape2.as_list())
with self.assertRaises(ValueError):
np.broadcast(zeros1, zeros2)
with self.assertRaises(ValueError):
np.broadcast(zeros2, zeros1)
self.assertFalse(common_shapes.is_broadcast_compatible(shape1, shape2))
self.assertFalse(common_shapes.is_broadcast_compatible(shape2, shape1))
with self.assertRaises(ValueError):
common_shapes.broadcast_shape(shape1, shape2)
with self.assertRaises(ValueError):
common_shapes.broadcast_shape(shape2, shape1)
# Asserts that we get the same result with numpy (for known shapes), and that
# the order of arguments does not matter (i.e., broadcasting is reflexive).
def _assert_broadcast(self, expected, shape1, shape2):
if shape1.dims is not None and shape2.dims is not None:
expected_np = expected.as_list()
zeros1 = np.zeros(shape1.as_list())
zeros2 = np.zeros(shape2.as_list())
self.assertAllEqual(expected_np, np.broadcast(zeros1, zeros2).shape)
self.assertAllEqual(expected_np, np.broadcast(zeros2, zeros1).shape)
self.assertEqual(
expected, common_shapes.broadcast_shape(shape1, shape2))
self.assertEqual(
expected, common_shapes.broadcast_shape(shape2, shape1))
else:
self.assertEqual(expected, common_shapes.broadcast_shape(shape1, shape2))
self.assertEqual(expected, common_shapes.broadcast_shape(shape2, shape1))
def testBroadcast_one_dimension(self):
s1 = tensor_shape.TensorShape([5])
s2 = tensor_shape.TensorShape([7])
unknown = tensor_shape.unknown_shape()
scalar = tensor_shape.TensorShape([])
expanded_scalar = tensor_shape.TensorShape([1])
# Tensors with same shape should have the same broadcast result.
for shape in (s1, s2, unknown, scalar, expanded_scalar):
self._assert_broadcast(expected=shape, shape1=shape, shape2=shape)
# [] and [1] act like identity.
self._assert_broadcast(expected=s1, shape1=s1, shape2=scalar)
self._assert_broadcast(expected=s2, shape1=s2, shape2=scalar)
self._assert_broadcast(expected=s1, shape1=s1, shape2=expanded_scalar)
self._assert_broadcast(expected=s2, shape1=s2, shape2=expanded_scalar)
self._assert_broadcast(expected=unknown, shape1=s1, shape2=unknown)
self._assert_broadcast(expected=unknown, shape1=s2, shape2=unknown)
self._assert_broadcast(
expected=expanded_scalar, shape1=scalar, shape2=expanded_scalar)
self._assert_incompatible_broadcast(shape1=s1, shape2=s2)
def testBroadcast_many_dimensions(self):
unknown = tensor_shape.unknown_shape()
shape_0 = tensor_shape.TensorShape([])
shape_1 = tensor_shape.TensorShape([1])
shape_4 = tensor_shape.TensorShape([4])
shape_1x4 = tensor_shape.TensorShape([1, 4])
shape_4x1 = tensor_shape.TensorShape([4, 1])
shape_3x4 = tensor_shape.TensorShape([3, 4])
shape_4x3 = tensor_shape.TensorShape([4, 3])
# Tensors with same shape should have the same broadcast result.
for shape in (
shape_0, shape_1, shape_4, shape_1x4, shape_4x1, shape_3x4, shape_4x3):
self._assert_broadcast(expected=shape, shape1=shape, shape2=shape)
# [] and [1] act like identity.
for identity in (shape_0, shape_1):
for shape in (shape_4, shape_1x4, shape_4x1, shape_3x4, shape_4x3):
self._assert_broadcast(expected=shape, shape1=identity, shape2=shape)
# Unknown in, unknown out.
for shape in (shape_4, shape_1x4, shape_4x1, shape_3x4, shape_4x3):
self._assert_broadcast(expected=unknown, shape1=shape, shape2=unknown)
self._assert_broadcast(expected=shape_1x4, shape1=shape_4, shape2=shape_1x4)
shape_4x4 = tensor_shape.TensorShape([4, 4])
self._assert_broadcast(expected=shape_4x4, shape1=shape_4, shape2=shape_4x1)
self._assert_broadcast(expected=shape_3x4, shape1=shape_4, shape2=shape_3x4)
self._assert_incompatible_broadcast(shape1=shape_4, shape2=shape_4x3)
self._assert_broadcast(
expected=shape_4x4, shape1=shape_1x4, shape2=shape_4x1)
self._assert_broadcast(
expected=shape_3x4, shape1=shape_1x4, shape2=shape_3x4)
self._assert_incompatible_broadcast(shape1=shape_1x4, shape2=shape_4x3)
self._assert_incompatible_broadcast(shape1=shape_4x1, shape2=shape_3x4)
self._assert_broadcast(
expected=shape_4x3, shape1=shape_4x1, shape2=shape_4x3)
self._assert_incompatible_broadcast(shape1=shape_3x4, shape2=shape_4x3)
# Asserts that the order of arguments does not matter (i.e., broadcasting is
# reflexive).
def _assert_broadcast_with_unknown_dims(self, expected, shape1, shape2):
actual_dims = common_shapes.broadcast_shape(shape1, shape2).dims
reflexive_actual_dims = common_shapes.broadcast_shape(shape2, shape1).dims
if actual_dims is None:
self.assertIsNone(reflexive_actual_dims)
elif reflexive_actual_dims is None:
self.assertIsNone(actual_dims)
else:
self.assertEqual(len(actual_dims), len(reflexive_actual_dims))
for actual_dim, reflexive_actual_dim in zip(
actual_dims, reflexive_actual_dims):
self.assertEqual(actual_dim.value, reflexive_actual_dim.value)
expected_dims = expected.dims
if expected_dims is None:
self.assertIsNone(actual_dims)
elif actual_dims is None:
self.assertIsNone(expected_dims)
else:
self.assertEqual(len(expected_dims), len(actual_dims))
for expected_dim, actual_dim in zip(expected_dims, actual_dims):
self.assertEqual(expected_dim.value, actual_dim.value)
def testBroadcast_unknown_dims(self):
unknown = tensor_shape.unknown_shape()
shape_0 = tensor_shape.TensorShape([])
shape_1 = tensor_shape.TensorShape([1])
# pylint: disable=invalid-name
shape_U = tensor_shape.TensorShape([None])
shape_1xU = tensor_shape.TensorShape([1, None])
shape_Ux1 = tensor_shape.TensorShape([None, 1])
shape_4xU = tensor_shape.TensorShape([4, None])
shape_Ux4 = tensor_shape.TensorShape([None, 4])
# pylint: enable=invalid-name
# Tensors with same shape should have the same broadcast result.
for shape in (shape_U, shape_1xU, shape_Ux1, shape_4xU, shape_Ux4):
self._assert_broadcast_with_unknown_dims(
expected=shape, shape1=shape, shape2=shape)
# [] and [1] act like identity.
for identity in (shape_0, shape_1):
for shape in (shape_U, shape_1xU, shape_Ux1, shape_4xU, shape_Ux4):
self._assert_broadcast_with_unknown_dims(
expected=shape, shape1=identity, shape2=shape)
# Unknown in, unknown out.
for shape in (shape_U, shape_1xU, shape_Ux1, shape_4xU, shape_Ux4):
self._assert_broadcast_with_unknown_dims(
expected=unknown, shape1=shape, shape2=unknown)
self._assert_broadcast_with_unknown_dims(
expected=shape_1xU, shape1=shape_U, shape2=shape_1xU)
shape_UxU = tensor_shape.TensorShape([None, None]) # pylint: disable=invalid-name
self._assert_broadcast_with_unknown_dims(
expected=shape_UxU, shape1=shape_U, shape2=shape_Ux1)
self._assert_broadcast_with_unknown_dims(
expected=shape_4xU, shape1=shape_U, shape2=shape_4xU)
self._assert_broadcast_with_unknown_dims(
expected=shape_Ux4, shape1=shape_U, shape2=shape_Ux4)
self._assert_broadcast_with_unknown_dims(
expected=shape_UxU, shape1=shape_1xU, shape2=shape_Ux1)
self._assert_broadcast_with_unknown_dims(
expected=shape_4xU, shape1=shape_1xU, shape2=shape_4xU)
self._assert_broadcast_with_unknown_dims(
expected=shape_Ux4, shape1=shape_1xU, shape2=shape_Ux4)
self._assert_broadcast_with_unknown_dims(
expected=shape_4xU, shape1=shape_Ux1, shape2=shape_4xU)
self._assert_broadcast_with_unknown_dims(
expected=shape_Ux4, shape1=shape_Ux1, shape2=shape_Ux4)
shape_4x4 = tensor_shape.TensorShape([4, 4])
self._assert_broadcast_with_unknown_dims(
expected=shape_4x4, shape1=shape_4xU, shape2=shape_Ux4)
if __name__ == "__main__":
googletest.main()