Relax limitations on rerouting graph outputs.

- Allow multiple outputs of output_tensors in fold_batch_norms.
- Allow duplicate consumers in quantize.
- I also quick a fix issue for matching final layers that have batch norm.

PiperOrigin-RevId: 190873003
This commit is contained in:
Suharsh Sivakumar 2018-03-28 19:21:08 -07:00 committed by TensorFlower Gardener
parent 2b41d75654
commit a5a90e6b55
2 changed files with 15 additions and 9 deletions

View File

@ -134,9 +134,9 @@ def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor, nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor,
match.output_tensor) match.output_tensor)
if nodes_modified_count != 1: if nodes_modified_count == 0:
raise ValueError( raise ValueError('Folding batch norms failed, %s had no outputs.' %
'Unexpected inputs to op: %s' % match.output_tensor.name) match.output_tensor.name)
def _FindFusedBatchNorms(graph): def _FindFusedBatchNorms(graph):

View File

@ -305,7 +305,8 @@ def _FindLayersToQuantize(graph):
# the output of the final BiasAdd must be quantized. So we treat the BiasAdd # the output of the final BiasAdd must be quantized. So we treat the BiasAdd
# as the 'activation_op' in the _LayerMatch, to ensure that it's output is # as the 'activation_op' in the _LayerMatch, to ensure that it's output is
# quantized. # quantized.
final_layer_matcher = graph_matcher.GraphMatcher(bias_add_pattern) final_layer_matcher = graph_matcher.GraphMatcher(
graph_matcher.OneofPattern([bias_add_pattern, folded_bias_add_pattern]))
for match_result in final_layer_matcher.match_graph(graph): for match_result in final_layer_matcher.match_graph(graph):
layer_op = match_result.get_op(layer_pattern) layer_op = match_result.get_op(layer_pattern)
weight_tensor = match_result.get_tensor(weight_identity_pattern) weight_tensor = match_result.get_tensor(weight_identity_pattern)
@ -463,11 +464,16 @@ def _InsertQuantOp(context,
lambda: inputs, lambda: inputs,
name=name_prefix + '/delayed_quant') name=name_prefix + '/delayed_quant')
nodes_modified_count = graph_editor.reroute_ts( if consumers:
[quant], [inputs], can_modify=consumers) tensors_modified_count = graph_editor.reroute_ts(
if nodes_modified_count != len(consumers): [quant], [inputs], can_modify=consumers)
raise ValueError('Some inputs not quantized for ops: [%s]' % ', '.join( # Some operations can have multiple output tensors going to the same
[consumer.name for consumer in consumers])) # consumer. Since consumers is a set, we need to ensure that
# tensors_modified_count is greater than or equal to the length of the set
# of consumers.
if tensors_modified_count < len(consumers):
raise ValueError('No inputs quantized for ops: [%s]' % ', '.join(
[consumer.name for consumer in consumers]))
def _GetContextFromOp(op): def _GetContextFromOp(op):