Fix sparse kerastensors to maintain dense shape information after converting to a placeholder.

PiperOrigin-RevId: 316538468
Change-Id: I8e53a7e96067a8b7edd3f57cd8a8a89eb912824b
This commit is contained in:
Tomer Kaftan 2020-06-15 13:56:55 -07:00 committed by TensorFlower Gardener
parent a8950d70bf
commit 7d76bc4b60
3 changed files with 17 additions and 6 deletions

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import type_spec as type_spec_module from tensorflow.python.framework import type_spec as type_spec_module
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.util import nest from tensorflow.python.util import nest
@ -210,10 +211,20 @@ class _KerasTensorIterator(object):
def keras_tensor_to_placeholder(x): def keras_tensor_to_placeholder(x):
"""TODO(kaftan): Docstring.""" """TODO(kaftan): Docstring."""
if isinstance(x, KerasTensor): if isinstance(x, KerasTensor):
def tensor_spec_to_placeholder(tensorspec): spec = x.type_spec
return array_ops.placeholder(tensorspec.dtype, tensorspec.shape) if isinstance(spec, sparse_tensor.SparseTensorSpec):
ph = nest.map_structure(tensor_spec_to_placeholder, x.type_spec, # nest.map_structure loses dense shape information for sparse tensors.
expand_composites=True) # So, we special-case sparse placeholder creation.
# This only preserves shape information for top-level sparse tensors;
# not for sparse tensors that are nested inside another composite
# tensor.
return array_ops.sparse_placeholder(dtype=spec.dtype, shape=spec.shape)
def component_to_placeholder(component):
return array_ops.placeholder(component.dtype, component.shape)
ph = nest.map_structure(
component_to_placeholder, spec, expand_composites=True)
return ph return ph
else: else:
return x return x

View File

@ -603,8 +603,7 @@ class RaggedTensorInputValidationTest(keras_parameterized.TestCase,
@keras_parameterized.run_with_all_model_types() @keras_parameterized.run_with_all_model_types()
@keras_parameterized.run_all_keras_modes(always_skip_v1=True, @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
skip_keras_tensors=True)
class CompositeTensorModelPredictTest(keras_parameterized.TestCase): class CompositeTensorModelPredictTest(keras_parameterized.TestCase):
def _normalize_shape(self, shape): def _normalize_shape(self, shape):

View File

@ -3184,6 +3184,7 @@ def sparse_placeholder(dtype, shape=None, name=None):
# `SparseTensor` # `SparseTensor`
dense_shape_default = tensor_shape.TensorShape( dense_shape_default = tensor_shape.TensorShape(
tuple(None if dim == -1 else dim for dim in shape)) tuple(None if dim == -1 else dim for dim in shape))
shape = tuple(tensor_shape.dimension_value(dim) for dim in shape)
shape = tuple(-1 if dim is None else dim for dim in shape) shape = tuple(-1 if dim is None else dim for dim in shape)
shape = ops.convert_to_tensor( shape = ops.convert_to_tensor(
shape, dtype=dtypes.int64, name=default_shape_name) shape, dtype=dtypes.int64, name=default_shape_name)