Updated util/ to use tstring.
This is a part of a larger migration effort for tensorflow::tstring. See: https://github.com/tensorflow/community/pull/91 PiperOrigin-RevId: 263574807
This commit is contained in:
parent
0e271c3e39
commit
7adc342449
@ -50,8 +50,8 @@ Status HandleElementToSlice(T* src, T* dest, int64 num_values,
|
||||
}
|
||||
|
||||
template <>
|
||||
Status HandleElementToSlice<string>(string* src, string* dest, int64 num_values,
|
||||
bool can_move) {
|
||||
Status HandleElementToSlice<tstring>(tstring* src, tstring* dest,
|
||||
int64 num_values, bool can_move) {
|
||||
if (can_move) {
|
||||
for (int64 i = 0; i < num_values; ++i) {
|
||||
*dest++ = std::move(*src++);
|
||||
|
@ -465,7 +465,7 @@ enum class Type { Sparse, Dense };
|
||||
struct SparseBuffer {
|
||||
// Features are in one of the 3 vectors below depending on config's dtype.
|
||||
// Other 2 vectors remain empty.
|
||||
SmallVector<string> bytes_list;
|
||||
SmallVector<tstring> bytes_list;
|
||||
SmallVector<float> float_list;
|
||||
SmallVector<int64> int64_list;
|
||||
|
||||
@ -666,8 +666,8 @@ Status FastParseSerializedExample(
|
||||
break;
|
||||
}
|
||||
case DT_STRING: {
|
||||
auto out_p = out.flat<string>().data() + offset;
|
||||
LimitedArraySlice<string> slice(out_p, num_elements);
|
||||
auto out_p = out.flat<tstring>().data() + offset;
|
||||
LimitedArraySlice<tstring> slice(out_p, num_elements);
|
||||
if (!feature.ParseBytesList(&slice)) return parse_error();
|
||||
if (slice.EndDistance() != 0) {
|
||||
return shape_error(num_elements - slice.EndDistance(), "bytes");
|
||||
@ -907,7 +907,7 @@ const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer) {
|
||||
return buffer.float_list;
|
||||
}
|
||||
template <>
|
||||
const SmallVector<string>& GetListFromBuffer<string>(
|
||||
const SmallVector<tstring>& GetListFromBuffer<tstring>(
|
||||
const SparseBuffer& buffer) {
|
||||
return buffer.bytes_list;
|
||||
}
|
||||
@ -917,7 +917,7 @@ void CopyOrMoveBlock(const T* b, const T* e, T* t) {
|
||||
std::copy(b, e, t);
|
||||
}
|
||||
template <>
|
||||
void CopyOrMoveBlock(const string* b, const string* e, string* t) {
|
||||
void CopyOrMoveBlock(const tstring* b, const tstring* e, tstring* t) {
|
||||
std::move(b, e, t);
|
||||
}
|
||||
|
||||
@ -1002,8 +1002,8 @@ class TensorVector {
|
||||
} // namespace
|
||||
|
||||
Status FastParseExample(const Config& config,
|
||||
gtl::ArraySlice<string> serialized,
|
||||
gtl::ArraySlice<string> example_names,
|
||||
gtl::ArraySlice<tstring> serialized,
|
||||
gtl::ArraySlice<tstring> example_names,
|
||||
thread::ThreadPool* thread_pool, Result* result) {
|
||||
DCHECK(result != nullptr);
|
||||
// Check config so we can safely CHECK(false) in switches on config.*.dtype
|
||||
@ -1253,8 +1253,8 @@ Status FastParseExample(const Config& config,
|
||||
break;
|
||||
}
|
||||
case DT_STRING: {
|
||||
FillAndCopyVarLen<string>(d, num_elements, num_elements_per_minibatch,
|
||||
config, varlen_dense_buffers, &values);
|
||||
FillAndCopyVarLen<tstring>(d, num_elements, num_elements_per_minibatch,
|
||||
config, varlen_dense_buffers, &values);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@ -1440,8 +1440,8 @@ Status FastParseSingleExample(const Config& config,
|
||||
break;
|
||||
}
|
||||
case DT_STRING: {
|
||||
auto out_p = out->flat<string>().data();
|
||||
LimitedArraySlice<string> slice(out_p, num_elements);
|
||||
auto out_p = out->flat<tstring>().data();
|
||||
LimitedArraySlice<tstring> slice(out_p, num_elements);
|
||||
if (!feature.ParseBytesList(&slice)) return parse_error();
|
||||
if (slice.EndDistance() != 0) {
|
||||
return parse_error();
|
||||
@ -1453,7 +1453,7 @@ Status FastParseSingleExample(const Config& config,
|
||||
}
|
||||
|
||||
} else { // if variable length
|
||||
SmallVector<string> bytes_list;
|
||||
SmallVector<tstring> bytes_list;
|
||||
TensorVector<float> float_list;
|
||||
SmallVector<int64> int64_list;
|
||||
|
||||
@ -1627,7 +1627,7 @@ Status FastParseSingleExample(const Config& config,
|
||||
// Return the number of bytes elements parsed, or -1 on error. If out is null,
|
||||
// this method simply counts the number of elements without any copying.
|
||||
inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
|
||||
string* out) {
|
||||
tstring* out) {
|
||||
int num_elements = 0;
|
||||
uint32 length;
|
||||
if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) {
|
||||
@ -1638,12 +1638,23 @@ inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
|
||||
while (!stream->ExpectAtEnd()) {
|
||||
uint32 bytes_length;
|
||||
if (!stream->ExpectTag(kDelimitedTag(1)) ||
|
||||
!stream->ReadVarint32(&bytes_length) ||
|
||||
(out != nullptr && !stream->ReadString(out++, bytes_length))) {
|
||||
!stream->ReadVarint32(&bytes_length)) {
|
||||
return -1;
|
||||
}
|
||||
if (out == nullptr) {
|
||||
stream->Skip(bytes_length);
|
||||
} else {
|
||||
#ifdef USE_TSTRING
|
||||
out->resize_uninitialized(bytes_length);
|
||||
if (!stream->ReadRaw(out->data(), bytes_length)) {
|
||||
return -1;
|
||||
}
|
||||
#else // USE_TSTRING
|
||||
if (!stream->ReadString(out, bytes_length)) {
|
||||
return -1;
|
||||
}
|
||||
#endif // USE_TSTRING
|
||||
out++;
|
||||
}
|
||||
num_elements++;
|
||||
}
|
||||
@ -1809,7 +1820,7 @@ inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
|
||||
Status FastParseSequenceExample(
|
||||
const FastParseExampleConfig& context_config,
|
||||
const FastParseExampleConfig& feature_list_config,
|
||||
gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names,
|
||||
gtl::ArraySlice<tstring> serialized, gtl::ArraySlice<tstring> example_names,
|
||||
thread::ThreadPool* thread_pool, Result* context_result,
|
||||
Result* feature_list_result, std::vector<Tensor>* dense_feature_lengths) {
|
||||
int num_examples = serialized.size();
|
||||
@ -1878,10 +1889,10 @@ Status FastParseSequenceExample(
|
||||
all_context_features(num_examples);
|
||||
std::vector<absl::flat_hash_map<StringPiece, StringPiece>>
|
||||
all_sequence_features(num_examples);
|
||||
const string kUnknown = "<unknown>";
|
||||
const tstring kUnknown = "<unknown>";
|
||||
for (int d = 0; d < num_examples; d++) {
|
||||
const string& example = serialized[d];
|
||||
const string& example_name =
|
||||
const tstring& example = serialized[d];
|
||||
const tstring& example_name =
|
||||
example_names.empty() ? kUnknown : example_names[d];
|
||||
auto* context_features = &all_context_features[d];
|
||||
auto* sequence_features = &all_sequence_features[d];
|
||||
@ -2074,7 +2085,7 @@ Status FastParseSequenceExample(
|
||||
|
||||
// TODO(sundberg): Refactor to reduce code duplication, and add bounds
|
||||
// checking for the outputs.
|
||||
string* out_bytes = nullptr;
|
||||
tstring* out_bytes = nullptr;
|
||||
float* out_float = nullptr;
|
||||
int64* out_int64 = nullptr;
|
||||
switch (dtype) {
|
||||
@ -2097,7 +2108,7 @@ Status FastParseSequenceExample(
|
||||
for (int e = 0; e < num_examples; e++) {
|
||||
size_t num_elements = 0;
|
||||
const auto feature_iter = all_context_features[e].find(c.feature_name);
|
||||
const string& example_name =
|
||||
const tstring& example_name =
|
||||
example_names.empty() ? kUnknown : example_names[e];
|
||||
if (feature_iter == all_context_features[e].end()) {
|
||||
// Copy the default value, if present. If not, return an error.
|
||||
@ -2107,7 +2118,7 @@ Status FastParseSequenceExample(
|
||||
" (data type: ", DataTypeString(c.dtype), ")",
|
||||
" is required but could not be found.");
|
||||
}
|
||||
const string* in_bytes = nullptr;
|
||||
const tstring* in_bytes = nullptr;
|
||||
const float* in_float = nullptr;
|
||||
const int64* in_int64 = nullptr;
|
||||
size_t num = 0;
|
||||
@ -2185,7 +2196,7 @@ Status FastParseSequenceExample(
|
||||
Tensor(allocator, DT_INT64, TensorShape({2}));
|
||||
// TODO(sundberg): Refactor to reduce code duplication, and add bounds
|
||||
// checking for the outputs.
|
||||
string* out_bytes = nullptr;
|
||||
tstring* out_bytes = nullptr;
|
||||
float* out_float = nullptr;
|
||||
int64* out_int64 = nullptr;
|
||||
switch (dtype) {
|
||||
@ -2211,7 +2222,7 @@ Status FastParseSequenceExample(
|
||||
size_t max_num_cols = 0;
|
||||
for (int e = 0; e < num_examples; e++) {
|
||||
const auto& feature = all_context_features[e][c.feature_name];
|
||||
const string& example_name =
|
||||
const tstring& example_name =
|
||||
example_names.empty() ? kUnknown : example_names[e];
|
||||
if (!feature.empty()) {
|
||||
protobuf::io::CodedInputStream stream(
|
||||
@ -2276,7 +2287,7 @@ Status FastParseSequenceExample(
|
||||
Tensor(allocator, DT_INT64, dense_length_shape);
|
||||
int64* out_lengths = (*dense_feature_lengths)[t].flat<int64>().data();
|
||||
|
||||
string* out_bytes = nullptr;
|
||||
tstring* out_bytes = nullptr;
|
||||
float* out_float = nullptr;
|
||||
int64* out_int64 = nullptr;
|
||||
switch (dtype) {
|
||||
@ -2299,7 +2310,7 @@ Status FastParseSequenceExample(
|
||||
for (int e = 0; e < num_examples; e++) {
|
||||
size_t num_elements = 0, num_rows = 0;
|
||||
const auto feature_iter = all_sequence_features[e].find(c.feature_name);
|
||||
const string& example_name =
|
||||
const tstring& example_name =
|
||||
example_names.empty() ? kUnknown : example_names[e];
|
||||
if (feature_iter == all_sequence_features[e].end()) {
|
||||
// Return an error if this feature was not allowed to be missing.
|
||||
@ -2387,7 +2398,7 @@ Status FastParseSequenceExample(
|
||||
feature_list_result->sparse_shapes[t] =
|
||||
Tensor(allocator, DT_INT64, TensorShape({3}));
|
||||
|
||||
string* out_bytes = nullptr;
|
||||
tstring* out_bytes = nullptr;
|
||||
float* out_float = nullptr;
|
||||
int64* out_int64 = nullptr;
|
||||
switch (dtype) {
|
||||
@ -2416,7 +2427,7 @@ Status FastParseSequenceExample(
|
||||
size_t max_num_cols = 0;
|
||||
for (int e = 0; e < num_examples; e++) {
|
||||
const auto& feature = all_sequence_features[e][c.feature_name];
|
||||
const string& example_name =
|
||||
const tstring& example_name =
|
||||
example_names.empty() ? kUnknown : example_names[e];
|
||||
if (!feature.empty()) {
|
||||
protobuf::io::CodedInputStream stream(
|
||||
|
@ -99,8 +99,8 @@ struct Result {
|
||||
// Given example names have to either be empty or the same size as serialized.
|
||||
// example_names are used only for error messages.
|
||||
Status FastParseExample(const FastParseExampleConfig& config,
|
||||
gtl::ArraySlice<string> serialized,
|
||||
gtl::ArraySlice<string> example_names,
|
||||
gtl::ArraySlice<tstring> serialized,
|
||||
gtl::ArraySlice<tstring> example_names,
|
||||
thread::ThreadPool* thread_pool, Result* result);
|
||||
|
||||
// TODO(mrry): Move the hash table construction into the config object.
|
||||
@ -116,7 +116,7 @@ Status FastParseSingleExample(const FastParseSingleExampleConfig& config,
|
||||
Status FastParseSequenceExample(
|
||||
const example::FastParseExampleConfig& context_config,
|
||||
const example::FastParseExampleConfig& feature_list_config,
|
||||
gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names,
|
||||
gtl::ArraySlice<tstring> serialized, gtl::ArraySlice<tstring> example_names,
|
||||
thread::ThreadPool* thread_pool, example::Result* context_result,
|
||||
example::Result* feature_list_result,
|
||||
std::vector<Tensor>* dense_feature_lengths);
|
||||
|
@ -273,7 +273,7 @@ static void AddSparseFeature(const char* feature_name, DataType dtype,
|
||||
|
||||
TEST(FastParse, StatsCollection) {
|
||||
const size_t kNumExamples = 13;
|
||||
std::vector<string> serialized(kNumExamples, ExampleWithSomeFeatures());
|
||||
std::vector<tstring> serialized(kNumExamples, ExampleWithSomeFeatures());
|
||||
|
||||
FastParseExampleConfig config_dense;
|
||||
AddDenseFeature("bytes_list", DT_STRING, {2}, false, 2, &config_dense);
|
||||
@ -417,8 +417,9 @@ TEST(TestFastParseExample, Empty) {
|
||||
Result result;
|
||||
FastParseExampleConfig config;
|
||||
config.sparse.push_back({"test", DT_STRING});
|
||||
Status status = FastParseExample(config, gtl::ArraySlice<string>(),
|
||||
gtl::ArraySlice<string>(), nullptr, &result);
|
||||
Status status =
|
||||
FastParseExample(config, gtl::ArraySlice<tstring>(),
|
||||
gtl::ArraySlice<tstring>(), nullptr, &result);
|
||||
EXPECT_TRUE(status.ok()) << status;
|
||||
}
|
||||
|
||||
|
@ -101,7 +101,7 @@ Status FeatureDenseCopy(const std::size_t out_index, const string& name,
|
||||
"Values size: ",
|
||||
values.value_size(), " but output shape: ", shape.DebugString());
|
||||
}
|
||||
auto out_p = out->flat<string>().data() + offset;
|
||||
auto out_p = out->flat<tstring>().data() + offset;
|
||||
std::transform(values.value().data(),
|
||||
values.value().data() + num_elements, out_p,
|
||||
[](const string* s) { return *s; });
|
||||
@ -136,7 +136,7 @@ Tensor FeatureSparseCopy(const std::size_t batch, const string& key,
|
||||
const BytesList& values = feature.bytes_list();
|
||||
const int64 num_elements = values.value_size();
|
||||
Tensor out(dtype, TensorShape({num_elements}));
|
||||
auto out_p = out.flat<string>().data();
|
||||
auto out_p = out.flat<tstring>().data();
|
||||
std::transform(values.value().data(),
|
||||
values.value().data() + num_elements, out_p,
|
||||
[](const string* s) { return *s; });
|
||||
@ -175,8 +175,8 @@ int64 CopyIntoSparseTensor(const Tensor& in, const int batch,
|
||||
break;
|
||||
}
|
||||
case DT_STRING: {
|
||||
std::copy_n(in.flat<string>().data(), num_elements,
|
||||
values->flat<string>().data() + offset);
|
||||
std::copy_n(in.flat<tstring>().data(), num_elements,
|
||||
values->flat<tstring>().data() + offset);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@ -203,8 +203,9 @@ void RowDenseCopy(const std::size_t& out_index, const DataType& dtype,
|
||||
break;
|
||||
}
|
||||
case DT_STRING: {
|
||||
std::copy_n(in.flat<string>().data(), num_elements,
|
||||
out->flat<string>().data() + offset);
|
||||
// TODO(dero): verify.
|
||||
std::copy_n(in.flat<tstring>().data(), num_elements,
|
||||
out->flat<tstring>().data() + offset);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
|
@ -337,10 +337,24 @@ inline Status ReadPrimitive(CodedInputStream* input, int index, void* data) {
|
||||
// serialized proto.
|
||||
// May read all or part of a repeated field.
|
||||
inline Status ReadBytes(CodedInputStream* input, int index, void* datap) {
|
||||
string* data = reinterpret_cast<string*>(datap) + index;
|
||||
tstring* data = reinterpret_cast<tstring*>(datap) + index;
|
||||
|
||||
#ifdef USE_TSTRING
|
||||
uint32 length;
|
||||
if (!input->ReadVarint32(&length)) {
|
||||
return errors::DataLoss("Failed reading bytes");
|
||||
}
|
||||
|
||||
data->resize_uninitialized(length);
|
||||
|
||||
if (!input->ReadRaw(data->data(), length)) {
|
||||
return errors::DataLoss("Failed reading bytes");
|
||||
}
|
||||
#else // USE_TSTRING
|
||||
if (!WireFormatLite::ReadBytes(input, data)) {
|
||||
return errors::DataLoss("Failed reading bytes");
|
||||
}
|
||||
#endif // USE_TSTRING
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -354,8 +368,19 @@ inline Status ReadGroupBytes(CodedInputStream* input, int field_number,
|
||||
// TODO(nix): there is a faster way to grab TYPE_GROUP bytes by relying
|
||||
// on input->IsFlat() == true and using input->GetDirectBufferPointer()
|
||||
// with input->CurrentPosition().
|
||||
string* data = reinterpret_cast<string*>(datap) + index;
|
||||
tstring* data = reinterpret_cast<tstring*>(datap) + index;
|
||||
#ifdef USE_TSTRING
|
||||
// TODO(dero): To mitigate the string to tstring copy, we can implement our
|
||||
// own scanner as described above. We would first need to obtain the length
|
||||
// in an initial pass and resize/reserve the tstring. But, given that
|
||||
// TYPE_GROUP is deprecated and currently no tests in
|
||||
// tensorflow/python/kernel_tests/proto:decode_proto_op_test target a
|
||||
// TYPE_GROUP tag, we use std::string as a read buffer.
|
||||
string buf;
|
||||
StringOutputStream string_stream(&buf);
|
||||
#else // USE_TSTRING
|
||||
StringOutputStream string_stream(data);
|
||||
#endif // USE_TSTRING
|
||||
CodedOutputStream out(&string_stream);
|
||||
if (!WireFormatLite::SkipField(
|
||||
input,
|
||||
@ -364,6 +389,9 @@ inline Status ReadGroupBytes(CodedInputStream* input, int field_number,
|
||||
&out)) {
|
||||
return errors::DataLoss("Failed reading group");
|
||||
}
|
||||
#ifdef USE_TSTRING
|
||||
*data = buf;
|
||||
#endif // USE_TSTRING
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -179,29 +179,29 @@ inline void Fill(const Eigen::half* data, size_t n, TensorProto* t) {
|
||||
// Custom implementation for string.
|
||||
|
||||
template <>
|
||||
struct SaveTypeTraits<string> {
|
||||
struct SaveTypeTraits<tstring> {
|
||||
static constexpr bool supported = true;
|
||||
typedef const string* SavedType;
|
||||
typedef protobuf::RepeatedPtrField<string> RepeatedField;
|
||||
};
|
||||
|
||||
template <>
|
||||
inline const string* const* TensorProtoData<string>(const TensorProto& t) {
|
||||
static_assert(SaveTypeTraits<string>::supported,
|
||||
"Specified type string not supported for Restore");
|
||||
inline const string* const* TensorProtoData<tstring>(const TensorProto& t) {
|
||||
static_assert(SaveTypeTraits<tstring>::supported,
|
||||
"Specified type tstring not supported for Restore");
|
||||
return t.string_val().data();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline protobuf::RepeatedPtrField<string>* MutableTensorProtoData<string>(
|
||||
inline protobuf::RepeatedPtrField<string>* MutableTensorProtoData<tstring>(
|
||||
TensorProto* t) {
|
||||
static_assert(SaveTypeTraits<string>::supported,
|
||||
"Specified type string not supported for Save");
|
||||
static_assert(SaveTypeTraits<tstring>::supported,
|
||||
"Specified type tstring not supported for Save");
|
||||
return t->mutable_string_val();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void Fill(const string* data, size_t n, TensorProto* t) {
|
||||
inline void Fill(const tstring* data, size_t n, TensorProto* t) {
|
||||
typename protobuf::RepeatedPtrField<string> copy(data, data + n);
|
||||
t->mutable_string_val()->Swap(©);
|
||||
}
|
||||
|
@ -102,7 +102,7 @@ Example of grouping:
|
||||
Tensor values(DT_STRING, TensorShape({N});
|
||||
TensorShape shape({dim0,...});
|
||||
SparseTensor sp(indices, vals, shape);
|
||||
sp.Reorder<string>({1, 2, 0, 3, ...}); // Must provide NDIMS dims.
|
||||
sp.Reorder<tstring>({1, 2, 0, 3, ...}); // Must provide NDIMS dims.
|
||||
// group according to dims 1 and 2
|
||||
for (const auto& g : sp.group({1, 2})) {
|
||||
cout << "vals of ix[:, 1,2] for this group: "
|
||||
@ -111,7 +111,7 @@ Example of grouping:
|
||||
cout << "values of group:\n" << g.values();
|
||||
|
||||
TTypes<int64>::UnalignedMatrix g_ix = g.indices();
|
||||
TTypes<string>::UnalignedVec g_v = g.values();
|
||||
TTypes<tstring>::UnalignedVec g_v = g.values();
|
||||
ASSERT(g_ix.dimension(0) == g_v.size()); // number of elements match.
|
||||
}
|
||||
|
||||
@ -133,7 +133,7 @@ Shape checking is performed, as is boundary checking.
|
||||
|
||||
Tensor dense(DT_STRING, shape);
|
||||
// initialize other indices to zero. copy.
|
||||
ASSERT(sp.ToDense<string>(&dense, true));
|
||||
ASSERT(sp.ToDense<tstring>(&dense, true));
|
||||
|
||||
|
||||
Concat
|
||||
@ -215,7 +215,7 @@ Coding Example:
|
||||
EXPECT_EQ(conc.Order(), {-1, -1, -1});
|
||||
|
||||
// Reorder st3 so all input tensors have the exact same orders.
|
||||
st3.Reorder<string>({1, 0, 2});
|
||||
st3.Reorder<tstring>({1, 0, 2});
|
||||
SparseTensor conc2 = SparseTensor::Concat<string>({st1, st2, st3});
|
||||
EXPECT_EQ(conc2.Order(), {1, 0, 2});
|
||||
// All indices' orders matched, so output is in order.
|
||||
|
@ -170,7 +170,7 @@ TEST(SparseTensorTest, SparseTensorConstruction) {
|
||||
int N = 5;
|
||||
const int NDIM = 3;
|
||||
auto ix_c = GetSimpleIndexTensor(N, NDIM);
|
||||
Eigen::Tensor<string, 1, Eigen::RowMajor> vals_c(N);
|
||||
Eigen::Tensor<tstring, 1, Eigen::RowMajor> vals_c(N);
|
||||
vals_c(0) = "hi0";
|
||||
vals_c(1) = "hi1";
|
||||
vals_c(2) = "hi2";
|
||||
@ -200,7 +200,7 @@ TEST(SparseTensorTest, SparseTensorConstruction) {
|
||||
|
||||
// Regardless of how order is updated; so long as there are no
|
||||
// duplicates, the resulting indices are valid.
|
||||
st.Reorder<string>({2, 0, 1});
|
||||
st.Reorder<tstring>({2, 0, 1});
|
||||
TF_EXPECT_OK(st.IndicesValid());
|
||||
EXPECT_EQ(vals_t(0), "hi0");
|
||||
EXPECT_EQ(vals_t(1), "hi3");
|
||||
@ -210,7 +210,7 @@ TEST(SparseTensorTest, SparseTensorConstruction) {
|
||||
|
||||
ix_t = ix_c;
|
||||
vals_t = vals_c;
|
||||
st.Reorder<string>({0, 1, 2});
|
||||
st.Reorder<tstring>({0, 1, 2});
|
||||
TF_EXPECT_OK(st.IndicesValid());
|
||||
EXPECT_EQ(vals_t(0), "hi0");
|
||||
EXPECT_EQ(vals_t(1), "hi4");
|
||||
@ -220,7 +220,7 @@ TEST(SparseTensorTest, SparseTensorConstruction) {
|
||||
|
||||
ix_t = ix_c;
|
||||
vals_t = vals_c;
|
||||
st.Reorder<string>({2, 1, 0});
|
||||
st.Reorder<tstring>({2, 1, 0});
|
||||
TF_EXPECT_OK(st.IndicesValid());
|
||||
}
|
||||
|
||||
@ -239,7 +239,7 @@ TEST(SparseTensorTest, EmptySparseTensorAllowed) {
|
||||
EXPECT_EQ(st.order(), order);
|
||||
|
||||
std::vector<int64> new_order{1, 0, 2};
|
||||
st.Reorder<string>(new_order);
|
||||
st.Reorder<tstring>(new_order);
|
||||
TF_EXPECT_OK(st.IndicesValid());
|
||||
EXPECT_EQ(st.order(), new_order);
|
||||
}
|
||||
@ -259,13 +259,13 @@ TEST(SparseTensorTest, SortingWorksCorrectly) {
|
||||
for (int n = 0; n < 100; ++n) {
|
||||
ix_t = ix_t.random(Eigen::internal::UniformRandomGenerator<int64>(n + 1));
|
||||
ix_t = ix_t.abs() % 1000;
|
||||
st.Reorder<string>({0, 1, 2, 3});
|
||||
st.Reorder<tstring>({0, 1, 2, 3});
|
||||
TF_EXPECT_OK(st.IndicesValid());
|
||||
st.Reorder<string>({3, 2, 1, 0});
|
||||
st.Reorder<tstring>({3, 2, 1, 0});
|
||||
TF_EXPECT_OK(st.IndicesValid());
|
||||
st.Reorder<string>({1, 0, 2, 3});
|
||||
st.Reorder<tstring>({1, 0, 2, 3});
|
||||
TF_EXPECT_OK(st.IndicesValid());
|
||||
st.Reorder<string>({3, 0, 2, 1});
|
||||
st.Reorder<tstring>({3, 0, 2, 1});
|
||||
TF_EXPECT_OK(st.IndicesValid());
|
||||
}
|
||||
}
|
||||
@ -294,7 +294,7 @@ TEST(SparseTensorTest, ValidateIndicesFindsInvalid) {
|
||||
SparseTensor st;
|
||||
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
|
||||
|
||||
st.Reorder<string>(order);
|
||||
st.Reorder<tstring>(order);
|
||||
Status st_indices_valid = st.IndicesValid();
|
||||
EXPECT_FALSE(st_indices_valid.ok());
|
||||
EXPECT_EQ("indices[1] = [0,0,0] is repeated",
|
||||
@ -302,12 +302,12 @@ TEST(SparseTensorTest, ValidateIndicesFindsInvalid) {
|
||||
|
||||
ix_orig(1, 2) = 1;
|
||||
ix_t = ix_orig;
|
||||
st.Reorder<string>(order);
|
||||
st.Reorder<tstring>(order);
|
||||
TF_EXPECT_OK(st.IndicesValid()); // second index now (0, 0, 1)
|
||||
|
||||
ix_orig(0, 2) = 1;
|
||||
ix_t = ix_orig;
|
||||
st.Reorder<string>(order);
|
||||
st.Reorder<tstring>(order);
|
||||
st_indices_valid = st.IndicesValid();
|
||||
EXPECT_FALSE(st_indices_valid.ok()); // first index now (0, 0, 1)
|
||||
EXPECT_EQ("indices[1] = [0,0,1] is repeated",
|
||||
@ -332,12 +332,12 @@ TEST(SparseTensorTest, SparseTensorCheckBoundaries) {
|
||||
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
|
||||
EXPECT_FALSE(st.IndicesValid().ok());
|
||||
|
||||
st.Reorder<string>(order);
|
||||
st.Reorder<tstring>(order);
|
||||
TF_EXPECT_OK(st.IndicesValid());
|
||||
|
||||
ix_t(0, 0) = 11;
|
||||
ix.matrix<int64>() = ix_t;
|
||||
st.Reorder<string>(order);
|
||||
st.Reorder<tstring>(order);
|
||||
Status st_indices_valid = st.IndicesValid();
|
||||
EXPECT_FALSE(st_indices_valid.ok());
|
||||
// Error message references index 4 because of the call to Reorder.
|
||||
@ -346,7 +346,7 @@ TEST(SparseTensorTest, SparseTensorCheckBoundaries) {
|
||||
|
||||
ix_t(0, 0) = -1;
|
||||
ix.matrix<int64>() = ix_t;
|
||||
st.Reorder<string>(order);
|
||||
st.Reorder<tstring>(order);
|
||||
st_indices_valid = st.IndicesValid();
|
||||
EXPECT_FALSE(st_indices_valid.ok());
|
||||
EXPECT_EQ("[-1,0,0] is out of bounds: need 0 <= index < [10,10,10]",
|
||||
@ -354,7 +354,7 @@ TEST(SparseTensorTest, SparseTensorCheckBoundaries) {
|
||||
|
||||
ix_t(0, 0) = 0;
|
||||
ix.matrix<int64>() = ix_t;
|
||||
st.Reorder<string>(order);
|
||||
st.Reorder<tstring>(order);
|
||||
TF_EXPECT_OK(st.IndicesValid());
|
||||
}
|
||||
|
||||
@ -382,9 +382,9 @@ TEST(SparseTensorTest, SparseTensorToDenseTensor) {
|
||||
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
|
||||
|
||||
Tensor dense(DT_STRING, TensorShape({4, 4, 5}));
|
||||
st.ToDense<string>(&dense);
|
||||
st.ToDense<tstring>(&dense);
|
||||
|
||||
auto dense_t = dense.tensor<string, 3>();
|
||||
auto dense_t = dense.tensor<tstring, 3>();
|
||||
Eigen::array<Eigen::DenseIndex, NDIM> ix_n;
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int d = 0; d < NDIM; ++d) ix_n[d] = ix_t(n, d);
|
||||
@ -422,9 +422,9 @@ TEST(SparseTensorTest, SparseTensorToLargerDenseTensor) {
|
||||
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
|
||||
|
||||
Tensor dense(DT_STRING, TensorShape({10, 10, 10}));
|
||||
st.ToDense<string>(&dense);
|
||||
st.ToDense<tstring>(&dense);
|
||||
|
||||
auto dense_t = dense.tensor<string, 3>();
|
||||
auto dense_t = dense.tensor<tstring, 3>();
|
||||
Eigen::array<Eigen::DenseIndex, NDIM> ix_n;
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int d = 0; d < NDIM; ++d) ix_n[d] = ix_t(n, d);
|
||||
@ -554,10 +554,10 @@ TEST(SparseTensorTest, Concat) {
|
||||
SparseTensor st;
|
||||
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
|
||||
EXPECT_FALSE(st.IndicesValid().ok());
|
||||
st.Reorder<string>(order);
|
||||
st.Reorder<tstring>(order);
|
||||
TF_EXPECT_OK(st.IndicesValid());
|
||||
|
||||
SparseTensor concatted = SparseTensor::Concat<string>({st, st, st, st});
|
||||
SparseTensor concatted = SparseTensor::Concat<tstring>({st, st, st, st});
|
||||
EXPECT_EQ(concatted.order(), st.order());
|
||||
gtl::InlinedVector<int64, 8> expected_shape{40, 10, 10};
|
||||
EXPECT_EQ(concatted.shape(), expected_shape);
|
||||
@ -585,7 +585,7 @@ TEST(SparseTensorTest, Concat) {
|
||||
SparseTensor st_ooo;
|
||||
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, {0, 2, 1},
|
||||
&st_ooo)); // non-primary ix OOO
|
||||
SparseTensor conc_ooo = SparseTensor::Concat<string>({st, st, st, st_ooo});
|
||||
SparseTensor conc_ooo = SparseTensor::Concat<tstring>({st, st, st, st_ooo});
|
||||
std::vector<int64> expected_ooo{-1, -1, -1};
|
||||
EXPECT_EQ(conc_ooo.order(), expected_ooo);
|
||||
EXPECT_EQ(conc_ooo.shape(), expected_shape);
|
||||
@ -782,7 +782,7 @@ static void BM_SparseReorderString(int iters, int N32, int NDIM32) {
|
||||
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
|
||||
|
||||
testing::StartTiming();
|
||||
st.Reorder<string>(reorder);
|
||||
st.Reorder<tstring>(reorder);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -69,7 +69,7 @@ namespace {
|
||||
// Checksums the string lengths (as restored uint32 or uint64, not varint64
|
||||
// bytes) and string bytes, and stores it into "actual_crc32c".
|
||||
Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
|
||||
size_t offset, size_t size, string* destination,
|
||||
size_t offset, size_t size, tstring* destination,
|
||||
uint32* actual_crc32c, bool need_to_swap_bytes) {
|
||||
if (size == 0) return Status::OK();
|
||||
CHECK_GT(size, 0);
|
||||
@ -127,7 +127,7 @@ Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
|
||||
// Reads the actual string bytes.
|
||||
for (size_t i = 0; i < num_elements; ++i) {
|
||||
const uint64 string_length = string_lengths[i];
|
||||
string* buffer = &destination[i];
|
||||
tstring* buffer = &destination[i];
|
||||
|
||||
buffer->resize(string_length);
|
||||
size_t bytes_read = 0;
|
||||
@ -205,9 +205,9 @@ char* GetBackingBuffer(const Tensor& val) {
|
||||
return const_cast<char*>(val.tensor_data().data());
|
||||
}
|
||||
|
||||
string* GetStringBackingBuffer(const Tensor& val) {
|
||||
tstring* GetStringBackingBuffer(const Tensor& val) {
|
||||
CHECK_EQ(DT_STRING, val.dtype());
|
||||
return const_cast<string*>(val.flat<string>().data());
|
||||
return const_cast<tstring*>(val.flat<tstring>().data());
|
||||
}
|
||||
|
||||
Status ParseEntryProto(StringPiece key, StringPiece value,
|
||||
@ -244,14 +244,14 @@ Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
|
||||
// Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes),
|
||||
// the length-checksum, and all the string bytes.
|
||||
DCHECK_EQ(val.dtype(), DT_STRING);
|
||||
const string* strings = GetStringBackingBuffer(val);
|
||||
const tstring* strings = GetStringBackingBuffer(val);
|
||||
|
||||
// Writes the varint lengths.
|
||||
string lengths;
|
||||
lengths.reserve(val.NumElements()); // At least 1 byte per element.
|
||||
*crc32c = 0;
|
||||
for (int64 i = 0; i < val.NumElements(); ++i) {
|
||||
const string* elem = &strings[i];
|
||||
const tstring* elem = &strings[i];
|
||||
DCHECK_EQ(elem->size(), static_cast<uint64>(elem->size()));
|
||||
const uint64 elem_size = static_cast<uint64>(elem->size());
|
||||
|
||||
@ -281,7 +281,7 @@ Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
|
||||
|
||||
// Writes all the string bytes out.
|
||||
for (int64 i = 0; i < val.NumElements(); ++i) {
|
||||
const string* string = &strings[i];
|
||||
const tstring* string = &strings[i];
|
||||
TF_RETURN_IF_ERROR(out->Append(*string));
|
||||
*bytes_written += string->size();
|
||||
*crc32c = crc32c::Extend(*crc32c, string->data(), string->size());
|
||||
@ -675,7 +675,7 @@ static Status MergeOneBundle(Env* env, StringPiece prefix,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MergeBundles(Env* env, gtl::ArraySlice<string> prefixes,
|
||||
Status MergeBundles(Env* env, gtl::ArraySlice<tstring> prefixes,
|
||||
StringPiece merged_prefix) {
|
||||
// Merges all metadata tables.
|
||||
// TODO(zhifengc): KeyValue sorter if it becomes too big.
|
||||
@ -823,10 +823,10 @@ Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) {
|
||||
// Relaxes the check for string tensors as follows:
|
||||
// entry.size() == bytes(varint lengths) + bytes(data)
|
||||
// >= NumElems + bytes(data), since size bytes(varint) >= 1.
|
||||
// TotalBytes() == sizeof(string) * NumElems + bytes(data)
|
||||
// TotalBytes() == sizeof(tstring) * NumElems + bytes(data)
|
||||
// Since we don't know bytes(varint lengths), we just check an inequality.
|
||||
const size_t lower_bound = ret->NumElements() + ret->TotalBytes() -
|
||||
sizeof(string) * ret->NumElements();
|
||||
sizeof(tstring) * ret->NumElements();
|
||||
if (entry.size() < lower_bound) {
|
||||
return errors::DataLoss("Invalid size in bundle entry: key ", key(),
|
||||
"; stored size ", entry.size(),
|
||||
|
@ -172,7 +172,7 @@ class BundleWriter {
|
||||
//
|
||||
// Once merged, makes a best effort to delete the old metadata files.
|
||||
// Returns OK iff all bundles are successfully merged.
|
||||
Status MergeBundles(Env* env, gtl::ArraySlice<string> prefixes,
|
||||
Status MergeBundles(Env* env, gtl::ArraySlice<tstring> prefixes,
|
||||
StringPiece merged_prefix);
|
||||
|
||||
// On construction, silently attempts to read the metadata associated with
|
||||
|
@ -710,11 +710,12 @@ TEST(TensorBundleTest, StringTensorsOldFormat) {
|
||||
EXPECT_EQ(AllTensorKeys(&reader),
|
||||
std::vector<string>({"floats", "scalar", "string_tensor", "strs"}));
|
||||
|
||||
Expect<string>(&reader, "string_tensor", Tensor(DT_STRING, TensorShape({1})));
|
||||
Expect<string>(&reader, "scalar", test::AsTensor<string>({"hello"}));
|
||||
Expect<string>(
|
||||
Expect<tstring>(&reader, "string_tensor",
|
||||
Tensor(DT_STRING, TensorShape({1})));
|
||||
Expect<tstring>(&reader, "scalar", test::AsTensor<tstring>({"hello"}));
|
||||
Expect<tstring>(
|
||||
&reader, "strs",
|
||||
test::AsTensor<string>({"hello", "", "x01", string(1 << 10, 'c')}));
|
||||
test::AsTensor<tstring>({"hello", "", "x01", string(1 << 10, 'c')}));
|
||||
Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
|
||||
}
|
||||
|
||||
@ -726,14 +727,19 @@ TEST(TensorBundleTest, StringTensors) {
|
||||
BundleWriter writer(Env::Default(), Prefix("foo"));
|
||||
TF_EXPECT_OK(writer.Add("string_tensor",
|
||||
Tensor(DT_STRING, TensorShape({1})))); // Empty.
|
||||
TF_EXPECT_OK(writer.Add("scalar", test::AsTensor<string>({"hello"})));
|
||||
TF_EXPECT_OK(writer.Add("scalar", test::AsTensor<tstring>({"hello"})));
|
||||
TF_EXPECT_OK(writer.Add(
|
||||
"strs",
|
||||
test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')})));
|
||||
test::AsTensor<tstring>({"hello", "", "x01", string(1 << 25, 'c')})));
|
||||
|
||||
// Requires a 64-bit length.
|
||||
string* backing_string = long_string_tensor.flat<string>().data();
|
||||
tstring* backing_string = long_string_tensor.flat<tstring>().data();
|
||||
#ifdef USE_TSTRING
|
||||
backing_string->resize_uninitialized(kLongLength);
|
||||
std::char_traits<char>::assign(backing_string->data(), kLongLength, 'd');
|
||||
#else // USE_TSTRING
|
||||
backing_string->assign(kLongLength, 'd');
|
||||
#endif // USE_TSTRING
|
||||
TF_EXPECT_OK(writer.Add("long_scalar", long_string_tensor));
|
||||
|
||||
// Mixes in some floats.
|
||||
@ -747,12 +753,12 @@ TEST(TensorBundleTest, StringTensors) {
|
||||
std::vector<string>({"floats", "long_scalar", "scalar",
|
||||
"string_tensor", "strs"}));
|
||||
|
||||
Expect<string>(&reader, "string_tensor",
|
||||
Tensor(DT_STRING, TensorShape({1})));
|
||||
Expect<string>(&reader, "scalar", test::AsTensor<string>({"hello"}));
|
||||
Expect<string>(
|
||||
Expect<tstring>(&reader, "string_tensor",
|
||||
Tensor(DT_STRING, TensorShape({1})));
|
||||
Expect<tstring>(&reader, "scalar", test::AsTensor<tstring>({"hello"}));
|
||||
Expect<tstring>(
|
||||
&reader, "strs",
|
||||
test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')}));
|
||||
test::AsTensor<tstring>({"hello", "", "x01", string(1 << 25, 'c')}));
|
||||
|
||||
Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
|
||||
|
||||
@ -767,17 +773,17 @@ TEST(TensorBundleTest, StringTensors) {
|
||||
EXPECT_EQ(TensorShape({1}), shape);
|
||||
|
||||
// Zero-out the string so that we can be sure the new one is read in.
|
||||
string* backing_string = long_string_tensor.flat<string>().data();
|
||||
tstring* backing_string = long_string_tensor.flat<tstring>().data();
|
||||
backing_string->assign("");
|
||||
|
||||
// Read long_scalar and check it contains kLongLength 'd's.
|
||||
TF_ASSERT_OK(reader.Lookup("long_scalar", &long_string_tensor));
|
||||
ASSERT_EQ(backing_string, long_string_tensor.flat<string>().data());
|
||||
ASSERT_EQ(backing_string, long_string_tensor.flat<tstring>().data());
|
||||
EXPECT_EQ(kLongLength, backing_string->length());
|
||||
for (char c : *backing_string) {
|
||||
for (size_t i = 0; i < kLongLength; i++) {
|
||||
// Not using ASSERT_EQ('d', c) because this way is twice as fast due to
|
||||
// compiler optimizations.
|
||||
if (c != 'd') {
|
||||
if ((*backing_string)[i] != 'd') {
|
||||
FAIL() << "long_scalar is not full of 'd's as expected.";
|
||||
break;
|
||||
}
|
||||
@ -945,7 +951,7 @@ TEST(TensorBundleTest, Checksum) {
|
||||
auto WriteStrings = []() {
|
||||
BundleWriter writer(Env::Default(), Prefix("strings"));
|
||||
TF_EXPECT_OK(
|
||||
writer.Add("foo", test::AsTensor<string>({"hello", "world"})));
|
||||
writer.Add("foo", test::AsTensor<tstring>({"hello", "world"})));
|
||||
TF_ASSERT_OK(writer.Finish());
|
||||
};
|
||||
// Corrupts the first two bytes, which are the varint32-encoded lengths
|
||||
|
@ -55,7 +55,7 @@ struct CopyThatWorksWithStringPointer {
|
||||
// Eigen makes it extremely difficult to dereference a tensor of string* into
|
||||
// string, so we roll our own loop instead.
|
||||
template <>
|
||||
struct CopyThatWorksWithStringPointer<string> {
|
||||
struct CopyThatWorksWithStringPointer<tstring> {
|
||||
template <typename SrcTensor, typename DstTensor, typename Shape>
|
||||
static void Copy(const SrcTensor& s, Shape s_start, Shape len, DstTensor& d,
|
||||
Shape d_start) {
|
||||
|
@ -176,7 +176,7 @@ size_t TensorSliceWriter::MaxBytesPerElement(DataType dt) {
|
||||
}
|
||||
|
||||
template <>
|
||||
Status TensorSliceWriter::SaveData(const string* data, int64 num_elements,
|
||||
Status TensorSliceWriter::SaveData(const tstring* data, int64 num_elements,
|
||||
SavedSlice* ss) {
|
||||
size_t size_bound = ss->ByteSize() + kTensorProtoHeaderBytes +
|
||||
(num_elements * MaxBytesPerElement(DT_INT32));
|
||||
|
@ -178,7 +178,7 @@ Status TensorSliceWriter::SaveData(const T* data, int64 num_elements,
|
||||
}
|
||||
|
||||
template <>
|
||||
Status TensorSliceWriter::SaveData(const string* data, int64 num_elements,
|
||||
Status TensorSliceWriter::SaveData(const tstring* data, int64 num_elements,
|
||||
SavedSlice* ss);
|
||||
|
||||
// Create a table builder that will write to "filename" in
|
||||
|
@ -342,7 +342,7 @@ TEST(TensorSliceWriteTest, SizeErrors) {
|
||||
{
|
||||
TensorShape shape({256, 1024});
|
||||
TensorSlice slice = TensorSlice::ParseOrDie("-:-");
|
||||
const std::vector<string> data(256 * 1024, std::string(8192, 'f'));
|
||||
const std::vector<tstring> data(256 * 1024, std::string(8192, 'f'));
|
||||
Status s = writer.Add("test2", shape, slice, data.data());
|
||||
EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
|
||||
EXPECT_TRUE(absl::StrContains(s.error_message(),
|
||||
|
Loading…
Reference in New Issue
Block a user