diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h index 2a67c039ac0..71e0de97246 100644 --- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h +++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h @@ -21,7 +21,6 @@ limitations under the License. #ifdef INTEL_MKL -#include #include #include #include "tensorflow/core/common_runtime/bfc_allocator.h" diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc index 5343e6802d1..e9ced4d2b6b 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc @@ -222,7 +222,7 @@ Status MklToTfConversionPass::InsertInputConversionNode( BaseType(n->input_type(0))); // Check ordering of edges - for (uint i = 0; i < 4; i++) { + for (uint32 i = 0; i < 4; i++) { CHECK_EQ((edges[i]->dst_input() == i), true); } diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc index d9713075be6..c48a2038f92 100644 --- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc @@ -29,7 +29,6 @@ limitations under the License. #include #include "mkl_cblas.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -41,9 +40,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -#define MKL_Complex8 tensorflow::complex64 -#define MKL_Complex16 tensorflow::complex128 - namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; @@ -180,16 +176,16 @@ class BatchMatMulMkl : public OpKernel { void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA, const bool TransB, const MKL_INT *M_Array, const MKL_INT *N_Array, const MKL_INT *K_Array, - const MKL_Complex8 **A_Array, const MKL_INT *lda_Array, - const MKL_Complex8 **B_Array, const MKL_INT *ldb_Array, - MKL_Complex8 **C_Array, const MKL_INT *ldc_Array, + const complex64 **A_Array, const MKL_INT *lda_Array, + const complex64 **B_Array, const MKL_INT *ldb_Array, + complex64 **C_Array, const MKL_INT *ldc_Array, const MKL_INT group_count, const MKL_INT *group_size) { std::vector TransA_array( group_size[0], TransA ? CblasConjTrans : CblasNoTrans); std::vector TransB_array( group_size[0], TransB ? CblasConjTrans : CblasNoTrans); - std::vector alpha_Array(group_size[0], {1.0f, 0.0f}); - std::vector beta_Array(group_size[0], {0.0f, 0.0f}); + std::vector alpha_Array(group_size[0], {1.0f, 0.0f}); + std::vector beta_Array(group_size[0], {0.0f, 0.0f}); cblas_cgemm_batch( Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array, static_cast(&alpha_Array[0]), @@ -202,18 +198,18 @@ class BatchMatMulMkl : public OpKernel { void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA, const bool TransB, const MKL_INT *M_Array, const MKL_INT *N_Array, const MKL_INT *K_Array, - const MKL_Complex16 **A_Array, + const complex128 **A_Array, const MKL_INT *lda_Array, - const MKL_Complex16 **B_Array, - const MKL_INT *ldb_Array, MKL_Complex16 **C_Array, + const complex128 **B_Array, + const MKL_INT *ldb_Array, complex128 **C_Array, const MKL_INT *ldc_Array, const MKL_INT group_count, const MKL_INT *group_size) { std::vector TransA_array( group_size[0], TransA ? CblasConjTrans : CblasNoTrans); std::vector TransB_array( group_size[0], TransB ? CblasConjTrans : CblasNoTrans); - std::vector alpha_Array(group_size[0], {1.0f, 0.0f}); - std::vector beta_Array(group_size[0], {0.0f, 0.0f}); + std::vector alpha_Array(group_size[0], {1.0f, 0.0f}); + std::vector beta_Array(group_size[0], {0.0f, 0.0f}); cblas_zgemm_batch( Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array, static_cast(&alpha_Array[0]), diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc index 5a8799ae93c..e9a2376b545 100644 --- a/tensorflow/core/kernels/mkl_input_conversion_op.cc +++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc @@ -145,8 +145,8 @@ class MklInputConversionOp : public OpKernel { const MklShape* mkl_shape; const Tensor* tf_tensor; MklShape* tf_mkl_shape; - uint mkl_tensor_index; - uint tf_tensor_index; + uint32 mkl_tensor_index; + uint32 tf_tensor_index; if (input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) { mkl_tensor = &input_tensor_0; mkl_shape = &input_shape_0; diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc index 47598f443f7..25ad8c94a78 100644 --- a/tensorflow/core/kernels/mkl_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_matmul_op.cc @@ -170,32 +170,32 @@ class MklMatMulOp : public OpKernel { // Matrix-Matrix Multiplication with Complex64 (std::complex) tensors. // For detailed info about parameters, look at FP32 function description. void MklBlasGemm(bool transa, bool transb, const int m, const int n, - const int k, const std::complex* a, const int lda, - const std::complex* b, const int ldb, - std::complex* c, int const ldc) { + const int k, const complex64* a, const int lda, + const complex64* b, const int ldb, + complex64* c, int const ldc) { const MKL_Complex8 alpha = {1.0f, 0.0f}; const MKL_Complex8 beta = {0.0f, 0.0f}; cblas_cgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, - transb ? CblasTrans : CblasNoTrans, m, n, k, - static_cast(&alpha), static_cast(a), - lda, static_cast(b), ldb, - static_cast(&beta), static_cast(c), ldc); + transb ? CblasTrans : CblasNoTrans, + m, n, k, &alpha, reinterpret_cast(a), lda, + reinterpret_cast(b), ldb, &beta, + reinterpret_cast(c), ldc); } // Matrix-Matrix Multiplication with Complex128 (std::complex) // tensors. For detailed info about parameters, look at FP32 function // description. void MklBlasGemm(bool transa, bool transb, const int m, const int n, - const int k, const std::complex* a, const int lda, - const std::complex* b, const int ldb, - std::complex* c, const int ldc) { + const int k, const complex128* a, const int lda, + const complex128* b, const int ldb, + complex128* c, const int ldc) { const MKL_Complex16 alpha = {1.0, 0.0}; const MKL_Complex16 beta = {0.0, 0.0}; cblas_zgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, - transb ? CblasTrans : CblasNoTrans, m, n, k, - static_cast(&alpha), static_cast(a), - lda, static_cast(b), ldb, - static_cast(&beta), static_cast(c), ldc); + transb ? CblasTrans : CblasNoTrans, + m, n, k, &alpha, reinterpret_cast(a), lda, + reinterpret_cast(b), ldb, &beta, + reinterpret_cast(c), ldc); } }; diff --git a/tensorflow/core/kernels/mkl_tfconv_op.h b/tensorflow/core/kernels/mkl_tfconv_op.h index 5fafa14b5db..ddea9e281b2 100644 --- a/tensorflow/core/kernels/mkl_tfconv_op.h +++ b/tensorflow/core/kernels/mkl_tfconv_op.h @@ -128,7 +128,7 @@ class MklToTfOp : public OpKernel { #else static void ConvertMklToTf(OpKernel* op_kernel, OpKernelContext* context, string data_format_str, DataType op_data_type, - bool has_avx512f, uint input_number) { + bool has_avx512f, uint32 input_number) { // Check that input tensor is in MKL format. const Tensor& input_tensor = MklGetInput(context, input_number); MklShape input_shape; diff --git a/tensorflow/core/kernels/mkl_transpose_op.cc b/tensorflow/core/kernels/mkl_transpose_op.cc index 764d4c9400e..b44b4d6f542 100644 --- a/tensorflow/core/kernels/mkl_transpose_op.cc +++ b/tensorflow/core/kernels/mkl_transpose_op.cc @@ -18,9 +18,6 @@ limitations under the License. #ifdef INTEL_MKL #define EIGEN_USE_THREADS -#include "tensorflow/core/framework/numeric_types.h" -#define MKL_Complex8 tensorflow::complex64 -#define MKL_Complex16 tensorflow::complex128 #include "mkl_trans.h" #include "tensorflow/core/kernels/transpose_functor.h" #include "tensorflow/core/kernels/transpose_op.h" @@ -62,10 +59,31 @@ Status MKLTranspose2D(const char trans, const Tensor& in, Tensor* out); INSTANTIATE(float, s) INSTANTIATE(double, d) -INSTANTIATE(complex64, c) -INSTANTIATE(complex128, z) + #undef INSTANTIATE +template <> +Status MKLTranspose2D(const char trans, const Tensor& in, Tensor* out) { + const MKL_Complex8 alpha = { 1.0f, 0.0f }; + mkl_comatcopy('R', trans, in.dim_size(0), in.dim_size(1), alpha, + reinterpret_cast(in.flat().data()), + in.dim_size(1), + reinterpret_cast(const_cast(out->flat().data())), + in.dim_size(0)); + return Status::OK(); +} + +template <> +Status MKLTranspose2D(const char trans, const Tensor& in, Tensor* out) { + const MKL_Complex16 alpha = { 1.0, 0.0 }; + mkl_zomatcopy('R', trans, in.dim_size(0), in.dim_size(1), alpha, + reinterpret_cast(in.flat().data()), + in.dim_size(1), + reinterpret_cast(const_cast(out->flat().data())), + in.dim_size(0)); + return Status::OK(); +} + static const char kMKLTranspose = 'T'; static const char kMKLConjugateTranspose = 'C'; diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc index 79369fd4a9c..77594479cb1 100644 --- a/tensorflow/core/kernels/slice_op.cc +++ b/tensorflow/core/kernels/slice_op.cc @@ -358,11 +358,11 @@ class MklSliceOp : public OpKernel { /* data format = NCHW */ #pragma omp parallel for - for (size_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) { + for (ssize_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) { T* ip = in_buf + (d0 * in_strides[0]); T* op = op_buf + ((d0 - begin[0]) * out_strides[0]); #pragma omp parallel for - for (size_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) { + for (ssize_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) { T* ip1 = ip + (d1 * in_strides[1]); T* op1 = op + ((d1 - begin[1]) * out_strides[1]); // For NCHW, H and W will be contiguous. So we can copy @@ -376,15 +376,15 @@ class MklSliceOp : public OpKernel { /* data_format = NHWC */ #pragma omp parallel for - for (size_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) { + for (ssize_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) { T* ip = in_buf + (d0 * in_strides[0]); T* op = op_buf + ((d0 - begin[0]) * out_strides[0]); #pragma omp parallel for - for (size_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) { + for (ssize_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) { T* ip1 = ip + (d1 * in_strides[1]); T* op1 = op + ((d1 - begin[1]) * out_strides[1]); #pragma omp parallel for - for (size_t d2 = begin[2]; d2 < begin[2] + size[2]; d2++) { + for (ssize_t d2 = begin[2]; d2 < begin[2] + size[2]; d2++) { T* ip2 = ip1 + (d2 * in_strides[2]); T* ip3 = ip2 + begin[3]; T* op2 = op1 + ((d2 - begin[2]) * out_strides[2]); diff --git a/tensorflow/core/kernels/xsmm_conv2d.cc b/tensorflow/core/kernels/xsmm_conv2d.cc index 601704c8a70..ba03357cc6a 100644 --- a/tensorflow/core/kernels/xsmm_conv2d.cc +++ b/tensorflow/core/kernels/xsmm_conv2d.cc @@ -27,9 +27,6 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty(); #include #include -#if 0 -#include -#endif #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/blocking_counter.h" @@ -360,7 +357,6 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, l_tick6 = libxsmm_timer_tick(); #endif -#if 1 BlockingCounter counter(num_threads); for (int i = 0; i < num_threads; ++i) { @@ -371,14 +367,6 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, }); } counter.Wait(); -#else -#pragma omp parallel - { - chk_libxsmm_err( - libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, omp_get_thread_num()), - "Worker"); - } -#endif #if defined(LIBXSMM_DETAILED_TIMING) l_tick7 = libxsmm_timer_tick(); diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index db4c5c35e36..eda966bc334 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -1112,9 +1112,9 @@ inline void ForwardMklTensorInToOutWithMklShape(OpKernelContext* context, // Forward the MKL shape ONLY (used in elementwise and other ops where // we call the eigen implementation and MKL shape is not used) inline void ForwardMklMetaDataInToOut(OpKernelContext* context, - uint idx_data_in, uint idx_data_out) { - uint idx_meta_in = GetTensorMetaDataIndex(idx_data_in, context->num_inputs()); - uint idx_meta_out = + uint32 idx_data_in, uint32_t idx_data_out) { + uint32 idx_meta_in = GetTensorMetaDataIndex(idx_data_in, context->num_inputs()); + uint32 idx_meta_out = GetTensorMetaDataIndex(idx_data_out, context->num_outputs()); if (IsRefType(context->input_dtype(idx_data_in))) { @@ -1126,7 +1126,7 @@ inline void ForwardMklMetaDataInToOut(OpKernelContext* context, // Set a dummy MKL shape (called when the output is in TF format) inline void SetDummyMklShapeOutput(OpKernelContext* context, - uint idx_data_out) { + uint32 idx_data_out) { MklShape mkl_shape_output; mkl_shape_output.SetMklTensor(false); AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);