[Go]: Support device annotations when constructing graphs.

PiperOrigin-RevId: 204225504
This commit is contained in:
Asim Shankar 2018-07-11 18:17:25 -07:00 committed by TensorFlower Gardener
parent 5574d6041a
commit 4ddcd6999a
5 changed files with 84 additions and 5 deletions

View File

@ -177,7 +177,14 @@ type OpSpec struct {
// being added.
ControlDependencies []*Operation
// Other possible fields: Device, ColocateWith.
// The device on which the operation should be executed.
// If omitted, an appropriate device will automatically be selected.
//
// For example, if set of "/device:GPU:0", then the operation will
// execute on GPU #0.
Device string
// Other possible fields: ColocateWith.
}
// AddOperation adds an operation to g.
@ -225,6 +232,11 @@ func (g *Graph) AddOperation(args OpSpec) (*Operation, error) {
return nil, fmt.Errorf("%v (memory will be leaked)", err)
}
}
if len(args.Device) > 0 {
cdevice := C.CString(args.Device)
C.TF_SetDevice(cdesc, cdevice)
C.free(unsafe.Pointer(cdevice))
}
c := C.TF_FinishOperation(cdesc, status.c)
if err := status.Err(); err != nil {
return nil, err

View File

@ -37,6 +37,7 @@ type Scope struct {
namemap map[string]int
namespace string
controlDependencies []*tf.Operation
device string
err *scopeErr
}
@ -82,6 +83,7 @@ func (s *Scope) AddOperation(args tf.OpSpec) *tf.Operation {
args.Name = s.namespace + "/" + args.Name
}
args.ControlDependencies = append(args.ControlDependencies, s.controlDependencies...)
args.Device = s.device
op, err := s.graph.AddOperation(args)
if err != nil {
s.UpdateErr(args.Type, err)
@ -98,10 +100,12 @@ func (s *Scope) SubScope(namespace string) *Scope {
namespace = s.namespace + "/" + namespace
}
return &Scope{
graph: s.graph,
namemap: make(map[string]int),
namespace: namespace,
err: s.err,
graph: s.graph,
namemap: make(map[string]int),
namespace: namespace,
controlDependencies: s.controlDependencies,
device: s.device,
err: s.err,
}
}
@ -123,6 +127,25 @@ func (s *Scope) WithControlDependencies(ops ...*tf.Operation) *Scope {
namemap: s.namemap,
namespace: s.namespace,
controlDependencies: deps,
device: s.device,
err: s.err,
}
}
// WithDevice returns a new Scope which will cause all operations added to the
// graph to execute on devices that match the provided device specification.
//
// For example, WithDevice("/device:GPU:0") will cause operations added to
// the graph to execute on GPU #0.
//
// An empty string removes any device restrictions.
func (s *Scope) WithDevice(device string) *Scope {
return &Scope{
graph: s.graph,
namemap: s.namemap,
namespace: s.namespace,
controlDependencies: s.controlDependencies,
device: device,
err: s.err,
}
}

View File

@ -112,6 +112,21 @@ func TestControlDependencies(t *testing.T) {
}
}
func TestDevice(t *testing.T) {
s := NewScope()
matrix := Const(s, [][]float32{{3.0}})
s = s.WithDevice("/device:GPU:0")
square := MatMul(s.SubScope("square"), matrix, matrix)
s = s.WithDevice("")
cube := MatMul(s.SubScope("cube"), square, matrix)
if got, want := square.Op.Device(), "/device:GPU:0"; got != want {
t.Errorf("Got %q, want %q", got, want)
}
if got, want := cube.Op.Device(), ""; got != want {
t.Errorf("Got %q, want %q", got, want)
}
}
func TestScopeFinalize(t *testing.T) {
var (
root = NewScope()

View File

@ -45,6 +45,12 @@ func (op *Operation) NumOutputs() int {
return int(C.TF_OperationNumOutputs(op.c))
}
// Device returns a specification of the device on which this operation
// will be executed, or the empty string if there is no such specification.
func (op *Operation) Device() string {
return C.GoString(C.TF_OperationDevice(op.c))
}
// OutputListSize returns the size of the list of Outputs that is produced by a
// named output of op.
//

View File

@ -228,6 +228,29 @@ func TestOperationConsumers(t *testing.T) {
}
}
func TestOperationDevice(t *testing.T) {
graph := NewGraph()
v, err := NewTensor(float32(1.0))
if err != nil {
t.Fatal(err)
}
op, err := graph.AddOperation(OpSpec{
Type: "Const",
Name: "Const",
Attrs: map[string]interface{}{
"dtype": v.DataType(),
"value": v,
},
Device: "/device:GPU:0",
})
if err != nil {
t.Fatal(err)
}
if got, want := op.Device(), "/device:GPU:0"; got != want {
t.Errorf("Got %q, want %q", got, want)
}
}
func forceGC() {
var mem runtime.MemStats
runtime.ReadMemStats(&mem)