commit
04ba418175
27
RELEASE.md
27
RELEASE.md
@ -3,9 +3,8 @@
|
|||||||
## Major Features and Improvements
|
## Major Features and Improvements
|
||||||
* Added `tf.layers.conv3d_transpose` layer for spatio temporal deconvolution.
|
* Added `tf.layers.conv3d_transpose` layer for spatio temporal deconvolution.
|
||||||
* Added `tf.Session.make_callable()`, which provides a lower overhead means of running a similar step multiple times.
|
* Added `tf.Session.make_callable()`, which provides a lower overhead means of running a similar step multiple times.
|
||||||
* Added libverbs-based RDMA support to contrib (courtesy @junshi15 from Yahoo).
|
* Added ibverbs-based RDMA support to contrib (courtesy @junshi15 from Yahoo).
|
||||||
* Bring `tf.feature_column.*` into the API. Non-deprecated functionality from `tf.contrib.layers.*` is moved to `tf.feature_column.*` with cosmetic changes.
|
* `RNNCell` objects now subclass `tf.layers.Layer`. The strictness described
|
||||||
* `RNNCell` objects now subclass `tf.layers._Layer`. The strictness described
|
|
||||||
in the TensorFlow 1.1 release is gone: The first time an RNNCell is used,
|
in the TensorFlow 1.1 release is gone: The first time an RNNCell is used,
|
||||||
it caches its scope. All future uses of the RNNCell will reuse variables from
|
it caches its scope. All future uses of the RNNCell will reuse variables from
|
||||||
that same scope. This is a breaking change from the behavior of RNNCells
|
that same scope. This is a breaking change from the behavior of RNNCells
|
||||||
@ -23,6 +22,28 @@
|
|||||||
* TensorFlow C library now available for Windows.
|
* TensorFlow C library now available for Windows.
|
||||||
* We released a new open-source version of TensorBoard.
|
* We released a new open-source version of TensorBoard.
|
||||||
* [`SavedModel CLI`](https://www.tensorflow.org/versions/master/programmers_guide/saved_model_cli) tool available to inspect and execute MetaGraph in SavedModel
|
* [`SavedModel CLI`](https://www.tensorflow.org/versions/master/programmers_guide/saved_model_cli) tool available to inspect and execute MetaGraph in SavedModel
|
||||||
|
* RNNCells' variable names have been renamed for consistency with Keras layers.
|
||||||
|
Specifically, the previous variable names "weights" and "biases" have
|
||||||
|
been changed to "kernel" and "bias", respectively.
|
||||||
|
This may cause backward incompatibility with regard to your old
|
||||||
|
checkpoints containing such RNN cells, in which case you can use the tool
|
||||||
|
[checkpoint_convert script](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py)
|
||||||
|
to convert the variable names in your old checkpoints.
|
||||||
|
* Many of the RNN functions and classes that were in the `tf.nn` namespace
|
||||||
|
before the 1.0 release and which were moved to `tf.contrib.rnn` have now
|
||||||
|
been moved back to the core namespace. This includes
|
||||||
|
`RNNCell`, `LSTMCell`, `GRUCell`, and a number of other cells. These
|
||||||
|
now reside in `tf.nn.rnn_cell` (with aliases in `tf.contrib.rnn` for backwards
|
||||||
|
compatibility). The original `tf.nn.rnn` function is now `tf.nn.static_rnn`,
|
||||||
|
and the bidirectional static and state saving static rnn functions are also
|
||||||
|
now back in the `tf.nn` namespace.
|
||||||
|
|
||||||
|
Notable exceptions are the `EmbeddingWrapper`, `InputProjectionWrapper` and
|
||||||
|
`OutputProjectionWrapper`, which will slowly be moved to deprecation
|
||||||
|
in `tf.contrib.rnn`. These are inefficient wrappers that should often
|
||||||
|
be replaced by calling `embedding_lookup` or `layers.dense` as pre- or post-
|
||||||
|
processing of the rnn. For RNN decoding, this functionality has been replaced
|
||||||
|
with an alternative API in `tf.contrib.seq2seq`.
|
||||||
|
|
||||||
## Breaking Changes to the API
|
## Breaking Changes to the API
|
||||||
* `org.tensorflow.contrib.android.TensorFlowInferenceInterface` now throws exceptions where possible and has simplified method signatures.
|
* `org.tensorflow.contrib.android.TensorFlowInferenceInterface` now throws exceptions where possible and has simplified method signatures.
|
||||||
|
@ -961,8 +961,9 @@ typedef struct TF_WhileParams {
|
|||||||
// - Reference-type inputs
|
// - Reference-type inputs
|
||||||
// - Directly referencing external tensors from the cond/body graphs (this is
|
// - Directly referencing external tensors from the cond/body graphs (this is
|
||||||
// possible in the Python API)
|
// possible in the Python API)
|
||||||
TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
|
TF_CAPI_EXPORT extern TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs,
|
||||||
TF_Status* status);
|
int ninputs,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
// Builds the while loop specified by `params` and returns the output tensors of
|
// Builds the while loop specified by `params` and returns the output tensors of
|
||||||
// the while loop in `outputs`. `outputs` should be allocated to size
|
// the while loop in `outputs`. `outputs` should be allocated to size
|
||||||
@ -972,13 +973,14 @@ TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
|
|||||||
//
|
//
|
||||||
// Either this or TF_AbortWhile() must be called after a successful
|
// Either this or TF_AbortWhile() must be called after a successful
|
||||||
// TF_NewWhile() call.
|
// TF_NewWhile() call.
|
||||||
void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
|
TF_CAPI_EXPORT extern void TF_FinishWhile(const TF_WhileParams* params,
|
||||||
TF_Output* outputs);
|
TF_Status* status,
|
||||||
|
TF_Output* outputs);
|
||||||
|
|
||||||
// Frees `params`s resources without building a while loop. `params` is no
|
// Frees `params`s resources without building a while loop. `params` is no
|
||||||
// longer valid after this returns. Either this or TF_FinishWhile() must be
|
// longer valid after this returns. Either this or TF_FinishWhile() must be
|
||||||
// called after a successful TF_NewWhile() call.
|
// called after a successful TF_NewWhile() call.
|
||||||
void TF_AbortWhile(const TF_WhileParams* params);
|
TF_CAPI_EXPORT extern void TF_AbortWhile(const TF_WhileParams* params);
|
||||||
|
|
||||||
// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s,
|
// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s,
|
||||||
// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...
|
// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...
|
||||||
@ -994,8 +996,9 @@ void TF_AbortWhile(const TF_WhileParams* params);
|
|||||||
// supports. See
|
// supports. See
|
||||||
// https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md
|
// https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md
|
||||||
// for instructions on how to add C++ more gradients.
|
// for instructions on how to add C++ more gradients.
|
||||||
void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
|
TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny,
|
||||||
TF_Output* dx, TF_Status* status, TF_Output* dy);
|
TF_Output* x, int nx, TF_Output* dx,
|
||||||
|
TF_Status* status, TF_Output* dy);
|
||||||
|
|
||||||
// TODO(josh11b): Register OpDef, available to all operations added
|
// TODO(josh11b): Register OpDef, available to all operations added
|
||||||
// to this graph.
|
// to this graph.
|
||||||
@ -1032,7 +1035,7 @@ TF_CAPI_EXPORT extern TF_Session* TF_NewSession(TF_Graph* graph,
|
|||||||
//
|
//
|
||||||
// If successful, populates `graph` with the contents of the Graph and
|
// If successful, populates `graph` with the contents of the Graph and
|
||||||
// `meta_graph_def` with the MetaGraphDef of the loaded model.
|
// `meta_graph_def` with the MetaGraphDef of the loaded model.
|
||||||
TF_Session* TF_LoadSessionFromSavedModel(
|
TF_CAPI_EXPORT extern TF_Session* TF_LoadSessionFromSavedModel(
|
||||||
const TF_SessionOptions* session_options, const TF_Buffer* run_options,
|
const TF_SessionOptions* session_options, const TF_Buffer* run_options,
|
||||||
const char* export_dir, const char* const* tags, int tags_len,
|
const char* export_dir, const char* const* tags, int tags_len,
|
||||||
TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status);
|
TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status);
|
||||||
|
@ -27,10 +27,17 @@ allprojects {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
compile 'org.tensorflow:tensorflow-android:1.2.0-preview'
|
compile 'org.tensorflow:tensorflow-android:+'
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
This will tell Gradle to use the
|
||||||
|
[latest version](https://bintray.com/google/tensorflow/tensorflow-android/_latestVersion)
|
||||||
|
of the TensorFlow AAR that has been released to
|
||||||
|
[https://bintray.com/google/tensorflow/tensorflow-android](https://bintray.com/google/tensorflow/tensorflow-android).
|
||||||
|
You may replace the `+` with an explicit version label if you wish to
|
||||||
|
use a specific release of TensorFlow in your app.
|
||||||
|
|
||||||
To build the libraries yourself (if, for example, you want to support custom
|
To build the libraries yourself (if, for example, you want to support custom
|
||||||
TensorFlow operators), pick your preferred approach below:
|
TensorFlow operators), pick your preferred approach below:
|
||||||
|
|
||||||
|
@ -41,11 +41,11 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import rnn
|
from tensorflow.python.ops import rnn
|
||||||
|
from tensorflow.python.ops import rnn_cell
|
||||||
from tensorflow.python.ops import variable_scope as vs
|
from tensorflow.python.ops import variable_scope as vs
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -225,7 +225,7 @@ def crf_binary_score(tag_indices, sequence_lengths, transition_params):
|
|||||||
return binary_scores
|
return binary_scores
|
||||||
|
|
||||||
|
|
||||||
class CrfForwardRnnCell(core_rnn_cell.RNNCell):
|
class CrfForwardRnnCell(rnn_cell.RNNCell):
|
||||||
"""Computes the alpha values in a linear-chain CRF.
|
"""Computes the alpha values in a linear-chain CRF.
|
||||||
|
|
||||||
See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
|
See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
|
||||||
|
@ -22,7 +22,6 @@ import time
|
|||||||
|
|
||||||
from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
|
from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn
|
from tensorflow.contrib.rnn.python.ops import core_rnn
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
|
||||||
from tensorflow.contrib.rnn.python.ops import lstm_ops
|
from tensorflow.contrib.rnn.python.ops import lstm_ops
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -31,6 +30,7 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
|
from tensorflow.python.ops import rnn_cell
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -131,9 +131,9 @@ class CudnnRNNBenchmark(test.Benchmark):
|
|||||||
]
|
]
|
||||||
initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127)
|
initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127)
|
||||||
|
|
||||||
cell = core_rnn_cell_impl.LSTMCell(
|
cell = rnn_cell.LSTMCell(
|
||||||
num_units=num_units, initializer=initializer, state_is_tuple=True)
|
num_units=num_units, initializer=initializer, state_is_tuple=True)
|
||||||
multi_cell = core_rnn_cell_impl.MultiRNNCell(
|
multi_cell = rnn_cell.MultiRNNCell(
|
||||||
[cell() for _ in range(num_layers)])
|
[cell() for _ in range(num_layers)])
|
||||||
outputs, final_state = core_rnn.static_rnn(
|
outputs, final_state = core_rnn.static_rnn(
|
||||||
multi_cell, inputs, dtype=dtypes.float32)
|
multi_cell, inputs, dtype=dtypes.float32)
|
||||||
@ -159,7 +159,7 @@ class CudnnRNNBenchmark(test.Benchmark):
|
|||||||
]
|
]
|
||||||
cell = lambda: lstm_ops.LSTMBlockCell(num_units=num_units) # pylint: disable=cell-var-from-loop
|
cell = lambda: lstm_ops.LSTMBlockCell(num_units=num_units) # pylint: disable=cell-var-from-loop
|
||||||
|
|
||||||
multi_cell = core_rnn_cell_impl.MultiRNNCell(
|
multi_cell = rnn_cell.MultiRNNCell(
|
||||||
[cell() for _ in range(num_layers)])
|
[cell() for _ in range(num_layers)])
|
||||||
outputs, final_state = core_rnn.static_rnn(
|
outputs, final_state = core_rnn.static_rnn(
|
||||||
multi_cell, inputs, dtype=dtypes.float32)
|
multi_cell, inputs, dtype=dtypes.float32)
|
||||||
|
@ -21,11 +21,11 @@ from __future__ import print_function
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.grid_rnn.python.ops import grid_rnn_cell
|
from tensorflow.contrib.grid_rnn.python.ops import grid_rnn_cell
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
|
from tensorflow.python.ops import rnn
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -527,7 +527,7 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
dtypes.float32, shape=(batch_size, input_size))
|
dtypes.float32, shape=(batch_size, input_size))
|
||||||
]
|
]
|
||||||
|
|
||||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||||
|
|
||||||
self.assertEqual(len(outputs), len(inputs))
|
self.assertEqual(len(outputs), len(inputs))
|
||||||
self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
|
self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
|
||||||
@ -569,7 +569,7 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
|
array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
|
||||||
]
|
]
|
||||||
|
|
||||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||||
|
|
||||||
self.assertEqual(len(outputs), len(inputs))
|
self.assertEqual(len(outputs), len(inputs))
|
||||||
self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
|
self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
|
||||||
@ -609,7 +609,7 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
|
array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
|
||||||
]
|
]
|
||||||
|
|
||||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||||
|
|
||||||
self.assertEqual(len(outputs), len(inputs))
|
self.assertEqual(len(outputs), len(inputs))
|
||||||
self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
|
self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
|
||||||
@ -652,7 +652,7 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
dtypes.float32, shape=(batch_size, input_size))
|
dtypes.float32, shape=(batch_size, input_size))
|
||||||
] + (max_length - 1) * [array_ops.zeros([batch_size, input_size])])
|
] + (max_length - 1) * [array_ops.zeros([batch_size, input_size])])
|
||||||
|
|
||||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||||
|
|
||||||
self.assertEqual(len(outputs), len(inputs))
|
self.assertEqual(len(outputs), len(inputs))
|
||||||
self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
|
self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
|
||||||
@ -690,7 +690,7 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
array_ops.placeholder(dtypes.float32, shape=(None, input_size))
|
array_ops.placeholder(dtypes.float32, shape=(None, input_size))
|
||||||
]
|
]
|
||||||
|
|
||||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||||
|
|
||||||
self.assertEqual(len(outputs), len(inputs))
|
self.assertEqual(len(outputs), len(inputs))
|
||||||
|
|
||||||
|
@ -31,7 +31,6 @@ from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_f
|
|||||||
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import rnn_common
|
from tensorflow.contrib.learn.python.learn.estimators import rnn_common
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import run_config
|
from tensorflow.contrib.learn.python.learn.estimators import run_config
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -42,6 +41,7 @@ from tensorflow.python.ops import functional_ops
|
|||||||
from tensorflow.python.ops import lookup_ops
|
from tensorflow.python.ops import lookup_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.ops import rnn_cell
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -107,7 +107,7 @@ class DynamicRnnEstimatorTest(test.TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(DynamicRnnEstimatorTest, self).setUp()
|
super(DynamicRnnEstimatorTest, self).setUp()
|
||||||
self.rnn_cell = core_rnn_cell_impl.BasicRNNCell(self.NUM_RNN_CELL_UNITS)
|
self.rnn_cell = rnn_cell.BasicRNNCell(self.NUM_RNN_CELL_UNITS)
|
||||||
self.mock_target_column = MockTargetColumn(
|
self.mock_target_column = MockTargetColumn(
|
||||||
num_label_columns=self.NUM_LABEL_COLUMNS)
|
num_label_columns=self.NUM_LABEL_COLUMNS)
|
||||||
|
|
||||||
@ -312,19 +312,19 @@ class DynamicRnnEstimatorTest(test.TestCase):
|
|||||||
# A MultiRNNCell of LSTMCells is both a common choice and an interesting
|
# A MultiRNNCell of LSTMCells is both a common choice and an interesting
|
||||||
# test case, because it has two levels of nesting, with an inner class that
|
# test case, because it has two levels of nesting, with an inner class that
|
||||||
# is not a plain tuple.
|
# is not a plain tuple.
|
||||||
cell = core_rnn_cell_impl.MultiRNNCell(
|
cell = rnn_cell.MultiRNNCell(
|
||||||
[core_rnn_cell_impl.LSTMCell(i) for i in cell_sizes])
|
[rnn_cell.LSTMCell(i) for i in cell_sizes])
|
||||||
state_dict = {
|
state_dict = {
|
||||||
dynamic_rnn_estimator._get_state_name(i):
|
dynamic_rnn_estimator._get_state_name(i):
|
||||||
array_ops.expand_dims(math_ops.range(cell_size), 0)
|
array_ops.expand_dims(math_ops.range(cell_size), 0)
|
||||||
for i, cell_size in enumerate([5, 5, 3, 3, 7, 7])
|
for i, cell_size in enumerate([5, 5, 3, 3, 7, 7])
|
||||||
}
|
}
|
||||||
expected_state = (core_rnn_cell_impl.LSTMStateTuple(
|
expected_state = (rnn_cell.LSTMStateTuple(
|
||||||
np.reshape(np.arange(5), [1, -1]), np.reshape(np.arange(5), [1, -1])),
|
np.reshape(np.arange(5), [1, -1]), np.reshape(np.arange(5), [1, -1])),
|
||||||
core_rnn_cell_impl.LSTMStateTuple(
|
rnn_cell.LSTMStateTuple(
|
||||||
np.reshape(np.arange(3), [1, -1]),
|
np.reshape(np.arange(3), [1, -1]),
|
||||||
np.reshape(np.arange(3), [1, -1])),
|
np.reshape(np.arange(3), [1, -1])),
|
||||||
core_rnn_cell_impl.LSTMStateTuple(
|
rnn_cell.LSTMStateTuple(
|
||||||
np.reshape(np.arange(7), [1, -1]),
|
np.reshape(np.arange(7), [1, -1]),
|
||||||
np.reshape(np.arange(7), [1, -1])))
|
np.reshape(np.arange(7), [1, -1])))
|
||||||
actual_state = dynamic_rnn_estimator.dict_to_state_tuple(state_dict, cell)
|
actual_state = dynamic_rnn_estimator.dict_to_state_tuple(state_dict, cell)
|
||||||
|
@ -26,13 +26,13 @@ from tensorflow.contrib.learn.python.learn.estimators import constants
|
|||||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import model_fn
|
from tensorflow.contrib.learn.python.learn.estimators import model_fn
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import rnn_common
|
from tensorflow.contrib.learn.python.learn.estimators import rnn_common
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn
|
|
||||||
from tensorflow.contrib.training.python.training import sequence_queueing_state_saver as sqss
|
from tensorflow.contrib.training.python.training import sequence_queueing_state_saver as sqss
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import rnn
|
||||||
from tensorflow.python.training import momentum as momentum_opt
|
from tensorflow.python.training import momentum as momentum_opt
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
@ -64,7 +64,7 @@ def construct_state_saving_rnn(cell,
|
|||||||
final_state: The final state output by the RNN
|
final_state: The final state output by the RNN
|
||||||
"""
|
"""
|
||||||
with ops.name_scope(scope):
|
with ops.name_scope(scope):
|
||||||
rnn_outputs, final_state = core_rnn.static_state_saving_rnn(
|
rnn_outputs, final_state = rnn.static_state_saving_rnn(
|
||||||
cell=cell,
|
cell=cell,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
state_saver=state_saver,
|
state_saver=state_saver,
|
||||||
|
@ -21,9 +21,9 @@ from __future__ import print_function
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.learn.python.learn import ops
|
from tensorflow.contrib.learn.python.learn import ops
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import rnn_cell
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -82,7 +82,7 @@ class Seq2SeqOpsTest(test.TestCase):
|
|||||||
array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3)
|
array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3)
|
||||||
]
|
]
|
||||||
encoding = array_ops.placeholder(dtypes.float32, [2, 2])
|
encoding = array_ops.placeholder(dtypes.float32, [2, 2])
|
||||||
cell = core_rnn_cell_impl.GRUCell(2)
|
cell = rnn_cell.GRUCell(2)
|
||||||
outputs, states, sampling_outputs, sampling_states = (
|
outputs, states, sampling_outputs, sampling_states = (
|
||||||
ops.rnn_decoder(decoder_inputs, encoding, cell))
|
ops.rnn_decoder(decoder_inputs, encoding, cell))
|
||||||
self.assertEqual(len(outputs), 3)
|
self.assertEqual(len(outputs), 3)
|
||||||
|
@ -25,8 +25,7 @@ import random
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.legacy_seq2seq.python.ops import seq2seq as seq2seq_lib
|
from tensorflow.contrib.legacy_seq2seq.python.ops import seq2seq as seq2seq_lib
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn
|
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -37,6 +36,7 @@ from tensorflow.python.ops import gradients_impl
|
|||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import nn_impl
|
from tensorflow.python.ops import nn_impl
|
||||||
from tensorflow.python.ops import rnn
|
from tensorflow.python.ops import rnn
|
||||||
|
from tensorflow.python.ops import rnn_cell
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
@ -51,11 +51,10 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
||||||
_, enc_state = core_rnn.static_rnn(
|
_, enc_state = rnn.static_rnn(
|
||||||
core_rnn_cell_impl.GRUCell(2), inp, dtype=dtypes.float32)
|
rnn_cell.GRUCell(2), inp, dtype=dtypes.float32)
|
||||||
dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
|
dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
|
||||||
cell = core_rnn_cell_impl.OutputProjectionWrapper(
|
cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
|
||||||
core_rnn_cell_impl.GRUCell(2), 4)
|
|
||||||
dec, mem = seq2seq_lib.rnn_decoder(dec_inp, enc_state, cell)
|
dec, mem = seq2seq_lib.rnn_decoder(dec_inp, enc_state, cell)
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res = sess.run(dec)
|
res = sess.run(dec)
|
||||||
@ -71,8 +70,7 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
||||||
dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
|
dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
|
||||||
cell = core_rnn_cell_impl.OutputProjectionWrapper(
|
cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
|
||||||
core_rnn_cell_impl.GRUCell(2), 4)
|
|
||||||
dec, mem = seq2seq_lib.basic_rnn_seq2seq(inp, dec_inp, cell)
|
dec, mem = seq2seq_lib.basic_rnn_seq2seq(inp, dec_inp, cell)
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res = sess.run(dec)
|
res = sess.run(dec)
|
||||||
@ -88,8 +86,7 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
||||||
dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
|
dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
|
||||||
cell = core_rnn_cell_impl.OutputProjectionWrapper(
|
cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
|
||||||
core_rnn_cell_impl.GRUCell(2), 4)
|
|
||||||
dec, mem = seq2seq_lib.tied_rnn_seq2seq(inp, dec_inp, cell)
|
dec, mem = seq2seq_lib.tied_rnn_seq2seq(inp, dec_inp, cell)
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res = sess.run(dec)
|
res = sess.run(dec)
|
||||||
@ -105,9 +102,9 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
||||||
cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2)
|
cell_fn = lambda: rnn_cell.BasicLSTMCell(2)
|
||||||
cell = cell_fn()
|
cell = cell_fn()
|
||||||
_, enc_state = core_rnn.static_rnn(cell, inp, dtype=dtypes.float32)
|
_, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
|
||||||
dec_inp = [
|
dec_inp = [
|
||||||
constant_op.constant(
|
constant_op.constant(
|
||||||
i, dtypes.int32, shape=[2]) for i in range(3)
|
i, dtypes.int32, shape=[2]) for i in range(3)
|
||||||
@ -138,7 +135,7 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
constant_op.constant(
|
constant_op.constant(
|
||||||
i, dtypes.int32, shape=[2]) for i in range(3)
|
i, dtypes.int32, shape=[2]) for i in range(3)
|
||||||
]
|
]
|
||||||
cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2)
|
cell_fn = lambda: rnn_cell.BasicLSTMCell(2)
|
||||||
cell = cell_fn()
|
cell = cell_fn()
|
||||||
dec, mem = seq2seq_lib.embedding_rnn_seq2seq(
|
dec, mem = seq2seq_lib.embedding_rnn_seq2seq(
|
||||||
enc_inp,
|
enc_inp,
|
||||||
@ -158,7 +155,7 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
|
|
||||||
# Test with state_is_tuple=False.
|
# Test with state_is_tuple=False.
|
||||||
with variable_scope.variable_scope("no_tuple"):
|
with variable_scope.variable_scope("no_tuple"):
|
||||||
cell_nt = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
|
cell_nt = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
|
||||||
dec, mem = seq2seq_lib.embedding_rnn_seq2seq(
|
dec, mem = seq2seq_lib.embedding_rnn_seq2seq(
|
||||||
enc_inp,
|
enc_inp,
|
||||||
dec_inp,
|
dec_inp,
|
||||||
@ -242,9 +239,7 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
constant_op.constant(
|
constant_op.constant(
|
||||||
i, dtypes.int32, shape=[2]) for i in range(3)
|
i, dtypes.int32, shape=[2]) for i in range(3)
|
||||||
]
|
]
|
||||||
cell = functools.partial(
|
cell = functools.partial(rnn_cell.BasicLSTMCell, 2, state_is_tuple=True)
|
||||||
core_rnn_cell_impl.BasicLSTMCell,
|
|
||||||
2, state_is_tuple=True)
|
|
||||||
dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq(
|
dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq(
|
||||||
enc_inp, dec_inp, cell(), num_symbols=5, embedding_size=2)
|
enc_inp, dec_inp, cell(), num_symbols=5, embedding_size=2)
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
@ -324,11 +319,10 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
|
cell_fn = lambda: rnn_cell.GRUCell(2)
|
||||||
cell = cell_fn()
|
cell = cell_fn()
|
||||||
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
||||||
enc_outputs, enc_state = core_rnn.static_rnn(
|
enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
|
||||||
cell, inp, dtype=dtypes.float32)
|
|
||||||
attn_states = array_ops.concat([
|
attn_states = array_ops.concat([
|
||||||
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
|
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
|
||||||
], 1)
|
], 1)
|
||||||
@ -350,11 +344,10 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
|
cell_fn = lambda: rnn_cell.GRUCell(2)
|
||||||
cell = cell_fn()
|
cell = cell_fn()
|
||||||
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
||||||
enc_outputs, enc_state = core_rnn.static_rnn(
|
enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
|
||||||
cell, inp, dtype=dtypes.float32)
|
|
||||||
attn_states = array_ops.concat([
|
attn_states = array_ops.concat([
|
||||||
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
|
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
|
||||||
], 1)
|
], 1)
|
||||||
@ -377,7 +370,7 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
|
cell_fn = lambda: rnn_cell.GRUCell(2)
|
||||||
cell = cell_fn()
|
cell = cell_fn()
|
||||||
inp = constant_op.constant(0.5, shape=[2, 2, 2])
|
inp = constant_op.constant(0.5, shape=[2, 2, 2])
|
||||||
enc_outputs, enc_state = rnn.dynamic_rnn(
|
enc_outputs, enc_state = rnn.dynamic_rnn(
|
||||||
@ -401,7 +394,7 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
|
cell_fn = lambda: rnn_cell.GRUCell(2)
|
||||||
cell = cell_fn()
|
cell = cell_fn()
|
||||||
inp = constant_op.constant(0.5, shape=[2, 2, 2])
|
inp = constant_op.constant(0.5, shape=[2, 2, 2])
|
||||||
enc_outputs, enc_state = rnn.dynamic_rnn(
|
enc_outputs, enc_state = rnn.dynamic_rnn(
|
||||||
@ -426,14 +419,13 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
single_cell = lambda: core_rnn_cell_impl.BasicLSTMCell( # pylint: disable=g-long-lambda
|
single_cell = lambda: rnn_cell.BasicLSTMCell( # pylint: disable=g-long-lambda
|
||||||
2, state_is_tuple=True)
|
2, state_is_tuple=True)
|
||||||
cell_fn = lambda: core_rnn_cell_impl.MultiRNNCell( # pylint: disable=g-long-lambda
|
cell_fn = lambda: rnn_cell.MultiRNNCell( # pylint: disable=g-long-lambda
|
||||||
cells=[single_cell() for _ in range(2)], state_is_tuple=True)
|
cells=[single_cell() for _ in range(2)], state_is_tuple=True)
|
||||||
cell = cell_fn()
|
cell = cell_fn()
|
||||||
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
||||||
enc_outputs, enc_state = core_rnn.static_rnn(
|
enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
|
||||||
cell, inp, dtype=dtypes.float32)
|
|
||||||
attn_states = array_ops.concat([
|
attn_states = array_ops.concat([
|
||||||
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
|
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
|
||||||
], 1)
|
], 1)
|
||||||
@ -459,12 +451,11 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
cell_fn = lambda: core_rnn_cell_impl.MultiRNNCell( # pylint: disable=g-long-lambda
|
cell_fn = lambda: rnn_cell.MultiRNNCell( # pylint: disable=g-long-lambda
|
||||||
cells=[core_rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)])
|
cells=[rnn_cell.BasicLSTMCell(2) for _ in range(2)])
|
||||||
cell = cell_fn()
|
cell = cell_fn()
|
||||||
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
||||||
enc_outputs, enc_state = core_rnn.static_rnn(
|
enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
|
||||||
cell, inp, dtype=dtypes.float32)
|
|
||||||
attn_states = array_ops.concat([
|
attn_states = array_ops.concat([
|
||||||
array_ops.reshape(e, [-1, 1, cell.output_size])
|
array_ops.reshape(e, [-1, 1, cell.output_size])
|
||||||
for e in enc_outputs
|
for e in enc_outputs
|
||||||
@ -492,10 +483,9 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
|
||||||
cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
|
cell_fn = lambda: rnn_cell.GRUCell(2)
|
||||||
cell = cell_fn()
|
cell = cell_fn()
|
||||||
enc_outputs, enc_state = core_rnn.static_rnn(
|
enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
|
||||||
cell, inp, dtype=dtypes.float32)
|
|
||||||
attn_states = array_ops.concat([
|
attn_states = array_ops.concat([
|
||||||
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
|
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
|
||||||
], 1)
|
], 1)
|
||||||
@ -534,7 +524,7 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
constant_op.constant(
|
constant_op.constant(
|
||||||
i, dtypes.int32, shape=[2]) for i in range(3)
|
i, dtypes.int32, shape=[2]) for i in range(3)
|
||||||
]
|
]
|
||||||
cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2)
|
cell_fn = lambda: rnn_cell.BasicLSTMCell(2)
|
||||||
cell = cell_fn()
|
cell = cell_fn()
|
||||||
dec, mem = seq2seq_lib.embedding_attention_seq2seq(
|
dec, mem = seq2seq_lib.embedding_attention_seq2seq(
|
||||||
enc_inp,
|
enc_inp,
|
||||||
@ -555,8 +545,7 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
# Test with state_is_tuple=False.
|
# Test with state_is_tuple=False.
|
||||||
with variable_scope.variable_scope("no_tuple"):
|
with variable_scope.variable_scope("no_tuple"):
|
||||||
cell_fn = functools.partial(
|
cell_fn = functools.partial(
|
||||||
core_rnn_cell_impl.BasicLSTMCell,
|
rnn_cell.BasicLSTMCell, 2, state_is_tuple=False)
|
||||||
2, state_is_tuple=False)
|
|
||||||
cell_nt = cell_fn()
|
cell_nt = cell_fn()
|
||||||
dec, mem = seq2seq_lib.embedding_attention_seq2seq(
|
dec, mem = seq2seq_lib.embedding_attention_seq2seq(
|
||||||
enc_inp,
|
enc_inp,
|
||||||
@ -651,11 +640,10 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
]
|
]
|
||||||
dec_symbols_dict = {"0": 5, "1": 6}
|
dec_symbols_dict = {"0": 5, "1": 6}
|
||||||
def EncCellFn():
|
def EncCellFn():
|
||||||
return core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
|
return rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
|
||||||
def DecCellsFn():
|
def DecCellsFn():
|
||||||
return dict(
|
return dict((k, rnn_cell.BasicLSTMCell(2, state_is_tuple=True))
|
||||||
(k, core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True))
|
for k in dec_symbols_dict)
|
||||||
for k in dec_symbols_dict)
|
|
||||||
outputs_dict, state_dict = (seq2seq_lib.one2many_rnn_seq2seq(
|
outputs_dict, state_dict = (seq2seq_lib.one2many_rnn_seq2seq(
|
||||||
enc_inp, dec_inp_dict, EncCellFn(), DecCellsFn(),
|
enc_inp, dec_inp_dict, EncCellFn(), DecCellsFn(),
|
||||||
2, dec_symbols_dict, embedding_size=2))
|
2, dec_symbols_dict, embedding_size=2))
|
||||||
@ -796,8 +784,8 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
# """Example sequence-to-sequence model that uses GRU cells."""
|
# """Example sequence-to-sequence model that uses GRU cells."""
|
||||||
|
|
||||||
# def GRUSeq2Seq(enc_inp, dec_inp):
|
# def GRUSeq2Seq(enc_inp, dec_inp):
|
||||||
# cell = core_rnn_cell_impl.MultiRNNCell(
|
# cell = rnn_cell.MultiRNNCell(
|
||||||
# [core_rnn_cell_impl.GRUCell(24) for _ in range(2)])
|
# [rnn_cell.GRUCell(24) for _ in range(2)])
|
||||||
# return seq2seq_lib.embedding_attention_seq2seq(
|
# return seq2seq_lib.embedding_attention_seq2seq(
|
||||||
# enc_inp,
|
# enc_inp,
|
||||||
# dec_inp,
|
# dec_inp,
|
||||||
@ -862,9 +850,8 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
"""Example sequence-to-sequence model that uses GRU cells."""
|
"""Example sequence-to-sequence model that uses GRU cells."""
|
||||||
|
|
||||||
def GRUSeq2Seq(enc_inp, dec_inp):
|
def GRUSeq2Seq(enc_inp, dec_inp):
|
||||||
cell = core_rnn_cell_impl.MultiRNNCell(
|
cell = rnn_cell.MultiRNNCell(
|
||||||
[core_rnn_cell_impl.GRUCell(24) for _ in range(2)],
|
[rnn_cell.GRUCell(24) for _ in range(2)], state_is_tuple=True)
|
||||||
state_is_tuple=True)
|
|
||||||
return seq2seq_lib.embedding_attention_seq2seq(
|
return seq2seq_lib.embedding_attention_seq2seq(
|
||||||
enc_inp,
|
enc_inp,
|
||||||
dec_inp,
|
dec_inp,
|
||||||
@ -1040,7 +1027,7 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
self.assertAllClose(v_true.eval(), v_false.eval())
|
self.assertAllClose(v_true.eval(), v_false.eval())
|
||||||
|
|
||||||
def EmbeddingRNNSeq2SeqF(enc_inp, dec_inp, feed_previous):
|
def EmbeddingRNNSeq2SeqF(enc_inp, dec_inp, feed_previous):
|
||||||
cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
|
cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
|
||||||
return seq2seq_lib.embedding_rnn_seq2seq(
|
return seq2seq_lib.embedding_rnn_seq2seq(
|
||||||
enc_inp,
|
enc_inp,
|
||||||
dec_inp,
|
dec_inp,
|
||||||
@ -1051,7 +1038,7 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
feed_previous=feed_previous)
|
feed_previous=feed_previous)
|
||||||
|
|
||||||
def EmbeddingRNNSeq2SeqNoTupleF(enc_inp, dec_inp, feed_previous):
|
def EmbeddingRNNSeq2SeqNoTupleF(enc_inp, dec_inp, feed_previous):
|
||||||
cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
|
cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
|
||||||
return seq2seq_lib.embedding_rnn_seq2seq(
|
return seq2seq_lib.embedding_rnn_seq2seq(
|
||||||
enc_inp,
|
enc_inp,
|
||||||
dec_inp,
|
dec_inp,
|
||||||
@ -1062,7 +1049,7 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
feed_previous=feed_previous)
|
feed_previous=feed_previous)
|
||||||
|
|
||||||
def EmbeddingTiedRNNSeq2Seq(enc_inp, dec_inp, feed_previous):
|
def EmbeddingTiedRNNSeq2Seq(enc_inp, dec_inp, feed_previous):
|
||||||
cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
|
cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
|
||||||
return seq2seq_lib.embedding_tied_rnn_seq2seq(
|
return seq2seq_lib.embedding_tied_rnn_seq2seq(
|
||||||
enc_inp,
|
enc_inp,
|
||||||
dec_inp,
|
dec_inp,
|
||||||
@ -1072,7 +1059,7 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
feed_previous=feed_previous)
|
feed_previous=feed_previous)
|
||||||
|
|
||||||
def EmbeddingTiedRNNSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous):
|
def EmbeddingTiedRNNSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous):
|
||||||
cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
|
cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
|
||||||
return seq2seq_lib.embedding_tied_rnn_seq2seq(
|
return seq2seq_lib.embedding_tied_rnn_seq2seq(
|
||||||
enc_inp,
|
enc_inp,
|
||||||
dec_inp,
|
dec_inp,
|
||||||
@ -1082,7 +1069,7 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
feed_previous=feed_previous)
|
feed_previous=feed_previous)
|
||||||
|
|
||||||
def EmbeddingAttentionSeq2Seq(enc_inp, dec_inp, feed_previous):
|
def EmbeddingAttentionSeq2Seq(enc_inp, dec_inp, feed_previous):
|
||||||
cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
|
cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
|
||||||
return seq2seq_lib.embedding_attention_seq2seq(
|
return seq2seq_lib.embedding_attention_seq2seq(
|
||||||
enc_inp,
|
enc_inp,
|
||||||
dec_inp,
|
dec_inp,
|
||||||
@ -1093,7 +1080,7 @@ class Seq2SeqTest(test.TestCase):
|
|||||||
feed_previous=feed_previous)
|
feed_previous=feed_previous)
|
||||||
|
|
||||||
def EmbeddingAttentionSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous):
|
def EmbeddingAttentionSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous):
|
||||||
cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
|
cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
|
||||||
return seq2seq_lib.embedding_attention_seq2seq(
|
return seq2seq_lib.embedding_attention_seq2seq(
|
||||||
enc_inp,
|
enc_inp,
|
||||||
dec_inp,
|
dec_inp,
|
||||||
|
@ -62,9 +62,7 @@ import copy
|
|||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
from six.moves import zip # pylint: disable=redefined-builtin
|
from six.moves import zip # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn
|
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
|
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -72,11 +70,13 @@ from tensorflow.python.ops import control_flow_ops
|
|||||||
from tensorflow.python.ops import embedding_ops
|
from tensorflow.python.ops import embedding_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
|
from tensorflow.python.ops import rnn
|
||||||
|
from tensorflow.python.ops import rnn_cell_impl
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
# TODO(ebrevdo): Remove once _linear is fully deprecated.
|
# TODO(ebrevdo): Remove once _linear is fully deprecated.
|
||||||
linear = core_rnn_cell_impl._linear # pylint: disable=protected-access
|
linear = rnn_cell_impl._linear # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
def _extract_argmax_and_embed(embedding,
|
def _extract_argmax_and_embed(embedding,
|
||||||
@ -119,7 +119,7 @@ def rnn_decoder(decoder_inputs,
|
|||||||
Args:
|
Args:
|
||||||
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
|
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
|
||||||
initial_state: 2D Tensor with shape [batch_size x cell.state_size].
|
initial_state: 2D Tensor with shape [batch_size x cell.state_size].
|
||||||
cell: core_rnn_cell.RNNCell defining the cell function and size.
|
cell: rnn_cell.RNNCell defining the cell function and size.
|
||||||
loop_function: If not None, this function will be applied to the i-th output
|
loop_function: If not None, this function will be applied to the i-th output
|
||||||
in order to generate the i+1-st input, and decoder_inputs will be ignored,
|
in order to generate the i+1-st input, and decoder_inputs will be ignored,
|
||||||
except for the first element ("GO" symbol). This can be used for decoding,
|
except for the first element ("GO" symbol). This can be used for decoding,
|
||||||
@ -170,7 +170,7 @@ def basic_rnn_seq2seq(encoder_inputs,
|
|||||||
Args:
|
Args:
|
||||||
encoder_inputs: A list of 2D Tensors [batch_size x input_size].
|
encoder_inputs: A list of 2D Tensors [batch_size x input_size].
|
||||||
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
|
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
|
||||||
cell: core_rnn_cell.RNNCell defining the cell function and size.
|
cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
|
||||||
dtype: The dtype of the initial state of the RNN cell (default: tf.float32).
|
dtype: The dtype of the initial state of the RNN cell (default: tf.float32).
|
||||||
scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq".
|
scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq".
|
||||||
|
|
||||||
@ -183,7 +183,7 @@ def basic_rnn_seq2seq(encoder_inputs,
|
|||||||
"""
|
"""
|
||||||
with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"):
|
with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"):
|
||||||
enc_cell = copy.deepcopy(cell)
|
enc_cell = copy.deepcopy(cell)
|
||||||
_, enc_state = core_rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype)
|
_, enc_state = rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype)
|
||||||
return rnn_decoder(decoder_inputs, enc_state, cell)
|
return rnn_decoder(decoder_inputs, enc_state, cell)
|
||||||
|
|
||||||
|
|
||||||
@ -202,7 +202,7 @@ def tied_rnn_seq2seq(encoder_inputs,
|
|||||||
Args:
|
Args:
|
||||||
encoder_inputs: A list of 2D Tensors [batch_size x input_size].
|
encoder_inputs: A list of 2D Tensors [batch_size x input_size].
|
||||||
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
|
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
|
||||||
cell: core_rnn_cell.RNNCell defining the cell function and size.
|
cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
|
||||||
loop_function: If not None, this function will be applied to i-th output
|
loop_function: If not None, this function will be applied to i-th output
|
||||||
in order to generate i+1-th input, and decoder_inputs will be ignored,
|
in order to generate i+1-th input, and decoder_inputs will be ignored,
|
||||||
except for the first element ("GO" symbol), see rnn_decoder for details.
|
except for the first element ("GO" symbol), see rnn_decoder for details.
|
||||||
@ -219,7 +219,7 @@ def tied_rnn_seq2seq(encoder_inputs,
|
|||||||
"""
|
"""
|
||||||
with variable_scope.variable_scope("combined_tied_rnn_seq2seq"):
|
with variable_scope.variable_scope("combined_tied_rnn_seq2seq"):
|
||||||
scope = scope or "tied_rnn_seq2seq"
|
scope = scope or "tied_rnn_seq2seq"
|
||||||
_, enc_state = core_rnn.static_rnn(
|
_, enc_state = rnn.static_rnn(
|
||||||
cell, encoder_inputs, dtype=dtype, scope=scope)
|
cell, encoder_inputs, dtype=dtype, scope=scope)
|
||||||
variable_scope.get_variable_scope().reuse_variables()
|
variable_scope.get_variable_scope().reuse_variables()
|
||||||
return rnn_decoder(
|
return rnn_decoder(
|
||||||
@ -244,7 +244,7 @@ def embedding_rnn_decoder(decoder_inputs,
|
|||||||
Args:
|
Args:
|
||||||
decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs).
|
decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs).
|
||||||
initial_state: 2D Tensor [batch_size x cell.state_size].
|
initial_state: 2D Tensor [batch_size x cell.state_size].
|
||||||
cell: core_rnn_cell.RNNCell defining the cell function.
|
cell: tf.nn.rnn_cell.RNNCell defining the cell function.
|
||||||
num_symbols: Integer, how many symbols come into the embedding.
|
num_symbols: Integer, how many symbols come into the embedding.
|
||||||
embedding_size: Integer, the length of the embedding vector for each symbol.
|
embedding_size: Integer, the length of the embedding vector for each symbol.
|
||||||
output_projection: None or a pair (W, B) of output projection weights and
|
output_projection: None or a pair (W, B) of output projection weights and
|
||||||
@ -320,7 +320,7 @@ def embedding_rnn_seq2seq(encoder_inputs,
|
|||||||
Args:
|
Args:
|
||||||
encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
|
encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
|
||||||
decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
|
decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
|
||||||
cell: core_rnn_cell.RNNCell defining the cell function and size.
|
cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
|
||||||
num_encoder_symbols: Integer; number of symbols on the encoder side.
|
num_encoder_symbols: Integer; number of symbols on the encoder side.
|
||||||
num_decoder_symbols: Integer; number of symbols on the decoder side.
|
num_decoder_symbols: Integer; number of symbols on the decoder side.
|
||||||
embedding_size: Integer, the length of the embedding vector for each symbol.
|
embedding_size: Integer, the length of the embedding vector for each symbol.
|
||||||
@ -360,8 +360,7 @@ def embedding_rnn_seq2seq(encoder_inputs,
|
|||||||
encoder_cell,
|
encoder_cell,
|
||||||
embedding_classes=num_encoder_symbols,
|
embedding_classes=num_encoder_symbols,
|
||||||
embedding_size=embedding_size)
|
embedding_size=embedding_size)
|
||||||
_, encoder_state = core_rnn.static_rnn(
|
_, encoder_state = rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype)
|
||||||
encoder_cell, encoder_inputs, dtype=dtype)
|
|
||||||
|
|
||||||
# Decoder.
|
# Decoder.
|
||||||
if output_projection is None:
|
if output_projection is None:
|
||||||
@ -431,7 +430,7 @@ def embedding_tied_rnn_seq2seq(encoder_inputs,
|
|||||||
Args:
|
Args:
|
||||||
encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
|
encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
|
||||||
decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
|
decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
|
||||||
cell: core_rnn_cell.RNNCell defining the cell function and size.
|
cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
|
||||||
num_symbols: Integer; number of symbols for both encoder and decoder.
|
num_symbols: Integer; number of symbols for both encoder and decoder.
|
||||||
embedding_size: Integer, the length of the embedding vector for each symbol.
|
embedding_size: Integer, the length of the embedding vector for each symbol.
|
||||||
num_decoder_symbols: Integer; number of output symbols for decoder. If
|
num_decoder_symbols: Integer; number of output symbols for decoder. If
|
||||||
@ -560,7 +559,7 @@ def attention_decoder(decoder_inputs,
|
|||||||
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
|
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
|
||||||
initial_state: 2D Tensor [batch_size x cell.state_size].
|
initial_state: 2D Tensor [batch_size x cell.state_size].
|
||||||
attention_states: 3D Tensor [batch_size x attn_length x attn_size].
|
attention_states: 3D Tensor [batch_size x attn_length x attn_size].
|
||||||
cell: core_rnn_cell.RNNCell defining the cell function and size.
|
cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
|
||||||
output_size: Size of the output vectors; if None, we use cell.output_size.
|
output_size: Size of the output vectors; if None, we use cell.output_size.
|
||||||
num_heads: Number of attention heads that read from attention_states.
|
num_heads: Number of attention heads that read from attention_states.
|
||||||
loop_function: If not None, this function will be applied to i-th output
|
loop_function: If not None, this function will be applied to i-th output
|
||||||
@ -720,7 +719,7 @@ def embedding_attention_decoder(decoder_inputs,
|
|||||||
decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs).
|
decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs).
|
||||||
initial_state: 2D Tensor [batch_size x cell.state_size].
|
initial_state: 2D Tensor [batch_size x cell.state_size].
|
||||||
attention_states: 3D Tensor [batch_size x attn_length x attn_size].
|
attention_states: 3D Tensor [batch_size x attn_length x attn_size].
|
||||||
cell: core_rnn_cell.RNNCell defining the cell function.
|
cell: tf.nn.rnn_cell.RNNCell defining the cell function.
|
||||||
num_symbols: Integer, how many symbols come into the embedding.
|
num_symbols: Integer, how many symbols come into the embedding.
|
||||||
embedding_size: Integer, the length of the embedding vector for each symbol.
|
embedding_size: Integer, the length of the embedding vector for each symbol.
|
||||||
num_heads: Number of attention heads that read from attention_states.
|
num_heads: Number of attention heads that read from attention_states.
|
||||||
@ -814,7 +813,7 @@ def embedding_attention_seq2seq(encoder_inputs,
|
|||||||
Args:
|
Args:
|
||||||
encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
|
encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
|
||||||
decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
|
decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
|
||||||
cell: core_rnn_cell.RNNCell defining the cell function and size.
|
cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
|
||||||
num_encoder_symbols: Integer; number of symbols on the encoder side.
|
num_encoder_symbols: Integer; number of symbols on the encoder side.
|
||||||
num_decoder_symbols: Integer; number of symbols on the decoder side.
|
num_decoder_symbols: Integer; number of symbols on the decoder side.
|
||||||
embedding_size: Integer, the length of the embedding vector for each symbol.
|
embedding_size: Integer, the length of the embedding vector for each symbol.
|
||||||
@ -851,7 +850,7 @@ def embedding_attention_seq2seq(encoder_inputs,
|
|||||||
encoder_cell,
|
encoder_cell,
|
||||||
embedding_classes=num_encoder_symbols,
|
embedding_classes=num_encoder_symbols,
|
||||||
embedding_size=embedding_size)
|
embedding_size=embedding_size)
|
||||||
encoder_outputs, encoder_state = core_rnn.static_rnn(
|
encoder_outputs, encoder_state = rnn.static_rnn(
|
||||||
encoder_cell, encoder_inputs, dtype=dtype)
|
encoder_cell, encoder_inputs, dtype=dtype)
|
||||||
|
|
||||||
# First calculate a concatenation of encoder outputs to put attention on.
|
# First calculate a concatenation of encoder outputs to put attention on.
|
||||||
@ -937,9 +936,10 @@ def one2many_rnn_seq2seq(encoder_inputs,
|
|||||||
the corresponding decoder_inputs; each decoder_inputs is a list of 1D
|
the corresponding decoder_inputs; each decoder_inputs is a list of 1D
|
||||||
Tensors of shape [batch_size]; num_decoders is defined as
|
Tensors of shape [batch_size]; num_decoders is defined as
|
||||||
len(decoder_inputs_dict).
|
len(decoder_inputs_dict).
|
||||||
enc_cell: core_rnn_cell.RNNCell defining the encoder cell function and size.
|
enc_cell: tf.nn.rnn_cell.RNNCell defining the encoder cell function and
|
||||||
|
size.
|
||||||
dec_cells_dict: A dictionary mapping encoder name (string) to an
|
dec_cells_dict: A dictionary mapping encoder name (string) to an
|
||||||
instance of core_rnn_cell.RNNCell.
|
instance of tf.nn.rnn_cell.RNNCell.
|
||||||
num_encoder_symbols: Integer; number of symbols on the encoder side.
|
num_encoder_symbols: Integer; number of symbols on the encoder side.
|
||||||
num_decoder_symbols_dict: A dictionary mapping decoder name (string) to an
|
num_decoder_symbols_dict: A dictionary mapping decoder name (string) to an
|
||||||
integer specifying number of symbols for the corresponding decoder;
|
integer specifying number of symbols for the corresponding decoder;
|
||||||
@ -971,12 +971,12 @@ def one2many_rnn_seq2seq(encoder_inputs,
|
|||||||
outputs_dict = {}
|
outputs_dict = {}
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
|
|
||||||
if not isinstance(enc_cell, core_rnn_cell.RNNCell):
|
if not isinstance(enc_cell, rnn_cell_impl.RNNCell):
|
||||||
raise TypeError("enc_cell is not an RNNCell: %s" % type(enc_cell))
|
raise TypeError("enc_cell is not an RNNCell: %s" % type(enc_cell))
|
||||||
if set(dec_cells_dict) != set(decoder_inputs_dict):
|
if set(dec_cells_dict) != set(decoder_inputs_dict):
|
||||||
raise ValueError("keys of dec_cells_dict != keys of decodre_inputs_dict")
|
raise ValueError("keys of dec_cells_dict != keys of decodre_inputs_dict")
|
||||||
for dec_cell in dec_cells_dict.values():
|
for dec_cell in dec_cells_dict.values():
|
||||||
if not isinstance(dec_cell, core_rnn_cell.RNNCell):
|
if not isinstance(dec_cell, rnn_cell_impl.RNNCell):
|
||||||
raise TypeError("dec_cell is not an RNNCell: %s" % type(dec_cell))
|
raise TypeError("dec_cell is not an RNNCell: %s" % type(dec_cell))
|
||||||
|
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
@ -988,8 +988,7 @@ def one2many_rnn_seq2seq(encoder_inputs,
|
|||||||
enc_cell,
|
enc_cell,
|
||||||
embedding_classes=num_encoder_symbols,
|
embedding_classes=num_encoder_symbols,
|
||||||
embedding_size=embedding_size)
|
embedding_size=embedding_size)
|
||||||
_, encoder_state = core_rnn.static_rnn(
|
_, encoder_state = rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype)
|
||||||
enc_cell, encoder_inputs, dtype=dtype)
|
|
||||||
|
|
||||||
# Decoder.
|
# Decoder.
|
||||||
for name, decoder_inputs in decoder_inputs_dict.items():
|
for name, decoder_inputs in decoder_inputs_dict.items():
|
||||||
@ -1153,7 +1152,7 @@ def model_with_buckets(encoder_inputs,
|
|||||||
|
|
||||||
The seq2seq argument is a function that defines a sequence-to-sequence model,
|
The seq2seq argument is a function that defines a sequence-to-sequence model,
|
||||||
e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(
|
e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(
|
||||||
x, y, core_rnn_cell.GRUCell(24))
|
x, y, rnn_cell.GRUCell(24))
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input.
|
encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input.
|
||||||
|
@ -20,13 +20,13 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
from tensorflow.contrib.framework.python.ops import variables
|
from tensorflow.contrib.framework.python.ops import variables
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import rnn
|
from tensorflow.python.ops import rnn
|
||||||
|
from tensorflow.python.ops import rnn_cell
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
|
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ def ndlstm_base_unrolled(inputs, noutput, scope=None, reverse=False):
|
|||||||
"""
|
"""
|
||||||
with variable_scope.variable_scope(scope, "SeqLstmUnrolled", [inputs]):
|
with variable_scope.variable_scope(scope, "SeqLstmUnrolled", [inputs]):
|
||||||
length, batch_size, _ = _shape(inputs)
|
length, batch_size, _ = _shape(inputs)
|
||||||
lstm_cell = core_rnn_cell_impl.BasicLSTMCell(noutput, state_is_tuple=False)
|
lstm_cell = rnn_cell.BasicLSTMCell(noutput, state_is_tuple=False)
|
||||||
state = array_ops.zeros([batch_size, lstm_cell.state_size])
|
state = array_ops.zeros([batch_size, lstm_cell.state_size])
|
||||||
output_u = []
|
output_u = []
|
||||||
inputs_u = array_ops.unstack(inputs)
|
inputs_u = array_ops.unstack(inputs)
|
||||||
@ -88,7 +88,7 @@ def ndlstm_base_dynamic(inputs, noutput, scope=None, reverse=False):
|
|||||||
# TODO(tmb) make batch size, sequence_length dynamic
|
# TODO(tmb) make batch size, sequence_length dynamic
|
||||||
# example: sequence_length = tf.shape(inputs)[0]
|
# example: sequence_length = tf.shape(inputs)[0]
|
||||||
_, batch_size, _ = _shape(inputs)
|
_, batch_size, _ = _shape(inputs)
|
||||||
lstm_cell = core_rnn_cell_impl.BasicLSTMCell(noutput, state_is_tuple=False)
|
lstm_cell = rnn_cell.BasicLSTMCell(noutput, state_is_tuple=False)
|
||||||
state = array_ops.zeros([batch_size, lstm_cell.state_size])
|
state = array_ops.zeros([batch_size, lstm_cell.state_size])
|
||||||
sequence_length = int(inputs.get_shape()[0])
|
sequence_length = int(inputs.get_shape()[0])
|
||||||
sequence_lengths = math_ops.to_int64(
|
sequence_lengths = math_ops.to_int64(
|
||||||
@ -145,7 +145,7 @@ def sequence_to_final(inputs, noutput, scope=None, name=None, reverse=False):
|
|||||||
"""
|
"""
|
||||||
with variable_scope.variable_scope(scope, "SequenceToFinal", [inputs]):
|
with variable_scope.variable_scope(scope, "SequenceToFinal", [inputs]):
|
||||||
length, batch_size, _ = _shape(inputs)
|
length, batch_size, _ = _shape(inputs)
|
||||||
lstm = core_rnn_cell_impl.BasicLSTMCell(noutput, state_is_tuple=False)
|
lstm = rnn_cell.BasicLSTMCell(noutput, state_is_tuple=False)
|
||||||
state = array_ops.zeros([batch_size, lstm.state_size])
|
state = array_ops.zeros([batch_size, lstm.state_size])
|
||||||
inputs_u = array_ops.unstack(inputs)
|
inputs_u = array_ops.unstack(inputs)
|
||||||
if reverse:
|
if reverse:
|
||||||
|
@ -16,21 +16,26 @@
|
|||||||
|
|
||||||
See @{$python/contrib.rnn} guide.
|
See @{$python/contrib.rnn} guide.
|
||||||
|
|
||||||
|
# From core
|
||||||
@@RNNCell
|
@@RNNCell
|
||||||
@@BasicRNNCell
|
@@BasicRNNCell
|
||||||
@@BasicLSTMCell
|
@@BasicLSTMCell
|
||||||
@@GRUCell
|
@@GRUCell
|
||||||
@@LSTMCell
|
@@LSTMCell
|
||||||
@@LayerNormBasicLSTMCell
|
|
||||||
@@LSTMStateTuple
|
@@LSTMStateTuple
|
||||||
@@MultiRNNCell
|
|
||||||
@@LSTMBlockWrapper
|
|
||||||
@@DropoutWrapper
|
@@DropoutWrapper
|
||||||
|
@@MultiRNNCell
|
||||||
|
@@DeviceWrapper
|
||||||
|
@@ResidualWrapper
|
||||||
|
|
||||||
|
# Used to be in core, but kept in contrib.
|
||||||
@@EmbeddingWrapper
|
@@EmbeddingWrapper
|
||||||
@@InputProjectionWrapper
|
@@InputProjectionWrapper
|
||||||
@@OutputProjectionWrapper
|
@@OutputProjectionWrapper
|
||||||
@@DeviceWrapper
|
|
||||||
@@ResidualWrapper
|
# Created in contrib, eventual plans to move to core.
|
||||||
|
@@LayerNormBasicLSTMCell
|
||||||
|
@@LSTMBlockWrapper
|
||||||
@@LSTMBlockCell
|
@@LSTMBlockCell
|
||||||
@@GRUBlockCell
|
@@GRUBlockCell
|
||||||
@@FusedRNNCell
|
@@FusedRNNCell
|
||||||
@ -48,9 +53,11 @@ See @{$python/contrib.rnn} guide.
|
|||||||
@@HighwayWrapper
|
@@HighwayWrapper
|
||||||
@@GLSTMCell
|
@@GLSTMCell
|
||||||
|
|
||||||
### RNNCell wrappers
|
# RNNCell wrappers
|
||||||
@@AttentionCellWrapper
|
@@AttentionCellWrapper
|
||||||
@@CompiledWrapper
|
@@CompiledWrapper
|
||||||
|
|
||||||
|
# RNN functions
|
||||||
@@static_rnn
|
@@static_rnn
|
||||||
@@static_state_saving_rnn
|
@@static_state_saving_rnn
|
||||||
@@static_bidirectional_rnn
|
@@static_bidirectional_rnn
|
||||||
@ -62,31 +69,23 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn import static_bidirectional_rnn
|
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn import static_rnn
|
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn import static_state_saving_rnn
|
|
||||||
|
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import BasicLSTMCell
|
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import BasicRNNCell
|
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import DeviceWrapper
|
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import DropoutWrapper
|
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import EmbeddingWrapper
|
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import GRUCell
|
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import InputProjectionWrapper
|
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMCell
|
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple
|
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import MultiRNNCell
|
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import OutputProjectionWrapper
|
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import ResidualWrapper
|
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import RNNCell
|
|
||||||
|
|
||||||
# pylint: disable=unused-import,wildcard-import,line-too-long
|
# pylint: disable=unused-import,wildcard-import,line-too-long
|
||||||
|
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import EmbeddingWrapper
|
||||||
|
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import InputProjectionWrapper
|
||||||
|
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import OutputProjectionWrapper
|
||||||
|
|
||||||
from tensorflow.contrib.rnn.python.ops.fused_rnn_cell import *
|
from tensorflow.contrib.rnn.python.ops.fused_rnn_cell import *
|
||||||
from tensorflow.contrib.rnn.python.ops.gru_ops import *
|
from tensorflow.contrib.rnn.python.ops.gru_ops import *
|
||||||
from tensorflow.contrib.rnn.python.ops.lstm_ops import *
|
from tensorflow.contrib.rnn.python.ops.lstm_ops import *
|
||||||
from tensorflow.contrib.rnn.python.ops.rnn import *
|
from tensorflow.contrib.rnn.python.ops.rnn import *
|
||||||
from tensorflow.contrib.rnn.python.ops.rnn_cell import *
|
from tensorflow.contrib.rnn.python.ops.rnn_cell import *
|
||||||
|
|
||||||
|
from tensorflow.python.ops.rnn import static_bidirectional_rnn
|
||||||
|
from tensorflow.python.ops.rnn import static_rnn
|
||||||
|
from tensorflow.python.ops.rnn import static_state_saving_rnn
|
||||||
|
|
||||||
|
from tensorflow.python.ops.rnn_cell import *
|
||||||
# pylint: enable=unused-import,wildcard-import,line-too-long
|
# pylint: enable=unused-import,wildcard-import,line-too-long
|
||||||
|
|
||||||
from tensorflow.python.util.all_util import remove_undocumented
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
remove_undocumented(__name__, ['core_rnn_cell'])
|
remove_undocumented(__name__)
|
||||||
|
@ -25,8 +25,7 @@ import numpy as np
|
|||||||
# TODO(ebrevdo): Remove once _linear is fully deprecated.
|
# TODO(ebrevdo): Remove once _linear is fully deprecated.
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
from tensorflow.contrib import rnn as contrib_rnn
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import _linear as linear
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -37,12 +36,14 @@ from tensorflow.python.ops import init_ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import rnn
|
from tensorflow.python.ops import rnn
|
||||||
|
from tensorflow.python.ops import rnn_cell_impl
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables as variables_lib
|
from tensorflow.python.ops import variables as variables_lib
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
linear = rnn_cell_impl._linear
|
||||||
|
|
||||||
|
|
||||||
class RNNCellTest(test.TestCase):
|
class RNNCellTest(test.TestCase):
|
||||||
@ -74,14 +75,12 @@ class RNNCellTest(test.TestCase):
|
|||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([1, 2])
|
x = array_ops.zeros([1, 2])
|
||||||
m = array_ops.zeros([1, 2])
|
m = array_ops.zeros([1, 2])
|
||||||
cell = core_rnn_cell_impl.BasicRNNCell(2)
|
cell = rnn_cell_impl.BasicRNNCell(2)
|
||||||
g, _ = cell(x, m)
|
g, _ = cell(x, m)
|
||||||
self.assertEqual(
|
self.assertEqual([
|
||||||
["root/basic_rnn_cell/%s:0"
|
"root/basic_rnn_cell/%s:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
||||||
% core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
"root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME
|
||||||
"root/basic_rnn_cell/%s:0"
|
], [v.name for v in cell.trainable_variables])
|
||||||
% core_rnn_cell_impl._BIAS_VARIABLE_NAME],
|
|
||||||
[v.name for v in cell.trainable_variables])
|
|
||||||
self.assertFalse(cell.non_trainable_variables)
|
self.assertFalse(cell.non_trainable_variables)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(
|
res = sess.run(
|
||||||
@ -100,15 +99,13 @@ class RNNCellTest(test.TestCase):
|
|||||||
custom_getter=not_trainable_getter):
|
custom_getter=not_trainable_getter):
|
||||||
x = array_ops.zeros([1, 2])
|
x = array_ops.zeros([1, 2])
|
||||||
m = array_ops.zeros([1, 2])
|
m = array_ops.zeros([1, 2])
|
||||||
cell = core_rnn_cell_impl.BasicRNNCell(2)
|
cell = rnn_cell_impl.BasicRNNCell(2)
|
||||||
g, _ = cell(x, m)
|
g, _ = cell(x, m)
|
||||||
self.assertFalse(cell.trainable_variables)
|
self.assertFalse(cell.trainable_variables)
|
||||||
self.assertEqual(
|
self.assertEqual([
|
||||||
["root/basic_rnn_cell/%s:0"
|
"root/basic_rnn_cell/%s:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
||||||
% core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
"root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME
|
||||||
"root/basic_rnn_cell/%s:0"
|
], [v.name for v in cell.non_trainable_variables])
|
||||||
% core_rnn_cell_impl._BIAS_VARIABLE_NAME],
|
|
||||||
[v.name for v in cell.non_trainable_variables])
|
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(
|
res = sess.run(
|
||||||
[g], {x.name: np.array([[1., 1.]]),
|
[g], {x.name: np.array([[1., 1.]]),
|
||||||
@ -121,7 +118,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([1, 2])
|
x = array_ops.zeros([1, 2])
|
||||||
m = array_ops.zeros([1, 2])
|
m = array_ops.zeros([1, 2])
|
||||||
g, _ = core_rnn_cell_impl.GRUCell(2)(x, m)
|
g, _ = rnn_cell_impl.GRUCell(2)(x, m)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(
|
res = sess.run(
|
||||||
[g], {x.name: np.array([[1., 1.]]),
|
[g], {x.name: np.array([[1., 1.]]),
|
||||||
@ -133,7 +130,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
x = array_ops.zeros(
|
x = array_ops.zeros(
|
||||||
[1, 3]) # Test GRUCell with input_size != num_units.
|
[1, 3]) # Test GRUCell with input_size != num_units.
|
||||||
m = array_ops.zeros([1, 2])
|
m = array_ops.zeros([1, 2])
|
||||||
g, _ = core_rnn_cell_impl.GRUCell(2)(x, m)
|
g, _ = rnn_cell_impl.GRUCell(2)(x, m)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(
|
res = sess.run(
|
||||||
[g],
|
[g],
|
||||||
@ -148,20 +145,23 @@ class RNNCellTest(test.TestCase):
|
|||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([1, 2])
|
x = array_ops.zeros([1, 2])
|
||||||
m = array_ops.zeros([1, 8])
|
m = array_ops.zeros([1, 8])
|
||||||
cell = core_rnn_cell_impl.MultiRNNCell(
|
cell = rnn_cell_impl.MultiRNNCell(
|
||||||
[core_rnn_cell_impl.BasicLSTMCell(
|
[
|
||||||
2, state_is_tuple=False) for _ in range(2)],
|
rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
|
||||||
|
for _ in range(2)
|
||||||
|
],
|
||||||
state_is_tuple=False)
|
state_is_tuple=False)
|
||||||
g, out_m = cell(x, m)
|
g, out_m = cell(x, m)
|
||||||
expected_variable_names = [
|
expected_variable_names = [
|
||||||
"root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0"
|
"root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" %
|
||||||
% core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
||||||
"root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0"
|
"root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" %
|
||||||
% core_rnn_cell_impl._BIAS_VARIABLE_NAME,
|
rnn_cell_impl._BIAS_VARIABLE_NAME,
|
||||||
"root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0"
|
"root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" %
|
||||||
% core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
||||||
"root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0"
|
"root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" %
|
||||||
% core_rnn_cell_impl._BIAS_VARIABLE_NAME]
|
rnn_cell_impl._BIAS_VARIABLE_NAME
|
||||||
|
]
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
expected_variable_names, [v.name for v in cell.trainable_variables])
|
expected_variable_names, [v.name for v in cell.trainable_variables])
|
||||||
self.assertFalse(cell.non_trainable_variables)
|
self.assertFalse(cell.non_trainable_variables)
|
||||||
@ -185,8 +185,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
x = array_ops.zeros(
|
x = array_ops.zeros(
|
||||||
[1, 3]) # Test BasicLSTMCell with input_size != num_units.
|
[1, 3]) # Test BasicLSTMCell with input_size != num_units.
|
||||||
m = array_ops.zeros([1, 4])
|
m = array_ops.zeros([1, 4])
|
||||||
g, out_m = core_rnn_cell_impl.BasicLSTMCell(
|
g, out_m = rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)(x, m)
|
||||||
2, state_is_tuple=False)(x, m)
|
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(
|
res = sess.run(
|
||||||
[g, out_m],
|
[g, out_m],
|
||||||
@ -206,7 +205,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
x = array_ops.zeros([batch_size, input_size])
|
x = array_ops.zeros([batch_size, input_size])
|
||||||
m = array_ops.zeros([batch_size - 1, state_size])
|
m = array_ops.zeros([batch_size - 1, state_size])
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
g, out_m = core_rnn_cell_impl.BasicLSTMCell(
|
g, out_m = rnn_cell_impl.BasicLSTMCell(
|
||||||
num_units, state_is_tuple=False)(x, m)
|
num_units, state_is_tuple=False)(x, m)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
sess.run([g, out_m],
|
sess.run([g, out_m],
|
||||||
@ -225,7 +224,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
x = array_ops.zeros([batch_size, input_size])
|
x = array_ops.zeros([batch_size, input_size])
|
||||||
m = array_ops.zeros([batch_size, state_size])
|
m = array_ops.zeros([batch_size, state_size])
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
g, out_m = core_rnn_cell_impl.BasicLSTMCell(
|
g, out_m = rnn_cell_impl.BasicLSTMCell(
|
||||||
num_units, state_is_tuple=False)(x, m)
|
num_units, state_is_tuple=False)(x, m)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
sess.run([g, out_m],
|
sess.run([g, out_m],
|
||||||
@ -239,31 +238,29 @@ class RNNCellTest(test.TestCase):
|
|||||||
x = array_ops.zeros([1, 2])
|
x = array_ops.zeros([1, 2])
|
||||||
m0 = (array_ops.zeros([1, 2]),) * 2
|
m0 = (array_ops.zeros([1, 2]),) * 2
|
||||||
m1 = (array_ops.zeros([1, 2]),) * 2
|
m1 = (array_ops.zeros([1, 2]),) * 2
|
||||||
cell = core_rnn_cell_impl.MultiRNNCell(
|
cell = rnn_cell_impl.MultiRNNCell(
|
||||||
[core_rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)],
|
[rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)],
|
||||||
state_is_tuple=True)
|
state_is_tuple=True)
|
||||||
self.assertTrue(isinstance(cell.state_size, tuple))
|
self.assertTrue(isinstance(cell.state_size, tuple))
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
isinstance(cell.state_size[0], core_rnn_cell_impl.LSTMStateTuple))
|
isinstance(cell.state_size[0], rnn_cell_impl.LSTMStateTuple))
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
isinstance(cell.state_size[1], core_rnn_cell_impl.LSTMStateTuple))
|
isinstance(cell.state_size[1], rnn_cell_impl.LSTMStateTuple))
|
||||||
|
|
||||||
# Pass in regular tuples
|
# Pass in regular tuples
|
||||||
_, (out_m0, out_m1) = cell(x, (m0, m1))
|
_, (out_m0, out_m1) = cell(x, (m0, m1))
|
||||||
self.assertTrue(isinstance(out_m0, core_rnn_cell_impl.LSTMStateTuple))
|
self.assertTrue(isinstance(out_m0, rnn_cell_impl.LSTMStateTuple))
|
||||||
self.assertTrue(isinstance(out_m1, core_rnn_cell_impl.LSTMStateTuple))
|
self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple))
|
||||||
|
|
||||||
# Pass in LSTMStateTuples
|
# Pass in LSTMStateTuples
|
||||||
variable_scope.get_variable_scope().reuse_variables()
|
variable_scope.get_variable_scope().reuse_variables()
|
||||||
zero_state = cell.zero_state(1, dtypes.float32)
|
zero_state = cell.zero_state(1, dtypes.float32)
|
||||||
self.assertTrue(isinstance(zero_state, tuple))
|
self.assertTrue(isinstance(zero_state, tuple))
|
||||||
self.assertTrue(
|
self.assertTrue(isinstance(zero_state[0], rnn_cell_impl.LSTMStateTuple))
|
||||||
isinstance(zero_state[0], core_rnn_cell_impl.LSTMStateTuple))
|
self.assertTrue(isinstance(zero_state[1], rnn_cell_impl.LSTMStateTuple))
|
||||||
self.assertTrue(
|
|
||||||
isinstance(zero_state[1], core_rnn_cell_impl.LSTMStateTuple))
|
|
||||||
_, (out_m0, out_m1) = cell(x, zero_state)
|
_, (out_m0, out_m1) = cell(x, zero_state)
|
||||||
self.assertTrue(isinstance(out_m0, core_rnn_cell_impl.LSTMStateTuple))
|
self.assertTrue(isinstance(out_m0, rnn_cell_impl.LSTMStateTuple))
|
||||||
self.assertTrue(isinstance(out_m1, core_rnn_cell_impl.LSTMStateTuple))
|
self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple))
|
||||||
|
|
||||||
def testBasicLSTMCellWithStateTuple(self):
|
def testBasicLSTMCellWithStateTuple(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -272,9 +269,11 @@ class RNNCellTest(test.TestCase):
|
|||||||
x = array_ops.zeros([1, 2])
|
x = array_ops.zeros([1, 2])
|
||||||
m0 = array_ops.zeros([1, 4])
|
m0 = array_ops.zeros([1, 4])
|
||||||
m1 = array_ops.zeros([1, 4])
|
m1 = array_ops.zeros([1, 4])
|
||||||
cell = core_rnn_cell_impl.MultiRNNCell(
|
cell = rnn_cell_impl.MultiRNNCell(
|
||||||
[core_rnn_cell_impl.BasicLSTMCell(
|
[
|
||||||
2, state_is_tuple=False) for _ in range(2)],
|
rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
|
||||||
|
for _ in range(2)
|
||||||
|
],
|
||||||
state_is_tuple=True)
|
state_is_tuple=True)
|
||||||
g, (out_m0, out_m1) = cell(x, (m0, m1))
|
g, (out_m0, out_m1) = cell(x, (m0, m1))
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
@ -306,7 +305,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([batch_size, input_size])
|
x = array_ops.zeros([batch_size, input_size])
|
||||||
m = array_ops.zeros([batch_size, state_size])
|
m = array_ops.zeros([batch_size, state_size])
|
||||||
cell = core_rnn_cell_impl.LSTMCell(
|
cell = rnn_cell_impl.LSTMCell(
|
||||||
num_units=num_units,
|
num_units=num_units,
|
||||||
num_proj=num_proj,
|
num_proj=num_proj,
|
||||||
forget_bias=1.0,
|
forget_bias=1.0,
|
||||||
@ -340,7 +339,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([batch_size, input_size])
|
x = array_ops.zeros([batch_size, input_size])
|
||||||
m = array_ops.zeros([batch_size, state_size])
|
m = array_ops.zeros([batch_size, state_size])
|
||||||
cell = core_rnn_cell_impl.LSTMCell(
|
cell = rnn_cell_impl.LSTMCell(
|
||||||
num_units=num_units,
|
num_units=num_units,
|
||||||
num_proj=num_proj,
|
num_proj=num_proj,
|
||||||
forget_bias=1.0,
|
forget_bias=1.0,
|
||||||
@ -358,8 +357,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([1, 3])
|
x = array_ops.zeros([1, 3])
|
||||||
m = array_ops.zeros([1, 3])
|
m = array_ops.zeros([1, 3])
|
||||||
cell = core_rnn_cell_impl.OutputProjectionWrapper(
|
cell = contrib_rnn.OutputProjectionWrapper(rnn_cell_impl.GRUCell(3), 2)
|
||||||
core_rnn_cell_impl.GRUCell(3), 2)
|
|
||||||
g, new_m = cell(x, m)
|
g, new_m = cell(x, m)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run([g, new_m], {
|
res = sess.run([g, new_m], {
|
||||||
@ -376,8 +374,8 @@ class RNNCellTest(test.TestCase):
|
|||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([1, 2])
|
x = array_ops.zeros([1, 2])
|
||||||
m = array_ops.zeros([1, 3])
|
m = array_ops.zeros([1, 3])
|
||||||
cell = core_rnn_cell_impl.InputProjectionWrapper(
|
cell = contrib_rnn.InputProjectionWrapper(
|
||||||
core_rnn_cell_impl.GRUCell(3), num_proj=3)
|
rnn_cell_impl.GRUCell(3), num_proj=3)
|
||||||
g, new_m = cell(x, m)
|
g, new_m = cell(x, m)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(
|
res = sess.run(
|
||||||
@ -394,10 +392,10 @@ class RNNCellTest(test.TestCase):
|
|||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([1, 3])
|
x = array_ops.zeros([1, 3])
|
||||||
m = array_ops.zeros([1, 3])
|
m = array_ops.zeros([1, 3])
|
||||||
base_cell = core_rnn_cell_impl.GRUCell(3)
|
base_cell = rnn_cell_impl.GRUCell(3)
|
||||||
g, m_new = base_cell(x, m)
|
g, m_new = base_cell(x, m)
|
||||||
variable_scope.get_variable_scope().reuse_variables()
|
variable_scope.get_variable_scope().reuse_variables()
|
||||||
g_res, m_new_res = core_rnn_cell_impl.ResidualWrapper(base_cell)(x, m)
|
g_res, m_new_res = rnn_cell_impl.ResidualWrapper(base_cell)(x, m)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run([g, g_res, m_new, m_new_res], {
|
res = sess.run([g, g_res, m_new, m_new_res], {
|
||||||
x: np.array([[1., 1., 1.]]),
|
x: np.array([[1., 1., 1.]]),
|
||||||
@ -413,8 +411,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([1, 3])
|
x = array_ops.zeros([1, 3])
|
||||||
m = array_ops.zeros([1, 3])
|
m = array_ops.zeros([1, 3])
|
||||||
cell = core_rnn_cell_impl.DeviceWrapper(
|
cell = rnn_cell_impl.DeviceWrapper(rnn_cell_impl.GRUCell(3), "/cpu:14159")
|
||||||
core_rnn_cell_impl.GRUCell(3), "/cpu:14159")
|
|
||||||
outputs, _ = cell(x, m)
|
outputs, _ = cell(x, m)
|
||||||
self.assertTrue("cpu:14159" in outputs.device.lower())
|
self.assertTrue("cpu:14159" in outputs.device.lower())
|
||||||
|
|
||||||
@ -427,8 +424,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([1, 1, 3])
|
x = array_ops.zeros([1, 1, 3])
|
||||||
cell = core_rnn_cell_impl.DeviceWrapper(
|
cell = rnn_cell_impl.DeviceWrapper(rnn_cell_impl.GRUCell(3), "/gpu:0")
|
||||||
core_rnn_cell_impl.GRUCell(3), "/gpu:0")
|
|
||||||
with ops.device("/cpu:0"):
|
with ops.device("/cpu:0"):
|
||||||
outputs, _ = rnn.dynamic_rnn(
|
outputs, _ = rnn.dynamic_rnn(
|
||||||
cell=cell, inputs=x, dtype=dtypes.float32)
|
cell=cell, inputs=x, dtype=dtypes.float32)
|
||||||
@ -446,39 +442,14 @@ class RNNCellTest(test.TestCase):
|
|||||||
self.assertFalse([s for s in cpu_stats if "gru_cell" in s.node_name])
|
self.assertFalse([s for s in cpu_stats if "gru_cell" in s.node_name])
|
||||||
self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name])
|
self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name])
|
||||||
|
|
||||||
# def testUsingSecondCellInScopeWithExistingVariablesFails(self):
|
|
||||||
# # This test should go away when this behavior is no longer an
|
|
||||||
# # error (Approx. May 2017)
|
|
||||||
# cell1 = core_rnn_cell_impl.LSTMCell(3)
|
|
||||||
# cell2 = core_rnn_cell_impl.LSTMCell(3)
|
|
||||||
# x = array_ops.zeros([1, 3])
|
|
||||||
# m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2)
|
|
||||||
# cell1(x, m)
|
|
||||||
# with self.assertRaisesRegexp(ValueError, r"LSTMCell\(..., reuse=True\)"):
|
|
||||||
# cell2(x, m)
|
|
||||||
|
|
||||||
# def testUsingCellInDifferentScopeFromFirstCallFails(self):
|
|
||||||
# # This test should go away when this behavior is no longer an
|
|
||||||
# # error (Approx. May 2017)
|
|
||||||
# cell = core_rnn_cell_impl.LSTMCell(3)
|
|
||||||
# x = array_ops.zeros([1, 3])
|
|
||||||
# m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2)
|
|
||||||
# with variable_scope.variable_scope("scope1"):
|
|
||||||
# cell(x, m)
|
|
||||||
# with variable_scope.variable_scope("scope2"):
|
|
||||||
# with self.assertRaisesRegexp(ValueError, r"Attempt to reuse RNNCell"):
|
|
||||||
# cell(x, m)
|
|
||||||
|
|
||||||
def testEmbeddingWrapper(self):
|
def testEmbeddingWrapper(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([1, 1], dtype=dtypes.int32)
|
x = array_ops.zeros([1, 1], dtype=dtypes.int32)
|
||||||
m = array_ops.zeros([1, 2])
|
m = array_ops.zeros([1, 2])
|
||||||
embedding_cell = core_rnn_cell_impl.EmbeddingWrapper(
|
embedding_cell = contrib_rnn.EmbeddingWrapper(
|
||||||
core_rnn_cell_impl.GRUCell(2),
|
rnn_cell_impl.GRUCell(2), embedding_classes=3, embedding_size=2)
|
||||||
embedding_classes=3,
|
|
||||||
embedding_size=2)
|
|
||||||
self.assertEqual(embedding_cell.output_size, 2)
|
self.assertEqual(embedding_cell.output_size, 2)
|
||||||
g, new_m = embedding_cell(x, m)
|
g, new_m = embedding_cell(x, m)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
@ -495,9 +466,8 @@ class RNNCellTest(test.TestCase):
|
|||||||
with variable_scope.variable_scope("root"):
|
with variable_scope.variable_scope("root"):
|
||||||
inputs = ops.convert_to_tensor([[[0], [0]]], dtype=dtypes.int64)
|
inputs = ops.convert_to_tensor([[[0], [0]]], dtype=dtypes.int64)
|
||||||
input_lengths = ops.convert_to_tensor([2], dtype=dtypes.int64)
|
input_lengths = ops.convert_to_tensor([2], dtype=dtypes.int64)
|
||||||
embedding_cell = core_rnn_cell_impl.EmbeddingWrapper(
|
embedding_cell = contrib_rnn.EmbeddingWrapper(
|
||||||
core_rnn_cell_impl.BasicLSTMCell(
|
rnn_cell_impl.BasicLSTMCell(1, state_is_tuple=True),
|
||||||
1, state_is_tuple=True),
|
|
||||||
embedding_classes=1,
|
embedding_classes=1,
|
||||||
embedding_size=2)
|
embedding_size=2)
|
||||||
outputs, _ = rnn.dynamic_rnn(
|
outputs, _ = rnn.dynamic_rnn(
|
||||||
@ -515,9 +485,9 @@ class RNNCellTest(test.TestCase):
|
|||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([1, 2])
|
x = array_ops.zeros([1, 2])
|
||||||
m = array_ops.zeros([1, 4])
|
m = array_ops.zeros([1, 4])
|
||||||
_, ml = core_rnn_cell_impl.MultiRNNCell(
|
_, ml = rnn_cell_impl.MultiRNNCell(
|
||||||
[core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
|
[rnn_cell_impl.GRUCell(2)
|
||||||
state_is_tuple=False)(x, m)
|
for _ in range(2)], state_is_tuple=False)(x, m)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(ml, {
|
res = sess.run(ml, {
|
||||||
x.name: np.array([[1., 1.]]),
|
x.name: np.array([[1., 1.]]),
|
||||||
@ -536,13 +506,13 @@ class RNNCellTest(test.TestCase):
|
|||||||
|
|
||||||
# Test incorrectness of state
|
# Test incorrectness of state
|
||||||
with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
|
with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
|
||||||
core_rnn_cell_impl.MultiRNNCell(
|
rnn_cell_impl.MultiRNNCell(
|
||||||
[core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
|
[rnn_cell_impl.GRUCell(2)
|
||||||
state_is_tuple=True)(x, m_bad)
|
for _ in range(2)], state_is_tuple=True)(x, m_bad)
|
||||||
|
|
||||||
_, ml = core_rnn_cell_impl.MultiRNNCell(
|
_, ml = rnn_cell_impl.MultiRNNCell(
|
||||||
[core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
|
[rnn_cell_impl.GRUCell(2)
|
||||||
state_is_tuple=True)(x, m_good)
|
for _ in range(2)], state_is_tuple=True)(x, m_good)
|
||||||
|
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(ml, {
|
res = sess.run(ml, {
|
||||||
@ -571,23 +541,23 @@ class DropoutWrapperTest(test.TestCase):
|
|||||||
time_steps = 2
|
time_steps = 2
|
||||||
x = constant_op.constant(
|
x = constant_op.constant(
|
||||||
[[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32)
|
[[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32)
|
||||||
m = core_rnn_cell_impl.LSTMStateTuple(
|
m = rnn_cell_impl.LSTMStateTuple(
|
||||||
*[constant_op.constant([[0.1, 0.1, 0.1]],
|
*[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32)
|
||||||
dtype=dtypes.float32)] * 2)
|
] * 2)
|
||||||
else:
|
else:
|
||||||
x = constant_op.constant(
|
x = constant_op.constant(
|
||||||
np.random.randn(time_steps, batch_size, 3).astype(np.float32))
|
np.random.randn(time_steps, batch_size, 3).astype(np.float32))
|
||||||
m = core_rnn_cell_impl.LSTMStateTuple(
|
m = rnn_cell_impl.LSTMStateTuple(*[
|
||||||
*[constant_op.constant([[0.1, 0.1, 0.1]] * batch_size,
|
constant_op.constant(
|
||||||
dtype=dtypes.float32)] * 2)
|
[[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32)
|
||||||
|
] * 2)
|
||||||
outputs, final_state = rnn.dynamic_rnn(
|
outputs, final_state = rnn.dynamic_rnn(
|
||||||
cell=core_rnn_cell_impl.DropoutWrapper(
|
cell=rnn_cell_impl.DropoutWrapper(
|
||||||
core_rnn_cell_impl.LSTMCell(3),
|
rnn_cell_impl.LSTMCell(3), dtype=x.dtype, **kwargs),
|
||||||
dtype=x.dtype,
|
|
||||||
**kwargs),
|
|
||||||
time_major=True,
|
time_major=True,
|
||||||
parallel_iterations=parallel_iterations,
|
parallel_iterations=parallel_iterations,
|
||||||
inputs=x, initial_state=m)
|
inputs=x,
|
||||||
|
initial_state=m)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run([outputs, final_state])
|
res = sess.run([outputs, final_state])
|
||||||
self.assertEqual(res[0].shape, (time_steps, batch_size, 3))
|
self.assertEqual(res[0].shape, (time_steps, batch_size, 3))
|
||||||
@ -775,7 +745,7 @@ class SlimRNNCellTest(test.TestCase):
|
|||||||
m = array_ops.zeros([1, 2])
|
m = array_ops.zeros([1, 2])
|
||||||
my_cell = functools.partial(basic_rnn_cell, num_units=2)
|
my_cell = functools.partial(basic_rnn_cell, num_units=2)
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
g, _ = core_rnn_cell_impl._SlimRNNCell(my_cell)(x, m)
|
g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(
|
res = sess.run(
|
||||||
@ -792,12 +762,12 @@ class SlimRNNCellTest(test.TestCase):
|
|||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
inputs = random_ops.random_uniform((batch_size, input_size))
|
inputs = random_ops.random_uniform((batch_size, input_size))
|
||||||
_, initial_state = basic_rnn_cell(inputs, None, num_units)
|
_, initial_state = basic_rnn_cell(inputs, None, num_units)
|
||||||
rnn_cell = core_rnn_cell_impl.BasicRNNCell(num_units)
|
rnn_cell = rnn_cell_impl.BasicRNNCell(num_units)
|
||||||
outputs, state = rnn_cell(inputs, initial_state)
|
outputs, state = rnn_cell(inputs, initial_state)
|
||||||
variable_scope.get_variable_scope().reuse_variables()
|
variable_scope.get_variable_scope().reuse_variables()
|
||||||
my_cell = functools.partial(basic_rnn_cell, num_units=num_units)
|
my_cell = functools.partial(basic_rnn_cell, num_units=num_units)
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
slim_cell = core_rnn_cell_impl._SlimRNNCell(my_cell)
|
slim_cell = rnn_cell_impl._SlimRNNCell(my_cell)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
slim_outputs, slim_state = slim_cell(inputs, initial_state)
|
slim_outputs, slim_state = slim_cell(inputs, initial_state)
|
||||||
self.assertEqual(slim_outputs.get_shape(), outputs.get_shape())
|
self.assertEqual(slim_outputs.get_shape(), outputs.get_shape())
|
||||||
|
@ -24,9 +24,6 @@ import numpy as np
|
|||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
from tensorflow.contrib import rnn as rnn_lib
|
from tensorflow.contrib import rnn as rnn_lib
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn
|
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
|
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -38,6 +35,7 @@ from tensorflow.python.ops import gradients_impl
|
|||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import rnn
|
from tensorflow.python.ops import rnn
|
||||||
|
from tensorflow.python.ops import rnn_cell
|
||||||
from tensorflow.python.ops import tensor_array_ops
|
from tensorflow.python.ops import tensor_array_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables as variables_lib
|
from tensorflow.python.ops import variables as variables_lib
|
||||||
@ -153,7 +151,7 @@ class RNNTest(test.TestCase):
|
|||||||
cell = Plus1RNNCell()
|
cell = Plus1RNNCell()
|
||||||
inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))]
|
inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))]
|
||||||
with self.assertRaisesRegexp(ValueError, "must be a vector"):
|
with self.assertRaisesRegexp(ValueError, "must be a vector"):
|
||||||
core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32, sequence_length=4)
|
rnn.static_rnn(cell, inputs, dtype=dtypes.float32, sequence_length=4)
|
||||||
|
|
||||||
def testRNN(self):
|
def testRNN(self):
|
||||||
cell = Plus1RNNCell()
|
cell = Plus1RNNCell()
|
||||||
@ -164,7 +162,7 @@ class RNNTest(test.TestCase):
|
|||||||
array_ops.placeholder(
|
array_ops.placeholder(
|
||||||
dtypes.float32, shape=(batch_size, input_size))
|
dtypes.float32, shape=(batch_size, input_size))
|
||||||
]
|
]
|
||||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||||
self.assertEqual(len(outputs), len(inputs))
|
self.assertEqual(len(outputs), len(inputs))
|
||||||
for out, inp in zip(outputs, inputs):
|
for out, inp in zip(outputs, inputs):
|
||||||
self.assertEqual(out.get_shape(), inp.get_shape())
|
self.assertEqual(out.get_shape(), inp.get_shape())
|
||||||
@ -186,7 +184,7 @@ class RNNTest(test.TestCase):
|
|||||||
|
|
||||||
def testDropout(self):
|
def testDropout(self):
|
||||||
cell = Plus1RNNCell()
|
cell = Plus1RNNCell()
|
||||||
full_dropout_cell = core_rnn_cell_impl.DropoutWrapper(
|
full_dropout_cell = rnn_cell.DropoutWrapper(
|
||||||
cell, input_keep_prob=1e-12, seed=0)
|
cell, input_keep_prob=1e-12, seed=0)
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
input_size = 5
|
input_size = 5
|
||||||
@ -196,9 +194,9 @@ class RNNTest(test.TestCase):
|
|||||||
dtypes.float32, shape=(batch_size, input_size))
|
dtypes.float32, shape=(batch_size, input_size))
|
||||||
]
|
]
|
||||||
with variable_scope.variable_scope("share_scope"):
|
with variable_scope.variable_scope("share_scope"):
|
||||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||||
with variable_scope.variable_scope("drop_scope"):
|
with variable_scope.variable_scope("drop_scope"):
|
||||||
dropped_outputs, _ = core_rnn.static_rnn(
|
dropped_outputs, _ = rnn.static_rnn(
|
||||||
full_dropout_cell, inputs, dtype=dtypes.float32)
|
full_dropout_cell, inputs, dtype=dtypes.float32)
|
||||||
self.assertEqual(len(outputs), len(inputs))
|
self.assertEqual(len(outputs), len(inputs))
|
||||||
for out, inp in zip(outputs, inputs):
|
for out, inp in zip(outputs, inputs):
|
||||||
@ -227,7 +225,7 @@ class RNNTest(test.TestCase):
|
|||||||
dtypes.float32, shape=(batch_size, input_size))
|
dtypes.float32, shape=(batch_size, input_size))
|
||||||
]
|
]
|
||||||
with variable_scope.variable_scope("drop_scope"):
|
with variable_scope.variable_scope("drop_scope"):
|
||||||
dynamic_outputs, dynamic_state = core_rnn.static_rnn(
|
dynamic_outputs, dynamic_state = rnn.static_rnn(
|
||||||
cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32)
|
cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32)
|
||||||
self.assertEqual(len(dynamic_outputs), len(inputs))
|
self.assertEqual(len(dynamic_outputs), len(inputs))
|
||||||
|
|
||||||
@ -297,8 +295,7 @@ class RNNTest(test.TestCase):
|
|||||||
array_ops.placeholder(
|
array_ops.placeholder(
|
||||||
dtypes.float32, shape=(batch_size, input_size))
|
dtypes.float32, shape=(batch_size, input_size))
|
||||||
]
|
]
|
||||||
return core_rnn.static_rnn(
|
return rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope=scope)
|
||||||
cell, inputs, dtype=dtypes.float32, scope=scope)
|
|
||||||
|
|
||||||
self._testScope(factory, use_outer_scope=True)
|
self._testScope(factory, use_outer_scope=True)
|
||||||
self._testScope(factory, use_outer_scope=False)
|
self._testScope(factory, use_outer_scope=False)
|
||||||
@ -319,13 +316,13 @@ class LSTMTest(test.TestCase):
|
|||||||
with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess:
|
with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess:
|
||||||
initializer = init_ops.random_uniform_initializer(
|
initializer = init_ops.random_uniform_initializer(
|
||||||
-0.01, 0.01, seed=self._seed)
|
-0.01, 0.01, seed=self._seed)
|
||||||
cell = core_rnn_cell_impl.LSTMCell(
|
cell = rnn_cell.LSTMCell(
|
||||||
num_units, initializer=initializer, state_is_tuple=False)
|
num_units, initializer=initializer, state_is_tuple=False)
|
||||||
inputs = max_length * [
|
inputs = max_length * [
|
||||||
array_ops.placeholder(
|
array_ops.placeholder(
|
||||||
dtypes.float32, shape=(batch_size, input_size))
|
dtypes.float32, shape=(batch_size, input_size))
|
||||||
]
|
]
|
||||||
outputs, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||||
self.assertEqual(len(outputs), len(inputs))
|
self.assertEqual(len(outputs), len(inputs))
|
||||||
for out in outputs:
|
for out in outputs:
|
||||||
self.assertEqual(out.get_shape().as_list(), [batch_size, num_units])
|
self.assertEqual(out.get_shape().as_list(), [batch_size, num_units])
|
||||||
@ -342,7 +339,7 @@ class LSTMTest(test.TestCase):
|
|||||||
with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess:
|
with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess:
|
||||||
initializer = init_ops.random_uniform_initializer(
|
initializer = init_ops.random_uniform_initializer(
|
||||||
-0.01, 0.01, seed=self._seed)
|
-0.01, 0.01, seed=self._seed)
|
||||||
cell = core_rnn_cell_impl.LSTMCell(
|
cell = rnn_cell.LSTMCell(
|
||||||
num_units,
|
num_units,
|
||||||
use_peepholes=True,
|
use_peepholes=True,
|
||||||
cell_clip=0.0,
|
cell_clip=0.0,
|
||||||
@ -352,7 +349,7 @@ class LSTMTest(test.TestCase):
|
|||||||
array_ops.placeholder(
|
array_ops.placeholder(
|
||||||
dtypes.float32, shape=(batch_size, input_size))
|
dtypes.float32, shape=(batch_size, input_size))
|
||||||
]
|
]
|
||||||
outputs, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||||
self.assertEqual(len(outputs), len(inputs))
|
self.assertEqual(len(outputs), len(inputs))
|
||||||
for out in outputs:
|
for out in outputs:
|
||||||
self.assertEqual(out.get_shape().as_list(), [batch_size, num_units])
|
self.assertEqual(out.get_shape().as_list(), [batch_size, num_units])
|
||||||
@ -374,7 +371,7 @@ class LSTMTest(test.TestCase):
|
|||||||
initializer = init_ops.random_uniform_initializer(
|
initializer = init_ops.random_uniform_initializer(
|
||||||
-0.01, 0.01, seed=self._seed)
|
-0.01, 0.01, seed=self._seed)
|
||||||
state_saver = TestStateSaver(batch_size, 2 * num_units)
|
state_saver = TestStateSaver(batch_size, 2 * num_units)
|
||||||
cell = core_rnn_cell_impl.LSTMCell(
|
cell = rnn_cell.LSTMCell(
|
||||||
num_units,
|
num_units,
|
||||||
use_peepholes=False,
|
use_peepholes=False,
|
||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
@ -384,7 +381,7 @@ class LSTMTest(test.TestCase):
|
|||||||
dtypes.float32, shape=(batch_size, input_size))
|
dtypes.float32, shape=(batch_size, input_size))
|
||||||
]
|
]
|
||||||
with variable_scope.variable_scope("share_scope"):
|
with variable_scope.variable_scope("share_scope"):
|
||||||
outputs, state = core_rnn.static_state_saving_rnn(
|
outputs, state = rnn.static_state_saving_rnn(
|
||||||
cell, inputs, state_saver=state_saver, state_name="save_lstm")
|
cell, inputs, state_saver=state_saver, state_name="save_lstm")
|
||||||
self.assertEqual(len(outputs), len(inputs))
|
self.assertEqual(len(outputs), len(inputs))
|
||||||
for out in outputs:
|
for out in outputs:
|
||||||
@ -406,7 +403,7 @@ class LSTMTest(test.TestCase):
|
|||||||
initializer = init_ops.random_uniform_initializer(
|
initializer = init_ops.random_uniform_initializer(
|
||||||
-0.01, 0.01, seed=self._seed)
|
-0.01, 0.01, seed=self._seed)
|
||||||
state_saver = TestStateSaver(batch_size, num_units)
|
state_saver = TestStateSaver(batch_size, num_units)
|
||||||
cell = core_rnn_cell_impl.LSTMCell(
|
cell = rnn_cell.LSTMCell(
|
||||||
num_units,
|
num_units,
|
||||||
use_peepholes=False,
|
use_peepholes=False,
|
||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
@ -416,7 +413,7 @@ class LSTMTest(test.TestCase):
|
|||||||
dtypes.float32, shape=(batch_size, input_size))
|
dtypes.float32, shape=(batch_size, input_size))
|
||||||
]
|
]
|
||||||
with variable_scope.variable_scope("share_scope"):
|
with variable_scope.variable_scope("share_scope"):
|
||||||
outputs, state = core_rnn.static_state_saving_rnn(
|
outputs, state = rnn.static_state_saving_rnn(
|
||||||
cell, inputs, state_saver=state_saver, state_name=("c", "m"))
|
cell, inputs, state_saver=state_saver, state_name=("c", "m"))
|
||||||
self.assertEqual(len(outputs), len(inputs))
|
self.assertEqual(len(outputs), len(inputs))
|
||||||
for out in outputs:
|
for out in outputs:
|
||||||
@ -450,14 +447,14 @@ class LSTMTest(test.TestCase):
|
|||||||
})
|
})
|
||||||
|
|
||||||
def _cell(i):
|
def _cell(i):
|
||||||
return core_rnn_cell_impl.LSTMCell(
|
return rnn_cell.LSTMCell(
|
||||||
num_units + i,
|
num_units + i,
|
||||||
use_peepholes=False,
|
use_peepholes=False,
|
||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
state_is_tuple=True)
|
state_is_tuple=True)
|
||||||
|
|
||||||
# This creates a state tuple which has 4 sub-tuples of length 2 each.
|
# This creates a state tuple which has 4 sub-tuples of length 2 each.
|
||||||
cell = core_rnn_cell_impl.MultiRNNCell(
|
cell = rnn_cell.MultiRNNCell(
|
||||||
[_cell(i) for i in range(4)], state_is_tuple=True)
|
[_cell(i) for i in range(4)], state_is_tuple=True)
|
||||||
|
|
||||||
self.assertEqual(len(cell.state_size), 4)
|
self.assertEqual(len(cell.state_size), 4)
|
||||||
@ -471,7 +468,7 @@ class LSTMTest(test.TestCase):
|
|||||||
|
|
||||||
state_names = (("c0", "m0"), ("c1", "m1"), ("c2", "m2"), ("c3", "m3"))
|
state_names = (("c0", "m0"), ("c1", "m1"), ("c2", "m2"), ("c3", "m3"))
|
||||||
with variable_scope.variable_scope("share_scope"):
|
with variable_scope.variable_scope("share_scope"):
|
||||||
outputs, state = core_rnn.static_state_saving_rnn(
|
outputs, state = rnn.static_state_saving_rnn(
|
||||||
cell, inputs, state_saver=state_saver, state_name=state_names)
|
cell, inputs, state_saver=state_saver, state_name=state_names)
|
||||||
self.assertEqual(len(outputs), len(inputs))
|
self.assertEqual(len(outputs), len(inputs))
|
||||||
|
|
||||||
@ -508,13 +505,13 @@ class LSTMTest(test.TestCase):
|
|||||||
array_ops.placeholder(
|
array_ops.placeholder(
|
||||||
dtypes.float32, shape=(None, input_size))
|
dtypes.float32, shape=(None, input_size))
|
||||||
]
|
]
|
||||||
cell = core_rnn_cell_impl.LSTMCell(
|
cell = rnn_cell.LSTMCell(
|
||||||
num_units,
|
num_units,
|
||||||
use_peepholes=True,
|
use_peepholes=True,
|
||||||
num_proj=num_proj,
|
num_proj=num_proj,
|
||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
state_is_tuple=False)
|
state_is_tuple=False)
|
||||||
outputs, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||||
self.assertEqual(len(outputs), len(inputs))
|
self.assertEqual(len(outputs), len(inputs))
|
||||||
|
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
@ -535,20 +532,20 @@ class LSTMTest(test.TestCase):
|
|||||||
array_ops.placeholder(
|
array_ops.placeholder(
|
||||||
dtypes.float32, shape=(None, input_size))
|
dtypes.float32, shape=(None, input_size))
|
||||||
]
|
]
|
||||||
cell_notuple = core_rnn_cell_impl.LSTMCell(
|
cell_notuple = rnn_cell.LSTMCell(
|
||||||
num_units,
|
num_units,
|
||||||
use_peepholes=True,
|
use_peepholes=True,
|
||||||
num_proj=num_proj,
|
num_proj=num_proj,
|
||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
state_is_tuple=False)
|
state_is_tuple=False)
|
||||||
cell_tuple = core_rnn_cell_impl.LSTMCell(
|
cell_tuple = rnn_cell.LSTMCell(
|
||||||
num_units,
|
num_units,
|
||||||
use_peepholes=True,
|
use_peepholes=True,
|
||||||
num_proj=num_proj,
|
num_proj=num_proj,
|
||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
state_is_tuple=True)
|
state_is_tuple=True)
|
||||||
with variable_scope.variable_scope("root") as scope:
|
with variable_scope.variable_scope("root") as scope:
|
||||||
outputs_notuple, state_notuple = core_rnn.static_rnn(
|
outputs_notuple, state_notuple = rnn.static_rnn(
|
||||||
cell_notuple,
|
cell_notuple,
|
||||||
inputs,
|
inputs,
|
||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
@ -562,7 +559,7 @@ class LSTMTest(test.TestCase):
|
|||||||
# the parameters from different RNNCell instances. Right now,
|
# the parameters from different RNNCell instances. Right now,
|
||||||
# this seems an unrealistic use case except for testing.
|
# this seems an unrealistic use case except for testing.
|
||||||
cell_tuple._scope = cell_notuple._scope # pylint: disable=protected-access
|
cell_tuple._scope = cell_notuple._scope # pylint: disable=protected-access
|
||||||
outputs_tuple, state_tuple = core_rnn.static_rnn(
|
outputs_tuple, state_tuple = rnn.static_rnn(
|
||||||
cell_tuple,
|
cell_tuple,
|
||||||
inputs,
|
inputs,
|
||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
@ -603,7 +600,7 @@ class LSTMTest(test.TestCase):
|
|||||||
dtypes.float32, shape=(None, input_size))
|
dtypes.float32, shape=(None, input_size))
|
||||||
]
|
]
|
||||||
|
|
||||||
cell = core_rnn_cell_impl.LSTMCell(
|
cell = rnn_cell.LSTMCell(
|
||||||
num_units,
|
num_units,
|
||||||
use_peepholes=True,
|
use_peepholes=True,
|
||||||
num_proj=num_proj,
|
num_proj=num_proj,
|
||||||
@ -612,7 +609,7 @@ class LSTMTest(test.TestCase):
|
|||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
state_is_tuple=False)
|
state_is_tuple=False)
|
||||||
|
|
||||||
outputs, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||||
|
|
||||||
self.assertEqual(len(outputs), len(inputs))
|
self.assertEqual(len(outputs), len(inputs))
|
||||||
|
|
||||||
@ -635,7 +632,7 @@ class LSTMTest(test.TestCase):
|
|||||||
dtypes.float64, shape=(None, input_size))
|
dtypes.float64, shape=(None, input_size))
|
||||||
]
|
]
|
||||||
|
|
||||||
cell = core_rnn_cell_impl.LSTMCell(
|
cell = rnn_cell.LSTMCell(
|
||||||
num_units,
|
num_units,
|
||||||
use_peepholes=True,
|
use_peepholes=True,
|
||||||
num_proj=num_proj,
|
num_proj=num_proj,
|
||||||
@ -644,7 +641,7 @@ class LSTMTest(test.TestCase):
|
|||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
state_is_tuple=False)
|
state_is_tuple=False)
|
||||||
|
|
||||||
outputs, _ = core_rnn.static_rnn(
|
outputs, _ = rnn.static_rnn(
|
||||||
cell,
|
cell,
|
||||||
inputs,
|
inputs,
|
||||||
initial_state=cell.zero_state(batch_size, dtypes.float64))
|
initial_state=cell.zero_state(batch_size, dtypes.float64))
|
||||||
@ -672,7 +669,7 @@ class LSTMTest(test.TestCase):
|
|||||||
]
|
]
|
||||||
initializer = init_ops.constant_initializer(0.001)
|
initializer = init_ops.constant_initializer(0.001)
|
||||||
|
|
||||||
cell_noshard = core_rnn_cell_impl.LSTMCell(
|
cell_noshard = rnn_cell.LSTMCell(
|
||||||
num_units,
|
num_units,
|
||||||
num_proj=num_proj,
|
num_proj=num_proj,
|
||||||
use_peepholes=True,
|
use_peepholes=True,
|
||||||
@ -681,7 +678,7 @@ class LSTMTest(test.TestCase):
|
|||||||
num_proj_shards=num_proj_shards,
|
num_proj_shards=num_proj_shards,
|
||||||
state_is_tuple=False)
|
state_is_tuple=False)
|
||||||
|
|
||||||
cell_shard = core_rnn_cell_impl.LSTMCell(
|
cell_shard = rnn_cell.LSTMCell(
|
||||||
num_units,
|
num_units,
|
||||||
use_peepholes=True,
|
use_peepholes=True,
|
||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
@ -689,10 +686,10 @@ class LSTMTest(test.TestCase):
|
|||||||
state_is_tuple=False)
|
state_is_tuple=False)
|
||||||
|
|
||||||
with variable_scope.variable_scope("noshard_scope"):
|
with variable_scope.variable_scope("noshard_scope"):
|
||||||
outputs_noshard, state_noshard = core_rnn.static_rnn(
|
outputs_noshard, state_noshard = rnn.static_rnn(
|
||||||
cell_noshard, inputs, dtype=dtypes.float32)
|
cell_noshard, inputs, dtype=dtypes.float32)
|
||||||
with variable_scope.variable_scope("shard_scope"):
|
with variable_scope.variable_scope("shard_scope"):
|
||||||
outputs_shard, state_shard = core_rnn.static_rnn(
|
outputs_shard, state_shard = rnn.static_rnn(
|
||||||
cell_shard, inputs, dtype=dtypes.float32)
|
cell_shard, inputs, dtype=dtypes.float32)
|
||||||
|
|
||||||
self.assertEqual(len(outputs_noshard), len(inputs))
|
self.assertEqual(len(outputs_noshard), len(inputs))
|
||||||
@ -731,7 +728,7 @@ class LSTMTest(test.TestCase):
|
|||||||
dtypes.float64, shape=(None, input_size))
|
dtypes.float64, shape=(None, input_size))
|
||||||
]
|
]
|
||||||
|
|
||||||
cell = core_rnn_cell_impl.LSTMCell(
|
cell = rnn_cell.LSTMCell(
|
||||||
num_units,
|
num_units,
|
||||||
use_peepholes=True,
|
use_peepholes=True,
|
||||||
num_proj=num_proj,
|
num_proj=num_proj,
|
||||||
@ -739,9 +736,9 @@ class LSTMTest(test.TestCase):
|
|||||||
num_proj_shards=num_proj_shards,
|
num_proj_shards=num_proj_shards,
|
||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
state_is_tuple=False)
|
state_is_tuple=False)
|
||||||
dropout_cell = core_rnn_cell_impl.DropoutWrapper(cell, 0.5, seed=0)
|
dropout_cell = rnn_cell.DropoutWrapper(cell, 0.5, seed=0)
|
||||||
|
|
||||||
outputs, state = core_rnn.static_rnn(
|
outputs, state = rnn.static_rnn(
|
||||||
dropout_cell,
|
dropout_cell,
|
||||||
inputs,
|
inputs,
|
||||||
sequence_length=sequence_length,
|
sequence_length=sequence_length,
|
||||||
@ -776,13 +773,13 @@ class LSTMTest(test.TestCase):
|
|||||||
array_ops.placeholder(
|
array_ops.placeholder(
|
||||||
dtypes.float32, shape=(None, input_size))
|
dtypes.float32, shape=(None, input_size))
|
||||||
]
|
]
|
||||||
cell = core_rnn_cell_impl.LSTMCell(
|
cell = rnn_cell.LSTMCell(
|
||||||
num_units,
|
num_units,
|
||||||
use_peepholes=True,
|
use_peepholes=True,
|
||||||
num_proj=num_proj,
|
num_proj=num_proj,
|
||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
state_is_tuple=False)
|
state_is_tuple=False)
|
||||||
cell_d = core_rnn_cell_impl.LSTMCell(
|
cell_d = rnn_cell.LSTMCell(
|
||||||
num_units,
|
num_units,
|
||||||
use_peepholes=True,
|
use_peepholes=True,
|
||||||
num_proj=num_proj,
|
num_proj=num_proj,
|
||||||
@ -790,11 +787,11 @@ class LSTMTest(test.TestCase):
|
|||||||
state_is_tuple=False)
|
state_is_tuple=False)
|
||||||
|
|
||||||
with variable_scope.variable_scope("share_scope"):
|
with variable_scope.variable_scope("share_scope"):
|
||||||
outputs0, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
outputs0, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||||
with variable_scope.variable_scope("share_scope", reuse=True):
|
with variable_scope.variable_scope("share_scope", reuse=True):
|
||||||
outputs1, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
outputs1, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||||
with variable_scope.variable_scope("diff_scope"):
|
with variable_scope.variable_scope("diff_scope"):
|
||||||
outputs2, _ = core_rnn.static_rnn(cell_d, inputs, dtype=dtypes.float32)
|
outputs2, _ = rnn.static_rnn(cell_d, inputs, dtype=dtypes.float32)
|
||||||
|
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
input_value = np.random.randn(batch_size, input_size)
|
input_value = np.random.randn(batch_size, input_size)
|
||||||
@ -823,7 +820,7 @@ class LSTMTest(test.TestCase):
|
|||||||
array_ops.placeholder(
|
array_ops.placeholder(
|
||||||
dtypes.float32, shape=(None, input_size))
|
dtypes.float32, shape=(None, input_size))
|
||||||
]
|
]
|
||||||
cell = core_rnn_cell_impl.LSTMCell(
|
cell = rnn_cell.LSTMCell(
|
||||||
num_units,
|
num_units,
|
||||||
use_peepholes=True,
|
use_peepholes=True,
|
||||||
num_proj=num_proj,
|
num_proj=num_proj,
|
||||||
@ -832,10 +829,10 @@ class LSTMTest(test.TestCase):
|
|||||||
|
|
||||||
with ops_lib.name_scope("scope0"):
|
with ops_lib.name_scope("scope0"):
|
||||||
with variable_scope.variable_scope("share_scope"):
|
with variable_scope.variable_scope("share_scope"):
|
||||||
outputs0, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
outputs0, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||||
with ops_lib.name_scope("scope1"):
|
with ops_lib.name_scope("scope1"):
|
||||||
with variable_scope.variable_scope("share_scope", reuse=True):
|
with variable_scope.variable_scope("share_scope", reuse=True):
|
||||||
outputs1, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
outputs1, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||||
|
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
input_value = np.random.randn(batch_size, input_size)
|
input_value = np.random.randn(batch_size, input_size)
|
||||||
@ -881,7 +878,7 @@ class LSTMTest(test.TestCase):
|
|||||||
|
|
||||||
def testDynamicRNNAllowsUnknownTimeDimension(self):
|
def testDynamicRNNAllowsUnknownTimeDimension(self):
|
||||||
inputs = array_ops.placeholder(dtypes.float32, shape=[1, None, 20])
|
inputs = array_ops.placeholder(dtypes.float32, shape=[1, None, 20])
|
||||||
cell = core_rnn_cell.GRUCell(30)
|
cell = rnn_cell.GRUCell(30)
|
||||||
# Smoke test, this should not raise an error
|
# Smoke test, this should not raise an error
|
||||||
rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32)
|
rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32)
|
||||||
|
|
||||||
@ -900,14 +897,14 @@ class LSTMTest(test.TestCase):
|
|||||||
dtypes.float32, shape=(None, input_size))
|
dtypes.float32, shape=(None, input_size))
|
||||||
]
|
]
|
||||||
inputs_c = array_ops.stack(inputs)
|
inputs_c = array_ops.stack(inputs)
|
||||||
cell = core_rnn_cell.LSTMCell(
|
cell = rnn_cell.LSTMCell(
|
||||||
num_units,
|
num_units,
|
||||||
use_peepholes=True,
|
use_peepholes=True,
|
||||||
num_proj=num_proj,
|
num_proj=num_proj,
|
||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
state_is_tuple=True)
|
state_is_tuple=True)
|
||||||
with variable_scope.variable_scope("root") as scope:
|
with variable_scope.variable_scope("root") as scope:
|
||||||
outputs_static, state_static = core_rnn.static_rnn(
|
outputs_static, state_static = rnn.static_rnn(
|
||||||
cell,
|
cell,
|
||||||
inputs,
|
inputs,
|
||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
@ -921,8 +918,8 @@ class LSTMTest(test.TestCase):
|
|||||||
time_major=True,
|
time_major=True,
|
||||||
sequence_length=sequence_length,
|
sequence_length=sequence_length,
|
||||||
scope=scope)
|
scope=scope)
|
||||||
self.assertTrue(isinstance(state_static, core_rnn_cell.LSTMStateTuple))
|
self.assertTrue(isinstance(state_static, rnn_cell.LSTMStateTuple))
|
||||||
self.assertTrue(isinstance(state_dynamic, core_rnn_cell.LSTMStateTuple))
|
self.assertTrue(isinstance(state_dynamic, rnn_cell.LSTMStateTuple))
|
||||||
self.assertEqual(state_static[0], state_static.c)
|
self.assertEqual(state_static[0], state_static.c)
|
||||||
self.assertEqual(state_static[1], state_static.h)
|
self.assertEqual(state_static[1], state_static.h)
|
||||||
self.assertEqual(state_dynamic[0], state_dynamic.c)
|
self.assertEqual(state_dynamic[0], state_dynamic.c)
|
||||||
@ -960,7 +957,7 @@ class LSTMTest(test.TestCase):
|
|||||||
inputs_c = array_ops.stack(inputs)
|
inputs_c = array_ops.stack(inputs)
|
||||||
|
|
||||||
def _cell(i):
|
def _cell(i):
|
||||||
return core_rnn_cell.LSTMCell(
|
return rnn_cell.LSTMCell(
|
||||||
num_units + i,
|
num_units + i,
|
||||||
use_peepholes=True,
|
use_peepholes=True,
|
||||||
num_proj=num_proj + i,
|
num_proj=num_proj + i,
|
||||||
@ -968,7 +965,7 @@ class LSTMTest(test.TestCase):
|
|||||||
state_is_tuple=True)
|
state_is_tuple=True)
|
||||||
|
|
||||||
# This creates a state tuple which has 4 sub-tuples of length 2 each.
|
# This creates a state tuple which has 4 sub-tuples of length 2 each.
|
||||||
cell = core_rnn_cell.MultiRNNCell(
|
cell = rnn_cell.MultiRNNCell(
|
||||||
[_cell(i) for i in range(4)], state_is_tuple=True)
|
[_cell(i) for i in range(4)], state_is_tuple=True)
|
||||||
|
|
||||||
self.assertEqual(len(cell.state_size), 4)
|
self.assertEqual(len(cell.state_size), 4)
|
||||||
@ -982,7 +979,7 @@ class LSTMTest(test.TestCase):
|
|||||||
self.assertEqual(test_zero[i][1].get_shape()[1], cell.state_size[i][1])
|
self.assertEqual(test_zero[i][1].get_shape()[1], cell.state_size[i][1])
|
||||||
|
|
||||||
with variable_scope.variable_scope("root") as scope:
|
with variable_scope.variable_scope("root") as scope:
|
||||||
outputs_static, state_static = core_rnn.static_rnn(
|
outputs_static, state_static = rnn.static_rnn(
|
||||||
cell,
|
cell,
|
||||||
inputs,
|
inputs,
|
||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
@ -1034,7 +1031,7 @@ class LSTMTest(test.TestCase):
|
|||||||
initializer = init_ops.random_uniform_initializer(
|
initializer = init_ops.random_uniform_initializer(
|
||||||
-0.01, 0.01, seed=self._seed)
|
-0.01, 0.01, seed=self._seed)
|
||||||
|
|
||||||
cell = core_rnn_cell.LSTMCell(
|
cell = rnn_cell.LSTMCell(
|
||||||
num_units,
|
num_units,
|
||||||
use_peepholes=True,
|
use_peepholes=True,
|
||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
@ -1042,7 +1039,7 @@ class LSTMTest(test.TestCase):
|
|||||||
state_is_tuple=False)
|
state_is_tuple=False)
|
||||||
|
|
||||||
with variable_scope.variable_scope("dynamic_scope"):
|
with variable_scope.variable_scope("dynamic_scope"):
|
||||||
outputs_static, state_static = core_rnn.static_rnn(
|
outputs_static, state_static = rnn.static_rnn(
|
||||||
cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32)
|
cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32)
|
||||||
|
|
||||||
feeds = {concat_inputs: input_values}
|
feeds = {concat_inputs: input_values}
|
||||||
@ -1092,7 +1089,7 @@ class LSTMTest(test.TestCase):
|
|||||||
initializer = init_ops.random_uniform_initializer(
|
initializer = init_ops.random_uniform_initializer(
|
||||||
-0.01, 0.01, seed=self._seed)
|
-0.01, 0.01, seed=self._seed)
|
||||||
|
|
||||||
cell = core_rnn_cell.LSTMCell(
|
cell = rnn_cell.LSTMCell(
|
||||||
num_units,
|
num_units,
|
||||||
use_peepholes=True,
|
use_peepholes=True,
|
||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
@ -1205,16 +1202,16 @@ class BidirectionalRNNTest(test.TestCase):
|
|||||||
-0.01, 0.01, seed=self._seed)
|
-0.01, 0.01, seed=self._seed)
|
||||||
sequence_length = array_ops.placeholder(
|
sequence_length = array_ops.placeholder(
|
||||||
dtypes.int64) if use_sequence_length else None
|
dtypes.int64) if use_sequence_length else None
|
||||||
cell_fw = core_rnn_cell_impl.LSTMCell(
|
cell_fw = rnn_cell.LSTMCell(
|
||||||
num_units, input_size, initializer=initializer, state_is_tuple=False)
|
num_units, input_size, initializer=initializer, state_is_tuple=False)
|
||||||
cell_bw = core_rnn_cell_impl.LSTMCell(
|
cell_bw = rnn_cell.LSTMCell(
|
||||||
num_units, input_size, initializer=initializer, state_is_tuple=False)
|
num_units, input_size, initializer=initializer, state_is_tuple=False)
|
||||||
inputs = max_length * [
|
inputs = max_length * [
|
||||||
array_ops.placeholder(
|
array_ops.placeholder(
|
||||||
dtypes.float32,
|
dtypes.float32,
|
||||||
shape=(batch_size, input_size) if use_shape else (None, input_size))
|
shape=(batch_size, input_size) if use_shape else (None, input_size))
|
||||||
]
|
]
|
||||||
outputs, state_fw, state_bw = core_rnn.static_bidirectional_rnn(
|
outputs, state_fw, state_bw = rnn.static_bidirectional_rnn(
|
||||||
cell_fw,
|
cell_fw,
|
||||||
cell_bw,
|
cell_bw,
|
||||||
inputs,
|
inputs,
|
||||||
@ -1337,9 +1334,9 @@ class BidirectionalRNNTest(test.TestCase):
|
|||||||
-0.01, 0.01, seed=self._seed)
|
-0.01, 0.01, seed=self._seed)
|
||||||
sequence_length = (
|
sequence_length = (
|
||||||
array_ops.placeholder(dtypes.int64) if use_sequence_length else None)
|
array_ops.placeholder(dtypes.int64) if use_sequence_length else None)
|
||||||
cell_fw = core_rnn_cell.LSTMCell(
|
cell_fw = rnn_cell.LSTMCell(
|
||||||
num_units, initializer=initializer, state_is_tuple=use_state_tuple)
|
num_units, initializer=initializer, state_is_tuple=use_state_tuple)
|
||||||
cell_bw = core_rnn_cell.LSTMCell(
|
cell_bw = rnn_cell.LSTMCell(
|
||||||
num_units, initializer=initializer, state_is_tuple=use_state_tuple)
|
num_units, initializer=initializer, state_is_tuple=use_state_tuple)
|
||||||
inputs = max_length * [
|
inputs = max_length * [
|
||||||
array_ops.placeholder(
|
array_ops.placeholder(
|
||||||
@ -1530,7 +1527,7 @@ class MultiDimensionalLSTMTest(test.TestCase):
|
|||||||
# variables.
|
# variables.
|
||||||
cell = DummyMultiDimensionalLSTM(feature_dims)
|
cell = DummyMultiDimensionalLSTM(feature_dims)
|
||||||
state_saver = TestStateSaver(batch_size, input_size)
|
state_saver = TestStateSaver(batch_size, input_size)
|
||||||
outputs_static, state_static = core_rnn.static_rnn(
|
outputs_static, state_static = rnn.static_rnn(
|
||||||
cell, inputs, dtype=dtypes.float32, sequence_length=sequence_length)
|
cell, inputs, dtype=dtypes.float32, sequence_length=sequence_length)
|
||||||
outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
|
outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
|
||||||
cell,
|
cell,
|
||||||
@ -1538,13 +1535,13 @@ class MultiDimensionalLSTMTest(test.TestCase):
|
|||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
time_major=True,
|
time_major=True,
|
||||||
sequence_length=sequence_length)
|
sequence_length=sequence_length)
|
||||||
outputs_bid, state_fw, state_bw = core_rnn.static_bidirectional_rnn(
|
outputs_bid, state_fw, state_bw = rnn.static_bidirectional_rnn(
|
||||||
cell,
|
cell,
|
||||||
cell,
|
cell,
|
||||||
inputs_using_dim,
|
inputs_using_dim,
|
||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
sequence_length=sequence_length)
|
sequence_length=sequence_length)
|
||||||
outputs_sav, state_sav = core_rnn.static_state_saving_rnn(
|
outputs_sav, state_sav = rnn.static_state_saving_rnn(
|
||||||
cell,
|
cell,
|
||||||
inputs_using_dim,
|
inputs_using_dim,
|
||||||
sequence_length=sequence_length,
|
sequence_length=sequence_length,
|
||||||
@ -1634,15 +1631,15 @@ class NestedLSTMTest(test.TestCase):
|
|||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
time_major=True,
|
time_major=True,
|
||||||
sequence_length=sequence_length)
|
sequence_length=sequence_length)
|
||||||
outputs_static, state_static = core_rnn.static_rnn(
|
outputs_static, state_static = rnn.static_rnn(
|
||||||
cell, inputs, dtype=dtypes.float32, sequence_length=sequence_length)
|
cell, inputs, dtype=dtypes.float32, sequence_length=sequence_length)
|
||||||
outputs_bid, state_fw, state_bw = core_rnn.static_bidirectional_rnn(
|
outputs_bid, state_fw, state_bw = rnn.static_bidirectional_rnn(
|
||||||
cell,
|
cell,
|
||||||
cell,
|
cell,
|
||||||
inputs_using_dim,
|
inputs_using_dim,
|
||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
sequence_length=sequence_length)
|
sequence_length=sequence_length)
|
||||||
outputs_sav, state_sav = core_rnn.static_state_saving_rnn(
|
outputs_sav, state_sav = rnn.static_state_saving_rnn(
|
||||||
cell,
|
cell,
|
||||||
inputs_using_dim,
|
inputs_using_dim,
|
||||||
sequence_length=sequence_length,
|
sequence_length=sequence_length,
|
||||||
@ -1738,7 +1735,7 @@ class StateSaverRNNTest(test.TestCase):
|
|||||||
initializer = init_ops.random_uniform_initializer(
|
initializer = init_ops.random_uniform_initializer(
|
||||||
-0.01, 0.01, seed=self._seed)
|
-0.01, 0.01, seed=self._seed)
|
||||||
state_saver = TestStateSaver(batch_size, 2 * num_units)
|
state_saver = TestStateSaver(batch_size, 2 * num_units)
|
||||||
cell = core_rnn_cell_impl.LSTMCell(
|
cell = rnn_cell.LSTMCell(
|
||||||
num_units,
|
num_units,
|
||||||
use_peepholes=False,
|
use_peepholes=False,
|
||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
@ -1747,7 +1744,7 @@ class StateSaverRNNTest(test.TestCase):
|
|||||||
array_ops.placeholder(
|
array_ops.placeholder(
|
||||||
dtypes.float32, shape=(batch_size, input_size))
|
dtypes.float32, shape=(batch_size, input_size))
|
||||||
]
|
]
|
||||||
return core_rnn.static_state_saving_rnn(
|
return rnn.static_state_saving_rnn(
|
||||||
cell,
|
cell,
|
||||||
inputs,
|
inputs,
|
||||||
state_saver=state_saver,
|
state_saver=state_saver,
|
||||||
@ -1779,7 +1776,7 @@ class GRUTest(test.TestCase):
|
|||||||
concat_inputs = array_ops.placeholder(
|
concat_inputs = array_ops.placeholder(
|
||||||
dtypes.float32, shape=(time_steps, batch_size, input_size))
|
dtypes.float32, shape=(time_steps, batch_size, input_size))
|
||||||
|
|
||||||
cell = core_rnn_cell.GRUCell(num_units=num_units)
|
cell = rnn_cell.GRUCell(num_units=num_units)
|
||||||
|
|
||||||
with variable_scope.variable_scope("dynamic_scope"):
|
with variable_scope.variable_scope("dynamic_scope"):
|
||||||
outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
|
outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
|
||||||
@ -1830,7 +1827,7 @@ class GRUTest(test.TestCase):
|
|||||||
def factory(scope):
|
def factory(scope):
|
||||||
concat_inputs = array_ops.placeholder(
|
concat_inputs = array_ops.placeholder(
|
||||||
dtypes.float32, shape=(time_steps, batch_size, input_size))
|
dtypes.float32, shape=(time_steps, batch_size, input_size))
|
||||||
cell = core_rnn_cell.GRUCell(num_units=num_units)
|
cell = rnn_cell.GRUCell(num_units=num_units)
|
||||||
return rnn.dynamic_rnn(
|
return rnn.dynamic_rnn(
|
||||||
cell,
|
cell,
|
||||||
inputs=concat_inputs,
|
inputs=concat_inputs,
|
||||||
@ -1864,7 +1861,7 @@ class RawRNNTest(test.TestCase):
|
|||||||
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
||||||
inputs_ta = inputs_ta.unstack(inputs)
|
inputs_ta = inputs_ta.unstack(inputs)
|
||||||
|
|
||||||
cell = core_rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
||||||
|
|
||||||
def loop_fn(time_, cell_output, cell_state, unused_loop_state):
|
def loop_fn(time_, cell_output, cell_state, unused_loop_state):
|
||||||
emit_output = cell_output # == None for time == 0
|
emit_output = cell_output # == None for time == 0
|
||||||
@ -1965,7 +1962,7 @@ class RawRNNTest(test.TestCase):
|
|||||||
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
||||||
inputs_ta = inputs_ta.unstack(inputs)
|
inputs_ta = inputs_ta.unstack(inputs)
|
||||||
|
|
||||||
cell = core_rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
||||||
|
|
||||||
def loop_fn(time_, cell_output, cell_state, loop_state):
|
def loop_fn(time_, cell_output, cell_state, loop_state):
|
||||||
if cell_output is None:
|
if cell_output is None:
|
||||||
@ -2001,7 +1998,7 @@ class RawRNNTest(test.TestCase):
|
|||||||
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
||||||
inputs_ta = inputs_ta.unstack(inputs)
|
inputs_ta = inputs_ta.unstack(inputs)
|
||||||
|
|
||||||
cell = core_rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
||||||
|
|
||||||
def loop_fn(time_, cell_output, cell_state, loop_state):
|
def loop_fn(time_, cell_output, cell_state, loop_state):
|
||||||
if cell_output is None:
|
if cell_output is None:
|
||||||
@ -2044,7 +2041,7 @@ class RawRNNTest(test.TestCase):
|
|||||||
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
||||||
inputs_ta = inputs_ta.unstack(inputs)
|
inputs_ta = inputs_ta.unstack(inputs)
|
||||||
|
|
||||||
cell = core_rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
||||||
|
|
||||||
def loop_fn(time_, cell_output, cell_state, _):
|
def loop_fn(time_, cell_output, cell_state, _):
|
||||||
if cell_output is None:
|
if cell_output is None:
|
||||||
@ -2113,7 +2110,7 @@ class RawRNNTest(test.TestCase):
|
|||||||
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
||||||
inputs_ta = inputs_ta.unstack(inputs)
|
inputs_ta = inputs_ta.unstack(inputs)
|
||||||
|
|
||||||
cell = core_rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
||||||
|
|
||||||
def loop_fn(time_, cell_output, cell_state, unused_loop_state):
|
def loop_fn(time_, cell_output, cell_state, unused_loop_state):
|
||||||
emit_output = cell_output # == None for time == 0
|
emit_output = cell_output # == None for time == 0
|
||||||
@ -2138,7 +2135,7 @@ class RawRNNTest(test.TestCase):
|
|||||||
self._testScope(factory, prefix=None, use_outer_scope=False)
|
self._testScope(factory, prefix=None, use_outer_scope=False)
|
||||||
|
|
||||||
|
|
||||||
class DeviceWrapperCell(core_rnn_cell.RNNCell):
|
class DeviceWrapperCell(rnn_cell.RNNCell):
|
||||||
"""Class to ensure cell calculation happens on a specific device."""
|
"""Class to ensure cell calculation happens on a specific device."""
|
||||||
|
|
||||||
def __init__(self, cell, device):
|
def __init__(self, cell, device):
|
||||||
@ -2172,7 +2169,7 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase):
|
|||||||
input_size = 5
|
input_size = 5
|
||||||
num_units = 10
|
num_units = 10
|
||||||
|
|
||||||
cell = core_rnn_cell.LSTMCell(num_units, use_peepholes=True)
|
cell = rnn_cell.LSTMCell(num_units, use_peepholes=True)
|
||||||
gpu_cell = DeviceWrapperCell(cell, cell_device)
|
gpu_cell = DeviceWrapperCell(cell, cell_device)
|
||||||
inputs = np.random.randn(batch_size, time_steps,
|
inputs = np.random.randn(batch_size, time_steps,
|
||||||
input_size).astype(np.float32)
|
input_size).astype(np.float32)
|
||||||
|
@ -20,14 +20,14 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn
|
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
|
||||||
from tensorflow.contrib.rnn.python.ops import fused_rnn_cell
|
from tensorflow.contrib.rnn.python.ops import fused_rnn_cell
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
|
from tensorflow.python.ops import rnn
|
||||||
|
from tensorflow.python.ops import rnn_cell
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -41,7 +41,7 @@ class FusedRnnCellTest(test.TestCase):
|
|||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
initializer = init_ops.random_uniform_initializer(
|
initializer = init_ops.random_uniform_initializer(
|
||||||
-0.01, 0.01, seed=19890212)
|
-0.01, 0.01, seed=19890212)
|
||||||
cell = core_rnn_cell_impl.BasicRNNCell(10)
|
cell = rnn_cell.BasicRNNCell(10)
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
input_size = 20
|
input_size = 20
|
||||||
timelen = 15
|
timelen = 15
|
||||||
@ -49,7 +49,7 @@ class FusedRnnCellTest(test.TestCase):
|
|||||||
np.random.randn(timelen, batch_size, input_size))
|
np.random.randn(timelen, batch_size, input_size))
|
||||||
with variable_scope.variable_scope("basic", initializer=initializer):
|
with variable_scope.variable_scope("basic", initializer=initializer):
|
||||||
unpacked_inputs = array_ops.unstack(inputs)
|
unpacked_inputs = array_ops.unstack(inputs)
|
||||||
outputs, state = core_rnn.static_rnn(
|
outputs, state = rnn.static_rnn(
|
||||||
cell, unpacked_inputs, dtype=dtypes.float64)
|
cell, unpacked_inputs, dtype=dtypes.float64)
|
||||||
packed_outputs = array_ops.stack(outputs)
|
packed_outputs = array_ops.stack(outputs)
|
||||||
basic_vars = [
|
basic_vars = [
|
||||||
@ -65,7 +65,7 @@ class FusedRnnCellTest(test.TestCase):
|
|||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"fused_static", initializer=initializer):
|
"fused_static", initializer=initializer):
|
||||||
fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(
|
fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(
|
||||||
core_rnn_cell_impl.BasicRNNCell(10))
|
rnn_cell.BasicRNNCell(10))
|
||||||
outputs, state = fused_cell(inputs, dtype=dtypes.float64)
|
outputs, state = fused_cell(inputs, dtype=dtypes.float64)
|
||||||
fused_static_vars = [
|
fused_static_vars = [
|
||||||
v for v in variables.trainable_variables()
|
v for v in variables.trainable_variables()
|
||||||
@ -86,7 +86,7 @@ class FusedRnnCellTest(test.TestCase):
|
|||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"fused_dynamic", initializer=initializer):
|
"fused_dynamic", initializer=initializer):
|
||||||
fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(
|
fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(
|
||||||
core_rnn_cell_impl.BasicRNNCell(10), use_dynamic_rnn=True)
|
rnn_cell.BasicRNNCell(10), use_dynamic_rnn=True)
|
||||||
outputs, state = fused_cell(inputs, dtype=dtypes.float64)
|
outputs, state = fused_cell(inputs, dtype=dtypes.float64)
|
||||||
fused_dynamic_vars = [
|
fused_dynamic_vars = [
|
||||||
v for v in variables.trainable_variables()
|
v for v in variables.trainable_variables()
|
||||||
@ -109,8 +109,8 @@ class FusedRnnCellTest(test.TestCase):
|
|||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
initializer = init_ops.random_uniform_initializer(
|
initializer = init_ops.random_uniform_initializer(
|
||||||
-0.01, 0.01, seed=19890213)
|
-0.01, 0.01, seed=19890213)
|
||||||
fw_cell = core_rnn_cell_impl.BasicRNNCell(10)
|
fw_cell = rnn_cell.BasicRNNCell(10)
|
||||||
bw_cell = core_rnn_cell_impl.BasicRNNCell(10)
|
bw_cell = rnn_cell.BasicRNNCell(10)
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
input_size = 20
|
input_size = 20
|
||||||
timelen = 15
|
timelen = 15
|
||||||
@ -120,7 +120,7 @@ class FusedRnnCellTest(test.TestCase):
|
|||||||
# test bi-directional rnn
|
# test bi-directional rnn
|
||||||
with variable_scope.variable_scope("basic", initializer=initializer):
|
with variable_scope.variable_scope("basic", initializer=initializer):
|
||||||
unpacked_inputs = array_ops.unstack(inputs)
|
unpacked_inputs = array_ops.unstack(inputs)
|
||||||
outputs, fw_state, bw_state = core_rnn.static_bidirectional_rnn(
|
outputs, fw_state, bw_state = rnn.static_bidirectional_rnn(
|
||||||
fw_cell, bw_cell, unpacked_inputs, dtype=dtypes.float64)
|
fw_cell, bw_cell, unpacked_inputs, dtype=dtypes.float64)
|
||||||
packed_outputs = array_ops.stack(outputs)
|
packed_outputs = array_ops.stack(outputs)
|
||||||
basic_vars = [
|
basic_vars = [
|
||||||
@ -136,10 +136,9 @@ class FusedRnnCellTest(test.TestCase):
|
|||||||
|
|
||||||
with variable_scope.variable_scope("fused", initializer=initializer):
|
with variable_scope.variable_scope("fused", initializer=initializer):
|
||||||
fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(
|
fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(
|
||||||
core_rnn_cell_impl.BasicRNNCell(10))
|
rnn_cell.BasicRNNCell(10))
|
||||||
fused_bw_cell = fused_rnn_cell.TimeReversedFusedRNN(
|
fused_bw_cell = fused_rnn_cell.TimeReversedFusedRNN(
|
||||||
fused_rnn_cell.FusedRNNCellAdaptor(
|
fused_rnn_cell.FusedRNNCellAdaptor(rnn_cell.BasicRNNCell(10)))
|
||||||
core_rnn_cell_impl.BasicRNNCell(10)))
|
|
||||||
fw_outputs, fw_state = fused_cell(
|
fw_outputs, fw_state = fused_cell(
|
||||||
inputs, dtype=dtypes.float64, scope="fw")
|
inputs, dtype=dtypes.float64, scope="fw")
|
||||||
bw_outputs, bw_state = fused_bw_cell(
|
bw_outputs, bw_state = fused_bw_cell(
|
||||||
|
@ -22,7 +22,6 @@ import time
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
|
||||||
from tensorflow.contrib.rnn.python.ops import gru_ops
|
from tensorflow.contrib.rnn.python.ops import gru_ops
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -33,6 +32,7 @@ from tensorflow.python.ops import gradients_impl
|
|||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import rnn
|
from tensorflow.python.ops import rnn
|
||||||
|
from tensorflow.python.ops import rnn_cell
|
||||||
from tensorflow.python.ops import variable_scope as vs
|
from tensorflow.python.ops import variable_scope as vs
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -78,7 +78,7 @@ class GRUBlockCellTest(test.TestCase):
|
|||||||
|
|
||||||
# Output from the basic GRU cell implementation.
|
# Output from the basic GRU cell implementation.
|
||||||
with vs.variable_scope("basic", initializer=initializer):
|
with vs.variable_scope("basic", initializer=initializer):
|
||||||
output = core_rnn_cell_impl.GRUCell(cell_size)(x, h)
|
output = rnn_cell.GRUCell(cell_size)(x, h)
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
basic_res = sess.run([output], {x: x_value, h: h_value})
|
basic_res = sess.run([output], {x: x_value, h: h_value})
|
||||||
|
|
||||||
@ -128,7 +128,7 @@ class GRUBlockCellTest(test.TestCase):
|
|||||||
|
|
||||||
# Output from the basic GRU cell implementation.
|
# Output from the basic GRU cell implementation.
|
||||||
with vs.variable_scope("basic", initializer=initializer):
|
with vs.variable_scope("basic", initializer=initializer):
|
||||||
cell = core_rnn_cell_impl.GRUCell(cell_size)
|
cell = rnn_cell.GRUCell(cell_size)
|
||||||
outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
|
outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
|
||||||
cell,
|
cell,
|
||||||
inputs=concat_x,
|
inputs=concat_x,
|
||||||
@ -192,7 +192,7 @@ class GRUBlockCellTest(test.TestCase):
|
|||||||
|
|
||||||
# Gradients from the basic GRU cell implementation.
|
# Gradients from the basic GRU cell implementation.
|
||||||
with vs.variable_scope("basic", initializer=initializer):
|
with vs.variable_scope("basic", initializer=initializer):
|
||||||
output = core_rnn_cell_impl.GRUCell(cell_size)(x, h)
|
output = rnn_cell.GRUCell(cell_size)(x, h)
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
|
|
||||||
all_variables = variables.global_variables()[4:8]
|
all_variables = variables.global_variables()[4:8]
|
||||||
@ -258,7 +258,7 @@ class GRUBlockCellTest(test.TestCase):
|
|||||||
|
|
||||||
# Gradients from the basic GRU cell implementation.
|
# Gradients from the basic GRU cell implementation.
|
||||||
with vs.variable_scope("basic", initializer=initializer):
|
with vs.variable_scope("basic", initializer=initializer):
|
||||||
cell = core_rnn_cell_impl.GRUCell(cell_size)
|
cell = rnn_cell.GRUCell(cell_size)
|
||||||
|
|
||||||
outputs_dynamic, _ = rnn.dynamic_rnn(
|
outputs_dynamic, _ = rnn.dynamic_rnn(
|
||||||
cell,
|
cell,
|
||||||
@ -377,7 +377,7 @@ def training_gru_block_vs_gru_cell(batch_size,
|
|||||||
|
|
||||||
# Output from the basic GRU cell implementation.
|
# Output from the basic GRU cell implementation.
|
||||||
with vs.variable_scope("basic", initializer=initializer):
|
with vs.variable_scope("basic", initializer=initializer):
|
||||||
cell = core_rnn_cell_impl.GRUCell(cell_size)
|
cell = rnn_cell.GRUCell(cell_size)
|
||||||
|
|
||||||
outputs_dynamic, _ = rnn.dynamic_rnn(
|
outputs_dynamic, _ = rnn.dynamic_rnn(
|
||||||
cell,
|
cell,
|
||||||
@ -448,7 +448,7 @@ def inference_gru_block_vs_gru_cell(batch_size,
|
|||||||
|
|
||||||
# Output from the basic GRU cell implementation.
|
# Output from the basic GRU cell implementation.
|
||||||
with vs.variable_scope("basic", initializer=initializer):
|
with vs.variable_scope("basic", initializer=initializer):
|
||||||
cell = core_rnn_cell_impl.GRUCell(cell_size)
|
cell = rnn_cell.GRUCell(cell_size)
|
||||||
outputs_dynamic, _ = rnn.dynamic_rnn(
|
outputs_dynamic, _ = rnn.dynamic_rnn(
|
||||||
cell,
|
cell,
|
||||||
inputs=concat_x,
|
inputs=concat_x,
|
||||||
@ -497,8 +497,8 @@ def single_bprop_step_gru_block_vs_gru_cell(batch_size,
|
|||||||
|
|
||||||
# Output from the basic GRU cell implementation.
|
# Output from the basic GRU cell implementation.
|
||||||
with vs.variable_scope("basic", initializer=initializer):
|
with vs.variable_scope("basic", initializer=initializer):
|
||||||
output = core_rnn_cell_impl.GRUCell(cell_size)(array_ops.identity(x),
|
output = rnn_cell.GRUCell(cell_size)(array_ops.identity(x),
|
||||||
array_ops.identity(h))
|
array_ops.identity(h))
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
grad_output_wrt_input = gradients_impl.gradients([output], h)
|
grad_output_wrt_input = gradients_impl.gradients([output], h)
|
||||||
basic_time_bprop = time_taken_by_op(grad_output_wrt_input, sess, iters)
|
basic_time_bprop = time_taken_by_op(grad_output_wrt_input, sess, iters)
|
||||||
|
@ -20,8 +20,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn
|
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
|
||||||
from tensorflow.contrib.rnn.python.ops import lstm_ops
|
from tensorflow.contrib.rnn.python.ops import lstm_ops
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -30,6 +28,7 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import rnn
|
from tensorflow.python.ops import rnn
|
||||||
|
from tensorflow.python.ops import rnn_cell
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -66,10 +65,9 @@ class LSTMBlockCellTest(test.TestCase):
|
|||||||
m1 = array_ops.zeros([1, 2])
|
m1 = array_ops.zeros([1, 2])
|
||||||
m2 = array_ops.zeros([1, 2])
|
m2 = array_ops.zeros([1, 2])
|
||||||
m3 = array_ops.zeros([1, 2])
|
m3 = array_ops.zeros([1, 2])
|
||||||
g, ((out_m0, out_m1),
|
g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
|
||||||
(out_m2, out_m3)) = core_rnn_cell_impl.MultiRNNCell(
|
[lstm_ops.LSTMBlockCell(2)
|
||||||
[lstm_ops.LSTMBlockCell(2) for _ in range(2)],
|
for _ in range(2)], state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
|
||||||
state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
|
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
|
res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
|
||||||
x.name: np.array([[1., 1.]]),
|
x.name: np.array([[1., 1.]]),
|
||||||
@ -88,11 +86,11 @@ class LSTMBlockCellTest(test.TestCase):
|
|||||||
|
|
||||||
def testCompatibleNames(self):
|
def testCompatibleNames(self):
|
||||||
with self.test_session(use_gpu=self._use_gpu, graph=ops.Graph()):
|
with self.test_session(use_gpu=self._use_gpu, graph=ops.Graph()):
|
||||||
cell = core_rnn_cell_impl.LSTMCell(10)
|
cell = rnn_cell.LSTMCell(10)
|
||||||
pcell = core_rnn_cell_impl.LSTMCell(10, use_peepholes=True)
|
pcell = rnn_cell.LSTMCell(10, use_peepholes=True)
|
||||||
inputs = [array_ops.zeros([4, 5])] * 6
|
inputs = [array_ops.zeros([4, 5])] * 6
|
||||||
core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope="basic")
|
rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope="basic")
|
||||||
core_rnn.static_rnn(pcell, inputs, dtype=dtypes.float32, scope="peephole")
|
rnn.static_rnn(pcell, inputs, dtype=dtypes.float32, scope="peephole")
|
||||||
basic_names = {
|
basic_names = {
|
||||||
v.name: v.get_shape()
|
v.name: v.get_shape()
|
||||||
for v in variables.trainable_variables()
|
for v in variables.trainable_variables()
|
||||||
@ -102,8 +100,8 @@ class LSTMBlockCellTest(test.TestCase):
|
|||||||
cell = lstm_ops.LSTMBlockCell(10)
|
cell = lstm_ops.LSTMBlockCell(10)
|
||||||
pcell = lstm_ops.LSTMBlockCell(10, use_peephole=True)
|
pcell = lstm_ops.LSTMBlockCell(10, use_peephole=True)
|
||||||
inputs = [array_ops.zeros([4, 5])] * 6
|
inputs = [array_ops.zeros([4, 5])] * 6
|
||||||
core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope="basic")
|
rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope="basic")
|
||||||
core_rnn.static_rnn(pcell, inputs, dtype=dtypes.float32, scope="peephole")
|
rnn.static_rnn(pcell, inputs, dtype=dtypes.float32, scope="peephole")
|
||||||
block_names = {
|
block_names = {
|
||||||
v.name: v.get_shape()
|
v.name: v.get_shape()
|
||||||
for v in variables.trainable_variables()
|
for v in variables.trainable_variables()
|
||||||
@ -140,11 +138,9 @@ class LSTMBlockCellTest(test.TestCase):
|
|||||||
m1 = array_ops.zeros([1, 2])
|
m1 = array_ops.zeros([1, 2])
|
||||||
m2 = array_ops.zeros([1, 2])
|
m2 = array_ops.zeros([1, 2])
|
||||||
m3 = array_ops.zeros([1, 2])
|
m3 = array_ops.zeros([1, 2])
|
||||||
g, ((out_m0, out_m1),
|
g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
|
||||||
(out_m2, out_m3)) = core_rnn_cell_impl.MultiRNNCell(
|
[rnn_cell.BasicLSTMCell(2, state_is_tuple=True) for _ in range(2)],
|
||||||
[core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
|
state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
|
||||||
for _ in range(2)],
|
|
||||||
state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
|
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
basic_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
|
basic_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
|
||||||
x.name: x_values,
|
x.name: x_values,
|
||||||
@ -159,10 +155,9 @@ class LSTMBlockCellTest(test.TestCase):
|
|||||||
m1 = array_ops.zeros([1, 2])
|
m1 = array_ops.zeros([1, 2])
|
||||||
m2 = array_ops.zeros([1, 2])
|
m2 = array_ops.zeros([1, 2])
|
||||||
m3 = array_ops.zeros([1, 2])
|
m3 = array_ops.zeros([1, 2])
|
||||||
g, ((out_m0, out_m1),
|
g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
|
||||||
(out_m2, out_m3)) = core_rnn_cell_impl.MultiRNNCell(
|
[lstm_ops.LSTMBlockCell(2)
|
||||||
[lstm_ops.LSTMBlockCell(2) for _ in range(2)],
|
for _ in range(2)], state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
|
||||||
state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
|
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
block_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
|
block_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
|
||||||
x.name: x_values,
|
x.name: x_values,
|
||||||
@ -193,12 +188,12 @@ class LSTMBlockCellTest(test.TestCase):
|
|||||||
m1 = array_ops.zeros([1, 2])
|
m1 = array_ops.zeros([1, 2])
|
||||||
m2 = array_ops.zeros([1, 2])
|
m2 = array_ops.zeros([1, 2])
|
||||||
m3 = array_ops.zeros([1, 2])
|
m3 = array_ops.zeros([1, 2])
|
||||||
g, ((out_m0, out_m1),
|
g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
|
||||||
(out_m2, out_m3)) = core_rnn_cell_impl.MultiRNNCell(
|
[
|
||||||
[core_rnn_cell_impl.LSTMCell(
|
rnn_cell.LSTMCell(2, use_peepholes=True, state_is_tuple=True)
|
||||||
2, use_peepholes=True, state_is_tuple=True)
|
for _ in range(2)
|
||||||
for _ in range(2)],
|
],
|
||||||
state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
|
state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
basic_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
|
basic_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
|
||||||
x.name: x_values,
|
x.name: x_values,
|
||||||
@ -213,11 +208,9 @@ class LSTMBlockCellTest(test.TestCase):
|
|||||||
m1 = array_ops.zeros([1, 2])
|
m1 = array_ops.zeros([1, 2])
|
||||||
m2 = array_ops.zeros([1, 2])
|
m2 = array_ops.zeros([1, 2])
|
||||||
m3 = array_ops.zeros([1, 2])
|
m3 = array_ops.zeros([1, 2])
|
||||||
g, ((out_m0, out_m1),
|
g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
|
||||||
(out_m2, out_m3)) = core_rnn_cell_impl.MultiRNNCell(
|
[lstm_ops.LSTMBlockCell(2, use_peephole=True) for _ in range(2)],
|
||||||
[lstm_ops.LSTMBlockCell(2, use_peephole=True)
|
state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
|
||||||
for _ in range(2)],
|
|
||||||
state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
|
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
block_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
|
block_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
|
||||||
x.name: x_values,
|
x.name: x_values,
|
||||||
@ -247,8 +240,8 @@ class LSTMBlockCellTest(test.TestCase):
|
|||||||
initializer = init_ops.random_uniform_initializer(
|
initializer = init_ops.random_uniform_initializer(
|
||||||
-0.01, 0.01, seed=19890212)
|
-0.01, 0.01, seed=19890212)
|
||||||
with variable_scope.variable_scope("basic", initializer=initializer):
|
with variable_scope.variable_scope("basic", initializer=initializer):
|
||||||
cell = core_rnn_cell_impl.BasicLSTMCell(cell_size, state_is_tuple=True)
|
cell = rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True)
|
||||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||||
|
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
basic_outputs, basic_state = sess.run([outputs, state[0]])
|
basic_outputs, basic_state = sess.run([outputs, state[0]])
|
||||||
@ -321,9 +314,9 @@ class LSTMBlockCellTest(test.TestCase):
|
|||||||
initializer = init_ops.random_uniform_initializer(
|
initializer = init_ops.random_uniform_initializer(
|
||||||
-0.01, 0.01, seed=19890212)
|
-0.01, 0.01, seed=19890212)
|
||||||
with variable_scope.variable_scope("basic", initializer=initializer):
|
with variable_scope.variable_scope("basic", initializer=initializer):
|
||||||
cell = core_rnn_cell_impl.LSTMCell(
|
cell = rnn_cell.LSTMCell(
|
||||||
cell_size, use_peepholes=True, state_is_tuple=True)
|
cell_size, use_peepholes=True, state_is_tuple=True)
|
||||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||||
|
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
basic_outputs, basic_state = sess.run([outputs, state[0]])
|
basic_outputs, basic_state = sess.run([outputs, state[0]])
|
||||||
@ -410,8 +403,8 @@ class LSTMBlockCellTest(test.TestCase):
|
|||||||
initializer = init_ops.random_uniform_initializer(
|
initializer = init_ops.random_uniform_initializer(
|
||||||
-0.01, 0.01, seed=19890213)
|
-0.01, 0.01, seed=19890213)
|
||||||
with variable_scope.variable_scope("basic", initializer=initializer):
|
with variable_scope.variable_scope("basic", initializer=initializer):
|
||||||
cell = core_rnn_cell_impl.BasicLSTMCell(cell_size, state_is_tuple=True)
|
cell = rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True)
|
||||||
outputs, state = core_rnn.static_rnn(
|
outputs, state = rnn.static_rnn(
|
||||||
cell, inputs, dtype=dtypes.float32, sequence_length=seq_lengths)
|
cell, inputs, dtype=dtypes.float32, sequence_length=seq_lengths)
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
basic_outputs, basic_state = sess.run([outputs, state[0]])
|
basic_outputs, basic_state = sess.run([outputs, state[0]])
|
||||||
|
@ -22,8 +22,7 @@ import itertools
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell
|
||||||
from tensorflow.contrib.rnn.python.ops import rnn_cell
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
@ -37,6 +36,7 @@ from tensorflow.python.ops import init_ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import rnn
|
from tensorflow.python.ops import rnn
|
||||||
|
from tensorflow.python.ops import rnn_cell
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -65,7 +65,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([batch_size, input_size])
|
x = array_ops.zeros([batch_size, input_size])
|
||||||
m = array_ops.zeros([batch_size, state_size])
|
m = array_ops.zeros([batch_size, state_size])
|
||||||
output, state = rnn_cell.CoupledInputForgetGateLSTMCell(
|
output, state = contrib_rnn_cell.CoupledInputForgetGateLSTMCell(
|
||||||
num_units=num_units, forget_bias=1.0, state_is_tuple=False)(x, m)
|
num_units=num_units, forget_bias=1.0, state_is_tuple=False)(x, m)
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res = sess.run([output, state], {
|
res = sess.run([output, state], {
|
||||||
@ -94,7 +94,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([batch_size, input_size])
|
x = array_ops.zeros([batch_size, input_size])
|
||||||
m = array_ops.zeros([batch_size, state_size * num_shifts])
|
m = array_ops.zeros([batch_size, state_size * num_shifts])
|
||||||
output, state = rnn_cell.TimeFreqLSTMCell(
|
output, state = contrib_rnn_cell.TimeFreqLSTMCell(
|
||||||
num_units=num_units,
|
num_units=num_units,
|
||||||
feature_size=feature_size,
|
feature_size=feature_size,
|
||||||
frequency_skip=frequency_skip,
|
frequency_skip=frequency_skip,
|
||||||
@ -130,7 +130,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
num_shifts = int((input_size - feature_size) / frequency_skip + 1)
|
num_shifts = int((input_size - feature_size) / frequency_skip + 1)
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
cell = rnn_cell.GridLSTMCell(
|
cell = contrib_rnn_cell.GridLSTMCell(
|
||||||
num_units=num_units,
|
num_units=num_units,
|
||||||
feature_size=feature_size,
|
feature_size=feature_size,
|
||||||
frequency_skip=frequency_skip,
|
frequency_skip=frequency_skip,
|
||||||
@ -181,7 +181,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
end_freqindex_list = [2, 4]
|
end_freqindex_list = [2, 4]
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
cell = rnn_cell.GridLSTMCell(
|
cell = contrib_rnn_cell.GridLSTMCell(
|
||||||
num_units=num_units,
|
num_units=num_units,
|
||||||
feature_size=feature_size,
|
feature_size=feature_size,
|
||||||
frequency_skip=frequency_skip,
|
frequency_skip=frequency_skip,
|
||||||
@ -249,7 +249,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"state_is_tuple" + str(state_is_tuple),
|
"state_is_tuple" + str(state_is_tuple),
|
||||||
initializer=init_ops.constant_initializer(0.5)):
|
initializer=init_ops.constant_initializer(0.5)):
|
||||||
cell = rnn_cell.GridLSTMCell(
|
cell = contrib_rnn_cell.GridLSTMCell(
|
||||||
num_units=num_units,
|
num_units=num_units,
|
||||||
feature_size=feature_size,
|
feature_size=feature_size,
|
||||||
frequency_skip=frequency_skip,
|
frequency_skip=frequency_skip,
|
||||||
@ -330,7 +330,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
dtype=np.float32)
|
dtype=np.float32)
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
cell = rnn_cell.BidirectionalGridLSTMCell(
|
cell = contrib_rnn_cell.BidirectionalGridLSTMCell(
|
||||||
num_units=num_units,
|
num_units=num_units,
|
||||||
feature_size=feature_size,
|
feature_size=feature_size,
|
||||||
share_time_frequency_weights=True,
|
share_time_frequency_weights=True,
|
||||||
@ -403,7 +403,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
dtype=np.float32)
|
dtype=np.float32)
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
cell = rnn_cell.BidirectionalGridLSTMCell(
|
cell = contrib_rnn_cell.BidirectionalGridLSTMCell(
|
||||||
num_units=num_units,
|
num_units=num_units,
|
||||||
feature_size=feature_size,
|
feature_size=feature_size,
|
||||||
share_time_frequency_weights=True,
|
share_time_frequency_weights=True,
|
||||||
@ -442,28 +442,28 @@ class RNNCellTest(test.TestCase):
|
|||||||
def testAttentionCellWrapperFailures(self):
|
def testAttentionCellWrapperFailures(self):
|
||||||
with self.assertRaisesRegexp(TypeError,
|
with self.assertRaisesRegexp(TypeError,
|
||||||
"The parameter cell is not RNNCell."):
|
"The parameter cell is not RNNCell."):
|
||||||
rnn_cell.AttentionCellWrapper(None, 0)
|
contrib_rnn_cell.AttentionCellWrapper(None, 0)
|
||||||
|
|
||||||
num_units = 8
|
num_units = 8
|
||||||
for state_is_tuple in [False, True]:
|
for state_is_tuple in [False, True]:
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
|
lstm_cell = rnn_cell.BasicLSTMCell(
|
||||||
num_units, state_is_tuple=state_is_tuple)
|
num_units, state_is_tuple=state_is_tuple)
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, "attn_length should be greater than zero, got 0"):
|
ValueError, "attn_length should be greater than zero, got 0"):
|
||||||
rnn_cell.AttentionCellWrapper(
|
contrib_rnn_cell.AttentionCellWrapper(
|
||||||
lstm_cell, 0, state_is_tuple=state_is_tuple)
|
lstm_cell, 0, state_is_tuple=state_is_tuple)
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, "attn_length should be greater than zero, got -1"):
|
ValueError, "attn_length should be greater than zero, got -1"):
|
||||||
rnn_cell.AttentionCellWrapper(
|
contrib_rnn_cell.AttentionCellWrapper(
|
||||||
lstm_cell, -1, state_is_tuple=state_is_tuple)
|
lstm_cell, -1, state_is_tuple=state_is_tuple)
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
|
lstm_cell = rnn_cell.BasicLSTMCell(num_units, state_is_tuple=True)
|
||||||
num_units, state_is_tuple=True)
|
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, "Cell returns tuple of states, but the flag "
|
ValueError, "Cell returns tuple of states, but the flag "
|
||||||
"state_is_tuple is not set. State size is: *"):
|
"state_is_tuple is not set. State size is: *"):
|
||||||
rnn_cell.AttentionCellWrapper(lstm_cell, 4, state_is_tuple=False)
|
contrib_rnn_cell.AttentionCellWrapper(
|
||||||
|
lstm_cell, 4, state_is_tuple=False)
|
||||||
|
|
||||||
def testAttentionCellWrapperZeros(self):
|
def testAttentionCellWrapperZeros(self):
|
||||||
num_units = 8
|
num_units = 8
|
||||||
@ -475,9 +475,9 @@ class RNNCellTest(test.TestCase):
|
|||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with variable_scope.variable_scope("state_is_tuple_" + str(
|
with variable_scope.variable_scope("state_is_tuple_" + str(
|
||||||
state_is_tuple)):
|
state_is_tuple)):
|
||||||
lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
|
lstm_cell = rnn_cell.BasicLSTMCell(
|
||||||
num_units, state_is_tuple=state_is_tuple)
|
num_units, state_is_tuple=state_is_tuple)
|
||||||
cell = rnn_cell.AttentionCellWrapper(
|
cell = contrib_rnn_cell.AttentionCellWrapper(
|
||||||
lstm_cell, attn_length, state_is_tuple=state_is_tuple)
|
lstm_cell, attn_length, state_is_tuple=state_is_tuple)
|
||||||
if state_is_tuple:
|
if state_is_tuple:
|
||||||
zeros = array_ops.zeros([batch_size, num_units], dtype=np.float32)
|
zeros = array_ops.zeros([batch_size, num_units], dtype=np.float32)
|
||||||
@ -526,9 +526,9 @@ class RNNCellTest(test.TestCase):
|
|||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with variable_scope.variable_scope("state_is_tuple_" + str(
|
with variable_scope.variable_scope("state_is_tuple_" + str(
|
||||||
state_is_tuple)):
|
state_is_tuple)):
|
||||||
lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
|
lstm_cell = rnn_cell.BasicLSTMCell(
|
||||||
num_units, state_is_tuple=state_is_tuple)
|
num_units, state_is_tuple=state_is_tuple)
|
||||||
cell = rnn_cell.AttentionCellWrapper(
|
cell = contrib_rnn_cell.AttentionCellWrapper(
|
||||||
lstm_cell, attn_length, state_is_tuple=state_is_tuple)
|
lstm_cell, attn_length, state_is_tuple=state_is_tuple)
|
||||||
if state_is_tuple:
|
if state_is_tuple:
|
||||||
zeros = constant_op.constant(
|
zeros = constant_op.constant(
|
||||||
@ -603,9 +603,9 @@ class RNNCellTest(test.TestCase):
|
|||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"state_is_tuple", reuse=state_is_tuple,
|
"state_is_tuple", reuse=state_is_tuple,
|
||||||
initializer=init_ops.glorot_uniform_initializer()):
|
initializer=init_ops.glorot_uniform_initializer()):
|
||||||
lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
|
lstm_cell = rnn_cell.BasicLSTMCell(
|
||||||
num_units, state_is_tuple=state_is_tuple)
|
num_units, state_is_tuple=state_is_tuple)
|
||||||
cell = rnn_cell.AttentionCellWrapper(
|
cell = contrib_rnn_cell.AttentionCellWrapper(
|
||||||
lstm_cell, attn_length, state_is_tuple=state_is_tuple)
|
lstm_cell, attn_length, state_is_tuple=state_is_tuple)
|
||||||
# This is legacy behavior to preserve the test. Weight
|
# This is legacy behavior to preserve the test. Weight
|
||||||
# sharing no longer works by creating a new RNNCell in the
|
# sharing no longer works by creating a new RNNCell in the
|
||||||
@ -665,8 +665,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"nas_test",
|
"nas_test",
|
||||||
initializer=init_ops.constant_initializer(0.5)):
|
initializer=init_ops.constant_initializer(0.5)):
|
||||||
cell = rnn_cell.NASCell(
|
cell = contrib_rnn_cell.NASCell(num_units=num_units)
|
||||||
num_units=num_units)
|
|
||||||
inputs = constant_op.constant(
|
inputs = constant_op.constant(
|
||||||
np.array([[1., 1., 1., 1.],
|
np.array([[1., 1., 1., 1.],
|
||||||
[2., 2., 2., 2.],
|
[2., 2., 2., 2.],
|
||||||
@ -677,8 +676,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
0.1 * np.ones(
|
0.1 * np.ones(
|
||||||
(batch_size, num_units), dtype=np.float32),
|
(batch_size, num_units), dtype=np.float32),
|
||||||
dtype=dtypes.float32)
|
dtype=dtypes.float32)
|
||||||
init_state = core_rnn_cell_impl.LSTMStateTuple(state_value,
|
init_state = rnn_cell.LSTMStateTuple(state_value, state_value)
|
||||||
state_value)
|
|
||||||
output, state = cell(inputs, init_state)
|
output, state = cell(inputs, init_state)
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res = sess.run([output, state])
|
res = sess.run([output, state])
|
||||||
@ -719,9 +717,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"nas_proj_test",
|
"nas_proj_test",
|
||||||
initializer=init_ops.constant_initializer(0.5)):
|
initializer=init_ops.constant_initializer(0.5)):
|
||||||
cell = rnn_cell.NASCell(
|
cell = contrib_rnn_cell.NASCell(num_units=num_units, num_proj=num_proj)
|
||||||
num_units=num_units,
|
|
||||||
num_proj=num_proj)
|
|
||||||
inputs = constant_op.constant(
|
inputs = constant_op.constant(
|
||||||
np.array([[1., 1., 1., 1.],
|
np.array([[1., 1., 1., 1.],
|
||||||
[2., 2., 2., 2.],
|
[2., 2., 2., 2.],
|
||||||
@ -736,8 +732,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
0.1 * np.ones(
|
0.1 * np.ones(
|
||||||
(batch_size, num_proj), dtype=np.float32),
|
(batch_size, num_proj), dtype=np.float32),
|
||||||
dtype=dtypes.float32)
|
dtype=dtypes.float32)
|
||||||
init_state = core_rnn_cell_impl.LSTMStateTuple(state_value_c,
|
init_state = rnn_cell.LSTMStateTuple(state_value_c, state_value_h)
|
||||||
state_value_h)
|
|
||||||
output, state = cell(inputs, init_state)
|
output, state = cell(inputs, init_state)
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res = sess.run([output, state])
|
res = sess.run([output, state])
|
||||||
@ -767,7 +762,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"ugrnn_cell_test",
|
"ugrnn_cell_test",
|
||||||
initializer=init_ops.constant_initializer(0.5)):
|
initializer=init_ops.constant_initializer(0.5)):
|
||||||
cell = rnn_cell.UGRNNCell(num_units=num_units)
|
cell = contrib_rnn_cell.UGRNNCell(num_units=num_units)
|
||||||
inputs = constant_op.constant(
|
inputs = constant_op.constant(
|
||||||
np.array([[1., 1., 1., 1.],
|
np.array([[1., 1., 1., 1.],
|
||||||
[2., 2., 2., 2.],
|
[2., 2., 2., 2.],
|
||||||
@ -803,8 +798,8 @@ class RNNCellTest(test.TestCase):
|
|||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"intersection_rnn_cell_test",
|
"intersection_rnn_cell_test",
|
||||||
initializer=init_ops.constant_initializer(0.5)):
|
initializer=init_ops.constant_initializer(0.5)):
|
||||||
cell = rnn_cell.IntersectionRNNCell(num_units=num_units,
|
cell = contrib_rnn_cell.IntersectionRNNCell(
|
||||||
num_in_proj=num_units)
|
num_units=num_units, num_in_proj=num_units)
|
||||||
inputs = constant_op.constant(
|
inputs = constant_op.constant(
|
||||||
np.array([[1., 1., 1., 1.],
|
np.array([[1., 1., 1., 1.],
|
||||||
[2., 2., 2., 2.],
|
[2., 2., 2., 2.],
|
||||||
@ -826,7 +821,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
def testIntersectionRNNCellFailure(self):
|
def testIntersectionRNNCellFailure(self):
|
||||||
num_units = 2
|
num_units = 2
|
||||||
batch_size = 3
|
batch_size = 3
|
||||||
cell = rnn_cell.IntersectionRNNCell(num_units=num_units)
|
cell = contrib_rnn_cell.IntersectionRNNCell(num_units=num_units)
|
||||||
inputs = constant_op.constant(
|
inputs = constant_op.constant(
|
||||||
np.array([[1., 1., 1., 1.],
|
np.array([[1., 1., 1., 1.],
|
||||||
[2., 2., 2., 2.],
|
[2., 2., 2., 2.],
|
||||||
@ -862,9 +857,9 @@ class RNNCellTest(test.TestCase):
|
|||||||
x = array_ops.zeros([batch_size, input_size])
|
x = array_ops.zeros([batch_size, input_size])
|
||||||
c0 = array_ops.zeros([batch_size, 2])
|
c0 = array_ops.zeros([batch_size, 2])
|
||||||
h0 = array_ops.zeros([batch_size, 2])
|
h0 = array_ops.zeros([batch_size, 2])
|
||||||
state0 = core_rnn_cell_impl.LSTMStateTuple(c0, h0)
|
state0 = rnn_cell.LSTMStateTuple(c0, h0)
|
||||||
output, state = rnn_cell.PhasedLSTMCell(num_units=num_units)((t, x),
|
output, state = contrib_rnn_cell.PhasedLSTMCell(num_units=num_units)(
|
||||||
state0)
|
(t, x), state0)
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res = sess.run([output, state], {
|
res = sess.run([output, state], {
|
||||||
t.name:
|
t.name:
|
||||||
@ -886,12 +881,12 @@ class RNNCellTest(test.TestCase):
|
|||||||
"base_cell", initializer=init_ops.constant_initializer(0.5)):
|
"base_cell", initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([1, 3])
|
x = array_ops.zeros([1, 3])
|
||||||
m = array_ops.zeros([1, 3])
|
m = array_ops.zeros([1, 3])
|
||||||
base_cell = core_rnn_cell_impl.GRUCell(3)
|
base_cell = rnn_cell.GRUCell(3)
|
||||||
g, m_new = base_cell(x, m)
|
g, m_new = base_cell(x, m)
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"hw_cell", initializer=init_ops.constant_initializer(0.5)):
|
"hw_cell", initializer=init_ops.constant_initializer(0.5)):
|
||||||
hw_cell = rnn_cell.HighwayWrapper(
|
hw_cell = contrib_rnn_cell.HighwayWrapper(
|
||||||
core_rnn_cell_impl.GRUCell(3), carry_bias_init=-100.0)
|
rnn_cell.GRUCell(3), carry_bias_init=-100.0)
|
||||||
g_res, m_new_res = hw_cell(x, m)
|
g_res, m_new_res = hw_cell(x, m)
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res = sess.run([g, g_res, m_new, m_new_res], {
|
res = sess.run([g, g_res, m_new, m_new_res], {
|
||||||
@ -915,9 +910,9 @@ class RNNCellTest(test.TestCase):
|
|||||||
"root1", initializer=init_ops.constant_initializer(0.5)):
|
"root1", initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.ones([batch_size, num_units])
|
x = array_ops.ones([batch_size, num_units])
|
||||||
# When number_of_groups = 1, G-LSTM is equivalent to regular LSTM
|
# When number_of_groups = 1, G-LSTM is equivalent to regular LSTM
|
||||||
gcell = rnn_cell.GLSTMCell(num_units=num_units,
|
gcell = contrib_rnn_cell.GLSTMCell(
|
||||||
number_of_groups=number_of_groups)
|
num_units=num_units, number_of_groups=number_of_groups)
|
||||||
cell = core_rnn_cell_impl.LSTMCell(num_units=num_units)
|
cell = rnn_cell.LSTMCell(num_units=num_units)
|
||||||
self.assertTrue(isinstance(gcell.state_size, tuple))
|
self.assertTrue(isinstance(gcell.state_size, tuple))
|
||||||
zero_state = gcell.zero_state(batch_size=batch_size,
|
zero_state = gcell.zero_state(batch_size=batch_size,
|
||||||
dtype=dtypes.float32)
|
dtype=dtypes.float32)
|
||||||
@ -941,8 +936,8 @@ class RNNCellTest(test.TestCase):
|
|||||||
"root2", initializer=init_ops.constant_initializer(0.5)):
|
"root2", initializer=init_ops.constant_initializer(0.5)):
|
||||||
# input for G-LSTM with 2 groups
|
# input for G-LSTM with 2 groups
|
||||||
glstm_input = array_ops.ones([batch_size, num_units])
|
glstm_input = array_ops.ones([batch_size, num_units])
|
||||||
gcell = rnn_cell.GLSTMCell(num_units=num_units,
|
gcell = contrib_rnn_cell.GLSTMCell(
|
||||||
number_of_groups=number_of_groups)
|
num_units=num_units, number_of_groups=number_of_groups)
|
||||||
gcell_zero_state = gcell.zero_state(batch_size=batch_size,
|
gcell_zero_state = gcell.zero_state(batch_size=batch_size,
|
||||||
dtype=dtypes.float32)
|
dtype=dtypes.float32)
|
||||||
gh, gs = gcell(glstm_input, gcell_zero_state)
|
gh, gs = gcell(glstm_input, gcell_zero_state)
|
||||||
@ -950,8 +945,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
# input for LSTM cell simulating single G-LSTM group
|
# input for LSTM cell simulating single G-LSTM group
|
||||||
lstm_input = array_ops.ones([batch_size, num_units / number_of_groups])
|
lstm_input = array_ops.ones([batch_size, num_units / number_of_groups])
|
||||||
# note division by number_of_groups. This cell one simulates G-LSTM group
|
# note division by number_of_groups. This cell one simulates G-LSTM group
|
||||||
cell = core_rnn_cell_impl.LSTMCell(num_units=
|
cell = rnn_cell.LSTMCell(num_units=int(num_units / number_of_groups))
|
||||||
int(num_units / number_of_groups))
|
|
||||||
cell_zero_state = cell.zero_state(batch_size=batch_size,
|
cell_zero_state = cell.zero_state(batch_size=batch_size,
|
||||||
dtype=dtypes.float32)
|
dtype=dtypes.float32)
|
||||||
h, g = cell(lstm_input, cell_zero_state)
|
h, g = cell(lstm_input, cell_zero_state)
|
||||||
@ -974,13 +968,13 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
|
|||||||
x = array_ops.zeros([1, 2])
|
x = array_ops.zeros([1, 2])
|
||||||
c0 = array_ops.zeros([1, 2])
|
c0 = array_ops.zeros([1, 2])
|
||||||
h0 = array_ops.zeros([1, 2])
|
h0 = array_ops.zeros([1, 2])
|
||||||
state0 = core_rnn_cell_impl.LSTMStateTuple(c0, h0)
|
state0 = rnn_cell.LSTMStateTuple(c0, h0)
|
||||||
c1 = array_ops.zeros([1, 2])
|
c1 = array_ops.zeros([1, 2])
|
||||||
h1 = array_ops.zeros([1, 2])
|
h1 = array_ops.zeros([1, 2])
|
||||||
state1 = core_rnn_cell_impl.LSTMStateTuple(c1, h1)
|
state1 = rnn_cell.LSTMStateTuple(c1, h1)
|
||||||
state = (state0, state1)
|
state = (state0, state1)
|
||||||
single_cell = lambda: rnn_cell.LayerNormBasicLSTMCell(2)
|
single_cell = lambda: contrib_rnn_cell.LayerNormBasicLSTMCell(2)
|
||||||
cell = core_rnn_cell_impl.MultiRNNCell([single_cell() for _ in range(2)])
|
cell = rnn_cell.MultiRNNCell([single_cell() for _ in range(2)])
|
||||||
g, out_m = cell(x, state)
|
g, out_m = cell(x, state)
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res = sess.run([g, out_m], {
|
res = sess.run([g, out_m], {
|
||||||
@ -1015,8 +1009,8 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
|
|||||||
[1, 3]) # Test BasicLSTMCell with input_size != num_units.
|
[1, 3]) # Test BasicLSTMCell with input_size != num_units.
|
||||||
c = array_ops.zeros([1, 2])
|
c = array_ops.zeros([1, 2])
|
||||||
h = array_ops.zeros([1, 2])
|
h = array_ops.zeros([1, 2])
|
||||||
state = core_rnn_cell_impl.LSTMStateTuple(c, h)
|
state = rnn_cell.LSTMStateTuple(c, h)
|
||||||
cell = rnn_cell.LayerNormBasicLSTMCell(2)
|
cell = contrib_rnn_cell.LayerNormBasicLSTMCell(2)
|
||||||
g, out_m = cell(x, state)
|
g, out_m = cell(x, state)
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res = sess.run([g, out_m], {
|
res = sess.run([g, out_m], {
|
||||||
@ -1039,12 +1033,12 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
|
|||||||
x = array_ops.zeros([1, 2])
|
x = array_ops.zeros([1, 2])
|
||||||
c0 = array_ops.zeros([1, 2])
|
c0 = array_ops.zeros([1, 2])
|
||||||
h0 = array_ops.zeros([1, 2])
|
h0 = array_ops.zeros([1, 2])
|
||||||
state0 = core_rnn_cell_impl.LSTMStateTuple(c0, h0)
|
state0 = rnn_cell.LSTMStateTuple(c0, h0)
|
||||||
c1 = array_ops.zeros([1, 2])
|
c1 = array_ops.zeros([1, 2])
|
||||||
h1 = array_ops.zeros([1, 2])
|
h1 = array_ops.zeros([1, 2])
|
||||||
state1 = core_rnn_cell_impl.LSTMStateTuple(c1, h1)
|
state1 = rnn_cell.LSTMStateTuple(c1, h1)
|
||||||
cell = core_rnn_cell_impl.MultiRNNCell(
|
cell = rnn_cell.MultiRNNCell(
|
||||||
[rnn_cell.LayerNormBasicLSTMCell(2) for _ in range(2)])
|
[contrib_rnn_cell.LayerNormBasicLSTMCell(2) for _ in range(2)])
|
||||||
h, (s0, s1) = cell(x, (state0, state1))
|
h, (s0, s1) = cell(x, (state0, state1))
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res = sess.run([h, s0, s1], {
|
res = sess.run([h, s0, s1], {
|
||||||
@ -1094,8 +1088,8 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
|
|||||||
x = array_ops.zeros([1, 5])
|
x = array_ops.zeros([1, 5])
|
||||||
c = array_ops.zeros([1, 5])
|
c = array_ops.zeros([1, 5])
|
||||||
h = array_ops.zeros([1, 5])
|
h = array_ops.zeros([1, 5])
|
||||||
state = core_rnn_cell_impl.LSTMStateTuple(c, h)
|
state = rnn_cell.LSTMStateTuple(c, h)
|
||||||
cell = rnn_cell.LayerNormBasicLSTMCell(
|
cell = contrib_rnn_cell.LayerNormBasicLSTMCell(
|
||||||
num_units, layer_norm=False, dropout_keep_prob=keep_prob)
|
num_units, layer_norm=False, dropout_keep_prob=keep_prob)
|
||||||
|
|
||||||
g, s = cell(x, state)
|
g, s = cell(x, state)
|
||||||
@ -1138,10 +1132,9 @@ def _create_multi_lstm_cell_ops(batch_size, num_units, input_depth,
|
|||||||
inputs = variable_scope.get_variable(
|
inputs = variable_scope.get_variable(
|
||||||
"inputs", initializer=random_ops.random_uniform(
|
"inputs", initializer=random_ops.random_uniform(
|
||||||
(max_time, batch_size, input_depth), seed=1))
|
(max_time, batch_size, input_depth), seed=1))
|
||||||
maybe_xla = lambda c: rnn_cell.CompiledWrapper(c) if compiled else c
|
maybe_xla = lambda c: contrib_rnn_cell.CompiledWrapper(c) if compiled else c
|
||||||
cell = core_rnn_cell_impl.MultiRNNCell(
|
cell = rnn_cell.MultiRNNCell(
|
||||||
[maybe_xla(core_rnn_cell_impl.LSTMCell(num_units))
|
[maybe_xla(rnn_cell.LSTMCell(num_units)) for _ in range(num_layers)])
|
||||||
for _ in range(num_layers)])
|
|
||||||
initial_state = cell.zero_state(
|
initial_state = cell.zero_state(
|
||||||
batch_size=batch_size, dtype=dtypes.float32)
|
batch_size=batch_size, dtype=dtypes.float32)
|
||||||
outputs, final_state = rnn.dynamic_rnn(
|
outputs, final_state = rnn.dynamic_rnn(
|
||||||
@ -1219,13 +1212,13 @@ class CompiledWrapperTest(test.TestCase):
|
|||||||
|
|
||||||
# Test incorrectness of state
|
# Test incorrectness of state
|
||||||
with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
|
with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
|
||||||
core_rnn_cell_impl.MultiRNNCell(
|
rnn_cell.MultiRNNCell(
|
||||||
[core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
|
[rnn_cell.GRUCell(2)
|
||||||
state_is_tuple=True)(x, m_bad)
|
for _ in range(2)], state_is_tuple=True)(x, m_bad)
|
||||||
|
|
||||||
_, ml = core_rnn_cell_impl.MultiRNNCell(
|
_, ml = rnn_cell.MultiRNNCell(
|
||||||
[core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
|
[rnn_cell.GRUCell(2)
|
||||||
state_is_tuple=True)(x, m_good)
|
for _ in range(2)], state_is_tuple=True)(x, m_good)
|
||||||
|
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res = sess.run(ml, {
|
res = sess.run(ml, {
|
||||||
|
@ -22,12 +22,12 @@ import itertools
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
from tensorflow.contrib.rnn.python.ops import rnn as contrib_rnn
|
||||||
from tensorflow.contrib.rnn.python.ops import rnn
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
|
from tensorflow.python.ops import rnn_cell
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -58,14 +58,14 @@ class StackBidirectionalRNNTest(test.TestCase):
|
|||||||
dtypes.int64) if use_sequence_length else None
|
dtypes.int64) if use_sequence_length else None
|
||||||
|
|
||||||
self.cells_fw = [
|
self.cells_fw = [
|
||||||
core_rnn_cell_impl.LSTMCell(
|
rnn_cell.LSTMCell(
|
||||||
num_units,
|
num_units,
|
||||||
input_size,
|
input_size,
|
||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
state_is_tuple=False) for num_units in self.layers
|
state_is_tuple=False) for num_units in self.layers
|
||||||
]
|
]
|
||||||
self.cells_bw = [
|
self.cells_bw = [
|
||||||
core_rnn_cell_impl.LSTMCell(
|
rnn_cell.LSTMCell(
|
||||||
num_units,
|
num_units,
|
||||||
input_size,
|
input_size,
|
||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
@ -77,7 +77,7 @@ class StackBidirectionalRNNTest(test.TestCase):
|
|||||||
dtypes.float32,
|
dtypes.float32,
|
||||||
shape=(batch_size, input_size) if use_shape else (None, input_size))
|
shape=(batch_size, input_size) if use_shape else (None, input_size))
|
||||||
]
|
]
|
||||||
outputs, state_fw, state_bw = rnn.stack_bidirectional_rnn(
|
outputs, state_fw, state_bw = contrib_rnn.stack_bidirectional_rnn(
|
||||||
self.cells_fw,
|
self.cells_fw,
|
||||||
self.cells_bw,
|
self.cells_bw,
|
||||||
inputs,
|
inputs,
|
||||||
@ -237,14 +237,14 @@ class StackBidirectionalRNNTest(test.TestCase):
|
|||||||
sequence_length = array_ops.placeholder(dtypes.int64)
|
sequence_length = array_ops.placeholder(dtypes.int64)
|
||||||
|
|
||||||
self.cells_fw = [
|
self.cells_fw = [
|
||||||
core_rnn_cell_impl.LSTMCell(
|
rnn_cell.LSTMCell(
|
||||||
num_units,
|
num_units,
|
||||||
input_size,
|
input_size,
|
||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
state_is_tuple=False) for num_units in self.layers
|
state_is_tuple=False) for num_units in self.layers
|
||||||
]
|
]
|
||||||
self.cells_bw = [
|
self.cells_bw = [
|
||||||
core_rnn_cell_impl.LSTMCell(
|
rnn_cell.LSTMCell(
|
||||||
num_units,
|
num_units,
|
||||||
input_size,
|
input_size,
|
||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
@ -258,7 +258,7 @@ class StackBidirectionalRNNTest(test.TestCase):
|
|||||||
]
|
]
|
||||||
inputs_c = array_ops.stack(inputs)
|
inputs_c = array_ops.stack(inputs)
|
||||||
inputs_c = array_ops.transpose(inputs_c, [1, 0, 2])
|
inputs_c = array_ops.transpose(inputs_c, [1, 0, 2])
|
||||||
outputs, st_fw, st_bw = rnn.stack_bidirectional_dynamic_rnn(
|
outputs, st_fw, st_bw = contrib_rnn.stack_bidirectional_dynamic_rnn(
|
||||||
self.cells_fw,
|
self.cells_fw,
|
||||||
self.cells_bw,
|
self.cells_bw,
|
||||||
inputs_c,
|
inputs_c,
|
||||||
|
@ -1,357 +0,0 @@
|
|||||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
"""RNN helpers for TensorFlow models."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from tensorflow.python.framework import ops
|
|
||||||
from tensorflow.python.framework import tensor_shape
|
|
||||||
from tensorflow.python.ops import array_ops
|
|
||||||
from tensorflow.python.ops import math_ops
|
|
||||||
from tensorflow.python.ops import rnn
|
|
||||||
from tensorflow.python.ops import rnn_cell_impl
|
|
||||||
from tensorflow.python.ops import variable_scope as vs
|
|
||||||
from tensorflow.python.util import nest
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
_concat = rnn_cell_impl._concat
|
|
||||||
_like_rnncell = rnn_cell_impl._like_rnncell
|
|
||||||
_infer_state_dtype = rnn._infer_state_dtype
|
|
||||||
_reverse_seq = rnn._reverse_seq
|
|
||||||
_rnn_step = rnn._rnn_step
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
|
|
||||||
|
|
||||||
def static_rnn(cell, inputs, initial_state=None, dtype=None,
|
|
||||||
sequence_length=None, scope=None):
|
|
||||||
"""Creates a recurrent neural network specified by RNNCell `cell`.
|
|
||||||
|
|
||||||
The simplest form of RNN network generated is:
|
|
||||||
|
|
||||||
```python
|
|
||||||
state = cell.zero_state(...)
|
|
||||||
outputs = []
|
|
||||||
for input_ in inputs:
|
|
||||||
output, state = cell(input_, state)
|
|
||||||
outputs.append(output)
|
|
||||||
return (outputs, state)
|
|
||||||
```
|
|
||||||
However, a few other options are available:
|
|
||||||
|
|
||||||
An initial state can be provided.
|
|
||||||
If the sequence_length vector is provided, dynamic calculation is performed.
|
|
||||||
This method of calculation does not compute the RNN steps past the maximum
|
|
||||||
sequence length of the minibatch (thus saving computational time),
|
|
||||||
and properly propagates the state at an example's sequence length
|
|
||||||
to the final state output.
|
|
||||||
|
|
||||||
The dynamic calculation performed is, at time `t` for batch row `b`,
|
|
||||||
|
|
||||||
```python
|
|
||||||
(output, state)(b, t) =
|
|
||||||
(t >= sequence_length(b))
|
|
||||||
? (zeros(cell.output_size), states(b, sequence_length(b) - 1))
|
|
||||||
: cell(input(b, t), state(b, t - 1))
|
|
||||||
```
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cell: An instance of RNNCell.
|
|
||||||
inputs: A length T list of inputs, each a `Tensor` of shape
|
|
||||||
`[batch_size, input_size]`, or a nested tuple of such elements.
|
|
||||||
initial_state: (optional) An initial state for the RNN.
|
|
||||||
If `cell.state_size` is an integer, this must be
|
|
||||||
a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`.
|
|
||||||
If `cell.state_size` is a tuple, this should be a tuple of
|
|
||||||
tensors having shapes `[batch_size, s] for s in cell.state_size`.
|
|
||||||
dtype: (optional) The data type for the initial state and expected output.
|
|
||||||
Required if initial_state is not provided or RNN state has a heterogeneous
|
|
||||||
dtype.
|
|
||||||
sequence_length: Specifies the length of each sequence in inputs.
|
|
||||||
An int32 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`.
|
|
||||||
scope: VariableScope for the created subgraph; defaults to "rnn".
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A pair (outputs, state) where:
|
|
||||||
|
|
||||||
- outputs is a length T list of outputs (one for each input), or a nested
|
|
||||||
tuple of such elements.
|
|
||||||
- state is the final state
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: If `cell` is not an instance of RNNCell.
|
|
||||||
ValueError: If `inputs` is `None` or an empty list, or if the input depth
|
|
||||||
(column size) cannot be inferred from inputs via shape inference.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not _like_rnncell(cell):
|
|
||||||
raise TypeError("cell must be an instance of RNNCell")
|
|
||||||
if not nest.is_sequence(inputs):
|
|
||||||
raise TypeError("inputs must be a sequence")
|
|
||||||
if not inputs:
|
|
||||||
raise ValueError("inputs must not be empty")
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
# Create a new scope in which the caching device is either
|
|
||||||
# determined by the parent scope, or is set to place the cached
|
|
||||||
# Variable using the same placement as for the rest of the RNN.
|
|
||||||
with vs.variable_scope(scope or "rnn") as varscope:
|
|
||||||
if varscope.caching_device is None:
|
|
||||||
varscope.set_caching_device(lambda op: op.device)
|
|
||||||
|
|
||||||
# Obtain the first sequence of the input
|
|
||||||
first_input = inputs
|
|
||||||
while nest.is_sequence(first_input):
|
|
||||||
first_input = first_input[0]
|
|
||||||
|
|
||||||
# Temporarily avoid EmbeddingWrapper and seq2seq badness
|
|
||||||
# TODO(lukaszkaiser): remove EmbeddingWrapper
|
|
||||||
if first_input.get_shape().ndims != 1:
|
|
||||||
|
|
||||||
input_shape = first_input.get_shape().with_rank_at_least(2)
|
|
||||||
fixed_batch_size = input_shape[0]
|
|
||||||
|
|
||||||
flat_inputs = nest.flatten(inputs)
|
|
||||||
for flat_input in flat_inputs:
|
|
||||||
input_shape = flat_input.get_shape().with_rank_at_least(2)
|
|
||||||
batch_size, input_size = input_shape[0], input_shape[1:]
|
|
||||||
fixed_batch_size.merge_with(batch_size)
|
|
||||||
for i, size in enumerate(input_size):
|
|
||||||
if size.value is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Input size (dimension %d of inputs) must be accessible via "
|
|
||||||
"shape inference, but saw value None." % i)
|
|
||||||
else:
|
|
||||||
fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0]
|
|
||||||
|
|
||||||
if fixed_batch_size.value:
|
|
||||||
batch_size = fixed_batch_size.value
|
|
||||||
else:
|
|
||||||
batch_size = array_ops.shape(first_input)[0]
|
|
||||||
if initial_state is not None:
|
|
||||||
state = initial_state
|
|
||||||
else:
|
|
||||||
if not dtype:
|
|
||||||
raise ValueError("If no initial_state is provided, "
|
|
||||||
"dtype must be specified")
|
|
||||||
state = cell.zero_state(batch_size, dtype)
|
|
||||||
|
|
||||||
if sequence_length is not None: # Prepare variables
|
|
||||||
sequence_length = ops.convert_to_tensor(
|
|
||||||
sequence_length, name="sequence_length")
|
|
||||||
if sequence_length.get_shape().ndims not in (None, 1):
|
|
||||||
raise ValueError(
|
|
||||||
"sequence_length must be a vector of length batch_size")
|
|
||||||
def _create_zero_output(output_size):
|
|
||||||
# convert int to TensorShape if necessary
|
|
||||||
size = _concat(batch_size, output_size)
|
|
||||||
output = array_ops.zeros(
|
|
||||||
array_ops.stack(size), _infer_state_dtype(dtype, state))
|
|
||||||
shape = _concat(fixed_batch_size.value, output_size, static=True)
|
|
||||||
output.set_shape(tensor_shape.TensorShape(shape))
|
|
||||||
return output
|
|
||||||
|
|
||||||
output_size = cell.output_size
|
|
||||||
flat_output_size = nest.flatten(output_size)
|
|
||||||
flat_zero_output = tuple(
|
|
||||||
_create_zero_output(size) for size in flat_output_size)
|
|
||||||
zero_output = nest.pack_sequence_as(structure=output_size,
|
|
||||||
flat_sequence=flat_zero_output)
|
|
||||||
|
|
||||||
sequence_length = math_ops.to_int32(sequence_length)
|
|
||||||
min_sequence_length = math_ops.reduce_min(sequence_length)
|
|
||||||
max_sequence_length = math_ops.reduce_max(sequence_length)
|
|
||||||
|
|
||||||
for time, input_ in enumerate(inputs):
|
|
||||||
if time > 0: varscope.reuse_variables()
|
|
||||||
# pylint: disable=cell-var-from-loop
|
|
||||||
call_cell = lambda: cell(input_, state)
|
|
||||||
# pylint: enable=cell-var-from-loop
|
|
||||||
if sequence_length is not None:
|
|
||||||
(output, state) = _rnn_step(
|
|
||||||
time=time,
|
|
||||||
sequence_length=sequence_length,
|
|
||||||
min_sequence_length=min_sequence_length,
|
|
||||||
max_sequence_length=max_sequence_length,
|
|
||||||
zero_output=zero_output,
|
|
||||||
state=state,
|
|
||||||
call_cell=call_cell,
|
|
||||||
state_size=cell.state_size)
|
|
||||||
else:
|
|
||||||
(output, state) = call_cell()
|
|
||||||
|
|
||||||
outputs.append(output)
|
|
||||||
|
|
||||||
return (outputs, state)
|
|
||||||
|
|
||||||
|
|
||||||
def static_state_saving_rnn(cell, inputs, state_saver, state_name,
|
|
||||||
sequence_length=None, scope=None):
|
|
||||||
"""RNN that accepts a state saver for time-truncated RNN calculation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cell: An instance of `RNNCell`.
|
|
||||||
inputs: A length T list of inputs, each a `Tensor` of shape
|
|
||||||
`[batch_size, input_size]`.
|
|
||||||
state_saver: A state saver object with methods `state` and `save_state`.
|
|
||||||
state_name: Python string or tuple of strings. The name to use with the
|
|
||||||
state_saver. If the cell returns tuples of states (i.e.,
|
|
||||||
`cell.state_size` is a tuple) then `state_name` should be a tuple of
|
|
||||||
strings having the same length as `cell.state_size`. Otherwise it should
|
|
||||||
be a single string.
|
|
||||||
sequence_length: (optional) An int32/int64 vector size [batch_size].
|
|
||||||
See the documentation for rnn() for more details about sequence_length.
|
|
||||||
scope: VariableScope for the created subgraph; defaults to "rnn".
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A pair (outputs, state) where:
|
|
||||||
outputs is a length T list of outputs (one for each input)
|
|
||||||
states is the final state
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: If `cell` is not an instance of RNNCell.
|
|
||||||
ValueError: If `inputs` is `None` or an empty list, or if the arity and
|
|
||||||
type of `state_name` does not match that of `cell.state_size`.
|
|
||||||
"""
|
|
||||||
state_size = cell.state_size
|
|
||||||
state_is_tuple = nest.is_sequence(state_size)
|
|
||||||
state_name_tuple = nest.is_sequence(state_name)
|
|
||||||
|
|
||||||
if state_is_tuple != state_name_tuple:
|
|
||||||
raise ValueError(
|
|
||||||
"state_name should be the same type as cell.state_size. "
|
|
||||||
"state_name: %s, cell.state_size: %s"
|
|
||||||
% (str(state_name), str(state_size)))
|
|
||||||
|
|
||||||
if state_is_tuple:
|
|
||||||
state_name_flat = nest.flatten(state_name)
|
|
||||||
state_size_flat = nest.flatten(state_size)
|
|
||||||
|
|
||||||
if len(state_name_flat) != len(state_size_flat):
|
|
||||||
raise ValueError("#elems(state_name) != #elems(state_size): %d vs. %d"
|
|
||||||
% (len(state_name_flat), len(state_size_flat)))
|
|
||||||
|
|
||||||
initial_state = nest.pack_sequence_as(
|
|
||||||
structure=state_size,
|
|
||||||
flat_sequence=[state_saver.state(s) for s in state_name_flat])
|
|
||||||
else:
|
|
||||||
initial_state = state_saver.state(state_name)
|
|
||||||
|
|
||||||
(outputs, state) = static_rnn(cell, inputs, initial_state=initial_state,
|
|
||||||
sequence_length=sequence_length, scope=scope)
|
|
||||||
|
|
||||||
if state_is_tuple:
|
|
||||||
flat_state = nest.flatten(state)
|
|
||||||
state_name = nest.flatten(state_name)
|
|
||||||
save_state = [state_saver.save_state(name, substate)
|
|
||||||
for name, substate in zip(state_name, flat_state)]
|
|
||||||
else:
|
|
||||||
save_state = [state_saver.save_state(state_name, state)]
|
|
||||||
|
|
||||||
with ops.control_dependencies(save_state):
|
|
||||||
last_output = outputs[-1]
|
|
||||||
flat_last_output = nest.flatten(last_output)
|
|
||||||
flat_last_output = [
|
|
||||||
array_ops.identity(output) for output in flat_last_output]
|
|
||||||
outputs[-1] = nest.pack_sequence_as(structure=last_output,
|
|
||||||
flat_sequence=flat_last_output)
|
|
||||||
|
|
||||||
return (outputs, state)
|
|
||||||
|
|
||||||
|
|
||||||
def static_bidirectional_rnn(cell_fw, cell_bw, inputs,
|
|
||||||
initial_state_fw=None, initial_state_bw=None,
|
|
||||||
dtype=None, sequence_length=None, scope=None):
|
|
||||||
"""Creates a bidirectional recurrent neural network.
|
|
||||||
|
|
||||||
Similar to the unidirectional case above (rnn) but takes input and builds
|
|
||||||
independent forward and backward RNNs with the final forward and backward
|
|
||||||
outputs depth-concatenated, such that the output will have the format
|
|
||||||
[time][batch][cell_fw.output_size + cell_bw.output_size]. The input_size of
|
|
||||||
forward and backward cell must match. The initial state for both directions
|
|
||||||
is zero by default (but can be set optionally) and no intermediate states are
|
|
||||||
ever returned -- the network is fully unrolled for the given (passed in)
|
|
||||||
length(s) of the sequence(s) or completely unrolled if length(s) is not given.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cell_fw: An instance of RNNCell, to be used for forward direction.
|
|
||||||
cell_bw: An instance of RNNCell, to be used for backward direction.
|
|
||||||
inputs: A length T list of inputs, each a tensor of shape
|
|
||||||
[batch_size, input_size], or a nested tuple of such elements.
|
|
||||||
initial_state_fw: (optional) An initial state for the forward RNN.
|
|
||||||
This must be a tensor of appropriate type and shape
|
|
||||||
`[batch_size, cell_fw.state_size]`.
|
|
||||||
If `cell_fw.state_size` is a tuple, this should be a tuple of
|
|
||||||
tensors having shapes `[batch_size, s] for s in cell_fw.state_size`.
|
|
||||||
initial_state_bw: (optional) Same as for `initial_state_fw`, but using
|
|
||||||
the corresponding properties of `cell_bw`.
|
|
||||||
dtype: (optional) The data type for the initial state. Required if
|
|
||||||
either of the initial states are not provided.
|
|
||||||
sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
|
|
||||||
containing the actual lengths for each of the sequences.
|
|
||||||
scope: VariableScope for the created subgraph; defaults to
|
|
||||||
"bidirectional_rnn"
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple (outputs, output_state_fw, output_state_bw) where:
|
|
||||||
outputs is a length `T` list of outputs (one for each input), which
|
|
||||||
are depth-concatenated forward and backward outputs.
|
|
||||||
output_state_fw is the final state of the forward rnn.
|
|
||||||
output_state_bw is the final state of the backward rnn.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
|
|
||||||
ValueError: If inputs is None or an empty list.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not _like_rnncell(cell_fw):
|
|
||||||
raise TypeError("cell_fw must be an instance of RNNCell")
|
|
||||||
if not _like_rnncell(cell_bw):
|
|
||||||
raise TypeError("cell_bw must be an instance of RNNCell")
|
|
||||||
if not nest.is_sequence(inputs):
|
|
||||||
raise TypeError("inputs must be a sequence")
|
|
||||||
if not inputs:
|
|
||||||
raise ValueError("inputs must not be empty")
|
|
||||||
|
|
||||||
with vs.variable_scope(scope or "bidirectional_rnn"):
|
|
||||||
# Forward direction
|
|
||||||
with vs.variable_scope("fw") as fw_scope:
|
|
||||||
output_fw, output_state_fw = static_rnn(
|
|
||||||
cell_fw, inputs, initial_state_fw, dtype,
|
|
||||||
sequence_length, scope=fw_scope)
|
|
||||||
|
|
||||||
# Backward direction
|
|
||||||
with vs.variable_scope("bw") as bw_scope:
|
|
||||||
reversed_inputs = _reverse_seq(inputs, sequence_length)
|
|
||||||
tmp, output_state_bw = static_rnn(
|
|
||||||
cell_bw, reversed_inputs, initial_state_bw,
|
|
||||||
dtype, sequence_length, scope=bw_scope)
|
|
||||||
|
|
||||||
output_bw = _reverse_seq(tmp, sequence_length)
|
|
||||||
# Concat each of the forward/backward outputs
|
|
||||||
flat_output_fw = nest.flatten(output_fw)
|
|
||||||
flat_output_bw = nest.flatten(output_bw)
|
|
||||||
|
|
||||||
flat_outputs = tuple(
|
|
||||||
array_ops.concat([fw, bw], 1)
|
|
||||||
for fw, bw in zip(flat_output_fw, flat_output_bw))
|
|
||||||
|
|
||||||
outputs = nest.pack_sequence_as(structure=output_fw,
|
|
||||||
flat_sequence=flat_outputs)
|
|
||||||
|
|
||||||
return (outputs, output_state_fw, output_state_bw)
|
|
@ -12,45 +12,219 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
"""Module implementing RNN Cells that used to be in core.
|
||||||
|
|
||||||
"""Module for constructing RNN Cells.
|
|
||||||
|
|
||||||
## Base interface for all RNN Cells
|
|
||||||
|
|
||||||
@@RNNCell
|
|
||||||
|
|
||||||
## RNN Cells for use with TensorFlow's core RNN methods
|
|
||||||
|
|
||||||
@@BasicRNNCell
|
|
||||||
@@BasicLSTMCell
|
|
||||||
@@GRUCell
|
|
||||||
@@LSTMCell
|
|
||||||
|
|
||||||
## Classes storing split `RNNCell` state
|
|
||||||
|
|
||||||
@@LSTMStateTuple
|
|
||||||
|
|
||||||
## RNN Cell wrappers (RNNCells that wrap other RNNCells)
|
|
||||||
|
|
||||||
@@MultiRNNCell
|
|
||||||
@@DropoutWrapper
|
|
||||||
@@EmbeddingWrapper
|
@@EmbeddingWrapper
|
||||||
@@InputProjectionWrapper
|
@@InputProjectionWrapper
|
||||||
@@OutputProjectionWrapper
|
@@OutputProjectionWrapper
|
||||||
@@DeviceWrapper
|
|
||||||
@@ResidualWrapper
|
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
# go/tf-wildcard-import
|
import math
|
||||||
# pylint: disable=wildcard-import
|
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import *
|
from tensorflow.python.framework import ops
|
||||||
# pylint: enable=wildcard-import
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.util.all_util import remove_undocumented
|
from tensorflow.python.ops import embedding_ops
|
||||||
|
from tensorflow.python.ops import init_ops
|
||||||
|
from tensorflow.python.ops import rnn_cell_impl
|
||||||
|
from tensorflow.python.ops import variable_scope as vs
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
|
||||||
|
RNNCell = rnn_cell_impl.RNNCell # pylint: disable=invalid-name
|
||||||
|
_linear = rnn_cell_impl._linear # pylint: disable=invalid-name, protected-access
|
||||||
|
_like_rnncell = rnn_cell_impl._like_rnncell # pylint: disable=invalid-name, protected-access
|
||||||
|
|
||||||
|
|
||||||
_allowed_symbols = []
|
class EmbeddingWrapper(RNNCell):
|
||||||
|
"""Operator adding input embedding to the given cell.
|
||||||
|
|
||||||
remove_undocumented(__name__, _allowed_symbols)
|
Note: in many cases it may be more efficient to not use this wrapper,
|
||||||
|
but instead concatenate the whole sequence of your inputs in time,
|
||||||
|
do the embedding on this batch-concatenated sequence, then split it and
|
||||||
|
feed into your RNN.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
cell,
|
||||||
|
embedding_classes,
|
||||||
|
embedding_size,
|
||||||
|
initializer=None,
|
||||||
|
reuse=None):
|
||||||
|
"""Create a cell with an added input embedding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cell: an RNNCell, an embedding will be put before its inputs.
|
||||||
|
embedding_classes: integer, how many symbols will be embedded.
|
||||||
|
embedding_size: integer, the size of the vectors we embed into.
|
||||||
|
initializer: an initializer to use when creating the embedding;
|
||||||
|
if None, the initializer from variable scope or a default one is used.
|
||||||
|
reuse: (optional) Python boolean describing whether to reuse variables
|
||||||
|
in an existing scope. If not `True`, and the existing scope already has
|
||||||
|
the given variables, an error is raised.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if cell is not an RNNCell.
|
||||||
|
ValueError: if embedding_classes is not positive.
|
||||||
|
"""
|
||||||
|
super(EmbeddingWrapper, self).__init__(_reuse=reuse)
|
||||||
|
if not _like_rnncell(cell):
|
||||||
|
raise TypeError("The parameter cell is not RNNCell.")
|
||||||
|
if embedding_classes <= 0 or embedding_size <= 0:
|
||||||
|
raise ValueError("Both embedding_classes and embedding_size must be > 0: "
|
||||||
|
"%d, %d." % (embedding_classes, embedding_size))
|
||||||
|
self._cell = cell
|
||||||
|
self._embedding_classes = embedding_classes
|
||||||
|
self._embedding_size = embedding_size
|
||||||
|
self._initializer = initializer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_size(self):
|
||||||
|
return self._cell.state_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_size(self):
|
||||||
|
return self._cell.output_size
|
||||||
|
|
||||||
|
def zero_state(self, batch_size, dtype):
|
||||||
|
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
||||||
|
return self._cell.zero_state(batch_size, dtype)
|
||||||
|
|
||||||
|
def call(self, inputs, state):
|
||||||
|
"""Run the cell on embedded inputs."""
|
||||||
|
with ops.device("/cpu:0"):
|
||||||
|
if self._initializer:
|
||||||
|
initializer = self._initializer
|
||||||
|
elif vs.get_variable_scope().initializer:
|
||||||
|
initializer = vs.get_variable_scope().initializer
|
||||||
|
else:
|
||||||
|
# Default initializer for embeddings should have variance=1.
|
||||||
|
sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1.
|
||||||
|
initializer = init_ops.random_uniform_initializer(-sqrt3, sqrt3)
|
||||||
|
|
||||||
|
if isinstance(state, tuple):
|
||||||
|
data_type = state[0].dtype
|
||||||
|
else:
|
||||||
|
data_type = state.dtype
|
||||||
|
|
||||||
|
embedding = vs.get_variable(
|
||||||
|
"embedding", [self._embedding_classes, self._embedding_size],
|
||||||
|
initializer=initializer,
|
||||||
|
dtype=data_type)
|
||||||
|
embedded = embedding_ops.embedding_lookup(embedding,
|
||||||
|
array_ops.reshape(inputs, [-1]))
|
||||||
|
|
||||||
|
return self._cell(embedded, state)
|
||||||
|
|
||||||
|
|
||||||
|
class InputProjectionWrapper(RNNCell):
|
||||||
|
"""Operator adding an input projection to the given cell.
|
||||||
|
|
||||||
|
Note: in many cases it may be more efficient to not use this wrapper,
|
||||||
|
but instead concatenate the whole sequence of your inputs in time,
|
||||||
|
do the projection on this batch-concatenated sequence, then split it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
cell,
|
||||||
|
num_proj,
|
||||||
|
activation=None,
|
||||||
|
input_size=None,
|
||||||
|
reuse=None):
|
||||||
|
"""Create a cell with input projection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cell: an RNNCell, a projection of inputs is added before it.
|
||||||
|
num_proj: Python integer. The dimension to project to.
|
||||||
|
activation: (optional) an optional activation function.
|
||||||
|
input_size: Deprecated and unused.
|
||||||
|
reuse: (optional) Python boolean describing whether to reuse variables
|
||||||
|
in an existing scope. If not `True`, and the existing scope already has
|
||||||
|
the given variables, an error is raised.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if cell is not an RNNCell.
|
||||||
|
"""
|
||||||
|
super(InputProjectionWrapper, self).__init__(_reuse=reuse)
|
||||||
|
if input_size is not None:
|
||||||
|
logging.warn("%s: The input_size parameter is deprecated.", self)
|
||||||
|
if not _like_rnncell(cell):
|
||||||
|
raise TypeError("The parameter cell is not RNNCell.")
|
||||||
|
self._cell = cell
|
||||||
|
self._num_proj = num_proj
|
||||||
|
self._activation = activation
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_size(self):
|
||||||
|
return self._cell.state_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_size(self):
|
||||||
|
return self._cell.output_size
|
||||||
|
|
||||||
|
def zero_state(self, batch_size, dtype):
|
||||||
|
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
||||||
|
return self._cell.zero_state(batch_size, dtype)
|
||||||
|
|
||||||
|
def call(self, inputs, state):
|
||||||
|
"""Run the input projection and then the cell."""
|
||||||
|
# Default scope: "InputProjectionWrapper"
|
||||||
|
projected = _linear(inputs, self._num_proj, True)
|
||||||
|
if self._activation:
|
||||||
|
projected = self._activation(projected)
|
||||||
|
return self._cell(projected, state)
|
||||||
|
|
||||||
|
|
||||||
|
class OutputProjectionWrapper(RNNCell):
|
||||||
|
"""Operator adding an output projection to the given cell.
|
||||||
|
|
||||||
|
Note: in many cases it may be more efficient to not use this wrapper,
|
||||||
|
but instead concatenate the whole sequence of your outputs in time,
|
||||||
|
do the projection on this batch-concatenated sequence, then split it
|
||||||
|
if needed or directly feed into a softmax.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cell, output_size, activation=None, reuse=None):
|
||||||
|
"""Create a cell with output projection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cell: an RNNCell, a projection to output_size is added to it.
|
||||||
|
output_size: integer, the size of the output after projection.
|
||||||
|
activation: (optional) an optional activation function.
|
||||||
|
reuse: (optional) Python boolean describing whether to reuse variables
|
||||||
|
in an existing scope. If not `True`, and the existing scope already has
|
||||||
|
the given variables, an error is raised.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if cell is not an RNNCell.
|
||||||
|
ValueError: if output_size is not positive.
|
||||||
|
"""
|
||||||
|
super(OutputProjectionWrapper, self).__init__(_reuse=reuse)
|
||||||
|
if not _like_rnncell(cell):
|
||||||
|
raise TypeError("The parameter cell is not RNNCell.")
|
||||||
|
if output_size < 1:
|
||||||
|
raise ValueError("Parameter output_size must be > 0: %d." % output_size)
|
||||||
|
self._cell = cell
|
||||||
|
self._output_size = output_size
|
||||||
|
self._activation = activation
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_size(self):
|
||||||
|
return self._cell.state_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_size(self):
|
||||||
|
return self._output_size
|
||||||
|
|
||||||
|
def zero_state(self, batch_size, dtype):
|
||||||
|
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
||||||
|
return self._cell.zero_state(batch_size, dtype)
|
||||||
|
|
||||||
|
def call(self, inputs, state):
|
||||||
|
"""Run the cell and output projection on inputs, starting from state."""
|
||||||
|
output, res_state = self._cell(inputs, state)
|
||||||
|
projected = _linear(output, self._output_size, True)
|
||||||
|
if self._activation:
|
||||||
|
projected = self._activation(projected)
|
||||||
|
return projected, res_state
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -20,7 +20,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn as contrib_rnn
|
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import rnn
|
from tensorflow.python.ops import rnn
|
||||||
|
|
||||||
@ -116,12 +115,13 @@ class FusedRNNCellAdaptor(FusedRNNCell):
|
|||||||
else: # non-dynamic rnn
|
else: # non-dynamic rnn
|
||||||
if not is_list:
|
if not is_list:
|
||||||
inputs = array_ops.unstack(inputs)
|
inputs = array_ops.unstack(inputs)
|
||||||
outputs, state = contrib_rnn.static_rnn(self._cell,
|
outputs, state = rnn.static_rnn(
|
||||||
inputs,
|
self._cell,
|
||||||
initial_state=initial_state,
|
inputs,
|
||||||
dtype=dtype,
|
initial_state=initial_state,
|
||||||
sequence_length=sequence_length,
|
dtype=dtype,
|
||||||
scope=scope)
|
sequence_length=sequence_length,
|
||||||
|
scope=scope)
|
||||||
if not is_list:
|
if not is_list:
|
||||||
# Convert outputs back to tensor
|
# Convert outputs back to tensor
|
||||||
outputs = array_ops.stack(outputs)
|
outputs = array_ops.stack(outputs)
|
||||||
|
@ -18,13 +18,13 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.contrib.rnn.ops import gen_gru_ops
|
from tensorflow.contrib.rnn.ops import gen_gru_ops
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
|
|
||||||
from tensorflow.contrib.util import loader
|
from tensorflow.contrib.util import loader
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
|
from tensorflow.python.ops import rnn_cell_impl
|
||||||
from tensorflow.python.ops import variable_scope as vs
|
from tensorflow.python.ops import variable_scope as vs
|
||||||
from tensorflow.python.platform import resource_loader
|
from tensorflow.python.platform import resource_loader
|
||||||
|
|
||||||
@ -94,7 +94,7 @@ def _GRUBlockCellGrad(op, *grad):
|
|||||||
return d_x, d_h_prev, d_w_ru, d_w_c, d_b_ru, d_b_c
|
return d_x, d_h_prev, d_w_ru, d_w_c, d_b_ru, d_b_c
|
||||||
|
|
||||||
|
|
||||||
class GRUBlockCell(core_rnn_cell.RNNCell):
|
class GRUBlockCell(rnn_cell_impl.RNNCell):
|
||||||
r"""Block GRU cell implementation.
|
r"""Block GRU cell implementation.
|
||||||
|
|
||||||
The implementation is based on: http://arxiv.org/abs/1406.1078
|
The implementation is based on: http://arxiv.org/abs/1406.1078
|
||||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
|||||||
import abc
|
import abc
|
||||||
|
|
||||||
from tensorflow.contrib.rnn.ops import gen_lstm_ops
|
from tensorflow.contrib.rnn.ops import gen_lstm_ops
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
|
|
||||||
from tensorflow.contrib.rnn.python.ops import fused_rnn_cell
|
from tensorflow.contrib.rnn.python.ops import fused_rnn_cell
|
||||||
from tensorflow.contrib.util import loader
|
from tensorflow.contrib.util import loader
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -29,6 +28,7 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
|
from tensorflow.python.ops import rnn_cell_impl
|
||||||
from tensorflow.python.ops import variable_scope as vs
|
from tensorflow.python.ops import variable_scope as vs
|
||||||
from tensorflow.python.platform import resource_loader
|
from tensorflow.python.platform import resource_loader
|
||||||
|
|
||||||
@ -325,7 +325,7 @@ def _BlockLSTMGrad(op, *grad):
|
|||||||
wcf_grad, b_grad]
|
wcf_grad, b_grad]
|
||||||
|
|
||||||
|
|
||||||
class LSTMBlockCell(core_rnn_cell.RNNCell):
|
class LSTMBlockCell(rnn_cell_impl.RNNCell):
|
||||||
"""Basic LSTM recurrent network cell.
|
"""Basic LSTM recurrent network cell.
|
||||||
|
|
||||||
The implementation is based on: http://arxiv.org/abs/1409.2329.
|
The implementation is based on: http://arxiv.org/abs/1409.2329.
|
||||||
@ -333,7 +333,7 @@ class LSTMBlockCell(core_rnn_cell.RNNCell):
|
|||||||
We add `forget_bias` (default: 1) to the biases of the forget gate in order to
|
We add `forget_bias` (default: 1) to the biases of the forget gate in order to
|
||||||
reduce the scale of forgetting in the beginning of the training.
|
reduce the scale of forgetting in the beginning of the training.
|
||||||
|
|
||||||
Unlike `core_rnn_cell.LSTMCell`, this is a monolithic op and should be much
|
Unlike `rnn_cell_impl.LSTMCell`, this is a monolithic op and should be much
|
||||||
faster. The weight and bias matrices should be compatible as long as the
|
faster. The weight and bias matrices should be compatible as long as the
|
||||||
variable scope matches.
|
variable scope matches.
|
||||||
"""
|
"""
|
||||||
@ -363,7 +363,7 @@ class LSTMBlockCell(core_rnn_cell.RNNCell):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def state_size(self):
|
def state_size(self):
|
||||||
return core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units)
|
return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_size(self):
|
def output_size(self):
|
||||||
@ -402,7 +402,7 @@ class LSTMBlockCell(core_rnn_cell.RNNCell):
|
|||||||
forget_bias=self._forget_bias,
|
forget_bias=self._forget_bias,
|
||||||
use_peephole=self._use_peephole)
|
use_peephole=self._use_peephole)
|
||||||
|
|
||||||
new_state = core_rnn_cell.LSTMStateTuple(cs, h)
|
new_state = rnn_cell_impl.LSTMStateTuple(cs, h)
|
||||||
return h, new_state
|
return h, new_state
|
||||||
|
|
||||||
|
|
||||||
@ -546,8 +546,7 @@ class LSTMBlockWrapper(fused_rnn_cell.FusedRNNCell):
|
|||||||
# Input was a list, so return a list
|
# Input was a list, so return a list
|
||||||
outputs = array_ops.unstack(outputs)
|
outputs = array_ops.unstack(outputs)
|
||||||
|
|
||||||
final_state = core_rnn_cell.LSTMStateTuple(final_cell_state,
|
final_state = rnn_cell_impl.LSTMStateTuple(final_cell_state, final_output)
|
||||||
final_output)
|
|
||||||
return outputs, final_state
|
return outputs, final_state
|
||||||
|
|
||||||
def _gather_states(self, data, indices, batch_size):
|
def _gather_states(self, data, indices, batch_size):
|
||||||
@ -569,7 +568,7 @@ class LSTMBlockFusedCell(LSTMBlockWrapper):
|
|||||||
We add forget_bias (default: 1) to the biases of the forget gate in order to
|
We add forget_bias (default: 1) to the biases of the forget gate in order to
|
||||||
reduce the scale of forgetting in the beginning of the training.
|
reduce the scale of forgetting in the beginning of the training.
|
||||||
|
|
||||||
The variable naming is consistent with `core_rnn_cell.LSTMCell`.
|
The variable naming is consistent with `rnn_cell_impl.LSTMCell`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn as contrib_rnn
|
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import rnn
|
from tensorflow.python.ops import rnn
|
||||||
from tensorflow.python.ops import variable_scope as vs
|
from tensorflow.python.ops import variable_scope as vs
|
||||||
@ -106,7 +105,7 @@ def stack_bidirectional_rnn(cells_fw,
|
|||||||
initial_state_bw = initial_states_bw[i]
|
initial_state_bw = initial_states_bw[i]
|
||||||
|
|
||||||
with vs.variable_scope("cell_%d" % i) as cell_scope:
|
with vs.variable_scope("cell_%d" % i) as cell_scope:
|
||||||
prev_layer, state_fw, state_bw = contrib_rnn.static_bidirectional_rnn(
|
prev_layer, state_fw, state_bw = rnn.static_bidirectional_rnn(
|
||||||
cell_fw,
|
cell_fw,
|
||||||
cell_bw,
|
cell_bw,
|
||||||
prev_layer,
|
prev_layer,
|
||||||
|
@ -23,8 +23,6 @@ import math
|
|||||||
|
|
||||||
from tensorflow.contrib.compiler import jit
|
from tensorflow.contrib.compiler import jit
|
||||||
from tensorflow.contrib.layers.python.layers import layers
|
from tensorflow.contrib.layers.python.layers import layers
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
|
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import op_def_registry
|
from tensorflow.python.framework import op_def_registry
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -76,7 +74,7 @@ def _get_sharded_variable(name, shape, dtype, num_shards):
|
|||||||
return shards
|
return shards
|
||||||
|
|
||||||
|
|
||||||
class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
|
class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
|
||||||
"""Long short-term memory unit (LSTM) recurrent network cell.
|
"""Long short-term memory unit (LSTM) recurrent network cell.
|
||||||
|
|
||||||
The default non-peephole implementation is based on:
|
The default non-peephole implementation is based on:
|
||||||
@ -154,14 +152,12 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
|
|||||||
self._reuse = reuse
|
self._reuse = reuse
|
||||||
|
|
||||||
if num_proj:
|
if num_proj:
|
||||||
self._state_size = (
|
self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
|
||||||
core_rnn_cell.LSTMStateTuple(num_units, num_proj)
|
if state_is_tuple else num_units + num_proj)
|
||||||
if state_is_tuple else num_units + num_proj)
|
|
||||||
self._output_size = num_proj
|
self._output_size = num_proj
|
||||||
else:
|
else:
|
||||||
self._state_size = (
|
self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units)
|
||||||
core_rnn_cell.LSTMStateTuple(num_units, num_units)
|
if state_is_tuple else 2 * num_units)
|
||||||
if state_is_tuple else 2 * num_units)
|
|
||||||
self._output_size = num_units
|
self._output_size = num_units
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -254,12 +250,12 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
|
|||||||
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
|
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
|
||||||
# pylint: enable=invalid-unary-operand-type
|
# pylint: enable=invalid-unary-operand-type
|
||||||
|
|
||||||
new_state = (core_rnn_cell.LSTMStateTuple(c, m) if self._state_is_tuple else
|
new_state = (rnn_cell_impl.LSTMStateTuple(c, m)
|
||||||
array_ops.concat([c, m], 1))
|
if self._state_is_tuple else array_ops.concat([c, m], 1))
|
||||||
return m, new_state
|
return m, new_state
|
||||||
|
|
||||||
|
|
||||||
class TimeFreqLSTMCell(core_rnn_cell.RNNCell):
|
class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
|
||||||
"""Time-Frequency Long short-term memory unit (LSTM) recurrent network cell.
|
"""Time-Frequency Long short-term memory unit (LSTM) recurrent network cell.
|
||||||
|
|
||||||
This implementation is based on:
|
This implementation is based on:
|
||||||
@ -427,7 +423,7 @@ class TimeFreqLSTMCell(core_rnn_cell.RNNCell):
|
|||||||
return freq_inputs
|
return freq_inputs
|
||||||
|
|
||||||
|
|
||||||
class GridLSTMCell(core_rnn_cell.RNNCell):
|
class GridLSTMCell(rnn_cell_impl.RNNCell):
|
||||||
"""Grid Long short-term memory unit (LSTM) recurrent network cell.
|
"""Grid Long short-term memory unit (LSTM) recurrent network cell.
|
||||||
|
|
||||||
The default is based on:
|
The default is based on:
|
||||||
@ -1020,11 +1016,11 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
|
|||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
_linear = core_rnn_cell_impl._linear
|
_linear = rnn_cell_impl._linear
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
|
||||||
class AttentionCellWrapper(core_rnn_cell.RNNCell):
|
class AttentionCellWrapper(rnn_cell_impl.RNNCell):
|
||||||
"""Basic attention cell wrapper.
|
"""Basic attention cell wrapper.
|
||||||
|
|
||||||
Implementation based on https://arxiv.org/abs/1409.0473.
|
Implementation based on https://arxiv.org/abs/1409.0473.
|
||||||
@ -1155,7 +1151,7 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell):
|
|||||||
return new_attns, new_attn_states
|
return new_attns, new_attn_states
|
||||||
|
|
||||||
|
|
||||||
class HighwayWrapper(core_rnn_cell.RNNCell):
|
class HighwayWrapper(rnn_cell_impl.RNNCell):
|
||||||
"""RNNCell wrapper that adds highway connection on cell input and output.
|
"""RNNCell wrapper that adds highway connection on cell input and output.
|
||||||
|
|
||||||
Based on:
|
Based on:
|
||||||
@ -1238,7 +1234,7 @@ class HighwayWrapper(core_rnn_cell.RNNCell):
|
|||||||
return (res_outputs, new_state)
|
return (res_outputs, new_state)
|
||||||
|
|
||||||
|
|
||||||
class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell):
|
class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
|
||||||
"""LSTM unit with layer normalization and recurrent dropout.
|
"""LSTM unit with layer normalization and recurrent dropout.
|
||||||
|
|
||||||
This class adds layer normalization and recurrent dropout to a
|
This class adds layer normalization and recurrent dropout to a
|
||||||
@ -1300,7 +1296,7 @@ class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def state_size(self):
|
def state_size(self):
|
||||||
return core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units)
|
return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_size(self):
|
def output_size(self):
|
||||||
@ -1350,11 +1346,11 @@ class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell):
|
|||||||
new_c = self._norm(new_c, "state")
|
new_c = self._norm(new_c, "state")
|
||||||
new_h = self._activation(new_c) * math_ops.sigmoid(o)
|
new_h = self._activation(new_c) * math_ops.sigmoid(o)
|
||||||
|
|
||||||
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
|
new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
|
||||||
return new_h, new_state
|
return new_h, new_state
|
||||||
|
|
||||||
|
|
||||||
class NASCell(core_rnn_cell.RNNCell):
|
class NASCell(rnn_cell_impl.RNNCell):
|
||||||
"""Neural Architecture Search (NAS) recurrent network cell.
|
"""Neural Architecture Search (NAS) recurrent network cell.
|
||||||
|
|
||||||
This implements the recurrent cell from the paper:
|
This implements the recurrent cell from the paper:
|
||||||
@ -1388,10 +1384,10 @@ class NASCell(core_rnn_cell.RNNCell):
|
|||||||
self._reuse = reuse
|
self._reuse = reuse
|
||||||
|
|
||||||
if num_proj is not None:
|
if num_proj is not None:
|
||||||
self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_proj)
|
self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
|
||||||
self._output_size = num_proj
|
self._output_size = num_proj
|
||||||
else:
|
else:
|
||||||
self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_units)
|
self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
|
||||||
self._output_size = num_units
|
self._output_size = num_units
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1498,11 +1494,11 @@ class NASCell(core_rnn_cell.RNNCell):
|
|||||||
dtype)
|
dtype)
|
||||||
new_m = math_ops.matmul(new_m, concat_w_proj)
|
new_m = math_ops.matmul(new_m, concat_w_proj)
|
||||||
|
|
||||||
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_m)
|
new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_m)
|
||||||
return new_m, new_state
|
return new_m, new_state
|
||||||
|
|
||||||
|
|
||||||
class UGRNNCell(core_rnn_cell.RNNCell):
|
class UGRNNCell(rnn_cell_impl.RNNCell):
|
||||||
"""Update Gate Recurrent Neural Network (UGRNN) cell.
|
"""Update Gate Recurrent Neural Network (UGRNN) cell.
|
||||||
|
|
||||||
Compromise between a LSTM/GRU and a vanilla RNN. There is only one
|
Compromise between a LSTM/GRU and a vanilla RNN. There is only one
|
||||||
@ -1589,7 +1585,7 @@ class UGRNNCell(core_rnn_cell.RNNCell):
|
|||||||
return new_output, new_state
|
return new_output, new_state
|
||||||
|
|
||||||
|
|
||||||
class IntersectionRNNCell(core_rnn_cell.RNNCell):
|
class IntersectionRNNCell(rnn_cell_impl.RNNCell):
|
||||||
"""Intersection Recurrent Neural Network (+RNN) cell.
|
"""Intersection Recurrent Neural Network (+RNN) cell.
|
||||||
|
|
||||||
Architecture with coupled recurrent gate as well as coupled depth
|
Architecture with coupled recurrent gate as well as coupled depth
|
||||||
@ -1712,7 +1708,7 @@ class IntersectionRNNCell(core_rnn_cell.RNNCell):
|
|||||||
_REGISTERED_OPS = None
|
_REGISTERED_OPS = None
|
||||||
|
|
||||||
|
|
||||||
class CompiledWrapper(core_rnn_cell.RNNCell):
|
class CompiledWrapper(rnn_cell_impl.RNNCell):
|
||||||
"""Wraps step execution in an XLA JIT scope."""
|
"""Wraps step execution in an XLA JIT scope."""
|
||||||
|
|
||||||
def __init__(self, cell, compile_stateful=False):
|
def __init__(self, cell, compile_stateful=False):
|
||||||
@ -1783,7 +1779,7 @@ def _random_exp_initializer(minval,
|
|||||||
return _initializer
|
return _initializer
|
||||||
|
|
||||||
|
|
||||||
class PhasedLSTMCell(core_rnn_cell.RNNCell):
|
class PhasedLSTMCell(rnn_cell_impl.RNNCell):
|
||||||
"""Phased LSTM recurrent network cell.
|
"""Phased LSTM recurrent network cell.
|
||||||
|
|
||||||
https://arxiv.org/pdf/1610.09513v1.pdf
|
https://arxiv.org/pdf/1610.09513v1.pdf
|
||||||
@ -1831,7 +1827,7 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def state_size(self):
|
def state_size(self):
|
||||||
return core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units)
|
return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_size(self):
|
def output_size(self):
|
||||||
@ -1858,13 +1854,13 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
|
|||||||
It stores the time.
|
It stores the time.
|
||||||
The second Tensor has shape [batch, features_size], and type float32.
|
The second Tensor has shape [batch, features_size], and type float32.
|
||||||
It stores the features.
|
It stores the features.
|
||||||
state: core_rnn_cell.LSTMStateTuple, state from previous timestep.
|
state: rnn_cell_impl.LSTMStateTuple, state from previous timestep.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing:
|
A tuple containing:
|
||||||
- A Tensor of float32, and shape [batch_size, num_units], representing the
|
- A Tensor of float32, and shape [batch_size, num_units], representing the
|
||||||
output of the cell.
|
output of the cell.
|
||||||
- A core_rnn_cell.LSTMStateTuple, containing 2 Tensors of float32, shape
|
- A rnn_cell_impl.LSTMStateTuple, containing 2 Tensors of float32, shape
|
||||||
[batch_size, num_units], representing the new state and the output.
|
[batch_size, num_units], representing the new state and the output.
|
||||||
"""
|
"""
|
||||||
(c_prev, h_prev) = state
|
(c_prev, h_prev) = state
|
||||||
@ -1921,12 +1917,12 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
|
|||||||
new_c = k * new_c + (1 - k) * c_prev
|
new_c = k * new_c + (1 - k) * c_prev
|
||||||
new_h = k * new_h + (1 - k) * h_prev
|
new_h = k * new_h + (1 - k) * h_prev
|
||||||
|
|
||||||
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
|
new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
|
||||||
|
|
||||||
return new_h, new_state
|
return new_h, new_state
|
||||||
|
|
||||||
|
|
||||||
class GLSTMCell(core_rnn_cell.RNNCell):
|
class GLSTMCell(rnn_cell_impl.RNNCell):
|
||||||
"""Group LSTM cell (G-LSTM).
|
"""Group LSTM cell (G-LSTM).
|
||||||
|
|
||||||
The implementation is based on:
|
The implementation is based on:
|
||||||
@ -1982,10 +1978,10 @@ class GLSTMCell(core_rnn_cell.RNNCell):
|
|||||||
int(self._num_units / self._number_of_groups)]
|
int(self._num_units / self._number_of_groups)]
|
||||||
|
|
||||||
if num_proj:
|
if num_proj:
|
||||||
self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_proj)
|
self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
|
||||||
self._output_size = num_proj
|
self._output_size = num_proj
|
||||||
else:
|
else:
|
||||||
self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_units)
|
self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
|
||||||
self._output_size = num_units
|
self._output_size = num_units
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -2097,5 +2093,5 @@ class GLSTMCell(core_rnn_cell.RNNCell):
|
|||||||
with vs.variable_scope("projection"):
|
with vs.variable_scope("projection"):
|
||||||
m = _linear(m, self._num_proj, bias=False)
|
m = _linear(m, self._num_proj, bias=False)
|
||||||
|
|
||||||
new_state = core_rnn_cell.LSTMStateTuple(c, m)
|
new_state = rnn_cell_impl.LSTMStateTuple(c, m)
|
||||||
return m, new_state
|
return m, new_state
|
||||||
|
@ -24,13 +24,13 @@ import functools
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.rnn import core_rnn_cell
|
|
||||||
from tensorflow.contrib.seq2seq.python.ops import decoder
|
from tensorflow.contrib.seq2seq.python.ops import decoder
|
||||||
from tensorflow.contrib.seq2seq.python.ops import attention_wrapper as wrapper
|
from tensorflow.contrib.seq2seq.python.ops import attention_wrapper as wrapper
|
||||||
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
|
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
|
||||||
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
|
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
|
from tensorflow.python.ops import rnn_cell
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.ops import variable_scope as vs
|
from tensorflow.python.ops import variable_scope as vs
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -41,7 +41,7 @@ from tensorflow.python.util import nest
|
|||||||
|
|
||||||
# for testing
|
# for testing
|
||||||
AttentionWrapperState = wrapper.AttentionWrapperState # pylint: disable=invalid-name
|
AttentionWrapperState = wrapper.AttentionWrapperState # pylint: disable=invalid-name
|
||||||
LSTMStateTuple = core_rnn_cell.LSTMStateTuple # pylint: disable=invalid-name
|
LSTMStateTuple = rnn_cell.LSTMStateTuple # pylint: disable=invalid-name
|
||||||
BasicDecoderOutput = basic_decoder.BasicDecoderOutput # pylint: disable=invalid-name
|
BasicDecoderOutput = basic_decoder.BasicDecoderOutput # pylint: disable=invalid-name
|
||||||
float32 = np.float32
|
float32 = np.float32
|
||||||
int32 = np.int32
|
int32 = np.int32
|
||||||
@ -112,7 +112,7 @@ class AttentionWrapperTest(test.TestCase):
|
|||||||
with vs.variable_scope(
|
with vs.variable_scope(
|
||||||
'root',
|
'root',
|
||||||
initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
|
initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
|
||||||
cell = core_rnn_cell.LSTMCell(cell_depth)
|
cell = rnn_cell.LSTMCell(cell_depth)
|
||||||
cell = wrapper.AttentionWrapper(
|
cell = wrapper.AttentionWrapper(
|
||||||
cell,
|
cell,
|
||||||
attention_mechanism,
|
attention_mechanism,
|
||||||
@ -133,7 +133,7 @@ class AttentionWrapperTest(test.TestCase):
|
|||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
isinstance(final_state, wrapper.AttentionWrapperState))
|
isinstance(final_state, wrapper.AttentionWrapperState))
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
isinstance(final_state.cell_state, core_rnn_cell.LSTMStateTuple))
|
isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))
|
||||||
|
|
||||||
self.assertEqual((batch_size, None, attention_depth),
|
self.assertEqual((batch_size, None, attention_depth),
|
||||||
tuple(final_outputs.rnn_output.get_shape().as_list()))
|
tuple(final_outputs.rnn_output.get_shape().as_list()))
|
||||||
|
@ -21,13 +21,13 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.rnn import core_rnn_cell
|
|
||||||
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
|
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
|
||||||
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
|
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.layers import core as layers_core
|
from tensorflow.python.layers import core as layers_core
|
||||||
|
from tensorflow.python.ops import rnn_cell
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
# pylint: enable=g-import-not-at-top
|
# pylint: enable=g-import-not-at-top
|
||||||
@ -46,7 +46,7 @@ class BasicDecoderTest(test.TestCase):
|
|||||||
with self.test_session(use_gpu=True) as sess:
|
with self.test_session(use_gpu=True) as sess:
|
||||||
inputs = np.random.randn(batch_size, max_time,
|
inputs = np.random.randn(batch_size, max_time,
|
||||||
input_depth).astype(np.float32)
|
input_depth).astype(np.float32)
|
||||||
cell = core_rnn_cell.LSTMCell(cell_depth)
|
cell = rnn_cell.LSTMCell(cell_depth)
|
||||||
helper = helper_py.TrainingHelper(
|
helper = helper_py.TrainingHelper(
|
||||||
inputs, sequence_length, time_major=False)
|
inputs, sequence_length, time_major=False)
|
||||||
if use_output_layer:
|
if use_output_layer:
|
||||||
@ -77,8 +77,8 @@ class BasicDecoderTest(test.TestCase):
|
|||||||
constant_op.constant(0), first_inputs, first_state)
|
constant_op.constant(0), first_inputs, first_state)
|
||||||
batch_size_t = my_decoder.batch_size
|
batch_size_t = my_decoder.batch_size
|
||||||
|
|
||||||
self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple))
|
self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
|
||||||
self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple))
|
self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
|
isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
|
||||||
self.assertEqual((batch_size, expected_output_depth),
|
self.assertEqual((batch_size, expected_output_depth),
|
||||||
@ -130,7 +130,7 @@ class BasicDecoderTest(test.TestCase):
|
|||||||
with self.test_session(use_gpu=True) as sess:
|
with self.test_session(use_gpu=True) as sess:
|
||||||
embeddings = np.random.randn(vocabulary_size,
|
embeddings = np.random.randn(vocabulary_size,
|
||||||
input_depth).astype(np.float32)
|
input_depth).astype(np.float32)
|
||||||
cell = core_rnn_cell.LSTMCell(vocabulary_size)
|
cell = rnn_cell.LSTMCell(vocabulary_size)
|
||||||
helper = helper_py.GreedyEmbeddingHelper(embeddings, start_tokens,
|
helper = helper_py.GreedyEmbeddingHelper(embeddings, start_tokens,
|
||||||
end_token)
|
end_token)
|
||||||
my_decoder = basic_decoder.BasicDecoder(
|
my_decoder = basic_decoder.BasicDecoder(
|
||||||
@ -154,8 +154,8 @@ class BasicDecoderTest(test.TestCase):
|
|||||||
constant_op.constant(0), first_inputs, first_state)
|
constant_op.constant(0), first_inputs, first_state)
|
||||||
batch_size_t = my_decoder.batch_size
|
batch_size_t = my_decoder.batch_size
|
||||||
|
|
||||||
self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple))
|
self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
|
||||||
self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple))
|
self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
|
isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
|
||||||
self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
|
self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
|
||||||
@ -202,7 +202,7 @@ class BasicDecoderTest(test.TestCase):
|
|||||||
embeddings = np.random.randn(
|
embeddings = np.random.randn(
|
||||||
vocabulary_size, input_depth).astype(np.float32)
|
vocabulary_size, input_depth).astype(np.float32)
|
||||||
half = constant_op.constant(0.5)
|
half = constant_op.constant(0.5)
|
||||||
cell = core_rnn_cell.LSTMCell(vocabulary_size)
|
cell = rnn_cell.LSTMCell(vocabulary_size)
|
||||||
helper = helper_py.ScheduledEmbeddingTrainingHelper(
|
helper = helper_py.ScheduledEmbeddingTrainingHelper(
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
sequence_length=sequence_length,
|
sequence_length=sequence_length,
|
||||||
@ -230,8 +230,8 @@ class BasicDecoderTest(test.TestCase):
|
|||||||
constant_op.constant(0), first_inputs, first_state)
|
constant_op.constant(0), first_inputs, first_state)
|
||||||
batch_size_t = my_decoder.batch_size
|
batch_size_t = my_decoder.batch_size
|
||||||
|
|
||||||
self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple))
|
self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
|
||||||
self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple))
|
self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
|
isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
|
||||||
self.assertEqual((batch_size, vocabulary_size),
|
self.assertEqual((batch_size, vocabulary_size),
|
||||||
@ -293,7 +293,7 @@ class BasicDecoderTest(test.TestCase):
|
|||||||
with self.test_session(use_gpu=True) as sess:
|
with self.test_session(use_gpu=True) as sess:
|
||||||
inputs = np.random.randn(batch_size, max_time,
|
inputs = np.random.randn(batch_size, max_time,
|
||||||
input_depth).astype(np.float32)
|
input_depth).astype(np.float32)
|
||||||
cell = core_rnn_cell.LSTMCell(cell_depth)
|
cell = rnn_cell.LSTMCell(cell_depth)
|
||||||
sampling_probability = constant_op.constant(sampling_probability)
|
sampling_probability = constant_op.constant(sampling_probability)
|
||||||
|
|
||||||
next_input_layer = None
|
next_input_layer = None
|
||||||
@ -335,8 +335,8 @@ class BasicDecoderTest(test.TestCase):
|
|||||||
|
|
||||||
batch_size_t = my_decoder.batch_size
|
batch_size_t = my_decoder.batch_size
|
||||||
|
|
||||||
self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple))
|
self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
|
||||||
self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple))
|
self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
|
isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
|
||||||
self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
|
self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
|
||||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.rnn import core_rnn_cell
|
|
||||||
from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
|
from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
|
||||||
from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder
|
from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder
|
||||||
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
|
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
|
||||||
@ -32,6 +31,7 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.layers import core as layers_core
|
from tensorflow.python.layers import core as layers_core
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
|
from tensorflow.python.ops import rnn_cell
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -241,7 +241,7 @@ class BeamSearchDecoderTest(test.TestCase):
|
|||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
|
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
|
||||||
cell = core_rnn_cell.LSTMCell(cell_depth)
|
cell = rnn_cell.LSTMCell(cell_depth)
|
||||||
if has_attention:
|
if has_attention:
|
||||||
inputs = np.random.randn(batch_size, decoder_max_time,
|
inputs = np.random.randn(batch_size, decoder_max_time,
|
||||||
input_depth).astype(np.float32)
|
input_depth).astype(np.float32)
|
||||||
|
@ -21,12 +21,12 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.rnn import core_rnn_cell
|
|
||||||
from tensorflow.contrib.seq2seq.python.ops import decoder
|
from tensorflow.contrib.seq2seq.python.ops import decoder
|
||||||
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
|
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
|
||||||
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
|
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.ops import rnn
|
from tensorflow.python.ops import rnn
|
||||||
|
from tensorflow.python.ops import rnn_cell
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.ops import variable_scope as vs
|
from tensorflow.python.ops import variable_scope as vs
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -51,7 +51,7 @@ class DynamicDecodeRNNTest(test.TestCase):
|
|||||||
else:
|
else:
|
||||||
inputs = np.random.randn(batch_size, max_time,
|
inputs = np.random.randn(batch_size, max_time,
|
||||||
input_depth).astype(np.float32)
|
input_depth).astype(np.float32)
|
||||||
cell = core_rnn_cell.LSTMCell(cell_depth)
|
cell = rnn_cell.LSTMCell(cell_depth)
|
||||||
helper = helper_py.TrainingHelper(
|
helper = helper_py.TrainingHelper(
|
||||||
inputs, sequence_length, time_major=time_major)
|
inputs, sequence_length, time_major=time_major)
|
||||||
my_decoder = basic_decoder.BasicDecoder(
|
my_decoder = basic_decoder.BasicDecoder(
|
||||||
@ -71,7 +71,7 @@ class DynamicDecodeRNNTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
|
isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
|
||||||
self.assertTrue(isinstance(final_state, core_rnn_cell.LSTMStateTuple))
|
self.assertTrue(isinstance(final_state, rnn_cell.LSTMStateTuple))
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
(batch_size,),
|
(batch_size,),
|
||||||
@ -126,7 +126,7 @@ class DynamicDecodeRNNTest(test.TestCase):
|
|||||||
inputs = np.random.randn(batch_size, max_time,
|
inputs = np.random.randn(batch_size, max_time,
|
||||||
input_depth).astype(np.float32)
|
input_depth).astype(np.float32)
|
||||||
|
|
||||||
cell = core_rnn_cell.LSTMCell(cell_depth)
|
cell = rnn_cell.LSTMCell(cell_depth)
|
||||||
zero_state = cell.zero_state(dtype=dtypes.float32, batch_size=batch_size)
|
zero_state = cell.zero_state(dtype=dtypes.float32, batch_size=batch_size)
|
||||||
helper = helper_py.TrainingHelper(inputs, sequence_length)
|
helper = helper_py.TrainingHelper(inputs, sequence_length)
|
||||||
my_decoder = basic_decoder.BasicDecoder(
|
my_decoder = basic_decoder.BasicDecoder(
|
||||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
|||||||
import collections
|
import collections
|
||||||
import math
|
import math
|
||||||
|
|
||||||
from tensorflow.contrib.rnn import core_rnn_cell
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
@ -500,7 +499,7 @@ def hardmax(logits, name=None):
|
|||||||
math_ops.argmax(logits, -1), depth, dtype=logits.dtype)
|
math_ops.argmax(logits, -1), depth, dtype=logits.dtype)
|
||||||
|
|
||||||
|
|
||||||
class AttentionWrapper(core_rnn_cell.RNNCell):
|
class AttentionWrapper(rnn_cell_impl.RNNCell):
|
||||||
"""Wraps another `RNNCell` with attention.
|
"""Wraps another `RNNCell` with attention.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -108,9 +108,7 @@ class PrintModelAnalysisTest(test.TestCase):
|
|||||||
|
|
||||||
with gfile.Open(outfile, 'r') as f:
|
with gfile.Open(outfile, 'r') as f:
|
||||||
# pylint: disable=line-too-long
|
# pylint: disable=line-too-long
|
||||||
self.assertEqual(
|
self.assertEqual('_TFProfRoot (', f.read()[0:13])
|
||||||
'_TFProfRoot (0/451 params, 0/10.44k flops, 0B/5.28KB)\n model_analyzer_testlib.py:33:BuildSmallModel:image = array_ops... (0/0 params, 0/0 flops, 0B/864B)\n model_analyzer_testlib.py:37:BuildSmallModel:initializer=init_... (0/1 params, 0/0 flops, 0B/0B)\n model_analyzer_testlib.py:41:BuildSmallModel:initializer=init_... (0/162 params, 0/0 flops, 0B/1.30KB)\n model_analyzer_testlib.py:42:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/5.83k flops, 0B/432B)\n model_analyzer_testlib.py:46:BuildSmallModel:initializer=init_... (0/288 params, 0/0 flops, 0B/2.30KB)\n model_analyzer_testlib.py:47:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/4.61k flops, 0B/384B)\n',
|
|
||||||
f.read())
|
|
||||||
# pylint: enable=line-too-long
|
# pylint: enable=line-too-long
|
||||||
|
|
||||||
def testComplexCodeView(self):
|
def testComplexCodeView(self):
|
||||||
@ -138,25 +136,28 @@ class PrintModelAnalysisTest(test.TestCase):
|
|||||||
|
|
||||||
# pylint: disable=line-too-long
|
# pylint: disable=line-too-long
|
||||||
with gfile.Open(outfile, 'r') as f:
|
with gfile.Open(outfile, 'r') as f:
|
||||||
self.assertEqual(
|
self.assertEqual('_TFProfRoot (0', f.read()[:14])
|
||||||
'_TFProfRoot (0/2.84k params, 0/54.08k flops)\n model_analyzer_testlib.py:56:BuildFullModel:seq.append(array_... (0/1.80k params, 0/41.76k flops)\n model_analyzer_testlib.py:33:BuildSmallModel:image = array_ops... (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:37:BuildSmallModel:initializer=init_... (0/4 params, 0/0 flops)\n model_analyzer_testlib.py:41:BuildSmallModel:initializer=init_... (0/648 params, 0/0 flops)\n model_analyzer_testlib.py:42:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/23.33k flops)\n model_analyzer_testlib.py:46:BuildSmallModel:initializer=init_... (0/1.15k params, 0/0 flops)\n model_analyzer_testlib.py:47:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/18.43k flops)\n model_analyzer_testlib.py:60:BuildFullModel:cell, array_ops.c... (0/1.04k params, 0/4.13k flops)\n model_analyzer_testlib.py:62:BuildFullModel:target = array_op... (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:63:BuildFullModel:loss = nn_ops.l2_... (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:65:BuildFullModel:return sgd_op.min... (0/0 params, 0/8.19k flops)\n',
|
|
||||||
f.read())
|
|
||||||
|
|
||||||
self.assertLess(0, tfprof_node.total_exec_micros)
|
self.assertLess(0, tfprof_node.total_exec_micros)
|
||||||
self.assertEqual(2844, tfprof_node.total_parameters)
|
self.assertEqual(2844, tfprof_node.total_parameters)
|
||||||
self.assertEqual(54080, tfprof_node.total_float_ops)
|
self.assertEqual(54080, tfprof_node.total_float_ops)
|
||||||
self.assertEqual(5, len(tfprof_node.children))
|
self.assertEqual(5, len(tfprof_node.children))
|
||||||
self.assertEqual('_TFProfRoot', tfprof_node.name)
|
self.assertEqual('_TFProfRoot', tfprof_node.name)
|
||||||
self.assertEqual('model_analyzer_testlib.py:56:BuildFullModel:seq.append(array_...',
|
self.assertEqual(
|
||||||
tfprof_node.children[0].name)
|
'model_analyzer_testlib.py:58:BuildFullModel:seq.append(array_...',
|
||||||
self.assertEqual('model_analyzer_testlib.py:60:BuildFullModel:cell, array_ops.c...',
|
tfprof_node.children[0].name)
|
||||||
tfprof_node.children[1].name)
|
self.assertEqual(
|
||||||
self.assertEqual('model_analyzer_testlib.py:62:BuildFullModel:target = array_op...',
|
'model_analyzer_testlib.py:62:BuildFullModel:cell, array_ops.c...',
|
||||||
tfprof_node.children[2].name)
|
tfprof_node.children[1].name)
|
||||||
self.assertEqual('model_analyzer_testlib.py:63:BuildFullModel:loss = nn_ops.l2_...',
|
self.assertEqual(
|
||||||
tfprof_node.children[3].name)
|
'model_analyzer_testlib.py:64:BuildFullModel:target = array_op...',
|
||||||
self.assertEqual('model_analyzer_testlib.py:65:BuildFullModel:return sgd_op.min...',
|
tfprof_node.children[2].name)
|
||||||
tfprof_node.children[4].name)
|
self.assertEqual(
|
||||||
|
'model_analyzer_testlib.py:65:BuildFullModel:loss = nn_ops.l2_...',
|
||||||
|
tfprof_node.children[3].name)
|
||||||
|
self.assertEqual(
|
||||||
|
'model_analyzer_testlib.py:67:BuildFullModel:return sgd_op.min...',
|
||||||
|
tfprof_node.children[4].name)
|
||||||
# pylint: enable=line-too-long
|
# pylint: enable=line-too-long
|
||||||
|
|
||||||
def testCodeViewLeafGraphNode(self):
|
def testCodeViewLeafGraphNode(self):
|
||||||
|
@ -17,13 +17,15 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import BasicRNNCell
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
from tensorflow.python.ops import rnn
|
from tensorflow.python.ops import rnn
|
||||||
|
from tensorflow.python.ops import rnn_cell
|
||||||
|
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.training import gradient_descent
|
from tensorflow.python.training import gradient_descent
|
||||||
|
|
||||||
@ -55,7 +57,7 @@ def BuildFullModel():
|
|||||||
with variable_scope.variable_scope('inp_%d' % i):
|
with variable_scope.variable_scope('inp_%d' % i):
|
||||||
seq.append(array_ops.reshape(BuildSmallModel(), [2, 1, -1]))
|
seq.append(array_ops.reshape(BuildSmallModel(), [2, 1, -1]))
|
||||||
|
|
||||||
cell = BasicRNNCell(16, 48)
|
cell = rnn_cell.BasicRNNCell(16)
|
||||||
out = rnn.dynamic_rnn(
|
out = rnn.dynamic_rnn(
|
||||||
cell, array_ops.concat(seq, axis=1), dtype=dtypes.float32)[0]
|
cell, array_ops.concat(seq, axis=1), dtype=dtypes.float32)[0]
|
||||||
|
|
||||||
@ -63,5 +65,3 @@ def BuildFullModel():
|
|||||||
loss = nn_ops.l2_loss(math_ops.reduce_mean(target - out))
|
loss = nn_ops.l2_loss(math_ops.reduce_mean(target - out))
|
||||||
sgd_op = gradient_descent.GradientDescentOptimizer(1e-2)
|
sgd_op = gradient_descent.GradientDescentOptimizer(1e-2)
|
||||||
return sgd_op.minimize(loss)
|
return sgd_op.minimize(loss)
|
||||||
|
|
||||||
|
|
||||||
|
@ -186,6 +186,6 @@ apply from: "download-models.gradle"
|
|||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
if (nativeBuildSystem == 'cmake' || nativeBuildSystem == 'none') {
|
if (nativeBuildSystem == 'cmake' || nativeBuildSystem == 'none') {
|
||||||
compile 'org.tensorflow:tensorflow-android:1.2.0-preview'
|
compile 'org.tensorflow:tensorflow-android:+'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -211,6 +211,7 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Session_run(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!throwExceptionIfNotOK(env, status)) {
|
if (!throwExceptionIfNotOK(env, status)) {
|
||||||
|
TF_DeleteStatus(status);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
jlong* t = env->GetLongArrayElements(output_tensor_handles, nullptr);
|
jlong* t = env->GetLongArrayElements(output_tensor_handles, nullptr);
|
||||||
@ -226,5 +227,6 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Session_run(
|
|||||||
memcpy(elems, run_metadata->data, run_metadata->length);
|
memcpy(elems, run_metadata->data, run_metadata->length);
|
||||||
env->ReleaseByteArrayElements(ret, elems, JNI_COMMIT);
|
env->ReleaseByteArrayElements(ret, elems, JNI_COMMIT);
|
||||||
}
|
}
|
||||||
|
TF_DeleteStatus(status);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
@ -1709,14 +1709,23 @@ py_library(
|
|||||||
py_library(
|
py_library(
|
||||||
name = "rnn_cell",
|
name = "rnn_cell",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
"ops/rnn_cell.py",
|
||||||
"ops/rnn_cell_impl.py",
|
"ops/rnn_cell_impl.py",
|
||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":array_ops",
|
":array_ops",
|
||||||
|
":clip_ops",
|
||||||
":framework_for_generated_wrappers",
|
":framework_for_generated_wrappers",
|
||||||
|
":init_ops",
|
||||||
":layers_base",
|
":layers_base",
|
||||||
|
":math_ops",
|
||||||
|
":nn_ops",
|
||||||
|
":partitioned_variables",
|
||||||
|
":random_ops",
|
||||||
":util",
|
":util",
|
||||||
|
":variable_scope",
|
||||||
|
":variables",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ from tensorflow.python.platform import test
|
|||||||
from tensorflow.python.training import gradient_descent
|
from tensorflow.python.training import gradient_descent
|
||||||
|
|
||||||
|
|
||||||
class _RNNCellForTest(rnn_cell_impl._RNNCell): # pylint: disable=protected-access
|
class _RNNCellForTest(rnn_cell_impl.RNNCell): # pylint: disable=protected-access
|
||||||
"""RNN cell for testing."""
|
"""RNN cell for testing."""
|
||||||
|
|
||||||
def __init__(self, input_output_size, state_size):
|
def __init__(self, input_output_size, state_size):
|
||||||
|
@ -80,5 +80,6 @@ py_test(
|
|||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
|
"//tensorflow/python/estimator:inputs",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -140,6 +140,7 @@ from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
|||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import embedding_ops
|
from tensorflow.python.ops import embedding_ops
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import lookup_ops
|
from tensorflow.python.ops import lookup_ops
|
||||||
@ -1443,9 +1444,7 @@ class _LazyBuilder(object):
|
|||||||
return self._feature_tensors[key]
|
return self._feature_tensors[key]
|
||||||
|
|
||||||
if key in self._features:
|
if key in self._features:
|
||||||
# FeatureColumn is a raw feature.
|
feature_tensor = self._get_raw_feature_as_tensor(key)
|
||||||
feature_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
|
|
||||||
self._features[key])
|
|
||||||
self._feature_tensors[key] = feature_tensor
|
self._feature_tensors[key] = feature_tensor
|
||||||
return feature_tensor
|
return feature_tensor
|
||||||
|
|
||||||
@ -1464,6 +1463,55 @@ class _LazyBuilder(object):
|
|||||||
self._feature_tensors[column] = transformed
|
self._feature_tensors[column] = transformed
|
||||||
return transformed
|
return transformed
|
||||||
|
|
||||||
|
def _get_raw_feature_as_tensor(self, key):
|
||||||
|
"""Gets the raw_feature (keyed by `key`) as `tensor`.
|
||||||
|
|
||||||
|
The raw feature is converted to (sparse) tensor and maybe expand dim.
|
||||||
|
|
||||||
|
For both `Tensor` and `SparseTensor`, the rank will be expanded (to 2) if
|
||||||
|
the rank is 1. This supports dynamic rank also. For rank 0 raw feature, will
|
||||||
|
error out as it is not supported.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: A `str` key to access the raw feature.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` or `SparseTensor`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if the raw feature has rank 0.
|
||||||
|
"""
|
||||||
|
raw_feature = self._features[key]
|
||||||
|
feature_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
|
||||||
|
raw_feature)
|
||||||
|
|
||||||
|
def expand_dims(input_tensor):
|
||||||
|
# Input_tensor must have rank 1.
|
||||||
|
if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
|
||||||
|
return sparse_ops.sparse_reshape(
|
||||||
|
input_tensor, [array_ops.shape(input_tensor)[0], -1])
|
||||||
|
else:
|
||||||
|
return array_ops.expand_dims(input_tensor, -1)
|
||||||
|
|
||||||
|
rank = feature_tensor.get_shape().ndims
|
||||||
|
if rank is not None:
|
||||||
|
if rank == 0:
|
||||||
|
raise ValueError(
|
||||||
|
'Feature (key: {}) cannot have rank 0. Give: {}'.format(
|
||||||
|
key, feature_tensor))
|
||||||
|
return feature_tensor if rank != 1 else expand_dims(feature_tensor)
|
||||||
|
|
||||||
|
# Handle dynamic rank.
|
||||||
|
with ops.control_dependencies([
|
||||||
|
check_ops.assert_positive(
|
||||||
|
array_ops.rank(feature_tensor),
|
||||||
|
message='Feature (key: {}) cannot have rank 0. Given: {}'.format(
|
||||||
|
key, feature_tensor))]):
|
||||||
|
return control_flow_ops.cond(
|
||||||
|
math_ops.equal(1, array_ops.rank(feature_tensor)),
|
||||||
|
lambda: expand_dims(feature_tensor),
|
||||||
|
lambda: feature_tensor)
|
||||||
|
|
||||||
|
|
||||||
# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
|
# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
|
||||||
def _shape_offsets(shape):
|
def _shape_offsets(shape):
|
||||||
|
@ -26,6 +26,7 @@ import numpy as np
|
|||||||
from tensorflow.core.example import example_pb2
|
from tensorflow.core.example import example_pb2
|
||||||
from tensorflow.core.example import feature_pb2
|
from tensorflow.core.example import feature_pb2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
|
from tensorflow.python.estimator.inputs import numpy_io
|
||||||
from tensorflow.python.feature_column import feature_column_lib as fc
|
from tensorflow.python.feature_column import feature_column_lib as fc
|
||||||
from tensorflow.python.feature_column.feature_column import _CategoricalColumn
|
from tensorflow.python.feature_column.feature_column import _CategoricalColumn
|
||||||
from tensorflow.python.feature_column.feature_column import _DenseColumn
|
from tensorflow.python.feature_column.feature_column import _DenseColumn
|
||||||
@ -43,6 +44,8 @@ from tensorflow.python.ops import parsing_ops
|
|||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables as variables_lib
|
from tensorflow.python.ops import variables as variables_lib
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.training import coordinator
|
||||||
|
from tensorflow.python.training import queue_runner_impl
|
||||||
|
|
||||||
|
|
||||||
def _initialized_session():
|
def _initialized_session():
|
||||||
@ -1504,6 +1507,131 @@ class LinearModelTest(test.TestCase):
|
|||||||
features['price2']: [[1.], [5.]],
|
features['price2']: [[1.], [5.]],
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def test_with_numpy_input_fn(self):
|
||||||
|
price = fc.numeric_column('price')
|
||||||
|
price_buckets = fc.bucketized_column(price, boundaries=[0., 10., 100.,])
|
||||||
|
body_style = fc.categorical_column_with_vocabulary_list(
|
||||||
|
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
|
||||||
|
|
||||||
|
input_fn = numpy_io.numpy_input_fn(
|
||||||
|
x={
|
||||||
|
'price': np.array([-1., 2., 13., 104.]),
|
||||||
|
'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
|
||||||
|
},
|
||||||
|
batch_size=2,
|
||||||
|
shuffle=False)
|
||||||
|
features = input_fn()
|
||||||
|
net = fc.linear_model(features, [price_buckets, body_style])
|
||||||
|
# self.assertEqual(1 + 3 + 5, net.shape[1])
|
||||||
|
with _initialized_session() as sess:
|
||||||
|
coord = coordinator.Coordinator()
|
||||||
|
threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
|
||||||
|
|
||||||
|
bias = get_linear_model_bias()
|
||||||
|
price_buckets_var = get_linear_model_column_var(price_buckets)
|
||||||
|
body_style_var = get_linear_model_column_var(body_style)
|
||||||
|
|
||||||
|
sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
|
||||||
|
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||||
|
sess.run(bias.assign([5.]))
|
||||||
|
|
||||||
|
self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))
|
||||||
|
|
||||||
|
coord.request_stop()
|
||||||
|
coord.join(threads)
|
||||||
|
|
||||||
|
def test_with_1d_sparse_tensor(self):
|
||||||
|
price = fc.numeric_column('price')
|
||||||
|
price_buckets = fc.bucketized_column(price, boundaries=[0., 10., 100.,])
|
||||||
|
body_style = fc.categorical_column_with_vocabulary_list(
|
||||||
|
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
|
||||||
|
|
||||||
|
# Provides 1-dim tensor and dense tensor.
|
||||||
|
features = {
|
||||||
|
'price': constant_op.constant([-1., 12.,]),
|
||||||
|
'body-style': sparse_tensor.SparseTensor(
|
||||||
|
indices=((0,), (1,)),
|
||||||
|
values=('sedan', 'hardtop'),
|
||||||
|
dense_shape=(2,)),
|
||||||
|
}
|
||||||
|
self.assertEqual(1, features['price'].shape.ndims)
|
||||||
|
self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
|
||||||
|
|
||||||
|
net = fc.linear_model(features, [price_buckets, body_style])
|
||||||
|
with _initialized_session() as sess:
|
||||||
|
bias = get_linear_model_bias()
|
||||||
|
price_buckets_var = get_linear_model_column_var(price_buckets)
|
||||||
|
body_style_var = get_linear_model_column_var(body_style)
|
||||||
|
|
||||||
|
sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
|
||||||
|
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||||
|
sess.run(bias.assign([5.]))
|
||||||
|
|
||||||
|
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
|
||||||
|
|
||||||
|
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||||
|
price = fc.numeric_column('price')
|
||||||
|
price_buckets = fc.bucketized_column(price, boundaries=[0., 10., 100.,])
|
||||||
|
body_style = fc.categorical_column_with_vocabulary_list(
|
||||||
|
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
|
||||||
|
country = fc.categorical_column_with_vocabulary_list(
|
||||||
|
'country', vocabulary_list=['US', 'JP', 'CA'])
|
||||||
|
|
||||||
|
# Provides 1-dim tensor and dense tensor.
|
||||||
|
features = {
|
||||||
|
'price': array_ops.placeholder(dtypes.float32),
|
||||||
|
'body-style': array_ops.sparse_placeholder(dtypes.string),
|
||||||
|
'country': array_ops.placeholder(dtypes.string),
|
||||||
|
}
|
||||||
|
self.assertIsNone(features['price'].shape.ndims)
|
||||||
|
self.assertIsNone(features['body-style'].get_shape().ndims)
|
||||||
|
|
||||||
|
price_data = np.array([-1., 12.])
|
||||||
|
body_style_data = sparse_tensor.SparseTensorValue(
|
||||||
|
indices=((0,), (1,)),
|
||||||
|
values=('sedan', 'hardtop'),
|
||||||
|
dense_shape=(2,))
|
||||||
|
|
||||||
|
net = fc.linear_model(features, [price_buckets, body_style])
|
||||||
|
bias = get_linear_model_bias()
|
||||||
|
price_buckets_var = get_linear_model_column_var(price_buckets)
|
||||||
|
body_style_var = get_linear_model_column_var(body_style)
|
||||||
|
with _initialized_session() as sess:
|
||||||
|
sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
|
||||||
|
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||||
|
sess.run(bias.assign([5.]))
|
||||||
|
|
||||||
|
self.assertAllClose(
|
||||||
|
[[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
||||||
|
sess.run(net, feed_dict={
|
||||||
|
features['price']: price_data,
|
||||||
|
features['body-style']: body_style_data}))
|
||||||
|
|
||||||
|
# Dense categorical_column with unknown shape is not allowed.
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Undefined input_tensor shape.'):
|
||||||
|
fc.linear_model(features, [price_buckets, body_style, country])
|
||||||
|
|
||||||
|
def test_with_rank_0_feature(self):
|
||||||
|
price = fc.numeric_column('price')
|
||||||
|
features = {
|
||||||
|
'price': constant_op.constant(0),
|
||||||
|
}
|
||||||
|
self.assertEqual(0, features['price'].shape.ndims)
|
||||||
|
|
||||||
|
# Static rank 0 should fail
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
|
||||||
|
fc.linear_model(features, [price])
|
||||||
|
|
||||||
|
# Dynamic rank 0 should fail
|
||||||
|
features = {
|
||||||
|
'price': array_ops.placeholder(dtypes.float32),
|
||||||
|
}
|
||||||
|
net = fc.linear_model(features, [price])
|
||||||
|
self.assertEqual(1, net.shape[1])
|
||||||
|
with _initialized_session() as sess:
|
||||||
|
with self.assertRaisesOpError('Feature .* cannot have rank 0'):
|
||||||
|
sess.run(net, feed_dict={features['price']: np.array(1)})
|
||||||
|
|
||||||
|
|
||||||
class InputLayerTest(test.TestCase):
|
class InputLayerTest(test.TestCase):
|
||||||
|
|
||||||
@ -1663,6 +1791,180 @@ class InputLayerTest(test.TestCase):
|
|||||||
features['price2']: [[1.], [5.]],
|
features['price2']: [[1.], [5.]],
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def test_with_numpy_input_fn(self):
|
||||||
|
embedding_values = (
|
||||||
|
(1., 2., 3., 4., 5.), # id 0
|
||||||
|
(6., 7., 8., 9., 10.), # id 1
|
||||||
|
(11., 12., 13., 14., 15.) # id 2
|
||||||
|
)
|
||||||
|
def _initializer(shape, dtype, partition_info):
|
||||||
|
del shape, dtype, partition_info
|
||||||
|
return embedding_values
|
||||||
|
|
||||||
|
# price has 1 dimension in input_layer
|
||||||
|
price = fc.numeric_column('price')
|
||||||
|
body_style = fc.categorical_column_with_vocabulary_list(
|
||||||
|
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
|
||||||
|
# one_hot_body_style has 3 dims in input_layer.
|
||||||
|
one_hot_body_style = fc.indicator_column(body_style)
|
||||||
|
# embedded_body_style has 5 dims in input_layer.
|
||||||
|
embedded_body_style = fc.embedding_column(body_style, dimension=5,
|
||||||
|
initializer=_initializer)
|
||||||
|
|
||||||
|
input_fn = numpy_io.numpy_input_fn(
|
||||||
|
x={
|
||||||
|
'price': np.array([11., 12., 13., 14.]),
|
||||||
|
'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
|
||||||
|
},
|
||||||
|
batch_size=2,
|
||||||
|
shuffle=False)
|
||||||
|
features = input_fn()
|
||||||
|
net = fc.input_layer(features,
|
||||||
|
[price, one_hot_body_style, embedded_body_style])
|
||||||
|
self.assertEqual(1 + 3 + 5, net.shape[1])
|
||||||
|
with _initialized_session() as sess:
|
||||||
|
coord = coordinator.Coordinator()
|
||||||
|
threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
|
||||||
|
|
||||||
|
# Each row is formed by concatenating `embedded_body_style`,
|
||||||
|
# `one_hot_body_style`, and `price` in order.
|
||||||
|
self.assertAllEqual(
|
||||||
|
[[11., 12., 13., 14., 15., 0., 0., 1., 11.],
|
||||||
|
[1., 2., 3., 4., 5., 1., 0., 0., 12]],
|
||||||
|
sess.run(net))
|
||||||
|
|
||||||
|
coord.request_stop()
|
||||||
|
coord.join(threads)
|
||||||
|
|
||||||
|
def test_with_1d_sparse_tensor(self):
|
||||||
|
embedding_values = (
|
||||||
|
(1., 2., 3., 4., 5.), # id 0
|
||||||
|
(6., 7., 8., 9., 10.), # id 1
|
||||||
|
(11., 12., 13., 14., 15.) # id 2
|
||||||
|
)
|
||||||
|
def _initializer(shape, dtype, partition_info):
|
||||||
|
del shape, dtype, partition_info
|
||||||
|
return embedding_values
|
||||||
|
|
||||||
|
# price has 1 dimension in input_layer
|
||||||
|
price = fc.numeric_column('price')
|
||||||
|
|
||||||
|
# one_hot_body_style has 3 dims in input_layer.
|
||||||
|
body_style = fc.categorical_column_with_vocabulary_list(
|
||||||
|
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
|
||||||
|
one_hot_body_style = fc.indicator_column(body_style)
|
||||||
|
|
||||||
|
# embedded_body_style has 5 dims in input_layer.
|
||||||
|
country = fc.categorical_column_with_vocabulary_list(
|
||||||
|
'country', vocabulary_list=['US', 'JP', 'CA'])
|
||||||
|
embedded_country = fc.embedding_column(country, dimension=5,
|
||||||
|
initializer=_initializer)
|
||||||
|
|
||||||
|
# Provides 1-dim tensor and dense tensor.
|
||||||
|
features = {
|
||||||
|
'price': constant_op.constant([11., 12.,]),
|
||||||
|
'body-style': sparse_tensor.SparseTensor(
|
||||||
|
indices=((0,), (1,)),
|
||||||
|
values=('sedan', 'hardtop'),
|
||||||
|
dense_shape=(2,)),
|
||||||
|
# This is dense tensor for the categorical_column.
|
||||||
|
'country': constant_op.constant(['CA', 'US']),
|
||||||
|
}
|
||||||
|
self.assertEqual(1, features['price'].shape.ndims)
|
||||||
|
self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
|
||||||
|
self.assertEqual(1, features['country'].shape.ndims)
|
||||||
|
|
||||||
|
net = fc.input_layer(features,
|
||||||
|
[price, one_hot_body_style, embedded_country])
|
||||||
|
self.assertEqual(1 + 3 + 5, net.shape[1])
|
||||||
|
with _initialized_session() as sess:
|
||||||
|
|
||||||
|
# Each row is formed by concatenating `embedded_body_style`,
|
||||||
|
# `one_hot_body_style`, and `price` in order.
|
||||||
|
self.assertAllEqual(
|
||||||
|
[[0., 0., 1., 11., 12., 13., 14., 15., 11.],
|
||||||
|
[1., 0., 0., 1., 2., 3., 4., 5., 12.]],
|
||||||
|
sess.run(net))
|
||||||
|
|
||||||
|
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||||
|
embedding_values = (
|
||||||
|
(1., 2., 3., 4., 5.), # id 0
|
||||||
|
(6., 7., 8., 9., 10.), # id 1
|
||||||
|
(11., 12., 13., 14., 15.) # id 2
|
||||||
|
)
|
||||||
|
def _initializer(shape, dtype, partition_info):
|
||||||
|
del shape, dtype, partition_info
|
||||||
|
return embedding_values
|
||||||
|
|
||||||
|
# price has 1 dimension in input_layer
|
||||||
|
price = fc.numeric_column('price')
|
||||||
|
|
||||||
|
# one_hot_body_style has 3 dims in input_layer.
|
||||||
|
body_style = fc.categorical_column_with_vocabulary_list(
|
||||||
|
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
|
||||||
|
one_hot_body_style = fc.indicator_column(body_style)
|
||||||
|
|
||||||
|
# embedded_body_style has 5 dims in input_layer.
|
||||||
|
country = fc.categorical_column_with_vocabulary_list(
|
||||||
|
'country', vocabulary_list=['US', 'JP', 'CA'])
|
||||||
|
embedded_country = fc.embedding_column(country, dimension=5,
|
||||||
|
initializer=_initializer)
|
||||||
|
|
||||||
|
# Provides 1-dim tensor and dense tensor.
|
||||||
|
features = {
|
||||||
|
'price': array_ops.placeholder(dtypes.float32),
|
||||||
|
'body-style': array_ops.sparse_placeholder(dtypes.string),
|
||||||
|
# This is dense tensor for the categorical_column.
|
||||||
|
'country': array_ops.placeholder(dtypes.string),
|
||||||
|
}
|
||||||
|
self.assertIsNone(features['price'].shape.ndims)
|
||||||
|
self.assertIsNone(features['body-style'].get_shape().ndims)
|
||||||
|
self.assertIsNone(features['country'].shape.ndims)
|
||||||
|
|
||||||
|
price_data = np.array([11., 12.])
|
||||||
|
body_style_data = sparse_tensor.SparseTensorValue(
|
||||||
|
indices=((0,), (1,)),
|
||||||
|
values=('sedan', 'hardtop'),
|
||||||
|
dense_shape=(2,))
|
||||||
|
|
||||||
|
# Dense categorical_column with unknown shape is not allowed.
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Undefined input_tensor shape.'):
|
||||||
|
fc.input_layer(features, [price, one_hot_body_style, embedded_country])
|
||||||
|
|
||||||
|
net = fc.input_layer(features, [price, one_hot_body_style])
|
||||||
|
self.assertEqual(1 + 3, net.shape[1])
|
||||||
|
with _initialized_session() as sess:
|
||||||
|
|
||||||
|
# Each row is formed by concatenating `embedded_body_style`,
|
||||||
|
# `one_hot_body_style`, and `price` in order.
|
||||||
|
self.assertAllEqual(
|
||||||
|
[[0., 0., 1., 11.], [1., 0., 0., 12.]],
|
||||||
|
sess.run(net, feed_dict={
|
||||||
|
features['price']: price_data,
|
||||||
|
features['body-style']: body_style_data}))
|
||||||
|
|
||||||
|
def test_with_rank_0_feature(self):
|
||||||
|
# price has 1 dimension in input_layer
|
||||||
|
price = fc.numeric_column('price')
|
||||||
|
features = {
|
||||||
|
'price': constant_op.constant(0),
|
||||||
|
}
|
||||||
|
self.assertEqual(0, features['price'].shape.ndims)
|
||||||
|
|
||||||
|
# Static rank 0 should fail
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
|
||||||
|
fc.input_layer(features, [price])
|
||||||
|
|
||||||
|
# Dynamic rank 0 should fail
|
||||||
|
features = {
|
||||||
|
'price': array_ops.placeholder(dtypes.float32),
|
||||||
|
}
|
||||||
|
net = fc.input_layer(features, [price])
|
||||||
|
self.assertEqual(1, net.shape[1])
|
||||||
|
with _initialized_session() as sess:
|
||||||
|
with self.assertRaisesOpError('Feature .* cannot have rank 0'):
|
||||||
|
sess.run(net, feed_dict={features['price']: np.array(1)})
|
||||||
|
|
||||||
|
|
||||||
class MakeParseExampleSpecTest(test.TestCase):
|
class MakeParseExampleSpecTest(test.TestCase):
|
||||||
|
|
||||||
|
@ -29,9 +29,11 @@ import threading
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
_portpicker_import_error = None
|
||||||
try:
|
try:
|
||||||
import portpicker # pylint: disable=g-import-not-at-top
|
import portpicker # pylint: disable=g-import-not-at-top
|
||||||
except ImportError as _portpicker_import_error:
|
except ImportError as _error:
|
||||||
|
_portpicker_import_error = _error
|
||||||
portpicker = None
|
portpicker = None
|
||||||
|
|
||||||
# pylint: disable=g-import-not-at-top
|
# pylint: disable=g-import-not-at-top
|
||||||
@ -820,8 +822,8 @@ def create_local_cluster(num_workers, num_ps, protocol="grpc"):
|
|||||||
Raises:
|
Raises:
|
||||||
ImportError: if portpicker module was not found at load time
|
ImportError: if portpicker module was not found at load time
|
||||||
"""
|
"""
|
||||||
if not portpicker:
|
if _portpicker_import_error:
|
||||||
raise _portpicker_import_error
|
raise _portpicker_import_error # pylint: disable=raising-bad-type
|
||||||
worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
|
worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
|
||||||
ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
|
ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
|
||||||
cluster_dict = {
|
cluster_dict = {
|
||||||
|
@ -42,7 +42,7 @@ import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class Plus1RNNCell(rnn_cell_impl._RNNCell):
|
class Plus1RNNCell(rnn_cell_impl.RNNCell):
|
||||||
"""RNN Cell generating (output, new_state) = (input + 1, state + 1)."""
|
"""RNN Cell generating (output, new_state) = (input + 1, state + 1)."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -57,6 +57,24 @@ class Plus1RNNCell(rnn_cell_impl._RNNCell):
|
|||||||
return (input_ + 1, state + 1)
|
return (input_ + 1, state + 1)
|
||||||
|
|
||||||
|
|
||||||
|
class ScalarStateRNNCell(rnn_cell_impl.RNNCell):
|
||||||
|
"""RNN Cell generating (output, new_state) = (input + 1, state + 1)."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_size(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_size(self):
|
||||||
|
return tensor_shape.TensorShape([])
|
||||||
|
|
||||||
|
def zero_state(self, batch_size, dtype):
|
||||||
|
return array_ops.zeros([], dtype=dtypes.int32)
|
||||||
|
|
||||||
|
def __call__(self, input_, state, scope=None):
|
||||||
|
return (input_, state + 1)
|
||||||
|
|
||||||
|
|
||||||
class RNNTest(test.TestCase):
|
class RNNTest(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -78,6 +78,9 @@ See the @{$python/nn} guide.
|
|||||||
@@dynamic_rnn
|
@@dynamic_rnn
|
||||||
@@bidirectional_dynamic_rnn
|
@@bidirectional_dynamic_rnn
|
||||||
@@raw_rnn
|
@@raw_rnn
|
||||||
|
@@static_rnn
|
||||||
|
@@static_state_saving_rnn
|
||||||
|
@@static_bidirectional_rnn
|
||||||
@@ctc_loss
|
@@ctc_loss
|
||||||
@@ctc_greedy_decoder
|
@@ctc_greedy_decoder
|
||||||
@@ctc_beam_search_decoder
|
@@ctc_beam_search_decoder
|
||||||
@ -113,14 +116,15 @@ from tensorflow.python.util.all_util import remove_undocumented
|
|||||||
|
|
||||||
# Bring more nn-associated functionality into this package.
|
# Bring more nn-associated functionality into this package.
|
||||||
# go/tf-wildcard-import
|
# go/tf-wildcard-import
|
||||||
# pylint: disable=wildcard-import
|
# pylint: disable=wildcard-import,unused-import
|
||||||
from tensorflow.python.ops.ctc_ops import *
|
from tensorflow.python.ops.ctc_ops import *
|
||||||
from tensorflow.python.ops.nn_impl import *
|
from tensorflow.python.ops.nn_impl import *
|
||||||
from tensorflow.python.ops.nn_ops import *
|
from tensorflow.python.ops.nn_ops import *
|
||||||
from tensorflow.python.ops.candidate_sampling_ops import *
|
from tensorflow.python.ops.candidate_sampling_ops import *
|
||||||
from tensorflow.python.ops.embedding_ops import *
|
from tensorflow.python.ops.embedding_ops import *
|
||||||
from tensorflow.python.ops.rnn import *
|
from tensorflow.python.ops.rnn import *
|
||||||
# pylint: enable=wildcard-import
|
from tensorflow.python.ops import rnn_cell
|
||||||
|
# pylint: enable=wildcard-import,unused-import
|
||||||
|
|
||||||
|
|
||||||
# TODO(cwhipkey): sigmoid and tanh should not be exposed from tf.nn.
|
# TODO(cwhipkey): sigmoid and tanh should not be exposed from tf.nn.
|
||||||
@ -135,6 +139,7 @@ _allowed_symbols = [
|
|||||||
"lrn", # Excluded in gen_docs_combined.
|
"lrn", # Excluded in gen_docs_combined.
|
||||||
"relu_layer", # Excluded in gen_docs_combined.
|
"relu_layer", # Excluded in gen_docs_combined.
|
||||||
"xw_plus_b", # Excluded in gen_docs_combined.
|
"xw_plus_b", # Excluded in gen_docs_combined.
|
||||||
|
"rnn_cell", # rnn_cell is a submodule of tf.nn.
|
||||||
]
|
]
|
||||||
|
|
||||||
remove_undocumented(__name__, _allowed_symbols,
|
remove_undocumented(__name__, _allowed_symbols,
|
||||||
|
@ -13,8 +13,16 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
"""RNN helpers for TensorFlow models."""
|
"""RNN helpers for TensorFlow models.
|
||||||
|
|
||||||
|
|
||||||
|
@@bidirectional_dynamic_rnn
|
||||||
|
@@dynamic_rnn
|
||||||
|
@@raw_rnn
|
||||||
|
@@static_rnn
|
||||||
|
@@static_state_saving_rnn
|
||||||
|
@@static_bidirectional_rnn
|
||||||
|
"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
@ -1038,3 +1046,351 @@ def raw_rnn(cell, loop_fn,
|
|||||||
final_loop_state = None
|
final_loop_state = None
|
||||||
|
|
||||||
return (emit_ta, final_state, final_loop_state)
|
return (emit_ta, final_state, final_loop_state)
|
||||||
|
|
||||||
|
|
||||||
|
def static_rnn(cell,
|
||||||
|
inputs,
|
||||||
|
initial_state=None,
|
||||||
|
dtype=None,
|
||||||
|
sequence_length=None,
|
||||||
|
scope=None):
|
||||||
|
"""Creates a recurrent neural network specified by RNNCell `cell`.
|
||||||
|
|
||||||
|
The simplest form of RNN network generated is:
|
||||||
|
|
||||||
|
```python
|
||||||
|
state = cell.zero_state(...)
|
||||||
|
outputs = []
|
||||||
|
for input_ in inputs:
|
||||||
|
output, state = cell(input_, state)
|
||||||
|
outputs.append(output)
|
||||||
|
return (outputs, state)
|
||||||
|
```
|
||||||
|
However, a few other options are available:
|
||||||
|
|
||||||
|
An initial state can be provided.
|
||||||
|
If the sequence_length vector is provided, dynamic calculation is performed.
|
||||||
|
This method of calculation does not compute the RNN steps past the maximum
|
||||||
|
sequence length of the minibatch (thus saving computational time),
|
||||||
|
and properly propagates the state at an example's sequence length
|
||||||
|
to the final state output.
|
||||||
|
|
||||||
|
The dynamic calculation performed is, at time `t` for batch row `b`,
|
||||||
|
|
||||||
|
```python
|
||||||
|
(output, state)(b, t) =
|
||||||
|
(t >= sequence_length(b))
|
||||||
|
? (zeros(cell.output_size), states(b, sequence_length(b) - 1))
|
||||||
|
: cell(input(b, t), state(b, t - 1))
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cell: An instance of RNNCell.
|
||||||
|
inputs: A length T list of inputs, each a `Tensor` of shape
|
||||||
|
`[batch_size, input_size]`, or a nested tuple of such elements.
|
||||||
|
initial_state: (optional) An initial state for the RNN.
|
||||||
|
If `cell.state_size` is an integer, this must be
|
||||||
|
a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`.
|
||||||
|
If `cell.state_size` is a tuple, this should be a tuple of
|
||||||
|
tensors having shapes `[batch_size, s] for s in cell.state_size`.
|
||||||
|
dtype: (optional) The data type for the initial state and expected output.
|
||||||
|
Required if initial_state is not provided or RNN state has a heterogeneous
|
||||||
|
dtype.
|
||||||
|
sequence_length: Specifies the length of each sequence in inputs.
|
||||||
|
An int32 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`.
|
||||||
|
scope: VariableScope for the created subgraph; defaults to "rnn".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A pair (outputs, state) where:
|
||||||
|
|
||||||
|
- outputs is a length T list of outputs (one for each input), or a nested
|
||||||
|
tuple of such elements.
|
||||||
|
- state is the final state
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `cell` is not an instance of RNNCell.
|
||||||
|
ValueError: If `inputs` is `None` or an empty list, or if the input depth
|
||||||
|
(column size) cannot be inferred from inputs via shape inference.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not _like_rnncell(cell):
|
||||||
|
raise TypeError("cell must be an instance of RNNCell")
|
||||||
|
if not nest.is_sequence(inputs):
|
||||||
|
raise TypeError("inputs must be a sequence")
|
||||||
|
if not inputs:
|
||||||
|
raise ValueError("inputs must not be empty")
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
# Create a new scope in which the caching device is either
|
||||||
|
# determined by the parent scope, or is set to place the cached
|
||||||
|
# Variable using the same placement as for the rest of the RNN.
|
||||||
|
with vs.variable_scope(scope or "rnn") as varscope:
|
||||||
|
if varscope.caching_device is None:
|
||||||
|
varscope.set_caching_device(lambda op: op.device)
|
||||||
|
|
||||||
|
# Obtain the first sequence of the input
|
||||||
|
first_input = inputs
|
||||||
|
while nest.is_sequence(first_input):
|
||||||
|
first_input = first_input[0]
|
||||||
|
|
||||||
|
# Temporarily avoid EmbeddingWrapper and seq2seq badness
|
||||||
|
# TODO(lukaszkaiser): remove EmbeddingWrapper
|
||||||
|
if first_input.get_shape().ndims != 1:
|
||||||
|
|
||||||
|
input_shape = first_input.get_shape().with_rank_at_least(2)
|
||||||
|
fixed_batch_size = input_shape[0]
|
||||||
|
|
||||||
|
flat_inputs = nest.flatten(inputs)
|
||||||
|
for flat_input in flat_inputs:
|
||||||
|
input_shape = flat_input.get_shape().with_rank_at_least(2)
|
||||||
|
batch_size, input_size = input_shape[0], input_shape[1:]
|
||||||
|
fixed_batch_size.merge_with(batch_size)
|
||||||
|
for i, size in enumerate(input_size):
|
||||||
|
if size.value is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Input size (dimension %d of inputs) must be accessible via "
|
||||||
|
"shape inference, but saw value None." % i)
|
||||||
|
else:
|
||||||
|
fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0]
|
||||||
|
|
||||||
|
if fixed_batch_size.value:
|
||||||
|
batch_size = fixed_batch_size.value
|
||||||
|
else:
|
||||||
|
batch_size = array_ops.shape(first_input)[0]
|
||||||
|
if initial_state is not None:
|
||||||
|
state = initial_state
|
||||||
|
else:
|
||||||
|
if not dtype:
|
||||||
|
raise ValueError("If no initial_state is provided, "
|
||||||
|
"dtype must be specified")
|
||||||
|
state = cell.zero_state(batch_size, dtype)
|
||||||
|
|
||||||
|
if sequence_length is not None: # Prepare variables
|
||||||
|
sequence_length = ops.convert_to_tensor(
|
||||||
|
sequence_length, name="sequence_length")
|
||||||
|
if sequence_length.get_shape().ndims not in (None, 1):
|
||||||
|
raise ValueError(
|
||||||
|
"sequence_length must be a vector of length batch_size")
|
||||||
|
|
||||||
|
def _create_zero_output(output_size):
|
||||||
|
# convert int to TensorShape if necessary
|
||||||
|
size = _concat(batch_size, output_size)
|
||||||
|
output = array_ops.zeros(
|
||||||
|
array_ops.stack(size), _infer_state_dtype(dtype, state))
|
||||||
|
shape = _concat(fixed_batch_size.value, output_size, static=True)
|
||||||
|
output.set_shape(tensor_shape.TensorShape(shape))
|
||||||
|
return output
|
||||||
|
|
||||||
|
output_size = cell.output_size
|
||||||
|
flat_output_size = nest.flatten(output_size)
|
||||||
|
flat_zero_output = tuple(
|
||||||
|
_create_zero_output(size) for size in flat_output_size)
|
||||||
|
zero_output = nest.pack_sequence_as(
|
||||||
|
structure=output_size, flat_sequence=flat_zero_output)
|
||||||
|
|
||||||
|
sequence_length = math_ops.to_int32(sequence_length)
|
||||||
|
min_sequence_length = math_ops.reduce_min(sequence_length)
|
||||||
|
max_sequence_length = math_ops.reduce_max(sequence_length)
|
||||||
|
|
||||||
|
for time, input_ in enumerate(inputs):
|
||||||
|
if time > 0:
|
||||||
|
varscope.reuse_variables()
|
||||||
|
# pylint: disable=cell-var-from-loop
|
||||||
|
call_cell = lambda: cell(input_, state)
|
||||||
|
# pylint: enable=cell-var-from-loop
|
||||||
|
if sequence_length is not None:
|
||||||
|
(output, state) = _rnn_step(
|
||||||
|
time=time,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
min_sequence_length=min_sequence_length,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
zero_output=zero_output,
|
||||||
|
state=state,
|
||||||
|
call_cell=call_cell,
|
||||||
|
state_size=cell.state_size)
|
||||||
|
else:
|
||||||
|
(output, state) = call_cell()
|
||||||
|
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
return (outputs, state)
|
||||||
|
|
||||||
|
|
||||||
|
def static_state_saving_rnn(cell,
|
||||||
|
inputs,
|
||||||
|
state_saver,
|
||||||
|
state_name,
|
||||||
|
sequence_length=None,
|
||||||
|
scope=None):
|
||||||
|
"""RNN that accepts a state saver for time-truncated RNN calculation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cell: An instance of `RNNCell`.
|
||||||
|
inputs: A length T list of inputs, each a `Tensor` of shape
|
||||||
|
`[batch_size, input_size]`.
|
||||||
|
state_saver: A state saver object with methods `state` and `save_state`.
|
||||||
|
state_name: Python string or tuple of strings. The name to use with the
|
||||||
|
state_saver. If the cell returns tuples of states (i.e.,
|
||||||
|
`cell.state_size` is a tuple) then `state_name` should be a tuple of
|
||||||
|
strings having the same length as `cell.state_size`. Otherwise it should
|
||||||
|
be a single string.
|
||||||
|
sequence_length: (optional) An int32/int64 vector size [batch_size].
|
||||||
|
See the documentation for rnn() for more details about sequence_length.
|
||||||
|
scope: VariableScope for the created subgraph; defaults to "rnn".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A pair (outputs, state) where:
|
||||||
|
outputs is a length T list of outputs (one for each input)
|
||||||
|
states is the final state
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `cell` is not an instance of RNNCell.
|
||||||
|
ValueError: If `inputs` is `None` or an empty list, or if the arity and
|
||||||
|
type of `state_name` does not match that of `cell.state_size`.
|
||||||
|
"""
|
||||||
|
state_size = cell.state_size
|
||||||
|
state_is_tuple = nest.is_sequence(state_size)
|
||||||
|
state_name_tuple = nest.is_sequence(state_name)
|
||||||
|
|
||||||
|
if state_is_tuple != state_name_tuple:
|
||||||
|
raise ValueError("state_name should be the same type as cell.state_size. "
|
||||||
|
"state_name: %s, cell.state_size: %s" % (str(state_name),
|
||||||
|
str(state_size)))
|
||||||
|
|
||||||
|
if state_is_tuple:
|
||||||
|
state_name_flat = nest.flatten(state_name)
|
||||||
|
state_size_flat = nest.flatten(state_size)
|
||||||
|
|
||||||
|
if len(state_name_flat) != len(state_size_flat):
|
||||||
|
raise ValueError("#elems(state_name) != #elems(state_size): %d vs. %d" %
|
||||||
|
(len(state_name_flat), len(state_size_flat)))
|
||||||
|
|
||||||
|
initial_state = nest.pack_sequence_as(
|
||||||
|
structure=state_size,
|
||||||
|
flat_sequence=[state_saver.state(s) for s in state_name_flat])
|
||||||
|
else:
|
||||||
|
initial_state = state_saver.state(state_name)
|
||||||
|
|
||||||
|
(outputs, state) = static_rnn(
|
||||||
|
cell,
|
||||||
|
inputs,
|
||||||
|
initial_state=initial_state,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
scope=scope)
|
||||||
|
|
||||||
|
if state_is_tuple:
|
||||||
|
flat_state = nest.flatten(state)
|
||||||
|
state_name = nest.flatten(state_name)
|
||||||
|
save_state = [
|
||||||
|
state_saver.save_state(name, substate)
|
||||||
|
for name, substate in zip(state_name, flat_state)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
save_state = [state_saver.save_state(state_name, state)]
|
||||||
|
|
||||||
|
with ops.control_dependencies(save_state):
|
||||||
|
last_output = outputs[-1]
|
||||||
|
flat_last_output = nest.flatten(last_output)
|
||||||
|
flat_last_output = [
|
||||||
|
array_ops.identity(output) for output in flat_last_output
|
||||||
|
]
|
||||||
|
outputs[-1] = nest.pack_sequence_as(
|
||||||
|
structure=last_output, flat_sequence=flat_last_output)
|
||||||
|
|
||||||
|
return (outputs, state)
|
||||||
|
|
||||||
|
|
||||||
|
def static_bidirectional_rnn(cell_fw,
|
||||||
|
cell_bw,
|
||||||
|
inputs,
|
||||||
|
initial_state_fw=None,
|
||||||
|
initial_state_bw=None,
|
||||||
|
dtype=None,
|
||||||
|
sequence_length=None,
|
||||||
|
scope=None):
|
||||||
|
"""Creates a bidirectional recurrent neural network.
|
||||||
|
|
||||||
|
Similar to the unidirectional case above (rnn) but takes input and builds
|
||||||
|
independent forward and backward RNNs with the final forward and backward
|
||||||
|
outputs depth-concatenated, such that the output will have the format
|
||||||
|
[time][batch][cell_fw.output_size + cell_bw.output_size]. The input_size of
|
||||||
|
forward and backward cell must match. The initial state for both directions
|
||||||
|
is zero by default (but can be set optionally) and no intermediate states are
|
||||||
|
ever returned -- the network is fully unrolled for the given (passed in)
|
||||||
|
length(s) of the sequence(s) or completely unrolled if length(s) is not given.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cell_fw: An instance of RNNCell, to be used for forward direction.
|
||||||
|
cell_bw: An instance of RNNCell, to be used for backward direction.
|
||||||
|
inputs: A length T list of inputs, each a tensor of shape
|
||||||
|
[batch_size, input_size], or a nested tuple of such elements.
|
||||||
|
initial_state_fw: (optional) An initial state for the forward RNN.
|
||||||
|
This must be a tensor of appropriate type and shape
|
||||||
|
`[batch_size, cell_fw.state_size]`.
|
||||||
|
If `cell_fw.state_size` is a tuple, this should be a tuple of
|
||||||
|
tensors having shapes `[batch_size, s] for s in cell_fw.state_size`.
|
||||||
|
initial_state_bw: (optional) Same as for `initial_state_fw`, but using
|
||||||
|
the corresponding properties of `cell_bw`.
|
||||||
|
dtype: (optional) The data type for the initial state. Required if
|
||||||
|
either of the initial states are not provided.
|
||||||
|
sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
|
||||||
|
containing the actual lengths for each of the sequences.
|
||||||
|
scope: VariableScope for the created subgraph; defaults to
|
||||||
|
"bidirectional_rnn"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple (outputs, output_state_fw, output_state_bw) where:
|
||||||
|
outputs is a length `T` list of outputs (one for each input), which
|
||||||
|
are depth-concatenated forward and backward outputs.
|
||||||
|
output_state_fw is the final state of the forward rnn.
|
||||||
|
output_state_bw is the final state of the backward rnn.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
|
||||||
|
ValueError: If inputs is None or an empty list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not _like_rnncell(cell_fw):
|
||||||
|
raise TypeError("cell_fw must be an instance of RNNCell")
|
||||||
|
if not _like_rnncell(cell_bw):
|
||||||
|
raise TypeError("cell_bw must be an instance of RNNCell")
|
||||||
|
if not nest.is_sequence(inputs):
|
||||||
|
raise TypeError("inputs must be a sequence")
|
||||||
|
if not inputs:
|
||||||
|
raise ValueError("inputs must not be empty")
|
||||||
|
|
||||||
|
with vs.variable_scope(scope or "bidirectional_rnn"):
|
||||||
|
# Forward direction
|
||||||
|
with vs.variable_scope("fw") as fw_scope:
|
||||||
|
output_fw, output_state_fw = static_rnn(
|
||||||
|
cell_fw,
|
||||||
|
inputs,
|
||||||
|
initial_state_fw,
|
||||||
|
dtype,
|
||||||
|
sequence_length,
|
||||||
|
scope=fw_scope)
|
||||||
|
|
||||||
|
# Backward direction
|
||||||
|
with vs.variable_scope("bw") as bw_scope:
|
||||||
|
reversed_inputs = _reverse_seq(inputs, sequence_length)
|
||||||
|
tmp, output_state_bw = static_rnn(
|
||||||
|
cell_bw,
|
||||||
|
reversed_inputs,
|
||||||
|
initial_state_bw,
|
||||||
|
dtype,
|
||||||
|
sequence_length,
|
||||||
|
scope=bw_scope)
|
||||||
|
|
||||||
|
output_bw = _reverse_seq(tmp, sequence_length)
|
||||||
|
# Concat each of the forward/backward outputs
|
||||||
|
flat_output_fw = nest.flatten(output_fw)
|
||||||
|
flat_output_bw = nest.flatten(output_bw)
|
||||||
|
|
||||||
|
flat_outputs = tuple(
|
||||||
|
array_ops.concat([fw, bw], 1)
|
||||||
|
for fw, bw in zip(flat_output_fw, flat_output_bw))
|
||||||
|
|
||||||
|
outputs = nest.pack_sequence_as(
|
||||||
|
structure=output_fw, flat_sequence=flat_outputs)
|
||||||
|
|
||||||
|
return (outputs, output_state_fw, output_state_bw)
|
||||||
|
51
tensorflow/python/ops/rnn_cell.py
Normal file
51
tensorflow/python/ops/rnn_cell.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Module for constructing RNN Cells.
|
||||||
|
|
||||||
|
## Base interface for all RNN Cells
|
||||||
|
|
||||||
|
@@RNNCell
|
||||||
|
|
||||||
|
## RNN Cells for use with TensorFlow's core RNN methods
|
||||||
|
|
||||||
|
@@BasicRNNCell
|
||||||
|
@@BasicLSTMCell
|
||||||
|
@@GRUCell
|
||||||
|
@@LSTMCell
|
||||||
|
|
||||||
|
## Classes storing split `RNNCell` state
|
||||||
|
|
||||||
|
@@LSTMStateTuple
|
||||||
|
|
||||||
|
## RNN Cell wrappers (RNNCells that wrap other RNNCells)
|
||||||
|
|
||||||
|
@@MultiRNNCell
|
||||||
|
@@DropoutWrapper
|
||||||
|
@@DeviceWrapper
|
||||||
|
@@ResidualWrapper
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# go/tf-wildcard-import
|
||||||
|
# pylint: disable=wildcard-import
|
||||||
|
from tensorflow.python.ops.rnn_cell_impl import *
|
||||||
|
# pylint: enable=wildcard-import
|
||||||
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
|
|
||||||
|
_allowed_symbols = []
|
||||||
|
|
||||||
|
remove_undocumented(__name__, _allowed_symbols)
|
@ -12,18 +12,22 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
"""Module implementing RNN Cells.
|
"""Module implementing RNN Cells.
|
||||||
|
|
||||||
This module contains the abstract definition of a RNN cell: `_RNNCell`.
|
This module provides a number of basic commonly used RNN cells, such as LSTM
|
||||||
Actual implementations of various types of RNN cells are located in
|
(Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number of
|
||||||
`tensorflow.contrib`.
|
operators that allow adding dropouts, projections, or embeddings for inputs.
|
||||||
|
Constructing multi-layer cells is supported by the class `MultiRNNCell`, or by
|
||||||
|
calling the `rnn` ops several times.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import hashlib
|
||||||
|
import numbers
|
||||||
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -31,11 +35,22 @@ from tensorflow.python.framework import tensor_shape
|
|||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.layers import base as base_layer
|
from tensorflow.python.layers import base as base_layer
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import clip_ops
|
||||||
|
from tensorflow.python.ops import init_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import nn_ops
|
||||||
|
from tensorflow.python.ops import partitioned_variables
|
||||||
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import variable_scope as vs
|
from tensorflow.python.ops import variable_scope as vs
|
||||||
from tensorflow.python.ops import variables as tf_variables
|
from tensorflow.python.ops import variables as tf_variables
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
|
_BIAS_VARIABLE_NAME = "bias"
|
||||||
|
_WEIGHTS_VARIABLE_NAME = "kernel"
|
||||||
|
|
||||||
|
|
||||||
def _like_rnncell(cell):
|
def _like_rnncell(cell):
|
||||||
"""Checks that a given object is an RNNCell by using duck typing."""
|
"""Checks that a given object is an RNNCell by using duck typing."""
|
||||||
conditions = [hasattr(cell, "output_size"), hasattr(cell, "state_size"),
|
conditions = [hasattr(cell, "output_size"), hasattr(cell, "state_size"),
|
||||||
@ -115,7 +130,7 @@ def _zero_state_tensors(state_size, batch_size, dtype):
|
|||||||
return nest.map_structure(get_state_shape, state_size)
|
return nest.map_structure(get_state_shape, state_size)
|
||||||
|
|
||||||
|
|
||||||
class _RNNCell(base_layer.Layer):
|
class RNNCell(base_layer.Layer):
|
||||||
"""Abstract object representing an RNN cell.
|
"""Abstract object representing an RNN cell.
|
||||||
|
|
||||||
Every `RNNCell` must have the properties below and implement `call` with
|
Every `RNNCell` must have the properties below and implement `call` with
|
||||||
@ -158,11 +173,11 @@ class _RNNCell(base_layer.Layer):
|
|||||||
if scope is not None:
|
if scope is not None:
|
||||||
with vs.variable_scope(scope,
|
with vs.variable_scope(scope,
|
||||||
custom_getter=self._rnn_get_variable) as scope:
|
custom_getter=self._rnn_get_variable) as scope:
|
||||||
return super(_RNNCell, self).__call__(inputs, state, scope=scope)
|
return super(RNNCell, self).__call__(inputs, state, scope=scope)
|
||||||
else:
|
else:
|
||||||
with vs.variable_scope(vs.get_variable_scope(),
|
with vs.variable_scope(vs.get_variable_scope(),
|
||||||
custom_getter=self._rnn_get_variable):
|
custom_getter=self._rnn_get_variable):
|
||||||
return super(_RNNCell, self).__call__(inputs, state)
|
return super(RNNCell, self).__call__(inputs, state)
|
||||||
|
|
||||||
def _rnn_get_variable(self, getter, *args, **kwargs):
|
def _rnn_get_variable(self, getter, *args, **kwargs):
|
||||||
variable = getter(*args, **kwargs)
|
variable = getter(*args, **kwargs)
|
||||||
@ -212,3 +227,806 @@ class _RNNCell(base_layer.Layer):
|
|||||||
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
||||||
state_size = self.state_size
|
state_size = self.state_size
|
||||||
return _zero_state_tensors(state_size, batch_size, dtype)
|
return _zero_state_tensors(state_size, batch_size, dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicRNNCell(RNNCell):
|
||||||
|
"""The most basic RNN cell.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_units: int, The number of units in the LSTM cell.
|
||||||
|
activation: Nonlinearity to use. Default: `tanh`.
|
||||||
|
reuse: (optional) Python boolean describing whether to reuse variables
|
||||||
|
in an existing scope. If not `True`, and the existing scope already has
|
||||||
|
the given variables, an error is raised.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_units, activation=None, reuse=None):
|
||||||
|
super(BasicRNNCell, self).__init__(_reuse=reuse)
|
||||||
|
self._num_units = num_units
|
||||||
|
self._activation = activation or math_ops.tanh
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_size(self):
|
||||||
|
return self._num_units
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_size(self):
|
||||||
|
return self._num_units
|
||||||
|
|
||||||
|
def call(self, inputs, state):
|
||||||
|
"""Most basic RNN: output = new_state = act(W * input + U * state + B)."""
|
||||||
|
output = self._activation(_linear([inputs, state], self._num_units, True))
|
||||||
|
return output, output
|
||||||
|
|
||||||
|
|
||||||
|
class GRUCell(RNNCell):
|
||||||
|
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078)."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_units,
|
||||||
|
activation=None,
|
||||||
|
reuse=None,
|
||||||
|
kernel_initializer=None,
|
||||||
|
bias_initializer=None):
|
||||||
|
super(GRUCell, self).__init__(_reuse=reuse)
|
||||||
|
self._num_units = num_units
|
||||||
|
self._activation = activation or math_ops.tanh
|
||||||
|
self._kernel_initializer = kernel_initializer
|
||||||
|
self._bias_initializer = bias_initializer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_size(self):
|
||||||
|
return self._num_units
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_size(self):
|
||||||
|
return self._num_units
|
||||||
|
|
||||||
|
def call(self, inputs, state):
|
||||||
|
"""Gated recurrent unit (GRU) with nunits cells."""
|
||||||
|
with vs.variable_scope("gates"): # Reset gate and update gate.
|
||||||
|
# We start with bias of 1.0 to not reset and not update.
|
||||||
|
bias_ones = self._bias_initializer
|
||||||
|
if self._bias_initializer is None:
|
||||||
|
dtype = [a.dtype for a in [inputs, state]][0]
|
||||||
|
bias_ones = init_ops.constant_initializer(1.0, dtype=dtype)
|
||||||
|
value = math_ops.sigmoid(
|
||||||
|
_linear([inputs, state], 2 * self._num_units, True, bias_ones,
|
||||||
|
self._kernel_initializer))
|
||||||
|
r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
|
||||||
|
with vs.variable_scope("candidate"):
|
||||||
|
c = self._activation(
|
||||||
|
_linear([inputs, r * state], self._num_units, True,
|
||||||
|
self._bias_initializer, self._kernel_initializer))
|
||||||
|
new_h = u * state + (1 - u) * c
|
||||||
|
return new_h, new_h
|
||||||
|
|
||||||
|
|
||||||
|
_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h"))
|
||||||
|
|
||||||
|
|
||||||
|
class LSTMStateTuple(_LSTMStateTuple):
|
||||||
|
"""Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state.
|
||||||
|
|
||||||
|
Stores two elements: `(c, h)`, in that order.
|
||||||
|
|
||||||
|
Only used when `state_is_tuple=True`.
|
||||||
|
"""
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
(c, h) = self
|
||||||
|
if c.dtype != h.dtype:
|
||||||
|
raise TypeError("Inconsistent internal state: %s vs %s" %
|
||||||
|
(str(c.dtype), str(h.dtype)))
|
||||||
|
return c.dtype
|
||||||
|
|
||||||
|
|
||||||
|
class BasicLSTMCell(RNNCell):
|
||||||
|
"""Basic LSTM recurrent network cell.
|
||||||
|
|
||||||
|
The implementation is based on: http://arxiv.org/abs/1409.2329.
|
||||||
|
|
||||||
|
We add forget_bias (default: 1) to the biases of the forget gate in order to
|
||||||
|
reduce the scale of forgetting in the beginning of the training.
|
||||||
|
|
||||||
|
It does not allow cell clipping, a projection layer, and does not
|
||||||
|
use peep-hole connections: it is the basic baseline.
|
||||||
|
|
||||||
|
For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}
|
||||||
|
that follows.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_units, forget_bias=1.0,
|
||||||
|
state_is_tuple=True, activation=None, reuse=None):
|
||||||
|
"""Initialize the basic LSTM cell.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_units: int, The number of units in the LSTM cell.
|
||||||
|
forget_bias: float, The bias added to forget gates (see above).
|
||||||
|
state_is_tuple: If True, accepted and returned states are 2-tuples of
|
||||||
|
the `c_state` and `m_state`. If False, they are concatenated
|
||||||
|
along the column axis. The latter behavior will soon be deprecated.
|
||||||
|
activation: Activation function of the inner states. Default: `tanh`.
|
||||||
|
reuse: (optional) Python boolean describing whether to reuse variables
|
||||||
|
in an existing scope. If not `True`, and the existing scope already has
|
||||||
|
the given variables, an error is raised.
|
||||||
|
"""
|
||||||
|
super(BasicLSTMCell, self).__init__(_reuse=reuse)
|
||||||
|
if not state_is_tuple:
|
||||||
|
logging.warn("%s: Using a concatenated state is slower and will soon be "
|
||||||
|
"deprecated. Use state_is_tuple=True.", self)
|
||||||
|
self._num_units = num_units
|
||||||
|
self._forget_bias = forget_bias
|
||||||
|
self._state_is_tuple = state_is_tuple
|
||||||
|
self._activation = activation or math_ops.tanh
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_size(self):
|
||||||
|
return (LSTMStateTuple(self._num_units, self._num_units)
|
||||||
|
if self._state_is_tuple else 2 * self._num_units)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_size(self):
|
||||||
|
return self._num_units
|
||||||
|
|
||||||
|
def call(self, inputs, state):
|
||||||
|
"""Long short-term memory cell (LSTM)."""
|
||||||
|
sigmoid = math_ops.sigmoid
|
||||||
|
# Parameters of gates are concatenated into one multiply for efficiency.
|
||||||
|
if self._state_is_tuple:
|
||||||
|
c, h = state
|
||||||
|
else:
|
||||||
|
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
|
||||||
|
|
||||||
|
concat = _linear([inputs, h], 4 * self._num_units, True)
|
||||||
|
|
||||||
|
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
||||||
|
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
|
||||||
|
|
||||||
|
new_c = (
|
||||||
|
c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
|
||||||
|
new_h = self._activation(new_c) * sigmoid(o)
|
||||||
|
|
||||||
|
if self._state_is_tuple:
|
||||||
|
new_state = LSTMStateTuple(new_c, new_h)
|
||||||
|
else:
|
||||||
|
new_state = array_ops.concat([new_c, new_h], 1)
|
||||||
|
return new_h, new_state
|
||||||
|
|
||||||
|
|
||||||
|
class LSTMCell(RNNCell):
|
||||||
|
"""Long short-term memory unit (LSTM) recurrent network cell.
|
||||||
|
|
||||||
|
The default non-peephole implementation is based on:
|
||||||
|
|
||||||
|
http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
|
||||||
|
|
||||||
|
S. Hochreiter and J. Schmidhuber.
|
||||||
|
"Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
|
||||||
|
|
||||||
|
The peephole implementation is based on:
|
||||||
|
|
||||||
|
https://research.google.com/pubs/archive/43905.pdf
|
||||||
|
|
||||||
|
Hasim Sak, Andrew Senior, and Francoise Beaufays.
|
||||||
|
"Long short-term memory recurrent neural network architectures for
|
||||||
|
large scale acoustic modeling." INTERSPEECH, 2014.
|
||||||
|
|
||||||
|
The class uses optional peep-hole connections, optional cell clipping, and
|
||||||
|
an optional projection layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_units,
|
||||||
|
use_peepholes=False, cell_clip=None,
|
||||||
|
initializer=None, num_proj=None, proj_clip=None,
|
||||||
|
num_unit_shards=None, num_proj_shards=None,
|
||||||
|
forget_bias=1.0, state_is_tuple=True,
|
||||||
|
activation=None, reuse=None):
|
||||||
|
"""Initialize the parameters for an LSTM cell.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_units: int, The number of units in the LSTM cell
|
||||||
|
use_peepholes: bool, set True to enable diagonal/peephole connections.
|
||||||
|
cell_clip: (optional) A float value, if provided the cell state is clipped
|
||||||
|
by this value prior to the cell output activation.
|
||||||
|
initializer: (optional) The initializer to use for the weight and
|
||||||
|
projection matrices.
|
||||||
|
num_proj: (optional) int, The output dimensionality for the projection
|
||||||
|
matrices. If None, no projection is performed.
|
||||||
|
proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is
|
||||||
|
provided, then the projected values are clipped elementwise to within
|
||||||
|
`[-proj_clip, proj_clip]`.
|
||||||
|
num_unit_shards: Deprecated, will be removed by Jan. 2017.
|
||||||
|
Use a variable_scope partitioner instead.
|
||||||
|
num_proj_shards: Deprecated, will be removed by Jan. 2017.
|
||||||
|
Use a variable_scope partitioner instead.
|
||||||
|
forget_bias: Biases of the forget gate are initialized by default to 1
|
||||||
|
in order to reduce the scale of forgetting at the beginning of
|
||||||
|
the training.
|
||||||
|
state_is_tuple: If True, accepted and returned states are 2-tuples of
|
||||||
|
the `c_state` and `m_state`. If False, they are concatenated
|
||||||
|
along the column axis. This latter behavior will soon be deprecated.
|
||||||
|
activation: Activation function of the inner states. Default: `tanh`.
|
||||||
|
reuse: (optional) Python boolean describing whether to reuse variables
|
||||||
|
in an existing scope. If not `True`, and the existing scope already has
|
||||||
|
the given variables, an error is raised.
|
||||||
|
"""
|
||||||
|
super(LSTMCell, self).__init__(_reuse=reuse)
|
||||||
|
if not state_is_tuple:
|
||||||
|
logging.warn("%s: Using a concatenated state is slower and will soon be "
|
||||||
|
"deprecated. Use state_is_tuple=True.", self)
|
||||||
|
if num_unit_shards is not None or num_proj_shards is not None:
|
||||||
|
logging.warn(
|
||||||
|
"%s: The num_unit_shards and proj_unit_shards parameters are "
|
||||||
|
"deprecated and will be removed in Jan 2017. "
|
||||||
|
"Use a variable scope with a partitioner instead.", self)
|
||||||
|
|
||||||
|
self._num_units = num_units
|
||||||
|
self._use_peepholes = use_peepholes
|
||||||
|
self._cell_clip = cell_clip
|
||||||
|
self._initializer = initializer
|
||||||
|
self._num_proj = num_proj
|
||||||
|
self._proj_clip = proj_clip
|
||||||
|
self._num_unit_shards = num_unit_shards
|
||||||
|
self._num_proj_shards = num_proj_shards
|
||||||
|
self._forget_bias = forget_bias
|
||||||
|
self._state_is_tuple = state_is_tuple
|
||||||
|
self._activation = activation or math_ops.tanh
|
||||||
|
|
||||||
|
if num_proj:
|
||||||
|
self._state_size = (
|
||||||
|
LSTMStateTuple(num_units, num_proj)
|
||||||
|
if state_is_tuple else num_units + num_proj)
|
||||||
|
self._output_size = num_proj
|
||||||
|
else:
|
||||||
|
self._state_size = (
|
||||||
|
LSTMStateTuple(num_units, num_units)
|
||||||
|
if state_is_tuple else 2 * num_units)
|
||||||
|
self._output_size = num_units
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_size(self):
|
||||||
|
return self._state_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_size(self):
|
||||||
|
return self._output_size
|
||||||
|
|
||||||
|
def call(self, inputs, state):
|
||||||
|
"""Run one step of LSTM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: input Tensor, 2D, batch x num_units.
|
||||||
|
state: if `state_is_tuple` is False, this must be a state Tensor,
|
||||||
|
`2-D, batch x state_size`. If `state_is_tuple` is True, this must be a
|
||||||
|
tuple of state Tensors, both `2-D`, with column sizes `c_state` and
|
||||||
|
`m_state`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing:
|
||||||
|
|
||||||
|
- A `2-D, [batch x output_dim]`, Tensor representing the output of the
|
||||||
|
LSTM after reading `inputs` when previous state was `state`.
|
||||||
|
Here output_dim is:
|
||||||
|
num_proj if num_proj was set,
|
||||||
|
num_units otherwise.
|
||||||
|
- Tensor(s) representing the new state of LSTM after reading `inputs` when
|
||||||
|
the previous state was `state`. Same type and shape(s) as `state`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input size cannot be inferred from inputs via
|
||||||
|
static shape inference.
|
||||||
|
"""
|
||||||
|
num_proj = self._num_units if self._num_proj is None else self._num_proj
|
||||||
|
sigmoid = math_ops.sigmoid
|
||||||
|
|
||||||
|
if self._state_is_tuple:
|
||||||
|
(c_prev, m_prev) = state
|
||||||
|
else:
|
||||||
|
c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
|
||||||
|
m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
|
||||||
|
|
||||||
|
dtype = inputs.dtype
|
||||||
|
input_size = inputs.get_shape().with_rank(2)[1]
|
||||||
|
if input_size.value is None:
|
||||||
|
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
|
||||||
|
scope = vs.get_variable_scope()
|
||||||
|
with vs.variable_scope(scope, initializer=self._initializer) as unit_scope:
|
||||||
|
if self._num_unit_shards is not None:
|
||||||
|
unit_scope.set_partitioner(
|
||||||
|
partitioned_variables.fixed_size_partitioner(
|
||||||
|
self._num_unit_shards))
|
||||||
|
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
||||||
|
lstm_matrix = _linear([inputs, m_prev], 4 * self._num_units, bias=True)
|
||||||
|
i, j, f, o = array_ops.split(
|
||||||
|
value=lstm_matrix, num_or_size_splits=4, axis=1)
|
||||||
|
# Diagonal connections
|
||||||
|
if self._use_peepholes:
|
||||||
|
with vs.variable_scope(unit_scope) as projection_scope:
|
||||||
|
if self._num_unit_shards is not None:
|
||||||
|
projection_scope.set_partitioner(None)
|
||||||
|
w_f_diag = vs.get_variable(
|
||||||
|
"w_f_diag", shape=[self._num_units], dtype=dtype)
|
||||||
|
w_i_diag = vs.get_variable(
|
||||||
|
"w_i_diag", shape=[self._num_units], dtype=dtype)
|
||||||
|
w_o_diag = vs.get_variable(
|
||||||
|
"w_o_diag", shape=[self._num_units], dtype=dtype)
|
||||||
|
|
||||||
|
if self._use_peepholes:
|
||||||
|
c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
|
||||||
|
sigmoid(i + w_i_diag * c_prev) * self._activation(j))
|
||||||
|
else:
|
||||||
|
c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
|
||||||
|
self._activation(j))
|
||||||
|
|
||||||
|
if self._cell_clip is not None:
|
||||||
|
# pylint: disable=invalid-unary-operand-type
|
||||||
|
c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
|
||||||
|
# pylint: enable=invalid-unary-operand-type
|
||||||
|
if self._use_peepholes:
|
||||||
|
m = sigmoid(o + w_o_diag * c) * self._activation(c)
|
||||||
|
else:
|
||||||
|
m = sigmoid(o) * self._activation(c)
|
||||||
|
|
||||||
|
if self._num_proj is not None:
|
||||||
|
with vs.variable_scope("projection") as proj_scope:
|
||||||
|
if self._num_proj_shards is not None:
|
||||||
|
proj_scope.set_partitioner(
|
||||||
|
partitioned_variables.fixed_size_partitioner(
|
||||||
|
self._num_proj_shards))
|
||||||
|
m = _linear(m, self._num_proj, bias=False)
|
||||||
|
|
||||||
|
if self._proj_clip is not None:
|
||||||
|
# pylint: disable=invalid-unary-operand-type
|
||||||
|
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
|
||||||
|
# pylint: enable=invalid-unary-operand-type
|
||||||
|
|
||||||
|
new_state = (LSTMStateTuple(c, m) if self._state_is_tuple else
|
||||||
|
array_ops.concat([c, m], 1))
|
||||||
|
return m, new_state
|
||||||
|
|
||||||
|
|
||||||
|
def _enumerated_map_structure(map_fn, *args, **kwargs):
|
||||||
|
ix = [0]
|
||||||
|
def enumerated_fn(*inner_args, **inner_kwargs):
|
||||||
|
r = map_fn(ix[0], *inner_args, **inner_kwargs)
|
||||||
|
ix[0] += 1
|
||||||
|
return r
|
||||||
|
return nest.map_structure(enumerated_fn, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class DropoutWrapper(RNNCell):
|
||||||
|
"""Operator adding dropout to inputs and outputs of the given cell."""
|
||||||
|
|
||||||
|
def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0,
|
||||||
|
state_keep_prob=1.0, variational_recurrent=False,
|
||||||
|
input_size=None, dtype=None, seed=None):
|
||||||
|
"""Create a cell with added input, state, and/or output dropout.
|
||||||
|
|
||||||
|
If `variational_recurrent` is set to `True` (**NOT** the default behavior),
|
||||||
|
then the the same dropout mask is applied at every step, as described in:
|
||||||
|
|
||||||
|
Y. Gal, Z Ghahramani. "A Theoretically Grounded Application of Dropout in
|
||||||
|
Recurrent Neural Networks". https://arxiv.org/abs/1512.05287
|
||||||
|
|
||||||
|
Otherwise a different dropout mask is applied at every time step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cell: an RNNCell, a projection to output_size is added to it.
|
||||||
|
input_keep_prob: unit Tensor or float between 0 and 1, input keep
|
||||||
|
probability; if it is constant and 1, no input dropout will be added.
|
||||||
|
output_keep_prob: unit Tensor or float between 0 and 1, output keep
|
||||||
|
probability; if it is constant and 1, no output dropout will be added.
|
||||||
|
state_keep_prob: unit Tensor or float between 0 and 1, output keep
|
||||||
|
probability; if it is constant and 1, no output dropout will be added.
|
||||||
|
State dropout is performed on the *output* states of the cell.
|
||||||
|
variational_recurrent: Python bool. If `True`, then the same
|
||||||
|
dropout pattern is applied across all time steps per run call.
|
||||||
|
If this parameter is set, `input_size` **must** be provided.
|
||||||
|
input_size: (optional) (possibly nested tuple of) `TensorShape` objects
|
||||||
|
containing the depth(s) of the input tensors expected to be passed in to
|
||||||
|
the `DropoutWrapper`. Required and used **iff**
|
||||||
|
`variational_recurrent = True` and `input_keep_prob < 1`.
|
||||||
|
dtype: (optional) The `dtype` of the input, state, and output tensors.
|
||||||
|
Required and used **iff** `variational_recurrent = True`.
|
||||||
|
seed: (optional) integer, the randomness seed.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if cell is not an RNNCell.
|
||||||
|
ValueError: if any of the keep_probs are not between 0 and 1.
|
||||||
|
"""
|
||||||
|
if not _like_rnncell(cell):
|
||||||
|
raise TypeError("The parameter cell is not a RNNCell.")
|
||||||
|
with ops.name_scope("DropoutWrapperInit"):
|
||||||
|
def tensor_and_const_value(v):
|
||||||
|
tensor_value = ops.convert_to_tensor(v)
|
||||||
|
const_value = tensor_util.constant_value(tensor_value)
|
||||||
|
return (tensor_value, const_value)
|
||||||
|
for prob, attr in [(input_keep_prob, "input_keep_prob"),
|
||||||
|
(state_keep_prob, "state_keep_prob"),
|
||||||
|
(output_keep_prob, "output_keep_prob")]:
|
||||||
|
tensor_prob, const_prob = tensor_and_const_value(prob)
|
||||||
|
if const_prob is not None:
|
||||||
|
if const_prob < 0 or const_prob > 1:
|
||||||
|
raise ValueError("Parameter %s must be between 0 and 1: %d"
|
||||||
|
% (attr, const_prob))
|
||||||
|
setattr(self, "_%s" % attr, float(const_prob))
|
||||||
|
else:
|
||||||
|
setattr(self, "_%s" % attr, tensor_prob)
|
||||||
|
|
||||||
|
# Set cell, variational_recurrent, seed before running the code below
|
||||||
|
self._cell = cell
|
||||||
|
self._variational_recurrent = variational_recurrent
|
||||||
|
self._seed = seed
|
||||||
|
|
||||||
|
self._recurrent_input_noise = None
|
||||||
|
self._recurrent_state_noise = None
|
||||||
|
self._recurrent_output_noise = None
|
||||||
|
|
||||||
|
if variational_recurrent:
|
||||||
|
if dtype is None:
|
||||||
|
raise ValueError(
|
||||||
|
"When variational_recurrent=True, dtype must be provided")
|
||||||
|
|
||||||
|
def convert_to_batch_shape(s):
|
||||||
|
# Prepend a 1 for the batch dimension; for recurrent
|
||||||
|
# variational dropout we use the same dropout mask for all
|
||||||
|
# batch elements.
|
||||||
|
return array_ops.concat(
|
||||||
|
([1], tensor_shape.TensorShape(s).as_list()), 0)
|
||||||
|
|
||||||
|
def batch_noise(s, inner_seed):
|
||||||
|
shape = convert_to_batch_shape(s)
|
||||||
|
return random_ops.random_uniform(shape, seed=inner_seed, dtype=dtype)
|
||||||
|
|
||||||
|
if (not isinstance(self._input_keep_prob, numbers.Real) or
|
||||||
|
self._input_keep_prob < 1.0):
|
||||||
|
if input_size is None:
|
||||||
|
raise ValueError(
|
||||||
|
"When variational_recurrent=True and input_keep_prob < 1.0 or "
|
||||||
|
"is unknown, input_size must be provided")
|
||||||
|
self._recurrent_input_noise = _enumerated_map_structure(
|
||||||
|
lambda i, s: batch_noise(s, inner_seed=self._gen_seed("input", i)),
|
||||||
|
input_size)
|
||||||
|
self._recurrent_state_noise = _enumerated_map_structure(
|
||||||
|
lambda i, s: batch_noise(s, inner_seed=self._gen_seed("state", i)),
|
||||||
|
cell.state_size)
|
||||||
|
self._recurrent_output_noise = _enumerated_map_structure(
|
||||||
|
lambda i, s: batch_noise(s, inner_seed=self._gen_seed("output", i)),
|
||||||
|
cell.output_size)
|
||||||
|
|
||||||
|
def _gen_seed(self, salt_prefix, index):
|
||||||
|
if self._seed is None:
|
||||||
|
return None
|
||||||
|
salt = "%s_%d" % (salt_prefix, index)
|
||||||
|
string = (str(self._seed) + salt).encode("utf-8")
|
||||||
|
return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_size(self):
|
||||||
|
return self._cell.state_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_size(self):
|
||||||
|
return self._cell.output_size
|
||||||
|
|
||||||
|
def zero_state(self, batch_size, dtype):
|
||||||
|
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
||||||
|
return self._cell.zero_state(batch_size, dtype)
|
||||||
|
|
||||||
|
def _variational_recurrent_dropout_value(
|
||||||
|
self, index, value, noise, keep_prob):
|
||||||
|
"""Performs dropout given the pre-calculated noise tensor."""
|
||||||
|
# uniform [keep_prob, 1.0 + keep_prob)
|
||||||
|
random_tensor = keep_prob + noise
|
||||||
|
|
||||||
|
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
|
||||||
|
binary_tensor = math_ops.floor(random_tensor)
|
||||||
|
ret = math_ops.div(value, keep_prob) * binary_tensor
|
||||||
|
ret.set_shape(value.get_shape())
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def _dropout(self, values, salt_prefix, recurrent_noise, keep_prob):
|
||||||
|
"""Decides whether to perform standard dropout or recurrent dropout."""
|
||||||
|
if not self._variational_recurrent:
|
||||||
|
def dropout(i, v):
|
||||||
|
return nn_ops.dropout(
|
||||||
|
v, keep_prob=keep_prob, seed=self._gen_seed(salt_prefix, i))
|
||||||
|
return _enumerated_map_structure(dropout, values)
|
||||||
|
else:
|
||||||
|
def dropout(i, v, n):
|
||||||
|
return self._variational_recurrent_dropout_value(i, v, n, keep_prob)
|
||||||
|
return _enumerated_map_structure(dropout, values, recurrent_noise)
|
||||||
|
|
||||||
|
def __call__(self, inputs, state, scope=None):
|
||||||
|
"""Run the cell with the declared dropouts."""
|
||||||
|
def _should_dropout(p):
|
||||||
|
return (not isinstance(p, float)) or p < 1
|
||||||
|
|
||||||
|
if _should_dropout(self._input_keep_prob):
|
||||||
|
inputs = self._dropout(inputs, "input",
|
||||||
|
self._recurrent_input_noise,
|
||||||
|
self._input_keep_prob)
|
||||||
|
output, new_state = self._cell(inputs, state, scope)
|
||||||
|
if _should_dropout(self._state_keep_prob):
|
||||||
|
new_state = self._dropout(new_state, "state",
|
||||||
|
self._recurrent_state_noise,
|
||||||
|
self._state_keep_prob)
|
||||||
|
if _should_dropout(self._output_keep_prob):
|
||||||
|
output = self._dropout(output, "output",
|
||||||
|
self._recurrent_output_noise,
|
||||||
|
self._output_keep_prob)
|
||||||
|
return output, new_state
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualWrapper(RNNCell):
|
||||||
|
"""RNNCell wrapper that ensures cell inputs are added to the outputs."""
|
||||||
|
|
||||||
|
def __init__(self, cell):
|
||||||
|
"""Constructs a `ResidualWrapper` for `cell`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cell: An instance of `RNNCell`.
|
||||||
|
"""
|
||||||
|
self._cell = cell
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_size(self):
|
||||||
|
return self._cell.state_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_size(self):
|
||||||
|
return self._cell.output_size
|
||||||
|
|
||||||
|
def zero_state(self, batch_size, dtype):
|
||||||
|
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
||||||
|
return self._cell.zero_state(batch_size, dtype)
|
||||||
|
|
||||||
|
def __call__(self, inputs, state, scope=None):
|
||||||
|
"""Run the cell and add its inputs to its outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: cell inputs.
|
||||||
|
state: cell state.
|
||||||
|
scope: optional cell scope.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of cell outputs and new state.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If cell inputs and outputs have different structure (type).
|
||||||
|
ValueError: If cell inputs and outputs have different structure (value).
|
||||||
|
"""
|
||||||
|
outputs, new_state = self._cell(inputs, state, scope=scope)
|
||||||
|
nest.assert_same_structure(inputs, outputs)
|
||||||
|
# Ensure shapes match
|
||||||
|
def assert_shape_match(inp, out):
|
||||||
|
inp.get_shape().assert_is_compatible_with(out.get_shape())
|
||||||
|
nest.map_structure(assert_shape_match, inputs, outputs)
|
||||||
|
res_outputs = nest.map_structure(
|
||||||
|
lambda inp, out: inp + out, inputs, outputs)
|
||||||
|
return (res_outputs, new_state)
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceWrapper(RNNCell):
|
||||||
|
"""Operator that ensures an RNNCell runs on a particular device."""
|
||||||
|
|
||||||
|
def __init__(self, cell, device):
|
||||||
|
"""Construct a `DeviceWrapper` for `cell` with device `device`.
|
||||||
|
|
||||||
|
Ensures the wrapped `cell` is called with `tf.device(device)`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cell: An instance of `RNNCell`.
|
||||||
|
device: A device string or function, for passing to `tf.device`.
|
||||||
|
"""
|
||||||
|
self._cell = cell
|
||||||
|
self._device = device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_size(self):
|
||||||
|
return self._cell.state_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_size(self):
|
||||||
|
return self._cell.output_size
|
||||||
|
|
||||||
|
def zero_state(self, batch_size, dtype):
|
||||||
|
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
||||||
|
with ops.device(self._device):
|
||||||
|
return self._cell.zero_state(batch_size, dtype)
|
||||||
|
|
||||||
|
def __call__(self, inputs, state, scope=None):
|
||||||
|
"""Run the cell on specified device."""
|
||||||
|
with ops.device(self._device):
|
||||||
|
return self._cell(inputs, state, scope=scope)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiRNNCell(RNNCell):
|
||||||
|
"""RNN cell composed sequentially of multiple simple cells."""
|
||||||
|
|
||||||
|
def __init__(self, cells, state_is_tuple=True):
|
||||||
|
"""Create a RNN cell composed sequentially of a number of RNNCells.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cells: list of RNNCells that will be composed in this order.
|
||||||
|
state_is_tuple: If True, accepted and returned states are n-tuples, where
|
||||||
|
`n = len(cells)`. If False, the states are all
|
||||||
|
concatenated along the column axis. This latter behavior will soon be
|
||||||
|
deprecated.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if cells is empty (not allowed), or at least one of the cells
|
||||||
|
returns a state tuple but the flag `state_is_tuple` is `False`.
|
||||||
|
"""
|
||||||
|
super(MultiRNNCell, self).__init__()
|
||||||
|
if not cells:
|
||||||
|
raise ValueError("Must specify at least one cell for MultiRNNCell.")
|
||||||
|
if not nest.is_sequence(cells):
|
||||||
|
raise TypeError(
|
||||||
|
"cells must be a list or tuple, but saw: %s." % cells)
|
||||||
|
|
||||||
|
self._cells = cells
|
||||||
|
self._state_is_tuple = state_is_tuple
|
||||||
|
if not state_is_tuple:
|
||||||
|
if any(nest.is_sequence(c.state_size) for c in self._cells):
|
||||||
|
raise ValueError("Some cells return tuples of states, but the flag "
|
||||||
|
"state_is_tuple is not set. State sizes are: %s"
|
||||||
|
% str([c.state_size for c in self._cells]))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_size(self):
|
||||||
|
if self._state_is_tuple:
|
||||||
|
return tuple(cell.state_size for cell in self._cells)
|
||||||
|
else:
|
||||||
|
return sum([cell.state_size for cell in self._cells])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_size(self):
|
||||||
|
return self._cells[-1].output_size
|
||||||
|
|
||||||
|
def zero_state(self, batch_size, dtype):
|
||||||
|
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
||||||
|
if self._state_is_tuple:
|
||||||
|
return tuple(cell.zero_state(batch_size, dtype) for cell in self._cells)
|
||||||
|
else:
|
||||||
|
# We know here that state_size of each cell is not a tuple and
|
||||||
|
# presumably does not contain TensorArrays or anything else fancy
|
||||||
|
return super(MultiRNNCell, self).zero_state(batch_size, dtype)
|
||||||
|
|
||||||
|
def call(self, inputs, state):
|
||||||
|
"""Run this multi-layer cell on inputs, starting from state."""
|
||||||
|
cur_state_pos = 0
|
||||||
|
cur_inp = inputs
|
||||||
|
new_states = []
|
||||||
|
for i, cell in enumerate(self._cells):
|
||||||
|
with vs.variable_scope("cell_%d" % i):
|
||||||
|
if self._state_is_tuple:
|
||||||
|
if not nest.is_sequence(state):
|
||||||
|
raise ValueError(
|
||||||
|
"Expected state to be a tuple of length %d, but received: %s" %
|
||||||
|
(len(self.state_size), state))
|
||||||
|
cur_state = state[i]
|
||||||
|
else:
|
||||||
|
cur_state = array_ops.slice(state, [0, cur_state_pos],
|
||||||
|
[-1, cell.state_size])
|
||||||
|
cur_state_pos += cell.state_size
|
||||||
|
cur_inp, new_state = cell(cur_inp, cur_state)
|
||||||
|
new_states.append(new_state)
|
||||||
|
|
||||||
|
new_states = (tuple(new_states) if self._state_is_tuple else
|
||||||
|
array_ops.concat(new_states, 1))
|
||||||
|
|
||||||
|
return cur_inp, new_states
|
||||||
|
|
||||||
|
|
||||||
|
class _SlimRNNCell(RNNCell):
|
||||||
|
"""A simple wrapper for slim.rnn_cells."""
|
||||||
|
|
||||||
|
def __init__(self, cell_fn):
|
||||||
|
"""Create a SlimRNNCell from a cell_fn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cell_fn: a function which takes (inputs, state, scope) and produces the
|
||||||
|
outputs and the new_state. Additionally when called with inputs=None and
|
||||||
|
state=None it should return (initial_outputs, initial_state).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if cell_fn is not callable
|
||||||
|
ValueError: if cell_fn cannot produce a valid initial state.
|
||||||
|
"""
|
||||||
|
if not callable(cell_fn):
|
||||||
|
raise TypeError("cell_fn %s needs to be callable", cell_fn)
|
||||||
|
self._cell_fn = cell_fn
|
||||||
|
self._cell_name = cell_fn.func.__name__
|
||||||
|
init_output, init_state = self._cell_fn(None, None)
|
||||||
|
output_shape = init_output.get_shape()
|
||||||
|
state_shape = init_state.get_shape()
|
||||||
|
self._output_size = output_shape.with_rank(2)[1].value
|
||||||
|
self._state_size = state_shape.with_rank(2)[1].value
|
||||||
|
if self._output_size is None:
|
||||||
|
raise ValueError("Initial output created by %s has invalid shape %s" %
|
||||||
|
(self._cell_name, output_shape))
|
||||||
|
if self._state_size is None:
|
||||||
|
raise ValueError("Initial state created by %s has invalid shape %s" %
|
||||||
|
(self._cell_name, state_shape))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_size(self):
|
||||||
|
return self._state_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_size(self):
|
||||||
|
return self._output_size
|
||||||
|
|
||||||
|
def __call__(self, inputs, state, scope=None):
|
||||||
|
scope = scope or self._cell_name
|
||||||
|
output, state = self._cell_fn(inputs, state, scope=scope)
|
||||||
|
return output, state
|
||||||
|
|
||||||
|
|
||||||
|
def _linear(args,
|
||||||
|
output_size,
|
||||||
|
bias,
|
||||||
|
bias_initializer=None,
|
||||||
|
kernel_initializer=None):
|
||||||
|
"""Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: a 2D Tensor or a list of 2D, batch x n, Tensors.
|
||||||
|
output_size: int, second dimension of W[i].
|
||||||
|
bias: boolean, whether to add a bias term or not.
|
||||||
|
bias_initializer: starting value to initialize the bias
|
||||||
|
(default is all zeros).
|
||||||
|
kernel_initializer: starting value to initialize the weight.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A 2D Tensor with shape [batch x output_size] equal to
|
||||||
|
sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if some of the arguments has unspecified or wrong shape.
|
||||||
|
"""
|
||||||
|
if args is None or (nest.is_sequence(args) and not args):
|
||||||
|
raise ValueError("`args` must be specified")
|
||||||
|
if not nest.is_sequence(args):
|
||||||
|
args = [args]
|
||||||
|
|
||||||
|
# Calculate the total size of arguments on dimension 1.
|
||||||
|
total_arg_size = 0
|
||||||
|
shapes = [a.get_shape() for a in args]
|
||||||
|
for shape in shapes:
|
||||||
|
if shape.ndims != 2:
|
||||||
|
raise ValueError("linear is expecting 2D arguments: %s" % shapes)
|
||||||
|
if shape[1].value is None:
|
||||||
|
raise ValueError("linear expects shape[1] to be provided for shape %s, "
|
||||||
|
"but saw %s" % (shape, shape[1]))
|
||||||
|
else:
|
||||||
|
total_arg_size += shape[1].value
|
||||||
|
|
||||||
|
dtype = [a.dtype for a in args][0]
|
||||||
|
|
||||||
|
# Now the computation.
|
||||||
|
scope = vs.get_variable_scope()
|
||||||
|
with vs.variable_scope(scope) as outer_scope:
|
||||||
|
weights = vs.get_variable(
|
||||||
|
_WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size],
|
||||||
|
dtype=dtype,
|
||||||
|
initializer=kernel_initializer)
|
||||||
|
if len(args) == 1:
|
||||||
|
res = math_ops.matmul(args[0], weights)
|
||||||
|
else:
|
||||||
|
res = math_ops.matmul(array_ops.concat(args, 1), weights)
|
||||||
|
if not bias:
|
||||||
|
return res
|
||||||
|
with vs.variable_scope(outer_scope) as inner_scope:
|
||||||
|
inner_scope.set_partitioner(None)
|
||||||
|
if bias_initializer is None:
|
||||||
|
bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
|
||||||
|
biases = vs.get_variable(
|
||||||
|
_BIAS_VARIABLE_NAME, [output_size],
|
||||||
|
dtype=dtype,
|
||||||
|
initializer=bias_initializer)
|
||||||
|
return nn_ops.bias_add(res, biases)
|
||||||
|
@ -129,19 +129,23 @@ class Scaffold(object):
|
|||||||
copy_from_scaffold: Optional scaffold object to copy fields from. Its
|
copy_from_scaffold: Optional scaffold object to copy fields from. Its
|
||||||
fields will be overwritten by the provided fields in this function.
|
fields will be overwritten by the provided fields in this function.
|
||||||
"""
|
"""
|
||||||
if copy_from_scaffold:
|
if copy_from_scaffold is not None:
|
||||||
if not isinstance(copy_from_scaffold, Scaffold):
|
if not isinstance(copy_from_scaffold, Scaffold):
|
||||||
raise TypeError('copy_from_scaffold is not a Scaffold instance.')
|
raise TypeError('copy_from_scaffold is not a Scaffold instance.')
|
||||||
init_op = init_op or copy_from_scaffold.init_op
|
# We need _coalesce since Tensor is not converted to bool automatically,
|
||||||
init_feed_dict = init_feed_dict or copy_from_scaffold.init_feed_dict
|
# so the common idiom of (a or b) does not work.
|
||||||
|
coalesce = lambda a, b: a if a is not None else b
|
||||||
|
init_op = coalesce(init_op, copy_from_scaffold.init_op)
|
||||||
|
init_feed_dict = coalesce(init_feed_dict,
|
||||||
|
copy_from_scaffold.init_feed_dict)
|
||||||
# Use the original init_fn provided by the user to init the new Scaffold.
|
# Use the original init_fn provided by the user to init the new Scaffold.
|
||||||
init_fn = init_fn or copy_from_scaffold._user_init_fn # pylint: disable=protected-access
|
init_fn = coalesce(init_fn, copy_from_scaffold._user_init_fn) # pylint: disable=protected-access
|
||||||
ready_op = ready_op or copy_from_scaffold.ready_op
|
ready_op = coalesce(ready_op, copy_from_scaffold.ready_op)
|
||||||
ready_for_local_init_op = ready_for_local_init_op or (
|
ready_for_local_init_op = coalesce(
|
||||||
copy_from_scaffold.ready_for_local_init_op)
|
ready_for_local_init_op, copy_from_scaffold.ready_for_local_init_op)
|
||||||
local_init_op = local_init_op or copy_from_scaffold.local_init_op
|
local_init_op = coalesce(local_init_op, copy_from_scaffold.local_init_op)
|
||||||
summary_op = summary_op or copy_from_scaffold.summary_op
|
summary_op = coalesce(summary_op, copy_from_scaffold.summary_op)
|
||||||
saver = saver or copy_from_scaffold.saver
|
saver = coalesce(saver, copy_from_scaffold.saver)
|
||||||
|
|
||||||
# NOTE(touts): modifying the init function to be passed the scaffold is a
|
# NOTE(touts): modifying the init function to be passed the scaffold is a
|
||||||
# hack to make it easy to find the saver. Is there a better way?
|
# hack to make it easy to find the saver. Is there a better way?
|
||||||
@ -152,12 +156,12 @@ class Scaffold(object):
|
|||||||
self._init_fn = None
|
self._init_fn = None
|
||||||
|
|
||||||
self._init_op = init_op
|
self._init_op = init_op
|
||||||
|
self._init_feed_dict = init_feed_dict
|
||||||
self._ready_op = ready_op
|
self._ready_op = ready_op
|
||||||
self._ready_for_local_init_op = ready_for_local_init_op
|
self._ready_for_local_init_op = ready_for_local_init_op
|
||||||
self._local_init_op = local_init_op
|
self._local_init_op = local_init_op
|
||||||
self._summary_op = summary_op
|
self._summary_op = summary_op
|
||||||
self._saver = saver
|
self._saver = saver
|
||||||
self._init_feed_dict = init_feed_dict
|
|
||||||
|
|
||||||
def finalize(self):
|
def finalize(self):
|
||||||
"""Creates operations if needed and finalizes the graph."""
|
"""Creates operations if needed and finalizes the graph."""
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
path: "tensorflow.nn"
|
path: "tensorflow.nn"
|
||||||
tf_module {
|
tf_module {
|
||||||
|
member {
|
||||||
|
name: "rnn_cell"
|
||||||
|
mtype: "<type \'module\'>"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "all_candidate_sampler"
|
name: "all_candidate_sampler"
|
||||||
argspec: "args=[\'true_classes\', \'num_true\', \'num_sampled\', \'unique\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
argspec: "args=[\'true_classes\', \'num_true\', \'num_sampled\', \'unique\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
@ -284,6 +288,18 @@ tf_module {
|
|||||||
name: "sparse_softmax_cross_entropy_with_logits"
|
name: "sparse_softmax_cross_entropy_with_logits"
|
||||||
argspec: "args=[\'_sentinel\', \'labels\', \'logits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
argspec: "args=[\'_sentinel\', \'labels\', \'logits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "static_bidirectional_rnn"
|
||||||
|
argspec: "args=[\'cell_fw\', \'cell_bw\', \'inputs\', \'initial_state_fw\', \'initial_state_bw\', \'dtype\', \'sequence_length\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "static_rnn"
|
||||||
|
argspec: "args=[\'cell\', \'inputs\', \'initial_state\', \'dtype\', \'sequence_length\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "static_state_saving_rnn"
|
||||||
|
argspec: "args=[\'cell\', \'inputs\', \'state_saver\', \'state_name\', \'sequence_length\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "sufficient_statistics"
|
name: "sufficient_statistics"
|
||||||
argspec: "args=[\'x\', \'axes\', \'shift\', \'keep_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
|
argspec: "args=[\'x\', \'axes\', \'shift\', \'keep_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
|
||||||
|
@ -0,0 +1,95 @@
|
|||||||
|
path: "tensorflow.nn.rnn_cell.BasicLSTMCell"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.BasicLSTMCell\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "graph"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "losses"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "non_trainable_variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "non_trainable_weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "output_size"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "scope_name"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "state_size"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "trainable_variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "trainable_weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "updates"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'num_units\', \'forget_bias\', \'state_is_tuple\', \'activation\', \'reuse\'], varargs=None, keywords=None, defaults=[\'1.0\', \'True\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_loss"
|
||||||
|
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_update"
|
||||||
|
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_variable"
|
||||||
|
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "apply"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "build"
|
||||||
|
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "call"
|
||||||
|
argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_losses_for"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_updates_for"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "zero_state"
|
||||||
|
argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,95 @@
|
|||||||
|
path: "tensorflow.nn.rnn_cell.BasicRNNCell"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.BasicRNNCell\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "graph"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "losses"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "non_trainable_variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "non_trainable_weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "output_size"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "scope_name"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "state_size"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "trainable_variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "trainable_weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "updates"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'num_units\', \'activation\', \'reuse\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_loss"
|
||||||
|
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_update"
|
||||||
|
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_variable"
|
||||||
|
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "apply"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "build"
|
||||||
|
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "call"
|
||||||
|
argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_losses_for"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_updates_for"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "zero_state"
|
||||||
|
argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,95 @@
|
|||||||
|
path: "tensorflow.nn.rnn_cell.DeviceWrapper"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.DeviceWrapper\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "graph"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "losses"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "non_trainable_variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "non_trainable_weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "output_size"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "scope_name"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "state_size"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "trainable_variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "trainable_weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "updates"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'cell\', \'device\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_loss"
|
||||||
|
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_update"
|
||||||
|
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_variable"
|
||||||
|
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "apply"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "build"
|
||||||
|
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "call"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_losses_for"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_updates_for"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "zero_state"
|
||||||
|
argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,95 @@
|
|||||||
|
path: "tensorflow.nn.rnn_cell.DropoutWrapper"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.DropoutWrapper\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "graph"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "losses"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "non_trainable_variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "non_trainable_weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "output_size"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "scope_name"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "state_size"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "trainable_variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "trainable_weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "updates"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'cell\', \'input_keep_prob\', \'output_keep_prob\', \'state_keep_prob\', \'variational_recurrent\', \'input_size\', \'dtype\', \'seed\'], varargs=None, keywords=None, defaults=[\'1.0\', \'1.0\', \'1.0\', \'False\', \'None\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_loss"
|
||||||
|
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_update"
|
||||||
|
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_variable"
|
||||||
|
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "apply"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "build"
|
||||||
|
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "call"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_losses_for"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_updates_for"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "zero_state"
|
||||||
|
argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,95 @@
|
|||||||
|
path: "tensorflow.nn.rnn_cell.GRUCell"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.GRUCell\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "graph"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "losses"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "non_trainable_variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "non_trainable_weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "output_size"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "scope_name"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "state_size"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "trainable_variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "trainable_weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "updates"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'num_units\', \'activation\', \'reuse\', \'kernel_initializer\', \'bias_initializer\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_loss"
|
||||||
|
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_update"
|
||||||
|
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_variable"
|
||||||
|
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "apply"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "build"
|
||||||
|
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "call"
|
||||||
|
argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_losses_for"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_updates_for"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "zero_state"
|
||||||
|
argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,95 @@
|
|||||||
|
path: "tensorflow.nn.rnn_cell.LSTMCell"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.LSTMCell\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "graph"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "losses"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "non_trainable_variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "non_trainable_weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "output_size"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "scope_name"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "state_size"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "trainable_variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "trainable_weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "updates"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'num_units\', \'use_peepholes\', \'cell_clip\', \'initializer\', \'num_proj\', \'proj_clip\', \'num_unit_shards\', \'num_proj_shards\', \'forget_bias\', \'state_is_tuple\', \'activation\', \'reuse\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'1.0\', \'True\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_loss"
|
||||||
|
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_update"
|
||||||
|
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_variable"
|
||||||
|
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "apply"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "build"
|
||||||
|
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "call"
|
||||||
|
argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_losses_for"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_updates_for"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "zero_state"
|
||||||
|
argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,27 @@
|
|||||||
|
path: "tensorflow.nn.rnn_cell.LSTMStateTuple"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.LSTMStateTuple\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.LSTMStateTuple\'>"
|
||||||
|
is_instance: "<type \'tuple\'>"
|
||||||
|
member {
|
||||||
|
name: "c"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "dtype"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "h"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "count"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "index"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,95 @@
|
|||||||
|
path: "tensorflow.nn.rnn_cell.MultiRNNCell"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "graph"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "losses"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "non_trainable_variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "non_trainable_weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "output_size"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "scope_name"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "state_size"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "trainable_variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "trainable_weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "updates"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_loss"
|
||||||
|
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_update"
|
||||||
|
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_variable"
|
||||||
|
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "apply"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "build"
|
||||||
|
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "call"
|
||||||
|
argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_losses_for"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_updates_for"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "zero_state"
|
||||||
|
argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,94 @@
|
|||||||
|
path: "tensorflow.nn.rnn_cell.RNNCell"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "graph"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "losses"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "non_trainable_variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "non_trainable_weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "output_size"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "scope_name"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "state_size"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "trainable_variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "trainable_weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "updates"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'trainable\', \'name\', \'dtype\'], varargs=None, keywords=kwargs, defaults=[\'True\', \'None\', \"<dtype: \'float32\'>\"], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_loss"
|
||||||
|
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_update"
|
||||||
|
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_variable"
|
||||||
|
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "apply"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "build"
|
||||||
|
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "call"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_losses_for"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_updates_for"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "zero_state"
|
||||||
|
argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,95 @@
|
|||||||
|
path: "tensorflow.nn.rnn_cell.ResidualWrapper"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.ResidualWrapper\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "graph"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "losses"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "non_trainable_variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "non_trainable_weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "output_size"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "scope_name"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "state_size"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "trainable_variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "trainable_weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "updates"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'cell\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_loss"
|
||||||
|
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_update"
|
||||||
|
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_variable"
|
||||||
|
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "apply"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "build"
|
||||||
|
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "call"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_losses_for"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_updates_for"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "zero_state"
|
||||||
|
argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
43
tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.pbtxt
Normal file
43
tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.pbtxt
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
path: "tensorflow.nn.rnn_cell"
|
||||||
|
tf_module {
|
||||||
|
member {
|
||||||
|
name: "BasicLSTMCell"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "BasicRNNCell"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "DeviceWrapper"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "DropoutWrapper"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "GRUCell"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "LSTMCell"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "LSTMStateTuple"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "MultiRNNCell"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "RNNCell"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "ResidualWrapper"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
}
|
@ -12,12 +12,23 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Classes for converting parsed doc content into markdown pages."""
|
"""A module for converting parsed doc content into markdown pages.
|
||||||
|
|
||||||
|
The adjacent `parser` module creates `PageInfo` objects, containing all data
|
||||||
|
necessary to document an element of the TensorFlow API.
|
||||||
|
|
||||||
|
This module contains one public function, which handels the conversion of these
|
||||||
|
`PageInfo` objects into a markdown string:
|
||||||
|
|
||||||
|
md_page = build_md_page(page_info)
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
|
||||||
def build_md_page(page_info):
|
def build_md_page(page_info):
|
||||||
"""Given a PageInfo object, return markdown for the page.
|
"""Given a PageInfo object, return markdown for the page.
|
||||||
@ -46,7 +57,9 @@ def build_md_page(page_info):
|
|||||||
|
|
||||||
def _build_function_page(page_info):
|
def _build_function_page(page_info):
|
||||||
"""Given a FunctionPageInfo object Return the page as an md string."""
|
"""Given a FunctionPageInfo object Return the page as an md string."""
|
||||||
parts = ['# %s\n\n' % page_info.full_name]
|
parts = [_Metadata(page_info.full_name).build_html()]
|
||||||
|
|
||||||
|
parts.append('# %s\n\n' % page_info.full_name)
|
||||||
|
|
||||||
if page_info.aliases:
|
if page_info.aliases:
|
||||||
parts.extend('### `%s`\n' % name
|
parts.extend('### `%s`\n' % name
|
||||||
@ -70,7 +83,17 @@ def _build_function_page(page_info):
|
|||||||
|
|
||||||
def _build_class_page(page_info):
|
def _build_class_page(page_info):
|
||||||
"""Given a ClassPageInfo object Return the page as an md string."""
|
"""Given a ClassPageInfo object Return the page as an md string."""
|
||||||
parts = ['# {page_info.full_name}\n\n'.format(page_info=page_info)]
|
meta_data = _Metadata(page_info.full_name)
|
||||||
|
for item in itertools.chain(
|
||||||
|
page_info.classes,
|
||||||
|
page_info.properties,
|
||||||
|
page_info.methods,
|
||||||
|
page_info.other_members):
|
||||||
|
meta_data.append(item)
|
||||||
|
|
||||||
|
parts = [meta_data.build_html()]
|
||||||
|
|
||||||
|
parts.append('# {page_info.full_name}\n\n'.format(page_info=page_info))
|
||||||
|
|
||||||
if page_info.aliases:
|
if page_info.aliases:
|
||||||
parts.extend('### `class %s`\n' % name for name in page_info.aliases)
|
parts.extend('### `class %s`\n' % name for name in page_info.aliases)
|
||||||
@ -150,8 +173,17 @@ def _build_class_page(page_info):
|
|||||||
|
|
||||||
def _build_module_page(page_info):
|
def _build_module_page(page_info):
|
||||||
"""Given a ClassPageInfo object Return the page as an md string."""
|
"""Given a ClassPageInfo object Return the page as an md string."""
|
||||||
|
meta_data = _Metadata(page_info.full_name)
|
||||||
|
|
||||||
parts = ['# Module: {full_name}\n\n'.format(full_name=page_info.full_name)]
|
# Objects with their own pages are not added to the matadata list for the
|
||||||
|
# module, as the only thing on the module page is a link to the object's page.
|
||||||
|
for item in page_info.other_members:
|
||||||
|
meta_data.append(item)
|
||||||
|
|
||||||
|
parts = [meta_data.build_html()]
|
||||||
|
|
||||||
|
parts.append(
|
||||||
|
'# Module: {full_name}\n\n'.format(full_name=page_info.full_name))
|
||||||
|
|
||||||
if page_info.aliases:
|
if page_info.aliases:
|
||||||
parts.extend('### Module `%s`\n' % name for name in page_info.aliases)
|
parts.extend('### Module `%s`\n' % name for name in page_info.aliases)
|
||||||
@ -261,3 +293,41 @@ def _build_function_details(function_details):
|
|||||||
parts.append(''.join(sub))
|
parts.append(''.join(sub))
|
||||||
|
|
||||||
return '\n'.join(parts)
|
return '\n'.join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
class _Metadata(object):
|
||||||
|
"""A class for building a page's Metadata block.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: The name of the page being described by the Metadata block.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name):
|
||||||
|
"""Creata a Metadata builder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the page being described by the Metadata block.
|
||||||
|
"""
|
||||||
|
self.name = name
|
||||||
|
self._content = []
|
||||||
|
|
||||||
|
def append(self, item):
|
||||||
|
"""Add an item from the page to the Metadata block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item: The parsed page section to add.
|
||||||
|
"""
|
||||||
|
self._content.append(item.short_name)
|
||||||
|
|
||||||
|
def build_html(self):
|
||||||
|
"""Return the Metadata block as an Html string."""
|
||||||
|
schema = 'http://developers.google.com/ReferenceObject'
|
||||||
|
parts = ['<div itemscope itemtype="%s">' % schema]
|
||||||
|
|
||||||
|
parts.append('<meta itemprop="name" content="%s" />' % self.name)
|
||||||
|
for item in self._content:
|
||||||
|
parts.append('<meta itemprop="property" content="%s"/>' % item)
|
||||||
|
|
||||||
|
parts.extend(['</div>', '', ''])
|
||||||
|
|
||||||
|
return '\n'.join(parts)
|
||||||
|
Loading…
Reference in New Issue
Block a user