Use tensorflow::StringPiece in literal_util.
Use template for RepeatedField assignment. PiperOrigin-RevId: 157765477
This commit is contained in:
parent
7866fa01b7
commit
d6fe47af57
tensorflow/compiler/xla
@ -308,7 +308,7 @@ Status Literal::Copy(const Literal& src_literal,
|
||||
auto literal = MakeUnique<Literal>();
|
||||
*literal->mutable_shape() =
|
||||
ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())});
|
||||
literal->set_u8s(value.ToString());
|
||||
literal->set_u8s(tensorflow::StringPiece(value.ToString()));
|
||||
return literal;
|
||||
}
|
||||
|
||||
@ -1130,10 +1130,16 @@ void Literal::Resize<half>(int64 num_elements, half value) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static void CopyToRepeatedField(proto2::RepeatedField<NativeT>* dest,
|
||||
template <typename RepeatedFieldT, typename NativeT>
|
||||
static void CopyToRepeatedField(RepeatedFieldT* dest,
|
||||
const std::vector<NativeT>& src) {
|
||||
*dest = proto2::RepeatedField<NativeT>(src.begin(), src.end());
|
||||
*dest = RepeatedFieldT(src.begin(), src.end());
|
||||
}
|
||||
|
||||
template <typename RepeatedFieldT>
|
||||
static void CopyToRepeatedBoolField(RepeatedFieldT* dest,
|
||||
const BoolVector& src) {
|
||||
*dest = RepeatedFieldT(src.begin(), src.end());
|
||||
}
|
||||
|
||||
LiteralProto Literal::ToProto() const {
|
||||
@ -1143,24 +1149,23 @@ LiteralProto Literal::ToProto() const {
|
||||
switch (shape().element_type()) {
|
||||
case PRED:
|
||||
if (preds().begin()) {
|
||||
*proto.mutable_preds() =
|
||||
proto2::RepeatedField<bool>(preds().begin(), preds().end());
|
||||
CopyToRepeatedBoolField(proto.mutable_preds(), preds());
|
||||
}
|
||||
break;
|
||||
case U8:
|
||||
*proto.mutable_u8s() = u8s_string();
|
||||
break;
|
||||
case S32:
|
||||
CopyToRepeatedField<int32>(proto.mutable_s32s(), s32s());
|
||||
CopyToRepeatedField(proto.mutable_s32s(), s32s());
|
||||
break;
|
||||
case S64:
|
||||
CopyToRepeatedField<int64>(proto.mutable_s64s(), s64s());
|
||||
CopyToRepeatedField(proto.mutable_s64s(), s64s());
|
||||
break;
|
||||
case U32:
|
||||
CopyToRepeatedField<uint32>(proto.mutable_u32s(), u32s());
|
||||
CopyToRepeatedField(proto.mutable_u32s(), u32s());
|
||||
break;
|
||||
case U64:
|
||||
CopyToRepeatedField<uint64>(proto.mutable_u64s(), u64s());
|
||||
CopyToRepeatedField(proto.mutable_u64s(), u64s());
|
||||
break;
|
||||
case F16:
|
||||
*proto.mutable_f16s() =
|
||||
@ -1168,10 +1173,10 @@ LiteralProto Literal::ToProto() const {
|
||||
f16s_.size() / sizeof(half));
|
||||
break;
|
||||
case F32:
|
||||
CopyToRepeatedField<float>(proto.mutable_f32s(), f32s());
|
||||
CopyToRepeatedField(proto.mutable_f32s(), f32s());
|
||||
break;
|
||||
case F64:
|
||||
CopyToRepeatedField<double>(proto.mutable_f64s(), f64s());
|
||||
CopyToRepeatedField(proto.mutable_f64s(), f64s());
|
||||
break;
|
||||
case TUPLE:
|
||||
for (const auto& tuple : tuple_literals()) {
|
||||
@ -1185,9 +1190,9 @@ LiteralProto Literal::ToProto() const {
|
||||
return proto;
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
template <typename RepeatedFieldT, typename NativeT>
|
||||
static void CopyFromRepeatedField(std::vector<NativeT>* dest,
|
||||
const proto2::RepeatedField<NativeT>& src) {
|
||||
const RepeatedFieldT& src) {
|
||||
*dest = std::vector<NativeT>(src.begin(), src.end());
|
||||
}
|
||||
|
||||
@ -1206,16 +1211,16 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) {
|
||||
set_u8s(literal_proto.u8s());
|
||||
break;
|
||||
case S32:
|
||||
CopyFromRepeatedField<int32>(mutable_s32s(), literal_proto.s32s());
|
||||
CopyFromRepeatedField(mutable_s32s(), literal_proto.s32s());
|
||||
break;
|
||||
case S64:
|
||||
CopyFromRepeatedField<int64>(mutable_s64s(), literal_proto.s64s());
|
||||
CopyFromRepeatedField(mutable_s64s(), literal_proto.s64s());
|
||||
break;
|
||||
case U32:
|
||||
CopyFromRepeatedField<uint32>(mutable_u32s(), literal_proto.u32s());
|
||||
CopyFromRepeatedField(mutable_u32s(), literal_proto.u32s());
|
||||
break;
|
||||
case U64:
|
||||
CopyFromRepeatedField<uint64>(mutable_u64s(), literal_proto.u64s());
|
||||
CopyFromRepeatedField(mutable_u64s(), literal_proto.u64s());
|
||||
break;
|
||||
case F16: {
|
||||
const string& s(literal_proto.f16s());
|
||||
@ -1225,10 +1230,10 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) {
|
||||
break;
|
||||
}
|
||||
case F32:
|
||||
CopyFromRepeatedField<float>(mutable_f32s(), literal_proto.f32s());
|
||||
CopyFromRepeatedField(mutable_f32s(), literal_proto.f32s());
|
||||
break;
|
||||
case F64:
|
||||
CopyFromRepeatedField<double>(mutable_f64s(), literal_proto.f64s());
|
||||
CopyFromRepeatedField(mutable_f64s(), literal_proto.f64s());
|
||||
break;
|
||||
case TUPLE:
|
||||
for (const auto& proto : literal_proto.tuple_literals()) {
|
||||
|
@ -228,13 +228,13 @@ class Literal {
|
||||
int u8s_size() const { return u8s().size(); }
|
||||
const std::vector<uint8>& u8s() const { return u8s_; }
|
||||
void set_u8s(const std::vector<uint8>& value) { u8s_ = value; }
|
||||
void set_u8s(absl::string_view value) {
|
||||
void set_u8s(tensorflow::StringPiece value) {
|
||||
u8s_ = std::vector<uint8>(value.size());
|
||||
u8s_.clear();
|
||||
append_u8s(value);
|
||||
}
|
||||
|
||||
void append_u8s(absl::string_view value) {
|
||||
void append_u8s(tensorflow::StringPiece value) {
|
||||
u8s_.insert(u8s_.end(), value.begin(), value.end());
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user