Extend static shape inference for SparseTensors with dense_shapes constructed using slicing.

PiperOrigin-RevId: 161132391
This commit is contained in:
A. Unique TensorFlower 2017-07-06 14:37:27 -07:00 committed by TensorFlower Gardener
parent 53604916ed
commit 335f1f14d3
2 changed files with 117 additions and 7 deletions

View File

@ -770,13 +770,46 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name
# and concatenate it with `ret`.
ret = ret.concatenate(constant_value_as_shape(concat_input))
return ret
else:
ret = tensor_shape.unknown_shape(shape[0].value)
value = constant_value(tensor)
if value is not None:
ret = ret.merge_with(tensor_shape.TensorShape(
[d if d != -1 else None for d in value]))
return ret
elif tensor.op.type == "StridedSlice":
try:
begin = constant_value(tensor.op.inputs[1])
end = constant_value(tensor.op.inputs[2])
strides = constant_value(tensor.op.inputs[3])
if begin is not None and end is not None and strides is not None:
begin = begin[0]
end = end[0]
strides = strides[0]
begin_mask = tensor.op.get_attr("begin_mask")
if begin_mask == 1:
begin = None
end_mask = tensor.op.get_attr("end_mask")
if end_mask == 1:
end = None
ellipsis_mask = tensor.op.get_attr("ellipsis_mask")
new_axis_mask = tensor.op.get_attr("new_axis_mask")
shrink_axis_mask = tensor.op.get_attr("shrink_axis_mask")
valid_attributes = (not ellipsis_mask and not new_axis_mask and
not shrink_axis_mask and
(not begin_mask or (begin_mask == 1)) and
(not end_mask or (end_mask == 1)))
if valid_attributes: # additional inputs not supported
prev = constant_value_as_shape(tensor.op.inputs[0])
prev = prev[begin:end:strides]
ret = tensor_shape.TensorShape(prev)
return ret
except ValueError: # Could come from get_attr or slicing prev.
pass
except TypeError: # Could come from slicing prev.
pass
ret = tensor_shape.unknown_shape(shape[0].value)
value = constant_value(tensor)
if value is not None:
ret = ret.merge_with(tensor_shape.TensorShape(
[d if d >= 0 else None for d in value]))
return ret
def is_tensor(x): # pylint: disable=invalid-name

View File

@ -832,6 +832,83 @@ class ConstantValueAsShapeTest(test.TestCase):
c_val = tensor_util.constant_value_as_shape(tf_val)
self.assertEqual([16, 37, None, 48], c_val.as_list())
def testSlice(self):
tf_val = array_ops.placeholder(dtypes.int32, shape=(4,))[0:2]
c_val = tensor_util.constant_value_as_shape(tf_val)
self.assertEqual([None, None], c_val.as_list())
# begin:end
tf_val = constant_op.constant([10, 20, 30])[1:3]
c_val = tensor_util.constant_value_as_shape(tf_val)
self.assertEqual([20, 30], c_val.as_list())
# begin:end:stride
tf_val = array_ops.strided_slice(
constant_op.constant([10, 20, 30]), [1], [3], strides=[2])
c_val = tensor_util.constant_value_as_shape(tf_val)
self.assertEqual([20], c_val.as_list())
# [1, 2, 16, 37, None, 48]
tf_val_orig = array_ops.concat(
[[1, 2, 16, 37], array_ops.placeholder(
dtypes.int32, shape=(1,)), [48]], 0)
# begin: no end
tf_val = tf_val_orig[2:]
c_val = tensor_util.constant_value_as_shape(tf_val)
self.assertEqual([16, 37, None, 48], c_val.as_list())
# begin::negative slice
tf_val = tf_val_orig[2::-1]
c_val = tensor_util.constant_value_as_shape(tf_val)
self.assertEqual([16, 2, 1], c_val.as_list())
# :end:negative slice
tf_val = tf_val_orig[:1:-2]
c_val = tensor_util.constant_value_as_shape(tf_val)
self.assertEqual([48, 37], c_val.as_list())
# begin:end:negative slice
tf_val = tf_val_orig[3:1:-1]
c_val = tensor_util.constant_value_as_shape(tf_val)
self.assertEqual([37, 16], c_val.as_list())
# begin:negative end:slice
tf_val = tf_val_orig[1:-3:1]
c_val = tensor_util.constant_value_as_shape(tf_val)
self.assertEqual([2, 16], c_val.as_list())
# negative begin::slice
tf_val = tf_val_orig[-3::1]
c_val = tensor_util.constant_value_as_shape(tf_val)
self.assertEqual([37, None, 48], c_val.as_list())
# negative begin::negative slice
tf_val = tf_val_orig[-3::-1]
c_val = tensor_util.constant_value_as_shape(tf_val)
self.assertEqual([37, 16, 2, 1], c_val.as_list())
# negative begin:negative end:negative slice
tf_val = tf_val_orig[-3:-5:-1]
c_val = tensor_util.constant_value_as_shape(tf_val)
self.assertEqual([37, 16], c_val.as_list())
# Do not support shape inference for additional arguments
tf_val = constant_op.constant([10, 20, 30])[...]
c_val = tensor_util.constant_value_as_shape(tf_val)
self.assertEqual([None, None, None], c_val.as_list())
# Do not support shape inference for tensor slices.
tf_val = constant_op.constant([10, 20, 30])[
array_ops.placeholder(dtypes.int32, shape=()):]
c_val = tensor_util.constant_value_as_shape(tf_val)
self.assertEqual(tensor_shape.unknown_shape(), c_val)
# Do not support shape inference for higher rank
with self.assertRaises(ValueError):
tf_val = constant_op.constant([[10], [20], [30]])[:, 0:]
c_val = tensor_util.constant_value_as_shape(tf_val)
if __name__ == "__main__":
test.main()