Cleanup some duplicated methods for UnifiedLSTM.
The methods in the parent class should work the same way. PiperOrigin-RevId: 225094141
This commit is contained in:
parent
fc220a61b7
commit
9b8005ece0
@ -2530,6 +2530,7 @@ class LSTM(RNN):
|
||||
config['implementation'] = 1
|
||||
return cls(**config)
|
||||
|
||||
|
||||
@tf_export('keras.layers.LSTM', v1=[])
|
||||
class UnifiedLSTM(LSTM):
|
||||
"""Long Short-Term Memory layer - Hochreiter 1997.
|
||||
@ -2655,8 +2656,6 @@ class UnifiedLSTM(LSTM):
|
||||
self.state_spec = [
|
||||
InputSpec(shape=(None, dim)) for dim in (self.units, self.units)
|
||||
]
|
||||
self._num_constants = None
|
||||
self._num_inputs = None
|
||||
self._dropout_mask = None
|
||||
self.could_use_cudnn = (
|
||||
activation == 'tanh' and recurrent_activation == 'sigmoid' and
|
||||
@ -2775,46 +2774,6 @@ class UnifiedLSTM(LSTM):
|
||||
else:
|
||||
return output
|
||||
|
||||
@property
|
||||
def trainable_weights(self):
|
||||
if self.trainable:
|
||||
weights = []
|
||||
weights += self.cell.trainable_weights
|
||||
return weights
|
||||
return []
|
||||
|
||||
@property
|
||||
def non_trainable_weights(self):
|
||||
if not self.trainable:
|
||||
weights = []
|
||||
weights += self.cell.non_trainable_weights
|
||||
return weights
|
||||
return []
|
||||
|
||||
@property
|
||||
def losses(self):
|
||||
losses = []
|
||||
losses += self.cell.losses
|
||||
return losses + self._losses
|
||||
|
||||
@property
|
||||
def updates(self):
|
||||
updates = []
|
||||
updates += self.cell.updates
|
||||
return updates + self._updates
|
||||
|
||||
def get_weights(self):
|
||||
weights = []
|
||||
weights += self.cell.weights
|
||||
return K.batch_get_value(weights)
|
||||
|
||||
def set_weights(self, weights):
|
||||
tuples = []
|
||||
cell_weights = weights[:len(self.cell.weights)]
|
||||
if cell_weights:
|
||||
tuples.append((self.cell.weights, cell_weights))
|
||||
K.batch_set_value(tuples)
|
||||
|
||||
|
||||
def _canonical_to_params(weights, biases, shape, transpose_weights=False):
|
||||
"""Utility function convert variable to CuDNN compatible parameter.
|
||||
|
Loading…
Reference in New Issue
Block a user