Go: Add PartialRun support.
Change: 147783087
This commit is contained in:
parent
5cdf2afa52
commit
c35e3b523a
@ -16,4 +16,14 @@ package tensorflow
|
||||
|
||||
// #cgo LDFLAGS: -ltensorflow
|
||||
// #cgo CFLAGS: -I${SRCDIR}/../../
|
||||
//
|
||||
// // TODO(ashankar): Remove this after TensorFlow 1.1 has been released.
|
||||
// // Till then, the TensorFlow C API binary releases do not contain
|
||||
// // the TF_DeletePRunHandle symbol. We work around that by
|
||||
// // implementing the equivalent in session.cpp
|
||||
// extern void tfDeletePRunHandle(const char*);
|
||||
import "C"
|
||||
|
||||
func deletePRunHandle(h *C.char) {
|
||||
C.tfDeletePRunHandle(h)
|
||||
}
|
||||
|
24
tensorflow/go/session.cpp
Normal file
24
tensorflow/go/session.cpp
Normal file
@ -0,0 +1,24 @@
|
||||
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// TODO(ashankar): Remove this file when TensorFlow 1.1 is released.
|
||||
// See lib.go for details.
|
||||
|
||||
extern "C" {
|
||||
extern void tfDeletePRunHandle(const char* h);
|
||||
}
|
||||
|
||||
void tfDeletePRunHandle(const char* h) {
|
||||
delete[] h;
|
||||
}
|
@ -59,14 +59,14 @@ func NewSession(graph *Graph, options *SessionOptions) (*Session, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Run the graph with the associated session starting with the supplied inputs.
|
||||
// inputs and outputs may be set to nil. Runs, but does not return Tensors
|
||||
// for operations specified in targets.
|
||||
// Run the graph with the associated session starting with the supplied feeds
|
||||
// to compute the value of the requested fetches. Runs, but does not return
|
||||
// Tensors for operations specified in targets.
|
||||
//
|
||||
// On success, returns the Tensor outputs in the same order as supplied in
|
||||
// the outputs argument. If outputs is set to nil, the returned Tensor outputs
|
||||
// On success, returns the fetched Tensors in the same order as supplied in
|
||||
// the fetches argument. If fetches is set to nil, the returned Tensor fetches
|
||||
// is empty.
|
||||
func (s *Session) Run(inputs map[Output]*Tensor, outputs []Output, targets []*Operation) ([]*Tensor, error) {
|
||||
func (s *Session) Run(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) ([]*Tensor, error) {
|
||||
s.mu.Lock()
|
||||
if s.c == nil {
|
||||
s.mu.Unlock()
|
||||
@ -76,56 +76,126 @@ func (s *Session) Run(inputs map[Output]*Tensor, outputs []Output, targets []*Op
|
||||
s.mu.Unlock()
|
||||
defer s.wg.Done()
|
||||
|
||||
var inputPorts []C.TF_Output
|
||||
var inputValues []*C.TF_Tensor
|
||||
if inputs != nil {
|
||||
for port, tensor := range inputs {
|
||||
inputPorts = append(inputPorts, port.c())
|
||||
inputValues = append(inputValues, tensor.c)
|
||||
}
|
||||
}
|
||||
|
||||
var outputPorts []C.TF_Output
|
||||
for _, port := range outputs {
|
||||
outputPorts = append(outputPorts, port.c())
|
||||
}
|
||||
outputValues := make([]*C.TF_Tensor, len(outputs))
|
||||
var cTargets []*C.TF_Operation
|
||||
for _, target := range targets {
|
||||
cTargets = append(cTargets, target.c)
|
||||
}
|
||||
|
||||
c := newCRunArgs(feeds, fetches, targets)
|
||||
status := newStatus()
|
||||
var inputPortsPtr *C.TF_Output
|
||||
var inputValuesPtr **C.TF_Tensor
|
||||
if len(inputPorts) > 0 {
|
||||
inputPortsPtr = &inputPorts[0]
|
||||
inputValuesPtr = &inputValues[0]
|
||||
}
|
||||
|
||||
var outputPortsPtr *C.TF_Output
|
||||
var outputValuesPtr **C.TF_Tensor
|
||||
if len(outputPorts) > 0 {
|
||||
outputPortsPtr = &outputPorts[0]
|
||||
outputValuesPtr = &outputValues[0]
|
||||
}
|
||||
|
||||
var cTargetsPtr **C.TF_Operation
|
||||
if len(cTargets) > 0 {
|
||||
cTargetsPtr = &cTargets[0]
|
||||
}
|
||||
|
||||
C.TF_SessionRun(s.c, nil, inputPortsPtr, inputValuesPtr, C.int(len(inputPorts)), outputPortsPtr, outputValuesPtr, C.int(len(outputPorts)), cTargetsPtr, C.int(len(cTargets)), nil, status.c)
|
||||
C.TF_SessionRun(s.c, nil,
|
||||
ptrOutput(c.feeds), ptrTensor(c.feedTensors), C.int(len(feeds)),
|
||||
ptrOutput(c.fetches), ptrTensor(c.fetchTensors), C.int(len(fetches)),
|
||||
ptrOperation(c.targets), C.int(len(targets)),
|
||||
nil, status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.toGo(), nil
|
||||
}
|
||||
|
||||
tensors := make([]*Tensor, len(outputValues))
|
||||
for i, val := range outputValues {
|
||||
tensors[i] = newTensorFromC(val)
|
||||
// PartialRun enables incremental evaluation of graphs.
|
||||
//
|
||||
// PartialRun allows the caller to pause the evaluation of a graph, run
|
||||
// arbitrary code that depends on the intermediate computation of the graph,
|
||||
// and then resume graph execution. The results of the arbitrary code can be
|
||||
// fed into the graph when resuming execution. In contrast, Session.Run
|
||||
// executes the graph to compute the requested fetches using the provided feeds
|
||||
// and discards all intermediate state (e.g., value of intermediate tensors)
|
||||
// when it returns.
|
||||
//
|
||||
// For example, consider a graph for unsupervised training of a neural network
|
||||
// model. PartialRun can be used to pause execution after the forward pass of
|
||||
// the network, let the caller actuate the output (e.g., play a game, actuate a
|
||||
// robot etc.), determine the error/loss and then feed this calculated loss
|
||||
// when resuming the backward pass of the graph.
|
||||
type PartialRun struct {
|
||||
session *Session
|
||||
handle *C.char
|
||||
}
|
||||
|
||||
// Run resumes execution of the graph to compute the requested fetches and
|
||||
// targets with the provided feeds.
|
||||
func (pr *PartialRun) Run(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) ([]*Tensor, error) {
|
||||
var (
|
||||
c = newCRunArgs(feeds, fetches, targets)
|
||||
status = newStatus()
|
||||
s = pr.session
|
||||
)
|
||||
s.mu.Lock()
|
||||
if s.c == nil {
|
||||
s.mu.Unlock()
|
||||
return nil, errors.New("session is closed")
|
||||
}
|
||||
s.wg.Add(1)
|
||||
s.mu.Unlock()
|
||||
defer s.wg.Done()
|
||||
|
||||
C.TF_SessionPRun(s.c, pr.handle,
|
||||
ptrOutput(c.feeds), ptrTensor(c.feedTensors), C.int(len(feeds)),
|
||||
ptrOutput(c.fetches), ptrTensor(c.fetchTensors), C.int(len(fetches)),
|
||||
ptrOperation(c.targets), C.int(len(targets)),
|
||||
status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.toGo(), nil
|
||||
}
|
||||
|
||||
// NewPartialRun sets up the graph for incremental evaluation.
|
||||
//
|
||||
// All values of feeds, fetches and targets that may be provided to Run calls
|
||||
// on the returned PartialRun need to be provided to NewPartialRun.
|
||||
//
|
||||
// See documentation for the PartialRun type.
|
||||
func (s *Session) NewPartialRun(feeds, fetches []Output, targets []*Operation) (*PartialRun, error) {
|
||||
var (
|
||||
cfeeds = make([]C.TF_Output, len(feeds))
|
||||
cfetches = make([]C.TF_Output, len(fetches))
|
||||
ctargets = make([]*C.TF_Operation, len(targets))
|
||||
|
||||
pcfeeds *C.TF_Output
|
||||
pcfetches *C.TF_Output
|
||||
pctargets **C.TF_Operation
|
||||
|
||||
status = newStatus()
|
||||
)
|
||||
if len(feeds) > 0 {
|
||||
pcfeeds = &cfeeds[0]
|
||||
for i, o := range feeds {
|
||||
cfeeds[i] = o.c()
|
||||
}
|
||||
}
|
||||
if len(fetches) > 0 {
|
||||
pcfetches = &cfetches[0]
|
||||
for i, o := range fetches {
|
||||
cfetches[i] = o.c()
|
||||
}
|
||||
}
|
||||
if len(targets) > 0 {
|
||||
pctargets = &ctargets[0]
|
||||
for i, o := range targets {
|
||||
ctargets[i] = o.c
|
||||
}
|
||||
}
|
||||
|
||||
return tensors, nil
|
||||
s.mu.Lock()
|
||||
if s.c == nil {
|
||||
s.mu.Unlock()
|
||||
return nil, errors.New("session is closed")
|
||||
}
|
||||
s.wg.Add(1)
|
||||
s.mu.Unlock()
|
||||
defer s.wg.Done()
|
||||
|
||||
pr := &PartialRun{session: s}
|
||||
C.TF_SessionPRunSetup(s.c,
|
||||
pcfeeds, C.int(len(feeds)),
|
||||
pcfetches, C.int(len(fetches)),
|
||||
pctargets, C.int(len(targets)),
|
||||
&pr.handle, status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
runtime.SetFinalizer(pr, func(pr *PartialRun) {
|
||||
deletePRunHandle(pr.handle)
|
||||
})
|
||||
return pr, nil
|
||||
}
|
||||
|
||||
// Close a session. This contacts any other processes associated with this
|
||||
@ -187,3 +257,61 @@ func (o *SessionOptions) c() *C.TF_SessionOptions {
|
||||
C.free(unsafe.Pointer(t))
|
||||
return opt
|
||||
}
|
||||
|
||||
// cRunArgs translates the arguments to Session.Run and PartialRun.Run into
|
||||
// values suitable for C library calls.
|
||||
type cRunArgs struct {
|
||||
feeds []C.TF_Output
|
||||
feedTensors []*C.TF_Tensor
|
||||
fetches []C.TF_Output
|
||||
fetchTensors []*C.TF_Tensor
|
||||
targets []*C.TF_Operation
|
||||
}
|
||||
|
||||
func newCRunArgs(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) *cRunArgs {
|
||||
c := &cRunArgs{
|
||||
fetches: make([]C.TF_Output, len(fetches)),
|
||||
fetchTensors: make([]*C.TF_Tensor, len(fetches)),
|
||||
targets: make([]*C.TF_Operation, len(targets)),
|
||||
}
|
||||
for o, t := range feeds {
|
||||
c.feeds = append(c.feeds, o.c())
|
||||
c.feedTensors = append(c.feedTensors, t.c)
|
||||
}
|
||||
for i, o := range fetches {
|
||||
c.fetches[i] = o.c()
|
||||
}
|
||||
for i, t := range targets {
|
||||
c.targets[i] = t.c
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *cRunArgs) toGo() []*Tensor {
|
||||
ret := make([]*Tensor, len(c.fetchTensors))
|
||||
for i, ct := range c.fetchTensors {
|
||||
ret[i] = newTensorFromC(ct)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func ptrOutput(l []C.TF_Output) *C.TF_Output {
|
||||
if len(l) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &l[0]
|
||||
}
|
||||
|
||||
func ptrTensor(l []*C.TF_Tensor) **C.TF_Tensor {
|
||||
if len(l) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &l[0]
|
||||
}
|
||||
|
||||
func ptrOperation(l []*C.TF_Operation) **C.TF_Operation {
|
||||
if len(l) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &l[0]
|
||||
}
|
||||
|
@ -181,3 +181,68 @@ func TestConcurrency(t *testing.T) {
|
||||
t.Errorf("Close() 2: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func ExamplePartialRun() {
|
||||
var (
|
||||
// Create a graph: a + 2 + 3 + b.
|
||||
//
|
||||
// Skipping error handling for brevity of this example.
|
||||
// The 'op' package can be used to make graph construction code
|
||||
// with error handling more succinct.
|
||||
g = NewGraph()
|
||||
a, _ = Placeholder(g, "a", Int32)
|
||||
b, _ = Placeholder(g, "b", Int32)
|
||||
two, _ = Const(g, "Two", int32(2))
|
||||
three, _ = Const(g, "Three", int32(3))
|
||||
|
||||
plus2, _ = Add(g, "plus2", a, two) // a + 2
|
||||
plus3, _ = Add(g, "plus3", plus2, three) // (a + 2) + 3
|
||||
plusB, _ = Add(g, "plusB", plus3, b) // ((a + 2) + 3) + b
|
||||
|
||||
)
|
||||
sess, err := NewSession(g, nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer sess.Close()
|
||||
|
||||
// All the feeds, fetches and targets for subsequent PartialRun.Run
|
||||
// calls must be provided at setup.
|
||||
pr, err := sess.NewPartialRun(
|
||||
[]Output{a, b},
|
||||
[]Output{plus2, plusB},
|
||||
[]*Operation{plus3.Op},
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Feed 'a=1', fetch 'plus2', and compute (but do not fetch) 'plus3'.
|
||||
// Imagine this to be the forward pass of unsupervised neural network
|
||||
// training of a robot.
|
||||
val, _ := NewTensor(int32(1))
|
||||
fetches, err := pr.Run(
|
||||
map[Output]*Tensor{a: val},
|
||||
[]Output{plus2},
|
||||
nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
v1 := fetches[0].Value().(int32)
|
||||
|
||||
// Now, feed 'b=4', fetch 'plusB=a+2+3+b'
|
||||
// Imagine this to be the result of actuating the robot to determine
|
||||
// the error produced by the current state of the neural network.
|
||||
val, _ = NewTensor(int32(4))
|
||||
fetches, err = pr.Run(
|
||||
map[Output]*Tensor{b: val},
|
||||
[]Output{plusB},
|
||||
nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
v2 := fetches[0].Value().(int32)
|
||||
|
||||
fmt.Println(v1, v2)
|
||||
// Output: 3 10
|
||||
}
|
||||
|
@ -46,9 +46,18 @@ func Const(g *Graph, name string, value interface{}) (Output, error) {
|
||||
|
||||
func Neg(g *Graph, name string, port Output) (Output, error) {
|
||||
op, err := g.AddOperation(OpSpec{
|
||||
Type: "Neg",
|
||||
Name: name,
|
||||
Type: "Neg",
|
||||
Name: name,
|
||||
Input: []Input{port},
|
||||
})
|
||||
return op.Output(0), err
|
||||
}
|
||||
|
||||
func Add(g *Graph, name string, x, y Output) (Output, error) {
|
||||
op, err := g.AddOperation(OpSpec{
|
||||
Type: "Add",
|
||||
Name: name,
|
||||
Input: []Input{x, y},
|
||||
})
|
||||
return op.Output(0), err
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user