Merge pull request #43225 from Intel-tensorflow:dnn0x_cleanup_transpose
PiperOrigin-RevId: 331886427 Change-Id: Ia883fcf1555b2cf312304c73e82f41c03ad9bfd5
This commit is contained in:
commit
b4f7c37ccb
@ -27,7 +27,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||
#include "tensorflow/core/kernels/transpose_op.h"
|
||||
#include "tensorflow/core/util/mkl_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
using mkldnn::stream;
|
||||
@ -126,7 +125,7 @@ template <typename T>
|
||||
Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor,
|
||||
Tensor* out_tensor, const gtl::ArraySlice<int32>& perm) {
|
||||
try {
|
||||
engine cpu_engine = engine(ENGINE_CPU, 0);
|
||||
engine cpu_engine = engine(engine::kind::cpu, 0);
|
||||
MklDnnData<T> in(&cpu_engine);
|
||||
MklDnnData<T> out(&cpu_engine);
|
||||
|
||||
@ -144,7 +143,6 @@ Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor,
|
||||
out.SetUsrMem(in_dims, out_strides, out_tensor);
|
||||
|
||||
std::vector<primitive> net;
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
auto* prim = FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem());
|
||||
transpose_stream.reset(CreateStream(context, prim->GetEngine()));
|
||||
in.SetUsrMemDataHandle(&in_tensor, transpose_stream);
|
||||
@ -154,11 +152,6 @@ Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor,
|
||||
net_args.push_back({{MKLDNN_ARG_FROM, *in.GetUsrMem()},
|
||||
{MKLDNN_ARG_TO, *out.GetUsrMem()}});
|
||||
execute_primitives(net, transpose_stream, net_args);
|
||||
#else
|
||||
transpose_stream.reset(new CPU_STREAM(cpu_engine));
|
||||
net.push_back(FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem()));
|
||||
transpose_stream->submit(net).wait();
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
return Status::OK();
|
||||
} catch (mkldnn::error& e) {
|
||||
@ -196,7 +189,7 @@ Status MklTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
|
||||
|
||||
// MKL-DNN has limit on the maximum number of dimensions in a tensor.
|
||||
// Fallback to Eigen for not supported cases.
|
||||
if (in.dims() <= TENSOR_MAX_DIMS) {
|
||||
if (in.dims() <= MKLDNN_MAX_NDIMS) {
|
||||
switch (in.dtype()) {
|
||||
case DT_FLOAT:
|
||||
return MKLTransposeND<float>(ctx, in, out, perm);
|
||||
@ -243,7 +236,7 @@ Status MklConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
|
||||
|
||||
// MKL-DNN has limit on the maximum number of dimensions in a tensor.
|
||||
// Fallback to Eigen for not supported cases.
|
||||
if (in.dims() <= TENSOR_MAX_DIMS) {
|
||||
if (in.dims() <= MKLDNN_MAX_NDIMS) {
|
||||
switch (in.dtype()) {
|
||||
case DT_FLOAT:
|
||||
return MKLTransposeND<float>(ctx, in, out, perm);
|
||||
|
Loading…
x
Reference in New Issue
Block a user