From c47ec542111140fffd8420023bab95ae21321a12 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Tue, 28 May 2019 10:40:18 -0700 Subject: [PATCH] Internal cleanup. PiperOrigin-RevId: 250315112 --- tensorflow/python/keras/layers/cudnn_recurrent.py | 2 +- tensorflow/python/keras/layers/recurrent.py | 10 +++++----- tensorflow/python/keras/layers/wrappers.py | 11 +++++------ 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/tensorflow/python/keras/layers/cudnn_recurrent.py b/tensorflow/python/keras/layers/cudnn_recurrent.py index 37650752861..68ac8b7b277 100644 --- a/tensorflow/python/keras/layers/cudnn_recurrent.py +++ b/tensorflow/python/keras/layers/cudnn_recurrent.py @@ -77,7 +77,7 @@ class _CuDNNRNN(RNN): self.state_spec = [InputSpec(shape=(None, dim)) for dim in state_size] self.constants_spec = None self._states = None - self._num_constants = None + self._num_constants = 0 self._vector_shape = constant_op.constant([-1]) def call(self, inputs, mask=None, training=None, initial_state=None): diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index 24a499f429a..ef04ac7a549 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -406,7 +406,7 @@ class RNN(Layer): self.state_spec = None self._states = None self.constants_spec = None - self._num_constants = None + self._num_constants = 0 @property def states(self): @@ -769,7 +769,7 @@ class RNN(Layer): and not isinstance(inputs, tuple)): # get initial_state from full input spec # as they could be copied to multiple GPU. - if self._num_constants is None: + if not self._num_constants: initial_state = inputs[1:] else: initial_state = inputs[1:-self._num_constants] @@ -853,7 +853,7 @@ class RNN(Layer): 'unroll': self.unroll, 'time_major': self.time_major } - if self._num_constants is not None: + if self._num_constants: config['num_constants'] = self._num_constants if self.zero_output_for_mask: config['zero_output_for_mask'] = self.zero_output_for_mask @@ -870,7 +870,7 @@ class RNN(Layer): def from_config(cls, config, custom_objects=None): from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top cell = deserialize_layer(config.pop('cell'), custom_objects=custom_objects) - num_constants = config.pop('num_constants', None) + num_constants = config.pop('num_constants', 0) layer = cls(cell, **config) layer._num_constants = num_constants return layer @@ -2698,7 +2698,7 @@ def _standardize_args(inputs, initial_state, constants, num_constants): # could be a list of items, or list of list if the initial_state is complex # structure, and finally followed by constants which is a flat list. assert initial_state is None and constants is None - if num_constants is not None: + if num_constants: constants = inputs[-num_constants:] inputs = inputs[:-num_constants] if len(inputs) > 1: diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py index 76e2967b477..1c249aa87a0 100644 --- a/tensorflow/python/keras/layers/wrappers.py +++ b/tensorflow/python/keras/layers/wrappers.py @@ -452,7 +452,7 @@ class Bidirectional(Wrapper): self.return_state = layer.return_state self.supports_masking = True self._trainable = True - self._num_constants = None + self._num_constants = 0 # We don't want to track `layer` since we're already tracking the two copies # of it we actually run. self._setattr_tracking = False @@ -619,11 +619,10 @@ class Bidirectional(Wrapper): # forward and backward section, and be feed to layers accordingly. forward_inputs = [inputs[0]] backward_inputs = [inputs[0]] - pivot = (len(inputs) - - (self._num_constants if self._num_constants else 0)) // 2 + 1 + pivot = (len(inputs) - self._num_constants) // 2 + 1 # add forward initial state forward_inputs += inputs[1:pivot] - if self._num_constants is None: + if not self._num_constants: # add backward initial state backward_inputs += inputs[pivot:] else: @@ -722,7 +721,7 @@ class Bidirectional(Wrapper): def get_config(self): config = {'merge_mode': self.merge_mode} - if self._num_constants is not None: + if self._num_constants: config['num_constants'] = self._num_constants if hasattr(self, '_backward_layer_config'): @@ -737,7 +736,7 @@ class Bidirectional(Wrapper): def from_config(cls, config, custom_objects=None): # Instead of updating the input, create a copy and use that. config = config.copy() - num_constants = config.pop('num_constants', None) + num_constants = config.pop('num_constants', 0) backward_layer_config = config.pop('backward_layer', None) if backward_layer_config is not None: from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top