[ROCm] Fix to enable XLA_GPU device registration for ROCm platform

This commit is contained in:
Deven Desai 2020-01-03 18:11:27 +00:00
parent c14b6951de
commit 88a1e3b399

View File

@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs // Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
// operators using XLA via the XLA "CUDA" (GPU) backend. // operators using XLA via the XLA "CUDA" or "ROCM" (GPU) backend.
#include <set> #include <set>
@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
namespace tensorflow { namespace tensorflow {
@ -69,7 +70,8 @@ Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
return Status::OK(); return Status::OK();
} }
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA"); auto platform =
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
if (!platform.ok()) { if (!platform.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine. // Treat failures as non-fatal; there might not be a GPU in the machine.
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status(); VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
@ -117,7 +119,8 @@ Status XlaGpuDeviceFactory::CreateDevices(
RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT); RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
(void)registrations; (void)registrations;
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA"); auto platform =
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
if (!platform.ok()) { if (!platform.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine. // Treat failures as non-fatal; there might not be a GPU in the machine.
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status(); VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();