From f2a08fbecdd5ad4bd826c5efd7df60126ab716be Mon Sep 17 00:00:00 2001 From: James Keeling Date: Mon, 17 Dec 2018 10:43:38 -0800 Subject: [PATCH] Add Go wrapper around Eager C API's Context. This is part of a series of changes to provide a thin Go wrapper around the Eager C API. PiperOrigin-RevId: 225851549 --- tensorflow/go/BUILD | 1 + tensorflow/go/context.go | 109 ++++++++++++++++++++++++++++++++++ tensorflow/go/context_test.go | 57 ++++++++++++++++++ 3 files changed, 167 insertions(+) create mode 100644 tensorflow/go/context.go create mode 100644 tensorflow/go/context_test.go diff --git a/tensorflow/go/BUILD b/tensorflow/go/BUILD index f16cffac994..62d6b4f57c2 100644 --- a/tensorflow/go/BUILD +++ b/tensorflow/go/BUILD @@ -17,6 +17,7 @@ sh_test( ":all_files", # Go sources "//tensorflow:libtensorflow.so", # C library "//tensorflow/c:headers", # C library header + "//tensorflow/c/eager:headers", # Eager C library header "//tensorflow/cc/saved_model:saved_model_half_plus_two", # Testdata for LoadSavedModel ], ) diff --git a/tensorflow/go/context.go b/tensorflow/go/context.go new file mode 100644 index 00000000000..04f86282af3 --- /dev/null +++ b/tensorflow/go/context.go @@ -0,0 +1,109 @@ +/* +Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tensorflow + +// #include +// #include "tensorflow/c/c_api.h" +// #include "tensorflow/c/eager/c_api.h" +import "C" +import ( + "fmt" + "runtime" +) + +// ContextOptions contains configuration information for a session +type ContextOptions struct { + // Config is a binary-serialized representation of the + // tensorflow.ConfigProto protocol message + // (https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto). + Config []byte + + // Sets the default execution mode + Async bool +} + +// c converts the ContextOptions to the C API's TF_ContextOptions. +// Caller takes ownership of returned object. +func (o *ContextOptions) c() (*C.TFE_ContextOptions, error) { + opt := C.TFE_NewContextOptions() + if o == nil { + return opt, nil + } + + if sz := len(o.Config); sz > 0 { + status := newStatus() + cConfig := C.CBytes(o.Config) + C.TFE_ContextOptionsSetConfig(opt, cConfig, C.size_t(sz), status.c) + C.free(cConfig) + if err := status.Err(); err != nil { + C.TFE_DeleteContextOptions(opt) + return nil, fmt.Errorf("invalid ContextOptions.Config: %v", err) + } + } + + var async uint8 + if o.Async { + async = 1 + } + C.TFE_ContextOptionsSetAsync(opt, C.uchar(async)) + + return opt, nil +} + +// Context for executing operations eagerly. +// +// A Context allows operations to be executed immediately. It encapsulates +// information such as the available devices, resource manager etc. It also +// allows the user to configure execution using a ConfigProto, as they can +// configure a Session when executing a Graph. +type Context struct { + c *C.TFE_Context +} + +// NewContext creates a new context for eager execution. +// options may be nil to use the default options. +func NewContext(options *ContextOptions) (*Context, error) { + status := newStatus() + cOpt, err := options.c() + if err != nil { + return nil, err + } + defer C.TFE_DeleteContextOptions(cOpt) + cContext := C.TFE_NewContext(cOpt, status.c) + if err := status.Err(); err != nil { + return nil, err + } + + c := &Context{c: cContext} + runtime.SetFinalizer(c, (*Context).finalizer) + return c, nil +} + +func (c *Context) finalizer() { + C.TFE_DeleteContext(c.c) +} + +// ListDevices returns the list of devices associated with a Context. +func (c *Context) ListDevices() ([]Device, error) { + status := newStatus() + devicesList := C.TFE_ContextListDevices(c.c, status.c) + if err := status.Err(); err != nil { + return nil, fmt.Errorf("SessionListDevices() failed: %v", err) + } + defer C.TF_DeleteDeviceList(devicesList) + return deviceSliceFromDeviceList(devicesList) +} diff --git a/tensorflow/go/context_test.go b/tensorflow/go/context_test.go new file mode 100644 index 00000000000..ce4005da242 --- /dev/null +++ b/tensorflow/go/context_test.go @@ -0,0 +1,57 @@ +/* +Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tensorflow + +import ( + "fmt" + "testing" +) + +func TestContextConfigSetAsync(t *testing.T) { + tests := []bool{false, true} + for _, test := range tests { + t.Run(fmt.Sprint(test), func(t *testing.T) { + opt := &ContextOptions{Async: test} + if _, err := NewContext(opt); err != nil { + t.Fatal(err) + } + }) + } +} + +func TestContextConfigListDevices(t *testing.T) { + c, err := NewContext(nil) + if err != nil { + t.Fatal(err) + } + devs, err := c.ListDevices() + if err != nil { + t.Fatal(err) + } + if len(devs) < 1 { + t.Fatalf("No devices found using ListDevices()") + } + foundCPUDevice := false + for _, d := range devs { + if d.Type == "CPU" { + foundCPUDevice = true + } + } + if !foundCPUDevice { + t.Error("Failed to find CPU device using ListDevices()") + } +}