Merge pull request #35175 from Taehun:fused_batch_norm_v3
PiperOrigin-RevId: 290787530 Change-Id: I73b10b1668cfb798bd092de02782645f69196797
This commit is contained in:
commit
632528ddd2
@ -78,12 +78,15 @@ INPUT_ORDER = {
|
|||||||
"conv_op", "mean_op", "var_op", "beta_op", "gamma_op"
|
"conv_op", "mean_op", "var_op", "beta_op", "gamma_op"
|
||||||
],
|
],
|
||||||
# Order of inputs for FusedBatchNorm.
|
# Order of inputs for FusedBatchNorm.
|
||||||
"FusedBatchNorm": ["conv_op", "gamma_op", "beta_op", "mean_op", "var_op"]
|
"FusedBatchNorm": ["conv_op", "gamma_op", "beta_op", "mean_op", "var_op"],
|
||||||
|
# Order of inputs for FusedBatchNormV3.
|
||||||
|
"FusedBatchNormV3": ["conv_op", "gamma_op", "beta_op", "mean_op", "var_op"]
|
||||||
}
|
}
|
||||||
# Name of the attribute epsilon value is stored in.
|
# Name of the attribute epsilon value is stored in.
|
||||||
EPSILON_ATTR = {
|
EPSILON_ATTR = {
|
||||||
"BatchNormWithGlobalNormalization": "variance_epsilon",
|
"BatchNormWithGlobalNormalization": "variance_epsilon",
|
||||||
"FusedBatchNorm": "epsilon"
|
"FusedBatchNorm": "epsilon",
|
||||||
|
"FusedBatchNormV3": "epsilon"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -211,10 +214,10 @@ def fold_batch_norms(input_graph_def):
|
|||||||
addition, rather than the more expensive multiple ops, and even bake the
|
addition, rather than the more expensive multiple ops, and even bake the
|
||||||
scaling into the convolution weights. This function identifies the typical
|
scaling into the convolution weights. This function identifies the typical
|
||||||
pattern of batch normalization subgraphs, and performs the transformation to
|
pattern of batch normalization subgraphs, and performs the transformation to
|
||||||
fold the computations down into a simpler form. It currently only spots batch
|
fold the computations down into a simpler form. It currently only supports
|
||||||
normalization that's performed by the BatchNormWithGlobalNormalization and
|
batch normalization that's performed by the BatchNormWithGlobalNormalization
|
||||||
FusedBatchNorm ops, and will need to be extended in the future to handle the
|
FusedBatchNorm and FusedBatchNormV3 ops, and will need to be extended in the
|
||||||
newer style.
|
future to handle the newer style.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_graph_def: A GraphDef containing a model.
|
input_graph_def: A GraphDef containing a model.
|
||||||
@ -235,12 +238,30 @@ def fold_batch_norms(input_graph_def):
|
|||||||
nodes_to_skip = {}
|
nodes_to_skip = {}
|
||||||
new_ops = []
|
new_ops = []
|
||||||
for node in input_graph_def.node:
|
for node in input_graph_def.node:
|
||||||
if node.op not in ("BatchNormWithGlobalNormalization", "FusedBatchNorm"):
|
if (node.op not in ("BatchNormWithGlobalNormalization", "FusedBatchNorm",
|
||||||
|
"FusedBatchNormV3")):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
bias = None
|
||||||
conv_op = node_from_map(input_node_map,
|
conv_op = node_from_map(input_node_map,
|
||||||
node.input[INPUT_ORDER[node.op].index("conv_op")])
|
node.input[INPUT_ORDER[node.op].index("conv_op")])
|
||||||
if conv_op.op != "Conv2D" and conv_op.op != "DepthwiseConv2dNative":
|
# There might be an Add/BiasAdd op between the conv and the batchnorm,
|
||||||
|
# which we can fold into the mean param of the batchnorm.
|
||||||
|
if conv_op.op in ["BiasAdd", "Add", "AddV2"]:
|
||||||
|
add_op = conv_op
|
||||||
|
# Follow the first input of the add to get to the conv.
|
||||||
|
conv_op = node_from_map(input_node_map, add_op.input[0])
|
||||||
|
bias = node_from_map(input_node_map, add_op.input[1])
|
||||||
|
if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]:
|
||||||
|
# Follow the second input of the add to get to the conv.
|
||||||
|
conv_op = node_from_map(input_node_map, add_op.input[1])
|
||||||
|
bias = node_from_map(input_node_map, add_op.input[0])
|
||||||
|
if bias and bias.op != "Const":
|
||||||
|
tf_logging.warning("The bias %s after the conv %s was not a constant. "
|
||||||
|
"Maybe because freeze_graph wasn't "
|
||||||
|
"run first?" % (bias.name, conv_op.name))
|
||||||
|
continue
|
||||||
|
if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]:
|
||||||
tf_logging.warning("Didn't find expected Conv2D or DepthwiseConv2dNative"
|
tf_logging.warning("Didn't find expected Conv2D or DepthwiseConv2dNative"
|
||||||
" input to '%s'" % node.name)
|
" input to '%s'" % node.name)
|
||||||
continue
|
continue
|
||||||
@ -265,6 +286,10 @@ def fold_batch_norms(input_graph_def):
|
|||||||
" run first?" % (node.name, mean_op))
|
" run first?" % (node.name, mean_op))
|
||||||
continue
|
continue
|
||||||
mean_value = values_from_const(mean_op)
|
mean_value = values_from_const(mean_op)
|
||||||
|
if bias is not None:
|
||||||
|
# Adjust the mean of the batchnorm based on the add op in-between the conv
|
||||||
|
# and the batchnorm.
|
||||||
|
mean_value = mean_value - values_from_const(bias)
|
||||||
if mean_value.shape != (channel_count,):
|
if mean_value.shape != (channel_count,):
|
||||||
tf_logging.warning("Incorrect shape for mean, found %s, expected %s,"
|
tf_logging.warning("Incorrect shape for mean, found %s, expected %s,"
|
||||||
" for node %s" % (str(mean_value.shape), str(
|
" for node %s" % (str(mean_value.shape), str(
|
||||||
@ -316,11 +341,9 @@ def fold_batch_norms(input_graph_def):
|
|||||||
variance_epsilon_value = node.attr[EPSILON_ATTR[node.op]].f
|
variance_epsilon_value = node.attr[EPSILON_ATTR[node.op]].f
|
||||||
nodes_to_skip[node.name] = True
|
nodes_to_skip[node.name] = True
|
||||||
nodes_to_skip[weights_op.name] = True
|
nodes_to_skip[weights_op.name] = True
|
||||||
nodes_to_skip[mean_op.name] = True
|
|
||||||
nodes_to_skip[var_op.name] = True
|
|
||||||
nodes_to_skip[beta_op.name] = True
|
|
||||||
nodes_to_skip[gamma_op.name] = True
|
|
||||||
nodes_to_skip[conv_op.name] = True
|
nodes_to_skip[conv_op.name] = True
|
||||||
|
if bias is not None:
|
||||||
|
nodes_to_skip[add_op.name] = True
|
||||||
|
|
||||||
if scale_after_normalization(node):
|
if scale_after_normalization(node):
|
||||||
scale_value = (
|
scale_value = (
|
||||||
@ -347,11 +370,16 @@ def fold_batch_norms(input_graph_def):
|
|||||||
it.iternext()
|
it.iternext()
|
||||||
scaled_weights_op = node_def_pb2.NodeDef()
|
scaled_weights_op = node_def_pb2.NodeDef()
|
||||||
scaled_weights_op.op = "Const"
|
scaled_weights_op.op = "Const"
|
||||||
scaled_weights_op.name = weights_op.name
|
scaled_weights_op.name = conv_op.name + "_weights"
|
||||||
scaled_weights_op.attr["dtype"].CopyFrom(weights_op.attr["dtype"])
|
scaled_weights_op.attr["dtype"].CopyFrom(weights_op.attr["dtype"])
|
||||||
scaled_weights_op.attr["value"].CopyFrom(
|
scaled_weights_op.attr["value"].CopyFrom(
|
||||||
attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
|
attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
|
||||||
scaled_weights, weights.dtype.type, weights.shape)))
|
scaled_weights, weights.dtype.type, weights.shape)))
|
||||||
|
# Replace the weights node with scaled weights node
|
||||||
|
for i, weights_node in enumerate(conv_op.input):
|
||||||
|
if weights_node == weights_op.name:
|
||||||
|
conv_op.input[i] = scaled_weights_op.name
|
||||||
|
|
||||||
new_conv_op = node_def_pb2.NodeDef()
|
new_conv_op = node_def_pb2.NodeDef()
|
||||||
new_conv_op.CopyFrom(conv_op)
|
new_conv_op.CopyFrom(conv_op)
|
||||||
offset_op = node_def_pb2.NodeDef()
|
offset_op = node_def_pb2.NodeDef()
|
||||||
@ -375,9 +403,16 @@ def fold_batch_norms(input_graph_def):
|
|||||||
continue
|
continue
|
||||||
new_node = node_def_pb2.NodeDef()
|
new_node = node_def_pb2.NodeDef()
|
||||||
new_node.CopyFrom(node)
|
new_node.CopyFrom(node)
|
||||||
|
retained_input = []
|
||||||
|
for input_node in new_node.input:
|
||||||
|
if not input_node.startswith("^") or input_node[1:] not in nodes_to_skip:
|
||||||
|
retained_input.append(input_node)
|
||||||
|
new_node.input[:] = retained_input
|
||||||
|
|
||||||
result_graph_def.node.extend([new_node])
|
result_graph_def.node.extend([new_node])
|
||||||
|
|
||||||
result_graph_def.node.extend(new_ops)
|
result_graph_def.node.extend(new_ops)
|
||||||
|
result_graph_def.versions.CopyFrom(input_graph_def.versions)
|
||||||
return result_graph_def
|
return result_graph_def
|
||||||
|
|
||||||
|
|
||||||
|
@ -233,6 +233,66 @@ class OptimizeForInferenceTest(test.TestCase):
|
|||||||
for node in optimized_graph_def.node:
|
for node in optimized_graph_def.node:
|
||||||
self.assertNotEqual("FusedBatchNorm", node.op)
|
self.assertNotEqual("FusedBatchNorm", node.op)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testFoldFusedBatchNormsV3(self):
|
||||||
|
for data_format, conv2d_func in [("NHWC", nn_ops.conv2d),
|
||||||
|
("NCHW", nn_ops.conv2d),
|
||||||
|
("NHWC", nn_ops.depthwise_conv2d_native),
|
||||||
|
("NCHW", nn_ops.depthwise_conv2d_native)]:
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
|
||||||
|
input_op = constant_op.constant(
|
||||||
|
np.array(inputs),
|
||||||
|
shape=[1, 1, 6, 2] if data_format == "NHWC" else [1, 2, 1, 6],
|
||||||
|
dtype=dtypes.float32)
|
||||||
|
if conv2d_func == nn_ops.conv2d:
|
||||||
|
weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
|
||||||
|
weights_op = constant_op.constant(
|
||||||
|
np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32)
|
||||||
|
else:
|
||||||
|
weights = [1, 2, 0.3, 0.4]
|
||||||
|
weights_op = constant_op.constant(
|
||||||
|
np.array(weights), shape=[1, 2, 2, 1], dtype=dtypes.float32)
|
||||||
|
mean_op = constant_op.constant(
|
||||||
|
np.array([10, 20]), shape=[2], dtype=dtypes.float32)
|
||||||
|
variance_op = constant_op.constant(
|
||||||
|
np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32)
|
||||||
|
beta_op = constant_op.constant(
|
||||||
|
np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32)
|
||||||
|
gamma_op = constant_op.constant(
|
||||||
|
np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32)
|
||||||
|
ops.get_default_graph().graph_def_versions.producer = 9
|
||||||
|
conv_op = conv2d_func(
|
||||||
|
input_op,
|
||||||
|
weights_op, [1, 1, 1, 1],
|
||||||
|
padding="SAME",
|
||||||
|
data_format=data_format,
|
||||||
|
name="conv_op")
|
||||||
|
gen_nn_ops.fused_batch_norm_v3(
|
||||||
|
conv_op,
|
||||||
|
gamma_op,
|
||||||
|
beta_op,
|
||||||
|
mean_op,
|
||||||
|
variance_op,
|
||||||
|
0.00001,
|
||||||
|
is_training=False,
|
||||||
|
data_format=data_format,
|
||||||
|
name="output")
|
||||||
|
original_graph_def = sess.graph_def
|
||||||
|
original_result = sess.run(["output:0"])
|
||||||
|
optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
|
||||||
|
original_graph_def)
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
_ = importer.import_graph_def(
|
||||||
|
optimized_graph_def, input_map={}, name="optimized")
|
||||||
|
optimized_result = sess.run(["optimized/output:0"])
|
||||||
|
|
||||||
|
self.assertAllClose(
|
||||||
|
original_result, optimized_result, rtol=1e-04, atol=1e-06)
|
||||||
|
|
||||||
|
for node in optimized_graph_def.node:
|
||||||
|
self.assertNotEqual("FusedBatchNormV3", node.op)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testFuseResizePadAndConv(self):
|
def testFuseResizePadAndConv(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user