From 4e5bee560fcc13e867bc0d80c31e66fa6a8e4548 Mon Sep 17 00:00:00 2001 From: Dero Gharibian <dero@google.com> Date: Thu, 22 Aug 2019 21:25:26 -0700 Subject: [PATCH] Update core/distributed_runtime to use tstring. This is a part of a larger migration effort for tensorflow::tstring. See: https://github.com/tensorflow/community/pull/91 PiperOrigin-RevId: 264982634 --- .../rpc/grpc_rpc_factory.cc | 28 +++++++++---------- .../rpc/grpc_tensor_coding_test.cc | 4 +-- .../core/distributed_runtime/rpc/grpc_util.cc | 18 +++++++++++- .../core/distributed_runtime/rpc/grpc_util.h | 3 ++ .../distributed_runtime/tensor_coding_test.cc | 4 +-- tensorflow/core/platform/tstring.h | 22 +++++++++++++++ 6 files changed, 60 insertions(+), 19 deletions(-) diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc index 8be6f1d6994..272d6bb1b20 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc @@ -34,8 +34,8 @@ namespace internal { class GrpcCall { public: explicit GrpcCall(CallContainer<GrpcCall>* container, int index, bool try_rpc, - const string* request_msg, string* response_msg, - int32* status_code, string* status_message) + const tstring* request_msg, tstring* response_msg, + int32* status_code, tstring* status_message) : container_(container), index_(index), try_rpc_(try_rpc), @@ -59,18 +59,18 @@ class GrpcCall { CallOptions* call_opts() { return &call_opts_; } int index() { return index_; } - const string& request() const { return *request_msg_; } - string* response() const { return response_msg_; } + const tstring& request() const { return *request_msg_; } + tstring* response() const { return response_msg_; } private: CallContainer<GrpcCall>* const container_; const int index_; bool try_rpc_; CallOptions call_opts_; - const string* request_msg_; - string* response_msg_; + const tstring* request_msg_; + tstring* response_msg_; int* status_code_; - string* status_message_; + tstring* status_message_; }; } // namespace internal @@ -168,16 +168,16 @@ void GrpcRPCFactory::CreateCall(const Tensor& request_t, const bool try_rpc, int index, CallContainer<GrpcCall>* container, Tensor* response_t, Tensor* status_code_t, Tensor* status_message_t) { - auto request = request_t.flat<string>(); - auto get_request_ptr = [&request](int64 ix) -> const string* { + auto request = request_t.flat<tstring>(); + auto get_request_ptr = [&request](int64 ix) -> const tstring* { return (request.size() > 1) ? &(request(ix)) : &(request(0)); }; - auto response = response_t->flat<string>(); + auto response = response_t->flat<tstring>(); int32* status_code_ptr = nullptr; - string* status_message_ptr = nullptr; + tstring* status_message_ptr = nullptr; if (try_rpc) { status_code_ptr = status_code_t->flat<int32>().data(); - status_message_ptr = status_message_t->flat<string>().data(); + status_message_ptr = status_message_t->flat<tstring>().data(); } container->RegisterCall(container, index, try_rpc, get_request_ptr(index), &response(index), @@ -200,13 +200,13 @@ void GrpcRPCFactory::StartCall(const Tensor& address_t, const Tensor& method_t, return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix)) : singleton_stub; }; - auto get_method_ptr = [&method](int64 ix) -> const string* { + auto get_method_ptr = [&method](int64 ix) -> const tstring* { return (method.size() > 1) ? &(method(ix)) : &(method(0)); }; int index = call->index(); // This object will delete itself when done. - new RPCState<string>( + new RPCState<tstring>( get_stub(index), &completion_queue_, *get_method_ptr(index), call->request(), call->response(), /*done=*/[call](const Status& s) { call->Done(s); }, call->call_opts(), diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc index d07bac5631c..29ee480e39b 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc @@ -65,11 +65,11 @@ class GrpcTensorCodingTest : public ::testing::Test { } } void DoTestForStrings(DataType dt) { - gtl::InlinedVector<string, 4> v; + gtl::InlinedVector<tstring, 4> v; for (int elems = 0; elems <= 10000; elems++) { if (elems < 100 || (elems % 1000 == 0)) { Tensor a(dt, TensorShape({1, static_cast<int64>(v.size())})); - test::FillValues<string>(&a, v); + test::FillValues<tstring>(&a, v); Validate(a, (elems == 0)); } v.push_back(strings::StrCat("This is string ", elems)); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_util.cc b/tensorflow/core/distributed_runtime/rpc/grpc_util.cc index 471e2c16b34..5dda1459167 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_util.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_util.cc @@ -100,7 +100,7 @@ bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, TensorResponse* dst) { return s.ok(); } -// GrpcMaybeParseProto into a string simply copies bytes into the string. +// GrpcMaybeParseProto simply copies bytes into the string. bool GrpcMaybeParseProto(grpc::ByteBuffer* src, string* dst) { dst->clear(); dst->reserve(src->Length()); @@ -114,4 +114,20 @@ bool GrpcMaybeParseProto(grpc::ByteBuffer* src, string* dst) { return true; } +#ifdef USE_TSTRING +// GrpcMaybeParseProto simply copies bytes into the tstring. +bool GrpcMaybeParseProto(grpc::ByteBuffer* src, tstring* dst) { + dst->clear(); + dst->reserve(src->Length()); + std::vector<::grpc::Slice> slices; + if (!src->Dump(&slices).ok()) { + return false; + } + for (const ::grpc::Slice& s : slices) { + dst->append(reinterpret_cast<const char*>(s.begin()), s.size()); + } + return true; +} +#endif // USE_TSTRING + } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_util.h b/tensorflow/core/distributed_runtime/rpc/grpc_util.h index 976f3e6452a..aed798217cb 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_util.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_util.h @@ -131,6 +131,9 @@ bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, TensorResponse* dst); // Copy grpc buffer src to string *dst. bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, string* dst); +// Copy grpc buffer src to tstring *dst. +bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, tstring* dst); + } // namespace tensorflow #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_ diff --git a/tensorflow/core/distributed_runtime/tensor_coding_test.cc b/tensorflow/core/distributed_runtime/tensor_coding_test.cc index 52a057bdb2f..02e137a46c6 100644 --- a/tensorflow/core/distributed_runtime/tensor_coding_test.cc +++ b/tensorflow/core/distributed_runtime/tensor_coding_test.cc @@ -120,12 +120,12 @@ class TensorResponseTest : public ::testing::Test { } } void DoTestForStrings(DataType dt) { - gtl::InlinedVector<string, 4> v; + gtl::InlinedVector<tstring, 4> v; LOG(ERROR) << "DT: string"; for (int elems = 0; elems <= 10000; elems++) { if (elems < 100 || (elems % 1000 == 0)) { Tensor a(dt, TensorShape({1, static_cast<int64>(v.size())})); - test::FillValues<string>(&a, v); + test::FillValues<tstring>(&a, v); Validate(a, (elems == 0), true); } v.push_back(strings::StrCat("This is string ", elems)); diff --git a/tensorflow/core/platform/tstring.h b/tensorflow/core/platform/tstring.h index ea145525fcf..d7c82755e48 100644 --- a/tensorflow/core/platform/tstring.h +++ b/tensorflow/core/platform/tstring.h @@ -143,12 +143,16 @@ class tstring { char& operator[](size_t i) { return str_[i]; } + void clear() noexcept { str_.clear(); } + void resize(size_t new_size) { str_.resize(new_size); } void resize_uninitialized(size_t new_size) { ResizeUninitialized<decltype(str_)>::Resize(str_, new_size); } + void reserve(size_t n) { str_.reserve(n); } + tstring& assign(const char* str, size_t len) { str_.assign(str, len); @@ -161,6 +165,24 @@ class tstring { return *this; } + tstring& append(const tstring& str) { + str_.append(str); + + return *this; + } + + tstring& append(const char* str, size_t len) { + str_.append(str, len); + + return *this; + } + + tstring& append(const char* str) { + str_.append(str); + + return *this; + } + friend const tstring operator+(const tstring& a, const tstring& b); friend bool operator==(const char* a, const tstring& b); friend bool operator==(const std::string& a, const tstring& b);