Support ROCm in dynamic_parition fix

This commit is contained in:
drebain 2020-08-24 00:26:22 -07:00
parent 113555fab7
commit bd92fe5f43

View File

@ -36,7 +36,6 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@ -48,9 +47,14 @@ limitations under the License.
#include "tensorflow/core/kernels/gpu_prim.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
#include "tensorflow/core/util/transform_output_iterator.h"
#if GOOGLE_CUDA
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
#endif
using stream_executor::cuda::ScopedActivateExecutorContext;
#elif TENSORFLOW_USE_ROCM
#include "tensorflow/core/platform/rocm.h"
using stream_executor::rocm::ScopedActivateExecutorContext;
#endif // GOOGLE_CUDA
namespace tensorflow {
@ -306,12 +310,8 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
TensorReference partition_ref(partition_count);
auto wrapped_callback = [this, c, &data, &partitions, indices_out,
partition_ref, cpu_tensor, done]() {
GPUDeviceContext* gpu_device_context = static_cast<GPUDeviceContext*>(
c->op_device_context());
auto* stream = gpu_device_context->stream();
stream_executor::gpu::ScopedActivateExecutorContext scoped_activation {
stream->parent()
};
auto stream = c->op_device_context()->stream();
ScopedActivateExecutorContext scoped_activation{stream->parent()};
OpOutputList outputs;
this->AllocateOutputs(c, &data, &partitions, &cpu_tensor, &outputs, done);