Support addition of gradient operations in a graph for golang

This commit is contained in:
Cibifang 2018-08-27 09:28:30 +08:00
parent 09792df012
commit ccccbe7259
2 changed files with 320 additions and 0 deletions

View File

@ -147,6 +147,73 @@ func (g *Graph) Operations() []Operation {
return ops
}
// AddGradients adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s,
// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...
// This is a simplified version of AddGradientsWithPrefix() without prefix
func (g *Graph) AddGradients(y []Output, x []Output, dx []Output) ([]Output, error) {
return g.AddGradientsWithPrefix("", y, x, dx)
}
// AddGradientsWithPrefix adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s,
// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...
// This is a variant of AddGradients that allows to caller to pass a custom
// name prefix to the operations added to a graph to compute the gradients.
func (g *Graph) AddGradientsWithPrefix(prefix string, y []Output, x []Output, dx []Output) ([]Output, error) {
var (
cprefix = C.CString(prefix)
cy = make([]C.TF_Output, len(y))
cx = make([]C.TF_Output, len(x))
cdx = make([]C.TF_Output, len(dx))
cdy = make([]C.TF_Output, len(x))
pcy *C.TF_Output
pcx *C.TF_Output
pcdx *C.TF_Output
pcdy *C.TF_Output
status = newStatus()
)
if len(y) > 0 {
pcy = &cy[0]
for i, o := range y {
cy[i] = o.c()
}
}
if len(x) > 0 {
pcx = &cx[0]
for i, o := range x {
cx[i] = o.c()
}
pcdy = &cdy[0]
}
if len(dx) > 0 {
pcdx = &cdx[0]
for i, o := range dx {
cdx[i] = o.c()
}
}
// If prefix is "", the C.TF_AddGradientsWithPrefix need cprefix to be nil but not ""
if len(prefix) == 0 {
C.TF_AddGradientsWithPrefix(g.c, nil, pcy, C.int(len(y)), pcx, C.int(len(x)), pcdx, status.c, pcdy)
} else {
C.TF_AddGradientsWithPrefix(g.c, cprefix, pcy, C.int(len(y)), pcx, C.int(len(x)), pcdx, status.c, pcdy)
}
if err := status.Err(); err != nil {
return nil, err
}
dy := make([]Output, len(x))
for i, co := range cdy {
op := &Operation{co.oper, g}
dy[i] = Output{op, int(co.index)}
}
return dy, nil
}
// OpSpec is the specification of an Operation to be added to a Graph
// (using Graph.AddOperation).
type OpSpec struct {

View File

@ -19,6 +19,7 @@ package tensorflow
import (
"bytes"
"fmt"
"strings"
"testing"
)
@ -80,3 +81,255 @@ func TestGraphWriteToAndImport(t *testing.T) {
t.Error(err)
}
}
func TestGraphAddGradients(t *testing.T) {
g := NewGraph()
x1, err := Placeholder(g, "x1", Float)
if err != nil {
t.Fatal(err)
}
x2, err := Placeholder(g, "x2", Float)
if err != nil {
t.Fatal(err)
}
op0, err := g.AddOperation(OpSpec{
Type: "Square",
Name: "y0",
Input: []Input{x1},
})
if err != nil {
t.Fatal(err)
}
y0 := op0.Output(0)
op1, err := g.AddOperation(OpSpec{
Type: "Square",
Name: "y1",
Input: []Input{y0},
})
if err != nil {
t.Fatal(err)
}
y1 := op1.Output(0)
op2, err := g.AddOperation(OpSpec{
Type: "AddN",
Input: []Input{OutputList([]Output{y0, x2})},
})
if err != nil {
t.Fatal(err)
}
y2 := op2.Output(0)
grads0, err := g.AddGradients([]Output{y1}, []Output{x1}, nil)
if err != nil {
t.Fatal(err)
}
if len(grads0) != 1 {
t.Fatal(len(grads0))
}
if grads0[0].DataType() != Float {
t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), Float)
}
grads1, err := g.AddGradients([]Output{y2}, []Output{x1, x2}, nil)
if err != nil {
t.Fatal(err)
}
if len(grads1) != 2 {
t.Fatal(len(grads1))
}
if grads1[0].DataType() != Float {
t.Fatalf("Got DataType %v, wanted %v", grads1[0].DataType(), Float)
}
if grads1[1].DataType() != Float {
t.Fatalf("Got DataType %v, wanted %v", grads1[1].DataType(), Float)
}
sess, err := NewSession(g, nil)
if err != nil {
t.Fatal(err)
}
c1, _ := NewTensor(float32(3.0))
c2, _ := NewTensor(float32(2.0))
outputs, err := sess.Run(
map[Output]*Tensor{x1: c1, x2: c2},
[]Output{grads0[0], grads1[0], grads1[1]},
nil)
if err != nil {
t.Fatal(err)
}
if len(outputs) != 3 {
t.Fatal(len(outputs))
}
if outputs[0].Value().(float32) != 108.0 {
t.Fatalf("Got %v, wanted float 108.0", outputs[0].Value())
}
if outputs[1].Value().(float32) != 6.0 {
t.Fatalf("Got %v, wanted float 6.0", outputs[1].Value())
}
if outputs[2].Value().(float32) != 1.0 {
t.Fatalf("Got %v, wanted float 1.0", outputs[2].Value())
}
}
func TestGraphAddGradientsSums(t *testing.T) {
g := NewGraph()
x, err := Placeholder(g, "x", Float)
if err != nil {
t.Fatal(err)
}
op0, err := g.AddOperation(OpSpec{
Type: "Square",
Name: "y0",
Input: []Input{x},
})
if err != nil {
t.Fatal(err)
}
y0 := op0.Output(0)
op1, err := g.AddOperation(OpSpec{
Type: "Square",
Name: "y1",
Input: []Input{y0},
})
y1 := op1.Output(0)
grad, err := g.AddGradients([]Output{y0, y1}, []Output{x}, nil)
if err != nil {
t.Fatal(err)
}
if len(grad) != 1 {
t.Fatal(len(grad))
}
if grad[0].DataType() != Float {
t.Fatalf("Got DataType %v, wanted %v", grad[0].DataType(), Float)
}
sess, err := NewSession(g, nil)
if err != nil {
t.Fatal(err)
}
c, _ := NewTensor(float32(3.0))
outputs, err := sess.Run(
map[Output]*Tensor{x: c},
[]Output{grad[0]},
nil)
if err != nil {
t.Fatal(err)
}
if outputs[0].Value().(float32) != 114.0 {
t.Fatalf("Got %v, wanted float 114.0", outputs[0].Value())
}
}
func TestGraphAddGradientsWithInitialValuesToGraph(t *testing.T) {
g := NewGraph()
x, err := Placeholder(g, "x", Float)
op0, err := g.AddOperation(OpSpec{
Type: "Square",
Name: "y0",
Input: []Input{x},
})
if err != nil {
t.Fatal(err)
}
y0 := op0.Output(0)
op1, err := g.AddOperation(OpSpec{
Type: "Square",
Name: "y1",
Input: []Input{y0},
})
if err != nil {
t.Fatal(err)
}
y1 := op1.Output(0)
grads0, err := g.AddGradients([]Output{y1}, []Output{y0}, nil)
if err != nil {
t.Fatal(err)
}
if len(grads0) != 1 {
t.Fatal(len(grads0))
}
if grads0[0].DataType() != Float {
t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), Float)
}
grads1, err := g.AddGradients([]Output{y0}, []Output{x}, []Output{grads0[0]})
if err != nil {
t.Fatal(err)
}
if len(grads1) != 1 {
t.Fatal(len(grads1))
}
if grads1[0].DataType() != Float {
t.Fatalf("Got DataType %v, wanted %v", grads1[0].DataType(), Float)
}
sess, err := NewSession(g, nil)
if err != nil {
t.Fatal(err)
}
c, _ := NewTensor(float32(3.0))
outputs, err := sess.Run(
map[Output]*Tensor{x: c},
[]Output{grads1[0]},
nil)
if err != nil {
t.Fatal(err)
}
if outputs[0].Value().(float32) != 108.0 {
t.Fatalf("Got %v, wanted float 108.0", outputs[0].Value())
}
}
func TestGraphValidateGradientsNames(t *testing.T) {
g := NewGraph()
x, err := Placeholder(g, "x", Float)
if err != nil {
t.Fatal(err)
}
op0, err := g.AddOperation(OpSpec{
Type: "Square",
Name: "y0",
Input: []Input{x},
})
if err != nil {
t.Fatal(err)
}
y0 := op0.Output(0)
grads0, err := g.AddGradients([]Output{y0}, []Output{x}, nil)
if err != nil {
t.Fatal(err)
}
if !strings.HasPrefix(grads0[0].Op.Name(), "gradients/") {
t.Fatalf("Got name %v, wanted started with gradients/", grads0[0].Op.Name())
}
grads1, err := g.AddGradients([]Output{y0}, []Output{x}, nil)
if err != nil {
t.Fatal(err)
}
if !strings.HasPrefix(grads1[0].Op.Name(), "gradients_1/") {
t.Fatalf("Got name %v, wanted started with gradients_1/", grads1[0].Op.Name())
}
grads2, err := g.AddGradientsWithPrefix("more_gradients", []Output{y0}, []Output{x}, nil)
if err != nil {
t.Fatal(err)
}
if !strings.HasPrefix(grads2[0].Op.Name(), "more_gradients/") {
t.Fatalf("Got name %v, wanted started with more_gradients/", grads2[0].Op.Name())
}
grads3, err := g.AddGradientsWithPrefix("even_more_gradients", []Output{y0}, []Output{x}, nil)
if err != nil {
t.Fatal(err)
}
if !strings.HasPrefix(grads3[0].Op.Name(), "even_more_gradients/") {
t.Fatalf("Got name %v, wanted started with even_more_gradients/", grads3[0].Op.Name())
}
}