MKL DNN 0.x code cleanup - MKLSlice op

This commit is contained in:
xiaohong1031 2020-09-11 09:28:38 -07:00
parent 80bdf7c72f
commit 5481229f1b
2 changed files with 28 additions and 76 deletions

View File

@ -18,7 +18,6 @@ limitations under the License.
#ifdef INTEL_MKL
#include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@ -26,13 +25,10 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/prefetch.h"
#include "tensorflow/core/util/mkl_types.h"
#include "tensorflow/core/util/mkl_util.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
using mkldnn::stream;
#ifndef ENABLE_MKLDNN_V1
using mkldnn::view;
#endif
namespace tensorflow {
@ -89,11 +85,10 @@ static void ValidateMklInputs(OpKernelContext* context, bool* is_identity,
const int input_dims = input_tf_shape.dims();
OP_REQUIRES(
context,
TensorShapeUtils::IsVector(begin_tensor.shape()) &&
TensorShapeUtils::IsVector(size_tensor.shape()) &&
begin_tensor.NumElements() == input_dims &&
size_tensor.NumElements() == input_dims,
context, TensorShapeUtils::IsVector(begin_tensor.shape()) &&
TensorShapeUtils::IsVector(size_tensor.shape()) &&
begin_tensor.NumElements() == input_dims &&
size_tensor.NumElements() == input_dims,
errors::InvalidArgument(
"Expected begin and size arguments to be 1-D tensors of size ",
input_dims, ", but got shapes ", begin_tensor.shape().DebugString(),
@ -181,7 +176,7 @@ template <typename T>
class MklSlicePrimitive : public MklPrimitive {
public:
explicit MklSlicePrimitive(const MklSliceParams& sliceParams)
: MklPrimitive(engine(ENGINE_CPU, 0)) {
: MklPrimitive(engine(engine::kind::cpu, 0)) {
Setup(sliceParams);
}
@ -198,12 +193,9 @@ class MklSlicePrimitive : public MklPrimitive {
context_.src_mem->set_data_handle(sliceParams.from->get_data_handle());
context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle());
#endif // ENABLE_MKLDNN_THREADPOOL
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.slice_primitives, slice_stream,
context_.slice_primitives_args);
#else
slice_stream->submit(context_.slice_primitives);
#endif
// We should set it back to DummyData so as to make the primitive
// in cache pool stateless. Otherwise, if the result for previous
@ -224,12 +216,8 @@ class MklSlicePrimitive : public MklPrimitive {
std::shared_ptr<reorder::primitive_desc> reorder_pd;
std::shared_ptr<mkldnn::stream> slice_stream;
std::vector<mkldnn::primitive> slice_primitives;
#ifdef ENABLE_MKLDNN_V1
std::shared_ptr<mkldnn::memory> src_sub_mem;
std::vector<std::unordered_map<int, memory>> slice_primitives_args;
#else
std::shared_ptr<view::primitive_desc> view_pd;
#endif // ENABLE_MKLDNN_V1
SliceContext()
: src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
} context_;
@ -237,15 +225,13 @@ class MklSlicePrimitive : public MklPrimitive {
void Setup(const MklSliceParams& sliceParams) {
// Actually, DummyData will not be used in computation,
// because the real data will be filled before execution.
context_.src_mem.reset(new MEMORY_CONSTRUCTOR_WITH_MEM_PD(
sliceParams.from, cpu_engine_, DummyData));
context_.dst_mem.reset(new MEMORY_CONSTRUCTOR_WITH_MEM_PD(
sliceParams.to, cpu_engine_, DummyData));
auto src_pd = context_.src_mem->GET_DESC;
auto dst_pd = context_.dst_mem->GET_DESC;
#ifdef ENABLE_MKLDNN_V1
// MKL-DNN 1.x removes struct view, alias of memory in 0.x version.
// So the implementation is based on submemory.
context_.src_mem.reset(
new memory(sliceParams.from->get_desc(), cpu_engine_, DummyData));
context_.dst_mem.reset(
new memory(sliceParams.to->get_desc(), cpu_engine_, DummyData));
auto src_pd = context_.src_mem->get_desc();
auto dst_pd = context_.dst_mem->get_desc();
auto src_sub_desc = context_.src_mem->get_desc().submemory_desc(
sliceParams.size_dims, sliceParams.begin_dims);
context_.src_sub_mem.reset(new memory(src_sub_desc, cpu_engine_, nullptr));
@ -256,18 +242,7 @@ class MklSlicePrimitive : public MklPrimitive {
context_.slice_primitives_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{ MKLDNN_ARG_DST,
*context_.dst_mem }});
#else
context_.view_pd =
std::make_shared<view::primitive_desc>(view::primitive_desc(
src_pd, sliceParams.size_dims, sliceParams.begin_dims));
context_.reorder_pd =
std::make_shared<reorder::primitive_desc>(reorder::primitive_desc(
context_.view_pd->dst_primitive_desc(), dst_pd));
context_.reorder_prim = std::make_shared<mkldnn::reorder>(
reorder(*context_.reorder_pd, *context_.src_mem, *context_.dst_mem));
#endif
{MKLDNN_ARG_DST, *context_.dst_mem}});
context_.slice_primitives.push_back(*context_.reorder_prim);
}
};
@ -298,32 +273,24 @@ class MklSlicePrimitiveFactory : public MklPrimitiveFactory<T> {
static string CreateKey(const MklSliceParams& sliceParams) {
string prefix = "reorder";
FactoryKeyCreator key_creator;
auto const& from_desc = GET_MEMORY_DESC_FROM_MEM_PTR(sliceParams.from).data;
auto const& to_desc = GET_MEMORY_DESC_FROM_MEM_PTR(sliceParams.to).data;
auto const& from_desc = sliceParams.from->get_desc().data;
auto const& to_desc = sliceParams.to->get_desc().data;
const int kIdxFirstStride = 0;
memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
// MKL-DNN removes "struct view". Submemory has similar capability.
auto from_strides = from_desc.MEMORY_FORMAT_DESC.blocking.strides;
auto to_strides = to_desc.MEMORY_FORMAT_DESC.blocking.strides;
memory::dims from_strides_outer_blocks(
GET_BLOCK_STRIDES(from_strides, kIdxFirstStride),
&GET_BLOCK_STRIDES(from_strides, kIdxFirstStride)[from_desc.ndims]);
memory::dims to_strides_outer_blocks(
GET_BLOCK_STRIDES(to_strides, kIdxFirstStride),
&GET_BLOCK_STRIDES(to_strides, kIdxFirstStride)[to_desc.ndims]);
auto from_strides = from_desc.format_desc.blocking.strides;
auto to_strides = to_desc.format_desc.blocking.strides;
memory::dims from_strides_outer_blocks(from_strides,
&from_strides[from_desc.ndims]);
memory::dims to_strides_outer_blocks(to_strides,
&to_strides[to_desc.ndims]);
key_creator.AddAsKey(prefix);
#ifndef ENABLE_MKLDNN_V1
key_creator.AddAsKey(static_cast<int>(from_desc.format));
#endif
key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
key_creator.AddAsKey(from_dims);
key_creator.AddAsKey(from_strides_outer_blocks);
#ifndef ENABLE_MKLDNN_V1
key_creator.AddAsKey(static_cast<int>(to_desc.format));
#endif
key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
key_creator.AddAsKey(to_dims);
key_creator.AddAsKey(to_strides_outer_blocks);
@ -401,7 +368,7 @@ class MklSliceOp : public OpKernel {
// primitive descriptor. And the reorder uses source memory as input but
// traverses it according to a view in_submem_pd.
auto cpu_engine = engine(ENGINE_CPU, 0);
auto cpu_engine = engine(engine::kind::cpu, 0);
MklDnnData<T> src(&cpu_engine);
MklDnnData<T> output(&cpu_engine);
@ -468,22 +435,13 @@ class MklSliceOp : public OpKernel {
// Or else do nothing for it.
auto op_md =
MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
#ifdef ENABLE_MKLDNN_V1
src.CheckReorderToOpMem(op_md, cpu_engine, context);
#else
auto op_pd = memory::primitive_desc(op_md, cpu_engine);
src.CheckReorderToOpMem(op_pd);
#endif
// Step 2 - Create memory for output.
auto output_strides = CalculateTFStrides(size_dims);
auto output_md =
MklDnnData<T>::CreateBlockedMemDesc(size_dims, output_strides);
#ifdef ENABLE_MKLDNN_V1
auto output_pd = output_md;
#else
auto output_pd = memory::primitive_desc(output_md, cpu_engine);
#endif
AllocateOutputTensor(context, input_mkl_shape, &output_pd, size_dims,
&output_tensor, &output_mkl_shape);
DCHECK(output_tensor);
@ -500,9 +458,9 @@ class MklSliceOp : public OpKernel {
slice_stream.reset(CreateStream(context, reorder_prim->GetEngine()));
reorder_prim->Execute(sliceParams, slice_stream);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
string(__FILE__) + ":" + std::to_string(__LINE__);
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
string(e.message) + ", in file " + string(__FILE__) +
":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
context,
errors::Aborted("Operation received an exception:", error_msg));
@ -512,7 +470,7 @@ class MklSliceOp : public OpKernel {
private:
void AllocateOutputTensor(OpKernelContext* context,
const MklDnnShape& input_mkl_shape,
MEMORY_PRIMITIVE_DESC* output_pd,
memory::desc* output_pd,
const memory::dims& output_dims,
Tensor** output_tensor,
MklDnnShape* output_mkl_shape) {

View File

@ -68,9 +68,6 @@ TEST(MklUtilTest, MklDnnBlockedFormatTest) {
EXPECT_EQ(a_md1.data.ndims, 2);
EXPECT_EQ(a_md1.data.dims[0], 3);
EXPECT_EQ(a_md1.data.dims[1], 4);
#ifndef ENABLE_MKLDNN_V1
EXPECT_EQ(a_md1.data.format, mkldnn_blocked);
#endif // !ENABLE_MKLDNN_V1
// Setting for case 2
MklDnnData<float> b(&cpu_engine);
@ -82,9 +79,6 @@ TEST(MklUtilTest, MklDnnBlockedFormatTest) {
EXPECT_EQ(b_md2.data.ndims, 2);
EXPECT_EQ(b_md2.data.dims[0], 3);
EXPECT_EQ(b_md2.data.dims[1], 4);
#ifndef ENABLE_MKLDNN_V1
EXPECT_EQ(b_md2.data.format, mkldnn_blocked);
#endif // !ENABLE_MKLDNN_V1
}
TEST(MklUtilTest, LRUCacheTest) {