optimize_for_inference_lib.fold_batch_norms() preserves data_format (#16075)
Fixes https://github.com/tensorflow/tensorflow/issues/15034
This commit is contained in:
parent
c24e3dd451
commit
6afe900f54
@ -349,6 +349,7 @@ def fold_batch_norms(input_graph_def):
|
||||
bias_add_op.op = "BiasAdd"
|
||||
bias_add_op.name = node.name
|
||||
bias_add_op.attr["T"].CopyFrom(conv_op.attr["T"])
|
||||
bias_add_op.attr["data_format"].CopyFrom(conv_op.attr["data_format"])
|
||||
bias_add_op.input.extend([new_conv_op.name, offset_op.name])
|
||||
new_ops.extend([scaled_weights_op, new_conv_op, offset_op, bias_add_op])
|
||||
|
||||
|
@ -173,48 +173,53 @@ class OptimizeForInferenceTest(test.TestCase):
|
||||
self.assertNotEqual("BatchNormWithGlobalNormalization", node.op)
|
||||
|
||||
def testFoldFusedBatchNorms(self):
|
||||
with self.test_session() as sess:
|
||||
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
|
||||
input_op = constant_op.constant(
|
||||
np.array(inputs), shape=[1, 1, 6, 2], dtype=dtypes.float32)
|
||||
weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
|
||||
weights_op = constant_op.constant(
|
||||
np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32)
|
||||
conv_op = nn_ops.conv2d(
|
||||
input_op, weights_op, [1, 1, 1, 1], padding="SAME", name="conv_op")
|
||||
mean_op = constant_op.constant(
|
||||
np.array([10, 20]), shape=[2], dtype=dtypes.float32)
|
||||
variance_op = constant_op.constant(
|
||||
np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32)
|
||||
beta_op = constant_op.constant(
|
||||
np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32)
|
||||
gamma_op = constant_op.constant(
|
||||
np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32)
|
||||
ops.get_default_graph().graph_def_versions.producer = 9
|
||||
gen_nn_ops._fused_batch_norm(
|
||||
conv_op,
|
||||
gamma_op,
|
||||
beta_op,
|
||||
mean_op,
|
||||
variance_op,
|
||||
0.00001,
|
||||
is_training=False,
|
||||
name="output")
|
||||
original_graph_def = sess.graph_def
|
||||
original_result = sess.run(["output:0"])
|
||||
optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
|
||||
original_graph_def)
|
||||
for data_format, use_gpu in [("NHWC", False), ("NCHW", True)]:
|
||||
with self.test_session(use_gpu=use_gpu) as sess:
|
||||
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
|
||||
input_op = constant_op.constant(
|
||||
np.array(inputs),
|
||||
shape=[1, 1, 6, 2] if data_format == "NHWC" else [1, 2, 1, 6],
|
||||
dtype=dtypes.float32)
|
||||
weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
|
||||
weights_op = constant_op.constant(
|
||||
np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32)
|
||||
conv_op = nn_ops.conv2d(
|
||||
input_op, weights_op, [1, 1, 1, 1], padding="SAME",
|
||||
data_format=data_format, name="conv_op")
|
||||
mean_op = constant_op.constant(
|
||||
np.array([10, 20]), shape=[2], dtype=dtypes.float32)
|
||||
variance_op = constant_op.constant(
|
||||
np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32)
|
||||
beta_op = constant_op.constant(
|
||||
np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32)
|
||||
gamma_op = constant_op.constant(
|
||||
np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32)
|
||||
ops.get_default_graph().graph_def_versions.producer = 9
|
||||
gen_nn_ops._fused_batch_norm(
|
||||
conv_op,
|
||||
gamma_op,
|
||||
beta_op,
|
||||
mean_op,
|
||||
variance_op,
|
||||
0.00001,
|
||||
is_training=False,
|
||||
data_format=data_format,
|
||||
name="output")
|
||||
original_graph_def = sess.graph_def
|
||||
original_result = sess.run(["output:0"])
|
||||
optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
|
||||
original_graph_def)
|
||||
|
||||
with self.test_session() as sess:
|
||||
_ = importer.import_graph_def(
|
||||
optimized_graph_def, input_map={}, name="optimized")
|
||||
optimized_result = sess.run(["optimized/output:0"])
|
||||
with self.test_session(use_gpu=use_gpu) as sess:
|
||||
_ = importer.import_graph_def(
|
||||
optimized_graph_def, input_map={}, name="optimized")
|
||||
optimized_result = sess.run(["optimized/output:0"])
|
||||
|
||||
self.assertAllClose(
|
||||
original_result, optimized_result, rtol=1e-04, atol=1e-06)
|
||||
self.assertAllClose(
|
||||
original_result, optimized_result, rtol=1e-04, atol=1e-06)
|
||||
|
||||
for node in optimized_graph_def.node:
|
||||
self.assertNotEqual("FusedBatchNorm", node.op)
|
||||
for node in optimized_graph_def.node:
|
||||
self.assertNotEqual("FusedBatchNorm", node.op)
|
||||
|
||||
def testFuseResizePadAndConv(self):
|
||||
with self.test_session() as sess:
|
||||
|
Loading…
Reference in New Issue
Block a user