Enable sparse tensor dense matmul op tests in eager mode.

PiperOrigin-RevId: 335670100
Change-Id: I959563a43a8955810e9a8672f127fc0af6aab5cd
This commit is contained in:
Penporn Koanantakool 2020-10-06 10:32:54 -07:00 committed by TensorFlower Gardener
parent 6ae2b68fd3
commit e665554b90

View File

@ -97,7 +97,6 @@ class SparseTensorDenseMatMulTest(test.TestCase):
self._testMatmul(x, y, indices_dtype=indices_dtype)
@test_util.run_deprecated_v1
def testBasic(self):
np.random.seed(127) # Repeatable results
self._testBasic(np.int32)
@ -108,7 +107,6 @@ class SparseTensorDenseMatMulTest(test.TestCase):
self._testBasic(np.int32, indices_dtype=np.int32)
self._testBasic(np.float32, indices_dtype=np.int32)
@test_util.run_deprecated_v1
def testShapeInference(self):
x = np.random.rand(10, 10)
x[np.abs(x) < 0.5] = 0 # Make it sparse
@ -116,97 +114,91 @@ class SparseTensorDenseMatMulTest(test.TestCase):
x_indices = np.vstack(np.where(x)).astype(np.int64).T
x_values = x[np.where(x)]
x_shape = x.shape
x_st = sparse_tensor.SparseTensor(x_indices, x_values, x_shape)
result = sparse_ops.sparse_tensor_dense_matmul(x_st, y)
self.assertEqual(result.get_shape(), (10, 20))
x_shape_unknown = array_ops.placeholder(dtype=dtypes.int64, shape=None)
x_st_shape_unknown = sparse_tensor.SparseTensor(x_indices, x_values,
x_shape_unknown)
result_left_shape_unknown = sparse_ops.sparse_tensor_dense_matmul(
x_st_shape_unknown, y)
self.assertEqual(result_left_shape_unknown.get_shape().as_list(),
[None, 20])
with ops.Graph().as_default():
x_st = sparse_tensor.SparseTensor(x_indices, x_values, x_shape)
result = sparse_ops.sparse_tensor_dense_matmul(x_st, y)
self.assertEqual(result.get_shape(), (10, 20))
x_shape_inconsistent = [10, 15]
x_st_shape_inconsistent = sparse_tensor.SparseTensor(x_indices, x_values,
x_shape_inconsistent)
with self.assertRaisesRegex(ValueError, "Dimensions must be equal"):
sparse_ops.sparse_tensor_dense_matmul(x_st_shape_inconsistent, y)
x_shape_unknown = array_ops.placeholder(dtype=dtypes.int64, shape=None)
x_st_shape_unknown = sparse_tensor.SparseTensor(x_indices, x_values,
x_shape_unknown)
result_left_shape_unknown = sparse_ops.sparse_tensor_dense_matmul(
x_st_shape_unknown, y)
self.assertEqual(result_left_shape_unknown.get_shape().as_list(),
[None, 20])
@test_util.deprecated_graph_mode_only
x_shape_inconsistent = [10, 15]
x_st_shape_inconsistent = sparse_tensor.SparseTensor(
x_indices, x_values, x_shape_inconsistent)
with self.assertRaisesRegex(ValueError, "Dimensions must be equal"):
sparse_ops.sparse_tensor_dense_matmul(x_st_shape_inconsistent, y)
@test_util.run_in_graph_and_eager_modes(use_gpu=False)
def testInvalidIndicesForSparseTensorDenseMatmul(self):
# Note: use_gpu=False because nice errors are only returned from CPU kernel.
with self.session(use_gpu=False):
indices = np.matrix([[1, 10]]).astype(np.int64)
values = np.array([10]).astype(np.float32)
shape = [3, 2]
sparse_t = sparse_tensor.SparseTensor(indices, values, shape)
# TODO(b/169813429): Make GPU kernel return nice errors too.
indices = np.matrix([[1, 10]]).astype(np.int64)
values = np.array([10]).astype(np.float32)
shape = [3, 2]
sparse_t = sparse_tensor.SparseTensor(indices, values, shape)
# Test multiplying by both a small and large dense matrix, to hit
# both cases in the kernel.
dense_t = np.matrix([[1] * 5, [2] * 5], dtype=np.float32)
with self.assertRaisesOpError(
"k .10. from index.0,1. out of bounds .>=2."):
self.evaluate(sparse_ops.sparse_tensor_dense_matmul(sparse_t, dense_t))
dense_t = np.matrix([[1] * 500, [2] * 500], dtype=np.float32)
with self.assertRaisesOpError(
"k .10. from index.0,1. out of bounds .>=2."):
self.evaluate(sparse_ops.sparse_tensor_dense_matmul(sparse_t, dense_t))
# Test multiplying by both a small and large dense matrix, to hit
# both cases in the kernel.
dense_t = np.matrix([[1] * 5, [2] * 5], dtype=np.float32)
with self.assertRaisesOpError("k .10. from index.0,1. out of bounds .>=2."):
self.evaluate(sparse_ops.sparse_tensor_dense_matmul(sparse_t, dense_t))
dense_t = np.matrix([[1] * 500, [2] * 500], dtype=np.float32)
with self.assertRaisesOpError("k .10. from index.0,1. out of bounds .>=2."):
self.evaluate(sparse_ops.sparse_tensor_dense_matmul(sparse_t, dense_t))
# Repeat with adjoint_a, to get a different error.
dense_t = np.matrix([[1] * 5, [2] * 5, [3] * 5], dtype=np.float32)
with self.assertRaisesOpError(
"m .10. from index.0,1. out of bounds .>=2."):
self.evaluate(
sparse_ops.sparse_tensor_dense_matmul(
sparse_t, dense_t, adjoint_a=True))
dense_t = np.matrix([[1] * 500, [2] * 500, [3] * 500], dtype=np.float32)
with self.assertRaisesOpError(
"m .10. from index.0,1. out of bounds .>=2."):
self.evaluate(
sparse_ops.sparse_tensor_dense_matmul(
sparse_t, dense_t, adjoint_a=True))
# Repeat with adjoint_a, to get a different error.
dense_t = np.matrix([[1] * 5, [2] * 5, [3] * 5], dtype=np.float32)
with self.assertRaisesOpError("m .10. from index.0,1. out of bounds .>=2."):
self.evaluate(
sparse_ops.sparse_tensor_dense_matmul(
sparse_t, dense_t, adjoint_a=True))
dense_t = np.matrix([[1] * 500, [2] * 500, [3] * 500], dtype=np.float32)
with self.assertRaisesOpError("m .10. from index.0,1. out of bounds .>=2."):
self.evaluate(
sparse_ops.sparse_tensor_dense_matmul(
sparse_t, dense_t, adjoint_a=True))
@test_util.run_gpu_only
def testInvalidIndicesForSparseTensorDenseMatmulOnGPU(self):
# Note: use_gpu=False because nice errors are only returned from CPU kerne
if not test.is_gpu_available():
return
with self.session(use_gpu=True):
indices = np.array([[1, 10]]).astype(np.int64)
values = np.array([10]).astype(np.float32)
shape = [3, 2]
sparse_t = sparse_tensor.SparseTensor(indices, values, shape)
indices = np.array([[1, 10]]).astype(np.int64)
values = np.array([10]).astype(np.float32)
shape = [3, 2]
sparse_t = sparse_tensor.SparseTensor(indices, values, shape)
# Test multiplying by both a small and large dense matrix, to hit
# both cases in the kernel.
dense_t = np.matrix([[1] * 5, [2] * 5], dtype=np.float32)
expected_t = np.array([[0] * 5, [np.nan] * 5, [0] * 5], dtype=np.float32)
self.assertAllClose(expected_t,
sparse_ops.sparse_tensor_dense_matmul(
sparse_t, dense_t))
dense_t = np.matrix([[1] * 500, [2] * 500], dtype=np.float32)
expected_t = np.array(
[[0] * 500, [np.nan] * 500, [0] * 500], dtype=np.float32)
self.assertAllClose(expected_t,
sparse_ops.sparse_tensor_dense_matmul(
sparse_t, dense_t))
# Test multiplying by both a small and large dense matrix, to hit
# both cases in the kernel.
dense_t = np.matrix([[1] * 5, [2] * 5], dtype=np.float32)
expected_t = np.array([[0] * 5, [np.nan] * 5, [0] * 5], dtype=np.float32)
self.assertAllClose(
expected_t, sparse_ops.sparse_tensor_dense_matmul(sparse_t, dense_t))
dense_t = np.matrix([[1] * 500, [2] * 500], dtype=np.float32)
expected_t = np.array([[0] * 500, [np.nan] * 500, [0] * 500],
dtype=np.float32)
self.assertAllClose(
expected_t, sparse_ops.sparse_tensor_dense_matmul(sparse_t, dense_t))
# Repeat with adjoint_a, now the error is that the sparse index
# is OOO w.r.t. the output. The GPU kernel can't do much here,
# so it just doesn't accumulate.
# Repeat with adjoint_a, now the error is that the sparse index
# is OOO w.r.t. the output. The GPU kernel can't do much here,
# so it just doesn't accumulate.
dense_t = np.matrix([[1] * 5, [2] * 5, [3] * 5], dtype=np.float32)
expected_t = np.array([[0] * 5, [0] * 5], dtype=np.float32)
self.assertAllClose(expected_t,
sparse_ops.sparse_tensor_dense_matmul(
sparse_t, dense_t, adjoint_a=True))
dense_t = np.matrix([[1] * 5, [2] * 5, [3] * 5], dtype=np.float32)
expected_t = np.array([[0] * 5, [0] * 5], dtype=np.float32)
self.assertAllClose(
expected_t,
sparse_ops.sparse_tensor_dense_matmul(
sparse_t, dense_t, adjoint_a=True))
dense_t = np.matrix([[1] * 500, [2] * 500, [3] * 500], dtype=np.float32)
expected_t = np.array([[0] * 500, [0] * 500], dtype=np.float32)
self.assertAllClose(expected_t,
sparse_ops.sparse_tensor_dense_matmul(
sparse_t, dense_t, adjoint_a=True))
dense_t = np.matrix([[1] * 500, [2] * 500, [3] * 500], dtype=np.float32)
expected_t = np.array([[0] * 500, [0] * 500], dtype=np.float32)
self.assertAllClose(
expected_t,
sparse_ops.sparse_tensor_dense_matmul(
sparse_t, dense_t, adjoint_a=True))
# Tests setting one dimension to be a high value.
def _testLarge(self, np_dtype):
@ -235,7 +227,6 @@ class SparseTensorDenseMatMulTest(test.TestCase):
self._testLarge(np.complex128)
# Tests random sized matrices.
@test_util.run_deprecated_v1
def testFloatRandom(self):
np.random.seed(127) # Repeatable results
for _ in range(8):