Don't needlessly synchronize the CUDA stream in CropAndResize.

Make the op Async so we don't block an executor thread while waiting for the result of the box bounds check to be copied back to the host.
Change: 154868460
This commit is contained in:
A. Unique TensorFlower 2017-05-02 12:03:34 -08:00 committed by TensorFlower Gardener
parent fc407cbcc0
commit 867d407d98
4 changed files with 319 additions and 240 deletions

View File

@ -19,6 +19,9 @@ limitations under the License.
#include "tensorflow/core/kernels/crop_and_resize_op.h"
#include <functional>
#include <string>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@ -26,10 +29,13 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
@ -37,41 +43,67 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
using Callback = std::function<void()>;
static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
const Tensor& boxes,
const Tensor& box_ind,
int* num_boxes) {
if (boxes.NumElements() == 0 && box_ind.NumElements() == 0) {
namespace {
static inline Status ParseAndCheckBoxSizes(const Tensor& boxes,
const Tensor& box_index,
int* num_boxes) {
if (boxes.NumElements() == 0 && box_index.NumElements() == 0) {
*num_boxes = 0;
return;
return Status::OK();
}
// The shape of 'boxes' is [num_boxes, 4].
OP_REQUIRES(context, boxes.dims() == 2,
errors::InvalidArgument("boxes must be 2-D",
boxes.shape().DebugString()));
if (boxes.dims() != 2) {
return errors::InvalidArgument("boxes must be 2-D",
boxes.shape().DebugString());
}
*num_boxes = boxes.dim_size(0);
OP_REQUIRES(context, boxes.dim_size(1) == 4,
errors::InvalidArgument("boxes must have 4 columns"));
// The shape of 'box_ind' is [num_boxes].
OP_REQUIRES(context, box_ind.dims() == 1,
errors::InvalidArgument("box_ind must be 1-D",
box_ind.shape().DebugString()));
OP_REQUIRES(context, box_ind.dim_size(0) == *num_boxes,
errors::InvalidArgument("box_ind has incompatible shape"));
if (boxes.dim_size(1) != 4) {
return errors::InvalidArgument("boxes must have 4 columns");
}
// The shape of 'box_index' is [num_boxes].
if (box_index.dims() != 1) {
return errors::InvalidArgument("box_index must be 1-D",
box_index.shape().DebugString());
}
if (box_index.dim_size(0) != *num_boxes) {
return errors::InvalidArgument("box_index has incompatible shape");
}
return Status::OK();
}
// Verifies that all values in box_ind are in [0, batch).
// Conditionally calls the compute callback if all values in box_index are in
// [0, batch_size) then calls done.
template <typename Device>
inline void CheckValidBoxInd(
OpKernelContext* context,
typename TTypes<int32, 1>::ConstTensor box_ind_data, int batch);
inline void RunIfBoxIndexIsValid(
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
int batch_size, Callback compute, Callback done);
// Specialization of CheckValidBoxIndex for a CPUDevice.
template <>
inline void RunIfBoxIndexIsValid<CPUDevice>(
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
int batch_size, Callback compute, Callback done) {
const int num_boxes = box_index.dimension(0);
for (int b = 0; b < num_boxes; ++b) {
OP_REQUIRES_ASYNC(
context, FastBoundsCheck(box_index(b), batch_size),
errors::OutOfRange("box_index has values outside [0, batch_size)"),
done);
}
compute();
done();
}
} // namespace
template <typename Device, typename T>
class CropAndResizeOp : public OpKernel {
class CropAndResizeOp : public AsyncOpKernel {
public:
explicit CropAndResizeOp(OpKernelConstruction* context) : OpKernel(context) {
explicit CropAndResizeOp(OpKernelConstruction* context)
: AsyncOpKernel(context) {
string method;
OP_REQUIRES_OK(context, context->GetAttr("method", &method));
OP_REQUIRES(context, method == "bilinear",
@ -80,69 +112,77 @@ class CropAndResizeOp : public OpKernel {
&extrapolation_value_));
}
void Compute(OpKernelContext* context) override {
// The shape of 'image' is [batch, image_height, image_width, channels].
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
// The shape of 'image' is [batch_size, image_height, image_width,
// channels].
const Tensor& image = context->input(0);
OP_REQUIRES(context, image.dims() == 4,
errors::InvalidArgument("input image must be 4-D",
image.shape().DebugString()));
const int batch = image.dim_size(0);
const int image_height = image.dim_size(1);
const int image_width = image.dim_size(2);
const int depth = image.dim_size(3);
OP_REQUIRES(context, image_height > 0 && image_width > 0,
errors::InvalidArgument("image dimensions must be positive"));
// The shape of 'boxes' is [num_boxes, 4].
const Tensor& boxes = context->input(1);
// The shape of 'box_ind' is [num_boxes].
const Tensor& box_ind = context->input(2);
int num_boxes = 0;
ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes);
// The shape of 'box_index' is [num_boxes].
const Tensor& box_index = context->input(2);
// The shape of 'crop_size' is [2].
const Tensor& crop_size = context->input(3);
OP_REQUIRES(context, crop_size.dims() == 1,
errors::InvalidArgument("crop_size must be 1-D",
crop_size.shape().DebugString()));
OP_REQUIRES(context, crop_size.dim_size(0) == 2,
errors::InvalidArgument("crop_size must have two elements",
crop_size.shape().DebugString()));
// Validate inputs dimensions.
OP_REQUIRES_ASYNC(context, image.dims() == 4,
errors::InvalidArgument("input image must be 4-D",
image.shape().DebugString()),
done);
const int batch_size = image.dim_size(0);
const int image_height = image.dim_size(1);
const int image_width = image.dim_size(2);
const int depth = image.dim_size(3);
OP_REQUIRES_ASYNC(
context, image_height > 0 && image_width > 0,
errors::InvalidArgument("image dimensions must be positive"), done);
int num_boxes = 0;
OP_REQUIRES_OK_ASYNC(
context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
OP_REQUIRES_ASYNC(context, crop_size.dims() == 1,
errors::InvalidArgument("crop_size must be 1-D",
crop_size.shape().DebugString()),
done);
OP_REQUIRES_ASYNC(
context, crop_size.dim_size(0) == 2,
errors::InvalidArgument("crop_size must have two elements",
crop_size.shape().DebugString()),
done);
// Copy and validate crop sizes.
auto crop_size_vec = crop_size.vec<int32>();
const int crop_height = internal::SubtleMustCopy(crop_size_vec(0));
const int crop_width = internal::SubtleMustCopy(crop_size_vec(1));
OP_REQUIRES(context, crop_height > 0 && crop_width > 0,
errors::InvalidArgument("crop dimensions must be positive"));
OP_REQUIRES_ASYNC(
context, crop_height > 0 && crop_width > 0,
errors::InvalidArgument("crop dimensions must be positive"), done);
// Allocate output tensor.
Tensor* output = nullptr;
OP_REQUIRES_OK(
OP_REQUIRES_OK_ASYNC(
context,
context->allocate_output(
0, TensorShape({num_boxes, crop_height, crop_width, depth}),
&output));
&output),
done);
typename TTypes<T, 4>::ConstTensor image_data = image.tensor<T, 4>();
typename TTypes<float, 2>::ConstTensor boxes_data =
boxes.tensor<float, 2>();
typename TTypes<int32, 1>::ConstTensor box_ind_data =
box_ind.tensor<int32, 1>();
typename TTypes<float, 4>::Tensor crops_data = output->tensor<float, 4>();
auto compute_callback = [this, context, output]() {
const Tensor& image = context->input(0);
const Tensor& boxes = context->input(1);
const Tensor& box_index = context->input(2);
const bool status = functor::CropAndResize<Device, T>()(
context->eigen_device<Device>(), image.tensor<T, 4>(),
boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(),
extrapolation_value_, output->tensor<float, 4>());
if (!status) {
context->SetStatus(
errors::Internal("Failed launch CropAndResizeKernel."));
}
};
CheckValidBoxInd<Device>(context, box_ind_data, batch);
bool status = functor::CropAndResize<Device, T>()(
context->eigen_device<Device>(), image_data, boxes_data, box_ind_data,
extrapolation_value_, crops_data);
if (!status) {
context->SetStatus(
errors::Internal("Failed launch CropAndResizeKernel."));
}
RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
batch_size, std::move(compute_callback),
std::move(done));
}
private:
@ -155,10 +195,10 @@ template <typename T>
struct CropAndResize<CPUDevice, T> {
bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
typename TTypes<int32, 1>::ConstTensor box_index,
float extrapolation_value,
typename TTypes<float, 4>::Tensor crops) {
const int batch = image.dimension(0);
const int batch_size = image.dimension(0);
const int image_height = image.dimension(1);
const int image_width = image.dimension(2);
@ -173,8 +213,8 @@ struct CropAndResize<CPUDevice, T> {
const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3);
const int32 b_in = box_ind(b);
if (b_in < 0 || b_in >= batch) {
const int32 b_in = box_index(b);
if (!FastBoundsCheck(b_in, batch_size)) {
continue;
}
@ -235,89 +275,94 @@ struct CropAndResize<CPUDevice, T> {
return true;
}
};
} // namespace functor
template <typename Device, typename T>
class CropAndResizeGradImageOp : public OpKernel {
class CropAndResizeGradImageOp : public AsyncOpKernel {
public:
explicit CropAndResizeGradImageOp(OpKernelConstruction* context)
: OpKernel(context) {
: AsyncOpKernel(context) {
string method;
OP_REQUIRES_OK(context, context->GetAttr("method", &method));
OP_REQUIRES(context, method == "bilinear",
errors::InvalidArgument("method must be 'bilinear'", method));
}
void Compute(OpKernelContext* context) override {
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
const Tensor& grads = context->input(0);
OP_REQUIRES(context, grads.dims() == 4,
errors::InvalidArgument("grads image must be 4-D",
grads.shape().DebugString()));
const int crop_height = grads.dim_size(1);
const int crop_width = grads.dim_size(2);
OP_REQUIRES(context, crop_height > 0 && crop_width > 0,
errors::InvalidArgument("grads dimensions must be positive"));
// The shape of 'boxes' is [num_boxes, 4].
const Tensor& boxes = context->input(1);
// The shape of 'box_ind' is [num_boxes].
const Tensor& box_ind = context->input(2);
int num_boxes = 0;
ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes);
OP_REQUIRES(
context, grads.dim_size(0) == num_boxes,
errors::InvalidArgument("boxes and grads have incompatible shape"));
// The shape of 'box_index' is [num_boxes].
const Tensor& box_index = context->input(2);
// The shape of 'image_size' is [4].
const Tensor& image_size = context->input(3);
OP_REQUIRES(context, image_size.dims() == 1,
errors::InvalidArgument("image_size must be 1-D",
image_size.shape().DebugString()));
OP_REQUIRES(context, image_size.dim_size(0) == 4,
errors::InvalidArgument("image_size must have 4 elements",
image_size.shape().DebugString()));
// Validate input shapes.
OP_REQUIRES_ASYNC(context, grads.dims() == 4,
errors::InvalidArgument("grads image must be 4-D",
grads.shape().DebugString()),
done);
const int crop_height = grads.dim_size(1);
const int crop_width = grads.dim_size(2);
OP_REQUIRES_ASYNC(
context, crop_height > 0 && crop_width > 0,
errors::InvalidArgument("grads dimensions must be positive"), done);
int num_boxes = 0;
OP_REQUIRES_OK_ASYNC(
context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
OP_REQUIRES_ASYNC(
context, grads.dim_size(0) == num_boxes,
errors::InvalidArgument("boxes and grads have incompatible shape"),
done);
OP_REQUIRES_ASYNC(context, image_size.dims() == 1,
errors::InvalidArgument("image_size must be 1-D",
image_size.shape().DebugString()),
done);
OP_REQUIRES_ASYNC(context, image_size.dim_size(0) == 4,
errors::InvalidArgument("image_size must have 4 elements",
image_size.shape().DebugString()),
done);
auto image_size_vec = image_size.vec<int32>();
const int batch = internal::SubtleMustCopy(image_size_vec(0));
const int batch_size = internal::SubtleMustCopy(image_size_vec(0));
const int image_height = internal::SubtleMustCopy(image_size_vec(1));
const int image_width = internal::SubtleMustCopy(image_size_vec(2));
const int depth = internal::SubtleMustCopy(image_size_vec(3));
OP_REQUIRES(context, image_height > 0 && image_width > 0,
errors::InvalidArgument("image dimensions must be positive"));
OP_REQUIRES(
OP_REQUIRES_ASYNC(
context, image_height > 0 && image_width > 0,
errors::InvalidArgument("image dimensions must be positive"), done);
OP_REQUIRES_ASYNC(
context, grads.dim_size(3) == depth,
errors::InvalidArgument("image_size and grads are incompatible"));
errors::InvalidArgument("image_size and grads are incompatible"), done);
// Allocate output tensor.
Tensor* output = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(
0, TensorShape({batch, image_height, image_width, depth}),
&output));
OP_REQUIRES_OK_ASYNC(
context,
context->allocate_output(
0, TensorShape({batch_size, image_height, image_width, depth}),
&output),
done);
typename TTypes<float, 4>::ConstTensor grads_data =
grads.tensor<float, 4>();
typename TTypes<float, 2>::ConstTensor boxes_data =
boxes.tensor<float, 2>();
typename TTypes<int32, 1>::ConstTensor box_ind_data =
box_ind.tensor<int32, 1>();
typename TTypes<T, 4>::Tensor output_data = output->tensor<T, 4>();
auto compute_callback = [context, output]() {
const Tensor& grads = context->input(0);
const Tensor& boxes = context->input(1);
const Tensor& box_index = context->input(2);
const bool status = functor::CropAndResizeBackpropImage<Device, T>()(
context->eigen_device<Device>(), grads.tensor<float, 4>(),
boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(),
output->tensor<T, 4>());
if (!status) {
context->SetStatus(errors::Internal(
"Failed launch CropAndResizeBackpropImage kernel."));
}
};
CheckValidBoxInd<Device>(context, box_ind_data, batch);
bool status = functor::CropAndResizeBackpropImage<Device, T>()(
context->eigen_device<Device>(), grads_data, boxes_data, box_ind_data,
output_data);
if (!status) {
context->SetStatus(
errors::Internal("Failed launch CropAndResizeBackpropImageKernel."));
}
RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
batch_size, std::move(compute_callback),
std::move(done));
}
};
@ -328,9 +373,9 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
bool operator()(const CPUDevice& d,
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
typename TTypes<int32, 1>::ConstTensor box_index,
typename TTypes<T, 4>::Tensor grads_image) {
const int batch = grads_image.dimension(0);
const int batch_size = grads_image.dimension(0);
const int image_height = grads_image.dimension(1);
const int image_width = grads_image.dimension(2);
@ -347,8 +392,8 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3);
const int32 b_in = box_ind(b);
if (b_in < 0 || b_in >= batch) {
const int32 b_in = box_index(b);
if (!FastBoundsCheck(b_in, batch_size)) {
continue;
}
@ -399,83 +444,90 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
return true;
}
};
} // namespace functor
template <typename Device, typename T>
class CropAndResizeGradBoxesOp : public OpKernel {
class CropAndResizeGradBoxesOp : public AsyncOpKernel {
public:
explicit CropAndResizeGradBoxesOp(OpKernelConstruction* context)
: OpKernel(context) {
: AsyncOpKernel(context) {
string method;
OP_REQUIRES_OK(context, context->GetAttr("method", &method));
OP_REQUIRES(context, method == "bilinear",
errors::InvalidArgument("method must be 'bilinear'", method));
}
void Compute(OpKernelContext* context) override {
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
const Tensor& grads = context->input(0);
// The shape of 'boxes' is [num_boxes, 4].
const Tensor& boxes = context->input(2);
// The shape of 'box_index' is [num_boxes].
const Tensor& box_index = context->input(3);
// The shape of 'image' is [batch_size, image_height, image_width, depth].
const Tensor& image = context->input(1);
OP_REQUIRES(context, grads.dims() == 4,
errors::InvalidArgument("grads image must be 4-D",
grads.shape().DebugString()));
// Validate input shapes.
OP_REQUIRES_ASYNC(context, grads.dims() == 4,
errors::InvalidArgument("grads image must be 4-D",
grads.shape().DebugString()),
done);
const int crop_height = grads.dim_size(1);
const int crop_width = grads.dim_size(2);
const int depth = grads.dim_size(3);
OP_REQUIRES(context, crop_height > 0 && crop_width > 0,
errors::InvalidArgument("grads dimensions must be positive"));
OP_REQUIRES_ASYNC(
context, crop_height > 0 && crop_width > 0,
errors::InvalidArgument("grads dimensions must be positive"), done);
// The shape of 'image' is [batch, image_height, image_width, depth].
const Tensor& image = context->input(1);
OP_REQUIRES(context, image.dims() == 4,
errors::InvalidArgument("input image must be 4-D",
image.shape().DebugString()));
const int batch = image.dim_size(0);
OP_REQUIRES_ASYNC(context, image.dims() == 4,
errors::InvalidArgument("input image must be 4-D",
image.shape().DebugString()),
done);
const int batch_size = image.dim_size(0);
const int image_height = image.dim_size(1);
const int image_width = image.dim_size(2);
OP_REQUIRES(context, image_height > 0 && image_width > 0,
errors::InvalidArgument("image dimensions must be positive"));
OP_REQUIRES(context, image.dim_size(3) == depth,
errors::InvalidArgument("image, grads depth differ"));
// The shape of 'boxes' is [num_boxes, 4].
const Tensor& boxes = context->input(2);
// The shape of 'box_ind' is [num_boxes].
const Tensor& box_ind = context->input(3);
OP_REQUIRES_ASYNC(
context, image_height > 0 && image_width > 0,
errors::InvalidArgument("image dimensions must be positive"), done);
OP_REQUIRES_ASYNC(context, image.dim_size(3) == depth,
errors::InvalidArgument("image, grads depth differ"),
done);
int num_boxes = 0;
ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes);
OP_REQUIRES_OK_ASYNC(
context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
OP_REQUIRES(
OP_REQUIRES_ASYNC(
context, grads.dim_size(0) == num_boxes,
errors::InvalidArgument("boxes and grads have incompatible shape"));
errors::InvalidArgument("boxes and grads have incompatible shape"),
done);
// Allocate output tensor.
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(
0, TensorShape({num_boxes, 4}), &output));
OP_REQUIRES_OK_ASYNC(
context,
context->allocate_output(0, TensorShape({num_boxes, 4}), &output),
done);
typename TTypes<float, 4>::ConstTensor grads_data =
grads.tensor<float, 4>();
typename TTypes<T, 4>::ConstTensor image_data = image.tensor<T, 4>();
typename TTypes<float, 2>::ConstTensor boxes_data =
boxes.tensor<float, 2>();
typename TTypes<int32, 1>::ConstTensor box_ind_data =
box_ind.tensor<int32, 1>();
typename TTypes<float, 2>::Tensor output_data = output->tensor<float, 2>();
auto compute_callback = [context, output]() {
const Tensor& grads = context->input(0);
const Tensor& image = context->input(1);
const Tensor& boxes = context->input(2);
const Tensor& box_index = context->input(3);
const bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
context->eigen_device<Device>(), grads.tensor<float, 4>(),
image.tensor<T, 4>(), boxes.tensor<float, 2>(),
box_index.tensor<int32, 1>(), output->tensor<float, 2>());
if (!status) {
context->SetStatus(errors::Internal(
"Failed launch CropAndResizeBackpropBoxes kernel."));
}
};
CheckValidBoxInd<Device>(context, box_ind_data, batch);
bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
context->eigen_device<Device>(), grads_data, image_data, boxes_data,
box_ind_data, output_data);
if (!status) {
context->SetStatus(
errors::Internal("Failed launch CropAndResizeBackpropBoxesKernel."));
}
RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
batch_size, std::move(compute_callback),
std::move(done));
}
};
@ -487,9 +539,9 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
typename TTypes<int32, 1>::ConstTensor box_index,
typename TTypes<float, 2>::Tensor grads_boxes) {
const int batch = image.dimension(0);
const int batch_size = image.dimension(0);
const int image_height = image.dimension(1);
const int image_width = image.dimension(2);
@ -506,8 +558,8 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3);
const int32 b_in = box_ind(b);
if (b_in < 0 || b_in >= batch) {
const int32 b_in = box_index(b);
if (!FastBoundsCheck(b_in, batch_size)) {
continue;
}
@ -589,30 +641,19 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
return true;
}
};
} // namespace functor
// Specialization of CheckValidBoxInd for a CPUDevice.
template <>
inline void CheckValidBoxInd<CPUDevice>(
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_ind,
int batch) {
const int num_boxes = box_ind.dimension(0);
for (int b = 0; b < num_boxes; ++b) {
OP_REQUIRES(context, box_ind(b) >= 0 && box_ind(b) < batch,
errors::OutOfRange("box_ind has values outside [0, batch)"));
}
}
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.HostMemory("crop_size"), \
CropAndResizeOp<CPUDevice, T>); \
\
REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T"), \
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.HostMemory("crop_size"), \
CropAndResizeOp<CPUDevice, T>); \
\
REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T"), \
CropAndResizeGradBoxesOp<CPUDevice, T>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
@ -634,50 +675,86 @@ TF_CALL_double(REGISTER_KERNEL);
#if GOOGLE_CUDA
// Forward declaration of the CheckValidBoxIndHelper specialization for GPU.
// Forward declaration of the CheckValidBoxIndexHelper specialization for GPU.
namespace functor {
template <>
void CheckValidBoxIndHelper<GPUDevice>::operator()(
const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_ind,
int batch, typename TTypes<bool, 0>::Tensor isvalid);
extern template struct CheckValidBoxIndHelper<GPUDevice>;
void CheckValidBoxIndexHelper<GPUDevice>::operator()(
const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_index,
int batch_size, typename TTypes<bool, 0>::Tensor isvalid);
extern template struct CheckValidBoxIndexHelper<GPUDevice>;
} // namespace functor
// Specialization of CheckValidBoxInd for a GPUDevice.
namespace {
// Specialization of CheckValidBoxIndex for a GPUDevice.
template <>
inline void CheckValidBoxInd<GPUDevice>(
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_ind,
int batch) {
const int num_boxes = box_ind.dimension(0);
inline void RunIfBoxIndexIsValid<GPUDevice>(
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
int batch_size, Callback compute, Callback done) {
const int num_boxes = box_index.dimension(0);
if (num_boxes == 0) {
compute();
done();
return;
}
Tensor isvalid_tensor;
OP_REQUIRES_OK(context,
context->allocate_temp(DataTypeToEnum<bool>::value,
TensorShape({}), &isvalid_tensor));
typename TTypes<bool, 0>::Tensor isvalid = isvalid_tensor.tensor<bool, 0>();
Tensor isvalid_dev_tensor;
OP_REQUIRES_OK_ASYNC(
context,
context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
&isvalid_dev_tensor),
done);
typename TTypes<bool, 0>::Tensor isvalid_dev =
isvalid_dev_tensor.tensor<bool, 0>();
functor::CheckValidBoxIndHelper<GPUDevice>()(
context->eigen_device<GPUDevice>(), box_ind, batch, isvalid);
// Run the actual box check on the device.
functor::CheckValidBoxIndexHelper<GPUDevice>()(
context->eigen_device<GPUDevice>(), box_index, batch_size, isvalid_dev);
// Copy the result back to the host.
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
OP_REQUIRES_ASYNC(context, stream,
errors::Internal("No GPU stream available."), done);
Tensor isvalid_host_tensor;
// Use pinned host memory on the host to avoid unnecessary
// synchronization.
AllocatorAttributes alloc_attr;
alloc_attr.set_on_host(true);
alloc_attr.set_gpu_compatible(true);
OP_REQUIRES_OK_ASYNC(
context,
context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
&isvalid_host_tensor, alloc_attr),
done);
typename TTypes<bool, 0>::Tensor isvalid_host =
isvalid_host_tensor.tensor<bool, 0>();
bool isvalid_host = false;
perftools::gputools::DeviceMemoryBase isvalid_gpu(isvalid.data(),
sizeof(bool));
stream->ThenMemcpy(&isvalid_host, isvalid_gpu, sizeof(bool));
stream->BlockHostUntilDone();
perftools::gputools::DeviceMemoryBase wrapped(isvalid_dev.data(),
sizeof(bool));
const bool status = stream
->ThenMemcpy(isvalid_host.data() /* destination */,
wrapped /* source */, sizeof(bool))
.ok();
OP_REQUIRES_ASYNC(
context, status,
errors::Internal("Failed to launch copy of isvalid from device to host."),
done);
OP_REQUIRES(context, stream->ok(),
errors::Internal("cudaMemcpy from device to host failed"));
auto wrapped_callback = [context, isvalid_host, compute, done]() {
OP_REQUIRES_ASYNC(
context, isvalid_host(),
errors::OutOfRange("box_index has values outside [0, batch_size)"),
done);
compute();
done();
};
OP_REQUIRES(context, isvalid_host,
errors::OutOfRange("box_ind has values outside [0, batch)"));
context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
stream, wrapped_callback);
}
} // namespace
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
.Device(DEVICE_GPU) \

View File

@ -53,12 +53,12 @@ struct CropAndResizeBackpropBoxes {
};
template <typename Device>
struct CheckValidBoxIndHelper {
// Checks if all values in box_ind are in [0, batch).
struct CheckValidBoxIndexHelper {
// Checks if all values in box_index are in [0, batch).
void operator()(const Device& d,
typename TTypes<int32, 1>::ConstTensor box_ind, int batch,
typename TTypes<int32, 1>::ConstTensor box_index, int batch,
typename TTypes<bool, 0>::Tensor isvalid) {
isvalid.device(d) = ((box_ind >= 0) && (box_ind < batch)).all();
isvalid.device(d) = ((box_index >= 0) && (box_index < batch)).all();
}
};

View File

@ -440,7 +440,7 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
#undef DEFINE_GPU_SPECS
template struct CheckValidBoxIndHelper<GPUDevice>;
template struct CheckValidBoxIndexHelper<GPUDevice>;
} // namespace functor
} // namespace tensorflow

View File

@ -251,7 +251,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) {
Status s = RunOpKernel();
ASSERT_FALSE(s.ok());
EXPECT_TRUE(
StringPiece(s.ToString()).contains("box_ind has incompatible shape"))
StringPiece(s.ToString()).contains("box_index has incompatible shape"))
<< s;
}
@ -264,8 +264,10 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) {
Status s = RunOpKernel();
ASSERT_FALSE(s.ok());
EXPECT_TRUE(StringPiece(s.ToString())
.contains("box_ind has values outside [0, batch)"))
.contains("box_index has values outside [0, batch_size)"))
<< s;
}
// TODO(zhengxq, rmlarsen): Add a benchmark.
} // namespace tensorflow