sparse_ops: Preserving static shape info in sparse_reshape, sparse_reorder, sparse_add, sparse_reset_shape in the cases where all input shapes are known and do not contain implicit "-1" dimensions. Exceptions are raises when appropriate, preventing a dishonest static shape from being set.

Change: 155013345
This commit is contained in:
Ian Langmore 2017-05-03 14:31:45 -08:00 committed by TensorFlower Gardener
parent 965d620104
commit 1457d7ffdc
5 changed files with 92 additions and 12 deletions

View File

@ -88,6 +88,7 @@ class SparseAddTest(test.TestCase):
for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
for sp_b in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
sp_sum = sparse_ops.sparse_add(sp_a, sp_b)
self.assertAllEqual((3, 3), sp_sum.get_shape())
sum_out = sess.run(sp_sum)

View File

@ -328,6 +328,12 @@ class SparseResetShapeTest(test_util.TensorFlowTestCase):
return sparse_tensor.SparseTensorValue(self._IND_2_5_6, self._VAL_2_5_6,
self._SHP_2_5_6)
def testStaticShapeInfoPreservedWhenNewShapeIsProvidedAndStatic(self):
sp_input = self._SparseTensor_2x5x6()
new_shape = np.array([3, 6, 7], dtype=np.int64)
sp_output = sparse_ops.sparse_reset_shape(sp_input, new_shape)
self.assertAllEqual([3, 6, 7], sp_output.get_shape())
def testBasic(self):
with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensor_2x5x6()
@ -397,14 +403,21 @@ class SparseResetShapeTest(test_util.TensorFlowTestCase):
with self.assertRaisesOpError("x == y did not hold element-wise"):
sess.run(out, feed_dict={new_shape: np.array([3, 7], dtype=np.int64)})
def testInvalidDimensionSize(self):
def testInvalidDimensionSizeStatic(self):
sp_input = self._SparseTensor_2x5x6()
new_shape = np.array([3, 7, 5], dtype=np.int64)
with self.assertRaisesRegexp(ValueError, "should have dimension sizes"):
sparse_ops.sparse_reset_shape(sp_input, new_shape)
def testInvalidDimensionSizeDynamic(self):
with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensor_2x5x6()
new_shape = np.array([3, 7, 5], dtype=np.int64)
new_shape = array_ops.placeholder(dtype=dtypes.int32)
out = sparse_ops.sparse_reset_shape(sp_input, new_shape)
with self.assertRaisesOpError("x <= y did not hold element-wise"):
sess.run(out)
sess.run(out, feed_dict={new_shape: [3, 7, 5]})
def testInvalidDimensionSizeInputUnavailableInGraphConstruction(self):
sp_input = array_ops.sparse_placeholder(dtype=dtypes.int32)

View File

@ -48,6 +48,13 @@ class SparseReorderTest(test.TestCase):
shape = np.array([5, 6]).astype(np.int64)
return sparse_tensor.SparseTensorValue(ind, val, shape)
def testStaticShapeInfoPreserved(self):
sp_input = sparse_tensor.SparseTensor.from_value(
self._SparseTensorValue_5x6(np.arange(6)))
self.assertAllEqual((5, 6), sp_input.get_shape())
sp_output = sparse_ops.sparse_reorder(sp_input)
self.assertAllEqual((5, 6), sp_output.get_shape())
def testAlreadyInOrder(self):
with self.test_session(use_gpu=False) as sess:
input_val = self._SparseTensorValue_5x6(np.arange(6))

View File

@ -50,6 +50,13 @@ class SparseReshapeTest(test.TestCase):
shape = np.array([2, 3, 4])
return sparse_tensor.SparseTensorValue(ind, val, shape)
def testStaticShapeInfoPreserved(self):
sp_input = sparse_tensor.SparseTensor.from_value(
self._SparseTensorValue_5x6())
self.assertAllEqual((5, 6), sp_input.get_shape())
sp_output = sparse_ops.sparse_reshape(sp_input, shape=(1, 5, 2, 3))
self.assertAllEqual((1, 5, 2, 3), sp_output.get_shape())
def testSameShape(self):
with self.test_session(use_gpu=False) as sess:
input_val = self._SparseTensorValue_5x6()
@ -180,6 +187,12 @@ class SparseReshapeTest(test.TestCase):
with self.assertRaisesOpError("only one output shape size may be -1"):
sess.run(sp_output, {sp_input: input_val})
def testProvideStaticallyMismatchedSizes(self):
input_val = self._SparseTensorValue_5x6()
sp_input = sparse_tensor.SparseTensor.from_value(input_val)
with self.assertRaisesRegexp(ValueError, "Cannot reshape"):
sparse_ops.sparse_reshape(sp_input, [4, 7])
def testFeedMismatchedSizes(self):
with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensorPlaceholder()

View File

@ -51,6 +51,7 @@ import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
@ -288,12 +289,21 @@ def sparse_add(a, b, thresh=0):
if all(isinstance(inp, sparse_classes) for inp in [a, b]):
a = _convert_to_sparse_tensor(a)
b = _convert_to_sparse_tensor(b)
thresh = ops.convert_to_tensor(
thresh, dtype=a.values.dtype.real_dtype, name="thresh")
output_ind, output_val, output_shape = (gen_sparse_ops._sparse_add(
a.indices, a.values, a.dense_shape,
b.indices, b.values, b.dense_shape,
thresh))
# Attempt to get output_shape statically.
a.get_shape().assert_is_compatible_with(b.get_shape())
static_shape = array_ops.broadcast_static_shape(
a.get_shape(), b.get_shape())
if static_shape.is_fully_defined():
output_shape = static_shape.as_list()
return sparse_tensor.SparseTensor(output_ind, output_val, output_shape)
else:
# swap to make `a` the SparseTensor.
@ -368,8 +378,12 @@ def sparse_reorder(sp_input, name=None):
reordered_ind, reordered_val = (gen_sparse_ops._sparse_reorder(
sp_input.indices, sp_input.values, sp_input.dense_shape, name=name))
return sparse_tensor.SparseTensor(reordered_ind, reordered_val,
array_ops.identity(sp_input.dense_shape))
if sp_input.get_shape().is_fully_defined():
dense_shape = sp_input.get_shape().as_list()
else:
dense_shape = array_ops.identity(sp_input.dense_shape)
return sparse_tensor.SparseTensor(reordered_ind, reordered_val, dense_shape)
def sparse_reshape(sp_input, shape, name=None):
@ -416,13 +430,30 @@ def sparse_reshape(sp_input, shape, name=None):
Raises:
TypeError: If `sp_input` is not a `SparseTensor`.
ValueError: If argument `shape` requests a `SparseTensor` with a different
number of elements than `sp_input`.
"""
sp_input = _convert_to_sparse_tensor(sp_input)
shape = ops.convert_to_tensor(shape, dtype=dtypes.int64)
with ops.name_scope(name, "SparseReshape", [sp_input]) as name:
reshaped_ind, reshaped_shape = gen_sparse_ops._sparse_reshape(
sp_input.indices, sp_input.dense_shape, shape, name=name)
reshaped_shape_const = tensor_util.constant_value(shape)
if (reshaped_shape_const is not None
and sp_input.get_shape().is_fully_defined()):
# Don't deal with inferred dimensions. That would add significant code.
if all(n >= 0 for n in reshaped_shape_const):
reshaped_size = np.prod(reshaped_shape_const)
in_shape_size = np.prod(sp_input.get_shape().as_list())
if reshaped_size != in_shape_size:
raise ValueError(
"Cannot reshape a tensor with %d elements to shape %s "
"(%d elements)."
% (in_shape_size, reshaped_shape_const, reshaped_size))
reshaped_shape = reshaped_shape_const
return sparse_tensor.SparseTensor(
reshaped_ind, array_ops.identity(sp_input.values),
reshaped_shape)
@ -986,6 +1017,8 @@ def sparse_reset_shape(sp_input, new_shape=None):
TypeError: If `sp_input` is not a `SparseTensor`.
ValueError: If `new_shape` represents a tensor with a different rank from
that of `sp_input` (if shapes are known when graph is constructed).
ValueError: If `new_shape` is determined during graph build to have
dimension sizes that are too small.
OpError:
- If `new_shape` has dimension sizes that are too small.
- If shapes are not known during graph construction time, and during run
@ -1009,14 +1042,27 @@ def sparse_reset_shape(sp_input, new_shape=None):
# error before the sparse_tensor.SparseTensor catches it.
output_shape_tensor.get_shape()[0].merge_with(in_shape.get_shape()[0])
# For cases where shape is not known during graph construction.
output_shape_tensor = control_flow_ops.with_dependencies(
[check_ops.assert_equal(
array_ops.shape(in_shape), array_ops.shape(output_shape_tensor))],
output_shape_tensor)
output_shape_tensor = control_flow_ops.with_dependencies(
[check_ops.assert_less_equal(in_shape, output_shape_tensor)],
output_shape_tensor_const = tensor_util.constant_value(
output_shape_tensor)
# For cases where all shapes are known during graph construction
if (output_shape_tensor_const is not None
and sp_input.get_shape().is_fully_defined()):
in_shape_const = np.array(sp_input.get_shape().as_list())
if not np.all(in_shape_const <= output_shape_tensor_const):
raise ValueError(
"Requested new_shape should have dimension sizes >= sp_input.shape."
" Found new_shape (%s), sp_input.shape (%s)."
% (in_shape_const, output_shape_tensor_const))
output_shape_tensor = output_shape_tensor_const
else:
# For cases where shape is not known during graph construction.
output_shape_tensor = control_flow_ops.with_dependencies(
[check_ops.assert_equal(
array_ops.shape(in_shape), array_ops.shape(output_shape_tensor))],
output_shape_tensor)
output_shape_tensor = control_flow_ops.with_dependencies(
[check_ops.assert_less_equal(in_shape, output_shape_tensor)],
output_shape_tensor)
return sparse_tensor.SparseTensor(in_indices, in_values, output_shape_tensor)