Added ROCm support for the fake_quant ops

This commit is contained in:
Jeffrey Poznanovic 2019-03-19 17:57:19 +00:00
parent 5ae4902198
commit c9f8da889b
2 changed files with 12 additions and 12 deletions

View File

@ -15,9 +15,9 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#ifdef GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/fake_quant_ops_functor.h" #include "tensorflow/core/kernels/fake_quant_ops_functor.h"
@ -28,9 +28,9 @@ limitations under the License.
using tensorflow::BinaryElementWiseOp; using tensorflow::BinaryElementWiseOp;
using tensorflow::DEVICE_CPU; using tensorflow::DEVICE_CPU;
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
using tensorflow::DEVICE_GPU; using tensorflow::DEVICE_GPU;
#endif #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
using tensorflow::OpKernel; using tensorflow::OpKernel;
using tensorflow::OpKernelConstruction; using tensorflow::OpKernelConstruction;
using tensorflow::OpKernelContext; using tensorflow::OpKernelContext;
@ -143,7 +143,7 @@ REGISTER_KERNEL_BUILDER(
Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_CPU), Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_CPU),
FakeQuantWithMinMaxArgsGradientOp<CPUDevice>); FakeQuantWithMinMaxArgsGradientOp<CPUDevice>);
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
typedef Eigen::GpuDevice GPUDevice; typedef Eigen::GpuDevice GPUDevice;
// Forward declarations for functor specializations for GPU. // Forward declarations for functor specializations for GPU.
@ -165,7 +165,7 @@ void FakeQuantWithMinMaxArgsGradientFunctor<GPUDevice>::operator()(
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(
Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_GPU), Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_GPU),
FakeQuantWithMinMaxArgsGradientOp<GPUDevice>); FakeQuantWithMinMaxArgsGradientOp<GPUDevice>);
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
// Implementation of FakeQuantWithMinMaxVarsOp, see its documentation in // Implementation of FakeQuantWithMinMaxVarsOp, see its documentation in
@ -265,7 +265,7 @@ REGISTER_KERNEL_BUILDER(
Name("FakeQuantWithMinMaxVarsGradient").Device(DEVICE_CPU), Name("FakeQuantWithMinMaxVarsGradient").Device(DEVICE_CPU),
FakeQuantWithMinMaxVarsGradientOp<CPUDevice>); FakeQuantWithMinMaxVarsGradientOp<CPUDevice>);
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <> template <>
void FakeQuantWithMinMaxVarsFunctor<GPUDevice>::operator()( void FakeQuantWithMinMaxVarsFunctor<GPUDevice>::operator()(
const GPUDevice& d, typename TTypes<float>::ConstFlat inputs, const GPUDevice& d, typename TTypes<float>::ConstFlat inputs,
@ -294,7 +294,7 @@ REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsGradient")
.HostMemory("min") .HostMemory("min")
.HostMemory("max"), .HostMemory("max"),
FakeQuantWithMinMaxVarsGradientOp<GPUDevice>); FakeQuantWithMinMaxVarsGradientOp<GPUDevice>);
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
// Implementation of FakeQuantWithMinMaxVarsPerChannelOp, see its documentation // Implementation of FakeQuantWithMinMaxVarsPerChannelOp, see its documentation
@ -411,7 +411,7 @@ REGISTER_KERNEL_BUILDER(
Name("FakeQuantWithMinMaxVarsPerChannelGradient").Device(DEVICE_CPU), Name("FakeQuantWithMinMaxVarsPerChannelGradient").Device(DEVICE_CPU),
FakeQuantWithMinMaxVarsPerChannelGradientOp<CPUDevice>); FakeQuantWithMinMaxVarsPerChannelGradientOp<CPUDevice>);
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <> template <>
void FakeQuantWithMinMaxVarsPerChannelFunctor<GPUDevice>::operator()( void FakeQuantWithMinMaxVarsPerChannelFunctor<GPUDevice>::operator()(
const GPUDevice& d, typename TTypes<float>::ConstMatrix inputs, const GPUDevice& d, typename TTypes<float>::ConstMatrix inputs,
@ -443,6 +443,6 @@ REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannelGradient")
.HostMemory("min") .HostMemory("min")
.HostMemory("max"), .HostMemory("max"),
FakeQuantWithMinMaxVarsPerChannelGradientOp<GPUDevice>); FakeQuantWithMinMaxVarsPerChannelGradientOp<GPUDevice>);
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow } // namespace tensorflow

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define FAKE_QUANT_NO_DEBUG #define FAKE_QUANT_NO_DEBUG
@ -34,4 +34,4 @@ template struct FakeQuantWithMinMaxVarsPerChannelGradientFunctor<GPUDevice>;
} // namespace tensorflow } // namespace tensorflow
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM