Automated rollback of commit 0f65838cb9787c47f41ede1884e72a0144ad2fe0
PiperOrigin-RevId: 264946313
This commit is contained in:
parent
3444b15795
commit
095f802808
@ -733,30 +733,6 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
dataset,
|
dataset,
|
||||||
expected_output=[self.evaluate(_check(_sparse(i))) for i in range(10)])
|
expected_output=[self.evaluate(_check(_sparse(i))) for i in range(10)])
|
||||||
|
|
||||||
def testSparseMapShapeInference(self):
|
|
||||||
if not context.executing_eagerly():
|
|
||||||
self.skipTest("SparseTensor shape inference requires eager mode")
|
|
||||||
row_lengths = np.random.randint(0, 4, size=128)
|
|
||||||
values = np.ones(np.sum(row_lengths))
|
|
||||||
sparse = ragged_tensor.RaggedTensor.from_row_lengths(
|
|
||||||
values, row_lengths).to_sparse()
|
|
||||||
dataset = dataset_ops.Dataset.from_tensor_slices(sparse)
|
|
||||||
dataset = dataset.batch(32, drop_remainder=True)
|
|
||||||
dataset = dataset.map(lambda x: x)
|
|
||||||
self.assertEqual((32, 3), dataset.element_spec.shape)
|
|
||||||
|
|
||||||
def testSparseMapShapeInferencePartial(self):
|
|
||||||
if not context.executing_eagerly():
|
|
||||||
self.skipTest("SparseTensor shape inference requires eager mode")
|
|
||||||
row_lengths = np.random.randint(0, 4, size=128)
|
|
||||||
values = np.ones(np.sum(row_lengths))
|
|
||||||
sparse = ragged_tensor.RaggedTensor.from_row_lengths(
|
|
||||||
values, row_lengths).to_sparse()
|
|
||||||
dataset = dataset_ops.Dataset.from_tensor_slices(sparse)
|
|
||||||
dataset = dataset.batch(32, drop_remainder=False)
|
|
||||||
dataset = dataset.map(lambda x: x)
|
|
||||||
self.assertEqual([None, 3], dataset.element_spec.shape.as_list())
|
|
||||||
|
|
||||||
def testTensorArray(self):
|
def testTensorArray(self):
|
||||||
|
|
||||||
def _tensor_array(i):
|
def _tensor_array(i):
|
||||||
|
@ -24,7 +24,6 @@ import numpy as np
|
|||||||
from tensorflow.python import _pywrap_utils
|
from tensorflow.python import _pywrap_utils
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.framework import composite_tensor
|
from tensorflow.python.framework import composite_tensor
|
||||||
from tensorflow.python.framework import constant_op
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_like
|
from tensorflow.python.framework import tensor_like
|
||||||
@ -339,28 +338,11 @@ class SparseTensorSpec(type_spec.BatchableTypeSpec):
|
|||||||
|
|
||||||
def _from_compatible_tensor_list(self, tensor_list):
|
def _from_compatible_tensor_list(self, tensor_list):
|
||||||
tensor_list = gen_sparse_ops.deserialize_sparse(tensor_list[0], self._dtype)
|
tensor_list = gen_sparse_ops.deserialize_sparse(tensor_list[0], self._dtype)
|
||||||
indices, values, dense_shape = tensor_list
|
result = SparseTensor(*tensor_list)
|
||||||
rank = self._shape.ndims
|
rank = self._shape.ndims
|
||||||
indices.set_shape([None, rank])
|
result.indices.set_shape([None, rank])
|
||||||
# We restore the dense_shape from the SparseTypeSpec. This is necessary
|
result.dense_shape.set_shape([rank])
|
||||||
# for shape inference when using placeholder SparseTensors in function
|
return result
|
||||||
# tracing.
|
|
||||||
if self._shape.is_fully_defined():
|
|
||||||
dense_shape = ops.convert_to_tensor(
|
|
||||||
self._shape, dtype=dtypes.int64, name="shape")
|
|
||||||
elif (self._shape.rank is not None and
|
|
||||||
any(dim.value is not None for dim in self._shape.dims)):
|
|
||||||
# array_ops imports sparse_tensor.py. Local import to avoid import cycle.
|
|
||||||
from tensorflow.python.ops import array_ops # pylint: disable=g-import-not-at-top
|
|
||||||
pieces = array_ops.unstack(dense_shape, num=self._shape.rank)
|
|
||||||
for i, dim in enumerate(self._shape.dims):
|
|
||||||
if dim.value is not None:
|
|
||||||
pieces[i] = constant_op.constant(dim.value, dense_shape.dtype)
|
|
||||||
dense_shape = array_ops.stack(pieces)
|
|
||||||
else:
|
|
||||||
dense_shape.set_shape([rank])
|
|
||||||
|
|
||||||
return SparseTensor(indices, values, dense_shape)
|
|
||||||
|
|
||||||
def _batch(self, batch_size):
|
def _batch(self, batch_size):
|
||||||
return SparseTensorSpec(
|
return SparseTensorSpec(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user