From 8443eb5ce41b5d62e7fa5827be9ca5f2feaa95cb Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Fri, 28 Feb 2020 12:56:53 -0800 Subject: [PATCH] Fix error when attempting to expand 1D CompositeTensors. PiperOrigin-RevId: 297909613 Change-Id: Icc958fb7aa4742e48999320a8ff49da691351831 --- tensorflow/python/keras/engine/data_adapter.py | 3 ++- tensorflow/python/keras/engine/data_adapter_test.py | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index 917d1c2f7bd..3881d994714 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -1256,7 +1256,8 @@ def expand_1d(data): """Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s.""" def _expand_single_1d_tensor(t): - if (hasattr(t, "shape") and + # Leaves `CompositeTensor`s as-is. + if (isinstance(t, ops.Tensor) and isinstance(t.shape, tensor_shape.TensorShape) and t.shape.rank == 1): return array_ops.expand_dims_v2(t, axis=-1) return t diff --git a/tensorflow/python/keras/engine/data_adapter_test.py b/tensorflow/python/keras/engine/data_adapter_test.py index 346dc325f39..c0b875d016f 100644 --- a/tensorflow/python/keras/engine/data_adapter_test.py +++ b/tensorflow/python/keras/engine/data_adapter_test.py @@ -1062,6 +1062,15 @@ class TestValidationSplit(keras_parameterized.TestCase): self.assertIsNone(val_sw) +class TestUtils(keras_parameterized.TestCase): + + def test_expand_1d_sparse_tensors_untouched(self): + st = sparse_tensor.SparseTensor( + indices=[[0], [10]], values=[1, 2], dense_shape=[10]) + st = data_adapter.expand_1d(st) + self.assertEqual(st.shape.rank, 1) + + if __name__ == '__main__': ops.enable_eager_execution() test.main()