[ROCm] Fix to enable XLA_GPU device registration for ROCm platform
This commit is contained in:
parent
c14b6951de
commit
88a1e3b399
|
@ -14,7 +14,7 @@ limitations under the License.
|
|||
==============================================================================*/
|
||||
|
||||
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
|
||||
// operators using XLA via the XLA "CUDA" (GPU) backend.
|
||||
// operators using XLA via the XLA "CUDA" or "ROCM" (GPU) backend.
|
||||
|
||||
#include <set>
|
||||
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/jit/xla_device_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -69,7 +70,8 @@ Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
|
||||
auto platform =
|
||||
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
|
||||
if (!platform.ok()) {
|
||||
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
||||
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
|
||||
|
@ -117,7 +119,8 @@ Status XlaGpuDeviceFactory::CreateDevices(
|
|||
RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
|
||||
(void)registrations;
|
||||
|
||||
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
|
||||
auto platform =
|
||||
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
|
||||
if (!platform.ok()) {
|
||||
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
||||
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
|
||||
|
|
Loading…
Reference in New Issue