diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index 917d1c2f7bd..3881d994714 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -1256,7 +1256,8 @@ def expand_1d(data): """Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s.""" def _expand_single_1d_tensor(t): - if (hasattr(t, "shape") and + # Leaves `CompositeTensor`s as-is. + if (isinstance(t, ops.Tensor) and isinstance(t.shape, tensor_shape.TensorShape) and t.shape.rank == 1): return array_ops.expand_dims_v2(t, axis=-1) return t diff --git a/tensorflow/python/keras/engine/data_adapter_test.py b/tensorflow/python/keras/engine/data_adapter_test.py index 346dc325f39..c0b875d016f 100644 --- a/tensorflow/python/keras/engine/data_adapter_test.py +++ b/tensorflow/python/keras/engine/data_adapter_test.py @@ -1062,6 +1062,15 @@ class TestValidationSplit(keras_parameterized.TestCase): self.assertIsNone(val_sw) +class TestUtils(keras_parameterized.TestCase): + + def test_expand_1d_sparse_tensors_untouched(self): + st = sparse_tensor.SparseTensor( + indices=[[0], [10]], values=[1, 2], dense_shape=[10]) + st = data_adapter.expand_1d(st) + self.assertEqual(st.shape.rank, 1) + + if __name__ == '__main__': ops.enable_eager_execution() test.main()