Remove usages of tensorflow::EigenThreadPoolWrapper class and its clones.
tensorflow::threads::ThreadPool now has AsEigenThreadPool() function that provides an implementation of Eigen thread pool interface. There is no longer need to wrap the thread pool to use inside an Eigen device. PiperOrigin-RevId: 243145027
This commit is contained in:
parent
4e28bb323d
commit
2c017701cc
@ -66,38 +66,16 @@ const absl::optional<std::set<int>>& BackendOptions::allowed_devices() const {
|
||||
return allowed_devices_;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
|
||||
public:
|
||||
explicit EigenThreadPoolWrapper(tensorflow::thread::ThreadPool* pool)
|
||||
: pool_(pool) {}
|
||||
~EigenThreadPoolWrapper() override {}
|
||||
|
||||
void Schedule(std::function<void()> fn) override {
|
||||
pool_->Schedule(std::move(fn));
|
||||
}
|
||||
int NumThreads() const override { return pool_->NumThreads(); }
|
||||
int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
|
||||
|
||||
private:
|
||||
tensorflow::thread::ThreadPool* pool_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// Define this in .cc file to avoid having to include eigen or forward declare
|
||||
// these types in the header.
|
||||
struct Backend::IntraOpThreadPool {
|
||||
explicit IntraOpThreadPool(const int num_threads)
|
||||
: pool(new tensorflow::thread::ThreadPool(tensorflow::Env::Default(),
|
||||
"XLAEigen", num_threads)),
|
||||
wrapper(new EigenThreadPoolWrapper(pool.get())),
|
||||
device(new Eigen::ThreadPoolDevice(wrapper.get(),
|
||||
wrapper->NumThreads())) {}
|
||||
device(new Eigen::ThreadPoolDevice(pool->AsEigenThreadPool(),
|
||||
pool->NumThreads())) {}
|
||||
|
||||
std::unique_ptr<tensorflow::thread::ThreadPool> pool;
|
||||
std::unique_ptr<EigenThreadPoolWrapper> wrapper;
|
||||
std::unique_ptr<Eigen::ThreadPoolDevice> device;
|
||||
};
|
||||
|
||||
|
@ -28,7 +28,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
@ -101,8 +100,7 @@ std::unique_ptr<Array2D<float>> EigenMatrixMultiply(const Array2D<float>& a,
|
||||
} else {
|
||||
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
|
||||
2);
|
||||
tensorflow::EigenThreadPoolWrapper tp(&pool);
|
||||
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
||||
Eigen::ThreadPoolDevice device(pool.AsEigenThreadPool(), pool.NumThreads());
|
||||
ExecutableRunOptions run_options;
|
||||
run_options.set_intra_op_thread_pool(&device);
|
||||
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <new>
|
||||
@ -42,7 +43,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
@ -895,8 +895,7 @@ void BM_ParallelFusion(int num_iters) {
|
||||
// Initialize thread pool.
|
||||
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
|
||||
intra_op_parallelism_threads);
|
||||
tensorflow::EigenThreadPoolWrapper tp(&pool);
|
||||
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
||||
Eigen::ThreadPoolDevice device(pool.AsEigenThreadPool(), pool.NumThreads());
|
||||
|
||||
// Initialize ExecutableRunOptions.
|
||||
ExecutableRunOptions options;
|
||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/byte_order.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
@ -108,12 +107,10 @@ struct LocalClientTestBase::EigenThreadPoolWrapper {
|
||||
explicit EigenThreadPoolWrapper()
|
||||
: pool(new tensorflow::thread::ThreadPool(
|
||||
tensorflow::Env::Default(), "XLAEigenTest", /*num_threads=*/2)),
|
||||
wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())),
|
||||
device(new Eigen::ThreadPoolDevice(wrapper.get(),
|
||||
wrapper->NumThreads())) {}
|
||||
device(new Eigen::ThreadPoolDevice(pool->AsEigenThreadPool(),
|
||||
pool->NumThreads())) {}
|
||||
|
||||
std::unique_ptr<tensorflow::thread::ThreadPool> pool;
|
||||
std::unique_ptr<tensorflow::EigenThreadPoolWrapper> wrapper;
|
||||
std::unique_ptr<Eigen::ThreadPoolDevice> device;
|
||||
};
|
||||
|
||||
|
@ -47,18 +47,14 @@ class SingleThreadedCpuDevice : public Device {
|
||||
DeviceLocality())) {
|
||||
eigen_worker_threads_.num_threads = kNumThreads;
|
||||
eigen_worker_threads_.workers = GraphRunnerThreadPool();
|
||||
eigen_threadpool_wrapper_.reset(
|
||||
new EigenThreadPoolWrapper(eigen_worker_threads_.workers));
|
||||
eigen_device_.reset(new Eigen::ThreadPoolDevice(
|
||||
eigen_threadpool_wrapper_.get(), eigen_worker_threads_.num_threads));
|
||||
eigen_worker_threads_.workers->AsEigenThreadPool(),
|
||||
eigen_worker_threads_.num_threads));
|
||||
set_tensorflow_cpu_worker_threads(&eigen_worker_threads_);
|
||||
set_eigen_cpu_device(eigen_device_.get());
|
||||
}
|
||||
|
||||
~SingleThreadedCpuDevice() override {
|
||||
eigen_threadpool_wrapper_.reset();
|
||||
eigen_device_.reset();
|
||||
}
|
||||
~SingleThreadedCpuDevice() override { eigen_device_.reset(); }
|
||||
|
||||
Status Sync() override { return Status::OK(); }
|
||||
|
||||
@ -79,7 +75,6 @@ class SingleThreadedCpuDevice : public Device {
|
||||
|
||||
private:
|
||||
DeviceBase::CpuWorkerThreads eigen_worker_threads_;
|
||||
std::unique_ptr<Eigen::ThreadPoolInterface> eigen_threadpool_wrapper_;
|
||||
std::unique_ptr<Eigen::ThreadPoolDevice> eigen_device_;
|
||||
};
|
||||
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
@ -29,8 +28,8 @@ namespace tensorflow {
|
||||
TEST(DeviceBaseTest, CpuDevice) {
|
||||
DeviceBase dbase(Env::Default());
|
||||
thread::ThreadPool pool(Env::Default(), "test", 16);
|
||||
EigenThreadPoolWrapper wrapper(&pool);
|
||||
Eigen::ThreadPoolDevice eigen_device(&wrapper, pool.NumThreads());
|
||||
Eigen::ThreadPoolDevice eigen_device(pool.AsEigenThreadPool(),
|
||||
pool.NumThreads());
|
||||
ASSERT_FALSE(dbase.has_eigen_cpu_device());
|
||||
dbase.set_eigen_cpu_device(&eigen_device);
|
||||
ASSERT_TRUE(dbase.has_eigen_cpu_device());
|
||||
|
@ -59,27 +59,6 @@ namespace grappler {
|
||||
using TensorVector = gtl::InlinedVector<TensorValue, 4>;
|
||||
|
||||
namespace {
|
||||
class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
|
||||
public:
|
||||
explicit EigenThreadPoolWrapper(thread::ThreadPool* pool) : pool_(pool) {}
|
||||
~EigenThreadPoolWrapper() override {}
|
||||
void Schedule(std::function<void()> fn) override {
|
||||
auto wrapped = [=]() {
|
||||
// TensorFlow flushes denormals to zero and rounds to nearest, so we do
|
||||
// the same here.
|
||||
port::ScopedFlushDenormal flush;
|
||||
port::ScopedSetRound round(FE_TONEAREST);
|
||||
fn();
|
||||
};
|
||||
pool_->Schedule(std::move(wrapped));
|
||||
}
|
||||
int NumThreads() const override { return pool_->NumThreads(); }
|
||||
int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
|
||||
|
||||
private:
|
||||
thread::ThreadPool* pool_ = nullptr;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
bool AllValuesAre(const TensorProto& proto, const T& value) {
|
||||
Tensor tensor;
|
||||
|
@ -26,44 +26,18 @@ namespace tensorflow {
|
||||
namespace grappler {
|
||||
using TensorVector = gtl::InlinedVector<TensorValue, 4>;
|
||||
|
||||
namespace {
|
||||
class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
|
||||
public:
|
||||
explicit EigenThreadPoolWrapper(thread::ThreadPool* pool) : pool_(pool) {}
|
||||
~EigenThreadPoolWrapper() override {}
|
||||
void Schedule(std::function<void()> fn) override {
|
||||
auto wrapped = [=]() {
|
||||
// TensorFlow flushes denormals to zero and rounds to nearest, so we do
|
||||
// the same here.
|
||||
port::ScopedFlushDenormal flush;
|
||||
port::ScopedSetRound round(FE_TONEAREST);
|
||||
fn();
|
||||
};
|
||||
pool_->Schedule(std::move(wrapped));
|
||||
}
|
||||
int NumThreads() const override { return pool_->NumThreads(); }
|
||||
int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
|
||||
|
||||
private:
|
||||
thread::ThreadPool* pool_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
DeviceSimple::DeviceSimple() : DeviceBase(Env::Default()) {
|
||||
eigen_worker_threads_.num_threads = port::NumSchedulableCPUs();
|
||||
eigen_worker_threads_.workers = new thread::ThreadPool(
|
||||
Env::Default(), "evaluation_utils", eigen_worker_threads_.num_threads);
|
||||
eigen_threadpool_wrapper_.reset(
|
||||
new EigenThreadPoolWrapper(eigen_worker_threads_.workers));
|
||||
eigen_device_.reset(new Eigen::ThreadPoolDevice(
|
||||
eigen_threadpool_wrapper_.get(), eigen_worker_threads_.num_threads));
|
||||
eigen_worker_threads_.workers->AsEigenThreadPool(),
|
||||
eigen_worker_threads_.num_threads));
|
||||
set_tensorflow_cpu_worker_threads(&eigen_worker_threads_);
|
||||
set_eigen_cpu_device(eigen_device_.get());
|
||||
}
|
||||
|
||||
DeviceSimple::~DeviceSimple() {
|
||||
eigen_threadpool_wrapper_.reset();
|
||||
eigen_device_.reset();
|
||||
delete eigen_worker_threads_.workers;
|
||||
}
|
||||
|
@ -47,7 +47,6 @@ class DeviceSimple : public DeviceBase {
|
||||
|
||||
private:
|
||||
DeviceBase::CpuWorkerThreads eigen_worker_threads_;
|
||||
std::unique_ptr<Eigen::ThreadPoolInterface> eigen_threadpool_wrapper_;
|
||||
std::unique_ptr<Eigen::ThreadPoolDevice> eigen_device_;
|
||||
};
|
||||
|
||||
|
@ -19,6 +19,8 @@ limitations under the License.
|
||||
#define EIGEN_USE_GPU
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#include "tensorflow/cc/ops/nn_ops.h"
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
@ -26,10 +28,8 @@ limitations under the License.
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/nn_ops.h"
|
||||
#include "tensorflow/cc/ops/nn_ops_internal.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
@ -735,8 +735,8 @@ static void BM_LRNFloat(int iters, int depth, int cols, int rows,
|
||||
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
|
||||
|
||||
thread::ThreadPool threadpool(Env::Default(), "test", num_threads);
|
||||
EigenThreadPoolWrapper wrapper(&threadpool);
|
||||
Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads);
|
||||
Eigen::ThreadPoolDevice eigen_cpu_device(threadpool.AsEigenThreadPool(),
|
||||
num_threads);
|
||||
device->set_eigen_cpu_device(&eigen_cpu_device);
|
||||
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
@ -817,8 +817,8 @@ static void BM_AvgPool(int iters, int batch_size, int rows, int cols, int depth,
|
||||
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
|
||||
|
||||
thread::ThreadPool threadpool(Env::Default(), "test", num_threads);
|
||||
EigenThreadPoolWrapper wrapper(&threadpool);
|
||||
Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads);
|
||||
Eigen::ThreadPoolDevice eigen_cpu_device(threadpool.AsEigenThreadPool(),
|
||||
num_threads);
|
||||
device->set_eigen_cpu_device(&eigen_cpu_device);
|
||||
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
@ -909,8 +909,8 @@ static void BM_AvgPoolBk(int iters, int batch_size, int rows, int cols,
|
||||
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
|
||||
|
||||
thread::ThreadPool threadpool(Env::Default(), "test", num_threads);
|
||||
EigenThreadPoolWrapper wrapper(&threadpool);
|
||||
Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads);
|
||||
Eigen::ThreadPoolDevice eigen_cpu_device(threadpool.AsEigenThreadPool(),
|
||||
num_threads);
|
||||
device->set_eigen_cpu_device(&eigen_cpu_device);
|
||||
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
@ -1013,8 +1013,8 @@ static void BM_MaxPool(int iters, int batch_size, int rows, int cols, int depth,
|
||||
DeviceFactory::NewDevice("CPU", options, "/job:a/replica:0/task:0"));
|
||||
|
||||
thread::ThreadPool threadpool(Env::Default(), "test", num_threads);
|
||||
EigenThreadPoolWrapper wrapper(&threadpool);
|
||||
Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads);
|
||||
Eigen::ThreadPoolDevice eigen_cpu_device(threadpool.AsEigenThreadPool(),
|
||||
num_threads);
|
||||
device->set_eigen_cpu_device(&eigen_cpu_device);
|
||||
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
@ -1193,8 +1193,8 @@ static void BM_ReluFloat(int iters, int batch_size, int rows, int cols,
|
||||
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
|
||||
|
||||
thread::ThreadPool threadpool(Env::Default(), "test", num_threads);
|
||||
EigenThreadPoolWrapper wrapper(&threadpool);
|
||||
Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads);
|
||||
Eigen::ThreadPoolDevice eigen_cpu_device(threadpool.AsEigenThreadPool(),
|
||||
num_threads);
|
||||
device->set_eigen_cpu_device(&eigen_cpu_device);
|
||||
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
||||
#include <limits>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
@ -213,8 +212,8 @@ void TestRequantizeManyInNewRange8To32Bit() {
|
||||
template <typename InputType, typename OutputType>
|
||||
void TestRequantizeManyInNewRangeEigenVsNonEigen() {
|
||||
thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */);
|
||||
EigenThreadPoolWrapper wrapper(&threadpool);
|
||||
Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */);
|
||||
Eigen::ThreadPoolDevice eigen_device(threadpool.AsEigenThreadPool(),
|
||||
2 /* num_threads */);
|
||||
|
||||
const size_t ranges_count = 6;
|
||||
const float ranges[ranges_count][4] = {
|
||||
@ -295,8 +294,8 @@ void TimeRequantizeManyInNewRange(int64 num_elements, int64 iterations,
|
||||
}
|
||||
|
||||
thread::ThreadPool threadpool(Env::Default(), "test", 4 /* num_threads */);
|
||||
EigenThreadPoolWrapper wrapper(&threadpool);
|
||||
Eigen::ThreadPoolDevice eigen_device(&wrapper, 4 /* num_threads */);
|
||||
Eigen::ThreadPoolDevice eigen_device(threadpool.AsEigenThreadPool(),
|
||||
4 /* num_threads */);
|
||||
|
||||
Tensor i_tensor =
|
||||
tensorflow::test::AsTensor(gtl::ArraySlice<InputType>(values_quantized));
|
||||
@ -607,8 +606,8 @@ void TestRequantizeManyInNewRange32To8Bit() {
|
||||
|
||||
void TestRequantizeManyInNewRange32To8BitUsingEigen() {
|
||||
thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */);
|
||||
EigenThreadPoolWrapper wrapper(&threadpool);
|
||||
Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */);
|
||||
Eigen::ThreadPoolDevice eigen_device(threadpool.AsEigenThreadPool(),
|
||||
2 /* num_threads */);
|
||||
TestRequantizeManyInNewRange32To8Bit(&eigen_device);
|
||||
}
|
||||
|
||||
@ -638,8 +637,8 @@ void TestFloatTensorToQuantized() {
|
||||
// FloatToQuantized.
|
||||
void TestFloatToQuantizedInPlaceUsingEigen() {
|
||||
thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */);
|
||||
EigenThreadPoolWrapper wrapper(&threadpool);
|
||||
Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */);
|
||||
Eigen::ThreadPoolDevice eigen_device(threadpool.AsEigenThreadPool(),
|
||||
2 /* num_threads */);
|
||||
|
||||
TestFloatToQuantizedInPlaceUsingEigen<quint8>(&eigen_device);
|
||||
TestFloatToQuantizedInPlaceUsingEigen<qint8>(&eigen_device);
|
||||
@ -649,8 +648,8 @@ void TestFloatToQuantizedInPlaceUsingEigen() {
|
||||
|
||||
void TestOverflowWithEigen() {
|
||||
thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */);
|
||||
EigenThreadPoolWrapper wrapper(&threadpool);
|
||||
Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */);
|
||||
Eigen::ThreadPoolDevice eigen_device(threadpool.AsEigenThreadPool(),
|
||||
2 /* num_threads */);
|
||||
|
||||
const int num_vals = 4;
|
||||
const float input_min = 0.0f;
|
||||
@ -717,8 +716,8 @@ void TestQuantizedTensorToFloat() {
|
||||
// QuantizedToFloat.
|
||||
void TestQuantizedToFloatInPlaceUsingEigen() {
|
||||
thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */);
|
||||
EigenThreadPoolWrapper wrapper(&threadpool);
|
||||
Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */);
|
||||
Eigen::ThreadPoolDevice eigen_device(threadpool.AsEigenThreadPool(),
|
||||
2 /* num_threads */);
|
||||
|
||||
TestQuantizedToFloatInPlaceUsingEigen<quint8>(&eigen_device);
|
||||
TestQuantizedToFloatInPlaceUsingEigen<qint8>(&eigen_device);
|
||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -218,8 +217,7 @@ TEST_F(QuantizedBatchNormOpTest, SameAsFloat) {
|
||||
allocator(), DT_FLOAT,
|
||||
TensorShape({input_batch, input_height, input_width, input_depth}));
|
||||
thread::ThreadPool threadpool(Env::Default(), "test", 1);
|
||||
EigenThreadPoolWrapper wrapper(&threadpool);
|
||||
Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, 1);
|
||||
Eigen::ThreadPoolDevice eigen_cpu_device(threadpool.AsEigenThreadPool(), 1);
|
||||
const Tensor& const_input_float = input_float;
|
||||
const Tensor& const_mean_float = mean_float;
|
||||
const Tensor& const_variance_float = variance_float;
|
||||
|
@ -210,6 +210,7 @@ void ThreadPool::SetStealPartitions(
|
||||
}
|
||||
|
||||
Eigen::ThreadPoolInterface* ThreadPool::AsEigenThreadPool() {
|
||||
DCHECK(impl_ != nullptr);
|
||||
return impl_.get();
|
||||
}
|
||||
} // namespace thread
|
||||
|
Loading…
Reference in New Issue
Block a user