From 7d76bc4b60e020bbfd1339923ea5a7c3ab007217 Mon Sep 17 00:00:00 2001 From: Tomer Kaftan Date: Mon, 15 Jun 2020 13:56:55 -0700 Subject: [PATCH] Fix sparse kerastensors to maintain dense shape information after converting to a placeholder. PiperOrigin-RevId: 316538468 Change-Id: I8e53a7e96067a8b7edd3f57cd8a8a89eb912824b --- .../python/keras/engine/keras_tensor.py | 19 +++++++++++++++---- .../utils/composite_tensor_support_test.py | 3 +-- tensorflow/python/ops/array_ops.py | 1 + 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/keras/engine/keras_tensor.py b/tensorflow/python/keras/engine/keras_tensor.py index 4ea01da8db2..c5c0068c652 100644 --- a/tensorflow/python/keras/engine/keras_tensor.py +++ b/tensorflow/python/keras/engine/keras_tensor.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import type_spec as type_spec_module from tensorflow.python.ops import array_ops from tensorflow.python.util import nest @@ -210,10 +211,20 @@ class _KerasTensorIterator(object): def keras_tensor_to_placeholder(x): """TODO(kaftan): Docstring.""" if isinstance(x, KerasTensor): - def tensor_spec_to_placeholder(tensorspec): - return array_ops.placeholder(tensorspec.dtype, tensorspec.shape) - ph = nest.map_structure(tensor_spec_to_placeholder, x.type_spec, - expand_composites=True) + spec = x.type_spec + if isinstance(spec, sparse_tensor.SparseTensorSpec): + # nest.map_structure loses dense shape information for sparse tensors. + # So, we special-case sparse placeholder creation. + # This only preserves shape information for top-level sparse tensors; + # not for sparse tensors that are nested inside another composite + # tensor. + return array_ops.sparse_placeholder(dtype=spec.dtype, shape=spec.shape) + + def component_to_placeholder(component): + return array_ops.placeholder(component.dtype, component.shape) + + ph = nest.map_structure( + component_to_placeholder, spec, expand_composites=True) return ph else: return x diff --git a/tensorflow/python/keras/utils/composite_tensor_support_test.py b/tensorflow/python/keras/utils/composite_tensor_support_test.py index f31558ddba8..daba188414a 100644 --- a/tensorflow/python/keras/utils/composite_tensor_support_test.py +++ b/tensorflow/python/keras/utils/composite_tensor_support_test.py @@ -603,8 +603,7 @@ class RaggedTensorInputValidationTest(keras_parameterized.TestCase, @keras_parameterized.run_with_all_model_types() -@keras_parameterized.run_all_keras_modes(always_skip_v1=True, - skip_keras_tensors=True) +@keras_parameterized.run_all_keras_modes(always_skip_v1=True) class CompositeTensorModelPredictTest(keras_parameterized.TestCase): def _normalize_shape(self, shape): diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index c77977bf7d2..1c00b81c9ca 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -3184,6 +3184,7 @@ def sparse_placeholder(dtype, shape=None, name=None): # `SparseTensor` dense_shape_default = tensor_shape.TensorShape( tuple(None if dim == -1 else dim for dim in shape)) + shape = tuple(tensor_shape.dimension_value(dim) for dim in shape) shape = tuple(-1 if dim is None else dim for dim in shape) shape = ops.convert_to_tensor( shape, dtype=dtypes.int64, name=default_shape_name)