From 1457d7ffdcee6a18619a74bdc465ffa60c0fd1ff Mon Sep 17 00:00:00 2001 From: Ian Langmore Date: Wed, 3 May 2017 14:31:45 -0800 Subject: [PATCH] sparse_ops: Preserving static shape info in sparse_reshape, sparse_reorder, sparse_add, sparse_reset_shape in the cases where all input shapes are known and do not contain implicit "-1" dimensions. Exceptions are raises when appropriate, preventing a dishonest static shape from being set. Change: 155013345 --- .../python/kernel_tests/sparse_add_op_test.py | 1 + .../python/kernel_tests/sparse_ops_test.py | 19 +++++- .../kernel_tests/sparse_reorder_op_test.py | 7 ++ .../kernel_tests/sparse_reshape_op_test.py | 13 ++++ tensorflow/python/ops/sparse_ops.py | 64 ++++++++++++++++--- 5 files changed, 92 insertions(+), 12 deletions(-) diff --git a/tensorflow/python/kernel_tests/sparse_add_op_test.py b/tensorflow/python/kernel_tests/sparse_add_op_test.py index 874dcbabf10..555c16194e1 100644 --- a/tensorflow/python/kernel_tests/sparse_add_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_add_op_test.py @@ -88,6 +88,7 @@ class SparseAddTest(test.TestCase): for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()): for sp_b in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()): sp_sum = sparse_ops.sparse_add(sp_a, sp_b) + self.assertAllEqual((3, 3), sp_sum.get_shape()) sum_out = sess.run(sp_sum) diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py index 06d5cbaf2d0..bad11a29df0 100644 --- a/tensorflow/python/kernel_tests/sparse_ops_test.py +++ b/tensorflow/python/kernel_tests/sparse_ops_test.py @@ -328,6 +328,12 @@ class SparseResetShapeTest(test_util.TensorFlowTestCase): return sparse_tensor.SparseTensorValue(self._IND_2_5_6, self._VAL_2_5_6, self._SHP_2_5_6) + def testStaticShapeInfoPreservedWhenNewShapeIsProvidedAndStatic(self): + sp_input = self._SparseTensor_2x5x6() + new_shape = np.array([3, 6, 7], dtype=np.int64) + sp_output = sparse_ops.sparse_reset_shape(sp_input, new_shape) + self.assertAllEqual([3, 6, 7], sp_output.get_shape()) + def testBasic(self): with self.test_session(use_gpu=False) as sess: sp_input = self._SparseTensor_2x5x6() @@ -397,14 +403,21 @@ class SparseResetShapeTest(test_util.TensorFlowTestCase): with self.assertRaisesOpError("x == y did not hold element-wise"): sess.run(out, feed_dict={new_shape: np.array([3, 7], dtype=np.int64)}) - def testInvalidDimensionSize(self): + def testInvalidDimensionSizeStatic(self): + sp_input = self._SparseTensor_2x5x6() + new_shape = np.array([3, 7, 5], dtype=np.int64) + + with self.assertRaisesRegexp(ValueError, "should have dimension sizes"): + sparse_ops.sparse_reset_shape(sp_input, new_shape) + + def testInvalidDimensionSizeDynamic(self): with self.test_session(use_gpu=False) as sess: sp_input = self._SparseTensor_2x5x6() - new_shape = np.array([3, 7, 5], dtype=np.int64) + new_shape = array_ops.placeholder(dtype=dtypes.int32) out = sparse_ops.sparse_reset_shape(sp_input, new_shape) with self.assertRaisesOpError("x <= y did not hold element-wise"): - sess.run(out) + sess.run(out, feed_dict={new_shape: [3, 7, 5]}) def testInvalidDimensionSizeInputUnavailableInGraphConstruction(self): sp_input = array_ops.sparse_placeholder(dtype=dtypes.int32) diff --git a/tensorflow/python/kernel_tests/sparse_reorder_op_test.py b/tensorflow/python/kernel_tests/sparse_reorder_op_test.py index 5136cdadead..18335d665af 100644 --- a/tensorflow/python/kernel_tests/sparse_reorder_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_reorder_op_test.py @@ -48,6 +48,13 @@ class SparseReorderTest(test.TestCase): shape = np.array([5, 6]).astype(np.int64) return sparse_tensor.SparseTensorValue(ind, val, shape) + def testStaticShapeInfoPreserved(self): + sp_input = sparse_tensor.SparseTensor.from_value( + self._SparseTensorValue_5x6(np.arange(6))) + self.assertAllEqual((5, 6), sp_input.get_shape()) + sp_output = sparse_ops.sparse_reorder(sp_input) + self.assertAllEqual((5, 6), sp_output.get_shape()) + def testAlreadyInOrder(self): with self.test_session(use_gpu=False) as sess: input_val = self._SparseTensorValue_5x6(np.arange(6)) diff --git a/tensorflow/python/kernel_tests/sparse_reshape_op_test.py b/tensorflow/python/kernel_tests/sparse_reshape_op_test.py index 1bb05aa3b2a..42874ea9b7a 100644 --- a/tensorflow/python/kernel_tests/sparse_reshape_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_reshape_op_test.py @@ -50,6 +50,13 @@ class SparseReshapeTest(test.TestCase): shape = np.array([2, 3, 4]) return sparse_tensor.SparseTensorValue(ind, val, shape) + def testStaticShapeInfoPreserved(self): + sp_input = sparse_tensor.SparseTensor.from_value( + self._SparseTensorValue_5x6()) + self.assertAllEqual((5, 6), sp_input.get_shape()) + sp_output = sparse_ops.sparse_reshape(sp_input, shape=(1, 5, 2, 3)) + self.assertAllEqual((1, 5, 2, 3), sp_output.get_shape()) + def testSameShape(self): with self.test_session(use_gpu=False) as sess: input_val = self._SparseTensorValue_5x6() @@ -180,6 +187,12 @@ class SparseReshapeTest(test.TestCase): with self.assertRaisesOpError("only one output shape size may be -1"): sess.run(sp_output, {sp_input: input_val}) + def testProvideStaticallyMismatchedSizes(self): + input_val = self._SparseTensorValue_5x6() + sp_input = sparse_tensor.SparseTensor.from_value(input_val) + with self.assertRaisesRegexp(ValueError, "Cannot reshape"): + sparse_ops.sparse_reshape(sp_input, [4, 7]) + def testFeedMismatchedSizes(self): with self.test_session(use_gpu=False) as sess: sp_input = self._SparseTensorPlaceholder() diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index f8eb34aa5eb..0140a27aaa7 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -51,6 +51,7 @@ import numpy as np from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops @@ -288,12 +289,21 @@ def sparse_add(a, b, thresh=0): if all(isinstance(inp, sparse_classes) for inp in [a, b]): a = _convert_to_sparse_tensor(a) + b = _convert_to_sparse_tensor(b) thresh = ops.convert_to_tensor( thresh, dtype=a.values.dtype.real_dtype, name="thresh") output_ind, output_val, output_shape = (gen_sparse_ops._sparse_add( a.indices, a.values, a.dense_shape, b.indices, b.values, b.dense_shape, thresh)) + + # Attempt to get output_shape statically. + a.get_shape().assert_is_compatible_with(b.get_shape()) + static_shape = array_ops.broadcast_static_shape( + a.get_shape(), b.get_shape()) + if static_shape.is_fully_defined(): + output_shape = static_shape.as_list() + return sparse_tensor.SparseTensor(output_ind, output_val, output_shape) else: # swap to make `a` the SparseTensor. @@ -368,8 +378,12 @@ def sparse_reorder(sp_input, name=None): reordered_ind, reordered_val = (gen_sparse_ops._sparse_reorder( sp_input.indices, sp_input.values, sp_input.dense_shape, name=name)) - return sparse_tensor.SparseTensor(reordered_ind, reordered_val, - array_ops.identity(sp_input.dense_shape)) + if sp_input.get_shape().is_fully_defined(): + dense_shape = sp_input.get_shape().as_list() + else: + dense_shape = array_ops.identity(sp_input.dense_shape) + + return sparse_tensor.SparseTensor(reordered_ind, reordered_val, dense_shape) def sparse_reshape(sp_input, shape, name=None): @@ -416,13 +430,30 @@ def sparse_reshape(sp_input, shape, name=None): Raises: TypeError: If `sp_input` is not a `SparseTensor`. + ValueError: If argument `shape` requests a `SparseTensor` with a different + number of elements than `sp_input`. """ sp_input = _convert_to_sparse_tensor(sp_input) + shape = ops.convert_to_tensor(shape, dtype=dtypes.int64) with ops.name_scope(name, "SparseReshape", [sp_input]) as name: reshaped_ind, reshaped_shape = gen_sparse_ops._sparse_reshape( sp_input.indices, sp_input.dense_shape, shape, name=name) + reshaped_shape_const = tensor_util.constant_value(shape) + if (reshaped_shape_const is not None + and sp_input.get_shape().is_fully_defined()): + # Don't deal with inferred dimensions. That would add significant code. + if all(n >= 0 for n in reshaped_shape_const): + reshaped_size = np.prod(reshaped_shape_const) + in_shape_size = np.prod(sp_input.get_shape().as_list()) + if reshaped_size != in_shape_size: + raise ValueError( + "Cannot reshape a tensor with %d elements to shape %s " + "(%d elements)." + % (in_shape_size, reshaped_shape_const, reshaped_size)) + reshaped_shape = reshaped_shape_const + return sparse_tensor.SparseTensor( reshaped_ind, array_ops.identity(sp_input.values), reshaped_shape) @@ -986,6 +1017,8 @@ def sparse_reset_shape(sp_input, new_shape=None): TypeError: If `sp_input` is not a `SparseTensor`. ValueError: If `new_shape` represents a tensor with a different rank from that of `sp_input` (if shapes are known when graph is constructed). + ValueError: If `new_shape` is determined during graph build to have + dimension sizes that are too small. OpError: - If `new_shape` has dimension sizes that are too small. - If shapes are not known during graph construction time, and during run @@ -1009,14 +1042,27 @@ def sparse_reset_shape(sp_input, new_shape=None): # error before the sparse_tensor.SparseTensor catches it. output_shape_tensor.get_shape()[0].merge_with(in_shape.get_shape()[0]) - # For cases where shape is not known during graph construction. - output_shape_tensor = control_flow_ops.with_dependencies( - [check_ops.assert_equal( - array_ops.shape(in_shape), array_ops.shape(output_shape_tensor))], - output_shape_tensor) - output_shape_tensor = control_flow_ops.with_dependencies( - [check_ops.assert_less_equal(in_shape, output_shape_tensor)], + output_shape_tensor_const = tensor_util.constant_value( output_shape_tensor) + # For cases where all shapes are known during graph construction + if (output_shape_tensor_const is not None + and sp_input.get_shape().is_fully_defined()): + in_shape_const = np.array(sp_input.get_shape().as_list()) + if not np.all(in_shape_const <= output_shape_tensor_const): + raise ValueError( + "Requested new_shape should have dimension sizes >= sp_input.shape." + " Found new_shape (%s), sp_input.shape (%s)." + % (in_shape_const, output_shape_tensor_const)) + output_shape_tensor = output_shape_tensor_const + else: + # For cases where shape is not known during graph construction. + output_shape_tensor = control_flow_ops.with_dependencies( + [check_ops.assert_equal( + array_ops.shape(in_shape), array_ops.shape(output_shape_tensor))], + output_shape_tensor) + output_shape_tensor = control_flow_ops.with_dependencies( + [check_ops.assert_less_equal(in_shape, output_shape_tensor)], + output_shape_tensor) return sparse_tensor.SparseTensor(in_indices, in_values, output_shape_tensor)