Split Rendezvous class into pure-virtual RendezvousInterface and refcounted Rendezvous.
This change lays the groundwork for creating non-refcounted RendezvousInterface implementations, which would allow us to avoid dynamic allocation and atomic refcount operations in some cases. It modifies internal classes that use Rendezvous* to use RendezvousInterface* instead: the change is safe because none of these rely on the ability to modify the rendezvous' refcount (and it is unlikely that it would be safe for them to do so). PiperOrigin-RevId: 282764107 Change-Id: I8ef6fe995962dfa6556ae066f990c6445462a13e
This commit is contained in:
parent
957b33238c
commit
81f844c1ff
@ -1287,7 +1287,7 @@ class ExecutorState {
|
||||
|
||||
int64 step_id_;
|
||||
// Not owned.
|
||||
Rendezvous* rendezvous_;
|
||||
RendezvousInterface* rendezvous_;
|
||||
Executor::RendezvousFactory* create_rendezvous_ = nullptr;
|
||||
CollectiveExecutor* collective_executor_ = nullptr;
|
||||
SessionState* session_state_;
|
||||
|
@ -88,7 +88,7 @@ class Executor {
|
||||
|
||||
struct Args {
|
||||
int64 step_id = 0;
|
||||
Rendezvous* rendezvous = nullptr;
|
||||
RendezvousInterface* rendezvous = nullptr;
|
||||
StepStatsCollectorInterface* stats_collector = nullptr;
|
||||
CallFrameInterface* call_frame = nullptr;
|
||||
CancellationManager* cancellation_manager = nullptr;
|
||||
|
@ -1017,7 +1017,7 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
|
||||
Item* item, DoneCallback done) {
|
||||
string target_device = parent_->GetDeviceName(handle);
|
||||
string source_device = opts.source_device;
|
||||
Rendezvous* rendezvous = opts.rendezvous;
|
||||
RendezvousInterface* rendezvous = opts.rendezvous;
|
||||
DeviceContext* device_context;
|
||||
Status s = parent_->GetDeviceContext(target_device, &device_context);
|
||||
if (!s.ok()) {
|
||||
|
@ -1854,7 +1854,8 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
|
||||
|
||||
Tensor y;
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.rendezvous = new IntraProcessRendezvous(device_mgr_.get());
|
||||
Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_.get());
|
||||
opts.rendezvous = rendezvous;
|
||||
opts.source_device = "/device:CPU:1";
|
||||
// Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
|
||||
TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}, true));
|
||||
@ -1869,7 +1870,7 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
|
||||
y,
|
||||
test::AsTensor<tstring>({"/job:localhost/replica:0/task:0/device:CPU:1"},
|
||||
TensorShape({})));
|
||||
opts.rendezvous->Unref();
|
||||
rendezvous->Unref();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -45,7 +45,7 @@ namespace {
|
||||
// A simple rendezvous class.
|
||||
// Assumes a single sender and a single receiver, no duplicate sends, and no
|
||||
// sends of dead tensors.
|
||||
class SimpleRendezvous : public Rendezvous {
|
||||
class SimpleRendezvous : public RendezvousInterface {
|
||||
public:
|
||||
explicit SimpleRendezvous() {}
|
||||
|
||||
@ -124,8 +124,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
|
||||
std::unique_ptr<Graph> graph_to_run(new Graph(graph->op_registry()));
|
||||
CopyGraph(*graph, graph_to_run.get());
|
||||
|
||||
SimpleRendezvous* rendez = new SimpleRendezvous;
|
||||
core::ScopedUnref rendez_unref(rendez);
|
||||
SimpleRendezvous rendez;
|
||||
|
||||
// Extract the input names and keys, and feed in the inputs.
|
||||
std::vector<string> input_names;
|
||||
@ -136,8 +135,8 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
|
||||
tensor_name, FrameAndIter(0, 0));
|
||||
Rendezvous::ParsedKey parsed;
|
||||
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(full_key, &parsed));
|
||||
TF_RETURN_IF_ERROR(rendez->Send(parsed, Rendezvous::Args(), in.second,
|
||||
false /* is_dead */));
|
||||
TF_RETURN_IF_ERROR(rendez.Send(parsed, Rendezvous::Args(), in.second,
|
||||
false /* is_dead */));
|
||||
}
|
||||
|
||||
// Call RewriteGraphForExecution
|
||||
@ -180,7 +179,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
|
||||
// called via this method.
|
||||
args.step_id = LogMemory::CONSTANT_FOLDING_STEP_ID;
|
||||
args.runner = runner;
|
||||
args.rendezvous = rendez;
|
||||
args.rendezvous = &rendez;
|
||||
// NOTE: Use of graph runner is limited to single-device executions
|
||||
// so a CollectiveExecutor should never be required.
|
||||
args.collective_executor = nullptr;
|
||||
@ -201,7 +200,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
|
||||
bool is_dead;
|
||||
Tensor output_tensor;
|
||||
TF_RETURN_IF_ERROR(
|
||||
rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead));
|
||||
rendez.Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead));
|
||||
// Does a deep copy so that ownership of the tensor isn't tied to the
|
||||
// allocator of the cpu device we created above. The allocator could be
|
||||
// deleted along with the device.
|
||||
|
@ -122,7 +122,7 @@ Status ProcessFunctionLibraryRuntime::SendTensors(
|
||||
const string& key_prefix, int64 src_incarnation,
|
||||
gtl::ArraySlice<Tensor> tensors_to_send, DeviceContext* device_context,
|
||||
const std::vector<AllocatorAttributes>& alloc_attrs,
|
||||
Rendezvous* rendezvous) {
|
||||
RendezvousInterface* rendezvous) {
|
||||
std::vector<string> keys;
|
||||
for (int i = 0; i < tensors_to_send.size(); ++i) {
|
||||
string name = strings::StrCat(key_prefix, i);
|
||||
@ -140,8 +140,9 @@ void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
|
||||
const string& source_device, const string& target_device,
|
||||
const string& key_prefix, int64 src_incarnation, int64 num_tensors,
|
||||
DeviceContext* device_context,
|
||||
const std::vector<AllocatorAttributes>& alloc_attrs, Rendezvous* rendezvous,
|
||||
std::vector<Tensor>* received_tensors, StatusCallback done) {
|
||||
const std::vector<AllocatorAttributes>& alloc_attrs,
|
||||
RendezvousInterface* rendezvous, std::vector<Tensor>* received_tensors,
|
||||
StatusCallback done) {
|
||||
std::vector<string> keys;
|
||||
for (int64 i = 0; i < num_tensors; ++i) {
|
||||
string name = strings::StrCat(key_prefix, i);
|
||||
|
@ -92,7 +92,7 @@ class ProcessFunctionLibraryRuntime {
|
||||
gtl::ArraySlice<Tensor> tensors_to_send,
|
||||
DeviceContext* device_context,
|
||||
const std::vector<AllocatorAttributes>& alloc_attrs,
|
||||
Rendezvous* rendezvous);
|
||||
RendezvousInterface* rendezvous);
|
||||
|
||||
// Receives `received_tensors` from `target_device` (originally sent from
|
||||
// `source_device`) using `rendezvous`. Uses `key_prefix` to construct the
|
||||
@ -105,7 +105,7 @@ class ProcessFunctionLibraryRuntime {
|
||||
const string& key_prefix, int64 src_incarnation, int64 num_tensors,
|
||||
DeviceContext* device_context,
|
||||
const std::vector<AllocatorAttributes>& alloc_attrs,
|
||||
Rendezvous* rendezvous, std::vector<Tensor>* received_tensors,
|
||||
RendezvousInterface* rendezvous, std::vector<Tensor>* received_tensors,
|
||||
StatusCallback done);
|
||||
|
||||
static const char kDefaultFLRDevice[];
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
Status SendTensorsToRendezvous(
|
||||
Rendezvous* rendezvous, DeviceContext* device_context,
|
||||
RendezvousInterface* rendezvous, DeviceContext* device_context,
|
||||
const std::vector<AllocatorAttributes>& alloc_attrs,
|
||||
const std::vector<string>& keys, gtl::ArraySlice<Tensor> tensors_to_send) {
|
||||
if (keys.size() != tensors_to_send.size()) {
|
||||
@ -54,7 +54,7 @@ Status SendTensorsToRendezvous(
|
||||
}
|
||||
|
||||
void RecvOutputsFromRendezvousAsync(
|
||||
Rendezvous* rendezvous, DeviceContext* device_context,
|
||||
RendezvousInterface* rendezvous, DeviceContext* device_context,
|
||||
const std::vector<AllocatorAttributes>& alloc_attrs,
|
||||
const std::vector<string>& keys, std::vector<Tensor>* received_tensors,
|
||||
StatusCallback done) {
|
||||
@ -118,7 +118,8 @@ void RecvOutputsFromRendezvousAsync(
|
||||
status_cb->Unref();
|
||||
}
|
||||
|
||||
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
|
||||
Status RecvOutputsFromRendezvous(RendezvousInterface* rendezvous,
|
||||
NamedTensors* out,
|
||||
const Rendezvous::Args& args) {
|
||||
// Receives values requested by the caller.
|
||||
Rendezvous::ParsedKey parsed;
|
||||
|
@ -31,7 +31,7 @@ typedef std::function<void(const Status&)> StatusCallback;
|
||||
// allocated. `alloc_attrs` should either be {} or should match the length of
|
||||
// `keys`.
|
||||
Status SendTensorsToRendezvous(
|
||||
Rendezvous* rendezvous, DeviceContext* device_context,
|
||||
RendezvousInterface* rendezvous, DeviceContext* device_context,
|
||||
const std::vector<AllocatorAttributes>& alloc_attrs,
|
||||
const std::vector<string>& keys, gtl::ArraySlice<Tensor> tensors_to_send);
|
||||
|
||||
@ -40,12 +40,13 @@ Status SendTensorsToRendezvous(
|
||||
// information as how to store the received tensors. Should be {} or match the
|
||||
// length of `keys`.
|
||||
void RecvOutputsFromRendezvousAsync(
|
||||
Rendezvous* rendezvous, DeviceContext* device_context,
|
||||
RendezvousInterface* rendezvous, DeviceContext* device_context,
|
||||
const std::vector<AllocatorAttributes>& alloc_attrs,
|
||||
const std::vector<string>& keys, std::vector<Tensor>* received_tensors,
|
||||
StatusCallback done);
|
||||
|
||||
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
|
||||
Status RecvOutputsFromRendezvous(RendezvousInterface* rendezvous,
|
||||
NamedTensors* out,
|
||||
const Rendezvous::Args& args);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -687,7 +687,7 @@ class FunctionLibraryRuntime {
|
||||
// tensors to the remote TensorHandles in the default device.
|
||||
absl::optional<int64> op_id = absl::nullopt;
|
||||
|
||||
Rendezvous* rendezvous = nullptr;
|
||||
RendezvousInterface* rendezvous = nullptr;
|
||||
CancellationManager* cancellation_manager = nullptr;
|
||||
CollectiveExecutor* collective_executor = nullptr;
|
||||
ScopedStepContainer* step_container = nullptr;
|
||||
|
@ -672,7 +672,7 @@ class OpKernelContext {
|
||||
|
||||
// Mechanism used by this op kernel invocation to communicate with
|
||||
// computations running on other devices.
|
||||
Rendezvous* rendezvous = nullptr;
|
||||
RendezvousInterface* rendezvous = nullptr;
|
||||
const std::function<Status(const int64, const DeviceMgr*, Rendezvous** r)>*
|
||||
create_rendezvous;
|
||||
|
||||
@ -1100,7 +1100,7 @@ class OpKernelContext {
|
||||
//
|
||||
// An op kernel communicates with outside environment through
|
||||
// Rendezvous Send() and Recv().
|
||||
Rendezvous* rendezvous() const { return params_->rendezvous; }
|
||||
RendezvousInterface* rendezvous() const { return params_->rendezvous; }
|
||||
Status create_rendezvous(const int64 step_id, const DeviceMgr* device_mgr,
|
||||
Rendezvous** r) const {
|
||||
return (*params_->create_rendezvous)(step_id, device_mgr, r);
|
||||
|
@ -113,10 +113,10 @@ Status Rendezvous::ParseKey(StringPiece key, ParsedKey* out) {
|
||||
return errors::InvalidArgument("Invalid rendezvous key: ", key);
|
||||
}
|
||||
|
||||
Rendezvous::~Rendezvous() {}
|
||||
RendezvousInterface::~RendezvousInterface() {}
|
||||
|
||||
Status Rendezvous::Recv(const ParsedKey& key, const Args& recv_args,
|
||||
Tensor* val, bool* is_dead, int64 timeout_ms) {
|
||||
Status RendezvousInterface::Recv(const ParsedKey& key, const Args& recv_args,
|
||||
Tensor* val, bool* is_dead, int64 timeout_ms) {
|
||||
Status ret;
|
||||
Notification n;
|
||||
RecvAsync(key, recv_args,
|
||||
@ -141,8 +141,8 @@ Status Rendezvous::Recv(const ParsedKey& key, const Args& recv_args,
|
||||
return ret;
|
||||
}
|
||||
|
||||
Status Rendezvous::Recv(const ParsedKey& key, const Args& args, Tensor* val,
|
||||
bool* is_dead) {
|
||||
Status RendezvousInterface::Recv(const ParsedKey& key, const Args& args,
|
||||
Tensor* val, bool* is_dead) {
|
||||
const int64 no_timeout = 0;
|
||||
return Recv(key, args, val, is_dead, no_timeout);
|
||||
}
|
||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_
|
||||
#define TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_
|
||||
#ifndef TENSORFLOW_CORE_FRAMEWORK_RENDEZVOUS_H_
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_RENDEZVOUS_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
@ -44,7 +44,7 @@ namespace tensorflow {
|
||||
// been produced. A consumer has the choice of making a blocking call
|
||||
// or providing a callback: in either case, the consumer receives the
|
||||
// Tensor as soon as it is available. A producer never blocks.
|
||||
class Rendezvous : public core::RefCounted {
|
||||
class RendezvousInterface {
|
||||
public:
|
||||
struct Args {
|
||||
DeviceContext* device_context = nullptr;
|
||||
@ -52,13 +52,6 @@ class Rendezvous : public core::RefCounted {
|
||||
CancellationManager* cancellation_manager = nullptr; // not owned.
|
||||
};
|
||||
|
||||
// Constructs a rendezvous key for the tensor of "name" sent from
|
||||
// "src_device" to "dst_device". The tensor is generated in the frame
|
||||
// and iteration specified by "frame_iter".
|
||||
static string CreateKey(const string& src_device, uint64 src_incarnation,
|
||||
const string& dst_device, const string& name,
|
||||
const FrameAndIter& frame_iter);
|
||||
|
||||
// Parses the key constructed by CreateKey and parse src/dst device
|
||||
// names into structures respectively.
|
||||
struct ParsedKey {
|
||||
@ -81,7 +74,6 @@ class Rendezvous : public core::RefCounted {
|
||||
friend class RecvOp;
|
||||
string buf_;
|
||||
};
|
||||
static Status ParseKey(StringPiece key, ParsedKey* out);
|
||||
|
||||
// The caller is a tensor producer and it sends a message (a tensor
|
||||
// "val" and a bool "is_dead") under the given "key".
|
||||
@ -123,12 +115,28 @@ class Rendezvous : public core::RefCounted {
|
||||
virtual void StartAbort(const Status& status) = 0;
|
||||
|
||||
protected:
|
||||
~Rendezvous() override;
|
||||
virtual ~RendezvousInterface();
|
||||
|
||||
virtual bool is_cross_process() { return false; }
|
||||
friend class ProcessFunctionLibraryRuntime;
|
||||
};
|
||||
|
||||
// A reference-counted implementation of RendezvousInterface.
|
||||
//
|
||||
// This class is used in cases where a rendezvous may be shared between multiple
|
||||
// threads with no clear owner.
|
||||
class Rendezvous : public RendezvousInterface, public core::RefCounted {
|
||||
public:
|
||||
// Constructs a rendezvous key for the tensor of "name" sent from
|
||||
// "src_device" to "dst_device". The tensor is generated in the frame
|
||||
// and iteration specified by "frame_iter".
|
||||
static string CreateKey(const string& src_device, uint64 src_incarnation,
|
||||
const string& dst_device, const string& name,
|
||||
const FrameAndIter& frame_iter);
|
||||
|
||||
static Status ParseKey(StringPiece key, ParsedKey* out);
|
||||
};
|
||||
|
||||
// Returns a Rendezvous instance that is limited to use only by
|
||||
// producers and consumers in the local process. The caller assumes
|
||||
// ownership of one Ref() on the returned object.
|
||||
@ -136,4 +144,4 @@ Rendezvous* NewLocalRendezvous();
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_RENDEZVOUS_H_
|
||||
|
Loading…
Reference in New Issue
Block a user