From 222861851e0135fce7556c6bbfe511c805667b1e Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Fri, 29 Jan 2021 17:04:11 -0800 Subject: [PATCH] Remove unused RPC op kernels. PiperOrigin-RevId: 354639187 Change-Id: I302f2cfc9b386f53cf0641461b699f3730a63150 --- tensorflow/core/BUILD | 2 - .../core/api_def/base_api/api_def_Rpc.pbtxt | 108 --------- .../api_def/base_api/api_def_TryRpc.pbtxt | 123 ---------- .../core/api_def/java_api/api_def_Rpc.pbtxt | 3 - .../api_def/java_api/api_def_TryRpc.pbtxt | 3 - tensorflow/core/distributed_runtime/rpc/BUILD | 30 --- .../rpc/grpc_rpc_factory.cc | 217 ------------------ .../rpc/grpc_rpc_factory.h | 77 ------- .../rpc/grpc_rpc_factory_registration.cc | 34 --- tensorflow/core/kernels/BUILD | 17 -- tensorflow/core/kernels/rpc_op.cc | 128 ----------- tensorflow/core/ops/BUILD | 2 - .../core/ops/compat/ops_history_v1/Rpc.pbtxt | 41 ---- .../ops/compat/ops_history_v1/TryRpc.pbtxt | 49 ---- .../core/ops/compat/ops_history_v2/Rpc.pbtxt | 41 ---- .../ops/compat/ops_history_v2/TryRpc.pbtxt | 49 ---- tensorflow/core/ops/rpc_ops.cc | 80 ------- tensorflow/core/util/rpc/BUILD | 48 ---- tensorflow/core/util/rpc/call_container.h | 182 --------------- tensorflow/core/util/rpc/rpc_factory.cc | 53 ----- tensorflow/core/util/rpc/rpc_factory.h | 71 ------ .../core/util/rpc/rpc_factory_registry.cc | 44 ---- .../core/util/rpc/rpc_factory_registry.h | 72 ------ .../util/rpc/rpc_factory_registry_test.cc | 41 ---- tensorflow/python/BUILD | 1 - 25 files changed, 1516 deletions(-) delete mode 100644 tensorflow/core/api_def/base_api/api_def_Rpc.pbtxt delete mode 100644 tensorflow/core/api_def/base_api/api_def_TryRpc.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_Rpc.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_TryRpc.pbtxt delete mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc delete mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h delete mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory_registration.cc delete mode 100644 tensorflow/core/kernels/rpc_op.cc delete mode 100644 tensorflow/core/ops/compat/ops_history_v1/Rpc.pbtxt delete mode 100644 tensorflow/core/ops/compat/ops_history_v1/TryRpc.pbtxt delete mode 100644 tensorflow/core/ops/compat/ops_history_v2/Rpc.pbtxt delete mode 100644 tensorflow/core/ops/compat/ops_history_v2/TryRpc.pbtxt delete mode 100644 tensorflow/core/ops/rpc_ops.cc delete mode 100644 tensorflow/core/util/rpc/BUILD delete mode 100644 tensorflow/core/util/rpc/call_container.h delete mode 100644 tensorflow/core/util/rpc/rpc_factory.cc delete mode 100644 tensorflow/core/util/rpc/rpc_factory.h delete mode 100644 tensorflow/core/util/rpc/rpc_factory_registry.cc delete mode 100644 tensorflow/core/util/rpc/rpc_factory_registry.h delete mode 100644 tensorflow/core/util/rpc/rpc_factory_registry_test.cc diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 7740bd930a3..2104df3e8bb 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -641,7 +641,6 @@ cc_library( "//tensorflow/core/kernels:required", "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:rnn_ops", - "//tensorflow/core/kernels:rpc_op", "//tensorflow/core/kernels:scoped_allocator_ops", "//tensorflow/core/kernels:sdca_ops", "//tensorflow/core/kernels:searchsorted_op", @@ -978,7 +977,6 @@ filegroup( "resource_variable_ops_op_lib", "risc_ops_op_lib", "rnn_ops_op_lib", - "rpc_ops_op_lib", "scoped_allocator_ops_op_lib", "script_ops_op_lib", "sdca_ops_op_lib", diff --git a/tensorflow/core/api_def/base_api/api_def_Rpc.pbtxt b/tensorflow/core/api_def/base_api/api_def_Rpc.pbtxt deleted file mode 100644 index 344ef191fd5..00000000000 --- a/tensorflow/core/api_def/base_api/api_def_Rpc.pbtxt +++ /dev/null @@ -1,108 +0,0 @@ -op { - graph_op_name: "Rpc" - in_arg { - name: "address" - description: < -#include -#include - -#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/refcount.h" -#include "tensorflow/core/util/rpc/call_container.h" -#include "tensorflow/core/util/rpc/rpc_factory.h" - -#include "tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h" - -namespace tensorflow { - -namespace internal { -class GrpcCall { - public: - explicit GrpcCall(CallContainer* container, int index, bool try_rpc, - const tstring* request_msg, tstring* response_msg, - int32* status_code, tstring* status_message) - : container_(container), - index_(index), - try_rpc_(try_rpc), - request_msg_(request_msg), - response_msg_(response_msg), - status_code_(status_code), - status_message_(status_message) {} - - void StartCancel() { call_opts_.StartCancel(); } - - void Done(const Status& s) { - DCHECK(container_ != nullptr); - if (!s.ok() && try_rpc_) { - DCHECK(status_code_ != nullptr); - DCHECK(status_message_ != nullptr); - *status_code_ = s.code(); - *status_message_ = s.error_message(); - } - container_->Done(s, index_); - } - - CallOptions* call_opts() { return &call_opts_; } - int index() { return index_; } - const tstring& request() const { return *request_msg_; } - tstring* response() const { return response_msg_; } - - private: - CallContainer* const container_; - const int index_; - bool try_rpc_; - CallOptions call_opts_; - const tstring* request_msg_; - tstring* response_msg_; - int* status_code_; - tstring* status_message_; -}; - -} // namespace internal - -using internal::GrpcCall; - -GrpcRPCFactory::GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast, - int64 timeout_in_ms) - : RPCFactory(), fail_fast_(fail_fast), timeout_in_ms_(timeout_in_ms) { - // TODO(ebrevdo): Investigate possible performance improvements by - // replacing this thread with a threadpool. - polling_thread_ = - ctx->env()->StartThread(ThreadOptions(), "rpc_op_grpc_factory", [this]() { - void* tag; - bool ok; - while (completion_queue_.Next(&tag, &ok)) { - GrpcClientCQTag* callback_tag = static_cast(tag); - callback_tag->OnCompleted(ok); - } - }); -} - -GrpcRPCFactory::~GrpcRPCFactory() { - // The amount of time we wait depends on several parameters, including: - // - the value of the fail_fast attribute. - // - the timeout option of the rpc call in the proto declaration. - // - the network roundtrip time and service's execution time. - // - // If a connection is made but the service doesn't ever respond, and - // there is no timeout option set for this rpc call, then it is - // possible the RPC request will wait forever. - // - completion_queue_.Shutdown(); - delete polling_thread_; -} - -void GrpcRPCFactory::Call(OpKernelContext* ctx, int64 num_elements, - const Tensor& address_t, const Tensor& method_t, - const Tensor& request_t, const bool try_rpc, - Tensor* response_t, Tensor* status_code_t, - Tensor* status_message_t, - AsyncOpKernel::DoneCallback done) { - if (try_rpc) { - // In this case status_code will never be set in the response, - // so we just set it to OK. - DCHECK(status_code_t != nullptr); - status_code_t->flat().setConstant( - static_cast(errors::Code::OK)); - } - - CallContainer::CreateCallFn create_call_fn = - [this, &request_t, &try_rpc, response_t, status_code_t, status_message_t]( - CallContainer* container, int index) { - CreateCall(request_t, try_rpc, index, container, response_t, - status_code_t, status_message_t); - }; - - CallContainer::StartCallFn start_call_fn = - [this, &address_t, &method_t](GrpcCall* call) { - StartCall(address_t, method_t, call); - }; - - // This object will delete itself when done. - new CallContainer(ctx, num_elements, fail_fast_, try_rpc, - std::move(done), std::move(create_call_fn), - std::move(start_call_fn)); -} - -::grpc::GenericStub* GrpcRPCFactory::GetOrCreateStubForAddress( - const string& address) { - mutex_lock lock(mu_); - - auto stub = stubs_.find(address); - if (stub != stubs_.end()) return stub->second.get(); - - ChannelPtr channel = CreateChannelForAddress(address); - auto* created = new ::grpc::GenericStub(channel); - stubs_[address].reset(created); - return created; -} - -GrpcRPCFactory::ChannelPtr GrpcRPCFactory::CreateChannelForAddress( - const string& address) { - ::grpc::ChannelArguments args; - args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits::max()); - - // Set a standard backoff timeout of 1s instead of the - // (sometimes default) 20s. - args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 1000); - return ::grpc::CreateCustomChannel( - /*target=*/address, ::grpc::InsecureChannelCredentials(), args); -} - -void GrpcRPCFactory::CreateCall(const Tensor& request_t, const bool try_rpc, - int index, CallContainer* container, - Tensor* response_t, Tensor* status_code_t, - Tensor* status_message_t) { - auto request = request_t.flat(); - auto get_request_ptr = [&request](int64 ix) -> const tstring* { - return (request.size() > 1) ? &(request(ix)) : &(request(0)); - }; - auto response = response_t->flat(); - int32* status_code_ptr = nullptr; - tstring* status_message_ptr = nullptr; - if (try_rpc) { - status_code_ptr = status_code_t->flat().data(); - status_message_ptr = status_message_t->flat().data(); - } - container->RegisterCall(container, index, try_rpc, get_request_ptr(index), - &response(index), - (try_rpc) ? &status_code_ptr[index] : nullptr, - (try_rpc) ? &status_message_ptr[index] : nullptr); -} - -void GrpcRPCFactory::StartCall(const Tensor& address_t, const Tensor& method_t, - GrpcCall* call) { - auto address = address_t.flat(); - auto method = method_t.flat(); - // Stubs are maintained by the GrpcRPCFactory class and will be - // deleted when the class is destroyed. - ::grpc::GenericStub* singleton_stub = nullptr; - if (address.size() == 1) { - singleton_stub = GetOrCreateStubForAddress(address(0)); - } - auto get_stub = [&address, this, - singleton_stub](int64 ix) -> ::grpc::GenericStub* { - return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix)) - : singleton_stub; - }; - auto get_method_ptr = [&method](int64 ix) -> const tstring* { - return (method.size() > 1) ? &(method(ix)) : &(method(0)); - }; - - int index = call->index(); - // This object will delete itself when done. - new RPCState( - get_stub(index), &completion_queue_, *get_method_ptr(index), - call->request(), call->response(), - /*done=*/[call](const Status& s) { call->Done(s); }, call->call_opts(), - /*threadpool=*/nullptr, fail_fast_, timeout_in_ms_, /*max_retries=*/0, - /*target=*/nullptr); -} - -} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h deleted file mode 100644 index ae9abf765df..00000000000 --- a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_ -#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_ - -#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/util/rpc/call_container.h" -#include "tensorflow/core/util/rpc/rpc_factory.h" - -namespace tensorflow { - -// Forward declaration of GrpcCall. -namespace internal { -class GrpcCall; -} // namespace internal - -class GrpcRPCFactory : public RPCFactory { - public: - explicit GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast, - int64 timeout_in_ms); - - // Explicit destructor to control destruction order. - ~GrpcRPCFactory() override; - - void Call(OpKernelContext* ctx, int64 num_elements, const Tensor& address_t, - const Tensor& method_t, const Tensor& request_t, const bool try_rpc, - Tensor* response_t, Tensor* status_code_t, Tensor* status_message_t, - AsyncOpKernel::DoneCallback done) override; - - protected: - typedef std::shared_ptr<::grpc::Channel> ChannelPtr; - virtual ChannelPtr CreateChannelForAddress(const string& address); - - private: - // Creates a call and registers it with given `container`. The `index` is used - // to index into the tensor arguments. - void CreateCall(const Tensor& request_t, const bool try_rpc, int index, - CallContainer* container, - Tensor* response_t, Tensor* status_code_t, - Tensor* status_message_t); - - // Asynchronously invokes the given `call`. The call completion is handled - // by the call container the call was previously registered with. - void StartCall(const Tensor& address_t, const Tensor& method_t, - internal::GrpcCall* call); - - ::grpc::GenericStub* GetOrCreateStubForAddress(const string& address); - - bool fail_fast_; - int64 timeout_in_ms_; - ::grpc::CompletionQueue completion_queue_; - Thread* polling_thread_; // Owned. - - mutex mu_; - typedef std::unique_ptr<::grpc::GenericStub> StubPtr; - std::unordered_map stubs_ TF_GUARDED_BY(mu_); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory_registration.cc b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory_registration.cc deleted file mode 100644 index b8844893784..00000000000 --- a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory_registration.cc +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h" -#include "tensorflow/core/util/rpc/rpc_factory.h" -#include "tensorflow/core/util/rpc/rpc_factory_registry.h" - -namespace tensorflow { -namespace { - -// Used for adding the grpc factory to the RPC factory registry. -struct Value { - static RPCFactory* Function(OpKernelConstruction* ctx, bool fail_fast, - int64 timeout_in_ms) { - return new GrpcRPCFactory(ctx, fail_fast, timeout_in_ms); - } -}; - -REGISTER_RPC_FACTORY("grpc", Value::Function); - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 3a9e6f8dfa4..68909457180 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -6387,7 +6387,6 @@ filegroup( "spectrogram_convert_test_data.cc", "decode_proto_op.cc", "encode_proto_op.cc", - "rpc_op.cc", "sobol_op.cc", # Excluded due to experimental status: "debug_ops.*", @@ -7304,22 +7303,6 @@ tf_kernel_library( ], ) -tf_kernel_library( - name = "rpc_op", - srcs = [ - "rpc_op.cc", - ], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/util/rpc:call_container", - "//tensorflow/core/util/rpc:rpc_factory", - "//tensorflow/core/util/rpc:rpc_factory_registry", - "//third_party/eigen3", - ], -) - tf_kernel_library( name = "unicode_script_op", srcs = ["unicode_script_op.cc"], diff --git a/tensorflow/core/kernels/rpc_op.cc b/tensorflow/core/kernels/rpc_op.cc deleted file mode 100644 index 3c606e4ec67..00000000000 --- a/tensorflow/core/kernels/rpc_op.cc +++ /dev/null @@ -1,128 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// RpcOp is a TensorFlow op that sends and receives arbitrary messages. -// -// See docs in ../ops/rpc_op.cc. - -#include -#include -#include - -#include "third_party/eigen3/Eigen/Core" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/refcount.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/util/rpc/call_container.h" -#include "tensorflow/core/util/rpc/rpc_factory.h" -#include "tensorflow/core/util/rpc/rpc_factory_registry.h" - -namespace tensorflow { - -class RpcOp : public AsyncOpKernel { - public: - explicit RpcOp(OpKernelConstruction* context) : AsyncOpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("protocol", &protocol_)); - OP_REQUIRES(context, !protocol_.empty(), - errors::InvalidArgument("protocol must be non-empty.")); - bool fail_fast; - OP_REQUIRES_OK(context, context->GetAttr("fail_fast", &fail_fast)); - int64 timeout_in_ms; - OP_REQUIRES_OK(context, context->GetAttr("timeout_in_ms", &timeout_in_ms)); - - RPCFactoryRegistry::RPCFactoryFn* rpc_factory_fn = - RPCFactoryRegistry::Global()->Get(protocol_); - OP_REQUIRES(context, rpc_factory_fn != nullptr, - errors::InvalidArgument("The protocol ", protocol_, - " was not recognized.")); - - rpc_factory_.reset((*rpc_factory_fn)(context, fail_fast, timeout_in_ms)); - } - - ~RpcOp() override {} - - void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { - const Tensor& address_t = ctx->input(0); - const Tensor& method_t = ctx->input(1); - const Tensor& request_t = ctx->input(2); - - OP_REQUIRES_ASYNC( - ctx, address_t.dims() == 0 || address_t.dims() == 1, - errors::InvalidArgument("address must be a scalar or vector."), done); - OP_REQUIRES_ASYNC( - ctx, method_t.dims() == 0 || method_t.dims() == 1, - errors::InvalidArgument("method must be a scalar or vector."), done); - OP_REQUIRES_ASYNC( - ctx, request_t.dims() == 0 || request_t.dims() == 1, - errors::InvalidArgument("request must be a scalar or vector."), done); - - TensorShape output_shape({}); - for (const Tensor& t : {address_t, method_t, request_t}) { - if (t.dims() == 1) { - OP_REQUIRES_ASYNC( - ctx, - output_shape.dims() == 0 || - output_shape.dim_size(0) == t.dim_size(0), - errors::InvalidArgument( - "Input vector shapes don't match: ", output_shape.DebugString(), - " vs. ", t.shape().DebugString()), - done); - output_shape = t.shape(); - } - } - - Tensor* response_t; - OP_REQUIRES_OK_ASYNC( - ctx, ctx->allocate_output(0, output_shape, &response_t), done); - - const bool try_rpc = (ctx->num_outputs() > 1); - - Tensor* status_code_t = nullptr; - Tensor* status_message_t = nullptr; - if (try_rpc) { - OP_REQUIRES_OK_ASYNC( - ctx, ctx->allocate_output(1, output_shape, &status_code_t), done); - OP_REQUIRES_OK_ASYNC( - ctx, ctx->allocate_output(2, output_shape, &status_message_t), done); - } - - if (request_t.NumElements() == 0) { - // Special case, we finished early! - done(); - return; - } - - int64 num_elements = output_shape.num_elements(); - - rpc_factory_->Call(ctx, num_elements, address_t, method_t, request_t, - try_rpc, response_t, status_code_t, status_message_t, - std::move(done)); - } - - private: - string protocol_; - std::unique_ptr rpc_factory_; - - TF_DISALLOW_COPY_AND_ASSIGN(RpcOp); -}; - -REGISTER_KERNEL_BUILDER(Name("Rpc").Device(DEVICE_CPU), RpcOp); -REGISTER_KERNEL_BUILDER(Name("TryRpc").Device(DEVICE_CPU), RpcOp); - -} // namespace tensorflow diff --git a/tensorflow/core/ops/BUILD b/tensorflow/core/ops/BUILD index 40aaae8f148..16f12bc9f39 100644 --- a/tensorflow/core/ops/BUILD +++ b/tensorflow/core/ops/BUILD @@ -85,7 +85,6 @@ tf_gen_op_libs( "remote_fused_graph_ops", "risc_ops", "rnn_ops", - "rpc_ops", "scoped_allocator_ops", "sdca_ops", "set_ops", @@ -286,7 +285,6 @@ cc_library( ":stateful_random_ops_op_lib", ":remote_fused_graph_ops_op_lib", ":resource_variable_ops_op_lib", - ":rpc_ops_op_lib", ":scoped_allocator_ops_op_lib", ":script_ops_op_lib", ":sdca_ops_op_lib", diff --git a/tensorflow/core/ops/compat/ops_history_v1/Rpc.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Rpc.pbtxt deleted file mode 100644 index 224e52ea574..00000000000 --- a/tensorflow/core/ops/compat/ops_history_v1/Rpc.pbtxt +++ /dev/null @@ -1,41 +0,0 @@ -op { - name: "Rpc" - input_arg { - name: "address" - type: DT_STRING - } - input_arg { - name: "method" - type: DT_STRING - } - input_arg { - name: "request" - type: DT_STRING - } - output_arg { - name: "response" - type: DT_STRING - } - attr { - name: "protocol" - type: "string" - default_value { - s: "" - } - } - attr { - name: "fail_fast" - type: "bool" - default_value { - b: true - } - } - attr { - name: "timeout_in_ms" - type: "int" - default_value { - i: 0 - } - } - is_stateful: true -} diff --git a/tensorflow/core/ops/compat/ops_history_v1/TryRpc.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/TryRpc.pbtxt deleted file mode 100644 index e585195fb9b..00000000000 --- a/tensorflow/core/ops/compat/ops_history_v1/TryRpc.pbtxt +++ /dev/null @@ -1,49 +0,0 @@ -op { - name: "TryRpc" - input_arg { - name: "address" - type: DT_STRING - } - input_arg { - name: "method" - type: DT_STRING - } - input_arg { - name: "request" - type: DT_STRING - } - output_arg { - name: "response" - type: DT_STRING - } - output_arg { - name: "status_code" - type: DT_INT32 - } - output_arg { - name: "status_message" - type: DT_STRING - } - attr { - name: "protocol" - type: "string" - default_value { - s: "" - } - } - attr { - name: "fail_fast" - type: "bool" - default_value { - b: true - } - } - attr { - name: "timeout_in_ms" - type: "int" - default_value { - i: 0 - } - } - is_stateful: true -} diff --git a/tensorflow/core/ops/compat/ops_history_v2/Rpc.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/Rpc.pbtxt deleted file mode 100644 index 224e52ea574..00000000000 --- a/tensorflow/core/ops/compat/ops_history_v2/Rpc.pbtxt +++ /dev/null @@ -1,41 +0,0 @@ -op { - name: "Rpc" - input_arg { - name: "address" - type: DT_STRING - } - input_arg { - name: "method" - type: DT_STRING - } - input_arg { - name: "request" - type: DT_STRING - } - output_arg { - name: "response" - type: DT_STRING - } - attr { - name: "protocol" - type: "string" - default_value { - s: "" - } - } - attr { - name: "fail_fast" - type: "bool" - default_value { - b: true - } - } - attr { - name: "timeout_in_ms" - type: "int" - default_value { - i: 0 - } - } - is_stateful: true -} diff --git a/tensorflow/core/ops/compat/ops_history_v2/TryRpc.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/TryRpc.pbtxt deleted file mode 100644 index e585195fb9b..00000000000 --- a/tensorflow/core/ops/compat/ops_history_v2/TryRpc.pbtxt +++ /dev/null @@ -1,49 +0,0 @@ -op { - name: "TryRpc" - input_arg { - name: "address" - type: DT_STRING - } - input_arg { - name: "method" - type: DT_STRING - } - input_arg { - name: "request" - type: DT_STRING - } - output_arg { - name: "response" - type: DT_STRING - } - output_arg { - name: "status_code" - type: DT_INT32 - } - output_arg { - name: "status_message" - type: DT_STRING - } - attr { - name: "protocol" - type: "string" - default_value { - s: "" - } - } - attr { - name: "fail_fast" - type: "bool" - default_value { - b: true - } - } - attr { - name: "timeout_in_ms" - type: "int" - default_value { - i: 0 - } - } - is_stateful: true -} diff --git a/tensorflow/core/ops/rpc_ops.cc b/tensorflow/core/ops/rpc_ops.cc deleted file mode 100644 index 136f96d9ea7..00000000000 --- a/tensorflow/core/ops/rpc_ops.cc +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" - -namespace tensorflow { - -using tensorflow::shape_inference::InferenceContext; -using tensorflow::shape_inference::ShapeHandle; - -Status RpcShapeOp(InferenceContext* c, bool try_rpc) { - ShapeHandle address; - ShapeHandle method; - ShapeHandle request; - ShapeHandle output; - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &address)); - if (c->Rank(address) == 1) { - TF_RETURN_IF_ERROR(c->Merge(output, address, &output)); - } - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &method)); - if (c->Rank(method) == 1) { - TF_RETURN_IF_ERROR(c->Merge(output, method, &output)); - } - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &request)); - if (c->Rank(request) == 1) { - TF_RETURN_IF_ERROR(c->Merge(output, request, &output)); - } - if (!c->RankKnown(output)) { - output = request; - } - c->set_output(0, output); // response - if (try_rpc) { - c->set_output(1, output); // status_code - c->set_output(2, output); // status_message - } - return Status::OK(); -} - -REGISTER_OP("Rpc") - .Input("address: string") - .Input("method: string") - .Input("request: string") - .Attr("protocol: string = ''") - .Attr("fail_fast: bool = true") - .Attr("timeout_in_ms: int = 0") - .Output("response: string") - .SetIsStateful() - .SetShapeFn([](InferenceContext* c) { - return RpcShapeOp(c, /*try_rpc=*/false); - }); - -REGISTER_OP("TryRpc") - .Input("address: string") - .Input("method: string") - .Input("request: string") - .Attr("protocol: string = ''") - .Attr("fail_fast: bool = true") - .Attr("timeout_in_ms: int = 0") - .Output("response: string") - .Output("status_code: int32") - .Output("status_message: string") - .SetIsStateful() - .SetShapeFn([](InferenceContext* c) { - return RpcShapeOp(c, /*try_rpc=*/true); - }); - -} // namespace tensorflow diff --git a/tensorflow/core/util/rpc/BUILD b/tensorflow/core/util/rpc/BUILD deleted file mode 100644 index c1b8869b8d2..00000000000 --- a/tensorflow/core/util/rpc/BUILD +++ /dev/null @@ -1,48 +0,0 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "call_container", - hdrs = ["call_container.h"], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - ], -) - -cc_library( - name = "rpc_factory", - srcs = ["rpc_factory.cc"], - hdrs = ["rpc_factory.h"], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "rpc_factory_registry", - srcs = ["rpc_factory_registry.cc"], - hdrs = ["rpc_factory_registry.h"], - deps = [ - ":rpc_factory", - "//tensorflow/core:framework", - ], -) - -tf_cc_test( - name = "rpc_factory_registry_test", - srcs = ["rpc_factory_registry_test.cc"], - deps = [ - ":rpc_factory_registry", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ], -) diff --git a/tensorflow/core/util/rpc/call_container.h b/tensorflow/core/util/rpc/call_container.h deleted file mode 100644 index 39ead10815a..00000000000 --- a/tensorflow/core/util/rpc/call_container.h +++ /dev/null @@ -1,182 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_ -#define TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_ - -#include - -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/refcount.h" -#include "tensorflow/core/util/reffed_status_callback.h" - -namespace tensorflow { - -namespace internal { -// The following class is used for coordination between a `CallContainer` -// instance and a cancellation callback to make sure that the `CallContainer` -// instance waits for the cancellation callback to be destroyed (either because -// a cancellation occurred or because the callback was deregistered) before -// deleting itself. Without this coordination the cancellation callback could -// attempt to access a `CallContainer` instance that is no longer valid. -class NotifyWhenDestroyed { - public: - explicit NotifyWhenDestroyed(std::shared_ptr notification) - : notification_(std::move(notification)) {} - - ~NotifyWhenDestroyed() { notification_->Notify(); } - - private: - std::shared_ptr notification_; -}; -} // namespace internal - -// The following class is responsible for the life cycle management of a set of -// RPC calls. The calls are started when an instance of the class is created and -// the class contract guarantees to invoke a "done" callback provided by the -// caller when all RPC calls have either completed or been cancelled. -// -// The caller should not make any assumptions about the validity of an instance -// of this class after the provided callback has been invoked, which may be -// immediately after the instance was created. -template -class CallContainer { - public: - typedef std::function*, int)> CreateCallFn; - typedef std::function StartCallFn; - - // Uses the provided `create_call_fn` and `start_call_fn` functions to create - // and start a set of RPC calls. When all RPC calls have either completed or - // been cancelled, the `done` callback is invoked. The caller should not make - // any assumptions about the validity of the created instance as the instance - // will delete itself after invoking the `done` callback. - explicit CallContainer(OpKernelContext* ctx, int num_calls, bool fail_fast, - bool try_rpc, AsyncOpKernel::DoneCallback done, - CreateCallFn create_call_fn, - StartCallFn start_call_fn); - - // Registers a call with this container. This method expects its arguments to - // match those of a `Call` constructor as it forwards them to an underlying - // collection, which creates a `Call` instance in place. - template - void RegisterCall(Args&&... args); - - // Starts the cancellation of all RPC calls managed by this container. - void StartCancel(); - - // Indicates that the `index`-th RPC call has finished. - void Done(const Status& s, int index); - - private: - OpKernelContext* ctx_; - std::list calls_; - const AsyncOpKernel::DoneCallback done_; - const CancellationToken token_; - const bool fail_fast_; - const bool try_rpc_; - std::shared_ptr callback_destroyed_; - - // Performs its own reference counting. - ReffedStatusCallback* reffed_status_callback_; -}; - -template -CallContainer::CallContainer( - OpKernelContext* ctx, int num_calls, bool fail_fast, bool try_rpc, - AsyncOpKernel::DoneCallback done, - typename CallContainer::CreateCallFn create_call_fn, - typename CallContainer::StartCallFn start_call_fn) - : ctx_(ctx), - done_(std::move(done)), - token_(ctx->cancellation_manager() != nullptr - ? ctx->cancellation_manager()->get_cancellation_token() - : CancellationManager::kInvalidToken), - fail_fast_(fail_fast), - try_rpc_(try_rpc), - callback_destroyed_(new Notification) { - CHECK_GT(num_calls, 0); - - // This will run when all RPCs are finished. - reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) { - if (token_ != CancellationManager::kInvalidToken) { - ctx_->cancellation_manager()->DeregisterCallback(token_); - } - ctx_->SetStatus(s); - done_(); - callback_destroyed_->WaitForNotification(); - delete this; - }); - - // The cancellation callback needs to be registered before the RPC calls are - // started to make sure that the callback is properly cleaned up by the - // `reffed_status_callback` when all calls complete. At the same time, the - // cancellation callback should wait for the RPC calls to be started for the - // cancellation to take effect. - std::shared_ptr notify_when_destroyed( - new internal::NotifyWhenDestroyed(callback_destroyed_)); - std::shared_ptr calls_started(new Notification); - bool is_cancelled = false; - if (token_ != CancellationManager::kInvalidToken) { - is_cancelled = !ctx_->cancellation_manager()->RegisterCallback( - token_, [this, calls_started, notify_when_destroyed]() { - calls_started->WaitForNotification(); - StartCancel(); - }); - } - - for (int i = 0; i < num_calls; ++i) { - create_call_fn(this, i); - // Increase the reference on the callback for each new RPC. - reffed_status_callback_->Ref(); - } - for (Call& call : calls_) { - start_call_fn(&call); - } - calls_started->Notify(); - - if (is_cancelled) { - ctx_->SetStatus(errors::Cancelled("Operation has been cancelled.")); - StartCancel(); - } - - // Subtract reference count from the initial creation. - reffed_status_callback_->Unref(); -} - -template -template -void CallContainer::RegisterCall(Args&&... args) { - calls_.emplace_back(std::forward(args)...); -} - -template -void CallContainer::StartCancel() { - for (auto& call : calls_) { - call.StartCancel(); - } -} - -template -void CallContainer::Done(const Status& s, int index) { - if (!try_rpc_) { - reffed_status_callback_->UpdateStatus(s); - } - reffed_status_callback_->Unref(); -} - -} // namespace tensorflow -#endif // TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_ diff --git a/tensorflow/core/util/rpc/rpc_factory.cc b/tensorflow/core/util/rpc/rpc_factory.cc deleted file mode 100644 index 8530f02b6e2..00000000000 --- a/tensorflow/core/util/rpc/rpc_factory.cc +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/lib/strings/numbers.h" - -#include "tensorflow/core/util/rpc/rpc_factory.h" - -namespace tensorflow { - -template <> -bool GetEnvVar(const char* key, const string& default_value, string* value) { - const char* env_value = std::getenv(key); - if (!env_value || env_value[0] == '\0') { - *value = default_value; - } else { - *value = env_value; - } - return true; -} - -template <> -bool GetEnvVar(const char* key, const int64& default_value, int64* value) { - const char* env_value = std::getenv(key); - if (!env_value || env_value[0] == '\0') { - *value = default_value; - return true; - } - return strings::safe_strto64(env_value, value); -} - -template <> -bool GetEnvVar(const char* key, const uint64& default_value, uint64* value) { - const char* env_value = std::getenv(key); - if (!env_value || env_value[0] == '\0') { - *value = default_value; - return true; - } - return strings::safe_strtou64(env_value, value); -} - -} // namespace tensorflow diff --git a/tensorflow/core/util/rpc/rpc_factory.h b/tensorflow/core/util/rpc/rpc_factory.h deleted file mode 100644 index c4eaaf44570..00000000000 --- a/tensorflow/core/util/rpc/rpc_factory.h +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_ -#define TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_ - -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor_types.h" - -namespace tensorflow { - -// Return the environment variable `key`. If the variable is not set, -// use the default value. If it is set but could not be parsed, -// return `false`. Otherwise set `value` and return `true`. -template -bool GetEnvVar(const char* key, const T& default_value, T* value); - -class RPCFactory { - public: - RPCFactory() {} - virtual ~RPCFactory() {} - - // Asynchronously invokes methods `method_t` at addresses `address_t` with - // request strings from `request_t`. Any of these may be scalar - // Tensors, in which case the operands are broadcasted. - // Upon completion of all requests, `response_t` will be populated and the - // `done` callback will be invoked. - // - // If `try_rpc` is `true`, then `status_message_t` and - // `status_code_t` will be populated as well. - // - // If `try_rpc` is `false`, then `status_message_t` and - // `status_code_t` are ignored (and may be nullptr). Instead, the - // status of any failed call will be propagated to the op. - // - // REQUIRES: - // - `response_t` is not null, and is a string Tensor with the same shape as - // `request_t`. - // - // If `try_rpc` is `true`: - // - `status_code_t` and `status_message_t` are not null. - // - `status_code_t` is an int32 Tensor with the same shape as - // `request_t`. - // - `status_message_t` is a string Tensor with the same shape as - // `request_t`. - virtual void Call(OpKernelContext* ctx, int64 num_elements, - const Tensor& address_t, const Tensor& method_t, - const Tensor& request_t, const bool try_rpc, - Tensor* response_t, Tensor* status_code_t, - Tensor* status_message_t, - AsyncOpKernel::DoneCallback done) = 0; - - private: - TF_DISALLOW_COPY_AND_ASSIGN(RPCFactory); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_ diff --git a/tensorflow/core/util/rpc/rpc_factory_registry.cc b/tensorflow/core/util/rpc/rpc_factory_registry.cc deleted file mode 100644 index a148b5c04d0..00000000000 --- a/tensorflow/core/util/rpc/rpc_factory_registry.cc +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "tensorflow/core/util/rpc/rpc_factory.h" - -#include "tensorflow/core/util/rpc/rpc_factory_registry.h" - -namespace tensorflow { - -RPCFactoryRegistry* RPCFactoryRegistry::Global() { - static RPCFactoryRegistry* registry = new RPCFactoryRegistry; - return registry; -} - -RPCFactoryRegistry::RPCFactoryFn* RPCFactoryRegistry::Get( - const string& protocol) { - auto found = fns_.find(protocol); - if (found == fns_.end()) return nullptr; - return &found->second; -} - -void RPCFactoryRegistry::Register(const string& protocol, - const RPCFactoryFn& factory_fn) { - auto existing = Get(protocol); - CHECK_EQ(existing, nullptr) - << "RPC factory for protocol: " << protocol << " already registered"; - fns_.insert(std::pair(protocol, factory_fn)); -} - -} // namespace tensorflow diff --git a/tensorflow/core/util/rpc/rpc_factory_registry.h b/tensorflow/core/util/rpc/rpc_factory_registry.h deleted file mode 100644 index 2635a4012e8..00000000000 --- a/tensorflow/core/util/rpc/rpc_factory_registry.h +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_ -#define TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_ - -#include -#include - -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/util/rpc/rpc_factory.h" - -namespace tensorflow { - -class RPCFactoryRegistry { - public: - typedef std::function - RPCFactoryFn; - - // Returns a pointer to a global RPCFactoryRegistry object. - static RPCFactoryRegistry* Global(); - - // Returns a pointer to an function that creates an RPC factory for the given - // protocol. - RPCFactoryFn* Get(const string& protocol); - - // Registers a function that creates and RPC factory for the given protocol. - // The function should transfer the ownership of the factory to its caller. - void Register(const string& protocol, const RPCFactoryFn& factory_fn); - - private: - std::map fns_; -}; - -namespace rpc_factory_registration { - -class RPCFactoryRegistration { - public: - RPCFactoryRegistration(const string& protocol, - const RPCFactoryRegistry::RPCFactoryFn& factory_fn) { - RPCFactoryRegistry::Global()->Register(protocol, factory_fn); - } -}; - -} // namespace rpc_factory_registration - -#define REGISTER_RPC_FACTORY(protocol, factory_fn) \ - REGISTER_RPC_FACTORY_UNIQ_HELPER(__COUNTER__, protocol, factory_fn) - -#define REGISTER_RPC_FACTORY_UNIQ_HELPER(ctr, protocol, factory_fn) \ - REGISTER_RPC_FACTORY_UNIQ(ctr, protocol, factory_fn) - -#define REGISTER_RPC_FACTORY_UNIQ(ctr, protocol, factory_fn) \ - static rpc_factory_registration::RPCFactoryRegistration \ - rpc_factory_registration_fn_##ctr(protocol, factory_fn) - -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_ diff --git a/tensorflow/core/util/rpc/rpc_factory_registry_test.cc b/tensorflow/core/util/rpc/rpc_factory_registry_test.cc deleted file mode 100644 index cfd0f95016e..00000000000 --- a/tensorflow/core/util/rpc/rpc_factory_registry_test.cc +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/util/rpc/rpc_factory_registry.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace { - -struct Value { - static RPCFactory* Function(OpKernelConstruction* ctx, bool fail_fast, - int64 timeout_in_ms) { - return nullptr; - } -}; - -REGISTER_RPC_FACTORY("TEST FACTORY 1", Value::Function); -REGISTER_RPC_FACTORY("TEST FACTORY 2", Value::Function); -} // namespace - -TEST(RPCFactoryRegistryTest, TestBasic) { - EXPECT_EQ(RPCFactoryRegistry::Global()->Get("NON-EXISTENT"), nullptr); - auto factory1 = RPCFactoryRegistry::Global()->Get("TEST FACTORY 1"); - EXPECT_NE(factory1, nullptr); - auto factory2 = RPCFactoryRegistry::Global()->Get("TEST FACTORY 2"); - EXPECT_NE(factory2, nullptr); -} - -} // namespace tensorflow diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 7b1ccdae12e..dc8cd56f91a 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -5206,7 +5206,6 @@ pywrap_tensorflow_macro( ":bfloat16_lib", ":cost_analyzer_lib", ":model_analyzer_lib", - "//tensorflow/core/distributed_runtime/rpc:grpc_rpc_factory_registration", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_session", "//tensorflow/python/util:cpp_python_util",