Use tensorflow::StringPiece in literal_util.

Use template for RepeatedField assignment.

PiperOrigin-RevId: 157765477
This commit is contained in:
A. Unique TensorFlower 2017-06-01 14:44:53 -07:00 committed by TensorFlower Gardener
parent 7866fa01b7
commit d6fe47af57
2 changed files with 27 additions and 22 deletions

View File

@ -308,7 +308,7 @@ Status Literal::Copy(const Literal& src_literal,
auto literal = MakeUnique<Literal>(); auto literal = MakeUnique<Literal>();
*literal->mutable_shape() = *literal->mutable_shape() =
ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}); ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())});
literal->set_u8s(value.ToString()); literal->set_u8s(tensorflow::StringPiece(value.ToString()));
return literal; return literal;
} }
@ -1130,10 +1130,16 @@ void Literal::Resize<half>(int64 num_elements, half value) {
} }
} }
template <typename NativeT> template <typename RepeatedFieldT, typename NativeT>
static void CopyToRepeatedField(proto2::RepeatedField<NativeT>* dest, static void CopyToRepeatedField(RepeatedFieldT* dest,
const std::vector<NativeT>& src) { const std::vector<NativeT>& src) {
*dest = proto2::RepeatedField<NativeT>(src.begin(), src.end()); *dest = RepeatedFieldT(src.begin(), src.end());
}
template <typename RepeatedFieldT>
static void CopyToRepeatedBoolField(RepeatedFieldT* dest,
const BoolVector& src) {
*dest = RepeatedFieldT(src.begin(), src.end());
} }
LiteralProto Literal::ToProto() const { LiteralProto Literal::ToProto() const {
@ -1143,24 +1149,23 @@ LiteralProto Literal::ToProto() const {
switch (shape().element_type()) { switch (shape().element_type()) {
case PRED: case PRED:
if (preds().begin()) { if (preds().begin()) {
*proto.mutable_preds() = CopyToRepeatedBoolField(proto.mutable_preds(), preds());
proto2::RepeatedField<bool>(preds().begin(), preds().end());
} }
break; break;
case U8: case U8:
*proto.mutable_u8s() = u8s_string(); *proto.mutable_u8s() = u8s_string();
break; break;
case S32: case S32:
CopyToRepeatedField<int32>(proto.mutable_s32s(), s32s()); CopyToRepeatedField(proto.mutable_s32s(), s32s());
break; break;
case S64: case S64:
CopyToRepeatedField<int64>(proto.mutable_s64s(), s64s()); CopyToRepeatedField(proto.mutable_s64s(), s64s());
break; break;
case U32: case U32:
CopyToRepeatedField<uint32>(proto.mutable_u32s(), u32s()); CopyToRepeatedField(proto.mutable_u32s(), u32s());
break; break;
case U64: case U64:
CopyToRepeatedField<uint64>(proto.mutable_u64s(), u64s()); CopyToRepeatedField(proto.mutable_u64s(), u64s());
break; break;
case F16: case F16:
*proto.mutable_f16s() = *proto.mutable_f16s() =
@ -1168,10 +1173,10 @@ LiteralProto Literal::ToProto() const {
f16s_.size() / sizeof(half)); f16s_.size() / sizeof(half));
break; break;
case F32: case F32:
CopyToRepeatedField<float>(proto.mutable_f32s(), f32s()); CopyToRepeatedField(proto.mutable_f32s(), f32s());
break; break;
case F64: case F64:
CopyToRepeatedField<double>(proto.mutable_f64s(), f64s()); CopyToRepeatedField(proto.mutable_f64s(), f64s());
break; break;
case TUPLE: case TUPLE:
for (const auto& tuple : tuple_literals()) { for (const auto& tuple : tuple_literals()) {
@ -1185,9 +1190,9 @@ LiteralProto Literal::ToProto() const {
return proto; return proto;
} }
template <typename NativeT> template <typename RepeatedFieldT, typename NativeT>
static void CopyFromRepeatedField(std::vector<NativeT>* dest, static void CopyFromRepeatedField(std::vector<NativeT>* dest,
const proto2::RepeatedField<NativeT>& src) { const RepeatedFieldT& src) {
*dest = std::vector<NativeT>(src.begin(), src.end()); *dest = std::vector<NativeT>(src.begin(), src.end());
} }
@ -1206,16 +1211,16 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) {
set_u8s(literal_proto.u8s()); set_u8s(literal_proto.u8s());
break; break;
case S32: case S32:
CopyFromRepeatedField<int32>(mutable_s32s(), literal_proto.s32s()); CopyFromRepeatedField(mutable_s32s(), literal_proto.s32s());
break; break;
case S64: case S64:
CopyFromRepeatedField<int64>(mutable_s64s(), literal_proto.s64s()); CopyFromRepeatedField(mutable_s64s(), literal_proto.s64s());
break; break;
case U32: case U32:
CopyFromRepeatedField<uint32>(mutable_u32s(), literal_proto.u32s()); CopyFromRepeatedField(mutable_u32s(), literal_proto.u32s());
break; break;
case U64: case U64:
CopyFromRepeatedField<uint64>(mutable_u64s(), literal_proto.u64s()); CopyFromRepeatedField(mutable_u64s(), literal_proto.u64s());
break; break;
case F16: { case F16: {
const string& s(literal_proto.f16s()); const string& s(literal_proto.f16s());
@ -1225,10 +1230,10 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) {
break; break;
} }
case F32: case F32:
CopyFromRepeatedField<float>(mutable_f32s(), literal_proto.f32s()); CopyFromRepeatedField(mutable_f32s(), literal_proto.f32s());
break; break;
case F64: case F64:
CopyFromRepeatedField<double>(mutable_f64s(), literal_proto.f64s()); CopyFromRepeatedField(mutable_f64s(), literal_proto.f64s());
break; break;
case TUPLE: case TUPLE:
for (const auto& proto : literal_proto.tuple_literals()) { for (const auto& proto : literal_proto.tuple_literals()) {

View File

@ -228,13 +228,13 @@ class Literal {
int u8s_size() const { return u8s().size(); } int u8s_size() const { return u8s().size(); }
const std::vector<uint8>& u8s() const { return u8s_; } const std::vector<uint8>& u8s() const { return u8s_; }
void set_u8s(const std::vector<uint8>& value) { u8s_ = value; } void set_u8s(const std::vector<uint8>& value) { u8s_ = value; }
void set_u8s(absl::string_view value) { void set_u8s(tensorflow::StringPiece value) {
u8s_ = std::vector<uint8>(value.size()); u8s_ = std::vector<uint8>(value.size());
u8s_.clear(); u8s_.clear();
append_u8s(value); append_u8s(value);
} }
void append_u8s(absl::string_view value) { void append_u8s(tensorflow::StringPiece value) {
u8s_.insert(u8s_.end(), value.begin(), value.end()); u8s_.insert(u8s_.end(), value.begin(), value.end());
} }