From 97afdaef078de6831ac221a2bda2ecae4a92fa44 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 22 Jan 2021 18:46:11 -0800 Subject: [PATCH] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 353358654 Change-Id: I3318addb3547781d5434763705db0043e8d1ee20 --- tensorflow/go/op/wrappers.go | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index b53e75f3af0..64c61c1559d 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -33195,16 +33195,32 @@ func BoostedTreesQuantileStreamResourceHandleOp(scope *Scope, optional ...Booste return op.Output(0) } +// XlaShardingAttr is an optional argument to XlaSharding. +type XlaShardingAttr func(optionalAttr) + +// XlaShardingSharding sets the optional sharding attribute to value. +// If not specified, defaults to "" +func XlaShardingSharding(value string) XlaShardingAttr { + return func(m optionalAttr) { + m["sharding"] = value + } +} + // An op which shards the input based on the given sharding attribute. -func XlaSharding(scope *Scope, input tf.Output) (output tf.Output) { +func XlaSharding(scope *Scope, input tf.Output, optional ...XlaShardingAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ Type: "XlaSharding", Input: []tf.Input{ input, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0)