Merge pull request #37081 from Intel-tensorflow:nhasabni/fixes_for_dnnl1.0
PiperOrigin-RevId: 297740489 Change-Id: I6201722923452c038bf0fd44065776dda12f87e5
This commit is contained in:
commit
7dbff51a35
tensorflow
core
kernels
mkl_avgpooling_op.ccmkl_concat_op.ccmkl_conv_grad_filter_ops.ccmkl_conv_grad_input_ops.ccmkl_conv_ops.ccmkl_dequantize_op.ccmkl_fused_batch_norm_op.ccmkl_matmul_op_fused.ccmkl_matmul_ops_common.hmkl_maxpooling_op.ccmkl_pooling_ops_common.ccmkl_pooling_ops_common.hmkl_qmatmul_op.ccmkl_relu_op.ccmkl_softmax_op.cc
ops
util
third_party/mkl_dnn
@ -116,12 +116,13 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
|
||||
src_dims, output_dims_mkl_order, filter_dims, strides, padding_left,
|
||||
padding_right, ALGORITHM::pooling_avg_exclude_padding,
|
||||
pooling_prop_kind,
|
||||
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_));
|
||||
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), input_md);
|
||||
#else
|
||||
MklPoolingParams fwdParams(
|
||||
src_dims, output_dims_mkl_order, filter_dims, strides, padding_left,
|
||||
padding_right, ALGORITHM::pooling_avg_exclude_padding,
|
||||
pooling_prop_kind, static_cast<MEMORY_FORMAT>(input_md.data.format));
|
||||
pooling_prop_kind, static_cast<MEMORY_FORMAT>(input_md.data.format),
|
||||
input_md);
|
||||
#endif
|
||||
pooling_fwd = MklPoolingFwdPrimitiveFactory<T>::Get(fwdParams);
|
||||
|
||||
@ -240,13 +241,13 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
|
||||
orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
|
||||
strides, padding_left, padding_right,
|
||||
ALGORITHM::pooling_avg_exclude_padding, prop_kind::forward_training,
|
||||
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_));
|
||||
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), src_md);
|
||||
#else
|
||||
MklPoolingParams bwdParams(
|
||||
orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
|
||||
strides, padding_left, padding_right,
|
||||
ALGORITHM::pooling_avg_exclude_padding, prop_kind::forward_training,
|
||||
static_cast<MEMORY_FORMAT>(src_md.data.format));
|
||||
static_cast<MEMORY_FORMAT>(src_md.data.format), src_md);
|
||||
#endif
|
||||
MklPoolingBwdPrimitive<T>* pooling_bwd =
|
||||
MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);
|
||||
|
@ -266,7 +266,7 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
||||
explicit MklConcatFwdPrimitive(const MklConcatFwdParams& concat_fwd_dims,
|
||||
const std::vector<memory::desc>& srcs_md)
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
context_.fwd_stream.reset(new CPU_STREAM(stream::kind::eager));
|
||||
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
// Create concat primitive
|
||||
Setup(concat_fwd_dims, srcs_md);
|
||||
}
|
||||
@ -292,8 +292,8 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
||||
}
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.fwd_primitives, *context_.fwd_stream,
|
||||
context_.fwd_primitives_args.at(i));
|
||||
execute_primitives(context_.fwd_primitives, context_.fwd_stream,
|
||||
context_.fwd_primitives_args);
|
||||
#else
|
||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
@ -328,7 +328,7 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
||||
std::shared_ptr<mkldnn::memory> dst_mem;
|
||||
|
||||
// Memory descriptor
|
||||
std::vector<std::shared_ptr<mkldnn::memory::desc>> src_md;
|
||||
std::vector<mkldnn::memory::desc> src_md;
|
||||
std::shared_ptr<mkldnn::memory::desc> dst_md;
|
||||
|
||||
// Concat primitive descriptor
|
||||
@ -339,7 +339,7 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
||||
std::vector<mkldnn::primitive> fwd_primitives;
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
std::vector<std::unordered_map<int, memory>> fwd_primitive_args;
|
||||
std::vector<std::unordered_map<int, memory>> fwd_primitives_args;
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
ConcatFwdContext()
|
||||
@ -355,15 +355,14 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
||||
const std::vector<memory::desc>& srcs_md) {
|
||||
// Create memory descriptors for concat with specified srcs format
|
||||
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
|
||||
std::shared_ptr<mkldnn::memory::desc> source_md(
|
||||
new memory::desc(srcs_md[i].data));
|
||||
mkldnn::memory::desc source_md(memory::desc(srcs_md[i].data));
|
||||
context_.src_md.push_back(source_md);
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
std::shared_ptr<mkldnn::memory> src_mem(
|
||||
new mkldnn::memory(*source_md, cpu_engine_, DummyData));
|
||||
new mkldnn::memory(source_md, cpu_engine_, DummyData));
|
||||
#else
|
||||
std::shared_ptr<mkldnn::memory::primitive_desc> src_mpd(
|
||||
new memory::primitive_desc(*source_md, cpu_engine_));
|
||||
new memory::primitive_desc(source_md, cpu_engine_));
|
||||
context_.src_pd_shdptr.push_back(src_mpd);
|
||||
|
||||
std::shared_ptr<mkldnn::memory> src_mem(
|
||||
@ -665,8 +664,9 @@ class MklConcatOp : public OpKernel {
|
||||
if (input_tensors[k].NumElements() == 0) continue;
|
||||
auto src_md = mkl_input_shapes[k].GetMklLayout();
|
||||
srcs[k].SetUsrMem(src_md, &input_tensors[k]);
|
||||
|
||||
if (src_md.data.format != mkl_common_format) {
|
||||
auto src_tf_fmt = MklTensorFormatToMklDnnDataFormat(
|
||||
mkl_input_shapes[k].GetTfDataFormat());
|
||||
if (src_tf_fmt != mkl_common_format) {
|
||||
memory::dims src_dims(src_md.data.dims,
|
||||
&src_md.data.dims[src_md.data.ndims]);
|
||||
src_md =
|
||||
@ -935,7 +935,8 @@ class MklConcatOp : public OpKernel {
|
||||
for (int k = 0; k < input_shapes.size(); k++) {
|
||||
auto src_dims = TFShapeToMklDnnDims(input_shapes[k].GetTfShape());
|
||||
*concat_dim_size += src_dims[concat_dim];
|
||||
int fmt = static_cast<int>(input_shapes[k].GetMklLayout().data.format);
|
||||
int fmt = static_cast<int>(
|
||||
MklTensorFormatToMklDnnDataFormat(input_shapes[k].GetTfDataFormat()));
|
||||
occurrence_map[fmt] += 1;
|
||||
}
|
||||
|
||||
@ -943,7 +944,7 @@ class MklConcatOp : public OpKernel {
|
||||
// this means that all inputs have a same format
|
||||
// return it with is_reorder_needed set false.
|
||||
return static_cast<MEMORY_FORMAT>(
|
||||
input_shapes[0].GetMklLayout().data.format);
|
||||
MklTensorFormatToMklDnnDataFormat(input_shapes[0].GetTfDataFormat()));
|
||||
}
|
||||
|
||||
// Input tensors have different formats. Thus, reorder is needed.
|
||||
@ -970,7 +971,7 @@ class MklConcatOp : public OpKernel {
|
||||
.TypeConstraint<type>("T") \
|
||||
.HostMemory("concat_dim") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklConcatOp<CPUDevice, type, NAME_IS_CONCAT_DIM>) \
|
||||
MklConcatOp<CPUDevice, type, NAME_IS_CONCAT_DIM>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_MklConcatV2") \
|
||||
.Device(DEVICE_CPU) \
|
||||
@ -978,7 +979,7 @@ class MklConcatOp : public OpKernel {
|
||||
.TypeConstraint<int32>("Tidx") \
|
||||
.HostMemory("axis") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklConcatOp<CPUDevice, type, NAME_IS_AXIS>)
|
||||
MklConcatOp<CPUDevice, type, NAME_IS_AXIS>);
|
||||
|
||||
TF_CALL_float(REGISTER_MKL_CPU);
|
||||
TF_CALL_bfloat16(REGISTER_MKL_CPU);
|
||||
@ -988,14 +989,14 @@ REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2")
|
||||
.TypeConstraint<quint8>("T")
|
||||
.HostMemory("axis")
|
||||
.Label(mkl_op_registry::kMklQuantizedOpLabel),
|
||||
MklConcatOp<CPUDevice, quint8, NAME_IS_AXIS>)
|
||||
MklConcatOp<CPUDevice, quint8, NAME_IS_AXIS>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<qint8>("T")
|
||||
.HostMemory("axis")
|
||||
.Label(mkl_op_registry::kMklQuantizedOpLabel),
|
||||
MklConcatOp<CPUDevice, qint8, NAME_IS_AXIS>)
|
||||
MklConcatOp<CPUDevice, qint8, NAME_IS_AXIS>);
|
||||
|
||||
#undef REGISTER_CONCAT_MKL
|
||||
} // namespace tensorflow
|
||||
|
@ -62,13 +62,19 @@ struct MklConvBwdFilterParams {
|
||||
memory::dims dilations;
|
||||
memory::dims padding_left;
|
||||
memory::dims padding_right;
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
padding_kind padding;
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
MklConvBwdFilterParams(memory::dims src_dims, memory::dims diff_filter_dims,
|
||||
memory::dims diff_bias_dims,
|
||||
memory::dims diff_dst_dims, memory::dims strides,
|
||||
memory::dims dilations, memory::dims padding_left,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
memory::dims padding_right, padding_kind padding)
|
||||
#else
|
||||
memory::dims padding_right)
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
: src_dims(src_dims),
|
||||
diff_filter_dims(diff_filter_dims),
|
||||
diff_bias_dims(diff_bias_dims),
|
||||
@ -76,8 +82,14 @@ struct MklConvBwdFilterParams {
|
||||
strides(strides),
|
||||
dilations(dilations),
|
||||
padding_left(padding_left),
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
padding_right(padding_right),
|
||||
padding(padding) {}
|
||||
padding(padding) {
|
||||
}
|
||||
#else
|
||||
padding_right(padding_right) {
|
||||
}
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -241,8 +253,12 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
||||
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
|
||||
*context_.diff_filter_md, *context_.diff_dst_md,
|
||||
convBwdFilterDims.strides, convBwdFilterDims.dilations,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
convBwdFilterDims.padding_left, convBwdFilterDims.padding_right,
|
||||
convBwdFilterDims.padding));
|
||||
#else
|
||||
convBwdFilterDims.padding_left, convBwdFilterDims.padding_right));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
|
||||
|
||||
// Create descriptor and primitive descriptor for convolution bwd filter.
|
||||
@ -252,14 +268,22 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
||||
*context_.diff_filter_md, *context_.diff_bias_md,
|
||||
*context_.diff_dst_md, convBwdFilterDims.strides,
|
||||
convBwdFilterDims.dilations, convBwdFilterDims.padding_left,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
convBwdFilterDims.padding_right, convBwdFilterDims.padding));
|
||||
#else
|
||||
convBwdFilterDims.padding_right));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
} else {
|
||||
context_.bwd_filter_desc.reset(new ConvBwdFilterDesc(
|
||||
ALGORITHM::convolution_direct, *context_.src_md,
|
||||
*context_.diff_filter_md, *context_.diff_dst_md,
|
||||
convBwdFilterDims.strides, convBwdFilterDims.dilations,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
convBwdFilterDims.padding_left, convBwdFilterDims.padding_right,
|
||||
convBwdFilterDims.padding));
|
||||
#else
|
||||
convBwdFilterDims.padding_left, convBwdFilterDims.padding_right));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
}
|
||||
context_.bwd_filter_pd.reset(new ConvBwdFilterPd(
|
||||
*context_.bwd_filter_desc, cpu_engine_, *context_.fwd_pd));
|
||||
@ -495,11 +519,14 @@ class MklConvCustomBackpropFilterOp
|
||||
// The default dilation factor for each dimension is 1 in TF and
|
||||
// 0 in MKL-DNN.
|
||||
for (int i = 0; i < dilations.size(); ++i) --dilations[i];
|
||||
|
||||
MklConvBwdFilterParams convBwdFilterDims(
|
||||
fwd_src_dims, fwd_filter_dims, diff_bias_dims, diff_dst_dims, strides,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
dilations, padding_left, padding_right,
|
||||
TFPaddingToMklDnnPadding(this->padding_));
|
||||
#else
|
||||
dilations, padding_left, padding_right);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// MKL-DNN allocates large buffers when a conv gradient filter primtive is
|
||||
// created. So we don't cache conv backward primitives when the env
|
||||
|
@ -65,20 +65,32 @@ struct MklConvBwdInputParams {
|
||||
memory::dims dilations;
|
||||
memory::dims padding_left;
|
||||
memory::dims padding_right;
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
padding_kind padding;
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
MklConvBwdInputParams(memory::dims diff_src_dims, memory::dims filter_dims,
|
||||
memory::dims diff_dst_dims, memory::dims strides,
|
||||
memory::dims dilations, memory::dims padding_left,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
memory::dims padding_right, padding_kind padding)
|
||||
#else
|
||||
memory::dims padding_right)
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
: diff_src_dims(diff_src_dims),
|
||||
filter_dims(filter_dims),
|
||||
diff_dst_dims(diff_dst_dims),
|
||||
strides(strides),
|
||||
dilations(dilations),
|
||||
padding_left(padding_left),
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
padding_right(padding_right),
|
||||
padding(padding) {}
|
||||
padding(padding) {
|
||||
}
|
||||
#else
|
||||
padding_right(padding_right) {
|
||||
}
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -211,14 +223,22 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
|
||||
ALGORITHM::convolution_direct, *context_.diff_src_md,
|
||||
*context_.filter_md, *context_.diff_dst_md, convBwdInputDims.strides,
|
||||
convBwdInputDims.dilations, convBwdInputDims.padding_left,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
convBwdInputDims.padding_right, convBwdInputDims.padding));
|
||||
#else
|
||||
convBwdInputDims.padding_right));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
context_.fwd_desc.reset(new ConvFwdDesc(
|
||||
prop_kind::forward, ALGORITHM::convolution_direct,
|
||||
*context_.diff_src_md, *context_.filter_md, *context_.diff_dst_md,
|
||||
convBwdInputDims.strides, convBwdInputDims.dilations,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
convBwdInputDims.padding_left, convBwdInputDims.padding_right,
|
||||
convBwdInputDims.padding));
|
||||
#else
|
||||
convBwdInputDims.padding_left, convBwdInputDims.padding_right));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// Create primitive descriptors for conv fwd and conv bwd input.
|
||||
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
|
||||
@ -440,11 +460,14 @@ class MklConvCustomBackpropInputOp
|
||||
// The default dilation factor for each dimension is 1 in TF and
|
||||
// 0 in MKL-DNN.
|
||||
for (int i = 0; i < dilations.size(); ++i) --dilations[i];
|
||||
|
||||
MklConvBwdInputParams convBwdInputDims(
|
||||
fwd_src_dims, fwd_filter_dims, diff_dst_dims, strides, dilations,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
padding_left, padding_right,
|
||||
TFPaddingToMklDnnPadding(this->padding_));
|
||||
#else
|
||||
padding_left, padding_right);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// We don't cache those primitives if the environment variable
|
||||
// TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is true and if primitve descriptor
|
||||
|
@ -239,13 +239,21 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
|
||||
*context_.filter_md, *context_.bias_md, *context_.dst_md,
|
||||
convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
convFwdDims.padding_right, padding_kind::zero));
|
||||
#else
|
||||
convFwdDims.padding_right));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
} else {
|
||||
context_.fwd_desc.reset(new convolution_forward::desc(
|
||||
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
|
||||
*context_.filter_md, *context_.dst_md, convFwdDims.strides,
|
||||
convFwdDims.dilations, convFwdDims.padding_left,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
convFwdDims.padding_right, padding_kind::zero));
|
||||
#else
|
||||
convFwdDims.padding_right));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
|
||||
@ -261,12 +269,7 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
float op_scale = post_op_param.param[0];
|
||||
float op_alpha = post_op_param.param[1];
|
||||
float op_beta = post_op_param.param[2];
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
post_ops.append_eltwise(op_scale, mkldnn::algorithm::eltwise_relu,
|
||||
op_alpha,
|
||||
#else
|
||||
post_ops.append_eltwise(op_scale, post_op_param.alg, op_alpha,
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
op_beta);
|
||||
} else if (post_op_param.name == "sum") {
|
||||
DCHECK_EQ(post_op_param.param.size(), 1);
|
||||
@ -1033,6 +1036,7 @@ class MklConvOp : public OpKernel {
|
||||
&cached_filter_data_ptensor_, filter_tensor));
|
||||
|
||||
Tensor* second_tensor = nullptr;
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
TensorShape filter_mkl_format;
|
||||
filter_mkl_format.AddDim(
|
||||
sizeof(GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc)) /
|
||||
@ -1042,6 +1046,21 @@ class MklConvOp : public OpKernel {
|
||||
&cached_filter_md_ptensor_, &second_tensor));
|
||||
second_tensor->scalar<int32>()() = static_cast<int32>(
|
||||
GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc));
|
||||
#else
|
||||
// There is no tensor format in DNNL 1.x. So we cache the complete filter
|
||||
// descriptor as flat byte array.
|
||||
TensorShape cached_filter_md_shape;
|
||||
memory::desc weights_desc = conv_prim_desc.weights_desc();
|
||||
// We don't use .get_size() method of memory::desc since it returns size
|
||||
// required to store primitive's input memory. It is much more than size of
|
||||
// memory::desc itself.
|
||||
cached_filter_md_shape.AddDim(sizeof(weights_desc) / sizeof(uint8));
|
||||
OP_REQUIRES_OK(context, context->allocate_persistent(
|
||||
DT_UINT8, cached_filter_md_shape,
|
||||
&cached_filter_md_ptensor_, &second_tensor));
|
||||
*reinterpret_cast<memory::desc*>(second_tensor->flat<uint8>().data()) =
|
||||
weights_desc;
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
void AllocatePersistentTensor(OpKernelContext* context,
|
||||
@ -1230,12 +1249,11 @@ class MklConvOp : public OpKernel {
|
||||
const Tensor& cached_filter_md =
|
||||
*cached_filter_md_ptensor_.AccessTensor(context);
|
||||
|
||||
// Check if the memory descriptor of the cached weights is same as
|
||||
// filter_md. If so, we can use the cached weights; otherwise
|
||||
// return nullptr.
|
||||
// Check if the memory descriptor of the cached weights is same as
|
||||
// filter_md. If so, we can use the cached weights; otherwise
|
||||
// return nullptr.
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
if (cached_filter_md.scalar<int64>().size() &&
|
||||
AreMemoryDescriptorsEqual(filter_md, cached_filter_md)) {
|
||||
if (filter_md == *static_cast<memory::desc*>(cached_filter_md.data())) {
|
||||
#else
|
||||
if (cached_filter_md.scalar<int32>().size() &&
|
||||
cached_filter_md.scalar<int32>()() == filter_md) {
|
||||
@ -1568,7 +1586,7 @@ class MklQuantizedConv2DOp
|
||||
|
||||
if (!scaled_bias_buf_)
|
||||
AllocTmpBuffer<Tbias>(context, &scaled_bias_tensor_,
|
||||
conv_fwd_pd->bias_primitive_desc(),
|
||||
GET_BIAS_DESC_FROM_OP_PD(conv_fwd_pd),
|
||||
&scaled_bias_buf_);
|
||||
if (!scaled_bias_) {
|
||||
scaled_bias_ = new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_,
|
||||
|
@ -89,12 +89,26 @@ class MklDequantizeOp : public OpKernel {
|
||||
Tensor* output_tensor = nullptr;
|
||||
MklDnnShape output_mkl_shape;
|
||||
TensorShape output_tf_shape;
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
memory::desc dst_md =
|
||||
src_mkl_shape.IsMklTensor()
|
||||
? memory::desc(src_dims, MklDnnType<float>(),
|
||||
static_cast<memory::format>(src_md.data.format))
|
||||
static_cast<MEMORY_FORMAT>(src_md.data.format))
|
||||
: memory::desc(src_dims, MklDnnType<float>(),
|
||||
MEMORY_FORMAT::nhwc);
|
||||
#else
|
||||
memory::desc dst_md = memory::desc();
|
||||
if (src_mkl_shape.IsMklTensor()) {
|
||||
dst_md = memory::desc(src_mkl_shape.GetMklLayout().data);
|
||||
// There is no API in MKL-DNN v1.x to construct memory descriptor with
|
||||
// same .data field but different type.
|
||||
dst_md.data.data_type = memory::convert_to_c(MklDnnType<float>());
|
||||
} else {
|
||||
dst_md =
|
||||
memory::desc(src_dims, MklDnnType<float>(), MEMORY_FORMAT::nhwc);
|
||||
}
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// If input is MKL shape, output is also MKL shape.
|
||||
// If input is TF shape, output is also TF shape.
|
||||
if (src_mkl_shape.IsMklTensor()) {
|
||||
|
@ -42,15 +42,29 @@ struct MklBatchNormFwdParams {
|
||||
int depth;
|
||||
float eps;
|
||||
bool training;
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
MEMORY_FORMAT src_format;
|
||||
#else
|
||||
memory::desc src_md;
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
bool training, MEMORY_FORMAT src_format)
|
||||
#else
|
||||
bool training, memory::desc src_md)
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
: src_dims(src_dims),
|
||||
depth(depth),
|
||||
eps(eps),
|
||||
training(training),
|
||||
src_format(src_format) {}
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
src_format(src_format) {
|
||||
}
|
||||
#else
|
||||
src_md(src_md) {
|
||||
}
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
@ -177,16 +191,17 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
||||
context_.pkind = fwdParams.training ? prop_kind::forward_training
|
||||
: prop_kind::forward_scoring;
|
||||
|
||||
// Memory descriptor
|
||||
auto src_md = memory::desc({fwdParams.src_dims}, MklDnnType<T>(),
|
||||
fwdParams.src_format);
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// Memory descriptor
|
||||
auto src_md = fwdParams.src_md;
|
||||
// Create forward BatchNorm descriptor and primitive descriptor.
|
||||
auto fwd_desc = batch_normalization_forward::desc(
|
||||
context_.pkind, src_md, fwdParams.eps,
|
||||
static_cast<mkldnn::normalization_flags>(context_.flags));
|
||||
#else
|
||||
// Memory descriptor
|
||||
auto src_md = memory::desc({fwdParams.src_dims}, MklDnnType<T>(),
|
||||
fwdParams.src_format);
|
||||
auto fwd_desc = batch_normalization_forward::desc(
|
||||
context_.pkind, src_md, fwdParams.eps, context_.flags);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
@ -368,17 +383,36 @@ struct MklBatchNormBwdParams {
|
||||
int depth;
|
||||
float eps;
|
||||
bool training;
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
MEMORY_FORMAT src_format;
|
||||
#else
|
||||
memory::desc src_md;
|
||||
memory::desc diff_dst_md;
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
MklBatchNormBwdParams(memory::dims src_dims, memory::dims diff_dst_dims,
|
||||
int depth, float eps, bool training,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
MEMORY_FORMAT src_format)
|
||||
: src_dims(src_dims),
|
||||
diff_dst_dims(diff_dst_dims),
|
||||
depth(depth),
|
||||
eps(eps),
|
||||
training(training),
|
||||
src_format(src_format) {}
|
||||
src_format(src_format) {
|
||||
}
|
||||
#else
|
||||
memory::desc src_md, memory::desc diff_dst_md)
|
||||
: src_dims(src_dims),
|
||||
diff_dst_dims(diff_dst_dims),
|
||||
depth(depth),
|
||||
eps(eps),
|
||||
training(training),
|
||||
src_md(src_md),
|
||||
diff_dst_md(diff_dst_md) {
|
||||
}
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
@ -432,12 +466,9 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// Execute backward batch-normalization primitives.
|
||||
// TODO(intel-tf): Use execute_primitive instead of the inlined code.
|
||||
DCHECK_EQ(context_.bwd_primitives.size(), context_.net_args.size());
|
||||
for (size_t i = 0; i < context_.bwd_primitives.size(); ++i) {
|
||||
context_.bwd_primitives.at(i).execute(*context_.bwd_stream,
|
||||
context_.net_args.at(i));
|
||||
}
|
||||
execute_primitives(context_.bwd_primitives, context_.bwd_stream,
|
||||
context_.net_args);
|
||||
#else
|
||||
context_.bwd_stream->submit(context_.bwd_primitives);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
@ -516,10 +547,15 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
|
||||
: (GET_FLAG(use_scale_shift) | GET_FLAG(use_global_stats));
|
||||
|
||||
// Memory descriptors.
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
auto src_md = memory::desc({bwdParams.src_dims}, MklDnnType<T>(),
|
||||
bwdParams.src_format);
|
||||
auto diff_dst_md = memory::desc({bwdParams.diff_dst_dims}, MklDnnType<T>(),
|
||||
bwdParams.src_format);
|
||||
#else
|
||||
auto src_md = bwdParams.src_md;
|
||||
auto diff_dst_md = bwdParams.diff_dst_md;
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
auto variance_desc =
|
||||
memory::desc({1, bwdParams.depth}, MklDnnType<U>(), MEMORY_FORMAT::nc);
|
||||
auto mean_desc =
|
||||
@ -794,7 +830,7 @@ class MklFusedBatchNormOp : public OpKernel {
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_,
|
||||
dnn_fmt);
|
||||
src_md);
|
||||
#else
|
||||
MklBatchNormFwdParams fwdParams(
|
||||
src_dims, depth_, epsilon_, is_training_,
|
||||
@ -1149,7 +1185,7 @@ class MklFusedBatchNormGradOp : public OpKernel {
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
MklBatchNormBwdParams bwdParams(src_dims, diff_dst_dims, depth_, epsilon_,
|
||||
is_training_, dnn_fmt);
|
||||
is_training_, src_md, diff_dst_md);
|
||||
#else
|
||||
MklBatchNormBwdParams bwdParams(
|
||||
src_dims, diff_dst_dims, depth_, epsilon_, is_training_,
|
||||
@ -1158,19 +1194,10 @@ class MklFusedBatchNormGradOp : public OpKernel {
|
||||
MklFusedBatchNormBwdPrimitive<T, U>* bn_bwd =
|
||||
MklFusedBatchNormBwdPrimitiveFactory<T, U>::Get(bwdParams);
|
||||
|
||||
// Check if src/diff_dst needs to be reordered.
|
||||
const T* src_data = nullptr;
|
||||
const T* src_data = src_tensor.flat<T>().data();
|
||||
const T* diff_dst_data = diff_dst_tensor.flat<T>().data();
|
||||
// Check if diff_dst input needs to be reordered
|
||||
std::shared_ptr<BatchNormBwdPd> bn_bwd_pd = bn_bwd->GetBatchNormBwdPd();
|
||||
if (IS_SRC_REORDER_NEEDED(src_md, bn_bwd_pd, bn_bwd)) {
|
||||
src.SetUsrMem(src_md, &src_tensor);
|
||||
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
GET_SRC_DESC_FROM_OP_PD(bn_bwd_pd), cpu_engine_));
|
||||
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
|
||||
}
|
||||
|
||||
const T* diff_dst_data = nullptr;
|
||||
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bn_bwd_pd, bn_bwd)) {
|
||||
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
|
||||
diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
|
@ -105,6 +105,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
memory::dims weight_dims = memory::dims({channel, k});
|
||||
memory::dims bias_dims = memory::dims({channel});
|
||||
memory::dims dst_dims = memory::dims({batch, channel});
|
||||
MEMORY_FORMAT src_format = MEMORY_FORMAT::nc;
|
||||
MEMORY_FORMAT weight_format =
|
||||
transpose_b_ ? MEMORY_FORMAT::oi : MEMORY_FORMAT::io;
|
||||
|
||||
@ -112,7 +113,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
// 1. const, let MKL-DNN determine format because it will be cached;
|
||||
// 2. var, keep the original format to avoid reordering.
|
||||
MklDnnMatMulFwdParams matmul_params(
|
||||
src_dims, weight_dims, bias_dims, dst_dims,
|
||||
src_dims, weight_dims, bias_dims, dst_dims, src_format,
|
||||
(this->is_weight_const_) ? MEMORY_FORMAT::any : weight_format);
|
||||
|
||||
// Extend the basic parameters for data types and fusions.
|
||||
@ -152,44 +153,44 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
MklDnnData<T> src_mkl(&(this->cpu_engine_));
|
||||
MklDnnData<T> weight_mkl(&(this->cpu_engine_));
|
||||
|
||||
if (src_mkl_shape.IsMklTensor()) {
|
||||
memory::desc input_md = src_mkl_shape.GetMklLayout();
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
if (input_md != matmul_pd->src_desc()) {
|
||||
#else
|
||||
if (input_md.data.format != MKL_TENSOR_FORMAT_NC) {
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
src_mkl.SetUsrMem(input_md, src_data);
|
||||
src_mkl.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
matmul_pd.get()->PRIMITIVE_DESC_SRC, this->cpu_engine_));
|
||||
src_data = reinterpret_cast<T*>(src_mkl.GetOpMem().get_data_handle());
|
||||
}
|
||||
auto src_md = src_mkl_shape.IsMklTensor()
|
||||
? src_mkl_shape.GetMklLayout()
|
||||
: memory::desc(src_dims, MklDnnType<T>(), src_format);
|
||||
|
||||
if (IS_SRC_REORDER_NEEDED(src_md, matmul_pd, matmul_prim)) {
|
||||
src_mkl.SetUsrMem(src_md, src_data);
|
||||
src_mkl.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
matmul_pd.get()->PRIMITIVE_DESC_SRC, this->cpu_engine_));
|
||||
src_data = reinterpret_cast<T*>(src_mkl.GetOpMem().get_data_handle());
|
||||
}
|
||||
|
||||
// Get cached data when weight is const.
|
||||
memory::format expected_format = matmul_prim->GetWeightMemoryFormat();
|
||||
DCHECK(expected_format != weight_format && this->is_weight_const_);
|
||||
if (this->is_weight_const_) {
|
||||
const memory::desc weight_md =
|
||||
memory::desc(weight_dims, MklDnnType<T>(), weight_format);
|
||||
if (IS_WEIGHTS_REORDER_NEEDED(weight_md, matmul_pd, matmul_prim)) {
|
||||
T* cached_weight_data = nullptr;
|
||||
if (this->IsWeightCacheEmpty(ctx)) {
|
||||
auto weight_md =
|
||||
memory::desc(weight_dims, MklDnnType<T>(), weight_format);
|
||||
this->CacheWeight(ctx, matmul_pd, cached_weight_data, weight_tensor,
|
||||
weight_mkl, weight_md);
|
||||
|
||||
if (this->is_weight_const_) {
|
||||
if (this->IsWeightCacheEmpty(ctx)) {
|
||||
this->CacheWeight(ctx, matmul_pd, cached_weight_data, weight_tensor,
|
||||
weight_mkl, weight_md);
|
||||
}
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
cached_weight_data = this->GetCachedWeight(
|
||||
ctx, GET_WEIGHTS_DESC_FROM_OP_PD(matmul_pd));
|
||||
#else
|
||||
cached_weight_data = this->GetCachedWeight(
|
||||
ctx, GET_WEIGHTS_DESC_FROM_OP_PD(matmul_pd).desc());
|
||||
#endif
|
||||
}
|
||||
cached_weight_data = this->GetCachedWeight(ctx, expected_format);
|
||||
|
||||
// Cache weight may fail when it gets different format in different
|
||||
// iteration. Fallback to reoder if it happens.
|
||||
// TODO: Fix this slow path.
|
||||
// Also do generel reorder if weight isn't const.
|
||||
if (cached_weight_data != nullptr) {
|
||||
weight_data = cached_weight_data;
|
||||
} else {
|
||||
memory::desc input_md =
|
||||
memory::desc(weight_dims, MklDnnType<T>(), weight_format);
|
||||
|
||||
//>>>>>>> master
|
||||
weight_mkl.SetUsrMem(input_md, weight_data);
|
||||
weight_mkl.SetUsrMem(weight_md, weight_data);
|
||||
weight_mkl.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
matmul_pd.get()->PRIMITIVE_DESC_WEIGHTS, this->cpu_engine_));
|
||||
weight_data =
|
||||
@ -210,23 +211,21 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
|
||||
void ExtendMklDnnMatMulFwdParams(OpKernelContext* ctx,
|
||||
MklDnnMatMulFwdParams& params) {
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
if (fused_ops_.size() == 2) {
|
||||
string post_op = fused_ops_[1];
|
||||
|
||||
if (post_op == "Relu") {
|
||||
params.post_op_params.push_back({"relu", { 1.0, 0.0, 0.0 }});
|
||||
params.post_op_params.push_back({"relu", {1.0, 0.0, 0.0}});
|
||||
} else if (post_op == "Relu6") {
|
||||
params.post_op_params.push_back({"relu6", { 1.0, 6.0, 0.0 }});
|
||||
params.post_op_params.push_back({"relu6", {1.0, 6.0, 0.0}});
|
||||
} else if (post_op == "Elu") {
|
||||
params.post_op_params.push_back({"elu", { 1.0, 1.0, 0.0 }});
|
||||
params.post_op_params.push_back({"elu", {1.0, 1.0, 0.0}});
|
||||
} else {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, errors::InvalidArgument(
|
||||
"Unsupported post-argument in MklFusedMatMul: ", post_op));
|
||||
}
|
||||
}
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -41,7 +41,8 @@ struct MklDnnMatMulFwdParams {
|
||||
memory::dims weight_dims;
|
||||
memory::dims bias_dims;
|
||||
memory::dims dst_dims;
|
||||
MEMORY_FORMAT weight_fmt;
|
||||
MEMORY_FORMAT src_format;
|
||||
MEMORY_FORMAT weight_format;
|
||||
string dtypes = string("");
|
||||
struct PostOpParam {
|
||||
string name;
|
||||
@ -51,12 +52,14 @@ struct MklDnnMatMulFwdParams {
|
||||
|
||||
MklDnnMatMulFwdParams(memory::dims src_dims, memory::dims weight_dims,
|
||||
memory::dims bias_dims, memory::dims dst_dims,
|
||||
MEMORY_FORMAT weight_fmt = MEMORY_FORMAT::any)
|
||||
MEMORY_FORMAT src_format = MEMORY_FORMAT::any,
|
||||
MEMORY_FORMAT weight_format = MEMORY_FORMAT::any)
|
||||
: src_dims(src_dims),
|
||||
weight_dims(weight_dims),
|
||||
bias_dims(bias_dims),
|
||||
dst_dims(dst_dims),
|
||||
weight_fmt(weight_fmt) {}
|
||||
src_format(src_format),
|
||||
weight_format(weight_format) {}
|
||||
};
|
||||
|
||||
// With quantization, input, weight, bias, and output can have different types.
|
||||
@ -182,15 +185,11 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
|
||||
// format.
|
||||
context_.src_md.reset(new memory::desc({matmul_fwd_params.src_dims},
|
||||
MklDnnType<Tinput>(),
|
||||
MEMORY_FORMAT::any));
|
||||
matmul_fwd_params.src_format));
|
||||
|
||||
context_.weight_md.reset(new memory::desc({matmul_fwd_params.weight_dims},
|
||||
MklDnnType<Tweight>(),
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
MEMORY_FORMAT::any));
|
||||
#else
|
||||
matmul_fwd_params.weight_fmt));
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
matmul_fwd_params.weight_format));
|
||||
|
||||
context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims},
|
||||
MklDnnType<Toutput>(),
|
||||
@ -438,49 +437,61 @@ class MklDnnMatMulOpBase : public OpKernel {
|
||||
|
||||
// reorder and cache the weight
|
||||
weight.SetUsrMem(weight_md, &weight_tensor);
|
||||
weight.CheckReorderToOpMem(matmul_fwd_pd.get()->weights_primitive_desc());
|
||||
weight.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_));
|
||||
weight_data = static_cast<Tweight*>(weight.GetOpMem().get_data_handle());
|
||||
|
||||
Tensor* weight_tensor_ptr = nullptr;
|
||||
|
||||
size_t weight_size = matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS.get_size();
|
||||
TensorShape weight_tf_shape;
|
||||
weight_tf_shape.AddDim(
|
||||
(matmul_fwd_pd.get()->weights_primitive_desc().get_size() /
|
||||
sizeof(Tweight)));
|
||||
weight_tf_shape.AddDim(weight_size / sizeof(Tweight));
|
||||
|
||||
OP_REQUIRES_OK(context, context->allocate_persistent(
|
||||
DataTypeToEnum<Tweight>::value, weight_tf_shape,
|
||||
&weight_oi_, &weight_tensor_ptr));
|
||||
|
||||
void* weight_oi_t_data = weight.GetTensorBuffer(weight_tensor_ptr);
|
||||
size_t weight_size = weight.GetOpMem().get_primitive_desc().get_size();
|
||||
memcpy(weight_oi_t_data, weight_data, weight_size);
|
||||
|
||||
// cache the memory descriptor
|
||||
// cache the memory descriptor
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
auto expected_md = GET_WEIGHTS_DESC_FROM_OP_PD(matmul_fwd_pd);
|
||||
#else
|
||||
auto expected_md = GET_WEIGHTS_DESC_FROM_OP_PD(matmul_fwd_pd).desc();
|
||||
#endif
|
||||
Tensor* weight_md_tensor_ptr = nullptr;
|
||||
TensorShape weight_mkl_format;
|
||||
weight_mkl_format.AddDim(1);
|
||||
weight_mkl_format.AddDim(sizeof(expected_md) / sizeof(Tweight));
|
||||
|
||||
OP_REQUIRES_OK(context, context->allocate_persistent(
|
||||
DT_INT32, weight_mkl_format, &weight_oi_md_,
|
||||
&weight_md_tensor_ptr));
|
||||
weight_md_tensor_ptr->scalar<int32>()() =
|
||||
matmul_fwd_pd.get()->weights_primitive_desc().desc().data.format;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_persistent(DataTypeToEnum<Tweight>::value,
|
||||
weight_mkl_format, &weight_oi_md_,
|
||||
&weight_md_tensor_ptr));
|
||||
*reinterpret_cast<memory::desc*>(
|
||||
weight_md_tensor_ptr->flat<Tweight>().data()) = expected_md;
|
||||
}
|
||||
|
||||
Tweight* GetCachedWeight(OpKernelContext* context,
|
||||
const memory::format& weight_mf)
|
||||
const memory::desc& expected_md)
|
||||
LOCKS_EXCLUDED(mu_) {
|
||||
tf_shared_lock lock(mu_);
|
||||
const Tensor& weight_t = *weight_oi_.AccessTensor(context);
|
||||
const Tensor& weight_md_t = *weight_oi_md_.AccessTensor(context);
|
||||
|
||||
// Check if the memory descriptor of the cached weight is same as
|
||||
// weight_mf. if so use the cached memory, else return NULL
|
||||
if (weight_md_t.scalar<int32>().size() &&
|
||||
weight_md_t.scalar<int32>()() == weight_mf) {
|
||||
return static_cast<Tweight*>(
|
||||
const_cast<Tweight*>(weight_t.flat<Tweight>().data()));
|
||||
// Check if the memory descriptor of the cached weight is same as
|
||||
// expected_md. if so use the cached memory, else return NULL
|
||||
if (weight_md_t.flat<Tweight>().size()) {
|
||||
const memory::desc& stored_md =
|
||||
*(static_cast<memory::desc*>(weight_md_t.data()));
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
if (stored_md == expected_md) {
|
||||
#else
|
||||
if (stored_md.data.format == expected_md.data.format) {
|
||||
#endif
|
||||
return static_cast<Tweight*>(
|
||||
const_cast<Tweight*>(weight_t.flat<Tweight>().data()));
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@ -527,7 +538,8 @@ void dnnl_gemm_exec(const dnnl::desc& a_md, const dnnl::desc& b_md,
|
||||
dnnl::stream s(cpu_engine);
|
||||
matmul_prim.execute(s, {{DNNL_ARG_SRC, a_memory},
|
||||
{DNNL_ARG_WEIGHTS, b_memory},
|
||||
{DNNL_ARG_DST, c_memory}});
|
||||
{ DNNL_ARG_DST,
|
||||
c_memory }});
|
||||
s.wait();
|
||||
}
|
||||
|
||||
|
@ -143,12 +143,12 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
|
||||
MklPoolingParams fwdParams(
|
||||
src_dims, output_dims_mkl_order, filter_dims, strides, padding_left,
|
||||
padding_right, ALGORITHM::pooling_max, pooling_prop_kind,
|
||||
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_));
|
||||
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), input_md);
|
||||
#else
|
||||
MklPoolingParams fwdParams(
|
||||
src_dims, output_dims_mkl_order, filter_dims, strides, padding_left,
|
||||
padding_right, ALGORITHM::pooling_max, pooling_prop_kind,
|
||||
static_cast<MEMORY_FORMAT>(input_md.data.format));
|
||||
static_cast<MEMORY_FORMAT>(input_md.data.format), input_md);
|
||||
#endif
|
||||
pooling_fwd = MklPoolingFwdPrimitiveFactory<T>::Get(fwdParams);
|
||||
// Allocate output tensor.
|
||||
@ -303,13 +303,13 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
|
||||
orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
|
||||
strides, padding_left, padding_right, ALGORITHM::pooling_max,
|
||||
prop_kind::forward_training,
|
||||
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_));
|
||||
static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), src_md);
|
||||
#else
|
||||
MklPoolingParams bwdParams(
|
||||
orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
|
||||
strides, padding_left, padding_right, ALGORITHM::pooling_max,
|
||||
prop_kind::forward_training,
|
||||
static_cast<MEMORY_FORMAT>(src_md.data.format));
|
||||
static_cast<MEMORY_FORMAT>(src_md.data.format), src_md);
|
||||
#endif
|
||||
MklPoolingBwdPrimitive<T>* pooling_bwd =
|
||||
MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);
|
||||
|
@ -43,6 +43,7 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
|
||||
// so src format is currently hard-coded.
|
||||
// A utility function is used to do this,
|
||||
// which may be broken with future CPU architectures
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
bool is_2d = (fwdParams.src_dims.size() == 4);
|
||||
if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value)
|
||||
context_.src_fmt = is_2d ? MEMORY_FORMAT::nhwc : MEMORY_FORMAT::ndhwc;
|
||||
@ -51,6 +52,9 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
|
||||
|
||||
context_.src_md.reset(new memory::desc({fwdParams.src_dims}, MklDnnType<T>(),
|
||||
context_.src_fmt));
|
||||
#else
|
||||
context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
context_.dst_md.reset(new memory::desc({fwdParams.dst_dims}, MklDnnType<T>(),
|
||||
MEMORY_FORMAT::any));
|
||||
|
||||
|
@ -47,13 +47,14 @@ struct MklPoolingParams {
|
||||
memory::dims padding_right;
|
||||
mkldnn::algorithm alg_kind;
|
||||
mkldnn::prop_kind prop_kind;
|
||||
memory::format src_format;
|
||||
MEMORY_FORMAT src_format;
|
||||
memory::desc src_md;
|
||||
|
||||
MklPoolingParams(memory::dims src_dims, memory::dims dst_dims,
|
||||
memory::dims filter_dims, memory::dims strides,
|
||||
memory::dims padding_left, memory::dims padding_right,
|
||||
mkldnn::algorithm alg_kind, mkldnn::prop_kind prop_kind,
|
||||
memory::format src_format)
|
||||
MEMORY_FORMAT src_format, memory::desc src_md)
|
||||
: src_dims(src_dims),
|
||||
dst_dims(dst_dims),
|
||||
filter_dims(filter_dims),
|
||||
@ -62,7 +63,8 @@ struct MklPoolingParams {
|
||||
padding_right(padding_right),
|
||||
alg_kind(alg_kind),
|
||||
prop_kind(prop_kind),
|
||||
src_format(src_format) {}
|
||||
src_format(src_format),
|
||||
src_md(src_md) {}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
@ -269,11 +269,11 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
|
||||
}
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
weight_data = this->GetCachedWeight(
|
||||
context, static_cast<int32>(weight_mkl_shape.GetTfDataFormat()));
|
||||
context, GET_WEIGHTS_DESC_FROM_OP_PD(matmul_fwd_pd));
|
||||
#else
|
||||
weight_data = this->GetCachedWeight(
|
||||
context, matmul_fwd->GetWeightMemoryFormat());
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
context, GET_WEIGHTS_DESC_FROM_OP_PD(matmul_fwd_pd).desc());
|
||||
#endif
|
||||
is_weight_cached = (weight_data != nullptr);
|
||||
}
|
||||
|
||||
|
@ -164,7 +164,11 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
|
||||
context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
|
||||
|
||||
context_.src_mpd.reset(
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
new MEMORY_PRIMITIVE_DESC(*context_.src_md));
|
||||
#else
|
||||
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
|
||||
#endif
|
||||
|
||||
// Create an eltwise forward descriptor and primitive descriptor
|
||||
context_.fwd_desc.reset(new eltwise_forward::desc(
|
||||
@ -397,7 +401,6 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
|
||||
// Create memory descriptors for eltwise data w/ no specified format
|
||||
context_.src_md.reset(new memory::desc(bwdParams.common_md.data));
|
||||
context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data));
|
||||
|
||||
context_.src_mpd.reset(
|
||||
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
|
||||
context_.diff_dst_mpd.reset(
|
||||
|
@ -65,7 +65,7 @@ class MklSoftmaxPrimitive : public MklPrimitive {
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.fwd_primitives, context_.fwd_stream,
|
||||
context_.net_args);
|
||||
context_.fwd_net_args);
|
||||
#else
|
||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
||||
#endif
|
||||
|
@ -1302,6 +1302,7 @@ REGISTER_OP("_MklFusedBatchNormV3")
|
||||
.Attr("U: {float}")
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.Attr("exponential_avg_factor: float = 1.0")
|
||||
.Attr("is_training: bool = true")
|
||||
.SetShapeFn(shape_inference::FusedBatchNormShape)
|
||||
.Doc(
|
||||
|
@ -2524,6 +2524,7 @@ REGISTER_OP("_MklFusedBatchNorm")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
.Attr("data_format: string = 'NHWC'")
|
||||
.Attr("exponential_avg_factor: float = 1.0")
|
||||
.Attr("is_training: bool = true")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle x;
|
||||
@ -2673,6 +2674,7 @@ REGISTER_OP("_MklFusedBatchNormV2")
|
||||
.Attr("U: {float}")
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.Attr("exponential_avg_factor: float = 1.0")
|
||||
.Attr("is_training: bool = true")
|
||||
.SetShapeFn(shape_inference::FusedBatchNormShape);
|
||||
|
||||
|
@ -48,6 +48,7 @@ namespace tensorflow {
|
||||
#define GET_WORKSPACE_DESC_FROM_OP_PD(op_pd) op_pd->workspace_desc()
|
||||
#define GET_TENSOR_FORMAT(fmt) MklTensorFormatToMklDnnDataFormat(fmt)
|
||||
#define GET_TF_DATA_FORMAT(shape, mem_desc) shape.GetTfDataFormat()
|
||||
#define GET_USR_MEM_PRIM_DESC(src) src.GetUsrMemDesc()
|
||||
#define GET_WEIGHTS_DESC_FROM_OP_PD(op_pd) op_pd->weights_desc()
|
||||
#define GET_WEIGHTS_FORMAT_FROM_OP_PD(op_pd, op) \
|
||||
GET_WEIGHTS_DESC_FROM_OP_PD(op_pd)
|
||||
@ -114,7 +115,6 @@ namespace tensorflow {
|
||||
#define TENSOR_FORMAT MKL_TENSOR_FORMAT
|
||||
#define TENSOR_FORMAT_NHWC MKL_TENSOR_FORMAT_NHWC
|
||||
#define TENSOR_MAX_DIMS MKLDNN_MAX_NDIMS
|
||||
#define GET_USR_MEM_PRIM_DESC(src) src.GetUsrMemDesc()
|
||||
|
||||
#else
|
||||
|
||||
@ -148,6 +148,7 @@ namespace tensorflow {
|
||||
op_pd.get()->workspace_primitive_desc()
|
||||
#define GET_TENSOR_FORMAT(fmt) fmt
|
||||
#define GET_TF_DATA_FORMAT(shape, mem_desc) mem_desc.data.format
|
||||
#define GET_USR_MEM_PRIM_DESC(src) src.GetUsrMemPrimDesc()
|
||||
#define GET_WEIGHTS_DESC_FROM_OP_PD(op_pd) op_pd.get()->weights_primitive_desc()
|
||||
#define GET_WEIGHTS_FORMAT_FROM_OP_PD(op_pd, op) op->GetFilterMemoryFormat()
|
||||
#define IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, op_pd, op) \
|
||||
@ -215,7 +216,6 @@ namespace tensorflow {
|
||||
#define SUMMAND_MD summand_pd
|
||||
#define TENSOR_FORMAT TensorFormat
|
||||
#define TENSOR_FORMAT_NHWC FORMAT_NHWC
|
||||
#define GET_USR_MEM_PRIM_DESC(src) src.GetUsrMemPrimDesc()
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -18,12 +18,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/util/mkl_types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(MklUtilTest, MklDnnTfShape) {
|
||||
auto cpu_engine = engine(engine::cpu, 0);
|
||||
auto cpu_engine = engine(ENGINE_CPU, 0);
|
||||
MklDnnData<float> a(&cpu_engine);
|
||||
|
||||
const int N = 1, C = 2, H = 3, W = 4;
|
||||
@ -31,7 +32,7 @@ TEST(MklUtilTest, MklDnnTfShape) {
|
||||
MklDnnShape a_mkldnn_shape;
|
||||
a_mkldnn_shape.SetMklTensor(true);
|
||||
// Create TF layout in NCHW.
|
||||
a_mkldnn_shape.SetTfLayout(a_dims.size(), a_dims, memory::format::nchw);
|
||||
a_mkldnn_shape.SetTfLayout(a_dims.size(), a_dims, MKL_TENSOR_FORMAT_NCHW);
|
||||
TensorShape a_tf_shape_nchw({N, C, H, W});
|
||||
TensorShape a_tf_shape_nhwc({N, H, W, C});
|
||||
TensorShape a_mkldnn_tf_shape = a_mkldnn_shape.GetTfShape();
|
||||
@ -43,7 +44,7 @@ TEST(MklUtilTest, MklDnnTfShape) {
|
||||
MklDnnShape b_mkldnn_shape;
|
||||
b_mkldnn_shape.SetMklTensor(true);
|
||||
// Create TF layout in NHWC.
|
||||
b_mkldnn_shape.SetTfLayout(b_dims.size(), b_dims, memory::format::nhwc);
|
||||
b_mkldnn_shape.SetTfLayout(b_dims.size(), b_dims, MKL_TENSOR_FORMAT_NHWC);
|
||||
TensorShape b_tf_shape_nhwc({N, H, W, C});
|
||||
TensorShape b_tf_shape_nchw({N, C, H, W});
|
||||
TensorShape b_mkldnn_tf_shape = b_mkldnn_shape.GetTfShape();
|
||||
@ -55,7 +56,7 @@ TEST(MklUtilTest, MklDnnTfShape) {
|
||||
TEST(MklUtilTest, MklDnnBlockedFormatTest) {
|
||||
// Let's create 2D tensor of shape {3, 4} with 3 being innermost dimension
|
||||
// first (case 1) and then it being outermost dimension (case 2).
|
||||
auto cpu_engine = engine(engine::cpu, 0);
|
||||
auto cpu_engine = engine(ENGINE_CPU, 0);
|
||||
|
||||
// Setting for case 1
|
||||
MklDnnData<float> a(&cpu_engine);
|
||||
@ -67,7 +68,9 @@ TEST(MklUtilTest, MklDnnBlockedFormatTest) {
|
||||
EXPECT_EQ(a_md1.data.ndims, 2);
|
||||
EXPECT_EQ(a_md1.data.dims[0], 3);
|
||||
EXPECT_EQ(a_md1.data.dims[1], 4);
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
EXPECT_EQ(a_md1.data.format, mkldnn_blocked);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// Setting for case 2
|
||||
MklDnnData<float> b(&cpu_engine);
|
||||
@ -79,7 +82,9 @@ TEST(MklUtilTest, MklDnnBlockedFormatTest) {
|
||||
EXPECT_EQ(b_md2.data.ndims, 2);
|
||||
EXPECT_EQ(b_md2.data.dims[0], 3);
|
||||
EXPECT_EQ(b_md2.data.dims[1], 4);
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
EXPECT_EQ(b_md2.data.format, mkldnn_blocked);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
TEST(MklUtilTest, LRUCacheTest) {
|
||||
|
@ -174,12 +174,12 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
||||
|
||||
tf_http_archive(
|
||||
name = "mkl_dnn_v1",
|
||||
build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"),
|
||||
sha256 = "fcc2d951f7170eade0cfdd0d8d1d58e3e7785bd326bca6555f3722f8cba71811",
|
||||
strip_prefix = "mkl-dnn-1.0-pc2",
|
||||
build_file = clean_dep("//third_party/mkl_dnn:mkldnn_v1.BUILD"),
|
||||
sha256 = "30979a09753e8e35d942446c3778c9f0eba543acf2fb0282af8b9c89355d0ddf",
|
||||
strip_prefix = "mkl-dnn-1.2",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/intel/mkl-dnn/archive/v1.0-pc2.tar.gz",
|
||||
"https://github.com/intel/mkl-dnn/archive/v1.0-pc2.tar.gz",
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/intel/mkl-dnn/archive/v1.2.tar.gz",
|
||||
"https://github.com/intel/mkl-dnn/archive/v1.2.tar.gz",
|
||||
],
|
||||
)
|
||||
|
||||
|
136
third_party/mkl_dnn/mkldnn_v1.BUILD
vendored
Normal file
136
third_party/mkl_dnn/mkldnn_v1.BUILD
vendored
Normal file
@ -0,0 +1,136 @@
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load(
|
||||
"@org_tensorflow//third_party/mkl_dnn:build_defs.bzl",
|
||||
"if_mkl_open_source_only",
|
||||
"if_mkl_v1_open_source_only",
|
||||
)
|
||||
load(
|
||||
"@org_tensorflow//third_party:common.bzl",
|
||||
"template_rule",
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "clang_linux_x86_64",
|
||||
values = {
|
||||
"cpu": "k8",
|
||||
"define": "using_clang=true",
|
||||
},
|
||||
)
|
||||
|
||||
template_rule(
|
||||
name = "dnnl_config_h",
|
||||
src = "include/dnnl_config.h.in",
|
||||
out = "include/dnnl_config.h",
|
||||
substitutions = {
|
||||
"#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_OMP",
|
||||
"#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_OMP",
|
||||
"#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE",
|
||||
},
|
||||
)
|
||||
|
||||
# Create the file mkldnn_version.h with MKL-DNN version numbers.
|
||||
# Currently, the version numbers are hard coded here. If MKL-DNN is upgraded then
|
||||
# the version numbers have to be updated manually. The version numbers can be
|
||||
# obtained from the PROJECT_VERSION settings in CMakeLists.txt. The variable is
|
||||
# set to "version_major.version_minor.version_patch". The git hash version can
|
||||
# be set to NA.
|
||||
# TODO(agramesh1) Automatically get the version numbers from CMakeLists.txt.
|
||||
|
||||
template_rule(
|
||||
name = "dnnl_version_h",
|
||||
src = "include/dnnl_version.h.in",
|
||||
out = "include/dnnl_version.h",
|
||||
substitutions = {
|
||||
"@DNNL_VERSION_MAJOR@": "1",
|
||||
"@DNNL_VERSION_MINOR@": "2",
|
||||
"@DNNL_VERSION_PATCH@": "0",
|
||||
"@DNNL_VERSION_HASH@": "N/A",
|
||||
},
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mkl_dnn",
|
||||
srcs = glob([
|
||||
"src/common/*.cpp",
|
||||
"src/common/*.hpp",
|
||||
"src/cpu/*.cpp",
|
||||
"src/cpu/*.hpp",
|
||||
"src/cpu/**/*.cpp",
|
||||
"src/cpu/**/*.hpp",
|
||||
"src/cpu/xbyak/*.h",
|
||||
]) + if_mkl_v1_open_source_only([
|
||||
":dnnl_config_h",
|
||||
]) + [":dnnl_version_h"],
|
||||
hdrs = glob(["include/*"]),
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-DUSE_MKL",
|
||||
"-DUSE_CBLAS",
|
||||
] + if_mkl_open_source_only([
|
||||
"-UUSE_MKL",
|
||||
"-UUSE_CBLAS",
|
||||
]) + if_mkl_v1_open_source_only([
|
||||
"-UUSE_MKL",
|
||||
"-UUSE_CBLAS",
|
||||
]) + select({
|
||||
"@org_tensorflow//tensorflow:linux_x86_64": [
|
||||
"-fopenmp", # only works with gcc
|
||||
],
|
||||
# TODO(ibiryukov): enable openmp with clang by including libomp as a
|
||||
# dependency.
|
||||
":clang_linux_x86_64": [],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
includes = [
|
||||
"include",
|
||||
"src",
|
||||
"src/common",
|
||||
"src/cpu",
|
||||
"src/cpu/gemm",
|
||||
"src/cpu/xbyak",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"@org_tensorflow//tensorflow:linux_x86_64": [
|
||||
"@mkl_linux//:mkl_headers",
|
||||
"@mkl_linux//:mkl_libs_linux",
|
||||
],
|
||||
"@org_tensorflow//tensorflow:macos": [
|
||||
"@mkl_darwin//:mkl_headers",
|
||||
"@mkl_darwin//:mkl_libs_darwin",
|
||||
],
|
||||
"@org_tensorflow//tensorflow:windows": [
|
||||
"@mkl_windows//:mkl_headers",
|
||||
"@mkl_windows//:mkl_libs_windows",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mkldnn_single_threaded",
|
||||
srcs = glob([
|
||||
"src/common/*.cpp",
|
||||
"src/common/*.hpp",
|
||||
"src/cpu/*.cpp",
|
||||
"src/cpu/*.hpp",
|
||||
"src/cpu/**/*.cpp",
|
||||
"src/cpu/**/*.hpp",
|
||||
"src/cpu/xbyak/*.h",
|
||||
]) + [":dnnl_config_h"],
|
||||
hdrs = glob(["include/*"]),
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-DMKLDNN_THR=MKLDNN_THR_SEQ", # Disables threading.
|
||||
],
|
||||
includes = [
|
||||
"include",
|
||||
"src",
|
||||
"src/common",
|
||||
"src/cpu",
|
||||
"src/cpu/gemm",
|
||||
"src/cpu/xbyak",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
Loading…
Reference in New Issue
Block a user