Internal cleanup.

PiperOrigin-RevId: 250315112
This commit is contained in:
Scott Zhu 2019-05-28 10:40:18 -07:00 committed by TensorFlower Gardener
parent 3638d4b89b
commit c47ec54211
3 changed files with 11 additions and 12 deletions

View File

@ -77,7 +77,7 @@ class _CuDNNRNN(RNN):
self.state_spec = [InputSpec(shape=(None, dim)) for dim in state_size]
self.constants_spec = None
self._states = None
self._num_constants = None
self._num_constants = 0
self._vector_shape = constant_op.constant([-1])
def call(self, inputs, mask=None, training=None, initial_state=None):

View File

@ -406,7 +406,7 @@ class RNN(Layer):
self.state_spec = None
self._states = None
self.constants_spec = None
self._num_constants = None
self._num_constants = 0
@property
def states(self):
@ -769,7 +769,7 @@ class RNN(Layer):
and not isinstance(inputs, tuple)):
# get initial_state from full input spec
# as they could be copied to multiple GPU.
if self._num_constants is None:
if not self._num_constants:
initial_state = inputs[1:]
else:
initial_state = inputs[1:-self._num_constants]
@ -853,7 +853,7 @@ class RNN(Layer):
'unroll': self.unroll,
'time_major': self.time_major
}
if self._num_constants is not None:
if self._num_constants:
config['num_constants'] = self._num_constants
if self.zero_output_for_mask:
config['zero_output_for_mask'] = self.zero_output_for_mask
@ -870,7 +870,7 @@ class RNN(Layer):
def from_config(cls, config, custom_objects=None):
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
cell = deserialize_layer(config.pop('cell'), custom_objects=custom_objects)
num_constants = config.pop('num_constants', None)
num_constants = config.pop('num_constants', 0)
layer = cls(cell, **config)
layer._num_constants = num_constants
return layer
@ -2698,7 +2698,7 @@ def _standardize_args(inputs, initial_state, constants, num_constants):
# could be a list of items, or list of list if the initial_state is complex
# structure, and finally followed by constants which is a flat list.
assert initial_state is None and constants is None
if num_constants is not None:
if num_constants:
constants = inputs[-num_constants:]
inputs = inputs[:-num_constants]
if len(inputs) > 1:

View File

@ -452,7 +452,7 @@ class Bidirectional(Wrapper):
self.return_state = layer.return_state
self.supports_masking = True
self._trainable = True
self._num_constants = None
self._num_constants = 0
# We don't want to track `layer` since we're already tracking the two copies
# of it we actually run.
self._setattr_tracking = False
@ -619,11 +619,10 @@ class Bidirectional(Wrapper):
# forward and backward section, and be feed to layers accordingly.
forward_inputs = [inputs[0]]
backward_inputs = [inputs[0]]
pivot = (len(inputs) -
(self._num_constants if self._num_constants else 0)) // 2 + 1
pivot = (len(inputs) - self._num_constants) // 2 + 1
# add forward initial state
forward_inputs += inputs[1:pivot]
if self._num_constants is None:
if not self._num_constants:
# add backward initial state
backward_inputs += inputs[pivot:]
else:
@ -722,7 +721,7 @@ class Bidirectional(Wrapper):
def get_config(self):
config = {'merge_mode': self.merge_mode}
if self._num_constants is not None:
if self._num_constants:
config['num_constants'] = self._num_constants
if hasattr(self, '_backward_layer_config'):
@ -737,7 +736,7 @@ class Bidirectional(Wrapper):
def from_config(cls, config, custom_objects=None):
# Instead of updating the input, create a copy and use that.
config = config.copy()
num_constants = config.pop('num_constants', None)
num_constants = config.pop('num_constants', 0)
backward_layer_config = config.pop('backward_layer', None)
if backward_layer_config is not None:
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top