From 335f1f14d3b17d93ed7f575546cdad035dc22426 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 6 Jul 2017 14:37:27 -0700 Subject: [PATCH] Extend static shape inference for SparseTensors with dense_shapes constructed using slicing. PiperOrigin-RevId: 161132391 --- tensorflow/python/framework/tensor_util.py | 47 +++++++++-- .../python/framework/tensor_util_test.py | 77 +++++++++++++++++++ 2 files changed, 117 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 10811100010..323802e57fe 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -770,13 +770,46 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name # and concatenate it with `ret`. ret = ret.concatenate(constant_value_as_shape(concat_input)) return ret - else: - ret = tensor_shape.unknown_shape(shape[0].value) - value = constant_value(tensor) - if value is not None: - ret = ret.merge_with(tensor_shape.TensorShape( - [d if d != -1 else None for d in value])) - return ret + elif tensor.op.type == "StridedSlice": + try: + begin = constant_value(tensor.op.inputs[1]) + end = constant_value(tensor.op.inputs[2]) + strides = constant_value(tensor.op.inputs[3]) + if begin is not None and end is not None and strides is not None: + begin = begin[0] + end = end[0] + strides = strides[0] + begin_mask = tensor.op.get_attr("begin_mask") + if begin_mask == 1: + begin = None + end_mask = tensor.op.get_attr("end_mask") + if end_mask == 1: + end = None + + ellipsis_mask = tensor.op.get_attr("ellipsis_mask") + new_axis_mask = tensor.op.get_attr("new_axis_mask") + shrink_axis_mask = tensor.op.get_attr("shrink_axis_mask") + valid_attributes = (not ellipsis_mask and not new_axis_mask and + not shrink_axis_mask and + (not begin_mask or (begin_mask == 1)) and + (not end_mask or (end_mask == 1))) + if valid_attributes: # additional inputs not supported + prev = constant_value_as_shape(tensor.op.inputs[0]) + prev = prev[begin:end:strides] + ret = tensor_shape.TensorShape(prev) + return ret + + except ValueError: # Could come from get_attr or slicing prev. + pass + except TypeError: # Could come from slicing prev. + pass + + ret = tensor_shape.unknown_shape(shape[0].value) + value = constant_value(tensor) + if value is not None: + ret = ret.merge_with(tensor_shape.TensorShape( + [d if d >= 0 else None for d in value])) + return ret def is_tensor(x): # pylint: disable=invalid-name diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index 8949702b875..b0a117a21e0 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -832,6 +832,83 @@ class ConstantValueAsShapeTest(test.TestCase): c_val = tensor_util.constant_value_as_shape(tf_val) self.assertEqual([16, 37, None, 48], c_val.as_list()) + def testSlice(self): + tf_val = array_ops.placeholder(dtypes.int32, shape=(4,))[0:2] + c_val = tensor_util.constant_value_as_shape(tf_val) + self.assertEqual([None, None], c_val.as_list()) + + # begin:end + tf_val = constant_op.constant([10, 20, 30])[1:3] + c_val = tensor_util.constant_value_as_shape(tf_val) + self.assertEqual([20, 30], c_val.as_list()) + + # begin:end:stride + tf_val = array_ops.strided_slice( + constant_op.constant([10, 20, 30]), [1], [3], strides=[2]) + c_val = tensor_util.constant_value_as_shape(tf_val) + self.assertEqual([20], c_val.as_list()) + + # [1, 2, 16, 37, None, 48] + tf_val_orig = array_ops.concat( + [[1, 2, 16, 37], array_ops.placeholder( + dtypes.int32, shape=(1,)), [48]], 0) + + # begin: no end + tf_val = tf_val_orig[2:] + c_val = tensor_util.constant_value_as_shape(tf_val) + self.assertEqual([16, 37, None, 48], c_val.as_list()) + + # begin::negative slice + tf_val = tf_val_orig[2::-1] + c_val = tensor_util.constant_value_as_shape(tf_val) + self.assertEqual([16, 2, 1], c_val.as_list()) + + # :end:negative slice + tf_val = tf_val_orig[:1:-2] + c_val = tensor_util.constant_value_as_shape(tf_val) + self.assertEqual([48, 37], c_val.as_list()) + + # begin:end:negative slice + tf_val = tf_val_orig[3:1:-1] + c_val = tensor_util.constant_value_as_shape(tf_val) + self.assertEqual([37, 16], c_val.as_list()) + + # begin:negative end:slice + tf_val = tf_val_orig[1:-3:1] + c_val = tensor_util.constant_value_as_shape(tf_val) + self.assertEqual([2, 16], c_val.as_list()) + + # negative begin::slice + tf_val = tf_val_orig[-3::1] + c_val = tensor_util.constant_value_as_shape(tf_val) + self.assertEqual([37, None, 48], c_val.as_list()) + + # negative begin::negative slice + tf_val = tf_val_orig[-3::-1] + c_val = tensor_util.constant_value_as_shape(tf_val) + self.assertEqual([37, 16, 2, 1], c_val.as_list()) + + # negative begin:negative end:negative slice + tf_val = tf_val_orig[-3:-5:-1] + c_val = tensor_util.constant_value_as_shape(tf_val) + self.assertEqual([37, 16], c_val.as_list()) + + # Do not support shape inference for additional arguments + tf_val = constant_op.constant([10, 20, 30])[...] + c_val = tensor_util.constant_value_as_shape(tf_val) + self.assertEqual([None, None, None], c_val.as_list()) + + # Do not support shape inference for tensor slices. + tf_val = constant_op.constant([10, 20, 30])[ + array_ops.placeholder(dtypes.int32, shape=()):] + c_val = tensor_util.constant_value_as_shape(tf_val) + self.assertEqual(tensor_shape.unknown_shape(), c_val) + + # Do not support shape inference for higher rank + with self.assertRaises(ValueError): + tf_val = constant_op.constant([[10], [20], [30]])[:, 0:] + c_val = tensor_util.constant_value_as_shape(tf_val) + if __name__ == "__main__": test.main()