[ROCm] Fix to enable XLA_GPU device registration for ROCm platform

This commit is contained in:
Deven Desai 2020-01-03 18:11:27 +00:00
parent c14b6951de
commit 88a1e3b399
1 changed files with 6 additions and 3 deletions

View File

@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
// operators using XLA via the XLA "CUDA" (GPU) backend.
// operators using XLA via the XLA "CUDA" or "ROCM" (GPU) backend.
#include <set>
@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
@ -69,7 +70,8 @@ Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
return Status::OK();
}
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
auto platform =
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
if (!platform.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine.
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
@ -117,7 +119,8 @@ Status XlaGpuDeviceFactory::CreateDevices(
RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
(void)registrations;
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
auto platform =
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
if (!platform.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine.
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();