Add TraceMeProducer/Consumer for SharedBatchScheduler.
PiperOrigin-RevId: 316976310 Change-Id: I1cd2a03390aedd7e8e85b2826ab3aadd096bafdd
This commit is contained in:
parent
7a92859246
commit
b186ba0334
tensorflow/core
@ -70,6 +70,7 @@ cc_library(
|
||||
":batch_scheduler_hdrs",
|
||||
":periodic_function_dynamic",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core/profiler/lib:connected_traceme",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
],
|
||||
)
|
||||
@ -81,6 +82,7 @@ cc_library(
|
||||
":batch_scheduler",
|
||||
":periodic_function_dynamic",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/profiler/lib:connected_traceme",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -77,7 +77,8 @@ class BatchTask {
|
||||
template <typename TaskType>
|
||||
class Batch {
|
||||
public:
|
||||
Batch() = default;
|
||||
Batch();
|
||||
explicit Batch(uint64 traceme_context_id);
|
||||
virtual ~Batch(); // Blocks until the batch is closed.
|
||||
|
||||
// Appends 'task' to the batch. After calling AddTask(), the newly-added task
|
||||
@ -113,6 +114,9 @@ class Batch {
|
||||
// Marks the batch as closed. Dies if called more than once.
|
||||
void Close();
|
||||
|
||||
// Returns the TraceMe context id of this batch.
|
||||
uint64 traceme_context_id() const;
|
||||
|
||||
private:
|
||||
mutable mutex mu_;
|
||||
|
||||
@ -125,6 +129,9 @@ class Batch {
|
||||
// Whether the batch has been closed.
|
||||
Notification closed_;
|
||||
|
||||
// The TracMe context id.
|
||||
const uint64 traceme_context_id_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Batch);
|
||||
};
|
||||
|
||||
@ -187,6 +194,13 @@ class BatchScheduler {
|
||||
//////////
|
||||
// Implementation details follow. API users need not read.
|
||||
|
||||
template <typename TaskType>
|
||||
Batch<TaskType>::Batch() : Batch(0) {}
|
||||
|
||||
template <typename TaskType>
|
||||
Batch<TaskType>::Batch(uint64 traceme_context_id)
|
||||
: traceme_context_id_(traceme_context_id) {}
|
||||
|
||||
template <typename TaskType>
|
||||
Batch<TaskType>::~Batch() {
|
||||
WaitUntilClosed();
|
||||
@ -275,6 +289,11 @@ void Batch<TaskType>::Close() {
|
||||
closed_.Notify();
|
||||
}
|
||||
|
||||
template <typename TaskType>
|
||||
uint64 Batch<TaskType>::traceme_context_id() const {
|
||||
return traceme_context_id_;
|
||||
}
|
||||
|
||||
} // namespace serving
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -36,6 +36,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/lib/connected_traceme.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -311,6 +312,9 @@ class Queue {
|
||||
// The enqueued batches. See the invariants in the class comments above.
|
||||
std::deque<std::unique_ptr<Batch<TaskType>>> batches_ TF_GUARDED_BY(mu_);
|
||||
|
||||
// The counter of the TraceMe context ids.
|
||||
uint64 traceme_context_id_counter_ TF_GUARDED_BY(mu_) = 0;
|
||||
|
||||
// The time at which the first task was added to the open (back-most) batch
|
||||
// in 'batches_'. Valid iff that batch contains at least one task.
|
||||
uint64 open_batch_start_time_micros_ TF_GUARDED_BY(mu_);
|
||||
@ -529,8 +533,6 @@ Queue<TaskType>::~Queue() {
|
||||
|
||||
template <typename TaskType>
|
||||
Status Queue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
|
||||
profiler::TraceMe trace_me(
|
||||
[task] { return strings::StrCat("Schedule:", (*task)->size()); });
|
||||
if ((*task)->size() > options_.max_batch_size) {
|
||||
return errors::InvalidArgument("Task size ", (*task)->size(),
|
||||
" is larger than maximum batch size ",
|
||||
@ -554,6 +556,10 @@ Status Queue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
|
||||
if (batches_.back()->empty()) {
|
||||
open_batch_start_time_micros_ = env_->NowMicros();
|
||||
}
|
||||
profiler::TraceMeProducer trace_me(
|
||||
[&] { return strings::StrCat("Schedule:", (*task)->size()); },
|
||||
profiler::ContextType::kSharedBatchScheduler,
|
||||
batches_.back()->traceme_context_id());
|
||||
batches_.back()->AddTask(std::move(*task));
|
||||
|
||||
if (!schedulable_batch_) {
|
||||
@ -621,8 +627,10 @@ std::unique_ptr<Batch<TaskType>> Queue<TaskType>::ScheduleBatch() {
|
||||
|
||||
template <typename TaskType>
|
||||
void Queue<TaskType>::ProcessBatch(std::unique_ptr<Batch<TaskType>> batch) {
|
||||
profiler::TraceMe trace_me(
|
||||
[&batch] { return strings::StrCat("ProcessBatch:", batch->size()); });
|
||||
profiler::TraceMeConsumer trace_me(
|
||||
[&batch] { return strings::StrCat("ProcessBatch:", batch->size()); },
|
||||
profiler::ContextType::kSharedBatchScheduler,
|
||||
batch->traceme_context_id());
|
||||
process_batch_callback_(std::move(batch));
|
||||
|
||||
{
|
||||
@ -665,7 +673,7 @@ bool Queue<TaskType>::IsEmptyInternal() const {
|
||||
template <typename TaskType>
|
||||
void Queue<TaskType>::StartNewBatch() {
|
||||
batches_.back()->Close();
|
||||
batches_.emplace_back(new Batch<TaskType>);
|
||||
batches_.emplace_back(new Batch<TaskType>(++traceme_context_id_counter_));
|
||||
}
|
||||
|
||||
template <typename TaskType>
|
||||
|
@ -28,6 +28,7 @@ namespace profiler {
|
||||
enum class ContextType : int {
|
||||
kGeneric,
|
||||
kTfExecutor,
|
||||
kSharedBatchScheduler,
|
||||
};
|
||||
|
||||
/*
|
||||
|
Loading…
Reference in New Issue
Block a user