Updated use of bidirectional_rnn to API changes in TF master branch
This commit is contained in:
parent
6963eba144
commit
3e1bef80d9
@ -96,7 +96,7 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 2,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": true
|
"collapsed": false
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -409,7 +409,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"has the same meaning as the `None` dimension in the shape of `y`. The `n_steps` dimension of its shape indicates the number of time-slices in the sequence. Finally, the `n_input + 2*n_input*n_context` dimension of its shape indicates the number of bins in Fourier transform `n_input` along with the number of bins in the prefix-context `n_input*n_context` and postfix-contex `n_input*n_context`.\n",
|
"has the same meaning as the `None` dimension in the shape of `y`. The `n_steps` dimension of its shape indicates the number of time-slices in the sequence. Finally, the `n_input + 2*n_input*n_context` dimension of its shape indicates the number of bins in Fourier transform `n_input` along with the number of bins in the prefix-context `n_input*n_context` and postfix-contex `n_input*n_context`.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"The next placeholders we introduce `istate_fw` and `istate_bw` corresponds to the states and cells of the forward and backward LSTM networks. As both of these are floats of dimension `n_cell_dim`, we define `istate_fw` and `istate_bw` as follows"
|
"The next placeholders we introduce `istate_fw` and `istate_bw` correspond to the initial states and cells of the forward and backward LSTM networks. As both of these are floats of dimension `n_cell_dim`, we define `istate_fw` and `istate_bw` as follows"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -420,8 +420,8 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"istate_fw = tf.placeholder(\"float\", [None, 2*n_cell_dim])\n",
|
"istate_fw = (tf.placeholder(\"float\", [None, n_cell_dim]), tf.placeholder(\"float\", [None, n_cell_dim]))\n",
|
||||||
"istate_bw = tf.placeholder(\"float\", [None, 2*n_cell_dim])"
|
"istate_bw = (tf.placeholder(\"float\", [None, n_cell_dim]), tf.placeholder(\"float\", [None, n_cell_dim]))"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -520,11 +520,11 @@
|
|||||||
" layer_3 = tf.split(0, n_steps, layer_3)\n",
|
" layer_3 = tf.split(0, n_steps, layer_3)\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Get lstm cell output\n",
|
" # Get lstm cell output\n",
|
||||||
" outputs = tf.nn.bidirectional_rnn(lstm_fw_cell,\n",
|
" outputs, output_state_fw, output_state_bw = tf.nn.bidirectional_rnn(cell_fw=lstm_fw_cell,\n",
|
||||||
" lstm_bw_cell,\n",
|
" cell_bw=lstm_bw_cell,\n",
|
||||||
" layer_3,\n",
|
" inputs=layer_3,\n",
|
||||||
" initial_state_fw=_istate_fw,\n",
|
" initial_state_fw=_istate_fw,\n",
|
||||||
" initial_state_bw=_istate_bw)\n",
|
" initial_state_bw=_istate_bw)\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Reshape outputs from a list of n_steps tensors each of shape [batch_size, 2*n_cell_dim]\n",
|
" # Reshape outputs from a list of n_steps tensors each of shape [batch_size, 2*n_cell_dim]\n",
|
||||||
" # to a single tensor of shape [n_steps*batch_size, 2*n_cell_dim]\n",
|
" # to a single tensor of shape [n_steps*batch_size, 2*n_cell_dim]\n",
|
||||||
@ -600,11 +600,11 @@
|
|||||||
"The next line of `BiRNN`\n",
|
"The next line of `BiRNN`\n",
|
||||||
"```python\n",
|
"```python\n",
|
||||||
" # Get lstm cell output\n",
|
" # Get lstm cell output\n",
|
||||||
" outputs = tf.nn.bidirectional_rnn(lstm_fw_cell,\n",
|
" outputs, output_state_fw, output_state_bw = tf.nn.bidirectional_rnn(cell_fw=lstm_fw_cell,\n",
|
||||||
" lstm_bw_cell,\n",
|
" cell_bw=lstm_bw_cell,\n",
|
||||||
" layer_3,\n",
|
" inputs=layer_3,\n",
|
||||||
" initial_state_fw=_istate_fw,\n",
|
" initial_state_fw=_istate_fw,\n",
|
||||||
" initial_state_bw=_istate_bw)\n",
|
" initial_state_bw=_istate_bw)\n",
|
||||||
"```\n",
|
"```\n",
|
||||||
"feeds `layer_3` to the LSTM BRNN cell and obtains the LSTM BRNN output.\n",
|
"feeds `layer_3` to the LSTM BRNN cell and obtains the LSTM BRNN output.\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -1223,7 +1223,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 20,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": true
|
"collapsed": true
|
||||||
},
|
},
|
||||||
|
Loading…
Reference in New Issue
Block a user