Converted py_func.i to pybind11

This is part of a larger effort to deprecate swig and eventually with
modularization break pywrap_tensorflow into smaller components.
Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md
for more information.

PiperOrigin-RevId: 271409413
This commit is contained in:
Sergei Lebedev 2019-09-26 12:44:11 -07:00 committed by TensorFlower Gardener
parent 5ac1558cdb
commit afa2418f90
9 changed files with 42 additions and 24 deletions

View File

@ -19,7 +19,7 @@ visibility = [
"//bazel_pip/tensorflow/lite/toco/python:__pkg__",
]
load("//tensorflow:tensorflow.bzl", "if_mlir", "if_not_v2", "if_not_windows", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "if_mlir", "if_not_windows", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test", "cc_header_only_library")
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
load("//tensorflow:tensorflow.bzl", "pybind_extension")
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
@ -518,6 +518,24 @@ cc_library(
],
)
cc_header_only_library(
name = "py_func_headers_lib",
deps = [
":py_func_lib",
],
)
tf_python_pybind_extension(
name = "_pywrap_py_func",
srcs = ["lib/core/py_func_wrapper.cc"],
module_name = "_pywrap_py_func",
deps = [
":py_func_headers_lib",
"//third_party/python_runtime:headers",
"@pybind11",
],
)
cc_library(
name = "safe_ptr",
srcs = ["lib/core/safe_ptr.cc"],
@ -3639,6 +3657,7 @@ py_library(
srcs = ["ops/script_ops.py"],
srcs_version = "PY2AND3",
deps = [
":_pywrap_py_func",
":array_ops",
":framework_for_generated_wrappers",
":script_ops_gen",
@ -5008,7 +5027,6 @@ tf_py_wrap_cc(
"grappler/tf_optimizer.i",
"lib/core/bfloat16.i",
"lib/core/py_exception_registry.i",
"lib/core/py_func.i",
"lib/core/strings.i",
"lib/io/file_io.i",
"lib/io/py_record_reader.i",
@ -5117,6 +5135,7 @@ genrule(
name = "pybind_symbol_target_libs_file",
srcs = [
":cpp_python_util", # util
":py_func_lib", # py_func
"//tensorflow/core:util_port", # util_port
"//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer

View File

@ -17,6 +17,11 @@ limitations under the License.
#include <Python.h>
// clang-format: off
// Must be inlcluded first.
#include "tensorflow/python/lib/core/numpy.h"
// clang-format: on
#include <array>
#include "numpy/arrayobject.h"
@ -25,7 +30,9 @@ limitations under the License.
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"

View File

@ -16,11 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_PYTHON_LIB_CORE_PY_FUNC_H_
#define TENSORFLOW_PYTHON_LIB_CORE_PY_FUNC_H_
// Must be included first
#include "tensorflow/python/lib/core/numpy.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include <Python.h>
namespace tensorflow {

View File

@ -1,4 +1,4 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,17 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/platform/base.i"
%{
#include "include/pybind11/pybind11.h"
#include "tensorflow/python/lib/core/py_func.h"
%}
%ignoreall
namespace py = pybind11;
%unignore tensorflow;
%unignore tensorflow::InitializePyTrampoline;
%include "tensorflow/python/lib/core/py_func.h"
%unignoreall
PYBIND11_MODULE(_pywrap_py_func, m) {
m.def("initialize_py_trampoline", [](py::object trampoline) {
return tensorflow::InitializePyTrampoline(trampoline.ptr());
});
}

View File

@ -28,7 +28,7 @@ import weakref
import numpy as np
import six
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import _pywrap_py_func
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@ -259,7 +259,7 @@ class FuncRegistry(object):
# Global registry for py functions.
_py_funcs = FuncRegistry()
pywrap_tensorflow.InitializePyTrampoline(_py_funcs)
_pywrap_py_func.initialize_py_trampoline(_py_funcs)
def _internal_py_func(func,

View File

@ -21,7 +21,6 @@ limitations under the License.
%include "tensorflow/python/pywrap_tfe.i"
%include "tensorflow/python/lib/core/py_func.i"
%include "tensorflow/python/lib/core/py_exception_registry.i"
%include "tensorflow/python/lib/io/py_record_reader.i"

View File

@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/c/checkpoint_reader.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h"
#include "tensorflow/python/lib/core/py_func.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
%}

View File

@ -20,7 +20,6 @@ limitations under the License.
%{
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/stat_summarizer.h"
#include "tensorflow/python/lib/core/py_func.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"

View File

@ -51,3 +51,6 @@ tensorflow::EventsWriter::InitWithSuffix
tensorflow::EventsWriter::WriteSerializedEvent
tensorflow::EventsWriter::Flush
tensorflow::EventsWriter::Close
[py_func_lib] # py_func
tensorflow::InitializePyTrampoline