Remove Layer._supports_ragged_inputs property.
This property made it more difficult to create a Layer that supports RaggedTensors, since by default every user-created Layer class was assumed to not work with RaggedTensors. Instead, an error message is added to common built-in Layer subclasses that don't support RaggedTensors. PiperOrigin-RevId: 309315394 Change-Id: Id587d99cfaa4890c41aee49ec437f96108b4fbc7
This commit is contained in:
parent
d11820b4cf
commit
862a62d6b4
@ -313,7 +313,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
# Provides information about which inputs are compatible with the layer.
|
# Provides information about which inputs are compatible with the layer.
|
||||||
self._input_spec = None
|
self._input_spec = None
|
||||||
self.supports_masking = False
|
self.supports_masking = False
|
||||||
self._supports_ragged_inputs = False
|
|
||||||
|
|
||||||
self._init_set_name(name)
|
self._init_set_name(name)
|
||||||
self._activity_regularizer = kwargs.pop('activity_regularizer', None)
|
self._activity_regularizer = kwargs.pop('activity_regularizer', None)
|
||||||
@ -905,12 +904,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
# are casted, not before.
|
# are casted, not before.
|
||||||
input_spec.assert_input_compatibility(self.input_spec, inputs,
|
input_spec.assert_input_compatibility(self.input_spec, inputs,
|
||||||
self.name)
|
self.name)
|
||||||
if (any(isinstance(x, ragged_tensor.RaggedTensor) for x in input_list)
|
|
||||||
and not self._supports_ragged_inputs):
|
|
||||||
raise ValueError('Layer %s does not support RaggedTensors as input. '
|
|
||||||
'Inputs received: %s. You can try converting your '
|
|
||||||
'input to an uniform tensor.' % (self.name, inputs))
|
|
||||||
|
|
||||||
graph = backend.get_graph()
|
graph = backend.get_graph()
|
||||||
with graph.as_default(), backend.name_scope(self._name_scope()):
|
with graph.as_default(), backend.name_scope(self._name_scope()):
|
||||||
# Build layer if applicable (if the `build` method has been
|
# Build layer if applicable (if the `build` method has been
|
||||||
|
@ -1454,7 +1454,6 @@ class DTypeTest(keras_parameterized.TestCase):
|
|||||||
row_splits=array_ops.constant([0, 2, 2, 3], dtype='int64'))
|
row_splits=array_ops.constant([0, 2, 2, 3], dtype='int64'))
|
||||||
|
|
||||||
layer = IdentityLayer(dtype='float16')
|
layer = IdentityLayer(dtype='float16')
|
||||||
layer._supports_ragged_inputs = True
|
|
||||||
|
|
||||||
for x in sparse, ragged:
|
for x in sparse, ragged:
|
||||||
self.assertEqual(x.dtype, 'float32')
|
self.assertEqual(x.dtype, 'float32')
|
||||||
@ -1462,19 +1461,6 @@ class DTypeTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(y.dtype, 'float16')
|
self.assertEqual(y.dtype, 'float16')
|
||||||
self.assertEqual(type(x), type(y))
|
self.assertEqual(type(x), type(y))
|
||||||
|
|
||||||
def test_supports_ragged_inputs_attribute_error(self):
|
|
||||||
with self.assertRaisesRegexp(ValueError,
|
|
||||||
'does not support RaggedTensors'):
|
|
||||||
ragged = ragged_tensor.RaggedTensor.from_row_splits(
|
|
||||||
values=array_ops.constant([1., 2., 3.], dtype='float32'),
|
|
||||||
row_splits=array_ops.constant([0, 2, 2, 3], dtype='int64'))
|
|
||||||
model = sequential.Sequential([
|
|
||||||
input_layer.InputLayer(input_shape=(None,), ragged=True),
|
|
||||||
IdentityLayer()
|
|
||||||
])
|
|
||||||
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
|
||||||
model.train_on_batch(ragged)
|
|
||||||
|
|
||||||
@testing_utils.enable_v2_dtype_behavior
|
@testing_utils.enable_v2_dtype_behavior
|
||||||
def test_passing_non_tensor(self):
|
def test_passing_non_tensor(self):
|
||||||
layer = IdentityLayer()
|
layer = IdentityLayer()
|
||||||
|
@ -34,6 +34,7 @@ from tensorflow.python.ops import control_flow_v2_func_graphs
|
|||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import init_ops_v2
|
from tensorflow.python.ops import init_ops_v2
|
||||||
from tensorflow.python.ops import variables as tf_variables
|
from tensorflow.python.ops import variables as tf_variables
|
||||||
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
from tensorflow.python.training.tracking import base as tracking
|
from tensorflow.python.training.tracking import base as tracking
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import tf_contextlib
|
from tensorflow.python.util import tf_contextlib
|
||||||
@ -792,6 +793,14 @@ class TrackableWeightHandler(object):
|
|||||||
backend.get_session().run(self._assign_op, feed_dict)
|
backend.get_session().run(self._assign_op, feed_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def no_ragged_support(inputs, layer_name):
|
||||||
|
input_list = nest.flatten(inputs)
|
||||||
|
if any(isinstance(x, ragged_tensor.RaggedTensor) for x in input_list):
|
||||||
|
raise ValueError('Layer %s does not support RaggedTensors as input. '
|
||||||
|
'Inputs received: %s. You can try converting your '
|
||||||
|
'input to an uniform tensor.' % (layer_name, inputs))
|
||||||
|
|
||||||
|
|
||||||
# TODO(kathywu): This is a temporary hack. When a network of layers is revived
|
# TODO(kathywu): This is a temporary hack. When a network of layers is revived
|
||||||
# from SavedModel, only the top-level layer will have losses. This causes issues
|
# from SavedModel, only the top-level layer will have losses. This causes issues
|
||||||
# in eager mode because the child layers may have graph losses
|
# in eager mode because the child layers may have graph losses
|
||||||
|
@ -182,7 +182,6 @@ class Layer(base_layer.Layer):
|
|||||||
# Provides information about which inputs are compatible with the layer.
|
# Provides information about which inputs are compatible with the layer.
|
||||||
self._input_spec = None
|
self._input_spec = None
|
||||||
self.supports_masking = False
|
self.supports_masking = False
|
||||||
self._supports_ragged_inputs = False
|
|
||||||
|
|
||||||
self._init_set_name(name)
|
self._init_set_name(name)
|
||||||
self._activity_regularizer = kwargs.pop('activity_regularizer', None)
|
self._activity_regularizer = kwargs.pop('activity_regularizer', None)
|
||||||
@ -746,12 +745,6 @@ class Layer(base_layer.Layer):
|
|||||||
# are casted, not before.
|
# are casted, not before.
|
||||||
input_spec.assert_input_compatibility(self.input_spec, inputs,
|
input_spec.assert_input_compatibility(self.input_spec, inputs,
|
||||||
self.name)
|
self.name)
|
||||||
if (any(isinstance(x, ragged_tensor.RaggedTensor) for x in input_list)
|
|
||||||
and self._supports_ragged_inputs is False): # pylint: disable=g-bool-id-comparison
|
|
||||||
raise ValueError('Layer %s does not support RaggedTensors as input. '
|
|
||||||
'Inputs received: %s. You can try converting your '
|
|
||||||
'input to an uniform tensor.' % (self.name, inputs))
|
|
||||||
|
|
||||||
graph = backend.get_graph()
|
graph = backend.get_graph()
|
||||||
with graph.as_default(), backend.name_scope(self._name_scope()):
|
with graph.as_default(), backend.name_scope(self._name_scope()):
|
||||||
# Build layer if applicable (if the `build` method has been
|
# Build layer if applicable (if the `build` method has been
|
||||||
|
@ -132,7 +132,6 @@ class InputLayer(base_layer.Layer):
|
|||||||
self.ragged = ragged
|
self.ragged = ragged
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.supports_masking = True
|
self.supports_masking = True
|
||||||
self._supports_ragged_inputs = True
|
|
||||||
|
|
||||||
if isinstance(input_shape, tensor_shape.TensorShape):
|
if isinstance(input_shape, tensor_shape.TensorShape):
|
||||||
input_shape = tuple(input_shape.as_list())
|
input_shape = tuple(input_shape.as_list())
|
||||||
|
@ -51,7 +51,6 @@ from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
|||||||
from tensorflow.python.keras.utils.io_utils import path_to_string
|
from tensorflow.python.keras.utils.io_utils import path_to_string
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import checkpoint_management
|
from tensorflow.python.training import checkpoint_management
|
||||||
from tensorflow.python.training import py_checkpoint_reader
|
from tensorflow.python.training import py_checkpoint_reader
|
||||||
@ -426,10 +425,6 @@ class Network(base_layer.Layer):
|
|||||||
self._is_graph_network = False
|
self._is_graph_network = False
|
||||||
self.inputs = None
|
self.inputs = None
|
||||||
self.outputs = None
|
self.outputs = None
|
||||||
# Since we don't know whether the subclass model support ragged inputs,
|
|
||||||
# we leave it as True, otherwise the layer will raise error when a ragged
|
|
||||||
# tensor is called as input.
|
|
||||||
self._supports_ragged_inputs = True
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@trackable_layer_utils.cache_recursive_attribute('dynamic')
|
@trackable_layer_utils.cache_recursive_attribute('dynamic')
|
||||||
@ -1398,8 +1393,6 @@ class Network(base_layer.Layer):
|
|||||||
'Note that input tensors are '
|
'Note that input tensors are '
|
||||||
'instantiated via `tensor = tf.keras.Input(shape)`.\n'
|
'instantiated via `tensor = tf.keras.Input(shape)`.\n'
|
||||||
'The tensor that caused the issue was: ' + str(x.name))
|
'The tensor that caused the issue was: ' + str(x.name))
|
||||||
if isinstance(x, ragged_tensor.RaggedTensor):
|
|
||||||
self._supports_ragged_inputs = True
|
|
||||||
|
|
||||||
# Check compatibility of batch sizes of Input Layers.
|
# Check compatibility of batch sizes of Input Layers.
|
||||||
input_batch_sizes = [
|
input_batch_sizes = [
|
||||||
|
@ -967,14 +967,9 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
|||||||
@combinations.generate(combinations.keras_mode_combinations())
|
@combinations.generate(combinations.keras_mode_combinations())
|
||||||
def test_composite_call_kwarg_derived_from_keras_layer(self):
|
def test_composite_call_kwarg_derived_from_keras_layer(self):
|
||||||
|
|
||||||
# Create a test layer that accepts composite tensor inputs (note the
|
# Create a test layer that accepts composite tensor inputs.
|
||||||
# 'supports_ragged_inputs = True' in the init method.)
|
|
||||||
class MaybeAdd(layers.Layer):
|
class MaybeAdd(layers.Layer):
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(MaybeAdd, self).__init__(**kwargs)
|
|
||||||
self._supports_ragged_inputs = True
|
|
||||||
|
|
||||||
def call(self, x1, x2=None):
|
def call(self, x1, x2=None):
|
||||||
# We need to convert this to a tensor for loss calculations -
|
# We need to convert this to a tensor for loss calculations -
|
||||||
# losses don't play nicely with ragged tensors yet.
|
# losses don't play nicely with ragged tensors yet.
|
||||||
|
@ -37,6 +37,7 @@ from tensorflow.python.keras import backend as K
|
|||||||
from tensorflow.python.keras import constraints
|
from tensorflow.python.keras import constraints
|
||||||
from tensorflow.python.keras import initializers
|
from tensorflow.python.keras import initializers
|
||||||
from tensorflow.python.keras import regularizers
|
from tensorflow.python.keras import regularizers
|
||||||
|
from tensorflow.python.keras.engine import base_layer_utils
|
||||||
from tensorflow.python.keras.engine.base_layer import Layer
|
from tensorflow.python.keras.engine.base_layer import Layer
|
||||||
from tensorflow.python.keras.engine.input_spec import InputSpec
|
from tensorflow.python.keras.engine.input_spec import InputSpec
|
||||||
from tensorflow.python.keras.utils import conv_utils
|
from tensorflow.python.keras.utils import conv_utils
|
||||||
@ -830,7 +831,6 @@ class Lambda(Layer):
|
|||||||
if mask is not None:
|
if mask is not None:
|
||||||
self.supports_masking = True
|
self.supports_masking = True
|
||||||
self.mask = mask
|
self.mask = mask
|
||||||
self._supports_ragged_inputs = True
|
|
||||||
self._output_shape = output_shape
|
self._output_shape = output_shape
|
||||||
|
|
||||||
# Warning on every invocation will be quite irksome in Eager mode.
|
# Warning on every invocation will be quite irksome in Eager mode.
|
||||||
@ -1177,6 +1177,7 @@ class Dense(Layer):
|
|||||||
self.built = True
|
self.built = True
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
|
base_layer_utils.no_ragged_support(inputs, self.name)
|
||||||
rank = inputs.shape.rank
|
rank = inputs.shape.rank
|
||||||
if rank is not None and rank > 2:
|
if rank is not None and rank > 2:
|
||||||
# Broadcasting is required for the inputs.
|
# Broadcasting is required for the inputs.
|
||||||
|
@ -121,7 +121,6 @@ class Embedding(Layer):
|
|||||||
self.mask_zero = mask_zero
|
self.mask_zero = mask_zero
|
||||||
self.supports_masking = mask_zero
|
self.supports_masking = mask_zero
|
||||||
self.input_length = input_length
|
self.input_length = input_length
|
||||||
self._supports_ragged_inputs = True
|
|
||||||
|
|
||||||
@tf_utils.shape_type_conversion
|
@tf_utils.shape_type_conversion
|
||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
|
from tensorflow.python.keras.engine import base_layer_utils
|
||||||
from tensorflow.python.keras.engine.base_layer import Layer
|
from tensorflow.python.keras.engine.base_layer import Layer
|
||||||
from tensorflow.python.keras.utils import tf_utils
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -43,7 +44,6 @@ class _Merge(Layer):
|
|||||||
"""
|
"""
|
||||||
super(_Merge, self).__init__(**kwargs)
|
super(_Merge, self).__init__(**kwargs)
|
||||||
self.supports_masking = True
|
self.supports_masking = True
|
||||||
self._supports_ragged_inputs = True
|
|
||||||
|
|
||||||
def _merge_function(self, inputs):
|
def _merge_function(self, inputs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -651,7 +651,6 @@ class Dot(_Merge):
|
|||||||
self.normalize = normalize
|
self.normalize = normalize
|
||||||
self.supports_masking = True
|
self.supports_masking = True
|
||||||
self._reshape_required = False
|
self._reshape_required = False
|
||||||
self._supports_ragged_inputs = False
|
|
||||||
|
|
||||||
@tf_utils.shape_type_conversion
|
@tf_utils.shape_type_conversion
|
||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
@ -677,6 +676,7 @@ class Dot(_Merge):
|
|||||||
'Chosen axes: %s, %s' % (axes[0], axes[1]))
|
'Chosen axes: %s, %s' % (axes[0], axes[1]))
|
||||||
|
|
||||||
def _merge_function(self, inputs):
|
def _merge_function(self, inputs):
|
||||||
|
base_layer_utils.no_ragged_support(inputs, self.name)
|
||||||
if len(inputs) != 2:
|
if len(inputs) != 2:
|
||||||
raise ValueError('A `Dot` layer should be called on exactly 2 inputs')
|
raise ValueError('A `Dot` layer should be called on exactly 2 inputs')
|
||||||
x1 = inputs[0]
|
x1 = inputs[0]
|
||||||
|
@ -714,7 +714,6 @@ class GlobalPooling1D(Layer):
|
|||||||
super(GlobalPooling1D, self).__init__(**kwargs)
|
super(GlobalPooling1D, self).__init__(**kwargs)
|
||||||
self.input_spec = InputSpec(ndim=3)
|
self.input_spec = InputSpec(ndim=3)
|
||||||
self.data_format = conv_utils.normalize_data_format(data_format)
|
self.data_format = conv_utils.normalize_data_format(data_format)
|
||||||
self._supports_ragged_inputs = True
|
|
||||||
|
|
||||||
def compute_output_shape(self, input_shape):
|
def compute_output_shape(self, input_shape):
|
||||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||||
@ -849,7 +848,6 @@ class GlobalPooling2D(Layer):
|
|||||||
super(GlobalPooling2D, self).__init__(**kwargs)
|
super(GlobalPooling2D, self).__init__(**kwargs)
|
||||||
self.data_format = conv_utils.normalize_data_format(data_format)
|
self.data_format = conv_utils.normalize_data_format(data_format)
|
||||||
self.input_spec = InputSpec(ndim=4)
|
self.input_spec = InputSpec(ndim=4)
|
||||||
self._supports_ragged_inputs = True
|
|
||||||
|
|
||||||
def compute_output_shape(self, input_shape):
|
def compute_output_shape(self, input_shape):
|
||||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||||
@ -957,7 +955,6 @@ class GlobalPooling3D(Layer):
|
|||||||
super(GlobalPooling3D, self).__init__(**kwargs)
|
super(GlobalPooling3D, self).__init__(**kwargs)
|
||||||
self.data_format = conv_utils.normalize_data_format(data_format)
|
self.data_format = conv_utils.normalize_data_format(data_format)
|
||||||
self.input_spec = InputSpec(ndim=5)
|
self.input_spec = InputSpec(ndim=5)
|
||||||
self._supports_ragged_inputs = True
|
|
||||||
|
|
||||||
def compute_output_shape(self, input_shape):
|
def compute_output_shape(self, input_shape):
|
||||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||||
|
@ -161,7 +161,6 @@ class CategoryCrossing(Layer):
|
|||||||
self._depth_tuple = depth
|
self._depth_tuple = depth
|
||||||
elif depth is not None:
|
elif depth is not None:
|
||||||
self._depth_tuple = tuple([i for i in range(1, depth + 1)])
|
self._depth_tuple = tuple([i for i in range(1, depth + 1)])
|
||||||
self._supports_ragged_inputs = True
|
|
||||||
|
|
||||||
def partial_crossing(self, partial_inputs, ragged_out, sparse_out):
|
def partial_crossing(self, partial_inputs, ragged_out, sparse_out):
|
||||||
"""Gets the crossed output from a partial list/tuple of inputs."""
|
"""Gets the crossed output from a partial list/tuple of inputs."""
|
||||||
|
@ -102,9 +102,6 @@ class CategoricalEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
|
|||||||
self._sparse = sparse
|
self._sparse = sparse
|
||||||
self._called = False
|
self._called = False
|
||||||
|
|
||||||
# This layer supports RaggedTensor inputs.
|
|
||||||
self._supports_ragged_inputs = True
|
|
||||||
|
|
||||||
# We are adding these here instead of in build() since they do not depend
|
# We are adding these here instead of in build() since they do not depend
|
||||||
# on the input shape at all.
|
# on the input shape at all.
|
||||||
if max_tokens is None:
|
if max_tokens is None:
|
||||||
|
@ -56,7 +56,6 @@ class Discretization(Layer):
|
|||||||
|
|
||||||
def __init__(self, bins, output_mode=INTEGER, **kwargs):
|
def __init__(self, bins, output_mode=INTEGER, **kwargs):
|
||||||
super(Discretization, self).__init__(**kwargs)
|
super(Discretization, self).__init__(**kwargs)
|
||||||
self._supports_ragged_inputs = True
|
|
||||||
self.bins = bins
|
self.bins = bins
|
||||||
self.output_mode = output_mode
|
self.output_mode = output_mode
|
||||||
|
|
||||||
|
@ -90,7 +90,6 @@ class Hashing(Layer):
|
|||||||
super(Hashing, self).__init__(name=name, **kwargs)
|
super(Hashing, self).__init__(name=name, **kwargs)
|
||||||
self.num_bins = num_bins
|
self.num_bins = num_bins
|
||||||
self.salt = salt
|
self.salt = salt
|
||||||
self._supports_ragged_inputs = True
|
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
# Converts integer inputs to string.
|
# Converts integer inputs to string.
|
||||||
|
@ -150,9 +150,6 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
|||||||
super(IndexLookup, self).__init__(
|
super(IndexLookup, self).__init__(
|
||||||
combiner=_IndexLookupCombiner(self.max_tokens), **kwargs)
|
combiner=_IndexLookupCombiner(self.max_tokens), **kwargs)
|
||||||
|
|
||||||
# This layer supports RaggedTensor inputs.
|
|
||||||
self._supports_ragged_inputs = True
|
|
||||||
|
|
||||||
# If the layer's input type is int32, we can only output int32 values -
|
# If the layer's input type is int32, we can only output int32 values -
|
||||||
# MutableHashTable doesn't allow us to map int32->int64.
|
# MutableHashTable doesn't allow us to map int32->int64.
|
||||||
if self.dtype == dtypes.int32:
|
if self.dtype == dtypes.int32:
|
||||||
|
@ -75,7 +75,6 @@ class Reduction(Layer):
|
|||||||
# We temporarily turn off autocasting, as it does not apply to named call
|
# We temporarily turn off autocasting, as it does not apply to named call
|
||||||
# kwargs.
|
# kwargs.
|
||||||
super(Reduction, self).__init__(**kwargs)
|
super(Reduction, self).__init__(**kwargs)
|
||||||
self._supports_ragged_inputs = True
|
|
||||||
|
|
||||||
def call(self, inputs, weights=None):
|
def call(self, inputs, weights=None):
|
||||||
# If we are not weighting the inputs we can immediately reduce the data
|
# If we are not weighting the inputs we can immediately reduce the data
|
||||||
|
@ -302,7 +302,6 @@ class TextVectorization(CombinerPreprocessingLayer):
|
|||||||
combiner=_TextVectorizationCombiner(
|
combiner=_TextVectorizationCombiner(
|
||||||
self._max_vocab_size, compute_idf=output_mode == TFIDF),
|
self._max_vocab_size, compute_idf=output_mode == TFIDF),
|
||||||
**kwargs)
|
**kwargs)
|
||||||
self._supports_ragged_inputs = True
|
|
||||||
|
|
||||||
reserve_zero = output_mode in [None, INT]
|
reserve_zero = output_mode in [None, INT]
|
||||||
self._index_lookup_layer = self._get_index_lookup_class()(
|
self._index_lookup_layer = self._get_index_lookup_class()(
|
||||||
|
@ -438,7 +438,6 @@ class RNN(Layer):
|
|||||||
self._states = None
|
self._states = None
|
||||||
self.constants_spec = None
|
self.constants_spec = None
|
||||||
self._num_constants = 0
|
self._num_constants = 0
|
||||||
self._supports_ragged_inputs = True
|
|
||||||
|
|
||||||
if stateful:
|
if stateful:
|
||||||
if ds_context.has_strategy():
|
if ds_context.has_strategy():
|
||||||
|
@ -125,7 +125,6 @@ class TimeDistributed(Wrapper):
|
|||||||
input=layer))
|
input=layer))
|
||||||
super(TimeDistributed, self).__init__(layer, **kwargs)
|
super(TimeDistributed, self).__init__(layer, **kwargs)
|
||||||
self.supports_masking = True
|
self.supports_masking = True
|
||||||
self._supports_ragged_inputs = True
|
|
||||||
|
|
||||||
# It is safe to use the fast, reshape-based approach with all of our
|
# It is safe to use the fast, reshape-based approach with all of our
|
||||||
# built-in Layers.
|
# built-in Layers.
|
||||||
@ -449,7 +448,6 @@ class Bidirectional(Wrapper):
|
|||||||
self._trainable = True
|
self._trainable = True
|
||||||
self._num_constants = 0
|
self._num_constants = 0
|
||||||
self.input_spec = layer.input_spec
|
self.input_spec = layer.input_spec
|
||||||
self._supports_ragged_inputs = True
|
|
||||||
|
|
||||||
def _verify_layer_config(self):
|
def _verify_layer_config(self):
|
||||||
"""Ensure the forward and backward layers have valid common property."""
|
"""Ensure the forward and backward layers have valid common property."""
|
||||||
|
@ -55,7 +55,6 @@ class ToDense(Layer):
|
|||||||
def __init__(self, default_value, **kwargs):
|
def __init__(self, default_value, **kwargs):
|
||||||
super(ToDense, self).__init__(**kwargs)
|
super(ToDense, self).__init__(**kwargs)
|
||||||
self._default_value = default_value
|
self._default_value = default_value
|
||||||
self._supports_ragged_inputs = True
|
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
if isinstance(inputs, dict): # Dicts are no longer flattened.
|
if isinstance(inputs, dict): # Dicts are no longer flattened.
|
||||||
@ -83,7 +82,6 @@ class ToRagged(Layer):
|
|||||||
super(ToRagged, self).__init__(**kwargs)
|
super(ToRagged, self).__init__(**kwargs)
|
||||||
self._padding = padding
|
self._padding = padding
|
||||||
self._ragged_rank = ragged_rank
|
self._ragged_rank = ragged_rank
|
||||||
self._supports_ragged_inputs = True
|
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
return ragged_tensor.RaggedTensor.from_tensor(
|
return ragged_tensor.RaggedTensor.from_tensor(
|
||||||
|
Loading…
Reference in New Issue
Block a user