Support GPUs in BatchFunction.

Now multiple devices are supported within BatchFunction. Currently, inputs and outputs must still be on the CPU, as the concatenation/splitting is done on the CPU.

PiperOrigin-RevId: 347524478
Change-Id: Ib329987bf09513570c3c260e4c0834d6102a4364
This commit is contained in:
Reed Wanderman-Milne 2020-12-14 20:03:22 -08:00 committed by TensorFlower Gardener
parent 85f554c486
commit ba28d6de31
6 changed files with 159 additions and 36 deletions

View File

@ -1426,6 +1426,10 @@ void ProcessFunctionLibraryRuntime::Run(
InternalArgs* comp_args) -> Status {
// "Index"s of _Arg nodes are unique when all arguments are local Tensors.
for (const auto& it : comp_data.arg_indices) {
if (it.index >= args.size()) {
return errors::InvalidArgument(
"index ", it.index, " is out of range [0, ", args.size(), ")");
}
if (it.sub_index >= 0) {
const Tensor& t = args[it.index];
if (t.dtype() != DT_RESOURCE) {

View File

@ -632,6 +632,7 @@ cc_library(
srcs = ["batch_kernels.cc"],
deps = [
":ops_util_hdrs",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",

View File

@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "absl/strings/str_cat.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/device.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
@ -187,12 +189,7 @@ class BatchFunctionKernel : public AsyncOpKernel {
c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
auto lib = c->function_library();
OP_REQUIRES(c, lib != nullptr, errors::Internal("No function library"));
NameAttrList func;
OP_REQUIRES_OK(c, c->GetAttr("f", &func));
OP_REQUIRES_OK(
c, lib->Instantiate(func.name(), AttrSlice(&func.attr()), &fhandle_));
OP_REQUIRES_OK(c, c->GetAttr("f", &func_));
if (num_batch_threads_ <= 0) {
adaptive_batch_scheduler_options_ =
absl::make_optional(AdaptiveBatchSchedulerOptions{
@ -242,8 +239,11 @@ class BatchFunctionKernel : public AsyncOpKernel {
std::function<Status(BatchResource**)> creator;
FunctionLibraryRuntime::Handle handle;
OP_REQUIRES_OK_ASYNC(c, GetOrCreateFunctionHandle(c, &handle), done);
if (adaptive_batch_scheduler_options_ != absl::nullopt) {
creator = [this](BatchResource** r) {
creator = [this, handle](BatchResource** r) {
serving::AdaptiveSharedBatchScheduler<
serving::BatchResourceBase::BatchTask>::Options
adaptive_shared_batch_scheduler_options;
@ -274,16 +274,16 @@ class BatchFunctionKernel : public AsyncOpKernel {
TF_RETURN_IF_ERROR(BatchResource::Create(
adaptive_shared_batch_scheduler_options, max_batch_size_,
batch_timeout_micros_, max_enqueued_batches_, allowed_batch_sizes_,
fhandle_, &new_resource));
handle, &new_resource));
*r = new_resource.release();
return Status::OK();
};
} else {
creator = [this](BatchResource** r) {
creator = [this, handle](BatchResource** r) {
std::unique_ptr<BatchResource> new_resource;
TF_RETURN_IF_ERROR(BatchResource::Create(
num_batch_threads_, max_batch_size_, batch_timeout_micros_,
max_enqueued_batches_, allowed_batch_sizes_, fhandle_,
max_enqueued_batches_, allowed_batch_sizes_, handle,
enable_large_batch_splitting_, &new_resource));
*r = new_resource.release();
return Status::OK();
@ -302,6 +302,75 @@ class BatchFunctionKernel : public AsyncOpKernel {
// Assume br calls done, so nothing to do here.
}
Status InstantiateFunction(OpKernelContext* c,
FunctionLibraryRuntime::Handle* handle) const {
// TODO(b/173748062): Merge this instantiation logic with PartitionedCall.
FunctionLibraryRuntime* lib = c->function_library();
if (!lib) {
return errors::Internal("No function library");
}
FunctionLibraryRuntime::InstantiateOptions opts;
opts.target = lib->device() == nullptr ? "" : lib->device()->name();
opts.is_multi_device_function = true;
Device* cpu_device;
TF_RETURN_IF_ERROR(lib->device_mgr()->LookupDevice("CPU:0", &cpu_device));
const FunctionDef* fdef =
lib->GetFunctionLibraryDefinition()->Find(func_.name());
if (!fdef) {
return errors::NotFound("Failed to find definition for function \"",
func_.name(), "\"");
}
OpInputList in_tensors;
TF_RETURN_IF_ERROR(c->input_list("in_tensors", &in_tensors));
for (int i = 0; i < in_tensors.size(); i++) {
if (in_tensors[i].dtype() == DT_RESOURCE) {
return errors::InvalidArgument(
"BatchFunction cannot take resource inputs but input ", i,
" is a resource.");
} else {
// Currently, inputs are on CPU since they are concatenated on CPU
opts.input_devices.push_back(cpu_device->name());
}
}
OpInputList captured_tensors;
TF_RETURN_IF_ERROR(c->input_list("captured_tensors", &captured_tensors));
for (const Tensor& t : captured_tensors) {
if (t.dtype() == DT_RESOURCE) {
const ResourceHandle& rhandle = t.flat<ResourceHandle>()(0);
opts.input_devices.push_back(rhandle.device());
} else {
opts.input_devices.push_back(cpu_device->name());
}
}
const OpDef& signature = fdef->signature();
for (int i = 0; i < signature.output_arg_size(); i++) {
// Currently, outputs must be on CPU since they are split on CPU.
opts.output_devices.push_back(cpu_device->name());
}
if (opts.input_devices.size() != signature.input_arg_size()) {
return errors::InvalidArgument(
"Function takes ", signature.input_arg_size(), " argument(s) but ",
opts.input_devices.size(), " argument(s) were passed");
}
return lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), opts,
handle);
}
Status GetOrCreateFunctionHandle(OpKernelContext* c,
FunctionLibraryRuntime::Handle* handle) {
mutex_lock ml(mu_);
if (!fhandle_) {
TF_RETURN_IF_ERROR(InstantiateFunction(c, handle));
fhandle_ = *handle;
} else {
*handle = fhandle_.value();
}
return Status::OK();
}
// Validates 'allowed_batch_sizes_'. The entries must increase monotonically,
// and the last one must equal 'max_batch_size_'.
Status ValidateAllowedBatchSizes() const {
@ -337,9 +406,11 @@ class BatchFunctionKernel : public AsyncOpKernel {
int32 batch_timeout_micros_;
int32 max_enqueued_batches_;
std::vector<int32> allowed_batch_sizes_;
FunctionLibraryRuntime::Handle fhandle_;
NameAttrList func_;
absl::optional<FunctionLibraryRuntime::Handle> fhandle_ TF_GUARDED_BY(mu_);
bool enable_large_batch_splitting_;
bool has_attribute_enable_large_batch_splitting_;
mutex mu_;
// Parameters for adaptive batch scheduler only.
// Note 'num_batch_threads_' above is shared by two implementations of batch
@ -355,6 +426,14 @@ class BatchFunctionKernel : public AsyncOpKernel {
REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
BatchFunctionKernel);
// Currently all inputs and outputs are on the host.
// TODO(b/173748277): Accept inputs/outputs on the device.
REGISTER_KERNEL_BUILDER(Name("BatchFunction")
.Device(DEVICE_GPU)
.HostMemory("in_tensors")
.HostMemory("captured_tensors")
.HostMemory("out_tensors"),
BatchFunctionKernel);
class BatchKernel : public AsyncOpKernel {
public:

View File

@ -71,8 +71,10 @@ Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs,
TensorShape output_shape(input_shape);
output_shape.set_dim(0, output_dim0);
TF_RETURN_IF_ERROR(
context->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
AllocatorAttributes attr;
attr.set_on_host(true);
TF_RETURN_IF_ERROR(context->allocate_temp(DataTypeToEnum<T>::value,
output_shape, output, attr));
if (output->NumElements() > 0) {
auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
@ -167,8 +169,10 @@ Status SplitCPU(OpKernelContext* context, const Tensor& input,
TensorShape output_shape = input.shape();
output_shape.set_dim(0, size);
Tensor output;
AllocatorAttributes attr;
attr.set_on_host(true);
TF_RETURN_IF_ERROR(
context->allocate_temp(input.dtype(), output_shape, &output));
context->allocate_temp(input.dtype(), output_shape, &output, attr));
auto output_shaped = output.shaped<T, 2>({size, suffix_dim_size});
Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{

View File

@ -2829,7 +2829,7 @@ py_library(
],
)
py_test(
cuda_py_test(
name = "batch_ops_test",
size = "small",
srcs = ["ops/batch_ops_test.py"],

View File

@ -25,12 +25,17 @@ import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework.errors import InvalidArgumentError
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import batch_ops
from tensorflow.python.ops import gen_batch_ops
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -50,7 +55,7 @@ class BatchOpsTest(test.TestCase):
"""Tests that a single batched tensor executes together and only once."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
with self.cached_session(use_gpu=True) as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, _ = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
@ -92,7 +97,7 @@ class BatchOpsTest(test.TestCase):
"""Test that batching with padding up to an allowed batch size works."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
with self.cached_session(use_gpu=True) as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
batched, index, _ = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=10,
@ -124,7 +129,7 @@ class BatchOpsTest(test.TestCase):
"""Tests that multiple batched tensors execute together."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
with self.cached_session(use_gpu=True) as sess:
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, _, _ = batch_ops.batch(
@ -165,7 +170,7 @@ class BatchOpsTest(test.TestCase):
"""Tests illegally feeding tensors with different dim0 sizes."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
with self.cached_session(use_gpu=True) as sess:
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
batched, index, _ = batch_ops.batch(
@ -181,7 +186,7 @@ class BatchOpsTest(test.TestCase):
"""Tests that batch and unbatch work together."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
with self.cached_session(use_gpu=True) as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=10,
@ -207,7 +212,7 @@ class BatchOpsTest(test.TestCase):
"""Tests that the batch_function decorator works."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
with self.cached_session(use_gpu=True) as sess:
# TODO(apassos): Removing this line causes test flakiness! Ideally should
# be investigated.
default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable
@ -235,33 +240,62 @@ class BatchOpsTest(test.TestCase):
"""Tests that the batch_function decorator works."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
with self.cached_session(use_gpu=True) as sess:
captured_inp0 = array_ops.placeholder_with_default(2., shape=[])
captured_inp1 = resource_variable_ops.ResourceVariable(3.)
with ops.device("/cpu:0"):
captured_inp2 = resource_variable_ops.ResourceVariable(4.)
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
return in_t + captured_inp0 - captured_inp1
return in_t + captured_inp0 + captured_inp1 + captured_inp2
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
result = computation(inp)
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
sess.run(variables.global_variables_initializer())
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
self.assertEqual(thread_results[0], [10])
self.assertEqual(main_results[0], [11])
@test_util.disable_xla("DeviceIndex returns sentinel value with XLA")
def testBatchDecoratedGpu(self):
if context.executing_eagerly():
return
with self.cached_session(use_gpu=True) as sess:
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
# index is 0 on CPU and 1 on GPU
index = gen_functional_ops.DeviceIndex(device_names=["CPU", "GPU"])
return in_t + math_ops.cast(index, dtypes.float32)
inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
result = computation(inp)
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [10.]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [20.]})
worker_thread.join()
self.assertEqual(thread_results[0], [10 + test_util.is_gpu_available()])
self.assertEqual(main_results[0], [20 + test_util.is_gpu_available()])
def testBatchFunctionOp(self):
"""Tests that the batch_function op works."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
with self.cached_session(use_gpu=True) as sess:
@function.Defun(dtypes.int32)
def computation(in_t):
@ -292,7 +326,7 @@ class BatchOpsTest(test.TestCase):
"""Tests that batch_function op works with captured input."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
with self.cached_session(use_gpu=True) as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@ -328,7 +362,7 @@ class BatchOpsTest(test.TestCase):
"""Tests that batch_function op works with error in the inputs."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
with self.cached_session(use_gpu=True) as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@function.Defun(dtypes.int32, dtypes.int32)
@ -345,8 +379,9 @@ class BatchOpsTest(test.TestCase):
captured_tensors=computation.captured_inputs,
Tout=[o.type for o in computation.definition.signature.output_arg])
with self.assertRaisesRegex(InvalidArgumentError,
".*2 arguments.*but 1.*"):
with self.assertRaisesRegex(
InvalidArgumentError,
r"Function takes 2 argument\(s\) but 1 argument\(s\) were passed"):
sess.run([result], feed_dict={inp: [2]})
def testBatchFunctionOpWithLargeBatchSplitted(self):
@ -354,7 +389,7 @@ class BatchOpsTest(test.TestCase):
if context.executing_eagerly():
return
with self.cached_session() as sess:
with self.cached_session(use_gpu=True) as sess:
@function.Defun(dtypes.int32)
def computation(in_t):
@ -408,7 +443,7 @@ class BatchOpsTest(test.TestCase):
"""Tests that the batch_function decorator works."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
with self.cached_session(use_gpu=True) as sess:
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
@ -432,7 +467,7 @@ class BatchOpsTest(test.TestCase):
"""Tests that the unbatch timeout works."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
with self.cached_session(use_gpu=True) as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,