Simplify shape code.
Fixes: #37640 PiperOrigin-RevId: 307820599 Change-Id: I74e5119798420b428353aeaf3047096507601202
This commit is contained in:
parent
b644a649e7
commit
081c7d5add
@ -2477,9 +2477,11 @@ tf_py_test(
|
||||
main = "framework/sparse_tensor_test.py",
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":array_ops",
|
||||
":framework",
|
||||
":framework_for_generated_wrappers",
|
||||
":framework_test_lib",
|
||||
":math_ops",
|
||||
":platform_test",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
|
||||
@ -132,38 +132,9 @@ class SparseTensor(tensor_like.TensorLike, composite_tensor.CompositeTensor):
|
||||
# is a VariableOp and updating users of SparseTensor.
|
||||
values = ops.convert_to_tensor(values, name="values")
|
||||
|
||||
# Can't check `if context.executing_eagerly()` here because sparse
|
||||
# placeholders can still be used in eager context, when building a
|
||||
# functional model.
|
||||
if isinstance(indices, ops.EagerTensor):
|
||||
try:
|
||||
dense_shape = ops.convert_to_tensor(
|
||||
dense_shape, name="dense_shape", dtype=dtypes.int64)
|
||||
dense_shape_default = tensor_shape.TensorShape(dense_shape)
|
||||
except ValueError:
|
||||
raise ValueError("Unable to create eager SparseTensor. Check that "
|
||||
"your shape is correctly defined. Eager "
|
||||
"SparseTensors don't support unknown dimesions.\n"
|
||||
"got shape:\n {}".format(dense_shape))
|
||||
else:
|
||||
if isinstance(dense_shape, ops.Tensor):
|
||||
dense_shape_default = tensor_util.constant_value_as_shape(dense_shape)
|
||||
else:
|
||||
dense_shape_default = []
|
||||
for dim in dense_shape:
|
||||
if isinstance(dim, ops.Tensor):
|
||||
# There is code passing lists of constant tensors.
|
||||
dim = tensor_util.constant_value(dim)
|
||||
if dim == -1:
|
||||
# -1 may be passed for unknown shapes.
|
||||
dim = None
|
||||
|
||||
dense_shape_default.append(dim)
|
||||
|
||||
dense_shape_default = tensor_shape.TensorShape(dense_shape_default)
|
||||
|
||||
dense_shape = ops.convert_to_tensor(
|
||||
dense_shape, name="dense_shape", dtype=dtypes.int64)
|
||||
dense_shape = ops.convert_to_tensor(
|
||||
dense_shape, name="dense_shape", dtype=dtypes.int64)
|
||||
dense_shape_default = tensor_util.constant_value_as_shape(dense_shape)
|
||||
|
||||
self._indices = indices
|
||||
self._values = values
|
||||
|
||||
@ -29,6 +29,8 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
@ -124,6 +126,84 @@ class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase):
|
||||
sparse_tensor_value.dense_shape, convertee.dense_shape)
|
||||
|
||||
|
||||
class SparseTensorShapeTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def test_simple(self):
|
||||
indices = [[0, 2]]
|
||||
values = [1]
|
||||
dense_shape = [5, 5]
|
||||
sp = sparse_tensor.SparseTensor(indices, values, dense_shape)
|
||||
|
||||
self.assertIsInstance(sp.shape, tensor_shape.TensorShape)
|
||||
self.assertIsInstance(sp.dense_shape, ops.Tensor)
|
||||
self.assertEqual(sp.shape.as_list(), [5, 5])
|
||||
|
||||
def test_unknown_shape(self):
|
||||
|
||||
@def_function.function
|
||||
def my_func(dense_shape):
|
||||
indices = [[0, 2]]
|
||||
values = [1]
|
||||
sp = sparse_tensor.SparseTensor(indices, values, dense_shape)
|
||||
self.assertEqual(sp.shape.as_list(), [None, None])
|
||||
return sp
|
||||
|
||||
my_func.get_concrete_function(
|
||||
dense_shape=tensor_spec.TensorSpec(
|
||||
dtype=dtypes.int64, shape=[2,]))
|
||||
|
||||
def test_partial_shape(self):
|
||||
|
||||
@def_function.function
|
||||
def my_func(x):
|
||||
indices = [[0, 2]]
|
||||
values = [1]
|
||||
y = ops.convert_to_tensor(3, dtype=dtypes.int64)
|
||||
dense_shape = [x, y]
|
||||
sp = sparse_tensor.SparseTensor(indices, values, dense_shape)
|
||||
self.assertEqual(sp.shape.as_list(), [None, 3])
|
||||
return sp
|
||||
|
||||
my_func.get_concrete_function(
|
||||
x=tensor_spec.TensorSpec(dtype=dtypes.int64, shape=[]))
|
||||
|
||||
def test_neg_shape(self):
|
||||
indices = [[0, 2]]
|
||||
values = [1]
|
||||
dense_shape = [-1, 5]
|
||||
sp = sparse_tensor.SparseTensor(indices, values, dense_shape)
|
||||
self.assertEqual(sp.shape.as_list(), [None, 5])
|
||||
|
||||
def test_unknown_tensor_shape(self):
|
||||
|
||||
@def_function.function
|
||||
def my_func(x):
|
||||
indices = [[0, 0]]
|
||||
values = [1]
|
||||
dense_shape = array_ops.shape(x)
|
||||
dense_shape = math_ops.cast(dense_shape, dtypes.int64)
|
||||
|
||||
sp = sparse_tensor.SparseTensor(indices, values, dense_shape)
|
||||
self.assertEqual(sp.shape.as_list(), [None, None])
|
||||
return sp
|
||||
|
||||
my_func.get_concrete_function(
|
||||
x=tensor_spec.TensorSpec(dtype=dtypes.int64, shape=[None, None]))
|
||||
|
||||
def test_unknown_rank(self):
|
||||
|
||||
@def_function.function
|
||||
def my_func(dense_shape):
|
||||
indices = [[0, 0]]
|
||||
values = [1]
|
||||
sp = sparse_tensor.SparseTensor(indices, values, dense_shape)
|
||||
self.assertEqual(sp.shape.rank, None)
|
||||
return sp
|
||||
|
||||
my_func.get_concrete_function(
|
||||
dense_shape=tensor_spec.TensorSpec(dtype=dtypes.int64, shape=[None]))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class SparseTensorSpecTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user