Combine AddGradients and AddGradientsWithPrefix Methods in golang
This commit is contained in:
parent
ccccbe7259
commit
a9a6c8efec
tensorflow/go
@ -147,18 +147,11 @@ func (g *Graph) Operations() []Operation {
|
||||
return ops
|
||||
}
|
||||
|
||||
// AddGradients adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s,
|
||||
// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...
|
||||
// This is a simplified version of AddGradientsWithPrefix() without prefix
|
||||
func (g *Graph) AddGradients(y []Output, x []Output, dx []Output) ([]Output, error) {
|
||||
return g.AddGradientsWithPrefix("", y, x, dx)
|
||||
}
|
||||
|
||||
// AddGradientsWithPrefix adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s,
|
||||
// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...
|
||||
// This is a variant of AddGradients that allows to caller to pass a custom
|
||||
// name prefix to the operations added to a graph to compute the gradients.
|
||||
func (g *Graph) AddGradientsWithPrefix(prefix string, y []Output, x []Output, dx []Output) ([]Output, error) {
|
||||
// This methods allows to caller to pass a custom name prefix to the operations
|
||||
// added to a graph to compute the gradients.
|
||||
func (g *Graph) AddGradients(prefix string, y []Output, x []Output, dx []Output) ([]Output, error) {
|
||||
var (
|
||||
cprefix = C.CString(prefix)
|
||||
|
||||
|
@ -119,7 +119,7 @@ func TestGraphAddGradients(t *testing.T) {
|
||||
}
|
||||
y2 := op2.Output(0)
|
||||
|
||||
grads0, err := g.AddGradients([]Output{y1}, []Output{x1}, nil)
|
||||
grads0, err := g.AddGradients("", []Output{y1}, []Output{x1}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -130,7 +130,7 @@ func TestGraphAddGradients(t *testing.T) {
|
||||
t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), Float)
|
||||
}
|
||||
|
||||
grads1, err := g.AddGradients([]Output{y2}, []Output{x1, x2}, nil)
|
||||
grads1, err := g.AddGradients("", []Output{y2}, []Output{x1, x2}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -194,7 +194,7 @@ func TestGraphAddGradientsSums(t *testing.T) {
|
||||
})
|
||||
y1 := op1.Output(0)
|
||||
|
||||
grad, err := g.AddGradients([]Output{y0, y1}, []Output{x}, nil)
|
||||
grad, err := g.AddGradients("", []Output{y0, y1}, []Output{x}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -245,7 +245,7 @@ func TestGraphAddGradientsWithInitialValuesToGraph(t *testing.T) {
|
||||
}
|
||||
y1 := op1.Output(0)
|
||||
|
||||
grads0, err := g.AddGradients([]Output{y1}, []Output{y0}, nil)
|
||||
grads0, err := g.AddGradients("", []Output{y1}, []Output{y0}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -256,7 +256,7 @@ func TestGraphAddGradientsWithInitialValuesToGraph(t *testing.T) {
|
||||
t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), Float)
|
||||
}
|
||||
|
||||
grads1, err := g.AddGradients([]Output{y0}, []Output{x}, []Output{grads0[0]})
|
||||
grads1, err := g.AddGradients("", []Output{y0}, []Output{x}, []Output{grads0[0]})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -301,7 +301,7 @@ func TestGraphValidateGradientsNames(t *testing.T) {
|
||||
}
|
||||
y0 := op0.Output(0)
|
||||
|
||||
grads0, err := g.AddGradients([]Output{y0}, []Output{x}, nil)
|
||||
grads0, err := g.AddGradients("", []Output{y0}, []Output{x}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -309,7 +309,7 @@ func TestGraphValidateGradientsNames(t *testing.T) {
|
||||
t.Fatalf("Got name %v, wanted started with gradients/", grads0[0].Op.Name())
|
||||
}
|
||||
|
||||
grads1, err := g.AddGradients([]Output{y0}, []Output{x}, nil)
|
||||
grads1, err := g.AddGradients("", []Output{y0}, []Output{x}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -317,7 +317,7 @@ func TestGraphValidateGradientsNames(t *testing.T) {
|
||||
t.Fatalf("Got name %v, wanted started with gradients_1/", grads1[0].Op.Name())
|
||||
}
|
||||
|
||||
grads2, err := g.AddGradientsWithPrefix("more_gradients", []Output{y0}, []Output{x}, nil)
|
||||
grads2, err := g.AddGradients("more_gradients", []Output{y0}, []Output{x}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -325,7 +325,7 @@ func TestGraphValidateGradientsNames(t *testing.T) {
|
||||
t.Fatalf("Got name %v, wanted started with more_gradients/", grads2[0].Op.Name())
|
||||
}
|
||||
|
||||
grads3, err := g.AddGradientsWithPrefix("even_more_gradients", []Output{y0}, []Output{x}, nil)
|
||||
grads3, err := g.AddGradients("even_more_gradients", []Output{y0}, []Output{x}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user