TensorFlow: make split_op not use internal header library for callback,
since this breaks the build on GPU. Change: 115582331
This commit is contained in:
parent
86e93febaa
commit
13d7f52034
@ -26,7 +26,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/gpu_device_context.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
@ -230,11 +229,10 @@ class SplitOpGPU : public SplitOpBase<GPUDevice, T> {
|
||||
perftools::gputools::DeviceMemoryBase output_ptrs_base{
|
||||
output_ptrs_on_gpu.flat<int8>().data(), static_cast<uint64>(num_split)};
|
||||
TensorReference tensor_ref(output_ptrs_on_host);
|
||||
stream->ThenMemcpy(&output_ptrs_base,
|
||||
output_ptrs_on_host.flat<int8>().data(),
|
||||
output_ptrs_total_bytes);
|
||||
context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
|
||||
stream, [tensor_ref]() { tensor_ref.Unref(); });
|
||||
stream
|
||||
->ThenMemcpy(&output_ptrs_base, output_ptrs_on_host.flat<int8>().data(),
|
||||
output_ptrs_total_bytes)
|
||||
.ThenDoHostCallback([tensor_ref]() { tensor_ref.Unref(); });
|
||||
SplitOpGPULaunch<T>().Run(
|
||||
context->eigen_device<GPUDevice>(), input.flat<T>().data(), num_split,
|
||||
prefix_dim_size, split_dim_size, suffix_dim_size,
|
||||
|
Loading…
Reference in New Issue
Block a user