Remove Layer._supports_ragged_inputs property.
This property made it more difficult to create a Layer that supports RaggedTensors, since by default every user-created Layer class was assumed to not work with RaggedTensors. Instead, an error message is added to common built-in Layer subclasses that don't support RaggedTensors. PiperOrigin-RevId: 309315394 Change-Id: Id587d99cfaa4890c41aee49ec437f96108b4fbc7
This commit is contained in:
parent
d11820b4cf
commit
862a62d6b4
@ -313,7 +313,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
# Provides information about which inputs are compatible with the layer.
|
||||
self._input_spec = None
|
||||
self.supports_masking = False
|
||||
self._supports_ragged_inputs = False
|
||||
|
||||
self._init_set_name(name)
|
||||
self._activity_regularizer = kwargs.pop('activity_regularizer', None)
|
||||
@ -905,12 +904,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
# are casted, not before.
|
||||
input_spec.assert_input_compatibility(self.input_spec, inputs,
|
||||
self.name)
|
||||
if (any(isinstance(x, ragged_tensor.RaggedTensor) for x in input_list)
|
||||
and not self._supports_ragged_inputs):
|
||||
raise ValueError('Layer %s does not support RaggedTensors as input. '
|
||||
'Inputs received: %s. You can try converting your '
|
||||
'input to an uniform tensor.' % (self.name, inputs))
|
||||
|
||||
graph = backend.get_graph()
|
||||
with graph.as_default(), backend.name_scope(self._name_scope()):
|
||||
# Build layer if applicable (if the `build` method has been
|
||||
|
@ -1454,7 +1454,6 @@ class DTypeTest(keras_parameterized.TestCase):
|
||||
row_splits=array_ops.constant([0, 2, 2, 3], dtype='int64'))
|
||||
|
||||
layer = IdentityLayer(dtype='float16')
|
||||
layer._supports_ragged_inputs = True
|
||||
|
||||
for x in sparse, ragged:
|
||||
self.assertEqual(x.dtype, 'float32')
|
||||
@ -1462,19 +1461,6 @@ class DTypeTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(y.dtype, 'float16')
|
||||
self.assertEqual(type(x), type(y))
|
||||
|
||||
def test_supports_ragged_inputs_attribute_error(self):
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'does not support RaggedTensors'):
|
||||
ragged = ragged_tensor.RaggedTensor.from_row_splits(
|
||||
values=array_ops.constant([1., 2., 3.], dtype='float32'),
|
||||
row_splits=array_ops.constant([0, 2, 2, 3], dtype='int64'))
|
||||
model = sequential.Sequential([
|
||||
input_layer.InputLayer(input_shape=(None,), ragged=True),
|
||||
IdentityLayer()
|
||||
])
|
||||
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
||||
model.train_on_batch(ragged)
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_passing_non_tensor(self):
|
||||
layer = IdentityLayer()
|
||||
|
@ -34,6 +34,7 @@ from tensorflow.python.ops import control_flow_v2_func_graphs
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import init_ops_v2
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.training.tracking import base as tracking
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
@ -792,6 +793,14 @@ class TrackableWeightHandler(object):
|
||||
backend.get_session().run(self._assign_op, feed_dict)
|
||||
|
||||
|
||||
def no_ragged_support(inputs, layer_name):
|
||||
input_list = nest.flatten(inputs)
|
||||
if any(isinstance(x, ragged_tensor.RaggedTensor) for x in input_list):
|
||||
raise ValueError('Layer %s does not support RaggedTensors as input. '
|
||||
'Inputs received: %s. You can try converting your '
|
||||
'input to an uniform tensor.' % (layer_name, inputs))
|
||||
|
||||
|
||||
# TODO(kathywu): This is a temporary hack. When a network of layers is revived
|
||||
# from SavedModel, only the top-level layer will have losses. This causes issues
|
||||
# in eager mode because the child layers may have graph losses
|
||||
|
@ -182,7 +182,6 @@ class Layer(base_layer.Layer):
|
||||
# Provides information about which inputs are compatible with the layer.
|
||||
self._input_spec = None
|
||||
self.supports_masking = False
|
||||
self._supports_ragged_inputs = False
|
||||
|
||||
self._init_set_name(name)
|
||||
self._activity_regularizer = kwargs.pop('activity_regularizer', None)
|
||||
@ -746,12 +745,6 @@ class Layer(base_layer.Layer):
|
||||
# are casted, not before.
|
||||
input_spec.assert_input_compatibility(self.input_spec, inputs,
|
||||
self.name)
|
||||
if (any(isinstance(x, ragged_tensor.RaggedTensor) for x in input_list)
|
||||
and self._supports_ragged_inputs is False): # pylint: disable=g-bool-id-comparison
|
||||
raise ValueError('Layer %s does not support RaggedTensors as input. '
|
||||
'Inputs received: %s. You can try converting your '
|
||||
'input to an uniform tensor.' % (self.name, inputs))
|
||||
|
||||
graph = backend.get_graph()
|
||||
with graph.as_default(), backend.name_scope(self._name_scope()):
|
||||
# Build layer if applicable (if the `build` method has been
|
||||
|
@ -132,7 +132,6 @@ class InputLayer(base_layer.Layer):
|
||||
self.ragged = ragged
|
||||
self.batch_size = batch_size
|
||||
self.supports_masking = True
|
||||
self._supports_ragged_inputs = True
|
||||
|
||||
if isinstance(input_shape, tensor_shape.TensorShape):
|
||||
input_shape = tuple(input_shape.as_list())
|
||||
|
@ -51,7 +51,6 @@ from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||
from tensorflow.python.keras.utils.io_utils import path_to_string
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import py_checkpoint_reader
|
||||
@ -426,10 +425,6 @@ class Network(base_layer.Layer):
|
||||
self._is_graph_network = False
|
||||
self.inputs = None
|
||||
self.outputs = None
|
||||
# Since we don't know whether the subclass model support ragged inputs,
|
||||
# we leave it as True, otherwise the layer will raise error when a ragged
|
||||
# tensor is called as input.
|
||||
self._supports_ragged_inputs = True
|
||||
|
||||
@property
|
||||
@trackable_layer_utils.cache_recursive_attribute('dynamic')
|
||||
@ -1398,8 +1393,6 @@ class Network(base_layer.Layer):
|
||||
'Note that input tensors are '
|
||||
'instantiated via `tensor = tf.keras.Input(shape)`.\n'
|
||||
'The tensor that caused the issue was: ' + str(x.name))
|
||||
if isinstance(x, ragged_tensor.RaggedTensor):
|
||||
self._supports_ragged_inputs = True
|
||||
|
||||
# Check compatibility of batch sizes of Input Layers.
|
||||
input_batch_sizes = [
|
||||
|
@ -967,14 +967,9 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
@combinations.generate(combinations.keras_mode_combinations())
|
||||
def test_composite_call_kwarg_derived_from_keras_layer(self):
|
||||
|
||||
# Create a test layer that accepts composite tensor inputs (note the
|
||||
# 'supports_ragged_inputs = True' in the init method.)
|
||||
# Create a test layer that accepts composite tensor inputs.
|
||||
class MaybeAdd(layers.Layer):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(MaybeAdd, self).__init__(**kwargs)
|
||||
self._supports_ragged_inputs = True
|
||||
|
||||
def call(self, x1, x2=None):
|
||||
# We need to convert this to a tensor for loss calculations -
|
||||
# losses don't play nicely with ragged tensors yet.
|
||||
|
@ -37,6 +37,7 @@ from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import constraints
|
||||
from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.keras import regularizers
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine.base_layer import Layer
|
||||
from tensorflow.python.keras.engine.input_spec import InputSpec
|
||||
from tensorflow.python.keras.utils import conv_utils
|
||||
@ -830,7 +831,6 @@ class Lambda(Layer):
|
||||
if mask is not None:
|
||||
self.supports_masking = True
|
||||
self.mask = mask
|
||||
self._supports_ragged_inputs = True
|
||||
self._output_shape = output_shape
|
||||
|
||||
# Warning on every invocation will be quite irksome in Eager mode.
|
||||
@ -1177,6 +1177,7 @@ class Dense(Layer):
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
base_layer_utils.no_ragged_support(inputs, self.name)
|
||||
rank = inputs.shape.rank
|
||||
if rank is not None and rank > 2:
|
||||
# Broadcasting is required for the inputs.
|
||||
|
@ -121,7 +121,6 @@ class Embedding(Layer):
|
||||
self.mask_zero = mask_zero
|
||||
self.supports_masking = mask_zero
|
||||
self.input_length = input_length
|
||||
self._supports_ragged_inputs = True
|
||||
|
||||
@tf_utils.shape_type_conversion
|
||||
def build(self, input_shape):
|
||||
|
@ -21,6 +21,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine.base_layer import Layer
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -43,7 +44,6 @@ class _Merge(Layer):
|
||||
"""
|
||||
super(_Merge, self).__init__(**kwargs)
|
||||
self.supports_masking = True
|
||||
self._supports_ragged_inputs = True
|
||||
|
||||
def _merge_function(self, inputs):
|
||||
raise NotImplementedError
|
||||
@ -651,7 +651,6 @@ class Dot(_Merge):
|
||||
self.normalize = normalize
|
||||
self.supports_masking = True
|
||||
self._reshape_required = False
|
||||
self._supports_ragged_inputs = False
|
||||
|
||||
@tf_utils.shape_type_conversion
|
||||
def build(self, input_shape):
|
||||
@ -677,6 +676,7 @@ class Dot(_Merge):
|
||||
'Chosen axes: %s, %s' % (axes[0], axes[1]))
|
||||
|
||||
def _merge_function(self, inputs):
|
||||
base_layer_utils.no_ragged_support(inputs, self.name)
|
||||
if len(inputs) != 2:
|
||||
raise ValueError('A `Dot` layer should be called on exactly 2 inputs')
|
||||
x1 = inputs[0]
|
||||
|
@ -714,7 +714,6 @@ class GlobalPooling1D(Layer):
|
||||
super(GlobalPooling1D, self).__init__(**kwargs)
|
||||
self.input_spec = InputSpec(ndim=3)
|
||||
self.data_format = conv_utils.normalize_data_format(data_format)
|
||||
self._supports_ragged_inputs = True
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||
@ -849,7 +848,6 @@ class GlobalPooling2D(Layer):
|
||||
super(GlobalPooling2D, self).__init__(**kwargs)
|
||||
self.data_format = conv_utils.normalize_data_format(data_format)
|
||||
self.input_spec = InputSpec(ndim=4)
|
||||
self._supports_ragged_inputs = True
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||
@ -957,7 +955,6 @@ class GlobalPooling3D(Layer):
|
||||
super(GlobalPooling3D, self).__init__(**kwargs)
|
||||
self.data_format = conv_utils.normalize_data_format(data_format)
|
||||
self.input_spec = InputSpec(ndim=5)
|
||||
self._supports_ragged_inputs = True
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||
|
@ -161,7 +161,6 @@ class CategoryCrossing(Layer):
|
||||
self._depth_tuple = depth
|
||||
elif depth is not None:
|
||||
self._depth_tuple = tuple([i for i in range(1, depth + 1)])
|
||||
self._supports_ragged_inputs = True
|
||||
|
||||
def partial_crossing(self, partial_inputs, ragged_out, sparse_out):
|
||||
"""Gets the crossed output from a partial list/tuple of inputs."""
|
||||
|
@ -102,9 +102,6 @@ class CategoricalEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
self._sparse = sparse
|
||||
self._called = False
|
||||
|
||||
# This layer supports RaggedTensor inputs.
|
||||
self._supports_ragged_inputs = True
|
||||
|
||||
# We are adding these here instead of in build() since they do not depend
|
||||
# on the input shape at all.
|
||||
if max_tokens is None:
|
||||
|
@ -56,7 +56,6 @@ class Discretization(Layer):
|
||||
|
||||
def __init__(self, bins, output_mode=INTEGER, **kwargs):
|
||||
super(Discretization, self).__init__(**kwargs)
|
||||
self._supports_ragged_inputs = True
|
||||
self.bins = bins
|
||||
self.output_mode = output_mode
|
||||
|
||||
|
@ -90,7 +90,6 @@ class Hashing(Layer):
|
||||
super(Hashing, self).__init__(name=name, **kwargs)
|
||||
self.num_bins = num_bins
|
||||
self.salt = salt
|
||||
self._supports_ragged_inputs = True
|
||||
|
||||
def call(self, inputs):
|
||||
# Converts integer inputs to string.
|
||||
|
@ -150,9 +150,6 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
super(IndexLookup, self).__init__(
|
||||
combiner=_IndexLookupCombiner(self.max_tokens), **kwargs)
|
||||
|
||||
# This layer supports RaggedTensor inputs.
|
||||
self._supports_ragged_inputs = True
|
||||
|
||||
# If the layer's input type is int32, we can only output int32 values -
|
||||
# MutableHashTable doesn't allow us to map int32->int64.
|
||||
if self.dtype == dtypes.int32:
|
||||
|
@ -75,7 +75,6 @@ class Reduction(Layer):
|
||||
# We temporarily turn off autocasting, as it does not apply to named call
|
||||
# kwargs.
|
||||
super(Reduction, self).__init__(**kwargs)
|
||||
self._supports_ragged_inputs = True
|
||||
|
||||
def call(self, inputs, weights=None):
|
||||
# If we are not weighting the inputs we can immediately reduce the data
|
||||
|
@ -302,7 +302,6 @@ class TextVectorization(CombinerPreprocessingLayer):
|
||||
combiner=_TextVectorizationCombiner(
|
||||
self._max_vocab_size, compute_idf=output_mode == TFIDF),
|
||||
**kwargs)
|
||||
self._supports_ragged_inputs = True
|
||||
|
||||
reserve_zero = output_mode in [None, INT]
|
||||
self._index_lookup_layer = self._get_index_lookup_class()(
|
||||
|
@ -438,7 +438,6 @@ class RNN(Layer):
|
||||
self._states = None
|
||||
self.constants_spec = None
|
||||
self._num_constants = 0
|
||||
self._supports_ragged_inputs = True
|
||||
|
||||
if stateful:
|
||||
if ds_context.has_strategy():
|
||||
|
@ -125,7 +125,6 @@ class TimeDistributed(Wrapper):
|
||||
input=layer))
|
||||
super(TimeDistributed, self).__init__(layer, **kwargs)
|
||||
self.supports_masking = True
|
||||
self._supports_ragged_inputs = True
|
||||
|
||||
# It is safe to use the fast, reshape-based approach with all of our
|
||||
# built-in Layers.
|
||||
@ -449,7 +448,6 @@ class Bidirectional(Wrapper):
|
||||
self._trainable = True
|
||||
self._num_constants = 0
|
||||
self.input_spec = layer.input_spec
|
||||
self._supports_ragged_inputs = True
|
||||
|
||||
def _verify_layer_config(self):
|
||||
"""Ensure the forward and backward layers have valid common property."""
|
||||
|
@ -55,7 +55,6 @@ class ToDense(Layer):
|
||||
def __init__(self, default_value, **kwargs):
|
||||
super(ToDense, self).__init__(**kwargs)
|
||||
self._default_value = default_value
|
||||
self._supports_ragged_inputs = True
|
||||
|
||||
def call(self, inputs):
|
||||
if isinstance(inputs, dict): # Dicts are no longer flattened.
|
||||
@ -83,7 +82,6 @@ class ToRagged(Layer):
|
||||
super(ToRagged, self).__init__(**kwargs)
|
||||
self._padding = padding
|
||||
self._ragged_rank = ragged_rank
|
||||
self._supports_ragged_inputs = True
|
||||
|
||||
def call(self, inputs):
|
||||
return ragged_tensor.RaggedTensor.from_tensor(
|
||||
|
Loading…
Reference in New Issue
Block a user