review changes.
This commit is contained in:
parent
dbd14b1587
commit
aae350dbb5
@ -1098,7 +1098,6 @@ class MklConvOp : public OpKernel {
|
||||
}
|
||||
|
||||
// TODO(intel-mkl): This function does not seem to be called. Remove it.
|
||||
// LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot
|
||||
// Prepare and execute net - checks for input and output reorders.
|
||||
void PrepareAndExecuteNet(const ConvFwdPd& conv_prim_desc,
|
||||
MklDnnData<Tinput>* src,
|
||||
|
@ -26,32 +26,36 @@ limitations under the License.
|
||||
#include <vector>
|
||||
#include "mkldnn.hpp"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#define EIGEN_USE_THREADS
|
||||
#include "tensorflow/core/platform/threadpool.h"
|
||||
#define EIGEN_USE_THREADS
|
||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||
using dnnl::threadpool_iface;
|
||||
using dnnl::stream_attr;
|
||||
|
||||
namespace tensorflow {
|
||||
// balance211 function tries to divide n jobs equally among 'team' threads.
|
||||
// This is the same as DNNL load balancer.
|
||||
template <typename T, typename U>
|
||||
inline void balance211(T n, U team, U tid, T& n_start, T& n_end) {
|
||||
T& n_my = n_end;
|
||||
if (team <= 1 || n == 0) {
|
||||
n_start = 0;
|
||||
n_my = n;
|
||||
} else {
|
||||
// team = T1 + T2
|
||||
// n = T1*n1 + T2*n2 (n1 - n2 = 1)
|
||||
T n1 = (n + (T)team - 1) / team;
|
||||
T n2 = n1 - 1;
|
||||
T T1 = n - n2 * (T)team;
|
||||
n_my = (T)tid < T1 ? n1 : n2;
|
||||
n_start = (T)tid <= T1 ? tid * n1 : T1 * n1 + ((T)tid - T1) * n2;
|
||||
}
|
||||
|
||||
n_end += n_start;
|
||||
// Divide 'n' units of work equally among 'teams' threads. If 'n' is not
|
||||
// divisible by 'teams' and has a remainder 'r', the first 'r' teams have one
|
||||
// unit of work more than the rest. Returns the range of work that belongs to
|
||||
// the team 'tid'.
|
||||
// Parameters
|
||||
// n Total number of jobs.
|
||||
// team Number of workers.
|
||||
// tid Current thread_id.
|
||||
// n_start start of range operated by the thread.
|
||||
// n_end end of the range operated by the thread.
|
||||
|
||||
template <typename T, typename U>
|
||||
inline void balance211(T n, U team, U tid, T* n_start, T* n_end) {
|
||||
if (team <= 1 || n == 0) {
|
||||
*n_start = 0;
|
||||
*n_end = n;
|
||||
return;
|
||||
}
|
||||
T min_per_team = n / team;
|
||||
T remainder = n - min_per_team * team; // i.e., n % teams.
|
||||
*n_start = tid * min_per_team + std::min(tid, remainder);
|
||||
*n_end = *n_start + min_per_team + (tid < remainder);
|
||||
}
|
||||
|
||||
struct MklDnnThreadPool : public dnnl::threadpool_iface {
|
||||
@ -60,13 +64,7 @@ struct MklDnnThreadPool : public dnnl::threadpool_iface {
|
||||
MklDnnThreadPool(OpKernelContext* ctx)
|
||||
: eigen_interface_(ctx->device()
|
||||
->tensorflow_cpu_worker_threads()
|
||||
->workers->AsEigenThreadPool())
|
||||
#if DNNL_PRINT_STATS
|
||||
,
|
||||
jobs_per_thread(eigen_interface_->NumThreads())
|
||||
#endif
|
||||
{
|
||||
}
|
||||
->workers->AsEigenThreadPool()) {}
|
||||
virtual int get_num_threads() override {
|
||||
return eigen_interface_->NumThreads();
|
||||
}
|
||||
@ -91,26 +89,16 @@ struct MklDnnThreadPool : public dnnl::threadpool_iface {
|
||||
eigen_interface_->ScheduleWithHint(
|
||||
[i, n, njobs, fn]() {
|
||||
int start, end;
|
||||
balance211(n, njobs, i, start, end);
|
||||
balance211(n, njobs, i, &start, &end);
|
||||
for (int j = start; j < end; j++) fn(j, n);
|
||||
#if DNNL_PRINT_STATS
|
||||
jobs_per_thread[eigen_interface_->CurrentThreadId()]++;
|
||||
#endif
|
||||
},
|
||||
i, i + 1);
|
||||
}
|
||||
}
|
||||
#if DNNL_PRINT_STATS
|
||||
void print_thread_usage_stats() {
|
||||
for (int i = 0; i < jobs_per_thread_.size(); i++)
|
||||
std::cout << " Thread" << i << "," << jobs_per_thread[i] << std::endl;
|
||||
}
|
||||
#endif
|
||||
~MklDnnThreadPool() {}
|
||||
|
||||
private:
|
||||
Eigen::ThreadPoolInterface* eigen_interface_ = nullptr;
|
||||
std::vector<int> jobs_per_thread_;
|
||||
};
|
||||
|
||||
class MklDnnThreadPoolWrapper {
|
||||
|
Loading…
Reference in New Issue
Block a user