sparse_ops: Preserving static shape info in sparse_reshape, sparse_reorder, sparse_add, sparse_reset_shape in the cases where all input shapes are known and do not contain implicit "-1" dimensions. Exceptions are raises when appropriate, preventing a dishonest static shape from being set.
Change: 155013345
This commit is contained in:
parent
965d620104
commit
1457d7ffdc
@ -88,6 +88,7 @@ class SparseAddTest(test.TestCase):
|
|||||||
for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
|
for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
|
||||||
for sp_b in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
|
for sp_b in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
|
||||||
sp_sum = sparse_ops.sparse_add(sp_a, sp_b)
|
sp_sum = sparse_ops.sparse_add(sp_a, sp_b)
|
||||||
|
self.assertAllEqual((3, 3), sp_sum.get_shape())
|
||||||
|
|
||||||
sum_out = sess.run(sp_sum)
|
sum_out = sess.run(sp_sum)
|
||||||
|
|
||||||
|
@ -328,6 +328,12 @@ class SparseResetShapeTest(test_util.TensorFlowTestCase):
|
|||||||
return sparse_tensor.SparseTensorValue(self._IND_2_5_6, self._VAL_2_5_6,
|
return sparse_tensor.SparseTensorValue(self._IND_2_5_6, self._VAL_2_5_6,
|
||||||
self._SHP_2_5_6)
|
self._SHP_2_5_6)
|
||||||
|
|
||||||
|
def testStaticShapeInfoPreservedWhenNewShapeIsProvidedAndStatic(self):
|
||||||
|
sp_input = self._SparseTensor_2x5x6()
|
||||||
|
new_shape = np.array([3, 6, 7], dtype=np.int64)
|
||||||
|
sp_output = sparse_ops.sparse_reset_shape(sp_input, new_shape)
|
||||||
|
self.assertAllEqual([3, 6, 7], sp_output.get_shape())
|
||||||
|
|
||||||
def testBasic(self):
|
def testBasic(self):
|
||||||
with self.test_session(use_gpu=False) as sess:
|
with self.test_session(use_gpu=False) as sess:
|
||||||
sp_input = self._SparseTensor_2x5x6()
|
sp_input = self._SparseTensor_2x5x6()
|
||||||
@ -397,14 +403,21 @@ class SparseResetShapeTest(test_util.TensorFlowTestCase):
|
|||||||
with self.assertRaisesOpError("x == y did not hold element-wise"):
|
with self.assertRaisesOpError("x == y did not hold element-wise"):
|
||||||
sess.run(out, feed_dict={new_shape: np.array([3, 7], dtype=np.int64)})
|
sess.run(out, feed_dict={new_shape: np.array([3, 7], dtype=np.int64)})
|
||||||
|
|
||||||
def testInvalidDimensionSize(self):
|
def testInvalidDimensionSizeStatic(self):
|
||||||
|
sp_input = self._SparseTensor_2x5x6()
|
||||||
|
new_shape = np.array([3, 7, 5], dtype=np.int64)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, "should have dimension sizes"):
|
||||||
|
sparse_ops.sparse_reset_shape(sp_input, new_shape)
|
||||||
|
|
||||||
|
def testInvalidDimensionSizeDynamic(self):
|
||||||
with self.test_session(use_gpu=False) as sess:
|
with self.test_session(use_gpu=False) as sess:
|
||||||
sp_input = self._SparseTensor_2x5x6()
|
sp_input = self._SparseTensor_2x5x6()
|
||||||
new_shape = np.array([3, 7, 5], dtype=np.int64)
|
new_shape = array_ops.placeholder(dtype=dtypes.int32)
|
||||||
out = sparse_ops.sparse_reset_shape(sp_input, new_shape)
|
out = sparse_ops.sparse_reset_shape(sp_input, new_shape)
|
||||||
|
|
||||||
with self.assertRaisesOpError("x <= y did not hold element-wise"):
|
with self.assertRaisesOpError("x <= y did not hold element-wise"):
|
||||||
sess.run(out)
|
sess.run(out, feed_dict={new_shape: [3, 7, 5]})
|
||||||
|
|
||||||
def testInvalidDimensionSizeInputUnavailableInGraphConstruction(self):
|
def testInvalidDimensionSizeInputUnavailableInGraphConstruction(self):
|
||||||
sp_input = array_ops.sparse_placeholder(dtype=dtypes.int32)
|
sp_input = array_ops.sparse_placeholder(dtype=dtypes.int32)
|
||||||
|
@ -48,6 +48,13 @@ class SparseReorderTest(test.TestCase):
|
|||||||
shape = np.array([5, 6]).astype(np.int64)
|
shape = np.array([5, 6]).astype(np.int64)
|
||||||
return sparse_tensor.SparseTensorValue(ind, val, shape)
|
return sparse_tensor.SparseTensorValue(ind, val, shape)
|
||||||
|
|
||||||
|
def testStaticShapeInfoPreserved(self):
|
||||||
|
sp_input = sparse_tensor.SparseTensor.from_value(
|
||||||
|
self._SparseTensorValue_5x6(np.arange(6)))
|
||||||
|
self.assertAllEqual((5, 6), sp_input.get_shape())
|
||||||
|
sp_output = sparse_ops.sparse_reorder(sp_input)
|
||||||
|
self.assertAllEqual((5, 6), sp_output.get_shape())
|
||||||
|
|
||||||
def testAlreadyInOrder(self):
|
def testAlreadyInOrder(self):
|
||||||
with self.test_session(use_gpu=False) as sess:
|
with self.test_session(use_gpu=False) as sess:
|
||||||
input_val = self._SparseTensorValue_5x6(np.arange(6))
|
input_val = self._SparseTensorValue_5x6(np.arange(6))
|
||||||
|
@ -50,6 +50,13 @@ class SparseReshapeTest(test.TestCase):
|
|||||||
shape = np.array([2, 3, 4])
|
shape = np.array([2, 3, 4])
|
||||||
return sparse_tensor.SparseTensorValue(ind, val, shape)
|
return sparse_tensor.SparseTensorValue(ind, val, shape)
|
||||||
|
|
||||||
|
def testStaticShapeInfoPreserved(self):
|
||||||
|
sp_input = sparse_tensor.SparseTensor.from_value(
|
||||||
|
self._SparseTensorValue_5x6())
|
||||||
|
self.assertAllEqual((5, 6), sp_input.get_shape())
|
||||||
|
sp_output = sparse_ops.sparse_reshape(sp_input, shape=(1, 5, 2, 3))
|
||||||
|
self.assertAllEqual((1, 5, 2, 3), sp_output.get_shape())
|
||||||
|
|
||||||
def testSameShape(self):
|
def testSameShape(self):
|
||||||
with self.test_session(use_gpu=False) as sess:
|
with self.test_session(use_gpu=False) as sess:
|
||||||
input_val = self._SparseTensorValue_5x6()
|
input_val = self._SparseTensorValue_5x6()
|
||||||
@ -180,6 +187,12 @@ class SparseReshapeTest(test.TestCase):
|
|||||||
with self.assertRaisesOpError("only one output shape size may be -1"):
|
with self.assertRaisesOpError("only one output shape size may be -1"):
|
||||||
sess.run(sp_output, {sp_input: input_val})
|
sess.run(sp_output, {sp_input: input_val})
|
||||||
|
|
||||||
|
def testProvideStaticallyMismatchedSizes(self):
|
||||||
|
input_val = self._SparseTensorValue_5x6()
|
||||||
|
sp_input = sparse_tensor.SparseTensor.from_value(input_val)
|
||||||
|
with self.assertRaisesRegexp(ValueError, "Cannot reshape"):
|
||||||
|
sparse_ops.sparse_reshape(sp_input, [4, 7])
|
||||||
|
|
||||||
def testFeedMismatchedSizes(self):
|
def testFeedMismatchedSizes(self):
|
||||||
with self.test_session(use_gpu=False) as sess:
|
with self.test_session(use_gpu=False) as sess:
|
||||||
sp_input = self._SparseTensorPlaceholder()
|
sp_input = self._SparseTensorPlaceholder()
|
||||||
|
@ -51,6 +51,7 @@ import numpy as np
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
@ -288,12 +289,21 @@ def sparse_add(a, b, thresh=0):
|
|||||||
|
|
||||||
if all(isinstance(inp, sparse_classes) for inp in [a, b]):
|
if all(isinstance(inp, sparse_classes) for inp in [a, b]):
|
||||||
a = _convert_to_sparse_tensor(a)
|
a = _convert_to_sparse_tensor(a)
|
||||||
|
b = _convert_to_sparse_tensor(b)
|
||||||
thresh = ops.convert_to_tensor(
|
thresh = ops.convert_to_tensor(
|
||||||
thresh, dtype=a.values.dtype.real_dtype, name="thresh")
|
thresh, dtype=a.values.dtype.real_dtype, name="thresh")
|
||||||
output_ind, output_val, output_shape = (gen_sparse_ops._sparse_add(
|
output_ind, output_val, output_shape = (gen_sparse_ops._sparse_add(
|
||||||
a.indices, a.values, a.dense_shape,
|
a.indices, a.values, a.dense_shape,
|
||||||
b.indices, b.values, b.dense_shape,
|
b.indices, b.values, b.dense_shape,
|
||||||
thresh))
|
thresh))
|
||||||
|
|
||||||
|
# Attempt to get output_shape statically.
|
||||||
|
a.get_shape().assert_is_compatible_with(b.get_shape())
|
||||||
|
static_shape = array_ops.broadcast_static_shape(
|
||||||
|
a.get_shape(), b.get_shape())
|
||||||
|
if static_shape.is_fully_defined():
|
||||||
|
output_shape = static_shape.as_list()
|
||||||
|
|
||||||
return sparse_tensor.SparseTensor(output_ind, output_val, output_shape)
|
return sparse_tensor.SparseTensor(output_ind, output_val, output_shape)
|
||||||
else:
|
else:
|
||||||
# swap to make `a` the SparseTensor.
|
# swap to make `a` the SparseTensor.
|
||||||
@ -368,8 +378,12 @@ def sparse_reorder(sp_input, name=None):
|
|||||||
reordered_ind, reordered_val = (gen_sparse_ops._sparse_reorder(
|
reordered_ind, reordered_val = (gen_sparse_ops._sparse_reorder(
|
||||||
sp_input.indices, sp_input.values, sp_input.dense_shape, name=name))
|
sp_input.indices, sp_input.values, sp_input.dense_shape, name=name))
|
||||||
|
|
||||||
return sparse_tensor.SparseTensor(reordered_ind, reordered_val,
|
if sp_input.get_shape().is_fully_defined():
|
||||||
array_ops.identity(sp_input.dense_shape))
|
dense_shape = sp_input.get_shape().as_list()
|
||||||
|
else:
|
||||||
|
dense_shape = array_ops.identity(sp_input.dense_shape)
|
||||||
|
|
||||||
|
return sparse_tensor.SparseTensor(reordered_ind, reordered_val, dense_shape)
|
||||||
|
|
||||||
|
|
||||||
def sparse_reshape(sp_input, shape, name=None):
|
def sparse_reshape(sp_input, shape, name=None):
|
||||||
@ -416,13 +430,30 @@ def sparse_reshape(sp_input, shape, name=None):
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If `sp_input` is not a `SparseTensor`.
|
TypeError: If `sp_input` is not a `SparseTensor`.
|
||||||
|
ValueError: If argument `shape` requests a `SparseTensor` with a different
|
||||||
|
number of elements than `sp_input`.
|
||||||
"""
|
"""
|
||||||
sp_input = _convert_to_sparse_tensor(sp_input)
|
sp_input = _convert_to_sparse_tensor(sp_input)
|
||||||
|
shape = ops.convert_to_tensor(shape, dtype=dtypes.int64)
|
||||||
|
|
||||||
with ops.name_scope(name, "SparseReshape", [sp_input]) as name:
|
with ops.name_scope(name, "SparseReshape", [sp_input]) as name:
|
||||||
reshaped_ind, reshaped_shape = gen_sparse_ops._sparse_reshape(
|
reshaped_ind, reshaped_shape = gen_sparse_ops._sparse_reshape(
|
||||||
sp_input.indices, sp_input.dense_shape, shape, name=name)
|
sp_input.indices, sp_input.dense_shape, shape, name=name)
|
||||||
|
|
||||||
|
reshaped_shape_const = tensor_util.constant_value(shape)
|
||||||
|
if (reshaped_shape_const is not None
|
||||||
|
and sp_input.get_shape().is_fully_defined()):
|
||||||
|
# Don't deal with inferred dimensions. That would add significant code.
|
||||||
|
if all(n >= 0 for n in reshaped_shape_const):
|
||||||
|
reshaped_size = np.prod(reshaped_shape_const)
|
||||||
|
in_shape_size = np.prod(sp_input.get_shape().as_list())
|
||||||
|
if reshaped_size != in_shape_size:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot reshape a tensor with %d elements to shape %s "
|
||||||
|
"(%d elements)."
|
||||||
|
% (in_shape_size, reshaped_shape_const, reshaped_size))
|
||||||
|
reshaped_shape = reshaped_shape_const
|
||||||
|
|
||||||
return sparse_tensor.SparseTensor(
|
return sparse_tensor.SparseTensor(
|
||||||
reshaped_ind, array_ops.identity(sp_input.values),
|
reshaped_ind, array_ops.identity(sp_input.values),
|
||||||
reshaped_shape)
|
reshaped_shape)
|
||||||
@ -986,6 +1017,8 @@ def sparse_reset_shape(sp_input, new_shape=None):
|
|||||||
TypeError: If `sp_input` is not a `SparseTensor`.
|
TypeError: If `sp_input` is not a `SparseTensor`.
|
||||||
ValueError: If `new_shape` represents a tensor with a different rank from
|
ValueError: If `new_shape` represents a tensor with a different rank from
|
||||||
that of `sp_input` (if shapes are known when graph is constructed).
|
that of `sp_input` (if shapes are known when graph is constructed).
|
||||||
|
ValueError: If `new_shape` is determined during graph build to have
|
||||||
|
dimension sizes that are too small.
|
||||||
OpError:
|
OpError:
|
||||||
- If `new_shape` has dimension sizes that are too small.
|
- If `new_shape` has dimension sizes that are too small.
|
||||||
- If shapes are not known during graph construction time, and during run
|
- If shapes are not known during graph construction time, and during run
|
||||||
@ -1009,14 +1042,27 @@ def sparse_reset_shape(sp_input, new_shape=None):
|
|||||||
# error before the sparse_tensor.SparseTensor catches it.
|
# error before the sparse_tensor.SparseTensor catches it.
|
||||||
output_shape_tensor.get_shape()[0].merge_with(in_shape.get_shape()[0])
|
output_shape_tensor.get_shape()[0].merge_with(in_shape.get_shape()[0])
|
||||||
|
|
||||||
# For cases where shape is not known during graph construction.
|
output_shape_tensor_const = tensor_util.constant_value(
|
||||||
output_shape_tensor = control_flow_ops.with_dependencies(
|
|
||||||
[check_ops.assert_equal(
|
|
||||||
array_ops.shape(in_shape), array_ops.shape(output_shape_tensor))],
|
|
||||||
output_shape_tensor)
|
|
||||||
output_shape_tensor = control_flow_ops.with_dependencies(
|
|
||||||
[check_ops.assert_less_equal(in_shape, output_shape_tensor)],
|
|
||||||
output_shape_tensor)
|
output_shape_tensor)
|
||||||
|
# For cases where all shapes are known during graph construction
|
||||||
|
if (output_shape_tensor_const is not None
|
||||||
|
and sp_input.get_shape().is_fully_defined()):
|
||||||
|
in_shape_const = np.array(sp_input.get_shape().as_list())
|
||||||
|
if not np.all(in_shape_const <= output_shape_tensor_const):
|
||||||
|
raise ValueError(
|
||||||
|
"Requested new_shape should have dimension sizes >= sp_input.shape."
|
||||||
|
" Found new_shape (%s), sp_input.shape (%s)."
|
||||||
|
% (in_shape_const, output_shape_tensor_const))
|
||||||
|
output_shape_tensor = output_shape_tensor_const
|
||||||
|
else:
|
||||||
|
# For cases where shape is not known during graph construction.
|
||||||
|
output_shape_tensor = control_flow_ops.with_dependencies(
|
||||||
|
[check_ops.assert_equal(
|
||||||
|
array_ops.shape(in_shape), array_ops.shape(output_shape_tensor))],
|
||||||
|
output_shape_tensor)
|
||||||
|
output_shape_tensor = control_flow_ops.with_dependencies(
|
||||||
|
[check_ops.assert_less_equal(in_shape, output_shape_tensor)],
|
||||||
|
output_shape_tensor)
|
||||||
|
|
||||||
return sparse_tensor.SparseTensor(in_indices, in_values, output_shape_tensor)
|
return sparse_tensor.SparseTensor(in_indices, in_values, output_shape_tensor)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user