Merge pull request #41791 from jonathanchu33:cache-input-signature

PiperOrigin-RevId: 324042170
Change-Id: I441dd72dc5f583c6d9e30b98b19d8276f0bd4ed8
This commit is contained in:
TensorFlower Gardener 2020-07-30 11:23:17 -07:00
commit cdbd96f307

View File

@ -2902,6 +2902,9 @@ class Function(object):
self._function_attributes = attributes or {}
self._capture_by_value = capture_by_value
self.tracing_count = 0
if self.input_signature is not None:
self._hashable_input_signature = _make_input_signature_hashable(
self.flat_input_signature)
self._lock = threading.Lock()
# _descriptor_cache is a of instance of a class to an instance-specific
@ -3072,10 +3075,11 @@ class Function(object):
inputs = (args, kwargs) if kwargs else args
input_signature = pywrap_tfe.TFE_Py_EncodeArg(inputs,
include_tensor_ranks_only)
hashable_input_signature = _make_input_signature_hashable(input_signature)
else:
del args, kwargs
assert not include_tensor_ranks_only
input_signature = self.flat_input_signature
hashable_input_signature = self._hashable_input_signature
ctx = context.context()
@ -3145,10 +3149,9 @@ class Function(object):
else:
variable_policy = save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES
return CacheKey(
_make_input_signature_hashable(input_signature), parent_graph,
device_functions, colocation_stack, in_cross_replica_context,
variable_policy, xla_context_id)
return CacheKey(hashable_input_signature, parent_graph, device_functions,
colocation_stack, in_cross_replica_context, variable_policy,
xla_context_id)
def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
"""Create a `ConcreteFunction` from `args` and `kwargs`."""