Improve shape inference for tf.eye.

PiperOrigin-RevId: 216550243
This commit is contained in:
A. Unique TensorFlower 2018-10-10 10:16:38 -07:00 committed by TensorFlower Gardener
parent e09ddb4290
commit c602fc061a
3 changed files with 133 additions and 60 deletions

View File

@ -1785,6 +1785,7 @@ cuda_py_test(
size = "medium",
srcs = ["linalg_ops_test.py"],
additional_deps = [
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",

View File

@ -18,6 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import dtypes
@ -52,7 +55,7 @@ class CholeskySolveTest(test.TestCase):
def test_works_with_five_different_random_pos_def_matrices(self):
for n in range(1, 6):
for np_type, atol in [(np.float32, 0.05), (np.float64, 1e-5)]:
with self.test_session(use_gpu=True):
with self.session(use_gpu=True):
# Create 2 x n x n matrix
array = np.array(
[_RandomPDMatrix(n, self.rng),
@ -76,7 +79,7 @@ class LogdetTest(test.TestCase):
(np.complex64, 0.05), (np.complex128, 1e-5)]:
matrix = _RandomPDMatrix(n, self.rng, np_dtype)
_, logdet_np = np.linalg.slogdet(matrix)
with self.test_session(use_gpu=True):
with self.session(use_gpu=True):
# Create 2 x n x n matrix
# matrix = np.array(
# [_RandomPDMatrix(n, self.rng, np_dtype),
@ -89,7 +92,7 @@ class LogdetTest(test.TestCase):
(np.complex64, 0.05), (np.complex128, 1e-5)]:
matrix = (np.eye(20) * 1e-6).astype(np_dtype)
_, logdet_np = np.linalg.slogdet(matrix)
with self.test_session(use_gpu=True):
with self.session(use_gpu=True):
logdet_tf = linalg.logdet(matrix)
self.assertAllClose(logdet_np, logdet_tf.eval(), atol=atol)
@ -105,7 +108,7 @@ class SlogdetTest(test.TestCase):
(np.complex64, 0.05), (np.complex128, 1e-5)]:
matrix = _RandomPDMatrix(n, self.rng, np_dtype)
sign_np, log_abs_det_np = np.linalg.slogdet(matrix)
with self.test_session(use_gpu=True):
with self.session(use_gpu=True):
sign_tf, log_abs_det_tf = linalg.slogdet(matrix)
self.assertAllClose(log_abs_det_np, log_abs_det_tf.eval(), atol=atol)
self.assertAllClose(sign_np, sign_tf.eval(), atol=atol)
@ -115,7 +118,7 @@ class SlogdetTest(test.TestCase):
(np.complex64, 0.05), (np.complex128, 1e-5)]:
matrix = (np.eye(20) * 1e-6).astype(np_dtype)
sign_np, log_abs_det_np = np.linalg.slogdet(matrix)
with self.test_session(use_gpu=True):
with self.session(use_gpu=True):
sign_tf, log_abs_det_tf = linalg.slogdet(matrix)
self.assertAllClose(log_abs_det_np, log_abs_det_tf.eval(), atol=atol)
self.assertAllClose(sign_np, sign_tf.eval(), atol=atol)
@ -128,66 +131,126 @@ class AdjointTest(test.TestCase):
matrix_np = np.array([[1 + 1j, 2 + 2j, 3 + 3j], [4 + 4j, 5 + 5j,
6 + 6j]]).astype(dtype)
expected_transposed = np.conj(matrix_np.T)
with self.cached_session():
with self.session():
matrix = ops.convert_to_tensor(matrix_np)
transposed = linalg.adjoint(matrix)
self.assertEqual((3, 2), transposed.get_shape())
self.assertAllEqual(expected_transposed, transposed.eval())
class EyeTest(test.TestCase):
pass # Will be filled in below
class EyeTest(parameterized.TestCase, test.TestCase):
def testShapeInferenceNoBatch(self):
self.assertEqual((2, 2), linalg_ops.eye(num_rows=2).shape)
self.assertEqual((2, 3), linalg_ops.eye(num_rows=2, num_columns=3).shape)
def _GetEyeTest(num_rows, num_columns, batch_shape, dtype):
def testShapeInferenceStaticBatch(self):
batch_shape = (2, 3)
self.assertEqual(
(2, 3, 2, 2),
linalg_ops.eye(num_rows=2, batch_shape=batch_shape).shape)
self.assertEqual(
(2, 3, 2, 3),
linalg_ops.eye(
num_rows=2, num_columns=3, batch_shape=batch_shape).shape)
def Test(self):
@parameterized.named_parameters(
("DynamicRow", array_ops.placeholder_with_default(2, shape=None), None),
("DynamicRowStaticColumn",
array_ops.placeholder_with_default(2, shape=None),
3),
("StaticRowDynamicColumn",
2,
array_ops.placeholder_with_default(3, shape=None)),
("DynamicRowDynamicColumn",
array_ops.placeholder_with_default(2, shape=None),
array_ops.placeholder_with_default(3, shape=None)))
def testShapeInferenceStaticBatchWith(self, num_rows, num_columns):
batch_shape = (2, 3)
identity_matrix = linalg_ops.eye(
num_rows=num_rows,
num_columns=num_columns,
batch_shape=batch_shape)
self.assertEqual(4, identity_matrix.shape.ndims)
self.assertEqual((2, 3), identity_matrix.shape[:2])
if num_rows is not None and not isinstance(num_rows, ops.Tensor):
self.assertEqual(2, identity_matrix.shape[-2])
if num_columns is not None and not isinstance(num_columns, ops.Tensor):
self.assertEqual(3, identity_matrix.shape[-1])
@parameterized.parameters(
itertools.product(
# num_rows
[0, 1, 2, 5],
# num_columns
[None, 0, 1, 2, 5],
# batch_shape
[None, [], [2], [2, 3]],
# dtype
[
dtypes.int32,
dtypes.int64,
dtypes.float32,
dtypes.float64,
dtypes.complex64,
dtypes.complex128
])
)
def test_eye_no_placeholder(self, num_rows, num_columns, batch_shape, dtype):
eye_np = np.eye(num_rows, M=num_columns, dtype=dtype.as_numpy_dtype)
if batch_shape is not None:
eye_np = np.tile(eye_np, batch_shape + [1, 1])
for use_placeholder in False, True:
if use_placeholder and (num_columns is None or batch_shape is None):
return
with self.test_session(use_gpu=True) as sess:
if use_placeholder:
num_rows_placeholder = array_ops.placeholder(
dtypes.int32, name="num_rows")
num_columns_placeholder = array_ops.placeholder(
dtypes.int32, name="num_columns")
batch_shape_placeholder = array_ops.placeholder(
dtypes.int32, name="batch_shape")
eye = linalg_ops.eye(
num_rows_placeholder,
num_columns=num_columns_placeholder,
batch_shape=batch_shape_placeholder,
dtype=dtype)
eye_tf = sess.run(
eye,
feed_dict={
num_rows_placeholder: num_rows,
num_columns_placeholder: num_columns,
batch_shape_placeholder: batch_shape
})
else:
eye_tf = linalg_ops.eye(
num_rows,
num_columns=num_columns,
batch_shape=batch_shape,
dtype=dtype).eval()
self.assertAllEqual(eye_np, eye_tf)
eye_tf = self.evaluate(linalg_ops.eye(
num_rows,
num_columns=num_columns,
batch_shape=batch_shape,
dtype=dtype))
self.assertAllEqual(eye_np, eye_tf)
return Test
@parameterized.parameters(
itertools.product(
# num_rows
[0, 1, 2, 5],
# num_columns
[0, 1, 2, 5],
# batch_shape
[[], [2], [2, 3]],
# dtype
[
dtypes.int32,
dtypes.int64,
dtypes.float32,
dtypes.float64,
dtypes.complex64,
dtypes.complex128
])
)
def test_eye_with_placeholder(
self, num_rows, num_columns, batch_shape, dtype):
eye_np = np.eye(num_rows, M=num_columns, dtype=dtype.as_numpy_dtype)
eye_np = np.tile(eye_np, batch_shape + [1, 1])
num_rows_placeholder = array_ops.placeholder(
dtypes.int32, name="num_rows")
num_columns_placeholder = array_ops.placeholder(
dtypes.int32, name="num_columns")
batch_shape_placeholder = array_ops.placeholder(
dtypes.int32, name="batch_shape")
eye = linalg_ops.eye(
num_rows_placeholder,
num_columns=num_columns_placeholder,
batch_shape=batch_shape_placeholder,
dtype=dtype)
with self.session(use_gpu=True) as sess:
eye_tf = sess.run(
eye,
feed_dict={
num_rows_placeholder: num_rows,
num_columns_placeholder: num_columns,
batch_shape_placeholder: batch_shape
})
self.assertAllEqual(eye_np, eye_tf)
if __name__ == "__main__":
for _num_rows in 0, 1, 2, 5:
for _num_columns in None, 0, 1, 2, 5:
for _batch_shape in None, [], [2], [2, 3]:
for _dtype in (dtypes.int32, dtypes.int64, dtypes.float32,
dtypes.float64, dtypes.complex64, dtypes.complex128):
name = "dtype_%s_num_rows_%s_num_column_%s_batch_shape_%s_" % (
_dtype.name, _num_rows, _num_columns, _batch_shape)
_AddTest(EyeTest, "EyeTest", name,
_GetEyeTest(_num_rows, _num_columns, _batch_shape, _dtype))
test.main()

View File

@ -44,22 +44,31 @@ def eye(num_rows,
is_square = num_columns is None
batch_shape = [] if batch_shape is None else batch_shape
num_columns = num_rows if num_columns is None else num_columns
if isinstance(num_rows, ops.Tensor) or isinstance(
num_columns, ops.Tensor) or isinstance(batch_shape, ops.Tensor):
batch_shape = ops.convert_to_tensor(
batch_shape, name='shape', dtype=dtypes.int32)
# We cannot statically infer what the diagonal size should be:
if (isinstance(num_rows, ops.Tensor) or
isinstance(num_columns, ops.Tensor)):
diag_size = math_ops.minimum(num_rows, num_columns)
diag_shape = array_ops.concat((batch_shape, [diag_size]), 0)
if not is_square:
shape = array_ops.concat((batch_shape, [num_rows, num_columns]), 0)
else:
# We can statically infer the diagonal size, and whether it is square.
if not isinstance(num_rows, compat.integral_types) or not isinstance(
num_columns, compat.integral_types):
raise TypeError(
'num_rows and num_columns must be positive integer values.')
batch_shape = [dim for dim in batch_shape]
is_square = num_rows == num_columns
diag_shape = batch_shape + [np.minimum(num_rows, num_columns)]
diag_size = np.minimum(num_rows, num_columns)
# We can not statically infer the shape of the tensor.
if isinstance(batch_shape, ops.Tensor) or isinstance(diag_size, ops.Tensor):
batch_shape = ops.convert_to_tensor(
batch_shape, name='shape', dtype=dtypes.int32)
diag_shape = array_ops.concat((batch_shape, [diag_size]), axis=0)
if not is_square:
shape = array_ops.concat((batch_shape, [num_rows, num_columns]), axis=0)
# We can statically infer everything.
else:
batch_shape = list(batch_shape)
diag_shape = batch_shape + [diag_size]
if not is_square:
shape = batch_shape + [num_rows, num_columns]