Reduce Functional.__call__ Python overhead by ~5-10%
PiperOrigin-RevId: 311775071 Change-Id: I45dd0a1ce865d6c17f7b5e292799348e1e17a91c
This commit is contained in:
parent
75132b735b
commit
b1fc80f4a1
@ -469,11 +469,11 @@ class Functional(training_lib.Model):
|
||||
mask: (Optional) Tensor or nested structure of Tensors.
|
||||
|
||||
Returns:
|
||||
Two lists: output_tensors, output_masks
|
||||
output_tensors
|
||||
"""
|
||||
inputs = self._flatten_to_reference_inputs(inputs)
|
||||
if mask is None:
|
||||
masks = [None for _ in range(len(inputs))]
|
||||
masks = [None] * len(inputs)
|
||||
else:
|
||||
masks = self._flatten_to_reference_inputs(mask)
|
||||
for input_t, mask in zip(inputs, masks):
|
||||
@ -481,55 +481,39 @@ class Functional(training_lib.Model):
|
||||
|
||||
# Dictionary mapping reference tensors to computed tensors.
|
||||
tensor_dict = {}
|
||||
tensor_usage_count = self._tensor_usage_count
|
||||
for x, y in zip(self.inputs, inputs):
|
||||
y = self._conform_to_reference_input(y, ref_input=x)
|
||||
x_id = str(id(x))
|
||||
tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id]
|
||||
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
|
||||
|
||||
depth_keys = list(self._nodes_by_depth.keys())
|
||||
nodes_by_depth = self._nodes_by_depth
|
||||
depth_keys = list(nodes_by_depth.keys())
|
||||
depth_keys.sort(reverse=True)
|
||||
|
||||
for depth in depth_keys:
|
||||
nodes = self._nodes_by_depth[depth]
|
||||
nodes = nodes_by_depth[depth]
|
||||
for node in nodes:
|
||||
if node.is_input:
|
||||
continue # Input tensors already exist.
|
||||
|
||||
if not all(
|
||||
str(id(tensor)) in tensor_dict
|
||||
for tensor in nest.flatten(node.keras_inputs)):
|
||||
if any(t_id not in tensor_dict for t_id in node.flat_input_ids):
|
||||
continue # Node is not computable, try skipping.
|
||||
|
||||
layer = node.layer
|
||||
args, kwargs = node.map_arguments(tensor_dict)
|
||||
outputs = layer(*args, **kwargs)
|
||||
outputs = node.layer(*args, **kwargs)
|
||||
|
||||
# Update tensor_dict.
|
||||
for x, y in zip(nest.flatten(node.outputs), nest.flatten(outputs)):
|
||||
x_id = str(id(x))
|
||||
tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id]
|
||||
for x_id, y in zip(node.flat_output_ids, nest.flatten(outputs)):
|
||||
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
|
||||
|
||||
output_tensors = []
|
||||
output_shapes = []
|
||||
for x in self.outputs:
|
||||
assert str(id(x)) in tensor_dict, 'Could not compute output ' + str(x)
|
||||
tensor = tensor_dict[str(id(x))].pop()
|
||||
output_shapes.append(x.shape)
|
||||
output_tensors.append(tensor)
|
||||
x_id = str(id(x))
|
||||
assert x_id in tensor_dict, 'Could not compute output ' + str(x)
|
||||
output_tensors.append(tensor_dict[x_id].pop())
|
||||
|
||||
if output_shapes is not None:
|
||||
input_shapes = [x.shape for x in inputs]
|
||||
try:
|
||||
cache_key = tuple(tf_utils.convert_shapes(input_shapes, to_tuples=True))
|
||||
self._output_shape_cache[cache_key] = nest.pack_sequence_as(
|
||||
self._nested_outputs, output_shapes)
|
||||
except ValueError:
|
||||
# In case there are unknown TensorShape, eg for sparse tensor input,
|
||||
# We skip the caching since the shape is unknown.
|
||||
pass
|
||||
|
||||
output_tensors = nest.pack_sequence_as(self._nested_outputs, output_tensors)
|
||||
return output_tensors
|
||||
return nest.pack_sequence_as(self._nested_outputs, output_tensors)
|
||||
|
||||
def _flatten_to_reference_inputs(self, tensors):
|
||||
"""Maps `tensors` to their respective `keras.Input`."""
|
||||
|
@ -102,6 +102,10 @@ class Node(object):
|
||||
tensor._keras_history = KerasHistory(
|
||||
layer=layer, node_index=node_index, tensor_index=i)
|
||||
|
||||
# Cached for performance.
|
||||
self.flat_input_ids = [str(id(t)) for t in self._keras_inputs]
|
||||
self.flat_output_ids = [str(id(t)) for t in nest.flatten(self.outputs)]
|
||||
|
||||
@property
|
||||
def keras_inputs(self):
|
||||
"""Tensors input to this node that can be traced back to a `keras.Input`."""
|
||||
|
Loading…
Reference in New Issue
Block a user