Minor tweak to allow the use of ROCM platform as a gpu target when registering a custom_call in xla client
This commit is contained in:
parent
da8326db56
commit
3f95825a35
@ -431,7 +431,10 @@ def register_custom_call_target(name, fn, platform='cpu'):
|
|||||||
fn: a PyCapsule object containing the function pointer.
|
fn: a PyCapsule object containing the function pointer.
|
||||||
platform: the target platform.
|
platform: the target platform.
|
||||||
"""
|
"""
|
||||||
_xla.register_custom_call_target(name, fn, xla_platform_names[platform])
|
# To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM"
|
||||||
|
# Since that is hardcoded to CUDA, we are using the following as workaround.
|
||||||
|
_xla.register_custom_call_target(name, fn,
|
||||||
|
xla_platform_names.get(platform, platform))
|
||||||
|
|
||||||
|
|
||||||
# Deprecated. Use register_custom_call_target instead.
|
# Deprecated. Use register_custom_call_target instead.
|
||||||
|
Loading…
Reference in New Issue
Block a user