Work around incompatibility between inspect.isgeneratorfunction and pybind11.

PiperOrigin-RevId: 297665999
Change-Id: I940aad736cec85b9a32038107380a6e3ecc6c344
This commit is contained in:
Dan Moldovan 2020-02-27 12:54:16 -08:00 committed by TensorFlower Gardener
parent 260a840659
commit d675fbb6cf
6 changed files with 67 additions and 1 deletions

View File

@ -73,6 +73,7 @@ tf_py_test(
deps = [ deps = [
":impl", ":impl",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python/autograph/impl/testing:pybind_for_testing",
"@gast_archive//:gast", "@gast_archive//:gast",
], ],
) )

View File

@ -400,7 +400,9 @@ def is_whitelisted(
logging.log(2, 'Whitelisted: %s: %s', o, rule) logging.log(2, 'Whitelisted: %s: %s', o, rule)
return True return True
if tf_inspect.isgeneratorfunction(o): # The check for __code__ below is because isgeneratorfunction crashes
# without one.
if hasattr(o, '__code__') and tf_inspect.isgeneratorfunction(o):
logging.warn( logging.warn(
'Entity %s appears to be a generator function. It will not be converted' 'Entity %s appears to be a generator function. It will not be converted'
' by AutoGraph.', o) ' by AutoGraph.', o)

View File

@ -32,6 +32,7 @@ from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.impl import api from tensorflow.python.autograph.impl import api
from tensorflow.python.autograph.impl import conversion from tensorflow.python.autograph.impl import conversion
from tensorflow.python.autograph.impl.testing import pybind_for_testing
from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import parser
from tensorflow.python.eager import function from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
@ -119,6 +120,13 @@ class ConversionTest(test.TestCase):
self.assertTrue(conversion.is_whitelisted(bound_method)) self.assertTrue(conversion.is_whitelisted(bound_method))
def test_is_whitelisted_pybind(self):
test_object = pybind_for_testing.TestClassDef()
with test.mock.patch.object(config, 'CONVERSION_RULES', ()):
# TODO(mdan): This should return True for functions and methods.
# Note: currently, native bindings are whitelisted by a separate check.
self.assertFalse(conversion.is_whitelisted(test_object.method))
def test_convert_entity_to_ast_unsupported_types(self): def test_convert_entity_to_ast_unsupported_types(self):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
program_ctx = self._simple_program_ctx() program_ctx = self._simple_program_ctx()

View File

@ -0,0 +1,15 @@
load("//tensorflow:tensorflow.bzl", "pybind_extension")
package(
default_visibility = ["//tensorflow:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)
pybind_extension(
name = "pybind_for_testing",
srcs = ["pybind_for_testing.cc"],
module_name = "pybind_for_testing",
deps = [
"@pybind11",
],
)

View File

@ -0,0 +1,39 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ==============================================================================
#include "include/pybind11/pybind11.h"
#include "include/pybind11/pytypes.h"
#include "include/pybind11/stl.h"
namespace autograph {
namespace py = pybind11;
class TestClassDef {
public:
TestClassDef() = default;
py::object Method() const;
};
py::object TestClassDef::Method() const { return py::none(); }
PYBIND11_MODULE(pybind_for_testing, m) {
py::class_<TestClassDef>(m, "TestClassDef")
.def(py::init<>())
.def("method", &TestClassDef::Method);
}
} // namespace autograph

View File

@ -84,6 +84,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/lite/python:tflite_convert", "//tensorflow/lite/python:tflite_convert",
"//tensorflow/lite/toco/python:toco_from_protos", "//tensorflow/lite/toco/python:toco_from_protos",
"//tensorflow/python/autograph/core:test_lib", "//tensorflow/python/autograph/core:test_lib",
"//tensorflow/python/autograph/impl/testing:pybind_for_testing",
"//tensorflow/python/autograph/pyct/testing", "//tensorflow/python/autograph/pyct/testing",
"//tensorflow/python/autograph/pyct/common_transformers:common_transformers", "//tensorflow/python/autograph/pyct/common_transformers:common_transformers",
"//tensorflow/python/compiler:compiler", "//tensorflow/python/compiler:compiler",