commit
04ba418175
27
RELEASE.md
27
RELEASE.md
@ -3,9 +3,8 @@
|
||||
## Major Features and Improvements
|
||||
* Added `tf.layers.conv3d_transpose` layer for spatio temporal deconvolution.
|
||||
* Added `tf.Session.make_callable()`, which provides a lower overhead means of running a similar step multiple times.
|
||||
* Added libverbs-based RDMA support to contrib (courtesy @junshi15 from Yahoo).
|
||||
* Bring `tf.feature_column.*` into the API. Non-deprecated functionality from `tf.contrib.layers.*` is moved to `tf.feature_column.*` with cosmetic changes.
|
||||
* `RNNCell` objects now subclass `tf.layers._Layer`. The strictness described
|
||||
* Added ibverbs-based RDMA support to contrib (courtesy @junshi15 from Yahoo).
|
||||
* `RNNCell` objects now subclass `tf.layers.Layer`. The strictness described
|
||||
in the TensorFlow 1.1 release is gone: The first time an RNNCell is used,
|
||||
it caches its scope. All future uses of the RNNCell will reuse variables from
|
||||
that same scope. This is a breaking change from the behavior of RNNCells
|
||||
@ -23,6 +22,28 @@
|
||||
* TensorFlow C library now available for Windows.
|
||||
* We released a new open-source version of TensorBoard.
|
||||
* [`SavedModel CLI`](https://www.tensorflow.org/versions/master/programmers_guide/saved_model_cli) tool available to inspect and execute MetaGraph in SavedModel
|
||||
* RNNCells' variable names have been renamed for consistency with Keras layers.
|
||||
Specifically, the previous variable names "weights" and "biases" have
|
||||
been changed to "kernel" and "bias", respectively.
|
||||
This may cause backward incompatibility with regard to your old
|
||||
checkpoints containing such RNN cells, in which case you can use the tool
|
||||
[checkpoint_convert script](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py)
|
||||
to convert the variable names in your old checkpoints.
|
||||
* Many of the RNN functions and classes that were in the `tf.nn` namespace
|
||||
before the 1.0 release and which were moved to `tf.contrib.rnn` have now
|
||||
been moved back to the core namespace. This includes
|
||||
`RNNCell`, `LSTMCell`, `GRUCell`, and a number of other cells. These
|
||||
now reside in `tf.nn.rnn_cell` (with aliases in `tf.contrib.rnn` for backwards
|
||||
compatibility). The original `tf.nn.rnn` function is now `tf.nn.static_rnn`,
|
||||
and the bidirectional static and state saving static rnn functions are also
|
||||
now back in the `tf.nn` namespace.
|
||||
|
||||
Notable exceptions are the `EmbeddingWrapper`, `InputProjectionWrapper` and
|
||||
`OutputProjectionWrapper`, which will slowly be moved to deprecation
|
||||
in `tf.contrib.rnn`. These are inefficient wrappers that should often
|
||||
be replaced by calling `embedding_lookup` or `layers.dense` as pre- or post-
|
||||
processing of the rnn. For RNN decoding, this functionality has been replaced
|
||||
with an alternative API in `tf.contrib.seq2seq`.
|
||||
|
||||
## Breaking Changes to the API
|
||||
* `org.tensorflow.contrib.android.TensorFlowInferenceInterface` now throws exceptions where possible and has simplified method signatures.
|
||||
|
@ -961,8 +961,9 @@ typedef struct TF_WhileParams {
|
||||
// - Reference-type inputs
|
||||
// - Directly referencing external tensors from the cond/body graphs (this is
|
||||
// possible in the Python API)
|
||||
TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
|
||||
TF_Status* status);
|
||||
TF_CAPI_EXPORT extern TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs,
|
||||
int ninputs,
|
||||
TF_Status* status);
|
||||
|
||||
// Builds the while loop specified by `params` and returns the output tensors of
|
||||
// the while loop in `outputs`. `outputs` should be allocated to size
|
||||
@ -972,13 +973,14 @@ TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
|
||||
//
|
||||
// Either this or TF_AbortWhile() must be called after a successful
|
||||
// TF_NewWhile() call.
|
||||
void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
|
||||
TF_Output* outputs);
|
||||
TF_CAPI_EXPORT extern void TF_FinishWhile(const TF_WhileParams* params,
|
||||
TF_Status* status,
|
||||
TF_Output* outputs);
|
||||
|
||||
// Frees `params`s resources without building a while loop. `params` is no
|
||||
// longer valid after this returns. Either this or TF_FinishWhile() must be
|
||||
// called after a successful TF_NewWhile() call.
|
||||
void TF_AbortWhile(const TF_WhileParams* params);
|
||||
TF_CAPI_EXPORT extern void TF_AbortWhile(const TF_WhileParams* params);
|
||||
|
||||
// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s,
|
||||
// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...
|
||||
@ -994,8 +996,9 @@ void TF_AbortWhile(const TF_WhileParams* params);
|
||||
// supports. See
|
||||
// https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md
|
||||
// for instructions on how to add C++ more gradients.
|
||||
void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
|
||||
TF_Output* dx, TF_Status* status, TF_Output* dy);
|
||||
TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny,
|
||||
TF_Output* x, int nx, TF_Output* dx,
|
||||
TF_Status* status, TF_Output* dy);
|
||||
|
||||
// TODO(josh11b): Register OpDef, available to all operations added
|
||||
// to this graph.
|
||||
@ -1032,7 +1035,7 @@ TF_CAPI_EXPORT extern TF_Session* TF_NewSession(TF_Graph* graph,
|
||||
//
|
||||
// If successful, populates `graph` with the contents of the Graph and
|
||||
// `meta_graph_def` with the MetaGraphDef of the loaded model.
|
||||
TF_Session* TF_LoadSessionFromSavedModel(
|
||||
TF_CAPI_EXPORT extern TF_Session* TF_LoadSessionFromSavedModel(
|
||||
const TF_SessionOptions* session_options, const TF_Buffer* run_options,
|
||||
const char* export_dir, const char* const* tags, int tags_len,
|
||||
TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status);
|
||||
|
@ -27,10 +27,17 @@ allprojects {
|
||||
}
|
||||
|
||||
dependencies {
|
||||
compile 'org.tensorflow:tensorflow-android:1.2.0-preview'
|
||||
compile 'org.tensorflow:tensorflow-android:+'
|
||||
}
|
||||
```
|
||||
|
||||
This will tell Gradle to use the
|
||||
[latest version](https://bintray.com/google/tensorflow/tensorflow-android/_latestVersion)
|
||||
of the TensorFlow AAR that has been released to
|
||||
[https://bintray.com/google/tensorflow/tensorflow-android](https://bintray.com/google/tensorflow/tensorflow-android).
|
||||
You may replace the `+` with an explicit version label if you wish to
|
||||
use a specific release of TensorFlow in your app.
|
||||
|
||||
To build the libraries yourself (if, for example, you want to support custom
|
||||
TensorFlow operators), pick your preferred approach below:
|
||||
|
||||
|
@ -41,11 +41,11 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
|
||||
__all__ = [
|
||||
@ -225,7 +225,7 @@ def crf_binary_score(tag_indices, sequence_lengths, transition_params):
|
||||
return binary_scores
|
||||
|
||||
|
||||
class CrfForwardRnnCell(core_rnn_cell.RNNCell):
|
||||
class CrfForwardRnnCell(rnn_cell.RNNCell):
|
||||
"""Computes the alpha values in a linear-chain CRF.
|
||||
|
||||
See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
|
||||
|
@ -22,7 +22,6 @@ import time
|
||||
|
||||
from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
||||
from tensorflow.contrib.rnn.python.ops import lstm_ops
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -31,6 +30,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -131,9 +131,9 @@ class CudnnRNNBenchmark(test.Benchmark):
|
||||
]
|
||||
initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127)
|
||||
|
||||
cell = core_rnn_cell_impl.LSTMCell(
|
||||
cell = rnn_cell.LSTMCell(
|
||||
num_units=num_units, initializer=initializer, state_is_tuple=True)
|
||||
multi_cell = core_rnn_cell_impl.MultiRNNCell(
|
||||
multi_cell = rnn_cell.MultiRNNCell(
|
||||
[cell() for _ in range(num_layers)])
|
||||
outputs, final_state = core_rnn.static_rnn(
|
||||
multi_cell, inputs, dtype=dtypes.float32)
|
||||
@ -159,7 +159,7 @@ class CudnnRNNBenchmark(test.Benchmark):
|
||||
]
|
||||
cell = lambda: lstm_ops.LSTMBlockCell(num_units=num_units) # pylint: disable=cell-var-from-loop
|
||||
|
||||
multi_cell = core_rnn_cell_impl.MultiRNNCell(
|
||||
multi_cell = rnn_cell.MultiRNNCell(
|
||||
[cell() for _ in range(num_layers)])
|
||||
outputs, final_state = core_rnn.static_rnn(
|
||||
multi_cell, inputs, dtype=dtypes.float32)
|
||||
|
@ -21,11 +21,11 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.grid_rnn.python.ops import grid_rnn_cell
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
@ -527,7 +527,7 @@ class GridRNNCellTest(test.TestCase):
|
||||
dtypes.float32, shape=(batch_size, input_size))
|
||||
]
|
||||
|
||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
|
||||
self.assertEqual(len(outputs), len(inputs))
|
||||
self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
|
||||
@ -569,7 +569,7 @@ class GridRNNCellTest(test.TestCase):
|
||||
array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
|
||||
]
|
||||
|
||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
|
||||
self.assertEqual(len(outputs), len(inputs))
|
||||
self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
|
||||
@ -609,7 +609,7 @@ class GridRNNCellTest(test.TestCase):
|
||||
array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
|
||||
]
|
||||
|
||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
|
||||
self.assertEqual(len(outputs), len(inputs))
|
||||
self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
|
||||
@ -652,7 +652,7 @@ class GridRNNCellTest(test.TestCase):
|
||||
dtypes.float32, shape=(batch_size, input_size))
|
||||
] + (max_length - 1) * [array_ops.zeros([batch_size, input_size])])
|
||||
|
||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
|
||||
self.assertEqual(len(outputs), len(inputs))
|
||||
self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
|
||||
@ -690,7 +690,7 @@ class GridRNNCellTest(test.TestCase):
|
||||
array_ops.placeholder(dtypes.float32, shape=(None, input_size))
|
||||
]
|
||||
|
||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
|
||||
self.assertEqual(len(outputs), len(inputs))
|
||||
|
||||
|
@ -31,7 +31,6 @@ from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_f
|
||||
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||
from tensorflow.contrib.learn.python.learn.estimators import rnn_common
|
||||
from tensorflow.contrib.learn.python.learn.estimators import run_config
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -42,6 +41,7 @@ from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -107,7 +107,7 @@ class DynamicRnnEstimatorTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(DynamicRnnEstimatorTest, self).setUp()
|
||||
self.rnn_cell = core_rnn_cell_impl.BasicRNNCell(self.NUM_RNN_CELL_UNITS)
|
||||
self.rnn_cell = rnn_cell.BasicRNNCell(self.NUM_RNN_CELL_UNITS)
|
||||
self.mock_target_column = MockTargetColumn(
|
||||
num_label_columns=self.NUM_LABEL_COLUMNS)
|
||||
|
||||
@ -312,19 +312,19 @@ class DynamicRnnEstimatorTest(test.TestCase):
|
||||
# A MultiRNNCell of LSTMCells is both a common choice and an interesting
|
||||
# test case, because it has two levels of nesting, with an inner class that
|
||||
# is not a plain tuple.
|
||||
cell = core_rnn_cell_impl.MultiRNNCell(
|
||||
[core_rnn_cell_impl.LSTMCell(i) for i in cell_sizes])
|
||||
cell = rnn_cell.MultiRNNCell(
|
||||
[rnn_cell.LSTMCell(i) for i in cell_sizes])
|
||||
state_dict = {
|
||||
dynamic_rnn_estimator._get_state_name(i):
|
||||
array_ops.expand_dims(math_ops.range(cell_size), 0)
|
||||
for i, cell_size in enumerate([5, 5, 3, 3, 7, 7])
|
||||
}
|
||||
expected_state = (core_rnn_cell_impl.LSTMStateTuple(
|
||||
expected_state = (rnn_cell.LSTMStateTuple(
|
||||
np.reshape(np.arange(5), [1, -1]), np.reshape(np.arange(5), [1, -1])),
|
||||
core_rnn_cell_impl.LSTMStateTuple(
|
||||
rnn_cell.LSTMStateTuple(
|
||||
np.reshape(np.arange(3), [1, -1]),
|
||||
np.reshape(np.arange(3), [1, -1])),
|
||||
core_rnn_cell_impl.LSTMStateTuple(
|
||||
rnn_cell.LSTMStateTuple(
|
||||
np.reshape(np.arange(7), [1, -1]),
|
||||
np.reshape(np.arange(7), [1, -1])))
|
||||
actual_state = dynamic_rnn_estimator.dict_to_state_tuple(state_dict, cell)
|
||||
|
@ -26,13 +26,13 @@ from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators import model_fn
|
||||
from tensorflow.contrib.learn.python.learn.estimators import rnn_common
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn
|
||||
from tensorflow.contrib.training.python.training import sequence_queueing_state_saver as sqss
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.training import momentum as momentum_opt
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
@ -64,7 +64,7 @@ def construct_state_saving_rnn(cell,
|
||||
final_state: The final state output by the RNN
|
||||
"""
|
||||
with ops.name_scope(scope):
|
||||
rnn_outputs, final_state = core_rnn.static_state_saving_rnn(
|
||||
rnn_outputs, final_state = rnn.static_state_saving_rnn(
|
||||
cell=cell,
|
||||
inputs=inputs,
|
||||
state_saver=state_saver,
|
||||
|
@ -21,9 +21,9 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.learn.python.learn import ops
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -82,7 +82,7 @@ class Seq2SeqOpsTest(test.TestCase):
|
||||
array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3)
|
||||
]
|
||||
encoding = array_ops.placeholder(dtypes.float32, [2, 2])
|
||||
cell = core_rnn_cell_impl.GRUCell(2)
|
||||
cell = rnn_cell.GRUCell(2)
|
||||
outputs, states, sampling_outputs, sampling_states = (
|
||||
ops.rnn_decoder(decoder_inputs, encoding, cell))
|
||||
self.assertEqual(len(outputs), 3)
|
||||
|
@ -25,8 +25,7 @@ import random
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.legacy_seq2seq.python.ops import seq2seq as seq2seq_lib
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -37,6 +36,7 @@ from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import nn_impl
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
@ -51,11 +51,10 @@ class Seq2SeqTest(test.TestCase):
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
||||
_, enc_state = core_rnn.static_rnn(
|
||||
core_rnn_cell_impl.GRUCell(2), inp, dtype=dtypes.float32)
|
||||
_, enc_state = rnn.static_rnn(
|
||||
rnn_cell.GRUCell(2), inp, dtype=dtypes.float32)
|
||||
dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
|
||||
cell = core_rnn_cell_impl.OutputProjectionWrapper(
|
||||
core_rnn_cell_impl.GRUCell(2), 4)
|
||||
cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
|
||||
dec, mem = seq2seq_lib.rnn_decoder(dec_inp, enc_state, cell)
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
@ -71,8 +70,7 @@ class Seq2SeqTest(test.TestCase):
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
||||
dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
|
||||
cell = core_rnn_cell_impl.OutputProjectionWrapper(
|
||||
core_rnn_cell_impl.GRUCell(2), 4)
|
||||
cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
|
||||
dec, mem = seq2seq_lib.basic_rnn_seq2seq(inp, dec_inp, cell)
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
@ -88,8 +86,7 @@ class Seq2SeqTest(test.TestCase):
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
||||
dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
|
||||
cell = core_rnn_cell_impl.OutputProjectionWrapper(
|
||||
core_rnn_cell_impl.GRUCell(2), 4)
|
||||
cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
|
||||
dec, mem = seq2seq_lib.tied_rnn_seq2seq(inp, dec_inp, cell)
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res = sess.run(dec)
|
||||
@ -105,9 +102,9 @@ class Seq2SeqTest(test.TestCase):
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
||||
cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2)
|
||||
cell_fn = lambda: rnn_cell.BasicLSTMCell(2)
|
||||
cell = cell_fn()
|
||||
_, enc_state = core_rnn.static_rnn(cell, inp, dtype=dtypes.float32)
|
||||
_, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
|
||||
dec_inp = [
|
||||
constant_op.constant(
|
||||
i, dtypes.int32, shape=[2]) for i in range(3)
|
||||
@ -138,7 +135,7 @@ class Seq2SeqTest(test.TestCase):
|
||||
constant_op.constant(
|
||||
i, dtypes.int32, shape=[2]) for i in range(3)
|
||||
]
|
||||
cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2)
|
||||
cell_fn = lambda: rnn_cell.BasicLSTMCell(2)
|
||||
cell = cell_fn()
|
||||
dec, mem = seq2seq_lib.embedding_rnn_seq2seq(
|
||||
enc_inp,
|
||||
@ -158,7 +155,7 @@ class Seq2SeqTest(test.TestCase):
|
||||
|
||||
# Test with state_is_tuple=False.
|
||||
with variable_scope.variable_scope("no_tuple"):
|
||||
cell_nt = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
|
||||
cell_nt = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
|
||||
dec, mem = seq2seq_lib.embedding_rnn_seq2seq(
|
||||
enc_inp,
|
||||
dec_inp,
|
||||
@ -242,9 +239,7 @@ class Seq2SeqTest(test.TestCase):
|
||||
constant_op.constant(
|
||||
i, dtypes.int32, shape=[2]) for i in range(3)
|
||||
]
|
||||
cell = functools.partial(
|
||||
core_rnn_cell_impl.BasicLSTMCell,
|
||||
2, state_is_tuple=True)
|
||||
cell = functools.partial(rnn_cell.BasicLSTMCell, 2, state_is_tuple=True)
|
||||
dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq(
|
||||
enc_inp, dec_inp, cell(), num_symbols=5, embedding_size=2)
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
@ -324,11 +319,10 @@ class Seq2SeqTest(test.TestCase):
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
|
||||
cell_fn = lambda: rnn_cell.GRUCell(2)
|
||||
cell = cell_fn()
|
||||
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
||||
enc_outputs, enc_state = core_rnn.static_rnn(
|
||||
cell, inp, dtype=dtypes.float32)
|
||||
enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
|
||||
attn_states = array_ops.concat([
|
||||
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
|
||||
], 1)
|
||||
@ -350,11 +344,10 @@ class Seq2SeqTest(test.TestCase):
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
|
||||
cell_fn = lambda: rnn_cell.GRUCell(2)
|
||||
cell = cell_fn()
|
||||
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
||||
enc_outputs, enc_state = core_rnn.static_rnn(
|
||||
cell, inp, dtype=dtypes.float32)
|
||||
enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
|
||||
attn_states = array_ops.concat([
|
||||
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
|
||||
], 1)
|
||||
@ -377,7 +370,7 @@ class Seq2SeqTest(test.TestCase):
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
|
||||
cell_fn = lambda: rnn_cell.GRUCell(2)
|
||||
cell = cell_fn()
|
||||
inp = constant_op.constant(0.5, shape=[2, 2, 2])
|
||||
enc_outputs, enc_state = rnn.dynamic_rnn(
|
||||
@ -401,7 +394,7 @@ class Seq2SeqTest(test.TestCase):
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
|
||||
cell_fn = lambda: rnn_cell.GRUCell(2)
|
||||
cell = cell_fn()
|
||||
inp = constant_op.constant(0.5, shape=[2, 2, 2])
|
||||
enc_outputs, enc_state = rnn.dynamic_rnn(
|
||||
@ -426,14 +419,13 @@ class Seq2SeqTest(test.TestCase):
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
single_cell = lambda: core_rnn_cell_impl.BasicLSTMCell( # pylint: disable=g-long-lambda
|
||||
single_cell = lambda: rnn_cell.BasicLSTMCell( # pylint: disable=g-long-lambda
|
||||
2, state_is_tuple=True)
|
||||
cell_fn = lambda: core_rnn_cell_impl.MultiRNNCell( # pylint: disable=g-long-lambda
|
||||
cell_fn = lambda: rnn_cell.MultiRNNCell( # pylint: disable=g-long-lambda
|
||||
cells=[single_cell() for _ in range(2)], state_is_tuple=True)
|
||||
cell = cell_fn()
|
||||
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
||||
enc_outputs, enc_state = core_rnn.static_rnn(
|
||||
cell, inp, dtype=dtypes.float32)
|
||||
enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
|
||||
attn_states = array_ops.concat([
|
||||
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
|
||||
], 1)
|
||||
@ -459,12 +451,11 @@ class Seq2SeqTest(test.TestCase):
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
cell_fn = lambda: core_rnn_cell_impl.MultiRNNCell( # pylint: disable=g-long-lambda
|
||||
cells=[core_rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)])
|
||||
cell_fn = lambda: rnn_cell.MultiRNNCell( # pylint: disable=g-long-lambda
|
||||
cells=[rnn_cell.BasicLSTMCell(2) for _ in range(2)])
|
||||
cell = cell_fn()
|
||||
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
||||
enc_outputs, enc_state = core_rnn.static_rnn(
|
||||
cell, inp, dtype=dtypes.float32)
|
||||
enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
|
||||
attn_states = array_ops.concat([
|
||||
array_ops.reshape(e, [-1, 1, cell.output_size])
|
||||
for e in enc_outputs
|
||||
@ -492,10 +483,9 @@ class Seq2SeqTest(test.TestCase):
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
||||
cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
|
||||
cell_fn = lambda: rnn_cell.GRUCell(2)
|
||||
cell = cell_fn()
|
||||
enc_outputs, enc_state = core_rnn.static_rnn(
|
||||
cell, inp, dtype=dtypes.float32)
|
||||
enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
|
||||
attn_states = array_ops.concat([
|
||||
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
|
||||
], 1)
|
||||
@ -534,7 +524,7 @@ class Seq2SeqTest(test.TestCase):
|
||||
constant_op.constant(
|
||||
i, dtypes.int32, shape=[2]) for i in range(3)
|
||||
]
|
||||
cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2)
|
||||
cell_fn = lambda: rnn_cell.BasicLSTMCell(2)
|
||||
cell = cell_fn()
|
||||
dec, mem = seq2seq_lib.embedding_attention_seq2seq(
|
||||
enc_inp,
|
||||
@ -555,8 +545,7 @@ class Seq2SeqTest(test.TestCase):
|
||||
# Test with state_is_tuple=False.
|
||||
with variable_scope.variable_scope("no_tuple"):
|
||||
cell_fn = functools.partial(
|
||||
core_rnn_cell_impl.BasicLSTMCell,
|
||||
2, state_is_tuple=False)
|
||||
rnn_cell.BasicLSTMCell, 2, state_is_tuple=False)
|
||||
cell_nt = cell_fn()
|
||||
dec, mem = seq2seq_lib.embedding_attention_seq2seq(
|
||||
enc_inp,
|
||||
@ -651,11 +640,10 @@ class Seq2SeqTest(test.TestCase):
|
||||
]
|
||||
dec_symbols_dict = {"0": 5, "1": 6}
|
||||
def EncCellFn():
|
||||
return core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
|
||||
return rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
|
||||
def DecCellsFn():
|
||||
return dict(
|
||||
(k, core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True))
|
||||
for k in dec_symbols_dict)
|
||||
return dict((k, rnn_cell.BasicLSTMCell(2, state_is_tuple=True))
|
||||
for k in dec_symbols_dict)
|
||||
outputs_dict, state_dict = (seq2seq_lib.one2many_rnn_seq2seq(
|
||||
enc_inp, dec_inp_dict, EncCellFn(), DecCellsFn(),
|
||||
2, dec_symbols_dict, embedding_size=2))
|
||||
@ -796,8 +784,8 @@ class Seq2SeqTest(test.TestCase):
|
||||
# """Example sequence-to-sequence model that uses GRU cells."""
|
||||
|
||||
# def GRUSeq2Seq(enc_inp, dec_inp):
|
||||
# cell = core_rnn_cell_impl.MultiRNNCell(
|
||||
# [core_rnn_cell_impl.GRUCell(24) for _ in range(2)])
|
||||
# cell = rnn_cell.MultiRNNCell(
|
||||
# [rnn_cell.GRUCell(24) for _ in range(2)])
|
||||
# return seq2seq_lib.embedding_attention_seq2seq(
|
||||
# enc_inp,
|
||||
# dec_inp,
|
||||
@ -862,9 +850,8 @@ class Seq2SeqTest(test.TestCase):
|
||||
"""Example sequence-to-sequence model that uses GRU cells."""
|
||||
|
||||
def GRUSeq2Seq(enc_inp, dec_inp):
|
||||
cell = core_rnn_cell_impl.MultiRNNCell(
|
||||
[core_rnn_cell_impl.GRUCell(24) for _ in range(2)],
|
||||
state_is_tuple=True)
|
||||
cell = rnn_cell.MultiRNNCell(
|
||||
[rnn_cell.GRUCell(24) for _ in range(2)], state_is_tuple=True)
|
||||
return seq2seq_lib.embedding_attention_seq2seq(
|
||||
enc_inp,
|
||||
dec_inp,
|
||||
@ -1040,7 +1027,7 @@ class Seq2SeqTest(test.TestCase):
|
||||
self.assertAllClose(v_true.eval(), v_false.eval())
|
||||
|
||||
def EmbeddingRNNSeq2SeqF(enc_inp, dec_inp, feed_previous):
|
||||
cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
|
||||
cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
|
||||
return seq2seq_lib.embedding_rnn_seq2seq(
|
||||
enc_inp,
|
||||
dec_inp,
|
||||
@ -1051,7 +1038,7 @@ class Seq2SeqTest(test.TestCase):
|
||||
feed_previous=feed_previous)
|
||||
|
||||
def EmbeddingRNNSeq2SeqNoTupleF(enc_inp, dec_inp, feed_previous):
|
||||
cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
|
||||
cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
|
||||
return seq2seq_lib.embedding_rnn_seq2seq(
|
||||
enc_inp,
|
||||
dec_inp,
|
||||
@ -1062,7 +1049,7 @@ class Seq2SeqTest(test.TestCase):
|
||||
feed_previous=feed_previous)
|
||||
|
||||
def EmbeddingTiedRNNSeq2Seq(enc_inp, dec_inp, feed_previous):
|
||||
cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
|
||||
cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
|
||||
return seq2seq_lib.embedding_tied_rnn_seq2seq(
|
||||
enc_inp,
|
||||
dec_inp,
|
||||
@ -1072,7 +1059,7 @@ class Seq2SeqTest(test.TestCase):
|
||||
feed_previous=feed_previous)
|
||||
|
||||
def EmbeddingTiedRNNSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous):
|
||||
cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
|
||||
cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
|
||||
return seq2seq_lib.embedding_tied_rnn_seq2seq(
|
||||
enc_inp,
|
||||
dec_inp,
|
||||
@ -1082,7 +1069,7 @@ class Seq2SeqTest(test.TestCase):
|
||||
feed_previous=feed_previous)
|
||||
|
||||
def EmbeddingAttentionSeq2Seq(enc_inp, dec_inp, feed_previous):
|
||||
cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
|
||||
cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
|
||||
return seq2seq_lib.embedding_attention_seq2seq(
|
||||
enc_inp,
|
||||
dec_inp,
|
||||
@ -1093,7 +1080,7 @@ class Seq2SeqTest(test.TestCase):
|
||||
feed_previous=feed_previous)
|
||||
|
||||
def EmbeddingAttentionSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous):
|
||||
cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
|
||||
cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
|
||||
return seq2seq_lib.embedding_attention_seq2seq(
|
||||
enc_inp,
|
||||
dec_inp,
|
||||
|
@ -62,9 +62,7 @@ import copy
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
from six.moves import zip # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -72,11 +70,13 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell_impl
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
# TODO(ebrevdo): Remove once _linear is fully deprecated.
|
||||
linear = core_rnn_cell_impl._linear # pylint: disable=protected-access
|
||||
linear = rnn_cell_impl._linear # pylint: disable=protected-access
|
||||
|
||||
|
||||
def _extract_argmax_and_embed(embedding,
|
||||
@ -119,7 +119,7 @@ def rnn_decoder(decoder_inputs,
|
||||
Args:
|
||||
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
|
||||
initial_state: 2D Tensor with shape [batch_size x cell.state_size].
|
||||
cell: core_rnn_cell.RNNCell defining the cell function and size.
|
||||
cell: rnn_cell.RNNCell defining the cell function and size.
|
||||
loop_function: If not None, this function will be applied to the i-th output
|
||||
in order to generate the i+1-st input, and decoder_inputs will be ignored,
|
||||
except for the first element ("GO" symbol). This can be used for decoding,
|
||||
@ -170,7 +170,7 @@ def basic_rnn_seq2seq(encoder_inputs,
|
||||
Args:
|
||||
encoder_inputs: A list of 2D Tensors [batch_size x input_size].
|
||||
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
|
||||
cell: core_rnn_cell.RNNCell defining the cell function and size.
|
||||
cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
|
||||
dtype: The dtype of the initial state of the RNN cell (default: tf.float32).
|
||||
scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq".
|
||||
|
||||
@ -183,7 +183,7 @@ def basic_rnn_seq2seq(encoder_inputs,
|
||||
"""
|
||||
with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"):
|
||||
enc_cell = copy.deepcopy(cell)
|
||||
_, enc_state = core_rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype)
|
||||
_, enc_state = rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype)
|
||||
return rnn_decoder(decoder_inputs, enc_state, cell)
|
||||
|
||||
|
||||
@ -202,7 +202,7 @@ def tied_rnn_seq2seq(encoder_inputs,
|
||||
Args:
|
||||
encoder_inputs: A list of 2D Tensors [batch_size x input_size].
|
||||
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
|
||||
cell: core_rnn_cell.RNNCell defining the cell function and size.
|
||||
cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
|
||||
loop_function: If not None, this function will be applied to i-th output
|
||||
in order to generate i+1-th input, and decoder_inputs will be ignored,
|
||||
except for the first element ("GO" symbol), see rnn_decoder for details.
|
||||
@ -219,7 +219,7 @@ def tied_rnn_seq2seq(encoder_inputs,
|
||||
"""
|
||||
with variable_scope.variable_scope("combined_tied_rnn_seq2seq"):
|
||||
scope = scope or "tied_rnn_seq2seq"
|
||||
_, enc_state = core_rnn.static_rnn(
|
||||
_, enc_state = rnn.static_rnn(
|
||||
cell, encoder_inputs, dtype=dtype, scope=scope)
|
||||
variable_scope.get_variable_scope().reuse_variables()
|
||||
return rnn_decoder(
|
||||
@ -244,7 +244,7 @@ def embedding_rnn_decoder(decoder_inputs,
|
||||
Args:
|
||||
decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs).
|
||||
initial_state: 2D Tensor [batch_size x cell.state_size].
|
||||
cell: core_rnn_cell.RNNCell defining the cell function.
|
||||
cell: tf.nn.rnn_cell.RNNCell defining the cell function.
|
||||
num_symbols: Integer, how many symbols come into the embedding.
|
||||
embedding_size: Integer, the length of the embedding vector for each symbol.
|
||||
output_projection: None or a pair (W, B) of output projection weights and
|
||||
@ -320,7 +320,7 @@ def embedding_rnn_seq2seq(encoder_inputs,
|
||||
Args:
|
||||
encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
|
||||
decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
|
||||
cell: core_rnn_cell.RNNCell defining the cell function and size.
|
||||
cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
|
||||
num_encoder_symbols: Integer; number of symbols on the encoder side.
|
||||
num_decoder_symbols: Integer; number of symbols on the decoder side.
|
||||
embedding_size: Integer, the length of the embedding vector for each symbol.
|
||||
@ -360,8 +360,7 @@ def embedding_rnn_seq2seq(encoder_inputs,
|
||||
encoder_cell,
|
||||
embedding_classes=num_encoder_symbols,
|
||||
embedding_size=embedding_size)
|
||||
_, encoder_state = core_rnn.static_rnn(
|
||||
encoder_cell, encoder_inputs, dtype=dtype)
|
||||
_, encoder_state = rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype)
|
||||
|
||||
# Decoder.
|
||||
if output_projection is None:
|
||||
@ -431,7 +430,7 @@ def embedding_tied_rnn_seq2seq(encoder_inputs,
|
||||
Args:
|
||||
encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
|
||||
decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
|
||||
cell: core_rnn_cell.RNNCell defining the cell function and size.
|
||||
cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
|
||||
num_symbols: Integer; number of symbols for both encoder and decoder.
|
||||
embedding_size: Integer, the length of the embedding vector for each symbol.
|
||||
num_decoder_symbols: Integer; number of output symbols for decoder. If
|
||||
@ -560,7 +559,7 @@ def attention_decoder(decoder_inputs,
|
||||
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
|
||||
initial_state: 2D Tensor [batch_size x cell.state_size].
|
||||
attention_states: 3D Tensor [batch_size x attn_length x attn_size].
|
||||
cell: core_rnn_cell.RNNCell defining the cell function and size.
|
||||
cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
|
||||
output_size: Size of the output vectors; if None, we use cell.output_size.
|
||||
num_heads: Number of attention heads that read from attention_states.
|
||||
loop_function: If not None, this function will be applied to i-th output
|
||||
@ -720,7 +719,7 @@ def embedding_attention_decoder(decoder_inputs,
|
||||
decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs).
|
||||
initial_state: 2D Tensor [batch_size x cell.state_size].
|
||||
attention_states: 3D Tensor [batch_size x attn_length x attn_size].
|
||||
cell: core_rnn_cell.RNNCell defining the cell function.
|
||||
cell: tf.nn.rnn_cell.RNNCell defining the cell function.
|
||||
num_symbols: Integer, how many symbols come into the embedding.
|
||||
embedding_size: Integer, the length of the embedding vector for each symbol.
|
||||
num_heads: Number of attention heads that read from attention_states.
|
||||
@ -814,7 +813,7 @@ def embedding_attention_seq2seq(encoder_inputs,
|
||||
Args:
|
||||
encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
|
||||
decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
|
||||
cell: core_rnn_cell.RNNCell defining the cell function and size.
|
||||
cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
|
||||
num_encoder_symbols: Integer; number of symbols on the encoder side.
|
||||
num_decoder_symbols: Integer; number of symbols on the decoder side.
|
||||
embedding_size: Integer, the length of the embedding vector for each symbol.
|
||||
@ -851,7 +850,7 @@ def embedding_attention_seq2seq(encoder_inputs,
|
||||
encoder_cell,
|
||||
embedding_classes=num_encoder_symbols,
|
||||
embedding_size=embedding_size)
|
||||
encoder_outputs, encoder_state = core_rnn.static_rnn(
|
||||
encoder_outputs, encoder_state = rnn.static_rnn(
|
||||
encoder_cell, encoder_inputs, dtype=dtype)
|
||||
|
||||
# First calculate a concatenation of encoder outputs to put attention on.
|
||||
@ -937,9 +936,10 @@ def one2many_rnn_seq2seq(encoder_inputs,
|
||||
the corresponding decoder_inputs; each decoder_inputs is a list of 1D
|
||||
Tensors of shape [batch_size]; num_decoders is defined as
|
||||
len(decoder_inputs_dict).
|
||||
enc_cell: core_rnn_cell.RNNCell defining the encoder cell function and size.
|
||||
enc_cell: tf.nn.rnn_cell.RNNCell defining the encoder cell function and
|
||||
size.
|
||||
dec_cells_dict: A dictionary mapping encoder name (string) to an
|
||||
instance of core_rnn_cell.RNNCell.
|
||||
instance of tf.nn.rnn_cell.RNNCell.
|
||||
num_encoder_symbols: Integer; number of symbols on the encoder side.
|
||||
num_decoder_symbols_dict: A dictionary mapping decoder name (string) to an
|
||||
integer specifying number of symbols for the corresponding decoder;
|
||||
@ -971,12 +971,12 @@ def one2many_rnn_seq2seq(encoder_inputs,
|
||||
outputs_dict = {}
|
||||
state_dict = {}
|
||||
|
||||
if not isinstance(enc_cell, core_rnn_cell.RNNCell):
|
||||
if not isinstance(enc_cell, rnn_cell_impl.RNNCell):
|
||||
raise TypeError("enc_cell is not an RNNCell: %s" % type(enc_cell))
|
||||
if set(dec_cells_dict) != set(decoder_inputs_dict):
|
||||
raise ValueError("keys of dec_cells_dict != keys of decodre_inputs_dict")
|
||||
for dec_cell in dec_cells_dict.values():
|
||||
if not isinstance(dec_cell, core_rnn_cell.RNNCell):
|
||||
if not isinstance(dec_cell, rnn_cell_impl.RNNCell):
|
||||
raise TypeError("dec_cell is not an RNNCell: %s" % type(dec_cell))
|
||||
|
||||
with variable_scope.variable_scope(
|
||||
@ -988,8 +988,7 @@ def one2many_rnn_seq2seq(encoder_inputs,
|
||||
enc_cell,
|
||||
embedding_classes=num_encoder_symbols,
|
||||
embedding_size=embedding_size)
|
||||
_, encoder_state = core_rnn.static_rnn(
|
||||
enc_cell, encoder_inputs, dtype=dtype)
|
||||
_, encoder_state = rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype)
|
||||
|
||||
# Decoder.
|
||||
for name, decoder_inputs in decoder_inputs_dict.items():
|
||||
@ -1153,7 +1152,7 @@ def model_with_buckets(encoder_inputs,
|
||||
|
||||
The seq2seq argument is a function that defines a sequence-to-sequence model,
|
||||
e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(
|
||||
x, y, core_rnn_cell.GRUCell(24))
|
||||
x, y, rnn_cell.GRUCell(24))
|
||||
|
||||
Args:
|
||||
encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input.
|
||||
|
@ -20,13 +20,13 @@ from __future__ import print_function
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
from tensorflow.contrib.framework.python.ops import variables
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import variable_scope
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ def ndlstm_base_unrolled(inputs, noutput, scope=None, reverse=False):
|
||||
"""
|
||||
with variable_scope.variable_scope(scope, "SeqLstmUnrolled", [inputs]):
|
||||
length, batch_size, _ = _shape(inputs)
|
||||
lstm_cell = core_rnn_cell_impl.BasicLSTMCell(noutput, state_is_tuple=False)
|
||||
lstm_cell = rnn_cell.BasicLSTMCell(noutput, state_is_tuple=False)
|
||||
state = array_ops.zeros([batch_size, lstm_cell.state_size])
|
||||
output_u = []
|
||||
inputs_u = array_ops.unstack(inputs)
|
||||
@ -88,7 +88,7 @@ def ndlstm_base_dynamic(inputs, noutput, scope=None, reverse=False):
|
||||
# TODO(tmb) make batch size, sequence_length dynamic
|
||||
# example: sequence_length = tf.shape(inputs)[0]
|
||||
_, batch_size, _ = _shape(inputs)
|
||||
lstm_cell = core_rnn_cell_impl.BasicLSTMCell(noutput, state_is_tuple=False)
|
||||
lstm_cell = rnn_cell.BasicLSTMCell(noutput, state_is_tuple=False)
|
||||
state = array_ops.zeros([batch_size, lstm_cell.state_size])
|
||||
sequence_length = int(inputs.get_shape()[0])
|
||||
sequence_lengths = math_ops.to_int64(
|
||||
@ -145,7 +145,7 @@ def sequence_to_final(inputs, noutput, scope=None, name=None, reverse=False):
|
||||
"""
|
||||
with variable_scope.variable_scope(scope, "SequenceToFinal", [inputs]):
|
||||
length, batch_size, _ = _shape(inputs)
|
||||
lstm = core_rnn_cell_impl.BasicLSTMCell(noutput, state_is_tuple=False)
|
||||
lstm = rnn_cell.BasicLSTMCell(noutput, state_is_tuple=False)
|
||||
state = array_ops.zeros([batch_size, lstm.state_size])
|
||||
inputs_u = array_ops.unstack(inputs)
|
||||
if reverse:
|
||||
|
@ -16,21 +16,26 @@
|
||||
|
||||
See @{$python/contrib.rnn} guide.
|
||||
|
||||
# From core
|
||||
@@RNNCell
|
||||
@@BasicRNNCell
|
||||
@@BasicLSTMCell
|
||||
@@GRUCell
|
||||
@@LSTMCell
|
||||
@@LayerNormBasicLSTMCell
|
||||
@@LSTMStateTuple
|
||||
@@MultiRNNCell
|
||||
@@LSTMBlockWrapper
|
||||
@@DropoutWrapper
|
||||
@@MultiRNNCell
|
||||
@@DeviceWrapper
|
||||
@@ResidualWrapper
|
||||
|
||||
# Used to be in core, but kept in contrib.
|
||||
@@EmbeddingWrapper
|
||||
@@InputProjectionWrapper
|
||||
@@OutputProjectionWrapper
|
||||
@@DeviceWrapper
|
||||
@@ResidualWrapper
|
||||
|
||||
# Created in contrib, eventual plans to move to core.
|
||||
@@LayerNormBasicLSTMCell
|
||||
@@LSTMBlockWrapper
|
||||
@@LSTMBlockCell
|
||||
@@GRUBlockCell
|
||||
@@FusedRNNCell
|
||||
@ -48,9 +53,11 @@ See @{$python/contrib.rnn} guide.
|
||||
@@HighwayWrapper
|
||||
@@GLSTMCell
|
||||
|
||||
### RNNCell wrappers
|
||||
# RNNCell wrappers
|
||||
@@AttentionCellWrapper
|
||||
@@CompiledWrapper
|
||||
|
||||
# RNN functions
|
||||
@@static_rnn
|
||||
@@static_state_saving_rnn
|
||||
@@static_bidirectional_rnn
|
||||
@ -62,31 +69,23 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn import static_bidirectional_rnn
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn import static_rnn
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn import static_state_saving_rnn
|
||||
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import BasicLSTMCell
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import BasicRNNCell
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import DeviceWrapper
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import DropoutWrapper
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import EmbeddingWrapper
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import GRUCell
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import InputProjectionWrapper
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMCell
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import MultiRNNCell
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import OutputProjectionWrapper
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import ResidualWrapper
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import RNNCell
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import,line-too-long
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import EmbeddingWrapper
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import InputProjectionWrapper
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import OutputProjectionWrapper
|
||||
|
||||
from tensorflow.contrib.rnn.python.ops.fused_rnn_cell import *
|
||||
from tensorflow.contrib.rnn.python.ops.gru_ops import *
|
||||
from tensorflow.contrib.rnn.python.ops.lstm_ops import *
|
||||
from tensorflow.contrib.rnn.python.ops.rnn import *
|
||||
from tensorflow.contrib.rnn.python.ops.rnn_cell import *
|
||||
|
||||
from tensorflow.python.ops.rnn import static_bidirectional_rnn
|
||||
from tensorflow.python.ops.rnn import static_rnn
|
||||
from tensorflow.python.ops.rnn import static_state_saving_rnn
|
||||
|
||||
from tensorflow.python.ops.rnn_cell import *
|
||||
# pylint: enable=unused-import,wildcard-import,line-too-long
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
remove_undocumented(__name__, ['core_rnn_cell'])
|
||||
remove_undocumented(__name__)
|
||||
|
@ -25,8 +25,7 @@ import numpy as np
|
||||
# TODO(ebrevdo): Remove once _linear is fully deprecated.
|
||||
# pylint: disable=protected-access
|
||||
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import _linear as linear
|
||||
from tensorflow.contrib import rnn as contrib_rnn
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -37,12 +36,14 @@ from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell_impl
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
# pylint: enable=protected-access
|
||||
linear = rnn_cell_impl._linear
|
||||
|
||||
|
||||
class RNNCellTest(test.TestCase):
|
||||
@ -74,14 +75,12 @@ class RNNCellTest(test.TestCase):
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([1, 2])
|
||||
m = array_ops.zeros([1, 2])
|
||||
cell = core_rnn_cell_impl.BasicRNNCell(2)
|
||||
cell = rnn_cell_impl.BasicRNNCell(2)
|
||||
g, _ = cell(x, m)
|
||||
self.assertEqual(
|
||||
["root/basic_rnn_cell/%s:0"
|
||||
% core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
||||
"root/basic_rnn_cell/%s:0"
|
||||
% core_rnn_cell_impl._BIAS_VARIABLE_NAME],
|
||||
[v.name for v in cell.trainable_variables])
|
||||
self.assertEqual([
|
||||
"root/basic_rnn_cell/%s:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
||||
"root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME
|
||||
], [v.name for v in cell.trainable_variables])
|
||||
self.assertFalse(cell.non_trainable_variables)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run(
|
||||
@ -100,15 +99,13 @@ class RNNCellTest(test.TestCase):
|
||||
custom_getter=not_trainable_getter):
|
||||
x = array_ops.zeros([1, 2])
|
||||
m = array_ops.zeros([1, 2])
|
||||
cell = core_rnn_cell_impl.BasicRNNCell(2)
|
||||
cell = rnn_cell_impl.BasicRNNCell(2)
|
||||
g, _ = cell(x, m)
|
||||
self.assertFalse(cell.trainable_variables)
|
||||
self.assertEqual(
|
||||
["root/basic_rnn_cell/%s:0"
|
||||
% core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
||||
"root/basic_rnn_cell/%s:0"
|
||||
% core_rnn_cell_impl._BIAS_VARIABLE_NAME],
|
||||
[v.name for v in cell.non_trainable_variables])
|
||||
self.assertEqual([
|
||||
"root/basic_rnn_cell/%s:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
||||
"root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME
|
||||
], [v.name for v in cell.non_trainable_variables])
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run(
|
||||
[g], {x.name: np.array([[1., 1.]]),
|
||||
@ -121,7 +118,7 @@ class RNNCellTest(test.TestCase):
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([1, 2])
|
||||
m = array_ops.zeros([1, 2])
|
||||
g, _ = core_rnn_cell_impl.GRUCell(2)(x, m)
|
||||
g, _ = rnn_cell_impl.GRUCell(2)(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run(
|
||||
[g], {x.name: np.array([[1., 1.]]),
|
||||
@ -133,7 +130,7 @@ class RNNCellTest(test.TestCase):
|
||||
x = array_ops.zeros(
|
||||
[1, 3]) # Test GRUCell with input_size != num_units.
|
||||
m = array_ops.zeros([1, 2])
|
||||
g, _ = core_rnn_cell_impl.GRUCell(2)(x, m)
|
||||
g, _ = rnn_cell_impl.GRUCell(2)(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run(
|
||||
[g],
|
||||
@ -148,20 +145,23 @@ class RNNCellTest(test.TestCase):
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([1, 2])
|
||||
m = array_ops.zeros([1, 8])
|
||||
cell = core_rnn_cell_impl.MultiRNNCell(
|
||||
[core_rnn_cell_impl.BasicLSTMCell(
|
||||
2, state_is_tuple=False) for _ in range(2)],
|
||||
cell = rnn_cell_impl.MultiRNNCell(
|
||||
[
|
||||
rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
|
||||
for _ in range(2)
|
||||
],
|
||||
state_is_tuple=False)
|
||||
g, out_m = cell(x, m)
|
||||
expected_variable_names = [
|
||||
"root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0"
|
||||
% core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
||||
"root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0"
|
||||
% core_rnn_cell_impl._BIAS_VARIABLE_NAME,
|
||||
"root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0"
|
||||
% core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
||||
"root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0"
|
||||
% core_rnn_cell_impl._BIAS_VARIABLE_NAME]
|
||||
"root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" %
|
||||
rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
||||
"root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" %
|
||||
rnn_cell_impl._BIAS_VARIABLE_NAME,
|
||||
"root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" %
|
||||
rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
||||
"root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" %
|
||||
rnn_cell_impl._BIAS_VARIABLE_NAME
|
||||
]
|
||||
self.assertEqual(
|
||||
expected_variable_names, [v.name for v in cell.trainable_variables])
|
||||
self.assertFalse(cell.non_trainable_variables)
|
||||
@ -185,8 +185,7 @@ class RNNCellTest(test.TestCase):
|
||||
x = array_ops.zeros(
|
||||
[1, 3]) # Test BasicLSTMCell with input_size != num_units.
|
||||
m = array_ops.zeros([1, 4])
|
||||
g, out_m = core_rnn_cell_impl.BasicLSTMCell(
|
||||
2, state_is_tuple=False)(x, m)
|
||||
g, out_m = rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run(
|
||||
[g, out_m],
|
||||
@ -206,7 +205,7 @@ class RNNCellTest(test.TestCase):
|
||||
x = array_ops.zeros([batch_size, input_size])
|
||||
m = array_ops.zeros([batch_size - 1, state_size])
|
||||
with self.assertRaises(ValueError):
|
||||
g, out_m = core_rnn_cell_impl.BasicLSTMCell(
|
||||
g, out_m = rnn_cell_impl.BasicLSTMCell(
|
||||
num_units, state_is_tuple=False)(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
sess.run([g, out_m],
|
||||
@ -225,7 +224,7 @@ class RNNCellTest(test.TestCase):
|
||||
x = array_ops.zeros([batch_size, input_size])
|
||||
m = array_ops.zeros([batch_size, state_size])
|
||||
with self.assertRaises(ValueError):
|
||||
g, out_m = core_rnn_cell_impl.BasicLSTMCell(
|
||||
g, out_m = rnn_cell_impl.BasicLSTMCell(
|
||||
num_units, state_is_tuple=False)(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
sess.run([g, out_m],
|
||||
@ -239,31 +238,29 @@ class RNNCellTest(test.TestCase):
|
||||
x = array_ops.zeros([1, 2])
|
||||
m0 = (array_ops.zeros([1, 2]),) * 2
|
||||
m1 = (array_ops.zeros([1, 2]),) * 2
|
||||
cell = core_rnn_cell_impl.MultiRNNCell(
|
||||
[core_rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)],
|
||||
cell = rnn_cell_impl.MultiRNNCell(
|
||||
[rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)],
|
||||
state_is_tuple=True)
|
||||
self.assertTrue(isinstance(cell.state_size, tuple))
|
||||
self.assertTrue(
|
||||
isinstance(cell.state_size[0], core_rnn_cell_impl.LSTMStateTuple))
|
||||
isinstance(cell.state_size[0], rnn_cell_impl.LSTMStateTuple))
|
||||
self.assertTrue(
|
||||
isinstance(cell.state_size[1], core_rnn_cell_impl.LSTMStateTuple))
|
||||
isinstance(cell.state_size[1], rnn_cell_impl.LSTMStateTuple))
|
||||
|
||||
# Pass in regular tuples
|
||||
_, (out_m0, out_m1) = cell(x, (m0, m1))
|
||||
self.assertTrue(isinstance(out_m0, core_rnn_cell_impl.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(out_m1, core_rnn_cell_impl.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(out_m0, rnn_cell_impl.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple))
|
||||
|
||||
# Pass in LSTMStateTuples
|
||||
variable_scope.get_variable_scope().reuse_variables()
|
||||
zero_state = cell.zero_state(1, dtypes.float32)
|
||||
self.assertTrue(isinstance(zero_state, tuple))
|
||||
self.assertTrue(
|
||||
isinstance(zero_state[0], core_rnn_cell_impl.LSTMStateTuple))
|
||||
self.assertTrue(
|
||||
isinstance(zero_state[1], core_rnn_cell_impl.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(zero_state[0], rnn_cell_impl.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(zero_state[1], rnn_cell_impl.LSTMStateTuple))
|
||||
_, (out_m0, out_m1) = cell(x, zero_state)
|
||||
self.assertTrue(isinstance(out_m0, core_rnn_cell_impl.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(out_m1, core_rnn_cell_impl.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(out_m0, rnn_cell_impl.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple))
|
||||
|
||||
def testBasicLSTMCellWithStateTuple(self):
|
||||
with self.test_session() as sess:
|
||||
@ -272,9 +269,11 @@ class RNNCellTest(test.TestCase):
|
||||
x = array_ops.zeros([1, 2])
|
||||
m0 = array_ops.zeros([1, 4])
|
||||
m1 = array_ops.zeros([1, 4])
|
||||
cell = core_rnn_cell_impl.MultiRNNCell(
|
||||
[core_rnn_cell_impl.BasicLSTMCell(
|
||||
2, state_is_tuple=False) for _ in range(2)],
|
||||
cell = rnn_cell_impl.MultiRNNCell(
|
||||
[
|
||||
rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
|
||||
for _ in range(2)
|
||||
],
|
||||
state_is_tuple=True)
|
||||
g, (out_m0, out_m1) = cell(x, (m0, m1))
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
@ -306,7 +305,7 @@ class RNNCellTest(test.TestCase):
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([batch_size, input_size])
|
||||
m = array_ops.zeros([batch_size, state_size])
|
||||
cell = core_rnn_cell_impl.LSTMCell(
|
||||
cell = rnn_cell_impl.LSTMCell(
|
||||
num_units=num_units,
|
||||
num_proj=num_proj,
|
||||
forget_bias=1.0,
|
||||
@ -340,7 +339,7 @@ class RNNCellTest(test.TestCase):
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([batch_size, input_size])
|
||||
m = array_ops.zeros([batch_size, state_size])
|
||||
cell = core_rnn_cell_impl.LSTMCell(
|
||||
cell = rnn_cell_impl.LSTMCell(
|
||||
num_units=num_units,
|
||||
num_proj=num_proj,
|
||||
forget_bias=1.0,
|
||||
@ -358,8 +357,7 @@ class RNNCellTest(test.TestCase):
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([1, 3])
|
||||
m = array_ops.zeros([1, 3])
|
||||
cell = core_rnn_cell_impl.OutputProjectionWrapper(
|
||||
core_rnn_cell_impl.GRUCell(3), 2)
|
||||
cell = contrib_rnn.OutputProjectionWrapper(rnn_cell_impl.GRUCell(3), 2)
|
||||
g, new_m = cell(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run([g, new_m], {
|
||||
@ -376,8 +374,8 @@ class RNNCellTest(test.TestCase):
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([1, 2])
|
||||
m = array_ops.zeros([1, 3])
|
||||
cell = core_rnn_cell_impl.InputProjectionWrapper(
|
||||
core_rnn_cell_impl.GRUCell(3), num_proj=3)
|
||||
cell = contrib_rnn.InputProjectionWrapper(
|
||||
rnn_cell_impl.GRUCell(3), num_proj=3)
|
||||
g, new_m = cell(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run(
|
||||
@ -394,10 +392,10 @@ class RNNCellTest(test.TestCase):
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([1, 3])
|
||||
m = array_ops.zeros([1, 3])
|
||||
base_cell = core_rnn_cell_impl.GRUCell(3)
|
||||
base_cell = rnn_cell_impl.GRUCell(3)
|
||||
g, m_new = base_cell(x, m)
|
||||
variable_scope.get_variable_scope().reuse_variables()
|
||||
g_res, m_new_res = core_rnn_cell_impl.ResidualWrapper(base_cell)(x, m)
|
||||
g_res, m_new_res = rnn_cell_impl.ResidualWrapper(base_cell)(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run([g, g_res, m_new, m_new_res], {
|
||||
x: np.array([[1., 1., 1.]]),
|
||||
@ -413,8 +411,7 @@ class RNNCellTest(test.TestCase):
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([1, 3])
|
||||
m = array_ops.zeros([1, 3])
|
||||
cell = core_rnn_cell_impl.DeviceWrapper(
|
||||
core_rnn_cell_impl.GRUCell(3), "/cpu:14159")
|
||||
cell = rnn_cell_impl.DeviceWrapper(rnn_cell_impl.GRUCell(3), "/cpu:14159")
|
||||
outputs, _ = cell(x, m)
|
||||
self.assertTrue("cpu:14159" in outputs.device.lower())
|
||||
|
||||
@ -427,8 +424,7 @@ class RNNCellTest(test.TestCase):
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([1, 1, 3])
|
||||
cell = core_rnn_cell_impl.DeviceWrapper(
|
||||
core_rnn_cell_impl.GRUCell(3), "/gpu:0")
|
||||
cell = rnn_cell_impl.DeviceWrapper(rnn_cell_impl.GRUCell(3), "/gpu:0")
|
||||
with ops.device("/cpu:0"):
|
||||
outputs, _ = rnn.dynamic_rnn(
|
||||
cell=cell, inputs=x, dtype=dtypes.float32)
|
||||
@ -446,39 +442,14 @@ class RNNCellTest(test.TestCase):
|
||||
self.assertFalse([s for s in cpu_stats if "gru_cell" in s.node_name])
|
||||
self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name])
|
||||
|
||||
# def testUsingSecondCellInScopeWithExistingVariablesFails(self):
|
||||
# # This test should go away when this behavior is no longer an
|
||||
# # error (Approx. May 2017)
|
||||
# cell1 = core_rnn_cell_impl.LSTMCell(3)
|
||||
# cell2 = core_rnn_cell_impl.LSTMCell(3)
|
||||
# x = array_ops.zeros([1, 3])
|
||||
# m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2)
|
||||
# cell1(x, m)
|
||||
# with self.assertRaisesRegexp(ValueError, r"LSTMCell\(..., reuse=True\)"):
|
||||
# cell2(x, m)
|
||||
|
||||
# def testUsingCellInDifferentScopeFromFirstCallFails(self):
|
||||
# # This test should go away when this behavior is no longer an
|
||||
# # error (Approx. May 2017)
|
||||
# cell = core_rnn_cell_impl.LSTMCell(3)
|
||||
# x = array_ops.zeros([1, 3])
|
||||
# m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2)
|
||||
# with variable_scope.variable_scope("scope1"):
|
||||
# cell(x, m)
|
||||
# with variable_scope.variable_scope("scope2"):
|
||||
# with self.assertRaisesRegexp(ValueError, r"Attempt to reuse RNNCell"):
|
||||
# cell(x, m)
|
||||
|
||||
def testEmbeddingWrapper(self):
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([1, 1], dtype=dtypes.int32)
|
||||
m = array_ops.zeros([1, 2])
|
||||
embedding_cell = core_rnn_cell_impl.EmbeddingWrapper(
|
||||
core_rnn_cell_impl.GRUCell(2),
|
||||
embedding_classes=3,
|
||||
embedding_size=2)
|
||||
embedding_cell = contrib_rnn.EmbeddingWrapper(
|
||||
rnn_cell_impl.GRUCell(2), embedding_classes=3, embedding_size=2)
|
||||
self.assertEqual(embedding_cell.output_size, 2)
|
||||
g, new_m = embedding_cell(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
@ -495,9 +466,8 @@ class RNNCellTest(test.TestCase):
|
||||
with variable_scope.variable_scope("root"):
|
||||
inputs = ops.convert_to_tensor([[[0], [0]]], dtype=dtypes.int64)
|
||||
input_lengths = ops.convert_to_tensor([2], dtype=dtypes.int64)
|
||||
embedding_cell = core_rnn_cell_impl.EmbeddingWrapper(
|
||||
core_rnn_cell_impl.BasicLSTMCell(
|
||||
1, state_is_tuple=True),
|
||||
embedding_cell = contrib_rnn.EmbeddingWrapper(
|
||||
rnn_cell_impl.BasicLSTMCell(1, state_is_tuple=True),
|
||||
embedding_classes=1,
|
||||
embedding_size=2)
|
||||
outputs, _ = rnn.dynamic_rnn(
|
||||
@ -515,9 +485,9 @@ class RNNCellTest(test.TestCase):
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([1, 2])
|
||||
m = array_ops.zeros([1, 4])
|
||||
_, ml = core_rnn_cell_impl.MultiRNNCell(
|
||||
[core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
|
||||
state_is_tuple=False)(x, m)
|
||||
_, ml = rnn_cell_impl.MultiRNNCell(
|
||||
[rnn_cell_impl.GRUCell(2)
|
||||
for _ in range(2)], state_is_tuple=False)(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run(ml, {
|
||||
x.name: np.array([[1., 1.]]),
|
||||
@ -536,13 +506,13 @@ class RNNCellTest(test.TestCase):
|
||||
|
||||
# Test incorrectness of state
|
||||
with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
|
||||
core_rnn_cell_impl.MultiRNNCell(
|
||||
[core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
|
||||
state_is_tuple=True)(x, m_bad)
|
||||
rnn_cell_impl.MultiRNNCell(
|
||||
[rnn_cell_impl.GRUCell(2)
|
||||
for _ in range(2)], state_is_tuple=True)(x, m_bad)
|
||||
|
||||
_, ml = core_rnn_cell_impl.MultiRNNCell(
|
||||
[core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
|
||||
state_is_tuple=True)(x, m_good)
|
||||
_, ml = rnn_cell_impl.MultiRNNCell(
|
||||
[rnn_cell_impl.GRUCell(2)
|
||||
for _ in range(2)], state_is_tuple=True)(x, m_good)
|
||||
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run(ml, {
|
||||
@ -571,23 +541,23 @@ class DropoutWrapperTest(test.TestCase):
|
||||
time_steps = 2
|
||||
x = constant_op.constant(
|
||||
[[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32)
|
||||
m = core_rnn_cell_impl.LSTMStateTuple(
|
||||
*[constant_op.constant([[0.1, 0.1, 0.1]],
|
||||
dtype=dtypes.float32)] * 2)
|
||||
m = rnn_cell_impl.LSTMStateTuple(
|
||||
*[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32)
|
||||
] * 2)
|
||||
else:
|
||||
x = constant_op.constant(
|
||||
np.random.randn(time_steps, batch_size, 3).astype(np.float32))
|
||||
m = core_rnn_cell_impl.LSTMStateTuple(
|
||||
*[constant_op.constant([[0.1, 0.1, 0.1]] * batch_size,
|
||||
dtype=dtypes.float32)] * 2)
|
||||
m = rnn_cell_impl.LSTMStateTuple(*[
|
||||
constant_op.constant(
|
||||
[[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32)
|
||||
] * 2)
|
||||
outputs, final_state = rnn.dynamic_rnn(
|
||||
cell=core_rnn_cell_impl.DropoutWrapper(
|
||||
core_rnn_cell_impl.LSTMCell(3),
|
||||
dtype=x.dtype,
|
||||
**kwargs),
|
||||
cell=rnn_cell_impl.DropoutWrapper(
|
||||
rnn_cell_impl.LSTMCell(3), dtype=x.dtype, **kwargs),
|
||||
time_major=True,
|
||||
parallel_iterations=parallel_iterations,
|
||||
inputs=x, initial_state=m)
|
||||
inputs=x,
|
||||
initial_state=m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run([outputs, final_state])
|
||||
self.assertEqual(res[0].shape, (time_steps, batch_size, 3))
|
||||
@ -775,7 +745,7 @@ class SlimRNNCellTest(test.TestCase):
|
||||
m = array_ops.zeros([1, 2])
|
||||
my_cell = functools.partial(basic_rnn_cell, num_units=2)
|
||||
# pylint: disable=protected-access
|
||||
g, _ = core_rnn_cell_impl._SlimRNNCell(my_cell)(x, m)
|
||||
g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m)
|
||||
# pylint: enable=protected-access
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run(
|
||||
@ -792,12 +762,12 @@ class SlimRNNCellTest(test.TestCase):
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
inputs = random_ops.random_uniform((batch_size, input_size))
|
||||
_, initial_state = basic_rnn_cell(inputs, None, num_units)
|
||||
rnn_cell = core_rnn_cell_impl.BasicRNNCell(num_units)
|
||||
rnn_cell = rnn_cell_impl.BasicRNNCell(num_units)
|
||||
outputs, state = rnn_cell(inputs, initial_state)
|
||||
variable_scope.get_variable_scope().reuse_variables()
|
||||
my_cell = functools.partial(basic_rnn_cell, num_units=num_units)
|
||||
# pylint: disable=protected-access
|
||||
slim_cell = core_rnn_cell_impl._SlimRNNCell(my_cell)
|
||||
slim_cell = rnn_cell_impl._SlimRNNCell(my_cell)
|
||||
# pylint: enable=protected-access
|
||||
slim_outputs, slim_state = slim_cell(inputs, initial_state)
|
||||
self.assertEqual(slim_outputs.get_shape(), outputs.get_shape())
|
||||
|
@ -24,9 +24,6 @@ import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.contrib import rnn as rnn_lib
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -38,6 +35,7 @@ from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
@ -153,7 +151,7 @@ class RNNTest(test.TestCase):
|
||||
cell = Plus1RNNCell()
|
||||
inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))]
|
||||
with self.assertRaisesRegexp(ValueError, "must be a vector"):
|
||||
core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32, sequence_length=4)
|
||||
rnn.static_rnn(cell, inputs, dtype=dtypes.float32, sequence_length=4)
|
||||
|
||||
def testRNN(self):
|
||||
cell = Plus1RNNCell()
|
||||
@ -164,7 +162,7 @@ class RNNTest(test.TestCase):
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, shape=(batch_size, input_size))
|
||||
]
|
||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
self.assertEqual(len(outputs), len(inputs))
|
||||
for out, inp in zip(outputs, inputs):
|
||||
self.assertEqual(out.get_shape(), inp.get_shape())
|
||||
@ -186,7 +184,7 @@ class RNNTest(test.TestCase):
|
||||
|
||||
def testDropout(self):
|
||||
cell = Plus1RNNCell()
|
||||
full_dropout_cell = core_rnn_cell_impl.DropoutWrapper(
|
||||
full_dropout_cell = rnn_cell.DropoutWrapper(
|
||||
cell, input_keep_prob=1e-12, seed=0)
|
||||
batch_size = 2
|
||||
input_size = 5
|
||||
@ -196,9 +194,9 @@ class RNNTest(test.TestCase):
|
||||
dtypes.float32, shape=(batch_size, input_size))
|
||||
]
|
||||
with variable_scope.variable_scope("share_scope"):
|
||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
with variable_scope.variable_scope("drop_scope"):
|
||||
dropped_outputs, _ = core_rnn.static_rnn(
|
||||
dropped_outputs, _ = rnn.static_rnn(
|
||||
full_dropout_cell, inputs, dtype=dtypes.float32)
|
||||
self.assertEqual(len(outputs), len(inputs))
|
||||
for out, inp in zip(outputs, inputs):
|
||||
@ -227,7 +225,7 @@ class RNNTest(test.TestCase):
|
||||
dtypes.float32, shape=(batch_size, input_size))
|
||||
]
|
||||
with variable_scope.variable_scope("drop_scope"):
|
||||
dynamic_outputs, dynamic_state = core_rnn.static_rnn(
|
||||
dynamic_outputs, dynamic_state = rnn.static_rnn(
|
||||
cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32)
|
||||
self.assertEqual(len(dynamic_outputs), len(inputs))
|
||||
|
||||
@ -297,8 +295,7 @@ class RNNTest(test.TestCase):
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, shape=(batch_size, input_size))
|
||||
]
|
||||
return core_rnn.static_rnn(
|
||||
cell, inputs, dtype=dtypes.float32, scope=scope)
|
||||
return rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope=scope)
|
||||
|
||||
self._testScope(factory, use_outer_scope=True)
|
||||
self._testScope(factory, use_outer_scope=False)
|
||||
@ -319,13 +316,13 @@ class LSTMTest(test.TestCase):
|
||||
with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess:
|
||||
initializer = init_ops.random_uniform_initializer(
|
||||
-0.01, 0.01, seed=self._seed)
|
||||
cell = core_rnn_cell_impl.LSTMCell(
|
||||
cell = rnn_cell.LSTMCell(
|
||||
num_units, initializer=initializer, state_is_tuple=False)
|
||||
inputs = max_length * [
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, shape=(batch_size, input_size))
|
||||
]
|
||||
outputs, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
self.assertEqual(len(outputs), len(inputs))
|
||||
for out in outputs:
|
||||
self.assertEqual(out.get_shape().as_list(), [batch_size, num_units])
|
||||
@ -342,7 +339,7 @@ class LSTMTest(test.TestCase):
|
||||
with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess:
|
||||
initializer = init_ops.random_uniform_initializer(
|
||||
-0.01, 0.01, seed=self._seed)
|
||||
cell = core_rnn_cell_impl.LSTMCell(
|
||||
cell = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=True,
|
||||
cell_clip=0.0,
|
||||
@ -352,7 +349,7 @@ class LSTMTest(test.TestCase):
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, shape=(batch_size, input_size))
|
||||
]
|
||||
outputs, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
self.assertEqual(len(outputs), len(inputs))
|
||||
for out in outputs:
|
||||
self.assertEqual(out.get_shape().as_list(), [batch_size, num_units])
|
||||
@ -374,7 +371,7 @@ class LSTMTest(test.TestCase):
|
||||
initializer = init_ops.random_uniform_initializer(
|
||||
-0.01, 0.01, seed=self._seed)
|
||||
state_saver = TestStateSaver(batch_size, 2 * num_units)
|
||||
cell = core_rnn_cell_impl.LSTMCell(
|
||||
cell = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=False,
|
||||
initializer=initializer,
|
||||
@ -384,7 +381,7 @@ class LSTMTest(test.TestCase):
|
||||
dtypes.float32, shape=(batch_size, input_size))
|
||||
]
|
||||
with variable_scope.variable_scope("share_scope"):
|
||||
outputs, state = core_rnn.static_state_saving_rnn(
|
||||
outputs, state = rnn.static_state_saving_rnn(
|
||||
cell, inputs, state_saver=state_saver, state_name="save_lstm")
|
||||
self.assertEqual(len(outputs), len(inputs))
|
||||
for out in outputs:
|
||||
@ -406,7 +403,7 @@ class LSTMTest(test.TestCase):
|
||||
initializer = init_ops.random_uniform_initializer(
|
||||
-0.01, 0.01, seed=self._seed)
|
||||
state_saver = TestStateSaver(batch_size, num_units)
|
||||
cell = core_rnn_cell_impl.LSTMCell(
|
||||
cell = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=False,
|
||||
initializer=initializer,
|
||||
@ -416,7 +413,7 @@ class LSTMTest(test.TestCase):
|
||||
dtypes.float32, shape=(batch_size, input_size))
|
||||
]
|
||||
with variable_scope.variable_scope("share_scope"):
|
||||
outputs, state = core_rnn.static_state_saving_rnn(
|
||||
outputs, state = rnn.static_state_saving_rnn(
|
||||
cell, inputs, state_saver=state_saver, state_name=("c", "m"))
|
||||
self.assertEqual(len(outputs), len(inputs))
|
||||
for out in outputs:
|
||||
@ -450,14 +447,14 @@ class LSTMTest(test.TestCase):
|
||||
})
|
||||
|
||||
def _cell(i):
|
||||
return core_rnn_cell_impl.LSTMCell(
|
||||
return rnn_cell.LSTMCell(
|
||||
num_units + i,
|
||||
use_peepholes=False,
|
||||
initializer=initializer,
|
||||
state_is_tuple=True)
|
||||
|
||||
# This creates a state tuple which has 4 sub-tuples of length 2 each.
|
||||
cell = core_rnn_cell_impl.MultiRNNCell(
|
||||
cell = rnn_cell.MultiRNNCell(
|
||||
[_cell(i) for i in range(4)], state_is_tuple=True)
|
||||
|
||||
self.assertEqual(len(cell.state_size), 4)
|
||||
@ -471,7 +468,7 @@ class LSTMTest(test.TestCase):
|
||||
|
||||
state_names = (("c0", "m0"), ("c1", "m1"), ("c2", "m2"), ("c3", "m3"))
|
||||
with variable_scope.variable_scope("share_scope"):
|
||||
outputs, state = core_rnn.static_state_saving_rnn(
|
||||
outputs, state = rnn.static_state_saving_rnn(
|
||||
cell, inputs, state_saver=state_saver, state_name=state_names)
|
||||
self.assertEqual(len(outputs), len(inputs))
|
||||
|
||||
@ -508,13 +505,13 @@ class LSTMTest(test.TestCase):
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, shape=(None, input_size))
|
||||
]
|
||||
cell = core_rnn_cell_impl.LSTMCell(
|
||||
cell = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=True,
|
||||
num_proj=num_proj,
|
||||
initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
outputs, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
self.assertEqual(len(outputs), len(inputs))
|
||||
|
||||
variables_lib.global_variables_initializer().run()
|
||||
@ -535,20 +532,20 @@ class LSTMTest(test.TestCase):
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, shape=(None, input_size))
|
||||
]
|
||||
cell_notuple = core_rnn_cell_impl.LSTMCell(
|
||||
cell_notuple = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=True,
|
||||
num_proj=num_proj,
|
||||
initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
cell_tuple = core_rnn_cell_impl.LSTMCell(
|
||||
cell_tuple = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=True,
|
||||
num_proj=num_proj,
|
||||
initializer=initializer,
|
||||
state_is_tuple=True)
|
||||
with variable_scope.variable_scope("root") as scope:
|
||||
outputs_notuple, state_notuple = core_rnn.static_rnn(
|
||||
outputs_notuple, state_notuple = rnn.static_rnn(
|
||||
cell_notuple,
|
||||
inputs,
|
||||
dtype=dtypes.float32,
|
||||
@ -562,7 +559,7 @@ class LSTMTest(test.TestCase):
|
||||
# the parameters from different RNNCell instances. Right now,
|
||||
# this seems an unrealistic use case except for testing.
|
||||
cell_tuple._scope = cell_notuple._scope # pylint: disable=protected-access
|
||||
outputs_tuple, state_tuple = core_rnn.static_rnn(
|
||||
outputs_tuple, state_tuple = rnn.static_rnn(
|
||||
cell_tuple,
|
||||
inputs,
|
||||
dtype=dtypes.float32,
|
||||
@ -603,7 +600,7 @@ class LSTMTest(test.TestCase):
|
||||
dtypes.float32, shape=(None, input_size))
|
||||
]
|
||||
|
||||
cell = core_rnn_cell_impl.LSTMCell(
|
||||
cell = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=True,
|
||||
num_proj=num_proj,
|
||||
@ -612,7 +609,7 @@ class LSTMTest(test.TestCase):
|
||||
initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
|
||||
outputs, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
|
||||
self.assertEqual(len(outputs), len(inputs))
|
||||
|
||||
@ -635,7 +632,7 @@ class LSTMTest(test.TestCase):
|
||||
dtypes.float64, shape=(None, input_size))
|
||||
]
|
||||
|
||||
cell = core_rnn_cell_impl.LSTMCell(
|
||||
cell = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=True,
|
||||
num_proj=num_proj,
|
||||
@ -644,7 +641,7 @@ class LSTMTest(test.TestCase):
|
||||
initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
|
||||
outputs, _ = core_rnn.static_rnn(
|
||||
outputs, _ = rnn.static_rnn(
|
||||
cell,
|
||||
inputs,
|
||||
initial_state=cell.zero_state(batch_size, dtypes.float64))
|
||||
@ -672,7 +669,7 @@ class LSTMTest(test.TestCase):
|
||||
]
|
||||
initializer = init_ops.constant_initializer(0.001)
|
||||
|
||||
cell_noshard = core_rnn_cell_impl.LSTMCell(
|
||||
cell_noshard = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
num_proj=num_proj,
|
||||
use_peepholes=True,
|
||||
@ -681,7 +678,7 @@ class LSTMTest(test.TestCase):
|
||||
num_proj_shards=num_proj_shards,
|
||||
state_is_tuple=False)
|
||||
|
||||
cell_shard = core_rnn_cell_impl.LSTMCell(
|
||||
cell_shard = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=True,
|
||||
initializer=initializer,
|
||||
@ -689,10 +686,10 @@ class LSTMTest(test.TestCase):
|
||||
state_is_tuple=False)
|
||||
|
||||
with variable_scope.variable_scope("noshard_scope"):
|
||||
outputs_noshard, state_noshard = core_rnn.static_rnn(
|
||||
outputs_noshard, state_noshard = rnn.static_rnn(
|
||||
cell_noshard, inputs, dtype=dtypes.float32)
|
||||
with variable_scope.variable_scope("shard_scope"):
|
||||
outputs_shard, state_shard = core_rnn.static_rnn(
|
||||
outputs_shard, state_shard = rnn.static_rnn(
|
||||
cell_shard, inputs, dtype=dtypes.float32)
|
||||
|
||||
self.assertEqual(len(outputs_noshard), len(inputs))
|
||||
@ -731,7 +728,7 @@ class LSTMTest(test.TestCase):
|
||||
dtypes.float64, shape=(None, input_size))
|
||||
]
|
||||
|
||||
cell = core_rnn_cell_impl.LSTMCell(
|
||||
cell = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=True,
|
||||
num_proj=num_proj,
|
||||
@ -739,9 +736,9 @@ class LSTMTest(test.TestCase):
|
||||
num_proj_shards=num_proj_shards,
|
||||
initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
dropout_cell = core_rnn_cell_impl.DropoutWrapper(cell, 0.5, seed=0)
|
||||
dropout_cell = rnn_cell.DropoutWrapper(cell, 0.5, seed=0)
|
||||
|
||||
outputs, state = core_rnn.static_rnn(
|
||||
outputs, state = rnn.static_rnn(
|
||||
dropout_cell,
|
||||
inputs,
|
||||
sequence_length=sequence_length,
|
||||
@ -776,13 +773,13 @@ class LSTMTest(test.TestCase):
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, shape=(None, input_size))
|
||||
]
|
||||
cell = core_rnn_cell_impl.LSTMCell(
|
||||
cell = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=True,
|
||||
num_proj=num_proj,
|
||||
initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
cell_d = core_rnn_cell_impl.LSTMCell(
|
||||
cell_d = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=True,
|
||||
num_proj=num_proj,
|
||||
@ -790,11 +787,11 @@ class LSTMTest(test.TestCase):
|
||||
state_is_tuple=False)
|
||||
|
||||
with variable_scope.variable_scope("share_scope"):
|
||||
outputs0, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
outputs0, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
with variable_scope.variable_scope("share_scope", reuse=True):
|
||||
outputs1, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
outputs1, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
with variable_scope.variable_scope("diff_scope"):
|
||||
outputs2, _ = core_rnn.static_rnn(cell_d, inputs, dtype=dtypes.float32)
|
||||
outputs2, _ = rnn.static_rnn(cell_d, inputs, dtype=dtypes.float32)
|
||||
|
||||
variables_lib.global_variables_initializer().run()
|
||||
input_value = np.random.randn(batch_size, input_size)
|
||||
@ -823,7 +820,7 @@ class LSTMTest(test.TestCase):
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, shape=(None, input_size))
|
||||
]
|
||||
cell = core_rnn_cell_impl.LSTMCell(
|
||||
cell = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=True,
|
||||
num_proj=num_proj,
|
||||
@ -832,10 +829,10 @@ class LSTMTest(test.TestCase):
|
||||
|
||||
with ops_lib.name_scope("scope0"):
|
||||
with variable_scope.variable_scope("share_scope"):
|
||||
outputs0, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
outputs0, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
with ops_lib.name_scope("scope1"):
|
||||
with variable_scope.variable_scope("share_scope", reuse=True):
|
||||
outputs1, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
outputs1, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
|
||||
variables_lib.global_variables_initializer().run()
|
||||
input_value = np.random.randn(batch_size, input_size)
|
||||
@ -881,7 +878,7 @@ class LSTMTest(test.TestCase):
|
||||
|
||||
def testDynamicRNNAllowsUnknownTimeDimension(self):
|
||||
inputs = array_ops.placeholder(dtypes.float32, shape=[1, None, 20])
|
||||
cell = core_rnn_cell.GRUCell(30)
|
||||
cell = rnn_cell.GRUCell(30)
|
||||
# Smoke test, this should not raise an error
|
||||
rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
|
||||
@ -900,14 +897,14 @@ class LSTMTest(test.TestCase):
|
||||
dtypes.float32, shape=(None, input_size))
|
||||
]
|
||||
inputs_c = array_ops.stack(inputs)
|
||||
cell = core_rnn_cell.LSTMCell(
|
||||
cell = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=True,
|
||||
num_proj=num_proj,
|
||||
initializer=initializer,
|
||||
state_is_tuple=True)
|
||||
with variable_scope.variable_scope("root") as scope:
|
||||
outputs_static, state_static = core_rnn.static_rnn(
|
||||
outputs_static, state_static = rnn.static_rnn(
|
||||
cell,
|
||||
inputs,
|
||||
dtype=dtypes.float32,
|
||||
@ -921,8 +918,8 @@ class LSTMTest(test.TestCase):
|
||||
time_major=True,
|
||||
sequence_length=sequence_length,
|
||||
scope=scope)
|
||||
self.assertTrue(isinstance(state_static, core_rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(state_dynamic, core_rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(state_static, rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(state_dynamic, rnn_cell.LSTMStateTuple))
|
||||
self.assertEqual(state_static[0], state_static.c)
|
||||
self.assertEqual(state_static[1], state_static.h)
|
||||
self.assertEqual(state_dynamic[0], state_dynamic.c)
|
||||
@ -960,7 +957,7 @@ class LSTMTest(test.TestCase):
|
||||
inputs_c = array_ops.stack(inputs)
|
||||
|
||||
def _cell(i):
|
||||
return core_rnn_cell.LSTMCell(
|
||||
return rnn_cell.LSTMCell(
|
||||
num_units + i,
|
||||
use_peepholes=True,
|
||||
num_proj=num_proj + i,
|
||||
@ -968,7 +965,7 @@ class LSTMTest(test.TestCase):
|
||||
state_is_tuple=True)
|
||||
|
||||
# This creates a state tuple which has 4 sub-tuples of length 2 each.
|
||||
cell = core_rnn_cell.MultiRNNCell(
|
||||
cell = rnn_cell.MultiRNNCell(
|
||||
[_cell(i) for i in range(4)], state_is_tuple=True)
|
||||
|
||||
self.assertEqual(len(cell.state_size), 4)
|
||||
@ -982,7 +979,7 @@ class LSTMTest(test.TestCase):
|
||||
self.assertEqual(test_zero[i][1].get_shape()[1], cell.state_size[i][1])
|
||||
|
||||
with variable_scope.variable_scope("root") as scope:
|
||||
outputs_static, state_static = core_rnn.static_rnn(
|
||||
outputs_static, state_static = rnn.static_rnn(
|
||||
cell,
|
||||
inputs,
|
||||
dtype=dtypes.float32,
|
||||
@ -1034,7 +1031,7 @@ class LSTMTest(test.TestCase):
|
||||
initializer = init_ops.random_uniform_initializer(
|
||||
-0.01, 0.01, seed=self._seed)
|
||||
|
||||
cell = core_rnn_cell.LSTMCell(
|
||||
cell = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=True,
|
||||
initializer=initializer,
|
||||
@ -1042,7 +1039,7 @@ class LSTMTest(test.TestCase):
|
||||
state_is_tuple=False)
|
||||
|
||||
with variable_scope.variable_scope("dynamic_scope"):
|
||||
outputs_static, state_static = core_rnn.static_rnn(
|
||||
outputs_static, state_static = rnn.static_rnn(
|
||||
cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32)
|
||||
|
||||
feeds = {concat_inputs: input_values}
|
||||
@ -1092,7 +1089,7 @@ class LSTMTest(test.TestCase):
|
||||
initializer = init_ops.random_uniform_initializer(
|
||||
-0.01, 0.01, seed=self._seed)
|
||||
|
||||
cell = core_rnn_cell.LSTMCell(
|
||||
cell = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=True,
|
||||
initializer=initializer,
|
||||
@ -1205,16 +1202,16 @@ class BidirectionalRNNTest(test.TestCase):
|
||||
-0.01, 0.01, seed=self._seed)
|
||||
sequence_length = array_ops.placeholder(
|
||||
dtypes.int64) if use_sequence_length else None
|
||||
cell_fw = core_rnn_cell_impl.LSTMCell(
|
||||
cell_fw = rnn_cell.LSTMCell(
|
||||
num_units, input_size, initializer=initializer, state_is_tuple=False)
|
||||
cell_bw = core_rnn_cell_impl.LSTMCell(
|
||||
cell_bw = rnn_cell.LSTMCell(
|
||||
num_units, input_size, initializer=initializer, state_is_tuple=False)
|
||||
inputs = max_length * [
|
||||
array_ops.placeholder(
|
||||
dtypes.float32,
|
||||
shape=(batch_size, input_size) if use_shape else (None, input_size))
|
||||
]
|
||||
outputs, state_fw, state_bw = core_rnn.static_bidirectional_rnn(
|
||||
outputs, state_fw, state_bw = rnn.static_bidirectional_rnn(
|
||||
cell_fw,
|
||||
cell_bw,
|
||||
inputs,
|
||||
@ -1337,9 +1334,9 @@ class BidirectionalRNNTest(test.TestCase):
|
||||
-0.01, 0.01, seed=self._seed)
|
||||
sequence_length = (
|
||||
array_ops.placeholder(dtypes.int64) if use_sequence_length else None)
|
||||
cell_fw = core_rnn_cell.LSTMCell(
|
||||
cell_fw = rnn_cell.LSTMCell(
|
||||
num_units, initializer=initializer, state_is_tuple=use_state_tuple)
|
||||
cell_bw = core_rnn_cell.LSTMCell(
|
||||
cell_bw = rnn_cell.LSTMCell(
|
||||
num_units, initializer=initializer, state_is_tuple=use_state_tuple)
|
||||
inputs = max_length * [
|
||||
array_ops.placeholder(
|
||||
@ -1530,7 +1527,7 @@ class MultiDimensionalLSTMTest(test.TestCase):
|
||||
# variables.
|
||||
cell = DummyMultiDimensionalLSTM(feature_dims)
|
||||
state_saver = TestStateSaver(batch_size, input_size)
|
||||
outputs_static, state_static = core_rnn.static_rnn(
|
||||
outputs_static, state_static = rnn.static_rnn(
|
||||
cell, inputs, dtype=dtypes.float32, sequence_length=sequence_length)
|
||||
outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
|
||||
cell,
|
||||
@ -1538,13 +1535,13 @@ class MultiDimensionalLSTMTest(test.TestCase):
|
||||
dtype=dtypes.float32,
|
||||
time_major=True,
|
||||
sequence_length=sequence_length)
|
||||
outputs_bid, state_fw, state_bw = core_rnn.static_bidirectional_rnn(
|
||||
outputs_bid, state_fw, state_bw = rnn.static_bidirectional_rnn(
|
||||
cell,
|
||||
cell,
|
||||
inputs_using_dim,
|
||||
dtype=dtypes.float32,
|
||||
sequence_length=sequence_length)
|
||||
outputs_sav, state_sav = core_rnn.static_state_saving_rnn(
|
||||
outputs_sav, state_sav = rnn.static_state_saving_rnn(
|
||||
cell,
|
||||
inputs_using_dim,
|
||||
sequence_length=sequence_length,
|
||||
@ -1634,15 +1631,15 @@ class NestedLSTMTest(test.TestCase):
|
||||
dtype=dtypes.float32,
|
||||
time_major=True,
|
||||
sequence_length=sequence_length)
|
||||
outputs_static, state_static = core_rnn.static_rnn(
|
||||
outputs_static, state_static = rnn.static_rnn(
|
||||
cell, inputs, dtype=dtypes.float32, sequence_length=sequence_length)
|
||||
outputs_bid, state_fw, state_bw = core_rnn.static_bidirectional_rnn(
|
||||
outputs_bid, state_fw, state_bw = rnn.static_bidirectional_rnn(
|
||||
cell,
|
||||
cell,
|
||||
inputs_using_dim,
|
||||
dtype=dtypes.float32,
|
||||
sequence_length=sequence_length)
|
||||
outputs_sav, state_sav = core_rnn.static_state_saving_rnn(
|
||||
outputs_sav, state_sav = rnn.static_state_saving_rnn(
|
||||
cell,
|
||||
inputs_using_dim,
|
||||
sequence_length=sequence_length,
|
||||
@ -1738,7 +1735,7 @@ class StateSaverRNNTest(test.TestCase):
|
||||
initializer = init_ops.random_uniform_initializer(
|
||||
-0.01, 0.01, seed=self._seed)
|
||||
state_saver = TestStateSaver(batch_size, 2 * num_units)
|
||||
cell = core_rnn_cell_impl.LSTMCell(
|
||||
cell = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=False,
|
||||
initializer=initializer,
|
||||
@ -1747,7 +1744,7 @@ class StateSaverRNNTest(test.TestCase):
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, shape=(batch_size, input_size))
|
||||
]
|
||||
return core_rnn.static_state_saving_rnn(
|
||||
return rnn.static_state_saving_rnn(
|
||||
cell,
|
||||
inputs,
|
||||
state_saver=state_saver,
|
||||
@ -1779,7 +1776,7 @@ class GRUTest(test.TestCase):
|
||||
concat_inputs = array_ops.placeholder(
|
||||
dtypes.float32, shape=(time_steps, batch_size, input_size))
|
||||
|
||||
cell = core_rnn_cell.GRUCell(num_units=num_units)
|
||||
cell = rnn_cell.GRUCell(num_units=num_units)
|
||||
|
||||
with variable_scope.variable_scope("dynamic_scope"):
|
||||
outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
|
||||
@ -1830,7 +1827,7 @@ class GRUTest(test.TestCase):
|
||||
def factory(scope):
|
||||
concat_inputs = array_ops.placeholder(
|
||||
dtypes.float32, shape=(time_steps, batch_size, input_size))
|
||||
cell = core_rnn_cell.GRUCell(num_units=num_units)
|
||||
cell = rnn_cell.GRUCell(num_units=num_units)
|
||||
return rnn.dynamic_rnn(
|
||||
cell,
|
||||
inputs=concat_inputs,
|
||||
@ -1864,7 +1861,7 @@ class RawRNNTest(test.TestCase):
|
||||
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
||||
inputs_ta = inputs_ta.unstack(inputs)
|
||||
|
||||
cell = core_rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
||||
cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
||||
|
||||
def loop_fn(time_, cell_output, cell_state, unused_loop_state):
|
||||
emit_output = cell_output # == None for time == 0
|
||||
@ -1965,7 +1962,7 @@ class RawRNNTest(test.TestCase):
|
||||
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
||||
inputs_ta = inputs_ta.unstack(inputs)
|
||||
|
||||
cell = core_rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
||||
cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
||||
|
||||
def loop_fn(time_, cell_output, cell_state, loop_state):
|
||||
if cell_output is None:
|
||||
@ -2001,7 +1998,7 @@ class RawRNNTest(test.TestCase):
|
||||
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
||||
inputs_ta = inputs_ta.unstack(inputs)
|
||||
|
||||
cell = core_rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
||||
cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
||||
|
||||
def loop_fn(time_, cell_output, cell_state, loop_state):
|
||||
if cell_output is None:
|
||||
@ -2044,7 +2041,7 @@ class RawRNNTest(test.TestCase):
|
||||
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
||||
inputs_ta = inputs_ta.unstack(inputs)
|
||||
|
||||
cell = core_rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
||||
cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
||||
|
||||
def loop_fn(time_, cell_output, cell_state, _):
|
||||
if cell_output is None:
|
||||
@ -2113,7 +2110,7 @@ class RawRNNTest(test.TestCase):
|
||||
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
||||
inputs_ta = inputs_ta.unstack(inputs)
|
||||
|
||||
cell = core_rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
||||
cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
||||
|
||||
def loop_fn(time_, cell_output, cell_state, unused_loop_state):
|
||||
emit_output = cell_output # == None for time == 0
|
||||
@ -2138,7 +2135,7 @@ class RawRNNTest(test.TestCase):
|
||||
self._testScope(factory, prefix=None, use_outer_scope=False)
|
||||
|
||||
|
||||
class DeviceWrapperCell(core_rnn_cell.RNNCell):
|
||||
class DeviceWrapperCell(rnn_cell.RNNCell):
|
||||
"""Class to ensure cell calculation happens on a specific device."""
|
||||
|
||||
def __init__(self, cell, device):
|
||||
@ -2172,7 +2169,7 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase):
|
||||
input_size = 5
|
||||
num_units = 10
|
||||
|
||||
cell = core_rnn_cell.LSTMCell(num_units, use_peepholes=True)
|
||||
cell = rnn_cell.LSTMCell(num_units, use_peepholes=True)
|
||||
gpu_cell = DeviceWrapperCell(cell, cell_device)
|
||||
inputs = np.random.randn(batch_size, time_steps,
|
||||
input_size).astype(np.float32)
|
||||
|
@ -20,14 +20,14 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
||||
from tensorflow.contrib.rnn.python.ops import fused_rnn_cell
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
@ -41,7 +41,7 @@ class FusedRnnCellTest(test.TestCase):
|
||||
with self.test_session() as sess:
|
||||
initializer = init_ops.random_uniform_initializer(
|
||||
-0.01, 0.01, seed=19890212)
|
||||
cell = core_rnn_cell_impl.BasicRNNCell(10)
|
||||
cell = rnn_cell.BasicRNNCell(10)
|
||||
batch_size = 5
|
||||
input_size = 20
|
||||
timelen = 15
|
||||
@ -49,7 +49,7 @@ class FusedRnnCellTest(test.TestCase):
|
||||
np.random.randn(timelen, batch_size, input_size))
|
||||
with variable_scope.variable_scope("basic", initializer=initializer):
|
||||
unpacked_inputs = array_ops.unstack(inputs)
|
||||
outputs, state = core_rnn.static_rnn(
|
||||
outputs, state = rnn.static_rnn(
|
||||
cell, unpacked_inputs, dtype=dtypes.float64)
|
||||
packed_outputs = array_ops.stack(outputs)
|
||||
basic_vars = [
|
||||
@ -65,7 +65,7 @@ class FusedRnnCellTest(test.TestCase):
|
||||
with variable_scope.variable_scope(
|
||||
"fused_static", initializer=initializer):
|
||||
fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(
|
||||
core_rnn_cell_impl.BasicRNNCell(10))
|
||||
rnn_cell.BasicRNNCell(10))
|
||||
outputs, state = fused_cell(inputs, dtype=dtypes.float64)
|
||||
fused_static_vars = [
|
||||
v for v in variables.trainable_variables()
|
||||
@ -86,7 +86,7 @@ class FusedRnnCellTest(test.TestCase):
|
||||
with variable_scope.variable_scope(
|
||||
"fused_dynamic", initializer=initializer):
|
||||
fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(
|
||||
core_rnn_cell_impl.BasicRNNCell(10), use_dynamic_rnn=True)
|
||||
rnn_cell.BasicRNNCell(10), use_dynamic_rnn=True)
|
||||
outputs, state = fused_cell(inputs, dtype=dtypes.float64)
|
||||
fused_dynamic_vars = [
|
||||
v for v in variables.trainable_variables()
|
||||
@ -109,8 +109,8 @@ class FusedRnnCellTest(test.TestCase):
|
||||
with self.test_session() as sess:
|
||||
initializer = init_ops.random_uniform_initializer(
|
||||
-0.01, 0.01, seed=19890213)
|
||||
fw_cell = core_rnn_cell_impl.BasicRNNCell(10)
|
||||
bw_cell = core_rnn_cell_impl.BasicRNNCell(10)
|
||||
fw_cell = rnn_cell.BasicRNNCell(10)
|
||||
bw_cell = rnn_cell.BasicRNNCell(10)
|
||||
batch_size = 5
|
||||
input_size = 20
|
||||
timelen = 15
|
||||
@ -120,7 +120,7 @@ class FusedRnnCellTest(test.TestCase):
|
||||
# test bi-directional rnn
|
||||
with variable_scope.variable_scope("basic", initializer=initializer):
|
||||
unpacked_inputs = array_ops.unstack(inputs)
|
||||
outputs, fw_state, bw_state = core_rnn.static_bidirectional_rnn(
|
||||
outputs, fw_state, bw_state = rnn.static_bidirectional_rnn(
|
||||
fw_cell, bw_cell, unpacked_inputs, dtype=dtypes.float64)
|
||||
packed_outputs = array_ops.stack(outputs)
|
||||
basic_vars = [
|
||||
@ -136,10 +136,9 @@ class FusedRnnCellTest(test.TestCase):
|
||||
|
||||
with variable_scope.variable_scope("fused", initializer=initializer):
|
||||
fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(
|
||||
core_rnn_cell_impl.BasicRNNCell(10))
|
||||
rnn_cell.BasicRNNCell(10))
|
||||
fused_bw_cell = fused_rnn_cell.TimeReversedFusedRNN(
|
||||
fused_rnn_cell.FusedRNNCellAdaptor(
|
||||
core_rnn_cell_impl.BasicRNNCell(10)))
|
||||
fused_rnn_cell.FusedRNNCellAdaptor(rnn_cell.BasicRNNCell(10)))
|
||||
fw_outputs, fw_state = fused_cell(
|
||||
inputs, dtype=dtypes.float64, scope="fw")
|
||||
bw_outputs, bw_state = fused_bw_cell(
|
||||
|
@ -22,7 +22,6 @@ import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
||||
from tensorflow.contrib.rnn.python.ops import gru_ops
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -33,6 +32,7 @@ from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
@ -78,7 +78,7 @@ class GRUBlockCellTest(test.TestCase):
|
||||
|
||||
# Output from the basic GRU cell implementation.
|
||||
with vs.variable_scope("basic", initializer=initializer):
|
||||
output = core_rnn_cell_impl.GRUCell(cell_size)(x, h)
|
||||
output = rnn_cell.GRUCell(cell_size)(x, h)
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
basic_res = sess.run([output], {x: x_value, h: h_value})
|
||||
|
||||
@ -128,7 +128,7 @@ class GRUBlockCellTest(test.TestCase):
|
||||
|
||||
# Output from the basic GRU cell implementation.
|
||||
with vs.variable_scope("basic", initializer=initializer):
|
||||
cell = core_rnn_cell_impl.GRUCell(cell_size)
|
||||
cell = rnn_cell.GRUCell(cell_size)
|
||||
outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
|
||||
cell,
|
||||
inputs=concat_x,
|
||||
@ -192,7 +192,7 @@ class GRUBlockCellTest(test.TestCase):
|
||||
|
||||
# Gradients from the basic GRU cell implementation.
|
||||
with vs.variable_scope("basic", initializer=initializer):
|
||||
output = core_rnn_cell_impl.GRUCell(cell_size)(x, h)
|
||||
output = rnn_cell.GRUCell(cell_size)(x, h)
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
|
||||
all_variables = variables.global_variables()[4:8]
|
||||
@ -258,7 +258,7 @@ class GRUBlockCellTest(test.TestCase):
|
||||
|
||||
# Gradients from the basic GRU cell implementation.
|
||||
with vs.variable_scope("basic", initializer=initializer):
|
||||
cell = core_rnn_cell_impl.GRUCell(cell_size)
|
||||
cell = rnn_cell.GRUCell(cell_size)
|
||||
|
||||
outputs_dynamic, _ = rnn.dynamic_rnn(
|
||||
cell,
|
||||
@ -377,7 +377,7 @@ def training_gru_block_vs_gru_cell(batch_size,
|
||||
|
||||
# Output from the basic GRU cell implementation.
|
||||
with vs.variable_scope("basic", initializer=initializer):
|
||||
cell = core_rnn_cell_impl.GRUCell(cell_size)
|
||||
cell = rnn_cell.GRUCell(cell_size)
|
||||
|
||||
outputs_dynamic, _ = rnn.dynamic_rnn(
|
||||
cell,
|
||||
@ -448,7 +448,7 @@ def inference_gru_block_vs_gru_cell(batch_size,
|
||||
|
||||
# Output from the basic GRU cell implementation.
|
||||
with vs.variable_scope("basic", initializer=initializer):
|
||||
cell = core_rnn_cell_impl.GRUCell(cell_size)
|
||||
cell = rnn_cell.GRUCell(cell_size)
|
||||
outputs_dynamic, _ = rnn.dynamic_rnn(
|
||||
cell,
|
||||
inputs=concat_x,
|
||||
@ -497,8 +497,8 @@ def single_bprop_step_gru_block_vs_gru_cell(batch_size,
|
||||
|
||||
# Output from the basic GRU cell implementation.
|
||||
with vs.variable_scope("basic", initializer=initializer):
|
||||
output = core_rnn_cell_impl.GRUCell(cell_size)(array_ops.identity(x),
|
||||
array_ops.identity(h))
|
||||
output = rnn_cell.GRUCell(cell_size)(array_ops.identity(x),
|
||||
array_ops.identity(h))
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
grad_output_wrt_input = gradients_impl.gradients([output], h)
|
||||
basic_time_bprop = time_taken_by_op(grad_output_wrt_input, sess, iters)
|
||||
|
@ -20,8 +20,6 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
||||
from tensorflow.contrib.rnn.python.ops import lstm_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -30,6 +28,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
@ -66,10 +65,9 @@ class LSTMBlockCellTest(test.TestCase):
|
||||
m1 = array_ops.zeros([1, 2])
|
||||
m2 = array_ops.zeros([1, 2])
|
||||
m3 = array_ops.zeros([1, 2])
|
||||
g, ((out_m0, out_m1),
|
||||
(out_m2, out_m3)) = core_rnn_cell_impl.MultiRNNCell(
|
||||
[lstm_ops.LSTMBlockCell(2) for _ in range(2)],
|
||||
state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
|
||||
g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
|
||||
[lstm_ops.LSTMBlockCell(2)
|
||||
for _ in range(2)], state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
|
||||
x.name: np.array([[1., 1.]]),
|
||||
@ -88,11 +86,11 @@ class LSTMBlockCellTest(test.TestCase):
|
||||
|
||||
def testCompatibleNames(self):
|
||||
with self.test_session(use_gpu=self._use_gpu, graph=ops.Graph()):
|
||||
cell = core_rnn_cell_impl.LSTMCell(10)
|
||||
pcell = core_rnn_cell_impl.LSTMCell(10, use_peepholes=True)
|
||||
cell = rnn_cell.LSTMCell(10)
|
||||
pcell = rnn_cell.LSTMCell(10, use_peepholes=True)
|
||||
inputs = [array_ops.zeros([4, 5])] * 6
|
||||
core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope="basic")
|
||||
core_rnn.static_rnn(pcell, inputs, dtype=dtypes.float32, scope="peephole")
|
||||
rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope="basic")
|
||||
rnn.static_rnn(pcell, inputs, dtype=dtypes.float32, scope="peephole")
|
||||
basic_names = {
|
||||
v.name: v.get_shape()
|
||||
for v in variables.trainable_variables()
|
||||
@ -102,8 +100,8 @@ class LSTMBlockCellTest(test.TestCase):
|
||||
cell = lstm_ops.LSTMBlockCell(10)
|
||||
pcell = lstm_ops.LSTMBlockCell(10, use_peephole=True)
|
||||
inputs = [array_ops.zeros([4, 5])] * 6
|
||||
core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope="basic")
|
||||
core_rnn.static_rnn(pcell, inputs, dtype=dtypes.float32, scope="peephole")
|
||||
rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope="basic")
|
||||
rnn.static_rnn(pcell, inputs, dtype=dtypes.float32, scope="peephole")
|
||||
block_names = {
|
||||
v.name: v.get_shape()
|
||||
for v in variables.trainable_variables()
|
||||
@ -140,11 +138,9 @@ class LSTMBlockCellTest(test.TestCase):
|
||||
m1 = array_ops.zeros([1, 2])
|
||||
m2 = array_ops.zeros([1, 2])
|
||||
m3 = array_ops.zeros([1, 2])
|
||||
g, ((out_m0, out_m1),
|
||||
(out_m2, out_m3)) = core_rnn_cell_impl.MultiRNNCell(
|
||||
[core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
|
||||
for _ in range(2)],
|
||||
state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
|
||||
g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
|
||||
[rnn_cell.BasicLSTMCell(2, state_is_tuple=True) for _ in range(2)],
|
||||
state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
basic_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
|
||||
x.name: x_values,
|
||||
@ -159,10 +155,9 @@ class LSTMBlockCellTest(test.TestCase):
|
||||
m1 = array_ops.zeros([1, 2])
|
||||
m2 = array_ops.zeros([1, 2])
|
||||
m3 = array_ops.zeros([1, 2])
|
||||
g, ((out_m0, out_m1),
|
||||
(out_m2, out_m3)) = core_rnn_cell_impl.MultiRNNCell(
|
||||
[lstm_ops.LSTMBlockCell(2) for _ in range(2)],
|
||||
state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
|
||||
g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
|
||||
[lstm_ops.LSTMBlockCell(2)
|
||||
for _ in range(2)], state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
block_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
|
||||
x.name: x_values,
|
||||
@ -193,12 +188,12 @@ class LSTMBlockCellTest(test.TestCase):
|
||||
m1 = array_ops.zeros([1, 2])
|
||||
m2 = array_ops.zeros([1, 2])
|
||||
m3 = array_ops.zeros([1, 2])
|
||||
g, ((out_m0, out_m1),
|
||||
(out_m2, out_m3)) = core_rnn_cell_impl.MultiRNNCell(
|
||||
[core_rnn_cell_impl.LSTMCell(
|
||||
2, use_peepholes=True, state_is_tuple=True)
|
||||
for _ in range(2)],
|
||||
state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
|
||||
g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
|
||||
[
|
||||
rnn_cell.LSTMCell(2, use_peepholes=True, state_is_tuple=True)
|
||||
for _ in range(2)
|
||||
],
|
||||
state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
basic_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
|
||||
x.name: x_values,
|
||||
@ -213,11 +208,9 @@ class LSTMBlockCellTest(test.TestCase):
|
||||
m1 = array_ops.zeros([1, 2])
|
||||
m2 = array_ops.zeros([1, 2])
|
||||
m3 = array_ops.zeros([1, 2])
|
||||
g, ((out_m0, out_m1),
|
||||
(out_m2, out_m3)) = core_rnn_cell_impl.MultiRNNCell(
|
||||
[lstm_ops.LSTMBlockCell(2, use_peephole=True)
|
||||
for _ in range(2)],
|
||||
state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
|
||||
g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
|
||||
[lstm_ops.LSTMBlockCell(2, use_peephole=True) for _ in range(2)],
|
||||
state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
block_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
|
||||
x.name: x_values,
|
||||
@ -247,8 +240,8 @@ class LSTMBlockCellTest(test.TestCase):
|
||||
initializer = init_ops.random_uniform_initializer(
|
||||
-0.01, 0.01, seed=19890212)
|
||||
with variable_scope.variable_scope("basic", initializer=initializer):
|
||||
cell = core_rnn_cell_impl.BasicLSTMCell(cell_size, state_is_tuple=True)
|
||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
cell = rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True)
|
||||
outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
basic_outputs, basic_state = sess.run([outputs, state[0]])
|
||||
@ -321,9 +314,9 @@ class LSTMBlockCellTest(test.TestCase):
|
||||
initializer = init_ops.random_uniform_initializer(
|
||||
-0.01, 0.01, seed=19890212)
|
||||
with variable_scope.variable_scope("basic", initializer=initializer):
|
||||
cell = core_rnn_cell_impl.LSTMCell(
|
||||
cell = rnn_cell.LSTMCell(
|
||||
cell_size, use_peepholes=True, state_is_tuple=True)
|
||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
basic_outputs, basic_state = sess.run([outputs, state[0]])
|
||||
@ -410,8 +403,8 @@ class LSTMBlockCellTest(test.TestCase):
|
||||
initializer = init_ops.random_uniform_initializer(
|
||||
-0.01, 0.01, seed=19890213)
|
||||
with variable_scope.variable_scope("basic", initializer=initializer):
|
||||
cell = core_rnn_cell_impl.BasicLSTMCell(cell_size, state_is_tuple=True)
|
||||
outputs, state = core_rnn.static_rnn(
|
||||
cell = rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True)
|
||||
outputs, state = rnn.static_rnn(
|
||||
cell, inputs, dtype=dtypes.float32, sequence_length=seq_lengths)
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
basic_outputs, basic_state = sess.run([outputs, state[0]])
|
||||
|
@ -22,8 +22,7 @@ import itertools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
||||
from tensorflow.contrib.rnn.python.ops import rnn_cell
|
||||
from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -37,6 +36,7 @@ from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
@ -65,7 +65,7 @@ class RNNCellTest(test.TestCase):
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([batch_size, input_size])
|
||||
m = array_ops.zeros([batch_size, state_size])
|
||||
output, state = rnn_cell.CoupledInputForgetGateLSTMCell(
|
||||
output, state = contrib_rnn_cell.CoupledInputForgetGateLSTMCell(
|
||||
num_units=num_units, forget_bias=1.0, state_is_tuple=False)(x, m)
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res = sess.run([output, state], {
|
||||
@ -94,7 +94,7 @@ class RNNCellTest(test.TestCase):
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([batch_size, input_size])
|
||||
m = array_ops.zeros([batch_size, state_size * num_shifts])
|
||||
output, state = rnn_cell.TimeFreqLSTMCell(
|
||||
output, state = contrib_rnn_cell.TimeFreqLSTMCell(
|
||||
num_units=num_units,
|
||||
feature_size=feature_size,
|
||||
frequency_skip=frequency_skip,
|
||||
@ -130,7 +130,7 @@ class RNNCellTest(test.TestCase):
|
||||
num_shifts = int((input_size - feature_size) / frequency_skip + 1)
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
cell = rnn_cell.GridLSTMCell(
|
||||
cell = contrib_rnn_cell.GridLSTMCell(
|
||||
num_units=num_units,
|
||||
feature_size=feature_size,
|
||||
frequency_skip=frequency_skip,
|
||||
@ -181,7 +181,7 @@ class RNNCellTest(test.TestCase):
|
||||
end_freqindex_list = [2, 4]
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
cell = rnn_cell.GridLSTMCell(
|
||||
cell = contrib_rnn_cell.GridLSTMCell(
|
||||
num_units=num_units,
|
||||
feature_size=feature_size,
|
||||
frequency_skip=frequency_skip,
|
||||
@ -249,7 +249,7 @@ class RNNCellTest(test.TestCase):
|
||||
with variable_scope.variable_scope(
|
||||
"state_is_tuple" + str(state_is_tuple),
|
||||
initializer=init_ops.constant_initializer(0.5)):
|
||||
cell = rnn_cell.GridLSTMCell(
|
||||
cell = contrib_rnn_cell.GridLSTMCell(
|
||||
num_units=num_units,
|
||||
feature_size=feature_size,
|
||||
frequency_skip=frequency_skip,
|
||||
@ -330,7 +330,7 @@ class RNNCellTest(test.TestCase):
|
||||
dtype=np.float32)
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
cell = rnn_cell.BidirectionalGridLSTMCell(
|
||||
cell = contrib_rnn_cell.BidirectionalGridLSTMCell(
|
||||
num_units=num_units,
|
||||
feature_size=feature_size,
|
||||
share_time_frequency_weights=True,
|
||||
@ -403,7 +403,7 @@ class RNNCellTest(test.TestCase):
|
||||
dtype=np.float32)
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
cell = rnn_cell.BidirectionalGridLSTMCell(
|
||||
cell = contrib_rnn_cell.BidirectionalGridLSTMCell(
|
||||
num_units=num_units,
|
||||
feature_size=feature_size,
|
||||
share_time_frequency_weights=True,
|
||||
@ -442,28 +442,28 @@ class RNNCellTest(test.TestCase):
|
||||
def testAttentionCellWrapperFailures(self):
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
"The parameter cell is not RNNCell."):
|
||||
rnn_cell.AttentionCellWrapper(None, 0)
|
||||
contrib_rnn_cell.AttentionCellWrapper(None, 0)
|
||||
|
||||
num_units = 8
|
||||
for state_is_tuple in [False, True]:
|
||||
with ops.Graph().as_default():
|
||||
lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
|
||||
lstm_cell = rnn_cell.BasicLSTMCell(
|
||||
num_units, state_is_tuple=state_is_tuple)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "attn_length should be greater than zero, got 0"):
|
||||
rnn_cell.AttentionCellWrapper(
|
||||
contrib_rnn_cell.AttentionCellWrapper(
|
||||
lstm_cell, 0, state_is_tuple=state_is_tuple)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "attn_length should be greater than zero, got -1"):
|
||||
rnn_cell.AttentionCellWrapper(
|
||||
contrib_rnn_cell.AttentionCellWrapper(
|
||||
lstm_cell, -1, state_is_tuple=state_is_tuple)
|
||||
with ops.Graph().as_default():
|
||||
lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
|
||||
num_units, state_is_tuple=True)
|
||||
lstm_cell = rnn_cell.BasicLSTMCell(num_units, state_is_tuple=True)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Cell returns tuple of states, but the flag "
|
||||
"state_is_tuple is not set. State size is: *"):
|
||||
rnn_cell.AttentionCellWrapper(lstm_cell, 4, state_is_tuple=False)
|
||||
contrib_rnn_cell.AttentionCellWrapper(
|
||||
lstm_cell, 4, state_is_tuple=False)
|
||||
|
||||
def testAttentionCellWrapperZeros(self):
|
||||
num_units = 8
|
||||
@ -475,9 +475,9 @@ class RNNCellTest(test.TestCase):
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope("state_is_tuple_" + str(
|
||||
state_is_tuple)):
|
||||
lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
|
||||
lstm_cell = rnn_cell.BasicLSTMCell(
|
||||
num_units, state_is_tuple=state_is_tuple)
|
||||
cell = rnn_cell.AttentionCellWrapper(
|
||||
cell = contrib_rnn_cell.AttentionCellWrapper(
|
||||
lstm_cell, attn_length, state_is_tuple=state_is_tuple)
|
||||
if state_is_tuple:
|
||||
zeros = array_ops.zeros([batch_size, num_units], dtype=np.float32)
|
||||
@ -526,9 +526,9 @@ class RNNCellTest(test.TestCase):
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope("state_is_tuple_" + str(
|
||||
state_is_tuple)):
|
||||
lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
|
||||
lstm_cell = rnn_cell.BasicLSTMCell(
|
||||
num_units, state_is_tuple=state_is_tuple)
|
||||
cell = rnn_cell.AttentionCellWrapper(
|
||||
cell = contrib_rnn_cell.AttentionCellWrapper(
|
||||
lstm_cell, attn_length, state_is_tuple=state_is_tuple)
|
||||
if state_is_tuple:
|
||||
zeros = constant_op.constant(
|
||||
@ -603,9 +603,9 @@ class RNNCellTest(test.TestCase):
|
||||
with variable_scope.variable_scope(
|
||||
"state_is_tuple", reuse=state_is_tuple,
|
||||
initializer=init_ops.glorot_uniform_initializer()):
|
||||
lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
|
||||
lstm_cell = rnn_cell.BasicLSTMCell(
|
||||
num_units, state_is_tuple=state_is_tuple)
|
||||
cell = rnn_cell.AttentionCellWrapper(
|
||||
cell = contrib_rnn_cell.AttentionCellWrapper(
|
||||
lstm_cell, attn_length, state_is_tuple=state_is_tuple)
|
||||
# This is legacy behavior to preserve the test. Weight
|
||||
# sharing no longer works by creating a new RNNCell in the
|
||||
@ -665,8 +665,7 @@ class RNNCellTest(test.TestCase):
|
||||
with variable_scope.variable_scope(
|
||||
"nas_test",
|
||||
initializer=init_ops.constant_initializer(0.5)):
|
||||
cell = rnn_cell.NASCell(
|
||||
num_units=num_units)
|
||||
cell = contrib_rnn_cell.NASCell(num_units=num_units)
|
||||
inputs = constant_op.constant(
|
||||
np.array([[1., 1., 1., 1.],
|
||||
[2., 2., 2., 2.],
|
||||
@ -677,8 +676,7 @@ class RNNCellTest(test.TestCase):
|
||||
0.1 * np.ones(
|
||||
(batch_size, num_units), dtype=np.float32),
|
||||
dtype=dtypes.float32)
|
||||
init_state = core_rnn_cell_impl.LSTMStateTuple(state_value,
|
||||
state_value)
|
||||
init_state = rnn_cell.LSTMStateTuple(state_value, state_value)
|
||||
output, state = cell(inputs, init_state)
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res = sess.run([output, state])
|
||||
@ -719,9 +717,7 @@ class RNNCellTest(test.TestCase):
|
||||
with variable_scope.variable_scope(
|
||||
"nas_proj_test",
|
||||
initializer=init_ops.constant_initializer(0.5)):
|
||||
cell = rnn_cell.NASCell(
|
||||
num_units=num_units,
|
||||
num_proj=num_proj)
|
||||
cell = contrib_rnn_cell.NASCell(num_units=num_units, num_proj=num_proj)
|
||||
inputs = constant_op.constant(
|
||||
np.array([[1., 1., 1., 1.],
|
||||
[2., 2., 2., 2.],
|
||||
@ -736,8 +732,7 @@ class RNNCellTest(test.TestCase):
|
||||
0.1 * np.ones(
|
||||
(batch_size, num_proj), dtype=np.float32),
|
||||
dtype=dtypes.float32)
|
||||
init_state = core_rnn_cell_impl.LSTMStateTuple(state_value_c,
|
||||
state_value_h)
|
||||
init_state = rnn_cell.LSTMStateTuple(state_value_c, state_value_h)
|
||||
output, state = cell(inputs, init_state)
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res = sess.run([output, state])
|
||||
@ -767,7 +762,7 @@ class RNNCellTest(test.TestCase):
|
||||
with variable_scope.variable_scope(
|
||||
"ugrnn_cell_test",
|
||||
initializer=init_ops.constant_initializer(0.5)):
|
||||
cell = rnn_cell.UGRNNCell(num_units=num_units)
|
||||
cell = contrib_rnn_cell.UGRNNCell(num_units=num_units)
|
||||
inputs = constant_op.constant(
|
||||
np.array([[1., 1., 1., 1.],
|
||||
[2., 2., 2., 2.],
|
||||
@ -803,8 +798,8 @@ class RNNCellTest(test.TestCase):
|
||||
with variable_scope.variable_scope(
|
||||
"intersection_rnn_cell_test",
|
||||
initializer=init_ops.constant_initializer(0.5)):
|
||||
cell = rnn_cell.IntersectionRNNCell(num_units=num_units,
|
||||
num_in_proj=num_units)
|
||||
cell = contrib_rnn_cell.IntersectionRNNCell(
|
||||
num_units=num_units, num_in_proj=num_units)
|
||||
inputs = constant_op.constant(
|
||||
np.array([[1., 1., 1., 1.],
|
||||
[2., 2., 2., 2.],
|
||||
@ -826,7 +821,7 @@ class RNNCellTest(test.TestCase):
|
||||
def testIntersectionRNNCellFailure(self):
|
||||
num_units = 2
|
||||
batch_size = 3
|
||||
cell = rnn_cell.IntersectionRNNCell(num_units=num_units)
|
||||
cell = contrib_rnn_cell.IntersectionRNNCell(num_units=num_units)
|
||||
inputs = constant_op.constant(
|
||||
np.array([[1., 1., 1., 1.],
|
||||
[2., 2., 2., 2.],
|
||||
@ -862,9 +857,9 @@ class RNNCellTest(test.TestCase):
|
||||
x = array_ops.zeros([batch_size, input_size])
|
||||
c0 = array_ops.zeros([batch_size, 2])
|
||||
h0 = array_ops.zeros([batch_size, 2])
|
||||
state0 = core_rnn_cell_impl.LSTMStateTuple(c0, h0)
|
||||
output, state = rnn_cell.PhasedLSTMCell(num_units=num_units)((t, x),
|
||||
state0)
|
||||
state0 = rnn_cell.LSTMStateTuple(c0, h0)
|
||||
output, state = contrib_rnn_cell.PhasedLSTMCell(num_units=num_units)(
|
||||
(t, x), state0)
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res = sess.run([output, state], {
|
||||
t.name:
|
||||
@ -886,12 +881,12 @@ class RNNCellTest(test.TestCase):
|
||||
"base_cell", initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([1, 3])
|
||||
m = array_ops.zeros([1, 3])
|
||||
base_cell = core_rnn_cell_impl.GRUCell(3)
|
||||
base_cell = rnn_cell.GRUCell(3)
|
||||
g, m_new = base_cell(x, m)
|
||||
with variable_scope.variable_scope(
|
||||
"hw_cell", initializer=init_ops.constant_initializer(0.5)):
|
||||
hw_cell = rnn_cell.HighwayWrapper(
|
||||
core_rnn_cell_impl.GRUCell(3), carry_bias_init=-100.0)
|
||||
hw_cell = contrib_rnn_cell.HighwayWrapper(
|
||||
rnn_cell.GRUCell(3), carry_bias_init=-100.0)
|
||||
g_res, m_new_res = hw_cell(x, m)
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res = sess.run([g, g_res, m_new, m_new_res], {
|
||||
@ -915,9 +910,9 @@ class RNNCellTest(test.TestCase):
|
||||
"root1", initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.ones([batch_size, num_units])
|
||||
# When number_of_groups = 1, G-LSTM is equivalent to regular LSTM
|
||||
gcell = rnn_cell.GLSTMCell(num_units=num_units,
|
||||
number_of_groups=number_of_groups)
|
||||
cell = core_rnn_cell_impl.LSTMCell(num_units=num_units)
|
||||
gcell = contrib_rnn_cell.GLSTMCell(
|
||||
num_units=num_units, number_of_groups=number_of_groups)
|
||||
cell = rnn_cell.LSTMCell(num_units=num_units)
|
||||
self.assertTrue(isinstance(gcell.state_size, tuple))
|
||||
zero_state = gcell.zero_state(batch_size=batch_size,
|
||||
dtype=dtypes.float32)
|
||||
@ -941,8 +936,8 @@ class RNNCellTest(test.TestCase):
|
||||
"root2", initializer=init_ops.constant_initializer(0.5)):
|
||||
# input for G-LSTM with 2 groups
|
||||
glstm_input = array_ops.ones([batch_size, num_units])
|
||||
gcell = rnn_cell.GLSTMCell(num_units=num_units,
|
||||
number_of_groups=number_of_groups)
|
||||
gcell = contrib_rnn_cell.GLSTMCell(
|
||||
num_units=num_units, number_of_groups=number_of_groups)
|
||||
gcell_zero_state = gcell.zero_state(batch_size=batch_size,
|
||||
dtype=dtypes.float32)
|
||||
gh, gs = gcell(glstm_input, gcell_zero_state)
|
||||
@ -950,8 +945,7 @@ class RNNCellTest(test.TestCase):
|
||||
# input for LSTM cell simulating single G-LSTM group
|
||||
lstm_input = array_ops.ones([batch_size, num_units / number_of_groups])
|
||||
# note division by number_of_groups. This cell one simulates G-LSTM group
|
||||
cell = core_rnn_cell_impl.LSTMCell(num_units=
|
||||
int(num_units / number_of_groups))
|
||||
cell = rnn_cell.LSTMCell(num_units=int(num_units / number_of_groups))
|
||||
cell_zero_state = cell.zero_state(batch_size=batch_size,
|
||||
dtype=dtypes.float32)
|
||||
h, g = cell(lstm_input, cell_zero_state)
|
||||
@ -974,13 +968,13 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
|
||||
x = array_ops.zeros([1, 2])
|
||||
c0 = array_ops.zeros([1, 2])
|
||||
h0 = array_ops.zeros([1, 2])
|
||||
state0 = core_rnn_cell_impl.LSTMStateTuple(c0, h0)
|
||||
state0 = rnn_cell.LSTMStateTuple(c0, h0)
|
||||
c1 = array_ops.zeros([1, 2])
|
||||
h1 = array_ops.zeros([1, 2])
|
||||
state1 = core_rnn_cell_impl.LSTMStateTuple(c1, h1)
|
||||
state1 = rnn_cell.LSTMStateTuple(c1, h1)
|
||||
state = (state0, state1)
|
||||
single_cell = lambda: rnn_cell.LayerNormBasicLSTMCell(2)
|
||||
cell = core_rnn_cell_impl.MultiRNNCell([single_cell() for _ in range(2)])
|
||||
single_cell = lambda: contrib_rnn_cell.LayerNormBasicLSTMCell(2)
|
||||
cell = rnn_cell.MultiRNNCell([single_cell() for _ in range(2)])
|
||||
g, out_m = cell(x, state)
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res = sess.run([g, out_m], {
|
||||
@ -1015,8 +1009,8 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
|
||||
[1, 3]) # Test BasicLSTMCell with input_size != num_units.
|
||||
c = array_ops.zeros([1, 2])
|
||||
h = array_ops.zeros([1, 2])
|
||||
state = core_rnn_cell_impl.LSTMStateTuple(c, h)
|
||||
cell = rnn_cell.LayerNormBasicLSTMCell(2)
|
||||
state = rnn_cell.LSTMStateTuple(c, h)
|
||||
cell = contrib_rnn_cell.LayerNormBasicLSTMCell(2)
|
||||
g, out_m = cell(x, state)
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res = sess.run([g, out_m], {
|
||||
@ -1039,12 +1033,12 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
|
||||
x = array_ops.zeros([1, 2])
|
||||
c0 = array_ops.zeros([1, 2])
|
||||
h0 = array_ops.zeros([1, 2])
|
||||
state0 = core_rnn_cell_impl.LSTMStateTuple(c0, h0)
|
||||
state0 = rnn_cell.LSTMStateTuple(c0, h0)
|
||||
c1 = array_ops.zeros([1, 2])
|
||||
h1 = array_ops.zeros([1, 2])
|
||||
state1 = core_rnn_cell_impl.LSTMStateTuple(c1, h1)
|
||||
cell = core_rnn_cell_impl.MultiRNNCell(
|
||||
[rnn_cell.LayerNormBasicLSTMCell(2) for _ in range(2)])
|
||||
state1 = rnn_cell.LSTMStateTuple(c1, h1)
|
||||
cell = rnn_cell.MultiRNNCell(
|
||||
[contrib_rnn_cell.LayerNormBasicLSTMCell(2) for _ in range(2)])
|
||||
h, (s0, s1) = cell(x, (state0, state1))
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res = sess.run([h, s0, s1], {
|
||||
@ -1094,8 +1088,8 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
|
||||
x = array_ops.zeros([1, 5])
|
||||
c = array_ops.zeros([1, 5])
|
||||
h = array_ops.zeros([1, 5])
|
||||
state = core_rnn_cell_impl.LSTMStateTuple(c, h)
|
||||
cell = rnn_cell.LayerNormBasicLSTMCell(
|
||||
state = rnn_cell.LSTMStateTuple(c, h)
|
||||
cell = contrib_rnn_cell.LayerNormBasicLSTMCell(
|
||||
num_units, layer_norm=False, dropout_keep_prob=keep_prob)
|
||||
|
||||
g, s = cell(x, state)
|
||||
@ -1138,10 +1132,9 @@ def _create_multi_lstm_cell_ops(batch_size, num_units, input_depth,
|
||||
inputs = variable_scope.get_variable(
|
||||
"inputs", initializer=random_ops.random_uniform(
|
||||
(max_time, batch_size, input_depth), seed=1))
|
||||
maybe_xla = lambda c: rnn_cell.CompiledWrapper(c) if compiled else c
|
||||
cell = core_rnn_cell_impl.MultiRNNCell(
|
||||
[maybe_xla(core_rnn_cell_impl.LSTMCell(num_units))
|
||||
for _ in range(num_layers)])
|
||||
maybe_xla = lambda c: contrib_rnn_cell.CompiledWrapper(c) if compiled else c
|
||||
cell = rnn_cell.MultiRNNCell(
|
||||
[maybe_xla(rnn_cell.LSTMCell(num_units)) for _ in range(num_layers)])
|
||||
initial_state = cell.zero_state(
|
||||
batch_size=batch_size, dtype=dtypes.float32)
|
||||
outputs, final_state = rnn.dynamic_rnn(
|
||||
@ -1219,13 +1212,13 @@ class CompiledWrapperTest(test.TestCase):
|
||||
|
||||
# Test incorrectness of state
|
||||
with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
|
||||
core_rnn_cell_impl.MultiRNNCell(
|
||||
[core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
|
||||
state_is_tuple=True)(x, m_bad)
|
||||
rnn_cell.MultiRNNCell(
|
||||
[rnn_cell.GRUCell(2)
|
||||
for _ in range(2)], state_is_tuple=True)(x, m_bad)
|
||||
|
||||
_, ml = core_rnn_cell_impl.MultiRNNCell(
|
||||
[core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
|
||||
state_is_tuple=True)(x, m_good)
|
||||
_, ml = rnn_cell.MultiRNNCell(
|
||||
[rnn_cell.GRUCell(2)
|
||||
for _ in range(2)], state_is_tuple=True)(x, m_good)
|
||||
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res = sess.run(ml, {
|
||||
|
@ -22,12 +22,12 @@ import itertools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
||||
from tensorflow.contrib.rnn.python.ops import rnn
|
||||
from tensorflow.contrib.rnn.python.ops import rnn as contrib_rnn
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
@ -58,14 +58,14 @@ class StackBidirectionalRNNTest(test.TestCase):
|
||||
dtypes.int64) if use_sequence_length else None
|
||||
|
||||
self.cells_fw = [
|
||||
core_rnn_cell_impl.LSTMCell(
|
||||
rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
input_size,
|
||||
initializer=initializer,
|
||||
state_is_tuple=False) for num_units in self.layers
|
||||
]
|
||||
self.cells_bw = [
|
||||
core_rnn_cell_impl.LSTMCell(
|
||||
rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
input_size,
|
||||
initializer=initializer,
|
||||
@ -77,7 +77,7 @@ class StackBidirectionalRNNTest(test.TestCase):
|
||||
dtypes.float32,
|
||||
shape=(batch_size, input_size) if use_shape else (None, input_size))
|
||||
]
|
||||
outputs, state_fw, state_bw = rnn.stack_bidirectional_rnn(
|
||||
outputs, state_fw, state_bw = contrib_rnn.stack_bidirectional_rnn(
|
||||
self.cells_fw,
|
||||
self.cells_bw,
|
||||
inputs,
|
||||
@ -237,14 +237,14 @@ class StackBidirectionalRNNTest(test.TestCase):
|
||||
sequence_length = array_ops.placeholder(dtypes.int64)
|
||||
|
||||
self.cells_fw = [
|
||||
core_rnn_cell_impl.LSTMCell(
|
||||
rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
input_size,
|
||||
initializer=initializer,
|
||||
state_is_tuple=False) for num_units in self.layers
|
||||
]
|
||||
self.cells_bw = [
|
||||
core_rnn_cell_impl.LSTMCell(
|
||||
rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
input_size,
|
||||
initializer=initializer,
|
||||
@ -258,7 +258,7 @@ class StackBidirectionalRNNTest(test.TestCase):
|
||||
]
|
||||
inputs_c = array_ops.stack(inputs)
|
||||
inputs_c = array_ops.transpose(inputs_c, [1, 0, 2])
|
||||
outputs, st_fw, st_bw = rnn.stack_bidirectional_dynamic_rnn(
|
||||
outputs, st_fw, st_bw = contrib_rnn.stack_bidirectional_dynamic_rnn(
|
||||
self.cells_fw,
|
||||
self.cells_bw,
|
||||
inputs_c,
|
||||
|
@ -1,357 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""RNN helpers for TensorFlow models."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell_impl
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
_concat = rnn_cell_impl._concat
|
||||
_like_rnncell = rnn_cell_impl._like_rnncell
|
||||
_infer_state_dtype = rnn._infer_state_dtype
|
||||
_reverse_seq = rnn._reverse_seq
|
||||
_rnn_step = rnn._rnn_step
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def static_rnn(cell, inputs, initial_state=None, dtype=None,
|
||||
sequence_length=None, scope=None):
|
||||
"""Creates a recurrent neural network specified by RNNCell `cell`.
|
||||
|
||||
The simplest form of RNN network generated is:
|
||||
|
||||
```python
|
||||
state = cell.zero_state(...)
|
||||
outputs = []
|
||||
for input_ in inputs:
|
||||
output, state = cell(input_, state)
|
||||
outputs.append(output)
|
||||
return (outputs, state)
|
||||
```
|
||||
However, a few other options are available:
|
||||
|
||||
An initial state can be provided.
|
||||
If the sequence_length vector is provided, dynamic calculation is performed.
|
||||
This method of calculation does not compute the RNN steps past the maximum
|
||||
sequence length of the minibatch (thus saving computational time),
|
||||
and properly propagates the state at an example's sequence length
|
||||
to the final state output.
|
||||
|
||||
The dynamic calculation performed is, at time `t` for batch row `b`,
|
||||
|
||||
```python
|
||||
(output, state)(b, t) =
|
||||
(t >= sequence_length(b))
|
||||
? (zeros(cell.output_size), states(b, sequence_length(b) - 1))
|
||||
: cell(input(b, t), state(b, t - 1))
|
||||
```
|
||||
|
||||
Args:
|
||||
cell: An instance of RNNCell.
|
||||
inputs: A length T list of inputs, each a `Tensor` of shape
|
||||
`[batch_size, input_size]`, or a nested tuple of such elements.
|
||||
initial_state: (optional) An initial state for the RNN.
|
||||
If `cell.state_size` is an integer, this must be
|
||||
a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`.
|
||||
If `cell.state_size` is a tuple, this should be a tuple of
|
||||
tensors having shapes `[batch_size, s] for s in cell.state_size`.
|
||||
dtype: (optional) The data type for the initial state and expected output.
|
||||
Required if initial_state is not provided or RNN state has a heterogeneous
|
||||
dtype.
|
||||
sequence_length: Specifies the length of each sequence in inputs.
|
||||
An int32 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`.
|
||||
scope: VariableScope for the created subgraph; defaults to "rnn".
|
||||
|
||||
Returns:
|
||||
A pair (outputs, state) where:
|
||||
|
||||
- outputs is a length T list of outputs (one for each input), or a nested
|
||||
tuple of such elements.
|
||||
- state is the final state
|
||||
|
||||
Raises:
|
||||
TypeError: If `cell` is not an instance of RNNCell.
|
||||
ValueError: If `inputs` is `None` or an empty list, or if the input depth
|
||||
(column size) cannot be inferred from inputs via shape inference.
|
||||
"""
|
||||
|
||||
if not _like_rnncell(cell):
|
||||
raise TypeError("cell must be an instance of RNNCell")
|
||||
if not nest.is_sequence(inputs):
|
||||
raise TypeError("inputs must be a sequence")
|
||||
if not inputs:
|
||||
raise ValueError("inputs must not be empty")
|
||||
|
||||
outputs = []
|
||||
# Create a new scope in which the caching device is either
|
||||
# determined by the parent scope, or is set to place the cached
|
||||
# Variable using the same placement as for the rest of the RNN.
|
||||
with vs.variable_scope(scope or "rnn") as varscope:
|
||||
if varscope.caching_device is None:
|
||||
varscope.set_caching_device(lambda op: op.device)
|
||||
|
||||
# Obtain the first sequence of the input
|
||||
first_input = inputs
|
||||
while nest.is_sequence(first_input):
|
||||
first_input = first_input[0]
|
||||
|
||||
# Temporarily avoid EmbeddingWrapper and seq2seq badness
|
||||
# TODO(lukaszkaiser): remove EmbeddingWrapper
|
||||
if first_input.get_shape().ndims != 1:
|
||||
|
||||
input_shape = first_input.get_shape().with_rank_at_least(2)
|
||||
fixed_batch_size = input_shape[0]
|
||||
|
||||
flat_inputs = nest.flatten(inputs)
|
||||
for flat_input in flat_inputs:
|
||||
input_shape = flat_input.get_shape().with_rank_at_least(2)
|
||||
batch_size, input_size = input_shape[0], input_shape[1:]
|
||||
fixed_batch_size.merge_with(batch_size)
|
||||
for i, size in enumerate(input_size):
|
||||
if size.value is None:
|
||||
raise ValueError(
|
||||
"Input size (dimension %d of inputs) must be accessible via "
|
||||
"shape inference, but saw value None." % i)
|
||||
else:
|
||||
fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0]
|
||||
|
||||
if fixed_batch_size.value:
|
||||
batch_size = fixed_batch_size.value
|
||||
else:
|
||||
batch_size = array_ops.shape(first_input)[0]
|
||||
if initial_state is not None:
|
||||
state = initial_state
|
||||
else:
|
||||
if not dtype:
|
||||
raise ValueError("If no initial_state is provided, "
|
||||
"dtype must be specified")
|
||||
state = cell.zero_state(batch_size, dtype)
|
||||
|
||||
if sequence_length is not None: # Prepare variables
|
||||
sequence_length = ops.convert_to_tensor(
|
||||
sequence_length, name="sequence_length")
|
||||
if sequence_length.get_shape().ndims not in (None, 1):
|
||||
raise ValueError(
|
||||
"sequence_length must be a vector of length batch_size")
|
||||
def _create_zero_output(output_size):
|
||||
# convert int to TensorShape if necessary
|
||||
size = _concat(batch_size, output_size)
|
||||
output = array_ops.zeros(
|
||||
array_ops.stack(size), _infer_state_dtype(dtype, state))
|
||||
shape = _concat(fixed_batch_size.value, output_size, static=True)
|
||||
output.set_shape(tensor_shape.TensorShape(shape))
|
||||
return output
|
||||
|
||||
output_size = cell.output_size
|
||||
flat_output_size = nest.flatten(output_size)
|
||||
flat_zero_output = tuple(
|
||||
_create_zero_output(size) for size in flat_output_size)
|
||||
zero_output = nest.pack_sequence_as(structure=output_size,
|
||||
flat_sequence=flat_zero_output)
|
||||
|
||||
sequence_length = math_ops.to_int32(sequence_length)
|
||||
min_sequence_length = math_ops.reduce_min(sequence_length)
|
||||
max_sequence_length = math_ops.reduce_max(sequence_length)
|
||||
|
||||
for time, input_ in enumerate(inputs):
|
||||
if time > 0: varscope.reuse_variables()
|
||||
# pylint: disable=cell-var-from-loop
|
||||
call_cell = lambda: cell(input_, state)
|
||||
# pylint: enable=cell-var-from-loop
|
||||
if sequence_length is not None:
|
||||
(output, state) = _rnn_step(
|
||||
time=time,
|
||||
sequence_length=sequence_length,
|
||||
min_sequence_length=min_sequence_length,
|
||||
max_sequence_length=max_sequence_length,
|
||||
zero_output=zero_output,
|
||||
state=state,
|
||||
call_cell=call_cell,
|
||||
state_size=cell.state_size)
|
||||
else:
|
||||
(output, state) = call_cell()
|
||||
|
||||
outputs.append(output)
|
||||
|
||||
return (outputs, state)
|
||||
|
||||
|
||||
def static_state_saving_rnn(cell, inputs, state_saver, state_name,
|
||||
sequence_length=None, scope=None):
|
||||
"""RNN that accepts a state saver for time-truncated RNN calculation.
|
||||
|
||||
Args:
|
||||
cell: An instance of `RNNCell`.
|
||||
inputs: A length T list of inputs, each a `Tensor` of shape
|
||||
`[batch_size, input_size]`.
|
||||
state_saver: A state saver object with methods `state` and `save_state`.
|
||||
state_name: Python string or tuple of strings. The name to use with the
|
||||
state_saver. If the cell returns tuples of states (i.e.,
|
||||
`cell.state_size` is a tuple) then `state_name` should be a tuple of
|
||||
strings having the same length as `cell.state_size`. Otherwise it should
|
||||
be a single string.
|
||||
sequence_length: (optional) An int32/int64 vector size [batch_size].
|
||||
See the documentation for rnn() for more details about sequence_length.
|
||||
scope: VariableScope for the created subgraph; defaults to "rnn".
|
||||
|
||||
Returns:
|
||||
A pair (outputs, state) where:
|
||||
outputs is a length T list of outputs (one for each input)
|
||||
states is the final state
|
||||
|
||||
Raises:
|
||||
TypeError: If `cell` is not an instance of RNNCell.
|
||||
ValueError: If `inputs` is `None` or an empty list, or if the arity and
|
||||
type of `state_name` does not match that of `cell.state_size`.
|
||||
"""
|
||||
state_size = cell.state_size
|
||||
state_is_tuple = nest.is_sequence(state_size)
|
||||
state_name_tuple = nest.is_sequence(state_name)
|
||||
|
||||
if state_is_tuple != state_name_tuple:
|
||||
raise ValueError(
|
||||
"state_name should be the same type as cell.state_size. "
|
||||
"state_name: %s, cell.state_size: %s"
|
||||
% (str(state_name), str(state_size)))
|
||||
|
||||
if state_is_tuple:
|
||||
state_name_flat = nest.flatten(state_name)
|
||||
state_size_flat = nest.flatten(state_size)
|
||||
|
||||
if len(state_name_flat) != len(state_size_flat):
|
||||
raise ValueError("#elems(state_name) != #elems(state_size): %d vs. %d"
|
||||
% (len(state_name_flat), len(state_size_flat)))
|
||||
|
||||
initial_state = nest.pack_sequence_as(
|
||||
structure=state_size,
|
||||
flat_sequence=[state_saver.state(s) for s in state_name_flat])
|
||||
else:
|
||||
initial_state = state_saver.state(state_name)
|
||||
|
||||
(outputs, state) = static_rnn(cell, inputs, initial_state=initial_state,
|
||||
sequence_length=sequence_length, scope=scope)
|
||||
|
||||
if state_is_tuple:
|
||||
flat_state = nest.flatten(state)
|
||||
state_name = nest.flatten(state_name)
|
||||
save_state = [state_saver.save_state(name, substate)
|
||||
for name, substate in zip(state_name, flat_state)]
|
||||
else:
|
||||
save_state = [state_saver.save_state(state_name, state)]
|
||||
|
||||
with ops.control_dependencies(save_state):
|
||||
last_output = outputs[-1]
|
||||
flat_last_output = nest.flatten(last_output)
|
||||
flat_last_output = [
|
||||
array_ops.identity(output) for output in flat_last_output]
|
||||
outputs[-1] = nest.pack_sequence_as(structure=last_output,
|
||||
flat_sequence=flat_last_output)
|
||||
|
||||
return (outputs, state)
|
||||
|
||||
|
||||
def static_bidirectional_rnn(cell_fw, cell_bw, inputs,
|
||||
initial_state_fw=None, initial_state_bw=None,
|
||||
dtype=None, sequence_length=None, scope=None):
|
||||
"""Creates a bidirectional recurrent neural network.
|
||||
|
||||
Similar to the unidirectional case above (rnn) but takes input and builds
|
||||
independent forward and backward RNNs with the final forward and backward
|
||||
outputs depth-concatenated, such that the output will have the format
|
||||
[time][batch][cell_fw.output_size + cell_bw.output_size]. The input_size of
|
||||
forward and backward cell must match. The initial state for both directions
|
||||
is zero by default (but can be set optionally) and no intermediate states are
|
||||
ever returned -- the network is fully unrolled for the given (passed in)
|
||||
length(s) of the sequence(s) or completely unrolled if length(s) is not given.
|
||||
|
||||
Args:
|
||||
cell_fw: An instance of RNNCell, to be used for forward direction.
|
||||
cell_bw: An instance of RNNCell, to be used for backward direction.
|
||||
inputs: A length T list of inputs, each a tensor of shape
|
||||
[batch_size, input_size], or a nested tuple of such elements.
|
||||
initial_state_fw: (optional) An initial state for the forward RNN.
|
||||
This must be a tensor of appropriate type and shape
|
||||
`[batch_size, cell_fw.state_size]`.
|
||||
If `cell_fw.state_size` is a tuple, this should be a tuple of
|
||||
tensors having shapes `[batch_size, s] for s in cell_fw.state_size`.
|
||||
initial_state_bw: (optional) Same as for `initial_state_fw`, but using
|
||||
the corresponding properties of `cell_bw`.
|
||||
dtype: (optional) The data type for the initial state. Required if
|
||||
either of the initial states are not provided.
|
||||
sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
|
||||
containing the actual lengths for each of the sequences.
|
||||
scope: VariableScope for the created subgraph; defaults to
|
||||
"bidirectional_rnn"
|
||||
|
||||
Returns:
|
||||
A tuple (outputs, output_state_fw, output_state_bw) where:
|
||||
outputs is a length `T` list of outputs (one for each input), which
|
||||
are depth-concatenated forward and backward outputs.
|
||||
output_state_fw is the final state of the forward rnn.
|
||||
output_state_bw is the final state of the backward rnn.
|
||||
|
||||
Raises:
|
||||
TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
|
||||
ValueError: If inputs is None or an empty list.
|
||||
"""
|
||||
|
||||
if not _like_rnncell(cell_fw):
|
||||
raise TypeError("cell_fw must be an instance of RNNCell")
|
||||
if not _like_rnncell(cell_bw):
|
||||
raise TypeError("cell_bw must be an instance of RNNCell")
|
||||
if not nest.is_sequence(inputs):
|
||||
raise TypeError("inputs must be a sequence")
|
||||
if not inputs:
|
||||
raise ValueError("inputs must not be empty")
|
||||
|
||||
with vs.variable_scope(scope or "bidirectional_rnn"):
|
||||
# Forward direction
|
||||
with vs.variable_scope("fw") as fw_scope:
|
||||
output_fw, output_state_fw = static_rnn(
|
||||
cell_fw, inputs, initial_state_fw, dtype,
|
||||
sequence_length, scope=fw_scope)
|
||||
|
||||
# Backward direction
|
||||
with vs.variable_scope("bw") as bw_scope:
|
||||
reversed_inputs = _reverse_seq(inputs, sequence_length)
|
||||
tmp, output_state_bw = static_rnn(
|
||||
cell_bw, reversed_inputs, initial_state_bw,
|
||||
dtype, sequence_length, scope=bw_scope)
|
||||
|
||||
output_bw = _reverse_seq(tmp, sequence_length)
|
||||
# Concat each of the forward/backward outputs
|
||||
flat_output_fw = nest.flatten(output_fw)
|
||||
flat_output_bw = nest.flatten(output_bw)
|
||||
|
||||
flat_outputs = tuple(
|
||||
array_ops.concat([fw, bw], 1)
|
||||
for fw, bw in zip(flat_output_fw, flat_output_bw))
|
||||
|
||||
outputs = nest.pack_sequence_as(structure=output_fw,
|
||||
flat_sequence=flat_outputs)
|
||||
|
||||
return (outputs, output_state_fw, output_state_bw)
|
@ -12,45 +12,219 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Module implementing RNN Cells that used to be in core.
|
||||
|
||||
"""Module for constructing RNN Cells.
|
||||
|
||||
## Base interface for all RNN Cells
|
||||
|
||||
@@RNNCell
|
||||
|
||||
## RNN Cells for use with TensorFlow's core RNN methods
|
||||
|
||||
@@BasicRNNCell
|
||||
@@BasicLSTMCell
|
||||
@@GRUCell
|
||||
@@LSTMCell
|
||||
|
||||
## Classes storing split `RNNCell` state
|
||||
|
||||
@@LSTMStateTuple
|
||||
|
||||
## RNN Cell wrappers (RNNCells that wrap other RNNCells)
|
||||
|
||||
@@MultiRNNCell
|
||||
@@DropoutWrapper
|
||||
@@EmbeddingWrapper
|
||||
@@InputProjectionWrapper
|
||||
@@OutputProjectionWrapper
|
||||
@@DeviceWrapper
|
||||
@@ResidualWrapper
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# go/tf-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
import math
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import rnn_cell_impl
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
RNNCell = rnn_cell_impl.RNNCell # pylint: disable=invalid-name
|
||||
_linear = rnn_cell_impl._linear # pylint: disable=invalid-name, protected-access
|
||||
_like_rnncell = rnn_cell_impl._like_rnncell # pylint: disable=invalid-name, protected-access
|
||||
|
||||
|
||||
_allowed_symbols = []
|
||||
class EmbeddingWrapper(RNNCell):
|
||||
"""Operator adding input embedding to the given cell.
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
Note: in many cases it may be more efficient to not use this wrapper,
|
||||
but instead concatenate the whole sequence of your inputs in time,
|
||||
do the embedding on this batch-concatenated sequence, then split it and
|
||||
feed into your RNN.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
cell,
|
||||
embedding_classes,
|
||||
embedding_size,
|
||||
initializer=None,
|
||||
reuse=None):
|
||||
"""Create a cell with an added input embedding.
|
||||
|
||||
Args:
|
||||
cell: an RNNCell, an embedding will be put before its inputs.
|
||||
embedding_classes: integer, how many symbols will be embedded.
|
||||
embedding_size: integer, the size of the vectors we embed into.
|
||||
initializer: an initializer to use when creating the embedding;
|
||||
if None, the initializer from variable scope or a default one is used.
|
||||
reuse: (optional) Python boolean describing whether to reuse variables
|
||||
in an existing scope. If not `True`, and the existing scope already has
|
||||
the given variables, an error is raised.
|
||||
|
||||
Raises:
|
||||
TypeError: if cell is not an RNNCell.
|
||||
ValueError: if embedding_classes is not positive.
|
||||
"""
|
||||
super(EmbeddingWrapper, self).__init__(_reuse=reuse)
|
||||
if not _like_rnncell(cell):
|
||||
raise TypeError("The parameter cell is not RNNCell.")
|
||||
if embedding_classes <= 0 or embedding_size <= 0:
|
||||
raise ValueError("Both embedding_classes and embedding_size must be > 0: "
|
||||
"%d, %d." % (embedding_classes, embedding_size))
|
||||
self._cell = cell
|
||||
self._embedding_classes = embedding_classes
|
||||
self._embedding_size = embedding_size
|
||||
self._initializer = initializer
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._cell.state_size
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._cell.output_size
|
||||
|
||||
def zero_state(self, batch_size, dtype):
|
||||
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
||||
return self._cell.zero_state(batch_size, dtype)
|
||||
|
||||
def call(self, inputs, state):
|
||||
"""Run the cell on embedded inputs."""
|
||||
with ops.device("/cpu:0"):
|
||||
if self._initializer:
|
||||
initializer = self._initializer
|
||||
elif vs.get_variable_scope().initializer:
|
||||
initializer = vs.get_variable_scope().initializer
|
||||
else:
|
||||
# Default initializer for embeddings should have variance=1.
|
||||
sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1.
|
||||
initializer = init_ops.random_uniform_initializer(-sqrt3, sqrt3)
|
||||
|
||||
if isinstance(state, tuple):
|
||||
data_type = state[0].dtype
|
||||
else:
|
||||
data_type = state.dtype
|
||||
|
||||
embedding = vs.get_variable(
|
||||
"embedding", [self._embedding_classes, self._embedding_size],
|
||||
initializer=initializer,
|
||||
dtype=data_type)
|
||||
embedded = embedding_ops.embedding_lookup(embedding,
|
||||
array_ops.reshape(inputs, [-1]))
|
||||
|
||||
return self._cell(embedded, state)
|
||||
|
||||
|
||||
class InputProjectionWrapper(RNNCell):
|
||||
"""Operator adding an input projection to the given cell.
|
||||
|
||||
Note: in many cases it may be more efficient to not use this wrapper,
|
||||
but instead concatenate the whole sequence of your inputs in time,
|
||||
do the projection on this batch-concatenated sequence, then split it.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
cell,
|
||||
num_proj,
|
||||
activation=None,
|
||||
input_size=None,
|
||||
reuse=None):
|
||||
"""Create a cell with input projection.
|
||||
|
||||
Args:
|
||||
cell: an RNNCell, a projection of inputs is added before it.
|
||||
num_proj: Python integer. The dimension to project to.
|
||||
activation: (optional) an optional activation function.
|
||||
input_size: Deprecated and unused.
|
||||
reuse: (optional) Python boolean describing whether to reuse variables
|
||||
in an existing scope. If not `True`, and the existing scope already has
|
||||
the given variables, an error is raised.
|
||||
|
||||
Raises:
|
||||
TypeError: if cell is not an RNNCell.
|
||||
"""
|
||||
super(InputProjectionWrapper, self).__init__(_reuse=reuse)
|
||||
if input_size is not None:
|
||||
logging.warn("%s: The input_size parameter is deprecated.", self)
|
||||
if not _like_rnncell(cell):
|
||||
raise TypeError("The parameter cell is not RNNCell.")
|
||||
self._cell = cell
|
||||
self._num_proj = num_proj
|
||||
self._activation = activation
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._cell.state_size
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._cell.output_size
|
||||
|
||||
def zero_state(self, batch_size, dtype):
|
||||
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
||||
return self._cell.zero_state(batch_size, dtype)
|
||||
|
||||
def call(self, inputs, state):
|
||||
"""Run the input projection and then the cell."""
|
||||
# Default scope: "InputProjectionWrapper"
|
||||
projected = _linear(inputs, self._num_proj, True)
|
||||
if self._activation:
|
||||
projected = self._activation(projected)
|
||||
return self._cell(projected, state)
|
||||
|
||||
|
||||
class OutputProjectionWrapper(RNNCell):
|
||||
"""Operator adding an output projection to the given cell.
|
||||
|
||||
Note: in many cases it may be more efficient to not use this wrapper,
|
||||
but instead concatenate the whole sequence of your outputs in time,
|
||||
do the projection on this batch-concatenated sequence, then split it
|
||||
if needed or directly feed into a softmax.
|
||||
"""
|
||||
|
||||
def __init__(self, cell, output_size, activation=None, reuse=None):
|
||||
"""Create a cell with output projection.
|
||||
|
||||
Args:
|
||||
cell: an RNNCell, a projection to output_size is added to it.
|
||||
output_size: integer, the size of the output after projection.
|
||||
activation: (optional) an optional activation function.
|
||||
reuse: (optional) Python boolean describing whether to reuse variables
|
||||
in an existing scope. If not `True`, and the existing scope already has
|
||||
the given variables, an error is raised.
|
||||
|
||||
Raises:
|
||||
TypeError: if cell is not an RNNCell.
|
||||
ValueError: if output_size is not positive.
|
||||
"""
|
||||
super(OutputProjectionWrapper, self).__init__(_reuse=reuse)
|
||||
if not _like_rnncell(cell):
|
||||
raise TypeError("The parameter cell is not RNNCell.")
|
||||
if output_size < 1:
|
||||
raise ValueError("Parameter output_size must be > 0: %d." % output_size)
|
||||
self._cell = cell
|
||||
self._output_size = output_size
|
||||
self._activation = activation
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._cell.state_size
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._output_size
|
||||
|
||||
def zero_state(self, batch_size, dtype):
|
||||
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
||||
return self._cell.zero_state(batch_size, dtype)
|
||||
|
||||
def call(self, inputs, state):
|
||||
"""Run the cell and output projection on inputs, starting from state."""
|
||||
output, res_state = self._cell(inputs, state)
|
||||
projected = _linear(output, self._output_size, True)
|
||||
if self._activation:
|
||||
projected = self._activation(projected)
|
||||
return projected, res_state
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
|
||||
import abc
|
||||
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn as contrib_rnn
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
|
||||
@ -116,12 +115,13 @@ class FusedRNNCellAdaptor(FusedRNNCell):
|
||||
else: # non-dynamic rnn
|
||||
if not is_list:
|
||||
inputs = array_ops.unstack(inputs)
|
||||
outputs, state = contrib_rnn.static_rnn(self._cell,
|
||||
inputs,
|
||||
initial_state=initial_state,
|
||||
dtype=dtype,
|
||||
sequence_length=sequence_length,
|
||||
scope=scope)
|
||||
outputs, state = rnn.static_rnn(
|
||||
self._cell,
|
||||
inputs,
|
||||
initial_state=initial_state,
|
||||
dtype=dtype,
|
||||
sequence_length=sequence_length,
|
||||
scope=scope)
|
||||
if not is_list:
|
||||
# Convert outputs back to tensor
|
||||
outputs = array_ops.stack(outputs)
|
||||
|
@ -18,13 +18,13 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.rnn.ops import gen_gru_ops
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import rnn_cell_impl
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.platform import resource_loader
|
||||
|
||||
@ -94,7 +94,7 @@ def _GRUBlockCellGrad(op, *grad):
|
||||
return d_x, d_h_prev, d_w_ru, d_w_c, d_b_ru, d_b_c
|
||||
|
||||
|
||||
class GRUBlockCell(core_rnn_cell.RNNCell):
|
||||
class GRUBlockCell(rnn_cell_impl.RNNCell):
|
||||
r"""Block GRU cell implementation.
|
||||
|
||||
The implementation is based on: http://arxiv.org/abs/1406.1078
|
||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
import abc
|
||||
|
||||
from tensorflow.contrib.rnn.ops import gen_lstm_ops
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
|
||||
from tensorflow.contrib.rnn.python.ops import fused_rnn_cell
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -29,6 +28,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import rnn_cell_impl
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.platform import resource_loader
|
||||
|
||||
@ -325,7 +325,7 @@ def _BlockLSTMGrad(op, *grad):
|
||||
wcf_grad, b_grad]
|
||||
|
||||
|
||||
class LSTMBlockCell(core_rnn_cell.RNNCell):
|
||||
class LSTMBlockCell(rnn_cell_impl.RNNCell):
|
||||
"""Basic LSTM recurrent network cell.
|
||||
|
||||
The implementation is based on: http://arxiv.org/abs/1409.2329.
|
||||
@ -333,7 +333,7 @@ class LSTMBlockCell(core_rnn_cell.RNNCell):
|
||||
We add `forget_bias` (default: 1) to the biases of the forget gate in order to
|
||||
reduce the scale of forgetting in the beginning of the training.
|
||||
|
||||
Unlike `core_rnn_cell.LSTMCell`, this is a monolithic op and should be much
|
||||
Unlike `rnn_cell_impl.LSTMCell`, this is a monolithic op and should be much
|
||||
faster. The weight and bias matrices should be compatible as long as the
|
||||
variable scope matches.
|
||||
"""
|
||||
@ -363,7 +363,7 @@ class LSTMBlockCell(core_rnn_cell.RNNCell):
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units)
|
||||
return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
@ -402,7 +402,7 @@ class LSTMBlockCell(core_rnn_cell.RNNCell):
|
||||
forget_bias=self._forget_bias,
|
||||
use_peephole=self._use_peephole)
|
||||
|
||||
new_state = core_rnn_cell.LSTMStateTuple(cs, h)
|
||||
new_state = rnn_cell_impl.LSTMStateTuple(cs, h)
|
||||
return h, new_state
|
||||
|
||||
|
||||
@ -546,8 +546,7 @@ class LSTMBlockWrapper(fused_rnn_cell.FusedRNNCell):
|
||||
# Input was a list, so return a list
|
||||
outputs = array_ops.unstack(outputs)
|
||||
|
||||
final_state = core_rnn_cell.LSTMStateTuple(final_cell_state,
|
||||
final_output)
|
||||
final_state = rnn_cell_impl.LSTMStateTuple(final_cell_state, final_output)
|
||||
return outputs, final_state
|
||||
|
||||
def _gather_states(self, data, indices, batch_size):
|
||||
@ -569,7 +568,7 @@ class LSTMBlockFusedCell(LSTMBlockWrapper):
|
||||
We add forget_bias (default: 1) to the biases of the forget gate in order to
|
||||
reduce the scale of forgetting in the beginning of the training.
|
||||
|
||||
The variable naming is consistent with `core_rnn_cell.LSTMCell`.
|
||||
The variable naming is consistent with `rnn_cell_impl.LSTMCell`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn as contrib_rnn
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
@ -106,7 +105,7 @@ def stack_bidirectional_rnn(cells_fw,
|
||||
initial_state_bw = initial_states_bw[i]
|
||||
|
||||
with vs.variable_scope("cell_%d" % i) as cell_scope:
|
||||
prev_layer, state_fw, state_bw = contrib_rnn.static_bidirectional_rnn(
|
||||
prev_layer, state_fw, state_bw = rnn.static_bidirectional_rnn(
|
||||
cell_fw,
|
||||
cell_bw,
|
||||
prev_layer,
|
||||
|
@ -23,8 +23,6 @@ import math
|
||||
|
||||
from tensorflow.contrib.compiler import jit
|
||||
from tensorflow.contrib.layers.python.layers import layers
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import op_def_registry
|
||||
from tensorflow.python.framework import ops
|
||||
@ -76,7 +74,7 @@ def _get_sharded_variable(name, shape, dtype, num_shards):
|
||||
return shards
|
||||
|
||||
|
||||
class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
|
||||
class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
|
||||
"""Long short-term memory unit (LSTM) recurrent network cell.
|
||||
|
||||
The default non-peephole implementation is based on:
|
||||
@ -154,14 +152,12 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
|
||||
self._reuse = reuse
|
||||
|
||||
if num_proj:
|
||||
self._state_size = (
|
||||
core_rnn_cell.LSTMStateTuple(num_units, num_proj)
|
||||
if state_is_tuple else num_units + num_proj)
|
||||
self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
|
||||
if state_is_tuple else num_units + num_proj)
|
||||
self._output_size = num_proj
|
||||
else:
|
||||
self._state_size = (
|
||||
core_rnn_cell.LSTMStateTuple(num_units, num_units)
|
||||
if state_is_tuple else 2 * num_units)
|
||||
self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units)
|
||||
if state_is_tuple else 2 * num_units)
|
||||
self._output_size = num_units
|
||||
|
||||
@property
|
||||
@ -254,12 +250,12 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
|
||||
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
|
||||
# pylint: enable=invalid-unary-operand-type
|
||||
|
||||
new_state = (core_rnn_cell.LSTMStateTuple(c, m) if self._state_is_tuple else
|
||||
array_ops.concat([c, m], 1))
|
||||
new_state = (rnn_cell_impl.LSTMStateTuple(c, m)
|
||||
if self._state_is_tuple else array_ops.concat([c, m], 1))
|
||||
return m, new_state
|
||||
|
||||
|
||||
class TimeFreqLSTMCell(core_rnn_cell.RNNCell):
|
||||
class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
|
||||
"""Time-Frequency Long short-term memory unit (LSTM) recurrent network cell.
|
||||
|
||||
This implementation is based on:
|
||||
@ -427,7 +423,7 @@ class TimeFreqLSTMCell(core_rnn_cell.RNNCell):
|
||||
return freq_inputs
|
||||
|
||||
|
||||
class GridLSTMCell(core_rnn_cell.RNNCell):
|
||||
class GridLSTMCell(rnn_cell_impl.RNNCell):
|
||||
"""Grid Long short-term memory unit (LSTM) recurrent network cell.
|
||||
|
||||
The default is based on:
|
||||
@ -1020,11 +1016,11 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
_linear = core_rnn_cell_impl._linear
|
||||
_linear = rnn_cell_impl._linear
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
class AttentionCellWrapper(core_rnn_cell.RNNCell):
|
||||
class AttentionCellWrapper(rnn_cell_impl.RNNCell):
|
||||
"""Basic attention cell wrapper.
|
||||
|
||||
Implementation based on https://arxiv.org/abs/1409.0473.
|
||||
@ -1155,7 +1151,7 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell):
|
||||
return new_attns, new_attn_states
|
||||
|
||||
|
||||
class HighwayWrapper(core_rnn_cell.RNNCell):
|
||||
class HighwayWrapper(rnn_cell_impl.RNNCell):
|
||||
"""RNNCell wrapper that adds highway connection on cell input and output.
|
||||
|
||||
Based on:
|
||||
@ -1238,7 +1234,7 @@ class HighwayWrapper(core_rnn_cell.RNNCell):
|
||||
return (res_outputs, new_state)
|
||||
|
||||
|
||||
class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell):
|
||||
class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
|
||||
"""LSTM unit with layer normalization and recurrent dropout.
|
||||
|
||||
This class adds layer normalization and recurrent dropout to a
|
||||
@ -1300,7 +1296,7 @@ class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell):
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units)
|
||||
return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
@ -1350,11 +1346,11 @@ class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell):
|
||||
new_c = self._norm(new_c, "state")
|
||||
new_h = self._activation(new_c) * math_ops.sigmoid(o)
|
||||
|
||||
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
|
||||
new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
|
||||
return new_h, new_state
|
||||
|
||||
|
||||
class NASCell(core_rnn_cell.RNNCell):
|
||||
class NASCell(rnn_cell_impl.RNNCell):
|
||||
"""Neural Architecture Search (NAS) recurrent network cell.
|
||||
|
||||
This implements the recurrent cell from the paper:
|
||||
@ -1388,10 +1384,10 @@ class NASCell(core_rnn_cell.RNNCell):
|
||||
self._reuse = reuse
|
||||
|
||||
if num_proj is not None:
|
||||
self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_proj)
|
||||
self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
|
||||
self._output_size = num_proj
|
||||
else:
|
||||
self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_units)
|
||||
self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
|
||||
self._output_size = num_units
|
||||
|
||||
@property
|
||||
@ -1498,11 +1494,11 @@ class NASCell(core_rnn_cell.RNNCell):
|
||||
dtype)
|
||||
new_m = math_ops.matmul(new_m, concat_w_proj)
|
||||
|
||||
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_m)
|
||||
new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_m)
|
||||
return new_m, new_state
|
||||
|
||||
|
||||
class UGRNNCell(core_rnn_cell.RNNCell):
|
||||
class UGRNNCell(rnn_cell_impl.RNNCell):
|
||||
"""Update Gate Recurrent Neural Network (UGRNN) cell.
|
||||
|
||||
Compromise between a LSTM/GRU and a vanilla RNN. There is only one
|
||||
@ -1589,7 +1585,7 @@ class UGRNNCell(core_rnn_cell.RNNCell):
|
||||
return new_output, new_state
|
||||
|
||||
|
||||
class IntersectionRNNCell(core_rnn_cell.RNNCell):
|
||||
class IntersectionRNNCell(rnn_cell_impl.RNNCell):
|
||||
"""Intersection Recurrent Neural Network (+RNN) cell.
|
||||
|
||||
Architecture with coupled recurrent gate as well as coupled depth
|
||||
@ -1712,7 +1708,7 @@ class IntersectionRNNCell(core_rnn_cell.RNNCell):
|
||||
_REGISTERED_OPS = None
|
||||
|
||||
|
||||
class CompiledWrapper(core_rnn_cell.RNNCell):
|
||||
class CompiledWrapper(rnn_cell_impl.RNNCell):
|
||||
"""Wraps step execution in an XLA JIT scope."""
|
||||
|
||||
def __init__(self, cell, compile_stateful=False):
|
||||
@ -1783,7 +1779,7 @@ def _random_exp_initializer(minval,
|
||||
return _initializer
|
||||
|
||||
|
||||
class PhasedLSTMCell(core_rnn_cell.RNNCell):
|
||||
class PhasedLSTMCell(rnn_cell_impl.RNNCell):
|
||||
"""Phased LSTM recurrent network cell.
|
||||
|
||||
https://arxiv.org/pdf/1610.09513v1.pdf
|
||||
@ -1831,7 +1827,7 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units)
|
||||
return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
@ -1858,13 +1854,13 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
|
||||
It stores the time.
|
||||
The second Tensor has shape [batch, features_size], and type float32.
|
||||
It stores the features.
|
||||
state: core_rnn_cell.LSTMStateTuple, state from previous timestep.
|
||||
state: rnn_cell_impl.LSTMStateTuple, state from previous timestep.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- A Tensor of float32, and shape [batch_size, num_units], representing the
|
||||
output of the cell.
|
||||
- A core_rnn_cell.LSTMStateTuple, containing 2 Tensors of float32, shape
|
||||
- A rnn_cell_impl.LSTMStateTuple, containing 2 Tensors of float32, shape
|
||||
[batch_size, num_units], representing the new state and the output.
|
||||
"""
|
||||
(c_prev, h_prev) = state
|
||||
@ -1921,12 +1917,12 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
|
||||
new_c = k * new_c + (1 - k) * c_prev
|
||||
new_h = k * new_h + (1 - k) * h_prev
|
||||
|
||||
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
|
||||
new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
|
||||
|
||||
return new_h, new_state
|
||||
|
||||
|
||||
class GLSTMCell(core_rnn_cell.RNNCell):
|
||||
class GLSTMCell(rnn_cell_impl.RNNCell):
|
||||
"""Group LSTM cell (G-LSTM).
|
||||
|
||||
The implementation is based on:
|
||||
@ -1982,10 +1978,10 @@ class GLSTMCell(core_rnn_cell.RNNCell):
|
||||
int(self._num_units / self._number_of_groups)]
|
||||
|
||||
if num_proj:
|
||||
self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_proj)
|
||||
self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
|
||||
self._output_size = num_proj
|
||||
else:
|
||||
self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_units)
|
||||
self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
|
||||
self._output_size = num_units
|
||||
|
||||
@property
|
||||
@ -2097,5 +2093,5 @@ class GLSTMCell(core_rnn_cell.RNNCell):
|
||||
with vs.variable_scope("projection"):
|
||||
m = _linear(m, self._num_proj, bias=False)
|
||||
|
||||
new_state = core_rnn_cell.LSTMStateTuple(c, m)
|
||||
new_state = rnn_cell_impl.LSTMStateTuple(c, m)
|
||||
return m, new_state
|
||||
|
@ -24,13 +24,13 @@ import functools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.rnn import core_rnn_cell
|
||||
from tensorflow.contrib.seq2seq.python.ops import decoder
|
||||
from tensorflow.contrib.seq2seq.python.ops import attention_wrapper as wrapper
|
||||
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
|
||||
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.platform import test
|
||||
@ -41,7 +41,7 @@ from tensorflow.python.util import nest
|
||||
|
||||
# for testing
|
||||
AttentionWrapperState = wrapper.AttentionWrapperState # pylint: disable=invalid-name
|
||||
LSTMStateTuple = core_rnn_cell.LSTMStateTuple # pylint: disable=invalid-name
|
||||
LSTMStateTuple = rnn_cell.LSTMStateTuple # pylint: disable=invalid-name
|
||||
BasicDecoderOutput = basic_decoder.BasicDecoderOutput # pylint: disable=invalid-name
|
||||
float32 = np.float32
|
||||
int32 = np.int32
|
||||
@ -112,7 +112,7 @@ class AttentionWrapperTest(test.TestCase):
|
||||
with vs.variable_scope(
|
||||
'root',
|
||||
initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
|
||||
cell = core_rnn_cell.LSTMCell(cell_depth)
|
||||
cell = rnn_cell.LSTMCell(cell_depth)
|
||||
cell = wrapper.AttentionWrapper(
|
||||
cell,
|
||||
attention_mechanism,
|
||||
@ -133,7 +133,7 @@ class AttentionWrapperTest(test.TestCase):
|
||||
self.assertTrue(
|
||||
isinstance(final_state, wrapper.AttentionWrapperState))
|
||||
self.assertTrue(
|
||||
isinstance(final_state.cell_state, core_rnn_cell.LSTMStateTuple))
|
||||
isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))
|
||||
|
||||
self.assertEqual((batch_size, None, attention_depth),
|
||||
tuple(final_outputs.rnn_output.get_shape().as_list()))
|
||||
|
@ -21,13 +21,13 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.rnn import core_rnn_cell
|
||||
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
|
||||
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.layers import core as layers_core
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
# pylint: enable=g-import-not-at-top
|
||||
@ -46,7 +46,7 @@ class BasicDecoderTest(test.TestCase):
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
inputs = np.random.randn(batch_size, max_time,
|
||||
input_depth).astype(np.float32)
|
||||
cell = core_rnn_cell.LSTMCell(cell_depth)
|
||||
cell = rnn_cell.LSTMCell(cell_depth)
|
||||
helper = helper_py.TrainingHelper(
|
||||
inputs, sequence_length, time_major=False)
|
||||
if use_output_layer:
|
||||
@ -77,8 +77,8 @@ class BasicDecoderTest(test.TestCase):
|
||||
constant_op.constant(0), first_inputs, first_state)
|
||||
batch_size_t = my_decoder.batch_size
|
||||
|
||||
self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(
|
||||
isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
|
||||
self.assertEqual((batch_size, expected_output_depth),
|
||||
@ -130,7 +130,7 @@ class BasicDecoderTest(test.TestCase):
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
embeddings = np.random.randn(vocabulary_size,
|
||||
input_depth).astype(np.float32)
|
||||
cell = core_rnn_cell.LSTMCell(vocabulary_size)
|
||||
cell = rnn_cell.LSTMCell(vocabulary_size)
|
||||
helper = helper_py.GreedyEmbeddingHelper(embeddings, start_tokens,
|
||||
end_token)
|
||||
my_decoder = basic_decoder.BasicDecoder(
|
||||
@ -154,8 +154,8 @@ class BasicDecoderTest(test.TestCase):
|
||||
constant_op.constant(0), first_inputs, first_state)
|
||||
batch_size_t = my_decoder.batch_size
|
||||
|
||||
self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(
|
||||
isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
|
||||
self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
|
||||
@ -202,7 +202,7 @@ class BasicDecoderTest(test.TestCase):
|
||||
embeddings = np.random.randn(
|
||||
vocabulary_size, input_depth).astype(np.float32)
|
||||
half = constant_op.constant(0.5)
|
||||
cell = core_rnn_cell.LSTMCell(vocabulary_size)
|
||||
cell = rnn_cell.LSTMCell(vocabulary_size)
|
||||
helper = helper_py.ScheduledEmbeddingTrainingHelper(
|
||||
inputs=inputs,
|
||||
sequence_length=sequence_length,
|
||||
@ -230,8 +230,8 @@ class BasicDecoderTest(test.TestCase):
|
||||
constant_op.constant(0), first_inputs, first_state)
|
||||
batch_size_t = my_decoder.batch_size
|
||||
|
||||
self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(
|
||||
isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
|
||||
self.assertEqual((batch_size, vocabulary_size),
|
||||
@ -293,7 +293,7 @@ class BasicDecoderTest(test.TestCase):
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
inputs = np.random.randn(batch_size, max_time,
|
||||
input_depth).astype(np.float32)
|
||||
cell = core_rnn_cell.LSTMCell(cell_depth)
|
||||
cell = rnn_cell.LSTMCell(cell_depth)
|
||||
sampling_probability = constant_op.constant(sampling_probability)
|
||||
|
||||
next_input_layer = None
|
||||
@ -335,8 +335,8 @@ class BasicDecoderTest(test.TestCase):
|
||||
|
||||
batch_size_t = my_decoder.batch_size
|
||||
|
||||
self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(
|
||||
isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
|
||||
self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
|
||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.rnn import core_rnn_cell
|
||||
from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
|
||||
from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder
|
||||
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
|
||||
@ -32,6 +31,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.layers import core as layers_core
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -241,7 +241,7 @@ class BeamSearchDecoderTest(test.TestCase):
|
||||
|
||||
with self.test_session() as sess:
|
||||
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
|
||||
cell = core_rnn_cell.LSTMCell(cell_depth)
|
||||
cell = rnn_cell.LSTMCell(cell_depth)
|
||||
if has_attention:
|
||||
inputs = np.random.randn(batch_size, decoder_max_time,
|
||||
input_depth).astype(np.float32)
|
||||
|
@ -21,12 +21,12 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.rnn import core_rnn_cell
|
||||
from tensorflow.contrib.seq2seq.python.ops import decoder
|
||||
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
|
||||
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.platform import test
|
||||
@ -51,7 +51,7 @@ class DynamicDecodeRNNTest(test.TestCase):
|
||||
else:
|
||||
inputs = np.random.randn(batch_size, max_time,
|
||||
input_depth).astype(np.float32)
|
||||
cell = core_rnn_cell.LSTMCell(cell_depth)
|
||||
cell = rnn_cell.LSTMCell(cell_depth)
|
||||
helper = helper_py.TrainingHelper(
|
||||
inputs, sequence_length, time_major=time_major)
|
||||
my_decoder = basic_decoder.BasicDecoder(
|
||||
@ -71,7 +71,7 @@ class DynamicDecodeRNNTest(test.TestCase):
|
||||
|
||||
self.assertTrue(
|
||||
isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
|
||||
self.assertTrue(isinstance(final_state, core_rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(final_state, rnn_cell.LSTMStateTuple))
|
||||
|
||||
self.assertEqual(
|
||||
(batch_size,),
|
||||
@ -126,7 +126,7 @@ class DynamicDecodeRNNTest(test.TestCase):
|
||||
inputs = np.random.randn(batch_size, max_time,
|
||||
input_depth).astype(np.float32)
|
||||
|
||||
cell = core_rnn_cell.LSTMCell(cell_depth)
|
||||
cell = rnn_cell.LSTMCell(cell_depth)
|
||||
zero_state = cell.zero_state(dtype=dtypes.float32, batch_size=batch_size)
|
||||
helper = helper_py.TrainingHelper(inputs, sequence_length)
|
||||
my_decoder = basic_decoder.BasicDecoder(
|
||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
||||
import collections
|
||||
import math
|
||||
|
||||
from tensorflow.contrib.rnn import core_rnn_cell
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
@ -500,7 +499,7 @@ def hardmax(logits, name=None):
|
||||
math_ops.argmax(logits, -1), depth, dtype=logits.dtype)
|
||||
|
||||
|
||||
class AttentionWrapper(core_rnn_cell.RNNCell):
|
||||
class AttentionWrapper(rnn_cell_impl.RNNCell):
|
||||
"""Wraps another `RNNCell` with attention.
|
||||
"""
|
||||
|
||||
|
@ -108,9 +108,7 @@ class PrintModelAnalysisTest(test.TestCase):
|
||||
|
||||
with gfile.Open(outfile, 'r') as f:
|
||||
# pylint: disable=line-too-long
|
||||
self.assertEqual(
|
||||
'_TFProfRoot (0/451 params, 0/10.44k flops, 0B/5.28KB)\n model_analyzer_testlib.py:33:BuildSmallModel:image = array_ops... (0/0 params, 0/0 flops, 0B/864B)\n model_analyzer_testlib.py:37:BuildSmallModel:initializer=init_... (0/1 params, 0/0 flops, 0B/0B)\n model_analyzer_testlib.py:41:BuildSmallModel:initializer=init_... (0/162 params, 0/0 flops, 0B/1.30KB)\n model_analyzer_testlib.py:42:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/5.83k flops, 0B/432B)\n model_analyzer_testlib.py:46:BuildSmallModel:initializer=init_... (0/288 params, 0/0 flops, 0B/2.30KB)\n model_analyzer_testlib.py:47:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/4.61k flops, 0B/384B)\n',
|
||||
f.read())
|
||||
self.assertEqual('_TFProfRoot (', f.read()[0:13])
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
def testComplexCodeView(self):
|
||||
@ -138,25 +136,28 @@ class PrintModelAnalysisTest(test.TestCase):
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
with gfile.Open(outfile, 'r') as f:
|
||||
self.assertEqual(
|
||||
'_TFProfRoot (0/2.84k params, 0/54.08k flops)\n model_analyzer_testlib.py:56:BuildFullModel:seq.append(array_... (0/1.80k params, 0/41.76k flops)\n model_analyzer_testlib.py:33:BuildSmallModel:image = array_ops... (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:37:BuildSmallModel:initializer=init_... (0/4 params, 0/0 flops)\n model_analyzer_testlib.py:41:BuildSmallModel:initializer=init_... (0/648 params, 0/0 flops)\n model_analyzer_testlib.py:42:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/23.33k flops)\n model_analyzer_testlib.py:46:BuildSmallModel:initializer=init_... (0/1.15k params, 0/0 flops)\n model_analyzer_testlib.py:47:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/18.43k flops)\n model_analyzer_testlib.py:60:BuildFullModel:cell, array_ops.c... (0/1.04k params, 0/4.13k flops)\n model_analyzer_testlib.py:62:BuildFullModel:target = array_op... (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:63:BuildFullModel:loss = nn_ops.l2_... (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:65:BuildFullModel:return sgd_op.min... (0/0 params, 0/8.19k flops)\n',
|
||||
f.read())
|
||||
self.assertEqual('_TFProfRoot (0', f.read()[:14])
|
||||
|
||||
self.assertLess(0, tfprof_node.total_exec_micros)
|
||||
self.assertEqual(2844, tfprof_node.total_parameters)
|
||||
self.assertEqual(54080, tfprof_node.total_float_ops)
|
||||
self.assertEqual(5, len(tfprof_node.children))
|
||||
self.assertEqual('_TFProfRoot', tfprof_node.name)
|
||||
self.assertEqual('model_analyzer_testlib.py:56:BuildFullModel:seq.append(array_...',
|
||||
tfprof_node.children[0].name)
|
||||
self.assertEqual('model_analyzer_testlib.py:60:BuildFullModel:cell, array_ops.c...',
|
||||
tfprof_node.children[1].name)
|
||||
self.assertEqual('model_analyzer_testlib.py:62:BuildFullModel:target = array_op...',
|
||||
tfprof_node.children[2].name)
|
||||
self.assertEqual('model_analyzer_testlib.py:63:BuildFullModel:loss = nn_ops.l2_...',
|
||||
tfprof_node.children[3].name)
|
||||
self.assertEqual('model_analyzer_testlib.py:65:BuildFullModel:return sgd_op.min...',
|
||||
tfprof_node.children[4].name)
|
||||
self.assertEqual(
|
||||
'model_analyzer_testlib.py:58:BuildFullModel:seq.append(array_...',
|
||||
tfprof_node.children[0].name)
|
||||
self.assertEqual(
|
||||
'model_analyzer_testlib.py:62:BuildFullModel:cell, array_ops.c...',
|
||||
tfprof_node.children[1].name)
|
||||
self.assertEqual(
|
||||
'model_analyzer_testlib.py:64:BuildFullModel:target = array_op...',
|
||||
tfprof_node.children[2].name)
|
||||
self.assertEqual(
|
||||
'model_analyzer_testlib.py:65:BuildFullModel:loss = nn_ops.l2_...',
|
||||
tfprof_node.children[3].name)
|
||||
self.assertEqual(
|
||||
'model_analyzer_testlib.py:67:BuildFullModel:return sgd_op.min...',
|
||||
tfprof_node.children[4].name)
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
def testCodeViewLeafGraphNode(self):
|
||||
|
@ -17,13 +17,15 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import BasicRNNCell
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.training import gradient_descent
|
||||
|
||||
@ -55,7 +57,7 @@ def BuildFullModel():
|
||||
with variable_scope.variable_scope('inp_%d' % i):
|
||||
seq.append(array_ops.reshape(BuildSmallModel(), [2, 1, -1]))
|
||||
|
||||
cell = BasicRNNCell(16, 48)
|
||||
cell = rnn_cell.BasicRNNCell(16)
|
||||
out = rnn.dynamic_rnn(
|
||||
cell, array_ops.concat(seq, axis=1), dtype=dtypes.float32)[0]
|
||||
|
||||
@ -63,5 +65,3 @@ def BuildFullModel():
|
||||
loss = nn_ops.l2_loss(math_ops.reduce_mean(target - out))
|
||||
sgd_op = gradient_descent.GradientDescentOptimizer(1e-2)
|
||||
return sgd_op.minimize(loss)
|
||||
|
||||
|
||||
|
@ -186,6 +186,6 @@ apply from: "download-models.gradle"
|
||||
|
||||
dependencies {
|
||||
if (nativeBuildSystem == 'cmake' || nativeBuildSystem == 'none') {
|
||||
compile 'org.tensorflow:tensorflow-android:1.2.0-preview'
|
||||
compile 'org.tensorflow:tensorflow-android:+'
|
||||
}
|
||||
}
|
||||
|
@ -211,6 +211,7 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Session_run(
|
||||
}
|
||||
|
||||
if (!throwExceptionIfNotOK(env, status)) {
|
||||
TF_DeleteStatus(status);
|
||||
return nullptr;
|
||||
}
|
||||
jlong* t = env->GetLongArrayElements(output_tensor_handles, nullptr);
|
||||
@ -226,5 +227,6 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Session_run(
|
||||
memcpy(elems, run_metadata->data, run_metadata->length);
|
||||
env->ReleaseByteArrayElements(ret, elems, JNI_COMMIT);
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
return ret;
|
||||
}
|
||||
|
@ -1709,14 +1709,23 @@ py_library(
|
||||
py_library(
|
||||
name = "rnn_cell",
|
||||
srcs = [
|
||||
"ops/rnn_cell.py",
|
||||
"ops/rnn_cell_impl.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":array_ops",
|
||||
":clip_ops",
|
||||
":framework_for_generated_wrappers",
|
||||
":init_ops",
|
||||
":layers_base",
|
||||
":math_ops",
|
||||
":nn_ops",
|
||||
":partitioned_variables",
|
||||
":random_ops",
|
||||
":util",
|
||||
":variable_scope",
|
||||
":variables",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -53,7 +53,7 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import gradient_descent
|
||||
|
||||
|
||||
class _RNNCellForTest(rnn_cell_impl._RNNCell): # pylint: disable=protected-access
|
||||
class _RNNCellForTest(rnn_cell_impl.RNNCell): # pylint: disable=protected-access
|
||||
"""RNN cell for testing."""
|
||||
|
||||
def __init__(self, input_output_size, state_size):
|
||||
|
@ -80,5 +80,6 @@ py_test(
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python/estimator:inputs",
|
||||
],
|
||||
)
|
||||
|
@ -140,6 +140,7 @@ from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
@ -1443,9 +1444,7 @@ class _LazyBuilder(object):
|
||||
return self._feature_tensors[key]
|
||||
|
||||
if key in self._features:
|
||||
# FeatureColumn is a raw feature.
|
||||
feature_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
|
||||
self._features[key])
|
||||
feature_tensor = self._get_raw_feature_as_tensor(key)
|
||||
self._feature_tensors[key] = feature_tensor
|
||||
return feature_tensor
|
||||
|
||||
@ -1464,6 +1463,55 @@ class _LazyBuilder(object):
|
||||
self._feature_tensors[column] = transformed
|
||||
return transformed
|
||||
|
||||
def _get_raw_feature_as_tensor(self, key):
|
||||
"""Gets the raw_feature (keyed by `key`) as `tensor`.
|
||||
|
||||
The raw feature is converted to (sparse) tensor and maybe expand dim.
|
||||
|
||||
For both `Tensor` and `SparseTensor`, the rank will be expanded (to 2) if
|
||||
the rank is 1. This supports dynamic rank also. For rank 0 raw feature, will
|
||||
error out as it is not supported.
|
||||
|
||||
Args:
|
||||
key: A `str` key to access the raw feature.
|
||||
|
||||
Returns:
|
||||
A `Tensor` or `SparseTensor`.
|
||||
|
||||
Raises:
|
||||
ValueError: if the raw feature has rank 0.
|
||||
"""
|
||||
raw_feature = self._features[key]
|
||||
feature_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
|
||||
raw_feature)
|
||||
|
||||
def expand_dims(input_tensor):
|
||||
# Input_tensor must have rank 1.
|
||||
if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
|
||||
return sparse_ops.sparse_reshape(
|
||||
input_tensor, [array_ops.shape(input_tensor)[0], -1])
|
||||
else:
|
||||
return array_ops.expand_dims(input_tensor, -1)
|
||||
|
||||
rank = feature_tensor.get_shape().ndims
|
||||
if rank is not None:
|
||||
if rank == 0:
|
||||
raise ValueError(
|
||||
'Feature (key: {}) cannot have rank 0. Give: {}'.format(
|
||||
key, feature_tensor))
|
||||
return feature_tensor if rank != 1 else expand_dims(feature_tensor)
|
||||
|
||||
# Handle dynamic rank.
|
||||
with ops.control_dependencies([
|
||||
check_ops.assert_positive(
|
||||
array_ops.rank(feature_tensor),
|
||||
message='Feature (key: {}) cannot have rank 0. Given: {}'.format(
|
||||
key, feature_tensor))]):
|
||||
return control_flow_ops.cond(
|
||||
math_ops.equal(1, array_ops.rank(feature_tensor)),
|
||||
lambda: expand_dims(feature_tensor),
|
||||
lambda: feature_tensor)
|
||||
|
||||
|
||||
# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
|
||||
def _shape_offsets(shape):
|
||||
|
@ -26,6 +26,7 @@ import numpy as np
|
||||
from tensorflow.core.example import example_pb2
|
||||
from tensorflow.core.example import feature_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.estimator.inputs import numpy_io
|
||||
from tensorflow.python.feature_column import feature_column_lib as fc
|
||||
from tensorflow.python.feature_column.feature_column import _CategoricalColumn
|
||||
from tensorflow.python.feature_column.feature_column import _DenseColumn
|
||||
@ -43,6 +44,8 @@ from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import coordinator
|
||||
from tensorflow.python.training import queue_runner_impl
|
||||
|
||||
|
||||
def _initialized_session():
|
||||
@ -1504,6 +1507,131 @@ class LinearModelTest(test.TestCase):
|
||||
features['price2']: [[1.], [5.]],
|
||||
})
|
||||
|
||||
def test_with_numpy_input_fn(self):
|
||||
price = fc.numeric_column('price')
|
||||
price_buckets = fc.bucketized_column(price, boundaries=[0., 10., 100.,])
|
||||
body_style = fc.categorical_column_with_vocabulary_list(
|
||||
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
|
||||
|
||||
input_fn = numpy_io.numpy_input_fn(
|
||||
x={
|
||||
'price': np.array([-1., 2., 13., 104.]),
|
||||
'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
|
||||
},
|
||||
batch_size=2,
|
||||
shuffle=False)
|
||||
features = input_fn()
|
||||
net = fc.linear_model(features, [price_buckets, body_style])
|
||||
# self.assertEqual(1 + 3 + 5, net.shape[1])
|
||||
with _initialized_session() as sess:
|
||||
coord = coordinator.Coordinator()
|
||||
threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
|
||||
|
||||
bias = get_linear_model_bias()
|
||||
price_buckets_var = get_linear_model_column_var(price_buckets)
|
||||
body_style_var = get_linear_model_column_var(body_style)
|
||||
|
||||
sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
|
||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||
sess.run(bias.assign([5.]))
|
||||
|
||||
self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))
|
||||
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
|
||||
def test_with_1d_sparse_tensor(self):
|
||||
price = fc.numeric_column('price')
|
||||
price_buckets = fc.bucketized_column(price, boundaries=[0., 10., 100.,])
|
||||
body_style = fc.categorical_column_with_vocabulary_list(
|
||||
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
|
||||
|
||||
# Provides 1-dim tensor and dense tensor.
|
||||
features = {
|
||||
'price': constant_op.constant([-1., 12.,]),
|
||||
'body-style': sparse_tensor.SparseTensor(
|
||||
indices=((0,), (1,)),
|
||||
values=('sedan', 'hardtop'),
|
||||
dense_shape=(2,)),
|
||||
}
|
||||
self.assertEqual(1, features['price'].shape.ndims)
|
||||
self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
|
||||
|
||||
net = fc.linear_model(features, [price_buckets, body_style])
|
||||
with _initialized_session() as sess:
|
||||
bias = get_linear_model_bias()
|
||||
price_buckets_var = get_linear_model_column_var(price_buckets)
|
||||
body_style_var = get_linear_model_column_var(body_style)
|
||||
|
||||
sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
|
||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||
sess.run(bias.assign([5.]))
|
||||
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
|
||||
|
||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||
price = fc.numeric_column('price')
|
||||
price_buckets = fc.bucketized_column(price, boundaries=[0., 10., 100.,])
|
||||
body_style = fc.categorical_column_with_vocabulary_list(
|
||||
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
|
||||
country = fc.categorical_column_with_vocabulary_list(
|
||||
'country', vocabulary_list=['US', 'JP', 'CA'])
|
||||
|
||||
# Provides 1-dim tensor and dense tensor.
|
||||
features = {
|
||||
'price': array_ops.placeholder(dtypes.float32),
|
||||
'body-style': array_ops.sparse_placeholder(dtypes.string),
|
||||
'country': array_ops.placeholder(dtypes.string),
|
||||
}
|
||||
self.assertIsNone(features['price'].shape.ndims)
|
||||
self.assertIsNone(features['body-style'].get_shape().ndims)
|
||||
|
||||
price_data = np.array([-1., 12.])
|
||||
body_style_data = sparse_tensor.SparseTensorValue(
|
||||
indices=((0,), (1,)),
|
||||
values=('sedan', 'hardtop'),
|
||||
dense_shape=(2,))
|
||||
|
||||
net = fc.linear_model(features, [price_buckets, body_style])
|
||||
bias = get_linear_model_bias()
|
||||
price_buckets_var = get_linear_model_column_var(price_buckets)
|
||||
body_style_var = get_linear_model_column_var(body_style)
|
||||
with _initialized_session() as sess:
|
||||
sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
|
||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||
sess.run(bias.assign([5.]))
|
||||
|
||||
self.assertAllClose(
|
||||
[[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
||||
sess.run(net, feed_dict={
|
||||
features['price']: price_data,
|
||||
features['body-style']: body_style_data}))
|
||||
|
||||
# Dense categorical_column with unknown shape is not allowed.
|
||||
with self.assertRaisesRegexp(ValueError, 'Undefined input_tensor shape.'):
|
||||
fc.linear_model(features, [price_buckets, body_style, country])
|
||||
|
||||
def test_with_rank_0_feature(self):
|
||||
price = fc.numeric_column('price')
|
||||
features = {
|
||||
'price': constant_op.constant(0),
|
||||
}
|
||||
self.assertEqual(0, features['price'].shape.ndims)
|
||||
|
||||
# Static rank 0 should fail
|
||||
with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
|
||||
fc.linear_model(features, [price])
|
||||
|
||||
# Dynamic rank 0 should fail
|
||||
features = {
|
||||
'price': array_ops.placeholder(dtypes.float32),
|
||||
}
|
||||
net = fc.linear_model(features, [price])
|
||||
self.assertEqual(1, net.shape[1])
|
||||
with _initialized_session() as sess:
|
||||
with self.assertRaisesOpError('Feature .* cannot have rank 0'):
|
||||
sess.run(net, feed_dict={features['price']: np.array(1)})
|
||||
|
||||
|
||||
class InputLayerTest(test.TestCase):
|
||||
|
||||
@ -1663,6 +1791,180 @@ class InputLayerTest(test.TestCase):
|
||||
features['price2']: [[1.], [5.]],
|
||||
})
|
||||
|
||||
def test_with_numpy_input_fn(self):
|
||||
embedding_values = (
|
||||
(1., 2., 3., 4., 5.), # id 0
|
||||
(6., 7., 8., 9., 10.), # id 1
|
||||
(11., 12., 13., 14., 15.) # id 2
|
||||
)
|
||||
def _initializer(shape, dtype, partition_info):
|
||||
del shape, dtype, partition_info
|
||||
return embedding_values
|
||||
|
||||
# price has 1 dimension in input_layer
|
||||
price = fc.numeric_column('price')
|
||||
body_style = fc.categorical_column_with_vocabulary_list(
|
||||
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
|
||||
# one_hot_body_style has 3 dims in input_layer.
|
||||
one_hot_body_style = fc.indicator_column(body_style)
|
||||
# embedded_body_style has 5 dims in input_layer.
|
||||
embedded_body_style = fc.embedding_column(body_style, dimension=5,
|
||||
initializer=_initializer)
|
||||
|
||||
input_fn = numpy_io.numpy_input_fn(
|
||||
x={
|
||||
'price': np.array([11., 12., 13., 14.]),
|
||||
'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
|
||||
},
|
||||
batch_size=2,
|
||||
shuffle=False)
|
||||
features = input_fn()
|
||||
net = fc.input_layer(features,
|
||||
[price, one_hot_body_style, embedded_body_style])
|
||||
self.assertEqual(1 + 3 + 5, net.shape[1])
|
||||
with _initialized_session() as sess:
|
||||
coord = coordinator.Coordinator()
|
||||
threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
|
||||
|
||||
# Each row is formed by concatenating `embedded_body_style`,
|
||||
# `one_hot_body_style`, and `price` in order.
|
||||
self.assertAllEqual(
|
||||
[[11., 12., 13., 14., 15., 0., 0., 1., 11.],
|
||||
[1., 2., 3., 4., 5., 1., 0., 0., 12]],
|
||||
sess.run(net))
|
||||
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
|
||||
def test_with_1d_sparse_tensor(self):
|
||||
embedding_values = (
|
||||
(1., 2., 3., 4., 5.), # id 0
|
||||
(6., 7., 8., 9., 10.), # id 1
|
||||
(11., 12., 13., 14., 15.) # id 2
|
||||
)
|
||||
def _initializer(shape, dtype, partition_info):
|
||||
del shape, dtype, partition_info
|
||||
return embedding_values
|
||||
|
||||
# price has 1 dimension in input_layer
|
||||
price = fc.numeric_column('price')
|
||||
|
||||
# one_hot_body_style has 3 dims in input_layer.
|
||||
body_style = fc.categorical_column_with_vocabulary_list(
|
||||
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
|
||||
one_hot_body_style = fc.indicator_column(body_style)
|
||||
|
||||
# embedded_body_style has 5 dims in input_layer.
|
||||
country = fc.categorical_column_with_vocabulary_list(
|
||||
'country', vocabulary_list=['US', 'JP', 'CA'])
|
||||
embedded_country = fc.embedding_column(country, dimension=5,
|
||||
initializer=_initializer)
|
||||
|
||||
# Provides 1-dim tensor and dense tensor.
|
||||
features = {
|
||||
'price': constant_op.constant([11., 12.,]),
|
||||
'body-style': sparse_tensor.SparseTensor(
|
||||
indices=((0,), (1,)),
|
||||
values=('sedan', 'hardtop'),
|
||||
dense_shape=(2,)),
|
||||
# This is dense tensor for the categorical_column.
|
||||
'country': constant_op.constant(['CA', 'US']),
|
||||
}
|
||||
self.assertEqual(1, features['price'].shape.ndims)
|
||||
self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
|
||||
self.assertEqual(1, features['country'].shape.ndims)
|
||||
|
||||
net = fc.input_layer(features,
|
||||
[price, one_hot_body_style, embedded_country])
|
||||
self.assertEqual(1 + 3 + 5, net.shape[1])
|
||||
with _initialized_session() as sess:
|
||||
|
||||
# Each row is formed by concatenating `embedded_body_style`,
|
||||
# `one_hot_body_style`, and `price` in order.
|
||||
self.assertAllEqual(
|
||||
[[0., 0., 1., 11., 12., 13., 14., 15., 11.],
|
||||
[1., 0., 0., 1., 2., 3., 4., 5., 12.]],
|
||||
sess.run(net))
|
||||
|
||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||
embedding_values = (
|
||||
(1., 2., 3., 4., 5.), # id 0
|
||||
(6., 7., 8., 9., 10.), # id 1
|
||||
(11., 12., 13., 14., 15.) # id 2
|
||||
)
|
||||
def _initializer(shape, dtype, partition_info):
|
||||
del shape, dtype, partition_info
|
||||
return embedding_values
|
||||
|
||||
# price has 1 dimension in input_layer
|
||||
price = fc.numeric_column('price')
|
||||
|
||||
# one_hot_body_style has 3 dims in input_layer.
|
||||
body_style = fc.categorical_column_with_vocabulary_list(
|
||||
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
|
||||
one_hot_body_style = fc.indicator_column(body_style)
|
||||
|
||||
# embedded_body_style has 5 dims in input_layer.
|
||||
country = fc.categorical_column_with_vocabulary_list(
|
||||
'country', vocabulary_list=['US', 'JP', 'CA'])
|
||||
embedded_country = fc.embedding_column(country, dimension=5,
|
||||
initializer=_initializer)
|
||||
|
||||
# Provides 1-dim tensor and dense tensor.
|
||||
features = {
|
||||
'price': array_ops.placeholder(dtypes.float32),
|
||||
'body-style': array_ops.sparse_placeholder(dtypes.string),
|
||||
# This is dense tensor for the categorical_column.
|
||||
'country': array_ops.placeholder(dtypes.string),
|
||||
}
|
||||
self.assertIsNone(features['price'].shape.ndims)
|
||||
self.assertIsNone(features['body-style'].get_shape().ndims)
|
||||
self.assertIsNone(features['country'].shape.ndims)
|
||||
|
||||
price_data = np.array([11., 12.])
|
||||
body_style_data = sparse_tensor.SparseTensorValue(
|
||||
indices=((0,), (1,)),
|
||||
values=('sedan', 'hardtop'),
|
||||
dense_shape=(2,))
|
||||
|
||||
# Dense categorical_column with unknown shape is not allowed.
|
||||
with self.assertRaisesRegexp(ValueError, 'Undefined input_tensor shape.'):
|
||||
fc.input_layer(features, [price, one_hot_body_style, embedded_country])
|
||||
|
||||
net = fc.input_layer(features, [price, one_hot_body_style])
|
||||
self.assertEqual(1 + 3, net.shape[1])
|
||||
with _initialized_session() as sess:
|
||||
|
||||
# Each row is formed by concatenating `embedded_body_style`,
|
||||
# `one_hot_body_style`, and `price` in order.
|
||||
self.assertAllEqual(
|
||||
[[0., 0., 1., 11.], [1., 0., 0., 12.]],
|
||||
sess.run(net, feed_dict={
|
||||
features['price']: price_data,
|
||||
features['body-style']: body_style_data}))
|
||||
|
||||
def test_with_rank_0_feature(self):
|
||||
# price has 1 dimension in input_layer
|
||||
price = fc.numeric_column('price')
|
||||
features = {
|
||||
'price': constant_op.constant(0),
|
||||
}
|
||||
self.assertEqual(0, features['price'].shape.ndims)
|
||||
|
||||
# Static rank 0 should fail
|
||||
with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
|
||||
fc.input_layer(features, [price])
|
||||
|
||||
# Dynamic rank 0 should fail
|
||||
features = {
|
||||
'price': array_ops.placeholder(dtypes.float32),
|
||||
}
|
||||
net = fc.input_layer(features, [price])
|
||||
self.assertEqual(1, net.shape[1])
|
||||
with _initialized_session() as sess:
|
||||
with self.assertRaisesOpError('Feature .* cannot have rank 0'):
|
||||
sess.run(net, feed_dict={features['price']: np.array(1)})
|
||||
|
||||
|
||||
class MakeParseExampleSpecTest(test.TestCase):
|
||||
|
||||
|
@ -29,9 +29,11 @@ import threading
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
_portpicker_import_error = None
|
||||
try:
|
||||
import portpicker # pylint: disable=g-import-not-at-top
|
||||
except ImportError as _portpicker_import_error:
|
||||
except ImportError as _error:
|
||||
_portpicker_import_error = _error
|
||||
portpicker = None
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
@ -820,8 +822,8 @@ def create_local_cluster(num_workers, num_ps, protocol="grpc"):
|
||||
Raises:
|
||||
ImportError: if portpicker module was not found at load time
|
||||
"""
|
||||
if not portpicker:
|
||||
raise _portpicker_import_error
|
||||
if _portpicker_import_error:
|
||||
raise _portpicker_import_error # pylint: disable=raising-bad-type
|
||||
worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
|
||||
ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
|
||||
cluster_dict = {
|
||||
|
@ -42,7 +42,7 @@ import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class Plus1RNNCell(rnn_cell_impl._RNNCell):
|
||||
class Plus1RNNCell(rnn_cell_impl.RNNCell):
|
||||
"""RNN Cell generating (output, new_state) = (input + 1, state + 1)."""
|
||||
|
||||
@property
|
||||
@ -57,6 +57,24 @@ class Plus1RNNCell(rnn_cell_impl._RNNCell):
|
||||
return (input_ + 1, state + 1)
|
||||
|
||||
|
||||
class ScalarStateRNNCell(rnn_cell_impl.RNNCell):
|
||||
"""RNN Cell generating (output, new_state) = (input + 1, state + 1)."""
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return 1
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return tensor_shape.TensorShape([])
|
||||
|
||||
def zero_state(self, batch_size, dtype):
|
||||
return array_ops.zeros([], dtype=dtypes.int32)
|
||||
|
||||
def __call__(self, input_, state, scope=None):
|
||||
return (input_, state + 1)
|
||||
|
||||
|
||||
class RNNTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -78,6 +78,9 @@ See the @{$python/nn} guide.
|
||||
@@dynamic_rnn
|
||||
@@bidirectional_dynamic_rnn
|
||||
@@raw_rnn
|
||||
@@static_rnn
|
||||
@@static_state_saving_rnn
|
||||
@@static_bidirectional_rnn
|
||||
@@ctc_loss
|
||||
@@ctc_greedy_decoder
|
||||
@@ctc_beam_search_decoder
|
||||
@ -113,14 +116,15 @@ from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
# Bring more nn-associated functionality into this package.
|
||||
# go/tf-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.ops.ctc_ops import *
|
||||
from tensorflow.python.ops.nn_impl import *
|
||||
from tensorflow.python.ops.nn_ops import *
|
||||
from tensorflow.python.ops.candidate_sampling_ops import *
|
||||
from tensorflow.python.ops.embedding_ops import *
|
||||
from tensorflow.python.ops.rnn import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
||||
|
||||
# TODO(cwhipkey): sigmoid and tanh should not be exposed from tf.nn.
|
||||
@ -135,6 +139,7 @@ _allowed_symbols = [
|
||||
"lrn", # Excluded in gen_docs_combined.
|
||||
"relu_layer", # Excluded in gen_docs_combined.
|
||||
"xw_plus_b", # Excluded in gen_docs_combined.
|
||||
"rnn_cell", # rnn_cell is a submodule of tf.nn.
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols,
|
||||
|
@ -13,8 +13,16 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""RNN helpers for TensorFlow models."""
|
||||
"""RNN helpers for TensorFlow models.
|
||||
|
||||
|
||||
@@bidirectional_dynamic_rnn
|
||||
@@dynamic_rnn
|
||||
@@raw_rnn
|
||||
@@static_rnn
|
||||
@@static_state_saving_rnn
|
||||
@@static_bidirectional_rnn
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -1038,3 +1046,351 @@ def raw_rnn(cell, loop_fn,
|
||||
final_loop_state = None
|
||||
|
||||
return (emit_ta, final_state, final_loop_state)
|
||||
|
||||
|
||||
def static_rnn(cell,
|
||||
inputs,
|
||||
initial_state=None,
|
||||
dtype=None,
|
||||
sequence_length=None,
|
||||
scope=None):
|
||||
"""Creates a recurrent neural network specified by RNNCell `cell`.
|
||||
|
||||
The simplest form of RNN network generated is:
|
||||
|
||||
```python
|
||||
state = cell.zero_state(...)
|
||||
outputs = []
|
||||
for input_ in inputs:
|
||||
output, state = cell(input_, state)
|
||||
outputs.append(output)
|
||||
return (outputs, state)
|
||||
```
|
||||
However, a few other options are available:
|
||||
|
||||
An initial state can be provided.
|
||||
If the sequence_length vector is provided, dynamic calculation is performed.
|
||||
This method of calculation does not compute the RNN steps past the maximum
|
||||
sequence length of the minibatch (thus saving computational time),
|
||||
and properly propagates the state at an example's sequence length
|
||||
to the final state output.
|
||||
|
||||
The dynamic calculation performed is, at time `t` for batch row `b`,
|
||||
|
||||
```python
|
||||
(output, state)(b, t) =
|
||||
(t >= sequence_length(b))
|
||||
? (zeros(cell.output_size), states(b, sequence_length(b) - 1))
|
||||
: cell(input(b, t), state(b, t - 1))
|
||||
```
|
||||
|
||||
Args:
|
||||
cell: An instance of RNNCell.
|
||||
inputs: A length T list of inputs, each a `Tensor` of shape
|
||||
`[batch_size, input_size]`, or a nested tuple of such elements.
|
||||
initial_state: (optional) An initial state for the RNN.
|
||||
If `cell.state_size` is an integer, this must be
|
||||
a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`.
|
||||
If `cell.state_size` is a tuple, this should be a tuple of
|
||||
tensors having shapes `[batch_size, s] for s in cell.state_size`.
|
||||
dtype: (optional) The data type for the initial state and expected output.
|
||||
Required if initial_state is not provided or RNN state has a heterogeneous
|
||||
dtype.
|
||||
sequence_length: Specifies the length of each sequence in inputs.
|
||||
An int32 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`.
|
||||
scope: VariableScope for the created subgraph; defaults to "rnn".
|
||||
|
||||
Returns:
|
||||
A pair (outputs, state) where:
|
||||
|
||||
- outputs is a length T list of outputs (one for each input), or a nested
|
||||
tuple of such elements.
|
||||
- state is the final state
|
||||
|
||||
Raises:
|
||||
TypeError: If `cell` is not an instance of RNNCell.
|
||||
ValueError: If `inputs` is `None` or an empty list, or if the input depth
|
||||
(column size) cannot be inferred from inputs via shape inference.
|
||||
"""
|
||||
|
||||
if not _like_rnncell(cell):
|
||||
raise TypeError("cell must be an instance of RNNCell")
|
||||
if not nest.is_sequence(inputs):
|
||||
raise TypeError("inputs must be a sequence")
|
||||
if not inputs:
|
||||
raise ValueError("inputs must not be empty")
|
||||
|
||||
outputs = []
|
||||
# Create a new scope in which the caching device is either
|
||||
# determined by the parent scope, or is set to place the cached
|
||||
# Variable using the same placement as for the rest of the RNN.
|
||||
with vs.variable_scope(scope or "rnn") as varscope:
|
||||
if varscope.caching_device is None:
|
||||
varscope.set_caching_device(lambda op: op.device)
|
||||
|
||||
# Obtain the first sequence of the input
|
||||
first_input = inputs
|
||||
while nest.is_sequence(first_input):
|
||||
first_input = first_input[0]
|
||||
|
||||
# Temporarily avoid EmbeddingWrapper and seq2seq badness
|
||||
# TODO(lukaszkaiser): remove EmbeddingWrapper
|
||||
if first_input.get_shape().ndims != 1:
|
||||
|
||||
input_shape = first_input.get_shape().with_rank_at_least(2)
|
||||
fixed_batch_size = input_shape[0]
|
||||
|
||||
flat_inputs = nest.flatten(inputs)
|
||||
for flat_input in flat_inputs:
|
||||
input_shape = flat_input.get_shape().with_rank_at_least(2)
|
||||
batch_size, input_size = input_shape[0], input_shape[1:]
|
||||
fixed_batch_size.merge_with(batch_size)
|
||||
for i, size in enumerate(input_size):
|
||||
if size.value is None:
|
||||
raise ValueError(
|
||||
"Input size (dimension %d of inputs) must be accessible via "
|
||||
"shape inference, but saw value None." % i)
|
||||
else:
|
||||
fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0]
|
||||
|
||||
if fixed_batch_size.value:
|
||||
batch_size = fixed_batch_size.value
|
||||
else:
|
||||
batch_size = array_ops.shape(first_input)[0]
|
||||
if initial_state is not None:
|
||||
state = initial_state
|
||||
else:
|
||||
if not dtype:
|
||||
raise ValueError("If no initial_state is provided, "
|
||||
"dtype must be specified")
|
||||
state = cell.zero_state(batch_size, dtype)
|
||||
|
||||
if sequence_length is not None: # Prepare variables
|
||||
sequence_length = ops.convert_to_tensor(
|
||||
sequence_length, name="sequence_length")
|
||||
if sequence_length.get_shape().ndims not in (None, 1):
|
||||
raise ValueError(
|
||||
"sequence_length must be a vector of length batch_size")
|
||||
|
||||
def _create_zero_output(output_size):
|
||||
# convert int to TensorShape if necessary
|
||||
size = _concat(batch_size, output_size)
|
||||
output = array_ops.zeros(
|
||||
array_ops.stack(size), _infer_state_dtype(dtype, state))
|
||||
shape = _concat(fixed_batch_size.value, output_size, static=True)
|
||||
output.set_shape(tensor_shape.TensorShape(shape))
|
||||
return output
|
||||
|
||||
output_size = cell.output_size
|
||||
flat_output_size = nest.flatten(output_size)
|
||||
flat_zero_output = tuple(
|
||||
_create_zero_output(size) for size in flat_output_size)
|
||||
zero_output = nest.pack_sequence_as(
|
||||
structure=output_size, flat_sequence=flat_zero_output)
|
||||
|
||||
sequence_length = math_ops.to_int32(sequence_length)
|
||||
min_sequence_length = math_ops.reduce_min(sequence_length)
|
||||
max_sequence_length = math_ops.reduce_max(sequence_length)
|
||||
|
||||
for time, input_ in enumerate(inputs):
|
||||
if time > 0:
|
||||
varscope.reuse_variables()
|
||||
# pylint: disable=cell-var-from-loop
|
||||
call_cell = lambda: cell(input_, state)
|
||||
# pylint: enable=cell-var-from-loop
|
||||
if sequence_length is not None:
|
||||
(output, state) = _rnn_step(
|
||||
time=time,
|
||||
sequence_length=sequence_length,
|
||||
min_sequence_length=min_sequence_length,
|
||||
max_sequence_length=max_sequence_length,
|
||||
zero_output=zero_output,
|
||||
state=state,
|
||||
call_cell=call_cell,
|
||||
state_size=cell.state_size)
|
||||
else:
|
||||
(output, state) = call_cell()
|
||||
|
||||
outputs.append(output)
|
||||
|
||||
return (outputs, state)
|
||||
|
||||
|
||||
def static_state_saving_rnn(cell,
|
||||
inputs,
|
||||
state_saver,
|
||||
state_name,
|
||||
sequence_length=None,
|
||||
scope=None):
|
||||
"""RNN that accepts a state saver for time-truncated RNN calculation.
|
||||
|
||||
Args:
|
||||
cell: An instance of `RNNCell`.
|
||||
inputs: A length T list of inputs, each a `Tensor` of shape
|
||||
`[batch_size, input_size]`.
|
||||
state_saver: A state saver object with methods `state` and `save_state`.
|
||||
state_name: Python string or tuple of strings. The name to use with the
|
||||
state_saver. If the cell returns tuples of states (i.e.,
|
||||
`cell.state_size` is a tuple) then `state_name` should be a tuple of
|
||||
strings having the same length as `cell.state_size`. Otherwise it should
|
||||
be a single string.
|
||||
sequence_length: (optional) An int32/int64 vector size [batch_size].
|
||||
See the documentation for rnn() for more details about sequence_length.
|
||||
scope: VariableScope for the created subgraph; defaults to "rnn".
|
||||
|
||||
Returns:
|
||||
A pair (outputs, state) where:
|
||||
outputs is a length T list of outputs (one for each input)
|
||||
states is the final state
|
||||
|
||||
Raises:
|
||||
TypeError: If `cell` is not an instance of RNNCell.
|
||||
ValueError: If `inputs` is `None` or an empty list, or if the arity and
|
||||
type of `state_name` does not match that of `cell.state_size`.
|
||||
"""
|
||||
state_size = cell.state_size
|
||||
state_is_tuple = nest.is_sequence(state_size)
|
||||
state_name_tuple = nest.is_sequence(state_name)
|
||||
|
||||
if state_is_tuple != state_name_tuple:
|
||||
raise ValueError("state_name should be the same type as cell.state_size. "
|
||||
"state_name: %s, cell.state_size: %s" % (str(state_name),
|
||||
str(state_size)))
|
||||
|
||||
if state_is_tuple:
|
||||
state_name_flat = nest.flatten(state_name)
|
||||
state_size_flat = nest.flatten(state_size)
|
||||
|
||||
if len(state_name_flat) != len(state_size_flat):
|
||||
raise ValueError("#elems(state_name) != #elems(state_size): %d vs. %d" %
|
||||
(len(state_name_flat), len(state_size_flat)))
|
||||
|
||||
initial_state = nest.pack_sequence_as(
|
||||
structure=state_size,
|
||||
flat_sequence=[state_saver.state(s) for s in state_name_flat])
|
||||
else:
|
||||
initial_state = state_saver.state(state_name)
|
||||
|
||||
(outputs, state) = static_rnn(
|
||||
cell,
|
||||
inputs,
|
||||
initial_state=initial_state,
|
||||
sequence_length=sequence_length,
|
||||
scope=scope)
|
||||
|
||||
if state_is_tuple:
|
||||
flat_state = nest.flatten(state)
|
||||
state_name = nest.flatten(state_name)
|
||||
save_state = [
|
||||
state_saver.save_state(name, substate)
|
||||
for name, substate in zip(state_name, flat_state)
|
||||
]
|
||||
else:
|
||||
save_state = [state_saver.save_state(state_name, state)]
|
||||
|
||||
with ops.control_dependencies(save_state):
|
||||
last_output = outputs[-1]
|
||||
flat_last_output = nest.flatten(last_output)
|
||||
flat_last_output = [
|
||||
array_ops.identity(output) for output in flat_last_output
|
||||
]
|
||||
outputs[-1] = nest.pack_sequence_as(
|
||||
structure=last_output, flat_sequence=flat_last_output)
|
||||
|
||||
return (outputs, state)
|
||||
|
||||
|
||||
def static_bidirectional_rnn(cell_fw,
|
||||
cell_bw,
|
||||
inputs,
|
||||
initial_state_fw=None,
|
||||
initial_state_bw=None,
|
||||
dtype=None,
|
||||
sequence_length=None,
|
||||
scope=None):
|
||||
"""Creates a bidirectional recurrent neural network.
|
||||
|
||||
Similar to the unidirectional case above (rnn) but takes input and builds
|
||||
independent forward and backward RNNs with the final forward and backward
|
||||
outputs depth-concatenated, such that the output will have the format
|
||||
[time][batch][cell_fw.output_size + cell_bw.output_size]. The input_size of
|
||||
forward and backward cell must match. The initial state for both directions
|
||||
is zero by default (but can be set optionally) and no intermediate states are
|
||||
ever returned -- the network is fully unrolled for the given (passed in)
|
||||
length(s) of the sequence(s) or completely unrolled if length(s) is not given.
|
||||
|
||||
Args:
|
||||
cell_fw: An instance of RNNCell, to be used for forward direction.
|
||||
cell_bw: An instance of RNNCell, to be used for backward direction.
|
||||
inputs: A length T list of inputs, each a tensor of shape
|
||||
[batch_size, input_size], or a nested tuple of such elements.
|
||||
initial_state_fw: (optional) An initial state for the forward RNN.
|
||||
This must be a tensor of appropriate type and shape
|
||||
`[batch_size, cell_fw.state_size]`.
|
||||
If `cell_fw.state_size` is a tuple, this should be a tuple of
|
||||
tensors having shapes `[batch_size, s] for s in cell_fw.state_size`.
|
||||
initial_state_bw: (optional) Same as for `initial_state_fw`, but using
|
||||
the corresponding properties of `cell_bw`.
|
||||
dtype: (optional) The data type for the initial state. Required if
|
||||
either of the initial states are not provided.
|
||||
sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
|
||||
containing the actual lengths for each of the sequences.
|
||||
scope: VariableScope for the created subgraph; defaults to
|
||||
"bidirectional_rnn"
|
||||
|
||||
Returns:
|
||||
A tuple (outputs, output_state_fw, output_state_bw) where:
|
||||
outputs is a length `T` list of outputs (one for each input), which
|
||||
are depth-concatenated forward and backward outputs.
|
||||
output_state_fw is the final state of the forward rnn.
|
||||
output_state_bw is the final state of the backward rnn.
|
||||
|
||||
Raises:
|
||||
TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
|
||||
ValueError: If inputs is None or an empty list.
|
||||
"""
|
||||
|
||||
if not _like_rnncell(cell_fw):
|
||||
raise TypeError("cell_fw must be an instance of RNNCell")
|
||||
if not _like_rnncell(cell_bw):
|
||||
raise TypeError("cell_bw must be an instance of RNNCell")
|
||||
if not nest.is_sequence(inputs):
|
||||
raise TypeError("inputs must be a sequence")
|
||||
if not inputs:
|
||||
raise ValueError("inputs must not be empty")
|
||||
|
||||
with vs.variable_scope(scope or "bidirectional_rnn"):
|
||||
# Forward direction
|
||||
with vs.variable_scope("fw") as fw_scope:
|
||||
output_fw, output_state_fw = static_rnn(
|
||||
cell_fw,
|
||||
inputs,
|
||||
initial_state_fw,
|
||||
dtype,
|
||||
sequence_length,
|
||||
scope=fw_scope)
|
||||
|
||||
# Backward direction
|
||||
with vs.variable_scope("bw") as bw_scope:
|
||||
reversed_inputs = _reverse_seq(inputs, sequence_length)
|
||||
tmp, output_state_bw = static_rnn(
|
||||
cell_bw,
|
||||
reversed_inputs,
|
||||
initial_state_bw,
|
||||
dtype,
|
||||
sequence_length,
|
||||
scope=bw_scope)
|
||||
|
||||
output_bw = _reverse_seq(tmp, sequence_length)
|
||||
# Concat each of the forward/backward outputs
|
||||
flat_output_fw = nest.flatten(output_fw)
|
||||
flat_output_bw = nest.flatten(output_bw)
|
||||
|
||||
flat_outputs = tuple(
|
||||
array_ops.concat([fw, bw], 1)
|
||||
for fw, bw in zip(flat_output_fw, flat_output_bw))
|
||||
|
||||
outputs = nest.pack_sequence_as(
|
||||
structure=output_fw, flat_sequence=flat_outputs)
|
||||
|
||||
return (outputs, output_state_fw, output_state_bw)
|
||||
|
51
tensorflow/python/ops/rnn_cell.py
Normal file
51
tensorflow/python/ops/rnn_cell.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Module for constructing RNN Cells.
|
||||
|
||||
## Base interface for all RNN Cells
|
||||
|
||||
@@RNNCell
|
||||
|
||||
## RNN Cells for use with TensorFlow's core RNN methods
|
||||
|
||||
@@BasicRNNCell
|
||||
@@BasicLSTMCell
|
||||
@@GRUCell
|
||||
@@LSTMCell
|
||||
|
||||
## Classes storing split `RNNCell` state
|
||||
|
||||
@@LSTMStateTuple
|
||||
|
||||
## RNN Cell wrappers (RNNCells that wrap other RNNCells)
|
||||
|
||||
@@MultiRNNCell
|
||||
@@DropoutWrapper
|
||||
@@DeviceWrapper
|
||||
@@ResidualWrapper
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# go/tf-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.python.ops.rnn_cell_impl import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = []
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
@ -12,18 +12,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Module implementing RNN Cells.
|
||||
|
||||
This module contains the abstract definition of a RNN cell: `_RNNCell`.
|
||||
Actual implementations of various types of RNN cells are located in
|
||||
`tensorflow.contrib`.
|
||||
This module provides a number of basic commonly used RNN cells, such as LSTM
|
||||
(Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number of
|
||||
operators that allow adding dropouts, projections, or embeddings for inputs.
|
||||
Constructing multi-layer cells is supported by the class `MultiRNNCell`, or by
|
||||
calling the `rnn` ops several times.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import hashlib
|
||||
import numbers
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -31,11 +35,22 @@ from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.layers import base as base_layer
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import partitioned_variables
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
_BIAS_VARIABLE_NAME = "bias"
|
||||
_WEIGHTS_VARIABLE_NAME = "kernel"
|
||||
|
||||
|
||||
def _like_rnncell(cell):
|
||||
"""Checks that a given object is an RNNCell by using duck typing."""
|
||||
conditions = [hasattr(cell, "output_size"), hasattr(cell, "state_size"),
|
||||
@ -115,7 +130,7 @@ def _zero_state_tensors(state_size, batch_size, dtype):
|
||||
return nest.map_structure(get_state_shape, state_size)
|
||||
|
||||
|
||||
class _RNNCell(base_layer.Layer):
|
||||
class RNNCell(base_layer.Layer):
|
||||
"""Abstract object representing an RNN cell.
|
||||
|
||||
Every `RNNCell` must have the properties below and implement `call` with
|
||||
@ -158,11 +173,11 @@ class _RNNCell(base_layer.Layer):
|
||||
if scope is not None:
|
||||
with vs.variable_scope(scope,
|
||||
custom_getter=self._rnn_get_variable) as scope:
|
||||
return super(_RNNCell, self).__call__(inputs, state, scope=scope)
|
||||
return super(RNNCell, self).__call__(inputs, state, scope=scope)
|
||||
else:
|
||||
with vs.variable_scope(vs.get_variable_scope(),
|
||||
custom_getter=self._rnn_get_variable):
|
||||
return super(_RNNCell, self).__call__(inputs, state)
|
||||
return super(RNNCell, self).__call__(inputs, state)
|
||||
|
||||
def _rnn_get_variable(self, getter, *args, **kwargs):
|
||||
variable = getter(*args, **kwargs)
|
||||
@ -212,3 +227,806 @@ class _RNNCell(base_layer.Layer):
|
||||
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
||||
state_size = self.state_size
|
||||
return _zero_state_tensors(state_size, batch_size, dtype)
|
||||
|
||||
|
||||
class BasicRNNCell(RNNCell):
|
||||
"""The most basic RNN cell.
|
||||
|
||||
Args:
|
||||
num_units: int, The number of units in the LSTM cell.
|
||||
activation: Nonlinearity to use. Default: `tanh`.
|
||||
reuse: (optional) Python boolean describing whether to reuse variables
|
||||
in an existing scope. If not `True`, and the existing scope already has
|
||||
the given variables, an error is raised.
|
||||
"""
|
||||
|
||||
def __init__(self, num_units, activation=None, reuse=None):
|
||||
super(BasicRNNCell, self).__init__(_reuse=reuse)
|
||||
self._num_units = num_units
|
||||
self._activation = activation or math_ops.tanh
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._num_units
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._num_units
|
||||
|
||||
def call(self, inputs, state):
|
||||
"""Most basic RNN: output = new_state = act(W * input + U * state + B)."""
|
||||
output = self._activation(_linear([inputs, state], self._num_units, True))
|
||||
return output, output
|
||||
|
||||
|
||||
class GRUCell(RNNCell):
|
||||
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078)."""
|
||||
|
||||
def __init__(self,
|
||||
num_units,
|
||||
activation=None,
|
||||
reuse=None,
|
||||
kernel_initializer=None,
|
||||
bias_initializer=None):
|
||||
super(GRUCell, self).__init__(_reuse=reuse)
|
||||
self._num_units = num_units
|
||||
self._activation = activation or math_ops.tanh
|
||||
self._kernel_initializer = kernel_initializer
|
||||
self._bias_initializer = bias_initializer
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._num_units
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._num_units
|
||||
|
||||
def call(self, inputs, state):
|
||||
"""Gated recurrent unit (GRU) with nunits cells."""
|
||||
with vs.variable_scope("gates"): # Reset gate and update gate.
|
||||
# We start with bias of 1.0 to not reset and not update.
|
||||
bias_ones = self._bias_initializer
|
||||
if self._bias_initializer is None:
|
||||
dtype = [a.dtype for a in [inputs, state]][0]
|
||||
bias_ones = init_ops.constant_initializer(1.0, dtype=dtype)
|
||||
value = math_ops.sigmoid(
|
||||
_linear([inputs, state], 2 * self._num_units, True, bias_ones,
|
||||
self._kernel_initializer))
|
||||
r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
|
||||
with vs.variable_scope("candidate"):
|
||||
c = self._activation(
|
||||
_linear([inputs, r * state], self._num_units, True,
|
||||
self._bias_initializer, self._kernel_initializer))
|
||||
new_h = u * state + (1 - u) * c
|
||||
return new_h, new_h
|
||||
|
||||
|
||||
_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h"))
|
||||
|
||||
|
||||
class LSTMStateTuple(_LSTMStateTuple):
|
||||
"""Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state.
|
||||
|
||||
Stores two elements: `(c, h)`, in that order.
|
||||
|
||||
Only used when `state_is_tuple=True`.
|
||||
"""
|
||||
__slots__ = ()
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
(c, h) = self
|
||||
if c.dtype != h.dtype:
|
||||
raise TypeError("Inconsistent internal state: %s vs %s" %
|
||||
(str(c.dtype), str(h.dtype)))
|
||||
return c.dtype
|
||||
|
||||
|
||||
class BasicLSTMCell(RNNCell):
|
||||
"""Basic LSTM recurrent network cell.
|
||||
|
||||
The implementation is based on: http://arxiv.org/abs/1409.2329.
|
||||
|
||||
We add forget_bias (default: 1) to the biases of the forget gate in order to
|
||||
reduce the scale of forgetting in the beginning of the training.
|
||||
|
||||
It does not allow cell clipping, a projection layer, and does not
|
||||
use peep-hole connections: it is the basic baseline.
|
||||
|
||||
For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}
|
||||
that follows.
|
||||
"""
|
||||
|
||||
def __init__(self, num_units, forget_bias=1.0,
|
||||
state_is_tuple=True, activation=None, reuse=None):
|
||||
"""Initialize the basic LSTM cell.
|
||||
|
||||
Args:
|
||||
num_units: int, The number of units in the LSTM cell.
|
||||
forget_bias: float, The bias added to forget gates (see above).
|
||||
state_is_tuple: If True, accepted and returned states are 2-tuples of
|
||||
the `c_state` and `m_state`. If False, they are concatenated
|
||||
along the column axis. The latter behavior will soon be deprecated.
|
||||
activation: Activation function of the inner states. Default: `tanh`.
|
||||
reuse: (optional) Python boolean describing whether to reuse variables
|
||||
in an existing scope. If not `True`, and the existing scope already has
|
||||
the given variables, an error is raised.
|
||||
"""
|
||||
super(BasicLSTMCell, self).__init__(_reuse=reuse)
|
||||
if not state_is_tuple:
|
||||
logging.warn("%s: Using a concatenated state is slower and will soon be "
|
||||
"deprecated. Use state_is_tuple=True.", self)
|
||||
self._num_units = num_units
|
||||
self._forget_bias = forget_bias
|
||||
self._state_is_tuple = state_is_tuple
|
||||
self._activation = activation or math_ops.tanh
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return (LSTMStateTuple(self._num_units, self._num_units)
|
||||
if self._state_is_tuple else 2 * self._num_units)
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._num_units
|
||||
|
||||
def call(self, inputs, state):
|
||||
"""Long short-term memory cell (LSTM)."""
|
||||
sigmoid = math_ops.sigmoid
|
||||
# Parameters of gates are concatenated into one multiply for efficiency.
|
||||
if self._state_is_tuple:
|
||||
c, h = state
|
||||
else:
|
||||
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
|
||||
|
||||
concat = _linear([inputs, h], 4 * self._num_units, True)
|
||||
|
||||
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
||||
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
|
||||
|
||||
new_c = (
|
||||
c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
|
||||
new_h = self._activation(new_c) * sigmoid(o)
|
||||
|
||||
if self._state_is_tuple:
|
||||
new_state = LSTMStateTuple(new_c, new_h)
|
||||
else:
|
||||
new_state = array_ops.concat([new_c, new_h], 1)
|
||||
return new_h, new_state
|
||||
|
||||
|
||||
class LSTMCell(RNNCell):
|
||||
"""Long short-term memory unit (LSTM) recurrent network cell.
|
||||
|
||||
The default non-peephole implementation is based on:
|
||||
|
||||
http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
|
||||
|
||||
S. Hochreiter and J. Schmidhuber.
|
||||
"Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
|
||||
|
||||
The peephole implementation is based on:
|
||||
|
||||
https://research.google.com/pubs/archive/43905.pdf
|
||||
|
||||
Hasim Sak, Andrew Senior, and Francoise Beaufays.
|
||||
"Long short-term memory recurrent neural network architectures for
|
||||
large scale acoustic modeling." INTERSPEECH, 2014.
|
||||
|
||||
The class uses optional peep-hole connections, optional cell clipping, and
|
||||
an optional projection layer.
|
||||
"""
|
||||
|
||||
def __init__(self, num_units,
|
||||
use_peepholes=False, cell_clip=None,
|
||||
initializer=None, num_proj=None, proj_clip=None,
|
||||
num_unit_shards=None, num_proj_shards=None,
|
||||
forget_bias=1.0, state_is_tuple=True,
|
||||
activation=None, reuse=None):
|
||||
"""Initialize the parameters for an LSTM cell.
|
||||
|
||||
Args:
|
||||
num_units: int, The number of units in the LSTM cell
|
||||
use_peepholes: bool, set True to enable diagonal/peephole connections.
|
||||
cell_clip: (optional) A float value, if provided the cell state is clipped
|
||||
by this value prior to the cell output activation.
|
||||
initializer: (optional) The initializer to use for the weight and
|
||||
projection matrices.
|
||||
num_proj: (optional) int, The output dimensionality for the projection
|
||||
matrices. If None, no projection is performed.
|
||||
proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is
|
||||
provided, then the projected values are clipped elementwise to within
|
||||
`[-proj_clip, proj_clip]`.
|
||||
num_unit_shards: Deprecated, will be removed by Jan. 2017.
|
||||
Use a variable_scope partitioner instead.
|
||||
num_proj_shards: Deprecated, will be removed by Jan. 2017.
|
||||
Use a variable_scope partitioner instead.
|
||||
forget_bias: Biases of the forget gate are initialized by default to 1
|
||||
in order to reduce the scale of forgetting at the beginning of
|
||||
the training.
|
||||
state_is_tuple: If True, accepted and returned states are 2-tuples of
|
||||
the `c_state` and `m_state`. If False, they are concatenated
|
||||
along the column axis. This latter behavior will soon be deprecated.
|
||||
activation: Activation function of the inner states. Default: `tanh`.
|
||||
reuse: (optional) Python boolean describing whether to reuse variables
|
||||
in an existing scope. If not `True`, and the existing scope already has
|
||||
the given variables, an error is raised.
|
||||
"""
|
||||
super(LSTMCell, self).__init__(_reuse=reuse)
|
||||
if not state_is_tuple:
|
||||
logging.warn("%s: Using a concatenated state is slower and will soon be "
|
||||
"deprecated. Use state_is_tuple=True.", self)
|
||||
if num_unit_shards is not None or num_proj_shards is not None:
|
||||
logging.warn(
|
||||
"%s: The num_unit_shards and proj_unit_shards parameters are "
|
||||
"deprecated and will be removed in Jan 2017. "
|
||||
"Use a variable scope with a partitioner instead.", self)
|
||||
|
||||
self._num_units = num_units
|
||||
self._use_peepholes = use_peepholes
|
||||
self._cell_clip = cell_clip
|
||||
self._initializer = initializer
|
||||
self._num_proj = num_proj
|
||||
self._proj_clip = proj_clip
|
||||
self._num_unit_shards = num_unit_shards
|
||||
self._num_proj_shards = num_proj_shards
|
||||
self._forget_bias = forget_bias
|
||||
self._state_is_tuple = state_is_tuple
|
||||
self._activation = activation or math_ops.tanh
|
||||
|
||||
if num_proj:
|
||||
self._state_size = (
|
||||
LSTMStateTuple(num_units, num_proj)
|
||||
if state_is_tuple else num_units + num_proj)
|
||||
self._output_size = num_proj
|
||||
else:
|
||||
self._state_size = (
|
||||
LSTMStateTuple(num_units, num_units)
|
||||
if state_is_tuple else 2 * num_units)
|
||||
self._output_size = num_units
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._state_size
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._output_size
|
||||
|
||||
def call(self, inputs, state):
|
||||
"""Run one step of LSTM.
|
||||
|
||||
Args:
|
||||
inputs: input Tensor, 2D, batch x num_units.
|
||||
state: if `state_is_tuple` is False, this must be a state Tensor,
|
||||
`2-D, batch x state_size`. If `state_is_tuple` is True, this must be a
|
||||
tuple of state Tensors, both `2-D`, with column sizes `c_state` and
|
||||
`m_state`.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
|
||||
- A `2-D, [batch x output_dim]`, Tensor representing the output of the
|
||||
LSTM after reading `inputs` when previous state was `state`.
|
||||
Here output_dim is:
|
||||
num_proj if num_proj was set,
|
||||
num_units otherwise.
|
||||
- Tensor(s) representing the new state of LSTM after reading `inputs` when
|
||||
the previous state was `state`. Same type and shape(s) as `state`.
|
||||
|
||||
Raises:
|
||||
ValueError: If input size cannot be inferred from inputs via
|
||||
static shape inference.
|
||||
"""
|
||||
num_proj = self._num_units if self._num_proj is None else self._num_proj
|
||||
sigmoid = math_ops.sigmoid
|
||||
|
||||
if self._state_is_tuple:
|
||||
(c_prev, m_prev) = state
|
||||
else:
|
||||
c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
|
||||
m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
|
||||
|
||||
dtype = inputs.dtype
|
||||
input_size = inputs.get_shape().with_rank(2)[1]
|
||||
if input_size.value is None:
|
||||
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
|
||||
scope = vs.get_variable_scope()
|
||||
with vs.variable_scope(scope, initializer=self._initializer) as unit_scope:
|
||||
if self._num_unit_shards is not None:
|
||||
unit_scope.set_partitioner(
|
||||
partitioned_variables.fixed_size_partitioner(
|
||||
self._num_unit_shards))
|
||||
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
||||
lstm_matrix = _linear([inputs, m_prev], 4 * self._num_units, bias=True)
|
||||
i, j, f, o = array_ops.split(
|
||||
value=lstm_matrix, num_or_size_splits=4, axis=1)
|
||||
# Diagonal connections
|
||||
if self._use_peepholes:
|
||||
with vs.variable_scope(unit_scope) as projection_scope:
|
||||
if self._num_unit_shards is not None:
|
||||
projection_scope.set_partitioner(None)
|
||||
w_f_diag = vs.get_variable(
|
||||
"w_f_diag", shape=[self._num_units], dtype=dtype)
|
||||
w_i_diag = vs.get_variable(
|
||||
"w_i_diag", shape=[self._num_units], dtype=dtype)
|
||||
w_o_diag = vs.get_variable(
|
||||
"w_o_diag", shape=[self._num_units], dtype=dtype)
|
||||
|
||||
if self._use_peepholes:
|
||||
c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
|
||||
sigmoid(i + w_i_diag * c_prev) * self._activation(j))
|
||||
else:
|
||||
c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
|
||||
self._activation(j))
|
||||
|
||||
if self._cell_clip is not None:
|
||||
# pylint: disable=invalid-unary-operand-type
|
||||
c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
|
||||
# pylint: enable=invalid-unary-operand-type
|
||||
if self._use_peepholes:
|
||||
m = sigmoid(o + w_o_diag * c) * self._activation(c)
|
||||
else:
|
||||
m = sigmoid(o) * self._activation(c)
|
||||
|
||||
if self._num_proj is not None:
|
||||
with vs.variable_scope("projection") as proj_scope:
|
||||
if self._num_proj_shards is not None:
|
||||
proj_scope.set_partitioner(
|
||||
partitioned_variables.fixed_size_partitioner(
|
||||
self._num_proj_shards))
|
||||
m = _linear(m, self._num_proj, bias=False)
|
||||
|
||||
if self._proj_clip is not None:
|
||||
# pylint: disable=invalid-unary-operand-type
|
||||
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
|
||||
# pylint: enable=invalid-unary-operand-type
|
||||
|
||||
new_state = (LSTMStateTuple(c, m) if self._state_is_tuple else
|
||||
array_ops.concat([c, m], 1))
|
||||
return m, new_state
|
||||
|
||||
|
||||
def _enumerated_map_structure(map_fn, *args, **kwargs):
|
||||
ix = [0]
|
||||
def enumerated_fn(*inner_args, **inner_kwargs):
|
||||
r = map_fn(ix[0], *inner_args, **inner_kwargs)
|
||||
ix[0] += 1
|
||||
return r
|
||||
return nest.map_structure(enumerated_fn, *args, **kwargs)
|
||||
|
||||
|
||||
class DropoutWrapper(RNNCell):
|
||||
"""Operator adding dropout to inputs and outputs of the given cell."""
|
||||
|
||||
def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0,
|
||||
state_keep_prob=1.0, variational_recurrent=False,
|
||||
input_size=None, dtype=None, seed=None):
|
||||
"""Create a cell with added input, state, and/or output dropout.
|
||||
|
||||
If `variational_recurrent` is set to `True` (**NOT** the default behavior),
|
||||
then the the same dropout mask is applied at every step, as described in:
|
||||
|
||||
Y. Gal, Z Ghahramani. "A Theoretically Grounded Application of Dropout in
|
||||
Recurrent Neural Networks". https://arxiv.org/abs/1512.05287
|
||||
|
||||
Otherwise a different dropout mask is applied at every time step.
|
||||
|
||||
Args:
|
||||
cell: an RNNCell, a projection to output_size is added to it.
|
||||
input_keep_prob: unit Tensor or float between 0 and 1, input keep
|
||||
probability; if it is constant and 1, no input dropout will be added.
|
||||
output_keep_prob: unit Tensor or float between 0 and 1, output keep
|
||||
probability; if it is constant and 1, no output dropout will be added.
|
||||
state_keep_prob: unit Tensor or float between 0 and 1, output keep
|
||||
probability; if it is constant and 1, no output dropout will be added.
|
||||
State dropout is performed on the *output* states of the cell.
|
||||
variational_recurrent: Python bool. If `True`, then the same
|
||||
dropout pattern is applied across all time steps per run call.
|
||||
If this parameter is set, `input_size` **must** be provided.
|
||||
input_size: (optional) (possibly nested tuple of) `TensorShape` objects
|
||||
containing the depth(s) of the input tensors expected to be passed in to
|
||||
the `DropoutWrapper`. Required and used **iff**
|
||||
`variational_recurrent = True` and `input_keep_prob < 1`.
|
||||
dtype: (optional) The `dtype` of the input, state, and output tensors.
|
||||
Required and used **iff** `variational_recurrent = True`.
|
||||
seed: (optional) integer, the randomness seed.
|
||||
|
||||
Raises:
|
||||
TypeError: if cell is not an RNNCell.
|
||||
ValueError: if any of the keep_probs are not between 0 and 1.
|
||||
"""
|
||||
if not _like_rnncell(cell):
|
||||
raise TypeError("The parameter cell is not a RNNCell.")
|
||||
with ops.name_scope("DropoutWrapperInit"):
|
||||
def tensor_and_const_value(v):
|
||||
tensor_value = ops.convert_to_tensor(v)
|
||||
const_value = tensor_util.constant_value(tensor_value)
|
||||
return (tensor_value, const_value)
|
||||
for prob, attr in [(input_keep_prob, "input_keep_prob"),
|
||||
(state_keep_prob, "state_keep_prob"),
|
||||
(output_keep_prob, "output_keep_prob")]:
|
||||
tensor_prob, const_prob = tensor_and_const_value(prob)
|
||||
if const_prob is not None:
|
||||
if const_prob < 0 or const_prob > 1:
|
||||
raise ValueError("Parameter %s must be between 0 and 1: %d"
|
||||
% (attr, const_prob))
|
||||
setattr(self, "_%s" % attr, float(const_prob))
|
||||
else:
|
||||
setattr(self, "_%s" % attr, tensor_prob)
|
||||
|
||||
# Set cell, variational_recurrent, seed before running the code below
|
||||
self._cell = cell
|
||||
self._variational_recurrent = variational_recurrent
|
||||
self._seed = seed
|
||||
|
||||
self._recurrent_input_noise = None
|
||||
self._recurrent_state_noise = None
|
||||
self._recurrent_output_noise = None
|
||||
|
||||
if variational_recurrent:
|
||||
if dtype is None:
|
||||
raise ValueError(
|
||||
"When variational_recurrent=True, dtype must be provided")
|
||||
|
||||
def convert_to_batch_shape(s):
|
||||
# Prepend a 1 for the batch dimension; for recurrent
|
||||
# variational dropout we use the same dropout mask for all
|
||||
# batch elements.
|
||||
return array_ops.concat(
|
||||
([1], tensor_shape.TensorShape(s).as_list()), 0)
|
||||
|
||||
def batch_noise(s, inner_seed):
|
||||
shape = convert_to_batch_shape(s)
|
||||
return random_ops.random_uniform(shape, seed=inner_seed, dtype=dtype)
|
||||
|
||||
if (not isinstance(self._input_keep_prob, numbers.Real) or
|
||||
self._input_keep_prob < 1.0):
|
||||
if input_size is None:
|
||||
raise ValueError(
|
||||
"When variational_recurrent=True and input_keep_prob < 1.0 or "
|
||||
"is unknown, input_size must be provided")
|
||||
self._recurrent_input_noise = _enumerated_map_structure(
|
||||
lambda i, s: batch_noise(s, inner_seed=self._gen_seed("input", i)),
|
||||
input_size)
|
||||
self._recurrent_state_noise = _enumerated_map_structure(
|
||||
lambda i, s: batch_noise(s, inner_seed=self._gen_seed("state", i)),
|
||||
cell.state_size)
|
||||
self._recurrent_output_noise = _enumerated_map_structure(
|
||||
lambda i, s: batch_noise(s, inner_seed=self._gen_seed("output", i)),
|
||||
cell.output_size)
|
||||
|
||||
def _gen_seed(self, salt_prefix, index):
|
||||
if self._seed is None:
|
||||
return None
|
||||
salt = "%s_%d" % (salt_prefix, index)
|
||||
string = (str(self._seed) + salt).encode("utf-8")
|
||||
return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._cell.state_size
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._cell.output_size
|
||||
|
||||
def zero_state(self, batch_size, dtype):
|
||||
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
||||
return self._cell.zero_state(batch_size, dtype)
|
||||
|
||||
def _variational_recurrent_dropout_value(
|
||||
self, index, value, noise, keep_prob):
|
||||
"""Performs dropout given the pre-calculated noise tensor."""
|
||||
# uniform [keep_prob, 1.0 + keep_prob)
|
||||
random_tensor = keep_prob + noise
|
||||
|
||||
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
|
||||
binary_tensor = math_ops.floor(random_tensor)
|
||||
ret = math_ops.div(value, keep_prob) * binary_tensor
|
||||
ret.set_shape(value.get_shape())
|
||||
return ret
|
||||
|
||||
def _dropout(self, values, salt_prefix, recurrent_noise, keep_prob):
|
||||
"""Decides whether to perform standard dropout or recurrent dropout."""
|
||||
if not self._variational_recurrent:
|
||||
def dropout(i, v):
|
||||
return nn_ops.dropout(
|
||||
v, keep_prob=keep_prob, seed=self._gen_seed(salt_prefix, i))
|
||||
return _enumerated_map_structure(dropout, values)
|
||||
else:
|
||||
def dropout(i, v, n):
|
||||
return self._variational_recurrent_dropout_value(i, v, n, keep_prob)
|
||||
return _enumerated_map_structure(dropout, values, recurrent_noise)
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
"""Run the cell with the declared dropouts."""
|
||||
def _should_dropout(p):
|
||||
return (not isinstance(p, float)) or p < 1
|
||||
|
||||
if _should_dropout(self._input_keep_prob):
|
||||
inputs = self._dropout(inputs, "input",
|
||||
self._recurrent_input_noise,
|
||||
self._input_keep_prob)
|
||||
output, new_state = self._cell(inputs, state, scope)
|
||||
if _should_dropout(self._state_keep_prob):
|
||||
new_state = self._dropout(new_state, "state",
|
||||
self._recurrent_state_noise,
|
||||
self._state_keep_prob)
|
||||
if _should_dropout(self._output_keep_prob):
|
||||
output = self._dropout(output, "output",
|
||||
self._recurrent_output_noise,
|
||||
self._output_keep_prob)
|
||||
return output, new_state
|
||||
|
||||
|
||||
class ResidualWrapper(RNNCell):
|
||||
"""RNNCell wrapper that ensures cell inputs are added to the outputs."""
|
||||
|
||||
def __init__(self, cell):
|
||||
"""Constructs a `ResidualWrapper` for `cell`.
|
||||
|
||||
Args:
|
||||
cell: An instance of `RNNCell`.
|
||||
"""
|
||||
self._cell = cell
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._cell.state_size
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._cell.output_size
|
||||
|
||||
def zero_state(self, batch_size, dtype):
|
||||
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
||||
return self._cell.zero_state(batch_size, dtype)
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
"""Run the cell and add its inputs to its outputs.
|
||||
|
||||
Args:
|
||||
inputs: cell inputs.
|
||||
state: cell state.
|
||||
scope: optional cell scope.
|
||||
|
||||
Returns:
|
||||
Tuple of cell outputs and new state.
|
||||
|
||||
Raises:
|
||||
TypeError: If cell inputs and outputs have different structure (type).
|
||||
ValueError: If cell inputs and outputs have different structure (value).
|
||||
"""
|
||||
outputs, new_state = self._cell(inputs, state, scope=scope)
|
||||
nest.assert_same_structure(inputs, outputs)
|
||||
# Ensure shapes match
|
||||
def assert_shape_match(inp, out):
|
||||
inp.get_shape().assert_is_compatible_with(out.get_shape())
|
||||
nest.map_structure(assert_shape_match, inputs, outputs)
|
||||
res_outputs = nest.map_structure(
|
||||
lambda inp, out: inp + out, inputs, outputs)
|
||||
return (res_outputs, new_state)
|
||||
|
||||
|
||||
class DeviceWrapper(RNNCell):
|
||||
"""Operator that ensures an RNNCell runs on a particular device."""
|
||||
|
||||
def __init__(self, cell, device):
|
||||
"""Construct a `DeviceWrapper` for `cell` with device `device`.
|
||||
|
||||
Ensures the wrapped `cell` is called with `tf.device(device)`.
|
||||
|
||||
Args:
|
||||
cell: An instance of `RNNCell`.
|
||||
device: A device string or function, for passing to `tf.device`.
|
||||
"""
|
||||
self._cell = cell
|
||||
self._device = device
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._cell.state_size
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._cell.output_size
|
||||
|
||||
def zero_state(self, batch_size, dtype):
|
||||
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
||||
with ops.device(self._device):
|
||||
return self._cell.zero_state(batch_size, dtype)
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
"""Run the cell on specified device."""
|
||||
with ops.device(self._device):
|
||||
return self._cell(inputs, state, scope=scope)
|
||||
|
||||
|
||||
class MultiRNNCell(RNNCell):
|
||||
"""RNN cell composed sequentially of multiple simple cells."""
|
||||
|
||||
def __init__(self, cells, state_is_tuple=True):
|
||||
"""Create a RNN cell composed sequentially of a number of RNNCells.
|
||||
|
||||
Args:
|
||||
cells: list of RNNCells that will be composed in this order.
|
||||
state_is_tuple: If True, accepted and returned states are n-tuples, where
|
||||
`n = len(cells)`. If False, the states are all
|
||||
concatenated along the column axis. This latter behavior will soon be
|
||||
deprecated.
|
||||
|
||||
Raises:
|
||||
ValueError: if cells is empty (not allowed), or at least one of the cells
|
||||
returns a state tuple but the flag `state_is_tuple` is `False`.
|
||||
"""
|
||||
super(MultiRNNCell, self).__init__()
|
||||
if not cells:
|
||||
raise ValueError("Must specify at least one cell for MultiRNNCell.")
|
||||
if not nest.is_sequence(cells):
|
||||
raise TypeError(
|
||||
"cells must be a list or tuple, but saw: %s." % cells)
|
||||
|
||||
self._cells = cells
|
||||
self._state_is_tuple = state_is_tuple
|
||||
if not state_is_tuple:
|
||||
if any(nest.is_sequence(c.state_size) for c in self._cells):
|
||||
raise ValueError("Some cells return tuples of states, but the flag "
|
||||
"state_is_tuple is not set. State sizes are: %s"
|
||||
% str([c.state_size for c in self._cells]))
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
if self._state_is_tuple:
|
||||
return tuple(cell.state_size for cell in self._cells)
|
||||
else:
|
||||
return sum([cell.state_size for cell in self._cells])
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._cells[-1].output_size
|
||||
|
||||
def zero_state(self, batch_size, dtype):
|
||||
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
||||
if self._state_is_tuple:
|
||||
return tuple(cell.zero_state(batch_size, dtype) for cell in self._cells)
|
||||
else:
|
||||
# We know here that state_size of each cell is not a tuple and
|
||||
# presumably does not contain TensorArrays or anything else fancy
|
||||
return super(MultiRNNCell, self).zero_state(batch_size, dtype)
|
||||
|
||||
def call(self, inputs, state):
|
||||
"""Run this multi-layer cell on inputs, starting from state."""
|
||||
cur_state_pos = 0
|
||||
cur_inp = inputs
|
||||
new_states = []
|
||||
for i, cell in enumerate(self._cells):
|
||||
with vs.variable_scope("cell_%d" % i):
|
||||
if self._state_is_tuple:
|
||||
if not nest.is_sequence(state):
|
||||
raise ValueError(
|
||||
"Expected state to be a tuple of length %d, but received: %s" %
|
||||
(len(self.state_size), state))
|
||||
cur_state = state[i]
|
||||
else:
|
||||
cur_state = array_ops.slice(state, [0, cur_state_pos],
|
||||
[-1, cell.state_size])
|
||||
cur_state_pos += cell.state_size
|
||||
cur_inp, new_state = cell(cur_inp, cur_state)
|
||||
new_states.append(new_state)
|
||||
|
||||
new_states = (tuple(new_states) if self._state_is_tuple else
|
||||
array_ops.concat(new_states, 1))
|
||||
|
||||
return cur_inp, new_states
|
||||
|
||||
|
||||
class _SlimRNNCell(RNNCell):
|
||||
"""A simple wrapper for slim.rnn_cells."""
|
||||
|
||||
def __init__(self, cell_fn):
|
||||
"""Create a SlimRNNCell from a cell_fn.
|
||||
|
||||
Args:
|
||||
cell_fn: a function which takes (inputs, state, scope) and produces the
|
||||
outputs and the new_state. Additionally when called with inputs=None and
|
||||
state=None it should return (initial_outputs, initial_state).
|
||||
|
||||
Raises:
|
||||
TypeError: if cell_fn is not callable
|
||||
ValueError: if cell_fn cannot produce a valid initial state.
|
||||
"""
|
||||
if not callable(cell_fn):
|
||||
raise TypeError("cell_fn %s needs to be callable", cell_fn)
|
||||
self._cell_fn = cell_fn
|
||||
self._cell_name = cell_fn.func.__name__
|
||||
init_output, init_state = self._cell_fn(None, None)
|
||||
output_shape = init_output.get_shape()
|
||||
state_shape = init_state.get_shape()
|
||||
self._output_size = output_shape.with_rank(2)[1].value
|
||||
self._state_size = state_shape.with_rank(2)[1].value
|
||||
if self._output_size is None:
|
||||
raise ValueError("Initial output created by %s has invalid shape %s" %
|
||||
(self._cell_name, output_shape))
|
||||
if self._state_size is None:
|
||||
raise ValueError("Initial state created by %s has invalid shape %s" %
|
||||
(self._cell_name, state_shape))
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._state_size
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._output_size
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
scope = scope or self._cell_name
|
||||
output, state = self._cell_fn(inputs, state, scope=scope)
|
||||
return output, state
|
||||
|
||||
|
||||
def _linear(args,
|
||||
output_size,
|
||||
bias,
|
||||
bias_initializer=None,
|
||||
kernel_initializer=None):
|
||||
"""Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
|
||||
|
||||
Args:
|
||||
args: a 2D Tensor or a list of 2D, batch x n, Tensors.
|
||||
output_size: int, second dimension of W[i].
|
||||
bias: boolean, whether to add a bias term or not.
|
||||
bias_initializer: starting value to initialize the bias
|
||||
(default is all zeros).
|
||||
kernel_initializer: starting value to initialize the weight.
|
||||
|
||||
Returns:
|
||||
A 2D Tensor with shape [batch x output_size] equal to
|
||||
sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
|
||||
|
||||
Raises:
|
||||
ValueError: if some of the arguments has unspecified or wrong shape.
|
||||
"""
|
||||
if args is None or (nest.is_sequence(args) and not args):
|
||||
raise ValueError("`args` must be specified")
|
||||
if not nest.is_sequence(args):
|
||||
args = [args]
|
||||
|
||||
# Calculate the total size of arguments on dimension 1.
|
||||
total_arg_size = 0
|
||||
shapes = [a.get_shape() for a in args]
|
||||
for shape in shapes:
|
||||
if shape.ndims != 2:
|
||||
raise ValueError("linear is expecting 2D arguments: %s" % shapes)
|
||||
if shape[1].value is None:
|
||||
raise ValueError("linear expects shape[1] to be provided for shape %s, "
|
||||
"but saw %s" % (shape, shape[1]))
|
||||
else:
|
||||
total_arg_size += shape[1].value
|
||||
|
||||
dtype = [a.dtype for a in args][0]
|
||||
|
||||
# Now the computation.
|
||||
scope = vs.get_variable_scope()
|
||||
with vs.variable_scope(scope) as outer_scope:
|
||||
weights = vs.get_variable(
|
||||
_WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size],
|
||||
dtype=dtype,
|
||||
initializer=kernel_initializer)
|
||||
if len(args) == 1:
|
||||
res = math_ops.matmul(args[0], weights)
|
||||
else:
|
||||
res = math_ops.matmul(array_ops.concat(args, 1), weights)
|
||||
if not bias:
|
||||
return res
|
||||
with vs.variable_scope(outer_scope) as inner_scope:
|
||||
inner_scope.set_partitioner(None)
|
||||
if bias_initializer is None:
|
||||
bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
|
||||
biases = vs.get_variable(
|
||||
_BIAS_VARIABLE_NAME, [output_size],
|
||||
dtype=dtype,
|
||||
initializer=bias_initializer)
|
||||
return nn_ops.bias_add(res, biases)
|
||||
|
@ -129,19 +129,23 @@ class Scaffold(object):
|
||||
copy_from_scaffold: Optional scaffold object to copy fields from. Its
|
||||
fields will be overwritten by the provided fields in this function.
|
||||
"""
|
||||
if copy_from_scaffold:
|
||||
if copy_from_scaffold is not None:
|
||||
if not isinstance(copy_from_scaffold, Scaffold):
|
||||
raise TypeError('copy_from_scaffold is not a Scaffold instance.')
|
||||
init_op = init_op or copy_from_scaffold.init_op
|
||||
init_feed_dict = init_feed_dict or copy_from_scaffold.init_feed_dict
|
||||
# We need _coalesce since Tensor is not converted to bool automatically,
|
||||
# so the common idiom of (a or b) does not work.
|
||||
coalesce = lambda a, b: a if a is not None else b
|
||||
init_op = coalesce(init_op, copy_from_scaffold.init_op)
|
||||
init_feed_dict = coalesce(init_feed_dict,
|
||||
copy_from_scaffold.init_feed_dict)
|
||||
# Use the original init_fn provided by the user to init the new Scaffold.
|
||||
init_fn = init_fn or copy_from_scaffold._user_init_fn # pylint: disable=protected-access
|
||||
ready_op = ready_op or copy_from_scaffold.ready_op
|
||||
ready_for_local_init_op = ready_for_local_init_op or (
|
||||
copy_from_scaffold.ready_for_local_init_op)
|
||||
local_init_op = local_init_op or copy_from_scaffold.local_init_op
|
||||
summary_op = summary_op or copy_from_scaffold.summary_op
|
||||
saver = saver or copy_from_scaffold.saver
|
||||
init_fn = coalesce(init_fn, copy_from_scaffold._user_init_fn) # pylint: disable=protected-access
|
||||
ready_op = coalesce(ready_op, copy_from_scaffold.ready_op)
|
||||
ready_for_local_init_op = coalesce(
|
||||
ready_for_local_init_op, copy_from_scaffold.ready_for_local_init_op)
|
||||
local_init_op = coalesce(local_init_op, copy_from_scaffold.local_init_op)
|
||||
summary_op = coalesce(summary_op, copy_from_scaffold.summary_op)
|
||||
saver = coalesce(saver, copy_from_scaffold.saver)
|
||||
|
||||
# NOTE(touts): modifying the init function to be passed the scaffold is a
|
||||
# hack to make it easy to find the saver. Is there a better way?
|
||||
@ -152,12 +156,12 @@ class Scaffold(object):
|
||||
self._init_fn = None
|
||||
|
||||
self._init_op = init_op
|
||||
self._init_feed_dict = init_feed_dict
|
||||
self._ready_op = ready_op
|
||||
self._ready_for_local_init_op = ready_for_local_init_op
|
||||
self._local_init_op = local_init_op
|
||||
self._summary_op = summary_op
|
||||
self._saver = saver
|
||||
self._init_feed_dict = init_feed_dict
|
||||
|
||||
def finalize(self):
|
||||
"""Creates operations if needed and finalizes the graph."""
|
||||
|
@ -1,5 +1,9 @@
|
||||
path: "tensorflow.nn"
|
||||
tf_module {
|
||||
member {
|
||||
name: "rnn_cell"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "all_candidate_sampler"
|
||||
argspec: "args=[\'true_classes\', \'num_true\', \'num_sampled\', \'unique\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
@ -284,6 +288,18 @@ tf_module {
|
||||
name: "sparse_softmax_cross_entropy_with_logits"
|
||||
argspec: "args=[\'_sentinel\', \'labels\', \'logits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "static_bidirectional_rnn"
|
||||
argspec: "args=[\'cell_fw\', \'cell_bw\', \'inputs\', \'initial_state_fw\', \'initial_state_bw\', \'dtype\', \'sequence_length\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "static_rnn"
|
||||
argspec: "args=[\'cell\', \'inputs\', \'initial_state\', \'dtype\', \'sequence_length\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "static_state_saving_rnn"
|
||||
argspec: "args=[\'cell\', \'inputs\', \'state_saver\', \'state_name\', \'sequence_length\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "sufficient_statistics"
|
||||
argspec: "args=[\'x\', \'axes\', \'shift\', \'keep_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
|
||||
|
@ -0,0 +1,95 @@
|
||||
path: "tensorflow.nn.rnn_cell.BasicLSTMCell"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.BasicLSTMCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "graph"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "losses"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "scope_name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "state_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "updates"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'num_units\', \'forget_bias\', \'state_is_tuple\', \'activation\', \'reuse\'], varargs=None, keywords=None, defaults=[\'1.0\', \'True\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_update"
|
||||
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_variable"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "build"
|
||||
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_losses_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "zero_state"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,95 @@
|
||||
path: "tensorflow.nn.rnn_cell.BasicRNNCell"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.BasicRNNCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "graph"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "losses"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "scope_name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "state_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "updates"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'num_units\', \'activation\', \'reuse\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_update"
|
||||
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_variable"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "build"
|
||||
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_losses_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "zero_state"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,95 @@
|
||||
path: "tensorflow.nn.rnn_cell.DeviceWrapper"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.DeviceWrapper\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "graph"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "losses"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "scope_name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "state_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "updates"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'cell\', \'device\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_update"
|
||||
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_variable"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "build"
|
||||
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_losses_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "zero_state"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,95 @@
|
||||
path: "tensorflow.nn.rnn_cell.DropoutWrapper"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.DropoutWrapper\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "graph"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "losses"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "scope_name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "state_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "updates"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'cell\', \'input_keep_prob\', \'output_keep_prob\', \'state_keep_prob\', \'variational_recurrent\', \'input_size\', \'dtype\', \'seed\'], varargs=None, keywords=None, defaults=[\'1.0\', \'1.0\', \'1.0\', \'False\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_update"
|
||||
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_variable"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "build"
|
||||
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_losses_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "zero_state"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,95 @@
|
||||
path: "tensorflow.nn.rnn_cell.GRUCell"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.GRUCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "graph"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "losses"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "scope_name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "state_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "updates"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'num_units\', \'activation\', \'reuse\', \'kernel_initializer\', \'bias_initializer\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_update"
|
||||
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_variable"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "build"
|
||||
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_losses_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "zero_state"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,95 @@
|
||||
path: "tensorflow.nn.rnn_cell.LSTMCell"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.LSTMCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "graph"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "losses"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "scope_name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "state_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "updates"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'num_units\', \'use_peepholes\', \'cell_clip\', \'initializer\', \'num_proj\', \'proj_clip\', \'num_unit_shards\', \'num_proj_shards\', \'forget_bias\', \'state_is_tuple\', \'activation\', \'reuse\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'1.0\', \'True\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_update"
|
||||
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_variable"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "build"
|
||||
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_losses_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "zero_state"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,27 @@
|
||||
path: "tensorflow.nn.rnn_cell.LSTMStateTuple"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.LSTMStateTuple\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.LSTMStateTuple\'>"
|
||||
is_instance: "<type \'tuple\'>"
|
||||
member {
|
||||
name: "c"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "h"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "count"
|
||||
}
|
||||
member_method {
|
||||
name: "index"
|
||||
}
|
||||
}
|
@ -0,0 +1,95 @@
|
||||
path: "tensorflow.nn.rnn_cell.MultiRNNCell"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "graph"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "losses"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "scope_name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "state_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "updates"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_update"
|
||||
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_variable"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "build"
|
||||
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_losses_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "zero_state"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,94 @@
|
||||
path: "tensorflow.nn.rnn_cell.RNNCell"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "graph"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "losses"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "scope_name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "state_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "updates"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'trainable\', \'name\', \'dtype\'], varargs=None, keywords=kwargs, defaults=[\'True\', \'None\', \"<dtype: \'float32\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_update"
|
||||
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_variable"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "build"
|
||||
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_losses_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "zero_state"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,95 @@
|
||||
path: "tensorflow.nn.rnn_cell.ResidualWrapper"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.ResidualWrapper\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "graph"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "losses"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "scope_name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "state_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "updates"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'cell\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_update"
|
||||
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_variable"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "build"
|
||||
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_losses_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "zero_state"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
43
tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.pbtxt
Normal file
43
tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.pbtxt
Normal file
@ -0,0 +1,43 @@
|
||||
path: "tensorflow.nn.rnn_cell"
|
||||
tf_module {
|
||||
member {
|
||||
name: "BasicLSTMCell"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "BasicRNNCell"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "DeviceWrapper"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "DropoutWrapper"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "GRUCell"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "LSTMCell"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "LSTMStateTuple"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "MultiRNNCell"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "RNNCell"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "ResidualWrapper"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
}
|
@ -12,12 +12,23 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Classes for converting parsed doc content into markdown pages."""
|
||||
"""A module for converting parsed doc content into markdown pages.
|
||||
|
||||
The adjacent `parser` module creates `PageInfo` objects, containing all data
|
||||
necessary to document an element of the TensorFlow API.
|
||||
|
||||
This module contains one public function, which handels the conversion of these
|
||||
`PageInfo` objects into a markdown string:
|
||||
|
||||
md_page = build_md_page(page_info)
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
|
||||
|
||||
def build_md_page(page_info):
|
||||
"""Given a PageInfo object, return markdown for the page.
|
||||
@ -46,7 +57,9 @@ def build_md_page(page_info):
|
||||
|
||||
def _build_function_page(page_info):
|
||||
"""Given a FunctionPageInfo object Return the page as an md string."""
|
||||
parts = ['# %s\n\n' % page_info.full_name]
|
||||
parts = [_Metadata(page_info.full_name).build_html()]
|
||||
|
||||
parts.append('# %s\n\n' % page_info.full_name)
|
||||
|
||||
if page_info.aliases:
|
||||
parts.extend('### `%s`\n' % name
|
||||
@ -70,7 +83,17 @@ def _build_function_page(page_info):
|
||||
|
||||
def _build_class_page(page_info):
|
||||
"""Given a ClassPageInfo object Return the page as an md string."""
|
||||
parts = ['# {page_info.full_name}\n\n'.format(page_info=page_info)]
|
||||
meta_data = _Metadata(page_info.full_name)
|
||||
for item in itertools.chain(
|
||||
page_info.classes,
|
||||
page_info.properties,
|
||||
page_info.methods,
|
||||
page_info.other_members):
|
||||
meta_data.append(item)
|
||||
|
||||
parts = [meta_data.build_html()]
|
||||
|
||||
parts.append('# {page_info.full_name}\n\n'.format(page_info=page_info))
|
||||
|
||||
if page_info.aliases:
|
||||
parts.extend('### `class %s`\n' % name for name in page_info.aliases)
|
||||
@ -150,8 +173,17 @@ def _build_class_page(page_info):
|
||||
|
||||
def _build_module_page(page_info):
|
||||
"""Given a ClassPageInfo object Return the page as an md string."""
|
||||
meta_data = _Metadata(page_info.full_name)
|
||||
|
||||
parts = ['# Module: {full_name}\n\n'.format(full_name=page_info.full_name)]
|
||||
# Objects with their own pages are not added to the matadata list for the
|
||||
# module, as the only thing on the module page is a link to the object's page.
|
||||
for item in page_info.other_members:
|
||||
meta_data.append(item)
|
||||
|
||||
parts = [meta_data.build_html()]
|
||||
|
||||
parts.append(
|
||||
'# Module: {full_name}\n\n'.format(full_name=page_info.full_name))
|
||||
|
||||
if page_info.aliases:
|
||||
parts.extend('### Module `%s`\n' % name for name in page_info.aliases)
|
||||
@ -261,3 +293,41 @@ def _build_function_details(function_details):
|
||||
parts.append(''.join(sub))
|
||||
|
||||
return '\n'.join(parts)
|
||||
|
||||
|
||||
class _Metadata(object):
|
||||
"""A class for building a page's Metadata block.
|
||||
|
||||
Attributes:
|
||||
name: The name of the page being described by the Metadata block.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
"""Creata a Metadata builder.
|
||||
|
||||
Args:
|
||||
name: The name of the page being described by the Metadata block.
|
||||
"""
|
||||
self.name = name
|
||||
self._content = []
|
||||
|
||||
def append(self, item):
|
||||
"""Add an item from the page to the Metadata block.
|
||||
|
||||
Args:
|
||||
item: The parsed page section to add.
|
||||
"""
|
||||
self._content.append(item.short_name)
|
||||
|
||||
def build_html(self):
|
||||
"""Return the Metadata block as an Html string."""
|
||||
schema = 'http://developers.google.com/ReferenceObject'
|
||||
parts = ['<div itemscope itemtype="%s">' % schema]
|
||||
|
||||
parts.append('<meta itemprop="name" content="%s" />' % self.name)
|
||||
for item in self._content:
|
||||
parts.append('<meta itemprop="property" content="%s"/>' % item)
|
||||
|
||||
parts.extend(['</div>', '', ''])
|
||||
|
||||
return '\n'.join(parts)
|
||||
|
Loading…
Reference in New Issue
Block a user