From 3f95825a35788b7f75ccfe1aa0c4b2bf8e3c7d12 Mon Sep 17 00:00:00 2001 From: Deven Desai Date: Wed, 9 Dec 2020 19:50:02 +0000 Subject: [PATCH] Minor tweak to allow the use of ROCM platform as a gpu target when registering a custom_call in xla client --- tensorflow/compiler/xla/python/xla_client.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 4a78e5cb909..d1d3de9041d 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -431,7 +431,10 @@ def register_custom_call_target(name, fn, platform='cpu'): fn: a PyCapsule object containing the function pointer. platform: the target platform. """ - _xla.register_custom_call_target(name, fn, xla_platform_names[platform]) + # To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM" + # Since that is hardcoded to CUDA, we are using the following as workaround. + _xla.register_custom_call_target(name, fn, + xla_platform_names.get(platform, platform)) # Deprecated. Use register_custom_call_target instead.