From 568e2bef00f24af1159a0846abf67c099ca78a21 Mon Sep 17 00:00:00 2001 From: Chenkai Kuang Date: Tue, 22 Sep 2020 16:37:59 -0700 Subject: [PATCH] Fix a bug where `DenseHashTable` can't be saved to SavedModel, because it tries to save EagerTensors in a TF 1.x tf.Graph (used in TF 1.x style serving APIs). The error message was: "RuntimeError: Attempting to capture an EagerTensor without building a function." PiperOrigin-RevId: 333185295 Change-Id: I05b65d2bc8c446032300faf31696316898b45ef4 --- tensorflow/python/kernel_tests/BUILD | 3 ++ .../python/kernel_tests/lookup_ops_test.py | 44 +++++++++++++++++++ tensorflow/python/ops/lookup_ops.py | 14 +++--- 3 files changed, 55 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index d1a7051cc56..a0210904247 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -720,9 +720,12 @@ tf_py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:lookup_ops", "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_spec", "//tensorflow/python:training", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:readers", + "//tensorflow/python/saved_model:load", + "//tensorflow/python/saved_model:save", ], ) diff --git a/tensorflow/python/kernel_tests/lookup_ops_test.py b/tensorflow/python/kernel_tests/lookup_ops_test.py index 045dafc3089..e4810beec31 100644 --- a/tensorflow/python/kernel_tests/lookup_ops_test.py +++ b/tensorflow/python/kernel_tests/lookup_ops_test.py @@ -38,6 +38,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -47,6 +48,8 @@ from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import test +from tensorflow.python.saved_model import load as saved_model_load +from tensorflow.python.saved_model import save as saved_model_save from tensorflow.python.training import saver from tensorflow.python.training import server_lib from tensorflow.python.training.tracking import graph_view @@ -1705,6 +1708,47 @@ class DenseHashTableOpTest(test.TestCase): output = load_table.lookup(input_string) self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output)) + @test_util.run_v2_only + def testSavedModelSaveRestore(self): + save_dir = os.path.join(self.get_temp_dir(), "save_restore") + save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") + + root = tracking.AutoTrackable() + + default_value = -1 + empty_key = 0 + deleted_key = -1 + keys = constant_op.constant([11, 12, 13], dtypes.int64) + values = constant_op.constant([0, 1, 2], dtypes.int64) + root.table = lookup_ops.DenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=default_value, + empty_key=empty_key, + deleted_key=deleted_key, + name="t1", + checkpoint=True, + initial_num_buckets=32) + + @def_function.function( + input_signature=[tensor_spec.TensorSpec((), dtypes.int64)]) + def lookup(key): + return root.table.lookup(key) + + root.lookup = lookup + + self.assertAllEqual(0, root.table.size()) + root.table.insert(keys, values) + self.assertAllEqual(3, self.evaluate(root.table.size())) + self.assertAllEqual(32, len(self.evaluate(root.table.export()[0]))) + + saved_model_save.save(root, save_path) + + del root + loaded = saved_model_load.load(save_path) + self.assertEqual(loaded.lookup(12), 1) + self.assertEqual(loaded.lookup(10), -1) + @test_util.run_v1_only("Saver V1 only") def testVectorSaveRestore(self): save_dir = os.path.join(self.get_temp_dir(), "vector_save_restore") diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index 9f27ccf9a1c..e53629250c9 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -1985,10 +1985,8 @@ class DenseHashTable(LookupInterface): self._checkpoint = checkpoint self._name = name - self._empty_key = ops.convert_to_tensor( - empty_key, dtype=key_dtype, name="empty_key") - self._deleted_key = ops.convert_to_tensor( - deleted_key, dtype=key_dtype, name="deleted_key") + self._empty_key = empty_key + self._deleted_key = deleted_key self._shared_name = None if context.executing_eagerly(): # TODO(allenl): This will leak memory due to kernel caching by the @@ -2010,9 +2008,13 @@ class DenseHashTable(LookupInterface): # training to work correctly. Use the node name if no shared_name has been # explicitly specified. use_node_name_sharing = self._checkpoint and self._shared_name is None + empty_key = ops.convert_to_tensor( + self._empty_key, dtype=self._key_dtype, name="empty_key") + deleted_key = ops.convert_to_tensor( + self._deleted_key, dtype=self._key_dtype, name="deleted_key") table_ref = gen_lookup_ops.mutable_dense_hash_table_v2( - empty_key=self._empty_key, - deleted_key=self._deleted_key, + empty_key=empty_key, + deleted_key=deleted_key, shared_name=self._shared_name, use_node_name_sharing=use_node_name_sharing, value_dtype=self._value_dtype,