go: Add input mapping option when importing Graph
This commit is contained in:
parent
3f3ab6a249
commit
2328b196ba
@ -61,9 +61,33 @@ type GraphImportOptions struct {
|
||||
// Execution device
|
||||
Device string
|
||||
|
||||
// inputMapping defines a mapping between Outputs in the graph
|
||||
// and Outputs they should be replaced with.
|
||||
inputMapping map[struct {
|
||||
Name string
|
||||
Index int
|
||||
}]Output
|
||||
|
||||
// TODO: extend this structure to support more options from TF_ImportGraphDefOptions
|
||||
}
|
||||
|
||||
// AddInputMapping adds a mapping between an Output in the imported graph
|
||||
// and an Ouput in the destination graph that it should be replaced with,
|
||||
// where src:srcIndex is the name of the Operation and Output index to
|
||||
// replace and dst is the output to replace it with.
|
||||
func (o *GraphImportOptions) AddInputMapping(src string, srcIndex int, dst Output) {
|
||||
if o.inputMapping == nil {
|
||||
o.inputMapping = make(map[struct {
|
||||
Name string
|
||||
Index int
|
||||
}]Output)
|
||||
}
|
||||
o.inputMapping[struct {
|
||||
Name string
|
||||
Index int
|
||||
}{src, srcIndex}] = dst
|
||||
}
|
||||
|
||||
// NewGraph returns a new Graph.
|
||||
func NewGraph() *Graph {
|
||||
g := &Graph{C.TF_NewGraph()}
|
||||
@ -122,6 +146,12 @@ func (g *Graph) ImportWithOptions(def []byte, options GraphImportOptions) error
|
||||
C.TF_ImportGraphDefOptionsSetDefaultDevice(opts, cdev)
|
||||
}
|
||||
|
||||
for src, dst := range options.inputMapping {
|
||||
cSrcName := C.CString(src.Name)
|
||||
C.TF_ImportGraphDefOptionsAddInputMapping(opts, cSrcName, C.int(src.Index), dst.c())
|
||||
C.free(unsafe.Pointer(cSrcName))
|
||||
}
|
||||
|
||||
buf := C.TF_NewBuffer()
|
||||
defer C.TF_DeleteBuffer(buf)
|
||||
buf.length = C.size_t(len(def))
|
||||
|
@ -82,6 +82,73 @@ func TestGraphWriteToAndImport(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGraphInputMapping(t *testing.T) {
|
||||
// Construct a graph
|
||||
g := NewGraph()
|
||||
v, err := NewTensor(int64(1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
input, err := Placeholder(g, "input", v.DataType())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
neg, err := Neg(g, "neg", input)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Serialize the graph
|
||||
buf := new(bytes.Buffer)
|
||||
if _, err := g.WriteTo(buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
g = NewGraph()
|
||||
v, err = NewTensor(int64(1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
replacement, err := Placeholder(g, "replacement", v.DataType())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
options := GraphImportOptions{
|
||||
Prefix: "imported",
|
||||
}
|
||||
options.AddInputMapping("input", 0, replacement)
|
||||
// Import it into the same graph, with a prefix and replacement
|
||||
if err := g.ImportWithOptions(buf.Bytes(), options); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := hasOperations(g, "replacement", "imported/neg"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
sess, err := NewSession(g, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
neg = g.Operation("imported/neg").Output(0)
|
||||
|
||||
outputs, err := sess.Run(
|
||||
map[Output]*Tensor{replacement: v},
|
||||
[]Output{neg},
|
||||
nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(outputs) != 1 {
|
||||
t.Fatal(len(outputs))
|
||||
}
|
||||
if outputs[0].Value().(int64) != -1 {
|
||||
t.Fatalf("Got %v, wanted int64 -1", outputs[0].Value())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGraphAddGradients(t *testing.T) {
|
||||
g := NewGraph()
|
||||
x1, err := Placeholder(g, "x1", Float)
|
||||
|
Loading…
Reference in New Issue
Block a user