Fix bug in BatchFunction when run on GPU.
Sometimes the outputs would be wrong when the function was run multiple times in parallel, which caused the testBatchFunctionOpWithLargeBatchSplitted test to occasionally fail. The test was previously disabled when XLA was used, but the test would sometimes fail even without XLA. The issue was a rendezvous was specified in FunctionLibraryRuntime::Options when running the function. The rendezvous used was the Session's rendezvous, so it was shared among all calls to the function. If the function was run multiple times in parallel, tensors from one run's Send may be incorrectly sent to another run's Recv since they share the same rendezvous, edge name, and incarnation. Now the rendezvous is not specified, causing each run to create a new rendezvous. PiperOrigin-RevId: 357831120 Change-Id: I80092fcc5d511e374aebf8ebcba9bce690d1e0fa
This commit is contained in:
parent
52ecc01520
commit
a00d75168b
@ -160,9 +160,12 @@ class BatchResource : public serving::BatchResourceBase {
|
||||
opts.cancellation_manager = last_task_context->cancellation_manager();
|
||||
opts.collective_executor = last_task_context->collective_executor();
|
||||
opts.stats_collector = last_task_context->stats_collector();
|
||||
opts.rendezvous = last_task_context->rendezvous();
|
||||
opts.runner = last_task_context->runner();
|
||||
opts.run_all_kernels_inline = last_task_context->run_all_kernels_inline();
|
||||
// We do not set 'opts.rendezvous', since if the function is run multiple
|
||||
// times in parallel with the same rendezvous, a _Send node from one run
|
||||
// might be matched with a _Recv node of a different run. Not setting the
|
||||
// rendezvous causes a new rendezvous to be used for each run.
|
||||
Notification done_notif;
|
||||
|
||||
flib_->Run(opts, fhandle_, inputs, combined_outputs,
|
||||
|
@ -292,6 +292,31 @@ class BatchOpsTest(test.TestCase):
|
||||
self.assertEqual(thread_results[0], [10 + test_util.is_gpu_available()])
|
||||
self.assertEqual(main_results[0], [20 + test_util.is_gpu_available()])
|
||||
|
||||
def testParallelRunsWithCpuAndGpu(self):
|
||||
# Run multiple instances of a batch function in parallel. This is a
|
||||
# regression test: this used to fail because _Send nodes for one call would
|
||||
# send the tensor to the _Recv node for a different call.
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
@batch_ops.batch_function(1, 2, 1)
|
||||
def f(x):
|
||||
with ops.device("/GPU:0"):
|
||||
x = x + 1.
|
||||
with ops.device("/CPU:0"):
|
||||
return x + 1
|
||||
num_calls = 10
|
||||
placeholders = [array_ops.placeholder(dtypes.float32, shape=(1,))
|
||||
for _ in range(num_calls)]
|
||||
results = []
|
||||
for p in placeholders:
|
||||
(result,) = f(p)
|
||||
results.append(result)
|
||||
inputs = [[float(i)] for i in range(num_calls)]
|
||||
expected = [[float(i + 2)] for i in range(num_calls)]
|
||||
with self.session() as sess:
|
||||
outputs = sess.run(results, feed_dict=dict(zip(placeholders, inputs)))
|
||||
self.assertAllEqual(outputs, expected)
|
||||
|
||||
def testSoftPlacement(self):
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
@ -415,9 +440,6 @@ class BatchOpsTest(test.TestCase):
|
||||
|
||||
def testBatchFunctionOpWithLargeBatchSplitted(self):
|
||||
"""Tests that the batch_function op works with large batch splitted."""
|
||||
if test_util.is_xla_enabled():
|
||||
self.skipTest("b/178649404")
|
||||
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user