diff --git a/tensorflow/go/op/gradients.go b/tensorflow/go/op/gradients.go index 2dce134d5f3..2397e40bf40 100644 --- a/tensorflow/go/op/gradients.go +++ b/tensorflow/go/op/gradients.go @@ -21,19 +21,14 @@ import tf "github.com/tensorflow/tensorflow/tensorflow/go" // Gradients adds gradients computation ops to the graph according to scope. // // Arguments: -// prefix: unique string prefix applied before the names of nodes added to the graph to -// compute gradients. If null, will use "Gradients". // y: output of the function to derive // x: inputs of the function for which partial derivatives are computed // dx: if not null, the partial derivatives of some loss function L w.r.t. y // // return the partial derivatives -func Gradients(scope *Scope, prefix string, y []tf.Output, x []tf.Output, dx ...tf.Output) (output []tf.Output) { +func Gradients(scope *Scope, y []tf.Output, x []tf.Output, dx ...tf.Output) (output []tf.Output) { var err error - if prefix == "" { - prefix = "Gradients" - } - if output, err = scope.graph.AddGradients(scope.opName(scope.uniqueName(prefix)), y, x, dx); err != nil { + if output, err = scope.graph.AddGradients(scope.opName(scope.uniqueName("Gradients")), y, x, dx); err != nil { scope.UpdateErr("Gradients", err) return } diff --git a/tensorflow/go/op/gradients_test.go b/tensorflow/go/op/gradients_test.go index 7d03fada745..2bcb3e88ebe 100644 --- a/tensorflow/go/op/gradients_test.go +++ b/tensorflow/go/op/gradients_test.go @@ -33,7 +33,7 @@ func TestAddGradients(t *testing.T) { y2 = AddN(s.SubScope("y2"), []tf.Output{y0, x2}) ) - grads0 := Gradients(s, "", []tf.Output{y1}, []tf.Output{x1}) + grads0 := Gradients(s, []tf.Output{y1}, []tf.Output{x1}) if err := s.Err(); err != nil { t.Fatal(err) } @@ -44,7 +44,7 @@ func TestAddGradients(t *testing.T) { t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), tf.Float) } - grads1 := Gradients(s, "", []tf.Output{y2}, []tf.Output{x1, x2}) + grads1 := Gradients(s, []tf.Output{y2}, []tf.Output{x1, x2}) if err := s.Err(); err != nil { t.Fatal(err) } @@ -98,7 +98,7 @@ func TestAddGradientsSums(t *testing.T) { y1 = Square(s.SubScope("y1"), y0) ) - grad := Gradients(s, "", []tf.Output{y0, y1}, []tf.Output{x}) + grad := Gradients(s, []tf.Output{y0, y1}, []tf.Output{x}) if err := s.Err(); err != nil { t.Fatal(err) } @@ -139,7 +139,7 @@ func TestAddGradientsWithInitialValues(t *testing.T) { y1 = Square(s.SubScope("y1"), y0) ) - grads0 := Gradients(s, "", []tf.Output{y1}, []tf.Output{y0}) + grads0 := Gradients(s, []tf.Output{y1}, []tf.Output{y0}) if err := s.Err(); err != nil { t.Fatal(err) } @@ -150,7 +150,7 @@ func TestAddGradientsWithInitialValues(t *testing.T) { t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), tf.Float) } - grads1 := Gradients(s, "", []tf.Output{y0}, []tf.Output{x}, grads0[0]) + grads1 := Gradients(s, []tf.Output{y0}, []tf.Output{x}, grads0[0]) if err := s.Err(); err != nil { t.Fatal(err) } @@ -190,7 +190,7 @@ func TestValidateGradientsNames(t *testing.T) { y0 = Square(s.SubScope("y0"), x) ) - grads0 := Gradients(s, "", []tf.Output{y0}, []tf.Output{x}) + grads0 := Gradients(s, []tf.Output{y0}, []tf.Output{x}) if err := s.Err(); err != nil { t.Fatal(err) } @@ -198,7 +198,7 @@ func TestValidateGradientsNames(t *testing.T) { t.Fatalf("Got name %v, wanted started with Gradients/", grads0[0].Op.Name()) } - grads1 := Gradients(s, "", []tf.Output{y0}, []tf.Output{x}) + grads1 := Gradients(s, []tf.Output{y0}, []tf.Output{x}) if err := s.Err(); err != nil { t.Fatal(err) } @@ -206,28 +206,20 @@ func TestValidateGradientsNames(t *testing.T) { t.Fatalf("Got name %v, wanted started with Gradients_1/", grads1[0].Op.Name()) } - grads2 := Gradients(s, "more_gradients", []tf.Output{y0}, []tf.Output{x}) - if err := s.Err(); err != nil { - t.Fatal(err) - } - if !strings.HasPrefix(grads2[0].Op.Name(), "more_gradients/") { - t.Fatalf("Got name %v, wanted started with more_gradients/", grads2[0].Op.Name()) - } - sub := s.SubScope("sub") - grads3 := Gradients(sub, "even_more_gradients", []tf.Output{y0}, []tf.Output{x}) + grads3 := Gradients(sub, []tf.Output{y0}, []tf.Output{x}) if err := s.Err(); err != nil { t.Fatal(err) } - if !strings.HasPrefix(grads3[0].Op.Name(), "sub/even_more_gradients/") { - t.Fatalf("Got name %v, wanted started with sub/even_more_gradients/", grads3[0].Op.Name()) + if !strings.HasPrefix(grads3[0].Op.Name(), "sub/Gradients/") { + t.Fatalf("Got name %v, wanted started with sub/Gradients/", grads3[0].Op.Name()) } - grads4 := Gradients(sub, "even_more_gradients", []tf.Output{y0}, []tf.Output{x}) + grads4 := Gradients(sub, []tf.Output{y0}, []tf.Output{x}) if err := s.Err(); err != nil { t.Fatal(err) } - if !strings.HasPrefix(grads4[0].Op.Name(), "sub/even_more_gradients_1/") { - t.Fatalf("Got name %v, wanted started with sub/even_more_gradients_1/", grads4[0].Op.Name()) + if !strings.HasPrefix(grads4[0].Op.Name(), "sub/Gradients_1/") { + t.Fatalf("Got name %v, wanted started with sub/Gradients_1/", grads4[0].Op.Name()) } }