Add test case for sparse_dense_cwise shape validation.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2018-12-01 21:02:04 +00:00
parent 8584f21392
commit a7cd4dbea9

View File

@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
@ -798,6 +799,17 @@ class SparseMathOpsTest(test_util.TensorFlowTestCase):
result_tensor.values).eval() result_tensor.values).eval()
self.assertAllEqual(result_np, res_densified) self.assertAllEqual(result_np, res_densified)
@test_util.run_deprecated_v1
def testCwiseShapeValidation(self):
# Test case for GitHub 24072.
with self.session(use_gpu=False):
a = array_ops.ones([3, 4, 1], dtype=dtypes.int32)
b = sparse_tensor.SparseTensor([[0, 0, 1, 0], [0, 0, 3, 0]], [10, 20], [1, 1, 4, 2])
c = a * b
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"broadcasts dense to sparse only; got incompatible shapes"):
c.eval()
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testCwiseDivAndMul(self): def testCwiseDivAndMul(self):
np.random.seed(1618) np.random.seed(1618)