Added ROCm support for the fake_quant ops
This commit is contained in:
parent
5ae4902198
commit
c9f8da889b
@ -15,9 +15,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
#ifdef GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/fake_quant_ops_functor.h"
|
#include "tensorflow/core/kernels/fake_quant_ops_functor.h"
|
||||||
|
|
||||||
@ -28,9 +28,9 @@ limitations under the License.
|
|||||||
|
|
||||||
using tensorflow::BinaryElementWiseOp;
|
using tensorflow::BinaryElementWiseOp;
|
||||||
using tensorflow::DEVICE_CPU;
|
using tensorflow::DEVICE_CPU;
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
using tensorflow::DEVICE_GPU;
|
using tensorflow::DEVICE_GPU;
|
||||||
#endif
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
using tensorflow::OpKernel;
|
using tensorflow::OpKernel;
|
||||||
using tensorflow::OpKernelConstruction;
|
using tensorflow::OpKernelConstruction;
|
||||||
using tensorflow::OpKernelContext;
|
using tensorflow::OpKernelContext;
|
||||||
@ -143,7 +143,7 @@ REGISTER_KERNEL_BUILDER(
|
|||||||
Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_CPU),
|
Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_CPU),
|
||||||
FakeQuantWithMinMaxArgsGradientOp<CPUDevice>);
|
FakeQuantWithMinMaxArgsGradientOp<CPUDevice>);
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
// Forward declarations for functor specializations for GPU.
|
// Forward declarations for functor specializations for GPU.
|
||||||
@ -165,7 +165,7 @@ void FakeQuantWithMinMaxArgsGradientFunctor<GPUDevice>::operator()(
|
|||||||
REGISTER_KERNEL_BUILDER(
|
REGISTER_KERNEL_BUILDER(
|
||||||
Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_GPU),
|
Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_GPU),
|
||||||
FakeQuantWithMinMaxArgsGradientOp<GPUDevice>);
|
FakeQuantWithMinMaxArgsGradientOp<GPUDevice>);
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// Implementation of FakeQuantWithMinMaxVarsOp, see its documentation in
|
// Implementation of FakeQuantWithMinMaxVarsOp, see its documentation in
|
||||||
@ -265,7 +265,7 @@ REGISTER_KERNEL_BUILDER(
|
|||||||
Name("FakeQuantWithMinMaxVarsGradient").Device(DEVICE_CPU),
|
Name("FakeQuantWithMinMaxVarsGradient").Device(DEVICE_CPU),
|
||||||
FakeQuantWithMinMaxVarsGradientOp<CPUDevice>);
|
FakeQuantWithMinMaxVarsGradientOp<CPUDevice>);
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
template <>
|
template <>
|
||||||
void FakeQuantWithMinMaxVarsFunctor<GPUDevice>::operator()(
|
void FakeQuantWithMinMaxVarsFunctor<GPUDevice>::operator()(
|
||||||
const GPUDevice& d, typename TTypes<float>::ConstFlat inputs,
|
const GPUDevice& d, typename TTypes<float>::ConstFlat inputs,
|
||||||
@ -294,7 +294,7 @@ REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsGradient")
|
|||||||
.HostMemory("min")
|
.HostMemory("min")
|
||||||
.HostMemory("max"),
|
.HostMemory("max"),
|
||||||
FakeQuantWithMinMaxVarsGradientOp<GPUDevice>);
|
FakeQuantWithMinMaxVarsGradientOp<GPUDevice>);
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// Implementation of FakeQuantWithMinMaxVarsPerChannelOp, see its documentation
|
// Implementation of FakeQuantWithMinMaxVarsPerChannelOp, see its documentation
|
||||||
@ -411,7 +411,7 @@ REGISTER_KERNEL_BUILDER(
|
|||||||
Name("FakeQuantWithMinMaxVarsPerChannelGradient").Device(DEVICE_CPU),
|
Name("FakeQuantWithMinMaxVarsPerChannelGradient").Device(DEVICE_CPU),
|
||||||
FakeQuantWithMinMaxVarsPerChannelGradientOp<CPUDevice>);
|
FakeQuantWithMinMaxVarsPerChannelGradientOp<CPUDevice>);
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
template <>
|
template <>
|
||||||
void FakeQuantWithMinMaxVarsPerChannelFunctor<GPUDevice>::operator()(
|
void FakeQuantWithMinMaxVarsPerChannelFunctor<GPUDevice>::operator()(
|
||||||
const GPUDevice& d, typename TTypes<float>::ConstMatrix inputs,
|
const GPUDevice& d, typename TTypes<float>::ConstMatrix inputs,
|
||||||
@ -443,6 +443,6 @@ REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannelGradient")
|
|||||||
.HostMemory("min")
|
.HostMemory("min")
|
||||||
.HostMemory("max"),
|
.HostMemory("max"),
|
||||||
FakeQuantWithMinMaxVarsPerChannelGradientOp<GPUDevice>);
|
FakeQuantWithMinMaxVarsPerChannelGradientOp<GPUDevice>);
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#define FAKE_QUANT_NO_DEBUG
|
#define FAKE_QUANT_NO_DEBUG
|
||||||
|
|
||||||
@ -34,4 +34,4 @@ template struct FakeQuantWithMinMaxVarsPerChannelGradientFunctor<GPUDevice>;
|
|||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
Loading…
Reference in New Issue
Block a user