Add experimental custom parse_example op
PiperOrigin-RevId: 347640612 Change-Id: If4c83fc598391e8972f6b71015caf9def4d4bb7a
This commit is contained in:
parent
67e713e624
commit
af93fc7294
@ -1,75 +0,0 @@
|
||||
# Kernel for custom parse_example
|
||||
package(
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parse_example",
|
||||
srcs = [
|
||||
"example_proto_fast_parsing.cc",
|
||||
"parse_example.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"example_proto_fast_parsing.h",
|
||||
"parse_example.h",
|
||||
],
|
||||
deps = [
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@flatbuffers",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
"//tensorflow/lite/kernels/internal:tensor",
|
||||
"//tensorflow/lite:string_util",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//tensorflow:ios": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:feature_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "parse_example_test",
|
||||
srcs = ["parse_example_test.cc"],
|
||||
deps = [
|
||||
":parse_example",
|
||||
"@flatbuffers",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/core/api:op_resolver",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/kernels:test_main",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:string_util",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//tensorflow:ios": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/example:feature_util",
|
||||
"//tensorflow/core/platform:protobuf",
|
||||
"//tensorflow/core/platform:tstring",
|
||||
],
|
||||
}),
|
||||
)
|
@ -1,170 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace example {
|
||||
|
||||
string ExampleName(const gtl::ArraySlice<tstring> example_names, int n) {
|
||||
return example_names.empty() ? "<unknown>" : example_names[n];
|
||||
}
|
||||
|
||||
void CountSparseFeatures(
|
||||
const std::vector<std::vector<SparseBuffer>>& sparse_buffers, size_t d,
|
||||
size_t* total_num_features, size_t* max_num_features) {
|
||||
for (auto& sparse_values_tmp : sparse_buffers) {
|
||||
const std::vector<size_t>& end_indices =
|
||||
sparse_values_tmp[d].example_end_indices;
|
||||
*total_num_features += end_indices.back();
|
||||
*max_num_features = std::max(*max_num_features, end_indices[0]);
|
||||
for (size_t i = 1; i < end_indices.size(); ++i) {
|
||||
size_t example_size = end_indices[i] - end_indices[i - 1];
|
||||
*max_num_features = std::max(*max_num_features, example_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CopySparseBufferToTensor(DataType dtype, size_t offset, SparseBuffer* src,
|
||||
Tensor* dst) {
|
||||
switch (dtype) {
|
||||
case DT_INT64: {
|
||||
std::copy(src->int64_list.begin(), src->int64_list.end(),
|
||||
dst->flat<int64>().data() + offset);
|
||||
break;
|
||||
}
|
||||
case DT_FLOAT: {
|
||||
std::copy(src->float_list.begin(), src->float_list.end(),
|
||||
dst->flat<float>().data() + offset);
|
||||
break;
|
||||
}
|
||||
case DT_STRING: {
|
||||
std::move(src->bytes_list.begin(), src->bytes_list.end(),
|
||||
dst->flat<tstring>().data() + offset);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
ReportUnexpectedDataType(dtype);
|
||||
}
|
||||
}
|
||||
|
||||
uint8 PeekTag(protobuf::io::CodedInputStream* stream) {
|
||||
DCHECK(stream != nullptr);
|
||||
const void* ptr;
|
||||
int size;
|
||||
if (!stream->GetDirectBufferPointer(&ptr, &size)) return 0;
|
||||
return *static_cast<const uint8*>(ptr);
|
||||
}
|
||||
|
||||
bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result) {
|
||||
DCHECK(stream != nullptr);
|
||||
DCHECK(result != nullptr);
|
||||
uint32 length;
|
||||
if (!stream->ReadVarint32(&length)) return false;
|
||||
if (length == 0) {
|
||||
*result = StringPiece(nullptr, 0);
|
||||
return true;
|
||||
}
|
||||
const void* stream_alias;
|
||||
int stream_size;
|
||||
if (!stream->GetDirectBufferPointer(&stream_alias, &stream_size)) {
|
||||
return false;
|
||||
}
|
||||
if (static_cast<uint32>(stream_size) < length) return false;
|
||||
*result = StringPiece(static_cast<const char*>(stream_alias), length);
|
||||
stream->Skip(length);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream,
|
||||
parsed::FeatureMapEntry* feature_map_entry) {
|
||||
DCHECK(stream != nullptr);
|
||||
DCHECK(feature_map_entry != nullptr);
|
||||
uint32 length;
|
||||
if (!stream->ReadVarint32(&length)) return false;
|
||||
auto limit = stream->PushLimit(length);
|
||||
if (!stream->ExpectTag(kDelimitedTag(1))) return false;
|
||||
if (!ParseString(stream, &feature_map_entry->first)) return false;
|
||||
if (!stream->ExpectTag(kDelimitedTag(2))) return false;
|
||||
StringPiece feature_string_piece;
|
||||
if (!ParseString(stream, &feature_string_piece)) return false;
|
||||
feature_map_entry->second = parsed::Feature(feature_string_piece);
|
||||
if (!stream->ExpectAtEnd()) return false;
|
||||
stream->PopLimit(limit);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParseFeatures(protobuf::io::CodedInputStream* stream,
|
||||
parsed::Example* example) {
|
||||
DCHECK(stream != nullptr);
|
||||
DCHECK(example != nullptr);
|
||||
uint32 length;
|
||||
if (!stream->ReadVarint32(&length)) return false;
|
||||
auto limit = stream->PushLimit(length);
|
||||
while (!stream->ExpectAtEnd()) {
|
||||
parsed::FeatureMapEntry feature_map_entry;
|
||||
if (!stream->ExpectTag(kDelimitedTag(1))) return false;
|
||||
if (!ParseFeatureMapEntry(stream, &feature_map_entry)) return false;
|
||||
example->push_back(std::move(feature_map_entry));
|
||||
}
|
||||
stream->PopLimit(limit);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParseExample(protobuf::io::CodedInputStream* stream,
|
||||
parsed::Example* example) {
|
||||
DCHECK(stream != nullptr);
|
||||
DCHECK(example != nullptr);
|
||||
// Loop over the input stream which may contain multiple serialized Example
|
||||
// protos merged together as strings. This behavior is consistent with Proto's
|
||||
// ParseFromString when string representations are concatenated.
|
||||
while (!stream->ExpectAtEnd()) {
|
||||
if (!stream->ExpectTag(kDelimitedTag(1))) {
|
||||
if (!SkipExtraneousTag(stream)) return false;
|
||||
} else {
|
||||
if (!ParseFeatures(stream, example)) return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParseExample(StringPiece serialized, parsed::Example* example) {
|
||||
DCHECK(example != nullptr);
|
||||
protobuf::io::CodedInputStream stream(
|
||||
reinterpret_cast<const uint8*>(serialized.data()), serialized.size());
|
||||
EnableAliasing(&stream);
|
||||
return ParseExample(&stream, example);
|
||||
}
|
||||
|
||||
template <>
|
||||
void CopyOrMoveBlock(const tstring* b, const tstring* e, tstring* t) {
|
||||
std::move(b, e, t);
|
||||
}
|
||||
|
||||
template <>
|
||||
const SmallVector<int64>& GetListFromBuffer<int64>(const SparseBuffer& buffer) {
|
||||
return buffer.int64_list;
|
||||
}
|
||||
template <>
|
||||
const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer) {
|
||||
return buffer.float_list;
|
||||
}
|
||||
template <>
|
||||
const SmallVector<tstring>& GetListFromBuffer<tstring>(
|
||||
const SparseBuffer& buffer) {
|
||||
return buffer.bytes_list;
|
||||
}
|
||||
|
||||
} // namespace example
|
||||
} // namespace tensorflow
|
@ -1,688 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
|
||||
#include "tensorflow/core/util/example_proto_fast_parsing.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/casts.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/example/example.pb.h"
|
||||
#include "tensorflow/core/example/feature.pb.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/platform/byte_order.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/util/presized_cuckoo_map.h"
|
||||
#include "tensorflow/core/util/sparse/sparse_tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace example {
|
||||
|
||||
template <typename T>
|
||||
using SmallVector = gtl::InlinedVector<T, 4>;
|
||||
|
||||
template <typename T>
|
||||
class LimitedArraySlice {
|
||||
public:
|
||||
using value_type = T;
|
||||
|
||||
LimitedArraySlice(T* begin, size_t num_elements)
|
||||
: current_(begin), begin_(begin), end_(begin + num_elements) {}
|
||||
|
||||
// May return negative if there were push_back calls after slice was filled.
|
||||
int64 EndDistance() const { return end_ - current_; }
|
||||
|
||||
// Attempts to push value to the back of this. If the slice has
|
||||
// already been filled, this method has no effect on the underlying data, but
|
||||
// it changes the number returned by EndDistance into negative values.
|
||||
void push_back(T&& value) {
|
||||
if (EndDistance() > 0) *current_ = std::move(value);
|
||||
++current_;
|
||||
}
|
||||
|
||||
// "Constructs" an element at the back of this by resizing the slice, and
|
||||
// returns a mutable reference to the new last element.
|
||||
// REQUIRES: EndDistance() > 0.
|
||||
T& construct_at_end() {
|
||||
DCHECK_GT(EndDistance(), 0);
|
||||
return *(current_++);
|
||||
}
|
||||
|
||||
// Returns a mutable reference to the last element in the slice.
|
||||
// REQUIRES: size() > 0.
|
||||
T& back() { return *(current_ - 1); }
|
||||
|
||||
// Returns the number of elements in the slice.
|
||||
size_t size() const { return std::min(current_ - begin_, end_ - begin_); }
|
||||
|
||||
// Attempts to resize the vector to the given size. It does so by advancing
|
||||
// the pointer to the current element, possibly beyond the end of the slice.
|
||||
// As a consequence, calling `size()` after `resize(x)` was called might
|
||||
// return a value less than `x`.
|
||||
void resize(size_t size) { current_ = begin_ + size; }
|
||||
|
||||
// Returns the pointer to the underlying data buffer.
|
||||
T* data() { return begin_; }
|
||||
|
||||
private:
|
||||
T* current_;
|
||||
T* begin_;
|
||||
T* end_;
|
||||
};
|
||||
|
||||
template <typename A>
|
||||
auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
|
||||
a->EnableAliasing(true);
|
||||
}
|
||||
|
||||
template <typename A>
|
||||
void EnableAliasing(A&& a) {}
|
||||
|
||||
uint8 PeekTag(protobuf::io::CodedInputStream* stream);
|
||||
|
||||
constexpr uint8 kVarintTag(uint32 tag) { return (tag << 3) | 0; }
|
||||
constexpr uint8 kDelimitedTag(uint32 tag) { return (tag << 3) | 2; }
|
||||
constexpr uint8 kFixed32Tag(uint32 tag) { return (tag << 3) | 5; }
|
||||
|
||||
namespace parsed {
|
||||
|
||||
// ParseDataType has to be called first, then appropriate ParseZzzzList.
|
||||
class Feature {
|
||||
public:
|
||||
Feature() {}
|
||||
explicit Feature(StringPiece serialized) : serialized_(serialized) {}
|
||||
|
||||
Status ParseDataType(DataType* dtype) {
|
||||
DCHECK(dtype != nullptr);
|
||||
if (serialized_.empty()) {
|
||||
*dtype = DT_INVALID;
|
||||
return Status::OK();
|
||||
}
|
||||
uint8 oneof_tag = static_cast<uint8>(*serialized_.data());
|
||||
serialized_.remove_prefix(1);
|
||||
switch (oneof_tag) {
|
||||
case kDelimitedTag(1):
|
||||
*dtype = DT_STRING;
|
||||
break;
|
||||
case kDelimitedTag(2):
|
||||
*dtype = DT_FLOAT;
|
||||
break;
|
||||
case kDelimitedTag(3):
|
||||
*dtype = DT_INT64;
|
||||
break;
|
||||
default:
|
||||
// Initialize variable to avoid compiler warning
|
||||
*dtype = DT_INVALID;
|
||||
return errors::InvalidArgument("Unsupported datatype.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool GetNumElementsInBytesList(int* num_elements) {
|
||||
protobuf::io::CodedInputStream stream(
|
||||
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
|
||||
EnableAliasing(&stream);
|
||||
uint32 length = 0;
|
||||
if (!stream.ReadVarint32(&length)) return false;
|
||||
auto limit = stream.PushLimit(length);
|
||||
*num_elements = 0;
|
||||
while (!stream.ExpectAtEnd()) {
|
||||
if (!stream.ExpectTag(kDelimitedTag(1))) return false;
|
||||
uint32 bytes_length = 0;
|
||||
if (!stream.ReadVarint32(&bytes_length)) return false;
|
||||
if (!stream.Skip(bytes_length)) return false;
|
||||
++*num_elements;
|
||||
}
|
||||
stream.PopLimit(limit);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
tstring* construct_at_end(LimitedArraySlice<tstring>* bytes_list) {
|
||||
if (bytes_list->EndDistance() <= 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return &bytes_list->construct_at_end();
|
||||
}
|
||||
tstring* construct_at_end(SmallVector<tstring>* bytes_list) {
|
||||
return &bytes_list->emplace_back();
|
||||
}
|
||||
|
||||
template <typename Result>
|
||||
bool ParseBytesList(Result* bytes_list) {
|
||||
DCHECK(bytes_list != nullptr);
|
||||
|
||||
protobuf::io::CodedInputStream stream(
|
||||
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
|
||||
|
||||
EnableAliasing(&stream);
|
||||
|
||||
uint32 length;
|
||||
if (!stream.ReadVarint32(&length)) return false;
|
||||
auto limit = stream.PushLimit(length);
|
||||
|
||||
while (!stream.ExpectAtEnd()) {
|
||||
if (!stream.ExpectTag(kDelimitedTag(1))) return false;
|
||||
// parse string
|
||||
uint32 bytes_length;
|
||||
if (!stream.ReadVarint32(&bytes_length)) return false;
|
||||
tstring* bytes = construct_at_end(bytes_list);
|
||||
if (bytes == nullptr) return false;
|
||||
bytes->resize_uninitialized(bytes_length);
|
||||
if (!stream.ReadRaw(bytes->data(), bytes_length)) return false;
|
||||
}
|
||||
stream.PopLimit(limit);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Result>
|
||||
bool ParseFloatList(Result* float_list) {
|
||||
DCHECK(float_list != nullptr);
|
||||
protobuf::io::CodedInputStream stream(
|
||||
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
|
||||
EnableAliasing(&stream);
|
||||
uint32 length;
|
||||
if (!stream.ReadVarint32(&length)) return false;
|
||||
auto limit = stream.PushLimit(length);
|
||||
|
||||
if (!stream.ExpectAtEnd()) {
|
||||
uint8 peek_tag = PeekTag(&stream);
|
||||
if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
constexpr int32 kNumFloatBytes = 4;
|
||||
if (peek_tag == kDelimitedTag(1)) { // packed
|
||||
if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag
|
||||
uint32 packed_length;
|
||||
if (!stream.ReadVarint32(&packed_length)) return false;
|
||||
auto packed_limit = stream.PushLimit(packed_length);
|
||||
|
||||
// Store the initial size to know the offset we have to start writing
|
||||
// data from before resizing the output "vector".
|
||||
const size_t initial_size = float_list->size();
|
||||
float_list->resize(initial_size + packed_length / kNumFloatBytes);
|
||||
|
||||
// If the result data type is float and we are on a little endian
|
||||
// machine then we can simply memcpy the data from the proto into the
|
||||
// result vector.
|
||||
if (port::kLittleEndian &&
|
||||
sizeof(typename Result::value_type) == kNumFloatBytes) {
|
||||
// Calculate the length of the buffer available what can be less than
|
||||
// what we requested in resize in case of a LimitedArraySlice.
|
||||
const uint32 bytes_to_copy =
|
||||
std::min(static_cast<uint32>((float_list->size() - initial_size) *
|
||||
kNumFloatBytes),
|
||||
packed_length);
|
||||
if (!stream.ReadRaw(float_list->data() + initial_size, bytes_to_copy))
|
||||
return false;
|
||||
} else {
|
||||
int64 index = initial_size;
|
||||
while (!stream.ExpectAtEnd()) {
|
||||
uint32 buffer32;
|
||||
if (!stream.ReadLittleEndian32(&buffer32)) return false;
|
||||
if (index < float_list->size()) {
|
||||
float_list->data()[index] = absl::bit_cast<float>(buffer32);
|
||||
++index;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stream.PopLimit(packed_limit);
|
||||
} else { // non-packed
|
||||
const size_t initial_size = float_list->size();
|
||||
// 1 byte for the tag (`1` encoded as Variant32) and kNumFloatBytes for
|
||||
// the value.
|
||||
const int64 num_elements =
|
||||
stream.BytesUntilLimit() / (1 + kNumFloatBytes);
|
||||
float_list->resize(initial_size + num_elements);
|
||||
int64 index = initial_size;
|
||||
while (!stream.ExpectAtEnd()) {
|
||||
if (!stream.ExpectTag(kFixed32Tag(1))) return false;
|
||||
uint32 buffer32;
|
||||
if (!stream.ReadLittleEndian32(&buffer32)) return false;
|
||||
float_list->data()[index] = absl::bit_cast<float>(buffer32);
|
||||
++index;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stream.PopLimit(limit);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Result>
|
||||
bool ParseInt64List(Result* int64_list) {
|
||||
DCHECK(int64_list != nullptr);
|
||||
protobuf::io::CodedInputStream stream(
|
||||
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
|
||||
EnableAliasing(&stream);
|
||||
uint32 length;
|
||||
if (!stream.ReadVarint32(&length)) return false;
|
||||
auto limit = stream.PushLimit(length);
|
||||
|
||||
if (!stream.ExpectAtEnd()) {
|
||||
uint8 peek_tag = PeekTag(&stream);
|
||||
if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) {
|
||||
return false;
|
||||
}
|
||||
if (peek_tag == kDelimitedTag(1)) { // packed
|
||||
if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag
|
||||
uint32 packed_length;
|
||||
if (!stream.ReadVarint32(&packed_length)) return false;
|
||||
auto packed_limit = stream.PushLimit(packed_length);
|
||||
|
||||
while (!stream.ExpectAtEnd()) {
|
||||
protobuf_uint64 n; // There is no API for int64
|
||||
if (!stream.ReadVarint64(&n)) return false;
|
||||
int64_list->push_back(static_cast<int64>(n));
|
||||
}
|
||||
|
||||
stream.PopLimit(packed_limit);
|
||||
} else { // non-packed
|
||||
while (!stream.ExpectAtEnd()) {
|
||||
if (!stream.ExpectTag(kVarintTag(1))) return false;
|
||||
protobuf_uint64 n; // There is no API for int64
|
||||
if (!stream.ReadVarint64(&n)) return false;
|
||||
int64_list->push_back(static_cast<int64>(n));
|
||||
}
|
||||
}
|
||||
}
|
||||
stream.PopLimit(limit);
|
||||
return true;
|
||||
}
|
||||
|
||||
StringPiece GetSerialized() const { return serialized_; }
|
||||
|
||||
private:
|
||||
StringPiece serialized_;
|
||||
};
|
||||
|
||||
using FeatureMapEntry = std::pair<StringPiece, Feature>;
|
||||
using Example = std::vector<FeatureMapEntry>;
|
||||
|
||||
} // namespace parsed
|
||||
|
||||
inline bool SkipExtraneousTag(protobuf::io::CodedInputStream* stream) {
|
||||
uint32 data;
|
||||
protobuf_uint64 dummy;
|
||||
switch (stream->ReadTag() & 0x7) {
|
||||
case 0: // varint
|
||||
if (!stream->ReadVarint32(&data)) return false;
|
||||
return true;
|
||||
case 1: // fixed64
|
||||
if (!stream->ReadLittleEndian64(&dummy)) return false;
|
||||
return true;
|
||||
case 2: // length delimited
|
||||
if (!stream->ReadVarint32(&data)) return false;
|
||||
stream->Skip(data);
|
||||
return true;
|
||||
case 3: // group begin
|
||||
return false; // groups not supported.
|
||||
case 4: // group end
|
||||
return false; // groups not supported.
|
||||
case 5: // fixed32
|
||||
if (!stream->ReadLittleEndian32(&data)) return false;
|
||||
return true;
|
||||
}
|
||||
return false; // unrecognized tag type
|
||||
}
|
||||
|
||||
bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result);
|
||||
|
||||
bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream,
|
||||
parsed::FeatureMapEntry* feature_map_entry);
|
||||
|
||||
bool ParseFeatures(protobuf::io::CodedInputStream* stream,
|
||||
parsed::Example* example);
|
||||
|
||||
bool ParseExample(protobuf::io::CodedInputStream* stream,
|
||||
parsed::Example* example);
|
||||
|
||||
bool ParseExample(StringPiece serialized, parsed::Example* example);
|
||||
|
||||
using Config = FastParseExampleConfig;
|
||||
|
||||
// Enumeration for distinguishing feature types.
|
||||
// Note: FastParseSequenceExample constructs a map that includes Type values,
|
||||
// and relies on the fact that they are default-initialized to Dense.
|
||||
enum class Type { Dense, Sparse, Ragged };
|
||||
|
||||
// Note: We use SparseBuffer for sparse, ragged, and dense_varlen features.
|
||||
struct SparseBuffer {
|
||||
// Features are in one of the 3 vectors below depending on config's dtype.
|
||||
// Other 2 vectors remain empty.
|
||||
SmallVector<tstring> bytes_list;
|
||||
SmallVector<float> float_list;
|
||||
SmallVector<int64> int64_list;
|
||||
|
||||
// Features of example i are elements with indices
|
||||
// from example_end_indices[i-1] to example_end_indices[i]-1 on the
|
||||
// appropriate xxxxx_list
|
||||
std::vector<size_t> example_end_indices;
|
||||
};
|
||||
|
||||
struct SeededHasher {
|
||||
uint64 operator()(StringPiece s) const {
|
||||
return Hash64(s.data(), s.size(), seed);
|
||||
}
|
||||
uint64 seed{0xDECAFCAFFE};
|
||||
};
|
||||
|
||||
// Use this in the "default" clause of switch statements when dispatching
|
||||
// on a dtype variable that was checked by CheckConfigDataType():
|
||||
inline void ReportUnexpectedDataType(DataType dtype) {
|
||||
DCHECK(false)
|
||||
<< "Encountered unexpected DataType " << DataTypeString(dtype)
|
||||
<< "in variable that should have been checked by CheckConfigDataType().";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const SmallVector<T>& GetListFromBuffer(const SparseBuffer& buffer);
|
||||
|
||||
template <>
|
||||
const SmallVector<int64>& GetListFromBuffer<int64>(const SparseBuffer& buffer);
|
||||
|
||||
template <>
|
||||
const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer);
|
||||
|
||||
template <>
|
||||
const SmallVector<tstring>& GetListFromBuffer<tstring>(
|
||||
const SparseBuffer& buffer);
|
||||
|
||||
template <typename T>
|
||||
void CopyOrMoveBlock(const T* b, const T* e, T* t) {
|
||||
std::copy(b, e, t);
|
||||
}
|
||||
template <>
|
||||
void CopyOrMoveBlock(const tstring* b, const tstring* e, tstring* t);
|
||||
|
||||
void CountSparseFeatures(
|
||||
const std::vector<std::vector<SparseBuffer>>& sparse_buffers, size_t d,
|
||||
size_t* total_num_features, size_t* max_num_features);
|
||||
|
||||
void CopySparseBufferToTensor(DataType dtype, size_t offset, SparseBuffer* src,
|
||||
Tensor* dst);
|
||||
|
||||
// A struct used by FastParseSequenceExample to hold the serialized proto
|
||||
// substrings for a single feature, plus some auxiliary information derived
|
||||
// from those protos (such as the total value length).
|
||||
struct FeatureProtos {
|
||||
// Proto substrings from each serialized SequenceExample that correspond
|
||||
// with this feature. `protos_present` records whether the proto had a
|
||||
// value defined (even if that value is empty).
|
||||
std::vector<StringPiece> protos;
|
||||
std::vector<bool> protos_present;
|
||||
|
||||
// Information derived from protos:
|
||||
size_t length; // total length for ragged/sparse, max row length for dense.
|
||||
size_t num_rows; // only populated for ragged sequence features.
|
||||
|
||||
// Information from the config:
|
||||
Type type; // Whether this feature is sparse, ragged, or dense.
|
||||
DataType dtype;
|
||||
};
|
||||
|
||||
// Map from feature name to FeatureProtos for that feature.
|
||||
using FeatureProtosMap = absl::flat_hash_map<StringPiece, FeatureProtos>;
|
||||
|
||||
string ExampleName(const gtl::ArraySlice<tstring> example_names, int n);
|
||||
|
||||
// Return the number of bytes elements parsed, or -1 on error. If out is null,
|
||||
// this method simply counts the number of elements without any copying.
|
||||
inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
|
||||
tstring* out) {
|
||||
int num_elements = 0;
|
||||
uint32 length;
|
||||
if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) {
|
||||
return -1;
|
||||
}
|
||||
if (length > 0) {
|
||||
auto limit = stream->PushLimit(length);
|
||||
while (!stream->ExpectAtEnd()) {
|
||||
uint32 bytes_length;
|
||||
if (!stream->ExpectTag(kDelimitedTag(1)) ||
|
||||
!stream->ReadVarint32(&bytes_length)) {
|
||||
return -1;
|
||||
}
|
||||
if (out == nullptr) {
|
||||
stream->Skip(bytes_length);
|
||||
} else {
|
||||
out->resize_uninitialized(bytes_length);
|
||||
if (!stream->ReadRaw(out->data(), bytes_length)) {
|
||||
return -1;
|
||||
}
|
||||
out++;
|
||||
}
|
||||
num_elements++;
|
||||
}
|
||||
stream->PopLimit(limit);
|
||||
}
|
||||
return num_elements;
|
||||
}
|
||||
|
||||
inline void PadFloatFeature(int num_to_pad, float* out) {
|
||||
for (int i = 0; i < num_to_pad; i++) {
|
||||
*out++ = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
inline void PadInt64Feature(int num_to_pad, int64* out) {
|
||||
for (int i = 0; i < num_to_pad; i++) {
|
||||
*out++ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Return the number of float elements parsed, or -1 on error. If out is null,
|
||||
// this method simply counts the number of elements without any copying.
|
||||
inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
|
||||
float* out) {
|
||||
int num_elements = 0;
|
||||
uint32 length;
|
||||
if (!stream->ExpectTag(kDelimitedTag(2)) || !stream->ReadVarint32(&length)) {
|
||||
return -1;
|
||||
}
|
||||
if (length > 0) {
|
||||
auto limit = stream->PushLimit(length);
|
||||
uint8 peek_tag = PeekTag(stream);
|
||||
if (peek_tag == kDelimitedTag(1)) { // packed
|
||||
uint32 packed_length;
|
||||
if (!stream->ExpectTag(kDelimitedTag(1)) ||
|
||||
!stream->ReadVarint32(&packed_length)) {
|
||||
return -1;
|
||||
}
|
||||
auto packed_limit = stream->PushLimit(packed_length);
|
||||
while (!stream->ExpectAtEnd()) {
|
||||
uint32 buffer32;
|
||||
if (!stream->ReadLittleEndian32(&buffer32)) {
|
||||
return -1;
|
||||
}
|
||||
if (out != nullptr) {
|
||||
*out++ = absl::bit_cast<float>(buffer32);
|
||||
}
|
||||
num_elements++;
|
||||
}
|
||||
stream->PopLimit(packed_limit);
|
||||
} else if (peek_tag == kFixed32Tag(1)) {
|
||||
while (!stream->ExpectAtEnd()) {
|
||||
uint32 buffer32;
|
||||
if (!stream->ExpectTag(kFixed32Tag(1)) ||
|
||||
!stream->ReadLittleEndian32(&buffer32)) {
|
||||
return -1;
|
||||
}
|
||||
if (out != nullptr) {
|
||||
*out++ = absl::bit_cast<float>(buffer32);
|
||||
}
|
||||
num_elements++;
|
||||
}
|
||||
} else {
|
||||
// Unknown tag.
|
||||
return -1;
|
||||
}
|
||||
stream->PopLimit(limit);
|
||||
}
|
||||
return num_elements;
|
||||
}
|
||||
|
||||
// Return the number of int64 elements parsed, or -1 on error. If out is null,
|
||||
// this method simply counts the number of elements without any copying.
|
||||
inline int ParseInt64Feature(protobuf::io::CodedInputStream* stream,
|
||||
int64* out) {
|
||||
int num_elements = 0;
|
||||
uint32 length;
|
||||
if (!stream->ExpectTag(kDelimitedTag(3)) || !stream->ReadVarint32(&length)) {
|
||||
return -1;
|
||||
}
|
||||
if (length > 0) {
|
||||
auto limit = stream->PushLimit(length);
|
||||
uint8 peek_tag = PeekTag(stream);
|
||||
if (peek_tag == kDelimitedTag(1)) { // packed
|
||||
uint32 packed_length;
|
||||
if (!stream->ExpectTag(kDelimitedTag(1)) ||
|
||||
!stream->ReadVarint32(&packed_length)) {
|
||||
return -1;
|
||||
}
|
||||
auto packed_limit = stream->PushLimit(packed_length);
|
||||
while (!stream->ExpectAtEnd()) {
|
||||
protobuf_uint64 n; // There is no API for int64
|
||||
if (!stream->ReadVarint64(&n)) {
|
||||
return -1;
|
||||
}
|
||||
if (out != nullptr) {
|
||||
*out++ = n;
|
||||
}
|
||||
num_elements++;
|
||||
}
|
||||
stream->PopLimit(packed_limit);
|
||||
} else if (peek_tag == kVarintTag(1)) {
|
||||
while (!stream->ExpectAtEnd()) {
|
||||
protobuf_uint64 n; // There is no API for int64
|
||||
if (!stream->ExpectTag(kVarintTag(1)) || !stream->ReadVarint64(&n)) {
|
||||
return -1;
|
||||
}
|
||||
if (out != nullptr) {
|
||||
*out++ = n;
|
||||
}
|
||||
num_elements++;
|
||||
}
|
||||
} else {
|
||||
// Unknown tag.
|
||||
return -1;
|
||||
}
|
||||
stream->PopLimit(limit);
|
||||
}
|
||||
return num_elements;
|
||||
}
|
||||
|
||||
// Parses the next feature on `stream` into `out` starting at `out_offset`.
|
||||
// Updates `out_offset`, and returns the number of values added.
|
||||
// Returns -1 if the next feature on `stream` doesn't match `dtype`.
|
||||
inline int ParseFeature(DataType dtype, protobuf::io::CodedInputStream* stream,
|
||||
Tensor* out, size_t* out_offset) {
|
||||
int delta;
|
||||
switch (dtype) {
|
||||
case DT_STRING:
|
||||
delta =
|
||||
ParseBytesFeature(stream, out->flat<tstring>().data() + *out_offset);
|
||||
break;
|
||||
case DT_FLOAT:
|
||||
delta =
|
||||
ParseFloatFeature(stream, out->flat<float>().data() + *out_offset);
|
||||
break;
|
||||
case DT_INT64:
|
||||
delta =
|
||||
ParseInt64Feature(stream, out->flat<int64>().data() + *out_offset);
|
||||
break;
|
||||
default:
|
||||
ReportUnexpectedDataType(dtype);
|
||||
delta = 0;
|
||||
}
|
||||
if (delta > 0) {
|
||||
*out_offset += delta;
|
||||
}
|
||||
return delta;
|
||||
}
|
||||
|
||||
// Returns the length of the next feature on `stream`.
|
||||
// Returns -1 if the next feature on `stream` doesn't match `dtype`.
|
||||
inline int GetFeatureLength(DataType dtype,
|
||||
protobuf::io::CodedInputStream* stream) {
|
||||
switch (dtype) {
|
||||
case DT_STRING:
|
||||
return ParseBytesFeature(stream, nullptr);
|
||||
case DT_FLOAT:
|
||||
return ParseFloatFeature(stream, nullptr);
|
||||
case DT_INT64:
|
||||
return ParseInt64Feature(stream, nullptr);
|
||||
default:
|
||||
ReportUnexpectedDataType(dtype);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
inline DataType ParseDataType(protobuf::io::CodedInputStream* stream) {
|
||||
uint8 peek_tag = PeekTag(stream);
|
||||
switch (peek_tag) {
|
||||
case kDelimitedTag(1):
|
||||
return DT_STRING;
|
||||
case kDelimitedTag(2):
|
||||
return DT_FLOAT;
|
||||
case kDelimitedTag(3):
|
||||
return DT_INT64;
|
||||
default:
|
||||
return DT_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
|
||||
DataType dtype) {
|
||||
switch (dtype) {
|
||||
case DT_STRING:
|
||||
if (!stream->ExpectTag(kDelimitedTag(1))) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case DT_FLOAT:
|
||||
if (!stream->ExpectTag(kDelimitedTag(2))) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case DT_INT64:
|
||||
if (!stream->ExpectTag(kDelimitedTag(3))) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
uint32 length;
|
||||
return stream->ReadVarint32(&length) && length == 0;
|
||||
}
|
||||
|
||||
} // namespace example
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
|
File diff suppressed because it is too large
Load Diff
@ -1,33 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_PARSE_EXAMPLE_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_PARSE_EXAMPLE_H_
|
||||
|
||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace custom {
|
||||
|
||||
TfLiteRegistration* Register_PARSE_EXAMPLE();
|
||||
TfLiteRegistration* Register_PARSE_EXAMPLE_V2();
|
||||
|
||||
extern "C" void AddParseExampleOp(::tflite::MutableOpResolver* resolver);
|
||||
|
||||
} // namespace custom
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_PARSE_EXAMPLE_H_
|
@ -1,330 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/kernels/parse_example/parse_example.h"
|
||||
|
||||
#include <initializer_list>
|
||||
|
||||
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/core/example/feature_util.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/interpreter_builder.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/model_builder.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace custom {
|
||||
|
||||
namespace tf = ::tensorflow;
|
||||
|
||||
const char* kNodeDefTxt = R"pb(
|
||||
name: "ParseExample/ParseExample"
|
||||
op: "ParseExample"
|
||||
input: "serialized"
|
||||
input: "ParseExample/ParseExample/names"
|
||||
input: "ParseExample/ParseExample/dense_keys_0"
|
||||
input: "ParseExample/Const"
|
||||
attr {
|
||||
key: "Ndense"
|
||||
value { i: 1 }
|
||||
}
|
||||
attr {
|
||||
key: "Nsparse"
|
||||
value { i: 0 }
|
||||
}
|
||||
attr {
|
||||
key: "Tdense"
|
||||
value { list { type: DT_FLOAT } }
|
||||
}
|
||||
attr {
|
||||
key: "dense_shapes"
|
||||
value { list { shape { dim { size: 2 } } } }
|
||||
}
|
||||
attr {
|
||||
key: "sparse_types"
|
||||
value { list { type: DT_FLOAT } }
|
||||
}
|
||||
)pb";
|
||||
|
||||
const char* kNodeDefTxt2 = R"pb(
|
||||
name: "ParseExample/ParseExample"
|
||||
op: "ParseExample"
|
||||
input: "serialized"
|
||||
input: "ParseExample/ParseExample/names"
|
||||
input: "ParseExample/ParseExample/sparse_keys_0"
|
||||
attr {
|
||||
key: "Ndense"
|
||||
value { i: 0 }
|
||||
}
|
||||
attr {
|
||||
key: "Nsparse"
|
||||
value { i: 1 }
|
||||
}
|
||||
attr {
|
||||
key: "Tdense"
|
||||
value {}
|
||||
}
|
||||
attr {
|
||||
key: "dense_shapes"
|
||||
value {}
|
||||
}
|
||||
attr {
|
||||
key: "sparse_types"
|
||||
value { list { type: DT_FLOAT } }
|
||||
}
|
||||
)pb";
|
||||
|
||||
const char* kNodeDefTxt3 = R"pb(
|
||||
name: "ParseExample/ParseExample"
|
||||
op: "ParseExample"
|
||||
input: "serialized"
|
||||
input: "ParseExample/ParseExample/names"
|
||||
input: "ParseExample/ParseExample/sparse_keys_0"
|
||||
attr {
|
||||
key: "Ndense"
|
||||
value { i: 1 }
|
||||
}
|
||||
attr {
|
||||
key: "Nsparse"
|
||||
value { i: 0 }
|
||||
}
|
||||
attr {
|
||||
key: "Tdense"
|
||||
value { list { type: DT_STRING } }
|
||||
}
|
||||
attr {
|
||||
key: "dense_shapes"
|
||||
value { list { shape { dim { size: 1 } } } }
|
||||
}
|
||||
attr {
|
||||
key: "sparse_types"
|
||||
value { list { type: DT_FLOAT } }
|
||||
}
|
||||
)pb";
|
||||
|
||||
const char* kNodeDefTxt4 = R"pb(
|
||||
name: "ParseExample/ParseExample"
|
||||
op: "ParseExample"
|
||||
input: "serialized"
|
||||
input: "ParseExample/ParseExample/names"
|
||||
input: "ParseExample/ParseExample/sparse_keys_0"
|
||||
attr {
|
||||
key: "Ndense"
|
||||
value { i: 0 }
|
||||
}
|
||||
attr {
|
||||
key: "Nsparse"
|
||||
value { i: 1 }
|
||||
}
|
||||
attr {
|
||||
key: "Tdense"
|
||||
value {}
|
||||
}
|
||||
attr {
|
||||
key: "dense_shapes"
|
||||
value {}
|
||||
}
|
||||
attr {
|
||||
key: "sparse_types"
|
||||
value { list { type: DT_STRING } }
|
||||
}
|
||||
)pb";
|
||||
|
||||
template <typename DefaultType>
|
||||
class ParseExampleOpModel : public SingleOpModel {
|
||||
public:
|
||||
ParseExampleOpModel(std::string serialized_example,
|
||||
std::vector<std::string> sparse_keys,
|
||||
std::vector<std::string> dense_keys,
|
||||
std::initializer_list<DefaultType> dense_defaults,
|
||||
std::vector<TensorType> dense_types,
|
||||
std::vector<TensorType> sparse_types,
|
||||
const char* text_def, int dense_size = 2) {
|
||||
// Example
|
||||
string_indices_.push_back(AddInput(TensorData(TensorType_STRING, {1})));
|
||||
// Names
|
||||
string_indices_.push_back(
|
||||
AddConstInput<std::string>(TensorData(TensorType_STRING, {0}), {""}));
|
||||
std::for_each(sparse_keys.begin(), sparse_keys.end(), [&](auto&&) {
|
||||
string_indices_.push_back(AddInput(TensorData(TensorType_STRING, {1})));
|
||||
});
|
||||
std::for_each(dense_keys.begin(), dense_keys.end(), [&](auto&&) {
|
||||
string_indices_.push_back(AddInput(TensorData(TensorType_STRING, {1})));
|
||||
});
|
||||
if (dense_size > 0) {
|
||||
dense_defaults_ = AddConstInput<DefaultType>(
|
||||
TensorData(dense_types[0], {dense_size}), dense_defaults);
|
||||
}
|
||||
if (!sparse_keys.empty()) {
|
||||
for (int i = 0; i < sparse_keys.size(); i++) {
|
||||
sparse_indices_outputs_.push_back(AddOutput(TensorType_INT64));
|
||||
}
|
||||
for (int i = 0; i < sparse_keys.size(); i++) {
|
||||
sparse_values_outputs_.push_back(AddOutput(sparse_types[i]));
|
||||
}
|
||||
for (int i = 0; i < sparse_keys.size(); i++) {
|
||||
sparse_shapes_outputs_.push_back(AddOutput({TensorType_INT64, {2}}));
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < dense_keys.size(); i++) {
|
||||
dense_outputs_.push_back(AddOutput({dense_types[i], {dense_size}}));
|
||||
}
|
||||
|
||||
tf::NodeDef nodedef;
|
||||
tf::protobuf::TextFormat::Parser parser;
|
||||
tf::protobuf::io::ArrayInputStream input_stream(text_def, strlen(text_def));
|
||||
if (!parser.Parse(&input_stream, &nodedef)) {
|
||||
abort();
|
||||
}
|
||||
std::string serialized_nodedef;
|
||||
nodedef.SerializeToString(&serialized_nodedef);
|
||||
flexbuffers::Builder fbb;
|
||||
fbb.Vector([&]() {
|
||||
fbb.String(nodedef.op());
|
||||
fbb.String(serialized_nodedef);
|
||||
});
|
||||
fbb.Finish();
|
||||
const auto buffer = fbb.GetBuffer();
|
||||
SetCustomOp("ParseExample", buffer, Register_PARSE_EXAMPLE);
|
||||
BuildInterpreter({});
|
||||
int idx = 0;
|
||||
PopulateStringTensor(string_indices_[idx++], {serialized_example});
|
||||
PopulateStringTensor(string_indices_[idx++], {""});
|
||||
for (const auto& key : sparse_keys) {
|
||||
PopulateStringTensor(string_indices_[idx++], {key});
|
||||
}
|
||||
for (const auto& key : dense_keys) {
|
||||
PopulateStringTensor(string_indices_[idx++], {key});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> GetSparseIndicesOutput(int i) {
|
||||
return ExtractVector<T>(sparse_indices_outputs_[i]);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> GetSparseValuesOutput(int i) {
|
||||
return ExtractVector<T>(sparse_values_outputs_[i]);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> GetSparseShapesOutput(int i) {
|
||||
return ExtractVector<T>(sparse_shapes_outputs_[i]);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> GetDenseOutput(int i) {
|
||||
return ExtractVector<T>(dense_outputs_[i]);
|
||||
}
|
||||
|
||||
std::vector<std::string> GetStringOutput(int i) {
|
||||
auto* t = interpreter_->tensor(i);
|
||||
int count = GetStringCount(t);
|
||||
std::vector<std::string> v;
|
||||
for (int i = 0; i < count; ++i) {
|
||||
auto ref = GetString(t, i);
|
||||
v.emplace_back(ref.str, ref.len);
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
int DenseDefaults() { return dense_defaults_; }
|
||||
|
||||
int SparseValuesOutputs(int i) { return sparse_values_outputs_[i]; }
|
||||
|
||||
int DenseOutputs(int i) { return dense_outputs_[i]; }
|
||||
|
||||
std::vector<int> dense_outputs_;
|
||||
std::vector<int> sparse_indices_outputs_;
|
||||
std::vector<int> sparse_shapes_outputs_;
|
||||
std::vector<int> sparse_values_outputs_;
|
||||
std::vector<int> string_indices_;
|
||||
int dense_defaults_ = -1;
|
||||
};
|
||||
|
||||
TEST(ParseExampleOpsTest, SimpleTest) {
|
||||
tf::Example example;
|
||||
tf::AppendFeatureValues<float>({1.5f, 1.5f}, "time", &example);
|
||||
tf::AppendFeatureValues<float>({1.0f, 1.0f}, "num", &example);
|
||||
ParseExampleOpModel<float> m(example.SerializeAsString(), {}, {"time"},
|
||||
{0.f, 0.f}, {TensorType_FLOAT32}, {},
|
||||
kNodeDefTxt);
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetDenseOutput<float>(0),
|
||||
ElementsAreArray(ArrayFloatNear({1.5f, 1.5f})));
|
||||
}
|
||||
|
||||
TEST(ParseExampleOpsTest, SparseTest) {
|
||||
tf::Example example;
|
||||
tf::AppendFeatureValues<float>({1.5f}, "time", &example);
|
||||
ParseExampleOpModel<float> m(example.SerializeAsString(), {"time"}, {}, {},
|
||||
{}, {TensorType_FLOAT32}, kNodeDefTxt2, 0);
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetSparseIndicesOutput<int64_t>(0),
|
||||
ElementsAreArray(ArrayFloatNear({0, 0})));
|
||||
EXPECT_THAT(m.GetSparseValuesOutput<float>(0),
|
||||
ElementsAreArray(ArrayFloatNear({1.5f})));
|
||||
EXPECT_THAT(m.GetSparseShapesOutput<int64_t>(0),
|
||||
ElementsAreArray(ArrayFloatNear({1, 1})));
|
||||
}
|
||||
|
||||
TEST(ParseExampleOpsTest, SimpleBytesTest) {
|
||||
tf::Example example;
|
||||
const std::string test_data = "simpletest";
|
||||
tf::AppendFeatureValues<tensorflow::tstring>({test_data}, "time", &example);
|
||||
tf::AppendFeatureValues<float>({1.0f, 1.0f}, "num", &example);
|
||||
std::string default_value = "missing";
|
||||
ParseExampleOpModel<std::string> m(example.SerializeAsString(), {}, {"time"},
|
||||
{default_value}, {TensorType_STRING}, {},
|
||||
kNodeDefTxt3, 1);
|
||||
m.PopulateStringTensor(m.DenseDefaults(), {default_value});
|
||||
m.Invoke();
|
||||
std::vector<string> c = m.GetStringOutput(m.DenseOutputs(0));
|
||||
EXPECT_EQ(1, c.size());
|
||||
EXPECT_EQ(test_data, c[0]);
|
||||
}
|
||||
|
||||
TEST(ParseExampleOpsTest, SparseBytesTest) {
|
||||
tf::Example example;
|
||||
const std::string test_data = "simpletest";
|
||||
tf::AppendFeatureValues<tensorflow::tstring>({test_data, test_data}, "time",
|
||||
&example);
|
||||
tf::AppendFeatureValues<float>({1.0f, 1.0f}, "num", &example);
|
||||
ParseExampleOpModel<std::string> m(example.SerializeAsString(), {"time"}, {},
|
||||
{}, {}, {TensorType_STRING}, kNodeDefTxt4,
|
||||
0);
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetSparseIndicesOutput<int64_t>(0),
|
||||
testing::ElementsAreArray({0, 0, 0, 1}));
|
||||
auto values = m.GetStringOutput(m.SparseValuesOutputs(0));
|
||||
EXPECT_EQ(2, values.size());
|
||||
EXPECT_EQ(test_data, values[0]);
|
||||
EXPECT_EQ(test_data, values[1]);
|
||||
EXPECT_THAT(m.GetSparseShapesOutput<int64_t>(0),
|
||||
testing::ElementsAreArray({1, 2}));
|
||||
}
|
||||
|
||||
} // namespace custom
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
@ -239,7 +239,6 @@ cc_library(
|
||||
"//tensorflow/lite/kernels:reference_ops",
|
||||
"//tensorflow/lite/kernels:test_delegate_providers_lib",
|
||||
"//tensorflow/lite/kernels/hashtable:hashtable_op_kernels",
|
||||
"//tensorflow/lite/kernels/parse_example:parse_example",
|
||||
"//tensorflow/lite/tools/evaluation:utils",
|
||||
] + select({
|
||||
"//tensorflow:ios": [],
|
||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
||||
#endif
|
||||
#include "tensorflow/lite/kernels/custom_ops_register.h"
|
||||
#include "tensorflow/lite/kernels/hashtable/hashtable_ops.h"
|
||||
#include "tensorflow/lite/kernels/parse_example/parse_example.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/register_ref.h"
|
||||
#include "tensorflow/lite/kernels/test_delegate_providers.h"
|
||||
@ -371,7 +370,6 @@ TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
|
||||
ops::builtin::BuiltinOpResolver* buildinop_resolver_ =
|
||||
reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
|
||||
tflite::ops::custom::AddHashtableOps(buildinop_resolver_);
|
||||
tflite::ops::custom::AddParseExampleOp(buildinop_resolver_);
|
||||
}
|
||||
|
||||
switch (delegate_type) {
|
||||
|
Loading…
Reference in New Issue
Block a user