go: Add input mapping option when importing Graph

This commit is contained in:
Chris Tessum 2020-05-01 11:12:54 -05:00
parent 3f3ab6a249
commit 2328b196ba
2 changed files with 97 additions and 0 deletions

View File

@ -61,9 +61,33 @@ type GraphImportOptions struct {
// Execution device
Device string
// inputMapping defines a mapping between Outputs in the graph
// and Outputs they should be replaced with.
inputMapping map[struct {
Name string
Index int
}]Output
// TODO: extend this structure to support more options from TF_ImportGraphDefOptions
}
// AddInputMapping adds a mapping between an Output in the imported graph
// and an Ouput in the destination graph that it should be replaced with,
// where src:srcIndex is the name of the Operation and Output index to
// replace and dst is the output to replace it with.
func (o *GraphImportOptions) AddInputMapping(src string, srcIndex int, dst Output) {
if o.inputMapping == nil {
o.inputMapping = make(map[struct {
Name string
Index int
}]Output)
}
o.inputMapping[struct {
Name string
Index int
}{src, srcIndex}] = dst
}
// NewGraph returns a new Graph.
func NewGraph() *Graph {
g := &Graph{C.TF_NewGraph()}
@ -122,6 +146,12 @@ func (g *Graph) ImportWithOptions(def []byte, options GraphImportOptions) error
C.TF_ImportGraphDefOptionsSetDefaultDevice(opts, cdev)
}
for src, dst := range options.inputMapping {
cSrcName := C.CString(src.Name)
C.TF_ImportGraphDefOptionsAddInputMapping(opts, cSrcName, C.int(src.Index), dst.c())
C.free(unsafe.Pointer(cSrcName))
}
buf := C.TF_NewBuffer()
defer C.TF_DeleteBuffer(buf)
buf.length = C.size_t(len(def))

View File

@ -82,6 +82,73 @@ func TestGraphWriteToAndImport(t *testing.T) {
}
}
func TestGraphInputMapping(t *testing.T) {
// Construct a graph
g := NewGraph()
v, err := NewTensor(int64(1))
if err != nil {
t.Fatal(err)
}
input, err := Placeholder(g, "input", v.DataType())
if err != nil {
t.Fatal(err)
}
neg, err := Neg(g, "neg", input)
if err != nil {
t.Fatal(err)
}
// Serialize the graph
buf := new(bytes.Buffer)
if _, err := g.WriteTo(buf); err != nil {
t.Fatal(err)
}
g = NewGraph()
v, err = NewTensor(int64(1))
if err != nil {
t.Fatal(err)
}
replacement, err := Placeholder(g, "replacement", v.DataType())
if err != nil {
t.Fatal(err)
}
options := GraphImportOptions{
Prefix: "imported",
}
options.AddInputMapping("input", 0, replacement)
// Import it into the same graph, with a prefix and replacement
if err := g.ImportWithOptions(buf.Bytes(), options); err != nil {
t.Error(err)
}
if err := hasOperations(g, "replacement", "imported/neg"); err != nil {
t.Error(err)
}
sess, err := NewSession(g, nil)
if err != nil {
t.Fatal(err)
}
neg = g.Operation("imported/neg").Output(0)
outputs, err := sess.Run(
map[Output]*Tensor{replacement: v},
[]Output{neg},
nil)
if err != nil {
t.Fatal(err)
}
if len(outputs) != 1 {
t.Fatal(len(outputs))
}
if outputs[0].Value().(int64) != -1 {
t.Fatalf("Got %v, wanted int64 -1", outputs[0].Value())
}
}
func TestGraphAddGradients(t *testing.T) {
g := NewGraph()
x1, err := Placeholder(g, "x1", Float)