Simplify shape code.

Fixes: #37640
PiperOrigin-RevId: 307820599
Change-Id: I74e5119798420b428353aeaf3047096507601202
This commit is contained in:
Mark Daoust 2020-04-22 08:21:44 -07:00 committed by TensorFlower Gardener
parent b644a649e7
commit 081c7d5add
3 changed files with 85 additions and 32 deletions

View File

@ -2477,9 +2477,11 @@ tf_py_test(
main = "framework/sparse_tensor_test.py",
python_version = "PY3",
deps = [
":array_ops",
":framework",
":framework_for_generated_wrappers",
":framework_test_lib",
":math_ops",
":platform_test",
"//tensorflow/core:protos_all_py",
],

View File

@ -132,38 +132,9 @@ class SparseTensor(tensor_like.TensorLike, composite_tensor.CompositeTensor):
# is a VariableOp and updating users of SparseTensor.
values = ops.convert_to_tensor(values, name="values")
# Can't check `if context.executing_eagerly()` here because sparse
# placeholders can still be used in eager context, when building a
# functional model.
if isinstance(indices, ops.EagerTensor):
try:
dense_shape = ops.convert_to_tensor(
dense_shape, name="dense_shape", dtype=dtypes.int64)
dense_shape_default = tensor_shape.TensorShape(dense_shape)
except ValueError:
raise ValueError("Unable to create eager SparseTensor. Check that "
"your shape is correctly defined. Eager "
"SparseTensors don't support unknown dimesions.\n"
"got shape:\n {}".format(dense_shape))
else:
if isinstance(dense_shape, ops.Tensor):
dense_shape_default = tensor_util.constant_value_as_shape(dense_shape)
else:
dense_shape_default = []
for dim in dense_shape:
if isinstance(dim, ops.Tensor):
# There is code passing lists of constant tensors.
dim = tensor_util.constant_value(dim)
if dim == -1:
# -1 may be passed for unknown shapes.
dim = None
dense_shape_default.append(dim)
dense_shape_default = tensor_shape.TensorShape(dense_shape_default)
dense_shape = ops.convert_to_tensor(
dense_shape, name="dense_shape", dtype=dtypes.int64)
dense_shape = ops.convert_to_tensor(
dense_shape, name="dense_shape", dtype=dtypes.int64)
dense_shape_default = tensor_util.constant_value_as_shape(dense_shape)
self._indices = indices
self._values = values

View File

@ -29,6 +29,8 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import googletest
@ -124,6 +126,84 @@ class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase):
sparse_tensor_value.dense_shape, convertee.dense_shape)
class SparseTensorShapeTest(test_util.TensorFlowTestCase):
def test_simple(self):
indices = [[0, 2]]
values = [1]
dense_shape = [5, 5]
sp = sparse_tensor.SparseTensor(indices, values, dense_shape)
self.assertIsInstance(sp.shape, tensor_shape.TensorShape)
self.assertIsInstance(sp.dense_shape, ops.Tensor)
self.assertEqual(sp.shape.as_list(), [5, 5])
def test_unknown_shape(self):
@def_function.function
def my_func(dense_shape):
indices = [[0, 2]]
values = [1]
sp = sparse_tensor.SparseTensor(indices, values, dense_shape)
self.assertEqual(sp.shape.as_list(), [None, None])
return sp
my_func.get_concrete_function(
dense_shape=tensor_spec.TensorSpec(
dtype=dtypes.int64, shape=[2,]))
def test_partial_shape(self):
@def_function.function
def my_func(x):
indices = [[0, 2]]
values = [1]
y = ops.convert_to_tensor(3, dtype=dtypes.int64)
dense_shape = [x, y]
sp = sparse_tensor.SparseTensor(indices, values, dense_shape)
self.assertEqual(sp.shape.as_list(), [None, 3])
return sp
my_func.get_concrete_function(
x=tensor_spec.TensorSpec(dtype=dtypes.int64, shape=[]))
def test_neg_shape(self):
indices = [[0, 2]]
values = [1]
dense_shape = [-1, 5]
sp = sparse_tensor.SparseTensor(indices, values, dense_shape)
self.assertEqual(sp.shape.as_list(), [None, 5])
def test_unknown_tensor_shape(self):
@def_function.function
def my_func(x):
indices = [[0, 0]]
values = [1]
dense_shape = array_ops.shape(x)
dense_shape = math_ops.cast(dense_shape, dtypes.int64)
sp = sparse_tensor.SparseTensor(indices, values, dense_shape)
self.assertEqual(sp.shape.as_list(), [None, None])
return sp
my_func.get_concrete_function(
x=tensor_spec.TensorSpec(dtype=dtypes.int64, shape=[None, None]))
def test_unknown_rank(self):
@def_function.function
def my_func(dense_shape):
indices = [[0, 0]]
values = [1]
sp = sparse_tensor.SparseTensor(indices, values, dense_shape)
self.assertEqual(sp.shape.rank, None)
return sp
my_func.get_concrete_function(
dense_shape=tensor_spec.TensorSpec(dtype=dtypes.int64, shape=[None]))
@test_util.run_all_in_graph_and_eager_modes
class SparseTensorSpecTest(test_util.TensorFlowTestCase,
parameterized.TestCase):