Remove usages of tensorflow::EigenThreadPoolWrapper class and its clones.

tensorflow::threads::ThreadPool now has AsEigenThreadPool() function that
provides an implementation of Eigen thread pool interface. There is no longer
need to wrap the thread pool to use inside an Eigen device.
PiperOrigin-RevId: 243145027
This commit is contained in:
Sung Jin Hwang 2019-04-11 14:40:18 -07:00 committed by TensorFlower Gardener
parent 4e28bb323d
commit 2c017701cc
13 changed files with 40 additions and 124 deletions

View File

@ -66,38 +66,16 @@ const absl::optional<std::set<int>>& BackendOptions::allowed_devices() const {
return allowed_devices_;
}
namespace {
class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
public:
explicit EigenThreadPoolWrapper(tensorflow::thread::ThreadPool* pool)
: pool_(pool) {}
~EigenThreadPoolWrapper() override {}
void Schedule(std::function<void()> fn) override {
pool_->Schedule(std::move(fn));
}
int NumThreads() const override { return pool_->NumThreads(); }
int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
private:
tensorflow::thread::ThreadPool* pool_ = nullptr;
};
} // namespace
// Define this in .cc file to avoid having to include eigen or forward declare
// these types in the header.
struct Backend::IntraOpThreadPool {
explicit IntraOpThreadPool(const int num_threads)
: pool(new tensorflow::thread::ThreadPool(tensorflow::Env::Default(),
"XLAEigen", num_threads)),
wrapper(new EigenThreadPoolWrapper(pool.get())),
device(new Eigen::ThreadPoolDevice(wrapper.get(),
wrapper->NumThreads())) {}
device(new Eigen::ThreadPoolDevice(pool->AsEigenThreadPool(),
pool->NumThreads())) {}
std::unique_ptr<tensorflow::thread::ThreadPool> pool;
std::unique_ptr<EigenThreadPoolWrapper> wrapper;
std::unique_ptr<Eigen::ThreadPoolDevice> device;
};

View File

@ -28,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@ -101,8 +100,7 @@ std::unique_ptr<Array2D<float>> EigenMatrixMultiply(const Array2D<float>& a,
} else {
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
2);
tensorflow::EigenThreadPoolWrapper tp(&pool);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
Eigen::ThreadPoolDevice device(pool.AsEigenThreadPool(), pool.NumThreads());
ExecutableRunOptions run_options;
run_options.set_intra_op_thread_pool(&device);

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <math.h>
#include <algorithm>
#include <memory>
#include <new>
@ -42,7 +43,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test_benchmark.h"
@ -895,8 +895,7 @@ void BM_ParallelFusion(int num_iters) {
// Initialize thread pool.
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
intra_op_parallelism_threads);
tensorflow::EigenThreadPoolWrapper tp(&pool);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
Eigen::ThreadPoolDevice device(pool.AsEigenThreadPool(), pool.NumThreads());
// Initialize ExecutableRunOptions.
ExecutableRunOptions options;

View File

@ -26,7 +26,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/env.h"
@ -108,12 +107,10 @@ struct LocalClientTestBase::EigenThreadPoolWrapper {
explicit EigenThreadPoolWrapper()
: pool(new tensorflow::thread::ThreadPool(
tensorflow::Env::Default(), "XLAEigenTest", /*num_threads=*/2)),
wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())),
device(new Eigen::ThreadPoolDevice(wrapper.get(),
wrapper->NumThreads())) {}
device(new Eigen::ThreadPoolDevice(pool->AsEigenThreadPool(),
pool->NumThreads())) {}
std::unique_ptr<tensorflow::thread::ThreadPool> pool;
std::unique_ptr<tensorflow::EigenThreadPoolWrapper> wrapper;
std::unique_ptr<Eigen::ThreadPoolDevice> device;
};

View File

@ -47,18 +47,14 @@ class SingleThreadedCpuDevice : public Device {
DeviceLocality())) {
eigen_worker_threads_.num_threads = kNumThreads;
eigen_worker_threads_.workers = GraphRunnerThreadPool();
eigen_threadpool_wrapper_.reset(
new EigenThreadPoolWrapper(eigen_worker_threads_.workers));
eigen_device_.reset(new Eigen::ThreadPoolDevice(
eigen_threadpool_wrapper_.get(), eigen_worker_threads_.num_threads));
eigen_worker_threads_.workers->AsEigenThreadPool(),
eigen_worker_threads_.num_threads));
set_tensorflow_cpu_worker_threads(&eigen_worker_threads_);
set_eigen_cpu_device(eigen_device_.get());
}
~SingleThreadedCpuDevice() override {
eigen_threadpool_wrapper_.reset();
eigen_device_.reset();
}
~SingleThreadedCpuDevice() override { eigen_device_.reset(); }
Status Sync() override { return Status::OK(); }
@ -79,7 +75,6 @@ class SingleThreadedCpuDevice : public Device {
private:
DeviceBase::CpuWorkerThreads eigen_worker_threads_;
std::unique_ptr<Eigen::ThreadPoolInterface> eigen_threadpool_wrapper_;
std::unique_ptr<Eigen::ThreadPoolDevice> eigen_device_;
};

View File

@ -18,7 +18,6 @@ limitations under the License.
#include "tensorflow/core/framework/device_base.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
@ -29,8 +28,8 @@ namespace tensorflow {
TEST(DeviceBaseTest, CpuDevice) {
DeviceBase dbase(Env::Default());
thread::ThreadPool pool(Env::Default(), "test", 16);
EigenThreadPoolWrapper wrapper(&pool);
Eigen::ThreadPoolDevice eigen_device(&wrapper, pool.NumThreads());
Eigen::ThreadPoolDevice eigen_device(pool.AsEigenThreadPool(),
pool.NumThreads());
ASSERT_FALSE(dbase.has_eigen_cpu_device());
dbase.set_eigen_cpu_device(&eigen_device);
ASSERT_TRUE(dbase.has_eigen_cpu_device());

View File

@ -59,27 +59,6 @@ namespace grappler {
using TensorVector = gtl::InlinedVector<TensorValue, 4>;
namespace {
class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
public:
explicit EigenThreadPoolWrapper(thread::ThreadPool* pool) : pool_(pool) {}
~EigenThreadPoolWrapper() override {}
void Schedule(std::function<void()> fn) override {
auto wrapped = [=]() {
// TensorFlow flushes denormals to zero and rounds to nearest, so we do
// the same here.
port::ScopedFlushDenormal flush;
port::ScopedSetRound round(FE_TONEAREST);
fn();
};
pool_->Schedule(std::move(wrapped));
}
int NumThreads() const override { return pool_->NumThreads(); }
int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
private:
thread::ThreadPool* pool_ = nullptr;
};
template <typename T>
bool AllValuesAre(const TensorProto& proto, const T& value) {
Tensor tensor;

View File

@ -26,44 +26,18 @@ namespace tensorflow {
namespace grappler {
using TensorVector = gtl::InlinedVector<TensorValue, 4>;
namespace {
class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
public:
explicit EigenThreadPoolWrapper(thread::ThreadPool* pool) : pool_(pool) {}
~EigenThreadPoolWrapper() override {}
void Schedule(std::function<void()> fn) override {
auto wrapped = [=]() {
// TensorFlow flushes denormals to zero and rounds to nearest, so we do
// the same here.
port::ScopedFlushDenormal flush;
port::ScopedSetRound round(FE_TONEAREST);
fn();
};
pool_->Schedule(std::move(wrapped));
}
int NumThreads() const override { return pool_->NumThreads(); }
int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
private:
thread::ThreadPool* pool_ = nullptr;
};
} // namespace
DeviceSimple::DeviceSimple() : DeviceBase(Env::Default()) {
eigen_worker_threads_.num_threads = port::NumSchedulableCPUs();
eigen_worker_threads_.workers = new thread::ThreadPool(
Env::Default(), "evaluation_utils", eigen_worker_threads_.num_threads);
eigen_threadpool_wrapper_.reset(
new EigenThreadPoolWrapper(eigen_worker_threads_.workers));
eigen_device_.reset(new Eigen::ThreadPoolDevice(
eigen_threadpool_wrapper_.get(), eigen_worker_threads_.num_threads));
eigen_worker_threads_.workers->AsEigenThreadPool(),
eigen_worker_threads_.num_threads));
set_tensorflow_cpu_worker_threads(&eigen_worker_threads_);
set_eigen_cpu_device(eigen_device_.get());
}
DeviceSimple::~DeviceSimple() {
eigen_threadpool_wrapper_.reset();
eigen_device_.reset();
delete eigen_worker_threads_.workers;
}

View File

@ -47,7 +47,6 @@ class DeviceSimple : public DeviceBase {
private:
DeviceBase::CpuWorkerThreads eigen_worker_threads_;
std::unique_ptr<Eigen::ThreadPoolInterface> eigen_threadpool_wrapper_;
std::unique_ptr<Eigen::ThreadPoolDevice> eigen_device_;
};

View File

@ -19,6 +19,8 @@ limitations under the License.
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/cc/ops/nn_ops.h"
#include <functional>
#include <memory>
#include <unordered_map>
@ -26,10 +28,8 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/nn_ops_internal.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/fake_input.h"
@ -735,8 +735,8 @@ static void BM_LRNFloat(int iters, int depth, int cols, int rows,
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
thread::ThreadPool threadpool(Env::Default(), "test", num_threads);
EigenThreadPoolWrapper wrapper(&threadpool);
Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads);
Eigen::ThreadPoolDevice eigen_cpu_device(threadpool.AsEigenThreadPool(),
num_threads);
device->set_eigen_cpu_device(&eigen_cpu_device);
gtl::InlinedVector<TensorValue, 4> inputs;
@ -817,8 +817,8 @@ static void BM_AvgPool(int iters, int batch_size, int rows, int cols, int depth,
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
thread::ThreadPool threadpool(Env::Default(), "test", num_threads);
EigenThreadPoolWrapper wrapper(&threadpool);
Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads);
Eigen::ThreadPoolDevice eigen_cpu_device(threadpool.AsEigenThreadPool(),
num_threads);
device->set_eigen_cpu_device(&eigen_cpu_device);
gtl::InlinedVector<TensorValue, 4> inputs;
@ -909,8 +909,8 @@ static void BM_AvgPoolBk(int iters, int batch_size, int rows, int cols,
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
thread::ThreadPool threadpool(Env::Default(), "test", num_threads);
EigenThreadPoolWrapper wrapper(&threadpool);
Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads);
Eigen::ThreadPoolDevice eigen_cpu_device(threadpool.AsEigenThreadPool(),
num_threads);
device->set_eigen_cpu_device(&eigen_cpu_device);
gtl::InlinedVector<TensorValue, 4> inputs;
@ -1013,8 +1013,8 @@ static void BM_MaxPool(int iters, int batch_size, int rows, int cols, int depth,
DeviceFactory::NewDevice("CPU", options, "/job:a/replica:0/task:0"));
thread::ThreadPool threadpool(Env::Default(), "test", num_threads);
EigenThreadPoolWrapper wrapper(&threadpool);
Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads);
Eigen::ThreadPoolDevice eigen_cpu_device(threadpool.AsEigenThreadPool(),
num_threads);
device->set_eigen_cpu_device(&eigen_cpu_device);
gtl::InlinedVector<TensorValue, 4> inputs;
@ -1193,8 +1193,8 @@ static void BM_ReluFloat(int iters, int batch_size, int rows, int cols,
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
thread::ThreadPool threadpool(Env::Default(), "test", num_threads);
EigenThreadPoolWrapper wrapper(&threadpool);
Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads);
Eigen::ThreadPoolDevice eigen_cpu_device(threadpool.AsEigenThreadPool(),
num_threads);
device->set_eigen_cpu_device(&eigen_cpu_device);
gtl::InlinedVector<TensorValue, 4> inputs;

View File

@ -19,7 +19,6 @@ limitations under the License.
#include <limits>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
@ -213,8 +212,8 @@ void TestRequantizeManyInNewRange8To32Bit() {
template <typename InputType, typename OutputType>
void TestRequantizeManyInNewRangeEigenVsNonEigen() {
thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */);
EigenThreadPoolWrapper wrapper(&threadpool);
Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */);
Eigen::ThreadPoolDevice eigen_device(threadpool.AsEigenThreadPool(),
2 /* num_threads */);
const size_t ranges_count = 6;
const float ranges[ranges_count][4] = {
@ -295,8 +294,8 @@ void TimeRequantizeManyInNewRange(int64 num_elements, int64 iterations,
}
thread::ThreadPool threadpool(Env::Default(), "test", 4 /* num_threads */);
EigenThreadPoolWrapper wrapper(&threadpool);
Eigen::ThreadPoolDevice eigen_device(&wrapper, 4 /* num_threads */);
Eigen::ThreadPoolDevice eigen_device(threadpool.AsEigenThreadPool(),
4 /* num_threads */);
Tensor i_tensor =
tensorflow::test::AsTensor(gtl::ArraySlice<InputType>(values_quantized));
@ -607,8 +606,8 @@ void TestRequantizeManyInNewRange32To8Bit() {
void TestRequantizeManyInNewRange32To8BitUsingEigen() {
thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */);
EigenThreadPoolWrapper wrapper(&threadpool);
Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */);
Eigen::ThreadPoolDevice eigen_device(threadpool.AsEigenThreadPool(),
2 /* num_threads */);
TestRequantizeManyInNewRange32To8Bit(&eigen_device);
}
@ -638,8 +637,8 @@ void TestFloatTensorToQuantized() {
// FloatToQuantized.
void TestFloatToQuantizedInPlaceUsingEigen() {
thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */);
EigenThreadPoolWrapper wrapper(&threadpool);
Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */);
Eigen::ThreadPoolDevice eigen_device(threadpool.AsEigenThreadPool(),
2 /* num_threads */);
TestFloatToQuantizedInPlaceUsingEigen<quint8>(&eigen_device);
TestFloatToQuantizedInPlaceUsingEigen<qint8>(&eigen_device);
@ -649,8 +648,8 @@ void TestFloatToQuantizedInPlaceUsingEigen() {
void TestOverflowWithEigen() {
thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */);
EigenThreadPoolWrapper wrapper(&threadpool);
Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */);
Eigen::ThreadPoolDevice eigen_device(threadpool.AsEigenThreadPool(),
2 /* num_threads */);
const int num_vals = 4;
const float input_min = 0.0f;
@ -717,8 +716,8 @@ void TestQuantizedTensorToFloat() {
// QuantizedToFloat.
void TestQuantizedToFloatInPlaceUsingEigen() {
thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */);
EigenThreadPoolWrapper wrapper(&threadpool);
Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */);
Eigen::ThreadPoolDevice eigen_device(threadpool.AsEigenThreadPool(),
2 /* num_threads */);
TestQuantizedToFloatInPlaceUsingEigen<quint8>(&eigen_device);
TestQuantizedToFloatInPlaceUsingEigen<qint8>(&eigen_device);

View File

@ -16,7 +16,6 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
@ -218,8 +217,7 @@ TEST_F(QuantizedBatchNormOpTest, SameAsFloat) {
allocator(), DT_FLOAT,
TensorShape({input_batch, input_height, input_width, input_depth}));
thread::ThreadPool threadpool(Env::Default(), "test", 1);
EigenThreadPoolWrapper wrapper(&threadpool);
Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, 1);
Eigen::ThreadPoolDevice eigen_cpu_device(threadpool.AsEigenThreadPool(), 1);
const Tensor& const_input_float = input_float;
const Tensor& const_mean_float = mean_float;
const Tensor& const_variance_float = variance_float;

View File

@ -210,6 +210,7 @@ void ThreadPool::SetStealPartitions(
}
Eigen::ThreadPoolInterface* ThreadPool::AsEigenThreadPool() {
DCHECK(impl_ != nullptr);
return impl_.get();
}
} // namespace thread