Experimental C and Python APIs to invoke TensorFlow kernels on concrete values.

PiperOrigin-RevId: 164902588
This commit is contained in:
Alexandre Passos 2017-08-10 14:19:55 -07:00 committed by TensorFlower Gardener
parent 7dfabcc01c
commit 13eb3b90e9
34 changed files with 6207 additions and 1 deletions

View File

@ -380,6 +380,7 @@ filegroup(
"//tensorflow/java/src/main/native:all_files",
"//tensorflow/python:all_files",
"//tensorflow/python/debug:all_files",
"//tensorflow/python/eager:all_files",
"//tensorflow/python/estimator:all_files",
"//tensorflow/python/feature_column:all_files",
"//tensorflow/python/kernel_tests:all_files",

67
tensorflow/c/eager/BUILD Normal file
View File

@ -0,0 +1,67 @@
# Experimental extensions to the C API for eager execution of kernels.
licenses(["notice"]) # Apache 2.0
cc_library(
name = "c_api",
srcs = ["c_api.cc"],
hdrs = ["c_api.h"],
visibility = [
"//tensorflow:internal",
"//tensorflow/python/eager:__pkg__",
],
deps = [
":runtime",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
],
)
cc_test(
name = "c_api_test",
srcs = ["c_api_test.cc"],
deps = [
":c_api",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "runtime",
srcs = ["runtime.cc"],
hdrs = ["runtime.h"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
],
)
cc_test(
name = "runtime_test",
srcs = ["runtime_test.cc"],
deps = [
":runtime",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:client_session",
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

561
tensorflow/c/eager/c_api.cc Normal file
View File

@ -0,0 +1,561 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api.h"
#include <algorithm>
#include <cstddef>
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/runtime.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/version.h"
using tensorflow::int64;
using tensorflow::string;
namespace {
bool IsCPU(tensorflow::Device* d) {
return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
}
string DeviceName(tensorflow::Device* d) {
return (d == nullptr) ? "cpu:0" : d->name();
}
} // namespace
struct TFE_Context {
explicit TFE_Context(TF_Session* s) : session(s) {}
// TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph.
TF_Session* session;
tensorflow::mutex functions_mu;
tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){
tensorflow::OpRegistry::Global(), {}};
// One FunctionLibraryRuntime per device.
// func_libs[i] is the FunctionLibraryRuntime corresponding to
// session->devices[i].
std::vector<std::unique_ptr<tensorflow::FunctionLibraryRuntime> > func_libs;
std::unordered_map<tensorflow::Fprint128, tensorflow::KernelAndDevice*,
tensorflow::Fprint128Hasher>
kernel_cache;
tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) {
for (int i = 0; i < session->devices.size(); ++i) {
if (session->devices[i] == d) {
return func_libs[i].get();
}
}
return nullptr;
}
const std::vector<tensorflow::Device*>& devices() { return session->devices; }
};
struct TFE_TensorHandle {
TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d)
: t(t), d(d) {}
tensorflow::Tensor t;
// TODO(ashankar): d == nullptr iff local CPU
// This was expedient, but perhaps worth revisiting ('d' should always be a
// valid pointer?)
// This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are
// provided with the appropriate TFE_Context.
//
// TODO(ashankar): Reference count TFE_Context to ensure that 'd' of a
// TFE_TensorHandle does not outlive the TFE_Context from which it came?
tensorflow::Device* d;
};
struct TFE_Op {
TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t)
: ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {}
bool const is_function() const { return attr_types == nullptr; }
TFE_Context* ctx; // Must outlive the TFE_Op.
const char* name;
tensorflow::AttrBuilder attrs;
const tensorflow::AttrTypeMap* attr_types;
std::vector<tensorflow::Tensor> inputs;
std::vector<tensorflow::Device*> input_devices;
tensorflow::Device* device;
};
extern "C" {
TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) {
TF_Graph* graph = TF_NewGraph();
TF_Session* session = TF_NewSession(graph, opts, status);
if (status->status.ok()) {
if (session->device_mgr == nullptr || session->devices.empty()) {
status->status = tensorflow::errors::InvalidArgument(
"Provided TF_SessionOptions are not compatible with eager execution "
"(perhaps the TF_SessionOptions alluded to session execution in a "
"remote address space?)");
}
}
if (!status->status.ok()) {
TF_DeleteGraph(graph);
return nullptr;
}
TFE_Context* ret = new TFE_Context(session);
ret->func_libs.resize(ret->devices().size());
for (int i = 0; i < ret->devices().size(); ++i) {
ret->func_libs[i] = tensorflow::NewFunctionLibraryRuntime(
ret->session->device_mgr, opts->options.env, ret->devices()[i],
TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {});
}
return ret;
}
void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
status->status = tensorflow::Status::OK();
tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
TF_Graph* graph = ctx->session->graph;
TF_DeleteSession(ctx->session, status);
TF_DeleteGraph(graph);
delete ctx;
}
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
return TF_SessionListDevices(ctx->session, status);
}
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t) {
return new TFE_TensorHandle(
tensorflow::TensorCApi::MakeTensor(t->dtype, t->shape, t->buffer),
nullptr);
}
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { delete h; }
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
return static_cast<TF_DataType>(h->t.dtype());
}
int TFE_TensorHandleNumDims(TFE_TensorHandle* h) { return h->t.dims(); }
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index) {
return h->t.dim_size(dim_index);
}
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h) {
// This might be a bit confusing as a tensor on CPU can sometimes return
// "CPU:0" and sometimes "/job:localhost/replica:0/task:0/cpu:0".
// TODO(ashankar): Figure out which one would be nicer.
return (h->d == nullptr) ? "CPU:0" : h->d->name().c_str();
}
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
if (!IsCPU(h->d)) {
TF_SetStatus(status, TF_UNIMPLEMENTED,
tensorflow::strings::StrCat(
"TFE_TensorHandle can be resolved iff it is on CPU (this "
"handle is on ",
h->d->name(),
"). Consider using TFE_TensorHandleCopyToDevice to get a "
"copy of the tensor on CPU")
.c_str());
return nullptr;
}
return tensorflow::TF_TensorFromTensor(h->t, status);
}
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TFE_Context* ctx,
const char* device_name,
TF_Status* status) {
tensorflow::Device* dstd = nullptr;
status->status = ctx->session->device_mgr->LookupDevice(device_name, &dstd);
if (!status->status.ok()) return nullptr;
tensorflow::Device* srcd = h->d == nullptr ? ctx->devices()[0] : h->d;
const bool src_cpu = IsCPU(srcd);
const bool dst_cpu = IsCPU(dstd);
if (!src_cpu && !dst_cpu) {
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat(
"TFE_TensorHandleCopyToDevice requires either the source "
"TFE_TensorHandle be on or the destination device be CPU (they "
"are ",
DeviceName(srcd), " and ", DeviceName(dstd), " in this call)")
.c_str());
return nullptr;
}
tensorflow::Tensor* src = &(h->t);
if (src_cpu && dst_cpu) {
// There must be a better way, but for now redirect through proto to ensure
// that the underlying buffers are not shared.
tensorflow::TensorProto proto;
src->AsProtoTensorContent(&proto);
tensorflow::Tensor dst(src->dtype(), src->shape());
if (!dst.FromProto(proto)) {
TF_SetStatus(
status, TF_INTERNAL,
tensorflow::strings::StrCat(
"error copying between TFE_TensorHandles on CPU. Consider filing "
"a bug report at https://github.com/tensorflow/tensorflow/issues "
"mentioning version: ",
TF_Version(), " and ", __FILE__, ":", __LINE__)
.c_str());
return nullptr;
}
return new TFE_TensorHandle(dst, nullptr);
}
if (src_cpu) {
tensorflow::Tensor dst(
dstd->GetAllocator(tensorflow::AllocatorAttributes()), src->dtype(),
src->shape());
tensorflow::Notification n;
dstd->tensorflow_gpu_device_info()->default_context->CopyCPUTensorToDevice(
src, dstd, &dst, [status, &n](const tensorflow::Status& s) {
status->status = s;
n.Notify();
});
n.WaitForNotification();
return (TF_GetCode(status) == TF_OK) ? new TFE_TensorHandle(dst, dstd)
: nullptr;
}
CHECK(dst_cpu);
tensorflow::Tensor dst(src->dtype(), src->shape());
tensorflow::Notification n;
// TODO(ashankar): The Sync() call below may be more aggressive than
// necessary. It is based on knowledge of implementation details - that
// GPU devices are implemented using 3 streams - one for host->device copies,
// one for device->host copies and one for sending operations to the GPU.
// With that setup, Sync()ing across all 3 streams should be sufficient
// but more than necessary (since it waits for operations that might have
// nothing to do with this tensor to complete).
status->status = srcd->Sync();
if (!status->status.ok()) return nullptr;
srcd->tensorflow_gpu_device_info()->default_context->CopyDeviceTensorToCPU(
src, "IGNORE_MY_TENSOR_NAME", srcd, &dst,
[status, &n](const tensorflow::Status& s) {
status->status = s;
n.Notify();
});
n.WaitForNotification();
return (TF_GetCode(status) == TF_OK) ? new TFE_TensorHandle(dst, nullptr)
: nullptr;
}
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
const char* name = op_or_function_name; // Shorthand
const tensorflow::AttrTypeMap* types;
status->status = tensorflow::AttrTypeMapForOp(name, &types);
if (status->status.ok()) return new TFE_Op(ctx, name, types);
if (TF_GetCode(status) == TF_NOT_FOUND) {
tensorflow::mutex_lock l(ctx->functions_mu);
if (ctx->func_lib_def.Find(name) != nullptr) {
status->status = tensorflow::Status::OK();
return new TFE_Op(ctx, name, nullptr);
}
}
return nullptr;
}
void TFE_DeleteOp(TFE_Op* op) { delete op; }
static void TFE_OpSetDeviceHelper(TFE_Op* op, tensorflow::Device* device,
TF_Status* status) {
// Questionable heuristic: Place the op on the same device as the first input
// placed outside of host memory?
if (IsCPU(op->device) && !IsCPU(device)) {
op->device = device;
}
}
void TFE_OpSetDevice(TFE_Op* op, TFE_Context* ctx, const char* device_name,
TF_Status* status) {
tensorflow::Device* d = nullptr;
status->status = ctx->session->device_mgr->LookupDevice(device_name, &d);
if (!status->status.ok()) return;
TFE_OpSetDeviceHelper(op, d, status);
}
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
TFE_OpSetDeviceHelper(op, h->d, status);
if (!status->status.ok()) return;
op->inputs.push_back(h->t);
op->input_devices.push_back(h->d);
op->attrs.NumInputs(op->inputs.size());
}
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
unsigned char* is_list, TF_Status* status) {
TF_AttrType ret;
if (op->is_function()) {
status->status = tensorflow::errors::Unimplemented(
"TODO(apassos): Support for attributes for TensorFlow functions is not "
"ready yet.");
return TF_ATTR_INT; // The compiler requires that we return something.
}
status->status =
tensorflow::AttrTypeByName(op->attr_types, attr_name, &ret, is_list);
return ret;
}
void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) {
op->attrs.Set(attr_name, value);
}
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
op->attrs.Set(attr_name, static_cast<int64>(value));
}
void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
op->attrs.Set(attr_name, value);
}
void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
op->attrs.Set(attr_name, (value == 0) ? false : true);
}
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
op->attrs.Set(attr_name, static_cast<tensorflow::DataType>(value));
}
void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
const int num_dims, TF_Status* out_status) {
if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat(
"Value specified for `", attr_name, "` has ", num_dims,
" dimensions which is over the limit of ",
tensorflow::TensorShape::MaxDimensions(), ".")
.c_str());
return;
}
tensorflow::TensorShapeProto proto;
if (num_dims < 0) {
proto.set_unknown_rank(true);
} else {
for (int d = 0; d < num_dims; ++d) {
proto.add_dim()->set_size(dims[d]);
}
}
op->attrs.Set(attr_name, proto);
}
#define TFE_OP_SET_ATTR_LIST(fn, type) \
void fn(TFE_Op* op, const char* attr_name, const type* values, \
int num_values) { \
op->attrs.Set(attr_name, tensorflow::gtl::ArraySlice<const type>( \
values, num_values)); \
}
TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*)
TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float)
#undef TFE_OP_SET_ATTR_LIST
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
const int64_t* values, int num_values) {
op->attrs.Set(attr_name,
tensorflow::gtl::ArraySlice<const int64>(
reinterpret_cast<const int64*>(values), num_values));
}
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
const TF_DataType* values, int num_values) {
op->attrs.Set(
attr_name,
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
reinterpret_cast<const tensorflow::DataType*>(values), num_values));
}
void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
const unsigned char* values, int num_values) {
std::unique_ptr<bool[]> b(new bool[num_values]);
for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
}
op->attrs.Set(attr_name,
tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
}
void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
const int64_t** dims, const int* num_dims,
int num_values, TF_Status* out_status) {
std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
new tensorflow::TensorShapeProto[num_values]);
for (int i = 0; i < num_values; ++i) {
const auto num_dims_i = num_dims[i];
if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat(
"Value specified for `", attr_name, "` has ", num_dims_i,
" dimensions which is over the limit of ",
tensorflow::TensorShape::MaxDimensions(), ".")
.c_str());
return;
}
if (num_dims_i < 0) {
proto[i].set_unknown_rank(true);
} else {
const int64_t* dims_i = dims[i];
auto proto_i = &proto[i];
for (int d = 0; d < num_dims_i; ++d) {
proto_i->add_dim()->set_size(dims_i[d]);
}
}
}
op->attrs.Set(attr_name,
tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
proto.get(), num_values));
}
namespace {
tensorflow::Status ValidateInputTypeAndPlacement(
tensorflow::Device* host_device, tensorflow::Device* op_device, TFE_Op* op,
const tensorflow::OpKernel* kernel) {
const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types();
if (memtypes.size() != op->inputs.size()) {
return tensorflow::errors::InvalidArgument(
"expected ", memtypes.size(), " inputs, got ", op->inputs.size());
}
for (int i = 0; i < op->inputs.size(); ++i) {
const tensorflow::Device* expected_device =
memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device;
const tensorflow::Device* actual_device =
op->input_devices[i] == nullptr ? host_device : op->input_devices[i];
if (expected_device != actual_device) {
return tensorflow::errors::InvalidArgument(
"cannot compute ", op->name, " as input #", i,
" was expected to be on ", expected_device->name(),
" but is actually on ", actual_device->name(),
" (operation running on ", op_device->name(), ")");
}
if (op->inputs[i].dtype() != kernel->input_type(i)) {
return tensorflow::errors::InvalidArgument(
"cannot compute ", op->name, " as input #", i,
" was expected to be a ",
tensorflow::DataType_Name(kernel->input_type(i)), " tensor but is a ",
tensorflow::DataType_Name(op->inputs[i].dtype()), " tensor");
}
}
return tensorflow::Status::OK();
}
} // namespace
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
TFE_Context* ctx = op->ctx;
// TODO(ashankar): ASSUMPTION: ctx->devices()[0] is always CPU
tensorflow::Device* device =
(op->device == nullptr) ? ctx->devices()[0] : op->device;
std::vector<tensorflow::Tensor> outputs(1);
const tensorflow::MemoryTypeVector* output_memory_types = nullptr;
tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device->name());
tensorflow::KernelAndDevice* kernel =
tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key);
if (kernel == nullptr) {
const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
kernel = new tensorflow::KernelAndDevice();
if (!op->is_function()) {
status->status =
tensorflow::KernelAndDevice::InitOp(device, ndef, kernel);
} else {
// Knowledge of the implementation of InitFn (and in-turn
// FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
// will be accessed, so grab on to the lock.
// See WARNING comment below - would be nice to rework to avoid this
// subtlety.
tensorflow::mutex_lock l(ctx->functions_mu);
status->status = tensorflow::KernelAndDevice::InitFn(
ndef, ctx->func_lib(device), kernel);
}
if (!status->status.ok()) {
return;
}
tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
}
status->status = ValidateInputTypeAndPlacement(ctx->devices()[0], device, op,
kernel->kernel());
output_memory_types = &kernel->kernel()->output_memory_types();
if (!status->status.ok()) {
return;
}
// WARNING: kernel->Run utilizes the FunctionLibraryRuntime
// (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def,
// which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation
// of FunctionLibraryRuntime tells use that func_lib_def is not accessed by
// FunctionLibraryRuntime::Run(), so there is no thread-safety concern here.
// This is quite subtle. Re-work things to make this better? (Would it make
// sense for FunctionLibraryRuntime to ensure thread-safe access to
// FunctionLibraryDefinition?).
status->status = kernel->Run(&op->inputs, &outputs);
if (!status->status.ok()) return;
*num_retvals = std::min<int>(*num_retvals, outputs.size());
for (int i = 0; i < *num_retvals; ++i) {
tensorflow::Device* d = IsCPU(device) ? nullptr : device;
if (d != nullptr && output_memory_types != nullptr &&
(*output_memory_types)[i] == tensorflow::HOST_MEMORY) {
d = nullptr;
}
retvals[i] = new TFE_TensorHandle(outputs[i], d);
}
}
void TFE_ContextAddFunctionDef(TFE_Context* ctx,
const char* serialized_function_def, size_t size,
TF_Status* status) {
tensorflow::FunctionDef function_def;
if (!function_def.ParseFromArray(serialized_function_def, size)) {
status->status =
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
return;
}
tensorflow::mutex_lock l(ctx->functions_mu);
status->status = ctx->func_lib_def.AddFunctionDef(function_def);
}
} // extern "C"
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) {
return new TFE_TensorHandle(t, nullptr);
}
const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
TFE_TensorHandle* h, TF_Status* status) {
if (h->d != nullptr) {
status->status = tensorflow::errors::FailedPrecondition(
"TFE_TensorHandle is placed in device (not host) memory. Cannot return "
"a tensorflow::Tensor");
return nullptr;
}
return &h->t;
}

159
tensorflow/c/eager/c_api.h Normal file
View File

@ -0,0 +1,159 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_C_API_H_
#define TENSORFLOW_C_EAGER_C_API_H_
// C API extensions to experiment with eager execution of kernels.
#include "tensorflow/c/c_api.h"
#ifdef __cplusplus
extern "C" {
#endif
// "Context" under which operations/functions are executed. It encapsulates
// things like the available devices, resource manager etc.
//
// TODO(ashankar): Merge with TF_Session?
typedef struct TFE_Context TFE_Context;
extern TFE_Context* TFE_NewContext(const TF_SessionOptions* opts,
TF_Status* status);
extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status);
extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
TF_Status* status);
// A handle to a tensor on a device.
//
// Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape,
// type etc. Unlike a TF_Tensor, a TFE_TensorHandle may refer to such tensors
// placed in memory of different devices or remote address spaces.
typedef struct TFE_TensorHandle TFE_TensorHandle;
extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t);
extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h);
extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h);
extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index);
extern const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h);
extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h,
TF_Status* status);
// Create a new TFE_TensorHandle with the same contents as 'h' but placed
// in the memory of the device name 'device_name'.
//
// Currently requires at least one of the source or destination devices to
// be CPU (i.e., for the source or destination tensor to be placed in
// host memory).
extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TFE_Context* ctx,
const char* device_name,
TF_Status* status);
// Description of the TensorFlow op to execute.
//
// Assumes that the provided 'ctx' outlives the returned TFE_Op, i.e.,
// TFE_DeleteOp() is called before TFE_DeleteContext().
//
// Very similar to TF_OperationDescription with some differences:
// (1) TF_Output or TFE_TensorHandle* as arguments to TF_AddInput,
// TF_AddInputList
// (2) TF_ColocateWith, TF_AddControlInput etc. do not make sense.
// (3) Implementation detail: Avoid use of NodeBuilder/NodeDefBuilder since
// the additional sanity checks there seem unnecessary;
typedef struct TFE_Op TFE_Op;
extern TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status);
extern void TFE_DeleteOp(TFE_Op* op);
// TODO(ashankar): TFE_OpSetDevice and TFE_Execute should not have a TFE_Context
// parameter. Instead, the TFE_Context should be captured when creating the
// TFE_Op.
extern void TFE_OpSetDevice(TFE_Op* op, TFE_Context* ctx,
const char* device_name, TF_Status* status);
extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status);
extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
unsigned char* is_list, TF_Status* status);
extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name,
const char* value);
extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value);
extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value);
extern void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name,
unsigned char value);
extern void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name,
TF_DataType value);
// If the number of dimensions is unknown, `num_dims` must be set to
// -1 and `dims` can be null. If a dimension is unknown, the
// corresponding entry in the `dims` array must be -1.
extern void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name,
const int64_t* dims, const int num_dims,
TF_Status* out_status);
extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
const char** value, int num_values);
extern void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
const int64_t* values, int num_values);
extern void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
const float* values, int num_values);
extern void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
const unsigned char* values, int num_values);
extern void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
const TF_DataType* values, int num_values);
extern void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
const int64_t** dims, const int* num_dims,
int num_values, TF_Status* out_status);
// Execute the operation defined by 'op' and return handles to computed
// tensors in 'retvals'.
//
// 'retvals' must point to a pre-allocated array of TFE_TensorHandle*
// and '*num_retvals' should be set to the size of this array.
//
// On return, 'num_retvals' will be set to the actual number of outputs
// returned by the operation.
extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals,
int* num_retvals, TF_Status* status);
// Add a function (serialized FunctionDef protocol buffer) to ctx so
// that it can be invoked using TFE_Execute.
extern void TFE_ContextAddFunctionDef(TFE_Context* ctx,
const char* serialized_function_def,
size_t size, TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#ifdef __cplusplus
// A workaround to ease conversion to and from numpy objects and
// TFE_TensorHandle's.
//
// TODO(ashankar): Figure out an alternative scheme that precludes the need for
// these API-boundary breaking methods.
namespace tensorflow {
class Tensor;
} // namespace tensorflow
const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
TFE_TensorHandle* h, TF_Status* status);
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t);
#endif
#endif // TENSORFLOW_C_EAGER_C_API_H_

View File

@ -0,0 +1,463 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api.h"
#include <string.h>
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
using tensorflow::string;
namespace {
TFE_TensorHandle* TestMatrixTensorHandle() {
int64_t dims[] = {2, 2};
float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
TF_Tensor* t = TF_AllocateTensor(
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandle(t);
TF_DeleteTensor(t);
return th;
}
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "MatMul", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, a, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, b, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_OpSetAttrBool(op, "transpose_a", 0);
TFE_OpSetAttrBool(op, "transpose_b", 0);
TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
return op;
}
// TODO(apassos) uncomment after rewriting to use the right benchmark API
// void BM_InitOp(benchmark::State& state) {
// TF_Status* status = TF_NewStatus();
// TF_SessionOptions* opts = TF_NewSessionOptions();
// TFE_Context* ctx = TFE_NewContext(opts, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TF_DeleteSessionOptions(opts);
// TFE_TensorHandle* m = TestMatrixTensorHandle();
// for (auto _ : state) {
// TFE_Op* matmul = MatMulOp(ctx, m, m);
// TFE_DeleteOp(matmul);
// }
// TFE_DeleteTensorHandle(m);
// TFE_DeleteContext(ctx, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TF_DeleteStatus(status);
// }
// BENCHMARK(BM_InitOp);
// void BM_Execute(benchmark::State& state) {
// TF_Status* status = TF_NewStatus();
// TF_SessionOptions* opts = TF_NewSessionOptions();
// TFE_Context* ctx = TFE_NewContext(opts, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TF_DeleteSessionOptions(opts);
// TFE_TensorHandle* m = TestMatrixTensorHandle();
// TFE_Op* matmul = MatMulOp(ctx, m, m);
// TFE_TensorHandle* retvals[1];
// int num_retvals = 1;
// for (auto _ : state) {
// TFE_Execute(matmul, &retvals[0], &num_retvals, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// }
// TFE_DeleteOp(matmul);
// TFE_DeleteTensorHandle(m);
// TFE_DeleteContext(ctx, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TF_DeleteStatus(status);
// }
// BENCHMARK(BM_Execute);
TEST(CAPI, Context) {
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
TF_DeleteSessionOptions(opts);
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContext(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const int num_devices = TF_DeviceListCount(devices);
EXPECT_GE(num_devices, 1) << "At least one CPU device should exist";
for (int i = 0; i < num_devices; ++i) {
EXPECT_NE("", TF_DeviceListName(devices, i, status)) << i;
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
}
TF_DeleteDeviceList(devices);
TF_DeleteStatus(status);
}
TEST(CAPI, TensorHandle) {
TFE_TensorHandle* h = TestMatrixTensorHandle();
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Tensor* t = TFE_TensorHandleResolve(h, status.get());
ASSERT_EQ(16, TF_TensorByteSize(t));
float data[4] = {0};
memcpy(&data[0], TF_TensorData(t), TF_TensorByteSize(t));
EXPECT_EQ(1.0, data[0]);
EXPECT_EQ(2.0, data[1]);
EXPECT_EQ(3.0, data[2]);
EXPECT_EQ(4.0, data[3]);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(h);
}
TEST(CAPI, TensorHandleCopyBetweenDevices) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TF_DeleteSessionOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
const int num_devices = TF_DeviceListCount(devices);
const char* kCPUDevice = "CPU:0";
for (int i = 0; i < num_devices; ++i) {
const string name(TF_DeviceListName(devices, i, status.get()));
if (TF_GetCode(status.get()) != TF_OK) {
ADD_FAILURE() << i << " -- " << TF_Message(status.get());
continue;
}
auto tag = tensorflow::strings::StrCat("Device #", i, " (", name, ")");
// Copy to device
TFE_TensorHandle* hdevice =
TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get());
if (TF_GetCode(status.get()) != TF_OK) {
ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
continue;
}
// Copy back to CPU
TFE_TensorHandle* hcopy =
TFE_TensorHandleCopyToDevice(hdevice, ctx, kCPUDevice, status.get());
if (TF_GetCode(status.get()) != TF_OK) {
ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
continue;
}
TFE_DeleteTensorHandle(hdevice);
// Ensure that the contents are the same!
TF_Tensor* tcopy = TFE_TensorHandleResolve(hcopy, status.get());
TFE_DeleteTensorHandle(hcopy);
if (TF_GetCode(status.get()) != TF_OK) {
ADD_FAILURE() << tag;
continue;
}
EXPECT_EQ(TF_TensorByteSize(t), TF_TensorByteSize(tcopy)) << tag;
EXPECT_EQ(
0, memcmp(TF_TensorData(t), TF_TensorData(tcopy), TF_TensorByteSize(t)))
<< tag;
TF_DeleteTensor(tcopy);
}
TF_DeleteDeviceList(devices);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
TFE_DeleteContext(ctx, status.get());
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
}
TEST(CAPI, Execute) {
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* matmul = MatMulOp(ctx, m, m);
TFE_TensorHandle* retvals[2] = {nullptr};
int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call.
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(m);
TFE_DeleteContext(ctx, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
TFE_DeleteTensorHandle(retvals[0]);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TF_DeleteStatus(status);
}
string MatMulFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'MatMulFunction'"
" input_arg {"
" name: 'a'"
" type: DT_FLOAT"
" }"
" output_arg {"
" name: 'm'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'matmul'"
" op: 'MatMul'"
" input: 'a'"
" input: 'a'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'm'"
" value: 'matmul:product'"
" }",
&def));
return def.SerializeAsString();
}
TEST(CAPI, FunctionDefAndExecute) {
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts);
string function_def = MatMulFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* retval[1] = {nullptr};
int num_retvals = 1;
TFE_Op* op = TFE_NewOp(ctx, "MatMulFunction", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, m, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Execute(op, &retval[0], &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TFE_DeleteOp(op);
TFE_DeleteTensorHandle(m);
TF_Tensor* t = TFE_TensorHandleResolve(retval[0], status);
TFE_DeleteTensorHandle(retval[0]);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TFE_DeleteContext(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
// TODO(apassos) uncomment after rewriting to use the right benchmark API
// void BM_ExecuteFunction(benchmark::State& state) {
// TF_Status* status = TF_NewStatus();
// TF_SessionOptions* opts = TF_NewSessionOptions();
// TFE_Context* ctx = TFE_NewContext(opts, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TF_DeleteSessionOptions(opts);
// string function_def = MatMulFunction();
// TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
// status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TFE_TensorHandle* m = TestMatrixTensorHandle();
// TFE_Op* matmul = TFE_NewOp(ctx, "MatMulFunction", status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TFE_OpAddInput(matmul, m, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TFE_TensorHandle* retval[1] = {nullptr};
// int num_retvals = 1;
// for (auto _ : state) {
// TFE_Execute(matmul, &retval[0], &num_retvals, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// }
// TFE_DeleteTensorHandle(m);
// TFE_DeleteTensorHandle(retval[0]);
// TFE_DeleteContext(ctx, status);
// EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TF_DeleteStatus(status);
// }
// BENCHMARK(BM_ExecuteFunction);
// TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
// TF_Status* status) {
// // Create the variable handle.
// TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
// if (TF_GetCode(status) != TF_OK) return nullptr;
// TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
// TFE_OpSetAttrShape(op, "shape", {}, 0, status);
// TFE_OpSetAttrString(op, "container", "");
// TFE_OpSetAttrString(op, "shared_name", "");
// if (TF_GetCode(status) != TF_OK) return nullptr;
// TFE_TensorHandle* var_handle = nullptr;
// int num_retvals = 1;
// TFE_Execute(op, &var_handle, &num_retvals, status);
// TFE_DeleteOp(op);
// if (TF_GetCode(status) != TF_OK) return nullptr;
// CHECK_EQ(1, num_retvals);
// // Assign 'value' to it.
// op = TFE_NewOp(ctx, "AssignVariableOp", status);
// if (TF_GetCode(status) != TF_OK) return nullptr;
// TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
// TFE_OpAddInput(op, var_handle, status);
// // Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
// std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
// TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)),
// TF_DeleteTensor);
// memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
// std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
// value_handle(TFE_NewTensorHandle(t.get()), TFE_DeleteTensorHandle);
// TFE_OpAddInput(op, value_handle.get(), status);
// if (TF_GetCode(status) != TF_OK) return nullptr;
// num_retvals = 0;
// TFE_Execute(op, nullptr, &num_retvals, status);
// TFE_DeleteOp(op);
// if (TF_GetCode(status) != TF_OK) return nullptr;
// CHECK_EQ(0, num_retvals);
// return var_handle;
// }
// TEST(CAPI, Variables) {
// // Variables use resource handles, so this is really a test for resource
// // tensor handling.
// TF_Status* status = TF_NewStatus();
// TF_SessionOptions* opts = TF_NewSessionOptions();
// TFE_Context* ctx = TFE_NewContext(opts, status);
// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TF_DeleteSessionOptions(opts);
// TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status);
// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
// TFE_OpAddInput(op, var_handle, status);
// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// int num_retvals = 1;
// TFE_TensorHandle* value_handle = nullptr;
// TFE_Execute(op, &value_handle, &num_retvals, status);
// TFE_DeleteOp(op);
// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// ASSERT_EQ(1, num_retvals);
// EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(value_handle));
// EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle));
// float value = 0.0f;
// TF_Tensor* t = TFE_TensorHandleResolve(value_handle, status);
// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// ASSERT_EQ(sizeof(float), TF_TensorByteSize(t));
// memcpy(&value, TF_TensorData(t), sizeof(float));
// TF_DeleteTensor(t);
// EXPECT_EQ(12.0, value);
// TFE_DeleteTensorHandle(var_handle);
// TFE_DeleteTensorHandle(value_handle);
// TFE_DeleteContext(ctx, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TF_DeleteStatus(status);
// }
// void BM_ReadVariable(benchmark::State& state) {
// TF_Status* status = TF_NewStatus();
// TF_SessionOptions* opts = TF_NewSessionOptions();
// TFE_Context* ctx = TFE_NewContext(opts, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TF_DeleteSessionOptions(opts);
// TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
// TFE_OpAddInput(op, var_handle, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// int num_retvals = 1;
// TFE_TensorHandle* h = nullptr;
// for (auto _ : state) {
// TFE_Execute(op, &h, &num_retvals, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// CHECK_EQ(1, num_retvals);
// CHECK(h);
// CHECK_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
// CHECK_EQ(0, TFE_TensorHandleNumDims(h));
// h = nullptr;
// }
// TFE_DeleteOp(op);
// TFE_DeleteTensorHandle(var_handle);
// TFE_DeleteContext(ctx, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TF_DeleteStatus(status);
// }
// BENCHMARK(BM_ReadVariable);
} // namespace

View File

@ -0,0 +1,289 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/runtime.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
namespace tensorflow {
namespace {
mutex g_op_name_to_attr_type_map_lock(LINKER_INITIALIZED);
std::unordered_map<string, const AttrTypeMap*>* OpNameToAttrTypeMap() {
static auto* const m = new std::unordered_map<string, const AttrTypeMap*>;
return m;
}
const uint32 kIsList = 1U << 31;
} // namespace
Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) {
mutex_lock l(g_op_name_to_attr_type_map_lock);
*out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name);
if (*out != nullptr) return Status::OK();
const OpRegistrationData* op_reg_data = nullptr;
Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data);
if (!s.ok()) return s;
std::unique_ptr<AttrTypeMap> m(new AttrTypeMap);
// TODO(agarwal): Avoid having to create this "registry" at runtime,
// perhaps can be done at op registration time?
for (const auto& attr : op_reg_data->op_def.attr()) {
string type = attr.type();
const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0);
if (is_list) {
type = type.substr(5, type.length() - 6);
}
uint32 t = is_list ? kIsList : 0;
if (type == "string") {
t |= TF_ATTR_STRING;
} else if (type == "int") {
t |= TF_ATTR_INT;
} else if (type == "float") {
t |= TF_ATTR_FLOAT;
} else if (type == "bool") {
t |= TF_ATTR_BOOL;
} else if (type == "type") {
t |= TF_ATTR_TYPE;
} else if (type == "shape") {
t |= TF_ATTR_SHAPE;
} else if (type == "tensor") {
t |= TF_ATTR_TENSOR;
} else {
return errors::Unimplemented(
"TODO(agarwal): Enable support for ops with attributes of type '",
type, "'");
}
gtl::InsertIfNotPresent(m.get(), attr.name(), t);
}
*out = m.get();
(*OpNameToAttrTypeMap())[op_name] = m.release();
return Status::OK();
}
Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name,
TF_AttrType* out, unsigned char* is_list) {
CHECK(m);
auto* t = gtl::FindOrNull(*m, attr_name);
if (t == nullptr) {
return errors::InvalidArgument("Attribute '", attr_name,
"' does not exist for this operation");
}
*out = static_cast<TF_AttrType>(*t & ~kIsList);
if (*t & kIsList) {
*is_list = 1;
} else {
*is_list = 0;
}
return Status::OK();
}
#define DEFINE_SET_ATTR(value_type, value_field) \
template <> \
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, value_type&& value) { \
value_field.push_back(std::make_pair(attr_name, value)); \
return *this; \
}
DEFINE_SET_ATTR(StringPiece, string_attrs_);
DEFINE_SET_ATTR(float, float_attrs_);
DEFINE_SET_ATTR(int, int_attrs_);
DEFINE_SET_ATTR(bool, bool_attrs_);
DEFINE_SET_ATTR(tensorflow::DataType, type_attrs_);
#undef DEFINE_SET_ATTR
AttrBuilder& AttrBuilder::NumInputs(int n) {
DCHECK(!node_def_finalized_) << "Calling NumInputs after BuildNodeDef.";
num_inputs_ = n;
return *this;
}
const NodeDef& AttrBuilder::BuildNodeDef() {
if (node_def_finalized_) return *node_def_;
MayBeInitializeNodeDef();
for (int i = 0; i < num_inputs_; ++i) {
node_def_->add_input("dummy_input");
}
for (const auto& p : string_attrs_) {
SetInNodeDef(p.first, p.second);
}
for (const auto& p : int_attrs_) {
SetInNodeDef(p.first, p.second);
}
for (const auto& p : float_attrs_) {
SetInNodeDef(p.first, p.second);
}
for (const auto& p : bool_attrs_) {
SetInNodeDef(p.first, p.second);
}
for (const auto& p : type_attrs_) {
SetInNodeDef(p.first, p.second);
}
node_def_finalized_ = true;
return *node_def_;
}
namespace {
inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
const tensorflow::Fprint128& b) {
return {tensorflow::FingerprintCat64(a.low64, b.low64),
tensorflow::FingerprintCat64(a.low64, b.low64)};
}
void CombineUnordered(const tensorflow::Fprint128& a,
tensorflow::Fprint128* b) {
b->low64 += a.low64;
b->high64 += a.high64;
}
inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s,
const tensorflow::Fprint128& b) {
// TODO(agarwal): avoid ToString().
tensorflow::Fprint128 a = tensorflow::Fingerprint128(s.ToString());
return FingerprintCat128(a, b);
}
inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s, uint64 b) {
return CacheKeyHelper(s, {b, b});
}
} // namespace
tensorflow::Fprint128 AttrBuilder::CacheKey(const string& device) const {
tensorflow::Fprint128 f = tensorflow::Fingerprint128(op_name_);
f = tensorflow::FingerprintCat128(f, tensorflow::Fingerprint128(device));
if (node_def_ != nullptr) {
// Some attributes are directly written to node_def_ instead of being
// stored explicitly.
string value;
for (const auto& attr : node_def_->attr()) {
attr.second.SerializeToString(&value);
CombineUnordered(
CacheKeyHelper(attr.first, tensorflow::Fingerprint128(value)), &f);
}
// Note that node_def_ may be created but not finalized. This can happen
// when the creation was triggered by a call to Set, but BuildNodeDef has
// not been called.
if (node_def_finalized_) return f;
}
for (const auto& p : string_attrs_) {
// TODO(agarwal): avoid ToString().
CombineUnordered(CacheKeyHelper(p.first, tensorflow::Fingerprint128(
p.second.ToString())),
&f);
}
for (const auto& p : int_attrs_) {
CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
&f);
}
static std::hash<float> float_hasher;
for (const auto& p : float_attrs_) {
CombineUnordered(
CacheKeyHelper(p.first, static_cast<uint64>(float_hasher(p.second))),
&f);
}
for (const auto& p : bool_attrs_) {
CombineUnordered(CacheKeyHelper(p.first, p.second ? 1u : 0u), &f);
}
for (const auto& p : type_attrs_) {
CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
&f);
}
return f;
}
void AttrBuilder::MayBeInitializeNodeDef() {
if (node_def_ == nullptr) {
node_def_.reset(new NodeDef());
node_def_->set_name(op_name_);
node_def_->set_op(op_name_);
}
}
// static
Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef,
KernelAndDevice* out) {
OpKernel* k = nullptr;
Status s = CreateOpKernel(device->device_type().c_str(), device,
device->GetAllocator(AllocatorAttributes()),
nullptr, ndef, TF_GRAPH_DEF_VERSION, &k);
out->device_ = device;
out->kernel_.reset(k);
out->flib_ = nullptr;
return s;
}
// static
Status KernelAndDevice::InitFn(const NodeDef& ndef,
FunctionLibraryRuntime* flib,
KernelAndDevice* out) {
OpKernel* k = nullptr;
Status s = flib->CreateKernel(ndef, &k);
out->device_ = flib->device();
out->kernel_.reset(k);
out->flib_ = flib;
return s;
}
Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
std::vector<Tensor>* output_tensors) {
gtl::InlinedVector<TensorValue, 4> inputs;
for (Tensor& t : *input_tensors) {
inputs.push_back(TensorValue(&t));
}
std::vector<AllocatorAttributes> out_attrs(kernel_->num_outputs());
for (size_t i = 0; i < out_attrs.size(); ++i) {
out_attrs[i].set_on_host(kernel_->output_memory_types()[i] ==
tensorflow::HOST_MEMORY);
}
OpKernelContext::Params params;
params.device = device_;
params.frame_iter = FrameAndIter(0, 0);
params.inputs = &inputs;
params.op_kernel = kernel_.get();
params.resource_manager = device_->resource_manager();
params.output_attr_array = gtl::vector_as_array(&out_attrs);
params.function_library = flib_;
params.slice_reader_cache = &slice_reader_cache_;
// TODO(apassos): use a thread pool.
std::function<void(std::function<void()>)> runner =
[](std::function<void()> f) { f(); };
params.runner = &runner;
OpKernelContext context(&params);
device_->Compute(kernel_.get(), &context);
if (!context.status().ok()) return context.status();
output_tensors->clear();
for (int i = 0; i < context.num_outputs(); ++i) {
output_tensors->push_back(Tensor(*context.mutable_output(i)));
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,193 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_RUNTIME_H_
#define TENSORFLOW_C_EAGER_RUNTIME_H_
// Support for eager execution of TensorFlow kernels.
#include <memory>
#include <unordered_map>
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
namespace tensorflow {
// Maps attribute name to an encoding of the type of the attribute value.
// If the type is not a list type, the value is the same as the TF_AttrType type
// of the value. Else, the highest order bit is on, and the rest of the bits
// represent the TF_AttrType type of the values in the list.
typedef std::unordered_map<string, uint32> AttrTypeMap;
// Returns the AttrTypeMap for the TensorFlow operation named op_name.
Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out);
// Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'.
Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name,
TF_AttrType* out, unsigned char* is_list);
// KernelAndDevice::Init needs a NodeDef only to pass the attribute map through.
// An AttrBuilder is a convenience class to help with that - providing a smaller
// interface than NodeDefBuilder and avoiding expensive (unnecessary?) sanity
// checks (like number of inputs matching the OpDef - we only care about
// attributes here).
//
// TODO(ashankar): Take a closer look at checks in NodeDefBuilder and see which
// ones make sense to replicate.
// This is a helper class for creating a NodeDef. Additionally, this class
// allows computing a cache key based on fingerprinting the attributes of this
// NodeDef.
//
// Example usage:
// AttrBuilder a;
// a.NumInputs(2);
// a.Set("T", TF_FLOAT);
// uint64 cache_key = a.CacheKey("cpu:0");
// const NodeDef& n = a.BuildNodeDef();
//
// Note that all calls to Set and NumInputs should happen before calling
// BuildNodeDef. Also, calls to NumInputs or Set between multiple invocations
// to CacheKey may cause different values to be returned by CacheKey.
//
// For performance reasons, the class internally delays the actual construction
// of the NodeDef till BuildNodeDef is called, or Set is called with certain
// uncommon types (see template specializations of Set to see which types
// trigger a NodeDef creation).
class AttrBuilder {
public:
explicit AttrBuilder(const char* op)
: op_name_(op),
num_inputs_(0),
node_def_(nullptr),
node_def_finalized_(false) {}
// Needed to work around call to ValidateNodeDef in CreateOpKernel.
AttrBuilder& NumInputs(int n);
template <class T>
AttrBuilder& Set(StringPiece attr_name, T&& value) {
MayBeInitializeNodeDef();
return SetInNodeDef(attr_name, value);
}
tensorflow::Fprint128 CacheKey(const string& device) const;
const NodeDef& BuildNodeDef();
private:
template <class T>
using AttrVec = tensorflow::gtl::InlinedVector<std::pair<StringPiece, T>, 2>;
void MayBeInitializeNodeDef();
template <class T>
AttrBuilder& SetInNodeDef(StringPiece attr_name, T&& value) {
DCHECK(!node_def_finalized_) << "Calling SetInNodeDef after BuildNodeDef.";
// Copied from NodeDefBuilder::Attr
const AttrValue* found = AttrSlice(*node_def_).Find(attr_name);
if (found == nullptr) {
AddNodeAttr(attr_name, std::forward<T>(value), node_def_.get());
} else {
AttrValue attr_value;
SetAttrValue(std::forward<T>(value), &attr_value);
// TODO(ashankar): Do what is done in
// NodeDefBuilder::CheckInconsistency(attr_name, *found, attr_value);
}
return *this;
}
AttrVec<StringPiece> string_attrs_;
AttrVec<int> int_attrs_;
AttrVec<float> float_attrs_;
AttrVec<bool> bool_attrs_;
AttrVec<tensorflow::DataType> type_attrs_;
string op_name_;
int num_inputs_;
std::unique_ptr<NodeDef> node_def_;
bool node_def_finalized_;
}; // namespace tensorflow
template <>
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, StringPiece&& value);
template <>
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, int&& value);
template <>
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, float&& value);
template <>
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, bool&& value);
template <>
AttrBuilder& AttrBuilder::Set(StringPiece attr_name,
tensorflow::DataType&& value);
// KernelAndDevice encapsulates an instantiated kernel and the device it is on.
//
// Also see:
// https://www.tensorflow.org/code/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
// and
// https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h
class KernelAndDevice {
public:
// Populates 'out' with a kernel appropriate for 'ndef'.
//
// Assumes that 'ndef' refers to a primitive op (as opposed to a function).
static Status InitOp(Device* device, const NodeDef& ndef,
KernelAndDevice* out);
// Like InitOp but for functions defined in flib (i.e., ndef.op() refers to a
// TensorFlow function in the FunctionLibraryRuntime).
//
// The provided FunctionLibraryRuntime MUST outlive all calls to
// Run() on the returned KernelAndDevice.
//
// TODO(ashankar): There shouldn't be a need for a separate InitOp and InitFn.
// The implementation of InitFn should work for both because
// FunctionLibraryRuntime::CreateKernel will create a primitive op kernel if
// appropriate. However, for now we keep them separate because I haven't
// figured out thread-safety concerns around FunctionLibraryRuntime (in
// particular, how the underlying FunctionLibraryDefinition might be mutated
// by another thread as new functions are registered with it).
// Conservatively, thread-safe usage of the FunctionLibraryRuntime is pushed
// on to the caller (see locking in c_api.cc) for now. But I really should
// dig into this so that both InitOp and InitFn can be collapsed to
// FunctionLibraryRuntime::CreateKernel.
static Status InitFn(const NodeDef& ndef, FunctionLibraryRuntime* flib,
KernelAndDevice* out);
KernelAndDevice() : device_(nullptr), flib_(nullptr) {}
// TODO(ashankar): Handle list-valued inputs.
Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs);
const OpKernel* kernel() const { return kernel_.get(); }
private:
std::unique_ptr<OpKernel> kernel_;
tensorflow::Device* device_;
tensorflow::FunctionLibraryRuntime* flib_;
tensorflow::checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_RUNTIME_H_

View File

@ -0,0 +1,160 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/runtime.h"
#include <memory>
#include <vector>
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace {
Device* CPUDevice() {
return DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
}
TEST(AttrTypeMap, Lookup) {
const AttrTypeMap* m = nullptr;
Status s = AttrTypeMapForOp("ThisOpCannotPossiblyExist", &m);
EXPECT_FALSE(s.ok());
s = AttrTypeMapForOp("MatMul", &m);
ASSERT_TRUE(s.ok()) << s;
TF_AttrType t;
unsigned char is_list = 1;
s = AttrTypeByName(m, "ThisAttribyteCannotPossiblyExist", &t, &is_list);
EXPECT_FALSE(s.ok());
EXPECT_NE(is_list, 0);
s = AttrTypeByName(m, "transpose_a", &t, &is_list);
ASSERT_TRUE(s.ok()) << s;
EXPECT_EQ(TF_ATTR_BOOL, t);
EXPECT_EQ(is_list, 0);
s = AttrTypeMapForOp("Squeeze", &m);
ASSERT_TRUE(s.ok()) << s;
s = AttrTypeByName(m, "squeeze_dims", &t, &is_list);
ASSERT_TRUE(s.ok()) << s;
EXPECT_EQ(TF_ATTR_INT, t);
EXPECT_NE(is_list, 0);
}
TEST(KernelAndDevice, Run) {
Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor());
std::vector<Tensor> inputs;
inputs.push_back(t);
inputs.push_back(t);
NodeDef ndef(AttrBuilder("MatMul")
.Set("T", DT_FLOAT)
.Set("transpose_a", false)
.Set("transpose_b", false)
.NumInputs(inputs.size())
.BuildNodeDef());
std::unique_ptr<Device> device(CPUDevice());
KernelAndDevice kernel;
Status s = KernelAndDevice::InitOp(device.get(), ndef, &kernel);
ASSERT_TRUE(s.ok()) << s;
std::vector<Tensor> outputs;
s = kernel.Run(&inputs, &outputs);
ASSERT_TRUE(s.ok()) << s;
ASSERT_EQ(1, outputs.size());
const Tensor& out = outputs[0];
EXPECT_EQ(7, out.matrix<float>()(0, 0));
EXPECT_EQ(10, out.matrix<float>()(0, 1));
EXPECT_EQ(15, out.matrix<float>()(1, 0));
EXPECT_EQ(22, out.matrix<float>()(1, 1));
}
// TODO(apassos) uncomment after rewriting to use the right benchmark API
// void BM_CreateGraph(benchmark::State& state) {
// for (auto _ : state) {
// Scope root = Scope::NewRootScope();
// auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
// auto M = ops::MatMul(root, C, C);
// TF_CHECK_OK(root.status());
// }
// }
// BENCHMARK(BM_CreateGraph);
// void BM_RunGraph(benchmark::State& state) {
// Scope root = Scope::NewRootScope();
// auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
// auto M = ops::MatMul(root, C, C);
// SessionOptions opts;
// opts.config.set_inter_op_parallelism_threads(1);
// opts.config.set_intra_op_parallelism_threads(1);
// ClientSession sess(root, opts);
// std::vector<Tensor> outputs;
// for (auto _ : state) {
// outputs.clear();
// TF_CHECK_OK(sess.Run({M}, &outputs));
// }
// }
// BENCHMARK(BM_RunGraph);
// void BM_CreateAndDestroySession(benchmark::State& state) {
// Scope root = Scope::NewRootScope();
// auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
// auto M = ops::MatMul(root, C, C);
// for (auto _ : state) {
// ClientSession sess(root);
// }
// }
// BENCHMARK(BM_CreateAndDestroySession);
// void BM_KernelAndDeviceInit(benchmark::State& state) {
// NodeDef ndef(AttrBuilder("MatMul")
// .Set("T", DT_FLOAT)
// .Set("transpose_a", false)
// .Set("transpose_b", false)
// .NumInputs(2)
// .BuildNodeDef());
// std::unique_ptr<Device> device(CPUDevice());
// KernelAndDevice k;
// for (auto _ : state) {
// TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &k));
// }
// }
// BENCHMARK(BM_KernelAndDeviceInit);
// void BM_KernelAndDeviceRun(benchmark::State& state) {
// Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor());
// std::vector<Tensor> inputs;
// inputs.push_back(t);
// inputs.push_back(t);
// std::vector<Tensor> outputs;
// NodeDef ndef(AttrBuilder("MatMul")
// .Set("T", DT_FLOAT)
// .Set("transpose_a", false)
// .Set("transpose_b", false)
// .NumInputs(inputs.size())
// .BuildNodeDef());
// std::unique_ptr<Device> device(CPUDevice());
// KernelAndDevice kernel;
// TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &kernel));
// for (auto _ : state) {
// TF_CHECK_OK(kernel.Run(&inputs, &outputs));
// }
// }
// BENCHMARK(BM_KernelAndDeviceRun);
} // namespace
} // namespace tensorflow

View File

@ -18,6 +18,10 @@
set(tf_c_srcs
"${tensorflow_source_dir}/tensorflow/c/c_api.cc"
"${tensorflow_source_dir}/tensorflow/c/c_api.h"
"${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc"
"${tensorflow_source_dir}/tensorflow/c/eager/c_api.h"
"${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc"
"${tensorflow_source_dir}/tensorflow/c/eager/runtime.h"
"${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.cc"
"${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.h"
"${tensorflow_source_dir}/tensorflow/c/tf_status_helper.cc"

View File

@ -755,6 +755,8 @@ add_custom_command(
set (pywrap_tensorflow_internal_src
"${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.h"
"${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.cc"
"${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tfe.h"
"${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tfe_src.h"
"${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.h"
"${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.cc"
"${tensorflow_source_dir}/tensorflow/python/framework/cpp_shape_inference.h"

View File

@ -60,6 +60,7 @@ INCLUDEPRE_RE = re.compile(r"google::protobuf::internal::ExplicitlyConstructed|"
# Include if matched after exclude
INCLUDE_RE = re.compile(r"^(TF_\w*)$|"
r"^(TFE_\w*)$|"
r"tensorflow::|"
r"functor::|"
r"perftools::gputools")

View File

@ -2814,6 +2814,7 @@ tf_py_wrap_cc(
"lib/io/py_record_reader.i",
"lib/io/py_record_writer.i",
"platform/base.i",
"pywrap_tfe.i",
"training/quantize_training.i",
"training/server_lib.i",
"util/kernel_registry.i",
@ -2838,6 +2839,7 @@ tf_py_wrap_cc(
"//tensorflow/c:checkpoint_reader",
"//tensorflow/c:python_api",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/eager:c_api",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
"//tensorflow/core/grappler:grappler_item",
@ -2850,6 +2852,7 @@ tf_py_wrap_cc(
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/profiler/internal:print_model_analysis",
"//tensorflow/tools/graph_transforms:transform_graph_lib",
"//tensorflow/python/eager:pywrap_tfe_lib",
"//util/python:python_headers",
] + (tf_additional_lib_deps() +
tf_additional_plugin_deps() +

View File

@ -56,7 +56,7 @@ tensorflow::ImportNumpy();
// const char*.
%typemap(in) (const char* target) {
$1 = PyBytes_AsString($input);
if (!$1) {
if (!$1) {
// Python has raised an error.
SWIG_fail;
}

View File

@ -0,0 +1,254 @@
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
cc_library(
name = "pywrap_tfe_lib",
srcs = ["pywrap_tfe_src.cc"],
hdrs = ["pywrap_tfe.h"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",
"//tensorflow/core:lib",
"//tensorflow/python:numpy_lib",
"//tensorflow/python:py_func_lib",
"//util/python:python_headers",
],
)
py_library(
name = "core",
srcs = ["core.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
":context",
":memory_trace",
":tape",
"//tensorflow/python:errors",
"//tensorflow/python:pywrap_tensorflow",
],
)
py_library(
name = "tensor",
srcs = ["tensor.py"],
srcs_version = "PY2AND3",
visibility = ["//learning/brain/contrib/eager:__subpackages__"],
deps = [
":context",
":core",
":tape",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:tensor_shape",
"//third_party/py/numpy",
],
)
py_library(
name = "context",
srcs = ["context.py"],
srcs_version = "PY2AND3",
visibility = ["//learning/brain/contrib/eager:__subpackages__"],
deps = [
"//tensorflow/python:device",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:util",
],
)
py_library(
name = "tape",
srcs = ["tape.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python:dtypes",
"//tensorflow/python:util",
],
)
py_library(
name = "memory_trace",
srcs = ["memory_trace.py"],
srcs_version = "PY2AND3",
)
cuda_py_test(
name = "core_test",
srcs = ["core_test.py"],
additional_deps = [
":context",
":core",
":execute",
"//tensorflow/python:pywrap_tensorflow",
":tensor",
":test",
"//third_party/py/numpy",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_test_lib",
],
)
py_library(
name = "test",
srcs = ["test.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
":context",
"//tensorflow/python:client_testlib",
],
)
py_library(
name = "execute",
srcs = ["execute.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
":context",
":core",
":tape",
":tensor",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:dtypes",
"//tensorflow/python:lib",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
"@six_archive//:six",
],
)
cc_library(
name = "python_eager_op_gen",
srcs = ["python_eager_op_gen.cc"],
hdrs = ["python_eager_op_gen.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:op_gen_lib",
"//tensorflow/core:proto_text",
"//tensorflow/core:protos_all_cc",
"//tensorflow/python:python_op_gen",
],
)
cc_library(
name = "python_eager_op_gen_main",
srcs = [
"python_eager_op_gen_main.cc",
],
visibility = ["//visibility:public"],
deps = [
":python_eager_op_gen",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_binary(
name = "python_eager_op_gen_demo",
deps = [
":python_eager_op_gen_main",
"//tensorflow/core:ops",
],
)
py_library(
name = "custom_gradient",
srcs = ["custom_gradient.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
":core",
":tape",
"//tensorflow/python:framework_ops",
"//tensorflow/python:util",
],
)
py_library(
name = "graph_only_ops",
srcs = ["graph_only_ops.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework_ops",
],
)
py_library(
name = "framework_for_generated_wrappers",
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/eager:execute",
],
)
py_library(
name = "function",
srcs = ["function.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
":graph_only_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:graph_to_function_def",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:util",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:core",
"//tensorflow/python/eager:execute",
"//tensorflow/python/eager:tape",
"//tensorflow/python/eager:tensor",
"//third_party/py/numpy",
],
)
py_library(
name = "pip_dependencies",
visibility = ["//tensorflow:internal"],
deps = [
":context",
":core",
":execute",
":tensor",
":test",
"//tensorflow/python:pywrap_tensorflow",
],
)
# -----------------------------------------------------------------------------
# Google-internal targets.
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,333 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Experimental API for TensorFlow's "Eager" mode of execution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import threading
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.platform import app
from tensorflow.python.util import compat
from tensorflow.python.util import tf_contextlib
GRAPH_MODE = 0
EAGER_MODE = 1
# Default execution mode.
_default_mode = GRAPH_MODE
# TODO(agarwal): better name ?
class _EagerContext(threading.local):
"""Thread local eager context."""
def __init__(self):
super(_EagerContext, self).__init__()
self.device_index = -1
self.mode = _default_mode
self.scope_name = ""
self.recording_summaries = False
# TODO(agarwal): rename to EagerContext / EagerRuntime ?
class Context(object):
"""Environment in which eager operations execute."""
def __init__(self, graph=None):
self._eager_context = _EagerContext()
if not self.in_eager_mode():
raise ValueError("Trying to create a Context in GRAPH_MODE")
# Create a handle
opts = pywrap_tensorflow.TF_NewSessionOptions(target=compat.as_bytes(""),
config=None)
with errors.raise_exception_on_not_ok_status() as status:
self._handle = pywrap_tensorflow.TFE_NewContext(opts, status)
pywrap_tensorflow.TF_DeleteSessionOptions(opts)
# Store list of devices
self._devices = []
with errors.raise_exception_on_not_ok_status() as status:
device_list = pywrap_tensorflow.TFE_ContextListDevices(
self._handle, status)
try:
for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
with errors.raise_exception_on_not_ok_status() as status:
dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i, status)
self._devices.append(pydev.canonical_name(dev_name))
finally:
pywrap_tensorflow.TF_DeleteDeviceList(device_list)
self._summary_writer_resource = None
self._graph = graph or tf_ops.get_default_graph()
def __del__(self):
if self._handle is not None:
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.TFE_DeleteContext(self._handle, status)
def __str__(self):
lines = [
"Eager TensorFlow environment with %d devices" % (len(self._devices))
]
for i, d in enumerate(self._devices):
lines.append(" Device %d: %s" % (i, d))
return "\n".join(lines)
@tf_contextlib.contextmanager
def _mode(self, mode):
ctx = self._eager_context
old_mode = ctx.mode
ctx.mode = mode
try:
yield
finally:
ctx.mode = old_mode
def in_graph_mode(self):
"""Returns True if current thread is in GRAPH mode."""
return self._eager_context.mode == GRAPH_MODE
def in_eager_mode(self):
"""Returns True if current thread is in EAGER mode."""
return self._eager_context.mode == EAGER_MODE
@property
def scope_name(self):
"""Returns scope name for the current thread."""
return self._eager_context.scope_name
@scope_name.setter
def scope_name(self, s):
"""Sets scope name for the current thread."""
self._eager_context.scope_name = s
@property
def summary_writer_resource(self):
"""Returns summary writer resource."""
return self._summary_writer_resource
@summary_writer_resource.setter
def summary_writer_resource(self, resource):
"""Sets summary writer resource."""
self._summary_writer_resource = resource
@property
def recording_summaries(self):
"""Returns True if recording summaries is enabled in current thread.."""
return self._eager_context.recording_summaries
@recording_summaries.setter
def recording_summaries(self, val):
"""Enables recording summaries is enabled in current thread.."""
self._eager_context.recording_summaries = val
# TODO(agarwal): remove?
@property
def _device_index(self):
return self._eager_context.device_index
# TODO(agarwal): remove?
@_device_index.setter
def _device_index(self, val):
self._eager_context.device_index = val
@property
def device_name(self):
"""Returns the device name for the current thread."""
index = self._device_index
return None if index < 0 else self._devices[index]
def devices(self):
"""List of the names of devices available to execute operations."""
return self._devices
def num_gpus(self):
"""The number of GPUs available to execute operations."""
# TODO(ashankar): Use TF_DeviceListType to count GPU devices.
return len(self._devices) - 1
def as_default(self):
"""Returns a context manager to make self the default for this thread."""
return _default_context_stack.get_controller(self)
class _DefaultContextStack(tf_ops._DefaultStack): # pylint: disable=protected-access
"""A thread-local stack of Context objects."""
def __init__(self):
super(_DefaultContextStack, self).__init__()
self._global_default_context = None
def get_default(self):
"""Returns a thread local object if present, else a global default."""
return (super(_DefaultContextStack, self).get_default() or
self.global_default_context)
@property
def global_default_context(self):
if self._global_default_context is None:
self._global_default_context = Context()
return self._global_default_context
def reset(self):
super(_DefaultContextStack, self).reset()
self._global_default_context = None
_default_context_stack = _DefaultContextStack()
def get_default_context():
"""Returns a default Context object."""
return _default_context_stack.get_default()
# TODO(agarwal): switch users to get_default_context and get rid of this
# function.
def context():
return get_default_context()
def in_graph_mode():
"""Returns True if current thread is in GRAPH mode for default context."""
return get_default_context().in_graph_mode()
def in_eager_mode():
"""Returns True if current thread is in EAGER mode for default context."""
return get_default_context().in_eager_mode()
def graph_mode():
"""Context-manager to enable GRAPH mode for current thread."""
return get_default_context()._mode(GRAPH_MODE) # pylint: disable=protected-access
def eager_mode():
"""Context-manager to enable EAGER mode for current thread."""
return get_default_context()._mode(EAGER_MODE) # pylint: disable=protected-access
@contextlib.contextmanager
def namescope(name):
"""ContextManager for creating hierarchical name scopes."""
ctx = get_default_context()
old_name = ctx.scope_name
ctx.scope_name = "%s/%s" % (old_name, name) if old_name else name
try:
yield
finally:
ctx.scope_name = old_name
def scope_name():
"""Name of the current scope."""
return get_default_context().scope_name
@tf_contextlib.contextmanager
def device(name):
"""Context-manager to force placement of operations and Tensors on a device.
For example:
```python
with tfe.device('gpu:0'):
with tfe.device('cpu:0'):
shape = tfe.Tensor([], dtype=tf.int32)
x = ops.truncated_normal(shape, tf.float32)
```
will ensure that the `shape` Tensor is on CPU but the `truncated_normal`
operation
runs on GPU 0.
Args:
name: Name of the device (see get_default_context().devices()), or None to
enable automatic placement.
Yields:
Nothing.
Raises:
ValueError: If name does not correspond to a valid device.
"""
device_index = -1
ctx = get_default_context()
if name is not None:
name = pydev.canonical_name(name)
all_devices = ctx.devices()
for i, d in enumerate(all_devices):
# TODO(ashankar): This will change when we have distributed support.
# At that point, should not look for a string suffix but be able to
# do a full string comparison.
if d.endswith(name):
device_index = i
break
if device_index < 0:
raise ValueError("device {} does not match the available devices ({})".
format(name, all_devices))
old_device_index = ctx._device_index # pylint: disable=protected-access
try:
ctx._device_index = device_index # pylint: disable=protected-access
yield
finally:
ctx._device_index = old_device_index # pylint: disable=protected-access
@contextlib.contextmanager
def record_summaries():
"""Context-manager to enable recording of summaries."""
ctx = get_default_context()
old = ctx.recording_summaries
ctx.recording_summaries = True
try:
yield
finally:
ctx.recording_summaries = old
def should_record_summary():
"""True if a summary should be recorded now."""
c = get_default_context()
return c.recording_summaries and c.summary_writer_resource is not None
def run(main=None, argv=None):
"""Runs the program with an optional 'main' function and 'argv' list.
The program will run with eager execution enabled.
Args:
main: the main function to run
argv: the arguments to pass to it
"""
enable_eager_execution()
app.run(main, argv)
# TODO(apassos): This should not be a part of the public API.
def enable_eager_execution():
"""Enables, for the rest of the lifetime of this program, eager execution.
If not called immediately on startup risks creating breakage and bugs.
"""
global _default_mode
assert _default_mode == GRAPH_MODE
_default_mode = EAGER_MODE

View File

@ -0,0 +1,88 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Experimental API for TensorFlow's "Eager" mode of execution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import memory_trace
from tensorflow.python.framework import errors
# Trace of execution and memory usage.
_active_trace = None
_uid_counter = 0
_uid_lock = threading.Lock()
def uid():
"""A unique (within this program execution) integer."""
with _uid_lock:
global _uid_counter
_uid_counter += 1
return _uid_counter
def _status_to_exception(code, message):
try:
error_class = errors.exception_type_from_error_code(code)
return error_class(None, None, message)
except KeyError:
return errors.UnknownError(None, None, message, code)
class _NotOkStatusException(Exception):
"""Exception class to handle not ok Status."""
def __init__(self, message, code):
super(_NotOkStatusException, self).__init__()
self.message = message
self.code = code
def __str__(self):
e = _status_to_exception(self.code, self.message)
return "%s: %s" % (e.__class__.__name__, e)
pywrap_tensorflow.TFE_Py_RegisterExceptionClass(_NotOkStatusException)
def enable_tracing():
"""Enables tracing of execution and memory usage.
WARNING: tracing is not thread-safe.
"""
global _active_trace
_active_trace = memory_trace.MemoryTrace(
len(context.get_default_context().devices()))
def flush_trace():
"""Flushes the active trace, if it exists.
WARNING: tracing is not thread-safe.
"""
if _active_trace is not None:
_active_trace.flush_trace()
def active_trace():
"""Returns the current global active trace of execution and memory usage."""
return _active_trace

View File

@ -0,0 +1,488 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for core."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.eager import execute
from tensorflow.python.eager import tensor
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
def truncated_normal(shape):
return execute.execute(
'TruncatedNormal',
1,
inputs=[shape],
attrs=('dtype', dtypes.float32.as_datatype_enum, 'T',
shape.dtype.as_datatype_enum, 'seed', 0, 'seed2', 0))[0]
class TFETest(test_util.TensorFlowTestCase):
def testContext(self):
ctx = context.Context()
self.assertFalse(ctx.in_graph_mode())
self.assertTrue(ctx.in_eager_mode())
self.assertEqual('', ctx.scope_name)
self.assertEqual(-1, ctx._device_index) # pylint: disable=protected-access
self.assertFalse(ctx.recording_summaries)
self.assertIsNone(ctx.summary_writer_resource)
del ctx
def testDefaultContext(self):
orig = context.get_default_context()
self.assertIs(context.get_default_context(), orig)
c0 = context.Context()
self.assertIs(context.get_default_context(), orig)
context_manager_0 = c0.as_default()
self.assertIs(context.get_default_context(), orig)
with context_manager_0 as c0:
self.assertIs(context.get_default_context(), c0)
with context.Context().as_default() as c1:
self.assertIs(context.get_default_context(), c1)
self.assertIs(context.get_default_context(), c0)
self.assertIs(context.get_default_context(), orig)
def testContextWithThreads(self):
def run_fn(ctx1):
ctx2 = context.get_default_context()
# Default context created in different threads are different.
self.assertIsNot(ctx1, ctx2)
# Check that default values of the context created in a different thread
# are set correctly.
self.assertFalse(ctx2.in_graph_mode())
self.assertTrue(ctx2.in_eager_mode())
self.assertEqual('', ctx2.scope_name)
self.assertEqual(-1, ctx2._device_index) # pylint: disable=protected-access
self.assertFalse(ctx2.recording_summaries)
self.assertIsNone(ctx2.summary_writer_resource)
ctx1 = context.get_default_context()
t = threading.Thread(target=run_fn, args=(ctx1,))
t.start()
t.join()
def testScalarTensor(self):
t = tensor.Tensor(3)
self.assertEqual(t.numpy(), tensor.Tensor(np.array(3)).numpy())
self.assertEqual(dtypes.int32, t.dtype)
self.assertEqual(0, t.shape.ndims)
self.assertAllEqual([], t.shape.as_list())
def testTensorAndNumpyMatrix(self):
expected = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
actual = tensor.Tensor([[1.0, 2.0], [3.0, 4.0]])
self.assertAllEqual(expected, actual.numpy())
self.assertEqual(np.float32, actual.numpy().dtype)
self.assertEqual(dtypes.float32, actual.dtype)
self.assertAllEqual([2, 2], actual.shape.as_list())
def testFloatDowncast(self):
# Unless explicitly specified, float64->float32
t = tensor.Tensor(3.0)
self.assertEqual(dtypes.float32, t.dtype)
t = tensor.Tensor(3.0, dtype=dtypes.float64)
self.assertEqual(dtypes.float64, t.dtype)
def testBool(self):
t = tensor.Tensor(False)
if t:
self.assertFalse(True)
def testIntDowncast(self):
t = tensor.Tensor(3)
self.assertEqual(dtypes.int32, t.dtype)
t = tensor.Tensor(3, dtype=dtypes.int64)
self.assertEqual(dtypes.int64, t.dtype)
t = tensor.Tensor(2**33)
self.assertEqual(dtypes.int64, t.dtype)
def testTensorCreationFailure(self):
with self.assertRaises(Exception):
# Should fail because the each row of the Python object has a different
# number of columns.
self.assertEqual(None, tensor.Tensor([[1], [1, 2]]))
def testTensorPlacement(self):
if not context.context().num_gpus():
self.skipTest('No GPUs found')
x = tensor.Tensor(1.).as_gpu_tensor()
with context.device('gpu:0'):
y = tensor.Tensor(2.)
# Add would fail if t2 were not on GPU
result = execute.execute(
'Add', 1, inputs=[x, y],
attrs=('T', x.dtype.as_datatype_enum))[0].as_cpu_tensor().numpy()
self.assertEqual(3, result)
def testNumpyOrderHandling(self):
n = np.array([[1, 2], [3, 4]], order='F')
t = tensor.Tensor(n)
self.assertAllEqual([[1, 2], [3, 4]], t.numpy())
def testCopyBetweenDevices(self):
if not context.context().num_gpus():
self.skipTest('No GPUs found')
cpu = tensor.Tensor([[1., 2.], [3., 4.]])
c2g = cpu.as_gpu_tensor()
# Exercise a copy from GPU to CPU, even though we ignore the value.
_ = c2g.as_cpu_tensor()
with self.assertRaises(errors.InvalidArgumentError):
# c2g is on GPU. Copying between GPU devices fails
# (must redirect through CPU for now).
# TODO(ashankar): Perhaps the function should not fail and instead
# faciliate the copy through host memory?
c2g.as_gpu_tensor()
# Invalid device
with self.assertRaises(errors.InvalidArgumentError):
cpu.as_gpu_tensor(context.context().num_gpus() + 1)
def testNumpyForceCPU(self):
if not context.context().num_gpus():
self.skipTest('No GPUs found')
cpu = tensor.Tensor([[1., 2.], [3., 4.]])
c2g = cpu.as_gpu_tensor()
self.assertAllEqual(c2g.numpy(), cpu.numpy())
def testCopyFromCPUToCPU(self):
ta = tensor.Tensor([[1, 2], [3, 4]])
tb = ta.as_cpu_tensor()
self.assertNotEqual(ta._handle, tb._handle)
self.assertAllEqual(ta.numpy(), tb.numpy())
def testRegisterExceptionClass(self):
with self.assertRaises(TypeError):
pywrap_tensorflow.TFE_Py_RegisterExceptionClass(str)
pywrap_tensorflow.TFE_Py_RegisterExceptionClass(core._NotOkStatusException) # pylint: disable=protected-access
# TODO(agarwal): add tests passing incorrect typed values to attrs.
def testExecuteBasic(self):
three = tensor.Tensor(3)
five = tensor.Tensor(5)
product = execute.execute(
'Mul',
num_outputs=1,
inputs=[three, five],
attrs=('T', three.dtype.as_datatype_enum))[0]
self.assertEqual(15, product.numpy())
def testExecuteTooManyNumOutputs(self):
# num_outputs provided is 50, but only one output is produced.
# That should be okay.
product = execute.execute(
'Mul',
num_outputs=50,
inputs=[tensor.Tensor(3), tensor.Tensor(5)],
attrs=('T', dtypes.int32.as_datatype_enum))[0]
self.assertEqual(15, product.numpy())
def testMatMulGPU(self):
if not context.context().num_gpus():
self.skipTest('No GPUs found')
three = tensor.Tensor([[3.]]).as_gpu_tensor()
five = tensor.Tensor([[5.]]).as_gpu_tensor()
product = execute.execute(
'MatMul',
num_outputs=1,
inputs=[three, five],
attrs=('transpose_a', False, 'transpose_b', False, 'T',
three.dtype.as_datatype_enum))[0]
self.assertEqual([[15.0]], product.numpy())
def testExecuteStringAttr(self):
checked_three = execute.execute(
'CheckNumerics',
num_outputs=1,
inputs=[tensor.Tensor(3.)],
attrs=('message', 'just checking', 'T',
dtypes.float32.as_datatype_enum))[0]
self.assertEqual([[3]], checked_three.numpy())
def testExecuteStringAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError):
_ = execute.execute(
'CheckNumerics',
num_outputs=1,
inputs=[tensor.Tensor(3.)],
attrs=('message', 1, 'T', dtypes.float32.as_datatype_enum))
def testExecuteFloatAttr(self):
almost_equal = execute.execute(
'ApproximateEqual',
num_outputs=1,
inputs=[tensor.Tensor(3.0), tensor.Tensor(2.9)],
attrs=('tolerance', 0.3, 'T', dtypes.float32.as_datatype_enum))[0]
self.assertTrue(almost_equal.numpy())
def testExecuteFloatAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError):
_ = execute.execute(
'ApproximateEqual',
num_outputs=1,
inputs=[tensor.Tensor(3.0), tensor.Tensor(2.9)],
attrs=('tolerance', '0.3', 'T', dtypes.float32.as_datatype_enum))
def testExecuteIntAttr(self):
total = execute.execute(
'AddN',
num_outputs=1,
inputs=[tensor.Tensor(3), tensor.Tensor(4)],
attrs=('T', dtypes.int32.as_datatype_enum, 'N', 2))[0]
self.assertEqual(7, total.numpy())
def testExecuteIntAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError):
_ = execute.execute(
'AddN',
num_outputs=1,
inputs=[tensor.Tensor(3), tensor.Tensor(4)],
attrs=('T', dtypes.int32.as_datatype_enum, 'N', '2'))
# Looks like we don't have an existing op with list(bool) attrs.
def testExecuteBoolAttr(self):
product = execute.execute(
'MatMul',
num_outputs=1,
inputs=[tensor.Tensor([[3]]), tensor.Tensor([[5]])],
attrs=('transpose_a', True, 'transpose_b', False, 'T',
dtypes.int32.as_datatype_enum))[0]
self.assertEqual([[15]], product.numpy())
def testExecuteShapeAttr(self):
execute.execute(
'VarHandleOp',
num_outputs=1,
inputs=[],
attrs=('shape', [1, 2], 'dtype', dtypes.int32.as_datatype_enum,
'container', '', 'shared_name', ''))
def testExecuteShapeAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'VarHandleOp',
num_outputs=1,
inputs=[],
attrs=('shape', 1, 'dtype', dtypes.int32.as_datatype_enum,
'container', '', 'shared_name', ''))
def testExecuteListStringAttr(self):
execute.execute(
'TensorSummary',
num_outputs=1,
inputs=[tensor.Tensor(3.0)],
attrs=('T', dtypes.float32.as_datatype_enum, 'description',
'tensor_summary', 'labels', ['3',
'summary'], 'display_name', 'test'))
def testExecuteListStringAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'TensorSummary',
num_outputs=1,
inputs=[tensor.Tensor(3.0)],
attrs=('T', dtypes.float32.as_datatype_enum, 'description', '',
'labels', 3, 'display_name', 'test'))
def testExecuteListStringAttrBadListValue(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'TensorSummary',
num_outputs=1,
inputs=[tensor.Tensor(3.0)],
attrs=('T', dtypes.float32.as_datatype_enum, 'description', '',
'labels', [3], 'display_name', 'test'))
def testExecuteListFloatAttr(self):
b = execute.execute(
'Bucketize',
num_outputs=1,
inputs=[tensor.Tensor([3.0, 5.0, 7.0])],
attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', [4.0,
6.0]))[0]
self.assertAllEqual([0, 1, 2], b.numpy())
def testExecuteListFloatAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'Bucketize',
num_outputs=1,
inputs=[tensor.Tensor([3.0, 5.0, 7.0])],
attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', 4.0))
def testExecuteListFloatAttrBadListValue(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'Bucketize',
num_outputs=1,
inputs=[tensor.Tensor([3.0, 5.0, 7.0])],
attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries',
['4.0', '6.0']))
def testExecuteListIntAttr(self):
b = execute.execute(
'Squeeze',
num_outputs=1,
inputs=[tensor.Tensor([[[3.0]]])],
attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', [0, 2]))[0]
self.assertAllEqual([3], b.numpy())
def testExecuteListIntAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'Squeeze',
num_outputs=1,
inputs=[tensor.Tensor([[[3.0]]])],
attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', 0))
def testExecuteListIntAttrBadListValue(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'Squeeze',
num_outputs=1,
inputs=[tensor.Tensor([[[3.0]]])],
attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims',
['0', '2']))
def testExecuteListTypeListShapeAttr(self):
execute.execute(
'Barrier',
num_outputs=1,
inputs=[],
attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes',
[[1, 2]], 'capacity', -1, 'container', '', 'shared_name', ''))
def testExecuteListTypeAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'Barrier',
num_outputs=1,
inputs=[],
attrs=('component_types', dtypes.float64.as_datatype_enum, 'shapes',
[[1, 2]], 'capacity', -1, 'container', '', 'shared_name', ''))
def testExecuteListTypeAttrBadListValue(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'Barrier',
num_outputs=1,
inputs=[],
attrs=('component_types', '1', 'shapes', [[1, 2]], 'capacity', -1,
'container', '', 'shared_name', ''))
def testExecuteListShapeAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'Barrier',
num_outputs=1,
inputs=[],
attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes',
[1, 2], 'capacity', -1, 'container', '', 'shared_name', ''))
def testExecuteListShapeAttrBadListValue(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'Barrier',
num_outputs=1,
inputs=[],
attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes',
[1], 'capacity', -1, 'container', '', 'shared_name', ''))
def testExecuteMultipleOutputs(self):
split_dim = 1
value = [[0, 1, 2], [3, 4, 5]]
x1, x2, x3 = execute.execute(
'Split',
num_outputs=3,
inputs=[tensor.Tensor(split_dim),
tensor.Tensor(value)],
attrs=('num_split', 3, 'T', dtypes.int32.as_datatype_enum))
self.assertAllEqual([[0], [3]], x1.numpy())
self.assertAllEqual([[1], [4]], x2.numpy())
self.assertAllEqual([[2], [5]], x3.numpy())
def testExecuteBadNumOutputsArgument(self):
with self.assertRaises(TypeError):
execute.execute(
'Relu', [],
inputs=[tensor.Tensor(3.0)],
attrs=('T', dtypes.float32.as_datatype_enum))
def testExecuteUnknownOp(self):
with self.assertRaises(errors.NotFoundError):
execute.execute('BlahBlahBlah', num_outputs=1, inputs=[], attrs=None)
def testExecuteUnknownAttr(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'Identity',
num_outputs=1,
inputs=[tensor.Tensor(3)],
attrs=('T', dtypes.int32.as_datatype_enum, 'unknown_attr', 'blah'))
def testComposition(self):
def add(x, y):
return execute.execute(
'Add',
num_outputs=1,
inputs=[x, y],
attrs=('T', dtypes.int32.as_datatype_enum))[0]
x = tensor.Tensor(1)
three_x = add(add(x, x), x)
self.assertEquals(dtypes.int32, three_x.dtype)
self.assertEquals(3, three_x.numpy())
def testOperationWithNoInputsRunsOnDevice(self):
if not context.context().num_gpus():
self.skipTest('No GPUs found')
shape = tensor.Tensor([], dtype=dtypes.int32)
# x: Run the "TruncatedNormal" op CPU and copy result to GPU.
x = truncated_normal(shape).as_gpu_tensor()
# y: Explicitly run the "TruncatedNormal" op on GPU.
with context.device('gpu:0'):
y = truncated_normal(shape)
# Add would fail if x and y were not on the same device.
execute.execute('Add', 1, inputs=[x, y],
attrs=('T', x.dtype.as_datatype_enum))
def testInvalidDevice(self):
with self.assertRaises(ValueError):
with context.device('pu:0'):
_ = tensor.Tensor(1)
if __name__ == '__main__':
test.main()

View File

@ -0,0 +1,70 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Decorator to overrides the gradient for a function."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from autograd import core as ag_core
from tensorflow.python.eager import tape
from tensorflow.python.eager import tensor as _tensor
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.util import nest
def _watch_value_from_tape(tensor):
for t in tape._tape_stack.stack: # pylint: disable=protected-access
w = t.value.tensors.get(tape.tensor_id(tensor), None)
if w is not None:
return w
return tensor
def custom_gradient(f):
"""Decorator to define a function with a custom gradient.
The input function is expected to return the tuple
(results, gradient_function)
The output function will return results while possibly recording the
gradient_function and inputs in the tape.
Args:
f: function to be decorated.
Returns:
decorated function.
"""
def decorated(*args, **kwargs):
"""Decorated function with custom gradient."""
input_tensors = [_watch_value_from_tape(x) for x in args
if isinstance(x, (_tensor.Tensor, tf_ops.Tensor))
or ag_core.isnode(x)]
result, grad_fn = f(*args, **kwargs)
flat_result = nest.flatten(result)
flat_result = [ag_core.getval(x) for x in flat_result]
flat_result = tape.record_operation(
flat_result,
input_tensors,
[],
grad_fn)
flat_result = list(flat_result)
return nest.pack_sequence_as(structure=result, flat_sequence=flat_result)
return decorated

View File

@ -0,0 +1,241 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions called by the generated code to execute an eager-mode op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from autograd import core as ag_core
import six
from google.protobuf import text_format
from tensorflow.core.framework import tensor_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.eager import tape
from tensorflow.python.eager import tensor
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import compat
def execute(op_name, num_outputs, inputs, attrs=None, name=None):
"""Execute a TensorFlow operation.
Args:
op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
execute.
num_outputs: The number of outputs of the operation to fetch.
(Explicitly provided instead of being inferred for performance
reasons).
inputs: A list of inputs to the operation. Each entry should be a Tensor, or
a value which can be passed to the Tensor constructor to create one.
attrs: A tuple with alternating string attr names and attr values for this
operation.
name: Customized name for the operation.
Returns:
None if there are no outputs, a single Tensor object if there is one output
and a list of Tensor objects if there are multiple outputs.
Raises:
An exception on error.
"""
ctx = context.get_default_context()
# TODO(apassos) move this to convert_to_tensor
inputs = [ag_core.getval(x) for x in inputs]
# pylint: disable=protected-access
input_handles = [c._handle for c in inputs]
device_name = ctx.device_name
try:
outh = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
op_name, input_handles, attrs,
num_outputs)
# pylint: enable=protected-access
except core._NotOkStatusException as e: # pylint: disable=protected-access
raise core._status_to_exception(e.code, e.message) # pylint: disable=protected-access
# pylint: enable=protected-access
tensors = [tensor._tensor_from_handle(x) for x in outh] # pylint: disable=protected-access
if core.active_trace() is not None:
trace_name = name if name else op_name
for t in tensors:
# pylint: disable=protected-access
core.active_trace().record_tensor(trace_name,
tape.tensor_id(t),
t._device_name(),
t.shape.num_elements())
# pylint: enable=protected-access
return tensors
def record_gradient(unused_op_name, unused_inputs, unused_attrs, results,
unused_name):
"""Import backprop if you want gradients recorded."""
return results
def make_float(v, arg_name):
if not isinstance(v, compat.real_types):
raise TypeError("Expected float for argument '%s' not %s." %
(arg_name, repr(v)))
return float(v)
def make_int(v, arg_name):
if isinstance(v, six.string_types):
raise TypeError("Expected int for argument '%s' not %s." %
(arg_name, repr(v)))
try:
return int(v)
except (ValueError, TypeError):
raise TypeError("Expected int for argument '%s' not %s." %
(arg_name, repr(v)))
def make_str(v, arg_name):
if not isinstance(v, compat.bytes_or_text_types):
raise TypeError("Expected string for argument '%s' not %s." %
(arg_name, repr(v)))
return compat.as_bytes(v) # Convert unicode strings to bytes.
def make_bool(v, arg_name):
if not isinstance(v, bool):
raise TypeError("Expected bool for argument '%s' not %s." %
(arg_name, repr(v)))
return v
def make_type(v, arg_name):
try:
v = dtypes.as_dtype(v).base_dtype
except TypeError:
raise TypeError("Expected DataType for argument '%s' not %s." %
(arg_name, repr(v)))
i = v.as_datatype_enum
return i
def make_shape(v, arg_name):
"""Convert v into a list."""
# Args:
# v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.
# arg_name: String, for error messages.
# Returns:
# None if the rank is unknown, otherwise a list of ints (or Nones in the
# position where the dimension is unknown).
try:
shape = tensor_shape.as_shape(v)
except TypeError as e:
raise TypeError("Error converting %s to a TensorShape: %s" % (arg_name, e))
except ValueError as e:
raise ValueError("Error converting %s to a TensorShape: %s" % (arg_name, e))
if shape.ndims is None:
return None
else:
return shape.as_list()
def make_tensor(v, arg_name):
"""Ensure v is a TensorProto."""
if isinstance(v, tensor_pb2.TensorProto):
return v
elif isinstance(v, six.string_types):
pb = tensor_pb2.TensorProto()
text_format.Merge(v, pb)
return pb
raise TypeError(
"Don't know how to convert %s to a TensorProto for argument '%s'" %
(repr(v), arg_name))
def args_to_matching_eager(l, default_dtype=None):
"""Convert sequence `l` to eager same-type Tensors."""
# TODO(josh11b): Could we do a better job if we also passed in the
# allowed dtypes when that was known?
# Is some input already a Tensor with a dtype?
dtype = None
for t in l:
if isinstance(ag_core.getval(t), tensor.Tensor):
dtype = t.dtype
break
if dtype is None:
# TODO(josh11b): At the moment, I don't think this can fail, but at some
# point we likely should have some logic to prevent bad conversions.
dtype = default_dtype
if dtype is None:
# Infer a dtype based on the first value, and use that dtype for the
# remaining values.
ret = []
for t in l:
ret.append(tensor.convert_to_eager_tensor(t, dtype))
if dtype is None:
dtype = ret[-1].dtype
else:
ret = [tensor.convert_to_eager_tensor(t, dtype) for t in l]
return dtype, ret
def convert_to_mixed_eager_tensors(values):
v = [t if isinstance(ag_core.getval(t), tensor.Tensor) else tensor.Tensor(t)
for t in values]
types = [t.dtype for t in v]
return types, v
def args_to_mixed_eager_tensors(lists):
"""Converts a list of same-length lists of values to eager tensors."""
assert len(lists) > 1
# Generate an error if len(lists[i]) is not the same for all i.
lists_ret = []
for l in lists[1:]:
if len(l) != len(lists[0]):
raise ValueError(
"Expected list arguments to be the same length: %d != %d (%r vs. %r)"
% (len(lists[0]), len(l), lists[0], l))
lists_ret.append([])
# Convert the first element of each list first, then the second element, etc.
types = []
for i in range(len(lists[0])):
dtype = None
# If any list has a Tensor, use that dtype
for l in lists:
if isinstance(ag_core.getval(l[i]), tensor.Tensor):
dtype = l[i].dtype
break
if dtype is None:
# Convert the first one and use its dtype.
lists_ret[0].append(tensor.convert_to_eager_tensor(lists[0][i]))
dtype = lists_ret[0][i].dtype
for j in range(1, len(lists)):
lists_ret[j].append(
tensor.convert_to_eager_tensor(lists[j][i], dtype=dtype))
else:
# Convert everything to the found dtype.
for j in range(len(lists)):
lists_ret[j].append(
tensor.convert_to_eager_tensor(lists[j][i], dtype=dtype))
types.append(dtype)
return types, lists_ret

View File

@ -0,0 +1,518 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=unidiomatic-typecheck
"""Defun decorator for defining graph-mode functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import contextlib
import threading
from autograd import core as ag_core
import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.eager import execute
from tensorflow.python.eager import tape
from tensorflow.python.eager import tensor
from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.util import nest
# Thread-local storage for tfe Tensors which are referenced while evaluating a
# graph-mode function.
_scoped_captures = threading.local()
# _scoped_captures.tensors is either None or a map from tfe.Tensor id to a pair
# of a tfe tensor and its corresponding placeholder to pass as a function
# argument. The value should be None unless we're in function definition
# context.
_scoped_captures.tensors = None
@contextlib.contextmanager
def capture_tensors(captures):
old = _scoped_captures.__dict__.get("tensors", None)
try:
_scoped_captures.tensors = captures
yield
finally:
_scoped_captures.tensors = old
def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
"""Captures a tfe Tensor while building a graph mode function.
Creates a placeholder to pass the tensor as an argument.
Arguments:
value: A tfe.Tensor object
dtype: The datatype of the value produced by the node in the graph.
name: Name of the node in the graph.
as_ref: Ignored (required by register_tensor_conversion_function).
Returns:
A placeholder which will, at runtime, have the value of this tensor.
Raises:
ValueError: if called outside a defun context.
"""
_ = as_ref
tensor_map = _scoped_captures.tensors
if tensor_map is None:
raise ValueError(
"Trying to use tfe.Tensor objects in a graph outside graph mode. "
"To build a graph use tfe.defun or tfe.func_to_object.")
captured_value = tensor_map.get(tape.tensor_id(value), None)
if captured_value is None:
captured_value = graph_placeholder(dtype=dtype or value.dtype,
shape=value.shape,
name=name)
if captured_value.dtype == dtypes.resource:
captured_value._handle_data = value._handle_data # pylint: disable=protected-access
tensor_map[tape.tensor_id(value)] = (value, captured_value)
else:
captured_value = captured_value[1]
return captured_value
# TODO(apassos): it'd be really nice if we could scope this registration.
ops.register_tensor_conversion_function(tensor.Tensor,
_convert_to_graph_constant)
class _CapturingContext(object):
"""Tracks references to Tensors outside this context while it is active."""
def __init__(self):
# known_ops are ops which are created while this context is active
self.known_ops = set()
# captured_tensors are all tensors referenced to by ops in this context but
# not produced in it
self.captured_tensors = set()
def AddOp(self, op): # pylint: disable=invalid-name
if op.type in ["Variable", "VariableV2", "VarHandleOp"]:
raise ValueError("tfe.defun cannot capture variables created without "
"using tf.get_variable. Op: %s" % op)
self.known_ops.add(op)
for i in op.inputs:
if i.op not in self.known_ops:
self.captured_tensors.add(i)
def __enter__(self):
self._g = ops.get_default_graph()
self._old = self._g._get_control_flow_context() # pylint: disable=protected-access
self._g._set_control_flow_context(self) # pylint: disable=protected-access
def __exit__(self, _, __, ___): # pylint: disable=invalid-name
self._g._set_control_flow_context(self._old) # pylint: disable=protected-access
def _forward_name(n):
"""The name of a generated forward defun named n."""
return "__forward_%s_%s" % (n, core.uid())
def _backward_name(n):
"""The name of a generated backward defun named n."""
return "__backward_%s_%s" % (n, core.uid())
def _inference_name(n):
"""The name of a forward-but-no-gradient defun named n."""
return "__inference_%s_%s" % (n, core.uid())
class _DefinedFunction(object):
"""Mocks the interface of tf _DefinedFunction."""
def __init__(self, fdef):
self.definition = fdef
self.name = fdef.signature.name
self.grad_func_name = None
self.python_grad_func = None
def _map_sequence_obj_to_idx(sequence):
"""Maps objs in the sequence from id(obj) to sequence index."""
return {id(x): i for i, x in enumerate(sequence)}
class _GraphModeFunction(object):
"""Callable object representing a graph-mode function.
Args:
input_placeholders: list of placeholder values to feed when calling
the wrapped function.
extra_inputs: Tensor inputs this function definition closed over which
are passed as arguments. Need to track so gradients are supported
correctly.
fdef: the function definition we want to call.
graph: the graph from which the fdef operations were pulled. Used as
a context when computing gradients.
operations: the subset of operations in the graph used in the function
definition.
func_outputs: the python outputs of the graph-mode function, with
tensorflow.Tensor objects to be replaced by tfe values when called.
func_outputs_to_fdef_outputs: Maps id(obj) in func_outputs to index of
fdef's outputs. It allows mapping fdef output tensors to nested
func_outputs structure.
output_shapes: List of shapes of all tensors which are output by the
internal function.
"""
def __init__(self,
input_placeholders,
extra_inputs,
fdef,
graph,
operations,
func_outputs,
func_outputs_to_fdef_outputs,
output_shapes):
assert len(input_placeholders) == len(fdef.signature.input_arg), "%s %s" % (
len(input_placeholders), len(fdef.signature.input_arg))
self._input_placeholders = input_placeholders
self._extra_inputs = list(extra_inputs)
self._graph = graph
self._has_backprop = False
self._func_name = fdef.signature.name
self._fdef = _DefinedFunction(fdef)
self._num_outputs = len(fdef.signature.output_arg)
self._ops = operations
self._func_outputs = func_outputs
if (isinstance(func_outputs, (ops.Tensor, type(None)))
or ag_core.isnode(func_outputs)):
self._returns = [func_outputs]
else:
self._returns = list(func_outputs)
self._returns_to_fedf_outputs = func_outputs_to_fdef_outputs
self._output_shapes = output_shapes
def _compute_backprop(self):
"""Computes the backprop function object for this function."""
self._has_backprop = True
with self._graph.as_default(), context.graph_mode():
c = _CapturingContext()
with c:
filtered_outputs = [ag_core.getval(x)
for x in self._returns if x is not None]
self._out_grad_placeholders = [
graph_placeholder(x.dtype, x.shape)
for x in filtered_outputs
]
in_gradients = gradients_impl.gradients(
filtered_outputs,
self._input_placeholders,
grad_ys=self._out_grad_placeholders)
shapes = [x.shape for x in in_gradients if x is not None]
captures = list(sorted(c.captured_tensors, key=lambda x: x.name))
forward_function_def = graph_to_function_def.graph_to_function_def(
self._graph, self._ops,
self._input_placeholders,
filtered_outputs + captures)
self._forward_fdef = _DefinedFunction(forward_function_def)
_register_with_name(_forward_name(self._func_name),
forward_function_def)
backward_outputs = [x for x in in_gradients if x is not None]
all_inputs = self._out_grad_placeholders + captures
backward_function_def = graph_to_function_def.graph_to_function_def(
self._graph,
[x.op for x in self._out_grad_placeholders] +
list(sorted(c.known_ops, key=lambda x: x.name)),
all_inputs,
backward_outputs)
_register_with_name(_backward_name(self._func_name), backward_function_def)
self._backward_function = _GraphModeFunction(
all_inputs, [], backward_function_def, self._graph, c.known_ops,
in_gradients, _map_sequence_obj_to_idx(backward_outputs), shapes)
def _backprop_call(self, args):
"""Calls the wrapped function and records the result on a tape."""
all_args = args + self._extra_inputs
signature = self._forward_fdef.definition.signature
if context.in_graph_mode():
g = ops.get_default_graph()
g._add_function(self._forward_fdef) # pylint: disable=protected-access
unwrapped_args = [ag_core.getval(x) for x in all_args]
op = g.create_op(signature.name,
[ops.convert_to_tensor(x) for x in unwrapped_args],
[dtypes.DType(x.type) for x in signature.output_arg],
op_def=signature,
name="FunctionCall",
compute_shapes=False)
outputs = op.outputs
outputs = [outputs] if isinstance(
outputs, (tensor.Tensor, ops.Tensor, type(None))) else list(outputs)
for i, s in enumerate(self._output_shapes):
outputs[i].set_shape(s)
else:
outputs = execute.execute(
signature.name,
num_outputs=len(signature.output_arg),
inputs=all_args)
real_outputs = outputs[:len(self._returns)]
side_outputs = outputs[len(self._returns):]
watched_extra_inputs = []
for t in self._extra_inputs:
tid = tape.tensor_id(t)
for t in tape._tape_stack.stack: # pylint: disable=protected-access
w = t.value.tensors.get(tid, None)
if w is not None:
watched_extra_inputs.append(w)
break
else: # Note: for-else here done on purpose
watched_extra_inputs.append(t)
real_outputs = tape.record_operation(real_outputs,
(args + watched_extra_inputs),
side_outputs,
self._backward_function)
return self._build_call_outputs(self._returns, real_outputs)
def __call__(self, *args):
"""Executes the passed function in eager mode."""
tensor_inputs = [x for x in nest.flatten(args)
if isinstance(x, (tensor.Tensor, ops.Tensor,
tensor.LazyZero))
or ag_core.isnode(x)]
if tape.should_record(tensor_inputs) or any(
tape.any_tape_has(t) for t in self._extra_inputs):
if not self._has_backprop:
self._compute_backprop()
return self._backprop_call(tensor_inputs)
if context.in_graph_mode():
g = ops.get_default_graph()
g._add_function(self._fdef) # pylint: disable=protected-access
signature = self._fdef.definition.signature
args = list(tensor_inputs) + self._extra_inputs
op = g.create_op(signature.name,
[ops.convert_to_tensor(x) for x in args],
[dtypes.DType(x.type) for x in signature.output_arg],
op_def=signature,
name="FunctionCall",
compute_shapes=False)
result = op.outputs
for i, s in enumerate(self._output_shapes):
result[i].set_shape(s)
else:
tensor_inputs = [x.tensor() if isinstance(x, tensor.LazyZero) else x
for x in tensor_inputs]
result = execute.execute(
self._func_name,
num_outputs=self._num_outputs,
inputs=tensor_inputs + self._extra_inputs)
return self._build_call_outputs(self._returns, result)
def _build_call_outputs(self, func_outputs, result):
"""Maps the fdef output list to actual output structure.
Args:
func_outputs: The outputs originally defined by the graph function. It
could potentially be a nested structure.
result: Output lists defined by FunctionDef.
Returns:
The actual call output.
"""
if self._func_outputs is None:
return None
if isinstance(ag_core.getval(self._func_outputs), ops.Tensor):
return result[0]
outputs = []
for o in func_outputs:
vo = ag_core.getval(o)
if isinstance(vo, ops.Tensor):
outputs.append(result[self._returns_to_fedf_outputs[id(vo)]])
elif type(vo) in (tuple, list):
outputs.append(self._build_call_outputs(o, result))
else:
outputs.append(o)
return tuple(outputs) if type(func_outputs) is tuple else outputs
def _get_defun_inputs(args):
"""Maps the inputs args to graph inputs."""
ret = []
for a in args:
a = ag_core.getval(a)
if isinstance(a, (tensor.LazyZero, ops.Tensor, tensor.Tensor)):
ret.append(graph_placeholder(a.dtype, a.shape))
elif type(a) in (tuple, list):
ret.append(_get_defun_inputs(a))
else:
ret.append(a)
return tuple(ret) if type(args) is tuple else ret
def _defun_internal(name, func, args, kwds):
"""Defines and returns graph-mode version of func."""
with context.graph_mode():
tmp_graph = ops.Graph()
with tmp_graph.as_default():
func_inputs = _get_defun_inputs(args)
captures = {}
with capture_tensors(captures):
func_outputs = func(*func_inputs, **kwds)
ids = list(sorted(captures.keys()))
if ids:
extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids])
else:
extra_inputs = []
extra_placeholders = []
outputs_list = nest.flatten(func_outputs)
output_shapes = [x.shape for x in outputs_list if x is not None]
flat_inputs = [x for x in nest.flatten(func_inputs)
if isinstance(x, ops.Tensor)]
all_inputs = flat_inputs + list(extra_placeholders)
func_def_outputs = [ag_core.getval(x) for x in outputs_list if x is not None]
inference_function_def = graph_to_function_def.graph_to_function_def(
tmp_graph, tmp_graph.get_operations(),
all_inputs,
func_def_outputs)
# Register any other functions defined in the graph
# TODO(ashankar): Oh lord, forgive me for this lint travesty.
for f in tmp_graph._functions.values(): # pylint: disable=protected-access
# TODO(ashankar): What about the gradient registry?
_register_with_name(f.name, f.definition)
_register_with_name(_inference_name(name), inference_function_def)
return _GraphModeFunction(
all_inputs,
extra_inputs,
inference_function_def,
tmp_graph,
tmp_graph.get_operations(),
func_outputs,
_map_sequence_obj_to_idx(func_def_outputs),
output_shapes)
# Defun uses this instead of Tensor as a cache key. Using dtype because
# TensorFlow graphs are not parametric wrt dtypes, and using shapes for
# performance reasons, as much TensorFlow code specializes on known shapes to
# produce slimmer graphs.
_TensorDtype = collections.namedtuple("_TensorDtype", ["dtype", "shape"])
_ZeroDtype = collections.namedtuple("_ZeroDtype", ["dtype", "shape"])
def _cache_key(x):
"""Cache key for tfe functions."""
x = ag_core.getval(x)
if isinstance(x, tensor.Tensor):
return _TensorDtype(x.dtype, x._shape_tuple()) # pylint: disable=protected-access
if isinstance(x, tensor.LazyZero):
return _TensorDtype(x.dtype, tuple(x.shape.as_list())) # pylint: disable=protected-access
if isinstance(x, np.ndarray):
return ("array", x.shape, tuple(x.reshape(-1)))
if type(x) in (list, tuple):
return tuple([_cache_key(a) for a in x])
return x
def register_function_def(fdef):
fdef_string = fdef.SerializeToString()
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.TFE_ContextAddFunctionDef(
context.get_default_context()._handle, # pylint: disable=protected-access
fdef_string,
len(fdef_string),
status)
def _register_with_name(name, fdef):
"""Registers the function `fdef` with the name `name`."""
fdef.signature.name = name
register_function_def(fdef)
# TODO(apassos): better error messages for non-hashable arguments.
def named_defun(func, name):
"""Defines a function with a given name.
See the documentation for `defun` for more information on the semantics of the
function.
Args:
func: the function to be wrapped.
name: the name given to it.
Returns:
the wrapped function.
"""
arguments_to_functions = {}
def decorated(*args, **kwds):
"""Decorated version of func."""
# Macroexpand on non-Tensor arguments
cache_key = tuple(_cache_key(x) for x in args)
assert all(not isinstance(x, tensor.Tensor) for x in kwds.values())
cache_key = (cache_key, tuple(kwds.items()))
if cache_key not in arguments_to_functions:
arguments_to_functions[cache_key] = _defun_internal(
name, func, args, kwds)
return arguments_to_functions[cache_key](*args)
return decorated
def defun(func):
"""Decorator to compile func into graph_mode.
defun converts a function that constructs a TensorFlow graph into a function
that executes the graph. TensorFlow graphs typically execute faster and with a
lower memory-footprint than executing each of the operations that make up the
function individually as the TensorFlow runtime can optimize the graph and
execute sub-operations in parallel.
func must be a Python function that constructs a TensorFlow graph,
typically using functions in the tensorflow module.
Arguments to func can be either tfe.Tensor objects or Python
objects. Non-Tensor python objects are treated as constants, and new function
definitions are created internally based on their values.
func must return a tf.Tensor (NOT a tfe.Tensor) or a list of tf.Tensor (NOT a
tfe.Tensor). TODO(apassos) make the wrapped tfe ops return tf.Tensors when in
graph mode.
TODO(apassos): deal with captured global state. Deal with control flow.
Args:
func: function to be compiled.
Returns:
A callable that will execute the compiled function (and return zero
or more tfe.Tensor objects)
"""
return named_defun(func, func.__name__)

View File

@ -0,0 +1,45 @@
"""For eager-mode Python."""
load("//tensorflow:tensorflow.bzl", "clean_dep", "tf_copts")
def tfe_gen_op_wrapper_py(name,
out=None,
visibility=None,
deps=[],
generated_target_name=None):
"""Generate an eager-mode Python op wrapper for an op library."""
# Construct a cc_binary containing the specified ops.
tool_name = "gen_" + name + "_py_wrappers_cc"
if not deps:
deps = [str(Label("//tensorflow/core:" + name + "_op_lib"))]
native.cc_binary(
name=tool_name,
linkopts=["-lm"],
copts=tf_copts(),
linkstatic=1,
deps=([
clean_dep("//tensorflow/python/eager:python_eager_op_gen_main")
] + deps),
visibility=[clean_dep("//visibility:public")],)
# Invoke the previous cc_binary to generate a python file.
if not out:
out = "gen_" + name + ".py"
native.genrule(
name=name + "_pygenrule",
outs=[out],
tools=[tool_name],
cmd=("$(location " + tool_name + ") > $@"))
# Make a py_library out of the generated python file.
if not generated_target_name:
generated_target_name = name
native.py_library(
name=generated_target_name,
srcs=[out],
srcs_version="PY2AND3",
visibility=visibility,
deps=[
clean_dep("//tensorflow/python/eager:framework_for_generated_wrappers"),
],)

View File

@ -0,0 +1,50 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Graph-only versions of a few op functions, for internal use only."""
# Must be separate from array_ops to avoid a cyclic dependency.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import ops
def graph_zeros_like(tensor):
"""Graph-only version of tf.zeros_like(), for internal use only."""
g = ops._get_graph_from_inputs([tensor]) # pylint: disable=protected-access
with g.as_default(), ops.name_scope(None, "zeros_like", [tensor]) as name:
tensor = ops.convert_to_tensor(tensor, name="tensor")
dtype = tensor.dtype.base_dtype.as_datatype_enum
dtype_value = attr_value_pb2.AttrValue(type=dtype)
op = g.create_op("ZerosLike", [tensor], [dtype], input_types=[dtype],
attrs={"T": dtype_value}, name=name)
result, = op.outputs
return result
def graph_placeholder(dtype, shape, name=None):
"""Graph-only version of tf.placeholder(), for internal use only."""
dtype = dtype.base_dtype.as_datatype_enum
dtype_value = attr_value_pb2.AttrValue(type=dtype)
shape = attr_value_pb2.AttrValue(shape=shape.as_proto())
g = ops.get_default_graph()
with ops.name_scope(name, "placeholder", []) as name:
op = g.create_op("Placeholder", [], [dtype], input_types=[],
attrs={"dtype": dtype_value, "shape": shape}, name=name)
result, = op.outputs
return result

View File

@ -0,0 +1,88 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility to trace per-device memory consumption across time over execution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
TraceEntry = collections.namedtuple(
"TraceEntry", ["op_name", "tensor_id", "mem_usage", "device", "size"])
TensorData = collections.namedtuple(
"TensorData", ["op_name", "tensor_size", "device"])
class MemoryTrace(object):
"""Records a trace of memory usage over operation execution."""
def __init__(self, n_devices):
self.trace = []
self.tensor_to_data = {}
self.current_device_mem_usage = [0] * n_devices
def record_tensor(self, op_name, tensor_id, device, size):
self.current_device_mem_usage[device] += size
self.tensor_to_data[tensor_id] = TensorData(op_name, size, device)
self.trace.append(TraceEntry(op_name,
tensor_id,
self.current_device_mem_usage[:],
device,
size))
def delete_tensor(self, tensor_id):
if tensor_id not in self.tensor_to_data:
return
data = self.tensor_to_data.pop(tensor_id)
self.current_device_mem_usage[data.device] -= data.tensor_size
self.trace.append(TraceEntry(data.op_name,
tensor_id,
self.current_device_mem_usage[:],
data.device,
-data.tensor_size))
def flush_trace(self):
"""Prints the formatted trace recorded so far."""
longest_op_name = max(len(t.op_name) for t in self.trace)
longest_op_name = max(longest_op_name, len("op_name"))
longest_heap_size = max(max(len(str(d)) for d in t.mem_usage)
for t in self.trace)
longest_heap_size = max(longest_heap_size, len("d0"))
longest_id_len = max(len(str(t.tensor_id)) for t in self.trace)
longest_id_len = max(longest_id_len, 2)
first_line = []
first_line.append("+/-")
first_line.append("op_name".ljust(longest_op_name))
first_line.append("id".ljust(longest_id_len))
for i in range(len(self.current_device_mem_usage)):
first_line.append(("d"+str(i)).ljust(longest_heap_size))
first_line.append("size")
print(" | ".join(first_line))
for t in self.trace:
line = []
if t.size > 0:
line.append("+ ")
else:
line.append("- ")
line.append(t.op_name.ljust(longest_op_name))
line.append(str(t.tensor_id).ljust(longest_id_len))
for d in t.mem_usage:
line.append(str(d).ljust(longest_heap_size))
line.append(str(t.size))
print(" | ".join(line))
self.trace = []
print()

View File

@ -0,0 +1,763 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/python/eager/python_eager_op_gen.h"
#include <stdio.h>
#include <sstream>
#include <unordered_map>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb_text.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/framework/tensor.pb_text.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/python/framework/python_op_gen_internal.h"
namespace tensorflow {
namespace {
const int kRightMargin = 78;
string AttrVarName(const string& attr_name,
std::unordered_map<string, string>* attr_expressions) {
const string var = strings::StrCat("_attr_", attr_name);
if (attr_expressions != nullptr) (*attr_expressions)[attr_name] = var;
return var;
}
void AddInferredAttr(const string& attr_name, const string& value_expression,
string* result,
std::unordered_map<string, string>* attr_expressions) {
strings::StrAppend(result, " ", AttrVarName(attr_name, attr_expressions),
" = ", value_expression, "\n");
}
string VectorToTuple(const std::vector<string>& l) {
if (l.size() == 1) return strings::StrCat("(", l.front(), ",)");
string ret = "(";
for (int i = 0; i < l.size(); ++i) {
if (i > 0) {
strings::StrAppend(&ret, ", ");
}
strings::StrAppend(&ret, l[i]);
}
strings::StrAppend(&ret, ")");
return ret;
}
void Unflatten(const string& prefix, const std::vector<string>& output_sizes,
const string& var, string* result) {
for (int i = 0; i < output_sizes.size(); ++i) {
if (!output_sizes[i].empty()) {
strings::StrAppend(result, prefix, var, " = ");
if (i > 0) strings::StrAppend(result, var, "[:", i, "] + ");
if (i + 1 < output_sizes.size()) {
// Special case i == 0 to avoid "0 +" in the generated code.
if (i == 0) {
strings::StrAppend(result, "[", var, "[:", output_sizes[i], "]] + ",
var, "[", output_sizes[i], ":]");
} else {
strings::StrAppend(result, "[", var, "[", i, ":", i, " + ",
output_sizes[i], "]] + ", var, "[", i, " + ",
output_sizes[i], ":]");
}
} else {
strings::StrAppend(result, "[", var, "[", i, ":]]");
}
strings::StrAppend(result, "\n");
}
}
}
string TensorPBString(const TensorProto& pb) {
// Note: This gets used in the argument list, and so must survive naive
// word wrapping.
return strings::StrCat("\"\"\"", ProtoShortDebugString(pb), "\"\"\"");
}
class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
public:
GenEagerPythonOp(const OpDef& op_def, const string& function_name)
: python_op_gen_internal::GenPythonOp(op_def, function_name) {
op_name_ = function_name_;
op_name_.Consume("_");
}
~GenEagerPythonOp() override {}
string Code() override;
protected:
void ExpectListArg(const string& arg_name);
void AddEagerInferredAttrs();
void AddEagerInputCasts();
void AddEagerAttrs();
void AddEagerExecute(const string& num_outputs_expr);
void AddAttrForArg(const string& attr, int arg_index) {
gtl::InsertIfNotPresent(&inferred_attrs_, attr,
op_def_.input_arg(arg_index).name());
auto iter = attr_to_args_.find(attr);
if (iter == attr_to_args_.end()) {
attr_to_args_.insert(AttrToArgMap::value_type(attr, {arg_index}));
} else {
iter->second.push_back(arg_index);
}
}
// Returns a string expression representing a flattened list of all
// the inputs given by `*input_indices` (or all inputs if
// `input_indices` is nullptr). `*output_sizes` can be used to unflatten.
string FlattenInputs(const std::vector<int>* input_indices,
std::vector<string>* output_sizes) const;
StringPiece op_name_;
typedef std::unordered_map<string, std::vector<int>> AttrToArgMap;
AttrToArgMap attr_to_args_;
std::unordered_map<string, string> attr_expressions_;
};
string GetEagerPythonOp(const OpDef& op_def, const string& function_name) {
return GenEagerPythonOp(op_def, function_name).Code();
}
string GenEagerPythonOp::FlattenInputs(
const std::vector<int>* input_indices,
std::vector<string>* output_sizes) const {
string inputs;
enum { STARTING, WAS_LIST_INPUT, WAS_SOLO_INPUT } inputs_state = STARTING;
const int n = input_indices != nullptr ? input_indices->size()
: op_def_.input_arg_size();
for (int j = 0; j < n; ++j) {
const int i = input_indices ? (*input_indices)[j] : j;
const auto& arg(op_def_.input_arg(i));
const bool is_list =
!arg.type_list_attr().empty() || !arg.number_attr().empty();
if (is_list) {
if (inputs_state == WAS_SOLO_INPUT) {
strings::StrAppend(&inputs, "] + ");
} else if (inputs_state == WAS_LIST_INPUT) {
strings::StrAppend(&inputs, " + ");
}
strings::StrAppend(&inputs, "list(", param_names_[i], ")");
inputs_state = WAS_LIST_INPUT;
if (output_sizes != nullptr) {
if (!arg.number_attr().empty()) {
output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr));
} else {
output_sizes->emplace_back(
strings::StrCat("len(", param_names_[i], ")"));
}
}
} else {
if (inputs_state == WAS_SOLO_INPUT) {
strings::StrAppend(&inputs, ", ");
} else if (inputs_state == WAS_LIST_INPUT) {
strings::StrAppend(&inputs, " + [");
} else {
strings::StrAppend(&inputs, "[");
}
strings::StrAppend(&inputs, param_names_[i]);
inputs_state = WAS_SOLO_INPUT;
if (output_sizes != nullptr) output_sizes->emplace_back();
}
}
if (inputs_state == STARTING) return "[]";
if (inputs_state == WAS_SOLO_INPUT) {
strings::StrAppend(&inputs, "]");
}
return inputs;
}
string GenEagerPythonOp::Code() {
// This has all the input args followed by those attrs that don't have
// defaults.
std::vector<string> args_no_default;
// The parameters with defaults (these have to be listed after those without).
// No input args are included, just attrs.
std::vector<std::pair<string, string>> args_with_defaults;
for (int i = 0; i < op_def_.input_arg_size(); ++i) {
const auto& arg(op_def_.input_arg(i));
args_no_default.push_back(arg.name());
if (!arg.type_attr().empty()) {
AddAttrForArg(arg.type_attr(), i);
} else if (!arg.type_list_attr().empty()) {
AddAttrForArg(arg.type_list_attr(), i);
}
if (!arg.number_attr().empty()) {
AddAttrForArg(arg.number_attr(), i);
}
}
for (int i = 0; i < op_def_.attr_size(); ++i) {
const auto& attr(op_def_.attr(i));
// Do not add inferred attrs to the Python function signature.
if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
if (attr.has_default_value()) {
if (attr.type() == "tensor") {
args_with_defaults.emplace_back(
attr.name(),
strings::StrCat("_execute.make_tensor(",
TensorPBString(attr.default_value().tensor()),
", \"", attr.name(), "\")"));
} else if (attr.type() == "list(tensor)") {
std::vector<string> pbtxt;
for (const auto& pb : attr.default_value().list().tensor()) {
pbtxt.emplace_back(TensorPBString(pb));
}
args_with_defaults.emplace_back(
attr.name(),
strings::StrCat("[_execute.make_tensor(_pb, \"", attr.name(),
"\") for _pb in ", VectorToTuple(pbtxt), "]"));
} else {
args_with_defaults.emplace_back(
attr.name(), python_op_gen_internal::AttrValueToPython(
attr.type(), attr.default_value(), "_dtypes."));
}
} else {
args_no_default.push_back(attr.name());
}
}
}
// Save the list of attr parameters (attrs that won't be inferred),
// those with defaults go at the end.
// Get the attrs in the order we want by taking the attrs without defaults
// from the end of args_no_default, and adding args_no_default.
attrs_.reserve(args_no_default.size() - op_def_.input_arg_size() +
args_with_defaults.size());
attrs_.insert(attrs_.end(),
args_no_default.begin() + op_def_.input_arg_size(),
args_no_default.end());
for (const auto& a : args_with_defaults) {
attrs_.push_back(a.first);
}
param_names_.reserve(args_no_default.size() + args_with_defaults.size());
string parameters;
for (const string& name : args_no_default) {
if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
const string param = python_op_gen_internal::AvoidPythonReserved(name);
strings::StrAppend(&parameters, param);
param_names_.push_back(param);
}
for (const auto& name_default : args_with_defaults) {
if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
const string param =
python_op_gen_internal::AvoidPythonReserved(name_default.first);
strings::StrAppend(&parameters, param, "=", name_default.second);
param_names_.push_back(param);
}
if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
strings::StrAppend(&parameters, "name=None");
AddDefLine(parameters);
AddDocStringDescription();
AddDocStringArgs();
AddDocStringInputs();
AddDocStringAttrs();
strings::StrAppend(
&result_,
" name: A name for the operation (optional, only for graph mode).\n");
AddOutputGlobals();
AddDocStringOutputs();
strings::StrAppend(&result_, " \"\"\"\n");
// Function body.
// Validate list inputs, infer length attrs.
for (int i = 0; i < op_def_.attr_size(); ++i) {
const auto& attr(op_def_.attr(i));
if (attr.type() == "int") {
auto arg_list = attr_to_args_.find(attr.name());
if (arg_list != attr_to_args_.end()) {
// Inferred int attrs are the lengths of inputs. Validate those
// inputs are lists and have the same length.
for (auto iter = arg_list->second.begin();
iter != arg_list->second.end(); ++iter) {
const string& arg_name = param_names_[*iter];
ExpectListArg(arg_name);
if (iter == arg_list->second.begin()) {
AddInferredAttr(attr.name(), strings::StrCat("len(", arg_name, ")"),
&result_, &attr_expressions_);
} else {
const auto& attr_var = attr_expressions_[attr.name()];
strings::StrAppend(&result_, " if len(", arg_name,
") != ", attr_var,
":\n"
" raise ValueError(\n"
" \"List argument '",
arg_name, "' to '", op_name_,
"' Op with length %d \"\n"
" \"must match length %d of argument '",
inferred_attrs_[attr.name()],
"'.\" %\n"
" (len(",
arg_name, "), ", attr_var, "))\n");
}
}
}
}
}
// Values for non-inferred attrs.
for (int i = 0; i < attrs_.size(); ++i) {
const string& attr_name = attrs_[i];
const string& param = param_names_[i + op_def_.input_arg_size()];
const auto& attr = *FindAttr(attr_name, op_def_);
StringPiece attr_type = attr.type();
attr_expressions_[attr_name] = param;
const int default_index = i - (attrs_.size() - args_with_defaults.size());
if (default_index >= 0) {
const string& default_value = args_with_defaults[default_index].second;
strings::StrAppend(&result_, " if ", param, " is None:\n");
strings::StrAppend(&result_, " ", param, " = ", default_value, "\n");
}
if (attr_type.starts_with("list(")) {
ExpectListArg(param);
}
if (attr_type == "string") {
strings::StrAppend(&result_, " ", param, " = _execute.make_str(", param,
", \"", param, "\")\n");
} else if (attr_type == "list(string)") {
strings::StrAppend(&result_, " ", param, " = [_execute.make_str(_s, \"",
param, "\") for _s in ", param, "]\n");
} else if (attr_type == "int") {
strings::StrAppend(&result_, " ", param, " = _execute.make_int(", param,
", \"", param, "\")\n");
} else if (attr_type == "list(int)") {
strings::StrAppend(&result_, " ", param, " = [_execute.make_int(_i, \"",
param, "\") for _i in ", param, "]\n");
} else if (attr_type == "float") {
strings::StrAppend(&result_, " ", param, " = _execute.make_float(",
param, ", \"", param, "\")\n");
} else if (attr_type == "list(float)") {
strings::StrAppend(&result_, " ", param,
" = [_execute.make_float(_f, \"", param,
"\") for _f in ", param, "]\n");
} else if (attr_type == "bool") {
strings::StrAppend(&result_, " ", param, " = _execute.make_bool(", param,
", \"", param, "\")\n");
} else if (attr_type == "list(bool)") {
strings::StrAppend(&result_, " ", param, " = [_execute.make_bool(_b, \"",
param, "\") for _b in ", param, "]\n");
} else if (attr_type == "type") {
strings::StrAppend(&result_, " ", param, " = _execute.make_type(", param,
", \"", param, "\")\n");
} else if (attr_type == "list(type)") {
strings::StrAppend(&result_, " ", param, " = [_execute.make_type(_t, \"",
param, "\") for _t in ", param, "]\n");
} else if (attr_type == "shape") {
strings::StrAppend(&result_, " ", param, " = _execute.make_shape(",
param, ", \"", param, "\")\n");
} else if (attr_type == "list(shape)") {
strings::StrAppend(&result_, " ", param,
" = [_execute.make_shape(_s, \"", param,
"\") for _s in ", param, "]\n");
} else if (attr_type == "tensor") {
strings::StrAppend(&result_, " ", param, " = _execute.make_tensor(",
param, ", \"", param, "\")\n");
} else if (attr_type == "list(tensor)") {
strings::StrAppend(&result_, " ", param,
" = [_execute.make_tensor(_t, \"", param,
"\") for _t in ", param, "]\n");
} else if (attr_type != "func") {
return strings::StrCat("# No definition for ", function_name_,
" since we don't support attrs with type\n"
"# '",
attr_type, "' right now.\n\n");
}
}
// Figure out the list of inputs.
const string inputs = FlattenInputs(nullptr, nullptr);
// Handle graph-mode case
strings::StrAppend(&result_,
" if _context.in_graph_mode():\n"
" _, _, _op = _op_def_lib._apply_op_helper(\n");
AddBodyNoReturn(" ");
if (num_outs_ > 0) {
strings::StrAppend(&result_, " _result = _op.outputs[:]\n");
// Special case handling for stateful op with single list output
// that might be empty.
if (num_outs_ == 1 && op_def_.is_stateful() &&
(!op_def_.output_arg(0).number_attr().empty() ||
!op_def_.output_arg(0).type_list_attr().empty())) {
// TODO(josh11b): Can skip this if the number_attr/type_list_attr has
// a constraint indicating that this can never be empty.
strings::StrAppend(&result_,
" if not _result:\n"
" return _op\n");
}
strings::StrAppend(&result_, " _inputs_flat = ", inputs, "\n");
// Compute graph-mode attrs.
if (op_def_.attr_size() > 0) {
string attr_values;
for (int i = 0; i < op_def_.attr_size(); ++i) {
if (i > 0) strings::StrAppend(&attr_values, ", ");
const auto& attr_name(op_def_.attr(i).name());
strings::StrAppend(&attr_values, "\"", attr_name, "\", _op.get_attr(\"",
attr_name, "\")");
}
strings::StrAppend(&attr_values, ")");
strings::StrAppend(&result_,
WordWrap(" _attrs = (", attr_values, kRightMargin),
"\n");
} else {
strings::StrAppend(&result_, " _attrs = None\n");
}
} else {
strings::StrAppend(&result_, " return _op\n");
}
// Handle eager-mode case
strings::StrAppend(&result_, " else:\n");
// Expression representing the number of outputs.
int num_fixed_outputs = 0;
string num_outputs_expr;
// If output i is list output, output_sizes[i] will be set to a
// string with the python expression that will evaluate to its
// length. output_sizes[i] is empty for non-list outputs.
std::vector<string> output_sizes(num_outs_);
for (int i = 0; i < num_outs_; ++i) {
const auto& arg(op_def_.output_arg(i));
if (!arg.number_attr().empty()) {
if (!num_outputs_expr.empty()) {
strings::StrAppend(&num_outputs_expr, " + ");
}
output_sizes[i] = attr_expressions_[arg.number_attr()];
strings::StrAppend(&num_outputs_expr, output_sizes[i]);
} else if (!arg.type_list_attr().empty()) {
if (!num_outputs_expr.empty()) {
strings::StrAppend(&num_outputs_expr, " + ");
}
// Have to be careful to use an expression that works in both
// graph and eager paths here.
const auto iter = inferred_attrs_.find(arg.type_list_attr());
if (iter == inferred_attrs_.end()) {
output_sizes[i] = strings::StrCat(
"len(", attr_expressions_[arg.type_list_attr()], ")");
} else {
output_sizes[i] = strings::StrCat("len(", iter->second, ")");
}
strings::StrAppend(&num_outputs_expr, output_sizes[i]);
} else {
++num_fixed_outputs;
}
}
if (num_fixed_outputs > 0) {
if (!num_outputs_expr.empty()) {
strings::StrAppend(&num_outputs_expr, " + ");
}
strings::StrAppend(&num_outputs_expr, num_fixed_outputs);
} else if (num_outputs_expr.empty()) {
num_outputs_expr = "0";
}
bool eager_allowed = true;
for (const auto& arg : op_def_.input_arg()) {
if (arg.is_ref()) eager_allowed = false;
}
for (const auto& arg : op_def_.output_arg()) {
if (arg.is_ref()) eager_allowed = false;
}
if (eager_allowed) {
AddEagerInferredAttrs();
AddEagerInputCasts();
strings::StrAppend(&result_, " _inputs_flat = ", inputs, "\n");
AddEagerAttrs();
AddEagerExecute(num_outputs_expr);
} else {
strings::StrAppend(&result_,
" raise RuntimeError(\n"
" \"",
op_name_, " op does not support eager execution.\")\n");
}
if (num_outs_ > 0) {
strings::StrAppend(&result_, " _result = _execute.record_gradient(\n",
" \"", op_def_.name(),
"\", _inputs_flat, _attrs, _result, name)\n");
if (num_outs_ == 1 && !output_sizes[0].empty()) {
// Single list result.
} else if (num_outs_ == 1) {
// Execute returns a single-element list which we need to destructure.
strings::StrAppend(&result_, " _result, = _result\n");
} else {
// Have multiple outputs, so we will need to reformat the return
// value of execute() to be a list with one entry per op output
// (that entry will be a list of tensors if that output is of list
// type).
// For list outputs, convert the right subrange of _result into a list.
Unflatten(" ", output_sizes, "_result", &result_);
// Convert to a named tuple.
strings::StrAppend(&result_, " _result = _", op_def_.name(),
"Output._make(_result)\n");
}
}
strings::StrAppend(&result_, " return _result\n\n");
return prelude_ + result_;
}
void GenEagerPythonOp::ExpectListArg(const string& arg_name) {
strings::StrAppend(&result_, " if not isinstance(", arg_name,
", (list, tuple)):\n"
" raise TypeError(\n"
" \"Expected list for '",
arg_name,
"' argument to \"\n"
" \"'",
op_name_, "' Op, not %r.\" % ", arg_name, ")\n");
}
void GenEagerPythonOp::AddEagerInferredAttrs() {
// Figure out values for inferred attrs, and cast to eager tensors.
for (int i = 0; i < op_def_.attr_size(); ++i) {
const auto& attr(op_def_.attr(i));
auto arg_list = attr_to_args_.find(attr.name());
if (arg_list != attr_to_args_.end()) {
if (attr.type() == "type") {
std::vector<string> output_sizes;
const string flattened =
FlattenInputs(&arg_list->second, &output_sizes);
string conversion =
strings::StrCat("_execute.args_to_matching_eager(", flattened);
if (attr.has_default_value()) {
strings::StrAppend(
&conversion, ", ",
python_op_gen_internal::AttrValueToPython(
attr.type(), attr.default_value(), "_dtypes."));
}
strings::StrAppend(&conversion, ")");
const string var_name = AttrVarName(attr.name(), &attr_expressions_);
if (output_sizes.size() == 1) {
// Avoid creating a temporary variable in the case where
// we can easily assign to the right value directly.
const string inputs_var = param_names_[arg_list->second.front()];
if (output_sizes.front().empty()) {
strings::StrAppend(&result_, " ", var_name, ", (", inputs_var,
",) = ", conversion, "\n");
} else {
strings::StrAppend(&result_, " ", var_name, ", ", inputs_var,
" = ", conversion, "\n");
}
} else {
const string inputs_var = strings::StrCat("_inputs_", attr.name());
strings::StrAppend(&result_, " ", var_name, ", ", inputs_var,
" = ", conversion, "\n");
// Convert from a flat list of eager tensors back to the
// parameter variables.
Unflatten(" ", output_sizes, inputs_var, &result_);
std::vector<string> p;
for (int j : arg_list->second) {
p.emplace_back(param_names_[j]);
}
strings::StrAppend(&result_, " ", VectorToTuple(p), " = ",
inputs_var, "\n");
}
strings::StrAppend(&result_, " ", var_name, " = ", var_name,
".as_datatype_enum\n");
} else if (attr.type() == "list(type)") {
// NOTE: We ignore default values for these attrs, since it is
// unclear how you would use it, and the one use case is
// parse_single_sequence_example which only needs it for
// backwards compatibility.
const string var_name = AttrVarName(attr.name(), &attr_expressions_);
string inputs_var;
string conversion;
if (arg_list->second.size() > 1) {
// If you have more than one list(tensor) argument, their types
// have to match.
std::vector<string> lists;
for (auto iter = arg_list->second.begin();
iter != arg_list->second.end(); ++iter) {
lists.push_back(param_names_[*iter]);
}
inputs_var = VectorToTuple(lists);
conversion = "_execute.args_to_mixed_eager_tensors";
} else {
// For one list(tensor) argument, we just convert every
// element of the list to an eager tensor.
inputs_var = param_names_[arg_list->second.front()];
conversion = "_execute.convert_to_mixed_eager_tensors";
}
strings::StrAppend(&result_, " ", var_name, ", ", inputs_var, " = ",
conversion, "(", inputs_var, ")\n");
strings::StrAppend(&result_, " ", var_name,
" = [_t.as_datatype_enum for _t in ", var_name,
"]\n");
}
}
}
}
void GenEagerPythonOp::AddEagerInputCasts() {
// Cast remaining args to eager tensors
for (int i = 0; i < op_def_.input_arg_size(); ++i) {
const auto& arg(op_def_.input_arg(i));
if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue;
const string& param = param_names_[i];
const string fn = arg.number_attr().empty() ? "" : "n_";
const string dtype =
python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
strings::StrAppend(&result_, " ", param, " = _tensor.convert_", fn,
"to_eager_tensor(", param, ", ", dtype, ")\n");
}
}
void GenEagerPythonOp::AddEagerAttrs() {
// Compute eager attrs
if (op_def_.attr_size() > 0) {
string attr_values;
for (int i = 0; i < op_def_.attr_size(); ++i) {
if (i > 0) strings::StrAppend(&attr_values, ", ");
const auto& attr_name(op_def_.attr(i).name());
strings::StrAppend(&attr_values, "\"", attr_name, "\", ",
attr_expressions_[attr_name]);
}
strings::StrAppend(&attr_values, ")");
strings::StrAppend(
&result_, WordWrap(" _attrs = (", attr_values, kRightMargin), "\n");
} else {
strings::StrAppend(&result_, " _attrs = None\n");
}
}
void GenEagerPythonOp::AddEagerExecute(const string& num_outputs_expr) {
const string return_prefix = " _result = _execute.execute(";
const string return_args =
strings::StrCat("\"", op_def_.name(), "\", ", num_outputs_expr,
", inputs=_inputs_flat, attrs=_attrs, name=name)");
strings::StrAppend(&result_,
// Wrap the arguments, and indent to the (.
WordWrap(return_prefix, return_args, kRightMargin), "\n");
}
string GetEagerPythonOps(const OpList& ops,
const std::vector<string>& hidden_ops,
bool require_shapes) {
string result;
// Header
// TODO(josh11b): Mention the library for which wrappers are being generated.
strings::StrAppend(&result, R"("""Python wrappers for TensorFlow ops.
This file is MACHINE GENERATED! Do not edit.
"""
import collections as _collections
from tensorflow.python.eager import execute as _execute
from tensorflow.python.eager import context as _context
from tensorflow.python.eager import core as _core
from tensorflow.python.eager import tensor as _tensor
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.python.framework import tensor_shape as _tensor_shape
from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
# Needed to trigger the call to _set_call_cpp_shape_fn.
from tensorflow.python.framework import common_shapes as _common_shapes
from tensorflow.python.framework import op_def_registry as _op_def_registry
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import op_def_library as _op_def_library
)");
// We'll make a copy of ops that filters out descriptions.
OpList cleaned_ops;
auto out = cleaned_ops.mutable_op();
out->Reserve(ops.op_size());
for (const auto& op_def : ops.op()) {
bool is_hidden = false;
for (const string& hidden : hidden_ops) {
if (op_def.name() == hidden) {
is_hidden = true;
break;
}
}
string function_name;
python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(),
&function_name);
if (is_hidden) function_name = strings::StrCat("_", function_name);
// When users create custom python wrappers, they may link in the
// default op registry by accident, and because they can't
// enumerate all 'hidden' symbols, this guard is to prevent
// instantiating a python reserved word in their wrapper.
if (python_op_gen_internal::IsPythonReserved(function_name)) {
continue;
}
strings::StrAppend(&result, GetEagerPythonOp(op_def, function_name));
if (!require_shapes) {
strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(),
"\")(None)\n\n");
}
auto added = out->Add();
*added = op_def;
RemoveNonDeprecationDescriptionsFromOpDef(added);
}
result.append(R"(def _InitOpDefLibrary(op_list_proto_bytes):
op_list = _op_def_pb2.OpList()
op_list.ParseFromString(op_list_proto_bytes)
_op_def_registry.register_op_list(op_list)
op_def_lib = _op_def_library.OpDefLibrary()
op_def_lib.add_op_list(op_list)
return op_def_lib
)");
result.append("# ");
auto ops_text = ProtoDebugString(cleaned_ops);
str_util::StripTrailingWhitespace(&ops_text);
result.append(str_util::StringReplace(ops_text, "\n", "\n# ", true));
result.append("\n");
strings::Appendf(&result, "_op_def_lib = _InitOpDefLibrary(b\"%s\")\n",
str_util::CEscape(cleaned_ops.SerializeAsString()).c_str());
return result;
}
} // namespace
void PrintEagerPythonOps(const OpList& ops,
const std::vector<string>& hidden_ops,
bool require_shapes) {
printf("%s", GetEagerPythonOps(ops, hidden_ops, require_shapes).c_str());
}
string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len) {
string op_list_str(op_list_buf, op_list_len);
OpList ops;
ops.ParseFromString(op_list_str);
return GetEagerPythonOps(ops, {}, false);
}
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_
#define THIRD_PARTY_TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_
#include <string>
#include <vector>
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
// hidden_ops should be a list of Op names that should get a leading _
// in the output. Prints the output to stdout.
void PrintEagerPythonOps(const OpList& ops,
const std::vector<string>& hidden_ops,
bool require_shapes);
// Get the python wrappers for a list of ops in a OpList.
// `op_list_buf` should be a pointer to a buffer containing
// the binary encoded OpList proto, and `op_list_len` should be the
// length of that buffer.
string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len);
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_

View File

@ -0,0 +1,46 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/python/eager/python_eager_op_gen.h"
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/platform/init_main.h"
namespace tensorflow {
namespace {
void PrintAllPythonOps(const std::vector<string>& hidden_ops) {
OpList ops;
OpRegistry::Global()->Export(false, &ops);
PrintEagerPythonOps(ops, hidden_ops, true /* require_shapes */);
}
} // namespace
} // namespace tensorflow
int main(int argc, char* argv[]) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
if (argc == 1) {
tensorflow::PrintAllPythonOps({});
} else {
return -1;
}
return 0;
}

View File

@ -0,0 +1,67 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
#define TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include <Python.h>
typedef tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>
TFE_InputTensorHandles;
typedef tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2>
TFE_OutputTensorHandles;
// Execute a TensorFlow operation.
//
// 'device_name': Name of the device on which to execute the operation, or NULL
// for automatic selection.
// 'op_name': Name of the TensorFlow op to execute.
// 'inputs': An array of TFE_TensorHandle*'s of size 'num_inputs'. These tensors
// will be provided as input to the operation.
// 'attrs': A Python tuple alternating names and attr values.
// 'outputs': A pointer to a TFE_OutputTensorHandles in which outputs will
// placed. On success, its elements will be filled in and the
// caller takes ownership of each returned TFE_TensorHandle.
// 'outputs' MUST be sized to be at least as large as the number
// of tensors produced by the operation and will be resized to
// the actual number of tensors produced.
void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
const char* op_name, TFE_InputTensorHandles* inputs,
PyObject* attrs, TFE_OutputTensorHandles* outputs,
TF_Status* out_status);
// Convert a TFE_TensorHandle to a Python numpy.ndarray object.
//
// The two may share underlying storage so changes to one may reflect in the
// other.
PyObject* TFE_Py_TensorHandleToNumpy(TFE_TensorHandle* h, TF_Status* status);
// Convert a Python numpy.ndarray object to a TFE_TensorHandle.
//
// The two may share underlying storage so changes to one may reflect in the
// other.
TFE_TensorHandle* TFE_Py_NumpyToTensorHandle(PyObject* obj);
// Registers e as the Exception class for handling not ok Status. Returns
// Py_None if registration succeeds, else throws a TypeError and returns NULL.
PyObject* TFE_Py_RegisterExceptionClass(PyObject* e);
// Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using the
// class registered via TFE_Py_RegisterExceptionClass) and returns -1.
int TFE_Py_MayBeRaiseException(TF_Status* status);
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_

View File

@ -0,0 +1,377 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Must be included first.
#include "tensorflow/python/lib/core/numpy.h"
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/python/lib/core/py_func.h"
using tensorflow::string;
namespace {
#define PARSE_VALUE(fn_name, type, check_fn, parse_fn) \
bool fn_name(const string& key, PyObject* py_value, TF_Status* status, \
type* value) { \
if (check_fn(py_value)) { \
*value = static_cast<type>(parse_fn(py_value)); \
return true; \
} else { \
TF_SetStatus(status, TF_INVALID_ARGUMENT, \
tensorflow::strings::StrCat( \
"Expecting " #type " value for attr ", key, ", got ", \
py_value->ob_type->tp_name) \
.c_str()); \
return false; \
} \
}
#if PY_MAJOR_VERSION >= 3
PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong)
PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLong)
PARSE_VALUE(ParseStringValue, const char*, PyUnicode_Check, PyUnicode_AsUTF8)
#else
PARSE_VALUE(ParseStringValue, const char*, PyString_Check, PyString_AsString)
PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong)
PARSE_VALUE(ParseInt64Value, int64_t, PyInt_Check, PyInt_AsLong)
#endif
PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble)
#undef PARSE_VALUE
bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status,
unsigned char* value) {
*value = PyObject_IsTrue(py_value);
return true;
}
const char* ParseProtoValue(const string& key, const char* proto_name,
PyObject* py_value, size_t* size,
TF_Status* status) {
char* output = nullptr;
Py_ssize_t py_size;
#if PY_MAJOR_VERSION >= 3
if (!PyUnicode_Check(py_value) ||
(output = PyUnicode_AsUTF8AndSize(py_value, &py_size)) == nullptr) {
#else
if (!PyString_Check(py_value) ||
(PyString_AsStringAndSize(py_value, &output, &py_size) < 0)) {
#endif
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat("Expecting a string (serialized ",
proto_name, ") value for attr ", key)
.c_str());
return nullptr;
}
*size = static_cast<size_t>(py_size);
return output;
}
bool SetOpAttrList(TFE_Op* op, const char* key, PyObject* py_list,
TF_AttrType type, TF_Status* status) {
if (!PySequence_Check(py_list)) {
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat("Expecting sequence value for attr ", key,
", got ", py_list->ob_type->tp_name)
.c_str());
return false;
}
const int num_values = PySequence_Size(py_list);
#define PARSE_LIST(c_type, parse_fn) \
std::unique_ptr<c_type[]> values(new c_type[num_values]); \
for (int i = 0; i < num_values; ++i) { \
auto py_value = PySequence_ITEM(py_list, i); \
if (!parse_fn(key, py_value, status, &values[i])) return false; \
}
if (type == TF_ATTR_STRING) {
PARSE_LIST(const char*, ParseStringValue);
TFE_OpSetAttrStringList(op, key, values.get(), num_values);
} else if (type == TF_ATTR_INT) {
PARSE_LIST(int64_t, ParseInt64Value);
TFE_OpSetAttrIntList(op, key, values.get(), num_values);
} else if (type == TF_ATTR_FLOAT) {
PARSE_LIST(float, ParseFloatValue);
TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
} else if (type == TF_ATTR_BOOL) {
PARSE_LIST(unsigned char, ParseBoolValue);
TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
} else if (type == TF_ATTR_TYPE) {
PARSE_LIST(int, ParseIntValue);
TFE_OpSetAttrTypeList(op, key,
reinterpret_cast<const TF_DataType*>(values.get()),
num_values);
} else if (type == TF_ATTR_SHAPE) {
// Make one pass through the input counting the total number of
// dims across all the input lists.
int total_dims = 0;
for (int i = 0; i < num_values; ++i) {
auto py_value = PySequence_ITEM(py_list, i);
if (py_value != Py_None) {
if (!PySequence_Check(py_value)) {
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat(
"Expecting None or sequence value for element", i,
" of attr ", key, ", got ", py_value->ob_type->tp_name)
.c_str());
return false;
}
const auto size = PySequence_Size(py_value);
total_dims += size;
}
}
// Allocate a buffer that can fit all of the dims together.
std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
// Copy the input dims into the buffer and set dims to point to
// the start of each list's dims.
std::unique_ptr<const int64_t* []> dims(new const int64_t*[num_values]);
std::unique_ptr<int[]> num_dims(new int[num_values]);
int64_t* offset = buffer.get();
for (int i = 0; i < num_values; ++i) {
auto py_value = PySequence_ITEM(py_list, i);
if (py_value == Py_None) {
dims[i] = nullptr;
num_dims[i] = -1;
} else {
const auto size = PySequence_Size(py_value);
dims[i] = offset;
num_dims[i] = size;
for (int j = 0; j < size; ++j) {
auto inner_py_value = PySequence_ITEM(py_value, j);
if (inner_py_value == Py_None) {
*offset = -1;
} else if (!ParseInt64Value(key, inner_py_value, status, offset)) {
return false;
}
++offset;
}
}
}
TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
status);
if (TF_GetCode(status) != TF_OK) return false;
} else {
TF_SetStatus(status, TF_UNIMPLEMENTED,
tensorflow::strings::StrCat("Attr ", key,
" has unhandled list type ", type)
.c_str());
return false;
}
#undef PARSE_LIST
return true;
}
bool SetOpAttrScalar(TFE_Op* op, const char* key, PyObject* py_value,
TF_AttrType type, TF_Status* status) {
if (type == TF_ATTR_STRING) {
const char* value;
if (!ParseStringValue(key, py_value, status, &value)) return false;
TFE_OpSetAttrString(op, key, value);
} else if (type == TF_ATTR_INT) {
int64_t value;
if (!ParseInt64Value(key, py_value, status, &value)) return false;
TFE_OpSetAttrInt(op, key, value);
} else if (type == TF_ATTR_FLOAT) {
float value;
if (!ParseFloatValue(key, py_value, status, &value)) return false;
TFE_OpSetAttrFloat(op, key, value);
} else if (type == TF_ATTR_BOOL) {
unsigned char value;
if (!ParseBoolValue(key, py_value, status, &value)) return false;
TFE_OpSetAttrBool(op, key, value);
} else if (type == TF_ATTR_TYPE) {
int value;
if (!ParseIntValue(key, py_value, status, &value)) return false;
TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value));
} else if (type == TF_ATTR_SHAPE) {
if (py_value == Py_None) {
TFE_OpSetAttrShape(op, key, nullptr, -1, status);
} else {
if (!PySequence_Check(py_value)) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat(
"Expecting None or sequence value for attr", key,
", got ", py_value->ob_type->tp_name)
.c_str());
return false;
}
const auto num_dims = PySequence_Size(py_value);
std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
for (int i = 0; i < num_dims; ++i) {
auto inner_py_value = PySequence_ITEM(py_value, i);
if (inner_py_value == Py_None) {
dims[i] = -1;
} else if (!ParseInt64Value(key, inner_py_value, status, &dims[i])) {
return false;
}
}
TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status);
}
if (TF_GetCode(status) != TF_OK) return false;
} else {
TF_SetStatus(
status, TF_UNIMPLEMENTED,
tensorflow::strings::StrCat("Attr ", key, " has unhandled type ", type)
.c_str());
return false;
}
return true;
}
void SetOpAttrs(TFE_Op* op, PyObject* attrs, TF_Status* out_status) {
if (attrs == Py_None) return;
if (!PyTuple_Check(attrs)) {
TF_SetStatus(out_status, TF_INVALID_ARGUMENT, "Expecting an attrs tuple.");
return;
}
Py_ssize_t len = PyTuple_GET_SIZE(attrs);
if ((len & 1) != 0) {
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
"Expecting attrs tuple to have even length.");
return;
}
// Parse attrs
for (Py_ssize_t i = 0; i < len; i += 2) {
PyObject* py_key = PyTuple_GET_ITEM(attrs, i);
PyObject* py_value = PyTuple_GET_ITEM(attrs, i + 1);
#if PY_MAJOR_VERSION >= 3
const char* key = PyUnicode_AsUTF8(py_key);
#else
const char* key = PyString_AsString(py_key);
#endif
unsigned char is_list = 0;
const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status);
if (TF_GetCode(out_status) != TF_OK) return;
if (is_list != 0) {
if (!SetOpAttrList(op, key, py_value, type, out_status)) return;
} else {
if (!SetOpAttrScalar(op, key, py_value, type, out_status)) return;
}
}
}
} // namespace
void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
const char* op_name, TFE_InputTensorHandles* inputs,
PyObject* attrs, TFE_OutputTensorHandles* outputs,
TF_Status* out_status) {
TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
if (TF_GetCode(out_status) != TF_OK) return;
if (device_name != nullptr) {
TFE_OpSetDevice(op, ctx, device_name, out_status);
}
if (TF_GetCode(out_status) == TF_OK) {
for (int i = 0; i < inputs->size() && TF_GetCode(out_status) == TF_OK;
++i) {
TFE_OpAddInput(op, inputs->at(i), out_status);
}
}
if (TF_GetCode(out_status) == TF_OK) {
SetOpAttrs(op, attrs, out_status);
}
if (TF_GetCode(out_status) == TF_OK) {
int num_outputs = outputs->size();
TFE_Execute(op, outputs->data(), &num_outputs, out_status);
outputs->resize(num_outputs);
}
if (TF_GetCode(out_status) != TF_OK) {
TF_SetStatus(out_status, TF_GetCode(out_status),
tensorflow::strings::StrCat(TF_Message(out_status),
" [Op:", op_name, "]")
.c_str());
}
TFE_DeleteOp(op);
}
PyObject* TFE_Py_TensorHandleToNumpy(TFE_TensorHandle* h, TF_Status* status) {
const tensorflow::Tensor* t =
TFE_TensorHandleUnderlyingTensorInHostMemory(h, status);
if (TF_GetCode(status) != TF_OK) {
Py_RETURN_NONE;
}
PyObject* ret = nullptr;
auto cppstatus = tensorflow::ConvertTensorToNdarray(*t, &ret);
if (!cppstatus.ok()) {
TF_SetStatus(status, TF_Code(cppstatus.code()),
cppstatus.error_message().c_str());
}
if (ret != nullptr) return ret;
Py_RETURN_NONE;
}
namespace {
// Python subclass of Exception that is created on not ok Status.
tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr;
} // namespace
TFE_TensorHandle* TFE_Py_NumpyToTensorHandle(PyObject* obj) {
tensorflow::Tensor t;
auto cppstatus = tensorflow::ConvertNdarrayToTensor(obj, &t);
if (cppstatus.ok()) {
return TFE_NewTensorHandle(t);
} else {
tensorflow::mutex_lock l(exception_class_mutex);
auto msg = tensorflow::strings::StrCat(
"failed to convert numpy ndarray to a Tensor (",
cppstatus.error_message(), ")");
if (exception_class != nullptr) {
PyErr_SetObject(exception_class,
Py_BuildValue("si", msg.c_str(), TF_INVALID_ARGUMENT));
} else {
PyErr_SetString(PyExc_RuntimeError, msg.c_str());
}
}
return nullptr;
}
PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
tensorflow::mutex_lock l(exception_class_mutex);
if (exception_class != nullptr) {
Py_DECREF(exception_class);
}
if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
exception_class = nullptr;
PyErr_SetString(PyExc_TypeError,
"TFE_Py_RegisterExceptionClass: "
"Registered class should be subclass of Exception.");
return nullptr;
} else {
Py_INCREF(e);
exception_class = e;
Py_RETURN_NONE;
}
}
int TFE_Py_MayBeRaiseException(TF_Status* status) {
if (TF_GetCode(status) == TF_OK) return 0;
tensorflow::mutex_lock l(exception_class_mutex);
if (exception_class != nullptr) {
PyErr_SetObject(exception_class, Py_BuildValue("si", TF_Message(status),
TF_GetCode(status)));
} else {
PyErr_SetString(PyExc_RuntimeError, TF_Message(status));
}
return -1;
}

View File

@ -0,0 +1,240 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gradient tape utilites."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
from autograd import container_types
from autograd import core as ag_core
from tensorflow.python.framework import dtypes
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
def tensor_id(t):
"""Returns a unique identifier for this Tensor."""
t = ag_core.getval(t)
return t._id # pylint: disable=protected-access
class ImplicitTape(object):
"""Global object which can watch tensors and wrap them with autograd."""
def __init__(self):
self.tensors = {}
self.gradients = []
def __eq__(self, other):
return self is other
def __hash__(self):
return id(self)
@ag_core.primitive
def _watch_with_tape_internal(_, tensor):
"""Primitive to wrap a tensor around an ImplicitTape progenitor."""
return tensor
def _watch_with_tape(tape, tensor):
"""Wraps a watched Tensor and keeps track of it in the implicit tape."""
w = _watch_with_tape_internal(tape, tensor)
if ag_core.isnode(tape):
tape.value.tensors[tensor_id(tensor)] = w
return w
def _watch_with_tape_vjp(g, ans, vs, gvs, tape, tensor):
"""Gradient for _watch_with_tape_internal."""
del ans, gvs, tape
def mut_add(implicit_tape):
t = ag_core.getval(tensor)
implicit_tape.gradients.append((t, g))
return implicit_tape
return ag_core.SparseObject(vs, mut_add)
_watch_with_tape_internal.defvjp(_watch_with_tape_vjp, argnum=0)
_watch_with_tape_internal.defvjp(
lambda g, ans, vs, gvs, tape, tensor: g,
argnum=1)
class ImplicitTapeVSpace(ag_core.VSpace):
"""VSpace needed to have ImplicitTape be a valid progenitor."""
def zeros(self):
return ImplicitTape()
class ImplicitTapeNode(ag_core.Node):
"""Node to wrap ImplicitTape in."""
def __eq__(self, other):
return self is other
def __hash__(self):
return id(self)
ag_core.register_node(ImplicitTapeNode, ImplicitTape)
ag_core.register_vspace(ImplicitTapeVSpace, ImplicitTape)
# TODO(apassos) try to not do this.
class NoneVSpace(ag_core.VSpace):
"""VSpace for python None."""
def __init__(self, _):
self.size = 0
ag_core.register_vspace(NoneVSpace, type(None))
class _TapeStack(threading.local):
def __init__(self):
super(_TapeStack, self).__init__()
self._stack = []
@property
def stack(self):
return self._stack
@tf_contextlib.contextmanager
def replace_stack(self, new_stack):
old = self._stack
self._stack = new_stack
yield
self._stack = old
# The global tape stack.
_tape_stack = _TapeStack()
def push_new_tape():
"""Pushes a new tape onto the tape stack."""
progenitor = ag_core.new_progenitor(ImplicitTape())
_tape_stack.stack.append(progenitor)
ag_core.active_progenitors.add(progenitor)
def watch(tensor):
"""Marks this tensor to be watched by all tapes in the stack.
Args:
tensor: tensor to be watched.
Returns:
The tensor, potentially wrapped by all tapes in the stack.
"""
for t in _tape_stack.stack:
tensor = _watch_with_tape(t, tensor)
return tensor
def pop_tape():
"""Pops the top tape in the stack, if any."""
if _tape_stack.stack:
return _tape_stack.stack.pop()
return None
def any_tape_has(tensor):
for t in _tape_stack.stack:
if tensor_id(tensor) in t.value.tensors:
return True
return False
def should_record(tensors):
"""Returns true if any tape in the stach watches any of these tensors."""
return any(ag_core.isnode(x) for x in tensors)
class _EagerSequenceNode(container_types.SequenceNode):
"""Eager version of SequenceNode, to live in EagerSequenceVSpace."""
pass
class _EagerSequenceVSpace(container_types.SequenceVSpace):
"""Changes equality on SequenceVSpace to conform to tfe requirements."""
def __init__(self, value):
self.shape = [ag_core.vspace(x) for x in value]
self.size = sum(s.size for s in self.shape)
self.sequence_type = type(value)
def __eq__(self, other):
if type(self) != type(other): # pylint: disable=unidiomatic-typecheck
return False
if len(self.shape) != len(other.shape):
# TODO(apassos) function gradients sometimes return gradients for side
# inputs which breaks this assertion. Understand how to fix it.
return True
for ss, os in zip(self.shape, other.shape):
if ss != os:
if isinstance(ss, NoneVSpace) or isinstance(os, NoneVSpace):
continue
if ss.dtype == dtypes.resource or os.dtype == dtypes.resource:
continue
return False
return True
class _EagerList(list):
"""Type used to bypass SequenceVSpace."""
def __init__(self, value):
super(_EagerList, self).__init__(value)
for v in value:
assert not ag_core.isnode(v)
ag_core.register_vspace(_EagerSequenceVSpace, _EagerList)
ag_core.register_node(_EagerSequenceNode, _EagerList)
@ag_core.primitive
def _record_operation(output_tensors, input_tensors, side_outputs,
backward_function):
del input_tensors, side_outputs, backward_function
return _EagerList(output_tensors)
def record_operation(o, i, s, b):
"""Primitive to trigger autograd tracing on outputs from inputs."""
inputs = container_types.make_sequence(_EagerList, *i)
return _record_operation(o, inputs, s, b)
def _record_operation_vjp(g, ans, vs, gvs, output_tensors, input_tensors,
side_outputs, backward_function):
"""Gradient for _record_operation."""
del ans, vs, gvs, output_tensors, input_tensors
backward_args = tuple(g) + tuple(side_outputs)
if ag_core.isnode(backward_args):
backward_args = list(backward_args)
tensors = nest.flatten(backward_function(*backward_args))
return _EagerList([ag_core.getval(t) for t in tensors])
_record_operation.defvjp(_record_operation_vjp, argnum=1)

View File

@ -0,0 +1,411 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Experimental API for TensorFlow's "Eager" mode of execution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from autograd import core as ag_core
import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.eager import tape
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape
class Tensor(object):
"""A TensorFlow Tensor."""
def __init__(self, value, dtype=None):
"""Creates a Tensor object from a Python object or numpy array.
May share storage with the numpy array, in which case changes to the numpy
object will reflect
in the Tensor.
Arguments:
value: A numpy.array or a Python object to create a Tensor for.
dtype: TensorFlow dtype for the returned Tensor. If None, one will be
automatically selected.
"""
# TODO(ashankar): Evaluate if we can and perhaps share code with
# tf.constant defined in
# https://www.tensorflow.org/code/tensorflow/python/framework/constant_op.py
self._id = core.uid()
if not isinstance(value, np.ndarray):
npt = None if dtype is None else dtype.as_numpy_dtype
value = np.array(value, dtype=npt)
if dtype is None:
value = _maybe_modify_numpy_dtype_determination(value)
elif dtype is not None:
npt = dtype.as_numpy_dtype
if npt != value.dtype:
value = value.astype(npt)
try:
value = np.asarray(value, order="C")
self._handle = pywrap_tensorflow.TFE_Py_NumpyToTensorHandle(value)
except core._NotOkStatusException as e: # pylint: disable=protected-access
raise core._status_to_exception(e.code, e.message) # pylint: disable=protected-access
# Almost all TensorFlow kernels for GPU devices keep int32 tensors in host
# memory. This change approximates the same behavior for eager execution -
# keeping int32 tensors in host memory.
#
# We do so to preclude the need for callers into such kernels from having to
# explicitly place the int32 tensors in host memory. For example, prior to
# this change one needed:
#
# with tfe.device('/gpu:0'):
# ... # code here
# with tfe.device('/cpu:0'):
# shape = tfe.Tensor(...)
# y = tfe.ops.random_uniform(.., shape)
#
# Without the CPU device block tfe.ops.random_uniform would fail since the
# kernel expects the shape in host memory.
#
# After this change, we simplify the code:
#
# with tfe.device('/gpu:0'):
# y = tfe.ops.random_uniform(, tfe.Tensor(...))
#
# The approximation is not exact since if there are GPU kernels which do not
# require host memory for int32 tensors, there will be a discrepancy between
# eager execution and TensorFlow graphs. However, as of July 2017, there
# were no known GPU kernels that kept int32 tensors in device memory.
if _in_gpu_device() and value.dtype != np.int32:
ctx = context.get_default_context()
# pylint: disable=protected-access
device_name = ctx.device_name
with errors.raise_exception_on_not_ok_status() as status:
self._handle = pywrap_tensorflow.TFE_TensorHandleCopyToDevice(
self._handle, ctx._handle, device_name, status)
# pylint: enable=protected-access
self._dtype = dtypes.as_dtype(
pywrap_tensorflow.TFE_TensorHandleDataType(self._handle))
# This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
# be None for tensors of type other than DT_REOSURCE. For DT_RESOURCE
# tensors, this will contain a serialized HandleData proto with shape
# inference metadata about shapes and dtypes of resources accessible from
# this handle.
self._handle_data = None
if core.active_trace() is not None:
core.active_trace().record_tensor("MANUAL",
tape.tensor_id(self),
self._device_name(),
self.shape.num_elements())
def __del__(self):
if (pywrap_tensorflow is not None
and pywrap_tensorflow.TFE_DeleteTensorHandle is not None):
pywrap_tensorflow.TFE_DeleteTensorHandle(self._handle)
if core.active_trace() is not None:
core.active_trace().delete_tensor(tape.tensor_id(self))
def __str__(self):
if self.dtype.is_numpy_compatible and self.shape.num_elements() > 0:
n = self.numpy().reshape(-1)
if self.shape.num_elements() > 5:
return "tfe.Tensor(%s..., shape=%s, dtype=%s)" % (n[:5], self.shape,
self.dtype.name)
else:
return "tfe.Tensor(%s, dtype=%s)" % (
np.array_str(self.numpy()).replace("\n", ""), self.dtype.name)
return "tfe.Tensor(<unprintable>, shape=%s dtype=%s)" % (self.shape,
self.dtype.name)
def __repr__(self):
if self.dtype.is_numpy_compatible and self.shape.num_elements() > 0:
n = self.numpy()
# TODO(apassos): understand why self.numpy() sometimes returns not
# an array.
if isinstance(n, np.ndarray):
n = n.reshape(-1)
if self.shape.num_elements() > 5:
return "<tfe.Tensor at %s shape=%s dtype=%s>(%s..., min=%s, max=%s)" % (
self._id, self.shape, self.dtype.name, n[:5], np.min(n), np.max(n))
else:
return "<tfe.Tensor at %s shape=%s dtype=%s>(%s)" % (self._id,
self.shape,
self.dtype.name, n)
return "<tfe.Tensor at %s shape=%s dtype=%s>" % (self._id, self.shape,
self.dtype.name)
@staticmethod
def _override_operator(name, func):
setattr(Tensor, name, func)
def numpy(self):
"""Returns a numpy array with the same contents as the Tensor.
The contents of the Tensor must be backed by host memory. The
as_cpu_tensor() method can be used ensure that this is true.
TODO(ashankar,agarwal): Perhaps this should NOT reference the underlying
buffer but instead always explicitly copy? Note that currently it may or may
not copy based on whether the numpy data is properly aligned or not.
Returns:
A numpy array that may share memory with the Tensor object. Any changes
to one may be reflected in the other.
"""
# TODO(ashankar): This with status business seems expensive. Profile/avoid?
cpu = self.as_cpu_tensor()
with errors.raise_exception_on_not_ok_status() as status:
return pywrap_tensorflow.TFE_Py_TensorHandleToNumpy(cpu._handle, status) # pylint: disable=protected-access
def _copy(self, ctx, device_name):
"""Copies tensor to dest device."""
# pylint: disable=protected-access
# Creates a new tensor on the dest device.
with errors.raise_exception_on_not_ok_status() as status:
h = pywrap_tensorflow.TFE_TensorHandleCopyToDevice(
self._handle, ctx._handle, device_name, status)
new_tensor = _tensor_from_handle(h)
if core.active_trace() is not None:
core.active_trace().record_tensor("COPY",
tape.tensor_id(new_tensor),
new_tensor._device_name(),
new_tensor.shape.num_elements())
return new_tensor
# pylint: enable=protected-access
def _device_name(self):
return pywrap_tensorflow.TFE_TensorHandleDeviceName(self._handle)
@property
def dtype(self):
return self._dtype
@property
def shape(self):
"""The shape of this Tensor as a TensorShape object."""
n = pywrap_tensorflow.TFE_TensorHandleNumDims(self._handle)
# As of May 2017, TFE_TensorHandle objects were always backed by concrete
# tensors (which have a valid, known shape). There were vague plans to
# change this so that the Tensor class can also represent Tensors that have
# not yet been computed.
# If that happens, handle that (e.g., if n < 0: return tensor_shape(None))
# and also handle -1s returned by TFE_TensorHandleDim.
assert n >= 0, "See comment in source code"
return tensor_shape.TensorShape(
[pywrap_tensorflow.TFE_TensorHandleDim(self._handle, x)
for x in range(n)])
def get_shape(self):
"""Alias of Tensor.shape."""
return self.shape
def _shape_tuple(self):
"""The shape of this Tensor, as a tuple.
This is more performant than tuple(shape().as_list()) as it avoids
two list and one object creation. Marked private for now as from an API
perspective, it would be better to have a single performant way of
getting a shape rather than exposing shape() and shape_tuple()
(and heaven forbid, shape_list() etc. as well!). Punting on that for now,
but ideally one would work things out and remove the need for this method.
"""
n = pywrap_tensorflow.TFE_TensorHandleNumDims(self._handle)
# As of May 2017, TFE_TensorHandle objects were always backed by concrete
# tensors (which have a valid, known shape). There were vague plans to
# change this so that the Tensor class can also represent Tensors that have
# not yet been computed.
# If that happens, handle that (e.g., if n < 0: return tensor_shape(None))
# and also handle -1s returned by TFE_TensorHandleDim.
assert n >= 0, "See comment in source code"
return tuple(
pywrap_tensorflow.TFE_TensorHandleDim(self._handle, x)
for x in range(n))
def as_cpu_tensor(self):
"""A copy of this Tensor with contents backed by host memory."""
return self._copy(context.get_default_context(), "CPU:0")
def as_gpu_tensor(self, gpu_index=0):
"""A copy of this Tensor with contents backed by memory on the GPU.
Arguments:
gpu_index: Identifies which GPU to place the contents on the returned
Tensor in.
Returns:
A GPU-memory backed Tensor object initialized with the same contents
as this Tensor.
"""
return self._copy(context.get_default_context(), "GPU:" + str(gpu_index))
def __bool__(self):
if self._shape_tuple() != (): # pylint: disable=g-explicit-bool-comparison
raise ValueError(
"Non-scalar tensor %s cannot be converted to boolean." % repr(self))
if self.dtype != dtypes.bool:
raise ValueError(
"Non-boolean tensor %s cannot be converted to boolean." % repr(self))
return bool(self.as_cpu_tensor().numpy())
def __nonzero__(self):
return self.__bool__()
class IndexedSlices(object):
"""A sparse representation of a set of tensor slices at given indices.
This class is a simple wrapper for a pair of `Tensor` objects:
* `values`: A `Tensor` of any dtype with shape `[D0, D1, ..., Dn]`.
* `indices`: A 1-D integer `Tensor` with shape `[D0]`.
An `IndexedSlices` is typically used to represent a subset of a larger
tensor `dense` of shape `[LARGE0, D1, .. , DN]` where `LARGE0 >> D0`.
The values in `indices` are the indices in the first dimension of
the slices that have been extracted from the larger tensor.
The dense tensor `dense` represented by an `IndexedSlices` `slices` has
```python
dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]
```
The `IndexedSlices` class is used principally in the definition of
gradients for operations that have sparse gradients
(e.g. @{tf.gather}).
"""
def __init__(self, values, indices, dense_shape):
"""Creates an `IndexedSlices`."""
self._values = values
self._indices = indices
assert indices.shape[0] == values.shape[0]
self._dense_shape = dense_shape
@property
def values(self):
"""A `Tensor` containing the values of the slices."""
return self._values
@property
def indices(self):
"""A 1-D `Tensor` containing the indices of the slices."""
return self._indices
@property
def dense_shape(self):
"""A 1-D `Tensor` containing the shape of the corresponding dense tensor."""
return self._dense_shape
class _Op(object):
"""Fake op for _LazyZero to make its python API tf.Tensor-like."""
def __init__(self):
self.type = "Zeros"
class LazyZero(object):
"""Lazily-instantiated zero-valued Tensor used as autograd accumulator."""
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
self.op = _Op()
def __add__(self, other):
return other
def __radd__(self, other):
return other
def numpy(self):
return np.zeros(self.shape, self.dtype)
def convert_to_eager_tensor(t, dtype=None):
if isinstance(ag_core.getval(t), Tensor):
if dtype is not None and t.dtype != dtype:
raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype))
return t
return Tensor(t, dtype=dtype)
def convert_n_to_eager_tensor(values, dtype):
return [convert_to_eager_tensor(t, dtype) for t in values]
def _tensor_from_handle(handle):
"""'Private' constructor for the Tensor object.
The existence of a 'handle' is an implementation detail that should be hidden
from users of this module. Functions within this module do need to create a
Tensor object from a handle though.
One option would be to have an __init__(self, handle) method on the
Tensor class, but that would make the existence and use of a handle
'public'.
Instead, this function avoids exposing a Tensor.__init__ that understands
handles and yet allows functions within this module to create Tensor
objects from a handle.
Arguments:
handle: A valid TFE_TensorHandle object.
Returns:
A Tensor object.
"""
# pylint: disable=protected-access
t = Tensor.__new__(Tensor)
t._id = core.uid()
t._handle = handle
t._dtype = dtypes.as_dtype(pywrap_tensorflow.TFE_TensorHandleDataType(handle))
t._handle_data = None
return t
# pylint: enable=protected-access
# TODO(ashankar): use actual device type.
def _in_gpu_device():
return context.get_default_context()._device_index > 0 # pylint: disable=protected-access
def _maybe_modify_numpy_dtype_determination(np_array):
"""Tweak numpy dtype determination.
numpy prefers int64 and float64, we prefer int32 and float32.
(int32 is often used as the "shape" input to various operations,
many of which only support int32 shapes).
This preference is copied from tensor_util.make_tensor_proto
(https://goto.google.com/numpy_prefs_156503903)
Args:
np_array: A numpy ndarray
Returns:
A numpy ndarray whose dtype may have been modified.
"""
if np_array.dtype == np.float64:
return np_array.astype(np.float32)
if np_array.dtype == np.int64:
# Downcast iff there is no precision loss.
downcasted = np_array.astype(np.int32)
if np.array_equal(downcasted, np_array):
return downcasted
return np_array

View File

@ -0,0 +1,28 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for testing tfe code."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import context as _context
from tensorflow.python.platform import test as _test
from tensorflow.python.platform.test import * # pylint: disable=wildcard-import
def main(argv=None):
_context.enable_eager_execution()
_test.main(argv)

View File

@ -0,0 +1,153 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%ignore "";
%rename("%s") TFE_Py_RegisterExceptionClass;
%rename("%s") TFE_Py_NumpyToTensorHandle;
%rename("%s") TFE_NewContext;
%rename("%s") TFE_ContextListDevices;
%rename("%s") TFE_TensorHandleDataType;
%rename("%s") TFE_TensorHandleNumDims;
%rename("%s") TFE_DeleteTensorHandle;
%rename("%s") TFE_Py_Execute;
%rename("%s") TFE_ContextAddFunctionDef;
%rename("%s") TFE_TensorHandleDim;
%rename("%s") TFE_TensorHandleCopyToDevice;
%rename("%s") TFE_NewOp;
%rename("%s") TFE_Py_TensorHandleToNumpy;
%rename("%s") TFE_OpGetAttrType;
%{
#include "tensorflow/python/eager/pywrap_tfe.h"
%}
%typemap(out) TF_DataType {
$result = PyInt_FromLong($1);
}
%typemap(out) int64_t {
$result = PyInt_FromLong($1);
}
%typemap(out) TF_AttrType {
$result = PyInt_FromLong($1);
}
%typemap(in, numinputs=0) unsigned char* is_list (unsigned char tmp) {
$1 = &tmp;
}
%typemap(argout) unsigned char* is_list {
if (*$1 == 1) {
PyObject* list = PyList_New(1);
PyList_SetItem(list, 0, $result);
$result = list;
}
}
%include "tensorflow/c/eager/c_api.h"
%typemap(in) TFE_InputTensorHandles* inputs (TFE_InputTensorHandles temp) {
$1 = &temp;
if ($input != Py_None) {
if (!PyList_Check($input)) {
SWIG_exception_fail(SWIG_TypeError,
"must provide a list of Tensors as inputs");
}
Py_ssize_t len = PyList_Size($input);
$1->resize(len);
for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* elem = PyList_GetItem($input, i);
if (!elem) {
SWIG_fail;
}
void* thp = nullptr;
int res = SWIG_ConvertPtr(elem, &thp,
$descriptor(TFE_TensorHandle*), 0 | 0);
if (!SWIG_IsOK(res)) {
SWIG_exception_fail(SWIG_ArgError(res),
"provided list of inputs contains objects other "
"than 'TFE_TensorHandle*'");
}
(*$1)[i] = reinterpret_cast<TFE_TensorHandle*>(thp);
}
}
}
// Temporary for the argout
%typemap(in) TFE_OutputTensorHandles* outputs (TFE_OutputTensorHandles temp) {
if (!PyInt_Check($input)) {
SWIG_exception_fail(SWIG_TypeError,
"expected an integer value (size of the number of "
"outputs of the operation)");
}
$1 = &temp;
$1->resize(PyInt_AsLong($input), nullptr);
}
// Create new Status object.
%typemap(in, numinputs=0) TF_Status *out_status {
$1 = TF_NewStatus();
}
%typemap(freearg) (TF_Status* out_status) {
TF_DeleteStatus($1);
}
%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status) {
if (TFE_Py_MayBeRaiseException($2)) {
SWIG_fail;
} else {
int num_outputs = $1->size();
$result = PyList_New(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
PyList_SetItem($result, i, SWIG_NewPointerObj(SWIG_as_voidptr($1->at(i)),
$descriptor(TFE_TensorHandle*),
0 | 0));
}
}
}
// Note that we need to use a typemap for TFE_TensorHandle* so that we can call
// SWIG_fail in case the value is nullptr. Otherwise SWIG will wrap the
// nullptr and return it to python as an opaque object, and python does not
// know that it needs to check if an Exception has been raised.
// TODO(agarwal): check if we can get rid of this typemap.
%typemap(out) (TFE_TensorHandle*) {
if ($1 == nullptr) {
SWIG_fail;
} else {
$result = SWIG_NewPointerObj(SWIG_as_voidptr($1),
$descriptor(TFE_TensorHandle*), 0 | 0);
}
}
%include "tensorflow/python/eager/pywrap_tfe.h"
// Clear all typemaps127
%typemap(out) TF_DataType;
%typemap(out) int64_t;
%typemap(out) TF_AttrType;
%typemap(in, numinputs=0) TF_Status *out_status;
%typemap(argout) unsigned char* is_list;
%typemap(in) TFE_InputTensorHandles* inputs (TFE_InputTensorHandles temp);
%typemap(in) TFE_OutputTensorHandles* outputs (TFE_OutputTensorHandles temp);
%typemap(in, numinputs=0) TF_Status *out_status;
%typemap(freearg) (TF_Status* out_status);
%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status);
%typemap(out) (TFE_TensorHandle*);

View File

@ -17,6 +17,8 @@ limitations under the License.
* The includes are intentionally not alphabetically sorted, as the order of
* includes follows dependency order */
%include "tensorflow/python/pywrap_tfe.i"
%include "tensorflow/python/util/port.i"
%include "tensorflow/python/util/py_checkpoint_reader.i"
%include "tensorflow/python/util/stat_summarizer.i"
@ -45,3 +47,4 @@ limitations under the License.
%include "tensorflow/python/grappler/tf_optimizer.i"
%include "tensorflow/python/grappler/cost_analyzer.i"
%include "tensorflow/python/grappler/model_analyzer.i"

View File

@ -175,6 +175,7 @@ sh_binary(
"//tensorflow/python/debug:debug_pip",
"//tensorflow/python/saved_model:saved_model",
"//tensorflow/python:spectral_ops_test_util",
"//tensorflow/python/eager:pip_dependencies",
"//tensorflow/python/tools:tools_pip",
"//tensorflow/tools/dist_test/server:grpc_tensorflow_server",
],