Fix error when attempting to expand 1D CompositeTensors.
PiperOrigin-RevId: 297909613 Change-Id: Icc958fb7aa4742e48999320a8ff49da691351831
This commit is contained in:
parent
d0c470be2e
commit
8443eb5ce4
|
@ -1256,7 +1256,8 @@ def expand_1d(data):
|
||||||
"""Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s."""
|
"""Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s."""
|
||||||
|
|
||||||
def _expand_single_1d_tensor(t):
|
def _expand_single_1d_tensor(t):
|
||||||
if (hasattr(t, "shape") and
|
# Leaves `CompositeTensor`s as-is.
|
||||||
|
if (isinstance(t, ops.Tensor) and
|
||||||
isinstance(t.shape, tensor_shape.TensorShape) and t.shape.rank == 1):
|
isinstance(t.shape, tensor_shape.TensorShape) and t.shape.rank == 1):
|
||||||
return array_ops.expand_dims_v2(t, axis=-1)
|
return array_ops.expand_dims_v2(t, axis=-1)
|
||||||
return t
|
return t
|
||||||
|
|
|
@ -1062,6 +1062,15 @@ class TestValidationSplit(keras_parameterized.TestCase):
|
||||||
self.assertIsNone(val_sw)
|
self.assertIsNone(val_sw)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUtils(keras_parameterized.TestCase):
|
||||||
|
|
||||||
|
def test_expand_1d_sparse_tensors_untouched(self):
|
||||||
|
st = sparse_tensor.SparseTensor(
|
||||||
|
indices=[[0], [10]], values=[1, 2], dense_shape=[10])
|
||||||
|
st = data_adapter.expand_1d(st)
|
||||||
|
self.assertEqual(st.shape.rank, 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
test.main()
|
test.main()
|
||||||
|
|
Loading…
Reference in New Issue