Minor tweak to allow the use of ROCM platform as a gpu target when registering a custom_call in xla client
This commit is contained in:
parent
da8326db56
commit
3f95825a35
@ -431,7 +431,10 @@ def register_custom_call_target(name, fn, platform='cpu'):
|
||||
fn: a PyCapsule object containing the function pointer.
|
||||
platform: the target platform.
|
||||
"""
|
||||
_xla.register_custom_call_target(name, fn, xla_platform_names[platform])
|
||||
# To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM"
|
||||
# Since that is hardcoded to CUDA, we are using the following as workaround.
|
||||
_xla.register_custom_call_target(name, fn,
|
||||
xla_platform_names.get(platform, platform))
|
||||
|
||||
|
||||
# Deprecated. Use register_custom_call_target instead.
|
||||
|
Loading…
Reference in New Issue
Block a user