Adding ROCm support to OpsTestBase (necessary for some unit tests)
This commit is contained in:
parent
b6be4f36eb
commit
580dc945a6
@ -14,7 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/node_properties.h"
|
||||
#ifdef GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h"
|
||||
#endif
|
||||
@ -112,7 +112,7 @@ void OpsTestBase::SetDevice(const DeviceType& device_type,
|
||||
thread_pool_.get());
|
||||
|
||||
device_type_ = device_type;
|
||||
#ifdef GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
if (device_type == DEVICE_GPU) {
|
||||
managed_allocator_.reset(new GpuManagedAllocator());
|
||||
allocator_ = managed_allocator_.get();
|
||||
@ -122,7 +122,7 @@ void OpsTestBase::SetDevice(const DeviceType& device_type,
|
||||
}
|
||||
#else
|
||||
CHECK_NE(device_type, DEVICE_GPU)
|
||||
<< "Requesting GPU on binary compiled without GOOGLE_CUDA.";
|
||||
<< "Requesting GPU on binary compiled without GOOGLE_CUDA or TENSORFLOW_USE_ROCM.";
|
||||
allocator_ = device_->GetAllocator(AllocatorAttributes());
|
||||
#endif
|
||||
}
|
||||
@ -195,7 +195,7 @@ TensorValue OpsTestBase::mutable_input(int input_index) {
|
||||
Tensor* OpsTestBase::GetOutput(int output_index) {
|
||||
CHECK_LT(output_index, context_->num_outputs());
|
||||
Tensor* output = context_->mutable_output(output_index);
|
||||
#ifdef GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
if (device_type_ == DEVICE_GPU) {
|
||||
managed_outputs_.resize(context_->num_outputs());
|
||||
// Copy the output tensor to managed memory if we haven't done so.
|
||||
|
Loading…
Reference in New Issue
Block a user