From 5d52b95279be57076a794c2f334c150a26566360 Mon Sep 17 00:00:00 2001 From: Vishvananda Ishaya Abrams Date: Wed, 29 Nov 2017 22:25:37 -0800 Subject: [PATCH] Adds Operations() method to Graph There is currently no way to list all of the operations in a graph from the go api. This patch ads an Operations() method to retrieve the list using the existing TF_GraphNextOperation c api. The graph_test was modified to include testing this new method. Signed-off-by: Vishvananda Ishaya Abrams --- tensorflow/go/graph.go | 14 ++++++++++++++ tensorflow/go/graph_test.go | 22 +++++++++++++++++++--- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go index 46c600eab17..a40aded3bf4 100644 --- a/tensorflow/go/graph.go +++ b/tensorflow/go/graph.go @@ -114,6 +114,20 @@ func (g *Graph) Operation(name string) *Operation { return &Operation{cop, g} } +// Operations returns a list of all operations in the graph +func (g *Graph) Operations() []Operation { + var pos C.size_t = 0 + ops := []Operation{} + for { + cop := C.TF_GraphNextOperation(g.c, &pos) + if cop == nil { + break + } + ops = append(ops, Operation{cop, g}) + } + return ops +} + // OpSpec is the specification of an Operation to be added to a Graph // (using Graph.AddOperation). type OpSpec struct { diff --git a/tensorflow/go/graph_test.go b/tensorflow/go/graph_test.go index c3120bc7203..b8d65c54f69 100644 --- a/tensorflow/go/graph_test.go +++ b/tensorflow/go/graph_test.go @@ -29,10 +29,26 @@ func hasOperations(g *Graph, ops ...string) error { missing = append(missing, op) } } - if len(missing) == 0 { - return nil + if len(missing) != 0 { + return fmt.Errorf("Graph does not have the operations %v", missing) } - return fmt.Errorf("Graph does not have the operations %v", missing) + + inList := map[string]bool{} + for _, op := range g.Operations() { + inList[op.Name()] = true + } + + for _, op := range ops { + if !inList[op] { + missing = append(missing, op) + } + } + + if len(missing) != 0 { + return fmt.Errorf("Operations %v are missing from graph.Operations()", missing) + } + + return nil } func TestGraphWriteToAndImport(t *testing.T) {