Introduce op attribute 'enable_large_batch_splitting' and keep default behavior consistent.
Plug 'enable_large_batch_splitting' into queue options, so Queue can split input within its 'Schedule' method. PiperOrigin-RevId: 313472810 Change-Id: I72753d3fc31d887d77d2015280b7b1628ba2c0aa
This commit is contained in:
parent
ad8a4e1bda
commit
75fc53e381
@ -84,6 +84,13 @@ END
|
||||
name: "Tout"
|
||||
description: <<END
|
||||
the types of the output tensors.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "enable_large_batch_splitting"
|
||||
description: <<END
|
||||
input with a large size (i.e., larger than the largest value of
|
||||
`allowed_batch_sizes`) will be splitted into multiple batches with batch size.
|
||||
END
|
||||
}
|
||||
summary: "Batches all the inputs tensors to the computation done by the function."
|
||||
|
@ -272,6 +272,7 @@ class BatchResource : public ResourceBase {
|
||||
int32 batch_timeout_micros, int32 max_enqueued_batches,
|
||||
const std::vector<int32>& allowed_batch_sizes,
|
||||
FunctionLibraryRuntime::Handle fhandle,
|
||||
bool enable_large_batch_splitting,
|
||||
std::unique_ptr<BatchResource>* resource) {
|
||||
std::unique_ptr<BatchResource> new_resource(new BatchResource);
|
||||
|
||||
@ -286,6 +287,10 @@ class BatchResource : public ResourceBase {
|
||||
new_resource->batcher_queue_options_.batch_timeout_micros =
|
||||
batch_timeout_micros;
|
||||
|
||||
// Support for splitting large batch is still in progress.
|
||||
new_resource->batcher_queue_options_.enable_large_batch_splitting =
|
||||
enable_large_batch_splitting;
|
||||
|
||||
new_resource->allowed_batch_sizes_ = allowed_batch_sizes;
|
||||
|
||||
new_resource->fhandle_ = fhandle;
|
||||
@ -786,6 +791,13 @@ class BatchFunctionKernel : public AsyncOpKernel {
|
||||
OP_REQUIRES_OK(c, c->GetAttr("f", &func));
|
||||
OP_REQUIRES_OK(
|
||||
c, lib->Instantiate(func.name(), AttrSlice(&func.attr()), &fhandle_));
|
||||
|
||||
if (c->HasAttr("enable_large_batch_splitting")) {
|
||||
OP_REQUIRES_OK(c, c->GetAttr("enable_large_batch_splitting",
|
||||
&enable_large_batch_splitting_));
|
||||
} else {
|
||||
enable_large_batch_splitting_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
bool IsExpensive() override { return false; }
|
||||
@ -794,10 +806,10 @@ class BatchFunctionKernel : public AsyncOpKernel {
|
||||
BatchResource* br;
|
||||
std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
|
||||
std::unique_ptr<BatchResource> new_resource;
|
||||
TF_RETURN_IF_ERROR(
|
||||
BatchResource::Create(num_batch_threads_, max_batch_size_,
|
||||
batch_timeout_micros_, max_enqueued_batches_,
|
||||
allowed_batch_sizes_, fhandle_, &new_resource));
|
||||
TF_RETURN_IF_ERROR(BatchResource::Create(
|
||||
num_batch_threads_, max_batch_size_, batch_timeout_micros_,
|
||||
max_enqueued_batches_, allowed_batch_sizes_, fhandle_,
|
||||
enable_large_batch_splitting_, &new_resource));
|
||||
*r = new_resource.release();
|
||||
return Status::OK();
|
||||
};
|
||||
@ -844,6 +856,7 @@ class BatchFunctionKernel : public AsyncOpKernel {
|
||||
int32 max_enqueued_batches_;
|
||||
std::vector<int32> allowed_batch_sizes_;
|
||||
FunctionLibraryRuntime::Handle fhandle_;
|
||||
bool enable_large_batch_splitting_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
|
||||
@ -876,7 +889,7 @@ class BatchKernel : public AsyncOpKernel {
|
||||
std::unique_ptr<BatchResource> new_resource;
|
||||
TF_RETURN_IF_ERROR(BatchResource::Create(
|
||||
num_batch_threads_, max_batch_size_, batch_timeout_micros_,
|
||||
max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle,
|
||||
max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle, false,
|
||||
&new_resource));
|
||||
*r = new_resource.release();
|
||||
return Status::OK();
|
||||
|
@ -160,6 +160,10 @@ class SharedBatchScheduler
|
||||
// See the class documentation above for guidelines on how to tune this
|
||||
// parameter.
|
||||
size_t max_enqueued_batches = 10;
|
||||
|
||||
// If true, queue implementation would split one input batch task into
|
||||
// subtasks and fit them into different batches.
|
||||
bool enable_large_batch_splitting = false;
|
||||
};
|
||||
Status AddQueue(const QueueOptions& options,
|
||||
std::function<void(std::unique_ptr<Batch<TaskType>>)>
|
||||
|
@ -25,6 +25,19 @@ REGISTER_OP("BatchFunction")
|
||||
.Output("out_tensors: Tout")
|
||||
.Attr("f: func")
|
||||
.Attr("num_batch_threads: int")
|
||||
// 'max_batch_size' denotes the maximum batch size acceptable, i.e., inputs
|
||||
// with larger batch size are simply invalidated.
|
||||
// By default, 'max_batch_size' must be equal to max value of
|
||||
// 'allowed_batch_sizes'.
|
||||
// By setting 'enable_large_batch_splitting' (attribute below) to true,
|
||||
// 'max_batch_size' can be greater than or equal to max value of
|
||||
// 'allowed_batch_sizes', in other words,
|
||||
// 1) input with size > 'max_batch_size' is still invalidated.
|
||||
// 2) input with
|
||||
// a) size <= 'max_batch_size'
|
||||
// b) size > max value of 'allowed_batch_sizes'
|
||||
// will automatically be split into multiple batches (with batch size in
|
||||
// 'allowed_batch_sizes'), executed, and re-composed (as final output).
|
||||
.Attr("max_batch_size: int")
|
||||
.Attr("batch_timeout_micros: int")
|
||||
.Attr("max_enqueued_batches: int = 10")
|
||||
@ -35,6 +48,12 @@ REGISTER_OP("BatchFunction")
|
||||
.Attr("Tin: list(type)")
|
||||
.Attr("Tcaptured: list(type) >= 0")
|
||||
.Attr("Tout: list(type)")
|
||||
// If 'enable_large_batch_splitting' is true, for input batches exceeding
|
||||
// the largest value in "allowed_batch_sizes", allow the batch to be split
|
||||
// into multiple batches with batch size within "allowed_batch_sizes".
|
||||
// NOTE: Support for `enable_large_batch_splitting == true` is still
|
||||
// developed in progress.
|
||||
.Attr("enable_large_batch_splitting: bool = false")
|
||||
// TODO(apassos): Fix this shape inference function. It requires shape
|
||||
// inference of function calls.
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
@ -342,7 +342,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "BatchFunction"
|
||||
argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'None\'], "
|
||||
argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'enable_large_batch_splitting\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BatchIFFT"
|
||||
|
@ -342,7 +342,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "BatchFunction"
|
||||
argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'None\'], "
|
||||
argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'enable_large_batch_splitting\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BatchIFFT"
|
||||
|
Loading…
Reference in New Issue
Block a user