diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go index 27dc2d84c7a..6fe2b6b86dd 100644 --- a/tensorflow/go/graph.go +++ b/tensorflow/go/graph.go @@ -147,18 +147,11 @@ func (g *Graph) Operations() []Operation { return ops } -// AddGradients adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s, -// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2... -// This is a simplified version of AddGradientsWithPrefix() without prefix -func (g *Graph) AddGradients(y []Output, x []Output, dx []Output) ([]Output, error) { - return g.AddGradientsWithPrefix("", y, x, dx) -} - // AddGradientsWithPrefix adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s, // i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2... -// This is a variant of AddGradients that allows to caller to pass a custom -// name prefix to the operations added to a graph to compute the gradients. -func (g *Graph) AddGradientsWithPrefix(prefix string, y []Output, x []Output, dx []Output) ([]Output, error) { +// This methods allows to caller to pass a custom name prefix to the operations +// added to a graph to compute the gradients. +func (g *Graph) AddGradients(prefix string, y []Output, x []Output, dx []Output) ([]Output, error) { var ( cprefix = C.CString(prefix) diff --git a/tensorflow/go/graph_test.go b/tensorflow/go/graph_test.go index d9126f36acf..d8f32dbaa93 100644 --- a/tensorflow/go/graph_test.go +++ b/tensorflow/go/graph_test.go @@ -119,7 +119,7 @@ func TestGraphAddGradients(t *testing.T) { } y2 := op2.Output(0) - grads0, err := g.AddGradients([]Output{y1}, []Output{x1}, nil) + grads0, err := g.AddGradients("", []Output{y1}, []Output{x1}, nil) if err != nil { t.Fatal(err) } @@ -130,7 +130,7 @@ func TestGraphAddGradients(t *testing.T) { t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), Float) } - grads1, err := g.AddGradients([]Output{y2}, []Output{x1, x2}, nil) + grads1, err := g.AddGradients("", []Output{y2}, []Output{x1, x2}, nil) if err != nil { t.Fatal(err) } @@ -194,7 +194,7 @@ func TestGraphAddGradientsSums(t *testing.T) { }) y1 := op1.Output(0) - grad, err := g.AddGradients([]Output{y0, y1}, []Output{x}, nil) + grad, err := g.AddGradients("", []Output{y0, y1}, []Output{x}, nil) if err != nil { t.Fatal(err) } @@ -245,7 +245,7 @@ func TestGraphAddGradientsWithInitialValuesToGraph(t *testing.T) { } y1 := op1.Output(0) - grads0, err := g.AddGradients([]Output{y1}, []Output{y0}, nil) + grads0, err := g.AddGradients("", []Output{y1}, []Output{y0}, nil) if err != nil { t.Fatal(err) } @@ -256,7 +256,7 @@ func TestGraphAddGradientsWithInitialValuesToGraph(t *testing.T) { t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), Float) } - grads1, err := g.AddGradients([]Output{y0}, []Output{x}, []Output{grads0[0]}) + grads1, err := g.AddGradients("", []Output{y0}, []Output{x}, []Output{grads0[0]}) if err != nil { t.Fatal(err) } @@ -301,7 +301,7 @@ func TestGraphValidateGradientsNames(t *testing.T) { } y0 := op0.Output(0) - grads0, err := g.AddGradients([]Output{y0}, []Output{x}, nil) + grads0, err := g.AddGradients("", []Output{y0}, []Output{x}, nil) if err != nil { t.Fatal(err) } @@ -309,7 +309,7 @@ func TestGraphValidateGradientsNames(t *testing.T) { t.Fatalf("Got name %v, wanted started with gradients/", grads0[0].Op.Name()) } - grads1, err := g.AddGradients([]Output{y0}, []Output{x}, nil) + grads1, err := g.AddGradients("", []Output{y0}, []Output{x}, nil) if err != nil { t.Fatal(err) } @@ -317,7 +317,7 @@ func TestGraphValidateGradientsNames(t *testing.T) { t.Fatalf("Got name %v, wanted started with gradients_1/", grads1[0].Op.Name()) } - grads2, err := g.AddGradientsWithPrefix("more_gradients", []Output{y0}, []Output{x}, nil) + grads2, err := g.AddGradients("more_gradients", []Output{y0}, []Output{x}, nil) if err != nil { t.Fatal(err) } @@ -325,7 +325,7 @@ func TestGraphValidateGradientsNames(t *testing.T) { t.Fatalf("Got name %v, wanted started with more_gradients/", grads2[0].Op.Name()) } - grads3, err := g.AddGradientsWithPrefix("even_more_gradients", []Output{y0}, []Output{x}, nil) + grads3, err := g.AddGradients("even_more_gradients", []Output{y0}, []Output{x}, nil) if err != nil { t.Fatal(err) }