Merge pull request #24706 from hfp:upstream

PiperOrigin-RevId: 281336500
Change-Id: I65db6e77184d62717133ee8e61cbb7bf4b42bb56
This commit is contained in:
TensorFlower Gardener 2019-11-19 12:31:39 -08:00
commit e776cbc7ca
8 changed files with 308 additions and 268 deletions

View File

@ -241,8 +241,8 @@ struct LaunchXsmmBackwardFilter<CPUDevice, float> {
desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
if (!CanUseXsmmConv2D(desc, data_format)) {
return false;
}

View File

@ -290,9 +290,9 @@ struct LaunchXsmmBackwardInputConvolution<CPUDevice, float> {
desc.filter_format =
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM; // LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE;
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
desc.options = LIBXSMM_DNN_CONV_OPTION_OVERWRITE;
desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
auto input_ptr = input_backward.data();
auto filter_ptr = kernel.data();
auto output_ptr = output_backward.data();

View File

@ -320,9 +320,9 @@ class LaunchXsmmConvOp<CPUDevice, float> {
desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE;
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
desc.options = LIBXSMM_DNN_CONV_OPTION_OVERWRITE;
desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
if (dilation_rows != 1 || dilation_cols != 1 ||
!CanUseXsmmConv2D(desc, data_format)) {
return false;

View File

@ -1396,8 +1396,8 @@ void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, int tid,
int nthreads) {
return libxsmm_spmdm_createSparseSlice_bfloat16_thread(
handle, transA, reinterpret_cast<const uint16*>(A), libxsmm_output_csr_a,
block_id, tid, nthreads);
handle, transA, reinterpret_cast<const libxsmm_bfloat16*>(A),
libxsmm_output_csr_a, block_id, tid, nthreads);
}
void wrapper_libxsmm_spmdm_compute_generic_thread(
@ -1406,9 +1406,10 @@ void wrapper_libxsmm_spmdm_compute_generic_thread(
libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC,
const bfloat16* beta, float* C, int block_id, int tid, int nthreads) {
return libxsmm_spmdm_compute_bfloat16_thread(
handle, transA, transB, reinterpret_cast<const uint16*>(alpha), A_sparse,
reinterpret_cast<const uint16*>(B), transC,
reinterpret_cast<const uint16*>(beta), C, block_id, tid, nthreads);
handle, transA, transB, reinterpret_cast<const libxsmm_bfloat16*>(alpha),
A_sparse, reinterpret_cast<const libxsmm_bfloat16*>(B), transC,
reinterpret_cast<const libxsmm_bfloat16*>(beta), C, block_id, tid,
nthreads);
}
void wrapper_libxsmm_spmdm_compute_generic_thread(
empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
@ -1427,13 +1428,6 @@ inline void LibxsmmSparseMatMul<TL, TR>::Compute(
const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapR& right,
bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
bool transpose_output, MatrixMap* output) {
if (false) {
// Not handled by libxsmm currently
SparseMatMul<TL, TR>::Compute(
nullptr /* Assumes no cached data for fallback */, left, right,
transpose_left, thread_pool, transpose_output, output);
return;
}
const int num_threads = thread_pool->num_threads;
const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
@ -1444,6 +1438,7 @@ inline void LibxsmmSparseMatMul<TL, TR>::Compute(
(transpose_output ? output->dimension(1) : output->dimension(0)));
CHECK_EQ(right_dim1,
(transpose_output ? output->dimension(0) : output->dimension(1)));
#if 0 // this issue seems to be resolved
if (left_dim0 < 32 || left_dim1 < 32 || right_dim1 < 32) {
// Causes problems in libxsmm
SparseMatMul<TL, TR>::Compute(
@ -1451,6 +1446,7 @@ inline void LibxsmmSparseMatMul<TL, TR>::Compute(
transpose_left, thread_pool, transpose_output, output);
return;
}
#endif
auto left_data = left.data();
auto right_data = right.data();
auto output_data = output->data();
@ -1640,15 +1636,14 @@ inline void SparseMatMul<TL, TR>::Compute(
SparseMatMulOp<TA, TB, LibxsmmSparseMatMul>);
#endif
REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
REGISTER_SPARSE_MATMUL(float, bfloat16);
REGISTER_SPARSE_MATMUL(bfloat16, float);
#ifdef TENSORFLOW_USE_LIBXSMM
REGISTER_SPARSE_MATMUL_LIBXSMM(bfloat16, bfloat16);
REGISTER_SPARSE_MATMUL_LIBXSMM(float, float);
#else
REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
REGISTER_SPARSE_MATMUL(float, float);
#endif

View File

@ -27,6 +27,9 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty();
#include <stdlib.h>
#include <cstring>
#if defined(_OPENMP) && defined(LIBXSMM_USE_OPENMP)
#include <omp.h>
#endif
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
@ -36,6 +39,12 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty();
#include "include/libxsmm_malloc.h"
#include "src/libxsmm_main.h" // TODO(bsteiner): API to avoid incl. header from src/
#define CHECK_LIBXSMM(CONDITION_OK, MESSAGE) \
if (!(CONDITION_OK)) VLOG(0) << (MESSAGE)
#define CHECK_LIBXSMM_DNN(STATUS, MESSAGE) \
CHECK_LIBXSMM(LIBXSMM_DNN_SUCCESS == (STATUS), MESSAGE) \
<< " failed: " << libxsmm_dnn_get_error(STATUS);
namespace tensorflow {
// Xsmm*Conv2D are wrappers for libxsmm direct convolutions.
@ -73,12 +82,6 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
namespace functor {
static void chk_libxsmm_err(libxsmm_dnn_err_t status, string msg) {
if (status != LIBXSMM_DNN_SUCCESS) {
VLOG(0) << msg << " failed: " << libxsmm_dnn_get_error(status);
}
}
LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float* kcrs, int R,
int S, int C, int K, int blocksifm,
int blocksofm, int ifmblock,
@ -114,55 +117,117 @@ LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float* kcrs, int R,
}
}
class libxsmm_dnn_conv_desc_wrap {
public:
const libxsmm_dnn_conv_desc d;
libxsmm_dnn_conv_desc_wrap(const libxsmm_dnn_conv_desc& d_) : d(d_) {}
bool operator==(const libxsmm_dnn_conv_desc_wrap& w) const {
return (d.N == w.d.N && d.C == w.d.C && d.H == w.d.H && d.W == w.d.W &&
d.K == w.d.K && d.R == w.d.R && d.S == w.d.S && d.u == w.d.u &&
d.v == w.d.v && d.pad_h == w.d.pad_h && d.pad_w == w.d.pad_w);
struct libxsmm_dnn_registry_key {
const libxsmm_dnn_conv_desc descriptor;
libxsmm_dnn_registry_key(const libxsmm_dnn_conv_desc& desc_)
: descriptor(desc_) {}
bool operator==(const libxsmm_dnn_registry_key& regkey) const {
return 0 == memcmp(&descriptor, &regkey.descriptor, sizeof(descriptor));
}
};
struct HashFunction {
std::size_t operator()(const libxsmm_dnn_conv_desc_wrap& w) const {
return libxsmm_hash(&w.d, sizeof(w.d), 25071975);
std::size_t operator()(const libxsmm_dnn_registry_key& regkey) const {
return libxsmm_hash(&regkey.descriptor, sizeof(regkey.descriptor),
25071975);
}
};
class handles {
struct libxsmm_dnn_registry_value {
libxsmm_dnn_tensor_datalayout* layout_input;
libxsmm_dnn_tensor_datalayout* layout_filter;
libxsmm_dnn_tensor_datalayout* layout_output;
libxsmm_dnn_layer* handle;
};
typedef libxsmm_tf_allocator<libxsmm_scratch_allocator>
libxsmm_tf_scratch_allocator;
static class libxsmm_dnn_registry_type {
private:
typedef std::unordered_map<libxsmm_dnn_registry_key,
libxsmm_dnn_registry_value, HashFunction>
container_type;
public:
libxsmm_dnn_layer* find(const libxsmm_dnn_conv_desc_wrap& w) {
std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
HashFunction>::iterator i = libxsmm_handles.find(w);
if (i == libxsmm_handles.end()) {
libxsmm_dnn_registry_type() {
libxsmm_init(); /* must be first */
#if !defined(LIBXSMM_LOCAL_ALLOC)
{
libxsmm_malloc_function malloc_fn;
libxsmm_free_function free_fn;
malloc_fn.function = libxsmm_tf_scratch_allocator::malloc;
free_fn.function = libxsmm_tf_scratch_allocator::free;
libxsmm_set_scratch_allocator(0 /*context*/, malloc_fn, free_fn);
}
#endif
LIBXSMM_LOCK_ATTR_INIT(LIBXSMM_LOCK_RWLOCK, &attr);
LIBXSMM_LOCK_INIT(LIBXSMM_LOCK_RWLOCK, &lock, &attr);
}
~libxsmm_dnn_registry_type() {
LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK_RWLOCK, &lock);
const container_type::const_iterator end = container.end();
for (container_type::const_iterator i = container.begin(); i != end; ++i) {
CHECK_LIBXSMM_DNN(
libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_input),
"destroy input layout");
CHECK_LIBXSMM_DNN(
libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_output),
"destroy output layout");
CHECK_LIBXSMM_DNN(
libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_filter),
"destroy filter layout");
CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_conv_layer(i->second.handle),
"destroy handle");
}
LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_RWLOCK, &lock);
LIBXSMM_LOCK_DESTROY(LIBXSMM_LOCK_RWLOCK, &lock);
LIBXSMM_LOCK_ATTR_DESTROY(LIBXSMM_LOCK_RWLOCK, &attr);
libxsmm_finalize();
}
libxsmm_dnn_registry_value find(const libxsmm_dnn_registry_key& regkey) {
container_type::iterator i;
LIBXSMM_LOCK_ACQREAD(LIBXSMM_LOCK_RWLOCK, &lock);
i = container.find(regkey);
LIBXSMM_LOCK_RELREAD(LIBXSMM_LOCK_RWLOCK, &lock);
if (i == container.end()) {
libxsmm_dnn_err_t status;
libxsmm_dnn_layer* libxsmm_handle =
libxsmm_dnn_create_conv_layer(w.d, &status);
chk_libxsmm_err(status, "Create handle");
libxsmm_handles.insert(std::make_pair(w, libxsmm_handle));
return libxsmm_handle;
} else {
libxsmm_dnn_registry_value regentry;
LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK_RWLOCK, &lock);
i = container.find(regkey);
if (i == container.end()) { // re-check after lock acquisition
regentry.handle =
libxsmm_dnn_create_conv_layer(regkey.descriptor, &status);
if (LIBXSMM_DNN_WARN_FALLBACK != status) {
CHECK_LIBXSMM_DNN(status, "create handle");
} else { // warning
VLOG(1) << libxsmm_dnn_get_error(status);
}
regentry.layout_input = libxsmm_dnn_create_tensor_datalayout(
regentry.handle, LIBXSMM_DNN_INPUT, &status);
CHECK_LIBXSMM_DNN(status, "create input layout");
regentry.layout_output = libxsmm_dnn_create_tensor_datalayout(
regentry.handle, LIBXSMM_DNN_OUTPUT, &status);
CHECK_LIBXSMM_DNN(status, "create output layout");
regentry.layout_filter = libxsmm_dnn_create_tensor_datalayout(
regentry.handle, LIBXSMM_DNN_FILTER, &status);
CHECK_LIBXSMM_DNN(status, "create filter layout");
i = container.insert(std::make_pair(regkey, regentry)).first;
}
LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_RWLOCK, &lock);
}
return i->second;
}
}
~handles() {
std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
HashFunction>::iterator i;
for (i = libxsmm_handles.begin(); i != libxsmm_handles.end(); i++)
chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(i->second),
"Destroy handle");
}
private:
std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
HashFunction>
libxsmm_handles;
};
static handles libxsmm_handles;
container_type container;
LIBXSMM_LOCK_ATTR_TYPE(LIBXSMM_LOCK_RWLOCK) attr;
LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK_RWLOCK) lock;
} libxsmm_dnn_registry;
// #define LIBXSMM_DETAILED_TIMING
@ -173,73 +238,64 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
InputPtr input, FilterPtr filter,
OutputPtr output) {
#if defined(LIBXSMM_DETAILED_TIMING)
uint64 l_tick1;
uint64 l_tick2;
uint64 l_tick3;
uint64 l_tick4;
uint64 l_tick5;
uint64 l_tick6;
uint64 l_tick7;
uint64 l_tick8;
uint64 l_tick9;
uint64 l_tick10;
libxsmm_timer_tickint l_tick1;
libxsmm_timer_tickint l_tick2;
libxsmm_timer_tickint l_tick3;
libxsmm_timer_tickint l_tick4;
libxsmm_timer_tickint l_tick5;
libxsmm_timer_tickint l_tick6;
libxsmm_timer_tickint l_tick7;
libxsmm_timer_tickint l_tick8;
libxsmm_timer_tickint l_tick9;
libxsmm_timer_tickint l_tick10;
l_tick1 = libxsmm_timer_tick();
#endif
// setup scoped allocator, which adopts the allocator from the context
const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator(*ctx);
#if defined(LIBXSMM_LOCAL_ALLOC)
// setup scoped allocator, which adopts the allocator of the current context
const libxsmm_tf_scratch_allocator tf_allocator(*ctx);
#endif
const libxsmm_dnn_registry_key regkey(desc);
const libxsmm_dnn_registry_value regentry = libxsmm_dnn_registry.find(regkey);
libxsmm_dnn_tensor *libxsmm_input, *libxsmm_output, *libxsmm_filter;
libxsmm_dnn_err_t status;
libxsmm_dnn_layer* libxsmm_handle;
libxsmm_dnn_conv_desc_wrap w(desc);
void* scratch;
// if(kind == LIBXSMM_DNN_COMPUTE_KIND_FWD)
libxsmm_handle = libxsmm_handles.find(w);
// else{
// libxsmm_handle = libxsmm_dnn_create_conv_layer(desc, &status);
// chk_libxsmm_err(status, "Create handle");
//}
status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind);
status = libxsmm_dnn_get_codegen_success(regentry.handle, kind);
if (status == LIBXSMM_DNN_WARN_FALLBACK) {
return false; // Use non-libxsmm code
}
chk_libxsmm_err(status, "Check codegen status");
libxsmm_dnn_buffer* libxsmm_input;
libxsmm_dnn_buffer* libxsmm_output;
libxsmm_dnn_filter* libxsmm_filter;
CHECK_LIBXSMM_DNN(status, "code generation");
#if defined(LIBXSMM_DETAILED_TIMING)
l_tick2 = libxsmm_timer_tick();
#endif
int ifmblock = (libxsmm_handle->ifmblock);
int ofmblock = (libxsmm_handle->ofmblock);
const int ifmblock = regentry.handle->ifmblock;
const int ofmblock = regentry.handle->ofmblock;
int blocksifm =
desc.C % ifmblock == 0 ? desc.C / ifmblock : desc.C / ifmblock + 1;
int blocksofm =
desc.K % ofmblock == 0 ? desc.K / ofmblock : desc.K / ofmblock + 1;
float* native_filter =
(float*)libxsmm_aligned_scratch(blocksofm * blocksifm * desc.R * desc.S *
ifmblock * ofmblock * sizeof(float),
2097152);
const int blocksifm =
(desc.C % ifmblock == 0 ? desc.C / ifmblock : desc.C / ifmblock + 1);
const int blocksofm =
(desc.K % ofmblock == 0 ? desc.K / ofmblock : desc.K / ofmblock + 1);
const DeviceBase::CpuWorkerThreads* worker_threads =
const size_t filter_size =
blocksofm * blocksifm * desc.R * desc.S * ifmblock * ofmblock;
float* const native_filter = (float*)libxsmm_aligned_scratch(
filter_size * sizeof(float), 2097152 /*alignment*/);
const DeviceBase::CpuWorkerThreads* const worker_threads =
ctx->device()->tensorflow_cpu_worker_threads();
int num_threads = worker_threads->num_threads;
const int num_threads = worker_threads->num_threads;
#if 1
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
if (blocksofm > num_threads) {
int work = blocksofm;
const int work = blocksofm;
BlockingCounter count(num_threads);
for (int i = 0; i < num_threads; ++i) {
worker_threads->workers->Schedule([=, &count]() {
int start = work / num_threads * i;
int end = (start + work / num_threads) > work
const int start = work / num_threads * i;
const int end = (start + work / num_threads) > work
? work
: start + work / num_threads;
copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
@ -250,14 +306,14 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
}
count.Wait();
} else {
int work = blocksofm;
int num_threads = work;
const int work = blocksofm;
const int num_tasks = work;
BlockingCounter count(num_threads);
for (int i = 0; i < num_threads; ++i) {
BlockingCounter count(num_tasks);
for (int i = 0; i < num_tasks; ++i) {
worker_threads->workers->Schedule([=, &count]() {
int start = i;
int end = i + 1;
const int start = i;
const int end = i + 1;
copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
desc.K, blocksifm, blocksofm, ifmblock, ofmblock,
start, end);
@ -267,90 +323,89 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
count.Wait();
}
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
// Added: for weight update
// weight update buffer must be in the right format
// (LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR)
libxsmm_filter =
libxsmm_dnn_link_filter(libxsmm_handle, LIBXSMM_DNN_FILTER, filter,
LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR, &status);
chk_libxsmm_err(status,
"Link filter"); // weight update is in RSCK as
// filter should be returned in RSCK
// format
libxsmm_dnn_link_tensor(regentry.layout_filter, filter, &status);
CHECK_LIBXSMM_DNN(status, "link filter with layout");
}
#else
memset(native_filter, 0,
blocksofm * blocksifm * desc.R * desc.S * ifmblock * ofmblock *
sizeof(float));
memset(native_filter, 0, filter_size * sizeof(float));
#endif
#if defined(LIBXSMM_DETAILED_TIMING)
l_tick3 = libxsmm_timer_tick();
#endif
// LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR
libxsmm_input =
libxsmm_dnn_link_buffer(libxsmm_handle, LIBXSMM_DNN_INPUT, input,
LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
chk_libxsmm_err(status, "Link input buffer");
libxsmm_dnn_link_tensor(regentry.layout_input, input, &status);
CHECK_LIBXSMM_DNN(status, "link input buffer with layout");
// LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR
libxsmm_output =
libxsmm_dnn_link_buffer(libxsmm_handle, LIBXSMM_DNN_OUTPUT, output,
LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
chk_libxsmm_err(status, "Link output buffer");
libxsmm_dnn_link_tensor(regentry.layout_output, output, &status);
CHECK_LIBXSMM_DNN(status, "link output buffer with layout");
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
libxsmm_filter = libxsmm_dnn_link_filter(
libxsmm_handle, LIBXSMM_DNN_FILTER, native_filter,
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR, &status);
chk_libxsmm_err(status, "Link filter");
// LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR
libxsmm_filter =
libxsmm_dnn_link_tensor(regentry.layout_filter, native_filter, &status);
CHECK_LIBXSMM_DNN(status, "link filter with layout");
CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
LIBXSMM_DNN_REGULAR_FILTER),
"bind filter to handle");
}
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input,
LIBXSMM_DNN_REGULAR_INPUT),
"Bind input forward");
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output,
"bind input forward");
CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
LIBXSMM_DNN_REGULAR_FILTER),
"bind filter forward");
CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output,
LIBXSMM_DNN_REGULAR_OUTPUT),
"Bind output forward");
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter,
LIBXSMM_DNN_REGULAR_FILTER),
"Bind filter forward");
"bind output forward");
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_input), "Zero input");
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
CHECK_LIBXSMM_DNN(libxsmm_dnn_zero_tensor(libxsmm_input), "zeroing input");
CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input,
LIBXSMM_DNN_GRADIENT_INPUT),
"Bind input backward");
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output,
LIBXSMM_DNN_GRADIENT_OUTPUT),
"Bind output backward");
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter,
"bind input backward");
CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
LIBXSMM_DNN_REGULAR_FILTER),
"Bind filter backward");
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
chk_libxsmm_err(libxsmm_dnn_zero_filter(libxsmm_filter), "Zero filter");
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
LIBXSMM_DNN_REGULAR_INPUT),
"Bind input weight update");
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output,
"bind filter backward");
CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output,
LIBXSMM_DNN_GRADIENT_OUTPUT),
"Bind output weight update");
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter,
"bind output backward");
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
CHECK_LIBXSMM_DNN(libxsmm_dnn_zero_tensor(libxsmm_filter),
"zeroing filter");
CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input,
LIBXSMM_DNN_REGULAR_INPUT),
"bind input weight update");
CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
LIBXSMM_DNN_GRADIENT_FILTER),
"Bind filter weight update");
"bind filter weight update");
CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output,
LIBXSMM_DNN_GRADIENT_OUTPUT),
"bind output weight update");
} else {
/* shouldn't happen */
assert(0 /*should not happen*/);
}
#if defined(LIBXSMM_DETAILED_TIMING)
l_tick4 = libxsmm_timer_tick();
#endif
/* bind scratch */
scratch = (void*)libxsmm_aligned_scratch(
libxsmm_dnn_get_scratch_size(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL,
&status),
2097152);
chk_libxsmm_err(status, "scratch allocation");
chk_libxsmm_err(libxsmm_dnn_bind_scratch(
libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch),
const size_t scratch_size = libxsmm_dnn_get_scratch_size(
regentry.handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, &status);
CHECK_LIBXSMM_DNN(status, "get scratch size");
void* const scratch =
libxsmm_aligned_scratch(scratch_size, 2097152 /*alignment*/);
CHECK_LIBXSMM(0 != scratch, "scratch memory allocation");
CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_scratch(
regentry.handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch),
"binding scratch");
#if defined(LIBXSMM_DETAILED_TIMING)
@ -358,30 +413,39 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
#endif
if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
libxsmm_dnn_transpose_filter(libxsmm_handle, LIBXSMM_DNN_FILTER);
libxsmm_dnn_transpose_filter(regentry.handle, LIBXSMM_DNN_FILTER);
}
#if defined(LIBXSMM_DETAILED_TIMING)
l_tick6 = libxsmm_timer_tick();
#endif
#if !defined(_OPENMP) || !defined(LIBXSMM_USE_OPENMP)
BlockingCounter counter(num_threads);
for (int i = 0; i < num_threads; ++i) {
worker_threads->workers->Schedule([=, &counter]() {
chk_libxsmm_err(libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, i),
"Worker");
CHECK_LIBXSMM_DNN(libxsmm_dnn_execute_st(regentry.handle, kind, 0, i),
"worker");
counter.DecrementCount();
});
}
counter.Wait();
#else
#pragma omp parallel
{
CHECK_LIBXSMM_DNN(
libxsmm_dnn_execute_st(regentry.handle, kind, 0, omp_get_thread_num()),
"worker");
}
#endif
#if defined(LIBXSMM_DETAILED_TIMING)
l_tick7 = libxsmm_timer_tick();
#endif
if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
libxsmm_dnn_reduce_wu_filters(libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER);
libxsmm_dnn_reduce_wu_filters(regentry.handle, LIBXSMM_DNN_GRADIENT_FILTER);
}
#if defined(LIBXSMM_DETAILED_TIMING)
@ -389,54 +453,52 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
#endif
/* clean up */
chk_libxsmm_err(
libxsmm_dnn_release_scratch(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL),
CHECK_LIBXSMM_DNN(libxsmm_dnn_release_scratch(regentry.handle,
LIBXSMM_DNN_COMPUTE_KIND_ALL),
"release scratch");
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
chk_libxsmm_err(
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT),
CHECK_LIBXSMM_DNN(
libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_INPUT),
"release input");
chk_libxsmm_err(
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_OUTPUT),
CHECK_LIBXSMM_DNN(
libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_OUTPUT),
"release output");
chk_libxsmm_err(
libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER),
CHECK_LIBXSMM_DNN(
libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_FILTER),
"release filter");
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
chk_libxsmm_err(
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_INPUT),
CHECK_LIBXSMM_DNN(
libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_GRADIENT_INPUT),
"release input");
chk_libxsmm_err(
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT),
CHECK_LIBXSMM_DNN(libxsmm_dnn_release_tensor(regentry.handle,
LIBXSMM_DNN_GRADIENT_OUTPUT),
"release output");
chk_libxsmm_err(
libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER),
CHECK_LIBXSMM_DNN(
libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_FILTER),
"release filter");
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
chk_libxsmm_err(
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT),
CHECK_LIBXSMM_DNN(
libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_INPUT),
"release input");
chk_libxsmm_err(
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT),
CHECK_LIBXSMM_DNN(libxsmm_dnn_release_tensor(regentry.handle,
LIBXSMM_DNN_GRADIENT_OUTPUT),
"release output");
chk_libxsmm_err(
libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER),
CHECK_LIBXSMM_DNN(libxsmm_dnn_release_tensor(regentry.handle,
LIBXSMM_DNN_GRADIENT_FILTER),
"release filter");
} else {
/* shouldn't happen */
}
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_input), "Destroy input");
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_output), "Destroy output");
chk_libxsmm_err(libxsmm_dnn_destroy_filter(libxsmm_filter), "Destroy filter");
CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_input), "destroy input");
CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_output),
"destroy output");
CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_filter),
"destroy filter");
#if defined(LIBXSMM_DETAILED_TIMING)
l_tick9 = libxsmm_timer_tick();
#endif
// if(kind != LIBXSMM_DNN_COMPUTE_KIND_FWD)
// chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle),
// "Destroy handle");
libxsmm_free(native_filter);
libxsmm_free(scratch);

View File

@ -312,8 +312,8 @@ TEST(XsmmConv2DTest, Basic) {
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;//LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
if (!CanUseXsmmConv2D(desc, data_format)) {
return false;
}

View File

@ -193,11 +193,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "libxsmm_archive",
build_file = clean_dep("//third_party:libxsmm.BUILD"),
sha256 = "5fc1972471cd8e2b8b64ea017590193739fc88d9818e3d086621e5c08e86ea35",
strip_prefix = "libxsmm-1.11",
sha256 = "9c0af4509ea341d1ee2c6c19fc6f19289318c3bd4b17844efeb9e7f9691abf76",
strip_prefix = "libxsmm-1.14",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/hfp/libxsmm/archive/1.11.tar.gz",
"https://github.com/hfp/libxsmm/archive/1.11.tar.gz",
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/hfp/libxsmm/archive/1.14.tar.gz",
"https://github.com/hfp/libxsmm/archive/1.14.tar.gz",
],
)

View File

@ -45,62 +45,45 @@ genrule(
cc_library(
name = "xsmm_avx",
srcs = [
"src/libxsmm_cpuid_x86.c",
"src/libxsmm_dnn.c",
"src/libxsmm_dnn_convolution_backward.c",
"src/libxsmm_dnn_convolution_forward.c",
"src/libxsmm_dnn_convolution_weight_update.c",
"src/libxsmm_dnn_convolution_winograd_backward.c",
"src/libxsmm_dnn_convolution_winograd_forward.c",
"src/libxsmm_dnn_convolution_winograd_weight_update.c",
"src/libxsmm_dnn_handle.c",
"src/libxsmm_dump.c",
"src/libxsmm_ext_gemm.c",
"src/libxsmm_ext_trans.c",
"src/libxsmm_fsspmdm.c",
"src/libxsmm_gemm.c",
"src/libxsmm_main.c",
"src/libxsmm_malloc.c",
"src/libxsmm_perf.c",
"src/libxsmm_spmdm.c",
"src/libxsmm_sync.c",
"src/libxsmm_timer.c",
"src/libxsmm_trace.c",
"src/libxsmm_trans.c",
] + glob([
srcs = glob(
[
# general source files (translation units)
"src/generator_*.c",
]),
hdrs = [
"include/libxsmm_cpuid.h",
"include/libxsmm_dnn.h",
"include/libxsmm_frontend.h",
"include/libxsmm_fsspmdm.h",
"include/libxsmm_generator.h",
"include/libxsmm_intrinsics_x86.h",
"include/libxsmm_macros.h",
"include/libxsmm_malloc.h",
"include/libxsmm_spmdm.h",
"include/libxsmm_sync.h",
"include/libxsmm_timer.h",
"include/libxsmm_typedefs.h",
# Source files #included internally:
"src/libxsmm_gemm_diff.c",
"src/libxsmm_hash.c",
# Generated:
"src/libxsmm_*.c",
],
exclude = [
# exclude generators (with main functions)
"src/libxsmm_generator_*.c",
],
),
hdrs = glob(
[
# general header files
"include/libxsmm_*.h",
# trigger rebuild if template changed
"src/template/*.c",
],
exclude = [
# exclude existing/generated headers
"include/libxsmm.h",
"include/libxsmm_config.h",
"include/libxsmm_dispatch.h",
] + glob([
# trigger rebuild if template changed
"src/template/*.c",
]),
copts = [
"-mavx", # JIT does not work without avx anyway, and this silences some CRC32 warnings.
"-Wno-vla", # Libxsmm convolutions heavily use VLA.
],
) + [
# source files included internally
"src/libxsmm_hash.c",
# generated header files
"include/libxsmm.h",
"include/libxsmm_config.h",
"include/libxsmm_dispatch.h",
],
#copts = [
# "-mavx", # JIT does not work without avx anyway, and this silences some CRC32 warnings.
# "-Wno-vla", # Libxsmm convolutions heavily use VLA.
#],
defines = [
"LIBXSMM_BUILD",
"LIBXSMM_CTOR",
"__BLAS=0",
],
includes = [