Merge pull request #21895 from Cibifang:golang-add-gradients-with-name
PiperOrigin-RevId: 222855902
This commit is contained in:
commit
082a7b9293
@ -174,6 +174,68 @@ func (g *Graph) Operations() []Operation {
|
||||
return ops
|
||||
}
|
||||
|
||||
// AddGradients adds operations to compute the partial derivatives of the sum of tensors in y
|
||||
// with respect to tensors in x, i.e., d(y[0] + y[1] + ...) / d x[0], d(y[0] + y[1] + ... ) / d x[1] etc.
|
||||
//
|
||||
// prefix, if non-empty, is the name prefix used for all operations added to the graph to compute
|
||||
// these gradients.
|
||||
func (g *Graph) AddGradients(prefix string, y []Output, x []Output, dx []Output) ([]Output, error) {
|
||||
var (
|
||||
cprefix *C.char
|
||||
|
||||
cy = make([]C.TF_Output, len(y))
|
||||
cx = make([]C.TF_Output, len(x))
|
||||
cdx = make([]C.TF_Output, len(dx))
|
||||
cdy = make([]C.TF_Output, len(x))
|
||||
|
||||
pcy *C.TF_Output
|
||||
pcx *C.TF_Output
|
||||
pcdx *C.TF_Output
|
||||
pcdy *C.TF_Output
|
||||
|
||||
status = newStatus()
|
||||
)
|
||||
|
||||
if len(y) > 0 {
|
||||
pcy = &cy[0]
|
||||
for i, o := range y {
|
||||
cy[i] = o.c()
|
||||
}
|
||||
}
|
||||
if len(x) > 0 {
|
||||
pcx = &cx[0]
|
||||
for i, o := range x {
|
||||
cx[i] = o.c()
|
||||
}
|
||||
pcdy = &cdy[0]
|
||||
}
|
||||
if len(dx) > 0 {
|
||||
pcdx = &cdx[0]
|
||||
for i, o := range dx {
|
||||
cdx[i] = o.c()
|
||||
}
|
||||
}
|
||||
|
||||
// If prefix is "", the C.TF_AddGradientsWithPrefix need cprefix to be nil but not ""
|
||||
if len(prefix) != 0 {
|
||||
cprefix = C.CString(prefix)
|
||||
defer C.free(unsafe.Pointer(cprefix))
|
||||
}
|
||||
|
||||
C.TF_AddGradientsWithPrefix(g.c, cprefix, pcy, C.int(len(y)), pcx, C.int(len(x)), pcdx, status.c, pcdy)
|
||||
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dy := make([]Output, len(x))
|
||||
for i, co := range cdy {
|
||||
op := &Operation{co.oper, g}
|
||||
dy[i] = Output{op, int(co.index)}
|
||||
}
|
||||
|
||||
return dy, nil
|
||||
}
|
||||
|
||||
// OpSpec is the specification of an Operation to be added to a Graph
|
||||
// (using Graph.AddOperation).
|
||||
type OpSpec struct {
|
||||
|
@ -19,6 +19,7 @@ package tensorflow
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -80,3 +81,260 @@ func TestGraphWriteToAndImport(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGraphAddGradients(t *testing.T) {
|
||||
g := NewGraph()
|
||||
x1, err := Placeholder(g, "x1", Float)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
x2, err := Placeholder(g, "x2", Float)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
op0, err := g.AddOperation(OpSpec{
|
||||
Type: "Square",
|
||||
Name: "y0",
|
||||
Input: []Input{x1},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
y0 := op0.Output(0)
|
||||
op1, err := g.AddOperation(OpSpec{
|
||||
Type: "Square",
|
||||
Name: "y1",
|
||||
Input: []Input{y0},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
y1 := op1.Output(0)
|
||||
op2, err := g.AddOperation(OpSpec{
|
||||
Type: "AddN",
|
||||
Input: []Input{OutputList([]Output{y0, x2})},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
y2 := op2.Output(0)
|
||||
|
||||
grads0, err := g.AddGradients("", []Output{y1}, []Output{x1}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(grads0) != 1 {
|
||||
t.Fatal(len(grads0))
|
||||
}
|
||||
if grads0[0].DataType() != Float {
|
||||
t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), Float)
|
||||
}
|
||||
|
||||
grads1, err := g.AddGradients("", []Output{y2}, []Output{x1, x2}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(grads1) != 2 {
|
||||
t.Fatal(len(grads1))
|
||||
}
|
||||
if grads1[0].DataType() != Float {
|
||||
t.Fatalf("Got DataType %v, wanted %v", grads1[0].DataType(), Float)
|
||||
}
|
||||
if grads1[1].DataType() != Float {
|
||||
t.Fatalf("Got DataType %v, wanted %v", grads1[1].DataType(), Float)
|
||||
}
|
||||
|
||||
sess, err := NewSession(g, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c1, _ := NewTensor(float32(3.0))
|
||||
c2, _ := NewTensor(float32(2.0))
|
||||
outputs, err := sess.Run(
|
||||
map[Output]*Tensor{x1: c1, x2: c2},
|
||||
[]Output{grads0[0], grads1[0], grads1[1]},
|
||||
nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(outputs) != 3 {
|
||||
t.Fatal(len(outputs))
|
||||
}
|
||||
if outputs[0].Value().(float32) != 108.0 {
|
||||
t.Fatalf("Got %v, wanted float 108.0", outputs[0].Value())
|
||||
}
|
||||
if outputs[1].Value().(float32) != 6.0 {
|
||||
t.Fatalf("Got %v, wanted float 6.0", outputs[1].Value())
|
||||
}
|
||||
if outputs[2].Value().(float32) != 1.0 {
|
||||
t.Fatalf("Got %v, wanted float 1.0", outputs[2].Value())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGraphAddGradientsSums(t *testing.T) {
|
||||
g := NewGraph()
|
||||
x, err := Placeholder(g, "x", Float)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
op0, err := g.AddOperation(OpSpec{
|
||||
Type: "Square",
|
||||
Name: "y0",
|
||||
Input: []Input{x},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
y0 := op0.Output(0)
|
||||
op1, err := g.AddOperation(OpSpec{
|
||||
Type: "Square",
|
||||
Name: "y1",
|
||||
Input: []Input{y0},
|
||||
})
|
||||
y1 := op1.Output(0)
|
||||
|
||||
grad, err := g.AddGradients("", []Output{y0, y1}, []Output{x}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(grad) != 1 {
|
||||
t.Fatal(len(grad))
|
||||
}
|
||||
if grad[0].DataType() != Float {
|
||||
t.Fatalf("Got DataType %v, wanted %v", grad[0].DataType(), Float)
|
||||
}
|
||||
|
||||
sess, err := NewSession(g, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c, _ := NewTensor(float32(3.0))
|
||||
outputs, err := sess.Run(
|
||||
map[Output]*Tensor{x: c},
|
||||
[]Output{grad[0]},
|
||||
nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if outputs[0].Value().(float32) != 114.0 {
|
||||
t.Fatalf("Got %v, wanted float 114.0", outputs[0].Value())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGraphAddGradientsWithInitialValues(t *testing.T) {
|
||||
g := NewGraph()
|
||||
x, err := Placeholder(g, "x", Float)
|
||||
op0, err := g.AddOperation(OpSpec{
|
||||
Type: "Square",
|
||||
Name: "y0",
|
||||
Input: []Input{x},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
y0 := op0.Output(0)
|
||||
op1, err := g.AddOperation(OpSpec{
|
||||
Type: "Square",
|
||||
Name: "y1",
|
||||
Input: []Input{y0},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
y1 := op1.Output(0)
|
||||
|
||||
grads0, err := g.AddGradients("", []Output{y1}, []Output{y0}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(grads0) != 1 {
|
||||
t.Fatal(len(grads0))
|
||||
}
|
||||
if grads0[0].DataType() != Float {
|
||||
t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), Float)
|
||||
}
|
||||
|
||||
grads1, err := g.AddGradients("", []Output{y0}, []Output{x}, []Output{grads0[0]})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(grads1) != 1 {
|
||||
t.Fatal(len(grads1))
|
||||
}
|
||||
if grads1[0].DataType() != Float {
|
||||
t.Fatalf("Got DataType %v, wanted %v", grads1[0].DataType(), Float)
|
||||
}
|
||||
|
||||
sess, err := NewSession(g, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c, _ := NewTensor(float32(3.0))
|
||||
outputs, err := sess.Run(
|
||||
map[Output]*Tensor{x: c},
|
||||
[]Output{grads1[0]},
|
||||
nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if outputs[0].Value().(float32) != 108.0 {
|
||||
t.Fatalf("Got %v, wanted float 108.0", outputs[0].Value())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGraphValidateGradientsNames(t *testing.T) {
|
||||
g := NewGraph()
|
||||
x, err := Placeholder(g, "x", Float)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
op0, err := g.AddOperation(OpSpec{
|
||||
Type: "Square",
|
||||
Name: "y0",
|
||||
Input: []Input{x},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
y0 := op0.Output(0)
|
||||
|
||||
grads0, err := g.AddGradients("", []Output{y0}, []Output{x}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.HasPrefix(grads0[0].Op.Name(), "gradients/") {
|
||||
t.Fatalf("Got name %v, wanted started with gradients/", grads0[0].Op.Name())
|
||||
}
|
||||
|
||||
grads1, err := g.AddGradients("", []Output{y0}, []Output{x}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.HasPrefix(grads1[0].Op.Name(), "gradients_1/") {
|
||||
t.Fatalf("Got name %v, wanted started with gradients_1/", grads1[0].Op.Name())
|
||||
}
|
||||
|
||||
grads2, err := g.AddGradients("more_gradients", []Output{y0}, []Output{x}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.HasPrefix(grads2[0].Op.Name(), "more_gradients/") {
|
||||
t.Fatalf("Got name %v, wanted started with more_gradients/", grads2[0].Op.Name())
|
||||
}
|
||||
|
||||
grads3, err := g.AddGradients("even_more_gradients", []Output{y0}, []Output{x}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.HasPrefix(grads3[0].Op.Name(), "even_more_gradients/") {
|
||||
t.Fatalf("Got name %v, wanted started with even_more_gradients/", grads3[0].Op.Name())
|
||||
}
|
||||
|
||||
_, err = g.AddGradients("even_more_gradients", []Output{y0}, []Output{x}, nil)
|
||||
if err == nil {
|
||||
t.Error("AddGradients should have failed if gradients name is already existing")
|
||||
}
|
||||
}
|
||||
|
49
tensorflow/go/op/gradients.go
Normal file
49
tensorflow/go/op/gradients.go
Normal file
@ -0,0 +1,49 @@
|
||||
/*
|
||||
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package op
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
tf "github.com/tensorflow/tensorflow/tensorflow/go"
|
||||
)
|
||||
|
||||
// Gradients adds gradients computation ops to the graph according to scope.
|
||||
//
|
||||
// Arguments:
|
||||
// y: output of the function to derive
|
||||
// x: inputs of the function for which partial derivatives are computed
|
||||
// dx: if not null, the partial derivatives of some loss function L w.r.t. y
|
||||
//
|
||||
// return the partial derivatives
|
||||
func Gradients(scope *Scope, y []tf.Output, x []tf.Output, dx ...tf.Output) (output []tf.Output) {
|
||||
if len(scope.controlDependencies) > 0 {
|
||||
scope.UpdateErr("Gradients", fmt.Errorf("Gradients does not currently support control dependencies (via Scope.WithControlDependencies)."))
|
||||
return
|
||||
}
|
||||
if scope.device != "" {
|
||||
scope.UpdateErr("Gradients", fmt.Errorf("Gradients does not currently support device annotations (via Scope.WithDevice)."))
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
if output, err = scope.graph.AddGradients(scope.opName("Gradients"), y, x, dx); err != nil {
|
||||
scope.UpdateErr("Gradients", err)
|
||||
return
|
||||
}
|
||||
return output
|
||||
}
|
246
tensorflow/go/op/gradients_test.go
Normal file
246
tensorflow/go/op/gradients_test.go
Normal file
@ -0,0 +1,246 @@
|
||||
/*
|
||||
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package op
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
tf "github.com/tensorflow/tensorflow/tensorflow/go"
|
||||
)
|
||||
|
||||
func TestAddGradients(t *testing.T) {
|
||||
var (
|
||||
s = NewScope()
|
||||
x1 = Placeholder(s.SubScope("x1"), tf.Float)
|
||||
x2 = Placeholder(s.SubScope("x2"), tf.Float)
|
||||
y0 = Square(s.SubScope("y0"), x1)
|
||||
y1 = Square(s.SubScope("y1"), y0)
|
||||
y2 = AddN(s.SubScope("y2"), []tf.Output{y0, x2})
|
||||
)
|
||||
|
||||
grads0 := Gradients(s, []tf.Output{y1}, []tf.Output{x1})
|
||||
if err := s.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(grads0) != 1 {
|
||||
t.Fatal(len(grads0))
|
||||
}
|
||||
if grads0[0].DataType() != tf.Float {
|
||||
t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), tf.Float)
|
||||
}
|
||||
|
||||
sub := s.SubScope("sub")
|
||||
grads1 := Gradients(sub, []tf.Output{y2}, []tf.Output{x1, x2})
|
||||
if err := sub.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(grads1) != 2 {
|
||||
t.Fatal(len(grads1))
|
||||
}
|
||||
if grads1[0].DataType() != tf.Float {
|
||||
t.Fatalf("Got DataType %v, wanted %v", grads1[0].DataType(), tf.Float)
|
||||
}
|
||||
if grads1[1].DataType() != tf.Float {
|
||||
t.Fatalf("Got DataType %v, wanted %v", grads1[1].DataType(), tf.Float)
|
||||
}
|
||||
|
||||
graph, err := sub.Finalize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sess, err := tf.NewSession(graph, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c1, _ := tf.NewTensor(float32(3.0))
|
||||
c2, _ := tf.NewTensor(float32(3.0))
|
||||
outputs, err := sess.Run(
|
||||
map[tf.Output]*tf.Tensor{x1: c1, x2: c2},
|
||||
[]tf.Output{grads0[0], grads1[0], grads1[1]},
|
||||
nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(outputs) != 3 {
|
||||
t.Fatal(len(outputs))
|
||||
}
|
||||
if outputs[0].Value().(float32) != 108.0 {
|
||||
t.Fatalf("Got %v, wanted float 108.0", outputs[0].Value())
|
||||
}
|
||||
if outputs[1].Value().(float32) != 6.0 {
|
||||
t.Fatalf("Got %v, wanted float 6.0", outputs[1].Value())
|
||||
}
|
||||
if outputs[2].Value().(float32) != 1.0 {
|
||||
t.Fatalf("Got %v, wanted float 1.0", outputs[2].Value())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddGradientsSums(t *testing.T) {
|
||||
var (
|
||||
s = NewScope()
|
||||
x = Placeholder(s.SubScope("x"), tf.Float)
|
||||
y0 = Square(s.SubScope("y0"), x)
|
||||
y1 = Square(s.SubScope("y1"), y0)
|
||||
)
|
||||
|
||||
grad := Gradients(s, []tf.Output{y0, y1}, []tf.Output{x})
|
||||
if err := s.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(grad) != 1 {
|
||||
t.Fatal(len(grad))
|
||||
}
|
||||
if grad[0].DataType() != tf.Float {
|
||||
t.Fatalf("Got DataType %v, wanted %v", grad[0].DataType(), tf.Float)
|
||||
}
|
||||
|
||||
graph, err := s.Finalize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sess, err := tf.NewSession(graph, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c, _ := tf.NewTensor(float32(3.0))
|
||||
outputs, err := sess.Run(
|
||||
map[tf.Output]*tf.Tensor{x: c},
|
||||
[]tf.Output{grad[0]},
|
||||
nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if outputs[0].Value().(float32) != 114.0 {
|
||||
t.Fatalf("Got %v, wanted float 114.0", outputs[0].Value())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddGradientsWithInitialValues(t *testing.T) {
|
||||
var (
|
||||
s = NewScope()
|
||||
x = Placeholder(s.SubScope("x1"), tf.Float)
|
||||
y0 = Square(s.SubScope("y0"), x)
|
||||
y1 = Square(s.SubScope("y1"), y0)
|
||||
)
|
||||
|
||||
grads0 := Gradients(s, []tf.Output{y1}, []tf.Output{y0})
|
||||
if err := s.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(grads0) != 1 {
|
||||
t.Fatal(len(grads0))
|
||||
}
|
||||
if grads0[0].DataType() != tf.Float {
|
||||
t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), tf.Float)
|
||||
}
|
||||
|
||||
sub := s.SubScope("sub")
|
||||
grads1 := Gradients(sub, []tf.Output{y0}, []tf.Output{x}, grads0[0])
|
||||
if err := sub.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(grads1) != 1 {
|
||||
t.Fatal(len(grads1))
|
||||
}
|
||||
if grads1[0].DataType() != tf.Float {
|
||||
t.Fatalf("Got DataType %v, wanted %v", grads1[0].DataType(), tf.Float)
|
||||
}
|
||||
|
||||
graph, err := sub.Finalize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sess, err := tf.NewSession(graph, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c, _ := tf.NewTensor(float32(3.0))
|
||||
outputs, err := sess.Run(
|
||||
map[tf.Output]*tf.Tensor{x: c},
|
||||
[]tf.Output{grads1[0]},
|
||||
nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if outputs[0].Value().(float32) != 108.0 {
|
||||
t.Fatalf("Got %v, wanted float 108.0", outputs[0].Value())
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateGradientsNames(t *testing.T) {
|
||||
var (
|
||||
s = NewScope()
|
||||
x = Placeholder(s.SubScope("x"), tf.Float)
|
||||
y0 = Square(s.SubScope("y0"), x)
|
||||
)
|
||||
|
||||
grads0 := Gradients(s, []tf.Output{y0}, []tf.Output{x})
|
||||
if err := s.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.HasPrefix(grads0[0].Op.Name(), "Gradients/") {
|
||||
t.Fatalf("Got name %v, wanted started with Gradients/", grads0[0].Op.Name())
|
||||
}
|
||||
|
||||
sub := s.SubScope("sub")
|
||||
grads1 := Gradients(sub, []tf.Output{y0}, []tf.Output{x})
|
||||
if err := s.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.HasPrefix(grads1[0].Op.Name(), "sub/Gradients/") {
|
||||
t.Fatalf("Got name %v, wanted started with sub/Gradients/", grads1[0].Op.Name())
|
||||
}
|
||||
|
||||
Gradients(sub, []tf.Output{y0}, []tf.Output{x})
|
||||
if err := s.Err(); err == nil {
|
||||
t.Error("Gradients should have failed if executed more than once for scope of the same namespace")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddGradientsWithControlDependencies(t *testing.T) {
|
||||
var (
|
||||
s = NewScope()
|
||||
zero = Const(s.SubScope("zero"), int32(0))
|
||||
x = Placeholder(s.SubScope("x"), tf.Float)
|
||||
y0 = Square(s.SubScope("y0"), x)
|
||||
variable = VarHandleOp(s, tf.Int32, tf.ScalarShape())
|
||||
init = AssignVariableOp(s, variable, zero)
|
||||
readDeps = []*tf.Operation{init}
|
||||
)
|
||||
s = s.WithControlDependencies(readDeps...)
|
||||
Gradients(s, []tf.Output{y0}, []tf.Output{x})
|
||||
if err := s.Err(); err == nil {
|
||||
t.Error("Gradients should have failed when control dependencies are set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddGradientsWithDevice(t *testing.T) {
|
||||
var (
|
||||
s = NewScope()
|
||||
x = Placeholder(s.SubScope("x"), tf.Float)
|
||||
y0 = Square(s.SubScope("y0"), x)
|
||||
)
|
||||
s = s.WithDevice("/device:GPU:0")
|
||||
Gradients(s, []tf.Output{y0}, []tf.Output{x})
|
||||
if err := s.Err(); err == nil {
|
||||
t.Error("Gradients should have failed when device is set")
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user