From a00d75168ba8ce52d2e03833871f7d5a2a88ccd3 Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Tue, 16 Feb 2021 16:34:03 -0800 Subject: [PATCH] Fix bug in BatchFunction when run on GPU. Sometimes the outputs would be wrong when the function was run multiple times in parallel, which caused the testBatchFunctionOpWithLargeBatchSplitted test to occasionally fail. The test was previously disabled when XLA was used, but the test would sometimes fail even without XLA. The issue was a rendezvous was specified in FunctionLibraryRuntime::Options when running the function. The rendezvous used was the Session's rendezvous, so it was shared among all calls to the function. If the function was run multiple times in parallel, tensors from one run's Send may be incorrectly sent to another run's Recv since they share the same rendezvous, edge name, and incarnation. Now the rendezvous is not specified, causing each run to create a new rendezvous. PiperOrigin-RevId: 357831120 Change-Id: I80092fcc5d511e374aebf8ebcba9bce690d1e0fa --- tensorflow/core/kernels/batch_kernels.cc | 5 ++++- tensorflow/python/ops/batch_ops_test.py | 28 +++++++++++++++++++++--- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc index b52b4ab563d..5c0e6cd6524 100644 --- a/tensorflow/core/kernels/batch_kernels.cc +++ b/tensorflow/core/kernels/batch_kernels.cc @@ -160,9 +160,12 @@ class BatchResource : public serving::BatchResourceBase { opts.cancellation_manager = last_task_context->cancellation_manager(); opts.collective_executor = last_task_context->collective_executor(); opts.stats_collector = last_task_context->stats_collector(); - opts.rendezvous = last_task_context->rendezvous(); opts.runner = last_task_context->runner(); opts.run_all_kernels_inline = last_task_context->run_all_kernels_inline(); + // We do not set 'opts.rendezvous', since if the function is run multiple + // times in parallel with the same rendezvous, a _Send node from one run + // might be matched with a _Recv node of a different run. Not setting the + // rendezvous causes a new rendezvous to be used for each run. Notification done_notif; flib_->Run(opts, fhandle_, inputs, combined_outputs, diff --git a/tensorflow/python/ops/batch_ops_test.py b/tensorflow/python/ops/batch_ops_test.py index e54e69a2366..b63c4c8b5f1 100644 --- a/tensorflow/python/ops/batch_ops_test.py +++ b/tensorflow/python/ops/batch_ops_test.py @@ -292,6 +292,31 @@ class BatchOpsTest(test.TestCase): self.assertEqual(thread_results[0], [10 + test_util.is_gpu_available()]) self.assertEqual(main_results[0], [20 + test_util.is_gpu_available()]) + def testParallelRunsWithCpuAndGpu(self): + # Run multiple instances of a batch function in parallel. This is a + # regression test: this used to fail because _Send nodes for one call would + # send the tensor to the _Recv node for a different call. + if context.executing_eagerly(): + return + @batch_ops.batch_function(1, 2, 1) + def f(x): + with ops.device("/GPU:0"): + x = x + 1. + with ops.device("/CPU:0"): + return x + 1 + num_calls = 10 + placeholders = [array_ops.placeholder(dtypes.float32, shape=(1,)) + for _ in range(num_calls)] + results = [] + for p in placeholders: + (result,) = f(p) + results.append(result) + inputs = [[float(i)] for i in range(num_calls)] + expected = [[float(i + 2)] for i in range(num_calls)] + with self.session() as sess: + outputs = sess.run(results, feed_dict=dict(zip(placeholders, inputs))) + self.assertAllEqual(outputs, expected) + def testSoftPlacement(self): if context.executing_eagerly(): return @@ -415,9 +440,6 @@ class BatchOpsTest(test.TestCase): def testBatchFunctionOpWithLargeBatchSplitted(self): """Tests that the batch_function op works with large batch splitted.""" - if test_util.is_xla_enabled(): - self.skipTest("b/178649404") - if context.executing_eagerly(): return