Add annotations for producing and consuming batch in adaptive shared batch scheduler.

PiperOrigin-RevId: 341962301
Change-Id: I14f6536c8ca399cd8ff6bafa42d8b8d535a8bd7c
This commit is contained in:
Mingming Liu 2020-11-11 19:44:32 -08:00 committed by TensorFlower Gardener
parent b6a59d53d3
commit 41485207f3
3 changed files with 47 additions and 5 deletions
tensorflow/core

View File

@ -123,6 +123,7 @@ cc_library(
":batch_scheduler",
":periodic_function_dynamic",
"//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:connected_traceme",
],
)

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
#include <algorithm>
#include <atomic>
#include <functional>
#include <memory>
#include <random>
@ -34,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/threadpool_interface.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/connected_traceme.h"
namespace tensorflow {
namespace serving {
@ -277,6 +279,10 @@ class ASBSQueue : public BatchScheduler<TaskType> {
// Number of size 1 tasks which could currently be scheduled without failing.
size_t SchedulingCapacityLocked() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Returns uint64 one greater than was returned by the previous call.
// Context id is reused after std::numeric_limits<uint64>::max is exhausted.
static uint64 NewTraceMeContextIdForBatch();
std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler_;
const QueueOptions options_;
// Owned by scheduler_.
@ -292,10 +298,11 @@ template <typename TaskType>
class ASBSBatch : public Batch<TaskType> {
public:
ASBSBatch(ASBSQueue<TaskType>* queue, int64 creation_time_micros,
int64 batch_timeout_micros)
int64 batch_timeout_micros, uint64 traceme_context_id)
: queue_(queue),
creation_time_micros_(creation_time_micros),
schedulable_time_micros_(creation_time_micros + batch_timeout_micros) {}
schedulable_time_micros_(creation_time_micros + batch_timeout_micros),
traceme_context_id_(traceme_context_id) {}
~ASBSBatch() override {}
@ -305,10 +312,13 @@ class ASBSBatch : public Batch<TaskType> {
int64 schedulable_time_micros() const { return schedulable_time_micros_; }
uint64 traceme_context_id() const { return traceme_context_id_; }
private:
ASBSQueue<TaskType>* queue_;
const int64 creation_time_micros_;
const int64 schedulable_time_micros_;
const uint64 traceme_context_id_;
TF_DISALLOW_COPY_AND_ASSIGN(ASBSBatch);
};
} // namespace internal
@ -505,6 +515,13 @@ void AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper(
const internal::ASBSBatch<TaskType>* batch,
AdaptiveSharedBatchScheduler<TaskType>::BatchProcessor callback,
bool is_express) {
profiler::TraceMeConsumer trace_me(
[&] {
return profiler::TraceMeEncode(
"ProcessBatch", {{"batch_size_before_padding", batch->size()}});
},
profiler::ContextType::kAdaptiveSharedBatchScheduler,
batch->traceme_context_id());
int64 start_time = batch->creation_time_micros();
callback(std::unique_ptr<Batch<TaskType>>(
const_cast<internal::ASBSBatch<TaskType>*>(batch)));
@ -599,6 +616,7 @@ Status ASBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
if (size > SchedulingCapacityLocked()) {
return errors::Unavailable("The batch scheduling queue is full");
}
int remaining_batch_size =
current_batch_ == nullptr
? options_.max_batch_size
@ -626,11 +644,26 @@ Status ASBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
}
if (!current_batch_) {
num_enqueued_batches_++;
current_batch_ =
new ASBSBatch<TaskType>(this, scheduler_->GetEnv()->NowMicros(),
options_.batch_timeout_micros);
// batch.traceme_context_id connects TraceMeProducer and
// TraceMeConsumer.
// When multiple calls to "ASBS::Schedule" accumulate to one batch, they
// are processed in the same batch and should share traceme_context_id.
current_batch_ = new ASBSBatch<TaskType>(
this, scheduler_->GetEnv()->NowMicros(),
options_.batch_timeout_micros, NewTraceMeContextIdForBatch());
new_batches.push_back(current_batch_);
}
// Annotate each task (corresponds to one call of schedule) with a
// TraceMeProducer.
profiler::TraceMeProducer trace_me(
[task_size = task->size()] {
return profiler::TraceMeEncode(
"ASBSQueue::Schedule",
{{"batching_input_task_size", task_size}});
},
profiler::ContextType::kAdaptiveSharedBatchScheduler,
this->current_batch_->traceme_context_id());
current_batch_->AddTask(std::move(task));
num_enqueued_tasks_++;
// If current_batch_ is now full, allow it to be processed immediately.
@ -683,6 +716,13 @@ size_t ASBSQueue<TaskType>::SchedulingCapacityLocked() const {
options_.max_enqueued_batches - num_enqueued_batches_;
return spare_batches * options_.max_batch_size + current_batch_capacity;
}
template <typename TaskType>
// static
uint64 ASBSQueue<TaskType>::NewTraceMeContextIdForBatch() {
static std::atomic<uint64> traceme_context_id(0);
return traceme_context_id.fetch_add(1, std::memory_order_relaxed);
}
} // namespace internal
} // namespace serving
} // namespace tensorflow

View File

@ -30,6 +30,7 @@ enum class ContextType : int {
kTfExecutor,
kSharedBatchScheduler,
kPjRt,
kAdaptiveSharedBatchScheduler,
};
/*