Ensure serialized concrete functions are saved deterministically.

Previously, a cycle of saving and loading would randomly change which concrete
function was chosen given a set of input arguments.

PiperOrigin-RevId: 356844182
Change-Id: I5e2b3ee87866c62b619bde7b2df95d1f4f22d82c
This commit is contained in:
Daniel Ellis 2021-02-10 15:44:29 -08:00 committed by TensorFlower Gardener
parent 779e443222
commit 7bcef3c792
2 changed files with 38 additions and 2 deletions

View File

@ -2898,8 +2898,17 @@ class FunctionCache(object):
_FunctionGarbageCollector(self.arg_relaxed_specs)] _FunctionGarbageCollector(self.arg_relaxed_specs)]
def all_values(self): def all_values(self):
"""A set of all `ConcreteFunction` instances held by this cache.""" """A list of all `ConcreteFunction` instances held by this cache."""
return set(self.primary.values()) | set(self.arg_relaxed.values()) # We need to simultaneously make sure our returned concrete functions are
# unique *and* make sure they are returned in a deterministic order for
# serialization.
#
# TODO(b/174215821): It's likely that we ultimately would just prefer to
# choose the most specific concrete function shape given a set of
# arguments. If and when that is implemented, this logic can be revisited.
primary_functions = set(self.primary.values())
return list(self.primary.values()) + [
v for v in self.arg_relaxed.values() if v not in primary_functions]
class Function(object): class Function(object):

View File

@ -2169,5 +2169,32 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
finally: finally:
def_function.run_functions_eagerly(False) def_function.run_functions_eagerly(False)
def test_restored_model_concrete_function_is_deterministic(self):
previous_concrete_function = None
for _ in range(100):
class MyModel(module.Module):
@def_function.function
def __call__(self, x):
return x * constant_op.constant(3.0)
model = MyModel()
model(array_ops.ones((7, 3), dtype=dtypes.float32))
model.__call__.get_concrete_function(
tensor_spec.TensorSpec([None, 3], dtypes.float32))
loaded = cycle(model, 1)
# Ensure the newly loaded concrete function is the same as the previous
# after a cycle of serialization / deserialization.
new_concrete_function = loaded.__call__.get_concrete_function(
tensor_spec.TensorSpec([None, 3], dtypes.float32))
if previous_concrete_function is not None:
self.assertEqual(previous_concrete_function.pretty_printed_signature(),
new_concrete_function.pretty_printed_signature())
previous_concrete_function = new_concrete_function
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()