Extends fold_batch_norms transform to also fold the mul introduced by batch normalization after fully connected layers (MatMul).

Change: 148868461
This commit is contained in:
A. Unique TensorFlower 2017-03-01 02:29:36 -08:00 committed by TensorFlower Gardener
parent 7e48bada5a
commit ec86b03789
3 changed files with 75 additions and 16 deletions

View File

@ -341,12 +341,13 @@ Args: None \
Prerequisites: [fold_constants](#fold_constants)
This transform tries to optimize away the Mul that's introduced after a Conv2D
when batch normalization has been used during training. It scans the graph for
any channel-wise multiplies immediately after convolutions, and multiplies the
convolution's weights with the Mul instead so this can be omitted at inference
time. You'll need to make sure you run [fold_constants](#fold_constants) first,
since the pattern can only be spotted if the normal complex expression that's
produced by training for the Mul input is collapsed down into a simple constant.
(or a MatMul) when batch normalization has been used during training. It scans
the graph for any channel-wise multiplies immediately after convolutions, and
multiplies the convolution's (or matrix multiplication's) weights with the Mul
instead so this can be omitted at inference time. You'll need to make sure you
run [fold_constants](#fold_constants) first, since the pattern can only be
spotted if the normal complex expression that's produced by training for the Mul
input is collapsed down into a simple constant.
### fold_constants

View File

@ -27,23 +27,24 @@ limitations under the License.
namespace tensorflow {
namespace graph_transforms {
// Converts Conv2D ops followed by column-wise Muls into equivalent ops with the
// Mul baked into the convolution weights, to save computation during inference.
// Converts Conv2D or MatMul ops followed by column-wise Muls into equivalent
// ops with the Mul baked into the convolution weights, to save computation
// during inference.
Status FoldBatchNorms(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
GraphDef replaced_graph_def;
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
input_graph_def, // clang-format off
{"Mul", // mul_node
{"Mul", // mul_node
{
{"Conv2D", // conv_node
{"Conv2D|MatMul", // conv_node
{
{"*"}, // input_node
{"Const"}, // weights_node
{"*"}, // input_node
{"Const"}, // weights_node
}
},
{"Const"}, // mul_values_node
{"Const"}, // mul_values_node
}
}, // clang-format on
[](const NodeMatch& match, const std::set<string>& input_nodes,
@ -61,7 +62,8 @@ Status FoldBatchNorms(const GraphDef& input_graph_def,
// Make sure all the inputs really are vectors, with as many entries as
// there are columns in the weights.
const int64 weights_cols = weights.shape().dim_size(3);
const int weights_cols_index = conv_node.op() == "Conv2D" ? 3 : 1;
const int64 weights_cols = weights.shape().dim_size(weights_cols_index);
if ((mul_values.shape().dims() != 1) ||
(mul_values.shape().dim_size(0) != weights_cols)) {
return errors::InvalidArgument(

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
@ -35,7 +36,7 @@ Status FoldBatchNorms(const GraphDef& input_graph_def,
class FoldBatchNormsTest : public ::testing::Test {
protected:
void TestFoldBatchNorms() {
void TestFoldBatchNormsConv2D() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
@ -85,9 +86,64 @@ class FoldBatchNormsTest : public ::testing::Test {
EXPECT_NE("Mul", node.op());
}
}
void TestFoldBatchNormsMatMul() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Tensor input_data(DT_FLOAT, TensorShape({6, 2}));
test::FillValues<float>(
&input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
-5.0f, -3.0f, -6.0f});
Output input_op =
Const(root.WithOpName("input_op"), Input::Initializer(input_data));
Tensor weights_data(DT_FLOAT, TensorShape({2, 2}));
test::FillValues<float>(&weights_data, {1.0f, 2.0f, 0.3f, 0.4f});
Output weights_op =
Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
Output matmul_op =
MatMul(root.WithOpName("matmul_op"), input_op, weights_op);
Tensor mul_values_data(DT_FLOAT, TensorShape({2}));
test::FillValues<float>(&mul_values_data, {2.0f, 3.0f});
Output mul_values_op = Const(root.WithOpName("mul_values"),
Input::Initializer(mul_values_data));
Output mul_op = Mul(root.WithOpName("output"), matmul_op, mul_values_op);
GraphDef original_graph_def;
TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
TF_ASSERT_OK(original_session->Create(original_graph_def));
std::vector<Tensor> original_outputs;
TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
GraphDef fused_graph_def;
TF_ASSERT_OK(
FoldBatchNorms(original_graph_def, {{}, {"output"}}, &fused_graph_def));
std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
TF_ASSERT_OK(fused_session->Create(fused_graph_def));
std::vector<Tensor> fused_outputs;
TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
for (const NodeDef& node : fused_graph_def.node()) {
EXPECT_NE("Mul", node.op());
}
}
};
TEST_F(FoldBatchNormsTest, TestFoldBatchNorms) { TestFoldBatchNorms(); }
TEST_F(FoldBatchNormsTest, TestFoldBatchNormsConv2D) {
TestFoldBatchNormsConv2D();
}
TEST_F(FoldBatchNormsTest, TestFoldBatchNormsMatMul) {
TestFoldBatchNormsMatMul();
}
} // namespace graph_transforms
} // namespace tensorflow