Fix a batch task creation bug in TFRT batch fallback kernel. Without the fix,

TFRT batch fallback kernel to crash in high QPS load. The bug makes it only
create an object of base class BatchTask even when splitting a task of the
derived class FallbackBatchTask (used only in TFRT), and put it into a batch
with other FallbackBatchTask objects. When this batch of mixed types of tasks is
processed, it crashes.

PiperOrigin-RevId: 339927837
Change-Id: Ie52bd11c61c9ddbe6ab803cd90208419d4b2dba6
This commit is contained in:
Hongmin Fan 2020-10-30 13:17:43 -07:00 committed by TensorFlower Gardener
parent a877252856
commit 60ac36f504
2 changed files with 32 additions and 15 deletions

View File

@ -83,6 +83,26 @@ const string& GetModelName(OpKernelContext* ctx) {
} // namespace
std::unique_ptr<BatchResourceBase::BatchTask>
BatchResourceBase::BatchTask::CreateSplitTask(
int split_index, AsyncOpKernel::DoneCallback done_callback) {
std::unique_ptr<BatchTask> task = CreateDerivedTask();
task->guid = this->guid;
task->propagated_context = Context(ContextKind::kThread);
task->inputs.reserve(this->inputs.size());
task->captured_inputs = this->captured_inputs;
task->context = this->context;
task->done_callback = done_callback;
task->split_index = split_index;
task->output = this->output;
task->status = this->status;
task->is_partial = true;
task->start_time = this->start_time;
return task;
}
using ::tensorflow::concat_split_util::Concat;
using ::tensorflow::concat_split_util::Split;
using TensorMatrix = std::vector<std::vector<Tensor>>;
@ -317,20 +337,7 @@ Status BatchResourceBase::ConcatInputTensors(
output_tasks->reserve(output_task_num);
for (int i = 0; i < output_task_num; i++) {
auto task = absl::make_unique<BatchTask>();
task->guid = input_task.guid;
task->propagated_context = Context(ContextKind::kThread);
task->captured_inputs = input_task.captured_inputs;
task->context = input_task.context;
task->done_callback = barrier.Inc();
task->start_time = input_task.start_time;
task->split_index = i;
task->inputs.reserve(input_task.inputs.size());
task->is_partial = true;
task->status = input_task.status;
task->output = input_task.output;
output_tasks->push_back(std::move(task));
output_tasks->push_back(input_task.CreateSplitTask(i, barrier.Inc()));
}
const int num_input_tensors = input_task.inputs.size();

View File

@ -87,9 +87,19 @@ class BatchResourceBase : public ResourceBase {
bool is_partial = false;
uint64 start_time;
size_t size() const override { return inputs[0].shape().dim_size(0); }
uint64 start_time;
// Create a split task from this one. The caller needs to setup the inputs
// of the new task
std::unique_ptr<BatchTask> CreateSplitTask(
int split_index, AsyncOpKernel::DoneCallback done_callback);
protected:
virtual std::unique_ptr<BatchTask> CreateDerivedTask() {
return std::make_unique<BatchTask>();
}
};
// Appending a T suffix to make the type alias different to those in