[Intel MKL] Compilation fixes to integrate DNNL 1.0
This commit is contained in:
parent
be6324180a
commit
4f61e4cab9
tensorflow
core
graph
kernels
mkl_avgpooling_op.ccmkl_concat_op.ccmkl_conv_grad_filter_ops.ccmkl_conv_grad_input_ops.ccmkl_conv_ops.ccmkl_dequantize_op.ccmkl_fused_batch_norm_op.ccmkl_matmul_op_fused.ccmkl_matmul_ops_common.hmkl_maxpooling_op.ccmkl_pooling_ops_common.ccmkl_pooling_ops_common.hmkl_qmatmul_op.ccmkl_relu_op.ccmkl_softmax_op.cc
ops
util
third_party/mkl_dnn
@ -365,7 +365,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
|
||||
CopyAttrsAll, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
|
||||
CopyAttrsAll, RewriteIfAtleastOneMklInput,
|
||||
kRewriteForLayoutPropagation});
|
||||
@ -403,12 +402,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
{csinfo_.conjugate_transpose,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
rinfo_.push_back({csinfo_.conv2d,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.conv2d),
|
||||
CopyAttrsConvCheckConstFilter, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias,
|
||||
CopyAttrsConvCheckConstFilter, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
@ -477,20 +474,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
{csinfo_.fused_batch_norm_grad_v3,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d,
|
||||
CopyAttrsFusedConv2D, FusedConv2DRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back({csinfo_.fused_matmul, csinfo_.mkl_fused_matmul,
|
||||
CopyAttrsAllCheckConstFilter, FusedMatMulRewrite});
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
rinfo_.push_back({csinfo_.identity,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.identity),
|
||||
CopyAttrsAll, RewriteIfAtleastOneMklInput,
|
||||
kRewriteForLayoutPropagation});
|
||||
|
||||
rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
|
||||
CopyAttrsAll, LrnRewrite, kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back(
|
||||
@ -654,7 +647,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
mkl_op_registry::GetMklOpName(csinfo_.quantize_v2),
|
||||
CopyAttrsAll, QuantizeOpRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
|
||||
CopyAttrsAll, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
@ -667,11 +659,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
rinfo_.push_back(
|
||||
{csinfo_.relu6_grad, mkl_op_registry::GetMklOpName(csinfo_.relu6_grad),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
rinfo_.push_back(
|
||||
{csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
// Disable these two MKL operators for now due to some test failures caused
|
||||
// by these two ops
|
||||
/*
|
||||
@ -691,7 +681,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
mkl_op_registry::GetMklOpName(csinfo_.slice),
|
||||
CopyAttrsAll, RewriteIfAtleastOneMklInput,
|
||||
kRewriteForLayoutPropagation});
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
rinfo_.push_back(
|
||||
{csinfo_.softmax, mkl_op_registry::GetMklOpName(csinfo_.softmax),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
@ -798,7 +787,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
// CheckForMklOp
|
||||
FuseMaxPool3D,
|
||||
CopyAttrsPooling});
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
// Standard interface to run pass
|
||||
|
@ -29,7 +29,9 @@ using mkldnn::algorithm;
|
||||
using mkldnn::engine;
|
||||
using mkldnn::error;
|
||||
using mkldnn::memory;
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
using mkldnn::padding_kind;
|
||||
#endif
|
||||
using mkldnn::pooling_backward;
|
||||
using mkldnn::pooling_forward;
|
||||
using mkldnn::prop_kind;
|
||||
@ -108,11 +110,18 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
|
||||
pooling_prop_kind = prop_kind::forward_inference;
|
||||
else
|
||||
pooling_prop_kind = prop_kind::forward_training;
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
MklPoolingParams fwdParams(
|
||||
src_dims, output_dims_mkl_order, filter_dims, strides, padding_left,
|
||||
padding_right, ALGORITHM::pooling_avg_exclude_padding,
|
||||
pooling_prop_kind, static_cast<memory::format>(input_md.data.format));
|
||||
|
||||
pooling_prop_kind,
|
||||
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_));
|
||||
#else
|
||||
MklPoolingParams fwdParams(
|
||||
src_dims, output_dims_mkl_order, filter_dims, strides, padding_left,
|
||||
padding_right, ALGORITHM::pooling_avg_exclude_padding,
|
||||
pooling_prop_kind, static_cast<MEMORY_FORMAT>(input_md.data.format));
|
||||
#endif
|
||||
pooling_fwd = MklPoolingFwdPrimitiveFactory<T>::Get(fwdParams);
|
||||
|
||||
// Allocate output tensor.
|
||||
@ -224,11 +233,19 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
|
||||
|
||||
// Pass prop_kind::forward_training to create a forward primitive
|
||||
// that is used in the backward pass.
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
MklPoolingParams bwdParams(
|
||||
orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
|
||||
strides, padding_left, padding_right,
|
||||
ALGORITHM::pooling_avg_exclude_padding, prop_kind::forward_training,
|
||||
static_cast<memory::format>(src_md.data.format));
|
||||
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_));
|
||||
#else
|
||||
MklPoolingParams bwdParams(
|
||||
orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
|
||||
strides, padding_left, padding_right,
|
||||
ALGORITHM::pooling_avg_exclude_padding, prop_kind::forward_training,
|
||||
static_cast<MEMORY_FORMAT>(src_md.data.format));
|
||||
#endif
|
||||
MklPoolingBwdPrimitive<T>* pooling_bwd =
|
||||
MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);
|
||||
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "mkldnn.hpp"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
@ -32,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/mkl_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
using mkldnn::concat;
|
||||
using mkldnn::stream;
|
||||
@ -266,7 +266,7 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
||||
explicit MklConcatFwdPrimitive(const MklConcatFwdParams& concat_fwd_dims,
|
||||
const std::vector<memory::desc>& srcs_md)
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
context_.fwd_stream.reset(new CPU_STREAM(stream::kind::eager));
|
||||
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
// Create concat primitive
|
||||
Setup(concat_fwd_dims, srcs_md);
|
||||
}
|
||||
@ -292,8 +292,8 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
||||
}
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.fwd_primitives, *context_.fwd_stream,
|
||||
context_.fwd_primitives_args.at(i));
|
||||
execute_primitives(context_.fwd_primitives, context_.fwd_stream,
|
||||
context_.fwd_primitives_args);
|
||||
#else
|
||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
@ -328,7 +328,7 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
||||
std::shared_ptr<mkldnn::memory> dst_mem;
|
||||
|
||||
// Memory descriptor
|
||||
std::vector<std::shared_ptr<mkldnn::memory::desc>> src_md;
|
||||
std::vector<mkldnn::memory::desc> src_md;
|
||||
std::shared_ptr<mkldnn::memory::desc> dst_md;
|
||||
|
||||
// Concat primitive descriptor
|
||||
@ -339,7 +339,7 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
||||
std::vector<mkldnn::primitive> fwd_primitives;
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
std::vector<std::unordered_map<int, memory>> fwd_primitive_args;
|
||||
std::vector<std::unordered_map<int, memory>> fwd_primitives_args;
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
ConcatFwdContext()
|
||||
@ -355,15 +355,14 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
||||
const std::vector<memory::desc>& srcs_md) {
|
||||
// Create memory descriptors for concat with specified srcs format
|
||||
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
|
||||
std::shared_ptr<mkldnn::memory::desc> source_md(
|
||||
new memory::desc(srcs_md[i].data));
|
||||
mkldnn::memory::desc source_md(memory::desc(srcs_md[i].data));
|
||||
context_.src_md.push_back(source_md);
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
std::shared_ptr<mkldnn::memory> src_mem(
|
||||
new mkldnn::memory(*source_md, cpu_engine_, DummyData));
|
||||
new mkldnn::memory(source_md, cpu_engine_, DummyData));
|
||||
#else
|
||||
std::shared_ptr<mkldnn::memory::primitive_desc> src_mpd(
|
||||
new memory::primitive_desc(*source_md, cpu_engine_));
|
||||
new memory::primitive_desc(source_md, cpu_engine_));
|
||||
context_.src_pd_shdptr.push_back(src_mpd);
|
||||
|
||||
std::shared_ptr<mkldnn::memory> src_mem(
|
||||
@ -665,8 +664,9 @@ class MklConcatOp : public OpKernel {
|
||||
if (input_tensors[k].NumElements() == 0) continue;
|
||||
auto src_md = mkl_input_shapes[k].GetMklLayout();
|
||||
srcs[k].SetUsrMem(src_md, &input_tensors[k]);
|
||||
|
||||
if (src_md.data.format != mkl_common_format) {
|
||||
auto src_tf_fmt = MklTensorFormatToMklDnnDataFormat(
|
||||
mkl_input_shapes[k].GetTfDataFormat());
|
||||
if (src_tf_fmt != mkl_common_format) {
|
||||
memory::dims src_dims(src_md.data.dims,
|
||||
&src_md.data.dims[src_md.data.ndims]);
|
||||
src_md =
|
||||
@ -935,7 +935,8 @@ class MklConcatOp : public OpKernel {
|
||||
for (int k = 0; k < input_shapes.size(); k++) {
|
||||
auto src_dims = TFShapeToMklDnnDims(input_shapes[k].GetTfShape());
|
||||
*concat_dim_size += src_dims[concat_dim];
|
||||
int fmt = static_cast<int>(input_shapes[k].GetMklLayout().data.format);
|
||||
int fmt = static_cast<int>(
|
||||
MklTensorFormatToMklDnnDataFormat(input_shapes[k].GetTfDataFormat()));
|
||||
occurrence_map[fmt] += 1;
|
||||
}
|
||||
|
||||
@ -943,7 +944,7 @@ class MklConcatOp : public OpKernel {
|
||||
// this means that all inputs have a same format
|
||||
// return it with is_reorder_needed set false.
|
||||
return static_cast<MEMORY_FORMAT>(
|
||||
input_shapes[0].GetMklLayout().data.format);
|
||||
MklTensorFormatToMklDnnDataFormat(input_shapes[0].GetTfDataFormat()));
|
||||
}
|
||||
|
||||
// Input tensors have different formats. Thus, reorder is needed.
|
||||
@ -970,7 +971,7 @@ class MklConcatOp : public OpKernel {
|
||||
.TypeConstraint<type>("T") \
|
||||
.HostMemory("concat_dim") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklConcatOp<CPUDevice, type, NAME_IS_CONCAT_DIM>) \
|
||||
MklConcatOp<CPUDevice, type, NAME_IS_CONCAT_DIM>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_MklConcatV2") \
|
||||
.Device(DEVICE_CPU) \
|
||||
@ -978,7 +979,7 @@ class MklConcatOp : public OpKernel {
|
||||
.TypeConstraint<int32>("Tidx") \
|
||||
.HostMemory("axis") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklConcatOp<CPUDevice, type, NAME_IS_AXIS>)
|
||||
MklConcatOp<CPUDevice, type, NAME_IS_AXIS>);
|
||||
|
||||
TF_CALL_float(REGISTER_MKL_CPU);
|
||||
TF_CALL_bfloat16(REGISTER_MKL_CPU);
|
||||
@ -988,14 +989,14 @@ REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2")
|
||||
.TypeConstraint<quint8>("T")
|
||||
.HostMemory("axis")
|
||||
.Label(mkl_op_registry::kMklQuantizedOpLabel),
|
||||
MklConcatOp<CPUDevice, quint8, NAME_IS_AXIS>)
|
||||
MklConcatOp<CPUDevice, quint8, NAME_IS_AXIS>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<qint8>("T")
|
||||
.HostMemory("axis")
|
||||
.Label(mkl_op_registry::kMklQuantizedOpLabel),
|
||||
MklConcatOp<CPUDevice, qint8, NAME_IS_AXIS>)
|
||||
MklConcatOp<CPUDevice, qint8, NAME_IS_AXIS>);
|
||||
|
||||
#undef REGISTER_CONCAT_MKL
|
||||
} // namespace tensorflow
|
||||
|
@ -62,13 +62,19 @@ struct MklConvBwdFilterParams {
|
||||
memory::dims dilations;
|
||||
memory::dims padding_left;
|
||||
memory::dims padding_right;
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
padding_kind padding;
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
MklConvBwdFilterParams(memory::dims src_dims, memory::dims diff_filter_dims,
|
||||
memory::dims diff_bias_dims,
|
||||
memory::dims diff_dst_dims, memory::dims strides,
|
||||
memory::dims dilations, memory::dims padding_left,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
memory::dims padding_right, padding_kind padding)
|
||||
#else
|
||||
memory::dims padding_right)
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
: src_dims(src_dims),
|
||||
diff_filter_dims(diff_filter_dims),
|
||||
diff_bias_dims(diff_bias_dims),
|
||||
@ -76,8 +82,14 @@ struct MklConvBwdFilterParams {
|
||||
strides(strides),
|
||||
dilations(dilations),
|
||||
padding_left(padding_left),
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
padding_right(padding_right),
|
||||
padding(padding) {}
|
||||
padding(padding) {
|
||||
}
|
||||
#else
|
||||
padding_right(padding_right) {
|
||||
}
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -241,8 +253,12 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
||||
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
|
||||
*context_.diff_filter_md, *context_.diff_dst_md,
|
||||
convBwdFilterDims.strides, convBwdFilterDims.dilations,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
convBwdFilterDims.padding_left, convBwdFilterDims.padding_right,
|
||||
convBwdFilterDims.padding));
|
||||
#else
|
||||
convBwdFilterDims.padding_left, convBwdFilterDims.padding_right));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
|
||||
|
||||
// Create descriptor and primitive descriptor for convolution bwd filter.
|
||||
@ -252,14 +268,22 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
||||
*context_.diff_filter_md, *context_.diff_bias_md,
|
||||
*context_.diff_dst_md, convBwdFilterDims.strides,
|
||||
convBwdFilterDims.dilations, convBwdFilterDims.padding_left,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
convBwdFilterDims.padding_right, convBwdFilterDims.padding));
|
||||
#else
|
||||
convBwdFilterDims.padding_right));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
} else {
|
||||
context_.bwd_filter_desc.reset(new ConvBwdFilterDesc(
|
||||
ALGORITHM::convolution_direct, *context_.src_md,
|
||||
*context_.diff_filter_md, *context_.diff_dst_md,
|
||||
convBwdFilterDims.strides, convBwdFilterDims.dilations,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
convBwdFilterDims.padding_left, convBwdFilterDims.padding_right,
|
||||
convBwdFilterDims.padding));
|
||||
#else
|
||||
convBwdFilterDims.padding_left, convBwdFilterDims.padding_right));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
}
|
||||
context_.bwd_filter_pd.reset(new ConvBwdFilterPd(
|
||||
*context_.bwd_filter_desc, cpu_engine_, *context_.fwd_pd));
|
||||
@ -495,11 +519,14 @@ class MklConvCustomBackpropFilterOp
|
||||
// The default dilation factor for each dimension is 1 in TF and
|
||||
// 0 in MKL-DNN.
|
||||
for (int i = 0; i < dilations.size(); ++i) --dilations[i];
|
||||
|
||||
MklConvBwdFilterParams convBwdFilterDims(
|
||||
fwd_src_dims, fwd_filter_dims, diff_bias_dims, diff_dst_dims, strides,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
dilations, padding_left, padding_right,
|
||||
TFPaddingToMklDnnPadding(this->padding_));
|
||||
#else
|
||||
dilations, padding_left, padding_right);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// MKL-DNN allocates large buffers when a conv gradient filter primtive is
|
||||
// created. So we don't cache conv backward primitives when the env
|
||||
|
@ -65,20 +65,32 @@ struct MklConvBwdInputParams {
|
||||
memory::dims dilations;
|
||||
memory::dims padding_left;
|
||||
memory::dims padding_right;
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
padding_kind padding;
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
MklConvBwdInputParams(memory::dims diff_src_dims, memory::dims filter_dims,
|
||||
memory::dims diff_dst_dims, memory::dims strides,
|
||||
memory::dims dilations, memory::dims padding_left,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
memory::dims padding_right, padding_kind padding)
|
||||
#else
|
||||
memory::dims padding_right)
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
: diff_src_dims(diff_src_dims),
|
||||
filter_dims(filter_dims),
|
||||
diff_dst_dims(diff_dst_dims),
|
||||
strides(strides),
|
||||
dilations(dilations),
|
||||
padding_left(padding_left),
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
padding_right(padding_right),
|
||||
padding(padding) {}
|
||||
padding(padding) {
|
||||
}
|
||||
#else
|
||||
padding_right(padding_right) {
|
||||
}
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -211,14 +223,22 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
|
||||
ALGORITHM::convolution_direct, *context_.diff_src_md,
|
||||
*context_.filter_md, *context_.diff_dst_md, convBwdInputDims.strides,
|
||||
convBwdInputDims.dilations, convBwdInputDims.padding_left,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
convBwdInputDims.padding_right, convBwdInputDims.padding));
|
||||
#else
|
||||
convBwdInputDims.padding_right));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
context_.fwd_desc.reset(new ConvFwdDesc(
|
||||
prop_kind::forward, ALGORITHM::convolution_direct,
|
||||
*context_.diff_src_md, *context_.filter_md, *context_.diff_dst_md,
|
||||
convBwdInputDims.strides, convBwdInputDims.dilations,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
convBwdInputDims.padding_left, convBwdInputDims.padding_right,
|
||||
convBwdInputDims.padding));
|
||||
#else
|
||||
convBwdInputDims.padding_left, convBwdInputDims.padding_right));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// Create primitive descriptors for conv fwd and conv bwd input.
|
||||
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
|
||||
@ -440,11 +460,14 @@ class MklConvCustomBackpropInputOp
|
||||
// The default dilation factor for each dimension is 1 in TF and
|
||||
// 0 in MKL-DNN.
|
||||
for (int i = 0; i < dilations.size(); ++i) --dilations[i];
|
||||
|
||||
MklConvBwdInputParams convBwdInputDims(
|
||||
fwd_src_dims, fwd_filter_dims, diff_dst_dims, strides, dilations,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
padding_left, padding_right,
|
||||
TFPaddingToMklDnnPadding(this->padding_));
|
||||
#else
|
||||
padding_left, padding_right);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// We don't cache those primitives if the environment variable
|
||||
// TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is true and if primitve descriptor
|
||||
|
@ -24,8 +24,8 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "mkldnn.hpp"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "mkldnn.hpp"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -239,13 +239,21 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
|
||||
*context_.filter_md, *context_.bias_md, *context_.dst_md,
|
||||
convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
convFwdDims.padding_right, padding_kind::zero));
|
||||
#else
|
||||
convFwdDims.padding_right));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
} else {
|
||||
context_.fwd_desc.reset(new convolution_forward::desc(
|
||||
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
|
||||
*context_.filter_md, *context_.dst_md, convFwdDims.strides,
|
||||
convFwdDims.dilations, convFwdDims.padding_left,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
convFwdDims.padding_right, padding_kind::zero));
|
||||
#else
|
||||
convFwdDims.padding_right));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
|
||||
@ -261,12 +269,7 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
float op_scale = post_op_param.param[0];
|
||||
float op_alpha = post_op_param.param[1];
|
||||
float op_beta = post_op_param.param[2];
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
post_ops.append_eltwise(op_scale, mkldnn::algorithm::eltwise_relu,
|
||||
op_alpha,
|
||||
#else
|
||||
post_ops.append_eltwise(op_scale, post_op_param.alg, op_alpha,
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
op_beta);
|
||||
} else if (post_op_param.name == "sum") {
|
||||
DCHECK_EQ(post_op_param.param.size(), 1);
|
||||
@ -1033,6 +1036,7 @@ class MklConvOp : public OpKernel {
|
||||
&cached_filter_data_ptensor_, filter_tensor));
|
||||
|
||||
Tensor* second_tensor = nullptr;
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
TensorShape filter_mkl_format;
|
||||
filter_mkl_format.AddDim(
|
||||
sizeof(GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc)) /
|
||||
@ -1042,6 +1046,21 @@ class MklConvOp : public OpKernel {
|
||||
&cached_filter_md_ptensor_, &second_tensor));
|
||||
second_tensor->scalar<int32>()() = static_cast<int32>(
|
||||
GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc));
|
||||
#else
|
||||
// There is no tensor format in DNNL 1.x. So we cache the complete filter
|
||||
// descriptor as flat byte array.
|
||||
TensorShape cached_filter_md_shape;
|
||||
memory::desc weights_desc = conv_prim_desc.weights_desc();
|
||||
// We don't use .get_size() method of memory::desc since it returns size
|
||||
// required to store primitive's input memory. It is much more than size of
|
||||
// memory::desc itself.
|
||||
cached_filter_md_shape.AddDim(sizeof(weights_desc) / sizeof(uint8));
|
||||
OP_REQUIRES_OK(context, context->allocate_persistent(
|
||||
DT_UINT8, cached_filter_md_shape,
|
||||
&cached_filter_md_ptensor_, &second_tensor));
|
||||
*reinterpret_cast<memory::desc*>(second_tensor->flat<uint8>().data()) =
|
||||
weights_desc;
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
void AllocatePersistentTensor(OpKernelContext* context,
|
||||
@ -1230,12 +1249,11 @@ class MklConvOp : public OpKernel {
|
||||
const Tensor& cached_filter_md =
|
||||
*cached_filter_md_ptensor_.AccessTensor(context);
|
||||
|
||||
// Check if the memory descriptor of the cached weights is same as
|
||||
// filter_md. If so, we can use the cached weights; otherwise
|
||||
// return nullptr.
|
||||
// Check if the memory descriptor of the cached weights is same as
|
||||
// filter_md. If so, we can use the cached weights; otherwise
|
||||
// return nullptr.
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
if (cached_filter_md.scalar<int64>().size() &&
|
||||
AreMemoryDescriptorsEqual(filter_md, cached_filter_md)) {
|
||||
if (filter_md == *static_cast<memory::desc*>(cached_filter_md.data())) {
|
||||
#else
|
||||
if (cached_filter_md.scalar<int32>().size() &&
|
||||
cached_filter_md.scalar<int32>()() == filter_md) {
|
||||
@ -1568,7 +1586,7 @@ class MklQuantizedConv2DOp
|
||||
|
||||
if (!scaled_bias_buf_)
|
||||
AllocTmpBuffer<Tbias>(context, &scaled_bias_tensor_,
|
||||
conv_fwd_pd->bias_primitive_desc(),
|
||||
GET_BIAS_DESC_FROM_OP_PD(conv_fwd_pd),
|
||||
&scaled_bias_buf_);
|
||||
if (!scaled_bias_) {
|
||||
scaled_bias_ = new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_,
|
||||
|
@ -89,12 +89,26 @@ class MklDequantizeOp : public OpKernel {
|
||||
Tensor* output_tensor = nullptr;
|
||||
MklDnnShape output_mkl_shape;
|
||||
TensorShape output_tf_shape;
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
memory::desc dst_md =
|
||||
src_mkl_shape.IsMklTensor()
|
||||
? memory::desc(src_dims, MklDnnType<float>(),
|
||||
static_cast<memory::format>(src_md.data.format))
|
||||
static_cast<MEMORY_FORMAT>(src_md.data.format))
|
||||
: memory::desc(src_dims, MklDnnType<float>(),
|
||||
MEMORY_FORMAT::nhwc);
|
||||
#else
|
||||
memory::desc dst_md = memory::desc();
|
||||
if (src_mkl_shape.IsMklTensor()) {
|
||||
dst_md = memory::desc(src_mkl_shape.GetMklLayout().data);
|
||||
// There is no API in MKL-DNN v1.x to construct memory descriptor with
|
||||
// same .data field but different type.
|
||||
dst_md.data.data_type = memory::convert_to_c(MklDnnType<float>());
|
||||
} else {
|
||||
dst_md =
|
||||
memory::desc(src_dims, MklDnnType<float>(), MEMORY_FORMAT::nhwc);
|
||||
}
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// If input is MKL shape, output is also MKL shape.
|
||||
// If input is TF shape, output is also TF shape.
|
||||
if (src_mkl_shape.IsMklTensor()) {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -105,6 +105,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
memory::dims weight_dims = memory::dims({channel, k});
|
||||
memory::dims bias_dims = memory::dims({channel});
|
||||
memory::dims dst_dims = memory::dims({batch, channel});
|
||||
MEMORY_FORMAT src_format = MEMORY_FORMAT::nc;
|
||||
MEMORY_FORMAT weight_format =
|
||||
transpose_b_ ? MEMORY_FORMAT::oi : MEMORY_FORMAT::io;
|
||||
|
||||
@ -112,7 +113,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
// 1. const, let MKL-DNN determine format because it will be cached;
|
||||
// 2. var, keep the original format to avoid reordering.
|
||||
MklDnnMatMulFwdParams matmul_params(
|
||||
src_dims, weight_dims, bias_dims, dst_dims,
|
||||
src_dims, weight_dims, bias_dims, dst_dims, src_format,
|
||||
(this->is_weight_const_) ? MEMORY_FORMAT::any : weight_format);
|
||||
|
||||
// Extend the basic parameters for data types and fusions.
|
||||
@ -152,44 +153,44 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
MklDnnData<T> src_mkl(&(this->cpu_engine_));
|
||||
MklDnnData<T> weight_mkl(&(this->cpu_engine_));
|
||||
|
||||
if (src_mkl_shape.IsMklTensor()) {
|
||||
memory::desc input_md = src_mkl_shape.GetMklLayout();
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
if (input_md != matmul_pd->src_desc()) {
|
||||
#else
|
||||
if (input_md.data.format != MKL_TENSOR_FORMAT_NC) {
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
src_mkl.SetUsrMem(input_md, src_data);
|
||||
src_mkl.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
matmul_pd.get()->PRIMITIVE_DESC_SRC, this->cpu_engine_));
|
||||
src_data = reinterpret_cast<T*>(src_mkl.GetOpMem().get_data_handle());
|
||||
}
|
||||
auto src_md = src_mkl_shape.IsMklTensor()
|
||||
? src_mkl_shape.GetMklLayout()
|
||||
: memory::desc(src_dims, MklDnnType<T>(), src_format);
|
||||
|
||||
if (IS_SRC_REORDER_NEEDED(src_md, matmul_pd, matmul_prim)) {
|
||||
src_mkl.SetUsrMem(src_md, src_data);
|
||||
src_mkl.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
matmul_pd.get()->PRIMITIVE_DESC_SRC, this->cpu_engine_));
|
||||
src_data = reinterpret_cast<T*>(src_mkl.GetOpMem().get_data_handle());
|
||||
}
|
||||
|
||||
// Get cached data when weight is const.
|
||||
memory::format expected_format = matmul_prim->GetWeightMemoryFormat();
|
||||
DCHECK(expected_format != weight_format && this->is_weight_const_);
|
||||
if (this->is_weight_const_) {
|
||||
const memory::desc weight_md =
|
||||
memory::desc(weight_dims, MklDnnType<T>(), weight_format);
|
||||
if (IS_WEIGHTS_REORDER_NEEDED(weight_md, matmul_pd, matmul_prim)) {
|
||||
T* cached_weight_data = nullptr;
|
||||
if (this->IsWeightCacheEmpty(ctx)) {
|
||||
auto weight_md =
|
||||
memory::desc(weight_dims, MklDnnType<T>(), weight_format);
|
||||
this->CacheWeight(ctx, matmul_pd, cached_weight_data, weight_tensor,
|
||||
weight_mkl, weight_md);
|
||||
|
||||
if (this->is_weight_const_) {
|
||||
if (this->IsWeightCacheEmpty(ctx)) {
|
||||
this->CacheWeight(ctx, matmul_pd, cached_weight_data, weight_tensor,
|
||||
weight_mkl, weight_md);
|
||||
}
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
cached_weight_data = this->GetCachedWeight(
|
||||
ctx, GET_WEIGHTS_DESC_FROM_OP_PD(matmul_pd));
|
||||
#else
|
||||
cached_weight_data = this->GetCachedWeight(
|
||||
ctx, GET_WEIGHTS_DESC_FROM_OP_PD(matmul_pd).desc());
|
||||
#endif
|
||||
}
|
||||
cached_weight_data = this->GetCachedWeight(ctx, expected_format);
|
||||
|
||||
// Cache weight may fail when it gets different format in different
|
||||
// iteration. Fallback to reoder if it happens.
|
||||
// TODO: Fix this slow path.
|
||||
// Also do generel reorder if weight isn't const.
|
||||
if (cached_weight_data != nullptr) {
|
||||
weight_data = cached_weight_data;
|
||||
} else {
|
||||
memory::desc input_md =
|
||||
memory::desc(weight_dims, MklDnnType<T>(), weight_format);
|
||||
|
||||
//>>>>>>> master
|
||||
weight_mkl.SetUsrMem(input_md, weight_data);
|
||||
weight_mkl.SetUsrMem(weight_md, weight_data);
|
||||
weight_mkl.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
matmul_pd.get()->PRIMITIVE_DESC_WEIGHTS, this->cpu_engine_));
|
||||
weight_data =
|
||||
@ -210,23 +211,21 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
|
||||
void ExtendMklDnnMatMulFwdParams(OpKernelContext* ctx,
|
||||
MklDnnMatMulFwdParams& params) {
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
if (fused_ops_.size() == 2) {
|
||||
string post_op = fused_ops_[1];
|
||||
|
||||
if (post_op == "Relu") {
|
||||
params.post_op_params.push_back({"relu", { 1.0, 0.0, 0.0 }});
|
||||
params.post_op_params.push_back({"relu", {1.0, 0.0, 0.0}});
|
||||
} else if (post_op == "Relu6") {
|
||||
params.post_op_params.push_back({"relu6", { 1.0, 6.0, 0.0 }});
|
||||
params.post_op_params.push_back({"relu6", {1.0, 6.0, 0.0}});
|
||||
} else if (post_op == "Elu") {
|
||||
params.post_op_params.push_back({"elu", { 1.0, 1.0, 0.0 }});
|
||||
params.post_op_params.push_back({"elu", {1.0, 1.0, 0.0}});
|
||||
} else {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, errors::InvalidArgument(
|
||||
"Unsupported post-argument in MklFusedMatMul: ", post_op));
|
||||
}
|
||||
}
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -41,7 +41,8 @@ struct MklDnnMatMulFwdParams {
|
||||
memory::dims weight_dims;
|
||||
memory::dims bias_dims;
|
||||
memory::dims dst_dims;
|
||||
MEMORY_FORMAT weight_fmt;
|
||||
MEMORY_FORMAT src_format;
|
||||
MEMORY_FORMAT weight_format;
|
||||
string dtypes = string("");
|
||||
struct PostOpParam {
|
||||
string name;
|
||||
@ -51,12 +52,14 @@ struct MklDnnMatMulFwdParams {
|
||||
|
||||
MklDnnMatMulFwdParams(memory::dims src_dims, memory::dims weight_dims,
|
||||
memory::dims bias_dims, memory::dims dst_dims,
|
||||
MEMORY_FORMAT weight_fmt = MEMORY_FORMAT::any)
|
||||
MEMORY_FORMAT src_format = MEMORY_FORMAT::any,
|
||||
MEMORY_FORMAT weight_format = MEMORY_FORMAT::any)
|
||||
: src_dims(src_dims),
|
||||
weight_dims(weight_dims),
|
||||
bias_dims(bias_dims),
|
||||
dst_dims(dst_dims),
|
||||
weight_fmt(weight_fmt) {}
|
||||
src_format(src_format),
|
||||
weight_format(weight_format) {}
|
||||
};
|
||||
|
||||
// With quantization, input, weight, bias, and output can have different types.
|
||||
@ -182,15 +185,11 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
|
||||
// format.
|
||||
context_.src_md.reset(new memory::desc({matmul_fwd_params.src_dims},
|
||||
MklDnnType<Tinput>(),
|
||||
MEMORY_FORMAT::any));
|
||||
matmul_fwd_params.src_format));
|
||||
|
||||
context_.weight_md.reset(new memory::desc({matmul_fwd_params.weight_dims},
|
||||
MklDnnType<Tweight>(),
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
MEMORY_FORMAT::any));
|
||||
#else
|
||||
matmul_fwd_params.weight_fmt));
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
matmul_fwd_params.weight_format));
|
||||
|
||||
context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims},
|
||||
MklDnnType<Toutput>(),
|
||||
@ -438,49 +437,61 @@ class MklDnnMatMulOpBase : public OpKernel {
|
||||
|
||||
// reorder and cache the weight
|
||||
weight.SetUsrMem(weight_md, &weight_tensor);
|
||||
weight.CheckReorderToOpMem(matmul_fwd_pd.get()->weights_primitive_desc());
|
||||
weight.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_));
|
||||
weight_data = static_cast<Tweight*>(weight.GetOpMem().get_data_handle());
|
||||
|
||||
Tensor* weight_tensor_ptr = nullptr;
|
||||
|
||||
size_t size = matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS.get_size();
|
||||
TensorShape weight_tf_shape;
|
||||
weight_tf_shape.AddDim(
|
||||
(matmul_fwd_pd.get()->weights_primitive_desc().get_size() /
|
||||
sizeof(Tweight)));
|
||||
weight_tf_shape.AddDim(size / sizeof(Tweight));
|
||||
|
||||
OP_REQUIRES_OK(context, context->allocate_persistent(
|
||||
DataTypeToEnum<Tweight>::value, weight_tf_shape,
|
||||
&weight_oi_, &weight_tensor_ptr));
|
||||
|
||||
void* weight_oi_t_data = weight.GetTensorBuffer(weight_tensor_ptr);
|
||||
size_t weight_size = weight.GetOpMem().get_primitive_desc().get_size();
|
||||
memcpy(weight_oi_t_data, weight_data, weight_size);
|
||||
memcpy(weight_oi_t_data, weight_data, size);
|
||||
|
||||
// cache the memory descriptor
|
||||
// cache the memory descriptor
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
auto expected_md = GET_WEIGHTS_DESC_FROM_OP_PD(matmul_fwd_pd);
|
||||
#else
|
||||
auto expected_md = GET_WEIGHTS_DESC_FROM_OP_PD(matmul_fwd_pd).desc();
|
||||
#endif
|
||||
Tensor* weight_md_tensor_ptr = nullptr;
|
||||
TensorShape weight_mkl_format;
|
||||
weight_mkl_format.AddDim(1);
|
||||
weight_mkl_format.AddDim(sizeof(expected_md) / sizeof(Tweight));
|
||||
|
||||
OP_REQUIRES_OK(context, context->allocate_persistent(
|
||||
DT_INT32, weight_mkl_format, &weight_oi_md_,
|
||||
&weight_md_tensor_ptr));
|
||||
weight_md_tensor_ptr->scalar<int32>()() =
|
||||
matmul_fwd_pd.get()->weights_primitive_desc().desc().data.format;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_persistent(DataTypeToEnum<Tweight>::value,
|
||||
weight_mkl_format, &weight_oi_md_,
|
||||
&weight_md_tensor_ptr));
|
||||
*reinterpret_cast<memory::desc*>(
|
||||
weight_md_tensor_ptr->flat<Tweight>().data()) = expected_md;
|
||||
}
|
||||
|
||||
Tweight* GetCachedWeight(OpKernelContext* context,
|
||||
const memory::format& weight_mf)
|
||||
const memory::desc& expected_md)
|
||||
LOCKS_EXCLUDED(mu_) {
|
||||
tf_shared_lock lock(mu_);
|
||||
const Tensor& weight_t = *weight_oi_.AccessTensor(context);
|
||||
const Tensor& weight_md_t = *weight_oi_md_.AccessTensor(context);
|
||||
|
||||
// Check if the memory descriptor of the cached weight is same as
|
||||
// weight_mf. if so use the cached memory, else return NULL
|
||||
if (weight_md_t.scalar<int32>().size() &&
|
||||
weight_md_t.scalar<int32>()() == weight_mf) {
|
||||
return static_cast<Tweight*>(
|
||||
const_cast<Tweight*>(weight_t.flat<Tweight>().data()));
|
||||
// Check if the memory descriptor of the cached weight is same as
|
||||
// expected_md. if so use the cached memory, else return NULL
|
||||
if (weight_md_t.flat<Tweight>().size()) {
|
||||
const memory::desc& stored_md =
|
||||
*(static_cast<memory::desc*>(weight_md_t.data()));
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
if (stored_md == expected_md) {
|
||||
#else
|
||||
if (stored_md.data.format == expected_md.data.format) {
|
||||
#endif
|
||||
return static_cast<Tweight*>(
|
||||
const_cast<Tweight*>(weight_t.flat<Tweight>().data()));
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@ -527,7 +538,8 @@ void dnnl_gemm_exec(const dnnl::desc& a_md, const dnnl::desc& b_md,
|
||||
dnnl::stream s(cpu_engine);
|
||||
matmul_prim.execute(s, {{DNNL_ARG_SRC, a_memory},
|
||||
{DNNL_ARG_WEIGHTS, b_memory},
|
||||
{DNNL_ARG_DST, c_memory}});
|
||||
{ DNNL_ARG_DST,
|
||||
c_memory }});
|
||||
s.wait();
|
||||
}
|
||||
|
||||
|
@ -32,7 +32,9 @@ using mkldnn::algorithm;
|
||||
using mkldnn::engine;
|
||||
using mkldnn::error;
|
||||
using mkldnn::memory;
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
using mkldnn::padding_kind;
|
||||
#endif
|
||||
using mkldnn::pooling_backward;
|
||||
using mkldnn::pooling_forward;
|
||||
using mkldnn::prop_kind;
|
||||
@ -136,21 +138,31 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
|
||||
pooling_prop_kind = prop_kind::forward_inference;
|
||||
else
|
||||
pooling_prop_kind = prop_kind::forward_training;
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
MklPoolingParams fwdParams(
|
||||
src_dims, output_dims_mkl_order, filter_dims, strides, padding_left,
|
||||
padding_right, ALGORITHM::pooling_max, pooling_prop_kind,
|
||||
static_cast<memory::format>(input_md.data.format));
|
||||
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_));
|
||||
#else
|
||||
MklPoolingParams fwdParams(
|
||||
src_dims, output_dims_mkl_order, filter_dims, strides, padding_left,
|
||||
padding_right, ALGORITHM::pooling_max, pooling_prop_kind,
|
||||
static_cast<MEMORY_FORMAT>(input_md.data.format));
|
||||
#endif
|
||||
pooling_fwd = MklPoolingFwdPrimitiveFactory<T>::Get(fwdParams);
|
||||
|
||||
// Allocate output tensor.
|
||||
this->AllocateOutputTensor(context, *(pooling_fwd->GetPoolingFwdPd()),
|
||||
output_dims_mkl_order,
|
||||
this->tensor_format_mkldnn_, &output_tensor);
|
||||
OP_REQUIRES_OK(context, context->status());
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
dnn_data_output.SetUsrMem(output_dims_mkl_order,
|
||||
pooling_fwd->GetDstMemoryFormat(),
|
||||
output_tensor);
|
||||
|
||||
this->data_format_mkldnn_, output_tensor);
|
||||
#else
|
||||
dnn_data_output.SetUsrMem(
|
||||
GET_DST_DESC_FROM_OP_PD(pooling_fwd->GetPoolingFwdPd()),
|
||||
output_tensor);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
const T* src_data = input_tensor.flat<T>().data();
|
||||
|
||||
T* dst_data = output_tensor->flat<T>().data();
|
||||
@ -183,7 +195,6 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
|
||||
OP_REQUIRES_OK(context, context->status());
|
||||
T* ws_data =
|
||||
static_cast<T*>(dnn_data_wksp.GetOpMem().get_data_handle());
|
||||
|
||||
// Execute pooling op.
|
||||
pooling_fwd->Execute(src_data, dst_data, ws_data);
|
||||
}
|
||||
@ -285,11 +296,19 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
|
||||
: memory::desc(orig_input_dims_mkl_order, MklDnnType<T>(),
|
||||
this->data_format_mkldnn_);
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
MklPoolingParams bwdParams(
|
||||
orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
|
||||
strides, padding_left, padding_right, ALGORITHM::pooling_max,
|
||||
prop_kind::forward_training,
|
||||
static_cast<memory::format>(src_md.data.format));
|
||||
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_));
|
||||
#else
|
||||
MklPoolingParams bwdParams(
|
||||
orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
|
||||
strides, padding_left, padding_right, ALGORITHM::pooling_max,
|
||||
prop_kind::forward_training,
|
||||
static_cast<MEMORY_FORMAT>(src_md.data.format));
|
||||
#endif
|
||||
MklPoolingBwdPrimitive<T>* pooling_bwd =
|
||||
MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);
|
||||
|
||||
|
@ -54,17 +54,25 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
|
||||
context_.dst_md.reset(new memory::desc({fwdParams.dst_dims}, MklDnnType<T>(),
|
||||
MEMORY_FORMAT::any));
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
// Create a pooling descriptor.
|
||||
context_.fwd_desc.reset(new pooling_forward::desc(
|
||||
fwdParams.prop_kind, fwdParams.alg_kind, *context_.src_md,
|
||||
*context_.dst_md, fwdParams.strides, fwdParams.filter_dims,
|
||||
fwdParams.padding_left, fwdParams.padding_right, padding_kind::zero));
|
||||
#else
|
||||
context_.fwd_desc.reset(new pooling_forward::desc(
|
||||
fwdParams.prop_kind, fwdParams.alg_kind, *context_.src_md,
|
||||
*context_.dst_md, fwdParams.strides, fwdParams.filter_dims,
|
||||
fwdParams.padding_left, fwdParams.padding_right));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
context_.fwd_pd.reset(
|
||||
new pooling_forward::primitive_desc(*context_.fwd_desc, cpu_engine_));
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
context_.dst_fmt = static_cast<MEMORY_FORMAT>(
|
||||
context_.fwd_pd.get()->PRIMITIVE_DESC_DST.desc().data.format);
|
||||
#else
|
||||
context_.dst_fmt = static_cast<MEMORY_FORMAT>(MEMORY_FORMAT::any);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// Create MKL-DNN internal memory object with dummy data.
|
||||
@ -72,7 +80,6 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
|
||||
context_.fwd_pd.get()->PRIMITIVE_DESC_SRC, cpu_engine_, DummyData));
|
||||
context_.dst_mem.reset(new MEMORY_CONSTRUCTOR(
|
||||
context_.fwd_pd.get()->PRIMITIVE_DESC_DST, cpu_engine_, DummyData));
|
||||
|
||||
// For max pooling, need to return workspace (ws) for backward computing.
|
||||
if (fwdParams.alg_kind == ALGORITHM::pooling_max &&
|
||||
fwdParams.prop_kind == prop_kind::forward_training) {
|
||||
@ -160,10 +167,11 @@ void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
|
||||
|
||||
// Create memory descriptor.
|
||||
context_.diff_src_md.reset(new memory::desc(
|
||||
{bwdParams.src_dims}, MklDnnType<T>(), memory::format::any));
|
||||
{bwdParams.src_dims}, MklDnnType<T>(), MEMORY_FORMAT::any));
|
||||
context_.diff_dst_md.reset(new memory::desc(
|
||||
{bwdParams.dst_dims}, MklDnnType<T>(), bwdParams.src_format));
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
context_.bwd_desc.reset(new pooling_backward::desc(
|
||||
bwdParams.alg_kind, *context_.diff_src_md, *context_.diff_dst_md,
|
||||
bwdParams.strides, bwdParams.filter_dims, bwdParams.padding_left,
|
||||
@ -175,6 +183,18 @@ void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
|
||||
bwdParams.prop_kind, bwdParams.alg_kind, *context_.diff_src_md,
|
||||
*context_.diff_dst_md, bwdParams.strides, bwdParams.filter_dims,
|
||||
bwdParams.padding_left, bwdParams.padding_right, padding_kind::zero));
|
||||
#else
|
||||
context_.bwd_desc.reset(new pooling_backward::desc(
|
||||
bwdParams.alg_kind, *context_.diff_src_md, *context_.diff_dst_md,
|
||||
bwdParams.strides, bwdParams.filter_dims, bwdParams.padding_left,
|
||||
bwdParams.padding_right));
|
||||
// Create a forward primitive,
|
||||
// which will be used as a hint for creating backward primitive.
|
||||
context_.fwd_desc.reset(new pooling_forward::desc(
|
||||
bwdParams.prop_kind, bwdParams.alg_kind, *context_.diff_src_md,
|
||||
*context_.diff_dst_md, bwdParams.strides, bwdParams.filter_dims,
|
||||
bwdParams.padding_left, bwdParams.padding_right));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
context_.fwd_pd.reset(
|
||||
new pooling_forward::primitive_desc(*context_.fwd_desc, cpu_engine_));
|
||||
context_.bwd_pd.reset(new pooling_backward::primitive_desc(
|
||||
|
@ -47,13 +47,13 @@ struct MklPoolingParams {
|
||||
memory::dims padding_right;
|
||||
mkldnn::algorithm alg_kind;
|
||||
mkldnn::prop_kind prop_kind;
|
||||
memory::format src_format;
|
||||
MEMORY_FORMAT src_format;
|
||||
|
||||
MklPoolingParams(memory::dims src_dims, memory::dims dst_dims,
|
||||
memory::dims filter_dims, memory::dims strides,
|
||||
memory::dims padding_left, memory::dims padding_right,
|
||||
mkldnn::algorithm alg_kind, mkldnn::prop_kind prop_kind,
|
||||
memory::format src_format)
|
||||
MEMORY_FORMAT src_format)
|
||||
: src_dims(src_dims),
|
||||
dst_dims(dst_dims),
|
||||
filter_dims(filter_dims),
|
||||
|
@ -269,11 +269,11 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
|
||||
}
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
weight_data = this->GetCachedWeight(
|
||||
context, static_cast<int32>(weight_mkl_shape.GetTfDataFormat()));
|
||||
context, GET_WEIGHTS_DESC_FROM_OP_PD(matmul_fwd_pd));
|
||||
#else
|
||||
weight_data = this->GetCachedWeight(
|
||||
context, matmul_fwd->GetWeightMemoryFormat());
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
context, GET_WEIGHTS_DESC_FROM_OP_PD(matmul_fwd_pd).desc());
|
||||
#endif
|
||||
is_weight_cached = (weight_data != nullptr);
|
||||
}
|
||||
|
||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
|
||||
#include "mkldnn.hpp"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
@ -27,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/util/mkl_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
using mkldnn::algorithm;
|
||||
using mkldnn::eltwise_forward;
|
||||
@ -164,7 +164,11 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
|
||||
context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
|
||||
|
||||
context_.src_mpd.reset(
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
new MEMORY_PRIMITIVE_DESC(*context_.src_md));
|
||||
#else
|
||||
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
|
||||
#endif
|
||||
|
||||
// Create an eltwise forward descriptor and primitive descriptor
|
||||
context_.fwd_desc.reset(new eltwise_forward::desc(
|
||||
@ -397,7 +401,6 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
|
||||
// Create memory descriptors for eltwise data w/ no specified format
|
||||
context_.src_md.reset(new memory::desc(bwdParams.common_md.data));
|
||||
context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data));
|
||||
|
||||
context_.src_mpd.reset(
|
||||
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
|
||||
context_.diff_dst_mpd.reset(
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include "mkldnn.hpp"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
@ -27,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/mkl_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
using mkldnn::prop_kind;
|
||||
using mkldnn::softmax_forward;
|
||||
@ -65,7 +65,7 @@ class MklSoftmaxPrimitive : public MklPrimitive {
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.fwd_primitives, context_.fwd_stream,
|
||||
context_.net_args);
|
||||
context_.fwd_net_args);
|
||||
#else
|
||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
||||
#endif
|
||||
|
@ -1259,6 +1259,7 @@ REGISTER_OP("_MklFusedBatchNormV3")
|
||||
.Attr("U: {float}")
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.Attr("exponential_avg_factor: float = 1.0")
|
||||
.Attr("is_training: bool = true")
|
||||
.SetShapeFn(shape_inference::FusedBatchNormShape)
|
||||
.Doc(
|
||||
|
@ -2524,6 +2524,7 @@ REGISTER_OP("_MklFusedBatchNorm")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
.Attr("data_format: string = 'NHWC'")
|
||||
.Attr("exponential_avg_factor: float = 1.0")
|
||||
.Attr("is_training: bool = true")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle x;
|
||||
@ -2673,6 +2674,7 @@ REGISTER_OP("_MklFusedBatchNormV2")
|
||||
.Attr("U: {float}")
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.Attr("exponential_avg_factor: float = 1.0")
|
||||
.Attr("is_training: bool = true")
|
||||
.SetShapeFn(shape_inference::FusedBatchNormShape);
|
||||
|
||||
|
@ -23,6 +23,7 @@ namespace tensorflow {
|
||||
#define ADD_MD add_md
|
||||
#define ALGORITHM mkldnn::algorithm
|
||||
#define ALGORITHM_UNDEF ALGORITHM::undef
|
||||
#define BN_FLAGS mkldnn::normalization_flags
|
||||
#define CPU_STREAM(engine) stream(engine)
|
||||
#define DATA_WITH_ENGINE(data, engine) data, engine
|
||||
#define DST_MD dst_md
|
||||
@ -41,6 +42,8 @@ namespace tensorflow {
|
||||
GET_MEMORY_DESC_FROM_MEM_PTR(mem_ptr)
|
||||
#define GET_MEMORY_SIZE_FROM_MD(md, engine) md.get_size()
|
||||
#define GET_SRC_DESC_FROM_OP_PD(op_pd) op_pd->src_desc()
|
||||
#define GET_DST_DESC_FROM_OP_PD(op_pd) op_pd->dst_desc()
|
||||
#define GET_BIAS_DESC_FROM_OP_PD(op_pd) op_pd->bias_desc()
|
||||
#define GET_DIFF_DST_DESC_FROM_OP_PD(op_pd) op_pd->diff_dst_desc()
|
||||
#define GET_WORKSPACE_DESC_FROM_OP_PD(op_pd) op_pd->workspace_desc()
|
||||
#define GET_TENSOR_FORMAT(fmt) MklTensorFormatToMklDnnDataFormat(fmt)
|
||||
@ -112,6 +115,7 @@ namespace tensorflow {
|
||||
#define TENSOR_FORMAT_NHWC MKL_TENSOR_FORMAT_NHWC
|
||||
#define TENSOR_MAX_DIMS MKLDNN_MAX_NDIMS
|
||||
#define GET_USR_MEM_PRIM_DESC(src) src.GetUsrMemDesc()
|
||||
#define BN_FLAGS mkldnn::normalization_flags
|
||||
|
||||
#else
|
||||
|
||||
@ -136,6 +140,8 @@ namespace tensorflow {
|
||||
#define GET_MEMORY_SIZE_FROM_MD(md, engine) \
|
||||
memory::primitive_desc(md, engine).get_size()
|
||||
#define GET_SRC_DESC_FROM_OP_PD(op_pd) op_pd.get()->src_primitive_desc()
|
||||
#define GET_DST_DESC_FROM_OP_PD(op_pd) op_pd.get()->dst_primitive_desc()
|
||||
#define GET_BIAS_DESC_FROM_OP_PD(op_pd) op_pd.get()->bias_primitive_desc()
|
||||
#define GET_DIFF_DST_DESC_FROM_OP_PD(op_pd) \
|
||||
op_pd.get()->diff_dst_primitive_desc()
|
||||
#define GET_WORKSPACE_DESC_FROM_OP_PD(op_pd) \
|
||||
@ -210,6 +216,7 @@ namespace tensorflow {
|
||||
#define TENSOR_FORMAT TensorFormat
|
||||
#define TENSOR_FORMAT_NHWC FORMAT_NHWC
|
||||
#define GET_USR_MEM_PRIM_DESC(src) src.GetUsrMemPrimDesc()
|
||||
#define BN_FLAGS mkldnn
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -42,7 +42,9 @@ limitations under the License.
|
||||
|
||||
using mkldnn::engine;
|
||||
using mkldnn::memory;
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
using mkldnn::padding_kind;
|
||||
#endif
|
||||
using mkldnn::primitive;
|
||||
using mkldnn::reorder;
|
||||
using mkldnn::stream;
|
||||
@ -1224,10 +1226,12 @@ inline memory::dims CalculateTFStrides(const memory::dims& dims_tf_order) {
|
||||
return strides;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
inline padding_kind TFPaddingToMklDnnPadding(Padding pad) {
|
||||
// MKL-DNN only supports zero padding.
|
||||
return padding_kind::zero;
|
||||
}
|
||||
#endif
|
||||
|
||||
/// Helper function to create memory descriptor in Blocked format
|
||||
///
|
||||
@ -2157,7 +2161,7 @@ inline bool IsConv1x1StrideNot1(memory::dims filter_dims,
|
||||
}
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
void execute_primitives(
|
||||
inline void execute_primitives(
|
||||
std::vector<mkldnn::primitive>& primitives, std::shared_ptr<stream> stream,
|
||||
std::vector<std::unordered_map<int, memory>>& net_args) {
|
||||
DCHECK_EQ(primitives.size(), net_args.size());
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "tensorflow/core/util/mkl_types.h"
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
@ -23,7 +24,7 @@ namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(MklUtilTest, MklDnnTfShape) {
|
||||
auto cpu_engine = engine(engine::cpu, 0);
|
||||
auto cpu_engine = engine(ENGINE_CPU, 0);
|
||||
MklDnnData<float> a(&cpu_engine);
|
||||
|
||||
const int N = 1, C = 2, H = 3, W = 4;
|
||||
@ -31,7 +32,7 @@ TEST(MklUtilTest, MklDnnTfShape) {
|
||||
MklDnnShape a_mkldnn_shape;
|
||||
a_mkldnn_shape.SetMklTensor(true);
|
||||
// Create TF layout in NCHW.
|
||||
a_mkldnn_shape.SetTfLayout(a_dims.size(), a_dims, memory::format::nchw);
|
||||
a_mkldnn_shape.SetTfLayout(a_dims.size(), a_dims, MKL_TENSOR_FORMAT_NCHW);
|
||||
TensorShape a_tf_shape_nchw({N, C, H, W});
|
||||
TensorShape a_tf_shape_nhwc({N, H, W, C});
|
||||
TensorShape a_mkldnn_tf_shape = a_mkldnn_shape.GetTfShape();
|
||||
@ -43,7 +44,7 @@ TEST(MklUtilTest, MklDnnTfShape) {
|
||||
MklDnnShape b_mkldnn_shape;
|
||||
b_mkldnn_shape.SetMklTensor(true);
|
||||
// Create TF layout in NHWC.
|
||||
b_mkldnn_shape.SetTfLayout(b_dims.size(), b_dims, memory::format::nhwc);
|
||||
b_mkldnn_shape.SetTfLayout(b_dims.size(), b_dims, MKL_TENSOR_FORMAT_NHWC);
|
||||
TensorShape b_tf_shape_nhwc({N, H, W, C});
|
||||
TensorShape b_tf_shape_nchw({N, C, H, W});
|
||||
TensorShape b_mkldnn_tf_shape = b_mkldnn_shape.GetTfShape();
|
||||
@ -55,7 +56,7 @@ TEST(MklUtilTest, MklDnnTfShape) {
|
||||
TEST(MklUtilTest, MklDnnBlockedFormatTest) {
|
||||
// Let's create 2D tensor of shape {3, 4} with 3 being innermost dimension
|
||||
// first (case 1) and then it being outermost dimension (case 2).
|
||||
auto cpu_engine = engine(engine::cpu, 0);
|
||||
auto cpu_engine = engine(ENGINE_CPU, 0);
|
||||
|
||||
// Setting for case 1
|
||||
MklDnnData<float> a(&cpu_engine);
|
||||
@ -67,7 +68,9 @@ TEST(MklUtilTest, MklDnnBlockedFormatTest) {
|
||||
EXPECT_EQ(a_md1.data.ndims, 2);
|
||||
EXPECT_EQ(a_md1.data.dims[0], 3);
|
||||
EXPECT_EQ(a_md1.data.dims[1], 4);
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
EXPECT_EQ(a_md1.data.format, mkldnn_blocked);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// Setting for case 2
|
||||
MklDnnData<float> b(&cpu_engine);
|
||||
@ -79,7 +82,9 @@ TEST(MklUtilTest, MklDnnBlockedFormatTest) {
|
||||
EXPECT_EQ(b_md2.data.ndims, 2);
|
||||
EXPECT_EQ(b_md2.data.dims[0], 3);
|
||||
EXPECT_EQ(b_md2.data.dims[1], 4);
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
EXPECT_EQ(b_md2.data.format, mkldnn_blocked);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
TEST(MklUtilTest, LRUCacheTest) {
|
||||
|
@ -174,12 +174,12 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
||||
|
||||
tf_http_archive(
|
||||
name = "mkl_dnn_v1",
|
||||
build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"),
|
||||
sha256 = "fcc2d951f7170eade0cfdd0d8d1d58e3e7785bd326bca6555f3722f8cba71811",
|
||||
strip_prefix = "mkl-dnn-1.0-pc2",
|
||||
build_file = clean_dep("//third_party/mkl_dnn:mkldnn_v1.BUILD"),
|
||||
sha256 = "27fd9da9720c452852f1226581e7914efcf74e1ff898468fdcbe1813528831ba",
|
||||
strip_prefix = "mkl-dnn-1.0",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/intel/mkl-dnn/archive/v1.0-pc2.tar.gz",
|
||||
"https://github.com/intel/mkl-dnn/archive/v1.0-pc2.tar.gz",
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/intel/mkl-dnn/archive/v1.0.tar.gz",
|
||||
"https://github.com/intel/mkl-dnn/archive/v1.0.tar.gz",
|
||||
],
|
||||
)
|
||||
|
||||
|
135
third_party/mkl_dnn/mkldnn_v1.BUILD
vendored
Normal file
135
third_party/mkl_dnn/mkldnn_v1.BUILD
vendored
Normal file
@ -0,0 +1,135 @@
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load(
|
||||
"@org_tensorflow//third_party/mkl_dnn:build_defs.bzl",
|
||||
"if_mkl_open_source_only",
|
||||
"if_mkl_v1_open_source_only",
|
||||
)
|
||||
load(
|
||||
"@org_tensorflow//third_party:common.bzl",
|
||||
"template_rule",
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "clang_linux_x86_64",
|
||||
values = {
|
||||
"cpu": "k8",
|
||||
"define": "using_clang=true",
|
||||
},
|
||||
)
|
||||
|
||||
template_rule(
|
||||
name = "mkldnn_config_h",
|
||||
src = "include/mkldnn_config.h.in",
|
||||
out = "include/mkldnn_config.h",
|
||||
substitutions = {
|
||||
"#cmakedefine MKLDNN_CPU_RUNTIME MKLDNN_RUNTIME_${MKLDNN_CPU_RUNTIME_CURRENT}": "#define MKLDNN_CPU_RUNTIME MKLDNN_RUNTIME_OMP",
|
||||
"#cmakedefine MKLDNN_GPU_RUNTIME MKLDNN_RUNTIME_${MKLDNN_GPU_RUNTIME}": "#define MKLDNN_GPU_RUNTIME MKLDNN_RUNTIME_NONE",
|
||||
},
|
||||
)
|
||||
|
||||
# Create the file mkldnn_version.h with MKL-DNN version numbers.
|
||||
# Currently, the version numbers are hard coded here. If MKL-DNN is upgraded then
|
||||
# the version numbers have to be updated manually. The version numbers can be
|
||||
# obtained from the PROJECT_VERSION settings in CMakeLists.txt. The variable is
|
||||
# set to "version_major.version_minor.version_patch". The git hash version can
|
||||
# be set to NA.
|
||||
# TODO(agramesh1) Automatically get the version numbers from CMakeLists.txt.
|
||||
|
||||
template_rule(
|
||||
name = "mkldnn_version_h",
|
||||
src = "include/mkldnn_version.h.in",
|
||||
out = "include/mkldnn_version.h",
|
||||
substitutions = {
|
||||
"@MKLDNN_VERSION_MAJOR@": "1",
|
||||
"@MKLDNN_VERSION_MINOR@": "0",
|
||||
"@MKLDNN_VERSION_PATCH@": "0",
|
||||
"@MKLDNN_VERSION_HASH@": "N/A",
|
||||
},
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mkl_dnn",
|
||||
srcs = glob([
|
||||
"src/common/*.cpp",
|
||||
"src/common/*.hpp",
|
||||
"src/cpu/*.cpp",
|
||||
"src/cpu/*.hpp",
|
||||
"src/cpu/**/*.cpp",
|
||||
"src/cpu/**/*.hpp",
|
||||
"src/cpu/xbyak/*.h",
|
||||
]) + if_mkl_v1_open_source_only([
|
||||
":mkldnn_config_h",
|
||||
]) + [":mkldnn_version_h"],
|
||||
hdrs = glob(["include/*"]),
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-DUSE_MKL",
|
||||
"-DUSE_CBLAS",
|
||||
] + if_mkl_open_source_only([
|
||||
"-UUSE_MKL",
|
||||
"-UUSE_CBLAS",
|
||||
]) + if_mkl_v1_open_source_only([
|
||||
"-UUSE_MKL",
|
||||
"-UUSE_CBLAS",
|
||||
]) + select({
|
||||
"@org_tensorflow//tensorflow:linux_x86_64": [
|
||||
"-fopenmp", # only works with gcc
|
||||
],
|
||||
# TODO(ibiryukov): enable openmp with clang by including libomp as a
|
||||
# dependency.
|
||||
":clang_linux_x86_64": [],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
includes = [
|
||||
"include",
|
||||
"src",
|
||||
"src/common",
|
||||
"src/cpu",
|
||||
"src/cpu/gemm",
|
||||
"src/cpu/xbyak",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"@org_tensorflow//tensorflow:linux_x86_64": [
|
||||
"@mkl_linux//:mkl_headers",
|
||||
"@mkl_linux//:mkl_libs_linux",
|
||||
],
|
||||
"@org_tensorflow//tensorflow:macos": [
|
||||
"@mkl_darwin//:mkl_headers",
|
||||
"@mkl_darwin//:mkl_libs_darwin",
|
||||
],
|
||||
"@org_tensorflow//tensorflow:windows": [
|
||||
"@mkl_windows//:mkl_headers",
|
||||
"@mkl_windows//:mkl_libs_windows",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mkldnn_single_threaded",
|
||||
srcs = glob([
|
||||
"src/common/*.cpp",
|
||||
"src/common/*.hpp",
|
||||
"src/cpu/*.cpp",
|
||||
"src/cpu/*.hpp",
|
||||
"src/cpu/**/*.cpp",
|
||||
"src/cpu/**/*.hpp",
|
||||
"src/cpu/xbyak/*.h",
|
||||
]) + [":mkldnn_config_h"],
|
||||
hdrs = glob(["include/*"]),
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-DMKLDNN_THR=MKLDNN_THR_SEQ", # Disables threading.
|
||||
],
|
||||
includes = [
|
||||
"include",
|
||||
"src",
|
||||
"src/common",
|
||||
"src/cpu",
|
||||
"src/cpu/gemm",
|
||||
"src/cpu/xbyak",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
Loading…
Reference in New Issue
Block a user