From 75fc53e381ab874eed2eafb13bfdd42dcfdcd440 Mon Sep 17 00:00:00 2001 From: Mingming Liu Date: Wed, 27 May 2020 15:11:06 -0700 Subject: [PATCH] Introduce op attribute 'enable_large_batch_splitting' and keep default behavior consistent. Plug 'enable_large_batch_splitting' into queue options, so Queue can split input within its 'Schedule' method. PiperOrigin-RevId: 313472810 Change-Id: I72753d3fc31d887d77d2015280b7b1628ba2c0aa --- .../base_api/api_def_BatchFunction.pbtxt | 7 ++++++ tensorflow/core/kernels/batch_kernels.cc | 23 +++++++++++++++---- .../batching_util/shared_batch_scheduler.h | 4 ++++ tensorflow/core/ops/batch_ops.cc | 19 +++++++++++++++ .../api/golden/v1/tensorflow.raw_ops.pbtxt | 2 +- .../api/golden/v2/tensorflow.raw_ops.pbtxt | 2 +- 6 files changed, 50 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/api_def/base_api/api_def_BatchFunction.pbtxt b/tensorflow/core/api_def/base_api/api_def_BatchFunction.pbtxt index 09eff6177b1..ae5942b3617 100644 --- a/tensorflow/core/api_def/base_api/api_def_BatchFunction.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_BatchFunction.pbtxt @@ -84,6 +84,13 @@ END name: "Tout" description: <& allowed_batch_sizes, FunctionLibraryRuntime::Handle fhandle, + bool enable_large_batch_splitting, std::unique_ptr* resource) { std::unique_ptr new_resource(new BatchResource); @@ -286,6 +287,10 @@ class BatchResource : public ResourceBase { new_resource->batcher_queue_options_.batch_timeout_micros = batch_timeout_micros; + // Support for splitting large batch is still in progress. + new_resource->batcher_queue_options_.enable_large_batch_splitting = + enable_large_batch_splitting; + new_resource->allowed_batch_sizes_ = allowed_batch_sizes; new_resource->fhandle_ = fhandle; @@ -786,6 +791,13 @@ class BatchFunctionKernel : public AsyncOpKernel { OP_REQUIRES_OK(c, c->GetAttr("f", &func)); OP_REQUIRES_OK( c, lib->Instantiate(func.name(), AttrSlice(&func.attr()), &fhandle_)); + + if (c->HasAttr("enable_large_batch_splitting")) { + OP_REQUIRES_OK(c, c->GetAttr("enable_large_batch_splitting", + &enable_large_batch_splitting_)); + } else { + enable_large_batch_splitting_ = false; + } } bool IsExpensive() override { return false; } @@ -794,10 +806,10 @@ class BatchFunctionKernel : public AsyncOpKernel { BatchResource* br; std::function creator = [this](BatchResource** r) { std::unique_ptr new_resource; - TF_RETURN_IF_ERROR( - BatchResource::Create(num_batch_threads_, max_batch_size_, - batch_timeout_micros_, max_enqueued_batches_, - allowed_batch_sizes_, fhandle_, &new_resource)); + TF_RETURN_IF_ERROR(BatchResource::Create( + num_batch_threads_, max_batch_size_, batch_timeout_micros_, + max_enqueued_batches_, allowed_batch_sizes_, fhandle_, + enable_large_batch_splitting_, &new_resource)); *r = new_resource.release(); return Status::OK(); }; @@ -844,6 +856,7 @@ class BatchFunctionKernel : public AsyncOpKernel { int32 max_enqueued_batches_; std::vector allowed_batch_sizes_; FunctionLibraryRuntime::Handle fhandle_; + bool enable_large_batch_splitting_; }; REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU), @@ -876,7 +889,7 @@ class BatchKernel : public AsyncOpKernel { std::unique_ptr new_resource; TF_RETURN_IF_ERROR(BatchResource::Create( num_batch_threads_, max_batch_size_, batch_timeout_micros_, - max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle, + max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle, false, &new_resource)); *r = new_resource.release(); return Status::OK(); diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h index c44de023ced..66bdff933d8 100644 --- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h @@ -160,6 +160,10 @@ class SharedBatchScheduler // See the class documentation above for guidelines on how to tune this // parameter. size_t max_enqueued_batches = 10; + + // If true, queue implementation would split one input batch task into + // subtasks and fit them into different batches. + bool enable_large_batch_splitting = false; }; Status AddQueue(const QueueOptions& options, std::function>)> diff --git a/tensorflow/core/ops/batch_ops.cc b/tensorflow/core/ops/batch_ops.cc index ba7faeb5e8a..cfa4049938d 100644 --- a/tensorflow/core/ops/batch_ops.cc +++ b/tensorflow/core/ops/batch_ops.cc @@ -25,6 +25,19 @@ REGISTER_OP("BatchFunction") .Output("out_tensors: Tout") .Attr("f: func") .Attr("num_batch_threads: int") + // 'max_batch_size' denotes the maximum batch size acceptable, i.e., inputs + // with larger batch size are simply invalidated. + // By default, 'max_batch_size' must be equal to max value of + // 'allowed_batch_sizes'. + // By setting 'enable_large_batch_splitting' (attribute below) to true, + // 'max_batch_size' can be greater than or equal to max value of + // 'allowed_batch_sizes', in other words, + // 1) input with size > 'max_batch_size' is still invalidated. + // 2) input with + // a) size <= 'max_batch_size' + // b) size > max value of 'allowed_batch_sizes' + // will automatically be split into multiple batches (with batch size in + // 'allowed_batch_sizes'), executed, and re-composed (as final output). .Attr("max_batch_size: int") .Attr("batch_timeout_micros: int") .Attr("max_enqueued_batches: int = 10") @@ -35,6 +48,12 @@ REGISTER_OP("BatchFunction") .Attr("Tin: list(type)") .Attr("Tcaptured: list(type) >= 0") .Attr("Tout: list(type)") + // If 'enable_large_batch_splitting' is true, for input batches exceeding + // the largest value in "allowed_batch_sizes", allow the batch to be split + // into multiple batches with batch size within "allowed_batch_sizes". + // NOTE: Support for `enable_large_batch_splitting == true` is still + // developed in progress. + .Attr("enable_large_batch_splitting: bool = false") // TODO(apassos): Fix this shape inference function. It requires shape // inference of function calls. .SetShapeFn(shape_inference::UnknownShape); diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 37a95cc88d1..a8efb9e59b5 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -342,7 +342,7 @@ tf_module { } member_method { name: "BatchFunction" - argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'None\'], " + argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'enable_large_batch_splitting\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'False\', \'None\'], " } member_method { name: "BatchIFFT" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 37a95cc88d1..a8efb9e59b5 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -342,7 +342,7 @@ tf_module { } member_method { name: "BatchFunction" - argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'None\'], " + argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'enable_large_batch_splitting\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'False\', \'None\'], " } member_method { name: "BatchIFFT"