Move keras related legacy RNN test to integration test.

Tests are run under compat.v1 and graph model.

PiperOrigin-RevId: 305138425
Change-Id: I45054928f2aab86792f908752fe76fe057398b44
This commit is contained in:
Scott Zhu 2020-04-06 16:04:52 -07:00 committed by TensorFlower Gardener
parent 79985f62fc
commit 03ee7f43e8
3 changed files with 398 additions and 357 deletions

View File

@ -31,3 +31,13 @@ tf_py_test(
"//tensorflow/python:extra_py_tests_deps",
],
)
tf_py_test(
name = "legacy_rnn_test", # Remove this target in when TF 1 is deprecated.
srcs = ["legacy_rnn_test.py"],
python_version = "PY3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:extra_py_tests_deps",
],
)

View File

@ -0,0 +1,388 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
class KerasNetworkTFRNNs(tf.keras.Model):
def __init__(self, name=None):
super(KerasNetworkTFRNNs, self).__init__(name=name)
self._cell = tf.nn.rnn_cell.MultiRNNCell(
[tf.nn.rnn_cell.LSTMCell(1) for _ in range(2)])
def call(self, inputs):
return self._cell(inputs, self._cell.get_initial_state(inputs))
class KerasNetworkKerasRNNs(tf.keras.Model):
def __init__(self, name=None):
super(KerasNetworkKerasRNNs, self).__init__(name=name)
self._cell = tf.keras.layers.StackedRNNCells(
[tf.keras.layers.LSTMCell(1) for _ in range(2)])
def call(self, inputs):
return self._cell(inputs, self._cell.get_initial_state(inputs))
class LegacyRNNTest(tf.test.TestCase):
def setUp(self):
super(LegacyRNNTest, self).setUp()
self._seed = 23489
np.random.seed(self._seed)
def testRNNWithKerasSimpleRNNCell(self):
with self.cached_session() as sess:
input_shape = 10
output_shape = 5
timestep = 4
batch = 100
(x_train, y_train), _ = get_test_data(
train_samples=batch,
test_samples=0,
input_shape=(timestep, input_shape),
num_classes=output_shape)
y_train = tf.keras.utils.to_categorical(y_train)
cell = tf.keras.layers.SimpleRNNCell(output_shape)
inputs = tf.placeholder(
tf.float32, shape=(None, timestep, input_shape))
predict = tf.placeholder(
tf.float32, shape=(None, output_shape))
outputs, state = tf.nn.dynamic_rnn(
cell, inputs, dtype=tf.float32)
self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
self.assertEqual(state.shape.as_list(), [None, output_shape])
loss = tf.losses.softmax_cross_entropy(predict, state)
train_op = tf.train.GradientDescentOptimizer(0.001).minimize(loss)
sess.run([tf.global_variables_initializer()])
_, outputs, state = sess.run(
[train_op, outputs, state], {inputs: x_train, predict: y_train})
self.assertEqual(len(outputs), batch)
self.assertEqual(len(state), batch)
def testRNNWithKerasGRUCell(self):
with self.cached_session() as sess:
input_shape = 10
output_shape = 5
timestep = 4
batch = 100
(x_train, y_train), _ = get_test_data(
train_samples=batch,
test_samples=0,
input_shape=(timestep, input_shape),
num_classes=output_shape)
y_train = tf.keras.utils.to_categorical(y_train)
cell = tf.keras.layers.GRUCell(output_shape)
inputs = tf.placeholder(
tf.float32, shape=(None, timestep, input_shape))
predict = tf.placeholder(
tf.float32, shape=(None, output_shape))
outputs, state = tf.nn.dynamic_rnn(
cell, inputs, dtype=tf.float32)
self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
self.assertEqual(state.shape.as_list(), [None, output_shape])
loss = tf.losses.softmax_cross_entropy(predict, state)
train_op = tf.train.GradientDescentOptimizer(0.001).minimize(loss)
sess.run([tf.global_variables_initializer()])
_, outputs, state = sess.run(
[train_op, outputs, state], {inputs: x_train, predict: y_train})
self.assertEqual(len(outputs), batch)
self.assertEqual(len(state), batch)
def testRNNWithKerasLSTMCell(self):
with self.cached_session() as sess:
input_shape = 10
output_shape = 5
timestep = 4
batch = 100
(x_train, y_train), _ = get_test_data(
train_samples=batch,
test_samples=0,
input_shape=(timestep, input_shape),
num_classes=output_shape)
y_train = tf.keras.utils.to_categorical(y_train)
cell = tf.keras.layers.LSTMCell(output_shape)
inputs = tf.placeholder(
tf.float32, shape=(None, timestep, input_shape))
predict = tf.placeholder(
tf.float32, shape=(None, output_shape))
outputs, state = tf.nn.dynamic_rnn(
cell, inputs, dtype=tf.float32)
self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
self.assertEqual(len(state), 2)
self.assertEqual(state[0].shape.as_list(), [None, output_shape])
self.assertEqual(state[1].shape.as_list(), [None, output_shape])
loss = tf.losses.softmax_cross_entropy(predict, state[0])
train_op = tf.train.GradientDescentOptimizer(0.001).minimize(loss)
sess.run([tf.global_variables_initializer()])
_, outputs, state = sess.run(
[train_op, outputs, state], {inputs: x_train, predict: y_train})
self.assertEqual(len(outputs), batch)
self.assertEqual(len(state), 2)
self.assertEqual(len(state[0]), batch)
self.assertEqual(len(state[1]), batch)
def testRNNWithStackKerasCell(self):
with self.cached_session() as sess:
input_shape = 10
output_shape = 5
timestep = 4
batch = 100
(x_train, y_train), _ = get_test_data(
train_samples=batch,
test_samples=0,
input_shape=(timestep, input_shape),
num_classes=output_shape)
y_train = tf.keras.utils.to_categorical(y_train)
cell = tf.keras.layers.StackedRNNCells(
[tf.keras.layers.LSTMCell(2 * output_shape),
tf.keras.layers.LSTMCell(output_shape)])
inputs = tf.placeholder(
tf.float32, shape=(None, timestep, input_shape))
predict = tf.placeholder(
tf.float32, shape=(None, output_shape))
outputs, state = tf.nn.dynamic_rnn(
cell, inputs, dtype=tf.float32)
self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
self.assertEqual(len(state), 2)
state = tf.nest.flatten(state)
self.assertEqual(len(state), 4)
self.assertEqual(state[0].shape.as_list(), [None, 2 * output_shape])
self.assertEqual(state[1].shape.as_list(), [None, 2 * output_shape])
self.assertEqual(state[2].shape.as_list(), [None, output_shape])
self.assertEqual(state[3].shape.as_list(), [None, output_shape])
loss = tf.losses.softmax_cross_entropy(predict, state[2])
train_op = tf.train.GradientDescentOptimizer(0.001).minimize(loss)
sess.run([tf.global_variables_initializer()])
_, outputs, state = sess.run(
[train_op, outputs, state], {inputs: x_train, predict: y_train})
self.assertEqual(len(outputs), batch)
self.assertEqual(len(state), 4)
for s in state:
self.assertEqual(len(s), batch)
def testStaticRNNWithKerasSimpleRNNCell(self):
with self.cached_session() as sess:
input_shape = 10
output_shape = 5
timestep = 4
batch = 100
(x_train, y_train), _ = get_test_data(
train_samples=batch,
test_samples=0,
input_shape=(timestep, input_shape),
num_classes=output_shape)
x_train = np.transpose(x_train, (1, 0, 2))
y_train = tf.keras.utils.to_categorical(y_train)
cell = tf.keras.layers.SimpleRNNCell(output_shape)
inputs = [tf.placeholder(
tf.float32, shape=(None, input_shape))] * timestep
predict = tf.placeholder(
tf.float32, shape=(None, output_shape))
outputs, state = tf.nn.static_rnn(
cell, inputs, dtype=tf.float32)
self.assertEqual(len(outputs), timestep)
self.assertEqual(outputs[0].shape.as_list(), [None, output_shape])
self.assertEqual(state.shape.as_list(), [None, output_shape])
loss = tf.losses.softmax_cross_entropy(predict, state)
train_op = tf.train.GradientDescentOptimizer(0.001).minimize(loss)
sess.run([tf.global_variables_initializer()])
feed_dict = {i: d for i, d in zip(inputs, x_train)}
feed_dict[predict] = y_train
_, outputs, state = sess.run(
[train_op, outputs, state], feed_dict)
self.assertEqual(len(outputs), timestep)
self.assertEqual(len(outputs[0]), batch)
self.assertEqual(len(state), batch)
def testKerasAndTFRNNLayerOutputComparison(self):
input_shape = 10
output_shape = 5
timestep = 4
batch = 20
(x_train, _), _ = get_test_data(
train_samples=batch,
test_samples=0,
input_shape=(timestep, input_shape),
num_classes=output_shape)
fix_weights_generator = tf.keras.layers.SimpleRNNCell(output_shape)
fix_weights_generator.build((None, input_shape))
weights = fix_weights_generator.get_weights()
with self.session(graph=tf.Graph()) as sess:
inputs = tf.placeholder(
tf.float32, shape=(None, timestep, input_shape))
cell = tf.keras.layers.SimpleRNNCell(output_shape)
tf_out, tf_state = tf.nn.dynamic_rnn(
cell, inputs, dtype=tf.float32)
cell.set_weights(weights)
[tf_out, tf_state] = sess.run([tf_out, tf_state], {inputs: x_train})
with self.session(graph=tf.Graph()) as sess:
k_input = tf.keras.Input(shape=(timestep, input_shape),
dtype=tf.float32)
cell = tf.keras.layers.SimpleRNNCell(output_shape)
layer = tf.keras.layers.RNN(
cell, return_sequences=True, return_state=True)
keras_out = layer(k_input)
cell.set_weights(weights)
k_out, k_state = sess.run(keras_out, {k_input: x_train})
self.assertAllClose(tf_out, k_out)
self.assertAllClose(tf_state, k_state)
def testSimpleRNNCellAndBasicRNNCellComparison(self):
input_shape = 10
output_shape = 5
timestep = 4
batch = 20
(x_train, _), _ = get_test_data(
train_samples=batch,
test_samples=0,
input_shape=(timestep, input_shape),
num_classes=output_shape)
fix_weights_generator = tf.keras.layers.SimpleRNNCell(output_shape)
fix_weights_generator.build((None, input_shape))
# The SimpleRNNCell contains 3 weights: kernel, recurrent_kernel, and bias
# The BasicRNNCell contains 2 weight: kernel and bias, where kernel is
# zipped [kernel, recurrent_kernel] in SimpleRNNCell.
keras_weights = fix_weights_generator.get_weights()
kernel, recurrent_kernel, bias = keras_weights
tf_weights = [np.concatenate((kernel, recurrent_kernel)), bias]
with self.session(graph=tf.Graph()) as sess:
inputs = tf.placeholder(
tf.float32, shape=(None, timestep, input_shape))
cell = tf.keras.layers.SimpleRNNCell(output_shape)
k_out, k_state = tf.nn.dynamic_rnn(
cell, inputs, dtype=tf.float32)
cell.set_weights(keras_weights)
[k_out, k_state] = sess.run([k_out, k_state], {inputs: x_train})
with self.session(graph=tf.Graph()) as sess:
inputs = tf.placeholder(
tf.float32, shape=(None, timestep, input_shape))
cell = tf.nn.rnn_cell.BasicRNNCell(output_shape)
tf_out, tf_state = tf.nn.dynamic_rnn(
cell, inputs, dtype=tf.float32)
cell.set_weights(tf_weights)
[tf_out, tf_state] = sess.run([tf_out, tf_state], {inputs: x_train})
self.assertAllClose(tf_out, k_out, atol=1e-5)
self.assertAllClose(tf_state, k_state, atol=1e-5)
def testRNNCellSerialization(self):
for cell in [
tf.nn.rnn_cell.LSTMCell(32, use_peepholes=True, cell_clip=True),
tf.nn.rnn_cell.BasicLSTMCell(32, dtype=tf.float32),
tf.nn.rnn_cell.BasicRNNCell(32, activation="relu", dtype=tf.float32),
tf.nn.rnn_cell.GRUCell(32, dtype=tf.float32)
]:
with self.cached_session():
x = tf.keras.Input((None, 5))
layer = tf.keras.layers.RNN(cell)
y = layer(x)
model = tf.keras.models.Model(x, y)
model.compile(optimizer="rmsprop", loss="mse")
# Test basic case serialization.
x_np = np.random.random((6, 5, 5))
y_np = model.predict(x_np)
weights = model.get_weights()
config = layer.get_config()
# The custom_objects is important here since rnn_cell_impl is
# not visible as a Keras layer, and also has a name conflict with
# keras.LSTMCell and GRUCell.
layer = tf.keras.layers.RNN.from_config(
config,
custom_objects={
"BasicRNNCell": tf.nn.rnn_cell.BasicRNNCell,
"GRUCell": tf.nn.rnn_cell.GRUCell,
"LSTMCell": tf.nn.rnn_cell.LSTMCell,
"BasicLSTMCell": tf.nn.rnn_cell.BasicLSTMCell
})
y = layer(x)
model = tf.keras.models.Model(x, y)
model.set_weights(weights)
y_np_2 = model.predict(x_np)
self.assertAllClose(y_np, y_np_2, atol=1e-4)
def testRNNCellActsLikeKerasRNNCellInProperScope(self):
with tf.layers.experimental.keras_style_scope():
kn1 = KerasNetworkTFRNNs(name="kn1")
kn2 = KerasNetworkKerasRNNs(name="kn2")
z = tf.zeros((2, 3))
kn1(z) # pylint:disable=not-callable
kn2(z) # pylint:disable=not-callable
# pylint: disable=protected-access
self.assertTrue(all("kn1" in v.name for v in kn1._cell.variables))
self.assertTrue(all("kn2" in v.name for v in kn2._cell.variables))
with tf.layers.experimental.keras_style_scope():
kn1_new = KerasNetworkTFRNNs(name="kn1_new")
kn2_new = KerasNetworkKerasRNNs(name="kn2_new")
kn2_new(z) # pylint:disable=not-callable
# Most importantly, this doesn't fail due to variable scope reuse issues.
kn1_new(z) # pylint:disable=not-callable
self.assertTrue(all("kn1_new" in v.name for v in kn1_new._cell.variables))
self.assertTrue(all("kn2_new" in v.name for v in kn2_new._cell.variables))
def get_test_data(train_samples,
test_samples,
input_shape,
num_classes):
num_sample = train_samples + test_samples
templates = 2 * num_classes * np.random.random((num_classes,) + input_shape)
y = np.random.randint(0, num_classes, size=(num_sample,))
x = np.zeros((num_sample,) + input_shape, dtype=np.float32)
for i in range(num_sample):
x[i] = templates[y[i]] + np.random.normal(loc=0, scale=1., size=input_shape)
return ((x[:train_samples], y[:train_samples]),
(x[train_samples:], y[train_samples:]))
if __name__ == "__main__":
tf.test.main()

View File

@ -26,7 +26,6 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import keras
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@ -34,10 +33,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import network as keras_network
from tensorflow.python.keras.utils import np_utils
from tensorflow.python.layers import base as base_layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
@ -47,14 +42,11 @@ from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables as variables_lib
import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import
from tensorflow.python.ops.losses import losses
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
import tensorflow.python.ops.sparse_grad # pylint: disable=unused-import
import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
from tensorflow.python.training import saver
from tensorflow.python.training import training
from tensorflow.python.util import nest
class Plus1RNNCell(rnn_cell_impl.RNNCell):
@ -130,28 +122,6 @@ class TensorArrayStateRNNCell(rnn_cell_impl.RNNCell):
return (input_, (state[0] + 1, new_array))
class KerasNetworkTFRNNs(keras_network.Network):
def __init__(self, name=None):
super(KerasNetworkTFRNNs, self).__init__(name=name)
self._cell = rnn_cell_impl.MultiRNNCell(
[rnn_cell_impl.LSTMCell(1) for _ in range(2)])
def call(self, inputs):
return self._cell(inputs, self._cell.get_initial_state(inputs))
class KerasNetworkKerasRNNs(keras_network.Network):
def __init__(self, name=None):
super(KerasNetworkKerasRNNs, self).__init__(name=name)
self._cell = keras.layers.StackedRNNCells(
[keras.layers.LSTMCell(1) for _ in range(2)])
def call(self, inputs):
return self._cell(inputs, self._cell.get_initial_state(inputs))
class RNNTest(test.TestCase):
def setUp(self):
@ -361,269 +331,6 @@ class RNNTest(test.TestCase):
self._assert_cell_builds(rnn_cell_impl.LSTMCell, f32, 5, 7, 3)
self._assert_cell_builds(rnn_cell_impl.LSTMCell, f64, 5, 7, 3)
@test_util.run_deprecated_v1
def testRNNWithKerasSimpleRNNCell(self):
with self.cached_session() as sess:
input_shape = 10
output_shape = 5
timestep = 4
batch = 100
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=batch,
test_samples=0,
input_shape=(timestep, input_shape),
num_classes=output_shape)
y_train = np_utils.to_categorical(y_train)
cell = keras.layers.SimpleRNNCell(output_shape)
inputs = array_ops.placeholder(
dtypes.float32, shape=(None, timestep, input_shape))
predict = array_ops.placeholder(
dtypes.float32, shape=(None, output_shape))
outputs, state = rnn.dynamic_rnn(
cell, inputs, dtype=dtypes.float32)
self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
self.assertEqual(state.shape.as_list(), [None, output_shape])
loss = losses.softmax_cross_entropy(predict, state)
train_op = training.GradientDescentOptimizer(0.001).minimize(loss)
sess.run([variables_lib.global_variables_initializer()])
_, outputs, state = sess.run(
[train_op, outputs, state], {inputs: x_train, predict: y_train})
self.assertEqual(len(outputs), batch)
self.assertEqual(len(state), batch)
@test_util.run_deprecated_v1
def testRNNWithKerasGRUCell(self):
with self.cached_session() as sess:
input_shape = 10
output_shape = 5
timestep = 4
batch = 100
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=batch,
test_samples=0,
input_shape=(timestep, input_shape),
num_classes=output_shape)
y_train = np_utils.to_categorical(y_train)
cell = keras.layers.GRUCell(output_shape)
inputs = array_ops.placeholder(
dtypes.float32, shape=(None, timestep, input_shape))
predict = array_ops.placeholder(
dtypes.float32, shape=(None, output_shape))
outputs, state = rnn.dynamic_rnn(
cell, inputs, dtype=dtypes.float32)
self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
self.assertEqual(state.shape.as_list(), [None, output_shape])
loss = losses.softmax_cross_entropy(predict, state)
train_op = training.GradientDescentOptimizer(0.001).minimize(loss)
sess.run([variables_lib.global_variables_initializer()])
_, outputs, state = sess.run(
[train_op, outputs, state], {inputs: x_train, predict: y_train})
self.assertEqual(len(outputs), batch)
self.assertEqual(len(state), batch)
@test_util.run_deprecated_v1
def testRNNWithKerasLSTMCell(self):
with self.cached_session() as sess:
input_shape = 10
output_shape = 5
timestep = 4
batch = 100
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=batch,
test_samples=0,
input_shape=(timestep, input_shape),
num_classes=output_shape)
y_train = np_utils.to_categorical(y_train)
cell = keras.layers.LSTMCell(output_shape)
inputs = array_ops.placeholder(
dtypes.float32, shape=(None, timestep, input_shape))
predict = array_ops.placeholder(
dtypes.float32, shape=(None, output_shape))
outputs, state = rnn.dynamic_rnn(
cell, inputs, dtype=dtypes.float32)
self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
self.assertEqual(len(state), 2)
self.assertEqual(state[0].shape.as_list(), [None, output_shape])
self.assertEqual(state[1].shape.as_list(), [None, output_shape])
loss = losses.softmax_cross_entropy(predict, state[0])
train_op = training.GradientDescentOptimizer(0.001).minimize(loss)
sess.run([variables_lib.global_variables_initializer()])
_, outputs, state = sess.run(
[train_op, outputs, state], {inputs: x_train, predict: y_train})
self.assertEqual(len(outputs), batch)
self.assertEqual(len(state), 2)
self.assertEqual(len(state[0]), batch)
self.assertEqual(len(state[1]), batch)
@test_util.run_deprecated_v1
def testRNNWithStackKerasCell(self):
with self.cached_session() as sess:
input_shape = 10
output_shape = 5
timestep = 4
batch = 100
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=batch,
test_samples=0,
input_shape=(timestep, input_shape),
num_classes=output_shape)
y_train = np_utils.to_categorical(y_train)
cell = keras.layers.StackedRNNCells(
[keras.layers.LSTMCell(2 * output_shape),
keras.layers.LSTMCell(output_shape)])
inputs = array_ops.placeholder(
dtypes.float32, shape=(None, timestep, input_shape))
predict = array_ops.placeholder(
dtypes.float32, shape=(None, output_shape))
outputs, state = rnn.dynamic_rnn(
cell, inputs, dtype=dtypes.float32)
self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
self.assertEqual(len(state), 2)
state = nest.flatten(state)
self.assertEqual(len(state), 4)
self.assertEqual(state[0].shape.as_list(), [None, 2 * output_shape])
self.assertEqual(state[1].shape.as_list(), [None, 2 * output_shape])
self.assertEqual(state[2].shape.as_list(), [None, output_shape])
self.assertEqual(state[3].shape.as_list(), [None, output_shape])
loss = losses.softmax_cross_entropy(predict, state[2])
train_op = training.GradientDescentOptimizer(0.001).minimize(loss)
sess.run([variables_lib.global_variables_initializer()])
_, outputs, state = sess.run(
[train_op, outputs, state], {inputs: x_train, predict: y_train})
self.assertEqual(len(outputs), batch)
self.assertEqual(len(state), 4)
for s in state:
self.assertEqual(len(s), batch)
@test_util.run_deprecated_v1
def testStaticRNNWithKerasSimpleRNNCell(self):
with self.cached_session() as sess:
input_shape = 10
output_shape = 5
timestep = 4
batch = 100
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=batch,
test_samples=0,
input_shape=(timestep, input_shape),
num_classes=output_shape)
x_train = np.transpose(x_train, (1, 0, 2))
y_train = np_utils.to_categorical(y_train)
cell = keras.layers.SimpleRNNCell(output_shape)
inputs = [array_ops.placeholder(
dtypes.float32, shape=(None, input_shape))] * timestep
predict = array_ops.placeholder(
dtypes.float32, shape=(None, output_shape))
outputs, state = rnn.static_rnn(
cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), timestep)
self.assertEqual(outputs[0].shape.as_list(), [None, output_shape])
self.assertEqual(state.shape.as_list(), [None, output_shape])
loss = losses.softmax_cross_entropy(predict, state)
train_op = training.GradientDescentOptimizer(0.001).minimize(loss)
sess.run([variables_lib.global_variables_initializer()])
feed_dict = {i: d for i, d in zip(inputs, x_train)}
feed_dict[predict] = y_train
_, outputs, state = sess.run(
[train_op, outputs, state], feed_dict)
self.assertEqual(len(outputs), timestep)
self.assertEqual(len(outputs[0]), batch)
self.assertEqual(len(state), batch)
@test_util.run_deprecated_v1
def testKerasAndTFRNNLayerOutputComparison(self):
input_shape = 10
output_shape = 5
timestep = 4
batch = 20
(x_train, _), _ = testing_utils.get_test_data(
train_samples=batch,
test_samples=0,
input_shape=(timestep, input_shape),
num_classes=output_shape)
fix_weights_generator = keras.layers.SimpleRNNCell(output_shape)
fix_weights_generator.build((None, input_shape))
weights = fix_weights_generator.get_weights()
with self.session(graph=ops_lib.Graph()) as sess:
inputs = array_ops.placeholder(
dtypes.float32, shape=(None, timestep, input_shape))
cell = keras.layers.SimpleRNNCell(output_shape)
tf_out, tf_state = rnn.dynamic_rnn(
cell, inputs, dtype=dtypes.float32)
cell.set_weights(weights)
[tf_out, tf_state] = sess.run([tf_out, tf_state], {inputs: x_train})
with self.session(graph=ops_lib.Graph()) as sess:
k_input = keras.Input(shape=(timestep, input_shape),
dtype=dtypes.float32)
cell = keras.layers.SimpleRNNCell(output_shape)
layer = keras.layers.RNN(cell, return_sequences=True, return_state=True)
keras_out = layer(k_input)
cell.set_weights(weights)
k_out, k_state = sess.run(keras_out, {k_input: x_train})
self.assertAllClose(tf_out, k_out)
self.assertAllClose(tf_state, k_state)
@test_util.run_deprecated_v1
def testSimpleRNNCellAndBasicRNNCellComparison(self):
input_shape = 10
output_shape = 5
timestep = 4
batch = 20
(x_train, _), _ = testing_utils.get_test_data(
train_samples=batch,
test_samples=0,
input_shape=(timestep, input_shape),
num_classes=output_shape)
fix_weights_generator = keras.layers.SimpleRNNCell(output_shape)
fix_weights_generator.build((None, input_shape))
# The SimpleRNNCell contains 3 weights: kernel, recurrent_kernel, and bias
# The BasicRNNCell contains 2 weight: kernel and bias, where kernel is
# zipped [kernel, recurrent_kernel] in SimpleRNNCell.
keras_weights = fix_weights_generator.get_weights()
kernel, recurrent_kernel, bias = keras_weights
tf_weights = [np.concatenate((kernel, recurrent_kernel)), bias]
with self.session(graph=ops_lib.Graph()) as sess:
inputs = array_ops.placeholder(
dtypes.float32, shape=(None, timestep, input_shape))
cell = keras.layers.SimpleRNNCell(output_shape)
k_out, k_state = rnn.dynamic_rnn(
cell, inputs, dtype=dtypes.float32)
cell.set_weights(keras_weights)
[k_out, k_state] = sess.run([k_out, k_state], {inputs: x_train})
with self.session(graph=ops_lib.Graph()) as sess:
inputs = array_ops.placeholder(
dtypes.float32, shape=(None, timestep, input_shape))
cell = rnn_cell_impl.BasicRNNCell(output_shape)
tf_out, tf_state = rnn.dynamic_rnn(
cell, inputs, dtype=dtypes.float32)
cell.set_weights(tf_weights)
[tf_out, tf_state] = sess.run([tf_out, tf_state], {inputs: x_train})
self.assertAllClose(tf_out, k_out, atol=1e-5)
self.assertAllClose(tf_state, k_state, atol=1e-5)
@test_util.run_deprecated_v1
def testBasicLSTMCellInterchangeWithLSTMCell(self):
with self.session(graph=ops_lib.Graph()) as sess:
@ -649,70 +356,6 @@ class RNNTest(test.TestCase):
save.restore(sess, save_path)
self.assertAllEqual([10.] * 4, self.evaluate(lstm_cell._bias))
# TODO(scottzhu): Look into updating for V2 Initializers.
@test_util.run_deprecated_v1
def testRNNCellSerialization(self):
for cell in [
rnn_cell_impl.LSTMCell(32, use_peepholes=True, cell_clip=True),
rnn_cell_impl.BasicLSTMCell(32, dtype=dtypes.float32),
rnn_cell_impl.BasicRNNCell(32, activation="relu", dtype=dtypes.float32),
rnn_cell_impl.GRUCell(32, dtype=dtypes.float32)
]:
with self.cached_session():
x = keras.Input((None, 5))
layer = keras.layers.RNN(cell)
y = layer(x)
model = keras.models.Model(x, y)
model.compile(optimizer="rmsprop", loss="mse")
# Test basic case serialization.
x_np = np.random.random((6, 5, 5))
y_np = model.predict(x_np)
weights = model.get_weights()
config = layer.get_config()
# The custom_objects is important here since rnn_cell_impl is
# not visible as a Keras layer, and also has a name conflict with
# keras.LSTMCell and GRUCell.
layer = keras.layers.RNN.from_config(
config,
custom_objects={
"BasicRNNCell": rnn_cell_impl.BasicRNNCell,
"GRUCell": rnn_cell_impl.GRUCell,
"LSTMCell": rnn_cell_impl.LSTMCell,
"BasicLSTMCell": rnn_cell_impl.BasicLSTMCell
})
y = layer(x)
model = keras.models.Model(x, y)
model.set_weights(weights)
y_np_2 = model.predict(x_np)
self.assertAllClose(y_np, y_np_2, atol=1e-4)
def testRNNCellActsLikeKerasRNNCellInProperScope(self):
with base_layers.keras_style_scope():
kn1 = KerasNetworkTFRNNs(name="kn1")
kn2 = KerasNetworkKerasRNNs(name="kn2")
z = array_ops.zeros((2, 3))
kn1(z)
kn2(z)
# pylint: disable=protected-access
self.assertTrue(all("kn1" in v.name for v in kn1._cell.variables))
self.assertTrue(all("kn2" in v.name for v in kn2._cell.variables))
with base_layers.keras_style_scope():
kn1_new = KerasNetworkTFRNNs(name="kn1_new")
kn2_new = KerasNetworkKerasRNNs(name="kn2_new")
kn2_new(z)
# Most importantly, this doesn't fail due to variable scope reuse issues.
kn1_new(z)
self.assertTrue(all("kn1_new" in v.name for v in kn1_new._cell.variables))
self.assertTrue(all("kn2_new" in v.name for v in kn2_new._cell.variables))
######### Benchmarking RNN code