4390 lines
154 KiB
C++
4390 lines
154 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include <atomic>
|
|
#include <cstring>
|
|
#include <unordered_map>
|
|
|
|
#include "absl/debugging/leak_check.h"
|
|
#include "absl/strings/str_cat.h"
|
|
#include "absl/types/variant.h"
|
|
#include "tensorflow/c/c_api.h"
|
|
#include "tensorflow/c/c_api_internal.h"
|
|
#include "tensorflow/c/eager/c_api.h"
|
|
#include "tensorflow/c/eager/c_api_internal.h"
|
|
#include "tensorflow/c/eager/tape.h"
|
|
#include "tensorflow/c/eager/tfe_context_internal.h"
|
|
#include "tensorflow/c/eager/tfe_op_internal.h"
|
|
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
|
#include "tensorflow/c/tf_status.h"
|
|
#include "tensorflow/core/framework/types.pb.h"
|
|
#include "tensorflow/core/lib/core/errors.h"
|
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
|
#include "tensorflow/core/lib/gtl/compactptrset.h"
|
|
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
|
#include "tensorflow/core/lib/strings/strcat.h"
|
|
#include "tensorflow/core/lib/strings/stringprintf.h"
|
|
#include "tensorflow/core/platform/casts.h"
|
|
#include "tensorflow/core/platform/mutex.h"
|
|
#include "tensorflow/core/platform/protobuf.h"
|
|
#include "tensorflow/core/platform/status.h"
|
|
#include "tensorflow/core/platform/types.h"
|
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
|
#include "tensorflow/core/util/managed_stack_trace.h"
|
|
#include "tensorflow/python/eager/pywrap_gradient_exclusions.h"
|
|
#include "tensorflow/python/eager/pywrap_tensor.h"
|
|
#include "tensorflow/python/eager/pywrap_tfe.h"
|
|
#include "tensorflow/python/lib/core/py_util.h"
|
|
#include "tensorflow/python/lib/core/safe_ptr.h"
|
|
#include "tensorflow/python/util/stack_trace.h"
|
|
#include "tensorflow/python/util/util.h"
|
|
|
|
using tensorflow::Status;
|
|
using tensorflow::string;
|
|
using tensorflow::strings::Printf;
|
|
|
|
namespace {
|
|
// NOTE: Items are retrieved from and returned to these unique_ptrs, and they
|
|
// act as arenas. This is important if the same thread requests 2 items without
|
|
// releasing one.
|
|
// The following sequence of events on the same thread will still succeed:
|
|
// - GetOp <- Returns existing.
|
|
// - GetOp <- Allocates and returns a new pointer.
|
|
// - ReleaseOp <- Sets the item in the unique_ptr.
|
|
// - ReleaseOp <- Sets the item in the unique_ptr, deleting the old one.
|
|
// This occurs when a PyFunc kernel is run. This behavior makes it safe in that
|
|
// case, as well as the case where python decides to reuse the underlying
|
|
// C++ thread in 2 python threads case.
|
|
struct OpDeleter {
|
|
void operator()(TFE_Op* op) const { TFE_DeleteOp(op); }
|
|
};
|
|
thread_local std::unordered_map<TFE_Context*,
|
|
std::unique_ptr<TFE_Op, OpDeleter>>
|
|
thread_local_eager_operation_map; // NOLINT
|
|
thread_local std::unique_ptr<TF_Status> thread_local_tf_status = // NOLINT
|
|
nullptr;
|
|
|
|
std::unique_ptr<TFE_Op, OpDeleter> ReleaseThreadLocalOp(TFE_Context* ctx) {
|
|
auto it = thread_local_eager_operation_map.find(ctx);
|
|
if (it == thread_local_eager_operation_map.end()) {
|
|
return nullptr;
|
|
}
|
|
return std::move(it->second);
|
|
}
|
|
|
|
TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
|
|
const char* raw_device_name, TF_Status* status) {
|
|
auto op = ReleaseThreadLocalOp(ctx);
|
|
if (!op) {
|
|
op.reset(tensorflow::wrap(tensorflow::unwrap(ctx)->CreateOperation()));
|
|
}
|
|
status->status =
|
|
tensorflow::unwrap(op.get())->Reset(op_or_function_name, raw_device_name);
|
|
if (!status->status.ok()) {
|
|
op.reset();
|
|
}
|
|
return op.release();
|
|
}
|
|
|
|
void ReturnOp(TFE_Context* ctx, TFE_Op* op) {
|
|
if (op) {
|
|
tensorflow::unwrap(op)->Clear();
|
|
thread_local_eager_operation_map[ctx].reset(op);
|
|
}
|
|
}
|
|
|
|
TF_Status* ReleaseThreadLocalStatus() {
|
|
if (thread_local_tf_status == nullptr) {
|
|
return nullptr;
|
|
}
|
|
return thread_local_tf_status.release();
|
|
}
|
|
|
|
struct InputInfo {
|
|
InputInfo(int i, bool is_list) : i(i), is_list(is_list) {}
|
|
|
|
int i;
|
|
bool is_list = false;
|
|
};
|
|
|
|
// Takes in output gradients, returns input gradients.
|
|
typedef std::function<PyObject*(PyObject*,
|
|
const std::vector<tensorflow::int64>&)>
|
|
PyBackwardFunction;
|
|
|
|
using AttrToInputsMap =
|
|
tensorflow::gtl::FlatMap<string,
|
|
tensorflow::gtl::InlinedVector<InputInfo, 4>>;
|
|
|
|
tensorflow::gtl::FlatMap<string, AttrToInputsMap*>* GetAllAttrToInputsMaps() {
|
|
static auto* all_attr_to_input_maps =
|
|
new tensorflow::gtl::FlatMap<string, AttrToInputsMap*>;
|
|
return all_attr_to_input_maps;
|
|
}
|
|
|
|
// This function doesn't use a lock, since we depend on the GIL directly.
|
|
AttrToInputsMap* GetAttrToInputsMapHoldingGIL(const tensorflow::OpDef& op_def) {
|
|
#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 4
|
|
DCHECK(PyGILState_Check())
|
|
<< "This function needs to hold the GIL when called.";
|
|
#endif
|
|
auto* all_attr_to_input_maps = GetAllAttrToInputsMaps();
|
|
auto* output =
|
|
tensorflow::gtl::FindPtrOrNull(*all_attr_to_input_maps, op_def.name());
|
|
if (output != nullptr) {
|
|
return output;
|
|
}
|
|
|
|
std::unique_ptr<AttrToInputsMap> m(new AttrToInputsMap);
|
|
|
|
// Store a list of InputIndex -> List of corresponding inputs.
|
|
for (int i = 0; i < op_def.input_arg_size(); i++) {
|
|
if (!op_def.input_arg(i).type_attr().empty()) {
|
|
auto it = m->find(op_def.input_arg(i).type_attr());
|
|
if (it == m->end()) {
|
|
it = m->insert({op_def.input_arg(i).type_attr(), {}}).first;
|
|
}
|
|
it->second.emplace_back(i, !op_def.input_arg(i).number_attr().empty());
|
|
}
|
|
}
|
|
|
|
auto* retval = m.get();
|
|
(*all_attr_to_input_maps)[op_def.name()] = m.release();
|
|
|
|
return retval;
|
|
}
|
|
|
|
// This function doesn't use a lock, since we depend on the GIL directly.
|
|
tensorflow::gtl::FlatMap<
|
|
string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>*
|
|
GetAllAttrToDefaultsMaps() {
|
|
static auto* all_attr_to_defaults_maps = new tensorflow::gtl::FlatMap<
|
|
string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>;
|
|
return all_attr_to_defaults_maps;
|
|
}
|
|
|
|
tensorflow::gtl::FlatMap<string, tensorflow::DataType>*
|
|
GetAttrToDefaultsMapHoldingGIL(const tensorflow::OpDef& op_def) {
|
|
#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 4
|
|
DCHECK(PyGILState_Check())
|
|
<< "This function needs to hold the GIL when called.";
|
|
#endif
|
|
auto* all_attr_to_defaults_maps = GetAllAttrToDefaultsMaps();
|
|
auto* output =
|
|
tensorflow::gtl::FindPtrOrNull(*all_attr_to_defaults_maps, op_def.name());
|
|
if (output != nullptr) {
|
|
return output;
|
|
}
|
|
|
|
auto* new_map = new tensorflow::gtl::FlatMap<string, tensorflow::DataType>;
|
|
|
|
for (const auto& attr : op_def.attr()) {
|
|
if (attr.type() == "type" && attr.has_default_value()) {
|
|
new_map->insert({attr.name(), attr.default_value().type()});
|
|
}
|
|
}
|
|
|
|
(*all_attr_to_defaults_maps)[op_def.name()] = new_map;
|
|
|
|
return new_map;
|
|
}
|
|
|
|
struct FastPathOpExecInfo {
|
|
TFE_Context* ctx;
|
|
const char* device_name;
|
|
|
|
bool run_callbacks;
|
|
bool run_post_exec_callbacks;
|
|
bool run_gradient_callback;
|
|
|
|
// The op name of the main op being executed.
|
|
PyObject* name;
|
|
// The op type name of the main op being executed.
|
|
PyObject* op_name;
|
|
PyObject* callbacks;
|
|
|
|
// All the args passed into the FastPathOpExecInfo.
|
|
PyObject* args;
|
|
|
|
// DTypes can come from another input that has the same attr. So build that
|
|
// map.
|
|
const AttrToInputsMap* attr_to_inputs_map;
|
|
const tensorflow::gtl::FlatMap<string, tensorflow::DataType>* default_dtypes;
|
|
tensorflow::gtl::FlatMap<string, tensorflow::DataType> cached_dtypes;
|
|
};
|
|
|
|
#define PARSE_VALUE(fn_name, type, check_fn, parse_fn) \
|
|
bool fn_name(const string& key, PyObject* py_value, TF_Status* status, \
|
|
type* value) { \
|
|
if (check_fn(py_value)) { \
|
|
*value = static_cast<type>(parse_fn(py_value)); \
|
|
return true; \
|
|
} else { \
|
|
TF_SetStatus(status, TF_INVALID_ARGUMENT, \
|
|
tensorflow::strings::StrCat( \
|
|
"Expecting " #type " value for attr ", key, ", got ", \
|
|
py_value->ob_type->tp_name) \
|
|
.c_str()); \
|
|
return false; \
|
|
} \
|
|
}
|
|
|
|
#if PY_MAJOR_VERSION >= 3
|
|
PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong)
|
|
PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLongLong)
|
|
#else
|
|
PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong)
|
|
#endif
|
|
PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble)
|
|
#undef PARSE_VALUE
|
|
|
|
#if PY_MAJOR_VERSION < 3
|
|
bool ParseInt64Value(const string& key, PyObject* py_value, TF_Status* status,
|
|
int64_t* value) {
|
|
if (PyInt_Check(py_value)) {
|
|
*value = static_cast<int64_t>(PyInt_AsLong(py_value));
|
|
return true;
|
|
} else if (PyLong_Check(py_value)) {
|
|
*value = static_cast<int64_t>(PyLong_AsLong(py_value));
|
|
return true;
|
|
}
|
|
TF_SetStatus(
|
|
status, TF_INVALID_ARGUMENT,
|
|
tensorflow::strings::StrCat("Expecting int or long value for attr ", key,
|
|
", got ", py_value->ob_type->tp_name)
|
|
.c_str());
|
|
return false;
|
|
}
|
|
#endif
|
|
|
|
Py_ssize_t TensorShapeNumDims(PyObject* value) {
|
|
const auto size = PySequence_Size(value);
|
|
if (size == -1) {
|
|
// TensorShape.__len__ raises an error in the scenario where the shape is an
|
|
// unknown, which needs to be cleared.
|
|
// TODO(nareshmodi): ensure that this is actually a TensorShape.
|
|
PyErr_Clear();
|
|
}
|
|
return size;
|
|
}
|
|
|
|
bool IsInteger(PyObject* py_value) {
|
|
#if PY_MAJOR_VERSION >= 3
|
|
return PyLong_Check(py_value);
|
|
#else
|
|
return PyInt_Check(py_value) || PyLong_Check(py_value);
|
|
#endif
|
|
}
|
|
|
|
// This function considers a Dimension._value of None to be valid, and sets the
|
|
// value to be -1 in that case.
|
|
bool ParseDimensionValue(const string& key, PyObject* py_value,
|
|
TF_Status* status, int64_t* value) {
|
|
if (IsInteger(py_value)) {
|
|
return ParseInt64Value(key, py_value, status, value);
|
|
}
|
|
|
|
tensorflow::Safe_PyObjectPtr dimension_value(
|
|
PyObject_GetAttrString(py_value, "_value"));
|
|
if (dimension_value == nullptr) {
|
|
PyErr_Clear();
|
|
TF_SetStatus(
|
|
status, TF_INVALID_ARGUMENT,
|
|
tensorflow::strings::StrCat("Expecting a Dimension for attr ", key,
|
|
", got ", py_value->ob_type->tp_name)
|
|
.c_str());
|
|
return false;
|
|
}
|
|
|
|
if (dimension_value.get() == Py_None) {
|
|
*value = -1;
|
|
return true;
|
|
}
|
|
|
|
return ParseInt64Value(key, dimension_value.get(), status, value);
|
|
}
|
|
|
|
bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status,
|
|
tensorflow::StringPiece* value) {
|
|
if (PyBytes_Check(py_value)) {
|
|
Py_ssize_t size = 0;
|
|
char* buf = nullptr;
|
|
if (PyBytes_AsStringAndSize(py_value, &buf, &size) < 0) return false;
|
|
*value = tensorflow::StringPiece(buf, size);
|
|
return true;
|
|
}
|
|
#if PY_MAJOR_VERSION >= 3
|
|
if (PyUnicode_Check(py_value)) {
|
|
Py_ssize_t size = 0;
|
|
const char* buf = PyUnicode_AsUTF8AndSize(py_value, &size);
|
|
if (buf == nullptr) return false;
|
|
*value = tensorflow::StringPiece(buf, size);
|
|
return true;
|
|
}
|
|
#endif
|
|
TF_SetStatus(
|
|
status, TF_INVALID_ARGUMENT,
|
|
tensorflow::strings::StrCat("Expecting a string value for attr ", key,
|
|
", got ", py_value->ob_type->tp_name)
|
|
.c_str());
|
|
return false;
|
|
}
|
|
|
|
bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status,
|
|
unsigned char* value) {
|
|
*value = PyObject_IsTrue(py_value);
|
|
return true;
|
|
}
|
|
|
|
// The passed in py_value is expected to be an object of the python type
|
|
// dtypes.DType or an int.
|
|
bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status,
|
|
int* value) {
|
|
if (IsInteger(py_value)) {
|
|
return ParseIntValue(key, py_value, status, value);
|
|
}
|
|
|
|
tensorflow::Safe_PyObjectPtr py_type_enum(
|
|
PyObject_GetAttrString(py_value, "_type_enum"));
|
|
if (py_type_enum == nullptr) {
|
|
PyErr_Clear();
|
|
TF_SetStatus(
|
|
status, TF_INVALID_ARGUMENT,
|
|
tensorflow::strings::StrCat("Expecting a DType.dtype for attr ", key,
|
|
", got ", py_value->ob_type->tp_name)
|
|
.c_str());
|
|
return false;
|
|
}
|
|
|
|
return ParseIntValue(key, py_type_enum.get(), status, value);
|
|
}
|
|
|
|
bool SetOpAttrList(
|
|
TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_list,
|
|
TF_AttrType type,
|
|
tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
|
|
TF_Status* status) {
|
|
if (!PySequence_Check(py_list)) {
|
|
TF_SetStatus(
|
|
status, TF_INVALID_ARGUMENT,
|
|
tensorflow::strings::StrCat("Expecting sequence value for attr ", key,
|
|
", got ", py_list->ob_type->tp_name)
|
|
.c_str());
|
|
return false;
|
|
}
|
|
const int num_values = PySequence_Size(py_list);
|
|
if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values;
|
|
|
|
#define PARSE_LIST(c_type, parse_fn) \
|
|
std::unique_ptr<c_type[]> values(new c_type[num_values]); \
|
|
for (int i = 0; i < num_values; ++i) { \
|
|
tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); \
|
|
if (!parse_fn(key, py_value.get(), status, &values[i])) return false; \
|
|
}
|
|
|
|
if (type == TF_ATTR_STRING) {
|
|
std::unique_ptr<const void*[]> values(new const void*[num_values]);
|
|
std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
|
|
for (int i = 0; i < num_values; ++i) {
|
|
tensorflow::StringPiece value;
|
|
tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
|
|
if (!ParseStringValue(key, py_value.get(), status, &value)) return false;
|
|
values[i] = value.data();
|
|
lengths[i] = value.size();
|
|
}
|
|
TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
|
|
} else if (type == TF_ATTR_INT) {
|
|
PARSE_LIST(int64_t, ParseInt64Value);
|
|
TFE_OpSetAttrIntList(op, key, values.get(), num_values);
|
|
} else if (type == TF_ATTR_FLOAT) {
|
|
PARSE_LIST(float, ParseFloatValue);
|
|
TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
|
|
} else if (type == TF_ATTR_BOOL) {
|
|
PARSE_LIST(unsigned char, ParseBoolValue);
|
|
TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
|
|
} else if (type == TF_ATTR_TYPE) {
|
|
PARSE_LIST(int, ParseTypeValue);
|
|
TFE_OpSetAttrTypeList(op, key,
|
|
reinterpret_cast<const TF_DataType*>(values.get()),
|
|
num_values);
|
|
} else if (type == TF_ATTR_SHAPE) {
|
|
// Make one pass through the input counting the total number of
|
|
// dims across all the input lists.
|
|
int total_dims = 0;
|
|
for (int i = 0; i < num_values; ++i) {
|
|
tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
|
|
if (py_value.get() != Py_None) {
|
|
if (!PySequence_Check(py_value.get())) {
|
|
TF_SetStatus(
|
|
status, TF_INVALID_ARGUMENT,
|
|
tensorflow::strings::StrCat(
|
|
"Expecting None or sequence value for element", i,
|
|
" of attr ", key, ", got ", py_value->ob_type->tp_name)
|
|
.c_str());
|
|
return false;
|
|
}
|
|
const auto size = TensorShapeNumDims(py_value.get());
|
|
if (size >= 0) {
|
|
total_dims += size;
|
|
}
|
|
}
|
|
}
|
|
// Allocate a buffer that can fit all of the dims together.
|
|
std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
|
|
// Copy the input dims into the buffer and set dims to point to
|
|
// the start of each list's dims.
|
|
std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
|
|
std::unique_ptr<int[]> num_dims(new int[num_values]);
|
|
int64_t* offset = buffer.get();
|
|
for (int i = 0; i < num_values; ++i) {
|
|
tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
|
|
if (py_value.get() == Py_None) {
|
|
dims[i] = nullptr;
|
|
num_dims[i] = -1;
|
|
} else {
|
|
const auto size = TensorShapeNumDims(py_value.get());
|
|
if (size == -1) {
|
|
dims[i] = nullptr;
|
|
num_dims[i] = -1;
|
|
continue;
|
|
}
|
|
dims[i] = offset;
|
|
num_dims[i] = size;
|
|
for (int j = 0; j < size; ++j) {
|
|
tensorflow::Safe_PyObjectPtr inner_py_value(
|
|
PySequence_ITEM(py_value.get(), j));
|
|
if (inner_py_value.get() == Py_None) {
|
|
*offset = -1;
|
|
} else if (!ParseDimensionValue(key, inner_py_value.get(), status,
|
|
offset)) {
|
|
return false;
|
|
}
|
|
++offset;
|
|
}
|
|
}
|
|
}
|
|
TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
|
|
status);
|
|
if (!status->status.ok()) return false;
|
|
} else if (type == TF_ATTR_FUNC) {
|
|
std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
|
|
for (int i = 0; i < num_values; ++i) {
|
|
tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
|
|
// Allow:
|
|
// (1) String function name, OR
|
|
// (2) A Python object with a .name attribute
|
|
// (A crude test for being a
|
|
// tensorflow.python.framework.function._DefinedFunction)
|
|
// (which is what the various "defun" or "Defun" decorators do).
|
|
// And in the future also allow an object that can encapsulate
|
|
// the function name and its attribute values.
|
|
tensorflow::StringPiece func_name;
|
|
if (!ParseStringValue(key, py_value.get(), status, &func_name)) {
|
|
PyObject* name_attr = PyObject_GetAttrString(py_value.get(), "name");
|
|
if (name_attr == nullptr ||
|
|
!ParseStringValue(key, name_attr, status, &func_name)) {
|
|
TF_SetStatus(
|
|
status, TF_INVALID_ARGUMENT,
|
|
tensorflow::strings::StrCat(
|
|
"unable to set function value attribute from a ",
|
|
py_value.get()->ob_type->tp_name,
|
|
" object. If you think this is an error, please file an "
|
|
"issue at "
|
|
"https://github.com/tensorflow/tensorflow/issues/new")
|
|
.c_str());
|
|
return false;
|
|
}
|
|
}
|
|
funcs[i] = TFE_NewOp(ctx, func_name.data(), status);
|
|
if (!status->status.ok()) return false;
|
|
}
|
|
TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
|
|
if (!status->status.ok()) return false;
|
|
} else {
|
|
TF_SetStatus(status, TF_UNIMPLEMENTED,
|
|
tensorflow::strings::StrCat("Attr ", key,
|
|
" has unhandled list type ", type)
|
|
.c_str());
|
|
return false;
|
|
}
|
|
#undef PARSE_LIST
|
|
return true;
|
|
}
|
|
|
|
TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
|
|
TF_Status* status) {
|
|
TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
|
|
for (const auto& attr : func.attr()) {
|
|
if (!status->status.ok()) return nullptr;
|
|
SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
|
|
if (!status->status.ok()) return nullptr;
|
|
}
|
|
return func_op;
|
|
}
|
|
|
|
void SetOpAttrListDefault(
|
|
TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
|
|
const char* key, TF_AttrType type,
|
|
tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
|
|
TF_Status* status) {
|
|
if (type == TF_ATTR_STRING) {
|
|
int num_values = attr.default_value().list().s_size();
|
|
std::unique_ptr<const void*[]> values(new const void*[num_values]);
|
|
std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
|
|
(*attr_list_sizes)[key] = num_values;
|
|
for (int i = 0; i < num_values; i++) {
|
|
const string& v = attr.default_value().list().s(i);
|
|
values[i] = v.data();
|
|
lengths[i] = v.size();
|
|
}
|
|
TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
|
|
} else if (type == TF_ATTR_INT) {
|
|
int num_values = attr.default_value().list().i_size();
|
|
std::unique_ptr<int64_t[]> values(new int64_t[num_values]);
|
|
(*attr_list_sizes)[key] = num_values;
|
|
for (int i = 0; i < num_values; i++) {
|
|
values[i] = attr.default_value().list().i(i);
|
|
}
|
|
TFE_OpSetAttrIntList(op, key, values.get(), num_values);
|
|
} else if (type == TF_ATTR_FLOAT) {
|
|
int num_values = attr.default_value().list().f_size();
|
|
std::unique_ptr<float[]> values(new float[num_values]);
|
|
(*attr_list_sizes)[key] = num_values;
|
|
for (int i = 0; i < num_values; i++) {
|
|
values[i] = attr.default_value().list().f(i);
|
|
}
|
|
TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
|
|
} else if (type == TF_ATTR_BOOL) {
|
|
int num_values = attr.default_value().list().b_size();
|
|
std::unique_ptr<unsigned char[]> values(new unsigned char[num_values]);
|
|
(*attr_list_sizes)[key] = num_values;
|
|
for (int i = 0; i < num_values; i++) {
|
|
values[i] = attr.default_value().list().b(i);
|
|
}
|
|
TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
|
|
} else if (type == TF_ATTR_TYPE) {
|
|
int num_values = attr.default_value().list().type_size();
|
|
std::unique_ptr<int[]> values(new int[num_values]);
|
|
(*attr_list_sizes)[key] = num_values;
|
|
for (int i = 0; i < num_values; i++) {
|
|
values[i] = attr.default_value().list().type(i);
|
|
}
|
|
TFE_OpSetAttrTypeList(op, key,
|
|
reinterpret_cast<const TF_DataType*>(values.get()),
|
|
attr.default_value().list().type_size());
|
|
} else if (type == TF_ATTR_SHAPE) {
|
|
int num_values = attr.default_value().list().shape_size();
|
|
(*attr_list_sizes)[key] = num_values;
|
|
int total_dims = 0;
|
|
for (int i = 0; i < num_values; ++i) {
|
|
if (!attr.default_value().list().shape(i).unknown_rank()) {
|
|
total_dims += attr.default_value().list().shape(i).dim_size();
|
|
}
|
|
}
|
|
// Allocate a buffer that can fit all of the dims together.
|
|
std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
|
|
// Copy the input dims into the buffer and set dims to point to
|
|
// the start of each list's dims.
|
|
std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
|
|
std::unique_ptr<int[]> num_dims(new int[num_values]);
|
|
int64_t* offset = buffer.get();
|
|
for (int i = 0; i < num_values; ++i) {
|
|
const auto& shape = attr.default_value().list().shape(i);
|
|
if (shape.unknown_rank()) {
|
|
dims[i] = nullptr;
|
|
num_dims[i] = -1;
|
|
} else {
|
|
for (int j = 0; j < shape.dim_size(); j++) {
|
|
*offset = shape.dim(j).size();
|
|
++offset;
|
|
}
|
|
}
|
|
}
|
|
TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
|
|
status);
|
|
} else if (type == TF_ATTR_FUNC) {
|
|
int num_values = attr.default_value().list().func_size();
|
|
(*attr_list_sizes)[key] = num_values;
|
|
std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
|
|
for (int i = 0; i < num_values; i++) {
|
|
funcs[i] = GetFunc(ctx, attr.default_value().list().func(i), status);
|
|
}
|
|
TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
|
|
} else {
|
|
TF_SetStatus(status, TF_UNIMPLEMENTED,
|
|
"Lists of tensors are not yet implemented for default valued "
|
|
"attributes for an operation.");
|
|
}
|
|
}
|
|
|
|
bool SetOpAttrScalar(
|
|
TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_value,
|
|
TF_AttrType type,
|
|
tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
|
|
TF_Status* status) {
|
|
if (type == TF_ATTR_STRING) {
|
|
tensorflow::StringPiece value;
|
|
if (!ParseStringValue(key, py_value, status, &value)) return false;
|
|
TFE_OpSetAttrString(op, key, value.data(), value.size());
|
|
} else if (type == TF_ATTR_INT) {
|
|
int64_t value;
|
|
if (!ParseInt64Value(key, py_value, status, &value)) return false;
|
|
TFE_OpSetAttrInt(op, key, value);
|
|
// attr_list_sizes is set for all int attributes (since at this point we are
|
|
// not aware if that attribute might be used to calculate the size of an
|
|
// output list or not).
|
|
if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = value;
|
|
} else if (type == TF_ATTR_FLOAT) {
|
|
float value;
|
|
if (!ParseFloatValue(key, py_value, status, &value)) return false;
|
|
TFE_OpSetAttrFloat(op, key, value);
|
|
} else if (type == TF_ATTR_BOOL) {
|
|
unsigned char value;
|
|
if (!ParseBoolValue(key, py_value, status, &value)) return false;
|
|
TFE_OpSetAttrBool(op, key, value);
|
|
} else if (type == TF_ATTR_TYPE) {
|
|
int value;
|
|
if (!ParseTypeValue(key, py_value, status, &value)) return false;
|
|
TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value));
|
|
} else if (type == TF_ATTR_SHAPE) {
|
|
if (py_value == Py_None) {
|
|
TFE_OpSetAttrShape(op, key, nullptr, -1, status);
|
|
} else {
|
|
if (!PySequence_Check(py_value)) {
|
|
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
|
tensorflow::strings::StrCat(
|
|
"Expecting None or sequence value for attr", key,
|
|
", got ", py_value->ob_type->tp_name)
|
|
.c_str());
|
|
return false;
|
|
}
|
|
const auto num_dims = TensorShapeNumDims(py_value);
|
|
if (num_dims == -1) {
|
|
TFE_OpSetAttrShape(op, key, nullptr, -1, status);
|
|
return true;
|
|
}
|
|
std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
|
|
for (int i = 0; i < num_dims; ++i) {
|
|
tensorflow::Safe_PyObjectPtr inner_py_value(
|
|
PySequence_ITEM(py_value, i));
|
|
if (inner_py_value.get() == Py_None) {
|
|
dims[i] = -1;
|
|
} else if (!ParseDimensionValue(key, inner_py_value.get(), status,
|
|
&dims[i])) {
|
|
return false;
|
|
}
|
|
}
|
|
TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status);
|
|
}
|
|
if (!status->status.ok()) return false;
|
|
} else if (type == TF_ATTR_FUNC) {
|
|
// Allow:
|
|
// (1) String function name, OR
|
|
// (2) A Python object with a .name attribute
|
|
// (A crude test for being a
|
|
// tensorflow.python.framework.function._DefinedFunction)
|
|
// (which is what the various "defun" or "Defun" decorators do).
|
|
// And in the future also allow an object that can encapsulate
|
|
// the function name and its attribute values.
|
|
tensorflow::StringPiece func_name;
|
|
if (!ParseStringValue(key, py_value, status, &func_name)) {
|
|
PyObject* name_attr = PyObject_GetAttrString(py_value, "name");
|
|
if (name_attr == nullptr ||
|
|
!ParseStringValue(key, name_attr, status, &func_name)) {
|
|
TF_SetStatus(
|
|
status, TF_INVALID_ARGUMENT,
|
|
tensorflow::strings::StrCat(
|
|
"unable to set function value attribute from a ",
|
|
py_value->ob_type->tp_name,
|
|
" object. If you think this is an error, please file an issue "
|
|
"at https://github.com/tensorflow/tensorflow/issues/new")
|
|
.c_str());
|
|
return false;
|
|
}
|
|
}
|
|
TF_SetStatus(status, TF_OK, "");
|
|
TFE_OpSetAttrFunctionName(op, key, func_name.data(), func_name.size());
|
|
} else {
|
|
TF_SetStatus(
|
|
status, TF_UNIMPLEMENTED,
|
|
tensorflow::strings::StrCat("Attr ", key, " has unhandled type ", type)
|
|
.c_str());
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void SetOpAttrScalarDefault(
|
|
TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value,
|
|
const char* attr_name,
|
|
tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
|
|
TF_Status* status) {
|
|
SetOpAttrValueScalar(ctx, op, default_value, attr_name, status);
|
|
if (default_value.value_case() == tensorflow::AttrValue::kI) {
|
|
(*attr_list_sizes)[attr_name] = default_value.i();
|
|
}
|
|
}
|
|
|
|
// start_index is the index at which the Tuple/List attrs will start getting
|
|
// processed.
|
|
void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index,
|
|
TF_Status* out_status) {
|
|
if (attrs == Py_None) return;
|
|
Py_ssize_t len = PyTuple_GET_SIZE(attrs) - start_index;
|
|
if ((len & 1) != 0) {
|
|
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
|
|
"Expecting attrs tuple to have even length.");
|
|
return;
|
|
}
|
|
// Parse attrs
|
|
for (Py_ssize_t i = 0; i < len; i += 2) {
|
|
PyObject* py_key = PyTuple_GET_ITEM(attrs, start_index + i);
|
|
PyObject* py_value = PyTuple_GET_ITEM(attrs, start_index + i + 1);
|
|
#if PY_MAJOR_VERSION >= 3
|
|
const char* key = PyBytes_Check(py_key) ? PyBytes_AsString(py_key)
|
|
: PyUnicode_AsUTF8(py_key);
|
|
#else
|
|
const char* key = PyBytes_AsString(py_key);
|
|
#endif
|
|
unsigned char is_list = 0;
|
|
const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status);
|
|
if (!out_status->status.ok()) return;
|
|
if (is_list != 0) {
|
|
if (!SetOpAttrList(ctx, op, key, py_value, type, nullptr, out_status))
|
|
return;
|
|
} else {
|
|
if (!SetOpAttrScalar(ctx, op, key, py_value, type, nullptr, out_status))
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
// This function will set the op attrs required. If an attr has the value of
|
|
// None, then it will read the AttrDef to get the default value and set that
|
|
// instead. Any failure in this function will simply fall back to the slow
|
|
// path.
|
|
void SetOpAttrWithDefaults(
|
|
TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
|
|
const char* attr_name, PyObject* attr_value,
|
|
tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
|
|
TF_Status* status) {
|
|
unsigned char is_list = 0;
|
|
const TF_AttrType type = TFE_OpGetAttrType(op, attr_name, &is_list, status);
|
|
if (!status->status.ok()) return;
|
|
if (attr_value == Py_None) {
|
|
if (is_list != 0) {
|
|
SetOpAttrListDefault(ctx, op, attr, attr_name, type, attr_list_sizes,
|
|
status);
|
|
} else {
|
|
SetOpAttrScalarDefault(ctx, op, attr.default_value(), attr_name,
|
|
attr_list_sizes, status);
|
|
}
|
|
} else {
|
|
if (is_list != 0) {
|
|
SetOpAttrList(ctx, op, attr_name, attr_value, type, attr_list_sizes,
|
|
status);
|
|
} else {
|
|
SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes,
|
|
status);
|
|
}
|
|
}
|
|
}
|
|
|
|
PyObject* GetPythonObjectFromInt(int num) {
|
|
#if PY_MAJOR_VERSION >= 3
|
|
return PyLong_FromLong(num);
|
|
#else
|
|
return PyInt_FromLong(num);
|
|
#endif
|
|
}
|
|
|
|
// Python subclass of Exception that is created on not ok Status.
|
|
tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
|
|
PyObject* exception_class TF_GUARDED_BY(exception_class_mutex) = nullptr;
|
|
|
|
// Python subclass of Exception that is created to signal fallback.
|
|
PyObject* fallback_exception_class = nullptr;
|
|
|
|
// Python function that returns input gradients given output gradients.
|
|
PyObject* gradient_function = nullptr;
|
|
|
|
// Python function that returns output gradients given input gradients.
|
|
PyObject* forward_gradient_function = nullptr;
|
|
|
|
static std::atomic<int64_t> _uid;
|
|
|
|
} // namespace
|
|
|
|
TF_Status* GetStatus() {
|
|
TF_Status* maybe_status = ReleaseThreadLocalStatus();
|
|
if (maybe_status) {
|
|
TF_SetStatus(maybe_status, TF_OK, "");
|
|
return maybe_status;
|
|
} else {
|
|
return TF_NewStatus();
|
|
}
|
|
}
|
|
|
|
void ReturnStatus(TF_Status* status) {
|
|
TF_SetStatus(status, TF_OK, "");
|
|
thread_local_tf_status.reset(status);
|
|
}
|
|
|
|
void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
|
|
const char* op_name, TFE_InputTensorHandles* inputs,
|
|
PyObject* attrs, TFE_OutputTensorHandles* outputs,
|
|
TF_Status* out_status) {
|
|
TFE_Py_ExecuteCancelable(ctx, device_name, op_name, inputs, attrs,
|
|
/*cancellation_manager=*/nullptr, outputs,
|
|
out_status);
|
|
}
|
|
|
|
void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name,
|
|
const char* op_name,
|
|
TFE_InputTensorHandles* inputs, PyObject* attrs,
|
|
TFE_CancellationManager* cancellation_manager,
|
|
TFE_OutputTensorHandles* outputs,
|
|
TF_Status* out_status) {
|
|
tensorflow::profiler::TraceMe activity(
|
|
"TFE_Py_ExecuteCancelable", tensorflow::profiler::TraceMeLevel::kInfo);
|
|
|
|
TFE_Op* op = GetOp(ctx, op_name, device_name, out_status);
|
|
|
|
auto cleaner = tensorflow::gtl::MakeCleanup([ctx, op] { ReturnOp(ctx, op); });
|
|
if (!out_status->status.ok()) return;
|
|
|
|
tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace(
|
|
tensorflow::StackTrace::kStackTraceInitialSize));
|
|
|
|
for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) {
|
|
TFE_OpAddInput(op, inputs->at(i), out_status);
|
|
}
|
|
if (cancellation_manager && out_status->status.ok()) {
|
|
TFE_OpSetCancellationManager(op, cancellation_manager, out_status);
|
|
}
|
|
if (out_status->status.ok()) {
|
|
SetOpAttrs(ctx, op, attrs, 0, out_status);
|
|
}
|
|
Py_BEGIN_ALLOW_THREADS;
|
|
|
|
int num_outputs = outputs->size();
|
|
|
|
if (out_status->status.ok()) {
|
|
TFE_Execute(op, outputs->data(), &num_outputs, out_status);
|
|
}
|
|
|
|
if (out_status->status.ok()) {
|
|
outputs->resize(num_outputs);
|
|
} else {
|
|
TF_SetStatus(out_status, TF_GetCode(out_status),
|
|
tensorflow::strings::StrCat(TF_Message(out_status),
|
|
" [Op:", op_name, "]")
|
|
.c_str());
|
|
}
|
|
|
|
Py_END_ALLOW_THREADS;
|
|
}
|
|
|
|
PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
|
|
tensorflow::mutex_lock l(exception_class_mutex);
|
|
if (exception_class != nullptr) {
|
|
Py_DECREF(exception_class);
|
|
}
|
|
if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
|
|
exception_class = nullptr;
|
|
PyErr_SetString(PyExc_TypeError,
|
|
"TFE_Py_RegisterExceptionClass: "
|
|
"Registered class should be subclass of Exception.");
|
|
return nullptr;
|
|
}
|
|
|
|
Py_INCREF(e);
|
|
exception_class = e;
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) {
|
|
if (fallback_exception_class != nullptr) {
|
|
Py_DECREF(fallback_exception_class);
|
|
}
|
|
if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
|
|
fallback_exception_class = nullptr;
|
|
PyErr_SetString(PyExc_TypeError,
|
|
"TFE_Py_RegisterFallbackExceptionClass: "
|
|
"Registered class should be subclass of Exception.");
|
|
return nullptr;
|
|
} else {
|
|
Py_INCREF(e);
|
|
fallback_exception_class = e;
|
|
Py_RETURN_NONE;
|
|
}
|
|
}
|
|
|
|
PyObject* TFE_Py_RegisterGradientFunction(PyObject* e) {
|
|
if (gradient_function != nullptr) {
|
|
Py_DECREF(gradient_function);
|
|
}
|
|
if (!PyCallable_Check(e)) {
|
|
gradient_function = nullptr;
|
|
PyErr_SetString(PyExc_TypeError,
|
|
"TFE_Py_RegisterGradientFunction: "
|
|
"Registered object should be function.");
|
|
return nullptr;
|
|
} else {
|
|
Py_INCREF(e);
|
|
gradient_function = e;
|
|
Py_RETURN_NONE;
|
|
}
|
|
}
|
|
|
|
PyObject* TFE_Py_RegisterJVPFunction(PyObject* e) {
|
|
if (forward_gradient_function != nullptr) {
|
|
Py_DECREF(forward_gradient_function);
|
|
}
|
|
if (!PyCallable_Check(e)) {
|
|
forward_gradient_function = nullptr;
|
|
PyErr_SetString(PyExc_TypeError,
|
|
"TFE_Py_RegisterJVPFunction: "
|
|
"Registered object should be function.");
|
|
return nullptr;
|
|
} else {
|
|
Py_INCREF(e);
|
|
forward_gradient_function = e;
|
|
Py_RETURN_NONE;
|
|
}
|
|
}
|
|
|
|
void RaiseFallbackException(const char* message) {
|
|
if (fallback_exception_class != nullptr) {
|
|
PyErr_SetString(fallback_exception_class, message);
|
|
return;
|
|
}
|
|
|
|
PyErr_SetString(
|
|
PyExc_RuntimeError,
|
|
tensorflow::strings::StrCat(
|
|
"Fallback exception type not set, attempting to fallback due to ",
|
|
message)
|
|
.data());
|
|
}
|
|
|
|
// Format and return `status`' error message with the attached stack trace if
|
|
// available. `status` must have an error.
|
|
std::string FormatErrorStatusStackTrace(const tensorflow::Status& status) {
|
|
tensorflow::DCheckPyGilState();
|
|
DCHECK(!status.ok());
|
|
|
|
if (status.stack_trace().empty()) return status.error_message();
|
|
|
|
const std::vector<tensorflow::StackFrame>& stack_trace = status.stack_trace();
|
|
|
|
PyObject* linecache = PyImport_ImportModule("linecache");
|
|
PyObject* getline =
|
|
PyObject_GetAttr(linecache, PyUnicode_FromString("getline"));
|
|
DCHECK(getline);
|
|
|
|
std::ostringstream result;
|
|
result << "Exception originated from\n\n";
|
|
|
|
for (const tensorflow::StackFrame& stack_frame : stack_trace) {
|
|
PyObject* line_str_obj = PyObject_CallFunction(
|
|
getline, const_cast<char*>("si"), stack_frame.file_name.c_str(),
|
|
stack_frame.line_number);
|
|
tensorflow::StringPiece line_str = TFE_GetPythonString(line_str_obj);
|
|
tensorflow::str_util::RemoveWhitespaceContext(&line_str);
|
|
result << " File \"" << stack_frame.file_name << "\", line "
|
|
<< stack_frame.line_number << ", in " << stack_frame.function_name
|
|
<< '\n';
|
|
|
|
if (!line_str.empty()) result << " " << line_str << '\n';
|
|
Py_XDECREF(line_str_obj);
|
|
}
|
|
|
|
Py_DecRef(getline);
|
|
Py_DecRef(linecache);
|
|
|
|
result << '\n' << status.error_message();
|
|
return result.str();
|
|
}
|
|
|
|
int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) {
|
|
if (status->status.ok()) return 0;
|
|
const char* msg = TF_Message(status);
|
|
if (exception == nullptr) {
|
|
tensorflow::mutex_lock l(exception_class_mutex);
|
|
if (exception_class != nullptr) {
|
|
tensorflow::Safe_PyObjectPtr val(Py_BuildValue(
|
|
"si", FormatErrorStatusStackTrace(status->status).c_str(),
|
|
TF_GetCode(status)));
|
|
if (PyErr_Occurred()) {
|
|
// NOTE: This hides the actual error (i.e. the reason `status` was not
|
|
// TF_OK), but there is nothing we can do at this point since we can't
|
|
// generate a reasonable error from the status.
|
|
// Consider adding a message explaining this.
|
|
return -1;
|
|
}
|
|
PyErr_SetObject(exception_class, val.get());
|
|
return -1;
|
|
} else {
|
|
exception = PyExc_RuntimeError;
|
|
}
|
|
}
|
|
// May be update already set exception.
|
|
PyErr_SetString(exception, msg);
|
|
return -1;
|
|
}
|
|
|
|
int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
|
|
PyObject* exception) {
|
|
if (status.ok()) return 0;
|
|
const char* msg = status.error_message().c_str();
|
|
if (exception == nullptr) {
|
|
tensorflow::mutex_lock l(exception_class_mutex);
|
|
if (exception_class != nullptr) {
|
|
tensorflow::Safe_PyObjectPtr val(Py_BuildValue(
|
|
"si", FormatErrorStatusStackTrace(status).c_str(), status.code()));
|
|
PyErr_SetObject(exception_class, val.get());
|
|
return -1;
|
|
} else {
|
|
exception = PyExc_RuntimeError;
|
|
}
|
|
}
|
|
// May be update already set exception.
|
|
PyErr_SetString(exception, msg);
|
|
return -1;
|
|
}
|
|
|
|
const char* TFE_GetPythonString(PyObject* o) {
|
|
#if PY_MAJOR_VERSION >= 3
|
|
if (PyBytes_Check(o)) {
|
|
return PyBytes_AsString(o);
|
|
} else {
|
|
return PyUnicode_AsUTF8(o);
|
|
}
|
|
#else
|
|
return PyBytes_AsString(o);
|
|
#endif
|
|
}
|
|
|
|
int64_t get_uid() { return _uid++; }
|
|
|
|
PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); }
|
|
|
|
void TFE_DeleteContextCapsule(PyObject* context) {
|
|
TFE_Context* ctx =
|
|
reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr));
|
|
auto op = ReleaseThreadLocalOp(ctx);
|
|
op.reset();
|
|
TFE_DeleteContext(ctx);
|
|
}
|
|
|
|
static tensorflow::int64 MakeInt(PyObject* integer) {
|
|
#if PY_MAJOR_VERSION >= 3
|
|
return PyLong_AsLong(integer);
|
|
#else
|
|
return PyInt_AsLong(integer);
|
|
#endif
|
|
}
|
|
|
|
static tensorflow::int64 FastTensorId(PyObject* tensor) {
|
|
if (EagerTensor_CheckExact(tensor)) {
|
|
return PyEagerTensor_ID(tensor);
|
|
}
|
|
PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
|
|
if (id_field == nullptr) {
|
|
return -1;
|
|
}
|
|
tensorflow::int64 id = MakeInt(id_field);
|
|
Py_DECREF(id_field);
|
|
return id;
|
|
}
|
|
|
|
namespace tensorflow {
|
|
DataType PyTensor_DataType(PyObject* tensor) {
|
|
if (EagerTensor_CheckExact(tensor)) {
|
|
return PyEagerTensor_Dtype(tensor);
|
|
} else {
|
|
#if PY_MAJOR_VERSION < 3
|
|
// Python 2.x:
|
|
static PyObject* dtype_attr = PyString_InternFromString("dtype");
|
|
static PyObject* type_enum_attr = PyString_InternFromString("_type_enum");
|
|
#else
|
|
// Python 3.x:
|
|
static PyObject* dtype_attr = PyUnicode_InternFromString("dtype");
|
|
static PyObject* type_enum_attr = PyUnicode_InternFromString("_type_enum");
|
|
#endif
|
|
Safe_PyObjectPtr dtype_field(PyObject_GetAttr(tensor, dtype_attr));
|
|
if (!dtype_field) {
|
|
return DT_INVALID;
|
|
}
|
|
|
|
Safe_PyObjectPtr enum_field(
|
|
PyObject_GetAttr(dtype_field.get(), type_enum_attr));
|
|
if (!enum_field) {
|
|
return DT_INVALID;
|
|
}
|
|
|
|
return static_cast<DataType>(MakeInt(enum_field.get()));
|
|
}
|
|
}
|
|
} // namespace tensorflow
|
|
|
|
class PyTapeTensor {
|
|
public:
|
|
PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
|
|
const tensorflow::TensorShape& shape)
|
|
: id_(id), dtype_(dtype), shape_(shape) {}
|
|
PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
|
|
PyObject* shape)
|
|
: id_(id), dtype_(dtype), shape_(shape) {
|
|
Py_INCREF(absl::get<1>(shape_));
|
|
}
|
|
PyTapeTensor(const PyTapeTensor& other) {
|
|
id_ = other.id_;
|
|
dtype_ = other.dtype_;
|
|
shape_ = other.shape_;
|
|
if (shape_.index() == 1) {
|
|
Py_INCREF(absl::get<1>(shape_));
|
|
}
|
|
}
|
|
|
|
~PyTapeTensor() {
|
|
if (shape_.index() == 1) {
|
|
Py_DECREF(absl::get<1>(shape_));
|
|
}
|
|
}
|
|
PyObject* GetShape() const;
|
|
PyObject* GetPyDType() const { return PyLong_FromLong(dtype_); }
|
|
tensorflow::int64 GetID() const { return id_; }
|
|
tensorflow::DataType GetDType() const { return dtype_; }
|
|
|
|
PyObject* OnesLike() const;
|
|
PyObject* ZerosLike() const;
|
|
|
|
private:
|
|
tensorflow::int64 id_;
|
|
tensorflow::DataType dtype_;
|
|
|
|
// Note that if shape_.index() == 1, meaning shape_ contains a PyObject, that
|
|
// PyObject is the tensor itself. This is used to support tf.shape(tensor) for
|
|
// partially-defined shapes and tf.zeros_like(tensor) for variant-dtype
|
|
// tensors.
|
|
absl::variant<tensorflow::TensorShape, PyObject*> shape_;
|
|
};
|
|
|
|
static PyTapeTensor TapeTensorFromTensor(PyObject* tensor);
|
|
|
|
class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
|
|
PyTapeTensor> {
|
|
public:
|
|
explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
|
|
Py_INCREF(py_vspace_);
|
|
}
|
|
|
|
tensorflow::Status Initialize() {
|
|
num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
|
|
if (num_elements_ == nullptr) {
|
|
return tensorflow::errors::InvalidArgument("invalid vspace");
|
|
}
|
|
aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
|
|
if (aggregate_fn_ == nullptr) {
|
|
return tensorflow::errors::InvalidArgument("invalid vspace");
|
|
}
|
|
zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn");
|
|
if (zeros_fn_ == nullptr) {
|
|
return tensorflow::errors::InvalidArgument("invalid vspace");
|
|
}
|
|
zeros_like_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_like_fn");
|
|
if (zeros_like_fn_ == nullptr) {
|
|
return tensorflow::errors::InvalidArgument("invalid vspace");
|
|
}
|
|
ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn");
|
|
if (ones_fn_ == nullptr) {
|
|
return tensorflow::errors::InvalidArgument("invalid vspace");
|
|
}
|
|
ones_like_fn_ = PyObject_GetAttrString(py_vspace_, "ones_like_fn");
|
|
if (ones_like_fn_ == nullptr) {
|
|
return tensorflow::errors::InvalidArgument("invalid vspace");
|
|
}
|
|
graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn");
|
|
if (graph_shape_fn_ == nullptr) {
|
|
return tensorflow::errors::InvalidArgument("invalid vspace");
|
|
}
|
|
return tensorflow::Status::OK();
|
|
}
|
|
|
|
~PyVSpace() override {
|
|
Py_XDECREF(num_elements_);
|
|
Py_XDECREF(aggregate_fn_);
|
|
Py_XDECREF(zeros_fn_);
|
|
Py_XDECREF(zeros_like_fn_);
|
|
Py_XDECREF(ones_fn_);
|
|
Py_XDECREF(ones_like_fn_);
|
|
Py_XDECREF(graph_shape_fn_);
|
|
|
|
Py_DECREF(py_vspace_);
|
|
}
|
|
|
|
tensorflow::int64 NumElements(PyObject* tensor) const final {
|
|
if (EagerTensor_CheckExact(tensor)) {
|
|
return PyEagerTensor_NumElements(tensor);
|
|
}
|
|
PyObject* arglist =
|
|
Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
|
|
PyObject* result = PyEval_CallObject(num_elements_, arglist);
|
|
Py_DECREF(arglist);
|
|
if (result == nullptr) {
|
|
// The caller detects whether a python exception has been raised.
|
|
return -1;
|
|
}
|
|
tensorflow::int64 r = MakeInt(result);
|
|
Py_DECREF(result);
|
|
return r;
|
|
}
|
|
|
|
PyObject* AggregateGradients(
|
|
tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
|
|
PyObject* list = PyList_New(gradient_tensors.size());
|
|
for (int i = 0; i < gradient_tensors.size(); ++i) {
|
|
// Note: stealing a reference to the gradient tensors.
|
|
CHECK(gradient_tensors[i] != nullptr);
|
|
CHECK(gradient_tensors[i] != Py_None);
|
|
PyList_SET_ITEM(list, i,
|
|
reinterpret_cast<PyObject*>(gradient_tensors[i]));
|
|
}
|
|
PyObject* arglist = Py_BuildValue("(O)", list);
|
|
CHECK(arglist != nullptr);
|
|
PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
|
|
Py_DECREF(arglist);
|
|
Py_DECREF(list);
|
|
return result;
|
|
}
|
|
|
|
tensorflow::int64 TensorId(PyObject* tensor) const final {
|
|
return FastTensorId(tensor);
|
|
}
|
|
|
|
void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
|
|
|
|
PyObject* Ones(PyObject* shape, PyObject* dtype) const {
|
|
if (PyErr_Occurred()) {
|
|
return nullptr;
|
|
}
|
|
PyObject* arg_list = Py_BuildValue("OO", shape, dtype);
|
|
PyObject* result = PyEval_CallObject(ones_fn_, arg_list);
|
|
Py_DECREF(arg_list);
|
|
return result;
|
|
}
|
|
|
|
PyObject* OnesLike(PyObject* tensor) const {
|
|
if (PyErr_Occurred()) {
|
|
return nullptr;
|
|
}
|
|
return PyObject_CallFunctionObjArgs(ones_like_fn_, tensor, NULL);
|
|
}
|
|
|
|
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
|
Status BuildOnesLike(const PyTapeTensor& t,
|
|
PyObject** result) const override {
|
|
*result = t.OnesLike();
|
|
return Status::OK();
|
|
}
|
|
|
|
PyObject* Zeros(PyObject* shape, PyObject* dtype) const {
|
|
if (PyErr_Occurred()) {
|
|
return nullptr;
|
|
}
|
|
PyObject* arg_list = Py_BuildValue("OO", shape, dtype);
|
|
PyObject* result = PyEval_CallObject(zeros_fn_, arg_list);
|
|
Py_DECREF(arg_list);
|
|
return result;
|
|
}
|
|
|
|
PyObject* ZerosLike(PyObject* tensor) const {
|
|
if (PyErr_Occurred()) {
|
|
return nullptr;
|
|
}
|
|
return PyObject_CallFunctionObjArgs(zeros_like_fn_, tensor, NULL);
|
|
}
|
|
|
|
PyObject* GraphShape(PyObject* tensor) const {
|
|
PyObject* arg_list = Py_BuildValue("(O)", tensor);
|
|
PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list);
|
|
Py_DECREF(arg_list);
|
|
return result;
|
|
}
|
|
|
|
tensorflow::Status CallBackwardFunction(
|
|
const string& op_type, PyBackwardFunction* backward_function,
|
|
const std::vector<tensorflow::int64>& unneeded_gradients,
|
|
tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
|
|
absl::Span<PyObject*> result) const final {
|
|
PyObject* grads = PyTuple_New(output_gradients.size());
|
|
for (int i = 0; i < output_gradients.size(); ++i) {
|
|
if (output_gradients[i] == nullptr) {
|
|
Py_INCREF(Py_None);
|
|
PyTuple_SET_ITEM(grads, i, Py_None);
|
|
} else {
|
|
PyTuple_SET_ITEM(grads, i,
|
|
reinterpret_cast<PyObject*>(output_gradients[i]));
|
|
}
|
|
}
|
|
PyObject* py_result = (*backward_function)(grads, unneeded_gradients);
|
|
Py_DECREF(grads);
|
|
if (py_result == nullptr) {
|
|
return tensorflow::errors::Internal("gradient function threw exceptions");
|
|
}
|
|
PyObject* seq =
|
|
PySequence_Fast(py_result, "expected a sequence of gradients");
|
|
if (seq == nullptr) {
|
|
return tensorflow::errors::InvalidArgument(
|
|
"gradient function did not return a list");
|
|
}
|
|
int len = PySequence_Fast_GET_SIZE(seq);
|
|
if (len != result.size()) {
|
|
return tensorflow::errors::Internal(
|
|
"Recorded operation '", op_type,
|
|
"' returned too few gradients. Expected ", result.size(),
|
|
" but received ", len);
|
|
}
|
|
PyObject** seq_array = PySequence_Fast_ITEMS(seq);
|
|
VLOG(1) << "Gradient length is " << len;
|
|
for (int i = 0; i < len; ++i) {
|
|
PyObject* item = seq_array[i];
|
|
if (item == Py_None) {
|
|
result[i] = nullptr;
|
|
} else {
|
|
Py_INCREF(item);
|
|
result[i] = item;
|
|
}
|
|
}
|
|
Py_DECREF(seq);
|
|
Py_DECREF(py_result);
|
|
return tensorflow::Status::OK();
|
|
}
|
|
|
|
void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
|
|
|
|
PyTapeTensor TapeTensorFromGradient(PyObject* tensor) const final {
|
|
return TapeTensorFromTensor(tensor);
|
|
}
|
|
|
|
private:
|
|
PyObject* py_vspace_;
|
|
|
|
PyObject* num_elements_;
|
|
PyObject* aggregate_fn_;
|
|
PyObject* zeros_fn_;
|
|
PyObject* zeros_like_fn_;
|
|
PyObject* ones_fn_;
|
|
PyObject* ones_like_fn_;
|
|
PyObject* graph_shape_fn_;
|
|
};
|
|
PyVSpace* py_vspace = nullptr;
|
|
|
|
bool HasAccumulator();
|
|
|
|
PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
|
|
if (py_vspace != nullptr) {
|
|
if (HasAccumulator()) {
|
|
// Accumulators reference py_vspace, so we can't swap it out while one is
|
|
// active. This is unlikely to ever happen.
|
|
MaybeRaiseExceptionFromStatus(
|
|
tensorflow::errors::Internal(
|
|
"Can't change the vspace implementation while a "
|
|
"forward accumulator is active."),
|
|
nullptr);
|
|
}
|
|
delete py_vspace;
|
|
}
|
|
|
|
py_vspace = new PyVSpace(e);
|
|
auto status = py_vspace->Initialize();
|
|
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
|
|
delete py_vspace;
|
|
return nullptr;
|
|
}
|
|
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
PyObject* PyTapeTensor::GetShape() const {
|
|
if (shape_.index() == 0) {
|
|
auto& shape = absl::get<0>(shape_);
|
|
PyObject* py_shape = PyTuple_New(shape.dims());
|
|
for (int i = 0; i < shape.dims(); ++i) {
|
|
PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
|
|
}
|
|
|
|
return py_shape;
|
|
}
|
|
|
|
return py_vspace->GraphShape(absl::get<1>(shape_));
|
|
}
|
|
|
|
PyObject* PyTapeTensor::OnesLike() const {
|
|
if (shape_.index() == 1) {
|
|
PyObject* tensor = absl::get<1>(shape_);
|
|
return py_vspace->OnesLike(tensor);
|
|
}
|
|
PyObject* py_shape = GetShape();
|
|
PyObject* dtype_field = GetPyDType();
|
|
PyObject* result = py_vspace->Ones(py_shape, dtype_field);
|
|
Py_DECREF(dtype_field);
|
|
Py_DECREF(py_shape);
|
|
return result;
|
|
}
|
|
|
|
PyObject* PyTapeTensor::ZerosLike() const {
|
|
if (shape_.index() == 1) {
|
|
PyObject* tensor = absl::get<1>(shape_);
|
|
return py_vspace->ZerosLike(tensor);
|
|
}
|
|
PyObject* py_shape = GetShape();
|
|
PyObject* dtype_field = GetPyDType();
|
|
PyObject* result = py_vspace->Zeros(py_shape, dtype_field);
|
|
Py_DECREF(dtype_field);
|
|
Py_DECREF(py_shape);
|
|
return result;
|
|
}
|
|
|
|
// Keeps track of all variables that have been accessed during execution.
|
|
class VariableWatcher {
|
|
public:
|
|
VariableWatcher() {}
|
|
|
|
~VariableWatcher() {
|
|
for (const IdAndVariable& v : watched_variables_) {
|
|
Py_DECREF(v.variable);
|
|
}
|
|
}
|
|
|
|
tensorflow::int64 WatchVariable(PyObject* v) {
|
|
tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
|
|
if (handle == nullptr) {
|
|
return -1;
|
|
}
|
|
tensorflow::int64 id = FastTensorId(handle.get());
|
|
|
|
tensorflow::mutex_lock l(watched_variables_mu_);
|
|
auto insert_result = watched_variables_.emplace(id, v);
|
|
|
|
if (insert_result.second) {
|
|
// Only increment the reference count if we aren't already watching this
|
|
// variable.
|
|
Py_INCREF(v);
|
|
}
|
|
|
|
return id;
|
|
}
|
|
|
|
PyObject* GetVariablesAsPyTuple() {
|
|
tensorflow::mutex_lock l(watched_variables_mu_);
|
|
PyObject* result = PyTuple_New(watched_variables_.size());
|
|
Py_ssize_t pos = 0;
|
|
for (const IdAndVariable& id_and_variable : watched_variables_) {
|
|
PyTuple_SET_ITEM(result, pos++, id_and_variable.variable);
|
|
Py_INCREF(id_and_variable.variable);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
private:
|
|
// We store an IdAndVariable in the map since the map needs to be locked
|
|
// during insert, but should not call back into python during insert to avoid
|
|
// deadlocking with the GIL.
|
|
struct IdAndVariable {
|
|
tensorflow::int64 id;
|
|
PyObject* variable;
|
|
|
|
IdAndVariable(tensorflow::int64 id, PyObject* variable)
|
|
: id(id), variable(variable) {}
|
|
};
|
|
struct CompareById {
|
|
bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) const {
|
|
return lhs.id < rhs.id;
|
|
}
|
|
};
|
|
|
|
tensorflow::mutex watched_variables_mu_;
|
|
std::set<IdAndVariable, CompareById> watched_variables_
|
|
TF_GUARDED_BY(watched_variables_mu_);
|
|
};
|
|
|
|
class GradientTape
|
|
: public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
|
|
PyTapeTensor> {
|
|
public:
|
|
explicit GradientTape(bool persistent, bool watch_accessed_variables)
|
|
: tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
|
|
PyTapeTensor>(persistent),
|
|
watch_accessed_variables_(watch_accessed_variables) {}
|
|
|
|
virtual ~GradientTape() {}
|
|
|
|
void VariableAccessed(PyObject* v) {
|
|
if (watch_accessed_variables_) {
|
|
WatchVariable(v);
|
|
}
|
|
}
|
|
|
|
void WatchVariable(PyObject* v) {
|
|
tensorflow::int64 id = variable_watcher_.WatchVariable(v);
|
|
|
|
if (!PyErr_Occurred()) {
|
|
this->Watch(id);
|
|
}
|
|
}
|
|
|
|
PyObject* GetVariablesAsPyTuple() {
|
|
return variable_watcher_.GetVariablesAsPyTuple();
|
|
}
|
|
|
|
private:
|
|
bool watch_accessed_variables_;
|
|
VariableWatcher variable_watcher_;
|
|
};
|
|
|
|
typedef tensorflow::eager::ForwardAccumulator<PyObject, PyBackwardFunction,
|
|
PyTapeTensor>
|
|
ForwardAccumulator;
|
|
|
|
// Incremented when a GradientTape or accumulator is newly added to a set, and
|
|
// used to enforce an ordering between them.
|
|
std::atomic_uint_fast64_t tape_nesting_id_counter(0);
|
|
|
|
typedef struct {
|
|
PyObject_HEAD
|
|
/* Type-specific fields go here. */
|
|
GradientTape* tape;
|
|
// A nesting order between GradientTapes and ForwardAccumulators, used to
|
|
// ensure that GradientTapes do not watch the products of outer
|
|
// ForwardAccumulators.
|
|
tensorflow::int64 nesting_id;
|
|
} TFE_Py_Tape;
|
|
|
|
static void TFE_Py_Tape_Delete(PyObject* tape) {
|
|
delete reinterpret_cast<TFE_Py_Tape*>(tape)->tape;
|
|
Py_TYPE(tape)->tp_free(tape);
|
|
}
|
|
|
|
static PyTypeObject TFE_Py_Tape_Type = {
|
|
PyVarObject_HEAD_INIT(nullptr, 0) "tfe.Tape", /* tp_name */
|
|
sizeof(TFE_Py_Tape), /* tp_basicsize */
|
|
0, /* tp_itemsize */
|
|
&TFE_Py_Tape_Delete, /* tp_dealloc */
|
|
#if PY_VERSION_HEX < 0x03080000
|
|
nullptr, /* tp_print */
|
|
#else
|
|
0, /* tp_vectorcall_offset */
|
|
#endif
|
|
nullptr, /* tp_getattr */
|
|
nullptr, /* tp_setattr */
|
|
nullptr, /* tp_reserved */
|
|
nullptr, /* tp_repr */
|
|
nullptr, /* tp_as_number */
|
|
nullptr, /* tp_as_sequence */
|
|
nullptr, /* tp_as_mapping */
|
|
nullptr, /* tp_hash */
|
|
nullptr, /* tp_call */
|
|
nullptr, /* tp_str */
|
|
nullptr, /* tp_getattro */
|
|
nullptr, /* tp_setattro */
|
|
nullptr, /* tp_as_buffer */
|
|
Py_TPFLAGS_DEFAULT, /* tp_flags */
|
|
"TFE_Py_Tape objects", /* tp_doc */
|
|
};
|
|
|
|
typedef struct {
|
|
PyObject_HEAD
|
|
/* Type-specific fields go here. */
|
|
ForwardAccumulator* accumulator;
|
|
// A nesting order between GradientTapes and ForwardAccumulators, used to
|
|
// ensure that GradientTapes do not watch the products of outer
|
|
// ForwardAccumulators.
|
|
tensorflow::int64 nesting_id;
|
|
} TFE_Py_ForwardAccumulator;
|
|
|
|
static void TFE_Py_ForwardAccumulatorDelete(PyObject* accumulator) {
|
|
delete reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)->accumulator;
|
|
Py_TYPE(accumulator)->tp_free(accumulator);
|
|
}
|
|
|
|
static PyTypeObject TFE_Py_ForwardAccumulator_Type = {
|
|
PyVarObject_HEAD_INIT(nullptr, 0) "ForwardAccumulator", /* tp_name */
|
|
sizeof(TFE_Py_ForwardAccumulator), /* tp_basicsize */
|
|
0, /* tp_itemsize */
|
|
&TFE_Py_ForwardAccumulatorDelete, /* tp_dealloc */
|
|
#if PY_VERSION_HEX < 0x03080000
|
|
nullptr, /* tp_print */
|
|
#else
|
|
0, /* tp_vectorcall_offset */
|
|
#endif
|
|
nullptr, /* tp_getattr */
|
|
nullptr, /* tp_setattr */
|
|
nullptr, /* tp_reserved */
|
|
nullptr, /* tp_repr */
|
|
nullptr, /* tp_as_number */
|
|
nullptr, /* tp_as_sequence */
|
|
nullptr, /* tp_as_mapping */
|
|
nullptr, /* tp_hash */
|
|
nullptr, /* tp_call */
|
|
nullptr, /* tp_str */
|
|
nullptr, /* tp_getattro */
|
|
nullptr, /* tp_setattro */
|
|
nullptr, /* tp_as_buffer */
|
|
Py_TPFLAGS_DEFAULT, /* tp_flags */
|
|
"TFE_Py_ForwardAccumulator objects", /* tp_doc */
|
|
};
|
|
|
|
typedef struct {
|
|
PyObject_HEAD
|
|
/* Type-specific fields go here. */
|
|
VariableWatcher* variable_watcher;
|
|
} TFE_Py_VariableWatcher;
|
|
|
|
static void TFE_Py_VariableWatcher_Delete(PyObject* variable_watcher) {
|
|
delete reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
|
|
->variable_watcher;
|
|
Py_TYPE(variable_watcher)->tp_free(variable_watcher);
|
|
}
|
|
|
|
static PyTypeObject TFE_Py_VariableWatcher_Type = {
|
|
PyVarObject_HEAD_INIT(nullptr, 0) "tfe.VariableWatcher", /* tp_name */
|
|
sizeof(TFE_Py_VariableWatcher), /* tp_basicsize */
|
|
0, /* tp_itemsize */
|
|
&TFE_Py_VariableWatcher_Delete, /* tp_dealloc */
|
|
#if PY_VERSION_HEX < 0x03080000
|
|
nullptr, /* tp_print */
|
|
#else
|
|
0, /* tp_vectorcall_offset */
|
|
#endif
|
|
nullptr, /* tp_getattr */
|
|
nullptr, /* tp_setattr */
|
|
nullptr, /* tp_reserved */
|
|
nullptr, /* tp_repr */
|
|
nullptr, /* tp_as_number */
|
|
nullptr, /* tp_as_sequence */
|
|
nullptr, /* tp_as_mapping */
|
|
nullptr, /* tp_hash */
|
|
nullptr, /* tp_call */
|
|
nullptr, /* tp_str */
|
|
nullptr, /* tp_getattro */
|
|
nullptr, /* tp_setattro */
|
|
nullptr, /* tp_as_buffer */
|
|
Py_TPFLAGS_DEFAULT, /* tp_flags */
|
|
"TFE_Py_VariableWatcher objects", /* tp_doc */
|
|
};
|
|
|
|
// Note: in the current design no mutex is needed here because of the python
|
|
// GIL, which is always held when any TFE_Py_* methods are called. We should
|
|
// revisit this if/when decide to not hold the GIL while manipulating the tape
|
|
// stack.
|
|
tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
|
|
thread_local std::unique_ptr<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>
|
|
tape_set = nullptr;
|
|
if (tape_set == nullptr) {
|
|
tape_set.reset(new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>);
|
|
}
|
|
return tape_set.get();
|
|
}
|
|
|
|
tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>*
|
|
GetVariableWatcherSet() {
|
|
thread_local std::unique_ptr<
|
|
tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>
|
|
variable_watcher_set = nullptr;
|
|
if (variable_watcher_set == nullptr) {
|
|
variable_watcher_set.reset(
|
|
new tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>);
|
|
}
|
|
return variable_watcher_set.get();
|
|
}
|
|
|
|
// A linked hash set, where iteration is in insertion order.
|
|
//
|
|
// Nested accumulators rely on op recording happening in insertion order, so an
|
|
// unordered data structure like CompactPointerSet is not suitable. Outer
|
|
// accumulators need to observe operations first so they know to watch the inner
|
|
// accumulator's jvp computation.
|
|
//
|
|
// Not thread safe.
|
|
class AccumulatorSet {
|
|
public:
|
|
// Returns true if `element` was newly inserted, false if it already exists.
|
|
bool insert(TFE_Py_ForwardAccumulator* element) {
|
|
if (map_.find(element) != map_.end()) {
|
|
return false;
|
|
}
|
|
ListType::iterator it = ordered_.insert(ordered_.end(), element);
|
|
map_.insert(std::make_pair(element, it));
|
|
return true;
|
|
}
|
|
|
|
void erase(TFE_Py_ForwardAccumulator* element) {
|
|
MapType::iterator existing = map_.find(element);
|
|
if (existing == map_.end()) {
|
|
return;
|
|
}
|
|
ListType::iterator list_position = existing->second;
|
|
map_.erase(existing);
|
|
ordered_.erase(list_position);
|
|
}
|
|
|
|
bool empty() const { return ordered_.empty(); }
|
|
|
|
size_t size() const { return ordered_.size(); }
|
|
|
|
private:
|
|
typedef std::list<TFE_Py_ForwardAccumulator*> ListType;
|
|
typedef tensorflow::gtl::FlatMap<TFE_Py_ForwardAccumulator*,
|
|
ListType::iterator>
|
|
MapType;
|
|
|
|
public:
|
|
typedef ListType::const_iterator const_iterator;
|
|
typedef ListType::const_reverse_iterator const_reverse_iterator;
|
|
|
|
const_iterator begin() const { return ordered_.begin(); }
|
|
const_iterator end() const { return ordered_.end(); }
|
|
|
|
const_reverse_iterator rbegin() const { return ordered_.rbegin(); }
|
|
const_reverse_iterator rend() const { return ordered_.rend(); }
|
|
|
|
private:
|
|
MapType map_;
|
|
ListType ordered_;
|
|
};
|
|
|
|
AccumulatorSet* GetAccumulatorSet() {
|
|
thread_local std::unique_ptr<AccumulatorSet> accumulator_set{nullptr};
|
|
if (accumulator_set == nullptr) {
|
|
accumulator_set.reset(new AccumulatorSet);
|
|
}
|
|
return accumulator_set.get();
|
|
}
|
|
|
|
inline bool HasAccumulator() { return !GetAccumulatorSet()->empty(); }
|
|
|
|
inline bool HasGradientTape() { return !GetTapeSet()->empty(); }
|
|
|
|
inline bool HasAccumulatorOrTape() {
|
|
return HasGradientTape() || HasAccumulator();
|
|
}
|
|
|
|
// A safe copy of a set, used for tapes and accumulators. The copy is not
|
|
// affected by other python threads changing the set of active tapes.
|
|
template <typename ContainerType>
|
|
class SafeSetCopy {
|
|
public:
|
|
explicit SafeSetCopy(const ContainerType& to_copy) : set_copy_(to_copy) {
|
|
for (auto* member : set_copy_) {
|
|
Py_INCREF(member);
|
|
}
|
|
}
|
|
|
|
~SafeSetCopy() {
|
|
for (auto* member : set_copy_) {
|
|
Py_DECREF(member);
|
|
}
|
|
}
|
|
|
|
typename ContainerType::const_iterator begin() const {
|
|
return set_copy_.begin();
|
|
}
|
|
|
|
typename ContainerType::const_iterator end() const { return set_copy_.end(); }
|
|
|
|
bool empty() const { return set_copy_.empty(); }
|
|
size_t size() const { return set_copy_.size(); }
|
|
|
|
protected:
|
|
ContainerType set_copy_;
|
|
};
|
|
|
|
class SafeTapeSet
|
|
: public SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>> {
|
|
public:
|
|
SafeTapeSet()
|
|
: SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>(
|
|
*GetTapeSet()) {}
|
|
};
|
|
|
|
class SafeAccumulatorSet : public SafeSetCopy<AccumulatorSet> {
|
|
public:
|
|
SafeAccumulatorSet() : SafeSetCopy<AccumulatorSet>(*GetAccumulatorSet()) {}
|
|
|
|
typename AccumulatorSet::const_reverse_iterator rbegin() const {
|
|
return set_copy_.rbegin();
|
|
}
|
|
|
|
typename AccumulatorSet::const_reverse_iterator rend() const {
|
|
return set_copy_.rend();
|
|
}
|
|
};
|
|
|
|
class SafeVariableWatcherSet
|
|
: public SafeSetCopy<
|
|
tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> {
|
|
public:
|
|
SafeVariableWatcherSet()
|
|
: SafeSetCopy<
|
|
tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>(
|
|
*GetVariableWatcherSet()) {}
|
|
};
|
|
|
|
bool* ThreadTapeIsStopped() {
|
|
thread_local bool thread_tape_is_stopped{false};
|
|
return &thread_tape_is_stopped;
|
|
}
|
|
|
|
void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
|
|
|
|
void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; }
|
|
|
|
PyObject* TFE_Py_TapeSetIsStopped() {
|
|
if (*ThreadTapeIsStopped()) {
|
|
Py_RETURN_TRUE;
|
|
}
|
|
Py_RETURN_FALSE;
|
|
}
|
|
|
|
PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
|
|
PyObject* watch_accessed_variables) {
|
|
TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
|
|
if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
|
|
TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
|
|
tape->tape = new GradientTape(persistent == Py_True,
|
|
watch_accessed_variables == Py_True);
|
|
Py_INCREF(tape);
|
|
tape->nesting_id = tape_nesting_id_counter.fetch_add(1);
|
|
GetTapeSet()->insert(tape);
|
|
return reinterpret_cast<PyObject*>(tape);
|
|
}
|
|
|
|
void TFE_Py_TapeSetAdd(PyObject* tape) {
|
|
Py_INCREF(tape);
|
|
TFE_Py_Tape* tfe_tape = reinterpret_cast<TFE_Py_Tape*>(tape);
|
|
if (!GetTapeSet()->insert(tfe_tape).second) {
|
|
// Already exists in the tape set.
|
|
Py_DECREF(tape);
|
|
} else {
|
|
tfe_tape->nesting_id = tape_nesting_id_counter.fetch_add(1);
|
|
}
|
|
}
|
|
|
|
PyObject* TFE_Py_TapeSetIsEmpty() {
|
|
if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
|
|
Py_RETURN_TRUE;
|
|
}
|
|
Py_RETURN_FALSE;
|
|
}
|
|
|
|
void TFE_Py_TapeSetRemove(PyObject* tape) {
|
|
auto* stack = GetTapeSet();
|
|
stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape));
|
|
// We kept a reference to the tape in the set to ensure it wouldn't get
|
|
// deleted under us; cleaning it up here.
|
|
Py_DECREF(tape);
|
|
}
|
|
|
|
static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
|
|
if (list == Py_None) {
|
|
return {};
|
|
}
|
|
PyObject* seq = PySequence_Fast(list, "expected a sequence");
|
|
if (seq == nullptr) {
|
|
return {};
|
|
}
|
|
int len = PySequence_Size(list);
|
|
PyObject** seq_array = PySequence_Fast_ITEMS(seq);
|
|
std::vector<tensorflow::int64> tensor_ids;
|
|
tensor_ids.reserve(len);
|
|
for (int i = 0; i < len; ++i) {
|
|
PyObject* item = seq_array[i];
|
|
#if PY_MAJOR_VERSION >= 3
|
|
if (PyLong_Check(item)) {
|
|
#else
|
|
if (PyLong_Check(item) || PyInt_Check(item)) {
|
|
#endif
|
|
tensorflow::int64 id = MakeInt(item);
|
|
tensor_ids.push_back(id);
|
|
} else {
|
|
tensor_ids.push_back(-1);
|
|
}
|
|
}
|
|
Py_DECREF(seq);
|
|
return tensor_ids;
|
|
}
|
|
|
|
// Fill `tensor_ids` and `dtypes` from `tensors`, none of which may be
|
|
// null. Returns true on success and false on a Python exception.
|
|
bool TensorShapesAndDtypes(PyObject* tensors,
|
|
std::vector<tensorflow::int64>* tensor_ids,
|
|
std::vector<tensorflow::DataType>* dtypes) {
|
|
tensorflow::Safe_PyObjectPtr seq(
|
|
PySequence_Fast(tensors, "expected a sequence"));
|
|
if (seq == nullptr) {
|
|
return false;
|
|
}
|
|
int len = PySequence_Fast_GET_SIZE(seq.get());
|
|
PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
|
|
tensor_ids->reserve(len);
|
|
dtypes->reserve(len);
|
|
for (int i = 0; i < len; ++i) {
|
|
PyObject* item = seq_array[i];
|
|
tensor_ids->push_back(FastTensorId(item));
|
|
dtypes->push_back(tensorflow::PyTensor_DataType(item));
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool TapeCouldPossiblyRecord(PyObject* tensors) {
|
|
if (tensors == Py_None) {
|
|
return false;
|
|
}
|
|
if (*ThreadTapeIsStopped()) {
|
|
return false;
|
|
}
|
|
if (!HasAccumulatorOrTape()) {
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool CouldBackprop() { return !*ThreadTapeIsStopped() && HasGradientTape(); }
|
|
|
|
bool CouldForwardprop() { return !*ThreadTapeIsStopped() && HasAccumulator(); }
|
|
|
|
PyObject* TFE_Py_TapeSetShouldRecordBackprop(PyObject* tensors) {
|
|
if (!TapeCouldPossiblyRecord(tensors) || !CouldBackprop()) {
|
|
Py_RETURN_FALSE;
|
|
}
|
|
// TODO(apassos) consider not building a list and changing the API to check
|
|
// each tensor individually.
|
|
std::vector<tensorflow::int64> tensor_ids;
|
|
std::vector<tensorflow::DataType> dtypes;
|
|
if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
|
|
return nullptr;
|
|
}
|
|
auto tape_set = *GetTapeSet();
|
|
for (TFE_Py_Tape* tape : tape_set) {
|
|
if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
|
|
Py_RETURN_TRUE;
|
|
}
|
|
}
|
|
|
|
Py_RETURN_FALSE;
|
|
}
|
|
|
|
PyObject* TFE_Py_ForwardAccumulatorPushState() {
|
|
auto forward_accumulators = *GetAccumulatorSet();
|
|
for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
|
|
accumulator->accumulator->PushState();
|
|
}
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
PyObject* TFE_Py_ForwardAccumulatorPopState() {
|
|
auto forward_accumulators = *GetAccumulatorSet();
|
|
for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
|
|
accumulator->accumulator->PopState();
|
|
}
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) {
|
|
if (!TapeCouldPossiblyRecord(tensors)) {
|
|
return GetPythonObjectFromInt(0);
|
|
}
|
|
std::vector<tensorflow::int64> tensor_ids;
|
|
std::vector<tensorflow::DataType> dtypes;
|
|
if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
|
|
return nullptr;
|
|
}
|
|
|
|
// If there is a persistent tape watching, or if there are multiple tapes
|
|
// watching, we'll return immediately indicating that higher-order tape
|
|
// gradients are possible.
|
|
bool some_tape_watching = false;
|
|
if (CouldBackprop()) {
|
|
auto tape_set = *GetTapeSet();
|
|
for (TFE_Py_Tape* tape : tape_set) {
|
|
if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
|
|
if (tape->tape->IsPersistent() || some_tape_watching) {
|
|
// Either this is the second tape watching, or this tape is
|
|
// persistent: higher-order gradients are possible.
|
|
return GetPythonObjectFromInt(2);
|
|
}
|
|
some_tape_watching = true;
|
|
}
|
|
}
|
|
}
|
|
if (CouldForwardprop()) {
|
|
auto forward_accumulators = *GetAccumulatorSet();
|
|
for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
|
|
if (accumulator->accumulator->ShouldRecord(tensor_ids, dtypes)) {
|
|
if (some_tape_watching) {
|
|
// This is the second tape watching: higher-order gradients are
|
|
// possible. Note that there's no equivalent of persistence for
|
|
// forward-mode.
|
|
return GetPythonObjectFromInt(2);
|
|
}
|
|
some_tape_watching = true;
|
|
}
|
|
}
|
|
}
|
|
if (some_tape_watching) {
|
|
// There's exactly one non-persistent tape. The user can request first-order
|
|
// gradients but won't be able to get higher-order tape gradients.
|
|
return GetPythonObjectFromInt(1);
|
|
} else {
|
|
// There are no tapes. The user can't request tape gradients.
|
|
return GetPythonObjectFromInt(0);
|
|
}
|
|
}
|
|
|
|
void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
|
|
if (!CouldBackprop()) {
|
|
return;
|
|
}
|
|
tensorflow::int64 tensor_id = FastTensorId(tensor);
|
|
if (PyErr_Occurred()) {
|
|
return;
|
|
}
|
|
reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
|
|
}
|
|
|
|
bool ListContainsNone(PyObject* list) {
|
|
if (list == Py_None) return true;
|
|
tensorflow::Safe_PyObjectPtr seq(
|
|
PySequence_Fast(list, "expected a sequence"));
|
|
if (seq == nullptr) {
|
|
return false;
|
|
}
|
|
|
|
int len = PySequence_Size(list);
|
|
PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
|
|
for (int i = 0; i < len; ++i) {
|
|
PyObject* item = seq_array[i];
|
|
if (item == Py_None) return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
|
|
if (EagerTensor_CheckExact(tensor)) {
|
|
tensorflow::ImmediateExecutionTensorHandle* handle =
|
|
tensorflow::unwrap(EagerTensor_Handle(tensor));
|
|
tensorflow::int64 id = PyEagerTensor_ID(tensor);
|
|
tensorflow::DataType dtype =
|
|
static_cast<tensorflow::DataType>(handle->DataType());
|
|
if (dtype == tensorflow::DT_VARIANT) {
|
|
return PyTapeTensor(id, dtype, tensor);
|
|
}
|
|
|
|
tensorflow::TensorShape tensor_shape;
|
|
int num_dims;
|
|
tensorflow::Status status = handle->NumDims(&num_dims);
|
|
if (status.ok()) {
|
|
for (int i = 0; i < num_dims; ++i) {
|
|
tensorflow::int64 dim_size;
|
|
status = handle->Dim(i, &dim_size);
|
|
if (!status.ok()) break;
|
|
tensor_shape.AddDim(dim_size);
|
|
}
|
|
}
|
|
|
|
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
|
|
return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
|
|
tensorflow::TensorShape({}));
|
|
} else {
|
|
return PyTapeTensor(id, dtype, tensor_shape);
|
|
}
|
|
}
|
|
tensorflow::int64 id = FastTensorId(tensor);
|
|
if (PyErr_Occurred()) {
|
|
return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
|
|
tensorflow::TensorShape({}));
|
|
}
|
|
PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype");
|
|
PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum");
|
|
Py_DECREF(dtype_object);
|
|
tensorflow::DataType dtype =
|
|
static_cast<tensorflow::DataType>(MakeInt(dtype_enum));
|
|
Py_DECREF(dtype_enum);
|
|
if (PyErr_Occurred()) {
|
|
return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
|
|
tensorflow::TensorShape({}));
|
|
}
|
|
static char _shape_tuple[] = "_shape_tuple";
|
|
tensorflow::Safe_PyObjectPtr shape_tuple(
|
|
PyObject_CallMethod(tensor, _shape_tuple, nullptr));
|
|
if (PyErr_Occurred()) {
|
|
return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
|
|
tensorflow::TensorShape({}));
|
|
}
|
|
|
|
if (ListContainsNone(shape_tuple.get()) || dtype == tensorflow::DT_VARIANT) {
|
|
return PyTapeTensor(id, dtype, tensor);
|
|
}
|
|
|
|
auto l = MakeIntList(shape_tuple.get());
|
|
// Replace -1, which represents accidental Nones which can occur in graph mode
|
|
// and can cause errors in shape construction with 0s.
|
|
for (auto& c : l) {
|
|
if (c < 0) {
|
|
c = 0;
|
|
}
|
|
}
|
|
tensorflow::TensorShape shape(l);
|
|
return PyTapeTensor(id, dtype, shape);
|
|
}
|
|
|
|
// Populates output_info from output_seq, which must come from PySequence_Fast.
|
|
//
|
|
// Does not take ownership of output_seq. Returns true on success and false if a
|
|
// Python exception has been set.
|
|
bool TapeTensorsFromTensorSequence(PyObject* output_seq,
|
|
std::vector<PyTapeTensor>* output_info) {
|
|
Py_ssize_t output_len = PySequence_Fast_GET_SIZE(output_seq);
|
|
PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq);
|
|
output_info->reserve(output_len);
|
|
for (Py_ssize_t i = 0; i < output_len; ++i) {
|
|
output_info->push_back(TapeTensorFromTensor(output_seq_array[i]));
|
|
if (PyErr_Occurred() != nullptr) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
|
|
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
|
|
if (seq == nullptr) {
|
|
return {};
|
|
}
|
|
int len = PySequence_Fast_GET_SIZE(seq);
|
|
PyObject** seq_array = PySequence_Fast_ITEMS(seq);
|
|
std::vector<tensorflow::int64> list;
|
|
list.reserve(len);
|
|
for (int i = 0; i < len; ++i) {
|
|
PyObject* tensor = seq_array[i];
|
|
list.push_back(FastTensorId(tensor));
|
|
if (PyErr_Occurred()) {
|
|
Py_DECREF(seq);
|
|
return list;
|
|
}
|
|
}
|
|
Py_DECREF(seq);
|
|
return list;
|
|
}
|
|
|
|
void TFE_Py_TapeVariableAccessed(PyObject* variable) {
|
|
if (!CouldBackprop()) {
|
|
return;
|
|
}
|
|
for (TFE_Py_Tape* tape : SafeTapeSet()) {
|
|
tape->tape->VariableAccessed(variable);
|
|
}
|
|
}
|
|
|
|
void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) {
|
|
if (!CouldBackprop()) {
|
|
return;
|
|
}
|
|
reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable);
|
|
}
|
|
|
|
PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
|
|
return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
|
|
}
|
|
|
|
PyObject* TFE_Py_VariableWatcherNew() {
|
|
TFE_Py_VariableWatcher_Type.tp_new = PyType_GenericNew;
|
|
if (PyType_Ready(&TFE_Py_VariableWatcher_Type) < 0) return nullptr;
|
|
TFE_Py_VariableWatcher* variable_watcher =
|
|
PyObject_NEW(TFE_Py_VariableWatcher, &TFE_Py_VariableWatcher_Type);
|
|
variable_watcher->variable_watcher = new VariableWatcher();
|
|
Py_INCREF(variable_watcher);
|
|
GetVariableWatcherSet()->insert(variable_watcher);
|
|
return reinterpret_cast<PyObject*>(variable_watcher);
|
|
}
|
|
|
|
void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher) {
|
|
auto* stack = GetVariableWatcherSet();
|
|
stack->erase(reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher));
|
|
// We kept a reference to the variable watcher in the set to ensure it
|
|
// wouldn't get deleted under us; cleaning it up here.
|
|
Py_DECREF(variable_watcher);
|
|
}
|
|
|
|
void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable) {
|
|
for (TFE_Py_VariableWatcher* variable_watcher : SafeVariableWatcherSet()) {
|
|
variable_watcher->variable_watcher->WatchVariable(variable);
|
|
}
|
|
}
|
|
|
|
PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher) {
|
|
return reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
|
|
->variable_watcher->GetVariablesAsPyTuple();
|
|
}
|
|
|
|
namespace {
|
|
std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
|
|
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
|
|
if (seq == nullptr) {
|
|
return {};
|
|
}
|
|
int len = PySequence_Fast_GET_SIZE(seq);
|
|
PyObject** seq_array = PySequence_Fast_ITEMS(seq);
|
|
std::vector<tensorflow::DataType> list;
|
|
list.reserve(len);
|
|
for (int i = 0; i < len; ++i) {
|
|
PyObject* tensor = seq_array[i];
|
|
list.push_back(tensorflow::PyTensor_DataType(tensor));
|
|
}
|
|
Py_DECREF(seq);
|
|
return list;
|
|
}
|
|
|
|
PyObject* ForwardAccumulatorDeleteGradient(PyObject* tensor_id,
|
|
PyObject* weak_tensor_ref) {
|
|
tensorflow::int64 parsed_tensor_id = MakeInt(tensor_id);
|
|
for (TFE_Py_ForwardAccumulator* accumulator : *GetAccumulatorSet()) {
|
|
accumulator->accumulator->DeleteGradient(parsed_tensor_id);
|
|
}
|
|
Py_DECREF(weak_tensor_ref);
|
|
Py_DECREF(tensor_id);
|
|
Py_INCREF(Py_None);
|
|
return Py_None;
|
|
}
|
|
|
|
static PyMethodDef forward_accumulator_delete_gradient_method_def = {
|
|
"ForwardAccumulatorDeleteGradient", ForwardAccumulatorDeleteGradient,
|
|
METH_O, "ForwardAccumulatorDeleteGradient"};
|
|
|
|
void RegisterForwardAccumulatorCleanup(PyObject* tensor,
|
|
tensorflow::int64 tensor_id) {
|
|
tensorflow::Safe_PyObjectPtr callback(
|
|
PyCFunction_New(&forward_accumulator_delete_gradient_method_def,
|
|
PyLong_FromLong(tensor_id)));
|
|
// We need to keep a reference to the weakref active if we want our callback
|
|
// called. The callback itself now owns the weakref object and the tensor ID
|
|
// object.
|
|
PyWeakref_NewRef(tensor, callback.get());
|
|
}
|
|
|
|
void TapeSetRecordBackprop(
|
|
const string& op_type, const std::vector<PyTapeTensor>& output_info,
|
|
const std::vector<tensorflow::int64>& input_ids,
|
|
const std::vector<tensorflow::DataType>& input_dtypes,
|
|
const std::function<PyBackwardFunction*()>& backward_function_getter,
|
|
const std::function<void(PyBackwardFunction*)>& backward_function_killer,
|
|
tensorflow::uint64 max_gradient_tape_id) {
|
|
if (!CouldBackprop()) {
|
|
return;
|
|
}
|
|
for (TFE_Py_Tape* tape : SafeTapeSet()) {
|
|
if (tape->nesting_id < max_gradient_tape_id) {
|
|
tape->tape->RecordOperation(op_type, output_info, input_ids, input_dtypes,
|
|
backward_function_getter,
|
|
backward_function_killer);
|
|
}
|
|
}
|
|
}
|
|
|
|
bool TapeSetRecordForwardprop(
|
|
const string& op_type, PyObject* output_seq,
|
|
const std::vector<PyTapeTensor>& output_info, PyObject* input_tensors,
|
|
const std::vector<tensorflow::int64>& input_ids,
|
|
const std::vector<tensorflow::DataType>& input_dtypes,
|
|
const std::function<PyBackwardFunction*()>& backward_function_getter,
|
|
const std::function<void(PyBackwardFunction*)>& backward_function_killer,
|
|
const tensorflow::eager::ForwardFunction<PyObject>* forward_function,
|
|
PyObject* forwardprop_output_indices,
|
|
tensorflow::uint64* max_gradient_tape_id) {
|
|
*max_gradient_tape_id = std::numeric_limits<tensorflow::uint64>::max();
|
|
if (!CouldForwardprop()) {
|
|
return true;
|
|
}
|
|
auto accumulator_set = SafeAccumulatorSet();
|
|
tensorflow::Safe_PyObjectPtr input_seq(
|
|
PySequence_Fast(input_tensors, "expected a sequence of tensors"));
|
|
if (input_seq == nullptr || PyErr_Occurred()) return false;
|
|
Py_ssize_t input_len = PySequence_Fast_GET_SIZE(input_seq.get());
|
|
PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq);
|
|
for (int i = 0; i < output_info.size(); ++i) {
|
|
RegisterForwardAccumulatorCleanup(output_seq_array[i],
|
|
output_info[i].GetID());
|
|
}
|
|
if (forwardprop_output_indices != nullptr &&
|
|
forwardprop_output_indices != Py_None) {
|
|
tensorflow::Safe_PyObjectPtr indices_fast(PySequence_Fast(
|
|
forwardprop_output_indices, "Expected a sequence of indices"));
|
|
if (indices_fast == nullptr || PyErr_Occurred()) {
|
|
return false;
|
|
}
|
|
if (PySequence_Fast_GET_SIZE(indices_fast.get()) !=
|
|
accumulator_set.size()) {
|
|
MaybeRaiseExceptionFromStatus(
|
|
tensorflow::errors::Internal(
|
|
"Accumulators were added or removed from the active set "
|
|
"between packing and unpacking."),
|
|
nullptr);
|
|
}
|
|
PyObject** indices_fast_array = PySequence_Fast_ITEMS(indices_fast.get());
|
|
Py_ssize_t accumulator_index = 0;
|
|
for (AccumulatorSet::const_reverse_iterator it = accumulator_set.rbegin();
|
|
it != accumulator_set.rend(); ++it, ++accumulator_index) {
|
|
tensorflow::Safe_PyObjectPtr jvp_index_seq(
|
|
PySequence_Fast(indices_fast_array[accumulator_index],
|
|
"Expected a sequence of jvp indices."));
|
|
if (jvp_index_seq == nullptr || PyErr_Occurred()) {
|
|
return false;
|
|
}
|
|
Py_ssize_t num_jvps = PySequence_Fast_GET_SIZE(jvp_index_seq.get());
|
|
PyObject** jvp_index_seq_array =
|
|
PySequence_Fast_ITEMS(jvp_index_seq.get());
|
|
for (Py_ssize_t jvp_index = 0; jvp_index < num_jvps; ++jvp_index) {
|
|
PyObject* tuple = jvp_index_seq_array[jvp_index];
|
|
tensorflow::int64 primal_tensor_id =
|
|
output_info[MakeInt(PyTuple_GetItem(tuple, 0))].GetID();
|
|
(*it)->accumulator->Watch(
|
|
primal_tensor_id,
|
|
output_seq_array[MakeInt(PyTuple_GetItem(tuple, 1))]);
|
|
}
|
|
}
|
|
} else {
|
|
std::vector<PyTapeTensor> input_info;
|
|
input_info.reserve(input_len);
|
|
PyObject** input_seq_array = PySequence_Fast_ITEMS(input_seq.get());
|
|
for (Py_ssize_t i = 0; i < input_len; ++i) {
|
|
input_info.push_back(TapeTensorFromTensor(input_seq_array[i]));
|
|
}
|
|
for (TFE_Py_ForwardAccumulator* accumulator : accumulator_set) {
|
|
tensorflow::Status status = accumulator->accumulator->Accumulate(
|
|
op_type, input_info, output_info, input_ids, input_dtypes,
|
|
forward_function, backward_function_getter, backward_function_killer);
|
|
if (PyErr_Occurred()) return false; // Don't swallow Python exceptions.
|
|
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
|
|
return false;
|
|
}
|
|
if (accumulator->accumulator->BusyAccumulating()) {
|
|
// Ensure inner accumulators don't see outer accumulators' jvps. This
|
|
// mostly happens on its own, with some potentially surprising
|
|
// exceptions, so the blanket policy is for consistency.
|
|
*max_gradient_tape_id = accumulator->nesting_id;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
PyObject* TangentsAsPyTuple(const std::vector<PyObject*>& input_tangents) {
|
|
PyObject* py_input_tangents = PyTuple_New(input_tangents.size());
|
|
for (int i = 0; i < input_tangents.size(); ++i) {
|
|
PyObject* element;
|
|
if (input_tangents[i] == nullptr) {
|
|
element = Py_None;
|
|
} else {
|
|
element = input_tangents[i];
|
|
}
|
|
Py_INCREF(element);
|
|
PyTuple_SET_ITEM(py_input_tangents, i, element);
|
|
}
|
|
return py_input_tangents;
|
|
}
|
|
|
|
tensorflow::Status ParseTangentOutputs(
|
|
PyObject* user_output, std::vector<PyObject*>* output_tangents) {
|
|
if (user_output == Py_None) {
|
|
// No connected gradients.
|
|
return tensorflow::Status::OK();
|
|
}
|
|
tensorflow::Safe_PyObjectPtr fast_result(
|
|
PySequence_Fast(user_output, "expected a sequence of forward gradients"));
|
|
if (fast_result == nullptr) {
|
|
return tensorflow::errors::InvalidArgument(
|
|
"forward gradient function did not return a sequence.");
|
|
}
|
|
int len = PySequence_Fast_GET_SIZE(fast_result.get());
|
|
PyObject** fast_result_array = PySequence_Fast_ITEMS(fast_result.get());
|
|
output_tangents->reserve(len);
|
|
for (int i = 0; i < len; ++i) {
|
|
PyObject* item = fast_result_array[i];
|
|
if (item == Py_None) {
|
|
output_tangents->push_back(nullptr);
|
|
} else {
|
|
Py_INCREF(item);
|
|
output_tangents->push_back(item);
|
|
}
|
|
}
|
|
return tensorflow::Status::OK();
|
|
}
|
|
|
|
// Calls the registered forward_gradient_function, computing `output_tangents`
|
|
// from `input_tangents`. `output_tangents` must not be null.
|
|
//
|
|
// `op_name`, `attrs`, `inputs`, and `results` describe the operation for which
|
|
// the forward function is being called.
|
|
tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs,
|
|
PyObject* inputs, PyObject* results,
|
|
const std::vector<PyObject*>& input_tangents,
|
|
std::vector<PyObject*>* output_tangents,
|
|
bool use_batch) {
|
|
if (forward_gradient_function == nullptr) {
|
|
return tensorflow::errors::Internal(
|
|
"No forward gradient function registered.");
|
|
}
|
|
tensorflow::Safe_PyObjectPtr py_input_tangents(
|
|
TangentsAsPyTuple(input_tangents));
|
|
|
|
// Normalize the input sequence to a tuple so it works with function
|
|
// caching; otherwise it may be an opaque _InputList object.
|
|
tensorflow::Safe_PyObjectPtr input_tuple(PySequence_Tuple(inputs));
|
|
PyObject* to_batch = (use_batch) ? Py_True : Py_False;
|
|
tensorflow::Safe_PyObjectPtr callback_args(
|
|
Py_BuildValue("OOOOOO", op_name, attrs, input_tuple.get(), results,
|
|
py_input_tangents.get(), to_batch));
|
|
tensorflow::Safe_PyObjectPtr py_result(
|
|
PyObject_CallObject(forward_gradient_function, callback_args.get()));
|
|
if (py_result == nullptr || PyErr_Occurred()) {
|
|
return tensorflow::errors::Internal(
|
|
"forward gradient function threw exceptions");
|
|
}
|
|
return ParseTangentOutputs(py_result.get(), output_tangents);
|
|
}
|
|
|
|
// Like CallJVPFunction, but calls a pre-bound forward function.
|
|
// These are passed in from a record_gradient argument.
|
|
tensorflow::Status CallOpSpecificJVPFunction(
|
|
PyObject* op_specific_forward_function,
|
|
const std::vector<PyObject*>& input_tangents,
|
|
std::vector<PyObject*>* output_tangents) {
|
|
tensorflow::Safe_PyObjectPtr py_input_tangents(
|
|
TangentsAsPyTuple(input_tangents));
|
|
|
|
tensorflow::Safe_PyObjectPtr py_result(PyObject_CallObject(
|
|
op_specific_forward_function, py_input_tangents.get()));
|
|
if (py_result == nullptr || PyErr_Occurred()) {
|
|
return tensorflow::errors::Internal(
|
|
"forward gradient function threw exceptions");
|
|
}
|
|
return ParseTangentOutputs(py_result.get(), output_tangents);
|
|
}
|
|
|
|
bool ParseOpTypeString(PyObject* op_type, string* op_type_string) {
|
|
if (PyBytes_Check(op_type)) {
|
|
*op_type_string = PyBytes_AsString(op_type);
|
|
} else if (PyUnicode_Check(op_type)) {
|
|
#if PY_MAJOR_VERSION >= 3
|
|
*op_type_string = PyUnicode_AsUTF8(op_type);
|
|
#else
|
|
PyObject* py_str = PyUnicode_AsUTF8String(op_type);
|
|
if (py_str == nullptr) {
|
|
return false;
|
|
}
|
|
*op_type_string = PyBytes_AS_STRING(py_str);
|
|
Py_DECREF(py_str);
|
|
#endif
|
|
} else {
|
|
PyErr_SetString(PyExc_RuntimeError, "op_type should be a string.");
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool TapeSetRecordOperation(
|
|
PyObject* op_type, PyObject* input_tensors, PyObject* output_tensors,
|
|
const std::vector<tensorflow::int64>& input_ids,
|
|
const std::vector<tensorflow::DataType>& input_dtypes,
|
|
const std::function<PyBackwardFunction*()>& backward_function_getter,
|
|
const std::function<void(PyBackwardFunction*)>& backward_function_killer,
|
|
const tensorflow::eager::ForwardFunction<PyObject>* forward_function) {
|
|
std::vector<PyTapeTensor> output_info;
|
|
tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
|
|
output_tensors, "expected a sequence of integer tensor ids"));
|
|
if (PyErr_Occurred() ||
|
|
!TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
|
|
return false;
|
|
}
|
|
string op_type_str;
|
|
if (!ParseOpTypeString(op_type, &op_type_str)) {
|
|
return false;
|
|
}
|
|
tensorflow::uint64 max_gradient_tape_id;
|
|
if (!TapeSetRecordForwardprop(
|
|
op_type_str, output_seq.get(), output_info, input_tensors, input_ids,
|
|
input_dtypes, backward_function_getter, backward_function_killer,
|
|
forward_function, nullptr /* No special-cased jvps. */,
|
|
&max_gradient_tape_id)) {
|
|
return false;
|
|
}
|
|
TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes,
|
|
backward_function_getter, backward_function_killer,
|
|
max_gradient_tape_id);
|
|
return true;
|
|
}
|
|
} // namespace
|
|
|
|
PyObject* TFE_Py_TapeSetRecordOperation(PyObject* op_type,
|
|
PyObject* output_tensors,
|
|
PyObject* input_tensors,
|
|
PyObject* backward_function,
|
|
PyObject* forward_function) {
|
|
if (!HasAccumulatorOrTape() || *ThreadTapeIsStopped()) {
|
|
Py_RETURN_NONE;
|
|
}
|
|
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
|
|
if (PyErr_Occurred()) return nullptr;
|
|
|
|
std::vector<tensorflow::DataType> input_dtypes =
|
|
MakeTensorDtypeList(input_tensors);
|
|
if (PyErr_Occurred()) return nullptr;
|
|
|
|
std::function<PyBackwardFunction*()> backward_function_getter(
|
|
[backward_function]() {
|
|
Py_INCREF(backward_function);
|
|
PyBackwardFunction* function = new PyBackwardFunction(
|
|
[backward_function](PyObject* out_grads,
|
|
const std::vector<tensorflow::int64>& unused) {
|
|
return PyObject_CallObject(backward_function, out_grads);
|
|
});
|
|
return function;
|
|
});
|
|
std::function<void(PyBackwardFunction*)> backward_function_killer(
|
|
[backward_function](PyBackwardFunction* py_backward_function) {
|
|
Py_DECREF(backward_function);
|
|
delete py_backward_function;
|
|
});
|
|
|
|
if (forward_function == Py_None) {
|
|
if (!TapeSetRecordOperation(
|
|
op_type, input_tensors, output_tensors, input_ids, input_dtypes,
|
|
backward_function_getter, backward_function_killer,
|
|
nullptr /* No special-cased forward function */)) {
|
|
return nullptr;
|
|
}
|
|
} else {
|
|
tensorflow::eager::ForwardFunction<PyObject> wrapped_forward_function(
|
|
[forward_function](const std::vector<PyObject*>& input_tangents,
|
|
std::vector<PyObject*>* output_tangents,
|
|
bool use_batch = false) {
|
|
return CallOpSpecificJVPFunction(forward_function, input_tangents,
|
|
output_tangents);
|
|
});
|
|
if (!TapeSetRecordOperation(
|
|
op_type, input_tensors, output_tensors, input_ids, input_dtypes,
|
|
backward_function_getter, backward_function_killer,
|
|
&wrapped_forward_function)) {
|
|
return nullptr;
|
|
}
|
|
}
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
PyObject* TFE_Py_TapeSetRecordOperationForwardprop(
|
|
PyObject* op_type, PyObject* output_tensors, PyObject* input_tensors,
|
|
PyObject* backward_function, PyObject* forwardprop_output_indices) {
|
|
if (!HasAccumulator() || *ThreadTapeIsStopped()) {
|
|
Py_RETURN_NONE;
|
|
}
|
|
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
|
|
if (PyErr_Occurred()) return nullptr;
|
|
|
|
std::vector<tensorflow::DataType> input_dtypes =
|
|
MakeTensorDtypeList(input_tensors);
|
|
if (PyErr_Occurred()) return nullptr;
|
|
|
|
std::function<PyBackwardFunction*()> backward_function_getter(
|
|
[backward_function]() {
|
|
Py_INCREF(backward_function);
|
|
PyBackwardFunction* function = new PyBackwardFunction(
|
|
[backward_function](PyObject* out_grads,
|
|
const std::vector<tensorflow::int64>& unused) {
|
|
return PyObject_CallObject(backward_function, out_grads);
|
|
});
|
|
return function;
|
|
});
|
|
std::function<void(PyBackwardFunction*)> backward_function_killer(
|
|
[backward_function](PyBackwardFunction* py_backward_function) {
|
|
Py_DECREF(backward_function);
|
|
delete py_backward_function;
|
|
});
|
|
std::vector<PyTapeTensor> output_info;
|
|
tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
|
|
output_tensors, "expected a sequence of integer tensor ids"));
|
|
if (PyErr_Occurred() ||
|
|
!TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
|
|
return nullptr;
|
|
}
|
|
string op_type_str;
|
|
if (!ParseOpTypeString(op_type, &op_type_str)) {
|
|
return nullptr;
|
|
}
|
|
tensorflow::uint64 max_gradient_tape_id;
|
|
if (!TapeSetRecordForwardprop(
|
|
op_type_str, output_seq.get(), output_info, input_tensors, input_ids,
|
|
input_dtypes, backward_function_getter, backward_function_killer,
|
|
nullptr /* no special-cased forward function */,
|
|
forwardprop_output_indices, &max_gradient_tape_id)) {
|
|
return nullptr;
|
|
}
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
PyObject* TFE_Py_TapeSetRecordOperationBackprop(PyObject* op_type,
|
|
PyObject* output_tensors,
|
|
PyObject* input_tensors,
|
|
PyObject* backward_function) {
|
|
if (!CouldBackprop()) {
|
|
Py_RETURN_NONE;
|
|
}
|
|
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
|
|
if (PyErr_Occurred()) return nullptr;
|
|
|
|
std::vector<tensorflow::DataType> input_dtypes =
|
|
MakeTensorDtypeList(input_tensors);
|
|
if (PyErr_Occurred()) return nullptr;
|
|
|
|
std::function<PyBackwardFunction*()> backward_function_getter(
|
|
[backward_function]() {
|
|
Py_INCREF(backward_function);
|
|
PyBackwardFunction* function = new PyBackwardFunction(
|
|
[backward_function](PyObject* out_grads,
|
|
const std::vector<tensorflow::int64>& unused) {
|
|
return PyObject_CallObject(backward_function, out_grads);
|
|
});
|
|
return function;
|
|
});
|
|
std::function<void(PyBackwardFunction*)> backward_function_killer(
|
|
[backward_function](PyBackwardFunction* py_backward_function) {
|
|
Py_DECREF(backward_function);
|
|
delete py_backward_function;
|
|
});
|
|
std::vector<PyTapeTensor> output_info;
|
|
tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
|
|
output_tensors, "expected a sequence of integer tensor ids"));
|
|
if (PyErr_Occurred() ||
|
|
!TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
|
|
return nullptr;
|
|
}
|
|
string op_type_str;
|
|
if (!ParseOpTypeString(op_type, &op_type_str)) {
|
|
return nullptr;
|
|
}
|
|
TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes,
|
|
backward_function_getter, backward_function_killer,
|
|
// No filtering based on relative ordering with forward
|
|
// accumulators.
|
|
std::numeric_limits<tensorflow::uint64>::max());
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
|
|
for (TFE_Py_Tape* tape : *GetTapeSet()) {
|
|
tape->tape->DeleteTrace(tensor_id);
|
|
}
|
|
}
|
|
|
|
std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
|
|
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
|
|
if (seq == nullptr) {
|
|
return {};
|
|
}
|
|
int len = PySequence_Fast_GET_SIZE(seq);
|
|
PyObject** seq_array = PySequence_Fast_ITEMS(seq);
|
|
std::vector<PyObject*> list(seq_array, seq_array + len);
|
|
Py_DECREF(seq);
|
|
return list;
|
|
}
|
|
|
|
PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
|
|
PyObject* sources, PyObject* output_gradients,
|
|
PyObject* sources_raw,
|
|
PyObject* unconnected_gradients,
|
|
TF_Status* status) {
|
|
TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
|
|
if (!tape_obj->tape->IsPersistent()) {
|
|
auto* tape_set = GetTapeSet();
|
|
if (tape_set->find(tape_obj) != tape_set->end()) {
|
|
PyErr_SetString(PyExc_RuntimeError,
|
|
"gradient() cannot be invoked within the "
|
|
"GradientTape context (i.e., while operations are being "
|
|
"recorded). Either move the call to gradient() to be "
|
|
"outside the 'with tf.GradientTape' block, or "
|
|
"use a persistent tape: "
|
|
"'with tf.GradientTape(persistent=true)'");
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
std::vector<tensorflow::int64> target_vec = MakeTensorIDList(target);
|
|
if (PyErr_Occurred()) {
|
|
return nullptr;
|
|
}
|
|
std::vector<tensorflow::int64> sources_vec = MakeTensorIDList(sources);
|
|
if (PyErr_Occurred()) {
|
|
return nullptr;
|
|
}
|
|
tensorflow::gtl::FlatSet<tensorflow::int64> sources_set(sources_vec.begin(),
|
|
sources_vec.end());
|
|
|
|
tensorflow::Safe_PyObjectPtr seq =
|
|
tensorflow::make_safe(PySequence_Fast(target, "expected a sequence"));
|
|
int len = PySequence_Fast_GET_SIZE(seq.get());
|
|
PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
|
|
std::unordered_map<tensorflow::int64, PyTapeTensor>
|
|
source_tensors_that_are_targets;
|
|
for (int i = 0; i < len; ++i) {
|
|
tensorflow::int64 target_id = target_vec[i];
|
|
if (sources_set.find(target_id) != sources_set.end()) {
|
|
auto tensor = seq_array[i];
|
|
source_tensors_that_are_targets.insert(
|
|
std::make_pair(target_id, TapeTensorFromTensor(tensor)));
|
|
}
|
|
if (PyErr_Occurred()) {
|
|
return nullptr;
|
|
}
|
|
}
|
|
if (PyErr_Occurred()) {
|
|
return nullptr;
|
|
}
|
|
|
|
std::vector<PyObject*> outgrad_vec;
|
|
if (output_gradients != Py_None) {
|
|
outgrad_vec = MakeTensorList(output_gradients);
|
|
if (PyErr_Occurred()) {
|
|
return nullptr;
|
|
}
|
|
for (PyObject* tensor : outgrad_vec) {
|
|
// Calling the backward function will eat a reference to the tensors in
|
|
// outgrad_vec, so we need to increase their reference count.
|
|
Py_INCREF(tensor);
|
|
}
|
|
}
|
|
std::vector<PyObject*> result(sources_vec.size());
|
|
status->status = tape_obj->tape->ComputeGradient(
|
|
*py_vspace, target_vec, sources_vec, source_tensors_that_are_targets,
|
|
outgrad_vec, absl::MakeSpan(result));
|
|
if (!status->status.ok()) {
|
|
if (PyErr_Occurred()) {
|
|
// Do not propagate the erroneous status as that would swallow the
|
|
// exception which caused the problem.
|
|
status->status = tensorflow::Status::OK();
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
bool unconnected_gradients_zero =
|
|
strcmp(TFE_GetPythonString(unconnected_gradients), "zero") == 0;
|
|
std::vector<PyObject*> sources_obj;
|
|
if (unconnected_gradients_zero) {
|
|
// Uses the "raw" sources here so it can properly make a zeros tensor even
|
|
// if there are resource variables as sources.
|
|
sources_obj = MakeTensorList(sources_raw);
|
|
}
|
|
|
|
if (!result.empty()) {
|
|
PyObject* py_result = PyList_New(result.size());
|
|
tensorflow::gtl::FlatSet<PyObject*> seen_results(result.size());
|
|
for (int i = 0; i < result.size(); ++i) {
|
|
if (result[i] == nullptr) {
|
|
if (unconnected_gradients_zero) {
|
|
// generate a zeros tensor in the shape of sources[i]
|
|
tensorflow::DataType dtype =
|
|
tensorflow::PyTensor_DataType(sources_obj[i]);
|
|
PyTapeTensor tensor =
|
|
PyTapeTensor(sources_vec[i], dtype, sources_obj[i]);
|
|
result[i] = tensor.ZerosLike();
|
|
} else {
|
|
Py_INCREF(Py_None);
|
|
result[i] = Py_None;
|
|
}
|
|
} else if (seen_results.find(result[i]) != seen_results.end()) {
|
|
Py_INCREF(result[i]);
|
|
}
|
|
seen_results.insert(result[i]);
|
|
PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]));
|
|
}
|
|
return py_result;
|
|
}
|
|
return PyList_New(0);
|
|
}
|
|
|
|
PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch) {
|
|
TFE_Py_ForwardAccumulator_Type.tp_new = PyType_GenericNew;
|
|
if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr;
|
|
TFE_Py_ForwardAccumulator* accumulator =
|
|
PyObject_NEW(TFE_Py_ForwardAccumulator, &TFE_Py_ForwardAccumulator_Type);
|
|
if (py_vspace == nullptr) {
|
|
MaybeRaiseExceptionFromStatus(
|
|
tensorflow::errors::Internal(
|
|
"ForwardAccumulator requires a PyVSpace to be registered."),
|
|
nullptr);
|
|
}
|
|
accumulator->accumulator = new ForwardAccumulator(*py_vspace, use_batch);
|
|
return reinterpret_cast<PyObject*>(accumulator);
|
|
}
|
|
|
|
PyObject* TFE_Py_ForwardAccumulatorSetAdd(PyObject* accumulator) {
|
|
TFE_Py_ForwardAccumulator* c_accumulator(
|
|
reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
|
|
c_accumulator->nesting_id = tape_nesting_id_counter.fetch_add(1);
|
|
if (GetAccumulatorSet()->insert(c_accumulator)) {
|
|
Py_INCREF(accumulator);
|
|
Py_RETURN_NONE;
|
|
} else {
|
|
MaybeRaiseExceptionFromStatus(
|
|
tensorflow::errors::Internal(
|
|
"A ForwardAccumulator was added to the active set twice."),
|
|
nullptr);
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
void TFE_Py_ForwardAccumulatorSetRemove(PyObject* accumulator) {
|
|
GetAccumulatorSet()->erase(
|
|
reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
|
|
Py_DECREF(accumulator);
|
|
}
|
|
|
|
void TFE_Py_ForwardAccumulatorWatch(PyObject* accumulator, PyObject* tensor,
|
|
PyObject* tangent) {
|
|
tensorflow::int64 tensor_id = FastTensorId(tensor);
|
|
reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)
|
|
->accumulator->Watch(tensor_id, tangent);
|
|
RegisterForwardAccumulatorCleanup(tensor, tensor_id);
|
|
}
|
|
|
|
// Returns a new reference to the JVP Tensor.
|
|
PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator,
|
|
PyObject* tensor) {
|
|
PyObject* jvp = reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)
|
|
->accumulator->FetchJVP(FastTensorId(tensor));
|
|
if (jvp == nullptr) {
|
|
jvp = Py_None;
|
|
}
|
|
Py_INCREF(jvp);
|
|
return jvp;
|
|
}
|
|
|
|
PyObject* TFE_Py_PackJVPs(PyObject* tensors) {
|
|
if (!TapeCouldPossiblyRecord(tensors)) {
|
|
tensorflow::Safe_PyObjectPtr empty_tuple(PyTuple_New(0));
|
|
tensorflow::Safe_PyObjectPtr empty_list(PyList_New(0));
|
|
return PyTuple_Pack(2, empty_tuple.get(), empty_list.get());
|
|
}
|
|
auto accumulators = *GetAccumulatorSet();
|
|
tensorflow::Safe_PyObjectPtr tensors_fast(
|
|
PySequence_Fast(tensors, "Expected a sequence of input Tensors."));
|
|
if (tensors_fast == nullptr || PyErr_Occurred()) {
|
|
return nullptr;
|
|
}
|
|
std::vector<tensorflow::int64> augmented_input_ids;
|
|
int len = PySequence_Fast_GET_SIZE(tensors_fast.get());
|
|
PyObject** tensors_fast_array = PySequence_Fast_ITEMS(tensors_fast.get());
|
|
for (Py_ssize_t position = 0; position < len; ++position) {
|
|
PyObject* input = tensors_fast_array[position];
|
|
if (input == Py_None) {
|
|
continue;
|
|
}
|
|
tensorflow::DataType input_dtype(tensorflow::PyTensor_DataType(input));
|
|
if (input_dtype == tensorflow::DT_INVALID) {
|
|
return nullptr;
|
|
}
|
|
augmented_input_ids.push_back(FastTensorId(input));
|
|
}
|
|
if (PyErr_Occurred()) {
|
|
return nullptr;
|
|
}
|
|
// Find the innermost accumulator such that all outer accumulators are
|
|
// recording. Any more deeply nested accumulators will not have their JVPs
|
|
// saved.
|
|
AccumulatorSet::const_iterator innermost_all_recording = accumulators.begin();
|
|
for (; innermost_all_recording != accumulators.end();
|
|
++innermost_all_recording) {
|
|
if ((*innermost_all_recording)->accumulator->BusyAccumulating()) {
|
|
break;
|
|
}
|
|
}
|
|
AccumulatorSet::const_reverse_iterator reverse_innermost_all_recording(
|
|
innermost_all_recording);
|
|
|
|
bool saving_jvps = false;
|
|
tensorflow::Safe_PyObjectPtr all_indices(PyTuple_New(accumulators.size()));
|
|
std::vector<PyObject*> new_tensors;
|
|
Py_ssize_t accumulator_index = 0;
|
|
// Start with the innermost accumulators to give outer accumulators a chance
|
|
// to find their higher-order JVPs.
|
|
for (AccumulatorSet::const_reverse_iterator it = accumulators.rbegin();
|
|
it != accumulators.rend(); ++it, ++accumulator_index) {
|
|
std::vector<tensorflow::int64> new_input_ids;
|
|
std::vector<std::pair<tensorflow::int64, tensorflow::int64>>
|
|
accumulator_indices;
|
|
if (it == reverse_innermost_all_recording) {
|
|
saving_jvps = true;
|
|
}
|
|
if (saving_jvps) {
|
|
for (int input_index = 0; input_index < augmented_input_ids.size();
|
|
++input_index) {
|
|
tensorflow::int64 existing_input = augmented_input_ids[input_index];
|
|
PyObject* jvp = (*it)->accumulator->FetchJVP(existing_input);
|
|
if (jvp != nullptr) {
|
|
new_tensors.push_back(jvp);
|
|
new_input_ids.push_back(FastTensorId(jvp));
|
|
accumulator_indices.emplace_back(
|
|
input_index,
|
|
augmented_input_ids.size() + new_input_ids.size() - 1);
|
|
}
|
|
}
|
|
}
|
|
tensorflow::Safe_PyObjectPtr accumulator_indices_py(
|
|
PyTuple_New(accumulator_indices.size()));
|
|
for (int i = 0; i < accumulator_indices.size(); ++i) {
|
|
tensorflow::Safe_PyObjectPtr from_index(
|
|
GetPythonObjectFromInt(accumulator_indices[i].first));
|
|
tensorflow::Safe_PyObjectPtr to_index(
|
|
GetPythonObjectFromInt(accumulator_indices[i].second));
|
|
PyTuple_SetItem(accumulator_indices_py.get(), i,
|
|
PyTuple_Pack(2, from_index.get(), to_index.get()));
|
|
}
|
|
PyTuple_SetItem(all_indices.get(), accumulator_index,
|
|
accumulator_indices_py.release());
|
|
augmented_input_ids.insert(augmented_input_ids.end(), new_input_ids.begin(),
|
|
new_input_ids.end());
|
|
}
|
|
|
|
tensorflow::Safe_PyObjectPtr new_tensors_py(PyList_New(new_tensors.size()));
|
|
for (int i = 0; i < new_tensors.size(); ++i) {
|
|
PyObject* jvp = new_tensors[i];
|
|
Py_INCREF(jvp);
|
|
PyList_SET_ITEM(new_tensors_py.get(), i, jvp);
|
|
}
|
|
return PyTuple_Pack(2, all_indices.get(), new_tensors_py.get());
|
|
}
|
|
|
|
namespace {
|
|
|
|
// Indices for the "args" tuple that's passed to TFE_Py_FastPathExecute_C.
|
|
enum FastPathExecuteArgIndex {
|
|
FAST_PATH_EXECUTE_ARG_CONTEXT = 0,
|
|
FAST_PATH_EXECUTE_ARG_OP_NAME = 1,
|
|
FAST_PATH_EXECUTE_ARG_NAME = 2,
|
|
FAST_PATH_EXECUTE_ARG_INPUT_START = 3
|
|
};
|
|
|
|
PyObject* GetPythonObjectFromString(tensorflow::StringPiece s) {
|
|
#if PY_MAJOR_VERSION >= 3
|
|
return PyUnicode_FromStringAndSize(s.data(), s.size());
|
|
#else
|
|
return PyBytes_FromStringAndSize(s.data(), s.size());
|
|
#endif
|
|
}
|
|
|
|
bool CheckResourceVariable(PyObject* item) {
|
|
if (tensorflow::swig::IsResourceVariable(item)) {
|
|
tensorflow::Safe_PyObjectPtr handle(
|
|
PyObject_GetAttrString(item, "_handle"));
|
|
return EagerTensor_CheckExact(handle.get());
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
bool IsNumberType(PyObject* item) {
|
|
#if PY_MAJOR_VERSION >= 3
|
|
return PyFloat_Check(item) || PyLong_Check(item);
|
|
#else
|
|
return PyFloat_Check(item) || PyInt_Check(item) || PyLong_Check(item);
|
|
#endif
|
|
}
|
|
|
|
bool CheckOneInput(PyObject* item) {
|
|
if (EagerTensor_CheckExact(item) || CheckResourceVariable(item) ||
|
|
PyArray_Check(item) || IsNumberType(item)) {
|
|
return true;
|
|
}
|
|
|
|
// Sequences are not properly handled. Sequences with purely python numeric
|
|
// types work, but sequences with mixes of EagerTensors and python numeric
|
|
// types don't work.
|
|
// TODO(nareshmodi): fix
|
|
return false;
|
|
}
|
|
|
|
bool CheckInputsOk(PyObject* seq, int start_index,
|
|
const tensorflow::OpDef& op_def) {
|
|
for (int i = 0; i < op_def.input_arg_size(); i++) {
|
|
PyObject* item = PyTuple_GET_ITEM(seq, i + start_index);
|
|
if (!op_def.input_arg(i).number_attr().empty() ||
|
|
!op_def.input_arg(i).type_list_attr().empty()) {
|
|
// This item should be a seq input.
|
|
if (!PySequence_Check(item)) {
|
|
VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
|
|
<< "\", Input \"" << op_def.input_arg(i).name()
|
|
<< "\" since we expected a sequence, but got "
|
|
<< item->ob_type->tp_name;
|
|
return false;
|
|
}
|
|
tensorflow::Safe_PyObjectPtr fast_item(
|
|
PySequence_Fast(item, "Could not parse sequence."));
|
|
if (fast_item.get() == nullptr) {
|
|
return false;
|
|
}
|
|
int len = PySequence_Fast_GET_SIZE(fast_item.get());
|
|
PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get());
|
|
for (Py_ssize_t j = 0; j < len; j++) {
|
|
PyObject* inner_item = fast_item_array[j];
|
|
if (!CheckOneInput(inner_item)) {
|
|
VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
|
|
<< "\", Input \"" << op_def.input_arg(i).name()
|
|
<< "\", Index " << j
|
|
<< " since we expected an EagerTensor/ResourceVariable, "
|
|
"but got "
|
|
<< inner_item->ob_type->tp_name;
|
|
return false;
|
|
}
|
|
}
|
|
} else if (!CheckOneInput(item)) {
|
|
VLOG(1)
|
|
<< "Falling back to slow path for Op \"" << op_def.name()
|
|
<< "\", Input \"" << op_def.input_arg(i).name()
|
|
<< "\" since we expected an EagerTensor/ResourceVariable, but got "
|
|
<< item->ob_type->tp_name;
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
tensorflow::DataType MaybeGetDType(PyObject* item) {
|
|
if (EagerTensor_CheckExact(item) || CheckResourceVariable(item)) {
|
|
return tensorflow::PyTensor_DataType(item);
|
|
}
|
|
|
|
return tensorflow::DT_INVALID;
|
|
}
|
|
|
|
tensorflow::DataType MaybeGetDTypeForAttr(const string& attr,
|
|
FastPathOpExecInfo* op_exec_info) {
|
|
auto cached_it = op_exec_info->cached_dtypes.find(attr);
|
|
if (cached_it != op_exec_info->cached_dtypes.end()) {
|
|
return cached_it->second;
|
|
}
|
|
|
|
auto it = op_exec_info->attr_to_inputs_map->find(attr);
|
|
if (it == op_exec_info->attr_to_inputs_map->end()) {
|
|
// No other inputs - this should never happen.
|
|
return tensorflow::DT_INVALID;
|
|
}
|
|
|
|
for (const auto& input_info : it->second) {
|
|
PyObject* item = PyTuple_GET_ITEM(
|
|
op_exec_info->args, FAST_PATH_EXECUTE_ARG_INPUT_START + input_info.i);
|
|
if (input_info.is_list) {
|
|
tensorflow::Safe_PyObjectPtr fast_item(
|
|
PySequence_Fast(item, "Unable to allocate"));
|
|
int len = PySequence_Fast_GET_SIZE(fast_item.get());
|
|
PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get());
|
|
for (int i = 0; i < len; i++) {
|
|
auto dtype = MaybeGetDType(fast_item_array[i]);
|
|
if (dtype != tensorflow::DT_INVALID) return dtype;
|
|
}
|
|
} else {
|
|
auto dtype = MaybeGetDType(item);
|
|
if (dtype != tensorflow::DT_INVALID) return dtype;
|
|
}
|
|
}
|
|
|
|
auto default_it = op_exec_info->default_dtypes->find(attr);
|
|
if (default_it != op_exec_info->default_dtypes->end()) {
|
|
return default_it->second;
|
|
}
|
|
|
|
return tensorflow::DT_INVALID;
|
|
}
|
|
|
|
PyObject* CopySequenceSettingIndicesToNull(
|
|
PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) {
|
|
tensorflow::Safe_PyObjectPtr fast_seq(
|
|
PySequence_Fast(seq, "unable to allocate"));
|
|
int len = PySequence_Fast_GET_SIZE(fast_seq.get());
|
|
PyObject** fast_seq_array = PySequence_Fast_ITEMS(fast_seq.get());
|
|
PyObject* result = PyTuple_New(len);
|
|
for (int i = 0; i < len; i++) {
|
|
PyObject* item;
|
|
if (indices.find(i) != indices.end()) {
|
|
item = Py_None;
|
|
} else {
|
|
item = fast_seq_array[i];
|
|
}
|
|
Py_INCREF(item);
|
|
PyTuple_SET_ITEM(result, i, item);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
|
PyObject* results,
|
|
PyObject* forward_pass_name_scope = nullptr) {
|
|
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
|
|
if (PyErr_Occurred()) return nullptr;
|
|
std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs);
|
|
if (PyErr_Occurred()) return nullptr;
|
|
|
|
bool should_record = false;
|
|
for (TFE_Py_Tape* tape : SafeTapeSet()) {
|
|
if (tape->tape->ShouldRecord(input_ids, input_dtypes)) {
|
|
should_record = true;
|
|
break;
|
|
}
|
|
}
|
|
if (!should_record) {
|
|
for (TFE_Py_ForwardAccumulator* accumulator : SafeAccumulatorSet()) {
|
|
if (accumulator->accumulator->ShouldRecord(input_ids, input_dtypes)) {
|
|
should_record = true;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
if (!should_record) Py_RETURN_NONE;
|
|
|
|
string c_op_name = TFE_GetPythonString(op_name);
|
|
|
|
PyObject* op_outputs;
|
|
bool op_outputs_tuple_created = false;
|
|
|
|
if (const auto unused_output_indices =
|
|
OpGradientUnusedOutputIndices(c_op_name)) {
|
|
if (unused_output_indices->empty()) {
|
|
op_outputs = Py_None;
|
|
} else {
|
|
op_outputs_tuple_created = true;
|
|
op_outputs =
|
|
CopySequenceSettingIndicesToNull(results, *unused_output_indices);
|
|
}
|
|
} else {
|
|
op_outputs = results;
|
|
}
|
|
|
|
PyObject* op_inputs;
|
|
bool op_inputs_tuple_created = false;
|
|
|
|
if (const auto unused_input_indices =
|
|
OpGradientUnusedInputIndices(c_op_name)) {
|
|
if (unused_input_indices->empty()) {
|
|
op_inputs = Py_None;
|
|
} else {
|
|
op_inputs_tuple_created = true;
|
|
op_inputs =
|
|
CopySequenceSettingIndicesToNull(inputs, *unused_input_indices);
|
|
}
|
|
} else {
|
|
op_inputs = inputs;
|
|
}
|
|
|
|
tensorflow::eager::ForwardFunction<PyObject> py_forward_function(
|
|
[op_name, attrs, inputs, results](
|
|
const std::vector<PyObject*>& input_tangents,
|
|
std::vector<PyObject*>* output_tangents, bool use_batch) {
|
|
return CallJVPFunction(op_name, attrs, inputs, results, input_tangents,
|
|
output_tangents, use_batch);
|
|
});
|
|
tensorflow::eager::ForwardFunction<PyObject>* forward_function;
|
|
if (c_op_name == "While" || c_op_name == "StatelessWhile" ||
|
|
c_op_name == "If" || c_op_name == "StatelessIf") {
|
|
// Control flow contains non-hashable attributes. Handling them in Python is
|
|
// a headache, so instead we'll stay as close to GradientTape's handling as
|
|
// possible (a null forward function means the accumulator forwards to a
|
|
// tape).
|
|
//
|
|
// This is safe to do since we'll only see control flow when graph building,
|
|
// in which case we can rely on pruning.
|
|
forward_function = nullptr;
|
|
} else {
|
|
forward_function = &py_forward_function;
|
|
}
|
|
|
|
PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs));
|
|
|
|
if (!forward_pass_name_scope) forward_pass_name_scope = Py_None;
|
|
|
|
TapeSetRecordOperation(
|
|
op_name, inputs, results, input_ids, input_dtypes,
|
|
[op_name, attrs, num_inputs, op_inputs, op_outputs,
|
|
forward_pass_name_scope]() {
|
|
Py_INCREF(op_name);
|
|
Py_INCREF(attrs);
|
|
Py_INCREF(num_inputs);
|
|
Py_INCREF(op_inputs);
|
|
Py_INCREF(op_outputs);
|
|
Py_INCREF(forward_pass_name_scope);
|
|
PyBackwardFunction* function = new PyBackwardFunction(
|
|
[op_name, attrs, num_inputs, op_inputs, op_outputs,
|
|
forward_pass_name_scope](
|
|
PyObject* output_grads,
|
|
const std::vector<tensorflow::int64>& unneeded_gradients) {
|
|
if (PyErr_Occurred()) {
|
|
return static_cast<PyObject*>(nullptr);
|
|
}
|
|
tensorflow::Safe_PyObjectPtr skip_input_indices;
|
|
if (!unneeded_gradients.empty()) {
|
|
skip_input_indices.reset(
|
|
PyTuple_New(unneeded_gradients.size()));
|
|
for (int i = 0; i < unneeded_gradients.size(); i++) {
|
|
PyTuple_SET_ITEM(
|
|
skip_input_indices.get(), i,
|
|
GetPythonObjectFromInt(unneeded_gradients[i]));
|
|
}
|
|
} else {
|
|
Py_INCREF(Py_None);
|
|
skip_input_indices.reset(Py_None);
|
|
}
|
|
tensorflow::Safe_PyObjectPtr callback_args(Py_BuildValue(
|
|
"OOOOOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs,
|
|
output_grads, skip_input_indices.get(),
|
|
forward_pass_name_scope));
|
|
|
|
tensorflow::Safe_PyObjectPtr result(
|
|
PyObject_CallObject(gradient_function, callback_args.get()));
|
|
|
|
if (PyErr_Occurred()) return static_cast<PyObject*>(nullptr);
|
|
|
|
return tensorflow::swig::Flatten(result.get());
|
|
});
|
|
return function;
|
|
},
|
|
[op_name, attrs, num_inputs, op_inputs, op_outputs,
|
|
forward_pass_name_scope](PyBackwardFunction* backward_function) {
|
|
Py_DECREF(op_name);
|
|
Py_DECREF(attrs);
|
|
Py_DECREF(num_inputs);
|
|
Py_DECREF(op_inputs);
|
|
Py_DECREF(op_outputs);
|
|
Py_DECREF(forward_pass_name_scope);
|
|
|
|
delete backward_function;
|
|
},
|
|
forward_function);
|
|
|
|
Py_DECREF(num_inputs);
|
|
if (op_outputs_tuple_created) Py_DECREF(op_outputs);
|
|
if (op_inputs_tuple_created) Py_DECREF(op_inputs);
|
|
|
|
if (PyErr_Occurred()) {
|
|
return nullptr;
|
|
}
|
|
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
void MaybeNotifyVariableAccessed(PyObject* input) {
|
|
DCHECK(CheckResourceVariable(input));
|
|
DCHECK(PyObject_HasAttrString(input, "_trainable"));
|
|
|
|
tensorflow::Safe_PyObjectPtr trainable(
|
|
PyObject_GetAttrString(input, "_trainable"));
|
|
if (trainable.get() == Py_False) return;
|
|
TFE_Py_TapeVariableAccessed(input);
|
|
TFE_Py_VariableWatcherVariableAccessed(input);
|
|
}
|
|
|
|
bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
|
|
PyObject* input, tensorflow::Safe_PyObjectPtr* output,
|
|
TF_Status* status) {
|
|
MaybeNotifyVariableAccessed(input);
|
|
|
|
TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status);
|
|
auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); });
|
|
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
|
|
|
|
TFE_OpSetDevice(op, parent_op_exec_info.device_name, status);
|
|
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
|
|
|
|
// Set dtype
|
|
DCHECK(PyObject_HasAttrString(input, "_dtype"));
|
|
tensorflow::Safe_PyObjectPtr dtype(PyObject_GetAttrString(input, "_dtype"));
|
|
int value;
|
|
if (!ParseTypeValue("_dtype", dtype.get(), status, &value)) {
|
|
return false;
|
|
}
|
|
TFE_OpSetAttrType(op, "dtype", static_cast<TF_DataType>(value));
|
|
|
|
// Get handle
|
|
tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(input, "_handle"));
|
|
if (!EagerTensor_CheckExact(handle.get())) return false;
|
|
TFE_OpAddInput(op, EagerTensor_Handle(handle.get()), status);
|
|
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
|
|
|
|
int num_retvals = 1;
|
|
TFE_TensorHandle* output_handle;
|
|
TFE_Execute(op, &output_handle, &num_retvals, status);
|
|
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
|
|
|
|
// Always create the py object (and correctly DECREF it) from the returned
|
|
// value, else the data will leak.
|
|
output->reset(EagerTensorFromHandle(output_handle));
|
|
|
|
// TODO(nareshmodi): Should we run post exec callbacks here?
|
|
if (parent_op_exec_info.run_gradient_callback) {
|
|
tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(1));
|
|
PyTuple_SET_ITEM(inputs.get(), 0, handle.release());
|
|
|
|
tensorflow::Safe_PyObjectPtr outputs(PyTuple_New(1));
|
|
Py_INCREF(output->get()); // stay alive after since tuple steals.
|
|
PyTuple_SET_ITEM(outputs.get(), 0, output->get());
|
|
|
|
tensorflow::Safe_PyObjectPtr op_string(
|
|
GetPythonObjectFromString("ReadVariableOp"));
|
|
if (!RecordGradient(op_string.get(), inputs.get(), Py_None,
|
|
outputs.get())) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
// Supports 3 cases at the moment:
|
|
// i) input is an EagerTensor.
|
|
// ii) input is a ResourceVariable - in this case, the is_variable param is
|
|
// set to true.
|
|
// iii) input is an arbitrary python list/tuple (note, this handling doesn't
|
|
// support packing).
|
|
//
|
|
// NOTE: dtype_hint_getter must *always* return a PyObject that can be
|
|
// decref'd. So if no hint is found, Py_RETURN_NONE (which correctly
|
|
// increfs Py_None).
|
|
//
|
|
// NOTE: This function sets a python error directly, and returns false.
|
|
// TF_Status is only passed since we don't want to have to reallocate it.
|
|
bool ConvertToTensor(
|
|
const FastPathOpExecInfo& op_exec_info, PyObject* input,
|
|
tensorflow::Safe_PyObjectPtr* output_handle,
|
|
// This gets a hint for this particular input.
|
|
const std::function<tensorflow::DataType()>& dtype_hint_getter,
|
|
// This sets the dtype after conversion is complete.
|
|
const std::function<void(const tensorflow::DataType dtype)>& dtype_setter,
|
|
TF_Status* status) {
|
|
if (EagerTensor_CheckExact(input)) {
|
|
Py_INCREF(input);
|
|
output_handle->reset(input);
|
|
return true;
|
|
} else if (CheckResourceVariable(input)) {
|
|
return ReadVariableOp(op_exec_info, input, output_handle, status);
|
|
}
|
|
|
|
// The hint comes from a supposedly similarly typed tensor.
|
|
tensorflow::DataType dtype_hint = dtype_hint_getter();
|
|
|
|
TFE_TensorHandle* handle = tensorflow::ConvertToEagerTensor(
|
|
op_exec_info.ctx, input, dtype_hint, op_exec_info.device_name);
|
|
if (handle == nullptr) {
|
|
return MaybeRaiseExceptionFromTFStatus(status, nullptr);
|
|
}
|
|
|
|
output_handle->reset(EagerTensorFromHandle(handle));
|
|
dtype_setter(
|
|
static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(handle)));
|
|
|
|
return true;
|
|
}
|
|
|
|
// Adds input and type attr to the op, and to the list of flattened
|
|
// inputs/attrs.
|
|
bool AddInputToOp(FastPathOpExecInfo* op_exec_info, PyObject* input,
|
|
const bool add_type_attr,
|
|
const tensorflow::OpDef::ArgDef& input_arg,
|
|
std::vector<tensorflow::Safe_PyObjectPtr>* flattened_attrs,
|
|
std::vector<tensorflow::Safe_PyObjectPtr>* flattened_inputs,
|
|
TFE_Op* op, TF_Status* status) {
|
|
// py_eager_tensor's ownership is transferred to flattened_inputs if it is
|
|
// required, else the object is destroyed and DECREF'd when the object goes
|
|
// out of scope in this function.
|
|
tensorflow::Safe_PyObjectPtr py_eager_tensor = nullptr;
|
|
|
|
if (!ConvertToTensor(
|
|
*op_exec_info, input, &py_eager_tensor,
|
|
[&]() {
|
|
if (input_arg.type() != tensorflow::DataType::DT_INVALID) {
|
|
return input_arg.type();
|
|
}
|
|
return MaybeGetDTypeForAttr(input_arg.type_attr(), op_exec_info);
|
|
},
|
|
[&](const tensorflow::DataType dtype) {
|
|
op_exec_info->cached_dtypes[input_arg.type_attr()] = dtype;
|
|
},
|
|
status)) {
|
|
return false;
|
|
}
|
|
|
|
TFE_TensorHandle* input_handle = EagerTensor_Handle(py_eager_tensor.get());
|
|
|
|
if (add_type_attr && !input_arg.type_attr().empty()) {
|
|
auto dtype = TFE_TensorHandleDataType(input_handle);
|
|
TFE_OpSetAttrType(op, input_arg.type_attr().data(), dtype);
|
|
if (flattened_attrs != nullptr) {
|
|
flattened_attrs->emplace_back(
|
|
GetPythonObjectFromString(input_arg.type_attr()));
|
|
flattened_attrs->emplace_back(PyLong_FromLong(dtype));
|
|
}
|
|
}
|
|
|
|
if (flattened_inputs != nullptr) {
|
|
flattened_inputs->emplace_back(std::move(py_eager_tensor));
|
|
}
|
|
|
|
TFE_OpAddInput(op, input_handle, status);
|
|
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
const char* GetDeviceName(PyObject* py_device_name) {
|
|
if (py_device_name != Py_None) {
|
|
return TFE_GetPythonString(py_device_name);
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
bool RaiseIfNotPySequence(PyObject* seq, const string& attr_name) {
|
|
if (!PySequence_Check(seq)) {
|
|
PyErr_SetString(PyExc_TypeError,
|
|
Printf("expected a sequence for attr %s, got %s instead",
|
|
attr_name.data(), seq->ob_type->tp_name)
|
|
.data());
|
|
|
|
return false;
|
|
}
|
|
if (PyArray_Check(seq) &&
|
|
PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)) != 1) {
|
|
PyErr_SetString(PyExc_ValueError,
|
|
Printf("expected a sequence for attr %s, got an ndarray "
|
|
"with rank %d instead",
|
|
attr_name.data(),
|
|
PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)))
|
|
.data());
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool RunCallbacks(
|
|
const FastPathOpExecInfo& op_exec_info, PyObject* args,
|
|
int num_inferred_attrs,
|
|
const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_inputs,
|
|
const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_attrs,
|
|
PyObject* flattened_result) {
|
|
DCHECK(op_exec_info.run_callbacks);
|
|
|
|
tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(flattened_inputs.size()));
|
|
for (int i = 0; i < flattened_inputs.size(); i++) {
|
|
PyObject* input = flattened_inputs[i].get();
|
|
Py_INCREF(input);
|
|
PyTuple_SET_ITEM(inputs.get(), i, input);
|
|
}
|
|
|
|
int num_non_inferred_attrs = PyTuple_GET_SIZE(args) - num_inferred_attrs;
|
|
int num_attrs = flattened_attrs.size() + num_non_inferred_attrs;
|
|
tensorflow::Safe_PyObjectPtr attrs(PyTuple_New(num_attrs));
|
|
|
|
for (int i = 0; i < num_non_inferred_attrs; i++) {
|
|
auto* attr = PyTuple_GET_ITEM(args, num_inferred_attrs + i);
|
|
Py_INCREF(attr);
|
|
PyTuple_SET_ITEM(attrs.get(), i, attr);
|
|
}
|
|
|
|
for (int i = num_non_inferred_attrs; i < num_attrs; i++) {
|
|
PyObject* attr_or_name =
|
|
flattened_attrs.at(i - num_non_inferred_attrs).get();
|
|
Py_INCREF(attr_or_name);
|
|
PyTuple_SET_ITEM(attrs.get(), i, attr_or_name);
|
|
}
|
|
|
|
if (op_exec_info.run_gradient_callback) {
|
|
if (!RecordGradient(op_exec_info.op_name, inputs.get(), attrs.get(),
|
|
flattened_result)) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
if (op_exec_info.run_post_exec_callbacks) {
|
|
tensorflow::Safe_PyObjectPtr callback_args(
|
|
Py_BuildValue("OOOOO", op_exec_info.op_name, inputs.get(), attrs.get(),
|
|
flattened_result, op_exec_info.name));
|
|
for (Py_ssize_t i = 0; i < PyList_Size(op_exec_info.callbacks); i++) {
|
|
PyObject* callback_fn = PyList_GET_ITEM(op_exec_info.callbacks, i);
|
|
if (!PyCallable_Check(callback_fn)) {
|
|
PyErr_SetString(
|
|
PyExc_TypeError,
|
|
Printf("expected a function for "
|
|
"post execution callback in index %ld, got %s instead",
|
|
i, callback_fn->ob_type->tp_name)
|
|
.c_str());
|
|
return false;
|
|
}
|
|
PyObject* callback_result =
|
|
PyObject_CallObject(callback_fn, callback_args.get());
|
|
if (!callback_result) {
|
|
return false;
|
|
}
|
|
Py_DECREF(callback_result);
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
|
|
tensorflow::profiler::TraceMe activity(
|
|
"TFE_Py_FastPathExecute_C", tensorflow::profiler::TraceMeLevel::kInfo);
|
|
Py_ssize_t args_size = PyTuple_GET_SIZE(args);
|
|
if (args_size < FAST_PATH_EXECUTE_ARG_INPUT_START) {
|
|
PyErr_SetString(
|
|
PyExc_ValueError,
|
|
Printf("There must be at least %d items in the input tuple.",
|
|
FAST_PATH_EXECUTE_ARG_INPUT_START)
|
|
.c_str());
|
|
return nullptr;
|
|
}
|
|
|
|
FastPathOpExecInfo op_exec_info;
|
|
|
|
PyObject* py_eager_context =
|
|
PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_CONTEXT);
|
|
|
|
// TODO(edoper): Use interned string here
|
|
PyObject* eager_context_handle =
|
|
PyObject_GetAttrString(py_eager_context, "_context_handle");
|
|
|
|
TFE_Context* ctx = reinterpret_cast<TFE_Context*>(
|
|
PyCapsule_GetPointer(eager_context_handle, nullptr));
|
|
op_exec_info.ctx = ctx;
|
|
op_exec_info.args = args;
|
|
|
|
if (ctx == nullptr) {
|
|
// The context hasn't been initialized. It will be in the slow path.
|
|
RaiseFallbackException(
|
|
"This function does not handle the case of the path where "
|
|
"all inputs are not already EagerTensors.");
|
|
return nullptr;
|
|
}
|
|
|
|
auto* tld = tensorflow::GetEagerContextThreadLocalData(py_eager_context);
|
|
if (tld == nullptr) {
|
|
return nullptr;
|
|
}
|
|
op_exec_info.device_name = GetDeviceName(tld->device_name.get());
|
|
op_exec_info.callbacks = tld->op_callbacks.get();
|
|
|
|
op_exec_info.op_name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_OP_NAME);
|
|
op_exec_info.name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_NAME);
|
|
|
|
// TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks
|
|
// (similar to benchmark_tf_gradient_function_*). Also consider using an
|
|
// InlinedVector for flattened_attrs and flattened_inputs if the benchmarks
|
|
// point out problems with heap allocs.
|
|
op_exec_info.run_gradient_callback =
|
|
!*ThreadTapeIsStopped() && HasAccumulatorOrTape();
|
|
op_exec_info.run_post_exec_callbacks =
|
|
op_exec_info.callbacks != Py_None &&
|
|
PyList_Size(op_exec_info.callbacks) > 0;
|
|
op_exec_info.run_callbacks = op_exec_info.run_gradient_callback ||
|
|
op_exec_info.run_post_exec_callbacks;
|
|
|
|
TF_Status* status = GetStatus();
|
|
const char* op_name = TFE_GetPythonString(op_exec_info.op_name);
|
|
if (op_name == nullptr) {
|
|
PyErr_SetString(PyExc_TypeError,
|
|
Printf("expected a string for op_name, got %s instead",
|
|
op_exec_info.op_name->ob_type->tp_name)
|
|
.c_str());
|
|
return nullptr;
|
|
}
|
|
|
|
TFE_Op* op = GetOp(ctx, op_name, op_exec_info.device_name, status);
|
|
|
|
auto cleaner = tensorflow::gtl::MakeCleanup([status, ctx, op] {
|
|
ReturnStatus(status);
|
|
ReturnOp(ctx, op);
|
|
});
|
|
|
|
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
|
|
return nullptr;
|
|
}
|
|
|
|
tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace(
|
|
tensorflow::StackTrace::kStackTraceInitialSize));
|
|
|
|
const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef();
|
|
if (op_def == nullptr) return nullptr;
|
|
|
|
if (args_size <
|
|
FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size()) {
|
|
PyErr_SetString(
|
|
PyExc_ValueError,
|
|
Printf("Tuple size smaller than intended. Expected to be at least %d, "
|
|
"was %ld",
|
|
FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
|
|
args_size)
|
|
.c_str());
|
|
return nullptr;
|
|
}
|
|
|
|
if (!CheckInputsOk(args, FAST_PATH_EXECUTE_ARG_INPUT_START, *op_def)) {
|
|
RaiseFallbackException(
|
|
"This function does not handle the case of the path where "
|
|
"all inputs are not already EagerTensors.");
|
|
return nullptr;
|
|
}
|
|
|
|
op_exec_info.attr_to_inputs_map = GetAttrToInputsMapHoldingGIL(*op_def);
|
|
op_exec_info.default_dtypes = GetAttrToDefaultsMapHoldingGIL(*op_def);
|
|
|
|
// Mapping of attr name to size - used to calculate the number of values
|
|
// to be expected by the TFE_Execute run.
|
|
tensorflow::gtl::FlatMap<string, tensorflow::int64> attr_list_sizes;
|
|
|
|
// Set non-inferred attrs, including setting defaults if the attr is passed in
|
|
// as None.
|
|
for (int i = FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size();
|
|
i < args_size; i += 2) {
|
|
PyObject* py_attr_name = PyTuple_GET_ITEM(args, i);
|
|
const char* attr_name = TFE_GetPythonString(py_attr_name);
|
|
PyObject* py_attr_value = PyTuple_GET_ITEM(args, i + 1);
|
|
|
|
// Not creating an index since most of the time there are not more than a
|
|
// few attrs.
|
|
// TODO(nareshmodi): Maybe include the index as part of the
|
|
// OpRegistrationData.
|
|
for (const auto& attr : op_def->attr()) {
|
|
if (tensorflow::StringPiece(attr_name) == attr.name()) {
|
|
SetOpAttrWithDefaults(ctx, op, attr, attr_name, py_attr_value,
|
|
&attr_list_sizes, status);
|
|
|
|
if (!status->status.ok()) {
|
|
VLOG(1) << "Falling back to slow path for Op \"" << op_def->name()
|
|
<< "\" since we are unable to set the value for attr \""
|
|
<< attr.name() << "\" due to: " << TF_Message(status);
|
|
RaiseFallbackException(TF_Message(status));
|
|
return nullptr;
|
|
}
|
|
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Flat attrs and inputs as required by the record_gradient call. The attrs
|
|
// here only contain inferred attrs (non-inferred attrs are added directly
|
|
// from the input args).
|
|
// All items in flattened_attrs and flattened_inputs contain
|
|
// Safe_PyObjectPtr - any time something steals a reference to this, it must
|
|
// INCREF.
|
|
// TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work
|
|
// directly.
|
|
std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_attrs =
|
|
nullptr;
|
|
std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_inputs =
|
|
nullptr;
|
|
|
|
// TODO(nareshmodi): Encapsulate callbacks information into a struct.
|
|
if (op_exec_info.run_callbacks) {
|
|
flattened_attrs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
|
|
flattened_inputs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
|
|
}
|
|
|
|
// Add inferred attrs and inputs.
|
|
// The following code might set duplicate type attrs. This will result in
|
|
// the CacheKey for the generated AttrBuilder possibly differing from
|
|
// those where the type attrs are correctly set. Inconsistent CacheKeys
|
|
// for ops means that there might be unnecessarily duplicated kernels.
|
|
// TODO(nareshmodi): Fix this.
|
|
for (int i = 0; i < op_def->input_arg_size(); i++) {
|
|
const auto& input_arg = op_def->input_arg(i);
|
|
|
|
PyObject* input =
|
|
PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_INPUT_START + i);
|
|
if (!input_arg.number_attr().empty()) {
|
|
// The item is a homogeneous list.
|
|
if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr;
|
|
tensorflow::Safe_PyObjectPtr fast_input(
|
|
PySequence_Fast(input, "Could not parse sequence."));
|
|
if (fast_input.get() == nullptr) {
|
|
return nullptr;
|
|
}
|
|
Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get());
|
|
PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get());
|
|
|
|
TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len);
|
|
if (op_exec_info.run_callbacks) {
|
|
flattened_attrs->emplace_back(
|
|
GetPythonObjectFromString(input_arg.number_attr()));
|
|
flattened_attrs->emplace_back(PyLong_FromLong(len));
|
|
}
|
|
attr_list_sizes[input_arg.number_attr()] = len;
|
|
|
|
if (len > 0) {
|
|
// First item adds the type attr.
|
|
if (!AddInputToOp(&op_exec_info, fast_input_array[0], true, input_arg,
|
|
flattened_attrs.get(), flattened_inputs.get(), op,
|
|
status)) {
|
|
return nullptr;
|
|
}
|
|
|
|
for (Py_ssize_t j = 1; j < len; j++) {
|
|
// Since the list is homogeneous, we don't need to re-add the attr.
|
|
if (!AddInputToOp(&op_exec_info, fast_input_array[j], false,
|
|
input_arg, nullptr /* flattened_attrs */,
|
|
flattened_inputs.get(), op, status)) {
|
|
return nullptr;
|
|
}
|
|
}
|
|
}
|
|
} else if (!input_arg.type_list_attr().empty()) {
|
|
// The item is a heterogeneous list.
|
|
if (!RaiseIfNotPySequence(input, input_arg.type_list_attr())) {
|
|
return nullptr;
|
|
}
|
|
tensorflow::Safe_PyObjectPtr fast_input(
|
|
PySequence_Fast(input, "Could not parse sequence."));
|
|
if (fast_input.get() == nullptr) {
|
|
return nullptr;
|
|
}
|
|
const string& attr_name = input_arg.type_list_attr();
|
|
Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get());
|
|
PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get());
|
|
tensorflow::gtl::InlinedVector<TF_DataType, 4> attr_value(len);
|
|
PyObject* py_attr_value = nullptr;
|
|
if (op_exec_info.run_callbacks) {
|
|
py_attr_value = PyTuple_New(len);
|
|
}
|
|
for (Py_ssize_t j = 0; j < len; j++) {
|
|
PyObject* py_input = fast_input_array[j];
|
|
tensorflow::Safe_PyObjectPtr py_eager_tensor;
|
|
if (!ConvertToTensor(
|
|
op_exec_info, py_input, &py_eager_tensor,
|
|
[]() { return tensorflow::DT_INVALID; },
|
|
[](const tensorflow::DataType dtype) {}, status)) {
|
|
return nullptr;
|
|
}
|
|
|
|
TFE_TensorHandle* input_handle =
|
|
EagerTensor_Handle(py_eager_tensor.get());
|
|
|
|
attr_value[j] = TFE_TensorHandleDataType(input_handle);
|
|
|
|
TFE_OpAddInput(op, input_handle, status);
|
|
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
|
|
return nullptr;
|
|
}
|
|
|
|
if (op_exec_info.run_callbacks) {
|
|
flattened_inputs->emplace_back(std::move(py_eager_tensor));
|
|
|
|
PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j]));
|
|
}
|
|
}
|
|
if (op_exec_info.run_callbacks) {
|
|
flattened_attrs->emplace_back(GetPythonObjectFromString(attr_name));
|
|
flattened_attrs->emplace_back(py_attr_value);
|
|
}
|
|
TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(),
|
|
attr_value.size());
|
|
attr_list_sizes[attr_name] = len;
|
|
} else {
|
|
// The item is a single item.
|
|
if (!AddInputToOp(&op_exec_info, input, true, input_arg,
|
|
flattened_attrs.get(), flattened_inputs.get(), op,
|
|
status)) {
|
|
return nullptr;
|
|
}
|
|
}
|
|
}
|
|
|
|
int64_t num_outputs = 0;
|
|
for (int i = 0; i < op_def->output_arg_size(); i++) {
|
|
const auto& output_arg = op_def->output_arg(i);
|
|
int64_t delta = 1;
|
|
if (!output_arg.number_attr().empty()) {
|
|
delta = attr_list_sizes[output_arg.number_attr()];
|
|
} else if (!output_arg.type_list_attr().empty()) {
|
|
delta = attr_list_sizes[output_arg.type_list_attr()];
|
|
}
|
|
if (delta < 0) {
|
|
RaiseFallbackException(
|
|
"Attributes suggest that the size of an output list is less than 0");
|
|
return nullptr;
|
|
}
|
|
num_outputs += delta;
|
|
}
|
|
|
|
// If number of retvals is larger than int32, we error out.
|
|
if (static_cast<int64_t>(static_cast<int32_t>(num_outputs)) != num_outputs) {
|
|
PyErr_SetString(
|
|
PyExc_ValueError,
|
|
Printf("Number of outputs is too big: %ld", num_outputs).c_str());
|
|
return nullptr;
|
|
}
|
|
int num_retvals = num_outputs;
|
|
|
|
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);
|
|
|
|
Py_BEGIN_ALLOW_THREADS;
|
|
TFE_Execute(op, retvals.data(), &num_retvals, status);
|
|
Py_END_ALLOW_THREADS;
|
|
|
|
if (!status->status.ok()) {
|
|
// Augment the status with the op_name for easier debugging similar to
|
|
// TFE_Py_Execute.
|
|
std::vector<tensorflow::StackFrame> stack_trace =
|
|
status->status.stack_trace();
|
|
status->status = tensorflow::Status(
|
|
status->status.code(),
|
|
tensorflow::strings::StrCat(
|
|
TF_Message(status),
|
|
" [Op:", TFE_GetPythonString(op_exec_info.op_name), "]"),
|
|
std::move(stack_trace));
|
|
|
|
MaybeRaiseExceptionFromTFStatus(status, nullptr);
|
|
return nullptr;
|
|
}
|
|
|
|
tensorflow::Safe_PyObjectPtr flat_result(PyList_New(num_retvals));
|
|
for (int i = 0; i < num_retvals; ++i) {
|
|
PyList_SET_ITEM(flat_result.get(), i, EagerTensorFromHandle(retvals[i]));
|
|
}
|
|
|
|
if (op_exec_info.run_callbacks) {
|
|
if (!RunCallbacks(
|
|
op_exec_info, args,
|
|
FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
|
|
*flattened_inputs, *flattened_attrs, flat_result.get())) {
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
// Unflatten results.
|
|
if (op_def->output_arg_size() == 0) {
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
if (op_def->output_arg_size() == 1) {
|
|
if (!op_def->output_arg(0).number_attr().empty() ||
|
|
!op_def->output_arg(0).type_list_attr().empty()) {
|
|
return flat_result.release();
|
|
} else {
|
|
auto* result = PyList_GET_ITEM(flat_result.get(), 0);
|
|
Py_INCREF(result);
|
|
return result;
|
|
}
|
|
}
|
|
|
|
// Correctly output the results that are made into a namedtuple.
|
|
PyObject* result = PyList_New(op_def->output_arg_size());
|
|
int flat_result_index = 0;
|
|
for (int i = 0; i < op_def->output_arg_size(); i++) {
|
|
if (!op_def->output_arg(i).number_attr().empty()) {
|
|
int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()];
|
|
PyObject* inner_list = PyList_New(list_length);
|
|
for (int j = 0; j < list_length; j++) {
|
|
PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
|
|
Py_INCREF(obj);
|
|
PyList_SET_ITEM(inner_list, j, obj);
|
|
}
|
|
PyList_SET_ITEM(result, i, inner_list);
|
|
} else if (!op_def->output_arg(i).type_list_attr().empty()) {
|
|
int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()];
|
|
PyObject* inner_list = PyList_New(list_length);
|
|
for (int j = 0; j < list_length; j++) {
|
|
PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
|
|
Py_INCREF(obj);
|
|
PyList_SET_ITEM(inner_list, j, obj);
|
|
}
|
|
PyList_SET_ITEM(result, i, inner_list);
|
|
} else {
|
|
PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
|
|
Py_INCREF(obj);
|
|
PyList_SET_ITEM(result, i, obj);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
|
|
PyObject* attrs, PyObject* results,
|
|
PyObject* forward_pass_name_scope) {
|
|
if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
return RecordGradient(op_name, inputs, attrs, results,
|
|
forward_pass_name_scope);
|
|
}
|
|
|
|
namespace {
|
|
const char kTensor[] = "T";
|
|
const char kList[] = "L";
|
|
const char kListEnd[] = "l";
|
|
const char kTuple[] = "U";
|
|
const char kTupleEnd[] = "u";
|
|
const char kDict[] = "D";
|
|
const char kRaw[] = "R";
|
|
const char kShape[] = "s";
|
|
const char kShapeDelim[] = "-";
|
|
const char kDType[] = "d";
|
|
const char kNone[] = "n";
|
|
const char kCompositeTensor[] = "C";
|
|
const char kAttrs[] = "A";
|
|
const char kAttrsEnd[] = "a";
|
|
|
|
struct EncodeResult {
|
|
string str;
|
|
std::vector<PyObject*> objects;
|
|
|
|
PyObject* ToPyTuple() {
|
|
PyObject* result = PyTuple_New(2);
|
|
|
|
PyTuple_SET_ITEM(result, 0, GetPythonObjectFromString(str));
|
|
|
|
if (objects.empty()) {
|
|
Py_INCREF(Py_None);
|
|
PyTuple_SET_ITEM(result, 1, Py_None);
|
|
} else {
|
|
PyObject* objects_tuple = PyTuple_New(objects.size());
|
|
|
|
for (int i = 0; i < objects.size(); i++) {
|
|
PyTuple_SET_ITEM(objects_tuple, i, objects[i]);
|
|
}
|
|
|
|
PyTuple_SET_ITEM(result, 1, objects_tuple);
|
|
}
|
|
|
|
return result;
|
|
}
|
|
};
|
|
|
|
tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg,
|
|
bool include_tensor_ranks_only,
|
|
EncodeResult* result) {
|
|
if (EagerTensor_CheckExact(arg)) {
|
|
tensorflow::ImmediateExecutionTensorHandle* handle =
|
|
tensorflow::unwrap(EagerTensor_Handle(arg));
|
|
|
|
absl::StrAppend(&result->str, kDType,
|
|
static_cast<tensorflow::DataType>(handle->DataType()));
|
|
absl::StrAppend(&result->str, kShape);
|
|
|
|
int num_dims;
|
|
tensorflow::Status status = handle->NumDims(&num_dims);
|
|
if (!status.ok()) return status;
|
|
|
|
if (include_tensor_ranks_only) {
|
|
absl::StrAppend(&result->str, num_dims);
|
|
} else {
|
|
for (int i = 0; i < num_dims; ++i) {
|
|
tensorflow::int64 dim_size;
|
|
status = handle->Dim(i, &dim_size);
|
|
if (!status.ok()) return status;
|
|
absl::StrAppend(&result->str, dim_size, kShapeDelim);
|
|
}
|
|
}
|
|
return tensorflow::Status::OK();
|
|
}
|
|
|
|
tensorflow::Safe_PyObjectPtr dtype_object(
|
|
PyObject_GetAttrString(arg, "dtype"));
|
|
|
|
if (dtype_object == nullptr) {
|
|
return tensorflow::errors::InvalidArgument(
|
|
"ops.Tensor object doesn't have dtype() attr.");
|
|
}
|
|
|
|
tensorflow::Safe_PyObjectPtr dtype_enum(
|
|
PyObject_GetAttrString(dtype_object.get(), "_type_enum"));
|
|
|
|
if (dtype_enum == nullptr) {
|
|
return tensorflow::errors::InvalidArgument(
|
|
"ops.Tensor's dtype object doesn't have _type_enum() attr.");
|
|
}
|
|
|
|
tensorflow::DataType dtype =
|
|
static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get()));
|
|
|
|
absl::StrAppend(&result->str, kDType, dtype);
|
|
|
|
static char _shape_tuple[] = "_shape_tuple";
|
|
tensorflow::Safe_PyObjectPtr shape_tuple(
|
|
PyObject_CallMethod(arg, _shape_tuple, nullptr));
|
|
|
|
if (shape_tuple == nullptr) {
|
|
return tensorflow::errors::InvalidArgument(
|
|
"ops.Tensor object doesn't have _shape_tuple() method.");
|
|
}
|
|
|
|
if (shape_tuple.get() == Py_None) {
|
|
// Unknown shape, encode that directly.
|
|
absl::StrAppend(&result->str, kNone);
|
|
return tensorflow::Status::OK();
|
|
}
|
|
|
|
absl::StrAppend(&result->str, kShape);
|
|
tensorflow::Safe_PyObjectPtr shape_seq(PySequence_Fast(
|
|
shape_tuple.get(), "shape_tuple didn't return a sequence"));
|
|
|
|
int len = PySequence_Fast_GET_SIZE(shape_seq.get());
|
|
PyObject** shape_seq_array = PySequence_Fast_ITEMS(shape_seq.get());
|
|
|
|
if (include_tensor_ranks_only) {
|
|
absl::StrAppend(&result->str, len);
|
|
} else {
|
|
for (int i = 0; i < len; ++i) {
|
|
PyObject* item = shape_seq_array[i];
|
|
if (item == Py_None) {
|
|
absl::StrAppend(&result->str, kNone);
|
|
} else {
|
|
absl::StrAppend(&result->str, MakeInt(item));
|
|
}
|
|
}
|
|
}
|
|
return tensorflow::Status::OK();
|
|
}
|
|
|
|
tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
|
|
bool include_tensor_ranks_only,
|
|
EncodeResult* result);
|
|
|
|
// This function doesn't set the type of sequence before
|
|
tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type,
|
|
const char* end_type,
|
|
bool include_tensor_ranks_only,
|
|
EncodeResult* result) {
|
|
tensorflow::Safe_PyObjectPtr arg_seq(
|
|
PySequence_Fast(arg, "unable to create seq from list/tuple"));
|
|
|
|
absl::StrAppend(&result->str, type);
|
|
int len = PySequence_Fast_GET_SIZE(arg_seq.get());
|
|
PyObject** arg_seq_array = PySequence_Fast_ITEMS(arg_seq.get());
|
|
for (int i = 0; i < len; ++i) {
|
|
PyObject* item = arg_seq_array[i];
|
|
if (item == Py_None) {
|
|
absl::StrAppend(&result->str, kNone);
|
|
} else {
|
|
TF_RETURN_IF_ERROR(
|
|
TFE_Py_EncodeArgHelper(item, include_tensor_ranks_only, result));
|
|
}
|
|
}
|
|
absl::StrAppend(&result->str, end_type);
|
|
|
|
return tensorflow::Status::OK();
|
|
}
|
|
|
|
tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
|
|
bool include_tensor_ranks_only,
|
|
EncodeResult* result) {
|
|
if (tensorflow::swig::IsTensor(arg)) {
|
|
absl::StrAppend(&result->str, kTensor);
|
|
TF_RETURN_IF_ERROR(
|
|
TFE_Py_EncodeTensor(arg, include_tensor_ranks_only, result));
|
|
} else if (PyList_Check(arg)) {
|
|
TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(
|
|
arg, kList, kListEnd, include_tensor_ranks_only, result));
|
|
} else if (tensorflow::swig::IsTuple(arg)) {
|
|
TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(
|
|
arg, kTuple, kTupleEnd, include_tensor_ranks_only, result));
|
|
} else if (tensorflow::swig::IsMapping(arg)) {
|
|
tensorflow::Safe_PyObjectPtr keys(tensorflow::swig::MappingKeys(arg));
|
|
if (PyList_Sort(keys.get()) == -1) {
|
|
return tensorflow::errors::Internal("Unable to sort keys");
|
|
}
|
|
|
|
absl::StrAppend(&result->str, kDict);
|
|
int len = PyList_Size(keys.get());
|
|
|
|
for (int i = 0; i < len; i++) {
|
|
PyObject* key = PyList_GetItem(keys.get(), i);
|
|
TF_RETURN_IF_ERROR(
|
|
TFE_Py_EncodeArgHelper(key, include_tensor_ranks_only, result));
|
|
tensorflow::Safe_PyObjectPtr value(PyObject_GetItem(arg, key));
|
|
TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(
|
|
value.get(), include_tensor_ranks_only, result));
|
|
}
|
|
} else if (tensorflow::swig::IsCompositeTensor(arg)) {
|
|
absl::StrAppend(&result->str, kCompositeTensor);
|
|
|
|
// Add the typespec to the list of objects. (Do *not* use a weakref,
|
|
// since the type spec is often a temporary object.)
|
|
PyObject* type_spec(PyObject_GetAttrString(arg, "_type_spec"));
|
|
if (type_spec == nullptr) {
|
|
return tensorflow::errors::InvalidArgument(
|
|
"Error while reading CompositeTensor._type_spec.");
|
|
}
|
|
result->objects.push_back(type_spec);
|
|
} else if (tensorflow::swig::IsTypeSpec(arg)) {
|
|
// Add the typespec (not a weakref) in case it's a temporary object.
|
|
absl::StrAppend(&result->str, kRaw);
|
|
Py_INCREF(arg);
|
|
result->objects.push_back(arg);
|
|
} else if (tensorflow::swig::IsAttrs(arg)) {
|
|
absl::StrAppend(&result->str, kAttrs);
|
|
tensorflow::Safe_PyObjectPtr attrs(
|
|
PyObject_GetAttrString(arg, "__attrs_attrs__"));
|
|
tensorflow::Safe_PyObjectPtr iter(PyObject_GetIter(attrs.get()));
|
|
for (tensorflow::Safe_PyObjectPtr item(PyIter_Next(iter.get())); item;
|
|
item.reset(PyIter_Next(iter.get()))) {
|
|
tensorflow::Safe_PyObjectPtr name(
|
|
PyObject_GetAttrString(item.get(), "name"));
|
|
tensorflow::Safe_PyObjectPtr attr_arg(PyObject_GetAttr(arg, name.get()));
|
|
TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(
|
|
attr_arg.get(), include_tensor_ranks_only, result));
|
|
}
|
|
absl::StrAppend(&result->str, kAttrsEnd);
|
|
} else {
|
|
PyObject* object = PyWeakref_NewRef(arg, nullptr);
|
|
|
|
if (object == nullptr) {
|
|
PyErr_Clear();
|
|
|
|
object = arg;
|
|
Py_INCREF(object);
|
|
}
|
|
|
|
absl::StrAppend(&result->str, kRaw);
|
|
result->objects.push_back(object);
|
|
}
|
|
|
|
return tensorflow::Status::OK();
|
|
}
|
|
|
|
} // namespace
|
|
|
|
// `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes
|
|
// are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes
|
|
// are used for both performance reasons, as much TensorFlow code specializes
|
|
// on known shapes to produce slimmer graphs, and correctness, as some
|
|
// high-level APIs require shapes to be fully-known.
|
|
//
|
|
// `include_tensor_ranks_only` allows caching on arguments excluding shape info,
|
|
// so that a slow path using relaxed shape can rely on a cache key that excludes
|
|
// shapes.
|
|
PyObject* TFE_Py_EncodeArg(PyObject* arg, bool include_tensor_ranks_only) {
|
|
EncodeResult result;
|
|
const auto status =
|
|
TFE_Py_EncodeArgHelper(arg, include_tensor_ranks_only, &result);
|
|
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
|
|
return nullptr;
|
|
}
|
|
|
|
return result.ToPyTuple();
|
|
}
|
|
|
|
// A method prints incoming messages directly to Python's
|
|
// stdout using Python's C API. This is necessary in Jupyter notebooks
|
|
// and colabs where messages to the C stdout don't go to the notebook
|
|
// cell outputs, but calls to Python's stdout do.
|
|
void PrintToPythonStdout(const char* msg) {
|
|
if (Py_IsInitialized()) {
|
|
PyGILState_STATE py_threadstate;
|
|
py_threadstate = PyGILState_Ensure();
|
|
|
|
string string_msg = msg;
|
|
// PySys_WriteStdout truncates strings over 1000 bytes, so
|
|
// we write the message in chunks small enough to not be truncated.
|
|
int CHUNK_SIZE = 900;
|
|
auto len = string_msg.length();
|
|
for (int i = 0; i < len; i += CHUNK_SIZE) {
|
|
PySys_WriteStdout("%s", string_msg.substr(i, CHUNK_SIZE).c_str());
|
|
}
|
|
|
|
// Force flushing to make sure print newlines aren't interleaved in
|
|
// some colab environments
|
|
PyRun_SimpleString("import sys; sys.stdout.flush()");
|
|
|
|
PyGILState_Release(py_threadstate);
|
|
}
|
|
}
|
|
|
|
// Register PrintToPythonStdout as a log listener, to allow
|
|
// printing in colabs and jupyter notebooks to work.
|
|
void TFE_Py_EnableInteractivePythonLogging() {
|
|
static bool enabled_interactive_logging = false;
|
|
if (!enabled_interactive_logging) {
|
|
enabled_interactive_logging = true;
|
|
TF_RegisterLogListener(PrintToPythonStdout);
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
// weak reference to Python Context object currently active
|
|
PyObject* weak_eager_context = nullptr;
|
|
} // namespace
|
|
|
|
PyObject* TFE_Py_SetEagerContext(PyObject* py_context) {
|
|
Py_XDECREF(weak_eager_context);
|
|
weak_eager_context = PyWeakref_NewRef(py_context, nullptr);
|
|
if (weak_eager_context == nullptr) {
|
|
return nullptr;
|
|
}
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
PyObject* GetPyEagerContext() {
|
|
if (weak_eager_context == nullptr) {
|
|
PyErr_SetString(PyExc_RuntimeError, "Python eager context is not set");
|
|
return nullptr;
|
|
}
|
|
PyObject* py_context = PyWeakref_GET_OBJECT(weak_eager_context);
|
|
if (py_context == Py_None) {
|
|
PyErr_SetString(PyExc_RuntimeError, "Eager context has been destroyed");
|
|
return nullptr;
|
|
}
|
|
Py_INCREF(py_context);
|
|
return py_context;
|
|
}
|
|
|
|
namespace {
|
|
|
|
// Default values for thread_local_data fields.
|
|
struct EagerContextThreadLocalDataDefaults {
|
|
tensorflow::Safe_PyObjectPtr is_eager;
|
|
tensorflow::Safe_PyObjectPtr device_spec;
|
|
};
|
|
|
|
// Maps each py_eager_context object to its thread_local_data.
|
|
//
|
|
// Note: we need to use the python Context object as the key here (and not
|
|
// its handle object), because the handle object isn't created until the
|
|
// context is initialized; but thread_local_data is potentially accessed
|
|
// before then.
|
|
using EagerContextThreadLocalDataMap = absl::flat_hash_map<
|
|
PyObject*, std::unique_ptr<tensorflow::EagerContextThreadLocalData>>;
|
|
thread_local EagerContextThreadLocalDataMap*
|
|
eager_context_thread_local_data_map = nullptr;
|
|
|
|
// Maps each py_eager_context object to default values.
|
|
using EagerContextThreadLocalDataDefaultsMap =
|
|
absl::flat_hash_map<PyObject*, EagerContextThreadLocalDataDefaults>;
|
|
EagerContextThreadLocalDataDefaultsMap*
|
|
eager_context_thread_local_data_defaults = nullptr;
|
|
|
|
} // namespace
|
|
|
|
namespace tensorflow {
|
|
|
|
void MakeEagerContextThreadLocalData(PyObject* py_eager_context,
|
|
PyObject* is_eager,
|
|
PyObject* device_spec) {
|
|
DCheckPyGilState();
|
|
if (eager_context_thread_local_data_defaults == nullptr) {
|
|
absl::LeakCheckDisabler disabler;
|
|
eager_context_thread_local_data_defaults =
|
|
new EagerContextThreadLocalDataDefaultsMap();
|
|
}
|
|
if (eager_context_thread_local_data_defaults->count(py_eager_context) > 0) {
|
|
PyErr_SetString(PyExc_AssertionError,
|
|
"MakeEagerContextThreadLocalData may not be called "
|
|
"twice on the same eager Context object.");
|
|
}
|
|
|
|
auto& defaults =
|
|
(*eager_context_thread_local_data_defaults)[py_eager_context];
|
|
Py_INCREF(is_eager);
|
|
defaults.is_eager.reset(is_eager);
|
|
Py_INCREF(device_spec);
|
|
defaults.device_spec.reset(device_spec);
|
|
}
|
|
|
|
EagerContextThreadLocalData* GetEagerContextThreadLocalData(
|
|
PyObject* py_eager_context) {
|
|
if (eager_context_thread_local_data_defaults == nullptr) {
|
|
PyErr_SetString(PyExc_AssertionError,
|
|
"MakeEagerContextThreadLocalData must be called "
|
|
"before GetEagerContextThreadLocalData.");
|
|
return nullptr;
|
|
}
|
|
auto defaults =
|
|
eager_context_thread_local_data_defaults->find(py_eager_context);
|
|
if (defaults == eager_context_thread_local_data_defaults->end()) {
|
|
PyErr_SetString(PyExc_AssertionError,
|
|
"MakeEagerContextThreadLocalData must be called "
|
|
"before GetEagerContextThreadLocalData.");
|
|
return nullptr;
|
|
}
|
|
|
|
if (eager_context_thread_local_data_map == nullptr) {
|
|
absl::LeakCheckDisabler disabler;
|
|
eager_context_thread_local_data_map = new EagerContextThreadLocalDataMap();
|
|
}
|
|
auto& thread_local_data =
|
|
(*eager_context_thread_local_data_map)[py_eager_context];
|
|
|
|
if (!thread_local_data) {
|
|
thread_local_data.reset(new EagerContextThreadLocalData());
|
|
|
|
Safe_PyObjectPtr is_eager(PyObject_CallFunctionObjArgs(
|
|
defaults->second.is_eager.get(), nullptr));
|
|
if (!is_eager) return nullptr;
|
|
thread_local_data->is_eager = PyObject_IsTrue(is_eager.get());
|
|
|
|
#if PY_MAJOR_VERSION >= 3
|
|
PyObject* scope_name = PyUnicode_FromString("");
|
|
#else
|
|
PyObject* scope_name = PyString_FromString("");
|
|
#endif
|
|
thread_local_data->scope_name.reset(scope_name);
|
|
|
|
#if PY_MAJOR_VERSION >= 3
|
|
PyObject* device_name = PyUnicode_FromString("");
|
|
#else
|
|
PyObject* device_name = PyString_FromString("");
|
|
#endif
|
|
thread_local_data->device_name.reset(device_name);
|
|
|
|
Py_INCREF(defaults->second.device_spec.get());
|
|
thread_local_data->device_spec.reset(defaults->second.device_spec.get());
|
|
|
|
Py_INCREF(Py_None);
|
|
thread_local_data->function_call_options.reset(Py_None);
|
|
|
|
Py_INCREF(Py_None);
|
|
thread_local_data->executor.reset(Py_None);
|
|
|
|
thread_local_data->op_callbacks.reset(PyList_New(0));
|
|
}
|
|
return thread_local_data.get();
|
|
}
|
|
|
|
void DestroyEagerContextThreadLocalData(PyObject* py_eager_context) {
|
|
DCheckPyGilState();
|
|
if (eager_context_thread_local_data_defaults) {
|
|
eager_context_thread_local_data_defaults->erase(py_eager_context);
|
|
}
|
|
if (eager_context_thread_local_data_map) {
|
|
eager_context_thread_local_data_map->erase(py_eager_context);
|
|
}
|
|
}
|
|
|
|
} // namespace tensorflow
|