diff --git a/tensorflow/core/kernels/crop_and_resize_op.cc b/tensorflow/core/kernels/crop_and_resize_op.cc index 746fe63e2a0..1c7afcf8663 100644 --- a/tensorflow/core/kernels/crop_and_resize_op.cc +++ b/tensorflow/core/kernels/crop_and_resize_op.cc @@ -19,6 +19,9 @@ limitations under the License. #include "tensorflow/core/kernels/crop_and_resize_op.h" +#include <functional> +#include <string> + #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -26,10 +29,13 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" #if GOOGLE_CUDA +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA @@ -37,41 +43,67 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +using Callback = std::function<void()>; -static inline void ParseAndCheckBoxSizes(OpKernelContext* context, - const Tensor& boxes, - const Tensor& box_ind, - int* num_boxes) { - if (boxes.NumElements() == 0 && box_ind.NumElements() == 0) { +namespace { + +static inline Status ParseAndCheckBoxSizes(const Tensor& boxes, + const Tensor& box_index, + int* num_boxes) { + if (boxes.NumElements() == 0 && box_index.NumElements() == 0) { *num_boxes = 0; - return; + return Status::OK(); } // The shape of 'boxes' is [num_boxes, 4]. - OP_REQUIRES(context, boxes.dims() == 2, - errors::InvalidArgument("boxes must be 2-D", - boxes.shape().DebugString())); + if (boxes.dims() != 2) { + return errors::InvalidArgument("boxes must be 2-D", + boxes.shape().DebugString()); + } *num_boxes = boxes.dim_size(0); - OP_REQUIRES(context, boxes.dim_size(1) == 4, - errors::InvalidArgument("boxes must have 4 columns")); - - // The shape of 'box_ind' is [num_boxes]. - OP_REQUIRES(context, box_ind.dims() == 1, - errors::InvalidArgument("box_ind must be 1-D", - box_ind.shape().DebugString())); - OP_REQUIRES(context, box_ind.dim_size(0) == *num_boxes, - errors::InvalidArgument("box_ind has incompatible shape")); + if (boxes.dim_size(1) != 4) { + return errors::InvalidArgument("boxes must have 4 columns"); + } + // The shape of 'box_index' is [num_boxes]. + if (box_index.dims() != 1) { + return errors::InvalidArgument("box_index must be 1-D", + box_index.shape().DebugString()); + } + if (box_index.dim_size(0) != *num_boxes) { + return errors::InvalidArgument("box_index has incompatible shape"); + } + return Status::OK(); } -// Verifies that all values in box_ind are in [0, batch). +// Conditionally calls the compute callback if all values in box_index are in +// [0, batch_size) then calls done. template <typename Device> -inline void CheckValidBoxInd( - OpKernelContext* context, - typename TTypes<int32, 1>::ConstTensor box_ind_data, int batch); +inline void RunIfBoxIndexIsValid( + OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index, + int batch_size, Callback compute, Callback done); + +// Specialization of CheckValidBoxIndex for a CPUDevice. +template <> +inline void RunIfBoxIndexIsValid<CPUDevice>( + OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index, + int batch_size, Callback compute, Callback done) { + const int num_boxes = box_index.dimension(0); + for (int b = 0; b < num_boxes; ++b) { + OP_REQUIRES_ASYNC( + context, FastBoundsCheck(box_index(b), batch_size), + errors::OutOfRange("box_index has values outside [0, batch_size)"), + done); + } + compute(); + done(); +} + +} // namespace template <typename Device, typename T> -class CropAndResizeOp : public OpKernel { +class CropAndResizeOp : public AsyncOpKernel { public: - explicit CropAndResizeOp(OpKernelConstruction* context) : OpKernel(context) { + explicit CropAndResizeOp(OpKernelConstruction* context) + : AsyncOpKernel(context) { string method; OP_REQUIRES_OK(context, context->GetAttr("method", &method)); OP_REQUIRES(context, method == "bilinear", @@ -80,69 +112,77 @@ class CropAndResizeOp : public OpKernel { &extrapolation_value_)); } - void Compute(OpKernelContext* context) override { - // The shape of 'image' is [batch, image_height, image_width, channels]. + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + // The shape of 'image' is [batch_size, image_height, image_width, + // channels]. const Tensor& image = context->input(0); - OP_REQUIRES(context, image.dims() == 4, - errors::InvalidArgument("input image must be 4-D", - image.shape().DebugString())); - - const int batch = image.dim_size(0); - const int image_height = image.dim_size(1); - const int image_width = image.dim_size(2); - const int depth = image.dim_size(3); - OP_REQUIRES(context, image_height > 0 && image_width > 0, - errors::InvalidArgument("image dimensions must be positive")); - // The shape of 'boxes' is [num_boxes, 4]. const Tensor& boxes = context->input(1); - - // The shape of 'box_ind' is [num_boxes]. - const Tensor& box_ind = context->input(2); - - int num_boxes = 0; - ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes); - + // The shape of 'box_index' is [num_boxes]. + const Tensor& box_index = context->input(2); // The shape of 'crop_size' is [2]. const Tensor& crop_size = context->input(3); - OP_REQUIRES(context, crop_size.dims() == 1, - errors::InvalidArgument("crop_size must be 1-D", - crop_size.shape().DebugString())); - OP_REQUIRES(context, crop_size.dim_size(0) == 2, - errors::InvalidArgument("crop_size must have two elements", - crop_size.shape().DebugString())); + // Validate inputs dimensions. + OP_REQUIRES_ASYNC(context, image.dims() == 4, + errors::InvalidArgument("input image must be 4-D", + image.shape().DebugString()), + done); + const int batch_size = image.dim_size(0); + const int image_height = image.dim_size(1); + const int image_width = image.dim_size(2); + const int depth = image.dim_size(3); + OP_REQUIRES_ASYNC( + context, image_height > 0 && image_width > 0, + errors::InvalidArgument("image dimensions must be positive"), done); + int num_boxes = 0; + OP_REQUIRES_OK_ASYNC( + context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done); + OP_REQUIRES_ASYNC(context, crop_size.dims() == 1, + errors::InvalidArgument("crop_size must be 1-D", + crop_size.shape().DebugString()), + done); + OP_REQUIRES_ASYNC( + context, crop_size.dim_size(0) == 2, + errors::InvalidArgument("crop_size must have two elements", + crop_size.shape().DebugString()), + done); + + // Copy and validate crop sizes. auto crop_size_vec = crop_size.vec<int32>(); const int crop_height = internal::SubtleMustCopy(crop_size_vec(0)); const int crop_width = internal::SubtleMustCopy(crop_size_vec(1)); - OP_REQUIRES(context, crop_height > 0 && crop_width > 0, - errors::InvalidArgument("crop dimensions must be positive")); + OP_REQUIRES_ASYNC( + context, crop_height > 0 && crop_width > 0, + errors::InvalidArgument("crop dimensions must be positive"), done); // Allocate output tensor. Tensor* output = nullptr; - OP_REQUIRES_OK( + OP_REQUIRES_OK_ASYNC( context, context->allocate_output( 0, TensorShape({num_boxes, crop_height, crop_width, depth}), - &output)); + &output), + done); - typename TTypes<T, 4>::ConstTensor image_data = image.tensor<T, 4>(); - typename TTypes<float, 2>::ConstTensor boxes_data = - boxes.tensor<float, 2>(); - typename TTypes<int32, 1>::ConstTensor box_ind_data = - box_ind.tensor<int32, 1>(); - typename TTypes<float, 4>::Tensor crops_data = output->tensor<float, 4>(); + auto compute_callback = [this, context, output]() { + const Tensor& image = context->input(0); + const Tensor& boxes = context->input(1); + const Tensor& box_index = context->input(2); + const bool status = functor::CropAndResize<Device, T>()( + context->eigen_device<Device>(), image.tensor<T, 4>(), + boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(), + extrapolation_value_, output->tensor<float, 4>()); + if (!status) { + context->SetStatus( + errors::Internal("Failed launch CropAndResizeKernel.")); + } + }; - CheckValidBoxInd<Device>(context, box_ind_data, batch); - - bool status = functor::CropAndResize<Device, T>()( - context->eigen_device<Device>(), image_data, boxes_data, box_ind_data, - extrapolation_value_, crops_data); - if (!status) { - context->SetStatus( - errors::Internal("Failed launch CropAndResizeKernel.")); - } + RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(), + batch_size, std::move(compute_callback), + std::move(done)); } private: @@ -155,10 +195,10 @@ template <typename T> struct CropAndResize<CPUDevice, T> { bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor image, typename TTypes<float, 2>::ConstTensor boxes, - typename TTypes<int32, 1>::ConstTensor box_ind, + typename TTypes<int32, 1>::ConstTensor box_index, float extrapolation_value, typename TTypes<float, 4>::Tensor crops) { - const int batch = image.dimension(0); + const int batch_size = image.dimension(0); const int image_height = image.dimension(1); const int image_width = image.dimension(2); @@ -173,8 +213,8 @@ struct CropAndResize<CPUDevice, T> { const float y2 = boxes(b, 2); const float x2 = boxes(b, 3); - const int32 b_in = box_ind(b); - if (b_in < 0 || b_in >= batch) { + const int32 b_in = box_index(b); + if (!FastBoundsCheck(b_in, batch_size)) { continue; } @@ -235,89 +275,94 @@ struct CropAndResize<CPUDevice, T> { return true; } }; + } // namespace functor template <typename Device, typename T> -class CropAndResizeGradImageOp : public OpKernel { +class CropAndResizeGradImageOp : public AsyncOpKernel { public: explicit CropAndResizeGradImageOp(OpKernelConstruction* context) - : OpKernel(context) { + : AsyncOpKernel(context) { string method; OP_REQUIRES_OK(context, context->GetAttr("method", &method)); OP_REQUIRES(context, method == "bilinear", errors::InvalidArgument("method must be 'bilinear'", method)); } - void Compute(OpKernelContext* context) override { + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth]. const Tensor& grads = context->input(0); - - OP_REQUIRES(context, grads.dims() == 4, - errors::InvalidArgument("grads image must be 4-D", - grads.shape().DebugString())); - const int crop_height = grads.dim_size(1); - const int crop_width = grads.dim_size(2); - OP_REQUIRES(context, crop_height > 0 && crop_width > 0, - errors::InvalidArgument("grads dimensions must be positive")); - // The shape of 'boxes' is [num_boxes, 4]. const Tensor& boxes = context->input(1); - - // The shape of 'box_ind' is [num_boxes]. - const Tensor& box_ind = context->input(2); - - int num_boxes = 0; - ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes); - - OP_REQUIRES( - context, grads.dim_size(0) == num_boxes, - errors::InvalidArgument("boxes and grads have incompatible shape")); - + // The shape of 'box_index' is [num_boxes]. + const Tensor& box_index = context->input(2); // The shape of 'image_size' is [4]. const Tensor& image_size = context->input(3); - OP_REQUIRES(context, image_size.dims() == 1, - errors::InvalidArgument("image_size must be 1-D", - image_size.shape().DebugString())); - OP_REQUIRES(context, image_size.dim_size(0) == 4, - errors::InvalidArgument("image_size must have 4 elements", - image_size.shape().DebugString())); + // Validate input shapes. + OP_REQUIRES_ASYNC(context, grads.dims() == 4, + errors::InvalidArgument("grads image must be 4-D", + grads.shape().DebugString()), + done); + const int crop_height = grads.dim_size(1); + const int crop_width = grads.dim_size(2); + OP_REQUIRES_ASYNC( + context, crop_height > 0 && crop_width > 0, + errors::InvalidArgument("grads dimensions must be positive"), done); + int num_boxes = 0; + OP_REQUIRES_OK_ASYNC( + context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done); + OP_REQUIRES_ASYNC( + context, grads.dim_size(0) == num_boxes, + errors::InvalidArgument("boxes and grads have incompatible shape"), + done); + + OP_REQUIRES_ASYNC(context, image_size.dims() == 1, + errors::InvalidArgument("image_size must be 1-D", + image_size.shape().DebugString()), + done); + OP_REQUIRES_ASYNC(context, image_size.dim_size(0) == 4, + errors::InvalidArgument("image_size must have 4 elements", + image_size.shape().DebugString()), + done); auto image_size_vec = image_size.vec<int32>(); - const int batch = internal::SubtleMustCopy(image_size_vec(0)); + const int batch_size = internal::SubtleMustCopy(image_size_vec(0)); const int image_height = internal::SubtleMustCopy(image_size_vec(1)); const int image_width = internal::SubtleMustCopy(image_size_vec(2)); const int depth = internal::SubtleMustCopy(image_size_vec(3)); - - OP_REQUIRES(context, image_height > 0 && image_width > 0, - errors::InvalidArgument("image dimensions must be positive")); - OP_REQUIRES( + OP_REQUIRES_ASYNC( + context, image_height > 0 && image_width > 0, + errors::InvalidArgument("image dimensions must be positive"), done); + OP_REQUIRES_ASYNC( context, grads.dim_size(3) == depth, - errors::InvalidArgument("image_size and grads are incompatible")); + errors::InvalidArgument("image_size and grads are incompatible"), done); // Allocate output tensor. Tensor* output = nullptr; - OP_REQUIRES_OK( - context, context->allocate_output( - 0, TensorShape({batch, image_height, image_width, depth}), - &output)); + OP_REQUIRES_OK_ASYNC( + context, + context->allocate_output( + 0, TensorShape({batch_size, image_height, image_width, depth}), + &output), + done); - typename TTypes<float, 4>::ConstTensor grads_data = - grads.tensor<float, 4>(); - typename TTypes<float, 2>::ConstTensor boxes_data = - boxes.tensor<float, 2>(); - typename TTypes<int32, 1>::ConstTensor box_ind_data = - box_ind.tensor<int32, 1>(); - typename TTypes<T, 4>::Tensor output_data = output->tensor<T, 4>(); + auto compute_callback = [context, output]() { + const Tensor& grads = context->input(0); + const Tensor& boxes = context->input(1); + const Tensor& box_index = context->input(2); + const bool status = functor::CropAndResizeBackpropImage<Device, T>()( + context->eigen_device<Device>(), grads.tensor<float, 4>(), + boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(), + output->tensor<T, 4>()); + if (!status) { + context->SetStatus(errors::Internal( + "Failed launch CropAndResizeBackpropImage kernel.")); + } + }; - CheckValidBoxInd<Device>(context, box_ind_data, batch); - - bool status = functor::CropAndResizeBackpropImage<Device, T>()( - context->eigen_device<Device>(), grads_data, boxes_data, box_ind_data, - output_data); - if (!status) { - context->SetStatus( - errors::Internal("Failed launch CropAndResizeBackpropImageKernel.")); - } + RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(), + batch_size, std::move(compute_callback), + std::move(done)); } }; @@ -328,9 +373,9 @@ struct CropAndResizeBackpropImage<CPUDevice, T> { bool operator()(const CPUDevice& d, typename TTypes<float, 4>::ConstTensor grads, typename TTypes<float, 2>::ConstTensor boxes, - typename TTypes<int32, 1>::ConstTensor box_ind, + typename TTypes<int32, 1>::ConstTensor box_index, typename TTypes<T, 4>::Tensor grads_image) { - const int batch = grads_image.dimension(0); + const int batch_size = grads_image.dimension(0); const int image_height = grads_image.dimension(1); const int image_width = grads_image.dimension(2); @@ -347,8 +392,8 @@ struct CropAndResizeBackpropImage<CPUDevice, T> { const float y2 = boxes(b, 2); const float x2 = boxes(b, 3); - const int32 b_in = box_ind(b); - if (b_in < 0 || b_in >= batch) { + const int32 b_in = box_index(b); + if (!FastBoundsCheck(b_in, batch_size)) { continue; } @@ -399,83 +444,90 @@ struct CropAndResizeBackpropImage<CPUDevice, T> { return true; } }; + } // namespace functor template <typename Device, typename T> -class CropAndResizeGradBoxesOp : public OpKernel { +class CropAndResizeGradBoxesOp : public AsyncOpKernel { public: explicit CropAndResizeGradBoxesOp(OpKernelConstruction* context) - : OpKernel(context) { + : AsyncOpKernel(context) { string method; OP_REQUIRES_OK(context, context->GetAttr("method", &method)); OP_REQUIRES(context, method == "bilinear", errors::InvalidArgument("method must be 'bilinear'", method)); } - void Compute(OpKernelContext* context) override { + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth]. const Tensor& grads = context->input(0); + // The shape of 'boxes' is [num_boxes, 4]. + const Tensor& boxes = context->input(2); + // The shape of 'box_index' is [num_boxes]. + const Tensor& box_index = context->input(3); + // The shape of 'image' is [batch_size, image_height, image_width, depth]. + const Tensor& image = context->input(1); - OP_REQUIRES(context, grads.dims() == 4, - errors::InvalidArgument("grads image must be 4-D", - grads.shape().DebugString())); - + // Validate input shapes. + OP_REQUIRES_ASYNC(context, grads.dims() == 4, + errors::InvalidArgument("grads image must be 4-D", + grads.shape().DebugString()), + done); const int crop_height = grads.dim_size(1); const int crop_width = grads.dim_size(2); const int depth = grads.dim_size(3); - OP_REQUIRES(context, crop_height > 0 && crop_width > 0, - errors::InvalidArgument("grads dimensions must be positive")); + OP_REQUIRES_ASYNC( + context, crop_height > 0 && crop_width > 0, + errors::InvalidArgument("grads dimensions must be positive"), done); - // The shape of 'image' is [batch, image_height, image_width, depth]. - const Tensor& image = context->input(1); - OP_REQUIRES(context, image.dims() == 4, - errors::InvalidArgument("input image must be 4-D", - image.shape().DebugString())); - - const int batch = image.dim_size(0); + OP_REQUIRES_ASYNC(context, image.dims() == 4, + errors::InvalidArgument("input image must be 4-D", + image.shape().DebugString()), + done); + const int batch_size = image.dim_size(0); const int image_height = image.dim_size(1); const int image_width = image.dim_size(2); - OP_REQUIRES(context, image_height > 0 && image_width > 0, - errors::InvalidArgument("image dimensions must be positive")); - OP_REQUIRES(context, image.dim_size(3) == depth, - errors::InvalidArgument("image, grads depth differ")); - - // The shape of 'boxes' is [num_boxes, 4]. - const Tensor& boxes = context->input(2); - - // The shape of 'box_ind' is [num_boxes]. - const Tensor& box_ind = context->input(3); + OP_REQUIRES_ASYNC( + context, image_height > 0 && image_width > 0, + errors::InvalidArgument("image dimensions must be positive"), done); + OP_REQUIRES_ASYNC(context, image.dim_size(3) == depth, + errors::InvalidArgument("image, grads depth differ"), + done); int num_boxes = 0; - ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes); + OP_REQUIRES_OK_ASYNC( + context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done); - OP_REQUIRES( + OP_REQUIRES_ASYNC( context, grads.dim_size(0) == num_boxes, - errors::InvalidArgument("boxes and grads have incompatible shape")); + errors::InvalidArgument("boxes and grads have incompatible shape"), + done); // Allocate output tensor. Tensor* output = nullptr; - OP_REQUIRES_OK(context, context->allocate_output( - 0, TensorShape({num_boxes, 4}), &output)); + OP_REQUIRES_OK_ASYNC( + context, + context->allocate_output(0, TensorShape({num_boxes, 4}), &output), + done); - typename TTypes<float, 4>::ConstTensor grads_data = - grads.tensor<float, 4>(); - typename TTypes<T, 4>::ConstTensor image_data = image.tensor<T, 4>(); - typename TTypes<float, 2>::ConstTensor boxes_data = - boxes.tensor<float, 2>(); - typename TTypes<int32, 1>::ConstTensor box_ind_data = - box_ind.tensor<int32, 1>(); - typename TTypes<float, 2>::Tensor output_data = output->tensor<float, 2>(); + auto compute_callback = [context, output]() { + const Tensor& grads = context->input(0); + const Tensor& image = context->input(1); + const Tensor& boxes = context->input(2); + const Tensor& box_index = context->input(3); + const bool status = functor::CropAndResizeBackpropBoxes<Device, T>()( + context->eigen_device<Device>(), grads.tensor<float, 4>(), + image.tensor<T, 4>(), boxes.tensor<float, 2>(), + box_index.tensor<int32, 1>(), output->tensor<float, 2>()); + if (!status) { + context->SetStatus(errors::Internal( + "Failed launch CropAndResizeBackpropBoxes kernel.")); + } + }; - CheckValidBoxInd<Device>(context, box_ind_data, batch); - - bool status = functor::CropAndResizeBackpropBoxes<Device, T>()( - context->eigen_device<Device>(), grads_data, image_data, boxes_data, - box_ind_data, output_data); - if (!status) { - context->SetStatus( - errors::Internal("Failed launch CropAndResizeBackpropBoxesKernel.")); - } + RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(), + batch_size, std::move(compute_callback), + std::move(done)); } }; @@ -487,9 +539,9 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> { typename TTypes<float, 4>::ConstTensor grads, typename TTypes<T, 4>::ConstTensor image, typename TTypes<float, 2>::ConstTensor boxes, - typename TTypes<int32, 1>::ConstTensor box_ind, + typename TTypes<int32, 1>::ConstTensor box_index, typename TTypes<float, 2>::Tensor grads_boxes) { - const int batch = image.dimension(0); + const int batch_size = image.dimension(0); const int image_height = image.dimension(1); const int image_width = image.dimension(2); @@ -506,8 +558,8 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> { const float y2 = boxes(b, 2); const float x2 = boxes(b, 3); - const int32 b_in = box_ind(b); - if (b_in < 0 || b_in >= batch) { + const int32 b_in = box_index(b); + if (!FastBoundsCheck(b_in, batch_size)) { continue; } @@ -589,30 +641,19 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> { return true; } }; + } // namespace functor -// Specialization of CheckValidBoxInd for a CPUDevice. -template <> -inline void CheckValidBoxInd<CPUDevice>( - OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_ind, - int batch) { - const int num_boxes = box_ind.dimension(0); - for (int b = 0; b < num_boxes; ++b) { - OP_REQUIRES(context, box_ind(b) >= 0 && box_ind(b) < batch, - errors::OutOfRange("box_ind has values outside [0, batch)")); - } -} - -#define REGISTER_KERNEL(T) \ - REGISTER_KERNEL_BUILDER(Name("CropAndResize") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<T>("T") \ - .HostMemory("crop_size"), \ - CropAndResizeOp<CPUDevice, T>); \ - \ - REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<T>("T"), \ +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("CropAndResize") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .HostMemory("crop_size"), \ + CropAndResizeOp<CPUDevice, T>); \ + \ + REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T"), \ CropAndResizeGradBoxesOp<CPUDevice, T>); TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL); @@ -634,50 +675,86 @@ TF_CALL_double(REGISTER_KERNEL); #if GOOGLE_CUDA -// Forward declaration of the CheckValidBoxIndHelper specialization for GPU. +// Forward declaration of the CheckValidBoxIndexHelper specialization for GPU. namespace functor { template <> -void CheckValidBoxIndHelper<GPUDevice>::operator()( - const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_ind, - int batch, typename TTypes<bool, 0>::Tensor isvalid); -extern template struct CheckValidBoxIndHelper<GPUDevice>; +void CheckValidBoxIndexHelper<GPUDevice>::operator()( + const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_index, + int batch_size, typename TTypes<bool, 0>::Tensor isvalid); +extern template struct CheckValidBoxIndexHelper<GPUDevice>; } // namespace functor -// Specialization of CheckValidBoxInd for a GPUDevice. +namespace { + +// Specialization of CheckValidBoxIndex for a GPUDevice. template <> -inline void CheckValidBoxInd<GPUDevice>( - OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_ind, - int batch) { - const int num_boxes = box_ind.dimension(0); +inline void RunIfBoxIndexIsValid<GPUDevice>( + OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index, + int batch_size, Callback compute, Callback done) { + const int num_boxes = box_index.dimension(0); if (num_boxes == 0) { + compute(); + done(); return; } - Tensor isvalid_tensor; - OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum<bool>::value, - TensorShape({}), &isvalid_tensor)); - typename TTypes<bool, 0>::Tensor isvalid = isvalid_tensor.tensor<bool, 0>(); + Tensor isvalid_dev_tensor; + OP_REQUIRES_OK_ASYNC( + context, + context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}), + &isvalid_dev_tensor), + done); + typename TTypes<bool, 0>::Tensor isvalid_dev = + isvalid_dev_tensor.tensor<bool, 0>(); - functor::CheckValidBoxIndHelper<GPUDevice>()( - context->eigen_device<GPUDevice>(), box_ind, batch, isvalid); + // Run the actual box check on the device. + functor::CheckValidBoxIndexHelper<GPUDevice>()( + context->eigen_device<GPUDevice>(), box_index, batch_size, isvalid_dev); + // Copy the result back to the host. auto* stream = context->op_device_context()->stream(); - OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + OP_REQUIRES_ASYNC(context, stream, + errors::Internal("No GPU stream available."), done); + Tensor isvalid_host_tensor; + // Use pinned host memory on the host to avoid unnecessary + // synchronization. + AllocatorAttributes alloc_attr; + alloc_attr.set_on_host(true); + alloc_attr.set_gpu_compatible(true); + OP_REQUIRES_OK_ASYNC( + context, + context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}), + &isvalid_host_tensor, alloc_attr), + done); + typename TTypes<bool, 0>::Tensor isvalid_host = + isvalid_host_tensor.tensor<bool, 0>(); - bool isvalid_host = false; - perftools::gputools::DeviceMemoryBase isvalid_gpu(isvalid.data(), - sizeof(bool)); - stream->ThenMemcpy(&isvalid_host, isvalid_gpu, sizeof(bool)); - stream->BlockHostUntilDone(); + perftools::gputools::DeviceMemoryBase wrapped(isvalid_dev.data(), + sizeof(bool)); + const bool status = stream + ->ThenMemcpy(isvalid_host.data() /* destination */, + wrapped /* source */, sizeof(bool)) + .ok(); + OP_REQUIRES_ASYNC( + context, status, + errors::Internal("Failed to launch copy of isvalid from device to host."), + done); - OP_REQUIRES(context, stream->ok(), - errors::Internal("cudaMemcpy from device to host failed")); + auto wrapped_callback = [context, isvalid_host, compute, done]() { + OP_REQUIRES_ASYNC( + context, isvalid_host(), + errors::OutOfRange("box_index has values outside [0, batch_size)"), + done); + compute(); + done(); + }; - OP_REQUIRES(context, isvalid_host, - errors::OutOfRange("box_ind has values outside [0, batch)")); + context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( + stream, wrapped_callback); } +} // namespace + #define REGISTER_KERNEL(T) \ REGISTER_KERNEL_BUILDER(Name("CropAndResize") \ .Device(DEVICE_GPU) \ diff --git a/tensorflow/core/kernels/crop_and_resize_op.h b/tensorflow/core/kernels/crop_and_resize_op.h index 22df1bdd56b..460dbad22b4 100644 --- a/tensorflow/core/kernels/crop_and_resize_op.h +++ b/tensorflow/core/kernels/crop_and_resize_op.h @@ -53,12 +53,12 @@ struct CropAndResizeBackpropBoxes { }; template <typename Device> -struct CheckValidBoxIndHelper { - // Checks if all values in box_ind are in [0, batch). +struct CheckValidBoxIndexHelper { + // Checks if all values in box_index are in [0, batch). void operator()(const Device& d, - typename TTypes<int32, 1>::ConstTensor box_ind, int batch, + typename TTypes<int32, 1>::ConstTensor box_index, int batch, typename TTypes<bool, 0>::Tensor isvalid) { - isvalid.device(d) = ((box_ind >= 0) && (box_ind < batch)).all(); + isvalid.device(d) = ((box_index >= 0) && (box_index < batch)).all(); } }; diff --git a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc index 254475db465..c1235fda892 100644 --- a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc +++ b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc @@ -440,7 +440,7 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); #undef DEFINE_GPU_SPECS -template struct CheckValidBoxIndHelper<GPUDevice>; +template struct CheckValidBoxIndexHelper<GPUDevice>; } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/crop_and_resize_op_test.cc b/tensorflow/core/kernels/crop_and_resize_op_test.cc index 3a7f180598e..d6139dae966 100644 --- a/tensorflow/core/kernels/crop_and_resize_op_test.cc +++ b/tensorflow/core/kernels/crop_and_resize_op_test.cc @@ -251,7 +251,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) { Status s = RunOpKernel(); ASSERT_FALSE(s.ok()); EXPECT_TRUE( - StringPiece(s.ToString()).contains("box_ind has incompatible shape")) + StringPiece(s.ToString()).contains("box_index has incompatible shape")) << s; } @@ -264,8 +264,10 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) { Status s = RunOpKernel(); ASSERT_FALSE(s.ok()); EXPECT_TRUE(StringPiece(s.ToString()) - .contains("box_ind has values outside [0, batch)")) + .contains("box_index has values outside [0, batch_size)")) << s; } +// TODO(zhengxq, rmlarsen): Add a benchmark. + } // namespace tensorflow