MKL DNN 0.x code cleanup - MKLSlice op

This commit is contained in:
xiaohong1031 2020-09-11 09:28:38 -07:00
parent 80bdf7c72f
commit 5481229f1b
2 changed files with 28 additions and 76 deletions

View File

@ -18,7 +18,6 @@ limitations under the License.
#ifdef INTEL_MKL #ifdef INTEL_MKL
#include "mkldnn.hpp" #include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
@ -26,13 +25,10 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/prefetch.h" #include "tensorflow/core/platform/prefetch.h"
#include "tensorflow/core/util/mkl_types.h"
#include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/mkl_util.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
using mkldnn::stream; using mkldnn::stream;
#ifndef ENABLE_MKLDNN_V1
using mkldnn::view;
#endif
namespace tensorflow { namespace tensorflow {
@ -89,11 +85,10 @@ static void ValidateMklInputs(OpKernelContext* context, bool* is_identity,
const int input_dims = input_tf_shape.dims(); const int input_dims = input_tf_shape.dims();
OP_REQUIRES( OP_REQUIRES(
context, context, TensorShapeUtils::IsVector(begin_tensor.shape()) &&
TensorShapeUtils::IsVector(begin_tensor.shape()) && TensorShapeUtils::IsVector(size_tensor.shape()) &&
TensorShapeUtils::IsVector(size_tensor.shape()) && begin_tensor.NumElements() == input_dims &&
begin_tensor.NumElements() == input_dims && size_tensor.NumElements() == input_dims,
size_tensor.NumElements() == input_dims,
errors::InvalidArgument( errors::InvalidArgument(
"Expected begin and size arguments to be 1-D tensors of size ", "Expected begin and size arguments to be 1-D tensors of size ",
input_dims, ", but got shapes ", begin_tensor.shape().DebugString(), input_dims, ", but got shapes ", begin_tensor.shape().DebugString(),
@ -181,7 +176,7 @@ template <typename T>
class MklSlicePrimitive : public MklPrimitive { class MklSlicePrimitive : public MklPrimitive {
public: public:
explicit MklSlicePrimitive(const MklSliceParams& sliceParams) explicit MklSlicePrimitive(const MklSliceParams& sliceParams)
: MklPrimitive(engine(ENGINE_CPU, 0)) { : MklPrimitive(engine(engine::kind::cpu, 0)) {
Setup(sliceParams); Setup(sliceParams);
} }
@ -198,12 +193,9 @@ class MklSlicePrimitive : public MklPrimitive {
context_.src_mem->set_data_handle(sliceParams.from->get_data_handle()); context_.src_mem->set_data_handle(sliceParams.from->get_data_handle());
context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle()); context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle());
#endif // ENABLE_MKLDNN_THREADPOOL #endif // ENABLE_MKLDNN_THREADPOOL
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.slice_primitives, slice_stream, execute_primitives(context_.slice_primitives, slice_stream,
context_.slice_primitives_args); context_.slice_primitives_args);
#else
slice_stream->submit(context_.slice_primitives);
#endif
// We should set it back to DummyData so as to make the primitive // We should set it back to DummyData so as to make the primitive
// in cache pool stateless. Otherwise, if the result for previous // in cache pool stateless. Otherwise, if the result for previous
@ -224,12 +216,8 @@ class MklSlicePrimitive : public MklPrimitive {
std::shared_ptr<reorder::primitive_desc> reorder_pd; std::shared_ptr<reorder::primitive_desc> reorder_pd;
std::shared_ptr<mkldnn::stream> slice_stream; std::shared_ptr<mkldnn::stream> slice_stream;
std::vector<mkldnn::primitive> slice_primitives; std::vector<mkldnn::primitive> slice_primitives;
#ifdef ENABLE_MKLDNN_V1
std::shared_ptr<mkldnn::memory> src_sub_mem; std::shared_ptr<mkldnn::memory> src_sub_mem;
std::vector<std::unordered_map<int, memory>> slice_primitives_args; std::vector<std::unordered_map<int, memory>> slice_primitives_args;
#else
std::shared_ptr<view::primitive_desc> view_pd;
#endif // ENABLE_MKLDNN_V1
SliceContext() SliceContext()
: src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {} : src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
} context_; } context_;
@ -237,15 +225,13 @@ class MklSlicePrimitive : public MklPrimitive {
void Setup(const MklSliceParams& sliceParams) { void Setup(const MklSliceParams& sliceParams) {
// Actually, DummyData will not be used in computation, // Actually, DummyData will not be used in computation,
// because the real data will be filled before execution. // because the real data will be filled before execution.
context_.src_mem.reset(new MEMORY_CONSTRUCTOR_WITH_MEM_PD( context_.src_mem.reset(
sliceParams.from, cpu_engine_, DummyData)); new memory(sliceParams.from->get_desc(), cpu_engine_, DummyData));
context_.dst_mem.reset(new MEMORY_CONSTRUCTOR_WITH_MEM_PD( context_.dst_mem.reset(
sliceParams.to, cpu_engine_, DummyData)); new memory(sliceParams.to->get_desc(), cpu_engine_, DummyData));
auto src_pd = context_.src_mem->GET_DESC; auto src_pd = context_.src_mem->get_desc();
auto dst_pd = context_.dst_mem->GET_DESC; auto dst_pd = context_.dst_mem->get_desc();
#ifdef ENABLE_MKLDNN_V1
// MKL-DNN 1.x removes struct view, alias of memory in 0.x version.
// So the implementation is based on submemory.
auto src_sub_desc = context_.src_mem->get_desc().submemory_desc( auto src_sub_desc = context_.src_mem->get_desc().submemory_desc(
sliceParams.size_dims, sliceParams.begin_dims); sliceParams.size_dims, sliceParams.begin_dims);
context_.src_sub_mem.reset(new memory(src_sub_desc, cpu_engine_, nullptr)); context_.src_sub_mem.reset(new memory(src_sub_desc, cpu_engine_, nullptr));
@ -256,18 +242,7 @@ class MklSlicePrimitive : public MklPrimitive {
context_.slice_primitives_args.push_back( context_.slice_primitives_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem}, {{MKLDNN_ARG_SRC, *context_.src_mem},
{ MKLDNN_ARG_DST, {MKLDNN_ARG_DST, *context_.dst_mem}});
*context_.dst_mem }});
#else
context_.view_pd =
std::make_shared<view::primitive_desc>(view::primitive_desc(
src_pd, sliceParams.size_dims, sliceParams.begin_dims));
context_.reorder_pd =
std::make_shared<reorder::primitive_desc>(reorder::primitive_desc(
context_.view_pd->dst_primitive_desc(), dst_pd));
context_.reorder_prim = std::make_shared<mkldnn::reorder>(
reorder(*context_.reorder_pd, *context_.src_mem, *context_.dst_mem));
#endif
context_.slice_primitives.push_back(*context_.reorder_prim); context_.slice_primitives.push_back(*context_.reorder_prim);
} }
}; };
@ -298,32 +273,24 @@ class MklSlicePrimitiveFactory : public MklPrimitiveFactory<T> {
static string CreateKey(const MklSliceParams& sliceParams) { static string CreateKey(const MklSliceParams& sliceParams) {
string prefix = "reorder"; string prefix = "reorder";
FactoryKeyCreator key_creator; FactoryKeyCreator key_creator;
auto const& from_desc = GET_MEMORY_DESC_FROM_MEM_PTR(sliceParams.from).data; auto const& from_desc = sliceParams.from->get_desc().data;
auto const& to_desc = GET_MEMORY_DESC_FROM_MEM_PTR(sliceParams.to).data; auto const& to_desc = sliceParams.to->get_desc().data;
const int kIdxFirstStride = 0; const int kIdxFirstStride = 0;
memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]); memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]); memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
// MKL-DNN removes "struct view". Submemory has similar capability. // MKL-DNN removes "struct view". Submemory has similar capability.
auto from_strides = from_desc.MEMORY_FORMAT_DESC.blocking.strides; auto from_strides = from_desc.format_desc.blocking.strides;
auto to_strides = to_desc.MEMORY_FORMAT_DESC.blocking.strides; auto to_strides = to_desc.format_desc.blocking.strides;
memory::dims from_strides_outer_blocks( memory::dims from_strides_outer_blocks(from_strides,
GET_BLOCK_STRIDES(from_strides, kIdxFirstStride), &from_strides[from_desc.ndims]);
&GET_BLOCK_STRIDES(from_strides, kIdxFirstStride)[from_desc.ndims]); memory::dims to_strides_outer_blocks(to_strides,
memory::dims to_strides_outer_blocks( &to_strides[to_desc.ndims]);
GET_BLOCK_STRIDES(to_strides, kIdxFirstStride),
&GET_BLOCK_STRIDES(to_strides, kIdxFirstStride)[to_desc.ndims]);
key_creator.AddAsKey(prefix); key_creator.AddAsKey(prefix);
#ifndef ENABLE_MKLDNN_V1
key_creator.AddAsKey(static_cast<int>(from_desc.format));
#endif
key_creator.AddAsKey(static_cast<int>(from_desc.data_type)); key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
key_creator.AddAsKey(from_dims); key_creator.AddAsKey(from_dims);
key_creator.AddAsKey(from_strides_outer_blocks); key_creator.AddAsKey(from_strides_outer_blocks);
#ifndef ENABLE_MKLDNN_V1
key_creator.AddAsKey(static_cast<int>(to_desc.format));
#endif
key_creator.AddAsKey(static_cast<int>(to_desc.data_type)); key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
key_creator.AddAsKey(to_dims); key_creator.AddAsKey(to_dims);
key_creator.AddAsKey(to_strides_outer_blocks); key_creator.AddAsKey(to_strides_outer_blocks);
@ -401,7 +368,7 @@ class MklSliceOp : public OpKernel {
// primitive descriptor. And the reorder uses source memory as input but // primitive descriptor. And the reorder uses source memory as input but
// traverses it according to a view in_submem_pd. // traverses it according to a view in_submem_pd.
auto cpu_engine = engine(ENGINE_CPU, 0); auto cpu_engine = engine(engine::kind::cpu, 0);
MklDnnData<T> src(&cpu_engine); MklDnnData<T> src(&cpu_engine);
MklDnnData<T> output(&cpu_engine); MklDnnData<T> output(&cpu_engine);
@ -468,22 +435,13 @@ class MklSliceOp : public OpKernel {
// Or else do nothing for it. // Or else do nothing for it.
auto op_md = auto op_md =
MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides); MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
#ifdef ENABLE_MKLDNN_V1
src.CheckReorderToOpMem(op_md, cpu_engine, context); src.CheckReorderToOpMem(op_md, cpu_engine, context);
#else
auto op_pd = memory::primitive_desc(op_md, cpu_engine);
src.CheckReorderToOpMem(op_pd);
#endif
// Step 2 - Create memory for output. // Step 2 - Create memory for output.
auto output_strides = CalculateTFStrides(size_dims); auto output_strides = CalculateTFStrides(size_dims);
auto output_md = auto output_md =
MklDnnData<T>::CreateBlockedMemDesc(size_dims, output_strides); MklDnnData<T>::CreateBlockedMemDesc(size_dims, output_strides);
#ifdef ENABLE_MKLDNN_V1
auto output_pd = output_md; auto output_pd = output_md;
#else
auto output_pd = memory::primitive_desc(output_md, cpu_engine);
#endif
AllocateOutputTensor(context, input_mkl_shape, &output_pd, size_dims, AllocateOutputTensor(context, input_mkl_shape, &output_pd, size_dims,
&output_tensor, &output_mkl_shape); &output_tensor, &output_mkl_shape);
DCHECK(output_tensor); DCHECK(output_tensor);
@ -500,9 +458,9 @@ class MklSliceOp : public OpKernel {
slice_stream.reset(CreateStream(context, reorder_prim->GetEngine())); slice_stream.reset(CreateStream(context, reorder_prim->GetEngine()));
reorder_prim->Execute(sliceParams, slice_stream); reorder_prim->Execute(sliceParams, slice_stream);
} catch (mkldnn::error& e) { } catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) + string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
", message: " + string(e.message) + ", in file " + string(e.message) + ", in file " + string(__FILE__) +
string(__FILE__) + ":" + std::to_string(__LINE__); ":" + std::to_string(__LINE__);
OP_REQUIRES_OK( OP_REQUIRES_OK(
context, context,
errors::Aborted("Operation received an exception:", error_msg)); errors::Aborted("Operation received an exception:", error_msg));
@ -512,7 +470,7 @@ class MklSliceOp : public OpKernel {
private: private:
void AllocateOutputTensor(OpKernelContext* context, void AllocateOutputTensor(OpKernelContext* context,
const MklDnnShape& input_mkl_shape, const MklDnnShape& input_mkl_shape,
MEMORY_PRIMITIVE_DESC* output_pd, memory::desc* output_pd,
const memory::dims& output_dims, const memory::dims& output_dims,
Tensor** output_tensor, Tensor** output_tensor,
MklDnnShape* output_mkl_shape) { MklDnnShape* output_mkl_shape) {

View File

@ -68,9 +68,6 @@ TEST(MklUtilTest, MklDnnBlockedFormatTest) {
EXPECT_EQ(a_md1.data.ndims, 2); EXPECT_EQ(a_md1.data.ndims, 2);
EXPECT_EQ(a_md1.data.dims[0], 3); EXPECT_EQ(a_md1.data.dims[0], 3);
EXPECT_EQ(a_md1.data.dims[1], 4); EXPECT_EQ(a_md1.data.dims[1], 4);
#ifndef ENABLE_MKLDNN_V1
EXPECT_EQ(a_md1.data.format, mkldnn_blocked);
#endif // !ENABLE_MKLDNN_V1
// Setting for case 2 // Setting for case 2
MklDnnData<float> b(&cpu_engine); MklDnnData<float> b(&cpu_engine);
@ -82,9 +79,6 @@ TEST(MklUtilTest, MklDnnBlockedFormatTest) {
EXPECT_EQ(b_md2.data.ndims, 2); EXPECT_EQ(b_md2.data.ndims, 2);
EXPECT_EQ(b_md2.data.dims[0], 3); EXPECT_EQ(b_md2.data.dims[0], 3);
EXPECT_EQ(b_md2.data.dims[1], 4); EXPECT_EQ(b_md2.data.dims[1], 4);
#ifndef ENABLE_MKLDNN_V1
EXPECT_EQ(b_md2.data.format, mkldnn_blocked);
#endif // !ENABLE_MKLDNN_V1
} }
TEST(MklUtilTest, LRUCacheTest) { TEST(MklUtilTest, LRUCacheTest) {