Enable tests for tf.linalg.tensordot in eager mode.

PiperOrigin-RevId: 312144965
Change-Id: I2d75f7d9bd7f05aef6d1dee620dffcea66071b97
This commit is contained in:
A. Unique TensorFlower 2020-05-18 13:38:25 -07:00 committed by TensorFlower Gardener
parent da67fcddef
commit d4f71ff132

View File

@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python import tf2 from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl from tensorflow.python.framework import errors_impl
@ -41,16 +41,19 @@ def _add_test(test, test_name, fn):
class TensordotTest(test_lib.TestCase): class TensordotTest(test_lib.TestCase):
@test_util.run_v1_only("b/120545219") @test_util.run_in_graph_and_eager_modes(use_gpu=True)
def test_invalid_shape(self): def test_invalid_shape(self):
a = [[1, 2], [3, 4]] a = [[1, 2], [3, 4]]
b = [[1, 2], [3, 4], [5, 6]] b = [[1, 2], [3, 4], [5, 6]]
a_axes = [1] a_axes = [1]
b_axes = [0] b_axes = [0]
# Invalid static shapes. # Invalid static shapes.
with self.assertRaises(ValueError): with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
math_ops.tensordot(a, b, (a_axes, b_axes)) math_ops.tensordot(a, b, (a_axes, b_axes))
# Invalid dynamic shapes. # Invalid dynamic shapes.
if context.executing_eagerly():
return
with self.cached_session() as sess: with self.cached_session() as sess:
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"Matrix size-incompatible"): "Matrix size-incompatible"):
@ -65,7 +68,7 @@ class TensordotTest(test_lib.TestCase):
axes_ph: (a_axes, b_axes) axes_ph: (a_axes, b_axes)
}) })
@test_util.run_v1_only("b/120545219") @test_util.run_in_graph_and_eager_modes(use_gpu=True)
def test_invalid_axes(self): def test_invalid_axes(self):
a = [[1, 2], [3, 4]] a = [[1, 2], [3, 4]]
b = [[1, 2], [3, 4]] b = [[1, 2], [3, 4]]
@ -77,6 +80,8 @@ class TensordotTest(test_lib.TestCase):
with self.assertRaises(IndexError): with self.assertRaises(IndexError):
math_ops.tensordot(a, b, [[0], [7]]) math_ops.tensordot(a, b, [[0], [7]])
if context.executing_eagerly():
return
# Invalid dynamic axes. # Invalid dynamic axes.
a_ph = array_ops.placeholder(dtypes.float32) a_ph = array_ops.placeholder(dtypes.float32)
b_ph = array_ops.placeholder(dtypes.float32) b_ph = array_ops.placeholder(dtypes.float32)
@ -93,22 +98,22 @@ class TensordotTest(test_lib.TestCase):
axes_ph: axes_value axes_ph: axes_value
}) })
# Test case for 11950 # Test case for https://github.com/tensorflow/tensorflow/issues/11950
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def test_valid_axis(self): def test_valid_axis(self):
for axes_value in [1, 2], [[1], [2]], [[], []], 0: for axes_value in [1, 2], [[1], [2]], [[], []], 0:
with self.cached_session(): np_a = np.ones((3, 3))
np_a = np.ones((3, 3)) np_b = np.array([2, 3, 1])[None, None]
np_b = np.array([2, 3, 1])[None, None] np_ans = np.tensordot(np_a, np_b, axes_value)
np_ans = np.tensordot(np_a, np_b, axes_value)
tf_a = array_ops.ones((3, 3), dtype=dtypes.float32) tf_a = array_ops.ones((3, 3), dtype=dtypes.float32)
tf_b = constant_op.constant([2, 3, 1], dtype=dtypes.float32)[None, None] tf_b = constant_op.constant([2, 3, 1], dtype=dtypes.float32)[None, None]
tf_ans = math_ops.tensordot(tf_a, tf_b, axes_value) tf_ans = math_ops.tensordot(tf_a, tf_b, axes_value)
self.assertAllEqual(tf_ans.shape, np_ans.shape) self.assertAllEqual(tf_ans.shape, np_ans.shape)
self.assertAllEqual(tf_ans, np_ans) self.assertAllEqual(self.evaluate(tf_ans), np_ans)
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("Shape inference test")
def test_partial_shape_inference(self): def test_partial_shape_inference(self):
for axes in ([1], [0]), 1: for axes in ([1], [0]), 1:
a = array_ops.placeholder(dtypes.float32) a = array_ops.placeholder(dtypes.float32)
@ -159,7 +164,10 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
size=np.prod(b_shape)).reshape(b_shape).astype(dtype_) size=np.prod(b_shape)).reshape(b_shape).astype(dtype_)
return a, b, a_dims, b_dims return a, b, a_dims, b_dims
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def test_tensordot(self): def test_tensordot(self):
if dynamic_shape_ and context.executing_eagerly():
self.skipTest("Placeholders not support in eager mode")
num_trials = min(30, num_dims_ * num_dims_) num_trials = min(30, num_dims_ * num_dims_)
if dtype_ == np.float16: if dtype_ == np.float16:
tol = 0.05 tol = 0.05
@ -187,7 +195,10 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
self.assertAllClose(tf_ans, np_ans, rtol=tol, atol=tol) self.assertAllClose(tf_ans, np_ans, rtol=tol, atol=tol)
self.assertAllEqual(tf_ans.shape, np_ans.shape) self.assertAllEqual(tf_ans.shape, np_ans.shape)
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def test_tensordot_scalar_axes(self): def test_tensordot_scalar_axes(self):
if dynamic_shape_ and context.executing_eagerly():
self.skipTest("Placeholders not support in eager mode")
if num_dims_ < 1: if num_dims_ < 1:
self.skipTest("Not a test") self.skipTest("Not a test")
if dtype_ == np.float16: if dtype_ == np.float16:
@ -229,7 +240,7 @@ if __name__ == "__main__":
for rank_b in 1, 2, 4, 5: for rank_b in 1, 2, 4, 5:
for num_dims in range(0, min(rank_a, rank_b) + 1): for num_dims in range(0, min(rank_a, rank_b) + 1):
# TF2 does not support placeholders under eager so we skip it # TF2 does not support placeholders under eager so we skip it
for dynamic_shape in set([False, not tf2.enabled()]): for dynamic_shape in set([False, True]):
for testcase in _get_tensordot_tests(dtype, rank_a, rank_b, for testcase in _get_tensordot_tests(dtype, rank_a, rank_b,
num_dims, dynamic_shape): num_dims, dynamic_shape):
name = "%s_%s_%s_%s_%s_%s" % (testcase.__name__, dtype.__name__, name = "%s_%s_%s_%s_%s_%s" % (testcase.__name__, dtype.__name__,