Fix error when attempting to expand 1D CompositeTensors.

PiperOrigin-RevId: 297909613
Change-Id: Icc958fb7aa4742e48999320a8ff49da691351831
This commit is contained in:
Thomas O'Malley 2020-02-28 12:56:53 -08:00 committed by TensorFlower Gardener
parent d0c470be2e
commit 8443eb5ce4
2 changed files with 11 additions and 1 deletions

View File

@ -1256,7 +1256,8 @@ def expand_1d(data):
"""Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s."""
def _expand_single_1d_tensor(t):
if (hasattr(t, "shape") and
# Leaves `CompositeTensor`s as-is.
if (isinstance(t, ops.Tensor) and
isinstance(t.shape, tensor_shape.TensorShape) and t.shape.rank == 1):
return array_ops.expand_dims_v2(t, axis=-1)
return t

View File

@ -1062,6 +1062,15 @@ class TestValidationSplit(keras_parameterized.TestCase):
self.assertIsNone(val_sw)
class TestUtils(keras_parameterized.TestCase):
def test_expand_1d_sparse_tensors_untouched(self):
st = sparse_tensor.SparseTensor(
indices=[[0], [10]], values=[1, 2], dense_shape=[10])
st = data_adapter.expand_1d(st)
self.assertEqual(st.shape.rank, 1)
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()