Add experimental custom parse_example op

PiperOrigin-RevId: 347640612
Change-Id: If4c83fc598391e8972f6b71015caf9def4d4bb7a
This commit is contained in:
Mihai Maruseac 2020-12-15 10:19:58 -08:00 committed by TensorFlower Gardener
parent 67e713e624
commit af93fc7294
8 changed files with 0 additions and 2300 deletions

View File

@ -1,75 +0,0 @@
# Kernel for custom parse_example
package(
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "parse_example",
srcs = [
"example_proto_fast_parsing.cc",
"parse_example.cc",
],
hdrs = [
"example_proto_fast_parsing.h",
"parse_example.h",
],
deps = [
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@flatbuffers",
"//tensorflow/lite:framework",
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels/internal:tensor",
"//tensorflow/lite:string_util",
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//tensorflow:ios": [
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:core_cpu",
"//tensorflow/core:feature_util",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
],
}),
)
cc_test(
name = "parse_example_test",
srcs = ["parse_example_test.cc"],
deps = [
":parse_example",
"@flatbuffers",
"//tensorflow/lite/c:common",
"//tensorflow/lite/core/api:op_resolver",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/kernels:test_main",
"//tensorflow/lite/kernels:test_util",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite:framework",
"//tensorflow/lite:string_util",
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//tensorflow:ios": [
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/example:feature_util",
"//tensorflow/core/platform:protobuf",
"//tensorflow/core/platform:tstring",
],
}),
)

View File

@ -1,170 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.h"
namespace tensorflow {
namespace example {
string ExampleName(const gtl::ArraySlice<tstring> example_names, int n) {
return example_names.empty() ? "<unknown>" : example_names[n];
}
void CountSparseFeatures(
const std::vector<std::vector<SparseBuffer>>& sparse_buffers, size_t d,
size_t* total_num_features, size_t* max_num_features) {
for (auto& sparse_values_tmp : sparse_buffers) {
const std::vector<size_t>& end_indices =
sparse_values_tmp[d].example_end_indices;
*total_num_features += end_indices.back();
*max_num_features = std::max(*max_num_features, end_indices[0]);
for (size_t i = 1; i < end_indices.size(); ++i) {
size_t example_size = end_indices[i] - end_indices[i - 1];
*max_num_features = std::max(*max_num_features, example_size);
}
}
}
void CopySparseBufferToTensor(DataType dtype, size_t offset, SparseBuffer* src,
Tensor* dst) {
switch (dtype) {
case DT_INT64: {
std::copy(src->int64_list.begin(), src->int64_list.end(),
dst->flat<int64>().data() + offset);
break;
}
case DT_FLOAT: {
std::copy(src->float_list.begin(), src->float_list.end(),
dst->flat<float>().data() + offset);
break;
}
case DT_STRING: {
std::move(src->bytes_list.begin(), src->bytes_list.end(),
dst->flat<tstring>().data() + offset);
break;
}
default:
ReportUnexpectedDataType(dtype);
}
}
uint8 PeekTag(protobuf::io::CodedInputStream* stream) {
DCHECK(stream != nullptr);
const void* ptr;
int size;
if (!stream->GetDirectBufferPointer(&ptr, &size)) return 0;
return *static_cast<const uint8*>(ptr);
}
bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result) {
DCHECK(stream != nullptr);
DCHECK(result != nullptr);
uint32 length;
if (!stream->ReadVarint32(&length)) return false;
if (length == 0) {
*result = StringPiece(nullptr, 0);
return true;
}
const void* stream_alias;
int stream_size;
if (!stream->GetDirectBufferPointer(&stream_alias, &stream_size)) {
return false;
}
if (static_cast<uint32>(stream_size) < length) return false;
*result = StringPiece(static_cast<const char*>(stream_alias), length);
stream->Skip(length);
return true;
}
bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream,
parsed::FeatureMapEntry* feature_map_entry) {
DCHECK(stream != nullptr);
DCHECK(feature_map_entry != nullptr);
uint32 length;
if (!stream->ReadVarint32(&length)) return false;
auto limit = stream->PushLimit(length);
if (!stream->ExpectTag(kDelimitedTag(1))) return false;
if (!ParseString(stream, &feature_map_entry->first)) return false;
if (!stream->ExpectTag(kDelimitedTag(2))) return false;
StringPiece feature_string_piece;
if (!ParseString(stream, &feature_string_piece)) return false;
feature_map_entry->second = parsed::Feature(feature_string_piece);
if (!stream->ExpectAtEnd()) return false;
stream->PopLimit(limit);
return true;
}
bool ParseFeatures(protobuf::io::CodedInputStream* stream,
parsed::Example* example) {
DCHECK(stream != nullptr);
DCHECK(example != nullptr);
uint32 length;
if (!stream->ReadVarint32(&length)) return false;
auto limit = stream->PushLimit(length);
while (!stream->ExpectAtEnd()) {
parsed::FeatureMapEntry feature_map_entry;
if (!stream->ExpectTag(kDelimitedTag(1))) return false;
if (!ParseFeatureMapEntry(stream, &feature_map_entry)) return false;
example->push_back(std::move(feature_map_entry));
}
stream->PopLimit(limit);
return true;
}
bool ParseExample(protobuf::io::CodedInputStream* stream,
parsed::Example* example) {
DCHECK(stream != nullptr);
DCHECK(example != nullptr);
// Loop over the input stream which may contain multiple serialized Example
// protos merged together as strings. This behavior is consistent with Proto's
// ParseFromString when string representations are concatenated.
while (!stream->ExpectAtEnd()) {
if (!stream->ExpectTag(kDelimitedTag(1))) {
if (!SkipExtraneousTag(stream)) return false;
} else {
if (!ParseFeatures(stream, example)) return false;
}
}
return true;
}
bool ParseExample(StringPiece serialized, parsed::Example* example) {
DCHECK(example != nullptr);
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(serialized.data()), serialized.size());
EnableAliasing(&stream);
return ParseExample(&stream, example);
}
template <>
void CopyOrMoveBlock(const tstring* b, const tstring* e, tstring* t) {
std::move(b, e, t);
}
template <>
const SmallVector<int64>& GetListFromBuffer<int64>(const SparseBuffer& buffer) {
return buffer.int64_list;
}
template <>
const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer) {
return buffer.float_list;
}
template <>
const SmallVector<tstring>& GetListFromBuffer<tstring>(
const SparseBuffer& buffer) {
return buffer.bytes_list;
}
} // namespace example
} // namespace tensorflow

View File

@ -1,688 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
#define TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
#include "tensorflow/core/util/example_proto_fast_parsing.h"
#include <vector>
#include "absl/base/casts.h"
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/util/presized_cuckoo_map.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"
namespace tensorflow {
namespace example {
template <typename T>
using SmallVector = gtl::InlinedVector<T, 4>;
template <typename T>
class LimitedArraySlice {
public:
using value_type = T;
LimitedArraySlice(T* begin, size_t num_elements)
: current_(begin), begin_(begin), end_(begin + num_elements) {}
// May return negative if there were push_back calls after slice was filled.
int64 EndDistance() const { return end_ - current_; }
// Attempts to push value to the back of this. If the slice has
// already been filled, this method has no effect on the underlying data, but
// it changes the number returned by EndDistance into negative values.
void push_back(T&& value) {
if (EndDistance() > 0) *current_ = std::move(value);
++current_;
}
// "Constructs" an element at the back of this by resizing the slice, and
// returns a mutable reference to the new last element.
// REQUIRES: EndDistance() > 0.
T& construct_at_end() {
DCHECK_GT(EndDistance(), 0);
return *(current_++);
}
// Returns a mutable reference to the last element in the slice.
// REQUIRES: size() > 0.
T& back() { return *(current_ - 1); }
// Returns the number of elements in the slice.
size_t size() const { return std::min(current_ - begin_, end_ - begin_); }
// Attempts to resize the vector to the given size. It does so by advancing
// the pointer to the current element, possibly beyond the end of the slice.
// As a consequence, calling `size()` after `resize(x)` was called might
// return a value less than `x`.
void resize(size_t size) { current_ = begin_ + size; }
// Returns the pointer to the underlying data buffer.
T* data() { return begin_; }
private:
T* current_;
T* begin_;
T* end_;
};
template <typename A>
auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
a->EnableAliasing(true);
}
template <typename A>
void EnableAliasing(A&& a) {}
uint8 PeekTag(protobuf::io::CodedInputStream* stream);
constexpr uint8 kVarintTag(uint32 tag) { return (tag << 3) | 0; }
constexpr uint8 kDelimitedTag(uint32 tag) { return (tag << 3) | 2; }
constexpr uint8 kFixed32Tag(uint32 tag) { return (tag << 3) | 5; }
namespace parsed {
// ParseDataType has to be called first, then appropriate ParseZzzzList.
class Feature {
public:
Feature() {}
explicit Feature(StringPiece serialized) : serialized_(serialized) {}
Status ParseDataType(DataType* dtype) {
DCHECK(dtype != nullptr);
if (serialized_.empty()) {
*dtype = DT_INVALID;
return Status::OK();
}
uint8 oneof_tag = static_cast<uint8>(*serialized_.data());
serialized_.remove_prefix(1);
switch (oneof_tag) {
case kDelimitedTag(1):
*dtype = DT_STRING;
break;
case kDelimitedTag(2):
*dtype = DT_FLOAT;
break;
case kDelimitedTag(3):
*dtype = DT_INT64;
break;
default:
// Initialize variable to avoid compiler warning
*dtype = DT_INVALID;
return errors::InvalidArgument("Unsupported datatype.");
}
return Status::OK();
}
bool GetNumElementsInBytesList(int* num_elements) {
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
EnableAliasing(&stream);
uint32 length = 0;
if (!stream.ReadVarint32(&length)) return false;
auto limit = stream.PushLimit(length);
*num_elements = 0;
while (!stream.ExpectAtEnd()) {
if (!stream.ExpectTag(kDelimitedTag(1))) return false;
uint32 bytes_length = 0;
if (!stream.ReadVarint32(&bytes_length)) return false;
if (!stream.Skip(bytes_length)) return false;
++*num_elements;
}
stream.PopLimit(limit);
return true;
}
// Helper methods
tstring* construct_at_end(LimitedArraySlice<tstring>* bytes_list) {
if (bytes_list->EndDistance() <= 0) {
return nullptr;
}
return &bytes_list->construct_at_end();
}
tstring* construct_at_end(SmallVector<tstring>* bytes_list) {
return &bytes_list->emplace_back();
}
template <typename Result>
bool ParseBytesList(Result* bytes_list) {
DCHECK(bytes_list != nullptr);
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
EnableAliasing(&stream);
uint32 length;
if (!stream.ReadVarint32(&length)) return false;
auto limit = stream.PushLimit(length);
while (!stream.ExpectAtEnd()) {
if (!stream.ExpectTag(kDelimitedTag(1))) return false;
// parse string
uint32 bytes_length;
if (!stream.ReadVarint32(&bytes_length)) return false;
tstring* bytes = construct_at_end(bytes_list);
if (bytes == nullptr) return false;
bytes->resize_uninitialized(bytes_length);
if (!stream.ReadRaw(bytes->data(), bytes_length)) return false;
}
stream.PopLimit(limit);
return true;
}
template <typename Result>
bool ParseFloatList(Result* float_list) {
DCHECK(float_list != nullptr);
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
EnableAliasing(&stream);
uint32 length;
if (!stream.ReadVarint32(&length)) return false;
auto limit = stream.PushLimit(length);
if (!stream.ExpectAtEnd()) {
uint8 peek_tag = PeekTag(&stream);
if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) {
return false;
}
constexpr int32 kNumFloatBytes = 4;
if (peek_tag == kDelimitedTag(1)) { // packed
if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag
uint32 packed_length;
if (!stream.ReadVarint32(&packed_length)) return false;
auto packed_limit = stream.PushLimit(packed_length);
// Store the initial size to know the offset we have to start writing
// data from before resizing the output "vector".
const size_t initial_size = float_list->size();
float_list->resize(initial_size + packed_length / kNumFloatBytes);
// If the result data type is float and we are on a little endian
// machine then we can simply memcpy the data from the proto into the
// result vector.
if (port::kLittleEndian &&
sizeof(typename Result::value_type) == kNumFloatBytes) {
// Calculate the length of the buffer available what can be less than
// what we requested in resize in case of a LimitedArraySlice.
const uint32 bytes_to_copy =
std::min(static_cast<uint32>((float_list->size() - initial_size) *
kNumFloatBytes),
packed_length);
if (!stream.ReadRaw(float_list->data() + initial_size, bytes_to_copy))
return false;
} else {
int64 index = initial_size;
while (!stream.ExpectAtEnd()) {
uint32 buffer32;
if (!stream.ReadLittleEndian32(&buffer32)) return false;
if (index < float_list->size()) {
float_list->data()[index] = absl::bit_cast<float>(buffer32);
++index;
}
}
}
stream.PopLimit(packed_limit);
} else { // non-packed
const size_t initial_size = float_list->size();
// 1 byte for the tag (`1` encoded as Variant32) and kNumFloatBytes for
// the value.
const int64 num_elements =
stream.BytesUntilLimit() / (1 + kNumFloatBytes);
float_list->resize(initial_size + num_elements);
int64 index = initial_size;
while (!stream.ExpectAtEnd()) {
if (!stream.ExpectTag(kFixed32Tag(1))) return false;
uint32 buffer32;
if (!stream.ReadLittleEndian32(&buffer32)) return false;
float_list->data()[index] = absl::bit_cast<float>(buffer32);
++index;
}
}
}
stream.PopLimit(limit);
return true;
}
template <typename Result>
bool ParseInt64List(Result* int64_list) {
DCHECK(int64_list != nullptr);
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
EnableAliasing(&stream);
uint32 length;
if (!stream.ReadVarint32(&length)) return false;
auto limit = stream.PushLimit(length);
if (!stream.ExpectAtEnd()) {
uint8 peek_tag = PeekTag(&stream);
if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) {
return false;
}
if (peek_tag == kDelimitedTag(1)) { // packed
if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag
uint32 packed_length;
if (!stream.ReadVarint32(&packed_length)) return false;
auto packed_limit = stream.PushLimit(packed_length);
while (!stream.ExpectAtEnd()) {
protobuf_uint64 n; // There is no API for int64
if (!stream.ReadVarint64(&n)) return false;
int64_list->push_back(static_cast<int64>(n));
}
stream.PopLimit(packed_limit);
} else { // non-packed
while (!stream.ExpectAtEnd()) {
if (!stream.ExpectTag(kVarintTag(1))) return false;
protobuf_uint64 n; // There is no API for int64
if (!stream.ReadVarint64(&n)) return false;
int64_list->push_back(static_cast<int64>(n));
}
}
}
stream.PopLimit(limit);
return true;
}
StringPiece GetSerialized() const { return serialized_; }
private:
StringPiece serialized_;
};
using FeatureMapEntry = std::pair<StringPiece, Feature>;
using Example = std::vector<FeatureMapEntry>;
} // namespace parsed
inline bool SkipExtraneousTag(protobuf::io::CodedInputStream* stream) {
uint32 data;
protobuf_uint64 dummy;
switch (stream->ReadTag() & 0x7) {
case 0: // varint
if (!stream->ReadVarint32(&data)) return false;
return true;
case 1: // fixed64
if (!stream->ReadLittleEndian64(&dummy)) return false;
return true;
case 2: // length delimited
if (!stream->ReadVarint32(&data)) return false;
stream->Skip(data);
return true;
case 3: // group begin
return false; // groups not supported.
case 4: // group end
return false; // groups not supported.
case 5: // fixed32
if (!stream->ReadLittleEndian32(&data)) return false;
return true;
}
return false; // unrecognized tag type
}
bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result);
bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream,
parsed::FeatureMapEntry* feature_map_entry);
bool ParseFeatures(protobuf::io::CodedInputStream* stream,
parsed::Example* example);
bool ParseExample(protobuf::io::CodedInputStream* stream,
parsed::Example* example);
bool ParseExample(StringPiece serialized, parsed::Example* example);
using Config = FastParseExampleConfig;
// Enumeration for distinguishing feature types.
// Note: FastParseSequenceExample constructs a map that includes Type values,
// and relies on the fact that they are default-initialized to Dense.
enum class Type { Dense, Sparse, Ragged };
// Note: We use SparseBuffer for sparse, ragged, and dense_varlen features.
struct SparseBuffer {
// Features are in one of the 3 vectors below depending on config's dtype.
// Other 2 vectors remain empty.
SmallVector<tstring> bytes_list;
SmallVector<float> float_list;
SmallVector<int64> int64_list;
// Features of example i are elements with indices
// from example_end_indices[i-1] to example_end_indices[i]-1 on the
// appropriate xxxxx_list
std::vector<size_t> example_end_indices;
};
struct SeededHasher {
uint64 operator()(StringPiece s) const {
return Hash64(s.data(), s.size(), seed);
}
uint64 seed{0xDECAFCAFFE};
};
// Use this in the "default" clause of switch statements when dispatching
// on a dtype variable that was checked by CheckConfigDataType():
inline void ReportUnexpectedDataType(DataType dtype) {
DCHECK(false)
<< "Encountered unexpected DataType " << DataTypeString(dtype)
<< "in variable that should have been checked by CheckConfigDataType().";
}
template <typename T>
const SmallVector<T>& GetListFromBuffer(const SparseBuffer& buffer);
template <>
const SmallVector<int64>& GetListFromBuffer<int64>(const SparseBuffer& buffer);
template <>
const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer);
template <>
const SmallVector<tstring>& GetListFromBuffer<tstring>(
const SparseBuffer& buffer);
template <typename T>
void CopyOrMoveBlock(const T* b, const T* e, T* t) {
std::copy(b, e, t);
}
template <>
void CopyOrMoveBlock(const tstring* b, const tstring* e, tstring* t);
void CountSparseFeatures(
const std::vector<std::vector<SparseBuffer>>& sparse_buffers, size_t d,
size_t* total_num_features, size_t* max_num_features);
void CopySparseBufferToTensor(DataType dtype, size_t offset, SparseBuffer* src,
Tensor* dst);
// A struct used by FastParseSequenceExample to hold the serialized proto
// substrings for a single feature, plus some auxiliary information derived
// from those protos (such as the total value length).
struct FeatureProtos {
// Proto substrings from each serialized SequenceExample that correspond
// with this feature. `protos_present` records whether the proto had a
// value defined (even if that value is empty).
std::vector<StringPiece> protos;
std::vector<bool> protos_present;
// Information derived from protos:
size_t length; // total length for ragged/sparse, max row length for dense.
size_t num_rows; // only populated for ragged sequence features.
// Information from the config:
Type type; // Whether this feature is sparse, ragged, or dense.
DataType dtype;
};
// Map from feature name to FeatureProtos for that feature.
using FeatureProtosMap = absl::flat_hash_map<StringPiece, FeatureProtos>;
string ExampleName(const gtl::ArraySlice<tstring> example_names, int n);
// Return the number of bytes elements parsed, or -1 on error. If out is null,
// this method simply counts the number of elements without any copying.
inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
tstring* out) {
int num_elements = 0;
uint32 length;
if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) {
return -1;
}
if (length > 0) {
auto limit = stream->PushLimit(length);
while (!stream->ExpectAtEnd()) {
uint32 bytes_length;
if (!stream->ExpectTag(kDelimitedTag(1)) ||
!stream->ReadVarint32(&bytes_length)) {
return -1;
}
if (out == nullptr) {
stream->Skip(bytes_length);
} else {
out->resize_uninitialized(bytes_length);
if (!stream->ReadRaw(out->data(), bytes_length)) {
return -1;
}
out++;
}
num_elements++;
}
stream->PopLimit(limit);
}
return num_elements;
}
inline void PadFloatFeature(int num_to_pad, float* out) {
for (int i = 0; i < num_to_pad; i++) {
*out++ = 0.0;
}
}
inline void PadInt64Feature(int num_to_pad, int64* out) {
for (int i = 0; i < num_to_pad; i++) {
*out++ = 0;
}
}
// Return the number of float elements parsed, or -1 on error. If out is null,
// this method simply counts the number of elements without any copying.
inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
float* out) {
int num_elements = 0;
uint32 length;
if (!stream->ExpectTag(kDelimitedTag(2)) || !stream->ReadVarint32(&length)) {
return -1;
}
if (length > 0) {
auto limit = stream->PushLimit(length);
uint8 peek_tag = PeekTag(stream);
if (peek_tag == kDelimitedTag(1)) { // packed
uint32 packed_length;
if (!stream->ExpectTag(kDelimitedTag(1)) ||
!stream->ReadVarint32(&packed_length)) {
return -1;
}
auto packed_limit = stream->PushLimit(packed_length);
while (!stream->ExpectAtEnd()) {
uint32 buffer32;
if (!stream->ReadLittleEndian32(&buffer32)) {
return -1;
}
if (out != nullptr) {
*out++ = absl::bit_cast<float>(buffer32);
}
num_elements++;
}
stream->PopLimit(packed_limit);
} else if (peek_tag == kFixed32Tag(1)) {
while (!stream->ExpectAtEnd()) {
uint32 buffer32;
if (!stream->ExpectTag(kFixed32Tag(1)) ||
!stream->ReadLittleEndian32(&buffer32)) {
return -1;
}
if (out != nullptr) {
*out++ = absl::bit_cast<float>(buffer32);
}
num_elements++;
}
} else {
// Unknown tag.
return -1;
}
stream->PopLimit(limit);
}
return num_elements;
}
// Return the number of int64 elements parsed, or -1 on error. If out is null,
// this method simply counts the number of elements without any copying.
inline int ParseInt64Feature(protobuf::io::CodedInputStream* stream,
int64* out) {
int num_elements = 0;
uint32 length;
if (!stream->ExpectTag(kDelimitedTag(3)) || !stream->ReadVarint32(&length)) {
return -1;
}
if (length > 0) {
auto limit = stream->PushLimit(length);
uint8 peek_tag = PeekTag(stream);
if (peek_tag == kDelimitedTag(1)) { // packed
uint32 packed_length;
if (!stream->ExpectTag(kDelimitedTag(1)) ||
!stream->ReadVarint32(&packed_length)) {
return -1;
}
auto packed_limit = stream->PushLimit(packed_length);
while (!stream->ExpectAtEnd()) {
protobuf_uint64 n; // There is no API for int64
if (!stream->ReadVarint64(&n)) {
return -1;
}
if (out != nullptr) {
*out++ = n;
}
num_elements++;
}
stream->PopLimit(packed_limit);
} else if (peek_tag == kVarintTag(1)) {
while (!stream->ExpectAtEnd()) {
protobuf_uint64 n; // There is no API for int64
if (!stream->ExpectTag(kVarintTag(1)) || !stream->ReadVarint64(&n)) {
return -1;
}
if (out != nullptr) {
*out++ = n;
}
num_elements++;
}
} else {
// Unknown tag.
return -1;
}
stream->PopLimit(limit);
}
return num_elements;
}
// Parses the next feature on `stream` into `out` starting at `out_offset`.
// Updates `out_offset`, and returns the number of values added.
// Returns -1 if the next feature on `stream` doesn't match `dtype`.
inline int ParseFeature(DataType dtype, protobuf::io::CodedInputStream* stream,
Tensor* out, size_t* out_offset) {
int delta;
switch (dtype) {
case DT_STRING:
delta =
ParseBytesFeature(stream, out->flat<tstring>().data() + *out_offset);
break;
case DT_FLOAT:
delta =
ParseFloatFeature(stream, out->flat<float>().data() + *out_offset);
break;
case DT_INT64:
delta =
ParseInt64Feature(stream, out->flat<int64>().data() + *out_offset);
break;
default:
ReportUnexpectedDataType(dtype);
delta = 0;
}
if (delta > 0) {
*out_offset += delta;
}
return delta;
}
// Returns the length of the next feature on `stream`.
// Returns -1 if the next feature on `stream` doesn't match `dtype`.
inline int GetFeatureLength(DataType dtype,
protobuf::io::CodedInputStream* stream) {
switch (dtype) {
case DT_STRING:
return ParseBytesFeature(stream, nullptr);
case DT_FLOAT:
return ParseFloatFeature(stream, nullptr);
case DT_INT64:
return ParseInt64Feature(stream, nullptr);
default:
ReportUnexpectedDataType(dtype);
return -1;
}
}
inline DataType ParseDataType(protobuf::io::CodedInputStream* stream) {
uint8 peek_tag = PeekTag(stream);
switch (peek_tag) {
case kDelimitedTag(1):
return DT_STRING;
case kDelimitedTag(2):
return DT_FLOAT;
case kDelimitedTag(3):
return DT_INT64;
default:
return DT_INVALID;
}
}
inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
DataType dtype) {
switch (dtype) {
case DT_STRING:
if (!stream->ExpectTag(kDelimitedTag(1))) {
return false;
}
break;
case DT_FLOAT:
if (!stream->ExpectTag(kDelimitedTag(2))) {
return false;
}
break;
case DT_INT64:
if (!stream->ExpectTag(kDelimitedTag(3))) {
return false;
}
break;
default:
return false;
}
uint32 length;
return stream->ReadVarint32(&length) && length == 0;
}
} // namespace example
} // namespace tensorflow
#endif // TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_

File diff suppressed because it is too large Load Diff

View File

@ -1,33 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_PARSE_EXAMPLE_H_
#define TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_PARSE_EXAMPLE_H_
#include "tensorflow/lite/mutable_op_resolver.h"
namespace tflite {
namespace ops {
namespace custom {
TfLiteRegistration* Register_PARSE_EXAMPLE();
TfLiteRegistration* Register_PARSE_EXAMPLE_V2();
extern "C" void AddParseExampleOp(::tflite::MutableOpResolver* resolver);
} // namespace custom
} // namespace ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_PARSE_EXAMPLE_H_

View File

@ -1,330 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/parse_example/parse_example.h"
#include <initializer_list>
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
#include "tensorflow/core/example/feature_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/interpreter_builder.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model_builder.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/string_util.h"
namespace tflite {
namespace ops {
namespace custom {
namespace tf = ::tensorflow;
const char* kNodeDefTxt = R"pb(
name: "ParseExample/ParseExample"
op: "ParseExample"
input: "serialized"
input: "ParseExample/ParseExample/names"
input: "ParseExample/ParseExample/dense_keys_0"
input: "ParseExample/Const"
attr {
key: "Ndense"
value { i: 1 }
}
attr {
key: "Nsparse"
value { i: 0 }
}
attr {
key: "Tdense"
value { list { type: DT_FLOAT } }
}
attr {
key: "dense_shapes"
value { list { shape { dim { size: 2 } } } }
}
attr {
key: "sparse_types"
value { list { type: DT_FLOAT } }
}
)pb";
const char* kNodeDefTxt2 = R"pb(
name: "ParseExample/ParseExample"
op: "ParseExample"
input: "serialized"
input: "ParseExample/ParseExample/names"
input: "ParseExample/ParseExample/sparse_keys_0"
attr {
key: "Ndense"
value { i: 0 }
}
attr {
key: "Nsparse"
value { i: 1 }
}
attr {
key: "Tdense"
value {}
}
attr {
key: "dense_shapes"
value {}
}
attr {
key: "sparse_types"
value { list { type: DT_FLOAT } }
}
)pb";
const char* kNodeDefTxt3 = R"pb(
name: "ParseExample/ParseExample"
op: "ParseExample"
input: "serialized"
input: "ParseExample/ParseExample/names"
input: "ParseExample/ParseExample/sparse_keys_0"
attr {
key: "Ndense"
value { i: 1 }
}
attr {
key: "Nsparse"
value { i: 0 }
}
attr {
key: "Tdense"
value { list { type: DT_STRING } }
}
attr {
key: "dense_shapes"
value { list { shape { dim { size: 1 } } } }
}
attr {
key: "sparse_types"
value { list { type: DT_FLOAT } }
}
)pb";
const char* kNodeDefTxt4 = R"pb(
name: "ParseExample/ParseExample"
op: "ParseExample"
input: "serialized"
input: "ParseExample/ParseExample/names"
input: "ParseExample/ParseExample/sparse_keys_0"
attr {
key: "Ndense"
value { i: 0 }
}
attr {
key: "Nsparse"
value { i: 1 }
}
attr {
key: "Tdense"
value {}
}
attr {
key: "dense_shapes"
value {}
}
attr {
key: "sparse_types"
value { list { type: DT_STRING } }
}
)pb";
template <typename DefaultType>
class ParseExampleOpModel : public SingleOpModel {
public:
ParseExampleOpModel(std::string serialized_example,
std::vector<std::string> sparse_keys,
std::vector<std::string> dense_keys,
std::initializer_list<DefaultType> dense_defaults,
std::vector<TensorType> dense_types,
std::vector<TensorType> sparse_types,
const char* text_def, int dense_size = 2) {
// Example
string_indices_.push_back(AddInput(TensorData(TensorType_STRING, {1})));
// Names
string_indices_.push_back(
AddConstInput<std::string>(TensorData(TensorType_STRING, {0}), {""}));
std::for_each(sparse_keys.begin(), sparse_keys.end(), [&](auto&&) {
string_indices_.push_back(AddInput(TensorData(TensorType_STRING, {1})));
});
std::for_each(dense_keys.begin(), dense_keys.end(), [&](auto&&) {
string_indices_.push_back(AddInput(TensorData(TensorType_STRING, {1})));
});
if (dense_size > 0) {
dense_defaults_ = AddConstInput<DefaultType>(
TensorData(dense_types[0], {dense_size}), dense_defaults);
}
if (!sparse_keys.empty()) {
for (int i = 0; i < sparse_keys.size(); i++) {
sparse_indices_outputs_.push_back(AddOutput(TensorType_INT64));
}
for (int i = 0; i < sparse_keys.size(); i++) {
sparse_values_outputs_.push_back(AddOutput(sparse_types[i]));
}
for (int i = 0; i < sparse_keys.size(); i++) {
sparse_shapes_outputs_.push_back(AddOutput({TensorType_INT64, {2}}));
}
}
for (int i = 0; i < dense_keys.size(); i++) {
dense_outputs_.push_back(AddOutput({dense_types[i], {dense_size}}));
}
tf::NodeDef nodedef;
tf::protobuf::TextFormat::Parser parser;
tf::protobuf::io::ArrayInputStream input_stream(text_def, strlen(text_def));
if (!parser.Parse(&input_stream, &nodedef)) {
abort();
}
std::string serialized_nodedef;
nodedef.SerializeToString(&serialized_nodedef);
flexbuffers::Builder fbb;
fbb.Vector([&]() {
fbb.String(nodedef.op());
fbb.String(serialized_nodedef);
});
fbb.Finish();
const auto buffer = fbb.GetBuffer();
SetCustomOp("ParseExample", buffer, Register_PARSE_EXAMPLE);
BuildInterpreter({});
int idx = 0;
PopulateStringTensor(string_indices_[idx++], {serialized_example});
PopulateStringTensor(string_indices_[idx++], {""});
for (const auto& key : sparse_keys) {
PopulateStringTensor(string_indices_[idx++], {key});
}
for (const auto& key : dense_keys) {
PopulateStringTensor(string_indices_[idx++], {key});
}
}
template <typename T>
std::vector<T> GetSparseIndicesOutput(int i) {
return ExtractVector<T>(sparse_indices_outputs_[i]);
}
template <typename T>
std::vector<T> GetSparseValuesOutput(int i) {
return ExtractVector<T>(sparse_values_outputs_[i]);
}
template <typename T>
std::vector<T> GetSparseShapesOutput(int i) {
return ExtractVector<T>(sparse_shapes_outputs_[i]);
}
template <typename T>
std::vector<T> GetDenseOutput(int i) {
return ExtractVector<T>(dense_outputs_[i]);
}
std::vector<std::string> GetStringOutput(int i) {
auto* t = interpreter_->tensor(i);
int count = GetStringCount(t);
std::vector<std::string> v;
for (int i = 0; i < count; ++i) {
auto ref = GetString(t, i);
v.emplace_back(ref.str, ref.len);
}
return v;
}
int DenseDefaults() { return dense_defaults_; }
int SparseValuesOutputs(int i) { return sparse_values_outputs_[i]; }
int DenseOutputs(int i) { return dense_outputs_[i]; }
std::vector<int> dense_outputs_;
std::vector<int> sparse_indices_outputs_;
std::vector<int> sparse_shapes_outputs_;
std::vector<int> sparse_values_outputs_;
std::vector<int> string_indices_;
int dense_defaults_ = -1;
};
TEST(ParseExampleOpsTest, SimpleTest) {
tf::Example example;
tf::AppendFeatureValues<float>({1.5f, 1.5f}, "time", &example);
tf::AppendFeatureValues<float>({1.0f, 1.0f}, "num", &example);
ParseExampleOpModel<float> m(example.SerializeAsString(), {}, {"time"},
{0.f, 0.f}, {TensorType_FLOAT32}, {},
kNodeDefTxt);
m.Invoke();
EXPECT_THAT(m.GetDenseOutput<float>(0),
ElementsAreArray(ArrayFloatNear({1.5f, 1.5f})));
}
TEST(ParseExampleOpsTest, SparseTest) {
tf::Example example;
tf::AppendFeatureValues<float>({1.5f}, "time", &example);
ParseExampleOpModel<float> m(example.SerializeAsString(), {"time"}, {}, {},
{}, {TensorType_FLOAT32}, kNodeDefTxt2, 0);
m.Invoke();
EXPECT_THAT(m.GetSparseIndicesOutput<int64_t>(0),
ElementsAreArray(ArrayFloatNear({0, 0})));
EXPECT_THAT(m.GetSparseValuesOutput<float>(0),
ElementsAreArray(ArrayFloatNear({1.5f})));
EXPECT_THAT(m.GetSparseShapesOutput<int64_t>(0),
ElementsAreArray(ArrayFloatNear({1, 1})));
}
TEST(ParseExampleOpsTest, SimpleBytesTest) {
tf::Example example;
const std::string test_data = "simpletest";
tf::AppendFeatureValues<tensorflow::tstring>({test_data}, "time", &example);
tf::AppendFeatureValues<float>({1.0f, 1.0f}, "num", &example);
std::string default_value = "missing";
ParseExampleOpModel<std::string> m(example.SerializeAsString(), {}, {"time"},
{default_value}, {TensorType_STRING}, {},
kNodeDefTxt3, 1);
m.PopulateStringTensor(m.DenseDefaults(), {default_value});
m.Invoke();
std::vector<string> c = m.GetStringOutput(m.DenseOutputs(0));
EXPECT_EQ(1, c.size());
EXPECT_EQ(test_data, c[0]);
}
TEST(ParseExampleOpsTest, SparseBytesTest) {
tf::Example example;
const std::string test_data = "simpletest";
tf::AppendFeatureValues<tensorflow::tstring>({test_data, test_data}, "time",
&example);
tf::AppendFeatureValues<float>({1.0f, 1.0f}, "num", &example);
ParseExampleOpModel<std::string> m(example.SerializeAsString(), {"time"}, {},
{}, {}, {TensorType_STRING}, kNodeDefTxt4,
0);
m.Invoke();
EXPECT_THAT(m.GetSparseIndicesOutput<int64_t>(0),
testing::ElementsAreArray({0, 0, 0, 1}));
auto values = m.GetStringOutput(m.SparseValuesOutputs(0));
EXPECT_EQ(2, values.size());
EXPECT_EQ(test_data, values[0]);
EXPECT_EQ(test_data, values[1]);
EXPECT_THAT(m.GetSparseShapesOutput<int64_t>(0),
testing::ElementsAreArray({1, 2}));
}
} // namespace custom
} // namespace ops
} // namespace tflite

View File

@ -239,7 +239,6 @@ cc_library(
"//tensorflow/lite/kernels:reference_ops",
"//tensorflow/lite/kernels:test_delegate_providers_lib",
"//tensorflow/lite/kernels/hashtable:hashtable_op_kernels",
"//tensorflow/lite/kernels/parse_example:parse_example",
"//tensorflow/lite/tools/evaluation:utils",
] + select({
"//tensorflow:ios": [],

View File

@ -26,7 +26,6 @@ limitations under the License.
#endif
#include "tensorflow/lite/kernels/custom_ops_register.h"
#include "tensorflow/lite/kernels/hashtable/hashtable_ops.h"
#include "tensorflow/lite/kernels/parse_example/parse_example.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/register_ref.h"
#include "tensorflow/lite/kernels/test_delegate_providers.h"
@ -371,7 +370,6 @@ TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
ops::builtin::BuiltinOpResolver* buildinop_resolver_ =
reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
tflite::ops::custom::AddHashtableOps(buildinop_resolver_);
tflite::ops::custom::AddParseExampleOp(buildinop_resolver_);
}
switch (delegate_type) {