Add ROCm support for adjust_hue and adjust_saturation op
This commit is contained in:
parent
ef0b1eff8d
commit
9f46967616
@ -14,7 +14,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_
|
#define TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
@ -141,5 +141,5 @@ __global__ void adjust_hsv_nhwc(const int64 number_elements,
|
|||||||
} // namespace internal
|
} // namespace internal
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_
|
#endif // TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_
|
||||||
|
@ -13,7 +13,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -249,7 +249,7 @@ REGISTER_KERNEL_BUILDER(
|
|||||||
Name("AdjustHue").Device(DEVICE_CPU).TypeConstraint<float>("T"),
|
Name("AdjustHue").Device(DEVICE_CPU).TypeConstraint<float>("T"),
|
||||||
AdjustHueOp<CPUDevice, float>);
|
AdjustHueOp<CPUDevice, float>);
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class AdjustHueOp<GPUDevice, T> : public AdjustHueOpBase {
|
class AdjustHueOp<GPUDevice, T> : public AdjustHueOpBase {
|
||||||
public:
|
public:
|
||||||
|
@ -14,7 +14,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H_
|
#define TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H_
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
@ -37,5 +37,5 @@ struct AdjustHueGPU {
|
|||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H_
|
#endif // TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H_
|
||||||
|
@ -12,7 +12,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
@ -35,10 +35,10 @@ void AdjustHueGPU<T>::operator()(GPUDevice* device,
|
|||||||
const int threads_per_block = config.thread_per_block;
|
const int threads_per_block = config.thread_per_block;
|
||||||
const int block_count =
|
const int block_count =
|
||||||
(number_of_elements + threads_per_block - 1) / threads_per_block;
|
(number_of_elements + threads_per_block - 1) / threads_per_block;
|
||||||
TF_CHECK_OK(CudaLaunchKernel(internal::adjust_hsv_nhwc<true, false, false, T>,
|
TF_CHECK_OK(GpuLaunchKernel(internal::adjust_hsv_nhwc<true, false, false, T>,
|
||||||
block_count, threads_per_block, 0, stream,
|
block_count, threads_per_block, 0, stream,
|
||||||
number_of_elements, input, output, delta,
|
number_of_elements, input, output, delta,
|
||||||
nullptr, nullptr));
|
nullptr, nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
template struct AdjustHueGPU<float>;
|
template struct AdjustHueGPU<float>;
|
||||||
@ -46,4 +46,4 @@ template struct AdjustHueGPU<Eigen::half>;
|
|||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
@ -14,7 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -215,7 +215,7 @@ REGISTER_KERNEL_BUILDER(
|
|||||||
Name("AdjustSaturation").Device(DEVICE_CPU).TypeConstraint<float>("T"),
|
Name("AdjustSaturation").Device(DEVICE_CPU).TypeConstraint<float>("T"),
|
||||||
AdjustSaturationOp<CPUDevice, float>);
|
AdjustSaturationOp<CPUDevice, float>);
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class AdjustSaturationOp<GPUDevice, T> : public AdjustSaturationOpBase {
|
class AdjustSaturationOp<GPUDevice, T> : public AdjustSaturationOpBase {
|
||||||
public:
|
public:
|
||||||
|
@ -14,7 +14,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H_
|
#define TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H_
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
@ -37,5 +37,5 @@ struct AdjustSaturationGPU {
|
|||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H_
|
#endif // TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H_
|
||||||
|
@ -12,7 +12,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
@ -36,10 +36,10 @@ void AdjustSaturationGPU<T>::operator()(GPUDevice* device,
|
|||||||
const int threads_per_block = config.thread_per_block;
|
const int threads_per_block = config.thread_per_block;
|
||||||
const int block_count =
|
const int block_count =
|
||||||
(number_of_elements + threads_per_block - 1) / threads_per_block;
|
(number_of_elements + threads_per_block - 1) / threads_per_block;
|
||||||
TF_CHECK_OK(CudaLaunchKernel(internal::adjust_hsv_nhwc<false, true, false, T>,
|
TF_CHECK_OK(GpuLaunchKernel(internal::adjust_hsv_nhwc<false, true, false, T>,
|
||||||
block_count, threads_per_block, 0, stream,
|
block_count, threads_per_block, 0, stream,
|
||||||
number_of_elements, input, output, nullptr,
|
number_of_elements, input, output, nullptr,
|
||||||
scale, nullptr));
|
scale, nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
template struct AdjustSaturationGPU<float>;
|
template struct AdjustSaturationGPU<float>;
|
||||||
@ -47,4 +47,4 @@ template struct AdjustSaturationGPU<Eigen::half>;
|
|||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
Loading…
Reference in New Issue
Block a user