diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go index b3b2c9cc20a..ac28c3ac5bd 100644 --- a/tensorflow/go/graph.go +++ b/tensorflow/go/graph.go @@ -61,9 +61,33 @@ type GraphImportOptions struct { // Execution device Device string + // inputMapping defines a mapping between Outputs in the graph + // and Outputs they should be replaced with. + inputMapping map[struct { + Name string + Index int + }]Output + // TODO: extend this structure to support more options from TF_ImportGraphDefOptions } +// AddInputMapping adds a mapping between an Output in the imported graph +// and an Ouput in the destination graph that it should be replaced with, +// where src:srcIndex is the name of the Operation and Output index to +// replace and dst is the output to replace it with. +func (o *GraphImportOptions) AddInputMapping(src string, srcIndex int, dst Output) { + if o.inputMapping == nil { + o.inputMapping = make(map[struct { + Name string + Index int + }]Output) + } + o.inputMapping[struct { + Name string + Index int + }{src, srcIndex}] = dst +} + // NewGraph returns a new Graph. func NewGraph() *Graph { g := &Graph{C.TF_NewGraph()} @@ -122,6 +146,12 @@ func (g *Graph) ImportWithOptions(def []byte, options GraphImportOptions) error C.TF_ImportGraphDefOptionsSetDefaultDevice(opts, cdev) } + for src, dst := range options.inputMapping { + cSrcName := C.CString(src.Name) + C.TF_ImportGraphDefOptionsAddInputMapping(opts, cSrcName, C.int(src.Index), dst.c()) + C.free(unsafe.Pointer(cSrcName)) + } + buf := C.TF_NewBuffer() defer C.TF_DeleteBuffer(buf) buf.length = C.size_t(len(def)) diff --git a/tensorflow/go/graph_test.go b/tensorflow/go/graph_test.go index 067c7db5c3c..bb112303807 100644 --- a/tensorflow/go/graph_test.go +++ b/tensorflow/go/graph_test.go @@ -82,6 +82,73 @@ func TestGraphWriteToAndImport(t *testing.T) { } } +func TestGraphInputMapping(t *testing.T) { + // Construct a graph + g := NewGraph() + v, err := NewTensor(int64(1)) + if err != nil { + t.Fatal(err) + } + input, err := Placeholder(g, "input", v.DataType()) + if err != nil { + t.Fatal(err) + } + neg, err := Neg(g, "neg", input) + if err != nil { + t.Fatal(err) + } + + // Serialize the graph + buf := new(bytes.Buffer) + if _, err := g.WriteTo(buf); err != nil { + t.Fatal(err) + } + + g = NewGraph() + v, err = NewTensor(int64(1)) + if err != nil { + t.Fatal(err) + } + + replacement, err := Placeholder(g, "replacement", v.DataType()) + if err != nil { + t.Fatal(err) + } + + options := GraphImportOptions{ + Prefix: "imported", + } + options.AddInputMapping("input", 0, replacement) + // Import it into the same graph, with a prefix and replacement + if err := g.ImportWithOptions(buf.Bytes(), options); err != nil { + t.Error(err) + } + if err := hasOperations(g, "replacement", "imported/neg"); err != nil { + t.Error(err) + } + + sess, err := NewSession(g, nil) + if err != nil { + t.Fatal(err) + } + + neg = g.Operation("imported/neg").Output(0) + + outputs, err := sess.Run( + map[Output]*Tensor{replacement: v}, + []Output{neg}, + nil) + if err != nil { + t.Fatal(err) + } + if len(outputs) != 1 { + t.Fatal(len(outputs)) + } + if outputs[0].Value().(int64) != -1 { + t.Fatalf("Got %v, wanted int64 -1", outputs[0].Value()) + } +} + func TestGraphAddGradients(t *testing.T) { g := NewGraph() x1, err := Placeholder(g, "x1", Float)