sparse_ops: Preserving static shape info in sparse_reshape, sparse_reorder, sparse_add, sparse_reset_shape in the cases where all input shapes are known and do not contain implicit "-1" dimensions. Exceptions are raises when appropriate, preventing a dishonest static shape from being set.
Change: 155013345
This commit is contained in:
parent
965d620104
commit
1457d7ffdc
@ -88,6 +88,7 @@ class SparseAddTest(test.TestCase):
|
||||
for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
|
||||
for sp_b in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
|
||||
sp_sum = sparse_ops.sparse_add(sp_a, sp_b)
|
||||
self.assertAllEqual((3, 3), sp_sum.get_shape())
|
||||
|
||||
sum_out = sess.run(sp_sum)
|
||||
|
||||
|
@ -328,6 +328,12 @@ class SparseResetShapeTest(test_util.TensorFlowTestCase):
|
||||
return sparse_tensor.SparseTensorValue(self._IND_2_5_6, self._VAL_2_5_6,
|
||||
self._SHP_2_5_6)
|
||||
|
||||
def testStaticShapeInfoPreservedWhenNewShapeIsProvidedAndStatic(self):
|
||||
sp_input = self._SparseTensor_2x5x6()
|
||||
new_shape = np.array([3, 6, 7], dtype=np.int64)
|
||||
sp_output = sparse_ops.sparse_reset_shape(sp_input, new_shape)
|
||||
self.assertAllEqual([3, 6, 7], sp_output.get_shape())
|
||||
|
||||
def testBasic(self):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
sp_input = self._SparseTensor_2x5x6()
|
||||
@ -397,14 +403,21 @@ class SparseResetShapeTest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaisesOpError("x == y did not hold element-wise"):
|
||||
sess.run(out, feed_dict={new_shape: np.array([3, 7], dtype=np.int64)})
|
||||
|
||||
def testInvalidDimensionSize(self):
|
||||
def testInvalidDimensionSizeStatic(self):
|
||||
sp_input = self._SparseTensor_2x5x6()
|
||||
new_shape = np.array([3, 7, 5], dtype=np.int64)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "should have dimension sizes"):
|
||||
sparse_ops.sparse_reset_shape(sp_input, new_shape)
|
||||
|
||||
def testInvalidDimensionSizeDynamic(self):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
sp_input = self._SparseTensor_2x5x6()
|
||||
new_shape = np.array([3, 7, 5], dtype=np.int64)
|
||||
new_shape = array_ops.placeholder(dtype=dtypes.int32)
|
||||
out = sparse_ops.sparse_reset_shape(sp_input, new_shape)
|
||||
|
||||
with self.assertRaisesOpError("x <= y did not hold element-wise"):
|
||||
sess.run(out)
|
||||
sess.run(out, feed_dict={new_shape: [3, 7, 5]})
|
||||
|
||||
def testInvalidDimensionSizeInputUnavailableInGraphConstruction(self):
|
||||
sp_input = array_ops.sparse_placeholder(dtype=dtypes.int32)
|
||||
|
@ -48,6 +48,13 @@ class SparseReorderTest(test.TestCase):
|
||||
shape = np.array([5, 6]).astype(np.int64)
|
||||
return sparse_tensor.SparseTensorValue(ind, val, shape)
|
||||
|
||||
def testStaticShapeInfoPreserved(self):
|
||||
sp_input = sparse_tensor.SparseTensor.from_value(
|
||||
self._SparseTensorValue_5x6(np.arange(6)))
|
||||
self.assertAllEqual((5, 6), sp_input.get_shape())
|
||||
sp_output = sparse_ops.sparse_reorder(sp_input)
|
||||
self.assertAllEqual((5, 6), sp_output.get_shape())
|
||||
|
||||
def testAlreadyInOrder(self):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
input_val = self._SparseTensorValue_5x6(np.arange(6))
|
||||
|
@ -50,6 +50,13 @@ class SparseReshapeTest(test.TestCase):
|
||||
shape = np.array([2, 3, 4])
|
||||
return sparse_tensor.SparseTensorValue(ind, val, shape)
|
||||
|
||||
def testStaticShapeInfoPreserved(self):
|
||||
sp_input = sparse_tensor.SparseTensor.from_value(
|
||||
self._SparseTensorValue_5x6())
|
||||
self.assertAllEqual((5, 6), sp_input.get_shape())
|
||||
sp_output = sparse_ops.sparse_reshape(sp_input, shape=(1, 5, 2, 3))
|
||||
self.assertAllEqual((1, 5, 2, 3), sp_output.get_shape())
|
||||
|
||||
def testSameShape(self):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
input_val = self._SparseTensorValue_5x6()
|
||||
@ -180,6 +187,12 @@ class SparseReshapeTest(test.TestCase):
|
||||
with self.assertRaisesOpError("only one output shape size may be -1"):
|
||||
sess.run(sp_output, {sp_input: input_val})
|
||||
|
||||
def testProvideStaticallyMismatchedSizes(self):
|
||||
input_val = self._SparseTensorValue_5x6()
|
||||
sp_input = sparse_tensor.SparseTensor.from_value(input_val)
|
||||
with self.assertRaisesRegexp(ValueError, "Cannot reshape"):
|
||||
sparse_ops.sparse_reshape(sp_input, [4, 7])
|
||||
|
||||
def testFeedMismatchedSizes(self):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
sp_input = self._SparseTensorPlaceholder()
|
||||
|
@ -51,6 +51,7 @@ import numpy as np
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -288,12 +289,21 @@ def sparse_add(a, b, thresh=0):
|
||||
|
||||
if all(isinstance(inp, sparse_classes) for inp in [a, b]):
|
||||
a = _convert_to_sparse_tensor(a)
|
||||
b = _convert_to_sparse_tensor(b)
|
||||
thresh = ops.convert_to_tensor(
|
||||
thresh, dtype=a.values.dtype.real_dtype, name="thresh")
|
||||
output_ind, output_val, output_shape = (gen_sparse_ops._sparse_add(
|
||||
a.indices, a.values, a.dense_shape,
|
||||
b.indices, b.values, b.dense_shape,
|
||||
thresh))
|
||||
|
||||
# Attempt to get output_shape statically.
|
||||
a.get_shape().assert_is_compatible_with(b.get_shape())
|
||||
static_shape = array_ops.broadcast_static_shape(
|
||||
a.get_shape(), b.get_shape())
|
||||
if static_shape.is_fully_defined():
|
||||
output_shape = static_shape.as_list()
|
||||
|
||||
return sparse_tensor.SparseTensor(output_ind, output_val, output_shape)
|
||||
else:
|
||||
# swap to make `a` the SparseTensor.
|
||||
@ -368,8 +378,12 @@ def sparse_reorder(sp_input, name=None):
|
||||
reordered_ind, reordered_val = (gen_sparse_ops._sparse_reorder(
|
||||
sp_input.indices, sp_input.values, sp_input.dense_shape, name=name))
|
||||
|
||||
return sparse_tensor.SparseTensor(reordered_ind, reordered_val,
|
||||
array_ops.identity(sp_input.dense_shape))
|
||||
if sp_input.get_shape().is_fully_defined():
|
||||
dense_shape = sp_input.get_shape().as_list()
|
||||
else:
|
||||
dense_shape = array_ops.identity(sp_input.dense_shape)
|
||||
|
||||
return sparse_tensor.SparseTensor(reordered_ind, reordered_val, dense_shape)
|
||||
|
||||
|
||||
def sparse_reshape(sp_input, shape, name=None):
|
||||
@ -416,13 +430,30 @@ def sparse_reshape(sp_input, shape, name=None):
|
||||
|
||||
Raises:
|
||||
TypeError: If `sp_input` is not a `SparseTensor`.
|
||||
ValueError: If argument `shape` requests a `SparseTensor` with a different
|
||||
number of elements than `sp_input`.
|
||||
"""
|
||||
sp_input = _convert_to_sparse_tensor(sp_input)
|
||||
shape = ops.convert_to_tensor(shape, dtype=dtypes.int64)
|
||||
|
||||
with ops.name_scope(name, "SparseReshape", [sp_input]) as name:
|
||||
reshaped_ind, reshaped_shape = gen_sparse_ops._sparse_reshape(
|
||||
sp_input.indices, sp_input.dense_shape, shape, name=name)
|
||||
|
||||
reshaped_shape_const = tensor_util.constant_value(shape)
|
||||
if (reshaped_shape_const is not None
|
||||
and sp_input.get_shape().is_fully_defined()):
|
||||
# Don't deal with inferred dimensions. That would add significant code.
|
||||
if all(n >= 0 for n in reshaped_shape_const):
|
||||
reshaped_size = np.prod(reshaped_shape_const)
|
||||
in_shape_size = np.prod(sp_input.get_shape().as_list())
|
||||
if reshaped_size != in_shape_size:
|
||||
raise ValueError(
|
||||
"Cannot reshape a tensor with %d elements to shape %s "
|
||||
"(%d elements)."
|
||||
% (in_shape_size, reshaped_shape_const, reshaped_size))
|
||||
reshaped_shape = reshaped_shape_const
|
||||
|
||||
return sparse_tensor.SparseTensor(
|
||||
reshaped_ind, array_ops.identity(sp_input.values),
|
||||
reshaped_shape)
|
||||
@ -986,6 +1017,8 @@ def sparse_reset_shape(sp_input, new_shape=None):
|
||||
TypeError: If `sp_input` is not a `SparseTensor`.
|
||||
ValueError: If `new_shape` represents a tensor with a different rank from
|
||||
that of `sp_input` (if shapes are known when graph is constructed).
|
||||
ValueError: If `new_shape` is determined during graph build to have
|
||||
dimension sizes that are too small.
|
||||
OpError:
|
||||
- If `new_shape` has dimension sizes that are too small.
|
||||
- If shapes are not known during graph construction time, and during run
|
||||
@ -1009,14 +1042,27 @@ def sparse_reset_shape(sp_input, new_shape=None):
|
||||
# error before the sparse_tensor.SparseTensor catches it.
|
||||
output_shape_tensor.get_shape()[0].merge_with(in_shape.get_shape()[0])
|
||||
|
||||
# For cases where shape is not known during graph construction.
|
||||
output_shape_tensor = control_flow_ops.with_dependencies(
|
||||
[check_ops.assert_equal(
|
||||
array_ops.shape(in_shape), array_ops.shape(output_shape_tensor))],
|
||||
output_shape_tensor)
|
||||
output_shape_tensor = control_flow_ops.with_dependencies(
|
||||
[check_ops.assert_less_equal(in_shape, output_shape_tensor)],
|
||||
output_shape_tensor_const = tensor_util.constant_value(
|
||||
output_shape_tensor)
|
||||
# For cases where all shapes are known during graph construction
|
||||
if (output_shape_tensor_const is not None
|
||||
and sp_input.get_shape().is_fully_defined()):
|
||||
in_shape_const = np.array(sp_input.get_shape().as_list())
|
||||
if not np.all(in_shape_const <= output_shape_tensor_const):
|
||||
raise ValueError(
|
||||
"Requested new_shape should have dimension sizes >= sp_input.shape."
|
||||
" Found new_shape (%s), sp_input.shape (%s)."
|
||||
% (in_shape_const, output_shape_tensor_const))
|
||||
output_shape_tensor = output_shape_tensor_const
|
||||
else:
|
||||
# For cases where shape is not known during graph construction.
|
||||
output_shape_tensor = control_flow_ops.with_dependencies(
|
||||
[check_ops.assert_equal(
|
||||
array_ops.shape(in_shape), array_ops.shape(output_shape_tensor))],
|
||||
output_shape_tensor)
|
||||
output_shape_tensor = control_flow_ops.with_dependencies(
|
||||
[check_ops.assert_less_equal(in_shape, output_shape_tensor)],
|
||||
output_shape_tensor)
|
||||
|
||||
return sparse_tensor.SparseTensor(in_indices, in_values, output_shape_tensor)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user