SparseTensor: when deserializing from component tensors, propagate the static dense_shape from the typespec.
PiperOrigin-RevId: 342338243 Change-Id: Iefaa8857f702d34521fc4ee23e1cbaf19752f53c
This commit is contained in:
parent
31c20f9e8a
commit
595da5557f
@ -342,11 +342,7 @@ class SparseTensorSpec(type_spec.BatchableTypeSpec):
|
||||
not tf2.enabled()):
|
||||
return SparseTensorValue(*tensor_list)
|
||||
else:
|
||||
result = SparseTensor(*tensor_list)
|
||||
# Augment the static dense shape with the shape carried by the spec.
|
||||
result._dense_shape_default = result._dense_shape_default.merge_with( # pylint: disable=protected-access
|
||||
self._shape)
|
||||
return result
|
||||
return SparseTensor(*tensor_list)
|
||||
|
||||
# The SparseTensorSpec tensor_list encoding uses (de)serialize_sparse ops
|
||||
# to (un)box the component tensors in a way that allows for batching &
|
||||
|
@ -290,16 +290,6 @@ class SparseTensorSpecTest(test_util.TensorFlowTestCase,
|
||||
self.assertAllEqual(st.values, st_reconstructed.values)
|
||||
self.assertAllEqual(st.dense_shape, st_reconstructed.dense_shape)
|
||||
|
||||
def testFromComponentsDynamicDenseShapeTensor(self):
|
||||
@def_function.function(input_signature=[
|
||||
sparse_tensor.SparseTensorSpec([None, 10, 100])])
|
||||
def sparse_fun(st):
|
||||
self.assertEqual(st.get_shape().as_list(), [None, 10, 100])
|
||||
return st.dense_shape
|
||||
|
||||
# Force tracing the TF function.
|
||||
_ = sparse_fun.get_concrete_function()
|
||||
|
||||
@test_util.run_v1_only("SparseTensorValue is deprecated in v2")
|
||||
def testFromNumpyComponents(self):
|
||||
indices = np.array([[0], [8]])
|
||||
|
@ -740,7 +740,7 @@ class DataTypesTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_sparse_tensors(self):
|
||||
shape = tensor_shape.TensorShape([3, 4])
|
||||
shape = tensor_shape.TensorShape([None, None])
|
||||
|
||||
def true_fn():
|
||||
return [sparse_tensor.SparseTensor(indices=[[0, 0], [1, 2]],
|
||||
|
Loading…
Reference in New Issue
Block a user