LIBXSMM in TensorFlow broken due to TensorFlow API changes (#24706)
* Original issue about TF/API change was resolved in LIBXSMM 1.11. * Register sparse matmul-op to use LIBXSMM's Bfloat16 functionality; - SpMDM/Bf16 was present since TF v1.1 but TF had no/public Bf16 * Updated (outdated) conv-integration with LIBXSMM; fixes compilation - https://github.com/hfp/libxsmm/issues/281 should be considered - Use desc. datatype_in/out; desc. previously had only "datatype" * Generalized libxsmm.BUILD using glob/patterns. * Updated LIBXSMM v1.14.
This commit is contained in:
parent
88a3f5c153
commit
dcbfb65b94
|
@ -241,8 +241,8 @@ struct LaunchXsmmBackwardFilter<CPUDevice, float> {
|
|||
desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
|
||||
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
|
||||
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
|
||||
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
|
||||
|
||||
desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
|
||||
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
|
||||
if (!CanUseXsmmConv2D(desc, data_format)) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -290,9 +290,9 @@ struct LaunchXsmmBackwardInputConvolution<CPUDevice, float> {
|
|||
desc.filter_format =
|
||||
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM; // LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
|
||||
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
|
||||
desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE;
|
||||
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
|
||||
|
||||
desc.options = LIBXSMM_DNN_CONV_OPTION_OVERWRITE;
|
||||
desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
|
||||
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
|
||||
auto input_ptr = input_backward.data();
|
||||
auto filter_ptr = kernel.data();
|
||||
auto output_ptr = output_backward.data();
|
||||
|
|
|
@ -320,9 +320,9 @@ class LaunchXsmmConvOp<CPUDevice, float> {
|
|||
desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
|
||||
desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;
|
||||
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
|
||||
desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE;
|
||||
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
|
||||
|
||||
desc.options = LIBXSMM_DNN_CONV_OPTION_OVERWRITE;
|
||||
desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
|
||||
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
|
||||
if (dilation_rows != 1 || dilation_cols != 1 ||
|
||||
!CanUseXsmmConv2D(desc, data_format)) {
|
||||
return false;
|
||||
|
|
|
@ -1396,7 +1396,7 @@ void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
|
|||
libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, int tid,
|
||||
int nthreads) {
|
||||
return libxsmm_spmdm_createSparseSlice_bfloat16_thread(
|
||||
handle, transA, reinterpret_cast<const uint16*>(A), libxsmm_output_csr_a,
|
||||
handle, transA, reinterpret_cast<const libxsmm_bfloat16*>(A), libxsmm_output_csr_a,
|
||||
block_id, tid, nthreads);
|
||||
}
|
||||
|
||||
|
@ -1406,9 +1406,9 @@ void wrapper_libxsmm_spmdm_compute_generic_thread(
|
|||
libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC,
|
||||
const bfloat16* beta, float* C, int block_id, int tid, int nthreads) {
|
||||
return libxsmm_spmdm_compute_bfloat16_thread(
|
||||
handle, transA, transB, reinterpret_cast<const uint16*>(alpha), A_sparse,
|
||||
reinterpret_cast<const uint16*>(B), transC,
|
||||
reinterpret_cast<const uint16*>(beta), C, block_id, tid, nthreads);
|
||||
handle, transA, transB, reinterpret_cast<const libxsmm_bfloat16*>(alpha), A_sparse,
|
||||
reinterpret_cast<const libxsmm_bfloat16*>(B), transC,
|
||||
reinterpret_cast<const libxsmm_bfloat16*>(beta), C, block_id, tid, nthreads);
|
||||
}
|
||||
void wrapper_libxsmm_spmdm_compute_generic_thread(
|
||||
empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
|
||||
|
@ -1427,13 +1427,6 @@ inline void LibxsmmSparseMatMul<TL, TR>::Compute(
|
|||
const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapR& right,
|
||||
bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
|
||||
bool transpose_output, MatrixMap* output) {
|
||||
if (false) {
|
||||
// Not handled by libxsmm currently
|
||||
SparseMatMul<TL, TR>::Compute(
|
||||
nullptr /* Assumes no cached data for fallback */, left, right,
|
||||
transpose_left, thread_pool, transpose_output, output);
|
||||
return;
|
||||
}
|
||||
const int num_threads = thread_pool->num_threads;
|
||||
const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
|
||||
const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
|
||||
|
@ -1444,6 +1437,7 @@ inline void LibxsmmSparseMatMul<TL, TR>::Compute(
|
|||
(transpose_output ? output->dimension(1) : output->dimension(0)));
|
||||
CHECK_EQ(right_dim1,
|
||||
(transpose_output ? output->dimension(0) : output->dimension(1)));
|
||||
#if 0 // this issue seems to be resolved
|
||||
if (left_dim0 < 32 || left_dim1 < 32 || right_dim1 < 32) {
|
||||
// Causes problems in libxsmm
|
||||
SparseMatMul<TL, TR>::Compute(
|
||||
|
@ -1451,6 +1445,7 @@ inline void LibxsmmSparseMatMul<TL, TR>::Compute(
|
|||
transpose_left, thread_pool, transpose_output, output);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
auto left_data = left.data();
|
||||
auto right_data = right.data();
|
||||
auto output_data = output->data();
|
||||
|
@ -1640,15 +1635,14 @@ inline void SparseMatMul<TL, TR>::Compute(
|
|||
SparseMatMulOp<TA, TB, LibxsmmSparseMatMul>);
|
||||
#endif
|
||||
|
||||
REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
|
||||
|
||||
REGISTER_SPARSE_MATMUL(float, bfloat16);
|
||||
|
||||
REGISTER_SPARSE_MATMUL(bfloat16, float);
|
||||
|
||||
#ifdef TENSORFLOW_USE_LIBXSMM
|
||||
REGISTER_SPARSE_MATMUL_LIBXSMM(bfloat16, bfloat16);
|
||||
REGISTER_SPARSE_MATMUL_LIBXSMM(float, float);
|
||||
#else
|
||||
REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
|
||||
REGISTER_SPARSE_MATMUL(float, float);
|
||||
#endif
|
||||
|
||||
|
|
|
@ -27,6 +27,9 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty();
|
|||
|
||||
#include <stdlib.h>
|
||||
#include <cstring>
|
||||
#if defined(_OPENMP) && defined(LIBXSMM_USE_OPENMP)
|
||||
# include <omp.h>
|
||||
#endif
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
|
@ -36,6 +39,11 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty();
|
|||
#include "include/libxsmm_malloc.h"
|
||||
#include "src/libxsmm_main.h" // TODO(bsteiner): API to avoid incl. header from src/
|
||||
|
||||
#define CHECK_LIBXSMM(CONDITION_OK, MESSAGE) if (!(CONDITION_OK)) VLOG(0) << (MESSAGE)
|
||||
#define CHECK_LIBXSMM_DNN(STATUS, MESSAGE) CHECK_LIBXSMM(LIBXSMM_DNN_SUCCESS == (STATUS), MESSAGE) \
|
||||
<< " failed: " << libxsmm_dnn_get_error(STATUS);
|
||||
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Xsmm*Conv2D are wrappers for libxsmm direct convolutions.
|
||||
|
@ -73,12 +81,6 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
|
|||
|
||||
namespace functor {
|
||||
|
||||
static void chk_libxsmm_err(libxsmm_dnn_err_t status, string msg) {
|
||||
if (status != LIBXSMM_DNN_SUCCESS) {
|
||||
VLOG(0) << msg << " failed: " << libxsmm_dnn_get_error(status);
|
||||
}
|
||||
}
|
||||
|
||||
LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float* kcrs, int R,
|
||||
int S, int C, int K, int blocksifm,
|
||||
int blocksofm, int ifmblock,
|
||||
|
@ -114,55 +116,115 @@ LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float* kcrs, int R,
|
|||
}
|
||||
}
|
||||
|
||||
class libxsmm_dnn_conv_desc_wrap {
|
||||
public:
|
||||
const libxsmm_dnn_conv_desc d;
|
||||
|
||||
libxsmm_dnn_conv_desc_wrap(const libxsmm_dnn_conv_desc& d_) : d(d_) {}
|
||||
bool operator==(const libxsmm_dnn_conv_desc_wrap& w) const {
|
||||
return (d.N == w.d.N && d.C == w.d.C && d.H == w.d.H && d.W == w.d.W &&
|
||||
d.K == w.d.K && d.R == w.d.R && d.S == w.d.S && d.u == w.d.u &&
|
||||
d.v == w.d.v && d.pad_h == w.d.pad_h && d.pad_w == w.d.pad_w);
|
||||
struct libxsmm_dnn_registry_key {
|
||||
const libxsmm_dnn_conv_desc descriptor;
|
||||
libxsmm_dnn_registry_key(const libxsmm_dnn_conv_desc& desc_): descriptor(desc_) {}
|
||||
bool operator==(const libxsmm_dnn_registry_key& regkey) const {
|
||||
return 0 == memcmp(&descriptor, ®key.descriptor, sizeof(descriptor));
|
||||
}
|
||||
};
|
||||
|
||||
struct HashFunction {
|
||||
std::size_t operator()(const libxsmm_dnn_conv_desc_wrap& w) const {
|
||||
return libxsmm_hash(&w.d, sizeof(w.d), 25071975);
|
||||
std::size_t operator()(const libxsmm_dnn_registry_key& regkey) const {
|
||||
return libxsmm_hash(®key.descriptor, sizeof(regkey.descriptor), 25071975);
|
||||
}
|
||||
};
|
||||
|
||||
class handles {
|
||||
public:
|
||||
libxsmm_dnn_layer* find(const libxsmm_dnn_conv_desc_wrap& w) {
|
||||
std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
|
||||
HashFunction>::iterator i = libxsmm_handles.find(w);
|
||||
if (i == libxsmm_handles.end()) {
|
||||
libxsmm_dnn_err_t status;
|
||||
libxsmm_dnn_layer* libxsmm_handle =
|
||||
libxsmm_dnn_create_conv_layer(w.d, &status);
|
||||
chk_libxsmm_err(status, "Create handle");
|
||||
libxsmm_handles.insert(std::make_pair(w, libxsmm_handle));
|
||||
return libxsmm_handle;
|
||||
} else {
|
||||
return i->second;
|
||||
struct libxsmm_dnn_registry_value {
|
||||
libxsmm_dnn_tensor_datalayout* layout_input;
|
||||
libxsmm_dnn_tensor_datalayout* layout_filter;
|
||||
libxsmm_dnn_tensor_datalayout* layout_output;
|
||||
libxsmm_dnn_layer* handle;
|
||||
};
|
||||
|
||||
typedef libxsmm_tf_allocator<libxsmm_scratch_allocator> libxsmm_tf_scratch_allocator;
|
||||
|
||||
static class libxsmm_dnn_registry_type {
|
||||
private:
|
||||
typedef std::unordered_map<
|
||||
libxsmm_dnn_registry_key,
|
||||
libxsmm_dnn_registry_value,
|
||||
HashFunction>
|
||||
container_type;
|
||||
|
||||
public:
|
||||
libxsmm_dnn_registry_type() {
|
||||
libxsmm_init(); /* must be first */
|
||||
#if !defined(LIBXSMM_LOCAL_ALLOC)
|
||||
{ libxsmm_malloc_function malloc_fn; libxsmm_free_function free_fn;
|
||||
malloc_fn.function = libxsmm_tf_scratch_allocator::malloc;
|
||||
free_fn.function = libxsmm_tf_scratch_allocator::free;
|
||||
libxsmm_set_scratch_allocator(0/*context*/, malloc_fn, free_fn);
|
||||
}
|
||||
#endif
|
||||
LIBXSMM_LOCK_ATTR_INIT(LIBXSMM_LOCK_RWLOCK, &attr);
|
||||
LIBXSMM_LOCK_INIT(LIBXSMM_LOCK_RWLOCK, &lock, &attr);
|
||||
}
|
||||
~handles() {
|
||||
std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
|
||||
HashFunction>::iterator i;
|
||||
for (i = libxsmm_handles.begin(); i != libxsmm_handles.end(); i++)
|
||||
chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(i->second),
|
||||
"Destroy handle");
|
||||
~libxsmm_dnn_registry_type() {
|
||||
LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK_RWLOCK, &lock);
|
||||
const container_type::const_iterator end = container.end();
|
||||
for (container_type::const_iterator i = container.begin(); i != end; ++i) {
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_input),
|
||||
"destroy input layout");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_output),
|
||||
"destroy output layout");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_filter),
|
||||
"destroy filter layout");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_destroy_conv_layer(i->second.handle),
|
||||
"destroy handle");
|
||||
}
|
||||
LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_RWLOCK, &lock);
|
||||
LIBXSMM_LOCK_DESTROY(LIBXSMM_LOCK_RWLOCK, &lock);
|
||||
LIBXSMM_LOCK_ATTR_DESTROY(LIBXSMM_LOCK_RWLOCK, &attr);
|
||||
libxsmm_finalize();
|
||||
}
|
||||
libxsmm_dnn_registry_value find(const libxsmm_dnn_registry_key& regkey) {
|
||||
container_type::iterator i;
|
||||
LIBXSMM_LOCK_ACQREAD(LIBXSMM_LOCK_RWLOCK, &lock);
|
||||
i = container.find(regkey);
|
||||
LIBXSMM_LOCK_RELREAD(LIBXSMM_LOCK_RWLOCK, &lock);
|
||||
if (i == container.end()) {
|
||||
libxsmm_dnn_err_t status;
|
||||
libxsmm_dnn_registry_value regentry;
|
||||
|
||||
LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK_RWLOCK, &lock);
|
||||
i = container.find(regkey);
|
||||
if (i == container.end()) { // re-check after lock acquisition
|
||||
regentry.handle = libxsmm_dnn_create_conv_layer(regkey.descriptor, &status);
|
||||
if (LIBXSMM_DNN_WARN_FALLBACK != status) {
|
||||
CHECK_LIBXSMM_DNN(status, "create handle");
|
||||
}
|
||||
else { // warning
|
||||
VLOG(1) << libxsmm_dnn_get_error(status);
|
||||
}
|
||||
regentry.layout_input = libxsmm_dnn_create_tensor_datalayout(
|
||||
regentry.handle, LIBXSMM_DNN_INPUT, &status);
|
||||
CHECK_LIBXSMM_DNN(status, "create input layout");
|
||||
|
||||
regentry.layout_output = libxsmm_dnn_create_tensor_datalayout(
|
||||
regentry.handle, LIBXSMM_DNN_OUTPUT, &status);
|
||||
CHECK_LIBXSMM_DNN(status, "create output layout");
|
||||
|
||||
regentry.layout_filter = libxsmm_dnn_create_tensor_datalayout(
|
||||
regentry.handle, LIBXSMM_DNN_FILTER, &status);
|
||||
CHECK_LIBXSMM_DNN(status, "create filter layout");
|
||||
|
||||
i = container.insert(std::make_pair(regkey, regentry)).first;
|
||||
}
|
||||
LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_RWLOCK, &lock);
|
||||
}
|
||||
return i->second;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
|
||||
HashFunction>
|
||||
libxsmm_handles;
|
||||
};
|
||||
|
||||
static handles libxsmm_handles;
|
||||
private:
|
||||
container_type container;
|
||||
LIBXSMM_LOCK_ATTR_TYPE(LIBXSMM_LOCK_RWLOCK) attr;
|
||||
LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK_RWLOCK) lock;
|
||||
} libxsmm_dnn_registry;
|
||||
|
||||
// #define LIBXSMM_DETAILED_TIMING
|
||||
|
||||
|
@ -173,73 +235,61 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
|||
InputPtr input, FilterPtr filter,
|
||||
OutputPtr output) {
|
||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||
uint64 l_tick1;
|
||||
uint64 l_tick2;
|
||||
uint64 l_tick3;
|
||||
uint64 l_tick4;
|
||||
uint64 l_tick5;
|
||||
uint64 l_tick6;
|
||||
uint64 l_tick7;
|
||||
uint64 l_tick8;
|
||||
uint64 l_tick9;
|
||||
uint64 l_tick10;
|
||||
libxsmm_timer_tickint l_tick1;
|
||||
libxsmm_timer_tickint l_tick2;
|
||||
libxsmm_timer_tickint l_tick3;
|
||||
libxsmm_timer_tickint l_tick4;
|
||||
libxsmm_timer_tickint l_tick5;
|
||||
libxsmm_timer_tickint l_tick6;
|
||||
libxsmm_timer_tickint l_tick7;
|
||||
libxsmm_timer_tickint l_tick8;
|
||||
libxsmm_timer_tickint l_tick9;
|
||||
libxsmm_timer_tickint l_tick10;
|
||||
l_tick1 = libxsmm_timer_tick();
|
||||
#endif
|
||||
// setup scoped allocator, which adopts the allocator from the context
|
||||
const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator(*ctx);
|
||||
#if defined(LIBXSMM_LOCAL_ALLOC)
|
||||
// setup scoped allocator, which adopts the allocator of the current context
|
||||
const libxsmm_tf_scratch_allocator tf_allocator(*ctx);
|
||||
#endif
|
||||
const libxsmm_dnn_registry_key regkey(desc);
|
||||
const libxsmm_dnn_registry_value regentry = libxsmm_dnn_registry.find(regkey);
|
||||
libxsmm_dnn_tensor *libxsmm_input, *libxsmm_output, *libxsmm_filter;
|
||||
libxsmm_dnn_err_t status;
|
||||
libxsmm_dnn_layer* libxsmm_handle;
|
||||
libxsmm_dnn_conv_desc_wrap w(desc);
|
||||
void* scratch;
|
||||
|
||||
// if(kind == LIBXSMM_DNN_COMPUTE_KIND_FWD)
|
||||
libxsmm_handle = libxsmm_handles.find(w);
|
||||
// else{
|
||||
// libxsmm_handle = libxsmm_dnn_create_conv_layer(desc, &status);
|
||||
// chk_libxsmm_err(status, "Create handle");
|
||||
//}
|
||||
|
||||
status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind);
|
||||
status = libxsmm_dnn_get_codegen_success(regentry.handle, kind);
|
||||
if (status == LIBXSMM_DNN_WARN_FALLBACK) {
|
||||
return false; // Use non-libxsmm code
|
||||
}
|
||||
chk_libxsmm_err(status, "Check codegen status");
|
||||
|
||||
libxsmm_dnn_buffer* libxsmm_input;
|
||||
libxsmm_dnn_buffer* libxsmm_output;
|
||||
libxsmm_dnn_filter* libxsmm_filter;
|
||||
CHECK_LIBXSMM_DNN(status, "code generation");
|
||||
|
||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||
l_tick2 = libxsmm_timer_tick();
|
||||
#endif
|
||||
|
||||
int ifmblock = (libxsmm_handle->ifmblock);
|
||||
int ofmblock = (libxsmm_handle->ofmblock);
|
||||
const int ifmblock = regentry.handle->ifmblock;
|
||||
const int ofmblock = regentry.handle->ofmblock;
|
||||
|
||||
int blocksifm =
|
||||
desc.C % ifmblock == 0 ? desc.C / ifmblock : desc.C / ifmblock + 1;
|
||||
int blocksofm =
|
||||
desc.K % ofmblock == 0 ? desc.K / ofmblock : desc.K / ofmblock + 1;
|
||||
float* native_filter =
|
||||
(float*)libxsmm_aligned_scratch(blocksofm * blocksifm * desc.R * desc.S *
|
||||
ifmblock * ofmblock * sizeof(float),
|
||||
2097152);
|
||||
const int blocksifm = (desc.C % ifmblock == 0 ? desc.C / ifmblock : desc.C / ifmblock + 1);
|
||||
const int blocksofm = (desc.K % ofmblock == 0 ? desc.K / ofmblock : desc.K / ofmblock + 1);
|
||||
|
||||
const DeviceBase::CpuWorkerThreads* worker_threads =
|
||||
ctx->device()->tensorflow_cpu_worker_threads();
|
||||
const size_t filter_size = blocksofm * blocksifm * desc.R * desc.S * ifmblock * ofmblock;
|
||||
float *const native_filter = (float*)libxsmm_aligned_scratch(
|
||||
filter_size * sizeof(float), 2097152/*alignment*/);
|
||||
|
||||
int num_threads = worker_threads->num_threads;
|
||||
const DeviceBase::CpuWorkerThreads *const worker_threads =
|
||||
ctx->device()->tensorflow_cpu_worker_threads();
|
||||
const int num_threads = worker_threads->num_threads;
|
||||
|
||||
#if 1
|
||||
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
|
||||
kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
|
||||
if (blocksofm > num_threads) {
|
||||
int work = blocksofm;
|
||||
const int work = blocksofm;
|
||||
BlockingCounter count(num_threads);
|
||||
for (int i = 0; i < num_threads; ++i) {
|
||||
worker_threads->workers->Schedule([=, &count]() {
|
||||
int start = work / num_threads * i;
|
||||
int end = (start + work / num_threads) > work
|
||||
const int start = work / num_threads * i;
|
||||
const int end = (start + work / num_threads) > work
|
||||
? work
|
||||
: start + work / num_threads;
|
||||
copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
|
||||
|
@ -250,14 +300,14 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
|||
}
|
||||
count.Wait();
|
||||
} else {
|
||||
int work = blocksofm;
|
||||
int num_threads = work;
|
||||
const int work = blocksofm;
|
||||
const int num_tasks = work;
|
||||
|
||||
BlockingCounter count(num_threads);
|
||||
for (int i = 0; i < num_threads; ++i) {
|
||||
BlockingCounter count(num_tasks);
|
||||
for (int i = 0; i < num_tasks; ++i) {
|
||||
worker_threads->workers->Schedule([=, &count]() {
|
||||
int start = i;
|
||||
int end = i + 1;
|
||||
const int start = i;
|
||||
const int end = i + 1;
|
||||
copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
|
||||
desc.K, blocksifm, blocksofm, ifmblock, ofmblock,
|
||||
start, end);
|
||||
|
@ -267,121 +317,122 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
|||
count.Wait();
|
||||
}
|
||||
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
|
||||
// Added: for weight update
|
||||
libxsmm_filter =
|
||||
libxsmm_dnn_link_filter(libxsmm_handle, LIBXSMM_DNN_FILTER, filter,
|
||||
LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR, &status);
|
||||
chk_libxsmm_err(status,
|
||||
"Link filter"); // weight update is in RSCK as
|
||||
// filter should be returned in RSCK
|
||||
// format
|
||||
// weight update buffer must be in the right format (LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR)
|
||||
libxsmm_filter = libxsmm_dnn_link_tensor(regentry.layout_filter, filter, &status);
|
||||
CHECK_LIBXSMM_DNN(status, "link filter with layout");
|
||||
}
|
||||
#else
|
||||
memset(native_filter, 0,
|
||||
blocksofm * blocksifm * desc.R * desc.S * ifmblock * ofmblock *
|
||||
sizeof(float));
|
||||
memset(native_filter, 0, filter_size * sizeof(float));
|
||||
#endif
|
||||
|
||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||
l_tick3 = libxsmm_timer_tick();
|
||||
#endif
|
||||
|
||||
libxsmm_input =
|
||||
libxsmm_dnn_link_buffer(libxsmm_handle, LIBXSMM_DNN_INPUT, input,
|
||||
LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
|
||||
chk_libxsmm_err(status, "Link input buffer");
|
||||
libxsmm_output =
|
||||
libxsmm_dnn_link_buffer(libxsmm_handle, LIBXSMM_DNN_OUTPUT, output,
|
||||
LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
|
||||
chk_libxsmm_err(status, "Link output buffer");
|
||||
// LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR
|
||||
libxsmm_input = libxsmm_dnn_link_tensor(regentry.layout_input, input, &status);
|
||||
CHECK_LIBXSMM_DNN(status, "link input buffer with layout");
|
||||
|
||||
// LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR
|
||||
libxsmm_output = libxsmm_dnn_link_tensor(regentry.layout_output, output, &status);
|
||||
CHECK_LIBXSMM_DNN(status, "link output buffer with layout");
|
||||
|
||||
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
|
||||
kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
|
||||
libxsmm_filter = libxsmm_dnn_link_filter(
|
||||
libxsmm_handle, LIBXSMM_DNN_FILTER, native_filter,
|
||||
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR, &status);
|
||||
chk_libxsmm_err(status, "Link filter");
|
||||
// LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR
|
||||
libxsmm_filter = libxsmm_dnn_link_tensor(regentry.layout_filter, native_filter, &status);
|
||||
CHECK_LIBXSMM_DNN(status, "link filter with layout");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter, LIBXSMM_DNN_REGULAR_FILTER),
|
||||
"bind filter to handle");
|
||||
}
|
||||
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
|
||||
LIBXSMM_DNN_REGULAR_INPUT),
|
||||
"Bind input forward");
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output,
|
||||
LIBXSMM_DNN_REGULAR_OUTPUT),
|
||||
"Bind output forward");
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter,
|
||||
LIBXSMM_DNN_REGULAR_FILTER),
|
||||
"Bind filter forward");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input, LIBXSMM_DNN_REGULAR_INPUT),
|
||||
"bind input forward");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter, LIBXSMM_DNN_REGULAR_FILTER),
|
||||
"bind filter forward");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output, LIBXSMM_DNN_REGULAR_OUTPUT),
|
||||
"bind output forward");
|
||||
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
|
||||
chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_input), "Zero input");
|
||||
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
|
||||
LIBXSMM_DNN_GRADIENT_INPUT),
|
||||
"Bind input backward");
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output,
|
||||
LIBXSMM_DNN_GRADIENT_OUTPUT),
|
||||
"Bind output backward");
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter,
|
||||
LIBXSMM_DNN_REGULAR_FILTER),
|
||||
"Bind filter backward");
|
||||
CHECK_LIBXSMM_DNN(libxsmm_dnn_zero_tensor(libxsmm_input), "zeroing input");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input, LIBXSMM_DNN_GRADIENT_INPUT),
|
||||
"bind input backward");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter, LIBXSMM_DNN_REGULAR_FILTER),
|
||||
"bind filter backward");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output, LIBXSMM_DNN_GRADIENT_OUTPUT),
|
||||
"bind output backward");
|
||||
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
|
||||
chk_libxsmm_err(libxsmm_dnn_zero_filter(libxsmm_filter), "Zero filter");
|
||||
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
|
||||
LIBXSMM_DNN_REGULAR_INPUT),
|
||||
"Bind input weight update");
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output,
|
||||
LIBXSMM_DNN_GRADIENT_OUTPUT),
|
||||
"Bind output weight update");
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter,
|
||||
LIBXSMM_DNN_GRADIENT_FILTER),
|
||||
"Bind filter weight update");
|
||||
CHECK_LIBXSMM_DNN(libxsmm_dnn_zero_tensor(libxsmm_filter), "zeroing filter");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input, LIBXSMM_DNN_REGULAR_INPUT),
|
||||
"bind input weight update");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter, LIBXSMM_DNN_GRADIENT_FILTER),
|
||||
"bind filter weight update");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output, LIBXSMM_DNN_GRADIENT_OUTPUT),
|
||||
"bind output weight update");
|
||||
} else {
|
||||
/* shouldn't happen */
|
||||
assert(0/*should not happen*/);
|
||||
}
|
||||
|
||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||
l_tick4 = libxsmm_timer_tick();
|
||||
#endif
|
||||
|
||||
/* bind scratch */
|
||||
scratch = (void*)libxsmm_aligned_scratch(
|
||||
libxsmm_dnn_get_scratch_size(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL,
|
||||
&status),
|
||||
2097152);
|
||||
chk_libxsmm_err(status, "scratch allocation");
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_scratch(
|
||||
libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch),
|
||||
"binding scratch");
|
||||
const size_t scratch_size = libxsmm_dnn_get_scratch_size(regentry.handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, &status);
|
||||
CHECK_LIBXSMM_DNN(status, "get scratch size");
|
||||
void *const scratch = libxsmm_aligned_scratch(scratch_size, 2097152/*alignment*/);
|
||||
CHECK_LIBXSMM(0 != scratch, "scratch memory allocation");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_bind_scratch(regentry.handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch),
|
||||
"binding scratch");
|
||||
|
||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||
l_tick5 = libxsmm_timer_tick();
|
||||
#endif
|
||||
|
||||
if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
|
||||
libxsmm_dnn_transpose_filter(libxsmm_handle, LIBXSMM_DNN_FILTER);
|
||||
libxsmm_dnn_transpose_filter(regentry.handle, LIBXSMM_DNN_FILTER);
|
||||
}
|
||||
|
||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||
l_tick6 = libxsmm_timer_tick();
|
||||
#endif
|
||||
|
||||
#if !defined(_OPENMP) || !defined(LIBXSMM_USE_OPENMP)
|
||||
BlockingCounter counter(num_threads);
|
||||
|
||||
for (int i = 0; i < num_threads; ++i) {
|
||||
worker_threads->workers->Schedule([=, &counter]() {
|
||||
chk_libxsmm_err(libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, i),
|
||||
"Worker");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_execute_st(regentry.handle, kind, 0, i),
|
||||
"worker");
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
#else
|
||||
# pragma omp parallel
|
||||
{
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_execute_st(regentry.handle, kind, 0, omp_get_thread_num()),
|
||||
"worker");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||
l_tick7 = libxsmm_timer_tick();
|
||||
#endif
|
||||
|
||||
if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
|
||||
libxsmm_dnn_reduce_wu_filters(libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER);
|
||||
libxsmm_dnn_reduce_wu_filters(regentry.handle, LIBXSMM_DNN_GRADIENT_FILTER);
|
||||
}
|
||||
|
||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||
|
@ -389,54 +440,50 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
|||
#endif
|
||||
|
||||
/* clean up */
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_release_scratch(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL),
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_release_scratch(regentry.handle, LIBXSMM_DNN_COMPUTE_KIND_ALL),
|
||||
"release scratch");
|
||||
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT),
|
||||
"release input");
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_OUTPUT),
|
||||
"release output");
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER),
|
||||
"release filter");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_INPUT),
|
||||
"release input");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_OUTPUT),
|
||||
"release output");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_FILTER),
|
||||
"release filter");
|
||||
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_INPUT),
|
||||
"release input");
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT),
|
||||
"release output");
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER),
|
||||
"release filter");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_GRADIENT_INPUT),
|
||||
"release input");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_GRADIENT_OUTPUT),
|
||||
"release output");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_FILTER),
|
||||
"release filter");
|
||||
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT),
|
||||
"release input");
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT),
|
||||
"release output");
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER),
|
||||
"release filter");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_INPUT),
|
||||
"release input");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_GRADIENT_OUTPUT),
|
||||
"release output");
|
||||
CHECK_LIBXSMM_DNN(
|
||||
libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_GRADIENT_FILTER),
|
||||
"release filter");
|
||||
} else {
|
||||
/* shouldn't happen */
|
||||
}
|
||||
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_input), "Destroy input");
|
||||
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_output), "Destroy output");
|
||||
chk_libxsmm_err(libxsmm_dnn_destroy_filter(libxsmm_filter), "Destroy filter");
|
||||
CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_input), "destroy input");
|
||||
CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_output), "destroy output");
|
||||
CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_filter), "destroy filter");
|
||||
|
||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||
l_tick9 = libxsmm_timer_tick();
|
||||
#endif
|
||||
|
||||
// if(kind != LIBXSMM_DNN_COMPUTE_KIND_FWD)
|
||||
// chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle),
|
||||
// "Destroy handle");
|
||||
|
||||
libxsmm_free(native_filter);
|
||||
libxsmm_free(scratch);
|
||||
|
||||
|
|
|
@ -308,12 +308,11 @@ TEST(XsmmConv2DTest, Basic) {
|
|||
desc.threads = num_threads;
|
||||
desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
|
||||
desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
|
||||
desc.filter_format =
|
||||
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;//LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
|
||||
desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;//LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
|
||||
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
|
||||
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
|
||||
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
|
||||
|
||||
desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
|
||||
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
|
||||
if (!CanUseXsmmConv2D(desc, data_format)) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -193,11 +193,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
|||
tf_http_archive(
|
||||
name = "libxsmm_archive",
|
||||
build_file = clean_dep("//third_party:libxsmm.BUILD"),
|
||||
sha256 = "5fc1972471cd8e2b8b64ea017590193739fc88d9818e3d086621e5c08e86ea35",
|
||||
strip_prefix = "libxsmm-1.11",
|
||||
sha256 = "9c0af4509ea341d1ee2c6c19fc6f19289318c3bd4b17844efeb9e7f9691abf76",
|
||||
strip_prefix = "libxsmm-1.14",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/hfp/libxsmm/archive/1.11.tar.gz",
|
||||
"https://github.com/hfp/libxsmm/archive/1.11.tar.gz",
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/hfp/libxsmm/archive/1.14.tar.gz",
|
||||
"https://github.com/hfp/libxsmm/archive/1.14.tar.gz",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -45,62 +45,39 @@ genrule(
|
|||
|
||||
cc_library(
|
||||
name = "xsmm_avx",
|
||||
srcs = [
|
||||
"src/libxsmm_cpuid_x86.c",
|
||||
"src/libxsmm_dnn.c",
|
||||
"src/libxsmm_dnn_convolution_backward.c",
|
||||
"src/libxsmm_dnn_convolution_forward.c",
|
||||
"src/libxsmm_dnn_convolution_weight_update.c",
|
||||
"src/libxsmm_dnn_convolution_winograd_backward.c",
|
||||
"src/libxsmm_dnn_convolution_winograd_forward.c",
|
||||
"src/libxsmm_dnn_convolution_winograd_weight_update.c",
|
||||
"src/libxsmm_dnn_handle.c",
|
||||
"src/libxsmm_dump.c",
|
||||
"src/libxsmm_ext_gemm.c",
|
||||
"src/libxsmm_ext_trans.c",
|
||||
"src/libxsmm_fsspmdm.c",
|
||||
"src/libxsmm_gemm.c",
|
||||
"src/libxsmm_main.c",
|
||||
"src/libxsmm_malloc.c",
|
||||
"src/libxsmm_perf.c",
|
||||
"src/libxsmm_spmdm.c",
|
||||
"src/libxsmm_sync.c",
|
||||
"src/libxsmm_timer.c",
|
||||
"src/libxsmm_trace.c",
|
||||
"src/libxsmm_trans.c",
|
||||
] + glob([
|
||||
srcs = glob([
|
||||
# general source files (translation units)
|
||||
"src/generator_*.c",
|
||||
"src/libxsmm_*.c",
|
||||
], exclude=[
|
||||
# exclude generators (with main functions)
|
||||
"src/libxsmm_generator_*.c",
|
||||
]),
|
||||
hdrs = [
|
||||
"include/libxsmm_cpuid.h",
|
||||
"include/libxsmm_dnn.h",
|
||||
"include/libxsmm_frontend.h",
|
||||
"include/libxsmm_fsspmdm.h",
|
||||
"include/libxsmm_generator.h",
|
||||
"include/libxsmm_intrinsics_x86.h",
|
||||
"include/libxsmm_macros.h",
|
||||
"include/libxsmm_malloc.h",
|
||||
"include/libxsmm_spmdm.h",
|
||||
"include/libxsmm_sync.h",
|
||||
"include/libxsmm_timer.h",
|
||||
"include/libxsmm_typedefs.h",
|
||||
# Source files #included internally:
|
||||
"src/libxsmm_gemm_diff.c",
|
||||
"src/libxsmm_hash.c",
|
||||
# Generated:
|
||||
hdrs = glob([
|
||||
# general header files
|
||||
"include/libxsmm_*.h",
|
||||
# trigger rebuild if template changed
|
||||
"src/template/*.c",
|
||||
], exclude=[
|
||||
# exclude existing/generated headers
|
||||
"include/libxsmm.h",
|
||||
"include/libxsmm_config.h",
|
||||
"include/libxsmm_dispatch.h",
|
||||
]) + [
|
||||
# source files included internally
|
||||
"src/libxsmm_hash.c",
|
||||
# generated header files
|
||||
"include/libxsmm.h",
|
||||
"include/libxsmm_config.h",
|
||||
"include/libxsmm_dispatch.h",
|
||||
] + glob([
|
||||
# trigger rebuild if template changed
|
||||
"src/template/*.c",
|
||||
]),
|
||||
copts = [
|
||||
"-mavx", # JIT does not work without avx anyway, and this silences some CRC32 warnings.
|
||||
"-Wno-vla", # Libxsmm convolutions heavily use VLA.
|
||||
],
|
||||
#copts = [
|
||||
# "-mavx", # JIT does not work without avx anyway, and this silences some CRC32 warnings.
|
||||
# "-Wno-vla", # Libxsmm convolutions heavily use VLA.
|
||||
#],
|
||||
defines = [
|
||||
"LIBXSMM_BUILD",
|
||||
"LIBXSMM_CTOR",
|
||||
"__BLAS=0",
|
||||
],
|
||||
includes = [
|
||||
|
|
Loading…
Reference in New Issue