Go: Update generated wrapper functions for TensorFlow ops.
PiperOrigin-RevId: 353358654 Change-Id: I3318addb3547781d5434763705db0043e8d1ee20
This commit is contained in:
parent
62a1fad4fa
commit
97afdaef07
@ -33195,16 +33195,32 @@ func BoostedTreesQuantileStreamResourceHandleOp(scope *Scope, optional ...Booste
|
|||||||
return op.Output(0)
|
return op.Output(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// XlaShardingAttr is an optional argument to XlaSharding.
|
||||||
|
type XlaShardingAttr func(optionalAttr)
|
||||||
|
|
||||||
|
// XlaShardingSharding sets the optional sharding attribute to value.
|
||||||
|
// If not specified, defaults to ""
|
||||||
|
func XlaShardingSharding(value string) XlaShardingAttr {
|
||||||
|
return func(m optionalAttr) {
|
||||||
|
m["sharding"] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// An op which shards the input based on the given sharding attribute.
|
// An op which shards the input based on the given sharding attribute.
|
||||||
func XlaSharding(scope *Scope, input tf.Output) (output tf.Output) {
|
func XlaSharding(scope *Scope, input tf.Output, optional ...XlaShardingAttr) (output tf.Output) {
|
||||||
if scope.Err() != nil {
|
if scope.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
attrs := map[string]interface{}{}
|
||||||
|
for _, a := range optional {
|
||||||
|
a(attrs)
|
||||||
|
}
|
||||||
opspec := tf.OpSpec{
|
opspec := tf.OpSpec{
|
||||||
Type: "XlaSharding",
|
Type: "XlaSharding",
|
||||||
Input: []tf.Input{
|
Input: []tf.Input{
|
||||||
input,
|
input,
|
||||||
},
|
},
|
||||||
|
Attrs: attrs,
|
||||||
}
|
}
|
||||||
op := scope.AddOperation(opspec)
|
op := scope.AddOperation(opspec)
|
||||||
return op.Output(0)
|
return op.Output(0)
|
||||||
|
Loading…
Reference in New Issue
Block a user