Cleanup some duplicated methods for UnifiedLSTM.

The methods in the parent class should work the same way.

PiperOrigin-RevId: 225094141
This commit is contained in:
Scott Zhu 2018-12-11 16:40:04 -08:00 committed by TensorFlower Gardener
parent fc220a61b7
commit 9b8005ece0

View File

@ -2530,6 +2530,7 @@ class LSTM(RNN):
config['implementation'] = 1 config['implementation'] = 1
return cls(**config) return cls(**config)
@tf_export('keras.layers.LSTM', v1=[]) @tf_export('keras.layers.LSTM', v1=[])
class UnifiedLSTM(LSTM): class UnifiedLSTM(LSTM):
"""Long Short-Term Memory layer - Hochreiter 1997. """Long Short-Term Memory layer - Hochreiter 1997.
@ -2655,8 +2656,6 @@ class UnifiedLSTM(LSTM):
self.state_spec = [ self.state_spec = [
InputSpec(shape=(None, dim)) for dim in (self.units, self.units) InputSpec(shape=(None, dim)) for dim in (self.units, self.units)
] ]
self._num_constants = None
self._num_inputs = None
self._dropout_mask = None self._dropout_mask = None
self.could_use_cudnn = ( self.could_use_cudnn = (
activation == 'tanh' and recurrent_activation == 'sigmoid' and activation == 'tanh' and recurrent_activation == 'sigmoid' and
@ -2775,46 +2774,6 @@ class UnifiedLSTM(LSTM):
else: else:
return output return output
@property
def trainable_weights(self):
if self.trainable:
weights = []
weights += self.cell.trainable_weights
return weights
return []
@property
def non_trainable_weights(self):
if not self.trainable:
weights = []
weights += self.cell.non_trainable_weights
return weights
return []
@property
def losses(self):
losses = []
losses += self.cell.losses
return losses + self._losses
@property
def updates(self):
updates = []
updates += self.cell.updates
return updates + self._updates
def get_weights(self):
weights = []
weights += self.cell.weights
return K.batch_get_value(weights)
def set_weights(self, weights):
tuples = []
cell_weights = weights[:len(self.cell.weights)]
if cell_weights:
tuples.append((self.cell.weights, cell_weights))
K.batch_set_value(tuples)
def _canonical_to_params(weights, biases, shape, transpose_weights=False): def _canonical_to_params(weights, biases, shape, transpose_weights=False):
"""Utility function convert variable to CuDNN compatible parameter. """Utility function convert variable to CuDNN compatible parameter.