From 9be4774701e1d5032e0831ac82afc143ae1251f7 Mon Sep 17 00:00:00 2001 From: Jonathan Chu Date: Mon, 27 Jul 2020 23:11:03 +0000 Subject: [PATCH 1/4] Cache hashable input signature for _cache_key --- tensorflow/python/eager/function.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index f86e2889f3d..6662105cbeb 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -2902,6 +2902,8 @@ class Function(object): self._function_attributes = attributes or {} self._capture_by_value = capture_by_value self.tracing_count = 0 + self._hashable_input_signature = _make_input_signature_hashable( + self.flat_input_signature) self._lock = threading.Lock() # _descriptor_cache is a of instance of a class to an instance-specific @@ -2940,6 +2942,11 @@ class Function(object): """Returns the flattened input signature.""" return self._function_spec.flat_input_signature + @property + def hashable_input_signature(self): + """Returns a cached hashable object for the flattened input signature.""" + return self._hashable_input_signature + def _get_concrete_function_internal_garbage_collected(self, *args, **kwargs): """Returns a concrete function which cleans up its graph function.""" if self.input_signature: @@ -3072,10 +3079,11 @@ class Function(object): inputs = (args, kwargs) if kwargs else args input_signature = pywrap_tfe.TFE_Py_EncodeArg(inputs, include_tensor_ranks_only) + hashable_input_signature = _make_input_signature_hashable(input_signature) else: del args, kwargs assert not include_tensor_ranks_only - input_signature = self.flat_input_signature + hashable_input_signature = self.hashable_input_signature ctx = context.context() @@ -3144,9 +3152,9 @@ class Function(object): save_context.get_save_options().experimental_variable_policy) else: variable_policy = save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES - + return CacheKey( - _make_input_signature_hashable(input_signature), parent_graph, + hashable_input_signature, parent_graph, device_functions, colocation_stack, in_cross_replica_context, variable_policy, xla_context_id) From 9cc5633717b2b4258c0129a6f6ce430e4dceef77 Mon Sep 17 00:00:00 2001 From: Jonathan Chu Date: Tue, 28 Jul 2020 22:04:14 +0000 Subject: [PATCH 2/4] Remove public hashable_input_signature property, add conditional in initialization --- tensorflow/python/eager/function.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 6662105cbeb..0f1912a5efd 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -2902,8 +2902,9 @@ class Function(object): self._function_attributes = attributes or {} self._capture_by_value = capture_by_value self.tracing_count = 0 - self._hashable_input_signature = _make_input_signature_hashable( - self.flat_input_signature) + if self.input_signature is not None: + self._hashable_input_signature = _make_input_signature_hashable( + self.flat_input_signature) self._lock = threading.Lock() # _descriptor_cache is a of instance of a class to an instance-specific @@ -2942,11 +2943,6 @@ class Function(object): """Returns the flattened input signature.""" return self._function_spec.flat_input_signature - @property - def hashable_input_signature(self): - """Returns a cached hashable object for the flattened input signature.""" - return self._hashable_input_signature - def _get_concrete_function_internal_garbage_collected(self, *args, **kwargs): """Returns a concrete function which cleans up its graph function.""" if self.input_signature: @@ -3083,7 +3079,8 @@ class Function(object): else: del args, kwargs assert not include_tensor_ranks_only - hashable_input_signature = self.hashable_input_signature + assert hasattr(self, '_hashable_input_signature') + hashable_input_signature = self._hashable_input_signature ctx = context.context() From 225d851dad37cb036b37b59da7d6e091e4f08006 Mon Sep 17 00:00:00 2001 From: Jonathan Chu Date: Wed, 29 Jul 2020 18:34:50 +0000 Subject: [PATCH 3/4] Remove extraneous assert --- tensorflow/python/eager/function.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 0f1912a5efd..0c06e0425cd 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -3079,7 +3079,6 @@ class Function(object): else: del args, kwargs assert not include_tensor_ranks_only - assert hasattr(self, '_hashable_input_signature') hashable_input_signature = self._hashable_input_signature ctx = context.context() From 20939a4b51e72fd8e4263e8cc79f01f63a544f62 Mon Sep 17 00:00:00 2001 From: Jonathan Chu Date: Wed, 29 Jul 2020 19:27:03 +0000 Subject: [PATCH 4/4] Remove whitespace --- tensorflow/python/eager/function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 0c06e0425cd..40df3e33e27 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -3148,7 +3148,7 @@ class Function(object): save_context.get_save_options().experimental_variable_policy) else: variable_policy = save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES - + return CacheKey( hashable_input_signature, parent_graph, device_functions, colocation_stack, in_cross_replica_context,