Enable tests for tf.linalg.tensordot in eager mode.
PiperOrigin-RevId: 312144965 Change-Id: I2d75f7d9bd7f05aef6d1dee620dffcea66071b97
This commit is contained in:
parent
da67fcddef
commit
d4f71ff132
@ -20,7 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
@ -41,16 +41,19 @@ def _add_test(test, test_name, fn):
|
|||||||
|
|
||||||
class TensordotTest(test_lib.TestCase):
|
class TensordotTest(test_lib.TestCase):
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def test_invalid_shape(self):
|
def test_invalid_shape(self):
|
||||||
a = [[1, 2], [3, 4]]
|
a = [[1, 2], [3, 4]]
|
||||||
b = [[1, 2], [3, 4], [5, 6]]
|
b = [[1, 2], [3, 4], [5, 6]]
|
||||||
a_axes = [1]
|
a_axes = [1]
|
||||||
b_axes = [0]
|
b_axes = [0]
|
||||||
# Invalid static shapes.
|
# Invalid static shapes.
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
|
||||||
math_ops.tensordot(a, b, (a_axes, b_axes))
|
math_ops.tensordot(a, b, (a_axes, b_axes))
|
||||||
|
|
||||||
# Invalid dynamic shapes.
|
# Invalid dynamic shapes.
|
||||||
|
if context.executing_eagerly():
|
||||||
|
return
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||||
"Matrix size-incompatible"):
|
"Matrix size-incompatible"):
|
||||||
@ -65,7 +68,7 @@ class TensordotTest(test_lib.TestCase):
|
|||||||
axes_ph: (a_axes, b_axes)
|
axes_ph: (a_axes, b_axes)
|
||||||
})
|
})
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def test_invalid_axes(self):
|
def test_invalid_axes(self):
|
||||||
a = [[1, 2], [3, 4]]
|
a = [[1, 2], [3, 4]]
|
||||||
b = [[1, 2], [3, 4]]
|
b = [[1, 2], [3, 4]]
|
||||||
@ -77,6 +80,8 @@ class TensordotTest(test_lib.TestCase):
|
|||||||
with self.assertRaises(IndexError):
|
with self.assertRaises(IndexError):
|
||||||
math_ops.tensordot(a, b, [[0], [7]])
|
math_ops.tensordot(a, b, [[0], [7]])
|
||||||
|
|
||||||
|
if context.executing_eagerly():
|
||||||
|
return
|
||||||
# Invalid dynamic axes.
|
# Invalid dynamic axes.
|
||||||
a_ph = array_ops.placeholder(dtypes.float32)
|
a_ph = array_ops.placeholder(dtypes.float32)
|
||||||
b_ph = array_ops.placeholder(dtypes.float32)
|
b_ph = array_ops.placeholder(dtypes.float32)
|
||||||
@ -93,22 +98,22 @@ class TensordotTest(test_lib.TestCase):
|
|||||||
axes_ph: axes_value
|
axes_ph: axes_value
|
||||||
})
|
})
|
||||||
|
|
||||||
# Test case for 11950
|
# Test case for https://github.com/tensorflow/tensorflow/issues/11950
|
||||||
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def test_valid_axis(self):
|
def test_valid_axis(self):
|
||||||
for axes_value in [1, 2], [[1], [2]], [[], []], 0:
|
for axes_value in [1, 2], [[1], [2]], [[], []], 0:
|
||||||
with self.cached_session():
|
np_a = np.ones((3, 3))
|
||||||
np_a = np.ones((3, 3))
|
np_b = np.array([2, 3, 1])[None, None]
|
||||||
np_b = np.array([2, 3, 1])[None, None]
|
np_ans = np.tensordot(np_a, np_b, axes_value)
|
||||||
np_ans = np.tensordot(np_a, np_b, axes_value)
|
|
||||||
|
|
||||||
tf_a = array_ops.ones((3, 3), dtype=dtypes.float32)
|
tf_a = array_ops.ones((3, 3), dtype=dtypes.float32)
|
||||||
tf_b = constant_op.constant([2, 3, 1], dtype=dtypes.float32)[None, None]
|
tf_b = constant_op.constant([2, 3, 1], dtype=dtypes.float32)[None, None]
|
||||||
tf_ans = math_ops.tensordot(tf_a, tf_b, axes_value)
|
tf_ans = math_ops.tensordot(tf_a, tf_b, axes_value)
|
||||||
|
|
||||||
self.assertAllEqual(tf_ans.shape, np_ans.shape)
|
self.assertAllEqual(tf_ans.shape, np_ans.shape)
|
||||||
self.assertAllEqual(tf_ans, np_ans)
|
self.assertAllEqual(self.evaluate(tf_ans), np_ans)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("Shape inference test")
|
||||||
def test_partial_shape_inference(self):
|
def test_partial_shape_inference(self):
|
||||||
for axes in ([1], [0]), 1:
|
for axes in ([1], [0]), 1:
|
||||||
a = array_ops.placeholder(dtypes.float32)
|
a = array_ops.placeholder(dtypes.float32)
|
||||||
@ -159,7 +164,10 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
|
|||||||
size=np.prod(b_shape)).reshape(b_shape).astype(dtype_)
|
size=np.prod(b_shape)).reshape(b_shape).astype(dtype_)
|
||||||
return a, b, a_dims, b_dims
|
return a, b, a_dims, b_dims
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def test_tensordot(self):
|
def test_tensordot(self):
|
||||||
|
if dynamic_shape_ and context.executing_eagerly():
|
||||||
|
self.skipTest("Placeholders not support in eager mode")
|
||||||
num_trials = min(30, num_dims_ * num_dims_)
|
num_trials = min(30, num_dims_ * num_dims_)
|
||||||
if dtype_ == np.float16:
|
if dtype_ == np.float16:
|
||||||
tol = 0.05
|
tol = 0.05
|
||||||
@ -187,7 +195,10 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
|
|||||||
self.assertAllClose(tf_ans, np_ans, rtol=tol, atol=tol)
|
self.assertAllClose(tf_ans, np_ans, rtol=tol, atol=tol)
|
||||||
self.assertAllEqual(tf_ans.shape, np_ans.shape)
|
self.assertAllEqual(tf_ans.shape, np_ans.shape)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def test_tensordot_scalar_axes(self):
|
def test_tensordot_scalar_axes(self):
|
||||||
|
if dynamic_shape_ and context.executing_eagerly():
|
||||||
|
self.skipTest("Placeholders not support in eager mode")
|
||||||
if num_dims_ < 1:
|
if num_dims_ < 1:
|
||||||
self.skipTest("Not a test")
|
self.skipTest("Not a test")
|
||||||
if dtype_ == np.float16:
|
if dtype_ == np.float16:
|
||||||
@ -229,7 +240,7 @@ if __name__ == "__main__":
|
|||||||
for rank_b in 1, 2, 4, 5:
|
for rank_b in 1, 2, 4, 5:
|
||||||
for num_dims in range(0, min(rank_a, rank_b) + 1):
|
for num_dims in range(0, min(rank_a, rank_b) + 1):
|
||||||
# TF2 does not support placeholders under eager so we skip it
|
# TF2 does not support placeholders under eager so we skip it
|
||||||
for dynamic_shape in set([False, not tf2.enabled()]):
|
for dynamic_shape in set([False, True]):
|
||||||
for testcase in _get_tensordot_tests(dtype, rank_a, rank_b,
|
for testcase in _get_tensordot_tests(dtype, rank_a, rank_b,
|
||||||
num_dims, dynamic_shape):
|
num_dims, dynamic_shape):
|
||||||
name = "%s_%s_%s_%s_%s_%s" % (testcase.__name__, dtype.__name__,
|
name = "%s_%s_%s_%s_%s_%s" % (testcase.__name__, dtype.__name__,
|
||||||
|
Loading…
Reference in New Issue
Block a user