Fuse BN and Relu in mkl path
This commit is contained in:
parent
5992e75800
commit
d4d23502bf
@ -268,6 +268,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
csinfo_.dequantize = "Dequantize";
|
csinfo_.dequantize = "Dequantize";
|
||||||
csinfo_.fused_batch_norm = "FusedBatchNorm";
|
csinfo_.fused_batch_norm = "FusedBatchNorm";
|
||||||
csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
|
csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
|
||||||
|
csinfo_.fused_batch_norm_ex = "_FusedBatchNormEx";
|
||||||
csinfo_.fused_batch_norm_v2 = "FusedBatchNormV2";
|
csinfo_.fused_batch_norm_v2 = "FusedBatchNormV2";
|
||||||
csinfo_.fused_batch_norm_grad_v2 = "FusedBatchNormGradV2";
|
csinfo_.fused_batch_norm_grad_v2 = "FusedBatchNormGradV2";
|
||||||
csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3";
|
csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3";
|
||||||
@ -294,6 +295,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
"_MklDepthwiseConv2dNativeBackpropInput";
|
"_MklDepthwiseConv2dNativeBackpropInput";
|
||||||
csinfo_.mkl_depthwise_conv2d_grad_filter =
|
csinfo_.mkl_depthwise_conv2d_grad_filter =
|
||||||
"_MklDepthwiseConv2dNativeBackpropFilter";
|
"_MklDepthwiseConv2dNativeBackpropFilter";
|
||||||
|
csinfo_.mkl_fused_batch_norm_ex = "_MklFusedBatchNormEx";
|
||||||
csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
|
csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
|
||||||
csinfo_.mkl_fused_matmul = "_MklFusedMatMul";
|
csinfo_.mkl_fused_matmul = "_MklFusedMatMul";
|
||||||
csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
|
csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
|
||||||
@ -476,6 +478,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
{csinfo_.fused_batch_norm_grad_v3,
|
{csinfo_.fused_batch_norm_grad_v3,
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
|
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
|
||||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||||
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
|
rinfo_.push_back({csinfo_.fused_batch_norm_ex,
|
||||||
|
csinfo_.mkl_fused_batch_norm_ex, CopyAttrsAll,
|
||||||
|
FusedBatchNormExRewrite, kRewriteForLayoutPropagation});
|
||||||
|
#endif
|
||||||
rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d,
|
rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d,
|
||||||
CopyAttrsFusedConv2D, FusedConv2DRewrite,
|
CopyAttrsFusedConv2D, FusedConv2DRewrite,
|
||||||
kRewriteForLayoutPropagation});
|
kRewriteForLayoutPropagation});
|
||||||
@ -920,6 +927,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
string dequantize;
|
string dequantize;
|
||||||
string fused_batch_norm;
|
string fused_batch_norm;
|
||||||
string fused_batch_norm_grad;
|
string fused_batch_norm_grad;
|
||||||
|
string fused_batch_norm_ex;
|
||||||
string fused_batch_norm_v2;
|
string fused_batch_norm_v2;
|
||||||
string fused_batch_norm_grad_v2;
|
string fused_batch_norm_grad_v2;
|
||||||
string fused_batch_norm_v3;
|
string fused_batch_norm_v3;
|
||||||
@ -944,6 +952,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
string mkl_conv2d_with_bias;
|
string mkl_conv2d_with_bias;
|
||||||
string mkl_depthwise_conv2d_grad_input;
|
string mkl_depthwise_conv2d_grad_input;
|
||||||
string mkl_depthwise_conv2d_grad_filter;
|
string mkl_depthwise_conv2d_grad_filter;
|
||||||
|
string mkl_fused_batch_norm_ex;
|
||||||
string mkl_fused_conv2d;
|
string mkl_fused_conv2d;
|
||||||
string mkl_fused_matmul;
|
string mkl_fused_matmul;
|
||||||
string mkl_pad_with_conv2d;
|
string mkl_pad_with_conv2d;
|
||||||
@ -1652,6 +1661,31 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
return do_rewrite;
|
return do_rewrite;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool FusedBatchNormExRewrite(const Node* n) {
|
||||||
|
CHECK_NOTNULL(n);
|
||||||
|
|
||||||
|
int num_side_inputs;
|
||||||
|
TF_CHECK_OK(GetNodeAttr(n->def(), "num_side_inputs", &num_side_inputs));
|
||||||
|
string activation_mode;
|
||||||
|
TF_CHECK_OK(GetNodeAttr(n->def(), "activation_mode", &activation_mode));
|
||||||
|
|
||||||
|
// if the num_side_inputs is not 0, don't rewrite the node.
|
||||||
|
if (num_side_inputs != 0) {
|
||||||
|
VLOG(1) << "FusedBatchNormExRewrite: The model sets num_side_inputs"
|
||||||
|
<< "larger than 0 is not optimized by Intel MKL.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the activation_mode is not 'Relu', don't rewrite the node.
|
||||||
|
if (activation_mode != "Relu") {
|
||||||
|
VLOG(1) << "FusedBatchNormExRewrite: Only Relu activation mode is"
|
||||||
|
<< "supported by Intel MKL.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
static bool FusedConv2DRewrite(const Node* n) {
|
static bool FusedConv2DRewrite(const Node* n) {
|
||||||
// MKL DNN currently doesn't support all fusions that grappler fuses
|
// MKL DNN currently doesn't support all fusions that grappler fuses
|
||||||
// together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if
|
// together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if
|
||||||
@ -2131,9 +2165,6 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
|
|||||||
// Number of input slots to original op
|
// Number of input slots to original op
|
||||||
// Input slots are represented by .Input() calls in REGISTER_OP.
|
// Input slots are represented by .Input() calls in REGISTER_OP.
|
||||||
int old_node_input_slots = old_node->op_def().input_arg_size();
|
int old_node_input_slots = old_node->op_def().input_arg_size();
|
||||||
// Actual number of inputs can be greater than or equal to number
|
|
||||||
// of Input slots because inputs of type list could be unfolded.
|
|
||||||
CHECK_GE(old_node_inputs.size(), old_node_input_slots);
|
|
||||||
int nn_slot_idx = 0; // slot index for inputs of new node
|
int nn_slot_idx = 0; // slot index for inputs of new node
|
||||||
|
|
||||||
// Let's copy all inputs (TF tensors) of original node to new node.
|
// Let's copy all inputs (TF tensors) of original node to new node.
|
||||||
@ -2141,13 +2172,14 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
|
|||||||
for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
|
for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
|
||||||
// An input slot could be a single tensor or a list. We need
|
// An input slot could be a single tensor or a list. We need
|
||||||
// to handle this case accordingly.
|
// to handle this case accordingly.
|
||||||
CHECK_LT(iidx, old_node_inputs.size());
|
|
||||||
const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
|
const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
|
||||||
if (ArgIsList(arg)) {
|
if (ArgIsList(arg)) {
|
||||||
std::vector<NodeBuilder::NodeOut> new_node_inputs;
|
std::vector<NodeBuilder::NodeOut> new_node_inputs;
|
||||||
int N = GetTensorListLength(arg, old_node);
|
int tensor_list_length = GetTensorListLength(arg, old_node);
|
||||||
GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
|
if (tensor_list_length != 0) {
|
||||||
&new_node_inputs);
|
GetNodesProducingTFTensorList(old_node_inputs, &iidx,
|
||||||
|
tensor_list_length, &new_node_inputs);
|
||||||
|
}
|
||||||
nb->Input(new_node_inputs);
|
nb->Input(new_node_inputs);
|
||||||
nn_slot_idx++;
|
nn_slot_idx++;
|
||||||
} else {
|
} else {
|
||||||
@ -2180,13 +2212,14 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
|
|||||||
for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
|
for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
|
||||||
// An input slot could be a single tensor or a list. We need
|
// An input slot could be a single tensor or a list. We need
|
||||||
// to handle this case accordingly.
|
// to handle this case accordingly.
|
||||||
CHECK_LT(iidx, old_node_inputs.size());
|
|
||||||
const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
|
const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
|
||||||
if (ArgIsList(arg)) {
|
if (ArgIsList(arg)) {
|
||||||
std::vector<NodeBuilder::NodeOut> new_node_inputs;
|
std::vector<NodeBuilder::NodeOut> new_node_inputs;
|
||||||
int N = GetTensorListLength(arg, old_node);
|
int tensor_list_length = GetTensorListLength(arg, old_node);
|
||||||
GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N,
|
if (tensor_list_length != 0) {
|
||||||
&new_node_inputs);
|
GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx,
|
||||||
|
tensor_list_length, &new_node_inputs);
|
||||||
|
}
|
||||||
nb->Input(new_node_inputs);
|
nb->Input(new_node_inputs);
|
||||||
nn_slot_idx++;
|
nn_slot_idx++;
|
||||||
} else {
|
} else {
|
||||||
@ -3702,6 +3735,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
|
|||||||
n->type_string() != csinfo_.pad_with_conv2d &&
|
n->type_string() != csinfo_.pad_with_conv2d &&
|
||||||
n->type_string() != csinfo_.pad_with_fused_conv2d &&
|
n->type_string() != csinfo_.pad_with_fused_conv2d &&
|
||||||
n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
|
n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
|
||||||
|
n->type_string() != csinfo_.fused_batch_norm_ex &&
|
||||||
n->type_string() != csinfo_.fused_conv2d &&
|
n->type_string() != csinfo_.fused_conv2d &&
|
||||||
n->type_string() != csinfo_.fused_matmul &&
|
n->type_string() != csinfo_.fused_matmul &&
|
||||||
!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
|
!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
|
||||||
|
@ -3108,6 +3108,112 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV3_Negative) {
|
|||||||
"B->F:1;C->F:2;D->F:3;E->F:4;F->G:1");
|
"B->F:1;C->F:2;D->F:3;E->F:4;F->G:1");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
|
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||||
|
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||||
|
InitGraph("node { name: 'A' op: '" #INPUT \
|
||||||
|
"'}" \
|
||||||
|
"node { name: 'B' op: 'Input'}" \
|
||||||
|
"node { name: 'C' op: 'Input'}" \
|
||||||
|
"node { name: 'D' op: 'Input'}" \
|
||||||
|
"node { name: 'E' op: 'Input'}" \
|
||||||
|
"node { name: 'F' op: '_FusedBatchNormEx'" \
|
||||||
|
" attr { key: 'T' value { type: " #T \
|
||||||
|
" } }" \
|
||||||
|
" attr { key: 'U' value { type: DT_FLOAT } }" \
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }" \
|
||||||
|
" attr { key: 'epsilon' value { f: 0.0001 } }" \
|
||||||
|
" attr { key: 'num_side_inputs' value { i: 0 } }" \
|
||||||
|
" attr { key: 'is_training' value { b: true } }" \
|
||||||
|
" attr { key: 'activation_mode' value { s: 'Relu' } }" \
|
||||||
|
" input: ['A', 'B', 'C', 'D', 'E'] }" \
|
||||||
|
"node { name: 'G' op: 'Zeta'" \
|
||||||
|
" attr { key: 'T' value { type: " #T \
|
||||||
|
" } }" \
|
||||||
|
" input: ['A', 'F'] }"); \
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||||
|
"A(" #INPUT \
|
||||||
|
");B(Input);C(Input);D(Input);" \
|
||||||
|
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);" \
|
||||||
|
"DMT/_4(Const);E(Input);" \
|
||||||
|
"F(_MklFusedBatchNormEx);G(Zeta)|A->F;A->G;" \
|
||||||
|
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
|
||||||
|
"A:control->DMT/_2:control;A:control->DMT/_3:control;" \
|
||||||
|
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;" \
|
||||||
|
"DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;" \
|
||||||
|
"E->F:4;F->G:1"); \
|
||||||
|
}
|
||||||
|
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Positive);
|
||||||
|
#undef REGISTER_TEST
|
||||||
|
|
||||||
|
// Rewrite test for _FusedBatchNormEx Op with side input
|
||||||
|
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||||
|
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||||
|
InitGraph("node { name: 'A' op: '" #INPUT \
|
||||||
|
"'}" \
|
||||||
|
"node { name: 'B' op: 'Input'}" \
|
||||||
|
"node { name: 'C' op: 'Input'}" \
|
||||||
|
"node { name: 'D' op: 'Input'}" \
|
||||||
|
"node { name: 'E' op: 'Input'}" \
|
||||||
|
"node { name: 'F' op: '" #INPUT \
|
||||||
|
"'}" \
|
||||||
|
"node { name: 'G' op: '_FusedBatchNormEx'" \
|
||||||
|
" attr { key: 'T' value { type: " #T \
|
||||||
|
" } }" \
|
||||||
|
" attr { key: 'U' value { type: DT_FLOAT } }" \
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }" \
|
||||||
|
" attr { key: 'epsilon' value { f: 0.0001 } }" \
|
||||||
|
" attr { key: 'num_side_inputs' value { i: 1 } }" \
|
||||||
|
" attr { key: 'is_training' value { b: true } }" \
|
||||||
|
" attr { key: 'activation_mode' value { s: 'Relu' } }" \
|
||||||
|
" input: ['A', 'B', 'C', 'D', 'E', 'F'] }" \
|
||||||
|
"node { name: 'H' op: 'Zeta'" \
|
||||||
|
" attr { key: 'T' value { type: " #T \
|
||||||
|
" } }" \
|
||||||
|
" input: ['A', 'G'] }"); \
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||||
|
"A(" #INPUT \
|
||||||
|
");B(Input);C(Input);D(Input);E(Input);" \
|
||||||
|
"F(" #INPUT \
|
||||||
|
");G(_FusedBatchNormEx);H(Zeta)|A->G;A->H;" \
|
||||||
|
"B->G:1;C->G:2;D->G:3;E->G:4;F->G:5;G->H:1"); \
|
||||||
|
}
|
||||||
|
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Negative1);
|
||||||
|
#undef REGISTER_TEST
|
||||||
|
|
||||||
|
// Rewrite test for _FusedBatchNormEx Op with Identity activation
|
||||||
|
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||||
|
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||||
|
InitGraph("node { name: 'A' op: '" #INPUT \
|
||||||
|
"'}" \
|
||||||
|
"node { name: 'B' op: 'Input'}" \
|
||||||
|
"node { name: 'C' op: 'Input'}" \
|
||||||
|
"node { name: 'D' op: 'Input'}" \
|
||||||
|
"node { name: 'E' op: 'Input'}" \
|
||||||
|
"node { name: 'G' op: '_FusedBatchNormEx'" \
|
||||||
|
" attr { key: 'T' value { type: " #T \
|
||||||
|
" } }" \
|
||||||
|
" attr { key: 'U' value { type: DT_FLOAT } }" \
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }" \
|
||||||
|
" attr { key: 'epsilon' value { f: 0.0001 } }" \
|
||||||
|
" attr { key: 'num_side_inputs' value { i: 1 } }" \
|
||||||
|
" attr { key: 'is_training' value { b: true } }" \
|
||||||
|
" attr { key: 'activation_mode' value { s: 'Identity' } }" \
|
||||||
|
" input: ['A', 'B', 'C', 'D', 'E'] }" \
|
||||||
|
"node { name: 'H' op: 'Zeta'" \
|
||||||
|
" attr { key: 'T' value { type: " #T \
|
||||||
|
" } }" \
|
||||||
|
" input: ['A', 'G'] }"); \
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||||
|
"A(" #INPUT \
|
||||||
|
");B(Input);C(Input);D(Input);E(Input);" \
|
||||||
|
"G(_FusedBatchNormEx);H(Zeta)|A->G;A->H;" \
|
||||||
|
"B->G:1;C->G:2;D->G:3;E->G:4;G->H:1"); \
|
||||||
|
}
|
||||||
|
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Negative2);
|
||||||
|
#undef REGISTER_TEST
|
||||||
|
#endif // ENABLE_MKLDNN_V1
|
||||||
|
|
||||||
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizedDepthwiseConv2D_Positive) {
|
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizedDepthwiseConv2D_Positive) {
|
||||||
InitGraph(
|
InitGraph(
|
||||||
"node { name: 'A' op: 'QuantizedUnsignedInt8Input'}"
|
"node { name: 'A' op: 'QuantizedUnsignedInt8Input'}"
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/grappler/grappler_item.h"
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/remapper.h"
|
#include "tensorflow/core/grappler/optimizers/remapper.h"
|
||||||
#include "tensorflow/core/grappler/utils/grappler_test.h"
|
#include "tensorflow/core/grappler/utils/grappler_test.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -173,6 +174,178 @@ TEST_F(MklRemapperTest, FuseConv2DWithBiasAndAddNRelu) {
|
|||||||
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
|
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
|
TEST_F(MklRemapperTest, FuseBatchNormWithRelu) {
|
||||||
|
using ::tensorflow::ops::Placeholder;
|
||||||
|
|
||||||
|
for (bool is_training : {true, false}) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
|
||||||
|
const int num_channels = 24;
|
||||||
|
|
||||||
|
TensorShape channel_shape({num_channels});
|
||||||
|
TensorShape empty_shape({0});
|
||||||
|
|
||||||
|
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT,
|
||||||
|
ops::Placeholder::Shape({2, 8, 8, num_channels}));
|
||||||
|
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_FLOAT);
|
||||||
|
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT);
|
||||||
|
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT);
|
||||||
|
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT);
|
||||||
|
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT);
|
||||||
|
|
||||||
|
float epsilon = 0.1f;
|
||||||
|
auto fbn = ops::FusedBatchNormV3(
|
||||||
|
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
|
||||||
|
ops::FusedBatchNormV3::IsTraining(is_training)
|
||||||
|
.Epsilon(epsilon)
|
||||||
|
.DataFormat("NHWC"));
|
||||||
|
auto relu = ops::Relu(s.WithOpName("relu"), fbn.y);
|
||||||
|
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
|
||||||
|
|
||||||
|
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
|
||||||
|
auto scale_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
|
||||||
|
auto offset_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
|
||||||
|
auto mean_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
|
||||||
|
: channel_shape);
|
||||||
|
auto var_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
|
||||||
|
: channel_shape);
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
item.fetch = {"fetch"};
|
||||||
|
item.feed = {{"input", input_t},
|
||||||
|
{"scale", scale_t},
|
||||||
|
{"offset", offset_t},
|
||||||
|
{"mean", mean_t},
|
||||||
|
{"var", var_t}};
|
||||||
|
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||||
|
|
||||||
|
// Place all nodes on CPU.
|
||||||
|
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||||
|
item.graph.mutable_node(i)->set_device("/device:CPU:0");
|
||||||
|
}
|
||||||
|
|
||||||
|
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
|
||||||
|
GraphDef output;
|
||||||
|
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
|
||||||
|
int found = 0;
|
||||||
|
for (const NodeDef& node : output.node()) {
|
||||||
|
if (node.name() == "relu") {
|
||||||
|
EXPECT_EQ(node.op(), "Identity");
|
||||||
|
ASSERT_EQ(node.input_size(), 1);
|
||||||
|
EXPECT_EQ(node.input(0), "fused_batch_norm");
|
||||||
|
found++;
|
||||||
|
}
|
||||||
|
if (node.name() == "fused_batch_norm") {
|
||||||
|
EXPECT_EQ(node.op(), "_FusedBatchNormEx");
|
||||||
|
ASSERT_EQ(node.input_size(), 5);
|
||||||
|
EXPECT_EQ(node.input(0), "input_cast");
|
||||||
|
EXPECT_EQ(node.input(1), "scale");
|
||||||
|
EXPECT_EQ(node.input(2), "offset");
|
||||||
|
EXPECT_EQ(node.input(3), "mean");
|
||||||
|
EXPECT_EQ(node.input(4), "var");
|
||||||
|
|
||||||
|
auto attr = node.attr();
|
||||||
|
EXPECT_EQ(attr["num_side_inputs"].i(), 0);
|
||||||
|
EXPECT_EQ(attr["activation_mode"].s(), "Relu");
|
||||||
|
found++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_EQ(found, 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MklRemapperTest, FuseBatchNormWithAddAndRelu) {
|
||||||
|
using ::tensorflow::ops::Placeholder;
|
||||||
|
|
||||||
|
for (bool is_training : {true, false}) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
|
||||||
|
const int num_channels = 24;
|
||||||
|
|
||||||
|
TensorShape input_shape({2, 8, 8, num_channels});
|
||||||
|
TensorShape channel_shape({num_channels});
|
||||||
|
TensorShape empty_shape({0});
|
||||||
|
|
||||||
|
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT,
|
||||||
|
ops::Placeholder::Shape(input_shape));
|
||||||
|
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_FLOAT);
|
||||||
|
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT);
|
||||||
|
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT);
|
||||||
|
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT);
|
||||||
|
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT);
|
||||||
|
auto side_input = Placeholder(s.WithOpName("side_input"), DT_FLOAT,
|
||||||
|
ops::Placeholder::Shape(input_shape));
|
||||||
|
auto side_input_cast =
|
||||||
|
ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_FLOAT);
|
||||||
|
|
||||||
|
float epsilon = 0.1f;
|
||||||
|
auto fbn = ops::FusedBatchNormV3(
|
||||||
|
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
|
||||||
|
ops::FusedBatchNormV3::IsTraining(is_training)
|
||||||
|
.Epsilon(epsilon)
|
||||||
|
.DataFormat("NHWC"));
|
||||||
|
auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast);
|
||||||
|
auto relu = ops::Relu(s.WithOpName("relu"), add);
|
||||||
|
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
|
||||||
|
|
||||||
|
auto input_t = GenerateRandomTensor<DT_FLOAT>(input_shape);
|
||||||
|
auto scale_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
|
||||||
|
auto offset_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
|
||||||
|
auto mean_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
|
||||||
|
: channel_shape);
|
||||||
|
auto var_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
|
||||||
|
: channel_shape);
|
||||||
|
auto side_input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
item.fetch = {"fetch"};
|
||||||
|
item.feed = {{"input", input_t}, {"scale", scale_t},
|
||||||
|
{"offset", offset_t}, {"mean", mean_t},
|
||||||
|
{"var", var_t}, {"side_input", side_input_t}};
|
||||||
|
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||||
|
|
||||||
|
// Place all nodes on CPU.
|
||||||
|
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||||
|
item.graph.mutable_node(i)->set_device("/device:CPU:0");
|
||||||
|
}
|
||||||
|
|
||||||
|
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
|
||||||
|
GraphDef output;
|
||||||
|
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
|
||||||
|
int found = 0;
|
||||||
|
for (const NodeDef& node : output.node()) {
|
||||||
|
if (node.name() == "add") {
|
||||||
|
EXPECT_EQ(node.op(), "Add");
|
||||||
|
ASSERT_EQ(node.input_size(), 2);
|
||||||
|
EXPECT_EQ(node.input(0), "fused_batch_norm");
|
||||||
|
EXPECT_EQ(node.input(1), "side_input_cast");
|
||||||
|
found++;
|
||||||
|
}
|
||||||
|
if (node.name() == "relu") {
|
||||||
|
EXPECT_EQ(node.op(), "Relu");
|
||||||
|
ASSERT_EQ(node.input_size(), 1);
|
||||||
|
EXPECT_EQ(node.input(0), "add");
|
||||||
|
found++;
|
||||||
|
}
|
||||||
|
if (node.name() == "fused_batch_norm") {
|
||||||
|
EXPECT_EQ(node.op(), "FusedBatchNormV3");
|
||||||
|
ASSERT_EQ(node.input_size(), 5);
|
||||||
|
EXPECT_EQ(node.input(0), "input_cast");
|
||||||
|
EXPECT_EQ(node.input(1), "scale");
|
||||||
|
EXPECT_EQ(node.input(2), "offset");
|
||||||
|
EXPECT_EQ(node.input(3), "mean");
|
||||||
|
EXPECT_EQ(node.input(4), "var");
|
||||||
|
found++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_EQ(found, 3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif // ENABLE_MKLDNN_V1
|
||||||
|
|
||||||
} // namespace grappler
|
} // namespace grappler
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
#endif // INTEL_MKL
|
#endif // INTEL_MKL
|
||||||
|
@ -741,24 +741,27 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
|
|||||||
[&](const utils::MutableNodeView& fused_batch_norm) -> bool {
|
[&](const utils::MutableNodeView& fused_batch_norm) -> bool {
|
||||||
const auto* fused_batch_norm_node_def = fused_batch_norm.node();
|
const auto* fused_batch_norm_node_def = fused_batch_norm.node();
|
||||||
if (!IsFusedBatchNorm(*fused_batch_norm_node_def)) return false;
|
if (!IsFusedBatchNorm(*fused_batch_norm_node_def)) return false;
|
||||||
|
// We fuse FusedBatchNorm on GPU or MKL CPU.
|
||||||
// We fuse FusedBatchNorm only on GPU, because on CPU we fuse it with
|
#ifndef ENABLE_MKLDNN_V1
|
||||||
// contraction (MatMul or Conv2D node).
|
|
||||||
if (!NodeIsOnGpu(fused_batch_norm_node_def)) return false;
|
if (!NodeIsOnGpu(fused_batch_norm_node_def)) return false;
|
||||||
|
#endif
|
||||||
|
|
||||||
DataType t_dtype = GetDataTypeFromAttr(*fused_batch_norm_node_def, "T");
|
DataType t_dtype = GetDataTypeFromAttr(*fused_batch_norm_node_def, "T");
|
||||||
|
#ifndef ENABLE_MKLDNN_V1
|
||||||
if (t_dtype != DT_FLOAT && t_dtype != DT_HALF) return false;
|
if (t_dtype != DT_FLOAT && t_dtype != DT_HALF) return false;
|
||||||
|
#else
|
||||||
|
if (t_dtype != DT_FLOAT && t_dtype != DT_BFLOAT16) return false;
|
||||||
|
#endif
|
||||||
|
|
||||||
// Get the FusedBatchNorm training mode.
|
// Get the FusedBatchNorm training mode.
|
||||||
bool is_training;
|
bool is_training;
|
||||||
if (!GetNodeAttr(*fused_batch_norm_node_def, kIsTraining, &is_training)
|
if (!GetNodeAttr(*fused_batch_norm_node_def, kIsTraining, &is_training)
|
||||||
.ok())
|
.ok())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// In training mode we rely on cuDNN for computing FusedBatchNorm with side
|
// In training mode we rely on cuDNN for computing FusedBatchNorm with side
|
||||||
// inputs and activation, and it has its own limitations. In inference mode
|
// inputs and activation, and it has its own limitations. In inference mode
|
||||||
// we have a custom CUDA kernel that doesn't not have these constraints.
|
// we have a custom CUDA kernel that doesn't not have these constraints.
|
||||||
if (is_training) {
|
if (is_training && NodeIsOnGpu(fused_batch_norm_node_def)) {
|
||||||
// cuDNN only supports NHWC data layout.
|
// cuDNN only supports NHWC data layout.
|
||||||
string data_format;
|
string data_format;
|
||||||
if (!GetNodeAttr(*fused_batch_norm_node_def, kDataFormat, &data_format)
|
if (!GetNodeAttr(*fused_batch_norm_node_def, kDataFormat, &data_format)
|
||||||
@ -810,6 +813,12 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
|
|||||||
|
|
||||||
// Input to a Relu can be an Add node with FusedBatchNorm as one of the inputs
|
// Input to a Relu can be an Add node with FusedBatchNorm as one of the inputs
|
||||||
if (IsAdd(*relu_fanin_0_node_def)) {
|
if (IsAdd(*relu_fanin_0_node_def)) {
|
||||||
|
// Currently no CPU implementation for "FusedBatchNorm + SideInput +
|
||||||
|
// <Activation>""
|
||||||
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
|
return false;
|
||||||
|
#endif
|
||||||
|
|
||||||
// Check that only Relu node consumes the output of an Add node.
|
// Check that only Relu node consumes the output of an Add node.
|
||||||
if (HasControlFaninOrFanout(*relu_fanin_0_node_view) ||
|
if (HasControlFaninOrFanout(*relu_fanin_0_node_view) ||
|
||||||
!HasAtMostOneFanoutAtPort0(*relu_fanin_0_node_view) ||
|
!HasAtMostOneFanoutAtPort0(*relu_fanin_0_node_view) ||
|
||||||
@ -881,7 +890,11 @@ void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm,
|
|||||||
if (fused_batch_norm.op() != "FusedBatchNorm") {
|
if (fused_batch_norm.op() != "FusedBatchNorm") {
|
||||||
(*attr)["U"] = src_attr.at("U");
|
(*attr)["U"] = src_attr.at("U");
|
||||||
} else {
|
} else {
|
||||||
|
#ifndef ENABLE_MKLDNN_V1
|
||||||
(*attr)["U"] = src_attr.at("T");
|
(*attr)["U"] = src_attr.at("T");
|
||||||
|
#else
|
||||||
|
SetAttrValue(DT_FLOAT, &(*attr)["U"]);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8135,7 +8135,14 @@ tf_mkl_kernel_library(
|
|||||||
tf_mkl_kernel_library(
|
tf_mkl_kernel_library(
|
||||||
name = "mkl_fused_batch_norm_op",
|
name = "mkl_fused_batch_norm_op",
|
||||||
srcs = ["mkl_fused_batch_norm_op.cc"],
|
srcs = ["mkl_fused_batch_norm_op.cc"],
|
||||||
deps = NN_DEPS + mkl_deps(),
|
hdrs = [
|
||||||
|
"fused_batch_norm_op.h",
|
||||||
|
"no_op.h",
|
||||||
|
],
|
||||||
|
deps = NN_DEPS + [
|
||||||
|
":fused_batch_norm_op",
|
||||||
|
":no_op",
|
||||||
|
] + mkl_deps(),
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_cc_test_mkl(
|
tf_cc_test_mkl(
|
||||||
|
@ -14,14 +14,16 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#ifdef INTEL_MKL
|
#ifdef INTEL_MKL
|
||||||
#include "mkldnn.hpp"
|
#include "mkldnn.hpp"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/kernels/fused_batch_norm_op.h"
|
||||||
|
#include "tensorflow/core/kernels/no_op.h"
|
||||||
#include "tensorflow/core/util/mkl_types.h"
|
#include "tensorflow/core/util/mkl_types.h"
|
||||||
#include "tensorflow/core/util/mkl_util.h"
|
#include "tensorflow/core/util/mkl_util.h"
|
||||||
#include "tensorflow/core/util/tensor_format.h"
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
|
||||||
#define GET_FLAG(bn_flag) static_cast<int>(BN_FLAGS::bn_flag)
|
#define GET_FLAG(bn_flag) static_cast<int>(BN_FLAGS::bn_flag)
|
||||||
#define IS_SET(cflag) (context_.flags & GET_FLAG(cflag))
|
#define IS_SET(cflag) (context_.flags & GET_FLAG(cflag))
|
||||||
@ -37,11 +39,14 @@ using BatchNormBwdPd = mkldnn::batch_normalization_backward::primitive_desc;
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||||
|
|
||||||
|
using FusedBNActivationMode = functor::FusedBatchNormActivationMode;
|
||||||
|
|
||||||
struct MklBatchNormFwdParams {
|
struct MklBatchNormFwdParams {
|
||||||
memory::dims src_dims;
|
memory::dims src_dims;
|
||||||
int depth;
|
int depth;
|
||||||
float eps;
|
float eps;
|
||||||
bool training;
|
bool training;
|
||||||
|
FusedBNActivationMode activation_mode;
|
||||||
#ifndef ENABLE_MKLDNN_V1
|
#ifndef ENABLE_MKLDNN_V1
|
||||||
MEMORY_FORMAT src_format;
|
MEMORY_FORMAT src_format;
|
||||||
#else
|
#else
|
||||||
@ -50,14 +55,17 @@ struct MklBatchNormFwdParams {
|
|||||||
|
|
||||||
MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps,
|
MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps,
|
||||||
#ifndef ENABLE_MKLDNN_V1
|
#ifndef ENABLE_MKLDNN_V1
|
||||||
bool training, MEMORY_FORMAT src_format)
|
bool training, MEMORY_FORMAT src_format,
|
||||||
|
FusedBNActivationMode activation_mode)
|
||||||
#else
|
#else
|
||||||
bool training, memory::desc src_md)
|
bool training, memory::desc src_md,
|
||||||
|
FusedBNActivationMode activation_mode)
|
||||||
#endif // !ENABLE_MKLDNN_V1
|
#endif // !ENABLE_MKLDNN_V1
|
||||||
: src_dims(src_dims),
|
: src_dims(src_dims),
|
||||||
depth(depth),
|
depth(depth),
|
||||||
eps(eps),
|
eps(eps),
|
||||||
training(training),
|
training(training),
|
||||||
|
activation_mode(activation_mode),
|
||||||
#ifndef ENABLE_MKLDNN_V1
|
#ifndef ENABLE_MKLDNN_V1
|
||||||
src_format(src_format) {
|
src_format(src_format) {
|
||||||
}
|
}
|
||||||
@ -90,7 +98,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
|||||||
// mean_data: output data buffer of means
|
// mean_data: output data buffer of means
|
||||||
// variance_data: output data buffer of variances
|
// variance_data: output data buffer of variances
|
||||||
void Execute(const T* src_data, const U* weights_data, T* dst_data,
|
void Execute(const T* src_data, const U* weights_data, T* dst_data,
|
||||||
U* mean_data, U* variance_data) {
|
U* mean_data, U* variance_data, U* workspace_data) {
|
||||||
context_.src_mem->set_data_handle(
|
context_.src_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<T*>(src_data)));
|
static_cast<void*>(const_cast<T*>(src_data)));
|
||||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
||||||
@ -104,6 +112,9 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
|||||||
context_.mean_mem->set_data_handle(static_cast<void*>(mean_data));
|
context_.mean_mem->set_data_handle(static_cast<void*>(mean_data));
|
||||||
context_.variance_mem->set_data_handle(static_cast<void*>(variance_data));
|
context_.variance_mem->set_data_handle(static_cast<void*>(variance_data));
|
||||||
}
|
}
|
||||||
|
if (workspace_data != nullptr) {
|
||||||
|
context_.ws_mem->set_data_handle(workspace_data);
|
||||||
|
}
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
// Execute batch-normalization forward primitives.
|
// Execute batch-normalization forward primitives.
|
||||||
execute_primitives(context_.fwd_primitives, context_.fwd_stream,
|
execute_primitives(context_.fwd_primitives, context_.fwd_stream,
|
||||||
@ -123,6 +134,10 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
|||||||
context_.mean_mem->set_data_handle(DummyData);
|
context_.mean_mem->set_data_handle(DummyData);
|
||||||
context_.variance_mem->set_data_handle(DummyData);
|
context_.variance_mem->set_data_handle(DummyData);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (workspace_data != nullptr) {
|
||||||
|
context_.ws_mem->set_data_handle(DummyData);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MEMORY_PRIMITIVE_DESC GetDstPd() const { return context_.dst_mem->GET_DESC; }
|
MEMORY_PRIMITIVE_DESC GetDstPd() const { return context_.dst_mem->GET_DESC; }
|
||||||
@ -158,6 +173,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
|||||||
std::shared_ptr<mkldnn::memory> dst_mem;
|
std::shared_ptr<mkldnn::memory> dst_mem;
|
||||||
std::shared_ptr<mkldnn::memory> mean_mem;
|
std::shared_ptr<mkldnn::memory> mean_mem;
|
||||||
std::shared_ptr<mkldnn::memory> variance_mem;
|
std::shared_ptr<mkldnn::memory> variance_mem;
|
||||||
|
std::shared_ptr<mkldnn::memory> ws_mem;
|
||||||
|
|
||||||
// Forward BatchNorm primitive descriptor.
|
// Forward BatchNorm primitive descriptor.
|
||||||
std::shared_ptr<BatchNormFwdPd> fwd_pd;
|
std::shared_ptr<BatchNormFwdPd> fwd_pd;
|
||||||
@ -179,6 +195,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
|||||||
dst_mem(nullptr),
|
dst_mem(nullptr),
|
||||||
mean_mem(nullptr),
|
mean_mem(nullptr),
|
||||||
variance_mem(nullptr),
|
variance_mem(nullptr),
|
||||||
|
ws_mem(nullptr),
|
||||||
bn_fwd(nullptr),
|
bn_fwd(nullptr),
|
||||||
fwd_stream(nullptr) {}
|
fwd_stream(nullptr) {}
|
||||||
};
|
};
|
||||||
@ -192,6 +209,9 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
|||||||
: prop_kind::forward_scoring;
|
: prop_kind::forward_scoring;
|
||||||
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
|
if (fwdParams.activation_mode == FusedBNActivationMode::kRelu) {
|
||||||
|
context_.flags |= GET_FLAG(fuse_norm_relu);
|
||||||
|
}
|
||||||
// Memory descriptor
|
// Memory descriptor
|
||||||
auto src_md = fwdParams.src_md;
|
auto src_md = fwdParams.src_md;
|
||||||
// Create forward BatchNorm descriptor and primitive descriptor.
|
// Create forward BatchNorm descriptor and primitive descriptor.
|
||||||
@ -229,6 +249,13 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
|||||||
m_dims, U, MEMORY_FORMAT::nc, cpu_engine_, DummyData));
|
m_dims, U, MEMORY_FORMAT::nc, cpu_engine_, DummyData));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
|
if (IS_SET(fuse_norm_relu)) {
|
||||||
|
context_.ws_mem.reset(new MEMORY_CONSTRUCTOR(
|
||||||
|
context_.fwd_pd->workspace_desc(), cpu_engine_, DummyData));
|
||||||
|
}
|
||||||
|
#endif // ENABLE_MKLDNN_V1
|
||||||
|
|
||||||
// BatchNorm forward primitive.
|
// BatchNorm forward primitive.
|
||||||
// TODO(intel-tf): Merge all the #ifdefs and simplify code
|
// TODO(intel-tf): Merge all the #ifdefs and simplify code
|
||||||
if (!fwdParams.training && !(IS_SET(use_global_stats))) {
|
if (!fwdParams.training && !(IS_SET(use_global_stats))) {
|
||||||
@ -258,6 +285,16 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
|||||||
} else if (IS_SET(use_global_stats)) {
|
} else if (IS_SET(use_global_stats)) {
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) {
|
if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) {
|
||||||
|
if (IS_SET(fuse_norm_relu)) {
|
||||||
|
context_.net_args.push_back(
|
||||||
|
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||||
|
{MKLDNN_ARG_MEAN, *context_.mean_mem},
|
||||||
|
{MKLDNN_ARG_VARIANCE, *context_.variance_mem},
|
||||||
|
{MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
|
||||||
|
{MKLDNN_ARG_DST, *context_.dst_mem},
|
||||||
|
{ MKLDNN_ARG_WORKSPACE,
|
||||||
|
*context_.ws_mem }});
|
||||||
|
} else {
|
||||||
context_.net_args.push_back(
|
context_.net_args.push_back(
|
||||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||||
{MKLDNN_ARG_MEAN, *context_.mean_mem},
|
{MKLDNN_ARG_MEAN, *context_.mean_mem},
|
||||||
@ -265,6 +302,16 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
|||||||
{MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
|
{MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
|
||||||
{ MKLDNN_ARG_DST,
|
{ MKLDNN_ARG_DST,
|
||||||
*context_.dst_mem }});
|
*context_.dst_mem }});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (IS_SET(fuse_norm_relu)) {
|
||||||
|
context_.net_args.push_back(
|
||||||
|
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||||
|
{MKLDNN_ARG_MEAN, *context_.mean_mem},
|
||||||
|
{MKLDNN_ARG_VARIANCE, *context_.variance_mem},
|
||||||
|
{MKLDNN_ARG_DST, *context_.dst_mem},
|
||||||
|
{ MKLDNN_ARG_WORKSPACE,
|
||||||
|
*context_.ws_mem }});
|
||||||
} else {
|
} else {
|
||||||
context_.net_args.push_back(
|
context_.net_args.push_back(
|
||||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||||
@ -273,6 +320,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
|||||||
{ MKLDNN_ARG_DST,
|
{ MKLDNN_ARG_DST,
|
||||||
*context_.dst_mem }});
|
*context_.dst_mem }});
|
||||||
}
|
}
|
||||||
|
}
|
||||||
context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
|
context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
|
||||||
#else
|
#else
|
||||||
if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) {
|
if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) {
|
||||||
@ -291,6 +339,16 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
|||||||
} else {
|
} else {
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) {
|
if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) {
|
||||||
|
if (IS_SET(fuse_norm_relu)) {
|
||||||
|
context_.net_args.push_back(
|
||||||
|
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||||
|
{MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
|
||||||
|
{MKLDNN_ARG_DST, *context_.dst_mem},
|
||||||
|
{MKLDNN_ARG_MEAN, *context_.mean_mem},
|
||||||
|
{MKLDNN_ARG_VARIANCE, *context_.variance_mem},
|
||||||
|
{ MKLDNN_ARG_WORKSPACE,
|
||||||
|
*context_.ws_mem }});
|
||||||
|
} else {
|
||||||
context_.net_args.push_back(
|
context_.net_args.push_back(
|
||||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||||
{MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
|
{MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
|
||||||
@ -298,6 +356,16 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
|||||||
{MKLDNN_ARG_MEAN, *context_.mean_mem},
|
{MKLDNN_ARG_MEAN, *context_.mean_mem},
|
||||||
{ MKLDNN_ARG_VARIANCE,
|
{ MKLDNN_ARG_VARIANCE,
|
||||||
*context_.variance_mem }});
|
*context_.variance_mem }});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (IS_SET(fuse_norm_relu)) {
|
||||||
|
context_.net_args.push_back(
|
||||||
|
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||||
|
{MKLDNN_ARG_DST, *context_.dst_mem},
|
||||||
|
{MKLDNN_ARG_MEAN, *context_.mean_mem},
|
||||||
|
{MKLDNN_ARG_VARIANCE, *context_.variance_mem},
|
||||||
|
{ MKLDNN_ARG_WORKSPACE,
|
||||||
|
*context_.ws_mem }});
|
||||||
} else {
|
} else {
|
||||||
context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem},
|
context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||||
{MKLDNN_ARG_DST, *context_.dst_mem},
|
{MKLDNN_ARG_DST, *context_.dst_mem},
|
||||||
@ -305,6 +373,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
|||||||
{ MKLDNN_ARG_VARIANCE,
|
{ MKLDNN_ARG_VARIANCE,
|
||||||
*context_.variance_mem }});
|
*context_.variance_mem }});
|
||||||
}
|
}
|
||||||
|
}
|
||||||
context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
|
context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
|
||||||
#else
|
#else
|
||||||
if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) {
|
if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) {
|
||||||
@ -360,6 +429,7 @@ class MklFusedBatchNormFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
|||||||
key_creator.AddAsKey<int>(fwdParams.depth);
|
key_creator.AddAsKey<int>(fwdParams.depth);
|
||||||
key_creator.AddAsKey<float>(fwdParams.eps);
|
key_creator.AddAsKey<float>(fwdParams.eps);
|
||||||
key_creator.AddAsKey<bool>(fwdParams.training);
|
key_creator.AddAsKey<bool>(fwdParams.training);
|
||||||
|
key_creator.AddAsKey<FusedBNActivationMode>(fwdParams.activation_mode);
|
||||||
key_creator.AddAsKey(typeid(T).name());
|
key_creator.AddAsKey(typeid(T).name());
|
||||||
key_creator.AddAsKey(typeid(U).name());
|
key_creator.AddAsKey(typeid(U).name());
|
||||||
return key_creator.GetKey();
|
return key_creator.GetKey();
|
||||||
@ -676,7 +746,8 @@ class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
|||||||
// Adding a third parameter to the template to support FusedBatchNormV3
|
// Adding a third parameter to the template to support FusedBatchNormV3
|
||||||
// with MKL. This is different from default where the classes are
|
// with MKL. This is different from default where the classes are
|
||||||
// derived. Moves enabling to compile-time rather than runtime.
|
// derived. Moves enabling to compile-time rather than runtime.
|
||||||
template <typename Device, typename T, typename U, bool reserved_space>
|
template <typename Device, typename T, typename U, bool reserved_space,
|
||||||
|
bool is_batch_norm_ex = false>
|
||||||
class MklFusedBatchNormOp : public OpKernel {
|
class MklFusedBatchNormOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit MklFusedBatchNormOp(OpKernelConstruction* context)
|
explicit MklFusedBatchNormOp(OpKernelConstruction* context)
|
||||||
@ -696,6 +767,28 @@ class MklFusedBatchNormOp : public OpKernel {
|
|||||||
depth_ = 0;
|
depth_ = 0;
|
||||||
mean_values_ = nullptr;
|
mean_values_ = nullptr;
|
||||||
variance_values_ = nullptr;
|
variance_values_ = nullptr;
|
||||||
|
|
||||||
|
#ifndef ENABLE_MKLDNN_V1
|
||||||
|
OP_REQUIRES(context, !is_batch_norm_ex,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"_MklFusedBatchNormEx is not supported in DNNL 0.x ."));
|
||||||
|
#endif
|
||||||
|
if (!is_batch_norm_ex) {
|
||||||
|
activation_mode_ = FusedBNActivationMode::kIdentity;
|
||||||
|
} else {
|
||||||
|
int num_side_inputs;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->GetAttr("num_side_inputs", &num_side_inputs));
|
||||||
|
// Currently _MKLFusedBatchNormEx do not support "SideInput"
|
||||||
|
OP_REQUIRES(context, num_side_inputs == 0,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"_MKLFusedBatchNorm do not support side input now."));
|
||||||
|
|
||||||
|
OP_REQUIRES_OK(context, ParseActivationMode(context, &activation_mode_));
|
||||||
|
OP_REQUIRES(context, activation_mode_ == FusedBNActivationMode::kRelu,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"_MKLFusedBatchNorm only support Relu activation"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
@ -744,9 +837,12 @@ class MklFusedBatchNormOp : public OpKernel {
|
|||||||
|
|
||||||
// Handle the special case: input with 0 element and 0 batch size.
|
// Handle the special case: input with 0 element and 0 batch size.
|
||||||
Tensor* dst_tensor = nullptr;
|
Tensor* dst_tensor = nullptr;
|
||||||
|
TensorShape workspace_tf_shape;
|
||||||
if (tf_shape_src.num_elements() == 0) {
|
if (tf_shape_src.num_elements() == 0) {
|
||||||
HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(),
|
size_t workspace_bytes = 0;
|
||||||
&dst_tensor);
|
workspace_tf_shape.AddDim(workspace_bytes);
|
||||||
|
HandleEmptyInput(context, tf_shape_src, workspace_tf_shape,
|
||||||
|
scale_tensor.shape(), &dst_tensor);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -758,23 +854,16 @@ class MklFusedBatchNormOp : public OpKernel {
|
|||||||
// Index of output tensor(diff_src).
|
// Index of output tensor(diff_src).
|
||||||
const size_t kDstIndex = 0;
|
const size_t kDstIndex = 0;
|
||||||
|
|
||||||
// Allocate 4 output TF tensors.
|
// Allocate 5 output TF tensors.
|
||||||
Tensor* batch_mean_tensor = nullptr;
|
Tensor* batch_mean_tensor = nullptr;
|
||||||
Tensor* batch_variance_tensor = nullptr;
|
Tensor* batch_variance_tensor = nullptr;
|
||||||
Tensor* saved_mean_tensor = nullptr;
|
Tensor* saved_mean_tensor = nullptr;
|
||||||
Tensor* saved_variance_tensor = nullptr;
|
Tensor* saved_variance_tensor = nullptr;
|
||||||
Tensor* reserved_space_tensor = nullptr;
|
Tensor* reserved_space_tensor = nullptr;
|
||||||
AllocateTFOutputs(context, scale_tensor.shape(), &batch_mean_tensor,
|
|
||||||
&batch_variance_tensor, &saved_mean_tensor,
|
|
||||||
&saved_variance_tensor, &reserved_space_tensor);
|
|
||||||
|
|
||||||
if (is_training_)
|
|
||||||
SetMeanVariance(*batch_mean_tensor, *batch_variance_tensor);
|
|
||||||
else
|
|
||||||
SetMeanVariance(est_mean_tensor, est_variance_tensor);
|
|
||||||
|
|
||||||
MklDnnData<T> src(&cpu_engine_);
|
MklDnnData<T> src(&cpu_engine_);
|
||||||
MklDnnData<U> weights(&cpu_engine_);
|
MklDnnData<U> weights(&cpu_engine_);
|
||||||
|
MklDnnData<U> wksp(&cpu_engine_);
|
||||||
|
|
||||||
MEMORY_FORMAT dnn_fmt;
|
MEMORY_FORMAT dnn_fmt;
|
||||||
MKL_TENSOR_FORMAT mkl_tensor_fmt;
|
MKL_TENSOR_FORMAT mkl_tensor_fmt;
|
||||||
@ -801,6 +890,51 @@ class MklFusedBatchNormOp : public OpKernel {
|
|||||||
? dnn_shape_src.GetMklLayout()
|
? dnn_shape_src.GetMklLayout()
|
||||||
: memory::desc(src_dims, MklDnnType<T>(), dnn_fmt);
|
: memory::desc(src_dims, MklDnnType<T>(), dnn_fmt);
|
||||||
|
|
||||||
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
|
MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_,
|
||||||
|
src_md, activation_mode_);
|
||||||
|
#else
|
||||||
|
MklBatchNormFwdParams fwdParams(
|
||||||
|
src_dims, depth_, epsilon_, is_training_,
|
||||||
|
static_cast<MEMORY_FORMAT>(src_md.data.format), activation_mode_);
|
||||||
|
#endif // ENABLE_MKLDNN_V1
|
||||||
|
// Get forward batch-normalization op from the primitive caching pool.
|
||||||
|
MklFusedBatchNormFwdPrimitive<T, U>* bn_fwd =
|
||||||
|
MklFusedBatchNormFwdPrimitiveFactory<T, U>::Get(fwdParams);
|
||||||
|
|
||||||
|
// Allocate workspace tensor
|
||||||
|
U* ws_data = nullptr;
|
||||||
|
if (fwdParams.activation_mode == FusedBNActivationMode::kRelu) {
|
||||||
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
|
MEMORY_PRIMITIVE_DESC workspace_pd =
|
||||||
|
bn_fwd->GetBatchNormFwdPd()->workspace_desc();
|
||||||
|
size_t workspace_bytes = workspace_pd.get_size();
|
||||||
|
workspace_tf_shape.AddDim(workspace_bytes);
|
||||||
|
|
||||||
|
AllocateTFOutputs(context, scale_tensor.shape(), workspace_tf_shape,
|
||||||
|
&batch_mean_tensor, &batch_variance_tensor,
|
||||||
|
&saved_mean_tensor, &saved_variance_tensor,
|
||||||
|
&reserved_space_tensor);
|
||||||
|
if (reserved_space) {
|
||||||
|
wksp.SetUsrMem(workspace_pd, reserved_space_tensor);
|
||||||
|
ws_data = static_cast<U*>(wksp.GetOpMem().get_data_handle());
|
||||||
|
}
|
||||||
|
#endif // ENABLE_MKLDNN_V1
|
||||||
|
} else {
|
||||||
|
// There is actually no workspace tensor out, so we make a dummy one.
|
||||||
|
size_t workspace_bytes = 0;
|
||||||
|
workspace_tf_shape.AddDim(workspace_bytes);
|
||||||
|
AllocateTFOutputs(context, scale_tensor.shape(), workspace_tf_shape,
|
||||||
|
&batch_mean_tensor, &batch_variance_tensor,
|
||||||
|
&saved_mean_tensor, &saved_variance_tensor,
|
||||||
|
&reserved_space_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_training_)
|
||||||
|
SetMeanVariance(*batch_mean_tensor, *batch_variance_tensor);
|
||||||
|
else
|
||||||
|
SetMeanVariance(est_mean_tensor, est_variance_tensor);
|
||||||
|
|
||||||
// MKL-DNN packs scale & shift as "weights":
|
// MKL-DNN packs scale & shift as "weights":
|
||||||
// <scale>...<scale><shift>...<shift>
|
// <scale>...<scale><shift>...<shift>
|
||||||
weights.AllocateBuffer(2 * depth_ * sizeof(U));
|
weights.AllocateBuffer(2 * depth_ * sizeof(U));
|
||||||
@ -821,18 +955,6 @@ class MklFusedBatchNormOp : public OpKernel {
|
|||||||
reinterpret_cast<char*>(variance_values_),
|
reinterpret_cast<char*>(variance_values_),
|
||||||
depth_ * sizeof(U));
|
depth_ * sizeof(U));
|
||||||
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_,
|
|
||||||
src_md);
|
|
||||||
#else
|
|
||||||
MklBatchNormFwdParams fwdParams(
|
|
||||||
src_dims, depth_, epsilon_, is_training_,
|
|
||||||
static_cast<MEMORY_FORMAT>(src_md.data.format));
|
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
// Get forward batch-normalization op from the primitive caching pool.
|
|
||||||
MklFusedBatchNormFwdPrimitive<T, U>* bn_fwd =
|
|
||||||
MklFusedBatchNormFwdPrimitiveFactory<T, U>::Get(fwdParams);
|
|
||||||
|
|
||||||
// Check if reorder is needed for src.
|
// Check if reorder is needed for src.
|
||||||
const T* src_data = nullptr;
|
const T* src_data = nullptr;
|
||||||
std::shared_ptr<BatchNormFwdPd> bn_fwd_pd = bn_fwd->GetBatchNormFwdPd();
|
std::shared_ptr<BatchNormFwdPd> bn_fwd_pd = bn_fwd->GetBatchNormFwdPd();
|
||||||
@ -866,7 +988,7 @@ class MklFusedBatchNormOp : public OpKernel {
|
|||||||
|
|
||||||
// Execute
|
// Execute
|
||||||
bn_fwd->Execute(src_data, weights_op_data, dst_data, mean_op_data,
|
bn_fwd->Execute(src_data, weights_op_data, dst_data, mean_op_data,
|
||||||
variance_op_data);
|
variance_op_data, ws_data);
|
||||||
|
|
||||||
float adjust_factor = 1.0;
|
float adjust_factor = 1.0;
|
||||||
if (is_training_) {
|
if (is_training_) {
|
||||||
@ -924,6 +1046,7 @@ class MklFusedBatchNormOp : public OpKernel {
|
|||||||
U* mean_values_;
|
U* mean_values_;
|
||||||
U* variance_values_;
|
U* variance_values_;
|
||||||
size_t depth_; // Batch normalization is performed for per channel.
|
size_t depth_; // Batch normalization is performed for per channel.
|
||||||
|
FusedBNActivationMode activation_mode_;
|
||||||
engine cpu_engine_ = engine(ENGINE_CPU, 0);
|
engine cpu_engine_ = engine(ENGINE_CPU, 0);
|
||||||
|
|
||||||
void ExtractParams(OpKernelContext* context) {
|
void ExtractParams(OpKernelContext* context) {
|
||||||
@ -938,6 +1061,7 @@ class MklFusedBatchNormOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src,
|
void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src,
|
||||||
|
TensorShape workspace_tf_shape,
|
||||||
TensorShape tf_shape_scale, Tensor** dst_tensor) {
|
TensorShape tf_shape_scale, Tensor** dst_tensor) {
|
||||||
DCHECK(dst_tensor);
|
DCHECK(dst_tensor);
|
||||||
|
|
||||||
@ -955,12 +1079,14 @@ class MklFusedBatchNormOp : public OpKernel {
|
|||||||
Tensor* saved_mean_tensor = nullptr;
|
Tensor* saved_mean_tensor = nullptr;
|
||||||
Tensor* saved_variance_tensor = nullptr;
|
Tensor* saved_variance_tensor = nullptr;
|
||||||
Tensor* reserved_space_tensor = nullptr;
|
Tensor* reserved_space_tensor = nullptr;
|
||||||
AllocateTFOutputs(context, tf_shape_scale, &batch_mean_tensor,
|
AllocateTFOutputs(context, tf_shape_scale, workspace_tf_shape,
|
||||||
&batch_variance_tensor, &saved_mean_tensor,
|
&batch_mean_tensor, &batch_variance_tensor,
|
||||||
&saved_variance_tensor, &reserved_space_tensor);
|
&saved_mean_tensor, &saved_variance_tensor,
|
||||||
|
&reserved_space_tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllocateTFOutputs(OpKernelContext* context, TensorShape tf_shape_scale,
|
void AllocateTFOutputs(OpKernelContext* context, TensorShape tf_shape_scale,
|
||||||
|
TensorShape workspace_tf_shape,
|
||||||
Tensor** batch_mean_tensor,
|
Tensor** batch_mean_tensor,
|
||||||
Tensor** batch_variance_tensor,
|
Tensor** batch_variance_tensor,
|
||||||
Tensor** saved_mean_tensor,
|
Tensor** saved_mean_tensor,
|
||||||
@ -1024,21 +1150,15 @@ class MklFusedBatchNormOp : public OpKernel {
|
|||||||
std::fill_n(saved_variance_data, num_elements, static_cast<U>(0));
|
std::fill_n(saved_variance_data, num_elements, static_cast<U>(0));
|
||||||
|
|
||||||
// Changes to support reserved_space_3 parameter in FusedBatchNormV3.
|
// Changes to support reserved_space_3 parameter in FusedBatchNormV3.
|
||||||
// TODO: This parameter functionality is not implemented on CPU.
|
|
||||||
// It is used to hold intermediate results. So the allocated
|
|
||||||
// memory is filled with 0s.
|
|
||||||
if (reserved_space) {
|
if (reserved_space) {
|
||||||
DCHECK(reserved_space_tensor != nullptr);
|
DCHECK(reserved_space_tensor != nullptr);
|
||||||
|
|
||||||
MklDnnShape mkl_shape_reserved_space;
|
MklDnnShape mkl_shape_reserved_space;
|
||||||
mkl_shape_reserved_space.SetMklTensor(false);
|
mkl_shape_reserved_space.SetMklTensor(false);
|
||||||
AllocateOutputSetMklShape(context, kReservedSpaceIndex,
|
AllocateOutputSetMklShape(context, kReservedSpaceIndex,
|
||||||
reserved_space_tensor, tf_shape_scale,
|
reserved_space_tensor, workspace_tf_shape,
|
||||||
mkl_shape_reserved_space);
|
mkl_shape_reserved_space);
|
||||||
DCHECK((*reserved_space_tensor) != nullptr);
|
DCHECK((*reserved_space_tensor) != nullptr);
|
||||||
auto saved_reserved_space_data =
|
|
||||||
(*reserved_space_tensor)->flat<U>().data();
|
|
||||||
std::fill_n(saved_reserved_space_data, num_elements, static_cast<U>(0));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -1367,7 +1487,7 @@ class MklFusedBatchNormGradOp : public OpKernel {
|
|||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<T>("T") \
|
.TypeConstraint<T>("T") \
|
||||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||||
MklFusedBatchNormOp<CPUDevice, T, T, false>);
|
MklFusedBatchNormOp<CPUDevice, T, T, false, false>);
|
||||||
|
|
||||||
TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_CPU);
|
TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_CPU);
|
||||||
TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_CPU);
|
TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_CPU);
|
||||||
@ -1380,7 +1500,7 @@ TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_CPU);
|
|||||||
.TypeConstraint<T>("T") \
|
.TypeConstraint<T>("T") \
|
||||||
.TypeConstraint<U>("U") \
|
.TypeConstraint<U>("U") \
|
||||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||||
MklFusedBatchNormOp<CPUDevice, T, U, false>);
|
MklFusedBatchNormOp<CPUDevice, T, U, false, false>);
|
||||||
|
|
||||||
REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(float, float);
|
REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(float, float);
|
||||||
REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(bfloat16, float);
|
REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(bfloat16, float);
|
||||||
@ -1421,12 +1541,30 @@ REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(bfloat16, float);
|
|||||||
.TypeConstraint<T>("T") \
|
.TypeConstraint<T>("T") \
|
||||||
.TypeConstraint<U>("U") \
|
.TypeConstraint<U>("U") \
|
||||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||||
MklFusedBatchNormOp<CPUDevice, T, U, true>);
|
MklFusedBatchNormOp<CPUDevice, T, U, true, false>); \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("_MklFusedBatchNormEx") \
|
||||||
|
.Device(DEVICE_CPU) \
|
||||||
|
.TypeConstraint<T>("T") \
|
||||||
|
.TypeConstraint<U>("U") \
|
||||||
|
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||||
|
MklFusedBatchNormOp<CPUDevice, T, U, true, true>);
|
||||||
|
|
||||||
REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(float, float);
|
REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(float, float);
|
||||||
REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(bfloat16, float);
|
REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(bfloat16, float);
|
||||||
#undef REGISTER_MKL_FUSED_BATCHNORM_V3_CPU
|
#undef REGISTER_MKL_FUSED_BATCHNORM_V3_CPU
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx")
|
||||||
|
.Device(DEVICE_CPU)
|
||||||
|
.TypeConstraint<float>("T")
|
||||||
|
.TypeConstraint<float>("U"),
|
||||||
|
NoOp);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx")
|
||||||
|
.Device(DEVICE_CPU)
|
||||||
|
.TypeConstraint<bfloat16>("T")
|
||||||
|
.TypeConstraint<float>("U"),
|
||||||
|
NoOp);
|
||||||
|
|
||||||
#define REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(T, U) \
|
#define REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(T, U) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("_MklFusedBatchNormGradV3") \
|
Name("_MklFusedBatchNormGradV3") \
|
||||||
|
@ -1342,6 +1342,48 @@ REGISTER_OP("_MklFusedBatchNormGradV3")
|
|||||||
R"doc(MKL-DNN implementation of FusedBatchNormGradV3: Do not invoke this operator directly in Python.
|
R"doc(MKL-DNN implementation of FusedBatchNormGradV3: Do not invoke this operator directly in Python.
|
||||||
Graph rewrite pass is expected to invoke this operator.)doc");
|
Graph rewrite pass is expected to invoke this operator.)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("_MklFusedBatchNormEx")
|
||||||
|
.Input("x: T")
|
||||||
|
.Input("scale: U")
|
||||||
|
.Input("offset: U")
|
||||||
|
.Input("mean: U")
|
||||||
|
.Input("variance: U")
|
||||||
|
.Input("side_input: num_side_inputs * T")
|
||||||
|
.Input("mkl_x: uint8")
|
||||||
|
.Input("mkl_scale: uint8")
|
||||||
|
.Input("mkl_offset: uint8")
|
||||||
|
.Input("mkl_mean: uint8")
|
||||||
|
.Input("mkl_variance: uint8")
|
||||||
|
.Input("mkl_side_input: num_side_inputs * uint8")
|
||||||
|
.Output("y: T")
|
||||||
|
.Output("batch_mean: U")
|
||||||
|
.Output("batch_variance: U")
|
||||||
|
.Output("reserve_space_1: U")
|
||||||
|
.Output("reserve_space_2: U")
|
||||||
|
.Output("reserve_space_3: U")
|
||||||
|
.Output("mkl_y: uint8")
|
||||||
|
.Output("mkl_batch_mean: uint8")
|
||||||
|
.Output("mkl_batch_variance: uint8")
|
||||||
|
.Output("mkl_reserve_space_1: uint8")
|
||||||
|
.Output("mkl_reserve_space_2: uint8")
|
||||||
|
.Output("mkl_reserve_space_3: uint8")
|
||||||
|
.Attr("T: {bfloat16, float}")
|
||||||
|
.Attr("U: {float}")
|
||||||
|
.Attr("epsilon: float = 0.0001")
|
||||||
|
.Attr("exponential_avg_factor: float = 1.0")
|
||||||
|
.Attr(GetConvnetDataFormatAttrString())
|
||||||
|
.Attr("num_side_inputs: int >= 0 = 0")
|
||||||
|
.Attr("activation_mode: string = \"Identity\"")
|
||||||
|
.Attr("is_training: bool = true")
|
||||||
|
.SetShapeFn(shape_inference::FusedBatchNormShape)
|
||||||
|
.Doc(R"doc(
|
||||||
|
MKL version of FusedBatchNormEx operator. Uses MKL DNN APIs to perform fused
|
||||||
|
batch normalization and relu.
|
||||||
|
|
||||||
|
NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
|
||||||
|
expected to invoke these operators.
|
||||||
|
)doc");
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // INTEL_MKL
|
#endif // INTEL_MKL
|
||||||
|
@ -238,7 +238,11 @@ REGISTER_OP("_FusedBatchNormEx")
|
|||||||
.Output("reserve_space_1: U")
|
.Output("reserve_space_1: U")
|
||||||
.Output("reserve_space_2: U")
|
.Output("reserve_space_2: U")
|
||||||
.Output("reserve_space_3: U")
|
.Output("reserve_space_3: U")
|
||||||
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
|
.Attr("T: {half, float, bfloat16}")
|
||||||
|
#else
|
||||||
.Attr("T: {half, float}")
|
.Attr("T: {half, float}")
|
||||||
|
#endif
|
||||||
.Attr("U: {float}")
|
.Attr("U: {float}")
|
||||||
.Attr("epsilon: float = 0.0001")
|
.Attr("epsilon: float = 0.0001")
|
||||||
.Attr("exponential_avg_factor: float = 1.0")
|
.Attr("exponential_avg_factor: float = 1.0")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user