Updated util/ to use tstring.

This is a part of a larger migration effort for tensorflow::tstring.
See: https://github.com/tensorflow/community/pull/91
PiperOrigin-RevId: 263574807
This commit is contained in:
Dero Gharibian 2019-08-15 09:04:45 -07:00 committed by TensorFlower Gardener
parent 0e271c3e39
commit 7adc342449
16 changed files with 159 additions and 112 deletions

View File

@ -50,8 +50,8 @@ Status HandleElementToSlice(T* src, T* dest, int64 num_values,
}
template <>
Status HandleElementToSlice<string>(string* src, string* dest, int64 num_values,
bool can_move) {
Status HandleElementToSlice<tstring>(tstring* src, tstring* dest,
int64 num_values, bool can_move) {
if (can_move) {
for (int64 i = 0; i < num_values; ++i) {
*dest++ = std::move(*src++);

View File

@ -465,7 +465,7 @@ enum class Type { Sparse, Dense };
struct SparseBuffer {
// Features are in one of the 3 vectors below depending on config's dtype.
// Other 2 vectors remain empty.
SmallVector<string> bytes_list;
SmallVector<tstring> bytes_list;
SmallVector<float> float_list;
SmallVector<int64> int64_list;
@ -666,8 +666,8 @@ Status FastParseSerializedExample(
break;
}
case DT_STRING: {
auto out_p = out.flat<string>().data() + offset;
LimitedArraySlice<string> slice(out_p, num_elements);
auto out_p = out.flat<tstring>().data() + offset;
LimitedArraySlice<tstring> slice(out_p, num_elements);
if (!feature.ParseBytesList(&slice)) return parse_error();
if (slice.EndDistance() != 0) {
return shape_error(num_elements - slice.EndDistance(), "bytes");
@ -907,7 +907,7 @@ const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer) {
return buffer.float_list;
}
template <>
const SmallVector<string>& GetListFromBuffer<string>(
const SmallVector<tstring>& GetListFromBuffer<tstring>(
const SparseBuffer& buffer) {
return buffer.bytes_list;
}
@ -917,7 +917,7 @@ void CopyOrMoveBlock(const T* b, const T* e, T* t) {
std::copy(b, e, t);
}
template <>
void CopyOrMoveBlock(const string* b, const string* e, string* t) {
void CopyOrMoveBlock(const tstring* b, const tstring* e, tstring* t) {
std::move(b, e, t);
}
@ -1002,8 +1002,8 @@ class TensorVector {
} // namespace
Status FastParseExample(const Config& config,
gtl::ArraySlice<string> serialized,
gtl::ArraySlice<string> example_names,
gtl::ArraySlice<tstring> serialized,
gtl::ArraySlice<tstring> example_names,
thread::ThreadPool* thread_pool, Result* result) {
DCHECK(result != nullptr);
// Check config so we can safely CHECK(false) in switches on config.*.dtype
@ -1253,8 +1253,8 @@ Status FastParseExample(const Config& config,
break;
}
case DT_STRING: {
FillAndCopyVarLen<string>(d, num_elements, num_elements_per_minibatch,
config, varlen_dense_buffers, &values);
FillAndCopyVarLen<tstring>(d, num_elements, num_elements_per_minibatch,
config, varlen_dense_buffers, &values);
break;
}
default:
@ -1440,8 +1440,8 @@ Status FastParseSingleExample(const Config& config,
break;
}
case DT_STRING: {
auto out_p = out->flat<string>().data();
LimitedArraySlice<string> slice(out_p, num_elements);
auto out_p = out->flat<tstring>().data();
LimitedArraySlice<tstring> slice(out_p, num_elements);
if (!feature.ParseBytesList(&slice)) return parse_error();
if (slice.EndDistance() != 0) {
return parse_error();
@ -1453,7 +1453,7 @@ Status FastParseSingleExample(const Config& config,
}
} else { // if variable length
SmallVector<string> bytes_list;
SmallVector<tstring> bytes_list;
TensorVector<float> float_list;
SmallVector<int64> int64_list;
@ -1627,7 +1627,7 @@ Status FastParseSingleExample(const Config& config,
// Return the number of bytes elements parsed, or -1 on error. If out is null,
// this method simply counts the number of elements without any copying.
inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
string* out) {
tstring* out) {
int num_elements = 0;
uint32 length;
if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) {
@ -1638,12 +1638,23 @@ inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
while (!stream->ExpectAtEnd()) {
uint32 bytes_length;
if (!stream->ExpectTag(kDelimitedTag(1)) ||
!stream->ReadVarint32(&bytes_length) ||
(out != nullptr && !stream->ReadString(out++, bytes_length))) {
!stream->ReadVarint32(&bytes_length)) {
return -1;
}
if (out == nullptr) {
stream->Skip(bytes_length);
} else {
#ifdef USE_TSTRING
out->resize_uninitialized(bytes_length);
if (!stream->ReadRaw(out->data(), bytes_length)) {
return -1;
}
#else // USE_TSTRING
if (!stream->ReadString(out, bytes_length)) {
return -1;
}
#endif // USE_TSTRING
out++;
}
num_elements++;
}
@ -1809,7 +1820,7 @@ inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
Status FastParseSequenceExample(
const FastParseExampleConfig& context_config,
const FastParseExampleConfig& feature_list_config,
gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names,
gtl::ArraySlice<tstring> serialized, gtl::ArraySlice<tstring> example_names,
thread::ThreadPool* thread_pool, Result* context_result,
Result* feature_list_result, std::vector<Tensor>* dense_feature_lengths) {
int num_examples = serialized.size();
@ -1878,10 +1889,10 @@ Status FastParseSequenceExample(
all_context_features(num_examples);
std::vector<absl::flat_hash_map<StringPiece, StringPiece>>
all_sequence_features(num_examples);
const string kUnknown = "<unknown>";
const tstring kUnknown = "<unknown>";
for (int d = 0; d < num_examples; d++) {
const string& example = serialized[d];
const string& example_name =
const tstring& example = serialized[d];
const tstring& example_name =
example_names.empty() ? kUnknown : example_names[d];
auto* context_features = &all_context_features[d];
auto* sequence_features = &all_sequence_features[d];
@ -2074,7 +2085,7 @@ Status FastParseSequenceExample(
// TODO(sundberg): Refactor to reduce code duplication, and add bounds
// checking for the outputs.
string* out_bytes = nullptr;
tstring* out_bytes = nullptr;
float* out_float = nullptr;
int64* out_int64 = nullptr;
switch (dtype) {
@ -2097,7 +2108,7 @@ Status FastParseSequenceExample(
for (int e = 0; e < num_examples; e++) {
size_t num_elements = 0;
const auto feature_iter = all_context_features[e].find(c.feature_name);
const string& example_name =
const tstring& example_name =
example_names.empty() ? kUnknown : example_names[e];
if (feature_iter == all_context_features[e].end()) {
// Copy the default value, if present. If not, return an error.
@ -2107,7 +2118,7 @@ Status FastParseSequenceExample(
" (data type: ", DataTypeString(c.dtype), ")",
" is required but could not be found.");
}
const string* in_bytes = nullptr;
const tstring* in_bytes = nullptr;
const float* in_float = nullptr;
const int64* in_int64 = nullptr;
size_t num = 0;
@ -2185,7 +2196,7 @@ Status FastParseSequenceExample(
Tensor(allocator, DT_INT64, TensorShape({2}));
// TODO(sundberg): Refactor to reduce code duplication, and add bounds
// checking for the outputs.
string* out_bytes = nullptr;
tstring* out_bytes = nullptr;
float* out_float = nullptr;
int64* out_int64 = nullptr;
switch (dtype) {
@ -2211,7 +2222,7 @@ Status FastParseSequenceExample(
size_t max_num_cols = 0;
for (int e = 0; e < num_examples; e++) {
const auto& feature = all_context_features[e][c.feature_name];
const string& example_name =
const tstring& example_name =
example_names.empty() ? kUnknown : example_names[e];
if (!feature.empty()) {
protobuf::io::CodedInputStream stream(
@ -2276,7 +2287,7 @@ Status FastParseSequenceExample(
Tensor(allocator, DT_INT64, dense_length_shape);
int64* out_lengths = (*dense_feature_lengths)[t].flat<int64>().data();
string* out_bytes = nullptr;
tstring* out_bytes = nullptr;
float* out_float = nullptr;
int64* out_int64 = nullptr;
switch (dtype) {
@ -2299,7 +2310,7 @@ Status FastParseSequenceExample(
for (int e = 0; e < num_examples; e++) {
size_t num_elements = 0, num_rows = 0;
const auto feature_iter = all_sequence_features[e].find(c.feature_name);
const string& example_name =
const tstring& example_name =
example_names.empty() ? kUnknown : example_names[e];
if (feature_iter == all_sequence_features[e].end()) {
// Return an error if this feature was not allowed to be missing.
@ -2387,7 +2398,7 @@ Status FastParseSequenceExample(
feature_list_result->sparse_shapes[t] =
Tensor(allocator, DT_INT64, TensorShape({3}));
string* out_bytes = nullptr;
tstring* out_bytes = nullptr;
float* out_float = nullptr;
int64* out_int64 = nullptr;
switch (dtype) {
@ -2416,7 +2427,7 @@ Status FastParseSequenceExample(
size_t max_num_cols = 0;
for (int e = 0; e < num_examples; e++) {
const auto& feature = all_sequence_features[e][c.feature_name];
const string& example_name =
const tstring& example_name =
example_names.empty() ? kUnknown : example_names[e];
if (!feature.empty()) {
protobuf::io::CodedInputStream stream(

View File

@ -99,8 +99,8 @@ struct Result {
// Given example names have to either be empty or the same size as serialized.
// example_names are used only for error messages.
Status FastParseExample(const FastParseExampleConfig& config,
gtl::ArraySlice<string> serialized,
gtl::ArraySlice<string> example_names,
gtl::ArraySlice<tstring> serialized,
gtl::ArraySlice<tstring> example_names,
thread::ThreadPool* thread_pool, Result* result);
// TODO(mrry): Move the hash table construction into the config object.
@ -116,7 +116,7 @@ Status FastParseSingleExample(const FastParseSingleExampleConfig& config,
Status FastParseSequenceExample(
const example::FastParseExampleConfig& context_config,
const example::FastParseExampleConfig& feature_list_config,
gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names,
gtl::ArraySlice<tstring> serialized, gtl::ArraySlice<tstring> example_names,
thread::ThreadPool* thread_pool, example::Result* context_result,
example::Result* feature_list_result,
std::vector<Tensor>* dense_feature_lengths);

View File

@ -273,7 +273,7 @@ static void AddSparseFeature(const char* feature_name, DataType dtype,
TEST(FastParse, StatsCollection) {
const size_t kNumExamples = 13;
std::vector<string> serialized(kNumExamples, ExampleWithSomeFeatures());
std::vector<tstring> serialized(kNumExamples, ExampleWithSomeFeatures());
FastParseExampleConfig config_dense;
AddDenseFeature("bytes_list", DT_STRING, {2}, false, 2, &config_dense);
@ -417,8 +417,9 @@ TEST(TestFastParseExample, Empty) {
Result result;
FastParseExampleConfig config;
config.sparse.push_back({"test", DT_STRING});
Status status = FastParseExample(config, gtl::ArraySlice<string>(),
gtl::ArraySlice<string>(), nullptr, &result);
Status status =
FastParseExample(config, gtl::ArraySlice<tstring>(),
gtl::ArraySlice<tstring>(), nullptr, &result);
EXPECT_TRUE(status.ok()) << status;
}

View File

@ -101,7 +101,7 @@ Status FeatureDenseCopy(const std::size_t out_index, const string& name,
"Values size: ",
values.value_size(), " but output shape: ", shape.DebugString());
}
auto out_p = out->flat<string>().data() + offset;
auto out_p = out->flat<tstring>().data() + offset;
std::transform(values.value().data(),
values.value().data() + num_elements, out_p,
[](const string* s) { return *s; });
@ -136,7 +136,7 @@ Tensor FeatureSparseCopy(const std::size_t batch, const string& key,
const BytesList& values = feature.bytes_list();
const int64 num_elements = values.value_size();
Tensor out(dtype, TensorShape({num_elements}));
auto out_p = out.flat<string>().data();
auto out_p = out.flat<tstring>().data();
std::transform(values.value().data(),
values.value().data() + num_elements, out_p,
[](const string* s) { return *s; });
@ -175,8 +175,8 @@ int64 CopyIntoSparseTensor(const Tensor& in, const int batch,
break;
}
case DT_STRING: {
std::copy_n(in.flat<string>().data(), num_elements,
values->flat<string>().data() + offset);
std::copy_n(in.flat<tstring>().data(), num_elements,
values->flat<tstring>().data() + offset);
break;
}
default:
@ -203,8 +203,9 @@ void RowDenseCopy(const std::size_t& out_index, const DataType& dtype,
break;
}
case DT_STRING: {
std::copy_n(in.flat<string>().data(), num_elements,
out->flat<string>().data() + offset);
// TODO(dero): verify.
std::copy_n(in.flat<tstring>().data(), num_elements,
out->flat<tstring>().data() + offset);
break;
}
default:

View File

@ -337,10 +337,24 @@ inline Status ReadPrimitive(CodedInputStream* input, int index, void* data) {
// serialized proto.
// May read all or part of a repeated field.
inline Status ReadBytes(CodedInputStream* input, int index, void* datap) {
string* data = reinterpret_cast<string*>(datap) + index;
tstring* data = reinterpret_cast<tstring*>(datap) + index;
#ifdef USE_TSTRING
uint32 length;
if (!input->ReadVarint32(&length)) {
return errors::DataLoss("Failed reading bytes");
}
data->resize_uninitialized(length);
if (!input->ReadRaw(data->data(), length)) {
return errors::DataLoss("Failed reading bytes");
}
#else // USE_TSTRING
if (!WireFormatLite::ReadBytes(input, data)) {
return errors::DataLoss("Failed reading bytes");
}
#endif // USE_TSTRING
return Status::OK();
}
@ -354,8 +368,19 @@ inline Status ReadGroupBytes(CodedInputStream* input, int field_number,
// TODO(nix): there is a faster way to grab TYPE_GROUP bytes by relying
// on input->IsFlat() == true and using input->GetDirectBufferPointer()
// with input->CurrentPosition().
string* data = reinterpret_cast<string*>(datap) + index;
tstring* data = reinterpret_cast<tstring*>(datap) + index;
#ifdef USE_TSTRING
// TODO(dero): To mitigate the string to tstring copy, we can implement our
// own scanner as described above. We would first need to obtain the length
// in an initial pass and resize/reserve the tstring. But, given that
// TYPE_GROUP is deprecated and currently no tests in
// tensorflow/python/kernel_tests/proto:decode_proto_op_test target a
// TYPE_GROUP tag, we use std::string as a read buffer.
string buf;
StringOutputStream string_stream(&buf);
#else // USE_TSTRING
StringOutputStream string_stream(data);
#endif // USE_TSTRING
CodedOutputStream out(&string_stream);
if (!WireFormatLite::SkipField(
input,
@ -364,6 +389,9 @@ inline Status ReadGroupBytes(CodedInputStream* input, int field_number,
&out)) {
return errors::DataLoss("Failed reading group");
}
#ifdef USE_TSTRING
*data = buf;
#endif // USE_TSTRING
return Status::OK();
}

View File

@ -179,29 +179,29 @@ inline void Fill(const Eigen::half* data, size_t n, TensorProto* t) {
// Custom implementation for string.
template <>
struct SaveTypeTraits<string> {
struct SaveTypeTraits<tstring> {
static constexpr bool supported = true;
typedef const string* SavedType;
typedef protobuf::RepeatedPtrField<string> RepeatedField;
};
template <>
inline const string* const* TensorProtoData<string>(const TensorProto& t) {
static_assert(SaveTypeTraits<string>::supported,
"Specified type string not supported for Restore");
inline const string* const* TensorProtoData<tstring>(const TensorProto& t) {
static_assert(SaveTypeTraits<tstring>::supported,
"Specified type tstring not supported for Restore");
return t.string_val().data();
}
template <>
inline protobuf::RepeatedPtrField<string>* MutableTensorProtoData<string>(
inline protobuf::RepeatedPtrField<string>* MutableTensorProtoData<tstring>(
TensorProto* t) {
static_assert(SaveTypeTraits<string>::supported,
"Specified type string not supported for Save");
static_assert(SaveTypeTraits<tstring>::supported,
"Specified type tstring not supported for Save");
return t->mutable_string_val();
}
template <>
inline void Fill(const string* data, size_t n, TensorProto* t) {
inline void Fill(const tstring* data, size_t n, TensorProto* t) {
typename protobuf::RepeatedPtrField<string> copy(data, data + n);
t->mutable_string_val()->Swap(&copy);
}

View File

@ -102,7 +102,7 @@ Example of grouping:
Tensor values(DT_STRING, TensorShape({N});
TensorShape shape({dim0,...});
SparseTensor sp(indices, vals, shape);
sp.Reorder<string>({1, 2, 0, 3, ...}); // Must provide NDIMS dims.
sp.Reorder<tstring>({1, 2, 0, 3, ...}); // Must provide NDIMS dims.
// group according to dims 1 and 2
for (const auto& g : sp.group({1, 2})) {
cout << "vals of ix[:, 1,2] for this group: "
@ -111,7 +111,7 @@ Example of grouping:
cout << "values of group:\n" << g.values();
TTypes<int64>::UnalignedMatrix g_ix = g.indices();
TTypes<string>::UnalignedVec g_v = g.values();
TTypes<tstring>::UnalignedVec g_v = g.values();
ASSERT(g_ix.dimension(0) == g_v.size()); // number of elements match.
}
@ -133,7 +133,7 @@ Shape checking is performed, as is boundary checking.
Tensor dense(DT_STRING, shape);
// initialize other indices to zero. copy.
ASSERT(sp.ToDense<string>(&dense, true));
ASSERT(sp.ToDense<tstring>(&dense, true));
Concat
@ -215,7 +215,7 @@ Coding Example:
EXPECT_EQ(conc.Order(), {-1, -1, -1});
// Reorder st3 so all input tensors have the exact same orders.
st3.Reorder<string>({1, 0, 2});
st3.Reorder<tstring>({1, 0, 2});
SparseTensor conc2 = SparseTensor::Concat<string>({st1, st2, st3});
EXPECT_EQ(conc2.Order(), {1, 0, 2});
// All indices' orders matched, so output is in order.

View File

@ -170,7 +170,7 @@ TEST(SparseTensorTest, SparseTensorConstruction) {
int N = 5;
const int NDIM = 3;
auto ix_c = GetSimpleIndexTensor(N, NDIM);
Eigen::Tensor<string, 1, Eigen::RowMajor> vals_c(N);
Eigen::Tensor<tstring, 1, Eigen::RowMajor> vals_c(N);
vals_c(0) = "hi0";
vals_c(1) = "hi1";
vals_c(2) = "hi2";
@ -200,7 +200,7 @@ TEST(SparseTensorTest, SparseTensorConstruction) {
// Regardless of how order is updated; so long as there are no
// duplicates, the resulting indices are valid.
st.Reorder<string>({2, 0, 1});
st.Reorder<tstring>({2, 0, 1});
TF_EXPECT_OK(st.IndicesValid());
EXPECT_EQ(vals_t(0), "hi0");
EXPECT_EQ(vals_t(1), "hi3");
@ -210,7 +210,7 @@ TEST(SparseTensorTest, SparseTensorConstruction) {
ix_t = ix_c;
vals_t = vals_c;
st.Reorder<string>({0, 1, 2});
st.Reorder<tstring>({0, 1, 2});
TF_EXPECT_OK(st.IndicesValid());
EXPECT_EQ(vals_t(0), "hi0");
EXPECT_EQ(vals_t(1), "hi4");
@ -220,7 +220,7 @@ TEST(SparseTensorTest, SparseTensorConstruction) {
ix_t = ix_c;
vals_t = vals_c;
st.Reorder<string>({2, 1, 0});
st.Reorder<tstring>({2, 1, 0});
TF_EXPECT_OK(st.IndicesValid());
}
@ -239,7 +239,7 @@ TEST(SparseTensorTest, EmptySparseTensorAllowed) {
EXPECT_EQ(st.order(), order);
std::vector<int64> new_order{1, 0, 2};
st.Reorder<string>(new_order);
st.Reorder<tstring>(new_order);
TF_EXPECT_OK(st.IndicesValid());
EXPECT_EQ(st.order(), new_order);
}
@ -259,13 +259,13 @@ TEST(SparseTensorTest, SortingWorksCorrectly) {
for (int n = 0; n < 100; ++n) {
ix_t = ix_t.random(Eigen::internal::UniformRandomGenerator<int64>(n + 1));
ix_t = ix_t.abs() % 1000;
st.Reorder<string>({0, 1, 2, 3});
st.Reorder<tstring>({0, 1, 2, 3});
TF_EXPECT_OK(st.IndicesValid());
st.Reorder<string>({3, 2, 1, 0});
st.Reorder<tstring>({3, 2, 1, 0});
TF_EXPECT_OK(st.IndicesValid());
st.Reorder<string>({1, 0, 2, 3});
st.Reorder<tstring>({1, 0, 2, 3});
TF_EXPECT_OK(st.IndicesValid());
st.Reorder<string>({3, 0, 2, 1});
st.Reorder<tstring>({3, 0, 2, 1});
TF_EXPECT_OK(st.IndicesValid());
}
}
@ -294,7 +294,7 @@ TEST(SparseTensorTest, ValidateIndicesFindsInvalid) {
SparseTensor st;
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
st.Reorder<string>(order);
st.Reorder<tstring>(order);
Status st_indices_valid = st.IndicesValid();
EXPECT_FALSE(st_indices_valid.ok());
EXPECT_EQ("indices[1] = [0,0,0] is repeated",
@ -302,12 +302,12 @@ TEST(SparseTensorTest, ValidateIndicesFindsInvalid) {
ix_orig(1, 2) = 1;
ix_t = ix_orig;
st.Reorder<string>(order);
st.Reorder<tstring>(order);
TF_EXPECT_OK(st.IndicesValid()); // second index now (0, 0, 1)
ix_orig(0, 2) = 1;
ix_t = ix_orig;
st.Reorder<string>(order);
st.Reorder<tstring>(order);
st_indices_valid = st.IndicesValid();
EXPECT_FALSE(st_indices_valid.ok()); // first index now (0, 0, 1)
EXPECT_EQ("indices[1] = [0,0,1] is repeated",
@ -332,12 +332,12 @@ TEST(SparseTensorTest, SparseTensorCheckBoundaries) {
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
EXPECT_FALSE(st.IndicesValid().ok());
st.Reorder<string>(order);
st.Reorder<tstring>(order);
TF_EXPECT_OK(st.IndicesValid());
ix_t(0, 0) = 11;
ix.matrix<int64>() = ix_t;
st.Reorder<string>(order);
st.Reorder<tstring>(order);
Status st_indices_valid = st.IndicesValid();
EXPECT_FALSE(st_indices_valid.ok());
// Error message references index 4 because of the call to Reorder.
@ -346,7 +346,7 @@ TEST(SparseTensorTest, SparseTensorCheckBoundaries) {
ix_t(0, 0) = -1;
ix.matrix<int64>() = ix_t;
st.Reorder<string>(order);
st.Reorder<tstring>(order);
st_indices_valid = st.IndicesValid();
EXPECT_FALSE(st_indices_valid.ok());
EXPECT_EQ("[-1,0,0] is out of bounds: need 0 <= index < [10,10,10]",
@ -354,7 +354,7 @@ TEST(SparseTensorTest, SparseTensorCheckBoundaries) {
ix_t(0, 0) = 0;
ix.matrix<int64>() = ix_t;
st.Reorder<string>(order);
st.Reorder<tstring>(order);
TF_EXPECT_OK(st.IndicesValid());
}
@ -382,9 +382,9 @@ TEST(SparseTensorTest, SparseTensorToDenseTensor) {
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
Tensor dense(DT_STRING, TensorShape({4, 4, 5}));
st.ToDense<string>(&dense);
st.ToDense<tstring>(&dense);
auto dense_t = dense.tensor<string, 3>();
auto dense_t = dense.tensor<tstring, 3>();
Eigen::array<Eigen::DenseIndex, NDIM> ix_n;
for (int n = 0; n < N; ++n) {
for (int d = 0; d < NDIM; ++d) ix_n[d] = ix_t(n, d);
@ -422,9 +422,9 @@ TEST(SparseTensorTest, SparseTensorToLargerDenseTensor) {
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
Tensor dense(DT_STRING, TensorShape({10, 10, 10}));
st.ToDense<string>(&dense);
st.ToDense<tstring>(&dense);
auto dense_t = dense.tensor<string, 3>();
auto dense_t = dense.tensor<tstring, 3>();
Eigen::array<Eigen::DenseIndex, NDIM> ix_n;
for (int n = 0; n < N; ++n) {
for (int d = 0; d < NDIM; ++d) ix_n[d] = ix_t(n, d);
@ -554,10 +554,10 @@ TEST(SparseTensorTest, Concat) {
SparseTensor st;
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
EXPECT_FALSE(st.IndicesValid().ok());
st.Reorder<string>(order);
st.Reorder<tstring>(order);
TF_EXPECT_OK(st.IndicesValid());
SparseTensor concatted = SparseTensor::Concat<string>({st, st, st, st});
SparseTensor concatted = SparseTensor::Concat<tstring>({st, st, st, st});
EXPECT_EQ(concatted.order(), st.order());
gtl::InlinedVector<int64, 8> expected_shape{40, 10, 10};
EXPECT_EQ(concatted.shape(), expected_shape);
@ -585,7 +585,7 @@ TEST(SparseTensorTest, Concat) {
SparseTensor st_ooo;
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, {0, 2, 1},
&st_ooo)); // non-primary ix OOO
SparseTensor conc_ooo = SparseTensor::Concat<string>({st, st, st, st_ooo});
SparseTensor conc_ooo = SparseTensor::Concat<tstring>({st, st, st, st_ooo});
std::vector<int64> expected_ooo{-1, -1, -1};
EXPECT_EQ(conc_ooo.order(), expected_ooo);
EXPECT_EQ(conc_ooo.shape(), expected_shape);
@ -782,7 +782,7 @@ static void BM_SparseReorderString(int iters, int N32, int NDIM32) {
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
testing::StartTiming();
st.Reorder<string>(reorder);
st.Reorder<tstring>(reorder);
}
}

View File

@ -69,7 +69,7 @@ namespace {
// Checksums the string lengths (as restored uint32 or uint64, not varint64
// bytes) and string bytes, and stores it into "actual_crc32c".
Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
size_t offset, size_t size, string* destination,
size_t offset, size_t size, tstring* destination,
uint32* actual_crc32c, bool need_to_swap_bytes) {
if (size == 0) return Status::OK();
CHECK_GT(size, 0);
@ -127,7 +127,7 @@ Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
// Reads the actual string bytes.
for (size_t i = 0; i < num_elements; ++i) {
const uint64 string_length = string_lengths[i];
string* buffer = &destination[i];
tstring* buffer = &destination[i];
buffer->resize(string_length);
size_t bytes_read = 0;
@ -205,9 +205,9 @@ char* GetBackingBuffer(const Tensor& val) {
return const_cast<char*>(val.tensor_data().data());
}
string* GetStringBackingBuffer(const Tensor& val) {
tstring* GetStringBackingBuffer(const Tensor& val) {
CHECK_EQ(DT_STRING, val.dtype());
return const_cast<string*>(val.flat<string>().data());
return const_cast<tstring*>(val.flat<tstring>().data());
}
Status ParseEntryProto(StringPiece key, StringPiece value,
@ -244,14 +244,14 @@ Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
// Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes),
// the length-checksum, and all the string bytes.
DCHECK_EQ(val.dtype(), DT_STRING);
const string* strings = GetStringBackingBuffer(val);
const tstring* strings = GetStringBackingBuffer(val);
// Writes the varint lengths.
string lengths;
lengths.reserve(val.NumElements()); // At least 1 byte per element.
*crc32c = 0;
for (int64 i = 0; i < val.NumElements(); ++i) {
const string* elem = &strings[i];
const tstring* elem = &strings[i];
DCHECK_EQ(elem->size(), static_cast<uint64>(elem->size()));
const uint64 elem_size = static_cast<uint64>(elem->size());
@ -281,7 +281,7 @@ Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
// Writes all the string bytes out.
for (int64 i = 0; i < val.NumElements(); ++i) {
const string* string = &strings[i];
const tstring* string = &strings[i];
TF_RETURN_IF_ERROR(out->Append(*string));
*bytes_written += string->size();
*crc32c = crc32c::Extend(*crc32c, string->data(), string->size());
@ -675,7 +675,7 @@ static Status MergeOneBundle(Env* env, StringPiece prefix,
return Status::OK();
}
Status MergeBundles(Env* env, gtl::ArraySlice<string> prefixes,
Status MergeBundles(Env* env, gtl::ArraySlice<tstring> prefixes,
StringPiece merged_prefix) {
// Merges all metadata tables.
// TODO(zhifengc): KeyValue sorter if it becomes too big.
@ -823,10 +823,10 @@ Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) {
// Relaxes the check for string tensors as follows:
// entry.size() == bytes(varint lengths) + bytes(data)
// >= NumElems + bytes(data), since size bytes(varint) >= 1.
// TotalBytes() == sizeof(string) * NumElems + bytes(data)
// TotalBytes() == sizeof(tstring) * NumElems + bytes(data)
// Since we don't know bytes(varint lengths), we just check an inequality.
const size_t lower_bound = ret->NumElements() + ret->TotalBytes() -
sizeof(string) * ret->NumElements();
sizeof(tstring) * ret->NumElements();
if (entry.size() < lower_bound) {
return errors::DataLoss("Invalid size in bundle entry: key ", key(),
"; stored size ", entry.size(),

View File

@ -172,7 +172,7 @@ class BundleWriter {
//
// Once merged, makes a best effort to delete the old metadata files.
// Returns OK iff all bundles are successfully merged.
Status MergeBundles(Env* env, gtl::ArraySlice<string> prefixes,
Status MergeBundles(Env* env, gtl::ArraySlice<tstring> prefixes,
StringPiece merged_prefix);
// On construction, silently attempts to read the metadata associated with

View File

@ -710,11 +710,12 @@ TEST(TensorBundleTest, StringTensorsOldFormat) {
EXPECT_EQ(AllTensorKeys(&reader),
std::vector<string>({"floats", "scalar", "string_tensor", "strs"}));
Expect<string>(&reader, "string_tensor", Tensor(DT_STRING, TensorShape({1})));
Expect<string>(&reader, "scalar", test::AsTensor<string>({"hello"}));
Expect<string>(
Expect<tstring>(&reader, "string_tensor",
Tensor(DT_STRING, TensorShape({1})));
Expect<tstring>(&reader, "scalar", test::AsTensor<tstring>({"hello"}));
Expect<tstring>(
&reader, "strs",
test::AsTensor<string>({"hello", "", "x01", string(1 << 10, 'c')}));
test::AsTensor<tstring>({"hello", "", "x01", string(1 << 10, 'c')}));
Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
}
@ -726,14 +727,19 @@ TEST(TensorBundleTest, StringTensors) {
BundleWriter writer(Env::Default(), Prefix("foo"));
TF_EXPECT_OK(writer.Add("string_tensor",
Tensor(DT_STRING, TensorShape({1})))); // Empty.
TF_EXPECT_OK(writer.Add("scalar", test::AsTensor<string>({"hello"})));
TF_EXPECT_OK(writer.Add("scalar", test::AsTensor<tstring>({"hello"})));
TF_EXPECT_OK(writer.Add(
"strs",
test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')})));
test::AsTensor<tstring>({"hello", "", "x01", string(1 << 25, 'c')})));
// Requires a 64-bit length.
string* backing_string = long_string_tensor.flat<string>().data();
tstring* backing_string = long_string_tensor.flat<tstring>().data();
#ifdef USE_TSTRING
backing_string->resize_uninitialized(kLongLength);
std::char_traits<char>::assign(backing_string->data(), kLongLength, 'd');
#else // USE_TSTRING
backing_string->assign(kLongLength, 'd');
#endif // USE_TSTRING
TF_EXPECT_OK(writer.Add("long_scalar", long_string_tensor));
// Mixes in some floats.
@ -747,12 +753,12 @@ TEST(TensorBundleTest, StringTensors) {
std::vector<string>({"floats", "long_scalar", "scalar",
"string_tensor", "strs"}));
Expect<string>(&reader, "string_tensor",
Tensor(DT_STRING, TensorShape({1})));
Expect<string>(&reader, "scalar", test::AsTensor<string>({"hello"}));
Expect<string>(
Expect<tstring>(&reader, "string_tensor",
Tensor(DT_STRING, TensorShape({1})));
Expect<tstring>(&reader, "scalar", test::AsTensor<tstring>({"hello"}));
Expect<tstring>(
&reader, "strs",
test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')}));
test::AsTensor<tstring>({"hello", "", "x01", string(1 << 25, 'c')}));
Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
@ -767,17 +773,17 @@ TEST(TensorBundleTest, StringTensors) {
EXPECT_EQ(TensorShape({1}), shape);
// Zero-out the string so that we can be sure the new one is read in.
string* backing_string = long_string_tensor.flat<string>().data();
tstring* backing_string = long_string_tensor.flat<tstring>().data();
backing_string->assign("");
// Read long_scalar and check it contains kLongLength 'd's.
TF_ASSERT_OK(reader.Lookup("long_scalar", &long_string_tensor));
ASSERT_EQ(backing_string, long_string_tensor.flat<string>().data());
ASSERT_EQ(backing_string, long_string_tensor.flat<tstring>().data());
EXPECT_EQ(kLongLength, backing_string->length());
for (char c : *backing_string) {
for (size_t i = 0; i < kLongLength; i++) {
// Not using ASSERT_EQ('d', c) because this way is twice as fast due to
// compiler optimizations.
if (c != 'd') {
if ((*backing_string)[i] != 'd') {
FAIL() << "long_scalar is not full of 'd's as expected.";
break;
}
@ -945,7 +951,7 @@ TEST(TensorBundleTest, Checksum) {
auto WriteStrings = []() {
BundleWriter writer(Env::Default(), Prefix("strings"));
TF_EXPECT_OK(
writer.Add("foo", test::AsTensor<string>({"hello", "world"})));
writer.Add("foo", test::AsTensor<tstring>({"hello", "world"})));
TF_ASSERT_OK(writer.Finish());
};
// Corrupts the first two bytes, which are the varint32-encoded lengths

View File

@ -55,7 +55,7 @@ struct CopyThatWorksWithStringPointer {
// Eigen makes it extremely difficult to dereference a tensor of string* into
// string, so we roll our own loop instead.
template <>
struct CopyThatWorksWithStringPointer<string> {
struct CopyThatWorksWithStringPointer<tstring> {
template <typename SrcTensor, typename DstTensor, typename Shape>
static void Copy(const SrcTensor& s, Shape s_start, Shape len, DstTensor& d,
Shape d_start) {

View File

@ -176,7 +176,7 @@ size_t TensorSliceWriter::MaxBytesPerElement(DataType dt) {
}
template <>
Status TensorSliceWriter::SaveData(const string* data, int64 num_elements,
Status TensorSliceWriter::SaveData(const tstring* data, int64 num_elements,
SavedSlice* ss) {
size_t size_bound = ss->ByteSize() + kTensorProtoHeaderBytes +
(num_elements * MaxBytesPerElement(DT_INT32));

View File

@ -178,7 +178,7 @@ Status TensorSliceWriter::SaveData(const T* data, int64 num_elements,
}
template <>
Status TensorSliceWriter::SaveData(const string* data, int64 num_elements,
Status TensorSliceWriter::SaveData(const tstring* data, int64 num_elements,
SavedSlice* ss);
// Create a table builder that will write to "filename" in

View File

@ -342,7 +342,7 @@ TEST(TensorSliceWriteTest, SizeErrors) {
{
TensorShape shape({256, 1024});
TensorSlice slice = TensorSlice::ParseOrDie("-:-");
const std::vector<string> data(256 * 1024, std::string(8192, 'f'));
const std::vector<tstring> data(256 * 1024, std::string(8192, 'f'));
Status s = writer.Add("test2", shape, slice, data.data());
EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
EXPECT_TRUE(absl::StrContains(s.error_message(),