MKL DNN 0.x code cleanup - MKLSlice op
This commit is contained in:
parent
80bdf7c72f
commit
5481229f1b
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include "mkldnn.hpp"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -26,13 +25,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/prefetch.h"
|
||||
#include "tensorflow/core/util/mkl_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
using mkldnn::stream;
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
using mkldnn::view;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -89,11 +85,10 @@ static void ValidateMklInputs(OpKernelContext* context, bool* is_identity,
|
||||
const int input_dims = input_tf_shape.dims();
|
||||
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
TensorShapeUtils::IsVector(begin_tensor.shape()) &&
|
||||
TensorShapeUtils::IsVector(size_tensor.shape()) &&
|
||||
begin_tensor.NumElements() == input_dims &&
|
||||
size_tensor.NumElements() == input_dims,
|
||||
context, TensorShapeUtils::IsVector(begin_tensor.shape()) &&
|
||||
TensorShapeUtils::IsVector(size_tensor.shape()) &&
|
||||
begin_tensor.NumElements() == input_dims &&
|
||||
size_tensor.NumElements() == input_dims,
|
||||
errors::InvalidArgument(
|
||||
"Expected begin and size arguments to be 1-D tensors of size ",
|
||||
input_dims, ", but got shapes ", begin_tensor.shape().DebugString(),
|
||||
@ -181,7 +176,7 @@ template <typename T>
|
||||
class MklSlicePrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklSlicePrimitive(const MklSliceParams& sliceParams)
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
: MklPrimitive(engine(engine::kind::cpu, 0)) {
|
||||
Setup(sliceParams);
|
||||
}
|
||||
|
||||
@ -198,12 +193,9 @@ class MklSlicePrimitive : public MklPrimitive {
|
||||
context_.src_mem->set_data_handle(sliceParams.from->get_data_handle());
|
||||
context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle());
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
|
||||
execute_primitives(context_.slice_primitives, slice_stream,
|
||||
context_.slice_primitives_args);
|
||||
#else
|
||||
slice_stream->submit(context_.slice_primitives);
|
||||
#endif
|
||||
|
||||
// We should set it back to DummyData so as to make the primitive
|
||||
// in cache pool stateless. Otherwise, if the result for previous
|
||||
@ -224,12 +216,8 @@ class MklSlicePrimitive : public MklPrimitive {
|
||||
std::shared_ptr<reorder::primitive_desc> reorder_pd;
|
||||
std::shared_ptr<mkldnn::stream> slice_stream;
|
||||
std::vector<mkldnn::primitive> slice_primitives;
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
std::shared_ptr<mkldnn::memory> src_sub_mem;
|
||||
std::vector<std::unordered_map<int, memory>> slice_primitives_args;
|
||||
#else
|
||||
std::shared_ptr<view::primitive_desc> view_pd;
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
SliceContext()
|
||||
: src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
|
||||
} context_;
|
||||
@ -237,15 +225,13 @@ class MklSlicePrimitive : public MklPrimitive {
|
||||
void Setup(const MklSliceParams& sliceParams) {
|
||||
// Actually, DummyData will not be used in computation,
|
||||
// because the real data will be filled before execution.
|
||||
context_.src_mem.reset(new MEMORY_CONSTRUCTOR_WITH_MEM_PD(
|
||||
sliceParams.from, cpu_engine_, DummyData));
|
||||
context_.dst_mem.reset(new MEMORY_CONSTRUCTOR_WITH_MEM_PD(
|
||||
sliceParams.to, cpu_engine_, DummyData));
|
||||
auto src_pd = context_.src_mem->GET_DESC;
|
||||
auto dst_pd = context_.dst_mem->GET_DESC;
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// MKL-DNN 1.x removes struct view, alias of memory in 0.x version.
|
||||
// So the implementation is based on submemory.
|
||||
context_.src_mem.reset(
|
||||
new memory(sliceParams.from->get_desc(), cpu_engine_, DummyData));
|
||||
context_.dst_mem.reset(
|
||||
new memory(sliceParams.to->get_desc(), cpu_engine_, DummyData));
|
||||
auto src_pd = context_.src_mem->get_desc();
|
||||
auto dst_pd = context_.dst_mem->get_desc();
|
||||
|
||||
auto src_sub_desc = context_.src_mem->get_desc().submemory_desc(
|
||||
sliceParams.size_dims, sliceParams.begin_dims);
|
||||
context_.src_sub_mem.reset(new memory(src_sub_desc, cpu_engine_, nullptr));
|
||||
@ -256,18 +242,7 @@ class MklSlicePrimitive : public MklPrimitive {
|
||||
|
||||
context_.slice_primitives_args.push_back(
|
||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{ MKLDNN_ARG_DST,
|
||||
*context_.dst_mem }});
|
||||
#else
|
||||
context_.view_pd =
|
||||
std::make_shared<view::primitive_desc>(view::primitive_desc(
|
||||
src_pd, sliceParams.size_dims, sliceParams.begin_dims));
|
||||
context_.reorder_pd =
|
||||
std::make_shared<reorder::primitive_desc>(reorder::primitive_desc(
|
||||
context_.view_pd->dst_primitive_desc(), dst_pd));
|
||||
context_.reorder_prim = std::make_shared<mkldnn::reorder>(
|
||||
reorder(*context_.reorder_pd, *context_.src_mem, *context_.dst_mem));
|
||||
#endif
|
||||
{MKLDNN_ARG_DST, *context_.dst_mem}});
|
||||
context_.slice_primitives.push_back(*context_.reorder_prim);
|
||||
}
|
||||
};
|
||||
@ -298,32 +273,24 @@ class MklSlicePrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
static string CreateKey(const MklSliceParams& sliceParams) {
|
||||
string prefix = "reorder";
|
||||
FactoryKeyCreator key_creator;
|
||||
auto const& from_desc = GET_MEMORY_DESC_FROM_MEM_PTR(sliceParams.from).data;
|
||||
auto const& to_desc = GET_MEMORY_DESC_FROM_MEM_PTR(sliceParams.to).data;
|
||||
auto const& from_desc = sliceParams.from->get_desc().data;
|
||||
auto const& to_desc = sliceParams.to->get_desc().data;
|
||||
const int kIdxFirstStride = 0;
|
||||
memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
|
||||
memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
|
||||
|
||||
// MKL-DNN removes "struct view". Submemory has similar capability.
|
||||
auto from_strides = from_desc.MEMORY_FORMAT_DESC.blocking.strides;
|
||||
auto to_strides = to_desc.MEMORY_FORMAT_DESC.blocking.strides;
|
||||
memory::dims from_strides_outer_blocks(
|
||||
GET_BLOCK_STRIDES(from_strides, kIdxFirstStride),
|
||||
&GET_BLOCK_STRIDES(from_strides, kIdxFirstStride)[from_desc.ndims]);
|
||||
memory::dims to_strides_outer_blocks(
|
||||
GET_BLOCK_STRIDES(to_strides, kIdxFirstStride),
|
||||
&GET_BLOCK_STRIDES(to_strides, kIdxFirstStride)[to_desc.ndims]);
|
||||
auto from_strides = from_desc.format_desc.blocking.strides;
|
||||
auto to_strides = to_desc.format_desc.blocking.strides;
|
||||
memory::dims from_strides_outer_blocks(from_strides,
|
||||
&from_strides[from_desc.ndims]);
|
||||
memory::dims to_strides_outer_blocks(to_strides,
|
||||
&to_strides[to_desc.ndims]);
|
||||
|
||||
key_creator.AddAsKey(prefix);
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
key_creator.AddAsKey(static_cast<int>(from_desc.format));
|
||||
#endif
|
||||
key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
|
||||
key_creator.AddAsKey(from_dims);
|
||||
key_creator.AddAsKey(from_strides_outer_blocks);
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
key_creator.AddAsKey(static_cast<int>(to_desc.format));
|
||||
#endif
|
||||
key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
|
||||
key_creator.AddAsKey(to_dims);
|
||||
key_creator.AddAsKey(to_strides_outer_blocks);
|
||||
@ -401,7 +368,7 @@ class MklSliceOp : public OpKernel {
|
||||
// primitive descriptor. And the reorder uses source memory as input but
|
||||
// traverses it according to a view in_submem_pd.
|
||||
|
||||
auto cpu_engine = engine(ENGINE_CPU, 0);
|
||||
auto cpu_engine = engine(engine::kind::cpu, 0);
|
||||
MklDnnData<T> src(&cpu_engine);
|
||||
MklDnnData<T> output(&cpu_engine);
|
||||
|
||||
@ -468,22 +435,13 @@ class MklSliceOp : public OpKernel {
|
||||
// Or else do nothing for it.
|
||||
auto op_md =
|
||||
MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
src.CheckReorderToOpMem(op_md, cpu_engine, context);
|
||||
#else
|
||||
auto op_pd = memory::primitive_desc(op_md, cpu_engine);
|
||||
src.CheckReorderToOpMem(op_pd);
|
||||
#endif
|
||||
|
||||
// Step 2 - Create memory for output.
|
||||
auto output_strides = CalculateTFStrides(size_dims);
|
||||
auto output_md =
|
||||
MklDnnData<T>::CreateBlockedMemDesc(size_dims, output_strides);
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
auto output_pd = output_md;
|
||||
#else
|
||||
auto output_pd = memory::primitive_desc(output_md, cpu_engine);
|
||||
#endif
|
||||
AllocateOutputTensor(context, input_mkl_shape, &output_pd, size_dims,
|
||||
&output_tensor, &output_mkl_shape);
|
||||
DCHECK(output_tensor);
|
||||
@ -500,9 +458,9 @@ class MklSliceOp : public OpKernel {
|
||||
slice_stream.reset(CreateStream(context, reorder_prim->GetEngine()));
|
||||
reorder_prim->Execute(sliceParams, slice_stream);
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
", message: " + string(e.message) + ", in file " +
|
||||
string(__FILE__) + ":" + std::to_string(__LINE__);
|
||||
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
|
||||
string(e.message) + ", in file " + string(__FILE__) +
|
||||
":" + std::to_string(__LINE__);
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
errors::Aborted("Operation received an exception:", error_msg));
|
||||
@ -512,7 +470,7 @@ class MklSliceOp : public OpKernel {
|
||||
private:
|
||||
void AllocateOutputTensor(OpKernelContext* context,
|
||||
const MklDnnShape& input_mkl_shape,
|
||||
MEMORY_PRIMITIVE_DESC* output_pd,
|
||||
memory::desc* output_pd,
|
||||
const memory::dims& output_dims,
|
||||
Tensor** output_tensor,
|
||||
MklDnnShape* output_mkl_shape) {
|
||||
|
@ -68,9 +68,6 @@ TEST(MklUtilTest, MklDnnBlockedFormatTest) {
|
||||
EXPECT_EQ(a_md1.data.ndims, 2);
|
||||
EXPECT_EQ(a_md1.data.dims[0], 3);
|
||||
EXPECT_EQ(a_md1.data.dims[1], 4);
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
EXPECT_EQ(a_md1.data.format, mkldnn_blocked);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// Setting for case 2
|
||||
MklDnnData<float> b(&cpu_engine);
|
||||
@ -82,9 +79,6 @@ TEST(MklUtilTest, MklDnnBlockedFormatTest) {
|
||||
EXPECT_EQ(b_md2.data.ndims, 2);
|
||||
EXPECT_EQ(b_md2.data.dims[0], 3);
|
||||
EXPECT_EQ(b_md2.data.dims[1], 4);
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
EXPECT_EQ(b_md2.data.format, mkldnn_blocked);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
TEST(MklUtilTest, LRUCacheTest) {
|
||||
|
Loading…
Reference in New Issue
Block a user