Minor tweak to allow the use of ROCM platform as a gpu target when registering a custom_call in xla client

This commit is contained in:
Deven Desai 2020-12-09 19:50:02 +00:00
parent da8326db56
commit 3f95825a35

View File

@ -431,7 +431,10 @@ def register_custom_call_target(name, fn, platform='cpu'):
fn: a PyCapsule object containing the function pointer.
platform: the target platform.
"""
_xla.register_custom_call_target(name, fn, xla_platform_names[platform])
# To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM"
# Since that is hardcoded to CUDA, we are using the following as workaround.
_xla.register_custom_call_target(name, fn,
xla_platform_names.get(platform, platform))
# Deprecated. Use register_custom_call_target instead.