Remove Layer._supports_ragged_inputs property.

This property made it more difficult to create a Layer that supports
RaggedTensors, since by default every user-created Layer class was assumed to
not work with RaggedTensors.

Instead, an error message is added to common built-in Layer subclasses that don't
support RaggedTensors.

PiperOrigin-RevId: 309315394
Change-Id: Id587d99cfaa4890c41aee49ec437f96108b4fbc7
This commit is contained in:
Thomas O'Malley 2020-04-30 15:50:27 -07:00 committed by TensorFlower Gardener
parent d11820b4cf
commit 862a62d6b4
21 changed files with 14 additions and 65 deletions

View File

@ -313,7 +313,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# Provides information about which inputs are compatible with the layer.
self._input_spec = None
self.supports_masking = False
self._supports_ragged_inputs = False
self._init_set_name(name)
self._activity_regularizer = kwargs.pop('activity_regularizer', None)
@ -905,12 +904,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# are casted, not before.
input_spec.assert_input_compatibility(self.input_spec, inputs,
self.name)
if (any(isinstance(x, ragged_tensor.RaggedTensor) for x in input_list)
and not self._supports_ragged_inputs):
raise ValueError('Layer %s does not support RaggedTensors as input. '
'Inputs received: %s. You can try converting your '
'input to an uniform tensor.' % (self.name, inputs))
graph = backend.get_graph()
with graph.as_default(), backend.name_scope(self._name_scope()):
# Build layer if applicable (if the `build` method has been

View File

@ -1454,7 +1454,6 @@ class DTypeTest(keras_parameterized.TestCase):
row_splits=array_ops.constant([0, 2, 2, 3], dtype='int64'))
layer = IdentityLayer(dtype='float16')
layer._supports_ragged_inputs = True
for x in sparse, ragged:
self.assertEqual(x.dtype, 'float32')
@ -1462,19 +1461,6 @@ class DTypeTest(keras_parameterized.TestCase):
self.assertEqual(y.dtype, 'float16')
self.assertEqual(type(x), type(y))
def test_supports_ragged_inputs_attribute_error(self):
with self.assertRaisesRegexp(ValueError,
'does not support RaggedTensors'):
ragged = ragged_tensor.RaggedTensor.from_row_splits(
values=array_ops.constant([1., 2., 3.], dtype='float32'),
row_splits=array_ops.constant([0, 2, 2, 3], dtype='int64'))
model = sequential.Sequential([
input_layer.InputLayer(input_shape=(None,), ragged=True),
IdentityLayer()
])
model.compile(rmsprop.RMSprop(0.001), loss='mse')
model.train_on_batch(ragged)
@testing_utils.enable_v2_dtype_behavior
def test_passing_non_tensor(self):
layer = IdentityLayer()

View File

@ -34,6 +34,7 @@ from tensorflow.python.ops import control_flow_v2_func_graphs
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import init_ops_v2
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.training.tracking import base as tracking
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
@ -792,6 +793,14 @@ class TrackableWeightHandler(object):
backend.get_session().run(self._assign_op, feed_dict)
def no_ragged_support(inputs, layer_name):
input_list = nest.flatten(inputs)
if any(isinstance(x, ragged_tensor.RaggedTensor) for x in input_list):
raise ValueError('Layer %s does not support RaggedTensors as input. '
'Inputs received: %s. You can try converting your '
'input to an uniform tensor.' % (layer_name, inputs))
# TODO(kathywu): This is a temporary hack. When a network of layers is revived
# from SavedModel, only the top-level layer will have losses. This causes issues
# in eager mode because the child layers may have graph losses

View File

@ -182,7 +182,6 @@ class Layer(base_layer.Layer):
# Provides information about which inputs are compatible with the layer.
self._input_spec = None
self.supports_masking = False
self._supports_ragged_inputs = False
self._init_set_name(name)
self._activity_regularizer = kwargs.pop('activity_regularizer', None)
@ -746,12 +745,6 @@ class Layer(base_layer.Layer):
# are casted, not before.
input_spec.assert_input_compatibility(self.input_spec, inputs,
self.name)
if (any(isinstance(x, ragged_tensor.RaggedTensor) for x in input_list)
and self._supports_ragged_inputs is False): # pylint: disable=g-bool-id-comparison
raise ValueError('Layer %s does not support RaggedTensors as input. '
'Inputs received: %s. You can try converting your '
'input to an uniform tensor.' % (self.name, inputs))
graph = backend.get_graph()
with graph.as_default(), backend.name_scope(self._name_scope()):
# Build layer if applicable (if the `build` method has been

View File

@ -132,7 +132,6 @@ class InputLayer(base_layer.Layer):
self.ragged = ragged
self.batch_size = batch_size
self.supports_masking = True
self._supports_ragged_inputs = True
if isinstance(input_shape, tensor_shape.TensorShape):
input_shape = tuple(input_shape.as_list())

View File

@ -51,7 +51,6 @@ from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.keras.utils.io_utils import path_to_string
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import py_checkpoint_reader
@ -426,10 +425,6 @@ class Network(base_layer.Layer):
self._is_graph_network = False
self.inputs = None
self.outputs = None
# Since we don't know whether the subclass model support ragged inputs,
# we leave it as True, otherwise the layer will raise error when a ragged
# tensor is called as input.
self._supports_ragged_inputs = True
@property
@trackable_layer_utils.cache_recursive_attribute('dynamic')
@ -1398,8 +1393,6 @@ class Network(base_layer.Layer):
'Note that input tensors are '
'instantiated via `tensor = tf.keras.Input(shape)`.\n'
'The tensor that caused the issue was: ' + str(x.name))
if isinstance(x, ragged_tensor.RaggedTensor):
self._supports_ragged_inputs = True
# Check compatibility of batch sizes of Input Layers.
input_batch_sizes = [

View File

@ -967,14 +967,9 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
@combinations.generate(combinations.keras_mode_combinations())
def test_composite_call_kwarg_derived_from_keras_layer(self):
# Create a test layer that accepts composite tensor inputs (note the
# 'supports_ragged_inputs = True' in the init method.)
# Create a test layer that accepts composite tensor inputs.
class MaybeAdd(layers.Layer):
def __init__(self, **kwargs):
super(MaybeAdd, self).__init__(**kwargs)
self._supports_ragged_inputs = True
def call(self, x1, x2=None):
# We need to convert this to a tensor for loss calculations -
# losses don't play nicely with ragged tensors yet.

View File

@ -37,6 +37,7 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.utils import conv_utils
@ -830,7 +831,6 @@ class Lambda(Layer):
if mask is not None:
self.supports_masking = True
self.mask = mask
self._supports_ragged_inputs = True
self._output_shape = output_shape
# Warning on every invocation will be quite irksome in Eager mode.
@ -1177,6 +1177,7 @@ class Dense(Layer):
self.built = True
def call(self, inputs):
base_layer_utils.no_ragged_support(inputs, self.name)
rank = inputs.shape.rank
if rank is not None and rank > 2:
# Broadcasting is required for the inputs.

View File

@ -121,7 +121,6 @@ class Embedding(Layer):
self.mask_zero = mask_zero
self.supports_masking = mask_zero
self.input_length = input_length
self._supports_ragged_inputs = True
@tf_utils.shape_type_conversion
def build(self, input_shape):

View File

@ -21,6 +21,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
@ -43,7 +44,6 @@ class _Merge(Layer):
"""
super(_Merge, self).__init__(**kwargs)
self.supports_masking = True
self._supports_ragged_inputs = True
def _merge_function(self, inputs):
raise NotImplementedError
@ -651,7 +651,6 @@ class Dot(_Merge):
self.normalize = normalize
self.supports_masking = True
self._reshape_required = False
self._supports_ragged_inputs = False
@tf_utils.shape_type_conversion
def build(self, input_shape):
@ -677,6 +676,7 @@ class Dot(_Merge):
'Chosen axes: %s, %s' % (axes[0], axes[1]))
def _merge_function(self, inputs):
base_layer_utils.no_ragged_support(inputs, self.name)
if len(inputs) != 2:
raise ValueError('A `Dot` layer should be called on exactly 2 inputs')
x1 = inputs[0]

View File

@ -714,7 +714,6 @@ class GlobalPooling1D(Layer):
super(GlobalPooling1D, self).__init__(**kwargs)
self.input_spec = InputSpec(ndim=3)
self.data_format = conv_utils.normalize_data_format(data_format)
self._supports_ragged_inputs = True
def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
@ -849,7 +848,6 @@ class GlobalPooling2D(Layer):
super(GlobalPooling2D, self).__init__(**kwargs)
self.data_format = conv_utils.normalize_data_format(data_format)
self.input_spec = InputSpec(ndim=4)
self._supports_ragged_inputs = True
def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
@ -957,7 +955,6 @@ class GlobalPooling3D(Layer):
super(GlobalPooling3D, self).__init__(**kwargs)
self.data_format = conv_utils.normalize_data_format(data_format)
self.input_spec = InputSpec(ndim=5)
self._supports_ragged_inputs = True
def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()

View File

@ -161,7 +161,6 @@ class CategoryCrossing(Layer):
self._depth_tuple = depth
elif depth is not None:
self._depth_tuple = tuple([i for i in range(1, depth + 1)])
self._supports_ragged_inputs = True
def partial_crossing(self, partial_inputs, ragged_out, sparse_out):
"""Gets the crossed output from a partial list/tuple of inputs."""

View File

@ -102,9 +102,6 @@ class CategoricalEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
self._sparse = sparse
self._called = False
# This layer supports RaggedTensor inputs.
self._supports_ragged_inputs = True
# We are adding these here instead of in build() since they do not depend
# on the input shape at all.
if max_tokens is None:

View File

@ -56,7 +56,6 @@ class Discretization(Layer):
def __init__(self, bins, output_mode=INTEGER, **kwargs):
super(Discretization, self).__init__(**kwargs)
self._supports_ragged_inputs = True
self.bins = bins
self.output_mode = output_mode

View File

@ -90,7 +90,6 @@ class Hashing(Layer):
super(Hashing, self).__init__(name=name, **kwargs)
self.num_bins = num_bins
self.salt = salt
self._supports_ragged_inputs = True
def call(self, inputs):
# Converts integer inputs to string.

View File

@ -150,9 +150,6 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
super(IndexLookup, self).__init__(
combiner=_IndexLookupCombiner(self.max_tokens), **kwargs)
# This layer supports RaggedTensor inputs.
self._supports_ragged_inputs = True
# If the layer's input type is int32, we can only output int32 values -
# MutableHashTable doesn't allow us to map int32->int64.
if self.dtype == dtypes.int32:

View File

@ -75,7 +75,6 @@ class Reduction(Layer):
# We temporarily turn off autocasting, as it does not apply to named call
# kwargs.
super(Reduction, self).__init__(**kwargs)
self._supports_ragged_inputs = True
def call(self, inputs, weights=None):
# If we are not weighting the inputs we can immediately reduce the data

View File

@ -302,7 +302,6 @@ class TextVectorization(CombinerPreprocessingLayer):
combiner=_TextVectorizationCombiner(
self._max_vocab_size, compute_idf=output_mode == TFIDF),
**kwargs)
self._supports_ragged_inputs = True
reserve_zero = output_mode in [None, INT]
self._index_lookup_layer = self._get_index_lookup_class()(

View File

@ -438,7 +438,6 @@ class RNN(Layer):
self._states = None
self.constants_spec = None
self._num_constants = 0
self._supports_ragged_inputs = True
if stateful:
if ds_context.has_strategy():

View File

@ -125,7 +125,6 @@ class TimeDistributed(Wrapper):
input=layer))
super(TimeDistributed, self).__init__(layer, **kwargs)
self.supports_masking = True
self._supports_ragged_inputs = True
# It is safe to use the fast, reshape-based approach with all of our
# built-in Layers.
@ -449,7 +448,6 @@ class Bidirectional(Wrapper):
self._trainable = True
self._num_constants = 0
self.input_spec = layer.input_spec
self._supports_ragged_inputs = True
def _verify_layer_config(self):
"""Ensure the forward and backward layers have valid common property."""

View File

@ -55,7 +55,6 @@ class ToDense(Layer):
def __init__(self, default_value, **kwargs):
super(ToDense, self).__init__(**kwargs)
self._default_value = default_value
self._supports_ragged_inputs = True
def call(self, inputs):
if isinstance(inputs, dict): # Dicts are no longer flattened.
@ -83,7 +82,6 @@ class ToRagged(Layer):
super(ToRagged, self).__init__(**kwargs)
self._padding = padding
self._ragged_rank = ragged_rank
self._supports_ragged_inputs = True
def call(self, inputs):
return ragged_tensor.RaggedTensor.from_tensor(