Go: Update generated wrapper functions for TensorFlow ops.
PiperOrigin-RevId: 248763861
This commit is contained in:
parent
335915673e
commit
20b5b02371
@ -2240,6 +2240,17 @@ func DebugGradientIdentity(scope *Scope, input tf.Output) (output tf.Output) {
|
|||||||
return op.Output(0)
|
return op.Output(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GatherV2Attr is an optional argument to GatherV2.
|
||||||
|
type GatherV2Attr func(optionalAttr)
|
||||||
|
|
||||||
|
// GatherV2BatchDims sets the optional batch_dims attribute to value.
|
||||||
|
// If not specified, defaults to 0
|
||||||
|
func GatherV2BatchDims(value int64) GatherV2Attr {
|
||||||
|
return func(m optionalAttr) {
|
||||||
|
m["batch_dims"] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Gather slices from `params` axis `axis` according to `indices`.
|
// Gather slices from `params` axis `axis` according to `indices`.
|
||||||
//
|
//
|
||||||
// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D).
|
// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D).
|
||||||
@ -2279,15 +2290,20 @@ func DebugGradientIdentity(scope *Scope, input tf.Output) (output tf.Output) {
|
|||||||
//
|
//
|
||||||
// Returns Values from `params` gathered from indices given by `indices`, with
|
// Returns Values from `params` gathered from indices given by `indices`, with
|
||||||
// shape `params.shape[:axis] + indices.shape + params.shape[axis + 1:]`.
|
// shape `params.shape[:axis] + indices.shape + params.shape[axis + 1:]`.
|
||||||
func GatherV2(scope *Scope, params tf.Output, indices tf.Output, axis tf.Output) (output tf.Output) {
|
func GatherV2(scope *Scope, params tf.Output, indices tf.Output, axis tf.Output, optional ...GatherV2Attr) (output tf.Output) {
|
||||||
if scope.Err() != nil {
|
if scope.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
attrs := map[string]interface{}{}
|
||||||
|
for _, a := range optional {
|
||||||
|
a(attrs)
|
||||||
|
}
|
||||||
opspec := tf.OpSpec{
|
opspec := tf.OpSpec{
|
||||||
Type: "GatherV2",
|
Type: "GatherV2",
|
||||||
Input: []tf.Input{
|
Input: []tf.Input{
|
||||||
params, indices, axis,
|
params, indices, axis,
|
||||||
},
|
},
|
||||||
|
Attrs: attrs,
|
||||||
}
|
}
|
||||||
op := scope.AddOperation(opspec)
|
op := scope.AddOperation(opspec)
|
||||||
return op.Output(0)
|
return op.Output(0)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user