Utility class for calling Python context mangers from c++
PiperOrigin-RevId: 329006341 Change-Id: I0a4e827900108fb57aa05ac2f96143d7ea14e12e
This commit is contained in:
parent
6dcc90955d
commit
3b87d2932a
tensorflow/python
@ -1645,6 +1645,40 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "py_context_manager",
|
||||
srcs = ["framework/py_context_manager.cc"],
|
||||
hdrs = ["framework/py_context_manager.h"],
|
||||
deps = [
|
||||
":safe_pyobject_ptr",
|
||||
"//tensorflow/core:lib", # for core/platform/logging.h
|
||||
"//third_party/python_runtime:headers",
|
||||
],
|
||||
)
|
||||
|
||||
# Pybind extension used by py_context_manager_test.
|
||||
tf_python_pybind_extension(
|
||||
name = "_py_context_manager",
|
||||
srcs = ["framework/py_context_manager_pybind.cc"],
|
||||
module_name = "_py_context_manager",
|
||||
deps = [
|
||||
":py_context_manager",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "py_context_manager_test",
|
||||
srcs = ["framework/py_context_manager_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = ["no_pip"],
|
||||
tfrt_enabled = True,
|
||||
deps = [
|
||||
":_py_context_manager",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "op_def_util_cc",
|
||||
srcs = ["framework/op_def_util.cc"],
|
||||
|
74
tensorflow/python/framework/py_context_manager.cc
Normal file
74
tensorflow/python/framework/py_context_manager.cc
Normal file
@ -0,0 +1,74 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/python/framework/py_context_manager.h"
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
bool PyContextManager::Enter(PyObject* py_context_manager) {
|
||||
if (context_manager_) {
|
||||
PyErr_SetString(
|
||||
PyExc_ValueError,
|
||||
"tensorflow::PyContextManager::Enter must be called at most once.");
|
||||
}
|
||||
if (!py_context_manager) return false;
|
||||
context_manager_.reset(py_context_manager);
|
||||
static char _enter[] = "__enter__";
|
||||
var_.reset(PyObject_CallMethod(context_manager_.get(), _enter, nullptr));
|
||||
return var_ != nullptr;
|
||||
}
|
||||
|
||||
PyContextManager::~PyContextManager() {
|
||||
if (var_) {
|
||||
static char _exit[] = "__exit__";
|
||||
static char _ooo[] = "OOO";
|
||||
if (PyErr_Occurred()) {
|
||||
PyObject *type, *value, *traceback;
|
||||
PyErr_Fetch(&type, &value, &traceback);
|
||||
value = value ? value : Py_None;
|
||||
traceback = traceback ? traceback : Py_None;
|
||||
Safe_PyObjectPtr result(PyObject_CallMethod(
|
||||
context_manager_.get(), _exit, _ooo, type, value, traceback));
|
||||
if (result) {
|
||||
if (PyObject_IsTrue(result.get())) {
|
||||
PyErr_SetString(
|
||||
PyExc_ValueError,
|
||||
"tensorflow::PyContextManager::Enter does not support "
|
||||
"context managers that suppress exceptions.");
|
||||
} else {
|
||||
PyErr_Restore(type, value, traceback);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
PyObject* result = PyObject_CallMethod(context_manager_.get(), _exit,
|
||||
_ooo, Py_None, Py_None, Py_None);
|
||||
if (result) {
|
||||
Py_DECREF(result);
|
||||
} else {
|
||||
LOG(ERROR)
|
||||
<< "A context manager wrapped by tensorflow::PyContextManager "
|
||||
"raised a new exception from its __new__ method. This behavior "
|
||||
"is not supported by PyContextManager, and the exception is "
|
||||
"being suppressed.";
|
||||
PyErr_Clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
78
tensorflow/python/framework/py_context_manager.h
Normal file
78
tensorflow/python/framework/py_context_manager.h
Normal file
@ -0,0 +1,78 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_PYTHON_FRAMEWORK_PY_CONTEXT_MANAGER_H_
|
||||
#define TENSORFLOW_PYTHON_FRAMEWORK_PY_CONTEXT_MANAGER_H_
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Class that wraps a Python context manager, and calls the `__enter__` and
|
||||
// `__exit__` methods at appropriate times:
|
||||
//
|
||||
// * When `PyContextManager::Enter(cm)` is called, the context manager `cm`
|
||||
// is stored, and `cm.__enter__` is called. The result can be retrieved
|
||||
// with `PyContextManager::var()`.
|
||||
// * When the `PyContextManager` is destroyed, then `cm.__exit__` is called
|
||||
// (with information about any active exception).
|
||||
// * `PyContextManager::Enter(cm)` may be called at most once. If
|
||||
// `PyContextManager::Enter()` is never called, then the destructor is a
|
||||
// no-op (i.e., `__exit__` is not called).
|
||||
//
|
||||
// PyContextManager places two restrictons on the wrapped context managers:
|
||||
//
|
||||
// 1. The context manager may not suppress exceptions -- i.e., `__exit__`
|
||||
// may not return a True value. If it does, then a new exception will be
|
||||
// set, indicating that this is unuspported.
|
||||
// 2. The context manager may not raise an exception from `__exit__` if the
|
||||
// an exception is not active when it is called. If it does, then an error
|
||||
// message will be logged, indicating that this is unsupported, and the
|
||||
// exception will be suppressed.
|
||||
//
|
||||
// These restrictions are both intended to ensure that the state of
|
||||
// PyErr_Occured is unchanged by PyContextManager's destructor. This is
|
||||
// important, because changing the state of PyErr_Occurred in the destructor
|
||||
// would mean that we are returning a nullptr with no exception set, or
|
||||
// returning a non-null value with an exception set (both of which are invalid).
|
||||
class PyContextManager {
|
||||
public:
|
||||
// Calls `py_context_manager.__enter__()`, and stores the result in `var`.
|
||||
// Return true if `__enter__` succeeds, or false if `__enter__` raises an
|
||||
// exception. (Also returns false if `py_context_manager` is nullptr.)
|
||||
//
|
||||
// Steals a reference to `py_context_manager`. (This reference is deleted
|
||||
// when the destructor is called.)
|
||||
bool Enter(PyObject* py_context_manager);
|
||||
|
||||
// Calls `py_context_manager.__exit__()`.
|
||||
~PyContextManager();
|
||||
|
||||
// Returns the variable returned by `context_manager.__enter__()`.
|
||||
// (This is the `var` bound by `with context_manager as var`.)
|
||||
// Returns a borrowed reference.
|
||||
PyObject* var() { return var_.get(); }
|
||||
|
||||
protected:
|
||||
Safe_PyObjectPtr context_manager_;
|
||||
Safe_PyObjectPtr var_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_FRAMEWORK_PY_CONTEXT_MANAGER_H_
|
51
tensorflow/python/framework/py_context_manager_pybind.cc
Normal file
51
tensorflow/python/framework/py_context_manager_pybind.cc
Normal file
@ -0,0 +1,51 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "tensorflow/python/framework/py_context_manager.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace {
|
||||
|
||||
// Test harness for PyContextManager. Creates a PyContextManager `cm` that
|
||||
// wraps `context_manager`, calls `cm.Enter()`, and then calls `body_func`
|
||||
// with `cm.var()`. Returns the result of the function.
|
||||
py::handle TestPyContextManager(py::handle context_manager,
|
||||
py::handle body_func) {
|
||||
tensorflow::Safe_PyObjectPtr result;
|
||||
{
|
||||
tensorflow::PyContextManager cm;
|
||||
Py_INCREF(context_manager.ptr()); // cm.Enter steals a reference.
|
||||
if (!cm.Enter(context_manager.ptr())) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
result.reset(
|
||||
PyObject_CallFunctionObjArgs(body_func.ptr(), cm.var(), nullptr));
|
||||
}
|
||||
// cm gets destroyed here.
|
||||
|
||||
if (result) {
|
||||
return result.release();
|
||||
} else {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
PYBIND11_MODULE(_py_context_manager, m) {
|
||||
m.def("test_py_context_manager", TestPyContextManager);
|
||||
}
|
118
tensorflow/python/framework/py_context_manager_test.py
Normal file
118
tensorflow/python/framework/py_context_manager_test.py
Normal file
@ -0,0 +1,118 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.python.framework._py_context_manager."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import _py_context_manager
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class TestContextManager(object):
|
||||
|
||||
def __init__(self, behavior="basic"):
|
||||
self.log = []
|
||||
self.behavior = behavior
|
||||
|
||||
def __enter__(self):
|
||||
self.log.append("__enter__()")
|
||||
if self.behavior == "raise_from_enter":
|
||||
raise ValueError("exception in __enter__")
|
||||
return "var"
|
||||
|
||||
def __exit__(self, ex_type, ex_value, ex_tb):
|
||||
self.log.append("__exit__(%s, %s, %s)" % (ex_type, ex_value, ex_tb))
|
||||
if self.behavior == "raise_from_exit":
|
||||
raise ValueError("exception in __exit__")
|
||||
if self.behavior == "suppress_exception":
|
||||
return True
|
||||
|
||||
|
||||
# Expected log when the body doesn't raise an exception.
|
||||
NO_EXCEPTION_LOG = """\
|
||||
__enter__()
|
||||
body('var')
|
||||
__exit__(None, None, None)"""
|
||||
|
||||
# Expected log when the body does raise an exception. (Regular expression.)
|
||||
EXCEPTION_LOG = """\
|
||||
__enter__\\(\\)
|
||||
body\\('var'\\)
|
||||
__exit__\\(<class 'ValueError'>, Foo, <traceback object.*>\\)"""
|
||||
|
||||
|
||||
class OpDefUtilTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testBasic(self):
|
||||
cm = TestContextManager()
|
||||
|
||||
def body(var):
|
||||
cm.log.append("body(%r)" % var)
|
||||
|
||||
_py_context_manager.test_py_context_manager(cm, body)
|
||||
self.assertEqual("\n".join(cm.log), NO_EXCEPTION_LOG)
|
||||
|
||||
def testBodyRaisesException(self):
|
||||
cm = TestContextManager()
|
||||
|
||||
def body(var):
|
||||
cm.log.append("body(%r)" % var)
|
||||
raise ValueError("Foo")
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Foo"):
|
||||
_py_context_manager.test_py_context_manager(cm, body)
|
||||
self.assertRegex("\n".join(cm.log), EXCEPTION_LOG)
|
||||
|
||||
def testEnterRaisesException(self):
|
||||
cm = TestContextManager("raise_from_enter")
|
||||
|
||||
def body(var):
|
||||
cm.log.append("body(%r)" % var)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "exception in __enter__"):
|
||||
_py_context_manager.test_py_context_manager(cm, body)
|
||||
self.assertEqual("\n".join(cm.log), "__enter__()")
|
||||
|
||||
# Test behavior in unsupported case where __exit__ raises an exception.
|
||||
def testExitRaisesException(self):
|
||||
cm = TestContextManager("raise_from_exit")
|
||||
|
||||
def body(var):
|
||||
cm.log.append("body(%r)" % var)
|
||||
|
||||
# Note: this does *not* raise an exception (but does log a warning):
|
||||
_py_context_manager.test_py_context_manager(cm, body)
|
||||
self.assertEqual("\n".join(cm.log), NO_EXCEPTION_LOG)
|
||||
|
||||
# Test behavior in unsupported case where __exit__ suppresses exception.
|
||||
def testExitSuppressesException(self):
|
||||
cm = TestContextManager("suppress_exception")
|
||||
|
||||
def body(var):
|
||||
cm.log.append("body(%r)" % var)
|
||||
raise ValueError("Foo")
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "tensorflow::PyContextManager::Enter does not support "
|
||||
"context managers that suppress exception"):
|
||||
_py_context_manager.test_py_context_manager(cm, body)
|
||||
self.assertRegex("\n".join(cm.log), EXCEPTION_LOG)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
Loading…
Reference in New Issue
Block a user