Fix a bug where DenseHashTable can't be saved to SavedModel, because it tries to save EagerTensors in a TF 1.x tf.Graph (used in TF 1.x style serving APIs).

The error message was:
"RuntimeError: Attempting to capture an EagerTensor without building a function."

PiperOrigin-RevId: 333185295
Change-Id: I05b65d2bc8c446032300faf31696316898b45ef4
This commit is contained in:
Chenkai Kuang 2020-09-22 16:37:59 -07:00 committed by TensorFlower Gardener
parent c4120e6c8f
commit 568e2bef00
3 changed files with 55 additions and 6 deletions

View File

@ -720,9 +720,12 @@ tf_py_test(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:lookup_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:training",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:readers",
"//tensorflow/python/saved_model:load",
"//tensorflow/python/saved_model:save",
],
)

View File

@ -38,6 +38,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -47,6 +48,8 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
from tensorflow.python.saved_model import load as saved_model_load
from tensorflow.python.saved_model import save as saved_model_save
from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
from tensorflow.python.training.tracking import graph_view
@ -1705,6 +1708,47 @@ class DenseHashTableOpTest(test.TestCase):
output = load_table.lookup(input_string)
self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output))
@test_util.run_v2_only
def testSavedModelSaveRestore(self):
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
root = tracking.AutoTrackable()
default_value = -1
empty_key = 0
deleted_key = -1
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
root.table = lookup_ops.DenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
deleted_key=deleted_key,
name="t1",
checkpoint=True,
initial_num_buckets=32)
@def_function.function(
input_signature=[tensor_spec.TensorSpec((), dtypes.int64)])
def lookup(key):
return root.table.lookup(key)
root.lookup = lookup
self.assertAllEqual(0, root.table.size())
root.table.insert(keys, values)
self.assertAllEqual(3, self.evaluate(root.table.size()))
self.assertAllEqual(32, len(self.evaluate(root.table.export()[0])))
saved_model_save.save(root, save_path)
del root
loaded = saved_model_load.load(save_path)
self.assertEqual(loaded.lookup(12), 1)
self.assertEqual(loaded.lookup(10), -1)
@test_util.run_v1_only("Saver V1 only")
def testVectorSaveRestore(self):
save_dir = os.path.join(self.get_temp_dir(), "vector_save_restore")

View File

@ -1985,10 +1985,8 @@ class DenseHashTable(LookupInterface):
self._checkpoint = checkpoint
self._name = name
self._empty_key = ops.convert_to_tensor(
empty_key, dtype=key_dtype, name="empty_key")
self._deleted_key = ops.convert_to_tensor(
deleted_key, dtype=key_dtype, name="deleted_key")
self._empty_key = empty_key
self._deleted_key = deleted_key
self._shared_name = None
if context.executing_eagerly():
# TODO(allenl): This will leak memory due to kernel caching by the
@ -2010,9 +2008,13 @@ class DenseHashTable(LookupInterface):
# training to work correctly. Use the node name if no shared_name has been
# explicitly specified.
use_node_name_sharing = self._checkpoint and self._shared_name is None
empty_key = ops.convert_to_tensor(
self._empty_key, dtype=self._key_dtype, name="empty_key")
deleted_key = ops.convert_to_tensor(
self._deleted_key, dtype=self._key_dtype, name="deleted_key")
table_ref = gen_lookup_ops.mutable_dense_hash_table_v2(
empty_key=self._empty_key,
deleted_key=self._deleted_key,
empty_key=empty_key,
deleted_key=deleted_key,
shared_name=self._shared_name,
use_node_name_sharing=use_node_name_sharing,
value_dtype=self._value_dtype,