Update core/distributed_runtime to use tstring.

This is a part of a larger migration effort for tensorflow::tstring.
See: https://github.com/tensorflow/community/pull/91
PiperOrigin-RevId: 264982634
This commit is contained in:
Dero Gharibian 2019-08-22 21:25:26 -07:00 committed by TensorFlower Gardener
parent f09d3e2997
commit 4e5bee560f
6 changed files with 60 additions and 19 deletions

View File

@ -34,8 +34,8 @@ namespace internal {
class GrpcCall {
public:
explicit GrpcCall(CallContainer<GrpcCall>* container, int index, bool try_rpc,
const string* request_msg, string* response_msg,
int32* status_code, string* status_message)
const tstring* request_msg, tstring* response_msg,
int32* status_code, tstring* status_message)
: container_(container),
index_(index),
try_rpc_(try_rpc),
@ -59,18 +59,18 @@ class GrpcCall {
CallOptions* call_opts() { return &call_opts_; }
int index() { return index_; }
const string& request() const { return *request_msg_; }
string* response() const { return response_msg_; }
const tstring& request() const { return *request_msg_; }
tstring* response() const { return response_msg_; }
private:
CallContainer<GrpcCall>* const container_;
const int index_;
bool try_rpc_;
CallOptions call_opts_;
const string* request_msg_;
string* response_msg_;
const tstring* request_msg_;
tstring* response_msg_;
int* status_code_;
string* status_message_;
tstring* status_message_;
};
} // namespace internal
@ -168,16 +168,16 @@ void GrpcRPCFactory::CreateCall(const Tensor& request_t, const bool try_rpc,
int index, CallContainer<GrpcCall>* container,
Tensor* response_t, Tensor* status_code_t,
Tensor* status_message_t) {
auto request = request_t.flat<string>();
auto get_request_ptr = [&request](int64 ix) -> const string* {
auto request = request_t.flat<tstring>();
auto get_request_ptr = [&request](int64 ix) -> const tstring* {
return (request.size() > 1) ? &(request(ix)) : &(request(0));
};
auto response = response_t->flat<string>();
auto response = response_t->flat<tstring>();
int32* status_code_ptr = nullptr;
string* status_message_ptr = nullptr;
tstring* status_message_ptr = nullptr;
if (try_rpc) {
status_code_ptr = status_code_t->flat<int32>().data();
status_message_ptr = status_message_t->flat<string>().data();
status_message_ptr = status_message_t->flat<tstring>().data();
}
container->RegisterCall(container, index, try_rpc, get_request_ptr(index),
&response(index),
@ -200,13 +200,13 @@ void GrpcRPCFactory::StartCall(const Tensor& address_t, const Tensor& method_t,
return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix))
: singleton_stub;
};
auto get_method_ptr = [&method](int64 ix) -> const string* {
auto get_method_ptr = [&method](int64 ix) -> const tstring* {
return (method.size() > 1) ? &(method(ix)) : &(method(0));
};
int index = call->index();
// This object will delete itself when done.
new RPCState<string>(
new RPCState<tstring>(
get_stub(index), &completion_queue_, *get_method_ptr(index),
call->request(), call->response(),
/*done=*/[call](const Status& s) { call->Done(s); }, call->call_opts(),

View File

@ -65,11 +65,11 @@ class GrpcTensorCodingTest : public ::testing::Test {
}
}
void DoTestForStrings(DataType dt) {
gtl::InlinedVector<string, 4> v;
gtl::InlinedVector<tstring, 4> v;
for (int elems = 0; elems <= 10000; elems++) {
if (elems < 100 || (elems % 1000 == 0)) {
Tensor a(dt, TensorShape({1, static_cast<int64>(v.size())}));
test::FillValues<string>(&a, v);
test::FillValues<tstring>(&a, v);
Validate(a, (elems == 0));
}
v.push_back(strings::StrCat("This is string ", elems));

View File

@ -100,7 +100,7 @@ bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, TensorResponse* dst) {
return s.ok();
}
// GrpcMaybeParseProto into a string simply copies bytes into the string.
// GrpcMaybeParseProto simply copies bytes into the string.
bool GrpcMaybeParseProto(grpc::ByteBuffer* src, string* dst) {
dst->clear();
dst->reserve(src->Length());
@ -114,4 +114,20 @@ bool GrpcMaybeParseProto(grpc::ByteBuffer* src, string* dst) {
return true;
}
#ifdef USE_TSTRING
// GrpcMaybeParseProto simply copies bytes into the tstring.
bool GrpcMaybeParseProto(grpc::ByteBuffer* src, tstring* dst) {
dst->clear();
dst->reserve(src->Length());
std::vector<::grpc::Slice> slices;
if (!src->Dump(&slices).ok()) {
return false;
}
for (const ::grpc::Slice& s : slices) {
dst->append(reinterpret_cast<const char*>(s.begin()), s.size());
}
return true;
}
#endif // USE_TSTRING
} // namespace tensorflow

View File

@ -131,6 +131,9 @@ bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, TensorResponse* dst);
// Copy grpc buffer src to string *dst.
bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, string* dst);
// Copy grpc buffer src to tstring *dst.
bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, tstring* dst);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_

View File

@ -120,12 +120,12 @@ class TensorResponseTest : public ::testing::Test {
}
}
void DoTestForStrings(DataType dt) {
gtl::InlinedVector<string, 4> v;
gtl::InlinedVector<tstring, 4> v;
LOG(ERROR) << "DT: string";
for (int elems = 0; elems <= 10000; elems++) {
if (elems < 100 || (elems % 1000 == 0)) {
Tensor a(dt, TensorShape({1, static_cast<int64>(v.size())}));
test::FillValues<string>(&a, v);
test::FillValues<tstring>(&a, v);
Validate(a, (elems == 0), true);
}
v.push_back(strings::StrCat("This is string ", elems));

View File

@ -143,12 +143,16 @@ class tstring {
char& operator[](size_t i) { return str_[i]; }
void clear() noexcept { str_.clear(); }
void resize(size_t new_size) { str_.resize(new_size); }
void resize_uninitialized(size_t new_size) {
ResizeUninitialized<decltype(str_)>::Resize(str_, new_size);
}
void reserve(size_t n) { str_.reserve(n); }
tstring& assign(const char* str, size_t len) {
str_.assign(str, len);
@ -161,6 +165,24 @@ class tstring {
return *this;
}
tstring& append(const tstring& str) {
str_.append(str);
return *this;
}
tstring& append(const char* str, size_t len) {
str_.append(str, len);
return *this;
}
tstring& append(const char* str) {
str_.append(str);
return *this;
}
friend const tstring operator+(const tstring& a, const tstring& b);
friend bool operator==(const char* a, const tstring& b);
friend bool operator==(const std::string& a, const tstring& b);