Move the _ThreadLocalData class used by eager Context from Python (eager/context.py) to c++ (pywrap_tfe.h).

PiperOrigin-RevId: 330995559
Change-Id: I01bd358760b25f166a3e876fea8d83c9c0e8b513
This commit is contained in:
Edward Loper 2020-09-10 12:44:06 -07:00 committed by TensorFlower Gardener
parent 3db52f724a
commit 8e1ef83d9e
7 changed files with 387 additions and 89 deletions

View File

@ -80,8 +80,8 @@ def c_tfe_py_fastpath_execute(a,
assert ctx.executing_eagerly(
), "The prototype doesn't contain C code for graph construction"
try:
return pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"MatMul", name, ctx.op_callbacks,
return pywrap_tfe.TFE_Py_FastPathExecute(ctx, ctx._handle,
"MatMul", name,
a, b, "transpose_a", transpose_a,
"transpose_b", transpose_b)
except core._NotOkStatusException as e:

View File

@ -183,21 +183,6 @@ class _TensorCaches(threading.local):
return self._zeros_cache
class _ThreadLocalData(threading.local):
"""Thread local storage for the eager context."""
def __init__(self):
super(_ThreadLocalData, self).__init__()
self.device_spec = _starting_device_spec
self.device_name = ""
self.is_eager = default_execution_mode == EAGER_MODE
self.scope_name = ""
self.function_call_options = None
self.executor = None
self.op_callbacks = []
self.invoking_op_callbacks = False
ContextSwitch = collections.namedtuple(
"ContextSwitch", ["is_building_function", "enter_context_fn",
"device_stack"])
@ -420,7 +405,10 @@ class Context(object):
_tensor_caches_map[self._id] = _TensorCaches()
self._config = config
self._thread_local_data = _ThreadLocalData()
self._thread_local_data = pywrap_tfe.EagerContextThreadLocalData(
self,
is_eager=lambda: default_execution_mode == EAGER_MODE,
device_spec=_starting_device_spec)
self._context_switches = _ContextSwitchStack(self.executing_eagerly())
self._context_handle = None
self._context_devices = None

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
typedef tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>
TFE_InputTensorHandles;
@ -259,16 +260,15 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
// it will simply fail with a NotImplementedError.
//
// The "args" PyObject* is meant to be a tuple with the following structure:
// Item 1: The TFE Context
// Item 2: device_name: Name of the device on which to execute the operation,
// or NULL for automatic selection.
// Item 3: op_name: Name of the TensorFlow op to execute.
// Item 4: name: An optional name for the operation.
// Item 5: List representing all callbacks to execute after successful
// op execute.
// Item 6 onwards: inputs - This is a list of inputs followed by a list of
// Item 1: The Python eager Context object
// Item 2: op_name: Name of the TensorFlow op to execute.
// Item 3: name: An optional name for the operation.
// Item 4 onwards: inputs - This is a list of inputs followed by a list of
// attrs. It is not necessary for type attrs to be present.
//
// Note: the device_name and op_callbacks, which were previously passed
// as arguments, are now read via GetEagerContextThreadLocalData().
//
// This is named _C since there doesn't seem to be any way to make it visible
// in the SWIG interface without renaming due to the use of the %native
// directive.
@ -394,4 +394,59 @@ PyObject* GetPyEagerContext();
TF_Status* GetStatus();
// Returns the pre-allocated status to the code.
void ReturnStatus(TF_Status* status);
namespace tensorflow {
// Thread-local data associated with a Python eager Context object.
//
// TODO(edloper): Consider changing device_name and scope_name to a const char*
// (with nullptr used for None). However, note that existing code (e.g.
// TFE_TensorHandleCache::Lookup) assumes that the lifetime of these strings
// extends beyond the point where their value is changed; so we'd need to make
// sure that the strings stay alive (maybe using PyUnicode_InternInPlace?)
struct EagerContextThreadLocalData {
bool is_eager = false;
bool invoking_op_callbacks = false;
tensorflow::Safe_PyObjectPtr device_name;
tensorflow::Safe_PyObjectPtr scope_name;
tensorflow::Safe_PyObjectPtr device_spec;
tensorflow::Safe_PyObjectPtr function_call_options;
tensorflow::Safe_PyObjectPtr executor;
tensorflow::Safe_PyObjectPtr op_callbacks;
};
// Create a thread-local-data structure associated with py_eager_context.
// `is_eager` and `device_spec` are used to supply default values for those
// fields whenever a new thread-local instance is created for py_eager_tensor.
//
// This function assumes that the Python GIL is held (and does not perform its
// own locking).
void MakeEagerContextThreadLocalData(PyObject* py_eager_context,
PyObject* is_eager,
PyObject* device_spec);
// Returns the thread-local instance of EagerContextThreadLocalData that is
// associated with the given Python Context object. If an instance has not
// yet been created for `py_eager_context` in this thread, then a new one is
// created, and initialized with the default values specified in
// MakeEagerContextThreadLocalData.
EagerContextThreadLocalData* GetEagerContextThreadLocalData(
PyObject* py_eager_context);
// Free data structures used to track py_eager_context.
//
// This frees global state associated with py_eager_context, as well as thread-
// local state associated with py_eager_context and the current thread. If you
// wish to destroy thread-local state associated with a single py_eager_context
// for multiple threads, then you must call this method from each thread.
//
// Thread-local state assocaited with eager contexts is also automatically
// cleaned up when the thread is destroyed.
//
// This function assumes that the Python GIL is held (and does not perform its
// own locking).
void DestroyEagerContextThreadLocalData(PyObject* py_eager_context);
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_

View File

@ -2953,7 +2953,14 @@ PyObject* TFE_Py_PackJVPs(PyObject* tensors) {
}
namespace {
static const int kFastPathExecuteInputStartIndex = 5;
// Indices for the "args" tuple that's passed to TFE_Py_FastPathExecute_C.
enum FastPathExecuteArgIndex {
FAST_PATH_EXECUTE_ARG_CONTEXT = 0,
FAST_PATH_EXECUTE_ARG_OP_NAME = 1,
FAST_PATH_EXECUTE_ARG_NAME = 2,
FAST_PATH_EXECUTE_ARG_INPUT_START = 3
};
PyObject* GetPythonObjectFromString(tensorflow::StringPiece s) {
#if PY_MAJOR_VERSION >= 3
@ -3063,7 +3070,7 @@ tensorflow::DataType MaybeGetDTypeForAttr(const string& attr,
for (const auto& input_info : it->second) {
PyObject* item = PyTuple_GET_ITEM(
op_exec_info->args, kFastPathExecuteInputStartIndex + input_info.i);
op_exec_info->args, FAST_PATH_EXECUTE_ARG_INPUT_START + input_info.i);
if (input_info.is_list) {
tensorflow::Safe_PyObjectPtr fast_item(
PySequence_Fast(item, "Unable to allocate"));
@ -3526,19 +3533,26 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
tensorflow::profiler::TraceMe activity(
"TFE_Py_FastPathExecute_C", tensorflow::profiler::TraceMeLevel::kInfo);
Py_ssize_t args_size = PyTuple_GET_SIZE(args);
if (args_size < kFastPathExecuteInputStartIndex) {
if (args_size < FAST_PATH_EXECUTE_ARG_INPUT_START) {
PyErr_SetString(
PyExc_ValueError,
Printf("There must be at least %d items in the input tuple.",
kFastPathExecuteInputStartIndex)
FAST_PATH_EXECUTE_ARG_INPUT_START)
.c_str());
return nullptr;
}
FastPathOpExecInfo op_exec_info;
PyObject* py_eager_context =
PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_CONTEXT);
// TODO(edoper): Use interned string here
PyObject* eager_context_handle =
PyObject_GetAttrString(py_eager_context, "_context_handle");
TFE_Context* ctx = reinterpret_cast<TFE_Context*>(
PyCapsule_GetPointer(PyTuple_GET_ITEM(args, 0), nullptr));
PyCapsule_GetPointer(eager_context_handle, nullptr));
op_exec_info.ctx = ctx;
op_exec_info.args = args;
@ -3550,10 +3564,15 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
return nullptr;
}
op_exec_info.device_name = GetDeviceName(PyTuple_GET_ITEM(args, 1));
op_exec_info.op_name = PyTuple_GET_ITEM(args, 2);
op_exec_info.name = PyTuple_GET_ITEM(args, 3);
op_exec_info.callbacks = PyTuple_GET_ITEM(args, 4);
auto* tld = tensorflow::GetEagerContextThreadLocalData(py_eager_context);
if (tld == nullptr) {
return nullptr;
}
op_exec_info.device_name = GetDeviceName(tld->device_name.get());
op_exec_info.callbacks = tld->op_callbacks.get();
op_exec_info.op_name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_OP_NAME);
op_exec_info.name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_NAME);
// TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks
// (similar to benchmark_tf_gradient_function_*). Also consider using an
@ -3591,18 +3610,19 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef();
if (op_def == nullptr) return nullptr;
if (args_size < kFastPathExecuteInputStartIndex + op_def->input_arg_size()) {
if (args_size <
FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size()) {
PyErr_SetString(
PyExc_ValueError,
Printf("Tuple size smaller than intended. Expected to be at least %d, "
"was %ld",
kFastPathExecuteInputStartIndex + op_def->input_arg_size(),
FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
args_size)
.c_str());
return nullptr;
}
if (!CheckInputsOk(args, kFastPathExecuteInputStartIndex, *op_def)) {
if (!CheckInputsOk(args, FAST_PATH_EXECUTE_ARG_INPUT_START, *op_def)) {
RaiseFallbackException(
"This function does not handle the case of the path where "
"all inputs are not already EagerTensors.");
@ -3618,7 +3638,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
// Set non-inferred attrs, including setting defaults if the attr is passed in
// as None.
for (int i = kFastPathExecuteInputStartIndex + op_def->input_arg_size();
for (int i = FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size();
i < args_size; i += 2) {
PyObject* py_attr_name = PyTuple_GET_ITEM(args, i);
const char* attr_name = TFE_GetPythonString(py_attr_name);
@ -3675,7 +3695,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
const auto& input_arg = op_def->input_arg(i);
PyObject* input =
PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex + i);
PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_INPUT_START + i);
if (!input_arg.number_attr().empty()) {
// The item is a homogeneous list.
if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr;
@ -3820,7 +3840,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
if (op_exec_info.run_callbacks) {
if (!RunCallbacks(
op_exec_info, args,
kFastPathExecuteInputStartIndex + op_def->input_arg_size(),
FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
*flattened_inputs, *flattened_attrs, flat_result.get())) {
return nullptr;
}
@ -4203,3 +4223,125 @@ PyObject* GetPyEagerContext() {
Py_INCREF(py_context);
return py_context;
}
namespace {
// Default values for thread_local_data fields.
struct EagerContextThreadLocalDataDefaults {
tensorflow::Safe_PyObjectPtr is_eager;
tensorflow::Safe_PyObjectPtr device_spec;
};
// Maps each py_eager_context object to its thread_local_data.
//
// Note: we need to use the python Context object as the key here (and not
// its handle object), because the handle object isn't created until the
// context is initialized; but thread_local_data is potentially accessed
// before then.
using EagerContextThreadLocalDataMap = absl::flat_hash_map<
PyObject*, std::unique_ptr<tensorflow::EagerContextThreadLocalData>>;
thread_local EagerContextThreadLocalDataMap*
eager_context_thread_local_data_map = nullptr;
// Maps each py_eager_context object to default values.
using EagerContextThreadLocalDataDefaultsMap =
absl::flat_hash_map<PyObject*, EagerContextThreadLocalDataDefaults>;
EagerContextThreadLocalDataDefaultsMap*
eager_context_thread_local_data_defaults = nullptr;
} // namespace
namespace tensorflow {
void MakeEagerContextThreadLocalData(PyObject* py_eager_context,
PyObject* is_eager,
PyObject* device_spec) {
DCheckPyGilState();
if (eager_context_thread_local_data_defaults == nullptr) {
eager_context_thread_local_data_defaults =
new EagerContextThreadLocalDataDefaultsMap();
}
if (eager_context_thread_local_data_defaults->count(py_eager_context) > 0) {
PyErr_SetString(PyExc_AssertionError,
"MakeEagerContextThreadLocalData may not be called "
"twice on the same eager Context object.");
}
auto& defaults =
(*eager_context_thread_local_data_defaults)[py_eager_context];
Py_INCREF(is_eager);
defaults.is_eager.reset(is_eager);
Py_INCREF(device_spec);
defaults.device_spec.reset(device_spec);
}
EagerContextThreadLocalData* GetEagerContextThreadLocalData(
PyObject* py_eager_context) {
if (eager_context_thread_local_data_defaults == nullptr) {
PyErr_SetString(PyExc_AssertionError,
"MakeEagerContextThreadLocalData must be called "
"before GetEagerContextThreadLocalData.");
return nullptr;
}
auto defaults =
eager_context_thread_local_data_defaults->find(py_eager_context);
if (defaults == eager_context_thread_local_data_defaults->end()) {
PyErr_SetString(PyExc_AssertionError,
"MakeEagerContextThreadLocalData must be called "
"before GetEagerContextThreadLocalData.");
return nullptr;
}
if (eager_context_thread_local_data_map == nullptr) {
eager_context_thread_local_data_map = new EagerContextThreadLocalDataMap();
}
auto& thread_local_data =
(*eager_context_thread_local_data_map)[py_eager_context];
if (!thread_local_data) {
thread_local_data.reset(new EagerContextThreadLocalData());
Safe_PyObjectPtr is_eager(PyObject_CallFunctionObjArgs(
defaults->second.is_eager.get(), nullptr));
if (!is_eager) return nullptr;
thread_local_data->is_eager = PyObject_IsTrue(is_eager.get());
#if PY_MAJOR_VERSION >= 3
PyObject* scope_name = PyUnicode_FromString("");
#else
PyObject* scope_name = PyString_FromString("");
#endif
thread_local_data->scope_name.reset(scope_name);
#if PY_MAJOR_VERSION >= 3
PyObject* device_name = PyUnicode_FromString("");
#else
PyObject* device_name = PyString_FromString("");
#endif
thread_local_data->device_name.reset(device_name);
Py_INCREF(defaults->second.device_spec.get());
thread_local_data->device_spec.reset(defaults->second.device_spec.get());
Py_INCREF(Py_None);
thread_local_data->function_call_options.reset(Py_None);
Py_INCREF(Py_None);
thread_local_data->executor.reset(Py_None);
thread_local_data->op_callbacks.reset(PyList_New(0));
}
return thread_local_data.get();
}
void DestroyEagerContextThreadLocalData(PyObject* py_eager_context) {
DCheckPyGilState();
if (eager_context_thread_local_data_defaults) {
eager_context_thread_local_data_defaults->erase(py_eager_context);
}
if (eager_context_thread_local_data_map) {
eager_context_thread_local_data_map->erase(py_eager_context);
}
}
} // namespace tensorflow

View File

@ -57,16 +57,15 @@ class Tests(test.TestCase):
self.assertAllClose(
math_ops.matmul(a_2_by_2, b_2_by_2),
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"MatMul", None, None, a_2_by_2,
b_2_by_2, "transpose_a", False,
"transpose_b", False))
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "MatMul", None,
a_2_by_2, b_2_by_2, "transpose_a",
False, "transpose_b", False))
self.assertAllClose(
math_ops.matmul(a_100_by_784, b_100_by_784, transpose_b=True),
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"MatMul", None, None, a_100_by_784,
b_100_by_784, "transpose_a", False,
"transpose_b", True))
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "MatMul", None,
a_100_by_784, b_100_by_784,
"transpose_a", False, "transpose_b",
True))
@test_util.assert_no_new_tensors
@test_util.assert_no_garbage_created
@ -76,14 +75,12 @@ class Tests(test.TestCase):
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
m = resource_variable_ops.ResourceVariable(a_2_by_2)
x = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"MatMul", None, None, m, m,
"transpose_a", False, "transpose_b",
False)
y = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"MatMul", None, None, a_2_by_2,
a_2_by_2, "transpose_a", False,
x = pywrap_tfe.TFE_Py_FastPathExecute(ctx, "MatMul", None, m,
m, "transpose_a", False,
"transpose_b", False)
y = pywrap_tfe.TFE_Py_FastPathExecute(ctx, "MatMul", None,
a_2_by_2, a_2_by_2, "transpose_a",
False, "transpose_b", False)
self.assertAllEqual(x, y)
@ -96,10 +93,9 @@ class Tests(test.TestCase):
with backprop.GradientTape(persistent=True) as tape:
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
tape.watch(a_2_by_2)
z = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"MatMul", None, None, a_2_by_2,
a_2_by_2, "transpose_a", False,
"transpose_b", False)
z = pywrap_tfe.TFE_Py_FastPathExecute(ctx, "MatMul", None,
a_2_by_2, a_2_by_2, "transpose_a",
False, "transpose_b", False)
dz_dy = tape.gradient(z, [a_2_by_2])[0]
self.assertAllEqual(dz_dy.numpy(),
constant_op.constant(4.0, shape=[2, 2]).numpy())
@ -114,10 +110,9 @@ class Tests(test.TestCase):
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
m = resource_variable_ops.ResourceVariable(a_2_by_2)
tape.watch(m)
z = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"MatMul", None, None, m, m,
"transpose_a", False, "transpose_b",
False)
z = pywrap_tfe.TFE_Py_FastPathExecute(ctx, "MatMul", None, m,
m, "transpose_a", False,
"transpose_b", False)
dz_dy = tape.gradient(z, [m])[0]
self.assertAllEqual(dz_dy.numpy(),
constant_op.constant(4.0, shape=[2, 2]).numpy())
@ -134,8 +129,8 @@ class Tests(test.TestCase):
self.assertAllClose(
math_ops.add_n([a_2_by_2, b_2_by_2]),
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, "AddN",
None, None, [a_2_by_2, b_2_by_2]))
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "AddN", None,
[a_2_by_2, b_2_by_2]))
# Tests homogeneous list op
@test_util.assert_no_new_tensors
@ -150,8 +145,7 @@ class Tests(test.TestCase):
with backprop.GradientTape(persistent=True) as tape:
tape.watch(a_2_by_2)
tape.watch(b_2_by_2)
z1 = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"AddN", None, None,
z1 = pywrap_tfe.TFE_Py_FastPathExecute(ctx, "AddN", None,
[a_2_by_2, b_2_by_2])
z2 = math_ops.add_n([a_2_by_2, b_2_by_2])
dz1_dy = tape.gradient(z1, [a_2_by_2])[0]
@ -170,8 +164,7 @@ class Tests(test.TestCase):
self.assertAllClose(
array_ops.identity_n([a_2_by_2, b_2_by_2]),
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"IdentityN", None, None,
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "IdentityN", None,
[a_2_by_2, b_2_by_2]))
# Tests heterogeneous list op
@ -187,9 +180,8 @@ class Tests(test.TestCase):
with backprop.GradientTape(persistent=True) as tape:
tape.watch(a_2_by_2)
tape.watch(b_2_by_2)
z1 = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"IdentityN", None, None,
[a_2_by_2, b_2_by_2])
z1 = pywrap_tfe.TFE_Py_FastPathExecute(ctx, "IdentityN",
None, [a_2_by_2, b_2_by_2])
z2 = array_ops.identity_n([a_2_by_2, b_2_by_2])
dz1_dy = tape.gradient(z1[0], [a_2_by_2])[0]
dz2_dy = tape.gradient(z2[0], [a_2_by_2])[0]
@ -208,18 +200,17 @@ class Tests(test.TestCase):
# Not enough base params
with self.assertRaisesRegex(ValueError,
"at least 5 items in the input tuple"):
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, "Identity")
"at least 3 items in the input tuple"):
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "Identity")
# Not enough inputs
with self.assertRaisesRegex(ValueError, "Expected to be at least 6, was 5"):
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx_handle, "Identity",
None, [])
with self.assertRaisesRegex(ValueError, "Expected to be at least 4, was 3"):
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "Identity", None)
# Bad type
with self.assertRaisesRegex(TypeError, "expected a string for op_name"):
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, ctx_handle,
None, [], a_2_by_2)
pywrap_tfe.TFE_Py_FastPathExecute(ctx, ctx_handle, None,
a_2_by_2)
@test_util.assert_no_new_tensors
@test_util.assert_no_garbage_created
@ -229,11 +220,9 @@ class Tests(test.TestCase):
ctx = context.context()
ctx.ensure_initialized()
ctx_handle = ctx._handle
with self.assertRaises(core._FallbackException):
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, "Split",
None, None, split_dim, value,
"num_split", -1)
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "Split", None,
split_dim, value, "num_split", -1)
@test_util.assert_no_new_tensors
@test_util.assert_no_garbage_created
@ -273,9 +262,9 @@ class Tests(test.TestCase):
ctx = context.context()
ctx.ensure_initialized()
with self.assertRaises(core._FallbackException):
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, "MatMul",
None, None, m, m, "transpose_a", False,
"transpose_b", False)
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "MatMul", None, m, m,
"transpose_a", False, "transpose_b",
False)
def testOpDefDefaultType(self):
im = np.random.randint(

View File

@ -928,8 +928,7 @@ bool GenEagerPythonOp::AddEagerFallbackCode(
void GenEagerPythonOp::AddEagerFastPathExecute() {
string fastpath_execute_params =
strings::StrCat("_ctx._context_handle, tld.device_name, \"",
op_def_.name(), "\", ", "name, tld.op_callbacks");
strings::StrCat("_ctx, \"", op_def_.name(), "\", ", "name");
string fallback_params;
for (int i = 0; i < api_def_.in_arg_size(); i++) {

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/python/lib/core/pybind11_lib.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
#include "tensorflow/python/util/util.h"
namespace py = pybind11;
@ -286,6 +287,98 @@ static py::object TFE_ClearScalarCache() {
} // namespace tensorflow
namespace {
// Wrapper around the EagerContextThreadLocalData struct (defined in
// pywrap_tfe.h), so it can be accessed from Python.
//
// For PyObject* fields, the get_*() methods return a new reference; and the
// set_*() methods create a new reference (i.e., they do not steal a reference).
class EagerContextThreadLocalDataWrapper {
public:
explicit EagerContextThreadLocalDataWrapper(py::handle py_eager_context,
py::handle is_eager,
py::handle device_spec)
: py_eager_context_(py_eager_context.ptr()) {
tensorflow::MakeEagerContextThreadLocalData(
py_eager_context.ptr(), is_eager.ptr(), device_spec.ptr());
}
~EagerContextThreadLocalDataWrapper() {
tensorflow::DestroyEagerContextThreadLocalData(py_eager_context_);
}
bool get_is_eager() const { return GetData()->is_eager; }
void set_is_eager(bool v) { GetData()->is_eager = v; }
bool get_invoking_op_callbacks() const {
return GetData()->invoking_op_callbacks;
}
void set_invoking_op_callbacks(bool v) {
GetData()->invoking_op_callbacks = v;
}
py::handle get_device_name() const {
return GetPyObject(&GetData()->device_name);
}
void set_device_name(py::handle v) {
SetPyObject(v, &GetData()->device_name);
}
py::handle get_scope_name() const {
return GetPyObject(&GetData()->scope_name);
}
void set_scope_name(py::handle v) { SetPyObject(v, &GetData()->scope_name); }
py::handle get_device_spec() const {
return GetPyObject(&GetData()->device_spec);
}
void set_device_spec(py::handle v) {
SetPyObject(v, &GetData()->device_spec);
}
py::handle get_function_call_options() const {
return GetPyObject(&GetData()->function_call_options);
}
void set_function_call_options(py::handle v) {
SetPyObject(v, &GetData()->function_call_options);
}
py::handle get_executor() const { return GetPyObject(&GetData()->executor); }
void set_executor(py::handle v) { SetPyObject(v, &GetData()->executor); }
py::handle get_op_callbacks() const {
return GetPyObject(&GetData()->op_callbacks);
}
void set_op_callbacks(py::handle v) {
SetPyObject(v, &GetData()->op_callbacks);
}
private:
tensorflow::EagerContextThreadLocalData* GetData() const {
auto* result =
tensorflow::GetEagerContextThreadLocalData(py_eager_context_);
if (!result) {
throw py::error_already_set();
}
return result;
}
py::handle GetPyObject(tensorflow::Safe_PyObjectPtr* obj) const {
Py_INCREF(obj->get());
return obj->get();
}
void SetPyObject(py::handle value, tensorflow::Safe_PyObjectPtr* ptr) {
Py_INCREF(value.ptr());
ptr->reset(value.ptr());
}
PyObject* py_eager_context_; // not owned (borrowed reference).
};
} // namespace
// py::return_value_policy::reference is defined as specified by the
// pybind11 documents listed here.
// https://pybind11.readthedocs.io/en/stable/advanced/functions.html#return-value-policies
@ -1272,6 +1365,38 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
py::class_<EagerContextThreadLocalDataWrapper>(m,
"EagerContextThreadLocalData")
.def(py::init<py::handle, py::handle, py::handle>(),
py::arg("py_eager_context"), py::arg("is_eager"),
py::arg("device_spec"))
.def_property("is_eager",
&EagerContextThreadLocalDataWrapper::get_is_eager,
&EagerContextThreadLocalDataWrapper::set_is_eager)
.def_property(
"invoking_op_callbacks",
&EagerContextThreadLocalDataWrapper::get_invoking_op_callbacks,
&EagerContextThreadLocalDataWrapper::set_invoking_op_callbacks)
.def_property("device_name",
&EagerContextThreadLocalDataWrapper::get_device_name,
&EagerContextThreadLocalDataWrapper::set_device_name)
.def_property("scope_name",
&EagerContextThreadLocalDataWrapper::get_scope_name,
&EagerContextThreadLocalDataWrapper::set_scope_name)
.def_property("device_spec",
&EagerContextThreadLocalDataWrapper::get_device_spec,
&EagerContextThreadLocalDataWrapper::set_device_spec)
.def_property(
"function_call_options",
&EagerContextThreadLocalDataWrapper::get_function_call_options,
&EagerContextThreadLocalDataWrapper::set_function_call_options)
.def_property("executor",
&EagerContextThreadLocalDataWrapper::get_executor,
&EagerContextThreadLocalDataWrapper::set_executor)
.def_property("op_callbacks",
&EagerContextThreadLocalDataWrapper::get_op_callbacks,
&EagerContextThreadLocalDataWrapper::set_op_callbacks);
// C API Enum
py::enum_<TFE_ContextDevicePlacementPolicy>(