Ensure serialized concrete functions are saved deterministically.
Previously, a cycle of saving and loading would randomly change which concrete function was chosen given a set of input arguments. PiperOrigin-RevId: 356844182 Change-Id: I5e2b3ee87866c62b619bde7b2df95d1f4f22d82c
This commit is contained in:
parent
779e443222
commit
7bcef3c792
@ -2898,8 +2898,17 @@ class FunctionCache(object):
|
||||
_FunctionGarbageCollector(self.arg_relaxed_specs)]
|
||||
|
||||
def all_values(self):
|
||||
"""A set of all `ConcreteFunction` instances held by this cache."""
|
||||
return set(self.primary.values()) | set(self.arg_relaxed.values())
|
||||
"""A list of all `ConcreteFunction` instances held by this cache."""
|
||||
# We need to simultaneously make sure our returned concrete functions are
|
||||
# unique *and* make sure they are returned in a deterministic order for
|
||||
# serialization.
|
||||
#
|
||||
# TODO(b/174215821): It's likely that we ultimately would just prefer to
|
||||
# choose the most specific concrete function shape given a set of
|
||||
# arguments. If and when that is implemented, this logic can be revisited.
|
||||
primary_functions = set(self.primary.values())
|
||||
return list(self.primary.values()) + [
|
||||
v for v in self.arg_relaxed.values() if v not in primary_functions]
|
||||
|
||||
|
||||
class Function(object):
|
||||
|
@ -2169,5 +2169,32 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
||||
finally:
|
||||
def_function.run_functions_eagerly(False)
|
||||
|
||||
def test_restored_model_concrete_function_is_deterministic(self):
|
||||
previous_concrete_function = None
|
||||
for _ in range(100):
|
||||
|
||||
class MyModel(module.Module):
|
||||
|
||||
@def_function.function
|
||||
def __call__(self, x):
|
||||
return x * constant_op.constant(3.0)
|
||||
|
||||
model = MyModel()
|
||||
model(array_ops.ones((7, 3), dtype=dtypes.float32))
|
||||
model.__call__.get_concrete_function(
|
||||
tensor_spec.TensorSpec([None, 3], dtypes.float32))
|
||||
loaded = cycle(model, 1)
|
||||
|
||||
# Ensure the newly loaded concrete function is the same as the previous
|
||||
# after a cycle of serialization / deserialization.
|
||||
new_concrete_function = loaded.__call__.get_concrete_function(
|
||||
tensor_spec.TensorSpec([None, 3], dtypes.float32))
|
||||
if previous_concrete_function is not None:
|
||||
self.assertEqual(previous_concrete_function.pretty_printed_signature(),
|
||||
new_concrete_function.pretty_printed_signature())
|
||||
|
||||
previous_concrete_function = new_concrete_function
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user