Improve the error message when a saved function captures a symbolic Tensor without a constant value.
PiperOrigin-RevId: 253624252
This commit is contained in:
parent
d9a37895d6
commit
48bd9b6b56
@ -262,8 +262,13 @@ class _SaveableView(object):
|
||||
if (tensor_util.is_tensor(capture)
|
||||
and capture.dtype not in _UNCOPIABLE_DTYPES
|
||||
and capture not in self.captured_tensor_node_ids):
|
||||
copied_tensor = constant_op.constant(
|
||||
tensor_util.constant_value(capture))
|
||||
capture_constant_value = tensor_util.constant_value(capture)
|
||||
if capture_constant_value is None:
|
||||
raise ValueError(
|
||||
("Attempted to save a function {} which references a symbolic "
|
||||
"Tensor {} that is not a simple constant. This is not "
|
||||
"supported.").format(concrete_function.name, capture))
|
||||
copied_tensor = constant_op.constant(capture_constant_value)
|
||||
node_id = len(self.nodes)
|
||||
node = _CapturedConstant(
|
||||
eager_tensor=capture, graph_tensor=copied_tensor)
|
||||
|
@ -124,6 +124,23 @@ class SaveTest(test.TestCase):
|
||||
root = util.Checkpoint(model=sequential.Sequential([core.Dense(2)]))
|
||||
save.save(root, os.path.join(self.get_temp_dir(), "saved_model"))
|
||||
|
||||
def test_captured_symbolic_tensor_exception(self):
|
||||
root = module.Module()
|
||||
symbolic_tensor = []
|
||||
|
||||
@def_function.function
|
||||
def captured_intermediate(x):
|
||||
symbolic_tensor.append(math_ops.add(x, x, name="a_tensor"))
|
||||
return symbolic_tensor[-1] * 2
|
||||
|
||||
captured_intermediate(constant_op.constant(1.))
|
||||
|
||||
root.f = def_function.function(lambda: symbolic_tensor[-1],
|
||||
input_signature=[])
|
||||
with self.assertRaisesRegexp(ValueError, "a_tensor"):
|
||||
save.save(root, os.path.join(self.get_temp_dir(), "saved_model"),
|
||||
signatures=root.f)
|
||||
|
||||
def test_version_information_included(self):
|
||||
root = tracking.AutoTrackable()
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
|
Loading…
Reference in New Issue
Block a user