Update kernels and related libs to use tstring.
This is a part of a larger migration effort for tensorflow::tstring. See: https://github.com/tensorflow/community/pull/91 PiperOrigin-RevId: 265811129
This commit is contained in:
parent
63ba081d07
commit
f742e74da3
@ -124,7 +124,7 @@ class EncodeJpegOp : public OpKernel {
|
||||
context->allocate_output(0, TensorShape({}), &output));
|
||||
OP_REQUIRES(context,
|
||||
jpeg::Compress(image.flat<uint8>().data(), dim_size1, dim_size0,
|
||||
adjusted_flags, &output->scalar<string>()()),
|
||||
adjusted_flags, &output->scalar<tstring>()()),
|
||||
errors::Internal("JPEG encoding failed"));
|
||||
}
|
||||
|
||||
@ -190,7 +190,7 @@ class EncodeJpegVariableQualityOp : public OpKernel {
|
||||
context->allocate_output(0, TensorShape({}), &output));
|
||||
OP_REQUIRES(context,
|
||||
jpeg::Compress(image.flat<uint8>().data(), dim_size1, dim_size0,
|
||||
adjusted_flags, &output->scalar<string>()()),
|
||||
adjusted_flags, &output->scalar<tstring>()()),
|
||||
errors::Internal("JPEG encoding failed"));
|
||||
}
|
||||
};
|
||||
|
@ -78,17 +78,17 @@ class EncodePngOp : public OpKernel {
|
||||
context->allocate_output(0, TensorShape({}), &output));
|
||||
if (desired_channel_bits_ == 8) {
|
||||
OP_REQUIRES(context,
|
||||
png::WriteImageToBuffer(image.flat<uint8>().data(), width,
|
||||
height, width * channels, channels,
|
||||
desired_channel_bits_, compression_,
|
||||
&output->scalar<string>()(), nullptr),
|
||||
png::WriteImageToBuffer(
|
||||
image.flat<uint8>().data(), width, height,
|
||||
width * channels, channels, desired_channel_bits_,
|
||||
compression_, &output->scalar<tstring>()(), nullptr),
|
||||
errors::Internal("PNG encoding failed"));
|
||||
} else {
|
||||
OP_REQUIRES(context,
|
||||
png::WriteImageToBuffer(
|
||||
image.flat<uint16>().data(), width, height,
|
||||
width * channels * 2, channels, desired_channel_bits_,
|
||||
compression_, &output->scalar<string>()(), nullptr),
|
||||
compression_, &output->scalar<tstring>()(), nullptr),
|
||||
errors::Internal("PNG encoding failed"));
|
||||
}
|
||||
}
|
||||
|
@ -298,6 +298,26 @@ Status WriteVarLenField(const FieldDescriptor& field_desc, const Tensor& input,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static void WriteStringAdapter(int field_number, const tstring& value,
|
||||
CodedOutputStream* output) {
|
||||
// Unfortunately, external proto does not accept string_view.
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
WireFormatLite::WriteString(field_number, StringPiece(value), output);
|
||||
#else
|
||||
WireFormatLite::WriteString(field_number, string(value), output);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void WriteBytesAdapter(int field_number, const tstring& value,
|
||||
CodedOutputStream* output) {
|
||||
// Unfortunately, external proto does not accept string_view.
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
WireFormatLite::WriteBytes(field_number, StringPiece(value), output);
|
||||
#else
|
||||
WireFormatLite::WriteBytes(field_number, string(value), output);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Writes a group field. Groups are treated like submessages, but tag-delimited
|
||||
// instead of length-delimited. WireFormatLite handles this differently so we
|
||||
// code it ourselves.
|
||||
@ -388,15 +408,15 @@ Status WriteField(const FieldDescriptor& field_desc, const Tensor& input,
|
||||
WireFormatLite::WriteBoolNoTag>(
|
||||
field_desc, input, message_index, size, output);
|
||||
case WireFormatLite::TYPE_STRING:
|
||||
return WriteVarLenField<string, WireFormatLite::WriteString>(
|
||||
return WriteVarLenField<tstring, WriteStringAdapter>(
|
||||
field_desc, input, message_index, size, output);
|
||||
case WireFormatLite::TYPE_GROUP:
|
||||
return WriteGroup(field_desc, input, message_index, size, output);
|
||||
case WireFormatLite::TYPE_MESSAGE:
|
||||
return WriteVarLenField<string, WireFormatLite::WriteBytes>(
|
||||
return WriteVarLenField<tstring, WriteBytesAdapter>(
|
||||
field_desc, input, message_index, size, output);
|
||||
case WireFormatLite::TYPE_BYTES:
|
||||
return WriteVarLenField<string, WireFormatLite::WriteBytes>(
|
||||
return WriteVarLenField<tstring, WriteBytesAdapter>(
|
||||
field_desc, input, message_index, size, output);
|
||||
case WireFormatLite::TYPE_UINT32:
|
||||
switch (dtype) {
|
||||
@ -592,7 +612,7 @@ class EncodeProtoOp : public OpKernel {
|
||||
message_index++) {
|
||||
// TODO(nix): possibly optimize allocation here by calling
|
||||
// `bufs(message_index).reserve(DEFAULT_BUF_SIZE)`.
|
||||
StringOutputStream output_string(&bufs(message_index));
|
||||
TStringOutputStream output_string(&bufs(message_index));
|
||||
CodedOutputStream out(&output_string);
|
||||
// Write fields in ascending field_number order.
|
||||
for (int i : sorted_field_index_) {
|
||||
|
@ -58,7 +58,7 @@ class EncodeWavOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context,
|
||||
wav::EncodeAudioAsS16LEWav(
|
||||
audio.flat<float>().data(), sample_rate, channel_count,
|
||||
sample_count, &output->scalar<string>()()));
|
||||
sample_count, &output->scalar<tstring>()()));
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL_BUILDER(Name("EncodeWav").Device(DEVICE_CPU), EncodeWavOp);
|
||||
|
@ -132,10 +132,10 @@ class ParseExampleOp : public OpKernel {
|
||||
config.sparse.push_back({sparse_keys_t[d], attrs_.sparse_types[d]});
|
||||
}
|
||||
|
||||
auto serialized_t = serialized->flat<string>();
|
||||
auto names_t = names->flat<string>();
|
||||
gtl::ArraySlice<string> slice(serialized_t.data(), serialized_t.size());
|
||||
gtl::ArraySlice<string> names_slice(names_t.data(), names_t.size());
|
||||
auto serialized_t = serialized->flat<tstring>();
|
||||
auto names_t = names->flat<tstring>();
|
||||
gtl::ArraySlice<tstring> slice(serialized_t.data(), serialized_t.size());
|
||||
gtl::ArraySlice<tstring> names_slice(names_t.data(), names_t.size());
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
@ -352,11 +352,11 @@ class ParseSequenceExampleOp : public OpKernel {
|
||||
attrs_.feature_list_sparse_types[d]});
|
||||
}
|
||||
|
||||
auto serialized_t = serialized->flat<string>();
|
||||
auto debug_name_t = debug_name->flat<string>();
|
||||
gtl::ArraySlice<string> slice(serialized_t.data(), serialized_t.size());
|
||||
gtl::ArraySlice<string> names_slice(debug_name_t.data(),
|
||||
debug_name_t.size());
|
||||
auto serialized_t = serialized->flat<tstring>();
|
||||
auto debug_name_t = debug_name->flat<tstring>();
|
||||
gtl::ArraySlice<tstring> slice(serialized_t.data(), serialized_t.size());
|
||||
gtl::ArraySlice<tstring> names_slice(debug_name_t.data(),
|
||||
debug_name_t.size());
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
@ -853,10 +853,12 @@ class DecodeJSONExampleOp : public OpKernel {
|
||||
&binary_examples));
|
||||
|
||||
for (int i = 0; i < json_examples->NumElements(); ++i) {
|
||||
const string& json_example = json_examples->flat<string>()(i);
|
||||
auto status = protobuf::util::JsonToBinaryString(
|
||||
resolver_.get(), "type.googleapis.com/tensorflow.Example",
|
||||
json_example, &binary_examples->flat<string>()(i));
|
||||
const tstring& json_example = json_examples->flat<tstring>()(i);
|
||||
protobuf::io::ArrayInputStream in(json_example.data(),
|
||||
json_example.size());
|
||||
TStringOutputStream out(&binary_examples->flat<tstring>()(i));
|
||||
auto status = protobuf::util::JsonToBinaryStream(
|
||||
resolver_.get(), "type.googleapis.com/tensorflow.Example", &in, &out);
|
||||
OP_REQUIRES(ctx, status.ok(),
|
||||
errors::InvalidArgument("Error while parsing JSON: ",
|
||||
string(status.error_message())));
|
||||
|
@ -124,7 +124,7 @@ struct ExampleStore {
|
||||
Features* features = example.mutable_features();
|
||||
(*features->mutable_feature())[k_str] = f;
|
||||
}
|
||||
CHECK(example.SerializeToString(&string_t(b)));
|
||||
CHECK(SerializeToTString(example, &string_t(b)));
|
||||
}
|
||||
(*examples)[std::make_tuple(batch_size, num_keys, feature_size)] =
|
||||
record_string;
|
||||
|
@ -52,7 +52,7 @@ void FarmhashFingerprint64(TTypes<uint8, 2>::ConstTensor input,
|
||||
}
|
||||
}
|
||||
|
||||
void FarmhashFingerprint64(TTypes<string>::ConstFlat input,
|
||||
void FarmhashFingerprint64(TTypes<tstring>::ConstFlat input,
|
||||
TTypes<uint8, 2>::Matrix output) {
|
||||
DCHECK_EQ(output.dimension(0), input.dimension(0));
|
||||
DCHECK_EQ(output.dimension(1), sizeof(uint64));
|
||||
@ -79,7 +79,7 @@ class FingerprintOp : public OpKernel {
|
||||
errors::InvalidArgument("`method` should be a scalar string: ",
|
||||
method_tensor.shape()));
|
||||
// For now, farmhash64 is the only function supported.
|
||||
const string& method = method_tensor.scalar<string>()();
|
||||
const tstring& method = method_tensor.scalar<tstring>()();
|
||||
OP_REQUIRES(
|
||||
context, method == "farmhash64",
|
||||
errors::InvalidArgument("Unsupported fingerprint method: ", method));
|
||||
|
@ -82,10 +82,10 @@ TEST_F(FingerprintOpTest, StringGoldenValue) {
|
||||
buffer(1).resize(7);
|
||||
buffer(2).resize(0);
|
||||
buffer(3).resize(19);
|
||||
std::iota(buffer(0).begin(), buffer(0).end(), 0);
|
||||
std::iota(buffer(1).begin(), buffer(1).end(), 7);
|
||||
std::iota(buffer(2).begin(), buffer(2).end(), 71);
|
||||
std::iota(buffer(3).begin(), buffer(3).end(), 41);
|
||||
std::iota(&buffer(0)[0], &buffer(0)[0] + buffer(0).size(), 0);
|
||||
std::iota(&buffer(1)[0], &buffer(1)[0] + buffer(1).size(), 7);
|
||||
std::iota(&buffer(2)[0], &buffer(2)[0] + buffer(2).size(), 71);
|
||||
std::iota(&buffer(3)[0], &buffer(3)[0] + buffer(3).size(), 41);
|
||||
|
||||
TF_ASSERT_OK(MakeFingerprintOp(&data));
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
@ -137,7 +137,7 @@ TEST_F(FingerprintOpTest, CollisionString) {
|
||||
auto& input = tensor.vec<tstring>()(0);
|
||||
input.resize(size);
|
||||
|
||||
TTypes<uint8>::UnalignedFlat buffer(reinterpret_cast<uint8*>(&*input.begin()),
|
||||
TTypes<uint8>::UnalignedFlat buffer(reinterpret_cast<uint8*>(&input[0]),
|
||||
input.size());
|
||||
buffer.setRandom();
|
||||
|
||||
|
@ -134,7 +134,7 @@ T SubtleMustCopyIfIntegral(const T& value) {
|
||||
return internal::SubtleMustCopy(value);
|
||||
}
|
||||
|
||||
inline const string& SubtleMustCopyIfIntegral(const string& value) {
|
||||
inline const tstring& SubtleMustCopyIfIntegral(const tstring& value) {
|
||||
return value;
|
||||
}
|
||||
|
||||
|
@ -55,7 +55,7 @@ class RecordInputOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Tensor out(DT_STRING, {batch_size_});
|
||||
auto t_out = out.flat<string>();
|
||||
auto t_out = out.flat<tstring>();
|
||||
for (int i = 0; i < batch_size_; ++i) {
|
||||
OP_REQUIRES_OK(ctx, yielder_->YieldOne(&t_out(i)));
|
||||
}
|
||||
|
@ -44,7 +44,7 @@ RecordYielder::~RecordYielder() {
|
||||
delete thread_;
|
||||
}
|
||||
|
||||
Status RecordYielder::YieldOne(string* value) {
|
||||
Status RecordYielder::YieldOne(tstring* value) {
|
||||
mutex_lock l(mu_);
|
||||
while (!BufEnough() && status_.ok()) {
|
||||
buf_enough_.wait(l);
|
||||
|
@ -90,7 +90,7 @@ class RecordYielder {
|
||||
RecordYielder& operator=(const RecordYielder&) = delete;
|
||||
|
||||
// Yields one 'value'.
|
||||
Status YieldOne(string* value);
|
||||
Status YieldOne(tstring* value);
|
||||
|
||||
// Returns the current epoch number.
|
||||
int64 current_epoch() const { return epoch_; }
|
||||
|
@ -48,11 +48,15 @@ Status InternalCompute(const RE2& match, const string& rewrite,
|
||||
}
|
||||
auto output_flat = output_tensor->flat<tstring>();
|
||||
for (size_t i = 0; i < output_flat.size(); ++i) {
|
||||
// TODO(dero): Mitigate copy; Global and GlobalReplace below currently only
|
||||
// accept std::string.
|
||||
string buf = output_flat(i);
|
||||
if (replace_global) {
|
||||
RE2::GlobalReplace(&output_flat(i), match, rewrite);
|
||||
RE2::GlobalReplace(&buf, match, rewrite);
|
||||
} else {
|
||||
RE2::Replace(&output_flat(i), match, rewrite);
|
||||
RE2::Replace(&buf, match, rewrite);
|
||||
}
|
||||
output_flat(i) = std::move(buf);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -71,9 +71,9 @@ Graph* SetupRegexReplaceGraph(const Tensor& input, const string& input_pattern,
|
||||
const string& input_rewrite) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
Tensor pattern(DT_STRING, TensorShape({}));
|
||||
pattern.flat<string>().setConstant(input_pattern);
|
||||
pattern.flat<tstring>().setConstant(input_pattern);
|
||||
Tensor rewrite(DT_STRING, TensorShape({}));
|
||||
rewrite.flat<string>().setConstant(input_rewrite);
|
||||
rewrite.flat<tstring>().setConstant(input_rewrite);
|
||||
|
||||
TF_CHECK_OK(NodeBuilder("regex_replace_op", "RegexReplace")
|
||||
.Input(test::graph::Constant(g, input))
|
||||
|
@ -128,7 +128,7 @@ class StringNGramsOp : public tensorflow::OpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
void CreateNgrams(const string* data, string* output, int num_ngrams,
|
||||
void CreateNgrams(const tstring* data, tstring* output, int num_ngrams,
|
||||
int ngram_width) const {
|
||||
for (int ngram_index = 0; ngram_index < num_ngrams; ++ngram_index) {
|
||||
int pad_width = get_pad_width(ngram_width);
|
||||
@ -154,20 +154,20 @@ class StringNGramsOp : public tensorflow::OpKernel {
|
||||
ngram_size += num_separators * separator_.length();
|
||||
|
||||
// Build the ngram.
|
||||
string* ngram = &output[ngram_index];
|
||||
tstring* ngram = &output[ngram_index];
|
||||
ngram->reserve(ngram_size);
|
||||
for (int n = 0; n < left_padding; ++n) {
|
||||
*ngram += left_pad_;
|
||||
*ngram += separator_;
|
||||
ngram->append(left_pad_);
|
||||
ngram->append(separator_);
|
||||
}
|
||||
for (int n = 0; n < num_tokens - 1; ++n) {
|
||||
*ngram += data[data_start_index + n];
|
||||
*ngram += separator_;
|
||||
ngram->append(data[data_start_index + n]);
|
||||
ngram->append(separator_);
|
||||
}
|
||||
*ngram += data[data_start_index + num_tokens - 1];
|
||||
ngram->append(data[data_start_index + num_tokens - 1]);
|
||||
for (int n = 0; n < right_padding; ++n) {
|
||||
*ngram += separator_;
|
||||
*ngram += right_pad_;
|
||||
ngram->append(separator_);
|
||||
ngram->append(right_pad_);
|
||||
}
|
||||
|
||||
// In debug mode only: validate that we've reserved enough space for the
|
||||
|
@ -51,12 +51,12 @@ class NgramKernelTest : public tensorflow::OpsTestBase {
|
||||
TF_ASSERT_OK(InitOp());
|
||||
}
|
||||
|
||||
void assert_string_equal(const std::vector<string> &expected,
|
||||
void assert_string_equal(const std::vector<tstring> &expected,
|
||||
const Tensor &value) {
|
||||
Tensor expected_tensor(allocator(), DT_STRING,
|
||||
TensorShape({static_cast<int64>(expected.size())}));
|
||||
test::FillValues<string>(&expected_tensor, expected);
|
||||
test::ExpectTensorEqual<string>(expected_tensor, value);
|
||||
test::FillValues<tstring>(&expected_tensor, expected);
|
||||
test::ExpectTensorEqual<tstring>(expected_tensor, value);
|
||||
}
|
||||
void assert_int64_equal(const std::vector<int64> &expected,
|
||||
const Tensor &value) {
|
||||
@ -72,11 +72,11 @@ TEST_F(NgramKernelTest, TestPaddedTrigrams) {
|
||||
// Batch items are:
|
||||
// 0: "a", "b", "c", "d"
|
||||
// 1: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values( //
|
||||
std::vector<tstring> expected_values( //
|
||||
{"LP|LP|a", "LP|a|b", "a|b|c", "b|c|d", "c|d|RP", "d|RP|RP", // 0
|
||||
"LP|LP|e", "LP|e|f", "e|f|RP", "f|RP|RP"}); // 1
|
||||
std::vector<int64> expected_splits({0, 6, 10});
|
||||
@ -90,11 +90,11 @@ TEST_F(NgramKernelTest, TestPaddedBigramsAndTrigrams) {
|
||||
// Batch items are:
|
||||
// 0: "a", "b", "c", "d"
|
||||
// 1: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values(
|
||||
std::vector<tstring> expected_values(
|
||||
{"LP|a", "a|b", "b|c", "c|d", "d|RP", "LP|LP|a", "LP|a|b", "a|b|c",
|
||||
"b|c|d", "c|d|RP", "d|RP|RP", // 0
|
||||
"LP|e", "e|f", "f|RP", "LP|LP|e", "LP|e|f", "e|f|RP", "f|RP|RP"}); // 1
|
||||
@ -109,11 +109,11 @@ TEST_F(NgramKernelTest, TestPaddedBigrams) {
|
||||
// Batch items are:
|
||||
// 0: "a", "b", "c", "d"
|
||||
// 1: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values( //
|
||||
std::vector<tstring> expected_values( //
|
||||
{"LP|a", "a|b", "b|c", "c|d", "d|RP", // 0
|
||||
"LP|e", "e|f", "f|RP"}); // 1
|
||||
std::vector<int64> expected_splits({0, 5, 8});
|
||||
@ -127,11 +127,11 @@ TEST_F(NgramKernelTest, TestPaddingIsAtMostNGramSizeMinus1) {
|
||||
// Batch items are:
|
||||
// 0: "a", "b", "c", "d"
|
||||
// 1: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values( //
|
||||
std::vector<tstring> expected_values( //
|
||||
{"LP|a", "a|b", "b|c", "c|d", "d|RP", // 0
|
||||
"LP|e", "e|f", "f|RP"}); // 1
|
||||
std::vector<int64> expected_splits({0, 5, 8});
|
||||
@ -145,11 +145,11 @@ TEST_F(NgramKernelTest, TestPaddedUnigramAndBigrams) {
|
||||
// Batch items are:
|
||||
// 0: "a", "b", "c", "d"
|
||||
// 1: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values( //
|
||||
std::vector<tstring> expected_values( //
|
||||
{"a", "b", "c", "d", "LP|a", "a|b", "b|c", "c|d", "d|RP", // 0
|
||||
"e", "f", "LP|e", "e|f", "f|RP"}); // 1
|
||||
std::vector<int64> expected_splits({0, 9, 14});
|
||||
@ -166,11 +166,11 @@ TEST_F(NgramKernelTest, TestOverlappingPaddedNGrams) {
|
||||
// 0: "a"
|
||||
// 1: "b", "c", "d"
|
||||
// 2: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values( //
|
||||
std::vector<tstring> expected_values( //
|
||||
{"LP|LP|a", "LP|a|RP", "a|RP|RP", // ngrams for elem. 0
|
||||
"LP|LP|b", "LP|b|c", "b|c|d", "c|d|RP", "d|RP|RP", // ngrams for elem. 1
|
||||
"LP|LP|e", "LP|e|f", "e|f|RP", "f|RP|RP"}); // ngrams for elem. 2
|
||||
@ -186,12 +186,12 @@ TEST_F(NgramKernelTest, TestOverlappingPaddedMultiCharNGrams) {
|
||||
// 0: "a"
|
||||
// 1: "b", "c", "d"
|
||||
// 2: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}),
|
||||
{"aa", "bb", "cc", "dd", "ee", "ff"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}),
|
||||
{"aa", "bb", "cc", "dd", "ee", "ff"});
|
||||
AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values( //
|
||||
std::vector<tstring> expected_values( //
|
||||
{"LP|LP|aa", "LP|aa|RP", "aa|RP|RP", //
|
||||
"LP|LP|bb", "LP|bb|cc", "bb|cc|dd", "cc|dd|RP", "dd|RP|RP", //
|
||||
"LP|LP|ee", "LP|ee|ff", "ee|ff|RP", "ff|RP|RP"}); //
|
||||
@ -207,13 +207,13 @@ TEST_F(NgramKernelTest, TestMultiOverlappingPaddedNGrams) {
|
||||
MakeOp("|", {5}, "LP", "RP", -1, false);
|
||||
// Batch items are:
|
||||
// 0: "a"
|
||||
AddInputFromArray<string>(TensorShape({1}), {"a"});
|
||||
AddInputFromArray<tstring>(TensorShape({1}), {"a"});
|
||||
AddInputFromArray<int64>(TensorShape({2}), {0, 1});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values({"LP|LP|LP|LP|a", "LP|LP|LP|a|RP",
|
||||
"LP|LP|a|RP|RP", "LP|a|RP|RP|RP",
|
||||
"a|RP|RP|RP|RP"});
|
||||
std::vector<tstring> expected_values({"LP|LP|LP|LP|a", "LP|LP|LP|a|RP",
|
||||
"LP|LP|a|RP|RP", "LP|a|RP|RP|RP",
|
||||
"a|RP|RP|RP|RP"});
|
||||
std::vector<int64> expected_splits({0, 5});
|
||||
|
||||
assert_string_equal(expected_values, *GetOutput(0));
|
||||
@ -225,11 +225,11 @@ TEST_F(NgramKernelTest, TestUnpaddedTrigrams) {
|
||||
// Batch items are:
|
||||
// 0: "a", "b", "c", "d"
|
||||
// 1: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values({"a|b|c", "b|c|d"});
|
||||
std::vector<tstring> expected_values({"a|b|c", "b|c|d"});
|
||||
std::vector<int64> expected_splits({0, 2, 2});
|
||||
|
||||
assert_string_equal(expected_values, *GetOutput(0));
|
||||
@ -241,11 +241,11 @@ TEST_F(NgramKernelTest, TestUnpaddedTrigramsWithEmptySequence) {
|
||||
// Batch items are:
|
||||
// 0: "a", "b", "c", "d"
|
||||
// 1: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({4}), {0, 4, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values({"a|b|c", "b|c|d"});
|
||||
std::vector<tstring> expected_values({"a|b|c", "b|c|d"});
|
||||
std::vector<int64> expected_splits({0, 2, 2, 2});
|
||||
|
||||
assert_string_equal(expected_values, *GetOutput(0));
|
||||
@ -257,11 +257,11 @@ TEST_F(NgramKernelTest, TestUnpaddedTrigramsWithPreserveShort) {
|
||||
// Batch items are:
|
||||
// 0: "a", "b", "c", "d"
|
||||
// 1: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values({"a|b|c", "b|c|d", "e|f"});
|
||||
std::vector<tstring> expected_values({"a|b|c", "b|c|d", "e|f"});
|
||||
std::vector<int64> expected_splits({0, 2, 3});
|
||||
|
||||
assert_string_equal(expected_values, *GetOutput(0));
|
||||
@ -273,11 +273,11 @@ TEST_F(NgramKernelTest, TestUnpaddedTrigramsWithPreserveShortAndEmptySequence) {
|
||||
// Batch items are:
|
||||
// 0: "a", "b", "c", "d"
|
||||
// 1: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({4}), {0, 4, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values({"a|b|c", "b|c|d", "e|f"});
|
||||
std::vector<tstring> expected_values({"a|b|c", "b|c|d", "e|f"});
|
||||
std::vector<int64> expected_splits({0, 2, 2, 3});
|
||||
|
||||
assert_string_equal(expected_values, *GetOutput(0));
|
||||
@ -289,11 +289,11 @@ TEST_F(NgramKernelTest, TestUnpaddedTrigramsAndQuadgramsWithPreserveShort) {
|
||||
// Batch items are:
|
||||
// 0: "a", "b", "c", "d"
|
||||
// 1: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values({"a|b|c|d", "a|b|c", "b|c|d", "e|f"});
|
||||
std::vector<tstring> expected_values({"a|b|c|d", "a|b|c", "b|c|d", "e|f"});
|
||||
std::vector<int64> expected_splits({0, 3, 4});
|
||||
|
||||
assert_string_equal(expected_values, *GetOutput(0));
|
||||
@ -305,11 +305,11 @@ TEST_F(NgramKernelTest, TestUnpaddedBigramsAndTrigrams) {
|
||||
// Batch items are:
|
||||
// 0: "a", "b", "c", "d"
|
||||
// 1: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values(
|
||||
std::vector<tstring> expected_values(
|
||||
{"a|b", "b|c", "c|d", "a|b|c", "b|c|d", "e|f"});
|
||||
std::vector<int64> expected_splits({0, 5, 6});
|
||||
|
||||
@ -322,13 +322,13 @@ TEST_F(NgramKernelTest, TestUnpaddedBigramsAndTrigramsWithPreserveShort) {
|
||||
// Batch items are:
|
||||
// 0: "a", "b", "c", "d"
|
||||
// 1: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Note that in this case, because the bigram 'e|f' was already generated,
|
||||
// the op will not generate a special preserve_short bigram.
|
||||
std::vector<string> expected_values(
|
||||
std::vector<tstring> expected_values(
|
||||
{"a|b", "b|c", "c|d", "a|b|c", "b|c|d", "e|f"});
|
||||
std::vector<int64> expected_splits({0, 5, 6});
|
||||
|
||||
@ -341,13 +341,13 @@ TEST_F(NgramKernelTest, TestUnpaddedTrigramsAndBigramsWithPreserveShort) {
|
||||
// Batch items are:
|
||||
// 0: "a", "b", "c", "d"
|
||||
// 1: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Note that in this case, because the bigram 'e|f' was already generated,
|
||||
// the op will not generate a special preserve_short bigram.
|
||||
std::vector<string> expected_values(
|
||||
std::vector<tstring> expected_values(
|
||||
{"a|b|c", "b|c|d", "a|b", "b|c", "c|d", "e|f"});
|
||||
std::vector<int64> expected_splits({0, 5, 6});
|
||||
|
||||
@ -360,11 +360,11 @@ TEST_F(NgramKernelTest, TestUnpaddedBigrams) {
|
||||
// Batch items are:
|
||||
// 0: "a", "b", "c", "d"
|
||||
// 1: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values({"a|b", "b|c", "c|d", "e|f"});
|
||||
std::vector<tstring> expected_values({"a|b", "b|c", "c|d", "e|f"});
|
||||
std::vector<int64> expected_splits({0, 3, 4});
|
||||
|
||||
assert_string_equal(expected_values, *GetOutput(0));
|
||||
@ -377,11 +377,11 @@ TEST_F(NgramKernelTest, TestOverlappingUnpaddedNGrams) {
|
||||
// 0: "a"
|
||||
// 1: "b", "c", "d"
|
||||
// 2: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values({"b|c|d"});
|
||||
std::vector<tstring> expected_values({"b|c|d"});
|
||||
std::vector<int64> expected_splits({0, 0, 1, 1});
|
||||
|
||||
assert_string_equal(expected_values, *GetOutput(0));
|
||||
@ -394,11 +394,11 @@ TEST_F(NgramKernelTest, TestOverlappingUnpaddedNGramsNoOutput) {
|
||||
// 0: "a"
|
||||
// 1: "b", "c", "d"
|
||||
// 2: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values({});
|
||||
std::vector<tstring> expected_values({});
|
||||
std::vector<int64> expected_splits({0, 0, 0, 0});
|
||||
|
||||
assert_string_equal(expected_values, *GetOutput(0));
|
||||
@ -410,12 +410,13 @@ TEST_F(NgramKernelTest, TestSinglyPaddedTrigrams) {
|
||||
// Batch items are:
|
||||
// 0: "a", "b", "c", "d"
|
||||
// 1: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values({"LP|a|b", "a|b|c", "b|c|d", "c|d|RP", //
|
||||
"LP|e|f", "e|f|RP"});
|
||||
std::vector<tstring> expected_values({"LP|a|b", "a|b|c", "b|c|d",
|
||||
"c|d|RP", //
|
||||
"LP|e|f", "e|f|RP"});
|
||||
std::vector<int64> expected_splits({0, 4, 6});
|
||||
|
||||
assert_string_equal(expected_values, *GetOutput(0));
|
||||
@ -427,12 +428,12 @@ TEST_F(NgramKernelTest, TestSinglyPaddedBigrams) {
|
||||
// Batch items are:
|
||||
// 0: "a", "b", "c", "d"
|
||||
// 1: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values({"LP|a", "a|b", "b|c", "c|d", "d|RP", //
|
||||
"LP|e", "e|f", "f|RP"});
|
||||
std::vector<tstring> expected_values({"LP|a", "a|b", "b|c", "c|d", "d|RP", //
|
||||
"LP|e", "e|f", "f|RP"});
|
||||
std::vector<int64> expected_splits({0, 5, 8});
|
||||
|
||||
assert_string_equal(expected_values, *GetOutput(0));
|
||||
@ -444,11 +445,11 @@ TEST_F(NgramKernelTest, TestSinglyPaddedBigramsAnd5grams) {
|
||||
// Batch items are:
|
||||
// 0: "a", "b", "c", "d"
|
||||
// 1: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values( //
|
||||
std::vector<tstring> expected_values( //
|
||||
{"LP|a", "a|b", "b|c", "c|d", "d|RP", "LP|a|b|c|d", "a|b|c|d|RP", //
|
||||
"LP|e", "e|f", "f|RP"});
|
||||
std::vector<int64> expected_splits({0, 7, 10});
|
||||
@ -462,12 +463,12 @@ TEST_F(NgramKernelTest, TestSinglyPadded5gramsWithPreserveShort) {
|
||||
// Batch items are:
|
||||
// 0: "a", "b", "c", "d"
|
||||
// 1: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values( //
|
||||
{"LP|a|b|c|d", "a|b|c|d|RP", //
|
||||
std::vector<tstring> expected_values( //
|
||||
{"LP|a|b|c|d", "a|b|c|d|RP", //
|
||||
"LP|e|f|RP"});
|
||||
std::vector<int64> expected_splits({0, 2, 3});
|
||||
|
||||
@ -481,11 +482,11 @@ TEST_F(NgramKernelTest, TestOverlappingSinglyPaddedNGrams) {
|
||||
// 0: "a"
|
||||
// 1: "b", "c", "d"
|
||||
// 2: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values(
|
||||
std::vector<tstring> expected_values(
|
||||
{"LP|a|RP", // ngrams for elem. 0
|
||||
"LP|b|c", "b|c|d", "c|d|RP", // ngrams for elem. 1
|
||||
"LP|e|f", "e|f|RP"}); // ngrams for elem. 2
|
||||
@ -501,11 +502,11 @@ TEST_F(NgramKernelTest, TestOverlappingSinglyPaddedNGramsNoOutput) {
|
||||
// 0: "a"
|
||||
// 1: "b", "c", "d"
|
||||
// 2: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values({"LP|b|c|d|RP"});
|
||||
std::vector<tstring> expected_values({"LP|b|c|d|RP"});
|
||||
std::vector<int64> expected_splits({0, 0, 1, 1});
|
||||
|
||||
assert_string_equal(expected_values, *GetOutput(0));
|
||||
@ -517,11 +518,11 @@ TEST_F(NgramKernelTest, TestSinglyPaddedUnigrams) {
|
||||
// Batch items are:
|
||||
// 0: "a", "b", "c", "d"
|
||||
// 1: "e", "f"
|
||||
AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<tstring>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
|
||||
AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values({"a", "b", "c", "d", "e", "f"});
|
||||
std::vector<tstring> expected_values({"a", "b", "c", "d", "e", "f"});
|
||||
std::vector<int64> expected_splits({0, 4, 6});
|
||||
|
||||
assert_string_equal(expected_values, *GetOutput(0));
|
||||
@ -530,11 +531,11 @@ TEST_F(NgramKernelTest, TestSinglyPaddedUnigrams) {
|
||||
|
||||
TEST_F(NgramKernelTest, TestEmptyInput) {
|
||||
MakeOp("|", {1}, "LP", "RP", 3, false);
|
||||
AddInputFromArray<string>(TensorShape({0}), {});
|
||||
AddInputFromArray<tstring>(TensorShape({0}), {});
|
||||
AddInputFromArray<int64>(TensorShape({0}), {});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::vector<string> expected_values({});
|
||||
std::vector<tstring> expected_values({});
|
||||
std::vector<int64> expected_splits({});
|
||||
|
||||
assert_string_equal(expected_values, *GetOutput(0));
|
||||
|
@ -52,7 +52,7 @@ namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
void Encode(const UnicodeEncoding encoding, const icu::UnicodeString& in,
|
||||
string* out) {
|
||||
tstring* out) {
|
||||
if (encoding == UnicodeEncoding::UTF8) {
|
||||
out->clear();
|
||||
in.toUTF8String(*out);
|
||||
@ -330,7 +330,7 @@ class UnicodeTranscodeOp : public OpKernel {
|
||||
// Transcode the string from input encoding to the output_encoding_. If
|
||||
// non-valid characters are encountered, use the subst_/elide_replacement_
|
||||
// config to handle them.
|
||||
void Transcode(string* s, UConverter* input_encoder,
|
||||
void Transcode(tstring* s, UConverter* input_encoder,
|
||||
bool* found_any_format_error) {
|
||||
icu::UnicodeString source;
|
||||
IterateUnicodeString(
|
||||
@ -561,9 +561,9 @@ class UnicodeEncodeOp : public OpKernel {
|
||||
appendable_unicode_string.appendCodePoint(code_point);
|
||||
}
|
||||
// Encode our string and save in the output.
|
||||
string result;
|
||||
tstring result;
|
||||
Encode(encoding_, unicode_string, &result);
|
||||
output_tensor_flat(i - 1) = result;
|
||||
output_tensor_flat(i - 1) = std::move(result);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -34,8 +34,8 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
static Status ReadEntireFile(Env* env, const string& filename,
|
||||
string* contents) {
|
||||
template <typename T>
|
||||
static Status ReadEntireFile(Env* env, const string& filename, T* contents) {
|
||||
std::unique_ptr<RandomAccessFile> file;
|
||||
TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
|
||||
io::RandomAccessInputStream input_stream(file.get());
|
||||
@ -112,8 +112,8 @@ class ReadFileOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context, context->allocate_output("contents",
|
||||
TensorShape({}), &output));
|
||||
OP_REQUIRES_OK(context,
|
||||
ReadEntireFile(context->env(), input->scalar<string>()(),
|
||||
&output->scalar<string>()()));
|
||||
ReadEntireFile(context->env(), input->scalar<tstring>()(),
|
||||
&output->scalar<tstring>()()));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -167,7 +167,8 @@ Status BufferedInputStream::Seek(int64 position) {
|
||||
return SkipNBytes(position - bufpos);
|
||||
}
|
||||
|
||||
Status BufferedInputStream::ReadAll(string* result) {
|
||||
template <typename T>
|
||||
Status BufferedInputStream::ReadAll(T* result) {
|
||||
result->clear();
|
||||
Status status;
|
||||
while (status.ok()) {
|
||||
@ -186,6 +187,11 @@ Status BufferedInputStream::ReadAll(string* result) {
|
||||
return status;
|
||||
}
|
||||
|
||||
template Status BufferedInputStream::ReadAll<string>(string* result);
|
||||
#ifdef USE_TSTRING
|
||||
template Status BufferedInputStream::ReadAll<tstring>(tstring* result);
|
||||
#endif // USE_TSTRING
|
||||
|
||||
Status BufferedInputStream::Reset() {
|
||||
TF_RETURN_IF_ERROR(input_stream_->Reset());
|
||||
pos_ = 0;
|
||||
|
@ -79,7 +79,8 @@ class BufferedInputStream : public InputStreamInterface {
|
||||
//
|
||||
// Note: the amount of memory used by this function call is unbounded, so only
|
||||
// use in ops that expect that behavior.
|
||||
Status ReadAll(string* result);
|
||||
template <typename T>
|
||||
Status ReadAll(T* result);
|
||||
|
||||
Status Reset() override;
|
||||
|
||||
|
@ -84,7 +84,7 @@ void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize) {
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize,
|
||||
string *destination) {
|
||||
tstring *destination) {
|
||||
MemDestMgr *dest;
|
||||
if (cinfo->dest == nullptr) {
|
||||
cinfo->dest = reinterpret_cast<struct jpeg_destination_mgr *>(
|
||||
|
@ -33,7 +33,7 @@ typedef struct {
|
||||
JOCTET *buffer;
|
||||
int bufsize;
|
||||
int datacount;
|
||||
string *dest;
|
||||
tstring *dest;
|
||||
} MemDestMgr;
|
||||
|
||||
typedef struct {
|
||||
@ -52,7 +52,7 @@ void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize);
|
||||
// Same as above, except that buffer is only used as a temporary structure and
|
||||
// is emptied into "destination" as soon as it fills up.
|
||||
void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize,
|
||||
string *destination);
|
||||
tstring *destination);
|
||||
|
||||
} // namespace jpeg
|
||||
} // namespace tensorflow
|
||||
|
@ -592,7 +592,7 @@ bool GetImageInfo(const void* srcdata, int datasize, int* width, int* height,
|
||||
|
||||
namespace {
|
||||
bool CompressInternal(const uint8* srcdata, int width, int height,
|
||||
const CompressFlags& flags, string* output) {
|
||||
const CompressFlags& flags, tstring* output) {
|
||||
output->clear();
|
||||
const int components = (static_cast<int>(flags.format) & 0xff);
|
||||
|
||||
@ -762,14 +762,14 @@ bool CompressInternal(const uint8* srcdata, int width, int height,
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
bool Compress(const void* srcdata, int width, int height,
|
||||
const CompressFlags& flags, string* output) {
|
||||
const CompressFlags& flags, tstring* output) {
|
||||
return CompressInternal(static_cast<const uint8*>(srcdata), width, height,
|
||||
flags, output);
|
||||
}
|
||||
|
||||
string Compress(const void* srcdata, int width, int height,
|
||||
const CompressFlags& flags) {
|
||||
string temp;
|
||||
tstring Compress(const void* srcdata, int width, int height,
|
||||
const CompressFlags& flags) {
|
||||
tstring temp;
|
||||
CompressInternal(static_cast<const uint8*>(srcdata), width, height, flags,
|
||||
&temp);
|
||||
// If CompressInternal fails, temp will be empty.
|
||||
|
@ -149,12 +149,12 @@ struct CompressFlags {
|
||||
// The encoded data is returned as a string.
|
||||
// If not empty, XMP metadata can be embedded in the image header
|
||||
// On error, returns the empty string (which is never a valid jpeg).
|
||||
string Compress(const void* srcdata, int width, int height,
|
||||
const CompressFlags& flags);
|
||||
tstring Compress(const void* srcdata, int width, int height,
|
||||
const CompressFlags& flags);
|
||||
|
||||
// On error, returns false and sets output to empty.
|
||||
bool Compress(const void* srcdata, int width, int height,
|
||||
const CompressFlags& flags, string* output);
|
||||
const CompressFlags& flags, tstring* output);
|
||||
|
||||
} // namespace jpeg
|
||||
} // namespace tensorflow
|
||||
|
@ -326,7 +326,7 @@ TEST(JpegMemTest, Jpeg2) {
|
||||
CHECK_NE(string::npos, cpdata1.find(kXMP));
|
||||
|
||||
// Test the other API, where a storage string is supplied
|
||||
string cptest;
|
||||
tstring cptest;
|
||||
flags.stride = 0;
|
||||
Compress(refdata1.get(), in_w, in_h, flags, &cptest);
|
||||
CHECK_EQ(cptest, cpdata1);
|
||||
@ -465,7 +465,7 @@ TEST(JpegMemTest, ChromaDownsampling) {
|
||||
flags.format = FORMAT_RGB;
|
||||
flags.quality = 85;
|
||||
flags.chroma_downsampling = downsample;
|
||||
string recompressed;
|
||||
tstring recompressed;
|
||||
Compress(uncompressed.get(), w, h, flags, &recompressed);
|
||||
CHECK(!recompressed.empty());
|
||||
CHECK_EQ(IsChromaDownsampled(recompressed), downsample);
|
||||
|
@ -105,8 +105,9 @@ void StringReader(png_structp png_ptr, png_bytep data, png_size_t length) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void StringWriter(png_structp png_ptr, png_bytep data, png_size_t length) {
|
||||
string* const s = absl::bit_cast<string*>(png_get_io_ptr(png_ptr));
|
||||
T* const s = absl::bit_cast<T*>(png_get_io_ptr(png_ptr));
|
||||
s->append(absl::bit_cast<const char*>(data), length);
|
||||
}
|
||||
|
||||
@ -340,9 +341,10 @@ bool CommonFinishDecode(png_bytep data, int row_bytes, DecodeContext* context) {
|
||||
return ok;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool WriteImageToBuffer(
|
||||
const void* image, int width, int height, int row_bytes, int num_channels,
|
||||
int channel_bits, int compression, string* png_string,
|
||||
int channel_bits, int compression, T* png_string,
|
||||
const std::vector<std::pair<string, string> >* metadata) {
|
||||
CHECK_NOTNULL(image);
|
||||
CHECK_NOTNULL(png_string);
|
||||
@ -384,7 +386,7 @@ bool WriteImageToBuffer(
|
||||
return false;
|
||||
}
|
||||
|
||||
png_set_write_fn(png_ptr, png_string, StringWriter, StringWriterFlush);
|
||||
png_set_write_fn(png_ptr, png_string, StringWriter<T>, StringWriterFlush);
|
||||
if (compression < 0) compression = Z_DEFAULT_COMPRESSION;
|
||||
png_set_compression_level(png_ptr, compression);
|
||||
png_set_compression_mem_level(png_ptr, MAX_MEM_LEVEL);
|
||||
@ -418,5 +420,16 @@ bool WriteImageToBuffer(
|
||||
return true;
|
||||
}
|
||||
|
||||
template bool WriteImageToBuffer<string>(
|
||||
const void* image, int width, int height, int row_bytes, int num_channels,
|
||||
int channel_bits, int compression, string* png_string,
|
||||
const std::vector<std::pair<string, string> >* metadata);
|
||||
#ifdef USE_TSTRING
|
||||
template bool WriteImageToBuffer<tstring>(
|
||||
const void* image, int width, int height, int row_bytes, int num_channels,
|
||||
int channel_bits, int compression, tstring* png_string,
|
||||
const std::vector<std::pair<string, string> >* metadata);
|
||||
#endif // USE_TSTRING
|
||||
|
||||
} // namespace png
|
||||
} // namespace tensorflow
|
||||
|
@ -94,9 +94,10 @@ void CommonFreeDecode(DecodeContext* context);
|
||||
// compression is in [-1,9], where 0 is fast and weak compression, 9 is slow
|
||||
// and strong, and -1 is the zlib default.
|
||||
|
||||
template <typename T>
|
||||
bool WriteImageToBuffer(
|
||||
const void* image, int width, int height, int row_bytes, int num_channels,
|
||||
int channel_bits, int compression, string* png_string,
|
||||
int channel_bits, int compression, T* png_string,
|
||||
const std::vector<std::pair<string, string> >* metadata);
|
||||
|
||||
} // namespace png
|
||||
|
@ -73,7 +73,8 @@ Status DecodeThreeChars(const char* codes, char* result) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status Base64Decode(StringPiece data, string* decoded) {
|
||||
template <typename T>
|
||||
Status Base64Decode(StringPiece data, T* decoded) {
|
||||
if (decoded == nullptr) {
|
||||
return errors::Internal("'decoded' cannot be nullptr.");
|
||||
}
|
||||
@ -135,11 +136,13 @@ Status Base64Decode(StringPiece data, string* decoded) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Base64Encode(StringPiece source, string* encoded) {
|
||||
template <typename T>
|
||||
Status Base64Encode(StringPiece source, T* encoded) {
|
||||
return Base64Encode(source, false, encoded);
|
||||
}
|
||||
|
||||
Status Base64Encode(StringPiece source, bool with_padding, string* encoded) {
|
||||
template <typename T>
|
||||
Status Base64Encode(StringPiece source, bool with_padding, T* encoded) {
|
||||
const char* const base64_chars = kBase64UrlSafeChars;
|
||||
if (encoded == nullptr) {
|
||||
return errors::Internal("'encoded' cannot be nullptr.");
|
||||
@ -191,4 +194,16 @@ Status Base64Encode(StringPiece source, bool with_padding, string* encoded) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template Status Base64Decode<string>(StringPiece data, string* decoded);
|
||||
template Status Base64Encode<string>(StringPiece source, string* encoded);
|
||||
template Status Base64Encode<string>(StringPiece source, bool with_padding,
|
||||
string* encoded);
|
||||
|
||||
#ifdef USE_TSTRING
|
||||
template Status Base64Decode<tstring>(StringPiece data, tstring* decoded);
|
||||
template Status Base64Encode<tstring>(StringPiece source, tstring* encoded);
|
||||
template Status Base64Encode<tstring>(StringPiece source, bool with_padding,
|
||||
tstring* encoded);
|
||||
#endif // USE_TSTRING
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -24,13 +24,17 @@ namespace tensorflow {
|
||||
/// \brief Converts data into web-safe base64 encoding.
|
||||
///
|
||||
/// See https://en.wikipedia.org/wiki/Base64
|
||||
Status Base64Encode(StringPiece data, bool with_padding, string* encoded);
|
||||
Status Base64Encode(StringPiece data, string* encoded); // with_padding=false.
|
||||
template <typename T>
|
||||
Status Base64Encode(StringPiece source, bool with_padding, T* encoded);
|
||||
template <typename T>
|
||||
Status Base64Encode(StringPiece source,
|
||||
T* encoded); // with_padding=false.
|
||||
|
||||
/// \brief Converts data from web-safe base64 encoding.
|
||||
///
|
||||
/// See https://en.wikipedia.org/wiki/Base64
|
||||
Status Base64Decode(StringPiece data, string* decoded);
|
||||
template <typename T>
|
||||
Status Base64Decode(StringPiece data, T* decoded);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -21,11 +21,11 @@ namespace tensorflow {
|
||||
|
||||
TEST(Base64, EncodeDecode) {
|
||||
const string original = "a simple test message!";
|
||||
string encoded;
|
||||
tstring encoded;
|
||||
TF_EXPECT_OK(Base64Encode(original, &encoded));
|
||||
EXPECT_EQ("YSBzaW1wbGUgdGVzdCBtZXNzYWdlIQ", encoded);
|
||||
|
||||
string decoded;
|
||||
tstring decoded;
|
||||
TF_EXPECT_OK(Base64Decode(encoded, &decoded));
|
||||
EXPECT_EQ(original, decoded);
|
||||
}
|
||||
|
@ -132,9 +132,10 @@ Status ReadString(const string& data, int expected_length, string* value,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate,
|
||||
size_t num_channels, size_t num_frames,
|
||||
string* wav_string) {
|
||||
T* wav_string) {
|
||||
constexpr size_t kFormatChunkSize = 16;
|
||||
constexpr size_t kCompressionCodePcm = 1;
|
||||
constexpr size_t kBitsPerSample = 16;
|
||||
@ -173,7 +174,7 @@ Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate,
|
||||
}
|
||||
|
||||
wav_string->resize(file_size);
|
||||
char* data = &wav_string->at(0);
|
||||
char* data = &(*wav_string)[0];
|
||||
WavHeader* header = absl::bit_cast<WavHeader*>(data);
|
||||
|
||||
// Fill RIFF chunk.
|
||||
@ -208,6 +209,19 @@ Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template Status EncodeAudioAsS16LEWav<string>(const float* audio,
|
||||
size_t sample_rate,
|
||||
size_t num_channels,
|
||||
size_t num_frames,
|
||||
string* wav_string);
|
||||
#ifdef USE_TSTRING
|
||||
template Status EncodeAudioAsS16LEWav<tstring>(const float* audio,
|
||||
size_t sample_rate,
|
||||
size_t num_channels,
|
||||
size_t num_frames,
|
||||
tstring* wav_string);
|
||||
#endif // USE_TSTRING
|
||||
|
||||
Status DecodeLin16WaveAsFloatVector(const string& wav_string,
|
||||
std::vector<float>* float_values,
|
||||
uint32* sample_count, uint16* channel_count,
|
||||
|
@ -41,9 +41,10 @@ namespace wav {
|
||||
// if (EncodeAudioAsS16LEWav(audio_buffer, 8000, 2, 4, &wav_string).ok()) {
|
||||
// // Use wav_string.
|
||||
// }
|
||||
template <typename T>
|
||||
Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate,
|
||||
size_t num_channels, size_t num_frames,
|
||||
string* wav_string);
|
||||
T* wav_string);
|
||||
|
||||
// Decodes the little-endian signed 16-bit PCM WAV file data (aka LIN16
|
||||
// encoding) into a float Tensor. The channels are encoded as the lowest
|
||||
|
@ -34,12 +34,13 @@ Status ReadString(const string& data, int expected_length, string* value,
|
||||
|
||||
TEST(WavIO, BadArguments) {
|
||||
float audio[] = {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f};
|
||||
string result;
|
||||
tstring result;
|
||||
|
||||
EXPECT_EQ(error::INVALID_ARGUMENT,
|
||||
EncodeAudioAsS16LEWav(nullptr, 44100, 2, 3, &result).code());
|
||||
EXPECT_EQ(error::INVALID_ARGUMENT,
|
||||
EncodeAudioAsS16LEWav(audio, 44100, 2, 3, nullptr).code());
|
||||
EXPECT_EQ(
|
||||
error::INVALID_ARGUMENT,
|
||||
EncodeAudioAsS16LEWav(audio, 44100, 2, 3, (tstring*)nullptr).code());
|
||||
|
||||
const size_t kuint32max_plus_one = static_cast<size_t>(kuint32max) + 1;
|
||||
const size_t kuint16max_plus_one = static_cast<size_t>(kuint16max) + 1;
|
||||
|
@ -20,4 +20,43 @@ namespace tensorflow {
|
||||
const char* kProtobufInt64Typename = "::tensorflow::protobuf_int64";
|
||||
const char* kProtobufUint64Typename = "::tensorflow::protobuf_uint64";
|
||||
|
||||
#ifdef USE_TSTRING
|
||||
TStringOutputStream::TStringOutputStream(tstring* target) : target_(target) {}
|
||||
|
||||
bool TStringOutputStream::Next(void** data, int* size) {
|
||||
int old_size = target_->size();
|
||||
|
||||
// Grow the string.
|
||||
if (old_size < target_->capacity()) {
|
||||
// Resize the string to match its capacity, since we can get away
|
||||
// without a memory allocation this way.
|
||||
target_->resize_uninitialized(target_->capacity());
|
||||
} else {
|
||||
// Size has reached capacity, try to double the size.
|
||||
if (old_size > std::numeric_limits<int>::max() / 2) {
|
||||
// Can not double the size otherwise it is going to cause integer
|
||||
// overflow in the expression below: old_size * 2 ";
|
||||
return false;
|
||||
}
|
||||
// Double the size, also make sure that the new size is at least
|
||||
// kMinimumSize.
|
||||
target_->resize_uninitialized(
|
||||
std::max(old_size * 2,
|
||||
kMinimumSize + 0)); // "+ 0" works around GCC4 weirdness.
|
||||
}
|
||||
|
||||
*data = target_->data() + old_size;
|
||||
*size = target_->size() - old_size;
|
||||
return true;
|
||||
}
|
||||
|
||||
void TStringOutputStream::BackUp(int count) {
|
||||
target_->resize(target_->size() - count);
|
||||
}
|
||||
|
||||
protobuf::io::ByteCountInt64 TStringOutputStream::ByteCount() const {
|
||||
return target_->size();
|
||||
}
|
||||
#endif // USE_TSTRING
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -90,6 +90,29 @@ inline bool SerializeToTString(const protobuf::MessageLite& proto,
|
||||
#endif // USE_TSTRING
|
||||
}
|
||||
|
||||
#ifdef USE_TSTRING
|
||||
// Analogue to StringOutputStream for tstring.
|
||||
class TStringOutputStream : public protobuf::io::ZeroCopyOutputStream {
|
||||
public:
|
||||
explicit TStringOutputStream(tstring* target);
|
||||
~TStringOutputStream() override = default;
|
||||
|
||||
TStringOutputStream(const TStringOutputStream&) = delete;
|
||||
void operator=(const TStringOutputStream&) = delete;
|
||||
|
||||
bool Next(void** data, int* size) override;
|
||||
void BackUp(int count) override;
|
||||
protobuf::io::ByteCountInt64 ByteCount() const override;
|
||||
|
||||
private:
|
||||
static const int kMinimumSize = 16;
|
||||
|
||||
tstring* target_;
|
||||
};
|
||||
#else // USE_TSTRING
|
||||
typedef protobuf::io::StringOutputStream TStringOutputStream;
|
||||
#endif // USE_TSTRING
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_PLATFORM_PROTOBUF_H_
|
||||
|
@ -157,6 +157,8 @@ class tstring {
|
||||
|
||||
size_t size() const { return str_.size(); }
|
||||
|
||||
size_t capacity() const { return str_.capacity(); }
|
||||
|
||||
const char* c_str() const { return str_.c_str(); }
|
||||
|
||||
const char* data() const { return str_.data(); }
|
||||
@ -207,6 +209,8 @@ class tstring {
|
||||
return *this;
|
||||
}
|
||||
|
||||
void push_back(char ch) { str_.push_back(ch); }
|
||||
|
||||
friend const tstring operator+(const tstring& a, const tstring& b);
|
||||
friend bool operator==(const char* a, const tstring& b);
|
||||
friend bool operator==(const std::string& a, const tstring& b);
|
||||
|
Loading…
x
Reference in New Issue
Block a user