Add ragged tensor input support to Keras via the 'ragged' arg in keras.Input.
PiperOrigin-RevId: 254256357
This commit is contained in:
parent
4ab0dc8887
commit
7ed84ad814
@ -1168,6 +1168,7 @@ tf_py_test(
|
|||||||
"//tensorflow/python:sparse_ops",
|
"//tensorflow/python:sparse_ops",
|
||||||
"//tensorflow/python:sparse_tensor",
|
"//tensorflow/python:sparse_tensor",
|
||||||
],
|
],
|
||||||
|
shard_count = 4,
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
|
@ -38,7 +38,6 @@ from tensorflow.python.distribute import distribute_coordinator_context as dc_co
|
|||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.distribute import multi_worker_util
|
from tensorflow.python.distribute import multi_worker_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import composite_tensor_utils
|
|
||||||
from tensorflow.python.eager import function as eager_function
|
from tensorflow.python.eager import function as eager_function
|
||||||
from tensorflow.python.eager import lift_to_graph
|
from tensorflow.python.eager import lift_to_graph
|
||||||
from tensorflow.python.framework import composite_tensor
|
from tensorflow.python.framework import composite_tensor
|
||||||
@ -70,6 +69,7 @@ from tensorflow.python.ops import state_ops
|
|||||||
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
|
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
|
||||||
from tensorflow.python.ops import tensor_array_ops
|
from tensorflow.python.ops import tensor_array_ops
|
||||||
from tensorflow.python.ops import variables as variables_module
|
from tensorflow.python.ops import variables as variables_module
|
||||||
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import tf_contextlib
|
from tensorflow.python.util import tf_contextlib
|
||||||
@ -958,7 +958,12 @@ def is_keras_tensor(x):
|
|||||||
|
|
||||||
|
|
||||||
@keras_export('keras.backend.placeholder')
|
@keras_export('keras.backend.placeholder')
|
||||||
def placeholder(shape=None, ndim=None, dtype=None, sparse=False, name=None):
|
def placeholder(shape=None,
|
||||||
|
ndim=None,
|
||||||
|
dtype=None,
|
||||||
|
sparse=False,
|
||||||
|
name=None,
|
||||||
|
ragged=False):
|
||||||
"""Instantiates a placeholder tensor and returns it.
|
"""Instantiates a placeholder tensor and returns it.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -970,9 +975,14 @@ def placeholder(shape=None, ndim=None, dtype=None, sparse=False, name=None):
|
|||||||
dtype: Placeholder type.
|
dtype: Placeholder type.
|
||||||
sparse: Boolean, whether the placeholder should have a sparse type.
|
sparse: Boolean, whether the placeholder should have a sparse type.
|
||||||
name: Optional name string for the placeholder.
|
name: Optional name string for the placeholder.
|
||||||
|
ragged: Boolean, whether the placeholder should have a ragged type.
|
||||||
|
In this case, values of 'None' in the 'shape' argument represent
|
||||||
|
ragged dimensions. For more information about RaggedTensors, see this
|
||||||
|
[guide](https://www.tensorflow.org/guide/ragged_tensors).
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If called with eager execution.
|
ValueError: If called with eager execution
|
||||||
|
ValueError: If called with sparse = True and ragged = True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor instance (with Keras metadata included).
|
Tensor instance (with Keras metadata included).
|
||||||
@ -985,6 +995,11 @@ def placeholder(shape=None, ndim=None, dtype=None, sparse=False, name=None):
|
|||||||
<tf.Tensor 'Placeholder_4:0' shape=(2, 4, 5) dtype=float32>
|
<tf.Tensor 'Placeholder_4:0' shape=(2, 4, 5) dtype=float32>
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
if sparse and ragged:
|
||||||
|
raise ValueError(
|
||||||
|
'Cannot set both sparse and ragged to True when creating a placeholder.'
|
||||||
|
)
|
||||||
|
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = floatx()
|
dtype = floatx()
|
||||||
if not shape:
|
if not shape:
|
||||||
@ -993,6 +1008,20 @@ def placeholder(shape=None, ndim=None, dtype=None, sparse=False, name=None):
|
|||||||
with get_graph().as_default():
|
with get_graph().as_default():
|
||||||
if sparse:
|
if sparse:
|
||||||
x = array_ops.sparse_placeholder(dtype, shape=shape, name=name)
|
x = array_ops.sparse_placeholder(dtype, shape=shape, name=name)
|
||||||
|
elif ragged:
|
||||||
|
ragged_rank = 0
|
||||||
|
for i in range(1, len(shape)):
|
||||||
|
if shape[i] is None:
|
||||||
|
ragged_rank += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
value_shape = shape[(ragged_rank + 1):]
|
||||||
|
|
||||||
|
x = ragged_factory_ops.placeholder(
|
||||||
|
dtype=dtype,
|
||||||
|
ragged_rank=ragged_rank,
|
||||||
|
value_shape=value_shape,
|
||||||
|
name=name)
|
||||||
else:
|
else:
|
||||||
x = array_ops.placeholder(dtype, shape=shape, name=name)
|
x = array_ops.placeholder(dtype, shape=shape, name=name)
|
||||||
return x
|
return x
|
||||||
@ -1008,6 +1037,10 @@ def is_placeholder(x):
|
|||||||
Boolean.
|
Boolean.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
if isinstance(x, composite_tensor.CompositeTensor):
|
||||||
|
flat_components = nest.flatten(x, expand_composites=True)
|
||||||
|
return py_any(is_placeholder(c) for c in flat_components)
|
||||||
|
else:
|
||||||
return x.op.type == 'Placeholder'
|
return x.op.type == 'Placeholder'
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
return False
|
return False
|
||||||
@ -3108,63 +3141,6 @@ def print_tensor(x, message=''):
|
|||||||
logging_ops.print_v2(message, x, output_stream=sys.stdout)
|
logging_ops.print_v2(message, x, output_stream=sys.stdout)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def is_tensor_or_composite_tensor(value):
|
|
||||||
"""Test if a passed value object is a tensor-like or composite tensor."""
|
|
||||||
return (tensor_util.is_tensor(value) or isinstance(value, np.ndarray) or
|
|
||||||
composite_tensor_utils.is_composite_or_composite_value(value))
|
|
||||||
|
|
||||||
|
|
||||||
def _try_process_scipy_sparse_input(value):
|
|
||||||
"""Converts 'value' to a SparseTensor if it is a scipy sparse matrix.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
value: An object that may have the attributes of a scipy sparse matrix.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Either a SparseTensor based off of 'value' or 'value' itself.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
sparse_coo = value.tocoo()
|
|
||||||
row, col = sparse_coo.row, sparse_coo.col
|
|
||||||
data, shape = sparse_coo.data, sparse_coo.shape
|
|
||||||
except AttributeError:
|
|
||||||
# If we can't convert this object, it could be either a single data
|
|
||||||
# element (ie, a bool/int/float) which is OK to pass on, or something
|
|
||||||
# that we don't understand (which may or may not be OK). In either
|
|
||||||
# case, don't die here: the data standardization code will catch
|
|
||||||
# those issues.
|
|
||||||
return value
|
|
||||||
|
|
||||||
indices = np.concatenate((np.expand_dims(row, 1), np.expand_dims(col, 1)), 1)
|
|
||||||
return sparse_tensor.SparseTensor(indices, data, shape)
|
|
||||||
|
|
||||||
|
|
||||||
def try_convert_scipy_to_sparse(values):
|
|
||||||
"""Converts scipy sparse matrices in 'values' to SparseTensors, if possible.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
values: An input or list of inputs to convert. These may be TensorLikes,
|
|
||||||
ndarrays, composite tensors, or scipy sparse values.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An input or list of inputs where scipy sparse tensors have been converted
|
|
||||||
to tf.SparseTensors.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If input cannot be converted to a SparseTensor.
|
|
||||||
"""
|
|
||||||
# Convert scipy sparse data into sparse tensors.
|
|
||||||
value_structure = values
|
|
||||||
values = nest.flatten(values)
|
|
||||||
for idx, value in enumerate(values):
|
|
||||||
if not is_tensor_or_composite_tensor(value):
|
|
||||||
values[idx] = _try_process_scipy_sparse_input(value)
|
|
||||||
values = nest.pack_sequence_as(value_structure, values)
|
|
||||||
|
|
||||||
return values
|
|
||||||
|
|
||||||
|
|
||||||
# GRAPH MANIPULATION
|
# GRAPH MANIPULATION
|
||||||
|
|
||||||
|
|
||||||
@ -3194,6 +3170,7 @@ class GraphExecutionFunction(object):
|
|||||||
if not isinstance(updates, (list, tuple)):
|
if not isinstance(updates, (list, tuple)):
|
||||||
raise TypeError('`updates` in a Keras backend function '
|
raise TypeError('`updates` in a Keras backend function '
|
||||||
'should be a list or tuple.')
|
'should be a list or tuple.')
|
||||||
|
|
||||||
self._inputs_structure = inputs
|
self._inputs_structure = inputs
|
||||||
self.inputs = nest.flatten(inputs, expand_composites=True)
|
self.inputs = nest.flatten(inputs, expand_composites=True)
|
||||||
self._outputs_structure = outputs
|
self._outputs_structure = outputs
|
||||||
@ -3311,10 +3288,6 @@ class GraphExecutionFunction(object):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def __call__(self, inputs):
|
def __call__(self, inputs):
|
||||||
inputs = try_convert_scipy_to_sparse(inputs)
|
|
||||||
|
|
||||||
# Ensure that input value types match any expected composite tensor types.
|
|
||||||
# TODO(momernick): Once TensorSpecs are implemented for CTs, use that here.
|
|
||||||
inputs = nest.flatten(inputs, expand_composites=True)
|
inputs = nest.flatten(inputs, expand_composites=True)
|
||||||
|
|
||||||
session = get_session(inputs)
|
session = get_session(inputs)
|
||||||
@ -3488,10 +3461,8 @@ class EagerExecutionFunction(object):
|
|||||||
x.op.inputs[0])
|
x.op.inputs[0])
|
||||||
|
|
||||||
def __call__(self, inputs):
|
def __call__(self, inputs):
|
||||||
# Convert scipy sparse data into sparse tensors.
|
|
||||||
inputs = try_convert_scipy_to_sparse(inputs)
|
|
||||||
|
|
||||||
input_values = nest.flatten(inputs, expand_composites=True)
|
input_values = nest.flatten(inputs, expand_composites=True)
|
||||||
|
|
||||||
if self._freezable_vars_values:
|
if self._freezable_vars_values:
|
||||||
input_values = input_values + self._freezable_vars_values
|
input_values = input_values + self._freezable_vars_values
|
||||||
converted_inputs = []
|
converted_inputs = []
|
||||||
|
@ -34,12 +34,15 @@ class InputLayer(base_layer.Layer):
|
|||||||
"""Layer to be used as an entry point into a Network (a graph of layers).
|
"""Layer to be used as an entry point into a Network (a graph of layers).
|
||||||
|
|
||||||
It can either wrap an existing tensor (pass an `input_tensor` argument)
|
It can either wrap an existing tensor (pass an `input_tensor` argument)
|
||||||
or create its a placeholder tensor (pass arguments `input_shape`, and
|
or create a placeholder tensor (pass arguments `input_shape`, and
|
||||||
optionally, `dtype`).
|
optionally, `dtype`).
|
||||||
|
|
||||||
It is generally recommend to use the functional layer API via `Input`,
|
It is generally recommend to use the functional layer API via `Input`,
|
||||||
(which creates an `InputLayer`) without directly using `InputLayer`.
|
(which creates an `InputLayer`) without directly using `InputLayer`.
|
||||||
|
|
||||||
|
This class can create placeholders for tf.Tensors, tf.SparseTensors, and
|
||||||
|
tf.RaggedTensors by choosing 'sparse=True' or 'ragged=True'.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
input_shape: Shape tuple (not including the batch axis), or `TensorShape`
|
input_shape: Shape tuple (not including the batch axis), or `TensorShape`
|
||||||
instance (not including the batch axis).
|
instance (not including the batch axis).
|
||||||
@ -47,8 +50,11 @@ class InputLayer(base_layer.Layer):
|
|||||||
dtype: Datatype of the input.
|
dtype: Datatype of the input.
|
||||||
input_tensor: Optional tensor to use as layer input
|
input_tensor: Optional tensor to use as layer input
|
||||||
instead of creating a placeholder.
|
instead of creating a placeholder.
|
||||||
sparse: Boolean, whether the placeholder created
|
sparse: Boolean, whether the placeholder created is meant to be sparse.
|
||||||
is meant to be sparse.
|
ragged: Boolean, whether the placeholder created is meant to be ragged.
|
||||||
|
In this case, values of 'None' in the 'shape' argument represent
|
||||||
|
ragged dimensions. For more information about RaggedTensors, see
|
||||||
|
https://www.tensorflow.org/guide/ragged_tensors.
|
||||||
name: Name of the layer (string).
|
name: Name of the layer (string).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -59,6 +65,7 @@ class InputLayer(base_layer.Layer):
|
|||||||
input_tensor=None,
|
input_tensor=None,
|
||||||
sparse=False,
|
sparse=False,
|
||||||
name=None,
|
name=None,
|
||||||
|
ragged=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
strategy = distribution_strategy_context.get_strategy()
|
strategy = distribution_strategy_context.get_strategy()
|
||||||
if strategy and batch_size is not None and \
|
if strategy and batch_size is not None and \
|
||||||
@ -110,18 +117,12 @@ class InputLayer(base_layer.Layer):
|
|||||||
batch_input_shape = None
|
batch_input_shape = None
|
||||||
graph = backend.get_graph()
|
graph = backend.get_graph()
|
||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
# In graph mode, create a graph placeholder to call the layer on.
|
|
||||||
if sparse:
|
|
||||||
input_tensor = backend.placeholder(
|
input_tensor = backend.placeholder(
|
||||||
shape=batch_input_shape,
|
shape=batch_input_shape,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
name=self.name,
|
name=self.name,
|
||||||
sparse=True)
|
sparse=sparse,
|
||||||
else:
|
ragged=ragged)
|
||||||
input_tensor = backend.placeholder(
|
|
||||||
shape=batch_input_shape,
|
|
||||||
dtype=dtype,
|
|
||||||
name=self.name)
|
|
||||||
|
|
||||||
self.is_placeholder = True
|
self.is_placeholder = True
|
||||||
self._batch_input_shape = batch_input_shape
|
self._batch_input_shape = batch_input_shape
|
||||||
@ -164,6 +165,7 @@ def Input( # pylint: disable=invalid-name
|
|||||||
dtype=None,
|
dtype=None,
|
||||||
sparse=False,
|
sparse=False,
|
||||||
tensor=None,
|
tensor=None,
|
||||||
|
ragged=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""`Input()` is used to instantiate a Keras tensor.
|
"""`Input()` is used to instantiate a Keras tensor.
|
||||||
|
|
||||||
@ -184,17 +186,24 @@ def Input( # pylint: disable=invalid-name
|
|||||||
Arguments:
|
Arguments:
|
||||||
shape: A shape tuple (integers), not including the batch size.
|
shape: A shape tuple (integers), not including the batch size.
|
||||||
For instance, `shape=(32,)` indicates that the expected input
|
For instance, `shape=(32,)` indicates that the expected input
|
||||||
will be batches of 32-dimensional vectors.
|
will be batches of 32-dimensional vectors. Elements of this tuple
|
||||||
|
can be None; 'None' elements represent dimensions where the shape is
|
||||||
|
not known.
|
||||||
batch_size: optional static batch size (integer).
|
batch_size: optional static batch size (integer).
|
||||||
name: An optional name string for the layer.
|
name: An optional name string for the layer.
|
||||||
Should be unique in a model (do not reuse the same name twice).
|
Should be unique in a model (do not reuse the same name twice).
|
||||||
It will be autogenerated if it isn't provided.
|
It will be autogenerated if it isn't provided.
|
||||||
dtype: The data type expected by the input, as a string
|
dtype: The data type expected by the input, as a string
|
||||||
(`float32`, `float64`, `int32`...)
|
(`float32`, `float64`, `int32`...)
|
||||||
sparse: A boolean specifying whether the placeholder
|
sparse: A boolean specifying whether the placeholder to be created is
|
||||||
to be created is sparse.
|
sparse. Only one of 'ragged' and 'sparse' can be True.
|
||||||
tensor: Optional existing tensor to wrap into the `Input` layer.
|
tensor: Optional existing tensor to wrap into the `Input` layer.
|
||||||
If set, the layer will not create a placeholder tensor.
|
If set, the layer will not create a placeholder tensor.
|
||||||
|
ragged: A boolean specifying whether the placeholder to be created is
|
||||||
|
ragged. Only one of 'ragged' and 'sparse' can be True. In this case,
|
||||||
|
values of 'None' in the 'shape' argument represent ragged dimensions.
|
||||||
|
For more information about RaggedTensors, see
|
||||||
|
https://www.tensorflow.org/guide/ragged_tensors.
|
||||||
**kwargs: deprecated arguments support.
|
**kwargs: deprecated arguments support.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -222,6 +231,10 @@ def Input( # pylint: disable=invalid-name
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: in case of invalid arguments.
|
ValueError: in case of invalid arguments.
|
||||||
"""
|
"""
|
||||||
|
if sparse and ragged:
|
||||||
|
raise ValueError(
|
||||||
|
'Cannot set both sparse and ragged to True in a Keras input.')
|
||||||
|
|
||||||
batch_shape = None
|
batch_shape = None
|
||||||
if 'batch_shape' in kwargs:
|
if 'batch_shape' in kwargs:
|
||||||
batch_shape = kwargs.pop('batch_shape')
|
batch_shape = kwargs.pop('batch_shape')
|
||||||
@ -246,6 +259,7 @@ def Input( # pylint: disable=invalid-name
|
|||||||
name=name,
|
name=name,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
sparse=sparse,
|
sparse=sparse,
|
||||||
|
ragged=ragged,
|
||||||
input_tensor=tensor)
|
input_tensor=tensor)
|
||||||
else:
|
else:
|
||||||
input_layer = InputLayer(
|
input_layer = InputLayer(
|
||||||
@ -254,6 +268,7 @@ def Input( # pylint: disable=invalid-name
|
|||||||
name=name,
|
name=name,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
sparse=sparse,
|
sparse=sparse,
|
||||||
|
ragged=ragged,
|
||||||
input_tensor=tensor)
|
input_tensor=tensor)
|
||||||
|
|
||||||
# Return tensor including `_keras_history`.
|
# Return tensor including `_keras_history`.
|
||||||
|
@ -25,6 +25,7 @@ import numpy as np
|
|||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.ops import iterator_ops
|
from tensorflow.python.data.ops import iterator_ops
|
||||||
|
from tensorflow.python.data.util import structure
|
||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.distribute import multi_worker_util
|
from tensorflow.python.distribute import multi_worker_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -33,9 +34,11 @@ from tensorflow.python.eager import monitoring
|
|||||||
from tensorflow.python.framework import composite_tensor_utils
|
from tensorflow.python.framework import composite_tensor_utils
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
|
from tensorflow.python.framework import type_spec
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
from tensorflow.python.keras import losses
|
from tensorflow.python.keras import losses
|
||||||
from tensorflow.python.keras import metrics as metrics_module
|
from tensorflow.python.keras import metrics as metrics_module
|
||||||
@ -61,6 +64,10 @@ from tensorflow.python.util import nest
|
|||||||
from tensorflow.python.util import serialization
|
from tensorflow.python.util import serialization
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
|
|
||||||
|
try:
|
||||||
|
from scipy.sparse import issparse # pylint: disable=g-import-not-at-top
|
||||||
|
except ImportError:
|
||||||
|
issparse = None
|
||||||
|
|
||||||
_keras_api_gauge = monitoring.BoolGauge('/tensorflow/api/keras',
|
_keras_api_gauge = monitoring.BoolGauge('/tensorflow/api/keras',
|
||||||
'keras api usage', 'method')
|
'keras api usage', 'method')
|
||||||
@ -2444,6 +2451,39 @@ class Model(network.Network):
|
|||||||
check_batch_axis=False, # Don't enforce the batch size.
|
check_batch_axis=False, # Don't enforce the batch size.
|
||||||
exception_prefix='input')
|
exception_prefix='input')
|
||||||
|
|
||||||
|
# Get typespecs for the input data and sanitize it if necessary.
|
||||||
|
# TODO(momernick): This should be capable of doing full input validation
|
||||||
|
# at all times - validate that this is so and refactor the standardization
|
||||||
|
# code.
|
||||||
|
if isinstance(x, dataset_ops.DatasetV2):
|
||||||
|
x_shapes = dataset_ops.get_structure(x)
|
||||||
|
# TODO(momernick): Remove this once NestedStructure goes away. Right
|
||||||
|
# now, Dataset outputs one of these instead of an actual python structure.
|
||||||
|
if isinstance(x_shapes, structure.NestedStructure):
|
||||||
|
x_shapes = x_shapes._component_specs # pylint: disable=protected-access
|
||||||
|
if isinstance(x_shapes, tuple):
|
||||||
|
# If the output of a Dataset is a tuple, we assume it's either of the
|
||||||
|
# form (x_data, y_data) or (x_data, y_data, sample_weights). In either
|
||||||
|
# case, we only care about x_data here.
|
||||||
|
x_shapes = x_shapes[0]
|
||||||
|
else:
|
||||||
|
flat_inputs = nest.flatten(x, expand_composites=False)
|
||||||
|
flat_expected_inputs = nest.flatten(self.inputs, expand_composites=False)
|
||||||
|
converted_x = []
|
||||||
|
for (a, b) in zip(flat_inputs, flat_expected_inputs):
|
||||||
|
converted_x.append(_convert_scipy_sparse_tensor(a, b))
|
||||||
|
x = nest.pack_sequence_as(x, converted_x, expand_composites=False)
|
||||||
|
x_shapes = nest.map_structure(type_spec.type_spec_from_value, x)
|
||||||
|
|
||||||
|
# If the inputs are still a NestedStructure, then we have a dict-input to
|
||||||
|
# this model. We can't yet validate this. (It's only relevant for feature
|
||||||
|
# columns).
|
||||||
|
if not isinstance(x_shapes, structure.NestedStructure):
|
||||||
|
flat_inputs = nest.flatten(x_shapes, expand_composites=False)
|
||||||
|
flat_expected_inputs = nest.flatten(self.inputs, expand_composites=False)
|
||||||
|
for (a, b) in zip(flat_inputs, flat_expected_inputs):
|
||||||
|
nest.assert_same_structure(a, b, expand_composites=True)
|
||||||
|
|
||||||
if y is not None:
|
if y is not None:
|
||||||
if not self._is_graph_network:
|
if not self._is_graph_network:
|
||||||
feed_output_names = self._feed_output_names
|
feed_output_names = self._feed_output_names
|
||||||
@ -3049,3 +3089,34 @@ class _TrainingTarget(object):
|
|||||||
|
|
||||||
def _is_symbolic_tensor(x):
|
def _is_symbolic_tensor(x):
|
||||||
return tensor_util.is_tensor(x) and not isinstance(x, ops.EagerTensor)
|
return tensor_util.is_tensor(x) and not isinstance(x, ops.EagerTensor)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_scipy_sparse_tensor(value, expected_input):
|
||||||
|
"""Handle scipy sparse tensor conversions.
|
||||||
|
|
||||||
|
This method takes a value 'value' and returns the proper conversion. If
|
||||||
|
value is a scipy sparse tensor and the expected input is a dense tensor,
|
||||||
|
we densify 'value'. If value is a scipy sparse tensor and the expected input
|
||||||
|
is a TF SparseTensor, we convert 'value' to a SparseTensor. If 'value' is
|
||||||
|
not a scipy sparse tensor, or scipy is not imported, we pass it through
|
||||||
|
unchanged.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
value: An object that may be a scipy sparse tensor
|
||||||
|
expected_input: The expected input placeholder.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The possibly-converted 'value'.
|
||||||
|
"""
|
||||||
|
if issparse is not None and issparse(value):
|
||||||
|
if ops.is_dense_tensor_like(expected_input):
|
||||||
|
return value.toarray()
|
||||||
|
else:
|
||||||
|
sparse_coo = value.tocoo()
|
||||||
|
row, col = sparse_coo.row, sparse_coo.col
|
||||||
|
data, shape = sparse_coo.data, sparse_coo.shape
|
||||||
|
indices = np.concatenate((np.expand_dims(row, 1), np.expand_dims(col, 1)),
|
||||||
|
1)
|
||||||
|
return sparse_tensor.SparseTensor(indices, data, shape)
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
@ -353,11 +353,16 @@ def model_iteration(model,
|
|||||||
elif shuffle:
|
elif shuffle:
|
||||||
np.random.shuffle(index_array)
|
np.random.shuffle(index_array)
|
||||||
batches = make_batches(num_samples_or_steps, batch_size)
|
batches = make_batches(num_samples_or_steps, batch_size)
|
||||||
|
|
||||||
for batch_index, (batch_start, batch_end) in enumerate(batches):
|
for batch_index, (batch_start, batch_end) in enumerate(batches):
|
||||||
batch_ids = index_array[batch_start:batch_end]
|
batch_ids = index_array[batch_start:batch_end]
|
||||||
|
|
||||||
# Slice into a batch.
|
# Slice into a batch.
|
||||||
|
if len(batches) == 1:
|
||||||
|
# If we only have one batch, do not slice. This takes care of
|
||||||
|
# composite tensors in non-Dataset modes; we currently don't support
|
||||||
|
# slicing them.
|
||||||
|
# TODO(b/133517906): Add slicing support.
|
||||||
|
ins_batch = ins
|
||||||
|
else:
|
||||||
try:
|
try:
|
||||||
if ins and isinstance(ins[-1], int):
|
if ins and isinstance(ins[-1], int):
|
||||||
# Do not slice the training phase flag.
|
# Do not slice the training phase flag.
|
||||||
|
@ -382,6 +382,7 @@ def check_num_samples(ins, batch_size=None, steps=None, steps_name='steps'):
|
|||||||
' is set, the `batch_size` must be None.')
|
' is set, the `batch_size` must be None.')
|
||||||
if check_steps_argument(ins, steps, steps_name):
|
if check_steps_argument(ins, steps, steps_name):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if hasattr(ins[0], 'shape'):
|
if hasattr(ins[0], 'shape'):
|
||||||
return int(ins[0].shape[0])
|
return int(ins[0].shape[0])
|
||||||
return None # Edge case where ins == [static_learning_phase]
|
return None # Edge case where ins == [static_learning_phase]
|
||||||
@ -501,7 +502,8 @@ def standardize_input_data(data,
|
|||||||
continue
|
continue
|
||||||
data_shape = tuple(tensorshape.as_list())
|
data_shape = tuple(tensorshape.as_list())
|
||||||
elif composite_tensor_utils.is_composite_or_composite_value(data[i]):
|
elif composite_tensor_utils.is_composite_or_composite_value(data[i]):
|
||||||
data_shape = composite_tensor_utils.get_shape(data[i])
|
tensorshape = composite_tensor_utils.get_shape(data[i])
|
||||||
|
data_shape = tuple(tensorshape.as_list())
|
||||||
else:
|
else:
|
||||||
data_shape = data[i].shape
|
data_shape = data[i].shape
|
||||||
|
|
||||||
@ -591,6 +593,10 @@ def check_array_lengths(inputs, targets, weights=None):
|
|||||||
ValueError: in case of incorrectly formatted data.
|
ValueError: in case of incorrectly formatted data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def is_tensor_or_composite_tensor(x):
|
||||||
|
return tensor_util.is_tensor(
|
||||||
|
x) or composite_tensor_utils.is_composite_or_composite_value(x)
|
||||||
|
|
||||||
def set_of_lengths(x):
|
def set_of_lengths(x):
|
||||||
# Returns a set with the variation between
|
# Returns a set with the variation between
|
||||||
# different shapes, with None => 0
|
# different shapes, with None => 0
|
||||||
@ -600,7 +606,7 @@ def check_array_lengths(inputs, targets, weights=None):
|
|||||||
return set([
|
return set([
|
||||||
y.shape[0]
|
y.shape[0]
|
||||||
for y in x
|
for y in x
|
||||||
if y is not None and not tensor_util.is_tensor(y)
|
if y is not None and not is_tensor_or_composite_tensor(y)
|
||||||
])
|
])
|
||||||
|
|
||||||
set_x = set_of_lengths(inputs)
|
set_x = set_of_lengths(inputs)
|
||||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.sparse
|
import scipy.sparse
|
||||||
|
|
||||||
@ -27,6 +29,7 @@ from tensorflow.python.data.ops import dataset_ops
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
from tensorflow.python.keras.engine import input_layer
|
from tensorflow.python.keras.engine import input_layer
|
||||||
@ -35,6 +38,7 @@ from tensorflow.python.keras.layers import Layer
|
|||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import sparse_ops
|
from tensorflow.python.ops import sparse_ops
|
||||||
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
from tensorflow.python.ops.ragged import ragged_test_util
|
from tensorflow.python.ops.ragged import ragged_test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -51,15 +55,18 @@ class ToDense(Layer):
|
|||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
if isinstance(inputs, ragged_tensor.RaggedTensor):
|
if isinstance(inputs, ragged_tensor.RaggedTensor):
|
||||||
return inputs.to_tensor(default_value=self._default_value)
|
output = inputs.to_tensor(default_value=self._default_value)
|
||||||
elif isinstance(inputs, sparse_tensor.SparseTensor):
|
elif isinstance(inputs, sparse_tensor.SparseTensor):
|
||||||
return sparse_ops.sparse_tensor_to_dense(
|
output = sparse_ops.sparse_tensor_to_dense(
|
||||||
inputs, default_value=self._default_value)
|
inputs, default_value=self._default_value)
|
||||||
elif isinstance(inputs, ops.Tensor):
|
elif isinstance(inputs, ops.Tensor):
|
||||||
return inputs
|
output = inputs
|
||||||
else:
|
else:
|
||||||
raise TypeError("Unexpected tensor type %s" % type(inputs).__name__)
|
raise TypeError("Unexpected tensor type %s" % type(inputs).__name__)
|
||||||
|
|
||||||
|
# Return a float so that we can compile models with this as the final layer.
|
||||||
|
return math_ops.cast(output, dtypes.float32)
|
||||||
|
|
||||||
|
|
||||||
class ToRagged(Layer):
|
class ToRagged(Layer):
|
||||||
"""Create a ragged tensor based on a given dense tensor."""
|
"""Create a ragged tensor based on a given dense tensor."""
|
||||||
@ -94,7 +101,7 @@ class _SubclassModel(keras.Model):
|
|||||||
for i, layer in enumerate(layers):
|
for i, layer in enumerate(layers):
|
||||||
setattr(self, self._layer_name_for_i(i), layer)
|
setattr(self, self._layer_name_for_i(i), layer)
|
||||||
self.num_layers = len(layers)
|
self.num_layers = len(layers)
|
||||||
if i_layer:
|
if i_layer is not None:
|
||||||
self._set_inputs(i_layer)
|
self._set_inputs(i_layer)
|
||||||
|
|
||||||
def _layer_name_for_i(self, i):
|
def _layer_name_for_i(self, i):
|
||||||
@ -131,7 +138,7 @@ def get_model_from_layers_with_input(layers,
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
if model_type == "functional":
|
if model_type == "functional":
|
||||||
if model_input:
|
if model_input is not None:
|
||||||
inputs = model_input
|
inputs = model_input
|
||||||
else:
|
else:
|
||||||
if not input_shape:
|
if not input_shape:
|
||||||
@ -267,16 +274,96 @@ class CompositeTensorOutputTest(keras_parameterized.TestCase,
|
|||||||
self.assertAllEqual(output.dense_shape, expected_dense_shape)
|
self.assertAllEqual(output.dense_shape, expected_dense_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def get_input_name(use_dict):
|
||||||
|
# Define the input name.
|
||||||
|
if not use_dict:
|
||||||
|
return None # This is the same as not setting 'name'.
|
||||||
|
elif testing_utils.get_model_type() == "subclass":
|
||||||
|
return "input_1" # Subclass models don"t support input names.
|
||||||
|
else:
|
||||||
|
return "test_input_name"
|
||||||
|
|
||||||
|
|
||||||
|
def get_steps():
|
||||||
|
# Determine the steps arg (if appropriate)
|
||||||
|
if not testing_utils.should_run_eagerly():
|
||||||
|
# CompositeTensors in graph mode are symbolic and so require a steps arg.
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_inputs(data, use_dict, use_dataset, action, input_name):
|
||||||
|
input_data, expected_output = data
|
||||||
|
# Prepare the input data.
|
||||||
|
if use_dict:
|
||||||
|
input_data = {input_name: input_data}
|
||||||
|
if use_dataset:
|
||||||
|
if action == "predict":
|
||||||
|
input_data = dataset_ops.Dataset.from_tensors(input_data)
|
||||||
|
else:
|
||||||
|
input_data = dataset_ops.Dataset.from_tensors(
|
||||||
|
(input_data, expected_output))
|
||||||
|
expected_output = None
|
||||||
|
return (input_data, expected_output)
|
||||||
|
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types
|
@keras_parameterized.run_with_all_model_types
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
*test_util.generate_combinations_with_testcase_name(
|
||||||
|
use_dict=[True, False],
|
||||||
|
use_dataset=[True, False],
|
||||||
|
action=["predict", "evaluate", "fit"]))
|
||||||
class SparseTensorInputTest(keras_parameterized.TestCase,
|
class SparseTensorInputTest(keras_parameterized.TestCase,
|
||||||
ragged_test_util.RaggedTensorTestCase):
|
ragged_test_util.RaggedTensorTestCase):
|
||||||
|
|
||||||
|
def test_sparse_tensors(self, use_dict, use_dataset, action):
|
||||||
|
data = [(sparse_tensor.SparseTensor([[0, 0, 0], [1, 0, 0], [1, 0, 1]],
|
||||||
|
[1, 2, 3], [2, 1, 3]),
|
||||||
|
np.array([[[1, -1, -1]], [[2, 3, -1]]])),
|
||||||
|
(sparse_tensor.SparseTensor(
|
||||||
|
[[0, 0, 0], [1, 0, 0], [1, 0, 1], [2, 0, 1]], [5, 6, 7, 8],
|
||||||
|
[3, 1, 4]),
|
||||||
|
np.array([[[5, -1, -1, -1]], [[6, 7, -1, -1]], [[-1, 8, -1,
|
||||||
|
-1]]]))]
|
||||||
|
# Prepare the model to test.
|
||||||
|
input_name = get_input_name(use_dict)
|
||||||
|
model_input = input_layer.Input(
|
||||||
|
shape=(1, None), sparse=True, name=input_name, dtype=dtypes.int32)
|
||||||
|
layers = [ToDense(default_value=-1)]
|
||||||
|
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
||||||
|
model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
|
||||||
|
steps = get_steps()
|
||||||
|
|
||||||
|
# Prepare the input data
|
||||||
|
for data_element in data:
|
||||||
|
input_data, expected_output = prepare_inputs(data_element, use_dict,
|
||||||
|
use_dataset, action,
|
||||||
|
input_name)
|
||||||
|
# Perform the action.
|
||||||
|
if action == "predict":
|
||||||
|
result = model.predict(input_data, steps=steps)
|
||||||
|
self.assertAllEqual(expected_output, result)
|
||||||
|
if action == "evaluate":
|
||||||
|
result = model.evaluate(input_data, expected_output, steps=steps)
|
||||||
|
self.assertAllEqual(1.0, result[-1])
|
||||||
|
if action == "fit":
|
||||||
|
# TODO(momernick): What's the best way of validating that fit happened?
|
||||||
|
_ = model.fit(
|
||||||
|
input_data, expected_output, shuffle=False, steps_per_epoch=steps)
|
||||||
|
|
||||||
|
|
||||||
|
@keras_parameterized.run_with_all_model_types
|
||||||
|
@keras_parameterized.run_all_keras_modes
|
||||||
|
class ScipySparseTensorInputTest(keras_parameterized.TestCase,
|
||||||
|
ragged_test_util.RaggedTensorTestCase):
|
||||||
|
|
||||||
def test_sparse_scipy_predict_inputs_via_input_layer_args(self):
|
def test_sparse_scipy_predict_inputs_via_input_layer_args(self):
|
||||||
# Create a model that accepts a sparse input and converts the sparse tensor
|
# Create a model that accepts a sparse input and converts the sparse tensor
|
||||||
# back to a dense tensor. Scipy sparse matrices are limited to 2D, so use
|
# back to a dense tensor. Scipy sparse matrices are limited to 2D, so use
|
||||||
# a one-dimensional shape.
|
# a one-dimensional shape; note also that scipy's default dtype is int64.
|
||||||
model_input = input_layer.Input(shape=(3,), sparse=True)
|
model_input = input_layer.Input(shape=(3,), sparse=True, dtype=dtypes.int64)
|
||||||
layers = [ToDense(default_value=-1)]
|
layers = [ToDense(default_value=-1)]
|
||||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
||||||
|
|
||||||
@ -295,8 +382,8 @@ class SparseTensorInputTest(keras_parameterized.TestCase,
|
|||||||
def test_sparse_scipy_eval_inputs(self):
|
def test_sparse_scipy_eval_inputs(self):
|
||||||
# Create a model that accepts a sparse input and converts the sparse tensor
|
# Create a model that accepts a sparse input and converts the sparse tensor
|
||||||
# back to a dense tensor. Scipy sparse matrices are limited to 2D, so use
|
# back to a dense tensor. Scipy sparse matrices are limited to 2D, so use
|
||||||
# a one-dimensional shape.
|
# a one-dimensional shape; note also that scipy's default dtype is int64.
|
||||||
model_input = input_layer.Input(shape=(3,), sparse=True)
|
model_input = input_layer.Input(shape=(3,), sparse=True, dtype=dtypes.int64)
|
||||||
layers = [ToDense(default_value=-1)]
|
layers = [ToDense(default_value=-1)]
|
||||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
||||||
model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
|
model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
|
||||||
@ -317,12 +404,13 @@ class SparseTensorInputTest(keras_parameterized.TestCase,
|
|||||||
def test_sparse_scipy_predict_input_dicts_via_input_layer_args(self):
|
def test_sparse_scipy_predict_input_dicts_via_input_layer_args(self):
|
||||||
# Create a model that accepts a sparse input and converts the sparse tensor
|
# Create a model that accepts a sparse input and converts the sparse tensor
|
||||||
# back to a dense tensor. Scipy sparse matrices are limited to 2D, so use
|
# back to a dense tensor. Scipy sparse matrices are limited to 2D, so use
|
||||||
# a one-dimensional shape.
|
# a one-dimensional shape; note also that scipy's default dtype is int64.
|
||||||
if testing_utils.get_model_type() == "subclass":
|
if testing_utils.get_model_type() == "subclass":
|
||||||
input_name = "input_1" # Subclass models don"t support input names.
|
input_name = "input_1" # Subclass models don"t support input names.
|
||||||
else:
|
else:
|
||||||
input_name = "test_input_name"
|
input_name = "test_input_name"
|
||||||
model_input = input_layer.Input(shape=(3,), sparse=True, name=input_name)
|
model_input = input_layer.Input(
|
||||||
|
shape=(3,), sparse=True, name=input_name, dtype=dtypes.int64)
|
||||||
layers = [ToDense(default_value=-1)]
|
layers = [ToDense(default_value=-1)]
|
||||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
||||||
|
|
||||||
@ -347,12 +435,13 @@ class SparseTensorInputTest(keras_parameterized.TestCase,
|
|||||||
def test_sparse_scipy_eval_input_dicts(self):
|
def test_sparse_scipy_eval_input_dicts(self):
|
||||||
# Create a model that accepts a sparse input and converts the sparse tensor
|
# Create a model that accepts a sparse input and converts the sparse tensor
|
||||||
# back to a dense tensor. Scipy sparse matrices are limited to 2D, so use
|
# back to a dense tensor. Scipy sparse matrices are limited to 2D, so use
|
||||||
# a one-dimensional shape.
|
# a one-dimensional shape; note also that scipy's default dtype is int64.
|
||||||
if testing_utils.get_model_type() == "subclass":
|
if testing_utils.get_model_type() == "subclass":
|
||||||
input_name = "input_1" # Subclass models don"t support input names.
|
input_name = "input_1" # Subclass models don"t support input names.
|
||||||
else:
|
else:
|
||||||
input_name = "test_input_name"
|
input_name = "test_input_name"
|
||||||
model_input = input_layer.Input(shape=(3,), sparse=True, name=input_name)
|
model_input = input_layer.Input(
|
||||||
|
shape=(3,), sparse=True, name=input_name, dtype=dtypes.int64)
|
||||||
layers = [ToDense(default_value=-1)]
|
layers = [ToDense(default_value=-1)]
|
||||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
||||||
model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
|
model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
|
||||||
@ -375,235 +464,132 @@ class SparseTensorInputTest(keras_parameterized.TestCase,
|
|||||||
output_2 = model.evaluate(input_data_2, expected_output_2, steps=1)
|
output_2 = model.evaluate(input_data_2, expected_output_2, steps=1)
|
||||||
self.assertAllEqual(1.0, output_2[-1])
|
self.assertAllEqual(1.0, output_2[-1])
|
||||||
|
|
||||||
def test_sparse_tensor_eval_inputs(self):
|
|
||||||
# Create a model that accepts a sparse input and converts the sparse tensor
|
@keras_parameterized.run_with_all_model_types
|
||||||
# back to a dense tensor.
|
@keras_parameterized.run_all_keras_modes
|
||||||
model_input = input_layer.Input(shape=(1, None), sparse=True)
|
@parameterized.named_parameters(
|
||||||
|
*test_util.generate_combinations_with_testcase_name(
|
||||||
|
use_dict=[True, False],
|
||||||
|
use_dataset=[True, False],
|
||||||
|
action=["predict", "evaluate", "fit"]))
|
||||||
|
class RaggedTensorInputTest(keras_parameterized.TestCase,
|
||||||
|
ragged_test_util.RaggedTensorTestCase):
|
||||||
|
|
||||||
|
def test_ragged_input(self, use_dict, use_dataset, action):
|
||||||
|
data = [(ragged_factory_ops.constant([[[1]], [[2, 3]]]),
|
||||||
|
np.array([[[1, -1]], [[2, 3]]]))]
|
||||||
|
|
||||||
|
# Prepare the model to test.
|
||||||
|
input_name = get_input_name(use_dict)
|
||||||
|
model_input = input_layer.Input(
|
||||||
|
shape=(None, None), ragged=True, name=input_name, dtype=dtypes.int32)
|
||||||
layers = [ToDense(default_value=-1)]
|
layers = [ToDense(default_value=-1)]
|
||||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
||||||
model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
|
model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
|
||||||
|
|
||||||
|
# Prepare the input data
|
||||||
|
for data_element in data:
|
||||||
|
input_data, expected_output = prepare_inputs(data_element, use_dict,
|
||||||
|
use_dataset, action,
|
||||||
|
input_name)
|
||||||
|
# Perform the action.
|
||||||
|
if action == "predict":
|
||||||
|
result = model.predict(input_data)
|
||||||
|
self.assertAllEqual(expected_output, result)
|
||||||
|
if action == "evaluate":
|
||||||
|
result = model.evaluate(input_data, expected_output)
|
||||||
|
self.assertAllEqual(1.0, result[-1])
|
||||||
|
if action == "fit":
|
||||||
|
# TODO(momernick): What's the best way of validating that fit happened?
|
||||||
|
_ = model.fit(input_data, expected_output, shuffle=False)
|
||||||
|
|
||||||
|
|
||||||
|
@keras_parameterized.run_with_all_model_types
|
||||||
|
@keras_parameterized.run_all_keras_modes
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
*test_util.generate_combinations_with_testcase_name(
|
||||||
|
use_dict=[True, False], use_dataset=[True, False]))
|
||||||
|
class RaggedTensorInputValidationTest(keras_parameterized.TestCase,
|
||||||
|
ragged_test_util.RaggedTensorTestCase):
|
||||||
|
|
||||||
|
def test_ragged_tensor_input_with_one_none_dimension(self, use_dict,
|
||||||
|
use_dataset):
|
||||||
# Define some input data.
|
# Define some input data.
|
||||||
input_data = sparse_tensor.SparseTensor([[0, 0, 0], [1, 0, 0], [1, 0, 1]],
|
data = [(ragged_factory_ops.constant([[[1, 0]], [[2, 3]]], ragged_rank=1),
|
||||||
[1, 2, 3], [2, 1, 3])
|
np.array([[[1, 0]], [[2, 3]]]))]
|
||||||
expected_output = np.array([[[1, -1, -1]], [[2, 3, -1]]])
|
|
||||||
output = model.evaluate(input_data, expected_output, steps=1)
|
|
||||||
self.assertAllEqual(1.0, output[-1])
|
|
||||||
|
|
||||||
input_data_2 = sparse_tensor.SparseTensor(
|
# Prepare the model to test.
|
||||||
[[0, 0, 0], [1, 0, 0], [1, 0, 1], [2, 0, 1]], [5, 6, 7, 8], [3, 1, 4])
|
input_shape = (None, 2) # RaggedTensorInputTest uses (None, None).
|
||||||
expected_output_2 = np.array([[[5, -1, -1, -1]], [[6, 7, -1, -1]],
|
input_name = get_input_name(use_dict)
|
||||||
[[-1, 8, -1, -1]]])
|
|
||||||
output_2 = model.evaluate(input_data_2, expected_output_2, steps=1)
|
|
||||||
self.assertAllEqual(1.0, output_2[-1])
|
|
||||||
|
|
||||||
def test_sparse_tensor_predict_inputs_via_input_layer_args(self):
|
|
||||||
# Create a model that accepts a sparse input and converts the sparse tensor
|
|
||||||
# back to a dense tensor.
|
|
||||||
model_input = input_layer.Input(shape=(1, None), sparse=True)
|
|
||||||
layers = [ToDense(default_value=-1)]
|
|
||||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
|
||||||
|
|
||||||
# Define some input data.
|
|
||||||
input_data = sparse_tensor.SparseTensor([[0, 0, 0], [1, 0, 0], [1, 0, 1]],
|
|
||||||
[1, 2, 3], [2, 1, 3])
|
|
||||||
expected_output = np.array([[[1, -1, -1]], [[2, 3, -1]]])
|
|
||||||
output = model.predict(input_data, steps=1)
|
|
||||||
self.assertAllEqual(expected_output, output)
|
|
||||||
|
|
||||||
input_data_2 = sparse_tensor.SparseTensor(
|
|
||||||
[[0, 0, 0], [1, 0, 0], [1, 0, 1], [2, 0, 1]], [5, 6, 7, 8], [3, 1, 4])
|
|
||||||
expected_output_2 = np.array([[[5, -1, -1, -1]], [[6, 7, -1, -1]],
|
|
||||||
[[-1, 8, -1, -1]]])
|
|
||||||
output_2 = model.predict(input_data_2, steps=1)
|
|
||||||
self.assertAllEqual(expected_output_2, output_2)
|
|
||||||
|
|
||||||
def test_sparse_tensor_predict_input_dicts_via_input_layer_args(self):
|
|
||||||
# Create a model that accepts a sparse input and converts the sparse tensor
|
|
||||||
# back to a dense tensor.
|
|
||||||
if testing_utils.get_model_type() == "subclass":
|
|
||||||
input_name = "input_1" # Subclass models don"t support input names.
|
|
||||||
else:
|
|
||||||
input_name = "test_input_name"
|
|
||||||
model_input = input_layer.Input(
|
model_input = input_layer.Input(
|
||||||
shape=(1, None), sparse=True, name=input_name)
|
shape=input_shape, ragged=True, name=input_name, dtype=dtypes.int32)
|
||||||
layers = [ToDense(default_value=-1)]
|
|
||||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
|
||||||
|
|
||||||
# Define some input data.
|
|
||||||
input_data = {
|
|
||||||
input_name:
|
|
||||||
sparse_tensor.SparseTensor([[0, 0, 0], [1, 0, 0], [1, 0, 1]],
|
|
||||||
[1, 2, 3], [2, 1, 3])
|
|
||||||
}
|
|
||||||
expected_output = np.array([[[1, -1, -1]], [[2, 3, -1]]])
|
|
||||||
output = model.predict(input_data, steps=1)
|
|
||||||
self.assertAllEqual(expected_output, output)
|
|
||||||
|
|
||||||
input_data_2 = {
|
|
||||||
input_name:
|
|
||||||
sparse_tensor.SparseTensor(
|
|
||||||
[[0, 0, 0], [1, 0, 0], [1, 0, 1], [2, 0, 1]], [5, 6, 7, 8],
|
|
||||||
[3, 1, 4])
|
|
||||||
}
|
|
||||||
expected_output_2 = np.array([[[5, -1, -1, -1]], [[6, 7, -1, -1]],
|
|
||||||
[[-1, 8, -1, -1]]])
|
|
||||||
output_2 = model.predict(input_data_2, steps=1)
|
|
||||||
self.assertAllEqual(expected_output_2, output_2)
|
|
||||||
|
|
||||||
def test_sparse_tensor_eval_input_dicts_via_input_layer_args(self):
|
|
||||||
# Create a model that accepts a sparse input and converts the sparse tensor
|
|
||||||
# back to a dense tensor.
|
|
||||||
if testing_utils.get_model_type() == "subclass":
|
|
||||||
input_name = "input_1" # Subclass models don"t support input names.
|
|
||||||
else:
|
|
||||||
input_name = "test_input_name"
|
|
||||||
model_input = input_layer.Input(
|
|
||||||
shape=(1, None), sparse=True, name=input_name)
|
|
||||||
layers = [ToDense(default_value=-1)]
|
layers = [ToDense(default_value=-1)]
|
||||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
||||||
model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
|
model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
|
||||||
|
|
||||||
|
for data_element in data:
|
||||||
|
input_data, expected_output = prepare_inputs(
|
||||||
|
data_element,
|
||||||
|
use_dict,
|
||||||
|
use_dataset,
|
||||||
|
action="predict",
|
||||||
|
input_name=input_name)
|
||||||
|
result = model.predict(input_data)
|
||||||
|
self.assertAllEqual(expected_output, result)
|
||||||
|
|
||||||
|
def test_ragged_tensor_input_with_no_none_dimension(self, use_dict,
|
||||||
|
use_dataset):
|
||||||
# Define some input data.
|
# Define some input data.
|
||||||
input_data = {
|
data = [(ragged_factory_ops.constant([[[1, 0]], [[2, 3]]], ragged_rank=0),
|
||||||
input_name:
|
np.array([[[1, 0]], [[2, 3]]]))]
|
||||||
sparse_tensor.SparseTensor([[0, 0, 0], [1, 0, 0], [1, 0, 1]],
|
|
||||||
[1, 2, 3], [2, 1, 3])
|
|
||||||
}
|
|
||||||
expected_output = np.array([[[1, -1, -1]], [[2, 3, -1]]])
|
|
||||||
output = model.evaluate(input_data, expected_output, steps=1)
|
|
||||||
self.assertAllEqual(1.0, output[-1])
|
|
||||||
|
|
||||||
input_data_2 = {
|
# Prepare the model to test.
|
||||||
input_name:
|
input_shape = (1, 2) # RaggedTensorInputTest uses (None, None).
|
||||||
sparse_tensor.SparseTensor(
|
input_name = get_input_name(use_dict)
|
||||||
[[0, 0, 0], [1, 0, 0], [1, 0, 1], [2, 0, 1]], [5, 6, 7, 8],
|
model_input = input_layer.Input(
|
||||||
[3, 1, 4])
|
shape=input_shape, ragged=True, name=input_name, dtype=dtypes.int32)
|
||||||
}
|
|
||||||
expected_output_2 = np.array([[[5, -1, -1, -1]], [[6, 7, -1, -1]],
|
|
||||||
[[-1, 8, -1, -1]]])
|
|
||||||
output_2 = model.evaluate(input_data_2, expected_output_2, steps=1)
|
|
||||||
self.assertAllEqual(1.0, output_2[-1])
|
|
||||||
|
|
||||||
def test_sparse_tensor_dataset_predict_inputs_via_input_layer_args(self):
|
|
||||||
# Create a model that accepts a sparse input and converts the sparse tensor
|
|
||||||
# back to a dense tensor.
|
|
||||||
model_input = input_layer.Input(shape=(1, None), sparse=True)
|
|
||||||
layers = [ToDense(default_value=-1)]
|
|
||||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
|
||||||
|
|
||||||
# Define some input data.
|
|
||||||
input_data = dataset_ops.Dataset.from_tensors(
|
|
||||||
sparse_tensor.SparseTensor([[0, 0, 0], [1, 0, 0], [1, 0, 1]], [1, 2, 3],
|
|
||||||
[2, 1, 3]))
|
|
||||||
expected_output = np.array([[[1, -1, -1]], [[2, 3, -1]]])
|
|
||||||
output = model.predict(input_data)
|
|
||||||
self.assertAllEqual(expected_output, output)
|
|
||||||
|
|
||||||
input_data_2 = dataset_ops.Dataset.from_tensors(
|
|
||||||
sparse_tensor.SparseTensor([[0, 0, 0], [1, 0, 0], [1, 0, 1], [2, 0, 1]],
|
|
||||||
[5, 6, 7, 8], [3, 1, 4]))
|
|
||||||
expected_output_2 = np.array([[[5, -1, -1, -1]], [[6, 7, -1, -1]],
|
|
||||||
[[-1, 8, -1, -1]]])
|
|
||||||
output_2 = model.predict(input_data_2)
|
|
||||||
self.assertAllEqual(expected_output_2, output_2)
|
|
||||||
|
|
||||||
def test_sparse_tensor_dataset_eval_inputs_via_input_layer_args(self):
|
|
||||||
# Create a model that accepts a sparse input and converts the sparse tensor
|
|
||||||
# back to a dense tensor.
|
|
||||||
model_input = input_layer.Input(shape=(1, None), sparse=True)
|
|
||||||
layers = [ToDense(default_value=-1)]
|
layers = [ToDense(default_value=-1)]
|
||||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
||||||
model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
|
model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
|
||||||
|
|
||||||
# Define some input data.
|
# The input is a symbolic tensor in non-Eager modes, so 'steps' is required
|
||||||
input_tensor = sparse_tensor.SparseTensor([[0, 0, 0], [1, 0, 0], [1, 0, 1]],
|
# for that case only.
|
||||||
[1, 2, 3], [2, 1, 3])
|
steps = get_steps()
|
||||||
expected_output = np.array([[[1, -1, -1]], [[2, 3, -1]]])
|
|
||||||
input_data = dataset_ops.Dataset.from_tensors(
|
|
||||||
(input_tensor, expected_output))
|
|
||||||
output = model.evaluate(input_data)
|
|
||||||
self.assertAllEqual(1.0, output[-1])
|
|
||||||
|
|
||||||
input_tensor_2 = sparse_tensor.SparseTensor(
|
for data_element in data:
|
||||||
[[0, 0, 0], [1, 0, 0], [1, 0, 1], [2, 0, 1]], [5, 6, 7, 8], [3, 1, 4])
|
input_data, expected_output = prepare_inputs(
|
||||||
expected_output_2 = np.array([[[5, -1, -1, -1]], [[6, 7, -1, -1]],
|
data_element,
|
||||||
[[-1, 8, -1, -1]]])
|
use_dict,
|
||||||
input_data_2 = dataset_ops.Dataset.from_tensors(
|
use_dataset,
|
||||||
(input_tensor_2, expected_output_2))
|
action="predict",
|
||||||
output_2 = model.evaluate(input_data_2)
|
input_name=input_name)
|
||||||
self.assertAllEqual(1.0, output_2[-1])
|
result = model.predict(input_data, steps=steps)
|
||||||
|
self.assertAllEqual(expected_output, result)
|
||||||
|
|
||||||
def test_sparse_tensor_dataset_dict_predict_inputs_via_input_layer_args(self):
|
def test_ragged_tensor_input_with_wrong_ragged_rank_fails(
|
||||||
# Create a model that accepts a sparse input and converts the sparse tensor
|
self, use_dict, use_dataset):
|
||||||
# back to a dense tensor.
|
# Define some input data that will NOT match the input shape spec.
|
||||||
if testing_utils.get_model_type() == "subclass":
|
data = [(ragged_factory_ops.constant([[[1, 0]], [[2, 3]]]), None)]
|
||||||
input_name = "input_1" # Subclass models don"t support custom input names
|
|
||||||
else:
|
# Prepare the model to test.
|
||||||
input_name = "test_input_name"
|
input_shape = (None, 2) # RaggedTensorInputTest uses (None, None).
|
||||||
|
input_name = get_input_name(use_dict)
|
||||||
model_input = input_layer.Input(
|
model_input = input_layer.Input(
|
||||||
shape=(1, None), sparse=True, name=input_name)
|
shape=input_shape, ragged=True, name=input_name, dtype=dtypes.int32)
|
||||||
layers = [ToDense(default_value=-1)]
|
|
||||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
|
||||||
|
|
||||||
# Define some input data.
|
|
||||||
input_data = dataset_ops.Dataset.from_tensors({
|
|
||||||
input_name:
|
|
||||||
sparse_tensor.SparseTensor([[0, 0, 0], [1, 0, 0], [1, 0, 1]],
|
|
||||||
[1, 2, 3], [2, 1, 3])
|
|
||||||
})
|
|
||||||
expected_output = np.array([[[1, -1, -1]], [[2, 3, -1]]])
|
|
||||||
output = model.predict(input_data)
|
|
||||||
self.assertAllEqual(expected_output, output)
|
|
||||||
|
|
||||||
input_data_2 = dataset_ops.Dataset.from_tensors({
|
|
||||||
input_name:
|
|
||||||
sparse_tensor.SparseTensor(
|
|
||||||
[[0, 0, 0], [1, 0, 0], [1, 0, 1], [2, 0, 1]], [5, 6, 7, 8],
|
|
||||||
[3, 1, 4])
|
|
||||||
})
|
|
||||||
expected_output_2 = np.array([[[5, -1, -1, -1]], [[6, 7, -1, -1]],
|
|
||||||
[[-1, 8, -1, -1]]])
|
|
||||||
output_2 = model.predict(input_data_2)
|
|
||||||
self.assertAllEqual(expected_output_2, output_2)
|
|
||||||
|
|
||||||
def test_sparse_tensor_dataset_dict_eval_inputs_via_input_layer_args(self):
|
|
||||||
# Create a model that accepts a sparse input and converts the sparse tensor
|
|
||||||
# back to a dense tensor.
|
|
||||||
if testing_utils.get_model_type() == "subclass":
|
|
||||||
input_name = "input_1" # Subclass models don"t support custom input names
|
|
||||||
else:
|
|
||||||
input_name = "test_input_name"
|
|
||||||
model_input = input_layer.Input(
|
|
||||||
shape=(1, None), sparse=True, name=input_name)
|
|
||||||
layers = [ToDense(default_value=-1)]
|
layers = [ToDense(default_value=-1)]
|
||||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
||||||
model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
|
model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
|
||||||
|
|
||||||
# Define some input data.
|
# Define some input data with the wrong ragged rank
|
||||||
input_tensor = {
|
for data_element in data:
|
||||||
input_name:
|
input_data, _ = prepare_inputs(
|
||||||
sparse_tensor.SparseTensor([[0, 0, 0], [1, 0, 0], [1, 0, 1]],
|
data_element,
|
||||||
[1, 2, 3], [2, 1, 3])
|
use_dict,
|
||||||
}
|
use_dataset,
|
||||||
expected_output = np.array([[[1, -1, -1]], [[2, 3, -1]]])
|
action="predict",
|
||||||
input_data = dataset_ops.Dataset.from_tensors(
|
input_name=input_name)
|
||||||
(input_tensor, expected_output))
|
with self.assertRaisesRegex(ValueError, ".*don't have the same nested.*"):
|
||||||
output = model.evaluate(input_data)
|
_ = model.predict(input_data)
|
||||||
self.assertAllEqual(1.0, output[-1])
|
|
||||||
|
|
||||||
input_tensor_2 = {
|
|
||||||
input_name:
|
|
||||||
sparse_tensor.SparseTensor(
|
|
||||||
[[0, 0, 0], [1, 0, 0], [1, 0, 1], [2, 0, 1]], [5, 6, 7, 8],
|
|
||||||
[3, 1, 4])
|
|
||||||
}
|
|
||||||
expected_output_2 = np.array([[[5, -1, -1, -1]], [[6, 7, -1, -1]],
|
|
||||||
[[-1, 8, -1, -1]]])
|
|
||||||
input_data_2 = dataset_ops.Dataset.from_tensors(
|
|
||||||
(input_tensor_2, expected_output_2))
|
|
||||||
output_2 = model.evaluate(input_data_2)
|
|
||||||
self.assertAllEqual(1.0, output_2[-1])
|
|
||||||
|
|
||||||
|
|
||||||
# CompositeTensor shape validation only happens in non-eager modes and in non-
|
# CompositeTensor shape validation only happens in non-eager modes and in non-
|
||||||
@ -614,27 +600,48 @@ class SparseTensorInputValidationTest(keras_parameterized.TestCase,
|
|||||||
ragged_test_util.RaggedTensorTestCase):
|
ragged_test_util.RaggedTensorTestCase):
|
||||||
|
|
||||||
def test_sparse_scipy_input_checks_shape(self):
|
def test_sparse_scipy_input_checks_shape(self):
|
||||||
model_input = input_layer.Input(shape=(3,), sparse=True)
|
model_input = input_layer.Input(shape=(3,), sparse=True, dtype=dtypes.int32)
|
||||||
layers = [ToDense(default_value=-1)]
|
layers = [ToDense(default_value=-1)]
|
||||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
||||||
|
|
||||||
input_data = scipy.sparse.coo_matrix(([1, 2, 3], ([0, 1, 1], [0, 0, 1])),
|
input_data = scipy.sparse.coo_matrix(([1, 2, 3], ([0, 1, 1], [0, 0, 1])),
|
||||||
shape=[2, 4])
|
shape=[2, 4])
|
||||||
with self.assertRaisesRegex(ValueError, ".*got array with shape.*"):
|
with self.assertRaisesRegex(ValueError, ".*got array with shape.*"):
|
||||||
_ = model.predict(input_data, steps=1)
|
_ = model.predict(input_data)
|
||||||
|
|
||||||
def test_sparse_tensor_input_checks_shapes(self):
|
def test_sparse_tensor_input_checks_shapes(self):
|
||||||
# Create a model that accepts a sparse input and converts the sparse tensor
|
# Create a model that accepts a sparse input and converts the sparse tensor
|
||||||
# back to a dense tensor.
|
# back to a dense tensor.
|
||||||
model_input = input_layer.Input(shape=(2, None), sparse=True)
|
model_input = input_layer.Input(
|
||||||
|
shape=(2, None), sparse=True, dtype=dtypes.int32)
|
||||||
layers = [ToDense(default_value=-1)]
|
layers = [ToDense(default_value=-1)]
|
||||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
||||||
|
|
||||||
# Define some input data.
|
# Define some input data.
|
||||||
input_data = sparse_tensor.SparseTensor([[0, 0, 0], [1, 0, 0], [1, 0, 1]],
|
input_data = sparse_tensor.SparseTensor([[0, 0, 0], [1, 0, 0], [1, 0, 1]],
|
||||||
[1, 2, 3], [2, 1, 3])
|
[1, 2, 3], [2, 1, 3])
|
||||||
|
if not testing_utils.should_run_eagerly():
|
||||||
|
# This ragged tensor is actually a standard tensor (as it has no ragged
|
||||||
|
# dimensions). Because of this, graph mode models will expect a steps
|
||||||
|
# arg to be passed (as SparseTensors in graph mode are symbolic).
|
||||||
|
steps = 1
|
||||||
|
else:
|
||||||
|
steps = None
|
||||||
with self.assertRaisesRegex(ValueError, ".*got array with shape.*"):
|
with self.assertRaisesRegex(ValueError, ".*got array with shape.*"):
|
||||||
_ = model.predict(input_data, steps=1)
|
_ = model.predict(input_data, steps=steps)
|
||||||
|
|
||||||
|
def test_ragged_tensor_input_with_wrong_value_shape(self):
|
||||||
|
# Create a model that accepts a ragged input and converts it to dense.
|
||||||
|
model_input = input_layer.Input(
|
||||||
|
shape=(None, 4), ragged=True, dtype=dtypes.int32)
|
||||||
|
layers = [ToDense(default_value=-1)]
|
||||||
|
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
||||||
|
|
||||||
|
# Define some input data with the wrong ragged rank
|
||||||
|
input_data = ragged_factory_ops.constant([[[1, 0]], [[2, 3]]],
|
||||||
|
ragged_rank=1)
|
||||||
|
with self.assertRaisesRegex(ValueError, ".*got array with shape.*"):
|
||||||
|
_ = model.predict(input_data)
|
||||||
|
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types(
|
@keras_parameterized.run_with_all_model_types(
|
||||||
@ -648,13 +655,14 @@ class UndefinedCompositeTensorInputsTest(keras_parameterized.TestCase,
|
|||||||
# back to a dense tensor.
|
# back to a dense tensor.
|
||||||
layers = [ToDense(default_value=-1)]
|
layers = [ToDense(default_value=-1)]
|
||||||
model = testing_utils.get_model_from_layers(layers)
|
model = testing_utils.get_model_from_layers(layers)
|
||||||
|
steps = get_steps()
|
||||||
|
|
||||||
# Define some input data.
|
# Define some input data.
|
||||||
input_data = sparse_tensor.SparseTensor([[0, 0], [1, 0], [1, 1]], [1, 2, 3],
|
input_data = sparse_tensor.SparseTensor([[0, 0], [1, 0], [1, 1]], [1, 2, 3],
|
||||||
[2, 3])
|
[2, 3])
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, ".*All SparseTensor and RaggedTensor inputs .*"):
|
ValueError, ".*All SparseTensor and RaggedTensor inputs .*"):
|
||||||
_ = model.predict(input_data, steps=1)
|
_ = model.predict(input_data, steps=steps)
|
||||||
|
|
||||||
def test_subclass_implicit_sparse_scipy_inputs_fails(self):
|
def test_subclass_implicit_sparse_scipy_inputs_fails(self):
|
||||||
# Create a model that accepts a sparse input and converts the sparse tensor
|
# Create a model that accepts a sparse input and converts the sparse tensor
|
||||||
@ -666,7 +674,7 @@ class UndefinedCompositeTensorInputsTest(keras_parameterized.TestCase,
|
|||||||
input_data = scipy.sparse.coo_matrix(([1, 2, 3], ([0, 1, 1], [0, 0, 1])),
|
input_data = scipy.sparse.coo_matrix(([1, 2, 3], ([0, 1, 1], [0, 0, 1])),
|
||||||
shape=[2, 3])
|
shape=[2, 3])
|
||||||
with self.assertRaisesRegex(ValueError, ".*either a single array.*"):
|
with self.assertRaisesRegex(ValueError, ".*either a single array.*"):
|
||||||
_ = model.predict(input_data, steps=1)
|
_ = model.predict(input_data)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -106,7 +106,7 @@ def get_reachable_from_inputs(inputs, targets=None):
|
|||||||
Returns:
|
Returns:
|
||||||
A set of tensors reachable from the inputs (includes the inputs themselves).
|
A set of tensors reachable from the inputs (includes the inputs themselves).
|
||||||
"""
|
"""
|
||||||
inputs = nest.flatten(inputs)
|
inputs = nest.flatten(inputs, expand_composites=True)
|
||||||
reachable = set(inputs)
|
reachable = set(inputs)
|
||||||
if targets and not isinstance(targets, set):
|
if targets and not isinstance(targets, set):
|
||||||
targets = nest.flatten(targets)
|
targets = nest.flatten(targets)
|
||||||
|
@ -342,7 +342,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "placeholder"
|
name: "placeholder"
|
||||||
argspec: "args=[\'shape\', \'ndim\', \'dtype\', \'sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
|
argspec: "args=[\'shape\', \'ndim\', \'dtype\', \'sparse\', \'name\', \'ragged\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "pool2d"
|
name: "pool2d"
|
||||||
|
@ -108,7 +108,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
|
argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "add_loss"
|
name: "add_loss"
|
||||||
|
@ -418,7 +418,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Input"
|
name: "Input"
|
||||||
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
|
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "add"
|
name: "add"
|
||||||
|
@ -86,6 +86,6 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Input"
|
name: "Input"
|
||||||
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
|
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -338,7 +338,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "placeholder"
|
name: "placeholder"
|
||||||
argspec: "args=[\'shape\', \'ndim\', \'dtype\', \'sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
|
argspec: "args=[\'shape\', \'ndim\', \'dtype\', \'sparse\', \'name\', \'ragged\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "pool2d"
|
name: "pool2d"
|
||||||
|
@ -108,7 +108,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
|
argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "add_loss"
|
name: "add_loss"
|
||||||
|
@ -410,7 +410,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Input"
|
name: "Input"
|
||||||
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
|
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "add"
|
name: "add"
|
||||||
|
@ -86,6 +86,6 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Input"
|
name: "Input"
|
||||||
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
|
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user