[Go]: Support device annotations when constructing graphs.

PiperOrigin-RevId: 204225504
This commit is contained in:
Asim Shankar 2018-07-11 18:17:25 -07:00 committed by TensorFlower Gardener
parent 5574d6041a
commit 4ddcd6999a
5 changed files with 84 additions and 5 deletions

View File

@ -177,7 +177,14 @@ type OpSpec struct {
// being added. // being added.
ControlDependencies []*Operation ControlDependencies []*Operation
// Other possible fields: Device, ColocateWith. // The device on which the operation should be executed.
// If omitted, an appropriate device will automatically be selected.
//
// For example, if set of "/device:GPU:0", then the operation will
// execute on GPU #0.
Device string
// Other possible fields: ColocateWith.
} }
// AddOperation adds an operation to g. // AddOperation adds an operation to g.
@ -225,6 +232,11 @@ func (g *Graph) AddOperation(args OpSpec) (*Operation, error) {
return nil, fmt.Errorf("%v (memory will be leaked)", err) return nil, fmt.Errorf("%v (memory will be leaked)", err)
} }
} }
if len(args.Device) > 0 {
cdevice := C.CString(args.Device)
C.TF_SetDevice(cdesc, cdevice)
C.free(unsafe.Pointer(cdevice))
}
c := C.TF_FinishOperation(cdesc, status.c) c := C.TF_FinishOperation(cdesc, status.c)
if err := status.Err(); err != nil { if err := status.Err(); err != nil {
return nil, err return nil, err

View File

@ -37,6 +37,7 @@ type Scope struct {
namemap map[string]int namemap map[string]int
namespace string namespace string
controlDependencies []*tf.Operation controlDependencies []*tf.Operation
device string
err *scopeErr err *scopeErr
} }
@ -82,6 +83,7 @@ func (s *Scope) AddOperation(args tf.OpSpec) *tf.Operation {
args.Name = s.namespace + "/" + args.Name args.Name = s.namespace + "/" + args.Name
} }
args.ControlDependencies = append(args.ControlDependencies, s.controlDependencies...) args.ControlDependencies = append(args.ControlDependencies, s.controlDependencies...)
args.Device = s.device
op, err := s.graph.AddOperation(args) op, err := s.graph.AddOperation(args)
if err != nil { if err != nil {
s.UpdateErr(args.Type, err) s.UpdateErr(args.Type, err)
@ -98,10 +100,12 @@ func (s *Scope) SubScope(namespace string) *Scope {
namespace = s.namespace + "/" + namespace namespace = s.namespace + "/" + namespace
} }
return &Scope{ return &Scope{
graph: s.graph, graph: s.graph,
namemap: make(map[string]int), namemap: make(map[string]int),
namespace: namespace, namespace: namespace,
err: s.err, controlDependencies: s.controlDependencies,
device: s.device,
err: s.err,
} }
} }
@ -123,6 +127,25 @@ func (s *Scope) WithControlDependencies(ops ...*tf.Operation) *Scope {
namemap: s.namemap, namemap: s.namemap,
namespace: s.namespace, namespace: s.namespace,
controlDependencies: deps, controlDependencies: deps,
device: s.device,
err: s.err,
}
}
// WithDevice returns a new Scope which will cause all operations added to the
// graph to execute on devices that match the provided device specification.
//
// For example, WithDevice("/device:GPU:0") will cause operations added to
// the graph to execute on GPU #0.
//
// An empty string removes any device restrictions.
func (s *Scope) WithDevice(device string) *Scope {
return &Scope{
graph: s.graph,
namemap: s.namemap,
namespace: s.namespace,
controlDependencies: s.controlDependencies,
device: device,
err: s.err, err: s.err,
} }
} }

View File

@ -112,6 +112,21 @@ func TestControlDependencies(t *testing.T) {
} }
} }
func TestDevice(t *testing.T) {
s := NewScope()
matrix := Const(s, [][]float32{{3.0}})
s = s.WithDevice("/device:GPU:0")
square := MatMul(s.SubScope("square"), matrix, matrix)
s = s.WithDevice("")
cube := MatMul(s.SubScope("cube"), square, matrix)
if got, want := square.Op.Device(), "/device:GPU:0"; got != want {
t.Errorf("Got %q, want %q", got, want)
}
if got, want := cube.Op.Device(), ""; got != want {
t.Errorf("Got %q, want %q", got, want)
}
}
func TestScopeFinalize(t *testing.T) { func TestScopeFinalize(t *testing.T) {
var ( var (
root = NewScope() root = NewScope()

View File

@ -45,6 +45,12 @@ func (op *Operation) NumOutputs() int {
return int(C.TF_OperationNumOutputs(op.c)) return int(C.TF_OperationNumOutputs(op.c))
} }
// Device returns a specification of the device on which this operation
// will be executed, or the empty string if there is no such specification.
func (op *Operation) Device() string {
return C.GoString(C.TF_OperationDevice(op.c))
}
// OutputListSize returns the size of the list of Outputs that is produced by a // OutputListSize returns the size of the list of Outputs that is produced by a
// named output of op. // named output of op.
// //

View File

@ -228,6 +228,29 @@ func TestOperationConsumers(t *testing.T) {
} }
} }
func TestOperationDevice(t *testing.T) {
graph := NewGraph()
v, err := NewTensor(float32(1.0))
if err != nil {
t.Fatal(err)
}
op, err := graph.AddOperation(OpSpec{
Type: "Const",
Name: "Const",
Attrs: map[string]interface{}{
"dtype": v.DataType(),
"value": v,
},
Device: "/device:GPU:0",
})
if err != nil {
t.Fatal(err)
}
if got, want := op.Device(), "/device:GPU:0"; got != want {
t.Errorf("Got %q, want %q", got, want)
}
}
func forceGC() { func forceGC() {
var mem runtime.MemStats var mem runtime.MemStats
runtime.ReadMemStats(&mem) runtime.ReadMemStats(&mem)