Fuse BN and Relu in mkl path
This commit is contained in:
parent
5992e75800
commit
d4d23502bf
@ -268,6 +268,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
csinfo_.dequantize = "Dequantize";
|
||||
csinfo_.fused_batch_norm = "FusedBatchNorm";
|
||||
csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
|
||||
csinfo_.fused_batch_norm_ex = "_FusedBatchNormEx";
|
||||
csinfo_.fused_batch_norm_v2 = "FusedBatchNormV2";
|
||||
csinfo_.fused_batch_norm_grad_v2 = "FusedBatchNormGradV2";
|
||||
csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3";
|
||||
@ -294,6 +295,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
"_MklDepthwiseConv2dNativeBackpropInput";
|
||||
csinfo_.mkl_depthwise_conv2d_grad_filter =
|
||||
"_MklDepthwiseConv2dNativeBackpropFilter";
|
||||
csinfo_.mkl_fused_batch_norm_ex = "_MklFusedBatchNormEx";
|
||||
csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
|
||||
csinfo_.mkl_fused_matmul = "_MklFusedMatMul";
|
||||
csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
|
||||
@ -476,6 +478,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
{csinfo_.fused_batch_norm_grad_v3,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
rinfo_.push_back({csinfo_.fused_batch_norm_ex,
|
||||
csinfo_.mkl_fused_batch_norm_ex, CopyAttrsAll,
|
||||
FusedBatchNormExRewrite, kRewriteForLayoutPropagation});
|
||||
#endif
|
||||
rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d,
|
||||
CopyAttrsFusedConv2D, FusedConv2DRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
@ -920,6 +927,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
string dequantize;
|
||||
string fused_batch_norm;
|
||||
string fused_batch_norm_grad;
|
||||
string fused_batch_norm_ex;
|
||||
string fused_batch_norm_v2;
|
||||
string fused_batch_norm_grad_v2;
|
||||
string fused_batch_norm_v3;
|
||||
@ -944,6 +952,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
string mkl_conv2d_with_bias;
|
||||
string mkl_depthwise_conv2d_grad_input;
|
||||
string mkl_depthwise_conv2d_grad_filter;
|
||||
string mkl_fused_batch_norm_ex;
|
||||
string mkl_fused_conv2d;
|
||||
string mkl_fused_matmul;
|
||||
string mkl_pad_with_conv2d;
|
||||
@ -1652,6 +1661,31 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
return do_rewrite;
|
||||
}
|
||||
|
||||
static bool FusedBatchNormExRewrite(const Node* n) {
|
||||
CHECK_NOTNULL(n);
|
||||
|
||||
int num_side_inputs;
|
||||
TF_CHECK_OK(GetNodeAttr(n->def(), "num_side_inputs", &num_side_inputs));
|
||||
string activation_mode;
|
||||
TF_CHECK_OK(GetNodeAttr(n->def(), "activation_mode", &activation_mode));
|
||||
|
||||
// if the num_side_inputs is not 0, don't rewrite the node.
|
||||
if (num_side_inputs != 0) {
|
||||
VLOG(1) << "FusedBatchNormExRewrite: The model sets num_side_inputs"
|
||||
<< "larger than 0 is not optimized by Intel MKL.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// if the activation_mode is not 'Relu', don't rewrite the node.
|
||||
if (activation_mode != "Relu") {
|
||||
VLOG(1) << "FusedBatchNormExRewrite: Only Relu activation mode is"
|
||||
<< "supported by Intel MKL.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool FusedConv2DRewrite(const Node* n) {
|
||||
// MKL DNN currently doesn't support all fusions that grappler fuses
|
||||
// together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if
|
||||
@ -2131,9 +2165,6 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
|
||||
// Number of input slots to original op
|
||||
// Input slots are represented by .Input() calls in REGISTER_OP.
|
||||
int old_node_input_slots = old_node->op_def().input_arg_size();
|
||||
// Actual number of inputs can be greater than or equal to number
|
||||
// of Input slots because inputs of type list could be unfolded.
|
||||
CHECK_GE(old_node_inputs.size(), old_node_input_slots);
|
||||
int nn_slot_idx = 0; // slot index for inputs of new node
|
||||
|
||||
// Let's copy all inputs (TF tensors) of original node to new node.
|
||||
@ -2141,13 +2172,14 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
|
||||
for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
|
||||
// An input slot could be a single tensor or a list. We need
|
||||
// to handle this case accordingly.
|
||||
CHECK_LT(iidx, old_node_inputs.size());
|
||||
const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
|
||||
if (ArgIsList(arg)) {
|
||||
std::vector<NodeBuilder::NodeOut> new_node_inputs;
|
||||
int N = GetTensorListLength(arg, old_node);
|
||||
GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
|
||||
&new_node_inputs);
|
||||
int tensor_list_length = GetTensorListLength(arg, old_node);
|
||||
if (tensor_list_length != 0) {
|
||||
GetNodesProducingTFTensorList(old_node_inputs, &iidx,
|
||||
tensor_list_length, &new_node_inputs);
|
||||
}
|
||||
nb->Input(new_node_inputs);
|
||||
nn_slot_idx++;
|
||||
} else {
|
||||
@ -2180,13 +2212,14 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
|
||||
for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
|
||||
// An input slot could be a single tensor or a list. We need
|
||||
// to handle this case accordingly.
|
||||
CHECK_LT(iidx, old_node_inputs.size());
|
||||
const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
|
||||
if (ArgIsList(arg)) {
|
||||
std::vector<NodeBuilder::NodeOut> new_node_inputs;
|
||||
int N = GetTensorListLength(arg, old_node);
|
||||
GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N,
|
||||
&new_node_inputs);
|
||||
int tensor_list_length = GetTensorListLength(arg, old_node);
|
||||
if (tensor_list_length != 0) {
|
||||
GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx,
|
||||
tensor_list_length, &new_node_inputs);
|
||||
}
|
||||
nb->Input(new_node_inputs);
|
||||
nn_slot_idx++;
|
||||
} else {
|
||||
@ -3702,6 +3735,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
|
||||
n->type_string() != csinfo_.pad_with_conv2d &&
|
||||
n->type_string() != csinfo_.pad_with_fused_conv2d &&
|
||||
n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
|
||||
n->type_string() != csinfo_.fused_batch_norm_ex &&
|
||||
n->type_string() != csinfo_.fused_conv2d &&
|
||||
n->type_string() != csinfo_.fused_matmul &&
|
||||
!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
|
||||
|
@ -3108,6 +3108,112 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV3_Negative) {
|
||||
"B->F:1;C->F:2;D->F:3;E->F:4;F->G:1");
|
||||
}
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
InitGraph("node { name: 'A' op: '" #INPUT \
|
||||
"'}" \
|
||||
"node { name: 'B' op: 'Input'}" \
|
||||
"node { name: 'C' op: 'Input'}" \
|
||||
"node { name: 'D' op: 'Input'}" \
|
||||
"node { name: 'E' op: 'Input'}" \
|
||||
"node { name: 'F' op: '_FusedBatchNormEx'" \
|
||||
" attr { key: 'T' value { type: " #T \
|
||||
" } }" \
|
||||
" attr { key: 'U' value { type: DT_FLOAT } }" \
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }" \
|
||||
" attr { key: 'epsilon' value { f: 0.0001 } }" \
|
||||
" attr { key: 'num_side_inputs' value { i: 0 } }" \
|
||||
" attr { key: 'is_training' value { b: true } }" \
|
||||
" attr { key: 'activation_mode' value { s: 'Relu' } }" \
|
||||
" input: ['A', 'B', 'C', 'D', 'E'] }" \
|
||||
"node { name: 'G' op: 'Zeta'" \
|
||||
" attr { key: 'T' value { type: " #T \
|
||||
" } }" \
|
||||
" input: ['A', 'F'] }"); \
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||
"A(" #INPUT \
|
||||
");B(Input);C(Input);D(Input);" \
|
||||
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);" \
|
||||
"DMT/_4(Const);E(Input);" \
|
||||
"F(_MklFusedBatchNormEx);G(Zeta)|A->F;A->G;" \
|
||||
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
|
||||
"A:control->DMT/_2:control;A:control->DMT/_3:control;" \
|
||||
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;" \
|
||||
"DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;" \
|
||||
"E->F:4;F->G:1"); \
|
||||
}
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Positive);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
// Rewrite test for _FusedBatchNormEx Op with side input
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
InitGraph("node { name: 'A' op: '" #INPUT \
|
||||
"'}" \
|
||||
"node { name: 'B' op: 'Input'}" \
|
||||
"node { name: 'C' op: 'Input'}" \
|
||||
"node { name: 'D' op: 'Input'}" \
|
||||
"node { name: 'E' op: 'Input'}" \
|
||||
"node { name: 'F' op: '" #INPUT \
|
||||
"'}" \
|
||||
"node { name: 'G' op: '_FusedBatchNormEx'" \
|
||||
" attr { key: 'T' value { type: " #T \
|
||||
" } }" \
|
||||
" attr { key: 'U' value { type: DT_FLOAT } }" \
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }" \
|
||||
" attr { key: 'epsilon' value { f: 0.0001 } }" \
|
||||
" attr { key: 'num_side_inputs' value { i: 1 } }" \
|
||||
" attr { key: 'is_training' value { b: true } }" \
|
||||
" attr { key: 'activation_mode' value { s: 'Relu' } }" \
|
||||
" input: ['A', 'B', 'C', 'D', 'E', 'F'] }" \
|
||||
"node { name: 'H' op: 'Zeta'" \
|
||||
" attr { key: 'T' value { type: " #T \
|
||||
" } }" \
|
||||
" input: ['A', 'G'] }"); \
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||
"A(" #INPUT \
|
||||
");B(Input);C(Input);D(Input);E(Input);" \
|
||||
"F(" #INPUT \
|
||||
");G(_FusedBatchNormEx);H(Zeta)|A->G;A->H;" \
|
||||
"B->G:1;C->G:2;D->G:3;E->G:4;F->G:5;G->H:1"); \
|
||||
}
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Negative1);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
// Rewrite test for _FusedBatchNormEx Op with Identity activation
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
InitGraph("node { name: 'A' op: '" #INPUT \
|
||||
"'}" \
|
||||
"node { name: 'B' op: 'Input'}" \
|
||||
"node { name: 'C' op: 'Input'}" \
|
||||
"node { name: 'D' op: 'Input'}" \
|
||||
"node { name: 'E' op: 'Input'}" \
|
||||
"node { name: 'G' op: '_FusedBatchNormEx'" \
|
||||
" attr { key: 'T' value { type: " #T \
|
||||
" } }" \
|
||||
" attr { key: 'U' value { type: DT_FLOAT } }" \
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }" \
|
||||
" attr { key: 'epsilon' value { f: 0.0001 } }" \
|
||||
" attr { key: 'num_side_inputs' value { i: 1 } }" \
|
||||
" attr { key: 'is_training' value { b: true } }" \
|
||||
" attr { key: 'activation_mode' value { s: 'Identity' } }" \
|
||||
" input: ['A', 'B', 'C', 'D', 'E'] }" \
|
||||
"node { name: 'H' op: 'Zeta'" \
|
||||
" attr { key: 'T' value { type: " #T \
|
||||
" } }" \
|
||||
" input: ['A', 'G'] }"); \
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||
"A(" #INPUT \
|
||||
");B(Input);C(Input);D(Input);E(Input);" \
|
||||
"G(_FusedBatchNormEx);H(Zeta)|A->G;A->H;" \
|
||||
"B->G:1;C->G:2;D->G:3;E->G:4;G->H:1"); \
|
||||
}
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Negative2);
|
||||
#undef REGISTER_TEST
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizedDepthwiseConv2D_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'QuantizedUnsignedInt8Input'}"
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/remapper.h"
|
||||
#include "tensorflow/core/grappler/utils/grappler_test.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -173,6 +174,178 @@ TEST_F(MklRemapperTest, FuseConv2DWithBiasAndAddNRelu) {
|
||||
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
TEST_F(MklRemapperTest, FuseBatchNormWithRelu) {
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
|
||||
for (bool is_training : {true, false}) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
const int num_channels = 24;
|
||||
|
||||
TensorShape channel_shape({num_channels});
|
||||
TensorShape empty_shape({0});
|
||||
|
||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT,
|
||||
ops::Placeholder::Shape({2, 8, 8, num_channels}));
|
||||
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_FLOAT);
|
||||
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT);
|
||||
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT);
|
||||
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT);
|
||||
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT);
|
||||
|
||||
float epsilon = 0.1f;
|
||||
auto fbn = ops::FusedBatchNormV3(
|
||||
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
|
||||
ops::FusedBatchNormV3::IsTraining(is_training)
|
||||
.Epsilon(epsilon)
|
||||
.DataFormat("NHWC"));
|
||||
auto relu = ops::Relu(s.WithOpName("relu"), fbn.y);
|
||||
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
|
||||
|
||||
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
|
||||
auto scale_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
|
||||
auto offset_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
|
||||
auto mean_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
|
||||
: channel_shape);
|
||||
auto var_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
|
||||
: channel_shape);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
item.feed = {{"input", input_t},
|
||||
{"scale", scale_t},
|
||||
{"offset", offset_t},
|
||||
{"mean", mean_t},
|
||||
{"var", var_t}};
|
||||
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
// Place all nodes on CPU.
|
||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||
item.graph.mutable_node(i)->set_device("/device:CPU:0");
|
||||
}
|
||||
|
||||
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "relu") {
|
||||
EXPECT_EQ(node.op(), "Identity");
|
||||
ASSERT_EQ(node.input_size(), 1);
|
||||
EXPECT_EQ(node.input(0), "fused_batch_norm");
|
||||
found++;
|
||||
}
|
||||
if (node.name() == "fused_batch_norm") {
|
||||
EXPECT_EQ(node.op(), "_FusedBatchNormEx");
|
||||
ASSERT_EQ(node.input_size(), 5);
|
||||
EXPECT_EQ(node.input(0), "input_cast");
|
||||
EXPECT_EQ(node.input(1), "scale");
|
||||
EXPECT_EQ(node.input(2), "offset");
|
||||
EXPECT_EQ(node.input(3), "mean");
|
||||
EXPECT_EQ(node.input(4), "var");
|
||||
|
||||
auto attr = node.attr();
|
||||
EXPECT_EQ(attr["num_side_inputs"].i(), 0);
|
||||
EXPECT_EQ(attr["activation_mode"].s(), "Relu");
|
||||
found++;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(found, 2);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(MklRemapperTest, FuseBatchNormWithAddAndRelu) {
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
|
||||
for (bool is_training : {true, false}) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
const int num_channels = 24;
|
||||
|
||||
TensorShape input_shape({2, 8, 8, num_channels});
|
||||
TensorShape channel_shape({num_channels});
|
||||
TensorShape empty_shape({0});
|
||||
|
||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT,
|
||||
ops::Placeholder::Shape(input_shape));
|
||||
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_FLOAT);
|
||||
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT);
|
||||
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT);
|
||||
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT);
|
||||
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT);
|
||||
auto side_input = Placeholder(s.WithOpName("side_input"), DT_FLOAT,
|
||||
ops::Placeholder::Shape(input_shape));
|
||||
auto side_input_cast =
|
||||
ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_FLOAT);
|
||||
|
||||
float epsilon = 0.1f;
|
||||
auto fbn = ops::FusedBatchNormV3(
|
||||
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
|
||||
ops::FusedBatchNormV3::IsTraining(is_training)
|
||||
.Epsilon(epsilon)
|
||||
.DataFormat("NHWC"));
|
||||
auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast);
|
||||
auto relu = ops::Relu(s.WithOpName("relu"), add);
|
||||
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
|
||||
|
||||
auto input_t = GenerateRandomTensor<DT_FLOAT>(input_shape);
|
||||
auto scale_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
|
||||
auto offset_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
|
||||
auto mean_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
|
||||
: channel_shape);
|
||||
auto var_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
|
||||
: channel_shape);
|
||||
auto side_input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
item.feed = {{"input", input_t}, {"scale", scale_t},
|
||||
{"offset", offset_t}, {"mean", mean_t},
|
||||
{"var", var_t}, {"side_input", side_input_t}};
|
||||
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
// Place all nodes on CPU.
|
||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||
item.graph.mutable_node(i)->set_device("/device:CPU:0");
|
||||
}
|
||||
|
||||
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "add") {
|
||||
EXPECT_EQ(node.op(), "Add");
|
||||
ASSERT_EQ(node.input_size(), 2);
|
||||
EXPECT_EQ(node.input(0), "fused_batch_norm");
|
||||
EXPECT_EQ(node.input(1), "side_input_cast");
|
||||
found++;
|
||||
}
|
||||
if (node.name() == "relu") {
|
||||
EXPECT_EQ(node.op(), "Relu");
|
||||
ASSERT_EQ(node.input_size(), 1);
|
||||
EXPECT_EQ(node.input(0), "add");
|
||||
found++;
|
||||
}
|
||||
if (node.name() == "fused_batch_norm") {
|
||||
EXPECT_EQ(node.op(), "FusedBatchNormV3");
|
||||
ASSERT_EQ(node.input_size(), 5);
|
||||
EXPECT_EQ(node.input(0), "input_cast");
|
||||
EXPECT_EQ(node.input(1), "scale");
|
||||
EXPECT_EQ(node.input(2), "offset");
|
||||
EXPECT_EQ(node.input(3), "mean");
|
||||
EXPECT_EQ(node.input(4), "var");
|
||||
found++;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(found, 3);
|
||||
}
|
||||
}
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
#endif // INTEL_MKL
|
||||
|
@ -741,24 +741,27 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
|
||||
[&](const utils::MutableNodeView& fused_batch_norm) -> bool {
|
||||
const auto* fused_batch_norm_node_def = fused_batch_norm.node();
|
||||
if (!IsFusedBatchNorm(*fused_batch_norm_node_def)) return false;
|
||||
|
||||
// We fuse FusedBatchNorm only on GPU, because on CPU we fuse it with
|
||||
// contraction (MatMul or Conv2D node).
|
||||
// We fuse FusedBatchNorm on GPU or MKL CPU.
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
if (!NodeIsOnGpu(fused_batch_norm_node_def)) return false;
|
||||
#endif
|
||||
|
||||
DataType t_dtype = GetDataTypeFromAttr(*fused_batch_norm_node_def, "T");
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
if (t_dtype != DT_FLOAT && t_dtype != DT_HALF) return false;
|
||||
#else
|
||||
if (t_dtype != DT_FLOAT && t_dtype != DT_BFLOAT16) return false;
|
||||
#endif
|
||||
|
||||
// Get the FusedBatchNorm training mode.
|
||||
bool is_training;
|
||||
if (!GetNodeAttr(*fused_batch_norm_node_def, kIsTraining, &is_training)
|
||||
.ok())
|
||||
return false;
|
||||
|
||||
// In training mode we rely on cuDNN for computing FusedBatchNorm with side
|
||||
// inputs and activation, and it has its own limitations. In inference mode
|
||||
// we have a custom CUDA kernel that doesn't not have these constraints.
|
||||
if (is_training) {
|
||||
if (is_training && NodeIsOnGpu(fused_batch_norm_node_def)) {
|
||||
// cuDNN only supports NHWC data layout.
|
||||
string data_format;
|
||||
if (!GetNodeAttr(*fused_batch_norm_node_def, kDataFormat, &data_format)
|
||||
@ -810,6 +813,12 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
|
||||
|
||||
// Input to a Relu can be an Add node with FusedBatchNorm as one of the inputs
|
||||
if (IsAdd(*relu_fanin_0_node_def)) {
|
||||
// Currently no CPU implementation for "FusedBatchNorm + SideInput +
|
||||
// <Activation>""
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
return false;
|
||||
#endif
|
||||
|
||||
// Check that only Relu node consumes the output of an Add node.
|
||||
if (HasControlFaninOrFanout(*relu_fanin_0_node_view) ||
|
||||
!HasAtMostOneFanoutAtPort0(*relu_fanin_0_node_view) ||
|
||||
@ -881,7 +890,11 @@ void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm,
|
||||
if (fused_batch_norm.op() != "FusedBatchNorm") {
|
||||
(*attr)["U"] = src_attr.at("U");
|
||||
} else {
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
(*attr)["U"] = src_attr.at("T");
|
||||
#else
|
||||
SetAttrValue(DT_FLOAT, &(*attr)["U"]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -8135,7 +8135,14 @@ tf_mkl_kernel_library(
|
||||
tf_mkl_kernel_library(
|
||||
name = "mkl_fused_batch_norm_op",
|
||||
srcs = ["mkl_fused_batch_norm_op.cc"],
|
||||
deps = NN_DEPS + mkl_deps(),
|
||||
hdrs = [
|
||||
"fused_batch_norm_op.h",
|
||||
"no_op.h",
|
||||
],
|
||||
deps = NN_DEPS + [
|
||||
":fused_batch_norm_op",
|
||||
":no_op",
|
||||
] + mkl_deps(),
|
||||
)
|
||||
|
||||
tf_cc_test_mkl(
|
||||
|
@ -14,14 +14,16 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifdef INTEL_MKL
|
||||
#include "mkldnn.hpp"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/fused_batch_norm_op.h"
|
||||
#include "tensorflow/core/kernels/no_op.h"
|
||||
#include "tensorflow/core/util/mkl_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
#define GET_FLAG(bn_flag) static_cast<int>(BN_FLAGS::bn_flag)
|
||||
#define IS_SET(cflag) (context_.flags & GET_FLAG(cflag))
|
||||
@ -37,11 +39,14 @@ using BatchNormBwdPd = mkldnn::batch_normalization_backward::primitive_desc;
|
||||
namespace tensorflow {
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
|
||||
using FusedBNActivationMode = functor::FusedBatchNormActivationMode;
|
||||
|
||||
struct MklBatchNormFwdParams {
|
||||
memory::dims src_dims;
|
||||
int depth;
|
||||
float eps;
|
||||
bool training;
|
||||
FusedBNActivationMode activation_mode;
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
MEMORY_FORMAT src_format;
|
||||
#else
|
||||
@ -50,14 +55,17 @@ struct MklBatchNormFwdParams {
|
||||
|
||||
MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
bool training, MEMORY_FORMAT src_format)
|
||||
bool training, MEMORY_FORMAT src_format,
|
||||
FusedBNActivationMode activation_mode)
|
||||
#else
|
||||
bool training, memory::desc src_md)
|
||||
bool training, memory::desc src_md,
|
||||
FusedBNActivationMode activation_mode)
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
: src_dims(src_dims),
|
||||
depth(depth),
|
||||
eps(eps),
|
||||
training(training),
|
||||
activation_mode(activation_mode),
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
src_format(src_format) {
|
||||
}
|
||||
@ -90,7 +98,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
||||
// mean_data: output data buffer of means
|
||||
// variance_data: output data buffer of variances
|
||||
void Execute(const T* src_data, const U* weights_data, T* dst_data,
|
||||
U* mean_data, U* variance_data) {
|
||||
U* mean_data, U* variance_data, U* workspace_data) {
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
||||
@ -104,6 +112,9 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
||||
context_.mean_mem->set_data_handle(static_cast<void*>(mean_data));
|
||||
context_.variance_mem->set_data_handle(static_cast<void*>(variance_data));
|
||||
}
|
||||
if (workspace_data != nullptr) {
|
||||
context_.ws_mem->set_data_handle(workspace_data);
|
||||
}
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// Execute batch-normalization forward primitives.
|
||||
execute_primitives(context_.fwd_primitives, context_.fwd_stream,
|
||||
@ -123,6 +134,10 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
||||
context_.mean_mem->set_data_handle(DummyData);
|
||||
context_.variance_mem->set_data_handle(DummyData);
|
||||
}
|
||||
|
||||
if (workspace_data != nullptr) {
|
||||
context_.ws_mem->set_data_handle(DummyData);
|
||||
}
|
||||
}
|
||||
|
||||
MEMORY_PRIMITIVE_DESC GetDstPd() const { return context_.dst_mem->GET_DESC; }
|
||||
@ -158,6 +173,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
||||
std::shared_ptr<mkldnn::memory> dst_mem;
|
||||
std::shared_ptr<mkldnn::memory> mean_mem;
|
||||
std::shared_ptr<mkldnn::memory> variance_mem;
|
||||
std::shared_ptr<mkldnn::memory> ws_mem;
|
||||
|
||||
// Forward BatchNorm primitive descriptor.
|
||||
std::shared_ptr<BatchNormFwdPd> fwd_pd;
|
||||
@ -179,6 +195,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
||||
dst_mem(nullptr),
|
||||
mean_mem(nullptr),
|
||||
variance_mem(nullptr),
|
||||
ws_mem(nullptr),
|
||||
bn_fwd(nullptr),
|
||||
fwd_stream(nullptr) {}
|
||||
};
|
||||
@ -192,6 +209,9 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
||||
: prop_kind::forward_scoring;
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
if (fwdParams.activation_mode == FusedBNActivationMode::kRelu) {
|
||||
context_.flags |= GET_FLAG(fuse_norm_relu);
|
||||
}
|
||||
// Memory descriptor
|
||||
auto src_md = fwdParams.src_md;
|
||||
// Create forward BatchNorm descriptor and primitive descriptor.
|
||||
@ -229,6 +249,13 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
||||
m_dims, U, MEMORY_FORMAT::nc, cpu_engine_, DummyData));
|
||||
}
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
if (IS_SET(fuse_norm_relu)) {
|
||||
context_.ws_mem.reset(new MEMORY_CONSTRUCTOR(
|
||||
context_.fwd_pd->workspace_desc(), cpu_engine_, DummyData));
|
||||
}
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// BatchNorm forward primitive.
|
||||
// TODO(intel-tf): Merge all the #ifdefs and simplify code
|
||||
if (!fwdParams.training && !(IS_SET(use_global_stats))) {
|
||||
@ -258,20 +285,41 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
||||
} else if (IS_SET(use_global_stats)) {
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) {
|
||||
context_.net_args.push_back(
|
||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{MKLDNN_ARG_MEAN, *context_.mean_mem},
|
||||
{MKLDNN_ARG_VARIANCE, *context_.variance_mem},
|
||||
{MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
|
||||
{ MKLDNN_ARG_DST,
|
||||
*context_.dst_mem }});
|
||||
if (IS_SET(fuse_norm_relu)) {
|
||||
context_.net_args.push_back(
|
||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{MKLDNN_ARG_MEAN, *context_.mean_mem},
|
||||
{MKLDNN_ARG_VARIANCE, *context_.variance_mem},
|
||||
{MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
|
||||
{MKLDNN_ARG_DST, *context_.dst_mem},
|
||||
{ MKLDNN_ARG_WORKSPACE,
|
||||
*context_.ws_mem }});
|
||||
} else {
|
||||
context_.net_args.push_back(
|
||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{MKLDNN_ARG_MEAN, *context_.mean_mem},
|
||||
{MKLDNN_ARG_VARIANCE, *context_.variance_mem},
|
||||
{MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
|
||||
{ MKLDNN_ARG_DST,
|
||||
*context_.dst_mem }});
|
||||
}
|
||||
} else {
|
||||
context_.net_args.push_back(
|
||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{MKLDNN_ARG_MEAN, *context_.mean_mem},
|
||||
{MKLDNN_ARG_VARIANCE, *context_.variance_mem},
|
||||
{ MKLDNN_ARG_DST,
|
||||
*context_.dst_mem }});
|
||||
if (IS_SET(fuse_norm_relu)) {
|
||||
context_.net_args.push_back(
|
||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{MKLDNN_ARG_MEAN, *context_.mean_mem},
|
||||
{MKLDNN_ARG_VARIANCE, *context_.variance_mem},
|
||||
{MKLDNN_ARG_DST, *context_.dst_mem},
|
||||
{ MKLDNN_ARG_WORKSPACE,
|
||||
*context_.ws_mem }});
|
||||
} else {
|
||||
context_.net_args.push_back(
|
||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{MKLDNN_ARG_MEAN, *context_.mean_mem},
|
||||
{MKLDNN_ARG_VARIANCE, *context_.variance_mem},
|
||||
{ MKLDNN_ARG_DST,
|
||||
*context_.dst_mem }});
|
||||
}
|
||||
}
|
||||
context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
|
||||
#else
|
||||
@ -291,19 +339,40 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
||||
} else {
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) {
|
||||
context_.net_args.push_back(
|
||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
|
||||
{MKLDNN_ARG_DST, *context_.dst_mem},
|
||||
{MKLDNN_ARG_MEAN, *context_.mean_mem},
|
||||
{ MKLDNN_ARG_VARIANCE,
|
||||
*context_.variance_mem }});
|
||||
if (IS_SET(fuse_norm_relu)) {
|
||||
context_.net_args.push_back(
|
||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
|
||||
{MKLDNN_ARG_DST, *context_.dst_mem},
|
||||
{MKLDNN_ARG_MEAN, *context_.mean_mem},
|
||||
{MKLDNN_ARG_VARIANCE, *context_.variance_mem},
|
||||
{ MKLDNN_ARG_WORKSPACE,
|
||||
*context_.ws_mem }});
|
||||
} else {
|
||||
context_.net_args.push_back(
|
||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
|
||||
{MKLDNN_ARG_DST, *context_.dst_mem},
|
||||
{MKLDNN_ARG_MEAN, *context_.mean_mem},
|
||||
{ MKLDNN_ARG_VARIANCE,
|
||||
*context_.variance_mem }});
|
||||
}
|
||||
} else {
|
||||
context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{MKLDNN_ARG_DST, *context_.dst_mem},
|
||||
{MKLDNN_ARG_MEAN, *context_.mean_mem},
|
||||
{ MKLDNN_ARG_VARIANCE,
|
||||
*context_.variance_mem }});
|
||||
if (IS_SET(fuse_norm_relu)) {
|
||||
context_.net_args.push_back(
|
||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{MKLDNN_ARG_DST, *context_.dst_mem},
|
||||
{MKLDNN_ARG_MEAN, *context_.mean_mem},
|
||||
{MKLDNN_ARG_VARIANCE, *context_.variance_mem},
|
||||
{ MKLDNN_ARG_WORKSPACE,
|
||||
*context_.ws_mem }});
|
||||
} else {
|
||||
context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{MKLDNN_ARG_DST, *context_.dst_mem},
|
||||
{MKLDNN_ARG_MEAN, *context_.mean_mem},
|
||||
{ MKLDNN_ARG_VARIANCE,
|
||||
*context_.variance_mem }});
|
||||
}
|
||||
}
|
||||
context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
|
||||
#else
|
||||
@ -360,6 +429,7 @@ class MklFusedBatchNormFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
key_creator.AddAsKey<int>(fwdParams.depth);
|
||||
key_creator.AddAsKey<float>(fwdParams.eps);
|
||||
key_creator.AddAsKey<bool>(fwdParams.training);
|
||||
key_creator.AddAsKey<FusedBNActivationMode>(fwdParams.activation_mode);
|
||||
key_creator.AddAsKey(typeid(T).name());
|
||||
key_creator.AddAsKey(typeid(U).name());
|
||||
return key_creator.GetKey();
|
||||
@ -676,7 +746,8 @@ class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
// Adding a third parameter to the template to support FusedBatchNormV3
|
||||
// with MKL. This is different from default where the classes are
|
||||
// derived. Moves enabling to compile-time rather than runtime.
|
||||
template <typename Device, typename T, typename U, bool reserved_space>
|
||||
template <typename Device, typename T, typename U, bool reserved_space,
|
||||
bool is_batch_norm_ex = false>
|
||||
class MklFusedBatchNormOp : public OpKernel {
|
||||
public:
|
||||
explicit MklFusedBatchNormOp(OpKernelConstruction* context)
|
||||
@ -696,6 +767,28 @@ class MklFusedBatchNormOp : public OpKernel {
|
||||
depth_ = 0;
|
||||
mean_values_ = nullptr;
|
||||
variance_values_ = nullptr;
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
OP_REQUIRES(context, !is_batch_norm_ex,
|
||||
errors::InvalidArgument(
|
||||
"_MklFusedBatchNormEx is not supported in DNNL 0.x ."));
|
||||
#endif
|
||||
if (!is_batch_norm_ex) {
|
||||
activation_mode_ = FusedBNActivationMode::kIdentity;
|
||||
} else {
|
||||
int num_side_inputs;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("num_side_inputs", &num_side_inputs));
|
||||
// Currently _MKLFusedBatchNormEx do not support "SideInput"
|
||||
OP_REQUIRES(context, num_side_inputs == 0,
|
||||
errors::InvalidArgument(
|
||||
"_MKLFusedBatchNorm do not support side input now."));
|
||||
|
||||
OP_REQUIRES_OK(context, ParseActivationMode(context, &activation_mode_));
|
||||
OP_REQUIRES(context, activation_mode_ == FusedBNActivationMode::kRelu,
|
||||
errors::InvalidArgument(
|
||||
"_MKLFusedBatchNorm only support Relu activation"));
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
@ -744,9 +837,12 @@ class MklFusedBatchNormOp : public OpKernel {
|
||||
|
||||
// Handle the special case: input with 0 element and 0 batch size.
|
||||
Tensor* dst_tensor = nullptr;
|
||||
TensorShape workspace_tf_shape;
|
||||
if (tf_shape_src.num_elements() == 0) {
|
||||
HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(),
|
||||
&dst_tensor);
|
||||
size_t workspace_bytes = 0;
|
||||
workspace_tf_shape.AddDim(workspace_bytes);
|
||||
HandleEmptyInput(context, tf_shape_src, workspace_tf_shape,
|
||||
scale_tensor.shape(), &dst_tensor);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -758,23 +854,16 @@ class MklFusedBatchNormOp : public OpKernel {
|
||||
// Index of output tensor(diff_src).
|
||||
const size_t kDstIndex = 0;
|
||||
|
||||
// Allocate 4 output TF tensors.
|
||||
// Allocate 5 output TF tensors.
|
||||
Tensor* batch_mean_tensor = nullptr;
|
||||
Tensor* batch_variance_tensor = nullptr;
|
||||
Tensor* saved_mean_tensor = nullptr;
|
||||
Tensor* saved_variance_tensor = nullptr;
|
||||
Tensor* reserved_space_tensor = nullptr;
|
||||
AllocateTFOutputs(context, scale_tensor.shape(), &batch_mean_tensor,
|
||||
&batch_variance_tensor, &saved_mean_tensor,
|
||||
&saved_variance_tensor, &reserved_space_tensor);
|
||||
|
||||
if (is_training_)
|
||||
SetMeanVariance(*batch_mean_tensor, *batch_variance_tensor);
|
||||
else
|
||||
SetMeanVariance(est_mean_tensor, est_variance_tensor);
|
||||
|
||||
MklDnnData<T> src(&cpu_engine_);
|
||||
MklDnnData<U> weights(&cpu_engine_);
|
||||
MklDnnData<U> wksp(&cpu_engine_);
|
||||
|
||||
MEMORY_FORMAT dnn_fmt;
|
||||
MKL_TENSOR_FORMAT mkl_tensor_fmt;
|
||||
@ -801,6 +890,51 @@ class MklFusedBatchNormOp : public OpKernel {
|
||||
? dnn_shape_src.GetMklLayout()
|
||||
: memory::desc(src_dims, MklDnnType<T>(), dnn_fmt);
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_,
|
||||
src_md, activation_mode_);
|
||||
#else
|
||||
MklBatchNormFwdParams fwdParams(
|
||||
src_dims, depth_, epsilon_, is_training_,
|
||||
static_cast<MEMORY_FORMAT>(src_md.data.format), activation_mode_);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
// Get forward batch-normalization op from the primitive caching pool.
|
||||
MklFusedBatchNormFwdPrimitive<T, U>* bn_fwd =
|
||||
MklFusedBatchNormFwdPrimitiveFactory<T, U>::Get(fwdParams);
|
||||
|
||||
// Allocate workspace tensor
|
||||
U* ws_data = nullptr;
|
||||
if (fwdParams.activation_mode == FusedBNActivationMode::kRelu) {
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
MEMORY_PRIMITIVE_DESC workspace_pd =
|
||||
bn_fwd->GetBatchNormFwdPd()->workspace_desc();
|
||||
size_t workspace_bytes = workspace_pd.get_size();
|
||||
workspace_tf_shape.AddDim(workspace_bytes);
|
||||
|
||||
AllocateTFOutputs(context, scale_tensor.shape(), workspace_tf_shape,
|
||||
&batch_mean_tensor, &batch_variance_tensor,
|
||||
&saved_mean_tensor, &saved_variance_tensor,
|
||||
&reserved_space_tensor);
|
||||
if (reserved_space) {
|
||||
wksp.SetUsrMem(workspace_pd, reserved_space_tensor);
|
||||
ws_data = static_cast<U*>(wksp.GetOpMem().get_data_handle());
|
||||
}
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
} else {
|
||||
// There is actually no workspace tensor out, so we make a dummy one.
|
||||
size_t workspace_bytes = 0;
|
||||
workspace_tf_shape.AddDim(workspace_bytes);
|
||||
AllocateTFOutputs(context, scale_tensor.shape(), workspace_tf_shape,
|
||||
&batch_mean_tensor, &batch_variance_tensor,
|
||||
&saved_mean_tensor, &saved_variance_tensor,
|
||||
&reserved_space_tensor);
|
||||
}
|
||||
|
||||
if (is_training_)
|
||||
SetMeanVariance(*batch_mean_tensor, *batch_variance_tensor);
|
||||
else
|
||||
SetMeanVariance(est_mean_tensor, est_variance_tensor);
|
||||
|
||||
// MKL-DNN packs scale & shift as "weights":
|
||||
// <scale>...<scale><shift>...<shift>
|
||||
weights.AllocateBuffer(2 * depth_ * sizeof(U));
|
||||
@ -821,18 +955,6 @@ class MklFusedBatchNormOp : public OpKernel {
|
||||
reinterpret_cast<char*>(variance_values_),
|
||||
depth_ * sizeof(U));
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_,
|
||||
src_md);
|
||||
#else
|
||||
MklBatchNormFwdParams fwdParams(
|
||||
src_dims, depth_, epsilon_, is_training_,
|
||||
static_cast<MEMORY_FORMAT>(src_md.data.format));
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
// Get forward batch-normalization op from the primitive caching pool.
|
||||
MklFusedBatchNormFwdPrimitive<T, U>* bn_fwd =
|
||||
MklFusedBatchNormFwdPrimitiveFactory<T, U>::Get(fwdParams);
|
||||
|
||||
// Check if reorder is needed for src.
|
||||
const T* src_data = nullptr;
|
||||
std::shared_ptr<BatchNormFwdPd> bn_fwd_pd = bn_fwd->GetBatchNormFwdPd();
|
||||
@ -866,7 +988,7 @@ class MklFusedBatchNormOp : public OpKernel {
|
||||
|
||||
// Execute
|
||||
bn_fwd->Execute(src_data, weights_op_data, dst_data, mean_op_data,
|
||||
variance_op_data);
|
||||
variance_op_data, ws_data);
|
||||
|
||||
float adjust_factor = 1.0;
|
||||
if (is_training_) {
|
||||
@ -924,6 +1046,7 @@ class MklFusedBatchNormOp : public OpKernel {
|
||||
U* mean_values_;
|
||||
U* variance_values_;
|
||||
size_t depth_; // Batch normalization is performed for per channel.
|
||||
FusedBNActivationMode activation_mode_;
|
||||
engine cpu_engine_ = engine(ENGINE_CPU, 0);
|
||||
|
||||
void ExtractParams(OpKernelContext* context) {
|
||||
@ -938,6 +1061,7 @@ class MklFusedBatchNormOp : public OpKernel {
|
||||
}
|
||||
|
||||
void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src,
|
||||
TensorShape workspace_tf_shape,
|
||||
TensorShape tf_shape_scale, Tensor** dst_tensor) {
|
||||
DCHECK(dst_tensor);
|
||||
|
||||
@ -955,12 +1079,14 @@ class MklFusedBatchNormOp : public OpKernel {
|
||||
Tensor* saved_mean_tensor = nullptr;
|
||||
Tensor* saved_variance_tensor = nullptr;
|
||||
Tensor* reserved_space_tensor = nullptr;
|
||||
AllocateTFOutputs(context, tf_shape_scale, &batch_mean_tensor,
|
||||
&batch_variance_tensor, &saved_mean_tensor,
|
||||
&saved_variance_tensor, &reserved_space_tensor);
|
||||
AllocateTFOutputs(context, tf_shape_scale, workspace_tf_shape,
|
||||
&batch_mean_tensor, &batch_variance_tensor,
|
||||
&saved_mean_tensor, &saved_variance_tensor,
|
||||
&reserved_space_tensor);
|
||||
}
|
||||
|
||||
void AllocateTFOutputs(OpKernelContext* context, TensorShape tf_shape_scale,
|
||||
TensorShape workspace_tf_shape,
|
||||
Tensor** batch_mean_tensor,
|
||||
Tensor** batch_variance_tensor,
|
||||
Tensor** saved_mean_tensor,
|
||||
@ -1024,21 +1150,15 @@ class MklFusedBatchNormOp : public OpKernel {
|
||||
std::fill_n(saved_variance_data, num_elements, static_cast<U>(0));
|
||||
|
||||
// Changes to support reserved_space_3 parameter in FusedBatchNormV3.
|
||||
// TODO: This parameter functionality is not implemented on CPU.
|
||||
// It is used to hold intermediate results. So the allocated
|
||||
// memory is filled with 0s.
|
||||
if (reserved_space) {
|
||||
DCHECK(reserved_space_tensor != nullptr);
|
||||
|
||||
MklDnnShape mkl_shape_reserved_space;
|
||||
mkl_shape_reserved_space.SetMklTensor(false);
|
||||
AllocateOutputSetMklShape(context, kReservedSpaceIndex,
|
||||
reserved_space_tensor, tf_shape_scale,
|
||||
reserved_space_tensor, workspace_tf_shape,
|
||||
mkl_shape_reserved_space);
|
||||
DCHECK((*reserved_space_tensor) != nullptr);
|
||||
auto saved_reserved_space_data =
|
||||
(*reserved_space_tensor)->flat<U>().data();
|
||||
std::fill_n(saved_reserved_space_data, num_elements, static_cast<U>(0));
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -1367,7 +1487,7 @@ class MklFusedBatchNormGradOp : public OpKernel {
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklFusedBatchNormOp<CPUDevice, T, T, false>);
|
||||
MklFusedBatchNormOp<CPUDevice, T, T, false, false>);
|
||||
|
||||
TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_CPU);
|
||||
TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_CPU);
|
||||
@ -1380,7 +1500,7 @@ TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_CPU);
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<U>("U") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklFusedBatchNormOp<CPUDevice, T, U, false>);
|
||||
MklFusedBatchNormOp<CPUDevice, T, U, false, false>);
|
||||
|
||||
REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(float, float);
|
||||
REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(bfloat16, float);
|
||||
@ -1421,12 +1541,30 @@ REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(bfloat16, float);
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<U>("U") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklFusedBatchNormOp<CPUDevice, T, U, true>);
|
||||
MklFusedBatchNormOp<CPUDevice, T, U, true, false>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_MklFusedBatchNormEx") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<U>("U") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklFusedBatchNormOp<CPUDevice, T, U, true, true>);
|
||||
|
||||
REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(float, float);
|
||||
REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(bfloat16, float);
|
||||
#undef REGISTER_MKL_FUSED_BATCHNORM_V3_CPU
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<float>("T")
|
||||
.TypeConstraint<float>("U"),
|
||||
NoOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<bfloat16>("T")
|
||||
.TypeConstraint<float>("U"),
|
||||
NoOp);
|
||||
|
||||
#define REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(T, U) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_MklFusedBatchNormGradV3") \
|
||||
|
@ -1342,6 +1342,48 @@ REGISTER_OP("_MklFusedBatchNormGradV3")
|
||||
R"doc(MKL-DNN implementation of FusedBatchNormGradV3: Do not invoke this operator directly in Python.
|
||||
Graph rewrite pass is expected to invoke this operator.)doc");
|
||||
|
||||
REGISTER_OP("_MklFusedBatchNormEx")
|
||||
.Input("x: T")
|
||||
.Input("scale: U")
|
||||
.Input("offset: U")
|
||||
.Input("mean: U")
|
||||
.Input("variance: U")
|
||||
.Input("side_input: num_side_inputs * T")
|
||||
.Input("mkl_x: uint8")
|
||||
.Input("mkl_scale: uint8")
|
||||
.Input("mkl_offset: uint8")
|
||||
.Input("mkl_mean: uint8")
|
||||
.Input("mkl_variance: uint8")
|
||||
.Input("mkl_side_input: num_side_inputs * uint8")
|
||||
.Output("y: T")
|
||||
.Output("batch_mean: U")
|
||||
.Output("batch_variance: U")
|
||||
.Output("reserve_space_1: U")
|
||||
.Output("reserve_space_2: U")
|
||||
.Output("reserve_space_3: U")
|
||||
.Output("mkl_y: uint8")
|
||||
.Output("mkl_batch_mean: uint8")
|
||||
.Output("mkl_batch_variance: uint8")
|
||||
.Output("mkl_reserve_space_1: uint8")
|
||||
.Output("mkl_reserve_space_2: uint8")
|
||||
.Output("mkl_reserve_space_3: uint8")
|
||||
.Attr("T: {bfloat16, float}")
|
||||
.Attr("U: {float}")
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
.Attr("exponential_avg_factor: float = 1.0")
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.Attr("num_side_inputs: int >= 0 = 0")
|
||||
.Attr("activation_mode: string = \"Identity\"")
|
||||
.Attr("is_training: bool = true")
|
||||
.SetShapeFn(shape_inference::FusedBatchNormShape)
|
||||
.Doc(R"doc(
|
||||
MKL version of FusedBatchNormEx operator. Uses MKL DNN APIs to perform fused
|
||||
batch normalization and relu.
|
||||
|
||||
NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
|
||||
expected to invoke these operators.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // INTEL_MKL
|
||||
|
@ -238,7 +238,11 @@ REGISTER_OP("_FusedBatchNormEx")
|
||||
.Output("reserve_space_1: U")
|
||||
.Output("reserve_space_2: U")
|
||||
.Output("reserve_space_3: U")
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
.Attr("T: {half, float, bfloat16}")
|
||||
#else
|
||||
.Attr("T: {half, float}")
|
||||
#endif
|
||||
.Attr("U: {float}")
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
.Attr("exponential_avg_factor: float = 1.0")
|
||||
|
Loading…
x
Reference in New Issue
Block a user