diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc index d638760b833..81a16522c55 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc @@ -83,6 +83,26 @@ const string& GetModelName(OpKernelContext* ctx) { } // namespace +std::unique_ptr +BatchResourceBase::BatchTask::CreateSplitTask( + int split_index, AsyncOpKernel::DoneCallback done_callback) { + std::unique_ptr task = CreateDerivedTask(); + + task->guid = this->guid; + task->propagated_context = Context(ContextKind::kThread); + task->inputs.reserve(this->inputs.size()); + task->captured_inputs = this->captured_inputs; + task->context = this->context; + task->done_callback = done_callback; + task->split_index = split_index; + task->output = this->output; + task->status = this->status; + task->is_partial = true; + task->start_time = this->start_time; + + return task; +} + using ::tensorflow::concat_split_util::Concat; using ::tensorflow::concat_split_util::Split; using TensorMatrix = std::vector>; @@ -317,20 +337,7 @@ Status BatchResourceBase::ConcatInputTensors( output_tasks->reserve(output_task_num); for (int i = 0; i < output_task_num; i++) { - auto task = absl::make_unique(); - task->guid = input_task.guid; - task->propagated_context = Context(ContextKind::kThread); - task->captured_inputs = input_task.captured_inputs; - task->context = input_task.context; - task->done_callback = barrier.Inc(); - task->start_time = input_task.start_time; - task->split_index = i; - task->inputs.reserve(input_task.inputs.size()); - task->is_partial = true; - task->status = input_task.status; - - task->output = input_task.output; - output_tasks->push_back(std::move(task)); + output_tasks->push_back(input_task.CreateSplitTask(i, barrier.Inc())); } const int num_input_tensors = input_task.inputs.size(); diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.h b/tensorflow/core/kernels/batching_util/batch_resource_base.h index 39d6e3dd951..89391f2defe 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.h +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.h @@ -87,9 +87,19 @@ class BatchResourceBase : public ResourceBase { bool is_partial = false; + uint64 start_time; + size_t size() const override { return inputs[0].shape().dim_size(0); } - uint64 start_time; + // Create a split task from this one. The caller needs to setup the inputs + // of the new task + std::unique_ptr CreateSplitTask( + int split_index, AsyncOpKernel::DoneCallback done_callback); + + protected: + virtual std::unique_ptr CreateDerivedTask() { + return std::make_unique(); + } }; // Appending a T suffix to make the type alias different to those in