diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index b53e75f3af0..64c61c1559d 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -33195,16 +33195,32 @@ func BoostedTreesQuantileStreamResourceHandleOp(scope *Scope, optional ...Booste return op.Output(0) } +// XlaShardingAttr is an optional argument to XlaSharding. +type XlaShardingAttr func(optionalAttr) + +// XlaShardingSharding sets the optional sharding attribute to value. +// If not specified, defaults to "" +func XlaShardingSharding(value string) XlaShardingAttr { + return func(m optionalAttr) { + m["sharding"] = value + } +} + // An op which shards the input based on the given sharding attribute. -func XlaSharding(scope *Scope, input tf.Output) (output tf.Output) { +func XlaSharding(scope *Scope, input tf.Output, optional ...XlaShardingAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ Type: "XlaSharding", Input: []tf.Input{ input, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0)