remove ParallelExecute and use d.pallelFor instead of Shard

This commit is contained in:
leslie-fang-intel 2019-11-20 20:30:38 +08:00
parent 126a6f7879
commit 2002d5e283
3 changed files with 10 additions and 24 deletions

View File

@ -69,13 +69,13 @@ class ResizeNearestNeighborOp : public OpKernel {
/*half_pixe_centers=*/true, /*half_pixe_centers=*/true,
/*align_corners=*/true>()( /*align_corners=*/true>()(
context->eigen_device<Device>(), input_data, st.height_scale, context->eigen_device<Device>(), input_data, st.height_scale,
st.width_scale, output_data, context); st.width_scale, output_data);
} else { } else {
status = functor::ResizeNearestNeighbor<Device, T, status = functor::ResizeNearestNeighbor<Device, T,
/*half_pixe_centers=*/true, /*half_pixe_centers=*/true,
/*align_corners=*/false>()( /*align_corners=*/false>()(
context->eigen_device<Device>(), input_data, st.height_scale, context->eigen_device<Device>(), input_data, st.height_scale,
st.width_scale, output_data, context); st.width_scale, output_data);
} }
} else { } else {
if (align_corners_) { if (align_corners_) {
@ -83,13 +83,13 @@ class ResizeNearestNeighborOp : public OpKernel {
/*half_pixe_centers=*/false, /*half_pixe_centers=*/false,
/*align_corners=*/true>()( /*align_corners=*/true>()(
context->eigen_device<Device>(), input_data, st.height_scale, context->eigen_device<Device>(), input_data, st.height_scale,
st.width_scale, output_data, context); st.width_scale, output_data);
} else { } else {
status = functor::ResizeNearestNeighbor<Device, T, status = functor::ResizeNearestNeighbor<Device, T,
/*half_pixe_centers=*/false, /*half_pixe_centers=*/false,
/*align_corners=*/false>()( /*align_corners=*/false>()(
context->eigen_device<Device>(), input_data, st.height_scale, context->eigen_device<Device>(), input_data, st.height_scale,
st.width_scale, output_data, context); st.width_scale, output_data);
} }
} }
if (!status) { if (!status) {
@ -131,13 +131,9 @@ struct BoolToScaler<false> {
namespace functor { namespace functor {
template <typename T, bool half_pixel_centers, bool align_corners> template <typename T, bool half_pixel_centers, bool align_corners>
struct ResizeNearestNeighbor<CPUDevice, T, half_pixel_centers, align_corners> { struct ResizeNearestNeighbor<CPUDevice, T, half_pixel_centers, align_corners> {
bool ParallelExecute(const CPUDevice& d, bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
typename TTypes<T, 4>::ConstTensor input,
const float height_scale, const float width_scale, const float height_scale, const float width_scale,
typename TTypes<T, 4>::Tensor output, typename TTypes<T, 4>::Tensor output) {
OpKernelContext* c) {
const DeviceBase::CpuWorkerThreads& worker_threads =
*(c->device()->tensorflow_cpu_worker_threads());
const Eigen::Index batch_size = input.dimension(0); const Eigen::Index batch_size = input.dimension(0);
const Eigen::Index in_height = input.dimension(1); const Eigen::Index in_height = input.dimension(1);
const Eigen::Index in_width = input.dimension(2); const Eigen::Index in_width = input.dimension(2);
@ -170,18 +166,9 @@ struct ResizeNearestNeighbor<CPUDevice, T, half_pixel_centers, align_corners> {
} }
}; };
Eigen::Index N = batch_size * out_height * out_width; Eigen::Index N = batch_size * out_height * out_width;
Shard(worker_threads.num_threads, worker_threads.workers, N, 1000.0, d.parallelFor(N, Eigen::TensorOpCost(0, 0, 1000.0), ParallelResize);
ParallelResize); // TODO: Come up with a good cost estimate:
// 3500:26~27fps, 1000:27~28fps.
return true; return true;
} }
bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
const float height_scale, const float width_scale,
typename TTypes<T, 4>::Tensor output,
OpKernelContext* context) {
return ParallelExecute(d, input, height_scale, width_scale, output,
context);
}
}; };
} // namespace functor } // namespace functor

View File

@ -28,8 +28,7 @@ template <typename Device, typename T, bool half_pixel_centers,
struct ResizeNearestNeighbor { struct ResizeNearestNeighbor {
bool operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input, bool operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input,
const float height_scale, const float width_scale, const float height_scale, const float width_scale,
typename TTypes<T, 4>::Tensor output, typename TTypes<T, 4>::Tensor output);
OpKernelContext* context = NULL);
}; };
template <typename Device, typename T, bool half_pixel_centers, template <typename Device, typename T, bool half_pixel_centers,

View File

@ -51,7 +51,7 @@ static Graph* BM_Resize(const char* algorithm, int batches, int width,
BM_ResizeDev(cpu, ResizeNearestNeighbor, 10, 499, 499); BM_ResizeDev(cpu, ResizeNearestNeighbor, 10, 499, 499);
BM_ResizeDev(cpu, ResizeBilinear, 10, 499, 499); BM_ResizeDev(cpu, ResizeBilinear, 10, 499, 499);
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
BM_ResizeDev(gpu, ResizeNearestNeighbor, 10, 499, 499); BM_ResizeDev(gpu, ResizeNearestNeighbor, 10, 499, 499);
BM_ResizeDev(gpu, ResizeBilinear, 10, 499, 499); BM_ResizeDev(gpu, ResizeBilinear, 10, 499, 499);
#endif #endif