diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 74df39049cc..1669508ac4f 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -4043,7 +4043,6 @@ py_library( ":nn_ops", ":nn_ops_gen", ":platform_device_context", - ":rnn", ":sparse_ops", ":util", ":variables", @@ -4245,6 +4244,8 @@ py_library( ":util", ":variable_scope", ":variables", + "//tensorflow/python/keras/layers/legacy_rnn:rnn_cell_impl", + "//tensorflow/python/keras/layers/legacy_rnn:rnn_cell_wrapper_impl", ], ) diff --git a/tensorflow/python/keras/layers/legacy_rnn/BUILD b/tensorflow/python/keras/layers/legacy_rnn/BUILD new file mode 100644 index 00000000000..4d3b4a4c852 --- /dev/null +++ b/tensorflow/python/keras/layers/legacy_rnn/BUILD @@ -0,0 +1,54 @@ +# Description: +# Contains the legacy TF RNN APIs (internal TensorFlow version). + +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) + +py_library( + name = "rnn_cell_impl", + srcs = ["rnn_cell_impl.py"], + deps = [ + ":rnn_cell_wrapper_impl", + "//tensorflow/python:array_ops", + "//tensorflow/python:clip_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers_base", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:partitioned_variables", + "//tensorflow/python:platform", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_util", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/keras:activations", + "//tensorflow/python/keras:initializers", + "//tensorflow/python/keras/engine:input_spec", + "//tensorflow/python/keras/utils:tf_utils", + "//tensorflow/python/training/tracking:base", + ], +) + +py_library( + name = "rnn_cell_wrapper_impl", + srcs = ["rnn_cell_wrapper_impl.py"], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:tensor_array_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_util", + "//tensorflow/python:util", + "//tensorflow/python/keras/utils:generic_utils", + ], +) diff --git a/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py new file mode 100644 index 00000000000..1d03780d51b --- /dev/null +++ b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py @@ -0,0 +1,1355 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module implementing RNN Cells. + +This module provides a number of basic commonly used RNN cells, such as LSTM +(Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number of +operators that allow adding dropouts, projections, or embeddings for inputs. +Constructing multi-layer cells is supported by the class `MultiRNNCell`, or by +calling the `rnn` ops several times. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.keras import activations +from tensorflow.python.keras import initializers +from tensorflow.python.keras.engine import input_spec +from tensorflow.python.keras.layers.legacy_rnn import rnn_cell_wrapper_impl +from tensorflow.python.keras.utils import tf_utils +from tensorflow.python.layers import base as base_layer +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import variables as tf_variables +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training.tracking import base as trackable +from tensorflow.python.util import nest +from tensorflow.python.util.deprecation import deprecated +from tensorflow.python.util.tf_export import tf_export + +_BIAS_VARIABLE_NAME = "bias" +_WEIGHTS_VARIABLE_NAME = "kernel" + +# This can be used with self.assertRaisesRegexp for assert_like_rnncell. +ASSERT_LIKE_RNNCELL_ERROR_REGEXP = "is not an RNNCell" + + +def _hasattr(obj, attr_name): + try: + getattr(obj, attr_name) + except AttributeError: + return False + else: + return True + + +def assert_like_rnncell(cell_name, cell): + """Raises a TypeError if cell is not like an RNNCell. + + NOTE: Do not rely on the error message (in particular in tests) which can be + subject to change to increase readability. Use + ASSERT_LIKE_RNNCELL_ERROR_REGEXP. + + Args: + cell_name: A string to give a meaningful error referencing to the name of + the functionargument. + cell: The object which should behave like an RNNCell. + + Raises: + TypeError: A human-friendly exception. + """ + conditions = [ + _hasattr(cell, "output_size"), + _hasattr(cell, "state_size"), + _hasattr(cell, "get_initial_state") or _hasattr(cell, "zero_state"), + callable(cell), + ] + errors = [ + "'output_size' property is missing", "'state_size' property is missing", + "either 'zero_state' or 'get_initial_state' method is required", + "is not callable" + ] + + if not all(conditions): + + errors = [error for error, cond in zip(errors, conditions) if not cond] + raise TypeError("The argument {!r} ({}) is not an RNNCell: {}.".format( + cell_name, cell, ", ".join(errors))) + + +def _concat(prefix, suffix, static=False): + """Concat that enables int, Tensor, or TensorShape values. + + This function takes a size specification, which can be an integer, a + TensorShape, or a Tensor, and converts it into a concatenated Tensor + (if static = False) or a list of integers (if static = True). + + Args: + prefix: The prefix; usually the batch size (and/or time step size). + (TensorShape, int, or Tensor.) + suffix: TensorShape, int, or Tensor. + static: If `True`, return a python list with possibly unknown dimensions. + Otherwise return a `Tensor`. + + Returns: + shape: the concatenation of prefix and suffix. + + Raises: + ValueError: if `suffix` is not a scalar or vector (or TensorShape). + ValueError: if prefix or suffix was `None` and asked for dynamic + Tensors out. + """ + if isinstance(prefix, ops.Tensor): + p = prefix + p_static = tensor_util.constant_value(prefix) + if p.shape.ndims == 0: + p = array_ops.expand_dims(p, 0) + elif p.shape.ndims != 1: + raise ValueError("prefix tensor must be either a scalar or vector, " + "but saw tensor: %s" % p) + else: + p = tensor_shape.as_shape(prefix) + p_static = p.as_list() if p.ndims is not None else None + p = ( + constant_op.constant(p.as_list(), dtype=dtypes.int32) + if p.is_fully_defined() else None) + if isinstance(suffix, ops.Tensor): + s = suffix + s_static = tensor_util.constant_value(suffix) + if s.shape.ndims == 0: + s = array_ops.expand_dims(s, 0) + elif s.shape.ndims != 1: + raise ValueError("suffix tensor must be either a scalar or vector, " + "but saw tensor: %s" % s) + else: + s = tensor_shape.as_shape(suffix) + s_static = s.as_list() if s.ndims is not None else None + s = ( + constant_op.constant(s.as_list(), dtype=dtypes.int32) + if s.is_fully_defined() else None) + + if static: + shape = tensor_shape.as_shape(p_static).concatenate(s_static) + shape = shape.as_list() if shape.ndims is not None else None + else: + if p is None or s is None: + raise ValueError("Provided a prefix or suffix of None: %s and %s" % + (prefix, suffix)) + shape = array_ops.concat((p, s), 0) + return shape + + +def _zero_state_tensors(state_size, batch_size, dtype): + """Create tensors of zeros based on state_size, batch_size, and dtype.""" + + def get_state_shape(s): + """Combine s with batch_size to get a proper tensor shape.""" + c = _concat(batch_size, s) + size = array_ops.zeros(c, dtype=dtype) + if not context.executing_eagerly(): + c_static = _concat(batch_size, s, static=True) + size.set_shape(c_static) + return size + + return nest.map_structure(get_state_shape, state_size) + + +@tf_export(v1=["nn.rnn_cell.RNNCell"]) +class RNNCell(base_layer.Layer): + """Abstract object representing an RNN cell. + + Every `RNNCell` must have the properties below and implement `call` with + the signature `(output, next_state) = call(input, state)`. The optional + third input argument, `scope`, is allowed for backwards compatibility + purposes; but should be left off for new subclasses. + + This definition of cell differs from the definition used in the literature. + In the literature, 'cell' refers to an object with a single scalar output. + This definition refers to a horizontal array of such units. + + An RNN cell, in the most abstract setting, is anything that has + a state and performs some operation that takes a matrix of inputs. + This operation results in an output matrix with `self.output_size` columns. + If `self.state_size` is an integer, this operation also results in a new + state matrix with `self.state_size` columns. If `self.state_size` is a + (possibly nested tuple of) TensorShape object(s), then it should return a + matching structure of Tensors having shape `[batch_size].concatenate(s)` + for each `s` in `self.batch_size`. + """ + + def __init__(self, trainable=True, name=None, dtype=None, **kwargs): + super(RNNCell, self).__init__( + trainable=trainable, name=name, dtype=dtype, **kwargs) + # Attribute that indicates whether the cell is a TF RNN cell, due the slight + # difference between TF and Keras RNN cell. Notably the state is not wrapped + # in a list for TF cell where they are single tensor state, whereas keras + # cell will wrap the state into a list, and call() will have to unwrap them. + self._is_tf_rnn_cell = True + + def __call__(self, inputs, state, scope=None): + """Run this RNN cell on inputs, starting from the given state. + + Args: + inputs: `2-D` tensor with shape `[batch_size, input_size]`. + state: if `self.state_size` is an integer, this should be a `2-D Tensor` + with shape `[batch_size, self.state_size]`. Otherwise, if + `self.state_size` is a tuple of integers, this should be a tuple with + shapes `[batch_size, s] for s in self.state_size`. + scope: VariableScope for the created subgraph; defaults to class name. + + Returns: + A pair containing: + + - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`. + - New state: Either a single `2-D` tensor, or a tuple of tensors matching + the arity and shapes of `state`. + """ + if scope is not None: + with vs.variable_scope( + scope, custom_getter=self._rnn_get_variable) as scope: + return super(RNNCell, self).__call__(inputs, state, scope=scope) + else: + scope_attrname = "rnncell_scope" + scope = getattr(self, scope_attrname, None) + if scope is None: + scope = vs.variable_scope( + vs.get_variable_scope(), custom_getter=self._rnn_get_variable) + setattr(self, scope_attrname, scope) + with scope: + return super(RNNCell, self).__call__(inputs, state) + + def _rnn_get_variable(self, getter, *args, **kwargs): + variable = getter(*args, **kwargs) + if context.executing_eagerly(): + trainable = variable._trainable # pylint: disable=protected-access + else: + trainable = ( + variable in tf_variables.trainable_variables() or + (isinstance(variable, tf_variables.PartitionedVariable) and + list(variable)[0] in tf_variables.trainable_variables())) + if trainable and all(variable is not v for v in self._trainable_weights): + self._trainable_weights.append(variable) + elif not trainable and all( + variable is not v for v in self._non_trainable_weights): + self._non_trainable_weights.append(variable) + return variable + + @property + def state_size(self): + """size(s) of state(s) used by this cell. + + It can be represented by an Integer, a TensorShape or a tuple of Integers + or TensorShapes. + """ + raise NotImplementedError("Abstract method") + + @property + def output_size(self): + """Integer or TensorShape: size of outputs produced by this cell.""" + raise NotImplementedError("Abstract method") + + def build(self, _): + # This tells the parent Layer object that it's OK to call + # self.add_variable() inside the call() method. + pass + + def get_initial_state(self, inputs=None, batch_size=None, dtype=None): + if inputs is not None: + # Validate the given batch_size and dtype against inputs if provided. + inputs = ops.convert_to_tensor(inputs, name="inputs") + if batch_size is not None: + if tensor_util.is_tensor(batch_size): + static_batch_size = tensor_util.constant_value( + batch_size, partial=True) + else: + static_batch_size = batch_size + if inputs.shape.dims[0].value != static_batch_size: + raise ValueError( + "batch size from input tensor is different from the " + "input param. Input tensor batch: {}, batch_size: {}".format( + inputs.shape.dims[0].value, batch_size)) + + if dtype is not None and inputs.dtype != dtype: + raise ValueError( + "dtype from input tensor is different from the " + "input param. Input tensor dtype: {}, dtype: {}".format( + inputs.dtype, dtype)) + + batch_size = inputs.shape.dims[0].value or array_ops.shape(inputs)[0] + dtype = inputs.dtype + if batch_size is None or dtype is None: + raise ValueError( + "batch_size and dtype cannot be None while constructing initial " + "state: batch_size={}, dtype={}".format(batch_size, dtype)) + return self.zero_state(batch_size, dtype) + + def zero_state(self, batch_size, dtype): + """Return zero-filled state tensor(s). + + Args: + batch_size: int, float, or unit Tensor representing the batch size. + dtype: the data type to use for the state. + + Returns: + If `state_size` is an int or TensorShape, then the return value is a + `N-D` tensor of shape `[batch_size, state_size]` filled with zeros. + + If `state_size` is a nested list or tuple, then the return value is + a nested list or tuple (of the same structure) of `2-D` tensors with + the shapes `[batch_size, s]` for each s in `state_size`. + """ + # Try to use the last cached zero_state. This is done to avoid recreating + # zeros, especially when eager execution is enabled. + state_size = self.state_size + is_eager = context.executing_eagerly() + if is_eager and _hasattr(self, "_last_zero_state"): + (last_state_size, last_batch_size, last_dtype, + last_output) = getattr(self, "_last_zero_state") + if (last_batch_size == batch_size and last_dtype == dtype and + last_state_size == state_size): + return last_output + with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): + output = _zero_state_tensors(state_size, batch_size, dtype) + if is_eager: + self._last_zero_state = (state_size, batch_size, dtype, output) + return output + + # TODO(b/134773139): Remove when contrib RNN cells implement `get_config` + def get_config(self): # pylint: disable=useless-super-delegation + return super(RNNCell, self).get_config() + + +class LayerRNNCell(RNNCell): + """Subclass of RNNCells that act like proper `tf.Layer` objects. + + For backwards compatibility purposes, most `RNNCell` instances allow their + `call` methods to instantiate variables via `tf.compat.v1.get_variable`. The + underlying + variable scope thus keeps track of any variables, and returning cached + versions. This is atypical of `tf.layer` objects, which separate this + part of layer building into a `build` method that is only called once. + + Here we provide a subclass for `RNNCell` objects that act exactly as + `Layer` objects do. They must provide a `build` method and their + `call` methods do not access Variables `tf.compat.v1.get_variable`. + """ + + def __call__(self, inputs, state, scope=None, *args, **kwargs): + """Run this RNN cell on inputs, starting from the given state. + + Args: + inputs: `2-D` tensor with shape `[batch_size, input_size]`. + state: if `self.state_size` is an integer, this should be a `2-D Tensor` + with shape `[batch_size, self.state_size]`. Otherwise, if + `self.state_size` is a tuple of integers, this should be a tuple with + shapes `[batch_size, s] for s in self.state_size`. + scope: optional cell scope. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + A pair containing: + + - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`. + - New state: Either a single `2-D` tensor, or a tuple of tensors matching + the arity and shapes of `state`. + """ + # Bypass RNNCell's variable capturing semantics for LayerRNNCell. + # Instead, it is up to subclasses to provide a proper build + # method. See the class docstring for more details. + return base_layer.Layer.__call__( + self, inputs, state, scope=scope, *args, **kwargs) + + +@tf_export(v1=["nn.rnn_cell.BasicRNNCell"]) +class BasicRNNCell(LayerRNNCell): + """The most basic RNN cell. + + Note that this cell is not optimized for performance. Please use + `tf.contrib.cudnn_rnn.CudnnRNNTanh` for better performance on GPU. + + Args: + num_units: int, The number of units in the RNN cell. + activation: Nonlinearity to use. Default: `tanh`. It could also be string + that is within Keras activation function names. + reuse: (optional) Python boolean describing whether to reuse variables in an + existing scope. If not `True`, and the existing scope already has the + given variables, an error is raised. + name: String, the name of the layer. Layers with the same name will share + weights, but to avoid mistakes we require reuse=True in such cases. + dtype: Default dtype of the layer (default of `None` means use the type of + the first input). Required when `build` is called before `call`. + **kwargs: Dict, keyword named properties for common layer attributes, like + `trainable` etc when constructing the cell from configs of get_config(). + """ + + @deprecated(None, "This class is equivalent as tf.keras.layers.SimpleRNNCell," + " and will be replaced by that in Tensorflow 2.0.") + def __init__(self, + num_units, + activation=None, + reuse=None, + name=None, + dtype=None, + **kwargs): + super(BasicRNNCell, self).__init__( + _reuse=reuse, name=name, dtype=dtype, **kwargs) + _check_supported_dtypes(self.dtype) + if context.executing_eagerly() and context.num_gpus() > 0: + logging.warn( + "%s: Note that this cell is not optimized for performance. " + "Please use tf.contrib.cudnn_rnn.CudnnRNNTanh for better " + "performance on GPU.", self) + + # Inputs must be 2-dimensional. + self.input_spec = input_spec.InputSpec(ndim=2) + + self._num_units = num_units + if activation: + self._activation = activations.get(activation) + else: + self._activation = math_ops.tanh + + @property + def state_size(self): + return self._num_units + + @property + def output_size(self): + return self._num_units + + @tf_utils.shape_type_conversion + def build(self, inputs_shape): + if inputs_shape[-1] is None: + raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % + str(inputs_shape)) + _check_supported_dtypes(self.dtype) + + input_depth = inputs_shape[-1] + self._kernel = self.add_variable( + _WEIGHTS_VARIABLE_NAME, + shape=[input_depth + self._num_units, self._num_units]) + self._bias = self.add_variable( + _BIAS_VARIABLE_NAME, + shape=[self._num_units], + initializer=init_ops.zeros_initializer(dtype=self.dtype)) + + self.built = True + + def call(self, inputs, state): + """Most basic RNN: output = new_state = act(W * input + U * state + B).""" + _check_rnn_cell_input_dtypes([inputs, state]) + gate_inputs = math_ops.matmul( + array_ops.concat([inputs, state], 1), self._kernel) + gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) + output = self._activation(gate_inputs) + return output, output + + def get_config(self): + config = { + "num_units": self._num_units, + "activation": activations.serialize(self._activation), + "reuse": self._reuse, + } + base_config = super(BasicRNNCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +@tf_export(v1=["nn.rnn_cell.GRUCell"]) +class GRUCell(LayerRNNCell): + """Gated Recurrent Unit cell. + + Note that this cell is not optimized for performance. Please use + `tf.contrib.cudnn_rnn.CudnnGRU` for better performance on GPU, or + `tf.contrib.rnn.GRUBlockCellV2` for better performance on CPU. + + Args: + num_units: int, The number of units in the GRU cell. + activation: Nonlinearity to use. Default: `tanh`. + reuse: (optional) Python boolean describing whether to reuse variables in an + existing scope. If not `True`, and the existing scope already has the + given variables, an error is raised. + kernel_initializer: (optional) The initializer to use for the weight and + projection matrices. + bias_initializer: (optional) The initializer to use for the bias. + name: String, the name of the layer. Layers with the same name will share + weights, but to avoid mistakes we require reuse=True in such cases. + dtype: Default dtype of the layer (default of `None` means use the type of + the first input). Required when `build` is called before `call`. + **kwargs: Dict, keyword named properties for common layer attributes, like + `trainable` etc when constructing the cell from configs of get_config(). + + References: + Learning Phrase Representations using RNN Encoder Decoder for Statistical + Machine Translation: + [Cho et al., 2014] + (https://aclanthology.coli.uni-saarland.de/papers/D14-1179/d14-1179) + ([pdf](http://emnlp2014.org/papers/pdf/EMNLP2014179.pdf)) + """ + + @deprecated(None, "This class is equivalent as tf.keras.layers.GRUCell," + " and will be replaced by that in Tensorflow 2.0.") + def __init__(self, + num_units, + activation=None, + reuse=None, + kernel_initializer=None, + bias_initializer=None, + name=None, + dtype=None, + **kwargs): + super(GRUCell, self).__init__( + _reuse=reuse, name=name, dtype=dtype, **kwargs) + _check_supported_dtypes(self.dtype) + + if context.executing_eagerly() and context.num_gpus() > 0: + logging.warn( + "%s: Note that this cell is not optimized for performance. " + "Please use tf.contrib.cudnn_rnn.CudnnGRU for better " + "performance on GPU.", self) + # Inputs must be 2-dimensional. + self.input_spec = input_spec.InputSpec(ndim=2) + + self._num_units = num_units + if activation: + self._activation = activations.get(activation) + else: + self._activation = math_ops.tanh + self._kernel_initializer = initializers.get(kernel_initializer) + self._bias_initializer = initializers.get(bias_initializer) + + @property + def state_size(self): + return self._num_units + + @property + def output_size(self): + return self._num_units + + @tf_utils.shape_type_conversion + def build(self, inputs_shape): + if inputs_shape[-1] is None: + raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % + str(inputs_shape)) + _check_supported_dtypes(self.dtype) + input_depth = inputs_shape[-1] + self._gate_kernel = self.add_variable( + "gates/%s" % _WEIGHTS_VARIABLE_NAME, + shape=[input_depth + self._num_units, 2 * self._num_units], + initializer=self._kernel_initializer) + self._gate_bias = self.add_variable( + "gates/%s" % _BIAS_VARIABLE_NAME, + shape=[2 * self._num_units], + initializer=(self._bias_initializer + if self._bias_initializer is not None else + init_ops.constant_initializer(1.0, dtype=self.dtype))) + self._candidate_kernel = self.add_variable( + "candidate/%s" % _WEIGHTS_VARIABLE_NAME, + shape=[input_depth + self._num_units, self._num_units], + initializer=self._kernel_initializer) + self._candidate_bias = self.add_variable( + "candidate/%s" % _BIAS_VARIABLE_NAME, + shape=[self._num_units], + initializer=(self._bias_initializer + if self._bias_initializer is not None else + init_ops.zeros_initializer(dtype=self.dtype))) + + self.built = True + + def call(self, inputs, state): + """Gated recurrent unit (GRU) with nunits cells.""" + _check_rnn_cell_input_dtypes([inputs, state]) + + gate_inputs = math_ops.matmul( + array_ops.concat([inputs, state], 1), self._gate_kernel) + gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias) + + value = math_ops.sigmoid(gate_inputs) + r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) + + r_state = r * state + + candidate = math_ops.matmul( + array_ops.concat([inputs, r_state], 1), self._candidate_kernel) + candidate = nn_ops.bias_add(candidate, self._candidate_bias) + + c = self._activation(candidate) + new_h = u * state + (1 - u) * c + return new_h, new_h + + def get_config(self): + config = { + "num_units": self._num_units, + "kernel_initializer": initializers.serialize(self._kernel_initializer), + "bias_initializer": initializers.serialize(self._bias_initializer), + "activation": activations.serialize(self._activation), + "reuse": self._reuse, + } + base_config = super(GRUCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h")) + + +@tf_export(v1=["nn.rnn_cell.LSTMStateTuple"]) +class LSTMStateTuple(_LSTMStateTuple): + """Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state. + + Stores two elements: `(c, h)`, in that order. Where `c` is the hidden state + and `h` is the output. + + Only used when `state_is_tuple=True`. + """ + __slots__ = () + + @property + def dtype(self): + (c, h) = self + if c.dtype != h.dtype: + raise TypeError("Inconsistent internal state: %s vs %s" % + (str(c.dtype), str(h.dtype))) + return c.dtype + + +@tf_export(v1=["nn.rnn_cell.BasicLSTMCell"]) +class BasicLSTMCell(LayerRNNCell): + """DEPRECATED: Please use `tf.compat.v1.nn.rnn_cell.LSTMCell` instead. + + Basic LSTM recurrent network cell. + + The implementation is based on + + We add forget_bias (default: 1) to the biases of the forget gate in order to + reduce the scale of forgetting in the beginning of the training. + + It does not allow cell clipping, a projection layer, and does not + use peep-hole connections: it is the basic baseline. + + For advanced models, please use the full `tf.compat.v1.nn.rnn_cell.LSTMCell` + that follows. + + Note that this cell is not optimized for performance. Please use + `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or + `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for + better performance on CPU. + """ + + @deprecated(None, "This class is equivalent as tf.keras.layers.LSTMCell," + " and will be replaced by that in Tensorflow 2.0.") + def __init__(self, + num_units, + forget_bias=1.0, + state_is_tuple=True, + activation=None, + reuse=None, + name=None, + dtype=None, + **kwargs): + """Initialize the basic LSTM cell. + + Args: + num_units: int, The number of units in the LSTM cell. + forget_bias: float, The bias added to forget gates (see above). Must set + to `0.0` manually when restoring from CudnnLSTM-trained checkpoints. + state_is_tuple: If True, accepted and returned states are 2-tuples of the + `c_state` and `m_state`. If False, they are concatenated along the + column axis. The latter behavior will soon be deprecated. + activation: Activation function of the inner states. Default: `tanh`. It + could also be string that is within Keras activation function names. + reuse: (optional) Python boolean describing whether to reuse variables in + an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + name: String, the name of the layer. Layers with the same name will share + weights, but to avoid mistakes we require reuse=True in such cases. + dtype: Default dtype of the layer (default of `None` means use the type of + the first input). Required when `build` is called before `call`. + **kwargs: Dict, keyword named properties for common layer attributes, like + `trainable` etc when constructing the cell from configs of get_config(). + When restoring from CudnnLSTM-trained checkpoints, must use + `CudnnCompatibleLSTMCell` instead. + """ + super(BasicLSTMCell, self).__init__( + _reuse=reuse, name=name, dtype=dtype, **kwargs) + _check_supported_dtypes(self.dtype) + if not state_is_tuple: + logging.warn( + "%s: Using a concatenated state is slower and will soon be " + "deprecated. Use state_is_tuple=True.", self) + if context.executing_eagerly() and context.num_gpus() > 0: + logging.warn( + "%s: Note that this cell is not optimized for performance. " + "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better " + "performance on GPU.", self) + + # Inputs must be 2-dimensional. + self.input_spec = input_spec.InputSpec(ndim=2) + + self._num_units = num_units + self._forget_bias = forget_bias + self._state_is_tuple = state_is_tuple + if activation: + self._activation = activations.get(activation) + else: + self._activation = math_ops.tanh + + @property + def state_size(self): + return (LSTMStateTuple(self._num_units, self._num_units) + if self._state_is_tuple else 2 * self._num_units) + + @property + def output_size(self): + return self._num_units + + @tf_utils.shape_type_conversion + def build(self, inputs_shape): + if inputs_shape[-1] is None: + raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % + str(inputs_shape)) + _check_supported_dtypes(self.dtype) + input_depth = inputs_shape[-1] + h_depth = self._num_units + self._kernel = self.add_variable( + _WEIGHTS_VARIABLE_NAME, + shape=[input_depth + h_depth, 4 * self._num_units]) + self._bias = self.add_variable( + _BIAS_VARIABLE_NAME, + shape=[4 * self._num_units], + initializer=init_ops.zeros_initializer(dtype=self.dtype)) + + self.built = True + + def call(self, inputs, state): + """Long short-term memory cell (LSTM). + + Args: + inputs: `2-D` tensor with shape `[batch_size, input_size]`. + state: An `LSTMStateTuple` of state tensors, each shaped `[batch_size, + num_units]`, if `state_is_tuple` has been set to `True`. Otherwise, a + `Tensor` shaped `[batch_size, 2 * num_units]`. + + Returns: + A pair containing the new hidden state, and the new state (either a + `LSTMStateTuple` or a concatenated state, depending on + `state_is_tuple`). + """ + _check_rnn_cell_input_dtypes([inputs, state]) + + sigmoid = math_ops.sigmoid + one = constant_op.constant(1, dtype=dtypes.int32) + # Parameters of gates are concatenated into one multiply for efficiency. + if self._state_is_tuple: + c, h = state + else: + c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one) + + gate_inputs = math_ops.matmul( + array_ops.concat([inputs, h], 1), self._kernel) + gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + i, j, f, o = array_ops.split( + value=gate_inputs, num_or_size_splits=4, axis=one) + + forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype) + # Note that using `add` and `multiply` instead of `+` and `*` gives a + # performance improvement. So using those at the cost of readability. + add = math_ops.add + multiply = math_ops.multiply + new_c = add( + multiply(c, sigmoid(add(f, forget_bias_tensor))), + multiply(sigmoid(i), self._activation(j))) + new_h = multiply(self._activation(new_c), sigmoid(o)) + + if self._state_is_tuple: + new_state = LSTMStateTuple(new_c, new_h) + else: + new_state = array_ops.concat([new_c, new_h], 1) + return new_h, new_state + + def get_config(self): + config = { + "num_units": self._num_units, + "forget_bias": self._forget_bias, + "state_is_tuple": self._state_is_tuple, + "activation": activations.serialize(self._activation), + "reuse": self._reuse, + } + base_config = super(BasicLSTMCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +@tf_export(v1=["nn.rnn_cell.LSTMCell"]) +class LSTMCell(LayerRNNCell): + """Long short-term memory unit (LSTM) recurrent network cell. + + The default non-peephole implementation is based on (Gers et al., 1999). + The peephole implementation is based on (Sak et al., 2014). + + The class uses optional peep-hole connections, optional cell clipping, and + an optional projection layer. + + Note that this cell is not optimized for performance. Please use + `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or + `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for + better performance on CPU. + References: + Long short-term memory recurrent neural network architectures for large + scale acoustic modeling: + [Sak et al., 2014] + (https://www.isca-speech.org/archive/interspeech_2014/i14_0338.html) + ([pdf] + (https://www.isca-speech.org/archive/archive_papers/interspeech_2014/i14_0338.pdf)) + Learning to forget: + [Gers et al., 1999] + (http://digital-library.theiet.org/content/conferences/10.1049/cp_19991218) + ([pdf](https://arxiv.org/pdf/1409.2329.pdf)) + Long Short-Term Memory: + [Hochreiter et al., 1997] + (https://www.mitpressjournals.org/doi/abs/10.1162/neco.1997.9.8.1735) + ([pdf](http://ml.jku.at/publications/older/3504.pdf)) + """ + + @deprecated(None, "This class is equivalent as tf.keras.layers.LSTMCell," + " and will be replaced by that in Tensorflow 2.0.") + def __init__(self, + num_units, + use_peepholes=False, + cell_clip=None, + initializer=None, + num_proj=None, + proj_clip=None, + num_unit_shards=None, + num_proj_shards=None, + forget_bias=1.0, + state_is_tuple=True, + activation=None, + reuse=None, + name=None, + dtype=None, + **kwargs): + """Initialize the parameters for an LSTM cell. + + Args: + num_units: int, The number of units in the LSTM cell. + use_peepholes: bool, set True to enable diagonal/peephole connections. + cell_clip: (optional) A float value, if provided the cell state is clipped + by this value prior to the cell output activation. + initializer: (optional) The initializer to use for the weight and + projection matrices. + num_proj: (optional) int, The output dimensionality for the projection + matrices. If None, no projection is performed. + proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is + provided, then the projected values are clipped elementwise to within + `[-proj_clip, proj_clip]`. + num_unit_shards: Deprecated, will be removed by Jan. 2017. Use a + variable_scope partitioner instead. + num_proj_shards: Deprecated, will be removed by Jan. 2017. Use a + variable_scope partitioner instead. + forget_bias: Biases of the forget gate are initialized by default to 1 in + order to reduce the scale of forgetting at the beginning of the + training. Must set it manually to `0.0` when restoring from CudnnLSTM + trained checkpoints. + state_is_tuple: If True, accepted and returned states are 2-tuples of the + `c_state` and `m_state`. If False, they are concatenated along the + column axis. This latter behavior will soon be deprecated. + activation: Activation function of the inner states. Default: `tanh`. It + could also be string that is within Keras activation function names. + reuse: (optional) Python boolean describing whether to reuse variables in + an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + name: String, the name of the layer. Layers with the same name will share + weights, but to avoid mistakes we require reuse=True in such cases. + dtype: Default dtype of the layer (default of `None` means use the type of + the first input). Required when `build` is called before `call`. + **kwargs: Dict, keyword named properties for common layer attributes, like + `trainable` etc when constructing the cell from configs of get_config(). + When restoring from CudnnLSTM-trained checkpoints, use + `CudnnCompatibleLSTMCell` instead. + """ + super(LSTMCell, self).__init__( + _reuse=reuse, name=name, dtype=dtype, **kwargs) + _check_supported_dtypes(self.dtype) + if not state_is_tuple: + logging.warn( + "%s: Using a concatenated state is slower and will soon be " + "deprecated. Use state_is_tuple=True.", self) + if num_unit_shards is not None or num_proj_shards is not None: + logging.warn( + "%s: The num_unit_shards and proj_unit_shards parameters are " + "deprecated and will be removed in Jan 2017. " + "Use a variable scope with a partitioner instead.", self) + if context.executing_eagerly() and context.num_gpus() > 0: + logging.warn( + "%s: Note that this cell is not optimized for performance. " + "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better " + "performance on GPU.", self) + + # Inputs must be 2-dimensional. + self.input_spec = input_spec.InputSpec(ndim=2) + + self._num_units = num_units + self._use_peepholes = use_peepholes + self._cell_clip = cell_clip + self._initializer = initializers.get(initializer) + self._num_proj = num_proj + self._proj_clip = proj_clip + self._num_unit_shards = num_unit_shards + self._num_proj_shards = num_proj_shards + self._forget_bias = forget_bias + self._state_is_tuple = state_is_tuple + if activation: + self._activation = activations.get(activation) + else: + self._activation = math_ops.tanh + + if num_proj: + self._state_size = ( + LSTMStateTuple(num_units, num_proj) if state_is_tuple else num_units + + num_proj) + self._output_size = num_proj + else: + self._state_size = ( + LSTMStateTuple(num_units, num_units) if state_is_tuple else 2 * + num_units) + self._output_size = num_units + + @property + def state_size(self): + return self._state_size + + @property + def output_size(self): + return self._output_size + + @tf_utils.shape_type_conversion + def build(self, inputs_shape): + if inputs_shape[-1] is None: + raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % + str(inputs_shape)) + _check_supported_dtypes(self.dtype) + input_depth = inputs_shape[-1] + h_depth = self._num_units if self._num_proj is None else self._num_proj + maybe_partitioner = ( + partitioned_variables.fixed_size_partitioner(self._num_unit_shards) + if self._num_unit_shards is not None else None) + self._kernel = self.add_variable( + _WEIGHTS_VARIABLE_NAME, + shape=[input_depth + h_depth, 4 * self._num_units], + initializer=self._initializer, + partitioner=maybe_partitioner) + if self.dtype is None: + initializer = init_ops.zeros_initializer + else: + initializer = init_ops.zeros_initializer(dtype=self.dtype) + self._bias = self.add_variable( + _BIAS_VARIABLE_NAME, + shape=[4 * self._num_units], + initializer=initializer) + if self._use_peepholes: + self._w_f_diag = self.add_variable( + "w_f_diag", shape=[self._num_units], initializer=self._initializer) + self._w_i_diag = self.add_variable( + "w_i_diag", shape=[self._num_units], initializer=self._initializer) + self._w_o_diag = self.add_variable( + "w_o_diag", shape=[self._num_units], initializer=self._initializer) + + if self._num_proj is not None: + maybe_proj_partitioner = ( + partitioned_variables.fixed_size_partitioner(self._num_proj_shards) + if self._num_proj_shards is not None else None) + self._proj_kernel = self.add_variable( + "projection/%s" % _WEIGHTS_VARIABLE_NAME, + shape=[self._num_units, self._num_proj], + initializer=self._initializer, + partitioner=maybe_proj_partitioner) + + self.built = True + + def call(self, inputs, state): + """Run one step of LSTM. + + Args: + inputs: input Tensor, must be 2-D, `[batch, input_size]`. + state: if `state_is_tuple` is False, this must be a state Tensor, `2-D, + [batch, state_size]`. If `state_is_tuple` is True, this must be a tuple + of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. + + Returns: + A tuple containing: + + - A `2-D, [batch, output_dim]`, Tensor representing the output of the + LSTM after reading `inputs` when previous state was `state`. + Here output_dim is: + num_proj if num_proj was set, + num_units otherwise. + - Tensor(s) representing the new state of LSTM after reading `inputs` when + the previous state was `state`. Same type and shape(s) as `state`. + + Raises: + ValueError: If input size cannot be inferred from inputs via + static shape inference. + """ + _check_rnn_cell_input_dtypes([inputs, state]) + + num_proj = self._num_units if self._num_proj is None else self._num_proj + sigmoid = math_ops.sigmoid + + if self._state_is_tuple: + (c_prev, m_prev) = state + else: + c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) + m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) + + input_size = inputs.get_shape().with_rank(2).dims[1].value + if input_size is None: + raise ValueError("Could not infer input size from inputs.get_shape()[-1]") + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + lstm_matrix = math_ops.matmul( + array_ops.concat([inputs, m_prev], 1), self._kernel) + lstm_matrix = nn_ops.bias_add(lstm_matrix, self._bias) + + i, j, f, o = array_ops.split( + value=lstm_matrix, num_or_size_splits=4, axis=1) + # Diagonal connections + if self._use_peepholes: + c = ( + sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev + + sigmoid(i + self._w_i_diag * c_prev) * self._activation(j)) + else: + c = ( + sigmoid(f + self._forget_bias) * c_prev + + sigmoid(i) * self._activation(j)) + + if self._cell_clip is not None: + # pylint: disable=invalid-unary-operand-type + c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) + # pylint: enable=invalid-unary-operand-type + if self._use_peepholes: + m = sigmoid(o + self._w_o_diag * c) * self._activation(c) + else: + m = sigmoid(o) * self._activation(c) + + if self._num_proj is not None: + m = math_ops.matmul(m, self._proj_kernel) + + if self._proj_clip is not None: + # pylint: disable=invalid-unary-operand-type + m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) + # pylint: enable=invalid-unary-operand-type + + new_state = ( + LSTMStateTuple(c, m) + if self._state_is_tuple else array_ops.concat([c, m], 1)) + return m, new_state + + def get_config(self): + config = { + "num_units": self._num_units, + "use_peepholes": self._use_peepholes, + "cell_clip": self._cell_clip, + "initializer": initializers.serialize(self._initializer), + "num_proj": self._num_proj, + "proj_clip": self._proj_clip, + "num_unit_shards": self._num_unit_shards, + "num_proj_shards": self._num_proj_shards, + "forget_bias": self._forget_bias, + "state_is_tuple": self._state_is_tuple, + "activation": activations.serialize(self._activation), + "reuse": self._reuse, + } + base_config = super(LSTMCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class _RNNCellWrapperV1(RNNCell): + """Base class for cells wrappers V1 compatibility. + + This class along with `_RNNCellWrapperV2` allows to define cells wrappers that + are compatible with V1 and V2, and defines helper methods for this purpose. + """ + + def __init__(self, cell, *args, **kwargs): + super(_RNNCellWrapperV1, self).__init__(*args, **kwargs) + assert_like_rnncell("cell", cell) + self.cell = cell + if isinstance(cell, trackable.Trackable): + self._track_trackable(self.cell, name="cell") + + def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): + """Calls the wrapped cell and performs the wrapping logic. + + This method is called from the wrapper's `call` or `__call__` methods. + + Args: + inputs: A tensor with wrapped cell's input. + state: A tensor or tuple of tensors with wrapped cell's state. + cell_call_fn: Wrapped cell's method to use for step computation (cell's + `__call__` or 'call' method). + **kwargs: Additional arguments. + + Returns: + A pair containing: + - Output: A tensor with cell's output. + - New state: A tensor or tuple of tensors with new wrapped cell's state. + """ + raise NotImplementedError + + def __call__(self, inputs, state, scope=None): + """Runs the RNN cell step computation. + + We assume that the wrapped RNNCell is being built within its `__call__` + method. We directly use the wrapped cell's `__call__` in the overridden + wrapper `__call__` method. + + This allows to use the wrapped cell and the non-wrapped cell equivalently + when using `__call__`. + + Args: + inputs: A tensor with wrapped cell's input. + state: A tensor or tuple of tensors with wrapped cell's state. + scope: VariableScope for the subgraph created in the wrapped cells' + `__call__`. + + Returns: + A pair containing: + + - Output: A tensor with cell's output. + - New state: A tensor or tuple of tensors with new wrapped cell's state. + """ + return self._call_wrapped_cell( + inputs, state, cell_call_fn=self.cell.__call__, scope=scope) + + def get_config(self): + config = { + "cell": { + "class_name": self.cell.__class__.__name__, + "config": self.cell.get_config() + }, + } + base_config = super(_RNNCellWrapperV1, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + config = config.copy() + cell = config.pop("cell") + try: + assert_like_rnncell("cell", cell) + return cls(cell, **config) + except TypeError: + raise ValueError("RNNCellWrapper cannot reconstruct the wrapped cell. " + "Please overwrite the cell in the config with a RNNCell " + "instance.") + + +@tf_export(v1=["nn.rnn_cell.DropoutWrapper"]) +class DropoutWrapper(rnn_cell_wrapper_impl.DropoutWrapperBase, + _RNNCellWrapperV1): + """Operator adding dropout to inputs and outputs of the given cell.""" + + def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation + super(DropoutWrapper, self).__init__(*args, **kwargs) + + __init__.__doc__ = rnn_cell_wrapper_impl.DropoutWrapperBase.__init__.__doc__ + + +@tf_export(v1=["nn.rnn_cell.ResidualWrapper"]) +class ResidualWrapper(rnn_cell_wrapper_impl.ResidualWrapperBase, + _RNNCellWrapperV1): + """RNNCell wrapper that ensures cell inputs are added to the outputs.""" + + def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation + super(ResidualWrapper, self).__init__(*args, **kwargs) + + __init__.__doc__ = rnn_cell_wrapper_impl.ResidualWrapperBase.__init__.__doc__ + + +@tf_export(v1=["nn.rnn_cell.DeviceWrapper"]) +class DeviceWrapper(rnn_cell_wrapper_impl.DeviceWrapperBase, + _RNNCellWrapperV1): + + def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation + super(DeviceWrapper, self).__init__(*args, **kwargs) + + __init__.__doc__ = rnn_cell_wrapper_impl.DeviceWrapperBase.__init__.__doc__ + + +@tf_export(v1=["nn.rnn_cell.MultiRNNCell"]) +class MultiRNNCell(RNNCell): + """RNN cell composed sequentially of multiple simple cells. + + Example: + + ```python + num_units = [128, 64] + cells = [BasicLSTMCell(num_units=n) for n in num_units] + stacked_rnn_cell = MultiRNNCell(cells) + ``` + """ + + @deprecated(None, "This class is equivalent as " + "tf.keras.layers.StackedRNNCells, and will be replaced by " + "that in Tensorflow 2.0.") + def __init__(self, cells, state_is_tuple=True): + """Create a RNN cell composed sequentially of a number of RNNCells. + + Args: + cells: list of RNNCells that will be composed in this order. + state_is_tuple: If True, accepted and returned states are n-tuples, where + `n = len(cells)`. If False, the states are all concatenated along the + column axis. This latter behavior will soon be deprecated. + + Raises: + ValueError: if cells is empty (not allowed), or at least one of the cells + returns a state tuple but the flag `state_is_tuple` is `False`. + """ + super(MultiRNNCell, self).__init__() + if not cells: + raise ValueError("Must specify at least one cell for MultiRNNCell.") + if not nest.is_sequence(cells): + raise TypeError("cells must be a list or tuple, but saw: %s." % cells) + + if len(set(id(cell) for cell in cells)) < len(cells): + logging.log_first_n( + logging.WARN, "At least two cells provided to MultiRNNCell " + "are the same object and will share weights.", 1) + + self._cells = cells + for cell_number, cell in enumerate(self._cells): + # Add Trackable dependencies on these cells so their variables get + # saved with this object when using object-based saving. + if isinstance(cell, trackable.Trackable): + # TODO(allenl): Track down non-Trackable callers. + self._track_trackable(cell, name="cell-%d" % (cell_number,)) + self._state_is_tuple = state_is_tuple + if not state_is_tuple: + if any(nest.is_sequence(c.state_size) for c in self._cells): + raise ValueError("Some cells return tuples of states, but the flag " + "state_is_tuple is not set. State sizes are: %s" % + str([c.state_size for c in self._cells])) + + @property + def state_size(self): + if self._state_is_tuple: + return tuple(cell.state_size for cell in self._cells) + else: + return sum(cell.state_size for cell in self._cells) + + @property + def output_size(self): + return self._cells[-1].output_size + + def zero_state(self, batch_size, dtype): + with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): + if self._state_is_tuple: + return tuple(cell.zero_state(batch_size, dtype) for cell in self._cells) + else: + # We know here that state_size of each cell is not a tuple and + # presumably does not contain TensorArrays or anything else fancy + return super(MultiRNNCell, self).zero_state(batch_size, dtype) + + @property + def trainable_weights(self): + if not self.trainable: + return [] + weights = [] + for cell in self._cells: + if isinstance(cell, base_layer.Layer): + weights += cell.trainable_weights + return weights + + @property + def non_trainable_weights(self): + weights = [] + for cell in self._cells: + if isinstance(cell, base_layer.Layer): + weights += cell.non_trainable_weights + if not self.trainable: + trainable_weights = [] + for cell in self._cells: + if isinstance(cell, base_layer.Layer): + trainable_weights += cell.trainable_weights + return trainable_weights + weights + return weights + + def call(self, inputs, state): + """Run this multi-layer cell on inputs, starting from state.""" + cur_state_pos = 0 + cur_inp = inputs + new_states = [] + for i, cell in enumerate(self._cells): + with vs.variable_scope("cell_%d" % i): + if self._state_is_tuple: + if not nest.is_sequence(state): + raise ValueError( + "Expected state to be a tuple of length %d, but received: %s" % + (len(self.state_size), state)) + cur_state = state[i] + else: + cur_state = array_ops.slice(state, [0, cur_state_pos], + [-1, cell.state_size]) + cur_state_pos += cell.state_size + cur_inp, new_state = cell(cur_inp, cur_state) + new_states.append(new_state) + + new_states = ( + tuple(new_states) if self._state_is_tuple else array_ops.concat( + new_states, 1)) + + return cur_inp, new_states + + +def _check_rnn_cell_input_dtypes(inputs): + """Check whether the input tensors are with supported dtypes. + + Default RNN cells only support floats and complex as its dtypes since the + activation function (tanh and sigmoid) only allow those types. This function + will throw a proper error message if the inputs is not in a supported type. + + Args: + inputs: tensor or nested structure of tensors that are feed to RNN cell as + input or state. + + Raises: + ValueError: if any of the input tensor are not having dtypes of float or + complex. + """ + for t in nest.flatten(inputs): + _check_supported_dtypes(t.dtype) + + +def _check_supported_dtypes(dtype): + if dtype is None: + return + dtype = dtypes.as_dtype(dtype) + if not (dtype.is_floating or dtype.is_complex): + raise ValueError("RNN cell only supports floating point inputs, " + "but saw dtype: %s" % dtype) diff --git a/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_wrapper_impl.py b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_wrapper_impl.py new file mode 100644 index 00000000000..62a6baa5640 --- /dev/null +++ b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_wrapper_impl.py @@ -0,0 +1,516 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module contains the implementation of RNN cell wrappers.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import hashlib +import numbers +import sys +import types as python_types +import warnings + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.util import nest + + +class DropoutWrapperBase(object): + """Operator adding dropout to inputs and outputs of the given cell.""" + + def __init__(self, + cell, + input_keep_prob=1.0, + output_keep_prob=1.0, + state_keep_prob=1.0, + variational_recurrent=False, + input_size=None, + dtype=None, + seed=None, + dropout_state_filter_visitor=None, + **kwargs): + """Create a cell with added input, state, and/or output dropout. + + If `variational_recurrent` is set to `True` (**NOT** the default behavior), + then the same dropout mask is applied at every step, as described in: + [A Theoretically Grounded Application of Dropout in Recurrent + Neural Networks. Y. Gal, Z. Ghahramani](https://arxiv.org/abs/1512.05287). + + Otherwise a different dropout mask is applied at every time step. + + Note, by default (unless a custom `dropout_state_filter` is provided), + the memory state (`c` component of any `LSTMStateTuple`) passing through + a `DropoutWrapper` is never modified. This behavior is described in the + above article. + + Args: + cell: an RNNCell, a projection to output_size is added to it. + input_keep_prob: unit Tensor or float between 0 and 1, input keep + probability; if it is constant and 1, no input dropout will be added. + output_keep_prob: unit Tensor or float between 0 and 1, output keep + probability; if it is constant and 1, no output dropout will be added. + state_keep_prob: unit Tensor or float between 0 and 1, output keep + probability; if it is constant and 1, no output dropout will be added. + State dropout is performed on the outgoing states of the cell. **Note** + the state components to which dropout is applied when `state_keep_prob` + is in `(0, 1)` are also determined by the argument + `dropout_state_filter_visitor` (e.g. by default dropout is never applied + to the `c` component of an `LSTMStateTuple`). + variational_recurrent: Python bool. If `True`, then the same dropout + pattern is applied across all time steps per run call. If this parameter + is set, `input_size` **must** be provided. + input_size: (optional) (possibly nested tuple of) `TensorShape` objects + containing the depth(s) of the input tensors expected to be passed in to + the `DropoutWrapper`. Required and used **iff** `variational_recurrent + = True` and `input_keep_prob < 1`. + dtype: (optional) The `dtype` of the input, state, and output tensors. + Required and used **iff** `variational_recurrent = True`. + seed: (optional) integer, the randomness seed. + dropout_state_filter_visitor: (optional), default: (see below). Function + that takes any hierarchical level of the state and returns a scalar or + depth=1 structure of Python booleans describing which terms in the state + should be dropped out. In addition, if the function returns `True`, + dropout is applied across this sublevel. If the function returns + `False`, dropout is not applied across this entire sublevel. + Default behavior: perform dropout on all terms except the memory (`c`) + state of `LSTMCellState` objects, and don't try to apply dropout to + `TensorArray` objects: ``` + def dropout_state_filter_visitor(s): + if isinstance(s, LSTMCellState): # Never perform dropout on the c + state. return LSTMCellState(c=False, h=True) + elif isinstance(s, TensorArray): return False return True ``` + **kwargs: dict of keyword arguments for base layer. + + Raises: + TypeError: if `cell` is not an `RNNCell`, or `keep_state_fn` is provided + but not `callable`. + ValueError: if any of the keep_probs are not between 0 and 1. + """ + super(DropoutWrapperBase, self).__init__(cell, dtype=dtype, **kwargs) + + if (dropout_state_filter_visitor is not None and + not callable(dropout_state_filter_visitor)): + raise TypeError("dropout_state_filter_visitor must be callable") + self._dropout_state_filter = ( + dropout_state_filter_visitor or _default_dropout_state_filter_visitor) + with ops.name_scope("DropoutWrapperInit"): + + def tensor_and_const_value(v): + tensor_value = ops.convert_to_tensor(v) + const_value = tensor_util.constant_value(tensor_value) + return (tensor_value, const_value) + + for prob, attr in [(input_keep_prob, "input_keep_prob"), + (state_keep_prob, "state_keep_prob"), + (output_keep_prob, "output_keep_prob")]: + tensor_prob, const_prob = tensor_and_const_value(prob) + if const_prob is not None: + if const_prob < 0 or const_prob > 1: + raise ValueError("Parameter %s must be between 0 and 1: %d" % + (attr, const_prob)) + setattr(self, "_%s" % attr, float(const_prob)) + else: + setattr(self, "_%s" % attr, tensor_prob) + + # Set variational_recurrent, seed before running the code below + self._variational_recurrent = variational_recurrent + self._input_size = input_size + self._seed = seed + + self._recurrent_input_noise = None + self._recurrent_state_noise = None + self._recurrent_output_noise = None + + if variational_recurrent: + if dtype is None: + raise ValueError( + "When variational_recurrent=True, dtype must be provided") + + def convert_to_batch_shape(s): + # Prepend a 1 for the batch dimension; for recurrent + # variational dropout we use the same dropout mask for all + # batch elements. + return array_ops.concat(([1], tensor_shape.TensorShape(s).as_list()), 0) + + def batch_noise(s, inner_seed): + shape = convert_to_batch_shape(s) + return random_ops.random_uniform(shape, seed=inner_seed, dtype=dtype) + + if (not isinstance(self._input_keep_prob, numbers.Real) or + self._input_keep_prob < 1.0): + if input_size is None: + raise ValueError( + "When variational_recurrent=True and input_keep_prob < 1.0 or " + "is unknown, input_size must be provided") + self._recurrent_input_noise = _enumerated_map_structure_up_to( + input_size, + lambda i, s: batch_noise(s, inner_seed=self._gen_seed("input", i)), + input_size) + self._recurrent_state_noise = _enumerated_map_structure_up_to( + cell.state_size, + lambda i, s: batch_noise(s, inner_seed=self._gen_seed("state", i)), + cell.state_size) + self._recurrent_output_noise = _enumerated_map_structure_up_to( + cell.output_size, + lambda i, s: batch_noise(s, inner_seed=self._gen_seed("output", i)), + cell.output_size) + + def _gen_seed(self, salt_prefix, index): + if self._seed is None: + return None + salt = "%s_%d" % (salt_prefix, index) + string = (str(self._seed) + salt).encode("utf-8") + return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF + + @property + def wrapped_cell(self): + return self.cell + + @property + def state_size(self): + return self.cell.state_size + + @property + def output_size(self): + return self.cell.output_size + + def build(self, inputs_shape): + self.cell.build(inputs_shape) + self.built = True + + def zero_state(self, batch_size, dtype): + with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): + return self.cell.zero_state(batch_size, dtype) + + def _variational_recurrent_dropout_value( + self, unused_index, value, noise, keep_prob): + """Performs dropout given the pre-calculated noise tensor.""" + # uniform [keep_prob, 1.0 + keep_prob) + random_tensor = keep_prob + noise + + # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) + binary_tensor = math_ops.floor(random_tensor) + ret = math_ops.divide(value, keep_prob) * binary_tensor + ret.set_shape(value.get_shape()) + return ret + + def _dropout(self, + values, + salt_prefix, + recurrent_noise, + keep_prob, + shallow_filtered_substructure=None): + """Decides whether to perform standard dropout or recurrent dropout.""" + + if shallow_filtered_substructure is None: + # Put something so we traverse the entire structure; inside the + # dropout function we check to see if leafs of this are bool or not. + shallow_filtered_substructure = values + + if not self._variational_recurrent: + + def dropout(i, do_dropout, v): + if not isinstance(do_dropout, bool) or do_dropout: + return nn_ops.dropout_v2( + v, rate=1. - keep_prob, seed=self._gen_seed(salt_prefix, i)) + else: + return v + + return _enumerated_map_structure_up_to( + shallow_filtered_substructure, dropout, + *[shallow_filtered_substructure, values]) + else: + + def dropout(i, do_dropout, v, n): + if not isinstance(do_dropout, bool) or do_dropout: + return self._variational_recurrent_dropout_value(i, v, n, keep_prob) + else: + return v + + return _enumerated_map_structure_up_to( + shallow_filtered_substructure, dropout, + *[shallow_filtered_substructure, values, recurrent_noise]) + + def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): + """Runs the wrapped cell and applies dropout. + + Args: + inputs: A tensor with wrapped cell's input. + state: A tensor or tuple of tensors with wrapped cell's state. + cell_call_fn: Wrapped cell's method to use for step computation (cell's + `__call__` or 'call' method). + **kwargs: Additional arguments. + + Returns: + A pair containing: + + - Output: A tensor with cell's output. + - New state: A tensor or tuple of tensors with new wrapped cell's state. + """ + + def _should_dropout(p): + return (not isinstance(p, float)) or p < 1 + + if _should_dropout(self._input_keep_prob): + inputs = self._dropout(inputs, "input", self._recurrent_input_noise, + self._input_keep_prob) + output, new_state = cell_call_fn(inputs, state, **kwargs) + if _should_dropout(self._state_keep_prob): + # Identify which subsets of the state to perform dropout on and + # which ones to keep. + shallow_filtered_substructure = nest.get_traverse_shallow_structure( + self._dropout_state_filter, new_state) + new_state = self._dropout(new_state, "state", self._recurrent_state_noise, + self._state_keep_prob, + shallow_filtered_substructure) + if _should_dropout(self._output_keep_prob): + output = self._dropout(output, "output", self._recurrent_output_noise, + self._output_keep_prob) + return output, new_state + + def get_config(self): + """Returns the config of the dropout wrapper.""" + config = { + "input_keep_prob": self._input_keep_prob, + "output_keep_prob": self._output_keep_prob, + "state_keep_prob": self._state_keep_prob, + "variational_recurrent": self._variational_recurrent, + "input_size": self._input_size, + "seed": self._seed, + } + if self._dropout_state_filter != _default_dropout_state_filter_visitor: + function, function_type, function_module = _serialize_function_to_config( + self._dropout_state_filter) + config.update({"dropout_fn": function, + "dropout_fn_type": function_type, + "dropout_fn_module": function_module}) + base_config = super(DropoutWrapperBase, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + if "dropout_fn" in config: + config = config.copy() + dropout_state_filter = _parse_config_to_function( + config, custom_objects, "dropout_fn", "dropout_fn_type", + "dropout_fn_module") + config.pop("dropout_fn") + config["dropout_state_filter_visitor"] = dropout_state_filter + return super(DropoutWrapperBase, cls).from_config( + config, custom_objects=custom_objects) + + +class ResidualWrapperBase(object): + """RNNCell wrapper that ensures cell inputs are added to the outputs.""" + + def __init__(self, cell, residual_fn=None, **kwargs): + """Constructs a `ResidualWrapper` for `cell`. + + Args: + cell: An instance of `RNNCell`. + residual_fn: (Optional) The function to map raw cell inputs and raw cell + outputs to the actual cell outputs of the residual network. + Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs + and outputs. + **kwargs: dict of keyword arguments for base layer. + """ + super(ResidualWrapperBase, self).__init__(cell, **kwargs) + self._residual_fn = residual_fn + + @property + def state_size(self): + return self.cell.state_size + + @property + def output_size(self): + return self.cell.output_size + + def zero_state(self, batch_size, dtype): + with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): + return self.cell.zero_state(batch_size, dtype) + + def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): + """Run the cell and then apply the residual_fn on its inputs to its outputs. + + Args: + inputs: cell inputs. + state: cell state. + cell_call_fn: Wrapped cell's method to use for step computation (cell's + `__call__` or 'call' method). + **kwargs: Additional arguments passed to the wrapped cell's `call`. + + Returns: + Tuple of cell outputs and new state. + + Raises: + TypeError: If cell inputs and outputs have different structure (type). + ValueError: If cell inputs and outputs have different structure (value). + """ + outputs, new_state = cell_call_fn(inputs, state, **kwargs) + + # Ensure shapes match + def assert_shape_match(inp, out): + inp.get_shape().assert_is_compatible_with(out.get_shape()) + + def default_residual_fn(inputs, outputs): + nest.assert_same_structure(inputs, outputs) + nest.map_structure(assert_shape_match, inputs, outputs) + return nest.map_structure(lambda inp, out: inp + out, inputs, outputs) + + res_outputs = (self._residual_fn or default_residual_fn)(inputs, outputs) + return (res_outputs, new_state) + + def get_config(self): + """Returns the config of the residual wrapper.""" + if self._residual_fn is not None: + function, function_type, function_module = _serialize_function_to_config( + self._residual_fn) + config = { + "residual_fn": function, + "residual_fn_type": function_type, + "residual_fn_module": function_module + } + else: + config = {} + base_config = super(ResidualWrapperBase, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + if "residual_fn" in config: + config = config.copy() + residual_function = _parse_config_to_function(config, custom_objects, + "residual_fn", + "residual_fn_type", + "residual_fn_module") + config["residual_fn"] = residual_function + return super(ResidualWrapperBase, cls).from_config( + config, custom_objects=custom_objects) + + +class DeviceWrapperBase(object): + """Operator that ensures an RNNCell runs on a particular device.""" + + def __init__(self, cell, device, **kwargs): + """Construct a `DeviceWrapper` for `cell` with device `device`. + + Ensures the wrapped `cell` is called with `tf.device(device)`. + + Args: + cell: An instance of `RNNCell`. + device: A device string or function, for passing to `tf.device`. + **kwargs: dict of keyword arguments for base layer. + """ + super(DeviceWrapperBase, self).__init__(cell, **kwargs) + self._device = device + + @property + def state_size(self): + return self.cell.state_size + + @property + def output_size(self): + return self.cell.output_size + + def zero_state(self, batch_size, dtype): + with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): + with ops.device(self._device): + return self.cell.zero_state(batch_size, dtype) + + def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): + """Run the cell on specified device.""" + with ops.device(self._device): + return cell_call_fn(inputs, state, **kwargs) + + def get_config(self): + config = {"device": self._device} + base_config = super(DeviceWrapperBase, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +def _serialize_function_to_config(function): + """Serialize the function for get_config().""" + if isinstance(function, python_types.LambdaType): + output = generic_utils.func_dump(function) + output_type = "lambda" + module = function.__module__ + elif callable(function): + output = function.__name__ + output_type = "function" + module = function.__module__ + else: + raise ValueError("Unrecognized function type for input: {}".format( + type(function))) + + return output, output_type, module + + +def _parse_config_to_function(config, custom_objects, func_attr_name, + func_type_attr_name, module_attr_name): + """Reconstruct the function from the config.""" + globs = globals() + module = config.pop(module_attr_name, None) + if module in sys.modules: + globs.update(sys.modules[module].__dict__) + elif module is not None: + # Note: we don't know the name of the function if it's a lambda. + warnings.warn("{} is not loaded, but a layer uses it. " + "It may cause errors.".format(module), UserWarning) + if custom_objects: + globs.update(custom_objects) + function_type = config.pop(func_type_attr_name) + if function_type == "function": + # Simple lookup in custom objects + function = generic_utils.deserialize_keras_object( + config[func_attr_name], + custom_objects=custom_objects, + printable_module_name="function in wrapper") + elif function_type == "lambda": + # Unsafe deserialization from bytecode + function = generic_utils.func_load( + config[func_attr_name], globs=globs) + else: + raise TypeError("Unknown function type:", function_type) + return function + + +def _default_dropout_state_filter_visitor(substate): + from tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl import LSTMStateTuple # pylint: disable=g-import-not-at-top + if isinstance(substate, LSTMStateTuple): + # Do not perform dropout on the memory state. + return LSTMStateTuple(c=False, h=True) + elif isinstance(substate, tensor_array_ops.TensorArray): + return False + return True + + +def _enumerated_map_structure_up_to(shallow_structure, map_fn, *args, **kwargs): + ix = [0] + + def enumerated_fn(*inner_args, **inner_kwargs): + r = map_fn(ix[0], *inner_args, **inner_kwargs) + ix[0] += 1 + return r + + return nest.map_structure_up_to(shallow_structure, enumerated_fn, *args, + **kwargs) diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index 719b06a41d5..b0b29de22ea 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -24,1332 +24,27 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections - -from tensorflow.python.eager import context -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util -from tensorflow.python.keras import activations -from tensorflow.python.keras import initializers -from tensorflow.python.keras.engine import input_spec -from tensorflow.python.keras.utils import tf_utils -from tensorflow.python.layers import base as base_layer -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import clip_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import partitioned_variables -from tensorflow.python.ops import rnn_cell_wrapper_impl -from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.ops import variables as tf_variables -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training.tracking import base as trackable -from tensorflow.python.util import nest -from tensorflow.python.util.deprecation import deprecated -from tensorflow.python.util.tf_export import tf_export - -_BIAS_VARIABLE_NAME = "bias" -_WEIGHTS_VARIABLE_NAME = "kernel" - -# This can be used with self.assertRaisesRegexp for assert_like_rnncell. -ASSERT_LIKE_RNNCELL_ERROR_REGEXP = "is not an RNNCell" - - -def _hasattr(obj, attr_name): - try: - getattr(obj, attr_name) - except AttributeError: - return False - else: - return True - - -def assert_like_rnncell(cell_name, cell): - """Raises a TypeError if cell is not like an RNNCell. - - NOTE: Do not rely on the error message (in particular in tests) which can be - subject to change to increase readability. Use - ASSERT_LIKE_RNNCELL_ERROR_REGEXP. - - Args: - cell_name: A string to give a meaningful error referencing to the name of - the functionargument. - cell: The object which should behave like an RNNCell. - - Raises: - TypeError: A human-friendly exception. - """ - conditions = [ - _hasattr(cell, "output_size"), - _hasattr(cell, "state_size"), - _hasattr(cell, "get_initial_state") or _hasattr(cell, "zero_state"), - callable(cell), - ] - errors = [ - "'output_size' property is missing", "'state_size' property is missing", - "either 'zero_state' or 'get_initial_state' method is required", - "is not callable" - ] - - if not all(conditions): - - errors = [error for error, cond in zip(errors, conditions) if not cond] - raise TypeError("The argument {!r} ({}) is not an RNNCell: {}.".format( - cell_name, cell, ", ".join(errors))) - - -def _concat(prefix, suffix, static=False): - """Concat that enables int, Tensor, or TensorShape values. - - This function takes a size specification, which can be an integer, a - TensorShape, or a Tensor, and converts it into a concatenated Tensor - (if static = False) or a list of integers (if static = True). - - Args: - prefix: The prefix; usually the batch size (and/or time step size). - (TensorShape, int, or Tensor.) - suffix: TensorShape, int, or Tensor. - static: If `True`, return a python list with possibly unknown dimensions. - Otherwise return a `Tensor`. - - Returns: - shape: the concatenation of prefix and suffix. - - Raises: - ValueError: if `suffix` is not a scalar or vector (or TensorShape). - ValueError: if prefix or suffix was `None` and asked for dynamic - Tensors out. - """ - if isinstance(prefix, ops.Tensor): - p = prefix - p_static = tensor_util.constant_value(prefix) - if p.shape.ndims == 0: - p = array_ops.expand_dims(p, 0) - elif p.shape.ndims != 1: - raise ValueError("prefix tensor must be either a scalar or vector, " - "but saw tensor: %s" % p) - else: - p = tensor_shape.as_shape(prefix) - p_static = p.as_list() if p.ndims is not None else None - p = ( - constant_op.constant(p.as_list(), dtype=dtypes.int32) - if p.is_fully_defined() else None) - if isinstance(suffix, ops.Tensor): - s = suffix - s_static = tensor_util.constant_value(suffix) - if s.shape.ndims == 0: - s = array_ops.expand_dims(s, 0) - elif s.shape.ndims != 1: - raise ValueError("suffix tensor must be either a scalar or vector, " - "but saw tensor: %s" % s) - else: - s = tensor_shape.as_shape(suffix) - s_static = s.as_list() if s.ndims is not None else None - s = ( - constant_op.constant(s.as_list(), dtype=dtypes.int32) - if s.is_fully_defined() else None) - - if static: - shape = tensor_shape.as_shape(p_static).concatenate(s_static) - shape = shape.as_list() if shape.ndims is not None else None - else: - if p is None or s is None: - raise ValueError("Provided a prefix or suffix of None: %s and %s" % - (prefix, suffix)) - shape = array_ops.concat((p, s), 0) - return shape - - -def _zero_state_tensors(state_size, batch_size, dtype): - """Create tensors of zeros based on state_size, batch_size, and dtype.""" - - def get_state_shape(s): - """Combine s with batch_size to get a proper tensor shape.""" - c = _concat(batch_size, s) - size = array_ops.zeros(c, dtype=dtype) - if not context.executing_eagerly(): - c_static = _concat(batch_size, s, static=True) - size.set_shape(c_static) - return size - - return nest.map_structure(get_state_shape, state_size) - - -@tf_export(v1=["nn.rnn_cell.RNNCell"]) -class RNNCell(base_layer.Layer): - """Abstract object representing an RNN cell. - - Every `RNNCell` must have the properties below and implement `call` with - the signature `(output, next_state) = call(input, state)`. The optional - third input argument, `scope`, is allowed for backwards compatibility - purposes; but should be left off for new subclasses. - - This definition of cell differs from the definition used in the literature. - In the literature, 'cell' refers to an object with a single scalar output. - This definition refers to a horizontal array of such units. - - An RNN cell, in the most abstract setting, is anything that has - a state and performs some operation that takes a matrix of inputs. - This operation results in an output matrix with `self.output_size` columns. - If `self.state_size` is an integer, this operation also results in a new - state matrix with `self.state_size` columns. If `self.state_size` is a - (possibly nested tuple of) TensorShape object(s), then it should return a - matching structure of Tensors having shape `[batch_size].concatenate(s)` - for each `s` in `self.batch_size`. - """ - - def __init__(self, trainable=True, name=None, dtype=None, **kwargs): - super(RNNCell, self).__init__( - trainable=trainable, name=name, dtype=dtype, **kwargs) - # Attribute that indicates whether the cell is a TF RNN cell, due the slight - # difference between TF and Keras RNN cell. Notably the state is not wrapped - # in a list for TF cell where they are single tensor state, whereas keras - # cell will wrap the state into a list, and call() will have to unwrap them. - self._is_tf_rnn_cell = True - - def __call__(self, inputs, state, scope=None): - """Run this RNN cell on inputs, starting from the given state. - - Args: - inputs: `2-D` tensor with shape `[batch_size, input_size]`. - state: if `self.state_size` is an integer, this should be a `2-D Tensor` - with shape `[batch_size, self.state_size]`. Otherwise, if - `self.state_size` is a tuple of integers, this should be a tuple with - shapes `[batch_size, s] for s in self.state_size`. - scope: VariableScope for the created subgraph; defaults to class name. - - Returns: - A pair containing: - - - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`. - - New state: Either a single `2-D` tensor, or a tuple of tensors matching - the arity and shapes of `state`. - """ - if scope is not None: - with vs.variable_scope( - scope, custom_getter=self._rnn_get_variable) as scope: - return super(RNNCell, self).__call__(inputs, state, scope=scope) - else: - scope_attrname = "rnncell_scope" - scope = getattr(self, scope_attrname, None) - if scope is None: - scope = vs.variable_scope( - vs.get_variable_scope(), custom_getter=self._rnn_get_variable) - setattr(self, scope_attrname, scope) - with scope: - return super(RNNCell, self).__call__(inputs, state) - - def _rnn_get_variable(self, getter, *args, **kwargs): - variable = getter(*args, **kwargs) - if context.executing_eagerly(): - trainable = variable._trainable # pylint: disable=protected-access - else: - trainable = ( - variable in tf_variables.trainable_variables() or - (isinstance(variable, tf_variables.PartitionedVariable) and - list(variable)[0] in tf_variables.trainable_variables())) - if trainable and all(variable is not v for v in self._trainable_weights): - self._trainable_weights.append(variable) - elif not trainable and all( - variable is not v for v in self._non_trainable_weights): - self._non_trainable_weights.append(variable) - return variable - - @property - def state_size(self): - """size(s) of state(s) used by this cell. - - It can be represented by an Integer, a TensorShape or a tuple of Integers - or TensorShapes. - """ - raise NotImplementedError("Abstract method") - - @property - def output_size(self): - """Integer or TensorShape: size of outputs produced by this cell.""" - raise NotImplementedError("Abstract method") - - def build(self, _): - # This tells the parent Layer object that it's OK to call - # self.add_variable() inside the call() method. - pass - - def get_initial_state(self, inputs=None, batch_size=None, dtype=None): - if inputs is not None: - # Validate the given batch_size and dtype against inputs if provided. - inputs = ops.convert_to_tensor(inputs, name="inputs") - if batch_size is not None: - if tensor_util.is_tensor(batch_size): - static_batch_size = tensor_util.constant_value( - batch_size, partial=True) - else: - static_batch_size = batch_size - if inputs.shape.dims[0].value != static_batch_size: - raise ValueError( - "batch size from input tensor is different from the " - "input param. Input tensor batch: {}, batch_size: {}".format( - inputs.shape.dims[0].value, batch_size)) - - if dtype is not None and inputs.dtype != dtype: - raise ValueError( - "dtype from input tensor is different from the " - "input param. Input tensor dtype: {}, dtype: {}".format( - inputs.dtype, dtype)) - - batch_size = inputs.shape.dims[0].value or array_ops.shape(inputs)[0] - dtype = inputs.dtype - if batch_size is None or dtype is None: - raise ValueError( - "batch_size and dtype cannot be None while constructing initial " - "state: batch_size={}, dtype={}".format(batch_size, dtype)) - return self.zero_state(batch_size, dtype) - - def zero_state(self, batch_size, dtype): - """Return zero-filled state tensor(s). - - Args: - batch_size: int, float, or unit Tensor representing the batch size. - dtype: the data type to use for the state. - - Returns: - If `state_size` is an int or TensorShape, then the return value is a - `N-D` tensor of shape `[batch_size, state_size]` filled with zeros. - - If `state_size` is a nested list or tuple, then the return value is - a nested list or tuple (of the same structure) of `2-D` tensors with - the shapes `[batch_size, s]` for each s in `state_size`. - """ - # Try to use the last cached zero_state. This is done to avoid recreating - # zeros, especially when eager execution is enabled. - state_size = self.state_size - is_eager = context.executing_eagerly() - if is_eager and _hasattr(self, "_last_zero_state"): - (last_state_size, last_batch_size, last_dtype, - last_output) = getattr(self, "_last_zero_state") - if (last_batch_size == batch_size and last_dtype == dtype and - last_state_size == state_size): - return last_output - with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): - output = _zero_state_tensors(state_size, batch_size, dtype) - if is_eager: - self._last_zero_state = (state_size, batch_size, dtype, output) - return output - - # TODO(b/134773139): Remove when contrib RNN cells implement `get_config` - def get_config(self): # pylint: disable=useless-super-delegation - return super(RNNCell, self).get_config() - - -class LayerRNNCell(RNNCell): - """Subclass of RNNCells that act like proper `tf.Layer` objects. - - For backwards compatibility purposes, most `RNNCell` instances allow their - `call` methods to instantiate variables via `tf.compat.v1.get_variable`. The - underlying - variable scope thus keeps track of any variables, and returning cached - versions. This is atypical of `tf.layer` objects, which separate this - part of layer building into a `build` method that is only called once. - - Here we provide a subclass for `RNNCell` objects that act exactly as - `Layer` objects do. They must provide a `build` method and their - `call` methods do not access Variables `tf.compat.v1.get_variable`. - """ - - def __call__(self, inputs, state, scope=None, *args, **kwargs): - """Run this RNN cell on inputs, starting from the given state. - - Args: - inputs: `2-D` tensor with shape `[batch_size, input_size]`. - state: if `self.state_size` is an integer, this should be a `2-D Tensor` - with shape `[batch_size, self.state_size]`. Otherwise, if - `self.state_size` is a tuple of integers, this should be a tuple with - shapes `[batch_size, s] for s in self.state_size`. - scope: optional cell scope. - *args: Additional positional arguments. - **kwargs: Additional keyword arguments. - - Returns: - A pair containing: - - - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`. - - New state: Either a single `2-D` tensor, or a tuple of tensors matching - the arity and shapes of `state`. - """ - # Bypass RNNCell's variable capturing semantics for LayerRNNCell. - # Instead, it is up to subclasses to provide a proper build - # method. See the class docstring for more details. - return base_layer.Layer.__call__( - self, inputs, state, scope=scope, *args, **kwargs) - - -@tf_export(v1=["nn.rnn_cell.BasicRNNCell"]) -class BasicRNNCell(LayerRNNCell): - """The most basic RNN cell. - - Note that this cell is not optimized for performance. Please use - `tf.contrib.cudnn_rnn.CudnnRNNTanh` for better performance on GPU. - - Args: - num_units: int, The number of units in the RNN cell. - activation: Nonlinearity to use. Default: `tanh`. It could also be string - that is within Keras activation function names. - reuse: (optional) Python boolean describing whether to reuse variables in an - existing scope. If not `True`, and the existing scope already has the - given variables, an error is raised. - name: String, the name of the layer. Layers with the same name will share - weights, but to avoid mistakes we require reuse=True in such cases. - dtype: Default dtype of the layer (default of `None` means use the type of - the first input). Required when `build` is called before `call`. - **kwargs: Dict, keyword named properties for common layer attributes, like - `trainable` etc when constructing the cell from configs of get_config(). - """ - - @deprecated(None, "This class is equivalent as tf.keras.layers.SimpleRNNCell," - " and will be replaced by that in Tensorflow 2.0.") - def __init__(self, - num_units, - activation=None, - reuse=None, - name=None, - dtype=None, - **kwargs): - super(BasicRNNCell, self).__init__( - _reuse=reuse, name=name, dtype=dtype, **kwargs) - _check_supported_dtypes(self.dtype) - if context.executing_eagerly() and context.num_gpus() > 0: - logging.warn( - "%s: Note that this cell is not optimized for performance. " - "Please use tf.contrib.cudnn_rnn.CudnnRNNTanh for better " - "performance on GPU.", self) - - # Inputs must be 2-dimensional. - self.input_spec = input_spec.InputSpec(ndim=2) - - self._num_units = num_units - if activation: - self._activation = activations.get(activation) - else: - self._activation = math_ops.tanh - - @property - def state_size(self): - return self._num_units - - @property - def output_size(self): - return self._num_units - - @tf_utils.shape_type_conversion - def build(self, inputs_shape): - if inputs_shape[-1] is None: - raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % - str(inputs_shape)) - _check_supported_dtypes(self.dtype) - - input_depth = inputs_shape[-1] - self._kernel = self.add_variable( - _WEIGHTS_VARIABLE_NAME, - shape=[input_depth + self._num_units, self._num_units]) - self._bias = self.add_variable( - _BIAS_VARIABLE_NAME, - shape=[self._num_units], - initializer=init_ops.zeros_initializer(dtype=self.dtype)) - - self.built = True - - def call(self, inputs, state): - """Most basic RNN: output = new_state = act(W * input + U * state + B).""" - _check_rnn_cell_input_dtypes([inputs, state]) - gate_inputs = math_ops.matmul( - array_ops.concat([inputs, state], 1), self._kernel) - gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) - output = self._activation(gate_inputs) - return output, output - - def get_config(self): - config = { - "num_units": self._num_units, - "activation": activations.serialize(self._activation), - "reuse": self._reuse, - } - base_config = super(BasicRNNCell, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - -@tf_export(v1=["nn.rnn_cell.GRUCell"]) -class GRUCell(LayerRNNCell): - """Gated Recurrent Unit cell. - - Note that this cell is not optimized for performance. Please use - `tf.contrib.cudnn_rnn.CudnnGRU` for better performance on GPU, or - `tf.contrib.rnn.GRUBlockCellV2` for better performance on CPU. - - Args: - num_units: int, The number of units in the GRU cell. - activation: Nonlinearity to use. Default: `tanh`. - reuse: (optional) Python boolean describing whether to reuse variables in an - existing scope. If not `True`, and the existing scope already has the - given variables, an error is raised. - kernel_initializer: (optional) The initializer to use for the weight and - projection matrices. - bias_initializer: (optional) The initializer to use for the bias. - name: String, the name of the layer. Layers with the same name will share - weights, but to avoid mistakes we require reuse=True in such cases. - dtype: Default dtype of the layer (default of `None` means use the type of - the first input). Required when `build` is called before `call`. - **kwargs: Dict, keyword named properties for common layer attributes, like - `trainable` etc when constructing the cell from configs of get_config(). - - References: - Learning Phrase Representations using RNN Encoder Decoder for Statistical - Machine Translation: - [Cho et al., 2014] - (https://aclanthology.coli.uni-saarland.de/papers/D14-1179/d14-1179) - ([pdf](http://emnlp2014.org/papers/pdf/EMNLP2014179.pdf)) - """ - - @deprecated(None, "This class is equivalent as tf.keras.layers.GRUCell," - " and will be replaced by that in Tensorflow 2.0.") - def __init__(self, - num_units, - activation=None, - reuse=None, - kernel_initializer=None, - bias_initializer=None, - name=None, - dtype=None, - **kwargs): - super(GRUCell, self).__init__( - _reuse=reuse, name=name, dtype=dtype, **kwargs) - _check_supported_dtypes(self.dtype) - - if context.executing_eagerly() and context.num_gpus() > 0: - logging.warn( - "%s: Note that this cell is not optimized for performance. " - "Please use tf.contrib.cudnn_rnn.CudnnGRU for better " - "performance on GPU.", self) - # Inputs must be 2-dimensional. - self.input_spec = input_spec.InputSpec(ndim=2) - - self._num_units = num_units - if activation: - self._activation = activations.get(activation) - else: - self._activation = math_ops.tanh - self._kernel_initializer = initializers.get(kernel_initializer) - self._bias_initializer = initializers.get(bias_initializer) - - @property - def state_size(self): - return self._num_units - - @property - def output_size(self): - return self._num_units - - @tf_utils.shape_type_conversion - def build(self, inputs_shape): - if inputs_shape[-1] is None: - raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % - str(inputs_shape)) - _check_supported_dtypes(self.dtype) - input_depth = inputs_shape[-1] - self._gate_kernel = self.add_variable( - "gates/%s" % _WEIGHTS_VARIABLE_NAME, - shape=[input_depth + self._num_units, 2 * self._num_units], - initializer=self._kernel_initializer) - self._gate_bias = self.add_variable( - "gates/%s" % _BIAS_VARIABLE_NAME, - shape=[2 * self._num_units], - initializer=(self._bias_initializer - if self._bias_initializer is not None else - init_ops.constant_initializer(1.0, dtype=self.dtype))) - self._candidate_kernel = self.add_variable( - "candidate/%s" % _WEIGHTS_VARIABLE_NAME, - shape=[input_depth + self._num_units, self._num_units], - initializer=self._kernel_initializer) - self._candidate_bias = self.add_variable( - "candidate/%s" % _BIAS_VARIABLE_NAME, - shape=[self._num_units], - initializer=(self._bias_initializer - if self._bias_initializer is not None else - init_ops.zeros_initializer(dtype=self.dtype))) - - self.built = True - - def call(self, inputs, state): - """Gated recurrent unit (GRU) with nunits cells.""" - _check_rnn_cell_input_dtypes([inputs, state]) - - gate_inputs = math_ops.matmul( - array_ops.concat([inputs, state], 1), self._gate_kernel) - gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias) - - value = math_ops.sigmoid(gate_inputs) - r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) - - r_state = r * state - - candidate = math_ops.matmul( - array_ops.concat([inputs, r_state], 1), self._candidate_kernel) - candidate = nn_ops.bias_add(candidate, self._candidate_bias) - - c = self._activation(candidate) - new_h = u * state + (1 - u) * c - return new_h, new_h - - def get_config(self): - config = { - "num_units": self._num_units, - "kernel_initializer": initializers.serialize(self._kernel_initializer), - "bias_initializer": initializers.serialize(self._bias_initializer), - "activation": activations.serialize(self._activation), - "reuse": self._reuse, - } - base_config = super(GRUCell, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - -_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h")) - - -@tf_export(v1=["nn.rnn_cell.LSTMStateTuple"]) -class LSTMStateTuple(_LSTMStateTuple): - """Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state. - - Stores two elements: `(c, h)`, in that order. Where `c` is the hidden state - and `h` is the output. - - Only used when `state_is_tuple=True`. - """ - __slots__ = () - - @property - def dtype(self): - (c, h) = self - if c.dtype != h.dtype: - raise TypeError("Inconsistent internal state: %s vs %s" % - (str(c.dtype), str(h.dtype))) - return c.dtype - - -@tf_export(v1=["nn.rnn_cell.BasicLSTMCell"]) -class BasicLSTMCell(LayerRNNCell): - """DEPRECATED: Please use `tf.compat.v1.nn.rnn_cell.LSTMCell` instead. - - Basic LSTM recurrent network cell. - - The implementation is based on - - We add forget_bias (default: 1) to the biases of the forget gate in order to - reduce the scale of forgetting in the beginning of the training. - - It does not allow cell clipping, a projection layer, and does not - use peep-hole connections: it is the basic baseline. - - For advanced models, please use the full `tf.compat.v1.nn.rnn_cell.LSTMCell` - that follows. - - Note that this cell is not optimized for performance. Please use - `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or - `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for - better performance on CPU. - """ - - @deprecated(None, "This class is equivalent as tf.keras.layers.LSTMCell," - " and will be replaced by that in Tensorflow 2.0.") - def __init__(self, - num_units, - forget_bias=1.0, - state_is_tuple=True, - activation=None, - reuse=None, - name=None, - dtype=None, - **kwargs): - """Initialize the basic LSTM cell. - - Args: - num_units: int, The number of units in the LSTM cell. - forget_bias: float, The bias added to forget gates (see above). Must set - to `0.0` manually when restoring from CudnnLSTM-trained checkpoints. - state_is_tuple: If True, accepted and returned states are 2-tuples of the - `c_state` and `m_state`. If False, they are concatenated along the - column axis. The latter behavior will soon be deprecated. - activation: Activation function of the inner states. Default: `tanh`. It - could also be string that is within Keras activation function names. - reuse: (optional) Python boolean describing whether to reuse variables in - an existing scope. If not `True`, and the existing scope already has - the given variables, an error is raised. - name: String, the name of the layer. Layers with the same name will share - weights, but to avoid mistakes we require reuse=True in such cases. - dtype: Default dtype of the layer (default of `None` means use the type of - the first input). Required when `build` is called before `call`. - **kwargs: Dict, keyword named properties for common layer attributes, like - `trainable` etc when constructing the cell from configs of get_config(). - When restoring from CudnnLSTM-trained checkpoints, must use - `CudnnCompatibleLSTMCell` instead. - """ - super(BasicLSTMCell, self).__init__( - _reuse=reuse, name=name, dtype=dtype, **kwargs) - _check_supported_dtypes(self.dtype) - if not state_is_tuple: - logging.warn( - "%s: Using a concatenated state is slower and will soon be " - "deprecated. Use state_is_tuple=True.", self) - if context.executing_eagerly() and context.num_gpus() > 0: - logging.warn( - "%s: Note that this cell is not optimized for performance. " - "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better " - "performance on GPU.", self) - - # Inputs must be 2-dimensional. - self.input_spec = input_spec.InputSpec(ndim=2) - - self._num_units = num_units - self._forget_bias = forget_bias - self._state_is_tuple = state_is_tuple - if activation: - self._activation = activations.get(activation) - else: - self._activation = math_ops.tanh - - @property - def state_size(self): - return (LSTMStateTuple(self._num_units, self._num_units) - if self._state_is_tuple else 2 * self._num_units) - - @property - def output_size(self): - return self._num_units - - @tf_utils.shape_type_conversion - def build(self, inputs_shape): - if inputs_shape[-1] is None: - raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % - str(inputs_shape)) - _check_supported_dtypes(self.dtype) - input_depth = inputs_shape[-1] - h_depth = self._num_units - self._kernel = self.add_variable( - _WEIGHTS_VARIABLE_NAME, - shape=[input_depth + h_depth, 4 * self._num_units]) - self._bias = self.add_variable( - _BIAS_VARIABLE_NAME, - shape=[4 * self._num_units], - initializer=init_ops.zeros_initializer(dtype=self.dtype)) - - self.built = True - - def call(self, inputs, state): - """Long short-term memory cell (LSTM). - - Args: - inputs: `2-D` tensor with shape `[batch_size, input_size]`. - state: An `LSTMStateTuple` of state tensors, each shaped `[batch_size, - num_units]`, if `state_is_tuple` has been set to `True`. Otherwise, a - `Tensor` shaped `[batch_size, 2 * num_units]`. - - Returns: - A pair containing the new hidden state, and the new state (either a - `LSTMStateTuple` or a concatenated state, depending on - `state_is_tuple`). - """ - _check_rnn_cell_input_dtypes([inputs, state]) - - sigmoid = math_ops.sigmoid - one = constant_op.constant(1, dtype=dtypes.int32) - # Parameters of gates are concatenated into one multiply for efficiency. - if self._state_is_tuple: - c, h = state - else: - c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one) - - gate_inputs = math_ops.matmul( - array_ops.concat([inputs, h], 1), self._kernel) - gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) - - # i = input_gate, j = new_input, f = forget_gate, o = output_gate - i, j, f, o = array_ops.split( - value=gate_inputs, num_or_size_splits=4, axis=one) - - forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype) - # Note that using `add` and `multiply` instead of `+` and `*` gives a - # performance improvement. So using those at the cost of readability. - add = math_ops.add - multiply = math_ops.multiply - new_c = add( - multiply(c, sigmoid(add(f, forget_bias_tensor))), - multiply(sigmoid(i), self._activation(j))) - new_h = multiply(self._activation(new_c), sigmoid(o)) - - if self._state_is_tuple: - new_state = LSTMStateTuple(new_c, new_h) - else: - new_state = array_ops.concat([new_c, new_h], 1) - return new_h, new_state - - def get_config(self): - config = { - "num_units": self._num_units, - "forget_bias": self._forget_bias, - "state_is_tuple": self._state_is_tuple, - "activation": activations.serialize(self._activation), - "reuse": self._reuse, - } - base_config = super(BasicLSTMCell, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - -@tf_export(v1=["nn.rnn_cell.LSTMCell"]) -class LSTMCell(LayerRNNCell): - """Long short-term memory unit (LSTM) recurrent network cell. - - The default non-peephole implementation is based on (Gers et al., 1999). - The peephole implementation is based on (Sak et al., 2014). - - The class uses optional peep-hole connections, optional cell clipping, and - an optional projection layer. - - Note that this cell is not optimized for performance. Please use - `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or - `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for - better performance on CPU. - References: - Long short-term memory recurrent neural network architectures for large - scale acoustic modeling: - [Sak et al., 2014] - (https://www.isca-speech.org/archive/interspeech_2014/i14_0338.html) - ([pdf] - (https://www.isca-speech.org/archive/archive_papers/interspeech_2014/i14_0338.pdf)) - Learning to forget: - [Gers et al., 1999] - (http://digital-library.theiet.org/content/conferences/10.1049/cp_19991218) - ([pdf](https://arxiv.org/pdf/1409.2329.pdf)) - Long Short-Term Memory: - [Hochreiter et al., 1997] - (https://www.mitpressjournals.org/doi/abs/10.1162/neco.1997.9.8.1735) - ([pdf](http://ml.jku.at/publications/older/3504.pdf)) - """ - - @deprecated(None, "This class is equivalent as tf.keras.layers.LSTMCell," - " and will be replaced by that in Tensorflow 2.0.") - def __init__(self, - num_units, - use_peepholes=False, - cell_clip=None, - initializer=None, - num_proj=None, - proj_clip=None, - num_unit_shards=None, - num_proj_shards=None, - forget_bias=1.0, - state_is_tuple=True, - activation=None, - reuse=None, - name=None, - dtype=None, - **kwargs): - """Initialize the parameters for an LSTM cell. - - Args: - num_units: int, The number of units in the LSTM cell. - use_peepholes: bool, set True to enable diagonal/peephole connections. - cell_clip: (optional) A float value, if provided the cell state is clipped - by this value prior to the cell output activation. - initializer: (optional) The initializer to use for the weight and - projection matrices. - num_proj: (optional) int, The output dimensionality for the projection - matrices. If None, no projection is performed. - proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is - provided, then the projected values are clipped elementwise to within - `[-proj_clip, proj_clip]`. - num_unit_shards: Deprecated, will be removed by Jan. 2017. Use a - variable_scope partitioner instead. - num_proj_shards: Deprecated, will be removed by Jan. 2017. Use a - variable_scope partitioner instead. - forget_bias: Biases of the forget gate are initialized by default to 1 in - order to reduce the scale of forgetting at the beginning of the - training. Must set it manually to `0.0` when restoring from CudnnLSTM - trained checkpoints. - state_is_tuple: If True, accepted and returned states are 2-tuples of the - `c_state` and `m_state`. If False, they are concatenated along the - column axis. This latter behavior will soon be deprecated. - activation: Activation function of the inner states. Default: `tanh`. It - could also be string that is within Keras activation function names. - reuse: (optional) Python boolean describing whether to reuse variables in - an existing scope. If not `True`, and the existing scope already has - the given variables, an error is raised. - name: String, the name of the layer. Layers with the same name will share - weights, but to avoid mistakes we require reuse=True in such cases. - dtype: Default dtype of the layer (default of `None` means use the type of - the first input). Required when `build` is called before `call`. - **kwargs: Dict, keyword named properties for common layer attributes, like - `trainable` etc when constructing the cell from configs of get_config(). - When restoring from CudnnLSTM-trained checkpoints, use - `CudnnCompatibleLSTMCell` instead. - """ - super(LSTMCell, self).__init__( - _reuse=reuse, name=name, dtype=dtype, **kwargs) - _check_supported_dtypes(self.dtype) - if not state_is_tuple: - logging.warn( - "%s: Using a concatenated state is slower and will soon be " - "deprecated. Use state_is_tuple=True.", self) - if num_unit_shards is not None or num_proj_shards is not None: - logging.warn( - "%s: The num_unit_shards and proj_unit_shards parameters are " - "deprecated and will be removed in Jan 2017. " - "Use a variable scope with a partitioner instead.", self) - if context.executing_eagerly() and context.num_gpus() > 0: - logging.warn( - "%s: Note that this cell is not optimized for performance. " - "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better " - "performance on GPU.", self) - - # Inputs must be 2-dimensional. - self.input_spec = input_spec.InputSpec(ndim=2) - - self._num_units = num_units - self._use_peepholes = use_peepholes - self._cell_clip = cell_clip - self._initializer = initializers.get(initializer) - self._num_proj = num_proj - self._proj_clip = proj_clip - self._num_unit_shards = num_unit_shards - self._num_proj_shards = num_proj_shards - self._forget_bias = forget_bias - self._state_is_tuple = state_is_tuple - if activation: - self._activation = activations.get(activation) - else: - self._activation = math_ops.tanh - - if num_proj: - self._state_size = ( - LSTMStateTuple(num_units, num_proj) if state_is_tuple else num_units + - num_proj) - self._output_size = num_proj - else: - self._state_size = ( - LSTMStateTuple(num_units, num_units) if state_is_tuple else 2 * - num_units) - self._output_size = num_units - - @property - def state_size(self): - return self._state_size - - @property - def output_size(self): - return self._output_size - - @tf_utils.shape_type_conversion - def build(self, inputs_shape): - if inputs_shape[-1] is None: - raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % - str(inputs_shape)) - _check_supported_dtypes(self.dtype) - input_depth = inputs_shape[-1] - h_depth = self._num_units if self._num_proj is None else self._num_proj - maybe_partitioner = ( - partitioned_variables.fixed_size_partitioner(self._num_unit_shards) - if self._num_unit_shards is not None else None) - self._kernel = self.add_variable( - _WEIGHTS_VARIABLE_NAME, - shape=[input_depth + h_depth, 4 * self._num_units], - initializer=self._initializer, - partitioner=maybe_partitioner) - if self.dtype is None: - initializer = init_ops.zeros_initializer - else: - initializer = init_ops.zeros_initializer(dtype=self.dtype) - self._bias = self.add_variable( - _BIAS_VARIABLE_NAME, - shape=[4 * self._num_units], - initializer=initializer) - if self._use_peepholes: - self._w_f_diag = self.add_variable( - "w_f_diag", shape=[self._num_units], initializer=self._initializer) - self._w_i_diag = self.add_variable( - "w_i_diag", shape=[self._num_units], initializer=self._initializer) - self._w_o_diag = self.add_variable( - "w_o_diag", shape=[self._num_units], initializer=self._initializer) - - if self._num_proj is not None: - maybe_proj_partitioner = ( - partitioned_variables.fixed_size_partitioner(self._num_proj_shards) - if self._num_proj_shards is not None else None) - self._proj_kernel = self.add_variable( - "projection/%s" % _WEIGHTS_VARIABLE_NAME, - shape=[self._num_units, self._num_proj], - initializer=self._initializer, - partitioner=maybe_proj_partitioner) - - self.built = True - - def call(self, inputs, state): - """Run one step of LSTM. - - Args: - inputs: input Tensor, must be 2-D, `[batch, input_size]`. - state: if `state_is_tuple` is False, this must be a state Tensor, `2-D, - [batch, state_size]`. If `state_is_tuple` is True, this must be a tuple - of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. - - Returns: - A tuple containing: - - - A `2-D, [batch, output_dim]`, Tensor representing the output of the - LSTM after reading `inputs` when previous state was `state`. - Here output_dim is: - num_proj if num_proj was set, - num_units otherwise. - - Tensor(s) representing the new state of LSTM after reading `inputs` when - the previous state was `state`. Same type and shape(s) as `state`. - - Raises: - ValueError: If input size cannot be inferred from inputs via - static shape inference. - """ - _check_rnn_cell_input_dtypes([inputs, state]) - - num_proj = self._num_units if self._num_proj is None else self._num_proj - sigmoid = math_ops.sigmoid - - if self._state_is_tuple: - (c_prev, m_prev) = state - else: - c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) - m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) - - input_size = inputs.get_shape().with_rank(2).dims[1].value - if input_size is None: - raise ValueError("Could not infer input size from inputs.get_shape()[-1]") - - # i = input_gate, j = new_input, f = forget_gate, o = output_gate - lstm_matrix = math_ops.matmul( - array_ops.concat([inputs, m_prev], 1), self._kernel) - lstm_matrix = nn_ops.bias_add(lstm_matrix, self._bias) - - i, j, f, o = array_ops.split( - value=lstm_matrix, num_or_size_splits=4, axis=1) - # Diagonal connections - if self._use_peepholes: - c = ( - sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev + - sigmoid(i + self._w_i_diag * c_prev) * self._activation(j)) - else: - c = ( - sigmoid(f + self._forget_bias) * c_prev + - sigmoid(i) * self._activation(j)) - - if self._cell_clip is not None: - # pylint: disable=invalid-unary-operand-type - c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) - # pylint: enable=invalid-unary-operand-type - if self._use_peepholes: - m = sigmoid(o + self._w_o_diag * c) * self._activation(c) - else: - m = sigmoid(o) * self._activation(c) - - if self._num_proj is not None: - m = math_ops.matmul(m, self._proj_kernel) - - if self._proj_clip is not None: - # pylint: disable=invalid-unary-operand-type - m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) - # pylint: enable=invalid-unary-operand-type - - new_state = ( - LSTMStateTuple(c, m) - if self._state_is_tuple else array_ops.concat([c, m], 1)) - return m, new_state - - def get_config(self): - config = { - "num_units": self._num_units, - "use_peepholes": self._use_peepholes, - "cell_clip": self._cell_clip, - "initializer": initializers.serialize(self._initializer), - "num_proj": self._num_proj, - "proj_clip": self._proj_clip, - "num_unit_shards": self._num_unit_shards, - "num_proj_shards": self._num_proj_shards, - "forget_bias": self._forget_bias, - "state_is_tuple": self._state_is_tuple, - "activation": activations.serialize(self._activation), - "reuse": self._reuse, - } - base_config = super(LSTMCell, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - -class _RNNCellWrapperV1(RNNCell): - """Base class for cells wrappers V1 compatibility. - - This class along with `_RNNCellWrapperV2` allows to define cells wrappers that - are compatible with V1 and V2, and defines helper methods for this purpose. - """ - - def __init__(self, cell, *args, **kwargs): - super(_RNNCellWrapperV1, self).__init__(*args, **kwargs) - assert_like_rnncell("cell", cell) - self.cell = cell - if isinstance(cell, trackable.Trackable): - self._track_trackable(self.cell, name="cell") - - def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): - """Calls the wrapped cell and performs the wrapping logic. - - This method is called from the wrapper's `call` or `__call__` methods. - - Args: - inputs: A tensor with wrapped cell's input. - state: A tensor or tuple of tensors with wrapped cell's state. - cell_call_fn: Wrapped cell's method to use for step computation (cell's - `__call__` or 'call' method). - **kwargs: Additional arguments. - - Returns: - A pair containing: - - Output: A tensor with cell's output. - - New state: A tensor or tuple of tensors with new wrapped cell's state. - """ - raise NotImplementedError - - def __call__(self, inputs, state, scope=None): - """Runs the RNN cell step computation. - - We assume that the wrapped RNNCell is being built within its `__call__` - method. We directly use the wrapped cell's `__call__` in the overridden - wrapper `__call__` method. - - This allows to use the wrapped cell and the non-wrapped cell equivalently - when using `__call__`. - - Args: - inputs: A tensor with wrapped cell's input. - state: A tensor or tuple of tensors with wrapped cell's state. - scope: VariableScope for the subgraph created in the wrapped cells' - `__call__`. - - Returns: - A pair containing: - - - Output: A tensor with cell's output. - - New state: A tensor or tuple of tensors with new wrapped cell's state. - """ - return self._call_wrapped_cell( - inputs, state, cell_call_fn=self.cell.__call__, scope=scope) - - def get_config(self): - config = { - "cell": { - "class_name": self.cell.__class__.__name__, - "config": self.cell.get_config() - }, - } - base_config = super(_RNNCellWrapperV1, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - @classmethod - def from_config(cls, config, custom_objects=None): - config = config.copy() - cell = config.pop("cell") - try: - assert_like_rnncell("cell", cell) - return cls(cell, **config) - except TypeError: - raise ValueError("RNNCellWrapper cannot reconstruct the wrapped cell. " - "Please overwrite the cell in the config with a RNNCell " - "instance.") - - -@tf_export(v1=["nn.rnn_cell.DropoutWrapper"]) -class DropoutWrapper(rnn_cell_wrapper_impl.DropoutWrapperBase, - _RNNCellWrapperV1): - """Operator adding dropout to inputs and outputs of the given cell.""" - - def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation - super(DropoutWrapper, self).__init__(*args, **kwargs) - - __init__.__doc__ = rnn_cell_wrapper_impl.DropoutWrapperBase.__init__.__doc__ - - -@tf_export(v1=["nn.rnn_cell.ResidualWrapper"]) -class ResidualWrapper(rnn_cell_wrapper_impl.ResidualWrapperBase, - _RNNCellWrapperV1): - """RNNCell wrapper that ensures cell inputs are added to the outputs.""" - - def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation - super(ResidualWrapper, self).__init__(*args, **kwargs) - - __init__.__doc__ = rnn_cell_wrapper_impl.ResidualWrapperBase.__init__.__doc__ - - -@tf_export(v1=["nn.rnn_cell.DeviceWrapper"]) -class DeviceWrapper(rnn_cell_wrapper_impl.DeviceWrapperBase, - _RNNCellWrapperV1): - - def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation - super(DeviceWrapper, self).__init__(*args, **kwargs) - - __init__.__doc__ = rnn_cell_wrapper_impl.DeviceWrapperBase.__init__.__doc__ - - -@tf_export(v1=["nn.rnn_cell.MultiRNNCell"]) -class MultiRNNCell(RNNCell): - """RNN cell composed sequentially of multiple simple cells. - - Example: - - ```python - num_units = [128, 64] - cells = [BasicLSTMCell(num_units=n) for n in num_units] - stacked_rnn_cell = MultiRNNCell(cells) - ``` - """ - - @deprecated(None, "This class is equivalent as " - "tf.keras.layers.StackedRNNCells, and will be replaced by " - "that in Tensorflow 2.0.") - def __init__(self, cells, state_is_tuple=True): - """Create a RNN cell composed sequentially of a number of RNNCells. - - Args: - cells: list of RNNCells that will be composed in this order. - state_is_tuple: If True, accepted and returned states are n-tuples, where - `n = len(cells)`. If False, the states are all concatenated along the - column axis. This latter behavior will soon be deprecated. - - Raises: - ValueError: if cells is empty (not allowed), or at least one of the cells - returns a state tuple but the flag `state_is_tuple` is `False`. - """ - super(MultiRNNCell, self).__init__() - if not cells: - raise ValueError("Must specify at least one cell for MultiRNNCell.") - if not nest.is_sequence(cells): - raise TypeError("cells must be a list or tuple, but saw: %s." % cells) - - if len(set(id(cell) for cell in cells)) < len(cells): - logging.log_first_n( - logging.WARN, "At least two cells provided to MultiRNNCell " - "are the same object and will share weights.", 1) - - self._cells = cells - for cell_number, cell in enumerate(self._cells): - # Add Trackable dependencies on these cells so their variables get - # saved with this object when using object-based saving. - if isinstance(cell, trackable.Trackable): - # TODO(allenl): Track down non-Trackable callers. - self._track_trackable(cell, name="cell-%d" % (cell_number,)) - self._state_is_tuple = state_is_tuple - if not state_is_tuple: - if any(nest.is_sequence(c.state_size) for c in self._cells): - raise ValueError("Some cells return tuples of states, but the flag " - "state_is_tuple is not set. State sizes are: %s" % - str([c.state_size for c in self._cells])) - - @property - def state_size(self): - if self._state_is_tuple: - return tuple(cell.state_size for cell in self._cells) - else: - return sum(cell.state_size for cell in self._cells) - - @property - def output_size(self): - return self._cells[-1].output_size - - def zero_state(self, batch_size, dtype): - with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): - if self._state_is_tuple: - return tuple(cell.zero_state(batch_size, dtype) for cell in self._cells) - else: - # We know here that state_size of each cell is not a tuple and - # presumably does not contain TensorArrays or anything else fancy - return super(MultiRNNCell, self).zero_state(batch_size, dtype) - - @property - def trainable_weights(self): - if not self.trainable: - return [] - weights = [] - for cell in self._cells: - if isinstance(cell, base_layer.Layer): - weights += cell.trainable_weights - return weights - - @property - def non_trainable_weights(self): - weights = [] - for cell in self._cells: - if isinstance(cell, base_layer.Layer): - weights += cell.non_trainable_weights - if not self.trainable: - trainable_weights = [] - for cell in self._cells: - if isinstance(cell, base_layer.Layer): - trainable_weights += cell.trainable_weights - return trainable_weights + weights - return weights - - def call(self, inputs, state): - """Run this multi-layer cell on inputs, starting from state.""" - cur_state_pos = 0 - cur_inp = inputs - new_states = [] - for i, cell in enumerate(self._cells): - with vs.variable_scope("cell_%d" % i): - if self._state_is_tuple: - if not nest.is_sequence(state): - raise ValueError( - "Expected state to be a tuple of length %d, but received: %s" % - (len(self.state_size), state)) - cur_state = state[i] - else: - cur_state = array_ops.slice(state, [0, cur_state_pos], - [-1, cell.state_size]) - cur_state_pos += cell.state_size - cur_inp, new_state = cell(cur_inp, cur_state) - new_states.append(new_state) - - new_states = ( - tuple(new_states) if self._state_is_tuple else array_ops.concat( - new_states, 1)) - - return cur_inp, new_states - - -def _check_rnn_cell_input_dtypes(inputs): - """Check whether the input tensors are with supported dtypes. - - Default RNN cells only support floats and complex as its dtypes since the - activation function (tanh and sigmoid) only allow those types. This function - will throw a proper error message if the inputs is not in a supported type. - - Args: - inputs: tensor or nested structure of tensors that are feed to RNN cell as - input or state. - - Raises: - ValueError: if any of the input tensor are not having dtypes of float or - complex. - """ - for t in nest.flatten(inputs): - _check_supported_dtypes(t.dtype) - - -def _check_supported_dtypes(dtype): - if dtype is None: - return - dtype = dtypes.as_dtype(dtype) - if not (dtype.is_floating or dtype.is_complex): - raise ValueError("RNN cell only supports floating point inputs, " - "but saw dtype: %s" % dtype) +from tensorflow.python.keras.layers.legacy_rnn import rnn_cell_impl + +# Remove caller that rely on private symbol in future. +# pylint: disable=protected-access +_BIAS_VARIABLE_NAME = rnn_cell_impl._BIAS_VARIABLE_NAME +_WEIGHTS_VARIABLE_NAME = rnn_cell_impl._WEIGHTS_VARIABLE_NAME +_concat = rnn_cell_impl._concat +_zero_state_tensors = rnn_cell_impl._zero_state_tensors +# pylint: disable=protected-access + + +assert_like_rnncell = rnn_cell_impl.assert_like_rnncell +ASSERT_LIKE_RNNCELL_ERROR_REGEXP = rnn_cell_impl.ASSERT_LIKE_RNNCELL_ERROR_REGEXP # pylint: disable=line-too-long +BasicLSTMCell = rnn_cell_impl.BasicLSTMCell +BasicRNNCell = rnn_cell_impl.BasicRNNCell +DeviceWrapper = rnn_cell_impl.DeviceWrapper +DropoutWrapper = rnn_cell_impl.DropoutWrapper +GRUCell = rnn_cell_impl.GRUCell +LayerRNNCell = rnn_cell_impl.LayerRNNCell +LSTMCell = rnn_cell_impl.LSTMCell +LSTMStateTuple = rnn_cell_impl.LSTMStateTuple +MultiRNNCell = rnn_cell_impl.MultiRNNCell +ResidualWrapper = rnn_cell_impl.ResidualWrapper +RNNCell = rnn_cell_impl.RNNCell diff --git a/tensorflow/python/ops/rnn_cell_wrapper_impl.py b/tensorflow/python/ops/rnn_cell_wrapper_impl.py index 9c42caea63e..a8b1fea2a5f 100644 --- a/tensorflow/python/ops/rnn_cell_wrapper_impl.py +++ b/tensorflow/python/ops/rnn_cell_wrapper_impl.py @@ -17,500 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import hashlib -import numbers -import sys -import types as python_types -import warnings +from tensorflow.python.keras.layers.legacy_rnn import rnn_cell_wrapper_impl -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util -from tensorflow.python.keras.utils import generic_utils -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import tensor_array_ops -from tensorflow.python.util import nest - -class DropoutWrapperBase(object): - """Operator adding dropout to inputs and outputs of the given cell.""" - - def __init__(self, - cell, - input_keep_prob=1.0, - output_keep_prob=1.0, - state_keep_prob=1.0, - variational_recurrent=False, - input_size=None, - dtype=None, - seed=None, - dropout_state_filter_visitor=None, - **kwargs): - """Create a cell with added input, state, and/or output dropout. - - If `variational_recurrent` is set to `True` (**NOT** the default behavior), - then the same dropout mask is applied at every step, as described in: - [A Theoretically Grounded Application of Dropout in Recurrent - Neural Networks. Y. Gal, Z. Ghahramani](https://arxiv.org/abs/1512.05287). - - Otherwise a different dropout mask is applied at every time step. - - Note, by default (unless a custom `dropout_state_filter` is provided), - the memory state (`c` component of any `LSTMStateTuple`) passing through - a `DropoutWrapper` is never modified. This behavior is described in the - above article. - - Args: - cell: an RNNCell, a projection to output_size is added to it. - input_keep_prob: unit Tensor or float between 0 and 1, input keep - probability; if it is constant and 1, no input dropout will be added. - output_keep_prob: unit Tensor or float between 0 and 1, output keep - probability; if it is constant and 1, no output dropout will be added. - state_keep_prob: unit Tensor or float between 0 and 1, output keep - probability; if it is constant and 1, no output dropout will be added. - State dropout is performed on the outgoing states of the cell. **Note** - the state components to which dropout is applied when `state_keep_prob` - is in `(0, 1)` are also determined by the argument - `dropout_state_filter_visitor` (e.g. by default dropout is never applied - to the `c` component of an `LSTMStateTuple`). - variational_recurrent: Python bool. If `True`, then the same dropout - pattern is applied across all time steps per run call. If this parameter - is set, `input_size` **must** be provided. - input_size: (optional) (possibly nested tuple of) `TensorShape` objects - containing the depth(s) of the input tensors expected to be passed in to - the `DropoutWrapper`. Required and used **iff** `variational_recurrent - = True` and `input_keep_prob < 1`. - dtype: (optional) The `dtype` of the input, state, and output tensors. - Required and used **iff** `variational_recurrent = True`. - seed: (optional) integer, the randomness seed. - dropout_state_filter_visitor: (optional), default: (see below). Function - that takes any hierarchical level of the state and returns a scalar or - depth=1 structure of Python booleans describing which terms in the state - should be dropped out. In addition, if the function returns `True`, - dropout is applied across this sublevel. If the function returns - `False`, dropout is not applied across this entire sublevel. - Default behavior: perform dropout on all terms except the memory (`c`) - state of `LSTMCellState` objects, and don't try to apply dropout to - `TensorArray` objects: ``` - def dropout_state_filter_visitor(s): - if isinstance(s, LSTMCellState): # Never perform dropout on the c - state. return LSTMCellState(c=False, h=True) - elif isinstance(s, TensorArray): return False return True ``` - **kwargs: dict of keyword arguments for base layer. - - Raises: - TypeError: if `cell` is not an `RNNCell`, or `keep_state_fn` is provided - but not `callable`. - ValueError: if any of the keep_probs are not between 0 and 1. - """ - super(DropoutWrapperBase, self).__init__(cell, dtype=dtype, **kwargs) - - if (dropout_state_filter_visitor is not None and - not callable(dropout_state_filter_visitor)): - raise TypeError("dropout_state_filter_visitor must be callable") - self._dropout_state_filter = ( - dropout_state_filter_visitor or _default_dropout_state_filter_visitor) - with ops.name_scope("DropoutWrapperInit"): - - def tensor_and_const_value(v): - tensor_value = ops.convert_to_tensor(v) - const_value = tensor_util.constant_value(tensor_value) - return (tensor_value, const_value) - - for prob, attr in [(input_keep_prob, "input_keep_prob"), - (state_keep_prob, "state_keep_prob"), - (output_keep_prob, "output_keep_prob")]: - tensor_prob, const_prob = tensor_and_const_value(prob) - if const_prob is not None: - if const_prob < 0 or const_prob > 1: - raise ValueError("Parameter %s must be between 0 and 1: %d" % - (attr, const_prob)) - setattr(self, "_%s" % attr, float(const_prob)) - else: - setattr(self, "_%s" % attr, tensor_prob) - - # Set variational_recurrent, seed before running the code below - self._variational_recurrent = variational_recurrent - self._input_size = input_size - self._seed = seed - - self._recurrent_input_noise = None - self._recurrent_state_noise = None - self._recurrent_output_noise = None - - if variational_recurrent: - if dtype is None: - raise ValueError( - "When variational_recurrent=True, dtype must be provided") - - def convert_to_batch_shape(s): - # Prepend a 1 for the batch dimension; for recurrent - # variational dropout we use the same dropout mask for all - # batch elements. - return array_ops.concat(([1], tensor_shape.TensorShape(s).as_list()), 0) - - def batch_noise(s, inner_seed): - shape = convert_to_batch_shape(s) - return random_ops.random_uniform(shape, seed=inner_seed, dtype=dtype) - - if (not isinstance(self._input_keep_prob, numbers.Real) or - self._input_keep_prob < 1.0): - if input_size is None: - raise ValueError( - "When variational_recurrent=True and input_keep_prob < 1.0 or " - "is unknown, input_size must be provided") - self._recurrent_input_noise = _enumerated_map_structure_up_to( - input_size, - lambda i, s: batch_noise(s, inner_seed=self._gen_seed("input", i)), - input_size) - self._recurrent_state_noise = _enumerated_map_structure_up_to( - cell.state_size, - lambda i, s: batch_noise(s, inner_seed=self._gen_seed("state", i)), - cell.state_size) - self._recurrent_output_noise = _enumerated_map_structure_up_to( - cell.output_size, - lambda i, s: batch_noise(s, inner_seed=self._gen_seed("output", i)), - cell.output_size) - - def _gen_seed(self, salt_prefix, index): - if self._seed is None: - return None - salt = "%s_%d" % (salt_prefix, index) - string = (str(self._seed) + salt).encode("utf-8") - return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF - - @property - def wrapped_cell(self): - return self.cell - - @property - def state_size(self): - return self.cell.state_size - - @property - def output_size(self): - return self.cell.output_size - - def build(self, inputs_shape): - self.cell.build(inputs_shape) - self.built = True - - def zero_state(self, batch_size, dtype): - with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): - return self.cell.zero_state(batch_size, dtype) - - def _variational_recurrent_dropout_value( - self, unused_index, value, noise, keep_prob): - """Performs dropout given the pre-calculated noise tensor.""" - # uniform [keep_prob, 1.0 + keep_prob) - random_tensor = keep_prob + noise - - # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) - binary_tensor = math_ops.floor(random_tensor) - ret = math_ops.divide(value, keep_prob) * binary_tensor - ret.set_shape(value.get_shape()) - return ret - - def _dropout(self, - values, - salt_prefix, - recurrent_noise, - keep_prob, - shallow_filtered_substructure=None): - """Decides whether to perform standard dropout or recurrent dropout.""" - - if shallow_filtered_substructure is None: - # Put something so we traverse the entire structure; inside the - # dropout function we check to see if leafs of this are bool or not. - shallow_filtered_substructure = values - - if not self._variational_recurrent: - - def dropout(i, do_dropout, v): - if not isinstance(do_dropout, bool) or do_dropout: - return nn_ops.dropout_v2( - v, rate=1. - keep_prob, seed=self._gen_seed(salt_prefix, i)) - else: - return v - - return _enumerated_map_structure_up_to( - shallow_filtered_substructure, dropout, - *[shallow_filtered_substructure, values]) - else: - - def dropout(i, do_dropout, v, n): - if not isinstance(do_dropout, bool) or do_dropout: - return self._variational_recurrent_dropout_value(i, v, n, keep_prob) - else: - return v - - return _enumerated_map_structure_up_to( - shallow_filtered_substructure, dropout, - *[shallow_filtered_substructure, values, recurrent_noise]) - - def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): - """Runs the wrapped cell and applies dropout. - - Args: - inputs: A tensor with wrapped cell's input. - state: A tensor or tuple of tensors with wrapped cell's state. - cell_call_fn: Wrapped cell's method to use for step computation (cell's - `__call__` or 'call' method). - **kwargs: Additional arguments. - - Returns: - A pair containing: - - - Output: A tensor with cell's output. - - New state: A tensor or tuple of tensors with new wrapped cell's state. - """ - - def _should_dropout(p): - return (not isinstance(p, float)) or p < 1 - - if _should_dropout(self._input_keep_prob): - inputs = self._dropout(inputs, "input", self._recurrent_input_noise, - self._input_keep_prob) - output, new_state = cell_call_fn(inputs, state, **kwargs) - if _should_dropout(self._state_keep_prob): - # Identify which subsets of the state to perform dropout on and - # which ones to keep. - shallow_filtered_substructure = nest.get_traverse_shallow_structure( - self._dropout_state_filter, new_state) - new_state = self._dropout(new_state, "state", self._recurrent_state_noise, - self._state_keep_prob, - shallow_filtered_substructure) - if _should_dropout(self._output_keep_prob): - output = self._dropout(output, "output", self._recurrent_output_noise, - self._output_keep_prob) - return output, new_state - - def get_config(self): - """Returns the config of the dropout wrapper.""" - config = { - "input_keep_prob": self._input_keep_prob, - "output_keep_prob": self._output_keep_prob, - "state_keep_prob": self._state_keep_prob, - "variational_recurrent": self._variational_recurrent, - "input_size": self._input_size, - "seed": self._seed, - } - if self._dropout_state_filter != _default_dropout_state_filter_visitor: - function, function_type, function_module = _serialize_function_to_config( - self._dropout_state_filter) - config.update({"dropout_fn": function, - "dropout_fn_type": function_type, - "dropout_fn_module": function_module}) - base_config = super(DropoutWrapperBase, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - @classmethod - def from_config(cls, config, custom_objects=None): - if "dropout_fn" in config: - config = config.copy() - dropout_state_filter = _parse_config_to_function( - config, custom_objects, "dropout_fn", "dropout_fn_type", - "dropout_fn_module") - config.pop("dropout_fn") - config["dropout_state_filter_visitor"] = dropout_state_filter - return super(DropoutWrapperBase, cls).from_config( - config, custom_objects=custom_objects) - - -class ResidualWrapperBase(object): - """RNNCell wrapper that ensures cell inputs are added to the outputs.""" - - def __init__(self, cell, residual_fn=None, **kwargs): - """Constructs a `ResidualWrapper` for `cell`. - - Args: - cell: An instance of `RNNCell`. - residual_fn: (Optional) The function to map raw cell inputs and raw cell - outputs to the actual cell outputs of the residual network. - Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs - and outputs. - **kwargs: dict of keyword arguments for base layer. - """ - super(ResidualWrapperBase, self).__init__(cell, **kwargs) - self._residual_fn = residual_fn - - @property - def state_size(self): - return self.cell.state_size - - @property - def output_size(self): - return self.cell.output_size - - def zero_state(self, batch_size, dtype): - with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): - return self.cell.zero_state(batch_size, dtype) - - def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): - """Run the cell and then apply the residual_fn on its inputs to its outputs. - - Args: - inputs: cell inputs. - state: cell state. - cell_call_fn: Wrapped cell's method to use for step computation (cell's - `__call__` or 'call' method). - **kwargs: Additional arguments passed to the wrapped cell's `call`. - - Returns: - Tuple of cell outputs and new state. - - Raises: - TypeError: If cell inputs and outputs have different structure (type). - ValueError: If cell inputs and outputs have different structure (value). - """ - outputs, new_state = cell_call_fn(inputs, state, **kwargs) - - # Ensure shapes match - def assert_shape_match(inp, out): - inp.get_shape().assert_is_compatible_with(out.get_shape()) - - def default_residual_fn(inputs, outputs): - nest.assert_same_structure(inputs, outputs) - nest.map_structure(assert_shape_match, inputs, outputs) - return nest.map_structure(lambda inp, out: inp + out, inputs, outputs) - - res_outputs = (self._residual_fn or default_residual_fn)(inputs, outputs) - return (res_outputs, new_state) - - def get_config(self): - """Returns the config of the residual wrapper.""" - if self._residual_fn is not None: - function, function_type, function_module = _serialize_function_to_config( - self._residual_fn) - config = { - "residual_fn": function, - "residual_fn_type": function_type, - "residual_fn_module": function_module - } - else: - config = {} - base_config = super(ResidualWrapperBase, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - @classmethod - def from_config(cls, config, custom_objects=None): - if "residual_fn" in config: - config = config.copy() - residual_function = _parse_config_to_function(config, custom_objects, - "residual_fn", - "residual_fn_type", - "residual_fn_module") - config["residual_fn"] = residual_function - return super(ResidualWrapperBase, cls).from_config( - config, custom_objects=custom_objects) - - -class DeviceWrapperBase(object): - """Operator that ensures an RNNCell runs on a particular device.""" - - def __init__(self, cell, device, **kwargs): - """Construct a `DeviceWrapper` for `cell` with device `device`. - - Ensures the wrapped `cell` is called with `tf.device(device)`. - - Args: - cell: An instance of `RNNCell`. - device: A device string or function, for passing to `tf.device`. - **kwargs: dict of keyword arguments for base layer. - """ - super(DeviceWrapperBase, self).__init__(cell, **kwargs) - self._device = device - - @property - def state_size(self): - return self.cell.state_size - - @property - def output_size(self): - return self.cell.output_size - - def zero_state(self, batch_size, dtype): - with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): - with ops.device(self._device): - return self.cell.zero_state(batch_size, dtype) - - def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): - """Run the cell on specified device.""" - with ops.device(self._device): - return cell_call_fn(inputs, state, **kwargs) - - def get_config(self): - config = {"device": self._device} - base_config = super(DeviceWrapperBase, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - -def _serialize_function_to_config(function): - """Serialize the function for get_config().""" - if isinstance(function, python_types.LambdaType): - output = generic_utils.func_dump(function) - output_type = "lambda" - module = function.__module__ - elif callable(function): - output = function.__name__ - output_type = "function" - module = function.__module__ - else: - raise ValueError("Unrecognized function type for input: {}".format( - type(function))) - - return output, output_type, module - - -def _parse_config_to_function(config, custom_objects, func_attr_name, - func_type_attr_name, module_attr_name): - """Reconstruct the function from the config.""" - globs = globals() - module = config.pop(module_attr_name, None) - if module in sys.modules: - globs.update(sys.modules[module].__dict__) - elif module is not None: - # Note: we don't know the name of the function if it's a lambda. - warnings.warn("{} is not loaded, but a layer uses it. " - "It may cause errors.".format(module), UserWarning) - if custom_objects: - globs.update(custom_objects) - function_type = config.pop(func_type_attr_name) - if function_type == "function": - # Simple lookup in custom objects - function = generic_utils.deserialize_keras_object( - config[func_attr_name], - custom_objects=custom_objects, - printable_module_name="function in wrapper") - elif function_type == "lambda": - # Unsafe deserialization from bytecode - function = generic_utils.func_load( - config[func_attr_name], globs=globs) - else: - raise TypeError("Unknown function type:", function_type) - return function - - -def _default_dropout_state_filter_visitor(substate): - from tensorflow.python.ops.rnn_cell_impl import LSTMStateTuple # pylint: disable=g-import-not-at-top - if isinstance(substate, LSTMStateTuple): - # Do not perform dropout on the memory state. - return LSTMStateTuple(c=False, h=True) - elif isinstance(substate, tensor_array_ops.TensorArray): - return False - return True - - -def _enumerated_map_structure_up_to(shallow_structure, map_fn, *args, **kwargs): - ix = [0] - - def enumerated_fn(*inner_args, **inner_kwargs): - r = map_fn(ix[0], *inner_args, **inner_kwargs) - ix[0] += 1 - return r - - return nest.map_structure_up_to(shallow_structure, enumerated_fn, *args, - **kwargs) +DeviceWrapperBase = rnn_cell_wrapper_impl.DeviceWrapperBase +DropoutWrapperBase = rnn_cell_wrapper_impl.DropoutWrapperBase +ResidualWrapperBase = rnn_cell_wrapper_impl.ResidualWrapperBase diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt index a143468c615..34cc5a20beb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.lite.experimental.nn.TFLiteLSTMCell" tf_class { is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt index fd240a31637..f17c071254e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.lite.experimental.nn.TfLiteRNNCell" tf_class { is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt index 02ec119a24a..d337b185c46 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.nn.rnn_cell.BasicLSTMCell" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt index 185bfa99489..7269795bde5 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.nn.rnn_cell.BasicRNNCell" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt index 102a2266f5a..1368c1cb603 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.nn.rnn_cell.DeviceWrapper" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt index bb6bde99e53..6d490621aa9 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.nn.rnn_cell.DropoutWrapper" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt index 832ec6f6be6..a9669ff59a2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.nn.rnn_cell.GRUCell" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt index 6f471d3f811..54e517ac974 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.nn.rnn_cell.LSTMCell" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt index 1de8a55dcca..274624a6203 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.nn.rnn_cell.LSTMStateTuple" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "c" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt index 48d17d35fbe..db31de9d754 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.nn.rnn_cell.MultiRNNCell" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt index 5c428f658c9..2286a66efd8 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.nn.rnn_cell.RNNCell" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt index 629d73640f3..5570bf7af98 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.nn.rnn_cell.ResidualWrapper" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-device-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-device-wrapper.pbtxt index d047b62497c..e94252a2f6d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-device-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-device-wrapper.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.nn.RNNCellDeviceWrapper" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-dropout-wrapper.pbtxt index 7c6eee87429..9682fd0a29a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-dropout-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-dropout-wrapper.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.nn.RNNCellDropoutWrapper" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-residual-wrapper.pbtxt index e4d9d2e9737..0cd0cebfa1d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-residual-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-residual-wrapper.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.nn.RNNCellResidualWrapper" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: ""