[Go]: Support device annotations when constructing graphs.
PiperOrigin-RevId: 204225504
This commit is contained in:
parent
5574d6041a
commit
4ddcd6999a
@ -177,7 +177,14 @@ type OpSpec struct {
|
||||
// being added.
|
||||
ControlDependencies []*Operation
|
||||
|
||||
// Other possible fields: Device, ColocateWith.
|
||||
// The device on which the operation should be executed.
|
||||
// If omitted, an appropriate device will automatically be selected.
|
||||
//
|
||||
// For example, if set of "/device:GPU:0", then the operation will
|
||||
// execute on GPU #0.
|
||||
Device string
|
||||
|
||||
// Other possible fields: ColocateWith.
|
||||
}
|
||||
|
||||
// AddOperation adds an operation to g.
|
||||
@ -225,6 +232,11 @@ func (g *Graph) AddOperation(args OpSpec) (*Operation, error) {
|
||||
return nil, fmt.Errorf("%v (memory will be leaked)", err)
|
||||
}
|
||||
}
|
||||
if len(args.Device) > 0 {
|
||||
cdevice := C.CString(args.Device)
|
||||
C.TF_SetDevice(cdesc, cdevice)
|
||||
C.free(unsafe.Pointer(cdevice))
|
||||
}
|
||||
c := C.TF_FinishOperation(cdesc, status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, err
|
||||
|
@ -37,6 +37,7 @@ type Scope struct {
|
||||
namemap map[string]int
|
||||
namespace string
|
||||
controlDependencies []*tf.Operation
|
||||
device string
|
||||
err *scopeErr
|
||||
}
|
||||
|
||||
@ -82,6 +83,7 @@ func (s *Scope) AddOperation(args tf.OpSpec) *tf.Operation {
|
||||
args.Name = s.namespace + "/" + args.Name
|
||||
}
|
||||
args.ControlDependencies = append(args.ControlDependencies, s.controlDependencies...)
|
||||
args.Device = s.device
|
||||
op, err := s.graph.AddOperation(args)
|
||||
if err != nil {
|
||||
s.UpdateErr(args.Type, err)
|
||||
@ -98,10 +100,12 @@ func (s *Scope) SubScope(namespace string) *Scope {
|
||||
namespace = s.namespace + "/" + namespace
|
||||
}
|
||||
return &Scope{
|
||||
graph: s.graph,
|
||||
namemap: make(map[string]int),
|
||||
namespace: namespace,
|
||||
err: s.err,
|
||||
graph: s.graph,
|
||||
namemap: make(map[string]int),
|
||||
namespace: namespace,
|
||||
controlDependencies: s.controlDependencies,
|
||||
device: s.device,
|
||||
err: s.err,
|
||||
}
|
||||
}
|
||||
|
||||
@ -123,6 +127,25 @@ func (s *Scope) WithControlDependencies(ops ...*tf.Operation) *Scope {
|
||||
namemap: s.namemap,
|
||||
namespace: s.namespace,
|
||||
controlDependencies: deps,
|
||||
device: s.device,
|
||||
err: s.err,
|
||||
}
|
||||
}
|
||||
|
||||
// WithDevice returns a new Scope which will cause all operations added to the
|
||||
// graph to execute on devices that match the provided device specification.
|
||||
//
|
||||
// For example, WithDevice("/device:GPU:0") will cause operations added to
|
||||
// the graph to execute on GPU #0.
|
||||
//
|
||||
// An empty string removes any device restrictions.
|
||||
func (s *Scope) WithDevice(device string) *Scope {
|
||||
return &Scope{
|
||||
graph: s.graph,
|
||||
namemap: s.namemap,
|
||||
namespace: s.namespace,
|
||||
controlDependencies: s.controlDependencies,
|
||||
device: device,
|
||||
err: s.err,
|
||||
}
|
||||
}
|
||||
|
@ -112,6 +112,21 @@ func TestControlDependencies(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDevice(t *testing.T) {
|
||||
s := NewScope()
|
||||
matrix := Const(s, [][]float32{{3.0}})
|
||||
s = s.WithDevice("/device:GPU:0")
|
||||
square := MatMul(s.SubScope("square"), matrix, matrix)
|
||||
s = s.WithDevice("")
|
||||
cube := MatMul(s.SubScope("cube"), square, matrix)
|
||||
if got, want := square.Op.Device(), "/device:GPU:0"; got != want {
|
||||
t.Errorf("Got %q, want %q", got, want)
|
||||
}
|
||||
if got, want := cube.Op.Device(), ""; got != want {
|
||||
t.Errorf("Got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScopeFinalize(t *testing.T) {
|
||||
var (
|
||||
root = NewScope()
|
||||
|
@ -45,6 +45,12 @@ func (op *Operation) NumOutputs() int {
|
||||
return int(C.TF_OperationNumOutputs(op.c))
|
||||
}
|
||||
|
||||
// Device returns a specification of the device on which this operation
|
||||
// will be executed, or the empty string if there is no such specification.
|
||||
func (op *Operation) Device() string {
|
||||
return C.GoString(C.TF_OperationDevice(op.c))
|
||||
}
|
||||
|
||||
// OutputListSize returns the size of the list of Outputs that is produced by a
|
||||
// named output of op.
|
||||
//
|
||||
|
@ -228,6 +228,29 @@ func TestOperationConsumers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOperationDevice(t *testing.T) {
|
||||
graph := NewGraph()
|
||||
v, err := NewTensor(float32(1.0))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
op, err := graph.AddOperation(OpSpec{
|
||||
Type: "Const",
|
||||
Name: "Const",
|
||||
Attrs: map[string]interface{}{
|
||||
"dtype": v.DataType(),
|
||||
"value": v,
|
||||
},
|
||||
Device: "/device:GPU:0",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, want := op.Device(), "/device:GPU:0"; got != want {
|
||||
t.Errorf("Got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func forceGC() {
|
||||
var mem runtime.MemStats
|
||||
runtime.ReadMemStats(&mem)
|
||||
|
Loading…
Reference in New Issue
Block a user