Remove unused RPC op kernels.
PiperOrigin-RevId: 354639187 Change-Id: I302f2cfc9b386f53cf0641461b699f3730a63150
This commit is contained in:
parent
6574fc4e08
commit
222861851e
@ -641,7 +641,6 @@ cc_library(
|
||||
"//tensorflow/core/kernels:required",
|
||||
"//tensorflow/core/kernels:resource_variable_ops",
|
||||
"//tensorflow/core/kernels:rnn_ops",
|
||||
"//tensorflow/core/kernels:rpc_op",
|
||||
"//tensorflow/core/kernels:scoped_allocator_ops",
|
||||
"//tensorflow/core/kernels:sdca_ops",
|
||||
"//tensorflow/core/kernels:searchsorted_op",
|
||||
@ -978,7 +977,6 @@ filegroup(
|
||||
"resource_variable_ops_op_lib",
|
||||
"risc_ops_op_lib",
|
||||
"rnn_ops_op_lib",
|
||||
"rpc_ops_op_lib",
|
||||
"scoped_allocator_ops_op_lib",
|
||||
"script_ops_op_lib",
|
||||
"sdca_ops_op_lib",
|
||||
|
@ -1,108 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "Rpc"
|
||||
in_arg {
|
||||
name: "address"
|
||||
description: <<END
|
||||
`0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server.
|
||||
If this tensor has more than 1 element, then multiple parallel rpc requests
|
||||
are sent. This argument broadcasts with `method` and `request`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "method"
|
||||
description: <<END
|
||||
`0-D` or `1-D`. The method address on the RPC server.
|
||||
If this tensor has more than 1 element, then multiple parallel rpc requests
|
||||
are sent. This argument broadcasts with `address` and `request`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "request"
|
||||
description: <<END
|
||||
`0-D` or `1-D`. Serialized proto strings: the rpc request argument.
|
||||
If this tensor has more than 1 element, then multiple parallel rpc requests
|
||||
are sent. This argument broadcasts with `address` and `method`.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "response"
|
||||
description: <<END
|
||||
Same shape as `request`. Serialized proto strings: the rpc responses.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "protocol"
|
||||
description: <<END
|
||||
RPC protocol to use. Empty string means use the default protocol.
|
||||
Options include 'grpc'.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "fail_fast"
|
||||
description: <<END
|
||||
`boolean`. If `true` (default), then failures to connect
|
||||
(i.e., the server does not immediately respond) cause an RPC failure.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "timeout_in_ms"
|
||||
description: <<END
|
||||
`int`. If `0` (default), then the kernel will run the RPC
|
||||
request and only time out if the RPC deadline passes or the session times out.
|
||||
If this value is greater than `0`, then the op will raise an exception if
|
||||
the RPC takes longer than `timeout_in_ms`.
|
||||
END
|
||||
}
|
||||
summary: <<END
|
||||
Perform batches of RPC requests.
|
||||
END
|
||||
description: <<END
|
||||
This op asynchronously performs either a single RPC request, or a batch
|
||||
of requests. RPC requests are defined by three main parameters:
|
||||
|
||||
- `address` (the host+port or BNS address of the request)
|
||||
- `method` (the RPC method name for the request)
|
||||
- `request` (the serialized proto string, or vector of strings,
|
||||
of the RPC request argument).
|
||||
|
||||
For example, if you have an RPC service running on port localhost:2345,
|
||||
and its interface is configured with the following proto declaration:
|
||||
|
||||
```
|
||||
service MyService {
|
||||
rpc MyMethod(MyRequestProto) returns (MyResponseProto) {
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
then call this op with arguments:
|
||||
|
||||
```
|
||||
address = "localhost:2345"
|
||||
method = "MyService/MyMethod"
|
||||
```
|
||||
|
||||
The `request` tensor is a string tensor representing serialized `MyRequestProto`
|
||||
strings; and the output string tensor `response` will have the same shape
|
||||
and contain (upon successful completion) corresponding serialized
|
||||
`MyResponseProto` strings.
|
||||
|
||||
For example, to send a single, empty, `MyRequestProto`, call
|
||||
this op with `request = ""`. To send 5 **parallel** empty requests,
|
||||
call this op with `request = ["", "", "", "", ""]`.
|
||||
|
||||
More generally, one can create a batch of `MyRequestProto` serialized protos
|
||||
from regular batched tensors using the `encode_proto` op, and convert
|
||||
the response `MyResponseProto` serialized protos to batched tensors
|
||||
using the `decode_proto` op.
|
||||
|
||||
**NOTE** Working with serialized proto strings is faster than instantiating
|
||||
actual proto objects in memory, so no performance degradation is expected
|
||||
compared to writing custom kernels for this workflow.
|
||||
|
||||
If the connection fails or the remote worker returns an error
|
||||
status, the op reraises this exception locally.
|
||||
|
||||
See the `TryRpc` op if you prefer to handle RPC failures manually in the graph.
|
||||
END
|
||||
}
|
@ -1,123 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "TryRpc"
|
||||
in_arg {
|
||||
name: "address"
|
||||
description: <<END
|
||||
`0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server.
|
||||
If this tensor has more than 1 element, then multiple parallel rpc requests
|
||||
are sent. This argument broadcasts with `method` and `request`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "method"
|
||||
description: <<END
|
||||
`0-D` or `1-D`. The method address on the RPC server.
|
||||
If this tensor has more than 1 element, then multiple parallel rpc requests
|
||||
are sent. This argument broadcasts with `address` and `request`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "request"
|
||||
description: <<END
|
||||
`0-D` or `1-D`. Serialized proto strings: the rpc request argument.
|
||||
If this tensor has more than 1 element, then multiple parallel rpc requests
|
||||
are sent. This argument broadcasts with `address` and `method`.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "response"
|
||||
description: <<END
|
||||
Same shape as `request`. Serialized proto strings: the rpc responses.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "status_code"
|
||||
description: <<END
|
||||
Same shape as `request`. Values correspond to tensorflow Status enum codes.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "status_message"
|
||||
description: <<END
|
||||
Same shape as `request`. Values correspond to Status messages
|
||||
returned from the RPC calls.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "protocol"
|
||||
description: <<END
|
||||
RPC protocol to use. Empty string means use the default protocol.
|
||||
Options include 'grpc'.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "fail_fast"
|
||||
description: <<END
|
||||
`boolean`. If `true` (default), then failures to connect
|
||||
(i.e., the server does not immediately respond) cause an RPC failure.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "timeout_in_ms"
|
||||
description: <<END
|
||||
`int`. If `0` (default), then the kernel will run the RPC
|
||||
request and only time out if the RPC deadline passes or the session times out.
|
||||
If this value is greater than `0`, then the op will raise an exception if
|
||||
the RPC takes longer than `timeout_in_ms`.
|
||||
END
|
||||
}
|
||||
summary: <<END
|
||||
Perform batches of RPC requests.
|
||||
END
|
||||
description: <<END
|
||||
This op asynchronously performs either a single RPC request, or a batch
|
||||
of requests. RPC requests are defined by three main parameters:
|
||||
|
||||
- `address` (the host+port or BNS address of the request)
|
||||
- `method` (the method name for the request)
|
||||
- `request` (the serialized proto string, or vector of strings,
|
||||
of the RPC request argument).
|
||||
|
||||
For example, if you have an RPC service running on port localhost:2345,
|
||||
and its interface is configured with the following proto declaration:
|
||||
|
||||
```
|
||||
service MyService {
|
||||
rpc MyMethod(MyRequestProto) returns (MyResponseProto) {
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
then call this op with arguments:
|
||||
|
||||
```
|
||||
address = "localhost:2345"
|
||||
method = "MyService/MyMethod"
|
||||
```
|
||||
|
||||
The `request` tensor is a string tensor representing serialized `MyRequestProto`
|
||||
strings; and the output string tensor `response` will have the same shape
|
||||
and contain (upon successful completion) corresponding serialized
|
||||
`MyResponseProto` strings.
|
||||
|
||||
For example, to send a single, empty, `MyRequestProto`, call
|
||||
this op with `request = ""`. To send 5 **parallel** empty requests,
|
||||
call this op with `request = ["", "", "", "", ""]`.
|
||||
|
||||
More generally, one can create a batch of `MyRequestProto` serialized protos
|
||||
from regular batched tensors using the `encode_proto` op, and convert
|
||||
the response `MyResponseProto` serialized protos to batched tensors
|
||||
using the `decode_proto` op.
|
||||
|
||||
**NOTE** Working with serialized proto strings is faster than instantiating
|
||||
actual proto objects in memory, so no performance degradation is expected
|
||||
compared to writing custom kernels for this workflow.
|
||||
|
||||
Unlike the standard `Rpc` op, if the connection fails or the remote worker
|
||||
returns an error status, this op does **not** reraise the exception.
|
||||
Instead, the `status_code` and `status_message` entry for the corresponding RPC
|
||||
call is set with the error returned from the RPC call. The `response` tensor
|
||||
will contain valid response values for those minibatch entries whose RPCs did
|
||||
not fail; the rest of the entries will have empty strings.
|
||||
END
|
||||
}
|
@ -1,3 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "Rpc"
|
||||
}
|
@ -1,3 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "TryRpc"
|
||||
}
|
@ -563,33 +563,3 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core/protobuf:master_proto_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc_rpc_factory",
|
||||
srcs = [
|
||||
"grpc_rpc_factory.cc",
|
||||
],
|
||||
hdrs = ["grpc_rpc_factory.h"],
|
||||
deps = [
|
||||
":grpc_state",
|
||||
":grpc_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/util/rpc:call_container",
|
||||
"//tensorflow/core/util/rpc:rpc_factory",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc_rpc_factory_registration",
|
||||
srcs = [
|
||||
"grpc_rpc_factory_registration.cc",
|
||||
],
|
||||
deps = [
|
||||
":grpc_rpc_factory",
|
||||
"//tensorflow/core/util/rpc:rpc_factory",
|
||||
"//tensorflow/core/util/rpc:rpc_factory_registry",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -1,217 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/util/rpc/call_container.h"
|
||||
#include "tensorflow/core/util/rpc/rpc_factory.h"
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace internal {
|
||||
class GrpcCall {
|
||||
public:
|
||||
explicit GrpcCall(CallContainer<GrpcCall>* container, int index, bool try_rpc,
|
||||
const tstring* request_msg, tstring* response_msg,
|
||||
int32* status_code, tstring* status_message)
|
||||
: container_(container),
|
||||
index_(index),
|
||||
try_rpc_(try_rpc),
|
||||
request_msg_(request_msg),
|
||||
response_msg_(response_msg),
|
||||
status_code_(status_code),
|
||||
status_message_(status_message) {}
|
||||
|
||||
void StartCancel() { call_opts_.StartCancel(); }
|
||||
|
||||
void Done(const Status& s) {
|
||||
DCHECK(container_ != nullptr);
|
||||
if (!s.ok() && try_rpc_) {
|
||||
DCHECK(status_code_ != nullptr);
|
||||
DCHECK(status_message_ != nullptr);
|
||||
*status_code_ = s.code();
|
||||
*status_message_ = s.error_message();
|
||||
}
|
||||
container_->Done(s, index_);
|
||||
}
|
||||
|
||||
CallOptions* call_opts() { return &call_opts_; }
|
||||
int index() { return index_; }
|
||||
const tstring& request() const { return *request_msg_; }
|
||||
tstring* response() const { return response_msg_; }
|
||||
|
||||
private:
|
||||
CallContainer<GrpcCall>* const container_;
|
||||
const int index_;
|
||||
bool try_rpc_;
|
||||
CallOptions call_opts_;
|
||||
const tstring* request_msg_;
|
||||
tstring* response_msg_;
|
||||
int* status_code_;
|
||||
tstring* status_message_;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
using internal::GrpcCall;
|
||||
|
||||
GrpcRPCFactory::GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast,
|
||||
int64 timeout_in_ms)
|
||||
: RPCFactory(), fail_fast_(fail_fast), timeout_in_ms_(timeout_in_ms) {
|
||||
// TODO(ebrevdo): Investigate possible performance improvements by
|
||||
// replacing this thread with a threadpool.
|
||||
polling_thread_ =
|
||||
ctx->env()->StartThread(ThreadOptions(), "rpc_op_grpc_factory", [this]() {
|
||||
void* tag;
|
||||
bool ok;
|
||||
while (completion_queue_.Next(&tag, &ok)) {
|
||||
GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
|
||||
callback_tag->OnCompleted(ok);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
GrpcRPCFactory::~GrpcRPCFactory() {
|
||||
// The amount of time we wait depends on several parameters, including:
|
||||
// - the value of the fail_fast attribute.
|
||||
// - the timeout option of the rpc call in the proto declaration.
|
||||
// - the network roundtrip time and service's execution time.
|
||||
//
|
||||
// If a connection is made but the service doesn't ever respond, and
|
||||
// there is no timeout option set for this rpc call, then it is
|
||||
// possible the RPC request will wait forever.
|
||||
//
|
||||
completion_queue_.Shutdown();
|
||||
delete polling_thread_;
|
||||
}
|
||||
|
||||
void GrpcRPCFactory::Call(OpKernelContext* ctx, int64 num_elements,
|
||||
const Tensor& address_t, const Tensor& method_t,
|
||||
const Tensor& request_t, const bool try_rpc,
|
||||
Tensor* response_t, Tensor* status_code_t,
|
||||
Tensor* status_message_t,
|
||||
AsyncOpKernel::DoneCallback done) {
|
||||
if (try_rpc) {
|
||||
// In this case status_code will never be set in the response,
|
||||
// so we just set it to OK.
|
||||
DCHECK(status_code_t != nullptr);
|
||||
status_code_t->flat<int32>().setConstant(
|
||||
static_cast<int>(errors::Code::OK));
|
||||
}
|
||||
|
||||
CallContainer<GrpcCall>::CreateCallFn create_call_fn =
|
||||
[this, &request_t, &try_rpc, response_t, status_code_t, status_message_t](
|
||||
CallContainer<GrpcCall>* container, int index) {
|
||||
CreateCall(request_t, try_rpc, index, container, response_t,
|
||||
status_code_t, status_message_t);
|
||||
};
|
||||
|
||||
CallContainer<GrpcCall>::StartCallFn start_call_fn =
|
||||
[this, &address_t, &method_t](GrpcCall* call) {
|
||||
StartCall(address_t, method_t, call);
|
||||
};
|
||||
|
||||
// This object will delete itself when done.
|
||||
new CallContainer<GrpcCall>(ctx, num_elements, fail_fast_, try_rpc,
|
||||
std::move(done), std::move(create_call_fn),
|
||||
std::move(start_call_fn));
|
||||
}
|
||||
|
||||
::grpc::GenericStub* GrpcRPCFactory::GetOrCreateStubForAddress(
|
||||
const string& address) {
|
||||
mutex_lock lock(mu_);
|
||||
|
||||
auto stub = stubs_.find(address);
|
||||
if (stub != stubs_.end()) return stub->second.get();
|
||||
|
||||
ChannelPtr channel = CreateChannelForAddress(address);
|
||||
auto* created = new ::grpc::GenericStub(channel);
|
||||
stubs_[address].reset(created);
|
||||
return created;
|
||||
}
|
||||
|
||||
GrpcRPCFactory::ChannelPtr GrpcRPCFactory::CreateChannelForAddress(
|
||||
const string& address) {
|
||||
::grpc::ChannelArguments args;
|
||||
args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
|
||||
|
||||
// Set a standard backoff timeout of 1s instead of the
|
||||
// (sometimes default) 20s.
|
||||
args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 1000);
|
||||
return ::grpc::CreateCustomChannel(
|
||||
/*target=*/address, ::grpc::InsecureChannelCredentials(), args);
|
||||
}
|
||||
|
||||
void GrpcRPCFactory::CreateCall(const Tensor& request_t, const bool try_rpc,
|
||||
int index, CallContainer<GrpcCall>* container,
|
||||
Tensor* response_t, Tensor* status_code_t,
|
||||
Tensor* status_message_t) {
|
||||
auto request = request_t.flat<tstring>();
|
||||
auto get_request_ptr = [&request](int64 ix) -> const tstring* {
|
||||
return (request.size() > 1) ? &(request(ix)) : &(request(0));
|
||||
};
|
||||
auto response = response_t->flat<tstring>();
|
||||
int32* status_code_ptr = nullptr;
|
||||
tstring* status_message_ptr = nullptr;
|
||||
if (try_rpc) {
|
||||
status_code_ptr = status_code_t->flat<int32>().data();
|
||||
status_message_ptr = status_message_t->flat<tstring>().data();
|
||||
}
|
||||
container->RegisterCall(container, index, try_rpc, get_request_ptr(index),
|
||||
&response(index),
|
||||
(try_rpc) ? &status_code_ptr[index] : nullptr,
|
||||
(try_rpc) ? &status_message_ptr[index] : nullptr);
|
||||
}
|
||||
|
||||
void GrpcRPCFactory::StartCall(const Tensor& address_t, const Tensor& method_t,
|
||||
GrpcCall* call) {
|
||||
auto address = address_t.flat<tstring>();
|
||||
auto method = method_t.flat<tstring>();
|
||||
// Stubs are maintained by the GrpcRPCFactory class and will be
|
||||
// deleted when the class is destroyed.
|
||||
::grpc::GenericStub* singleton_stub = nullptr;
|
||||
if (address.size() == 1) {
|
||||
singleton_stub = GetOrCreateStubForAddress(address(0));
|
||||
}
|
||||
auto get_stub = [&address, this,
|
||||
singleton_stub](int64 ix) -> ::grpc::GenericStub* {
|
||||
return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix))
|
||||
: singleton_stub;
|
||||
};
|
||||
auto get_method_ptr = [&method](int64 ix) -> const tstring* {
|
||||
return (method.size() > 1) ? &(method(ix)) : &(method(0));
|
||||
};
|
||||
|
||||
int index = call->index();
|
||||
// This object will delete itself when done.
|
||||
new RPCState<tstring>(
|
||||
get_stub(index), &completion_queue_, *get_method_ptr(index),
|
||||
call->request(), call->response(),
|
||||
/*done=*/[call](const Status& s) { call->Done(s); }, call->call_opts(),
|
||||
/*threadpool=*/nullptr, fail_fast_, timeout_in_ms_, /*max_retries=*/0,
|
||||
/*target=*/nullptr);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -1,77 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/util/rpc/call_container.h"
|
||||
#include "tensorflow/core/util/rpc/rpc_factory.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Forward declaration of GrpcCall.
|
||||
namespace internal {
|
||||
class GrpcCall;
|
||||
} // namespace internal
|
||||
|
||||
class GrpcRPCFactory : public RPCFactory {
|
||||
public:
|
||||
explicit GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast,
|
||||
int64 timeout_in_ms);
|
||||
|
||||
// Explicit destructor to control destruction order.
|
||||
~GrpcRPCFactory() override;
|
||||
|
||||
void Call(OpKernelContext* ctx, int64 num_elements, const Tensor& address_t,
|
||||
const Tensor& method_t, const Tensor& request_t, const bool try_rpc,
|
||||
Tensor* response_t, Tensor* status_code_t, Tensor* status_message_t,
|
||||
AsyncOpKernel::DoneCallback done) override;
|
||||
|
||||
protected:
|
||||
typedef std::shared_ptr<::grpc::Channel> ChannelPtr;
|
||||
virtual ChannelPtr CreateChannelForAddress(const string& address);
|
||||
|
||||
private:
|
||||
// Creates a call and registers it with given `container`. The `index` is used
|
||||
// to index into the tensor arguments.
|
||||
void CreateCall(const Tensor& request_t, const bool try_rpc, int index,
|
||||
CallContainer<internal::GrpcCall>* container,
|
||||
Tensor* response_t, Tensor* status_code_t,
|
||||
Tensor* status_message_t);
|
||||
|
||||
// Asynchronously invokes the given `call`. The call completion is handled
|
||||
// by the call container the call was previously registered with.
|
||||
void StartCall(const Tensor& address_t, const Tensor& method_t,
|
||||
internal::GrpcCall* call);
|
||||
|
||||
::grpc::GenericStub* GetOrCreateStubForAddress(const string& address);
|
||||
|
||||
bool fail_fast_;
|
||||
int64 timeout_in_ms_;
|
||||
::grpc::CompletionQueue completion_queue_;
|
||||
Thread* polling_thread_; // Owned.
|
||||
|
||||
mutex mu_;
|
||||
typedef std::unique_ptr<::grpc::GenericStub> StubPtr;
|
||||
std::unordered_map<string, StubPtr> stubs_ TF_GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_
|
@ -1,34 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h"
|
||||
#include "tensorflow/core/util/rpc/rpc_factory.h"
|
||||
#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Used for adding the grpc factory to the RPC factory registry.
|
||||
struct Value {
|
||||
static RPCFactory* Function(OpKernelConstruction* ctx, bool fail_fast,
|
||||
int64 timeout_in_ms) {
|
||||
return new GrpcRPCFactory(ctx, fail_fast, timeout_in_ms);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_RPC_FACTORY("grpc", Value::Function);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -6387,7 +6387,6 @@ filegroup(
|
||||
"spectrogram_convert_test_data.cc",
|
||||
"decode_proto_op.cc",
|
||||
"encode_proto_op.cc",
|
||||
"rpc_op.cc",
|
||||
"sobol_op.cc",
|
||||
# Excluded due to experimental status:
|
||||
"debug_ops.*",
|
||||
@ -7304,22 +7303,6 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "rpc_op",
|
||||
srcs = [
|
||||
"rpc_op.cc",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/util/rpc:call_container",
|
||||
"//tensorflow/core/util/rpc:rpc_factory",
|
||||
"//tensorflow/core/util/rpc:rpc_factory_registry",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "unicode_script_op",
|
||||
srcs = ["unicode_script_op.cc"],
|
||||
|
@ -1,128 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// RpcOp is a TensorFlow op that sends and receives arbitrary messages.
|
||||
//
|
||||
// See docs in ../ops/rpc_op.cc.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/rpc/call_container.h"
|
||||
#include "tensorflow/core/util/rpc/rpc_factory.h"
|
||||
#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class RpcOp : public AsyncOpKernel {
|
||||
public:
|
||||
explicit RpcOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("protocol", &protocol_));
|
||||
OP_REQUIRES(context, !protocol_.empty(),
|
||||
errors::InvalidArgument("protocol must be non-empty."));
|
||||
bool fail_fast;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("fail_fast", &fail_fast));
|
||||
int64 timeout_in_ms;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("timeout_in_ms", &timeout_in_ms));
|
||||
|
||||
RPCFactoryRegistry::RPCFactoryFn* rpc_factory_fn =
|
||||
RPCFactoryRegistry::Global()->Get(protocol_);
|
||||
OP_REQUIRES(context, rpc_factory_fn != nullptr,
|
||||
errors::InvalidArgument("The protocol ", protocol_,
|
||||
" was not recognized."));
|
||||
|
||||
rpc_factory_.reset((*rpc_factory_fn)(context, fail_fast, timeout_in_ms));
|
||||
}
|
||||
|
||||
~RpcOp() override {}
|
||||
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
||||
const Tensor& address_t = ctx->input(0);
|
||||
const Tensor& method_t = ctx->input(1);
|
||||
const Tensor& request_t = ctx->input(2);
|
||||
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, address_t.dims() == 0 || address_t.dims() == 1,
|
||||
errors::InvalidArgument("address must be a scalar or vector."), done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, method_t.dims() == 0 || method_t.dims() == 1,
|
||||
errors::InvalidArgument("method must be a scalar or vector."), done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, request_t.dims() == 0 || request_t.dims() == 1,
|
||||
errors::InvalidArgument("request must be a scalar or vector."), done);
|
||||
|
||||
TensorShape output_shape({});
|
||||
for (const Tensor& t : {address_t, method_t, request_t}) {
|
||||
if (t.dims() == 1) {
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx,
|
||||
output_shape.dims() == 0 ||
|
||||
output_shape.dim_size(0) == t.dim_size(0),
|
||||
errors::InvalidArgument(
|
||||
"Input vector shapes don't match: ", output_shape.DebugString(),
|
||||
" vs. ", t.shape().DebugString()),
|
||||
done);
|
||||
output_shape = t.shape();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor* response_t;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, ctx->allocate_output(0, output_shape, &response_t), done);
|
||||
|
||||
const bool try_rpc = (ctx->num_outputs() > 1);
|
||||
|
||||
Tensor* status_code_t = nullptr;
|
||||
Tensor* status_message_t = nullptr;
|
||||
if (try_rpc) {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, ctx->allocate_output(1, output_shape, &status_code_t), done);
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, ctx->allocate_output(2, output_shape, &status_message_t), done);
|
||||
}
|
||||
|
||||
if (request_t.NumElements() == 0) {
|
||||
// Special case, we finished early!
|
||||
done();
|
||||
return;
|
||||
}
|
||||
|
||||
int64 num_elements = output_shape.num_elements();
|
||||
|
||||
rpc_factory_->Call(ctx, num_elements, address_t, method_t, request_t,
|
||||
try_rpc, response_t, status_code_t, status_message_t,
|
||||
std::move(done));
|
||||
}
|
||||
|
||||
private:
|
||||
string protocol_;
|
||||
std::unique_ptr<RPCFactory> rpc_factory_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RpcOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Rpc").Device(DEVICE_CPU), RpcOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("TryRpc").Device(DEVICE_CPU), RpcOp);
|
||||
|
||||
} // namespace tensorflow
|
@ -85,7 +85,6 @@ tf_gen_op_libs(
|
||||
"remote_fused_graph_ops",
|
||||
"risc_ops",
|
||||
"rnn_ops",
|
||||
"rpc_ops",
|
||||
"scoped_allocator_ops",
|
||||
"sdca_ops",
|
||||
"set_ops",
|
||||
@ -286,7 +285,6 @@ cc_library(
|
||||
":stateful_random_ops_op_lib",
|
||||
":remote_fused_graph_ops_op_lib",
|
||||
":resource_variable_ops_op_lib",
|
||||
":rpc_ops_op_lib",
|
||||
":scoped_allocator_ops_op_lib",
|
||||
":script_ops_op_lib",
|
||||
":sdca_ops_op_lib",
|
||||
|
@ -1,41 +0,0 @@
|
||||
op {
|
||||
name: "Rpc"
|
||||
input_arg {
|
||||
name: "address"
|
||||
type: DT_STRING
|
||||
}
|
||||
input_arg {
|
||||
name: "method"
|
||||
type: DT_STRING
|
||||
}
|
||||
input_arg {
|
||||
name: "request"
|
||||
type: DT_STRING
|
||||
}
|
||||
output_arg {
|
||||
name: "response"
|
||||
type: DT_STRING
|
||||
}
|
||||
attr {
|
||||
name: "protocol"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "fail_fast"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "timeout_in_ms"
|
||||
type: "int"
|
||||
default_value {
|
||||
i: 0
|
||||
}
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
@ -1,49 +0,0 @@
|
||||
op {
|
||||
name: "TryRpc"
|
||||
input_arg {
|
||||
name: "address"
|
||||
type: DT_STRING
|
||||
}
|
||||
input_arg {
|
||||
name: "method"
|
||||
type: DT_STRING
|
||||
}
|
||||
input_arg {
|
||||
name: "request"
|
||||
type: DT_STRING
|
||||
}
|
||||
output_arg {
|
||||
name: "response"
|
||||
type: DT_STRING
|
||||
}
|
||||
output_arg {
|
||||
name: "status_code"
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "status_message"
|
||||
type: DT_STRING
|
||||
}
|
||||
attr {
|
||||
name: "protocol"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "fail_fast"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "timeout_in_ms"
|
||||
type: "int"
|
||||
default_value {
|
||||
i: 0
|
||||
}
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
@ -1,41 +0,0 @@
|
||||
op {
|
||||
name: "Rpc"
|
||||
input_arg {
|
||||
name: "address"
|
||||
type: DT_STRING
|
||||
}
|
||||
input_arg {
|
||||
name: "method"
|
||||
type: DT_STRING
|
||||
}
|
||||
input_arg {
|
||||
name: "request"
|
||||
type: DT_STRING
|
||||
}
|
||||
output_arg {
|
||||
name: "response"
|
||||
type: DT_STRING
|
||||
}
|
||||
attr {
|
||||
name: "protocol"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "fail_fast"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "timeout_in_ms"
|
||||
type: "int"
|
||||
default_value {
|
||||
i: 0
|
||||
}
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
@ -1,49 +0,0 @@
|
||||
op {
|
||||
name: "TryRpc"
|
||||
input_arg {
|
||||
name: "address"
|
||||
type: DT_STRING
|
||||
}
|
||||
input_arg {
|
||||
name: "method"
|
||||
type: DT_STRING
|
||||
}
|
||||
input_arg {
|
||||
name: "request"
|
||||
type: DT_STRING
|
||||
}
|
||||
output_arg {
|
||||
name: "response"
|
||||
type: DT_STRING
|
||||
}
|
||||
output_arg {
|
||||
name: "status_code"
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "status_message"
|
||||
type: DT_STRING
|
||||
}
|
||||
attr {
|
||||
name: "protocol"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "fail_fast"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "timeout_in_ms"
|
||||
type: "int"
|
||||
default_value {
|
||||
i: 0
|
||||
}
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using tensorflow::shape_inference::InferenceContext;
|
||||
using tensorflow::shape_inference::ShapeHandle;
|
||||
|
||||
Status RpcShapeOp(InferenceContext* c, bool try_rpc) {
|
||||
ShapeHandle address;
|
||||
ShapeHandle method;
|
||||
ShapeHandle request;
|
||||
ShapeHandle output;
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &address));
|
||||
if (c->Rank(address) == 1) {
|
||||
TF_RETURN_IF_ERROR(c->Merge(output, address, &output));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &method));
|
||||
if (c->Rank(method) == 1) {
|
||||
TF_RETURN_IF_ERROR(c->Merge(output, method, &output));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &request));
|
||||
if (c->Rank(request) == 1) {
|
||||
TF_RETURN_IF_ERROR(c->Merge(output, request, &output));
|
||||
}
|
||||
if (!c->RankKnown(output)) {
|
||||
output = request;
|
||||
}
|
||||
c->set_output(0, output); // response
|
||||
if (try_rpc) {
|
||||
c->set_output(1, output); // status_code
|
||||
c->set_output(2, output); // status_message
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_OP("Rpc")
|
||||
.Input("address: string")
|
||||
.Input("method: string")
|
||||
.Input("request: string")
|
||||
.Attr("protocol: string = ''")
|
||||
.Attr("fail_fast: bool = true")
|
||||
.Attr("timeout_in_ms: int = 0")
|
||||
.Output("response: string")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
return RpcShapeOp(c, /*try_rpc=*/false);
|
||||
});
|
||||
|
||||
REGISTER_OP("TryRpc")
|
||||
.Input("address: string")
|
||||
.Input("method: string")
|
||||
.Input("request: string")
|
||||
.Attr("protocol: string = ''")
|
||||
.Attr("fail_fast: bool = true")
|
||||
.Attr("timeout_in_ms: int = 0")
|
||||
.Output("response: string")
|
||||
.Output("status_code: int32")
|
||||
.Output("status_message: string")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
return RpcShapeOp(c, /*try_rpc=*/true);
|
||||
});
|
||||
|
||||
} // namespace tensorflow
|
@ -1,48 +0,0 @@
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "call_container",
|
||||
hdrs = ["call_container.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rpc_factory",
|
||||
srcs = ["rpc_factory.cc"],
|
||||
hdrs = ["rpc_factory.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rpc_factory_registry",
|
||||
srcs = ["rpc_factory_registry.cc"],
|
||||
hdrs = ["rpc_factory_registry.h"],
|
||||
deps = [
|
||||
":rpc_factory",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "rpc_factory_registry_test",
|
||||
srcs = ["rpc_factory_registry_test.cc"],
|
||||
deps = [
|
||||
":rpc_factory_registry",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
@ -1,182 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
|
||||
#define TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
|
||||
|
||||
#include <list>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/util/reffed_status_callback.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace internal {
|
||||
// The following class is used for coordination between a `CallContainer`
|
||||
// instance and a cancellation callback to make sure that the `CallContainer`
|
||||
// instance waits for the cancellation callback to be destroyed (either because
|
||||
// a cancellation occurred or because the callback was deregistered) before
|
||||
// deleting itself. Without this coordination the cancellation callback could
|
||||
// attempt to access a `CallContainer` instance that is no longer valid.
|
||||
class NotifyWhenDestroyed {
|
||||
public:
|
||||
explicit NotifyWhenDestroyed(std::shared_ptr<Notification> notification)
|
||||
: notification_(std::move(notification)) {}
|
||||
|
||||
~NotifyWhenDestroyed() { notification_->Notify(); }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Notification> notification_;
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
// The following class is responsible for the life cycle management of a set of
|
||||
// RPC calls. The calls are started when an instance of the class is created and
|
||||
// the class contract guarantees to invoke a "done" callback provided by the
|
||||
// caller when all RPC calls have either completed or been cancelled.
|
||||
//
|
||||
// The caller should not make any assumptions about the validity of an instance
|
||||
// of this class after the provided callback has been invoked, which may be
|
||||
// immediately after the instance was created.
|
||||
template <class Call>
|
||||
class CallContainer {
|
||||
public:
|
||||
typedef std::function<void(CallContainer<Call>*, int)> CreateCallFn;
|
||||
typedef std::function<void(Call*)> StartCallFn;
|
||||
|
||||
// Uses the provided `create_call_fn` and `start_call_fn` functions to create
|
||||
// and start a set of RPC calls. When all RPC calls have either completed or
|
||||
// been cancelled, the `done` callback is invoked. The caller should not make
|
||||
// any assumptions about the validity of the created instance as the instance
|
||||
// will delete itself after invoking the `done` callback.
|
||||
explicit CallContainer(OpKernelContext* ctx, int num_calls, bool fail_fast,
|
||||
bool try_rpc, AsyncOpKernel::DoneCallback done,
|
||||
CreateCallFn create_call_fn,
|
||||
StartCallFn start_call_fn);
|
||||
|
||||
// Registers a call with this container. This method expects its arguments to
|
||||
// match those of a `Call` constructor as it forwards them to an underlying
|
||||
// collection, which creates a `Call` instance in place.
|
||||
template <class... Args>
|
||||
void RegisterCall(Args&&... args);
|
||||
|
||||
// Starts the cancellation of all RPC calls managed by this container.
|
||||
void StartCancel();
|
||||
|
||||
// Indicates that the `index`-th RPC call has finished.
|
||||
void Done(const Status& s, int index);
|
||||
|
||||
private:
|
||||
OpKernelContext* ctx_;
|
||||
std::list<Call> calls_;
|
||||
const AsyncOpKernel::DoneCallback done_;
|
||||
const CancellationToken token_;
|
||||
const bool fail_fast_;
|
||||
const bool try_rpc_;
|
||||
std::shared_ptr<Notification> callback_destroyed_;
|
||||
|
||||
// Performs its own reference counting.
|
||||
ReffedStatusCallback* reffed_status_callback_;
|
||||
};
|
||||
|
||||
template <class Call>
|
||||
CallContainer<Call>::CallContainer(
|
||||
OpKernelContext* ctx, int num_calls, bool fail_fast, bool try_rpc,
|
||||
AsyncOpKernel::DoneCallback done,
|
||||
typename CallContainer<Call>::CreateCallFn create_call_fn,
|
||||
typename CallContainer<Call>::StartCallFn start_call_fn)
|
||||
: ctx_(ctx),
|
||||
done_(std::move(done)),
|
||||
token_(ctx->cancellation_manager() != nullptr
|
||||
? ctx->cancellation_manager()->get_cancellation_token()
|
||||
: CancellationManager::kInvalidToken),
|
||||
fail_fast_(fail_fast),
|
||||
try_rpc_(try_rpc),
|
||||
callback_destroyed_(new Notification) {
|
||||
CHECK_GT(num_calls, 0);
|
||||
|
||||
// This will run when all RPCs are finished.
|
||||
reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) {
|
||||
if (token_ != CancellationManager::kInvalidToken) {
|
||||
ctx_->cancellation_manager()->DeregisterCallback(token_);
|
||||
}
|
||||
ctx_->SetStatus(s);
|
||||
done_();
|
||||
callback_destroyed_->WaitForNotification();
|
||||
delete this;
|
||||
});
|
||||
|
||||
// The cancellation callback needs to be registered before the RPC calls are
|
||||
// started to make sure that the callback is properly cleaned up by the
|
||||
// `reffed_status_callback` when all calls complete. At the same time, the
|
||||
// cancellation callback should wait for the RPC calls to be started for the
|
||||
// cancellation to take effect.
|
||||
std::shared_ptr<internal::NotifyWhenDestroyed> notify_when_destroyed(
|
||||
new internal::NotifyWhenDestroyed(callback_destroyed_));
|
||||
std::shared_ptr<Notification> calls_started(new Notification);
|
||||
bool is_cancelled = false;
|
||||
if (token_ != CancellationManager::kInvalidToken) {
|
||||
is_cancelled = !ctx_->cancellation_manager()->RegisterCallback(
|
||||
token_, [this, calls_started, notify_when_destroyed]() {
|
||||
calls_started->WaitForNotification();
|
||||
StartCancel();
|
||||
});
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_calls; ++i) {
|
||||
create_call_fn(this, i);
|
||||
// Increase the reference on the callback for each new RPC.
|
||||
reffed_status_callback_->Ref();
|
||||
}
|
||||
for (Call& call : calls_) {
|
||||
start_call_fn(&call);
|
||||
}
|
||||
calls_started->Notify();
|
||||
|
||||
if (is_cancelled) {
|
||||
ctx_->SetStatus(errors::Cancelled("Operation has been cancelled."));
|
||||
StartCancel();
|
||||
}
|
||||
|
||||
// Subtract reference count from the initial creation.
|
||||
reffed_status_callback_->Unref();
|
||||
}
|
||||
|
||||
template <class Call>
|
||||
template <class... Args>
|
||||
void CallContainer<Call>::RegisterCall(Args&&... args) {
|
||||
calls_.emplace_back(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <class Call>
|
||||
void CallContainer<Call>::StartCancel() {
|
||||
for (auto& call : calls_) {
|
||||
call.StartCancel();
|
||||
}
|
||||
}
|
||||
|
||||
template <class Call>
|
||||
void CallContainer<Call>::Done(const Status& s, int index) {
|
||||
if (!try_rpc_) {
|
||||
reffed_status_callback_->UpdateStatus(s);
|
||||
}
|
||||
reffed_status_callback_->Unref();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
|
@ -1,53 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
|
||||
#include "tensorflow/core/util/rpc/rpc_factory.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
template <>
|
||||
bool GetEnvVar(const char* key, const string& default_value, string* value) {
|
||||
const char* env_value = std::getenv(key);
|
||||
if (!env_value || env_value[0] == '\0') {
|
||||
*value = default_value;
|
||||
} else {
|
||||
*value = env_value;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool GetEnvVar(const char* key, const int64& default_value, int64* value) {
|
||||
const char* env_value = std::getenv(key);
|
||||
if (!env_value || env_value[0] == '\0') {
|
||||
*value = default_value;
|
||||
return true;
|
||||
}
|
||||
return strings::safe_strto64(env_value, value);
|
||||
}
|
||||
|
||||
template <>
|
||||
bool GetEnvVar(const char* key, const uint64& default_value, uint64* value) {
|
||||
const char* env_value = std::getenv(key);
|
||||
if (!env_value || env_value[0] == '\0') {
|
||||
*value = default_value;
|
||||
return true;
|
||||
}
|
||||
return strings::safe_strtou64(env_value, value);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -1,71 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_
|
||||
#define TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Return the environment variable `key`. If the variable is not set,
|
||||
// use the default value. If it is set but could not be parsed,
|
||||
// return `false`. Otherwise set `value` and return `true`.
|
||||
template <typename T>
|
||||
bool GetEnvVar(const char* key, const T& default_value, T* value);
|
||||
|
||||
class RPCFactory {
|
||||
public:
|
||||
RPCFactory() {}
|
||||
virtual ~RPCFactory() {}
|
||||
|
||||
// Asynchronously invokes methods `method_t` at addresses `address_t` with
|
||||
// request strings from `request_t`. Any of these may be scalar
|
||||
// Tensors, in which case the operands are broadcasted.
|
||||
// Upon completion of all requests, `response_t` will be populated and the
|
||||
// `done` callback will be invoked.
|
||||
//
|
||||
// If `try_rpc` is `true`, then `status_message_t` and
|
||||
// `status_code_t` will be populated as well.
|
||||
//
|
||||
// If `try_rpc` is `false`, then `status_message_t` and
|
||||
// `status_code_t` are ignored (and may be nullptr). Instead, the
|
||||
// status of any failed call will be propagated to the op.
|
||||
//
|
||||
// REQUIRES:
|
||||
// - `response_t` is not null, and is a string Tensor with the same shape as
|
||||
// `request_t`.
|
||||
//
|
||||
// If `try_rpc` is `true`:
|
||||
// - `status_code_t` and `status_message_t` are not null.
|
||||
// - `status_code_t` is an int32 Tensor with the same shape as
|
||||
// `request_t`.
|
||||
// - `status_message_t` is a string Tensor with the same shape as
|
||||
// `request_t`.
|
||||
virtual void Call(OpKernelContext* ctx, int64 num_elements,
|
||||
const Tensor& address_t, const Tensor& method_t,
|
||||
const Tensor& request_t, const bool try_rpc,
|
||||
Tensor* response_t, Tensor* status_code_t,
|
||||
Tensor* status_message_t,
|
||||
AsyncOpKernel::DoneCallback done) = 0;
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RPCFactory);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_
|
@ -1,44 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/util/rpc/rpc_factory.h"
|
||||
|
||||
#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
RPCFactoryRegistry* RPCFactoryRegistry::Global() {
|
||||
static RPCFactoryRegistry* registry = new RPCFactoryRegistry;
|
||||
return registry;
|
||||
}
|
||||
|
||||
RPCFactoryRegistry::RPCFactoryFn* RPCFactoryRegistry::Get(
|
||||
const string& protocol) {
|
||||
auto found = fns_.find(protocol);
|
||||
if (found == fns_.end()) return nullptr;
|
||||
return &found->second;
|
||||
}
|
||||
|
||||
void RPCFactoryRegistry::Register(const string& protocol,
|
||||
const RPCFactoryFn& factory_fn) {
|
||||
auto existing = Get(protocol);
|
||||
CHECK_EQ(existing, nullptr)
|
||||
<< "RPC factory for protocol: " << protocol << " already registered";
|
||||
fns_.insert(std::pair<const string&, RPCFactoryFn>(protocol, factory_fn));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -1,72 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_
|
||||
#define TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/util/rpc/rpc_factory.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class RPCFactoryRegistry {
|
||||
public:
|
||||
typedef std::function<RPCFactory*(OpKernelConstruction* ctx, bool fail_fast,
|
||||
int64 timeout_in_ms)>
|
||||
RPCFactoryFn;
|
||||
|
||||
// Returns a pointer to a global RPCFactoryRegistry object.
|
||||
static RPCFactoryRegistry* Global();
|
||||
|
||||
// Returns a pointer to an function that creates an RPC factory for the given
|
||||
// protocol.
|
||||
RPCFactoryFn* Get(const string& protocol);
|
||||
|
||||
// Registers a function that creates and RPC factory for the given protocol.
|
||||
// The function should transfer the ownership of the factory to its caller.
|
||||
void Register(const string& protocol, const RPCFactoryFn& factory_fn);
|
||||
|
||||
private:
|
||||
std::map<string, RPCFactoryFn> fns_;
|
||||
};
|
||||
|
||||
namespace rpc_factory_registration {
|
||||
|
||||
class RPCFactoryRegistration {
|
||||
public:
|
||||
RPCFactoryRegistration(const string& protocol,
|
||||
const RPCFactoryRegistry::RPCFactoryFn& factory_fn) {
|
||||
RPCFactoryRegistry::Global()->Register(protocol, factory_fn);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace rpc_factory_registration
|
||||
|
||||
#define REGISTER_RPC_FACTORY(protocol, factory_fn) \
|
||||
REGISTER_RPC_FACTORY_UNIQ_HELPER(__COUNTER__, protocol, factory_fn)
|
||||
|
||||
#define REGISTER_RPC_FACTORY_UNIQ_HELPER(ctr, protocol, factory_fn) \
|
||||
REGISTER_RPC_FACTORY_UNIQ(ctr, protocol, factory_fn)
|
||||
|
||||
#define REGISTER_RPC_FACTORY_UNIQ(ctr, protocol, factory_fn) \
|
||||
static rpc_factory_registration::RPCFactoryRegistration \
|
||||
rpc_factory_registration_fn_##ctr(protocol, factory_fn)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_
|
@ -1,41 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
struct Value {
|
||||
static RPCFactory* Function(OpKernelConstruction* ctx, bool fail_fast,
|
||||
int64 timeout_in_ms) {
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_RPC_FACTORY("TEST FACTORY 1", Value::Function);
|
||||
REGISTER_RPC_FACTORY("TEST FACTORY 2", Value::Function);
|
||||
} // namespace
|
||||
|
||||
TEST(RPCFactoryRegistryTest, TestBasic) {
|
||||
EXPECT_EQ(RPCFactoryRegistry::Global()->Get("NON-EXISTENT"), nullptr);
|
||||
auto factory1 = RPCFactoryRegistry::Global()->Get("TEST FACTORY 1");
|
||||
EXPECT_NE(factory1, nullptr);
|
||||
auto factory2 = RPCFactoryRegistry::Global()->Get("TEST FACTORY 2");
|
||||
EXPECT_NE(factory2, nullptr);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -5206,7 +5206,6 @@ pywrap_tensorflow_macro(
|
||||
":bfloat16_lib",
|
||||
":cost_analyzer_lib",
|
||||
":model_analyzer_lib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_rpc_factory_registration",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
|
||||
"//tensorflow/python/util:cpp_python_util",
|
||||
|
Loading…
x
Reference in New Issue
Block a user