Adding ROCm support to OpsTestBase (necessary for some unit tests)

This commit is contained in:
Eugene Kuznetsov 2020-01-22 12:24:24 -08:00
parent b6be4f36eb
commit 580dc945a6

View File

@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/node_properties.h"
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h"
#endif
@ -112,7 +112,7 @@ void OpsTestBase::SetDevice(const DeviceType& device_type,
thread_pool_.get());
device_type_ = device_type;
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
if (device_type == DEVICE_GPU) {
managed_allocator_.reset(new GpuManagedAllocator());
allocator_ = managed_allocator_.get();
@ -122,7 +122,7 @@ void OpsTestBase::SetDevice(const DeviceType& device_type,
}
#else
CHECK_NE(device_type, DEVICE_GPU)
<< "Requesting GPU on binary compiled without GOOGLE_CUDA.";
<< "Requesting GPU on binary compiled without GOOGLE_CUDA or TENSORFLOW_USE_ROCM.";
allocator_ = device_->GetAllocator(AllocatorAttributes());
#endif
}
@ -195,7 +195,7 @@ TensorValue OpsTestBase::mutable_input(int input_index) {
Tensor* OpsTestBase::GetOutput(int output_index) {
CHECK_LT(output_index, context_->num_outputs());
Tensor* output = context_->mutable_output(output_index);
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
if (device_type_ == DEVICE_GPU) {
managed_outputs_.resize(context_->num_outputs());
// Copy the output tensor to managed memory if we haven't done so.