Merge pull request #43155 from Intel-tensorflow:dnn0x_cleanup_slice
PiperOrigin-RevId: 331279868 Change-Id: I1b6abbff7d270e9229b97fb8ec5d3e908de210a2
This commit is contained in:
commit
c3137a8294
@ -26,13 +26,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/prefetch.h"
|
||||
#include "tensorflow/core/util/mkl_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
using mkldnn::stream;
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
using mkldnn::view;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -181,7 +177,7 @@ template <typename T>
|
||||
class MklSlicePrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklSlicePrimitive(const MklSliceParams& sliceParams)
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
: MklPrimitive(engine(engine::kind::cpu, 0)) {
|
||||
Setup(sliceParams);
|
||||
}
|
||||
|
||||
@ -198,12 +194,9 @@ class MklSlicePrimitive : public MklPrimitive {
|
||||
context_.src_mem->set_data_handle(sliceParams.from->get_data_handle());
|
||||
context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle());
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
|
||||
execute_primitives(context_.slice_primitives, slice_stream,
|
||||
context_.slice_primitives_args);
|
||||
#else
|
||||
slice_stream->submit(context_.slice_primitives);
|
||||
#endif
|
||||
|
||||
// We should set it back to DummyData so as to make the primitive
|
||||
// in cache pool stateless. Otherwise, if the result for previous
|
||||
@ -224,12 +217,8 @@ class MklSlicePrimitive : public MklPrimitive {
|
||||
std::shared_ptr<reorder::primitive_desc> reorder_pd;
|
||||
std::shared_ptr<mkldnn::stream> slice_stream;
|
||||
std::vector<mkldnn::primitive> slice_primitives;
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
std::shared_ptr<mkldnn::memory> src_sub_mem;
|
||||
std::vector<std::unordered_map<int, memory>> slice_primitives_args;
|
||||
#else
|
||||
std::shared_ptr<view::primitive_desc> view_pd;
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
SliceContext()
|
||||
: src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
|
||||
} context_;
|
||||
@ -237,15 +226,13 @@ class MklSlicePrimitive : public MklPrimitive {
|
||||
void Setup(const MklSliceParams& sliceParams) {
|
||||
// Actually, DummyData will not be used in computation,
|
||||
// because the real data will be filled before execution.
|
||||
context_.src_mem.reset(new MEMORY_CONSTRUCTOR_WITH_MEM_PD(
|
||||
sliceParams.from, cpu_engine_, DummyData));
|
||||
context_.dst_mem.reset(new MEMORY_CONSTRUCTOR_WITH_MEM_PD(
|
||||
sliceParams.to, cpu_engine_, DummyData));
|
||||
auto src_pd = context_.src_mem->GET_DESC;
|
||||
auto dst_pd = context_.dst_mem->GET_DESC;
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// MKL-DNN 1.x removes struct view, alias of memory in 0.x version.
|
||||
// So the implementation is based on submemory.
|
||||
context_.src_mem.reset(
|
||||
new memory(sliceParams.from->get_desc(), cpu_engine_, DummyData));
|
||||
context_.dst_mem.reset(
|
||||
new memory(sliceParams.to->get_desc(), cpu_engine_, DummyData));
|
||||
auto src_pd = context_.src_mem->get_desc();
|
||||
auto dst_pd = context_.dst_mem->get_desc();
|
||||
|
||||
auto src_sub_desc = context_.src_mem->get_desc().submemory_desc(
|
||||
sliceParams.size_dims, sliceParams.begin_dims);
|
||||
context_.src_sub_mem.reset(new memory(src_sub_desc, cpu_engine_, nullptr));
|
||||
@ -256,18 +243,7 @@ class MklSlicePrimitive : public MklPrimitive {
|
||||
|
||||
context_.slice_primitives_args.push_back(
|
||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{ MKLDNN_ARG_DST,
|
||||
*context_.dst_mem }});
|
||||
#else
|
||||
context_.view_pd =
|
||||
std::make_shared<view::primitive_desc>(view::primitive_desc(
|
||||
src_pd, sliceParams.size_dims, sliceParams.begin_dims));
|
||||
context_.reorder_pd =
|
||||
std::make_shared<reorder::primitive_desc>(reorder::primitive_desc(
|
||||
context_.view_pd->dst_primitive_desc(), dst_pd));
|
||||
context_.reorder_prim = std::make_shared<mkldnn::reorder>(
|
||||
reorder(*context_.reorder_pd, *context_.src_mem, *context_.dst_mem));
|
||||
#endif
|
||||
{MKLDNN_ARG_DST, *context_.dst_mem}});
|
||||
context_.slice_primitives.push_back(*context_.reorder_prim);
|
||||
}
|
||||
};
|
||||
@ -298,32 +274,24 @@ class MklSlicePrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
static string CreateKey(const MklSliceParams& sliceParams) {
|
||||
string prefix = "reorder";
|
||||
FactoryKeyCreator key_creator;
|
||||
auto const& from_desc = GET_MEMORY_DESC_FROM_MEM_PTR(sliceParams.from).data;
|
||||
auto const& to_desc = GET_MEMORY_DESC_FROM_MEM_PTR(sliceParams.to).data;
|
||||
auto const& from_desc = sliceParams.from->get_desc().data;
|
||||
auto const& to_desc = sliceParams.to->get_desc().data;
|
||||
const int kIdxFirstStride = 0;
|
||||
memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
|
||||
memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
|
||||
|
||||
// MKL-DNN removes "struct view". Submemory has similar capability.
|
||||
auto from_strides = from_desc.MEMORY_FORMAT_DESC.blocking.strides;
|
||||
auto to_strides = to_desc.MEMORY_FORMAT_DESC.blocking.strides;
|
||||
memory::dims from_strides_outer_blocks(
|
||||
GET_BLOCK_STRIDES(from_strides, kIdxFirstStride),
|
||||
&GET_BLOCK_STRIDES(from_strides, kIdxFirstStride)[from_desc.ndims]);
|
||||
memory::dims to_strides_outer_blocks(
|
||||
GET_BLOCK_STRIDES(to_strides, kIdxFirstStride),
|
||||
&GET_BLOCK_STRIDES(to_strides, kIdxFirstStride)[to_desc.ndims]);
|
||||
auto from_strides = from_desc.format_desc.blocking.strides;
|
||||
auto to_strides = to_desc.format_desc.blocking.strides;
|
||||
memory::dims from_strides_outer_blocks(from_strides,
|
||||
&from_strides[from_desc.ndims]);
|
||||
memory::dims to_strides_outer_blocks(to_strides,
|
||||
&to_strides[to_desc.ndims]);
|
||||
|
||||
key_creator.AddAsKey(prefix);
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
key_creator.AddAsKey(static_cast<int>(from_desc.format));
|
||||
#endif
|
||||
key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
|
||||
key_creator.AddAsKey(from_dims);
|
||||
key_creator.AddAsKey(from_strides_outer_blocks);
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
key_creator.AddAsKey(static_cast<int>(to_desc.format));
|
||||
#endif
|
||||
key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
|
||||
key_creator.AddAsKey(to_dims);
|
||||
key_creator.AddAsKey(to_strides_outer_blocks);
|
||||
@ -401,7 +369,7 @@ class MklSliceOp : public OpKernel {
|
||||
// primitive descriptor. And the reorder uses source memory as input but
|
||||
// traverses it according to a view in_submem_pd.
|
||||
|
||||
auto cpu_engine = engine(ENGINE_CPU, 0);
|
||||
auto cpu_engine = engine(engine::kind::cpu, 0);
|
||||
MklDnnData<T> src(&cpu_engine);
|
||||
MklDnnData<T> output(&cpu_engine);
|
||||
|
||||
@ -468,22 +436,13 @@ class MklSliceOp : public OpKernel {
|
||||
// Or else do nothing for it.
|
||||
auto op_md =
|
||||
MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
src.CheckReorderToOpMem(op_md, cpu_engine, context);
|
||||
#else
|
||||
auto op_pd = memory::primitive_desc(op_md, cpu_engine);
|
||||
src.CheckReorderToOpMem(op_pd);
|
||||
#endif
|
||||
|
||||
// Step 2 - Create memory for output.
|
||||
auto output_strides = CalculateTFStrides(size_dims);
|
||||
auto output_md =
|
||||
MklDnnData<T>::CreateBlockedMemDesc(size_dims, output_strides);
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
auto output_pd = output_md;
|
||||
#else
|
||||
auto output_pd = memory::primitive_desc(output_md, cpu_engine);
|
||||
#endif
|
||||
AllocateOutputTensor(context, input_mkl_shape, &output_pd, size_dims,
|
||||
&output_tensor, &output_mkl_shape);
|
||||
DCHECK(output_tensor);
|
||||
@ -512,7 +471,7 @@ class MklSliceOp : public OpKernel {
|
||||
private:
|
||||
void AllocateOutputTensor(OpKernelContext* context,
|
||||
const MklDnnShape& input_mkl_shape,
|
||||
MEMORY_PRIMITIVE_DESC* output_pd,
|
||||
memory::desc* output_pd,
|
||||
const memory::dims& output_dims,
|
||||
Tensor** output_tensor,
|
||||
MklDnnShape* output_mkl_shape) {
|
||||
|
||||
@ -68,9 +68,6 @@ TEST(MklUtilTest, MklDnnBlockedFormatTest) {
|
||||
EXPECT_EQ(a_md1.data.ndims, 2);
|
||||
EXPECT_EQ(a_md1.data.dims[0], 3);
|
||||
EXPECT_EQ(a_md1.data.dims[1], 4);
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
EXPECT_EQ(a_md1.data.format, mkldnn_blocked);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// Setting for case 2
|
||||
MklDnnData<float> b(&cpu_engine);
|
||||
@ -82,9 +79,6 @@ TEST(MklUtilTest, MklDnnBlockedFormatTest) {
|
||||
EXPECT_EQ(b_md2.data.ndims, 2);
|
||||
EXPECT_EQ(b_md2.data.dims[0], 3);
|
||||
EXPECT_EQ(b_md2.data.dims[1], 4);
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
EXPECT_EQ(b_md2.data.format, mkldnn_blocked);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
TEST(MklUtilTest, LRUCacheTest) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user