Work around incompatibility between inspect.isgeneratorfunction and pybind11.

PiperOrigin-RevId: 297665999
Change-Id: I940aad736cec85b9a32038107380a6e3ecc6c344
This commit is contained in:
Dan Moldovan 2020-02-27 12:54:16 -08:00 committed by TensorFlower Gardener
parent 260a840659
commit d675fbb6cf
6 changed files with 67 additions and 1 deletions

View File

@ -73,6 +73,7 @@ tf_py_test(
deps = [
":impl",
"//tensorflow/python:client_testlib",
"//tensorflow/python/autograph/impl/testing:pybind_for_testing",
"@gast_archive//:gast",
],
)

View File

@ -400,7 +400,9 @@ def is_whitelisted(
logging.log(2, 'Whitelisted: %s: %s', o, rule)
return True
if tf_inspect.isgeneratorfunction(o):
# The check for __code__ below is because isgeneratorfunction crashes
# without one.
if hasattr(o, '__code__') and tf_inspect.isgeneratorfunction(o):
logging.warn(
'Entity %s appears to be a generator function. It will not be converted'
' by AutoGraph.', o)

View File

@ -32,6 +32,7 @@ from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.impl import api
from tensorflow.python.autograph.impl import conversion
from tensorflow.python.autograph.impl.testing import pybind_for_testing
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
@ -119,6 +120,13 @@ class ConversionTest(test.TestCase):
self.assertTrue(conversion.is_whitelisted(bound_method))
def test_is_whitelisted_pybind(self):
test_object = pybind_for_testing.TestClassDef()
with test.mock.patch.object(config, 'CONVERSION_RULES', ()):
# TODO(mdan): This should return True for functions and methods.
# Note: currently, native bindings are whitelisted by a separate check.
self.assertFalse(conversion.is_whitelisted(test_object.method))
def test_convert_entity_to_ast_unsupported_types(self):
with self.assertRaises(NotImplementedError):
program_ctx = self._simple_program_ctx()

View File

@ -0,0 +1,15 @@
load("//tensorflow:tensorflow.bzl", "pybind_extension")
package(
default_visibility = ["//tensorflow:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)
pybind_extension(
name = "pybind_for_testing",
srcs = ["pybind_for_testing.cc"],
module_name = "pybind_for_testing",
deps = [
"@pybind11",
],
)

View File

@ -0,0 +1,39 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ==============================================================================
#include "include/pybind11/pybind11.h"
#include "include/pybind11/pytypes.h"
#include "include/pybind11/stl.h"
namespace autograph {
namespace py = pybind11;
class TestClassDef {
public:
TestClassDef() = default;
py::object Method() const;
};
py::object TestClassDef::Method() const { return py::none(); }
PYBIND11_MODULE(pybind_for_testing, m) {
py::class_<TestClassDef>(m, "TestClassDef")
.def(py::init<>())
.def("method", &TestClassDef::Method);
}
} // namespace autograph

View File

@ -84,6 +84,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/lite/python:tflite_convert",
"//tensorflow/lite/toco/python:toco_from_protos",
"//tensorflow/python/autograph/core:test_lib",
"//tensorflow/python/autograph/impl/testing:pybind_for_testing",
"//tensorflow/python/autograph/pyct/testing",
"//tensorflow/python/autograph/pyct/common_transformers:common_transformers",
"//tensorflow/python/compiler:compiler",