Remove unused RPC op kernels.

PiperOrigin-RevId: 354639187
Change-Id: I302f2cfc9b386f53cf0641461b699f3730a63150
This commit is contained in:
Jiri Simsa 2021-01-29 17:04:11 -08:00 committed by TensorFlower Gardener
parent 6574fc4e08
commit 222861851e
25 changed files with 0 additions and 1516 deletions

View File

@ -641,7 +641,6 @@ cc_library(
"//tensorflow/core/kernels:required",
"//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core/kernels:rnn_ops",
"//tensorflow/core/kernels:rpc_op",
"//tensorflow/core/kernels:scoped_allocator_ops",
"//tensorflow/core/kernels:sdca_ops",
"//tensorflow/core/kernels:searchsorted_op",
@ -978,7 +977,6 @@ filegroup(
"resource_variable_ops_op_lib",
"risc_ops_op_lib",
"rnn_ops_op_lib",
"rpc_ops_op_lib",
"scoped_allocator_ops_op_lib",
"script_ops_op_lib",
"sdca_ops_op_lib",

View File

@ -1,108 +0,0 @@
op {
graph_op_name: "Rpc"
in_arg {
name: "address"
description: <<END
`0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server.
If this tensor has more than 1 element, then multiple parallel rpc requests
are sent. This argument broadcasts with `method` and `request`.
END
}
in_arg {
name: "method"
description: <<END
`0-D` or `1-D`. The method address on the RPC server.
If this tensor has more than 1 element, then multiple parallel rpc requests
are sent. This argument broadcasts with `address` and `request`.
END
}
in_arg {
name: "request"
description: <<END
`0-D` or `1-D`. Serialized proto strings: the rpc request argument.
If this tensor has more than 1 element, then multiple parallel rpc requests
are sent. This argument broadcasts with `address` and `method`.
END
}
out_arg {
name: "response"
description: <<END
Same shape as `request`. Serialized proto strings: the rpc responses.
END
}
attr {
name: "protocol"
description: <<END
RPC protocol to use. Empty string means use the default protocol.
Options include 'grpc'.
END
}
attr {
name: "fail_fast"
description: <<END
`boolean`. If `true` (default), then failures to connect
(i.e., the server does not immediately respond) cause an RPC failure.
END
}
attr {
name: "timeout_in_ms"
description: <<END
`int`. If `0` (default), then the kernel will run the RPC
request and only time out if the RPC deadline passes or the session times out.
If this value is greater than `0`, then the op will raise an exception if
the RPC takes longer than `timeout_in_ms`.
END
}
summary: <<END
Perform batches of RPC requests.
END
description: <<END
This op asynchronously performs either a single RPC request, or a batch
of requests. RPC requests are defined by three main parameters:
- `address` (the host+port or BNS address of the request)
- `method` (the RPC method name for the request)
- `request` (the serialized proto string, or vector of strings,
of the RPC request argument).
For example, if you have an RPC service running on port localhost:2345,
and its interface is configured with the following proto declaration:
```
service MyService {
rpc MyMethod(MyRequestProto) returns (MyResponseProto) {
}
};
```
then call this op with arguments:
```
address = "localhost:2345"
method = "MyService/MyMethod"
```
The `request` tensor is a string tensor representing serialized `MyRequestProto`
strings; and the output string tensor `response` will have the same shape
and contain (upon successful completion) corresponding serialized
`MyResponseProto` strings.
For example, to send a single, empty, `MyRequestProto`, call
this op with `request = ""`. To send 5 **parallel** empty requests,
call this op with `request = ["", "", "", "", ""]`.
More generally, one can create a batch of `MyRequestProto` serialized protos
from regular batched tensors using the `encode_proto` op, and convert
the response `MyResponseProto` serialized protos to batched tensors
using the `decode_proto` op.
**NOTE** Working with serialized proto strings is faster than instantiating
actual proto objects in memory, so no performance degradation is expected
compared to writing custom kernels for this workflow.
If the connection fails or the remote worker returns an error
status, the op reraises this exception locally.
See the `TryRpc` op if you prefer to handle RPC failures manually in the graph.
END
}

View File

@ -1,123 +0,0 @@
op {
graph_op_name: "TryRpc"
in_arg {
name: "address"
description: <<END
`0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server.
If this tensor has more than 1 element, then multiple parallel rpc requests
are sent. This argument broadcasts with `method` and `request`.
END
}
in_arg {
name: "method"
description: <<END
`0-D` or `1-D`. The method address on the RPC server.
If this tensor has more than 1 element, then multiple parallel rpc requests
are sent. This argument broadcasts with `address` and `request`.
END
}
in_arg {
name: "request"
description: <<END
`0-D` or `1-D`. Serialized proto strings: the rpc request argument.
If this tensor has more than 1 element, then multiple parallel rpc requests
are sent. This argument broadcasts with `address` and `method`.
END
}
out_arg {
name: "response"
description: <<END
Same shape as `request`. Serialized proto strings: the rpc responses.
END
}
out_arg {
name: "status_code"
description: <<END
Same shape as `request`. Values correspond to tensorflow Status enum codes.
END
}
out_arg {
name: "status_message"
description: <<END
Same shape as `request`. Values correspond to Status messages
returned from the RPC calls.
END
}
attr {
name: "protocol"
description: <<END
RPC protocol to use. Empty string means use the default protocol.
Options include 'grpc'.
END
}
attr {
name: "fail_fast"
description: <<END
`boolean`. If `true` (default), then failures to connect
(i.e., the server does not immediately respond) cause an RPC failure.
END
}
attr {
name: "timeout_in_ms"
description: <<END
`int`. If `0` (default), then the kernel will run the RPC
request and only time out if the RPC deadline passes or the session times out.
If this value is greater than `0`, then the op will raise an exception if
the RPC takes longer than `timeout_in_ms`.
END
}
summary: <<END
Perform batches of RPC requests.
END
description: <<END
This op asynchronously performs either a single RPC request, or a batch
of requests. RPC requests are defined by three main parameters:
- `address` (the host+port or BNS address of the request)
- `method` (the method name for the request)
- `request` (the serialized proto string, or vector of strings,
of the RPC request argument).
For example, if you have an RPC service running on port localhost:2345,
and its interface is configured with the following proto declaration:
```
service MyService {
rpc MyMethod(MyRequestProto) returns (MyResponseProto) {
}
};
```
then call this op with arguments:
```
address = "localhost:2345"
method = "MyService/MyMethod"
```
The `request` tensor is a string tensor representing serialized `MyRequestProto`
strings; and the output string tensor `response` will have the same shape
and contain (upon successful completion) corresponding serialized
`MyResponseProto` strings.
For example, to send a single, empty, `MyRequestProto`, call
this op with `request = ""`. To send 5 **parallel** empty requests,
call this op with `request = ["", "", "", "", ""]`.
More generally, one can create a batch of `MyRequestProto` serialized protos
from regular batched tensors using the `encode_proto` op, and convert
the response `MyResponseProto` serialized protos to batched tensors
using the `decode_proto` op.
**NOTE** Working with serialized proto strings is faster than instantiating
actual proto objects in memory, so no performance degradation is expected
compared to writing custom kernels for this workflow.
Unlike the standard `Rpc` op, if the connection fails or the remote worker
returns an error status, this op does **not** reraise the exception.
Instead, the `status_code` and `status_message` entry for the corresponding RPC
call is set with the error returned from the RPC call. The `response` tensor
will contain valid response values for those minibatch entries whose RPCs did
not fail; the rest of the entries will have empty strings.
END
}

View File

@ -1,3 +0,0 @@
op {
graph_op_name: "Rpc"
}

View File

@ -1,3 +0,0 @@
op {
graph_op_name: "TryRpc"
}

View File

@ -563,33 +563,3 @@ tf_cuda_cc_test(
"//tensorflow/core/protobuf:master_proto_cc",
],
)
cc_library(
name = "grpc_rpc_factory",
srcs = [
"grpc_rpc_factory.cc",
],
hdrs = ["grpc_rpc_factory.h"],
deps = [
":grpc_state",
":grpc_util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/util/rpc:call_container",
"//tensorflow/core/util/rpc:rpc_factory",
],
)
cc_library(
name = "grpc_rpc_factory_registration",
srcs = [
"grpc_rpc_factory_registration.cc",
],
deps = [
":grpc_rpc_factory",
"//tensorflow/core/util/rpc:rpc_factory",
"//tensorflow/core/util/rpc:rpc_factory_registry",
],
alwayslink = 1,
)

View File

@ -1,217 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/util/rpc/call_container.h"
#include "tensorflow/core/util/rpc/rpc_factory.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h"
namespace tensorflow {
namespace internal {
class GrpcCall {
public:
explicit GrpcCall(CallContainer<GrpcCall>* container, int index, bool try_rpc,
const tstring* request_msg, tstring* response_msg,
int32* status_code, tstring* status_message)
: container_(container),
index_(index),
try_rpc_(try_rpc),
request_msg_(request_msg),
response_msg_(response_msg),
status_code_(status_code),
status_message_(status_message) {}
void StartCancel() { call_opts_.StartCancel(); }
void Done(const Status& s) {
DCHECK(container_ != nullptr);
if (!s.ok() && try_rpc_) {
DCHECK(status_code_ != nullptr);
DCHECK(status_message_ != nullptr);
*status_code_ = s.code();
*status_message_ = s.error_message();
}
container_->Done(s, index_);
}
CallOptions* call_opts() { return &call_opts_; }
int index() { return index_; }
const tstring& request() const { return *request_msg_; }
tstring* response() const { return response_msg_; }
private:
CallContainer<GrpcCall>* const container_;
const int index_;
bool try_rpc_;
CallOptions call_opts_;
const tstring* request_msg_;
tstring* response_msg_;
int* status_code_;
tstring* status_message_;
};
} // namespace internal
using internal::GrpcCall;
GrpcRPCFactory::GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast,
int64 timeout_in_ms)
: RPCFactory(), fail_fast_(fail_fast), timeout_in_ms_(timeout_in_ms) {
// TODO(ebrevdo): Investigate possible performance improvements by
// replacing this thread with a threadpool.
polling_thread_ =
ctx->env()->StartThread(ThreadOptions(), "rpc_op_grpc_factory", [this]() {
void* tag;
bool ok;
while (completion_queue_.Next(&tag, &ok)) {
GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
callback_tag->OnCompleted(ok);
}
});
}
GrpcRPCFactory::~GrpcRPCFactory() {
// The amount of time we wait depends on several parameters, including:
// - the value of the fail_fast attribute.
// - the timeout option of the rpc call in the proto declaration.
// - the network roundtrip time and service's execution time.
//
// If a connection is made but the service doesn't ever respond, and
// there is no timeout option set for this rpc call, then it is
// possible the RPC request will wait forever.
//
completion_queue_.Shutdown();
delete polling_thread_;
}
void GrpcRPCFactory::Call(OpKernelContext* ctx, int64 num_elements,
const Tensor& address_t, const Tensor& method_t,
const Tensor& request_t, const bool try_rpc,
Tensor* response_t, Tensor* status_code_t,
Tensor* status_message_t,
AsyncOpKernel::DoneCallback done) {
if (try_rpc) {
// In this case status_code will never be set in the response,
// so we just set it to OK.
DCHECK(status_code_t != nullptr);
status_code_t->flat<int32>().setConstant(
static_cast<int>(errors::Code::OK));
}
CallContainer<GrpcCall>::CreateCallFn create_call_fn =
[this, &request_t, &try_rpc, response_t, status_code_t, status_message_t](
CallContainer<GrpcCall>* container, int index) {
CreateCall(request_t, try_rpc, index, container, response_t,
status_code_t, status_message_t);
};
CallContainer<GrpcCall>::StartCallFn start_call_fn =
[this, &address_t, &method_t](GrpcCall* call) {
StartCall(address_t, method_t, call);
};
// This object will delete itself when done.
new CallContainer<GrpcCall>(ctx, num_elements, fail_fast_, try_rpc,
std::move(done), std::move(create_call_fn),
std::move(start_call_fn));
}
::grpc::GenericStub* GrpcRPCFactory::GetOrCreateStubForAddress(
const string& address) {
mutex_lock lock(mu_);
auto stub = stubs_.find(address);
if (stub != stubs_.end()) return stub->second.get();
ChannelPtr channel = CreateChannelForAddress(address);
auto* created = new ::grpc::GenericStub(channel);
stubs_[address].reset(created);
return created;
}
GrpcRPCFactory::ChannelPtr GrpcRPCFactory::CreateChannelForAddress(
const string& address) {
::grpc::ChannelArguments args;
args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
// Set a standard backoff timeout of 1s instead of the
// (sometimes default) 20s.
args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 1000);
return ::grpc::CreateCustomChannel(
/*target=*/address, ::grpc::InsecureChannelCredentials(), args);
}
void GrpcRPCFactory::CreateCall(const Tensor& request_t, const bool try_rpc,
int index, CallContainer<GrpcCall>* container,
Tensor* response_t, Tensor* status_code_t,
Tensor* status_message_t) {
auto request = request_t.flat<tstring>();
auto get_request_ptr = [&request](int64 ix) -> const tstring* {
return (request.size() > 1) ? &(request(ix)) : &(request(0));
};
auto response = response_t->flat<tstring>();
int32* status_code_ptr = nullptr;
tstring* status_message_ptr = nullptr;
if (try_rpc) {
status_code_ptr = status_code_t->flat<int32>().data();
status_message_ptr = status_message_t->flat<tstring>().data();
}
container->RegisterCall(container, index, try_rpc, get_request_ptr(index),
&response(index),
(try_rpc) ? &status_code_ptr[index] : nullptr,
(try_rpc) ? &status_message_ptr[index] : nullptr);
}
void GrpcRPCFactory::StartCall(const Tensor& address_t, const Tensor& method_t,
GrpcCall* call) {
auto address = address_t.flat<tstring>();
auto method = method_t.flat<tstring>();
// Stubs are maintained by the GrpcRPCFactory class and will be
// deleted when the class is destroyed.
::grpc::GenericStub* singleton_stub = nullptr;
if (address.size() == 1) {
singleton_stub = GetOrCreateStubForAddress(address(0));
}
auto get_stub = [&address, this,
singleton_stub](int64 ix) -> ::grpc::GenericStub* {
return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix))
: singleton_stub;
};
auto get_method_ptr = [&method](int64 ix) -> const tstring* {
return (method.size() > 1) ? &(method(ix)) : &(method(0));
};
int index = call->index();
// This object will delete itself when done.
new RPCState<tstring>(
get_stub(index), &completion_queue_, *get_method_ptr(index),
call->request(), call->response(),
/*done=*/[call](const Status& s) { call->Done(s); }, call->call_opts(),
/*threadpool=*/nullptr, fail_fast_, timeout_in_ms_, /*max_retries=*/0,
/*target=*/nullptr);
}
} // namespace tensorflow

View File

@ -1,77 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_
#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/util/rpc/call_container.h"
#include "tensorflow/core/util/rpc/rpc_factory.h"
namespace tensorflow {
// Forward declaration of GrpcCall.
namespace internal {
class GrpcCall;
} // namespace internal
class GrpcRPCFactory : public RPCFactory {
public:
explicit GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast,
int64 timeout_in_ms);
// Explicit destructor to control destruction order.
~GrpcRPCFactory() override;
void Call(OpKernelContext* ctx, int64 num_elements, const Tensor& address_t,
const Tensor& method_t, const Tensor& request_t, const bool try_rpc,
Tensor* response_t, Tensor* status_code_t, Tensor* status_message_t,
AsyncOpKernel::DoneCallback done) override;
protected:
typedef std::shared_ptr<::grpc::Channel> ChannelPtr;
virtual ChannelPtr CreateChannelForAddress(const string& address);
private:
// Creates a call and registers it with given `container`. The `index` is used
// to index into the tensor arguments.
void CreateCall(const Tensor& request_t, const bool try_rpc, int index,
CallContainer<internal::GrpcCall>* container,
Tensor* response_t, Tensor* status_code_t,
Tensor* status_message_t);
// Asynchronously invokes the given `call`. The call completion is handled
// by the call container the call was previously registered with.
void StartCall(const Tensor& address_t, const Tensor& method_t,
internal::GrpcCall* call);
::grpc::GenericStub* GetOrCreateStubForAddress(const string& address);
bool fail_fast_;
int64 timeout_in_ms_;
::grpc::CompletionQueue completion_queue_;
Thread* polling_thread_; // Owned.
mutex mu_;
typedef std::unique_ptr<::grpc::GenericStub> StubPtr;
std::unordered_map<string, StubPtr> stubs_ TF_GUARDED_BY(mu_);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_

View File

@ -1,34 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h"
#include "tensorflow/core/util/rpc/rpc_factory.h"
#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
namespace tensorflow {
namespace {
// Used for adding the grpc factory to the RPC factory registry.
struct Value {
static RPCFactory* Function(OpKernelConstruction* ctx, bool fail_fast,
int64 timeout_in_ms) {
return new GrpcRPCFactory(ctx, fail_fast, timeout_in_ms);
}
};
REGISTER_RPC_FACTORY("grpc", Value::Function);
} // namespace
} // namespace tensorflow

View File

@ -6387,7 +6387,6 @@ filegroup(
"spectrogram_convert_test_data.cc",
"decode_proto_op.cc",
"encode_proto_op.cc",
"rpc_op.cc",
"sobol_op.cc",
# Excluded due to experimental status:
"debug_ops.*",
@ -7304,22 +7303,6 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "rpc_op",
srcs = [
"rpc_op.cc",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/util/rpc:call_container",
"//tensorflow/core/util/rpc:rpc_factory",
"//tensorflow/core/util/rpc:rpc_factory_registry",
"//third_party/eigen3",
],
)
tf_kernel_library(
name = "unicode_script_op",
srcs = ["unicode_script_op.cc"],

View File

@ -1,128 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// RpcOp is a TensorFlow op that sends and receives arbitrary messages.
//
// See docs in ../ops/rpc_op.cc.
#include <memory>
#include <string>
#include <vector>
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/rpc/call_container.h"
#include "tensorflow/core/util/rpc/rpc_factory.h"
#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
namespace tensorflow {
class RpcOp : public AsyncOpKernel {
public:
explicit RpcOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("protocol", &protocol_));
OP_REQUIRES(context, !protocol_.empty(),
errors::InvalidArgument("protocol must be non-empty."));
bool fail_fast;
OP_REQUIRES_OK(context, context->GetAttr("fail_fast", &fail_fast));
int64 timeout_in_ms;
OP_REQUIRES_OK(context, context->GetAttr("timeout_in_ms", &timeout_in_ms));
RPCFactoryRegistry::RPCFactoryFn* rpc_factory_fn =
RPCFactoryRegistry::Global()->Get(protocol_);
OP_REQUIRES(context, rpc_factory_fn != nullptr,
errors::InvalidArgument("The protocol ", protocol_,
" was not recognized."));
rpc_factory_.reset((*rpc_factory_fn)(context, fail_fast, timeout_in_ms));
}
~RpcOp() override {}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
const Tensor& address_t = ctx->input(0);
const Tensor& method_t = ctx->input(1);
const Tensor& request_t = ctx->input(2);
OP_REQUIRES_ASYNC(
ctx, address_t.dims() == 0 || address_t.dims() == 1,
errors::InvalidArgument("address must be a scalar or vector."), done);
OP_REQUIRES_ASYNC(
ctx, method_t.dims() == 0 || method_t.dims() == 1,
errors::InvalidArgument("method must be a scalar or vector."), done);
OP_REQUIRES_ASYNC(
ctx, request_t.dims() == 0 || request_t.dims() == 1,
errors::InvalidArgument("request must be a scalar or vector."), done);
TensorShape output_shape({});
for (const Tensor& t : {address_t, method_t, request_t}) {
if (t.dims() == 1) {
OP_REQUIRES_ASYNC(
ctx,
output_shape.dims() == 0 ||
output_shape.dim_size(0) == t.dim_size(0),
errors::InvalidArgument(
"Input vector shapes don't match: ", output_shape.DebugString(),
" vs. ", t.shape().DebugString()),
done);
output_shape = t.shape();
}
}
Tensor* response_t;
OP_REQUIRES_OK_ASYNC(
ctx, ctx->allocate_output(0, output_shape, &response_t), done);
const bool try_rpc = (ctx->num_outputs() > 1);
Tensor* status_code_t = nullptr;
Tensor* status_message_t = nullptr;
if (try_rpc) {
OP_REQUIRES_OK_ASYNC(
ctx, ctx->allocate_output(1, output_shape, &status_code_t), done);
OP_REQUIRES_OK_ASYNC(
ctx, ctx->allocate_output(2, output_shape, &status_message_t), done);
}
if (request_t.NumElements() == 0) {
// Special case, we finished early!
done();
return;
}
int64 num_elements = output_shape.num_elements();
rpc_factory_->Call(ctx, num_elements, address_t, method_t, request_t,
try_rpc, response_t, status_code_t, status_message_t,
std::move(done));
}
private:
string protocol_;
std::unique_ptr<RPCFactory> rpc_factory_;
TF_DISALLOW_COPY_AND_ASSIGN(RpcOp);
};
REGISTER_KERNEL_BUILDER(Name("Rpc").Device(DEVICE_CPU), RpcOp);
REGISTER_KERNEL_BUILDER(Name("TryRpc").Device(DEVICE_CPU), RpcOp);
} // namespace tensorflow

View File

@ -85,7 +85,6 @@ tf_gen_op_libs(
"remote_fused_graph_ops",
"risc_ops",
"rnn_ops",
"rpc_ops",
"scoped_allocator_ops",
"sdca_ops",
"set_ops",
@ -286,7 +285,6 @@ cc_library(
":stateful_random_ops_op_lib",
":remote_fused_graph_ops_op_lib",
":resource_variable_ops_op_lib",
":rpc_ops_op_lib",
":scoped_allocator_ops_op_lib",
":script_ops_op_lib",
":sdca_ops_op_lib",

View File

@ -1,41 +0,0 @@
op {
name: "Rpc"
input_arg {
name: "address"
type: DT_STRING
}
input_arg {
name: "method"
type: DT_STRING
}
input_arg {
name: "request"
type: DT_STRING
}
output_arg {
name: "response"
type: DT_STRING
}
attr {
name: "protocol"
type: "string"
default_value {
s: ""
}
}
attr {
name: "fail_fast"
type: "bool"
default_value {
b: true
}
}
attr {
name: "timeout_in_ms"
type: "int"
default_value {
i: 0
}
}
is_stateful: true
}

View File

@ -1,49 +0,0 @@
op {
name: "TryRpc"
input_arg {
name: "address"
type: DT_STRING
}
input_arg {
name: "method"
type: DT_STRING
}
input_arg {
name: "request"
type: DT_STRING
}
output_arg {
name: "response"
type: DT_STRING
}
output_arg {
name: "status_code"
type: DT_INT32
}
output_arg {
name: "status_message"
type: DT_STRING
}
attr {
name: "protocol"
type: "string"
default_value {
s: ""
}
}
attr {
name: "fail_fast"
type: "bool"
default_value {
b: true
}
}
attr {
name: "timeout_in_ms"
type: "int"
default_value {
i: 0
}
}
is_stateful: true
}

View File

@ -1,41 +0,0 @@
op {
name: "Rpc"
input_arg {
name: "address"
type: DT_STRING
}
input_arg {
name: "method"
type: DT_STRING
}
input_arg {
name: "request"
type: DT_STRING
}
output_arg {
name: "response"
type: DT_STRING
}
attr {
name: "protocol"
type: "string"
default_value {
s: ""
}
}
attr {
name: "fail_fast"
type: "bool"
default_value {
b: true
}
}
attr {
name: "timeout_in_ms"
type: "int"
default_value {
i: 0
}
}
is_stateful: true
}

View File

@ -1,49 +0,0 @@
op {
name: "TryRpc"
input_arg {
name: "address"
type: DT_STRING
}
input_arg {
name: "method"
type: DT_STRING
}
input_arg {
name: "request"
type: DT_STRING
}
output_arg {
name: "response"
type: DT_STRING
}
output_arg {
name: "status_code"
type: DT_INT32
}
output_arg {
name: "status_message"
type: DT_STRING
}
attr {
name: "protocol"
type: "string"
default_value {
s: ""
}
}
attr {
name: "fail_fast"
type: "bool"
default_value {
b: true
}
}
attr {
name: "timeout_in_ms"
type: "int"
default_value {
i: 0
}
}
is_stateful: true
}

View File

@ -1,80 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
using tensorflow::shape_inference::InferenceContext;
using tensorflow::shape_inference::ShapeHandle;
Status RpcShapeOp(InferenceContext* c, bool try_rpc) {
ShapeHandle address;
ShapeHandle method;
ShapeHandle request;
ShapeHandle output;
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &address));
if (c->Rank(address) == 1) {
TF_RETURN_IF_ERROR(c->Merge(output, address, &output));
}
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &method));
if (c->Rank(method) == 1) {
TF_RETURN_IF_ERROR(c->Merge(output, method, &output));
}
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &request));
if (c->Rank(request) == 1) {
TF_RETURN_IF_ERROR(c->Merge(output, request, &output));
}
if (!c->RankKnown(output)) {
output = request;
}
c->set_output(0, output); // response
if (try_rpc) {
c->set_output(1, output); // status_code
c->set_output(2, output); // status_message
}
return Status::OK();
}
REGISTER_OP("Rpc")
.Input("address: string")
.Input("method: string")
.Input("request: string")
.Attr("protocol: string = ''")
.Attr("fail_fast: bool = true")
.Attr("timeout_in_ms: int = 0")
.Output("response: string")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
return RpcShapeOp(c, /*try_rpc=*/false);
});
REGISTER_OP("TryRpc")
.Input("address: string")
.Input("method: string")
.Input("request: string")
.Attr("protocol: string = ''")
.Attr("fail_fast: bool = true")
.Attr("timeout_in_ms: int = 0")
.Output("response: string")
.Output("status_code: int32")
.Output("status_message: string")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
return RpcShapeOp(c, /*try_rpc=*/true);
});
} // namespace tensorflow

View File

@ -1,48 +0,0 @@
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "call_container",
hdrs = ["call_container.h"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
cc_library(
name = "rpc_factory",
srcs = ["rpc_factory.cc"],
hdrs = ["rpc_factory.h"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
cc_library(
name = "rpc_factory_registry",
srcs = ["rpc_factory_registry.cc"],
hdrs = ["rpc_factory_registry.h"],
deps = [
":rpc_factory",
"//tensorflow/core:framework",
],
)
tf_cc_test(
name = "rpc_factory_registry_test",
srcs = ["rpc_factory_registry_test.cc"],
deps = [
":rpc_factory_registry",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)

View File

@ -1,182 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
#define TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
#include <list>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/util/reffed_status_callback.h"
namespace tensorflow {
namespace internal {
// The following class is used for coordination between a `CallContainer`
// instance and a cancellation callback to make sure that the `CallContainer`
// instance waits for the cancellation callback to be destroyed (either because
// a cancellation occurred or because the callback was deregistered) before
// deleting itself. Without this coordination the cancellation callback could
// attempt to access a `CallContainer` instance that is no longer valid.
class NotifyWhenDestroyed {
public:
explicit NotifyWhenDestroyed(std::shared_ptr<Notification> notification)
: notification_(std::move(notification)) {}
~NotifyWhenDestroyed() { notification_->Notify(); }
private:
std::shared_ptr<Notification> notification_;
};
} // namespace internal
// The following class is responsible for the life cycle management of a set of
// RPC calls. The calls are started when an instance of the class is created and
// the class contract guarantees to invoke a "done" callback provided by the
// caller when all RPC calls have either completed or been cancelled.
//
// The caller should not make any assumptions about the validity of an instance
// of this class after the provided callback has been invoked, which may be
// immediately after the instance was created.
template <class Call>
class CallContainer {
public:
typedef std::function<void(CallContainer<Call>*, int)> CreateCallFn;
typedef std::function<void(Call*)> StartCallFn;
// Uses the provided `create_call_fn` and `start_call_fn` functions to create
// and start a set of RPC calls. When all RPC calls have either completed or
// been cancelled, the `done` callback is invoked. The caller should not make
// any assumptions about the validity of the created instance as the instance
// will delete itself after invoking the `done` callback.
explicit CallContainer(OpKernelContext* ctx, int num_calls, bool fail_fast,
bool try_rpc, AsyncOpKernel::DoneCallback done,
CreateCallFn create_call_fn,
StartCallFn start_call_fn);
// Registers a call with this container. This method expects its arguments to
// match those of a `Call` constructor as it forwards them to an underlying
// collection, which creates a `Call` instance in place.
template <class... Args>
void RegisterCall(Args&&... args);
// Starts the cancellation of all RPC calls managed by this container.
void StartCancel();
// Indicates that the `index`-th RPC call has finished.
void Done(const Status& s, int index);
private:
OpKernelContext* ctx_;
std::list<Call> calls_;
const AsyncOpKernel::DoneCallback done_;
const CancellationToken token_;
const bool fail_fast_;
const bool try_rpc_;
std::shared_ptr<Notification> callback_destroyed_;
// Performs its own reference counting.
ReffedStatusCallback* reffed_status_callback_;
};
template <class Call>
CallContainer<Call>::CallContainer(
OpKernelContext* ctx, int num_calls, bool fail_fast, bool try_rpc,
AsyncOpKernel::DoneCallback done,
typename CallContainer<Call>::CreateCallFn create_call_fn,
typename CallContainer<Call>::StartCallFn start_call_fn)
: ctx_(ctx),
done_(std::move(done)),
token_(ctx->cancellation_manager() != nullptr
? ctx->cancellation_manager()->get_cancellation_token()
: CancellationManager::kInvalidToken),
fail_fast_(fail_fast),
try_rpc_(try_rpc),
callback_destroyed_(new Notification) {
CHECK_GT(num_calls, 0);
// This will run when all RPCs are finished.
reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) {
if (token_ != CancellationManager::kInvalidToken) {
ctx_->cancellation_manager()->DeregisterCallback(token_);
}
ctx_->SetStatus(s);
done_();
callback_destroyed_->WaitForNotification();
delete this;
});
// The cancellation callback needs to be registered before the RPC calls are
// started to make sure that the callback is properly cleaned up by the
// `reffed_status_callback` when all calls complete. At the same time, the
// cancellation callback should wait for the RPC calls to be started for the
// cancellation to take effect.
std::shared_ptr<internal::NotifyWhenDestroyed> notify_when_destroyed(
new internal::NotifyWhenDestroyed(callback_destroyed_));
std::shared_ptr<Notification> calls_started(new Notification);
bool is_cancelled = false;
if (token_ != CancellationManager::kInvalidToken) {
is_cancelled = !ctx_->cancellation_manager()->RegisterCallback(
token_, [this, calls_started, notify_when_destroyed]() {
calls_started->WaitForNotification();
StartCancel();
});
}
for (int i = 0; i < num_calls; ++i) {
create_call_fn(this, i);
// Increase the reference on the callback for each new RPC.
reffed_status_callback_->Ref();
}
for (Call& call : calls_) {
start_call_fn(&call);
}
calls_started->Notify();
if (is_cancelled) {
ctx_->SetStatus(errors::Cancelled("Operation has been cancelled."));
StartCancel();
}
// Subtract reference count from the initial creation.
reffed_status_callback_->Unref();
}
template <class Call>
template <class... Args>
void CallContainer<Call>::RegisterCall(Args&&... args) {
calls_.emplace_back(std::forward<Args>(args)...);
}
template <class Call>
void CallContainer<Call>::StartCancel() {
for (auto& call : calls_) {
call.StartCancel();
}
}
template <class Call>
void CallContainer<Call>::Done(const Status& s, int index) {
if (!try_rpc_) {
reffed_status_callback_->UpdateStatus(s);
}
reffed_status_callback_->Unref();
}
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_

View File

@ -1,53 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/util/rpc/rpc_factory.h"
namespace tensorflow {
template <>
bool GetEnvVar(const char* key, const string& default_value, string* value) {
const char* env_value = std::getenv(key);
if (!env_value || env_value[0] == '\0') {
*value = default_value;
} else {
*value = env_value;
}
return true;
}
template <>
bool GetEnvVar(const char* key, const int64& default_value, int64* value) {
const char* env_value = std::getenv(key);
if (!env_value || env_value[0] == '\0') {
*value = default_value;
return true;
}
return strings::safe_strto64(env_value, value);
}
template <>
bool GetEnvVar(const char* key, const uint64& default_value, uint64* value) {
const char* env_value = std::getenv(key);
if (!env_value || env_value[0] == '\0') {
*value = default_value;
return true;
}
return strings::safe_strtou64(env_value, value);
}
} // namespace tensorflow

View File

@ -1,71 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_
#define TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
namespace tensorflow {
// Return the environment variable `key`. If the variable is not set,
// use the default value. If it is set but could not be parsed,
// return `false`. Otherwise set `value` and return `true`.
template <typename T>
bool GetEnvVar(const char* key, const T& default_value, T* value);
class RPCFactory {
public:
RPCFactory() {}
virtual ~RPCFactory() {}
// Asynchronously invokes methods `method_t` at addresses `address_t` with
// request strings from `request_t`. Any of these may be scalar
// Tensors, in which case the operands are broadcasted.
// Upon completion of all requests, `response_t` will be populated and the
// `done` callback will be invoked.
//
// If `try_rpc` is `true`, then `status_message_t` and
// `status_code_t` will be populated as well.
//
// If `try_rpc` is `false`, then `status_message_t` and
// `status_code_t` are ignored (and may be nullptr). Instead, the
// status of any failed call will be propagated to the op.
//
// REQUIRES:
// - `response_t` is not null, and is a string Tensor with the same shape as
// `request_t`.
//
// If `try_rpc` is `true`:
// - `status_code_t` and `status_message_t` are not null.
// - `status_code_t` is an int32 Tensor with the same shape as
// `request_t`.
// - `status_message_t` is a string Tensor with the same shape as
// `request_t`.
virtual void Call(OpKernelContext* ctx, int64 num_elements,
const Tensor& address_t, const Tensor& method_t,
const Tensor& request_t, const bool try_rpc,
Tensor* response_t, Tensor* status_code_t,
Tensor* status_message_t,
AsyncOpKernel::DoneCallback done) = 0;
private:
TF_DISALLOW_COPY_AND_ASSIGN(RPCFactory);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_

View File

@ -1,44 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include "tensorflow/core/util/rpc/rpc_factory.h"
#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
namespace tensorflow {
RPCFactoryRegistry* RPCFactoryRegistry::Global() {
static RPCFactoryRegistry* registry = new RPCFactoryRegistry;
return registry;
}
RPCFactoryRegistry::RPCFactoryFn* RPCFactoryRegistry::Get(
const string& protocol) {
auto found = fns_.find(protocol);
if (found == fns_.end()) return nullptr;
return &found->second;
}
void RPCFactoryRegistry::Register(const string& protocol,
const RPCFactoryFn& factory_fn) {
auto existing = Get(protocol);
CHECK_EQ(existing, nullptr)
<< "RPC factory for protocol: " << protocol << " already registered";
fns_.insert(std::pair<const string&, RPCFactoryFn>(protocol, factory_fn));
}
} // namespace tensorflow

View File

@ -1,72 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_
#define TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_
#include <map>
#include <string>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/util/rpc/rpc_factory.h"
namespace tensorflow {
class RPCFactoryRegistry {
public:
typedef std::function<RPCFactory*(OpKernelConstruction* ctx, bool fail_fast,
int64 timeout_in_ms)>
RPCFactoryFn;
// Returns a pointer to a global RPCFactoryRegistry object.
static RPCFactoryRegistry* Global();
// Returns a pointer to an function that creates an RPC factory for the given
// protocol.
RPCFactoryFn* Get(const string& protocol);
// Registers a function that creates and RPC factory for the given protocol.
// The function should transfer the ownership of the factory to its caller.
void Register(const string& protocol, const RPCFactoryFn& factory_fn);
private:
std::map<string, RPCFactoryFn> fns_;
};
namespace rpc_factory_registration {
class RPCFactoryRegistration {
public:
RPCFactoryRegistration(const string& protocol,
const RPCFactoryRegistry::RPCFactoryFn& factory_fn) {
RPCFactoryRegistry::Global()->Register(protocol, factory_fn);
}
};
} // namespace rpc_factory_registration
#define REGISTER_RPC_FACTORY(protocol, factory_fn) \
REGISTER_RPC_FACTORY_UNIQ_HELPER(__COUNTER__, protocol, factory_fn)
#define REGISTER_RPC_FACTORY_UNIQ_HELPER(ctr, protocol, factory_fn) \
REGISTER_RPC_FACTORY_UNIQ(ctr, protocol, factory_fn)
#define REGISTER_RPC_FACTORY_UNIQ(ctr, protocol, factory_fn) \
static rpc_factory_registration::RPCFactoryRegistration \
rpc_factory_registration_fn_##ctr(protocol, factory_fn)
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_

View File

@ -1,41 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
struct Value {
static RPCFactory* Function(OpKernelConstruction* ctx, bool fail_fast,
int64 timeout_in_ms) {
return nullptr;
}
};
REGISTER_RPC_FACTORY("TEST FACTORY 1", Value::Function);
REGISTER_RPC_FACTORY("TEST FACTORY 2", Value::Function);
} // namespace
TEST(RPCFactoryRegistryTest, TestBasic) {
EXPECT_EQ(RPCFactoryRegistry::Global()->Get("NON-EXISTENT"), nullptr);
auto factory1 = RPCFactoryRegistry::Global()->Get("TEST FACTORY 1");
EXPECT_NE(factory1, nullptr);
auto factory2 = RPCFactoryRegistry::Global()->Get("TEST FACTORY 2");
EXPECT_NE(factory2, nullptr);
}
} // namespace tensorflow

View File

@ -5206,7 +5206,6 @@ pywrap_tensorflow_macro(
":bfloat16_lib",
":cost_analyzer_lib",
":model_analyzer_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_rpc_factory_registration",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
"//tensorflow/python/util:cpp_python_util",