Merge pull request #35175 from Taehun:fused_batch_norm_v3
PiperOrigin-RevId: 290787530 Change-Id: I73b10b1668cfb798bd092de02782645f69196797
This commit is contained in:
commit
632528ddd2
@ -78,12 +78,15 @@ INPUT_ORDER = {
|
||||
"conv_op", "mean_op", "var_op", "beta_op", "gamma_op"
|
||||
],
|
||||
# Order of inputs for FusedBatchNorm.
|
||||
"FusedBatchNorm": ["conv_op", "gamma_op", "beta_op", "mean_op", "var_op"]
|
||||
"FusedBatchNorm": ["conv_op", "gamma_op", "beta_op", "mean_op", "var_op"],
|
||||
# Order of inputs for FusedBatchNormV3.
|
||||
"FusedBatchNormV3": ["conv_op", "gamma_op", "beta_op", "mean_op", "var_op"]
|
||||
}
|
||||
# Name of the attribute epsilon value is stored in.
|
||||
EPSILON_ATTR = {
|
||||
"BatchNormWithGlobalNormalization": "variance_epsilon",
|
||||
"FusedBatchNorm": "epsilon"
|
||||
"FusedBatchNorm": "epsilon",
|
||||
"FusedBatchNormV3": "epsilon"
|
||||
}
|
||||
|
||||
|
||||
@ -211,10 +214,10 @@ def fold_batch_norms(input_graph_def):
|
||||
addition, rather than the more expensive multiple ops, and even bake the
|
||||
scaling into the convolution weights. This function identifies the typical
|
||||
pattern of batch normalization subgraphs, and performs the transformation to
|
||||
fold the computations down into a simpler form. It currently only spots batch
|
||||
normalization that's performed by the BatchNormWithGlobalNormalization and
|
||||
FusedBatchNorm ops, and will need to be extended in the future to handle the
|
||||
newer style.
|
||||
fold the computations down into a simpler form. It currently only supports
|
||||
batch normalization that's performed by the BatchNormWithGlobalNormalization
|
||||
FusedBatchNorm and FusedBatchNormV3 ops, and will need to be extended in the
|
||||
future to handle the newer style.
|
||||
|
||||
Args:
|
||||
input_graph_def: A GraphDef containing a model.
|
||||
@ -235,12 +238,30 @@ def fold_batch_norms(input_graph_def):
|
||||
nodes_to_skip = {}
|
||||
new_ops = []
|
||||
for node in input_graph_def.node:
|
||||
if node.op not in ("BatchNormWithGlobalNormalization", "FusedBatchNorm"):
|
||||
if (node.op not in ("BatchNormWithGlobalNormalization", "FusedBatchNorm",
|
||||
"FusedBatchNormV3")):
|
||||
continue
|
||||
|
||||
bias = None
|
||||
conv_op = node_from_map(input_node_map,
|
||||
node.input[INPUT_ORDER[node.op].index("conv_op")])
|
||||
if conv_op.op != "Conv2D" and conv_op.op != "DepthwiseConv2dNative":
|
||||
# There might be an Add/BiasAdd op between the conv and the batchnorm,
|
||||
# which we can fold into the mean param of the batchnorm.
|
||||
if conv_op.op in ["BiasAdd", "Add", "AddV2"]:
|
||||
add_op = conv_op
|
||||
# Follow the first input of the add to get to the conv.
|
||||
conv_op = node_from_map(input_node_map, add_op.input[0])
|
||||
bias = node_from_map(input_node_map, add_op.input[1])
|
||||
if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]:
|
||||
# Follow the second input of the add to get to the conv.
|
||||
conv_op = node_from_map(input_node_map, add_op.input[1])
|
||||
bias = node_from_map(input_node_map, add_op.input[0])
|
||||
if bias and bias.op != "Const":
|
||||
tf_logging.warning("The bias %s after the conv %s was not a constant. "
|
||||
"Maybe because freeze_graph wasn't "
|
||||
"run first?" % (bias.name, conv_op.name))
|
||||
continue
|
||||
if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]:
|
||||
tf_logging.warning("Didn't find expected Conv2D or DepthwiseConv2dNative"
|
||||
" input to '%s'" % node.name)
|
||||
continue
|
||||
@ -265,6 +286,10 @@ def fold_batch_norms(input_graph_def):
|
||||
" run first?" % (node.name, mean_op))
|
||||
continue
|
||||
mean_value = values_from_const(mean_op)
|
||||
if bias is not None:
|
||||
# Adjust the mean of the batchnorm based on the add op in-between the conv
|
||||
# and the batchnorm.
|
||||
mean_value = mean_value - values_from_const(bias)
|
||||
if mean_value.shape != (channel_count,):
|
||||
tf_logging.warning("Incorrect shape for mean, found %s, expected %s,"
|
||||
" for node %s" % (str(mean_value.shape), str(
|
||||
@ -316,11 +341,9 @@ def fold_batch_norms(input_graph_def):
|
||||
variance_epsilon_value = node.attr[EPSILON_ATTR[node.op]].f
|
||||
nodes_to_skip[node.name] = True
|
||||
nodes_to_skip[weights_op.name] = True
|
||||
nodes_to_skip[mean_op.name] = True
|
||||
nodes_to_skip[var_op.name] = True
|
||||
nodes_to_skip[beta_op.name] = True
|
||||
nodes_to_skip[gamma_op.name] = True
|
||||
nodes_to_skip[conv_op.name] = True
|
||||
if bias is not None:
|
||||
nodes_to_skip[add_op.name] = True
|
||||
|
||||
if scale_after_normalization(node):
|
||||
scale_value = (
|
||||
@ -347,11 +370,16 @@ def fold_batch_norms(input_graph_def):
|
||||
it.iternext()
|
||||
scaled_weights_op = node_def_pb2.NodeDef()
|
||||
scaled_weights_op.op = "Const"
|
||||
scaled_weights_op.name = weights_op.name
|
||||
scaled_weights_op.name = conv_op.name + "_weights"
|
||||
scaled_weights_op.attr["dtype"].CopyFrom(weights_op.attr["dtype"])
|
||||
scaled_weights_op.attr["value"].CopyFrom(
|
||||
attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
|
||||
scaled_weights, weights.dtype.type, weights.shape)))
|
||||
# Replace the weights node with scaled weights node
|
||||
for i, weights_node in enumerate(conv_op.input):
|
||||
if weights_node == weights_op.name:
|
||||
conv_op.input[i] = scaled_weights_op.name
|
||||
|
||||
new_conv_op = node_def_pb2.NodeDef()
|
||||
new_conv_op.CopyFrom(conv_op)
|
||||
offset_op = node_def_pb2.NodeDef()
|
||||
@ -375,9 +403,16 @@ def fold_batch_norms(input_graph_def):
|
||||
continue
|
||||
new_node = node_def_pb2.NodeDef()
|
||||
new_node.CopyFrom(node)
|
||||
retained_input = []
|
||||
for input_node in new_node.input:
|
||||
if not input_node.startswith("^") or input_node[1:] not in nodes_to_skip:
|
||||
retained_input.append(input_node)
|
||||
new_node.input[:] = retained_input
|
||||
|
||||
result_graph_def.node.extend([new_node])
|
||||
|
||||
result_graph_def.node.extend(new_ops)
|
||||
result_graph_def.versions.CopyFrom(input_graph_def.versions)
|
||||
return result_graph_def
|
||||
|
||||
|
||||
|
@ -233,6 +233,66 @@ class OptimizeForInferenceTest(test.TestCase):
|
||||
for node in optimized_graph_def.node:
|
||||
self.assertNotEqual("FusedBatchNorm", node.op)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFoldFusedBatchNormsV3(self):
|
||||
for data_format, conv2d_func in [("NHWC", nn_ops.conv2d),
|
||||
("NCHW", nn_ops.conv2d),
|
||||
("NHWC", nn_ops.depthwise_conv2d_native),
|
||||
("NCHW", nn_ops.depthwise_conv2d_native)]:
|
||||
with self.cached_session() as sess:
|
||||
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
|
||||
input_op = constant_op.constant(
|
||||
np.array(inputs),
|
||||
shape=[1, 1, 6, 2] if data_format == "NHWC" else [1, 2, 1, 6],
|
||||
dtype=dtypes.float32)
|
||||
if conv2d_func == nn_ops.conv2d:
|
||||
weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
|
||||
weights_op = constant_op.constant(
|
||||
np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32)
|
||||
else:
|
||||
weights = [1, 2, 0.3, 0.4]
|
||||
weights_op = constant_op.constant(
|
||||
np.array(weights), shape=[1, 2, 2, 1], dtype=dtypes.float32)
|
||||
mean_op = constant_op.constant(
|
||||
np.array([10, 20]), shape=[2], dtype=dtypes.float32)
|
||||
variance_op = constant_op.constant(
|
||||
np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32)
|
||||
beta_op = constant_op.constant(
|
||||
np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32)
|
||||
gamma_op = constant_op.constant(
|
||||
np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32)
|
||||
ops.get_default_graph().graph_def_versions.producer = 9
|
||||
conv_op = conv2d_func(
|
||||
input_op,
|
||||
weights_op, [1, 1, 1, 1],
|
||||
padding="SAME",
|
||||
data_format=data_format,
|
||||
name="conv_op")
|
||||
gen_nn_ops.fused_batch_norm_v3(
|
||||
conv_op,
|
||||
gamma_op,
|
||||
beta_op,
|
||||
mean_op,
|
||||
variance_op,
|
||||
0.00001,
|
||||
is_training=False,
|
||||
data_format=data_format,
|
||||
name="output")
|
||||
original_graph_def = sess.graph_def
|
||||
original_result = sess.run(["output:0"])
|
||||
optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
|
||||
original_graph_def)
|
||||
with self.cached_session() as sess:
|
||||
_ = importer.import_graph_def(
|
||||
optimized_graph_def, input_map={}, name="optimized")
|
||||
optimized_result = sess.run(["optimized/output:0"])
|
||||
|
||||
self.assertAllClose(
|
||||
original_result, optimized_result, rtol=1e-04, atol=1e-06)
|
||||
|
||||
for node in optimized_graph_def.node:
|
||||
self.assertNotEqual("FusedBatchNormV3", node.op)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFuseResizePadAndConv(self):
|
||||
with self.cached_session() as sess:
|
||||
|
Loading…
Reference in New Issue
Block a user