Reduce Functional.__call__ Python overhead by ~5-10%

PiperOrigin-RevId: 311775071
Change-Id: I45dd0a1ce865d6c17f7b5e292799348e1e17a91c
This commit is contained in:
Thomas O'Malley 2020-05-15 12:00:16 -07:00 committed by TensorFlower Gardener
parent 75132b735b
commit b1fc80f4a1
2 changed files with 19 additions and 31 deletions

View File

@ -469,11 +469,11 @@ class Functional(training_lib.Model):
mask: (Optional) Tensor or nested structure of Tensors.
Returns:
Two lists: output_tensors, output_masks
output_tensors
"""
inputs = self._flatten_to_reference_inputs(inputs)
if mask is None:
masks = [None for _ in range(len(inputs))]
masks = [None] * len(inputs)
else:
masks = self._flatten_to_reference_inputs(mask)
for input_t, mask in zip(inputs, masks):
@ -481,55 +481,39 @@ class Functional(training_lib.Model):
# Dictionary mapping reference tensors to computed tensors.
tensor_dict = {}
tensor_usage_count = self._tensor_usage_count
for x, y in zip(self.inputs, inputs):
y = self._conform_to_reference_input(y, ref_input=x)
x_id = str(id(x))
tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id]
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
depth_keys = list(self._nodes_by_depth.keys())
nodes_by_depth = self._nodes_by_depth
depth_keys = list(nodes_by_depth.keys())
depth_keys.sort(reverse=True)
for depth in depth_keys:
nodes = self._nodes_by_depth[depth]
nodes = nodes_by_depth[depth]
for node in nodes:
if node.is_input:
continue # Input tensors already exist.
if not all(
str(id(tensor)) in tensor_dict
for tensor in nest.flatten(node.keras_inputs)):
if any(t_id not in tensor_dict for t_id in node.flat_input_ids):
continue # Node is not computable, try skipping.
layer = node.layer
args, kwargs = node.map_arguments(tensor_dict)
outputs = layer(*args, **kwargs)
outputs = node.layer(*args, **kwargs)
# Update tensor_dict.
for x, y in zip(nest.flatten(node.outputs), nest.flatten(outputs)):
x_id = str(id(x))
tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id]
for x_id, y in zip(node.flat_output_ids, nest.flatten(outputs)):
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
output_tensors = []
output_shapes = []
for x in self.outputs:
assert str(id(x)) in tensor_dict, 'Could not compute output ' + str(x)
tensor = tensor_dict[str(id(x))].pop()
output_shapes.append(x.shape)
output_tensors.append(tensor)
x_id = str(id(x))
assert x_id in tensor_dict, 'Could not compute output ' + str(x)
output_tensors.append(tensor_dict[x_id].pop())
if output_shapes is not None:
input_shapes = [x.shape for x in inputs]
try:
cache_key = tuple(tf_utils.convert_shapes(input_shapes, to_tuples=True))
self._output_shape_cache[cache_key] = nest.pack_sequence_as(
self._nested_outputs, output_shapes)
except ValueError:
# In case there are unknown TensorShape, eg for sparse tensor input,
# We skip the caching since the shape is unknown.
pass
output_tensors = nest.pack_sequence_as(self._nested_outputs, output_tensors)
return output_tensors
return nest.pack_sequence_as(self._nested_outputs, output_tensors)
def _flatten_to_reference_inputs(self, tensors):
"""Maps `tensors` to their respective `keras.Input`."""

View File

@ -102,6 +102,10 @@ class Node(object):
tensor._keras_history = KerasHistory(
layer=layer, node_index=node_index, tensor_index=i)
# Cached for performance.
self.flat_input_ids = [str(id(t)) for t in self._keras_inputs]
self.flat_output_ids = [str(id(t)) for t in nest.flatten(self.outputs)]
@property
def keras_inputs(self):
"""Tensors input to this node that can be traced back to a `keras.Input`."""