diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index d49af0007b7..484f8112e9e 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -4620,6 +4620,45 @@ func KmeansPlusPlusInitialization(scope *Scope, points tf.Output, num_to_sample return op.Output(0) } +// CollectiveBcastRecvV2Attr is an optional argument to CollectiveBcastRecvV2. +type CollectiveBcastRecvV2Attr func(optionalAttr) + +// CollectiveBcastRecvV2CommunicationHint sets the optional communication_hint attribute to value. +// If not specified, defaults to "auto" +func CollectiveBcastRecvV2CommunicationHint(value string) CollectiveBcastRecvV2Attr { + return func(m optionalAttr) { + m["communication_hint"] = value + } +} + +// CollectiveBcastRecvV2TimeoutSeconds sets the optional timeout_seconds attribute to value. +// If not specified, defaults to 0 +func CollectiveBcastRecvV2TimeoutSeconds(value float32) CollectiveBcastRecvV2Attr { + return func(m optionalAttr) { + m["timeout_seconds"] = value + } +} + +// Receives a tensor value broadcast from another device. +func CollectiveBcastRecvV2(scope *Scope, group_size tf.Output, group_key tf.Output, instance_key tf.Output, shape tf.Output, T tf.DataType, optional ...CollectiveBcastRecvV2Attr) (data tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"T": T} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "CollectiveBcastRecvV2", + Input: []tf.Input{ + group_size, group_key, instance_key, shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // AbortAttr is an optional argument to Abort. type AbortAttr func(optionalAttr) @@ -49360,6 +49399,45 @@ func LoadTPUEmbeddingRMSPropParameters(scope *Scope, parameters tf.Output, ms tf return scope.AddOperation(opspec) } +// CollectiveBcastSendV2Attr is an optional argument to CollectiveBcastSendV2. +type CollectiveBcastSendV2Attr func(optionalAttr) + +// CollectiveBcastSendV2CommunicationHint sets the optional communication_hint attribute to value. +// If not specified, defaults to "auto" +func CollectiveBcastSendV2CommunicationHint(value string) CollectiveBcastSendV2Attr { + return func(m optionalAttr) { + m["communication_hint"] = value + } +} + +// CollectiveBcastSendV2TimeoutSeconds sets the optional timeout_seconds attribute to value. +// If not specified, defaults to 0 +func CollectiveBcastSendV2TimeoutSeconds(value float32) CollectiveBcastSendV2Attr { + return func(m optionalAttr) { + m["timeout_seconds"] = value + } +} + +// Broadcasts a tensor value to one or more other devices. +func CollectiveBcastSendV2(scope *Scope, input tf.Output, group_size tf.Output, group_key tf.Output, instance_key tf.Output, optional ...CollectiveBcastSendV2Attr) (data tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "CollectiveBcastSendV2", + Input: []tf.Input{ + input, group_size, group_key, instance_key, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // InfeedEnqueueTupleAttr is an optional argument to InfeedEnqueueTuple. type InfeedEnqueueTupleAttr func(optionalAttr)