Add folding FusedBatchNormV3.
This commit is contained in:
parent
08c45ed0c5
commit
8c7305bd49
@ -77,12 +77,15 @@ INPUT_ORDER = {
|
||||
"conv_op", "mean_op", "var_op", "beta_op", "gamma_op"
|
||||
],
|
||||
# Order of inputs for FusedBatchNorm.
|
||||
"FusedBatchNorm": ["conv_op", "gamma_op", "beta_op", "mean_op", "var_op"]
|
||||
"FusedBatchNorm": ["conv_op", "gamma_op", "beta_op", "mean_op", "var_op"],
|
||||
# Order of inputs for FusedBatchNormV3.
|
||||
"FusedBatchNormV3": ["conv_op", "gamma_op", "beta_op", "mean_op", "var_op"]
|
||||
}
|
||||
# Name of the attribute epsilon value is stored in.
|
||||
EPSILON_ATTR = {
|
||||
"BatchNormWithGlobalNormalization": "variance_epsilon",
|
||||
"FusedBatchNorm": "epsilon"
|
||||
"FusedBatchNorm": "epsilon",
|
||||
"FusedBatchNormV3": "epsilon"
|
||||
}
|
||||
|
||||
|
||||
@ -210,10 +213,10 @@ def fold_batch_norms(input_graph_def):
|
||||
addition, rather than the more expensive multiple ops, and even bake the
|
||||
scaling into the convolution weights. This function identifies the typical
|
||||
pattern of batch normalization subgraphs, and performs the transformation to
|
||||
fold the computations down into a simpler form. It currently only spots batch
|
||||
normalization that's performed by the BatchNormWithGlobalNormalization and
|
||||
FusedBatchNorm ops, and will need to be extended in the future to handle the
|
||||
newer style.
|
||||
fold the computations down into a simpler form. It currently only supports
|
||||
batch normalization that's performed by the BatchNormWithGlobalNormalization
|
||||
FusedBatchNorm and FusedBatchNormV3 ops, and will need to be extended in the
|
||||
future to handle the newer style.
|
||||
|
||||
Args:
|
||||
input_graph_def: A GraphDef containing a model.
|
||||
@ -234,12 +237,33 @@ def fold_batch_norms(input_graph_def):
|
||||
nodes_to_skip = {}
|
||||
new_ops = []
|
||||
for node in input_graph_def.node:
|
||||
if node.op not in ("BatchNormWithGlobalNormalization", "FusedBatchNorm"):
|
||||
if (node.op not in ("BatchNormWithGlobalNormalization",
|
||||
"FusedBatchNorm", "FusedBatchNormV3")):
|
||||
continue
|
||||
|
||||
conv_op = node_from_map(input_node_map,
|
||||
node.input[INPUT_ORDER[node.op].index("conv_op")])
|
||||
if conv_op.op != "Conv2D" and conv_op.op != "DepthwiseConv2dNative":
|
||||
bias = None
|
||||
conv_op = node_from_map(
|
||||
input_node_map,
|
||||
node.input[INPUT_ORDER[node.op].index("conv_op")])
|
||||
# There might be an Add/BiasAdd op between the conv and the batchnorm,
|
||||
# which we can fold into the mean param of the batchnorm.
|
||||
if conv_op.op in ['BiasAdd', 'Add', 'AddV2']:
|
||||
add_op = conv_op
|
||||
# Follow the first input of the add to get to the conv.
|
||||
conv_op = node_from_map(
|
||||
input_node_map, add_op.input[0])
|
||||
bias = node_from_map(input_node_map, add_op.input[1])
|
||||
if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]:
|
||||
# Follow the second input of the add to get to the conv.
|
||||
conv_op = node_from_map(
|
||||
input_node_map, add_op.input[1])
|
||||
bias = node_from_map(input_node_map, add_op.input[0])
|
||||
if bias and bias.op != 'Const':
|
||||
tf_logging.warning("The bias %s after the conv %s was not a constant. "
|
||||
"Maybe because freeze_graph wasn't "
|
||||
"run first?" % (bias.name, conv_op.name))
|
||||
continue
|
||||
if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]:
|
||||
tf_logging.warning("Didn't find expected Conv2D or DepthwiseConv2dNative"
|
||||
" input to '%s'" % node.name)
|
||||
continue
|
||||
@ -264,6 +288,10 @@ def fold_batch_norms(input_graph_def):
|
||||
" run first?" % (node.name, mean_op))
|
||||
continue
|
||||
mean_value = values_from_const(mean_op)
|
||||
if bias is not None:
|
||||
# Adjust the mean of the batchnorm based on the add op in-between the conv
|
||||
# and the batchnorm.
|
||||
mean_value = mean_value - values_from_const(bias)
|
||||
if mean_value.shape != (channel_count,):
|
||||
tf_logging.warning("Incorrect shape for mean, found %s, expected %s,"
|
||||
" for node %s" % (str(mean_value.shape), str(
|
||||
@ -315,11 +343,9 @@ def fold_batch_norms(input_graph_def):
|
||||
variance_epsilon_value = node.attr[EPSILON_ATTR[node.op]].f
|
||||
nodes_to_skip[node.name] = True
|
||||
nodes_to_skip[weights_op.name] = True
|
||||
nodes_to_skip[mean_op.name] = True
|
||||
nodes_to_skip[var_op.name] = True
|
||||
nodes_to_skip[beta_op.name] = True
|
||||
nodes_to_skip[gamma_op.name] = True
|
||||
nodes_to_skip[conv_op.name] = True
|
||||
if bias is not None:
|
||||
nodes_to_skip[add_op.name] = True
|
||||
|
||||
if scale_after_normalization(node):
|
||||
scale_value = (
|
||||
@ -346,11 +372,16 @@ def fold_batch_norms(input_graph_def):
|
||||
it.iternext()
|
||||
scaled_weights_op = node_def_pb2.NodeDef()
|
||||
scaled_weights_op.op = "Const"
|
||||
scaled_weights_op.name = weights_op.name
|
||||
scaled_weights_op.name = conv_op.name + '_weights'
|
||||
scaled_weights_op.attr["dtype"].CopyFrom(weights_op.attr["dtype"])
|
||||
scaled_weights_op.attr["value"].CopyFrom(
|
||||
attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
|
||||
scaled_weights, weights.dtype.type, weights.shape)))
|
||||
# Replace the weights node with scaled weights node
|
||||
for i, weights_node in enumerate(conv_op.input):
|
||||
if weights_node == weights_op.name:
|
||||
conv_op.input[i] = scaled_weights_op.name
|
||||
|
||||
new_conv_op = node_def_pb2.NodeDef()
|
||||
new_conv_op.CopyFrom(conv_op)
|
||||
offset_op = node_def_pb2.NodeDef()
|
||||
@ -374,9 +405,16 @@ def fold_batch_norms(input_graph_def):
|
||||
continue
|
||||
new_node = node_def_pb2.NodeDef()
|
||||
new_node.CopyFrom(node)
|
||||
retained_input = []
|
||||
for input_node in new_node.input:
|
||||
if not input_node.startswith('^') or input_node[1:] not in nodes_to_skip:
|
||||
retained_input.append(input_node)
|
||||
new_node.input[:] = retained_input
|
||||
|
||||
result_graph_def.node.extend([new_node])
|
||||
|
||||
result_graph_def.node.extend(new_ops)
|
||||
result_graph_def.versions.CopyFrom(input_graph_def.versions)
|
||||
return result_graph_def
|
||||
|
||||
|
||||
|
@ -233,6 +233,67 @@ class OptimizeForInferenceTest(test.TestCase):
|
||||
for node in optimized_graph_def.node:
|
||||
self.assertNotEqual("FusedBatchNorm", node.op)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFoldFusedBatchNormsV3(self):
|
||||
for data_format, conv2d_func in [
|
||||
("NHWC", nn_ops.conv2d), ("NCHW", nn_ops.conv2d),
|
||||
("NHWC", nn_ops.depthwise_conv2d_native),
|
||||
("NCHW", nn_ops.depthwise_conv2d_native)
|
||||
]:
|
||||
with self.cached_session() as sess:
|
||||
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
|
||||
input_op = constant_op.constant(
|
||||
np.array(inputs),
|
||||
shape=[1, 1, 6, 2] if data_format == "NHWC" else [1, 2, 1, 6],
|
||||
dtype=dtypes.float32)
|
||||
if conv2d_func == nn_ops.conv2d:
|
||||
weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
|
||||
weights_op = constant_op.constant(
|
||||
np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32)
|
||||
else:
|
||||
weights = [1, 2, 0.3, 0.4]
|
||||
weights_op = constant_op.constant(
|
||||
np.array(weights), shape=[1, 2, 2, 1], dtype=dtypes.float32)
|
||||
mean_op = constant_op.constant(
|
||||
np.array([10, 20]), shape=[2], dtype=dtypes.float32)
|
||||
variance_op = constant_op.constant(
|
||||
np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32)
|
||||
beta_op = constant_op.constant(
|
||||
np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32)
|
||||
gamma_op = constant_op.constant(
|
||||
np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32)
|
||||
ops.get_default_graph().graph_def_versions.producer = 9
|
||||
conv_op = conv2d_func(
|
||||
input_op,
|
||||
weights_op, [1, 1, 1, 1],
|
||||
padding="SAME",
|
||||
data_format=data_format,
|
||||
name="conv_op")
|
||||
gen_nn_ops.fused_batch_norm_v3(
|
||||
conv_op,
|
||||
gamma_op,
|
||||
beta_op,
|
||||
mean_op,
|
||||
variance_op,
|
||||
0.00001,
|
||||
is_training=False,
|
||||
data_format=data_format,
|
||||
name="output")
|
||||
original_graph_def = sess.graph_def
|
||||
original_result = sess.run(["output:0"])
|
||||
optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
|
||||
original_graph_def)
|
||||
with self.cached_session() as sess:
|
||||
_ = importer.import_graph_def(
|
||||
optimized_graph_def, input_map={}, name="optimized")
|
||||
optimized_result = sess.run(["optimized/output:0"])
|
||||
|
||||
self.assertAllClose(
|
||||
original_result, optimized_result, rtol=1e-04, atol=1e-06)
|
||||
|
||||
for node in optimized_graph_def.node:
|
||||
self.assertNotEqual("FusedBatchNormV3", node.op)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFuseResizePadAndConv(self):
|
||||
with self.cached_session() as sess:
|
||||
|
Loading…
Reference in New Issue
Block a user