From 081c7d5add2f084aabe71c2d4da1e0de6a780ac6 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Wed, 22 Apr 2020 08:21:44 -0700 Subject: [PATCH] Simplify shape code. Fixes: #37640 PiperOrigin-RevId: 307820599 Change-Id: I74e5119798420b428353aeaf3047096507601202 --- tensorflow/python/BUILD | 2 + tensorflow/python/framework/sparse_tensor.py | 35 +------- .../python/framework/sparse_tensor_test.py | 80 +++++++++++++++++++ 3 files changed, 85 insertions(+), 32 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index edbba09ab25..c9b2ca575b8 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2477,9 +2477,11 @@ tf_py_test( main = "framework/sparse_tensor_test.py", python_version = "PY3", deps = [ + ":array_ops", ":framework", ":framework_for_generated_wrappers", ":framework_test_lib", + ":math_ops", ":platform_test", "//tensorflow/core:protos_all_py", ], diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py index d085dfdab0d..ee479e67e58 100644 --- a/tensorflow/python/framework/sparse_tensor.py +++ b/tensorflow/python/framework/sparse_tensor.py @@ -132,38 +132,9 @@ class SparseTensor(tensor_like.TensorLike, composite_tensor.CompositeTensor): # is a VariableOp and updating users of SparseTensor. values = ops.convert_to_tensor(values, name="values") - # Can't check `if context.executing_eagerly()` here because sparse - # placeholders can still be used in eager context, when building a - # functional model. - if isinstance(indices, ops.EagerTensor): - try: - dense_shape = ops.convert_to_tensor( - dense_shape, name="dense_shape", dtype=dtypes.int64) - dense_shape_default = tensor_shape.TensorShape(dense_shape) - except ValueError: - raise ValueError("Unable to create eager SparseTensor. Check that " - "your shape is correctly defined. Eager " - "SparseTensors don't support unknown dimesions.\n" - "got shape:\n {}".format(dense_shape)) - else: - if isinstance(dense_shape, ops.Tensor): - dense_shape_default = tensor_util.constant_value_as_shape(dense_shape) - else: - dense_shape_default = [] - for dim in dense_shape: - if isinstance(dim, ops.Tensor): - # There is code passing lists of constant tensors. - dim = tensor_util.constant_value(dim) - if dim == -1: - # -1 may be passed for unknown shapes. - dim = None - - dense_shape_default.append(dim) - - dense_shape_default = tensor_shape.TensorShape(dense_shape_default) - - dense_shape = ops.convert_to_tensor( - dense_shape, name="dense_shape", dtype=dtypes.int64) + dense_shape = ops.convert_to_tensor( + dense_shape, name="dense_shape", dtype=dtypes.int64) + dense_shape_default = tensor_util.constant_value_as_shape(dense_shape) self._indices = indices self._values = values diff --git a/tensorflow/python/framework/sparse_tensor_test.py b/tensorflow/python/framework/sparse_tensor_test.py index f7ecf00f29b..0d18af1fe2f 100644 --- a/tensorflow/python/framework/sparse_tensor_test.py +++ b/tensorflow/python/framework/sparse_tensor_test.py @@ -29,6 +29,8 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import googletest @@ -124,6 +126,84 @@ class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase): sparse_tensor_value.dense_shape, convertee.dense_shape) +class SparseTensorShapeTest(test_util.TensorFlowTestCase): + + def test_simple(self): + indices = [[0, 2]] + values = [1] + dense_shape = [5, 5] + sp = sparse_tensor.SparseTensor(indices, values, dense_shape) + + self.assertIsInstance(sp.shape, tensor_shape.TensorShape) + self.assertIsInstance(sp.dense_shape, ops.Tensor) + self.assertEqual(sp.shape.as_list(), [5, 5]) + + def test_unknown_shape(self): + + @def_function.function + def my_func(dense_shape): + indices = [[0, 2]] + values = [1] + sp = sparse_tensor.SparseTensor(indices, values, dense_shape) + self.assertEqual(sp.shape.as_list(), [None, None]) + return sp + + my_func.get_concrete_function( + dense_shape=tensor_spec.TensorSpec( + dtype=dtypes.int64, shape=[2,])) + + def test_partial_shape(self): + + @def_function.function + def my_func(x): + indices = [[0, 2]] + values = [1] + y = ops.convert_to_tensor(3, dtype=dtypes.int64) + dense_shape = [x, y] + sp = sparse_tensor.SparseTensor(indices, values, dense_shape) + self.assertEqual(sp.shape.as_list(), [None, 3]) + return sp + + my_func.get_concrete_function( + x=tensor_spec.TensorSpec(dtype=dtypes.int64, shape=[])) + + def test_neg_shape(self): + indices = [[0, 2]] + values = [1] + dense_shape = [-1, 5] + sp = sparse_tensor.SparseTensor(indices, values, dense_shape) + self.assertEqual(sp.shape.as_list(), [None, 5]) + + def test_unknown_tensor_shape(self): + + @def_function.function + def my_func(x): + indices = [[0, 0]] + values = [1] + dense_shape = array_ops.shape(x) + dense_shape = math_ops.cast(dense_shape, dtypes.int64) + + sp = sparse_tensor.SparseTensor(indices, values, dense_shape) + self.assertEqual(sp.shape.as_list(), [None, None]) + return sp + + my_func.get_concrete_function( + x=tensor_spec.TensorSpec(dtype=dtypes.int64, shape=[None, None])) + + def test_unknown_rank(self): + + @def_function.function + def my_func(dense_shape): + indices = [[0, 0]] + values = [1] + sp = sparse_tensor.SparseTensor(indices, values, dense_shape) + self.assertEqual(sp.shape.rank, None) + return sp + + my_func.get_concrete_function( + dense_shape=tensor_spec.TensorSpec(dtype=dtypes.int64, shape=[None])) + + @test_util.run_all_in_graph_and_eager_modes class SparseTensorSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):