Merge pull request #17049 from fo40225/fix_mkl_win
Fix MKL build break on Windows
This commit is contained in:
commit
d66c9726dd
@ -21,7 +21,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#ifdef INTEL_MKL
|
#ifdef INTEL_MKL
|
||||||
|
|
||||||
#include <unistd.h>
|
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "tensorflow/core/common_runtime/bfc_allocator.h"
|
#include "tensorflow/core/common_runtime/bfc_allocator.h"
|
||||||
|
@ -222,7 +222,7 @@ Status MklToTfConversionPass::InsertInputConversionNode(
|
|||||||
BaseType(n->input_type(0)));
|
BaseType(n->input_type(0)));
|
||||||
|
|
||||||
// Check ordering of edges
|
// Check ordering of edges
|
||||||
for (uint i = 0; i < 4; i++) {
|
for (uint32 i = 0; i < 4; i++) {
|
||||||
CHECK_EQ((edges[i]->dst_input() == i), true);
|
CHECK_EQ((edges[i]->dst_input() == i), true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,7 +29,6 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include "mkl_cblas.h"
|
#include "mkl_cblas.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/framework/numeric_types.h"
|
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
@ -41,9 +40,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
#define MKL_Complex8 tensorflow::complex64
|
|
||||||
#define MKL_Complex16 tensorflow::complex128
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
@ -180,16 +176,16 @@ class BatchMatMulMkl : public OpKernel {
|
|||||||
void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
|
void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
|
||||||
const bool TransB, const MKL_INT *M_Array,
|
const bool TransB, const MKL_INT *M_Array,
|
||||||
const MKL_INT *N_Array, const MKL_INT *K_Array,
|
const MKL_INT *N_Array, const MKL_INT *K_Array,
|
||||||
const MKL_Complex8 **A_Array, const MKL_INT *lda_Array,
|
const complex64 **A_Array, const MKL_INT *lda_Array,
|
||||||
const MKL_Complex8 **B_Array, const MKL_INT *ldb_Array,
|
const complex64 **B_Array, const MKL_INT *ldb_Array,
|
||||||
MKL_Complex8 **C_Array, const MKL_INT *ldc_Array,
|
complex64 **C_Array, const MKL_INT *ldc_Array,
|
||||||
const MKL_INT group_count, const MKL_INT *group_size) {
|
const MKL_INT group_count, const MKL_INT *group_size) {
|
||||||
std::vector<CBLAS_TRANSPOSE> TransA_array(
|
std::vector<CBLAS_TRANSPOSE> TransA_array(
|
||||||
group_size[0], TransA ? CblasConjTrans : CblasNoTrans);
|
group_size[0], TransA ? CblasConjTrans : CblasNoTrans);
|
||||||
std::vector<CBLAS_TRANSPOSE> TransB_array(
|
std::vector<CBLAS_TRANSPOSE> TransB_array(
|
||||||
group_size[0], TransB ? CblasConjTrans : CblasNoTrans);
|
group_size[0], TransB ? CblasConjTrans : CblasNoTrans);
|
||||||
std::vector<MKL_Complex8> alpha_Array(group_size[0], {1.0f, 0.0f});
|
std::vector<complex64> alpha_Array(group_size[0], {1.0f, 0.0f});
|
||||||
std::vector<MKL_Complex8> beta_Array(group_size[0], {0.0f, 0.0f});
|
std::vector<complex64> beta_Array(group_size[0], {0.0f, 0.0f});
|
||||||
cblas_cgemm_batch(
|
cblas_cgemm_batch(
|
||||||
Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array,
|
Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array,
|
||||||
static_cast<const void *>(&alpha_Array[0]),
|
static_cast<const void *>(&alpha_Array[0]),
|
||||||
@ -202,18 +198,18 @@ class BatchMatMulMkl : public OpKernel {
|
|||||||
void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
|
void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
|
||||||
const bool TransB, const MKL_INT *M_Array,
|
const bool TransB, const MKL_INT *M_Array,
|
||||||
const MKL_INT *N_Array, const MKL_INT *K_Array,
|
const MKL_INT *N_Array, const MKL_INT *K_Array,
|
||||||
const MKL_Complex16 **A_Array,
|
const complex128 **A_Array,
|
||||||
const MKL_INT *lda_Array,
|
const MKL_INT *lda_Array,
|
||||||
const MKL_Complex16 **B_Array,
|
const complex128 **B_Array,
|
||||||
const MKL_INT *ldb_Array, MKL_Complex16 **C_Array,
|
const MKL_INT *ldb_Array, complex128 **C_Array,
|
||||||
const MKL_INT *ldc_Array, const MKL_INT group_count,
|
const MKL_INT *ldc_Array, const MKL_INT group_count,
|
||||||
const MKL_INT *group_size) {
|
const MKL_INT *group_size) {
|
||||||
std::vector<CBLAS_TRANSPOSE> TransA_array(
|
std::vector<CBLAS_TRANSPOSE> TransA_array(
|
||||||
group_size[0], TransA ? CblasConjTrans : CblasNoTrans);
|
group_size[0], TransA ? CblasConjTrans : CblasNoTrans);
|
||||||
std::vector<CBLAS_TRANSPOSE> TransB_array(
|
std::vector<CBLAS_TRANSPOSE> TransB_array(
|
||||||
group_size[0], TransB ? CblasConjTrans : CblasNoTrans);
|
group_size[0], TransB ? CblasConjTrans : CblasNoTrans);
|
||||||
std::vector<MKL_Complex16> alpha_Array(group_size[0], {1.0f, 0.0f});
|
std::vector<complex128> alpha_Array(group_size[0], {1.0f, 0.0f});
|
||||||
std::vector<MKL_Complex16> beta_Array(group_size[0], {0.0f, 0.0f});
|
std::vector<complex128> beta_Array(group_size[0], {0.0f, 0.0f});
|
||||||
cblas_zgemm_batch(
|
cblas_zgemm_batch(
|
||||||
Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array,
|
Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array,
|
||||||
static_cast<const void *>(&alpha_Array[0]),
|
static_cast<const void *>(&alpha_Array[0]),
|
||||||
|
@ -145,8 +145,8 @@ class MklInputConversionOp : public OpKernel {
|
|||||||
const MklShape* mkl_shape;
|
const MklShape* mkl_shape;
|
||||||
const Tensor* tf_tensor;
|
const Tensor* tf_tensor;
|
||||||
MklShape* tf_mkl_shape;
|
MklShape* tf_mkl_shape;
|
||||||
uint mkl_tensor_index;
|
uint32 mkl_tensor_index;
|
||||||
uint tf_tensor_index;
|
uint32 tf_tensor_index;
|
||||||
if (input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) {
|
if (input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) {
|
||||||
mkl_tensor = &input_tensor_0;
|
mkl_tensor = &input_tensor_0;
|
||||||
mkl_shape = &input_shape_0;
|
mkl_shape = &input_shape_0;
|
||||||
|
@ -170,32 +170,32 @@ class MklMatMulOp : public OpKernel {
|
|||||||
// Matrix-Matrix Multiplication with Complex64 (std::complex<float>) tensors.
|
// Matrix-Matrix Multiplication with Complex64 (std::complex<float>) tensors.
|
||||||
// For detailed info about parameters, look at FP32 function description.
|
// For detailed info about parameters, look at FP32 function description.
|
||||||
void MklBlasGemm(bool transa, bool transb, const int m, const int n,
|
void MklBlasGemm(bool transa, bool transb, const int m, const int n,
|
||||||
const int k, const std::complex<float>* a, const int lda,
|
const int k, const complex64* a, const int lda,
|
||||||
const std::complex<float>* b, const int ldb,
|
const complex64* b, const int ldb,
|
||||||
std::complex<float>* c, int const ldc) {
|
complex64* c, int const ldc) {
|
||||||
const MKL_Complex8 alpha = {1.0f, 0.0f};
|
const MKL_Complex8 alpha = {1.0f, 0.0f};
|
||||||
const MKL_Complex8 beta = {0.0f, 0.0f};
|
const MKL_Complex8 beta = {0.0f, 0.0f};
|
||||||
cblas_cgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
|
cblas_cgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
|
||||||
transb ? CblasTrans : CblasNoTrans, m, n, k,
|
transb ? CblasTrans : CblasNoTrans,
|
||||||
static_cast<const void*>(&alpha), static_cast<const void*>(a),
|
m, n, k, &alpha, reinterpret_cast<const MKL_Complex8*>(a), lda,
|
||||||
lda, static_cast<const void*>(b), ldb,
|
reinterpret_cast<const MKL_Complex8*>(b), ldb, &beta,
|
||||||
static_cast<const void*>(&beta), static_cast<void*>(c), ldc);
|
reinterpret_cast<MKL_Complex8*>(c), ldc);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Matrix-Matrix Multiplication with Complex128 (std::complex<double>)
|
// Matrix-Matrix Multiplication with Complex128 (std::complex<double>)
|
||||||
// tensors. For detailed info about parameters, look at FP32 function
|
// tensors. For detailed info about parameters, look at FP32 function
|
||||||
// description.
|
// description.
|
||||||
void MklBlasGemm(bool transa, bool transb, const int m, const int n,
|
void MklBlasGemm(bool transa, bool transb, const int m, const int n,
|
||||||
const int k, const std::complex<double>* a, const int lda,
|
const int k, const complex128* a, const int lda,
|
||||||
const std::complex<double>* b, const int ldb,
|
const complex128* b, const int ldb,
|
||||||
std::complex<double>* c, const int ldc) {
|
complex128* c, const int ldc) {
|
||||||
const MKL_Complex16 alpha = {1.0, 0.0};
|
const MKL_Complex16 alpha = {1.0, 0.0};
|
||||||
const MKL_Complex16 beta = {0.0, 0.0};
|
const MKL_Complex16 beta = {0.0, 0.0};
|
||||||
cblas_zgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
|
cblas_zgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
|
||||||
transb ? CblasTrans : CblasNoTrans, m, n, k,
|
transb ? CblasTrans : CblasNoTrans,
|
||||||
static_cast<const void*>(&alpha), static_cast<const void*>(a),
|
m, n, k, &alpha, reinterpret_cast<const MKL_Complex16*>(a), lda,
|
||||||
lda, static_cast<const void*>(b), ldb,
|
reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta,
|
||||||
static_cast<const void*>(&beta), static_cast<void*>(c), ldc);
|
reinterpret_cast<MKL_Complex16*>(c), ldc);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -128,7 +128,7 @@ class MklToTfOp : public OpKernel {
|
|||||||
#else
|
#else
|
||||||
static void ConvertMklToTf(OpKernel* op_kernel, OpKernelContext* context,
|
static void ConvertMklToTf(OpKernel* op_kernel, OpKernelContext* context,
|
||||||
string data_format_str, DataType op_data_type,
|
string data_format_str, DataType op_data_type,
|
||||||
bool has_avx512f, uint input_number) {
|
bool has_avx512f, uint32 input_number) {
|
||||||
// Check that input tensor is in MKL format.
|
// Check that input tensor is in MKL format.
|
||||||
const Tensor& input_tensor = MklGetInput(context, input_number);
|
const Tensor& input_tensor = MklGetInput(context, input_number);
|
||||||
MklShape input_shape;
|
MklShape input_shape;
|
||||||
|
@ -18,9 +18,6 @@ limitations under the License.
|
|||||||
#ifdef INTEL_MKL
|
#ifdef INTEL_MKL
|
||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
#include "tensorflow/core/framework/numeric_types.h"
|
|
||||||
#define MKL_Complex8 tensorflow::complex64
|
|
||||||
#define MKL_Complex16 tensorflow::complex128
|
|
||||||
#include "mkl_trans.h"
|
#include "mkl_trans.h"
|
||||||
#include "tensorflow/core/kernels/transpose_functor.h"
|
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||||
#include "tensorflow/core/kernels/transpose_op.h"
|
#include "tensorflow/core/kernels/transpose_op.h"
|
||||||
@ -62,10 +59,31 @@ Status MKLTranspose2D(const char trans, const Tensor& in, Tensor* out);
|
|||||||
|
|
||||||
INSTANTIATE(float, s)
|
INSTANTIATE(float, s)
|
||||||
INSTANTIATE(double, d)
|
INSTANTIATE(double, d)
|
||||||
INSTANTIATE(complex64, c)
|
|
||||||
INSTANTIATE(complex128, z)
|
|
||||||
#undef INSTANTIATE
|
#undef INSTANTIATE
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Status MKLTranspose2D<complex64>(const char trans, const Tensor& in, Tensor* out) {
|
||||||
|
const MKL_Complex8 alpha = { 1.0f, 0.0f };
|
||||||
|
mkl_comatcopy('R', trans, in.dim_size(0), in.dim_size(1), alpha,
|
||||||
|
reinterpret_cast<const MKL_Complex8*>(in.flat<complex64>().data()),
|
||||||
|
in.dim_size(1),
|
||||||
|
reinterpret_cast<MKL_Complex8*>(const_cast<complex64*>(out->flat<complex64>().data())),
|
||||||
|
in.dim_size(0));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Status MKLTranspose2D<complex128>(const char trans, const Tensor& in, Tensor* out) {
|
||||||
|
const MKL_Complex16 alpha = { 1.0, 0.0 };
|
||||||
|
mkl_zomatcopy('R', trans, in.dim_size(0), in.dim_size(1), alpha,
|
||||||
|
reinterpret_cast<const MKL_Complex16*>(in.flat<complex128>().data()),
|
||||||
|
in.dim_size(1),
|
||||||
|
reinterpret_cast<MKL_Complex16*>(const_cast<complex128*>(out->flat<complex128>().data())),
|
||||||
|
in.dim_size(0));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
static const char kMKLTranspose = 'T';
|
static const char kMKLTranspose = 'T';
|
||||||
static const char kMKLConjugateTranspose = 'C';
|
static const char kMKLConjugateTranspose = 'C';
|
||||||
|
|
||||||
|
@ -358,11 +358,11 @@ class MklSliceOp : public OpKernel {
|
|||||||
/* data format = NCHW */
|
/* data format = NCHW */
|
||||||
|
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
for (size_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
|
for (ssize_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
|
||||||
T* ip = in_buf + (d0 * in_strides[0]);
|
T* ip = in_buf + (d0 * in_strides[0]);
|
||||||
T* op = op_buf + ((d0 - begin[0]) * out_strides[0]);
|
T* op = op_buf + ((d0 - begin[0]) * out_strides[0]);
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
for (size_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
|
for (ssize_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
|
||||||
T* ip1 = ip + (d1 * in_strides[1]);
|
T* ip1 = ip + (d1 * in_strides[1]);
|
||||||
T* op1 = op + ((d1 - begin[1]) * out_strides[1]);
|
T* op1 = op + ((d1 - begin[1]) * out_strides[1]);
|
||||||
// For NCHW, H and W will be contiguous. So we can copy
|
// For NCHW, H and W will be contiguous. So we can copy
|
||||||
@ -376,15 +376,15 @@ class MklSliceOp : public OpKernel {
|
|||||||
/* data_format = NHWC */
|
/* data_format = NHWC */
|
||||||
|
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
for (size_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
|
for (ssize_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
|
||||||
T* ip = in_buf + (d0 * in_strides[0]);
|
T* ip = in_buf + (d0 * in_strides[0]);
|
||||||
T* op = op_buf + ((d0 - begin[0]) * out_strides[0]);
|
T* op = op_buf + ((d0 - begin[0]) * out_strides[0]);
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
for (size_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
|
for (ssize_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
|
||||||
T* ip1 = ip + (d1 * in_strides[1]);
|
T* ip1 = ip + (d1 * in_strides[1]);
|
||||||
T* op1 = op + ((d1 - begin[1]) * out_strides[1]);
|
T* op1 = op + ((d1 - begin[1]) * out_strides[1]);
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
for (size_t d2 = begin[2]; d2 < begin[2] + size[2]; d2++) {
|
for (ssize_t d2 = begin[2]; d2 < begin[2] + size[2]; d2++) {
|
||||||
T* ip2 = ip1 + (d2 * in_strides[2]);
|
T* ip2 = ip1 + (d2 * in_strides[2]);
|
||||||
T* ip3 = ip2 + begin[3];
|
T* ip3 = ip2 + begin[3];
|
||||||
T* op2 = op1 + ((d2 - begin[2]) * out_strides[2]);
|
T* op2 = op1 + ((d2 - begin[2]) * out_strides[2]);
|
||||||
|
@ -27,9 +27,6 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty();
|
|||||||
|
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#if 0
|
|
||||||
#include <omp.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||||
@ -360,7 +357,6 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
|||||||
l_tick6 = libxsmm_timer_tick();
|
l_tick6 = libxsmm_timer_tick();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if 1
|
|
||||||
BlockingCounter counter(num_threads);
|
BlockingCounter counter(num_threads);
|
||||||
|
|
||||||
for (int i = 0; i < num_threads; ++i) {
|
for (int i = 0; i < num_threads; ++i) {
|
||||||
@ -371,14 +367,6 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
counter.Wait();
|
counter.Wait();
|
||||||
#else
|
|
||||||
#pragma omp parallel
|
|
||||||
{
|
|
||||||
chk_libxsmm_err(
|
|
||||||
libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, omp_get_thread_num()),
|
|
||||||
"Worker");
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||||
l_tick7 = libxsmm_timer_tick();
|
l_tick7 = libxsmm_timer_tick();
|
||||||
|
@ -1112,9 +1112,9 @@ inline void ForwardMklTensorInToOutWithMklShape(OpKernelContext* context,
|
|||||||
// Forward the MKL shape ONLY (used in elementwise and other ops where
|
// Forward the MKL shape ONLY (used in elementwise and other ops where
|
||||||
// we call the eigen implementation and MKL shape is not used)
|
// we call the eigen implementation and MKL shape is not used)
|
||||||
inline void ForwardMklMetaDataInToOut(OpKernelContext* context,
|
inline void ForwardMklMetaDataInToOut(OpKernelContext* context,
|
||||||
uint idx_data_in, uint idx_data_out) {
|
uint32 idx_data_in, uint32_t idx_data_out) {
|
||||||
uint idx_meta_in = GetTensorMetaDataIndex(idx_data_in, context->num_inputs());
|
uint32 idx_meta_in = GetTensorMetaDataIndex(idx_data_in, context->num_inputs());
|
||||||
uint idx_meta_out =
|
uint32 idx_meta_out =
|
||||||
GetTensorMetaDataIndex(idx_data_out, context->num_outputs());
|
GetTensorMetaDataIndex(idx_data_out, context->num_outputs());
|
||||||
|
|
||||||
if (IsRefType(context->input_dtype(idx_data_in))) {
|
if (IsRefType(context->input_dtype(idx_data_in))) {
|
||||||
@ -1126,7 +1126,7 @@ inline void ForwardMklMetaDataInToOut(OpKernelContext* context,
|
|||||||
|
|
||||||
// Set a dummy MKL shape (called when the output is in TF format)
|
// Set a dummy MKL shape (called when the output is in TF format)
|
||||||
inline void SetDummyMklShapeOutput(OpKernelContext* context,
|
inline void SetDummyMklShapeOutput(OpKernelContext* context,
|
||||||
uint idx_data_out) {
|
uint32 idx_data_out) {
|
||||||
MklShape mkl_shape_output;
|
MklShape mkl_shape_output;
|
||||||
mkl_shape_output.SetMklTensor(false);
|
mkl_shape_output.SetMklTensor(false);
|
||||||
AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
|
AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user