Reduce Functional.__call__ Python overhead by ~5-10%

PiperOrigin-RevId: 311775071
Change-Id: I45dd0a1ce865d6c17f7b5e292799348e1e17a91c
This commit is contained in:
Thomas O'Malley 2020-05-15 12:00:16 -07:00 committed by TensorFlower Gardener
parent 75132b735b
commit b1fc80f4a1
2 changed files with 19 additions and 31 deletions

View File

@ -469,11 +469,11 @@ class Functional(training_lib.Model):
mask: (Optional) Tensor or nested structure of Tensors. mask: (Optional) Tensor or nested structure of Tensors.
Returns: Returns:
Two lists: output_tensors, output_masks output_tensors
""" """
inputs = self._flatten_to_reference_inputs(inputs) inputs = self._flatten_to_reference_inputs(inputs)
if mask is None: if mask is None:
masks = [None for _ in range(len(inputs))] masks = [None] * len(inputs)
else: else:
masks = self._flatten_to_reference_inputs(mask) masks = self._flatten_to_reference_inputs(mask)
for input_t, mask in zip(inputs, masks): for input_t, mask in zip(inputs, masks):
@ -481,55 +481,39 @@ class Functional(training_lib.Model):
# Dictionary mapping reference tensors to computed tensors. # Dictionary mapping reference tensors to computed tensors.
tensor_dict = {} tensor_dict = {}
tensor_usage_count = self._tensor_usage_count
for x, y in zip(self.inputs, inputs): for x, y in zip(self.inputs, inputs):
y = self._conform_to_reference_input(y, ref_input=x) y = self._conform_to_reference_input(y, ref_input=x)
x_id = str(id(x)) x_id = str(id(x))
tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id] tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
depth_keys = list(self._nodes_by_depth.keys()) nodes_by_depth = self._nodes_by_depth
depth_keys = list(nodes_by_depth.keys())
depth_keys.sort(reverse=True) depth_keys.sort(reverse=True)
for depth in depth_keys: for depth in depth_keys:
nodes = self._nodes_by_depth[depth] nodes = nodes_by_depth[depth]
for node in nodes: for node in nodes:
if node.is_input: if node.is_input:
continue # Input tensors already exist. continue # Input tensors already exist.
if not all( if any(t_id not in tensor_dict for t_id in node.flat_input_ids):
str(id(tensor)) in tensor_dict
for tensor in nest.flatten(node.keras_inputs)):
continue # Node is not computable, try skipping. continue # Node is not computable, try skipping.
layer = node.layer
args, kwargs = node.map_arguments(tensor_dict) args, kwargs = node.map_arguments(tensor_dict)
outputs = layer(*args, **kwargs) outputs = node.layer(*args, **kwargs)
# Update tensor_dict. # Update tensor_dict.
for x, y in zip(nest.flatten(node.outputs), nest.flatten(outputs)): for x_id, y in zip(node.flat_output_ids, nest.flatten(outputs)):
x_id = str(id(x)) tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id]
output_tensors = [] output_tensors = []
output_shapes = []
for x in self.outputs: for x in self.outputs:
assert str(id(x)) in tensor_dict, 'Could not compute output ' + str(x) x_id = str(id(x))
tensor = tensor_dict[str(id(x))].pop() assert x_id in tensor_dict, 'Could not compute output ' + str(x)
output_shapes.append(x.shape) output_tensors.append(tensor_dict[x_id].pop())
output_tensors.append(tensor)
if output_shapes is not None: return nest.pack_sequence_as(self._nested_outputs, output_tensors)
input_shapes = [x.shape for x in inputs]
try:
cache_key = tuple(tf_utils.convert_shapes(input_shapes, to_tuples=True))
self._output_shape_cache[cache_key] = nest.pack_sequence_as(
self._nested_outputs, output_shapes)
except ValueError:
# In case there are unknown TensorShape, eg for sparse tensor input,
# We skip the caching since the shape is unknown.
pass
output_tensors = nest.pack_sequence_as(self._nested_outputs, output_tensors)
return output_tensors
def _flatten_to_reference_inputs(self, tensors): def _flatten_to_reference_inputs(self, tensors):
"""Maps `tensors` to their respective `keras.Input`.""" """Maps `tensors` to their respective `keras.Input`."""

View File

@ -102,6 +102,10 @@ class Node(object):
tensor._keras_history = KerasHistory( tensor._keras_history = KerasHistory(
layer=layer, node_index=node_index, tensor_index=i) layer=layer, node_index=node_index, tensor_index=i)
# Cached for performance.
self.flat_input_ids = [str(id(t)) for t in self._keras_inputs]
self.flat_output_ids = [str(id(t)) for t in nest.flatten(self.outputs)]
@property @property
def keras_inputs(self): def keras_inputs(self):
"""Tensors input to this node that can be traced back to a `keras.Input`.""" """Tensors input to this node that can be traced back to a `keras.Input`."""