Introduce op attribute 'enable_large_batch_splitting' and keep default behavior consistent.
Plug 'enable_large_batch_splitting' into queue options, so Queue can split input within its 'Schedule' method. PiperOrigin-RevId: 313472810 Change-Id: I72753d3fc31d887d77d2015280b7b1628ba2c0aa
This commit is contained in:
parent
ad8a4e1bda
commit
75fc53e381
@ -84,6 +84,13 @@ END
|
|||||||
name: "Tout"
|
name: "Tout"
|
||||||
description: <<END
|
description: <<END
|
||||||
the types of the output tensors.
|
the types of the output tensors.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "enable_large_batch_splitting"
|
||||||
|
description: <<END
|
||||||
|
input with a large size (i.e., larger than the largest value of
|
||||||
|
`allowed_batch_sizes`) will be splitted into multiple batches with batch size.
|
||||||
END
|
END
|
||||||
}
|
}
|
||||||
summary: "Batches all the inputs tensors to the computation done by the function."
|
summary: "Batches all the inputs tensors to the computation done by the function."
|
||||||
|
@ -272,6 +272,7 @@ class BatchResource : public ResourceBase {
|
|||||||
int32 batch_timeout_micros, int32 max_enqueued_batches,
|
int32 batch_timeout_micros, int32 max_enqueued_batches,
|
||||||
const std::vector<int32>& allowed_batch_sizes,
|
const std::vector<int32>& allowed_batch_sizes,
|
||||||
FunctionLibraryRuntime::Handle fhandle,
|
FunctionLibraryRuntime::Handle fhandle,
|
||||||
|
bool enable_large_batch_splitting,
|
||||||
std::unique_ptr<BatchResource>* resource) {
|
std::unique_ptr<BatchResource>* resource) {
|
||||||
std::unique_ptr<BatchResource> new_resource(new BatchResource);
|
std::unique_ptr<BatchResource> new_resource(new BatchResource);
|
||||||
|
|
||||||
@ -286,6 +287,10 @@ class BatchResource : public ResourceBase {
|
|||||||
new_resource->batcher_queue_options_.batch_timeout_micros =
|
new_resource->batcher_queue_options_.batch_timeout_micros =
|
||||||
batch_timeout_micros;
|
batch_timeout_micros;
|
||||||
|
|
||||||
|
// Support for splitting large batch is still in progress.
|
||||||
|
new_resource->batcher_queue_options_.enable_large_batch_splitting =
|
||||||
|
enable_large_batch_splitting;
|
||||||
|
|
||||||
new_resource->allowed_batch_sizes_ = allowed_batch_sizes;
|
new_resource->allowed_batch_sizes_ = allowed_batch_sizes;
|
||||||
|
|
||||||
new_resource->fhandle_ = fhandle;
|
new_resource->fhandle_ = fhandle;
|
||||||
@ -786,6 +791,13 @@ class BatchFunctionKernel : public AsyncOpKernel {
|
|||||||
OP_REQUIRES_OK(c, c->GetAttr("f", &func));
|
OP_REQUIRES_OK(c, c->GetAttr("f", &func));
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
c, lib->Instantiate(func.name(), AttrSlice(&func.attr()), &fhandle_));
|
c, lib->Instantiate(func.name(), AttrSlice(&func.attr()), &fhandle_));
|
||||||
|
|
||||||
|
if (c->HasAttr("enable_large_batch_splitting")) {
|
||||||
|
OP_REQUIRES_OK(c, c->GetAttr("enable_large_batch_splitting",
|
||||||
|
&enable_large_batch_splitting_));
|
||||||
|
} else {
|
||||||
|
enable_large_batch_splitting_ = false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsExpensive() override { return false; }
|
bool IsExpensive() override { return false; }
|
||||||
@ -794,10 +806,10 @@ class BatchFunctionKernel : public AsyncOpKernel {
|
|||||||
BatchResource* br;
|
BatchResource* br;
|
||||||
std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
|
std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
|
||||||
std::unique_ptr<BatchResource> new_resource;
|
std::unique_ptr<BatchResource> new_resource;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(BatchResource::Create(
|
||||||
BatchResource::Create(num_batch_threads_, max_batch_size_,
|
num_batch_threads_, max_batch_size_, batch_timeout_micros_,
|
||||||
batch_timeout_micros_, max_enqueued_batches_,
|
max_enqueued_batches_, allowed_batch_sizes_, fhandle_,
|
||||||
allowed_batch_sizes_, fhandle_, &new_resource));
|
enable_large_batch_splitting_, &new_resource));
|
||||||
*r = new_resource.release();
|
*r = new_resource.release();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
};
|
||||||
@ -844,6 +856,7 @@ class BatchFunctionKernel : public AsyncOpKernel {
|
|||||||
int32 max_enqueued_batches_;
|
int32 max_enqueued_batches_;
|
||||||
std::vector<int32> allowed_batch_sizes_;
|
std::vector<int32> allowed_batch_sizes_;
|
||||||
FunctionLibraryRuntime::Handle fhandle_;
|
FunctionLibraryRuntime::Handle fhandle_;
|
||||||
|
bool enable_large_batch_splitting_;
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
|
||||||
@ -876,7 +889,7 @@ class BatchKernel : public AsyncOpKernel {
|
|||||||
std::unique_ptr<BatchResource> new_resource;
|
std::unique_ptr<BatchResource> new_resource;
|
||||||
TF_RETURN_IF_ERROR(BatchResource::Create(
|
TF_RETURN_IF_ERROR(BatchResource::Create(
|
||||||
num_batch_threads_, max_batch_size_, batch_timeout_micros_,
|
num_batch_threads_, max_batch_size_, batch_timeout_micros_,
|
||||||
max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle,
|
max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle, false,
|
||||||
&new_resource));
|
&new_resource));
|
||||||
*r = new_resource.release();
|
*r = new_resource.release();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -160,6 +160,10 @@ class SharedBatchScheduler
|
|||||||
// See the class documentation above for guidelines on how to tune this
|
// See the class documentation above for guidelines on how to tune this
|
||||||
// parameter.
|
// parameter.
|
||||||
size_t max_enqueued_batches = 10;
|
size_t max_enqueued_batches = 10;
|
||||||
|
|
||||||
|
// If true, queue implementation would split one input batch task into
|
||||||
|
// subtasks and fit them into different batches.
|
||||||
|
bool enable_large_batch_splitting = false;
|
||||||
};
|
};
|
||||||
Status AddQueue(const QueueOptions& options,
|
Status AddQueue(const QueueOptions& options,
|
||||||
std::function<void(std::unique_ptr<Batch<TaskType>>)>
|
std::function<void(std::unique_ptr<Batch<TaskType>>)>
|
||||||
|
@ -25,6 +25,19 @@ REGISTER_OP("BatchFunction")
|
|||||||
.Output("out_tensors: Tout")
|
.Output("out_tensors: Tout")
|
||||||
.Attr("f: func")
|
.Attr("f: func")
|
||||||
.Attr("num_batch_threads: int")
|
.Attr("num_batch_threads: int")
|
||||||
|
// 'max_batch_size' denotes the maximum batch size acceptable, i.e., inputs
|
||||||
|
// with larger batch size are simply invalidated.
|
||||||
|
// By default, 'max_batch_size' must be equal to max value of
|
||||||
|
// 'allowed_batch_sizes'.
|
||||||
|
// By setting 'enable_large_batch_splitting' (attribute below) to true,
|
||||||
|
// 'max_batch_size' can be greater than or equal to max value of
|
||||||
|
// 'allowed_batch_sizes', in other words,
|
||||||
|
// 1) input with size > 'max_batch_size' is still invalidated.
|
||||||
|
// 2) input with
|
||||||
|
// a) size <= 'max_batch_size'
|
||||||
|
// b) size > max value of 'allowed_batch_sizes'
|
||||||
|
// will automatically be split into multiple batches (with batch size in
|
||||||
|
// 'allowed_batch_sizes'), executed, and re-composed (as final output).
|
||||||
.Attr("max_batch_size: int")
|
.Attr("max_batch_size: int")
|
||||||
.Attr("batch_timeout_micros: int")
|
.Attr("batch_timeout_micros: int")
|
||||||
.Attr("max_enqueued_batches: int = 10")
|
.Attr("max_enqueued_batches: int = 10")
|
||||||
@ -35,6 +48,12 @@ REGISTER_OP("BatchFunction")
|
|||||||
.Attr("Tin: list(type)")
|
.Attr("Tin: list(type)")
|
||||||
.Attr("Tcaptured: list(type) >= 0")
|
.Attr("Tcaptured: list(type) >= 0")
|
||||||
.Attr("Tout: list(type)")
|
.Attr("Tout: list(type)")
|
||||||
|
// If 'enable_large_batch_splitting' is true, for input batches exceeding
|
||||||
|
// the largest value in "allowed_batch_sizes", allow the batch to be split
|
||||||
|
// into multiple batches with batch size within "allowed_batch_sizes".
|
||||||
|
// NOTE: Support for `enable_large_batch_splitting == true` is still
|
||||||
|
// developed in progress.
|
||||||
|
.Attr("enable_large_batch_splitting: bool = false")
|
||||||
// TODO(apassos): Fix this shape inference function. It requires shape
|
// TODO(apassos): Fix this shape inference function. It requires shape
|
||||||
// inference of function calls.
|
// inference of function calls.
|
||||||
.SetShapeFn(shape_inference::UnknownShape);
|
.SetShapeFn(shape_inference::UnknownShape);
|
||||||
|
@ -342,7 +342,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "BatchFunction"
|
name: "BatchFunction"
|
||||||
argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'None\'], "
|
argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'enable_large_batch_splitting\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'False\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "BatchIFFT"
|
name: "BatchIFFT"
|
||||||
|
@ -342,7 +342,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "BatchFunction"
|
name: "BatchFunction"
|
||||||
argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'None\'], "
|
argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'enable_large_batch_splitting\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'False\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "BatchIFFT"
|
name: "BatchIFFT"
|
||||||
|
Loading…
Reference in New Issue
Block a user