Extend static shape inference for SparseTensors with dense_shapes constructed using slicing.
PiperOrigin-RevId: 161132391
This commit is contained in:
parent
53604916ed
commit
335f1f14d3
@ -770,12 +770,45 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name
|
||||
# and concatenate it with `ret`.
|
||||
ret = ret.concatenate(constant_value_as_shape(concat_input))
|
||||
return ret
|
||||
else:
|
||||
elif tensor.op.type == "StridedSlice":
|
||||
try:
|
||||
begin = constant_value(tensor.op.inputs[1])
|
||||
end = constant_value(tensor.op.inputs[2])
|
||||
strides = constant_value(tensor.op.inputs[3])
|
||||
if begin is not None and end is not None and strides is not None:
|
||||
begin = begin[0]
|
||||
end = end[0]
|
||||
strides = strides[0]
|
||||
begin_mask = tensor.op.get_attr("begin_mask")
|
||||
if begin_mask == 1:
|
||||
begin = None
|
||||
end_mask = tensor.op.get_attr("end_mask")
|
||||
if end_mask == 1:
|
||||
end = None
|
||||
|
||||
ellipsis_mask = tensor.op.get_attr("ellipsis_mask")
|
||||
new_axis_mask = tensor.op.get_attr("new_axis_mask")
|
||||
shrink_axis_mask = tensor.op.get_attr("shrink_axis_mask")
|
||||
valid_attributes = (not ellipsis_mask and not new_axis_mask and
|
||||
not shrink_axis_mask and
|
||||
(not begin_mask or (begin_mask == 1)) and
|
||||
(not end_mask or (end_mask == 1)))
|
||||
if valid_attributes: # additional inputs not supported
|
||||
prev = constant_value_as_shape(tensor.op.inputs[0])
|
||||
prev = prev[begin:end:strides]
|
||||
ret = tensor_shape.TensorShape(prev)
|
||||
return ret
|
||||
|
||||
except ValueError: # Could come from get_attr or slicing prev.
|
||||
pass
|
||||
except TypeError: # Could come from slicing prev.
|
||||
pass
|
||||
|
||||
ret = tensor_shape.unknown_shape(shape[0].value)
|
||||
value = constant_value(tensor)
|
||||
if value is not None:
|
||||
ret = ret.merge_with(tensor_shape.TensorShape(
|
||||
[d if d != -1 else None for d in value]))
|
||||
[d if d >= 0 else None for d in value]))
|
||||
return ret
|
||||
|
||||
|
||||
|
@ -832,6 +832,83 @@ class ConstantValueAsShapeTest(test.TestCase):
|
||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||
self.assertEqual([16, 37, None, 48], c_val.as_list())
|
||||
|
||||
def testSlice(self):
|
||||
tf_val = array_ops.placeholder(dtypes.int32, shape=(4,))[0:2]
|
||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||
self.assertEqual([None, None], c_val.as_list())
|
||||
|
||||
# begin:end
|
||||
tf_val = constant_op.constant([10, 20, 30])[1:3]
|
||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||
self.assertEqual([20, 30], c_val.as_list())
|
||||
|
||||
# begin:end:stride
|
||||
tf_val = array_ops.strided_slice(
|
||||
constant_op.constant([10, 20, 30]), [1], [3], strides=[2])
|
||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||
self.assertEqual([20], c_val.as_list())
|
||||
|
||||
# [1, 2, 16, 37, None, 48]
|
||||
tf_val_orig = array_ops.concat(
|
||||
[[1, 2, 16, 37], array_ops.placeholder(
|
||||
dtypes.int32, shape=(1,)), [48]], 0)
|
||||
|
||||
# begin: no end
|
||||
tf_val = tf_val_orig[2:]
|
||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||
self.assertEqual([16, 37, None, 48], c_val.as_list())
|
||||
|
||||
# begin::negative slice
|
||||
tf_val = tf_val_orig[2::-1]
|
||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||
self.assertEqual([16, 2, 1], c_val.as_list())
|
||||
|
||||
# :end:negative slice
|
||||
tf_val = tf_val_orig[:1:-2]
|
||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||
self.assertEqual([48, 37], c_val.as_list())
|
||||
|
||||
# begin:end:negative slice
|
||||
tf_val = tf_val_orig[3:1:-1]
|
||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||
self.assertEqual([37, 16], c_val.as_list())
|
||||
|
||||
# begin:negative end:slice
|
||||
tf_val = tf_val_orig[1:-3:1]
|
||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||
self.assertEqual([2, 16], c_val.as_list())
|
||||
|
||||
# negative begin::slice
|
||||
tf_val = tf_val_orig[-3::1]
|
||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||
self.assertEqual([37, None, 48], c_val.as_list())
|
||||
|
||||
# negative begin::negative slice
|
||||
tf_val = tf_val_orig[-3::-1]
|
||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||
self.assertEqual([37, 16, 2, 1], c_val.as_list())
|
||||
|
||||
# negative begin:negative end:negative slice
|
||||
tf_val = tf_val_orig[-3:-5:-1]
|
||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||
self.assertEqual([37, 16], c_val.as_list())
|
||||
|
||||
# Do not support shape inference for additional arguments
|
||||
tf_val = constant_op.constant([10, 20, 30])[...]
|
||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||
self.assertEqual([None, None, None], c_val.as_list())
|
||||
|
||||
# Do not support shape inference for tensor slices.
|
||||
tf_val = constant_op.constant([10, 20, 30])[
|
||||
array_ops.placeholder(dtypes.int32, shape=()):]
|
||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||
self.assertEqual(tensor_shape.unknown_shape(), c_val)
|
||||
|
||||
# Do not support shape inference for higher rank
|
||||
with self.assertRaises(ValueError):
|
||||
tf_val = constant_op.constant([[10], [20], [30]])[:, 0:]
|
||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user