diff --git a/tensorflow/core/kernels/ops_testutil.cc b/tensorflow/core/kernels/ops_testutil.cc index 87f70d3a3b3..c6f751d196c 100644 --- a/tensorflow/core/kernels/ops_testutil.cc +++ b/tensorflow/core/kernels/ops_testutil.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/node_properties.h" -#ifdef GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #include "tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h" #endif @@ -112,7 +112,7 @@ void OpsTestBase::SetDevice(const DeviceType& device_type, thread_pool_.get()); device_type_ = device_type; -#ifdef GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM if (device_type == DEVICE_GPU) { managed_allocator_.reset(new GpuManagedAllocator()); allocator_ = managed_allocator_.get(); @@ -122,7 +122,7 @@ void OpsTestBase::SetDevice(const DeviceType& device_type, } #else CHECK_NE(device_type, DEVICE_GPU) - << "Requesting GPU on binary compiled without GOOGLE_CUDA."; + << "Requesting GPU on binary compiled without GOOGLE_CUDA or TENSORFLOW_USE_ROCM."; allocator_ = device_->GetAllocator(AllocatorAttributes()); #endif } @@ -195,7 +195,7 @@ TensorValue OpsTestBase::mutable_input(int input_index) { Tensor* OpsTestBase::GetOutput(int output_index) { CHECK_LT(output_index, context_->num_outputs()); Tensor* output = context_->mutable_output(output_index); -#ifdef GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM if (device_type_ == DEVICE_GPU) { managed_outputs_.resize(context_->num_outputs()); // Copy the output tensor to managed memory if we haven't done so.