Internal cleanup.
PiperOrigin-RevId: 250315112
This commit is contained in:
parent
3638d4b89b
commit
c47ec54211
@ -77,7 +77,7 @@ class _CuDNNRNN(RNN):
|
||||
self.state_spec = [InputSpec(shape=(None, dim)) for dim in state_size]
|
||||
self.constants_spec = None
|
||||
self._states = None
|
||||
self._num_constants = None
|
||||
self._num_constants = 0
|
||||
self._vector_shape = constant_op.constant([-1])
|
||||
|
||||
def call(self, inputs, mask=None, training=None, initial_state=None):
|
||||
|
||||
@ -406,7 +406,7 @@ class RNN(Layer):
|
||||
self.state_spec = None
|
||||
self._states = None
|
||||
self.constants_spec = None
|
||||
self._num_constants = None
|
||||
self._num_constants = 0
|
||||
|
||||
@property
|
||||
def states(self):
|
||||
@ -769,7 +769,7 @@ class RNN(Layer):
|
||||
and not isinstance(inputs, tuple)):
|
||||
# get initial_state from full input spec
|
||||
# as they could be copied to multiple GPU.
|
||||
if self._num_constants is None:
|
||||
if not self._num_constants:
|
||||
initial_state = inputs[1:]
|
||||
else:
|
||||
initial_state = inputs[1:-self._num_constants]
|
||||
@ -853,7 +853,7 @@ class RNN(Layer):
|
||||
'unroll': self.unroll,
|
||||
'time_major': self.time_major
|
||||
}
|
||||
if self._num_constants is not None:
|
||||
if self._num_constants:
|
||||
config['num_constants'] = self._num_constants
|
||||
if self.zero_output_for_mask:
|
||||
config['zero_output_for_mask'] = self.zero_output_for_mask
|
||||
@ -870,7 +870,7 @@ class RNN(Layer):
|
||||
def from_config(cls, config, custom_objects=None):
|
||||
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
|
||||
cell = deserialize_layer(config.pop('cell'), custom_objects=custom_objects)
|
||||
num_constants = config.pop('num_constants', None)
|
||||
num_constants = config.pop('num_constants', 0)
|
||||
layer = cls(cell, **config)
|
||||
layer._num_constants = num_constants
|
||||
return layer
|
||||
@ -2698,7 +2698,7 @@ def _standardize_args(inputs, initial_state, constants, num_constants):
|
||||
# could be a list of items, or list of list if the initial_state is complex
|
||||
# structure, and finally followed by constants which is a flat list.
|
||||
assert initial_state is None and constants is None
|
||||
if num_constants is not None:
|
||||
if num_constants:
|
||||
constants = inputs[-num_constants:]
|
||||
inputs = inputs[:-num_constants]
|
||||
if len(inputs) > 1:
|
||||
|
||||
@ -452,7 +452,7 @@ class Bidirectional(Wrapper):
|
||||
self.return_state = layer.return_state
|
||||
self.supports_masking = True
|
||||
self._trainable = True
|
||||
self._num_constants = None
|
||||
self._num_constants = 0
|
||||
# We don't want to track `layer` since we're already tracking the two copies
|
||||
# of it we actually run.
|
||||
self._setattr_tracking = False
|
||||
@ -619,11 +619,10 @@ class Bidirectional(Wrapper):
|
||||
# forward and backward section, and be feed to layers accordingly.
|
||||
forward_inputs = [inputs[0]]
|
||||
backward_inputs = [inputs[0]]
|
||||
pivot = (len(inputs) -
|
||||
(self._num_constants if self._num_constants else 0)) // 2 + 1
|
||||
pivot = (len(inputs) - self._num_constants) // 2 + 1
|
||||
# add forward initial state
|
||||
forward_inputs += inputs[1:pivot]
|
||||
if self._num_constants is None:
|
||||
if not self._num_constants:
|
||||
# add backward initial state
|
||||
backward_inputs += inputs[pivot:]
|
||||
else:
|
||||
@ -722,7 +721,7 @@ class Bidirectional(Wrapper):
|
||||
|
||||
def get_config(self):
|
||||
config = {'merge_mode': self.merge_mode}
|
||||
if self._num_constants is not None:
|
||||
if self._num_constants:
|
||||
config['num_constants'] = self._num_constants
|
||||
|
||||
if hasattr(self, '_backward_layer_config'):
|
||||
@ -737,7 +736,7 @@ class Bidirectional(Wrapper):
|
||||
def from_config(cls, config, custom_objects=None):
|
||||
# Instead of updating the input, create a copy and use that.
|
||||
config = config.copy()
|
||||
num_constants = config.pop('num_constants', None)
|
||||
num_constants = config.pop('num_constants', 0)
|
||||
backward_layer_config = config.pop('backward_layer', None)
|
||||
if backward_layer_config is not None:
|
||||
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user