Fix error when attempting to expand 1D CompositeTensors.
PiperOrigin-RevId: 297909613 Change-Id: Icc958fb7aa4742e48999320a8ff49da691351831
This commit is contained in:
parent
d0c470be2e
commit
8443eb5ce4
|
@ -1256,7 +1256,8 @@ def expand_1d(data):
|
|||
"""Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s."""
|
||||
|
||||
def _expand_single_1d_tensor(t):
|
||||
if (hasattr(t, "shape") and
|
||||
# Leaves `CompositeTensor`s as-is.
|
||||
if (isinstance(t, ops.Tensor) and
|
||||
isinstance(t.shape, tensor_shape.TensorShape) and t.shape.rank == 1):
|
||||
return array_ops.expand_dims_v2(t, axis=-1)
|
||||
return t
|
||||
|
|
|
@ -1062,6 +1062,15 @@ class TestValidationSplit(keras_parameterized.TestCase):
|
|||
self.assertIsNone(val_sw)
|
||||
|
||||
|
||||
class TestUtils(keras_parameterized.TestCase):
|
||||
|
||||
def test_expand_1d_sparse_tensors_untouched(self):
|
||||
st = sparse_tensor.SparseTensor(
|
||||
indices=[[0], [10]], values=[1, 2], dense_shape=[10])
|
||||
st = data_adapter.expand_1d(st)
|
||||
self.assertEqual(st.shape.rank, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
||||
|
|
Loading…
Reference in New Issue