Support ROCm in dynamic_parition fix
This commit is contained in:
parent
113555fab7
commit
bd92fe5f43
@ -36,7 +36,6 @@ limitations under the License.
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/gpu_device_context.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
@ -48,9 +47,14 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/gpu_prim.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
#include "tensorflow/core/util/transform_output_iterator.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
||||
#endif
|
||||
using stream_executor::cuda::ScopedActivateExecutorContext;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/platform/rocm.h"
|
||||
using stream_executor::rocm::ScopedActivateExecutorContext;
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -306,12 +310,8 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
|
||||
TensorReference partition_ref(partition_count);
|
||||
auto wrapped_callback = [this, c, &data, &partitions, indices_out,
|
||||
partition_ref, cpu_tensor, done]() {
|
||||
GPUDeviceContext* gpu_device_context = static_cast<GPUDeviceContext*>(
|
||||
c->op_device_context());
|
||||
auto* stream = gpu_device_context->stream();
|
||||
stream_executor::gpu::ScopedActivateExecutorContext scoped_activation {
|
||||
stream->parent()
|
||||
};
|
||||
auto stream = c->op_device_context()->stream();
|
||||
ScopedActivateExecutorContext scoped_activation{stream->parent()};
|
||||
|
||||
OpOutputList outputs;
|
||||
this->AllocateOutputs(c, &data, &partitions, &cpu_tensor, &outputs, done);
|
||||
|
Loading…
Reference in New Issue
Block a user