[ROCm] Fix to enable XLA_GPU device registration for ROCm platform
This commit is contained in:
parent
c14b6951de
commit
88a1e3b399
@ -14,7 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
|
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
|
||||||
// operators using XLA via the XLA "CUDA" (GPU) backend.
|
// operators using XLA via the XLA "CUDA" or "ROCM" (GPU) backend.
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
|
|
||||||
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/jit/xla_device_ops.h"
|
#include "tensorflow/compiler/jit/xla_device_ops.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
|
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -69,7 +70,8 @@ Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
|
auto platform =
|
||||||
|
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
|
||||||
if (!platform.ok()) {
|
if (!platform.ok()) {
|
||||||
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
||||||
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
|
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
|
||||||
@ -117,7 +119,8 @@ Status XlaGpuDeviceFactory::CreateDevices(
|
|||||||
RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
|
RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
|
||||||
(void)registrations;
|
(void)registrations;
|
||||||
|
|
||||||
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
|
auto platform =
|
||||||
|
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
|
||||||
if (!platform.ok()) {
|
if (!platform.ok()) {
|
||||||
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
||||||
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
|
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
|
||||||
|
Loading…
Reference in New Issue
Block a user