Experimental C and Python APIs to invoke TensorFlow kernels on concrete values.
PiperOrigin-RevId: 164902588
This commit is contained in:
parent
7dfabcc01c
commit
13eb3b90e9
tensorflow
BUILD
c/eager
contrib/cmake
python
BUILD
client
eager
BUILDcontext.pycore.pycore_test.pycustom_gradient.pyexecute.pyfunction.pygen_op.bzlgraph_only_ops.pymemory_trace.pypython_eager_op_gen.ccpython_eager_op_gen.hpython_eager_op_gen_main.ccpywrap_tfe.hpywrap_tfe_src.cctape.pytensor.pytest.py
pywrap_tfe.itensorflow.itools/pip_package
@ -380,6 +380,7 @@ filegroup(
|
||||
"//tensorflow/java/src/main/native:all_files",
|
||||
"//tensorflow/python:all_files",
|
||||
"//tensorflow/python/debug:all_files",
|
||||
"//tensorflow/python/eager:all_files",
|
||||
"//tensorflow/python/estimator:all_files",
|
||||
"//tensorflow/python/feature_column:all_files",
|
||||
"//tensorflow/python/kernel_tests:all_files",
|
||||
|
67
tensorflow/c/eager/BUILD
Normal file
67
tensorflow/c/eager/BUILD
Normal file
@ -0,0 +1,67 @@
|
||||
# Experimental extensions to the C API for eager execution of kernels.
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
cc_library(
|
||||
name = "c_api",
|
||||
srcs = ["c_api.cc"],
|
||||
hdrs = ["c_api.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
"//tensorflow/python/eager:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":runtime",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "c_api_test",
|
||||
srcs = ["c_api_test.cc"],
|
||||
deps = [
|
||||
":c_api",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "runtime",
|
||||
srcs = ["runtime.cc"],
|
||||
hdrs = ["runtime.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "runtime_test",
|
||||
srcs = ["runtime_test.cc"],
|
||||
deps = [
|
||||
":runtime",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:client_session",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
561
tensorflow/c/eager/c_api.cc
Normal file
561
tensorflow/c/eager/c_api.cc
Normal file
@ -0,0 +1,561 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/runtime.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
using tensorflow::int64;
|
||||
using tensorflow::string;
|
||||
|
||||
namespace {
|
||||
bool IsCPU(tensorflow::Device* d) {
|
||||
return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
|
||||
}
|
||||
|
||||
string DeviceName(tensorflow::Device* d) {
|
||||
return (d == nullptr) ? "cpu:0" : d->name();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
struct TFE_Context {
|
||||
explicit TFE_Context(TF_Session* s) : session(s) {}
|
||||
|
||||
// TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph.
|
||||
TF_Session* session;
|
||||
|
||||
tensorflow::mutex functions_mu;
|
||||
tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){
|
||||
tensorflow::OpRegistry::Global(), {}};
|
||||
|
||||
// One FunctionLibraryRuntime per device.
|
||||
// func_libs[i] is the FunctionLibraryRuntime corresponding to
|
||||
// session->devices[i].
|
||||
std::vector<std::unique_ptr<tensorflow::FunctionLibraryRuntime> > func_libs;
|
||||
|
||||
std::unordered_map<tensorflow::Fprint128, tensorflow::KernelAndDevice*,
|
||||
tensorflow::Fprint128Hasher>
|
||||
kernel_cache;
|
||||
|
||||
tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) {
|
||||
for (int i = 0; i < session->devices.size(); ++i) {
|
||||
if (session->devices[i] == d) {
|
||||
return func_libs[i].get();
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const std::vector<tensorflow::Device*>& devices() { return session->devices; }
|
||||
};
|
||||
|
||||
struct TFE_TensorHandle {
|
||||
TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d)
|
||||
: t(t), d(d) {}
|
||||
|
||||
tensorflow::Tensor t;
|
||||
// TODO(ashankar): d == nullptr iff local CPU
|
||||
// This was expedient, but perhaps worth revisiting ('d' should always be a
|
||||
// valid pointer?)
|
||||
// This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are
|
||||
// provided with the appropriate TFE_Context.
|
||||
//
|
||||
// TODO(ashankar): Reference count TFE_Context to ensure that 'd' of a
|
||||
// TFE_TensorHandle does not outlive the TFE_Context from which it came?
|
||||
tensorflow::Device* d;
|
||||
};
|
||||
|
||||
struct TFE_Op {
|
||||
TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t)
|
||||
: ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {}
|
||||
|
||||
bool const is_function() const { return attr_types == nullptr; }
|
||||
|
||||
TFE_Context* ctx; // Must outlive the TFE_Op.
|
||||
const char* name;
|
||||
tensorflow::AttrBuilder attrs;
|
||||
const tensorflow::AttrTypeMap* attr_types;
|
||||
std::vector<tensorflow::Tensor> inputs;
|
||||
std::vector<tensorflow::Device*> input_devices;
|
||||
tensorflow::Device* device;
|
||||
};
|
||||
|
||||
extern "C" {
|
||||
|
||||
TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) {
|
||||
TF_Graph* graph = TF_NewGraph();
|
||||
TF_Session* session = TF_NewSession(graph, opts, status);
|
||||
if (status->status.ok()) {
|
||||
if (session->device_mgr == nullptr || session->devices.empty()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Provided TF_SessionOptions are not compatible with eager execution "
|
||||
"(perhaps the TF_SessionOptions alluded to session execution in a "
|
||||
"remote address space?)");
|
||||
}
|
||||
}
|
||||
if (!status->status.ok()) {
|
||||
TF_DeleteGraph(graph);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TFE_Context* ret = new TFE_Context(session);
|
||||
ret->func_libs.resize(ret->devices().size());
|
||||
for (int i = 0; i < ret->devices().size(); ++i) {
|
||||
ret->func_libs[i] = tensorflow::NewFunctionLibraryRuntime(
|
||||
ret->session->device_mgr, opts->options.env, ret->devices()[i],
|
||||
TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {});
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
|
||||
status->status = tensorflow::Status::OK();
|
||||
tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
|
||||
TF_Graph* graph = ctx->session->graph;
|
||||
TF_DeleteSession(ctx->session, status);
|
||||
TF_DeleteGraph(graph);
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
|
||||
return TF_SessionListDevices(ctx->session, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t) {
|
||||
return new TFE_TensorHandle(
|
||||
tensorflow::TensorCApi::MakeTensor(t->dtype, t->shape, t->buffer),
|
||||
nullptr);
|
||||
}
|
||||
|
||||
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { delete h; }
|
||||
|
||||
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
|
||||
return static_cast<TF_DataType>(h->t.dtype());
|
||||
}
|
||||
|
||||
int TFE_TensorHandleNumDims(TFE_TensorHandle* h) { return h->t.dims(); }
|
||||
|
||||
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index) {
|
||||
return h->t.dim_size(dim_index);
|
||||
}
|
||||
|
||||
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h) {
|
||||
// This might be a bit confusing as a tensor on CPU can sometimes return
|
||||
// "CPU:0" and sometimes "/job:localhost/replica:0/task:0/cpu:0".
|
||||
// TODO(ashankar): Figure out which one would be nicer.
|
||||
return (h->d == nullptr) ? "CPU:0" : h->d->name().c_str();
|
||||
}
|
||||
|
||||
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (!IsCPU(h->d)) {
|
||||
TF_SetStatus(status, TF_UNIMPLEMENTED,
|
||||
tensorflow::strings::StrCat(
|
||||
"TFE_TensorHandle can be resolved iff it is on CPU (this "
|
||||
"handle is on ",
|
||||
h->d->name(),
|
||||
"). Consider using TFE_TensorHandleCopyToDevice to get a "
|
||||
"copy of the tensor on CPU")
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return tensorflow::TF_TensorFromTensor(h->t, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
TFE_Context* ctx,
|
||||
const char* device_name,
|
||||
TF_Status* status) {
|
||||
tensorflow::Device* dstd = nullptr;
|
||||
status->status = ctx->session->device_mgr->LookupDevice(device_name, &dstd);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
tensorflow::Device* srcd = h->d == nullptr ? ctx->devices()[0] : h->d;
|
||||
const bool src_cpu = IsCPU(srcd);
|
||||
const bool dst_cpu = IsCPU(dstd);
|
||||
if (!src_cpu && !dst_cpu) {
|
||||
TF_SetStatus(
|
||||
status, TF_INVALID_ARGUMENT,
|
||||
tensorflow::strings::StrCat(
|
||||
"TFE_TensorHandleCopyToDevice requires either the source "
|
||||
"TFE_TensorHandle be on or the destination device be CPU (they "
|
||||
"are ",
|
||||
DeviceName(srcd), " and ", DeviceName(dstd), " in this call)")
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Tensor* src = &(h->t);
|
||||
if (src_cpu && dst_cpu) {
|
||||
// There must be a better way, but for now redirect through proto to ensure
|
||||
// that the underlying buffers are not shared.
|
||||
tensorflow::TensorProto proto;
|
||||
src->AsProtoTensorContent(&proto);
|
||||
tensorflow::Tensor dst(src->dtype(), src->shape());
|
||||
if (!dst.FromProto(proto)) {
|
||||
TF_SetStatus(
|
||||
status, TF_INTERNAL,
|
||||
tensorflow::strings::StrCat(
|
||||
"error copying between TFE_TensorHandles on CPU. Consider filing "
|
||||
"a bug report at https://github.com/tensorflow/tensorflow/issues "
|
||||
"mentioning version: ",
|
||||
TF_Version(), " and ", __FILE__, ":", __LINE__)
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_TensorHandle(dst, nullptr);
|
||||
}
|
||||
if (src_cpu) {
|
||||
tensorflow::Tensor dst(
|
||||
dstd->GetAllocator(tensorflow::AllocatorAttributes()), src->dtype(),
|
||||
src->shape());
|
||||
tensorflow::Notification n;
|
||||
dstd->tensorflow_gpu_device_info()->default_context->CopyCPUTensorToDevice(
|
||||
src, dstd, &dst, [status, &n](const tensorflow::Status& s) {
|
||||
status->status = s;
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
return (TF_GetCode(status) == TF_OK) ? new TFE_TensorHandle(dst, dstd)
|
||||
: nullptr;
|
||||
}
|
||||
CHECK(dst_cpu);
|
||||
tensorflow::Tensor dst(src->dtype(), src->shape());
|
||||
tensorflow::Notification n;
|
||||
// TODO(ashankar): The Sync() call below may be more aggressive than
|
||||
// necessary. It is based on knowledge of implementation details - that
|
||||
// GPU devices are implemented using 3 streams - one for host->device copies,
|
||||
// one for device->host copies and one for sending operations to the GPU.
|
||||
// With that setup, Sync()ing across all 3 streams should be sufficient
|
||||
// but more than necessary (since it waits for operations that might have
|
||||
// nothing to do with this tensor to complete).
|
||||
status->status = srcd->Sync();
|
||||
if (!status->status.ok()) return nullptr;
|
||||
srcd->tensorflow_gpu_device_info()->default_context->CopyDeviceTensorToCPU(
|
||||
src, "IGNORE_MY_TENSOR_NAME", srcd, &dst,
|
||||
[status, &n](const tensorflow::Status& s) {
|
||||
status->status = s;
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
return (TF_GetCode(status) == TF_OK) ? new TFE_TensorHandle(dst, nullptr)
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status) {
|
||||
const char* name = op_or_function_name; // Shorthand
|
||||
const tensorflow::AttrTypeMap* types;
|
||||
status->status = tensorflow::AttrTypeMapForOp(name, &types);
|
||||
if (status->status.ok()) return new TFE_Op(ctx, name, types);
|
||||
if (TF_GetCode(status) == TF_NOT_FOUND) {
|
||||
tensorflow::mutex_lock l(ctx->functions_mu);
|
||||
if (ctx->func_lib_def.Find(name) != nullptr) {
|
||||
status->status = tensorflow::Status::OK();
|
||||
return new TFE_Op(ctx, name, nullptr);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void TFE_DeleteOp(TFE_Op* op) { delete op; }
|
||||
|
||||
static void TFE_OpSetDeviceHelper(TFE_Op* op, tensorflow::Device* device,
|
||||
TF_Status* status) {
|
||||
// Questionable heuristic: Place the op on the same device as the first input
|
||||
// placed outside of host memory?
|
||||
if (IsCPU(op->device) && !IsCPU(device)) {
|
||||
op->device = device;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetDevice(TFE_Op* op, TFE_Context* ctx, const char* device_name,
|
||||
TF_Status* status) {
|
||||
tensorflow::Device* d = nullptr;
|
||||
status->status = ctx->session->device_mgr->LookupDevice(device_name, &d);
|
||||
if (!status->status.ok()) return;
|
||||
TFE_OpSetDeviceHelper(op, d, status);
|
||||
}
|
||||
|
||||
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
|
||||
TFE_OpSetDeviceHelper(op, h->d, status);
|
||||
if (!status->status.ok()) return;
|
||||
op->inputs.push_back(h->t);
|
||||
op->input_devices.push_back(h->d);
|
||||
op->attrs.NumInputs(op->inputs.size());
|
||||
}
|
||||
|
||||
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||
unsigned char* is_list, TF_Status* status) {
|
||||
TF_AttrType ret;
|
||||
if (op->is_function()) {
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"TODO(apassos): Support for attributes for TensorFlow functions is not "
|
||||
"ready yet.");
|
||||
return TF_ATTR_INT; // The compiler requires that we return something.
|
||||
}
|
||||
status->status =
|
||||
tensorflow::AttrTypeByName(op->attr_types, attr_name, &ret, is_list);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) {
|
||||
op->attrs.Set(attr_name, value);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
|
||||
op->attrs.Set(attr_name, static_cast<int64>(value));
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
|
||||
op->attrs.Set(attr_name, value);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
|
||||
op->attrs.Set(attr_name, (value == 0) ? false : true);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
|
||||
op->attrs.Set(attr_name, static_cast<tensorflow::DataType>(value));
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
|
||||
const int num_dims, TF_Status* out_status) {
|
||||
if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
|
||||
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
|
||||
tensorflow::strings::StrCat(
|
||||
"Value specified for `", attr_name, "` has ", num_dims,
|
||||
" dimensions which is over the limit of ",
|
||||
tensorflow::TensorShape::MaxDimensions(), ".")
|
||||
.c_str());
|
||||
return;
|
||||
}
|
||||
tensorflow::TensorShapeProto proto;
|
||||
if (num_dims < 0) {
|
||||
proto.set_unknown_rank(true);
|
||||
} else {
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
proto.add_dim()->set_size(dims[d]);
|
||||
}
|
||||
}
|
||||
op->attrs.Set(attr_name, proto);
|
||||
}
|
||||
|
||||
#define TFE_OP_SET_ATTR_LIST(fn, type) \
|
||||
void fn(TFE_Op* op, const char* attr_name, const type* values, \
|
||||
int num_values) { \
|
||||
op->attrs.Set(attr_name, tensorflow::gtl::ArraySlice<const type>( \
|
||||
values, num_values)); \
|
||||
}
|
||||
TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*)
|
||||
TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float)
|
||||
#undef TFE_OP_SET_ATTR_LIST
|
||||
|
||||
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
|
||||
const int64_t* values, int num_values) {
|
||||
op->attrs.Set(attr_name,
|
||||
tensorflow::gtl::ArraySlice<const int64>(
|
||||
reinterpret_cast<const int64*>(values), num_values));
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
|
||||
const TF_DataType* values, int num_values) {
|
||||
op->attrs.Set(
|
||||
attr_name,
|
||||
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
|
||||
reinterpret_cast<const tensorflow::DataType*>(values), num_values));
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
|
||||
const unsigned char* values, int num_values) {
|
||||
std::unique_ptr<bool[]> b(new bool[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
b[i] = values[i];
|
||||
}
|
||||
op->attrs.Set(attr_name,
|
||||
tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
|
||||
const int64_t** dims, const int* num_dims,
|
||||
int num_values, TF_Status* out_status) {
|
||||
std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
|
||||
new tensorflow::TensorShapeProto[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
const auto num_dims_i = num_dims[i];
|
||||
|
||||
if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
|
||||
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
|
||||
tensorflow::strings::StrCat(
|
||||
"Value specified for `", attr_name, "` has ", num_dims_i,
|
||||
" dimensions which is over the limit of ",
|
||||
tensorflow::TensorShape::MaxDimensions(), ".")
|
||||
.c_str());
|
||||
return;
|
||||
}
|
||||
if (num_dims_i < 0) {
|
||||
proto[i].set_unknown_rank(true);
|
||||
} else {
|
||||
const int64_t* dims_i = dims[i];
|
||||
auto proto_i = &proto[i];
|
||||
for (int d = 0; d < num_dims_i; ++d) {
|
||||
proto_i->add_dim()->set_size(dims_i[d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
op->attrs.Set(attr_name,
|
||||
tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
|
||||
proto.get(), num_values));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
tensorflow::Status ValidateInputTypeAndPlacement(
|
||||
tensorflow::Device* host_device, tensorflow::Device* op_device, TFE_Op* op,
|
||||
const tensorflow::OpKernel* kernel) {
|
||||
const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types();
|
||||
if (memtypes.size() != op->inputs.size()) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"expected ", memtypes.size(), " inputs, got ", op->inputs.size());
|
||||
}
|
||||
for (int i = 0; i < op->inputs.size(); ++i) {
|
||||
const tensorflow::Device* expected_device =
|
||||
memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device;
|
||||
const tensorflow::Device* actual_device =
|
||||
op->input_devices[i] == nullptr ? host_device : op->input_devices[i];
|
||||
if (expected_device != actual_device) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"cannot compute ", op->name, " as input #", i,
|
||||
" was expected to be on ", expected_device->name(),
|
||||
" but is actually on ", actual_device->name(),
|
||||
" (operation running on ", op_device->name(), ")");
|
||||
}
|
||||
if (op->inputs[i].dtype() != kernel->input_type(i)) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"cannot compute ", op->name, " as input #", i,
|
||||
" was expected to be a ",
|
||||
tensorflow::DataType_Name(kernel->input_type(i)), " tensor but is a ",
|
||||
tensorflow::DataType_Name(op->inputs[i].dtype()), " tensor");
|
||||
}
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status) {
|
||||
TFE_Context* ctx = op->ctx;
|
||||
// TODO(ashankar): ASSUMPTION: ctx->devices()[0] is always CPU
|
||||
tensorflow::Device* device =
|
||||
(op->device == nullptr) ? ctx->devices()[0] : op->device;
|
||||
std::vector<tensorflow::Tensor> outputs(1);
|
||||
const tensorflow::MemoryTypeVector* output_memory_types = nullptr;
|
||||
tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device->name());
|
||||
tensorflow::KernelAndDevice* kernel =
|
||||
tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key);
|
||||
if (kernel == nullptr) {
|
||||
const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
|
||||
kernel = new tensorflow::KernelAndDevice();
|
||||
if (!op->is_function()) {
|
||||
status->status =
|
||||
tensorflow::KernelAndDevice::InitOp(device, ndef, kernel);
|
||||
} else {
|
||||
// Knowledge of the implementation of InitFn (and in-turn
|
||||
// FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
|
||||
// will be accessed, so grab on to the lock.
|
||||
// See WARNING comment below - would be nice to rework to avoid this
|
||||
// subtlety.
|
||||
tensorflow::mutex_lock l(ctx->functions_mu);
|
||||
status->status = tensorflow::KernelAndDevice::InitFn(
|
||||
ndef, ctx->func_lib(device), kernel);
|
||||
}
|
||||
if (!status->status.ok()) {
|
||||
return;
|
||||
}
|
||||
tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
|
||||
}
|
||||
status->status = ValidateInputTypeAndPlacement(ctx->devices()[0], device, op,
|
||||
kernel->kernel());
|
||||
output_memory_types = &kernel->kernel()->output_memory_types();
|
||||
if (!status->status.ok()) {
|
||||
return;
|
||||
}
|
||||
// WARNING: kernel->Run utilizes the FunctionLibraryRuntime
|
||||
// (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def,
|
||||
// which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation
|
||||
// of FunctionLibraryRuntime tells use that func_lib_def is not accessed by
|
||||
// FunctionLibraryRuntime::Run(), so there is no thread-safety concern here.
|
||||
// This is quite subtle. Re-work things to make this better? (Would it make
|
||||
// sense for FunctionLibraryRuntime to ensure thread-safe access to
|
||||
// FunctionLibraryDefinition?).
|
||||
status->status = kernel->Run(&op->inputs, &outputs);
|
||||
if (!status->status.ok()) return;
|
||||
*num_retvals = std::min<int>(*num_retvals, outputs.size());
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
tensorflow::Device* d = IsCPU(device) ? nullptr : device;
|
||||
if (d != nullptr && output_memory_types != nullptr &&
|
||||
(*output_memory_types)[i] == tensorflow::HOST_MEMORY) {
|
||||
d = nullptr;
|
||||
}
|
||||
retvals[i] = new TFE_TensorHandle(outputs[i], d);
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_ContextAddFunctionDef(TFE_Context* ctx,
|
||||
const char* serialized_function_def, size_t size,
|
||||
TF_Status* status) {
|
||||
tensorflow::FunctionDef function_def;
|
||||
if (!function_def.ParseFromArray(serialized_function_def, size)) {
|
||||
status->status =
|
||||
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
|
||||
return;
|
||||
}
|
||||
tensorflow::mutex_lock l(ctx->functions_mu);
|
||||
status->status = ctx->func_lib_def.AddFunctionDef(function_def);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) {
|
||||
return new TFE_TensorHandle(t, nullptr);
|
||||
}
|
||||
|
||||
const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
|
||||
TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h->d != nullptr) {
|
||||
status->status = tensorflow::errors::FailedPrecondition(
|
||||
"TFE_TensorHandle is placed in device (not host) memory. Cannot return "
|
||||
"a tensorflow::Tensor");
|
||||
return nullptr;
|
||||
}
|
||||
return &h->t;
|
||||
}
|
159
tensorflow/c/eager/c_api.h
Normal file
159
tensorflow/c/eager/c_api.h
Normal file
@ -0,0 +1,159 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_C_API_H_
|
||||
#define TENSORFLOW_C_EAGER_C_API_H_
|
||||
|
||||
// C API extensions to experiment with eager execution of kernels.
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// "Context" under which operations/functions are executed. It encapsulates
|
||||
// things like the available devices, resource manager etc.
|
||||
//
|
||||
// TODO(ashankar): Merge with TF_Session?
|
||||
typedef struct TFE_Context TFE_Context;
|
||||
|
||||
extern TFE_Context* TFE_NewContext(const TF_SessionOptions* opts,
|
||||
TF_Status* status);
|
||||
extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status);
|
||||
extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
// A handle to a tensor on a device.
|
||||
//
|
||||
// Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape,
|
||||
// type etc. Unlike a TF_Tensor, a TFE_TensorHandle may refer to such tensors
|
||||
// placed in memory of different devices or remote address spaces.
|
||||
typedef struct TFE_TensorHandle TFE_TensorHandle;
|
||||
|
||||
extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t);
|
||||
extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
|
||||
extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h);
|
||||
extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h);
|
||||
extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index);
|
||||
extern const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h);
|
||||
extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h,
|
||||
TF_Status* status);
|
||||
|
||||
// Create a new TFE_TensorHandle with the same contents as 'h' but placed
|
||||
// in the memory of the device name 'device_name'.
|
||||
//
|
||||
// Currently requires at least one of the source or destination devices to
|
||||
// be CPU (i.e., for the source or destination tensor to be placed in
|
||||
// host memory).
|
||||
extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
TFE_Context* ctx,
|
||||
const char* device_name,
|
||||
TF_Status* status);
|
||||
|
||||
// Description of the TensorFlow op to execute.
|
||||
//
|
||||
// Assumes that the provided 'ctx' outlives the returned TFE_Op, i.e.,
|
||||
// TFE_DeleteOp() is called before TFE_DeleteContext().
|
||||
//
|
||||
// Very similar to TF_OperationDescription with some differences:
|
||||
// (1) TF_Output or TFE_TensorHandle* as arguments to TF_AddInput,
|
||||
// TF_AddInputList
|
||||
// (2) TF_ColocateWith, TF_AddControlInput etc. do not make sense.
|
||||
// (3) Implementation detail: Avoid use of NodeBuilder/NodeDefBuilder since
|
||||
// the additional sanity checks there seem unnecessary;
|
||||
typedef struct TFE_Op TFE_Op;
|
||||
|
||||
extern TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status);
|
||||
extern void TFE_DeleteOp(TFE_Op* op);
|
||||
|
||||
// TODO(ashankar): TFE_OpSetDevice and TFE_Execute should not have a TFE_Context
|
||||
// parameter. Instead, the TFE_Context should be captured when creating the
|
||||
// TFE_Op.
|
||||
extern void TFE_OpSetDevice(TFE_Op* op, TFE_Context* ctx,
|
||||
const char* device_name, TF_Status* status);
|
||||
|
||||
extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status);
|
||||
|
||||
extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||
unsigned char* is_list, TF_Status* status);
|
||||
|
||||
extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name,
|
||||
const char* value);
|
||||
extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value);
|
||||
extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value);
|
||||
extern void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name,
|
||||
unsigned char value);
|
||||
extern void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name,
|
||||
TF_DataType value);
|
||||
// If the number of dimensions is unknown, `num_dims` must be set to
|
||||
// -1 and `dims` can be null. If a dimension is unknown, the
|
||||
// corresponding entry in the `dims` array must be -1.
|
||||
extern void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name,
|
||||
const int64_t* dims, const int num_dims,
|
||||
TF_Status* out_status);
|
||||
|
||||
extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
|
||||
const char** value, int num_values);
|
||||
extern void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
|
||||
const int64_t* values, int num_values);
|
||||
extern void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
|
||||
const float* values, int num_values);
|
||||
extern void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
|
||||
const unsigned char* values, int num_values);
|
||||
extern void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
|
||||
const TF_DataType* values, int num_values);
|
||||
extern void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
|
||||
const int64_t** dims, const int* num_dims,
|
||||
int num_values, TF_Status* out_status);
|
||||
|
||||
// Execute the operation defined by 'op' and return handles to computed
|
||||
// tensors in 'retvals'.
|
||||
//
|
||||
// 'retvals' must point to a pre-allocated array of TFE_TensorHandle*
|
||||
// and '*num_retvals' should be set to the size of this array.
|
||||
//
|
||||
// On return, 'num_retvals' will be set to the actual number of outputs
|
||||
// returned by the operation.
|
||||
extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals,
|
||||
int* num_retvals, TF_Status* status);
|
||||
|
||||
// Add a function (serialized FunctionDef protocol buffer) to ctx so
|
||||
// that it can be invoked using TFE_Execute.
|
||||
extern void TFE_ContextAddFunctionDef(TFE_Context* ctx,
|
||||
const char* serialized_function_def,
|
||||
size_t size, TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
// A workaround to ease conversion to and from numpy objects and
|
||||
// TFE_TensorHandle's.
|
||||
//
|
||||
// TODO(ashankar): Figure out an alternative scheme that precludes the need for
|
||||
// these API-boundary breaking methods.
|
||||
namespace tensorflow {
|
||||
class Tensor;
|
||||
} // namespace tensorflow
|
||||
|
||||
const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
|
||||
TFE_TensorHandle* h, TF_Status* status);
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t);
|
||||
#endif
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_C_API_H_
|
463
tensorflow/c/eager/c_api_test.cc
Normal file
463
tensorflow/c/eager/c_api_test.cc
Normal file
@ -0,0 +1,463 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
#include <string.h>
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
namespace {
|
||||
|
||||
TFE_TensorHandle* TestMatrixTensorHandle() {
|
||||
int64_t dims[] = {2, 2};
|
||||
float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
TF_Tensor* t = TF_AllocateTensor(
|
||||
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t);
|
||||
TF_DeleteTensor(t);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
TFE_Op* op = TFE_NewOp(ctx, "MatMul", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, a, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, b, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_OpSetAttrBool(op, "transpose_a", 0);
|
||||
TFE_OpSetAttrBool(op, "transpose_b", 0);
|
||||
TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
|
||||
|
||||
return op;
|
||||
}
|
||||
|
||||
// TODO(apassos) uncomment after rewriting to use the right benchmark API
|
||||
// void BM_InitOp(benchmark::State& state) {
|
||||
// TF_Status* status = TF_NewStatus();
|
||||
// TF_SessionOptions* opts = TF_NewSessionOptions();
|
||||
// TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
// TF_DeleteSessionOptions(opts);
|
||||
|
||||
// TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
// for (auto _ : state) {
|
||||
// TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
// TFE_DeleteOp(matmul);
|
||||
// }
|
||||
// TFE_DeleteTensorHandle(m);
|
||||
// TFE_DeleteContext(ctx, status);
|
||||
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
// TF_DeleteStatus(status);
|
||||
// }
|
||||
// BENCHMARK(BM_InitOp);
|
||||
|
||||
// void BM_Execute(benchmark::State& state) {
|
||||
// TF_Status* status = TF_NewStatus();
|
||||
// TF_SessionOptions* opts = TF_NewSessionOptions();
|
||||
// TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
// TF_DeleteSessionOptions(opts);
|
||||
|
||||
// TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
// TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
// TFE_TensorHandle* retvals[1];
|
||||
// int num_retvals = 1;
|
||||
// for (auto _ : state) {
|
||||
// TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
// }
|
||||
// TFE_DeleteOp(matmul);
|
||||
// TFE_DeleteTensorHandle(m);
|
||||
// TFE_DeleteContext(ctx, status);
|
||||
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
// TF_DeleteStatus(status);
|
||||
// }
|
||||
// BENCHMARK(BM_Execute);
|
||||
|
||||
TEST(CAPI, Context) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_SessionOptions* opts = TF_NewSessionOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
TF_DeleteSessionOptions(opts);
|
||||
|
||||
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_DeleteContext(ctx, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
const int num_devices = TF_DeviceListCount(devices);
|
||||
EXPECT_GE(num_devices, 1) << "At least one CPU device should exist";
|
||||
for (int i = 0; i < num_devices; ++i) {
|
||||
EXPECT_NE("", TF_DeviceListName(devices, i, status)) << i;
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
}
|
||||
TF_DeleteDeviceList(devices);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorHandle) {
|
||||
TFE_TensorHandle* h = TestMatrixTensorHandle();
|
||||
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
|
||||
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(h, status.get());
|
||||
ASSERT_EQ(16, TF_TensorByteSize(t));
|
||||
float data[4] = {0};
|
||||
memcpy(&data[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||
EXPECT_EQ(1.0, data[0]);
|
||||
EXPECT_EQ(2.0, data[1]);
|
||||
EXPECT_EQ(3.0, data[2]);
|
||||
EXPECT_EQ(4.0, data[3]);
|
||||
TF_DeleteTensor(t);
|
||||
TFE_DeleteTensorHandle(h);
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorHandleCopyBetweenDevices) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_SessionOptions* opts = TF_NewSessionOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||
TF_DeleteSessionOptions(opts);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
const int num_devices = TF_DeviceListCount(devices);
|
||||
|
||||
const char* kCPUDevice = "CPU:0";
|
||||
for (int i = 0; i < num_devices; ++i) {
|
||||
const string name(TF_DeviceListName(devices, i, status.get()));
|
||||
if (TF_GetCode(status.get()) != TF_OK) {
|
||||
ADD_FAILURE() << i << " -- " << TF_Message(status.get());
|
||||
continue;
|
||||
}
|
||||
auto tag = tensorflow::strings::StrCat("Device #", i, " (", name, ")");
|
||||
// Copy to device
|
||||
TFE_TensorHandle* hdevice =
|
||||
TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get());
|
||||
if (TF_GetCode(status.get()) != TF_OK) {
|
||||
ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
|
||||
continue;
|
||||
}
|
||||
// Copy back to CPU
|
||||
TFE_TensorHandle* hcopy =
|
||||
TFE_TensorHandleCopyToDevice(hdevice, ctx, kCPUDevice, status.get());
|
||||
if (TF_GetCode(status.get()) != TF_OK) {
|
||||
ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
|
||||
continue;
|
||||
}
|
||||
TFE_DeleteTensorHandle(hdevice);
|
||||
|
||||
// Ensure that the contents are the same!
|
||||
TF_Tensor* tcopy = TFE_TensorHandleResolve(hcopy, status.get());
|
||||
TFE_DeleteTensorHandle(hcopy);
|
||||
if (TF_GetCode(status.get()) != TF_OK) {
|
||||
ADD_FAILURE() << tag;
|
||||
continue;
|
||||
}
|
||||
EXPECT_EQ(TF_TensorByteSize(t), TF_TensorByteSize(tcopy)) << tag;
|
||||
EXPECT_EQ(
|
||||
0, memcmp(TF_TensorData(t), TF_TensorData(tcopy), TF_TensorByteSize(t)))
|
||||
<< tag;
|
||||
TF_DeleteTensor(tcopy);
|
||||
}
|
||||
|
||||
TF_DeleteDeviceList(devices);
|
||||
TF_DeleteTensor(t);
|
||||
TFE_DeleteTensorHandle(hcpu);
|
||||
TFE_DeleteContext(ctx, status.get());
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CAPI, Execute) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_SessionOptions* opts = TF_NewSessionOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteSessionOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
TFE_TensorHandle* retvals[2] = {nullptr};
|
||||
int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call.
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(matmul);
|
||||
TFE_DeleteTensorHandle(m);
|
||||
TFE_DeleteContext(ctx, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(1, num_retvals);
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
float product[4] = {0};
|
||||
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
||||
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
EXPECT_EQ(7, product[0]);
|
||||
EXPECT_EQ(10, product[1]);
|
||||
EXPECT_EQ(15, product[2]);
|
||||
EXPECT_EQ(22, product[3]);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
string MatMulFunction() {
|
||||
tensorflow::FunctionDef def;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
" signature {"
|
||||
" name: 'MatMulFunction'"
|
||||
" input_arg {"
|
||||
" name: 'a'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" output_arg {"
|
||||
" name: 'm'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'matmul'"
|
||||
" op: 'MatMul'"
|
||||
" input: 'a'"
|
||||
" input: 'a'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" ret {"
|
||||
" key: 'm'"
|
||||
" value: 'matmul:product'"
|
||||
" }",
|
||||
&def));
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
TEST(CAPI, FunctionDefAndExecute) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_SessionOptions* opts = TF_NewSessionOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteSessionOptions(opts);
|
||||
|
||||
string function_def = MatMulFunction();
|
||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* retval[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Op* op = TFE_NewOp(ctx, "MatMulFunction", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, m, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_Execute(op, &retval[0], &num_retvals, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(1, num_retvals);
|
||||
TFE_DeleteOp(op);
|
||||
TFE_DeleteTensorHandle(m);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retval[0], status);
|
||||
TFE_DeleteTensorHandle(retval[0]);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
float product[4] = {0};
|
||||
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
||||
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
EXPECT_EQ(7, product[0]);
|
||||
EXPECT_EQ(10, product[1]);
|
||||
EXPECT_EQ(15, product[2]);
|
||||
EXPECT_EQ(22, product[3]);
|
||||
TFE_DeleteContext(ctx, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
// TODO(apassos) uncomment after rewriting to use the right benchmark API
|
||||
// void BM_ExecuteFunction(benchmark::State& state) {
|
||||
// TF_Status* status = TF_NewStatus();
|
||||
// TF_SessionOptions* opts = TF_NewSessionOptions();
|
||||
// TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
// TF_DeleteSessionOptions(opts);
|
||||
|
||||
// string function_def = MatMulFunction();
|
||||
// TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
// status);
|
||||
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
// TFE_Op* matmul = TFE_NewOp(ctx, "MatMulFunction", status);
|
||||
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
// TFE_OpAddInput(matmul, m, status);
|
||||
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
// TFE_TensorHandle* retval[1] = {nullptr};
|
||||
// int num_retvals = 1;
|
||||
// for (auto _ : state) {
|
||||
// TFE_Execute(matmul, &retval[0], &num_retvals, status);
|
||||
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
// }
|
||||
// TFE_DeleteTensorHandle(m);
|
||||
// TFE_DeleteTensorHandle(retval[0]);
|
||||
// TFE_DeleteContext(ctx, status);
|
||||
// EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
// TF_DeleteStatus(status);
|
||||
// }
|
||||
// BENCHMARK(BM_ExecuteFunction);
|
||||
|
||||
// TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
|
||||
// TF_Status* status) {
|
||||
// // Create the variable handle.
|
||||
// TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
// if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
// TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
// TFE_OpSetAttrShape(op, "shape", {}, 0, status);
|
||||
// TFE_OpSetAttrString(op, "container", "");
|
||||
// TFE_OpSetAttrString(op, "shared_name", "");
|
||||
// if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
// TFE_TensorHandle* var_handle = nullptr;
|
||||
// int num_retvals = 1;
|
||||
// TFE_Execute(op, &var_handle, &num_retvals, status);
|
||||
// TFE_DeleteOp(op);
|
||||
// if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
// CHECK_EQ(1, num_retvals);
|
||||
|
||||
// // Assign 'value' to it.
|
||||
// op = TFE_NewOp(ctx, "AssignVariableOp", status);
|
||||
// if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
// TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
// TFE_OpAddInput(op, var_handle, status);
|
||||
|
||||
// // Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
|
||||
// std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
|
||||
// TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)),
|
||||
// TF_DeleteTensor);
|
||||
// memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
|
||||
|
||||
// std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
// value_handle(TFE_NewTensorHandle(t.get()), TFE_DeleteTensorHandle);
|
||||
|
||||
// TFE_OpAddInput(op, value_handle.get(), status);
|
||||
// if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
// num_retvals = 0;
|
||||
// TFE_Execute(op, nullptr, &num_retvals, status);
|
||||
// TFE_DeleteOp(op);
|
||||
// if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
// CHECK_EQ(0, num_retvals);
|
||||
|
||||
// return var_handle;
|
||||
// }
|
||||
|
||||
// TEST(CAPI, Variables) {
|
||||
// // Variables use resource handles, so this is really a test for resource
|
||||
// // tensor handling.
|
||||
// TF_Status* status = TF_NewStatus();
|
||||
// TF_SessionOptions* opts = TF_NewSessionOptions();
|
||||
// TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
// TF_DeleteSessionOptions(opts);
|
||||
|
||||
// TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status);
|
||||
// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
||||
// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
// TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
// TFE_OpAddInput(op, var_handle, status);
|
||||
// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
// int num_retvals = 1;
|
||||
// TFE_TensorHandle* value_handle = nullptr;
|
||||
// TFE_Execute(op, &value_handle, &num_retvals, status);
|
||||
// TFE_DeleteOp(op);
|
||||
|
||||
// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
// ASSERT_EQ(1, num_retvals);
|
||||
// EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(value_handle));
|
||||
// EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle));
|
||||
// float value = 0.0f;
|
||||
// TF_Tensor* t = TFE_TensorHandleResolve(value_handle, status);
|
||||
// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
// ASSERT_EQ(sizeof(float), TF_TensorByteSize(t));
|
||||
// memcpy(&value, TF_TensorData(t), sizeof(float));
|
||||
// TF_DeleteTensor(t);
|
||||
// EXPECT_EQ(12.0, value);
|
||||
|
||||
// TFE_DeleteTensorHandle(var_handle);
|
||||
// TFE_DeleteTensorHandle(value_handle);
|
||||
// TFE_DeleteContext(ctx, status);
|
||||
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
// TF_DeleteStatus(status);
|
||||
// }
|
||||
|
||||
// void BM_ReadVariable(benchmark::State& state) {
|
||||
// TF_Status* status = TF_NewStatus();
|
||||
// TF_SessionOptions* opts = TF_NewSessionOptions();
|
||||
// TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
// TF_DeleteSessionOptions(opts);
|
||||
|
||||
// TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status);
|
||||
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
||||
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
// TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
// TFE_OpAddInput(op, var_handle, status);
|
||||
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// int num_retvals = 1;
|
||||
// TFE_TensorHandle* h = nullptr;
|
||||
// for (auto _ : state) {
|
||||
// TFE_Execute(op, &h, &num_retvals, status);
|
||||
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
// CHECK_EQ(1, num_retvals);
|
||||
// CHECK(h);
|
||||
// CHECK_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
|
||||
// CHECK_EQ(0, TFE_TensorHandleNumDims(h));
|
||||
// h = nullptr;
|
||||
// }
|
||||
// TFE_DeleteOp(op);
|
||||
|
||||
// TFE_DeleteTensorHandle(var_handle);
|
||||
// TFE_DeleteContext(ctx, status);
|
||||
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
// TF_DeleteStatus(status);
|
||||
// }
|
||||
// BENCHMARK(BM_ReadVariable);
|
||||
|
||||
} // namespace
|
289
tensorflow/c/eager/runtime.cc
Normal file
289
tensorflow/c/eager/runtime.cc
Normal file
@ -0,0 +1,289 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/runtime.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
mutex g_op_name_to_attr_type_map_lock(LINKER_INITIALIZED);
|
||||
|
||||
std::unordered_map<string, const AttrTypeMap*>* OpNameToAttrTypeMap() {
|
||||
static auto* const m = new std::unordered_map<string, const AttrTypeMap*>;
|
||||
return m;
|
||||
}
|
||||
|
||||
const uint32 kIsList = 1U << 31;
|
||||
|
||||
} // namespace
|
||||
|
||||
Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) {
|
||||
mutex_lock l(g_op_name_to_attr_type_map_lock);
|
||||
*out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name);
|
||||
if (*out != nullptr) return Status::OK();
|
||||
const OpRegistrationData* op_reg_data = nullptr;
|
||||
Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data);
|
||||
if (!s.ok()) return s;
|
||||
std::unique_ptr<AttrTypeMap> m(new AttrTypeMap);
|
||||
// TODO(agarwal): Avoid having to create this "registry" at runtime,
|
||||
// perhaps can be done at op registration time?
|
||||
for (const auto& attr : op_reg_data->op_def.attr()) {
|
||||
string type = attr.type();
|
||||
const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0);
|
||||
if (is_list) {
|
||||
type = type.substr(5, type.length() - 6);
|
||||
}
|
||||
uint32 t = is_list ? kIsList : 0;
|
||||
if (type == "string") {
|
||||
t |= TF_ATTR_STRING;
|
||||
} else if (type == "int") {
|
||||
t |= TF_ATTR_INT;
|
||||
} else if (type == "float") {
|
||||
t |= TF_ATTR_FLOAT;
|
||||
} else if (type == "bool") {
|
||||
t |= TF_ATTR_BOOL;
|
||||
} else if (type == "type") {
|
||||
t |= TF_ATTR_TYPE;
|
||||
} else if (type == "shape") {
|
||||
t |= TF_ATTR_SHAPE;
|
||||
} else if (type == "tensor") {
|
||||
t |= TF_ATTR_TENSOR;
|
||||
} else {
|
||||
return errors::Unimplemented(
|
||||
"TODO(agarwal): Enable support for ops with attributes of type '",
|
||||
type, "'");
|
||||
}
|
||||
gtl::InsertIfNotPresent(m.get(), attr.name(), t);
|
||||
}
|
||||
*out = m.get();
|
||||
(*OpNameToAttrTypeMap())[op_name] = m.release();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name,
|
||||
TF_AttrType* out, unsigned char* is_list) {
|
||||
CHECK(m);
|
||||
auto* t = gtl::FindOrNull(*m, attr_name);
|
||||
if (t == nullptr) {
|
||||
return errors::InvalidArgument("Attribute '", attr_name,
|
||||
"' does not exist for this operation");
|
||||
}
|
||||
*out = static_cast<TF_AttrType>(*t & ~kIsList);
|
||||
if (*t & kIsList) {
|
||||
*is_list = 1;
|
||||
} else {
|
||||
*is_list = 0;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define DEFINE_SET_ATTR(value_type, value_field) \
|
||||
template <> \
|
||||
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, value_type&& value) { \
|
||||
value_field.push_back(std::make_pair(attr_name, value)); \
|
||||
return *this; \
|
||||
}
|
||||
|
||||
DEFINE_SET_ATTR(StringPiece, string_attrs_);
|
||||
DEFINE_SET_ATTR(float, float_attrs_);
|
||||
DEFINE_SET_ATTR(int, int_attrs_);
|
||||
DEFINE_SET_ATTR(bool, bool_attrs_);
|
||||
DEFINE_SET_ATTR(tensorflow::DataType, type_attrs_);
|
||||
|
||||
#undef DEFINE_SET_ATTR
|
||||
|
||||
AttrBuilder& AttrBuilder::NumInputs(int n) {
|
||||
DCHECK(!node_def_finalized_) << "Calling NumInputs after BuildNodeDef.";
|
||||
num_inputs_ = n;
|
||||
return *this;
|
||||
}
|
||||
|
||||
const NodeDef& AttrBuilder::BuildNodeDef() {
|
||||
if (node_def_finalized_) return *node_def_;
|
||||
MayBeInitializeNodeDef();
|
||||
for (int i = 0; i < num_inputs_; ++i) {
|
||||
node_def_->add_input("dummy_input");
|
||||
}
|
||||
for (const auto& p : string_attrs_) {
|
||||
SetInNodeDef(p.first, p.second);
|
||||
}
|
||||
for (const auto& p : int_attrs_) {
|
||||
SetInNodeDef(p.first, p.second);
|
||||
}
|
||||
for (const auto& p : float_attrs_) {
|
||||
SetInNodeDef(p.first, p.second);
|
||||
}
|
||||
for (const auto& p : bool_attrs_) {
|
||||
SetInNodeDef(p.first, p.second);
|
||||
}
|
||||
for (const auto& p : type_attrs_) {
|
||||
SetInNodeDef(p.first, p.second);
|
||||
}
|
||||
node_def_finalized_ = true;
|
||||
return *node_def_;
|
||||
}
|
||||
|
||||
namespace {
|
||||
inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
|
||||
const tensorflow::Fprint128& b) {
|
||||
return {tensorflow::FingerprintCat64(a.low64, b.low64),
|
||||
tensorflow::FingerprintCat64(a.low64, b.low64)};
|
||||
}
|
||||
|
||||
void CombineUnordered(const tensorflow::Fprint128& a,
|
||||
tensorflow::Fprint128* b) {
|
||||
b->low64 += a.low64;
|
||||
b->high64 += a.high64;
|
||||
}
|
||||
|
||||
inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s,
|
||||
const tensorflow::Fprint128& b) {
|
||||
// TODO(agarwal): avoid ToString().
|
||||
tensorflow::Fprint128 a = tensorflow::Fingerprint128(s.ToString());
|
||||
return FingerprintCat128(a, b);
|
||||
}
|
||||
|
||||
inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s, uint64 b) {
|
||||
return CacheKeyHelper(s, {b, b});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
tensorflow::Fprint128 AttrBuilder::CacheKey(const string& device) const {
|
||||
tensorflow::Fprint128 f = tensorflow::Fingerprint128(op_name_);
|
||||
f = tensorflow::FingerprintCat128(f, tensorflow::Fingerprint128(device));
|
||||
if (node_def_ != nullptr) {
|
||||
// Some attributes are directly written to node_def_ instead of being
|
||||
// stored explicitly.
|
||||
string value;
|
||||
for (const auto& attr : node_def_->attr()) {
|
||||
attr.second.SerializeToString(&value);
|
||||
CombineUnordered(
|
||||
CacheKeyHelper(attr.first, tensorflow::Fingerprint128(value)), &f);
|
||||
}
|
||||
// Note that node_def_ may be created but not finalized. This can happen
|
||||
// when the creation was triggered by a call to Set, but BuildNodeDef has
|
||||
// not been called.
|
||||
if (node_def_finalized_) return f;
|
||||
}
|
||||
for (const auto& p : string_attrs_) {
|
||||
// TODO(agarwal): avoid ToString().
|
||||
CombineUnordered(CacheKeyHelper(p.first, tensorflow::Fingerprint128(
|
||||
p.second.ToString())),
|
||||
&f);
|
||||
}
|
||||
for (const auto& p : int_attrs_) {
|
||||
CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
|
||||
&f);
|
||||
}
|
||||
static std::hash<float> float_hasher;
|
||||
for (const auto& p : float_attrs_) {
|
||||
CombineUnordered(
|
||||
CacheKeyHelper(p.first, static_cast<uint64>(float_hasher(p.second))),
|
||||
&f);
|
||||
}
|
||||
for (const auto& p : bool_attrs_) {
|
||||
CombineUnordered(CacheKeyHelper(p.first, p.second ? 1u : 0u), &f);
|
||||
}
|
||||
for (const auto& p : type_attrs_) {
|
||||
CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
|
||||
&f);
|
||||
}
|
||||
return f;
|
||||
}
|
||||
|
||||
void AttrBuilder::MayBeInitializeNodeDef() {
|
||||
if (node_def_ == nullptr) {
|
||||
node_def_.reset(new NodeDef());
|
||||
node_def_->set_name(op_name_);
|
||||
node_def_->set_op(op_name_);
|
||||
}
|
||||
}
|
||||
|
||||
// static
|
||||
Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef,
|
||||
KernelAndDevice* out) {
|
||||
OpKernel* k = nullptr;
|
||||
Status s = CreateOpKernel(device->device_type().c_str(), device,
|
||||
device->GetAllocator(AllocatorAttributes()),
|
||||
nullptr, ndef, TF_GRAPH_DEF_VERSION, &k);
|
||||
out->device_ = device;
|
||||
out->kernel_.reset(k);
|
||||
out->flib_ = nullptr;
|
||||
return s;
|
||||
}
|
||||
|
||||
// static
|
||||
Status KernelAndDevice::InitFn(const NodeDef& ndef,
|
||||
FunctionLibraryRuntime* flib,
|
||||
KernelAndDevice* out) {
|
||||
OpKernel* k = nullptr;
|
||||
Status s = flib->CreateKernel(ndef, &k);
|
||||
out->device_ = flib->device();
|
||||
out->kernel_.reset(k);
|
||||
out->flib_ = flib;
|
||||
return s;
|
||||
}
|
||||
|
||||
Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
|
||||
std::vector<Tensor>* output_tensors) {
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (Tensor& t : *input_tensors) {
|
||||
inputs.push_back(TensorValue(&t));
|
||||
}
|
||||
|
||||
std::vector<AllocatorAttributes> out_attrs(kernel_->num_outputs());
|
||||
for (size_t i = 0; i < out_attrs.size(); ++i) {
|
||||
out_attrs[i].set_on_host(kernel_->output_memory_types()[i] ==
|
||||
tensorflow::HOST_MEMORY);
|
||||
}
|
||||
|
||||
OpKernelContext::Params params;
|
||||
params.device = device_;
|
||||
params.frame_iter = FrameAndIter(0, 0);
|
||||
params.inputs = &inputs;
|
||||
params.op_kernel = kernel_.get();
|
||||
params.resource_manager = device_->resource_manager();
|
||||
params.output_attr_array = gtl::vector_as_array(&out_attrs);
|
||||
params.function_library = flib_;
|
||||
params.slice_reader_cache = &slice_reader_cache_;
|
||||
// TODO(apassos): use a thread pool.
|
||||
std::function<void(std::function<void()>)> runner =
|
||||
[](std::function<void()> f) { f(); };
|
||||
params.runner = &runner;
|
||||
|
||||
OpKernelContext context(¶ms);
|
||||
device_->Compute(kernel_.get(), &context);
|
||||
if (!context.status().ok()) return context.status();
|
||||
|
||||
output_tensors->clear();
|
||||
for (int i = 0; i < context.num_outputs(); ++i) {
|
||||
output_tensors->push_back(Tensor(*context.mutable_output(i)));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
193
tensorflow/c/eager/runtime.h
Normal file
193
tensorflow/c/eager/runtime.h
Normal file
@ -0,0 +1,193 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_RUNTIME_H_
|
||||
#define TENSORFLOW_C_EAGER_RUNTIME_H_
|
||||
|
||||
// Support for eager execution of TensorFlow kernels.
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Maps attribute name to an encoding of the type of the attribute value.
|
||||
// If the type is not a list type, the value is the same as the TF_AttrType type
|
||||
// of the value. Else, the highest order bit is on, and the rest of the bits
|
||||
// represent the TF_AttrType type of the values in the list.
|
||||
typedef std::unordered_map<string, uint32> AttrTypeMap;
|
||||
|
||||
// Returns the AttrTypeMap for the TensorFlow operation named op_name.
|
||||
Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out);
|
||||
|
||||
// Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'.
|
||||
Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name,
|
||||
TF_AttrType* out, unsigned char* is_list);
|
||||
|
||||
// KernelAndDevice::Init needs a NodeDef only to pass the attribute map through.
|
||||
// An AttrBuilder is a convenience class to help with that - providing a smaller
|
||||
// interface than NodeDefBuilder and avoiding expensive (unnecessary?) sanity
|
||||
// checks (like number of inputs matching the OpDef - we only care about
|
||||
// attributes here).
|
||||
//
|
||||
// TODO(ashankar): Take a closer look at checks in NodeDefBuilder and see which
|
||||
// ones make sense to replicate.
|
||||
|
||||
// This is a helper class for creating a NodeDef. Additionally, this class
|
||||
// allows computing a cache key based on fingerprinting the attributes of this
|
||||
// NodeDef.
|
||||
//
|
||||
// Example usage:
|
||||
// AttrBuilder a;
|
||||
// a.NumInputs(2);
|
||||
// a.Set("T", TF_FLOAT);
|
||||
// uint64 cache_key = a.CacheKey("cpu:0");
|
||||
// const NodeDef& n = a.BuildNodeDef();
|
||||
//
|
||||
// Note that all calls to Set and NumInputs should happen before calling
|
||||
// BuildNodeDef. Also, calls to NumInputs or Set between multiple invocations
|
||||
// to CacheKey may cause different values to be returned by CacheKey.
|
||||
//
|
||||
// For performance reasons, the class internally delays the actual construction
|
||||
// of the NodeDef till BuildNodeDef is called, or Set is called with certain
|
||||
// uncommon types (see template specializations of Set to see which types
|
||||
// trigger a NodeDef creation).
|
||||
class AttrBuilder {
|
||||
public:
|
||||
explicit AttrBuilder(const char* op)
|
||||
: op_name_(op),
|
||||
num_inputs_(0),
|
||||
node_def_(nullptr),
|
||||
node_def_finalized_(false) {}
|
||||
|
||||
// Needed to work around call to ValidateNodeDef in CreateOpKernel.
|
||||
AttrBuilder& NumInputs(int n);
|
||||
|
||||
template <class T>
|
||||
AttrBuilder& Set(StringPiece attr_name, T&& value) {
|
||||
MayBeInitializeNodeDef();
|
||||
return SetInNodeDef(attr_name, value);
|
||||
}
|
||||
|
||||
tensorflow::Fprint128 CacheKey(const string& device) const;
|
||||
|
||||
const NodeDef& BuildNodeDef();
|
||||
|
||||
private:
|
||||
template <class T>
|
||||
using AttrVec = tensorflow::gtl::InlinedVector<std::pair<StringPiece, T>, 2>;
|
||||
|
||||
void MayBeInitializeNodeDef();
|
||||
|
||||
template <class T>
|
||||
AttrBuilder& SetInNodeDef(StringPiece attr_name, T&& value) {
|
||||
DCHECK(!node_def_finalized_) << "Calling SetInNodeDef after BuildNodeDef.";
|
||||
// Copied from NodeDefBuilder::Attr
|
||||
const AttrValue* found = AttrSlice(*node_def_).Find(attr_name);
|
||||
if (found == nullptr) {
|
||||
AddNodeAttr(attr_name, std::forward<T>(value), node_def_.get());
|
||||
} else {
|
||||
AttrValue attr_value;
|
||||
SetAttrValue(std::forward<T>(value), &attr_value);
|
||||
// TODO(ashankar): Do what is done in
|
||||
// NodeDefBuilder::CheckInconsistency(attr_name, *found, attr_value);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
AttrVec<StringPiece> string_attrs_;
|
||||
AttrVec<int> int_attrs_;
|
||||
AttrVec<float> float_attrs_;
|
||||
AttrVec<bool> bool_attrs_;
|
||||
AttrVec<tensorflow::DataType> type_attrs_;
|
||||
string op_name_;
|
||||
int num_inputs_;
|
||||
std::unique_ptr<NodeDef> node_def_;
|
||||
bool node_def_finalized_;
|
||||
}; // namespace tensorflow
|
||||
|
||||
template <>
|
||||
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, StringPiece&& value);
|
||||
template <>
|
||||
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, int&& value);
|
||||
template <>
|
||||
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, float&& value);
|
||||
template <>
|
||||
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, bool&& value);
|
||||
template <>
|
||||
AttrBuilder& AttrBuilder::Set(StringPiece attr_name,
|
||||
tensorflow::DataType&& value);
|
||||
|
||||
// KernelAndDevice encapsulates an instantiated kernel and the device it is on.
|
||||
//
|
||||
// Also see:
|
||||
// https://www.tensorflow.org/code/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
|
||||
// and
|
||||
// https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h
|
||||
class KernelAndDevice {
|
||||
public:
|
||||
// Populates 'out' with a kernel appropriate for 'ndef'.
|
||||
//
|
||||
// Assumes that 'ndef' refers to a primitive op (as opposed to a function).
|
||||
static Status InitOp(Device* device, const NodeDef& ndef,
|
||||
KernelAndDevice* out);
|
||||
|
||||
// Like InitOp but for functions defined in flib (i.e., ndef.op() refers to a
|
||||
// TensorFlow function in the FunctionLibraryRuntime).
|
||||
//
|
||||
// The provided FunctionLibraryRuntime MUST outlive all calls to
|
||||
// Run() on the returned KernelAndDevice.
|
||||
//
|
||||
// TODO(ashankar): There shouldn't be a need for a separate InitOp and InitFn.
|
||||
// The implementation of InitFn should work for both because
|
||||
// FunctionLibraryRuntime::CreateKernel will create a primitive op kernel if
|
||||
// appropriate. However, for now we keep them separate because I haven't
|
||||
// figured out thread-safety concerns around FunctionLibraryRuntime (in
|
||||
// particular, how the underlying FunctionLibraryDefinition might be mutated
|
||||
// by another thread as new functions are registered with it).
|
||||
// Conservatively, thread-safe usage of the FunctionLibraryRuntime is pushed
|
||||
// on to the caller (see locking in c_api.cc) for now. But I really should
|
||||
// dig into this so that both InitOp and InitFn can be collapsed to
|
||||
// FunctionLibraryRuntime::CreateKernel.
|
||||
static Status InitFn(const NodeDef& ndef, FunctionLibraryRuntime* flib,
|
||||
KernelAndDevice* out);
|
||||
|
||||
KernelAndDevice() : device_(nullptr), flib_(nullptr) {}
|
||||
|
||||
// TODO(ashankar): Handle list-valued inputs.
|
||||
Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs);
|
||||
|
||||
const OpKernel* kernel() const { return kernel_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<OpKernel> kernel_;
|
||||
tensorflow::Device* device_;
|
||||
tensorflow::FunctionLibraryRuntime* flib_;
|
||||
tensorflow::checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_RUNTIME_H_
|
160
tensorflow/c/eager/runtime_test.cc
Normal file
160
tensorflow/c/eager/runtime_test.cc
Normal file
@ -0,0 +1,160 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/runtime.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/client/client_session.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Device* CPUDevice() {
|
||||
return DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
|
||||
}
|
||||
|
||||
TEST(AttrTypeMap, Lookup) {
|
||||
const AttrTypeMap* m = nullptr;
|
||||
Status s = AttrTypeMapForOp("ThisOpCannotPossiblyExist", &m);
|
||||
EXPECT_FALSE(s.ok());
|
||||
s = AttrTypeMapForOp("MatMul", &m);
|
||||
ASSERT_TRUE(s.ok()) << s;
|
||||
|
||||
TF_AttrType t;
|
||||
unsigned char is_list = 1;
|
||||
s = AttrTypeByName(m, "ThisAttribyteCannotPossiblyExist", &t, &is_list);
|
||||
EXPECT_FALSE(s.ok());
|
||||
EXPECT_NE(is_list, 0);
|
||||
s = AttrTypeByName(m, "transpose_a", &t, &is_list);
|
||||
ASSERT_TRUE(s.ok()) << s;
|
||||
EXPECT_EQ(TF_ATTR_BOOL, t);
|
||||
EXPECT_EQ(is_list, 0);
|
||||
|
||||
s = AttrTypeMapForOp("Squeeze", &m);
|
||||
ASSERT_TRUE(s.ok()) << s;
|
||||
s = AttrTypeByName(m, "squeeze_dims", &t, &is_list);
|
||||
ASSERT_TRUE(s.ok()) << s;
|
||||
EXPECT_EQ(TF_ATTR_INT, t);
|
||||
EXPECT_NE(is_list, 0);
|
||||
}
|
||||
|
||||
TEST(KernelAndDevice, Run) {
|
||||
Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor());
|
||||
std::vector<Tensor> inputs;
|
||||
inputs.push_back(t);
|
||||
inputs.push_back(t);
|
||||
NodeDef ndef(AttrBuilder("MatMul")
|
||||
.Set("T", DT_FLOAT)
|
||||
.Set("transpose_a", false)
|
||||
.Set("transpose_b", false)
|
||||
.NumInputs(inputs.size())
|
||||
.BuildNodeDef());
|
||||
std::unique_ptr<Device> device(CPUDevice());
|
||||
KernelAndDevice kernel;
|
||||
Status s = KernelAndDevice::InitOp(device.get(), ndef, &kernel);
|
||||
ASSERT_TRUE(s.ok()) << s;
|
||||
std::vector<Tensor> outputs;
|
||||
s = kernel.Run(&inputs, &outputs);
|
||||
ASSERT_TRUE(s.ok()) << s;
|
||||
ASSERT_EQ(1, outputs.size());
|
||||
const Tensor& out = outputs[0];
|
||||
EXPECT_EQ(7, out.matrix<float>()(0, 0));
|
||||
EXPECT_EQ(10, out.matrix<float>()(0, 1));
|
||||
EXPECT_EQ(15, out.matrix<float>()(1, 0));
|
||||
EXPECT_EQ(22, out.matrix<float>()(1, 1));
|
||||
}
|
||||
|
||||
// TODO(apassos) uncomment after rewriting to use the right benchmark API
|
||||
// void BM_CreateGraph(benchmark::State& state) {
|
||||
// for (auto _ : state) {
|
||||
// Scope root = Scope::NewRootScope();
|
||||
// auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
|
||||
// auto M = ops::MatMul(root, C, C);
|
||||
// TF_CHECK_OK(root.status());
|
||||
// }
|
||||
// }
|
||||
// BENCHMARK(BM_CreateGraph);
|
||||
|
||||
// void BM_RunGraph(benchmark::State& state) {
|
||||
// Scope root = Scope::NewRootScope();
|
||||
// auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
|
||||
// auto M = ops::MatMul(root, C, C);
|
||||
// SessionOptions opts;
|
||||
// opts.config.set_inter_op_parallelism_threads(1);
|
||||
// opts.config.set_intra_op_parallelism_threads(1);
|
||||
// ClientSession sess(root, opts);
|
||||
// std::vector<Tensor> outputs;
|
||||
// for (auto _ : state) {
|
||||
// outputs.clear();
|
||||
// TF_CHECK_OK(sess.Run({M}, &outputs));
|
||||
// }
|
||||
// }
|
||||
// BENCHMARK(BM_RunGraph);
|
||||
|
||||
// void BM_CreateAndDestroySession(benchmark::State& state) {
|
||||
// Scope root = Scope::NewRootScope();
|
||||
// auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
|
||||
// auto M = ops::MatMul(root, C, C);
|
||||
// for (auto _ : state) {
|
||||
// ClientSession sess(root);
|
||||
// }
|
||||
// }
|
||||
// BENCHMARK(BM_CreateAndDestroySession);
|
||||
|
||||
// void BM_KernelAndDeviceInit(benchmark::State& state) {
|
||||
// NodeDef ndef(AttrBuilder("MatMul")
|
||||
// .Set("T", DT_FLOAT)
|
||||
// .Set("transpose_a", false)
|
||||
// .Set("transpose_b", false)
|
||||
// .NumInputs(2)
|
||||
// .BuildNodeDef());
|
||||
// std::unique_ptr<Device> device(CPUDevice());
|
||||
// KernelAndDevice k;
|
||||
// for (auto _ : state) {
|
||||
// TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &k));
|
||||
// }
|
||||
// }
|
||||
// BENCHMARK(BM_KernelAndDeviceInit);
|
||||
|
||||
// void BM_KernelAndDeviceRun(benchmark::State& state) {
|
||||
// Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor());
|
||||
// std::vector<Tensor> inputs;
|
||||
// inputs.push_back(t);
|
||||
// inputs.push_back(t);
|
||||
// std::vector<Tensor> outputs;
|
||||
// NodeDef ndef(AttrBuilder("MatMul")
|
||||
// .Set("T", DT_FLOAT)
|
||||
// .Set("transpose_a", false)
|
||||
// .Set("transpose_b", false)
|
||||
// .NumInputs(inputs.size())
|
||||
// .BuildNodeDef());
|
||||
// std::unique_ptr<Device> device(CPUDevice());
|
||||
// KernelAndDevice kernel;
|
||||
// TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &kernel));
|
||||
// for (auto _ : state) {
|
||||
// TF_CHECK_OK(kernel.Run(&inputs, &outputs));
|
||||
// }
|
||||
// }
|
||||
// BENCHMARK(BM_KernelAndDeviceRun);
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -18,6 +18,10 @@
|
||||
set(tf_c_srcs
|
||||
"${tensorflow_source_dir}/tensorflow/c/c_api.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/c/c_api.h"
|
||||
"${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/c/eager/c_api.h"
|
||||
"${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/c/eager/runtime.h"
|
||||
"${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.h"
|
||||
"${tensorflow_source_dir}/tensorflow/c/tf_status_helper.cc"
|
||||
|
@ -755,6 +755,8 @@ add_custom_command(
|
||||
set (pywrap_tensorflow_internal_src
|
||||
"${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.h"
|
||||
"${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tfe.h"
|
||||
"${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tfe_src.h"
|
||||
"${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.h"
|
||||
"${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/python/framework/cpp_shape_inference.h"
|
||||
|
@ -60,6 +60,7 @@ INCLUDEPRE_RE = re.compile(r"google::protobuf::internal::ExplicitlyConstructed|"
|
||||
|
||||
# Include if matched after exclude
|
||||
INCLUDE_RE = re.compile(r"^(TF_\w*)$|"
|
||||
r"^(TFE_\w*)$|"
|
||||
r"tensorflow::|"
|
||||
r"functor::|"
|
||||
r"perftools::gputools")
|
||||
|
@ -2814,6 +2814,7 @@ tf_py_wrap_cc(
|
||||
"lib/io/py_record_reader.i",
|
||||
"lib/io/py_record_writer.i",
|
||||
"platform/base.i",
|
||||
"pywrap_tfe.i",
|
||||
"training/quantize_training.i",
|
||||
"training/server_lib.i",
|
||||
"util/kernel_registry.i",
|
||||
@ -2838,6 +2839,7 @@ tf_py_wrap_cc(
|
||||
"//tensorflow/c:checkpoint_reader",
|
||||
"//tensorflow/c:python_api",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
@ -2850,6 +2852,7 @@ tf_py_wrap_cc(
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/profiler/internal:print_model_analysis",
|
||||
"//tensorflow/tools/graph_transforms:transform_graph_lib",
|
||||
"//tensorflow/python/eager:pywrap_tfe_lib",
|
||||
"//util/python:python_headers",
|
||||
] + (tf_additional_lib_deps() +
|
||||
tf_additional_plugin_deps() +
|
||||
|
@ -56,7 +56,7 @@ tensorflow::ImportNumpy();
|
||||
// const char*.
|
||||
%typemap(in) (const char* target) {
|
||||
$1 = PyBytes_AsString($input);
|
||||
if (!$1) {
|
||||
if (!$1) {
|
||||
// Python has raised an error.
|
||||
SWIG_fail;
|
||||
}
|
||||
|
254
tensorflow/python/eager/BUILD
Normal file
254
tensorflow/python/eager/BUILD
Normal file
@ -0,0 +1,254 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
cc_library(
|
||||
name = "pywrap_tfe_lib",
|
||||
srcs = ["pywrap_tfe_src.cc"],
|
||||
hdrs = ["pywrap_tfe.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/python:numpy_lib",
|
||||
"//tensorflow/python:py_func_lib",
|
||||
"//util/python:python_headers",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "core",
|
||||
srcs = ["core.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":context",
|
||||
":memory_trace",
|
||||
":tape",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tensor",
|
||||
srcs = ["tensor.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//learning/brain/contrib/eager:__subpackages__"],
|
||||
deps = [
|
||||
":context",
|
||||
":core",
|
||||
":tape",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "context",
|
||||
srcs = ["context.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//learning/brain/contrib/eager:__subpackages__"],
|
||||
deps = [
|
||||
"//tensorflow/python:device",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tape",
|
||||
srcs = ["tape.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "memory_trace",
|
||||
srcs = ["memory_trace.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "core_test",
|
||||
srcs = ["core_test.py"],
|
||||
additional_deps = [
|
||||
":context",
|
||||
":core",
|
||||
":execute",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
":tensor",
|
||||
":test",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "test",
|
||||
srcs = ["test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":context",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "execute",
|
||||
srcs = ["execute.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":context",
|
||||
":core",
|
||||
":tape",
|
||||
":tensor",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:util",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "python_eager_op_gen",
|
||||
srcs = ["python_eager_op_gen.cc"],
|
||||
hdrs = ["python_eager_op_gen.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
"//tensorflow/core:proto_text",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/python:python_op_gen",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "python_eager_op_gen_main",
|
||||
srcs = [
|
||||
"python_eager_op_gen_main.cc",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":python_eager_op_gen",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "python_eager_op_gen_demo",
|
||||
deps = [
|
||||
":python_eager_op_gen_main",
|
||||
"//tensorflow/core:ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "custom_gradient",
|
||||
srcs = ["custom_gradient.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":core",
|
||||
":tape",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "graph_only_ops",
|
||||
srcs = ["graph_only_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:framework_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "framework_for_generated_wrappers",
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python/eager:execute",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "function",
|
||||
srcs = ["function.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":graph_only_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:graph_to_function_def",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:core",
|
||||
"//tensorflow/python/eager:execute",
|
||||
"//tensorflow/python/eager:tape",
|
||||
"//tensorflow/python/eager:tensor",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "pip_dependencies",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":context",
|
||||
":core",
|
||||
":execute",
|
||||
":tensor",
|
||||
":test",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Google-internal targets.
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
333
tensorflow/python/eager/context.py
Normal file
333
tensorflow/python/eager/context.py
Normal file
@ -0,0 +1,333 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Experimental API for TensorFlow's "Eager" mode of execution."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import threading
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops as tf_ops
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
GRAPH_MODE = 0
|
||||
EAGER_MODE = 1
|
||||
|
||||
# Default execution mode.
|
||||
_default_mode = GRAPH_MODE
|
||||
|
||||
|
||||
# TODO(agarwal): better name ?
|
||||
class _EagerContext(threading.local):
|
||||
"""Thread local eager context."""
|
||||
|
||||
def __init__(self):
|
||||
super(_EagerContext, self).__init__()
|
||||
self.device_index = -1
|
||||
self.mode = _default_mode
|
||||
self.scope_name = ""
|
||||
self.recording_summaries = False
|
||||
|
||||
|
||||
# TODO(agarwal): rename to EagerContext / EagerRuntime ?
|
||||
class Context(object):
|
||||
"""Environment in which eager operations execute."""
|
||||
|
||||
def __init__(self, graph=None):
|
||||
self._eager_context = _EagerContext()
|
||||
if not self.in_eager_mode():
|
||||
raise ValueError("Trying to create a Context in GRAPH_MODE")
|
||||
# Create a handle
|
||||
opts = pywrap_tensorflow.TF_NewSessionOptions(target=compat.as_bytes(""),
|
||||
config=None)
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
self._handle = pywrap_tensorflow.TFE_NewContext(opts, status)
|
||||
pywrap_tensorflow.TF_DeleteSessionOptions(opts)
|
||||
# Store list of devices
|
||||
self._devices = []
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
device_list = pywrap_tensorflow.TFE_ContextListDevices(
|
||||
self._handle, status)
|
||||
try:
|
||||
for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i, status)
|
||||
self._devices.append(pydev.canonical_name(dev_name))
|
||||
finally:
|
||||
pywrap_tensorflow.TF_DeleteDeviceList(device_list)
|
||||
|
||||
self._summary_writer_resource = None
|
||||
self._graph = graph or tf_ops.get_default_graph()
|
||||
|
||||
def __del__(self):
|
||||
if self._handle is not None:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
pywrap_tensorflow.TFE_DeleteContext(self._handle, status)
|
||||
|
||||
def __str__(self):
|
||||
lines = [
|
||||
"Eager TensorFlow environment with %d devices" % (len(self._devices))
|
||||
]
|
||||
for i, d in enumerate(self._devices):
|
||||
lines.append(" Device %d: %s" % (i, d))
|
||||
return "\n".join(lines)
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def _mode(self, mode):
|
||||
ctx = self._eager_context
|
||||
old_mode = ctx.mode
|
||||
ctx.mode = mode
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
ctx.mode = old_mode
|
||||
|
||||
def in_graph_mode(self):
|
||||
"""Returns True if current thread is in GRAPH mode."""
|
||||
return self._eager_context.mode == GRAPH_MODE
|
||||
|
||||
def in_eager_mode(self):
|
||||
"""Returns True if current thread is in EAGER mode."""
|
||||
return self._eager_context.mode == EAGER_MODE
|
||||
|
||||
@property
|
||||
def scope_name(self):
|
||||
"""Returns scope name for the current thread."""
|
||||
return self._eager_context.scope_name
|
||||
|
||||
@scope_name.setter
|
||||
def scope_name(self, s):
|
||||
"""Sets scope name for the current thread."""
|
||||
self._eager_context.scope_name = s
|
||||
|
||||
@property
|
||||
def summary_writer_resource(self):
|
||||
"""Returns summary writer resource."""
|
||||
return self._summary_writer_resource
|
||||
|
||||
@summary_writer_resource.setter
|
||||
def summary_writer_resource(self, resource):
|
||||
"""Sets summary writer resource."""
|
||||
self._summary_writer_resource = resource
|
||||
|
||||
@property
|
||||
def recording_summaries(self):
|
||||
"""Returns True if recording summaries is enabled in current thread.."""
|
||||
return self._eager_context.recording_summaries
|
||||
|
||||
@recording_summaries.setter
|
||||
def recording_summaries(self, val):
|
||||
"""Enables recording summaries is enabled in current thread.."""
|
||||
self._eager_context.recording_summaries = val
|
||||
|
||||
# TODO(agarwal): remove?
|
||||
@property
|
||||
def _device_index(self):
|
||||
return self._eager_context.device_index
|
||||
|
||||
# TODO(agarwal): remove?
|
||||
@_device_index.setter
|
||||
def _device_index(self, val):
|
||||
self._eager_context.device_index = val
|
||||
|
||||
@property
|
||||
def device_name(self):
|
||||
"""Returns the device name for the current thread."""
|
||||
index = self._device_index
|
||||
return None if index < 0 else self._devices[index]
|
||||
|
||||
def devices(self):
|
||||
"""List of the names of devices available to execute operations."""
|
||||
return self._devices
|
||||
|
||||
def num_gpus(self):
|
||||
"""The number of GPUs available to execute operations."""
|
||||
# TODO(ashankar): Use TF_DeviceListType to count GPU devices.
|
||||
return len(self._devices) - 1
|
||||
|
||||
def as_default(self):
|
||||
"""Returns a context manager to make self the default for this thread."""
|
||||
return _default_context_stack.get_controller(self)
|
||||
|
||||
|
||||
class _DefaultContextStack(tf_ops._DefaultStack): # pylint: disable=protected-access
|
||||
"""A thread-local stack of Context objects."""
|
||||
|
||||
def __init__(self):
|
||||
super(_DefaultContextStack, self).__init__()
|
||||
self._global_default_context = None
|
||||
|
||||
def get_default(self):
|
||||
"""Returns a thread local object if present, else a global default."""
|
||||
return (super(_DefaultContextStack, self).get_default() or
|
||||
self.global_default_context)
|
||||
|
||||
@property
|
||||
def global_default_context(self):
|
||||
if self._global_default_context is None:
|
||||
self._global_default_context = Context()
|
||||
return self._global_default_context
|
||||
|
||||
def reset(self):
|
||||
super(_DefaultContextStack, self).reset()
|
||||
self._global_default_context = None
|
||||
|
||||
|
||||
_default_context_stack = _DefaultContextStack()
|
||||
|
||||
|
||||
def get_default_context():
|
||||
"""Returns a default Context object."""
|
||||
return _default_context_stack.get_default()
|
||||
|
||||
|
||||
# TODO(agarwal): switch users to get_default_context and get rid of this
|
||||
# function.
|
||||
def context():
|
||||
return get_default_context()
|
||||
|
||||
|
||||
def in_graph_mode():
|
||||
"""Returns True if current thread is in GRAPH mode for default context."""
|
||||
return get_default_context().in_graph_mode()
|
||||
|
||||
|
||||
def in_eager_mode():
|
||||
"""Returns True if current thread is in EAGER mode for default context."""
|
||||
return get_default_context().in_eager_mode()
|
||||
|
||||
|
||||
def graph_mode():
|
||||
"""Context-manager to enable GRAPH mode for current thread."""
|
||||
return get_default_context()._mode(GRAPH_MODE) # pylint: disable=protected-access
|
||||
|
||||
|
||||
def eager_mode():
|
||||
"""Context-manager to enable EAGER mode for current thread."""
|
||||
return get_default_context()._mode(EAGER_MODE) # pylint: disable=protected-access
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def namescope(name):
|
||||
"""ContextManager for creating hierarchical name scopes."""
|
||||
ctx = get_default_context()
|
||||
old_name = ctx.scope_name
|
||||
ctx.scope_name = "%s/%s" % (old_name, name) if old_name else name
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
ctx.scope_name = old_name
|
||||
|
||||
|
||||
def scope_name():
|
||||
"""Name of the current scope."""
|
||||
return get_default_context().scope_name
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def device(name):
|
||||
"""Context-manager to force placement of operations and Tensors on a device.
|
||||
|
||||
For example:
|
||||
```python
|
||||
with tfe.device('gpu:0'):
|
||||
with tfe.device('cpu:0'):
|
||||
shape = tfe.Tensor([], dtype=tf.int32)
|
||||
x = ops.truncated_normal(shape, tf.float32)
|
||||
```
|
||||
will ensure that the `shape` Tensor is on CPU but the `truncated_normal`
|
||||
operation
|
||||
runs on GPU 0.
|
||||
|
||||
Args:
|
||||
name: Name of the device (see get_default_context().devices()), or None to
|
||||
enable automatic placement.
|
||||
|
||||
Yields:
|
||||
Nothing.
|
||||
|
||||
Raises:
|
||||
ValueError: If name does not correspond to a valid device.
|
||||
"""
|
||||
device_index = -1
|
||||
ctx = get_default_context()
|
||||
if name is not None:
|
||||
name = pydev.canonical_name(name)
|
||||
all_devices = ctx.devices()
|
||||
for i, d in enumerate(all_devices):
|
||||
# TODO(ashankar): This will change when we have distributed support.
|
||||
# At that point, should not look for a string suffix but be able to
|
||||
# do a full string comparison.
|
||||
if d.endswith(name):
|
||||
device_index = i
|
||||
break
|
||||
if device_index < 0:
|
||||
raise ValueError("device {} does not match the available devices ({})".
|
||||
format(name, all_devices))
|
||||
old_device_index = ctx._device_index # pylint: disable=protected-access
|
||||
try:
|
||||
ctx._device_index = device_index # pylint: disable=protected-access
|
||||
yield
|
||||
finally:
|
||||
ctx._device_index = old_device_index # pylint: disable=protected-access
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def record_summaries():
|
||||
"""Context-manager to enable recording of summaries."""
|
||||
ctx = get_default_context()
|
||||
old = ctx.recording_summaries
|
||||
ctx.recording_summaries = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
ctx.recording_summaries = old
|
||||
|
||||
|
||||
def should_record_summary():
|
||||
"""True if a summary should be recorded now."""
|
||||
c = get_default_context()
|
||||
return c.recording_summaries and c.summary_writer_resource is not None
|
||||
|
||||
|
||||
def run(main=None, argv=None):
|
||||
"""Runs the program with an optional 'main' function and 'argv' list.
|
||||
|
||||
The program will run with eager execution enabled.
|
||||
|
||||
Args:
|
||||
main: the main function to run
|
||||
argv: the arguments to pass to it
|
||||
"""
|
||||
enable_eager_execution()
|
||||
app.run(main, argv)
|
||||
|
||||
|
||||
# TODO(apassos): This should not be a part of the public API.
|
||||
def enable_eager_execution():
|
||||
"""Enables, for the rest of the lifetime of this program, eager execution.
|
||||
|
||||
If not called immediately on startup risks creating breakage and bugs.
|
||||
"""
|
||||
global _default_mode
|
||||
assert _default_mode == GRAPH_MODE
|
||||
_default_mode = EAGER_MODE
|
88
tensorflow/python/eager/core.py
Normal file
88
tensorflow/python/eager/core.py
Normal file
@ -0,0 +1,88 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Experimental API for TensorFlow's "Eager" mode of execution."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import memory_trace
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
# Trace of execution and memory usage.
|
||||
_active_trace = None
|
||||
|
||||
_uid_counter = 0
|
||||
_uid_lock = threading.Lock()
|
||||
|
||||
|
||||
def uid():
|
||||
"""A unique (within this program execution) integer."""
|
||||
with _uid_lock:
|
||||
global _uid_counter
|
||||
_uid_counter += 1
|
||||
return _uid_counter
|
||||
|
||||
|
||||
def _status_to_exception(code, message):
|
||||
try:
|
||||
error_class = errors.exception_type_from_error_code(code)
|
||||
return error_class(None, None, message)
|
||||
except KeyError:
|
||||
return errors.UnknownError(None, None, message, code)
|
||||
|
||||
|
||||
class _NotOkStatusException(Exception):
|
||||
"""Exception class to handle not ok Status."""
|
||||
|
||||
def __init__(self, message, code):
|
||||
super(_NotOkStatusException, self).__init__()
|
||||
self.message = message
|
||||
self.code = code
|
||||
|
||||
def __str__(self):
|
||||
e = _status_to_exception(self.code, self.message)
|
||||
return "%s: %s" % (e.__class__.__name__, e)
|
||||
|
||||
|
||||
pywrap_tensorflow.TFE_Py_RegisterExceptionClass(_NotOkStatusException)
|
||||
|
||||
|
||||
def enable_tracing():
|
||||
"""Enables tracing of execution and memory usage.
|
||||
|
||||
WARNING: tracing is not thread-safe.
|
||||
"""
|
||||
global _active_trace
|
||||
_active_trace = memory_trace.MemoryTrace(
|
||||
len(context.get_default_context().devices()))
|
||||
|
||||
|
||||
def flush_trace():
|
||||
"""Flushes the active trace, if it exists.
|
||||
|
||||
WARNING: tracing is not thread-safe.
|
||||
"""
|
||||
if _active_trace is not None:
|
||||
_active_trace.flush_trace()
|
||||
|
||||
|
||||
def active_trace():
|
||||
"""Returns the current global active trace of execution and memory usage."""
|
||||
return _active_trace
|
488
tensorflow/python/eager/core_test.py
Normal file
488
tensorflow/python/eager/core_test.py
Normal file
@ -0,0 +1,488 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for core."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import core
|
||||
from tensorflow.python.eager import execute
|
||||
from tensorflow.python.eager import tensor
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
|
||||
|
||||
def truncated_normal(shape):
|
||||
return execute.execute(
|
||||
'TruncatedNormal',
|
||||
1,
|
||||
inputs=[shape],
|
||||
attrs=('dtype', dtypes.float32.as_datatype_enum, 'T',
|
||||
shape.dtype.as_datatype_enum, 'seed', 0, 'seed2', 0))[0]
|
||||
|
||||
|
||||
class TFETest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testContext(self):
|
||||
ctx = context.Context()
|
||||
self.assertFalse(ctx.in_graph_mode())
|
||||
self.assertTrue(ctx.in_eager_mode())
|
||||
self.assertEqual('', ctx.scope_name)
|
||||
self.assertEqual(-1, ctx._device_index) # pylint: disable=protected-access
|
||||
self.assertFalse(ctx.recording_summaries)
|
||||
self.assertIsNone(ctx.summary_writer_resource)
|
||||
del ctx
|
||||
|
||||
def testDefaultContext(self):
|
||||
orig = context.get_default_context()
|
||||
self.assertIs(context.get_default_context(), orig)
|
||||
c0 = context.Context()
|
||||
self.assertIs(context.get_default_context(), orig)
|
||||
context_manager_0 = c0.as_default()
|
||||
self.assertIs(context.get_default_context(), orig)
|
||||
with context_manager_0 as c0:
|
||||
self.assertIs(context.get_default_context(), c0)
|
||||
with context.Context().as_default() as c1:
|
||||
self.assertIs(context.get_default_context(), c1)
|
||||
self.assertIs(context.get_default_context(), c0)
|
||||
self.assertIs(context.get_default_context(), orig)
|
||||
|
||||
def testContextWithThreads(self):
|
||||
|
||||
def run_fn(ctx1):
|
||||
ctx2 = context.get_default_context()
|
||||
# Default context created in different threads are different.
|
||||
self.assertIsNot(ctx1, ctx2)
|
||||
# Check that default values of the context created in a different thread
|
||||
# are set correctly.
|
||||
self.assertFalse(ctx2.in_graph_mode())
|
||||
self.assertTrue(ctx2.in_eager_mode())
|
||||
self.assertEqual('', ctx2.scope_name)
|
||||
self.assertEqual(-1, ctx2._device_index) # pylint: disable=protected-access
|
||||
self.assertFalse(ctx2.recording_summaries)
|
||||
self.assertIsNone(ctx2.summary_writer_resource)
|
||||
|
||||
ctx1 = context.get_default_context()
|
||||
t = threading.Thread(target=run_fn, args=(ctx1,))
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
def testScalarTensor(self):
|
||||
t = tensor.Tensor(3)
|
||||
self.assertEqual(t.numpy(), tensor.Tensor(np.array(3)).numpy())
|
||||
self.assertEqual(dtypes.int32, t.dtype)
|
||||
self.assertEqual(0, t.shape.ndims)
|
||||
self.assertAllEqual([], t.shape.as_list())
|
||||
|
||||
def testTensorAndNumpyMatrix(self):
|
||||
expected = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
|
||||
actual = tensor.Tensor([[1.0, 2.0], [3.0, 4.0]])
|
||||
self.assertAllEqual(expected, actual.numpy())
|
||||
self.assertEqual(np.float32, actual.numpy().dtype)
|
||||
self.assertEqual(dtypes.float32, actual.dtype)
|
||||
self.assertAllEqual([2, 2], actual.shape.as_list())
|
||||
|
||||
def testFloatDowncast(self):
|
||||
# Unless explicitly specified, float64->float32
|
||||
t = tensor.Tensor(3.0)
|
||||
self.assertEqual(dtypes.float32, t.dtype)
|
||||
t = tensor.Tensor(3.0, dtype=dtypes.float64)
|
||||
self.assertEqual(dtypes.float64, t.dtype)
|
||||
|
||||
def testBool(self):
|
||||
t = tensor.Tensor(False)
|
||||
if t:
|
||||
self.assertFalse(True)
|
||||
|
||||
def testIntDowncast(self):
|
||||
t = tensor.Tensor(3)
|
||||
self.assertEqual(dtypes.int32, t.dtype)
|
||||
t = tensor.Tensor(3, dtype=dtypes.int64)
|
||||
self.assertEqual(dtypes.int64, t.dtype)
|
||||
t = tensor.Tensor(2**33)
|
||||
self.assertEqual(dtypes.int64, t.dtype)
|
||||
|
||||
def testTensorCreationFailure(self):
|
||||
with self.assertRaises(Exception):
|
||||
# Should fail because the each row of the Python object has a different
|
||||
# number of columns.
|
||||
self.assertEqual(None, tensor.Tensor([[1], [1, 2]]))
|
||||
|
||||
def testTensorPlacement(self):
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest('No GPUs found')
|
||||
|
||||
x = tensor.Tensor(1.).as_gpu_tensor()
|
||||
with context.device('gpu:0'):
|
||||
y = tensor.Tensor(2.)
|
||||
# Add would fail if t2 were not on GPU
|
||||
result = execute.execute(
|
||||
'Add', 1, inputs=[x, y],
|
||||
attrs=('T', x.dtype.as_datatype_enum))[0].as_cpu_tensor().numpy()
|
||||
self.assertEqual(3, result)
|
||||
|
||||
def testNumpyOrderHandling(self):
|
||||
n = np.array([[1, 2], [3, 4]], order='F')
|
||||
t = tensor.Tensor(n)
|
||||
self.assertAllEqual([[1, 2], [3, 4]], t.numpy())
|
||||
|
||||
def testCopyBetweenDevices(self):
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest('No GPUs found')
|
||||
|
||||
cpu = tensor.Tensor([[1., 2.], [3., 4.]])
|
||||
c2g = cpu.as_gpu_tensor()
|
||||
# Exercise a copy from GPU to CPU, even though we ignore the value.
|
||||
_ = c2g.as_cpu_tensor()
|
||||
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
# c2g is on GPU. Copying between GPU devices fails
|
||||
# (must redirect through CPU for now).
|
||||
# TODO(ashankar): Perhaps the function should not fail and instead
|
||||
# faciliate the copy through host memory?
|
||||
c2g.as_gpu_tensor()
|
||||
|
||||
# Invalid device
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
cpu.as_gpu_tensor(context.context().num_gpus() + 1)
|
||||
|
||||
def testNumpyForceCPU(self):
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest('No GPUs found')
|
||||
|
||||
cpu = tensor.Tensor([[1., 2.], [3., 4.]])
|
||||
c2g = cpu.as_gpu_tensor()
|
||||
self.assertAllEqual(c2g.numpy(), cpu.numpy())
|
||||
|
||||
def testCopyFromCPUToCPU(self):
|
||||
ta = tensor.Tensor([[1, 2], [3, 4]])
|
||||
tb = ta.as_cpu_tensor()
|
||||
|
||||
self.assertNotEqual(ta._handle, tb._handle)
|
||||
self.assertAllEqual(ta.numpy(), tb.numpy())
|
||||
|
||||
def testRegisterExceptionClass(self):
|
||||
with self.assertRaises(TypeError):
|
||||
pywrap_tensorflow.TFE_Py_RegisterExceptionClass(str)
|
||||
pywrap_tensorflow.TFE_Py_RegisterExceptionClass(core._NotOkStatusException) # pylint: disable=protected-access
|
||||
|
||||
# TODO(agarwal): add tests passing incorrect typed values to attrs.
|
||||
def testExecuteBasic(self):
|
||||
three = tensor.Tensor(3)
|
||||
five = tensor.Tensor(5)
|
||||
product = execute.execute(
|
||||
'Mul',
|
||||
num_outputs=1,
|
||||
inputs=[three, five],
|
||||
attrs=('T', three.dtype.as_datatype_enum))[0]
|
||||
self.assertEqual(15, product.numpy())
|
||||
|
||||
def testExecuteTooManyNumOutputs(self):
|
||||
# num_outputs provided is 50, but only one output is produced.
|
||||
# That should be okay.
|
||||
product = execute.execute(
|
||||
'Mul',
|
||||
num_outputs=50,
|
||||
inputs=[tensor.Tensor(3), tensor.Tensor(5)],
|
||||
attrs=('T', dtypes.int32.as_datatype_enum))[0]
|
||||
self.assertEqual(15, product.numpy())
|
||||
|
||||
def testMatMulGPU(self):
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest('No GPUs found')
|
||||
three = tensor.Tensor([[3.]]).as_gpu_tensor()
|
||||
five = tensor.Tensor([[5.]]).as_gpu_tensor()
|
||||
product = execute.execute(
|
||||
'MatMul',
|
||||
num_outputs=1,
|
||||
inputs=[three, five],
|
||||
attrs=('transpose_a', False, 'transpose_b', False, 'T',
|
||||
three.dtype.as_datatype_enum))[0]
|
||||
self.assertEqual([[15.0]], product.numpy())
|
||||
|
||||
def testExecuteStringAttr(self):
|
||||
checked_three = execute.execute(
|
||||
'CheckNumerics',
|
||||
num_outputs=1,
|
||||
inputs=[tensor.Tensor(3.)],
|
||||
attrs=('message', 'just checking', 'T',
|
||||
dtypes.float32.as_datatype_enum))[0]
|
||||
self.assertEqual([[3]], checked_three.numpy())
|
||||
|
||||
def testExecuteStringAttrBadValue(self):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
_ = execute.execute(
|
||||
'CheckNumerics',
|
||||
num_outputs=1,
|
||||
inputs=[tensor.Tensor(3.)],
|
||||
attrs=('message', 1, 'T', dtypes.float32.as_datatype_enum))
|
||||
|
||||
def testExecuteFloatAttr(self):
|
||||
almost_equal = execute.execute(
|
||||
'ApproximateEqual',
|
||||
num_outputs=1,
|
||||
inputs=[tensor.Tensor(3.0), tensor.Tensor(2.9)],
|
||||
attrs=('tolerance', 0.3, 'T', dtypes.float32.as_datatype_enum))[0]
|
||||
self.assertTrue(almost_equal.numpy())
|
||||
|
||||
def testExecuteFloatAttrBadValue(self):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
_ = execute.execute(
|
||||
'ApproximateEqual',
|
||||
num_outputs=1,
|
||||
inputs=[tensor.Tensor(3.0), tensor.Tensor(2.9)],
|
||||
attrs=('tolerance', '0.3', 'T', dtypes.float32.as_datatype_enum))
|
||||
|
||||
def testExecuteIntAttr(self):
|
||||
total = execute.execute(
|
||||
'AddN',
|
||||
num_outputs=1,
|
||||
inputs=[tensor.Tensor(3), tensor.Tensor(4)],
|
||||
attrs=('T', dtypes.int32.as_datatype_enum, 'N', 2))[0]
|
||||
self.assertEqual(7, total.numpy())
|
||||
|
||||
def testExecuteIntAttrBadValue(self):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
_ = execute.execute(
|
||||
'AddN',
|
||||
num_outputs=1,
|
||||
inputs=[tensor.Tensor(3), tensor.Tensor(4)],
|
||||
attrs=('T', dtypes.int32.as_datatype_enum, 'N', '2'))
|
||||
|
||||
# Looks like we don't have an existing op with list(bool) attrs.
|
||||
def testExecuteBoolAttr(self):
|
||||
product = execute.execute(
|
||||
'MatMul',
|
||||
num_outputs=1,
|
||||
inputs=[tensor.Tensor([[3]]), tensor.Tensor([[5]])],
|
||||
attrs=('transpose_a', True, 'transpose_b', False, 'T',
|
||||
dtypes.int32.as_datatype_enum))[0]
|
||||
self.assertEqual([[15]], product.numpy())
|
||||
|
||||
def testExecuteShapeAttr(self):
|
||||
execute.execute(
|
||||
'VarHandleOp',
|
||||
num_outputs=1,
|
||||
inputs=[],
|
||||
attrs=('shape', [1, 2], 'dtype', dtypes.int32.as_datatype_enum,
|
||||
'container', '', 'shared_name', ''))
|
||||
|
||||
def testExecuteShapeAttrBadValue(self):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
execute.execute(
|
||||
'VarHandleOp',
|
||||
num_outputs=1,
|
||||
inputs=[],
|
||||
attrs=('shape', 1, 'dtype', dtypes.int32.as_datatype_enum,
|
||||
'container', '', 'shared_name', ''))
|
||||
|
||||
def testExecuteListStringAttr(self):
|
||||
execute.execute(
|
||||
'TensorSummary',
|
||||
num_outputs=1,
|
||||
inputs=[tensor.Tensor(3.0)],
|
||||
attrs=('T', dtypes.float32.as_datatype_enum, 'description',
|
||||
'tensor_summary', 'labels', ['3',
|
||||
'summary'], 'display_name', 'test'))
|
||||
|
||||
def testExecuteListStringAttrBadValue(self):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
execute.execute(
|
||||
'TensorSummary',
|
||||
num_outputs=1,
|
||||
inputs=[tensor.Tensor(3.0)],
|
||||
attrs=('T', dtypes.float32.as_datatype_enum, 'description', '',
|
||||
'labels', 3, 'display_name', 'test'))
|
||||
|
||||
def testExecuteListStringAttrBadListValue(self):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
execute.execute(
|
||||
'TensorSummary',
|
||||
num_outputs=1,
|
||||
inputs=[tensor.Tensor(3.0)],
|
||||
attrs=('T', dtypes.float32.as_datatype_enum, 'description', '',
|
||||
'labels', [3], 'display_name', 'test'))
|
||||
|
||||
def testExecuteListFloatAttr(self):
|
||||
b = execute.execute(
|
||||
'Bucketize',
|
||||
num_outputs=1,
|
||||
inputs=[tensor.Tensor([3.0, 5.0, 7.0])],
|
||||
attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', [4.0,
|
||||
6.0]))[0]
|
||||
self.assertAllEqual([0, 1, 2], b.numpy())
|
||||
|
||||
def testExecuteListFloatAttrBadValue(self):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
execute.execute(
|
||||
'Bucketize',
|
||||
num_outputs=1,
|
||||
inputs=[tensor.Tensor([3.0, 5.0, 7.0])],
|
||||
attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', 4.0))
|
||||
|
||||
def testExecuteListFloatAttrBadListValue(self):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
execute.execute(
|
||||
'Bucketize',
|
||||
num_outputs=1,
|
||||
inputs=[tensor.Tensor([3.0, 5.0, 7.0])],
|
||||
attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries',
|
||||
['4.0', '6.0']))
|
||||
|
||||
def testExecuteListIntAttr(self):
|
||||
b = execute.execute(
|
||||
'Squeeze',
|
||||
num_outputs=1,
|
||||
inputs=[tensor.Tensor([[[3.0]]])],
|
||||
attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', [0, 2]))[0]
|
||||
self.assertAllEqual([3], b.numpy())
|
||||
|
||||
def testExecuteListIntAttrBadValue(self):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
execute.execute(
|
||||
'Squeeze',
|
||||
num_outputs=1,
|
||||
inputs=[tensor.Tensor([[[3.0]]])],
|
||||
attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', 0))
|
||||
|
||||
def testExecuteListIntAttrBadListValue(self):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
execute.execute(
|
||||
'Squeeze',
|
||||
num_outputs=1,
|
||||
inputs=[tensor.Tensor([[[3.0]]])],
|
||||
attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims',
|
||||
['0', '2']))
|
||||
|
||||
def testExecuteListTypeListShapeAttr(self):
|
||||
execute.execute(
|
||||
'Barrier',
|
||||
num_outputs=1,
|
||||
inputs=[],
|
||||
attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes',
|
||||
[[1, 2]], 'capacity', -1, 'container', '', 'shared_name', ''))
|
||||
|
||||
def testExecuteListTypeAttrBadValue(self):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
execute.execute(
|
||||
'Barrier',
|
||||
num_outputs=1,
|
||||
inputs=[],
|
||||
attrs=('component_types', dtypes.float64.as_datatype_enum, 'shapes',
|
||||
[[1, 2]], 'capacity', -1, 'container', '', 'shared_name', ''))
|
||||
|
||||
def testExecuteListTypeAttrBadListValue(self):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
execute.execute(
|
||||
'Barrier',
|
||||
num_outputs=1,
|
||||
inputs=[],
|
||||
attrs=('component_types', '1', 'shapes', [[1, 2]], 'capacity', -1,
|
||||
'container', '', 'shared_name', ''))
|
||||
|
||||
def testExecuteListShapeAttrBadValue(self):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
execute.execute(
|
||||
'Barrier',
|
||||
num_outputs=1,
|
||||
inputs=[],
|
||||
attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes',
|
||||
[1, 2], 'capacity', -1, 'container', '', 'shared_name', ''))
|
||||
|
||||
def testExecuteListShapeAttrBadListValue(self):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
execute.execute(
|
||||
'Barrier',
|
||||
num_outputs=1,
|
||||
inputs=[],
|
||||
attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes',
|
||||
[1], 'capacity', -1, 'container', '', 'shared_name', ''))
|
||||
|
||||
def testExecuteMultipleOutputs(self):
|
||||
split_dim = 1
|
||||
value = [[0, 1, 2], [3, 4, 5]]
|
||||
x1, x2, x3 = execute.execute(
|
||||
'Split',
|
||||
num_outputs=3,
|
||||
inputs=[tensor.Tensor(split_dim),
|
||||
tensor.Tensor(value)],
|
||||
attrs=('num_split', 3, 'T', dtypes.int32.as_datatype_enum))
|
||||
self.assertAllEqual([[0], [3]], x1.numpy())
|
||||
self.assertAllEqual([[1], [4]], x2.numpy())
|
||||
self.assertAllEqual([[2], [5]], x3.numpy())
|
||||
|
||||
def testExecuteBadNumOutputsArgument(self):
|
||||
with self.assertRaises(TypeError):
|
||||
execute.execute(
|
||||
'Relu', [],
|
||||
inputs=[tensor.Tensor(3.0)],
|
||||
attrs=('T', dtypes.float32.as_datatype_enum))
|
||||
|
||||
def testExecuteUnknownOp(self):
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
execute.execute('BlahBlahBlah', num_outputs=1, inputs=[], attrs=None)
|
||||
|
||||
def testExecuteUnknownAttr(self):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
execute.execute(
|
||||
'Identity',
|
||||
num_outputs=1,
|
||||
inputs=[tensor.Tensor(3)],
|
||||
attrs=('T', dtypes.int32.as_datatype_enum, 'unknown_attr', 'blah'))
|
||||
|
||||
def testComposition(self):
|
||||
|
||||
def add(x, y):
|
||||
return execute.execute(
|
||||
'Add',
|
||||
num_outputs=1,
|
||||
inputs=[x, y],
|
||||
attrs=('T', dtypes.int32.as_datatype_enum))[0]
|
||||
|
||||
x = tensor.Tensor(1)
|
||||
three_x = add(add(x, x), x)
|
||||
self.assertEquals(dtypes.int32, three_x.dtype)
|
||||
self.assertEquals(3, three_x.numpy())
|
||||
|
||||
def testOperationWithNoInputsRunsOnDevice(self):
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest('No GPUs found')
|
||||
shape = tensor.Tensor([], dtype=dtypes.int32)
|
||||
|
||||
# x: Run the "TruncatedNormal" op CPU and copy result to GPU.
|
||||
x = truncated_normal(shape).as_gpu_tensor()
|
||||
# y: Explicitly run the "TruncatedNormal" op on GPU.
|
||||
with context.device('gpu:0'):
|
||||
y = truncated_normal(shape)
|
||||
# Add would fail if x and y were not on the same device.
|
||||
execute.execute('Add', 1, inputs=[x, y],
|
||||
attrs=('T', x.dtype.as_datatype_enum))
|
||||
|
||||
def testInvalidDevice(self):
|
||||
with self.assertRaises(ValueError):
|
||||
with context.device('pu:0'):
|
||||
_ = tensor.Tensor(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
70
tensorflow/python/eager/custom_gradient.py
Normal file
70
tensorflow/python/eager/custom_gradient.py
Normal file
@ -0,0 +1,70 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Decorator to overrides the gradient for a function."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from autograd import core as ag_core
|
||||
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.eager import tensor as _tensor
|
||||
from tensorflow.python.framework import ops as tf_ops
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
def _watch_value_from_tape(tensor):
|
||||
for t in tape._tape_stack.stack: # pylint: disable=protected-access
|
||||
w = t.value.tensors.get(tape.tensor_id(tensor), None)
|
||||
if w is not None:
|
||||
return w
|
||||
return tensor
|
||||
|
||||
|
||||
def custom_gradient(f):
|
||||
"""Decorator to define a function with a custom gradient.
|
||||
|
||||
The input function is expected to return the tuple
|
||||
(results, gradient_function)
|
||||
|
||||
The output function will return results while possibly recording the
|
||||
gradient_function and inputs in the tape.
|
||||
|
||||
Args:
|
||||
f: function to be decorated.
|
||||
|
||||
Returns:
|
||||
decorated function.
|
||||
"""
|
||||
|
||||
def decorated(*args, **kwargs):
|
||||
"""Decorated function with custom gradient."""
|
||||
input_tensors = [_watch_value_from_tape(x) for x in args
|
||||
if isinstance(x, (_tensor.Tensor, tf_ops.Tensor))
|
||||
or ag_core.isnode(x)]
|
||||
result, grad_fn = f(*args, **kwargs)
|
||||
|
||||
flat_result = nest.flatten(result)
|
||||
flat_result = [ag_core.getval(x) for x in flat_result]
|
||||
flat_result = tape.record_operation(
|
||||
flat_result,
|
||||
input_tensors,
|
||||
[],
|
||||
grad_fn)
|
||||
flat_result = list(flat_result)
|
||||
return nest.pack_sequence_as(structure=result, flat_sequence=flat_result)
|
||||
|
||||
return decorated
|
241
tensorflow/python/eager/execute.py
Normal file
241
tensorflow/python/eager/execute.py
Normal file
@ -0,0 +1,241 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Functions called by the generated code to execute an eager-mode op."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from autograd import core as ag_core
|
||||
import six
|
||||
|
||||
from google.protobuf import text_format
|
||||
from tensorflow.core.framework import tensor_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import core
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.eager import tensor
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
def execute(op_name, num_outputs, inputs, attrs=None, name=None):
|
||||
"""Execute a TensorFlow operation.
|
||||
|
||||
Args:
|
||||
op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
|
||||
execute.
|
||||
num_outputs: The number of outputs of the operation to fetch.
|
||||
(Explicitly provided instead of being inferred for performance
|
||||
reasons).
|
||||
inputs: A list of inputs to the operation. Each entry should be a Tensor, or
|
||||
a value which can be passed to the Tensor constructor to create one.
|
||||
attrs: A tuple with alternating string attr names and attr values for this
|
||||
operation.
|
||||
name: Customized name for the operation.
|
||||
|
||||
Returns:
|
||||
None if there are no outputs, a single Tensor object if there is one output
|
||||
and a list of Tensor objects if there are multiple outputs.
|
||||
|
||||
Raises:
|
||||
An exception on error.
|
||||
"""
|
||||
ctx = context.get_default_context()
|
||||
# TODO(apassos) move this to convert_to_tensor
|
||||
inputs = [ag_core.getval(x) for x in inputs]
|
||||
# pylint: disable=protected-access
|
||||
input_handles = [c._handle for c in inputs]
|
||||
device_name = ctx.device_name
|
||||
try:
|
||||
outh = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
|
||||
op_name, input_handles, attrs,
|
||||
num_outputs)
|
||||
# pylint: enable=protected-access
|
||||
except core._NotOkStatusException as e: # pylint: disable=protected-access
|
||||
raise core._status_to_exception(e.code, e.message) # pylint: disable=protected-access
|
||||
# pylint: enable=protected-access
|
||||
|
||||
tensors = [tensor._tensor_from_handle(x) for x in outh] # pylint: disable=protected-access
|
||||
if core.active_trace() is not None:
|
||||
trace_name = name if name else op_name
|
||||
for t in tensors:
|
||||
# pylint: disable=protected-access
|
||||
core.active_trace().record_tensor(trace_name,
|
||||
tape.tensor_id(t),
|
||||
t._device_name(),
|
||||
t.shape.num_elements())
|
||||
# pylint: enable=protected-access
|
||||
return tensors
|
||||
|
||||
|
||||
def record_gradient(unused_op_name, unused_inputs, unused_attrs, results,
|
||||
unused_name):
|
||||
"""Import backprop if you want gradients recorded."""
|
||||
return results
|
||||
|
||||
|
||||
def make_float(v, arg_name):
|
||||
if not isinstance(v, compat.real_types):
|
||||
raise TypeError("Expected float for argument '%s' not %s." %
|
||||
(arg_name, repr(v)))
|
||||
return float(v)
|
||||
|
||||
|
||||
def make_int(v, arg_name):
|
||||
if isinstance(v, six.string_types):
|
||||
raise TypeError("Expected int for argument '%s' not %s." %
|
||||
(arg_name, repr(v)))
|
||||
try:
|
||||
return int(v)
|
||||
except (ValueError, TypeError):
|
||||
raise TypeError("Expected int for argument '%s' not %s." %
|
||||
(arg_name, repr(v)))
|
||||
|
||||
|
||||
def make_str(v, arg_name):
|
||||
if not isinstance(v, compat.bytes_or_text_types):
|
||||
raise TypeError("Expected string for argument '%s' not %s." %
|
||||
(arg_name, repr(v)))
|
||||
return compat.as_bytes(v) # Convert unicode strings to bytes.
|
||||
|
||||
|
||||
def make_bool(v, arg_name):
|
||||
if not isinstance(v, bool):
|
||||
raise TypeError("Expected bool for argument '%s' not %s." %
|
||||
(arg_name, repr(v)))
|
||||
return v
|
||||
|
||||
|
||||
def make_type(v, arg_name):
|
||||
try:
|
||||
v = dtypes.as_dtype(v).base_dtype
|
||||
except TypeError:
|
||||
raise TypeError("Expected DataType for argument '%s' not %s." %
|
||||
(arg_name, repr(v)))
|
||||
i = v.as_datatype_enum
|
||||
return i
|
||||
|
||||
|
||||
def make_shape(v, arg_name):
|
||||
"""Convert v into a list."""
|
||||
# Args:
|
||||
# v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.
|
||||
# arg_name: String, for error messages.
|
||||
|
||||
# Returns:
|
||||
# None if the rank is unknown, otherwise a list of ints (or Nones in the
|
||||
# position where the dimension is unknown).
|
||||
try:
|
||||
shape = tensor_shape.as_shape(v)
|
||||
except TypeError as e:
|
||||
raise TypeError("Error converting %s to a TensorShape: %s" % (arg_name, e))
|
||||
except ValueError as e:
|
||||
raise ValueError("Error converting %s to a TensorShape: %s" % (arg_name, e))
|
||||
if shape.ndims is None:
|
||||
return None
|
||||
else:
|
||||
return shape.as_list()
|
||||
|
||||
|
||||
def make_tensor(v, arg_name):
|
||||
"""Ensure v is a TensorProto."""
|
||||
if isinstance(v, tensor_pb2.TensorProto):
|
||||
return v
|
||||
elif isinstance(v, six.string_types):
|
||||
pb = tensor_pb2.TensorProto()
|
||||
text_format.Merge(v, pb)
|
||||
return pb
|
||||
raise TypeError(
|
||||
"Don't know how to convert %s to a TensorProto for argument '%s'" %
|
||||
(repr(v), arg_name))
|
||||
|
||||
|
||||
def args_to_matching_eager(l, default_dtype=None):
|
||||
"""Convert sequence `l` to eager same-type Tensors."""
|
||||
# TODO(josh11b): Could we do a better job if we also passed in the
|
||||
# allowed dtypes when that was known?
|
||||
|
||||
# Is some input already a Tensor with a dtype?
|
||||
dtype = None
|
||||
for t in l:
|
||||
if isinstance(ag_core.getval(t), tensor.Tensor):
|
||||
dtype = t.dtype
|
||||
break
|
||||
|
||||
if dtype is None:
|
||||
# TODO(josh11b): At the moment, I don't think this can fail, but at some
|
||||
# point we likely should have some logic to prevent bad conversions.
|
||||
dtype = default_dtype
|
||||
|
||||
if dtype is None:
|
||||
# Infer a dtype based on the first value, and use that dtype for the
|
||||
# remaining values.
|
||||
ret = []
|
||||
for t in l:
|
||||
ret.append(tensor.convert_to_eager_tensor(t, dtype))
|
||||
if dtype is None:
|
||||
dtype = ret[-1].dtype
|
||||
else:
|
||||
ret = [tensor.convert_to_eager_tensor(t, dtype) for t in l]
|
||||
|
||||
return dtype, ret
|
||||
|
||||
|
||||
def convert_to_mixed_eager_tensors(values):
|
||||
v = [t if isinstance(ag_core.getval(t), tensor.Tensor) else tensor.Tensor(t)
|
||||
for t in values]
|
||||
types = [t.dtype for t in v]
|
||||
return types, v
|
||||
|
||||
|
||||
def args_to_mixed_eager_tensors(lists):
|
||||
"""Converts a list of same-length lists of values to eager tensors."""
|
||||
assert len(lists) > 1
|
||||
|
||||
# Generate an error if len(lists[i]) is not the same for all i.
|
||||
lists_ret = []
|
||||
for l in lists[1:]:
|
||||
if len(l) != len(lists[0]):
|
||||
raise ValueError(
|
||||
"Expected list arguments to be the same length: %d != %d (%r vs. %r)"
|
||||
% (len(lists[0]), len(l), lists[0], l))
|
||||
lists_ret.append([])
|
||||
|
||||
# Convert the first element of each list first, then the second element, etc.
|
||||
types = []
|
||||
for i in range(len(lists[0])):
|
||||
dtype = None
|
||||
# If any list has a Tensor, use that dtype
|
||||
for l in lists:
|
||||
if isinstance(ag_core.getval(l[i]), tensor.Tensor):
|
||||
dtype = l[i].dtype
|
||||
break
|
||||
if dtype is None:
|
||||
# Convert the first one and use its dtype.
|
||||
lists_ret[0].append(tensor.convert_to_eager_tensor(lists[0][i]))
|
||||
dtype = lists_ret[0][i].dtype
|
||||
for j in range(1, len(lists)):
|
||||
lists_ret[j].append(
|
||||
tensor.convert_to_eager_tensor(lists[j][i], dtype=dtype))
|
||||
else:
|
||||
# Convert everything to the found dtype.
|
||||
for j in range(len(lists)):
|
||||
lists_ret[j].append(
|
||||
tensor.convert_to_eager_tensor(lists[j][i], dtype=dtype))
|
||||
types.append(dtype)
|
||||
return types, lists_ret
|
518
tensorflow/python/eager/function.py
Normal file
518
tensorflow/python/eager/function.py
Normal file
@ -0,0 +1,518 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=unidiomatic-typecheck
|
||||
"""Defun decorator for defining graph-mode functions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import threading
|
||||
|
||||
from autograd import core as ag_core
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import core
|
||||
from tensorflow.python.eager import execute
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.eager import tensor
|
||||
from tensorflow.python.eager.graph_only_ops import graph_placeholder
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import graph_to_function_def
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
# Thread-local storage for tfe Tensors which are referenced while evaluating a
|
||||
# graph-mode function.
|
||||
_scoped_captures = threading.local()
|
||||
# _scoped_captures.tensors is either None or a map from tfe.Tensor id to a pair
|
||||
# of a tfe tensor and its corresponding placeholder to pass as a function
|
||||
# argument. The value should be None unless we're in function definition
|
||||
# context.
|
||||
_scoped_captures.tensors = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def capture_tensors(captures):
|
||||
old = _scoped_captures.__dict__.get("tensors", None)
|
||||
try:
|
||||
_scoped_captures.tensors = captures
|
||||
yield
|
||||
finally:
|
||||
_scoped_captures.tensors = old
|
||||
|
||||
|
||||
def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
|
||||
"""Captures a tfe Tensor while building a graph mode function.
|
||||
|
||||
Creates a placeholder to pass the tensor as an argument.
|
||||
|
||||
Arguments:
|
||||
value: A tfe.Tensor object
|
||||
dtype: The datatype of the value produced by the node in the graph.
|
||||
name: Name of the node in the graph.
|
||||
as_ref: Ignored (required by register_tensor_conversion_function).
|
||||
|
||||
Returns:
|
||||
A placeholder which will, at runtime, have the value of this tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: if called outside a defun context.
|
||||
"""
|
||||
_ = as_ref
|
||||
tensor_map = _scoped_captures.tensors
|
||||
if tensor_map is None:
|
||||
raise ValueError(
|
||||
"Trying to use tfe.Tensor objects in a graph outside graph mode. "
|
||||
"To build a graph use tfe.defun or tfe.func_to_object.")
|
||||
captured_value = tensor_map.get(tape.tensor_id(value), None)
|
||||
if captured_value is None:
|
||||
captured_value = graph_placeholder(dtype=dtype or value.dtype,
|
||||
shape=value.shape,
|
||||
name=name)
|
||||
if captured_value.dtype == dtypes.resource:
|
||||
captured_value._handle_data = value._handle_data # pylint: disable=protected-access
|
||||
tensor_map[tape.tensor_id(value)] = (value, captured_value)
|
||||
else:
|
||||
captured_value = captured_value[1]
|
||||
return captured_value
|
||||
|
||||
|
||||
# TODO(apassos): it'd be really nice if we could scope this registration.
|
||||
ops.register_tensor_conversion_function(tensor.Tensor,
|
||||
_convert_to_graph_constant)
|
||||
|
||||
|
||||
class _CapturingContext(object):
|
||||
"""Tracks references to Tensors outside this context while it is active."""
|
||||
|
||||
def __init__(self):
|
||||
# known_ops are ops which are created while this context is active
|
||||
self.known_ops = set()
|
||||
|
||||
# captured_tensors are all tensors referenced to by ops in this context but
|
||||
# not produced in it
|
||||
self.captured_tensors = set()
|
||||
|
||||
def AddOp(self, op): # pylint: disable=invalid-name
|
||||
if op.type in ["Variable", "VariableV2", "VarHandleOp"]:
|
||||
raise ValueError("tfe.defun cannot capture variables created without "
|
||||
"using tf.get_variable. Op: %s" % op)
|
||||
self.known_ops.add(op)
|
||||
for i in op.inputs:
|
||||
if i.op not in self.known_ops:
|
||||
self.captured_tensors.add(i)
|
||||
|
||||
def __enter__(self):
|
||||
self._g = ops.get_default_graph()
|
||||
self._old = self._g._get_control_flow_context() # pylint: disable=protected-access
|
||||
self._g._set_control_flow_context(self) # pylint: disable=protected-access
|
||||
|
||||
def __exit__(self, _, __, ___): # pylint: disable=invalid-name
|
||||
self._g._set_control_flow_context(self._old) # pylint: disable=protected-access
|
||||
|
||||
|
||||
def _forward_name(n):
|
||||
"""The name of a generated forward defun named n."""
|
||||
return "__forward_%s_%s" % (n, core.uid())
|
||||
|
||||
|
||||
def _backward_name(n):
|
||||
"""The name of a generated backward defun named n."""
|
||||
return "__backward_%s_%s" % (n, core.uid())
|
||||
|
||||
|
||||
def _inference_name(n):
|
||||
"""The name of a forward-but-no-gradient defun named n."""
|
||||
return "__inference_%s_%s" % (n, core.uid())
|
||||
|
||||
|
||||
class _DefinedFunction(object):
|
||||
"""Mocks the interface of tf _DefinedFunction."""
|
||||
|
||||
def __init__(self, fdef):
|
||||
self.definition = fdef
|
||||
self.name = fdef.signature.name
|
||||
self.grad_func_name = None
|
||||
self.python_grad_func = None
|
||||
|
||||
|
||||
def _map_sequence_obj_to_idx(sequence):
|
||||
"""Maps objs in the sequence from id(obj) to sequence index."""
|
||||
return {id(x): i for i, x in enumerate(sequence)}
|
||||
|
||||
|
||||
class _GraphModeFunction(object):
|
||||
"""Callable object representing a graph-mode function.
|
||||
|
||||
Args:
|
||||
input_placeholders: list of placeholder values to feed when calling
|
||||
the wrapped function.
|
||||
extra_inputs: Tensor inputs this function definition closed over which
|
||||
are passed as arguments. Need to track so gradients are supported
|
||||
correctly.
|
||||
fdef: the function definition we want to call.
|
||||
graph: the graph from which the fdef operations were pulled. Used as
|
||||
a context when computing gradients.
|
||||
operations: the subset of operations in the graph used in the function
|
||||
definition.
|
||||
func_outputs: the python outputs of the graph-mode function, with
|
||||
tensorflow.Tensor objects to be replaced by tfe values when called.
|
||||
func_outputs_to_fdef_outputs: Maps id(obj) in func_outputs to index of
|
||||
fdef's outputs. It allows mapping fdef output tensors to nested
|
||||
func_outputs structure.
|
||||
output_shapes: List of shapes of all tensors which are output by the
|
||||
internal function.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_placeholders,
|
||||
extra_inputs,
|
||||
fdef,
|
||||
graph,
|
||||
operations,
|
||||
func_outputs,
|
||||
func_outputs_to_fdef_outputs,
|
||||
output_shapes):
|
||||
assert len(input_placeholders) == len(fdef.signature.input_arg), "%s %s" % (
|
||||
len(input_placeholders), len(fdef.signature.input_arg))
|
||||
self._input_placeholders = input_placeholders
|
||||
self._extra_inputs = list(extra_inputs)
|
||||
self._graph = graph
|
||||
self._has_backprop = False
|
||||
self._func_name = fdef.signature.name
|
||||
self._fdef = _DefinedFunction(fdef)
|
||||
self._num_outputs = len(fdef.signature.output_arg)
|
||||
self._ops = operations
|
||||
self._func_outputs = func_outputs
|
||||
if (isinstance(func_outputs, (ops.Tensor, type(None)))
|
||||
or ag_core.isnode(func_outputs)):
|
||||
self._returns = [func_outputs]
|
||||
else:
|
||||
self._returns = list(func_outputs)
|
||||
self._returns_to_fedf_outputs = func_outputs_to_fdef_outputs
|
||||
self._output_shapes = output_shapes
|
||||
|
||||
def _compute_backprop(self):
|
||||
"""Computes the backprop function object for this function."""
|
||||
self._has_backprop = True
|
||||
with self._graph.as_default(), context.graph_mode():
|
||||
c = _CapturingContext()
|
||||
with c:
|
||||
filtered_outputs = [ag_core.getval(x)
|
||||
for x in self._returns if x is not None]
|
||||
self._out_grad_placeholders = [
|
||||
graph_placeholder(x.dtype, x.shape)
|
||||
for x in filtered_outputs
|
||||
]
|
||||
in_gradients = gradients_impl.gradients(
|
||||
filtered_outputs,
|
||||
self._input_placeholders,
|
||||
grad_ys=self._out_grad_placeholders)
|
||||
shapes = [x.shape for x in in_gradients if x is not None]
|
||||
captures = list(sorted(c.captured_tensors, key=lambda x: x.name))
|
||||
forward_function_def = graph_to_function_def.graph_to_function_def(
|
||||
self._graph, self._ops,
|
||||
self._input_placeholders,
|
||||
filtered_outputs + captures)
|
||||
self._forward_fdef = _DefinedFunction(forward_function_def)
|
||||
_register_with_name(_forward_name(self._func_name),
|
||||
forward_function_def)
|
||||
backward_outputs = [x for x in in_gradients if x is not None]
|
||||
all_inputs = self._out_grad_placeholders + captures
|
||||
backward_function_def = graph_to_function_def.graph_to_function_def(
|
||||
self._graph,
|
||||
[x.op for x in self._out_grad_placeholders] +
|
||||
list(sorted(c.known_ops, key=lambda x: x.name)),
|
||||
all_inputs,
|
||||
backward_outputs)
|
||||
_register_with_name(_backward_name(self._func_name), backward_function_def)
|
||||
self._backward_function = _GraphModeFunction(
|
||||
all_inputs, [], backward_function_def, self._graph, c.known_ops,
|
||||
in_gradients, _map_sequence_obj_to_idx(backward_outputs), shapes)
|
||||
|
||||
def _backprop_call(self, args):
|
||||
"""Calls the wrapped function and records the result on a tape."""
|
||||
all_args = args + self._extra_inputs
|
||||
signature = self._forward_fdef.definition.signature
|
||||
if context.in_graph_mode():
|
||||
g = ops.get_default_graph()
|
||||
g._add_function(self._forward_fdef) # pylint: disable=protected-access
|
||||
unwrapped_args = [ag_core.getval(x) for x in all_args]
|
||||
op = g.create_op(signature.name,
|
||||
[ops.convert_to_tensor(x) for x in unwrapped_args],
|
||||
[dtypes.DType(x.type) for x in signature.output_arg],
|
||||
op_def=signature,
|
||||
name="FunctionCall",
|
||||
compute_shapes=False)
|
||||
outputs = op.outputs
|
||||
outputs = [outputs] if isinstance(
|
||||
outputs, (tensor.Tensor, ops.Tensor, type(None))) else list(outputs)
|
||||
for i, s in enumerate(self._output_shapes):
|
||||
outputs[i].set_shape(s)
|
||||
else:
|
||||
outputs = execute.execute(
|
||||
signature.name,
|
||||
num_outputs=len(signature.output_arg),
|
||||
inputs=all_args)
|
||||
real_outputs = outputs[:len(self._returns)]
|
||||
side_outputs = outputs[len(self._returns):]
|
||||
watched_extra_inputs = []
|
||||
for t in self._extra_inputs:
|
||||
tid = tape.tensor_id(t)
|
||||
for t in tape._tape_stack.stack: # pylint: disable=protected-access
|
||||
w = t.value.tensors.get(tid, None)
|
||||
if w is not None:
|
||||
watched_extra_inputs.append(w)
|
||||
break
|
||||
else: # Note: for-else here done on purpose
|
||||
watched_extra_inputs.append(t)
|
||||
real_outputs = tape.record_operation(real_outputs,
|
||||
(args + watched_extra_inputs),
|
||||
side_outputs,
|
||||
self._backward_function)
|
||||
|
||||
return self._build_call_outputs(self._returns, real_outputs)
|
||||
|
||||
def __call__(self, *args):
|
||||
"""Executes the passed function in eager mode."""
|
||||
tensor_inputs = [x for x in nest.flatten(args)
|
||||
if isinstance(x, (tensor.Tensor, ops.Tensor,
|
||||
tensor.LazyZero))
|
||||
or ag_core.isnode(x)]
|
||||
if tape.should_record(tensor_inputs) or any(
|
||||
tape.any_tape_has(t) for t in self._extra_inputs):
|
||||
if not self._has_backprop:
|
||||
self._compute_backprop()
|
||||
return self._backprop_call(tensor_inputs)
|
||||
|
||||
if context.in_graph_mode():
|
||||
g = ops.get_default_graph()
|
||||
g._add_function(self._fdef) # pylint: disable=protected-access
|
||||
signature = self._fdef.definition.signature
|
||||
args = list(tensor_inputs) + self._extra_inputs
|
||||
op = g.create_op(signature.name,
|
||||
[ops.convert_to_tensor(x) for x in args],
|
||||
[dtypes.DType(x.type) for x in signature.output_arg],
|
||||
op_def=signature,
|
||||
name="FunctionCall",
|
||||
compute_shapes=False)
|
||||
result = op.outputs
|
||||
for i, s in enumerate(self._output_shapes):
|
||||
result[i].set_shape(s)
|
||||
else:
|
||||
tensor_inputs = [x.tensor() if isinstance(x, tensor.LazyZero) else x
|
||||
for x in tensor_inputs]
|
||||
result = execute.execute(
|
||||
self._func_name,
|
||||
num_outputs=self._num_outputs,
|
||||
inputs=tensor_inputs + self._extra_inputs)
|
||||
|
||||
return self._build_call_outputs(self._returns, result)
|
||||
|
||||
def _build_call_outputs(self, func_outputs, result):
|
||||
"""Maps the fdef output list to actual output structure.
|
||||
|
||||
Args:
|
||||
func_outputs: The outputs originally defined by the graph function. It
|
||||
could potentially be a nested structure.
|
||||
result: Output lists defined by FunctionDef.
|
||||
Returns:
|
||||
The actual call output.
|
||||
"""
|
||||
if self._func_outputs is None:
|
||||
return None
|
||||
if isinstance(ag_core.getval(self._func_outputs), ops.Tensor):
|
||||
return result[0]
|
||||
|
||||
outputs = []
|
||||
for o in func_outputs:
|
||||
vo = ag_core.getval(o)
|
||||
if isinstance(vo, ops.Tensor):
|
||||
outputs.append(result[self._returns_to_fedf_outputs[id(vo)]])
|
||||
elif type(vo) in (tuple, list):
|
||||
outputs.append(self._build_call_outputs(o, result))
|
||||
else:
|
||||
outputs.append(o)
|
||||
|
||||
return tuple(outputs) if type(func_outputs) is tuple else outputs
|
||||
|
||||
|
||||
def _get_defun_inputs(args):
|
||||
"""Maps the inputs args to graph inputs."""
|
||||
ret = []
|
||||
for a in args:
|
||||
a = ag_core.getval(a)
|
||||
if isinstance(a, (tensor.LazyZero, ops.Tensor, tensor.Tensor)):
|
||||
ret.append(graph_placeholder(a.dtype, a.shape))
|
||||
elif type(a) in (tuple, list):
|
||||
ret.append(_get_defun_inputs(a))
|
||||
else:
|
||||
ret.append(a)
|
||||
return tuple(ret) if type(args) is tuple else ret
|
||||
|
||||
|
||||
def _defun_internal(name, func, args, kwds):
|
||||
"""Defines and returns graph-mode version of func."""
|
||||
with context.graph_mode():
|
||||
tmp_graph = ops.Graph()
|
||||
with tmp_graph.as_default():
|
||||
func_inputs = _get_defun_inputs(args)
|
||||
|
||||
captures = {}
|
||||
with capture_tensors(captures):
|
||||
func_outputs = func(*func_inputs, **kwds)
|
||||
ids = list(sorted(captures.keys()))
|
||||
if ids:
|
||||
extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids])
|
||||
else:
|
||||
extra_inputs = []
|
||||
extra_placeholders = []
|
||||
outputs_list = nest.flatten(func_outputs)
|
||||
output_shapes = [x.shape for x in outputs_list if x is not None]
|
||||
|
||||
flat_inputs = [x for x in nest.flatten(func_inputs)
|
||||
if isinstance(x, ops.Tensor)]
|
||||
all_inputs = flat_inputs + list(extra_placeholders)
|
||||
|
||||
func_def_outputs = [ag_core.getval(x) for x in outputs_list if x is not None]
|
||||
inference_function_def = graph_to_function_def.graph_to_function_def(
|
||||
tmp_graph, tmp_graph.get_operations(),
|
||||
all_inputs,
|
||||
func_def_outputs)
|
||||
# Register any other functions defined in the graph
|
||||
# TODO(ashankar): Oh lord, forgive me for this lint travesty.
|
||||
for f in tmp_graph._functions.values(): # pylint: disable=protected-access
|
||||
# TODO(ashankar): What about the gradient registry?
|
||||
_register_with_name(f.name, f.definition)
|
||||
_register_with_name(_inference_name(name), inference_function_def)
|
||||
|
||||
return _GraphModeFunction(
|
||||
all_inputs,
|
||||
extra_inputs,
|
||||
inference_function_def,
|
||||
tmp_graph,
|
||||
tmp_graph.get_operations(),
|
||||
func_outputs,
|
||||
_map_sequence_obj_to_idx(func_def_outputs),
|
||||
output_shapes)
|
||||
|
||||
|
||||
# Defun uses this instead of Tensor as a cache key. Using dtype because
|
||||
# TensorFlow graphs are not parametric wrt dtypes, and using shapes for
|
||||
# performance reasons, as much TensorFlow code specializes on known shapes to
|
||||
# produce slimmer graphs.
|
||||
_TensorDtype = collections.namedtuple("_TensorDtype", ["dtype", "shape"])
|
||||
_ZeroDtype = collections.namedtuple("_ZeroDtype", ["dtype", "shape"])
|
||||
|
||||
|
||||
def _cache_key(x):
|
||||
"""Cache key for tfe functions."""
|
||||
x = ag_core.getval(x)
|
||||
if isinstance(x, tensor.Tensor):
|
||||
return _TensorDtype(x.dtype, x._shape_tuple()) # pylint: disable=protected-access
|
||||
if isinstance(x, tensor.LazyZero):
|
||||
return _TensorDtype(x.dtype, tuple(x.shape.as_list())) # pylint: disable=protected-access
|
||||
if isinstance(x, np.ndarray):
|
||||
return ("array", x.shape, tuple(x.reshape(-1)))
|
||||
if type(x) in (list, tuple):
|
||||
return tuple([_cache_key(a) for a in x])
|
||||
return x
|
||||
|
||||
|
||||
def register_function_def(fdef):
|
||||
fdef_string = fdef.SerializeToString()
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
pywrap_tensorflow.TFE_ContextAddFunctionDef(
|
||||
context.get_default_context()._handle, # pylint: disable=protected-access
|
||||
fdef_string,
|
||||
len(fdef_string),
|
||||
status)
|
||||
|
||||
|
||||
def _register_with_name(name, fdef):
|
||||
"""Registers the function `fdef` with the name `name`."""
|
||||
fdef.signature.name = name
|
||||
register_function_def(fdef)
|
||||
|
||||
|
||||
# TODO(apassos): better error messages for non-hashable arguments.
|
||||
def named_defun(func, name):
|
||||
"""Defines a function with a given name.
|
||||
|
||||
See the documentation for `defun` for more information on the semantics of the
|
||||
function.
|
||||
|
||||
Args:
|
||||
func: the function to be wrapped.
|
||||
name: the name given to it.
|
||||
|
||||
Returns:
|
||||
the wrapped function.
|
||||
"""
|
||||
arguments_to_functions = {}
|
||||
|
||||
def decorated(*args, **kwds):
|
||||
"""Decorated version of func."""
|
||||
# Macroexpand on non-Tensor arguments
|
||||
cache_key = tuple(_cache_key(x) for x in args)
|
||||
assert all(not isinstance(x, tensor.Tensor) for x in kwds.values())
|
||||
cache_key = (cache_key, tuple(kwds.items()))
|
||||
|
||||
if cache_key not in arguments_to_functions:
|
||||
arguments_to_functions[cache_key] = _defun_internal(
|
||||
name, func, args, kwds)
|
||||
return arguments_to_functions[cache_key](*args)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def defun(func):
|
||||
"""Decorator to compile func into graph_mode.
|
||||
|
||||
defun converts a function that constructs a TensorFlow graph into a function
|
||||
that executes the graph. TensorFlow graphs typically execute faster and with a
|
||||
lower memory-footprint than executing each of the operations that make up the
|
||||
function individually as the TensorFlow runtime can optimize the graph and
|
||||
execute sub-operations in parallel.
|
||||
|
||||
func must be a Python function that constructs a TensorFlow graph,
|
||||
typically using functions in the tensorflow module.
|
||||
|
||||
Arguments to func can be either tfe.Tensor objects or Python
|
||||
objects. Non-Tensor python objects are treated as constants, and new function
|
||||
definitions are created internally based on their values.
|
||||
|
||||
func must return a tf.Tensor (NOT a tfe.Tensor) or a list of tf.Tensor (NOT a
|
||||
tfe.Tensor). TODO(apassos) make the wrapped tfe ops return tf.Tensors when in
|
||||
graph mode.
|
||||
|
||||
TODO(apassos): deal with captured global state. Deal with control flow.
|
||||
|
||||
Args:
|
||||
func: function to be compiled.
|
||||
|
||||
Returns:
|
||||
A callable that will execute the compiled function (and return zero
|
||||
or more tfe.Tensor objects)
|
||||
"""
|
||||
return named_defun(func, func.__name__)
|
45
tensorflow/python/eager/gen_op.bzl
Normal file
45
tensorflow/python/eager/gen_op.bzl
Normal file
@ -0,0 +1,45 @@
|
||||
"""For eager-mode Python."""
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "clean_dep", "tf_copts")
|
||||
|
||||
def tfe_gen_op_wrapper_py(name,
|
||||
out=None,
|
||||
visibility=None,
|
||||
deps=[],
|
||||
generated_target_name=None):
|
||||
"""Generate an eager-mode Python op wrapper for an op library."""
|
||||
# Construct a cc_binary containing the specified ops.
|
||||
tool_name = "gen_" + name + "_py_wrappers_cc"
|
||||
if not deps:
|
||||
deps = [str(Label("//tensorflow/core:" + name + "_op_lib"))]
|
||||
native.cc_binary(
|
||||
name=tool_name,
|
||||
linkopts=["-lm"],
|
||||
copts=tf_copts(),
|
||||
linkstatic=1,
|
||||
deps=([
|
||||
clean_dep("//tensorflow/python/eager:python_eager_op_gen_main")
|
||||
] + deps),
|
||||
visibility=[clean_dep("//visibility:public")],)
|
||||
|
||||
# Invoke the previous cc_binary to generate a python file.
|
||||
if not out:
|
||||
out = "gen_" + name + ".py"
|
||||
|
||||
native.genrule(
|
||||
name=name + "_pygenrule",
|
||||
outs=[out],
|
||||
tools=[tool_name],
|
||||
cmd=("$(location " + tool_name + ") > $@"))
|
||||
|
||||
# Make a py_library out of the generated python file.
|
||||
if not generated_target_name:
|
||||
generated_target_name = name
|
||||
native.py_library(
|
||||
name=generated_target_name,
|
||||
srcs=[out],
|
||||
srcs_version="PY2AND3",
|
||||
visibility=visibility,
|
||||
deps=[
|
||||
clean_dep("//tensorflow/python/eager:framework_for_generated_wrappers"),
|
||||
],)
|
50
tensorflow/python/eager/graph_only_ops.py
Normal file
50
tensorflow/python/eager/graph_only_ops.py
Normal file
@ -0,0 +1,50 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Graph-only versions of a few op functions, for internal use only."""
|
||||
|
||||
# Must be separate from array_ops to avoid a cyclic dependency.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
|
||||
def graph_zeros_like(tensor):
|
||||
"""Graph-only version of tf.zeros_like(), for internal use only."""
|
||||
g = ops._get_graph_from_inputs([tensor]) # pylint: disable=protected-access
|
||||
with g.as_default(), ops.name_scope(None, "zeros_like", [tensor]) as name:
|
||||
tensor = ops.convert_to_tensor(tensor, name="tensor")
|
||||
dtype = tensor.dtype.base_dtype.as_datatype_enum
|
||||
dtype_value = attr_value_pb2.AttrValue(type=dtype)
|
||||
op = g.create_op("ZerosLike", [tensor], [dtype], input_types=[dtype],
|
||||
attrs={"T": dtype_value}, name=name)
|
||||
result, = op.outputs
|
||||
return result
|
||||
|
||||
|
||||
def graph_placeholder(dtype, shape, name=None):
|
||||
"""Graph-only version of tf.placeholder(), for internal use only."""
|
||||
dtype = dtype.base_dtype.as_datatype_enum
|
||||
dtype_value = attr_value_pb2.AttrValue(type=dtype)
|
||||
shape = attr_value_pb2.AttrValue(shape=shape.as_proto())
|
||||
g = ops.get_default_graph()
|
||||
with ops.name_scope(name, "placeholder", []) as name:
|
||||
op = g.create_op("Placeholder", [], [dtype], input_types=[],
|
||||
attrs={"dtype": dtype_value, "shape": shape}, name=name)
|
||||
result, = op.outputs
|
||||
return result
|
88
tensorflow/python/eager/memory_trace.py
Normal file
88
tensorflow/python/eager/memory_trace.py
Normal file
@ -0,0 +1,88 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utility to trace per-device memory consumption across time over execution."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
TraceEntry = collections.namedtuple(
|
||||
"TraceEntry", ["op_name", "tensor_id", "mem_usage", "device", "size"])
|
||||
TensorData = collections.namedtuple(
|
||||
"TensorData", ["op_name", "tensor_size", "device"])
|
||||
|
||||
|
||||
class MemoryTrace(object):
|
||||
"""Records a trace of memory usage over operation execution."""
|
||||
|
||||
def __init__(self, n_devices):
|
||||
|
||||
self.trace = []
|
||||
self.tensor_to_data = {}
|
||||
self.current_device_mem_usage = [0] * n_devices
|
||||
|
||||
def record_tensor(self, op_name, tensor_id, device, size):
|
||||
self.current_device_mem_usage[device] += size
|
||||
self.tensor_to_data[tensor_id] = TensorData(op_name, size, device)
|
||||
self.trace.append(TraceEntry(op_name,
|
||||
tensor_id,
|
||||
self.current_device_mem_usage[:],
|
||||
device,
|
||||
size))
|
||||
|
||||
def delete_tensor(self, tensor_id):
|
||||
if tensor_id not in self.tensor_to_data:
|
||||
return
|
||||
data = self.tensor_to_data.pop(tensor_id)
|
||||
self.current_device_mem_usage[data.device] -= data.tensor_size
|
||||
self.trace.append(TraceEntry(data.op_name,
|
||||
tensor_id,
|
||||
self.current_device_mem_usage[:],
|
||||
data.device,
|
||||
-data.tensor_size))
|
||||
|
||||
def flush_trace(self):
|
||||
"""Prints the formatted trace recorded so far."""
|
||||
longest_op_name = max(len(t.op_name) for t in self.trace)
|
||||
longest_op_name = max(longest_op_name, len("op_name"))
|
||||
longest_heap_size = max(max(len(str(d)) for d in t.mem_usage)
|
||||
for t in self.trace)
|
||||
longest_heap_size = max(longest_heap_size, len("d0"))
|
||||
longest_id_len = max(len(str(t.tensor_id)) for t in self.trace)
|
||||
longest_id_len = max(longest_id_len, 2)
|
||||
first_line = []
|
||||
first_line.append("+/-")
|
||||
first_line.append("op_name".ljust(longest_op_name))
|
||||
first_line.append("id".ljust(longest_id_len))
|
||||
for i in range(len(self.current_device_mem_usage)):
|
||||
first_line.append(("d"+str(i)).ljust(longest_heap_size))
|
||||
first_line.append("size")
|
||||
print(" | ".join(first_line))
|
||||
for t in self.trace:
|
||||
line = []
|
||||
if t.size > 0:
|
||||
line.append("+ ")
|
||||
else:
|
||||
line.append("- ")
|
||||
line.append(t.op_name.ljust(longest_op_name))
|
||||
line.append(str(t.tensor_id).ljust(longest_id_len))
|
||||
for d in t.mem_usage:
|
||||
line.append(str(d).ljust(longest_heap_size))
|
||||
line.append(str(t.size))
|
||||
print(" | ".join(line))
|
||||
self.trace = []
|
||||
print()
|
763
tensorflow/python/eager/python_eager_op_gen.cc
Normal file
763
tensorflow/python/eager/python_eager_op_gen.cc
Normal file
@ -0,0 +1,763 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/python/eager/python_eager_op_gen.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def.pb_text.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_def_util.h"
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/framework/tensor.pb_text.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/python/framework/python_op_gen_internal.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
const int kRightMargin = 78;
|
||||
|
||||
string AttrVarName(const string& attr_name,
|
||||
std::unordered_map<string, string>* attr_expressions) {
|
||||
const string var = strings::StrCat("_attr_", attr_name);
|
||||
if (attr_expressions != nullptr) (*attr_expressions)[attr_name] = var;
|
||||
return var;
|
||||
}
|
||||
|
||||
void AddInferredAttr(const string& attr_name, const string& value_expression,
|
||||
string* result,
|
||||
std::unordered_map<string, string>* attr_expressions) {
|
||||
strings::StrAppend(result, " ", AttrVarName(attr_name, attr_expressions),
|
||||
" = ", value_expression, "\n");
|
||||
}
|
||||
|
||||
string VectorToTuple(const std::vector<string>& l) {
|
||||
if (l.size() == 1) return strings::StrCat("(", l.front(), ",)");
|
||||
string ret = "(";
|
||||
for (int i = 0; i < l.size(); ++i) {
|
||||
if (i > 0) {
|
||||
strings::StrAppend(&ret, ", ");
|
||||
}
|
||||
strings::StrAppend(&ret, l[i]);
|
||||
}
|
||||
strings::StrAppend(&ret, ")");
|
||||
return ret;
|
||||
}
|
||||
|
||||
void Unflatten(const string& prefix, const std::vector<string>& output_sizes,
|
||||
const string& var, string* result) {
|
||||
for (int i = 0; i < output_sizes.size(); ++i) {
|
||||
if (!output_sizes[i].empty()) {
|
||||
strings::StrAppend(result, prefix, var, " = ");
|
||||
if (i > 0) strings::StrAppend(result, var, "[:", i, "] + ");
|
||||
if (i + 1 < output_sizes.size()) {
|
||||
// Special case i == 0 to avoid "0 +" in the generated code.
|
||||
if (i == 0) {
|
||||
strings::StrAppend(result, "[", var, "[:", output_sizes[i], "]] + ",
|
||||
var, "[", output_sizes[i], ":]");
|
||||
} else {
|
||||
strings::StrAppend(result, "[", var, "[", i, ":", i, " + ",
|
||||
output_sizes[i], "]] + ", var, "[", i, " + ",
|
||||
output_sizes[i], ":]");
|
||||
}
|
||||
} else {
|
||||
strings::StrAppend(result, "[", var, "[", i, ":]]");
|
||||
}
|
||||
strings::StrAppend(result, "\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
string TensorPBString(const TensorProto& pb) {
|
||||
// Note: This gets used in the argument list, and so must survive naive
|
||||
// word wrapping.
|
||||
return strings::StrCat("\"\"\"", ProtoShortDebugString(pb), "\"\"\"");
|
||||
}
|
||||
|
||||
class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
|
||||
public:
|
||||
GenEagerPythonOp(const OpDef& op_def, const string& function_name)
|
||||
: python_op_gen_internal::GenPythonOp(op_def, function_name) {
|
||||
op_name_ = function_name_;
|
||||
op_name_.Consume("_");
|
||||
}
|
||||
~GenEagerPythonOp() override {}
|
||||
|
||||
string Code() override;
|
||||
|
||||
protected:
|
||||
void ExpectListArg(const string& arg_name);
|
||||
void AddEagerInferredAttrs();
|
||||
void AddEagerInputCasts();
|
||||
void AddEagerAttrs();
|
||||
void AddEagerExecute(const string& num_outputs_expr);
|
||||
|
||||
void AddAttrForArg(const string& attr, int arg_index) {
|
||||
gtl::InsertIfNotPresent(&inferred_attrs_, attr,
|
||||
op_def_.input_arg(arg_index).name());
|
||||
auto iter = attr_to_args_.find(attr);
|
||||
if (iter == attr_to_args_.end()) {
|
||||
attr_to_args_.insert(AttrToArgMap::value_type(attr, {arg_index}));
|
||||
} else {
|
||||
iter->second.push_back(arg_index);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a string expression representing a flattened list of all
|
||||
// the inputs given by `*input_indices` (or all inputs if
|
||||
// `input_indices` is nullptr). `*output_sizes` can be used to unflatten.
|
||||
string FlattenInputs(const std::vector<int>* input_indices,
|
||||
std::vector<string>* output_sizes) const;
|
||||
|
||||
StringPiece op_name_;
|
||||
typedef std::unordered_map<string, std::vector<int>> AttrToArgMap;
|
||||
AttrToArgMap attr_to_args_;
|
||||
std::unordered_map<string, string> attr_expressions_;
|
||||
};
|
||||
|
||||
string GetEagerPythonOp(const OpDef& op_def, const string& function_name) {
|
||||
return GenEagerPythonOp(op_def, function_name).Code();
|
||||
}
|
||||
|
||||
string GenEagerPythonOp::FlattenInputs(
|
||||
const std::vector<int>* input_indices,
|
||||
std::vector<string>* output_sizes) const {
|
||||
string inputs;
|
||||
enum { STARTING, WAS_LIST_INPUT, WAS_SOLO_INPUT } inputs_state = STARTING;
|
||||
const int n = input_indices != nullptr ? input_indices->size()
|
||||
: op_def_.input_arg_size();
|
||||
for (int j = 0; j < n; ++j) {
|
||||
const int i = input_indices ? (*input_indices)[j] : j;
|
||||
const auto& arg(op_def_.input_arg(i));
|
||||
const bool is_list =
|
||||
!arg.type_list_attr().empty() || !arg.number_attr().empty();
|
||||
if (is_list) {
|
||||
if (inputs_state == WAS_SOLO_INPUT) {
|
||||
strings::StrAppend(&inputs, "] + ");
|
||||
} else if (inputs_state == WAS_LIST_INPUT) {
|
||||
strings::StrAppend(&inputs, " + ");
|
||||
}
|
||||
strings::StrAppend(&inputs, "list(", param_names_[i], ")");
|
||||
inputs_state = WAS_LIST_INPUT;
|
||||
if (output_sizes != nullptr) {
|
||||
if (!arg.number_attr().empty()) {
|
||||
output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr));
|
||||
} else {
|
||||
output_sizes->emplace_back(
|
||||
strings::StrCat("len(", param_names_[i], ")"));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (inputs_state == WAS_SOLO_INPUT) {
|
||||
strings::StrAppend(&inputs, ", ");
|
||||
} else if (inputs_state == WAS_LIST_INPUT) {
|
||||
strings::StrAppend(&inputs, " + [");
|
||||
} else {
|
||||
strings::StrAppend(&inputs, "[");
|
||||
}
|
||||
strings::StrAppend(&inputs, param_names_[i]);
|
||||
inputs_state = WAS_SOLO_INPUT;
|
||||
if (output_sizes != nullptr) output_sizes->emplace_back();
|
||||
}
|
||||
}
|
||||
if (inputs_state == STARTING) return "[]";
|
||||
if (inputs_state == WAS_SOLO_INPUT) {
|
||||
strings::StrAppend(&inputs, "]");
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
string GenEagerPythonOp::Code() {
|
||||
// This has all the input args followed by those attrs that don't have
|
||||
// defaults.
|
||||
std::vector<string> args_no_default;
|
||||
// The parameters with defaults (these have to be listed after those without).
|
||||
// No input args are included, just attrs.
|
||||
std::vector<std::pair<string, string>> args_with_defaults;
|
||||
for (int i = 0; i < op_def_.input_arg_size(); ++i) {
|
||||
const auto& arg(op_def_.input_arg(i));
|
||||
args_no_default.push_back(arg.name());
|
||||
if (!arg.type_attr().empty()) {
|
||||
AddAttrForArg(arg.type_attr(), i);
|
||||
} else if (!arg.type_list_attr().empty()) {
|
||||
AddAttrForArg(arg.type_list_attr(), i);
|
||||
}
|
||||
if (!arg.number_attr().empty()) {
|
||||
AddAttrForArg(arg.number_attr(), i);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < op_def_.attr_size(); ++i) {
|
||||
const auto& attr(op_def_.attr(i));
|
||||
// Do not add inferred attrs to the Python function signature.
|
||||
if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
|
||||
if (attr.has_default_value()) {
|
||||
if (attr.type() == "tensor") {
|
||||
args_with_defaults.emplace_back(
|
||||
attr.name(),
|
||||
strings::StrCat("_execute.make_tensor(",
|
||||
TensorPBString(attr.default_value().tensor()),
|
||||
", \"", attr.name(), "\")"));
|
||||
} else if (attr.type() == "list(tensor)") {
|
||||
std::vector<string> pbtxt;
|
||||
for (const auto& pb : attr.default_value().list().tensor()) {
|
||||
pbtxt.emplace_back(TensorPBString(pb));
|
||||
}
|
||||
args_with_defaults.emplace_back(
|
||||
attr.name(),
|
||||
strings::StrCat("[_execute.make_tensor(_pb, \"", attr.name(),
|
||||
"\") for _pb in ", VectorToTuple(pbtxt), "]"));
|
||||
} else {
|
||||
args_with_defaults.emplace_back(
|
||||
attr.name(), python_op_gen_internal::AttrValueToPython(
|
||||
attr.type(), attr.default_value(), "_dtypes."));
|
||||
}
|
||||
} else {
|
||||
args_no_default.push_back(attr.name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Save the list of attr parameters (attrs that won't be inferred),
|
||||
// those with defaults go at the end.
|
||||
// Get the attrs in the order we want by taking the attrs without defaults
|
||||
// from the end of args_no_default, and adding args_no_default.
|
||||
attrs_.reserve(args_no_default.size() - op_def_.input_arg_size() +
|
||||
args_with_defaults.size());
|
||||
attrs_.insert(attrs_.end(),
|
||||
args_no_default.begin() + op_def_.input_arg_size(),
|
||||
args_no_default.end());
|
||||
for (const auto& a : args_with_defaults) {
|
||||
attrs_.push_back(a.first);
|
||||
}
|
||||
|
||||
param_names_.reserve(args_no_default.size() + args_with_defaults.size());
|
||||
string parameters;
|
||||
for (const string& name : args_no_default) {
|
||||
if (!parameters.empty()) strings::StrAppend(¶meters, ", ");
|
||||
const string param = python_op_gen_internal::AvoidPythonReserved(name);
|
||||
strings::StrAppend(¶meters, param);
|
||||
param_names_.push_back(param);
|
||||
}
|
||||
for (const auto& name_default : args_with_defaults) {
|
||||
if (!parameters.empty()) strings::StrAppend(¶meters, ", ");
|
||||
const string param =
|
||||
python_op_gen_internal::AvoidPythonReserved(name_default.first);
|
||||
strings::StrAppend(¶meters, param, "=", name_default.second);
|
||||
param_names_.push_back(param);
|
||||
}
|
||||
if (!parameters.empty()) strings::StrAppend(¶meters, ", ");
|
||||
strings::StrAppend(¶meters, "name=None");
|
||||
|
||||
AddDefLine(parameters);
|
||||
AddDocStringDescription();
|
||||
AddDocStringArgs();
|
||||
AddDocStringInputs();
|
||||
AddDocStringAttrs();
|
||||
strings::StrAppend(
|
||||
&result_,
|
||||
" name: A name for the operation (optional, only for graph mode).\n");
|
||||
AddOutputGlobals();
|
||||
AddDocStringOutputs();
|
||||
strings::StrAppend(&result_, " \"\"\"\n");
|
||||
|
||||
// Function body.
|
||||
|
||||
// Validate list inputs, infer length attrs.
|
||||
for (int i = 0; i < op_def_.attr_size(); ++i) {
|
||||
const auto& attr(op_def_.attr(i));
|
||||
if (attr.type() == "int") {
|
||||
auto arg_list = attr_to_args_.find(attr.name());
|
||||
if (arg_list != attr_to_args_.end()) {
|
||||
// Inferred int attrs are the lengths of inputs. Validate those
|
||||
// inputs are lists and have the same length.
|
||||
for (auto iter = arg_list->second.begin();
|
||||
iter != arg_list->second.end(); ++iter) {
|
||||
const string& arg_name = param_names_[*iter];
|
||||
ExpectListArg(arg_name);
|
||||
if (iter == arg_list->second.begin()) {
|
||||
AddInferredAttr(attr.name(), strings::StrCat("len(", arg_name, ")"),
|
||||
&result_, &attr_expressions_);
|
||||
} else {
|
||||
const auto& attr_var = attr_expressions_[attr.name()];
|
||||
strings::StrAppend(&result_, " if len(", arg_name,
|
||||
") != ", attr_var,
|
||||
":\n"
|
||||
" raise ValueError(\n"
|
||||
" \"List argument '",
|
||||
arg_name, "' to '", op_name_,
|
||||
"' Op with length %d \"\n"
|
||||
" \"must match length %d of argument '",
|
||||
inferred_attrs_[attr.name()],
|
||||
"'.\" %\n"
|
||||
" (len(",
|
||||
arg_name, "), ", attr_var, "))\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Values for non-inferred attrs.
|
||||
for (int i = 0; i < attrs_.size(); ++i) {
|
||||
const string& attr_name = attrs_[i];
|
||||
const string& param = param_names_[i + op_def_.input_arg_size()];
|
||||
const auto& attr = *FindAttr(attr_name, op_def_);
|
||||
StringPiece attr_type = attr.type();
|
||||
attr_expressions_[attr_name] = param;
|
||||
const int default_index = i - (attrs_.size() - args_with_defaults.size());
|
||||
if (default_index >= 0) {
|
||||
const string& default_value = args_with_defaults[default_index].second;
|
||||
strings::StrAppend(&result_, " if ", param, " is None:\n");
|
||||
strings::StrAppend(&result_, " ", param, " = ", default_value, "\n");
|
||||
}
|
||||
if (attr_type.starts_with("list(")) {
|
||||
ExpectListArg(param);
|
||||
}
|
||||
|
||||
if (attr_type == "string") {
|
||||
strings::StrAppend(&result_, " ", param, " = _execute.make_str(", param,
|
||||
", \"", param, "\")\n");
|
||||
} else if (attr_type == "list(string)") {
|
||||
strings::StrAppend(&result_, " ", param, " = [_execute.make_str(_s, \"",
|
||||
param, "\") for _s in ", param, "]\n");
|
||||
} else if (attr_type == "int") {
|
||||
strings::StrAppend(&result_, " ", param, " = _execute.make_int(", param,
|
||||
", \"", param, "\")\n");
|
||||
} else if (attr_type == "list(int)") {
|
||||
strings::StrAppend(&result_, " ", param, " = [_execute.make_int(_i, \"",
|
||||
param, "\") for _i in ", param, "]\n");
|
||||
} else if (attr_type == "float") {
|
||||
strings::StrAppend(&result_, " ", param, " = _execute.make_float(",
|
||||
param, ", \"", param, "\")\n");
|
||||
} else if (attr_type == "list(float)") {
|
||||
strings::StrAppend(&result_, " ", param,
|
||||
" = [_execute.make_float(_f, \"", param,
|
||||
"\") for _f in ", param, "]\n");
|
||||
} else if (attr_type == "bool") {
|
||||
strings::StrAppend(&result_, " ", param, " = _execute.make_bool(", param,
|
||||
", \"", param, "\")\n");
|
||||
} else if (attr_type == "list(bool)") {
|
||||
strings::StrAppend(&result_, " ", param, " = [_execute.make_bool(_b, \"",
|
||||
param, "\") for _b in ", param, "]\n");
|
||||
} else if (attr_type == "type") {
|
||||
strings::StrAppend(&result_, " ", param, " = _execute.make_type(", param,
|
||||
", \"", param, "\")\n");
|
||||
} else if (attr_type == "list(type)") {
|
||||
strings::StrAppend(&result_, " ", param, " = [_execute.make_type(_t, \"",
|
||||
param, "\") for _t in ", param, "]\n");
|
||||
} else if (attr_type == "shape") {
|
||||
strings::StrAppend(&result_, " ", param, " = _execute.make_shape(",
|
||||
param, ", \"", param, "\")\n");
|
||||
} else if (attr_type == "list(shape)") {
|
||||
strings::StrAppend(&result_, " ", param,
|
||||
" = [_execute.make_shape(_s, \"", param,
|
||||
"\") for _s in ", param, "]\n");
|
||||
} else if (attr_type == "tensor") {
|
||||
strings::StrAppend(&result_, " ", param, " = _execute.make_tensor(",
|
||||
param, ", \"", param, "\")\n");
|
||||
} else if (attr_type == "list(tensor)") {
|
||||
strings::StrAppend(&result_, " ", param,
|
||||
" = [_execute.make_tensor(_t, \"", param,
|
||||
"\") for _t in ", param, "]\n");
|
||||
} else if (attr_type != "func") {
|
||||
return strings::StrCat("# No definition for ", function_name_,
|
||||
" since we don't support attrs with type\n"
|
||||
"# '",
|
||||
attr_type, "' right now.\n\n");
|
||||
}
|
||||
}
|
||||
|
||||
// Figure out the list of inputs.
|
||||
const string inputs = FlattenInputs(nullptr, nullptr);
|
||||
|
||||
// Handle graph-mode case
|
||||
strings::StrAppend(&result_,
|
||||
" if _context.in_graph_mode():\n"
|
||||
" _, _, _op = _op_def_lib._apply_op_helper(\n");
|
||||
AddBodyNoReturn(" ");
|
||||
if (num_outs_ > 0) {
|
||||
strings::StrAppend(&result_, " _result = _op.outputs[:]\n");
|
||||
// Special case handling for stateful op with single list output
|
||||
// that might be empty.
|
||||
if (num_outs_ == 1 && op_def_.is_stateful() &&
|
||||
(!op_def_.output_arg(0).number_attr().empty() ||
|
||||
!op_def_.output_arg(0).type_list_attr().empty())) {
|
||||
// TODO(josh11b): Can skip this if the number_attr/type_list_attr has
|
||||
// a constraint indicating that this can never be empty.
|
||||
strings::StrAppend(&result_,
|
||||
" if not _result:\n"
|
||||
" return _op\n");
|
||||
}
|
||||
strings::StrAppend(&result_, " _inputs_flat = ", inputs, "\n");
|
||||
|
||||
// Compute graph-mode attrs.
|
||||
if (op_def_.attr_size() > 0) {
|
||||
string attr_values;
|
||||
for (int i = 0; i < op_def_.attr_size(); ++i) {
|
||||
if (i > 0) strings::StrAppend(&attr_values, ", ");
|
||||
const auto& attr_name(op_def_.attr(i).name());
|
||||
strings::StrAppend(&attr_values, "\"", attr_name, "\", _op.get_attr(\"",
|
||||
attr_name, "\")");
|
||||
}
|
||||
strings::StrAppend(&attr_values, ")");
|
||||
strings::StrAppend(&result_,
|
||||
WordWrap(" _attrs = (", attr_values, kRightMargin),
|
||||
"\n");
|
||||
} else {
|
||||
strings::StrAppend(&result_, " _attrs = None\n");
|
||||
}
|
||||
} else {
|
||||
strings::StrAppend(&result_, " return _op\n");
|
||||
}
|
||||
|
||||
// Handle eager-mode case
|
||||
strings::StrAppend(&result_, " else:\n");
|
||||
|
||||
// Expression representing the number of outputs.
|
||||
int num_fixed_outputs = 0;
|
||||
string num_outputs_expr;
|
||||
// If output i is list output, output_sizes[i] will be set to a
|
||||
// string with the python expression that will evaluate to its
|
||||
// length. output_sizes[i] is empty for non-list outputs.
|
||||
std::vector<string> output_sizes(num_outs_);
|
||||
for (int i = 0; i < num_outs_; ++i) {
|
||||
const auto& arg(op_def_.output_arg(i));
|
||||
if (!arg.number_attr().empty()) {
|
||||
if (!num_outputs_expr.empty()) {
|
||||
strings::StrAppend(&num_outputs_expr, " + ");
|
||||
}
|
||||
output_sizes[i] = attr_expressions_[arg.number_attr()];
|
||||
strings::StrAppend(&num_outputs_expr, output_sizes[i]);
|
||||
} else if (!arg.type_list_attr().empty()) {
|
||||
if (!num_outputs_expr.empty()) {
|
||||
strings::StrAppend(&num_outputs_expr, " + ");
|
||||
}
|
||||
// Have to be careful to use an expression that works in both
|
||||
// graph and eager paths here.
|
||||
const auto iter = inferred_attrs_.find(arg.type_list_attr());
|
||||
if (iter == inferred_attrs_.end()) {
|
||||
output_sizes[i] = strings::StrCat(
|
||||
"len(", attr_expressions_[arg.type_list_attr()], ")");
|
||||
} else {
|
||||
output_sizes[i] = strings::StrCat("len(", iter->second, ")");
|
||||
}
|
||||
strings::StrAppend(&num_outputs_expr, output_sizes[i]);
|
||||
} else {
|
||||
++num_fixed_outputs;
|
||||
}
|
||||
}
|
||||
if (num_fixed_outputs > 0) {
|
||||
if (!num_outputs_expr.empty()) {
|
||||
strings::StrAppend(&num_outputs_expr, " + ");
|
||||
}
|
||||
strings::StrAppend(&num_outputs_expr, num_fixed_outputs);
|
||||
} else if (num_outputs_expr.empty()) {
|
||||
num_outputs_expr = "0";
|
||||
}
|
||||
|
||||
bool eager_allowed = true;
|
||||
for (const auto& arg : op_def_.input_arg()) {
|
||||
if (arg.is_ref()) eager_allowed = false;
|
||||
}
|
||||
for (const auto& arg : op_def_.output_arg()) {
|
||||
if (arg.is_ref()) eager_allowed = false;
|
||||
}
|
||||
|
||||
if (eager_allowed) {
|
||||
AddEagerInferredAttrs();
|
||||
AddEagerInputCasts();
|
||||
strings::StrAppend(&result_, " _inputs_flat = ", inputs, "\n");
|
||||
AddEagerAttrs();
|
||||
AddEagerExecute(num_outputs_expr);
|
||||
} else {
|
||||
strings::StrAppend(&result_,
|
||||
" raise RuntimeError(\n"
|
||||
" \"",
|
||||
op_name_, " op does not support eager execution.\")\n");
|
||||
}
|
||||
|
||||
if (num_outs_ > 0) {
|
||||
strings::StrAppend(&result_, " _result = _execute.record_gradient(\n",
|
||||
" \"", op_def_.name(),
|
||||
"\", _inputs_flat, _attrs, _result, name)\n");
|
||||
if (num_outs_ == 1 && !output_sizes[0].empty()) {
|
||||
// Single list result.
|
||||
} else if (num_outs_ == 1) {
|
||||
// Execute returns a single-element list which we need to destructure.
|
||||
strings::StrAppend(&result_, " _result, = _result\n");
|
||||
} else {
|
||||
// Have multiple outputs, so we will need to reformat the return
|
||||
// value of execute() to be a list with one entry per op output
|
||||
// (that entry will be a list of tensors if that output is of list
|
||||
// type).
|
||||
// For list outputs, convert the right subrange of _result into a list.
|
||||
Unflatten(" ", output_sizes, "_result", &result_);
|
||||
// Convert to a named tuple.
|
||||
strings::StrAppend(&result_, " _result = _", op_def_.name(),
|
||||
"Output._make(_result)\n");
|
||||
}
|
||||
}
|
||||
strings::StrAppend(&result_, " return _result\n\n");
|
||||
return prelude_ + result_;
|
||||
}
|
||||
|
||||
void GenEagerPythonOp::ExpectListArg(const string& arg_name) {
|
||||
strings::StrAppend(&result_, " if not isinstance(", arg_name,
|
||||
", (list, tuple)):\n"
|
||||
" raise TypeError(\n"
|
||||
" \"Expected list for '",
|
||||
arg_name,
|
||||
"' argument to \"\n"
|
||||
" \"'",
|
||||
op_name_, "' Op, not %r.\" % ", arg_name, ")\n");
|
||||
}
|
||||
|
||||
void GenEagerPythonOp::AddEagerInferredAttrs() {
|
||||
// Figure out values for inferred attrs, and cast to eager tensors.
|
||||
for (int i = 0; i < op_def_.attr_size(); ++i) {
|
||||
const auto& attr(op_def_.attr(i));
|
||||
auto arg_list = attr_to_args_.find(attr.name());
|
||||
if (arg_list != attr_to_args_.end()) {
|
||||
if (attr.type() == "type") {
|
||||
std::vector<string> output_sizes;
|
||||
const string flattened =
|
||||
FlattenInputs(&arg_list->second, &output_sizes);
|
||||
string conversion =
|
||||
strings::StrCat("_execute.args_to_matching_eager(", flattened);
|
||||
if (attr.has_default_value()) {
|
||||
strings::StrAppend(
|
||||
&conversion, ", ",
|
||||
python_op_gen_internal::AttrValueToPython(
|
||||
attr.type(), attr.default_value(), "_dtypes."));
|
||||
}
|
||||
strings::StrAppend(&conversion, ")");
|
||||
const string var_name = AttrVarName(attr.name(), &attr_expressions_);
|
||||
if (output_sizes.size() == 1) {
|
||||
// Avoid creating a temporary variable in the case where
|
||||
// we can easily assign to the right value directly.
|
||||
const string inputs_var = param_names_[arg_list->second.front()];
|
||||
if (output_sizes.front().empty()) {
|
||||
strings::StrAppend(&result_, " ", var_name, ", (", inputs_var,
|
||||
",) = ", conversion, "\n");
|
||||
} else {
|
||||
strings::StrAppend(&result_, " ", var_name, ", ", inputs_var,
|
||||
" = ", conversion, "\n");
|
||||
}
|
||||
} else {
|
||||
const string inputs_var = strings::StrCat("_inputs_", attr.name());
|
||||
strings::StrAppend(&result_, " ", var_name, ", ", inputs_var,
|
||||
" = ", conversion, "\n");
|
||||
// Convert from a flat list of eager tensors back to the
|
||||
// parameter variables.
|
||||
Unflatten(" ", output_sizes, inputs_var, &result_);
|
||||
std::vector<string> p;
|
||||
for (int j : arg_list->second) {
|
||||
p.emplace_back(param_names_[j]);
|
||||
}
|
||||
strings::StrAppend(&result_, " ", VectorToTuple(p), " = ",
|
||||
inputs_var, "\n");
|
||||
}
|
||||
strings::StrAppend(&result_, " ", var_name, " = ", var_name,
|
||||
".as_datatype_enum\n");
|
||||
} else if (attr.type() == "list(type)") {
|
||||
// NOTE: We ignore default values for these attrs, since it is
|
||||
// unclear how you would use it, and the one use case is
|
||||
// parse_single_sequence_example which only needs it for
|
||||
// backwards compatibility.
|
||||
const string var_name = AttrVarName(attr.name(), &attr_expressions_);
|
||||
string inputs_var;
|
||||
string conversion;
|
||||
if (arg_list->second.size() > 1) {
|
||||
// If you have more than one list(tensor) argument, their types
|
||||
// have to match.
|
||||
std::vector<string> lists;
|
||||
for (auto iter = arg_list->second.begin();
|
||||
iter != arg_list->second.end(); ++iter) {
|
||||
lists.push_back(param_names_[*iter]);
|
||||
}
|
||||
inputs_var = VectorToTuple(lists);
|
||||
conversion = "_execute.args_to_mixed_eager_tensors";
|
||||
} else {
|
||||
// For one list(tensor) argument, we just convert every
|
||||
// element of the list to an eager tensor.
|
||||
inputs_var = param_names_[arg_list->second.front()];
|
||||
conversion = "_execute.convert_to_mixed_eager_tensors";
|
||||
}
|
||||
strings::StrAppend(&result_, " ", var_name, ", ", inputs_var, " = ",
|
||||
conversion, "(", inputs_var, ")\n");
|
||||
strings::StrAppend(&result_, " ", var_name,
|
||||
" = [_t.as_datatype_enum for _t in ", var_name,
|
||||
"]\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GenEagerPythonOp::AddEagerInputCasts() {
|
||||
// Cast remaining args to eager tensors
|
||||
for (int i = 0; i < op_def_.input_arg_size(); ++i) {
|
||||
const auto& arg(op_def_.input_arg(i));
|
||||
if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue;
|
||||
const string& param = param_names_[i];
|
||||
const string fn = arg.number_attr().empty() ? "" : "n_";
|
||||
const string dtype =
|
||||
python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
|
||||
strings::StrAppend(&result_, " ", param, " = _tensor.convert_", fn,
|
||||
"to_eager_tensor(", param, ", ", dtype, ")\n");
|
||||
}
|
||||
}
|
||||
|
||||
void GenEagerPythonOp::AddEagerAttrs() {
|
||||
// Compute eager attrs
|
||||
if (op_def_.attr_size() > 0) {
|
||||
string attr_values;
|
||||
for (int i = 0; i < op_def_.attr_size(); ++i) {
|
||||
if (i > 0) strings::StrAppend(&attr_values, ", ");
|
||||
const auto& attr_name(op_def_.attr(i).name());
|
||||
strings::StrAppend(&attr_values, "\"", attr_name, "\", ",
|
||||
attr_expressions_[attr_name]);
|
||||
}
|
||||
strings::StrAppend(&attr_values, ")");
|
||||
strings::StrAppend(
|
||||
&result_, WordWrap(" _attrs = (", attr_values, kRightMargin), "\n");
|
||||
} else {
|
||||
strings::StrAppend(&result_, " _attrs = None\n");
|
||||
}
|
||||
}
|
||||
|
||||
void GenEagerPythonOp::AddEagerExecute(const string& num_outputs_expr) {
|
||||
const string return_prefix = " _result = _execute.execute(";
|
||||
const string return_args =
|
||||
strings::StrCat("\"", op_def_.name(), "\", ", num_outputs_expr,
|
||||
", inputs=_inputs_flat, attrs=_attrs, name=name)");
|
||||
strings::StrAppend(&result_,
|
||||
// Wrap the arguments, and indent to the (.
|
||||
WordWrap(return_prefix, return_args, kRightMargin), "\n");
|
||||
}
|
||||
|
||||
string GetEagerPythonOps(const OpList& ops,
|
||||
const std::vector<string>& hidden_ops,
|
||||
bool require_shapes) {
|
||||
string result;
|
||||
// Header
|
||||
// TODO(josh11b): Mention the library for which wrappers are being generated.
|
||||
strings::StrAppend(&result, R"("""Python wrappers for TensorFlow ops.
|
||||
|
||||
This file is MACHINE GENERATED! Do not edit.
|
||||
"""
|
||||
|
||||
import collections as _collections
|
||||
|
||||
from tensorflow.python.eager import execute as _execute
|
||||
from tensorflow.python.eager import context as _context
|
||||
from tensorflow.python.eager import core as _core
|
||||
from tensorflow.python.eager import tensor as _tensor
|
||||
from tensorflow.python.framework import dtypes as _dtypes
|
||||
from tensorflow.python.framework import tensor_shape as _tensor_shape
|
||||
|
||||
from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
|
||||
# Needed to trigger the call to _set_call_cpp_shape_fn.
|
||||
from tensorflow.python.framework import common_shapes as _common_shapes
|
||||
from tensorflow.python.framework import op_def_registry as _op_def_registry
|
||||
from tensorflow.python.framework import ops as _ops
|
||||
from tensorflow.python.framework import op_def_library as _op_def_library
|
||||
|
||||
)");
|
||||
|
||||
// We'll make a copy of ops that filters out descriptions.
|
||||
OpList cleaned_ops;
|
||||
auto out = cleaned_ops.mutable_op();
|
||||
out->Reserve(ops.op_size());
|
||||
for (const auto& op_def : ops.op()) {
|
||||
bool is_hidden = false;
|
||||
for (const string& hidden : hidden_ops) {
|
||||
if (op_def.name() == hidden) {
|
||||
is_hidden = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
string function_name;
|
||||
python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(),
|
||||
&function_name);
|
||||
if (is_hidden) function_name = strings::StrCat("_", function_name);
|
||||
|
||||
// When users create custom python wrappers, they may link in the
|
||||
// default op registry by accident, and because they can't
|
||||
// enumerate all 'hidden' symbols, this guard is to prevent
|
||||
// instantiating a python reserved word in their wrapper.
|
||||
if (python_op_gen_internal::IsPythonReserved(function_name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
strings::StrAppend(&result, GetEagerPythonOp(op_def, function_name));
|
||||
|
||||
if (!require_shapes) {
|
||||
strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(),
|
||||
"\")(None)\n\n");
|
||||
}
|
||||
|
||||
auto added = out->Add();
|
||||
*added = op_def;
|
||||
RemoveNonDeprecationDescriptionsFromOpDef(added);
|
||||
}
|
||||
|
||||
result.append(R"(def _InitOpDefLibrary(op_list_proto_bytes):
|
||||
op_list = _op_def_pb2.OpList()
|
||||
op_list.ParseFromString(op_list_proto_bytes)
|
||||
_op_def_registry.register_op_list(op_list)
|
||||
op_def_lib = _op_def_library.OpDefLibrary()
|
||||
op_def_lib.add_op_list(op_list)
|
||||
return op_def_lib
|
||||
)");
|
||||
|
||||
result.append("# ");
|
||||
auto ops_text = ProtoDebugString(cleaned_ops);
|
||||
str_util::StripTrailingWhitespace(&ops_text);
|
||||
result.append(str_util::StringReplace(ops_text, "\n", "\n# ", true));
|
||||
result.append("\n");
|
||||
strings::Appendf(&result, "_op_def_lib = _InitOpDefLibrary(b\"%s\")\n",
|
||||
str_util::CEscape(cleaned_ops.SerializeAsString()).c_str());
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void PrintEagerPythonOps(const OpList& ops,
|
||||
const std::vector<string>& hidden_ops,
|
||||
bool require_shapes) {
|
||||
printf("%s", GetEagerPythonOps(ops, hidden_ops, require_shapes).c_str());
|
||||
}
|
||||
|
||||
string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len) {
|
||||
string op_list_str(op_list_buf, op_list_len);
|
||||
OpList ops;
|
||||
ops.ParseFromString(op_list_str);
|
||||
return GetEagerPythonOps(ops, {}, false);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
39
tensorflow/python/eager/python_eager_op_gen.h
Normal file
39
tensorflow/python/eager/python_eager_op_gen.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// hidden_ops should be a list of Op names that should get a leading _
|
||||
// in the output. Prints the output to stdout.
|
||||
void PrintEagerPythonOps(const OpList& ops,
|
||||
const std::vector<string>& hidden_ops,
|
||||
bool require_shapes);
|
||||
|
||||
// Get the python wrappers for a list of ops in a OpList.
|
||||
// `op_list_buf` should be a pointer to a buffer containing
|
||||
// the binary encoded OpList proto, and `op_list_len` should be the
|
||||
// length of that buffer.
|
||||
string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_
|
46
tensorflow/python/eager/python_eager_op_gen_main.cc
Normal file
46
tensorflow/python/eager/python_eager_op_gen_main.cc
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/python/eager/python_eager_op_gen.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
void PrintAllPythonOps(const std::vector<string>& hidden_ops) {
|
||||
OpList ops;
|
||||
OpRegistry::Global()->Export(false, &ops);
|
||||
PrintEagerPythonOps(ops, hidden_ops, true /* require_shapes */);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||
|
||||
if (argc == 1) {
|
||||
tensorflow::PrintAllPythonOps({});
|
||||
} else {
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
67
tensorflow/python/eager/pywrap_tfe.h
Normal file
67
tensorflow/python/eager/pywrap_tfe.h
Normal file
@ -0,0 +1,67 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
|
||||
#define TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include <Python.h>
|
||||
|
||||
typedef tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>
|
||||
TFE_InputTensorHandles;
|
||||
typedef tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2>
|
||||
TFE_OutputTensorHandles;
|
||||
|
||||
// Execute a TensorFlow operation.
|
||||
//
|
||||
// 'device_name': Name of the device on which to execute the operation, or NULL
|
||||
// for automatic selection.
|
||||
// 'op_name': Name of the TensorFlow op to execute.
|
||||
// 'inputs': An array of TFE_TensorHandle*'s of size 'num_inputs'. These tensors
|
||||
// will be provided as input to the operation.
|
||||
// 'attrs': A Python tuple alternating names and attr values.
|
||||
// 'outputs': A pointer to a TFE_OutputTensorHandles in which outputs will
|
||||
// placed. On success, its elements will be filled in and the
|
||||
// caller takes ownership of each returned TFE_TensorHandle.
|
||||
// 'outputs' MUST be sized to be at least as large as the number
|
||||
// of tensors produced by the operation and will be resized to
|
||||
// the actual number of tensors produced.
|
||||
void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
|
||||
const char* op_name, TFE_InputTensorHandles* inputs,
|
||||
PyObject* attrs, TFE_OutputTensorHandles* outputs,
|
||||
TF_Status* out_status);
|
||||
|
||||
// Convert a TFE_TensorHandle to a Python numpy.ndarray object.
|
||||
//
|
||||
// The two may share underlying storage so changes to one may reflect in the
|
||||
// other.
|
||||
PyObject* TFE_Py_TensorHandleToNumpy(TFE_TensorHandle* h, TF_Status* status);
|
||||
|
||||
// Convert a Python numpy.ndarray object to a TFE_TensorHandle.
|
||||
//
|
||||
// The two may share underlying storage so changes to one may reflect in the
|
||||
// other.
|
||||
TFE_TensorHandle* TFE_Py_NumpyToTensorHandle(PyObject* obj);
|
||||
|
||||
// Registers e as the Exception class for handling not ok Status. Returns
|
||||
// Py_None if registration succeeds, else throws a TypeError and returns NULL.
|
||||
PyObject* TFE_Py_RegisterExceptionClass(PyObject* e);
|
||||
|
||||
// Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using the
|
||||
// class registered via TFE_Py_RegisterExceptionClass) and returns -1.
|
||||
int TFE_Py_MayBeRaiseException(TF_Status* status);
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
|
377
tensorflow/python/eager/pywrap_tfe_src.cc
Normal file
377
tensorflow/python/eager/pywrap_tfe_src.cc
Normal file
@ -0,0 +1,377 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Must be included first.
|
||||
#include "tensorflow/python/lib/core/numpy.h"
|
||||
|
||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/python/lib/core/py_func.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
namespace {
|
||||
|
||||
#define PARSE_VALUE(fn_name, type, check_fn, parse_fn) \
|
||||
bool fn_name(const string& key, PyObject* py_value, TF_Status* status, \
|
||||
type* value) { \
|
||||
if (check_fn(py_value)) { \
|
||||
*value = static_cast<type>(parse_fn(py_value)); \
|
||||
return true; \
|
||||
} else { \
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, \
|
||||
tensorflow::strings::StrCat( \
|
||||
"Expecting " #type " value for attr ", key, ", got ", \
|
||||
py_value->ob_type->tp_name) \
|
||||
.c_str()); \
|
||||
return false; \
|
||||
} \
|
||||
}
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong)
|
||||
PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLong)
|
||||
PARSE_VALUE(ParseStringValue, const char*, PyUnicode_Check, PyUnicode_AsUTF8)
|
||||
#else
|
||||
PARSE_VALUE(ParseStringValue, const char*, PyString_Check, PyString_AsString)
|
||||
PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong)
|
||||
PARSE_VALUE(ParseInt64Value, int64_t, PyInt_Check, PyInt_AsLong)
|
||||
#endif
|
||||
PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble)
|
||||
|
||||
#undef PARSE_VALUE
|
||||
|
||||
bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status,
|
||||
unsigned char* value) {
|
||||
*value = PyObject_IsTrue(py_value);
|
||||
return true;
|
||||
}
|
||||
|
||||
const char* ParseProtoValue(const string& key, const char* proto_name,
|
||||
PyObject* py_value, size_t* size,
|
||||
TF_Status* status) {
|
||||
char* output = nullptr;
|
||||
Py_ssize_t py_size;
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
if (!PyUnicode_Check(py_value) ||
|
||||
(output = PyUnicode_AsUTF8AndSize(py_value, &py_size)) == nullptr) {
|
||||
#else
|
||||
if (!PyString_Check(py_value) ||
|
||||
(PyString_AsStringAndSize(py_value, &output, &py_size) < 0)) {
|
||||
#endif
|
||||
TF_SetStatus(
|
||||
status, TF_INVALID_ARGUMENT,
|
||||
tensorflow::strings::StrCat("Expecting a string (serialized ",
|
||||
proto_name, ") value for attr ", key)
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
*size = static_cast<size_t>(py_size);
|
||||
return output;
|
||||
}
|
||||
|
||||
bool SetOpAttrList(TFE_Op* op, const char* key, PyObject* py_list,
|
||||
TF_AttrType type, TF_Status* status) {
|
||||
if (!PySequence_Check(py_list)) {
|
||||
TF_SetStatus(
|
||||
status, TF_INVALID_ARGUMENT,
|
||||
tensorflow::strings::StrCat("Expecting sequence value for attr ", key,
|
||||
", got ", py_list->ob_type->tp_name)
|
||||
.c_str());
|
||||
return false;
|
||||
}
|
||||
const int num_values = PySequence_Size(py_list);
|
||||
|
||||
#define PARSE_LIST(c_type, parse_fn) \
|
||||
std::unique_ptr<c_type[]> values(new c_type[num_values]); \
|
||||
for (int i = 0; i < num_values; ++i) { \
|
||||
auto py_value = PySequence_ITEM(py_list, i); \
|
||||
if (!parse_fn(key, py_value, status, &values[i])) return false; \
|
||||
}
|
||||
|
||||
if (type == TF_ATTR_STRING) {
|
||||
PARSE_LIST(const char*, ParseStringValue);
|
||||
TFE_OpSetAttrStringList(op, key, values.get(), num_values);
|
||||
} else if (type == TF_ATTR_INT) {
|
||||
PARSE_LIST(int64_t, ParseInt64Value);
|
||||
TFE_OpSetAttrIntList(op, key, values.get(), num_values);
|
||||
} else if (type == TF_ATTR_FLOAT) {
|
||||
PARSE_LIST(float, ParseFloatValue);
|
||||
TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
|
||||
} else if (type == TF_ATTR_BOOL) {
|
||||
PARSE_LIST(unsigned char, ParseBoolValue);
|
||||
TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
|
||||
} else if (type == TF_ATTR_TYPE) {
|
||||
PARSE_LIST(int, ParseIntValue);
|
||||
TFE_OpSetAttrTypeList(op, key,
|
||||
reinterpret_cast<const TF_DataType*>(values.get()),
|
||||
num_values);
|
||||
} else if (type == TF_ATTR_SHAPE) {
|
||||
// Make one pass through the input counting the total number of
|
||||
// dims across all the input lists.
|
||||
int total_dims = 0;
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
auto py_value = PySequence_ITEM(py_list, i);
|
||||
if (py_value != Py_None) {
|
||||
if (!PySequence_Check(py_value)) {
|
||||
TF_SetStatus(
|
||||
status, TF_INVALID_ARGUMENT,
|
||||
tensorflow::strings::StrCat(
|
||||
"Expecting None or sequence value for element", i,
|
||||
" of attr ", key, ", got ", py_value->ob_type->tp_name)
|
||||
.c_str());
|
||||
return false;
|
||||
}
|
||||
const auto size = PySequence_Size(py_value);
|
||||
total_dims += size;
|
||||
}
|
||||
}
|
||||
// Allocate a buffer that can fit all of the dims together.
|
||||
std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
|
||||
// Copy the input dims into the buffer and set dims to point to
|
||||
// the start of each list's dims.
|
||||
std::unique_ptr<const int64_t* []> dims(new const int64_t*[num_values]);
|
||||
std::unique_ptr<int[]> num_dims(new int[num_values]);
|
||||
int64_t* offset = buffer.get();
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
auto py_value = PySequence_ITEM(py_list, i);
|
||||
if (py_value == Py_None) {
|
||||
dims[i] = nullptr;
|
||||
num_dims[i] = -1;
|
||||
} else {
|
||||
const auto size = PySequence_Size(py_value);
|
||||
dims[i] = offset;
|
||||
num_dims[i] = size;
|
||||
for (int j = 0; j < size; ++j) {
|
||||
auto inner_py_value = PySequence_ITEM(py_value, j);
|
||||
if (inner_py_value == Py_None) {
|
||||
*offset = -1;
|
||||
} else if (!ParseInt64Value(key, inner_py_value, status, offset)) {
|
||||
return false;
|
||||
}
|
||||
++offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return false;
|
||||
} else {
|
||||
TF_SetStatus(status, TF_UNIMPLEMENTED,
|
||||
tensorflow::strings::StrCat("Attr ", key,
|
||||
" has unhandled list type ", type)
|
||||
.c_str());
|
||||
return false;
|
||||
}
|
||||
#undef PARSE_LIST
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SetOpAttrScalar(TFE_Op* op, const char* key, PyObject* py_value,
|
||||
TF_AttrType type, TF_Status* status) {
|
||||
if (type == TF_ATTR_STRING) {
|
||||
const char* value;
|
||||
if (!ParseStringValue(key, py_value, status, &value)) return false;
|
||||
TFE_OpSetAttrString(op, key, value);
|
||||
} else if (type == TF_ATTR_INT) {
|
||||
int64_t value;
|
||||
if (!ParseInt64Value(key, py_value, status, &value)) return false;
|
||||
TFE_OpSetAttrInt(op, key, value);
|
||||
} else if (type == TF_ATTR_FLOAT) {
|
||||
float value;
|
||||
if (!ParseFloatValue(key, py_value, status, &value)) return false;
|
||||
TFE_OpSetAttrFloat(op, key, value);
|
||||
} else if (type == TF_ATTR_BOOL) {
|
||||
unsigned char value;
|
||||
if (!ParseBoolValue(key, py_value, status, &value)) return false;
|
||||
TFE_OpSetAttrBool(op, key, value);
|
||||
} else if (type == TF_ATTR_TYPE) {
|
||||
int value;
|
||||
if (!ParseIntValue(key, py_value, status, &value)) return false;
|
||||
TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value));
|
||||
} else if (type == TF_ATTR_SHAPE) {
|
||||
if (py_value == Py_None) {
|
||||
TFE_OpSetAttrShape(op, key, nullptr, -1, status);
|
||||
} else {
|
||||
if (!PySequence_Check(py_value)) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
tensorflow::strings::StrCat(
|
||||
"Expecting None or sequence value for attr", key,
|
||||
", got ", py_value->ob_type->tp_name)
|
||||
.c_str());
|
||||
return false;
|
||||
}
|
||||
const auto num_dims = PySequence_Size(py_value);
|
||||
std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
auto inner_py_value = PySequence_ITEM(py_value, i);
|
||||
if (inner_py_value == Py_None) {
|
||||
dims[i] = -1;
|
||||
} else if (!ParseInt64Value(key, inner_py_value, status, &dims[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status);
|
||||
}
|
||||
if (TF_GetCode(status) != TF_OK) return false;
|
||||
} else {
|
||||
TF_SetStatus(
|
||||
status, TF_UNIMPLEMENTED,
|
||||
tensorflow::strings::StrCat("Attr ", key, " has unhandled type ", type)
|
||||
.c_str());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void SetOpAttrs(TFE_Op* op, PyObject* attrs, TF_Status* out_status) {
|
||||
if (attrs == Py_None) return;
|
||||
if (!PyTuple_Check(attrs)) {
|
||||
TF_SetStatus(out_status, TF_INVALID_ARGUMENT, "Expecting an attrs tuple.");
|
||||
return;
|
||||
}
|
||||
Py_ssize_t len = PyTuple_GET_SIZE(attrs);
|
||||
if ((len & 1) != 0) {
|
||||
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
|
||||
"Expecting attrs tuple to have even length.");
|
||||
return;
|
||||
}
|
||||
// Parse attrs
|
||||
for (Py_ssize_t i = 0; i < len; i += 2) {
|
||||
PyObject* py_key = PyTuple_GET_ITEM(attrs, i);
|
||||
PyObject* py_value = PyTuple_GET_ITEM(attrs, i + 1);
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
const char* key = PyUnicode_AsUTF8(py_key);
|
||||
#else
|
||||
const char* key = PyString_AsString(py_key);
|
||||
#endif
|
||||
unsigned char is_list = 0;
|
||||
const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status);
|
||||
if (TF_GetCode(out_status) != TF_OK) return;
|
||||
if (is_list != 0) {
|
||||
if (!SetOpAttrList(op, key, py_value, type, out_status)) return;
|
||||
} else {
|
||||
if (!SetOpAttrScalar(op, key, py_value, type, out_status)) return;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
|
||||
const char* op_name, TFE_InputTensorHandles* inputs,
|
||||
PyObject* attrs, TFE_OutputTensorHandles* outputs,
|
||||
TF_Status* out_status) {
|
||||
TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
|
||||
if (TF_GetCode(out_status) != TF_OK) return;
|
||||
if (device_name != nullptr) {
|
||||
TFE_OpSetDevice(op, ctx, device_name, out_status);
|
||||
}
|
||||
if (TF_GetCode(out_status) == TF_OK) {
|
||||
for (int i = 0; i < inputs->size() && TF_GetCode(out_status) == TF_OK;
|
||||
++i) {
|
||||
TFE_OpAddInput(op, inputs->at(i), out_status);
|
||||
}
|
||||
}
|
||||
if (TF_GetCode(out_status) == TF_OK) {
|
||||
SetOpAttrs(op, attrs, out_status);
|
||||
}
|
||||
if (TF_GetCode(out_status) == TF_OK) {
|
||||
int num_outputs = outputs->size();
|
||||
TFE_Execute(op, outputs->data(), &num_outputs, out_status);
|
||||
outputs->resize(num_outputs);
|
||||
}
|
||||
if (TF_GetCode(out_status) != TF_OK) {
|
||||
TF_SetStatus(out_status, TF_GetCode(out_status),
|
||||
tensorflow::strings::StrCat(TF_Message(out_status),
|
||||
" [Op:", op_name, "]")
|
||||
.c_str());
|
||||
}
|
||||
TFE_DeleteOp(op);
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_TensorHandleToNumpy(TFE_TensorHandle* h, TF_Status* status) {
|
||||
const tensorflow::Tensor* t =
|
||||
TFE_TensorHandleUnderlyingTensorInHostMemory(h, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
PyObject* ret = nullptr;
|
||||
auto cppstatus = tensorflow::ConvertTensorToNdarray(*t, &ret);
|
||||
if (!cppstatus.ok()) {
|
||||
TF_SetStatus(status, TF_Code(cppstatus.code()),
|
||||
cppstatus.error_message().c_str());
|
||||
}
|
||||
if (ret != nullptr) return ret;
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Python subclass of Exception that is created on not ok Status.
|
||||
tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
|
||||
PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr;
|
||||
} // namespace
|
||||
|
||||
TFE_TensorHandle* TFE_Py_NumpyToTensorHandle(PyObject* obj) {
|
||||
tensorflow::Tensor t;
|
||||
auto cppstatus = tensorflow::ConvertNdarrayToTensor(obj, &t);
|
||||
if (cppstatus.ok()) {
|
||||
return TFE_NewTensorHandle(t);
|
||||
} else {
|
||||
tensorflow::mutex_lock l(exception_class_mutex);
|
||||
auto msg = tensorflow::strings::StrCat(
|
||||
"failed to convert numpy ndarray to a Tensor (",
|
||||
cppstatus.error_message(), ")");
|
||||
if (exception_class != nullptr) {
|
||||
PyErr_SetObject(exception_class,
|
||||
Py_BuildValue("si", msg.c_str(), TF_INVALID_ARGUMENT));
|
||||
} else {
|
||||
PyErr_SetString(PyExc_RuntimeError, msg.c_str());
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
|
||||
tensorflow::mutex_lock l(exception_class_mutex);
|
||||
if (exception_class != nullptr) {
|
||||
Py_DECREF(exception_class);
|
||||
}
|
||||
if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
|
||||
exception_class = nullptr;
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"TFE_Py_RegisterExceptionClass: "
|
||||
"Registered class should be subclass of Exception.");
|
||||
return nullptr;
|
||||
} else {
|
||||
Py_INCREF(e);
|
||||
exception_class = e;
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
}
|
||||
|
||||
int TFE_Py_MayBeRaiseException(TF_Status* status) {
|
||||
if (TF_GetCode(status) == TF_OK) return 0;
|
||||
tensorflow::mutex_lock l(exception_class_mutex);
|
||||
if (exception_class != nullptr) {
|
||||
PyErr_SetObject(exception_class, Py_BuildValue("si", TF_Message(status),
|
||||
TF_GetCode(status)));
|
||||
} else {
|
||||
PyErr_SetString(PyExc_RuntimeError, TF_Message(status));
|
||||
}
|
||||
return -1;
|
||||
}
|
240
tensorflow/python/eager/tape.py
Normal file
240
tensorflow/python/eager/tape.py
Normal file
@ -0,0 +1,240 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Gradient tape utilites."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
from autograd import container_types
|
||||
from autograd import core as ag_core
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
|
||||
def tensor_id(t):
|
||||
"""Returns a unique identifier for this Tensor."""
|
||||
t = ag_core.getval(t)
|
||||
return t._id # pylint: disable=protected-access
|
||||
|
||||
|
||||
class ImplicitTape(object):
|
||||
"""Global object which can watch tensors and wrap them with autograd."""
|
||||
|
||||
def __init__(self):
|
||||
self.tensors = {}
|
||||
self.gradients = []
|
||||
|
||||
def __eq__(self, other):
|
||||
return self is other
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
|
||||
@ag_core.primitive
|
||||
def _watch_with_tape_internal(_, tensor):
|
||||
"""Primitive to wrap a tensor around an ImplicitTape progenitor."""
|
||||
return tensor
|
||||
|
||||
|
||||
def _watch_with_tape(tape, tensor):
|
||||
"""Wraps a watched Tensor and keeps track of it in the implicit tape."""
|
||||
w = _watch_with_tape_internal(tape, tensor)
|
||||
if ag_core.isnode(tape):
|
||||
tape.value.tensors[tensor_id(tensor)] = w
|
||||
return w
|
||||
|
||||
|
||||
def _watch_with_tape_vjp(g, ans, vs, gvs, tape, tensor):
|
||||
"""Gradient for _watch_with_tape_internal."""
|
||||
del ans, gvs, tape
|
||||
|
||||
def mut_add(implicit_tape):
|
||||
t = ag_core.getval(tensor)
|
||||
implicit_tape.gradients.append((t, g))
|
||||
return implicit_tape
|
||||
|
||||
return ag_core.SparseObject(vs, mut_add)
|
||||
|
||||
_watch_with_tape_internal.defvjp(_watch_with_tape_vjp, argnum=0)
|
||||
_watch_with_tape_internal.defvjp(
|
||||
lambda g, ans, vs, gvs, tape, tensor: g,
|
||||
argnum=1)
|
||||
|
||||
|
||||
class ImplicitTapeVSpace(ag_core.VSpace):
|
||||
"""VSpace needed to have ImplicitTape be a valid progenitor."""
|
||||
|
||||
def zeros(self):
|
||||
return ImplicitTape()
|
||||
|
||||
|
||||
class ImplicitTapeNode(ag_core.Node):
|
||||
"""Node to wrap ImplicitTape in."""
|
||||
|
||||
def __eq__(self, other):
|
||||
return self is other
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
ag_core.register_node(ImplicitTapeNode, ImplicitTape)
|
||||
ag_core.register_vspace(ImplicitTapeVSpace, ImplicitTape)
|
||||
|
||||
|
||||
# TODO(apassos) try to not do this.
|
||||
class NoneVSpace(ag_core.VSpace):
|
||||
"""VSpace for python None."""
|
||||
|
||||
def __init__(self, _):
|
||||
self.size = 0
|
||||
|
||||
|
||||
ag_core.register_vspace(NoneVSpace, type(None))
|
||||
|
||||
|
||||
class _TapeStack(threading.local):
|
||||
|
||||
def __init__(self):
|
||||
super(_TapeStack, self).__init__()
|
||||
self._stack = []
|
||||
|
||||
@property
|
||||
def stack(self):
|
||||
return self._stack
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def replace_stack(self, new_stack):
|
||||
old = self._stack
|
||||
self._stack = new_stack
|
||||
yield
|
||||
self._stack = old
|
||||
|
||||
|
||||
# The global tape stack.
|
||||
_tape_stack = _TapeStack()
|
||||
|
||||
|
||||
def push_new_tape():
|
||||
"""Pushes a new tape onto the tape stack."""
|
||||
progenitor = ag_core.new_progenitor(ImplicitTape())
|
||||
_tape_stack.stack.append(progenitor)
|
||||
ag_core.active_progenitors.add(progenitor)
|
||||
|
||||
|
||||
def watch(tensor):
|
||||
"""Marks this tensor to be watched by all tapes in the stack.
|
||||
|
||||
Args:
|
||||
tensor: tensor to be watched.
|
||||
|
||||
Returns:
|
||||
The tensor, potentially wrapped by all tapes in the stack.
|
||||
"""
|
||||
for t in _tape_stack.stack:
|
||||
tensor = _watch_with_tape(t, tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
def pop_tape():
|
||||
"""Pops the top tape in the stack, if any."""
|
||||
if _tape_stack.stack:
|
||||
return _tape_stack.stack.pop()
|
||||
return None
|
||||
|
||||
|
||||
def any_tape_has(tensor):
|
||||
for t in _tape_stack.stack:
|
||||
if tensor_id(tensor) in t.value.tensors:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def should_record(tensors):
|
||||
"""Returns true if any tape in the stach watches any of these tensors."""
|
||||
return any(ag_core.isnode(x) for x in tensors)
|
||||
|
||||
|
||||
class _EagerSequenceNode(container_types.SequenceNode):
|
||||
"""Eager version of SequenceNode, to live in EagerSequenceVSpace."""
|
||||
pass
|
||||
|
||||
|
||||
class _EagerSequenceVSpace(container_types.SequenceVSpace):
|
||||
"""Changes equality on SequenceVSpace to conform to tfe requirements."""
|
||||
|
||||
def __init__(self, value):
|
||||
self.shape = [ag_core.vspace(x) for x in value]
|
||||
self.size = sum(s.size for s in self.shape)
|
||||
self.sequence_type = type(value)
|
||||
|
||||
def __eq__(self, other):
|
||||
if type(self) != type(other): # pylint: disable=unidiomatic-typecheck
|
||||
return False
|
||||
if len(self.shape) != len(other.shape):
|
||||
# TODO(apassos) function gradients sometimes return gradients for side
|
||||
# inputs which breaks this assertion. Understand how to fix it.
|
||||
return True
|
||||
for ss, os in zip(self.shape, other.shape):
|
||||
if ss != os:
|
||||
if isinstance(ss, NoneVSpace) or isinstance(os, NoneVSpace):
|
||||
continue
|
||||
if ss.dtype == dtypes.resource or os.dtype == dtypes.resource:
|
||||
continue
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class _EagerList(list):
|
||||
"""Type used to bypass SequenceVSpace."""
|
||||
|
||||
def __init__(self, value):
|
||||
super(_EagerList, self).__init__(value)
|
||||
for v in value:
|
||||
assert not ag_core.isnode(v)
|
||||
|
||||
ag_core.register_vspace(_EagerSequenceVSpace, _EagerList)
|
||||
ag_core.register_node(_EagerSequenceNode, _EagerList)
|
||||
|
||||
|
||||
@ag_core.primitive
|
||||
def _record_operation(output_tensors, input_tensors, side_outputs,
|
||||
backward_function):
|
||||
del input_tensors, side_outputs, backward_function
|
||||
return _EagerList(output_tensors)
|
||||
|
||||
|
||||
def record_operation(o, i, s, b):
|
||||
"""Primitive to trigger autograd tracing on outputs from inputs."""
|
||||
inputs = container_types.make_sequence(_EagerList, *i)
|
||||
return _record_operation(o, inputs, s, b)
|
||||
|
||||
|
||||
def _record_operation_vjp(g, ans, vs, gvs, output_tensors, input_tensors,
|
||||
side_outputs, backward_function):
|
||||
"""Gradient for _record_operation."""
|
||||
del ans, vs, gvs, output_tensors, input_tensors
|
||||
backward_args = tuple(g) + tuple(side_outputs)
|
||||
if ag_core.isnode(backward_args):
|
||||
backward_args = list(backward_args)
|
||||
tensors = nest.flatten(backward_function(*backward_args))
|
||||
return _EagerList([ag_core.getval(t) for t in tensors])
|
||||
|
||||
_record_operation.defvjp(_record_operation_vjp, argnum=1)
|
411
tensorflow/python/eager/tensor.py
Normal file
411
tensorflow/python/eager/tensor.py
Normal file
@ -0,0 +1,411 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Experimental API for TensorFlow's "Eager" mode of execution."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from autograd import core as ag_core
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import core
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
|
||||
|
||||
class Tensor(object):
|
||||
"""A TensorFlow Tensor."""
|
||||
|
||||
def __init__(self, value, dtype=None):
|
||||
"""Creates a Tensor object from a Python object or numpy array.
|
||||
|
||||
May share storage with the numpy array, in which case changes to the numpy
|
||||
object will reflect
|
||||
in the Tensor.
|
||||
|
||||
Arguments:
|
||||
value: A numpy.array or a Python object to create a Tensor for.
|
||||
dtype: TensorFlow dtype for the returned Tensor. If None, one will be
|
||||
automatically selected.
|
||||
"""
|
||||
# TODO(ashankar): Evaluate if we can and perhaps share code with
|
||||
# tf.constant defined in
|
||||
# https://www.tensorflow.org/code/tensorflow/python/framework/constant_op.py
|
||||
self._id = core.uid()
|
||||
if not isinstance(value, np.ndarray):
|
||||
npt = None if dtype is None else dtype.as_numpy_dtype
|
||||
value = np.array(value, dtype=npt)
|
||||
if dtype is None:
|
||||
value = _maybe_modify_numpy_dtype_determination(value)
|
||||
elif dtype is not None:
|
||||
npt = dtype.as_numpy_dtype
|
||||
if npt != value.dtype:
|
||||
value = value.astype(npt)
|
||||
try:
|
||||
value = np.asarray(value, order="C")
|
||||
self._handle = pywrap_tensorflow.TFE_Py_NumpyToTensorHandle(value)
|
||||
except core._NotOkStatusException as e: # pylint: disable=protected-access
|
||||
raise core._status_to_exception(e.code, e.message) # pylint: disable=protected-access
|
||||
|
||||
# Almost all TensorFlow kernels for GPU devices keep int32 tensors in host
|
||||
# memory. This change approximates the same behavior for eager execution -
|
||||
# keeping int32 tensors in host memory.
|
||||
#
|
||||
# We do so to preclude the need for callers into such kernels from having to
|
||||
# explicitly place the int32 tensors in host memory. For example, prior to
|
||||
# this change one needed:
|
||||
#
|
||||
# with tfe.device('/gpu:0'):
|
||||
# ... # code here
|
||||
# with tfe.device('/cpu:0'):
|
||||
# shape = tfe.Tensor(...)
|
||||
# y = tfe.ops.random_uniform(.., shape)
|
||||
#
|
||||
# Without the CPU device block tfe.ops.random_uniform would fail since the
|
||||
# kernel expects the shape in host memory.
|
||||
#
|
||||
# After this change, we simplify the code:
|
||||
#
|
||||
# with tfe.device('/gpu:0'):
|
||||
# y = tfe.ops.random_uniform(, tfe.Tensor(...))
|
||||
#
|
||||
# The approximation is not exact since if there are GPU kernels which do not
|
||||
# require host memory for int32 tensors, there will be a discrepancy between
|
||||
# eager execution and TensorFlow graphs. However, as of July 2017, there
|
||||
# were no known GPU kernels that kept int32 tensors in device memory.
|
||||
if _in_gpu_device() and value.dtype != np.int32:
|
||||
ctx = context.get_default_context()
|
||||
# pylint: disable=protected-access
|
||||
device_name = ctx.device_name
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
self._handle = pywrap_tensorflow.TFE_TensorHandleCopyToDevice(
|
||||
self._handle, ctx._handle, device_name, status)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
self._dtype = dtypes.as_dtype(
|
||||
pywrap_tensorflow.TFE_TensorHandleDataType(self._handle))
|
||||
|
||||
# This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
|
||||
# be None for tensors of type other than DT_REOSURCE. For DT_RESOURCE
|
||||
# tensors, this will contain a serialized HandleData proto with shape
|
||||
# inference metadata about shapes and dtypes of resources accessible from
|
||||
# this handle.
|
||||
self._handle_data = None
|
||||
if core.active_trace() is not None:
|
||||
core.active_trace().record_tensor("MANUAL",
|
||||
tape.tensor_id(self),
|
||||
self._device_name(),
|
||||
self.shape.num_elements())
|
||||
|
||||
def __del__(self):
|
||||
if (pywrap_tensorflow is not None
|
||||
and pywrap_tensorflow.TFE_DeleteTensorHandle is not None):
|
||||
pywrap_tensorflow.TFE_DeleteTensorHandle(self._handle)
|
||||
if core.active_trace() is not None:
|
||||
core.active_trace().delete_tensor(tape.tensor_id(self))
|
||||
|
||||
def __str__(self):
|
||||
if self.dtype.is_numpy_compatible and self.shape.num_elements() > 0:
|
||||
n = self.numpy().reshape(-1)
|
||||
if self.shape.num_elements() > 5:
|
||||
return "tfe.Tensor(%s..., shape=%s, dtype=%s)" % (n[:5], self.shape,
|
||||
self.dtype.name)
|
||||
else:
|
||||
return "tfe.Tensor(%s, dtype=%s)" % (
|
||||
np.array_str(self.numpy()).replace("\n", ""), self.dtype.name)
|
||||
return "tfe.Tensor(<unprintable>, shape=%s dtype=%s)" % (self.shape,
|
||||
self.dtype.name)
|
||||
|
||||
def __repr__(self):
|
||||
if self.dtype.is_numpy_compatible and self.shape.num_elements() > 0:
|
||||
n = self.numpy()
|
||||
# TODO(apassos): understand why self.numpy() sometimes returns not
|
||||
# an array.
|
||||
if isinstance(n, np.ndarray):
|
||||
n = n.reshape(-1)
|
||||
if self.shape.num_elements() > 5:
|
||||
return "<tfe.Tensor at %s shape=%s dtype=%s>(%s..., min=%s, max=%s)" % (
|
||||
self._id, self.shape, self.dtype.name, n[:5], np.min(n), np.max(n))
|
||||
else:
|
||||
return "<tfe.Tensor at %s shape=%s dtype=%s>(%s)" % (self._id,
|
||||
self.shape,
|
||||
self.dtype.name, n)
|
||||
return "<tfe.Tensor at %s shape=%s dtype=%s>" % (self._id, self.shape,
|
||||
self.dtype.name)
|
||||
|
||||
@staticmethod
|
||||
def _override_operator(name, func):
|
||||
setattr(Tensor, name, func)
|
||||
|
||||
def numpy(self):
|
||||
"""Returns a numpy array with the same contents as the Tensor.
|
||||
|
||||
The contents of the Tensor must be backed by host memory. The
|
||||
as_cpu_tensor() method can be used ensure that this is true.
|
||||
|
||||
TODO(ashankar,agarwal): Perhaps this should NOT reference the underlying
|
||||
buffer but instead always explicitly copy? Note that currently it may or may
|
||||
not copy based on whether the numpy data is properly aligned or not.
|
||||
|
||||
Returns:
|
||||
A numpy array that may share memory with the Tensor object. Any changes
|
||||
to one may be reflected in the other.
|
||||
"""
|
||||
# TODO(ashankar): This with status business seems expensive. Profile/avoid?
|
||||
cpu = self.as_cpu_tensor()
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
return pywrap_tensorflow.TFE_Py_TensorHandleToNumpy(cpu._handle, status) # pylint: disable=protected-access
|
||||
|
||||
def _copy(self, ctx, device_name):
|
||||
"""Copies tensor to dest device."""
|
||||
# pylint: disable=protected-access
|
||||
# Creates a new tensor on the dest device.
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
h = pywrap_tensorflow.TFE_TensorHandleCopyToDevice(
|
||||
self._handle, ctx._handle, device_name, status)
|
||||
new_tensor = _tensor_from_handle(h)
|
||||
if core.active_trace() is not None:
|
||||
core.active_trace().record_tensor("COPY",
|
||||
tape.tensor_id(new_tensor),
|
||||
new_tensor._device_name(),
|
||||
new_tensor.shape.num_elements())
|
||||
return new_tensor
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def _device_name(self):
|
||||
return pywrap_tensorflow.TFE_TensorHandleDeviceName(self._handle)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""The shape of this Tensor as a TensorShape object."""
|
||||
n = pywrap_tensorflow.TFE_TensorHandleNumDims(self._handle)
|
||||
# As of May 2017, TFE_TensorHandle objects were always backed by concrete
|
||||
# tensors (which have a valid, known shape). There were vague plans to
|
||||
# change this so that the Tensor class can also represent Tensors that have
|
||||
# not yet been computed.
|
||||
# If that happens, handle that (e.g., if n < 0: return tensor_shape(None))
|
||||
# and also handle -1s returned by TFE_TensorHandleDim.
|
||||
assert n >= 0, "See comment in source code"
|
||||
return tensor_shape.TensorShape(
|
||||
[pywrap_tensorflow.TFE_TensorHandleDim(self._handle, x)
|
||||
for x in range(n)])
|
||||
|
||||
def get_shape(self):
|
||||
"""Alias of Tensor.shape."""
|
||||
return self.shape
|
||||
|
||||
def _shape_tuple(self):
|
||||
"""The shape of this Tensor, as a tuple.
|
||||
|
||||
This is more performant than tuple(shape().as_list()) as it avoids
|
||||
two list and one object creation. Marked private for now as from an API
|
||||
perspective, it would be better to have a single performant way of
|
||||
getting a shape rather than exposing shape() and shape_tuple()
|
||||
(and heaven forbid, shape_list() etc. as well!). Punting on that for now,
|
||||
but ideally one would work things out and remove the need for this method.
|
||||
"""
|
||||
n = pywrap_tensorflow.TFE_TensorHandleNumDims(self._handle)
|
||||
# As of May 2017, TFE_TensorHandle objects were always backed by concrete
|
||||
# tensors (which have a valid, known shape). There were vague plans to
|
||||
# change this so that the Tensor class can also represent Tensors that have
|
||||
# not yet been computed.
|
||||
# If that happens, handle that (e.g., if n < 0: return tensor_shape(None))
|
||||
# and also handle -1s returned by TFE_TensorHandleDim.
|
||||
assert n >= 0, "See comment in source code"
|
||||
return tuple(
|
||||
pywrap_tensorflow.TFE_TensorHandleDim(self._handle, x)
|
||||
for x in range(n))
|
||||
|
||||
def as_cpu_tensor(self):
|
||||
"""A copy of this Tensor with contents backed by host memory."""
|
||||
return self._copy(context.get_default_context(), "CPU:0")
|
||||
|
||||
def as_gpu_tensor(self, gpu_index=0):
|
||||
"""A copy of this Tensor with contents backed by memory on the GPU.
|
||||
|
||||
Arguments:
|
||||
gpu_index: Identifies which GPU to place the contents on the returned
|
||||
Tensor in.
|
||||
|
||||
Returns:
|
||||
A GPU-memory backed Tensor object initialized with the same contents
|
||||
as this Tensor.
|
||||
"""
|
||||
return self._copy(context.get_default_context(), "GPU:" + str(gpu_index))
|
||||
|
||||
def __bool__(self):
|
||||
if self._shape_tuple() != (): # pylint: disable=g-explicit-bool-comparison
|
||||
raise ValueError(
|
||||
"Non-scalar tensor %s cannot be converted to boolean." % repr(self))
|
||||
if self.dtype != dtypes.bool:
|
||||
raise ValueError(
|
||||
"Non-boolean tensor %s cannot be converted to boolean." % repr(self))
|
||||
return bool(self.as_cpu_tensor().numpy())
|
||||
|
||||
def __nonzero__(self):
|
||||
return self.__bool__()
|
||||
|
||||
|
||||
class IndexedSlices(object):
|
||||
"""A sparse representation of a set of tensor slices at given indices.
|
||||
|
||||
This class is a simple wrapper for a pair of `Tensor` objects:
|
||||
|
||||
* `values`: A `Tensor` of any dtype with shape `[D0, D1, ..., Dn]`.
|
||||
* `indices`: A 1-D integer `Tensor` with shape `[D0]`.
|
||||
|
||||
An `IndexedSlices` is typically used to represent a subset of a larger
|
||||
tensor `dense` of shape `[LARGE0, D1, .. , DN]` where `LARGE0 >> D0`.
|
||||
The values in `indices` are the indices in the first dimension of
|
||||
the slices that have been extracted from the larger tensor.
|
||||
|
||||
The dense tensor `dense` represented by an `IndexedSlices` `slices` has
|
||||
|
||||
```python
|
||||
dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]
|
||||
```
|
||||
|
||||
The `IndexedSlices` class is used principally in the definition of
|
||||
gradients for operations that have sparse gradients
|
||||
(e.g. @{tf.gather}).
|
||||
"""
|
||||
|
||||
def __init__(self, values, indices, dense_shape):
|
||||
"""Creates an `IndexedSlices`."""
|
||||
self._values = values
|
||||
self._indices = indices
|
||||
assert indices.shape[0] == values.shape[0]
|
||||
self._dense_shape = dense_shape
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
"""A `Tensor` containing the values of the slices."""
|
||||
return self._values
|
||||
|
||||
@property
|
||||
def indices(self):
|
||||
"""A 1-D `Tensor` containing the indices of the slices."""
|
||||
return self._indices
|
||||
|
||||
@property
|
||||
def dense_shape(self):
|
||||
"""A 1-D `Tensor` containing the shape of the corresponding dense tensor."""
|
||||
return self._dense_shape
|
||||
|
||||
|
||||
class _Op(object):
|
||||
"""Fake op for _LazyZero to make its python API tf.Tensor-like."""
|
||||
|
||||
def __init__(self):
|
||||
self.type = "Zeros"
|
||||
|
||||
|
||||
class LazyZero(object):
|
||||
"""Lazily-instantiated zero-valued Tensor used as autograd accumulator."""
|
||||
|
||||
def __init__(self, shape, dtype):
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
self.op = _Op()
|
||||
|
||||
def __add__(self, other):
|
||||
return other
|
||||
|
||||
def __radd__(self, other):
|
||||
return other
|
||||
|
||||
def numpy(self):
|
||||
return np.zeros(self.shape, self.dtype)
|
||||
|
||||
|
||||
def convert_to_eager_tensor(t, dtype=None):
|
||||
if isinstance(ag_core.getval(t), Tensor):
|
||||
if dtype is not None and t.dtype != dtype:
|
||||
raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype))
|
||||
return t
|
||||
return Tensor(t, dtype=dtype)
|
||||
|
||||
|
||||
def convert_n_to_eager_tensor(values, dtype):
|
||||
return [convert_to_eager_tensor(t, dtype) for t in values]
|
||||
|
||||
|
||||
def _tensor_from_handle(handle):
|
||||
"""'Private' constructor for the Tensor object.
|
||||
|
||||
The existence of a 'handle' is an implementation detail that should be hidden
|
||||
from users of this module. Functions within this module do need to create a
|
||||
Tensor object from a handle though.
|
||||
|
||||
One option would be to have an __init__(self, handle) method on the
|
||||
Tensor class, but that would make the existence and use of a handle
|
||||
'public'.
|
||||
|
||||
Instead, this function avoids exposing a Tensor.__init__ that understands
|
||||
handles and yet allows functions within this module to create Tensor
|
||||
objects from a handle.
|
||||
|
||||
Arguments:
|
||||
handle: A valid TFE_TensorHandle object.
|
||||
|
||||
Returns:
|
||||
A Tensor object.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
t = Tensor.__new__(Tensor)
|
||||
t._id = core.uid()
|
||||
t._handle = handle
|
||||
t._dtype = dtypes.as_dtype(pywrap_tensorflow.TFE_TensorHandleDataType(handle))
|
||||
t._handle_data = None
|
||||
return t
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
# TODO(ashankar): use actual device type.
|
||||
def _in_gpu_device():
|
||||
return context.get_default_context()._device_index > 0 # pylint: disable=protected-access
|
||||
|
||||
|
||||
def _maybe_modify_numpy_dtype_determination(np_array):
|
||||
"""Tweak numpy dtype determination.
|
||||
|
||||
numpy prefers int64 and float64, we prefer int32 and float32.
|
||||
(int32 is often used as the "shape" input to various operations,
|
||||
many of which only support int32 shapes).
|
||||
This preference is copied from tensor_util.make_tensor_proto
|
||||
(https://goto.google.com/numpy_prefs_156503903)
|
||||
|
||||
Args:
|
||||
np_array: A numpy ndarray
|
||||
Returns:
|
||||
A numpy ndarray whose dtype may have been modified.
|
||||
"""
|
||||
if np_array.dtype == np.float64:
|
||||
return np_array.astype(np.float32)
|
||||
if np_array.dtype == np.int64:
|
||||
# Downcast iff there is no precision loss.
|
||||
downcasted = np_array.astype(np.int32)
|
||||
if np.array_equal(downcasted, np_array):
|
||||
return downcasted
|
||||
return np_array
|
28
tensorflow/python/eager/test.py
Normal file
28
tensorflow/python/eager/test.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities for testing tfe code."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import context as _context
|
||||
from tensorflow.python.platform import test as _test
|
||||
from tensorflow.python.platform.test import * # pylint: disable=wildcard-import
|
||||
|
||||
|
||||
def main(argv=None):
|
||||
_context.enable_eager_execution()
|
||||
_test.main(argv)
|
153
tensorflow/python/pywrap_tfe.i
Normal file
153
tensorflow/python/pywrap_tfe.i
Normal file
@ -0,0 +1,153 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%ignore "";
|
||||
|
||||
%rename("%s") TFE_Py_RegisterExceptionClass;
|
||||
%rename("%s") TFE_Py_NumpyToTensorHandle;
|
||||
%rename("%s") TFE_NewContext;
|
||||
%rename("%s") TFE_ContextListDevices;
|
||||
%rename("%s") TFE_TensorHandleDataType;
|
||||
%rename("%s") TFE_TensorHandleNumDims;
|
||||
%rename("%s") TFE_DeleteTensorHandle;
|
||||
%rename("%s") TFE_Py_Execute;
|
||||
%rename("%s") TFE_ContextAddFunctionDef;
|
||||
%rename("%s") TFE_TensorHandleDim;
|
||||
%rename("%s") TFE_TensorHandleCopyToDevice;
|
||||
%rename("%s") TFE_NewOp;
|
||||
%rename("%s") TFE_Py_TensorHandleToNumpy;
|
||||
%rename("%s") TFE_OpGetAttrType;
|
||||
|
||||
|
||||
%{
|
||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||
%}
|
||||
|
||||
%typemap(out) TF_DataType {
|
||||
$result = PyInt_FromLong($1);
|
||||
}
|
||||
|
||||
%typemap(out) int64_t {
|
||||
$result = PyInt_FromLong($1);
|
||||
}
|
||||
|
||||
%typemap(out) TF_AttrType {
|
||||
$result = PyInt_FromLong($1);
|
||||
}
|
||||
|
||||
%typemap(in, numinputs=0) unsigned char* is_list (unsigned char tmp) {
|
||||
$1 = &tmp;
|
||||
}
|
||||
|
||||
%typemap(argout) unsigned char* is_list {
|
||||
if (*$1 == 1) {
|
||||
PyObject* list = PyList_New(1);
|
||||
PyList_SetItem(list, 0, $result);
|
||||
$result = list;
|
||||
}
|
||||
}
|
||||
|
||||
%include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
%typemap(in) TFE_InputTensorHandles* inputs (TFE_InputTensorHandles temp) {
|
||||
$1 = &temp;
|
||||
if ($input != Py_None) {
|
||||
if (!PyList_Check($input)) {
|
||||
SWIG_exception_fail(SWIG_TypeError,
|
||||
"must provide a list of Tensors as inputs");
|
||||
}
|
||||
Py_ssize_t len = PyList_Size($input);
|
||||
$1->resize(len);
|
||||
for (Py_ssize_t i = 0; i < len; ++i) {
|
||||
PyObject* elem = PyList_GetItem($input, i);
|
||||
if (!elem) {
|
||||
SWIG_fail;
|
||||
}
|
||||
void* thp = nullptr;
|
||||
int res = SWIG_ConvertPtr(elem, &thp,
|
||||
$descriptor(TFE_TensorHandle*), 0 | 0);
|
||||
if (!SWIG_IsOK(res)) {
|
||||
SWIG_exception_fail(SWIG_ArgError(res),
|
||||
"provided list of inputs contains objects other "
|
||||
"than 'TFE_TensorHandle*'");
|
||||
}
|
||||
(*$1)[i] = reinterpret_cast<TFE_TensorHandle*>(thp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Temporary for the argout
|
||||
%typemap(in) TFE_OutputTensorHandles* outputs (TFE_OutputTensorHandles temp) {
|
||||
if (!PyInt_Check($input)) {
|
||||
SWIG_exception_fail(SWIG_TypeError,
|
||||
"expected an integer value (size of the number of "
|
||||
"outputs of the operation)");
|
||||
}
|
||||
$1 = &temp;
|
||||
$1->resize(PyInt_AsLong($input), nullptr);
|
||||
}
|
||||
|
||||
// Create new Status object.
|
||||
%typemap(in, numinputs=0) TF_Status *out_status {
|
||||
$1 = TF_NewStatus();
|
||||
}
|
||||
|
||||
%typemap(freearg) (TF_Status* out_status) {
|
||||
TF_DeleteStatus($1);
|
||||
}
|
||||
|
||||
%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status) {
|
||||
if (TFE_Py_MayBeRaiseException($2)) {
|
||||
SWIG_fail;
|
||||
} else {
|
||||
int num_outputs = $1->size();
|
||||
$result = PyList_New(num_outputs);
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
PyList_SetItem($result, i, SWIG_NewPointerObj(SWIG_as_voidptr($1->at(i)),
|
||||
$descriptor(TFE_TensorHandle*),
|
||||
0 | 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Note that we need to use a typemap for TFE_TensorHandle* so that we can call
|
||||
// SWIG_fail in case the value is nullptr. Otherwise SWIG will wrap the
|
||||
// nullptr and return it to python as an opaque object, and python does not
|
||||
// know that it needs to check if an Exception has been raised.
|
||||
// TODO(agarwal): check if we can get rid of this typemap.
|
||||
%typemap(out) (TFE_TensorHandle*) {
|
||||
if ($1 == nullptr) {
|
||||
SWIG_fail;
|
||||
} else {
|
||||
$result = SWIG_NewPointerObj(SWIG_as_voidptr($1),
|
||||
$descriptor(TFE_TensorHandle*), 0 | 0);
|
||||
}
|
||||
}
|
||||
|
||||
%include "tensorflow/python/eager/pywrap_tfe.h"
|
||||
|
||||
|
||||
// Clear all typemaps127
|
||||
%typemap(out) TF_DataType;
|
||||
%typemap(out) int64_t;
|
||||
%typemap(out) TF_AttrType;
|
||||
%typemap(in, numinputs=0) TF_Status *out_status;
|
||||
%typemap(argout) unsigned char* is_list;
|
||||
%typemap(in) TFE_InputTensorHandles* inputs (TFE_InputTensorHandles temp);
|
||||
%typemap(in) TFE_OutputTensorHandles* outputs (TFE_OutputTensorHandles temp);
|
||||
%typemap(in, numinputs=0) TF_Status *out_status;
|
||||
%typemap(freearg) (TF_Status* out_status);
|
||||
%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status);
|
||||
%typemap(out) (TFE_TensorHandle*);
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||
* The includes are intentionally not alphabetically sorted, as the order of
|
||||
* includes follows dependency order */
|
||||
|
||||
%include "tensorflow/python/pywrap_tfe.i"
|
||||
|
||||
%include "tensorflow/python/util/port.i"
|
||||
%include "tensorflow/python/util/py_checkpoint_reader.i"
|
||||
%include "tensorflow/python/util/stat_summarizer.i"
|
||||
@ -45,3 +47,4 @@ limitations under the License.
|
||||
%include "tensorflow/python/grappler/tf_optimizer.i"
|
||||
%include "tensorflow/python/grappler/cost_analyzer.i"
|
||||
%include "tensorflow/python/grappler/model_analyzer.i"
|
||||
|
||||
|
@ -175,6 +175,7 @@ sh_binary(
|
||||
"//tensorflow/python/debug:debug_pip",
|
||||
"//tensorflow/python/saved_model:saved_model",
|
||||
"//tensorflow/python:spectral_ops_test_util",
|
||||
"//tensorflow/python/eager:pip_dependencies",
|
||||
"//tensorflow/python/tools:tools_pip",
|
||||
"//tensorflow/tools/dist_test/server:grpc_tensorflow_server",
|
||||
],
|
||||
|
Loading…
Reference in New Issue
Block a user