Move the _ThreadLocalData class used by eager Context
from Python (eager/context.py) to c++ (pywrap_tfe.h).
PiperOrigin-RevId: 330995559 Change-Id: I01bd358760b25f166a3e876fea8d83c9c0e8b513
This commit is contained in:
parent
3db52f724a
commit
8e1ef83d9e
tensorflow/python
@ -80,8 +80,8 @@ def c_tfe_py_fastpath_execute(a,
|
||||
assert ctx.executing_eagerly(
|
||||
), "The prototype doesn't contain C code for graph construction"
|
||||
try:
|
||||
return pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"MatMul", name, ctx.op_callbacks,
|
||||
return pywrap_tfe.TFE_Py_FastPathExecute(ctx, ctx._handle,
|
||||
"MatMul", name,
|
||||
a, b, "transpose_a", transpose_a,
|
||||
"transpose_b", transpose_b)
|
||||
except core._NotOkStatusException as e:
|
||||
|
@ -183,21 +183,6 @@ class _TensorCaches(threading.local):
|
||||
return self._zeros_cache
|
||||
|
||||
|
||||
class _ThreadLocalData(threading.local):
|
||||
"""Thread local storage for the eager context."""
|
||||
|
||||
def __init__(self):
|
||||
super(_ThreadLocalData, self).__init__()
|
||||
self.device_spec = _starting_device_spec
|
||||
self.device_name = ""
|
||||
self.is_eager = default_execution_mode == EAGER_MODE
|
||||
self.scope_name = ""
|
||||
self.function_call_options = None
|
||||
self.executor = None
|
||||
self.op_callbacks = []
|
||||
self.invoking_op_callbacks = False
|
||||
|
||||
|
||||
ContextSwitch = collections.namedtuple(
|
||||
"ContextSwitch", ["is_building_function", "enter_context_fn",
|
||||
"device_stack"])
|
||||
@ -420,7 +405,10 @@ class Context(object):
|
||||
_tensor_caches_map[self._id] = _TensorCaches()
|
||||
|
||||
self._config = config
|
||||
self._thread_local_data = _ThreadLocalData()
|
||||
self._thread_local_data = pywrap_tfe.EagerContextThreadLocalData(
|
||||
self,
|
||||
is_eager=lambda: default_execution_mode == EAGER_MODE,
|
||||
device_spec=_starting_device_spec)
|
||||
self._context_switches = _ContextSwitchStack(self.executing_eagerly())
|
||||
self._context_handle = None
|
||||
self._context_devices = None
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
|
||||
|
||||
typedef tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>
|
||||
TFE_InputTensorHandles;
|
||||
@ -259,16 +260,15 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
|
||||
// it will simply fail with a NotImplementedError.
|
||||
//
|
||||
// The "args" PyObject* is meant to be a tuple with the following structure:
|
||||
// Item 1: The TFE Context
|
||||
// Item 2: device_name: Name of the device on which to execute the operation,
|
||||
// or NULL for automatic selection.
|
||||
// Item 3: op_name: Name of the TensorFlow op to execute.
|
||||
// Item 4: name: An optional name for the operation.
|
||||
// Item 5: List representing all callbacks to execute after successful
|
||||
// op execute.
|
||||
// Item 6 onwards: inputs - This is a list of inputs followed by a list of
|
||||
// Item 1: The Python eager Context object
|
||||
// Item 2: op_name: Name of the TensorFlow op to execute.
|
||||
// Item 3: name: An optional name for the operation.
|
||||
// Item 4 onwards: inputs - This is a list of inputs followed by a list of
|
||||
// attrs. It is not necessary for type attrs to be present.
|
||||
//
|
||||
// Note: the device_name and op_callbacks, which were previously passed
|
||||
// as arguments, are now read via GetEagerContextThreadLocalData().
|
||||
//
|
||||
// This is named _C since there doesn't seem to be any way to make it visible
|
||||
// in the SWIG interface without renaming due to the use of the %native
|
||||
// directive.
|
||||
@ -394,4 +394,59 @@ PyObject* GetPyEagerContext();
|
||||
TF_Status* GetStatus();
|
||||
// Returns the pre-allocated status to the code.
|
||||
void ReturnStatus(TF_Status* status);
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Thread-local data associated with a Python eager Context object.
|
||||
//
|
||||
// TODO(edloper): Consider changing device_name and scope_name to a const char*
|
||||
// (with nullptr used for None). However, note that existing code (e.g.
|
||||
// TFE_TensorHandleCache::Lookup) assumes that the lifetime of these strings
|
||||
// extends beyond the point where their value is changed; so we'd need to make
|
||||
// sure that the strings stay alive (maybe using PyUnicode_InternInPlace?)
|
||||
struct EagerContextThreadLocalData {
|
||||
bool is_eager = false;
|
||||
bool invoking_op_callbacks = false;
|
||||
tensorflow::Safe_PyObjectPtr device_name;
|
||||
tensorflow::Safe_PyObjectPtr scope_name;
|
||||
tensorflow::Safe_PyObjectPtr device_spec;
|
||||
tensorflow::Safe_PyObjectPtr function_call_options;
|
||||
tensorflow::Safe_PyObjectPtr executor;
|
||||
tensorflow::Safe_PyObjectPtr op_callbacks;
|
||||
};
|
||||
|
||||
// Create a thread-local-data structure associated with py_eager_context.
|
||||
// `is_eager` and `device_spec` are used to supply default values for those
|
||||
// fields whenever a new thread-local instance is created for py_eager_tensor.
|
||||
//
|
||||
// This function assumes that the Python GIL is held (and does not perform its
|
||||
// own locking).
|
||||
void MakeEagerContextThreadLocalData(PyObject* py_eager_context,
|
||||
PyObject* is_eager,
|
||||
PyObject* device_spec);
|
||||
|
||||
// Returns the thread-local instance of EagerContextThreadLocalData that is
|
||||
// associated with the given Python Context object. If an instance has not
|
||||
// yet been created for `py_eager_context` in this thread, then a new one is
|
||||
// created, and initialized with the default values specified in
|
||||
// MakeEagerContextThreadLocalData.
|
||||
EagerContextThreadLocalData* GetEagerContextThreadLocalData(
|
||||
PyObject* py_eager_context);
|
||||
|
||||
// Free data structures used to track py_eager_context.
|
||||
//
|
||||
// This frees global state associated with py_eager_context, as well as thread-
|
||||
// local state associated with py_eager_context and the current thread. If you
|
||||
// wish to destroy thread-local state associated with a single py_eager_context
|
||||
// for multiple threads, then you must call this method from each thread.
|
||||
//
|
||||
// Thread-local state assocaited with eager contexts is also automatically
|
||||
// cleaned up when the thread is destroyed.
|
||||
//
|
||||
// This function assumes that the Python GIL is held (and does not perform its
|
||||
// own locking).
|
||||
void DestroyEagerContextThreadLocalData(PyObject* py_eager_context);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
|
||||
|
@ -2953,7 +2953,14 @@ PyObject* TFE_Py_PackJVPs(PyObject* tensors) {
|
||||
}
|
||||
|
||||
namespace {
|
||||
static const int kFastPathExecuteInputStartIndex = 5;
|
||||
|
||||
// Indices for the "args" tuple that's passed to TFE_Py_FastPathExecute_C.
|
||||
enum FastPathExecuteArgIndex {
|
||||
FAST_PATH_EXECUTE_ARG_CONTEXT = 0,
|
||||
FAST_PATH_EXECUTE_ARG_OP_NAME = 1,
|
||||
FAST_PATH_EXECUTE_ARG_NAME = 2,
|
||||
FAST_PATH_EXECUTE_ARG_INPUT_START = 3
|
||||
};
|
||||
|
||||
PyObject* GetPythonObjectFromString(tensorflow::StringPiece s) {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
@ -3063,7 +3070,7 @@ tensorflow::DataType MaybeGetDTypeForAttr(const string& attr,
|
||||
|
||||
for (const auto& input_info : it->second) {
|
||||
PyObject* item = PyTuple_GET_ITEM(
|
||||
op_exec_info->args, kFastPathExecuteInputStartIndex + input_info.i);
|
||||
op_exec_info->args, FAST_PATH_EXECUTE_ARG_INPUT_START + input_info.i);
|
||||
if (input_info.is_list) {
|
||||
tensorflow::Safe_PyObjectPtr fast_item(
|
||||
PySequence_Fast(item, "Unable to allocate"));
|
||||
@ -3526,19 +3533,26 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
|
||||
tensorflow::profiler::TraceMe activity(
|
||||
"TFE_Py_FastPathExecute_C", tensorflow::profiler::TraceMeLevel::kInfo);
|
||||
Py_ssize_t args_size = PyTuple_GET_SIZE(args);
|
||||
if (args_size < kFastPathExecuteInputStartIndex) {
|
||||
if (args_size < FAST_PATH_EXECUTE_ARG_INPUT_START) {
|
||||
PyErr_SetString(
|
||||
PyExc_ValueError,
|
||||
Printf("There must be at least %d items in the input tuple.",
|
||||
kFastPathExecuteInputStartIndex)
|
||||
FAST_PATH_EXECUTE_ARG_INPUT_START)
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
FastPathOpExecInfo op_exec_info;
|
||||
|
||||
PyObject* py_eager_context =
|
||||
PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_CONTEXT);
|
||||
|
||||
// TODO(edoper): Use interned string here
|
||||
PyObject* eager_context_handle =
|
||||
PyObject_GetAttrString(py_eager_context, "_context_handle");
|
||||
|
||||
TFE_Context* ctx = reinterpret_cast<TFE_Context*>(
|
||||
PyCapsule_GetPointer(PyTuple_GET_ITEM(args, 0), nullptr));
|
||||
PyCapsule_GetPointer(eager_context_handle, nullptr));
|
||||
op_exec_info.ctx = ctx;
|
||||
op_exec_info.args = args;
|
||||
|
||||
@ -3550,10 +3564,15 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
op_exec_info.device_name = GetDeviceName(PyTuple_GET_ITEM(args, 1));
|
||||
op_exec_info.op_name = PyTuple_GET_ITEM(args, 2);
|
||||
op_exec_info.name = PyTuple_GET_ITEM(args, 3);
|
||||
op_exec_info.callbacks = PyTuple_GET_ITEM(args, 4);
|
||||
auto* tld = tensorflow::GetEagerContextThreadLocalData(py_eager_context);
|
||||
if (tld == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
op_exec_info.device_name = GetDeviceName(tld->device_name.get());
|
||||
op_exec_info.callbacks = tld->op_callbacks.get();
|
||||
|
||||
op_exec_info.op_name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_OP_NAME);
|
||||
op_exec_info.name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_NAME);
|
||||
|
||||
// TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks
|
||||
// (similar to benchmark_tf_gradient_function_*). Also consider using an
|
||||
@ -3591,18 +3610,19 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
|
||||
const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef();
|
||||
if (op_def == nullptr) return nullptr;
|
||||
|
||||
if (args_size < kFastPathExecuteInputStartIndex + op_def->input_arg_size()) {
|
||||
if (args_size <
|
||||
FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size()) {
|
||||
PyErr_SetString(
|
||||
PyExc_ValueError,
|
||||
Printf("Tuple size smaller than intended. Expected to be at least %d, "
|
||||
"was %ld",
|
||||
kFastPathExecuteInputStartIndex + op_def->input_arg_size(),
|
||||
FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
|
||||
args_size)
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!CheckInputsOk(args, kFastPathExecuteInputStartIndex, *op_def)) {
|
||||
if (!CheckInputsOk(args, FAST_PATH_EXECUTE_ARG_INPUT_START, *op_def)) {
|
||||
RaiseFallbackException(
|
||||
"This function does not handle the case of the path where "
|
||||
"all inputs are not already EagerTensors.");
|
||||
@ -3618,7 +3638,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
|
||||
|
||||
// Set non-inferred attrs, including setting defaults if the attr is passed in
|
||||
// as None.
|
||||
for (int i = kFastPathExecuteInputStartIndex + op_def->input_arg_size();
|
||||
for (int i = FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size();
|
||||
i < args_size; i += 2) {
|
||||
PyObject* py_attr_name = PyTuple_GET_ITEM(args, i);
|
||||
const char* attr_name = TFE_GetPythonString(py_attr_name);
|
||||
@ -3675,7 +3695,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
|
||||
const auto& input_arg = op_def->input_arg(i);
|
||||
|
||||
PyObject* input =
|
||||
PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex + i);
|
||||
PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_INPUT_START + i);
|
||||
if (!input_arg.number_attr().empty()) {
|
||||
// The item is a homogeneous list.
|
||||
if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr;
|
||||
@ -3820,7 +3840,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
|
||||
if (op_exec_info.run_callbacks) {
|
||||
if (!RunCallbacks(
|
||||
op_exec_info, args,
|
||||
kFastPathExecuteInputStartIndex + op_def->input_arg_size(),
|
||||
FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
|
||||
*flattened_inputs, *flattened_attrs, flat_result.get())) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -4203,3 +4223,125 @@ PyObject* GetPyEagerContext() {
|
||||
Py_INCREF(py_context);
|
||||
return py_context;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Default values for thread_local_data fields.
|
||||
struct EagerContextThreadLocalDataDefaults {
|
||||
tensorflow::Safe_PyObjectPtr is_eager;
|
||||
tensorflow::Safe_PyObjectPtr device_spec;
|
||||
};
|
||||
|
||||
// Maps each py_eager_context object to its thread_local_data.
|
||||
//
|
||||
// Note: we need to use the python Context object as the key here (and not
|
||||
// its handle object), because the handle object isn't created until the
|
||||
// context is initialized; but thread_local_data is potentially accessed
|
||||
// before then.
|
||||
using EagerContextThreadLocalDataMap = absl::flat_hash_map<
|
||||
PyObject*, std::unique_ptr<tensorflow::EagerContextThreadLocalData>>;
|
||||
thread_local EagerContextThreadLocalDataMap*
|
||||
eager_context_thread_local_data_map = nullptr;
|
||||
|
||||
// Maps each py_eager_context object to default values.
|
||||
using EagerContextThreadLocalDataDefaultsMap =
|
||||
absl::flat_hash_map<PyObject*, EagerContextThreadLocalDataDefaults>;
|
||||
EagerContextThreadLocalDataDefaultsMap*
|
||||
eager_context_thread_local_data_defaults = nullptr;
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
void MakeEagerContextThreadLocalData(PyObject* py_eager_context,
|
||||
PyObject* is_eager,
|
||||
PyObject* device_spec) {
|
||||
DCheckPyGilState();
|
||||
if (eager_context_thread_local_data_defaults == nullptr) {
|
||||
eager_context_thread_local_data_defaults =
|
||||
new EagerContextThreadLocalDataDefaultsMap();
|
||||
}
|
||||
if (eager_context_thread_local_data_defaults->count(py_eager_context) > 0) {
|
||||
PyErr_SetString(PyExc_AssertionError,
|
||||
"MakeEagerContextThreadLocalData may not be called "
|
||||
"twice on the same eager Context object.");
|
||||
}
|
||||
|
||||
auto& defaults =
|
||||
(*eager_context_thread_local_data_defaults)[py_eager_context];
|
||||
Py_INCREF(is_eager);
|
||||
defaults.is_eager.reset(is_eager);
|
||||
Py_INCREF(device_spec);
|
||||
defaults.device_spec.reset(device_spec);
|
||||
}
|
||||
|
||||
EagerContextThreadLocalData* GetEagerContextThreadLocalData(
|
||||
PyObject* py_eager_context) {
|
||||
if (eager_context_thread_local_data_defaults == nullptr) {
|
||||
PyErr_SetString(PyExc_AssertionError,
|
||||
"MakeEagerContextThreadLocalData must be called "
|
||||
"before GetEagerContextThreadLocalData.");
|
||||
return nullptr;
|
||||
}
|
||||
auto defaults =
|
||||
eager_context_thread_local_data_defaults->find(py_eager_context);
|
||||
if (defaults == eager_context_thread_local_data_defaults->end()) {
|
||||
PyErr_SetString(PyExc_AssertionError,
|
||||
"MakeEagerContextThreadLocalData must be called "
|
||||
"before GetEagerContextThreadLocalData.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (eager_context_thread_local_data_map == nullptr) {
|
||||
eager_context_thread_local_data_map = new EagerContextThreadLocalDataMap();
|
||||
}
|
||||
auto& thread_local_data =
|
||||
(*eager_context_thread_local_data_map)[py_eager_context];
|
||||
|
||||
if (!thread_local_data) {
|
||||
thread_local_data.reset(new EagerContextThreadLocalData());
|
||||
|
||||
Safe_PyObjectPtr is_eager(PyObject_CallFunctionObjArgs(
|
||||
defaults->second.is_eager.get(), nullptr));
|
||||
if (!is_eager) return nullptr;
|
||||
thread_local_data->is_eager = PyObject_IsTrue(is_eager.get());
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
PyObject* scope_name = PyUnicode_FromString("");
|
||||
#else
|
||||
PyObject* scope_name = PyString_FromString("");
|
||||
#endif
|
||||
thread_local_data->scope_name.reset(scope_name);
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
PyObject* device_name = PyUnicode_FromString("");
|
||||
#else
|
||||
PyObject* device_name = PyString_FromString("");
|
||||
#endif
|
||||
thread_local_data->device_name.reset(device_name);
|
||||
|
||||
Py_INCREF(defaults->second.device_spec.get());
|
||||
thread_local_data->device_spec.reset(defaults->second.device_spec.get());
|
||||
|
||||
Py_INCREF(Py_None);
|
||||
thread_local_data->function_call_options.reset(Py_None);
|
||||
|
||||
Py_INCREF(Py_None);
|
||||
thread_local_data->executor.reset(Py_None);
|
||||
|
||||
thread_local_data->op_callbacks.reset(PyList_New(0));
|
||||
}
|
||||
return thread_local_data.get();
|
||||
}
|
||||
|
||||
void DestroyEagerContextThreadLocalData(PyObject* py_eager_context) {
|
||||
DCheckPyGilState();
|
||||
if (eager_context_thread_local_data_defaults) {
|
||||
eager_context_thread_local_data_defaults->erase(py_eager_context);
|
||||
}
|
||||
if (eager_context_thread_local_data_map) {
|
||||
eager_context_thread_local_data_map->erase(py_eager_context);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -57,16 +57,15 @@ class Tests(test.TestCase):
|
||||
|
||||
self.assertAllClose(
|
||||
math_ops.matmul(a_2_by_2, b_2_by_2),
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"MatMul", None, None, a_2_by_2,
|
||||
b_2_by_2, "transpose_a", False,
|
||||
"transpose_b", False))
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "MatMul", None,
|
||||
a_2_by_2, b_2_by_2, "transpose_a",
|
||||
False, "transpose_b", False))
|
||||
self.assertAllClose(
|
||||
math_ops.matmul(a_100_by_784, b_100_by_784, transpose_b=True),
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"MatMul", None, None, a_100_by_784,
|
||||
b_100_by_784, "transpose_a", False,
|
||||
"transpose_b", True))
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "MatMul", None,
|
||||
a_100_by_784, b_100_by_784,
|
||||
"transpose_a", False, "transpose_b",
|
||||
True))
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
@test_util.assert_no_garbage_created
|
||||
@ -76,14 +75,12 @@ class Tests(test.TestCase):
|
||||
|
||||
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
|
||||
m = resource_variable_ops.ResourceVariable(a_2_by_2)
|
||||
x = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"MatMul", None, None, m, m,
|
||||
"transpose_a", False, "transpose_b",
|
||||
False)
|
||||
y = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"MatMul", None, None, a_2_by_2,
|
||||
a_2_by_2, "transpose_a", False,
|
||||
x = pywrap_tfe.TFE_Py_FastPathExecute(ctx, "MatMul", None, m,
|
||||
m, "transpose_a", False,
|
||||
"transpose_b", False)
|
||||
y = pywrap_tfe.TFE_Py_FastPathExecute(ctx, "MatMul", None,
|
||||
a_2_by_2, a_2_by_2, "transpose_a",
|
||||
False, "transpose_b", False)
|
||||
|
||||
self.assertAllEqual(x, y)
|
||||
|
||||
@ -96,10 +93,9 @@ class Tests(test.TestCase):
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
|
||||
tape.watch(a_2_by_2)
|
||||
z = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"MatMul", None, None, a_2_by_2,
|
||||
a_2_by_2, "transpose_a", False,
|
||||
"transpose_b", False)
|
||||
z = pywrap_tfe.TFE_Py_FastPathExecute(ctx, "MatMul", None,
|
||||
a_2_by_2, a_2_by_2, "transpose_a",
|
||||
False, "transpose_b", False)
|
||||
dz_dy = tape.gradient(z, [a_2_by_2])[0]
|
||||
self.assertAllEqual(dz_dy.numpy(),
|
||||
constant_op.constant(4.0, shape=[2, 2]).numpy())
|
||||
@ -114,10 +110,9 @@ class Tests(test.TestCase):
|
||||
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
|
||||
m = resource_variable_ops.ResourceVariable(a_2_by_2)
|
||||
tape.watch(m)
|
||||
z = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"MatMul", None, None, m, m,
|
||||
"transpose_a", False, "transpose_b",
|
||||
False)
|
||||
z = pywrap_tfe.TFE_Py_FastPathExecute(ctx, "MatMul", None, m,
|
||||
m, "transpose_a", False,
|
||||
"transpose_b", False)
|
||||
dz_dy = tape.gradient(z, [m])[0]
|
||||
self.assertAllEqual(dz_dy.numpy(),
|
||||
constant_op.constant(4.0, shape=[2, 2]).numpy())
|
||||
@ -134,8 +129,8 @@ class Tests(test.TestCase):
|
||||
|
||||
self.assertAllClose(
|
||||
math_ops.add_n([a_2_by_2, b_2_by_2]),
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, "AddN",
|
||||
None, None, [a_2_by_2, b_2_by_2]))
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "AddN", None,
|
||||
[a_2_by_2, b_2_by_2]))
|
||||
|
||||
# Tests homogeneous list op
|
||||
@test_util.assert_no_new_tensors
|
||||
@ -150,8 +145,7 @@ class Tests(test.TestCase):
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
tape.watch(a_2_by_2)
|
||||
tape.watch(b_2_by_2)
|
||||
z1 = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"AddN", None, None,
|
||||
z1 = pywrap_tfe.TFE_Py_FastPathExecute(ctx, "AddN", None,
|
||||
[a_2_by_2, b_2_by_2])
|
||||
z2 = math_ops.add_n([a_2_by_2, b_2_by_2])
|
||||
dz1_dy = tape.gradient(z1, [a_2_by_2])[0]
|
||||
@ -170,8 +164,7 @@ class Tests(test.TestCase):
|
||||
|
||||
self.assertAllClose(
|
||||
array_ops.identity_n([a_2_by_2, b_2_by_2]),
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"IdentityN", None, None,
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "IdentityN", None,
|
||||
[a_2_by_2, b_2_by_2]))
|
||||
|
||||
# Tests heterogeneous list op
|
||||
@ -187,9 +180,8 @@ class Tests(test.TestCase):
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
tape.watch(a_2_by_2)
|
||||
tape.watch(b_2_by_2)
|
||||
z1 = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"IdentityN", None, None,
|
||||
[a_2_by_2, b_2_by_2])
|
||||
z1 = pywrap_tfe.TFE_Py_FastPathExecute(ctx, "IdentityN",
|
||||
None, [a_2_by_2, b_2_by_2])
|
||||
z2 = array_ops.identity_n([a_2_by_2, b_2_by_2])
|
||||
dz1_dy = tape.gradient(z1[0], [a_2_by_2])[0]
|
||||
dz2_dy = tape.gradient(z2[0], [a_2_by_2])[0]
|
||||
@ -208,18 +200,17 @@ class Tests(test.TestCase):
|
||||
|
||||
# Not enough base params
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"at least 5 items in the input tuple"):
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, "Identity")
|
||||
"at least 3 items in the input tuple"):
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "Identity")
|
||||
|
||||
# Not enough inputs
|
||||
with self.assertRaisesRegex(ValueError, "Expected to be at least 6, was 5"):
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx_handle, "Identity",
|
||||
None, [])
|
||||
with self.assertRaisesRegex(ValueError, "Expected to be at least 4, was 3"):
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "Identity", None)
|
||||
|
||||
# Bad type
|
||||
with self.assertRaisesRegex(TypeError, "expected a string for op_name"):
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, ctx_handle,
|
||||
None, [], a_2_by_2)
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx, ctx_handle, None,
|
||||
a_2_by_2)
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
@test_util.assert_no_garbage_created
|
||||
@ -229,11 +220,9 @@ class Tests(test.TestCase):
|
||||
ctx = context.context()
|
||||
ctx.ensure_initialized()
|
||||
|
||||
ctx_handle = ctx._handle
|
||||
with self.assertRaises(core._FallbackException):
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, "Split",
|
||||
None, None, split_dim, value,
|
||||
"num_split", -1)
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "Split", None,
|
||||
split_dim, value, "num_split", -1)
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
@test_util.assert_no_garbage_created
|
||||
@ -273,9 +262,9 @@ class Tests(test.TestCase):
|
||||
ctx = context.context()
|
||||
ctx.ensure_initialized()
|
||||
with self.assertRaises(core._FallbackException):
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, "MatMul",
|
||||
None, None, m, m, "transpose_a", False,
|
||||
"transpose_b", False)
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "MatMul", None, m, m,
|
||||
"transpose_a", False, "transpose_b",
|
||||
False)
|
||||
|
||||
def testOpDefDefaultType(self):
|
||||
im = np.random.randint(
|
||||
|
@ -928,8 +928,7 @@ bool GenEagerPythonOp::AddEagerFallbackCode(
|
||||
|
||||
void GenEagerPythonOp::AddEagerFastPathExecute() {
|
||||
string fastpath_execute_params =
|
||||
strings::StrCat("_ctx._context_handle, tld.device_name, \"",
|
||||
op_def_.name(), "\", ", "name, tld.op_callbacks");
|
||||
strings::StrCat("_ctx, \"", op_def_.name(), "\", ", "name");
|
||||
string fallback_params;
|
||||
|
||||
for (int i = 0; i < api_def_.in_arg_size(); i++) {
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
|
||||
#include "tensorflow/python/util/util.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
@ -286,6 +287,98 @@ static py::object TFE_ClearScalarCache() {
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace {
|
||||
|
||||
// Wrapper around the EagerContextThreadLocalData struct (defined in
|
||||
// pywrap_tfe.h), so it can be accessed from Python.
|
||||
//
|
||||
// For PyObject* fields, the get_*() methods return a new reference; and the
|
||||
// set_*() methods create a new reference (i.e., they do not steal a reference).
|
||||
class EagerContextThreadLocalDataWrapper {
|
||||
public:
|
||||
explicit EagerContextThreadLocalDataWrapper(py::handle py_eager_context,
|
||||
py::handle is_eager,
|
||||
py::handle device_spec)
|
||||
: py_eager_context_(py_eager_context.ptr()) {
|
||||
tensorflow::MakeEagerContextThreadLocalData(
|
||||
py_eager_context.ptr(), is_eager.ptr(), device_spec.ptr());
|
||||
}
|
||||
|
||||
~EagerContextThreadLocalDataWrapper() {
|
||||
tensorflow::DestroyEagerContextThreadLocalData(py_eager_context_);
|
||||
}
|
||||
|
||||
bool get_is_eager() const { return GetData()->is_eager; }
|
||||
void set_is_eager(bool v) { GetData()->is_eager = v; }
|
||||
|
||||
bool get_invoking_op_callbacks() const {
|
||||
return GetData()->invoking_op_callbacks;
|
||||
}
|
||||
void set_invoking_op_callbacks(bool v) {
|
||||
GetData()->invoking_op_callbacks = v;
|
||||
}
|
||||
|
||||
py::handle get_device_name() const {
|
||||
return GetPyObject(&GetData()->device_name);
|
||||
}
|
||||
void set_device_name(py::handle v) {
|
||||
SetPyObject(v, &GetData()->device_name);
|
||||
}
|
||||
|
||||
py::handle get_scope_name() const {
|
||||
return GetPyObject(&GetData()->scope_name);
|
||||
}
|
||||
void set_scope_name(py::handle v) { SetPyObject(v, &GetData()->scope_name); }
|
||||
|
||||
py::handle get_device_spec() const {
|
||||
return GetPyObject(&GetData()->device_spec);
|
||||
}
|
||||
void set_device_spec(py::handle v) {
|
||||
SetPyObject(v, &GetData()->device_spec);
|
||||
}
|
||||
|
||||
py::handle get_function_call_options() const {
|
||||
return GetPyObject(&GetData()->function_call_options);
|
||||
}
|
||||
void set_function_call_options(py::handle v) {
|
||||
SetPyObject(v, &GetData()->function_call_options);
|
||||
}
|
||||
|
||||
py::handle get_executor() const { return GetPyObject(&GetData()->executor); }
|
||||
void set_executor(py::handle v) { SetPyObject(v, &GetData()->executor); }
|
||||
|
||||
py::handle get_op_callbacks() const {
|
||||
return GetPyObject(&GetData()->op_callbacks);
|
||||
}
|
||||
void set_op_callbacks(py::handle v) {
|
||||
SetPyObject(v, &GetData()->op_callbacks);
|
||||
}
|
||||
|
||||
private:
|
||||
tensorflow::EagerContextThreadLocalData* GetData() const {
|
||||
auto* result =
|
||||
tensorflow::GetEagerContextThreadLocalData(py_eager_context_);
|
||||
if (!result) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
py::handle GetPyObject(tensorflow::Safe_PyObjectPtr* obj) const {
|
||||
Py_INCREF(obj->get());
|
||||
return obj->get();
|
||||
}
|
||||
|
||||
void SetPyObject(py::handle value, tensorflow::Safe_PyObjectPtr* ptr) {
|
||||
Py_INCREF(value.ptr());
|
||||
ptr->reset(value.ptr());
|
||||
}
|
||||
|
||||
PyObject* py_eager_context_; // not owned (borrowed reference).
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// py::return_value_policy::reference is defined as specified by the
|
||||
// pybind11 documents listed here.
|
||||
// https://pybind11.readthedocs.io/en/stable/advanced/functions.html#return-value-policies
|
||||
@ -1272,6 +1365,38 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
});
|
||||
|
||||
py::class_<EagerContextThreadLocalDataWrapper>(m,
|
||||
"EagerContextThreadLocalData")
|
||||
.def(py::init<py::handle, py::handle, py::handle>(),
|
||||
py::arg("py_eager_context"), py::arg("is_eager"),
|
||||
py::arg("device_spec"))
|
||||
.def_property("is_eager",
|
||||
&EagerContextThreadLocalDataWrapper::get_is_eager,
|
||||
&EagerContextThreadLocalDataWrapper::set_is_eager)
|
||||
.def_property(
|
||||
"invoking_op_callbacks",
|
||||
&EagerContextThreadLocalDataWrapper::get_invoking_op_callbacks,
|
||||
&EagerContextThreadLocalDataWrapper::set_invoking_op_callbacks)
|
||||
.def_property("device_name",
|
||||
&EagerContextThreadLocalDataWrapper::get_device_name,
|
||||
&EagerContextThreadLocalDataWrapper::set_device_name)
|
||||
.def_property("scope_name",
|
||||
&EagerContextThreadLocalDataWrapper::get_scope_name,
|
||||
&EagerContextThreadLocalDataWrapper::set_scope_name)
|
||||
.def_property("device_spec",
|
||||
&EagerContextThreadLocalDataWrapper::get_device_spec,
|
||||
&EagerContextThreadLocalDataWrapper::set_device_spec)
|
||||
.def_property(
|
||||
"function_call_options",
|
||||
&EagerContextThreadLocalDataWrapper::get_function_call_options,
|
||||
&EagerContextThreadLocalDataWrapper::set_function_call_options)
|
||||
.def_property("executor",
|
||||
&EagerContextThreadLocalDataWrapper::get_executor,
|
||||
&EagerContextThreadLocalDataWrapper::set_executor)
|
||||
.def_property("op_callbacks",
|
||||
&EagerContextThreadLocalDataWrapper::get_op_callbacks,
|
||||
&EagerContextThreadLocalDataWrapper::set_op_callbacks);
|
||||
|
||||
// C API Enum
|
||||
|
||||
py::enum_<TFE_ContextDevicePlacementPolicy>(
|
||||
|
Loading…
Reference in New Issue
Block a user