fix MKL_Complex cast problem
error : argument of type "" is incompatible with parameter of type ""
This commit is contained in:
parent
023d47d0f1
commit
24e343b18c
@ -29,7 +29,6 @@ limitations under the License.
|
||||
#include <vector>
|
||||
#include "mkl_cblas.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
@ -41,9 +40,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
#define MKL_Complex8 tensorflow::complex64
|
||||
#define MKL_Complex16 tensorflow::complex128
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
@ -180,16 +176,16 @@ class BatchMatMulMkl : public OpKernel {
|
||||
void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
|
||||
const bool TransB, const MKL_INT *M_Array,
|
||||
const MKL_INT *N_Array, const MKL_INT *K_Array,
|
||||
const MKL_Complex8 **A_Array, const MKL_INT *lda_Array,
|
||||
const MKL_Complex8 **B_Array, const MKL_INT *ldb_Array,
|
||||
MKL_Complex8 **C_Array, const MKL_INT *ldc_Array,
|
||||
const complex64 **A_Array, const MKL_INT *lda_Array,
|
||||
const complex64 **B_Array, const MKL_INT *ldb_Array,
|
||||
complex64 **C_Array, const MKL_INT *ldc_Array,
|
||||
const MKL_INT group_count, const MKL_INT *group_size) {
|
||||
std::vector<CBLAS_TRANSPOSE> TransA_array(
|
||||
group_size[0], TransA ? CblasConjTrans : CblasNoTrans);
|
||||
std::vector<CBLAS_TRANSPOSE> TransB_array(
|
||||
group_size[0], TransB ? CblasConjTrans : CblasNoTrans);
|
||||
std::vector<MKL_Complex8> alpha_Array(group_size[0], {1.0f, 0.0f});
|
||||
std::vector<MKL_Complex8> beta_Array(group_size[0], {0.0f, 0.0f});
|
||||
std::vector<complex64> alpha_Array(group_size[0], {1.0f, 0.0f});
|
||||
std::vector<complex64> beta_Array(group_size[0], {0.0f, 0.0f});
|
||||
cblas_cgemm_batch(
|
||||
Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array,
|
||||
static_cast<const void *>(&alpha_Array[0]),
|
||||
@ -202,18 +198,18 @@ class BatchMatMulMkl : public OpKernel {
|
||||
void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
|
||||
const bool TransB, const MKL_INT *M_Array,
|
||||
const MKL_INT *N_Array, const MKL_INT *K_Array,
|
||||
const MKL_Complex16 **A_Array,
|
||||
const complex128 **A_Array,
|
||||
const MKL_INT *lda_Array,
|
||||
const MKL_Complex16 **B_Array,
|
||||
const MKL_INT *ldb_Array, MKL_Complex16 **C_Array,
|
||||
const complex128 **B_Array,
|
||||
const MKL_INT *ldb_Array, complex128 **C_Array,
|
||||
const MKL_INT *ldc_Array, const MKL_INT group_count,
|
||||
const MKL_INT *group_size) {
|
||||
std::vector<CBLAS_TRANSPOSE> TransA_array(
|
||||
group_size[0], TransA ? CblasConjTrans : CblasNoTrans);
|
||||
std::vector<CBLAS_TRANSPOSE> TransB_array(
|
||||
group_size[0], TransB ? CblasConjTrans : CblasNoTrans);
|
||||
std::vector<MKL_Complex16> alpha_Array(group_size[0], {1.0f, 0.0f});
|
||||
std::vector<MKL_Complex16> beta_Array(group_size[0], {0.0f, 0.0f});
|
||||
std::vector<complex128> alpha_Array(group_size[0], {1.0f, 0.0f});
|
||||
std::vector<complex128> beta_Array(group_size[0], {0.0f, 0.0f});
|
||||
cblas_zgemm_batch(
|
||||
Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array,
|
||||
static_cast<const void *>(&alpha_Array[0]),
|
||||
|
@ -170,32 +170,32 @@ class MklMatMulOp : public OpKernel {
|
||||
// Matrix-Matrix Multiplication with Complex64 (std::complex<float>) tensors.
|
||||
// For detailed info about parameters, look at FP32 function description.
|
||||
void MklBlasGemm(bool transa, bool transb, const int m, const int n,
|
||||
const int k, const std::complex<float>* a, const int lda,
|
||||
const std::complex<float>* b, const int ldb,
|
||||
std::complex<float>* c, int const ldc) {
|
||||
const int k, const complex64* a, const int lda,
|
||||
const complex64* b, const int ldb,
|
||||
complex64* c, int const ldc) {
|
||||
const MKL_Complex8 alpha = {1.0f, 0.0f};
|
||||
const MKL_Complex8 beta = {0.0f, 0.0f};
|
||||
cblas_cgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
|
||||
transb ? CblasTrans : CblasNoTrans, m, n, k,
|
||||
static_cast<const void*>(&alpha), static_cast<const void*>(a),
|
||||
lda, static_cast<const void*>(b), ldb,
|
||||
static_cast<const void*>(&beta), static_cast<void*>(c), ldc);
|
||||
transb ? CblasTrans : CblasNoTrans,
|
||||
m, n, k, &alpha, reinterpret_cast<const MKL_Complex8*>(a), lda,
|
||||
reinterpret_cast<const MKL_Complex8*>(b), ldb, &beta,
|
||||
reinterpret_cast<MKL_Complex8*>(c), ldc);
|
||||
}
|
||||
|
||||
// Matrix-Matrix Multiplication with Complex128 (std::complex<double>)
|
||||
// tensors. For detailed info about parameters, look at FP32 function
|
||||
// description.
|
||||
void MklBlasGemm(bool transa, bool transb, const int m, const int n,
|
||||
const int k, const std::complex<double>* a, const int lda,
|
||||
const std::complex<double>* b, const int ldb,
|
||||
std::complex<double>* c, const int ldc) {
|
||||
const int k, const complex128* a, const int lda,
|
||||
const complex128* b, const int ldb,
|
||||
complex128* c, const int ldc) {
|
||||
const MKL_Complex16 alpha = {1.0, 0.0};
|
||||
const MKL_Complex16 beta = {0.0, 0.0};
|
||||
cblas_zgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
|
||||
transb ? CblasTrans : CblasNoTrans, m, n, k,
|
||||
static_cast<const void*>(&alpha), static_cast<const void*>(a),
|
||||
lda, static_cast<const void*>(b), ldb,
|
||||
static_cast<const void*>(&beta), static_cast<void*>(c), ldc);
|
||||
transb ? CblasTrans : CblasNoTrans,
|
||||
m, n, k, &alpha, reinterpret_cast<const MKL_Complex16*>(a), lda,
|
||||
reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta,
|
||||
reinterpret_cast<MKL_Complex16*>(c), ldc);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -18,9 +18,6 @@ limitations under the License.
|
||||
#ifdef INTEL_MKL
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#define MKL_Complex8 tensorflow::complex64
|
||||
#define MKL_Complex16 tensorflow::complex128
|
||||
#include "mkl_trans.h"
|
||||
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||
#include "tensorflow/core/kernels/transpose_op.h"
|
||||
@ -62,10 +59,31 @@ Status MKLTranspose2D(const char trans, const Tensor& in, Tensor* out);
|
||||
|
||||
INSTANTIATE(float, s)
|
||||
INSTANTIATE(double, d)
|
||||
INSTANTIATE(complex64, c)
|
||||
INSTANTIATE(complex128, z)
|
||||
|
||||
#undef INSTANTIATE
|
||||
|
||||
template <>
|
||||
Status MKLTranspose2D<complex64>(const char trans, const Tensor& in, Tensor* out) {
|
||||
const MKL_Complex8 alpha = { 1.0f, 0.0f };
|
||||
mkl_comatcopy('R', trans, in.dim_size(0), in.dim_size(1), alpha,
|
||||
reinterpret_cast<const MKL_Complex8*>(in.flat<complex64>().data()),
|
||||
in.dim_size(1),
|
||||
reinterpret_cast<MKL_Complex8*>(const_cast<complex64*>(out->flat<complex64>().data())),
|
||||
in.dim_size(0));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status MKLTranspose2D<complex128>(const char trans, const Tensor& in, Tensor* out) {
|
||||
const MKL_Complex16 alpha = { 1.0, 0.0 };
|
||||
mkl_zomatcopy('R', trans, in.dim_size(0), in.dim_size(1), alpha,
|
||||
reinterpret_cast<const MKL_Complex16*>(in.flat<complex128>().data()),
|
||||
in.dim_size(1),
|
||||
reinterpret_cast<MKL_Complex16*>(const_cast<complex128*>(out->flat<complex128>().data())),
|
||||
in.dim_size(0));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static const char kMKLTranspose = 'T';
|
||||
static const char kMKLConjugateTranspose = 'C';
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user