Further improve performance of ParseExample fixing regression on specific benchmarks.
- Ensure that each thread receives equal piece. - At least 8 threads (given 8 or more examples) and at most 64 - Use InlinedVector Change: 132488652
This commit is contained in:
parent
d3b34b7e56
commit
7705791619
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/casts.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/util/presized_cuckoo_map.h"
|
||||
@ -34,6 +35,10 @@ namespace tensorflow {
|
||||
namespace example {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
using SmallVector = gtl::InlinedVector<T, 4>;
|
||||
|
||||
template <typename A>
|
||||
auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
|
||||
a->EnableAliasing(true);
|
||||
@ -86,7 +91,7 @@ class Feature {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool ParseBytesList(std::vector<string>* bytes_list) {
|
||||
bool ParseBytesList(SmallVector<string>* bytes_list) {
|
||||
DCHECK(bytes_list != nullptr);
|
||||
protobuf::io::CodedInputStream stream(
|
||||
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
|
||||
@ -110,7 +115,7 @@ class Feature {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParseFloatList(std::vector<float>* float_list) {
|
||||
bool ParseFloatList(SmallVector<float>* float_list) {
|
||||
DCHECK(float_list != nullptr);
|
||||
protobuf::io::CodedInputStream stream(
|
||||
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
|
||||
@ -152,7 +157,7 @@ class Feature {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParseInt64List(std::vector<int64>* int64_list) {
|
||||
bool ParseInt64List(SmallVector<int64>* int64_list) {
|
||||
DCHECK(int64_list != nullptr);
|
||||
protobuf::io::CodedInputStream stream(
|
||||
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
|
||||
@ -293,7 +298,7 @@ bool TestFastParse(const string& serialized, Example* example) {
|
||||
case DT_INVALID:
|
||||
break;
|
||||
case DT_STRING: {
|
||||
std::vector<string> list;
|
||||
SmallVector<string> list;
|
||||
if (!entry.second.ParseBytesList(&list)) return false;
|
||||
auto* result_list = value.mutable_bytes_list();
|
||||
for (auto& bytes : list) {
|
||||
@ -302,7 +307,7 @@ bool TestFastParse(const string& serialized, Example* example) {
|
||||
break;
|
||||
}
|
||||
case DT_FLOAT: {
|
||||
std::vector<float> list;
|
||||
SmallVector<float> list;
|
||||
if (!entry.second.ParseFloatList(&list)) return false;
|
||||
auto* result_list = value.mutable_float_list();
|
||||
for (float f : list) {
|
||||
@ -311,7 +316,7 @@ bool TestFastParse(const string& serialized, Example* example) {
|
||||
break;
|
||||
}
|
||||
case DT_INT64: {
|
||||
std::vector<int64> list;
|
||||
SmallVector<int64> list;
|
||||
if (!entry.second.ParseInt64List(&list)) return false;
|
||||
auto* result_list = value.mutable_int64_list();
|
||||
for (int64 i : list) {
|
||||
@ -334,28 +339,32 @@ using Config = FastParseExampleConfig;
|
||||
|
||||
void ParallelFor(const std::function<void(size_t)>& f, size_t n,
|
||||
thread::ThreadPool* thread_pool) {
|
||||
DCHECK(thread_pool != nullptr);
|
||||
if (n == 0) return;
|
||||
BlockingCounter counter(n - 1);
|
||||
for (size_t i = 1; i < n; ++i) {
|
||||
thread_pool->Schedule([i, &f, &counter] {
|
||||
if (thread_pool == nullptr) {
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
f(i);
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
} else {
|
||||
BlockingCounter counter(n - 1);
|
||||
for (size_t i = 1; i < n; ++i) {
|
||||
thread_pool->Schedule([i, &f, &counter] {
|
||||
f(i);
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
f(0);
|
||||
counter.Wait();
|
||||
}
|
||||
f(0);
|
||||
counter.Wait();
|
||||
}
|
||||
|
||||
enum class Type { Sparse, Dense };
|
||||
|
||||
struct SparseBuffer {
|
||||
// TODO(lew): Use InlinedVector.
|
||||
// Features are in one of the 3 vectors below depending on config's dtype.
|
||||
// Other 2 vectors remain empty.
|
||||
std::vector<string> bytes_list;
|
||||
std::vector<float> float_list;
|
||||
std::vector<int64> int64_list;
|
||||
SmallVector<string> bytes_list;
|
||||
SmallVector<float> float_list;
|
||||
SmallVector<int64> int64_list;
|
||||
|
||||
// Features of example i are elements with indices
|
||||
// from example_end_indices[i-1] to example_end_indices[i]-1 on the
|
||||
@ -432,7 +441,7 @@ Status FastParseSerializedExample(
|
||||
|
||||
switch (config.dense[d].dtype) {
|
||||
case DT_INT64: {
|
||||
std::vector<int64> list;
|
||||
SmallVector<int64> list;
|
||||
if (!feature.ParseInt64List(&list)) return parse_error(feature_name);
|
||||
if (list.size() != num_elements) {
|
||||
return shape_error(list.size(), "int64");
|
||||
@ -442,7 +451,7 @@ Status FastParseSerializedExample(
|
||||
break;
|
||||
}
|
||||
case DT_FLOAT: {
|
||||
std::vector<float> list;
|
||||
SmallVector<float> list;
|
||||
if (!feature.ParseFloatList(&list)) return parse_error(feature_name);
|
||||
if (list.size() != num_elements) {
|
||||
return shape_error(list.size(), "float");
|
||||
@ -452,7 +461,7 @@ Status FastParseSerializedExample(
|
||||
break;
|
||||
}
|
||||
case DT_STRING: {
|
||||
std::vector<string> list;
|
||||
SmallVector<string> list;
|
||||
if (!feature.ParseBytesList(&list)) return parse_error(feature_name);
|
||||
if (list.size() != num_elements) {
|
||||
return shape_error(list.size(), "bytes");
|
||||
@ -580,7 +589,6 @@ Status FastParseExample(const Config& config,
|
||||
gtl::ArraySlice<string> serialized,
|
||||
gtl::ArraySlice<string> example_names,
|
||||
thread::ThreadPool* thread_pool, Result* result) {
|
||||
DCHECK(thread_pool != nullptr);
|
||||
DCHECK(result != nullptr);
|
||||
// Check config so we can safely CHECK(false) in switches on config.*.dtype
|
||||
for (auto& c : config.sparse) {
|
||||
@ -626,36 +634,49 @@ Status FastParseExample(const Config& config,
|
||||
}
|
||||
|
||||
// This parameter affects performance in a big and data-dependent way.
|
||||
const size_t kMiniBatchSizeBytes = 100000;
|
||||
const size_t kMiniBatchSizeBytes = 50000;
|
||||
|
||||
// Split examples into mini-batches for parallel processing.
|
||||
auto first_example_of_minibatch = [&] {
|
||||
std::vector<size_t> result;
|
||||
// Calculate number of minibatches.
|
||||
// In main regime make each minibatch around kMiniBatchSizeBytes bytes.
|
||||
// Apply 'special logic' below for small and big regimes.
|
||||
const size_t num_minibatches = [&] {
|
||||
size_t result = 0;
|
||||
size_t minibatch_bytes = 0;
|
||||
for (size_t i = 0; i < serialized.size(); i++) {
|
||||
if (minibatch_bytes == 0) { // start minibatch
|
||||
result.push_back(i);
|
||||
result++;
|
||||
}
|
||||
minibatch_bytes += serialized[i].size() + 1;
|
||||
if (minibatch_bytes > kMiniBatchSizeBytes) {
|
||||
minibatch_bytes = 0;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
// 'special logic'
|
||||
const size_t min_minibatches = std::min<size_t>(8, serialized.size());
|
||||
const size_t max_minibatches = 64;
|
||||
return std::max<size_t>(min_minibatches,
|
||||
std::min<size_t>(max_minibatches, result));
|
||||
}();
|
||||
|
||||
size_t num_minibatches = first_example_of_minibatch.size();
|
||||
auto first_example_of_minibatch = [&](size_t minibatch) -> size_t {
|
||||
return (serialized.size() * minibatch) / num_minibatches;
|
||||
};
|
||||
|
||||
// TODO(lew): A big performance low-hanging fruit here is to improve
|
||||
// num_minibatches calculation to take into account actual amount of work
|
||||
// needed, as the size in bytes is not perfect. Linear combination of
|
||||
// size in bytes and average number of features per example is promising.
|
||||
// Even better: measure time instead of estimating, but this is too costly
|
||||
// in small batches.
|
||||
// Maybe accept outside parameter #num_minibatches?
|
||||
|
||||
// Do minibatches in parallel.
|
||||
std::vector<std::vector<SparseBuffer>> sparse_buffers(num_minibatches);
|
||||
std::vector<Status> status_of_minibatch(num_minibatches);
|
||||
|
||||
auto ProcessMiniBatch = [&](size_t minibatch) {
|
||||
sparse_buffers[minibatch].resize(config.sparse.size());
|
||||
size_t start = first_example_of_minibatch[minibatch];
|
||||
size_t end = minibatch + 1 < num_minibatches
|
||||
? first_example_of_minibatch[minibatch + 1]
|
||||
: serialized.size();
|
||||
size_t start = first_example_of_minibatch(minibatch);
|
||||
size_t end = first_example_of_minibatch(minibatch + 1);
|
||||
for (size_t e = start; e < end; ++e) {
|
||||
status_of_minibatch[minibatch] = FastParseSerializedExample(
|
||||
serialized[e],
|
||||
@ -711,7 +732,7 @@ Status FastParseExample(const Config& config,
|
||||
// Update indices.
|
||||
int64* ix_p = &indices->matrix<int64>()(offset, 0);
|
||||
size_t delta = 0;
|
||||
size_t example_index = first_example_of_minibatch[i];
|
||||
size_t example_index = first_example_of_minibatch(i);
|
||||
for (size_t example_end_index : buffer.example_end_indices) {
|
||||
size_t feature_index = 0;
|
||||
for (; delta < example_end_index; ++delta) {
|
||||
|
@ -178,6 +178,15 @@ string MakeSerializedExample() {
|
||||
return serialized;
|
||||
}
|
||||
|
||||
TEST(TestFastParseExample, Empty) {
|
||||
Result result;
|
||||
FastParseExampleConfig config;
|
||||
config.sparse.push_back({"test", DT_STRING});
|
||||
Status status = FastParseExample(config, gtl::ArraySlice<string>(),
|
||||
gtl::ArraySlice<string>(), nullptr, &result);
|
||||
EXPECT_TRUE(status.ok()) << status;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace example
|
||||
|
Loading…
Reference in New Issue
Block a user