diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 4a78e5cb909..d1d3de9041d 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -431,7 +431,10 @@ def register_custom_call_target(name, fn, platform='cpu'): fn: a PyCapsule object containing the function pointer. platform: the target platform. """ - _xla.register_custom_call_target(name, fn, xla_platform_names[platform]) + # To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM" + # Since that is hardcoded to CUDA, we are using the following as workaround. + _xla.register_custom_call_target(name, fn, + xla_platform_names.get(platform, platform)) # Deprecated. Use register_custom_call_target instead.