optimize_for_inference_lib.fold_batch_norms() preserves data_format (#16075)

Fixes https://github.com/tensorflow/tensorflow/issues/15034
This commit is contained in:
yegord 2018-02-01 00:02:25 +01:00 committed by Rasmus Munk Larsen
parent c24e3dd451
commit 6afe900f54
2 changed files with 45 additions and 39 deletions

View File

@ -349,6 +349,7 @@ def fold_batch_norms(input_graph_def):
bias_add_op.op = "BiasAdd"
bias_add_op.name = node.name
bias_add_op.attr["T"].CopyFrom(conv_op.attr["T"])
bias_add_op.attr["data_format"].CopyFrom(conv_op.attr["data_format"])
bias_add_op.input.extend([new_conv_op.name, offset_op.name])
new_ops.extend([scaled_weights_op, new_conv_op, offset_op, bias_add_op])

View File

@ -173,48 +173,53 @@ class OptimizeForInferenceTest(test.TestCase):
self.assertNotEqual("BatchNormWithGlobalNormalization", node.op)
def testFoldFusedBatchNorms(self):
with self.test_session() as sess:
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
input_op = constant_op.constant(
np.array(inputs), shape=[1, 1, 6, 2], dtype=dtypes.float32)
weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
weights_op = constant_op.constant(
np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32)
conv_op = nn_ops.conv2d(
input_op, weights_op, [1, 1, 1, 1], padding="SAME", name="conv_op")
mean_op = constant_op.constant(
np.array([10, 20]), shape=[2], dtype=dtypes.float32)
variance_op = constant_op.constant(
np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32)
beta_op = constant_op.constant(
np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32)
gamma_op = constant_op.constant(
np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32)
ops.get_default_graph().graph_def_versions.producer = 9
gen_nn_ops._fused_batch_norm(
conv_op,
gamma_op,
beta_op,
mean_op,
variance_op,
0.00001,
is_training=False,
name="output")
original_graph_def = sess.graph_def
original_result = sess.run(["output:0"])
optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
original_graph_def)
for data_format, use_gpu in [("NHWC", False), ("NCHW", True)]:
with self.test_session(use_gpu=use_gpu) as sess:
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
input_op = constant_op.constant(
np.array(inputs),
shape=[1, 1, 6, 2] if data_format == "NHWC" else [1, 2, 1, 6],
dtype=dtypes.float32)
weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
weights_op = constant_op.constant(
np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32)
conv_op = nn_ops.conv2d(
input_op, weights_op, [1, 1, 1, 1], padding="SAME",
data_format=data_format, name="conv_op")
mean_op = constant_op.constant(
np.array([10, 20]), shape=[2], dtype=dtypes.float32)
variance_op = constant_op.constant(
np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32)
beta_op = constant_op.constant(
np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32)
gamma_op = constant_op.constant(
np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32)
ops.get_default_graph().graph_def_versions.producer = 9
gen_nn_ops._fused_batch_norm(
conv_op,
gamma_op,
beta_op,
mean_op,
variance_op,
0.00001,
is_training=False,
data_format=data_format,
name="output")
original_graph_def = sess.graph_def
original_result = sess.run(["output:0"])
optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
original_graph_def)
with self.test_session() as sess:
_ = importer.import_graph_def(
optimized_graph_def, input_map={}, name="optimized")
optimized_result = sess.run(["optimized/output:0"])
with self.test_session(use_gpu=use_gpu) as sess:
_ = importer.import_graph_def(
optimized_graph_def, input_map={}, name="optimized")
optimized_result = sess.run(["optimized/output:0"])
self.assertAllClose(
original_result, optimized_result, rtol=1e-04, atol=1e-06)
self.assertAllClose(
original_result, optimized_result, rtol=1e-04, atol=1e-06)
for node in optimized_graph_def.node:
self.assertNotEqual("FusedBatchNorm", node.op)
for node in optimized_graph_def.node:
self.assertNotEqual("FusedBatchNorm", node.op)
def testFuseResizePadAndConv(self):
with self.test_session() as sess: