Merge pull request from Intel-tensorflow:nhasabni/fixes_for_dnnl1.0

PiperOrigin-RevId: 297740489
Change-Id: I6201722923452c038bf0fd44065776dda12f87e5
This commit is contained in:
TensorFlower Gardener 2020-02-27 18:38:50 -08:00
commit 7dbff51a35
21 changed files with 422 additions and 147 deletions

View File

@ -116,12 +116,13 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
src_dims, output_dims_mkl_order, filter_dims, strides, padding_left,
padding_right, ALGORITHM::pooling_avg_exclude_padding,
pooling_prop_kind,
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_));
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), input_md);
#else
MklPoolingParams fwdParams(
src_dims, output_dims_mkl_order, filter_dims, strides, padding_left,
padding_right, ALGORITHM::pooling_avg_exclude_padding,
pooling_prop_kind, static_cast<MEMORY_FORMAT>(input_md.data.format));
pooling_prop_kind, static_cast<MEMORY_FORMAT>(input_md.data.format),
input_md);
#endif
pooling_fwd = MklPoolingFwdPrimitiveFactory<T>::Get(fwdParams);
@ -240,13 +241,13 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
strides, padding_left, padding_right,
ALGORITHM::pooling_avg_exclude_padding, prop_kind::forward_training,
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_));
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), src_md);
#else
MklPoolingParams bwdParams(
orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
strides, padding_left, padding_right,
ALGORITHM::pooling_avg_exclude_padding, prop_kind::forward_training,
static_cast<MEMORY_FORMAT>(src_md.data.format));
static_cast<MEMORY_FORMAT>(src_md.data.format), src_md);
#endif
MklPoolingBwdPrimitive<T>* pooling_bwd =
MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);

View File

@ -266,7 +266,7 @@ class MklConcatFwdPrimitive : public MklPrimitive {
explicit MklConcatFwdPrimitive(const MklConcatFwdParams& concat_fwd_dims,
const std::vector<memory::desc>& srcs_md)
: cpu_engine_(ENGINE_CPU, 0) {
context_.fwd_stream.reset(new CPU_STREAM(stream::kind::eager));
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
// Create concat primitive
Setup(concat_fwd_dims, srcs_md);
}
@ -292,8 +292,8 @@ class MklConcatFwdPrimitive : public MklPrimitive {
}
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.fwd_primitives, *context_.fwd_stream,
context_.fwd_primitives_args.at(i));
execute_primitives(context_.fwd_primitives, context_.fwd_stream,
context_.fwd_primitives_args);
#else
context_.fwd_stream->submit(context_.fwd_primitives);
#endif // ENABLE_MKLDNN_V1
@ -328,7 +328,7 @@ class MklConcatFwdPrimitive : public MklPrimitive {
std::shared_ptr<mkldnn::memory> dst_mem;
// Memory descriptor
std::vector<std::shared_ptr<mkldnn::memory::desc>> src_md;
std::vector<mkldnn::memory::desc> src_md;
std::shared_ptr<mkldnn::memory::desc> dst_md;
// Concat primitive descriptor
@ -339,7 +339,7 @@ class MklConcatFwdPrimitive : public MklPrimitive {
std::vector<mkldnn::primitive> fwd_primitives;
#ifdef ENABLE_MKLDNN_V1
std::vector<std::unordered_map<int, memory>> fwd_primitive_args;
std::vector<std::unordered_map<int, memory>> fwd_primitives_args;
#endif // ENABLE_MKLDNN_V1
ConcatFwdContext()
@ -355,15 +355,14 @@ class MklConcatFwdPrimitive : public MklPrimitive {
const std::vector<memory::desc>& srcs_md) {
// Create memory descriptors for concat with specified srcs format
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
std::shared_ptr<mkldnn::memory::desc> source_md(
new memory::desc(srcs_md[i].data));
mkldnn::memory::desc source_md(memory::desc(srcs_md[i].data));
context_.src_md.push_back(source_md);
#ifdef ENABLE_MKLDNN_V1
std::shared_ptr<mkldnn::memory> src_mem(
new mkldnn::memory(*source_md, cpu_engine_, DummyData));
new mkldnn::memory(source_md, cpu_engine_, DummyData));
#else
std::shared_ptr<mkldnn::memory::primitive_desc> src_mpd(
new memory::primitive_desc(*source_md, cpu_engine_));
new memory::primitive_desc(source_md, cpu_engine_));
context_.src_pd_shdptr.push_back(src_mpd);
std::shared_ptr<mkldnn::memory> src_mem(
@ -665,8 +664,9 @@ class MklConcatOp : public OpKernel {
if (input_tensors[k].NumElements() == 0) continue;
auto src_md = mkl_input_shapes[k].GetMklLayout();
srcs[k].SetUsrMem(src_md, &input_tensors[k]);
if (src_md.data.format != mkl_common_format) {
auto src_tf_fmt = MklTensorFormatToMklDnnDataFormat(
mkl_input_shapes[k].GetTfDataFormat());
if (src_tf_fmt != mkl_common_format) {
memory::dims src_dims(src_md.data.dims,
&src_md.data.dims[src_md.data.ndims]);
src_md =
@ -935,7 +935,8 @@ class MklConcatOp : public OpKernel {
for (int k = 0; k < input_shapes.size(); k++) {
auto src_dims = TFShapeToMklDnnDims(input_shapes[k].GetTfShape());
*concat_dim_size += src_dims[concat_dim];
int fmt = static_cast<int>(input_shapes[k].GetMklLayout().data.format);
int fmt = static_cast<int>(
MklTensorFormatToMklDnnDataFormat(input_shapes[k].GetTfDataFormat()));
occurrence_map[fmt] += 1;
}
@ -943,7 +944,7 @@ class MklConcatOp : public OpKernel {
// this means that all inputs have a same format
// return it with is_reorder_needed set false.
return static_cast<MEMORY_FORMAT>(
input_shapes[0].GetMklLayout().data.format);
MklTensorFormatToMklDnnDataFormat(input_shapes[0].GetTfDataFormat()));
}
// Input tensors have different formats. Thus, reorder is needed.
@ -970,7 +971,7 @@ class MklConcatOp : public OpKernel {
.TypeConstraint<type>("T") \
.HostMemory("concat_dim") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklConcatOp<CPUDevice, type, NAME_IS_CONCAT_DIM>) \
MklConcatOp<CPUDevice, type, NAME_IS_CONCAT_DIM>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklConcatV2") \
.Device(DEVICE_CPU) \
@ -978,7 +979,7 @@ class MklConcatOp : public OpKernel {
.TypeConstraint<int32>("Tidx") \
.HostMemory("axis") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklConcatOp<CPUDevice, type, NAME_IS_AXIS>)
MklConcatOp<CPUDevice, type, NAME_IS_AXIS>);
TF_CALL_float(REGISTER_MKL_CPU);
TF_CALL_bfloat16(REGISTER_MKL_CPU);
@ -988,14 +989,14 @@ REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2")
.TypeConstraint<quint8>("T")
.HostMemory("axis")
.Label(mkl_op_registry::kMklQuantizedOpLabel),
MklConcatOp<CPUDevice, quint8, NAME_IS_AXIS>)
MklConcatOp<CPUDevice, quint8, NAME_IS_AXIS>);
REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2")
.Device(DEVICE_CPU)
.TypeConstraint<qint8>("T")
.HostMemory("axis")
.Label(mkl_op_registry::kMklQuantizedOpLabel),
MklConcatOp<CPUDevice, qint8, NAME_IS_AXIS>)
MklConcatOp<CPUDevice, qint8, NAME_IS_AXIS>);
#undef REGISTER_CONCAT_MKL
} // namespace tensorflow

View File

@ -62,13 +62,19 @@ struct MklConvBwdFilterParams {
memory::dims dilations;
memory::dims padding_left;
memory::dims padding_right;
#ifndef ENABLE_MKLDNN_V1
padding_kind padding;
#endif // !ENABLE_MKLDNN_V1
MklConvBwdFilterParams(memory::dims src_dims, memory::dims diff_filter_dims,
memory::dims diff_bias_dims,
memory::dims diff_dst_dims, memory::dims strides,
memory::dims dilations, memory::dims padding_left,
#ifndef ENABLE_MKLDNN_V1
memory::dims padding_right, padding_kind padding)
#else
memory::dims padding_right)
#endif // !ENABLE_MKLDNN_V1
: src_dims(src_dims),
diff_filter_dims(diff_filter_dims),
diff_bias_dims(diff_bias_dims),
@ -76,8 +82,14 @@ struct MklConvBwdFilterParams {
strides(strides),
dilations(dilations),
padding_left(padding_left),
#ifndef ENABLE_MKLDNN_V1
padding_right(padding_right),
padding(padding) {}
padding(padding) {
}
#else
padding_right(padding_right) {
}
#endif // !ENABLE_MKLDNN_V1
};
template <typename T>
@ -241,8 +253,12 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
*context_.diff_filter_md, *context_.diff_dst_md,
convBwdFilterDims.strides, convBwdFilterDims.dilations,
#ifndef ENABLE_MKLDNN_V1
convBwdFilterDims.padding_left, convBwdFilterDims.padding_right,
convBwdFilterDims.padding));
#else
convBwdFilterDims.padding_left, convBwdFilterDims.padding_right));
#endif // !ENABLE_MKLDNN_V1
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
// Create descriptor and primitive descriptor for convolution bwd filter.
@ -252,14 +268,22 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
*context_.diff_filter_md, *context_.diff_bias_md,
*context_.diff_dst_md, convBwdFilterDims.strides,
convBwdFilterDims.dilations, convBwdFilterDims.padding_left,
#ifndef ENABLE_MKLDNN_V1
convBwdFilterDims.padding_right, convBwdFilterDims.padding));
#else
convBwdFilterDims.padding_right));
#endif // !ENABLE_MKLDNN_V1
} else {
context_.bwd_filter_desc.reset(new ConvBwdFilterDesc(
ALGORITHM::convolution_direct, *context_.src_md,
*context_.diff_filter_md, *context_.diff_dst_md,
convBwdFilterDims.strides, convBwdFilterDims.dilations,
#ifndef ENABLE_MKLDNN_V1
convBwdFilterDims.padding_left, convBwdFilterDims.padding_right,
convBwdFilterDims.padding));
#else
convBwdFilterDims.padding_left, convBwdFilterDims.padding_right));
#endif // !ENABLE_MKLDNN_V1
}
context_.bwd_filter_pd.reset(new ConvBwdFilterPd(
*context_.bwd_filter_desc, cpu_engine_, *context_.fwd_pd));
@ -495,11 +519,14 @@ class MklConvCustomBackpropFilterOp
// The default dilation factor for each dimension is 1 in TF and
// 0 in MKL-DNN.
for (int i = 0; i < dilations.size(); ++i) --dilations[i];
MklConvBwdFilterParams convBwdFilterDims(
fwd_src_dims, fwd_filter_dims, diff_bias_dims, diff_dst_dims, strides,
#ifndef ENABLE_MKLDNN_V1
dilations, padding_left, padding_right,
TFPaddingToMklDnnPadding(this->padding_));
#else
dilations, padding_left, padding_right);
#endif // !ENABLE_MKLDNN_V1
// MKL-DNN allocates large buffers when a conv gradient filter primtive is
// created. So we don't cache conv backward primitives when the env

View File

@ -65,20 +65,32 @@ struct MklConvBwdInputParams {
memory::dims dilations;
memory::dims padding_left;
memory::dims padding_right;
#ifndef ENABLE_MKLDNN_V1
padding_kind padding;
#endif // !ENABLE_MKLDNN_V1
MklConvBwdInputParams(memory::dims diff_src_dims, memory::dims filter_dims,
memory::dims diff_dst_dims, memory::dims strides,
memory::dims dilations, memory::dims padding_left,
#ifndef ENABLE_MKLDNN_V1
memory::dims padding_right, padding_kind padding)
#else
memory::dims padding_right)
#endif // !ENABLE_MKLDNN_V1
: diff_src_dims(diff_src_dims),
filter_dims(filter_dims),
diff_dst_dims(diff_dst_dims),
strides(strides),
dilations(dilations),
padding_left(padding_left),
#ifndef ENABLE_MKLDNN_V1
padding_right(padding_right),
padding(padding) {}
padding(padding) {
}
#else
padding_right(padding_right) {
}
#endif // !ENABLE_MKLDNN_V1
};
template <typename T>
@ -211,14 +223,22 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
ALGORITHM::convolution_direct, *context_.diff_src_md,
*context_.filter_md, *context_.diff_dst_md, convBwdInputDims.strides,
convBwdInputDims.dilations, convBwdInputDims.padding_left,
#ifndef ENABLE_MKLDNN_V1
convBwdInputDims.padding_right, convBwdInputDims.padding));
#else
convBwdInputDims.padding_right));
#endif // !ENABLE_MKLDNN_V1
context_.fwd_desc.reset(new ConvFwdDesc(
prop_kind::forward, ALGORITHM::convolution_direct,
*context_.diff_src_md, *context_.filter_md, *context_.diff_dst_md,
convBwdInputDims.strides, convBwdInputDims.dilations,
#ifndef ENABLE_MKLDNN_V1
convBwdInputDims.padding_left, convBwdInputDims.padding_right,
convBwdInputDims.padding));
#else
convBwdInputDims.padding_left, convBwdInputDims.padding_right));
#endif // !ENABLE_MKLDNN_V1
// Create primitive descriptors for conv fwd and conv bwd input.
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
@ -440,11 +460,14 @@ class MklConvCustomBackpropInputOp
// The default dilation factor for each dimension is 1 in TF and
// 0 in MKL-DNN.
for (int i = 0; i < dilations.size(); ++i) --dilations[i];
MklConvBwdInputParams convBwdInputDims(
fwd_src_dims, fwd_filter_dims, diff_dst_dims, strides, dilations,
#ifndef ENABLE_MKLDNN_V1
padding_left, padding_right,
TFPaddingToMklDnnPadding(this->padding_));
#else
padding_left, padding_right);
#endif // !ENABLE_MKLDNN_V1
// We don't cache those primitives if the environment variable
// TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is true and if primitve descriptor

View File

@ -239,13 +239,21 @@ class MklConvFwdPrimitive : public MklPrimitive {
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
*context_.filter_md, *context_.bias_md, *context_.dst_md,
convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
#ifndef ENABLE_MKLDNN_V1
convFwdDims.padding_right, padding_kind::zero));
#else
convFwdDims.padding_right));
#endif // !ENABLE_MKLDNN_V1
} else {
context_.fwd_desc.reset(new convolution_forward::desc(
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
*context_.filter_md, *context_.dst_md, convFwdDims.strides,
convFwdDims.dilations, convFwdDims.padding_left,
#ifndef ENABLE_MKLDNN_V1
convFwdDims.padding_right, padding_kind::zero));
#else
convFwdDims.padding_right));
#endif // !ENABLE_MKLDNN_V1
}
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
@ -261,12 +269,7 @@ class MklConvFwdPrimitive : public MklPrimitive {
float op_scale = post_op_param.param[0];
float op_alpha = post_op_param.param[1];
float op_beta = post_op_param.param[2];
#ifdef ENABLE_MKLDNN_V1
post_ops.append_eltwise(op_scale, mkldnn::algorithm::eltwise_relu,
op_alpha,
#else
post_ops.append_eltwise(op_scale, post_op_param.alg, op_alpha,
#endif // ENABLE_MKLDNN_V1
op_beta);
} else if (post_op_param.name == "sum") {
DCHECK_EQ(post_op_param.param.size(), 1);
@ -1033,6 +1036,7 @@ class MklConvOp : public OpKernel {
&cached_filter_data_ptensor_, filter_tensor));
Tensor* second_tensor = nullptr;
#ifndef ENABLE_MKLDNN_V1
TensorShape filter_mkl_format;
filter_mkl_format.AddDim(
sizeof(GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc)) /
@ -1042,6 +1046,21 @@ class MklConvOp : public OpKernel {
&cached_filter_md_ptensor_, &second_tensor));
second_tensor->scalar<int32>()() = static_cast<int32>(
GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc));
#else
// There is no tensor format in DNNL 1.x. So we cache the complete filter
// descriptor as flat byte array.
TensorShape cached_filter_md_shape;
memory::desc weights_desc = conv_prim_desc.weights_desc();
// We don't use .get_size() method of memory::desc since it returns size
// required to store primitive's input memory. It is much more than size of
// memory::desc itself.
cached_filter_md_shape.AddDim(sizeof(weights_desc) / sizeof(uint8));
OP_REQUIRES_OK(context, context->allocate_persistent(
DT_UINT8, cached_filter_md_shape,
&cached_filter_md_ptensor_, &second_tensor));
*reinterpret_cast<memory::desc*>(second_tensor->flat<uint8>().data()) =
weights_desc;
#endif // !ENABLE_MKLDNN_V1
}
void AllocatePersistentTensor(OpKernelContext* context,
@ -1230,12 +1249,11 @@ class MklConvOp : public OpKernel {
const Tensor& cached_filter_md =
*cached_filter_md_ptensor_.AccessTensor(context);
// Check if the memory descriptor of the cached weights is same as
// filter_md. If so, we can use the cached weights; otherwise
// return nullptr.
// Check if the memory descriptor of the cached weights is same as
// filter_md. If so, we can use the cached weights; otherwise
// return nullptr.
#ifdef ENABLE_MKLDNN_V1
if (cached_filter_md.scalar<int64>().size() &&
AreMemoryDescriptorsEqual(filter_md, cached_filter_md)) {
if (filter_md == *static_cast<memory::desc*>(cached_filter_md.data())) {
#else
if (cached_filter_md.scalar<int32>().size() &&
cached_filter_md.scalar<int32>()() == filter_md) {
@ -1568,7 +1586,7 @@ class MklQuantizedConv2DOp
if (!scaled_bias_buf_)
AllocTmpBuffer<Tbias>(context, &scaled_bias_tensor_,
conv_fwd_pd->bias_primitive_desc(),
GET_BIAS_DESC_FROM_OP_PD(conv_fwd_pd),
&scaled_bias_buf_);
if (!scaled_bias_) {
scaled_bias_ = new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_,

View File

@ -89,12 +89,26 @@ class MklDequantizeOp : public OpKernel {
Tensor* output_tensor = nullptr;
MklDnnShape output_mkl_shape;
TensorShape output_tf_shape;
#ifndef ENABLE_MKLDNN_V1
memory::desc dst_md =
src_mkl_shape.IsMklTensor()
? memory::desc(src_dims, MklDnnType<float>(),
static_cast<memory::format>(src_md.data.format))
static_cast<MEMORY_FORMAT>(src_md.data.format))
: memory::desc(src_dims, MklDnnType<float>(),
MEMORY_FORMAT::nhwc);
#else
memory::desc dst_md = memory::desc();
if (src_mkl_shape.IsMklTensor()) {
dst_md = memory::desc(src_mkl_shape.GetMklLayout().data);
// There is no API in MKL-DNN v1.x to construct memory descriptor with
// same .data field but different type.
dst_md.data.data_type = memory::convert_to_c(MklDnnType<float>());
} else {
dst_md =
memory::desc(src_dims, MklDnnType<float>(), MEMORY_FORMAT::nhwc);
}
#endif // !ENABLE_MKLDNN_V1
// If input is MKL shape, output is also MKL shape.
// If input is TF shape, output is also TF shape.
if (src_mkl_shape.IsMklTensor()) {

View File

@ -42,15 +42,29 @@ struct MklBatchNormFwdParams {
int depth;
float eps;
bool training;
#ifndef ENABLE_MKLDNN_V1
MEMORY_FORMAT src_format;
#else
memory::desc src_md;
#endif // !ENABLE_MKLDNN_V1
MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps,
#ifndef ENABLE_MKLDNN_V1
bool training, MEMORY_FORMAT src_format)
#else
bool training, memory::desc src_md)
#endif // !ENABLE_MKLDNN_V1
: src_dims(src_dims),
depth(depth),
eps(eps),
training(training),
src_format(src_format) {}
#ifndef ENABLE_MKLDNN_V1
src_format(src_format) {
}
#else
src_md(src_md) {
}
#endif // !ENABLE_MKLDNN_V1
};
template <typename T, typename U>
@ -177,16 +191,17 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
context_.pkind = fwdParams.training ? prop_kind::forward_training
: prop_kind::forward_scoring;
// Memory descriptor
auto src_md = memory::desc({fwdParams.src_dims}, MklDnnType<T>(),
fwdParams.src_format);
#ifdef ENABLE_MKLDNN_V1
// Memory descriptor
auto src_md = fwdParams.src_md;
// Create forward BatchNorm descriptor and primitive descriptor.
auto fwd_desc = batch_normalization_forward::desc(
context_.pkind, src_md, fwdParams.eps,
static_cast<mkldnn::normalization_flags>(context_.flags));
#else
// Memory descriptor
auto src_md = memory::desc({fwdParams.src_dims}, MklDnnType<T>(),
fwdParams.src_format);
auto fwd_desc = batch_normalization_forward::desc(
context_.pkind, src_md, fwdParams.eps, context_.flags);
#endif // ENABLE_MKLDNN_V1
@ -368,17 +383,36 @@ struct MklBatchNormBwdParams {
int depth;
float eps;
bool training;
#ifndef ENABLE_MKLDNN_V1
MEMORY_FORMAT src_format;
#else
memory::desc src_md;
memory::desc diff_dst_md;
#endif // !ENABLE_MKLDNN_V1
MklBatchNormBwdParams(memory::dims src_dims, memory::dims diff_dst_dims,
int depth, float eps, bool training,
#ifndef ENABLE_MKLDNN_V1
MEMORY_FORMAT src_format)
: src_dims(src_dims),
diff_dst_dims(diff_dst_dims),
depth(depth),
eps(eps),
training(training),
src_format(src_format) {}
src_format(src_format) {
}
#else
memory::desc src_md, memory::desc diff_dst_md)
: src_dims(src_dims),
diff_dst_dims(diff_dst_dims),
depth(depth),
eps(eps),
training(training),
src_md(src_md),
diff_dst_md(diff_dst_md) {
}
#endif // !ENABLE_MKLDNN_V1
};
template <typename T, typename U>
@ -432,12 +466,9 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
#ifdef ENABLE_MKLDNN_V1
// Execute backward batch-normalization primitives.
// TODO(intel-tf): Use execute_primitive instead of the inlined code.
DCHECK_EQ(context_.bwd_primitives.size(), context_.net_args.size());
for (size_t i = 0; i < context_.bwd_primitives.size(); ++i) {
context_.bwd_primitives.at(i).execute(*context_.bwd_stream,
context_.net_args.at(i));
}
execute_primitives(context_.bwd_primitives, context_.bwd_stream,
context_.net_args);
#else
context_.bwd_stream->submit(context_.bwd_primitives);
#endif // ENABLE_MKLDNN_V1
@ -516,10 +547,15 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
: (GET_FLAG(use_scale_shift) | GET_FLAG(use_global_stats));
// Memory descriptors.
#ifndef ENABLE_MKLDNN_V1
auto src_md = memory::desc({bwdParams.src_dims}, MklDnnType<T>(),
bwdParams.src_format);
auto diff_dst_md = memory::desc({bwdParams.diff_dst_dims}, MklDnnType<T>(),
bwdParams.src_format);
#else
auto src_md = bwdParams.src_md;
auto diff_dst_md = bwdParams.diff_dst_md;
#endif // !ENABLE_MKLDNN_V1
auto variance_desc =
memory::desc({1, bwdParams.depth}, MklDnnType<U>(), MEMORY_FORMAT::nc);
auto mean_desc =
@ -794,7 +830,7 @@ class MklFusedBatchNormOp : public OpKernel {
#ifdef ENABLE_MKLDNN_V1
MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_,
dnn_fmt);
src_md);
#else
MklBatchNormFwdParams fwdParams(
src_dims, depth_, epsilon_, is_training_,
@ -1149,7 +1185,7 @@ class MklFusedBatchNormGradOp : public OpKernel {
#ifdef ENABLE_MKLDNN_V1
MklBatchNormBwdParams bwdParams(src_dims, diff_dst_dims, depth_, epsilon_,
is_training_, dnn_fmt);
is_training_, src_md, diff_dst_md);
#else
MklBatchNormBwdParams bwdParams(
src_dims, diff_dst_dims, depth_, epsilon_, is_training_,
@ -1158,19 +1194,10 @@ class MklFusedBatchNormGradOp : public OpKernel {
MklFusedBatchNormBwdPrimitive<T, U>* bn_bwd =
MklFusedBatchNormBwdPrimitiveFactory<T, U>::Get(bwdParams);
// Check if src/diff_dst needs to be reordered.
const T* src_data = nullptr;
const T* src_data = src_tensor.flat<T>().data();
const T* diff_dst_data = diff_dst_tensor.flat<T>().data();
// Check if diff_dst input needs to be reordered
std::shared_ptr<BatchNormBwdPd> bn_bwd_pd = bn_bwd->GetBatchNormBwdPd();
if (IS_SRC_REORDER_NEEDED(src_md, bn_bwd_pd, bn_bwd)) {
src.SetUsrMem(src_md, &src_tensor);
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
GET_SRC_DESC_FROM_OP_PD(bn_bwd_pd), cpu_engine_));
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
} else {
src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
}
const T* diff_dst_data = nullptr;
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bn_bwd_pd, bn_bwd)) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(

View File

@ -105,6 +105,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
memory::dims weight_dims = memory::dims({channel, k});
memory::dims bias_dims = memory::dims({channel});
memory::dims dst_dims = memory::dims({batch, channel});
MEMORY_FORMAT src_format = MEMORY_FORMAT::nc;
MEMORY_FORMAT weight_format =
transpose_b_ ? MEMORY_FORMAT::oi : MEMORY_FORMAT::io;
@ -112,7 +113,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
// 1. const, let MKL-DNN determine format because it will be cached;
// 2. var, keep the original format to avoid reordering.
MklDnnMatMulFwdParams matmul_params(
src_dims, weight_dims, bias_dims, dst_dims,
src_dims, weight_dims, bias_dims, dst_dims, src_format,
(this->is_weight_const_) ? MEMORY_FORMAT::any : weight_format);
// Extend the basic parameters for data types and fusions.
@ -152,44 +153,44 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
MklDnnData<T> src_mkl(&(this->cpu_engine_));
MklDnnData<T> weight_mkl(&(this->cpu_engine_));
if (src_mkl_shape.IsMklTensor()) {
memory::desc input_md = src_mkl_shape.GetMklLayout();
#ifdef ENABLE_MKLDNN_V1
if (input_md != matmul_pd->src_desc()) {
#else
if (input_md.data.format != MKL_TENSOR_FORMAT_NC) {
#endif // ENABLE_MKLDNN_V1
src_mkl.SetUsrMem(input_md, src_data);
src_mkl.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
matmul_pd.get()->PRIMITIVE_DESC_SRC, this->cpu_engine_));
src_data = reinterpret_cast<T*>(src_mkl.GetOpMem().get_data_handle());
}
auto src_md = src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout()
: memory::desc(src_dims, MklDnnType<T>(), src_format);
if (IS_SRC_REORDER_NEEDED(src_md, matmul_pd, matmul_prim)) {
src_mkl.SetUsrMem(src_md, src_data);
src_mkl.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
matmul_pd.get()->PRIMITIVE_DESC_SRC, this->cpu_engine_));
src_data = reinterpret_cast<T*>(src_mkl.GetOpMem().get_data_handle());
}
// Get cached data when weight is const.
memory::format expected_format = matmul_prim->GetWeightMemoryFormat();
DCHECK(expected_format != weight_format && this->is_weight_const_);
if (this->is_weight_const_) {
const memory::desc weight_md =
memory::desc(weight_dims, MklDnnType<T>(), weight_format);
if (IS_WEIGHTS_REORDER_NEEDED(weight_md, matmul_pd, matmul_prim)) {
T* cached_weight_data = nullptr;
if (this->IsWeightCacheEmpty(ctx)) {
auto weight_md =
memory::desc(weight_dims, MklDnnType<T>(), weight_format);
this->CacheWeight(ctx, matmul_pd, cached_weight_data, weight_tensor,
weight_mkl, weight_md);
if (this->is_weight_const_) {
if (this->IsWeightCacheEmpty(ctx)) {
this->CacheWeight(ctx, matmul_pd, cached_weight_data, weight_tensor,
weight_mkl, weight_md);
}
#ifdef ENABLE_MKLDNN_V1
cached_weight_data = this->GetCachedWeight(
ctx, GET_WEIGHTS_DESC_FROM_OP_PD(matmul_pd));
#else
cached_weight_data = this->GetCachedWeight(
ctx, GET_WEIGHTS_DESC_FROM_OP_PD(matmul_pd).desc());
#endif
}
cached_weight_data = this->GetCachedWeight(ctx, expected_format);
// Cache weight may fail when it gets different format in different
// iteration. Fallback to reoder if it happens.
// TODO: Fix this slow path.
// Also do generel reorder if weight isn't const.
if (cached_weight_data != nullptr) {
weight_data = cached_weight_data;
} else {
memory::desc input_md =
memory::desc(weight_dims, MklDnnType<T>(), weight_format);
//>>>>>>> master
weight_mkl.SetUsrMem(input_md, weight_data);
weight_mkl.SetUsrMem(weight_md, weight_data);
weight_mkl.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
matmul_pd.get()->PRIMITIVE_DESC_WEIGHTS, this->cpu_engine_));
weight_data =
@ -210,23 +211,21 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
void ExtendMklDnnMatMulFwdParams(OpKernelContext* ctx,
MklDnnMatMulFwdParams& params) {
#ifndef ENABLE_MKLDNN_V1
if (fused_ops_.size() == 2) {
string post_op = fused_ops_[1];
if (post_op == "Relu") {
params.post_op_params.push_back({"relu", { 1.0, 0.0, 0.0 }});
params.post_op_params.push_back({"relu", {1.0, 0.0, 0.0}});
} else if (post_op == "Relu6") {
params.post_op_params.push_back({"relu6", { 1.0, 6.0, 0.0 }});
params.post_op_params.push_back({"relu6", {1.0, 6.0, 0.0}});
} else if (post_op == "Elu") {
params.post_op_params.push_back({"elu", { 1.0, 1.0, 0.0 }});
params.post_op_params.push_back({"elu", {1.0, 1.0, 0.0}});
} else {
OP_REQUIRES_OK(
ctx, errors::InvalidArgument(
"Unsupported post-argument in MklFusedMatMul: ", post_op));
}
}
#endif // !ENABLE_MKLDNN_V1
}
private:

View File

@ -41,7 +41,8 @@ struct MklDnnMatMulFwdParams {
memory::dims weight_dims;
memory::dims bias_dims;
memory::dims dst_dims;
MEMORY_FORMAT weight_fmt;
MEMORY_FORMAT src_format;
MEMORY_FORMAT weight_format;
string dtypes = string("");
struct PostOpParam {
string name;
@ -51,12 +52,14 @@ struct MklDnnMatMulFwdParams {
MklDnnMatMulFwdParams(memory::dims src_dims, memory::dims weight_dims,
memory::dims bias_dims, memory::dims dst_dims,
MEMORY_FORMAT weight_fmt = MEMORY_FORMAT::any)
MEMORY_FORMAT src_format = MEMORY_FORMAT::any,
MEMORY_FORMAT weight_format = MEMORY_FORMAT::any)
: src_dims(src_dims),
weight_dims(weight_dims),
bias_dims(bias_dims),
dst_dims(dst_dims),
weight_fmt(weight_fmt) {}
src_format(src_format),
weight_format(weight_format) {}
};
// With quantization, input, weight, bias, and output can have different types.
@ -182,15 +185,11 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
// format.
context_.src_md.reset(new memory::desc({matmul_fwd_params.src_dims},
MklDnnType<Tinput>(),
MEMORY_FORMAT::any));
matmul_fwd_params.src_format));
context_.weight_md.reset(new memory::desc({matmul_fwd_params.weight_dims},
MklDnnType<Tweight>(),
#ifdef ENABLE_MKLDNN_V1
MEMORY_FORMAT::any));
#else
matmul_fwd_params.weight_fmt));
#endif // ENABLE_MKLDNN_V1
matmul_fwd_params.weight_format));
context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims},
MklDnnType<Toutput>(),
@ -438,49 +437,61 @@ class MklDnnMatMulOpBase : public OpKernel {
// reorder and cache the weight
weight.SetUsrMem(weight_md, &weight_tensor);
weight.CheckReorderToOpMem(matmul_fwd_pd.get()->weights_primitive_desc());
weight.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_));
weight_data = static_cast<Tweight*>(weight.GetOpMem().get_data_handle());
Tensor* weight_tensor_ptr = nullptr;
size_t weight_size = matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS.get_size();
TensorShape weight_tf_shape;
weight_tf_shape.AddDim(
(matmul_fwd_pd.get()->weights_primitive_desc().get_size() /
sizeof(Tweight)));
weight_tf_shape.AddDim(weight_size / sizeof(Tweight));
OP_REQUIRES_OK(context, context->allocate_persistent(
DataTypeToEnum<Tweight>::value, weight_tf_shape,
&weight_oi_, &weight_tensor_ptr));
void* weight_oi_t_data = weight.GetTensorBuffer(weight_tensor_ptr);
size_t weight_size = weight.GetOpMem().get_primitive_desc().get_size();
memcpy(weight_oi_t_data, weight_data, weight_size);
// cache the memory descriptor
// cache the memory descriptor
#ifdef ENABLE_MKLDNN_V1
auto expected_md = GET_WEIGHTS_DESC_FROM_OP_PD(matmul_fwd_pd);
#else
auto expected_md = GET_WEIGHTS_DESC_FROM_OP_PD(matmul_fwd_pd).desc();
#endif
Tensor* weight_md_tensor_ptr = nullptr;
TensorShape weight_mkl_format;
weight_mkl_format.AddDim(1);
weight_mkl_format.AddDim(sizeof(expected_md) / sizeof(Tweight));
OP_REQUIRES_OK(context, context->allocate_persistent(
DT_INT32, weight_mkl_format, &weight_oi_md_,
&weight_md_tensor_ptr));
weight_md_tensor_ptr->scalar<int32>()() =
matmul_fwd_pd.get()->weights_primitive_desc().desc().data.format;
OP_REQUIRES_OK(
context, context->allocate_persistent(DataTypeToEnum<Tweight>::value,
weight_mkl_format, &weight_oi_md_,
&weight_md_tensor_ptr));
*reinterpret_cast<memory::desc*>(
weight_md_tensor_ptr->flat<Tweight>().data()) = expected_md;
}
Tweight* GetCachedWeight(OpKernelContext* context,
const memory::format& weight_mf)
const memory::desc& expected_md)
LOCKS_EXCLUDED(mu_) {
tf_shared_lock lock(mu_);
const Tensor& weight_t = *weight_oi_.AccessTensor(context);
const Tensor& weight_md_t = *weight_oi_md_.AccessTensor(context);
// Check if the memory descriptor of the cached weight is same as
// weight_mf. if so use the cached memory, else return NULL
if (weight_md_t.scalar<int32>().size() &&
weight_md_t.scalar<int32>()() == weight_mf) {
return static_cast<Tweight*>(
const_cast<Tweight*>(weight_t.flat<Tweight>().data()));
// Check if the memory descriptor of the cached weight is same as
// expected_md. if so use the cached memory, else return NULL
if (weight_md_t.flat<Tweight>().size()) {
const memory::desc& stored_md =
*(static_cast<memory::desc*>(weight_md_t.data()));
#ifdef ENABLE_MKLDNN_V1
if (stored_md == expected_md) {
#else
if (stored_md.data.format == expected_md.data.format) {
#endif
return static_cast<Tweight*>(
const_cast<Tweight*>(weight_t.flat<Tweight>().data()));
}
}
return nullptr;
}
@ -527,7 +538,8 @@ void dnnl_gemm_exec(const dnnl::desc& a_md, const dnnl::desc& b_md,
dnnl::stream s(cpu_engine);
matmul_prim.execute(s, {{DNNL_ARG_SRC, a_memory},
{DNNL_ARG_WEIGHTS, b_memory},
{DNNL_ARG_DST, c_memory}});
{ DNNL_ARG_DST,
c_memory }});
s.wait();
}

View File

@ -143,12 +143,12 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
MklPoolingParams fwdParams(
src_dims, output_dims_mkl_order, filter_dims, strides, padding_left,
padding_right, ALGORITHM::pooling_max, pooling_prop_kind,
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_));
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), input_md);
#else
MklPoolingParams fwdParams(
src_dims, output_dims_mkl_order, filter_dims, strides, padding_left,
padding_right, ALGORITHM::pooling_max, pooling_prop_kind,
static_cast<MEMORY_FORMAT>(input_md.data.format));
static_cast<MEMORY_FORMAT>(input_md.data.format), input_md);
#endif
pooling_fwd = MklPoolingFwdPrimitiveFactory<T>::Get(fwdParams);
// Allocate output tensor.
@ -303,13 +303,13 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
strides, padding_left, padding_right, ALGORITHM::pooling_max,
prop_kind::forward_training,
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_));
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), src_md);
#else
MklPoolingParams bwdParams(
orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
strides, padding_left, padding_right, ALGORITHM::pooling_max,
prop_kind::forward_training,
static_cast<MEMORY_FORMAT>(src_md.data.format));
static_cast<MEMORY_FORMAT>(src_md.data.format), src_md);
#endif
MklPoolingBwdPrimitive<T>* pooling_bwd =
MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);

View File

@ -43,6 +43,7 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
// so src format is currently hard-coded.
// A utility function is used to do this,
// which may be broken with future CPU architectures
#ifndef ENABLE_MKLDNN_V1
bool is_2d = (fwdParams.src_dims.size() == 4);
if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value)
context_.src_fmt = is_2d ? MEMORY_FORMAT::nhwc : MEMORY_FORMAT::ndhwc;
@ -51,6 +52,9 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
context_.src_md.reset(new memory::desc({fwdParams.src_dims}, MklDnnType<T>(),
context_.src_fmt));
#else
context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
#endif // !ENABLE_MKLDNN_V1
context_.dst_md.reset(new memory::desc({fwdParams.dst_dims}, MklDnnType<T>(),
MEMORY_FORMAT::any));

View File

@ -47,13 +47,14 @@ struct MklPoolingParams {
memory::dims padding_right;
mkldnn::algorithm alg_kind;
mkldnn::prop_kind prop_kind;
memory::format src_format;
MEMORY_FORMAT src_format;
memory::desc src_md;
MklPoolingParams(memory::dims src_dims, memory::dims dst_dims,
memory::dims filter_dims, memory::dims strides,
memory::dims padding_left, memory::dims padding_right,
mkldnn::algorithm alg_kind, mkldnn::prop_kind prop_kind,
memory::format src_format)
MEMORY_FORMAT src_format, memory::desc src_md)
: src_dims(src_dims),
dst_dims(dst_dims),
filter_dims(filter_dims),
@ -62,7 +63,8 @@ struct MklPoolingParams {
padding_right(padding_right),
alg_kind(alg_kind),
prop_kind(prop_kind),
src_format(src_format) {}
src_format(src_format),
src_md(src_md) {}
};
template <typename T>

View File

@ -269,11 +269,11 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
}
#ifdef ENABLE_MKLDNN_V1
weight_data = this->GetCachedWeight(
context, static_cast<int32>(weight_mkl_shape.GetTfDataFormat()));
context, GET_WEIGHTS_DESC_FROM_OP_PD(matmul_fwd_pd));
#else
weight_data = this->GetCachedWeight(
context, matmul_fwd->GetWeightMemoryFormat());
#endif // ENABLE_MKLDNN_V1
context, GET_WEIGHTS_DESC_FROM_OP_PD(matmul_fwd_pd).desc());
#endif
is_weight_cached = (weight_data != nullptr);
}

View File

@ -164,7 +164,11 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
context_.src_mpd.reset(
#ifdef ENABLE_MKLDNN_V1
new MEMORY_PRIMITIVE_DESC(*context_.src_md));
#else
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
#endif
// Create an eltwise forward descriptor and primitive descriptor
context_.fwd_desc.reset(new eltwise_forward::desc(
@ -397,7 +401,6 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
// Create memory descriptors for eltwise data w/ no specified format
context_.src_md.reset(new memory::desc(bwdParams.common_md.data));
context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data));
context_.src_mpd.reset(
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
context_.diff_dst_mpd.reset(

View File

@ -65,7 +65,7 @@ class MklSoftmaxPrimitive : public MklPrimitive {
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.fwd_primitives, context_.fwd_stream,
context_.net_args);
context_.fwd_net_args);
#else
context_.fwd_stream->submit(context_.fwd_primitives);
#endif

View File

@ -1302,6 +1302,7 @@ REGISTER_OP("_MklFusedBatchNormV3")
.Attr("U: {float}")
.Attr("epsilon: float = 0.0001")
.Attr(GetConvnetDataFormatAttrString())
.Attr("exponential_avg_factor: float = 1.0")
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormShape)
.Doc(

View File

@ -2524,6 +2524,7 @@ REGISTER_OP("_MklFusedBatchNorm")
.Attr("T: numbertype")
.Attr("epsilon: float = 0.0001")
.Attr("data_format: string = 'NHWC'")
.Attr("exponential_avg_factor: float = 1.0")
.Attr("is_training: bool = true")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle x;
@ -2673,6 +2674,7 @@ REGISTER_OP("_MklFusedBatchNormV2")
.Attr("U: {float}")
.Attr("epsilon: float = 0.0001")
.Attr(GetConvnetDataFormatAttrString())
.Attr("exponential_avg_factor: float = 1.0")
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormShape);

View File

@ -48,6 +48,7 @@ namespace tensorflow {
#define GET_WORKSPACE_DESC_FROM_OP_PD(op_pd) op_pd->workspace_desc()
#define GET_TENSOR_FORMAT(fmt) MklTensorFormatToMklDnnDataFormat(fmt)
#define GET_TF_DATA_FORMAT(shape, mem_desc) shape.GetTfDataFormat()
#define GET_USR_MEM_PRIM_DESC(src) src.GetUsrMemDesc()
#define GET_WEIGHTS_DESC_FROM_OP_PD(op_pd) op_pd->weights_desc()
#define GET_WEIGHTS_FORMAT_FROM_OP_PD(op_pd, op) \
GET_WEIGHTS_DESC_FROM_OP_PD(op_pd)
@ -114,7 +115,6 @@ namespace tensorflow {
#define TENSOR_FORMAT MKL_TENSOR_FORMAT
#define TENSOR_FORMAT_NHWC MKL_TENSOR_FORMAT_NHWC
#define TENSOR_MAX_DIMS MKLDNN_MAX_NDIMS
#define GET_USR_MEM_PRIM_DESC(src) src.GetUsrMemDesc()
#else
@ -148,6 +148,7 @@ namespace tensorflow {
op_pd.get()->workspace_primitive_desc()
#define GET_TENSOR_FORMAT(fmt) fmt
#define GET_TF_DATA_FORMAT(shape, mem_desc) mem_desc.data.format
#define GET_USR_MEM_PRIM_DESC(src) src.GetUsrMemPrimDesc()
#define GET_WEIGHTS_DESC_FROM_OP_PD(op_pd) op_pd.get()->weights_primitive_desc()
#define GET_WEIGHTS_FORMAT_FROM_OP_PD(op_pd, op) op->GetFilterMemoryFormat()
#define IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, op_pd, op) \
@ -215,7 +216,6 @@ namespace tensorflow {
#define SUMMAND_MD summand_pd
#define TENSOR_FORMAT TensorFormat
#define TENSOR_FORMAT_NHWC FORMAT_NHWC
#define GET_USR_MEM_PRIM_DESC(src) src.GetUsrMemPrimDesc()
#endif // ENABLE_MKLDNN_V1
} // namespace tensorflow

View File

@ -18,12 +18,13 @@ limitations under the License.
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/mkl_types.h"
namespace tensorflow {
namespace {
TEST(MklUtilTest, MklDnnTfShape) {
auto cpu_engine = engine(engine::cpu, 0);
auto cpu_engine = engine(ENGINE_CPU, 0);
MklDnnData<float> a(&cpu_engine);
const int N = 1, C = 2, H = 3, W = 4;
@ -31,7 +32,7 @@ TEST(MklUtilTest, MklDnnTfShape) {
MklDnnShape a_mkldnn_shape;
a_mkldnn_shape.SetMklTensor(true);
// Create TF layout in NCHW.
a_mkldnn_shape.SetTfLayout(a_dims.size(), a_dims, memory::format::nchw);
a_mkldnn_shape.SetTfLayout(a_dims.size(), a_dims, MKL_TENSOR_FORMAT_NCHW);
TensorShape a_tf_shape_nchw({N, C, H, W});
TensorShape a_tf_shape_nhwc({N, H, W, C});
TensorShape a_mkldnn_tf_shape = a_mkldnn_shape.GetTfShape();
@ -43,7 +44,7 @@ TEST(MklUtilTest, MklDnnTfShape) {
MklDnnShape b_mkldnn_shape;
b_mkldnn_shape.SetMklTensor(true);
// Create TF layout in NHWC.
b_mkldnn_shape.SetTfLayout(b_dims.size(), b_dims, memory::format::nhwc);
b_mkldnn_shape.SetTfLayout(b_dims.size(), b_dims, MKL_TENSOR_FORMAT_NHWC);
TensorShape b_tf_shape_nhwc({N, H, W, C});
TensorShape b_tf_shape_nchw({N, C, H, W});
TensorShape b_mkldnn_tf_shape = b_mkldnn_shape.GetTfShape();
@ -55,7 +56,7 @@ TEST(MklUtilTest, MklDnnTfShape) {
TEST(MklUtilTest, MklDnnBlockedFormatTest) {
// Let's create 2D tensor of shape {3, 4} with 3 being innermost dimension
// first (case 1) and then it being outermost dimension (case 2).
auto cpu_engine = engine(engine::cpu, 0);
auto cpu_engine = engine(ENGINE_CPU, 0);
// Setting for case 1
MklDnnData<float> a(&cpu_engine);
@ -67,7 +68,9 @@ TEST(MklUtilTest, MklDnnBlockedFormatTest) {
EXPECT_EQ(a_md1.data.ndims, 2);
EXPECT_EQ(a_md1.data.dims[0], 3);
EXPECT_EQ(a_md1.data.dims[1], 4);
#ifndef ENABLE_MKLDNN_V1
EXPECT_EQ(a_md1.data.format, mkldnn_blocked);
#endif // !ENABLE_MKLDNN_V1
// Setting for case 2
MklDnnData<float> b(&cpu_engine);
@ -79,7 +82,9 @@ TEST(MklUtilTest, MklDnnBlockedFormatTest) {
EXPECT_EQ(b_md2.data.ndims, 2);
EXPECT_EQ(b_md2.data.dims[0], 3);
EXPECT_EQ(b_md2.data.dims[1], 4);
#ifndef ENABLE_MKLDNN_V1
EXPECT_EQ(b_md2.data.format, mkldnn_blocked);
#endif // !ENABLE_MKLDNN_V1
}
TEST(MklUtilTest, LRUCacheTest) {

View File

@ -174,12 +174,12 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "mkl_dnn_v1",
build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"),
sha256 = "fcc2d951f7170eade0cfdd0d8d1d58e3e7785bd326bca6555f3722f8cba71811",
strip_prefix = "mkl-dnn-1.0-pc2",
build_file = clean_dep("//third_party/mkl_dnn:mkldnn_v1.BUILD"),
sha256 = "30979a09753e8e35d942446c3778c9f0eba543acf2fb0282af8b9c89355d0ddf",
strip_prefix = "mkl-dnn-1.2",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/intel/mkl-dnn/archive/v1.0-pc2.tar.gz",
"https://github.com/intel/mkl-dnn/archive/v1.0-pc2.tar.gz",
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/intel/mkl-dnn/archive/v1.2.tar.gz",
"https://github.com/intel/mkl-dnn/archive/v1.2.tar.gz",
],
)

136
third_party/mkl_dnn/mkldnn_v1.BUILD vendored Normal file
View File

@ -0,0 +1,136 @@
exports_files(["LICENSE"])
load(
"@org_tensorflow//third_party/mkl_dnn:build_defs.bzl",
"if_mkl_open_source_only",
"if_mkl_v1_open_source_only",
)
load(
"@org_tensorflow//third_party:common.bzl",
"template_rule",
)
config_setting(
name = "clang_linux_x86_64",
values = {
"cpu": "k8",
"define": "using_clang=true",
},
)
template_rule(
name = "dnnl_config_h",
src = "include/dnnl_config.h.in",
out = "include/dnnl_config.h",
substitutions = {
"#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_OMP",
"#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_OMP",
"#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE",
},
)
# Create the file mkldnn_version.h with MKL-DNN version numbers.
# Currently, the version numbers are hard coded here. If MKL-DNN is upgraded then
# the version numbers have to be updated manually. The version numbers can be
# obtained from the PROJECT_VERSION settings in CMakeLists.txt. The variable is
# set to "version_major.version_minor.version_patch". The git hash version can
# be set to NA.
# TODO(agramesh1) Automatically get the version numbers from CMakeLists.txt.
template_rule(
name = "dnnl_version_h",
src = "include/dnnl_version.h.in",
out = "include/dnnl_version.h",
substitutions = {
"@DNNL_VERSION_MAJOR@": "1",
"@DNNL_VERSION_MINOR@": "2",
"@DNNL_VERSION_PATCH@": "0",
"@DNNL_VERSION_HASH@": "N/A",
},
)
cc_library(
name = "mkl_dnn",
srcs = glob([
"src/common/*.cpp",
"src/common/*.hpp",
"src/cpu/*.cpp",
"src/cpu/*.hpp",
"src/cpu/**/*.cpp",
"src/cpu/**/*.hpp",
"src/cpu/xbyak/*.h",
]) + if_mkl_v1_open_source_only([
":dnnl_config_h",
]) + [":dnnl_version_h"],
hdrs = glob(["include/*"]),
copts = [
"-fexceptions",
"-DUSE_MKL",
"-DUSE_CBLAS",
] + if_mkl_open_source_only([
"-UUSE_MKL",
"-UUSE_CBLAS",
]) + if_mkl_v1_open_source_only([
"-UUSE_MKL",
"-UUSE_CBLAS",
]) + select({
"@org_tensorflow//tensorflow:linux_x86_64": [
"-fopenmp", # only works with gcc
],
# TODO(ibiryukov): enable openmp with clang by including libomp as a
# dependency.
":clang_linux_x86_64": [],
"//conditions:default": [],
}),
includes = [
"include",
"src",
"src/common",
"src/cpu",
"src/cpu/gemm",
"src/cpu/xbyak",
],
visibility = ["//visibility:public"],
deps = select({
"@org_tensorflow//tensorflow:linux_x86_64": [
"@mkl_linux//:mkl_headers",
"@mkl_linux//:mkl_libs_linux",
],
"@org_tensorflow//tensorflow:macos": [
"@mkl_darwin//:mkl_headers",
"@mkl_darwin//:mkl_libs_darwin",
],
"@org_tensorflow//tensorflow:windows": [
"@mkl_windows//:mkl_headers",
"@mkl_windows//:mkl_libs_windows",
],
"//conditions:default": [],
}),
)
cc_library(
name = "mkldnn_single_threaded",
srcs = glob([
"src/common/*.cpp",
"src/common/*.hpp",
"src/cpu/*.cpp",
"src/cpu/*.hpp",
"src/cpu/**/*.cpp",
"src/cpu/**/*.hpp",
"src/cpu/xbyak/*.h",
]) + [":dnnl_config_h"],
hdrs = glob(["include/*"]),
copts = [
"-fexceptions",
"-DMKLDNN_THR=MKLDNN_THR_SEQ", # Disables threading.
],
includes = [
"include",
"src",
"src/common",
"src/cpu",
"src/cpu/gemm",
"src/cpu/xbyak",
],
visibility = ["//visibility:public"],
)