Ensure serialized concrete functions are saved deterministically.
Previously, a cycle of saving and loading would randomly change which concrete function was chosen given a set of input arguments. PiperOrigin-RevId: 356844182 Change-Id: I5e2b3ee87866c62b619bde7b2df95d1f4f22d82c
This commit is contained in:
parent
779e443222
commit
7bcef3c792
@ -2898,8 +2898,17 @@ class FunctionCache(object):
|
|||||||
_FunctionGarbageCollector(self.arg_relaxed_specs)]
|
_FunctionGarbageCollector(self.arg_relaxed_specs)]
|
||||||
|
|
||||||
def all_values(self):
|
def all_values(self):
|
||||||
"""A set of all `ConcreteFunction` instances held by this cache."""
|
"""A list of all `ConcreteFunction` instances held by this cache."""
|
||||||
return set(self.primary.values()) | set(self.arg_relaxed.values())
|
# We need to simultaneously make sure our returned concrete functions are
|
||||||
|
# unique *and* make sure they are returned in a deterministic order for
|
||||||
|
# serialization.
|
||||||
|
#
|
||||||
|
# TODO(b/174215821): It's likely that we ultimately would just prefer to
|
||||||
|
# choose the most specific concrete function shape given a set of
|
||||||
|
# arguments. If and when that is implemented, this logic can be revisited.
|
||||||
|
primary_functions = set(self.primary.values())
|
||||||
|
return list(self.primary.values()) + [
|
||||||
|
v for v in self.arg_relaxed.values() if v not in primary_functions]
|
||||||
|
|
||||||
|
|
||||||
class Function(object):
|
class Function(object):
|
||||||
|
@ -2169,5 +2169,32 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
|||||||
finally:
|
finally:
|
||||||
def_function.run_functions_eagerly(False)
|
def_function.run_functions_eagerly(False)
|
||||||
|
|
||||||
|
def test_restored_model_concrete_function_is_deterministic(self):
|
||||||
|
previous_concrete_function = None
|
||||||
|
for _ in range(100):
|
||||||
|
|
||||||
|
class MyModel(module.Module):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def __call__(self, x):
|
||||||
|
return x * constant_op.constant(3.0)
|
||||||
|
|
||||||
|
model = MyModel()
|
||||||
|
model(array_ops.ones((7, 3), dtype=dtypes.float32))
|
||||||
|
model.__call__.get_concrete_function(
|
||||||
|
tensor_spec.TensorSpec([None, 3], dtypes.float32))
|
||||||
|
loaded = cycle(model, 1)
|
||||||
|
|
||||||
|
# Ensure the newly loaded concrete function is the same as the previous
|
||||||
|
# after a cycle of serialization / deserialization.
|
||||||
|
new_concrete_function = loaded.__call__.get_concrete_function(
|
||||||
|
tensor_spec.TensorSpec([None, 3], dtypes.float32))
|
||||||
|
if previous_concrete_function is not None:
|
||||||
|
self.assertEqual(previous_concrete_function.pretty_printed_signature(),
|
||||||
|
new_concrete_function.pretty_printed_signature())
|
||||||
|
|
||||||
|
previous_concrete_function = new_concrete_function
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user