Use tensorflow::StringPiece in literal_util.
Use template for RepeatedField assignment. PiperOrigin-RevId: 157765477
This commit is contained in:
parent
7866fa01b7
commit
d6fe47af57
@ -308,7 +308,7 @@ Status Literal::Copy(const Literal& src_literal,
|
|||||||
auto literal = MakeUnique<Literal>();
|
auto literal = MakeUnique<Literal>();
|
||||||
*literal->mutable_shape() =
|
*literal->mutable_shape() =
|
||||||
ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())});
|
ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())});
|
||||||
literal->set_u8s(value.ToString());
|
literal->set_u8s(tensorflow::StringPiece(value.ToString()));
|
||||||
return literal;
|
return literal;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1130,10 +1130,16 @@ void Literal::Resize<half>(int64 num_elements, half value) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename RepeatedFieldT, typename NativeT>
|
||||||
static void CopyToRepeatedField(proto2::RepeatedField<NativeT>* dest,
|
static void CopyToRepeatedField(RepeatedFieldT* dest,
|
||||||
const std::vector<NativeT>& src) {
|
const std::vector<NativeT>& src) {
|
||||||
*dest = proto2::RepeatedField<NativeT>(src.begin(), src.end());
|
*dest = RepeatedFieldT(src.begin(), src.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename RepeatedFieldT>
|
||||||
|
static void CopyToRepeatedBoolField(RepeatedFieldT* dest,
|
||||||
|
const BoolVector& src) {
|
||||||
|
*dest = RepeatedFieldT(src.begin(), src.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
LiteralProto Literal::ToProto() const {
|
LiteralProto Literal::ToProto() const {
|
||||||
@ -1143,24 +1149,23 @@ LiteralProto Literal::ToProto() const {
|
|||||||
switch (shape().element_type()) {
|
switch (shape().element_type()) {
|
||||||
case PRED:
|
case PRED:
|
||||||
if (preds().begin()) {
|
if (preds().begin()) {
|
||||||
*proto.mutable_preds() =
|
CopyToRepeatedBoolField(proto.mutable_preds(), preds());
|
||||||
proto2::RepeatedField<bool>(preds().begin(), preds().end());
|
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case U8:
|
case U8:
|
||||||
*proto.mutable_u8s() = u8s_string();
|
*proto.mutable_u8s() = u8s_string();
|
||||||
break;
|
break;
|
||||||
case S32:
|
case S32:
|
||||||
CopyToRepeatedField<int32>(proto.mutable_s32s(), s32s());
|
CopyToRepeatedField(proto.mutable_s32s(), s32s());
|
||||||
break;
|
break;
|
||||||
case S64:
|
case S64:
|
||||||
CopyToRepeatedField<int64>(proto.mutable_s64s(), s64s());
|
CopyToRepeatedField(proto.mutable_s64s(), s64s());
|
||||||
break;
|
break;
|
||||||
case U32:
|
case U32:
|
||||||
CopyToRepeatedField<uint32>(proto.mutable_u32s(), u32s());
|
CopyToRepeatedField(proto.mutable_u32s(), u32s());
|
||||||
break;
|
break;
|
||||||
case U64:
|
case U64:
|
||||||
CopyToRepeatedField<uint64>(proto.mutable_u64s(), u64s());
|
CopyToRepeatedField(proto.mutable_u64s(), u64s());
|
||||||
break;
|
break;
|
||||||
case F16:
|
case F16:
|
||||||
*proto.mutable_f16s() =
|
*proto.mutable_f16s() =
|
||||||
@ -1168,10 +1173,10 @@ LiteralProto Literal::ToProto() const {
|
|||||||
f16s_.size() / sizeof(half));
|
f16s_.size() / sizeof(half));
|
||||||
break;
|
break;
|
||||||
case F32:
|
case F32:
|
||||||
CopyToRepeatedField<float>(proto.mutable_f32s(), f32s());
|
CopyToRepeatedField(proto.mutable_f32s(), f32s());
|
||||||
break;
|
break;
|
||||||
case F64:
|
case F64:
|
||||||
CopyToRepeatedField<double>(proto.mutable_f64s(), f64s());
|
CopyToRepeatedField(proto.mutable_f64s(), f64s());
|
||||||
break;
|
break;
|
||||||
case TUPLE:
|
case TUPLE:
|
||||||
for (const auto& tuple : tuple_literals()) {
|
for (const auto& tuple : tuple_literals()) {
|
||||||
@ -1185,9 +1190,9 @@ LiteralProto Literal::ToProto() const {
|
|||||||
return proto;
|
return proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename RepeatedFieldT, typename NativeT>
|
||||||
static void CopyFromRepeatedField(std::vector<NativeT>* dest,
|
static void CopyFromRepeatedField(std::vector<NativeT>* dest,
|
||||||
const proto2::RepeatedField<NativeT>& src) {
|
const RepeatedFieldT& src) {
|
||||||
*dest = std::vector<NativeT>(src.begin(), src.end());
|
*dest = std::vector<NativeT>(src.begin(), src.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1206,16 +1211,16 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) {
|
|||||||
set_u8s(literal_proto.u8s());
|
set_u8s(literal_proto.u8s());
|
||||||
break;
|
break;
|
||||||
case S32:
|
case S32:
|
||||||
CopyFromRepeatedField<int32>(mutable_s32s(), literal_proto.s32s());
|
CopyFromRepeatedField(mutable_s32s(), literal_proto.s32s());
|
||||||
break;
|
break;
|
||||||
case S64:
|
case S64:
|
||||||
CopyFromRepeatedField<int64>(mutable_s64s(), literal_proto.s64s());
|
CopyFromRepeatedField(mutable_s64s(), literal_proto.s64s());
|
||||||
break;
|
break;
|
||||||
case U32:
|
case U32:
|
||||||
CopyFromRepeatedField<uint32>(mutable_u32s(), literal_proto.u32s());
|
CopyFromRepeatedField(mutable_u32s(), literal_proto.u32s());
|
||||||
break;
|
break;
|
||||||
case U64:
|
case U64:
|
||||||
CopyFromRepeatedField<uint64>(mutable_u64s(), literal_proto.u64s());
|
CopyFromRepeatedField(mutable_u64s(), literal_proto.u64s());
|
||||||
break;
|
break;
|
||||||
case F16: {
|
case F16: {
|
||||||
const string& s(literal_proto.f16s());
|
const string& s(literal_proto.f16s());
|
||||||
@ -1225,10 +1230,10 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case F32:
|
case F32:
|
||||||
CopyFromRepeatedField<float>(mutable_f32s(), literal_proto.f32s());
|
CopyFromRepeatedField(mutable_f32s(), literal_proto.f32s());
|
||||||
break;
|
break;
|
||||||
case F64:
|
case F64:
|
||||||
CopyFromRepeatedField<double>(mutable_f64s(), literal_proto.f64s());
|
CopyFromRepeatedField(mutable_f64s(), literal_proto.f64s());
|
||||||
break;
|
break;
|
||||||
case TUPLE:
|
case TUPLE:
|
||||||
for (const auto& proto : literal_proto.tuple_literals()) {
|
for (const auto& proto : literal_proto.tuple_literals()) {
|
||||||
|
@ -228,13 +228,13 @@ class Literal {
|
|||||||
int u8s_size() const { return u8s().size(); }
|
int u8s_size() const { return u8s().size(); }
|
||||||
const std::vector<uint8>& u8s() const { return u8s_; }
|
const std::vector<uint8>& u8s() const { return u8s_; }
|
||||||
void set_u8s(const std::vector<uint8>& value) { u8s_ = value; }
|
void set_u8s(const std::vector<uint8>& value) { u8s_ = value; }
|
||||||
void set_u8s(absl::string_view value) {
|
void set_u8s(tensorflow::StringPiece value) {
|
||||||
u8s_ = std::vector<uint8>(value.size());
|
u8s_ = std::vector<uint8>(value.size());
|
||||||
u8s_.clear();
|
u8s_.clear();
|
||||||
append_u8s(value);
|
append_u8s(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void append_u8s(absl::string_view value) {
|
void append_u8s(tensorflow::StringPiece value) {
|
||||||
u8s_.insert(u8s_.end(), value.begin(), value.end());
|
u8s_.insert(u8s_.end(), value.begin(), value.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user