Stop using uniqueName for Gradients in gradient.go

To keep the code style consistent with other functions.
This commit is contained in:
Cibifang 2018-10-16 16:53:52 +08:00
parent fbd637423b
commit 310f454e92
2 changed files with 15 additions and 24 deletions

View File

@ -28,7 +28,7 @@ import tf "github.com/tensorflow/tensorflow/tensorflow/go"
// return the partial derivatives
func Gradients(scope *Scope, y []tf.Output, x []tf.Output, dx ...tf.Output) (output []tf.Output) {
var err error
if output, err = scope.graph.AddGradients(scope.opName(scope.uniqueName("Gradients")), y, x, dx); err != nil {
if output, err = scope.graph.AddGradients(scope.opName("Gradients"), y, x, dx); err != nil {
scope.UpdateErr("Gradients", err)
return
}

View File

@ -44,8 +44,9 @@ func TestAddGradients(t *testing.T) {
t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), tf.Float)
}
grads1 := Gradients(s, []tf.Output{y2}, []tf.Output{x1, x2})
if err := s.Err(); err != nil {
sub := s.SubScope("sub")
grads1 := Gradients(sub, []tf.Output{y2}, []tf.Output{x1, x2})
if err := sub.Err(); err != nil {
t.Fatal(err)
}
if len(grads1) != 2 {
@ -58,7 +59,7 @@ func TestAddGradients(t *testing.T) {
t.Fatalf("Got DataType %v, wanted %v", grads1[1].DataType(), tf.Float)
}
graph, err := s.Finalize()
graph, err := sub.Finalize()
if err != nil {
t.Fatal(err)
}
@ -150,8 +151,9 @@ func TestAddGradientsWithInitialValues(t *testing.T) {
t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), tf.Float)
}
grads1 := Gradients(s, []tf.Output{y0}, []tf.Output{x}, grads0[0])
if err := s.Err(); err != nil {
sub := s.SubScope("sub")
grads1 := Gradients(sub, []tf.Output{y0}, []tf.Output{x}, grads0[0])
if err := sub.Err(); err != nil {
t.Fatal(err)
}
if len(grads1) != 1 {
@ -161,7 +163,7 @@ func TestAddGradientsWithInitialValues(t *testing.T) {
t.Fatalf("Got DataType %v, wanted %v", grads1[0].DataType(), tf.Float)
}
graph, err := s.Finalize()
graph, err := sub.Finalize()
if err != nil {
t.Fatal(err)
}
@ -198,28 +200,17 @@ func TestValidateGradientsNames(t *testing.T) {
t.Fatalf("Got name %v, wanted started with Gradients/", grads0[0].Op.Name())
}
grads1 := Gradients(s, []tf.Output{y0}, []tf.Output{x})
if err := s.Err(); err != nil {
t.Fatal(err)
}
if !strings.HasPrefix(grads1[0].Op.Name(), "Gradients_1/") {
t.Fatalf("Got name %v, wanted started with Gradients_1/", grads1[0].Op.Name())
}
sub := s.SubScope("sub")
grads3 := Gradients(sub, []tf.Output{y0}, []tf.Output{x})
grads1 := Gradients(sub, []tf.Output{y0}, []tf.Output{x})
if err := s.Err(); err != nil {
t.Fatal(err)
}
if !strings.HasPrefix(grads3[0].Op.Name(), "sub/Gradients/") {
t.Fatalf("Got name %v, wanted started with sub/Gradients/", grads3[0].Op.Name())
if !strings.HasPrefix(grads1[0].Op.Name(), "sub/Gradients/") {
t.Fatalf("Got name %v, wanted started with sub/Gradients/", grads1[0].Op.Name())
}
grads4 := Gradients(sub, []tf.Output{y0}, []tf.Output{x})
if err := s.Err(); err != nil {
t.Fatal(err)
}
if !strings.HasPrefix(grads4[0].Op.Name(), "sub/Gradients_1/") {
t.Fatalf("Got name %v, wanted started with sub/Gradients_1/", grads4[0].Op.Name())
Gradients(sub, []tf.Output{y0}, []tf.Output{x})
if err := s.Err(); err == nil {
t.Error("Gradients should have failed if executed more than once for scope of the same namespace")
}
}