Split Rendezvous class into pure-virtual RendezvousInterface and refcounted Rendezvous.

This change lays the groundwork for creating non-refcounted RendezvousInterface implementations, which would allow us to avoid dynamic allocation and atomic refcount
operations in some cases. It modifies internal classes that use Rendezvous* to use RendezvousInterface* instead: the change is safe because none of these rely on the ability to modify the rendezvous' refcount (and it is unlikely that it would be safe for them to do so).

PiperOrigin-RevId: 282764107
Change-Id: I8ef6fe995962dfa6556ae066f990c6445462a13e
This commit is contained in:
Derek Murray 2019-11-27 07:20:21 -08:00 committed by TensorFlower Gardener
parent 957b33238c
commit 81f844c1ff
13 changed files with 55 additions and 44 deletions

View File

@ -1287,7 +1287,7 @@ class ExecutorState {
int64 step_id_;
// Not owned.
Rendezvous* rendezvous_;
RendezvousInterface* rendezvous_;
Executor::RendezvousFactory* create_rendezvous_ = nullptr;
CollectiveExecutor* collective_executor_ = nullptr;
SessionState* session_state_;

View File

@ -88,7 +88,7 @@ class Executor {
struct Args {
int64 step_id = 0;
Rendezvous* rendezvous = nullptr;
RendezvousInterface* rendezvous = nullptr;
StepStatsCollectorInterface* stats_collector = nullptr;
CallFrameInterface* call_frame = nullptr;
CancellationManager* cancellation_manager = nullptr;

View File

@ -1017,7 +1017,7 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
Item* item, DoneCallback done) {
string target_device = parent_->GetDeviceName(handle);
string source_device = opts.source_device;
Rendezvous* rendezvous = opts.rendezvous;
RendezvousInterface* rendezvous = opts.rendezvous;
DeviceContext* device_context;
Status s = parent_->GetDeviceContext(target_device, &device_context);
if (!s.ok()) {

View File

@ -1854,7 +1854,8 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
Tensor y;
FunctionLibraryRuntime::Options opts;
opts.rendezvous = new IntraProcessRendezvous(device_mgr_.get());
Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_.get());
opts.rendezvous = rendezvous;
opts.source_device = "/device:CPU:1";
// Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}, true));
@ -1869,7 +1870,7 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
y,
test::AsTensor<tstring>({"/job:localhost/replica:0/task:0/device:CPU:1"},
TensorShape({})));
opts.rendezvous->Unref();
rendezvous->Unref();
}
namespace {

View File

@ -45,7 +45,7 @@ namespace {
// A simple rendezvous class.
// Assumes a single sender and a single receiver, no duplicate sends, and no
// sends of dead tensors.
class SimpleRendezvous : public Rendezvous {
class SimpleRendezvous : public RendezvousInterface {
public:
explicit SimpleRendezvous() {}
@ -124,8 +124,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
std::unique_ptr<Graph> graph_to_run(new Graph(graph->op_registry()));
CopyGraph(*graph, graph_to_run.get());
SimpleRendezvous* rendez = new SimpleRendezvous;
core::ScopedUnref rendez_unref(rendez);
SimpleRendezvous rendez;
// Extract the input names and keys, and feed in the inputs.
std::vector<string> input_names;
@ -136,8 +135,8 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
tensor_name, FrameAndIter(0, 0));
Rendezvous::ParsedKey parsed;
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(full_key, &parsed));
TF_RETURN_IF_ERROR(rendez->Send(parsed, Rendezvous::Args(), in.second,
false /* is_dead */));
TF_RETURN_IF_ERROR(rendez.Send(parsed, Rendezvous::Args(), in.second,
false /* is_dead */));
}
// Call RewriteGraphForExecution
@ -180,7 +179,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
// called via this method.
args.step_id = LogMemory::CONSTANT_FOLDING_STEP_ID;
args.runner = runner;
args.rendezvous = rendez;
args.rendezvous = &rendez;
// NOTE: Use of graph runner is limited to single-device executions
// so a CollectiveExecutor should never be required.
args.collective_executor = nullptr;
@ -201,7 +200,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
bool is_dead;
Tensor output_tensor;
TF_RETURN_IF_ERROR(
rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead));
rendez.Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead));
// Does a deep copy so that ownership of the tensor isn't tied to the
// allocator of the cpu device we created above. The allocator could be
// deleted along with the device.

View File

@ -122,7 +122,7 @@ Status ProcessFunctionLibraryRuntime::SendTensors(
const string& key_prefix, int64 src_incarnation,
gtl::ArraySlice<Tensor> tensors_to_send, DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs,
Rendezvous* rendezvous) {
RendezvousInterface* rendezvous) {
std::vector<string> keys;
for (int i = 0; i < tensors_to_send.size(); ++i) {
string name = strings::StrCat(key_prefix, i);
@ -140,8 +140,9 @@ void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
const string& source_device, const string& target_device,
const string& key_prefix, int64 src_incarnation, int64 num_tensors,
DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs, Rendezvous* rendezvous,
std::vector<Tensor>* received_tensors, StatusCallback done) {
const std::vector<AllocatorAttributes>& alloc_attrs,
RendezvousInterface* rendezvous, std::vector<Tensor>* received_tensors,
StatusCallback done) {
std::vector<string> keys;
for (int64 i = 0; i < num_tensors; ++i) {
string name = strings::StrCat(key_prefix, i);

View File

@ -92,7 +92,7 @@ class ProcessFunctionLibraryRuntime {
gtl::ArraySlice<Tensor> tensors_to_send,
DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs,
Rendezvous* rendezvous);
RendezvousInterface* rendezvous);
// Receives `received_tensors` from `target_device` (originally sent from
// `source_device`) using `rendezvous`. Uses `key_prefix` to construct the
@ -105,7 +105,7 @@ class ProcessFunctionLibraryRuntime {
const string& key_prefix, int64 src_incarnation, int64 num_tensors,
DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs,
Rendezvous* rendezvous, std::vector<Tensor>* received_tensors,
RendezvousInterface* rendezvous, std::vector<Tensor>* received_tensors,
StatusCallback done);
static const char kDefaultFLRDevice[];

View File

@ -20,7 +20,7 @@ limitations under the License.
namespace tensorflow {
Status SendTensorsToRendezvous(
Rendezvous* rendezvous, DeviceContext* device_context,
RendezvousInterface* rendezvous, DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs,
const std::vector<string>& keys, gtl::ArraySlice<Tensor> tensors_to_send) {
if (keys.size() != tensors_to_send.size()) {
@ -54,7 +54,7 @@ Status SendTensorsToRendezvous(
}
void RecvOutputsFromRendezvousAsync(
Rendezvous* rendezvous, DeviceContext* device_context,
RendezvousInterface* rendezvous, DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs,
const std::vector<string>& keys, std::vector<Tensor>* received_tensors,
StatusCallback done) {
@ -118,7 +118,8 @@ void RecvOutputsFromRendezvousAsync(
status_cb->Unref();
}
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
Status RecvOutputsFromRendezvous(RendezvousInterface* rendezvous,
NamedTensors* out,
const Rendezvous::Args& args) {
// Receives values requested by the caller.
Rendezvous::ParsedKey parsed;

View File

@ -31,7 +31,7 @@ typedef std::function<void(const Status&)> StatusCallback;
// allocated. `alloc_attrs` should either be {} or should match the length of
// `keys`.
Status SendTensorsToRendezvous(
Rendezvous* rendezvous, DeviceContext* device_context,
RendezvousInterface* rendezvous, DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs,
const std::vector<string>& keys, gtl::ArraySlice<Tensor> tensors_to_send);
@ -40,12 +40,13 @@ Status SendTensorsToRendezvous(
// information as how to store the received tensors. Should be {} or match the
// length of `keys`.
void RecvOutputsFromRendezvousAsync(
Rendezvous* rendezvous, DeviceContext* device_context,
RendezvousInterface* rendezvous, DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs,
const std::vector<string>& keys, std::vector<Tensor>* received_tensors,
StatusCallback done);
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
Status RecvOutputsFromRendezvous(RendezvousInterface* rendezvous,
NamedTensors* out,
const Rendezvous::Args& args);
} // namespace tensorflow

View File

@ -687,7 +687,7 @@ class FunctionLibraryRuntime {
// tensors to the remote TensorHandles in the default device.
absl::optional<int64> op_id = absl::nullopt;
Rendezvous* rendezvous = nullptr;
RendezvousInterface* rendezvous = nullptr;
CancellationManager* cancellation_manager = nullptr;
CollectiveExecutor* collective_executor = nullptr;
ScopedStepContainer* step_container = nullptr;

View File

@ -672,7 +672,7 @@ class OpKernelContext {
// Mechanism used by this op kernel invocation to communicate with
// computations running on other devices.
Rendezvous* rendezvous = nullptr;
RendezvousInterface* rendezvous = nullptr;
const std::function<Status(const int64, const DeviceMgr*, Rendezvous** r)>*
create_rendezvous;
@ -1100,7 +1100,7 @@ class OpKernelContext {
//
// An op kernel communicates with outside environment through
// Rendezvous Send() and Recv().
Rendezvous* rendezvous() const { return params_->rendezvous; }
RendezvousInterface* rendezvous() const { return params_->rendezvous; }
Status create_rendezvous(const int64 step_id, const DeviceMgr* device_mgr,
Rendezvous** r) const {
return (*params_->create_rendezvous)(step_id, device_mgr, r);

View File

@ -113,10 +113,10 @@ Status Rendezvous::ParseKey(StringPiece key, ParsedKey* out) {
return errors::InvalidArgument("Invalid rendezvous key: ", key);
}
Rendezvous::~Rendezvous() {}
RendezvousInterface::~RendezvousInterface() {}
Status Rendezvous::Recv(const ParsedKey& key, const Args& recv_args,
Tensor* val, bool* is_dead, int64 timeout_ms) {
Status RendezvousInterface::Recv(const ParsedKey& key, const Args& recv_args,
Tensor* val, bool* is_dead, int64 timeout_ms) {
Status ret;
Notification n;
RecvAsync(key, recv_args,
@ -141,8 +141,8 @@ Status Rendezvous::Recv(const ParsedKey& key, const Args& recv_args,
return ret;
}
Status Rendezvous::Recv(const ParsedKey& key, const Args& args, Tensor* val,
bool* is_dead) {
Status RendezvousInterface::Recv(const ParsedKey& key, const Args& args,
Tensor* val, bool* is_dead) {
const int64 no_timeout = 0;
return Recv(key, args, val, is_dead, no_timeout);
}

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_
#define TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_
#ifndef TENSORFLOW_CORE_FRAMEWORK_RENDEZVOUS_H_
#define TENSORFLOW_CORE_FRAMEWORK_RENDEZVOUS_H_
#include <string>
@ -44,7 +44,7 @@ namespace tensorflow {
// been produced. A consumer has the choice of making a blocking call
// or providing a callback: in either case, the consumer receives the
// Tensor as soon as it is available. A producer never blocks.
class Rendezvous : public core::RefCounted {
class RendezvousInterface {
public:
struct Args {
DeviceContext* device_context = nullptr;
@ -52,13 +52,6 @@ class Rendezvous : public core::RefCounted {
CancellationManager* cancellation_manager = nullptr; // not owned.
};
// Constructs a rendezvous key for the tensor of "name" sent from
// "src_device" to "dst_device". The tensor is generated in the frame
// and iteration specified by "frame_iter".
static string CreateKey(const string& src_device, uint64 src_incarnation,
const string& dst_device, const string& name,
const FrameAndIter& frame_iter);
// Parses the key constructed by CreateKey and parse src/dst device
// names into structures respectively.
struct ParsedKey {
@ -81,7 +74,6 @@ class Rendezvous : public core::RefCounted {
friend class RecvOp;
string buf_;
};
static Status ParseKey(StringPiece key, ParsedKey* out);
// The caller is a tensor producer and it sends a message (a tensor
// "val" and a bool "is_dead") under the given "key".
@ -123,12 +115,28 @@ class Rendezvous : public core::RefCounted {
virtual void StartAbort(const Status& status) = 0;
protected:
~Rendezvous() override;
virtual ~RendezvousInterface();
virtual bool is_cross_process() { return false; }
friend class ProcessFunctionLibraryRuntime;
};
// A reference-counted implementation of RendezvousInterface.
//
// This class is used in cases where a rendezvous may be shared between multiple
// threads with no clear owner.
class Rendezvous : public RendezvousInterface, public core::RefCounted {
public:
// Constructs a rendezvous key for the tensor of "name" sent from
// "src_device" to "dst_device". The tensor is generated in the frame
// and iteration specified by "frame_iter".
static string CreateKey(const string& src_device, uint64 src_incarnation,
const string& dst_device, const string& name,
const FrameAndIter& frame_iter);
static Status ParseKey(StringPiece key, ParsedKey* out);
};
// Returns a Rendezvous instance that is limited to use only by
// producers and consumers in the local process. The caller assumes
// ownership of one Ref() on the returned object.
@ -136,4 +144,4 @@ Rendezvous* NewLocalRendezvous();
} // end namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_
#endif // TENSORFLOW_CORE_FRAMEWORK_RENDEZVOUS_H_