LIBXSMM in TensorFlow broken due to TensorFlow API changes (#24706)

* Original issue about TF/API change was resolved in LIBXSMM 1.11.
* Register sparse matmul-op to use LIBXSMM's Bfloat16 functionality;
  - SpMDM/Bf16 was present since TF v1.1 but TF had no/public Bf16
* Updated (outdated) conv-integration with LIBXSMM; fixes compilation
  - https://github.com/hfp/libxsmm/issues/281 should be considered
  - Use desc. datatype_in/out; desc. previously had only "datatype"
* Generalized libxsmm.BUILD using glob/patterns.
* Updated LIBXSMM v1.14.
This commit is contained in:
Hans Pabst 2019-11-18 19:16:17 +01:00
parent 88a3f5c153
commit dcbfb65b94
8 changed files with 293 additions and 276 deletions

View File

@ -241,8 +241,8 @@ struct LaunchXsmmBackwardFilter<CPUDevice, float> {
desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_RSCK; desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE; desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE; desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
desc.datatype = LIBXSMM_DNN_DATATYPE_F32; desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
if (!CanUseXsmmConv2D(desc, data_format)) { if (!CanUseXsmmConv2D(desc, data_format)) {
return false; return false;
} }

View File

@ -290,9 +290,9 @@ struct LaunchXsmmBackwardInputConvolution<CPUDevice, float> {
desc.filter_format = desc.filter_format =
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM; // LIBXSMM_DNN_TENSOR_FORMAT_RSCK; LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM; // LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE; desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE; desc.options = LIBXSMM_DNN_CONV_OPTION_OVERWRITE;
desc.datatype = LIBXSMM_DNN_DATATYPE_F32; desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
auto input_ptr = input_backward.data(); auto input_ptr = input_backward.data();
auto filter_ptr = kernel.data(); auto filter_ptr = kernel.data();
auto output_ptr = output_backward.data(); auto output_ptr = output_backward.data();

View File

@ -320,9 +320,9 @@ class LaunchXsmmConvOp<CPUDevice, float> {
desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC; desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM; desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE; desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE; desc.options = LIBXSMM_DNN_CONV_OPTION_OVERWRITE;
desc.datatype = LIBXSMM_DNN_DATATYPE_F32; desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
if (dilation_rows != 1 || dilation_cols != 1 || if (dilation_rows != 1 || dilation_cols != 1 ||
!CanUseXsmmConv2D(desc, data_format)) { !CanUseXsmmConv2D(desc, data_format)) {
return false; return false;

View File

@ -1396,7 +1396,7 @@ void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, int tid, libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, int tid,
int nthreads) { int nthreads) {
return libxsmm_spmdm_createSparseSlice_bfloat16_thread( return libxsmm_spmdm_createSparseSlice_bfloat16_thread(
handle, transA, reinterpret_cast<const uint16*>(A), libxsmm_output_csr_a, handle, transA, reinterpret_cast<const libxsmm_bfloat16*>(A), libxsmm_output_csr_a,
block_id, tid, nthreads); block_id, tid, nthreads);
} }
@ -1406,9 +1406,9 @@ void wrapper_libxsmm_spmdm_compute_generic_thread(
libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC, libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC,
const bfloat16* beta, float* C, int block_id, int tid, int nthreads) { const bfloat16* beta, float* C, int block_id, int tid, int nthreads) {
return libxsmm_spmdm_compute_bfloat16_thread( return libxsmm_spmdm_compute_bfloat16_thread(
handle, transA, transB, reinterpret_cast<const uint16*>(alpha), A_sparse, handle, transA, transB, reinterpret_cast<const libxsmm_bfloat16*>(alpha), A_sparse,
reinterpret_cast<const uint16*>(B), transC, reinterpret_cast<const libxsmm_bfloat16*>(B), transC,
reinterpret_cast<const uint16*>(beta), C, block_id, tid, nthreads); reinterpret_cast<const libxsmm_bfloat16*>(beta), C, block_id, tid, nthreads);
} }
void wrapper_libxsmm_spmdm_compute_generic_thread( void wrapper_libxsmm_spmdm_compute_generic_thread(
empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA, empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
@ -1427,13 +1427,6 @@ inline void LibxsmmSparseMatMul<TL, TR>::Compute(
const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapR& right, const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapR& right,
bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool, bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
bool transpose_output, MatrixMap* output) { bool transpose_output, MatrixMap* output) {
if (false) {
// Not handled by libxsmm currently
SparseMatMul<TL, TR>::Compute(
nullptr /* Assumes no cached data for fallback */, left, right,
transpose_left, thread_pool, transpose_output, output);
return;
}
const int num_threads = thread_pool->num_threads; const int num_threads = thread_pool->num_threads;
const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0); const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1); const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
@ -1444,6 +1437,7 @@ inline void LibxsmmSparseMatMul<TL, TR>::Compute(
(transpose_output ? output->dimension(1) : output->dimension(0))); (transpose_output ? output->dimension(1) : output->dimension(0)));
CHECK_EQ(right_dim1, CHECK_EQ(right_dim1,
(transpose_output ? output->dimension(0) : output->dimension(1))); (transpose_output ? output->dimension(0) : output->dimension(1)));
#if 0 // this issue seems to be resolved
if (left_dim0 < 32 || left_dim1 < 32 || right_dim1 < 32) { if (left_dim0 < 32 || left_dim1 < 32 || right_dim1 < 32) {
// Causes problems in libxsmm // Causes problems in libxsmm
SparseMatMul<TL, TR>::Compute( SparseMatMul<TL, TR>::Compute(
@ -1451,6 +1445,7 @@ inline void LibxsmmSparseMatMul<TL, TR>::Compute(
transpose_left, thread_pool, transpose_output, output); transpose_left, thread_pool, transpose_output, output);
return; return;
} }
#endif
auto left_data = left.data(); auto left_data = left.data();
auto right_data = right.data(); auto right_data = right.data();
auto output_data = output->data(); auto output_data = output->data();
@ -1640,15 +1635,14 @@ inline void SparseMatMul<TL, TR>::Compute(
SparseMatMulOp<TA, TB, LibxsmmSparseMatMul>); SparseMatMulOp<TA, TB, LibxsmmSparseMatMul>);
#endif #endif
REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
REGISTER_SPARSE_MATMUL(float, bfloat16); REGISTER_SPARSE_MATMUL(float, bfloat16);
REGISTER_SPARSE_MATMUL(bfloat16, float); REGISTER_SPARSE_MATMUL(bfloat16, float);
#ifdef TENSORFLOW_USE_LIBXSMM #ifdef TENSORFLOW_USE_LIBXSMM
REGISTER_SPARSE_MATMUL_LIBXSMM(bfloat16, bfloat16);
REGISTER_SPARSE_MATMUL_LIBXSMM(float, float); REGISTER_SPARSE_MATMUL_LIBXSMM(float, float);
#else #else
REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
REGISTER_SPARSE_MATMUL(float, float); REGISTER_SPARSE_MATMUL(float, float);
#endif #endif

View File

@ -27,6 +27,9 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty();
#include <stdlib.h> #include <stdlib.h>
#include <cstring> #include <cstring>
#if defined(_OPENMP) && defined(LIBXSMM_USE_OPENMP)
# include <omp.h>
#endif
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/core/blocking_counter.h"
@ -36,6 +39,11 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty();
#include "include/libxsmm_malloc.h" #include "include/libxsmm_malloc.h"
#include "src/libxsmm_main.h" // TODO(bsteiner): API to avoid incl. header from src/ #include "src/libxsmm_main.h" // TODO(bsteiner): API to avoid incl. header from src/
#define CHECK_LIBXSMM(CONDITION_OK, MESSAGE) if (!(CONDITION_OK)) VLOG(0) << (MESSAGE)
#define CHECK_LIBXSMM_DNN(STATUS, MESSAGE) CHECK_LIBXSMM(LIBXSMM_DNN_SUCCESS == (STATUS), MESSAGE) \
<< " failed: " << libxsmm_dnn_get_error(STATUS);
namespace tensorflow { namespace tensorflow {
// Xsmm*Conv2D are wrappers for libxsmm direct convolutions. // Xsmm*Conv2D are wrappers for libxsmm direct convolutions.
@ -73,12 +81,6 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
namespace functor { namespace functor {
static void chk_libxsmm_err(libxsmm_dnn_err_t status, string msg) {
if (status != LIBXSMM_DNN_SUCCESS) {
VLOG(0) << msg << " failed: " << libxsmm_dnn_get_error(status);
}
}
LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float* kcrs, int R, LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float* kcrs, int R,
int S, int C, int K, int blocksifm, int S, int C, int K, int blocksifm,
int blocksofm, int ifmblock, int blocksofm, int ifmblock,
@ -114,55 +116,115 @@ LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float* kcrs, int R,
} }
} }
class libxsmm_dnn_conv_desc_wrap { struct libxsmm_dnn_registry_key {
public: const libxsmm_dnn_conv_desc descriptor;
const libxsmm_dnn_conv_desc d; libxsmm_dnn_registry_key(const libxsmm_dnn_conv_desc& desc_): descriptor(desc_) {}
bool operator==(const libxsmm_dnn_registry_key& regkey) const {
libxsmm_dnn_conv_desc_wrap(const libxsmm_dnn_conv_desc& d_) : d(d_) {} return 0 == memcmp(&descriptor, &regkey.descriptor, sizeof(descriptor));
bool operator==(const libxsmm_dnn_conv_desc_wrap& w) const {
return (d.N == w.d.N && d.C == w.d.C && d.H == w.d.H && d.W == w.d.W &&
d.K == w.d.K && d.R == w.d.R && d.S == w.d.S && d.u == w.d.u &&
d.v == w.d.v && d.pad_h == w.d.pad_h && d.pad_w == w.d.pad_w);
} }
}; };
struct HashFunction { struct HashFunction {
std::size_t operator()(const libxsmm_dnn_conv_desc_wrap& w) const { std::size_t operator()(const libxsmm_dnn_registry_key& regkey) const {
return libxsmm_hash(&w.d, sizeof(w.d), 25071975); return libxsmm_hash(&regkey.descriptor, sizeof(regkey.descriptor), 25071975);
} }
}; };
class handles { struct libxsmm_dnn_registry_value {
libxsmm_dnn_tensor_datalayout* layout_input;
libxsmm_dnn_tensor_datalayout* layout_filter;
libxsmm_dnn_tensor_datalayout* layout_output;
libxsmm_dnn_layer* handle;
};
typedef libxsmm_tf_allocator<libxsmm_scratch_allocator> libxsmm_tf_scratch_allocator;
static class libxsmm_dnn_registry_type {
private:
typedef std::unordered_map<
libxsmm_dnn_registry_key,
libxsmm_dnn_registry_value,
HashFunction>
container_type;
public: public:
libxsmm_dnn_layer* find(const libxsmm_dnn_conv_desc_wrap& w) { libxsmm_dnn_registry_type() {
std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*, libxsmm_init(); /* must be first */
HashFunction>::iterator i = libxsmm_handles.find(w); #if !defined(LIBXSMM_LOCAL_ALLOC)
if (i == libxsmm_handles.end()) { { libxsmm_malloc_function malloc_fn; libxsmm_free_function free_fn;
malloc_fn.function = libxsmm_tf_scratch_allocator::malloc;
free_fn.function = libxsmm_tf_scratch_allocator::free;
libxsmm_set_scratch_allocator(0/*context*/, malloc_fn, free_fn);
}
#endif
LIBXSMM_LOCK_ATTR_INIT(LIBXSMM_LOCK_RWLOCK, &attr);
LIBXSMM_LOCK_INIT(LIBXSMM_LOCK_RWLOCK, &lock, &attr);
}
~libxsmm_dnn_registry_type() {
LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK_RWLOCK, &lock);
const container_type::const_iterator end = container.end();
for (container_type::const_iterator i = container.begin(); i != end; ++i) {
CHECK_LIBXSMM_DNN(
libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_input),
"destroy input layout");
CHECK_LIBXSMM_DNN(
libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_output),
"destroy output layout");
CHECK_LIBXSMM_DNN(
libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_filter),
"destroy filter layout");
CHECK_LIBXSMM_DNN(
libxsmm_dnn_destroy_conv_layer(i->second.handle),
"destroy handle");
}
LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_RWLOCK, &lock);
LIBXSMM_LOCK_DESTROY(LIBXSMM_LOCK_RWLOCK, &lock);
LIBXSMM_LOCK_ATTR_DESTROY(LIBXSMM_LOCK_RWLOCK, &attr);
libxsmm_finalize();
}
libxsmm_dnn_registry_value find(const libxsmm_dnn_registry_key& regkey) {
container_type::iterator i;
LIBXSMM_LOCK_ACQREAD(LIBXSMM_LOCK_RWLOCK, &lock);
i = container.find(regkey);
LIBXSMM_LOCK_RELREAD(LIBXSMM_LOCK_RWLOCK, &lock);
if (i == container.end()) {
libxsmm_dnn_err_t status; libxsmm_dnn_err_t status;
libxsmm_dnn_layer* libxsmm_handle = libxsmm_dnn_registry_value regentry;
libxsmm_dnn_create_conv_layer(w.d, &status);
chk_libxsmm_err(status, "Create handle"); LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK_RWLOCK, &lock);
libxsmm_handles.insert(std::make_pair(w, libxsmm_handle)); i = container.find(regkey);
return libxsmm_handle; if (i == container.end()) { // re-check after lock acquisition
} else { regentry.handle = libxsmm_dnn_create_conv_layer(regkey.descriptor, &status);
if (LIBXSMM_DNN_WARN_FALLBACK != status) {
CHECK_LIBXSMM_DNN(status, "create handle");
}
else { // warning
VLOG(1) << libxsmm_dnn_get_error(status);
}
regentry.layout_input = libxsmm_dnn_create_tensor_datalayout(
regentry.handle, LIBXSMM_DNN_INPUT, &status);
CHECK_LIBXSMM_DNN(status, "create input layout");
regentry.layout_output = libxsmm_dnn_create_tensor_datalayout(
regentry.handle, LIBXSMM_DNN_OUTPUT, &status);
CHECK_LIBXSMM_DNN(status, "create output layout");
regentry.layout_filter = libxsmm_dnn_create_tensor_datalayout(
regentry.handle, LIBXSMM_DNN_FILTER, &status);
CHECK_LIBXSMM_DNN(status, "create filter layout");
i = container.insert(std::make_pair(regkey, regentry)).first;
}
LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_RWLOCK, &lock);
}
return i->second; return i->second;
} }
}
~handles() {
std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
HashFunction>::iterator i;
for (i = libxsmm_handles.begin(); i != libxsmm_handles.end(); i++)
chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(i->second),
"Destroy handle");
}
private: private:
std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*, container_type container;
HashFunction> LIBXSMM_LOCK_ATTR_TYPE(LIBXSMM_LOCK_RWLOCK) attr;
libxsmm_handles; LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK_RWLOCK) lock;
}; } libxsmm_dnn_registry;
static handles libxsmm_handles;
// #define LIBXSMM_DETAILED_TIMING // #define LIBXSMM_DETAILED_TIMING
@ -173,73 +235,61 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
InputPtr input, FilterPtr filter, InputPtr input, FilterPtr filter,
OutputPtr output) { OutputPtr output) {
#if defined(LIBXSMM_DETAILED_TIMING) #if defined(LIBXSMM_DETAILED_TIMING)
uint64 l_tick1; libxsmm_timer_tickint l_tick1;
uint64 l_tick2; libxsmm_timer_tickint l_tick2;
uint64 l_tick3; libxsmm_timer_tickint l_tick3;
uint64 l_tick4; libxsmm_timer_tickint l_tick4;
uint64 l_tick5; libxsmm_timer_tickint l_tick5;
uint64 l_tick6; libxsmm_timer_tickint l_tick6;
uint64 l_tick7; libxsmm_timer_tickint l_tick7;
uint64 l_tick8; libxsmm_timer_tickint l_tick8;
uint64 l_tick9; libxsmm_timer_tickint l_tick9;
uint64 l_tick10; libxsmm_timer_tickint l_tick10;
l_tick1 = libxsmm_timer_tick(); l_tick1 = libxsmm_timer_tick();
#endif #endif
// setup scoped allocator, which adopts the allocator from the context #if defined(LIBXSMM_LOCAL_ALLOC)
const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator(*ctx); // setup scoped allocator, which adopts the allocator of the current context
const libxsmm_tf_scratch_allocator tf_allocator(*ctx);
#endif
const libxsmm_dnn_registry_key regkey(desc);
const libxsmm_dnn_registry_value regentry = libxsmm_dnn_registry.find(regkey);
libxsmm_dnn_tensor *libxsmm_input, *libxsmm_output, *libxsmm_filter;
libxsmm_dnn_err_t status; libxsmm_dnn_err_t status;
libxsmm_dnn_layer* libxsmm_handle;
libxsmm_dnn_conv_desc_wrap w(desc);
void* scratch;
// if(kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) status = libxsmm_dnn_get_codegen_success(regentry.handle, kind);
libxsmm_handle = libxsmm_handles.find(w);
// else{
// libxsmm_handle = libxsmm_dnn_create_conv_layer(desc, &status);
// chk_libxsmm_err(status, "Create handle");
//}
status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind);
if (status == LIBXSMM_DNN_WARN_FALLBACK) { if (status == LIBXSMM_DNN_WARN_FALLBACK) {
return false; // Use non-libxsmm code return false; // Use non-libxsmm code
} }
chk_libxsmm_err(status, "Check codegen status"); CHECK_LIBXSMM_DNN(status, "code generation");
libxsmm_dnn_buffer* libxsmm_input;
libxsmm_dnn_buffer* libxsmm_output;
libxsmm_dnn_filter* libxsmm_filter;
#if defined(LIBXSMM_DETAILED_TIMING) #if defined(LIBXSMM_DETAILED_TIMING)
l_tick2 = libxsmm_timer_tick(); l_tick2 = libxsmm_timer_tick();
#endif #endif
int ifmblock = (libxsmm_handle->ifmblock); const int ifmblock = regentry.handle->ifmblock;
int ofmblock = (libxsmm_handle->ofmblock); const int ofmblock = regentry.handle->ofmblock;
int blocksifm = const int blocksifm = (desc.C % ifmblock == 0 ? desc.C / ifmblock : desc.C / ifmblock + 1);
desc.C % ifmblock == 0 ? desc.C / ifmblock : desc.C / ifmblock + 1; const int blocksofm = (desc.K % ofmblock == 0 ? desc.K / ofmblock : desc.K / ofmblock + 1);
int blocksofm =
desc.K % ofmblock == 0 ? desc.K / ofmblock : desc.K / ofmblock + 1;
float* native_filter =
(float*)libxsmm_aligned_scratch(blocksofm * blocksifm * desc.R * desc.S *
ifmblock * ofmblock * sizeof(float),
2097152);
const DeviceBase::CpuWorkerThreads* worker_threads = const size_t filter_size = blocksofm * blocksifm * desc.R * desc.S * ifmblock * ofmblock;
float *const native_filter = (float*)libxsmm_aligned_scratch(
filter_size * sizeof(float), 2097152/*alignment*/);
const DeviceBase::CpuWorkerThreads *const worker_threads =
ctx->device()->tensorflow_cpu_worker_threads(); ctx->device()->tensorflow_cpu_worker_threads();
const int num_threads = worker_threads->num_threads;
int num_threads = worker_threads->num_threads;
#if 1 #if 1
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD || if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
if (blocksofm > num_threads) { if (blocksofm > num_threads) {
int work = blocksofm; const int work = blocksofm;
BlockingCounter count(num_threads); BlockingCounter count(num_threads);
for (int i = 0; i < num_threads; ++i) { for (int i = 0; i < num_threads; ++i) {
worker_threads->workers->Schedule([=, &count]() { worker_threads->workers->Schedule([=, &count]() {
int start = work / num_threads * i; const int start = work / num_threads * i;
int end = (start + work / num_threads) > work const int end = (start + work / num_threads) > work
? work ? work
: start + work / num_threads; : start + work / num_threads;
copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C, copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
@ -250,14 +300,14 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
} }
count.Wait(); count.Wait();
} else { } else {
int work = blocksofm; const int work = blocksofm;
int num_threads = work; const int num_tasks = work;
BlockingCounter count(num_threads); BlockingCounter count(num_tasks);
for (int i = 0; i < num_threads; ++i) { for (int i = 0; i < num_tasks; ++i) {
worker_threads->workers->Schedule([=, &count]() { worker_threads->workers->Schedule([=, &count]() {
int start = i; const int start = i;
int end = i + 1; const int end = i + 1;
copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C, copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
desc.K, blocksifm, blocksofm, ifmblock, ofmblock, desc.K, blocksifm, blocksofm, ifmblock, ofmblock,
start, end); start, end);
@ -267,90 +317,81 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
count.Wait(); count.Wait();
} }
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) { } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
// Added: for weight update // weight update buffer must be in the right format (LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR)
libxsmm_filter = libxsmm_filter = libxsmm_dnn_link_tensor(regentry.layout_filter, filter, &status);
libxsmm_dnn_link_filter(libxsmm_handle, LIBXSMM_DNN_FILTER, filter, CHECK_LIBXSMM_DNN(status, "link filter with layout");
LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR, &status);
chk_libxsmm_err(status,
"Link filter"); // weight update is in RSCK as
// filter should be returned in RSCK
// format
} }
#else #else
memset(native_filter, 0, memset(native_filter, 0, filter_size * sizeof(float));
blocksofm * blocksifm * desc.R * desc.S * ifmblock * ofmblock *
sizeof(float));
#endif #endif
#if defined(LIBXSMM_DETAILED_TIMING) #if defined(LIBXSMM_DETAILED_TIMING)
l_tick3 = libxsmm_timer_tick(); l_tick3 = libxsmm_timer_tick();
#endif #endif
libxsmm_input = // LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR
libxsmm_dnn_link_buffer(libxsmm_handle, LIBXSMM_DNN_INPUT, input, libxsmm_input = libxsmm_dnn_link_tensor(regentry.layout_input, input, &status);
LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status); CHECK_LIBXSMM_DNN(status, "link input buffer with layout");
chk_libxsmm_err(status, "Link input buffer");
libxsmm_output = // LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR
libxsmm_dnn_link_buffer(libxsmm_handle, LIBXSMM_DNN_OUTPUT, output, libxsmm_output = libxsmm_dnn_link_tensor(regentry.layout_output, output, &status);
LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status); CHECK_LIBXSMM_DNN(status, "link output buffer with layout");
chk_libxsmm_err(status, "Link output buffer");
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD || if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
libxsmm_filter = libxsmm_dnn_link_filter( // LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR
libxsmm_handle, LIBXSMM_DNN_FILTER, native_filter, libxsmm_filter = libxsmm_dnn_link_tensor(regentry.layout_filter, native_filter, &status);
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR, &status); CHECK_LIBXSMM_DNN(status, "link filter with layout");
chk_libxsmm_err(status, "Link filter"); CHECK_LIBXSMM_DNN(
libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter, LIBXSMM_DNN_REGULAR_FILTER),
"bind filter to handle");
} }
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) { if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, CHECK_LIBXSMM_DNN(
LIBXSMM_DNN_REGULAR_INPUT), libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input, LIBXSMM_DNN_REGULAR_INPUT),
"Bind input forward"); "bind input forward");
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, CHECK_LIBXSMM_DNN(
LIBXSMM_DNN_REGULAR_OUTPUT), libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter, LIBXSMM_DNN_REGULAR_FILTER),
"Bind output forward"); "bind filter forward");
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, CHECK_LIBXSMM_DNN(
LIBXSMM_DNN_REGULAR_FILTER), libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output, LIBXSMM_DNN_REGULAR_OUTPUT),
"Bind filter forward"); "bind output forward");
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_input), "Zero input"); CHECK_LIBXSMM_DNN(libxsmm_dnn_zero_tensor(libxsmm_input), "zeroing input");
CHECK_LIBXSMM_DNN(
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input, LIBXSMM_DNN_GRADIENT_INPUT),
LIBXSMM_DNN_GRADIENT_INPUT), "bind input backward");
"Bind input backward"); CHECK_LIBXSMM_DNN(
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter, LIBXSMM_DNN_REGULAR_FILTER),
LIBXSMM_DNN_GRADIENT_OUTPUT), "bind filter backward");
"Bind output backward"); CHECK_LIBXSMM_DNN(
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output, LIBXSMM_DNN_GRADIENT_OUTPUT),
LIBXSMM_DNN_REGULAR_FILTER), "bind output backward");
"Bind filter backward");
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) { } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
chk_libxsmm_err(libxsmm_dnn_zero_filter(libxsmm_filter), "Zero filter"); CHECK_LIBXSMM_DNN(libxsmm_dnn_zero_tensor(libxsmm_filter), "zeroing filter");
CHECK_LIBXSMM_DNN(
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input, LIBXSMM_DNN_REGULAR_INPUT),
LIBXSMM_DNN_REGULAR_INPUT), "bind input weight update");
"Bind input weight update"); CHECK_LIBXSMM_DNN(
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter, LIBXSMM_DNN_GRADIENT_FILTER),
LIBXSMM_DNN_GRADIENT_OUTPUT), "bind filter weight update");
"Bind output weight update"); CHECK_LIBXSMM_DNN(
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output, LIBXSMM_DNN_GRADIENT_OUTPUT),
LIBXSMM_DNN_GRADIENT_FILTER), "bind output weight update");
"Bind filter weight update");
} else { } else {
/* shouldn't happen */ assert(0/*should not happen*/);
} }
#if defined(LIBXSMM_DETAILED_TIMING) #if defined(LIBXSMM_DETAILED_TIMING)
l_tick4 = libxsmm_timer_tick(); l_tick4 = libxsmm_timer_tick();
#endif #endif
/* bind scratch */ const size_t scratch_size = libxsmm_dnn_get_scratch_size(regentry.handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, &status);
scratch = (void*)libxsmm_aligned_scratch( CHECK_LIBXSMM_DNN(status, "get scratch size");
libxsmm_dnn_get_scratch_size(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, void *const scratch = libxsmm_aligned_scratch(scratch_size, 2097152/*alignment*/);
&status), CHECK_LIBXSMM(0 != scratch, "scratch memory allocation");
2097152); CHECK_LIBXSMM_DNN(
chk_libxsmm_err(status, "scratch allocation"); libxsmm_dnn_bind_scratch(regentry.handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch),
chk_libxsmm_err(libxsmm_dnn_bind_scratch(
libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch),
"binding scratch"); "binding scratch");
#if defined(LIBXSMM_DETAILED_TIMING) #if defined(LIBXSMM_DETAILED_TIMING)
@ -358,30 +399,40 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
#endif #endif
if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
libxsmm_dnn_transpose_filter(libxsmm_handle, LIBXSMM_DNN_FILTER); libxsmm_dnn_transpose_filter(regentry.handle, LIBXSMM_DNN_FILTER);
} }
#if defined(LIBXSMM_DETAILED_TIMING) #if defined(LIBXSMM_DETAILED_TIMING)
l_tick6 = libxsmm_timer_tick(); l_tick6 = libxsmm_timer_tick();
#endif #endif
#if !defined(_OPENMP) || !defined(LIBXSMM_USE_OPENMP)
BlockingCounter counter(num_threads); BlockingCounter counter(num_threads);
for (int i = 0; i < num_threads; ++i) { for (int i = 0; i < num_threads; ++i) {
worker_threads->workers->Schedule([=, &counter]() { worker_threads->workers->Schedule([=, &counter]() {
chk_libxsmm_err(libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, i), CHECK_LIBXSMM_DNN(
"Worker"); libxsmm_dnn_execute_st(regentry.handle, kind, 0, i),
"worker");
counter.DecrementCount(); counter.DecrementCount();
}); });
} }
counter.Wait(); counter.Wait();
#else
# pragma omp parallel
{
CHECK_LIBXSMM_DNN(
libxsmm_dnn_execute_st(regentry.handle, kind, 0, omp_get_thread_num()),
"worker");
}
#endif
#if defined(LIBXSMM_DETAILED_TIMING) #if defined(LIBXSMM_DETAILED_TIMING)
l_tick7 = libxsmm_timer_tick(); l_tick7 = libxsmm_timer_tick();
#endif #endif
if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) { if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
libxsmm_dnn_reduce_wu_filters(libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER); libxsmm_dnn_reduce_wu_filters(regentry.handle, LIBXSMM_DNN_GRADIENT_FILTER);
} }
#if defined(LIBXSMM_DETAILED_TIMING) #if defined(LIBXSMM_DETAILED_TIMING)
@ -389,54 +440,50 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
#endif #endif
/* clean up */ /* clean up */
chk_libxsmm_err( CHECK_LIBXSMM_DNN(
libxsmm_dnn_release_scratch(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL), libxsmm_dnn_release_scratch(regentry.handle, LIBXSMM_DNN_COMPUTE_KIND_ALL),
"release scratch"); "release scratch");
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) { if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
chk_libxsmm_err( CHECK_LIBXSMM_DNN(
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT), libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_INPUT),
"release input"); "release input");
chk_libxsmm_err( CHECK_LIBXSMM_DNN(
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_OUTPUT), libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_OUTPUT),
"release output"); "release output");
chk_libxsmm_err( CHECK_LIBXSMM_DNN(
libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER), libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_FILTER),
"release filter"); "release filter");
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
chk_libxsmm_err( CHECK_LIBXSMM_DNN(
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_INPUT), libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_GRADIENT_INPUT),
"release input"); "release input");
chk_libxsmm_err( CHECK_LIBXSMM_DNN(
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT), libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_GRADIENT_OUTPUT),
"release output"); "release output");
chk_libxsmm_err( CHECK_LIBXSMM_DNN(
libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER), libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_FILTER),
"release filter"); "release filter");
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) { } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
chk_libxsmm_err( CHECK_LIBXSMM_DNN(
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT), libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_INPUT),
"release input"); "release input");
chk_libxsmm_err( CHECK_LIBXSMM_DNN(
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT), libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_GRADIENT_OUTPUT),
"release output"); "release output");
chk_libxsmm_err( CHECK_LIBXSMM_DNN(
libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER), libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_GRADIENT_FILTER),
"release filter"); "release filter");
} else { } else {
/* shouldn't happen */ /* shouldn't happen */
} }
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_input), "Destroy input"); CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_input), "destroy input");
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_output), "Destroy output"); CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_output), "destroy output");
chk_libxsmm_err(libxsmm_dnn_destroy_filter(libxsmm_filter), "Destroy filter"); CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_filter), "destroy filter");
#if defined(LIBXSMM_DETAILED_TIMING) #if defined(LIBXSMM_DETAILED_TIMING)
l_tick9 = libxsmm_timer_tick(); l_tick9 = libxsmm_timer_tick();
#endif #endif
// if(kind != LIBXSMM_DNN_COMPUTE_KIND_FWD)
// chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle),
// "Destroy handle");
libxsmm_free(native_filter); libxsmm_free(native_filter);
libxsmm_free(scratch); libxsmm_free(scratch);

View File

@ -308,12 +308,11 @@ TEST(XsmmConv2DTest, Basic) {
desc.threads = num_threads; desc.threads = num_threads;
desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT; desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC; desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
desc.filter_format = desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;//LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;//LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE; desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE; desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
desc.datatype = LIBXSMM_DNN_DATATYPE_F32; desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
if (!CanUseXsmmConv2D(desc, data_format)) { if (!CanUseXsmmConv2D(desc, data_format)) {
return false; return false;
} }

View File

@ -193,11 +193,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
tf_http_archive( tf_http_archive(
name = "libxsmm_archive", name = "libxsmm_archive",
build_file = clean_dep("//third_party:libxsmm.BUILD"), build_file = clean_dep("//third_party:libxsmm.BUILD"),
sha256 = "5fc1972471cd8e2b8b64ea017590193739fc88d9818e3d086621e5c08e86ea35", sha256 = "9c0af4509ea341d1ee2c6c19fc6f19289318c3bd4b17844efeb9e7f9691abf76",
strip_prefix = "libxsmm-1.11", strip_prefix = "libxsmm-1.14",
urls = [ urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/hfp/libxsmm/archive/1.11.tar.gz", "https://storage.googleapis.com/mirror.tensorflow.org/github.com/hfp/libxsmm/archive/1.14.tar.gz",
"https://github.com/hfp/libxsmm/archive/1.11.tar.gz", "https://github.com/hfp/libxsmm/archive/1.14.tar.gz",
], ],
) )

View File

@ -45,62 +45,39 @@ genrule(
cc_library( cc_library(
name = "xsmm_avx", name = "xsmm_avx",
srcs = [ srcs = glob([
"src/libxsmm_cpuid_x86.c", # general source files (translation units)
"src/libxsmm_dnn.c",
"src/libxsmm_dnn_convolution_backward.c",
"src/libxsmm_dnn_convolution_forward.c",
"src/libxsmm_dnn_convolution_weight_update.c",
"src/libxsmm_dnn_convolution_winograd_backward.c",
"src/libxsmm_dnn_convolution_winograd_forward.c",
"src/libxsmm_dnn_convolution_winograd_weight_update.c",
"src/libxsmm_dnn_handle.c",
"src/libxsmm_dump.c",
"src/libxsmm_ext_gemm.c",
"src/libxsmm_ext_trans.c",
"src/libxsmm_fsspmdm.c",
"src/libxsmm_gemm.c",
"src/libxsmm_main.c",
"src/libxsmm_malloc.c",
"src/libxsmm_perf.c",
"src/libxsmm_spmdm.c",
"src/libxsmm_sync.c",
"src/libxsmm_timer.c",
"src/libxsmm_trace.c",
"src/libxsmm_trans.c",
] + glob([
"src/generator_*.c", "src/generator_*.c",
"src/libxsmm_*.c",
], exclude=[
# exclude generators (with main functions)
"src/libxsmm_generator_*.c",
]), ]),
hdrs = [ hdrs = glob([
"include/libxsmm_cpuid.h", # general header files
"include/libxsmm_dnn.h", "include/libxsmm_*.h",
"include/libxsmm_frontend.h", # trigger rebuild if template changed
"include/libxsmm_fsspmdm.h", "src/template/*.c",
"include/libxsmm_generator.h", ], exclude=[
"include/libxsmm_intrinsics_x86.h", # exclude existing/generated headers
"include/libxsmm_macros.h", "include/libxsmm.h",
"include/libxsmm_malloc.h", "include/libxsmm_config.h",
"include/libxsmm_spmdm.h", "include/libxsmm_dispatch.h",
"include/libxsmm_sync.h", ]) + [
"include/libxsmm_timer.h", # source files included internally
"include/libxsmm_typedefs.h", "src/libxsmm_hash.c",
# Source files #included internally: # generated header files
"src/libxsmm_gemm_diff.c",
"src/libxsmm_hash.c",
# Generated:
"include/libxsmm.h", "include/libxsmm.h",
"include/libxsmm_config.h", "include/libxsmm_config.h",
"include/libxsmm_dispatch.h", "include/libxsmm_dispatch.h",
] + glob([
# trigger rebuild if template changed
"src/template/*.c",
]),
copts = [
"-mavx", # JIT does not work without avx anyway, and this silences some CRC32 warnings.
"-Wno-vla", # Libxsmm convolutions heavily use VLA.
], ],
#copts = [
# "-mavx", # JIT does not work without avx anyway, and this silences some CRC32 warnings.
# "-Wno-vla", # Libxsmm convolutions heavily use VLA.
#],
defines = [ defines = [
"LIBXSMM_BUILD", "LIBXSMM_BUILD",
"LIBXSMM_CTOR",
"__BLAS=0", "__BLAS=0",
], ],
includes = [ includes = [