Merge pull request #43225 from Intel-tensorflow:dnn0x_cleanup_transpose
PiperOrigin-RevId: 331886427 Change-Id: Ia883fcf1555b2cf312304c73e82f41c03ad9bfd5
This commit is contained in:
commit
b4f7c37ccb
@ -27,7 +27,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/kernels/transpose_functor.h"
|
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||||
#include "tensorflow/core/kernels/transpose_op.h"
|
#include "tensorflow/core/kernels/transpose_op.h"
|
||||||
#include "tensorflow/core/util/mkl_types.h"
|
|
||||||
#include "tensorflow/core/util/mkl_util.h"
|
#include "tensorflow/core/util/mkl_util.h"
|
||||||
|
|
||||||
using mkldnn::stream;
|
using mkldnn::stream;
|
||||||
@ -126,7 +125,7 @@ template <typename T>
|
|||||||
Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor,
|
Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor,
|
||||||
Tensor* out_tensor, const gtl::ArraySlice<int32>& perm) {
|
Tensor* out_tensor, const gtl::ArraySlice<int32>& perm) {
|
||||||
try {
|
try {
|
||||||
engine cpu_engine = engine(ENGINE_CPU, 0);
|
engine cpu_engine = engine(engine::kind::cpu, 0);
|
||||||
MklDnnData<T> in(&cpu_engine);
|
MklDnnData<T> in(&cpu_engine);
|
||||||
MklDnnData<T> out(&cpu_engine);
|
MklDnnData<T> out(&cpu_engine);
|
||||||
|
|
||||||
@ -144,7 +143,6 @@ Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor,
|
|||||||
out.SetUsrMem(in_dims, out_strides, out_tensor);
|
out.SetUsrMem(in_dims, out_strides, out_tensor);
|
||||||
|
|
||||||
std::vector<primitive> net;
|
std::vector<primitive> net;
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
auto* prim = FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem());
|
auto* prim = FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem());
|
||||||
transpose_stream.reset(CreateStream(context, prim->GetEngine()));
|
transpose_stream.reset(CreateStream(context, prim->GetEngine()));
|
||||||
in.SetUsrMemDataHandle(&in_tensor, transpose_stream);
|
in.SetUsrMemDataHandle(&in_tensor, transpose_stream);
|
||||||
@ -154,11 +152,6 @@ Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor,
|
|||||||
net_args.push_back({{MKLDNN_ARG_FROM, *in.GetUsrMem()},
|
net_args.push_back({{MKLDNN_ARG_FROM, *in.GetUsrMem()},
|
||||||
{MKLDNN_ARG_TO, *out.GetUsrMem()}});
|
{MKLDNN_ARG_TO, *out.GetUsrMem()}});
|
||||||
execute_primitives(net, transpose_stream, net_args);
|
execute_primitives(net, transpose_stream, net_args);
|
||||||
#else
|
|
||||||
transpose_stream.reset(new CPU_STREAM(cpu_engine));
|
|
||||||
net.push_back(FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem()));
|
|
||||||
transpose_stream->submit(net).wait();
|
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} catch (mkldnn::error& e) {
|
} catch (mkldnn::error& e) {
|
||||||
@ -196,7 +189,7 @@ Status MklTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
|
|||||||
|
|
||||||
// MKL-DNN has limit on the maximum number of dimensions in a tensor.
|
// MKL-DNN has limit on the maximum number of dimensions in a tensor.
|
||||||
// Fallback to Eigen for not supported cases.
|
// Fallback to Eigen for not supported cases.
|
||||||
if (in.dims() <= TENSOR_MAX_DIMS) {
|
if (in.dims() <= MKLDNN_MAX_NDIMS) {
|
||||||
switch (in.dtype()) {
|
switch (in.dtype()) {
|
||||||
case DT_FLOAT:
|
case DT_FLOAT:
|
||||||
return MKLTransposeND<float>(ctx, in, out, perm);
|
return MKLTransposeND<float>(ctx, in, out, perm);
|
||||||
@ -243,7 +236,7 @@ Status MklConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
|
|||||||
|
|
||||||
// MKL-DNN has limit on the maximum number of dimensions in a tensor.
|
// MKL-DNN has limit on the maximum number of dimensions in a tensor.
|
||||||
// Fallback to Eigen for not supported cases.
|
// Fallback to Eigen for not supported cases.
|
||||||
if (in.dims() <= TENSOR_MAX_DIMS) {
|
if (in.dims() <= MKLDNN_MAX_NDIMS) {
|
||||||
switch (in.dtype()) {
|
switch (in.dtype()) {
|
||||||
case DT_FLOAT:
|
case DT_FLOAT:
|
||||||
return MKLTransposeND<float>(ctx, in, out, perm);
|
return MKLTransposeND<float>(ctx, in, out, perm);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user