From 65fe9180219728e064655537a0dcbdf30ac1d96d Mon Sep 17 00:00:00 2001 From: Eugene Kuznetsov Date: Fri, 13 Mar 2020 00:03:48 -0700 Subject: [PATCH] Implementation of a missing hook for XlaDeviceContext --- tensorflow/compiler/jit/xla_device_context.cc | 9 +++++++++ tensorflow/compiler/jit/xla_device_context.h | 3 +++ 2 files changed, 12 insertions(+) diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index e1cef25e33e..8ab62112719 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -294,4 +294,13 @@ se::Stream* XlaDeviceContext::GetDeviceToDeviceStream() { return device_to_device_stream(stream); } +Status XlaDeviceContext::ThenExecute(Device* device, stream_executor::Stream* stream, + std::function func) +{ + VLOG(2) << "XlaDeviceContext::ThenExecute"; + stream->ThenDoHostCallback(std::move(func)); + return Status::OK(); +} + + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 05d8dfa7556..84e5badcd88 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -86,6 +86,9 @@ class XlaDeviceContext : public DeviceContext { // Returns a device-to-device stream, in round-robin fashion. se::Stream* GetDeviceToDeviceStream(); + Status ThenExecute(Device* device, stream_executor::Stream* stream, + std::function func) override; + private: bool UseMultipleStreams() const { return stream_ != host_to_device_stream_; }