Merge pull request #35175 from Taehun:fused_batch_norm_v3

PiperOrigin-RevId: 290787530
Change-Id: I73b10b1668cfb798bd092de02782645f69196797
This commit is contained in:
TensorFlower Gardener 2020-01-21 11:42:30 -08:00
commit 632528ddd2
2 changed files with 108 additions and 13 deletions

View File

@ -78,12 +78,15 @@ INPUT_ORDER = {
"conv_op", "mean_op", "var_op", "beta_op", "gamma_op"
],
# Order of inputs for FusedBatchNorm.
"FusedBatchNorm": ["conv_op", "gamma_op", "beta_op", "mean_op", "var_op"]
"FusedBatchNorm": ["conv_op", "gamma_op", "beta_op", "mean_op", "var_op"],
# Order of inputs for FusedBatchNormV3.
"FusedBatchNormV3": ["conv_op", "gamma_op", "beta_op", "mean_op", "var_op"]
}
# Name of the attribute epsilon value is stored in.
EPSILON_ATTR = {
"BatchNormWithGlobalNormalization": "variance_epsilon",
"FusedBatchNorm": "epsilon"
"FusedBatchNorm": "epsilon",
"FusedBatchNormV3": "epsilon"
}
@ -211,10 +214,10 @@ def fold_batch_norms(input_graph_def):
addition, rather than the more expensive multiple ops, and even bake the
scaling into the convolution weights. This function identifies the typical
pattern of batch normalization subgraphs, and performs the transformation to
fold the computations down into a simpler form. It currently only spots batch
normalization that's performed by the BatchNormWithGlobalNormalization and
FusedBatchNorm ops, and will need to be extended in the future to handle the
newer style.
fold the computations down into a simpler form. It currently only supports
batch normalization that's performed by the BatchNormWithGlobalNormalization
FusedBatchNorm and FusedBatchNormV3 ops, and will need to be extended in the
future to handle the newer style.
Args:
input_graph_def: A GraphDef containing a model.
@ -235,12 +238,30 @@ def fold_batch_norms(input_graph_def):
nodes_to_skip = {}
new_ops = []
for node in input_graph_def.node:
if node.op not in ("BatchNormWithGlobalNormalization", "FusedBatchNorm"):
if (node.op not in ("BatchNormWithGlobalNormalization", "FusedBatchNorm",
"FusedBatchNormV3")):
continue
bias = None
conv_op = node_from_map(input_node_map,
node.input[INPUT_ORDER[node.op].index("conv_op")])
if conv_op.op != "Conv2D" and conv_op.op != "DepthwiseConv2dNative":
# There might be an Add/BiasAdd op between the conv and the batchnorm,
# which we can fold into the mean param of the batchnorm.
if conv_op.op in ["BiasAdd", "Add", "AddV2"]:
add_op = conv_op
# Follow the first input of the add to get to the conv.
conv_op = node_from_map(input_node_map, add_op.input[0])
bias = node_from_map(input_node_map, add_op.input[1])
if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]:
# Follow the second input of the add to get to the conv.
conv_op = node_from_map(input_node_map, add_op.input[1])
bias = node_from_map(input_node_map, add_op.input[0])
if bias and bias.op != "Const":
tf_logging.warning("The bias %s after the conv %s was not a constant. "
"Maybe because freeze_graph wasn't "
"run first?" % (bias.name, conv_op.name))
continue
if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]:
tf_logging.warning("Didn't find expected Conv2D or DepthwiseConv2dNative"
" input to '%s'" % node.name)
continue
@ -265,6 +286,10 @@ def fold_batch_norms(input_graph_def):
" run first?" % (node.name, mean_op))
continue
mean_value = values_from_const(mean_op)
if bias is not None:
# Adjust the mean of the batchnorm based on the add op in-between the conv
# and the batchnorm.
mean_value = mean_value - values_from_const(bias)
if mean_value.shape != (channel_count,):
tf_logging.warning("Incorrect shape for mean, found %s, expected %s,"
" for node %s" % (str(mean_value.shape), str(
@ -316,11 +341,9 @@ def fold_batch_norms(input_graph_def):
variance_epsilon_value = node.attr[EPSILON_ATTR[node.op]].f
nodes_to_skip[node.name] = True
nodes_to_skip[weights_op.name] = True
nodes_to_skip[mean_op.name] = True
nodes_to_skip[var_op.name] = True
nodes_to_skip[beta_op.name] = True
nodes_to_skip[gamma_op.name] = True
nodes_to_skip[conv_op.name] = True
if bias is not None:
nodes_to_skip[add_op.name] = True
if scale_after_normalization(node):
scale_value = (
@ -347,11 +370,16 @@ def fold_batch_norms(input_graph_def):
it.iternext()
scaled_weights_op = node_def_pb2.NodeDef()
scaled_weights_op.op = "Const"
scaled_weights_op.name = weights_op.name
scaled_weights_op.name = conv_op.name + "_weights"
scaled_weights_op.attr["dtype"].CopyFrom(weights_op.attr["dtype"])
scaled_weights_op.attr["value"].CopyFrom(
attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
scaled_weights, weights.dtype.type, weights.shape)))
# Replace the weights node with scaled weights node
for i, weights_node in enumerate(conv_op.input):
if weights_node == weights_op.name:
conv_op.input[i] = scaled_weights_op.name
new_conv_op = node_def_pb2.NodeDef()
new_conv_op.CopyFrom(conv_op)
offset_op = node_def_pb2.NodeDef()
@ -375,9 +403,16 @@ def fold_batch_norms(input_graph_def):
continue
new_node = node_def_pb2.NodeDef()
new_node.CopyFrom(node)
retained_input = []
for input_node in new_node.input:
if not input_node.startswith("^") or input_node[1:] not in nodes_to_skip:
retained_input.append(input_node)
new_node.input[:] = retained_input
result_graph_def.node.extend([new_node])
result_graph_def.node.extend(new_ops)
result_graph_def.versions.CopyFrom(input_graph_def.versions)
return result_graph_def

View File

@ -233,6 +233,66 @@ class OptimizeForInferenceTest(test.TestCase):
for node in optimized_graph_def.node:
self.assertNotEqual("FusedBatchNorm", node.op)
@test_util.run_deprecated_v1
def testFoldFusedBatchNormsV3(self):
for data_format, conv2d_func in [("NHWC", nn_ops.conv2d),
("NCHW", nn_ops.conv2d),
("NHWC", nn_ops.depthwise_conv2d_native),
("NCHW", nn_ops.depthwise_conv2d_native)]:
with self.cached_session() as sess:
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
input_op = constant_op.constant(
np.array(inputs),
shape=[1, 1, 6, 2] if data_format == "NHWC" else [1, 2, 1, 6],
dtype=dtypes.float32)
if conv2d_func == nn_ops.conv2d:
weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
weights_op = constant_op.constant(
np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32)
else:
weights = [1, 2, 0.3, 0.4]
weights_op = constant_op.constant(
np.array(weights), shape=[1, 2, 2, 1], dtype=dtypes.float32)
mean_op = constant_op.constant(
np.array([10, 20]), shape=[2], dtype=dtypes.float32)
variance_op = constant_op.constant(
np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32)
beta_op = constant_op.constant(
np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32)
gamma_op = constant_op.constant(
np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32)
ops.get_default_graph().graph_def_versions.producer = 9
conv_op = conv2d_func(
input_op,
weights_op, [1, 1, 1, 1],
padding="SAME",
data_format=data_format,
name="conv_op")
gen_nn_ops.fused_batch_norm_v3(
conv_op,
gamma_op,
beta_op,
mean_op,
variance_op,
0.00001,
is_training=False,
data_format=data_format,
name="output")
original_graph_def = sess.graph_def
original_result = sess.run(["output:0"])
optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
original_graph_def)
with self.cached_session() as sess:
_ = importer.import_graph_def(
optimized_graph_def, input_map={}, name="optimized")
optimized_result = sess.run(["optimized/output:0"])
self.assertAllClose(
original_result, optimized_result, rtol=1e-04, atol=1e-06)
for node in optimized_graph_def.node:
self.assertNotEqual("FusedBatchNormV3", node.op)
@test_util.run_deprecated_v1
def testFuseResizePadAndConv(self):
with self.cached_session() as sess: