Stop using uniqueName for Gradients in gradient.go
To keep the code style consistent with other functions.
This commit is contained in:
parent
fbd637423b
commit
310f454e92
@ -28,7 +28,7 @@ import tf "github.com/tensorflow/tensorflow/tensorflow/go"
|
||||
// return the partial derivatives
|
||||
func Gradients(scope *Scope, y []tf.Output, x []tf.Output, dx ...tf.Output) (output []tf.Output) {
|
||||
var err error
|
||||
if output, err = scope.graph.AddGradients(scope.opName(scope.uniqueName("Gradients")), y, x, dx); err != nil {
|
||||
if output, err = scope.graph.AddGradients(scope.opName("Gradients"), y, x, dx); err != nil {
|
||||
scope.UpdateErr("Gradients", err)
|
||||
return
|
||||
}
|
||||
|
@ -44,8 +44,9 @@ func TestAddGradients(t *testing.T) {
|
||||
t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), tf.Float)
|
||||
}
|
||||
|
||||
grads1 := Gradients(s, []tf.Output{y2}, []tf.Output{x1, x2})
|
||||
if err := s.Err(); err != nil {
|
||||
sub := s.SubScope("sub")
|
||||
grads1 := Gradients(sub, []tf.Output{y2}, []tf.Output{x1, x2})
|
||||
if err := sub.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(grads1) != 2 {
|
||||
@ -58,7 +59,7 @@ func TestAddGradients(t *testing.T) {
|
||||
t.Fatalf("Got DataType %v, wanted %v", grads1[1].DataType(), tf.Float)
|
||||
}
|
||||
|
||||
graph, err := s.Finalize()
|
||||
graph, err := sub.Finalize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -150,8 +151,9 @@ func TestAddGradientsWithInitialValues(t *testing.T) {
|
||||
t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), tf.Float)
|
||||
}
|
||||
|
||||
grads1 := Gradients(s, []tf.Output{y0}, []tf.Output{x}, grads0[0])
|
||||
if err := s.Err(); err != nil {
|
||||
sub := s.SubScope("sub")
|
||||
grads1 := Gradients(sub, []tf.Output{y0}, []tf.Output{x}, grads0[0])
|
||||
if err := sub.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(grads1) != 1 {
|
||||
@ -161,7 +163,7 @@ func TestAddGradientsWithInitialValues(t *testing.T) {
|
||||
t.Fatalf("Got DataType %v, wanted %v", grads1[0].DataType(), tf.Float)
|
||||
}
|
||||
|
||||
graph, err := s.Finalize()
|
||||
graph, err := sub.Finalize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -198,28 +200,17 @@ func TestValidateGradientsNames(t *testing.T) {
|
||||
t.Fatalf("Got name %v, wanted started with Gradients/", grads0[0].Op.Name())
|
||||
}
|
||||
|
||||
grads1 := Gradients(s, []tf.Output{y0}, []tf.Output{x})
|
||||
if err := s.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.HasPrefix(grads1[0].Op.Name(), "Gradients_1/") {
|
||||
t.Fatalf("Got name %v, wanted started with Gradients_1/", grads1[0].Op.Name())
|
||||
}
|
||||
|
||||
sub := s.SubScope("sub")
|
||||
grads3 := Gradients(sub, []tf.Output{y0}, []tf.Output{x})
|
||||
grads1 := Gradients(sub, []tf.Output{y0}, []tf.Output{x})
|
||||
if err := s.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.HasPrefix(grads3[0].Op.Name(), "sub/Gradients/") {
|
||||
t.Fatalf("Got name %v, wanted started with sub/Gradients/", grads3[0].Op.Name())
|
||||
if !strings.HasPrefix(grads1[0].Op.Name(), "sub/Gradients/") {
|
||||
t.Fatalf("Got name %v, wanted started with sub/Gradients/", grads1[0].Op.Name())
|
||||
}
|
||||
|
||||
grads4 := Gradients(sub, []tf.Output{y0}, []tf.Output{x})
|
||||
if err := s.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.HasPrefix(grads4[0].Op.Name(), "sub/Gradients_1/") {
|
||||
t.Fatalf("Got name %v, wanted started with sub/Gradients_1/", grads4[0].Op.Name())
|
||||
Gradients(sub, []tf.Output{y0}, []tf.Output{x})
|
||||
if err := s.Err(); err == nil {
|
||||
t.Error("Gradients should have failed if executed more than once for scope of the same namespace")
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user