Rollfoward of cl/347928110: Do not trace legacy rnn cells using the input spec.
While this will still generate the traces, the input and state shapes are passed down from the model inputs, and will not raise the error from b/172114000. PiperOrigin-RevId: 348078264 Change-Id: I7e060a9f40b3046385d75510ca0ae8a2bbc661fd
This commit is contained in:
parent
1fbddee1b1
commit
5e6c86fdf2
tensorflow/python
@ -41,6 +41,7 @@ py_library(
|
||||
"//tensorflow/python/keras:initializers",
|
||||
"//tensorflow/python/keras/engine:input_spec",
|
||||
"//tensorflow/python/keras/legacy_tf_layers:layers_base",
|
||||
"//tensorflow/python/keras/saving",
|
||||
"//tensorflow/python/keras/utils:tf_utils",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
],
|
||||
|
@ -346,6 +346,12 @@ class RNNCell(base_layer.Layer):
|
||||
def get_config(self): # pylint: disable=useless-super-delegation
|
||||
return super(RNNCell, self).get_config()
|
||||
|
||||
@property
|
||||
def _use_input_spec_as_call_signature(self):
|
||||
# We do not store the shape information for the state argument in the call
|
||||
# function for legacy RNN cells, so do not generate an input signature.
|
||||
return False
|
||||
|
||||
|
||||
class LayerRNNCell(RNNCell):
|
||||
"""Subclass of RNNCells that act like proper `tf.Layer` objects.
|
||||
|
@ -3128,6 +3128,10 @@ cuda_py_test(
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/saved_model:load",
|
||||
"//tensorflow/python/saved_model:save",
|
||||
"//tensorflow/python/training/tracking",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
|
@ -26,14 +26,16 @@ import numpy as np
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
@ -47,6 +49,9 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.saved_model import load
|
||||
from tensorflow.python.saved_model import save
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
@ -3060,6 +3065,29 @@ class RNNCellTest(test.TestCase, parameterized.TestCase):
|
||||
reconstructed_wrapper = wrapper_cls.from_config(config_copy)
|
||||
self.assertFalse(reconstructed_wrapper._dropout_state_filter(None))
|
||||
|
||||
def testSavedModel(self):
|
||||
if test_util.is_gpu_available():
|
||||
self.skipTest("b/175887901")
|
||||
|
||||
with self.cached_session():
|
||||
root = tracking.AutoTrackable()
|
||||
root.cell = rnn_cell_impl.LSTMCell(8)
|
||||
@def_function.function(input_signature=[tensor_spec.TensorSpec([3, 8])])
|
||||
def call(x):
|
||||
state = root.cell.zero_state(3, dtype=x.dtype)
|
||||
y, _ = root.cell(x, state)
|
||||
return y
|
||||
root.call = call
|
||||
expected = root.call(array_ops.zeros((3, 8)))
|
||||
self.evaluate(variables_lib.global_variables_initializer())
|
||||
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
save.save(root, save_dir)
|
||||
loaded = load.load(save_dir)
|
||||
self.evaluate(variables_lib.global_variables_initializer())
|
||||
self.assertAllClose(
|
||||
expected, loaded.call(array_ops.zeros((3, 8))))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@test_util.run_all_without_tensor_float_32(
|
||||
|
Loading…
Reference in New Issue
Block a user