Make sparse_to_dense_op_py_test v2 friendly

* Change op runtime error string to match with graph construction.
* Remove session scopes
* Fix invalid argument tests to be eager friendly by constructing and
  evaluating in the same line.

PiperOrigin-RevId: 320693166
Change-Id: I42ee0c4a4e141074863a366aef55f9960f0ea806
This commit is contained in:
Gaurav Jain 2020-07-10 16:00:43 -07:00 committed by TensorFlower Gardener
parent f085449f2b
commit 1956f5ad87
3 changed files with 63 additions and 124 deletions

View File

@ -63,7 +63,7 @@ class SparseToDense : public OpKernel {
const Tensor& output_shape = c->input(1); const Tensor& output_shape = c->input(1);
OP_REQUIRES( OP_REQUIRES(
c, TensorShapeUtils::IsVector(output_shape.shape()), c, TensorShapeUtils::IsVector(output_shape.shape()),
errors::InvalidArgument("output_shape should be a vector, got shape ", errors::InvalidArgument("output_shape must be rank 1, got shape ",
output_shape.shape().DebugString())); output_shape.shape().DebugString()));
OP_REQUIRES(c, output_shape.NumElements() == num_dims, OP_REQUIRES(c, output_shape.NumElements() == num_dims,
errors::InvalidArgument( errors::InvalidArgument(

View File

@ -136,7 +136,7 @@ class ScalarTest(test.TestCase):
def testSparseToDense(self): def testSparseToDense(self):
self.check(sparse_ops.sparse_to_dense, (1, 4, 7), self.check(sparse_ops.sparse_to_dense, (1, 4, 7),
'output_shape should be a vector', [0, 7, 0, 0]) 'output_shape must be rank 1', [0, 7, 0, 0])
def testTile(self): def testTile(self):
self.check(array_ops.tile, ([7], 2), 'Expected multiples to be 1-D', [7, 7]) self.check(array_ops.tile, ([7], 2), 'Expected multiples to be 1-D', [7, 7])

View File

@ -21,179 +21,119 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
def _SparseToDense(sparse_indices,
output_size,
sparse_values,
default_value,
validate_indices=True):
return sparse_ops.sparse_to_dense(
sparse_indices,
output_size,
sparse_values,
default_value=default_value,
validate_indices=validate_indices)
class SparseToDenseTest(test.TestCase): class SparseToDenseTest(test.TestCase):
@test_util.run_deprecated_v1
def testInt(self): def testInt(self):
with self.session(use_gpu=False): tf_ans = sparse_ops.sparse_to_dense([1, 3], [5], 1, 0)
tf_ans = _SparseToDense([1, 3], [5], 1, 0).eval()
np_ans = np.array([0, 1, 0, 1, 0]).astype(np.int32) np_ans = np.array([0, 1, 0, 1, 0]).astype(np.int32)
self.assertAllClose(np_ans, tf_ans) self.assertAllClose(np_ans, tf_ans)
@test_util.run_deprecated_v1
def testFloat(self): def testFloat(self):
with self.session(use_gpu=False): tf_ans = sparse_ops.sparse_to_dense([1, 3], [5], 1.0, 0.0)
tf_ans = _SparseToDense([1, 3], [5], 1.0, 0.0).eval()
np_ans = np.array([0, 1, 0, 1, 0]).astype(np.float32) np_ans = np.array([0, 1, 0, 1, 0]).astype(np.float32)
self.assertAllClose(np_ans, tf_ans) self.assertAllClose(np_ans, tf_ans)
@test_util.run_deprecated_v1
def testString(self): def testString(self):
with self.session(use_gpu=False): tf_ans = sparse_ops.sparse_to_dense([1, 3], [5], "a", "b")
tf_ans = _SparseToDense([1, 3], [5], "a", "b").eval()
np_ans = np.array(["b", "a", "b", "a", "b"]).astype(np.string_) np_ans = np.array(["b", "a", "b", "a", "b"]).astype(np.string_)
self.assertAllEqual(np_ans, tf_ans) self.assertAllEqual(np_ans, tf_ans)
@test_util.run_deprecated_v1
def testSetValue(self): def testSetValue(self):
with self.session(use_gpu=False): tf_ans = sparse_ops.sparse_to_dense([1, 3], [5], [1, 2], -1)
tf_ans = _SparseToDense([1, 3], [5], [1, 2], -1).eval()
np_ans = np.array([-1, 1, -1, 2, -1]).astype(np.int32) np_ans = np.array([-1, 1, -1, 2, -1]).astype(np.int32)
self.assertAllClose(np_ans, tf_ans) self.assertAllClose(np_ans, tf_ans)
@test_util.run_deprecated_v1
def testSetSingleValue(self): def testSetSingleValue(self):
with self.session(use_gpu=False): tf_ans = sparse_ops.sparse_to_dense([1, 3], [5], 1, -1)
tf_ans = _SparseToDense([1, 3], [5], 1, -1).eval()
np_ans = np.array([-1, 1, -1, 1, -1]).astype(np.int32) np_ans = np.array([-1, 1, -1, 1, -1]).astype(np.int32)
self.assertAllClose(np_ans, tf_ans) self.assertAllClose(np_ans, tf_ans)
@test_util.run_deprecated_v1
def test2d(self): def test2d(self):
# pylint: disable=bad-whitespace tf_ans = sparse_ops.sparse_to_dense([[1, 3], [2, 0]], [3, 4], 1, -1)
with self.session(use_gpu=False):
tf_ans = _SparseToDense([[1, 3], [2, 0]], [3, 4], 1, -1).eval()
np_ans = np.array([[-1, -1, -1, -1], np_ans = np.array([[-1, -1, -1, -1],
[-1, -1, -1, 1], [-1, -1, -1, 1],
[ 1, -1, -1, -1]]).astype(np.int32) [1, -1, -1, -1]]).astype(np.int32)
self.assertAllClose(np_ans, tf_ans) self.assertAllClose(np_ans, tf_ans)
@test_util.run_deprecated_v1
def testZeroDefault(self): def testZeroDefault(self):
with self.cached_session(): x = sparse_ops.sparse_to_dense(2, [4], 7)
x = sparse_ops.sparse_to_dense(2, [4], 7).eval() self.assertAllEqual(x, [0, 0, 7, 0])
self.assertAllEqual(x, [0, 0, 7, 0])
@test_util.run_deprecated_v1
def test3d(self): def test3d(self):
with self.session(use_gpu=False): tf_ans = sparse_ops.sparse_to_dense([[1, 3, 0], [2, 0, 1]], [3, 4, 2], 1,
tf_ans = _SparseToDense([[1, 3, 0], [2, 0, 1]], [3, 4, 2], 1, -1).eval() -1)
np_ans = np.ones((3, 4, 2), dtype=np.int32) * -1 np_ans = np.ones((3, 4, 2), dtype=np.int32) * -1
np_ans[1, 3, 0] = 1 np_ans[1, 3, 0] = 1
np_ans[2, 0, 1] = 1 np_ans[2, 0, 1] = 1
self.assertAllClose(np_ans, tf_ans) self.assertAllClose(np_ans, tf_ans)
@test_util.run_deprecated_v1
def testBadShape(self): def testBadShape(self):
with self.cached_session(): with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"): "must be rank 1"):
_SparseToDense([1, 3], [[5], [3]], 1, -1) sparse_ops.sparse_to_dense([1, 3], [[5], [3]], 1, -1)
@test_util.run_deprecated_v1
def testBadValue(self): def testBadValue(self):
with self.cached_session(): with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
dense = _SparseToDense([1, 3], [5], [[5], [3]], -1) r"sparse_values has incorrect shape \[2,1\], "
with self.assertRaisesOpError( r"should be \[\] or \[2\]"):
r"sparse_values has incorrect shape \[2,1\], " self.evaluate(sparse_ops.sparse_to_dense([1, 3], [5], [[5], [3]], -1))
r"should be \[\] or \[2\]"):
self.evaluate(dense)
@test_util.run_deprecated_v1
def testBadNumValues(self): def testBadNumValues(self):
with self.cached_session(): with self.assertRaisesRegex(
dense = _SparseToDense([1, 3], [5], [1, 2, 3], -1) (ValueError, errors.InvalidArgumentError),
with self.assertRaisesOpError( r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"):
r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"): self.evaluate(sparse_ops.sparse_to_dense([1, 3], [5], [1, 2, 3], -1))
self.evaluate(dense)
@test_util.run_deprecated_v1
def testBadDefault(self): def testBadDefault(self):
with self.cached_session(): with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
dense = _SparseToDense([1, 3], [5], [1, 2], [0]) "default_value should be a scalar"):
with self.assertRaisesOpError("default_value should be a scalar"): self.evaluate(sparse_ops.sparse_to_dense([1, 3], [5], [1, 2], [0]))
self.evaluate(dense)
@test_util.run_deprecated_v1
def testOutOfBoundsIndicesWithWithoutValidation(self): def testOutOfBoundsIndicesWithWithoutValidation(self):
with self.cached_session(): with self.assertRaisesRegex(
dense = _SparseToDense( (ValueError, errors.InvalidArgumentError),
sparse_indices=[[1], [10]], r"indices\[1\] = \[10\] is out of bounds: need 0 <= index < \[5\]"):
output_size=[5], self.evaluate(
sparse_values=[-1.0, 1.0], sparse_ops.sparse_to_dense([[1], [10]], [5], [1.0, 1.0], 0.0))
default_value=0.0) # Disable checks, the allocation should still fail.
with self.assertRaisesOpError( with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
r"indices\[1\] = \[10\] is out of bounds: need 0 <= index < \[5\]"): "out of bounds"):
self.evaluate(dense) self.evaluate(
# Disable checks, the allocation should still fail. sparse_ops.sparse_to_dense([[1], [10]], [5], [-1.0, 1.0],
with self.assertRaisesOpError("out of bounds"): 0.0,
dense_without_validation = _SparseToDense( validate_indices=False))
sparse_indices=[[1], [10]],
output_size=[5],
sparse_values=[-1.0, 1.0],
default_value=0.0,
validate_indices=False)
self.evaluate(dense_without_validation)
@test_util.run_deprecated_v1
def testRepeatingIndicesWithWithoutValidation(self): def testRepeatingIndicesWithWithoutValidation(self):
with self.cached_session(): with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
dense = _SparseToDense( r"indices\[1\] = \[1\] is repeated"):
sparse_indices=[[1], [1]], self.evaluate(
output_size=[5], sparse_ops.sparse_to_dense([[1], [1]], [5], [-1.0, 1.0], 0.0))
sparse_values=[-1.0, 1.0], # Disable checks
default_value=0.0) self.evaluate(
with self.assertRaisesOpError(r"indices\[1\] = \[1\] is repeated"): sparse_ops.sparse_to_dense([[1], [1]], [5], [-1.0, 1.0],
self.evaluate(dense) 0.0,
# Disable checks validate_indices=False))
dense_without_validation = _SparseToDense(
sparse_indices=[[1], [1]],
output_size=[5],
sparse_values=[-1.0, 1.0],
default_value=0.0,
validate_indices=False)
self.evaluate(dense_without_validation)
@test_util.run_deprecated_v1
def testUnsortedIndicesWithWithoutValidation(self): def testUnsortedIndicesWithWithoutValidation(self):
with self.cached_session(): with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
dense = _SparseToDense( r"indices\[1\] = \[1\] is out of order"):
sparse_indices=[[2], [1]], self.evaluate(
output_size=[5], sparse_ops.sparse_to_dense([[2], [1]], [5], [-1.0, 1.0], 0.0))
sparse_values=[-1.0, 1.0], # Disable checks
default_value=0.0) self.evaluate(
with self.assertRaisesOpError(r"indices\[1\] = \[1\] is out of order"): sparse_ops.sparse_to_dense([[2], [1]], [5], [-1.0, 1.0],
self.evaluate(dense) 0.0,
# Disable checks validate_indices=False))
dense_without_validation = _SparseToDense(
sparse_indices=[[2], [1]],
output_size=[5],
sparse_values=[-1.0, 1.0],
default_value=0.0,
validate_indices=False)
self.evaluate(dense_without_validation)
@test_util.run_deprecated_v1
def testShapeInferenceKnownShape(self): def testShapeInferenceKnownShape(self):
with self.session(use_gpu=False): with ops.Graph().as_default():
indices = array_ops.placeholder(dtypes.int64) indices = array_ops.placeholder(dtypes.int64)
shape = [4, 5, 6] shape = [4, 5, 6]
@ -204,13 +144,12 @@ class SparseToDenseTest(test.TestCase):
output = sparse_ops.sparse_to_dense(indices, shape, 1, 0) output = sparse_ops.sparse_to_dense(indices, shape, 1, 0)
self.assertEqual(output.get_shape().as_list(), [None, None, None]) self.assertEqual(output.get_shape().as_list(), [None, None, None])
@test_util.run_deprecated_v1
def testShapeInferenceUnknownShape(self): def testShapeInferenceUnknownShape(self):
with self.session(use_gpu=False): with ops.Graph().as_default():
indices = array_ops.placeholder(dtypes.int64) indices = array_ops.placeholder(dtypes.int64)
shape = array_ops.placeholder(dtypes.int64) shape = array_ops.placeholder(dtypes.int64)
output = sparse_ops.sparse_to_dense(indices, shape, 1, 0) output = sparse_ops.sparse_to_dense(indices, shape, 1, 0)
self.assertEqual(output.get_shape().ndims, None) self.assertIsNone(output.get_shape().ndims)
if __name__ == "__main__": if __name__ == "__main__":