diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index aaa5285d133..e82334b4923 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -262,8 +262,13 @@ class _SaveableView(object): if (tensor_util.is_tensor(capture) and capture.dtype not in _UNCOPIABLE_DTYPES and capture not in self.captured_tensor_node_ids): - copied_tensor = constant_op.constant( - tensor_util.constant_value(capture)) + capture_constant_value = tensor_util.constant_value(capture) + if capture_constant_value is None: + raise ValueError( + ("Attempted to save a function {} which references a symbolic " + "Tensor {} that is not a simple constant. This is not " + "supported.").format(concrete_function.name, capture)) + copied_tensor = constant_op.constant(capture_constant_value) node_id = len(self.nodes) node = _CapturedConstant( eager_tensor=capture, graph_tensor=copied_tensor) diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py index b412fa6f145..566c508526d 100644 --- a/tensorflow/python/saved_model/save_test.py +++ b/tensorflow/python/saved_model/save_test.py @@ -124,6 +124,23 @@ class SaveTest(test.TestCase): root = util.Checkpoint(model=sequential.Sequential([core.Dense(2)])) save.save(root, os.path.join(self.get_temp_dir(), "saved_model")) + def test_captured_symbolic_tensor_exception(self): + root = module.Module() + symbolic_tensor = [] + + @def_function.function + def captured_intermediate(x): + symbolic_tensor.append(math_ops.add(x, x, name="a_tensor")) + return symbolic_tensor[-1] * 2 + + captured_intermediate(constant_op.constant(1.)) + + root.f = def_function.function(lambda: symbolic_tensor[-1], + input_signature=[]) + with self.assertRaisesRegexp(ValueError, "a_tensor"): + save.save(root, os.path.join(self.get_temp_dir(), "saved_model"), + signatures=root.f) + def test_version_information_included(self): root = tracking.AutoTrackable() save_dir = os.path.join(self.get_temp_dir(), "saved_model")