Fix a batch task creation bug in TFRT batch fallback kernel. Without the fix,
TFRT batch fallback kernel to crash in high QPS load. The bug makes it only create an object of base class BatchTask even when splitting a task of the derived class FallbackBatchTask (used only in TFRT), and put it into a batch with other FallbackBatchTask objects. When this batch of mixed types of tasks is processed, it crashes. PiperOrigin-RevId: 339927837 Change-Id: Ie52bd11c61c9ddbe6ab803cd90208419d4b2dba6
This commit is contained in:
parent
a877252856
commit
60ac36f504
@ -83,6 +83,26 @@ const string& GetModelName(OpKernelContext* ctx) {
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<BatchResourceBase::BatchTask>
|
||||
BatchResourceBase::BatchTask::CreateSplitTask(
|
||||
int split_index, AsyncOpKernel::DoneCallback done_callback) {
|
||||
std::unique_ptr<BatchTask> task = CreateDerivedTask();
|
||||
|
||||
task->guid = this->guid;
|
||||
task->propagated_context = Context(ContextKind::kThread);
|
||||
task->inputs.reserve(this->inputs.size());
|
||||
task->captured_inputs = this->captured_inputs;
|
||||
task->context = this->context;
|
||||
task->done_callback = done_callback;
|
||||
task->split_index = split_index;
|
||||
task->output = this->output;
|
||||
task->status = this->status;
|
||||
task->is_partial = true;
|
||||
task->start_time = this->start_time;
|
||||
|
||||
return task;
|
||||
}
|
||||
|
||||
using ::tensorflow::concat_split_util::Concat;
|
||||
using ::tensorflow::concat_split_util::Split;
|
||||
using TensorMatrix = std::vector<std::vector<Tensor>>;
|
||||
@ -317,20 +337,7 @@ Status BatchResourceBase::ConcatInputTensors(
|
||||
|
||||
output_tasks->reserve(output_task_num);
|
||||
for (int i = 0; i < output_task_num; i++) {
|
||||
auto task = absl::make_unique<BatchTask>();
|
||||
task->guid = input_task.guid;
|
||||
task->propagated_context = Context(ContextKind::kThread);
|
||||
task->captured_inputs = input_task.captured_inputs;
|
||||
task->context = input_task.context;
|
||||
task->done_callback = barrier.Inc();
|
||||
task->start_time = input_task.start_time;
|
||||
task->split_index = i;
|
||||
task->inputs.reserve(input_task.inputs.size());
|
||||
task->is_partial = true;
|
||||
task->status = input_task.status;
|
||||
|
||||
task->output = input_task.output;
|
||||
output_tasks->push_back(std::move(task));
|
||||
output_tasks->push_back(input_task.CreateSplitTask(i, barrier.Inc()));
|
||||
}
|
||||
|
||||
const int num_input_tensors = input_task.inputs.size();
|
||||
|
||||
@ -87,9 +87,19 @@ class BatchResourceBase : public ResourceBase {
|
||||
|
||||
bool is_partial = false;
|
||||
|
||||
uint64 start_time;
|
||||
|
||||
size_t size() const override { return inputs[0].shape().dim_size(0); }
|
||||
|
||||
uint64 start_time;
|
||||
// Create a split task from this one. The caller needs to setup the inputs
|
||||
// of the new task
|
||||
std::unique_ptr<BatchTask> CreateSplitTask(
|
||||
int split_index, AsyncOpKernel::DoneCallback done_callback);
|
||||
|
||||
protected:
|
||||
virtual std::unique_ptr<BatchTask> CreateDerivedTask() {
|
||||
return std::make_unique<BatchTask>();
|
||||
}
|
||||
};
|
||||
|
||||
// Appending a T suffix to make the type alias different to those in
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user