diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc index b52b4ab563d..5c0e6cd6524 100644 --- a/tensorflow/core/kernels/batch_kernels.cc +++ b/tensorflow/core/kernels/batch_kernels.cc @@ -160,9 +160,12 @@ class BatchResource : public serving::BatchResourceBase { opts.cancellation_manager = last_task_context->cancellation_manager(); opts.collective_executor = last_task_context->collective_executor(); opts.stats_collector = last_task_context->stats_collector(); - opts.rendezvous = last_task_context->rendezvous(); opts.runner = last_task_context->runner(); opts.run_all_kernels_inline = last_task_context->run_all_kernels_inline(); + // We do not set 'opts.rendezvous', since if the function is run multiple + // times in parallel with the same rendezvous, a _Send node from one run + // might be matched with a _Recv node of a different run. Not setting the + // rendezvous causes a new rendezvous to be used for each run. Notification done_notif; flib_->Run(opts, fhandle_, inputs, combined_outputs, diff --git a/tensorflow/python/ops/batch_ops_test.py b/tensorflow/python/ops/batch_ops_test.py index e54e69a2366..b63c4c8b5f1 100644 --- a/tensorflow/python/ops/batch_ops_test.py +++ b/tensorflow/python/ops/batch_ops_test.py @@ -292,6 +292,31 @@ class BatchOpsTest(test.TestCase): self.assertEqual(thread_results[0], [10 + test_util.is_gpu_available()]) self.assertEqual(main_results[0], [20 + test_util.is_gpu_available()]) + def testParallelRunsWithCpuAndGpu(self): + # Run multiple instances of a batch function in parallel. This is a + # regression test: this used to fail because _Send nodes for one call would + # send the tensor to the _Recv node for a different call. + if context.executing_eagerly(): + return + @batch_ops.batch_function(1, 2, 1) + def f(x): + with ops.device("/GPU:0"): + x = x + 1. + with ops.device("/CPU:0"): + return x + 1 + num_calls = 10 + placeholders = [array_ops.placeholder(dtypes.float32, shape=(1,)) + for _ in range(num_calls)] + results = [] + for p in placeholders: + (result,) = f(p) + results.append(result) + inputs = [[float(i)] for i in range(num_calls)] + expected = [[float(i + 2)] for i in range(num_calls)] + with self.session() as sess: + outputs = sess.run(results, feed_dict=dict(zip(placeholders, inputs))) + self.assertAllEqual(outputs, expected) + def testSoftPlacement(self): if context.executing_eagerly(): return @@ -415,9 +440,6 @@ class BatchOpsTest(test.TestCase): def testBatchFunctionOpWithLargeBatchSplitted(self): """Tests that the batch_function op works with large batch splitted.""" - if test_util.is_xla_enabled(): - self.skipTest("b/178649404") - if context.executing_eagerly(): return