From 60ac36f504ccef5aa50261ee2d496a8b3590e78f Mon Sep 17 00:00:00 2001 From: Hongmin Fan Date: Fri, 30 Oct 2020 13:17:43 -0700 Subject: [PATCH] Fix a batch task creation bug in TFRT batch fallback kernel. Without the fix, TFRT batch fallback kernel to crash in high QPS load. The bug makes it only create an object of base class BatchTask even when splitting a task of the derived class FallbackBatchTask (used only in TFRT), and put it into a batch with other FallbackBatchTask objects. When this batch of mixed types of tasks is processed, it crashes. PiperOrigin-RevId: 339927837 Change-Id: Ie52bd11c61c9ddbe6ab803cd90208419d4b2dba6 --- .../batching_util/batch_resource_base.cc | 35 +++++++++++-------- .../batching_util/batch_resource_base.h | 12 ++++++- 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc index d638760b833..81a16522c55 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc @@ -83,6 +83,26 @@ const string& GetModelName(OpKernelContext* ctx) { } // namespace +std::unique_ptr +BatchResourceBase::BatchTask::CreateSplitTask( + int split_index, AsyncOpKernel::DoneCallback done_callback) { + std::unique_ptr task = CreateDerivedTask(); + + task->guid = this->guid; + task->propagated_context = Context(ContextKind::kThread); + task->inputs.reserve(this->inputs.size()); + task->captured_inputs = this->captured_inputs; + task->context = this->context; + task->done_callback = done_callback; + task->split_index = split_index; + task->output = this->output; + task->status = this->status; + task->is_partial = true; + task->start_time = this->start_time; + + return task; +} + using ::tensorflow::concat_split_util::Concat; using ::tensorflow::concat_split_util::Split; using TensorMatrix = std::vector>; @@ -317,20 +337,7 @@ Status BatchResourceBase::ConcatInputTensors( output_tasks->reserve(output_task_num); for (int i = 0; i < output_task_num; i++) { - auto task = absl::make_unique(); - task->guid = input_task.guid; - task->propagated_context = Context(ContextKind::kThread); - task->captured_inputs = input_task.captured_inputs; - task->context = input_task.context; - task->done_callback = barrier.Inc(); - task->start_time = input_task.start_time; - task->split_index = i; - task->inputs.reserve(input_task.inputs.size()); - task->is_partial = true; - task->status = input_task.status; - - task->output = input_task.output; - output_tasks->push_back(std::move(task)); + output_tasks->push_back(input_task.CreateSplitTask(i, barrier.Inc())); } const int num_input_tensors = input_task.inputs.size(); diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.h b/tensorflow/core/kernels/batching_util/batch_resource_base.h index 39d6e3dd951..89391f2defe 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.h +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.h @@ -87,9 +87,19 @@ class BatchResourceBase : public ResourceBase { bool is_partial = false; + uint64 start_time; + size_t size() const override { return inputs[0].shape().dim_size(0); } - uint64 start_time; + // Create a split task from this one. The caller needs to setup the inputs + // of the new task + std::unique_ptr CreateSplitTask( + int split_index, AsyncOpKernel::DoneCallback done_callback); + + protected: + virtual std::unique_ptr CreateDerivedTask() { + return std::make_unique(); + } }; // Appending a T suffix to make the type alias different to those in