Update kernels and related libs to use tstring.

This is a part of a larger migration effort for tensorflow::tstring.
See: https://github.com/tensorflow/community/pull/91
PiperOrigin-RevId: 265811129
This commit is contained in:
Dero Gharibian 2019-08-27 18:15:54 -07:00 committed by TensorFlower Gardener
parent 63ba081d07
commit f742e74da3
36 changed files with 303 additions and 154 deletions

View File

@ -124,7 +124,7 @@ class EncodeJpegOp : public OpKernel {
context->allocate_output(0, TensorShape({}), &output));
OP_REQUIRES(context,
jpeg::Compress(image.flat<uint8>().data(), dim_size1, dim_size0,
adjusted_flags, &output->scalar<string>()()),
adjusted_flags, &output->scalar<tstring>()()),
errors::Internal("JPEG encoding failed"));
}
@ -190,7 +190,7 @@ class EncodeJpegVariableQualityOp : public OpKernel {
context->allocate_output(0, TensorShape({}), &output));
OP_REQUIRES(context,
jpeg::Compress(image.flat<uint8>().data(), dim_size1, dim_size0,
adjusted_flags, &output->scalar<string>()()),
adjusted_flags, &output->scalar<tstring>()()),
errors::Internal("JPEG encoding failed"));
}
};

View File

@ -78,17 +78,17 @@ class EncodePngOp : public OpKernel {
context->allocate_output(0, TensorShape({}), &output));
if (desired_channel_bits_ == 8) {
OP_REQUIRES(context,
png::WriteImageToBuffer(image.flat<uint8>().data(), width,
height, width * channels, channels,
desired_channel_bits_, compression_,
&output->scalar<string>()(), nullptr),
png::WriteImageToBuffer(
image.flat<uint8>().data(), width, height,
width * channels, channels, desired_channel_bits_,
compression_, &output->scalar<tstring>()(), nullptr),
errors::Internal("PNG encoding failed"));
} else {
OP_REQUIRES(context,
png::WriteImageToBuffer(
image.flat<uint16>().data(), width, height,
width * channels * 2, channels, desired_channel_bits_,
compression_, &output->scalar<string>()(), nullptr),
compression_, &output->scalar<tstring>()(), nullptr),
errors::Internal("PNG encoding failed"));
}
}

View File

@ -298,6 +298,26 @@ Status WriteVarLenField(const FieldDescriptor& field_desc, const Tensor& input,
return Status::OK();
}
static void WriteStringAdapter(int field_number, const tstring& value,
CodedOutputStream* output) {
// Unfortunately, external proto does not accept string_view.
#if defined(PLATFORM_GOOGLE)
WireFormatLite::WriteString(field_number, StringPiece(value), output);
#else
WireFormatLite::WriteString(field_number, string(value), output);
#endif
}
static void WriteBytesAdapter(int field_number, const tstring& value,
CodedOutputStream* output) {
// Unfortunately, external proto does not accept string_view.
#if defined(PLATFORM_GOOGLE)
WireFormatLite::WriteBytes(field_number, StringPiece(value), output);
#else
WireFormatLite::WriteBytes(field_number, string(value), output);
#endif
}
// Writes a group field. Groups are treated like submessages, but tag-delimited
// instead of length-delimited. WireFormatLite handles this differently so we
// code it ourselves.
@ -388,15 +408,15 @@ Status WriteField(const FieldDescriptor& field_desc, const Tensor& input,
WireFormatLite::WriteBoolNoTag>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_STRING:
return WriteVarLenField<string, WireFormatLite::WriteString>(
return WriteVarLenField<tstring, WriteStringAdapter>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_GROUP:
return WriteGroup(field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_MESSAGE:
return WriteVarLenField<string, WireFormatLite::WriteBytes>(
return WriteVarLenField<tstring, WriteBytesAdapter>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_BYTES:
return WriteVarLenField<string, WireFormatLite::WriteBytes>(
return WriteVarLenField<tstring, WriteBytesAdapter>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_UINT32:
switch (dtype) {
@ -592,7 +612,7 @@ class EncodeProtoOp : public OpKernel {
message_index++) {
// TODO(nix): possibly optimize allocation here by calling
// `bufs(message_index).reserve(DEFAULT_BUF_SIZE)`.
StringOutputStream output_string(&bufs(message_index));
TStringOutputStream output_string(&bufs(message_index));
CodedOutputStream out(&output_string);
// Write fields in ascending field_number order.
for (int i : sorted_field_index_) {

View File

@ -58,7 +58,7 @@ class EncodeWavOp : public OpKernel {
OP_REQUIRES_OK(context,
wav::EncodeAudioAsS16LEWav(
audio.flat<float>().data(), sample_rate, channel_count,
sample_count, &output->scalar<string>()()));
sample_count, &output->scalar<tstring>()()));
}
};
REGISTER_KERNEL_BUILDER(Name("EncodeWav").Device(DEVICE_CPU), EncodeWavOp);

View File

@ -132,10 +132,10 @@ class ParseExampleOp : public OpKernel {
config.sparse.push_back({sparse_keys_t[d], attrs_.sparse_types[d]});
}
auto serialized_t = serialized->flat<string>();
auto names_t = names->flat<string>();
gtl::ArraySlice<string> slice(serialized_t.data(), serialized_t.size());
gtl::ArraySlice<string> names_slice(names_t.data(), names_t.size());
auto serialized_t = serialized->flat<tstring>();
auto names_t = names->flat<tstring>();
gtl::ArraySlice<tstring> slice(serialized_t.data(), serialized_t.size());
gtl::ArraySlice<tstring> names_slice(names_t.data(), names_t.size());
OP_REQUIRES_OK(
ctx,
@ -352,11 +352,11 @@ class ParseSequenceExampleOp : public OpKernel {
attrs_.feature_list_sparse_types[d]});
}
auto serialized_t = serialized->flat<string>();
auto debug_name_t = debug_name->flat<string>();
gtl::ArraySlice<string> slice(serialized_t.data(), serialized_t.size());
gtl::ArraySlice<string> names_slice(debug_name_t.data(),
debug_name_t.size());
auto serialized_t = serialized->flat<tstring>();
auto debug_name_t = debug_name->flat<tstring>();
gtl::ArraySlice<tstring> slice(serialized_t.data(), serialized_t.size());
gtl::ArraySlice<tstring> names_slice(debug_name_t.data(),
debug_name_t.size());
OP_REQUIRES_OK(
ctx,
@ -853,10 +853,12 @@ class DecodeJSONExampleOp : public OpKernel {
&binary_examples));
for (int i = 0; i < json_examples->NumElements(); ++i) {
const string& json_example = json_examples->flat<string>()(i);
auto status = protobuf::util::JsonToBinaryString(
resolver_.get(), "type.googleapis.com/tensorflow.Example",
json_example, &binary_examples->flat<string>()(i));
const tstring& json_example = json_examples->flat<tstring>()(i);
protobuf::io::ArrayInputStream in(json_example.data(),
json_example.size());
TStringOutputStream out(&binary_examples->flat<tstring>()(i));
auto status = protobuf::util::JsonToBinaryStream(
resolver_.get(), "type.googleapis.com/tensorflow.Example", &in, &out);
OP_REQUIRES(ctx, status.ok(),
errors::InvalidArgument("Error while parsing JSON: ",
string(status.error_message())));

View File

@ -124,7 +124,7 @@ struct ExampleStore {
Features* features = example.mutable_features();
(*features->mutable_feature())[k_str] = f;
}
CHECK(example.SerializeToString(&string_t(b)));
CHECK(SerializeToTString(example, &string_t(b)));
}
(*examples)[std::make_tuple(batch_size, num_keys, feature_size)] =
record_string;

View File

@ -52,7 +52,7 @@ void FarmhashFingerprint64(TTypes<uint8, 2>::ConstTensor input,
}
}
void FarmhashFingerprint64(TTypes<string>::ConstFlat input,
void FarmhashFingerprint64(TTypes<tstring>::ConstFlat input,
TTypes<uint8, 2>::Matrix output) {
DCHECK_EQ(output.dimension(0), input.dimension(0));
DCHECK_EQ(output.dimension(1), sizeof(uint64));
@ -79,7 +79,7 @@ class FingerprintOp : public OpKernel {
errors::InvalidArgument("`method` should be a scalar string: ",
method_tensor.shape()));
// For now, farmhash64 is the only function supported.
const string& method = method_tensor.scalar<string>()();
const tstring& method = method_tensor.scalar<tstring>()();
OP_REQUIRES(
context, method == "farmhash64",
errors::InvalidArgument("Unsupported fingerprint method: ", method));

View File

@ -82,10 +82,10 @@ TEST_F(FingerprintOpTest, StringGoldenValue) {
buffer(1).resize(7);
buffer(2).resize(0);
buffer(3).resize(19);
std::iota(buffer(0).begin(), buffer(0).end(), 0);
std::iota(buffer(1).begin(), buffer(1).end(), 7);
std::iota(buffer(2).begin(), buffer(2).end(), 71);
std::iota(buffer(3).begin(), buffer(3).end(), 41);
std::iota(&buffer(0)[0], &buffer(0)[0] + buffer(0).size(), 0);
std::iota(&buffer(1)[0], &buffer(1)[0] + buffer(1).size(), 7);
std::iota(&buffer(2)[0], &buffer(2)[0] + buffer(2).size(), 71);
std::iota(&buffer(3)[0], &buffer(3)[0] + buffer(3).size(), 41);
TF_ASSERT_OK(MakeFingerprintOp(&data));
TF_ASSERT_OK(RunOpKernel());
@ -137,7 +137,7 @@ TEST_F(FingerprintOpTest, CollisionString) {
auto& input = tensor.vec<tstring>()(0);
input.resize(size);
TTypes<uint8>::UnalignedFlat buffer(reinterpret_cast<uint8*>(&*input.begin()),
TTypes<uint8>::UnalignedFlat buffer(reinterpret_cast<uint8*>(&input[0]),
input.size());
buffer.setRandom();

View File

@ -134,7 +134,7 @@ T SubtleMustCopyIfIntegral(const T& value) {
return internal::SubtleMustCopy(value);
}
inline const string& SubtleMustCopyIfIntegral(const string& value) {
inline const tstring& SubtleMustCopyIfIntegral(const tstring& value) {
return value;
}

View File

@ -55,7 +55,7 @@ class RecordInputOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
Tensor out(DT_STRING, {batch_size_});
auto t_out = out.flat<string>();
auto t_out = out.flat<tstring>();
for (int i = 0; i < batch_size_; ++i) {
OP_REQUIRES_OK(ctx, yielder_->YieldOne(&t_out(i)));
}

View File

@ -44,7 +44,7 @@ RecordYielder::~RecordYielder() {
delete thread_;
}
Status RecordYielder::YieldOne(string* value) {
Status RecordYielder::YieldOne(tstring* value) {
mutex_lock l(mu_);
while (!BufEnough() && status_.ok()) {
buf_enough_.wait(l);

View File

@ -90,7 +90,7 @@ class RecordYielder {
RecordYielder& operator=(const RecordYielder&) = delete;
// Yields one 'value'.
Status YieldOne(string* value);
Status YieldOne(tstring* value);
// Returns the current epoch number.
int64 current_epoch() const { return epoch_; }

View File

@ -48,11 +48,15 @@ Status InternalCompute(const RE2& match, const string& rewrite,
}
auto output_flat = output_tensor->flat<tstring>();
for (size_t i = 0; i < output_flat.size(); ++i) {
// TODO(dero): Mitigate copy; Global and GlobalReplace below currently only
// accept std::string.
string buf = output_flat(i);
if (replace_global) {
RE2::GlobalReplace(&output_flat(i), match, rewrite);
RE2::GlobalReplace(&buf, match, rewrite);
} else {
RE2::Replace(&output_flat(i), match, rewrite);
RE2::Replace(&buf, match, rewrite);
}
output_flat(i) = std::move(buf);
}
return Status::OK();
}

View File

@ -71,9 +71,9 @@ Graph* SetupRegexReplaceGraph(const Tensor& input, const string& input_pattern,
const string& input_rewrite) {
Graph* g = new Graph(OpRegistry::Global());
Tensor pattern(DT_STRING, TensorShape({}));
pattern.flat<string>().setConstant(input_pattern);
pattern.flat<tstring>().setConstant(input_pattern);
Tensor rewrite(DT_STRING, TensorShape({}));
rewrite.flat<string>().setConstant(input_rewrite);
rewrite.flat<tstring>().setConstant(input_rewrite);
TF_CHECK_OK(NodeBuilder("regex_replace_op", "RegexReplace")
.Input(test::graph::Constant(g, input))

View File

@ -128,7 +128,7 @@ class StringNGramsOp : public tensorflow::OpKernel {
}
}
void CreateNgrams(const string* data, string* output, int num_ngrams,
void CreateNgrams(const tstring* data, tstring* output, int num_ngrams,
int ngram_width) const {
for (int ngram_index = 0; ngram_index < num_ngrams; ++ngram_index) {
int pad_width = get_pad_width(ngram_width);
@ -154,20 +154,20 @@ class StringNGramsOp : public tensorflow::OpKernel {
ngram_size += num_separators * separator_.length();
// Build the ngram.
string* ngram = &output[ngram_index];
tstring* ngram = &output[ngram_index];
ngram->reserve(ngram_size);
for (int n = 0; n < left_padding; ++n) {
*ngram += left_pad_;
*ngram += separator_;
ngram->append(left_pad_);
ngram->append(separator_);
}
for (int n = 0; n < num_tokens - 1; ++n) {
*ngram += data[data_start_index + n];
*ngram += separator_;
ngram->append(data[data_start_index + n]);
ngram->append(separator_);
}
*ngram += data[data_start_index + num_tokens - 1];
ngram->append(data[data_start_index + num_tokens - 1]);
for (int n = 0; n < right_padding; ++n) {
*ngram += separator_;
*ngram += right_pad_;
ngram->append(separator_);
ngram->append(right_pad_);
}
// In debug mode only: validate that we've reserved enough space for the

View File

@ -51,12 +51,12 @@ class NgramKernelTest : public tensorflow::OpsTestBase {
TF_ASSERT_OK(InitOp());
}
void assert_string_equal(const std::vector<string> &expected,
void assert_string_equal(const std::vector<tstring> &expected,
const Tensor &value) {
Tensor expected_tensor(allocator(), DT_STRING,
TensorShape({static_cast<int64>(expected.size())}));
test::FillValues<string>(&expected_tensor, expected);
test::ExpectTensorEqual<string>(expected_tensor, value);
test::FillValues<tstring>(&expected_tensor, expected);
test::ExpectTensorEqual<tstring>(expected_tensor, value);
}
void assert_int64_equal(const std::vector<int64> &expected,
const Tensor &value) {
@ -72,11 +72,11 @@ TEST_F(NgramKernelTest, TestPaddedTrigrams) {
// Batch items are:
// 0: "a", "b", "c", "d"
// 1: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values( //
std::vector<tstring> expected_values( //
{"LP|LP|a", "LP|a|b", "a|b|c", "b|c|d", "c|d|RP", "d|RP|RP", // 0
"LP|LP|e", "LP|e|f", "e|f|RP", "f|RP|RP"}); // 1
std::vector<int64> expected_splits({0, 6, 10});
@ -90,11 +90,11 @@ TEST_F(NgramKernelTest, TestPaddedBigramsAndTrigrams) {
// Batch items are:
// 0: "a", "b", "c", "d"
// 1: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values(
std::vector<tstring> expected_values(
{"LP|a", "a|b", "b|c", "c|d", "d|RP", "LP|LP|a", "LP|a|b", "a|b|c",
"b|c|d", "c|d|RP", "d|RP|RP", // 0
"LP|e", "e|f", "f|RP", "LP|LP|e", "LP|e|f", "e|f|RP", "f|RP|RP"}); // 1
@ -109,11 +109,11 @@ TEST_F(NgramKernelTest, TestPaddedBigrams) {
// Batch items are:
// 0: "a", "b", "c", "d"
// 1: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values( //
std::vector<tstring> expected_values( //
{"LP|a", "a|b", "b|c", "c|d", "d|RP", // 0
"LP|e", "e|f", "f|RP"}); // 1
std::vector<int64> expected_splits({0, 5, 8});
@ -127,11 +127,11 @@ TEST_F(NgramKernelTest, TestPaddingIsAtMostNGramSizeMinus1) {
// Batch items are:
// 0: "a", "b", "c", "d"
// 1: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values( //
std::vector<tstring> expected_values( //
{"LP|a", "a|b", "b|c", "c|d", "d|RP", // 0
"LP|e", "e|f", "f|RP"}); // 1
std::vector<int64> expected_splits({0, 5, 8});
@ -145,11 +145,11 @@ TEST_F(NgramKernelTest, TestPaddedUnigramAndBigrams) {
// Batch items are:
// 0: "a", "b", "c", "d"
// 1: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values( //
std::vector<tstring> expected_values( //
{"a", "b", "c", "d", "LP|a", "a|b", "b|c", "c|d", "d|RP", // 0
"e", "f", "LP|e", "e|f", "f|RP"}); // 1
std::vector<int64> expected_splits({0, 9, 14});
@ -166,11 +166,11 @@ TEST_F(NgramKernelTest, TestOverlappingPaddedNGrams) {
// 0: "a"
// 1: "b", "c", "d"
// 2: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values( //
std::vector<tstring> expected_values( //
{"LP|LP|a", "LP|a|RP", "a|RP|RP", // ngrams for elem. 0
"LP|LP|b", "LP|b|c", "b|c|d", "c|d|RP", "d|RP|RP", // ngrams for elem. 1
"LP|LP|e", "LP|e|f", "e|f|RP", "f|RP|RP"}); // ngrams for elem. 2
@ -186,12 +186,12 @@ TEST_F(NgramKernelTest, TestOverlappingPaddedMultiCharNGrams) {
// 0: "a"
// 1: "b", "c", "d"
// 2: "e", "f"
AddInputFromArray<string>(TensorShape({6}),
{"aa", "bb", "cc", "dd", "ee", "ff"});
AddInputFromArray<tstring>(TensorShape({6}),
{"aa", "bb", "cc", "dd", "ee", "ff"});
AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values( //
std::vector<tstring> expected_values( //
{"LP|LP|aa", "LP|aa|RP", "aa|RP|RP", //
"LP|LP|bb", "LP|bb|cc", "bb|cc|dd", "cc|dd|RP", "dd|RP|RP", //
"LP|LP|ee", "LP|ee|ff", "ee|ff|RP", "ff|RP|RP"}); //
@ -207,13 +207,13 @@ TEST_F(NgramKernelTest, TestMultiOverlappingPaddedNGrams) {
MakeOp("|", {5}, "LP", "RP", -1, false);
// Batch items are:
// 0: "a"
AddInputFromArray<string>(TensorShape({1}), {"a"});
AddInputFromArray<tstring>(TensorShape({1}), {"a"});
AddInputFromArray<int64>(TensorShape({2}), {0, 1});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values({"LP|LP|LP|LP|a", "LP|LP|LP|a|RP",
"LP|LP|a|RP|RP", "LP|a|RP|RP|RP",
"a|RP|RP|RP|RP"});
std::vector<tstring> expected_values({"LP|LP|LP|LP|a", "LP|LP|LP|a|RP",
"LP|LP|a|RP|RP", "LP|a|RP|RP|RP",
"a|RP|RP|RP|RP"});
std::vector<int64> expected_splits({0, 5});
assert_string_equal(expected_values, *GetOutput(0));
@ -225,11 +225,11 @@ TEST_F(NgramKernelTest, TestUnpaddedTrigrams) {
// Batch items are:
// 0: "a", "b", "c", "d"
// 1: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values({"a|b|c", "b|c|d"});
std::vector<tstring> expected_values({"a|b|c", "b|c|d"});
std::vector<int64> expected_splits({0, 2, 2});
assert_string_equal(expected_values, *GetOutput(0));
@ -241,11 +241,11 @@ TEST_F(NgramKernelTest, TestUnpaddedTrigramsWithEmptySequence) {
// Batch items are:
// 0: "a", "b", "c", "d"
// 1: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({4}), {0, 4, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values({"a|b|c", "b|c|d"});
std::vector<tstring> expected_values({"a|b|c", "b|c|d"});
std::vector<int64> expected_splits({0, 2, 2, 2});
assert_string_equal(expected_values, *GetOutput(0));
@ -257,11 +257,11 @@ TEST_F(NgramKernelTest, TestUnpaddedTrigramsWithPreserveShort) {
// Batch items are:
// 0: "a", "b", "c", "d"
// 1: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values({"a|b|c", "b|c|d", "e|f"});
std::vector<tstring> expected_values({"a|b|c", "b|c|d", "e|f"});
std::vector<int64> expected_splits({0, 2, 3});
assert_string_equal(expected_values, *GetOutput(0));
@ -273,11 +273,11 @@ TEST_F(NgramKernelTest, TestUnpaddedTrigramsWithPreserveShortAndEmptySequence) {
// Batch items are:
// 0: "a", "b", "c", "d"
// 1: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({4}), {0, 4, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values({"a|b|c", "b|c|d", "e|f"});
std::vector<tstring> expected_values({"a|b|c", "b|c|d", "e|f"});
std::vector<int64> expected_splits({0, 2, 2, 3});
assert_string_equal(expected_values, *GetOutput(0));
@ -289,11 +289,11 @@ TEST_F(NgramKernelTest, TestUnpaddedTrigramsAndQuadgramsWithPreserveShort) {
// Batch items are:
// 0: "a", "b", "c", "d"
// 1: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values({"a|b|c|d", "a|b|c", "b|c|d", "e|f"});
std::vector<tstring> expected_values({"a|b|c|d", "a|b|c", "b|c|d", "e|f"});
std::vector<int64> expected_splits({0, 3, 4});
assert_string_equal(expected_values, *GetOutput(0));
@ -305,11 +305,11 @@ TEST_F(NgramKernelTest, TestUnpaddedBigramsAndTrigrams) {
// Batch items are:
// 0: "a", "b", "c", "d"
// 1: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values(
std::vector<tstring> expected_values(
{"a|b", "b|c", "c|d", "a|b|c", "b|c|d", "e|f"});
std::vector<int64> expected_splits({0, 5, 6});
@ -322,13 +322,13 @@ TEST_F(NgramKernelTest, TestUnpaddedBigramsAndTrigramsWithPreserveShort) {
// Batch items are:
// 0: "a", "b", "c", "d"
// 1: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
TF_ASSERT_OK(RunOpKernel());
// Note that in this case, because the bigram 'e|f' was already generated,
// the op will not generate a special preserve_short bigram.
std::vector<string> expected_values(
std::vector<tstring> expected_values(
{"a|b", "b|c", "c|d", "a|b|c", "b|c|d", "e|f"});
std::vector<int64> expected_splits({0, 5, 6});
@ -341,13 +341,13 @@ TEST_F(NgramKernelTest, TestUnpaddedTrigramsAndBigramsWithPreserveShort) {
// Batch items are:
// 0: "a", "b", "c", "d"
// 1: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
TF_ASSERT_OK(RunOpKernel());
// Note that in this case, because the bigram 'e|f' was already generated,
// the op will not generate a special preserve_short bigram.
std::vector<string> expected_values(
std::vector<tstring> expected_values(
{"a|b|c", "b|c|d", "a|b", "b|c", "c|d", "e|f"});
std::vector<int64> expected_splits({0, 5, 6});
@ -360,11 +360,11 @@ TEST_F(NgramKernelTest, TestUnpaddedBigrams) {
// Batch items are:
// 0: "a", "b", "c", "d"
// 1: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values({"a|b", "b|c", "c|d", "e|f"});
std::vector<tstring> expected_values({"a|b", "b|c", "c|d", "e|f"});
std::vector<int64> expected_splits({0, 3, 4});
assert_string_equal(expected_values, *GetOutput(0));
@ -377,11 +377,11 @@ TEST_F(NgramKernelTest, TestOverlappingUnpaddedNGrams) {
// 0: "a"
// 1: "b", "c", "d"
// 2: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values({"b|c|d"});
std::vector<tstring> expected_values({"b|c|d"});
std::vector<int64> expected_splits({0, 0, 1, 1});
assert_string_equal(expected_values, *GetOutput(0));
@ -394,11 +394,11 @@ TEST_F(NgramKernelTest, TestOverlappingUnpaddedNGramsNoOutput) {
// 0: "a"
// 1: "b", "c", "d"
// 2: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values({});
std::vector<tstring> expected_values({});
std::vector<int64> expected_splits({0, 0, 0, 0});
assert_string_equal(expected_values, *GetOutput(0));
@ -410,12 +410,13 @@ TEST_F(NgramKernelTest, TestSinglyPaddedTrigrams) {
// Batch items are:
// 0: "a", "b", "c", "d"
// 1: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values({"LP|a|b", "a|b|c", "b|c|d", "c|d|RP", //
"LP|e|f", "e|f|RP"});
std::vector<tstring> expected_values({"LP|a|b", "a|b|c", "b|c|d",
"c|d|RP", //
"LP|e|f", "e|f|RP"});
std::vector<int64> expected_splits({0, 4, 6});
assert_string_equal(expected_values, *GetOutput(0));
@ -427,12 +428,12 @@ TEST_F(NgramKernelTest, TestSinglyPaddedBigrams) {
// Batch items are:
// 0: "a", "b", "c", "d"
// 1: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values({"LP|a", "a|b", "b|c", "c|d", "d|RP", //
"LP|e", "e|f", "f|RP"});
std::vector<tstring> expected_values({"LP|a", "a|b", "b|c", "c|d", "d|RP", //
"LP|e", "e|f", "f|RP"});
std::vector<int64> expected_splits({0, 5, 8});
assert_string_equal(expected_values, *GetOutput(0));
@ -444,11 +445,11 @@ TEST_F(NgramKernelTest, TestSinglyPaddedBigramsAnd5grams) {
// Batch items are:
// 0: "a", "b", "c", "d"
// 1: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values( //
std::vector<tstring> expected_values( //
{"LP|a", "a|b", "b|c", "c|d", "d|RP", "LP|a|b|c|d", "a|b|c|d|RP", //
"LP|e", "e|f", "f|RP"});
std::vector<int64> expected_splits({0, 7, 10});
@ -462,12 +463,12 @@ TEST_F(NgramKernelTest, TestSinglyPadded5gramsWithPreserveShort) {
// Batch items are:
// 0: "a", "b", "c", "d"
// 1: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values( //
{"LP|a|b|c|d", "a|b|c|d|RP", //
std::vector<tstring> expected_values( //
{"LP|a|b|c|d", "a|b|c|d|RP", //
"LP|e|f|RP"});
std::vector<int64> expected_splits({0, 2, 3});
@ -481,11 +482,11 @@ TEST_F(NgramKernelTest, TestOverlappingSinglyPaddedNGrams) {
// 0: "a"
// 1: "b", "c", "d"
// 2: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values(
std::vector<tstring> expected_values(
{"LP|a|RP", // ngrams for elem. 0
"LP|b|c", "b|c|d", "c|d|RP", // ngrams for elem. 1
"LP|e|f", "e|f|RP"}); // ngrams for elem. 2
@ -501,11 +502,11 @@ TEST_F(NgramKernelTest, TestOverlappingSinglyPaddedNGramsNoOutput) {
// 0: "a"
// 1: "b", "c", "d"
// 2: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values({"LP|b|c|d|RP"});
std::vector<tstring> expected_values({"LP|b|c|d|RP"});
std::vector<int64> expected_splits({0, 0, 1, 1});
assert_string_equal(expected_values, *GetOutput(0));
@ -517,11 +518,11 @@ TEST_F(NgramKernelTest, TestSinglyPaddedUnigrams) {
// Batch items are:
// 0: "a", "b", "c", "d"
// 1: "e", "f"
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values({"a", "b", "c", "d", "e", "f"});
std::vector<tstring> expected_values({"a", "b", "c", "d", "e", "f"});
std::vector<int64> expected_splits({0, 4, 6});
assert_string_equal(expected_values, *GetOutput(0));
@ -530,11 +531,11 @@ TEST_F(NgramKernelTest, TestSinglyPaddedUnigrams) {
TEST_F(NgramKernelTest, TestEmptyInput) {
MakeOp("|", {1}, "LP", "RP", 3, false);
AddInputFromArray<string>(TensorShape({0}), {});
AddInputFromArray<tstring>(TensorShape({0}), {});
AddInputFromArray<int64>(TensorShape({0}), {});
TF_ASSERT_OK(RunOpKernel());
std::vector<string> expected_values({});
std::vector<tstring> expected_values({});
std::vector<int64> expected_splits({});
assert_string_equal(expected_values, *GetOutput(0));

View File

@ -52,7 +52,7 @@ namespace tensorflow {
namespace {
void Encode(const UnicodeEncoding encoding, const icu::UnicodeString& in,
string* out) {
tstring* out) {
if (encoding == UnicodeEncoding::UTF8) {
out->clear();
in.toUTF8String(*out);
@ -330,7 +330,7 @@ class UnicodeTranscodeOp : public OpKernel {
// Transcode the string from input encoding to the output_encoding_. If
// non-valid characters are encountered, use the subst_/elide_replacement_
// config to handle them.
void Transcode(string* s, UConverter* input_encoder,
void Transcode(tstring* s, UConverter* input_encoder,
bool* found_any_format_error) {
icu::UnicodeString source;
IterateUnicodeString(
@ -561,9 +561,9 @@ class UnicodeEncodeOp : public OpKernel {
appendable_unicode_string.appendCodePoint(code_point);
}
// Encode our string and save in the output.
string result;
tstring result;
Encode(encoding_, unicode_string, &result);
output_tensor_flat(i - 1) = result;
output_tensor_flat(i - 1) = std::move(result);
}
}

View File

@ -34,8 +34,8 @@ limitations under the License.
namespace tensorflow {
static Status ReadEntireFile(Env* env, const string& filename,
string* contents) {
template <typename T>
static Status ReadEntireFile(Env* env, const string& filename, T* contents) {
std::unique_ptr<RandomAccessFile> file;
TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
io::RandomAccessInputStream input_stream(file.get());
@ -112,8 +112,8 @@ class ReadFileOp : public OpKernel {
OP_REQUIRES_OK(context, context->allocate_output("contents",
TensorShape({}), &output));
OP_REQUIRES_OK(context,
ReadEntireFile(context->env(), input->scalar<string>()(),
&output->scalar<string>()()));
ReadEntireFile(context->env(), input->scalar<tstring>()(),
&output->scalar<tstring>()()));
}
};

View File

@ -167,7 +167,8 @@ Status BufferedInputStream::Seek(int64 position) {
return SkipNBytes(position - bufpos);
}
Status BufferedInputStream::ReadAll(string* result) {
template <typename T>
Status BufferedInputStream::ReadAll(T* result) {
result->clear();
Status status;
while (status.ok()) {
@ -186,6 +187,11 @@ Status BufferedInputStream::ReadAll(string* result) {
return status;
}
template Status BufferedInputStream::ReadAll<string>(string* result);
#ifdef USE_TSTRING
template Status BufferedInputStream::ReadAll<tstring>(tstring* result);
#endif // USE_TSTRING
Status BufferedInputStream::Reset() {
TF_RETURN_IF_ERROR(input_stream_->Reset());
pos_ = 0;

View File

@ -79,7 +79,8 @@ class BufferedInputStream : public InputStreamInterface {
//
// Note: the amount of memory used by this function call is unbounded, so only
// use in ops that expect that behavior.
Status ReadAll(string* result);
template <typename T>
Status ReadAll(T* result);
Status Reset() override;

View File

@ -84,7 +84,7 @@ void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize) {
// -----------------------------------------------------------------------------
void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize,
string *destination) {
tstring *destination) {
MemDestMgr *dest;
if (cinfo->dest == nullptr) {
cinfo->dest = reinterpret_cast<struct jpeg_destination_mgr *>(

View File

@ -33,7 +33,7 @@ typedef struct {
JOCTET *buffer;
int bufsize;
int datacount;
string *dest;
tstring *dest;
} MemDestMgr;
typedef struct {
@ -52,7 +52,7 @@ void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize);
// Same as above, except that buffer is only used as a temporary structure and
// is emptied into "destination" as soon as it fills up.
void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize,
string *destination);
tstring *destination);
} // namespace jpeg
} // namespace tensorflow

View File

@ -592,7 +592,7 @@ bool GetImageInfo(const void* srcdata, int datasize, int* width, int* height,
namespace {
bool CompressInternal(const uint8* srcdata, int width, int height,
const CompressFlags& flags, string* output) {
const CompressFlags& flags, tstring* output) {
output->clear();
const int components = (static_cast<int>(flags.format) & 0xff);
@ -762,14 +762,14 @@ bool CompressInternal(const uint8* srcdata, int width, int height,
// -----------------------------------------------------------------------------
bool Compress(const void* srcdata, int width, int height,
const CompressFlags& flags, string* output) {
const CompressFlags& flags, tstring* output) {
return CompressInternal(static_cast<const uint8*>(srcdata), width, height,
flags, output);
}
string Compress(const void* srcdata, int width, int height,
const CompressFlags& flags) {
string temp;
tstring Compress(const void* srcdata, int width, int height,
const CompressFlags& flags) {
tstring temp;
CompressInternal(static_cast<const uint8*>(srcdata), width, height, flags,
&temp);
// If CompressInternal fails, temp will be empty.

View File

@ -149,12 +149,12 @@ struct CompressFlags {
// The encoded data is returned as a string.
// If not empty, XMP metadata can be embedded in the image header
// On error, returns the empty string (which is never a valid jpeg).
string Compress(const void* srcdata, int width, int height,
const CompressFlags& flags);
tstring Compress(const void* srcdata, int width, int height,
const CompressFlags& flags);
// On error, returns false and sets output to empty.
bool Compress(const void* srcdata, int width, int height,
const CompressFlags& flags, string* output);
const CompressFlags& flags, tstring* output);
} // namespace jpeg
} // namespace tensorflow

View File

@ -326,7 +326,7 @@ TEST(JpegMemTest, Jpeg2) {
CHECK_NE(string::npos, cpdata1.find(kXMP));
// Test the other API, where a storage string is supplied
string cptest;
tstring cptest;
flags.stride = 0;
Compress(refdata1.get(), in_w, in_h, flags, &cptest);
CHECK_EQ(cptest, cpdata1);
@ -465,7 +465,7 @@ TEST(JpegMemTest, ChromaDownsampling) {
flags.format = FORMAT_RGB;
flags.quality = 85;
flags.chroma_downsampling = downsample;
string recompressed;
tstring recompressed;
Compress(uncompressed.get(), w, h, flags, &recompressed);
CHECK(!recompressed.empty());
CHECK_EQ(IsChromaDownsampled(recompressed), downsample);

View File

@ -105,8 +105,9 @@ void StringReader(png_structp png_ptr, png_bytep data, png_size_t length) {
}
}
template <typename T>
void StringWriter(png_structp png_ptr, png_bytep data, png_size_t length) {
string* const s = absl::bit_cast<string*>(png_get_io_ptr(png_ptr));
T* const s = absl::bit_cast<T*>(png_get_io_ptr(png_ptr));
s->append(absl::bit_cast<const char*>(data), length);
}
@ -340,9 +341,10 @@ bool CommonFinishDecode(png_bytep data, int row_bytes, DecodeContext* context) {
return ok;
}
template <typename T>
bool WriteImageToBuffer(
const void* image, int width, int height, int row_bytes, int num_channels,
int channel_bits, int compression, string* png_string,
int channel_bits, int compression, T* png_string,
const std::vector<std::pair<string, string> >* metadata) {
CHECK_NOTNULL(image);
CHECK_NOTNULL(png_string);
@ -384,7 +386,7 @@ bool WriteImageToBuffer(
return false;
}
png_set_write_fn(png_ptr, png_string, StringWriter, StringWriterFlush);
png_set_write_fn(png_ptr, png_string, StringWriter<T>, StringWriterFlush);
if (compression < 0) compression = Z_DEFAULT_COMPRESSION;
png_set_compression_level(png_ptr, compression);
png_set_compression_mem_level(png_ptr, MAX_MEM_LEVEL);
@ -418,5 +420,16 @@ bool WriteImageToBuffer(
return true;
}
template bool WriteImageToBuffer<string>(
const void* image, int width, int height, int row_bytes, int num_channels,
int channel_bits, int compression, string* png_string,
const std::vector<std::pair<string, string> >* metadata);
#ifdef USE_TSTRING
template bool WriteImageToBuffer<tstring>(
const void* image, int width, int height, int row_bytes, int num_channels,
int channel_bits, int compression, tstring* png_string,
const std::vector<std::pair<string, string> >* metadata);
#endif // USE_TSTRING
} // namespace png
} // namespace tensorflow

View File

@ -94,9 +94,10 @@ void CommonFreeDecode(DecodeContext* context);
// compression is in [-1,9], where 0 is fast and weak compression, 9 is slow
// and strong, and -1 is the zlib default.
template <typename T>
bool WriteImageToBuffer(
const void* image, int width, int height, int row_bytes, int num_channels,
int channel_bits, int compression, string* png_string,
int channel_bits, int compression, T* png_string,
const std::vector<std::pair<string, string> >* metadata);
} // namespace png

View File

@ -73,7 +73,8 @@ Status DecodeThreeChars(const char* codes, char* result) {
}
} // namespace
Status Base64Decode(StringPiece data, string* decoded) {
template <typename T>
Status Base64Decode(StringPiece data, T* decoded) {
if (decoded == nullptr) {
return errors::Internal("'decoded' cannot be nullptr.");
}
@ -135,11 +136,13 @@ Status Base64Decode(StringPiece data, string* decoded) {
return Status::OK();
}
Status Base64Encode(StringPiece source, string* encoded) {
template <typename T>
Status Base64Encode(StringPiece source, T* encoded) {
return Base64Encode(source, false, encoded);
}
Status Base64Encode(StringPiece source, bool with_padding, string* encoded) {
template <typename T>
Status Base64Encode(StringPiece source, bool with_padding, T* encoded) {
const char* const base64_chars = kBase64UrlSafeChars;
if (encoded == nullptr) {
return errors::Internal("'encoded' cannot be nullptr.");
@ -191,4 +194,16 @@ Status Base64Encode(StringPiece source, bool with_padding, string* encoded) {
return Status::OK();
}
template Status Base64Decode<string>(StringPiece data, string* decoded);
template Status Base64Encode<string>(StringPiece source, string* encoded);
template Status Base64Encode<string>(StringPiece source, bool with_padding,
string* encoded);
#ifdef USE_TSTRING
template Status Base64Decode<tstring>(StringPiece data, tstring* decoded);
template Status Base64Encode<tstring>(StringPiece source, tstring* encoded);
template Status Base64Encode<tstring>(StringPiece source, bool with_padding,
tstring* encoded);
#endif // USE_TSTRING
} // namespace tensorflow

View File

@ -24,13 +24,17 @@ namespace tensorflow {
/// \brief Converts data into web-safe base64 encoding.
///
/// See https://en.wikipedia.org/wiki/Base64
Status Base64Encode(StringPiece data, bool with_padding, string* encoded);
Status Base64Encode(StringPiece data, string* encoded); // with_padding=false.
template <typename T>
Status Base64Encode(StringPiece source, bool with_padding, T* encoded);
template <typename T>
Status Base64Encode(StringPiece source,
T* encoded); // with_padding=false.
/// \brief Converts data from web-safe base64 encoding.
///
/// See https://en.wikipedia.org/wiki/Base64
Status Base64Decode(StringPiece data, string* decoded);
template <typename T>
Status Base64Decode(StringPiece data, T* decoded);
} // namespace tensorflow

View File

@ -21,11 +21,11 @@ namespace tensorflow {
TEST(Base64, EncodeDecode) {
const string original = "a simple test message!";
string encoded;
tstring encoded;
TF_EXPECT_OK(Base64Encode(original, &encoded));
EXPECT_EQ("YSBzaW1wbGUgdGVzdCBtZXNzYWdlIQ", encoded);
string decoded;
tstring decoded;
TF_EXPECT_OK(Base64Decode(encoded, &decoded));
EXPECT_EQ(original, decoded);
}

View File

@ -132,9 +132,10 @@ Status ReadString(const string& data, int expected_length, string* value,
return Status::OK();
}
template <typename T>
Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate,
size_t num_channels, size_t num_frames,
string* wav_string) {
T* wav_string) {
constexpr size_t kFormatChunkSize = 16;
constexpr size_t kCompressionCodePcm = 1;
constexpr size_t kBitsPerSample = 16;
@ -173,7 +174,7 @@ Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate,
}
wav_string->resize(file_size);
char* data = &wav_string->at(0);
char* data = &(*wav_string)[0];
WavHeader* header = absl::bit_cast<WavHeader*>(data);
// Fill RIFF chunk.
@ -208,6 +209,19 @@ Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate,
return Status::OK();
}
template Status EncodeAudioAsS16LEWav<string>(const float* audio,
size_t sample_rate,
size_t num_channels,
size_t num_frames,
string* wav_string);
#ifdef USE_TSTRING
template Status EncodeAudioAsS16LEWav<tstring>(const float* audio,
size_t sample_rate,
size_t num_channels,
size_t num_frames,
tstring* wav_string);
#endif // USE_TSTRING
Status DecodeLin16WaveAsFloatVector(const string& wav_string,
std::vector<float>* float_values,
uint32* sample_count, uint16* channel_count,

View File

@ -41,9 +41,10 @@ namespace wav {
// if (EncodeAudioAsS16LEWav(audio_buffer, 8000, 2, 4, &wav_string).ok()) {
// // Use wav_string.
// }
template <typename T>
Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate,
size_t num_channels, size_t num_frames,
string* wav_string);
T* wav_string);
// Decodes the little-endian signed 16-bit PCM WAV file data (aka LIN16
// encoding) into a float Tensor. The channels are encoded as the lowest

View File

@ -34,12 +34,13 @@ Status ReadString(const string& data, int expected_length, string* value,
TEST(WavIO, BadArguments) {
float audio[] = {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f};
string result;
tstring result;
EXPECT_EQ(error::INVALID_ARGUMENT,
EncodeAudioAsS16LEWav(nullptr, 44100, 2, 3, &result).code());
EXPECT_EQ(error::INVALID_ARGUMENT,
EncodeAudioAsS16LEWav(audio, 44100, 2, 3, nullptr).code());
EXPECT_EQ(
error::INVALID_ARGUMENT,
EncodeAudioAsS16LEWav(audio, 44100, 2, 3, (tstring*)nullptr).code());
const size_t kuint32max_plus_one = static_cast<size_t>(kuint32max) + 1;
const size_t kuint16max_plus_one = static_cast<size_t>(kuint16max) + 1;

View File

@ -20,4 +20,43 @@ namespace tensorflow {
const char* kProtobufInt64Typename = "::tensorflow::protobuf_int64";
const char* kProtobufUint64Typename = "::tensorflow::protobuf_uint64";
#ifdef USE_TSTRING
TStringOutputStream::TStringOutputStream(tstring* target) : target_(target) {}
bool TStringOutputStream::Next(void** data, int* size) {
int old_size = target_->size();
// Grow the string.
if (old_size < target_->capacity()) {
// Resize the string to match its capacity, since we can get away
// without a memory allocation this way.
target_->resize_uninitialized(target_->capacity());
} else {
// Size has reached capacity, try to double the size.
if (old_size > std::numeric_limits<int>::max() / 2) {
// Can not double the size otherwise it is going to cause integer
// overflow in the expression below: old_size * 2 ";
return false;
}
// Double the size, also make sure that the new size is at least
// kMinimumSize.
target_->resize_uninitialized(
std::max(old_size * 2,
kMinimumSize + 0)); // "+ 0" works around GCC4 weirdness.
}
*data = target_->data() + old_size;
*size = target_->size() - old_size;
return true;
}
void TStringOutputStream::BackUp(int count) {
target_->resize(target_->size() - count);
}
protobuf::io::ByteCountInt64 TStringOutputStream::ByteCount() const {
return target_->size();
}
#endif // USE_TSTRING
} // namespace tensorflow

View File

@ -90,6 +90,29 @@ inline bool SerializeToTString(const protobuf::MessageLite& proto,
#endif // USE_TSTRING
}
#ifdef USE_TSTRING
// Analogue to StringOutputStream for tstring.
class TStringOutputStream : public protobuf::io::ZeroCopyOutputStream {
public:
explicit TStringOutputStream(tstring* target);
~TStringOutputStream() override = default;
TStringOutputStream(const TStringOutputStream&) = delete;
void operator=(const TStringOutputStream&) = delete;
bool Next(void** data, int* size) override;
void BackUp(int count) override;
protobuf::io::ByteCountInt64 ByteCount() const override;
private:
static const int kMinimumSize = 16;
tstring* target_;
};
#else // USE_TSTRING
typedef protobuf::io::StringOutputStream TStringOutputStream;
#endif // USE_TSTRING
} // namespace tensorflow
#endif // TENSORFLOW_CORE_PLATFORM_PROTOBUF_H_

View File

@ -157,6 +157,8 @@ class tstring {
size_t size() const { return str_.size(); }
size_t capacity() const { return str_.capacity(); }
const char* c_str() const { return str_.c_str(); }
const char* data() const { return str_.data(); }
@ -207,6 +209,8 @@ class tstring {
return *this;
}
void push_back(char ch) { str_.push_back(ch); }
friend const tstring operator+(const tstring& a, const tstring& b);
friend bool operator==(const char* a, const tstring& b);
friend bool operator==(const std::string& a, const tstring& b);