Support GPUs in BatchFunction.
Now multiple devices are supported within BatchFunction. Currently, inputs and outputs must still be on the CPU, as the concatenation/splitting is done on the CPU. PiperOrigin-RevId: 347524478 Change-Id: Ib329987bf09513570c3c260e4c0834d6102a4364
This commit is contained in:
parent
85f554c486
commit
ba28d6de31
@ -1426,6 +1426,10 @@ void ProcessFunctionLibraryRuntime::Run(
|
||||
InternalArgs* comp_args) -> Status {
|
||||
// "Index"s of _Arg nodes are unique when all arguments are local Tensors.
|
||||
for (const auto& it : comp_data.arg_indices) {
|
||||
if (it.index >= args.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"index ", it.index, " is out of range [0, ", args.size(), ")");
|
||||
}
|
||||
if (it.sub_index >= 0) {
|
||||
const Tensor& t = args[it.index];
|
||||
if (t.dtype() != DT_RESOURCE) {
|
||||
|
@ -632,6 +632,7 @@ cc_library(
|
||||
srcs = ["batch_kernels.cc"],
|
||||
deps = [
|
||||
":ops_util_hdrs",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/framework/device.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
@ -187,12 +189,7 @@ class BatchFunctionKernel : public AsyncOpKernel {
|
||||
c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
|
||||
OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
|
||||
|
||||
auto lib = c->function_library();
|
||||
OP_REQUIRES(c, lib != nullptr, errors::Internal("No function library"));
|
||||
NameAttrList func;
|
||||
OP_REQUIRES_OK(c, c->GetAttr("f", &func));
|
||||
OP_REQUIRES_OK(
|
||||
c, lib->Instantiate(func.name(), AttrSlice(&func.attr()), &fhandle_));
|
||||
OP_REQUIRES_OK(c, c->GetAttr("f", &func_));
|
||||
if (num_batch_threads_ <= 0) {
|
||||
adaptive_batch_scheduler_options_ =
|
||||
absl::make_optional(AdaptiveBatchSchedulerOptions{
|
||||
@ -242,8 +239,11 @@ class BatchFunctionKernel : public AsyncOpKernel {
|
||||
|
||||
std::function<Status(BatchResource**)> creator;
|
||||
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
OP_REQUIRES_OK_ASYNC(c, GetOrCreateFunctionHandle(c, &handle), done);
|
||||
|
||||
if (adaptive_batch_scheduler_options_ != absl::nullopt) {
|
||||
creator = [this](BatchResource** r) {
|
||||
creator = [this, handle](BatchResource** r) {
|
||||
serving::AdaptiveSharedBatchScheduler<
|
||||
serving::BatchResourceBase::BatchTask>::Options
|
||||
adaptive_shared_batch_scheduler_options;
|
||||
@ -274,16 +274,16 @@ class BatchFunctionKernel : public AsyncOpKernel {
|
||||
TF_RETURN_IF_ERROR(BatchResource::Create(
|
||||
adaptive_shared_batch_scheduler_options, max_batch_size_,
|
||||
batch_timeout_micros_, max_enqueued_batches_, allowed_batch_sizes_,
|
||||
fhandle_, &new_resource));
|
||||
handle, &new_resource));
|
||||
*r = new_resource.release();
|
||||
return Status::OK();
|
||||
};
|
||||
} else {
|
||||
creator = [this](BatchResource** r) {
|
||||
creator = [this, handle](BatchResource** r) {
|
||||
std::unique_ptr<BatchResource> new_resource;
|
||||
TF_RETURN_IF_ERROR(BatchResource::Create(
|
||||
num_batch_threads_, max_batch_size_, batch_timeout_micros_,
|
||||
max_enqueued_batches_, allowed_batch_sizes_, fhandle_,
|
||||
max_enqueued_batches_, allowed_batch_sizes_, handle,
|
||||
enable_large_batch_splitting_, &new_resource));
|
||||
*r = new_resource.release();
|
||||
return Status::OK();
|
||||
@ -302,6 +302,75 @@ class BatchFunctionKernel : public AsyncOpKernel {
|
||||
// Assume br calls done, so nothing to do here.
|
||||
}
|
||||
|
||||
Status InstantiateFunction(OpKernelContext* c,
|
||||
FunctionLibraryRuntime::Handle* handle) const {
|
||||
// TODO(b/173748062): Merge this instantiation logic with PartitionedCall.
|
||||
FunctionLibraryRuntime* lib = c->function_library();
|
||||
if (!lib) {
|
||||
return errors::Internal("No function library");
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime::InstantiateOptions opts;
|
||||
opts.target = lib->device() == nullptr ? "" : lib->device()->name();
|
||||
opts.is_multi_device_function = true;
|
||||
|
||||
Device* cpu_device;
|
||||
TF_RETURN_IF_ERROR(lib->device_mgr()->LookupDevice("CPU:0", &cpu_device));
|
||||
|
||||
const FunctionDef* fdef =
|
||||
lib->GetFunctionLibraryDefinition()->Find(func_.name());
|
||||
if (!fdef) {
|
||||
return errors::NotFound("Failed to find definition for function \"",
|
||||
func_.name(), "\"");
|
||||
}
|
||||
OpInputList in_tensors;
|
||||
TF_RETURN_IF_ERROR(c->input_list("in_tensors", &in_tensors));
|
||||
for (int i = 0; i < in_tensors.size(); i++) {
|
||||
if (in_tensors[i].dtype() == DT_RESOURCE) {
|
||||
return errors::InvalidArgument(
|
||||
"BatchFunction cannot take resource inputs but input ", i,
|
||||
" is a resource.");
|
||||
} else {
|
||||
// Currently, inputs are on CPU since they are concatenated on CPU
|
||||
opts.input_devices.push_back(cpu_device->name());
|
||||
}
|
||||
}
|
||||
OpInputList captured_tensors;
|
||||
TF_RETURN_IF_ERROR(c->input_list("captured_tensors", &captured_tensors));
|
||||
for (const Tensor& t : captured_tensors) {
|
||||
if (t.dtype() == DT_RESOURCE) {
|
||||
const ResourceHandle& rhandle = t.flat<ResourceHandle>()(0);
|
||||
opts.input_devices.push_back(rhandle.device());
|
||||
} else {
|
||||
opts.input_devices.push_back(cpu_device->name());
|
||||
}
|
||||
}
|
||||
const OpDef& signature = fdef->signature();
|
||||
for (int i = 0; i < signature.output_arg_size(); i++) {
|
||||
// Currently, outputs must be on CPU since they are split on CPU.
|
||||
opts.output_devices.push_back(cpu_device->name());
|
||||
}
|
||||
if (opts.input_devices.size() != signature.input_arg_size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Function takes ", signature.input_arg_size(), " argument(s) but ",
|
||||
opts.input_devices.size(), " argument(s) were passed");
|
||||
}
|
||||
return lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), opts,
|
||||
handle);
|
||||
}
|
||||
|
||||
Status GetOrCreateFunctionHandle(OpKernelContext* c,
|
||||
FunctionLibraryRuntime::Handle* handle) {
|
||||
mutex_lock ml(mu_);
|
||||
if (!fhandle_) {
|
||||
TF_RETURN_IF_ERROR(InstantiateFunction(c, handle));
|
||||
fhandle_ = *handle;
|
||||
} else {
|
||||
*handle = fhandle_.value();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Validates 'allowed_batch_sizes_'. The entries must increase monotonically,
|
||||
// and the last one must equal 'max_batch_size_'.
|
||||
Status ValidateAllowedBatchSizes() const {
|
||||
@ -337,9 +406,11 @@ class BatchFunctionKernel : public AsyncOpKernel {
|
||||
int32 batch_timeout_micros_;
|
||||
int32 max_enqueued_batches_;
|
||||
std::vector<int32> allowed_batch_sizes_;
|
||||
FunctionLibraryRuntime::Handle fhandle_;
|
||||
NameAttrList func_;
|
||||
absl::optional<FunctionLibraryRuntime::Handle> fhandle_ TF_GUARDED_BY(mu_);
|
||||
bool enable_large_batch_splitting_;
|
||||
bool has_attribute_enable_large_batch_splitting_;
|
||||
mutex mu_;
|
||||
|
||||
// Parameters for adaptive batch scheduler only.
|
||||
// Note 'num_batch_threads_' above is shared by two implementations of batch
|
||||
@ -355,6 +426,14 @@ class BatchFunctionKernel : public AsyncOpKernel {
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
|
||||
BatchFunctionKernel);
|
||||
// Currently all inputs and outputs are on the host.
|
||||
// TODO(b/173748277): Accept inputs/outputs on the device.
|
||||
REGISTER_KERNEL_BUILDER(Name("BatchFunction")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("in_tensors")
|
||||
.HostMemory("captured_tensors")
|
||||
.HostMemory("out_tensors"),
|
||||
BatchFunctionKernel);
|
||||
|
||||
class BatchKernel : public AsyncOpKernel {
|
||||
public:
|
||||
|
@ -71,8 +71,10 @@ Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs,
|
||||
|
||||
TensorShape output_shape(input_shape);
|
||||
output_shape.set_dim(0, output_dim0);
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
|
||||
AllocatorAttributes attr;
|
||||
attr.set_on_host(true);
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(DataTypeToEnum<T>::value,
|
||||
output_shape, output, attr));
|
||||
if (output->NumElements() > 0) {
|
||||
auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
|
||||
@ -167,8 +169,10 @@ Status SplitCPU(OpKernelContext* context, const Tensor& input,
|
||||
TensorShape output_shape = input.shape();
|
||||
output_shape.set_dim(0, size);
|
||||
Tensor output;
|
||||
AllocatorAttributes attr;
|
||||
attr.set_on_host(true);
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->allocate_temp(input.dtype(), output_shape, &output));
|
||||
context->allocate_temp(input.dtype(), output_shape, &output, attr));
|
||||
auto output_shaped = output.shaped<T, 2>({size, suffix_dim_size});
|
||||
|
||||
Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{
|
||||
|
@ -2829,7 +2829,7 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
cuda_py_test(
|
||||
name = "batch_ops_test",
|
||||
size = "small",
|
||||
srcs = ["ops/batch_ops_test.py"],
|
||||
|
@ -25,12 +25,17 @@ import numpy as np
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework.errors import InvalidArgumentError
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import batch_ops
|
||||
from tensorflow.python.ops import gen_batch_ops
|
||||
from tensorflow.python.ops import gen_functional_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -50,7 +55,7 @@ class BatchOpsTest(test.TestCase):
|
||||
"""Tests that a single batched tensor executes together and only once."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
batched, index, _ = batch_ops.batch(
|
||||
[inp], num_batch_threads=1, max_batch_size=2,
|
||||
@ -92,7 +97,7 @@ class BatchOpsTest(test.TestCase):
|
||||
"""Test that batching with padding up to an allowed batch size works."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
|
||||
batched, index, _ = batch_ops.batch(
|
||||
[inp], num_batch_threads=1, max_batch_size=10,
|
||||
@ -124,7 +129,7 @@ class BatchOpsTest(test.TestCase):
|
||||
"""Tests that multiple batched tensors execute together."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
batched, _, _ = batch_ops.batch(
|
||||
@ -165,7 +170,7 @@ class BatchOpsTest(test.TestCase):
|
||||
"""Tests illegally feeding tensors with different dim0 sizes."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
|
||||
batched, index, _ = batch_ops.batch(
|
||||
@ -181,7 +186,7 @@ class BatchOpsTest(test.TestCase):
|
||||
"""Tests that batch and unbatch work together."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
batched, index, id_t = batch_ops.batch(
|
||||
[inp], num_batch_threads=1, max_batch_size=10,
|
||||
@ -207,7 +212,7 @@ class BatchOpsTest(test.TestCase):
|
||||
"""Tests that the batch_function decorator works."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
# TODO(apassos): Removing this line causes test flakiness! Ideally should
|
||||
# be investigated.
|
||||
default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable
|
||||
@ -235,33 +240,62 @@ class BatchOpsTest(test.TestCase):
|
||||
"""Tests that the batch_function decorator works."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
|
||||
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
captured_inp0 = array_ops.placeholder_with_default(2., shape=[])
|
||||
captured_inp1 = resource_variable_ops.ResourceVariable(3.)
|
||||
with ops.device("/cpu:0"):
|
||||
captured_inp2 = resource_variable_ops.ResourceVariable(4.)
|
||||
|
||||
@batch_ops.batch_function(1, 10, 100000)
|
||||
def computation(in_t):
|
||||
return in_t + captured_inp0 - captured_inp1
|
||||
return in_t + captured_inp0 + captured_inp1 + captured_inp2
|
||||
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
|
||||
result = computation(inp)
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
|
||||
|
||||
sess.run(variables.global_variables_initializer())
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([result], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [3])
|
||||
self.assertEqual(thread_results[0], [10])
|
||||
self.assertEqual(main_results[0], [11])
|
||||
|
||||
@test_util.disable_xla("DeviceIndex returns sentinel value with XLA")
|
||||
def testBatchDecoratedGpu(self):
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
|
||||
@batch_ops.batch_function(1, 10, 100000)
|
||||
def computation(in_t):
|
||||
# index is 0 on CPU and 1 on GPU
|
||||
index = gen_functional_ops.DeviceIndex(device_names=["CPU", "GPU"])
|
||||
return in_t + math_ops.cast(index, dtypes.float32)
|
||||
|
||||
inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
|
||||
result = computation(inp)
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [10.]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([result], feed_dict={inp: [20.]})
|
||||
worker_thread.join()
|
||||
self.assertEqual(thread_results[0], [10 + test_util.is_gpu_available()])
|
||||
self.assertEqual(main_results[0], [20 + test_util.is_gpu_available()])
|
||||
|
||||
def testBatchFunctionOp(self):
|
||||
"""Tests that the batch_function op works."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
|
||||
@function.Defun(dtypes.int32)
|
||||
def computation(in_t):
|
||||
@ -292,7 +326,7 @@ class BatchOpsTest(test.TestCase):
|
||||
"""Tests that batch_function op works with captured input."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
|
||||
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
@ -328,7 +362,7 @@ class BatchOpsTest(test.TestCase):
|
||||
"""Tests that batch_function op works with error in the inputs."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
|
||||
@function.Defun(dtypes.int32, dtypes.int32)
|
||||
@ -345,8 +379,9 @@ class BatchOpsTest(test.TestCase):
|
||||
captured_tensors=computation.captured_inputs,
|
||||
Tout=[o.type for o in computation.definition.signature.output_arg])
|
||||
|
||||
with self.assertRaisesRegex(InvalidArgumentError,
|
||||
".*2 arguments.*but 1.*"):
|
||||
with self.assertRaisesRegex(
|
||||
InvalidArgumentError,
|
||||
r"Function takes 2 argument\(s\) but 1 argument\(s\) were passed"):
|
||||
sess.run([result], feed_dict={inp: [2]})
|
||||
|
||||
def testBatchFunctionOpWithLargeBatchSplitted(self):
|
||||
@ -354,7 +389,7 @@ class BatchOpsTest(test.TestCase):
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
|
||||
@function.Defun(dtypes.int32)
|
||||
def computation(in_t):
|
||||
@ -408,7 +443,7 @@ class BatchOpsTest(test.TestCase):
|
||||
"""Tests that the batch_function decorator works."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
|
||||
@batch_ops.batch_function(1, 10, 100000)
|
||||
def computation(in_t):
|
||||
@ -432,7 +467,7 @@ class BatchOpsTest(test.TestCase):
|
||||
"""Tests that the unbatch timeout works."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
batched, index, id_t = batch_ops.batch(
|
||||
[inp], num_batch_threads=1, max_batch_size=2,
|
||||
|
Loading…
Reference in New Issue
Block a user