Utility class for calling Python context mangers from c++

PiperOrigin-RevId: 329006341
Change-Id: I0a4e827900108fb57aa05ac2f96143d7ea14e12e
This commit is contained in:
Edward Loper 2020-08-28 14:09:34 -07:00 committed by TensorFlower Gardener
parent 6dcc90955d
commit 3b87d2932a
5 changed files with 355 additions and 0 deletions

View File

@ -1645,6 +1645,40 @@ py_library(
],
)
cc_library(
name = "py_context_manager",
srcs = ["framework/py_context_manager.cc"],
hdrs = ["framework/py_context_manager.h"],
deps = [
":safe_pyobject_ptr",
"//tensorflow/core:lib", # for core/platform/logging.h
"//third_party/python_runtime:headers",
],
)
# Pybind extension used by py_context_manager_test.
tf_python_pybind_extension(
name = "_py_context_manager",
srcs = ["framework/py_context_manager_pybind.cc"],
module_name = "_py_context_manager",
deps = [
":py_context_manager",
"//third_party/python_runtime:headers",
"@pybind11",
],
)
tf_py_test(
name = "py_context_manager_test",
srcs = ["framework/py_context_manager_test.py"],
python_version = "PY3",
tags = ["no_pip"],
tfrt_enabled = True,
deps = [
":_py_context_manager",
],
)
cc_library(
name = "op_def_util_cc",
srcs = ["framework/op_def_util.cc"],

View File

@ -0,0 +1,74 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/python/framework/py_context_manager.h"
#include <map>
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
bool PyContextManager::Enter(PyObject* py_context_manager) {
if (context_manager_) {
PyErr_SetString(
PyExc_ValueError,
"tensorflow::PyContextManager::Enter must be called at most once.");
}
if (!py_context_manager) return false;
context_manager_.reset(py_context_manager);
static char _enter[] = "__enter__";
var_.reset(PyObject_CallMethod(context_manager_.get(), _enter, nullptr));
return var_ != nullptr;
}
PyContextManager::~PyContextManager() {
if (var_) {
static char _exit[] = "__exit__";
static char _ooo[] = "OOO";
if (PyErr_Occurred()) {
PyObject *type, *value, *traceback;
PyErr_Fetch(&type, &value, &traceback);
value = value ? value : Py_None;
traceback = traceback ? traceback : Py_None;
Safe_PyObjectPtr result(PyObject_CallMethod(
context_manager_.get(), _exit, _ooo, type, value, traceback));
if (result) {
if (PyObject_IsTrue(result.get())) {
PyErr_SetString(
PyExc_ValueError,
"tensorflow::PyContextManager::Enter does not support "
"context managers that suppress exceptions.");
} else {
PyErr_Restore(type, value, traceback);
}
}
} else {
PyObject* result = PyObject_CallMethod(context_manager_.get(), _exit,
_ooo, Py_None, Py_None, Py_None);
if (result) {
Py_DECREF(result);
} else {
LOG(ERROR)
<< "A context manager wrapped by tensorflow::PyContextManager "
"raised a new exception from its __new__ method. This behavior "
"is not supported by PyContextManager, and the exception is "
"being suppressed.";
PyErr_Clear();
}
}
}
}
} // namespace tensorflow

View File

@ -0,0 +1,78 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_PYTHON_FRAMEWORK_PY_CONTEXT_MANAGER_H_
#define TENSORFLOW_PYTHON_FRAMEWORK_PY_CONTEXT_MANAGER_H_
#include <Python.h>
#include <string>
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
namespace tensorflow {
// Class that wraps a Python context manager, and calls the `__enter__` and
// `__exit__` methods at appropriate times:
//
// * When `PyContextManager::Enter(cm)` is called, the context manager `cm`
// is stored, and `cm.__enter__` is called. The result can be retrieved
// with `PyContextManager::var()`.
// * When the `PyContextManager` is destroyed, then `cm.__exit__` is called
// (with information about any active exception).
// * `PyContextManager::Enter(cm)` may be called at most once. If
// `PyContextManager::Enter()` is never called, then the destructor is a
// no-op (i.e., `__exit__` is not called).
//
// PyContextManager places two restrictons on the wrapped context managers:
//
// 1. The context manager may not suppress exceptions -- i.e., `__exit__`
// may not return a True value. If it does, then a new exception will be
// set, indicating that this is unuspported.
// 2. The context manager may not raise an exception from `__exit__` if the
// an exception is not active when it is called. If it does, then an error
// message will be logged, indicating that this is unsupported, and the
// exception will be suppressed.
//
// These restrictions are both intended to ensure that the state of
// PyErr_Occured is unchanged by PyContextManager's destructor. This is
// important, because changing the state of PyErr_Occurred in the destructor
// would mean that we are returning a nullptr with no exception set, or
// returning a non-null value with an exception set (both of which are invalid).
class PyContextManager {
public:
// Calls `py_context_manager.__enter__()`, and stores the result in `var`.
// Return true if `__enter__` succeeds, or false if `__enter__` raises an
// exception. (Also returns false if `py_context_manager` is nullptr.)
//
// Steals a reference to `py_context_manager`. (This reference is deleted
// when the destructor is called.)
bool Enter(PyObject* py_context_manager);
// Calls `py_context_manager.__exit__()`.
~PyContextManager();
// Returns the variable returned by `context_manager.__enter__()`.
// (This is the `var` bound by `with context_manager as var`.)
// Returns a borrowed reference.
PyObject* var() { return var_.get(); }
protected:
Safe_PyObjectPtr context_manager_;
Safe_PyObjectPtr var_;
};
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_FRAMEWORK_PY_CONTEXT_MANAGER_H_

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "tensorflow/python/framework/py_context_manager.h"
namespace py = pybind11;
namespace {
// Test harness for PyContextManager. Creates a PyContextManager `cm` that
// wraps `context_manager`, calls `cm.Enter()`, and then calls `body_func`
// with `cm.var()`. Returns the result of the function.
py::handle TestPyContextManager(py::handle context_manager,
py::handle body_func) {
tensorflow::Safe_PyObjectPtr result;
{
tensorflow::PyContextManager cm;
Py_INCREF(context_manager.ptr()); // cm.Enter steals a reference.
if (!cm.Enter(context_manager.ptr())) {
throw py::error_already_set();
}
result.reset(
PyObject_CallFunctionObjArgs(body_func.ptr(), cm.var(), nullptr));
}
// cm gets destroyed here.
if (result) {
return result.release();
} else {
throw py::error_already_set();
}
}
} // namespace
PYBIND11_MODULE(_py_context_manager, m) {
m.def("test_py_context_manager", TestPyContextManager);
}

View File

@ -0,0 +1,118 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.framework._py_context_manager."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import _py_context_manager
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
class TestContextManager(object):
def __init__(self, behavior="basic"):
self.log = []
self.behavior = behavior
def __enter__(self):
self.log.append("__enter__()")
if self.behavior == "raise_from_enter":
raise ValueError("exception in __enter__")
return "var"
def __exit__(self, ex_type, ex_value, ex_tb):
self.log.append("__exit__(%s, %s, %s)" % (ex_type, ex_value, ex_tb))
if self.behavior == "raise_from_exit":
raise ValueError("exception in __exit__")
if self.behavior == "suppress_exception":
return True
# Expected log when the body doesn't raise an exception.
NO_EXCEPTION_LOG = """\
__enter__()
body('var')
__exit__(None, None, None)"""
# Expected log when the body does raise an exception. (Regular expression.)
EXCEPTION_LOG = """\
__enter__\\(\\)
body\\('var'\\)
__exit__\\(<class 'ValueError'>, Foo, <traceback object.*>\\)"""
class OpDefUtilTest(test_util.TensorFlowTestCase):
def testBasic(self):
cm = TestContextManager()
def body(var):
cm.log.append("body(%r)" % var)
_py_context_manager.test_py_context_manager(cm, body)
self.assertEqual("\n".join(cm.log), NO_EXCEPTION_LOG)
def testBodyRaisesException(self):
cm = TestContextManager()
def body(var):
cm.log.append("body(%r)" % var)
raise ValueError("Foo")
with self.assertRaisesRegexp(ValueError, "Foo"):
_py_context_manager.test_py_context_manager(cm, body)
self.assertRegex("\n".join(cm.log), EXCEPTION_LOG)
def testEnterRaisesException(self):
cm = TestContextManager("raise_from_enter")
def body(var):
cm.log.append("body(%r)" % var)
with self.assertRaisesRegexp(ValueError, "exception in __enter__"):
_py_context_manager.test_py_context_manager(cm, body)
self.assertEqual("\n".join(cm.log), "__enter__()")
# Test behavior in unsupported case where __exit__ raises an exception.
def testExitRaisesException(self):
cm = TestContextManager("raise_from_exit")
def body(var):
cm.log.append("body(%r)" % var)
# Note: this does *not* raise an exception (but does log a warning):
_py_context_manager.test_py_context_manager(cm, body)
self.assertEqual("\n".join(cm.log), NO_EXCEPTION_LOG)
# Test behavior in unsupported case where __exit__ suppresses exception.
def testExitSuppressesException(self):
cm = TestContextManager("suppress_exception")
def body(var):
cm.log.append("body(%r)" % var)
raise ValueError("Foo")
with self.assertRaisesRegexp(
ValueError, "tensorflow::PyContextManager::Enter does not support "
"context managers that suppress exception"):
_py_context_manager.test_py_context_manager(cm, body)
self.assertRegex("\n".join(cm.log), EXCEPTION_LOG)
if __name__ == "__main__":
googletest.main()