Make sparse_to_dense_op_py_test v2 friendly
* Change op runtime error string to match with graph construction. * Remove session scopes * Fix invalid argument tests to be eager friendly by constructing and evaluating in the same line. PiperOrigin-RevId: 320693166 Change-Id: I42ee0c4a4e141074863a366aef55f9960f0ea806
This commit is contained in:
		
							parent
							
								
									f085449f2b
								
							
						
					
					
						commit
						1956f5ad87
					
				| @ -63,7 +63,7 @@ class SparseToDense : public OpKernel { | |||||||
|     const Tensor& output_shape = c->input(1); |     const Tensor& output_shape = c->input(1); | ||||||
|     OP_REQUIRES( |     OP_REQUIRES( | ||||||
|         c, TensorShapeUtils::IsVector(output_shape.shape()), |         c, TensorShapeUtils::IsVector(output_shape.shape()), | ||||||
|         errors::InvalidArgument("output_shape should be a vector, got shape ", |         errors::InvalidArgument("output_shape must be rank 1, got shape ", | ||||||
|                                 output_shape.shape().DebugString())); |                                 output_shape.shape().DebugString())); | ||||||
|     OP_REQUIRES(c, output_shape.NumElements() == num_dims, |     OP_REQUIRES(c, output_shape.NumElements() == num_dims, | ||||||
|                 errors::InvalidArgument( |                 errors::InvalidArgument( | ||||||
|  | |||||||
| @ -136,7 +136,7 @@ class ScalarTest(test.TestCase): | |||||||
| 
 | 
 | ||||||
|   def testSparseToDense(self): |   def testSparseToDense(self): | ||||||
|     self.check(sparse_ops.sparse_to_dense, (1, 4, 7), |     self.check(sparse_ops.sparse_to_dense, (1, 4, 7), | ||||||
|                'output_shape should be a vector', [0, 7, 0, 0]) |                'output_shape must be rank 1', [0, 7, 0, 0]) | ||||||
| 
 | 
 | ||||||
|   def testTile(self): |   def testTile(self): | ||||||
|     self.check(array_ops.tile, ([7], 2), 'Expected multiples to be 1-D', [7, 7]) |     self.check(array_ops.tile, ([7], 2), 'Expected multiples to be 1-D', [7, 7]) | ||||||
|  | |||||||
| @ -21,179 +21,119 @@ from __future__ import print_function | |||||||
| import numpy as np | import numpy as np | ||||||
| 
 | 
 | ||||||
| from tensorflow.python.framework import dtypes | from tensorflow.python.framework import dtypes | ||||||
| from tensorflow.python.framework import test_util | from tensorflow.python.framework import errors | ||||||
|  | from tensorflow.python.framework import ops | ||||||
| from tensorflow.python.ops import array_ops | from tensorflow.python.ops import array_ops | ||||||
| from tensorflow.python.ops import sparse_ops | from tensorflow.python.ops import sparse_ops | ||||||
| from tensorflow.python.platform import test | from tensorflow.python.platform import test | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _SparseToDense(sparse_indices, |  | ||||||
|                    output_size, |  | ||||||
|                    sparse_values, |  | ||||||
|                    default_value, |  | ||||||
|                    validate_indices=True): |  | ||||||
|   return sparse_ops.sparse_to_dense( |  | ||||||
|       sparse_indices, |  | ||||||
|       output_size, |  | ||||||
|       sparse_values, |  | ||||||
|       default_value=default_value, |  | ||||||
|       validate_indices=validate_indices) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class SparseToDenseTest(test.TestCase): | class SparseToDenseTest(test.TestCase): | ||||||
| 
 | 
 | ||||||
|   @test_util.run_deprecated_v1 |  | ||||||
|   def testInt(self): |   def testInt(self): | ||||||
|     with self.session(use_gpu=False): |     tf_ans = sparse_ops.sparse_to_dense([1, 3], [5], 1, 0) | ||||||
|       tf_ans = _SparseToDense([1, 3], [5], 1, 0).eval() |  | ||||||
|     np_ans = np.array([0, 1, 0, 1, 0]).astype(np.int32) |     np_ans = np.array([0, 1, 0, 1, 0]).astype(np.int32) | ||||||
|     self.assertAllClose(np_ans, tf_ans) |     self.assertAllClose(np_ans, tf_ans) | ||||||
| 
 | 
 | ||||||
|   @test_util.run_deprecated_v1 |  | ||||||
|   def testFloat(self): |   def testFloat(self): | ||||||
|     with self.session(use_gpu=False): |     tf_ans = sparse_ops.sparse_to_dense([1, 3], [5], 1.0, 0.0) | ||||||
|       tf_ans = _SparseToDense([1, 3], [5], 1.0, 0.0).eval() |  | ||||||
|     np_ans = np.array([0, 1, 0, 1, 0]).astype(np.float32) |     np_ans = np.array([0, 1, 0, 1, 0]).astype(np.float32) | ||||||
|     self.assertAllClose(np_ans, tf_ans) |     self.assertAllClose(np_ans, tf_ans) | ||||||
| 
 | 
 | ||||||
|   @test_util.run_deprecated_v1 |  | ||||||
|   def testString(self): |   def testString(self): | ||||||
|     with self.session(use_gpu=False): |     tf_ans = sparse_ops.sparse_to_dense([1, 3], [5], "a", "b") | ||||||
|       tf_ans = _SparseToDense([1, 3], [5], "a", "b").eval() |  | ||||||
|     np_ans = np.array(["b", "a", "b", "a", "b"]).astype(np.string_) |     np_ans = np.array(["b", "a", "b", "a", "b"]).astype(np.string_) | ||||||
|     self.assertAllEqual(np_ans, tf_ans) |     self.assertAllEqual(np_ans, tf_ans) | ||||||
| 
 | 
 | ||||||
|   @test_util.run_deprecated_v1 |  | ||||||
|   def testSetValue(self): |   def testSetValue(self): | ||||||
|     with self.session(use_gpu=False): |     tf_ans = sparse_ops.sparse_to_dense([1, 3], [5], [1, 2], -1) | ||||||
|       tf_ans = _SparseToDense([1, 3], [5], [1, 2], -1).eval() |  | ||||||
|     np_ans = np.array([-1, 1, -1, 2, -1]).astype(np.int32) |     np_ans = np.array([-1, 1, -1, 2, -1]).astype(np.int32) | ||||||
|     self.assertAllClose(np_ans, tf_ans) |     self.assertAllClose(np_ans, tf_ans) | ||||||
| 
 | 
 | ||||||
|   @test_util.run_deprecated_v1 |  | ||||||
|   def testSetSingleValue(self): |   def testSetSingleValue(self): | ||||||
|     with self.session(use_gpu=False): |     tf_ans = sparse_ops.sparse_to_dense([1, 3], [5], 1, -1) | ||||||
|       tf_ans = _SparseToDense([1, 3], [5], 1, -1).eval() |  | ||||||
|     np_ans = np.array([-1, 1, -1, 1, -1]).astype(np.int32) |     np_ans = np.array([-1, 1, -1, 1, -1]).astype(np.int32) | ||||||
|     self.assertAllClose(np_ans, tf_ans) |     self.assertAllClose(np_ans, tf_ans) | ||||||
| 
 | 
 | ||||||
|   @test_util.run_deprecated_v1 |  | ||||||
|   def test2d(self): |   def test2d(self): | ||||||
|     # pylint: disable=bad-whitespace |     tf_ans = sparse_ops.sparse_to_dense([[1, 3], [2, 0]], [3, 4], 1, -1) | ||||||
|     with self.session(use_gpu=False): |  | ||||||
|       tf_ans = _SparseToDense([[1, 3], [2, 0]], [3, 4], 1, -1).eval() |  | ||||||
|     np_ans = np.array([[-1, -1, -1, -1], |     np_ans = np.array([[-1, -1, -1, -1], | ||||||
|                        [-1, -1, -1,  1], |                        [-1, -1, -1, 1], | ||||||
|                        [ 1, -1, -1, -1]]).astype(np.int32) |                        [1, -1, -1, -1]]).astype(np.int32) | ||||||
|     self.assertAllClose(np_ans, tf_ans) |     self.assertAllClose(np_ans, tf_ans) | ||||||
| 
 | 
 | ||||||
|   @test_util.run_deprecated_v1 |  | ||||||
|   def testZeroDefault(self): |   def testZeroDefault(self): | ||||||
|     with self.cached_session(): |     x = sparse_ops.sparse_to_dense(2, [4], 7) | ||||||
|       x = sparse_ops.sparse_to_dense(2, [4], 7).eval() |     self.assertAllEqual(x, [0, 0, 7, 0]) | ||||||
|       self.assertAllEqual(x, [0, 0, 7, 0]) |  | ||||||
| 
 | 
 | ||||||
|   @test_util.run_deprecated_v1 |  | ||||||
|   def test3d(self): |   def test3d(self): | ||||||
|     with self.session(use_gpu=False): |     tf_ans = sparse_ops.sparse_to_dense([[1, 3, 0], [2, 0, 1]], [3, 4, 2], 1, | ||||||
|       tf_ans = _SparseToDense([[1, 3, 0], [2, 0, 1]], [3, 4, 2], 1, -1).eval() |                                         -1) | ||||||
|     np_ans = np.ones((3, 4, 2), dtype=np.int32) * -1 |     np_ans = np.ones((3, 4, 2), dtype=np.int32) * -1 | ||||||
|     np_ans[1, 3, 0] = 1 |     np_ans[1, 3, 0] = 1 | ||||||
|     np_ans[2, 0, 1] = 1 |     np_ans[2, 0, 1] = 1 | ||||||
|     self.assertAllClose(np_ans, tf_ans) |     self.assertAllClose(np_ans, tf_ans) | ||||||
| 
 | 
 | ||||||
|   @test_util.run_deprecated_v1 |  | ||||||
|   def testBadShape(self): |   def testBadShape(self): | ||||||
|     with self.cached_session(): |     with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), | ||||||
|       with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"): |                                 "must be rank 1"): | ||||||
|         _SparseToDense([1, 3], [[5], [3]], 1, -1) |       sparse_ops.sparse_to_dense([1, 3], [[5], [3]], 1, -1) | ||||||
| 
 | 
 | ||||||
|   @test_util.run_deprecated_v1 |  | ||||||
|   def testBadValue(self): |   def testBadValue(self): | ||||||
|     with self.cached_session(): |     with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), | ||||||
|       dense = _SparseToDense([1, 3], [5], [[5], [3]], -1) |                                 r"sparse_values has incorrect shape \[2,1\], " | ||||||
|       with self.assertRaisesOpError( |                                 r"should be \[\] or \[2\]"): | ||||||
|           r"sparse_values has incorrect shape \[2,1\], " |       self.evaluate(sparse_ops.sparse_to_dense([1, 3], [5], [[5], [3]], -1)) | ||||||
|           r"should be \[\] or \[2\]"): |  | ||||||
|         self.evaluate(dense) |  | ||||||
| 
 | 
 | ||||||
|   @test_util.run_deprecated_v1 |  | ||||||
|   def testBadNumValues(self): |   def testBadNumValues(self): | ||||||
|     with self.cached_session(): |     with self.assertRaisesRegex( | ||||||
|       dense = _SparseToDense([1, 3], [5], [1, 2, 3], -1) |         (ValueError, errors.InvalidArgumentError), | ||||||
|       with self.assertRaisesOpError( |         r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"): | ||||||
|           r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"): |       self.evaluate(sparse_ops.sparse_to_dense([1, 3], [5], [1, 2, 3], -1)) | ||||||
|         self.evaluate(dense) |  | ||||||
| 
 | 
 | ||||||
|   @test_util.run_deprecated_v1 |  | ||||||
|   def testBadDefault(self): |   def testBadDefault(self): | ||||||
|     with self.cached_session(): |     with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), | ||||||
|       dense = _SparseToDense([1, 3], [5], [1, 2], [0]) |                                 "default_value should be a scalar"): | ||||||
|       with self.assertRaisesOpError("default_value should be a scalar"): |       self.evaluate(sparse_ops.sparse_to_dense([1, 3], [5], [1, 2], [0])) | ||||||
|         self.evaluate(dense) |  | ||||||
| 
 | 
 | ||||||
|   @test_util.run_deprecated_v1 |  | ||||||
|   def testOutOfBoundsIndicesWithWithoutValidation(self): |   def testOutOfBoundsIndicesWithWithoutValidation(self): | ||||||
|     with self.cached_session(): |     with self.assertRaisesRegex( | ||||||
|       dense = _SparseToDense( |         (ValueError, errors.InvalidArgumentError), | ||||||
|           sparse_indices=[[1], [10]], |         r"indices\[1\] = \[10\] is out of bounds: need 0 <= index < \[5\]"): | ||||||
|           output_size=[5], |       self.evaluate( | ||||||
|           sparse_values=[-1.0, 1.0], |           sparse_ops.sparse_to_dense([[1], [10]], [5], [1.0, 1.0], 0.0)) | ||||||
|           default_value=0.0) |     # Disable checks, the allocation should still fail. | ||||||
|       with self.assertRaisesOpError( |     with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), | ||||||
|           r"indices\[1\] = \[10\] is out of bounds: need 0 <= index < \[5\]"): |                                 "out of bounds"): | ||||||
|         self.evaluate(dense) |       self.evaluate( | ||||||
|       # Disable checks, the allocation should still fail. |           sparse_ops.sparse_to_dense([[1], [10]], [5], [-1.0, 1.0], | ||||||
|       with self.assertRaisesOpError("out of bounds"): |                                      0.0, | ||||||
|         dense_without_validation = _SparseToDense( |                                      validate_indices=False)) | ||||||
|             sparse_indices=[[1], [10]], |  | ||||||
|             output_size=[5], |  | ||||||
|             sparse_values=[-1.0, 1.0], |  | ||||||
|             default_value=0.0, |  | ||||||
|             validate_indices=False) |  | ||||||
|         self.evaluate(dense_without_validation) |  | ||||||
| 
 | 
 | ||||||
|   @test_util.run_deprecated_v1 |  | ||||||
|   def testRepeatingIndicesWithWithoutValidation(self): |   def testRepeatingIndicesWithWithoutValidation(self): | ||||||
|     with self.cached_session(): |     with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), | ||||||
|       dense = _SparseToDense( |                                 r"indices\[1\] = \[1\] is repeated"): | ||||||
|           sparse_indices=[[1], [1]], |       self.evaluate( | ||||||
|           output_size=[5], |           sparse_ops.sparse_to_dense([[1], [1]], [5], [-1.0, 1.0], 0.0)) | ||||||
|           sparse_values=[-1.0, 1.0], |     # Disable checks | ||||||
|           default_value=0.0) |     self.evaluate( | ||||||
|       with self.assertRaisesOpError(r"indices\[1\] = \[1\] is repeated"): |         sparse_ops.sparse_to_dense([[1], [1]], [5], [-1.0, 1.0], | ||||||
|         self.evaluate(dense) |                                    0.0, | ||||||
|       # Disable checks |                                    validate_indices=False)) | ||||||
|       dense_without_validation = _SparseToDense( |  | ||||||
|           sparse_indices=[[1], [1]], |  | ||||||
|           output_size=[5], |  | ||||||
|           sparse_values=[-1.0, 1.0], |  | ||||||
|           default_value=0.0, |  | ||||||
|           validate_indices=False) |  | ||||||
|       self.evaluate(dense_without_validation) |  | ||||||
| 
 | 
 | ||||||
|   @test_util.run_deprecated_v1 |  | ||||||
|   def testUnsortedIndicesWithWithoutValidation(self): |   def testUnsortedIndicesWithWithoutValidation(self): | ||||||
|     with self.cached_session(): |     with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), | ||||||
|       dense = _SparseToDense( |                                 r"indices\[1\] = \[1\] is out of order"): | ||||||
|           sparse_indices=[[2], [1]], |       self.evaluate( | ||||||
|           output_size=[5], |           sparse_ops.sparse_to_dense([[2], [1]], [5], [-1.0, 1.0], 0.0)) | ||||||
|           sparse_values=[-1.0, 1.0], |     # Disable checks | ||||||
|           default_value=0.0) |     self.evaluate( | ||||||
|       with self.assertRaisesOpError(r"indices\[1\] = \[1\] is out of order"): |         sparse_ops.sparse_to_dense([[2], [1]], [5], [-1.0, 1.0], | ||||||
|         self.evaluate(dense) |                                    0.0, | ||||||
|       # Disable checks |                                    validate_indices=False)) | ||||||
|       dense_without_validation = _SparseToDense( |  | ||||||
|           sparse_indices=[[2], [1]], |  | ||||||
|           output_size=[5], |  | ||||||
|           sparse_values=[-1.0, 1.0], |  | ||||||
|           default_value=0.0, |  | ||||||
|           validate_indices=False) |  | ||||||
|       self.evaluate(dense_without_validation) |  | ||||||
| 
 | 
 | ||||||
|   @test_util.run_deprecated_v1 |  | ||||||
|   def testShapeInferenceKnownShape(self): |   def testShapeInferenceKnownShape(self): | ||||||
|     with self.session(use_gpu=False): |     with ops.Graph().as_default(): | ||||||
|       indices = array_ops.placeholder(dtypes.int64) |       indices = array_ops.placeholder(dtypes.int64) | ||||||
| 
 | 
 | ||||||
|       shape = [4, 5, 6] |       shape = [4, 5, 6] | ||||||
| @ -204,13 +144,12 @@ class SparseToDenseTest(test.TestCase): | |||||||
|       output = sparse_ops.sparse_to_dense(indices, shape, 1, 0) |       output = sparse_ops.sparse_to_dense(indices, shape, 1, 0) | ||||||
|       self.assertEqual(output.get_shape().as_list(), [None, None, None]) |       self.assertEqual(output.get_shape().as_list(), [None, None, None]) | ||||||
| 
 | 
 | ||||||
|   @test_util.run_deprecated_v1 |  | ||||||
|   def testShapeInferenceUnknownShape(self): |   def testShapeInferenceUnknownShape(self): | ||||||
|     with self.session(use_gpu=False): |     with ops.Graph().as_default(): | ||||||
|       indices = array_ops.placeholder(dtypes.int64) |       indices = array_ops.placeholder(dtypes.int64) | ||||||
|       shape = array_ops.placeholder(dtypes.int64) |       shape = array_ops.placeholder(dtypes.int64) | ||||||
|       output = sparse_ops.sparse_to_dense(indices, shape, 1, 0) |       output = sparse_ops.sparse_to_dense(indices, shape, 1, 0) | ||||||
|       self.assertEqual(output.get_shape().ndims, None) |       self.assertIsNone(output.get_shape().ndims) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user