Enable tests for tf.linalg.tensordot in eager mode.
PiperOrigin-RevId: 312144965 Change-Id: I2d75f7d9bd7f05aef6d1dee620dffcea66071b97
This commit is contained in:
parent
da67fcddef
commit
d4f71ff132
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
@ -41,16 +41,19 @@ def _add_test(test, test_name, fn):
|
||||
|
||||
class TensordotTest(test_lib.TestCase):
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def test_invalid_shape(self):
|
||||
a = [[1, 2], [3, 4]]
|
||||
b = [[1, 2], [3, 4], [5, 6]]
|
||||
a_axes = [1]
|
||||
b_axes = [0]
|
||||
# Invalid static shapes.
|
||||
with self.assertRaises(ValueError):
|
||||
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
|
||||
math_ops.tensordot(a, b, (a_axes, b_axes))
|
||||
|
||||
# Invalid dynamic shapes.
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
"Matrix size-incompatible"):
|
||||
@ -65,7 +68,7 @@ class TensordotTest(test_lib.TestCase):
|
||||
axes_ph: (a_axes, b_axes)
|
||||
})
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def test_invalid_axes(self):
|
||||
a = [[1, 2], [3, 4]]
|
||||
b = [[1, 2], [3, 4]]
|
||||
@ -77,6 +80,8 @@ class TensordotTest(test_lib.TestCase):
|
||||
with self.assertRaises(IndexError):
|
||||
math_ops.tensordot(a, b, [[0], [7]])
|
||||
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
# Invalid dynamic axes.
|
||||
a_ph = array_ops.placeholder(dtypes.float32)
|
||||
b_ph = array_ops.placeholder(dtypes.float32)
|
||||
@ -93,22 +98,22 @@ class TensordotTest(test_lib.TestCase):
|
||||
axes_ph: axes_value
|
||||
})
|
||||
|
||||
# Test case for 11950
|
||||
# Test case for https://github.com/tensorflow/tensorflow/issues/11950
|
||||
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def test_valid_axis(self):
|
||||
for axes_value in [1, 2], [[1], [2]], [[], []], 0:
|
||||
with self.cached_session():
|
||||
np_a = np.ones((3, 3))
|
||||
np_b = np.array([2, 3, 1])[None, None]
|
||||
np_ans = np.tensordot(np_a, np_b, axes_value)
|
||||
np_a = np.ones((3, 3))
|
||||
np_b = np.array([2, 3, 1])[None, None]
|
||||
np_ans = np.tensordot(np_a, np_b, axes_value)
|
||||
|
||||
tf_a = array_ops.ones((3, 3), dtype=dtypes.float32)
|
||||
tf_b = constant_op.constant([2, 3, 1], dtype=dtypes.float32)[None, None]
|
||||
tf_ans = math_ops.tensordot(tf_a, tf_b, axes_value)
|
||||
tf_a = array_ops.ones((3, 3), dtype=dtypes.float32)
|
||||
tf_b = constant_op.constant([2, 3, 1], dtype=dtypes.float32)[None, None]
|
||||
tf_ans = math_ops.tensordot(tf_a, tf_b, axes_value)
|
||||
|
||||
self.assertAllEqual(tf_ans.shape, np_ans.shape)
|
||||
self.assertAllEqual(tf_ans, np_ans)
|
||||
self.assertAllEqual(tf_ans.shape, np_ans.shape)
|
||||
self.assertAllEqual(self.evaluate(tf_ans), np_ans)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_v1_only("Shape inference test")
|
||||
def test_partial_shape_inference(self):
|
||||
for axes in ([1], [0]), 1:
|
||||
a = array_ops.placeholder(dtypes.float32)
|
||||
@ -159,7 +164,10 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
|
||||
size=np.prod(b_shape)).reshape(b_shape).astype(dtype_)
|
||||
return a, b, a_dims, b_dims
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def test_tensordot(self):
|
||||
if dynamic_shape_ and context.executing_eagerly():
|
||||
self.skipTest("Placeholders not support in eager mode")
|
||||
num_trials = min(30, num_dims_ * num_dims_)
|
||||
if dtype_ == np.float16:
|
||||
tol = 0.05
|
||||
@ -187,7 +195,10 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
|
||||
self.assertAllClose(tf_ans, np_ans, rtol=tol, atol=tol)
|
||||
self.assertAllEqual(tf_ans.shape, np_ans.shape)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def test_tensordot_scalar_axes(self):
|
||||
if dynamic_shape_ and context.executing_eagerly():
|
||||
self.skipTest("Placeholders not support in eager mode")
|
||||
if num_dims_ < 1:
|
||||
self.skipTest("Not a test")
|
||||
if dtype_ == np.float16:
|
||||
@ -229,7 +240,7 @@ if __name__ == "__main__":
|
||||
for rank_b in 1, 2, 4, 5:
|
||||
for num_dims in range(0, min(rank_a, rank_b) + 1):
|
||||
# TF2 does not support placeholders under eager so we skip it
|
||||
for dynamic_shape in set([False, not tf2.enabled()]):
|
||||
for dynamic_shape in set([False, True]):
|
||||
for testcase in _get_tensordot_tests(dtype, rank_a, rank_b,
|
||||
num_dims, dynamic_shape):
|
||||
name = "%s_%s_%s_%s_%s_%s" % (testcase.__name__, dtype.__name__,
|
||||
|
Loading…
Reference in New Issue
Block a user