diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 4db48b45edd..6a762ee5d25 100644 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -498,6 +498,18 @@ py_test( ], ) +py_test( + name = "recurrent_test", + size = "small", + srcs = ["_impl/keras/layers/recurrent_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":keras", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + py_test( name = "serialization_test", size = "small", diff --git a/tensorflow/python/keras/_impl/keras/engine/topology.py b/tensorflow/python/keras/_impl/keras/engine/topology.py index f9be782f85e..2bcbabf19ce 100644 --- a/tensorflow/python/keras/_impl/keras/engine/topology.py +++ b/tensorflow/python/keras/_impl/keras/engine/topology.py @@ -29,6 +29,9 @@ from six.moves import zip # pylint: disable=redefined-builtin from tensorflow.python.eager import context from tensorflow.python.framework import tensor_shape from tensorflow.python.keras._impl.keras import backend as K +from tensorflow.python.keras._impl.keras import constraints +from tensorflow.python.keras._impl.keras import initializers +from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.utils import conv_utils from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.keras._impl.keras.utils.layer_utils import print_summary as print_layer_summary @@ -209,9 +212,9 @@ class Layer(tf_base_layers.Layer): dtype = K.floatx() weight = self.add_variable(name, shape, dtype=dtype, - initializer=initializer, - regularizer=regularizer, - constraint=constraint, + initializer=initializers.get(initializer), + regularizer=regularizers.get(regularizer), + constraint=constraints.get(constraint), trainable=trainable) return weight diff --git a/tensorflow/python/keras/_impl/keras/integration_test.py b/tensorflow/python/keras/_impl/keras/integration_test.py index 71100368480..871a8c73298 100644 --- a/tensorflow/python/keras/_impl/keras/integration_test.py +++ b/tensorflow/python/keras/_impl/keras/integration_test.py @@ -93,7 +93,7 @@ class KerasIntegrationTest(test.TestCase): y_test = keras.utils.to_categorical(y_test) model = keras.models.Sequential() - model.add(keras.layers.LSTM(3, return_sequences=True, + model.add(keras.layers.LSTM(5, return_sequences=True, input_shape=x_train.shape[1:])) model.add(keras.layers.GRU(y_train.shape[-1], activation='softmax')) model.compile(loss='categorical_crossentropy', diff --git a/tensorflow/python/keras/_impl/keras/layers/gru_test.py b/tensorflow/python/keras/_impl/keras/layers/gru_test.py index 03f0736161e..c57fbac41cc 100644 --- a/tensorflow/python/keras/_impl/keras/layers/gru_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/gru_test.py @@ -156,8 +156,10 @@ class GRULayerTest(test.TestCase): activity_regularizer='l1') layer.build((None, None, 2)) self.assertEqual(len(layer.losses), 3) - layer(keras.backend.variable(np.ones((2, 3, 2)))) - self.assertEqual(len(layer.losses), 4) + + x = keras.backend.variable(np.ones((2, 3, 2))) + layer(x) + self.assertEqual(len(layer.get_losses_for(x)), 1) def test_constraints_GRU(self): embedding_dim = 4 @@ -175,9 +177,9 @@ class GRULayerTest(test.TestCase): recurrent_constraint=r_constraint, bias_constraint=b_constraint) layer.build((None, None, embedding_dim)) - self.assertEqual(layer.kernel.constraint, k_constraint) - self.assertEqual(layer.recurrent_kernel.constraint, r_constraint) - self.assertEqual(layer.bias.constraint, b_constraint) + self.assertEqual(layer.cell.kernel.constraint, k_constraint) + self.assertEqual(layer.cell.recurrent_kernel.constraint, r_constraint) + self.assertEqual(layer.cell.bias.constraint, b_constraint) def test_with_masking_layer_GRU(self): layer_class = keras.layers.GRU diff --git a/tensorflow/python/keras/_impl/keras/layers/lstm_test.py b/tensorflow/python/keras/_impl/keras/layers/lstm_test.py index f43d90fec8f..8d359bf17cd 100644 --- a/tensorflow/python/keras/_impl/keras/layers/lstm_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/lstm_test.py @@ -156,8 +156,9 @@ class LSTMLayerTest(test.TestCase): activity_regularizer='l1') layer.build((None, None, 2)) self.assertEqual(len(layer.losses), 3) - layer(keras.backend.variable(np.ones((2, 3, 2)))) - self.assertEqual(len(layer.losses), 4) + x = keras.backend.variable(np.ones((2, 3, 2))) + layer(x) + self.assertEqual(len(layer.get_losses_for(x)), 1) def test_constraints_LSTM(self): embedding_dim = 4 @@ -175,9 +176,9 @@ class LSTMLayerTest(test.TestCase): recurrent_constraint=r_constraint, bias_constraint=b_constraint) layer.build((None, None, embedding_dim)) - self.assertEqual(layer.kernel.constraint, k_constraint) - self.assertEqual(layer.recurrent_kernel.constraint, r_constraint) - self.assertEqual(layer.bias.constraint, b_constraint) + self.assertEqual(layer.cell.kernel.constraint, k_constraint) + self.assertEqual(layer.cell.recurrent_kernel.constraint, r_constraint) + self.assertEqual(layer.cell.bias.constraint, b_constraint) def test_with_masking_layer_LSTM(self): layer_class = keras.layers.LSTM diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py index 139523403c1..2bc74d5f807 100644 --- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py +++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,93 +29,2157 @@ from tensorflow.python.keras._impl.keras import initializers from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.engine import InputSpec from tensorflow.python.keras._impl.keras.engine import Layer +from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg +from tensorflow.python.platform import tf_logging as logging -# pylint: disable=access-member-before-definition +class StackedRNNCells(Layer): + """Wrapper allowing a stack of RNN cells to behave as a single cell. - -def _time_distributed_dense(x, - w, - b=None, - dropout=None, - input_dim=None, - output_dim=None, - timesteps=None, - training=None): - """Apply `y . w + b` for every temporal slice y of x. + Used to implement efficient stacked RNNs. Arguments: - x: input tensor. - w: weight matrix. - b: optional bias vector. - dropout: whether to apply dropout (same dropout mask - for every temporal slice of the input). - input_dim: integer; optional dimensionality of the input. - output_dim: integer; optional dimensionality of the output. - timesteps: integer; optional number of timesteps. - training: training phase tensor or boolean. + cells: List of RNN cell instances. - Returns: - Output tensor. + Examples: + + ```python + cells = [ + keras.layers.LSTMCell(output_dim), + keras.layers.LSTMCell(output_dim), + keras.layers.LSTMCell(output_dim), + ] + + inputs = keras.Input((timesteps, input_dim)) + x = keras.layers.RNN(cells)(inputs) + ``` """ - if not input_dim: - input_dim = K.shape(x)[2] - if not timesteps: - timesteps = K.shape(x)[1] - if not output_dim: - output_dim = K.shape(w)[1] - if dropout is not None and 0. < dropout < 1.: - # apply the same dropout pattern at every timestep - ones = K.ones_like(K.reshape(x[:, 0, :], (-1, input_dim))) - dropout_matrix = K.dropout(ones, dropout) - expanded_dropout_matrix = K.repeat(dropout_matrix, timesteps) - x = K.in_train_phase(x * expanded_dropout_matrix, x, training=training) + def __init__(self, cells, **kwargs): + for cell in cells: + if not hasattr(cell, 'call'): + raise ValueError('All cells must have a `call` method. ' + 'received cells:', cells) + if not hasattr(cell, 'state_size'): + raise ValueError('All cells must have a ' + '`state_size` attribute. ' + 'received cells:', cells) + self.cells = cells + super(StackedRNNCells, self).__init__(**kwargs) - # collapse time dimension and batch dimension together - x = K.reshape(x, (-1, input_dim)) - x = K.dot(x, w) - if b is not None: - x = K.bias_add(x, b) - # reshape to 3D tensor - if K.backend() == 'tensorflow': - x = K.reshape(x, K.stack([-1, timesteps, output_dim])) - x.set_shape([None, None, output_dim]) - else: - x = K.reshape(x, (-1, timesteps, output_dim)) - return x + @property + def state_size(self): + # States are a flat list + # in reverse order of the cell stack. + # This allows to preserve the requirement + # `stack.state_size[0] == output_dim`. + # e.g. states of a 2-layer LSTM would be + # `[h2, c2, h1, c1]` + # (assuming one LSTM has states [h, c]) + state_size = [] + for cell in self.cells[::-1]: + if hasattr(cell.state_size, '__len__'): + state_size += list(cell.state_size) + else: + state_size.append(cell.state_size) + return tuple(state_size) + + def call(self, inputs, states, **kwargs): + # Recover per-cell states. + nested_states = [] + for cell in self.cells[::-1]: + if hasattr(cell.state_size, '__len__'): + nested_states.append(states[:len(cell.state_size)]) + states = states[len(cell.state_size):] + else: + nested_states.append([states[0]]) + states = states[1:] + nested_states = nested_states[::-1] + + # Call the cells in order and store the returned states. + new_nested_states = [] + for cell, states in zip(self.cells, nested_states): + inputs, states = cell.call(inputs, states, **kwargs) + new_nested_states.append(states) + + # Format the new states as a flat list + # in reverse cell order. + states = [] + for cell_states in new_nested_states[::-1]: + states += cell_states + return inputs, states + + def build(self, input_shape): + for cell in self.cells: + if isinstance(cell, Layer): + cell.build(input_shape) + if hasattr(cell.state_size, '__len__'): + output_dim = cell.state_size[0] + else: + output_dim = cell.state_size + input_shape = (input_shape[0], input_shape[1], output_dim) + self.built = True + + def get_config(self): + cells = [] + for cell in self.cells: + cells.append({ + 'class_name': cell.__class__.__name__, + 'config': cell.get_config() + }) + config = {'cells': cells} + base_config = super(StackedRNNCells, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + from tensorflow.python.keras._impl.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top + cells = [] + for cell_config in config.pop('cells'): + cells.append( + deserialize_layer(cell_config, custom_objects=custom_objects)) + return cls(cells, **config) + + @property + def trainable_weights(self): + if not self.trainable: + return [] + weights = [] + for cell in self.cells: + if isinstance(cell, Layer): + weights += cell.trainable_weights + return weights + + @property + def non_trainable_weights(self): + weights = [] + for cell in self.cells: + if isinstance(cell, Layer): + weights += cell.non_trainable_weights + if not self.trainable: + trainable_weights = [] + for cell in self.cells: + if isinstance(cell, Layer): + trainable_weights += cell.trainable_weights + return trainable_weights + weights + return weights + + def get_weights(self): + """Retrieves the weights of the model. + + Returns: + A flat list of Numpy arrays. + """ + weights = [] + for cell in self.cells: + if isinstance(cell, Layer): + weights += cell.weights + return K.batch_get_value(weights) + + def set_weights(self, weights): + """Sets the weights of the model. + + Arguments: + weights: A list of Numpy arrays with shapes and types matching + the output of `model.get_weights()`. + """ + tuples = [] + for cell in self.cells: + if isinstance(cell, Layer): + num_param = len(cell.weights) + weights = weights[:num_param] + for sw, w in zip(cell.weights, weights): + tuples.append((sw, w)) + weights = weights[num_param:] + K.batch_set_value(tuples) + + @property + def losses(self): + losses = [] + for cell in self.cells: + if isinstance(cell, Layer): + cell_losses = cell.losses + losses += cell_losses + return losses + + def get_losses_for(self, inputs=None): + losses = [] + for cell in self.cells: + if isinstance(cell, Layer): + cell_losses = cell.get_losses_for(inputs) + losses += cell_losses + return losses + + +class RNN(Layer): + """Base class for recurrent layers. + + Arguments: + cell: A RNN cell instance. A RNN cell is a class that has: + - a `call(input_at_t, states_at_t)` method, returning + `(output_at_t, states_at_t_plus_1)`. The call method of the + cell can also take the optional argument `constants`, see + section "Note on passing external constants" below. + - a `state_size` attribute. This can be a single integer + (single state) in which case it is + the size of the recurrent state + (which should be the same as the size of the cell output). + This can also be a list/tuple of integers + (one size per state). In this case, the first entry + (`state_size[0]`) should be the same as + the size of the cell output. + It is also possible for `cell` to be a list of RNN cell instances, + in which cases the cells get stacked on after the other in the RNN, + implementing an efficient stacked RNN. + return_sequences: Boolean. Whether to return the last output. + in the output sequence, or the full sequence. + return_state: Boolean. Whether to return the last state + in addition to the output. + go_backwards: Boolean (default False). + If True, process the input sequence backwards and return the + reversed sequence. + stateful: Boolean (default False). If True, the last state + for each sample at index i in a batch will be used as initial + state for the sample of index i in the following batch. + unroll: Boolean (default False). + If True, the network will be unrolled, + else a symbolic loop will be used. + Unrolling can speed-up a RNN, + although it tends to be more memory-intensive. + Unrolling is only suitable for short sequences. + input_dim: dimensionality of the input (integer). + This argument (or alternatively, + the keyword argument `input_shape`) + is required when using this layer as the first layer in a model. + input_length: Length of input sequences, to be specified + when it is constant. + This argument is required if you are going to connect + `Flatten` then `Dense` layers upstream + (without it, the shape of the dense outputs cannot be computed). + Note that if the recurrent layer is not the first layer + in your model, you would need to specify the input length + at the level of the first layer + (e.g. via the `input_shape` argument) + + Input shape: + 3D tensor with shape `(batch_size, timesteps, input_dim)`, + (Optional) 2D tensors with shape `(batch_size, output_dim)`. + + Output shape: + - if `return_state`: a list of tensors. The first tensor is + the output. The remaining tensors are the last states, + each with shape `(batch_size, units)`. + - if `return_sequences`: 3D tensor with shape + `(batch_size, timesteps, units)`. + - else, 2D tensor with shape `(batch_size, units)`. + + # Masking + This layer supports masking for input data with a variable number + of timesteps. To introduce masks to your data, + use an [Embedding](embeddings.md) layer with the `mask_zero` parameter + set to `True`. + + # Note on using statefulness in RNNs + You can set RNN layers to be 'stateful', which means that the states + computed for the samples in one batch will be reused as initial states + for the samples in the next batch. This assumes a one-to-one mapping + between samples in different successive batches. + + To enable statefulness: + - specify `stateful=True` in the layer constructor. + - specify a fixed batch size for your model, by passing + if sequential model: + `batch_input_shape=(...)` to the first layer in your model. + else for functional model with 1 or more Input layers: + `batch_shape=(...)` to all the first layers in your model. + This is the expected shape of your inputs + *including the batch size*. + It should be a tuple of integers, e.g. `(32, 10, 100)`. + - specify `shuffle=False` when calling fit(). + + To reset the states of your model, call `.reset_states()` on either + a specific layer, or on your entire model. + + # Note on specifying the initial state of RNNs + You can specify the initial state of RNN layers symbolically by + calling them with the keyword argument `initial_state`. The value of + `initial_state` should be a tensor or list of tensors representing + the initial state of the RNN layer. + + You can specify the initial state of RNN layers numerically by + calling `reset_states` with the keyword argument `states`. The value of + `states` should be a numpy array or list of numpy arrays representing + the initial state of the RNN layer. + + # Note on passing external constants to RNNs + You can pass "external" constants to the cell using the `constants` + keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This + requires that the `cell.call` method accepts the same keyword argument + `constants`. Such constants can be used to condition the cell + transformation on additional static inputs (not changing over time), + a.k.a. an attention mechanism. + + Examples: + + ```python + # First, let's define a RNN Cell, as a layer subclass. + + class MinimalRNNCell(keras.layers.Layer): + + def __init__(self, units, **kwargs): + self.units = units + self.state_size = units + super(MinimalRNNCell, self).__init__(**kwargs) + + def build(self, input_shape): + self.kernel = self.add_weight(shape=(input_shape[-1], self.units), + initializer='uniform', + name='kernel') + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units), + initializer='uniform', + name='recurrent_kernel') + self.built = True + + def call(self, inputs, states): + prev_output = states[0] + h = K.dot(inputs, self.kernel) + output = h + K.dot(prev_output, self.recurrent_kernel) + return output, [output] + + # Let's use this cell in a RNN layer: + + cell = MinimalRNNCell(32) + x = keras.Input((None, 5)) + layer = RNN(cell) + y = layer(x) + + # Here's how to use the cell to build a stacked RNN: + + cells = [MinimalRNNCell(32), MinimalRNNCell(64)] + x = keras.Input((None, 5)) + layer = RNN(cells) + y = layer(x) + ``` + """ + + def __init__(self, + cell, + return_sequences=False, + return_state=False, + go_backwards=False, + stateful=False, + unroll=False, + activity_regularizer=None, + **kwargs): + if isinstance(cell, (list, tuple)): + cell = StackedRNNCells(cell) + if not hasattr(cell, 'call'): + raise ValueError('`cell` should have a `call` method. ' + 'The RNN was passed:', cell) + if not hasattr(cell, 'state_size'): + raise ValueError('The RNN cell should have ' + 'an attribute `state_size` ' + '(tuple of integers, ' + 'one integer per RNN state).') + super(RNN, self).__init__( + activity_regularizer=regularizers.get(activity_regularizer), **kwargs) + self.cell = cell + self.return_sequences = return_sequences + self.return_state = return_state + self.go_backwards = go_backwards + self.stateful = stateful + self.unroll = unroll + + self.supports_masking = True + self.input_spec = [InputSpec(ndim=3)] + self.state_spec = None + self._states = None + self.constants_spec = None + self._num_constants = None + + @property + def states(self): + if self._states is None: + if isinstance(self.cell.state_size, int): + num_states = 1 + else: + num_states = len(self.cell.state_size) + return [None for _ in range(num_states)] + return self._states + + @states.setter + def states(self, states): + self._states = states + + def _compute_output_shape(self, input_shape): + if isinstance(input_shape, list): + input_shape = input_shape[0] + input_shape = tensor_shape.TensorShape(input_shape).as_list() + + if hasattr(self.cell.state_size, '__len__'): + output_dim = self.cell.state_size[0] + else: + output_dim = self.cell.state_size + + if self.return_sequences: + output_shape = (input_shape[0], input_shape[1], output_dim) + else: + output_shape = (input_shape[0], output_dim) + + if self.return_state: + state_shape = [(input_shape[0], output_dim) for _ in self.states] + output_shape = [output_shape] + state_shape + else: + output_shape = output_shape + return tensor_shape.TensorShape(output_shape) + + def compute_mask(self, inputs, mask): + if isinstance(mask, list): + mask = mask[0] + output_mask = mask if self.return_sequences else None + if self.return_state: + state_mask = [None for _ in self.states] + return [output_mask] + state_mask + else: + return output_mask + + def build(self, input_shape): + # Note input_shape will be list of shapes of initial states and + # constants if these are passed in __call__. + if self._num_constants is not None: + constants_shape = input_shape[-self._num_constants:] # pylint: disable=invalid-unary-operand-type + else: + constants_shape = None + + if isinstance(input_shape, list): + input_shape = input_shape[0] + input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list()) + + batch_size = input_shape[0] if self.stateful else None + input_dim = input_shape[-1] + self.input_spec[0] = InputSpec(shape=(batch_size, None, input_dim)) + + # allow cell (if layer) to build before we set or validate state_spec + if isinstance(self.cell, Layer): + step_input_shape = (input_shape[0],) + input_shape[2:] + if constants_shape is not None: + self.cell.build([step_input_shape] + constants_shape) + else: + self.cell.build(step_input_shape) + + # set or validate state_spec + if hasattr(self.cell.state_size, '__len__'): + state_size = list(self.cell.state_size) + else: + state_size = [self.cell.state_size] + + if self.state_spec is not None: + # initial_state was passed in call, check compatibility + if [spec.shape[-1] for spec in self.state_spec] != state_size: + raise ValueError( + 'An initial_state was passed that is not compatible with ' + '`cell.state_size`. Received `state_spec`={}; ' + 'However `cell.state_size` is ' + '{}'.format(self.state_spec, self.cell.state_size)) + else: + self.state_spec = [InputSpec(shape=(None, dim)) for dim in state_size] + if self.stateful: + self.reset_states() + + def get_initial_state(self, inputs): + # build an all-zero tensor of shape (samples, output_dim) + initial_state = K.zeros_like(inputs) # (samples, timesteps, input_dim) + initial_state = K.sum(initial_state, axis=(1, 2)) # (samples,) + initial_state = K.expand_dims(initial_state) # (samples, 1) + if hasattr(self.cell.state_size, '__len__'): + return [K.tile(initial_state, [1, dim]) for dim in self.cell.state_size] + else: + return [K.tile(initial_state, [1, self.cell.state_size])] + + def __call__(self, inputs, initial_state=None, constants=None, **kwargs): + inputs, initial_state, constants = self._standardize_args( + inputs, initial_state, constants) + + if initial_state is None and constants is None: + return super(RNN, self).__call__(inputs, **kwargs) + + # If any of `initial_state` or `constants` are specified and are Keras + # tensors, then add them to the inputs and temporarily modify the + # input_spec to include them. + + additional_inputs = [] + additional_specs = [] + if initial_state is not None: + kwargs['initial_state'] = initial_state + additional_inputs += initial_state + self.state_spec = [ + InputSpec(shape=K.int_shape(state)) for state in initial_state + ] + additional_specs += self.state_spec + if constants is not None: + kwargs['constants'] = constants + additional_inputs += constants + self.constants_spec = [ + InputSpec(shape=K.int_shape(constant)) for constant in constants + ] + self._num_constants = len(constants) + additional_specs += self.constants_spec + # at this point additional_inputs cannot be empty + is_keras_tensor = hasattr(additional_inputs[0], '_keras_history') + for tensor in additional_inputs: + if hasattr(tensor, '_keras_history') != is_keras_tensor: + raise ValueError('The initial state or constants of an RNN' + ' layer cannot be specified with a mix of' + ' Keras tensors and non-Keras tensors') + + if is_keras_tensor: + # Compute the full input spec, including state and constants + full_input = [inputs] + additional_inputs + full_input_spec = self.input_spec + additional_specs + # Perform the call with temporarily replaced input_spec + original_input_spec = self.input_spec + self.input_spec = full_input_spec + output = super(RNN, self).__call__(full_input, **kwargs) + self.input_spec = original_input_spec + return output + else: + return super(RNN, self).__call__(inputs, **kwargs) + + def call(self, + inputs, + mask=None, + training=None, + initial_state=None, + constants=None): + # input shape: `(samples, time (padded with zeros), input_dim)` + # note that the .build() method of subclasses MUST define + # self.input_spec and self.state_spec with complete input shapes. + if isinstance(inputs, list): + inputs = inputs[0] + if initial_state is not None: + pass + elif self.stateful: + initial_state = self.states + else: + initial_state = self.get_initial_state(inputs) + + if isinstance(mask, list): + mask = mask[0] + + if len(initial_state) != len(self.states): + raise ValueError( + 'Layer has ' + str(len(self.states)) + ' states but was passed ' + + str(len(initial_state)) + ' initial states.') + input_shape = K.int_shape(inputs) + timesteps = input_shape[1] + if self.unroll and timesteps in [None, 1]: + raise ValueError('Cannot unroll a RNN if the ' + 'time dimension is undefined or equal to 1. \n' + '- If using a Sequential model, ' + 'specify the time dimension by passing ' + 'an `input_shape` or `batch_input_shape` ' + 'argument to your first layer. If your ' + 'first layer is an Embedding, you can ' + 'also use the `input_length` argument.\n' + '- If using the functional API, specify ' + 'the time dimension by passing a `shape` ' + 'or `batch_shape` argument to your Input layer.') + + kwargs = {} + if has_arg(self.cell.call, 'training'): + kwargs['training'] = training + + if constants: + if not has_arg(self.cell.call, 'constants'): + raise ValueError('RNN cell does not support constants') + + def step(inputs, states): + constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type + states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type + return self.cell.call(inputs, states, constants=constants, **kwargs) + else: + + def step(inputs, states): + return self.cell.call(inputs, states, **kwargs) + + last_output, outputs, states = K.rnn( + step, + inputs, + initial_state, + constants=constants, + go_backwards=self.go_backwards, + mask=mask, + unroll=self.unroll) + if self.stateful: + updates = [] + for i in range(len(states)): + updates.append((self.states[i], states[i])) + self.add_update(updates, inputs) + + if self.return_sequences: + output = outputs + else: + output = last_output + + # Properly set learning phase + if getattr(last_output, '_uses_learning_phase', False): + output._uses_learning_phase = True + + if self.return_state: + if not isinstance(states, (list, tuple)): + states = [states] + else: + states = list(states) + return [output] + states + else: + return output + + def _standardize_args(self, inputs, initial_state, constants): + """Standardize `__call__` arguments to a single list of tensor inputs. + + When running a model loaded from file, the input tensors + `initial_state` and `constants` can be passed to `RNN.__call__` as part + of `inputs` instead of by the dedicated keyword arguments. This method + makes sure the arguments are separated and that `initial_state` and + `constants` are lists of tensors (or None). + + Arguments: + inputs: tensor or list/tuple of tensors + initial_state: tensor or list of tensors or None + constants: tensor or list of tensors or None + + Returns: + inputs: tensor + initial_state: list of tensors or None + constants: list of tensors or None + """ + if isinstance(inputs, list): + assert initial_state is None and constants is None + if self._num_constants is not None: + constants = inputs[-self._num_constants:] # pylint: disable=invalid-unary-operand-type + inputs = inputs[:-self._num_constants] # pylint: disable=invalid-unary-operand-type + if len(inputs) > 1: + initial_state = inputs[1:] + inputs = inputs[0] + + def to_list_or_none(x): + if x is None or isinstance(x, list): + return x + if isinstance(x, tuple): + return list(x) + return [x] + + initial_state = to_list_or_none(initial_state) + constants = to_list_or_none(constants) + + return inputs, initial_state, constants + + def reset_states(self, states=None): + if not self.stateful: + raise AttributeError('Layer must be stateful.') + batch_size = self.input_spec[0].shape[0] + if not batch_size: + raise ValueError('If a RNN is stateful, it needs to know ' + 'its batch size. Specify the batch size ' + 'of your input tensors: \n' + '- If using a Sequential model, ' + 'specify the batch size by passing ' + 'a `batch_input_shape` ' + 'argument to your first layer.\n' + '- If using the functional API, specify ' + 'the time dimension by passing a ' + '`batch_shape` argument to your Input layer.') + # initialize state if None + if self.states[0] is None: + if hasattr(self.cell.state_size, '__len__'): + self.states = [ + K.zeros((batch_size, dim)) for dim in self.cell.state_size + ] + else: + self.states = [K.zeros((batch_size, self.cell.state_size))] + elif states is None: + if hasattr(self.cell.state_size, '__len__'): + for state, dim in zip(self.states, self.cell.state_size): + K.set_value(state, np.zeros((batch_size, dim))) + else: + K.set_value(self.states[0], np.zeros((batch_size, + self.cell.state_size))) + else: + if not isinstance(states, (list, tuple)): + states = [states] + if len(states) != len(self.states): + raise ValueError('Layer ' + self.name + ' expects ' + + str(len(self.states)) + ' states, ' + 'but it received ' + str(len(states)) + + ' state values. Input received: ' + str(states)) + for index, (value, state) in enumerate(zip(states, self.states)): + if hasattr(self.cell.state_size, '__len__'): + dim = self.cell.state_size[index] + else: + dim = self.cell.state_size + if value.shape != (batch_size, dim): + raise ValueError( + 'State ' + str(index) + ' is incompatible with layer ' + + self.name + ': expected shape=' + str( + (batch_size, dim)) + ', found shape=' + str(value.shape)) + # TODO(fchollet): consider batch calls to `set_value`. + K.set_value(state, value) + + def get_config(self): + config = { + 'return_sequences': self.return_sequences, + 'return_state': self.return_state, + 'go_backwards': self.go_backwards, + 'stateful': self.stateful, + 'unroll': self.unroll + } + if self._num_constants is not None: + config['num_constants'] = self._num_constants + + cell_config = self.cell.get_config() + config['cell'] = { + 'class_name': self.cell.__class__.__name__, + 'config': cell_config + } + base_config = super(RNN, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + from tensorflow.python.keras._impl.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top + cell = deserialize_layer(config.pop('cell'), custom_objects=custom_objects) + num_constants = config.pop('num_constants', None) + layer = cls(cell, **config) + layer._num_constants = num_constants + return layer + + @property + def trainable_weights(self): + if isinstance(self.cell, Layer): + return self.cell.trainable_weights + return [] + + @property + def non_trainable_weights(self): + if isinstance(self.cell, Layer): + return self.cell.non_trainable_weights + return [] + + @property + def losses(self): + if isinstance(self.cell, Layer): + return self.cell.losses + return [] + + def get_losses_for(self, inputs=None): + if isinstance(self.cell, Layer): + cell_losses = self.cell.get_losses_for(inputs) + return cell_losses + super(RNN, self).get_losses_for(inputs) + return super(RNN, self).get_losses_for(inputs) + + +class SimpleRNNCell(Layer): + """Cell class for SimpleRNN. + + Arguments: + units: Positive integer, dimensionality of the output space. + activation: Activation function to use + (see [activations](../activations.md)). + If you pass None, no activation is applied + (ie. "linear" activation: `a(x) = x`). + use_bias: Boolean, whether the layer uses a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. + (see [initializers](../initializers.md)). + recurrent_initializer: Initializer for the `recurrent_kernel` + weights matrix, + used for the linear transformation of the recurrent state. + (see [initializers](../initializers.md)). + bias_initializer: Initializer for the bias vector + (see [initializers](../initializers.md)). + kernel_regularizer: Regularizer function applied to + the `kernel` weights matrix + (see [regularizer](../regularizers.md)). + recurrent_regularizer: Regularizer function applied to + the `recurrent_kernel` weights matrix + (see [regularizer](../regularizers.md)). + bias_regularizer: Regularizer function applied to the bias vector + (see [regularizer](../regularizers.md)). + kernel_constraint: Constraint function applied to + the `kernel` weights matrix + (see [constraints](../constraints.md)). + recurrent_constraint: Constraint function applied to + the `recurrent_kernel` weights matrix + (see [constraints](../constraints.md)). + bias_constraint: Constraint function applied to the bias vector + (see [constraints](../constraints.md)). + dropout: Float between 0 and 1. + Fraction of the units to drop for + the linear transformation of the inputs. + recurrent_dropout: Float between 0 and 1. + Fraction of the units to drop for + the linear transformation of the recurrent state. + """ + + def __init__(self, + units, + activation='tanh', + use_bias=True, + kernel_initializer='glorot_uniform', + recurrent_initializer='orthogonal', + bias_initializer='zeros', + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + dropout=0., + recurrent_dropout=0., + **kwargs): + super(SimpleRNNCell, self).__init__(**kwargs) + self.units = units + self.activation = activations.get(activation) + self.use_bias = use_bias + + self.kernel_initializer = initializers.get(kernel_initializer) + self.recurrent_initializer = initializers.get(recurrent_initializer) + self.bias_initializer = initializers.get(bias_initializer) + + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.recurrent_regularizer = regularizers.get(recurrent_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + + self.kernel_constraint = constraints.get(kernel_constraint) + self.recurrent_constraint = constraints.get(recurrent_constraint) + self.bias_constraint = constraints.get(bias_constraint) + + self.dropout = min(1., max(0., dropout)) + self.recurrent_dropout = min(1., max(0., recurrent_dropout)) + self.state_size = self.units + self._dropout_mask = None + self._recurrent_dropout_mask = None + + def build(self, input_shape): + self.kernel = self.add_weight( + shape=(input_shape[-1], self.units), + name='kernel', + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint) + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units), + name='recurrent_kernel', + initializer=self.recurrent_initializer, + regularizer=self.recurrent_regularizer, + constraint=self.recurrent_constraint) + if self.use_bias: + self.bias = self.add_weight( + shape=(self.units,), + name='bias', + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint) + else: + self.bias = None + self.built = True + + def _generate_dropout_mask(self, inputs, training=None): + if 0 < self.dropout < 1: + ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1)) + + def dropped_inputs(): + return K.dropout(ones, self.dropout) + + self._dropout_mask = K.in_train_phase( + dropped_inputs, ones, training=training) + else: + self._dropout_mask = None + + def _generate_recurrent_dropout_mask(self, inputs, training=None): + if 0 < self.recurrent_dropout < 1: + ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) + ones = K.tile(ones, (1, self.units)) + + def dropped_inputs(): + return K.dropout(ones, self.dropout) + + self._recurrent_dropout_mask = K.in_train_phase( + dropped_inputs, ones, training=training) + else: + self._recurrent_dropout_mask = None + + def call(self, inputs, states, training=None): + prev_output = states[0] + dp_mask = self._dropout_mask + rec_dp_mask = self._recurrent_dropout_mask + + if dp_mask is not None: + h = K.dot(inputs * dp_mask, self.kernel) + else: + h = K.dot(inputs, self.kernel) + if self.bias is not None: + h = K.bias_add(h, self.bias) + + if rec_dp_mask is not None: + prev_output *= rec_dp_mask + output = h + K.dot(prev_output, self.recurrent_kernel) + if self.activation is not None: + output = self.activation(output) + + # Properly set learning phase on output tensor. + if 0 < self.dropout + self.recurrent_dropout: + if training is None: + output._uses_learning_phase = True + return output, [output] + + +class SimpleRNN(RNN): + """Fully-connected RNN where the output is to be fed back to input. + + Arguments: + units: Positive integer, dimensionality of the output space. + activation: Activation function to use + (see [activations](../activations.md)). + If you pass None, no activation is applied + (ie. "linear" activation: `a(x) = x`). + use_bias: Boolean, whether the layer uses a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. + (see [initializers](../initializers.md)). + recurrent_initializer: Initializer for the `recurrent_kernel` + weights matrix, + used for the linear transformation of the recurrent state. + (see [initializers](../initializers.md)). + bias_initializer: Initializer for the bias vector + (see [initializers](../initializers.md)). + kernel_regularizer: Regularizer function applied to + the `kernel` weights matrix + (see [regularizer](../regularizers.md)). + recurrent_regularizer: Regularizer function applied to + the `recurrent_kernel` weights matrix + (see [regularizer](../regularizers.md)). + bias_regularizer: Regularizer function applied to the bias vector + (see [regularizer](../regularizers.md)). + activity_regularizer: Regularizer function applied to + the output of the layer (its "activation"). + (see [regularizer](../regularizers.md)). + kernel_constraint: Constraint function applied to + the `kernel` weights matrix + (see [constraints](../constraints.md)). + recurrent_constraint: Constraint function applied to + the `recurrent_kernel` weights matrix + (see [constraints](../constraints.md)). + bias_constraint: Constraint function applied to the bias vector + (see [constraints](../constraints.md)). + dropout: Float between 0 and 1. + Fraction of the units to drop for + the linear transformation of the inputs. + recurrent_dropout: Float between 0 and 1. + Fraction of the units to drop for + the linear transformation of the recurrent state. + return_sequences: Boolean. Whether to return the last output. + in the output sequence, or the full sequence. + return_state: Boolean. Whether to return the last state + in addition to the output. + go_backwards: Boolean (default False). + If True, process the input sequence backwards and return the + reversed sequence. + stateful: Boolean (default False). If True, the last state + for each sample at index i in a batch will be used as initial + state for the sample of index i in the following batch. + unroll: Boolean (default False). + If True, the network will be unrolled, + else a symbolic loop will be used. + Unrolling can speed-up a RNN, + although it tends to be more memory-intensive. + Unrolling is only suitable for short sequences. + """ + + def __init__(self, + units, + activation='tanh', + use_bias=True, + kernel_initializer='glorot_uniform', + recurrent_initializer='orthogonal', + bias_initializer='zeros', + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + dropout=0., + recurrent_dropout=0., + return_sequences=False, + return_state=False, + go_backwards=False, + stateful=False, + unroll=False, + **kwargs): + if 'implementation' in kwargs: + kwargs.pop('implementation') + logging.warning('The `implementation` argument ' + 'in `SimpleRNN` has been deprecated. ' + 'Please remove it from your layer call.') + cell = SimpleRNNCell( + units, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + recurrent_initializer=recurrent_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + recurrent_regularizer=recurrent_regularizer, + bias_regularizer=bias_regularizer, + kernel_constraint=kernel_constraint, + recurrent_constraint=recurrent_constraint, + bias_constraint=bias_constraint, + dropout=dropout, + recurrent_dropout=recurrent_dropout) + super(SimpleRNN, self).__init__( + cell, + return_sequences=return_sequences, + return_state=return_state, + go_backwards=go_backwards, + stateful=stateful, + unroll=unroll, + activity_regularizer=regularizers.get(activity_regularizer), + **kwargs) + # self.activity_regularizer = regularizers.get(activity_regularizer) + + def call(self, inputs, mask=None, training=None, initial_state=None): + self.cell._generate_dropout_mask(inputs, training=training) + self.cell._generate_recurrent_dropout_mask(inputs, training=training) + return super(SimpleRNN, self).call( + inputs, mask=mask, training=training, initial_state=initial_state) + + @property + def units(self): + return self.cell.units + + @property + def activation(self): + return self.cell.activation + + @property + def use_bias(self): + return self.cell.use_bias + + @property + def kernel_initializer(self): + return self.cell.kernel_initializer + + @property + def recurrent_initializer(self): + return self.cell.recurrent_initializer + + @property + def bias_initializer(self): + return self.cell.bias_initializer + + @property + def kernel_regularizer(self): + return self.cell.kernel_regularizer + + @property + def recurrent_regularizer(self): + return self.cell.recurrent_regularizer + + @property + def bias_regularizer(self): + return self.cell.bias_regularizer + + @property + def kernel_constraint(self): + return self.cell.kernel_constraint + + @property + def recurrent_constraint(self): + return self.cell.recurrent_constraint + + @property + def bias_constraint(self): + return self.cell.bias_constraint + + @property + def dropout(self): + return self.cell.dropout + + @property + def recurrent_dropout(self): + return self.cell.recurrent_dropout + + def get_config(self): + config = { + 'units': + self.units, + 'activation': + activations.serialize(self.activation), + 'use_bias': + self.use_bias, + 'kernel_initializer': + initializers.serialize(self.kernel_initializer), + 'recurrent_initializer': + initializers.serialize(self.recurrent_initializer), + 'bias_initializer': + initializers.serialize(self.bias_initializer), + 'kernel_regularizer': + regularizers.serialize(self.kernel_regularizer), + 'recurrent_regularizer': + regularizers.serialize(self.recurrent_regularizer), + 'bias_regularizer': + regularizers.serialize(self.bias_regularizer), + 'activity_regularizer': + regularizers.serialize(self.activity_regularizer), + 'kernel_constraint': + constraints.serialize(self.kernel_constraint), + 'recurrent_constraint': + constraints.serialize(self.recurrent_constraint), + 'bias_constraint': + constraints.serialize(self.bias_constraint), + 'dropout': + self.dropout, + 'recurrent_dropout': + self.recurrent_dropout + } + base_config = super(SimpleRNN, self).get_config() + del base_config['cell'] + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config): + if 'implementation' in config: + config.pop('implementation') + return cls(**config) + + +class GRUCell(Layer): + """Cell class for the GRU layer. + + Arguments: + units: Positive integer, dimensionality of the output space. + activation: Activation function to use + (see [activations](../activations.md)). + If you pass None, no activation is applied + (ie. "linear" activation: `a(x) = x`). + recurrent_activation: Activation function to use + for the recurrent step + (see [activations](../activations.md)). + use_bias: Boolean, whether the layer uses a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. + (see [initializers](../initializers.md)). + recurrent_initializer: Initializer for the `recurrent_kernel` + weights matrix, + used for the linear transformation of the recurrent state. + (see [initializers](../initializers.md)). + bias_initializer: Initializer for the bias vector + (see [initializers](../initializers.md)). + kernel_regularizer: Regularizer function applied to + the `kernel` weights matrix + (see [regularizer](../regularizers.md)). + recurrent_regularizer: Regularizer function applied to + the `recurrent_kernel` weights matrix + (see [regularizer](../regularizers.md)). + bias_regularizer: Regularizer function applied to the bias vector + (see [regularizer](../regularizers.md)). + kernel_constraint: Constraint function applied to + the `kernel` weights matrix + (see [constraints](../constraints.md)). + recurrent_constraint: Constraint function applied to + the `recurrent_kernel` weights matrix + (see [constraints](../constraints.md)). + bias_constraint: Constraint function applied to the bias vector + (see [constraints](../constraints.md)). + dropout: Float between 0 and 1. + Fraction of the units to drop for + the linear transformation of the inputs. + recurrent_dropout: Float between 0 and 1. + Fraction of the units to drop for + the linear transformation of the recurrent state. + implementation: Implementation mode, either 1 or 2. + Mode 1 will structure its operations as a larger number of + smaller dot products and additions, whereas mode 2 will + batch them into fewer, larger operations. These modes will + have different performance profiles on different hardware and + for different applications. + """ + + def __init__(self, + units, + activation='tanh', + recurrent_activation='hard_sigmoid', + use_bias=True, + kernel_initializer='glorot_uniform', + recurrent_initializer='orthogonal', + bias_initializer='zeros', + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + dropout=0., + recurrent_dropout=0., + implementation=1, + **kwargs): + super(GRUCell, self).__init__(**kwargs) + self.units = units + self.activation = activations.get(activation) + self.recurrent_activation = activations.get(recurrent_activation) + self.use_bias = use_bias + + self.kernel_initializer = initializers.get(kernel_initializer) + self.recurrent_initializer = initializers.get(recurrent_initializer) + self.bias_initializer = initializers.get(bias_initializer) + + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.recurrent_regularizer = regularizers.get(recurrent_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + + self.kernel_constraint = constraints.get(kernel_constraint) + self.recurrent_constraint = constraints.get(recurrent_constraint) + self.bias_constraint = constraints.get(bias_constraint) + + self.dropout = min(1., max(0., dropout)) + self.recurrent_dropout = min(1., max(0., recurrent_dropout)) + self.implementation = implementation + self.state_size = self.units + self._dropout_mask = None + self._recurrent_dropout_mask = None + + def build(self, input_shape): + input_dim = input_shape[-1] + self.kernel = self.add_weight( + shape=(input_dim, self.units * 3), + name='kernel', + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint) + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units * 3), + name='recurrent_kernel', + initializer=self.recurrent_initializer, + regularizer=self.recurrent_regularizer, + constraint=self.recurrent_constraint) + + if self.use_bias: + self.bias = self.add_weight( + shape=(self.units * 3,), + name='bias', + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint) + else: + self.bias = None + + self.kernel_z = self.kernel[:, :self.units] + self.recurrent_kernel_z = self.recurrent_kernel[:, :self.units] + self.kernel_r = self.kernel[:, self.units:self.units * 2] + self.recurrent_kernel_r = self.recurrent_kernel[:, self.units: + self.units * 2] + self.kernel_h = self.kernel[:, self.units * 2:] + self.recurrent_kernel_h = self.recurrent_kernel[:, self.units * 2:] + + if self.use_bias: + self.bias_z = self.bias[:self.units] + self.bias_r = self.bias[self.units:self.units * 2] + self.bias_h = self.bias[self.units * 2:] + else: + self.bias_z = None + self.bias_r = None + self.bias_h = None + self.built = True + + def _generate_dropout_mask(self, inputs, training=None): + if 0 < self.dropout < 1: + ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1)) + + def dropped_inputs(): + return K.dropout(ones, self.dropout) + + self._dropout_mask = [ + K.in_train_phase(dropped_inputs, ones, training=training) + for _ in range(3) + ] + else: + self._dropout_mask = None + + def _generate_recurrent_dropout_mask(self, inputs, training=None): + if 0 < self.recurrent_dropout < 1: + ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) + ones = K.tile(ones, (1, self.units)) + + def dropped_inputs(): + return K.dropout(ones, self.dropout) + + self._recurrent_dropout_mask = [ + K.in_train_phase(dropped_inputs, ones, training=training) + for _ in range(3) + ] + else: + self._recurrent_dropout_mask = None + + def call(self, inputs, states, training=None): + h_tm1 = states[0] # previous memory + + # dropout matrices for input units + dp_mask = self._dropout_mask + # dropout matrices for recurrent units + rec_dp_mask = self._recurrent_dropout_mask + + if self.implementation == 1: + if 0. < self.dropout < 1.: + inputs_z = inputs * dp_mask[0] + inputs_r = inputs * dp_mask[1] + inputs_h = inputs * dp_mask[2] + else: + inputs_z = inputs + inputs_r = inputs + inputs_h = inputs + x_z = K.dot(inputs_z, self.kernel_z) + x_r = K.dot(inputs_r, self.kernel_r) + x_h = K.dot(inputs_h, self.kernel_h) + if self.use_bias: + x_z = K.bias_add(x_z, self.bias_z) + x_r = K.bias_add(x_r, self.bias_r) + x_h = K.bias_add(x_h, self.bias_h) + + if 0. < self.recurrent_dropout < 1.: + h_tm1_z = h_tm1 * rec_dp_mask[0] + h_tm1_r = h_tm1 * rec_dp_mask[1] + h_tm1_h = h_tm1 * rec_dp_mask[2] + else: + h_tm1_z = h_tm1 + h_tm1_r = h_tm1 + h_tm1_h = h_tm1 + z = self.recurrent_activation( + x_z + K.dot(h_tm1_z, self.recurrent_kernel_z)) + r = self.recurrent_activation( + x_r + K.dot(h_tm1_r, self.recurrent_kernel_r)) + + hh = self.activation(x_h + K.dot(r * h_tm1_h, self.recurrent_kernel_h)) + else: + if 0. < self.dropout < 1.: + inputs *= dp_mask[0] + matrix_x = K.dot(inputs, self.kernel) + if self.use_bias: + matrix_x = K.bias_add(matrix_x, self.bias) + if 0. < self.recurrent_dropout < 1.: + h_tm1 *= rec_dp_mask[0] + matrix_inner = K.dot(h_tm1, self.recurrent_kernel[:, :2 * self.units]) + + x_z = matrix_x[:, :self.units] + x_r = matrix_x[:, self.units:2 * self.units] + recurrent_z = matrix_inner[:, :self.units] + recurrent_r = matrix_inner[:, self.units:2 * self.units] + + z = self.recurrent_activation(x_z + recurrent_z) + r = self.recurrent_activation(x_r + recurrent_r) + + x_h = matrix_x[:, 2 * self.units:] + recurrent_h = K.dot(r * h_tm1, self.recurrent_kernel[:, 2 * self.units:]) + hh = self.activation(x_h + recurrent_h) + h = z * h_tm1 + (1 - z) * hh + if 0 < self.dropout + self.recurrent_dropout: + if training is None: + h._uses_learning_phase = True + return h, [h] + + +class GRU(RNN): + # pylint: disable=line-too-long + """Gated Recurrent Unit - Cho et al. + + 2014. + + Arguments: + units: Positive integer, dimensionality of the output space. + activation: Activation function to use + (see [activations](../activations.md)). + If you pass None, no activation is applied + (ie. "linear" activation: `a(x) = x`). + recurrent_activation: Activation function to use + for the recurrent step + (see [activations](../activations.md)). + use_bias: Boolean, whether the layer uses a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. + (see [initializers](../initializers.md)). + recurrent_initializer: Initializer for the `recurrent_kernel` + weights matrix, + used for the linear transformation of the recurrent state. + (see [initializers](../initializers.md)). + bias_initializer: Initializer for the bias vector + (see [initializers](../initializers.md)). + kernel_regularizer: Regularizer function applied to + the `kernel` weights matrix + (see [regularizer](../regularizers.md)). + recurrent_regularizer: Regularizer function applied to + the `recurrent_kernel` weights matrix + (see [regularizer](../regularizers.md)). + bias_regularizer: Regularizer function applied to the bias vector + (see [regularizer](../regularizers.md)). + activity_regularizer: Regularizer function applied to + the output of the layer (its "activation"). + (see [regularizer](../regularizers.md)). + kernel_constraint: Constraint function applied to + the `kernel` weights matrix + (see [constraints](../constraints.md)). + recurrent_constraint: Constraint function applied to + the `recurrent_kernel` weights matrix + (see [constraints](../constraints.md)). + bias_constraint: Constraint function applied to the bias vector + (see [constraints](../constraints.md)). + dropout: Float between 0 and 1. + Fraction of the units to drop for + the linear transformation of the inputs. + recurrent_dropout: Float between 0 and 1. + Fraction of the units to drop for + the linear transformation of the recurrent state. + implementation: Implementation mode, either 1 or 2. + Mode 1 will structure its operations as a larger number of + smaller dot products and additions, whereas mode 2 will + batch them into fewer, larger operations. These modes will + have different performance profiles on different hardware and + for different applications. + return_sequences: Boolean. Whether to return the last output. + in the output sequence, or the full sequence. + return_state: Boolean. Whether to return the last state + in addition to the output. + go_backwards: Boolean (default False). + If True, process the input sequence backwards and return the + reversed sequence. + stateful: Boolean (default False). If True, the last state + for each sample at index i in a batch will be used as initial + state for the sample of index i in the following batch. + unroll: Boolean (default False). + If True, the network will be unrolled, + else a symbolic loop will be used. + Unrolling can speed-up a RNN, + although it tends to be more memory-intensive. + Unrolling is only suitable for short sequences. + + References: + - [On the Properties of Neural Machine Translation: Encoder-Decoder Approaches](https://arxiv.org/abs/1409.1259) + - [Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling](http://arxiv.org/abs/1412.3555v1) + - [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287) + """ + # pylint: enable=line-too-long + + def __init__(self, + units, + activation='tanh', + recurrent_activation='hard_sigmoid', + use_bias=True, + kernel_initializer='glorot_uniform', + recurrent_initializer='orthogonal', + bias_initializer='zeros', + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + dropout=0., + recurrent_dropout=0., + implementation=1, + return_sequences=False, + return_state=False, + go_backwards=False, + stateful=False, + unroll=False, + **kwargs): + if implementation == 0: + logging.warning('`implementation=0` has been deprecated, ' + 'and now defaults to `implementation=1`.' + 'Please update your layer call.') + cell = GRUCell( + units, + activation=activation, + recurrent_activation=recurrent_activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + recurrent_initializer=recurrent_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + recurrent_regularizer=recurrent_regularizer, + bias_regularizer=bias_regularizer, + kernel_constraint=kernel_constraint, + recurrent_constraint=recurrent_constraint, + bias_constraint=bias_constraint, + dropout=dropout, + recurrent_dropout=recurrent_dropout, + implementation=implementation) + super(GRU, self).__init__( + cell, + return_sequences=return_sequences, + return_state=return_state, + go_backwards=go_backwards, + stateful=stateful, + unroll=unroll, + **kwargs) + self.activity_regularizer = regularizers.get(activity_regularizer) + + def call(self, inputs, mask=None, training=None, initial_state=None): + self.cell._generate_dropout_mask(inputs, training=training) + self.cell._generate_recurrent_dropout_mask(inputs, training=training) + return super(GRU, self).call( + inputs, mask=mask, training=training, initial_state=initial_state) + + @property + def units(self): + return self.cell.units + + @property + def activation(self): + return self.cell.activation + + @property + def recurrent_activation(self): + return self.cell.recurrent_activation + + @property + def use_bias(self): + return self.cell.use_bias + + @property + def kernel_initializer(self): + return self.cell.kernel_initializer + + @property + def recurrent_initializer(self): + return self.cell.recurrent_initializer + + @property + def bias_initializer(self): + return self.cell.bias_initializer + + @property + def kernel_regularizer(self): + return self.cell.kernel_regularizer + + @property + def recurrent_regularizer(self): + return self.cell.recurrent_regularizer + + @property + def bias_regularizer(self): + return self.cell.bias_regularizer + + @property + def kernel_constraint(self): + return self.cell.kernel_constraint + + @property + def recurrent_constraint(self): + return self.cell.recurrent_constraint + + @property + def bias_constraint(self): + return self.cell.bias_constraint + + @property + def dropout(self): + return self.cell.dropout + + @property + def recurrent_dropout(self): + return self.cell.recurrent_dropout + + @property + def implementation(self): + return self.cell.implementation + + def get_config(self): + config = { + 'units': + self.units, + 'activation': + activations.serialize(self.activation), + 'recurrent_activation': + activations.serialize(self.recurrent_activation), + 'use_bias': + self.use_bias, + 'kernel_initializer': + initializers.serialize(self.kernel_initializer), + 'recurrent_initializer': + initializers.serialize(self.recurrent_initializer), + 'bias_initializer': + initializers.serialize(self.bias_initializer), + 'kernel_regularizer': + regularizers.serialize(self.kernel_regularizer), + 'recurrent_regularizer': + regularizers.serialize(self.recurrent_regularizer), + 'bias_regularizer': + regularizers.serialize(self.bias_regularizer), + 'activity_regularizer': + regularizers.serialize(self.activity_regularizer), + 'kernel_constraint': + constraints.serialize(self.kernel_constraint), + 'recurrent_constraint': + constraints.serialize(self.recurrent_constraint), + 'bias_constraint': + constraints.serialize(self.bias_constraint), + 'dropout': + self.dropout, + 'recurrent_dropout': + self.recurrent_dropout, + 'implementation': + self.implementation + } + base_config = super(GRU, self).get_config() + del base_config['cell'] + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config): + if 'implementation' in config and config['implementation'] == 0: + config['implementation'] = 1 + return cls(**config) + + +class LSTMCell(Layer): + """Cell class for the LSTM layer. + + Arguments: + units: Positive integer, dimensionality of the output space. + activation: Activation function to use + (see [activations](../activations.md)). + If you pass None, no activation is applied + (ie. "linear" activation: `a(x) = x`). + recurrent_activation: Activation function to use + for the recurrent step + (see [activations](../activations.md)). + use_bias: Boolean, whether the layer uses a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. + (see [initializers](../initializers.md)). + recurrent_initializer: Initializer for the `recurrent_kernel` + weights matrix, + used for the linear transformation of the recurrent state. + (see [initializers](../initializers.md)). + bias_initializer: Initializer for the bias vector + (see [initializers](../initializers.md)). + unit_forget_bias: Boolean. + If True, add 1 to the bias of the forget gate at initialization. + Setting it to true will also force `bias_initializer="zeros"`. + This is recommended in [Jozefowicz et + al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) + kernel_regularizer: Regularizer function applied to + the `kernel` weights matrix + (see [regularizer](../regularizers.md)). + recurrent_regularizer: Regularizer function applied to + the `recurrent_kernel` weights matrix + (see [regularizer](../regularizers.md)). + bias_regularizer: Regularizer function applied to the bias vector + (see [regularizer](../regularizers.md)). + kernel_constraint: Constraint function applied to + the `kernel` weights matrix + (see [constraints](../constraints.md)). + recurrent_constraint: Constraint function applied to + the `recurrent_kernel` weights matrix + (see [constraints](../constraints.md)). + bias_constraint: Constraint function applied to the bias vector + (see [constraints](../constraints.md)). + dropout: Float between 0 and 1. + Fraction of the units to drop for + the linear transformation of the inputs. + recurrent_dropout: Float between 0 and 1. + Fraction of the units to drop for + the linear transformation of the recurrent state. + implementation: Implementation mode, either 1 or 2. + Mode 1 will structure its operations as a larger number of + smaller dot products and additions, whereas mode 2 will + batch them into fewer, larger operations. These modes will + have different performance profiles on different hardware and + for different applications. + """ + + def __init__(self, + units, + activation='tanh', + recurrent_activation='hard_sigmoid', + use_bias=True, + kernel_initializer='glorot_uniform', + recurrent_initializer='orthogonal', + bias_initializer='zeros', + unit_forget_bias=True, + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + dropout=0., + recurrent_dropout=0., + implementation=1, + **kwargs): + super(LSTMCell, self).__init__(**kwargs) + self.units = units + self.activation = activations.get(activation) + self.recurrent_activation = activations.get(recurrent_activation) + self.use_bias = use_bias + + self.kernel_initializer = initializers.get(kernel_initializer) + self.recurrent_initializer = initializers.get(recurrent_initializer) + self.bias_initializer = initializers.get(bias_initializer) + self.unit_forget_bias = unit_forget_bias + + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.recurrent_regularizer = regularizers.get(recurrent_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + + self.kernel_constraint = constraints.get(kernel_constraint) + self.recurrent_constraint = constraints.get(recurrent_constraint) + self.bias_constraint = constraints.get(bias_constraint) + + self.dropout = min(1., max(0., dropout)) + self.recurrent_dropout = min(1., max(0., recurrent_dropout)) + self.implementation = implementation + self.state_size = (self.units, self.units) + self._dropout_mask = None + self._recurrent_dropout_mask = None + + def build(self, input_shape): + input_dim = input_shape[-1] + self.kernel = self.add_weight( + shape=(input_dim, self.units * 4), + name='kernel', + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint) + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units * 4), + name='recurrent_kernel', + initializer=self.recurrent_initializer, + regularizer=self.recurrent_regularizer, + constraint=self.recurrent_constraint) + + if self.use_bias: + if self.unit_forget_bias: + + def bias_initializer(_, *args, **kwargs): + return K.concatenate([ + self.bias_initializer((self.units,), *args, **kwargs), + initializers.Ones()((self.units,), *args, **kwargs), + self.bias_initializer((self.units * 2,), *args, **kwargs), + ]) + else: + bias_initializer = self.bias_initializer + self.bias = self.add_weight( + shape=(self.units * 4,), + name='bias', + initializer=bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint) + else: + self.bias = None + + self.kernel_i = self.kernel[:, :self.units] + self.kernel_f = self.kernel[:, self.units:self.units * 2] + self.kernel_c = self.kernel[:, self.units * 2:self.units * 3] + self.kernel_o = self.kernel[:, self.units * 3:] + + self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units] + self.recurrent_kernel_f = self.recurrent_kernel[:, self.units: + self.units * 2] + self.recurrent_kernel_c = self.recurrent_kernel[:, self.units * 2: + self.units * 3] + self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3:] + + if self.use_bias: + self.bias_i = self.bias[:self.units] + self.bias_f = self.bias[self.units:self.units * 2] + self.bias_c = self.bias[self.units * 2:self.units * 3] + self.bias_o = self.bias[self.units * 3:] + else: + self.bias_i = None + self.bias_f = None + self.bias_c = None + self.bias_o = None + self.built = True + + def _generate_dropout_mask(self, inputs, training=None): + if 0 < self.dropout < 1: + ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1)) + + def dropped_inputs(): + return K.dropout(ones, self.dropout) + + self._dropout_mask = [ + K.in_train_phase(dropped_inputs, ones, training=training) + for _ in range(4) + ] + else: + self._dropout_mask = None + + def _generate_recurrent_dropout_mask(self, inputs, training=None): + if 0 < self.recurrent_dropout < 1: + ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) + ones = K.tile(ones, (1, self.units)) + + def dropped_inputs(): + return K.dropout(ones, self.dropout) + + self._recurrent_dropout_mask = [ + K.in_train_phase(dropped_inputs, ones, training=training) + for _ in range(4) + ] + else: + self._recurrent_dropout_mask = None + + def call(self, inputs, states, training=None): + # dropout matrices for input units + dp_mask = self._dropout_mask + # dropout matrices for recurrent units + rec_dp_mask = self._recurrent_dropout_mask + + h_tm1 = states[0] # previous memory state + c_tm1 = states[1] # previous carry state + + if self.implementation == 1: + if 0 < self.dropout < 1.: + inputs_i = inputs * dp_mask[0] + inputs_f = inputs * dp_mask[1] + inputs_c = inputs * dp_mask[2] + inputs_o = inputs * dp_mask[3] + else: + inputs_i = inputs + inputs_f = inputs + inputs_c = inputs + inputs_o = inputs + x_i = K.dot(inputs_i, self.kernel_i) + x_f = K.dot(inputs_f, self.kernel_f) + x_c = K.dot(inputs_c, self.kernel_c) + x_o = K.dot(inputs_o, self.kernel_o) + if self.use_bias: + x_i = K.bias_add(x_i, self.bias_i) + x_f = K.bias_add(x_f, self.bias_f) + x_c = K.bias_add(x_c, self.bias_c) + x_o = K.bias_add(x_o, self.bias_o) + + if 0 < self.recurrent_dropout < 1.: + h_tm1_i = h_tm1 * rec_dp_mask[0] + h_tm1_f = h_tm1 * rec_dp_mask[1] + h_tm1_c = h_tm1 * rec_dp_mask[2] + h_tm1_o = h_tm1 * rec_dp_mask[3] + else: + h_tm1_i = h_tm1 + h_tm1_f = h_tm1 + h_tm1_c = h_tm1 + h_tm1_o = h_tm1 + i = self.recurrent_activation( + x_i + K.dot(h_tm1_i, self.recurrent_kernel_i)) + f = self.recurrent_activation( + x_f + K.dot(h_tm1_f, self.recurrent_kernel_f)) + c = f * c_tm1 + i * self.activation( + x_c + K.dot(h_tm1_c, self.recurrent_kernel_c)) + o = self.recurrent_activation( + x_o + K.dot(h_tm1_o, self.recurrent_kernel_o)) + else: + if 0. < self.dropout < 1.: + inputs *= dp_mask[0] + z = K.dot(inputs, self.kernel) + if 0. < self.recurrent_dropout < 1.: + h_tm1 *= rec_dp_mask[0] + z += K.dot(h_tm1, self.recurrent_kernel) + if self.use_bias: + z = K.bias_add(z, self.bias) + + z0 = z[:, :self.units] + z1 = z[:, self.units:2 * self.units] + z2 = z[:, 2 * self.units:3 * self.units] + z3 = z[:, 3 * self.units:] + + i = self.recurrent_activation(z0) + f = self.recurrent_activation(z1) + c = f * c_tm1 + i * self.activation(z2) + o = self.recurrent_activation(z3) + + h = o * self.activation(c) + if 0 < self.dropout + self.recurrent_dropout: + if training is None: + h._uses_learning_phase = True + return h, [h, c] + + +class LSTM(RNN): + # pylint: disable=line-too-long + """Long-Short Term Memory layer - Hochreiter 1997. + + Arguments: + units: Positive integer, dimensionality of the output space. + activation: Activation function to use + (see [activations](../activations.md)). + If you pass None, no activation is applied + (ie. "linear" activation: `a(x) = x`). + recurrent_activation: Activation function to use + for the recurrent step + (see [activations](../activations.md)). + use_bias: Boolean, whether the layer uses a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. + (see [initializers](../initializers.md)). + recurrent_initializer: Initializer for the `recurrent_kernel` + weights matrix, + used for the linear transformation of the recurrent state. + (see [initializers](../initializers.md)). + bias_initializer: Initializer for the bias vector + (see [initializers](../initializers.md)). + unit_forget_bias: Boolean. + If True, add 1 to the bias of the forget gate at initialization. + Setting it to true will also force `bias_initializer="zeros"`. + This is recommended in [Jozefowicz et + al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) + kernel_regularizer: Regularizer function applied to + the `kernel` weights matrix + (see [regularizer](../regularizers.md)). + recurrent_regularizer: Regularizer function applied to + the `recurrent_kernel` weights matrix + (see [regularizer](../regularizers.md)). + bias_regularizer: Regularizer function applied to the bias vector + (see [regularizer](../regularizers.md)). + activity_regularizer: Regularizer function applied to + the output of the layer (its "activation"). + (see [regularizer](../regularizers.md)). + kernel_constraint: Constraint function applied to + the `kernel` weights matrix + (see [constraints](../constraints.md)). + recurrent_constraint: Constraint function applied to + the `recurrent_kernel` weights matrix + (see [constraints](../constraints.md)). + bias_constraint: Constraint function applied to the bias vector + (see [constraints](../constraints.md)). + dropout: Float between 0 and 1. + Fraction of the units to drop for + the linear transformation of the inputs. + recurrent_dropout: Float between 0 and 1. + Fraction of the units to drop for + the linear transformation of the recurrent state. + implementation: Implementation mode, either 1 or 2. + Mode 1 will structure its operations as a larger number of + smaller dot products and additions, whereas mode 2 will + batch them into fewer, larger operations. These modes will + have different performance profiles on different hardware and + for different applications. + return_sequences: Boolean. Whether to return the last output. + in the output sequence, or the full sequence. + return_state: Boolean. Whether to return the last state + in addition to the output. + go_backwards: Boolean (default False). + If True, process the input sequence backwards and return the + reversed sequence. + stateful: Boolean (default False). If True, the last state + for each sample at index i in a batch will be used as initial + state for the sample of index i in the following batch. + unroll: Boolean (default False). + If True, the network will be unrolled, + else a symbolic loop will be used. + Unrolling can speed-up a RNN, + although it tends to be more memory-intensive. + Unrolling is only suitable for short sequences. + + References: + - [Long short-term memory](http://www.bioinf.jku.at/publications/older/2604.pdf) + - [Learning to forget: Continual prediction with LSTM](http://www.mitpressjournals.org/doi/pdf/10.1162/089976600300015015) + - [Supervised sequence labeling with recurrent neural networks](http://www.cs.toronto.edu/~graves/preprint.pdf) + - [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287) + """ + # pylint: enable=line-too-long + + def __init__(self, + units, + activation='tanh', + recurrent_activation='hard_sigmoid', + use_bias=True, + kernel_initializer='glorot_uniform', + recurrent_initializer='orthogonal', + bias_initializer='zeros', + unit_forget_bias=True, + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + dropout=0., + recurrent_dropout=0., + implementation=1, + return_sequences=False, + return_state=False, + go_backwards=False, + stateful=False, + unroll=False, + **kwargs): + if implementation == 0: + logging.warning('`implementation=0` has been deprecated, ' + 'and now defaults to `implementation=1`.' + 'Please update your layer call.') + cell = LSTMCell( + units, + activation=activation, + recurrent_activation=recurrent_activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + recurrent_initializer=recurrent_initializer, + unit_forget_bias=unit_forget_bias, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + recurrent_regularizer=recurrent_regularizer, + bias_regularizer=bias_regularizer, + kernel_constraint=kernel_constraint, + recurrent_constraint=recurrent_constraint, + bias_constraint=bias_constraint, + dropout=dropout, + recurrent_dropout=recurrent_dropout, + implementation=implementation) + super(LSTM, self).__init__( + cell, + return_sequences=return_sequences, + return_state=return_state, + go_backwards=go_backwards, + stateful=stateful, + unroll=unroll, + **kwargs) + self.activity_regularizer = regularizers.get(activity_regularizer) + + def call(self, inputs, mask=None, training=None, initial_state=None): + self.cell._generate_dropout_mask(inputs, training=training) + self.cell._generate_recurrent_dropout_mask(inputs, training=training) + return super(LSTM, self).call( + inputs, mask=mask, training=training, initial_state=initial_state) + + @property + def units(self): + return self.cell.units + + @property + def activation(self): + return self.cell.activation + + @property + def recurrent_activation(self): + return self.cell.recurrent_activation + + @property + def use_bias(self): + return self.cell.use_bias + + @property + def kernel_initializer(self): + return self.cell.kernel_initializer + + @property + def recurrent_initializer(self): + return self.cell.recurrent_initializer + + @property + def bias_initializer(self): + return self.cell.bias_initializer + + @property + def unit_forget_bias(self): + return self.cell.unit_forget_bias + + @property + def kernel_regularizer(self): + return self.cell.kernel_regularizer + + @property + def recurrent_regularizer(self): + return self.cell.recurrent_regularizer + + @property + def bias_regularizer(self): + return self.cell.bias_regularizer + + @property + def kernel_constraint(self): + return self.cell.kernel_constraint + + @property + def recurrent_constraint(self): + return self.cell.recurrent_constraint + + @property + def bias_constraint(self): + return self.cell.bias_constraint + + @property + def dropout(self): + return self.cell.dropout + + @property + def recurrent_dropout(self): + return self.cell.recurrent_dropout + + @property + def implementation(self): + return self.cell.implementation + + def get_config(self): + config = { + 'units': + self.units, + 'activation': + activations.serialize(self.activation), + 'recurrent_activation': + activations.serialize(self.recurrent_activation), + 'use_bias': + self.use_bias, + 'kernel_initializer': + initializers.serialize(self.kernel_initializer), + 'recurrent_initializer': + initializers.serialize(self.recurrent_initializer), + 'bias_initializer': + initializers.serialize(self.bias_initializer), + 'unit_forget_bias': + self.unit_forget_bias, + 'kernel_regularizer': + regularizers.serialize(self.kernel_regularizer), + 'recurrent_regularizer': + regularizers.serialize(self.recurrent_regularizer), + 'bias_regularizer': + regularizers.serialize(self.bias_regularizer), + 'activity_regularizer': + regularizers.serialize(self.activity_regularizer), + 'kernel_constraint': + constraints.serialize(self.kernel_constraint), + 'recurrent_constraint': + constraints.serialize(self.recurrent_constraint), + 'bias_constraint': + constraints.serialize(self.bias_constraint), + 'dropout': + self.dropout, + 'recurrent_dropout': + self.recurrent_dropout, + 'implementation': + self.implementation + } + base_config = super(LSTM, self).get_config() + del base_config['cell'] + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config): + if 'implementation' in config and config['implementation'] == 0: + config['implementation'] = 1 + return cls(**config) class Recurrent(Layer): - """Abstract base class for recurrent layers. + """Deprecated abstract base class for recurrent layers. - Do not use in a model -- it's not a valid layer! - Use its children classes `LSTM`, `GRU` and `SimpleRNN` instead. - - All recurrent layers (`LSTM`, `GRU`, `SimpleRNN`) also - follow the specifications of this class and accept - the keyword arguments listed below. - - Example: - - ```python - # as the first layer in a Sequential model - model = Sequential() - model.add(LSTM(32, input_shape=(10, 64))) - # now model.output_shape == (None, 32) - # note: `None` is the batch dimension. - - # for subsequent layers, no need to specify the input size: - model.add(LSTM(16)) - - # to stack recurrent layers, you must use return_sequences=True - # on any recurrent layer that feeds into another recurrent layer. - # note that you only need to specify the input size on the first layer. - model = Sequential() - model.add(LSTM(64, input_dim=64, input_length=10, return_sequences=True)) - model.add(LSTM(32, return_sequences=True)) - model.add(LSTM(10)) - ``` + It still exists because it is leveraged by the convolutional-recurrent layers. + It will be removed entirely in the future. + It was never part of the public API. + Do not use. Arguments: weights: list of Numpy arrays to set as initial weights. @@ -163,7 +2227,7 @@ class Recurrent(Layer): at the level of the first layer (e.g. via the `input_shape` argument) - Input shape:s + Input shape: 3D tensor with shape `(batch_size, timesteps, input_dim)`, (Optional) 2D tensors with shape `(batch_size, output_dim)`. @@ -439,832 +2503,3 @@ class Recurrent(Layer): } base_config = super(Recurrent, self).get_config() return dict(list(base_config.items()) + list(config.items())) - - -class SimpleRNN(Recurrent): - """Fully-connected RNN where the output is to be fed back to input. - - Arguments: - units: Positive integer, dimensionality of the output space. - activation: Activation function to use. - If you don't specify anything, no activation is applied - If you pass None, no activation is applied - (ie. "linear" activation: `a(x) = x`). - use_bias: Boolean, whether the layer uses a bias vector. - kernel_initializer: Initializer for the `kernel` weights matrix, - used for the linear transformation of the inputs.. - recurrent_initializer: Initializer for the `recurrent_kernel` - weights matrix, - used for the linear transformation of the recurrent state.. - bias_initializer: Initializer for the bias vector. - kernel_regularizer: Regularizer function applied to - the `kernel` weights matrix. - recurrent_regularizer: Regularizer function applied to - the `recurrent_kernel` weights matrix. - bias_regularizer: Regularizer function applied to the bias vector. - activity_regularizer: Regularizer function applied to - the output of the layer (its "activation").. - kernel_constraint: Constraint function applied to - the `kernel` weights matrix. - recurrent_constraint: Constraint function applied to - the `recurrent_kernel` weights matrix. - bias_constraint: Constraint function applied to the bias vector. - dropout: Float between 0 and 1. - Fraction of the units to drop for - the linear transformation of the inputs. - recurrent_dropout: Float between 0 and 1. - Fraction of the units to drop for - the linear transformation of the recurrent state. - - References: - - [A Theoretically Grounded Application of Dropout in Recurrent Neural - Networks](http://arxiv.org/abs/1512.05287) - """ - - def __init__(self, - units, - activation='tanh', - use_bias=True, - kernel_initializer='glorot_uniform', - recurrent_initializer='orthogonal', - bias_initializer='zeros', - kernel_regularizer=None, - recurrent_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - kernel_constraint=None, - recurrent_constraint=None, - bias_constraint=None, - dropout=0., - recurrent_dropout=0., - **kwargs): - super(SimpleRNN, self).__init__( - activity_regularizer=regularizers.get(activity_regularizer), **kwargs) - self.units = units - self.activation = activations.get(activation) - self.use_bias = use_bias - - self.kernel_initializer = initializers.get(kernel_initializer) - self.recurrent_initializer = initializers.get(recurrent_initializer) - self.bias_initializer = initializers.get(bias_initializer) - - self.kernel_regularizer = regularizers.get(kernel_regularizer) - self.recurrent_regularizer = regularizers.get(recurrent_regularizer) - self.bias_regularizer = regularizers.get(bias_regularizer) - - self.kernel_constraint = constraints.get(kernel_constraint) - self.recurrent_constraint = constraints.get(recurrent_constraint) - self.bias_constraint = constraints.get(bias_constraint) - - self.dropout = min(1., max(0., dropout)) - self.recurrent_dropout = min(1., max(0., recurrent_dropout)) - self.state_spec = InputSpec(shape=(None, self.units)) - - def build(self, input_shape): - if isinstance(input_shape, list): - input_shape = input_shape[0] - input_shape = tensor_shape.TensorShape(input_shape).as_list() - - batch_size = input_shape[0] if self.stateful else None - self.input_dim = input_shape[2] - self.input_spec[0] = InputSpec(shape=(batch_size, None, self.input_dim)) - - self.states = [None] - if self.stateful: - self.reset_states() - - self.kernel = self.add_weight( - shape=(self.input_dim, self.units), - name='kernel', - initializer=self.kernel_initializer, - regularizer=self.kernel_regularizer, - constraint=self.kernel_constraint) - self.recurrent_kernel = self.add_weight( - shape=(self.units, self.units), - name='recurrent_kernel', - initializer=self.recurrent_initializer, - regularizer=self.recurrent_regularizer, - constraint=self.recurrent_constraint) - if self.use_bias: - self.bias = self.add_weight( - shape=(self.units,), - name='bias', - initializer=self.bias_initializer, - regularizer=self.bias_regularizer, - constraint=self.bias_constraint) - else: - self.bias = None - self.built = True - - def preprocess_input(self, inputs, training=None): - if self.implementation > 0: - return inputs - else: - input_shape = inputs.get_shape().as_list() - input_dim = input_shape[2] - timesteps = input_shape[1] - return _time_distributed_dense( - inputs, - self.kernel, - self.bias, - self.dropout, - input_dim, - self.units, - timesteps, - training=training) - - def step(self, inputs, states): - if self.implementation == 0: - h = inputs - else: - if 0 < self.dropout < 1: - h = K.dot(inputs * states[1], self.kernel) - else: - h = K.dot(inputs, self.kernel) - if self.bias is not None: - h = K.bias_add(h, self.bias) - - prev_output = states[0] - if 0 < self.recurrent_dropout < 1: - prev_output *= states[2] - output = h + K.dot(prev_output, self.recurrent_kernel) - if self.activation is not None: - output = self.activation(output) - - # Properly set learning phase on output tensor. - if 0 < self.dropout + self.recurrent_dropout: - output._uses_learning_phase = True - return output, [output] - - def get_constants(self, inputs, training=None): - constants = [] - if self.implementation != 0 and 0 < self.dropout < 1: - input_shape = K.int_shape(inputs) - input_dim = input_shape[-1] - ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) - ones = K.tile(ones, (1, int(input_dim))) - - def dropped_inputs(): - return K.dropout(ones, self.dropout) - - dp_mask = K.in_train_phase(dropped_inputs, ones, training=training) - constants.append(dp_mask) - else: - constants.append(K.cast_to_floatx(1.)) - - if 0 < self.recurrent_dropout < 1: - ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) - ones = K.tile(ones, (1, self.units)) - - def dropped_inputs(): # pylint: disable=function-redefined - return K.dropout(ones, self.recurrent_dropout) - - rec_dp_mask = K.in_train_phase(dropped_inputs, ones, training=training) - constants.append(rec_dp_mask) - else: - constants.append(K.cast_to_floatx(1.)) - return constants - - def get_config(self): - config = { - 'units': self.units, - 'activation': activations.serialize(self.activation), - 'use_bias': self.use_bias, - 'kernel_initializer': initializers.serialize(self.kernel_initializer), - 'recurrent_initializer': - initializers.serialize(self.recurrent_initializer), - 'bias_initializer': initializers.serialize(self.bias_initializer), - 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), - 'recurrent_regularizer': - regularizers.serialize(self.recurrent_regularizer), - 'bias_regularizer': regularizers.serialize(self.bias_regularizer), - 'activity_regularizer': - regularizers.serialize(self.activity_regularizer), - 'kernel_constraint': constraints.serialize(self.kernel_constraint), - 'recurrent_constraint': - constraints.serialize(self.recurrent_constraint), - 'bias_constraint': constraints.serialize(self.bias_constraint), - 'dropout': self.dropout, - 'recurrent_dropout': self.recurrent_dropout - } - base_config = super(SimpleRNN, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - -class GRU(Recurrent): - """Gated Recurrent Unit - Cho et al. - - 2014. - - Arguments: - units: Positive integer, dimensionality of the output space. - activation: Activation function to use. - If you pass None, no activation is applied - (ie. "linear" activation: `a(x) = x`). - recurrent_activation: Activation function to use - for the recurrent step. - use_bias: Boolean, whether the layer uses a bias vector. - kernel_initializer: Initializer for the `kernel` weights matrix, - used for the linear transformation of the inputs.. - recurrent_initializer: Initializer for the `recurrent_kernel` - weights matrix, - used for the linear transformation of the recurrent state.. - bias_initializer: Initializer for the bias vector. - kernel_regularizer: Regularizer function applied to - the `kernel` weights matrix. - recurrent_regularizer: Regularizer function applied to - the `recurrent_kernel` weights matrix. - bias_regularizer: Regularizer function applied to the bias vector. - activity_regularizer: Regularizer function applied to - the output of the layer (its "activation").. - kernel_constraint: Constraint function applied to - the `kernel` weights matrix. - recurrent_constraint: Constraint function applied to - the `recurrent_kernel` weights matrix. - bias_constraint: Constraint function applied to the bias vector. - dropout: Float between 0 and 1. - Fraction of the units to drop for - the linear transformation of the inputs. - recurrent_dropout: Float between 0 and 1. - Fraction of the units to drop for - the linear transformation of the recurrent state. - - References: - - [On the Properties of Neural Machine Translation: Encoder-Decoder - Approaches](https://arxiv.org/abs/1409.1259) - - [Empirical Evaluation of Gated Recurrent Neural Networks on Sequence - Modeling](http://arxiv.org/abs/1412.3555v1) - - [A Theoretically Grounded Application of Dropout in Recurrent Neural - Networks](http://arxiv.org/abs/1512.05287) - """ - - def __init__(self, - units, - activation='tanh', - recurrent_activation='hard_sigmoid', - use_bias=True, - kernel_initializer='glorot_uniform', - recurrent_initializer='orthogonal', - bias_initializer='zeros', - kernel_regularizer=None, - recurrent_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - kernel_constraint=None, - recurrent_constraint=None, - bias_constraint=None, - dropout=0., - recurrent_dropout=0., - **kwargs): - super(GRU, self).__init__( - activity_regularizer=regularizers.get(activity_regularizer), **kwargs) - self.units = units - self.activation = activations.get(activation) - self.recurrent_activation = activations.get(recurrent_activation) - self.use_bias = use_bias - - self.kernel_initializer = initializers.get(kernel_initializer) - self.recurrent_initializer = initializers.get(recurrent_initializer) - self.bias_initializer = initializers.get(bias_initializer) - - self.kernel_regularizer = regularizers.get(kernel_regularizer) - self.recurrent_regularizer = regularizers.get(recurrent_regularizer) - self.bias_regularizer = regularizers.get(bias_regularizer) - - self.kernel_constraint = constraints.get(kernel_constraint) - self.recurrent_constraint = constraints.get(recurrent_constraint) - self.bias_constraint = constraints.get(bias_constraint) - - self.dropout = min(1., max(0., dropout)) - self.recurrent_dropout = min(1., max(0., recurrent_dropout)) - self.state_spec = InputSpec(shape=(None, self.units)) - - def build(self, input_shape): - if isinstance(input_shape, list): - input_shape = input_shape[0] - input_shape = tensor_shape.TensorShape(input_shape).as_list() - batch_size = input_shape[0] if self.stateful else None - self.input_dim = input_shape[2] - self.input_spec[0] = InputSpec(shape=(batch_size, None, self.input_dim)) - - self.states = [None] - if self.stateful: - self.reset_states() - - self.kernel = self.add_weight( - shape=(self.input_dim, self.units * 3), - name='kernel', - initializer=self.kernel_initializer, - regularizer=self.kernel_regularizer, - constraint=self.kernel_constraint) - self.recurrent_kernel = self.add_weight( - shape=(self.units, self.units * 3), - name='recurrent_kernel', - initializer=self.recurrent_initializer, - regularizer=self.recurrent_regularizer, - constraint=self.recurrent_constraint) - - if self.use_bias: - self.bias = self.add_weight( - shape=(self.units * 3,), - name='bias', - initializer=self.bias_initializer, - regularizer=self.bias_regularizer, - constraint=self.bias_constraint) - else: - self.bias = None - - self.kernel_z = self.kernel[:, :self.units] - self.recurrent_kernel_z = self.recurrent_kernel[:, :self.units] - self.kernel_r = self.kernel[:, self.units:self.units * 2] - self.recurrent_kernel_r = self.recurrent_kernel[:, self.units: - self.units * 2] - self.kernel_h = self.kernel[:, self.units * 2:] - self.recurrent_kernel_h = self.recurrent_kernel[:, self.units * 2:] - - if self.use_bias: - self.bias_z = self.bias[:self.units] - self.bias_r = self.bias[self.units:self.units * 2] - self.bias_h = self.bias[self.units * 2:] - else: - self.bias_z = None - self.bias_r = None - self.bias_h = None - self.built = True - - def preprocess_input(self, inputs, training=None): - if self.implementation == 0: - input_shape = inputs.get_shape().as_list() - input_dim = input_shape[2] - timesteps = input_shape[1] - - x_z = _time_distributed_dense( - inputs, - self.kernel_z, - self.bias_z, - self.dropout, - input_dim, - self.units, - timesteps, - training=training) - x_r = _time_distributed_dense( - inputs, - self.kernel_r, - self.bias_r, - self.dropout, - input_dim, - self.units, - timesteps, - training=training) - x_h = _time_distributed_dense( - inputs, - self.kernel_h, - self.bias_h, - self.dropout, - input_dim, - self.units, - timesteps, - training=training) - return K.concatenate([x_z, x_r, x_h], axis=2) - else: - return inputs - - def get_constants(self, inputs, training=None): - constants = [] - if self.implementation != 0 and 0 < self.dropout < 1: - input_shape = K.int_shape(inputs) - input_dim = input_shape[-1] - ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) - ones = K.tile(ones, (1, int(input_dim))) - - def dropped_inputs(): - return K.dropout(ones, self.dropout) - - dp_mask = [ - K.in_train_phase(dropped_inputs, ones, training=training) - for _ in range(3) - ] - constants.append(dp_mask) - else: - constants.append([K.cast_to_floatx(1.) for _ in range(3)]) - - if 0 < self.recurrent_dropout < 1: - ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) - ones = K.tile(ones, (1, self.units)) - - def dropped_inputs(): # pylint: disable=function-redefined - return K.dropout(ones, self.recurrent_dropout) - - rec_dp_mask = [ - K.in_train_phase(dropped_inputs, ones, training=training) - for _ in range(3) - ] - constants.append(rec_dp_mask) - else: - constants.append([K.cast_to_floatx(1.) for _ in range(3)]) - return constants - - def step(self, inputs, states): - h_tm1 = states[0] # previous memory - dp_mask = states[1] # dropout matrices for recurrent units - rec_dp_mask = states[2] - - if self.implementation == 2: - matrix_x = K.dot(inputs * dp_mask[0], self.kernel) - if self.use_bias: - matrix_x = K.bias_add(matrix_x, self.bias) - matrix_inner = K.dot(h_tm1 * rec_dp_mask[0], - self.recurrent_kernel[:, :2 * self.units]) - - x_z = matrix_x[:, :self.units] - x_r = matrix_x[:, self.units:2 * self.units] - recurrent_z = matrix_inner[:, :self.units] - recurrent_r = matrix_inner[:, self.units:2 * self.units] - - z = self.recurrent_activation(x_z + recurrent_z) - r = self.recurrent_activation(x_r + recurrent_r) - - x_h = matrix_x[:, 2 * self.units:] - recurrent_h = K.dot(r * h_tm1 * rec_dp_mask[0], - self.recurrent_kernel[:, 2 * self.units:]) - hh = self.activation(x_h + recurrent_h) - else: - if self.implementation == 0: - x_z = inputs[:, :self.units] - x_r = inputs[:, self.units:2 * self.units] - x_h = inputs[:, 2 * self.units:] - elif self.implementation == 1: - x_z = K.dot(inputs * dp_mask[0], self.kernel_z) - x_r = K.dot(inputs * dp_mask[1], self.kernel_r) - x_h = K.dot(inputs * dp_mask[2], self.kernel_h) - if self.use_bias: - x_z = K.bias_add(x_z, self.bias_z) - x_r = K.bias_add(x_r, self.bias_r) - x_h = K.bias_add(x_h, self.bias_h) - else: - raise ValueError('Unknown `implementation` mode.') - z = self.recurrent_activation(x_z + K.dot(h_tm1 * rec_dp_mask[0], - self.recurrent_kernel_z)) - r = self.recurrent_activation(x_r + K.dot(h_tm1 * rec_dp_mask[1], - self.recurrent_kernel_r)) - - hh = self.activation(x_h + K.dot(r * h_tm1 * rec_dp_mask[2], - self.recurrent_kernel_h)) - h = z * h_tm1 + (1 - z) * hh - if 0 < self.dropout + self.recurrent_dropout: - h._uses_learning_phase = True - return h, [h] - - def get_config(self): - config = { - 'units': self.units, - 'activation': activations.serialize(self.activation), - 'recurrent_activation': - activations.serialize(self.recurrent_activation), - 'use_bias': self.use_bias, - 'kernel_initializer': initializers.serialize(self.kernel_initializer), - 'recurrent_initializer': - initializers.serialize(self.recurrent_initializer), - 'bias_initializer': initializers.serialize(self.bias_initializer), - 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), - 'recurrent_regularizer': - regularizers.serialize(self.recurrent_regularizer), - 'bias_regularizer': regularizers.serialize(self.bias_regularizer), - 'activity_regularizer': - regularizers.serialize(self.activity_regularizer), - 'kernel_constraint': constraints.serialize(self.kernel_constraint), - 'recurrent_constraint': - constraints.serialize(self.recurrent_constraint), - 'bias_constraint': constraints.serialize(self.bias_constraint), - 'dropout': self.dropout, - 'recurrent_dropout': self.recurrent_dropout - } - base_config = super(GRU, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - -class LSTM(Recurrent): - """Long-Short Term Memory unit - Hochreiter 1997. - - For a step-by-step description of the algorithm, see - [this tutorial](http://deeplearning.net/tutorial/lstm.html). - - Arguments: - units: Positive integer, dimensionality of the output space. - activation: Activation function to use. - If you pass None, no activation is applied - (ie. "linear" activation: `a(x) = x`). - recurrent_activation: Activation function to use - for the recurrent step. - use_bias: Boolean, whether the layer uses a bias vector. - kernel_initializer: Initializer for the `kernel` weights matrix, - used for the linear transformation of the inputs.. - recurrent_initializer: Initializer for the `recurrent_kernel` - weights matrix, - used for the linear transformation of the recurrent state.. - bias_initializer: Initializer for the bias vector. - unit_forget_bias: Boolean. - If True, add 1 to the bias of the forget gate at initialization. - Setting it to true will also force `bias_initializer="zeros"`. - This is recommended in [Jozefowicz et - al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) - kernel_regularizer: Regularizer function applied to - the `kernel` weights matrix. - recurrent_regularizer: Regularizer function applied to - the `recurrent_kernel` weights matrix. - bias_regularizer: Regularizer function applied to the bias vector. - activity_regularizer: Regularizer function applied to - the output of the layer (its "activation").. - kernel_constraint: Constraint function applied to - the `kernel` weights matrix. - recurrent_constraint: Constraint function applied to - the `recurrent_kernel` weights matrix. - bias_constraint: Constraint function applied to the bias vector. - dropout: Float between 0 and 1. - Fraction of the units to drop for - the linear transformation of the inputs. - recurrent_dropout: Float between 0 and 1. - Fraction of the units to drop for - the linear transformation of the recurrent state. - - References: - - [Long short-term - memory]((http://www.bioinf.jku.at/publications/older/2604.pdf) - (original 1997 paper) - - [Supervised sequence labeling with recurrent neural - networks](http://www.cs.toronto.edu/~graves/preprint.pdf) - - [A Theoretically Grounded Application of Dropout in Recurrent Neural - Networks](http://arxiv.org/abs/1512.05287) - """ - - def __init__(self, - units, - activation='tanh', - recurrent_activation='hard_sigmoid', - use_bias=True, - kernel_initializer='glorot_uniform', - recurrent_initializer='orthogonal', - bias_initializer='zeros', - unit_forget_bias=True, - kernel_regularizer=None, - recurrent_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - kernel_constraint=None, - recurrent_constraint=None, - bias_constraint=None, - dropout=0., - recurrent_dropout=0., - **kwargs): - super(LSTM, self).__init__( - activity_regularizer=regularizers.get(activity_regularizer), **kwargs) - self.units = units - self.activation = activations.get(activation) - self.recurrent_activation = activations.get(recurrent_activation) - self.use_bias = use_bias - - self.kernel_initializer = initializers.get(kernel_initializer) - self.recurrent_initializer = initializers.get(recurrent_initializer) - self.bias_initializer = initializers.get(bias_initializer) - self.unit_forget_bias = unit_forget_bias - - self.kernel_regularizer = regularizers.get(kernel_regularizer) - self.recurrent_regularizer = regularizers.get(recurrent_regularizer) - self.bias_regularizer = regularizers.get(bias_regularizer) - - self.kernel_constraint = constraints.get(kernel_constraint) - self.recurrent_constraint = constraints.get(recurrent_constraint) - self.bias_constraint = constraints.get(bias_constraint) - - self.dropout = min(1., max(0., dropout)) - self.recurrent_dropout = min(1., max(0., recurrent_dropout)) - self.state_spec = [ - InputSpec(shape=(None, self.units)), - InputSpec(shape=(None, self.units)) - ] - - def build(self, input_shape): - if isinstance(input_shape, list): - input_shape = input_shape[0] - input_shape = tensor_shape.TensorShape(input_shape).as_list() - batch_size = input_shape[0] if self.stateful else None - self.input_dim = input_shape[2] - self.input_spec[0] = InputSpec(shape=(batch_size, None, self.input_dim)) - - self.states = [None, None] - if self.stateful: - self.reset_states() - - self.kernel = self.add_weight( - shape=(self.input_dim, self.units * 4), - name='kernel', - initializer=self.kernel_initializer, - regularizer=self.kernel_regularizer, - constraint=self.kernel_constraint) - self.recurrent_kernel = self.add_weight( - shape=(self.units, self.units * 4), - name='recurrent_kernel', - initializer=self.recurrent_initializer, - regularizer=self.recurrent_regularizer, - constraint=self.recurrent_constraint) - - if self.use_bias: - if self.unit_forget_bias: - - def bias_initializer(_, *args, **kwargs): - return K.concatenate([ - self.bias_initializer((self.units,), *args, **kwargs), - initializers.Ones()((self.units,), *args, **kwargs), - self.bias_initializer((self.units * 2,), *args, **kwargs), - ]) - else: - bias_initializer = self.bias_initializer - self.bias = self.add_weight( - shape=(self.units * 4,), - name='bias', - initializer=bias_initializer, - regularizer=self.bias_regularizer, - constraint=self.bias_constraint) - else: - self.bias = None - - self.kernel_i = self.kernel[:, :self.units] - self.kernel_f = self.kernel[:, self.units:self.units * 2] - self.kernel_c = self.kernel[:, self.units * 2:self.units * 3] - self.kernel_o = self.kernel[:, self.units * 3:] - - self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units] - self.recurrent_kernel_f = self.recurrent_kernel[:, self.units: - self.units * 2] - self.recurrent_kernel_c = self.recurrent_kernel[:, self.units * 2: - self.units * 3] - self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3:] - - if self.use_bias: - self.bias_i = self.bias[:self.units] - self.bias_f = self.bias[self.units:self.units * 2] - self.bias_c = self.bias[self.units * 2:self.units * 3] - self.bias_o = self.bias[self.units * 3:] - else: - self.bias_i = None - self.bias_f = None - self.bias_c = None - self.bias_o = None - self.built = True - - def preprocess_input(self, inputs, training=None): - if self.implementation == 0: - input_shape = inputs.get_shape().as_list() - input_dim = input_shape[2] - timesteps = input_shape[1] - - x_i = _time_distributed_dense( - inputs, - self.kernel_i, - self.bias_i, - self.dropout, - input_dim, - self.units, - timesteps, - training=training) - x_f = _time_distributed_dense( - inputs, - self.kernel_f, - self.bias_f, - self.dropout, - input_dim, - self.units, - timesteps, - training=training) - x_c = _time_distributed_dense( - inputs, - self.kernel_c, - self.bias_c, - self.dropout, - input_dim, - self.units, - timesteps, - training=training) - x_o = _time_distributed_dense( - inputs, - self.kernel_o, - self.bias_o, - self.dropout, - input_dim, - self.units, - timesteps, - training=training) - return K.concatenate([x_i, x_f, x_c, x_o], axis=2) - else: - return inputs - - def get_constants(self, inputs, training=None): - constants = [] - if self.implementation != 0 and 0 < self.dropout < 1: - input_shape = K.int_shape(inputs) - input_dim = input_shape[-1] - ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) - ones = K.tile(ones, (1, int(input_dim))) - - def dropped_inputs(): - return K.dropout(ones, self.dropout) - - dp_mask = [ - K.in_train_phase(dropped_inputs, ones, training=training) - for _ in range(4) - ] - constants.append(dp_mask) - else: - constants.append([K.cast_to_floatx(1.) for _ in range(4)]) - - if 0 < self.recurrent_dropout < 1: - ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) - ones = K.tile(ones, (1, self.units)) - - def dropped_inputs(): # pylint: disable=function-redefined - return K.dropout(ones, self.recurrent_dropout) - - rec_dp_mask = [ - K.in_train_phase(dropped_inputs, ones, training=training) - for _ in range(4) - ] - constants.append(rec_dp_mask) - else: - constants.append([K.cast_to_floatx(1.) for _ in range(4)]) - return constants - - def step(self, inputs, states): - h_tm1 = states[0] - c_tm1 = states[1] - dp_mask = states[2] - rec_dp_mask = states[3] - - if self.implementation == 2: - z = K.dot(inputs * dp_mask[0], self.kernel) - z += K.dot(h_tm1 * rec_dp_mask[0], self.recurrent_kernel) - if self.use_bias: - z = K.bias_add(z, self.bias) - - z0 = z[:, :self.units] - z1 = z[:, self.units:2 * self.units] - z2 = z[:, 2 * self.units:3 * self.units] - z3 = z[:, 3 * self.units:] - - i = self.recurrent_activation(z0) - f = self.recurrent_activation(z1) - c = f * c_tm1 + i * self.activation(z2) - o = self.recurrent_activation(z3) - else: - if self.implementation == 0: - x_i = inputs[:, :self.units] - x_f = inputs[:, self.units:2 * self.units] - x_c = inputs[:, 2 * self.units:3 * self.units] - x_o = inputs[:, 3 * self.units:] - elif self.implementation == 1: - x_i = K.dot(inputs * dp_mask[0], self.kernel_i) + self.bias_i - x_f = K.dot(inputs * dp_mask[1], self.kernel_f) + self.bias_f - x_c = K.dot(inputs * dp_mask[2], self.kernel_c) + self.bias_c - x_o = K.dot(inputs * dp_mask[3], self.kernel_o) + self.bias_o - else: - raise ValueError('Unknown `implementation` mode.') - - i = self.recurrent_activation(x_i + K.dot(h_tm1 * rec_dp_mask[0], - self.recurrent_kernel_i)) - f = self.recurrent_activation(x_f + K.dot(h_tm1 * rec_dp_mask[1], - self.recurrent_kernel_f)) - c = f * c_tm1 + i * self.activation( - x_c + K.dot(h_tm1 * rec_dp_mask[2], self.recurrent_kernel_c)) - o = self.recurrent_activation(x_o + K.dot(h_tm1 * rec_dp_mask[3], - self.recurrent_kernel_o)) - h = o * self.activation(c) - if 0 < self.dropout + self.recurrent_dropout: - h._uses_learning_phase = True - return h, [h, c] - - def get_config(self): - config = { - 'units': self.units, - 'activation': activations.serialize(self.activation), - 'recurrent_activation': - activations.serialize(self.recurrent_activation), - 'use_bias': self.use_bias, - 'kernel_initializer': initializers.serialize(self.kernel_initializer), - 'recurrent_initializer': - initializers.serialize(self.recurrent_initializer), - 'bias_initializer': initializers.serialize(self.bias_initializer), - 'unit_forget_bias': self.unit_forget_bias, - 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), - 'recurrent_regularizer': - regularizers.serialize(self.recurrent_regularizer), - 'bias_regularizer': regularizers.serialize(self.bias_regularizer), - 'activity_regularizer': - regularizers.serialize(self.activity_regularizer), - 'kernel_constraint': constraints.serialize(self.kernel_constraint), - 'recurrent_constraint': - constraints.serialize(self.recurrent_constraint), - 'bias_constraint': constraints.serialize(self.bias_constraint), - 'dropout': self.dropout, - 'recurrent_dropout': self.recurrent_dropout - } - base_config = super(LSTM, self).get_config() - return dict(list(base_config.items()) + list(config.items())) diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py new file mode 100644 index 00000000000..b1f89a30bb3 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py @@ -0,0 +1,378 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for recurrent layers functionality other than GRU, LSTM, SimpleRNN. + +See also: lstm_test.py, gru_test.py, simplernn_test.py. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.keras._impl import keras +from tensorflow.python.platform import test + + +class RNNTest(test.TestCase): + + def test_minimal_rnn_cell_non_layer(self): + + class MinimalRNNCell(object): + + def __init__(self, units, input_dim): + self.units = units + self.state_size = units + self.kernel = keras.backend.variable( + np.random.random((input_dim, units))) + + def call(self, inputs, states): + prev_output = states[0] + output = keras.backend.dot(inputs, self.kernel) + prev_output + return output, [output] + + with self.test_session(): + # Basic test case. + cell = MinimalRNNCell(32, 5) + x = keras.Input((None, 5)) + layer = keras.layers.RNN(cell) + y = layer(x) + model = keras.models.Model(x, y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32))) + + # Test stacking. + cells = [MinimalRNNCell(8, 5), + MinimalRNNCell(32, 8), + MinimalRNNCell(32, 32)] + layer = keras.layers.RNN(cells) + y = layer(x) + model = keras.models.Model(x, y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32))) + + def test_minimal_rnn_cell_non_layer_multiple_states(self): + + class MinimalRNNCell(object): + + def __init__(self, units, input_dim): + self.units = units + self.state_size = (units, units) + self.kernel = keras.backend.variable( + np.random.random((input_dim, units))) + + def call(self, inputs, states): + prev_output_1 = states[0] + prev_output_2 = states[1] + output = keras.backend.dot(inputs, self.kernel) + output += prev_output_1 + output -= prev_output_2 + return output, [output * 2, output * 3] + + with self.test_session(): + # Basic test case. + cell = MinimalRNNCell(32, 5) + x = keras.Input((None, 5)) + layer = keras.layers.RNN(cell) + y = layer(x) + model = keras.models.Model(x, y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32))) + + # Test stacking. + cells = [MinimalRNNCell(8, 5), + MinimalRNNCell(16, 8), + MinimalRNNCell(32, 16)] + layer = keras.layers.RNN(cells) + assert layer.cell.state_size == (32, 32, 16, 16, 8, 8) + y = layer(x) + model = keras.models.Model(x, y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32))) + + def test_minimal_rnn_cell_layer(self): + + class MinimalRNNCell(keras.layers.Layer): + + def __init__(self, units, **kwargs): + self.units = units + self.state_size = units + super(MinimalRNNCell, self).__init__(**kwargs) + + def build(self, input_shape): + self.kernel = self.add_weight(shape=(input_shape[-1], self.units), + initializer='uniform', + name='kernel') + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units), + initializer='uniform', + name='recurrent_kernel') + self.built = True + + def call(self, inputs, states): + prev_output = states[0] + h = keras.backend.dot(inputs, self.kernel) + output = h + keras.backend.dot(prev_output, self.recurrent_kernel) + return output, [output] + + def get_config(self): + config = {'units': self.units} + base_config = super(MinimalRNNCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + with self.test_session(): + # Test basic case. + x = keras.Input((None, 5)) + cell = MinimalRNNCell(32) + layer = keras.layers.RNN(cell) + y = layer(x) + model = keras.models.Model(x, y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32))) + + # Test basic case serialization. + x_np = np.random.random((6, 5, 5)) + y_np = model.predict(x_np) + weights = model.get_weights() + config = layer.get_config() + with keras.utils.CustomObjectScope({'MinimalRNNCell': MinimalRNNCell}): + layer = keras.layers.RNN.from_config(config) + y = layer(x) + model = keras.models.Model(x, y) + model.set_weights(weights) + y_np_2 = model.predict(x_np) + self.assertAllClose(y_np, y_np_2, atol=1e-4) + + # Test stacking. + cells = [MinimalRNNCell(8), + MinimalRNNCell(12), + MinimalRNNCell(32)] + layer = keras.layers.RNN(cells) + y = layer(x) + model = keras.models.Model(x, y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32))) + + # Test stacked RNN serialization. + x_np = np.random.random((6, 5, 5)) + y_np = model.predict(x_np) + weights = model.get_weights() + config = layer.get_config() + with keras.utils.CustomObjectScope({'MinimalRNNCell': MinimalRNNCell}): + layer = keras.layers.RNN.from_config(config) + y = layer(x) + model = keras.models.Model(x, y) + model.set_weights(weights) + y_np_2 = model.predict(x_np) + self.assertAllClose(y_np, y_np_2, atol=1e-4) + + def test_rnn_cell_with_constants_layer(self): + + class RNNCellWithConstants(keras.layers.Layer): + + def __init__(self, units, **kwargs): + self.units = units + self.state_size = units + super(RNNCellWithConstants, self).__init__(**kwargs) + + def build(self, input_shape): + if not isinstance(input_shape, list): + raise TypeError('expects constants shape') + [input_shape, constant_shape] = input_shape + # will (and should) raise if more than one constant passed + + self.input_kernel = self.add_weight( + shape=(input_shape[-1], self.units), + initializer='uniform', + name='kernel') + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units), + initializer='uniform', + name='recurrent_kernel') + self.constant_kernel = self.add_weight( + shape=(constant_shape[-1], self.units), + initializer='uniform', + name='constant_kernel') + self.built = True + + def call(self, inputs, states, constants): + [prev_output] = states + [constant] = constants + h_input = keras.backend.dot(inputs, self.input_kernel) + h_state = keras.backend.dot(prev_output, self.recurrent_kernel) + h_const = keras.backend.dot(constant, self.constant_kernel) + output = h_input + h_state + h_const + return output, [output] + + def get_config(self): + config = {'units': self.units} + base_config = super(RNNCellWithConstants, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + with self.test_session(): + # Test basic case. + x = keras.Input((None, 5)) + c = keras.Input((3,)) + cell = RNNCellWithConstants(32) + layer = keras.layers.RNN(cell) + y = layer(x, constants=c) + model = keras.models.Model([x, c], y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch( + [np.zeros((6, 5, 5)), np.zeros((6, 3))], + np.zeros((6, 32)) + ) + + with self.test_session(): + # Test basic case serialization. + x_np = np.random.random((6, 5, 5)) + c_np = np.random.random((6, 3)) + y_np = model.predict([x_np, c_np]) + weights = model.get_weights() + config = layer.get_config() + custom_objects = {'RNNCellWithConstants': RNNCellWithConstants} + with keras.utils.CustomObjectScope(custom_objects): + layer = keras.layers.RNN.from_config(config.copy()) + y = layer(x, constants=c) + model = keras.models.Model([x, c], y) + model.set_weights(weights) + y_np_2 = model.predict([x_np, c_np]) + self.assertAllClose(y_np, y_np_2, atol=1e-4) + + with self.test_session(): + # test flat list inputs + with keras.utils.CustomObjectScope(custom_objects): + layer = keras.layers.RNN.from_config(config.copy()) + y = layer([x, c]) + model = keras.models.Model([x, c], y) + model.set_weights(weights) + y_np_3 = model.predict([x_np, c_np]) + self.assertAllClose(y_np, y_np_3, atol=1e-4) + + def test_rnn_cell_with_constants_layer_passing_initial_state(self): + + class RNNCellWithConstants(keras.layers.Layer): + + def __init__(self, units, **kwargs): + self.units = units + self.state_size = units + super(RNNCellWithConstants, self).__init__(**kwargs) + + def build(self, input_shape): + if not isinstance(input_shape, list): + raise TypeError('expects constants shape') + [input_shape, constant_shape] = input_shape + # will (and should) raise if more than one constant passed + + self.input_kernel = self.add_weight( + shape=(input_shape[-1], self.units), + initializer='uniform', + name='kernel') + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units), + initializer='uniform', + name='recurrent_kernel') + self.constant_kernel = self.add_weight( + shape=(constant_shape[-1], self.units), + initializer='uniform', + name='constant_kernel') + self.built = True + + def call(self, inputs, states, constants): + [prev_output] = states + [constant] = constants + h_input = keras.backend.dot(inputs, self.input_kernel) + h_state = keras.backend.dot(prev_output, self.recurrent_kernel) + h_const = keras.backend.dot(constant, self.constant_kernel) + output = h_input + h_state + h_const + return output, [output] + + def get_config(self): + config = {'units': self.units} + base_config = super(RNNCellWithConstants, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + with self.test_session(): + # Test basic case. + x = keras.Input((None, 5)) + c = keras.Input((3,)) + s = keras.Input((32,)) + cell = RNNCellWithConstants(32) + layer = keras.layers.RNN(cell) + y = layer(x, initial_state=s, constants=c) + model = keras.models.Model([x, s, c], y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch( + [np.zeros((6, 5, 5)), np.zeros((6, 32)), np.zeros((6, 3))], + np.zeros((6, 32)) + ) + + with self.test_session(): + # Test basic case serialization. + x_np = np.random.random((6, 5, 5)) + s_np = np.random.random((6, 32)) + c_np = np.random.random((6, 3)) + y_np = model.predict([x_np, s_np, c_np]) + weights = model.get_weights() + config = layer.get_config() + custom_objects = {'RNNCellWithConstants': RNNCellWithConstants} + with keras.utils.CustomObjectScope(custom_objects): + layer = keras.layers.RNN.from_config(config.copy()) + y = layer(x, initial_state=s, constants=c) + model = keras.models.Model([x, s, c], y) + model.set_weights(weights) + y_np_2 = model.predict([x_np, s_np, c_np]) + self.assertAllClose(y_np, y_np_2, atol=1e-4) + + # verify that state is used + y_np_2_different_s = model.predict([x_np, s_np + 10., c_np]) + with self.assertRaises(AssertionError): + self.assertAllClose(y_np, y_np_2_different_s, atol=1e-4) + + with self.test_session(): + # test flat list inputs + with keras.utils.CustomObjectScope(custom_objects): + layer = keras.layers.RNN.from_config(config.copy()) + y = layer([x, s, c]) + model = keras.models.Model([x, s, c], y) + model.set_weights(weights) + y_np_3 = model.predict([x_np, s_np, c_np]) + self.assertAllClose(y_np, y_np_3, atol=1e-4) + + def test_stacked_rnn_attributes(self): + cells = [keras.layers.LSTMCell(3), + keras.layers.LSTMCell(3, kernel_regularizer='l2')] + layer = keras.layers.RNN(cells) + layer.build((None, None, 5)) + + # Test regularization losses + assert len(layer.losses) == 1 + + # Test weights + assert len(layer.trainable_weights) == 6 + cells[0].trainable = False + assert len(layer.trainable_weights) == 3 + assert len(layer.non_trainable_weights) == 3 + + # Test `get_losses_for` + x = keras.Input((None, 5)) + y = keras.backend.sum(x) + cells[0].add_loss(y, inputs=x) + assert layer.get_losses_for(x) == [y] + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py b/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py index 9833485236b..7edebdacd07 100644 --- a/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py @@ -156,8 +156,10 @@ class SimpleRNNLayerTest(test.TestCase): activity_regularizer='l1') layer.build((None, None, 2)) self.assertEqual(len(layer.losses), 3) - layer(keras.backend.variable(np.ones((2, 3, 2)))) - self.assertEqual(len(layer.losses), 4) + + x = keras.backend.variable(np.ones((2, 3, 2))) + layer(x) + self.assertEqual(len(layer.get_losses_for(x)), 1) def test_constraints_SimpleRNN(self): embedding_dim = 4 @@ -175,9 +177,9 @@ class SimpleRNNLayerTest(test.TestCase): recurrent_constraint=r_constraint, bias_constraint=b_constraint) layer.build((None, None, embedding_dim)) - self.assertEqual(layer.kernel.constraint, k_constraint) - self.assertEqual(layer.recurrent_kernel.constraint, r_constraint) - self.assertEqual(layer.bias.constraint, b_constraint) + self.assertEqual(layer.cell.kernel.constraint, k_constraint) + self.assertEqual(layer.cell.recurrent_kernel.constraint, r_constraint) + self.assertEqual(layer.cell.bias.constraint, b_constraint) def test_with_masking_layer_SimpleRNN(self): layer_class = keras.layers.SimpleRNN diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py index acf0a5e1799..b94bf8f0f67 100644 --- a/tensorflow/python/keras/layers/__init__.py +++ b/tensorflow/python/keras/layers/__init__.py @@ -134,6 +134,11 @@ from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool2D from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool3D # Recurrent layers. +from tensorflow.python.keras._impl.keras.layers.recurrent import RNN +from tensorflow.python.keras._impl.keras.layers.recurrent import StackedRNNCells +from tensorflow.python.keras._impl.keras.layers.recurrent import SimpleRNNCell +from tensorflow.python.keras._impl.keras.layers.recurrent import GRUCell +from tensorflow.python.keras._impl.keras.layers.recurrent import LSTMCell from tensorflow.python.keras._impl.keras.layers.recurrent import SimpleRNN from tensorflow.python.keras._impl.keras.layers.recurrent import GRU from tensorflow.python.keras._impl.keras.layers.recurrent import LSTM diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 8c8d774b754..c71e8382e91 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -642,7 +642,7 @@ class Layer(object): for output in output_list: with ops.name_scope('ActivityRegularizer'): activity_regularization = self._activity_regularizer(output) - self.add_loss(activity_regularization) + self.add_loss(activity_regularization, inputs=inputs) if not in_deferred_mode: # TODO(fchollet): consider how masking will work with deferred mode. diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py index 71eff2f9657..7ddfe37827d 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/layers/base_test.py @@ -574,6 +574,13 @@ class BaseLayerTest(test.TestCase): self.assertEqual(3, result['label'].numpy()) self.assertEqual(4.0, result['logits'].numpy()) + def testActivityRegularizer(self): + regularizer = math_ops.reduce_sum + layer = base_layers.Layer(activity_regularizer=regularizer) + x = array_ops.placeholder('int32') + layer.apply(x) + self.assertEqual(len(layer.get_losses_for(x)), 1) + class NetworkTest(test.TestCase): diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt new file mode 100644 index 00000000000..763184899ca --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt @@ -0,0 +1,179 @@ +path: "tensorflow.keras.layers.GRUCell" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'states\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt index 92373992548..889f2cbc234 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt @@ -1,14 +1,34 @@ path: "tensorflow.keras.layers.GRU" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" + member { + name: "activation" + mtype: "" + } member { name: "activity_regularizer" mtype: "" } + member { + name: "bias_constraint" + mtype: "" + } + member { + name: "bias_initializer" + mtype: "" + } + member { + name: "bias_regularizer" + mtype: "" + } + member { + name: "dropout" + mtype: "" + } member { name: "dtype" mtype: "" @@ -17,6 +37,10 @@ tf_class { name: "graph" mtype: "" } + member { + name: "implementation" + mtype: "" + } member { name: "inbound_nodes" mtype: "" @@ -33,6 +57,18 @@ tf_class { name: "input_shape" mtype: "" } + member { + name: "kernel_constraint" + mtype: "" + } + member { + name: "kernel_initializer" + mtype: "" + } + member { + name: "kernel_regularizer" + mtype: "" + } member { name: "losses" mtype: "" @@ -65,10 +101,34 @@ tf_class { name: "output_shape" mtype: "" } + member { + name: "recurrent_activation" + mtype: "" + } + member { + name: "recurrent_constraint" + mtype: "" + } + member { + name: "recurrent_dropout" + mtype: "" + } + member { + name: "recurrent_initializer" + mtype: "" + } + member { + name: "recurrent_regularizer" + mtype: "" + } member { name: "scope_name" mtype: "" } + member { + name: "states" + mtype: "" + } member { name: "trainable_variables" mtype: "" @@ -77,10 +137,18 @@ tf_class { name: "trainable_weights" mtype: "" } + member { + name: "units" + mtype: "" + } member { name: "updates" mtype: "" } + member { + name: "use_bias" + mtype: "" + } member { name: "variables" mtype: "" @@ -91,7 +159,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\'], " + argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\', \'False\', \'False\', \'False\', \'False\', \'False\'], " } member_method { name: "add_loss" @@ -137,10 +205,6 @@ tf_class { name: "get_config" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "get_constants" - argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "get_initial_state" argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" @@ -159,7 +223,7 @@ tf_class { } member_method { name: "get_losses_for" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "get_output_at" @@ -181,10 +245,6 @@ tf_class { name: "get_weights" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "preprocess_input" - argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "reset_states" argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -193,8 +253,4 @@ tf_class { name: "set_weights" argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "step" - argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=None, defaults=None" - } } diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt new file mode 100644 index 00000000000..4ce7c34f6c7 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt @@ -0,0 +1,179 @@ +path: "tensorflow.keras.layers.LSTMCell" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'states\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt index 20935e2f99a..e1a1d0d58ec 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt @@ -1,14 +1,34 @@ path: "tensorflow.keras.layers.LSTM" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" + member { + name: "activation" + mtype: "" + } member { name: "activity_regularizer" mtype: "" } + member { + name: "bias_constraint" + mtype: "" + } + member { + name: "bias_initializer" + mtype: "" + } + member { + name: "bias_regularizer" + mtype: "" + } + member { + name: "dropout" + mtype: "" + } member { name: "dtype" mtype: "" @@ -17,6 +37,10 @@ tf_class { name: "graph" mtype: "" } + member { + name: "implementation" + mtype: "" + } member { name: "inbound_nodes" mtype: "" @@ -33,6 +57,18 @@ tf_class { name: "input_shape" mtype: "" } + member { + name: "kernel_constraint" + mtype: "" + } + member { + name: "kernel_initializer" + mtype: "" + } + member { + name: "kernel_regularizer" + mtype: "" + } member { name: "losses" mtype: "" @@ -65,10 +101,34 @@ tf_class { name: "output_shape" mtype: "" } + member { + name: "recurrent_activation" + mtype: "" + } + member { + name: "recurrent_constraint" + mtype: "" + } + member { + name: "recurrent_dropout" + mtype: "" + } + member { + name: "recurrent_initializer" + mtype: "" + } + member { + name: "recurrent_regularizer" + mtype: "" + } member { name: "scope_name" mtype: "" } + member { + name: "states" + mtype: "" + } member { name: "trainable_variables" mtype: "" @@ -77,10 +137,22 @@ tf_class { name: "trainable_weights" mtype: "" } + member { + name: "unit_forget_bias" + mtype: "" + } + member { + name: "units" + mtype: "" + } member { name: "updates" mtype: "" } + member { + name: "use_bias" + mtype: "" + } member { name: "variables" mtype: "" @@ -91,7 +163,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\'], " + argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\', \'False\', \'False\', \'False\', \'False\', \'False\'], " } member_method { name: "add_loss" @@ -137,10 +209,6 @@ tf_class { name: "get_config" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "get_constants" - argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "get_initial_state" argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" @@ -159,7 +227,7 @@ tf_class { } member_method { name: "get_losses_for" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "get_output_at" @@ -181,10 +249,6 @@ tf_class { name: "get_weights" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "preprocess_input" - argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "reset_states" argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -193,8 +257,4 @@ tf_class { name: "set_weights" argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "step" - argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=None, defaults=None" - } } diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt new file mode 100644 index 00000000000..c7c9b10f22d --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt @@ -0,0 +1,191 @@ +path: "tensorflow.keras.layers.RNN" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "states" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\', \'activity_regularizer\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'initial_state\', \'constants\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_states" + argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt new file mode 100644 index 00000000000..10c7f8867cb --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt @@ -0,0 +1,179 @@ +path: "tensorflow.keras.layers.SimpleRNNCell" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'states\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt index f4148fcc230..588df21088f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt @@ -1,14 +1,34 @@ path: "tensorflow.keras.layers.SimpleRNN" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" + member { + name: "activation" + mtype: "" + } member { name: "activity_regularizer" mtype: "" } + member { + name: "bias_constraint" + mtype: "" + } + member { + name: "bias_initializer" + mtype: "" + } + member { + name: "bias_regularizer" + mtype: "" + } + member { + name: "dropout" + mtype: "" + } member { name: "dtype" mtype: "" @@ -33,6 +53,18 @@ tf_class { name: "input_shape" mtype: "" } + member { + name: "kernel_constraint" + mtype: "" + } + member { + name: "kernel_initializer" + mtype: "" + } + member { + name: "kernel_regularizer" + mtype: "" + } member { name: "losses" mtype: "" @@ -65,10 +97,30 @@ tf_class { name: "output_shape" mtype: "" } + member { + name: "recurrent_constraint" + mtype: "" + } + member { + name: "recurrent_dropout" + mtype: "" + } + member { + name: "recurrent_initializer" + mtype: "" + } + member { + name: "recurrent_regularizer" + mtype: "" + } member { name: "scope_name" mtype: "" } + member { + name: "states" + mtype: "" + } member { name: "trainable_variables" mtype: "" @@ -77,10 +129,18 @@ tf_class { name: "trainable_weights" mtype: "" } + member { + name: "units" + mtype: "" + } member { name: "updates" mtype: "" } + member { + name: "use_bias" + mtype: "" + } member { name: "variables" mtype: "" @@ -91,7 +151,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\'], " + argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'False\', \'False\', \'False\', \'False\', \'False\'], " } member_method { name: "add_loss" @@ -137,10 +197,6 @@ tf_class { name: "get_config" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "get_constants" - argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "get_initial_state" argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" @@ -159,7 +215,7 @@ tf_class { } member_method { name: "get_losses_for" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "get_output_at" @@ -181,10 +237,6 @@ tf_class { name: "get_weights" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "preprocess_input" - argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "reset_states" argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -193,8 +245,4 @@ tf_class { name: "set_weights" argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "step" - argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=None, defaults=None" - } } diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt new file mode 100644 index 00000000000..5779e413422 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt @@ -0,0 +1,183 @@ +path: "tensorflow.keras.layers.StackedRNNCells" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "state_size" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'cells\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt index 8466c3e0390..fe336c4be5a 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt @@ -140,6 +140,10 @@ tf_module { name: "GRU" mtype: "" } + member { + name: "GRUCell" + mtype: "" + } member { name: "GaussianDropout" mtype: "" @@ -208,6 +212,10 @@ tf_module { name: "LSTM" mtype: "" } + member { + name: "LSTMCell" + mtype: "" + } member { name: "Lambda" mtype: "" @@ -272,6 +280,10 @@ tf_module { name: "Permute" mtype: "" } + member { + name: "RNN" + mtype: "" + } member { name: "RepeatVector" mtype: "" @@ -292,6 +304,10 @@ tf_module { name: "SimpleRNN" mtype: "" } + member { + name: "SimpleRNNCell" + mtype: "" + } member { name: "SpatialDropout1D" mtype: "" @@ -304,6 +320,10 @@ tf_module { name: "SpatialDropout3D" mtype: "" } + member { + name: "StackedRNNCells" + mtype: "" + } member { name: "ThresholdedReLU" mtype: "" diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index f1c207f9b68..8d4e4c23dc3 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -98,7 +98,8 @@ do_pylint() { "^tensorflow/contrib/eager/python/evaluator\.py.*\[E0202.*method-hidden "\ "^tensorflow/contrib/eager/python/metrics_impl\.py.*\[E0202.*method-hidden "\ "^tensorflow/python/platform/gfile\.py.*\[E0301.*non-iterator "\ -"^tensorflow/python/keras/_impl/keras/callbacks\.py.*\[E1133.*not-an-iterable" +"^tensorflow/python/keras/_impl/keras/callbacks\.py.*\[E1133.*not-an-iterable "\ +"^tensorflow/python/keras/_impl/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition" echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\""