From 48bd9b6b56e2afe8f8f2924a794be2a2f591ec86 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Mon, 17 Jun 2019 11:35:41 -0700 Subject: [PATCH] Improve the error message when a saved function captures a symbolic Tensor without a constant value. PiperOrigin-RevId: 253624252 --- tensorflow/python/saved_model/save.py | 9 +++++++-- tensorflow/python/saved_model/save_test.py | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index aaa5285d133..e82334b4923 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -262,8 +262,13 @@ class _SaveableView(object): if (tensor_util.is_tensor(capture) and capture.dtype not in _UNCOPIABLE_DTYPES and capture not in self.captured_tensor_node_ids): - copied_tensor = constant_op.constant( - tensor_util.constant_value(capture)) + capture_constant_value = tensor_util.constant_value(capture) + if capture_constant_value is None: + raise ValueError( + ("Attempted to save a function {} which references a symbolic " + "Tensor {} that is not a simple constant. This is not " + "supported.").format(concrete_function.name, capture)) + copied_tensor = constant_op.constant(capture_constant_value) node_id = len(self.nodes) node = _CapturedConstant( eager_tensor=capture, graph_tensor=copied_tensor) diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py index b412fa6f145..566c508526d 100644 --- a/tensorflow/python/saved_model/save_test.py +++ b/tensorflow/python/saved_model/save_test.py @@ -124,6 +124,23 @@ class SaveTest(test.TestCase): root = util.Checkpoint(model=sequential.Sequential([core.Dense(2)])) save.save(root, os.path.join(self.get_temp_dir(), "saved_model")) + def test_captured_symbolic_tensor_exception(self): + root = module.Module() + symbolic_tensor = [] + + @def_function.function + def captured_intermediate(x): + symbolic_tensor.append(math_ops.add(x, x, name="a_tensor")) + return symbolic_tensor[-1] * 2 + + captured_intermediate(constant_op.constant(1.)) + + root.f = def_function.function(lambda: symbolic_tensor[-1], + input_signature=[]) + with self.assertRaisesRegexp(ValueError, "a_tensor"): + save.save(root, os.path.join(self.get_temp_dir(), "saved_model"), + signatures=root.f) + def test_version_information_included(self): root = tracking.AutoTrackable() save_dir = os.path.join(self.get_temp_dir(), "saved_model")