Move the _ThreadLocalData class used by eager Context
from Python (eager/context.py) to c++ (pywrap_tfe.h).
PiperOrigin-RevId: 330995559 Change-Id: I01bd358760b25f166a3e876fea8d83c9c0e8b513
This commit is contained in:
parent
3db52f724a
commit
8e1ef83d9e
@ -80,8 +80,8 @@ def c_tfe_py_fastpath_execute(a,
|
|||||||
assert ctx.executing_eagerly(
|
assert ctx.executing_eagerly(
|
||||||
), "The prototype doesn't contain C code for graph construction"
|
), "The prototype doesn't contain C code for graph construction"
|
||||||
try:
|
try:
|
||||||
return pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
return pywrap_tfe.TFE_Py_FastPathExecute(ctx, ctx._handle,
|
||||||
"MatMul", name, ctx.op_callbacks,
|
"MatMul", name,
|
||||||
a, b, "transpose_a", transpose_a,
|
a, b, "transpose_a", transpose_a,
|
||||||
"transpose_b", transpose_b)
|
"transpose_b", transpose_b)
|
||||||
except core._NotOkStatusException as e:
|
except core._NotOkStatusException as e:
|
||||||
|
@ -183,21 +183,6 @@ class _TensorCaches(threading.local):
|
|||||||
return self._zeros_cache
|
return self._zeros_cache
|
||||||
|
|
||||||
|
|
||||||
class _ThreadLocalData(threading.local):
|
|
||||||
"""Thread local storage for the eager context."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(_ThreadLocalData, self).__init__()
|
|
||||||
self.device_spec = _starting_device_spec
|
|
||||||
self.device_name = ""
|
|
||||||
self.is_eager = default_execution_mode == EAGER_MODE
|
|
||||||
self.scope_name = ""
|
|
||||||
self.function_call_options = None
|
|
||||||
self.executor = None
|
|
||||||
self.op_callbacks = []
|
|
||||||
self.invoking_op_callbacks = False
|
|
||||||
|
|
||||||
|
|
||||||
ContextSwitch = collections.namedtuple(
|
ContextSwitch = collections.namedtuple(
|
||||||
"ContextSwitch", ["is_building_function", "enter_context_fn",
|
"ContextSwitch", ["is_building_function", "enter_context_fn",
|
||||||
"device_stack"])
|
"device_stack"])
|
||||||
@ -420,7 +405,10 @@ class Context(object):
|
|||||||
_tensor_caches_map[self._id] = _TensorCaches()
|
_tensor_caches_map[self._id] = _TensorCaches()
|
||||||
|
|
||||||
self._config = config
|
self._config = config
|
||||||
self._thread_local_data = _ThreadLocalData()
|
self._thread_local_data = pywrap_tfe.EagerContextThreadLocalData(
|
||||||
|
self,
|
||||||
|
is_eager=lambda: default_execution_mode == EAGER_MODE,
|
||||||
|
device_spec=_starting_device_spec)
|
||||||
self._context_switches = _ContextSwitchStack(self.executing_eagerly())
|
self._context_switches = _ContextSwitchStack(self.executing_eagerly())
|
||||||
self._context_handle = None
|
self._context_handle = None
|
||||||
self._context_devices = None
|
self._context_devices = None
|
||||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
|
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
|
||||||
|
|
||||||
typedef tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>
|
typedef tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>
|
||||||
TFE_InputTensorHandles;
|
TFE_InputTensorHandles;
|
||||||
@ -259,16 +260,15 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
|
|||||||
// it will simply fail with a NotImplementedError.
|
// it will simply fail with a NotImplementedError.
|
||||||
//
|
//
|
||||||
// The "args" PyObject* is meant to be a tuple with the following structure:
|
// The "args" PyObject* is meant to be a tuple with the following structure:
|
||||||
// Item 1: The TFE Context
|
// Item 1: The Python eager Context object
|
||||||
// Item 2: device_name: Name of the device on which to execute the operation,
|
// Item 2: op_name: Name of the TensorFlow op to execute.
|
||||||
// or NULL for automatic selection.
|
// Item 3: name: An optional name for the operation.
|
||||||
// Item 3: op_name: Name of the TensorFlow op to execute.
|
// Item 4 onwards: inputs - This is a list of inputs followed by a list of
|
||||||
// Item 4: name: An optional name for the operation.
|
|
||||||
// Item 5: List representing all callbacks to execute after successful
|
|
||||||
// op execute.
|
|
||||||
// Item 6 onwards: inputs - This is a list of inputs followed by a list of
|
|
||||||
// attrs. It is not necessary for type attrs to be present.
|
// attrs. It is not necessary for type attrs to be present.
|
||||||
//
|
//
|
||||||
|
// Note: the device_name and op_callbacks, which were previously passed
|
||||||
|
// as arguments, are now read via GetEagerContextThreadLocalData().
|
||||||
|
//
|
||||||
// This is named _C since there doesn't seem to be any way to make it visible
|
// This is named _C since there doesn't seem to be any way to make it visible
|
||||||
// in the SWIG interface without renaming due to the use of the %native
|
// in the SWIG interface without renaming due to the use of the %native
|
||||||
// directive.
|
// directive.
|
||||||
@ -394,4 +394,59 @@ PyObject* GetPyEagerContext();
|
|||||||
TF_Status* GetStatus();
|
TF_Status* GetStatus();
|
||||||
// Returns the pre-allocated status to the code.
|
// Returns the pre-allocated status to the code.
|
||||||
void ReturnStatus(TF_Status* status);
|
void ReturnStatus(TF_Status* status);
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Thread-local data associated with a Python eager Context object.
|
||||||
|
//
|
||||||
|
// TODO(edloper): Consider changing device_name and scope_name to a const char*
|
||||||
|
// (with nullptr used for None). However, note that existing code (e.g.
|
||||||
|
// TFE_TensorHandleCache::Lookup) assumes that the lifetime of these strings
|
||||||
|
// extends beyond the point where their value is changed; so we'd need to make
|
||||||
|
// sure that the strings stay alive (maybe using PyUnicode_InternInPlace?)
|
||||||
|
struct EagerContextThreadLocalData {
|
||||||
|
bool is_eager = false;
|
||||||
|
bool invoking_op_callbacks = false;
|
||||||
|
tensorflow::Safe_PyObjectPtr device_name;
|
||||||
|
tensorflow::Safe_PyObjectPtr scope_name;
|
||||||
|
tensorflow::Safe_PyObjectPtr device_spec;
|
||||||
|
tensorflow::Safe_PyObjectPtr function_call_options;
|
||||||
|
tensorflow::Safe_PyObjectPtr executor;
|
||||||
|
tensorflow::Safe_PyObjectPtr op_callbacks;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create a thread-local-data structure associated with py_eager_context.
|
||||||
|
// `is_eager` and `device_spec` are used to supply default values for those
|
||||||
|
// fields whenever a new thread-local instance is created for py_eager_tensor.
|
||||||
|
//
|
||||||
|
// This function assumes that the Python GIL is held (and does not perform its
|
||||||
|
// own locking).
|
||||||
|
void MakeEagerContextThreadLocalData(PyObject* py_eager_context,
|
||||||
|
PyObject* is_eager,
|
||||||
|
PyObject* device_spec);
|
||||||
|
|
||||||
|
// Returns the thread-local instance of EagerContextThreadLocalData that is
|
||||||
|
// associated with the given Python Context object. If an instance has not
|
||||||
|
// yet been created for `py_eager_context` in this thread, then a new one is
|
||||||
|
// created, and initialized with the default values specified in
|
||||||
|
// MakeEagerContextThreadLocalData.
|
||||||
|
EagerContextThreadLocalData* GetEagerContextThreadLocalData(
|
||||||
|
PyObject* py_eager_context);
|
||||||
|
|
||||||
|
// Free data structures used to track py_eager_context.
|
||||||
|
//
|
||||||
|
// This frees global state associated with py_eager_context, as well as thread-
|
||||||
|
// local state associated with py_eager_context and the current thread. If you
|
||||||
|
// wish to destroy thread-local state associated with a single py_eager_context
|
||||||
|
// for multiple threads, then you must call this method from each thread.
|
||||||
|
//
|
||||||
|
// Thread-local state assocaited with eager contexts is also automatically
|
||||||
|
// cleaned up when the thread is destroyed.
|
||||||
|
//
|
||||||
|
// This function assumes that the Python GIL is held (and does not perform its
|
||||||
|
// own locking).
|
||||||
|
void DestroyEagerContextThreadLocalData(PyObject* py_eager_context);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
|
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
|
||||||
|
@ -2953,7 +2953,14 @@ PyObject* TFE_Py_PackJVPs(PyObject* tensors) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
static const int kFastPathExecuteInputStartIndex = 5;
|
|
||||||
|
// Indices for the "args" tuple that's passed to TFE_Py_FastPathExecute_C.
|
||||||
|
enum FastPathExecuteArgIndex {
|
||||||
|
FAST_PATH_EXECUTE_ARG_CONTEXT = 0,
|
||||||
|
FAST_PATH_EXECUTE_ARG_OP_NAME = 1,
|
||||||
|
FAST_PATH_EXECUTE_ARG_NAME = 2,
|
||||||
|
FAST_PATH_EXECUTE_ARG_INPUT_START = 3
|
||||||
|
};
|
||||||
|
|
||||||
PyObject* GetPythonObjectFromString(tensorflow::StringPiece s) {
|
PyObject* GetPythonObjectFromString(tensorflow::StringPiece s) {
|
||||||
#if PY_MAJOR_VERSION >= 3
|
#if PY_MAJOR_VERSION >= 3
|
||||||
@ -3063,7 +3070,7 @@ tensorflow::DataType MaybeGetDTypeForAttr(const string& attr,
|
|||||||
|
|
||||||
for (const auto& input_info : it->second) {
|
for (const auto& input_info : it->second) {
|
||||||
PyObject* item = PyTuple_GET_ITEM(
|
PyObject* item = PyTuple_GET_ITEM(
|
||||||
op_exec_info->args, kFastPathExecuteInputStartIndex + input_info.i);
|
op_exec_info->args, FAST_PATH_EXECUTE_ARG_INPUT_START + input_info.i);
|
||||||
if (input_info.is_list) {
|
if (input_info.is_list) {
|
||||||
tensorflow::Safe_PyObjectPtr fast_item(
|
tensorflow::Safe_PyObjectPtr fast_item(
|
||||||
PySequence_Fast(item, "Unable to allocate"));
|
PySequence_Fast(item, "Unable to allocate"));
|
||||||
@ -3526,19 +3533,26 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
|
|||||||
tensorflow::profiler::TraceMe activity(
|
tensorflow::profiler::TraceMe activity(
|
||||||
"TFE_Py_FastPathExecute_C", tensorflow::profiler::TraceMeLevel::kInfo);
|
"TFE_Py_FastPathExecute_C", tensorflow::profiler::TraceMeLevel::kInfo);
|
||||||
Py_ssize_t args_size = PyTuple_GET_SIZE(args);
|
Py_ssize_t args_size = PyTuple_GET_SIZE(args);
|
||||||
if (args_size < kFastPathExecuteInputStartIndex) {
|
if (args_size < FAST_PATH_EXECUTE_ARG_INPUT_START) {
|
||||||
PyErr_SetString(
|
PyErr_SetString(
|
||||||
PyExc_ValueError,
|
PyExc_ValueError,
|
||||||
Printf("There must be at least %d items in the input tuple.",
|
Printf("There must be at least %d items in the input tuple.",
|
||||||
kFastPathExecuteInputStartIndex)
|
FAST_PATH_EXECUTE_ARG_INPUT_START)
|
||||||
.c_str());
|
.c_str());
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
FastPathOpExecInfo op_exec_info;
|
FastPathOpExecInfo op_exec_info;
|
||||||
|
|
||||||
|
PyObject* py_eager_context =
|
||||||
|
PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_CONTEXT);
|
||||||
|
|
||||||
|
// TODO(edoper): Use interned string here
|
||||||
|
PyObject* eager_context_handle =
|
||||||
|
PyObject_GetAttrString(py_eager_context, "_context_handle");
|
||||||
|
|
||||||
TFE_Context* ctx = reinterpret_cast<TFE_Context*>(
|
TFE_Context* ctx = reinterpret_cast<TFE_Context*>(
|
||||||
PyCapsule_GetPointer(PyTuple_GET_ITEM(args, 0), nullptr));
|
PyCapsule_GetPointer(eager_context_handle, nullptr));
|
||||||
op_exec_info.ctx = ctx;
|
op_exec_info.ctx = ctx;
|
||||||
op_exec_info.args = args;
|
op_exec_info.args = args;
|
||||||
|
|
||||||
@ -3550,10 +3564,15 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
op_exec_info.device_name = GetDeviceName(PyTuple_GET_ITEM(args, 1));
|
auto* tld = tensorflow::GetEagerContextThreadLocalData(py_eager_context);
|
||||||
op_exec_info.op_name = PyTuple_GET_ITEM(args, 2);
|
if (tld == nullptr) {
|
||||||
op_exec_info.name = PyTuple_GET_ITEM(args, 3);
|
return nullptr;
|
||||||
op_exec_info.callbacks = PyTuple_GET_ITEM(args, 4);
|
}
|
||||||
|
op_exec_info.device_name = GetDeviceName(tld->device_name.get());
|
||||||
|
op_exec_info.callbacks = tld->op_callbacks.get();
|
||||||
|
|
||||||
|
op_exec_info.op_name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_OP_NAME);
|
||||||
|
op_exec_info.name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_NAME);
|
||||||
|
|
||||||
// TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks
|
// TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks
|
||||||
// (similar to benchmark_tf_gradient_function_*). Also consider using an
|
// (similar to benchmark_tf_gradient_function_*). Also consider using an
|
||||||
@ -3591,18 +3610,19 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
|
|||||||
const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef();
|
const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef();
|
||||||
if (op_def == nullptr) return nullptr;
|
if (op_def == nullptr) return nullptr;
|
||||||
|
|
||||||
if (args_size < kFastPathExecuteInputStartIndex + op_def->input_arg_size()) {
|
if (args_size <
|
||||||
|
FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size()) {
|
||||||
PyErr_SetString(
|
PyErr_SetString(
|
||||||
PyExc_ValueError,
|
PyExc_ValueError,
|
||||||
Printf("Tuple size smaller than intended. Expected to be at least %d, "
|
Printf("Tuple size smaller than intended. Expected to be at least %d, "
|
||||||
"was %ld",
|
"was %ld",
|
||||||
kFastPathExecuteInputStartIndex + op_def->input_arg_size(),
|
FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
|
||||||
args_size)
|
args_size)
|
||||||
.c_str());
|
.c_str());
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!CheckInputsOk(args, kFastPathExecuteInputStartIndex, *op_def)) {
|
if (!CheckInputsOk(args, FAST_PATH_EXECUTE_ARG_INPUT_START, *op_def)) {
|
||||||
RaiseFallbackException(
|
RaiseFallbackException(
|
||||||
"This function does not handle the case of the path where "
|
"This function does not handle the case of the path where "
|
||||||
"all inputs are not already EagerTensors.");
|
"all inputs are not already EagerTensors.");
|
||||||
@ -3618,7 +3638,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
|
|||||||
|
|
||||||
// Set non-inferred attrs, including setting defaults if the attr is passed in
|
// Set non-inferred attrs, including setting defaults if the attr is passed in
|
||||||
// as None.
|
// as None.
|
||||||
for (int i = kFastPathExecuteInputStartIndex + op_def->input_arg_size();
|
for (int i = FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size();
|
||||||
i < args_size; i += 2) {
|
i < args_size; i += 2) {
|
||||||
PyObject* py_attr_name = PyTuple_GET_ITEM(args, i);
|
PyObject* py_attr_name = PyTuple_GET_ITEM(args, i);
|
||||||
const char* attr_name = TFE_GetPythonString(py_attr_name);
|
const char* attr_name = TFE_GetPythonString(py_attr_name);
|
||||||
@ -3675,7 +3695,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
|
|||||||
const auto& input_arg = op_def->input_arg(i);
|
const auto& input_arg = op_def->input_arg(i);
|
||||||
|
|
||||||
PyObject* input =
|
PyObject* input =
|
||||||
PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex + i);
|
PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_INPUT_START + i);
|
||||||
if (!input_arg.number_attr().empty()) {
|
if (!input_arg.number_attr().empty()) {
|
||||||
// The item is a homogeneous list.
|
// The item is a homogeneous list.
|
||||||
if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr;
|
if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr;
|
||||||
@ -3820,7 +3840,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
|
|||||||
if (op_exec_info.run_callbacks) {
|
if (op_exec_info.run_callbacks) {
|
||||||
if (!RunCallbacks(
|
if (!RunCallbacks(
|
||||||
op_exec_info, args,
|
op_exec_info, args,
|
||||||
kFastPathExecuteInputStartIndex + op_def->input_arg_size(),
|
FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
|
||||||
*flattened_inputs, *flattened_attrs, flat_result.get())) {
|
*flattened_inputs, *flattened_attrs, flat_result.get())) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@ -4203,3 +4223,125 @@ PyObject* GetPyEagerContext() {
|
|||||||
Py_INCREF(py_context);
|
Py_INCREF(py_context);
|
||||||
return py_context;
|
return py_context;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Default values for thread_local_data fields.
|
||||||
|
struct EagerContextThreadLocalDataDefaults {
|
||||||
|
tensorflow::Safe_PyObjectPtr is_eager;
|
||||||
|
tensorflow::Safe_PyObjectPtr device_spec;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Maps each py_eager_context object to its thread_local_data.
|
||||||
|
//
|
||||||
|
// Note: we need to use the python Context object as the key here (and not
|
||||||
|
// its handle object), because the handle object isn't created until the
|
||||||
|
// context is initialized; but thread_local_data is potentially accessed
|
||||||
|
// before then.
|
||||||
|
using EagerContextThreadLocalDataMap = absl::flat_hash_map<
|
||||||
|
PyObject*, std::unique_ptr<tensorflow::EagerContextThreadLocalData>>;
|
||||||
|
thread_local EagerContextThreadLocalDataMap*
|
||||||
|
eager_context_thread_local_data_map = nullptr;
|
||||||
|
|
||||||
|
// Maps each py_eager_context object to default values.
|
||||||
|
using EagerContextThreadLocalDataDefaultsMap =
|
||||||
|
absl::flat_hash_map<PyObject*, EagerContextThreadLocalDataDefaults>;
|
||||||
|
EagerContextThreadLocalDataDefaultsMap*
|
||||||
|
eager_context_thread_local_data_defaults = nullptr;
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
void MakeEagerContextThreadLocalData(PyObject* py_eager_context,
|
||||||
|
PyObject* is_eager,
|
||||||
|
PyObject* device_spec) {
|
||||||
|
DCheckPyGilState();
|
||||||
|
if (eager_context_thread_local_data_defaults == nullptr) {
|
||||||
|
eager_context_thread_local_data_defaults =
|
||||||
|
new EagerContextThreadLocalDataDefaultsMap();
|
||||||
|
}
|
||||||
|
if (eager_context_thread_local_data_defaults->count(py_eager_context) > 0) {
|
||||||
|
PyErr_SetString(PyExc_AssertionError,
|
||||||
|
"MakeEagerContextThreadLocalData may not be called "
|
||||||
|
"twice on the same eager Context object.");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& defaults =
|
||||||
|
(*eager_context_thread_local_data_defaults)[py_eager_context];
|
||||||
|
Py_INCREF(is_eager);
|
||||||
|
defaults.is_eager.reset(is_eager);
|
||||||
|
Py_INCREF(device_spec);
|
||||||
|
defaults.device_spec.reset(device_spec);
|
||||||
|
}
|
||||||
|
|
||||||
|
EagerContextThreadLocalData* GetEagerContextThreadLocalData(
|
||||||
|
PyObject* py_eager_context) {
|
||||||
|
if (eager_context_thread_local_data_defaults == nullptr) {
|
||||||
|
PyErr_SetString(PyExc_AssertionError,
|
||||||
|
"MakeEagerContextThreadLocalData must be called "
|
||||||
|
"before GetEagerContextThreadLocalData.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto defaults =
|
||||||
|
eager_context_thread_local_data_defaults->find(py_eager_context);
|
||||||
|
if (defaults == eager_context_thread_local_data_defaults->end()) {
|
||||||
|
PyErr_SetString(PyExc_AssertionError,
|
||||||
|
"MakeEagerContextThreadLocalData must be called "
|
||||||
|
"before GetEagerContextThreadLocalData.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (eager_context_thread_local_data_map == nullptr) {
|
||||||
|
eager_context_thread_local_data_map = new EagerContextThreadLocalDataMap();
|
||||||
|
}
|
||||||
|
auto& thread_local_data =
|
||||||
|
(*eager_context_thread_local_data_map)[py_eager_context];
|
||||||
|
|
||||||
|
if (!thread_local_data) {
|
||||||
|
thread_local_data.reset(new EagerContextThreadLocalData());
|
||||||
|
|
||||||
|
Safe_PyObjectPtr is_eager(PyObject_CallFunctionObjArgs(
|
||||||
|
defaults->second.is_eager.get(), nullptr));
|
||||||
|
if (!is_eager) return nullptr;
|
||||||
|
thread_local_data->is_eager = PyObject_IsTrue(is_eager.get());
|
||||||
|
|
||||||
|
#if PY_MAJOR_VERSION >= 3
|
||||||
|
PyObject* scope_name = PyUnicode_FromString("");
|
||||||
|
#else
|
||||||
|
PyObject* scope_name = PyString_FromString("");
|
||||||
|
#endif
|
||||||
|
thread_local_data->scope_name.reset(scope_name);
|
||||||
|
|
||||||
|
#if PY_MAJOR_VERSION >= 3
|
||||||
|
PyObject* device_name = PyUnicode_FromString("");
|
||||||
|
#else
|
||||||
|
PyObject* device_name = PyString_FromString("");
|
||||||
|
#endif
|
||||||
|
thread_local_data->device_name.reset(device_name);
|
||||||
|
|
||||||
|
Py_INCREF(defaults->second.device_spec.get());
|
||||||
|
thread_local_data->device_spec.reset(defaults->second.device_spec.get());
|
||||||
|
|
||||||
|
Py_INCREF(Py_None);
|
||||||
|
thread_local_data->function_call_options.reset(Py_None);
|
||||||
|
|
||||||
|
Py_INCREF(Py_None);
|
||||||
|
thread_local_data->executor.reset(Py_None);
|
||||||
|
|
||||||
|
thread_local_data->op_callbacks.reset(PyList_New(0));
|
||||||
|
}
|
||||||
|
return thread_local_data.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
void DestroyEagerContextThreadLocalData(PyObject* py_eager_context) {
|
||||||
|
DCheckPyGilState();
|
||||||
|
if (eager_context_thread_local_data_defaults) {
|
||||||
|
eager_context_thread_local_data_defaults->erase(py_eager_context);
|
||||||
|
}
|
||||||
|
if (eager_context_thread_local_data_map) {
|
||||||
|
eager_context_thread_local_data_map->erase(py_eager_context);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
@ -57,16 +57,15 @@ class Tests(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
math_ops.matmul(a_2_by_2, b_2_by_2),
|
math_ops.matmul(a_2_by_2, b_2_by_2),
|
||||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "MatMul", None,
|
||||||
"MatMul", None, None, a_2_by_2,
|
a_2_by_2, b_2_by_2, "transpose_a",
|
||||||
b_2_by_2, "transpose_a", False,
|
False, "transpose_b", False))
|
||||||
"transpose_b", False))
|
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
math_ops.matmul(a_100_by_784, b_100_by_784, transpose_b=True),
|
math_ops.matmul(a_100_by_784, b_100_by_784, transpose_b=True),
|
||||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "MatMul", None,
|
||||||
"MatMul", None, None, a_100_by_784,
|
a_100_by_784, b_100_by_784,
|
||||||
b_100_by_784, "transpose_a", False,
|
"transpose_a", False, "transpose_b",
|
||||||
"transpose_b", True))
|
True))
|
||||||
|
|
||||||
@test_util.assert_no_new_tensors
|
@test_util.assert_no_new_tensors
|
||||||
@test_util.assert_no_garbage_created
|
@test_util.assert_no_garbage_created
|
||||||
@ -76,14 +75,12 @@ class Tests(test.TestCase):
|
|||||||
|
|
||||||
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
|
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
|
||||||
m = resource_variable_ops.ResourceVariable(a_2_by_2)
|
m = resource_variable_ops.ResourceVariable(a_2_by_2)
|
||||||
x = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
x = pywrap_tfe.TFE_Py_FastPathExecute(ctx, "MatMul", None, m,
|
||||||
"MatMul", None, None, m, m,
|
m, "transpose_a", False,
|
||||||
"transpose_a", False, "transpose_b",
|
|
||||||
False)
|
|
||||||
y = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
|
||||||
"MatMul", None, None, a_2_by_2,
|
|
||||||
a_2_by_2, "transpose_a", False,
|
|
||||||
"transpose_b", False)
|
"transpose_b", False)
|
||||||
|
y = pywrap_tfe.TFE_Py_FastPathExecute(ctx, "MatMul", None,
|
||||||
|
a_2_by_2, a_2_by_2, "transpose_a",
|
||||||
|
False, "transpose_b", False)
|
||||||
|
|
||||||
self.assertAllEqual(x, y)
|
self.assertAllEqual(x, y)
|
||||||
|
|
||||||
@ -96,10 +93,9 @@ class Tests(test.TestCase):
|
|||||||
with backprop.GradientTape(persistent=True) as tape:
|
with backprop.GradientTape(persistent=True) as tape:
|
||||||
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
|
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
|
||||||
tape.watch(a_2_by_2)
|
tape.watch(a_2_by_2)
|
||||||
z = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
z = pywrap_tfe.TFE_Py_FastPathExecute(ctx, "MatMul", None,
|
||||||
"MatMul", None, None, a_2_by_2,
|
a_2_by_2, a_2_by_2, "transpose_a",
|
||||||
a_2_by_2, "transpose_a", False,
|
False, "transpose_b", False)
|
||||||
"transpose_b", False)
|
|
||||||
dz_dy = tape.gradient(z, [a_2_by_2])[0]
|
dz_dy = tape.gradient(z, [a_2_by_2])[0]
|
||||||
self.assertAllEqual(dz_dy.numpy(),
|
self.assertAllEqual(dz_dy.numpy(),
|
||||||
constant_op.constant(4.0, shape=[2, 2]).numpy())
|
constant_op.constant(4.0, shape=[2, 2]).numpy())
|
||||||
@ -114,10 +110,9 @@ class Tests(test.TestCase):
|
|||||||
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
|
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
|
||||||
m = resource_variable_ops.ResourceVariable(a_2_by_2)
|
m = resource_variable_ops.ResourceVariable(a_2_by_2)
|
||||||
tape.watch(m)
|
tape.watch(m)
|
||||||
z = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
z = pywrap_tfe.TFE_Py_FastPathExecute(ctx, "MatMul", None, m,
|
||||||
"MatMul", None, None, m, m,
|
m, "transpose_a", False,
|
||||||
"transpose_a", False, "transpose_b",
|
"transpose_b", False)
|
||||||
False)
|
|
||||||
dz_dy = tape.gradient(z, [m])[0]
|
dz_dy = tape.gradient(z, [m])[0]
|
||||||
self.assertAllEqual(dz_dy.numpy(),
|
self.assertAllEqual(dz_dy.numpy(),
|
||||||
constant_op.constant(4.0, shape=[2, 2]).numpy())
|
constant_op.constant(4.0, shape=[2, 2]).numpy())
|
||||||
@ -134,8 +129,8 @@ class Tests(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
math_ops.add_n([a_2_by_2, b_2_by_2]),
|
math_ops.add_n([a_2_by_2, b_2_by_2]),
|
||||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, "AddN",
|
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "AddN", None,
|
||||||
None, None, [a_2_by_2, b_2_by_2]))
|
[a_2_by_2, b_2_by_2]))
|
||||||
|
|
||||||
# Tests homogeneous list op
|
# Tests homogeneous list op
|
||||||
@test_util.assert_no_new_tensors
|
@test_util.assert_no_new_tensors
|
||||||
@ -150,8 +145,7 @@ class Tests(test.TestCase):
|
|||||||
with backprop.GradientTape(persistent=True) as tape:
|
with backprop.GradientTape(persistent=True) as tape:
|
||||||
tape.watch(a_2_by_2)
|
tape.watch(a_2_by_2)
|
||||||
tape.watch(b_2_by_2)
|
tape.watch(b_2_by_2)
|
||||||
z1 = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
z1 = pywrap_tfe.TFE_Py_FastPathExecute(ctx, "AddN", None,
|
||||||
"AddN", None, None,
|
|
||||||
[a_2_by_2, b_2_by_2])
|
[a_2_by_2, b_2_by_2])
|
||||||
z2 = math_ops.add_n([a_2_by_2, b_2_by_2])
|
z2 = math_ops.add_n([a_2_by_2, b_2_by_2])
|
||||||
dz1_dy = tape.gradient(z1, [a_2_by_2])[0]
|
dz1_dy = tape.gradient(z1, [a_2_by_2])[0]
|
||||||
@ -170,8 +164,7 @@ class Tests(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
array_ops.identity_n([a_2_by_2, b_2_by_2]),
|
array_ops.identity_n([a_2_by_2, b_2_by_2]),
|
||||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "IdentityN", None,
|
||||||
"IdentityN", None, None,
|
|
||||||
[a_2_by_2, b_2_by_2]))
|
[a_2_by_2, b_2_by_2]))
|
||||||
|
|
||||||
# Tests heterogeneous list op
|
# Tests heterogeneous list op
|
||||||
@ -187,9 +180,8 @@ class Tests(test.TestCase):
|
|||||||
with backprop.GradientTape(persistent=True) as tape:
|
with backprop.GradientTape(persistent=True) as tape:
|
||||||
tape.watch(a_2_by_2)
|
tape.watch(a_2_by_2)
|
||||||
tape.watch(b_2_by_2)
|
tape.watch(b_2_by_2)
|
||||||
z1 = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
z1 = pywrap_tfe.TFE_Py_FastPathExecute(ctx, "IdentityN",
|
||||||
"IdentityN", None, None,
|
None, [a_2_by_2, b_2_by_2])
|
||||||
[a_2_by_2, b_2_by_2])
|
|
||||||
z2 = array_ops.identity_n([a_2_by_2, b_2_by_2])
|
z2 = array_ops.identity_n([a_2_by_2, b_2_by_2])
|
||||||
dz1_dy = tape.gradient(z1[0], [a_2_by_2])[0]
|
dz1_dy = tape.gradient(z1[0], [a_2_by_2])[0]
|
||||||
dz2_dy = tape.gradient(z2[0], [a_2_by_2])[0]
|
dz2_dy = tape.gradient(z2[0], [a_2_by_2])[0]
|
||||||
@ -208,18 +200,17 @@ class Tests(test.TestCase):
|
|||||||
|
|
||||||
# Not enough base params
|
# Not enough base params
|
||||||
with self.assertRaisesRegex(ValueError,
|
with self.assertRaisesRegex(ValueError,
|
||||||
"at least 5 items in the input tuple"):
|
"at least 3 items in the input tuple"):
|
||||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, "Identity")
|
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "Identity")
|
||||||
|
|
||||||
# Not enough inputs
|
# Not enough inputs
|
||||||
with self.assertRaisesRegex(ValueError, "Expected to be at least 6, was 5"):
|
with self.assertRaisesRegex(ValueError, "Expected to be at least 4, was 3"):
|
||||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx_handle, "Identity",
|
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "Identity", None)
|
||||||
None, [])
|
|
||||||
|
|
||||||
# Bad type
|
# Bad type
|
||||||
with self.assertRaisesRegex(TypeError, "expected a string for op_name"):
|
with self.assertRaisesRegex(TypeError, "expected a string for op_name"):
|
||||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, ctx_handle,
|
pywrap_tfe.TFE_Py_FastPathExecute(ctx, ctx_handle, None,
|
||||||
None, [], a_2_by_2)
|
a_2_by_2)
|
||||||
|
|
||||||
@test_util.assert_no_new_tensors
|
@test_util.assert_no_new_tensors
|
||||||
@test_util.assert_no_garbage_created
|
@test_util.assert_no_garbage_created
|
||||||
@ -229,11 +220,9 @@ class Tests(test.TestCase):
|
|||||||
ctx = context.context()
|
ctx = context.context()
|
||||||
ctx.ensure_initialized()
|
ctx.ensure_initialized()
|
||||||
|
|
||||||
ctx_handle = ctx._handle
|
|
||||||
with self.assertRaises(core._FallbackException):
|
with self.assertRaises(core._FallbackException):
|
||||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, "Split",
|
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "Split", None,
|
||||||
None, None, split_dim, value,
|
split_dim, value, "num_split", -1)
|
||||||
"num_split", -1)
|
|
||||||
|
|
||||||
@test_util.assert_no_new_tensors
|
@test_util.assert_no_new_tensors
|
||||||
@test_util.assert_no_garbage_created
|
@test_util.assert_no_garbage_created
|
||||||
@ -273,9 +262,9 @@ class Tests(test.TestCase):
|
|||||||
ctx = context.context()
|
ctx = context.context()
|
||||||
ctx.ensure_initialized()
|
ctx.ensure_initialized()
|
||||||
with self.assertRaises(core._FallbackException):
|
with self.assertRaises(core._FallbackException):
|
||||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, "MatMul",
|
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "MatMul", None, m, m,
|
||||||
None, None, m, m, "transpose_a", False,
|
"transpose_a", False, "transpose_b",
|
||||||
"transpose_b", False)
|
False)
|
||||||
|
|
||||||
def testOpDefDefaultType(self):
|
def testOpDefDefaultType(self):
|
||||||
im = np.random.randint(
|
im = np.random.randint(
|
||||||
|
@ -928,8 +928,7 @@ bool GenEagerPythonOp::AddEagerFallbackCode(
|
|||||||
|
|
||||||
void GenEagerPythonOp::AddEagerFastPathExecute() {
|
void GenEagerPythonOp::AddEagerFastPathExecute() {
|
||||||
string fastpath_execute_params =
|
string fastpath_execute_params =
|
||||||
strings::StrCat("_ctx._context_handle, tld.device_name, \"",
|
strings::StrCat("_ctx, \"", op_def_.name(), "\", ", "name");
|
||||||
op_def_.name(), "\", ", "name, tld.op_callbacks");
|
|
||||||
string fallback_params;
|
string fallback_params;
|
||||||
|
|
||||||
for (int i = 0; i < api_def_.in_arg_size(); i++) {
|
for (int i = 0; i < api_def_.in_arg_size(); i++) {
|
||||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||||
|
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
|
||||||
#include "tensorflow/python/util/util.h"
|
#include "tensorflow/python/util/util.h"
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
@ -286,6 +287,98 @@ static py::object TFE_ClearScalarCache() {
|
|||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Wrapper around the EagerContextThreadLocalData struct (defined in
|
||||||
|
// pywrap_tfe.h), so it can be accessed from Python.
|
||||||
|
//
|
||||||
|
// For PyObject* fields, the get_*() methods return a new reference; and the
|
||||||
|
// set_*() methods create a new reference (i.e., they do not steal a reference).
|
||||||
|
class EagerContextThreadLocalDataWrapper {
|
||||||
|
public:
|
||||||
|
explicit EagerContextThreadLocalDataWrapper(py::handle py_eager_context,
|
||||||
|
py::handle is_eager,
|
||||||
|
py::handle device_spec)
|
||||||
|
: py_eager_context_(py_eager_context.ptr()) {
|
||||||
|
tensorflow::MakeEagerContextThreadLocalData(
|
||||||
|
py_eager_context.ptr(), is_eager.ptr(), device_spec.ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
~EagerContextThreadLocalDataWrapper() {
|
||||||
|
tensorflow::DestroyEagerContextThreadLocalData(py_eager_context_);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool get_is_eager() const { return GetData()->is_eager; }
|
||||||
|
void set_is_eager(bool v) { GetData()->is_eager = v; }
|
||||||
|
|
||||||
|
bool get_invoking_op_callbacks() const {
|
||||||
|
return GetData()->invoking_op_callbacks;
|
||||||
|
}
|
||||||
|
void set_invoking_op_callbacks(bool v) {
|
||||||
|
GetData()->invoking_op_callbacks = v;
|
||||||
|
}
|
||||||
|
|
||||||
|
py::handle get_device_name() const {
|
||||||
|
return GetPyObject(&GetData()->device_name);
|
||||||
|
}
|
||||||
|
void set_device_name(py::handle v) {
|
||||||
|
SetPyObject(v, &GetData()->device_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
py::handle get_scope_name() const {
|
||||||
|
return GetPyObject(&GetData()->scope_name);
|
||||||
|
}
|
||||||
|
void set_scope_name(py::handle v) { SetPyObject(v, &GetData()->scope_name); }
|
||||||
|
|
||||||
|
py::handle get_device_spec() const {
|
||||||
|
return GetPyObject(&GetData()->device_spec);
|
||||||
|
}
|
||||||
|
void set_device_spec(py::handle v) {
|
||||||
|
SetPyObject(v, &GetData()->device_spec);
|
||||||
|
}
|
||||||
|
|
||||||
|
py::handle get_function_call_options() const {
|
||||||
|
return GetPyObject(&GetData()->function_call_options);
|
||||||
|
}
|
||||||
|
void set_function_call_options(py::handle v) {
|
||||||
|
SetPyObject(v, &GetData()->function_call_options);
|
||||||
|
}
|
||||||
|
|
||||||
|
py::handle get_executor() const { return GetPyObject(&GetData()->executor); }
|
||||||
|
void set_executor(py::handle v) { SetPyObject(v, &GetData()->executor); }
|
||||||
|
|
||||||
|
py::handle get_op_callbacks() const {
|
||||||
|
return GetPyObject(&GetData()->op_callbacks);
|
||||||
|
}
|
||||||
|
void set_op_callbacks(py::handle v) {
|
||||||
|
SetPyObject(v, &GetData()->op_callbacks);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
tensorflow::EagerContextThreadLocalData* GetData() const {
|
||||||
|
auto* result =
|
||||||
|
tensorflow::GetEagerContextThreadLocalData(py_eager_context_);
|
||||||
|
if (!result) {
|
||||||
|
throw py::error_already_set();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
py::handle GetPyObject(tensorflow::Safe_PyObjectPtr* obj) const {
|
||||||
|
Py_INCREF(obj->get());
|
||||||
|
return obj->get();
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetPyObject(py::handle value, tensorflow::Safe_PyObjectPtr* ptr) {
|
||||||
|
Py_INCREF(value.ptr());
|
||||||
|
ptr->reset(value.ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject* py_eager_context_; // not owned (borrowed reference).
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// py::return_value_policy::reference is defined as specified by the
|
// py::return_value_policy::reference is defined as specified by the
|
||||||
// pybind11 documents listed here.
|
// pybind11 documents listed here.
|
||||||
// https://pybind11.readthedocs.io/en/stable/advanced/functions.html#return-value-policies
|
// https://pybind11.readthedocs.io/en/stable/advanced/functions.html#return-value-policies
|
||||||
@ -1272,6 +1365,38 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||||
});
|
});
|
||||||
|
|
||||||
|
py::class_<EagerContextThreadLocalDataWrapper>(m,
|
||||||
|
"EagerContextThreadLocalData")
|
||||||
|
.def(py::init<py::handle, py::handle, py::handle>(),
|
||||||
|
py::arg("py_eager_context"), py::arg("is_eager"),
|
||||||
|
py::arg("device_spec"))
|
||||||
|
.def_property("is_eager",
|
||||||
|
&EagerContextThreadLocalDataWrapper::get_is_eager,
|
||||||
|
&EagerContextThreadLocalDataWrapper::set_is_eager)
|
||||||
|
.def_property(
|
||||||
|
"invoking_op_callbacks",
|
||||||
|
&EagerContextThreadLocalDataWrapper::get_invoking_op_callbacks,
|
||||||
|
&EagerContextThreadLocalDataWrapper::set_invoking_op_callbacks)
|
||||||
|
.def_property("device_name",
|
||||||
|
&EagerContextThreadLocalDataWrapper::get_device_name,
|
||||||
|
&EagerContextThreadLocalDataWrapper::set_device_name)
|
||||||
|
.def_property("scope_name",
|
||||||
|
&EagerContextThreadLocalDataWrapper::get_scope_name,
|
||||||
|
&EagerContextThreadLocalDataWrapper::set_scope_name)
|
||||||
|
.def_property("device_spec",
|
||||||
|
&EagerContextThreadLocalDataWrapper::get_device_spec,
|
||||||
|
&EagerContextThreadLocalDataWrapper::set_device_spec)
|
||||||
|
.def_property(
|
||||||
|
"function_call_options",
|
||||||
|
&EagerContextThreadLocalDataWrapper::get_function_call_options,
|
||||||
|
&EagerContextThreadLocalDataWrapper::set_function_call_options)
|
||||||
|
.def_property("executor",
|
||||||
|
&EagerContextThreadLocalDataWrapper::get_executor,
|
||||||
|
&EagerContextThreadLocalDataWrapper::set_executor)
|
||||||
|
.def_property("op_callbacks",
|
||||||
|
&EagerContextThreadLocalDataWrapper::get_op_callbacks,
|
||||||
|
&EagerContextThreadLocalDataWrapper::set_op_callbacks);
|
||||||
|
|
||||||
// C API Enum
|
// C API Enum
|
||||||
|
|
||||||
py::enum_<TFE_ContextDevicePlacementPolicy>(
|
py::enum_<TFE_ContextDevicePlacementPolicy>(
|
||||||
|
Loading…
Reference in New Issue
Block a user