Fixed bfloat16 integration of LIBXSMM sparse mat-mul.
Change: 149617825
This commit is contained in:
parent
89df2a1b41
commit
62be492ef4
@ -1522,21 +1522,24 @@ void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
|
||||
|
||||
void wrapper_libxsmm_spmdm_compute_generic_thread(
|
||||
empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle,
|
||||
char transA, char transB, const bfloat16* alpha,
|
||||
libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC,
|
||||
const bfloat16* beta, float* C, int block_id, int tid, int nthreads) {
|
||||
char transA, char transB, libxsmm_CSR_sparseslice* A_sparse,
|
||||
const bfloat16* B, char transC, float* C, int block_id, int tid,
|
||||
int nthreads) {
|
||||
const uint16 alpha = 1;
|
||||
const uint16 beta = 0;
|
||||
return libxsmm_spmdm_compute_bfloat16_thread(
|
||||
handle, transA, transB, reinterpret_cast<const uint16*>(alpha), A_sparse,
|
||||
reinterpret_cast<const uint16*>(B), transC,
|
||||
reinterpret_cast<const uint16*>(beta), C, block_id, tid, nthreads);
|
||||
handle, transA, transB, &alpha, A_sparse,
|
||||
reinterpret_cast<const uint16*>(B), transC, &beta, C, block_id, tid,
|
||||
nthreads);
|
||||
}
|
||||
void wrapper_libxsmm_spmdm_compute_generic_thread(
|
||||
empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
|
||||
char transB, const float* alpha, libxsmm_CSR_sparseslice* A_sparse,
|
||||
const float* B, char transC, const float* beta, float* C, int block_id,
|
||||
int tid, int nthreads) {
|
||||
return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, alpha,
|
||||
A_sparse, B, transC, beta, C,
|
||||
char transB, libxsmm_CSR_sparseslice* A_sparse, const float* B, char transC,
|
||||
float* C, int block_id, int tid, int nthreads) {
|
||||
const float alpha = 1.f;
|
||||
const float beta = 0.f;
|
||||
return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, &alpha,
|
||||
A_sparse, B, transC, &beta, C,
|
||||
block_id, tid, nthreads);
|
||||
}
|
||||
|
||||
@ -1648,13 +1651,11 @@ inline void LibxsmmSparseMatMul<TL, TR>::Compute(
|
||||
while (true) {
|
||||
int work_item = cur_mult_block_number.fetch_add(1);
|
||||
if (work_item >= total_num_mult_blocks) break;
|
||||
const TL alpha(1.0); // Stored in a variable so we can get a pointer
|
||||
const TL beta(0.0); // Stored in a variable so we can get a pointer
|
||||
wrapper_libxsmm_spmdm_compute_generic_thread(
|
||||
empty_type_wrapper<TL>{}, &entry->handle,
|
||||
(transpose_left ? 'T' : 'N'), 'N', &alpha, entry->output_csr,
|
||||
right_data, (transpose_output ? 'T' : 'N'), &beta, output_data,
|
||||
work_item, i, actual_num_threads);
|
||||
(transpose_left ? 'T' : 'N'), 'N', entry->output_csr, right_data,
|
||||
(transpose_output ? 'T' : 'N'), output_data, work_item, i,
|
||||
actual_num_threads);
|
||||
}
|
||||
});
|
||||
// Put handle + CSR storage back into cache
|
||||
@ -1802,15 +1803,17 @@ inline void SparseMatMul<TL, TR>::Compute(
|
||||
SparseMatMulOp<TA, TB, LibxsmmSparseMatMul>);
|
||||
#endif
|
||||
|
||||
REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
|
||||
|
||||
REGISTER_SPARSE_MATMUL(float, bfloat16);
|
||||
|
||||
REGISTER_SPARSE_MATMUL(bfloat16, float);
|
||||
|
||||
#ifdef TENSORFLOW_USE_LIBXSMM
|
||||
REGISTER_SPARSE_MATMUL_LIBXSMM(bfloat16, bfloat16);
|
||||
|
||||
REGISTER_SPARSE_MATMUL_LIBXSMM(float, float);
|
||||
#else
|
||||
REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
|
||||
|
||||
REGISTER_SPARSE_MATMUL(float, float);
|
||||
#endif
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user