Remove from func Gradients() in gradients.go
This commit is contained in:
parent
72ba6b2271
commit
8c9d5eb52b
@ -21,19 +21,14 @@ import tf "github.com/tensorflow/tensorflow/tensorflow/go"
|
||||
// Gradients adds gradients computation ops to the graph according to scope.
|
||||
//
|
||||
// Arguments:
|
||||
// prefix: unique string prefix applied before the names of nodes added to the graph to
|
||||
// compute gradients. If null, will use "Gradients".
|
||||
// y: output of the function to derive
|
||||
// x: inputs of the function for which partial derivatives are computed
|
||||
// dx: if not null, the partial derivatives of some loss function L w.r.t. y
|
||||
//
|
||||
// return the partial derivatives
|
||||
func Gradients(scope *Scope, prefix string, y []tf.Output, x []tf.Output, dx ...tf.Output) (output []tf.Output) {
|
||||
func Gradients(scope *Scope, y []tf.Output, x []tf.Output, dx ...tf.Output) (output []tf.Output) {
|
||||
var err error
|
||||
if prefix == "" {
|
||||
prefix = "Gradients"
|
||||
}
|
||||
if output, err = scope.graph.AddGradients(scope.opName(scope.uniqueName(prefix)), y, x, dx); err != nil {
|
||||
if output, err = scope.graph.AddGradients(scope.opName(scope.uniqueName("Gradients")), y, x, dx); err != nil {
|
||||
scope.UpdateErr("Gradients", err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -33,7 +33,7 @@ func TestAddGradients(t *testing.T) {
|
||||
y2 = AddN(s.SubScope("y2"), []tf.Output{y0, x2})
|
||||
)
|
||||
|
||||
grads0 := Gradients(s, "", []tf.Output{y1}, []tf.Output{x1})
|
||||
grads0 := Gradients(s, []tf.Output{y1}, []tf.Output{x1})
|
||||
if err := s.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -44,7 +44,7 @@ func TestAddGradients(t *testing.T) {
|
||||
t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), tf.Float)
|
||||
}
|
||||
|
||||
grads1 := Gradients(s, "", []tf.Output{y2}, []tf.Output{x1, x2})
|
||||
grads1 := Gradients(s, []tf.Output{y2}, []tf.Output{x1, x2})
|
||||
if err := s.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -98,7 +98,7 @@ func TestAddGradientsSums(t *testing.T) {
|
||||
y1 = Square(s.SubScope("y1"), y0)
|
||||
)
|
||||
|
||||
grad := Gradients(s, "", []tf.Output{y0, y1}, []tf.Output{x})
|
||||
grad := Gradients(s, []tf.Output{y0, y1}, []tf.Output{x})
|
||||
if err := s.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -139,7 +139,7 @@ func TestAddGradientsWithInitialValues(t *testing.T) {
|
||||
y1 = Square(s.SubScope("y1"), y0)
|
||||
)
|
||||
|
||||
grads0 := Gradients(s, "", []tf.Output{y1}, []tf.Output{y0})
|
||||
grads0 := Gradients(s, []tf.Output{y1}, []tf.Output{y0})
|
||||
if err := s.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -150,7 +150,7 @@ func TestAddGradientsWithInitialValues(t *testing.T) {
|
||||
t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), tf.Float)
|
||||
}
|
||||
|
||||
grads1 := Gradients(s, "", []tf.Output{y0}, []tf.Output{x}, grads0[0])
|
||||
grads1 := Gradients(s, []tf.Output{y0}, []tf.Output{x}, grads0[0])
|
||||
if err := s.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -190,7 +190,7 @@ func TestValidateGradientsNames(t *testing.T) {
|
||||
y0 = Square(s.SubScope("y0"), x)
|
||||
)
|
||||
|
||||
grads0 := Gradients(s, "", []tf.Output{y0}, []tf.Output{x})
|
||||
grads0 := Gradients(s, []tf.Output{y0}, []tf.Output{x})
|
||||
if err := s.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -198,7 +198,7 @@ func TestValidateGradientsNames(t *testing.T) {
|
||||
t.Fatalf("Got name %v, wanted started with Gradients/", grads0[0].Op.Name())
|
||||
}
|
||||
|
||||
grads1 := Gradients(s, "", []tf.Output{y0}, []tf.Output{x})
|
||||
grads1 := Gradients(s, []tf.Output{y0}, []tf.Output{x})
|
||||
if err := s.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -206,28 +206,20 @@ func TestValidateGradientsNames(t *testing.T) {
|
||||
t.Fatalf("Got name %v, wanted started with Gradients_1/", grads1[0].Op.Name())
|
||||
}
|
||||
|
||||
grads2 := Gradients(s, "more_gradients", []tf.Output{y0}, []tf.Output{x})
|
||||
if err := s.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.HasPrefix(grads2[0].Op.Name(), "more_gradients/") {
|
||||
t.Fatalf("Got name %v, wanted started with more_gradients/", grads2[0].Op.Name())
|
||||
}
|
||||
|
||||
sub := s.SubScope("sub")
|
||||
grads3 := Gradients(sub, "even_more_gradients", []tf.Output{y0}, []tf.Output{x})
|
||||
grads3 := Gradients(sub, []tf.Output{y0}, []tf.Output{x})
|
||||
if err := s.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.HasPrefix(grads3[0].Op.Name(), "sub/even_more_gradients/") {
|
||||
t.Fatalf("Got name %v, wanted started with sub/even_more_gradients/", grads3[0].Op.Name())
|
||||
if !strings.HasPrefix(grads3[0].Op.Name(), "sub/Gradients/") {
|
||||
t.Fatalf("Got name %v, wanted started with sub/Gradients/", grads3[0].Op.Name())
|
||||
}
|
||||
|
||||
grads4 := Gradients(sub, "even_more_gradients", []tf.Output{y0}, []tf.Output{x})
|
||||
grads4 := Gradients(sub, []tf.Output{y0}, []tf.Output{x})
|
||||
if err := s.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.HasPrefix(grads4[0].Op.Name(), "sub/even_more_gradients_1/") {
|
||||
t.Fatalf("Got name %v, wanted started with sub/even_more_gradients_1/", grads4[0].Op.Name())
|
||||
if !strings.HasPrefix(grads4[0].Op.Name(), "sub/Gradients_1/") {
|
||||
t.Fatalf("Got name %v, wanted started with sub/Gradients_1/", grads4[0].Op.Name())
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user