Further improve performance of ParseExample fixing regression on specific benchmarks.

- Ensure that each thread receives equal piece.
- At least 8 threads (given 8 or more examples) and at most 64
- Use InlinedVector
Change: 132488652
This commit is contained in:
A. Unique TensorFlower 2016-09-07 14:37:07 -08:00 committed by TensorFlower Gardener
parent d3b34b7e56
commit 7705791619
2 changed files with 65 additions and 35 deletions

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/util/presized_cuckoo_map.h"
@ -34,6 +35,10 @@ namespace tensorflow {
namespace example {
namespace {
template <typename T>
using SmallVector = gtl::InlinedVector<T, 4>;
template <typename A>
auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
a->EnableAliasing(true);
@ -86,7 +91,7 @@ class Feature {
return Status::OK();
}
bool ParseBytesList(std::vector<string>* bytes_list) {
bool ParseBytesList(SmallVector<string>* bytes_list) {
DCHECK(bytes_list != nullptr);
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
@ -110,7 +115,7 @@ class Feature {
return true;
}
bool ParseFloatList(std::vector<float>* float_list) {
bool ParseFloatList(SmallVector<float>* float_list) {
DCHECK(float_list != nullptr);
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
@ -152,7 +157,7 @@ class Feature {
return true;
}
bool ParseInt64List(std::vector<int64>* int64_list) {
bool ParseInt64List(SmallVector<int64>* int64_list) {
DCHECK(int64_list != nullptr);
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
@ -293,7 +298,7 @@ bool TestFastParse(const string& serialized, Example* example) {
case DT_INVALID:
break;
case DT_STRING: {
std::vector<string> list;
SmallVector<string> list;
if (!entry.second.ParseBytesList(&list)) return false;
auto* result_list = value.mutable_bytes_list();
for (auto& bytes : list) {
@ -302,7 +307,7 @@ bool TestFastParse(const string& serialized, Example* example) {
break;
}
case DT_FLOAT: {
std::vector<float> list;
SmallVector<float> list;
if (!entry.second.ParseFloatList(&list)) return false;
auto* result_list = value.mutable_float_list();
for (float f : list) {
@ -311,7 +316,7 @@ bool TestFastParse(const string& serialized, Example* example) {
break;
}
case DT_INT64: {
std::vector<int64> list;
SmallVector<int64> list;
if (!entry.second.ParseInt64List(&list)) return false;
auto* result_list = value.mutable_int64_list();
for (int64 i : list) {
@ -334,28 +339,32 @@ using Config = FastParseExampleConfig;
void ParallelFor(const std::function<void(size_t)>& f, size_t n,
thread::ThreadPool* thread_pool) {
DCHECK(thread_pool != nullptr);
if (n == 0) return;
BlockingCounter counter(n - 1);
for (size_t i = 1; i < n; ++i) {
thread_pool->Schedule([i, &f, &counter] {
if (thread_pool == nullptr) {
for (size_t i = 0; i < n; ++i) {
f(i);
counter.DecrementCount();
});
}
} else {
BlockingCounter counter(n - 1);
for (size_t i = 1; i < n; ++i) {
thread_pool->Schedule([i, &f, &counter] {
f(i);
counter.DecrementCount();
});
}
f(0);
counter.Wait();
}
f(0);
counter.Wait();
}
enum class Type { Sparse, Dense };
struct SparseBuffer {
// TODO(lew): Use InlinedVector.
// Features are in one of the 3 vectors below depending on config's dtype.
// Other 2 vectors remain empty.
std::vector<string> bytes_list;
std::vector<float> float_list;
std::vector<int64> int64_list;
SmallVector<string> bytes_list;
SmallVector<float> float_list;
SmallVector<int64> int64_list;
// Features of example i are elements with indices
// from example_end_indices[i-1] to example_end_indices[i]-1 on the
@ -432,7 +441,7 @@ Status FastParseSerializedExample(
switch (config.dense[d].dtype) {
case DT_INT64: {
std::vector<int64> list;
SmallVector<int64> list;
if (!feature.ParseInt64List(&list)) return parse_error(feature_name);
if (list.size() != num_elements) {
return shape_error(list.size(), "int64");
@ -442,7 +451,7 @@ Status FastParseSerializedExample(
break;
}
case DT_FLOAT: {
std::vector<float> list;
SmallVector<float> list;
if (!feature.ParseFloatList(&list)) return parse_error(feature_name);
if (list.size() != num_elements) {
return shape_error(list.size(), "float");
@ -452,7 +461,7 @@ Status FastParseSerializedExample(
break;
}
case DT_STRING: {
std::vector<string> list;
SmallVector<string> list;
if (!feature.ParseBytesList(&list)) return parse_error(feature_name);
if (list.size() != num_elements) {
return shape_error(list.size(), "bytes");
@ -580,7 +589,6 @@ Status FastParseExample(const Config& config,
gtl::ArraySlice<string> serialized,
gtl::ArraySlice<string> example_names,
thread::ThreadPool* thread_pool, Result* result) {
DCHECK(thread_pool != nullptr);
DCHECK(result != nullptr);
// Check config so we can safely CHECK(false) in switches on config.*.dtype
for (auto& c : config.sparse) {
@ -626,36 +634,49 @@ Status FastParseExample(const Config& config,
}
// This parameter affects performance in a big and data-dependent way.
const size_t kMiniBatchSizeBytes = 100000;
const size_t kMiniBatchSizeBytes = 50000;
// Split examples into mini-batches for parallel processing.
auto first_example_of_minibatch = [&] {
std::vector<size_t> result;
// Calculate number of minibatches.
// In main regime make each minibatch around kMiniBatchSizeBytes bytes.
// Apply 'special logic' below for small and big regimes.
const size_t num_minibatches = [&] {
size_t result = 0;
size_t minibatch_bytes = 0;
for (size_t i = 0; i < serialized.size(); i++) {
if (minibatch_bytes == 0) { // start minibatch
result.push_back(i);
result++;
}
minibatch_bytes += serialized[i].size() + 1;
if (minibatch_bytes > kMiniBatchSizeBytes) {
minibatch_bytes = 0;
}
}
return result;
// 'special logic'
const size_t min_minibatches = std::min<size_t>(8, serialized.size());
const size_t max_minibatches = 64;
return std::max<size_t>(min_minibatches,
std::min<size_t>(max_minibatches, result));
}();
size_t num_minibatches = first_example_of_minibatch.size();
auto first_example_of_minibatch = [&](size_t minibatch) -> size_t {
return (serialized.size() * minibatch) / num_minibatches;
};
// TODO(lew): A big performance low-hanging fruit here is to improve
// num_minibatches calculation to take into account actual amount of work
// needed, as the size in bytes is not perfect. Linear combination of
// size in bytes and average number of features per example is promising.
// Even better: measure time instead of estimating, but this is too costly
// in small batches.
// Maybe accept outside parameter #num_minibatches?
// Do minibatches in parallel.
std::vector<std::vector<SparseBuffer>> sparse_buffers(num_minibatches);
std::vector<Status> status_of_minibatch(num_minibatches);
auto ProcessMiniBatch = [&](size_t minibatch) {
sparse_buffers[minibatch].resize(config.sparse.size());
size_t start = first_example_of_minibatch[minibatch];
size_t end = minibatch + 1 < num_minibatches
? first_example_of_minibatch[minibatch + 1]
: serialized.size();
size_t start = first_example_of_minibatch(minibatch);
size_t end = first_example_of_minibatch(minibatch + 1);
for (size_t e = start; e < end; ++e) {
status_of_minibatch[minibatch] = FastParseSerializedExample(
serialized[e],
@ -711,7 +732,7 @@ Status FastParseExample(const Config& config,
// Update indices.
int64* ix_p = &indices->matrix<int64>()(offset, 0);
size_t delta = 0;
size_t example_index = first_example_of_minibatch[i];
size_t example_index = first_example_of_minibatch(i);
for (size_t example_end_index : buffer.example_end_indices) {
size_t feature_index = 0;
for (; delta < example_end_index; ++delta) {

View File

@ -178,6 +178,15 @@ string MakeSerializedExample() {
return serialized;
}
TEST(TestFastParseExample, Empty) {
Result result;
FastParseExampleConfig config;
config.sparse.push_back({"test", DT_STRING});
Status status = FastParseExample(config, gtl::ArraySlice<string>(),
gtl::ArraySlice<string>(), nullptr, &result);
EXPECT_TRUE(status.ok()) << status;
}
} // namespace
} // namespace example