Rollfoward of cl/347928110: Do not trace legacy rnn cells using the input spec.

While this will still generate the traces, the input and state shapes are passed down from the model inputs, and will not raise the error from b/172114000.

PiperOrigin-RevId: 348078264
Change-Id: I7e060a9f40b3046385d75510ca0ae8a2bbc661fd
This commit is contained in:
Katherine Wu 2020-12-17 13:03:12 -08:00 committed by TensorFlower Gardener
parent 1fbddee1b1
commit 5e6c86fdf2
4 changed files with 40 additions and 1 deletions
tensorflow/python
keras/layers/legacy_rnn
kernel_tests

View File

@ -41,6 +41,7 @@ py_library(
"//tensorflow/python/keras:initializers",
"//tensorflow/python/keras/engine:input_spec",
"//tensorflow/python/keras/legacy_tf_layers:layers_base",
"//tensorflow/python/keras/saving",
"//tensorflow/python/keras/utils:tf_utils",
"//tensorflow/python/training/tracking:base",
],

View File

@ -346,6 +346,12 @@ class RNNCell(base_layer.Layer):
def get_config(self): # pylint: disable=useless-super-delegation
return super(RNNCell, self).get_config()
@property
def _use_input_spec_as_call_signature(self):
# We do not store the shape information for the state argument in the call
# function for legacy RNN cells, so do not generate an input signature.
return False
class LayerRNNCell(RNNCell):
"""Subclass of RNNCells that act like proper `tf.Layer` objects.

View File

@ -3128,6 +3128,10 @@ cuda_py_test(
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/saved_model:load",
"//tensorflow/python/saved_model:save",
"//tensorflow/python/training/tracking",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],

View File

@ -26,14 +26,16 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
@ -47,6 +49,9 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import save
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util as trackable_utils
from tensorflow.python.util import nest
@ -3060,6 +3065,29 @@ class RNNCellTest(test.TestCase, parameterized.TestCase):
reconstructed_wrapper = wrapper_cls.from_config(config_copy)
self.assertFalse(reconstructed_wrapper._dropout_state_filter(None))
def testSavedModel(self):
if test_util.is_gpu_available():
self.skipTest("b/175887901")
with self.cached_session():
root = tracking.AutoTrackable()
root.cell = rnn_cell_impl.LSTMCell(8)
@def_function.function(input_signature=[tensor_spec.TensorSpec([3, 8])])
def call(x):
state = root.cell.zero_state(3, dtype=x.dtype)
y, _ = root.cell(x, state)
return y
root.call = call
expected = root.call(array_ops.zeros((3, 8)))
self.evaluate(variables_lib.global_variables_initializer())
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(root, save_dir)
loaded = load.load(save_dir)
self.evaluate(variables_lib.global_variables_initializer())
self.assertAllClose(
expected, loaded.call(array_ops.zeros((3, 8))))
@test_util.run_all_in_graph_and_eager_modes
@test_util.run_all_without_tensor_float_32(