Adds a "currentThreadIndex" method to Eigen's ThreadPoolDevice. Use it to handle per-thread buffer allocation for the tileable executor without resorting to thread_local that is not fully supported on Android.

Change: 126009029
This commit is contained in:
A. Unique TensorFlower 2016-06-27 14:53:07 -08:00 committed by TensorFlower Gardener
parent 370a6d4e91
commit 7a62f1e0be
12 changed files with 39 additions and 20 deletions
eigen.BUILD
tensorflow
third_party/eigen3
Eigen
unsupported/Eigen/CXX11

View File

@ -1,7 +1,6 @@
package(default_visibility = ["//visibility:public"])
archive_dir = "eigen-eigen-802d984ade26"
archive_dir = "eigen-eigen-334b1d428283"
cc_library(
name = "eigen",
hdrs = glob([archive_dir+"/**/*.h", archive_dir+"/unsupported/Eigen/CXX11/*", archive_dir+"/Eigen/*"]),

View File

@ -27,6 +27,8 @@ class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
~EigenThreadPoolWrapper() override {}
void Schedule(std::function<void()> fn) override { pool_->Schedule(fn); }
int NumThreads() const override { return pool_->NumThreads(); }
int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
private:
thread::ThreadPool* pool_ = nullptr;

View File

@ -35,7 +35,6 @@ struct scalar_fmod2_op {
return std::fmod(a, b);
}
};
template <typename T>
struct functor_traits<scalar_fmod2_op<T>> {
enum {
@ -44,6 +43,24 @@ struct functor_traits<scalar_fmod2_op<T>> {
};
};
// TODO(rmlarsen): This is a workaround for upstream change
// https://bitbucket.org/eigen/eigen/commits/f339468d04d0f87caeb6cab9aef568627e9f6ea9
// that renamed scalar_binary_pow_op to scalar_pow_op and deleted the unary
// version of the latter. Remove once we upgrade to Eigen 3.3.
template <typename Scalar, typename Exponent>
struct scalar_binary_pow_op_google {
EIGEN_EMPTY_STRUCT_CTOR(scalar_binary_pow_op_google)
EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a,
const Exponent& b) const {
return numext::pow(a, b);
}
};
template <typename Scalar, typename Exponent>
struct functor_traits<scalar_binary_pow_op_google<Scalar, Exponent>> {
enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
};
template <typename T, typename DivOrMod>
struct safe_div_or_mod_op {
static_assert(std::is_integral<T>::value, "Integer type expected");
@ -477,7 +494,7 @@ struct safe_mod : base<T, Eigen::internal::safe_div_or_mod_op<
};
template <typename T>
struct pow : base<T, Eigen::internal::scalar_binary_pow_op<T, T> > {};
struct pow : base<T, Eigen::internal::scalar_binary_pow_op_google<T, T>> {};
template <typename T>
struct maximum : base<T, Eigen::internal::scalar_max_op<T> > {};

View File

@ -74,15 +74,14 @@ struct ThreadPool::Impl : Eigen::ThreadPoolTempl<EigenEnvironment> {
Impl(Env* env, const ThreadOptions& thread_options, const string& name,
int num_threads)
: Eigen::ThreadPoolTempl<EigenEnvironment>(
num_threads, EigenEnvironment(env, thread_options, name)),
num_threads_(num_threads) {}
num_threads, EigenEnvironment(env, thread_options, name)) {}
void ParallelFor(int64 total, int64 cost_per_unit,
std::function<void(int64, int64)> fn) {
#ifdef EIGEN_USE_NONBLOCKING_THREAD_POOL
CHECK_GE(total, 0);
CHECK_EQ(total, (int64)(Eigen::Index)total);
Eigen::ThreadPoolDevice device(this, num_threads_);
Eigen::ThreadPoolDevice device(this, this->NumThreads());
device.parallelFor(
total, Eigen::TensorOpCost(0, 0, cost_per_unit),
[&fn](Eigen::Index first, Eigen::Index last) { fn(first, last); });
@ -90,10 +89,6 @@ struct ThreadPool::Impl : Eigen::ThreadPoolTempl<EigenEnvironment> {
CHECK(0); // should not be used with the old thread pool
#endif
}
int NumThreads() const { return num_threads_; };
const int num_threads_;
};
ThreadPool::ThreadPool(Env* env, const string& name, int num_threads)
@ -120,5 +115,7 @@ void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit,
int ThreadPool::NumThreads() const { return impl_->NumThreads(); }
int ThreadPool::CurrentThreadId() const { return impl_->CurrentThreadId(); }
} // namespace thread
} // namespace tensorflow

View File

@ -57,6 +57,10 @@ class ThreadPool {
// Returns the number of threads in the pool.
int NumThreads() const;
// Returns current thread id between 0 and NumThreads() - 1, if called from a
// thread in the pool. Returns -1 otherwise.
int CurrentThreadId() const;
struct Impl;
private:

View File

@ -6,8 +6,8 @@
def tf_workspace(path_prefix = "", tf_repo_name = ""):
native.new_http_archive(
name = "eigen_archive",
url = "https://bitbucket.org/eigen/eigen/get/802d984ade26.tar.gz",
sha256 = "1499997676bd9006082950a761b88d5c48554fd550747763b2b34951da29a2e8",
url = "https://bitbucket.org/eigen/eigen/get/334b1d428283.tar.gz",
sha256 = "6d5efd02c7c11fbb9d02df4f0b64f22ecbd348e7549f8a83c13fb4d8d9e19d4b",
build_file = path_prefix + "eigen.BUILD",
)
@ -146,7 +146,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
remote = "https://boringssl.googlesource.com/boringssl",
build_file = path_prefix + "boringssl.BUILD",
)
native.bind(
name = "boringssl_err_data_c",
actual = "@//" + path_prefix + "third_party/boringssl:err_data_c",

View File

@ -1 +1 @@
#include "eigen-eigen-802d984ade26/Eigen/Cholesky"
#include "eigen-eigen-334b1d428283/Eigen/Cholesky"

View File

@ -1 +1 @@
#include "eigen-eigen-802d984ade26/Eigen/Core"
#include "eigen-eigen-334b1d428283/Eigen/Core"

View File

@ -1 +1 @@
#include "eigen-eigen-802d984ade26/Eigen/Eigenvalues"
#include "eigen-eigen-334b1d428283/Eigen/Eigenvalues"

View File

@ -1 +1 @@
#include "eigen-eigen-802d984ade26/Eigen/LU"
#include "eigen-eigen-334b1d428283/Eigen/LU"

View File

@ -1 +1 @@
#include "eigen-eigen-802d984ade26/Eigen/QR"
#include "eigen-eigen-334b1d428283/Eigen/QR"

View File

@ -1 +1 @@
#include "eigen-eigen-802d984ade26/unsupported/Eigen/CXX11/Tensor"
#include "eigen-eigen-334b1d428283/unsupported/Eigen/CXX11/Tensor"