Merge pull request #41791 from jonathanchu33:cache-input-signature
PiperOrigin-RevId: 324042170 Change-Id: I441dd72dc5f583c6d9e30b98b19d8276f0bd4ed8
This commit is contained in:
commit
cdbd96f307
@ -2902,6 +2902,9 @@ class Function(object):
|
||||
self._function_attributes = attributes or {}
|
||||
self._capture_by_value = capture_by_value
|
||||
self.tracing_count = 0
|
||||
if self.input_signature is not None:
|
||||
self._hashable_input_signature = _make_input_signature_hashable(
|
||||
self.flat_input_signature)
|
||||
|
||||
self._lock = threading.Lock()
|
||||
# _descriptor_cache is a of instance of a class to an instance-specific
|
||||
@ -3072,10 +3075,11 @@ class Function(object):
|
||||
inputs = (args, kwargs) if kwargs else args
|
||||
input_signature = pywrap_tfe.TFE_Py_EncodeArg(inputs,
|
||||
include_tensor_ranks_only)
|
||||
hashable_input_signature = _make_input_signature_hashable(input_signature)
|
||||
else:
|
||||
del args, kwargs
|
||||
assert not include_tensor_ranks_only
|
||||
input_signature = self.flat_input_signature
|
||||
hashable_input_signature = self._hashable_input_signature
|
||||
|
||||
ctx = context.context()
|
||||
|
||||
@ -3145,10 +3149,9 @@ class Function(object):
|
||||
else:
|
||||
variable_policy = save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES
|
||||
|
||||
return CacheKey(
|
||||
_make_input_signature_hashable(input_signature), parent_graph,
|
||||
device_functions, colocation_stack, in_cross_replica_context,
|
||||
variable_policy, xla_context_id)
|
||||
return CacheKey(hashable_input_signature, parent_graph, device_functions,
|
||||
colocation_stack, in_cross_replica_context, variable_policy,
|
||||
xla_context_id)
|
||||
|
||||
def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
|
||||
"""Create a `ConcreteFunction` from `args` and `kwargs`."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user