Introduce op attribute 'enable_large_batch_splitting' and keep default behavior consistent.

Plug 'enable_large_batch_splitting' into queue options, so Queue can split input within its 'Schedule' method.

PiperOrigin-RevId: 313472810
Change-Id: I72753d3fc31d887d77d2015280b7b1628ba2c0aa
This commit is contained in:
Mingming Liu 2020-05-27 15:11:06 -07:00 committed by TensorFlower Gardener
parent ad8a4e1bda
commit 75fc53e381
6 changed files with 50 additions and 7 deletions

View File

@ -84,6 +84,13 @@ END
name: "Tout"
description: <<END
the types of the output tensors.
END
}
attr {
name: "enable_large_batch_splitting"
description: <<END
input with a large size (i.e., larger than the largest value of
`allowed_batch_sizes`) will be splitted into multiple batches with batch size.
END
}
summary: "Batches all the inputs tensors to the computation done by the function."

View File

@ -272,6 +272,7 @@ class BatchResource : public ResourceBase {
int32 batch_timeout_micros, int32 max_enqueued_batches,
const std::vector<int32>& allowed_batch_sizes,
FunctionLibraryRuntime::Handle fhandle,
bool enable_large_batch_splitting,
std::unique_ptr<BatchResource>* resource) {
std::unique_ptr<BatchResource> new_resource(new BatchResource);
@ -286,6 +287,10 @@ class BatchResource : public ResourceBase {
new_resource->batcher_queue_options_.batch_timeout_micros =
batch_timeout_micros;
// Support for splitting large batch is still in progress.
new_resource->batcher_queue_options_.enable_large_batch_splitting =
enable_large_batch_splitting;
new_resource->allowed_batch_sizes_ = allowed_batch_sizes;
new_resource->fhandle_ = fhandle;
@ -786,6 +791,13 @@ class BatchFunctionKernel : public AsyncOpKernel {
OP_REQUIRES_OK(c, c->GetAttr("f", &func));
OP_REQUIRES_OK(
c, lib->Instantiate(func.name(), AttrSlice(&func.attr()), &fhandle_));
if (c->HasAttr("enable_large_batch_splitting")) {
OP_REQUIRES_OK(c, c->GetAttr("enable_large_batch_splitting",
&enable_large_batch_splitting_));
} else {
enable_large_batch_splitting_ = false;
}
}
bool IsExpensive() override { return false; }
@ -794,10 +806,10 @@ class BatchFunctionKernel : public AsyncOpKernel {
BatchResource* br;
std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
std::unique_ptr<BatchResource> new_resource;
TF_RETURN_IF_ERROR(
BatchResource::Create(num_batch_threads_, max_batch_size_,
batch_timeout_micros_, max_enqueued_batches_,
allowed_batch_sizes_, fhandle_, &new_resource));
TF_RETURN_IF_ERROR(BatchResource::Create(
num_batch_threads_, max_batch_size_, batch_timeout_micros_,
max_enqueued_batches_, allowed_batch_sizes_, fhandle_,
enable_large_batch_splitting_, &new_resource));
*r = new_resource.release();
return Status::OK();
};
@ -844,6 +856,7 @@ class BatchFunctionKernel : public AsyncOpKernel {
int32 max_enqueued_batches_;
std::vector<int32> allowed_batch_sizes_;
FunctionLibraryRuntime::Handle fhandle_;
bool enable_large_batch_splitting_;
};
REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
@ -876,7 +889,7 @@ class BatchKernel : public AsyncOpKernel {
std::unique_ptr<BatchResource> new_resource;
TF_RETURN_IF_ERROR(BatchResource::Create(
num_batch_threads_, max_batch_size_, batch_timeout_micros_,
max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle,
max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle, false,
&new_resource));
*r = new_resource.release();
return Status::OK();

View File

@ -160,6 +160,10 @@ class SharedBatchScheduler
// See the class documentation above for guidelines on how to tune this
// parameter.
size_t max_enqueued_batches = 10;
// If true, queue implementation would split one input batch task into
// subtasks and fit them into different batches.
bool enable_large_batch_splitting = false;
};
Status AddQueue(const QueueOptions& options,
std::function<void(std::unique_ptr<Batch<TaskType>>)>

View File

@ -25,6 +25,19 @@ REGISTER_OP("BatchFunction")
.Output("out_tensors: Tout")
.Attr("f: func")
.Attr("num_batch_threads: int")
// 'max_batch_size' denotes the maximum batch size acceptable, i.e., inputs
// with larger batch size are simply invalidated.
// By default, 'max_batch_size' must be equal to max value of
// 'allowed_batch_sizes'.
// By setting 'enable_large_batch_splitting' (attribute below) to true,
// 'max_batch_size' can be greater than or equal to max value of
// 'allowed_batch_sizes', in other words,
// 1) input with size > 'max_batch_size' is still invalidated.
// 2) input with
// a) size <= 'max_batch_size'
// b) size > max value of 'allowed_batch_sizes'
// will automatically be split into multiple batches (with batch size in
// 'allowed_batch_sizes'), executed, and re-composed (as final output).
.Attr("max_batch_size: int")
.Attr("batch_timeout_micros: int")
.Attr("max_enqueued_batches: int = 10")
@ -35,6 +48,12 @@ REGISTER_OP("BatchFunction")
.Attr("Tin: list(type)")
.Attr("Tcaptured: list(type) >= 0")
.Attr("Tout: list(type)")
// If 'enable_large_batch_splitting' is true, for input batches exceeding
// the largest value in "allowed_batch_sizes", allow the batch to be split
// into multiple batches with batch size within "allowed_batch_sizes".
// NOTE: Support for `enable_large_batch_splitting == true` is still
// developed in progress.
.Attr("enable_large_batch_splitting: bool = false")
// TODO(apassos): Fix this shape inference function. It requires shape
// inference of function calls.
.SetShapeFn(shape_inference::UnknownShape);

View File

@ -342,7 +342,7 @@ tf_module {
}
member_method {
name: "BatchFunction"
argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'None\'], "
argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'enable_large_batch_splitting\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'False\', \'None\'], "
}
member_method {
name: "BatchIFFT"

View File

@ -342,7 +342,7 @@ tf_module {
}
member_method {
name: "BatchFunction"
argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'None\'], "
argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'enable_large_batch_splitting\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'False\', \'None\'], "
}
member_method {
name: "BatchIFFT"