diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc index 06734a6a2a3..2ed0522ce4a 100644 --- a/tensorflow/core/kernels/sparse_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_matmul_op.cc @@ -1522,21 +1522,24 @@ void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread( void wrapper_libxsmm_spmdm_compute_generic_thread( empty_type_wrapper, const libxsmm_spmdm_handle* handle, - char transA, char transB, const bfloat16* alpha, - libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC, - const bfloat16* beta, float* C, int block_id, int tid, int nthreads) { + char transA, char transB, libxsmm_CSR_sparseslice* A_sparse, + const bfloat16* B, char transC, float* C, int block_id, int tid, + int nthreads) { + const uint16 alpha = 1; + const uint16 beta = 0; return libxsmm_spmdm_compute_bfloat16_thread( - handle, transA, transB, reinterpret_cast(alpha), A_sparse, - reinterpret_cast(B), transC, - reinterpret_cast(beta), C, block_id, tid, nthreads); + handle, transA, transB, &alpha, A_sparse, + reinterpret_cast(B), transC, &beta, C, block_id, tid, + nthreads); } void wrapper_libxsmm_spmdm_compute_generic_thread( empty_type_wrapper, const libxsmm_spmdm_handle* handle, char transA, - char transB, const float* alpha, libxsmm_CSR_sparseslice* A_sparse, - const float* B, char transC, const float* beta, float* C, int block_id, - int tid, int nthreads) { - return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, alpha, - A_sparse, B, transC, beta, C, + char transB, libxsmm_CSR_sparseslice* A_sparse, const float* B, char transC, + float* C, int block_id, int tid, int nthreads) { + const float alpha = 1.f; + const float beta = 0.f; + return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, &alpha, + A_sparse, B, transC, &beta, C, block_id, tid, nthreads); } @@ -1648,13 +1651,11 @@ inline void LibxsmmSparseMatMul::Compute( while (true) { int work_item = cur_mult_block_number.fetch_add(1); if (work_item >= total_num_mult_blocks) break; - const TL alpha(1.0); // Stored in a variable so we can get a pointer - const TL beta(0.0); // Stored in a variable so we can get a pointer wrapper_libxsmm_spmdm_compute_generic_thread( empty_type_wrapper{}, &entry->handle, - (transpose_left ? 'T' : 'N'), 'N', &alpha, entry->output_csr, - right_data, (transpose_output ? 'T' : 'N'), &beta, output_data, - work_item, i, actual_num_threads); + (transpose_left ? 'T' : 'N'), 'N', entry->output_csr, right_data, + (transpose_output ? 'T' : 'N'), output_data, work_item, i, + actual_num_threads); } }); // Put handle + CSR storage back into cache @@ -1802,15 +1803,17 @@ inline void SparseMatMul::Compute( SparseMatMulOp); #endif -REGISTER_SPARSE_MATMUL(bfloat16, bfloat16); - REGISTER_SPARSE_MATMUL(float, bfloat16); REGISTER_SPARSE_MATMUL(bfloat16, float); #ifdef TENSORFLOW_USE_LIBXSMM +REGISTER_SPARSE_MATMUL_LIBXSMM(bfloat16, bfloat16); + REGISTER_SPARSE_MATMUL_LIBXSMM(float, float); #else +REGISTER_SPARSE_MATMUL(bfloat16, bfloat16); + REGISTER_SPARSE_MATMUL(float, float); #endif