Add TraceMeProducer/Consumer for SharedBatchScheduler.

PiperOrigin-RevId: 316976310
Change-Id: I1cd2a03390aedd7e8e85b2826ab3aadd096bafdd
This commit is contained in:
Li Lao 2020-06-17 15:21:23 -07:00 committed by TensorFlower Gardener
parent 7a92859246
commit b186ba0334
4 changed files with 36 additions and 6 deletions
tensorflow/core

View File

@ -70,6 +70,7 @@ cc_library(
":batch_scheduler_hdrs",
":periodic_function_dynamic",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core/profiler/lib:connected_traceme",
"//tensorflow/core/profiler/lib:traceme",
],
)
@ -81,6 +82,7 @@ cc_library(
":batch_scheduler",
":periodic_function_dynamic",
"//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:connected_traceme",
"//tensorflow/core/profiler/lib:traceme",
],
alwayslink = 1,

View File

@ -77,7 +77,8 @@ class BatchTask {
template <typename TaskType>
class Batch {
public:
Batch() = default;
Batch();
explicit Batch(uint64 traceme_context_id);
virtual ~Batch(); // Blocks until the batch is closed.
// Appends 'task' to the batch. After calling AddTask(), the newly-added task
@ -113,6 +114,9 @@ class Batch {
// Marks the batch as closed. Dies if called more than once.
void Close();
// Returns the TraceMe context id of this batch.
uint64 traceme_context_id() const;
private:
mutable mutex mu_;
@ -125,6 +129,9 @@ class Batch {
// Whether the batch has been closed.
Notification closed_;
// The TracMe context id.
const uint64 traceme_context_id_;
TF_DISALLOW_COPY_AND_ASSIGN(Batch);
};
@ -187,6 +194,13 @@ class BatchScheduler {
//////////
// Implementation details follow. API users need not read.
template <typename TaskType>
Batch<TaskType>::Batch() : Batch(0) {}
template <typename TaskType>
Batch<TaskType>::Batch(uint64 traceme_context_id)
: traceme_context_id_(traceme_context_id) {}
template <typename TaskType>
Batch<TaskType>::~Batch() {
WaitUntilClosed();
@ -275,6 +289,11 @@ void Batch<TaskType>::Close() {
closed_.Notify();
}
template <typename TaskType>
uint64 Batch<TaskType>::traceme_context_id() const {
return traceme_context_id_;
}
} // namespace serving
} // namespace tensorflow

View File

@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/connected_traceme.h"
#include "tensorflow/core/profiler/lib/traceme.h"
namespace tensorflow {
@ -311,6 +312,9 @@ class Queue {
// The enqueued batches. See the invariants in the class comments above.
std::deque<std::unique_ptr<Batch<TaskType>>> batches_ TF_GUARDED_BY(mu_);
// The counter of the TraceMe context ids.
uint64 traceme_context_id_counter_ TF_GUARDED_BY(mu_) = 0;
// The time at which the first task was added to the open (back-most) batch
// in 'batches_'. Valid iff that batch contains at least one task.
uint64 open_batch_start_time_micros_ TF_GUARDED_BY(mu_);
@ -529,8 +533,6 @@ Queue<TaskType>::~Queue() {
template <typename TaskType>
Status Queue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
profiler::TraceMe trace_me(
[task] { return strings::StrCat("Schedule:", (*task)->size()); });
if ((*task)->size() > options_.max_batch_size) {
return errors::InvalidArgument("Task size ", (*task)->size(),
" is larger than maximum batch size ",
@ -554,6 +556,10 @@ Status Queue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
if (batches_.back()->empty()) {
open_batch_start_time_micros_ = env_->NowMicros();
}
profiler::TraceMeProducer trace_me(
[&] { return strings::StrCat("Schedule:", (*task)->size()); },
profiler::ContextType::kSharedBatchScheduler,
batches_.back()->traceme_context_id());
batches_.back()->AddTask(std::move(*task));
if (!schedulable_batch_) {
@ -621,8 +627,10 @@ std::unique_ptr<Batch<TaskType>> Queue<TaskType>::ScheduleBatch() {
template <typename TaskType>
void Queue<TaskType>::ProcessBatch(std::unique_ptr<Batch<TaskType>> batch) {
profiler::TraceMe trace_me(
[&batch] { return strings::StrCat("ProcessBatch:", batch->size()); });
profiler::TraceMeConsumer trace_me(
[&batch] { return strings::StrCat("ProcessBatch:", batch->size()); },
profiler::ContextType::kSharedBatchScheduler,
batch->traceme_context_id());
process_batch_callback_(std::move(batch));
{
@ -665,7 +673,7 @@ bool Queue<TaskType>::IsEmptyInternal() const {
template <typename TaskType>
void Queue<TaskType>::StartNewBatch() {
batches_.back()->Close();
batches_.emplace_back(new Batch<TaskType>);
batches_.emplace_back(new Batch<TaskType>(++traceme_context_id_counter_));
}
template <typename TaskType>

View File

@ -28,6 +28,7 @@ namespace profiler {
enum class ContextType : int {
kGeneric,
kTfExecutor,
kSharedBatchScheduler,
};
/*