Fuse BN and Relu in mkl path

This commit is contained in:
ShengYang1 2020-03-29 16:23:03 +08:00
parent 5992e75800
commit d4d23502bf
8 changed files with 601 additions and 84 deletions

View File

@ -268,6 +268,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.dequantize = "Dequantize";
csinfo_.fused_batch_norm = "FusedBatchNorm";
csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
csinfo_.fused_batch_norm_ex = "_FusedBatchNormEx";
csinfo_.fused_batch_norm_v2 = "FusedBatchNormV2";
csinfo_.fused_batch_norm_grad_v2 = "FusedBatchNormGradV2";
csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3";
@ -294,6 +295,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
"_MklDepthwiseConv2dNativeBackpropInput";
csinfo_.mkl_depthwise_conv2d_grad_filter =
"_MklDepthwiseConv2dNativeBackpropFilter";
csinfo_.mkl_fused_batch_norm_ex = "_MklFusedBatchNormEx";
csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
csinfo_.mkl_fused_matmul = "_MklFusedMatMul";
csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
@ -476,6 +478,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
{csinfo_.fused_batch_norm_grad_v3,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
#ifdef ENABLE_MKLDNN_V1
rinfo_.push_back({csinfo_.fused_batch_norm_ex,
csinfo_.mkl_fused_batch_norm_ex, CopyAttrsAll,
FusedBatchNormExRewrite, kRewriteForLayoutPropagation});
#endif
rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d,
CopyAttrsFusedConv2D, FusedConv2DRewrite,
kRewriteForLayoutPropagation});
@ -920,6 +927,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string dequantize;
string fused_batch_norm;
string fused_batch_norm_grad;
string fused_batch_norm_ex;
string fused_batch_norm_v2;
string fused_batch_norm_grad_v2;
string fused_batch_norm_v3;
@ -944,6 +952,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string mkl_conv2d_with_bias;
string mkl_depthwise_conv2d_grad_input;
string mkl_depthwise_conv2d_grad_filter;
string mkl_fused_batch_norm_ex;
string mkl_fused_conv2d;
string mkl_fused_matmul;
string mkl_pad_with_conv2d;
@ -1652,6 +1661,31 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
return do_rewrite;
}
static bool FusedBatchNormExRewrite(const Node* n) {
CHECK_NOTNULL(n);
int num_side_inputs;
TF_CHECK_OK(GetNodeAttr(n->def(), "num_side_inputs", &num_side_inputs));
string activation_mode;
TF_CHECK_OK(GetNodeAttr(n->def(), "activation_mode", &activation_mode));
// if the num_side_inputs is not 0, don't rewrite the node.
if (num_side_inputs != 0) {
VLOG(1) << "FusedBatchNormExRewrite: The model sets num_side_inputs"
<< "larger than 0 is not optimized by Intel MKL.";
return false;
}
// if the activation_mode is not 'Relu', don't rewrite the node.
if (activation_mode != "Relu") {
VLOG(1) << "FusedBatchNormExRewrite: Only Relu activation mode is"
<< "supported by Intel MKL.";
return false;
}
return true;
}
static bool FusedConv2DRewrite(const Node* n) {
// MKL DNN currently doesn't support all fusions that grappler fuses
// together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if
@ -2131,9 +2165,6 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
// Number of input slots to original op
// Input slots are represented by .Input() calls in REGISTER_OP.
int old_node_input_slots = old_node->op_def().input_arg_size();
// Actual number of inputs can be greater than or equal to number
// of Input slots because inputs of type list could be unfolded.
CHECK_GE(old_node_inputs.size(), old_node_input_slots);
int nn_slot_idx = 0; // slot index for inputs of new node
// Let's copy all inputs (TF tensors) of original node to new node.
@ -2141,13 +2172,14 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
// An input slot could be a single tensor or a list. We need
// to handle this case accordingly.
CHECK_LT(iidx, old_node_inputs.size());
const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
if (ArgIsList(arg)) {
std::vector<NodeBuilder::NodeOut> new_node_inputs;
int N = GetTensorListLength(arg, old_node);
GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
&new_node_inputs);
int tensor_list_length = GetTensorListLength(arg, old_node);
if (tensor_list_length != 0) {
GetNodesProducingTFTensorList(old_node_inputs, &iidx,
tensor_list_length, &new_node_inputs);
}
nb->Input(new_node_inputs);
nn_slot_idx++;
} else {
@ -2180,13 +2212,14 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
// An input slot could be a single tensor or a list. We need
// to handle this case accordingly.
CHECK_LT(iidx, old_node_inputs.size());
const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
if (ArgIsList(arg)) {
std::vector<NodeBuilder::NodeOut> new_node_inputs;
int N = GetTensorListLength(arg, old_node);
GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N,
&new_node_inputs);
int tensor_list_length = GetTensorListLength(arg, old_node);
if (tensor_list_length != 0) {
GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx,
tensor_list_length, &new_node_inputs);
}
nb->Input(new_node_inputs);
nn_slot_idx++;
} else {
@ -3702,6 +3735,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
n->type_string() != csinfo_.pad_with_conv2d &&
n->type_string() != csinfo_.pad_with_fused_conv2d &&
n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
n->type_string() != csinfo_.fused_batch_norm_ex &&
n->type_string() != csinfo_.fused_conv2d &&
n->type_string() != csinfo_.fused_matmul &&
!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),

View File

@ -3108,6 +3108,112 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV3_Negative) {
"B->F:1;C->F:2;D->F:3;E->F:4;F->G:1");
}
#ifdef ENABLE_MKLDNN_V1
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT \
"'}" \
"node { name: 'B' op: 'Input'}" \
"node { name: 'C' op: 'Input'}" \
"node { name: 'D' op: 'Input'}" \
"node { name: 'E' op: 'Input'}" \
"node { name: 'F' op: '_FusedBatchNormEx'" \
" attr { key: 'T' value { type: " #T \
" } }" \
" attr { key: 'U' value { type: DT_FLOAT } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'epsilon' value { f: 0.0001 } }" \
" attr { key: 'num_side_inputs' value { i: 0 } }" \
" attr { key: 'is_training' value { b: true } }" \
" attr { key: 'activation_mode' value { s: 'Relu' } }" \
" input: ['A', 'B', 'C', 'D', 'E'] }" \
"node { name: 'G' op: 'Zeta'" \
" attr { key: 'T' value { type: " #T \
" } }" \
" input: ['A', 'F'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT \
");B(Input);C(Input);D(Input);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);" \
"DMT/_4(Const);E(Input);" \
"F(_MklFusedBatchNormEx);G(Zeta)|A->F;A->G;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;A:control->DMT/_3:control;" \
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;" \
"DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;" \
"E->F:4;F->G:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Positive);
#undef REGISTER_TEST
// Rewrite test for _FusedBatchNormEx Op with side input
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT \
"'}" \
"node { name: 'B' op: 'Input'}" \
"node { name: 'C' op: 'Input'}" \
"node { name: 'D' op: 'Input'}" \
"node { name: 'E' op: 'Input'}" \
"node { name: 'F' op: '" #INPUT \
"'}" \
"node { name: 'G' op: '_FusedBatchNormEx'" \
" attr { key: 'T' value { type: " #T \
" } }" \
" attr { key: 'U' value { type: DT_FLOAT } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'epsilon' value { f: 0.0001 } }" \
" attr { key: 'num_side_inputs' value { i: 1 } }" \
" attr { key: 'is_training' value { b: true } }" \
" attr { key: 'activation_mode' value { s: 'Relu' } }" \
" input: ['A', 'B', 'C', 'D', 'E', 'F'] }" \
"node { name: 'H' op: 'Zeta'" \
" attr { key: 'T' value { type: " #T \
" } }" \
" input: ['A', 'G'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT \
");B(Input);C(Input);D(Input);E(Input);" \
"F(" #INPUT \
");G(_FusedBatchNormEx);H(Zeta)|A->G;A->H;" \
"B->G:1;C->G:2;D->G:3;E->G:4;F->G:5;G->H:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Negative1);
#undef REGISTER_TEST
// Rewrite test for _FusedBatchNormEx Op with Identity activation
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT \
"'}" \
"node { name: 'B' op: 'Input'}" \
"node { name: 'C' op: 'Input'}" \
"node { name: 'D' op: 'Input'}" \
"node { name: 'E' op: 'Input'}" \
"node { name: 'G' op: '_FusedBatchNormEx'" \
" attr { key: 'T' value { type: " #T \
" } }" \
" attr { key: 'U' value { type: DT_FLOAT } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'epsilon' value { f: 0.0001 } }" \
" attr { key: 'num_side_inputs' value { i: 1 } }" \
" attr { key: 'is_training' value { b: true } }" \
" attr { key: 'activation_mode' value { s: 'Identity' } }" \
" input: ['A', 'B', 'C', 'D', 'E'] }" \
"node { name: 'H' op: 'Zeta'" \
" attr { key: 'T' value { type: " #T \
" } }" \
" input: ['A', 'G'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT \
");B(Input);C(Input);D(Input);E(Input);" \
"G(_FusedBatchNormEx);H(Zeta)|A->G;A->H;" \
"B->G:1;C->G:2;D->G:3;E->G:4;G->H:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Negative2);
#undef REGISTER_TEST
#endif // ENABLE_MKLDNN_V1
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizedDepthwiseConv2D_Positive) {
InitGraph(
"node { name: 'A' op: 'QuantizedUnsignedInt8Input'}"

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/remapper.h"
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -173,6 +174,178 @@ TEST_F(MklRemapperTest, FuseConv2DWithBiasAndAddNRelu) {
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
#ifdef ENABLE_MKLDNN_V1
TEST_F(MklRemapperTest, FuseBatchNormWithRelu) {
using ::tensorflow::ops::Placeholder;
for (bool is_training : {true, false}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
const int num_channels = 24;
TensorShape channel_shape({num_channels});
TensorShape empty_shape({0});
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT,
ops::Placeholder::Shape({2, 8, 8, num_channels}));
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_FLOAT);
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT);
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT);
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT);
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT);
float epsilon = 0.1f;
auto fbn = ops::FusedBatchNormV3(
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
ops::FusedBatchNormV3::IsTraining(is_training)
.Epsilon(epsilon)
.DataFormat("NHWC"));
auto relu = ops::Relu(s.WithOpName("relu"), fbn.y);
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
auto scale_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
auto offset_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
auto mean_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
: channel_shape);
auto var_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
: channel_shape);
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_t},
{"scale", scale_t},
{"offset", offset_t},
{"mean", mean_t},
{"var", var_t}};
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
// Place all nodes on CPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
}
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "relu") {
EXPECT_EQ(node.op(), "Identity");
ASSERT_EQ(node.input_size(), 1);
EXPECT_EQ(node.input(0), "fused_batch_norm");
found++;
}
if (node.name() == "fused_batch_norm") {
EXPECT_EQ(node.op(), "_FusedBatchNormEx");
ASSERT_EQ(node.input_size(), 5);
EXPECT_EQ(node.input(0), "input_cast");
EXPECT_EQ(node.input(1), "scale");
EXPECT_EQ(node.input(2), "offset");
EXPECT_EQ(node.input(3), "mean");
EXPECT_EQ(node.input(4), "var");
auto attr = node.attr();
EXPECT_EQ(attr["num_side_inputs"].i(), 0);
EXPECT_EQ(attr["activation_mode"].s(), "Relu");
found++;
}
}
EXPECT_EQ(found, 2);
}
}
TEST_F(MklRemapperTest, FuseBatchNormWithAddAndRelu) {
using ::tensorflow::ops::Placeholder;
for (bool is_training : {true, false}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
const int num_channels = 24;
TensorShape input_shape({2, 8, 8, num_channels});
TensorShape channel_shape({num_channels});
TensorShape empty_shape({0});
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT,
ops::Placeholder::Shape(input_shape));
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_FLOAT);
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT);
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT);
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT);
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT);
auto side_input = Placeholder(s.WithOpName("side_input"), DT_FLOAT,
ops::Placeholder::Shape(input_shape));
auto side_input_cast =
ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_FLOAT);
float epsilon = 0.1f;
auto fbn = ops::FusedBatchNormV3(
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
ops::FusedBatchNormV3::IsTraining(is_training)
.Epsilon(epsilon)
.DataFormat("NHWC"));
auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast);
auto relu = ops::Relu(s.WithOpName("relu"), add);
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
auto input_t = GenerateRandomTensor<DT_FLOAT>(input_shape);
auto scale_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
auto offset_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
auto mean_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
: channel_shape);
auto var_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
: channel_shape);
auto side_input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_t}, {"scale", scale_t},
{"offset", offset_t}, {"mean", mean_t},
{"var", var_t}, {"side_input", side_input_t}};
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
// Place all nodes on CPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
}
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "add") {
EXPECT_EQ(node.op(), "Add");
ASSERT_EQ(node.input_size(), 2);
EXPECT_EQ(node.input(0), "fused_batch_norm");
EXPECT_EQ(node.input(1), "side_input_cast");
found++;
}
if (node.name() == "relu") {
EXPECT_EQ(node.op(), "Relu");
ASSERT_EQ(node.input_size(), 1);
EXPECT_EQ(node.input(0), "add");
found++;
}
if (node.name() == "fused_batch_norm") {
EXPECT_EQ(node.op(), "FusedBatchNormV3");
ASSERT_EQ(node.input_size(), 5);
EXPECT_EQ(node.input(0), "input_cast");
EXPECT_EQ(node.input(1), "scale");
EXPECT_EQ(node.input(2), "offset");
EXPECT_EQ(node.input(3), "mean");
EXPECT_EQ(node.input(4), "var");
found++;
}
}
EXPECT_EQ(found, 3);
}
}
#endif // ENABLE_MKLDNN_V1
} // namespace grappler
} // namespace tensorflow
#endif // INTEL_MKL

View File

@ -741,24 +741,27 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
[&](const utils::MutableNodeView& fused_batch_norm) -> bool {
const auto* fused_batch_norm_node_def = fused_batch_norm.node();
if (!IsFusedBatchNorm(*fused_batch_norm_node_def)) return false;
// We fuse FusedBatchNorm only on GPU, because on CPU we fuse it with
// contraction (MatMul or Conv2D node).
// We fuse FusedBatchNorm on GPU or MKL CPU.
#ifndef ENABLE_MKLDNN_V1
if (!NodeIsOnGpu(fused_batch_norm_node_def)) return false;
#endif
DataType t_dtype = GetDataTypeFromAttr(*fused_batch_norm_node_def, "T");
#ifndef ENABLE_MKLDNN_V1
if (t_dtype != DT_FLOAT && t_dtype != DT_HALF) return false;
#else
if (t_dtype != DT_FLOAT && t_dtype != DT_BFLOAT16) return false;
#endif
// Get the FusedBatchNorm training mode.
bool is_training;
if (!GetNodeAttr(*fused_batch_norm_node_def, kIsTraining, &is_training)
.ok())
return false;
// In training mode we rely on cuDNN for computing FusedBatchNorm with side
// inputs and activation, and it has its own limitations. In inference mode
// we have a custom CUDA kernel that doesn't not have these constraints.
if (is_training) {
if (is_training && NodeIsOnGpu(fused_batch_norm_node_def)) {
// cuDNN only supports NHWC data layout.
string data_format;
if (!GetNodeAttr(*fused_batch_norm_node_def, kDataFormat, &data_format)
@ -810,6 +813,12 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
// Input to a Relu can be an Add node with FusedBatchNorm as one of the inputs
if (IsAdd(*relu_fanin_0_node_def)) {
// Currently no CPU implementation for "FusedBatchNorm + SideInput +
// <Activation>""
#ifdef ENABLE_MKLDNN_V1
return false;
#endif
// Check that only Relu node consumes the output of an Add node.
if (HasControlFaninOrFanout(*relu_fanin_0_node_view) ||
!HasAtMostOneFanoutAtPort0(*relu_fanin_0_node_view) ||
@ -881,7 +890,11 @@ void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm,
if (fused_batch_norm.op() != "FusedBatchNorm") {
(*attr)["U"] = src_attr.at("U");
} else {
#ifndef ENABLE_MKLDNN_V1
(*attr)["U"] = src_attr.at("T");
#else
SetAttrValue(DT_FLOAT, &(*attr)["U"]);
#endif
}
}

View File

@ -8135,7 +8135,14 @@ tf_mkl_kernel_library(
tf_mkl_kernel_library(
name = "mkl_fused_batch_norm_op",
srcs = ["mkl_fused_batch_norm_op.cc"],
deps = NN_DEPS + mkl_deps(),
hdrs = [
"fused_batch_norm_op.h",
"no_op.h",
],
deps = NN_DEPS + [
":fused_batch_norm_op",
":no_op",
] + mkl_deps(),
)
tf_cc_test_mkl(

View File

@ -14,14 +14,16 @@ limitations under the License.
==============================================================================*/
#ifdef INTEL_MKL
#include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/fused_batch_norm_op.h"
#include "tensorflow/core/kernels/no_op.h"
#include "tensorflow/core/util/mkl_types.h"
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/tensor_format.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#define GET_FLAG(bn_flag) static_cast<int>(BN_FLAGS::bn_flag)
#define IS_SET(cflag) (context_.flags & GET_FLAG(cflag))
@ -37,11 +39,14 @@ using BatchNormBwdPd = mkldnn::batch_normalization_backward::primitive_desc;
namespace tensorflow {
using CPUDevice = Eigen::ThreadPoolDevice;
using FusedBNActivationMode = functor::FusedBatchNormActivationMode;
struct MklBatchNormFwdParams {
memory::dims src_dims;
int depth;
float eps;
bool training;
FusedBNActivationMode activation_mode;
#ifndef ENABLE_MKLDNN_V1
MEMORY_FORMAT src_format;
#else
@ -50,14 +55,17 @@ struct MklBatchNormFwdParams {
MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps,
#ifndef ENABLE_MKLDNN_V1
bool training, MEMORY_FORMAT src_format)
bool training, MEMORY_FORMAT src_format,
FusedBNActivationMode activation_mode)
#else
bool training, memory::desc src_md)
bool training, memory::desc src_md,
FusedBNActivationMode activation_mode)
#endif // !ENABLE_MKLDNN_V1
: src_dims(src_dims),
depth(depth),
eps(eps),
training(training),
activation_mode(activation_mode),
#ifndef ENABLE_MKLDNN_V1
src_format(src_format) {
}
@ -90,7 +98,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
// mean_data: output data buffer of means
// variance_data: output data buffer of variances
void Execute(const T* src_data, const U* weights_data, T* dst_data,
U* mean_data, U* variance_data) {
U* mean_data, U* variance_data, U* workspace_data) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
@ -104,6 +112,9 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
context_.mean_mem->set_data_handle(static_cast<void*>(mean_data));
context_.variance_mem->set_data_handle(static_cast<void*>(variance_data));
}
if (workspace_data != nullptr) {
context_.ws_mem->set_data_handle(workspace_data);
}
#ifdef ENABLE_MKLDNN_V1
// Execute batch-normalization forward primitives.
execute_primitives(context_.fwd_primitives, context_.fwd_stream,
@ -123,6 +134,10 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
context_.mean_mem->set_data_handle(DummyData);
context_.variance_mem->set_data_handle(DummyData);
}
if (workspace_data != nullptr) {
context_.ws_mem->set_data_handle(DummyData);
}
}
MEMORY_PRIMITIVE_DESC GetDstPd() const { return context_.dst_mem->GET_DESC; }
@ -158,6 +173,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
std::shared_ptr<mkldnn::memory> dst_mem;
std::shared_ptr<mkldnn::memory> mean_mem;
std::shared_ptr<mkldnn::memory> variance_mem;
std::shared_ptr<mkldnn::memory> ws_mem;
// Forward BatchNorm primitive descriptor.
std::shared_ptr<BatchNormFwdPd> fwd_pd;
@ -179,6 +195,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
dst_mem(nullptr),
mean_mem(nullptr),
variance_mem(nullptr),
ws_mem(nullptr),
bn_fwd(nullptr),
fwd_stream(nullptr) {}
};
@ -192,6 +209,9 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
: prop_kind::forward_scoring;
#ifdef ENABLE_MKLDNN_V1
if (fwdParams.activation_mode == FusedBNActivationMode::kRelu) {
context_.flags |= GET_FLAG(fuse_norm_relu);
}
// Memory descriptor
auto src_md = fwdParams.src_md;
// Create forward BatchNorm descriptor and primitive descriptor.
@ -229,6 +249,13 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
m_dims, U, MEMORY_FORMAT::nc, cpu_engine_, DummyData));
}
#ifdef ENABLE_MKLDNN_V1
if (IS_SET(fuse_norm_relu)) {
context_.ws_mem.reset(new MEMORY_CONSTRUCTOR(
context_.fwd_pd->workspace_desc(), cpu_engine_, DummyData));
}
#endif // ENABLE_MKLDNN_V1
// BatchNorm forward primitive.
// TODO(intel-tf): Merge all the #ifdefs and simplify code
if (!fwdParams.training && !(IS_SET(use_global_stats))) {
@ -258,20 +285,41 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
} else if (IS_SET(use_global_stats)) {
#ifdef ENABLE_MKLDNN_V1
if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) {
context_.net_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_MEAN, *context_.mean_mem},
{MKLDNN_ARG_VARIANCE, *context_.variance_mem},
{MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
{ MKLDNN_ARG_DST,
*context_.dst_mem }});
if (IS_SET(fuse_norm_relu)) {
context_.net_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_MEAN, *context_.mean_mem},
{MKLDNN_ARG_VARIANCE, *context_.variance_mem},
{MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
{MKLDNN_ARG_DST, *context_.dst_mem},
{ MKLDNN_ARG_WORKSPACE,
*context_.ws_mem }});
} else {
context_.net_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_MEAN, *context_.mean_mem},
{MKLDNN_ARG_VARIANCE, *context_.variance_mem},
{MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
{ MKLDNN_ARG_DST,
*context_.dst_mem }});
}
} else {
context_.net_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_MEAN, *context_.mean_mem},
{MKLDNN_ARG_VARIANCE, *context_.variance_mem},
{ MKLDNN_ARG_DST,
*context_.dst_mem }});
if (IS_SET(fuse_norm_relu)) {
context_.net_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_MEAN, *context_.mean_mem},
{MKLDNN_ARG_VARIANCE, *context_.variance_mem},
{MKLDNN_ARG_DST, *context_.dst_mem},
{ MKLDNN_ARG_WORKSPACE,
*context_.ws_mem }});
} else {
context_.net_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_MEAN, *context_.mean_mem},
{MKLDNN_ARG_VARIANCE, *context_.variance_mem},
{ MKLDNN_ARG_DST,
*context_.dst_mem }});
}
}
context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
#else
@ -291,19 +339,40 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
} else {
#ifdef ENABLE_MKLDNN_V1
if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) {
context_.net_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
{MKLDNN_ARG_DST, *context_.dst_mem},
{MKLDNN_ARG_MEAN, *context_.mean_mem},
{ MKLDNN_ARG_VARIANCE,
*context_.variance_mem }});
if (IS_SET(fuse_norm_relu)) {
context_.net_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
{MKLDNN_ARG_DST, *context_.dst_mem},
{MKLDNN_ARG_MEAN, *context_.mean_mem},
{MKLDNN_ARG_VARIANCE, *context_.variance_mem},
{ MKLDNN_ARG_WORKSPACE,
*context_.ws_mem }});
} else {
context_.net_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
{MKLDNN_ARG_DST, *context_.dst_mem},
{MKLDNN_ARG_MEAN, *context_.mean_mem},
{ MKLDNN_ARG_VARIANCE,
*context_.variance_mem }});
}
} else {
context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_DST, *context_.dst_mem},
{MKLDNN_ARG_MEAN, *context_.mean_mem},
{ MKLDNN_ARG_VARIANCE,
*context_.variance_mem }});
if (IS_SET(fuse_norm_relu)) {
context_.net_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_DST, *context_.dst_mem},
{MKLDNN_ARG_MEAN, *context_.mean_mem},
{MKLDNN_ARG_VARIANCE, *context_.variance_mem},
{ MKLDNN_ARG_WORKSPACE,
*context_.ws_mem }});
} else {
context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_DST, *context_.dst_mem},
{MKLDNN_ARG_MEAN, *context_.mean_mem},
{ MKLDNN_ARG_VARIANCE,
*context_.variance_mem }});
}
}
context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
#else
@ -360,6 +429,7 @@ class MklFusedBatchNormFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
key_creator.AddAsKey<int>(fwdParams.depth);
key_creator.AddAsKey<float>(fwdParams.eps);
key_creator.AddAsKey<bool>(fwdParams.training);
key_creator.AddAsKey<FusedBNActivationMode>(fwdParams.activation_mode);
key_creator.AddAsKey(typeid(T).name());
key_creator.AddAsKey(typeid(U).name());
return key_creator.GetKey();
@ -676,7 +746,8 @@ class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
// Adding a third parameter to the template to support FusedBatchNormV3
// with MKL. This is different from default where the classes are
// derived. Moves enabling to compile-time rather than runtime.
template <typename Device, typename T, typename U, bool reserved_space>
template <typename Device, typename T, typename U, bool reserved_space,
bool is_batch_norm_ex = false>
class MklFusedBatchNormOp : public OpKernel {
public:
explicit MklFusedBatchNormOp(OpKernelConstruction* context)
@ -696,6 +767,28 @@ class MklFusedBatchNormOp : public OpKernel {
depth_ = 0;
mean_values_ = nullptr;
variance_values_ = nullptr;
#ifndef ENABLE_MKLDNN_V1
OP_REQUIRES(context, !is_batch_norm_ex,
errors::InvalidArgument(
"_MklFusedBatchNormEx is not supported in DNNL 0.x ."));
#endif
if (!is_batch_norm_ex) {
activation_mode_ = FusedBNActivationMode::kIdentity;
} else {
int num_side_inputs;
OP_REQUIRES_OK(context,
context->GetAttr("num_side_inputs", &num_side_inputs));
// Currently _MKLFusedBatchNormEx do not support "SideInput"
OP_REQUIRES(context, num_side_inputs == 0,
errors::InvalidArgument(
"_MKLFusedBatchNorm do not support side input now."));
OP_REQUIRES_OK(context, ParseActivationMode(context, &activation_mode_));
OP_REQUIRES(context, activation_mode_ == FusedBNActivationMode::kRelu,
errors::InvalidArgument(
"_MKLFusedBatchNorm only support Relu activation"));
}
}
void Compute(OpKernelContext* context) override {
@ -744,9 +837,12 @@ class MklFusedBatchNormOp : public OpKernel {
// Handle the special case: input with 0 element and 0 batch size.
Tensor* dst_tensor = nullptr;
TensorShape workspace_tf_shape;
if (tf_shape_src.num_elements() == 0) {
HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(),
&dst_tensor);
size_t workspace_bytes = 0;
workspace_tf_shape.AddDim(workspace_bytes);
HandleEmptyInput(context, tf_shape_src, workspace_tf_shape,
scale_tensor.shape(), &dst_tensor);
return;
}
@ -758,23 +854,16 @@ class MklFusedBatchNormOp : public OpKernel {
// Index of output tensor(diff_src).
const size_t kDstIndex = 0;
// Allocate 4 output TF tensors.
// Allocate 5 output TF tensors.
Tensor* batch_mean_tensor = nullptr;
Tensor* batch_variance_tensor = nullptr;
Tensor* saved_mean_tensor = nullptr;
Tensor* saved_variance_tensor = nullptr;
Tensor* reserved_space_tensor = nullptr;
AllocateTFOutputs(context, scale_tensor.shape(), &batch_mean_tensor,
&batch_variance_tensor, &saved_mean_tensor,
&saved_variance_tensor, &reserved_space_tensor);
if (is_training_)
SetMeanVariance(*batch_mean_tensor, *batch_variance_tensor);
else
SetMeanVariance(est_mean_tensor, est_variance_tensor);
MklDnnData<T> src(&cpu_engine_);
MklDnnData<U> weights(&cpu_engine_);
MklDnnData<U> wksp(&cpu_engine_);
MEMORY_FORMAT dnn_fmt;
MKL_TENSOR_FORMAT mkl_tensor_fmt;
@ -801,6 +890,51 @@ class MklFusedBatchNormOp : public OpKernel {
? dnn_shape_src.GetMklLayout()
: memory::desc(src_dims, MklDnnType<T>(), dnn_fmt);
#ifdef ENABLE_MKLDNN_V1
MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_,
src_md, activation_mode_);
#else
MklBatchNormFwdParams fwdParams(
src_dims, depth_, epsilon_, is_training_,
static_cast<MEMORY_FORMAT>(src_md.data.format), activation_mode_);
#endif // ENABLE_MKLDNN_V1
// Get forward batch-normalization op from the primitive caching pool.
MklFusedBatchNormFwdPrimitive<T, U>* bn_fwd =
MklFusedBatchNormFwdPrimitiveFactory<T, U>::Get(fwdParams);
// Allocate workspace tensor
U* ws_data = nullptr;
if (fwdParams.activation_mode == FusedBNActivationMode::kRelu) {
#ifdef ENABLE_MKLDNN_V1
MEMORY_PRIMITIVE_DESC workspace_pd =
bn_fwd->GetBatchNormFwdPd()->workspace_desc();
size_t workspace_bytes = workspace_pd.get_size();
workspace_tf_shape.AddDim(workspace_bytes);
AllocateTFOutputs(context, scale_tensor.shape(), workspace_tf_shape,
&batch_mean_tensor, &batch_variance_tensor,
&saved_mean_tensor, &saved_variance_tensor,
&reserved_space_tensor);
if (reserved_space) {
wksp.SetUsrMem(workspace_pd, reserved_space_tensor);
ws_data = static_cast<U*>(wksp.GetOpMem().get_data_handle());
}
#endif // ENABLE_MKLDNN_V1
} else {
// There is actually no workspace tensor out, so we make a dummy one.
size_t workspace_bytes = 0;
workspace_tf_shape.AddDim(workspace_bytes);
AllocateTFOutputs(context, scale_tensor.shape(), workspace_tf_shape,
&batch_mean_tensor, &batch_variance_tensor,
&saved_mean_tensor, &saved_variance_tensor,
&reserved_space_tensor);
}
if (is_training_)
SetMeanVariance(*batch_mean_tensor, *batch_variance_tensor);
else
SetMeanVariance(est_mean_tensor, est_variance_tensor);
// MKL-DNN packs scale & shift as "weights":
// <scale>...<scale><shift>...<shift>
weights.AllocateBuffer(2 * depth_ * sizeof(U));
@ -821,18 +955,6 @@ class MklFusedBatchNormOp : public OpKernel {
reinterpret_cast<char*>(variance_values_),
depth_ * sizeof(U));
#ifdef ENABLE_MKLDNN_V1
MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_,
src_md);
#else
MklBatchNormFwdParams fwdParams(
src_dims, depth_, epsilon_, is_training_,
static_cast<MEMORY_FORMAT>(src_md.data.format));
#endif // ENABLE_MKLDNN_V1
// Get forward batch-normalization op from the primitive caching pool.
MklFusedBatchNormFwdPrimitive<T, U>* bn_fwd =
MklFusedBatchNormFwdPrimitiveFactory<T, U>::Get(fwdParams);
// Check if reorder is needed for src.
const T* src_data = nullptr;
std::shared_ptr<BatchNormFwdPd> bn_fwd_pd = bn_fwd->GetBatchNormFwdPd();
@ -866,7 +988,7 @@ class MklFusedBatchNormOp : public OpKernel {
// Execute
bn_fwd->Execute(src_data, weights_op_data, dst_data, mean_op_data,
variance_op_data);
variance_op_data, ws_data);
float adjust_factor = 1.0;
if (is_training_) {
@ -924,6 +1046,7 @@ class MklFusedBatchNormOp : public OpKernel {
U* mean_values_;
U* variance_values_;
size_t depth_; // Batch normalization is performed for per channel.
FusedBNActivationMode activation_mode_;
engine cpu_engine_ = engine(ENGINE_CPU, 0);
void ExtractParams(OpKernelContext* context) {
@ -938,6 +1061,7 @@ class MklFusedBatchNormOp : public OpKernel {
}
void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src,
TensorShape workspace_tf_shape,
TensorShape tf_shape_scale, Tensor** dst_tensor) {
DCHECK(dst_tensor);
@ -955,12 +1079,14 @@ class MklFusedBatchNormOp : public OpKernel {
Tensor* saved_mean_tensor = nullptr;
Tensor* saved_variance_tensor = nullptr;
Tensor* reserved_space_tensor = nullptr;
AllocateTFOutputs(context, tf_shape_scale, &batch_mean_tensor,
&batch_variance_tensor, &saved_mean_tensor,
&saved_variance_tensor, &reserved_space_tensor);
AllocateTFOutputs(context, tf_shape_scale, workspace_tf_shape,
&batch_mean_tensor, &batch_variance_tensor,
&saved_mean_tensor, &saved_variance_tensor,
&reserved_space_tensor);
}
void AllocateTFOutputs(OpKernelContext* context, TensorShape tf_shape_scale,
TensorShape workspace_tf_shape,
Tensor** batch_mean_tensor,
Tensor** batch_variance_tensor,
Tensor** saved_mean_tensor,
@ -1024,21 +1150,15 @@ class MklFusedBatchNormOp : public OpKernel {
std::fill_n(saved_variance_data, num_elements, static_cast<U>(0));
// Changes to support reserved_space_3 parameter in FusedBatchNormV3.
// TODO: This parameter functionality is not implemented on CPU.
// It is used to hold intermediate results. So the allocated
// memory is filled with 0s.
if (reserved_space) {
DCHECK(reserved_space_tensor != nullptr);
MklDnnShape mkl_shape_reserved_space;
mkl_shape_reserved_space.SetMklTensor(false);
AllocateOutputSetMklShape(context, kReservedSpaceIndex,
reserved_space_tensor, tf_shape_scale,
reserved_space_tensor, workspace_tf_shape,
mkl_shape_reserved_space);
DCHECK((*reserved_space_tensor) != nullptr);
auto saved_reserved_space_data =
(*reserved_space_tensor)->flat<U>().data();
std::fill_n(saved_reserved_space_data, num_elements, static_cast<U>(0));
}
}
};
@ -1367,7 +1487,7 @@ class MklFusedBatchNormGradOp : public OpKernel {
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklFusedBatchNormOp<CPUDevice, T, T, false>);
MklFusedBatchNormOp<CPUDevice, T, T, false, false>);
TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_CPU);
TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_CPU);
@ -1380,7 +1500,7 @@ TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_CPU);
.TypeConstraint<T>("T") \
.TypeConstraint<U>("U") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklFusedBatchNormOp<CPUDevice, T, U, false>);
MklFusedBatchNormOp<CPUDevice, T, U, false, false>);
REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(float, float);
REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(bfloat16, float);
@ -1421,12 +1541,30 @@ REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(bfloat16, float);
.TypeConstraint<T>("T") \
.TypeConstraint<U>("U") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklFusedBatchNormOp<CPUDevice, T, U, true>);
MklFusedBatchNormOp<CPUDevice, T, U, true, false>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklFusedBatchNormEx") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.TypeConstraint<U>("U") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklFusedBatchNormOp<CPUDevice, T, U, true, true>);
REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(float, float);
REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(bfloat16, float);
#undef REGISTER_MKL_FUSED_BATCHNORM_V3_CPU
REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx")
.Device(DEVICE_CPU)
.TypeConstraint<float>("T")
.TypeConstraint<float>("U"),
NoOp);
REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx")
.Device(DEVICE_CPU)
.TypeConstraint<bfloat16>("T")
.TypeConstraint<float>("U"),
NoOp);
#define REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(T, U) \
REGISTER_KERNEL_BUILDER( \
Name("_MklFusedBatchNormGradV3") \

View File

@ -1342,6 +1342,48 @@ REGISTER_OP("_MklFusedBatchNormGradV3")
R"doc(MKL-DNN implementation of FusedBatchNormGradV3: Do not invoke this operator directly in Python.
Graph rewrite pass is expected to invoke this operator.)doc");
REGISTER_OP("_MklFusedBatchNormEx")
.Input("x: T")
.Input("scale: U")
.Input("offset: U")
.Input("mean: U")
.Input("variance: U")
.Input("side_input: num_side_inputs * T")
.Input("mkl_x: uint8")
.Input("mkl_scale: uint8")
.Input("mkl_offset: uint8")
.Input("mkl_mean: uint8")
.Input("mkl_variance: uint8")
.Input("mkl_side_input: num_side_inputs * uint8")
.Output("y: T")
.Output("batch_mean: U")
.Output("batch_variance: U")
.Output("reserve_space_1: U")
.Output("reserve_space_2: U")
.Output("reserve_space_3: U")
.Output("mkl_y: uint8")
.Output("mkl_batch_mean: uint8")
.Output("mkl_batch_variance: uint8")
.Output("mkl_reserve_space_1: uint8")
.Output("mkl_reserve_space_2: uint8")
.Output("mkl_reserve_space_3: uint8")
.Attr("T: {bfloat16, float}")
.Attr("U: {float}")
.Attr("epsilon: float = 0.0001")
.Attr("exponential_avg_factor: float = 1.0")
.Attr(GetConvnetDataFormatAttrString())
.Attr("num_side_inputs: int >= 0 = 0")
.Attr("activation_mode: string = \"Identity\"")
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormShape)
.Doc(R"doc(
MKL version of FusedBatchNormEx operator. Uses MKL DNN APIs to perform fused
batch normalization and relu.
NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
} // namespace tensorflow
#endif // INTEL_MKL

View File

@ -238,7 +238,11 @@ REGISTER_OP("_FusedBatchNormEx")
.Output("reserve_space_1: U")
.Output("reserve_space_2: U")
.Output("reserve_space_3: U")
#ifdef ENABLE_MKLDNN_V1
.Attr("T: {half, float, bfloat16}")
#else
.Attr("T: {half, float}")
#endif
.Attr("U: {float}")
.Attr("epsilon: float = 0.0001")
.Attr("exponential_avg_factor: float = 1.0")