Add ROCm support for adjust_hue and adjust_saturation op
This commit is contained in:
parent
ef0b1eff8d
commit
9f46967616
@ -14,7 +14,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
@ -141,5 +141,5 @@ __global__ void adjust_hsv_nhwc(const int64 number_elements,
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#endif // TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_
|
||||
|
@ -13,7 +13,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -249,7 +249,7 @@ REGISTER_KERNEL_BUILDER(
|
||||
Name("AdjustHue").Device(DEVICE_CPU).TypeConstraint<float>("T"),
|
||||
AdjustHueOp<CPUDevice, float>);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
template <typename T>
|
||||
class AdjustHueOp<GPUDevice, T> : public AdjustHueOpBase {
|
||||
public:
|
||||
|
@ -14,7 +14,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H_
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
@ -37,5 +37,5 @@ struct AdjustHueGPU {
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#endif // TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H_
|
||||
|
@ -12,7 +12,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
@ -35,10 +35,10 @@ void AdjustHueGPU<T>::operator()(GPUDevice* device,
|
||||
const int threads_per_block = config.thread_per_block;
|
||||
const int block_count =
|
||||
(number_of_elements + threads_per_block - 1) / threads_per_block;
|
||||
TF_CHECK_OK(CudaLaunchKernel(internal::adjust_hsv_nhwc<true, false, false, T>,
|
||||
block_count, threads_per_block, 0, stream,
|
||||
number_of_elements, input, output, delta,
|
||||
nullptr, nullptr));
|
||||
TF_CHECK_OK(GpuLaunchKernel(internal::adjust_hsv_nhwc<true, false, false, T>,
|
||||
block_count, threads_per_block, 0, stream,
|
||||
number_of_elements, input, output, delta,
|
||||
nullptr, nullptr));
|
||||
}
|
||||
|
||||
template struct AdjustHueGPU<float>;
|
||||
@ -46,4 +46,4 @@ template struct AdjustHueGPU<Eigen::half>;
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -14,7 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -215,7 +215,7 @@ REGISTER_KERNEL_BUILDER(
|
||||
Name("AdjustSaturation").Device(DEVICE_CPU).TypeConstraint<float>("T"),
|
||||
AdjustSaturationOp<CPUDevice, float>);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
template <typename T>
|
||||
class AdjustSaturationOp<GPUDevice, T> : public AdjustSaturationOpBase {
|
||||
public:
|
||||
|
@ -14,7 +14,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H_
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
@ -37,5 +37,5 @@ struct AdjustSaturationGPU {
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#endif // TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H_
|
||||
|
@ -12,7 +12,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
@ -36,10 +36,10 @@ void AdjustSaturationGPU<T>::operator()(GPUDevice* device,
|
||||
const int threads_per_block = config.thread_per_block;
|
||||
const int block_count =
|
||||
(number_of_elements + threads_per_block - 1) / threads_per_block;
|
||||
TF_CHECK_OK(CudaLaunchKernel(internal::adjust_hsv_nhwc<false, true, false, T>,
|
||||
block_count, threads_per_block, 0, stream,
|
||||
number_of_elements, input, output, nullptr,
|
||||
scale, nullptr));
|
||||
TF_CHECK_OK(GpuLaunchKernel(internal::adjust_hsv_nhwc<false, true, false, T>,
|
||||
block_count, threads_per_block, 0, stream,
|
||||
number_of_elements, input, output, nullptr,
|
||||
scale, nullptr));
|
||||
}
|
||||
|
||||
template struct AdjustSaturationGPU<float>;
|
||||
@ -47,4 +47,4 @@ template struct AdjustSaturationGPU<Eigen::half>;
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
Loading…
Reference in New Issue
Block a user