Add experimental custom parse_example op
PiperOrigin-RevId: 347640612 Change-Id: If4c83fc598391e8972f6b71015caf9def4d4bb7a
This commit is contained in:
parent
67e713e624
commit
af93fc7294
@ -1,75 +0,0 @@
|
|||||||
# Kernel for custom parse_example
|
|
||||||
package(
|
|
||||||
default_visibility = [
|
|
||||||
"//visibility:public",
|
|
||||||
],
|
|
||||||
licenses = ["notice"], # Apache 2.0
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "parse_example",
|
|
||||||
srcs = [
|
|
||||||
"example_proto_fast_parsing.cc",
|
|
||||||
"parse_example.cc",
|
|
||||||
],
|
|
||||||
hdrs = [
|
|
||||||
"example_proto_fast_parsing.h",
|
|
||||||
"parse_example.h",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
"@com_google_absl//absl/base",
|
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
|
||||||
"@flatbuffers",
|
|
||||||
"//tensorflow/lite:framework",
|
|
||||||
"//tensorflow/lite/c:common",
|
|
||||||
"//tensorflow/lite/kernels:kernel_util",
|
|
||||||
"//tensorflow/lite/kernels/internal:tensor",
|
|
||||||
"//tensorflow/lite:string_util",
|
|
||||||
] + select({
|
|
||||||
"//tensorflow:android": [
|
|
||||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
|
||||||
],
|
|
||||||
"//tensorflow:ios": [
|
|
||||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
|
||||||
],
|
|
||||||
"//conditions:default": [
|
|
||||||
"//tensorflow/core:core_cpu",
|
|
||||||
"//tensorflow/core:feature_util",
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:framework_internal",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:lib_internal",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
|
||||||
],
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
|
||||||
name = "parse_example_test",
|
|
||||||
srcs = ["parse_example_test.cc"],
|
|
||||||
deps = [
|
|
||||||
":parse_example",
|
|
||||||
"@flatbuffers",
|
|
||||||
"//tensorflow/lite/c:common",
|
|
||||||
"//tensorflow/lite/core/api:op_resolver",
|
|
||||||
"//tensorflow/lite/kernels:builtin_ops",
|
|
||||||
"//tensorflow/lite/kernels:test_main",
|
|
||||||
"//tensorflow/lite/kernels:test_util",
|
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
|
||||||
"//tensorflow/lite:framework",
|
|
||||||
"//tensorflow/lite:string_util",
|
|
||||||
] + select({
|
|
||||||
"//tensorflow:android": [
|
|
||||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
|
||||||
],
|
|
||||||
"//tensorflow:ios": [
|
|
||||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
|
||||||
],
|
|
||||||
"//conditions:default": [
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
|
||||||
"//tensorflow/core/example:feature_util",
|
|
||||||
"//tensorflow/core/platform:protobuf",
|
|
||||||
"//tensorflow/core/platform:tstring",
|
|
||||||
],
|
|
||||||
}),
|
|
||||||
)
|
|
@ -1,170 +0,0 @@
|
|||||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#include "tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace example {
|
|
||||||
|
|
||||||
string ExampleName(const gtl::ArraySlice<tstring> example_names, int n) {
|
|
||||||
return example_names.empty() ? "<unknown>" : example_names[n];
|
|
||||||
}
|
|
||||||
|
|
||||||
void CountSparseFeatures(
|
|
||||||
const std::vector<std::vector<SparseBuffer>>& sparse_buffers, size_t d,
|
|
||||||
size_t* total_num_features, size_t* max_num_features) {
|
|
||||||
for (auto& sparse_values_tmp : sparse_buffers) {
|
|
||||||
const std::vector<size_t>& end_indices =
|
|
||||||
sparse_values_tmp[d].example_end_indices;
|
|
||||||
*total_num_features += end_indices.back();
|
|
||||||
*max_num_features = std::max(*max_num_features, end_indices[0]);
|
|
||||||
for (size_t i = 1; i < end_indices.size(); ++i) {
|
|
||||||
size_t example_size = end_indices[i] - end_indices[i - 1];
|
|
||||||
*max_num_features = std::max(*max_num_features, example_size);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void CopySparseBufferToTensor(DataType dtype, size_t offset, SparseBuffer* src,
|
|
||||||
Tensor* dst) {
|
|
||||||
switch (dtype) {
|
|
||||||
case DT_INT64: {
|
|
||||||
std::copy(src->int64_list.begin(), src->int64_list.end(),
|
|
||||||
dst->flat<int64>().data() + offset);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case DT_FLOAT: {
|
|
||||||
std::copy(src->float_list.begin(), src->float_list.end(),
|
|
||||||
dst->flat<float>().data() + offset);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case DT_STRING: {
|
|
||||||
std::move(src->bytes_list.begin(), src->bytes_list.end(),
|
|
||||||
dst->flat<tstring>().data() + offset);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
ReportUnexpectedDataType(dtype);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8 PeekTag(protobuf::io::CodedInputStream* stream) {
|
|
||||||
DCHECK(stream != nullptr);
|
|
||||||
const void* ptr;
|
|
||||||
int size;
|
|
||||||
if (!stream->GetDirectBufferPointer(&ptr, &size)) return 0;
|
|
||||||
return *static_cast<const uint8*>(ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result) {
|
|
||||||
DCHECK(stream != nullptr);
|
|
||||||
DCHECK(result != nullptr);
|
|
||||||
uint32 length;
|
|
||||||
if (!stream->ReadVarint32(&length)) return false;
|
|
||||||
if (length == 0) {
|
|
||||||
*result = StringPiece(nullptr, 0);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
const void* stream_alias;
|
|
||||||
int stream_size;
|
|
||||||
if (!stream->GetDirectBufferPointer(&stream_alias, &stream_size)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (static_cast<uint32>(stream_size) < length) return false;
|
|
||||||
*result = StringPiece(static_cast<const char*>(stream_alias), length);
|
|
||||||
stream->Skip(length);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream,
|
|
||||||
parsed::FeatureMapEntry* feature_map_entry) {
|
|
||||||
DCHECK(stream != nullptr);
|
|
||||||
DCHECK(feature_map_entry != nullptr);
|
|
||||||
uint32 length;
|
|
||||||
if (!stream->ReadVarint32(&length)) return false;
|
|
||||||
auto limit = stream->PushLimit(length);
|
|
||||||
if (!stream->ExpectTag(kDelimitedTag(1))) return false;
|
|
||||||
if (!ParseString(stream, &feature_map_entry->first)) return false;
|
|
||||||
if (!stream->ExpectTag(kDelimitedTag(2))) return false;
|
|
||||||
StringPiece feature_string_piece;
|
|
||||||
if (!ParseString(stream, &feature_string_piece)) return false;
|
|
||||||
feature_map_entry->second = parsed::Feature(feature_string_piece);
|
|
||||||
if (!stream->ExpectAtEnd()) return false;
|
|
||||||
stream->PopLimit(limit);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ParseFeatures(protobuf::io::CodedInputStream* stream,
|
|
||||||
parsed::Example* example) {
|
|
||||||
DCHECK(stream != nullptr);
|
|
||||||
DCHECK(example != nullptr);
|
|
||||||
uint32 length;
|
|
||||||
if (!stream->ReadVarint32(&length)) return false;
|
|
||||||
auto limit = stream->PushLimit(length);
|
|
||||||
while (!stream->ExpectAtEnd()) {
|
|
||||||
parsed::FeatureMapEntry feature_map_entry;
|
|
||||||
if (!stream->ExpectTag(kDelimitedTag(1))) return false;
|
|
||||||
if (!ParseFeatureMapEntry(stream, &feature_map_entry)) return false;
|
|
||||||
example->push_back(std::move(feature_map_entry));
|
|
||||||
}
|
|
||||||
stream->PopLimit(limit);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ParseExample(protobuf::io::CodedInputStream* stream,
|
|
||||||
parsed::Example* example) {
|
|
||||||
DCHECK(stream != nullptr);
|
|
||||||
DCHECK(example != nullptr);
|
|
||||||
// Loop over the input stream which may contain multiple serialized Example
|
|
||||||
// protos merged together as strings. This behavior is consistent with Proto's
|
|
||||||
// ParseFromString when string representations are concatenated.
|
|
||||||
while (!stream->ExpectAtEnd()) {
|
|
||||||
if (!stream->ExpectTag(kDelimitedTag(1))) {
|
|
||||||
if (!SkipExtraneousTag(stream)) return false;
|
|
||||||
} else {
|
|
||||||
if (!ParseFeatures(stream, example)) return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ParseExample(StringPiece serialized, parsed::Example* example) {
|
|
||||||
DCHECK(example != nullptr);
|
|
||||||
protobuf::io::CodedInputStream stream(
|
|
||||||
reinterpret_cast<const uint8*>(serialized.data()), serialized.size());
|
|
||||||
EnableAliasing(&stream);
|
|
||||||
return ParseExample(&stream, example);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
void CopyOrMoveBlock(const tstring* b, const tstring* e, tstring* t) {
|
|
||||||
std::move(b, e, t);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
const SmallVector<int64>& GetListFromBuffer<int64>(const SparseBuffer& buffer) {
|
|
||||||
return buffer.int64_list;
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer) {
|
|
||||||
return buffer.float_list;
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
const SmallVector<tstring>& GetListFromBuffer<tstring>(
|
|
||||||
const SparseBuffer& buffer) {
|
|
||||||
return buffer.bytes_list;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace example
|
|
||||||
} // namespace tensorflow
|
|
@ -1,688 +0,0 @@
|
|||||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#ifndef TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
|
|
||||||
#define TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
|
|
||||||
#include "tensorflow/core/util/example_proto_fast_parsing.h"
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "absl/base/casts.h"
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
|
||||||
#include "tensorflow/core/example/example.pb.h"
|
|
||||||
#include "tensorflow/core/example/feature.pb.h"
|
|
||||||
#include "tensorflow/core/framework/allocator.h"
|
|
||||||
#include "tensorflow/core/framework/numeric_op.h"
|
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
|
||||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
|
||||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
|
||||||
#include "tensorflow/core/platform/byte_order.h"
|
|
||||||
#include "tensorflow/core/platform/logging.h"
|
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
|
||||||
#include "tensorflow/core/util/presized_cuckoo_map.h"
|
|
||||||
#include "tensorflow/core/util/sparse/sparse_tensor.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace example {
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
using SmallVector = gtl::InlinedVector<T, 4>;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
class LimitedArraySlice {
|
|
||||||
public:
|
|
||||||
using value_type = T;
|
|
||||||
|
|
||||||
LimitedArraySlice(T* begin, size_t num_elements)
|
|
||||||
: current_(begin), begin_(begin), end_(begin + num_elements) {}
|
|
||||||
|
|
||||||
// May return negative if there were push_back calls after slice was filled.
|
|
||||||
int64 EndDistance() const { return end_ - current_; }
|
|
||||||
|
|
||||||
// Attempts to push value to the back of this. If the slice has
|
|
||||||
// already been filled, this method has no effect on the underlying data, but
|
|
||||||
// it changes the number returned by EndDistance into negative values.
|
|
||||||
void push_back(T&& value) {
|
|
||||||
if (EndDistance() > 0) *current_ = std::move(value);
|
|
||||||
++current_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// "Constructs" an element at the back of this by resizing the slice, and
|
|
||||||
// returns a mutable reference to the new last element.
|
|
||||||
// REQUIRES: EndDistance() > 0.
|
|
||||||
T& construct_at_end() {
|
|
||||||
DCHECK_GT(EndDistance(), 0);
|
|
||||||
return *(current_++);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a mutable reference to the last element in the slice.
|
|
||||||
// REQUIRES: size() > 0.
|
|
||||||
T& back() { return *(current_ - 1); }
|
|
||||||
|
|
||||||
// Returns the number of elements in the slice.
|
|
||||||
size_t size() const { return std::min(current_ - begin_, end_ - begin_); }
|
|
||||||
|
|
||||||
// Attempts to resize the vector to the given size. It does so by advancing
|
|
||||||
// the pointer to the current element, possibly beyond the end of the slice.
|
|
||||||
// As a consequence, calling `size()` after `resize(x)` was called might
|
|
||||||
// return a value less than `x`.
|
|
||||||
void resize(size_t size) { current_ = begin_ + size; }
|
|
||||||
|
|
||||||
// Returns the pointer to the underlying data buffer.
|
|
||||||
T* data() { return begin_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
T* current_;
|
|
||||||
T* begin_;
|
|
||||||
T* end_;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename A>
|
|
||||||
auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
|
|
||||||
a->EnableAliasing(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename A>
|
|
||||||
void EnableAliasing(A&& a) {}
|
|
||||||
|
|
||||||
uint8 PeekTag(protobuf::io::CodedInputStream* stream);
|
|
||||||
|
|
||||||
constexpr uint8 kVarintTag(uint32 tag) { return (tag << 3) | 0; }
|
|
||||||
constexpr uint8 kDelimitedTag(uint32 tag) { return (tag << 3) | 2; }
|
|
||||||
constexpr uint8 kFixed32Tag(uint32 tag) { return (tag << 3) | 5; }
|
|
||||||
|
|
||||||
namespace parsed {
|
|
||||||
|
|
||||||
// ParseDataType has to be called first, then appropriate ParseZzzzList.
|
|
||||||
class Feature {
|
|
||||||
public:
|
|
||||||
Feature() {}
|
|
||||||
explicit Feature(StringPiece serialized) : serialized_(serialized) {}
|
|
||||||
|
|
||||||
Status ParseDataType(DataType* dtype) {
|
|
||||||
DCHECK(dtype != nullptr);
|
|
||||||
if (serialized_.empty()) {
|
|
||||||
*dtype = DT_INVALID;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
uint8 oneof_tag = static_cast<uint8>(*serialized_.data());
|
|
||||||
serialized_.remove_prefix(1);
|
|
||||||
switch (oneof_tag) {
|
|
||||||
case kDelimitedTag(1):
|
|
||||||
*dtype = DT_STRING;
|
|
||||||
break;
|
|
||||||
case kDelimitedTag(2):
|
|
||||||
*dtype = DT_FLOAT;
|
|
||||||
break;
|
|
||||||
case kDelimitedTag(3):
|
|
||||||
*dtype = DT_INT64;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
// Initialize variable to avoid compiler warning
|
|
||||||
*dtype = DT_INVALID;
|
|
||||||
return errors::InvalidArgument("Unsupported datatype.");
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool GetNumElementsInBytesList(int* num_elements) {
|
|
||||||
protobuf::io::CodedInputStream stream(
|
|
||||||
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
|
|
||||||
EnableAliasing(&stream);
|
|
||||||
uint32 length = 0;
|
|
||||||
if (!stream.ReadVarint32(&length)) return false;
|
|
||||||
auto limit = stream.PushLimit(length);
|
|
||||||
*num_elements = 0;
|
|
||||||
while (!stream.ExpectAtEnd()) {
|
|
||||||
if (!stream.ExpectTag(kDelimitedTag(1))) return false;
|
|
||||||
uint32 bytes_length = 0;
|
|
||||||
if (!stream.ReadVarint32(&bytes_length)) return false;
|
|
||||||
if (!stream.Skip(bytes_length)) return false;
|
|
||||||
++*num_elements;
|
|
||||||
}
|
|
||||||
stream.PopLimit(limit);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper methods
|
|
||||||
tstring* construct_at_end(LimitedArraySlice<tstring>* bytes_list) {
|
|
||||||
if (bytes_list->EndDistance() <= 0) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return &bytes_list->construct_at_end();
|
|
||||||
}
|
|
||||||
tstring* construct_at_end(SmallVector<tstring>* bytes_list) {
|
|
||||||
return &bytes_list->emplace_back();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Result>
|
|
||||||
bool ParseBytesList(Result* bytes_list) {
|
|
||||||
DCHECK(bytes_list != nullptr);
|
|
||||||
|
|
||||||
protobuf::io::CodedInputStream stream(
|
|
||||||
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
|
|
||||||
|
|
||||||
EnableAliasing(&stream);
|
|
||||||
|
|
||||||
uint32 length;
|
|
||||||
if (!stream.ReadVarint32(&length)) return false;
|
|
||||||
auto limit = stream.PushLimit(length);
|
|
||||||
|
|
||||||
while (!stream.ExpectAtEnd()) {
|
|
||||||
if (!stream.ExpectTag(kDelimitedTag(1))) return false;
|
|
||||||
// parse string
|
|
||||||
uint32 bytes_length;
|
|
||||||
if (!stream.ReadVarint32(&bytes_length)) return false;
|
|
||||||
tstring* bytes = construct_at_end(bytes_list);
|
|
||||||
if (bytes == nullptr) return false;
|
|
||||||
bytes->resize_uninitialized(bytes_length);
|
|
||||||
if (!stream.ReadRaw(bytes->data(), bytes_length)) return false;
|
|
||||||
}
|
|
||||||
stream.PopLimit(limit);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Result>
|
|
||||||
bool ParseFloatList(Result* float_list) {
|
|
||||||
DCHECK(float_list != nullptr);
|
|
||||||
protobuf::io::CodedInputStream stream(
|
|
||||||
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
|
|
||||||
EnableAliasing(&stream);
|
|
||||||
uint32 length;
|
|
||||||
if (!stream.ReadVarint32(&length)) return false;
|
|
||||||
auto limit = stream.PushLimit(length);
|
|
||||||
|
|
||||||
if (!stream.ExpectAtEnd()) {
|
|
||||||
uint8 peek_tag = PeekTag(&stream);
|
|
||||||
if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr int32 kNumFloatBytes = 4;
|
|
||||||
if (peek_tag == kDelimitedTag(1)) { // packed
|
|
||||||
if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag
|
|
||||||
uint32 packed_length;
|
|
||||||
if (!stream.ReadVarint32(&packed_length)) return false;
|
|
||||||
auto packed_limit = stream.PushLimit(packed_length);
|
|
||||||
|
|
||||||
// Store the initial size to know the offset we have to start writing
|
|
||||||
// data from before resizing the output "vector".
|
|
||||||
const size_t initial_size = float_list->size();
|
|
||||||
float_list->resize(initial_size + packed_length / kNumFloatBytes);
|
|
||||||
|
|
||||||
// If the result data type is float and we are on a little endian
|
|
||||||
// machine then we can simply memcpy the data from the proto into the
|
|
||||||
// result vector.
|
|
||||||
if (port::kLittleEndian &&
|
|
||||||
sizeof(typename Result::value_type) == kNumFloatBytes) {
|
|
||||||
// Calculate the length of the buffer available what can be less than
|
|
||||||
// what we requested in resize in case of a LimitedArraySlice.
|
|
||||||
const uint32 bytes_to_copy =
|
|
||||||
std::min(static_cast<uint32>((float_list->size() - initial_size) *
|
|
||||||
kNumFloatBytes),
|
|
||||||
packed_length);
|
|
||||||
if (!stream.ReadRaw(float_list->data() + initial_size, bytes_to_copy))
|
|
||||||
return false;
|
|
||||||
} else {
|
|
||||||
int64 index = initial_size;
|
|
||||||
while (!stream.ExpectAtEnd()) {
|
|
||||||
uint32 buffer32;
|
|
||||||
if (!stream.ReadLittleEndian32(&buffer32)) return false;
|
|
||||||
if (index < float_list->size()) {
|
|
||||||
float_list->data()[index] = absl::bit_cast<float>(buffer32);
|
|
||||||
++index;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.PopLimit(packed_limit);
|
|
||||||
} else { // non-packed
|
|
||||||
const size_t initial_size = float_list->size();
|
|
||||||
// 1 byte for the tag (`1` encoded as Variant32) and kNumFloatBytes for
|
|
||||||
// the value.
|
|
||||||
const int64 num_elements =
|
|
||||||
stream.BytesUntilLimit() / (1 + kNumFloatBytes);
|
|
||||||
float_list->resize(initial_size + num_elements);
|
|
||||||
int64 index = initial_size;
|
|
||||||
while (!stream.ExpectAtEnd()) {
|
|
||||||
if (!stream.ExpectTag(kFixed32Tag(1))) return false;
|
|
||||||
uint32 buffer32;
|
|
||||||
if (!stream.ReadLittleEndian32(&buffer32)) return false;
|
|
||||||
float_list->data()[index] = absl::bit_cast<float>(buffer32);
|
|
||||||
++index;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.PopLimit(limit);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Result>
|
|
||||||
bool ParseInt64List(Result* int64_list) {
|
|
||||||
DCHECK(int64_list != nullptr);
|
|
||||||
protobuf::io::CodedInputStream stream(
|
|
||||||
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
|
|
||||||
EnableAliasing(&stream);
|
|
||||||
uint32 length;
|
|
||||||
if (!stream.ReadVarint32(&length)) return false;
|
|
||||||
auto limit = stream.PushLimit(length);
|
|
||||||
|
|
||||||
if (!stream.ExpectAtEnd()) {
|
|
||||||
uint8 peek_tag = PeekTag(&stream);
|
|
||||||
if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (peek_tag == kDelimitedTag(1)) { // packed
|
|
||||||
if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag
|
|
||||||
uint32 packed_length;
|
|
||||||
if (!stream.ReadVarint32(&packed_length)) return false;
|
|
||||||
auto packed_limit = stream.PushLimit(packed_length);
|
|
||||||
|
|
||||||
while (!stream.ExpectAtEnd()) {
|
|
||||||
protobuf_uint64 n; // There is no API for int64
|
|
||||||
if (!stream.ReadVarint64(&n)) return false;
|
|
||||||
int64_list->push_back(static_cast<int64>(n));
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.PopLimit(packed_limit);
|
|
||||||
} else { // non-packed
|
|
||||||
while (!stream.ExpectAtEnd()) {
|
|
||||||
if (!stream.ExpectTag(kVarintTag(1))) return false;
|
|
||||||
protobuf_uint64 n; // There is no API for int64
|
|
||||||
if (!stream.ReadVarint64(&n)) return false;
|
|
||||||
int64_list->push_back(static_cast<int64>(n));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stream.PopLimit(limit);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
StringPiece GetSerialized() const { return serialized_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
StringPiece serialized_;
|
|
||||||
};
|
|
||||||
|
|
||||||
using FeatureMapEntry = std::pair<StringPiece, Feature>;
|
|
||||||
using Example = std::vector<FeatureMapEntry>;
|
|
||||||
|
|
||||||
} // namespace parsed
|
|
||||||
|
|
||||||
inline bool SkipExtraneousTag(protobuf::io::CodedInputStream* stream) {
|
|
||||||
uint32 data;
|
|
||||||
protobuf_uint64 dummy;
|
|
||||||
switch (stream->ReadTag() & 0x7) {
|
|
||||||
case 0: // varint
|
|
||||||
if (!stream->ReadVarint32(&data)) return false;
|
|
||||||
return true;
|
|
||||||
case 1: // fixed64
|
|
||||||
if (!stream->ReadLittleEndian64(&dummy)) return false;
|
|
||||||
return true;
|
|
||||||
case 2: // length delimited
|
|
||||||
if (!stream->ReadVarint32(&data)) return false;
|
|
||||||
stream->Skip(data);
|
|
||||||
return true;
|
|
||||||
case 3: // group begin
|
|
||||||
return false; // groups not supported.
|
|
||||||
case 4: // group end
|
|
||||||
return false; // groups not supported.
|
|
||||||
case 5: // fixed32
|
|
||||||
if (!stream->ReadLittleEndian32(&data)) return false;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false; // unrecognized tag type
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result);
|
|
||||||
|
|
||||||
bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream,
|
|
||||||
parsed::FeatureMapEntry* feature_map_entry);
|
|
||||||
|
|
||||||
bool ParseFeatures(protobuf::io::CodedInputStream* stream,
|
|
||||||
parsed::Example* example);
|
|
||||||
|
|
||||||
bool ParseExample(protobuf::io::CodedInputStream* stream,
|
|
||||||
parsed::Example* example);
|
|
||||||
|
|
||||||
bool ParseExample(StringPiece serialized, parsed::Example* example);
|
|
||||||
|
|
||||||
using Config = FastParseExampleConfig;
|
|
||||||
|
|
||||||
// Enumeration for distinguishing feature types.
|
|
||||||
// Note: FastParseSequenceExample constructs a map that includes Type values,
|
|
||||||
// and relies on the fact that they are default-initialized to Dense.
|
|
||||||
enum class Type { Dense, Sparse, Ragged };
|
|
||||||
|
|
||||||
// Note: We use SparseBuffer for sparse, ragged, and dense_varlen features.
|
|
||||||
struct SparseBuffer {
|
|
||||||
// Features are in one of the 3 vectors below depending on config's dtype.
|
|
||||||
// Other 2 vectors remain empty.
|
|
||||||
SmallVector<tstring> bytes_list;
|
|
||||||
SmallVector<float> float_list;
|
|
||||||
SmallVector<int64> int64_list;
|
|
||||||
|
|
||||||
// Features of example i are elements with indices
|
|
||||||
// from example_end_indices[i-1] to example_end_indices[i]-1 on the
|
|
||||||
// appropriate xxxxx_list
|
|
||||||
std::vector<size_t> example_end_indices;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct SeededHasher {
|
|
||||||
uint64 operator()(StringPiece s) const {
|
|
||||||
return Hash64(s.data(), s.size(), seed);
|
|
||||||
}
|
|
||||||
uint64 seed{0xDECAFCAFFE};
|
|
||||||
};
|
|
||||||
|
|
||||||
// Use this in the "default" clause of switch statements when dispatching
|
|
||||||
// on a dtype variable that was checked by CheckConfigDataType():
|
|
||||||
inline void ReportUnexpectedDataType(DataType dtype) {
|
|
||||||
DCHECK(false)
|
|
||||||
<< "Encountered unexpected DataType " << DataTypeString(dtype)
|
|
||||||
<< "in variable that should have been checked by CheckConfigDataType().";
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
const SmallVector<T>& GetListFromBuffer(const SparseBuffer& buffer);
|
|
||||||
|
|
||||||
template <>
|
|
||||||
const SmallVector<int64>& GetListFromBuffer<int64>(const SparseBuffer& buffer);
|
|
||||||
|
|
||||||
template <>
|
|
||||||
const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer);
|
|
||||||
|
|
||||||
template <>
|
|
||||||
const SmallVector<tstring>& GetListFromBuffer<tstring>(
|
|
||||||
const SparseBuffer& buffer);
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void CopyOrMoveBlock(const T* b, const T* e, T* t) {
|
|
||||||
std::copy(b, e, t);
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
void CopyOrMoveBlock(const tstring* b, const tstring* e, tstring* t);
|
|
||||||
|
|
||||||
void CountSparseFeatures(
|
|
||||||
const std::vector<std::vector<SparseBuffer>>& sparse_buffers, size_t d,
|
|
||||||
size_t* total_num_features, size_t* max_num_features);
|
|
||||||
|
|
||||||
void CopySparseBufferToTensor(DataType dtype, size_t offset, SparseBuffer* src,
|
|
||||||
Tensor* dst);
|
|
||||||
|
|
||||||
// A struct used by FastParseSequenceExample to hold the serialized proto
|
|
||||||
// substrings for a single feature, plus some auxiliary information derived
|
|
||||||
// from those protos (such as the total value length).
|
|
||||||
struct FeatureProtos {
|
|
||||||
// Proto substrings from each serialized SequenceExample that correspond
|
|
||||||
// with this feature. `protos_present` records whether the proto had a
|
|
||||||
// value defined (even if that value is empty).
|
|
||||||
std::vector<StringPiece> protos;
|
|
||||||
std::vector<bool> protos_present;
|
|
||||||
|
|
||||||
// Information derived from protos:
|
|
||||||
size_t length; // total length for ragged/sparse, max row length for dense.
|
|
||||||
size_t num_rows; // only populated for ragged sequence features.
|
|
||||||
|
|
||||||
// Information from the config:
|
|
||||||
Type type; // Whether this feature is sparse, ragged, or dense.
|
|
||||||
DataType dtype;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Map from feature name to FeatureProtos for that feature.
|
|
||||||
using FeatureProtosMap = absl::flat_hash_map<StringPiece, FeatureProtos>;
|
|
||||||
|
|
||||||
string ExampleName(const gtl::ArraySlice<tstring> example_names, int n);
|
|
||||||
|
|
||||||
// Return the number of bytes elements parsed, or -1 on error. If out is null,
|
|
||||||
// this method simply counts the number of elements without any copying.
|
|
||||||
inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
|
|
||||||
tstring* out) {
|
|
||||||
int num_elements = 0;
|
|
||||||
uint32 length;
|
|
||||||
if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
if (length > 0) {
|
|
||||||
auto limit = stream->PushLimit(length);
|
|
||||||
while (!stream->ExpectAtEnd()) {
|
|
||||||
uint32 bytes_length;
|
|
||||||
if (!stream->ExpectTag(kDelimitedTag(1)) ||
|
|
||||||
!stream->ReadVarint32(&bytes_length)) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
if (out == nullptr) {
|
|
||||||
stream->Skip(bytes_length);
|
|
||||||
} else {
|
|
||||||
out->resize_uninitialized(bytes_length);
|
|
||||||
if (!stream->ReadRaw(out->data(), bytes_length)) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
out++;
|
|
||||||
}
|
|
||||||
num_elements++;
|
|
||||||
}
|
|
||||||
stream->PopLimit(limit);
|
|
||||||
}
|
|
||||||
return num_elements;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void PadFloatFeature(int num_to_pad, float* out) {
|
|
||||||
for (int i = 0; i < num_to_pad; i++) {
|
|
||||||
*out++ = 0.0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void PadInt64Feature(int num_to_pad, int64* out) {
|
|
||||||
for (int i = 0; i < num_to_pad; i++) {
|
|
||||||
*out++ = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the number of float elements parsed, or -1 on error. If out is null,
|
|
||||||
// this method simply counts the number of elements without any copying.
|
|
||||||
inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
|
|
||||||
float* out) {
|
|
||||||
int num_elements = 0;
|
|
||||||
uint32 length;
|
|
||||||
if (!stream->ExpectTag(kDelimitedTag(2)) || !stream->ReadVarint32(&length)) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
if (length > 0) {
|
|
||||||
auto limit = stream->PushLimit(length);
|
|
||||||
uint8 peek_tag = PeekTag(stream);
|
|
||||||
if (peek_tag == kDelimitedTag(1)) { // packed
|
|
||||||
uint32 packed_length;
|
|
||||||
if (!stream->ExpectTag(kDelimitedTag(1)) ||
|
|
||||||
!stream->ReadVarint32(&packed_length)) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
auto packed_limit = stream->PushLimit(packed_length);
|
|
||||||
while (!stream->ExpectAtEnd()) {
|
|
||||||
uint32 buffer32;
|
|
||||||
if (!stream->ReadLittleEndian32(&buffer32)) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
if (out != nullptr) {
|
|
||||||
*out++ = absl::bit_cast<float>(buffer32);
|
|
||||||
}
|
|
||||||
num_elements++;
|
|
||||||
}
|
|
||||||
stream->PopLimit(packed_limit);
|
|
||||||
} else if (peek_tag == kFixed32Tag(1)) {
|
|
||||||
while (!stream->ExpectAtEnd()) {
|
|
||||||
uint32 buffer32;
|
|
||||||
if (!stream->ExpectTag(kFixed32Tag(1)) ||
|
|
||||||
!stream->ReadLittleEndian32(&buffer32)) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
if (out != nullptr) {
|
|
||||||
*out++ = absl::bit_cast<float>(buffer32);
|
|
||||||
}
|
|
||||||
num_elements++;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Unknown tag.
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
stream->PopLimit(limit);
|
|
||||||
}
|
|
||||||
return num_elements;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the number of int64 elements parsed, or -1 on error. If out is null,
|
|
||||||
// this method simply counts the number of elements without any copying.
|
|
||||||
inline int ParseInt64Feature(protobuf::io::CodedInputStream* stream,
|
|
||||||
int64* out) {
|
|
||||||
int num_elements = 0;
|
|
||||||
uint32 length;
|
|
||||||
if (!stream->ExpectTag(kDelimitedTag(3)) || !stream->ReadVarint32(&length)) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
if (length > 0) {
|
|
||||||
auto limit = stream->PushLimit(length);
|
|
||||||
uint8 peek_tag = PeekTag(stream);
|
|
||||||
if (peek_tag == kDelimitedTag(1)) { // packed
|
|
||||||
uint32 packed_length;
|
|
||||||
if (!stream->ExpectTag(kDelimitedTag(1)) ||
|
|
||||||
!stream->ReadVarint32(&packed_length)) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
auto packed_limit = stream->PushLimit(packed_length);
|
|
||||||
while (!stream->ExpectAtEnd()) {
|
|
||||||
protobuf_uint64 n; // There is no API for int64
|
|
||||||
if (!stream->ReadVarint64(&n)) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
if (out != nullptr) {
|
|
||||||
*out++ = n;
|
|
||||||
}
|
|
||||||
num_elements++;
|
|
||||||
}
|
|
||||||
stream->PopLimit(packed_limit);
|
|
||||||
} else if (peek_tag == kVarintTag(1)) {
|
|
||||||
while (!stream->ExpectAtEnd()) {
|
|
||||||
protobuf_uint64 n; // There is no API for int64
|
|
||||||
if (!stream->ExpectTag(kVarintTag(1)) || !stream->ReadVarint64(&n)) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
if (out != nullptr) {
|
|
||||||
*out++ = n;
|
|
||||||
}
|
|
||||||
num_elements++;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Unknown tag.
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
stream->PopLimit(limit);
|
|
||||||
}
|
|
||||||
return num_elements;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parses the next feature on `stream` into `out` starting at `out_offset`.
|
|
||||||
// Updates `out_offset`, and returns the number of values added.
|
|
||||||
// Returns -1 if the next feature on `stream` doesn't match `dtype`.
|
|
||||||
inline int ParseFeature(DataType dtype, protobuf::io::CodedInputStream* stream,
|
|
||||||
Tensor* out, size_t* out_offset) {
|
|
||||||
int delta;
|
|
||||||
switch (dtype) {
|
|
||||||
case DT_STRING:
|
|
||||||
delta =
|
|
||||||
ParseBytesFeature(stream, out->flat<tstring>().data() + *out_offset);
|
|
||||||
break;
|
|
||||||
case DT_FLOAT:
|
|
||||||
delta =
|
|
||||||
ParseFloatFeature(stream, out->flat<float>().data() + *out_offset);
|
|
||||||
break;
|
|
||||||
case DT_INT64:
|
|
||||||
delta =
|
|
||||||
ParseInt64Feature(stream, out->flat<int64>().data() + *out_offset);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
ReportUnexpectedDataType(dtype);
|
|
||||||
delta = 0;
|
|
||||||
}
|
|
||||||
if (delta > 0) {
|
|
||||||
*out_offset += delta;
|
|
||||||
}
|
|
||||||
return delta;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the length of the next feature on `stream`.
|
|
||||||
// Returns -1 if the next feature on `stream` doesn't match `dtype`.
|
|
||||||
inline int GetFeatureLength(DataType dtype,
|
|
||||||
protobuf::io::CodedInputStream* stream) {
|
|
||||||
switch (dtype) {
|
|
||||||
case DT_STRING:
|
|
||||||
return ParseBytesFeature(stream, nullptr);
|
|
||||||
case DT_FLOAT:
|
|
||||||
return ParseFloatFeature(stream, nullptr);
|
|
||||||
case DT_INT64:
|
|
||||||
return ParseInt64Feature(stream, nullptr);
|
|
||||||
default:
|
|
||||||
ReportUnexpectedDataType(dtype);
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline DataType ParseDataType(protobuf::io::CodedInputStream* stream) {
|
|
||||||
uint8 peek_tag = PeekTag(stream);
|
|
||||||
switch (peek_tag) {
|
|
||||||
case kDelimitedTag(1):
|
|
||||||
return DT_STRING;
|
|
||||||
case kDelimitedTag(2):
|
|
||||||
return DT_FLOAT;
|
|
||||||
case kDelimitedTag(3):
|
|
||||||
return DT_INT64;
|
|
||||||
default:
|
|
||||||
return DT_INVALID;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
|
|
||||||
DataType dtype) {
|
|
||||||
switch (dtype) {
|
|
||||||
case DT_STRING:
|
|
||||||
if (!stream->ExpectTag(kDelimitedTag(1))) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case DT_FLOAT:
|
|
||||||
if (!stream->ExpectTag(kDelimitedTag(2))) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case DT_INT64:
|
|
||||||
if (!stream->ExpectTag(kDelimitedTag(3))) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
uint32 length;
|
|
||||||
return stream->ReadVarint32(&length) && length == 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace example
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
|
|
File diff suppressed because it is too large
Load Diff
@ -1,33 +0,0 @@
|
|||||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#ifndef TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_PARSE_EXAMPLE_H_
|
|
||||||
#define TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_PARSE_EXAMPLE_H_
|
|
||||||
|
|
||||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
|
||||||
namespace ops {
|
|
||||||
namespace custom {
|
|
||||||
|
|
||||||
TfLiteRegistration* Register_PARSE_EXAMPLE();
|
|
||||||
TfLiteRegistration* Register_PARSE_EXAMPLE_V2();
|
|
||||||
|
|
||||||
extern "C" void AddParseExampleOp(::tflite::MutableOpResolver* resolver);
|
|
||||||
|
|
||||||
} // namespace custom
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace tflite
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_PARSE_EXAMPLE_H_
|
|
@ -1,330 +0,0 @@
|
|||||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#include "tensorflow/lite/kernels/parse_example/parse_example.h"
|
|
||||||
|
|
||||||
#include <initializer_list>
|
|
||||||
|
|
||||||
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
|
|
||||||
#include "tensorflow/core/example/feature_util.h"
|
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
|
||||||
#include "tensorflow/core/platform/tstring.h"
|
|
||||||
#include "tensorflow/lite/c/common.h"
|
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
|
||||||
#include "tensorflow/lite/interpreter.h"
|
|
||||||
#include "tensorflow/lite/interpreter_builder.h"
|
|
||||||
#include "tensorflow/lite/kernels/register.h"
|
|
||||||
#include "tensorflow/lite/kernels/test_util.h"
|
|
||||||
#include "tensorflow/lite/model_builder.h"
|
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
|
||||||
#include "tensorflow/lite/string_util.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
|
||||||
namespace ops {
|
|
||||||
namespace custom {
|
|
||||||
|
|
||||||
namespace tf = ::tensorflow;
|
|
||||||
|
|
||||||
const char* kNodeDefTxt = R"pb(
|
|
||||||
name: "ParseExample/ParseExample"
|
|
||||||
op: "ParseExample"
|
|
||||||
input: "serialized"
|
|
||||||
input: "ParseExample/ParseExample/names"
|
|
||||||
input: "ParseExample/ParseExample/dense_keys_0"
|
|
||||||
input: "ParseExample/Const"
|
|
||||||
attr {
|
|
||||||
key: "Ndense"
|
|
||||||
value { i: 1 }
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
key: "Nsparse"
|
|
||||||
value { i: 0 }
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
key: "Tdense"
|
|
||||||
value { list { type: DT_FLOAT } }
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
key: "dense_shapes"
|
|
||||||
value { list { shape { dim { size: 2 } } } }
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
key: "sparse_types"
|
|
||||||
value { list { type: DT_FLOAT } }
|
|
||||||
}
|
|
||||||
)pb";
|
|
||||||
|
|
||||||
const char* kNodeDefTxt2 = R"pb(
|
|
||||||
name: "ParseExample/ParseExample"
|
|
||||||
op: "ParseExample"
|
|
||||||
input: "serialized"
|
|
||||||
input: "ParseExample/ParseExample/names"
|
|
||||||
input: "ParseExample/ParseExample/sparse_keys_0"
|
|
||||||
attr {
|
|
||||||
key: "Ndense"
|
|
||||||
value { i: 0 }
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
key: "Nsparse"
|
|
||||||
value { i: 1 }
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
key: "Tdense"
|
|
||||||
value {}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
key: "dense_shapes"
|
|
||||||
value {}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
key: "sparse_types"
|
|
||||||
value { list { type: DT_FLOAT } }
|
|
||||||
}
|
|
||||||
)pb";
|
|
||||||
|
|
||||||
const char* kNodeDefTxt3 = R"pb(
|
|
||||||
name: "ParseExample/ParseExample"
|
|
||||||
op: "ParseExample"
|
|
||||||
input: "serialized"
|
|
||||||
input: "ParseExample/ParseExample/names"
|
|
||||||
input: "ParseExample/ParseExample/sparse_keys_0"
|
|
||||||
attr {
|
|
||||||
key: "Ndense"
|
|
||||||
value { i: 1 }
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
key: "Nsparse"
|
|
||||||
value { i: 0 }
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
key: "Tdense"
|
|
||||||
value { list { type: DT_STRING } }
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
key: "dense_shapes"
|
|
||||||
value { list { shape { dim { size: 1 } } } }
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
key: "sparse_types"
|
|
||||||
value { list { type: DT_FLOAT } }
|
|
||||||
}
|
|
||||||
)pb";
|
|
||||||
|
|
||||||
const char* kNodeDefTxt4 = R"pb(
|
|
||||||
name: "ParseExample/ParseExample"
|
|
||||||
op: "ParseExample"
|
|
||||||
input: "serialized"
|
|
||||||
input: "ParseExample/ParseExample/names"
|
|
||||||
input: "ParseExample/ParseExample/sparse_keys_0"
|
|
||||||
attr {
|
|
||||||
key: "Ndense"
|
|
||||||
value { i: 0 }
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
key: "Nsparse"
|
|
||||||
value { i: 1 }
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
key: "Tdense"
|
|
||||||
value {}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
key: "dense_shapes"
|
|
||||||
value {}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
key: "sparse_types"
|
|
||||||
value { list { type: DT_STRING } }
|
|
||||||
}
|
|
||||||
)pb";
|
|
||||||
|
|
||||||
template <typename DefaultType>
|
|
||||||
class ParseExampleOpModel : public SingleOpModel {
|
|
||||||
public:
|
|
||||||
ParseExampleOpModel(std::string serialized_example,
|
|
||||||
std::vector<std::string> sparse_keys,
|
|
||||||
std::vector<std::string> dense_keys,
|
|
||||||
std::initializer_list<DefaultType> dense_defaults,
|
|
||||||
std::vector<TensorType> dense_types,
|
|
||||||
std::vector<TensorType> sparse_types,
|
|
||||||
const char* text_def, int dense_size = 2) {
|
|
||||||
// Example
|
|
||||||
string_indices_.push_back(AddInput(TensorData(TensorType_STRING, {1})));
|
|
||||||
// Names
|
|
||||||
string_indices_.push_back(
|
|
||||||
AddConstInput<std::string>(TensorData(TensorType_STRING, {0}), {""}));
|
|
||||||
std::for_each(sparse_keys.begin(), sparse_keys.end(), [&](auto&&) {
|
|
||||||
string_indices_.push_back(AddInput(TensorData(TensorType_STRING, {1})));
|
|
||||||
});
|
|
||||||
std::for_each(dense_keys.begin(), dense_keys.end(), [&](auto&&) {
|
|
||||||
string_indices_.push_back(AddInput(TensorData(TensorType_STRING, {1})));
|
|
||||||
});
|
|
||||||
if (dense_size > 0) {
|
|
||||||
dense_defaults_ = AddConstInput<DefaultType>(
|
|
||||||
TensorData(dense_types[0], {dense_size}), dense_defaults);
|
|
||||||
}
|
|
||||||
if (!sparse_keys.empty()) {
|
|
||||||
for (int i = 0; i < sparse_keys.size(); i++) {
|
|
||||||
sparse_indices_outputs_.push_back(AddOutput(TensorType_INT64));
|
|
||||||
}
|
|
||||||
for (int i = 0; i < sparse_keys.size(); i++) {
|
|
||||||
sparse_values_outputs_.push_back(AddOutput(sparse_types[i]));
|
|
||||||
}
|
|
||||||
for (int i = 0; i < sparse_keys.size(); i++) {
|
|
||||||
sparse_shapes_outputs_.push_back(AddOutput({TensorType_INT64, {2}}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int i = 0; i < dense_keys.size(); i++) {
|
|
||||||
dense_outputs_.push_back(AddOutput({dense_types[i], {dense_size}}));
|
|
||||||
}
|
|
||||||
|
|
||||||
tf::NodeDef nodedef;
|
|
||||||
tf::protobuf::TextFormat::Parser parser;
|
|
||||||
tf::protobuf::io::ArrayInputStream input_stream(text_def, strlen(text_def));
|
|
||||||
if (!parser.Parse(&input_stream, &nodedef)) {
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
std::string serialized_nodedef;
|
|
||||||
nodedef.SerializeToString(&serialized_nodedef);
|
|
||||||
flexbuffers::Builder fbb;
|
|
||||||
fbb.Vector([&]() {
|
|
||||||
fbb.String(nodedef.op());
|
|
||||||
fbb.String(serialized_nodedef);
|
|
||||||
});
|
|
||||||
fbb.Finish();
|
|
||||||
const auto buffer = fbb.GetBuffer();
|
|
||||||
SetCustomOp("ParseExample", buffer, Register_PARSE_EXAMPLE);
|
|
||||||
BuildInterpreter({});
|
|
||||||
int idx = 0;
|
|
||||||
PopulateStringTensor(string_indices_[idx++], {serialized_example});
|
|
||||||
PopulateStringTensor(string_indices_[idx++], {""});
|
|
||||||
for (const auto& key : sparse_keys) {
|
|
||||||
PopulateStringTensor(string_indices_[idx++], {key});
|
|
||||||
}
|
|
||||||
for (const auto& key : dense_keys) {
|
|
||||||
PopulateStringTensor(string_indices_[idx++], {key});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
std::vector<T> GetSparseIndicesOutput(int i) {
|
|
||||||
return ExtractVector<T>(sparse_indices_outputs_[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
std::vector<T> GetSparseValuesOutput(int i) {
|
|
||||||
return ExtractVector<T>(sparse_values_outputs_[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
std::vector<T> GetSparseShapesOutput(int i) {
|
|
||||||
return ExtractVector<T>(sparse_shapes_outputs_[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
std::vector<T> GetDenseOutput(int i) {
|
|
||||||
return ExtractVector<T>(dense_outputs_[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::string> GetStringOutput(int i) {
|
|
||||||
auto* t = interpreter_->tensor(i);
|
|
||||||
int count = GetStringCount(t);
|
|
||||||
std::vector<std::string> v;
|
|
||||||
for (int i = 0; i < count; ++i) {
|
|
||||||
auto ref = GetString(t, i);
|
|
||||||
v.emplace_back(ref.str, ref.len);
|
|
||||||
}
|
|
||||||
return v;
|
|
||||||
}
|
|
||||||
|
|
||||||
int DenseDefaults() { return dense_defaults_; }
|
|
||||||
|
|
||||||
int SparseValuesOutputs(int i) { return sparse_values_outputs_[i]; }
|
|
||||||
|
|
||||||
int DenseOutputs(int i) { return dense_outputs_[i]; }
|
|
||||||
|
|
||||||
std::vector<int> dense_outputs_;
|
|
||||||
std::vector<int> sparse_indices_outputs_;
|
|
||||||
std::vector<int> sparse_shapes_outputs_;
|
|
||||||
std::vector<int> sparse_values_outputs_;
|
|
||||||
std::vector<int> string_indices_;
|
|
||||||
int dense_defaults_ = -1;
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST(ParseExampleOpsTest, SimpleTest) {
|
|
||||||
tf::Example example;
|
|
||||||
tf::AppendFeatureValues<float>({1.5f, 1.5f}, "time", &example);
|
|
||||||
tf::AppendFeatureValues<float>({1.0f, 1.0f}, "num", &example);
|
|
||||||
ParseExampleOpModel<float> m(example.SerializeAsString(), {}, {"time"},
|
|
||||||
{0.f, 0.f}, {TensorType_FLOAT32}, {},
|
|
||||||
kNodeDefTxt);
|
|
||||||
m.Invoke();
|
|
||||||
EXPECT_THAT(m.GetDenseOutput<float>(0),
|
|
||||||
ElementsAreArray(ArrayFloatNear({1.5f, 1.5f})));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ParseExampleOpsTest, SparseTest) {
|
|
||||||
tf::Example example;
|
|
||||||
tf::AppendFeatureValues<float>({1.5f}, "time", &example);
|
|
||||||
ParseExampleOpModel<float> m(example.SerializeAsString(), {"time"}, {}, {},
|
|
||||||
{}, {TensorType_FLOAT32}, kNodeDefTxt2, 0);
|
|
||||||
m.Invoke();
|
|
||||||
EXPECT_THAT(m.GetSparseIndicesOutput<int64_t>(0),
|
|
||||||
ElementsAreArray(ArrayFloatNear({0, 0})));
|
|
||||||
EXPECT_THAT(m.GetSparseValuesOutput<float>(0),
|
|
||||||
ElementsAreArray(ArrayFloatNear({1.5f})));
|
|
||||||
EXPECT_THAT(m.GetSparseShapesOutput<int64_t>(0),
|
|
||||||
ElementsAreArray(ArrayFloatNear({1, 1})));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ParseExampleOpsTest, SimpleBytesTest) {
|
|
||||||
tf::Example example;
|
|
||||||
const std::string test_data = "simpletest";
|
|
||||||
tf::AppendFeatureValues<tensorflow::tstring>({test_data}, "time", &example);
|
|
||||||
tf::AppendFeatureValues<float>({1.0f, 1.0f}, "num", &example);
|
|
||||||
std::string default_value = "missing";
|
|
||||||
ParseExampleOpModel<std::string> m(example.SerializeAsString(), {}, {"time"},
|
|
||||||
{default_value}, {TensorType_STRING}, {},
|
|
||||||
kNodeDefTxt3, 1);
|
|
||||||
m.PopulateStringTensor(m.DenseDefaults(), {default_value});
|
|
||||||
m.Invoke();
|
|
||||||
std::vector<string> c = m.GetStringOutput(m.DenseOutputs(0));
|
|
||||||
EXPECT_EQ(1, c.size());
|
|
||||||
EXPECT_EQ(test_data, c[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ParseExampleOpsTest, SparseBytesTest) {
|
|
||||||
tf::Example example;
|
|
||||||
const std::string test_data = "simpletest";
|
|
||||||
tf::AppendFeatureValues<tensorflow::tstring>({test_data, test_data}, "time",
|
|
||||||
&example);
|
|
||||||
tf::AppendFeatureValues<float>({1.0f, 1.0f}, "num", &example);
|
|
||||||
ParseExampleOpModel<std::string> m(example.SerializeAsString(), {"time"}, {},
|
|
||||||
{}, {}, {TensorType_STRING}, kNodeDefTxt4,
|
|
||||||
0);
|
|
||||||
m.Invoke();
|
|
||||||
EXPECT_THAT(m.GetSparseIndicesOutput<int64_t>(0),
|
|
||||||
testing::ElementsAreArray({0, 0, 0, 1}));
|
|
||||||
auto values = m.GetStringOutput(m.SparseValuesOutputs(0));
|
|
||||||
EXPECT_EQ(2, values.size());
|
|
||||||
EXPECT_EQ(test_data, values[0]);
|
|
||||||
EXPECT_EQ(test_data, values[1]);
|
|
||||||
EXPECT_THAT(m.GetSparseShapesOutput<int64_t>(0),
|
|
||||||
testing::ElementsAreArray({1, 2}));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace custom
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace tflite
|
|
@ -239,7 +239,6 @@ cc_library(
|
|||||||
"//tensorflow/lite/kernels:reference_ops",
|
"//tensorflow/lite/kernels:reference_ops",
|
||||||
"//tensorflow/lite/kernels:test_delegate_providers_lib",
|
"//tensorflow/lite/kernels:test_delegate_providers_lib",
|
||||||
"//tensorflow/lite/kernels/hashtable:hashtable_op_kernels",
|
"//tensorflow/lite/kernels/hashtable:hashtable_op_kernels",
|
||||||
"//tensorflow/lite/kernels/parse_example:parse_example",
|
|
||||||
"//tensorflow/lite/tools/evaluation:utils",
|
"//tensorflow/lite/tools/evaluation:utils",
|
||||||
] + select({
|
] + select({
|
||||||
"//tensorflow:ios": [],
|
"//tensorflow:ios": [],
|
||||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
|||||||
#endif
|
#endif
|
||||||
#include "tensorflow/lite/kernels/custom_ops_register.h"
|
#include "tensorflow/lite/kernels/custom_ops_register.h"
|
||||||
#include "tensorflow/lite/kernels/hashtable/hashtable_ops.h"
|
#include "tensorflow/lite/kernels/hashtable/hashtable_ops.h"
|
||||||
#include "tensorflow/lite/kernels/parse_example/parse_example.h"
|
|
||||||
#include "tensorflow/lite/kernels/register.h"
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
#include "tensorflow/lite/kernels/register_ref.h"
|
#include "tensorflow/lite/kernels/register_ref.h"
|
||||||
#include "tensorflow/lite/kernels/test_delegate_providers.h"
|
#include "tensorflow/lite/kernels/test_delegate_providers.h"
|
||||||
@ -371,7 +370,6 @@ TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
|
|||||||
ops::builtin::BuiltinOpResolver* buildinop_resolver_ =
|
ops::builtin::BuiltinOpResolver* buildinop_resolver_ =
|
||||||
reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
|
reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
|
||||||
tflite::ops::custom::AddHashtableOps(buildinop_resolver_);
|
tflite::ops::custom::AddHashtableOps(buildinop_resolver_);
|
||||||
tflite::ops::custom::AddParseExampleOp(buildinop_resolver_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (delegate_type) {
|
switch (delegate_type) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user