Merge pull request #24706 from hfp:upstream
PiperOrigin-RevId: 281336500 Change-Id: I65db6e77184d62717133ee8e61cbb7bf4b42bb56
This commit is contained in:
commit
e776cbc7ca
|
@ -241,8 +241,8 @@ struct LaunchXsmmBackwardFilter<CPUDevice, float> {
|
||||||
desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
|
desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
|
||||||
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
|
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
|
||||||
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
|
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
|
||||||
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
|
desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
|
||||||
|
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
|
||||||
if (!CanUseXsmmConv2D(desc, data_format)) {
|
if (!CanUseXsmmConv2D(desc, data_format)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -290,9 +290,9 @@ struct LaunchXsmmBackwardInputConvolution<CPUDevice, float> {
|
||||||
desc.filter_format =
|
desc.filter_format =
|
||||||
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM; // LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
|
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM; // LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
|
||||||
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
|
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
|
||||||
desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE;
|
desc.options = LIBXSMM_DNN_CONV_OPTION_OVERWRITE;
|
||||||
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
|
desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
|
||||||
|
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
|
||||||
auto input_ptr = input_backward.data();
|
auto input_ptr = input_backward.data();
|
||||||
auto filter_ptr = kernel.data();
|
auto filter_ptr = kernel.data();
|
||||||
auto output_ptr = output_backward.data();
|
auto output_ptr = output_backward.data();
|
||||||
|
|
|
@ -320,9 +320,9 @@ class LaunchXsmmConvOp<CPUDevice, float> {
|
||||||
desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
|
desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
|
||||||
desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;
|
desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;
|
||||||
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
|
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
|
||||||
desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE;
|
desc.options = LIBXSMM_DNN_CONV_OPTION_OVERWRITE;
|
||||||
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
|
desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
|
||||||
|
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
|
||||||
if (dilation_rows != 1 || dilation_cols != 1 ||
|
if (dilation_rows != 1 || dilation_cols != 1 ||
|
||||||
!CanUseXsmmConv2D(desc, data_format)) {
|
!CanUseXsmmConv2D(desc, data_format)) {
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -1396,8 +1396,8 @@ void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
|
||||||
libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, int tid,
|
libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, int tid,
|
||||||
int nthreads) {
|
int nthreads) {
|
||||||
return libxsmm_spmdm_createSparseSlice_bfloat16_thread(
|
return libxsmm_spmdm_createSparseSlice_bfloat16_thread(
|
||||||
handle, transA, reinterpret_cast<const uint16*>(A), libxsmm_output_csr_a,
|
handle, transA, reinterpret_cast<const libxsmm_bfloat16*>(A),
|
||||||
block_id, tid, nthreads);
|
libxsmm_output_csr_a, block_id, tid, nthreads);
|
||||||
}
|
}
|
||||||
|
|
||||||
void wrapper_libxsmm_spmdm_compute_generic_thread(
|
void wrapper_libxsmm_spmdm_compute_generic_thread(
|
||||||
|
@ -1406,9 +1406,10 @@ void wrapper_libxsmm_spmdm_compute_generic_thread(
|
||||||
libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC,
|
libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC,
|
||||||
const bfloat16* beta, float* C, int block_id, int tid, int nthreads) {
|
const bfloat16* beta, float* C, int block_id, int tid, int nthreads) {
|
||||||
return libxsmm_spmdm_compute_bfloat16_thread(
|
return libxsmm_spmdm_compute_bfloat16_thread(
|
||||||
handle, transA, transB, reinterpret_cast<const uint16*>(alpha), A_sparse,
|
handle, transA, transB, reinterpret_cast<const libxsmm_bfloat16*>(alpha),
|
||||||
reinterpret_cast<const uint16*>(B), transC,
|
A_sparse, reinterpret_cast<const libxsmm_bfloat16*>(B), transC,
|
||||||
reinterpret_cast<const uint16*>(beta), C, block_id, tid, nthreads);
|
reinterpret_cast<const libxsmm_bfloat16*>(beta), C, block_id, tid,
|
||||||
|
nthreads);
|
||||||
}
|
}
|
||||||
void wrapper_libxsmm_spmdm_compute_generic_thread(
|
void wrapper_libxsmm_spmdm_compute_generic_thread(
|
||||||
empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
|
empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
|
||||||
|
@ -1427,13 +1428,6 @@ inline void LibxsmmSparseMatMul<TL, TR>::Compute(
|
||||||
const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapR& right,
|
const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapR& right,
|
||||||
bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
|
bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
|
||||||
bool transpose_output, MatrixMap* output) {
|
bool transpose_output, MatrixMap* output) {
|
||||||
if (false) {
|
|
||||||
// Not handled by libxsmm currently
|
|
||||||
SparseMatMul<TL, TR>::Compute(
|
|
||||||
nullptr /* Assumes no cached data for fallback */, left, right,
|
|
||||||
transpose_left, thread_pool, transpose_output, output);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const int num_threads = thread_pool->num_threads;
|
const int num_threads = thread_pool->num_threads;
|
||||||
const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
|
const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
|
||||||
const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
|
const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
|
||||||
|
@ -1444,6 +1438,7 @@ inline void LibxsmmSparseMatMul<TL, TR>::Compute(
|
||||||
(transpose_output ? output->dimension(1) : output->dimension(0)));
|
(transpose_output ? output->dimension(1) : output->dimension(0)));
|
||||||
CHECK_EQ(right_dim1,
|
CHECK_EQ(right_dim1,
|
||||||
(transpose_output ? output->dimension(0) : output->dimension(1)));
|
(transpose_output ? output->dimension(0) : output->dimension(1)));
|
||||||
|
#if 0 // this issue seems to be resolved
|
||||||
if (left_dim0 < 32 || left_dim1 < 32 || right_dim1 < 32) {
|
if (left_dim0 < 32 || left_dim1 < 32 || right_dim1 < 32) {
|
||||||
// Causes problems in libxsmm
|
// Causes problems in libxsmm
|
||||||
SparseMatMul<TL, TR>::Compute(
|
SparseMatMul<TL, TR>::Compute(
|
||||||
|
@ -1451,6 +1446,7 @@ inline void LibxsmmSparseMatMul<TL, TR>::Compute(
|
||||||
transpose_left, thread_pool, transpose_output, output);
|
transpose_left, thread_pool, transpose_output, output);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
auto left_data = left.data();
|
auto left_data = left.data();
|
||||||
auto right_data = right.data();
|
auto right_data = right.data();
|
||||||
auto output_data = output->data();
|
auto output_data = output->data();
|
||||||
|
@ -1640,15 +1636,14 @@ inline void SparseMatMul<TL, TR>::Compute(
|
||||||
SparseMatMulOp<TA, TB, LibxsmmSparseMatMul>);
|
SparseMatMulOp<TA, TB, LibxsmmSparseMatMul>);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
|
|
||||||
|
|
||||||
REGISTER_SPARSE_MATMUL(float, bfloat16);
|
REGISTER_SPARSE_MATMUL(float, bfloat16);
|
||||||
|
|
||||||
REGISTER_SPARSE_MATMUL(bfloat16, float);
|
REGISTER_SPARSE_MATMUL(bfloat16, float);
|
||||||
|
|
||||||
#ifdef TENSORFLOW_USE_LIBXSMM
|
#ifdef TENSORFLOW_USE_LIBXSMM
|
||||||
|
REGISTER_SPARSE_MATMUL_LIBXSMM(bfloat16, bfloat16);
|
||||||
REGISTER_SPARSE_MATMUL_LIBXSMM(float, float);
|
REGISTER_SPARSE_MATMUL_LIBXSMM(float, float);
|
||||||
#else
|
#else
|
||||||
|
REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
|
||||||
REGISTER_SPARSE_MATMUL(float, float);
|
REGISTER_SPARSE_MATMUL(float, float);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,9 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty();
|
||||||
|
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
#if defined(_OPENMP) && defined(LIBXSMM_USE_OPENMP)
|
||||||
|
#include <omp.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||||
|
@ -36,6 +39,12 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty();
|
||||||
#include "include/libxsmm_malloc.h"
|
#include "include/libxsmm_malloc.h"
|
||||||
#include "src/libxsmm_main.h" // TODO(bsteiner): API to avoid incl. header from src/
|
#include "src/libxsmm_main.h" // TODO(bsteiner): API to avoid incl. header from src/
|
||||||
|
|
||||||
|
#define CHECK_LIBXSMM(CONDITION_OK, MESSAGE) \
|
||||||
|
if (!(CONDITION_OK)) VLOG(0) << (MESSAGE)
|
||||||
|
#define CHECK_LIBXSMM_DNN(STATUS, MESSAGE) \
|
||||||
|
CHECK_LIBXSMM(LIBXSMM_DNN_SUCCESS == (STATUS), MESSAGE) \
|
||||||
|
<< " failed: " << libxsmm_dnn_get_error(STATUS);
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// Xsmm*Conv2D are wrappers for libxsmm direct convolutions.
|
// Xsmm*Conv2D are wrappers for libxsmm direct convolutions.
|
||||||
|
@ -73,12 +82,6 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
|
||||||
namespace functor {
|
namespace functor {
|
||||||
|
|
||||||
static void chk_libxsmm_err(libxsmm_dnn_err_t status, string msg) {
|
|
||||||
if (status != LIBXSMM_DNN_SUCCESS) {
|
|
||||||
VLOG(0) << msg << " failed: " << libxsmm_dnn_get_error(status);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float* kcrs, int R,
|
LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float* kcrs, int R,
|
||||||
int S, int C, int K, int blocksifm,
|
int S, int C, int K, int blocksifm,
|
||||||
int blocksofm, int ifmblock,
|
int blocksofm, int ifmblock,
|
||||||
|
@ -114,55 +117,117 @@ LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float* kcrs, int R,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class libxsmm_dnn_conv_desc_wrap {
|
struct libxsmm_dnn_registry_key {
|
||||||
public:
|
const libxsmm_dnn_conv_desc descriptor;
|
||||||
const libxsmm_dnn_conv_desc d;
|
libxsmm_dnn_registry_key(const libxsmm_dnn_conv_desc& desc_)
|
||||||
|
: descriptor(desc_) {}
|
||||||
libxsmm_dnn_conv_desc_wrap(const libxsmm_dnn_conv_desc& d_) : d(d_) {}
|
bool operator==(const libxsmm_dnn_registry_key& regkey) const {
|
||||||
bool operator==(const libxsmm_dnn_conv_desc_wrap& w) const {
|
return 0 == memcmp(&descriptor, ®key.descriptor, sizeof(descriptor));
|
||||||
return (d.N == w.d.N && d.C == w.d.C && d.H == w.d.H && d.W == w.d.W &&
|
|
||||||
d.K == w.d.K && d.R == w.d.R && d.S == w.d.S && d.u == w.d.u &&
|
|
||||||
d.v == w.d.v && d.pad_h == w.d.pad_h && d.pad_w == w.d.pad_w);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct HashFunction {
|
struct HashFunction {
|
||||||
std::size_t operator()(const libxsmm_dnn_conv_desc_wrap& w) const {
|
std::size_t operator()(const libxsmm_dnn_registry_key& regkey) const {
|
||||||
return libxsmm_hash(&w.d, sizeof(w.d), 25071975);
|
return libxsmm_hash(®key.descriptor, sizeof(regkey.descriptor),
|
||||||
|
25071975);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class handles {
|
struct libxsmm_dnn_registry_value {
|
||||||
|
libxsmm_dnn_tensor_datalayout* layout_input;
|
||||||
|
libxsmm_dnn_tensor_datalayout* layout_filter;
|
||||||
|
libxsmm_dnn_tensor_datalayout* layout_output;
|
||||||
|
libxsmm_dnn_layer* handle;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef libxsmm_tf_allocator<libxsmm_scratch_allocator>
|
||||||
|
libxsmm_tf_scratch_allocator;
|
||||||
|
|
||||||
|
static class libxsmm_dnn_registry_type {
|
||||||
|
private:
|
||||||
|
typedef std::unordered_map<libxsmm_dnn_registry_key,
|
||||||
|
libxsmm_dnn_registry_value, HashFunction>
|
||||||
|
container_type;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
libxsmm_dnn_layer* find(const libxsmm_dnn_conv_desc_wrap& w) {
|
libxsmm_dnn_registry_type() {
|
||||||
std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
|
libxsmm_init(); /* must be first */
|
||||||
HashFunction>::iterator i = libxsmm_handles.find(w);
|
#if !defined(LIBXSMM_LOCAL_ALLOC)
|
||||||
if (i == libxsmm_handles.end()) {
|
{
|
||||||
|
libxsmm_malloc_function malloc_fn;
|
||||||
|
libxsmm_free_function free_fn;
|
||||||
|
malloc_fn.function = libxsmm_tf_scratch_allocator::malloc;
|
||||||
|
free_fn.function = libxsmm_tf_scratch_allocator::free;
|
||||||
|
libxsmm_set_scratch_allocator(0 /*context*/, malloc_fn, free_fn);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
LIBXSMM_LOCK_ATTR_INIT(LIBXSMM_LOCK_RWLOCK, &attr);
|
||||||
|
LIBXSMM_LOCK_INIT(LIBXSMM_LOCK_RWLOCK, &lock, &attr);
|
||||||
|
}
|
||||||
|
~libxsmm_dnn_registry_type() {
|
||||||
|
LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK_RWLOCK, &lock);
|
||||||
|
const container_type::const_iterator end = container.end();
|
||||||
|
for (container_type::const_iterator i = container.begin(); i != end; ++i) {
|
||||||
|
CHECK_LIBXSMM_DNN(
|
||||||
|
libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_input),
|
||||||
|
"destroy input layout");
|
||||||
|
CHECK_LIBXSMM_DNN(
|
||||||
|
libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_output),
|
||||||
|
"destroy output layout");
|
||||||
|
CHECK_LIBXSMM_DNN(
|
||||||
|
libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_filter),
|
||||||
|
"destroy filter layout");
|
||||||
|
CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_conv_layer(i->second.handle),
|
||||||
|
"destroy handle");
|
||||||
|
}
|
||||||
|
LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_RWLOCK, &lock);
|
||||||
|
LIBXSMM_LOCK_DESTROY(LIBXSMM_LOCK_RWLOCK, &lock);
|
||||||
|
LIBXSMM_LOCK_ATTR_DESTROY(LIBXSMM_LOCK_RWLOCK, &attr);
|
||||||
|
libxsmm_finalize();
|
||||||
|
}
|
||||||
|
libxsmm_dnn_registry_value find(const libxsmm_dnn_registry_key& regkey) {
|
||||||
|
container_type::iterator i;
|
||||||
|
LIBXSMM_LOCK_ACQREAD(LIBXSMM_LOCK_RWLOCK, &lock);
|
||||||
|
i = container.find(regkey);
|
||||||
|
LIBXSMM_LOCK_RELREAD(LIBXSMM_LOCK_RWLOCK, &lock);
|
||||||
|
if (i == container.end()) {
|
||||||
libxsmm_dnn_err_t status;
|
libxsmm_dnn_err_t status;
|
||||||
libxsmm_dnn_layer* libxsmm_handle =
|
libxsmm_dnn_registry_value regentry;
|
||||||
libxsmm_dnn_create_conv_layer(w.d, &status);
|
|
||||||
chk_libxsmm_err(status, "Create handle");
|
LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK_RWLOCK, &lock);
|
||||||
libxsmm_handles.insert(std::make_pair(w, libxsmm_handle));
|
i = container.find(regkey);
|
||||||
return libxsmm_handle;
|
if (i == container.end()) { // re-check after lock acquisition
|
||||||
} else {
|
regentry.handle =
|
||||||
|
libxsmm_dnn_create_conv_layer(regkey.descriptor, &status);
|
||||||
|
if (LIBXSMM_DNN_WARN_FALLBACK != status) {
|
||||||
|
CHECK_LIBXSMM_DNN(status, "create handle");
|
||||||
|
} else { // warning
|
||||||
|
VLOG(1) << libxsmm_dnn_get_error(status);
|
||||||
|
}
|
||||||
|
regentry.layout_input = libxsmm_dnn_create_tensor_datalayout(
|
||||||
|
regentry.handle, LIBXSMM_DNN_INPUT, &status);
|
||||||
|
CHECK_LIBXSMM_DNN(status, "create input layout");
|
||||||
|
|
||||||
|
regentry.layout_output = libxsmm_dnn_create_tensor_datalayout(
|
||||||
|
regentry.handle, LIBXSMM_DNN_OUTPUT, &status);
|
||||||
|
CHECK_LIBXSMM_DNN(status, "create output layout");
|
||||||
|
|
||||||
|
regentry.layout_filter = libxsmm_dnn_create_tensor_datalayout(
|
||||||
|
regentry.handle, LIBXSMM_DNN_FILTER, &status);
|
||||||
|
CHECK_LIBXSMM_DNN(status, "create filter layout");
|
||||||
|
|
||||||
|
i = container.insert(std::make_pair(regkey, regentry)).first;
|
||||||
|
}
|
||||||
|
LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_RWLOCK, &lock);
|
||||||
|
}
|
||||||
return i->second;
|
return i->second;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
~handles() {
|
|
||||||
std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
|
|
||||||
HashFunction>::iterator i;
|
|
||||||
for (i = libxsmm_handles.begin(); i != libxsmm_handles.end(); i++)
|
|
||||||
chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(i->second),
|
|
||||||
"Destroy handle");
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
|
container_type container;
|
||||||
HashFunction>
|
LIBXSMM_LOCK_ATTR_TYPE(LIBXSMM_LOCK_RWLOCK) attr;
|
||||||
libxsmm_handles;
|
LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK_RWLOCK) lock;
|
||||||
};
|
} libxsmm_dnn_registry;
|
||||||
|
|
||||||
static handles libxsmm_handles;
|
|
||||||
|
|
||||||
// #define LIBXSMM_DETAILED_TIMING
|
// #define LIBXSMM_DETAILED_TIMING
|
||||||
|
|
||||||
|
@ -173,73 +238,64 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
||||||
InputPtr input, FilterPtr filter,
|
InputPtr input, FilterPtr filter,
|
||||||
OutputPtr output) {
|
OutputPtr output) {
|
||||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||||
uint64 l_tick1;
|
libxsmm_timer_tickint l_tick1;
|
||||||
uint64 l_tick2;
|
libxsmm_timer_tickint l_tick2;
|
||||||
uint64 l_tick3;
|
libxsmm_timer_tickint l_tick3;
|
||||||
uint64 l_tick4;
|
libxsmm_timer_tickint l_tick4;
|
||||||
uint64 l_tick5;
|
libxsmm_timer_tickint l_tick5;
|
||||||
uint64 l_tick6;
|
libxsmm_timer_tickint l_tick6;
|
||||||
uint64 l_tick7;
|
libxsmm_timer_tickint l_tick7;
|
||||||
uint64 l_tick8;
|
libxsmm_timer_tickint l_tick8;
|
||||||
uint64 l_tick9;
|
libxsmm_timer_tickint l_tick9;
|
||||||
uint64 l_tick10;
|
libxsmm_timer_tickint l_tick10;
|
||||||
l_tick1 = libxsmm_timer_tick();
|
l_tick1 = libxsmm_timer_tick();
|
||||||
#endif
|
#endif
|
||||||
// setup scoped allocator, which adopts the allocator from the context
|
#if defined(LIBXSMM_LOCAL_ALLOC)
|
||||||
const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator(*ctx);
|
// setup scoped allocator, which adopts the allocator of the current context
|
||||||
|
const libxsmm_tf_scratch_allocator tf_allocator(*ctx);
|
||||||
|
#endif
|
||||||
|
const libxsmm_dnn_registry_key regkey(desc);
|
||||||
|
const libxsmm_dnn_registry_value regentry = libxsmm_dnn_registry.find(regkey);
|
||||||
|
libxsmm_dnn_tensor *libxsmm_input, *libxsmm_output, *libxsmm_filter;
|
||||||
libxsmm_dnn_err_t status;
|
libxsmm_dnn_err_t status;
|
||||||
libxsmm_dnn_layer* libxsmm_handle;
|
|
||||||
libxsmm_dnn_conv_desc_wrap w(desc);
|
|
||||||
void* scratch;
|
|
||||||
|
|
||||||
// if(kind == LIBXSMM_DNN_COMPUTE_KIND_FWD)
|
status = libxsmm_dnn_get_codegen_success(regentry.handle, kind);
|
||||||
libxsmm_handle = libxsmm_handles.find(w);
|
|
||||||
// else{
|
|
||||||
// libxsmm_handle = libxsmm_dnn_create_conv_layer(desc, &status);
|
|
||||||
// chk_libxsmm_err(status, "Create handle");
|
|
||||||
//}
|
|
||||||
|
|
||||||
status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind);
|
|
||||||
if (status == LIBXSMM_DNN_WARN_FALLBACK) {
|
if (status == LIBXSMM_DNN_WARN_FALLBACK) {
|
||||||
return false; // Use non-libxsmm code
|
return false; // Use non-libxsmm code
|
||||||
}
|
}
|
||||||
chk_libxsmm_err(status, "Check codegen status");
|
CHECK_LIBXSMM_DNN(status, "code generation");
|
||||||
|
|
||||||
libxsmm_dnn_buffer* libxsmm_input;
|
|
||||||
libxsmm_dnn_buffer* libxsmm_output;
|
|
||||||
libxsmm_dnn_filter* libxsmm_filter;
|
|
||||||
|
|
||||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||||
l_tick2 = libxsmm_timer_tick();
|
l_tick2 = libxsmm_timer_tick();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
int ifmblock = (libxsmm_handle->ifmblock);
|
const int ifmblock = regentry.handle->ifmblock;
|
||||||
int ofmblock = (libxsmm_handle->ofmblock);
|
const int ofmblock = regentry.handle->ofmblock;
|
||||||
|
|
||||||
int blocksifm =
|
const int blocksifm =
|
||||||
desc.C % ifmblock == 0 ? desc.C / ifmblock : desc.C / ifmblock + 1;
|
(desc.C % ifmblock == 0 ? desc.C / ifmblock : desc.C / ifmblock + 1);
|
||||||
int blocksofm =
|
const int blocksofm =
|
||||||
desc.K % ofmblock == 0 ? desc.K / ofmblock : desc.K / ofmblock + 1;
|
(desc.K % ofmblock == 0 ? desc.K / ofmblock : desc.K / ofmblock + 1);
|
||||||
float* native_filter =
|
|
||||||
(float*)libxsmm_aligned_scratch(blocksofm * blocksifm * desc.R * desc.S *
|
|
||||||
ifmblock * ofmblock * sizeof(float),
|
|
||||||
2097152);
|
|
||||||
|
|
||||||
const DeviceBase::CpuWorkerThreads* worker_threads =
|
const size_t filter_size =
|
||||||
|
blocksofm * blocksifm * desc.R * desc.S * ifmblock * ofmblock;
|
||||||
|
float* const native_filter = (float*)libxsmm_aligned_scratch(
|
||||||
|
filter_size * sizeof(float), 2097152 /*alignment*/);
|
||||||
|
|
||||||
|
const DeviceBase::CpuWorkerThreads* const worker_threads =
|
||||||
ctx->device()->tensorflow_cpu_worker_threads();
|
ctx->device()->tensorflow_cpu_worker_threads();
|
||||||
|
const int num_threads = worker_threads->num_threads;
|
||||||
int num_threads = worker_threads->num_threads;
|
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
|
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
|
||||||
kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
|
kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
|
||||||
if (blocksofm > num_threads) {
|
if (blocksofm > num_threads) {
|
||||||
int work = blocksofm;
|
const int work = blocksofm;
|
||||||
BlockingCounter count(num_threads);
|
BlockingCounter count(num_threads);
|
||||||
for (int i = 0; i < num_threads; ++i) {
|
for (int i = 0; i < num_threads; ++i) {
|
||||||
worker_threads->workers->Schedule([=, &count]() {
|
worker_threads->workers->Schedule([=, &count]() {
|
||||||
int start = work / num_threads * i;
|
const int start = work / num_threads * i;
|
||||||
int end = (start + work / num_threads) > work
|
const int end = (start + work / num_threads) > work
|
||||||
? work
|
? work
|
||||||
: start + work / num_threads;
|
: start + work / num_threads;
|
||||||
copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
|
copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
|
||||||
|
@ -250,14 +306,14 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
||||||
}
|
}
|
||||||
count.Wait();
|
count.Wait();
|
||||||
} else {
|
} else {
|
||||||
int work = blocksofm;
|
const int work = blocksofm;
|
||||||
int num_threads = work;
|
const int num_tasks = work;
|
||||||
|
|
||||||
BlockingCounter count(num_threads);
|
BlockingCounter count(num_tasks);
|
||||||
for (int i = 0; i < num_threads; ++i) {
|
for (int i = 0; i < num_tasks; ++i) {
|
||||||
worker_threads->workers->Schedule([=, &count]() {
|
worker_threads->workers->Schedule([=, &count]() {
|
||||||
int start = i;
|
const int start = i;
|
||||||
int end = i + 1;
|
const int end = i + 1;
|
||||||
copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
|
copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
|
||||||
desc.K, blocksifm, blocksofm, ifmblock, ofmblock,
|
desc.K, blocksifm, blocksofm, ifmblock, ofmblock,
|
||||||
start, end);
|
start, end);
|
||||||
|
@ -267,90 +323,89 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
||||||
count.Wait();
|
count.Wait();
|
||||||
}
|
}
|
||||||
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
|
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
|
||||||
// Added: for weight update
|
// weight update buffer must be in the right format
|
||||||
|
// (LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR)
|
||||||
libxsmm_filter =
|
libxsmm_filter =
|
||||||
libxsmm_dnn_link_filter(libxsmm_handle, LIBXSMM_DNN_FILTER, filter,
|
libxsmm_dnn_link_tensor(regentry.layout_filter, filter, &status);
|
||||||
LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR, &status);
|
CHECK_LIBXSMM_DNN(status, "link filter with layout");
|
||||||
chk_libxsmm_err(status,
|
|
||||||
"Link filter"); // weight update is in RSCK as
|
|
||||||
// filter should be returned in RSCK
|
|
||||||
// format
|
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
memset(native_filter, 0,
|
memset(native_filter, 0, filter_size * sizeof(float));
|
||||||
blocksofm * blocksifm * desc.R * desc.S * ifmblock * ofmblock *
|
|
||||||
sizeof(float));
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||||
l_tick3 = libxsmm_timer_tick();
|
l_tick3 = libxsmm_timer_tick();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR
|
||||||
libxsmm_input =
|
libxsmm_input =
|
||||||
libxsmm_dnn_link_buffer(libxsmm_handle, LIBXSMM_DNN_INPUT, input,
|
libxsmm_dnn_link_tensor(regentry.layout_input, input, &status);
|
||||||
LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
|
CHECK_LIBXSMM_DNN(status, "link input buffer with layout");
|
||||||
chk_libxsmm_err(status, "Link input buffer");
|
|
||||||
|
// LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR
|
||||||
libxsmm_output =
|
libxsmm_output =
|
||||||
libxsmm_dnn_link_buffer(libxsmm_handle, LIBXSMM_DNN_OUTPUT, output,
|
libxsmm_dnn_link_tensor(regentry.layout_output, output, &status);
|
||||||
LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
|
CHECK_LIBXSMM_DNN(status, "link output buffer with layout");
|
||||||
chk_libxsmm_err(status, "Link output buffer");
|
|
||||||
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
|
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
|
||||||
kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
|
kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
|
||||||
libxsmm_filter = libxsmm_dnn_link_filter(
|
// LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR
|
||||||
libxsmm_handle, LIBXSMM_DNN_FILTER, native_filter,
|
libxsmm_filter =
|
||||||
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR, &status);
|
libxsmm_dnn_link_tensor(regentry.layout_filter, native_filter, &status);
|
||||||
chk_libxsmm_err(status, "Link filter");
|
CHECK_LIBXSMM_DNN(status, "link filter with layout");
|
||||||
|
CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
|
||||||
|
LIBXSMM_DNN_REGULAR_FILTER),
|
||||||
|
"bind filter to handle");
|
||||||
}
|
}
|
||||||
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
|
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
|
||||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
|
CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input,
|
||||||
LIBXSMM_DNN_REGULAR_INPUT),
|
LIBXSMM_DNN_REGULAR_INPUT),
|
||||||
"Bind input forward");
|
"bind input forward");
|
||||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output,
|
CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
|
||||||
|
LIBXSMM_DNN_REGULAR_FILTER),
|
||||||
|
"bind filter forward");
|
||||||
|
CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output,
|
||||||
LIBXSMM_DNN_REGULAR_OUTPUT),
|
LIBXSMM_DNN_REGULAR_OUTPUT),
|
||||||
"Bind output forward");
|
"bind output forward");
|
||||||
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter,
|
|
||||||
LIBXSMM_DNN_REGULAR_FILTER),
|
|
||||||
"Bind filter forward");
|
|
||||||
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
|
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
|
||||||
chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_input), "Zero input");
|
CHECK_LIBXSMM_DNN(libxsmm_dnn_zero_tensor(libxsmm_input), "zeroing input");
|
||||||
|
CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input,
|
||||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
|
|
||||||
LIBXSMM_DNN_GRADIENT_INPUT),
|
LIBXSMM_DNN_GRADIENT_INPUT),
|
||||||
"Bind input backward");
|
"bind input backward");
|
||||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output,
|
CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
|
||||||
LIBXSMM_DNN_GRADIENT_OUTPUT),
|
|
||||||
"Bind output backward");
|
|
||||||
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter,
|
|
||||||
LIBXSMM_DNN_REGULAR_FILTER),
|
LIBXSMM_DNN_REGULAR_FILTER),
|
||||||
"Bind filter backward");
|
"bind filter backward");
|
||||||
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
|
CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output,
|
||||||
chk_libxsmm_err(libxsmm_dnn_zero_filter(libxsmm_filter), "Zero filter");
|
|
||||||
|
|
||||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
|
|
||||||
LIBXSMM_DNN_REGULAR_INPUT),
|
|
||||||
"Bind input weight update");
|
|
||||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output,
|
|
||||||
LIBXSMM_DNN_GRADIENT_OUTPUT),
|
LIBXSMM_DNN_GRADIENT_OUTPUT),
|
||||||
"Bind output weight update");
|
"bind output backward");
|
||||||
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter,
|
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
|
||||||
|
CHECK_LIBXSMM_DNN(libxsmm_dnn_zero_tensor(libxsmm_filter),
|
||||||
|
"zeroing filter");
|
||||||
|
CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input,
|
||||||
|
LIBXSMM_DNN_REGULAR_INPUT),
|
||||||
|
"bind input weight update");
|
||||||
|
CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
|
||||||
LIBXSMM_DNN_GRADIENT_FILTER),
|
LIBXSMM_DNN_GRADIENT_FILTER),
|
||||||
"Bind filter weight update");
|
"bind filter weight update");
|
||||||
|
CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output,
|
||||||
|
LIBXSMM_DNN_GRADIENT_OUTPUT),
|
||||||
|
"bind output weight update");
|
||||||
} else {
|
} else {
|
||||||
/* shouldn't happen */
|
assert(0 /*should not happen*/);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||||
l_tick4 = libxsmm_timer_tick();
|
l_tick4 = libxsmm_timer_tick();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/* bind scratch */
|
const size_t scratch_size = libxsmm_dnn_get_scratch_size(
|
||||||
scratch = (void*)libxsmm_aligned_scratch(
|
regentry.handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, &status);
|
||||||
libxsmm_dnn_get_scratch_size(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL,
|
CHECK_LIBXSMM_DNN(status, "get scratch size");
|
||||||
&status),
|
void* const scratch =
|
||||||
2097152);
|
libxsmm_aligned_scratch(scratch_size, 2097152 /*alignment*/);
|
||||||
chk_libxsmm_err(status, "scratch allocation");
|
CHECK_LIBXSMM(0 != scratch, "scratch memory allocation");
|
||||||
chk_libxsmm_err(libxsmm_dnn_bind_scratch(
|
CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_scratch(
|
||||||
libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch),
|
regentry.handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch),
|
||||||
"binding scratch");
|
"binding scratch");
|
||||||
|
|
||||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||||
|
@ -358,30 +413,39 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
|
if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
|
||||||
libxsmm_dnn_transpose_filter(libxsmm_handle, LIBXSMM_DNN_FILTER);
|
libxsmm_dnn_transpose_filter(regentry.handle, LIBXSMM_DNN_FILTER);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||||
l_tick6 = libxsmm_timer_tick();
|
l_tick6 = libxsmm_timer_tick();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if !defined(_OPENMP) || !defined(LIBXSMM_USE_OPENMP)
|
||||||
BlockingCounter counter(num_threads);
|
BlockingCounter counter(num_threads);
|
||||||
|
|
||||||
for (int i = 0; i < num_threads; ++i) {
|
for (int i = 0; i < num_threads; ++i) {
|
||||||
worker_threads->workers->Schedule([=, &counter]() {
|
worker_threads->workers->Schedule([=, &counter]() {
|
||||||
chk_libxsmm_err(libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, i),
|
CHECK_LIBXSMM_DNN(libxsmm_dnn_execute_st(regentry.handle, kind, 0, i),
|
||||||
"Worker");
|
"worker");
|
||||||
counter.DecrementCount();
|
counter.DecrementCount();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
counter.Wait();
|
counter.Wait();
|
||||||
|
#else
|
||||||
|
#pragma omp parallel
|
||||||
|
{
|
||||||
|
CHECK_LIBXSMM_DNN(
|
||||||
|
libxsmm_dnn_execute_st(regentry.handle, kind, 0, omp_get_thread_num()),
|
||||||
|
"worker");
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||||
l_tick7 = libxsmm_timer_tick();
|
l_tick7 = libxsmm_timer_tick();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
|
if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
|
||||||
libxsmm_dnn_reduce_wu_filters(libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER);
|
libxsmm_dnn_reduce_wu_filters(regentry.handle, LIBXSMM_DNN_GRADIENT_FILTER);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||||
|
@ -389,54 +453,52 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/* clean up */
|
/* clean up */
|
||||||
chk_libxsmm_err(
|
CHECK_LIBXSMM_DNN(libxsmm_dnn_release_scratch(regentry.handle,
|
||||||
libxsmm_dnn_release_scratch(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL),
|
LIBXSMM_DNN_COMPUTE_KIND_ALL),
|
||||||
"release scratch");
|
"release scratch");
|
||||||
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
|
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
|
||||||
chk_libxsmm_err(
|
CHECK_LIBXSMM_DNN(
|
||||||
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT),
|
libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_INPUT),
|
||||||
"release input");
|
"release input");
|
||||||
chk_libxsmm_err(
|
CHECK_LIBXSMM_DNN(
|
||||||
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_OUTPUT),
|
libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_OUTPUT),
|
||||||
"release output");
|
"release output");
|
||||||
chk_libxsmm_err(
|
CHECK_LIBXSMM_DNN(
|
||||||
libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER),
|
libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_FILTER),
|
||||||
"release filter");
|
"release filter");
|
||||||
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
|
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
|
||||||
chk_libxsmm_err(
|
CHECK_LIBXSMM_DNN(
|
||||||
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_INPUT),
|
libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_GRADIENT_INPUT),
|
||||||
"release input");
|
"release input");
|
||||||
chk_libxsmm_err(
|
CHECK_LIBXSMM_DNN(libxsmm_dnn_release_tensor(regentry.handle,
|
||||||
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT),
|
LIBXSMM_DNN_GRADIENT_OUTPUT),
|
||||||
"release output");
|
"release output");
|
||||||
chk_libxsmm_err(
|
CHECK_LIBXSMM_DNN(
|
||||||
libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER),
|
libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_FILTER),
|
||||||
"release filter");
|
"release filter");
|
||||||
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
|
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
|
||||||
chk_libxsmm_err(
|
CHECK_LIBXSMM_DNN(
|
||||||
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT),
|
libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_INPUT),
|
||||||
"release input");
|
"release input");
|
||||||
chk_libxsmm_err(
|
CHECK_LIBXSMM_DNN(libxsmm_dnn_release_tensor(regentry.handle,
|
||||||
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT),
|
LIBXSMM_DNN_GRADIENT_OUTPUT),
|
||||||
"release output");
|
"release output");
|
||||||
chk_libxsmm_err(
|
CHECK_LIBXSMM_DNN(libxsmm_dnn_release_tensor(regentry.handle,
|
||||||
libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER),
|
LIBXSMM_DNN_GRADIENT_FILTER),
|
||||||
"release filter");
|
"release filter");
|
||||||
} else {
|
} else {
|
||||||
/* shouldn't happen */
|
/* shouldn't happen */
|
||||||
}
|
}
|
||||||
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_input), "Destroy input");
|
CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_input), "destroy input");
|
||||||
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_output), "Destroy output");
|
CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_output),
|
||||||
chk_libxsmm_err(libxsmm_dnn_destroy_filter(libxsmm_filter), "Destroy filter");
|
"destroy output");
|
||||||
|
CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_filter),
|
||||||
|
"destroy filter");
|
||||||
|
|
||||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||||
l_tick9 = libxsmm_timer_tick();
|
l_tick9 = libxsmm_timer_tick();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// if(kind != LIBXSMM_DNN_COMPUTE_KIND_FWD)
|
|
||||||
// chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle),
|
|
||||||
// "Destroy handle");
|
|
||||||
|
|
||||||
libxsmm_free(native_filter);
|
libxsmm_free(native_filter);
|
||||||
libxsmm_free(scratch);
|
libxsmm_free(scratch);
|
||||||
|
|
||||||
|
|
|
@ -312,8 +312,8 @@ TEST(XsmmConv2DTest, Basic) {
|
||||||
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;//LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
|
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;//LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
|
||||||
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
|
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
|
||||||
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
|
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
|
||||||
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
|
desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
|
||||||
|
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
|
||||||
if (!CanUseXsmmConv2D(desc, data_format)) {
|
if (!CanUseXsmmConv2D(desc, data_format)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -193,11 +193,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
||||||
tf_http_archive(
|
tf_http_archive(
|
||||||
name = "libxsmm_archive",
|
name = "libxsmm_archive",
|
||||||
build_file = clean_dep("//third_party:libxsmm.BUILD"),
|
build_file = clean_dep("//third_party:libxsmm.BUILD"),
|
||||||
sha256 = "5fc1972471cd8e2b8b64ea017590193739fc88d9818e3d086621e5c08e86ea35",
|
sha256 = "9c0af4509ea341d1ee2c6c19fc6f19289318c3bd4b17844efeb9e7f9691abf76",
|
||||||
strip_prefix = "libxsmm-1.11",
|
strip_prefix = "libxsmm-1.14",
|
||||||
urls = [
|
urls = [
|
||||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/hfp/libxsmm/archive/1.11.tar.gz",
|
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/hfp/libxsmm/archive/1.14.tar.gz",
|
||||||
"https://github.com/hfp/libxsmm/archive/1.11.tar.gz",
|
"https://github.com/hfp/libxsmm/archive/1.14.tar.gz",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -45,62 +45,45 @@ genrule(
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "xsmm_avx",
|
name = "xsmm_avx",
|
||||||
srcs = [
|
srcs = glob(
|
||||||
"src/libxsmm_cpuid_x86.c",
|
[
|
||||||
"src/libxsmm_dnn.c",
|
# general source files (translation units)
|
||||||
"src/libxsmm_dnn_convolution_backward.c",
|
|
||||||
"src/libxsmm_dnn_convolution_forward.c",
|
|
||||||
"src/libxsmm_dnn_convolution_weight_update.c",
|
|
||||||
"src/libxsmm_dnn_convolution_winograd_backward.c",
|
|
||||||
"src/libxsmm_dnn_convolution_winograd_forward.c",
|
|
||||||
"src/libxsmm_dnn_convolution_winograd_weight_update.c",
|
|
||||||
"src/libxsmm_dnn_handle.c",
|
|
||||||
"src/libxsmm_dump.c",
|
|
||||||
"src/libxsmm_ext_gemm.c",
|
|
||||||
"src/libxsmm_ext_trans.c",
|
|
||||||
"src/libxsmm_fsspmdm.c",
|
|
||||||
"src/libxsmm_gemm.c",
|
|
||||||
"src/libxsmm_main.c",
|
|
||||||
"src/libxsmm_malloc.c",
|
|
||||||
"src/libxsmm_perf.c",
|
|
||||||
"src/libxsmm_spmdm.c",
|
|
||||||
"src/libxsmm_sync.c",
|
|
||||||
"src/libxsmm_timer.c",
|
|
||||||
"src/libxsmm_trace.c",
|
|
||||||
"src/libxsmm_trans.c",
|
|
||||||
] + glob([
|
|
||||||
"src/generator_*.c",
|
"src/generator_*.c",
|
||||||
]),
|
"src/libxsmm_*.c",
|
||||||
hdrs = [
|
],
|
||||||
"include/libxsmm_cpuid.h",
|
exclude = [
|
||||||
"include/libxsmm_dnn.h",
|
# exclude generators (with main functions)
|
||||||
"include/libxsmm_frontend.h",
|
"src/libxsmm_generator_*.c",
|
||||||
"include/libxsmm_fsspmdm.h",
|
],
|
||||||
"include/libxsmm_generator.h",
|
),
|
||||||
"include/libxsmm_intrinsics_x86.h",
|
hdrs = glob(
|
||||||
"include/libxsmm_macros.h",
|
[
|
||||||
"include/libxsmm_malloc.h",
|
# general header files
|
||||||
"include/libxsmm_spmdm.h",
|
"include/libxsmm_*.h",
|
||||||
"include/libxsmm_sync.h",
|
# trigger rebuild if template changed
|
||||||
"include/libxsmm_timer.h",
|
"src/template/*.c",
|
||||||
"include/libxsmm_typedefs.h",
|
],
|
||||||
# Source files #included internally:
|
exclude = [
|
||||||
"src/libxsmm_gemm_diff.c",
|
# exclude existing/generated headers
|
||||||
"src/libxsmm_hash.c",
|
|
||||||
# Generated:
|
|
||||||
"include/libxsmm.h",
|
"include/libxsmm.h",
|
||||||
"include/libxsmm_config.h",
|
"include/libxsmm_config.h",
|
||||||
"include/libxsmm_dispatch.h",
|
"include/libxsmm_dispatch.h",
|
||||||
] + glob([
|
|
||||||
# trigger rebuild if template changed
|
|
||||||
"src/template/*.c",
|
|
||||||
]),
|
|
||||||
copts = [
|
|
||||||
"-mavx", # JIT does not work without avx anyway, and this silences some CRC32 warnings.
|
|
||||||
"-Wno-vla", # Libxsmm convolutions heavily use VLA.
|
|
||||||
],
|
],
|
||||||
|
) + [
|
||||||
|
# source files included internally
|
||||||
|
"src/libxsmm_hash.c",
|
||||||
|
# generated header files
|
||||||
|
"include/libxsmm.h",
|
||||||
|
"include/libxsmm_config.h",
|
||||||
|
"include/libxsmm_dispatch.h",
|
||||||
|
],
|
||||||
|
#copts = [
|
||||||
|
# "-mavx", # JIT does not work without avx anyway, and this silences some CRC32 warnings.
|
||||||
|
# "-Wno-vla", # Libxsmm convolutions heavily use VLA.
|
||||||
|
#],
|
||||||
defines = [
|
defines = [
|
||||||
"LIBXSMM_BUILD",
|
"LIBXSMM_BUILD",
|
||||||
|
"LIBXSMM_CTOR",
|
||||||
"__BLAS=0",
|
"__BLAS=0",
|
||||||
],
|
],
|
||||||
includes = [
|
includes = [
|
||||||
|
|
Loading…
Reference in New Issue