Remove from func Gradients() in gradients.go

This commit is contained in:
Cibifang 2018-10-01 18:02:32 +08:00
parent 72ba6b2271
commit 8c9d5eb52b
2 changed files with 15 additions and 28 deletions

View File

@ -21,19 +21,14 @@ import tf "github.com/tensorflow/tensorflow/tensorflow/go"
// Gradients adds gradients computation ops to the graph according to scope.
//
// Arguments:
// prefix: unique string prefix applied before the names of nodes added to the graph to
// compute gradients. If null, will use "Gradients".
// y: output of the function to derive
// x: inputs of the function for which partial derivatives are computed
// dx: if not null, the partial derivatives of some loss function L w.r.t. y
//
// return the partial derivatives
func Gradients(scope *Scope, prefix string, y []tf.Output, x []tf.Output, dx ...tf.Output) (output []tf.Output) {
func Gradients(scope *Scope, y []tf.Output, x []tf.Output, dx ...tf.Output) (output []tf.Output) {
var err error
if prefix == "" {
prefix = "Gradients"
}
if output, err = scope.graph.AddGradients(scope.opName(scope.uniqueName(prefix)), y, x, dx); err != nil {
if output, err = scope.graph.AddGradients(scope.opName(scope.uniqueName("Gradients")), y, x, dx); err != nil {
scope.UpdateErr("Gradients", err)
return
}

View File

@ -33,7 +33,7 @@ func TestAddGradients(t *testing.T) {
y2 = AddN(s.SubScope("y2"), []tf.Output{y0, x2})
)
grads0 := Gradients(s, "", []tf.Output{y1}, []tf.Output{x1})
grads0 := Gradients(s, []tf.Output{y1}, []tf.Output{x1})
if err := s.Err(); err != nil {
t.Fatal(err)
}
@ -44,7 +44,7 @@ func TestAddGradients(t *testing.T) {
t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), tf.Float)
}
grads1 := Gradients(s, "", []tf.Output{y2}, []tf.Output{x1, x2})
grads1 := Gradients(s, []tf.Output{y2}, []tf.Output{x1, x2})
if err := s.Err(); err != nil {
t.Fatal(err)
}
@ -98,7 +98,7 @@ func TestAddGradientsSums(t *testing.T) {
y1 = Square(s.SubScope("y1"), y0)
)
grad := Gradients(s, "", []tf.Output{y0, y1}, []tf.Output{x})
grad := Gradients(s, []tf.Output{y0, y1}, []tf.Output{x})
if err := s.Err(); err != nil {
t.Fatal(err)
}
@ -139,7 +139,7 @@ func TestAddGradientsWithInitialValues(t *testing.T) {
y1 = Square(s.SubScope("y1"), y0)
)
grads0 := Gradients(s, "", []tf.Output{y1}, []tf.Output{y0})
grads0 := Gradients(s, []tf.Output{y1}, []tf.Output{y0})
if err := s.Err(); err != nil {
t.Fatal(err)
}
@ -150,7 +150,7 @@ func TestAddGradientsWithInitialValues(t *testing.T) {
t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), tf.Float)
}
grads1 := Gradients(s, "", []tf.Output{y0}, []tf.Output{x}, grads0[0])
grads1 := Gradients(s, []tf.Output{y0}, []tf.Output{x}, grads0[0])
if err := s.Err(); err != nil {
t.Fatal(err)
}
@ -190,7 +190,7 @@ func TestValidateGradientsNames(t *testing.T) {
y0 = Square(s.SubScope("y0"), x)
)
grads0 := Gradients(s, "", []tf.Output{y0}, []tf.Output{x})
grads0 := Gradients(s, []tf.Output{y0}, []tf.Output{x})
if err := s.Err(); err != nil {
t.Fatal(err)
}
@ -198,7 +198,7 @@ func TestValidateGradientsNames(t *testing.T) {
t.Fatalf("Got name %v, wanted started with Gradients/", grads0[0].Op.Name())
}
grads1 := Gradients(s, "", []tf.Output{y0}, []tf.Output{x})
grads1 := Gradients(s, []tf.Output{y0}, []tf.Output{x})
if err := s.Err(); err != nil {
t.Fatal(err)
}
@ -206,28 +206,20 @@ func TestValidateGradientsNames(t *testing.T) {
t.Fatalf("Got name %v, wanted started with Gradients_1/", grads1[0].Op.Name())
}
grads2 := Gradients(s, "more_gradients", []tf.Output{y0}, []tf.Output{x})
if err := s.Err(); err != nil {
t.Fatal(err)
}
if !strings.HasPrefix(grads2[0].Op.Name(), "more_gradients/") {
t.Fatalf("Got name %v, wanted started with more_gradients/", grads2[0].Op.Name())
}
sub := s.SubScope("sub")
grads3 := Gradients(sub, "even_more_gradients", []tf.Output{y0}, []tf.Output{x})
grads3 := Gradients(sub, []tf.Output{y0}, []tf.Output{x})
if err := s.Err(); err != nil {
t.Fatal(err)
}
if !strings.HasPrefix(grads3[0].Op.Name(), "sub/even_more_gradients/") {
t.Fatalf("Got name %v, wanted started with sub/even_more_gradients/", grads3[0].Op.Name())
if !strings.HasPrefix(grads3[0].Op.Name(), "sub/Gradients/") {
t.Fatalf("Got name %v, wanted started with sub/Gradients/", grads3[0].Op.Name())
}
grads4 := Gradients(sub, "even_more_gradients", []tf.Output{y0}, []tf.Output{x})
grads4 := Gradients(sub, []tf.Output{y0}, []tf.Output{x})
if err := s.Err(); err != nil {
t.Fatal(err)
}
if !strings.HasPrefix(grads4[0].Op.Name(), "sub/even_more_gradients_1/") {
t.Fatalf("Got name %v, wanted started with sub/even_more_gradients_1/", grads4[0].Op.Name())
if !strings.HasPrefix(grads4[0].Op.Name(), "sub/Gradients_1/") {
t.Fatalf("Got name %v, wanted started with sub/Gradients_1/", grads4[0].Op.Name())
}
}