From 4e5bee560fcc13e867bc0d80c31e66fa6a8e4548 Mon Sep 17 00:00:00 2001
From: Dero Gharibian <dero@google.com>
Date: Thu, 22 Aug 2019 21:25:26 -0700
Subject: [PATCH] Update core/distributed_runtime to use tstring.

This is a part of a larger migration effort for tensorflow::tstring.
See: https://github.com/tensorflow/community/pull/91
PiperOrigin-RevId: 264982634
---
 .../rpc/grpc_rpc_factory.cc                   | 28 +++++++++----------
 .../rpc/grpc_tensor_coding_test.cc            |  4 +--
 .../core/distributed_runtime/rpc/grpc_util.cc | 18 +++++++++++-
 .../core/distributed_runtime/rpc/grpc_util.h  |  3 ++
 .../distributed_runtime/tensor_coding_test.cc |  4 +--
 tensorflow/core/platform/tstring.h            | 22 +++++++++++++++
 6 files changed, 60 insertions(+), 19 deletions(-)

diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
index 8be6f1d6994..272d6bb1b20 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
@@ -34,8 +34,8 @@ namespace internal {
 class GrpcCall {
  public:
   explicit GrpcCall(CallContainer<GrpcCall>* container, int index, bool try_rpc,
-                    const string* request_msg, string* response_msg,
-                    int32* status_code, string* status_message)
+                    const tstring* request_msg, tstring* response_msg,
+                    int32* status_code, tstring* status_message)
       : container_(container),
         index_(index),
         try_rpc_(try_rpc),
@@ -59,18 +59,18 @@ class GrpcCall {
 
   CallOptions* call_opts() { return &call_opts_; }
   int index() { return index_; }
-  const string& request() const { return *request_msg_; }
-  string* response() const { return response_msg_; }
+  const tstring& request() const { return *request_msg_; }
+  tstring* response() const { return response_msg_; }
 
  private:
   CallContainer<GrpcCall>* const container_;
   const int index_;
   bool try_rpc_;
   CallOptions call_opts_;
-  const string* request_msg_;
-  string* response_msg_;
+  const tstring* request_msg_;
+  tstring* response_msg_;
   int* status_code_;
-  string* status_message_;
+  tstring* status_message_;
 };
 
 }  // namespace internal
@@ -168,16 +168,16 @@ void GrpcRPCFactory::CreateCall(const Tensor& request_t, const bool try_rpc,
                                 int index, CallContainer<GrpcCall>* container,
                                 Tensor* response_t, Tensor* status_code_t,
                                 Tensor* status_message_t) {
-  auto request = request_t.flat<string>();
-  auto get_request_ptr = [&request](int64 ix) -> const string* {
+  auto request = request_t.flat<tstring>();
+  auto get_request_ptr = [&request](int64 ix) -> const tstring* {
     return (request.size() > 1) ? &(request(ix)) : &(request(0));
   };
-  auto response = response_t->flat<string>();
+  auto response = response_t->flat<tstring>();
   int32* status_code_ptr = nullptr;
-  string* status_message_ptr = nullptr;
+  tstring* status_message_ptr = nullptr;
   if (try_rpc) {
     status_code_ptr = status_code_t->flat<int32>().data();
-    status_message_ptr = status_message_t->flat<string>().data();
+    status_message_ptr = status_message_t->flat<tstring>().data();
   }
   container->RegisterCall(container, index, try_rpc, get_request_ptr(index),
                           &response(index),
@@ -200,13 +200,13 @@ void GrpcRPCFactory::StartCall(const Tensor& address_t, const Tensor& method_t,
     return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix))
                                 : singleton_stub;
   };
-  auto get_method_ptr = [&method](int64 ix) -> const string* {
+  auto get_method_ptr = [&method](int64 ix) -> const tstring* {
     return (method.size() > 1) ? &(method(ix)) : &(method(0));
   };
 
   int index = call->index();
   // This object will delete itself when done.
-  new RPCState<string>(
+  new RPCState<tstring>(
       get_stub(index), &completion_queue_, *get_method_ptr(index),
       call->request(), call->response(),
       /*done=*/[call](const Status& s) { call->Done(s); }, call->call_opts(),
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc
index d07bac5631c..29ee480e39b 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc
@@ -65,11 +65,11 @@ class GrpcTensorCodingTest : public ::testing::Test {
     }
   }
   void DoTestForStrings(DataType dt) {
-    gtl::InlinedVector<string, 4> v;
+    gtl::InlinedVector<tstring, 4> v;
     for (int elems = 0; elems <= 10000; elems++) {
       if (elems < 100 || (elems % 1000 == 0)) {
         Tensor a(dt, TensorShape({1, static_cast<int64>(v.size())}));
-        test::FillValues<string>(&a, v);
+        test::FillValues<tstring>(&a, v);
         Validate(a, (elems == 0));
       }
       v.push_back(strings::StrCat("This is string ", elems));
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_util.cc b/tensorflow/core/distributed_runtime/rpc/grpc_util.cc
index 471e2c16b34..5dda1459167 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_util.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_util.cc
@@ -100,7 +100,7 @@ bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, TensorResponse* dst) {
   return s.ok();
 }
 
-// GrpcMaybeParseProto into a string simply copies bytes into the string.
+// GrpcMaybeParseProto simply copies bytes into the string.
 bool GrpcMaybeParseProto(grpc::ByteBuffer* src, string* dst) {
   dst->clear();
   dst->reserve(src->Length());
@@ -114,4 +114,20 @@ bool GrpcMaybeParseProto(grpc::ByteBuffer* src, string* dst) {
   return true;
 }
 
+#ifdef USE_TSTRING
+// GrpcMaybeParseProto simply copies bytes into the tstring.
+bool GrpcMaybeParseProto(grpc::ByteBuffer* src, tstring* dst) {
+  dst->clear();
+  dst->reserve(src->Length());
+  std::vector<::grpc::Slice> slices;
+  if (!src->Dump(&slices).ok()) {
+    return false;
+  }
+  for (const ::grpc::Slice& s : slices) {
+    dst->append(reinterpret_cast<const char*>(s.begin()), s.size());
+  }
+  return true;
+}
+#endif  // USE_TSTRING
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_util.h b/tensorflow/core/distributed_runtime/rpc/grpc_util.h
index 976f3e6452a..aed798217cb 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_util.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_util.h
@@ -131,6 +131,9 @@ bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, TensorResponse* dst);
 // Copy grpc buffer src to string *dst.
 bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, string* dst);
 
+// Copy grpc buffer src to tstring *dst.
+bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, tstring* dst);
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
diff --git a/tensorflow/core/distributed_runtime/tensor_coding_test.cc b/tensorflow/core/distributed_runtime/tensor_coding_test.cc
index 52a057bdb2f..02e137a46c6 100644
--- a/tensorflow/core/distributed_runtime/tensor_coding_test.cc
+++ b/tensorflow/core/distributed_runtime/tensor_coding_test.cc
@@ -120,12 +120,12 @@ class TensorResponseTest : public ::testing::Test {
     }
   }
   void DoTestForStrings(DataType dt) {
-    gtl::InlinedVector<string, 4> v;
+    gtl::InlinedVector<tstring, 4> v;
     LOG(ERROR) << "DT: string";
     for (int elems = 0; elems <= 10000; elems++) {
       if (elems < 100 || (elems % 1000 == 0)) {
         Tensor a(dt, TensorShape({1, static_cast<int64>(v.size())}));
-        test::FillValues<string>(&a, v);
+        test::FillValues<tstring>(&a, v);
         Validate(a, (elems == 0), true);
       }
       v.push_back(strings::StrCat("This is string ", elems));
diff --git a/tensorflow/core/platform/tstring.h b/tensorflow/core/platform/tstring.h
index ea145525fcf..d7c82755e48 100644
--- a/tensorflow/core/platform/tstring.h
+++ b/tensorflow/core/platform/tstring.h
@@ -143,12 +143,16 @@ class tstring {
 
   char& operator[](size_t i) { return str_[i]; }
 
+  void clear() noexcept { str_.clear(); }
+
   void resize(size_t new_size) { str_.resize(new_size); }
 
   void resize_uninitialized(size_t new_size) {
     ResizeUninitialized<decltype(str_)>::Resize(str_, new_size);
   }
 
+  void reserve(size_t n) { str_.reserve(n); }
+
   tstring& assign(const char* str, size_t len) {
     str_.assign(str, len);
 
@@ -161,6 +165,24 @@ class tstring {
     return *this;
   }
 
+  tstring& append(const tstring& str) {
+    str_.append(str);
+
+    return *this;
+  }
+
+  tstring& append(const char* str, size_t len) {
+    str_.append(str, len);
+
+    return *this;
+  }
+
+  tstring& append(const char* str) {
+    str_.append(str);
+
+    return *this;
+  }
+
   friend const tstring operator+(const tstring& a, const tstring& b);
   friend bool operator==(const char* a, const tstring& b);
   friend bool operator==(const std::string& a, const tstring& b);