Merge pull request #41791 from jonathanchu33:cache-input-signature
PiperOrigin-RevId: 324042170 Change-Id: I441dd72dc5f583c6d9e30b98b19d8276f0bd4ed8
This commit is contained in:
commit
cdbd96f307
@ -2902,6 +2902,9 @@ class Function(object):
|
|||||||
self._function_attributes = attributes or {}
|
self._function_attributes = attributes or {}
|
||||||
self._capture_by_value = capture_by_value
|
self._capture_by_value = capture_by_value
|
||||||
self.tracing_count = 0
|
self.tracing_count = 0
|
||||||
|
if self.input_signature is not None:
|
||||||
|
self._hashable_input_signature = _make_input_signature_hashable(
|
||||||
|
self.flat_input_signature)
|
||||||
|
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
# _descriptor_cache is a of instance of a class to an instance-specific
|
# _descriptor_cache is a of instance of a class to an instance-specific
|
||||||
@ -3072,10 +3075,11 @@ class Function(object):
|
|||||||
inputs = (args, kwargs) if kwargs else args
|
inputs = (args, kwargs) if kwargs else args
|
||||||
input_signature = pywrap_tfe.TFE_Py_EncodeArg(inputs,
|
input_signature = pywrap_tfe.TFE_Py_EncodeArg(inputs,
|
||||||
include_tensor_ranks_only)
|
include_tensor_ranks_only)
|
||||||
|
hashable_input_signature = _make_input_signature_hashable(input_signature)
|
||||||
else:
|
else:
|
||||||
del args, kwargs
|
del args, kwargs
|
||||||
assert not include_tensor_ranks_only
|
assert not include_tensor_ranks_only
|
||||||
input_signature = self.flat_input_signature
|
hashable_input_signature = self._hashable_input_signature
|
||||||
|
|
||||||
ctx = context.context()
|
ctx = context.context()
|
||||||
|
|
||||||
@ -3145,10 +3149,9 @@ class Function(object):
|
|||||||
else:
|
else:
|
||||||
variable_policy = save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES
|
variable_policy = save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES
|
||||||
|
|
||||||
return CacheKey(
|
return CacheKey(hashable_input_signature, parent_graph, device_functions,
|
||||||
_make_input_signature_hashable(input_signature), parent_graph,
|
colocation_stack, in_cross_replica_context, variable_policy,
|
||||||
device_functions, colocation_stack, in_cross_replica_context,
|
xla_context_id)
|
||||||
variable_policy, xla_context_id)
|
|
||||||
|
|
||||||
def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
|
def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
|
||||||
"""Create a `ConcreteFunction` from `args` and `kwargs`."""
|
"""Create a `ConcreteFunction` from `args` and `kwargs`."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user