Further improve performance of ParseExample fixing regression on specific benchmarks.
- Ensure that each thread receives equal piece. - At least 8 threads (given 8 or more examples) and at most 64 - Use InlinedVector Change: 132488652
This commit is contained in:
parent
d3b34b7e56
commit
7705791619
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/casts.h"
|
#include "tensorflow/core/lib/core/casts.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
#include "tensorflow/core/util/presized_cuckoo_map.h"
|
#include "tensorflow/core/util/presized_cuckoo_map.h"
|
||||||
@ -34,6 +35,10 @@ namespace tensorflow {
|
|||||||
namespace example {
|
namespace example {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using SmallVector = gtl::InlinedVector<T, 4>;
|
||||||
|
|
||||||
template <typename A>
|
template <typename A>
|
||||||
auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
|
auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
|
||||||
a->EnableAliasing(true);
|
a->EnableAliasing(true);
|
||||||
@ -86,7 +91,7 @@ class Feature {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ParseBytesList(std::vector<string>* bytes_list) {
|
bool ParseBytesList(SmallVector<string>* bytes_list) {
|
||||||
DCHECK(bytes_list != nullptr);
|
DCHECK(bytes_list != nullptr);
|
||||||
protobuf::io::CodedInputStream stream(
|
protobuf::io::CodedInputStream stream(
|
||||||
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
|
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
|
||||||
@ -110,7 +115,7 @@ class Feature {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ParseFloatList(std::vector<float>* float_list) {
|
bool ParseFloatList(SmallVector<float>* float_list) {
|
||||||
DCHECK(float_list != nullptr);
|
DCHECK(float_list != nullptr);
|
||||||
protobuf::io::CodedInputStream stream(
|
protobuf::io::CodedInputStream stream(
|
||||||
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
|
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
|
||||||
@ -152,7 +157,7 @@ class Feature {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ParseInt64List(std::vector<int64>* int64_list) {
|
bool ParseInt64List(SmallVector<int64>* int64_list) {
|
||||||
DCHECK(int64_list != nullptr);
|
DCHECK(int64_list != nullptr);
|
||||||
protobuf::io::CodedInputStream stream(
|
protobuf::io::CodedInputStream stream(
|
||||||
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
|
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
|
||||||
@ -293,7 +298,7 @@ bool TestFastParse(const string& serialized, Example* example) {
|
|||||||
case DT_INVALID:
|
case DT_INVALID:
|
||||||
break;
|
break;
|
||||||
case DT_STRING: {
|
case DT_STRING: {
|
||||||
std::vector<string> list;
|
SmallVector<string> list;
|
||||||
if (!entry.second.ParseBytesList(&list)) return false;
|
if (!entry.second.ParseBytesList(&list)) return false;
|
||||||
auto* result_list = value.mutable_bytes_list();
|
auto* result_list = value.mutable_bytes_list();
|
||||||
for (auto& bytes : list) {
|
for (auto& bytes : list) {
|
||||||
@ -302,7 +307,7 @@ bool TestFastParse(const string& serialized, Example* example) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case DT_FLOAT: {
|
case DT_FLOAT: {
|
||||||
std::vector<float> list;
|
SmallVector<float> list;
|
||||||
if (!entry.second.ParseFloatList(&list)) return false;
|
if (!entry.second.ParseFloatList(&list)) return false;
|
||||||
auto* result_list = value.mutable_float_list();
|
auto* result_list = value.mutable_float_list();
|
||||||
for (float f : list) {
|
for (float f : list) {
|
||||||
@ -311,7 +316,7 @@ bool TestFastParse(const string& serialized, Example* example) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case DT_INT64: {
|
case DT_INT64: {
|
||||||
std::vector<int64> list;
|
SmallVector<int64> list;
|
||||||
if (!entry.second.ParseInt64List(&list)) return false;
|
if (!entry.second.ParseInt64List(&list)) return false;
|
||||||
auto* result_list = value.mutable_int64_list();
|
auto* result_list = value.mutable_int64_list();
|
||||||
for (int64 i : list) {
|
for (int64 i : list) {
|
||||||
@ -334,28 +339,32 @@ using Config = FastParseExampleConfig;
|
|||||||
|
|
||||||
void ParallelFor(const std::function<void(size_t)>& f, size_t n,
|
void ParallelFor(const std::function<void(size_t)>& f, size_t n,
|
||||||
thread::ThreadPool* thread_pool) {
|
thread::ThreadPool* thread_pool) {
|
||||||
DCHECK(thread_pool != nullptr);
|
|
||||||
if (n == 0) return;
|
if (n == 0) return;
|
||||||
BlockingCounter counter(n - 1);
|
if (thread_pool == nullptr) {
|
||||||
for (size_t i = 1; i < n; ++i) {
|
for (size_t i = 0; i < n; ++i) {
|
||||||
thread_pool->Schedule([i, &f, &counter] {
|
|
||||||
f(i);
|
f(i);
|
||||||
counter.DecrementCount();
|
}
|
||||||
});
|
} else {
|
||||||
|
BlockingCounter counter(n - 1);
|
||||||
|
for (size_t i = 1; i < n; ++i) {
|
||||||
|
thread_pool->Schedule([i, &f, &counter] {
|
||||||
|
f(i);
|
||||||
|
counter.DecrementCount();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
f(0);
|
||||||
|
counter.Wait();
|
||||||
}
|
}
|
||||||
f(0);
|
|
||||||
counter.Wait();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
enum class Type { Sparse, Dense };
|
enum class Type { Sparse, Dense };
|
||||||
|
|
||||||
struct SparseBuffer {
|
struct SparseBuffer {
|
||||||
// TODO(lew): Use InlinedVector.
|
|
||||||
// Features are in one of the 3 vectors below depending on config's dtype.
|
// Features are in one of the 3 vectors below depending on config's dtype.
|
||||||
// Other 2 vectors remain empty.
|
// Other 2 vectors remain empty.
|
||||||
std::vector<string> bytes_list;
|
SmallVector<string> bytes_list;
|
||||||
std::vector<float> float_list;
|
SmallVector<float> float_list;
|
||||||
std::vector<int64> int64_list;
|
SmallVector<int64> int64_list;
|
||||||
|
|
||||||
// Features of example i are elements with indices
|
// Features of example i are elements with indices
|
||||||
// from example_end_indices[i-1] to example_end_indices[i]-1 on the
|
// from example_end_indices[i-1] to example_end_indices[i]-1 on the
|
||||||
@ -432,7 +441,7 @@ Status FastParseSerializedExample(
|
|||||||
|
|
||||||
switch (config.dense[d].dtype) {
|
switch (config.dense[d].dtype) {
|
||||||
case DT_INT64: {
|
case DT_INT64: {
|
||||||
std::vector<int64> list;
|
SmallVector<int64> list;
|
||||||
if (!feature.ParseInt64List(&list)) return parse_error(feature_name);
|
if (!feature.ParseInt64List(&list)) return parse_error(feature_name);
|
||||||
if (list.size() != num_elements) {
|
if (list.size() != num_elements) {
|
||||||
return shape_error(list.size(), "int64");
|
return shape_error(list.size(), "int64");
|
||||||
@ -442,7 +451,7 @@ Status FastParseSerializedExample(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case DT_FLOAT: {
|
case DT_FLOAT: {
|
||||||
std::vector<float> list;
|
SmallVector<float> list;
|
||||||
if (!feature.ParseFloatList(&list)) return parse_error(feature_name);
|
if (!feature.ParseFloatList(&list)) return parse_error(feature_name);
|
||||||
if (list.size() != num_elements) {
|
if (list.size() != num_elements) {
|
||||||
return shape_error(list.size(), "float");
|
return shape_error(list.size(), "float");
|
||||||
@ -452,7 +461,7 @@ Status FastParseSerializedExample(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case DT_STRING: {
|
case DT_STRING: {
|
||||||
std::vector<string> list;
|
SmallVector<string> list;
|
||||||
if (!feature.ParseBytesList(&list)) return parse_error(feature_name);
|
if (!feature.ParseBytesList(&list)) return parse_error(feature_name);
|
||||||
if (list.size() != num_elements) {
|
if (list.size() != num_elements) {
|
||||||
return shape_error(list.size(), "bytes");
|
return shape_error(list.size(), "bytes");
|
||||||
@ -580,7 +589,6 @@ Status FastParseExample(const Config& config,
|
|||||||
gtl::ArraySlice<string> serialized,
|
gtl::ArraySlice<string> serialized,
|
||||||
gtl::ArraySlice<string> example_names,
|
gtl::ArraySlice<string> example_names,
|
||||||
thread::ThreadPool* thread_pool, Result* result) {
|
thread::ThreadPool* thread_pool, Result* result) {
|
||||||
DCHECK(thread_pool != nullptr);
|
|
||||||
DCHECK(result != nullptr);
|
DCHECK(result != nullptr);
|
||||||
// Check config so we can safely CHECK(false) in switches on config.*.dtype
|
// Check config so we can safely CHECK(false) in switches on config.*.dtype
|
||||||
for (auto& c : config.sparse) {
|
for (auto& c : config.sparse) {
|
||||||
@ -626,36 +634,49 @@ Status FastParseExample(const Config& config,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// This parameter affects performance in a big and data-dependent way.
|
// This parameter affects performance in a big and data-dependent way.
|
||||||
const size_t kMiniBatchSizeBytes = 100000;
|
const size_t kMiniBatchSizeBytes = 50000;
|
||||||
|
|
||||||
// Split examples into mini-batches for parallel processing.
|
// Calculate number of minibatches.
|
||||||
auto first_example_of_minibatch = [&] {
|
// In main regime make each minibatch around kMiniBatchSizeBytes bytes.
|
||||||
std::vector<size_t> result;
|
// Apply 'special logic' below for small and big regimes.
|
||||||
|
const size_t num_minibatches = [&] {
|
||||||
|
size_t result = 0;
|
||||||
size_t minibatch_bytes = 0;
|
size_t minibatch_bytes = 0;
|
||||||
for (size_t i = 0; i < serialized.size(); i++) {
|
for (size_t i = 0; i < serialized.size(); i++) {
|
||||||
if (minibatch_bytes == 0) { // start minibatch
|
if (minibatch_bytes == 0) { // start minibatch
|
||||||
result.push_back(i);
|
result++;
|
||||||
}
|
}
|
||||||
minibatch_bytes += serialized[i].size() + 1;
|
minibatch_bytes += serialized[i].size() + 1;
|
||||||
if (minibatch_bytes > kMiniBatchSizeBytes) {
|
if (minibatch_bytes > kMiniBatchSizeBytes) {
|
||||||
minibatch_bytes = 0;
|
minibatch_bytes = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result;
|
// 'special logic'
|
||||||
|
const size_t min_minibatches = std::min<size_t>(8, serialized.size());
|
||||||
|
const size_t max_minibatches = 64;
|
||||||
|
return std::max<size_t>(min_minibatches,
|
||||||
|
std::min<size_t>(max_minibatches, result));
|
||||||
}();
|
}();
|
||||||
|
|
||||||
size_t num_minibatches = first_example_of_minibatch.size();
|
auto first_example_of_minibatch = [&](size_t minibatch) -> size_t {
|
||||||
|
return (serialized.size() * minibatch) / num_minibatches;
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO(lew): A big performance low-hanging fruit here is to improve
|
||||||
|
// num_minibatches calculation to take into account actual amount of work
|
||||||
|
// needed, as the size in bytes is not perfect. Linear combination of
|
||||||
|
// size in bytes and average number of features per example is promising.
|
||||||
|
// Even better: measure time instead of estimating, but this is too costly
|
||||||
|
// in small batches.
|
||||||
|
// Maybe accept outside parameter #num_minibatches?
|
||||||
|
|
||||||
// Do minibatches in parallel.
|
// Do minibatches in parallel.
|
||||||
std::vector<std::vector<SparseBuffer>> sparse_buffers(num_minibatches);
|
std::vector<std::vector<SparseBuffer>> sparse_buffers(num_minibatches);
|
||||||
std::vector<Status> status_of_minibatch(num_minibatches);
|
std::vector<Status> status_of_minibatch(num_minibatches);
|
||||||
|
|
||||||
auto ProcessMiniBatch = [&](size_t minibatch) {
|
auto ProcessMiniBatch = [&](size_t minibatch) {
|
||||||
sparse_buffers[minibatch].resize(config.sparse.size());
|
sparse_buffers[minibatch].resize(config.sparse.size());
|
||||||
size_t start = first_example_of_minibatch[minibatch];
|
size_t start = first_example_of_minibatch(minibatch);
|
||||||
size_t end = minibatch + 1 < num_minibatches
|
size_t end = first_example_of_minibatch(minibatch + 1);
|
||||||
? first_example_of_minibatch[minibatch + 1]
|
|
||||||
: serialized.size();
|
|
||||||
for (size_t e = start; e < end; ++e) {
|
for (size_t e = start; e < end; ++e) {
|
||||||
status_of_minibatch[minibatch] = FastParseSerializedExample(
|
status_of_minibatch[minibatch] = FastParseSerializedExample(
|
||||||
serialized[e],
|
serialized[e],
|
||||||
@ -711,7 +732,7 @@ Status FastParseExample(const Config& config,
|
|||||||
// Update indices.
|
// Update indices.
|
||||||
int64* ix_p = &indices->matrix<int64>()(offset, 0);
|
int64* ix_p = &indices->matrix<int64>()(offset, 0);
|
||||||
size_t delta = 0;
|
size_t delta = 0;
|
||||||
size_t example_index = first_example_of_minibatch[i];
|
size_t example_index = first_example_of_minibatch(i);
|
||||||
for (size_t example_end_index : buffer.example_end_indices) {
|
for (size_t example_end_index : buffer.example_end_indices) {
|
||||||
size_t feature_index = 0;
|
size_t feature_index = 0;
|
||||||
for (; delta < example_end_index; ++delta) {
|
for (; delta < example_end_index; ++delta) {
|
||||||
|
@ -178,6 +178,15 @@ string MakeSerializedExample() {
|
|||||||
return serialized;
|
return serialized;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TestFastParseExample, Empty) {
|
||||||
|
Result result;
|
||||||
|
FastParseExampleConfig config;
|
||||||
|
config.sparse.push_back({"test", DT_STRING});
|
||||||
|
Status status = FastParseExample(config, gtl::ArraySlice<string>(),
|
||||||
|
gtl::ArraySlice<string>(), nullptr, &result);
|
||||||
|
EXPECT_TRUE(status.ok()) << status;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
} // namespace example
|
} // namespace example
|
||||||
|
Loading…
Reference in New Issue
Block a user