Use tensorflow::StringPiece in literal_util.

Use template for RepeatedField assignment.

PiperOrigin-RevId: 157765477
This commit is contained in:
A. Unique TensorFlower 2017-06-01 14:44:53 -07:00 committed by TensorFlower Gardener
parent 7866fa01b7
commit d6fe47af57
2 changed files with 27 additions and 22 deletions
tensorflow/compiler/xla

View File

@ -308,7 +308,7 @@ Status Literal::Copy(const Literal& src_literal,
auto literal = MakeUnique<Literal>();
*literal->mutable_shape() =
ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())});
literal->set_u8s(value.ToString());
literal->set_u8s(tensorflow::StringPiece(value.ToString()));
return literal;
}
@ -1130,10 +1130,16 @@ void Literal::Resize<half>(int64 num_elements, half value) {
}
}
template <typename NativeT>
static void CopyToRepeatedField(proto2::RepeatedField<NativeT>* dest,
template <typename RepeatedFieldT, typename NativeT>
static void CopyToRepeatedField(RepeatedFieldT* dest,
const std::vector<NativeT>& src) {
*dest = proto2::RepeatedField<NativeT>(src.begin(), src.end());
*dest = RepeatedFieldT(src.begin(), src.end());
}
template <typename RepeatedFieldT>
static void CopyToRepeatedBoolField(RepeatedFieldT* dest,
const BoolVector& src) {
*dest = RepeatedFieldT(src.begin(), src.end());
}
LiteralProto Literal::ToProto() const {
@ -1143,24 +1149,23 @@ LiteralProto Literal::ToProto() const {
switch (shape().element_type()) {
case PRED:
if (preds().begin()) {
*proto.mutable_preds() =
proto2::RepeatedField<bool>(preds().begin(), preds().end());
CopyToRepeatedBoolField(proto.mutable_preds(), preds());
}
break;
case U8:
*proto.mutable_u8s() = u8s_string();
break;
case S32:
CopyToRepeatedField<int32>(proto.mutable_s32s(), s32s());
CopyToRepeatedField(proto.mutable_s32s(), s32s());
break;
case S64:
CopyToRepeatedField<int64>(proto.mutable_s64s(), s64s());
CopyToRepeatedField(proto.mutable_s64s(), s64s());
break;
case U32:
CopyToRepeatedField<uint32>(proto.mutable_u32s(), u32s());
CopyToRepeatedField(proto.mutable_u32s(), u32s());
break;
case U64:
CopyToRepeatedField<uint64>(proto.mutable_u64s(), u64s());
CopyToRepeatedField(proto.mutable_u64s(), u64s());
break;
case F16:
*proto.mutable_f16s() =
@ -1168,10 +1173,10 @@ LiteralProto Literal::ToProto() const {
f16s_.size() / sizeof(half));
break;
case F32:
CopyToRepeatedField<float>(proto.mutable_f32s(), f32s());
CopyToRepeatedField(proto.mutable_f32s(), f32s());
break;
case F64:
CopyToRepeatedField<double>(proto.mutable_f64s(), f64s());
CopyToRepeatedField(proto.mutable_f64s(), f64s());
break;
case TUPLE:
for (const auto& tuple : tuple_literals()) {
@ -1185,9 +1190,9 @@ LiteralProto Literal::ToProto() const {
return proto;
}
template <typename NativeT>
template <typename RepeatedFieldT, typename NativeT>
static void CopyFromRepeatedField(std::vector<NativeT>* dest,
const proto2::RepeatedField<NativeT>& src) {
const RepeatedFieldT& src) {
*dest = std::vector<NativeT>(src.begin(), src.end());
}
@ -1206,16 +1211,16 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) {
set_u8s(literal_proto.u8s());
break;
case S32:
CopyFromRepeatedField<int32>(mutable_s32s(), literal_proto.s32s());
CopyFromRepeatedField(mutable_s32s(), literal_proto.s32s());
break;
case S64:
CopyFromRepeatedField<int64>(mutable_s64s(), literal_proto.s64s());
CopyFromRepeatedField(mutable_s64s(), literal_proto.s64s());
break;
case U32:
CopyFromRepeatedField<uint32>(mutable_u32s(), literal_proto.u32s());
CopyFromRepeatedField(mutable_u32s(), literal_proto.u32s());
break;
case U64:
CopyFromRepeatedField<uint64>(mutable_u64s(), literal_proto.u64s());
CopyFromRepeatedField(mutable_u64s(), literal_proto.u64s());
break;
case F16: {
const string& s(literal_proto.f16s());
@ -1225,10 +1230,10 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) {
break;
}
case F32:
CopyFromRepeatedField<float>(mutable_f32s(), literal_proto.f32s());
CopyFromRepeatedField(mutable_f32s(), literal_proto.f32s());
break;
case F64:
CopyFromRepeatedField<double>(mutable_f64s(), literal_proto.f64s());
CopyFromRepeatedField(mutable_f64s(), literal_proto.f64s());
break;
case TUPLE:
for (const auto& proto : literal_proto.tuple_literals()) {

View File

@ -228,13 +228,13 @@ class Literal {
int u8s_size() const { return u8s().size(); }
const std::vector<uint8>& u8s() const { return u8s_; }
void set_u8s(const std::vector<uint8>& value) { u8s_ = value; }
void set_u8s(absl::string_view value) {
void set_u8s(tensorflow::StringPiece value) {
u8s_ = std::vector<uint8>(value.size());
u8s_.clear();
append_u8s(value);
}
void append_u8s(absl::string_view value) {
void append_u8s(tensorflow::StringPiece value) {
u8s_.insert(u8s_.end(), value.begin(), value.end());
}