110 lines
3.0 KiB
Go
110 lines
3.0 KiB
Go
/*
|
|
Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package tensorflow
|
|
|
|
// #include <stdlib.h>
|
|
// #include "tensorflow/c/c_api.h"
|
|
// #include "tensorflow/c/eager/c_api.h"
|
|
import "C"
|
|
import (
|
|
"fmt"
|
|
"runtime"
|
|
)
|
|
|
|
// ContextOptions contains configuration information for a session
|
|
type ContextOptions struct {
|
|
// Config is a binary-serialized representation of the
|
|
// tensorflow.ConfigProto protocol message
|
|
// (https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto).
|
|
Config []byte
|
|
|
|
// Sets the default execution mode
|
|
Async bool
|
|
}
|
|
|
|
// c converts the ContextOptions to the C API's TF_ContextOptions.
|
|
// Caller takes ownership of returned object.
|
|
func (o *ContextOptions) c() (*C.TFE_ContextOptions, error) {
|
|
opt := C.TFE_NewContextOptions()
|
|
if o == nil {
|
|
return opt, nil
|
|
}
|
|
|
|
if sz := len(o.Config); sz > 0 {
|
|
status := newStatus()
|
|
cConfig := C.CBytes(o.Config)
|
|
C.TFE_ContextOptionsSetConfig(opt, cConfig, C.size_t(sz), status.c)
|
|
C.free(cConfig)
|
|
if err := status.Err(); err != nil {
|
|
C.TFE_DeleteContextOptions(opt)
|
|
return nil, fmt.Errorf("invalid ContextOptions.Config: %v", err)
|
|
}
|
|
}
|
|
|
|
var async uint8
|
|
if o.Async {
|
|
async = 1
|
|
}
|
|
C.TFE_ContextOptionsSetAsync(opt, C.uchar(async))
|
|
|
|
return opt, nil
|
|
}
|
|
|
|
// Context for executing operations eagerly.
|
|
//
|
|
// A Context allows operations to be executed immediately. It encapsulates
|
|
// information such as the available devices, resource manager etc. It also
|
|
// allows the user to configure execution using a ConfigProto, as they can
|
|
// configure a Session when executing a Graph.
|
|
type Context struct {
|
|
c *C.TFE_Context
|
|
}
|
|
|
|
// NewContext creates a new context for eager execution.
|
|
// options may be nil to use the default options.
|
|
func NewContext(options *ContextOptions) (*Context, error) {
|
|
status := newStatus()
|
|
cOpt, err := options.c()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer C.TFE_DeleteContextOptions(cOpt)
|
|
cContext := C.TFE_NewContext(cOpt, status.c)
|
|
if err := status.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
c := &Context{c: cContext}
|
|
runtime.SetFinalizer(c, (*Context).finalizer)
|
|
return c, nil
|
|
}
|
|
|
|
func (c *Context) finalizer() {
|
|
C.TFE_DeleteContext(c.c)
|
|
}
|
|
|
|
// ListDevices returns the list of devices associated with a Context.
|
|
func (c *Context) ListDevices() ([]Device, error) {
|
|
status := newStatus()
|
|
devicesList := C.TFE_ContextListDevices(c.c, status.c)
|
|
if err := status.Err(); err != nil {
|
|
return nil, fmt.Errorf("SessionListDevices() failed: %v", err)
|
|
}
|
|
defer C.TF_DeleteDeviceList(devicesList)
|
|
return deviceSliceFromDeviceList(devicesList)
|
|
}
|