Update core/distributed_runtime to use tstring.
This is a part of a larger migration effort for tensorflow::tstring. See: https://github.com/tensorflow/community/pull/91 PiperOrigin-RevId: 264982634
This commit is contained in:
parent
f09d3e2997
commit
4e5bee560f
@ -34,8 +34,8 @@ namespace internal {
|
||||
class GrpcCall {
|
||||
public:
|
||||
explicit GrpcCall(CallContainer<GrpcCall>* container, int index, bool try_rpc,
|
||||
const string* request_msg, string* response_msg,
|
||||
int32* status_code, string* status_message)
|
||||
const tstring* request_msg, tstring* response_msg,
|
||||
int32* status_code, tstring* status_message)
|
||||
: container_(container),
|
||||
index_(index),
|
||||
try_rpc_(try_rpc),
|
||||
@ -59,18 +59,18 @@ class GrpcCall {
|
||||
|
||||
CallOptions* call_opts() { return &call_opts_; }
|
||||
int index() { return index_; }
|
||||
const string& request() const { return *request_msg_; }
|
||||
string* response() const { return response_msg_; }
|
||||
const tstring& request() const { return *request_msg_; }
|
||||
tstring* response() const { return response_msg_; }
|
||||
|
||||
private:
|
||||
CallContainer<GrpcCall>* const container_;
|
||||
const int index_;
|
||||
bool try_rpc_;
|
||||
CallOptions call_opts_;
|
||||
const string* request_msg_;
|
||||
string* response_msg_;
|
||||
const tstring* request_msg_;
|
||||
tstring* response_msg_;
|
||||
int* status_code_;
|
||||
string* status_message_;
|
||||
tstring* status_message_;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
@ -168,16 +168,16 @@ void GrpcRPCFactory::CreateCall(const Tensor& request_t, const bool try_rpc,
|
||||
int index, CallContainer<GrpcCall>* container,
|
||||
Tensor* response_t, Tensor* status_code_t,
|
||||
Tensor* status_message_t) {
|
||||
auto request = request_t.flat<string>();
|
||||
auto get_request_ptr = [&request](int64 ix) -> const string* {
|
||||
auto request = request_t.flat<tstring>();
|
||||
auto get_request_ptr = [&request](int64 ix) -> const tstring* {
|
||||
return (request.size() > 1) ? &(request(ix)) : &(request(0));
|
||||
};
|
||||
auto response = response_t->flat<string>();
|
||||
auto response = response_t->flat<tstring>();
|
||||
int32* status_code_ptr = nullptr;
|
||||
string* status_message_ptr = nullptr;
|
||||
tstring* status_message_ptr = nullptr;
|
||||
if (try_rpc) {
|
||||
status_code_ptr = status_code_t->flat<int32>().data();
|
||||
status_message_ptr = status_message_t->flat<string>().data();
|
||||
status_message_ptr = status_message_t->flat<tstring>().data();
|
||||
}
|
||||
container->RegisterCall(container, index, try_rpc, get_request_ptr(index),
|
||||
&response(index),
|
||||
@ -200,13 +200,13 @@ void GrpcRPCFactory::StartCall(const Tensor& address_t, const Tensor& method_t,
|
||||
return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix))
|
||||
: singleton_stub;
|
||||
};
|
||||
auto get_method_ptr = [&method](int64 ix) -> const string* {
|
||||
auto get_method_ptr = [&method](int64 ix) -> const tstring* {
|
||||
return (method.size() > 1) ? &(method(ix)) : &(method(0));
|
||||
};
|
||||
|
||||
int index = call->index();
|
||||
// This object will delete itself when done.
|
||||
new RPCState<string>(
|
||||
new RPCState<tstring>(
|
||||
get_stub(index), &completion_queue_, *get_method_ptr(index),
|
||||
call->request(), call->response(),
|
||||
/*done=*/[call](const Status& s) { call->Done(s); }, call->call_opts(),
|
||||
|
@ -65,11 +65,11 @@ class GrpcTensorCodingTest : public ::testing::Test {
|
||||
}
|
||||
}
|
||||
void DoTestForStrings(DataType dt) {
|
||||
gtl::InlinedVector<string, 4> v;
|
||||
gtl::InlinedVector<tstring, 4> v;
|
||||
for (int elems = 0; elems <= 10000; elems++) {
|
||||
if (elems < 100 || (elems % 1000 == 0)) {
|
||||
Tensor a(dt, TensorShape({1, static_cast<int64>(v.size())}));
|
||||
test::FillValues<string>(&a, v);
|
||||
test::FillValues<tstring>(&a, v);
|
||||
Validate(a, (elems == 0));
|
||||
}
|
||||
v.push_back(strings::StrCat("This is string ", elems));
|
||||
|
@ -100,7 +100,7 @@ bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, TensorResponse* dst) {
|
||||
return s.ok();
|
||||
}
|
||||
|
||||
// GrpcMaybeParseProto into a string simply copies bytes into the string.
|
||||
// GrpcMaybeParseProto simply copies bytes into the string.
|
||||
bool GrpcMaybeParseProto(grpc::ByteBuffer* src, string* dst) {
|
||||
dst->clear();
|
||||
dst->reserve(src->Length());
|
||||
@ -114,4 +114,20 @@ bool GrpcMaybeParseProto(grpc::ByteBuffer* src, string* dst) {
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifdef USE_TSTRING
|
||||
// GrpcMaybeParseProto simply copies bytes into the tstring.
|
||||
bool GrpcMaybeParseProto(grpc::ByteBuffer* src, tstring* dst) {
|
||||
dst->clear();
|
||||
dst->reserve(src->Length());
|
||||
std::vector<::grpc::Slice> slices;
|
||||
if (!src->Dump(&slices).ok()) {
|
||||
return false;
|
||||
}
|
||||
for (const ::grpc::Slice& s : slices) {
|
||||
dst->append(reinterpret_cast<const char*>(s.begin()), s.size());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
#endif // USE_TSTRING
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -131,6 +131,9 @@ bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, TensorResponse* dst);
|
||||
// Copy grpc buffer src to string *dst.
|
||||
bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, string* dst);
|
||||
|
||||
// Copy grpc buffer src to tstring *dst.
|
||||
bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, tstring* dst);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
|
||||
|
@ -120,12 +120,12 @@ class TensorResponseTest : public ::testing::Test {
|
||||
}
|
||||
}
|
||||
void DoTestForStrings(DataType dt) {
|
||||
gtl::InlinedVector<string, 4> v;
|
||||
gtl::InlinedVector<tstring, 4> v;
|
||||
LOG(ERROR) << "DT: string";
|
||||
for (int elems = 0; elems <= 10000; elems++) {
|
||||
if (elems < 100 || (elems % 1000 == 0)) {
|
||||
Tensor a(dt, TensorShape({1, static_cast<int64>(v.size())}));
|
||||
test::FillValues<string>(&a, v);
|
||||
test::FillValues<tstring>(&a, v);
|
||||
Validate(a, (elems == 0), true);
|
||||
}
|
||||
v.push_back(strings::StrCat("This is string ", elems));
|
||||
|
@ -143,12 +143,16 @@ class tstring {
|
||||
|
||||
char& operator[](size_t i) { return str_[i]; }
|
||||
|
||||
void clear() noexcept { str_.clear(); }
|
||||
|
||||
void resize(size_t new_size) { str_.resize(new_size); }
|
||||
|
||||
void resize_uninitialized(size_t new_size) {
|
||||
ResizeUninitialized<decltype(str_)>::Resize(str_, new_size);
|
||||
}
|
||||
|
||||
void reserve(size_t n) { str_.reserve(n); }
|
||||
|
||||
tstring& assign(const char* str, size_t len) {
|
||||
str_.assign(str, len);
|
||||
|
||||
@ -161,6 +165,24 @@ class tstring {
|
||||
return *this;
|
||||
}
|
||||
|
||||
tstring& append(const tstring& str) {
|
||||
str_.append(str);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
tstring& append(const char* str, size_t len) {
|
||||
str_.append(str, len);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
tstring& append(const char* str) {
|
||||
str_.append(str);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
friend const tstring operator+(const tstring& a, const tstring& b);
|
||||
friend bool operator==(const char* a, const tstring& b);
|
||||
friend bool operator==(const std::string& a, const tstring& b);
|
||||
|
Loading…
Reference in New Issue
Block a user