Improve the error message when a saved function captures a symbolic Tensor without a constant value.
PiperOrigin-RevId: 253624252
This commit is contained in:
parent
d9a37895d6
commit
48bd9b6b56
@ -262,8 +262,13 @@ class _SaveableView(object):
|
|||||||
if (tensor_util.is_tensor(capture)
|
if (tensor_util.is_tensor(capture)
|
||||||
and capture.dtype not in _UNCOPIABLE_DTYPES
|
and capture.dtype not in _UNCOPIABLE_DTYPES
|
||||||
and capture not in self.captured_tensor_node_ids):
|
and capture not in self.captured_tensor_node_ids):
|
||||||
copied_tensor = constant_op.constant(
|
capture_constant_value = tensor_util.constant_value(capture)
|
||||||
tensor_util.constant_value(capture))
|
if capture_constant_value is None:
|
||||||
|
raise ValueError(
|
||||||
|
("Attempted to save a function {} which references a symbolic "
|
||||||
|
"Tensor {} that is not a simple constant. This is not "
|
||||||
|
"supported.").format(concrete_function.name, capture))
|
||||||
|
copied_tensor = constant_op.constant(capture_constant_value)
|
||||||
node_id = len(self.nodes)
|
node_id = len(self.nodes)
|
||||||
node = _CapturedConstant(
|
node = _CapturedConstant(
|
||||||
eager_tensor=capture, graph_tensor=copied_tensor)
|
eager_tensor=capture, graph_tensor=copied_tensor)
|
||||||
|
@ -124,6 +124,23 @@ class SaveTest(test.TestCase):
|
|||||||
root = util.Checkpoint(model=sequential.Sequential([core.Dense(2)]))
|
root = util.Checkpoint(model=sequential.Sequential([core.Dense(2)]))
|
||||||
save.save(root, os.path.join(self.get_temp_dir(), "saved_model"))
|
save.save(root, os.path.join(self.get_temp_dir(), "saved_model"))
|
||||||
|
|
||||||
|
def test_captured_symbolic_tensor_exception(self):
|
||||||
|
root = module.Module()
|
||||||
|
symbolic_tensor = []
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def captured_intermediate(x):
|
||||||
|
symbolic_tensor.append(math_ops.add(x, x, name="a_tensor"))
|
||||||
|
return symbolic_tensor[-1] * 2
|
||||||
|
|
||||||
|
captured_intermediate(constant_op.constant(1.))
|
||||||
|
|
||||||
|
root.f = def_function.function(lambda: symbolic_tensor[-1],
|
||||||
|
input_signature=[])
|
||||||
|
with self.assertRaisesRegexp(ValueError, "a_tensor"):
|
||||||
|
save.save(root, os.path.join(self.get_temp_dir(), "saved_model"),
|
||||||
|
signatures=root.f)
|
||||||
|
|
||||||
def test_version_information_included(self):
|
def test_version_information_included(self):
|
||||||
root = tracking.AutoTrackable()
|
root = tracking.AutoTrackable()
|
||||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||||
|
Loading…
Reference in New Issue
Block a user