Updated util/ to use tstring.

This is a part of a larger migration effort for tensorflow::tstring.
See: https://github.com/tensorflow/community/pull/91
PiperOrigin-RevId: 263574807
This commit is contained in:
Dero Gharibian 2019-08-15 09:04:45 -07:00 committed by TensorFlower Gardener
parent 0e271c3e39
commit 7adc342449
16 changed files with 159 additions and 112 deletions

View File

@ -50,8 +50,8 @@ Status HandleElementToSlice(T* src, T* dest, int64 num_values,
} }
template <> template <>
Status HandleElementToSlice<string>(string* src, string* dest, int64 num_values, Status HandleElementToSlice<tstring>(tstring* src, tstring* dest,
bool can_move) { int64 num_values, bool can_move) {
if (can_move) { if (can_move) {
for (int64 i = 0; i < num_values; ++i) { for (int64 i = 0; i < num_values; ++i) {
*dest++ = std::move(*src++); *dest++ = std::move(*src++);

View File

@ -465,7 +465,7 @@ enum class Type { Sparse, Dense };
struct SparseBuffer { struct SparseBuffer {
// Features are in one of the 3 vectors below depending on config's dtype. // Features are in one of the 3 vectors below depending on config's dtype.
// Other 2 vectors remain empty. // Other 2 vectors remain empty.
SmallVector<string> bytes_list; SmallVector<tstring> bytes_list;
SmallVector<float> float_list; SmallVector<float> float_list;
SmallVector<int64> int64_list; SmallVector<int64> int64_list;
@ -666,8 +666,8 @@ Status FastParseSerializedExample(
break; break;
} }
case DT_STRING: { case DT_STRING: {
auto out_p = out.flat<string>().data() + offset; auto out_p = out.flat<tstring>().data() + offset;
LimitedArraySlice<string> slice(out_p, num_elements); LimitedArraySlice<tstring> slice(out_p, num_elements);
if (!feature.ParseBytesList(&slice)) return parse_error(); if (!feature.ParseBytesList(&slice)) return parse_error();
if (slice.EndDistance() != 0) { if (slice.EndDistance() != 0) {
return shape_error(num_elements - slice.EndDistance(), "bytes"); return shape_error(num_elements - slice.EndDistance(), "bytes");
@ -907,7 +907,7 @@ const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer) {
return buffer.float_list; return buffer.float_list;
} }
template <> template <>
const SmallVector<string>& GetListFromBuffer<string>( const SmallVector<tstring>& GetListFromBuffer<tstring>(
const SparseBuffer& buffer) { const SparseBuffer& buffer) {
return buffer.bytes_list; return buffer.bytes_list;
} }
@ -917,7 +917,7 @@ void CopyOrMoveBlock(const T* b, const T* e, T* t) {
std::copy(b, e, t); std::copy(b, e, t);
} }
template <> template <>
void CopyOrMoveBlock(const string* b, const string* e, string* t) { void CopyOrMoveBlock(const tstring* b, const tstring* e, tstring* t) {
std::move(b, e, t); std::move(b, e, t);
} }
@ -1002,8 +1002,8 @@ class TensorVector {
} // namespace } // namespace
Status FastParseExample(const Config& config, Status FastParseExample(const Config& config,
gtl::ArraySlice<string> serialized, gtl::ArraySlice<tstring> serialized,
gtl::ArraySlice<string> example_names, gtl::ArraySlice<tstring> example_names,
thread::ThreadPool* thread_pool, Result* result) { thread::ThreadPool* thread_pool, Result* result) {
DCHECK(result != nullptr); DCHECK(result != nullptr);
// Check config so we can safely CHECK(false) in switches on config.*.dtype // Check config so we can safely CHECK(false) in switches on config.*.dtype
@ -1253,8 +1253,8 @@ Status FastParseExample(const Config& config,
break; break;
} }
case DT_STRING: { case DT_STRING: {
FillAndCopyVarLen<string>(d, num_elements, num_elements_per_minibatch, FillAndCopyVarLen<tstring>(d, num_elements, num_elements_per_minibatch,
config, varlen_dense_buffers, &values); config, varlen_dense_buffers, &values);
break; break;
} }
default: default:
@ -1440,8 +1440,8 @@ Status FastParseSingleExample(const Config& config,
break; break;
} }
case DT_STRING: { case DT_STRING: {
auto out_p = out->flat<string>().data(); auto out_p = out->flat<tstring>().data();
LimitedArraySlice<string> slice(out_p, num_elements); LimitedArraySlice<tstring> slice(out_p, num_elements);
if (!feature.ParseBytesList(&slice)) return parse_error(); if (!feature.ParseBytesList(&slice)) return parse_error();
if (slice.EndDistance() != 0) { if (slice.EndDistance() != 0) {
return parse_error(); return parse_error();
@ -1453,7 +1453,7 @@ Status FastParseSingleExample(const Config& config,
} }
} else { // if variable length } else { // if variable length
SmallVector<string> bytes_list; SmallVector<tstring> bytes_list;
TensorVector<float> float_list; TensorVector<float> float_list;
SmallVector<int64> int64_list; SmallVector<int64> int64_list;
@ -1627,7 +1627,7 @@ Status FastParseSingleExample(const Config& config,
// Return the number of bytes elements parsed, or -1 on error. If out is null, // Return the number of bytes elements parsed, or -1 on error. If out is null,
// this method simply counts the number of elements without any copying. // this method simply counts the number of elements without any copying.
inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream, inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
string* out) { tstring* out) {
int num_elements = 0; int num_elements = 0;
uint32 length; uint32 length;
if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) { if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) {
@ -1638,12 +1638,23 @@ inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
while (!stream->ExpectAtEnd()) { while (!stream->ExpectAtEnd()) {
uint32 bytes_length; uint32 bytes_length;
if (!stream->ExpectTag(kDelimitedTag(1)) || if (!stream->ExpectTag(kDelimitedTag(1)) ||
!stream->ReadVarint32(&bytes_length) || !stream->ReadVarint32(&bytes_length)) {
(out != nullptr && !stream->ReadString(out++, bytes_length))) {
return -1; return -1;
} }
if (out == nullptr) { if (out == nullptr) {
stream->Skip(bytes_length); stream->Skip(bytes_length);
} else {
#ifdef USE_TSTRING
out->resize_uninitialized(bytes_length);
if (!stream->ReadRaw(out->data(), bytes_length)) {
return -1;
}
#else // USE_TSTRING
if (!stream->ReadString(out, bytes_length)) {
return -1;
}
#endif // USE_TSTRING
out++;
} }
num_elements++; num_elements++;
} }
@ -1809,7 +1820,7 @@ inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
Status FastParseSequenceExample( Status FastParseSequenceExample(
const FastParseExampleConfig& context_config, const FastParseExampleConfig& context_config,
const FastParseExampleConfig& feature_list_config, const FastParseExampleConfig& feature_list_config,
gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names, gtl::ArraySlice<tstring> serialized, gtl::ArraySlice<tstring> example_names,
thread::ThreadPool* thread_pool, Result* context_result, thread::ThreadPool* thread_pool, Result* context_result,
Result* feature_list_result, std::vector<Tensor>* dense_feature_lengths) { Result* feature_list_result, std::vector<Tensor>* dense_feature_lengths) {
int num_examples = serialized.size(); int num_examples = serialized.size();
@ -1878,10 +1889,10 @@ Status FastParseSequenceExample(
all_context_features(num_examples); all_context_features(num_examples);
std::vector<absl::flat_hash_map<StringPiece, StringPiece>> std::vector<absl::flat_hash_map<StringPiece, StringPiece>>
all_sequence_features(num_examples); all_sequence_features(num_examples);
const string kUnknown = "<unknown>"; const tstring kUnknown = "<unknown>";
for (int d = 0; d < num_examples; d++) { for (int d = 0; d < num_examples; d++) {
const string& example = serialized[d]; const tstring& example = serialized[d];
const string& example_name = const tstring& example_name =
example_names.empty() ? kUnknown : example_names[d]; example_names.empty() ? kUnknown : example_names[d];
auto* context_features = &all_context_features[d]; auto* context_features = &all_context_features[d];
auto* sequence_features = &all_sequence_features[d]; auto* sequence_features = &all_sequence_features[d];
@ -2074,7 +2085,7 @@ Status FastParseSequenceExample(
// TODO(sundberg): Refactor to reduce code duplication, and add bounds // TODO(sundberg): Refactor to reduce code duplication, and add bounds
// checking for the outputs. // checking for the outputs.
string* out_bytes = nullptr; tstring* out_bytes = nullptr;
float* out_float = nullptr; float* out_float = nullptr;
int64* out_int64 = nullptr; int64* out_int64 = nullptr;
switch (dtype) { switch (dtype) {
@ -2097,7 +2108,7 @@ Status FastParseSequenceExample(
for (int e = 0; e < num_examples; e++) { for (int e = 0; e < num_examples; e++) {
size_t num_elements = 0; size_t num_elements = 0;
const auto feature_iter = all_context_features[e].find(c.feature_name); const auto feature_iter = all_context_features[e].find(c.feature_name);
const string& example_name = const tstring& example_name =
example_names.empty() ? kUnknown : example_names[e]; example_names.empty() ? kUnknown : example_names[e];
if (feature_iter == all_context_features[e].end()) { if (feature_iter == all_context_features[e].end()) {
// Copy the default value, if present. If not, return an error. // Copy the default value, if present. If not, return an error.
@ -2107,7 +2118,7 @@ Status FastParseSequenceExample(
" (data type: ", DataTypeString(c.dtype), ")", " (data type: ", DataTypeString(c.dtype), ")",
" is required but could not be found."); " is required but could not be found.");
} }
const string* in_bytes = nullptr; const tstring* in_bytes = nullptr;
const float* in_float = nullptr; const float* in_float = nullptr;
const int64* in_int64 = nullptr; const int64* in_int64 = nullptr;
size_t num = 0; size_t num = 0;
@ -2185,7 +2196,7 @@ Status FastParseSequenceExample(
Tensor(allocator, DT_INT64, TensorShape({2})); Tensor(allocator, DT_INT64, TensorShape({2}));
// TODO(sundberg): Refactor to reduce code duplication, and add bounds // TODO(sundberg): Refactor to reduce code duplication, and add bounds
// checking for the outputs. // checking for the outputs.
string* out_bytes = nullptr; tstring* out_bytes = nullptr;
float* out_float = nullptr; float* out_float = nullptr;
int64* out_int64 = nullptr; int64* out_int64 = nullptr;
switch (dtype) { switch (dtype) {
@ -2211,7 +2222,7 @@ Status FastParseSequenceExample(
size_t max_num_cols = 0; size_t max_num_cols = 0;
for (int e = 0; e < num_examples; e++) { for (int e = 0; e < num_examples; e++) {
const auto& feature = all_context_features[e][c.feature_name]; const auto& feature = all_context_features[e][c.feature_name];
const string& example_name = const tstring& example_name =
example_names.empty() ? kUnknown : example_names[e]; example_names.empty() ? kUnknown : example_names[e];
if (!feature.empty()) { if (!feature.empty()) {
protobuf::io::CodedInputStream stream( protobuf::io::CodedInputStream stream(
@ -2276,7 +2287,7 @@ Status FastParseSequenceExample(
Tensor(allocator, DT_INT64, dense_length_shape); Tensor(allocator, DT_INT64, dense_length_shape);
int64* out_lengths = (*dense_feature_lengths)[t].flat<int64>().data(); int64* out_lengths = (*dense_feature_lengths)[t].flat<int64>().data();
string* out_bytes = nullptr; tstring* out_bytes = nullptr;
float* out_float = nullptr; float* out_float = nullptr;
int64* out_int64 = nullptr; int64* out_int64 = nullptr;
switch (dtype) { switch (dtype) {
@ -2299,7 +2310,7 @@ Status FastParseSequenceExample(
for (int e = 0; e < num_examples; e++) { for (int e = 0; e < num_examples; e++) {
size_t num_elements = 0, num_rows = 0; size_t num_elements = 0, num_rows = 0;
const auto feature_iter = all_sequence_features[e].find(c.feature_name); const auto feature_iter = all_sequence_features[e].find(c.feature_name);
const string& example_name = const tstring& example_name =
example_names.empty() ? kUnknown : example_names[e]; example_names.empty() ? kUnknown : example_names[e];
if (feature_iter == all_sequence_features[e].end()) { if (feature_iter == all_sequence_features[e].end()) {
// Return an error if this feature was not allowed to be missing. // Return an error if this feature was not allowed to be missing.
@ -2387,7 +2398,7 @@ Status FastParseSequenceExample(
feature_list_result->sparse_shapes[t] = feature_list_result->sparse_shapes[t] =
Tensor(allocator, DT_INT64, TensorShape({3})); Tensor(allocator, DT_INT64, TensorShape({3}));
string* out_bytes = nullptr; tstring* out_bytes = nullptr;
float* out_float = nullptr; float* out_float = nullptr;
int64* out_int64 = nullptr; int64* out_int64 = nullptr;
switch (dtype) { switch (dtype) {
@ -2416,7 +2427,7 @@ Status FastParseSequenceExample(
size_t max_num_cols = 0; size_t max_num_cols = 0;
for (int e = 0; e < num_examples; e++) { for (int e = 0; e < num_examples; e++) {
const auto& feature = all_sequence_features[e][c.feature_name]; const auto& feature = all_sequence_features[e][c.feature_name];
const string& example_name = const tstring& example_name =
example_names.empty() ? kUnknown : example_names[e]; example_names.empty() ? kUnknown : example_names[e];
if (!feature.empty()) { if (!feature.empty()) {
protobuf::io::CodedInputStream stream( protobuf::io::CodedInputStream stream(

View File

@ -99,8 +99,8 @@ struct Result {
// Given example names have to either be empty or the same size as serialized. // Given example names have to either be empty or the same size as serialized.
// example_names are used only for error messages. // example_names are used only for error messages.
Status FastParseExample(const FastParseExampleConfig& config, Status FastParseExample(const FastParseExampleConfig& config,
gtl::ArraySlice<string> serialized, gtl::ArraySlice<tstring> serialized,
gtl::ArraySlice<string> example_names, gtl::ArraySlice<tstring> example_names,
thread::ThreadPool* thread_pool, Result* result); thread::ThreadPool* thread_pool, Result* result);
// TODO(mrry): Move the hash table construction into the config object. // TODO(mrry): Move the hash table construction into the config object.
@ -116,7 +116,7 @@ Status FastParseSingleExample(const FastParseSingleExampleConfig& config,
Status FastParseSequenceExample( Status FastParseSequenceExample(
const example::FastParseExampleConfig& context_config, const example::FastParseExampleConfig& context_config,
const example::FastParseExampleConfig& feature_list_config, const example::FastParseExampleConfig& feature_list_config,
gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names, gtl::ArraySlice<tstring> serialized, gtl::ArraySlice<tstring> example_names,
thread::ThreadPool* thread_pool, example::Result* context_result, thread::ThreadPool* thread_pool, example::Result* context_result,
example::Result* feature_list_result, example::Result* feature_list_result,
std::vector<Tensor>* dense_feature_lengths); std::vector<Tensor>* dense_feature_lengths);

View File

@ -273,7 +273,7 @@ static void AddSparseFeature(const char* feature_name, DataType dtype,
TEST(FastParse, StatsCollection) { TEST(FastParse, StatsCollection) {
const size_t kNumExamples = 13; const size_t kNumExamples = 13;
std::vector<string> serialized(kNumExamples, ExampleWithSomeFeatures()); std::vector<tstring> serialized(kNumExamples, ExampleWithSomeFeatures());
FastParseExampleConfig config_dense; FastParseExampleConfig config_dense;
AddDenseFeature("bytes_list", DT_STRING, {2}, false, 2, &config_dense); AddDenseFeature("bytes_list", DT_STRING, {2}, false, 2, &config_dense);
@ -417,8 +417,9 @@ TEST(TestFastParseExample, Empty) {
Result result; Result result;
FastParseExampleConfig config; FastParseExampleConfig config;
config.sparse.push_back({"test", DT_STRING}); config.sparse.push_back({"test", DT_STRING});
Status status = FastParseExample(config, gtl::ArraySlice<string>(), Status status =
gtl::ArraySlice<string>(), nullptr, &result); FastParseExample(config, gtl::ArraySlice<tstring>(),
gtl::ArraySlice<tstring>(), nullptr, &result);
EXPECT_TRUE(status.ok()) << status; EXPECT_TRUE(status.ok()) << status;
} }

View File

@ -101,7 +101,7 @@ Status FeatureDenseCopy(const std::size_t out_index, const string& name,
"Values size: ", "Values size: ",
values.value_size(), " but output shape: ", shape.DebugString()); values.value_size(), " but output shape: ", shape.DebugString());
} }
auto out_p = out->flat<string>().data() + offset; auto out_p = out->flat<tstring>().data() + offset;
std::transform(values.value().data(), std::transform(values.value().data(),
values.value().data() + num_elements, out_p, values.value().data() + num_elements, out_p,
[](const string* s) { return *s; }); [](const string* s) { return *s; });
@ -136,7 +136,7 @@ Tensor FeatureSparseCopy(const std::size_t batch, const string& key,
const BytesList& values = feature.bytes_list(); const BytesList& values = feature.bytes_list();
const int64 num_elements = values.value_size(); const int64 num_elements = values.value_size();
Tensor out(dtype, TensorShape({num_elements})); Tensor out(dtype, TensorShape({num_elements}));
auto out_p = out.flat<string>().data(); auto out_p = out.flat<tstring>().data();
std::transform(values.value().data(), std::transform(values.value().data(),
values.value().data() + num_elements, out_p, values.value().data() + num_elements, out_p,
[](const string* s) { return *s; }); [](const string* s) { return *s; });
@ -175,8 +175,8 @@ int64 CopyIntoSparseTensor(const Tensor& in, const int batch,
break; break;
} }
case DT_STRING: { case DT_STRING: {
std::copy_n(in.flat<string>().data(), num_elements, std::copy_n(in.flat<tstring>().data(), num_elements,
values->flat<string>().data() + offset); values->flat<tstring>().data() + offset);
break; break;
} }
default: default:
@ -203,8 +203,9 @@ void RowDenseCopy(const std::size_t& out_index, const DataType& dtype,
break; break;
} }
case DT_STRING: { case DT_STRING: {
std::copy_n(in.flat<string>().data(), num_elements, // TODO(dero): verify.
out->flat<string>().data() + offset); std::copy_n(in.flat<tstring>().data(), num_elements,
out->flat<tstring>().data() + offset);
break; break;
} }
default: default:

View File

@ -337,10 +337,24 @@ inline Status ReadPrimitive(CodedInputStream* input, int index, void* data) {
// serialized proto. // serialized proto.
// May read all or part of a repeated field. // May read all or part of a repeated field.
inline Status ReadBytes(CodedInputStream* input, int index, void* datap) { inline Status ReadBytes(CodedInputStream* input, int index, void* datap) {
string* data = reinterpret_cast<string*>(datap) + index; tstring* data = reinterpret_cast<tstring*>(datap) + index;
#ifdef USE_TSTRING
uint32 length;
if (!input->ReadVarint32(&length)) {
return errors::DataLoss("Failed reading bytes");
}
data->resize_uninitialized(length);
if (!input->ReadRaw(data->data(), length)) {
return errors::DataLoss("Failed reading bytes");
}
#else // USE_TSTRING
if (!WireFormatLite::ReadBytes(input, data)) { if (!WireFormatLite::ReadBytes(input, data)) {
return errors::DataLoss("Failed reading bytes"); return errors::DataLoss("Failed reading bytes");
} }
#endif // USE_TSTRING
return Status::OK(); return Status::OK();
} }
@ -354,8 +368,19 @@ inline Status ReadGroupBytes(CodedInputStream* input, int field_number,
// TODO(nix): there is a faster way to grab TYPE_GROUP bytes by relying // TODO(nix): there is a faster way to grab TYPE_GROUP bytes by relying
// on input->IsFlat() == true and using input->GetDirectBufferPointer() // on input->IsFlat() == true and using input->GetDirectBufferPointer()
// with input->CurrentPosition(). // with input->CurrentPosition().
string* data = reinterpret_cast<string*>(datap) + index; tstring* data = reinterpret_cast<tstring*>(datap) + index;
#ifdef USE_TSTRING
// TODO(dero): To mitigate the string to tstring copy, we can implement our
// own scanner as described above. We would first need to obtain the length
// in an initial pass and resize/reserve the tstring. But, given that
// TYPE_GROUP is deprecated and currently no tests in
// tensorflow/python/kernel_tests/proto:decode_proto_op_test target a
// TYPE_GROUP tag, we use std::string as a read buffer.
string buf;
StringOutputStream string_stream(&buf);
#else // USE_TSTRING
StringOutputStream string_stream(data); StringOutputStream string_stream(data);
#endif // USE_TSTRING
CodedOutputStream out(&string_stream); CodedOutputStream out(&string_stream);
if (!WireFormatLite::SkipField( if (!WireFormatLite::SkipField(
input, input,
@ -364,6 +389,9 @@ inline Status ReadGroupBytes(CodedInputStream* input, int field_number,
&out)) { &out)) {
return errors::DataLoss("Failed reading group"); return errors::DataLoss("Failed reading group");
} }
#ifdef USE_TSTRING
*data = buf;
#endif // USE_TSTRING
return Status::OK(); return Status::OK();
} }

View File

@ -179,29 +179,29 @@ inline void Fill(const Eigen::half* data, size_t n, TensorProto* t) {
// Custom implementation for string. // Custom implementation for string.
template <> template <>
struct SaveTypeTraits<string> { struct SaveTypeTraits<tstring> {
static constexpr bool supported = true; static constexpr bool supported = true;
typedef const string* SavedType; typedef const string* SavedType;
typedef protobuf::RepeatedPtrField<string> RepeatedField; typedef protobuf::RepeatedPtrField<string> RepeatedField;
}; };
template <> template <>
inline const string* const* TensorProtoData<string>(const TensorProto& t) { inline const string* const* TensorProtoData<tstring>(const TensorProto& t) {
static_assert(SaveTypeTraits<string>::supported, static_assert(SaveTypeTraits<tstring>::supported,
"Specified type string not supported for Restore"); "Specified type tstring not supported for Restore");
return t.string_val().data(); return t.string_val().data();
} }
template <> template <>
inline protobuf::RepeatedPtrField<string>* MutableTensorProtoData<string>( inline protobuf::RepeatedPtrField<string>* MutableTensorProtoData<tstring>(
TensorProto* t) { TensorProto* t) {
static_assert(SaveTypeTraits<string>::supported, static_assert(SaveTypeTraits<tstring>::supported,
"Specified type string not supported for Save"); "Specified type tstring not supported for Save");
return t->mutable_string_val(); return t->mutable_string_val();
} }
template <> template <>
inline void Fill(const string* data, size_t n, TensorProto* t) { inline void Fill(const tstring* data, size_t n, TensorProto* t) {
typename protobuf::RepeatedPtrField<string> copy(data, data + n); typename protobuf::RepeatedPtrField<string> copy(data, data + n);
t->mutable_string_val()->Swap(&copy); t->mutable_string_val()->Swap(&copy);
} }

View File

@ -102,7 +102,7 @@ Example of grouping:
Tensor values(DT_STRING, TensorShape({N}); Tensor values(DT_STRING, TensorShape({N});
TensorShape shape({dim0,...}); TensorShape shape({dim0,...});
SparseTensor sp(indices, vals, shape); SparseTensor sp(indices, vals, shape);
sp.Reorder<string>({1, 2, 0, 3, ...}); // Must provide NDIMS dims. sp.Reorder<tstring>({1, 2, 0, 3, ...}); // Must provide NDIMS dims.
// group according to dims 1 and 2 // group according to dims 1 and 2
for (const auto& g : sp.group({1, 2})) { for (const auto& g : sp.group({1, 2})) {
cout << "vals of ix[:, 1,2] for this group: " cout << "vals of ix[:, 1,2] for this group: "
@ -111,7 +111,7 @@ Example of grouping:
cout << "values of group:\n" << g.values(); cout << "values of group:\n" << g.values();
TTypes<int64>::UnalignedMatrix g_ix = g.indices(); TTypes<int64>::UnalignedMatrix g_ix = g.indices();
TTypes<string>::UnalignedVec g_v = g.values(); TTypes<tstring>::UnalignedVec g_v = g.values();
ASSERT(g_ix.dimension(0) == g_v.size()); // number of elements match. ASSERT(g_ix.dimension(0) == g_v.size()); // number of elements match.
} }
@ -133,7 +133,7 @@ Shape checking is performed, as is boundary checking.
Tensor dense(DT_STRING, shape); Tensor dense(DT_STRING, shape);
// initialize other indices to zero. copy. // initialize other indices to zero. copy.
ASSERT(sp.ToDense<string>(&dense, true)); ASSERT(sp.ToDense<tstring>(&dense, true));
Concat Concat
@ -215,7 +215,7 @@ Coding Example:
EXPECT_EQ(conc.Order(), {-1, -1, -1}); EXPECT_EQ(conc.Order(), {-1, -1, -1});
// Reorder st3 so all input tensors have the exact same orders. // Reorder st3 so all input tensors have the exact same orders.
st3.Reorder<string>({1, 0, 2}); st3.Reorder<tstring>({1, 0, 2});
SparseTensor conc2 = SparseTensor::Concat<string>({st1, st2, st3}); SparseTensor conc2 = SparseTensor::Concat<string>({st1, st2, st3});
EXPECT_EQ(conc2.Order(), {1, 0, 2}); EXPECT_EQ(conc2.Order(), {1, 0, 2});
// All indices' orders matched, so output is in order. // All indices' orders matched, so output is in order.

View File

@ -170,7 +170,7 @@ TEST(SparseTensorTest, SparseTensorConstruction) {
int N = 5; int N = 5;
const int NDIM = 3; const int NDIM = 3;
auto ix_c = GetSimpleIndexTensor(N, NDIM); auto ix_c = GetSimpleIndexTensor(N, NDIM);
Eigen::Tensor<string, 1, Eigen::RowMajor> vals_c(N); Eigen::Tensor<tstring, 1, Eigen::RowMajor> vals_c(N);
vals_c(0) = "hi0"; vals_c(0) = "hi0";
vals_c(1) = "hi1"; vals_c(1) = "hi1";
vals_c(2) = "hi2"; vals_c(2) = "hi2";
@ -200,7 +200,7 @@ TEST(SparseTensorTest, SparseTensorConstruction) {
// Regardless of how order is updated; so long as there are no // Regardless of how order is updated; so long as there are no
// duplicates, the resulting indices are valid. // duplicates, the resulting indices are valid.
st.Reorder<string>({2, 0, 1}); st.Reorder<tstring>({2, 0, 1});
TF_EXPECT_OK(st.IndicesValid()); TF_EXPECT_OK(st.IndicesValid());
EXPECT_EQ(vals_t(0), "hi0"); EXPECT_EQ(vals_t(0), "hi0");
EXPECT_EQ(vals_t(1), "hi3"); EXPECT_EQ(vals_t(1), "hi3");
@ -210,7 +210,7 @@ TEST(SparseTensorTest, SparseTensorConstruction) {
ix_t = ix_c; ix_t = ix_c;
vals_t = vals_c; vals_t = vals_c;
st.Reorder<string>({0, 1, 2}); st.Reorder<tstring>({0, 1, 2});
TF_EXPECT_OK(st.IndicesValid()); TF_EXPECT_OK(st.IndicesValid());
EXPECT_EQ(vals_t(0), "hi0"); EXPECT_EQ(vals_t(0), "hi0");
EXPECT_EQ(vals_t(1), "hi4"); EXPECT_EQ(vals_t(1), "hi4");
@ -220,7 +220,7 @@ TEST(SparseTensorTest, SparseTensorConstruction) {
ix_t = ix_c; ix_t = ix_c;
vals_t = vals_c; vals_t = vals_c;
st.Reorder<string>({2, 1, 0}); st.Reorder<tstring>({2, 1, 0});
TF_EXPECT_OK(st.IndicesValid()); TF_EXPECT_OK(st.IndicesValid());
} }
@ -239,7 +239,7 @@ TEST(SparseTensorTest, EmptySparseTensorAllowed) {
EXPECT_EQ(st.order(), order); EXPECT_EQ(st.order(), order);
std::vector<int64> new_order{1, 0, 2}; std::vector<int64> new_order{1, 0, 2};
st.Reorder<string>(new_order); st.Reorder<tstring>(new_order);
TF_EXPECT_OK(st.IndicesValid()); TF_EXPECT_OK(st.IndicesValid());
EXPECT_EQ(st.order(), new_order); EXPECT_EQ(st.order(), new_order);
} }
@ -259,13 +259,13 @@ TEST(SparseTensorTest, SortingWorksCorrectly) {
for (int n = 0; n < 100; ++n) { for (int n = 0; n < 100; ++n) {
ix_t = ix_t.random(Eigen::internal::UniformRandomGenerator<int64>(n + 1)); ix_t = ix_t.random(Eigen::internal::UniformRandomGenerator<int64>(n + 1));
ix_t = ix_t.abs() % 1000; ix_t = ix_t.abs() % 1000;
st.Reorder<string>({0, 1, 2, 3}); st.Reorder<tstring>({0, 1, 2, 3});
TF_EXPECT_OK(st.IndicesValid()); TF_EXPECT_OK(st.IndicesValid());
st.Reorder<string>({3, 2, 1, 0}); st.Reorder<tstring>({3, 2, 1, 0});
TF_EXPECT_OK(st.IndicesValid()); TF_EXPECT_OK(st.IndicesValid());
st.Reorder<string>({1, 0, 2, 3}); st.Reorder<tstring>({1, 0, 2, 3});
TF_EXPECT_OK(st.IndicesValid()); TF_EXPECT_OK(st.IndicesValid());
st.Reorder<string>({3, 0, 2, 1}); st.Reorder<tstring>({3, 0, 2, 1});
TF_EXPECT_OK(st.IndicesValid()); TF_EXPECT_OK(st.IndicesValid());
} }
} }
@ -294,7 +294,7 @@ TEST(SparseTensorTest, ValidateIndicesFindsInvalid) {
SparseTensor st; SparseTensor st;
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st)); TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
st.Reorder<string>(order); st.Reorder<tstring>(order);
Status st_indices_valid = st.IndicesValid(); Status st_indices_valid = st.IndicesValid();
EXPECT_FALSE(st_indices_valid.ok()); EXPECT_FALSE(st_indices_valid.ok());
EXPECT_EQ("indices[1] = [0,0,0] is repeated", EXPECT_EQ("indices[1] = [0,0,0] is repeated",
@ -302,12 +302,12 @@ TEST(SparseTensorTest, ValidateIndicesFindsInvalid) {
ix_orig(1, 2) = 1; ix_orig(1, 2) = 1;
ix_t = ix_orig; ix_t = ix_orig;
st.Reorder<string>(order); st.Reorder<tstring>(order);
TF_EXPECT_OK(st.IndicesValid()); // second index now (0, 0, 1) TF_EXPECT_OK(st.IndicesValid()); // second index now (0, 0, 1)
ix_orig(0, 2) = 1; ix_orig(0, 2) = 1;
ix_t = ix_orig; ix_t = ix_orig;
st.Reorder<string>(order); st.Reorder<tstring>(order);
st_indices_valid = st.IndicesValid(); st_indices_valid = st.IndicesValid();
EXPECT_FALSE(st_indices_valid.ok()); // first index now (0, 0, 1) EXPECT_FALSE(st_indices_valid.ok()); // first index now (0, 0, 1)
EXPECT_EQ("indices[1] = [0,0,1] is repeated", EXPECT_EQ("indices[1] = [0,0,1] is repeated",
@ -332,12 +332,12 @@ TEST(SparseTensorTest, SparseTensorCheckBoundaries) {
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st)); TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
EXPECT_FALSE(st.IndicesValid().ok()); EXPECT_FALSE(st.IndicesValid().ok());
st.Reorder<string>(order); st.Reorder<tstring>(order);
TF_EXPECT_OK(st.IndicesValid()); TF_EXPECT_OK(st.IndicesValid());
ix_t(0, 0) = 11; ix_t(0, 0) = 11;
ix.matrix<int64>() = ix_t; ix.matrix<int64>() = ix_t;
st.Reorder<string>(order); st.Reorder<tstring>(order);
Status st_indices_valid = st.IndicesValid(); Status st_indices_valid = st.IndicesValid();
EXPECT_FALSE(st_indices_valid.ok()); EXPECT_FALSE(st_indices_valid.ok());
// Error message references index 4 because of the call to Reorder. // Error message references index 4 because of the call to Reorder.
@ -346,7 +346,7 @@ TEST(SparseTensorTest, SparseTensorCheckBoundaries) {
ix_t(0, 0) = -1; ix_t(0, 0) = -1;
ix.matrix<int64>() = ix_t; ix.matrix<int64>() = ix_t;
st.Reorder<string>(order); st.Reorder<tstring>(order);
st_indices_valid = st.IndicesValid(); st_indices_valid = st.IndicesValid();
EXPECT_FALSE(st_indices_valid.ok()); EXPECT_FALSE(st_indices_valid.ok());
EXPECT_EQ("[-1,0,0] is out of bounds: need 0 <= index < [10,10,10]", EXPECT_EQ("[-1,0,0] is out of bounds: need 0 <= index < [10,10,10]",
@ -354,7 +354,7 @@ TEST(SparseTensorTest, SparseTensorCheckBoundaries) {
ix_t(0, 0) = 0; ix_t(0, 0) = 0;
ix.matrix<int64>() = ix_t; ix.matrix<int64>() = ix_t;
st.Reorder<string>(order); st.Reorder<tstring>(order);
TF_EXPECT_OK(st.IndicesValid()); TF_EXPECT_OK(st.IndicesValid());
} }
@ -382,9 +382,9 @@ TEST(SparseTensorTest, SparseTensorToDenseTensor) {
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st)); TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
Tensor dense(DT_STRING, TensorShape({4, 4, 5})); Tensor dense(DT_STRING, TensorShape({4, 4, 5}));
st.ToDense<string>(&dense); st.ToDense<tstring>(&dense);
auto dense_t = dense.tensor<string, 3>(); auto dense_t = dense.tensor<tstring, 3>();
Eigen::array<Eigen::DenseIndex, NDIM> ix_n; Eigen::array<Eigen::DenseIndex, NDIM> ix_n;
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
for (int d = 0; d < NDIM; ++d) ix_n[d] = ix_t(n, d); for (int d = 0; d < NDIM; ++d) ix_n[d] = ix_t(n, d);
@ -422,9 +422,9 @@ TEST(SparseTensorTest, SparseTensorToLargerDenseTensor) {
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st)); TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
Tensor dense(DT_STRING, TensorShape({10, 10, 10})); Tensor dense(DT_STRING, TensorShape({10, 10, 10}));
st.ToDense<string>(&dense); st.ToDense<tstring>(&dense);
auto dense_t = dense.tensor<string, 3>(); auto dense_t = dense.tensor<tstring, 3>();
Eigen::array<Eigen::DenseIndex, NDIM> ix_n; Eigen::array<Eigen::DenseIndex, NDIM> ix_n;
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
for (int d = 0; d < NDIM; ++d) ix_n[d] = ix_t(n, d); for (int d = 0; d < NDIM; ++d) ix_n[d] = ix_t(n, d);
@ -554,10 +554,10 @@ TEST(SparseTensorTest, Concat) {
SparseTensor st; SparseTensor st;
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st)); TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
EXPECT_FALSE(st.IndicesValid().ok()); EXPECT_FALSE(st.IndicesValid().ok());
st.Reorder<string>(order); st.Reorder<tstring>(order);
TF_EXPECT_OK(st.IndicesValid()); TF_EXPECT_OK(st.IndicesValid());
SparseTensor concatted = SparseTensor::Concat<string>({st, st, st, st}); SparseTensor concatted = SparseTensor::Concat<tstring>({st, st, st, st});
EXPECT_EQ(concatted.order(), st.order()); EXPECT_EQ(concatted.order(), st.order());
gtl::InlinedVector<int64, 8> expected_shape{40, 10, 10}; gtl::InlinedVector<int64, 8> expected_shape{40, 10, 10};
EXPECT_EQ(concatted.shape(), expected_shape); EXPECT_EQ(concatted.shape(), expected_shape);
@ -585,7 +585,7 @@ TEST(SparseTensorTest, Concat) {
SparseTensor st_ooo; SparseTensor st_ooo;
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, {0, 2, 1}, TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, {0, 2, 1},
&st_ooo)); // non-primary ix OOO &st_ooo)); // non-primary ix OOO
SparseTensor conc_ooo = SparseTensor::Concat<string>({st, st, st, st_ooo}); SparseTensor conc_ooo = SparseTensor::Concat<tstring>({st, st, st, st_ooo});
std::vector<int64> expected_ooo{-1, -1, -1}; std::vector<int64> expected_ooo{-1, -1, -1};
EXPECT_EQ(conc_ooo.order(), expected_ooo); EXPECT_EQ(conc_ooo.order(), expected_ooo);
EXPECT_EQ(conc_ooo.shape(), expected_shape); EXPECT_EQ(conc_ooo.shape(), expected_shape);
@ -782,7 +782,7 @@ static void BM_SparseReorderString(int iters, int N32, int NDIM32) {
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st)); TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
testing::StartTiming(); testing::StartTiming();
st.Reorder<string>(reorder); st.Reorder<tstring>(reorder);
} }
} }

View File

@ -69,7 +69,7 @@ namespace {
// Checksums the string lengths (as restored uint32 or uint64, not varint64 // Checksums the string lengths (as restored uint32 or uint64, not varint64
// bytes) and string bytes, and stores it into "actual_crc32c". // bytes) and string bytes, and stores it into "actual_crc32c".
Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements, Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
size_t offset, size_t size, string* destination, size_t offset, size_t size, tstring* destination,
uint32* actual_crc32c, bool need_to_swap_bytes) { uint32* actual_crc32c, bool need_to_swap_bytes) {
if (size == 0) return Status::OK(); if (size == 0) return Status::OK();
CHECK_GT(size, 0); CHECK_GT(size, 0);
@ -127,7 +127,7 @@ Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
// Reads the actual string bytes. // Reads the actual string bytes.
for (size_t i = 0; i < num_elements; ++i) { for (size_t i = 0; i < num_elements; ++i) {
const uint64 string_length = string_lengths[i]; const uint64 string_length = string_lengths[i];
string* buffer = &destination[i]; tstring* buffer = &destination[i];
buffer->resize(string_length); buffer->resize(string_length);
size_t bytes_read = 0; size_t bytes_read = 0;
@ -205,9 +205,9 @@ char* GetBackingBuffer(const Tensor& val) {
return const_cast<char*>(val.tensor_data().data()); return const_cast<char*>(val.tensor_data().data());
} }
string* GetStringBackingBuffer(const Tensor& val) { tstring* GetStringBackingBuffer(const Tensor& val) {
CHECK_EQ(DT_STRING, val.dtype()); CHECK_EQ(DT_STRING, val.dtype());
return const_cast<string*>(val.flat<string>().data()); return const_cast<tstring*>(val.flat<tstring>().data());
} }
Status ParseEntryProto(StringPiece key, StringPiece value, Status ParseEntryProto(StringPiece key, StringPiece value,
@ -244,14 +244,14 @@ Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
// Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes), // Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes),
// the length-checksum, and all the string bytes. // the length-checksum, and all the string bytes.
DCHECK_EQ(val.dtype(), DT_STRING); DCHECK_EQ(val.dtype(), DT_STRING);
const string* strings = GetStringBackingBuffer(val); const tstring* strings = GetStringBackingBuffer(val);
// Writes the varint lengths. // Writes the varint lengths.
string lengths; string lengths;
lengths.reserve(val.NumElements()); // At least 1 byte per element. lengths.reserve(val.NumElements()); // At least 1 byte per element.
*crc32c = 0; *crc32c = 0;
for (int64 i = 0; i < val.NumElements(); ++i) { for (int64 i = 0; i < val.NumElements(); ++i) {
const string* elem = &strings[i]; const tstring* elem = &strings[i];
DCHECK_EQ(elem->size(), static_cast<uint64>(elem->size())); DCHECK_EQ(elem->size(), static_cast<uint64>(elem->size()));
const uint64 elem_size = static_cast<uint64>(elem->size()); const uint64 elem_size = static_cast<uint64>(elem->size());
@ -281,7 +281,7 @@ Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
// Writes all the string bytes out. // Writes all the string bytes out.
for (int64 i = 0; i < val.NumElements(); ++i) { for (int64 i = 0; i < val.NumElements(); ++i) {
const string* string = &strings[i]; const tstring* string = &strings[i];
TF_RETURN_IF_ERROR(out->Append(*string)); TF_RETURN_IF_ERROR(out->Append(*string));
*bytes_written += string->size(); *bytes_written += string->size();
*crc32c = crc32c::Extend(*crc32c, string->data(), string->size()); *crc32c = crc32c::Extend(*crc32c, string->data(), string->size());
@ -675,7 +675,7 @@ static Status MergeOneBundle(Env* env, StringPiece prefix,
return Status::OK(); return Status::OK();
} }
Status MergeBundles(Env* env, gtl::ArraySlice<string> prefixes, Status MergeBundles(Env* env, gtl::ArraySlice<tstring> prefixes,
StringPiece merged_prefix) { StringPiece merged_prefix) {
// Merges all metadata tables. // Merges all metadata tables.
// TODO(zhifengc): KeyValue sorter if it becomes too big. // TODO(zhifengc): KeyValue sorter if it becomes too big.
@ -823,10 +823,10 @@ Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) {
// Relaxes the check for string tensors as follows: // Relaxes the check for string tensors as follows:
// entry.size() == bytes(varint lengths) + bytes(data) // entry.size() == bytes(varint lengths) + bytes(data)
// >= NumElems + bytes(data), since size bytes(varint) >= 1. // >= NumElems + bytes(data), since size bytes(varint) >= 1.
// TotalBytes() == sizeof(string) * NumElems + bytes(data) // TotalBytes() == sizeof(tstring) * NumElems + bytes(data)
// Since we don't know bytes(varint lengths), we just check an inequality. // Since we don't know bytes(varint lengths), we just check an inequality.
const size_t lower_bound = ret->NumElements() + ret->TotalBytes() - const size_t lower_bound = ret->NumElements() + ret->TotalBytes() -
sizeof(string) * ret->NumElements(); sizeof(tstring) * ret->NumElements();
if (entry.size() < lower_bound) { if (entry.size() < lower_bound) {
return errors::DataLoss("Invalid size in bundle entry: key ", key(), return errors::DataLoss("Invalid size in bundle entry: key ", key(),
"; stored size ", entry.size(), "; stored size ", entry.size(),

View File

@ -172,7 +172,7 @@ class BundleWriter {
// //
// Once merged, makes a best effort to delete the old metadata files. // Once merged, makes a best effort to delete the old metadata files.
// Returns OK iff all bundles are successfully merged. // Returns OK iff all bundles are successfully merged.
Status MergeBundles(Env* env, gtl::ArraySlice<string> prefixes, Status MergeBundles(Env* env, gtl::ArraySlice<tstring> prefixes,
StringPiece merged_prefix); StringPiece merged_prefix);
// On construction, silently attempts to read the metadata associated with // On construction, silently attempts to read the metadata associated with

View File

@ -710,11 +710,12 @@ TEST(TensorBundleTest, StringTensorsOldFormat) {
EXPECT_EQ(AllTensorKeys(&reader), EXPECT_EQ(AllTensorKeys(&reader),
std::vector<string>({"floats", "scalar", "string_tensor", "strs"})); std::vector<string>({"floats", "scalar", "string_tensor", "strs"}));
Expect<string>(&reader, "string_tensor", Tensor(DT_STRING, TensorShape({1}))); Expect<tstring>(&reader, "string_tensor",
Expect<string>(&reader, "scalar", test::AsTensor<string>({"hello"})); Tensor(DT_STRING, TensorShape({1})));
Expect<string>( Expect<tstring>(&reader, "scalar", test::AsTensor<tstring>({"hello"}));
Expect<tstring>(
&reader, "strs", &reader, "strs",
test::AsTensor<string>({"hello", "", "x01", string(1 << 10, 'c')})); test::AsTensor<tstring>({"hello", "", "x01", string(1 << 10, 'c')}));
Expect<float>(&reader, "floats", Constant_2x3<float>(16.18)); Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
} }
@ -726,14 +727,19 @@ TEST(TensorBundleTest, StringTensors) {
BundleWriter writer(Env::Default(), Prefix("foo")); BundleWriter writer(Env::Default(), Prefix("foo"));
TF_EXPECT_OK(writer.Add("string_tensor", TF_EXPECT_OK(writer.Add("string_tensor",
Tensor(DT_STRING, TensorShape({1})))); // Empty. Tensor(DT_STRING, TensorShape({1})))); // Empty.
TF_EXPECT_OK(writer.Add("scalar", test::AsTensor<string>({"hello"}))); TF_EXPECT_OK(writer.Add("scalar", test::AsTensor<tstring>({"hello"})));
TF_EXPECT_OK(writer.Add( TF_EXPECT_OK(writer.Add(
"strs", "strs",
test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')}))); test::AsTensor<tstring>({"hello", "", "x01", string(1 << 25, 'c')})));
// Requires a 64-bit length. // Requires a 64-bit length.
string* backing_string = long_string_tensor.flat<string>().data(); tstring* backing_string = long_string_tensor.flat<tstring>().data();
#ifdef USE_TSTRING
backing_string->resize_uninitialized(kLongLength);
std::char_traits<char>::assign(backing_string->data(), kLongLength, 'd');
#else // USE_TSTRING
backing_string->assign(kLongLength, 'd'); backing_string->assign(kLongLength, 'd');
#endif // USE_TSTRING
TF_EXPECT_OK(writer.Add("long_scalar", long_string_tensor)); TF_EXPECT_OK(writer.Add("long_scalar", long_string_tensor));
// Mixes in some floats. // Mixes in some floats.
@ -747,12 +753,12 @@ TEST(TensorBundleTest, StringTensors) {
std::vector<string>({"floats", "long_scalar", "scalar", std::vector<string>({"floats", "long_scalar", "scalar",
"string_tensor", "strs"})); "string_tensor", "strs"}));
Expect<string>(&reader, "string_tensor", Expect<tstring>(&reader, "string_tensor",
Tensor(DT_STRING, TensorShape({1}))); Tensor(DT_STRING, TensorShape({1})));
Expect<string>(&reader, "scalar", test::AsTensor<string>({"hello"})); Expect<tstring>(&reader, "scalar", test::AsTensor<tstring>({"hello"}));
Expect<string>( Expect<tstring>(
&reader, "strs", &reader, "strs",
test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')})); test::AsTensor<tstring>({"hello", "", "x01", string(1 << 25, 'c')}));
Expect<float>(&reader, "floats", Constant_2x3<float>(16.18)); Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
@ -767,17 +773,17 @@ TEST(TensorBundleTest, StringTensors) {
EXPECT_EQ(TensorShape({1}), shape); EXPECT_EQ(TensorShape({1}), shape);
// Zero-out the string so that we can be sure the new one is read in. // Zero-out the string so that we can be sure the new one is read in.
string* backing_string = long_string_tensor.flat<string>().data(); tstring* backing_string = long_string_tensor.flat<tstring>().data();
backing_string->assign(""); backing_string->assign("");
// Read long_scalar and check it contains kLongLength 'd's. // Read long_scalar and check it contains kLongLength 'd's.
TF_ASSERT_OK(reader.Lookup("long_scalar", &long_string_tensor)); TF_ASSERT_OK(reader.Lookup("long_scalar", &long_string_tensor));
ASSERT_EQ(backing_string, long_string_tensor.flat<string>().data()); ASSERT_EQ(backing_string, long_string_tensor.flat<tstring>().data());
EXPECT_EQ(kLongLength, backing_string->length()); EXPECT_EQ(kLongLength, backing_string->length());
for (char c : *backing_string) { for (size_t i = 0; i < kLongLength; i++) {
// Not using ASSERT_EQ('d', c) because this way is twice as fast due to // Not using ASSERT_EQ('d', c) because this way is twice as fast due to
// compiler optimizations. // compiler optimizations.
if (c != 'd') { if ((*backing_string)[i] != 'd') {
FAIL() << "long_scalar is not full of 'd's as expected."; FAIL() << "long_scalar is not full of 'd's as expected.";
break; break;
} }
@ -945,7 +951,7 @@ TEST(TensorBundleTest, Checksum) {
auto WriteStrings = []() { auto WriteStrings = []() {
BundleWriter writer(Env::Default(), Prefix("strings")); BundleWriter writer(Env::Default(), Prefix("strings"));
TF_EXPECT_OK( TF_EXPECT_OK(
writer.Add("foo", test::AsTensor<string>({"hello", "world"}))); writer.Add("foo", test::AsTensor<tstring>({"hello", "world"})));
TF_ASSERT_OK(writer.Finish()); TF_ASSERT_OK(writer.Finish());
}; };
// Corrupts the first two bytes, which are the varint32-encoded lengths // Corrupts the first two bytes, which are the varint32-encoded lengths

View File

@ -55,7 +55,7 @@ struct CopyThatWorksWithStringPointer {
// Eigen makes it extremely difficult to dereference a tensor of string* into // Eigen makes it extremely difficult to dereference a tensor of string* into
// string, so we roll our own loop instead. // string, so we roll our own loop instead.
template <> template <>
struct CopyThatWorksWithStringPointer<string> { struct CopyThatWorksWithStringPointer<tstring> {
template <typename SrcTensor, typename DstTensor, typename Shape> template <typename SrcTensor, typename DstTensor, typename Shape>
static void Copy(const SrcTensor& s, Shape s_start, Shape len, DstTensor& d, static void Copy(const SrcTensor& s, Shape s_start, Shape len, DstTensor& d,
Shape d_start) { Shape d_start) {

View File

@ -176,7 +176,7 @@ size_t TensorSliceWriter::MaxBytesPerElement(DataType dt) {
} }
template <> template <>
Status TensorSliceWriter::SaveData(const string* data, int64 num_elements, Status TensorSliceWriter::SaveData(const tstring* data, int64 num_elements,
SavedSlice* ss) { SavedSlice* ss) {
size_t size_bound = ss->ByteSize() + kTensorProtoHeaderBytes + size_t size_bound = ss->ByteSize() + kTensorProtoHeaderBytes +
(num_elements * MaxBytesPerElement(DT_INT32)); (num_elements * MaxBytesPerElement(DT_INT32));

View File

@ -178,7 +178,7 @@ Status TensorSliceWriter::SaveData(const T* data, int64 num_elements,
} }
template <> template <>
Status TensorSliceWriter::SaveData(const string* data, int64 num_elements, Status TensorSliceWriter::SaveData(const tstring* data, int64 num_elements,
SavedSlice* ss); SavedSlice* ss);
// Create a table builder that will write to "filename" in // Create a table builder that will write to "filename" in

View File

@ -342,7 +342,7 @@ TEST(TensorSliceWriteTest, SizeErrors) {
{ {
TensorShape shape({256, 1024}); TensorShape shape({256, 1024});
TensorSlice slice = TensorSlice::ParseOrDie("-:-"); TensorSlice slice = TensorSlice::ParseOrDie("-:-");
const std::vector<string> data(256 * 1024, std::string(8192, 'f')); const std::vector<tstring> data(256 * 1024, std::string(8192, 'f'));
Status s = writer.Add("test2", shape, slice, data.data()); Status s = writer.Add("test2", shape, slice, data.data());
EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
EXPECT_TRUE(absl::StrContains(s.error_message(), EXPECT_TRUE(absl::StrContains(s.error_message(),