Add folding FusedBatchNormV3.
This commit is contained in:
parent
08c45ed0c5
commit
8c7305bd49
@ -77,12 +77,15 @@ INPUT_ORDER = {
|
|||||||
"conv_op", "mean_op", "var_op", "beta_op", "gamma_op"
|
"conv_op", "mean_op", "var_op", "beta_op", "gamma_op"
|
||||||
],
|
],
|
||||||
# Order of inputs for FusedBatchNorm.
|
# Order of inputs for FusedBatchNorm.
|
||||||
"FusedBatchNorm": ["conv_op", "gamma_op", "beta_op", "mean_op", "var_op"]
|
"FusedBatchNorm": ["conv_op", "gamma_op", "beta_op", "mean_op", "var_op"],
|
||||||
|
# Order of inputs for FusedBatchNormV3.
|
||||||
|
"FusedBatchNormV3": ["conv_op", "gamma_op", "beta_op", "mean_op", "var_op"]
|
||||||
}
|
}
|
||||||
# Name of the attribute epsilon value is stored in.
|
# Name of the attribute epsilon value is stored in.
|
||||||
EPSILON_ATTR = {
|
EPSILON_ATTR = {
|
||||||
"BatchNormWithGlobalNormalization": "variance_epsilon",
|
"BatchNormWithGlobalNormalization": "variance_epsilon",
|
||||||
"FusedBatchNorm": "epsilon"
|
"FusedBatchNorm": "epsilon",
|
||||||
|
"FusedBatchNormV3": "epsilon"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -210,10 +213,10 @@ def fold_batch_norms(input_graph_def):
|
|||||||
addition, rather than the more expensive multiple ops, and even bake the
|
addition, rather than the more expensive multiple ops, and even bake the
|
||||||
scaling into the convolution weights. This function identifies the typical
|
scaling into the convolution weights. This function identifies the typical
|
||||||
pattern of batch normalization subgraphs, and performs the transformation to
|
pattern of batch normalization subgraphs, and performs the transformation to
|
||||||
fold the computations down into a simpler form. It currently only spots batch
|
fold the computations down into a simpler form. It currently only supports
|
||||||
normalization that's performed by the BatchNormWithGlobalNormalization and
|
batch normalization that's performed by the BatchNormWithGlobalNormalization
|
||||||
FusedBatchNorm ops, and will need to be extended in the future to handle the
|
FusedBatchNorm and FusedBatchNormV3 ops, and will need to be extended in the
|
||||||
newer style.
|
future to handle the newer style.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_graph_def: A GraphDef containing a model.
|
input_graph_def: A GraphDef containing a model.
|
||||||
@ -234,12 +237,33 @@ def fold_batch_norms(input_graph_def):
|
|||||||
nodes_to_skip = {}
|
nodes_to_skip = {}
|
||||||
new_ops = []
|
new_ops = []
|
||||||
for node in input_graph_def.node:
|
for node in input_graph_def.node:
|
||||||
if node.op not in ("BatchNormWithGlobalNormalization", "FusedBatchNorm"):
|
if (node.op not in ("BatchNormWithGlobalNormalization",
|
||||||
|
"FusedBatchNorm", "FusedBatchNormV3")):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
conv_op = node_from_map(input_node_map,
|
bias = None
|
||||||
|
conv_op = node_from_map(
|
||||||
|
input_node_map,
|
||||||
node.input[INPUT_ORDER[node.op].index("conv_op")])
|
node.input[INPUT_ORDER[node.op].index("conv_op")])
|
||||||
if conv_op.op != "Conv2D" and conv_op.op != "DepthwiseConv2dNative":
|
# There might be an Add/BiasAdd op between the conv and the batchnorm,
|
||||||
|
# which we can fold into the mean param of the batchnorm.
|
||||||
|
if conv_op.op in ['BiasAdd', 'Add', 'AddV2']:
|
||||||
|
add_op = conv_op
|
||||||
|
# Follow the first input of the add to get to the conv.
|
||||||
|
conv_op = node_from_map(
|
||||||
|
input_node_map, add_op.input[0])
|
||||||
|
bias = node_from_map(input_node_map, add_op.input[1])
|
||||||
|
if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]:
|
||||||
|
# Follow the second input of the add to get to the conv.
|
||||||
|
conv_op = node_from_map(
|
||||||
|
input_node_map, add_op.input[1])
|
||||||
|
bias = node_from_map(input_node_map, add_op.input[0])
|
||||||
|
if bias and bias.op != 'Const':
|
||||||
|
tf_logging.warning("The bias %s after the conv %s was not a constant. "
|
||||||
|
"Maybe because freeze_graph wasn't "
|
||||||
|
"run first?" % (bias.name, conv_op.name))
|
||||||
|
continue
|
||||||
|
if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]:
|
||||||
tf_logging.warning("Didn't find expected Conv2D or DepthwiseConv2dNative"
|
tf_logging.warning("Didn't find expected Conv2D or DepthwiseConv2dNative"
|
||||||
" input to '%s'" % node.name)
|
" input to '%s'" % node.name)
|
||||||
continue
|
continue
|
||||||
@ -264,6 +288,10 @@ def fold_batch_norms(input_graph_def):
|
|||||||
" run first?" % (node.name, mean_op))
|
" run first?" % (node.name, mean_op))
|
||||||
continue
|
continue
|
||||||
mean_value = values_from_const(mean_op)
|
mean_value = values_from_const(mean_op)
|
||||||
|
if bias is not None:
|
||||||
|
# Adjust the mean of the batchnorm based on the add op in-between the conv
|
||||||
|
# and the batchnorm.
|
||||||
|
mean_value = mean_value - values_from_const(bias)
|
||||||
if mean_value.shape != (channel_count,):
|
if mean_value.shape != (channel_count,):
|
||||||
tf_logging.warning("Incorrect shape for mean, found %s, expected %s,"
|
tf_logging.warning("Incorrect shape for mean, found %s, expected %s,"
|
||||||
" for node %s" % (str(mean_value.shape), str(
|
" for node %s" % (str(mean_value.shape), str(
|
||||||
@ -315,11 +343,9 @@ def fold_batch_norms(input_graph_def):
|
|||||||
variance_epsilon_value = node.attr[EPSILON_ATTR[node.op]].f
|
variance_epsilon_value = node.attr[EPSILON_ATTR[node.op]].f
|
||||||
nodes_to_skip[node.name] = True
|
nodes_to_skip[node.name] = True
|
||||||
nodes_to_skip[weights_op.name] = True
|
nodes_to_skip[weights_op.name] = True
|
||||||
nodes_to_skip[mean_op.name] = True
|
|
||||||
nodes_to_skip[var_op.name] = True
|
|
||||||
nodes_to_skip[beta_op.name] = True
|
|
||||||
nodes_to_skip[gamma_op.name] = True
|
|
||||||
nodes_to_skip[conv_op.name] = True
|
nodes_to_skip[conv_op.name] = True
|
||||||
|
if bias is not None:
|
||||||
|
nodes_to_skip[add_op.name] = True
|
||||||
|
|
||||||
if scale_after_normalization(node):
|
if scale_after_normalization(node):
|
||||||
scale_value = (
|
scale_value = (
|
||||||
@ -346,11 +372,16 @@ def fold_batch_norms(input_graph_def):
|
|||||||
it.iternext()
|
it.iternext()
|
||||||
scaled_weights_op = node_def_pb2.NodeDef()
|
scaled_weights_op = node_def_pb2.NodeDef()
|
||||||
scaled_weights_op.op = "Const"
|
scaled_weights_op.op = "Const"
|
||||||
scaled_weights_op.name = weights_op.name
|
scaled_weights_op.name = conv_op.name + '_weights'
|
||||||
scaled_weights_op.attr["dtype"].CopyFrom(weights_op.attr["dtype"])
|
scaled_weights_op.attr["dtype"].CopyFrom(weights_op.attr["dtype"])
|
||||||
scaled_weights_op.attr["value"].CopyFrom(
|
scaled_weights_op.attr["value"].CopyFrom(
|
||||||
attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
|
attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
|
||||||
scaled_weights, weights.dtype.type, weights.shape)))
|
scaled_weights, weights.dtype.type, weights.shape)))
|
||||||
|
# Replace the weights node with scaled weights node
|
||||||
|
for i, weights_node in enumerate(conv_op.input):
|
||||||
|
if weights_node == weights_op.name:
|
||||||
|
conv_op.input[i] = scaled_weights_op.name
|
||||||
|
|
||||||
new_conv_op = node_def_pb2.NodeDef()
|
new_conv_op = node_def_pb2.NodeDef()
|
||||||
new_conv_op.CopyFrom(conv_op)
|
new_conv_op.CopyFrom(conv_op)
|
||||||
offset_op = node_def_pb2.NodeDef()
|
offset_op = node_def_pb2.NodeDef()
|
||||||
@ -374,9 +405,16 @@ def fold_batch_norms(input_graph_def):
|
|||||||
continue
|
continue
|
||||||
new_node = node_def_pb2.NodeDef()
|
new_node = node_def_pb2.NodeDef()
|
||||||
new_node.CopyFrom(node)
|
new_node.CopyFrom(node)
|
||||||
|
retained_input = []
|
||||||
|
for input_node in new_node.input:
|
||||||
|
if not input_node.startswith('^') or input_node[1:] not in nodes_to_skip:
|
||||||
|
retained_input.append(input_node)
|
||||||
|
new_node.input[:] = retained_input
|
||||||
|
|
||||||
result_graph_def.node.extend([new_node])
|
result_graph_def.node.extend([new_node])
|
||||||
|
|
||||||
result_graph_def.node.extend(new_ops)
|
result_graph_def.node.extend(new_ops)
|
||||||
|
result_graph_def.versions.CopyFrom(input_graph_def.versions)
|
||||||
return result_graph_def
|
return result_graph_def
|
||||||
|
|
||||||
|
|
||||||
|
@ -233,6 +233,67 @@ class OptimizeForInferenceTest(test.TestCase):
|
|||||||
for node in optimized_graph_def.node:
|
for node in optimized_graph_def.node:
|
||||||
self.assertNotEqual("FusedBatchNorm", node.op)
|
self.assertNotEqual("FusedBatchNorm", node.op)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testFoldFusedBatchNormsV3(self):
|
||||||
|
for data_format, conv2d_func in [
|
||||||
|
("NHWC", nn_ops.conv2d), ("NCHW", nn_ops.conv2d),
|
||||||
|
("NHWC", nn_ops.depthwise_conv2d_native),
|
||||||
|
("NCHW", nn_ops.depthwise_conv2d_native)
|
||||||
|
]:
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
|
||||||
|
input_op = constant_op.constant(
|
||||||
|
np.array(inputs),
|
||||||
|
shape=[1, 1, 6, 2] if data_format == "NHWC" else [1, 2, 1, 6],
|
||||||
|
dtype=dtypes.float32)
|
||||||
|
if conv2d_func == nn_ops.conv2d:
|
||||||
|
weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
|
||||||
|
weights_op = constant_op.constant(
|
||||||
|
np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32)
|
||||||
|
else:
|
||||||
|
weights = [1, 2, 0.3, 0.4]
|
||||||
|
weights_op = constant_op.constant(
|
||||||
|
np.array(weights), shape=[1, 2, 2, 1], dtype=dtypes.float32)
|
||||||
|
mean_op = constant_op.constant(
|
||||||
|
np.array([10, 20]), shape=[2], dtype=dtypes.float32)
|
||||||
|
variance_op = constant_op.constant(
|
||||||
|
np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32)
|
||||||
|
beta_op = constant_op.constant(
|
||||||
|
np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32)
|
||||||
|
gamma_op = constant_op.constant(
|
||||||
|
np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32)
|
||||||
|
ops.get_default_graph().graph_def_versions.producer = 9
|
||||||
|
conv_op = conv2d_func(
|
||||||
|
input_op,
|
||||||
|
weights_op, [1, 1, 1, 1],
|
||||||
|
padding="SAME",
|
||||||
|
data_format=data_format,
|
||||||
|
name="conv_op")
|
||||||
|
gen_nn_ops.fused_batch_norm_v3(
|
||||||
|
conv_op,
|
||||||
|
gamma_op,
|
||||||
|
beta_op,
|
||||||
|
mean_op,
|
||||||
|
variance_op,
|
||||||
|
0.00001,
|
||||||
|
is_training=False,
|
||||||
|
data_format=data_format,
|
||||||
|
name="output")
|
||||||
|
original_graph_def = sess.graph_def
|
||||||
|
original_result = sess.run(["output:0"])
|
||||||
|
optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
|
||||||
|
original_graph_def)
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
_ = importer.import_graph_def(
|
||||||
|
optimized_graph_def, input_map={}, name="optimized")
|
||||||
|
optimized_result = sess.run(["optimized/output:0"])
|
||||||
|
|
||||||
|
self.assertAllClose(
|
||||||
|
original_result, optimized_result, rtol=1e-04, atol=1e-06)
|
||||||
|
|
||||||
|
for node in optimized_graph_def.node:
|
||||||
|
self.assertNotEqual("FusedBatchNormV3", node.op)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testFuseResizePadAndConv(self):
|
def testFuseResizePadAndConv(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
|
Loading…
Reference in New Issue
Block a user