SparseTensor: when deserializing from component tensors, propagate the static dense_shape from the typespec.

PiperOrigin-RevId: 342338243
Change-Id: Iefaa8857f702d34521fc4ee23e1cbaf19752f53c
This commit is contained in:
A. Unique TensorFlower 2020-11-13 14:24:56 -08:00 committed by TensorFlower Gardener
parent 31c20f9e8a
commit 595da5557f
3 changed files with 2 additions and 16 deletions

View File

@ -342,11 +342,7 @@ class SparseTensorSpec(type_spec.BatchableTypeSpec):
not tf2.enabled()):
return SparseTensorValue(*tensor_list)
else:
result = SparseTensor(*tensor_list)
# Augment the static dense shape with the shape carried by the spec.
result._dense_shape_default = result._dense_shape_default.merge_with( # pylint: disable=protected-access
self._shape)
return result
return SparseTensor(*tensor_list)
# The SparseTensorSpec tensor_list encoding uses (de)serialize_sparse ops
# to (un)box the component tensors in a way that allows for batching &

View File

@ -290,16 +290,6 @@ class SparseTensorSpecTest(test_util.TensorFlowTestCase,
self.assertAllEqual(st.values, st_reconstructed.values)
self.assertAllEqual(st.dense_shape, st_reconstructed.dense_shape)
def testFromComponentsDynamicDenseShapeTensor(self):
@def_function.function(input_signature=[
sparse_tensor.SparseTensorSpec([None, 10, 100])])
def sparse_fun(st):
self.assertEqual(st.get_shape().as_list(), [None, 10, 100])
return st.dense_shape
# Force tracing the TF function.
_ = sparse_fun.get_concrete_function()
@test_util.run_v1_only("SparseTensorValue is deprecated in v2")
def testFromNumpyComponents(self):
indices = np.array([[0], [8]])

View File

@ -740,7 +740,7 @@ class DataTypesTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1
def test_sparse_tensors(self):
shape = tensor_shape.TensorShape([3, 4])
shape = tensor_shape.TensorShape([None, None])
def true_fn():
return [sparse_tensor.SparseTensor(indices=[[0, 0], [1, 2]],