diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 0b345210d4f..64e3b70a782 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -2898,8 +2898,17 @@ class FunctionCache(object): _FunctionGarbageCollector(self.arg_relaxed_specs)] def all_values(self): - """A set of all `ConcreteFunction` instances held by this cache.""" - return set(self.primary.values()) | set(self.arg_relaxed.values()) + """A list of all `ConcreteFunction` instances held by this cache.""" + # We need to simultaneously make sure our returned concrete functions are + # unique *and* make sure they are returned in a deterministic order for + # serialization. + # + # TODO(b/174215821): It's likely that we ultimately would just prefer to + # choose the most specific concrete function shape given a set of + # arguments. If and when that is implemented, this logic can be revisited. + primary_functions = set(self.primary.values()) + return list(self.primary.values()) + [ + v for v in self.arg_relaxed.values() if v not in primary_functions] class Function(object): diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index ecdc53d2572..91211fff3bd 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -2169,5 +2169,32 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase): finally: def_function.run_functions_eagerly(False) + def test_restored_model_concrete_function_is_deterministic(self): + previous_concrete_function = None + for _ in range(100): + + class MyModel(module.Module): + + @def_function.function + def __call__(self, x): + return x * constant_op.constant(3.0) + + model = MyModel() + model(array_ops.ones((7, 3), dtype=dtypes.float32)) + model.__call__.get_concrete_function( + tensor_spec.TensorSpec([None, 3], dtypes.float32)) + loaded = cycle(model, 1) + + # Ensure the newly loaded concrete function is the same as the previous + # after a cycle of serialization / deserialization. + new_concrete_function = loaded.__call__.get_concrete_function( + tensor_spec.TensorSpec([None, 3], dtypes.float32)) + if previous_concrete_function is not None: + self.assertEqual(previous_concrete_function.pretty_printed_signature(), + new_concrete_function.pretty_printed_signature()) + + previous_concrete_function = new_concrete_function + + if __name__ == "__main__": test.main()