From e38610399a1293e5f8f1b1c290c355149bed1418 Mon Sep 17 00:00:00 2001
From: Ayush Dubey <ayushd@google.com>
Date: Thu, 17 Dec 2020 10:40:16 -0800
Subject: [PATCH] Add CollectiveBcastSend/RecvV2 which takes input tensors
 rather than attributes.

CollectiveBcastSend/Recv accepts the following inputs as attributes on the op:
group_size, group_key, and instance_key.  Attributes imply these values are
embedded in the NodeDef.

The use case motivating this change is a compact representation for SPMD
computation.  The goal is to change those inputs from the collective op which
can be accepted during runtime to tensors rather than attributes.  This enables
the graph builder to avoid early explosion of the SPMD program.

This op is not exposed in the `tf.` namespace for now, and should be considered
experimental.

PiperOrigin-RevId: 348049802
Change-Id: I15ffcb562fb75cc64e5970ac7795cb81ff187fd8
---
 .../api_def_CollectiveBcastRecvV2.pbtxt       |   5 +
 .../api_def_CollectiveBcastSendV2.pbtxt       |   5 +
 .../grappler/optimizers/function_optimizer.cc |   3 +-
 tensorflow/core/kernels/collective_ops.cc     | 257 ++++++++++++++++++
 tensorflow/core/ops/collective_ops.cc         |  31 +++
 .../python/framework/auto_control_deps.py     |   2 +
 .../kernel_tests/collective_ops_test.py       |  57 ++++
 tensorflow/python/ops/collective_ops.py       |  72 +++++
 .../api/golden/v1/tensorflow.raw_ops.pbtxt    |   8 +
 .../api/golden/v2/tensorflow.raw_ops.pbtxt    |   8 +
 10 files changed, 447 insertions(+), 1 deletion(-)
 create mode 100644 tensorflow/core/api_def/base_api/api_def_CollectiveBcastRecvV2.pbtxt
 create mode 100644 tensorflow/core/api_def/base_api/api_def_CollectiveBcastSendV2.pbtxt

diff --git a/tensorflow/core/api_def/base_api/api_def_CollectiveBcastRecvV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_CollectiveBcastRecvV2.pbtxt
new file mode 100644
index 00000000000..71148889b2b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_CollectiveBcastRecvV2.pbtxt
@@ -0,0 +1,5 @@
+op {
+  graph_op_name: "CollectiveBcastRecvV2"
+  summary: "Receives a tensor value broadcast from another device."
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_CollectiveBcastSendV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_CollectiveBcastSendV2.pbtxt
new file mode 100644
index 00000000000..8d0ccedcedd
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_CollectiveBcastSendV2.pbtxt
@@ -0,0 +1,5 @@
+op {
+  graph_op_name: "CollectiveBcastSendV2"
+  summary: "Broadcasts a tensor value to one or more other devices."
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc
index ad9f3d5ed8f..96c533994a8 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc
@@ -830,7 +830,8 @@ const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
        // to run asynchronously to avoid deadlock.
        "CollectiveGather", "CollectiveGatherV2", "CollectiveReduce",
        "CollectiveReduceV2", "CollectiveBcastSend", "CollectiveBcastRecv",
-       "NcclAllReduce", "Send", "Recv",
+       "CollectiveBcastSendV2", "CollectiveBcastRecvV2", "NcclAllReduce",
+       "Send", "Recv",
 
        // Legacy random ops.
        // See details in tensorflow/python/framework/auto_control_deps.py.
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
index 357ae158ea1..f8f18751d86 100644
--- a/tensorflow/core/kernels/collective_ops.cc
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -16,6 +16,7 @@ limitations under the License.
 #include "tensorflow/core/framework/collective.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_util.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
@@ -742,5 +743,261 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveGatherV2")
                             .HostMemory("instance_key"),
                         CollectiveGatherV2OpKernel);
 
+class CollectiveBcastSendV2OpKernel : public AsyncOpKernel {
+ public:
+  explicit CollectiveBcastSendV2OpKernel(OpKernelConstruction* c)
+      : AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) {
+    OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
+    OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
+    OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
+    const bool is_source = true;
+    name_ = strings::StrCat(name(), ": Broadcast(", is_source, ")");
+  }
+
+ protected:
+  void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+    CollectiveExecutor* col_exec = c->collective_executor();
+    OP_REQUIRES_ASYNC(
+        c, col_exec,
+        errors::Internal(
+            "Failed to get CollectiveExecutor from OpKernelContext for Op ",
+            name_),
+        done);
+    const Tensor& input = c->input(0);
+    const Tensor& group_size = c->input(1);
+    const Tensor& group_key = c->input(2);
+    const Tensor& instance_key = c->input(3);
+    OP_REQUIRES_ASYNC(
+        c, group_size.dims() == 0,
+        errors::Internal("Unexpected dimensions on input group_size"), done);
+    OP_REQUIRES_ASYNC(
+        c, group_key.dims() == 0,
+        errors::Internal("Unexpected dimensions on input group_key"), done);
+    OP_REQUIRES_ASYNC(
+        c, instance_key.dims() == 0,
+        errors::Internal("Unexpected dimensions on input instance_key"), done);
+
+    auto col_params = new CollectiveParams();
+    col_params->name = name_;
+    col_params->group.device_type = device_type_;
+    col_params->group.group_size = group_size.unaligned_flat<int32>()(0);
+    col_params->group.group_key = group_key.unaligned_flat<int32>()(0);
+    col_params->instance.type = BROADCAST_COLLECTIVE;
+    col_params->instance.instance_key = instance_key.unaligned_flat<int32>()(0);
+    col_params->instance.data_type = data_type_;
+    col_params->instance.impl_details.communication_hint = communication_hint_;
+    col_params->instance.impl_details.timeout_seconds = timeout_seconds_;
+    col_params->is_source = true;
+    // Add a default value for subdiv offsets, which is the same as the default
+    // value in the V1 op's attribute.
+    col_params->instance.impl_details.subdiv_offsets.push_back(0);
+    VLOG(1) << "CollectiveBcastSendV2 group_size "
+            << col_params->group.group_size << " group_key "
+            << col_params->group.group_key << " instance_key "
+            << col_params->instance.instance_key;
+
+    auto done_with_cleanup = [col_params, done = std::move(done)]() {
+      delete col_params;
+      done();
+    };
+
+    // Allocate the output tensor, trying to reuse the input.
+    Tensor* output = nullptr;
+    OP_REQUIRES_OK_ASYNC(
+        c, c->forward_input_or_allocate_output({0}, 0, input.shape(), &output),
+        done_with_cleanup);
+    col_params->instance.shape = input.shape();
+
+    // Resolve the collective params.
+    // Schedule the `CompleteParamsAsync` call on a work queue that can handle
+    // blocking work because it's not guaranteed that this call cannot block.
+    c->collective_executor()->RunClosure([c,
+                                          done = std::move(done_with_cleanup),
+                                          col_params, col_exec]() {
+      VLOG(1) << "CollectiveBcastSendV2 CompleteParams for collective "
+              << col_params->name << " device " << c->device()->name()
+              << " group " << col_params->group.group_key << " instance "
+              << col_params->instance.instance_key;
+      col_exec->CompleteParamsAsync(
+          c->device()->attributes(), col_params, c->cancellation_manager(),
+          [c, done = std::move(done), col_params, col_exec](const Status& s) {
+            if (s.ok()) {
+              auto actual_done = [c, group_key = col_params->group.group_key,
+                                  instance_key =
+                                      col_params->instance.instance_key,
+                                  done = std::move(done)](const Status& s) {
+                VLOG(1) << "CollectiveBcastSendV2 ExecuteAsync done for "
+                           "collective "
+                        << c->op_kernel().name() << " device "
+                        << c->device()->name() << " group " << group_key
+                        << " instance " << instance_key << " status " << s;
+                OP_REQUIRES_OK_ASYNC(c, s, done);
+                done();
+              };
+              VLOG(1) << "CollectiveBcastSendV2 ExecuteAsync start for "
+                         "collective "
+                      << col_params->name << " device " << c->device()->name()
+                      << " group " << col_params->group.group_key
+                      << " instance " << col_params->instance.instance_key;
+              col_exec->ExecuteAsync(
+                  c, *col_params,
+                  CollectiveKey(c, col_params->group.group_key,
+                                col_params->instance.instance_key),
+                  actual_done);
+            } else {
+              c->SetStatus(s);
+              done();
+            }
+          });
+    });
+  }
+
+ private:
+  DeviceType device_type_;
+  DataType data_type_ = DT_INVALID;
+  string communication_hint_;
+  float timeout_seconds_ = 0;
+  string name_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSendV2").Device(DEVICE_CPU),
+                        CollectiveBcastSendV2OpKernel);
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSendV2")
+                            .Device(DEVICE_GPU)
+                            .HostMemory("group_size")
+                            .HostMemory("group_key")
+                            .HostMemory("instance_key"),
+                        CollectiveBcastSendV2OpKernel);
+
+class CollectiveBcastRecvV2OpKernel : public AsyncOpKernel {
+ public:
+  explicit CollectiveBcastRecvV2OpKernel(OpKernelConstruction* c)
+      : AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) {
+    OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
+    OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
+    OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
+    const bool is_source = false;
+    name_ = strings::StrCat(name(), ": Broadcast(", is_source, ")");
+  }
+
+ protected:
+  void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+    CollectiveExecutor* col_exec = c->collective_executor();
+    OP_REQUIRES_ASYNC(
+        c, col_exec,
+        errors::Internal(
+            "Failed to get CollectiveExecutor from OpKernelContext for Op ",
+            name_),
+        done);
+    const Tensor& group_size = c->input(0);
+    const Tensor& group_key = c->input(1);
+    const Tensor& instance_key = c->input(2);
+    const Tensor& shape_tensor = c->input(3);
+    OP_REQUIRES_ASYNC(
+        c, group_size.dims() == 0,
+        errors::Internal("Unexpected dimensions on input group_size"), done);
+    OP_REQUIRES_ASYNC(
+        c, group_key.dims() == 0,
+        errors::Internal("Unexpected dimensions on input group_key"), done);
+    OP_REQUIRES_ASYNC(
+        c, instance_key.dims() == 0,
+        errors::Internal("Unexpected dimensions on input instance_key"), done);
+
+    auto col_params = new CollectiveParams();
+    auto done_with_cleanup = [col_params, done = std::move(done)]() {
+      delete col_params;
+      done();
+    };
+
+    OP_REQUIRES_OK_ASYNC(
+        c, tensor::MakeShape(shape_tensor, &col_params->instance.shape),
+        done_with_cleanup);
+    col_params->name = name_;
+    col_params->group.device_type = device_type_;
+    col_params->group.group_size = group_size.unaligned_flat<int32>()(0);
+    col_params->group.group_key = group_key.unaligned_flat<int32>()(0);
+    col_params->instance.type = BROADCAST_COLLECTIVE;
+    col_params->instance.instance_key = instance_key.unaligned_flat<int32>()(0);
+    col_params->instance.data_type = data_type_;
+    col_params->instance.impl_details.communication_hint = communication_hint_;
+    col_params->instance.impl_details.timeout_seconds = timeout_seconds_;
+    col_params->is_source = false;
+    // Add a default value for subdiv offsets, which is the same as the default
+    // value in the V1 op's attribute.
+    col_params->instance.impl_details.subdiv_offsets.push_back(0);
+    VLOG(1) << "CollectiveBcastRecvV2 group_size "
+            << col_params->group.group_size << " group_key "
+            << col_params->group.group_key << " instance_key "
+            << col_params->instance.instance_key;
+
+    // Allocate the output tensor.
+    Tensor* output = nullptr;
+    OP_REQUIRES_OK_ASYNC(c,
+                         c->forward_input_or_allocate_output(
+                             {0}, 0, col_params->instance.shape, &output),
+                         done_with_cleanup);
+
+    // Resolve the collective params.
+    // Schedule the `CompleteParamsAsync` call on a work queue that can handle
+    // blocking work because it's not guaranteed that this call cannot block.
+    c->collective_executor()->RunClosure([c,
+                                          done = std::move(done_with_cleanup),
+                                          col_params, col_exec]() {
+      VLOG(1) << "CollectiveBcastRecvV2 CompleteParams for collective "
+              << col_params->name << " device " << c->device()->name()
+              << " group " << col_params->group.group_key << " instance "
+              << col_params->instance.instance_key;
+      col_exec->CompleteParamsAsync(
+          c->device()->attributes(), col_params, c->cancellation_manager(),
+          [c, done = std::move(done), col_params, col_exec](const Status& s) {
+            if (s.ok()) {
+              auto actual_done = [c, group_key = col_params->group.group_key,
+                                  instance_key =
+                                      col_params->instance.instance_key,
+                                  done = std::move(done)](const Status& s) {
+                VLOG(1) << "CollectiveBcastRecvV2 ExecuteAsync done for "
+                           "collective "
+                        << c->op_kernel().name() << " device "
+                        << c->device()->name() << " group " << group_key
+                        << " instance " << instance_key << " status " << s;
+                OP_REQUIRES_OK_ASYNC(c, s, done);
+                done();
+              };
+              VLOG(1) << "CollectiveBcastRecvV2 ExecuteAsync start for "
+                         "collective "
+                      << col_params->name << " device " << c->device()->name()
+                      << " group " << col_params->group.group_key
+                      << " instance " << col_params->instance.instance_key;
+              col_exec->ExecuteAsync(
+                  c, *col_params,
+                  CollectiveKey(c, col_params->group.group_key,
+                                col_params->instance.instance_key),
+                  actual_done);
+            } else {
+              c->SetStatus(s);
+              done();
+            }
+          });
+    });
+  }
+
+ private:
+  DeviceType device_type_;
+  DataType data_type_ = DT_INVALID;
+  string communication_hint_;
+  float timeout_seconds_ = 0;
+  string name_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecvV2").Device(DEVICE_CPU),
+                        CollectiveBcastRecvV2OpKernel);
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecvV2")
+                            .Device(DEVICE_GPU)
+                            .HostMemory("group_size")
+                            .HostMemory("group_key")
+                            .HostMemory("instance_key")
+                            .HostMemory("shape"),
+                        CollectiveBcastRecvV2OpKernel);
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/core/ops/collective_ops.cc b/tensorflow/core/ops/collective_ops.cc
index 9033c7154ff..e58db89f6b5 100644
--- a/tensorflow/core/ops/collective_ops.cc
+++ b/tensorflow/core/ops/collective_ops.cc
@@ -145,4 +145,35 @@ REGISTER_OP("CollectiveGatherV2")
       return Status::OK();
     });
 
+REGISTER_OP("CollectiveBcastSendV2")
+    .Input("input: T")
+    .Output("data: T")
+    .Attr("T: {bool, float, float16, float64, int32, int64}")
+    .Input("group_size: int32")
+    .Input("group_key: int32")
+    .Input("instance_key: int32")
+    .Attr("communication_hint: string = 'auto'")
+    .Attr("timeout_seconds: float = 0")
+    .SetIsStateful()
+    .SetShapeFn(shape_inference::UnchangedShape);
+
+REGISTER_OP("CollectiveBcastRecvV2")
+    .Output("data: T")
+    .Attr("T: {bool, float, float16, float64, int32, int64}")
+    .Input("group_size: int32")
+    .Input("group_key: int32")
+    .Input("instance_key: int32")
+    .Input("shape: Tshape")
+    .Attr("Tshape: {int32, int64} = DT_INT32")
+    .Attr("communication_hint: string = 'auto'")
+    .Attr("timeout_seconds: float = 0")
+    .SetIsStateful()
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      // The output shape is given by the `shape` input at index 3.
+      shape_inference::ShapeHandle out;
+      TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(/*input_idx=*/3, &out));
+      c->set_output(/*idx=*/0, out);
+      return Status::OK();
+    });
+
 }  // namespace tensorflow
diff --git a/tensorflow/python/framework/auto_control_deps.py b/tensorflow/python/framework/auto_control_deps.py
index 3f6ab98605b..fe78637a93c 100644
--- a/tensorflow/python/framework/auto_control_deps.py
+++ b/tensorflow/python/framework/auto_control_deps.py
@@ -45,7 +45,9 @@ ASYNC_STATEFUL_OPS = [
     "CollectiveReduce",
     "CollectiveReduceV2",
     "CollectiveBcastSend",
+    "CollectiveBcastSendV2",
     "CollectiveBcastRecv",
+    "CollectiveBcastRecvV2",
     "NcclAllReduce",
     # We do not add "Send" here since we want it to be added as a control output
     # in order to avoid being pruned.
diff --git a/tensorflow/python/kernel_tests/collective_ops_test.py b/tensorflow/python/kernel_tests/collective_ops_test.py
index fdaf3213759..916b23489e7 100644
--- a/tensorflow/python/kernel_tests/collective_ops_test.py
+++ b/tensorflow/python/kernel_tests/collective_ops_test.py
@@ -43,6 +43,8 @@ from tensorflow.python.platform import test
 class CollectiveOpsV1(object):
   all_reduce = _collective_ops.all_reduce
   all_gather = _collective_ops.all_gather
+  broadcast_send = _collective_ops.broadcast_send
+  broadcast_recv = _collective_ops.broadcast_recv
 
 
 class CollectiveOpsV2(object):
@@ -63,6 +65,25 @@ class CollectiveOpsV2(object):
     return _collective_ops.all_gather_v2(t, group_size, group_key, instance_key,
                                          *args, **kwargs)
 
+  @staticmethod
+  def broadcast_send(t, shape, dtype, group_size, group_key, instance_key,
+                     *args, **kwargs):
+    group_size = array_ops.identity(group_size)
+    group_key = array_ops.identity(group_key)
+    instance_key = array_ops.identity(instance_key)
+    return _collective_ops.broadcast_send_v2(t, group_size, group_key,
+                                             instance_key, *args, **kwargs)
+
+  @staticmethod
+  def broadcast_recv(shape, dtype, group_size, group_key, instance_key, *args,
+                     **kwargs):
+    group_size = array_ops.identity(group_size)
+    group_key = array_ops.identity(group_key)
+    instance_key = array_ops.identity(instance_key)
+    shape = array_ops.identity(shape)
+    return _collective_ops.broadcast_recv_v2(
+        shape, dtype, group_size, group_key, instance_key, *args, **kwargs)
+
 
 device_combination = (
     combinations.combine(device='CPU', communication='RING', required_gpus=0) +
@@ -191,6 +212,42 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
     for result in run_all_gather_2devices():
       self.assertAllClose(result, [1., 1.], rtol=1e-5, atol=1e-5)
 
+  def testBroadcast(self, collective_ops, device, communication):
+    dev0 = '/device:%s:0' % device
+    dev1 = '/device:%s:1' % device
+
+    @def_function.function
+    def run_broadcast_2devices():
+      shape = [3]
+      in_value = constant_op.constant([1., 2., 3.], shape=shape)
+      group_size = 2
+      group_key = 2
+      instance_key = 2
+      collectives = []
+      with ops.device(dev0):
+        collectives.append(
+            collective_ops.broadcast_send(
+                in_value,
+                shape,
+                in_value.dtype,
+                group_size,
+                group_key,
+                instance_key,
+                communication_hint=communication))
+      with ops.device(dev1):
+        collectives.append(
+            collective_ops.broadcast_recv(
+                shape,
+                in_value.dtype,
+                group_size,
+                group_key,
+                instance_key,
+                communication_hint=communication))
+      return collectives
+
+    for result in run_broadcast_2devices():
+      self.assertAllClose(result, [1., 2., 3.], rtol=1e-5, atol=1e-5)
+
   def testInstanceKeyScopedUnderGroupKey(self, collective_ops, device,
                                          communication):
     if device == 'GPU' and context.num_gpus() < 4:
diff --git a/tensorflow/python/ops/collective_ops.py b/tensorflow/python/ops/collective_ops.py
index 4f33c3aeecc..fcbd0ca826d 100644
--- a/tensorflow/python/ops/collective_ops.py
+++ b/tensorflow/python/ops/collective_ops.py
@@ -261,6 +261,40 @@ def broadcast_send(t,
       timeout_seconds=timeout)
 
 
+def broadcast_send_v2(t,
+                      group_size,
+                      group_key,
+                      instance_key,
+                      communication_hint='auto',
+                      timeout=0):
+  """Broadcasts one tensor to a group of others, across devices.
+
+  Args:
+    t: the tensor to be sent.
+    group_size: an int32 tensor.  One plus the number of receiving tensors, i.e.
+        the total number of devices participating.  Each tensor must reside on a
+        different device.
+    group_key: an int32 tensor identifying the group of devices.
+    instance_key: an int32 tensor identifying the participating group of Ops.
+    communication_hint: preferred collective communication.  The implementation
+      may fall back to another mechanism.  Options include `auto`, `ring`, and
+      `nccl`.
+    timeout: If set to a non zero, set a completion timeout to detect staleness.
+      If the timer goes off, a DeadlineExceededError is raised.
+      The timeout value in seconds. This feature is experimental.
+
+  Returns:
+    An Op implementing the distributed broadcast send.
+  """
+  return gen_collective_ops.collective_bcast_send_v2(
+      t,
+      group_size=group_size,
+      group_key=group_key,
+      instance_key=instance_key,
+      communication_hint=communication_hint.lower(),
+      timeout_seconds=timeout)
+
+
 def broadcast_recv(shape,
                    dtype,
                    group_size,
@@ -302,3 +336,41 @@ def broadcast_recv(shape,
       instance_key=instance_key,
       communication_hint=communication_hint.lower(),
       timeout_seconds=timeout)
+
+
+def broadcast_recv_v2(shape,
+                      dtype,
+                      group_size,
+                      group_key,
+                      instance_key,
+                      communication_hint='auto',
+                      timeout=0):
+  """Receives a broadcasts tensor, across devices.
+
+  Args:
+    shape: an int tensor.  Shape of the tensor to be received.
+    dtype: Type of the tensor to be received.
+    group_size: an int32 tensor.  One plus the number of receiving tensors, i.e.
+        the total number of devices participating.  Each tensor must reside on a
+        different device.
+    group_key: an int32 tensor identifying the group of devices.
+    instance_key: an int32 tensor identifying the participating group of Ops.
+    communication_hint: preferred collective communication.  The implementation
+      may fall back to another mechanism.  Options include `auto`, `ring`, and
+      `nccl`.
+    timeout: If set to a non zero, set a completion timeout to detect staleness.
+      If the timer goes off, a DeadlineExceededError is raised.
+      The timeout value in seconds. This feature is experimental.
+
+  Returns:
+    An Op implementing the broadcast receive.
+  """
+  return gen_collective_ops.collective_bcast_recv_v2(
+      T=dtype,
+      group_size=group_size,
+      group_key=group_key,
+      instance_key=instance_key,
+      shape=shape,
+      communication_hint=communication_hint.lower(),
+      timeout_seconds=timeout)
+
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 21dd896e68d..7f6943d9e67 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -752,10 +752,18 @@ tf_module {
     name: "CollectiveBcastRecv"
     argspec: "args=[\'T\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
   }
+  member_method {
+    name: "CollectiveBcastRecvV2"
+    argspec: "args=[\'group_size\', \'group_key\', \'instance_key\', \'shape\', \'T\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
+  }
   member_method {
     name: "CollectiveBcastSend"
     argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
   }
+  member_method {
+    name: "CollectiveBcastSendV2"
+    argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
+  }
   member_method {
     name: "CollectiveGather"
     argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 21dd896e68d..7f6943d9e67 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -752,10 +752,18 @@ tf_module {
     name: "CollectiveBcastRecv"
     argspec: "args=[\'T\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
   }
+  member_method {
+    name: "CollectiveBcastRecvV2"
+    argspec: "args=[\'group_size\', \'group_key\', \'instance_key\', \'shape\', \'T\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
+  }
   member_method {
     name: "CollectiveBcastSend"
     argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
   }
+  member_method {
+    name: "CollectiveBcastSendV2"
+    argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
+  }
   member_method {
     name: "CollectiveGather"
     argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "