fix MKL_Complex cast problem

error : argument of type "" is incompatible with parameter of type ""
This commit is contained in:
fo40225 2018-02-16 00:19:29 +08:00
parent 023d47d0f1
commit 24e343b18c
3 changed files with 47 additions and 33 deletions

View File

@ -29,7 +29,6 @@ limitations under the License.
#include <vector> #include <vector>
#include "mkl_cblas.h" #include "mkl_cblas.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
@ -41,9 +40,6 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#define MKL_Complex8 tensorflow::complex64
#define MKL_Complex16 tensorflow::complex128
namespace tensorflow { namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::ThreadPoolDevice CPUDevice;
@ -180,16 +176,16 @@ class BatchMatMulMkl : public OpKernel {
void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA, void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
const bool TransB, const MKL_INT *M_Array, const bool TransB, const MKL_INT *M_Array,
const MKL_INT *N_Array, const MKL_INT *K_Array, const MKL_INT *N_Array, const MKL_INT *K_Array,
const MKL_Complex8 **A_Array, const MKL_INT *lda_Array, const complex64 **A_Array, const MKL_INT *lda_Array,
const MKL_Complex8 **B_Array, const MKL_INT *ldb_Array, const complex64 **B_Array, const MKL_INT *ldb_Array,
MKL_Complex8 **C_Array, const MKL_INT *ldc_Array, complex64 **C_Array, const MKL_INT *ldc_Array,
const MKL_INT group_count, const MKL_INT *group_size) { const MKL_INT group_count, const MKL_INT *group_size) {
std::vector<CBLAS_TRANSPOSE> TransA_array( std::vector<CBLAS_TRANSPOSE> TransA_array(
group_size[0], TransA ? CblasConjTrans : CblasNoTrans); group_size[0], TransA ? CblasConjTrans : CblasNoTrans);
std::vector<CBLAS_TRANSPOSE> TransB_array( std::vector<CBLAS_TRANSPOSE> TransB_array(
group_size[0], TransB ? CblasConjTrans : CblasNoTrans); group_size[0], TransB ? CblasConjTrans : CblasNoTrans);
std::vector<MKL_Complex8> alpha_Array(group_size[0], {1.0f, 0.0f}); std::vector<complex64> alpha_Array(group_size[0], {1.0f, 0.0f});
std::vector<MKL_Complex8> beta_Array(group_size[0], {0.0f, 0.0f}); std::vector<complex64> beta_Array(group_size[0], {0.0f, 0.0f});
cblas_cgemm_batch( cblas_cgemm_batch(
Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array, Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array,
static_cast<const void *>(&alpha_Array[0]), static_cast<const void *>(&alpha_Array[0]),
@ -202,18 +198,18 @@ class BatchMatMulMkl : public OpKernel {
void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA, void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
const bool TransB, const MKL_INT *M_Array, const bool TransB, const MKL_INT *M_Array,
const MKL_INT *N_Array, const MKL_INT *K_Array, const MKL_INT *N_Array, const MKL_INT *K_Array,
const MKL_Complex16 **A_Array, const complex128 **A_Array,
const MKL_INT *lda_Array, const MKL_INT *lda_Array,
const MKL_Complex16 **B_Array, const complex128 **B_Array,
const MKL_INT *ldb_Array, MKL_Complex16 **C_Array, const MKL_INT *ldb_Array, complex128 **C_Array,
const MKL_INT *ldc_Array, const MKL_INT group_count, const MKL_INT *ldc_Array, const MKL_INT group_count,
const MKL_INT *group_size) { const MKL_INT *group_size) {
std::vector<CBLAS_TRANSPOSE> TransA_array( std::vector<CBLAS_TRANSPOSE> TransA_array(
group_size[0], TransA ? CblasConjTrans : CblasNoTrans); group_size[0], TransA ? CblasConjTrans : CblasNoTrans);
std::vector<CBLAS_TRANSPOSE> TransB_array( std::vector<CBLAS_TRANSPOSE> TransB_array(
group_size[0], TransB ? CblasConjTrans : CblasNoTrans); group_size[0], TransB ? CblasConjTrans : CblasNoTrans);
std::vector<MKL_Complex16> alpha_Array(group_size[0], {1.0f, 0.0f}); std::vector<complex128> alpha_Array(group_size[0], {1.0f, 0.0f});
std::vector<MKL_Complex16> beta_Array(group_size[0], {0.0f, 0.0f}); std::vector<complex128> beta_Array(group_size[0], {0.0f, 0.0f});
cblas_zgemm_batch( cblas_zgemm_batch(
Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array, Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array,
static_cast<const void *>(&alpha_Array[0]), static_cast<const void *>(&alpha_Array[0]),

View File

@ -170,32 +170,32 @@ class MklMatMulOp : public OpKernel {
// Matrix-Matrix Multiplication with Complex64 (std::complex<float>) tensors. // Matrix-Matrix Multiplication with Complex64 (std::complex<float>) tensors.
// For detailed info about parameters, look at FP32 function description. // For detailed info about parameters, look at FP32 function description.
void MklBlasGemm(bool transa, bool transb, const int m, const int n, void MklBlasGemm(bool transa, bool transb, const int m, const int n,
const int k, const std::complex<float>* a, const int lda, const int k, const complex64* a, const int lda,
const std::complex<float>* b, const int ldb, const complex64* b, const int ldb,
std::complex<float>* c, int const ldc) { complex64* c, int const ldc) {
const MKL_Complex8 alpha = {1.0f, 0.0f}; const MKL_Complex8 alpha = {1.0f, 0.0f};
const MKL_Complex8 beta = {0.0f, 0.0f}; const MKL_Complex8 beta = {0.0f, 0.0f};
cblas_cgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, cblas_cgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
transb ? CblasTrans : CblasNoTrans, m, n, k, transb ? CblasTrans : CblasNoTrans,
static_cast<const void*>(&alpha), static_cast<const void*>(a), m, n, k, &alpha, reinterpret_cast<const MKL_Complex8*>(a), lda,
lda, static_cast<const void*>(b), ldb, reinterpret_cast<const MKL_Complex8*>(b), ldb, &beta,
static_cast<const void*>(&beta), static_cast<void*>(c), ldc); reinterpret_cast<MKL_Complex8*>(c), ldc);
} }
// Matrix-Matrix Multiplication with Complex128 (std::complex<double>) // Matrix-Matrix Multiplication with Complex128 (std::complex<double>)
// tensors. For detailed info about parameters, look at FP32 function // tensors. For detailed info about parameters, look at FP32 function
// description. // description.
void MklBlasGemm(bool transa, bool transb, const int m, const int n, void MklBlasGemm(bool transa, bool transb, const int m, const int n,
const int k, const std::complex<double>* a, const int lda, const int k, const complex128* a, const int lda,
const std::complex<double>* b, const int ldb, const complex128* b, const int ldb,
std::complex<double>* c, const int ldc) { complex128* c, const int ldc) {
const MKL_Complex16 alpha = {1.0, 0.0}; const MKL_Complex16 alpha = {1.0, 0.0};
const MKL_Complex16 beta = {0.0, 0.0}; const MKL_Complex16 beta = {0.0, 0.0};
cblas_zgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, cblas_zgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
transb ? CblasTrans : CblasNoTrans, m, n, k, transb ? CblasTrans : CblasNoTrans,
static_cast<const void*>(&alpha), static_cast<const void*>(a), m, n, k, &alpha, reinterpret_cast<const MKL_Complex16*>(a), lda,
lda, static_cast<const void*>(b), ldb, reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta,
static_cast<const void*>(&beta), static_cast<void*>(c), ldc); reinterpret_cast<MKL_Complex16*>(c), ldc);
} }
}; };

View File

@ -18,9 +18,6 @@ limitations under the License.
#ifdef INTEL_MKL #ifdef INTEL_MKL
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#include "tensorflow/core/framework/numeric_types.h"
#define MKL_Complex8 tensorflow::complex64
#define MKL_Complex16 tensorflow::complex128
#include "mkl_trans.h" #include "mkl_trans.h"
#include "tensorflow/core/kernels/transpose_functor.h" #include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/kernels/transpose_op.h" #include "tensorflow/core/kernels/transpose_op.h"
@ -62,10 +59,31 @@ Status MKLTranspose2D(const char trans, const Tensor& in, Tensor* out);
INSTANTIATE(float, s) INSTANTIATE(float, s)
INSTANTIATE(double, d) INSTANTIATE(double, d)
INSTANTIATE(complex64, c)
INSTANTIATE(complex128, z)
#undef INSTANTIATE #undef INSTANTIATE
template <>
Status MKLTranspose2D<complex64>(const char trans, const Tensor& in, Tensor* out) {
const MKL_Complex8 alpha = { 1.0f, 0.0f };
mkl_comatcopy('R', trans, in.dim_size(0), in.dim_size(1), alpha,
reinterpret_cast<const MKL_Complex8*>(in.flat<complex64>().data()),
in.dim_size(1),
reinterpret_cast<MKL_Complex8*>(const_cast<complex64*>(out->flat<complex64>().data())),
in.dim_size(0));
return Status::OK();
}
template <>
Status MKLTranspose2D<complex128>(const char trans, const Tensor& in, Tensor* out) {
const MKL_Complex16 alpha = { 1.0, 0.0 };
mkl_zomatcopy('R', trans, in.dim_size(0), in.dim_size(1), alpha,
reinterpret_cast<const MKL_Complex16*>(in.flat<complex128>().data()),
in.dim_size(1),
reinterpret_cast<MKL_Complex16*>(const_cast<complex128*>(out->flat<complex128>().data())),
in.dim_size(0));
return Status::OK();
}
static const char kMKLTranspose = 'T'; static const char kMKLTranspose = 'T';
static const char kMKLConjugateTranspose = 'C'; static const char kMKLConjugateTranspose = 'C';