Add TraceMeProducer/Consumer for SharedBatchScheduler.
PiperOrigin-RevId: 316976310 Change-Id: I1cd2a03390aedd7e8e85b2826ab3aadd096bafdd
This commit is contained in:
parent
7a92859246
commit
b186ba0334
@ -70,6 +70,7 @@ cc_library(
|
|||||||
":batch_scheduler_hdrs",
|
":batch_scheduler_hdrs",
|
||||||
":periodic_function_dynamic",
|
":periodic_function_dynamic",
|
||||||
"//tensorflow/core:framework_headers_lib",
|
"//tensorflow/core:framework_headers_lib",
|
||||||
|
"//tensorflow/core/profiler/lib:connected_traceme",
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -81,6 +82,7 @@ cc_library(
|
|||||||
":batch_scheduler",
|
":batch_scheduler",
|
||||||
":periodic_function_dynamic",
|
":periodic_function_dynamic",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/profiler/lib:connected_traceme",
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
|
@ -77,7 +77,8 @@ class BatchTask {
|
|||||||
template <typename TaskType>
|
template <typename TaskType>
|
||||||
class Batch {
|
class Batch {
|
||||||
public:
|
public:
|
||||||
Batch() = default;
|
Batch();
|
||||||
|
explicit Batch(uint64 traceme_context_id);
|
||||||
virtual ~Batch(); // Blocks until the batch is closed.
|
virtual ~Batch(); // Blocks until the batch is closed.
|
||||||
|
|
||||||
// Appends 'task' to the batch. After calling AddTask(), the newly-added task
|
// Appends 'task' to the batch. After calling AddTask(), the newly-added task
|
||||||
@ -113,6 +114,9 @@ class Batch {
|
|||||||
// Marks the batch as closed. Dies if called more than once.
|
// Marks the batch as closed. Dies if called more than once.
|
||||||
void Close();
|
void Close();
|
||||||
|
|
||||||
|
// Returns the TraceMe context id of this batch.
|
||||||
|
uint64 traceme_context_id() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
mutable mutex mu_;
|
mutable mutex mu_;
|
||||||
|
|
||||||
@ -125,6 +129,9 @@ class Batch {
|
|||||||
// Whether the batch has been closed.
|
// Whether the batch has been closed.
|
||||||
Notification closed_;
|
Notification closed_;
|
||||||
|
|
||||||
|
// The TracMe context id.
|
||||||
|
const uint64 traceme_context_id_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(Batch);
|
TF_DISALLOW_COPY_AND_ASSIGN(Batch);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -187,6 +194,13 @@ class BatchScheduler {
|
|||||||
//////////
|
//////////
|
||||||
// Implementation details follow. API users need not read.
|
// Implementation details follow. API users need not read.
|
||||||
|
|
||||||
|
template <typename TaskType>
|
||||||
|
Batch<TaskType>::Batch() : Batch(0) {}
|
||||||
|
|
||||||
|
template <typename TaskType>
|
||||||
|
Batch<TaskType>::Batch(uint64 traceme_context_id)
|
||||||
|
: traceme_context_id_(traceme_context_id) {}
|
||||||
|
|
||||||
template <typename TaskType>
|
template <typename TaskType>
|
||||||
Batch<TaskType>::~Batch() {
|
Batch<TaskType>::~Batch() {
|
||||||
WaitUntilClosed();
|
WaitUntilClosed();
|
||||||
@ -275,6 +289,11 @@ void Batch<TaskType>::Close() {
|
|||||||
closed_.Notify();
|
closed_.Notify();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename TaskType>
|
||||||
|
uint64 Batch<TaskType>::traceme_context_id() const {
|
||||||
|
return traceme_context_id_;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace serving
|
} // namespace serving
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -36,6 +36,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/profiler/lib/connected_traceme.h"
|
||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -311,6 +312,9 @@ class Queue {
|
|||||||
// The enqueued batches. See the invariants in the class comments above.
|
// The enqueued batches. See the invariants in the class comments above.
|
||||||
std::deque<std::unique_ptr<Batch<TaskType>>> batches_ TF_GUARDED_BY(mu_);
|
std::deque<std::unique_ptr<Batch<TaskType>>> batches_ TF_GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
// The counter of the TraceMe context ids.
|
||||||
|
uint64 traceme_context_id_counter_ TF_GUARDED_BY(mu_) = 0;
|
||||||
|
|
||||||
// The time at which the first task was added to the open (back-most) batch
|
// The time at which the first task was added to the open (back-most) batch
|
||||||
// in 'batches_'. Valid iff that batch contains at least one task.
|
// in 'batches_'. Valid iff that batch contains at least one task.
|
||||||
uint64 open_batch_start_time_micros_ TF_GUARDED_BY(mu_);
|
uint64 open_batch_start_time_micros_ TF_GUARDED_BY(mu_);
|
||||||
@ -529,8 +533,6 @@ Queue<TaskType>::~Queue() {
|
|||||||
|
|
||||||
template <typename TaskType>
|
template <typename TaskType>
|
||||||
Status Queue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
|
Status Queue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
|
||||||
profiler::TraceMe trace_me(
|
|
||||||
[task] { return strings::StrCat("Schedule:", (*task)->size()); });
|
|
||||||
if ((*task)->size() > options_.max_batch_size) {
|
if ((*task)->size() > options_.max_batch_size) {
|
||||||
return errors::InvalidArgument("Task size ", (*task)->size(),
|
return errors::InvalidArgument("Task size ", (*task)->size(),
|
||||||
" is larger than maximum batch size ",
|
" is larger than maximum batch size ",
|
||||||
@ -554,6 +556,10 @@ Status Queue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
|
|||||||
if (batches_.back()->empty()) {
|
if (batches_.back()->empty()) {
|
||||||
open_batch_start_time_micros_ = env_->NowMicros();
|
open_batch_start_time_micros_ = env_->NowMicros();
|
||||||
}
|
}
|
||||||
|
profiler::TraceMeProducer trace_me(
|
||||||
|
[&] { return strings::StrCat("Schedule:", (*task)->size()); },
|
||||||
|
profiler::ContextType::kSharedBatchScheduler,
|
||||||
|
batches_.back()->traceme_context_id());
|
||||||
batches_.back()->AddTask(std::move(*task));
|
batches_.back()->AddTask(std::move(*task));
|
||||||
|
|
||||||
if (!schedulable_batch_) {
|
if (!schedulable_batch_) {
|
||||||
@ -621,8 +627,10 @@ std::unique_ptr<Batch<TaskType>> Queue<TaskType>::ScheduleBatch() {
|
|||||||
|
|
||||||
template <typename TaskType>
|
template <typename TaskType>
|
||||||
void Queue<TaskType>::ProcessBatch(std::unique_ptr<Batch<TaskType>> batch) {
|
void Queue<TaskType>::ProcessBatch(std::unique_ptr<Batch<TaskType>> batch) {
|
||||||
profiler::TraceMe trace_me(
|
profiler::TraceMeConsumer trace_me(
|
||||||
[&batch] { return strings::StrCat("ProcessBatch:", batch->size()); });
|
[&batch] { return strings::StrCat("ProcessBatch:", batch->size()); },
|
||||||
|
profiler::ContextType::kSharedBatchScheduler,
|
||||||
|
batch->traceme_context_id());
|
||||||
process_batch_callback_(std::move(batch));
|
process_batch_callback_(std::move(batch));
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -665,7 +673,7 @@ bool Queue<TaskType>::IsEmptyInternal() const {
|
|||||||
template <typename TaskType>
|
template <typename TaskType>
|
||||||
void Queue<TaskType>::StartNewBatch() {
|
void Queue<TaskType>::StartNewBatch() {
|
||||||
batches_.back()->Close();
|
batches_.back()->Close();
|
||||||
batches_.emplace_back(new Batch<TaskType>);
|
batches_.emplace_back(new Batch<TaskType>(++traceme_context_id_counter_));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TaskType>
|
template <typename TaskType>
|
||||||
|
@ -28,6 +28,7 @@ namespace profiler {
|
|||||||
enum class ContextType : int {
|
enum class ContextType : int {
|
||||||
kGeneric,
|
kGeneric,
|
||||||
kTfExecutor,
|
kTfExecutor,
|
||||||
|
kSharedBatchScheduler,
|
||||||
};
|
};
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
Loading…
x
Reference in New Issue
Block a user