Fix a bug where DenseHashTable
can't be saved to SavedModel, because it tries to save EagerTensors in a TF 1.x tf.Graph (used in TF 1.x style serving APIs).
The error message was: "RuntimeError: Attempting to capture an EagerTensor without building a function." PiperOrigin-RevId: 333185295 Change-Id: I05b65d2bc8c446032300faf31696316898b45ef4
This commit is contained in:
parent
c4120e6c8f
commit
568e2bef00
@ -720,9 +720,12 @@ tf_py_test(
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:lookup_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
"//tensorflow/python/saved_model:load",
|
||||
"//tensorflow/python/saved_model:save",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -38,6 +38,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -47,6 +48,8 @@ from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import load as saved_model_load
|
||||
from tensorflow.python.saved_model import save as saved_model_save
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.training.tracking import graph_view
|
||||
@ -1705,6 +1708,47 @@ class DenseHashTableOpTest(test.TestCase):
|
||||
output = load_table.lookup(input_string)
|
||||
self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output))
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testSavedModelSaveRestore(self):
|
||||
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
|
||||
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
|
||||
|
||||
root = tracking.AutoTrackable()
|
||||
|
||||
default_value = -1
|
||||
empty_key = 0
|
||||
deleted_key = -1
|
||||
keys = constant_op.constant([11, 12, 13], dtypes.int64)
|
||||
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
||||
root.table = lookup_ops.DenseHashTable(
|
||||
dtypes.int64,
|
||||
dtypes.int64,
|
||||
default_value=default_value,
|
||||
empty_key=empty_key,
|
||||
deleted_key=deleted_key,
|
||||
name="t1",
|
||||
checkpoint=True,
|
||||
initial_num_buckets=32)
|
||||
|
||||
@def_function.function(
|
||||
input_signature=[tensor_spec.TensorSpec((), dtypes.int64)])
|
||||
def lookup(key):
|
||||
return root.table.lookup(key)
|
||||
|
||||
root.lookup = lookup
|
||||
|
||||
self.assertAllEqual(0, root.table.size())
|
||||
root.table.insert(keys, values)
|
||||
self.assertAllEqual(3, self.evaluate(root.table.size()))
|
||||
self.assertAllEqual(32, len(self.evaluate(root.table.export()[0])))
|
||||
|
||||
saved_model_save.save(root, save_path)
|
||||
|
||||
del root
|
||||
loaded = saved_model_load.load(save_path)
|
||||
self.assertEqual(loaded.lookup(12), 1)
|
||||
self.assertEqual(loaded.lookup(10), -1)
|
||||
|
||||
@test_util.run_v1_only("Saver V1 only")
|
||||
def testVectorSaveRestore(self):
|
||||
save_dir = os.path.join(self.get_temp_dir(), "vector_save_restore")
|
||||
|
@ -1985,10 +1985,8 @@ class DenseHashTable(LookupInterface):
|
||||
self._checkpoint = checkpoint
|
||||
self._name = name
|
||||
|
||||
self._empty_key = ops.convert_to_tensor(
|
||||
empty_key, dtype=key_dtype, name="empty_key")
|
||||
self._deleted_key = ops.convert_to_tensor(
|
||||
deleted_key, dtype=key_dtype, name="deleted_key")
|
||||
self._empty_key = empty_key
|
||||
self._deleted_key = deleted_key
|
||||
self._shared_name = None
|
||||
if context.executing_eagerly():
|
||||
# TODO(allenl): This will leak memory due to kernel caching by the
|
||||
@ -2010,9 +2008,13 @@ class DenseHashTable(LookupInterface):
|
||||
# training to work correctly. Use the node name if no shared_name has been
|
||||
# explicitly specified.
|
||||
use_node_name_sharing = self._checkpoint and self._shared_name is None
|
||||
empty_key = ops.convert_to_tensor(
|
||||
self._empty_key, dtype=self._key_dtype, name="empty_key")
|
||||
deleted_key = ops.convert_to_tensor(
|
||||
self._deleted_key, dtype=self._key_dtype, name="deleted_key")
|
||||
table_ref = gen_lookup_ops.mutable_dense_hash_table_v2(
|
||||
empty_key=self._empty_key,
|
||||
deleted_key=self._deleted_key,
|
||||
empty_key=empty_key,
|
||||
deleted_key=deleted_key,
|
||||
shared_name=self._shared_name,
|
||||
use_node_name_sharing=use_node_name_sharing,
|
||||
value_dtype=self._value_dtype,
|
||||
|
Loading…
x
Reference in New Issue
Block a user