Go: Update generated wrapper functions for TensorFlow ops.
PiperOrigin-RevId: 348050910 Change-Id: Ib86ac36d5e76d498b7d32e5db1b2b867e224209b
This commit is contained in:
parent
911a08098a
commit
869022b4d6
@ -4620,6 +4620,45 @@ func KmeansPlusPlusInitialization(scope *Scope, points tf.Output, num_to_sample
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// CollectiveBcastRecvV2Attr is an optional argument to CollectiveBcastRecvV2.
|
||||
type CollectiveBcastRecvV2Attr func(optionalAttr)
|
||||
|
||||
// CollectiveBcastRecvV2CommunicationHint sets the optional communication_hint attribute to value.
|
||||
// If not specified, defaults to "auto"
|
||||
func CollectiveBcastRecvV2CommunicationHint(value string) CollectiveBcastRecvV2Attr {
|
||||
return func(m optionalAttr) {
|
||||
m["communication_hint"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// CollectiveBcastRecvV2TimeoutSeconds sets the optional timeout_seconds attribute to value.
|
||||
// If not specified, defaults to 0
|
||||
func CollectiveBcastRecvV2TimeoutSeconds(value float32) CollectiveBcastRecvV2Attr {
|
||||
return func(m optionalAttr) {
|
||||
m["timeout_seconds"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Receives a tensor value broadcast from another device.
|
||||
func CollectiveBcastRecvV2(scope *Scope, group_size tf.Output, group_key tf.Output, instance_key tf.Output, shape tf.Output, T tf.DataType, optional ...CollectiveBcastRecvV2Attr) (data tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
attrs := map[string]interface{}{"T": T}
|
||||
for _, a := range optional {
|
||||
a(attrs)
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "CollectiveBcastRecvV2",
|
||||
Input: []tf.Input{
|
||||
group_size, group_key, instance_key, shape,
|
||||
},
|
||||
Attrs: attrs,
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// AbortAttr is an optional argument to Abort.
|
||||
type AbortAttr func(optionalAttr)
|
||||
|
||||
@ -49360,6 +49399,45 @@ func LoadTPUEmbeddingRMSPropParameters(scope *Scope, parameters tf.Output, ms tf
|
||||
return scope.AddOperation(opspec)
|
||||
}
|
||||
|
||||
// CollectiveBcastSendV2Attr is an optional argument to CollectiveBcastSendV2.
|
||||
type CollectiveBcastSendV2Attr func(optionalAttr)
|
||||
|
||||
// CollectiveBcastSendV2CommunicationHint sets the optional communication_hint attribute to value.
|
||||
// If not specified, defaults to "auto"
|
||||
func CollectiveBcastSendV2CommunicationHint(value string) CollectiveBcastSendV2Attr {
|
||||
return func(m optionalAttr) {
|
||||
m["communication_hint"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// CollectiveBcastSendV2TimeoutSeconds sets the optional timeout_seconds attribute to value.
|
||||
// If not specified, defaults to 0
|
||||
func CollectiveBcastSendV2TimeoutSeconds(value float32) CollectiveBcastSendV2Attr {
|
||||
return func(m optionalAttr) {
|
||||
m["timeout_seconds"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcasts a tensor value to one or more other devices.
|
||||
func CollectiveBcastSendV2(scope *Scope, input tf.Output, group_size tf.Output, group_key tf.Output, instance_key tf.Output, optional ...CollectiveBcastSendV2Attr) (data tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
attrs := map[string]interface{}{}
|
||||
for _, a := range optional {
|
||||
a(attrs)
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "CollectiveBcastSendV2",
|
||||
Input: []tf.Input{
|
||||
input, group_size, group_key, instance_key,
|
||||
},
|
||||
Attrs: attrs,
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// InfeedEnqueueTupleAttr is an optional argument to InfeedEnqueueTuple.
|
||||
type InfeedEnqueueTupleAttr func(optionalAttr)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user