Adds a "currentThreadIndex" method to Eigen's ThreadPoolDevice. Use it to handle per-thread buffer allocation for the tileable executor without resorting to thread_local that is not fully supported on Android.
Change: 126009029
This commit is contained in:
parent
370a6d4e91
commit
7a62f1e0be
@ -1,7 +1,6 @@
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
archive_dir = "eigen-eigen-802d984ade26"
|
||||
|
||||
archive_dir = "eigen-eigen-334b1d428283"
|
||||
cc_library(
|
||||
name = "eigen",
|
||||
hdrs = glob([archive_dir+"/**/*.h", archive_dir+"/unsupported/Eigen/CXX11/*", archive_dir+"/Eigen/*"]),
|
||||
|
@ -27,6 +27,8 @@ class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
|
||||
~EigenThreadPoolWrapper() override {}
|
||||
|
||||
void Schedule(std::function<void()> fn) override { pool_->Schedule(fn); }
|
||||
int NumThreads() const override { return pool_->NumThreads(); }
|
||||
int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
|
||||
|
||||
private:
|
||||
thread::ThreadPool* pool_ = nullptr;
|
||||
|
@ -35,7 +35,6 @@ struct scalar_fmod2_op {
|
||||
return std::fmod(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct functor_traits<scalar_fmod2_op<T>> {
|
||||
enum {
|
||||
@ -44,6 +43,24 @@ struct functor_traits<scalar_fmod2_op<T>> {
|
||||
};
|
||||
};
|
||||
|
||||
// TODO(rmlarsen): This is a workaround for upstream change
|
||||
// https://bitbucket.org/eigen/eigen/commits/f339468d04d0f87caeb6cab9aef568627e9f6ea9
|
||||
// that renamed scalar_binary_pow_op to scalar_pow_op and deleted the unary
|
||||
// version of the latter. Remove once we upgrade to Eigen 3.3.
|
||||
template <typename Scalar, typename Exponent>
|
||||
struct scalar_binary_pow_op_google {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_binary_pow_op_google)
|
||||
EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a,
|
||||
const Exponent& b) const {
|
||||
return numext::pow(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scalar, typename Exponent>
|
||||
struct functor_traits<scalar_binary_pow_op_google<Scalar, Exponent>> {
|
||||
enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
|
||||
};
|
||||
|
||||
template <typename T, typename DivOrMod>
|
||||
struct safe_div_or_mod_op {
|
||||
static_assert(std::is_integral<T>::value, "Integer type expected");
|
||||
@ -477,7 +494,7 @@ struct safe_mod : base<T, Eigen::internal::safe_div_or_mod_op<
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct pow : base<T, Eigen::internal::scalar_binary_pow_op<T, T> > {};
|
||||
struct pow : base<T, Eigen::internal::scalar_binary_pow_op_google<T, T>> {};
|
||||
|
||||
template <typename T>
|
||||
struct maximum : base<T, Eigen::internal::scalar_max_op<T> > {};
|
||||
|
@ -74,15 +74,14 @@ struct ThreadPool::Impl : Eigen::ThreadPoolTempl<EigenEnvironment> {
|
||||
Impl(Env* env, const ThreadOptions& thread_options, const string& name,
|
||||
int num_threads)
|
||||
: Eigen::ThreadPoolTempl<EigenEnvironment>(
|
||||
num_threads, EigenEnvironment(env, thread_options, name)),
|
||||
num_threads_(num_threads) {}
|
||||
num_threads, EigenEnvironment(env, thread_options, name)) {}
|
||||
|
||||
void ParallelFor(int64 total, int64 cost_per_unit,
|
||||
std::function<void(int64, int64)> fn) {
|
||||
#ifdef EIGEN_USE_NONBLOCKING_THREAD_POOL
|
||||
CHECK_GE(total, 0);
|
||||
CHECK_EQ(total, (int64)(Eigen::Index)total);
|
||||
Eigen::ThreadPoolDevice device(this, num_threads_);
|
||||
Eigen::ThreadPoolDevice device(this, this->NumThreads());
|
||||
device.parallelFor(
|
||||
total, Eigen::TensorOpCost(0, 0, cost_per_unit),
|
||||
[&fn](Eigen::Index first, Eigen::Index last) { fn(first, last); });
|
||||
@ -90,10 +89,6 @@ struct ThreadPool::Impl : Eigen::ThreadPoolTempl<EigenEnvironment> {
|
||||
CHECK(0); // should not be used with the old thread pool
|
||||
#endif
|
||||
}
|
||||
|
||||
int NumThreads() const { return num_threads_; };
|
||||
|
||||
const int num_threads_;
|
||||
};
|
||||
|
||||
ThreadPool::ThreadPool(Env* env, const string& name, int num_threads)
|
||||
@ -120,5 +115,7 @@ void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit,
|
||||
|
||||
int ThreadPool::NumThreads() const { return impl_->NumThreads(); }
|
||||
|
||||
int ThreadPool::CurrentThreadId() const { return impl_->CurrentThreadId(); }
|
||||
|
||||
} // namespace thread
|
||||
} // namespace tensorflow
|
||||
|
@ -57,6 +57,10 @@ class ThreadPool {
|
||||
// Returns the number of threads in the pool.
|
||||
int NumThreads() const;
|
||||
|
||||
// Returns current thread id between 0 and NumThreads() - 1, if called from a
|
||||
// thread in the pool. Returns -1 otherwise.
|
||||
int CurrentThreadId() const;
|
||||
|
||||
struct Impl;
|
||||
|
||||
private:
|
||||
|
@ -6,8 +6,8 @@
|
||||
def tf_workspace(path_prefix = "", tf_repo_name = ""):
|
||||
native.new_http_archive(
|
||||
name = "eigen_archive",
|
||||
url = "https://bitbucket.org/eigen/eigen/get/802d984ade26.tar.gz",
|
||||
sha256 = "1499997676bd9006082950a761b88d5c48554fd550747763b2b34951da29a2e8",
|
||||
url = "https://bitbucket.org/eigen/eigen/get/334b1d428283.tar.gz",
|
||||
sha256 = "6d5efd02c7c11fbb9d02df4f0b64f22ecbd348e7549f8a83c13fb4d8d9e19d4b",
|
||||
build_file = path_prefix + "eigen.BUILD",
|
||||
)
|
||||
|
||||
@ -146,7 +146,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
|
||||
remote = "https://boringssl.googlesource.com/boringssl",
|
||||
build_file = path_prefix + "boringssl.BUILD",
|
||||
)
|
||||
|
||||
|
||||
native.bind(
|
||||
name = "boringssl_err_data_c",
|
||||
actual = "@//" + path_prefix + "third_party/boringssl:err_data_c",
|
||||
|
2
third_party/eigen3/Eigen/Cholesky
vendored
2
third_party/eigen3/Eigen/Cholesky
vendored
@ -1 +1 @@
|
||||
#include "eigen-eigen-802d984ade26/Eigen/Cholesky"
|
||||
#include "eigen-eigen-334b1d428283/Eigen/Cholesky"
|
||||
|
2
third_party/eigen3/Eigen/Core
vendored
2
third_party/eigen3/Eigen/Core
vendored
@ -1 +1 @@
|
||||
#include "eigen-eigen-802d984ade26/Eigen/Core"
|
||||
#include "eigen-eigen-334b1d428283/Eigen/Core"
|
||||
|
2
third_party/eigen3/Eigen/Eigenvalues
vendored
2
third_party/eigen3/Eigen/Eigenvalues
vendored
@ -1 +1 @@
|
||||
#include "eigen-eigen-802d984ade26/Eigen/Eigenvalues"
|
||||
#include "eigen-eigen-334b1d428283/Eigen/Eigenvalues"
|
||||
|
2
third_party/eigen3/Eigen/LU
vendored
2
third_party/eigen3/Eigen/LU
vendored
@ -1 +1 @@
|
||||
#include "eigen-eigen-802d984ade26/Eigen/LU"
|
||||
#include "eigen-eigen-334b1d428283/Eigen/LU"
|
||||
|
2
third_party/eigen3/Eigen/QR
vendored
2
third_party/eigen3/Eigen/QR
vendored
@ -1 +1 @@
|
||||
#include "eigen-eigen-802d984ade26/Eigen/QR"
|
||||
#include "eigen-eigen-334b1d428283/Eigen/QR"
|
||||
|
@ -1 +1 @@
|
||||
#include "eigen-eigen-802d984ade26/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "eigen-eigen-334b1d428283/unsupported/Eigen/CXX11/Tensor"
|
||||
|
Loading…
Reference in New Issue
Block a user