Improve shape inference for tf.eye.
PiperOrigin-RevId: 216550243
This commit is contained in:
parent
e09ddb4290
commit
c602fc061a
@ -1785,6 +1785,7 @@ cuda_py_test(
|
||||
size = "medium",
|
||||
srcs = ["linalg_ops_test.py"],
|
||||
additional_deps = [
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -18,6 +18,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -52,7 +55,7 @@ class CholeskySolveTest(test.TestCase):
|
||||
def test_works_with_five_different_random_pos_def_matrices(self):
|
||||
for n in range(1, 6):
|
||||
for np_type, atol in [(np.float32, 0.05), (np.float64, 1e-5)]:
|
||||
with self.test_session(use_gpu=True):
|
||||
with self.session(use_gpu=True):
|
||||
# Create 2 x n x n matrix
|
||||
array = np.array(
|
||||
[_RandomPDMatrix(n, self.rng),
|
||||
@ -76,7 +79,7 @@ class LogdetTest(test.TestCase):
|
||||
(np.complex64, 0.05), (np.complex128, 1e-5)]:
|
||||
matrix = _RandomPDMatrix(n, self.rng, np_dtype)
|
||||
_, logdet_np = np.linalg.slogdet(matrix)
|
||||
with self.test_session(use_gpu=True):
|
||||
with self.session(use_gpu=True):
|
||||
# Create 2 x n x n matrix
|
||||
# matrix = np.array(
|
||||
# [_RandomPDMatrix(n, self.rng, np_dtype),
|
||||
@ -89,7 +92,7 @@ class LogdetTest(test.TestCase):
|
||||
(np.complex64, 0.05), (np.complex128, 1e-5)]:
|
||||
matrix = (np.eye(20) * 1e-6).astype(np_dtype)
|
||||
_, logdet_np = np.linalg.slogdet(matrix)
|
||||
with self.test_session(use_gpu=True):
|
||||
with self.session(use_gpu=True):
|
||||
logdet_tf = linalg.logdet(matrix)
|
||||
self.assertAllClose(logdet_np, logdet_tf.eval(), atol=atol)
|
||||
|
||||
@ -105,7 +108,7 @@ class SlogdetTest(test.TestCase):
|
||||
(np.complex64, 0.05), (np.complex128, 1e-5)]:
|
||||
matrix = _RandomPDMatrix(n, self.rng, np_dtype)
|
||||
sign_np, log_abs_det_np = np.linalg.slogdet(matrix)
|
||||
with self.test_session(use_gpu=True):
|
||||
with self.session(use_gpu=True):
|
||||
sign_tf, log_abs_det_tf = linalg.slogdet(matrix)
|
||||
self.assertAllClose(log_abs_det_np, log_abs_det_tf.eval(), atol=atol)
|
||||
self.assertAllClose(sign_np, sign_tf.eval(), atol=atol)
|
||||
@ -115,7 +118,7 @@ class SlogdetTest(test.TestCase):
|
||||
(np.complex64, 0.05), (np.complex128, 1e-5)]:
|
||||
matrix = (np.eye(20) * 1e-6).astype(np_dtype)
|
||||
sign_np, log_abs_det_np = np.linalg.slogdet(matrix)
|
||||
with self.test_session(use_gpu=True):
|
||||
with self.session(use_gpu=True):
|
||||
sign_tf, log_abs_det_tf = linalg.slogdet(matrix)
|
||||
self.assertAllClose(log_abs_det_np, log_abs_det_tf.eval(), atol=atol)
|
||||
self.assertAllClose(sign_np, sign_tf.eval(), atol=atol)
|
||||
@ -128,66 +131,126 @@ class AdjointTest(test.TestCase):
|
||||
matrix_np = np.array([[1 + 1j, 2 + 2j, 3 + 3j], [4 + 4j, 5 + 5j,
|
||||
6 + 6j]]).astype(dtype)
|
||||
expected_transposed = np.conj(matrix_np.T)
|
||||
with self.cached_session():
|
||||
with self.session():
|
||||
matrix = ops.convert_to_tensor(matrix_np)
|
||||
transposed = linalg.adjoint(matrix)
|
||||
self.assertEqual((3, 2), transposed.get_shape())
|
||||
self.assertAllEqual(expected_transposed, transposed.eval())
|
||||
|
||||
|
||||
class EyeTest(test.TestCase):
|
||||
pass # Will be filled in below
|
||||
class EyeTest(parameterized.TestCase, test.TestCase):
|
||||
|
||||
def testShapeInferenceNoBatch(self):
|
||||
self.assertEqual((2, 2), linalg_ops.eye(num_rows=2).shape)
|
||||
self.assertEqual((2, 3), linalg_ops.eye(num_rows=2, num_columns=3).shape)
|
||||
|
||||
def _GetEyeTest(num_rows, num_columns, batch_shape, dtype):
|
||||
def testShapeInferenceStaticBatch(self):
|
||||
batch_shape = (2, 3)
|
||||
self.assertEqual(
|
||||
(2, 3, 2, 2),
|
||||
linalg_ops.eye(num_rows=2, batch_shape=batch_shape).shape)
|
||||
self.assertEqual(
|
||||
(2, 3, 2, 3),
|
||||
linalg_ops.eye(
|
||||
num_rows=2, num_columns=3, batch_shape=batch_shape).shape)
|
||||
|
||||
def Test(self):
|
||||
@parameterized.named_parameters(
|
||||
("DynamicRow", array_ops.placeholder_with_default(2, shape=None), None),
|
||||
("DynamicRowStaticColumn",
|
||||
array_ops.placeholder_with_default(2, shape=None),
|
||||
3),
|
||||
("StaticRowDynamicColumn",
|
||||
2,
|
||||
array_ops.placeholder_with_default(3, shape=None)),
|
||||
("DynamicRowDynamicColumn",
|
||||
array_ops.placeholder_with_default(2, shape=None),
|
||||
array_ops.placeholder_with_default(3, shape=None)))
|
||||
def testShapeInferenceStaticBatchWith(self, num_rows, num_columns):
|
||||
batch_shape = (2, 3)
|
||||
identity_matrix = linalg_ops.eye(
|
||||
num_rows=num_rows,
|
||||
num_columns=num_columns,
|
||||
batch_shape=batch_shape)
|
||||
self.assertEqual(4, identity_matrix.shape.ndims)
|
||||
self.assertEqual((2, 3), identity_matrix.shape[:2])
|
||||
if num_rows is not None and not isinstance(num_rows, ops.Tensor):
|
||||
self.assertEqual(2, identity_matrix.shape[-2])
|
||||
|
||||
if num_columns is not None and not isinstance(num_columns, ops.Tensor):
|
||||
self.assertEqual(3, identity_matrix.shape[-1])
|
||||
|
||||
@parameterized.parameters(
|
||||
itertools.product(
|
||||
# num_rows
|
||||
[0, 1, 2, 5],
|
||||
# num_columns
|
||||
[None, 0, 1, 2, 5],
|
||||
# batch_shape
|
||||
[None, [], [2], [2, 3]],
|
||||
# dtype
|
||||
[
|
||||
dtypes.int32,
|
||||
dtypes.int64,
|
||||
dtypes.float32,
|
||||
dtypes.float64,
|
||||
dtypes.complex64,
|
||||
dtypes.complex128
|
||||
])
|
||||
)
|
||||
def test_eye_no_placeholder(self, num_rows, num_columns, batch_shape, dtype):
|
||||
eye_np = np.eye(num_rows, M=num_columns, dtype=dtype.as_numpy_dtype)
|
||||
if batch_shape is not None:
|
||||
eye_np = np.tile(eye_np, batch_shape + [1, 1])
|
||||
for use_placeholder in False, True:
|
||||
if use_placeholder and (num_columns is None or batch_shape is None):
|
||||
return
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
if use_placeholder:
|
||||
num_rows_placeholder = array_ops.placeholder(
|
||||
dtypes.int32, name="num_rows")
|
||||
num_columns_placeholder = array_ops.placeholder(
|
||||
dtypes.int32, name="num_columns")
|
||||
batch_shape_placeholder = array_ops.placeholder(
|
||||
dtypes.int32, name="batch_shape")
|
||||
eye = linalg_ops.eye(
|
||||
num_rows_placeholder,
|
||||
num_columns=num_columns_placeholder,
|
||||
batch_shape=batch_shape_placeholder,
|
||||
dtype=dtype)
|
||||
eye_tf = sess.run(
|
||||
eye,
|
||||
feed_dict={
|
||||
num_rows_placeholder: num_rows,
|
||||
num_columns_placeholder: num_columns,
|
||||
batch_shape_placeholder: batch_shape
|
||||
})
|
||||
else:
|
||||
eye_tf = linalg_ops.eye(
|
||||
num_rows,
|
||||
num_columns=num_columns,
|
||||
batch_shape=batch_shape,
|
||||
dtype=dtype).eval()
|
||||
self.assertAllEqual(eye_np, eye_tf)
|
||||
eye_tf = self.evaluate(linalg_ops.eye(
|
||||
num_rows,
|
||||
num_columns=num_columns,
|
||||
batch_shape=batch_shape,
|
||||
dtype=dtype))
|
||||
self.assertAllEqual(eye_np, eye_tf)
|
||||
|
||||
return Test
|
||||
@parameterized.parameters(
|
||||
itertools.product(
|
||||
# num_rows
|
||||
[0, 1, 2, 5],
|
||||
# num_columns
|
||||
[0, 1, 2, 5],
|
||||
# batch_shape
|
||||
[[], [2], [2, 3]],
|
||||
# dtype
|
||||
[
|
||||
dtypes.int32,
|
||||
dtypes.int64,
|
||||
dtypes.float32,
|
||||
dtypes.float64,
|
||||
dtypes.complex64,
|
||||
dtypes.complex128
|
||||
])
|
||||
)
|
||||
def test_eye_with_placeholder(
|
||||
self, num_rows, num_columns, batch_shape, dtype):
|
||||
eye_np = np.eye(num_rows, M=num_columns, dtype=dtype.as_numpy_dtype)
|
||||
eye_np = np.tile(eye_np, batch_shape + [1, 1])
|
||||
num_rows_placeholder = array_ops.placeholder(
|
||||
dtypes.int32, name="num_rows")
|
||||
num_columns_placeholder = array_ops.placeholder(
|
||||
dtypes.int32, name="num_columns")
|
||||
batch_shape_placeholder = array_ops.placeholder(
|
||||
dtypes.int32, name="batch_shape")
|
||||
eye = linalg_ops.eye(
|
||||
num_rows_placeholder,
|
||||
num_columns=num_columns_placeholder,
|
||||
batch_shape=batch_shape_placeholder,
|
||||
dtype=dtype)
|
||||
with self.session(use_gpu=True) as sess:
|
||||
eye_tf = sess.run(
|
||||
eye,
|
||||
feed_dict={
|
||||
num_rows_placeholder: num_rows,
|
||||
num_columns_placeholder: num_columns,
|
||||
batch_shape_placeholder: batch_shape
|
||||
})
|
||||
self.assertAllEqual(eye_np, eye_tf)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for _num_rows in 0, 1, 2, 5:
|
||||
for _num_columns in None, 0, 1, 2, 5:
|
||||
for _batch_shape in None, [], [2], [2, 3]:
|
||||
for _dtype in (dtypes.int32, dtypes.int64, dtypes.float32,
|
||||
dtypes.float64, dtypes.complex64, dtypes.complex128):
|
||||
name = "dtype_%s_num_rows_%s_num_column_%s_batch_shape_%s_" % (
|
||||
_dtype.name, _num_rows, _num_columns, _batch_shape)
|
||||
_AddTest(EyeTest, "EyeTest", name,
|
||||
_GetEyeTest(_num_rows, _num_columns, _batch_shape, _dtype))
|
||||
|
||||
test.main()
|
||||
|
@ -44,22 +44,31 @@ def eye(num_rows,
|
||||
is_square = num_columns is None
|
||||
batch_shape = [] if batch_shape is None else batch_shape
|
||||
num_columns = num_rows if num_columns is None else num_columns
|
||||
if isinstance(num_rows, ops.Tensor) or isinstance(
|
||||
num_columns, ops.Tensor) or isinstance(batch_shape, ops.Tensor):
|
||||
batch_shape = ops.convert_to_tensor(
|
||||
batch_shape, name='shape', dtype=dtypes.int32)
|
||||
|
||||
# We cannot statically infer what the diagonal size should be:
|
||||
if (isinstance(num_rows, ops.Tensor) or
|
||||
isinstance(num_columns, ops.Tensor)):
|
||||
diag_size = math_ops.minimum(num_rows, num_columns)
|
||||
diag_shape = array_ops.concat((batch_shape, [diag_size]), 0)
|
||||
if not is_square:
|
||||
shape = array_ops.concat((batch_shape, [num_rows, num_columns]), 0)
|
||||
else:
|
||||
# We can statically infer the diagonal size, and whether it is square.
|
||||
if not isinstance(num_rows, compat.integral_types) or not isinstance(
|
||||
num_columns, compat.integral_types):
|
||||
raise TypeError(
|
||||
'num_rows and num_columns must be positive integer values.')
|
||||
batch_shape = [dim for dim in batch_shape]
|
||||
is_square = num_rows == num_columns
|
||||
diag_shape = batch_shape + [np.minimum(num_rows, num_columns)]
|
||||
diag_size = np.minimum(num_rows, num_columns)
|
||||
|
||||
# We can not statically infer the shape of the tensor.
|
||||
if isinstance(batch_shape, ops.Tensor) or isinstance(diag_size, ops.Tensor):
|
||||
batch_shape = ops.convert_to_tensor(
|
||||
batch_shape, name='shape', dtype=dtypes.int32)
|
||||
diag_shape = array_ops.concat((batch_shape, [diag_size]), axis=0)
|
||||
if not is_square:
|
||||
shape = array_ops.concat((batch_shape, [num_rows, num_columns]), axis=0)
|
||||
# We can statically infer everything.
|
||||
else:
|
||||
batch_shape = list(batch_shape)
|
||||
diag_shape = batch_shape + [diag_size]
|
||||
if not is_square:
|
||||
shape = batch_shape + [num_rows, num_columns]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user