Improve the error message when a saved function captures a symbolic Tensor without a constant value.

PiperOrigin-RevId: 253624252
This commit is contained in:
Allen Lavoie 2019-06-17 11:35:41 -07:00 committed by TensorFlower Gardener
parent d9a37895d6
commit 48bd9b6b56
2 changed files with 24 additions and 2 deletions

View File

@ -262,8 +262,13 @@ class _SaveableView(object):
if (tensor_util.is_tensor(capture)
and capture.dtype not in _UNCOPIABLE_DTYPES
and capture not in self.captured_tensor_node_ids):
copied_tensor = constant_op.constant(
tensor_util.constant_value(capture))
capture_constant_value = tensor_util.constant_value(capture)
if capture_constant_value is None:
raise ValueError(
("Attempted to save a function {} which references a symbolic "
"Tensor {} that is not a simple constant. This is not "
"supported.").format(concrete_function.name, capture))
copied_tensor = constant_op.constant(capture_constant_value)
node_id = len(self.nodes)
node = _CapturedConstant(
eager_tensor=capture, graph_tensor=copied_tensor)

View File

@ -124,6 +124,23 @@ class SaveTest(test.TestCase):
root = util.Checkpoint(model=sequential.Sequential([core.Dense(2)]))
save.save(root, os.path.join(self.get_temp_dir(), "saved_model"))
def test_captured_symbolic_tensor_exception(self):
root = module.Module()
symbolic_tensor = []
@def_function.function
def captured_intermediate(x):
symbolic_tensor.append(math_ops.add(x, x, name="a_tensor"))
return symbolic_tensor[-1] * 2
captured_intermediate(constant_op.constant(1.))
root.f = def_function.function(lambda: symbolic_tensor[-1],
input_signature=[])
with self.assertRaisesRegexp(ValueError, "a_tensor"):
save.save(root, os.path.join(self.get_temp_dir(), "saved_model"),
signatures=root.f)
def test_version_information_included(self):
root = tracking.AutoTrackable()
save_dir = os.path.join(self.get_temp_dir(), "saved_model")