diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc index 594dbd0d0df..9fd9fe6d73d 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc @@ -241,8 +241,8 @@ struct LaunchXsmmBackwardFilter { desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_RSCK; desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE; desc.options = LIBXSMM_DNN_CONV_OPTION_NONE; - desc.datatype = LIBXSMM_DNN_DATATYPE_F32; - + desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32; + desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32; if (!CanUseXsmmConv2D(desc, data_format)) { return false; } diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index 2f6200e5045..1b004a7f683 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -290,9 +290,9 @@ struct LaunchXsmmBackwardInputConvolution { desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM; // LIBXSMM_DNN_TENSOR_FORMAT_RSCK; desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE; - desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE; - desc.datatype = LIBXSMM_DNN_DATATYPE_F32; - + desc.options = LIBXSMM_DNN_CONV_OPTION_OVERWRITE; + desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32; + desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32; auto input_ptr = input_backward.data(); auto filter_ptr = kernel.data(); auto output_ptr = output_backward.data(); diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index d5ce7de1d25..7322b4ecb38 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -320,9 +320,9 @@ class LaunchXsmmConvOp { desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC; desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM; desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE; - desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE; - desc.datatype = LIBXSMM_DNN_DATATYPE_F32; - + desc.options = LIBXSMM_DNN_CONV_OPTION_OVERWRITE; + desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32; + desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32; if (dilation_rows != 1 || dilation_cols != 1 || !CanUseXsmmConv2D(desc, data_format)) { return false; diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc index a2ee69cecd7..4cdbb762679 100644 --- a/tensorflow/core/kernels/sparse_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_matmul_op.cc @@ -1396,8 +1396,8 @@ void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread( libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, int tid, int nthreads) { return libxsmm_spmdm_createSparseSlice_bfloat16_thread( - handle, transA, reinterpret_cast(A), libxsmm_output_csr_a, - block_id, tid, nthreads); + handle, transA, reinterpret_cast(A), + libxsmm_output_csr_a, block_id, tid, nthreads); } void wrapper_libxsmm_spmdm_compute_generic_thread( @@ -1406,9 +1406,10 @@ void wrapper_libxsmm_spmdm_compute_generic_thread( libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC, const bfloat16* beta, float* C, int block_id, int tid, int nthreads) { return libxsmm_spmdm_compute_bfloat16_thread( - handle, transA, transB, reinterpret_cast(alpha), A_sparse, - reinterpret_cast(B), transC, - reinterpret_cast(beta), C, block_id, tid, nthreads); + handle, transA, transB, reinterpret_cast(alpha), + A_sparse, reinterpret_cast(B), transC, + reinterpret_cast(beta), C, block_id, tid, + nthreads); } void wrapper_libxsmm_spmdm_compute_generic_thread( empty_type_wrapper, const libxsmm_spmdm_handle* handle, char transA, @@ -1427,13 +1428,6 @@ inline void LibxsmmSparseMatMul::Compute( const typename LibxsmmSparseMatMul::ConstMatrixMapR& right, bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool, bool transpose_output, MatrixMap* output) { - if (false) { - // Not handled by libxsmm currently - SparseMatMul::Compute( - nullptr /* Assumes no cached data for fallback */, left, right, - transpose_left, thread_pool, transpose_output, output); - return; - } const int num_threads = thread_pool->num_threads; const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0); const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1); @@ -1444,6 +1438,7 @@ inline void LibxsmmSparseMatMul::Compute( (transpose_output ? output->dimension(1) : output->dimension(0))); CHECK_EQ(right_dim1, (transpose_output ? output->dimension(0) : output->dimension(1))); +#if 0 // this issue seems to be resolved if (left_dim0 < 32 || left_dim1 < 32 || right_dim1 < 32) { // Causes problems in libxsmm SparseMatMul::Compute( @@ -1451,6 +1446,7 @@ inline void LibxsmmSparseMatMul::Compute( transpose_left, thread_pool, transpose_output, output); return; } +#endif auto left_data = left.data(); auto right_data = right.data(); auto output_data = output->data(); @@ -1640,15 +1636,14 @@ inline void SparseMatMul::Compute( SparseMatMulOp); #endif -REGISTER_SPARSE_MATMUL(bfloat16, bfloat16); - REGISTER_SPARSE_MATMUL(float, bfloat16); - REGISTER_SPARSE_MATMUL(bfloat16, float); #ifdef TENSORFLOW_USE_LIBXSMM +REGISTER_SPARSE_MATMUL_LIBXSMM(bfloat16, bfloat16); REGISTER_SPARSE_MATMUL_LIBXSMM(float, float); #else +REGISTER_SPARSE_MATMUL(bfloat16, bfloat16); REGISTER_SPARSE_MATMUL(float, float); #endif diff --git a/tensorflow/core/kernels/xsmm_conv2d.cc b/tensorflow/core/kernels/xsmm_conv2d.cc index 941e2bdf545..3da48898d05 100644 --- a/tensorflow/core/kernels/xsmm_conv2d.cc +++ b/tensorflow/core/kernels/xsmm_conv2d.cc @@ -27,6 +27,9 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty(); #include #include +#if defined(_OPENMP) && defined(LIBXSMM_USE_OPENMP) +#include +#endif #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/blocking_counter.h" @@ -36,6 +39,12 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty(); #include "include/libxsmm_malloc.h" #include "src/libxsmm_main.h" // TODO(bsteiner): API to avoid incl. header from src/ +#define CHECK_LIBXSMM(CONDITION_OK, MESSAGE) \ + if (!(CONDITION_OK)) VLOG(0) << (MESSAGE) +#define CHECK_LIBXSMM_DNN(STATUS, MESSAGE) \ + CHECK_LIBXSMM(LIBXSMM_DNN_SUCCESS == (STATUS), MESSAGE) \ + << " failed: " << libxsmm_dnn_get_error(STATUS); + namespace tensorflow { // Xsmm*Conv2D are wrappers for libxsmm direct convolutions. @@ -73,12 +82,6 @@ typedef Eigen::ThreadPoolDevice CPUDevice; namespace functor { -static void chk_libxsmm_err(libxsmm_dnn_err_t status, string msg) { - if (status != LIBXSMM_DNN_SUCCESS) { - VLOG(0) << msg << " failed: " << libxsmm_dnn_get_error(status); - } -} - LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float* kcrs, int R, int S, int C, int K, int blocksifm, int blocksofm, int ifmblock, @@ -114,55 +117,117 @@ LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float* kcrs, int R, } } -class libxsmm_dnn_conv_desc_wrap { - public: - const libxsmm_dnn_conv_desc d; - - libxsmm_dnn_conv_desc_wrap(const libxsmm_dnn_conv_desc& d_) : d(d_) {} - bool operator==(const libxsmm_dnn_conv_desc_wrap& w) const { - return (d.N == w.d.N && d.C == w.d.C && d.H == w.d.H && d.W == w.d.W && - d.K == w.d.K && d.R == w.d.R && d.S == w.d.S && d.u == w.d.u && - d.v == w.d.v && d.pad_h == w.d.pad_h && d.pad_w == w.d.pad_w); +struct libxsmm_dnn_registry_key { + const libxsmm_dnn_conv_desc descriptor; + libxsmm_dnn_registry_key(const libxsmm_dnn_conv_desc& desc_) + : descriptor(desc_) {} + bool operator==(const libxsmm_dnn_registry_key& regkey) const { + return 0 == memcmp(&descriptor, ®key.descriptor, sizeof(descriptor)); } }; struct HashFunction { - std::size_t operator()(const libxsmm_dnn_conv_desc_wrap& w) const { - return libxsmm_hash(&w.d, sizeof(w.d), 25071975); + std::size_t operator()(const libxsmm_dnn_registry_key& regkey) const { + return libxsmm_hash(®key.descriptor, sizeof(regkey.descriptor), + 25071975); } }; -class handles { +struct libxsmm_dnn_registry_value { + libxsmm_dnn_tensor_datalayout* layout_input; + libxsmm_dnn_tensor_datalayout* layout_filter; + libxsmm_dnn_tensor_datalayout* layout_output; + libxsmm_dnn_layer* handle; +}; + +typedef libxsmm_tf_allocator + libxsmm_tf_scratch_allocator; + +static class libxsmm_dnn_registry_type { + private: + typedef std::unordered_map + container_type; + public: - libxsmm_dnn_layer* find(const libxsmm_dnn_conv_desc_wrap& w) { - std::unordered_map::iterator i = libxsmm_handles.find(w); - if (i == libxsmm_handles.end()) { - libxsmm_dnn_err_t status; - libxsmm_dnn_layer* libxsmm_handle = - libxsmm_dnn_create_conv_layer(w.d, &status); - chk_libxsmm_err(status, "Create handle"); - libxsmm_handles.insert(std::make_pair(w, libxsmm_handle)); - return libxsmm_handle; - } else { - return i->second; + libxsmm_dnn_registry_type() { + libxsmm_init(); /* must be first */ +#if !defined(LIBXSMM_LOCAL_ALLOC) + { + libxsmm_malloc_function malloc_fn; + libxsmm_free_function free_fn; + malloc_fn.function = libxsmm_tf_scratch_allocator::malloc; + free_fn.function = libxsmm_tf_scratch_allocator::free; + libxsmm_set_scratch_allocator(0 /*context*/, malloc_fn, free_fn); } +#endif + LIBXSMM_LOCK_ATTR_INIT(LIBXSMM_LOCK_RWLOCK, &attr); + LIBXSMM_LOCK_INIT(LIBXSMM_LOCK_RWLOCK, &lock, &attr); } - ~handles() { - std::unordered_map::iterator i; - for (i = libxsmm_handles.begin(); i != libxsmm_handles.end(); i++) - chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(i->second), - "Destroy handle"); + ~libxsmm_dnn_registry_type() { + LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK_RWLOCK, &lock); + const container_type::const_iterator end = container.end(); + for (container_type::const_iterator i = container.begin(); i != end; ++i) { + CHECK_LIBXSMM_DNN( + libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_input), + "destroy input layout"); + CHECK_LIBXSMM_DNN( + libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_output), + "destroy output layout"); + CHECK_LIBXSMM_DNN( + libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_filter), + "destroy filter layout"); + CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_conv_layer(i->second.handle), + "destroy handle"); + } + LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_RWLOCK, &lock); + LIBXSMM_LOCK_DESTROY(LIBXSMM_LOCK_RWLOCK, &lock); + LIBXSMM_LOCK_ATTR_DESTROY(LIBXSMM_LOCK_RWLOCK, &attr); + libxsmm_finalize(); + } + libxsmm_dnn_registry_value find(const libxsmm_dnn_registry_key& regkey) { + container_type::iterator i; + LIBXSMM_LOCK_ACQREAD(LIBXSMM_LOCK_RWLOCK, &lock); + i = container.find(regkey); + LIBXSMM_LOCK_RELREAD(LIBXSMM_LOCK_RWLOCK, &lock); + if (i == container.end()) { + libxsmm_dnn_err_t status; + libxsmm_dnn_registry_value regentry; + + LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK_RWLOCK, &lock); + i = container.find(regkey); + if (i == container.end()) { // re-check after lock acquisition + regentry.handle = + libxsmm_dnn_create_conv_layer(regkey.descriptor, &status); + if (LIBXSMM_DNN_WARN_FALLBACK != status) { + CHECK_LIBXSMM_DNN(status, "create handle"); + } else { // warning + VLOG(1) << libxsmm_dnn_get_error(status); + } + regentry.layout_input = libxsmm_dnn_create_tensor_datalayout( + regentry.handle, LIBXSMM_DNN_INPUT, &status); + CHECK_LIBXSMM_DNN(status, "create input layout"); + + regentry.layout_output = libxsmm_dnn_create_tensor_datalayout( + regentry.handle, LIBXSMM_DNN_OUTPUT, &status); + CHECK_LIBXSMM_DNN(status, "create output layout"); + + regentry.layout_filter = libxsmm_dnn_create_tensor_datalayout( + regentry.handle, LIBXSMM_DNN_FILTER, &status); + CHECK_LIBXSMM_DNN(status, "create filter layout"); + + i = container.insert(std::make_pair(regkey, regentry)).first; + } + LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_RWLOCK, &lock); + } + return i->second; } private: - std::unordered_map - libxsmm_handles; -}; - -static handles libxsmm_handles; + container_type container; + LIBXSMM_LOCK_ATTR_TYPE(LIBXSMM_LOCK_RWLOCK) attr; + LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK_RWLOCK) lock; +} libxsmm_dnn_registry; // #define LIBXSMM_DETAILED_TIMING @@ -173,75 +238,66 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, InputPtr input, FilterPtr filter, OutputPtr output) { #if defined(LIBXSMM_DETAILED_TIMING) - uint64 l_tick1; - uint64 l_tick2; - uint64 l_tick3; - uint64 l_tick4; - uint64 l_tick5; - uint64 l_tick6; - uint64 l_tick7; - uint64 l_tick8; - uint64 l_tick9; - uint64 l_tick10; + libxsmm_timer_tickint l_tick1; + libxsmm_timer_tickint l_tick2; + libxsmm_timer_tickint l_tick3; + libxsmm_timer_tickint l_tick4; + libxsmm_timer_tickint l_tick5; + libxsmm_timer_tickint l_tick6; + libxsmm_timer_tickint l_tick7; + libxsmm_timer_tickint l_tick8; + libxsmm_timer_tickint l_tick9; + libxsmm_timer_tickint l_tick10; l_tick1 = libxsmm_timer_tick(); #endif - // setup scoped allocator, which adopts the allocator from the context - const libxsmm_tf_allocator tf_allocator(*ctx); +#if defined(LIBXSMM_LOCAL_ALLOC) + // setup scoped allocator, which adopts the allocator of the current context + const libxsmm_tf_scratch_allocator tf_allocator(*ctx); +#endif + const libxsmm_dnn_registry_key regkey(desc); + const libxsmm_dnn_registry_value regentry = libxsmm_dnn_registry.find(regkey); + libxsmm_dnn_tensor *libxsmm_input, *libxsmm_output, *libxsmm_filter; libxsmm_dnn_err_t status; - libxsmm_dnn_layer* libxsmm_handle; - libxsmm_dnn_conv_desc_wrap w(desc); - void* scratch; - // if(kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) - libxsmm_handle = libxsmm_handles.find(w); - // else{ - // libxsmm_handle = libxsmm_dnn_create_conv_layer(desc, &status); - // chk_libxsmm_err(status, "Create handle"); - //} - - status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind); + status = libxsmm_dnn_get_codegen_success(regentry.handle, kind); if (status == LIBXSMM_DNN_WARN_FALLBACK) { return false; // Use non-libxsmm code } - chk_libxsmm_err(status, "Check codegen status"); - - libxsmm_dnn_buffer* libxsmm_input; - libxsmm_dnn_buffer* libxsmm_output; - libxsmm_dnn_filter* libxsmm_filter; + CHECK_LIBXSMM_DNN(status, "code generation"); #if defined(LIBXSMM_DETAILED_TIMING) l_tick2 = libxsmm_timer_tick(); #endif - int ifmblock = (libxsmm_handle->ifmblock); - int ofmblock = (libxsmm_handle->ofmblock); + const int ifmblock = regentry.handle->ifmblock; + const int ofmblock = regentry.handle->ofmblock; - int blocksifm = - desc.C % ifmblock == 0 ? desc.C / ifmblock : desc.C / ifmblock + 1; - int blocksofm = - desc.K % ofmblock == 0 ? desc.K / ofmblock : desc.K / ofmblock + 1; - float* native_filter = - (float*)libxsmm_aligned_scratch(blocksofm * blocksifm * desc.R * desc.S * - ifmblock * ofmblock * sizeof(float), - 2097152); + const int blocksifm = + (desc.C % ifmblock == 0 ? desc.C / ifmblock : desc.C / ifmblock + 1); + const int blocksofm = + (desc.K % ofmblock == 0 ? desc.K / ofmblock : desc.K / ofmblock + 1); - const DeviceBase::CpuWorkerThreads* worker_threads = + const size_t filter_size = + blocksofm * blocksifm * desc.R * desc.S * ifmblock * ofmblock; + float* const native_filter = (float*)libxsmm_aligned_scratch( + filter_size * sizeof(float), 2097152 /*alignment*/); + + const DeviceBase::CpuWorkerThreads* const worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); - - int num_threads = worker_threads->num_threads; + const int num_threads = worker_threads->num_threads; #if 1 if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD || kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { if (blocksofm > num_threads) { - int work = blocksofm; + const int work = blocksofm; BlockingCounter count(num_threads); for (int i = 0; i < num_threads; ++i) { worker_threads->workers->Schedule([=, &count]() { - int start = work / num_threads * i; - int end = (start + work / num_threads) > work - ? work - : start + work / num_threads; + const int start = work / num_threads * i; + const int end = (start + work / num_threads) > work + ? work + : start + work / num_threads; copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C, desc.K, blocksifm, blocksofm, ifmblock, ofmblock, start, end); @@ -250,14 +306,14 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, } count.Wait(); } else { - int work = blocksofm; - int num_threads = work; + const int work = blocksofm; + const int num_tasks = work; - BlockingCounter count(num_threads); - for (int i = 0; i < num_threads; ++i) { + BlockingCounter count(num_tasks); + for (int i = 0; i < num_tasks; ++i) { worker_threads->workers->Schedule([=, &count]() { - int start = i; - int end = i + 1; + const int start = i; + const int end = i + 1; copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C, desc.K, blocksifm, blocksofm, ifmblock, ofmblock, start, end); @@ -267,121 +323,129 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, count.Wait(); } } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) { - // Added: for weight update + // weight update buffer must be in the right format + // (LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR) libxsmm_filter = - libxsmm_dnn_link_filter(libxsmm_handle, LIBXSMM_DNN_FILTER, filter, - LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR, &status); - chk_libxsmm_err(status, - "Link filter"); // weight update is in RSCK as - // filter should be returned in RSCK - // format + libxsmm_dnn_link_tensor(regentry.layout_filter, filter, &status); + CHECK_LIBXSMM_DNN(status, "link filter with layout"); } #else - memset(native_filter, 0, - blocksofm * blocksifm * desc.R * desc.S * ifmblock * ofmblock * - sizeof(float)); + memset(native_filter, 0, filter_size * sizeof(float)); #endif #if defined(LIBXSMM_DETAILED_TIMING) l_tick3 = libxsmm_timer_tick(); #endif + // LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR libxsmm_input = - libxsmm_dnn_link_buffer(libxsmm_handle, LIBXSMM_DNN_INPUT, input, - LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status); - chk_libxsmm_err(status, "Link input buffer"); + libxsmm_dnn_link_tensor(regentry.layout_input, input, &status); + CHECK_LIBXSMM_DNN(status, "link input buffer with layout"); + + // LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR libxsmm_output = - libxsmm_dnn_link_buffer(libxsmm_handle, LIBXSMM_DNN_OUTPUT, output, - LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status); - chk_libxsmm_err(status, "Link output buffer"); + libxsmm_dnn_link_tensor(regentry.layout_output, output, &status); + CHECK_LIBXSMM_DNN(status, "link output buffer with layout"); + if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD || kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { - libxsmm_filter = libxsmm_dnn_link_filter( - libxsmm_handle, LIBXSMM_DNN_FILTER, native_filter, - LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR, &status); - chk_libxsmm_err(status, "Link filter"); + // LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR + libxsmm_filter = + libxsmm_dnn_link_tensor(regentry.layout_filter, native_filter, &status); + CHECK_LIBXSMM_DNN(status, "link filter with layout"); + CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter, + LIBXSMM_DNN_REGULAR_FILTER), + "bind filter to handle"); } if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) { - chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, - LIBXSMM_DNN_REGULAR_INPUT), - "Bind input forward"); - chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, - LIBXSMM_DNN_REGULAR_OUTPUT), - "Bind output forward"); - chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, - LIBXSMM_DNN_REGULAR_FILTER), - "Bind filter forward"); + CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input, + LIBXSMM_DNN_REGULAR_INPUT), + "bind input forward"); + CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter, + LIBXSMM_DNN_REGULAR_FILTER), + "bind filter forward"); + CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output, + LIBXSMM_DNN_REGULAR_OUTPUT), + "bind output forward"); } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { - chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_input), "Zero input"); - - chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, - LIBXSMM_DNN_GRADIENT_INPUT), - "Bind input backward"); - chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, - LIBXSMM_DNN_GRADIENT_OUTPUT), - "Bind output backward"); - chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, - LIBXSMM_DNN_REGULAR_FILTER), - "Bind filter backward"); + CHECK_LIBXSMM_DNN(libxsmm_dnn_zero_tensor(libxsmm_input), "zeroing input"); + CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input, + LIBXSMM_DNN_GRADIENT_INPUT), + "bind input backward"); + CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter, + LIBXSMM_DNN_REGULAR_FILTER), + "bind filter backward"); + CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output, + LIBXSMM_DNN_GRADIENT_OUTPUT), + "bind output backward"); } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) { - chk_libxsmm_err(libxsmm_dnn_zero_filter(libxsmm_filter), "Zero filter"); - - chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, - LIBXSMM_DNN_REGULAR_INPUT), - "Bind input weight update"); - chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, - LIBXSMM_DNN_GRADIENT_OUTPUT), - "Bind output weight update"); - chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, - LIBXSMM_DNN_GRADIENT_FILTER), - "Bind filter weight update"); + CHECK_LIBXSMM_DNN(libxsmm_dnn_zero_tensor(libxsmm_filter), + "zeroing filter"); + CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input, + LIBXSMM_DNN_REGULAR_INPUT), + "bind input weight update"); + CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter, + LIBXSMM_DNN_GRADIENT_FILTER), + "bind filter weight update"); + CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output, + LIBXSMM_DNN_GRADIENT_OUTPUT), + "bind output weight update"); } else { - /* shouldn't happen */ + assert(0 /*should not happen*/); } #if defined(LIBXSMM_DETAILED_TIMING) l_tick4 = libxsmm_timer_tick(); #endif - /* bind scratch */ - scratch = (void*)libxsmm_aligned_scratch( - libxsmm_dnn_get_scratch_size(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, - &status), - 2097152); - chk_libxsmm_err(status, "scratch allocation"); - chk_libxsmm_err(libxsmm_dnn_bind_scratch( - libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch), - "binding scratch"); + const size_t scratch_size = libxsmm_dnn_get_scratch_size( + regentry.handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, &status); + CHECK_LIBXSMM_DNN(status, "get scratch size"); + void* const scratch = + libxsmm_aligned_scratch(scratch_size, 2097152 /*alignment*/); + CHECK_LIBXSMM(0 != scratch, "scratch memory allocation"); + CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_scratch( + regentry.handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch), + "binding scratch"); #if defined(LIBXSMM_DETAILED_TIMING) l_tick5 = libxsmm_timer_tick(); #endif if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { - libxsmm_dnn_transpose_filter(libxsmm_handle, LIBXSMM_DNN_FILTER); + libxsmm_dnn_transpose_filter(regentry.handle, LIBXSMM_DNN_FILTER); } #if defined(LIBXSMM_DETAILED_TIMING) l_tick6 = libxsmm_timer_tick(); #endif +#if !defined(_OPENMP) || !defined(LIBXSMM_USE_OPENMP) BlockingCounter counter(num_threads); for (int i = 0; i < num_threads; ++i) { worker_threads->workers->Schedule([=, &counter]() { - chk_libxsmm_err(libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, i), - "Worker"); + CHECK_LIBXSMM_DNN(libxsmm_dnn_execute_st(regentry.handle, kind, 0, i), + "worker"); counter.DecrementCount(); }); } counter.Wait(); +#else +#pragma omp parallel + { + CHECK_LIBXSMM_DNN( + libxsmm_dnn_execute_st(regentry.handle, kind, 0, omp_get_thread_num()), + "worker"); + } +#endif #if defined(LIBXSMM_DETAILED_TIMING) l_tick7 = libxsmm_timer_tick(); #endif if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) { - libxsmm_dnn_reduce_wu_filters(libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER); + libxsmm_dnn_reduce_wu_filters(regentry.handle, LIBXSMM_DNN_GRADIENT_FILTER); } #if defined(LIBXSMM_DETAILED_TIMING) @@ -389,54 +453,52 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, #endif /* clean up */ - chk_libxsmm_err( - libxsmm_dnn_release_scratch(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL), - "release scratch"); + CHECK_LIBXSMM_DNN(libxsmm_dnn_release_scratch(regentry.handle, + LIBXSMM_DNN_COMPUTE_KIND_ALL), + "release scratch"); if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) { - chk_libxsmm_err( - libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT), + CHECK_LIBXSMM_DNN( + libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_INPUT), "release input"); - chk_libxsmm_err( - libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_OUTPUT), + CHECK_LIBXSMM_DNN( + libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_OUTPUT), "release output"); - chk_libxsmm_err( - libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER), + CHECK_LIBXSMM_DNN( + libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_FILTER), "release filter"); } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { - chk_libxsmm_err( - libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_INPUT), + CHECK_LIBXSMM_DNN( + libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_GRADIENT_INPUT), "release input"); - chk_libxsmm_err( - libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT), - "release output"); - chk_libxsmm_err( - libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER), + CHECK_LIBXSMM_DNN(libxsmm_dnn_release_tensor(regentry.handle, + LIBXSMM_DNN_GRADIENT_OUTPUT), + "release output"); + CHECK_LIBXSMM_DNN( + libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_FILTER), "release filter"); } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) { - chk_libxsmm_err( - libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT), + CHECK_LIBXSMM_DNN( + libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_INPUT), "release input"); - chk_libxsmm_err( - libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT), - "release output"); - chk_libxsmm_err( - libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER), - "release filter"); + CHECK_LIBXSMM_DNN(libxsmm_dnn_release_tensor(regentry.handle, + LIBXSMM_DNN_GRADIENT_OUTPUT), + "release output"); + CHECK_LIBXSMM_DNN(libxsmm_dnn_release_tensor(regentry.handle, + LIBXSMM_DNN_GRADIENT_FILTER), + "release filter"); } else { /* shouldn't happen */ } - chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_input), "Destroy input"); - chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_output), "Destroy output"); - chk_libxsmm_err(libxsmm_dnn_destroy_filter(libxsmm_filter), "Destroy filter"); + CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_input), "destroy input"); + CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_output), + "destroy output"); + CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_filter), + "destroy filter"); #if defined(LIBXSMM_DETAILED_TIMING) l_tick9 = libxsmm_timer_tick(); #endif - // if(kind != LIBXSMM_DNN_COMPUTE_KIND_FWD) - // chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle), - // "Destroy handle"); - libxsmm_free(native_filter); libxsmm_free(scratch); diff --git a/tensorflow/core/kernels/xsmm_conv2d_test.cc b/tensorflow/core/kernels/xsmm_conv2d_test.cc index 481f3b7ba46..8e6aedf2506 100644 --- a/tensorflow/core/kernels/xsmm_conv2d_test.cc +++ b/tensorflow/core/kernels/xsmm_conv2d_test.cc @@ -312,8 +312,8 @@ TEST(XsmmConv2DTest, Basic) { LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;//LIBXSMM_DNN_TENSOR_FORMAT_RSCK; desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE; desc.options = LIBXSMM_DNN_CONV_OPTION_NONE; - desc.datatype = LIBXSMM_DNN_DATATYPE_F32; - + desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32; + desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32; if (!CanUseXsmmConv2D(desc, data_format)) { return false; } diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 1494075ee21..5dc60cad1f9 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -193,11 +193,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "libxsmm_archive", build_file = clean_dep("//third_party:libxsmm.BUILD"), - sha256 = "5fc1972471cd8e2b8b64ea017590193739fc88d9818e3d086621e5c08e86ea35", - strip_prefix = "libxsmm-1.11", + sha256 = "9c0af4509ea341d1ee2c6c19fc6f19289318c3bd4b17844efeb9e7f9691abf76", + strip_prefix = "libxsmm-1.14", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/hfp/libxsmm/archive/1.11.tar.gz", - "https://github.com/hfp/libxsmm/archive/1.11.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/hfp/libxsmm/archive/1.14.tar.gz", + "https://github.com/hfp/libxsmm/archive/1.14.tar.gz", ], ) diff --git a/third_party/libxsmm.BUILD b/third_party/libxsmm.BUILD index dc7dcc95170..0e59250fef3 100644 --- a/third_party/libxsmm.BUILD +++ b/third_party/libxsmm.BUILD @@ -45,62 +45,45 @@ genrule( cc_library( name = "xsmm_avx", - srcs = [ - "src/libxsmm_cpuid_x86.c", - "src/libxsmm_dnn.c", - "src/libxsmm_dnn_convolution_backward.c", - "src/libxsmm_dnn_convolution_forward.c", - "src/libxsmm_dnn_convolution_weight_update.c", - "src/libxsmm_dnn_convolution_winograd_backward.c", - "src/libxsmm_dnn_convolution_winograd_forward.c", - "src/libxsmm_dnn_convolution_winograd_weight_update.c", - "src/libxsmm_dnn_handle.c", - "src/libxsmm_dump.c", - "src/libxsmm_ext_gemm.c", - "src/libxsmm_ext_trans.c", - "src/libxsmm_fsspmdm.c", - "src/libxsmm_gemm.c", - "src/libxsmm_main.c", - "src/libxsmm_malloc.c", - "src/libxsmm_perf.c", - "src/libxsmm_spmdm.c", - "src/libxsmm_sync.c", - "src/libxsmm_timer.c", - "src/libxsmm_trace.c", - "src/libxsmm_trans.c", - ] + glob([ - "src/generator_*.c", - ]), - hdrs = [ - "include/libxsmm_cpuid.h", - "include/libxsmm_dnn.h", - "include/libxsmm_frontend.h", - "include/libxsmm_fsspmdm.h", - "include/libxsmm_generator.h", - "include/libxsmm_intrinsics_x86.h", - "include/libxsmm_macros.h", - "include/libxsmm_malloc.h", - "include/libxsmm_spmdm.h", - "include/libxsmm_sync.h", - "include/libxsmm_timer.h", - "include/libxsmm_typedefs.h", - # Source files #included internally: - "src/libxsmm_gemm_diff.c", + srcs = glob( + [ + # general source files (translation units) + "src/generator_*.c", + "src/libxsmm_*.c", + ], + exclude = [ + # exclude generators (with main functions) + "src/libxsmm_generator_*.c", + ], + ), + hdrs = glob( + [ + # general header files + "include/libxsmm_*.h", + # trigger rebuild if template changed + "src/template/*.c", + ], + exclude = [ + # exclude existing/generated headers + "include/libxsmm.h", + "include/libxsmm_config.h", + "include/libxsmm_dispatch.h", + ], + ) + [ + # source files included internally "src/libxsmm_hash.c", - # Generated: + # generated header files "include/libxsmm.h", "include/libxsmm_config.h", "include/libxsmm_dispatch.h", - ] + glob([ - # trigger rebuild if template changed - "src/template/*.c", - ]), - copts = [ - "-mavx", # JIT does not work without avx anyway, and this silences some CRC32 warnings. - "-Wno-vla", # Libxsmm convolutions heavily use VLA. ], + #copts = [ + # "-mavx", # JIT does not work without avx anyway, and this silences some CRC32 warnings. + # "-Wno-vla", # Libxsmm convolutions heavily use VLA. + #], defines = [ "LIBXSMM_BUILD", + "LIBXSMM_CTOR", "__BLAS=0", ], includes = [