Cleanup some duplicated methods for UnifiedLSTM.
The methods in the parent class should work the same way. PiperOrigin-RevId: 225094141
This commit is contained in:
parent
fc220a61b7
commit
9b8005ece0
@ -2530,6 +2530,7 @@ class LSTM(RNN):
|
|||||||
config['implementation'] = 1
|
config['implementation'] = 1
|
||||||
return cls(**config)
|
return cls(**config)
|
||||||
|
|
||||||
|
|
||||||
@tf_export('keras.layers.LSTM', v1=[])
|
@tf_export('keras.layers.LSTM', v1=[])
|
||||||
class UnifiedLSTM(LSTM):
|
class UnifiedLSTM(LSTM):
|
||||||
"""Long Short-Term Memory layer - Hochreiter 1997.
|
"""Long Short-Term Memory layer - Hochreiter 1997.
|
||||||
@ -2655,8 +2656,6 @@ class UnifiedLSTM(LSTM):
|
|||||||
self.state_spec = [
|
self.state_spec = [
|
||||||
InputSpec(shape=(None, dim)) for dim in (self.units, self.units)
|
InputSpec(shape=(None, dim)) for dim in (self.units, self.units)
|
||||||
]
|
]
|
||||||
self._num_constants = None
|
|
||||||
self._num_inputs = None
|
|
||||||
self._dropout_mask = None
|
self._dropout_mask = None
|
||||||
self.could_use_cudnn = (
|
self.could_use_cudnn = (
|
||||||
activation == 'tanh' and recurrent_activation == 'sigmoid' and
|
activation == 'tanh' and recurrent_activation == 'sigmoid' and
|
||||||
@ -2775,46 +2774,6 @@ class UnifiedLSTM(LSTM):
|
|||||||
else:
|
else:
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@property
|
|
||||||
def trainable_weights(self):
|
|
||||||
if self.trainable:
|
|
||||||
weights = []
|
|
||||||
weights += self.cell.trainable_weights
|
|
||||||
return weights
|
|
||||||
return []
|
|
||||||
|
|
||||||
@property
|
|
||||||
def non_trainable_weights(self):
|
|
||||||
if not self.trainable:
|
|
||||||
weights = []
|
|
||||||
weights += self.cell.non_trainable_weights
|
|
||||||
return weights
|
|
||||||
return []
|
|
||||||
|
|
||||||
@property
|
|
||||||
def losses(self):
|
|
||||||
losses = []
|
|
||||||
losses += self.cell.losses
|
|
||||||
return losses + self._losses
|
|
||||||
|
|
||||||
@property
|
|
||||||
def updates(self):
|
|
||||||
updates = []
|
|
||||||
updates += self.cell.updates
|
|
||||||
return updates + self._updates
|
|
||||||
|
|
||||||
def get_weights(self):
|
|
||||||
weights = []
|
|
||||||
weights += self.cell.weights
|
|
||||||
return K.batch_get_value(weights)
|
|
||||||
|
|
||||||
def set_weights(self, weights):
|
|
||||||
tuples = []
|
|
||||||
cell_weights = weights[:len(self.cell.weights)]
|
|
||||||
if cell_weights:
|
|
||||||
tuples.append((self.cell.weights, cell_weights))
|
|
||||||
K.batch_set_value(tuples)
|
|
||||||
|
|
||||||
|
|
||||||
def _canonical_to_params(weights, biases, shape, transpose_weights=False):
|
def _canonical_to_params(weights, biases, shape, transpose_weights=False):
|
||||||
"""Utility function convert variable to CuDNN compatible parameter.
|
"""Utility function convert variable to CuDNN compatible parameter.
|
||||||
|
Loading…
Reference in New Issue
Block a user