Export the TF Session classes and functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.

PiperOrigin-RevId: 292259851
Change-Id: If5abe93f9cf25018d185e220d4dfbc216b5f3b32
This commit is contained in:
Amit Patankar 2020-01-29 18:25:16 -08:00 committed by TensorFlower Gardener
parent 4ca8bf54d7
commit a02fe6c24a
46 changed files with 1615 additions and 1237 deletions

View File

@ -57,6 +57,7 @@ filegroup(
name = "pywrap_required_hdrs",
srcs = [
"c_api_internal.h",
"python_api.h",
"tf_status_helper.h",
"tf_status_internal.h",
"tf_tensor_internal.h",
@ -98,6 +99,17 @@ tf_cuda_library(
],
)
filegroup(
name = "pywrap_tf_session_hdrs",
srcs = [
"python_api.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "tf_attrtype",
hdrs = ["tf_attrtype.h"],

View File

@ -18,6 +18,9 @@ limitations under the License.
namespace tensorflow {
// Adjust value in third_party/tensorflow/python/client/tf_session_wrapper.cc
// in the get_tensor_handle_key function if adjusting the value for
// kTensorHandleResourceTypeName.
const char* SessionState::kTensorHandleResourceTypeName = "TensorHandle";
Status SessionState::GetTensor(const string& handle, Tensor* tensor) {

View File

@ -173,6 +173,7 @@ py_library(
":platform",
":proto_ops",
":pywrap_tensorflow",
":pywrap_tf_session",
":pywrap_tfe",
":rnn_ops_gen",
":saver_test_utils",
@ -559,6 +560,58 @@ cc_library(
alwayslink = 1,
)
py_library(
name = "pywrap_tf_session",
srcs = ["client/pywrap_tf_session.py"],
visibility = ["//visibility:public"],
deps = [
":_pywrap_tf_session",
":pywrap_tensorflow",
],
)
tf_python_pybind_extension(
name = "_pywrap_tf_session",
srcs = ["client/tf_session_wrapper.cc"],
hdrs = [
"client/tf_session_helper.h",
"lib/core/numpy.h",
"lib/core/safe_ptr.h",
"//tensorflow/c:headers",
"//tensorflow/c:pywrap_required_hdrs",
"//tensorflow/c/eager:headers",
"//tensorflow/core/common_runtime/eager:pywrap_required_hdrs",
"//tensorflow/core/distributed_runtime:pywrap_required_hdrs",
"//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs",
"//tensorflow/core/framework:pywrap_required_hdrs",
],
module_name = "_pywrap_tf_session",
deps = [
":pybind11_lib",
":pybind11_status",
"//third_party/py/numpy:headers",
"@pybind11",
"//third_party/python_runtime:headers",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:framework",
"//tensorflow/core:core_cpu_headers_lib",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/types:optional",
] + if_static(
extra_deps = [
"//tensorflow/core:eager_service_proto_cc",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:worker_proto_cc",
],
otherwise = [
"//tensorflow/core:eager_service_proto_cc_headers_only",
"//tensorflow/core:master_proto_cc_headers_only",
"//tensorflow/core:worker_proto_cc_headers_only",
],
),
)
tf_python_pybind_extension(
name = "_pywrap_tfprof",
srcs = ["util/tfprof_wrapper.cc"],
@ -1173,6 +1226,7 @@ py_library(
":lib",
":platform",
":pywrap_tensorflow",
":pywrap_tf_session",
":pywrap_tfe",
":pywrap_mlir",
":random_seed",
@ -1194,7 +1248,7 @@ py_library(
srcs = ["framework/c_api_util.py"],
srcs_version = "PY2AND3",
deps = [
":pywrap_tensorflow",
":pywrap_tf_session",
"//tensorflow/core:protos_all_py",
],
)
@ -1271,6 +1325,7 @@ py_library(
":_pywrap_py_exception_registry",
":c_api_util",
":error_interpolation",
":pywrap_tf_session",
":util",
],
)
@ -1296,6 +1351,7 @@ py_library(
":framework_ops",
":graph_to_function_def",
":op_def_registry",
":pywrap_tf_session",
":util",
":variable_scope",
"//tensorflow/core:protos_all_py",
@ -1387,7 +1443,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
":pywrap_tensorflow",
":pywrap_tf_session",
":util",
"//tensorflow/core:protos_all_py",
],
@ -1629,6 +1685,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":control_flow_ops",
":pywrap_tf_session",
":tensor_util",
],
)
@ -1801,7 +1858,7 @@ py_library(
srcs = ["framework/versions.py"],
srcs_version = "PY2AND3",
deps = [
":pywrap_tensorflow",
":pywrap_tf_session",
],
)
@ -1837,7 +1894,7 @@ py_library(
":gpu_util",
":platform",
":platform_test",
":pywrap_tensorflow",
":pywrap_tf_session",
":random_seed",
":resource_variable_ops",
":session",
@ -3244,6 +3301,7 @@ py_library(
":functional_ops_gen",
":gradients_util",
":list_ops",
":pywrap_tf_session",
":tensor_array_ops",
":tensor_shape",
":tensor_util",
@ -3334,6 +3392,7 @@ py_library(
deps = [
":gradients_impl",
":gradients_util",
":pywrap_tf_session",
":unconnected_gradients",
"//tensorflow/python/eager:forwardprop",
"//tensorflow/python/eager:function",
@ -3776,6 +3835,7 @@ py_library(
":framework_for_generated_wrappers",
":math_ops",
":math_ops_gen",
":pywrap_tf_session",
":tensor_util",
"//tensorflow/python/eager:context",
"//third_party/py/numpy",
@ -3838,6 +3898,7 @@ py_library(
":array_ops_gen",
":dtypes",
":framework_ops",
":pywrap_tf_session",
":resource_variable_ops_gen",
":tensor_shape",
":util",
@ -5070,7 +5131,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":framework",
":pywrap_tensorflow",
":pywrap_tf_session",
":util",
"//tensorflow/core:protos_all_py",
],
@ -5588,8 +5649,6 @@ tf_py_wrap_cc(
name = "pywrap_tensorflow_internal",
srcs = ["tensorflow.i"],
swig_includes = [
"client/tf_session.i",
"client/tf_sessionrun_wrapper.i",
"grappler/cluster.i",
"grappler/cost_analyzer.i",
"grappler/item.i",
@ -5683,6 +5742,10 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [
"//tensorflow/core/profiler/lib:profiler_session", # tfe
"//tensorflow/c:tf_status_helper", # tfe
"//tensorflow/compiler/mlir/python:mlir", # mlir
"//tensorflow/core:op_gen_lib", # tf_session
"//tensorflow/core:core_cpu_base_no_ops", # tf_session
"//tensorflow/c:python_api", # tf_session
"//tensorflow/python:tf_session_helper", # tf_session
]
# Filter the DEF file to reduce the number of symbols to 64K or less.

View File

@ -0,0 +1,70 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python module for Session ops, vars, and functions exported by pybind11."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=invalid-import-order,g-bad-import-order, wildcard-import, unused-import
from tensorflow.python import pywrap_tensorflow
from tensorflow.python._pywrap_tf_session import *
from tensorflow.python._pywrap_tf_session import _TF_SetTarget
from tensorflow.python._pywrap_tf_session import _TF_SetConfig
from tensorflow.python._pywrap_tf_session import _TF_NewSessionOptions
# Convert versions to strings for Python2 and keep api_compatibility_test green.
# We can remove this hack once we remove Python2 presubmits. pybind11 can only
# return unicode for Python2 even with py::str.
# https://pybind11.readthedocs.io/en/stable/advanced/cast/strings.html#returning-c-strings-to-python
# pylint: disable=undefined-variable
__version__ = str(get_version())
__git_version__ = str(get_git_version())
__compiler_version__ = str(get_compiler_version())
__cxx11_abi_flag__ = get_cxx11_abi_flag()
__monolithic_build__ = get_monolithic_build()
# User getters to hold attributes rather than pybind11's m.attr due to
# b/145559202.
GRAPH_DEF_VERSION = get_graph_def_version()
GRAPH_DEF_VERSION_MIN_CONSUMER = get_graph_def_version_min_consumer()
GRAPH_DEF_VERSION_MIN_PRODUCER = get_graph_def_version_min_producer()
TENSOR_HANDLE_KEY = get_tensor_handle_key()
# pylint: enable=undefined-variable
# Disable pylint invalid name warnings for legacy functions.
# pylint: disable=invalid-name
def TF_NewSessionOptions(target=None, config=None):
# NOTE: target and config are validated in the session constructor.
opts = _TF_NewSessionOptions()
if target is not None:
_TF_SetTarget(opts, target)
if config is not None:
config_str = config.SerializeToString()
_TF_SetConfig(opts, config_str)
return opts
# Disable pylind undefined-variable as the variable is exported in the shared
# object via pybind11.
# pylint: disable=undefined-variable
def TF_Reset(target, containers=None, config=None):
opts = TF_NewSessionOptions(target=target, config=config)
try:
TF_Reset_wrapper(opts, containers)
finally:
TF_DeleteSessionOptions(opts)

View File

@ -28,7 +28,7 @@ import wrapt
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import pywrap_tensorflow as tf_session
from tensorflow.python.client import pywrap_tf_session as tf_session
from tensorflow.python.eager import context
from tensorflow.python.eager import monitoring
from tensorflow.python.framework import device

View File

@ -21,7 +21,7 @@ from __future__ import print_function
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow as tf_session
from tensorflow.python.client import pywrap_tf_session as tf_session
from tensorflow.python.client import session
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops

View File

@ -1,877 +0,0 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/lib/core/strings.i"
%include "tensorflow/python/platform/base.i"
%{
#include "tensorflow/c/python_api.h"
#include "tensorflow/core/framework/session_state.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/python/client/tf_session_helper.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
#include "tensorflow/python/eager/pywrap_tfe.h"
// We were getting lucky on imports with safe_ptr.h being placed prior to
// tf_session which imported safe_ptr. We also need pywrap_tfe.h to cast
// one of the inputs to a graph function from a Python string to const char*.
// Helper function to convert a Python list of Tensors to a C++ vector of
// TF_Outputs.
//
// Returns true if successful. Otherwise, returns false and sets error_msg.
bool PyTensorListToVector(PyObject* py_tensor_list,
std::vector<TF_Output>* vec,
string* error_msg) {
if (!PyList_Check(py_tensor_list)) {
*error_msg = "expected Python list.";
return false;
}
size_t size = PyList_Size(py_tensor_list);
for (int i = 0; i < size; ++i) {
PyObject* item = PyList_GetItem(py_tensor_list, i);
TF_Output* input_ptr;
if (!SWIG_IsOK(SWIG_ConvertPtr(item, reinterpret_cast<void**>(&input_ptr),
SWIGTYPE_p_TF_Output, 0))) {
*error_msg = "expected Python list of wrapped TF_Output objects. "
"Found python list of something else.";
return false;
}
vec->push_back(*input_ptr);
}
return true;
}
// Helper function to convert a TF_Output to a wrapped TF_Output Python object.
PyObject* CreateWrappedTFOutput(TF_Output tf_output) {
// We used heap-allocated pointers in the Python runtime (this is what SWIG
// generates by default for functions returning TF_Output).
TF_Output* tf_output_ptr = new TF_Output(tf_output);
// Use SWIG_POINTER_OWN so the TF_Output* is deleted by Python.
return SWIG_NewPointerObj(tf_output_ptr, SWIGTYPE_p_TF_Output,
SWIG_POINTER_OWN);
}
// Helper function to convert a TF_Operation to a wrapped TF_Operation Python
// object.
PyObject* CreateWrappedTFOperation(TF_Operation* tf_operation) {
// No flags since operation is owned by TF_Graph.
return SWIG_NewPointerObj(tf_operation, SWIGTYPE_p_TF_Operation, 0);
}
// Helper function to convert a Python list of ints to a C++ vector of int64s
void PyInt64ListToVector(PyObject* py_int_seq, std::vector<int64_t>* vec) {
int size = PySequence_Fast_GET_SIZE(py_int_seq);
for (int i = 0; i < size; ++i) {
PyObject* item = PySequence_Fast_GET_ITEM(py_int_seq, i);
vec->push_back(PyLong_AsLongLong(item));
}
}
%}
%include "tensorflow/c/tf_datatype.h"
%include "tensorflow/c/tf_status.h"
%include "tensorflow/python/client/tf_sessionrun_wrapper.i"
// Required to use PyArray_* functions.
%init %{
tensorflow::ImportNumpy();
%}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* op_name {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
// TensorFlow version and GraphDef versions
%constant const char* __version__ = TF_VERSION_STRING;
%constant int GRAPH_DEF_VERSION = TF_GRAPH_DEF_VERSION;
%constant int GRAPH_DEF_VERSION_MIN_CONSUMER = TF_GRAPH_DEF_VERSION_MIN_CONSUMER;
%constant int GRAPH_DEF_VERSION_MIN_PRODUCER = TF_GRAPH_DEF_VERSION_MIN_PRODUCER;
// Git version information
%constant const char* __git_version__ = tf_git_version();
// Compiler
%constant const char* __compiler_version__ = tf_compiler_version();
// _GLIBCXX_USE_CXX11_ABI flag value
%constant const int __cxx11_abi_flag__ = tf_cxx11_abi_flag();
// Flag indicating whether the build is monolithic
%constant const int __monolithic_build__ = tf_monolithic_build();
// Release the Python GIL for the duration of most methods.
%exception {
Py_BEGIN_ALLOW_THREADS;
$action
Py_END_ALLOW_THREADS;
}
// The target input to TF_SetTarget() is passed as a null-terminated
// const char*.
%typemap(in) (const char* target) {
$1 = PyBytes_AsString($input);
if (!$1) {
// Python has raised an error.
SWIG_fail;
}
}
// Constants used by TensorHandle (get_session_handle).
%constant const char* TENSOR_HANDLE_KEY = tensorflow::SessionState::kTensorHandleResourceTypeName;
// Convert TF_OperationName output to unicode python string
%typemap(out) const char* TF_OperationName {
$result = PyUnicode_FromString($1);
}
// Convert TF_OperationOpType output to unicode python string
%typemap(out) const char* TF_OperationOpType {
$result = PyUnicode_FromString($1);
}
// Convert TF_DeviceListMemoryBytes and TF_Dim int64_t output to Python integers
%typemap(out) int64_t {
$result = PyLong_FromLongLong($1);
}
// Convert TF_DeviceListIncarnation uint64_t output to Python integer
%typemap(out) uint64_t {
$result = PyLong_FromUnsignedLongLong($1);
}
// Convert TF_OperationGetAttrType TF_DataType* out-argument to Python integer.
%typemap(in, numinputs=0) TF_DataType *value (TF_DataType temp) {
$1 = &temp;
}
%typemap(argout) TF_DataType *value {
$result = PyInt_FromLong(*$1);
}
// Convert TF_OperationGetAttrBool unsigned char* out-argument to Python bool.
%typemap(in, numinputs=0) unsigned char *value (unsigned char temp) {
$1 = &temp;
}
%typemap(argout) unsigned char *value {
$result = PyBool_FromLong(*$1);
}
// Convert TF_OperationGetAttrInt int64_t* out-argument to Python bool.
%typemap(in, numinputs=0) int64_t *value (int64_t temp) {
$1 = &temp;
}
%typemap(argout) int64_t *value {
$result = PyLong_FromLongLong(*$1);
}
// We use TF_OperationGetControlInputs_wrapper instead of
// TF_OperationGetControlInputs
%ignore TF_OperationGetControlInputs;
%unignore TF_OperationGetControlInputs_wrapper;
// See comment for "%noexception TF_SessionRun_wrapper;"
%noexception TF_OperationGetControlInputs_wrapper;
// Migrate one function from pywrap_tfe.i
%include "tensorflow/c/c_api_experimental.h"
%unignore TF_ImportGraphDefOptionsSetValidateColocationConstraints;
%noexception TF_ImportGraphDefOptionsSetValidateColocationConstraints;
// Build a Python list of TF_Operation* and return it.
%typemap(out) std::vector<TF_Operation*> tensorflow::TF_OperationGetControlInputs_wrapper {
$result = PyList_New($1.size());
if (!$result) {
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
}
for (size_t i = 0; i < $1.size(); ++i) {
PyList_SET_ITEM($result, i, CreateWrappedTFOperation($1[i]));
}
}
// We use TF_OperationGetControlOutputs_wrapper instead of
// TF_OperationGetControlOutputs
%ignore TF_OperationGetControlOutputs;
%unignore TF_OperationGetControlOutputs_wrapper;
// See comment for "%noexception TF_SessionRun_wrapper;"
%noexception TF_OperationGetControlOutputs_wrapper;
// Build a Python list of TF_Operation* and return it.
%typemap(out) std::vector<TF_Operation*> tensorflow::TF_OperationGetControlOutputs_wrapper {
$result = PyList_New($1.size());
if (!$result) {
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
}
for (size_t i = 0; i < $1.size(); ++i) {
PyList_SET_ITEM($result, i, CreateWrappedTFOperation($1[i]));
}
}
%ignore TF_OperationOutputConsumers;
%unignore TF_OperationOutputConsumers_wrapper;
// See comment for "%noexception TF_SessionRun_wrapper;"
%noexception TF_OperationGetOutputConsumers_wrapper;
// Build a Python list of unicode strings and return it. (Operation names are
// always represented as unicode.)
%typemap(out) std::vector<const char*>
tensorflow::TF_OperationOutputConsumers_wrapper {
$result = PyList_New($1.size());
if (!$result) {
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
}
for (size_t i = 0; i < $1.size(); ++i) {
PyList_SET_ITEM($result, i, PyUnicode_FromString($1[i]));
}
}
%unignore GetOperationInputs;
// See comment for "%noexception TF_SessionRun_wrapper;"
%noexception GetOperationInputs;
// Build a Python list of TF_Outputs and return it.
// TODO(skyewm): is there some way to generalize this pattern? Maybe a macro?
%typemap(out) std::vector<TF_Output> tensorflow::GetOperationInputs {
$result = PyList_New($1.size());
if (!$result) {
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
}
// Unwrap the generated SwigValueWrapper<std::vector<TF_Output>>
const std::vector<TF_Output>& tf_outputs = $1;
for (size_t i = 0; i < tf_outputs.size(); ++i) {
PyList_SET_ITEM($result, i, CreateWrappedTFOutput(tf_outputs[i]));
}
}
%ignore TF_ImportGraphDefResultsMissingUnusedInputMappings;
%unignore TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper;
// See comment for "%noexception TF_SessionRun_wrapper;"
%noexception TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper;
%typemap(out) std::vector<string>
TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{
$result = PyList_New($1.size());
if (!$result) {
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
}
for (size_t i = 0; i < $1.size(); ++i) {
const string& input_str = $1[i];
PyList_SET_ITEM($result, i, PyBytes_FromStringAndSize(input_str.data(),
input_str.size()));
}
}
////////////////////////////////////////////////////////////////////////////////
// BEGIN TYPEMAPS FOR tensorflow::TF_Run_wrapper()
////////////////////////////////////////////////////////////////////////////////
// Converts a python list of strings to NameVector.
// Has multiple users including feeds/fetches names and function output names
%typemap(in) const tensorflow::NameVector& (
tensorflow::NameVector temp,
tensorflow::Safe_PyObjectPtr temp_string_list(
tensorflow::make_safe(static_cast<PyObject*>(nullptr)))) {
if (!PyList_Check($input)) {
SWIG_exception_fail(
SWIG_TypeError,
tensorflow::strings::Printf(
"Expected a python list for conversion "
"to tensorflow::NameVector but got %s",
Py_TYPE($input)->tp_name).c_str());
}
Py_ssize_t len = PyList_Size($input);
temp_string_list = tensorflow::make_safe(PyList_New(len));
if (!temp_string_list) {
SWIG_exception_fail(
SWIG_MemoryError,
tensorflow::strings::Printf("Failed to create a list of size %zd",
len).c_str());
}
for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* elem = PyList_GetItem($input, i);
if (!elem) {
SWIG_fail;
}
// Keep a reference to the string in case the incoming list is modified.
PyList_SET_ITEM(temp_string_list.get(), i, elem);
Py_INCREF(elem);
char* string_elem = PyBytes_AsString(elem);
if (!string_elem) {
SWIG_exception_fail(
SWIG_TypeError,
tensorflow::strings::Printf(
"Element %zd was of type %s instead of a string",
i, Py_TYPE(elem)->tp_name).c_str());
}
// TODO(mrry): Avoid copying the fetch name in, if this impacts performance.
temp.push_back(string_elem);
}
$1 = &temp;
}
// Define temporaries for the argout outputs.
%typemap(in, numinputs=0) tensorflow::PyObjectVector* out_values (
tensorflow::PyObjectVector temp) {
$1 = &temp;
}
// TODO(iga): move this and the corresponding typemap(argout) to
// tf_sessionrun_wrapper.i once we get rid of this code for DeprecatedSession.
%typemap(in, numinputs=0) char** out_handle (
char* temp) {
$1 = &temp;
}
// Build a Python list of outputs and return it.
%typemap(argout) tensorflow::PyObjectVector* out_values {
std::vector<tensorflow::Safe_PyObjectPtr> out_values_safe;
for (size_t i = 0; i < $1->size(); ++i) {
out_values_safe.emplace_back(tensorflow::make_safe($1->at(i)));
}
$result = PyList_New($1->size());
if (!$result) {
SWIG_exception_fail(
SWIG_MemoryError,
tensorflow::strings::Printf("Failed to create a list of size %zd",
$1->size()).c_str());
}
for (size_t i = 0; i < $1->size(); ++i) {
PyList_SET_ITEM($result, i, $1->at(i));
out_values_safe[i].release();
}
}
// Return the handle as a python string object.
%typemap(argout) char** out_handle {
%#if PY_MAJOR_VERSION < 3
$result = PyString_FromStringAndSize(
%#else
$result = PyUnicode_FromStringAndSize(
%#endif
*$1, *$1 == nullptr ? 0 : strlen(*$1));
delete[] *$1;
}
////////////////////////////////////////////////////////////////////////////////
// END TYPEMAPS FOR tensorflow::TF_Run_wrapper()
////////////////////////////////////////////////////////////////////////////////
// Typemap for TF_Status* inputs that automatically unwraps a ScopedTFStatus.
// This can also handle a wrapped TF_Status* input.
%typemap(in) (TF_Status*) {
PyObject* wrapped_tf_status;
if (strcmp(Py_TYPE($input)->tp_name, "ScopedTFStatus") == 0) {
DCHECK(PyObject_HasAttrString($input, "status"))
<< "ScopedTFStatus.status not found! Do you need to modify "
"tf_session.i?";
wrapped_tf_status = PyObject_GetAttrString($input, "status");
} else {
// Assume wrapped TF_Status*
wrapped_tf_status = $input;
}
DCHECK_EQ(strcmp(Py_TYPE(wrapped_tf_status)->tp_name, "SwigPyObject"), 0)
<< Py_TYPE(wrapped_tf_status)->tp_name;
// The following is the default SWIG code generated for TF_Status*
void* tf_status = nullptr;
int r = SWIG_ConvertPtr(wrapped_tf_status, &tf_status,
$descriptor(TF_Status*), 0 | 0);
if (!SWIG_IsOK(r)) {
SWIG_exception_fail(
SWIG_ArgError(r),
"in method '_TF_DeleteStatus', argument 1 of type 'TF_Status *'");
}
$1 = reinterpret_cast<TF_Status*>(tf_status);
}
// Typemap for functions that return a TF_Buffer struct. This typemap creates a
// Python string from the TF_Buffer and returns it. The TF_Buffer.data string
// is not expected to be NULL-terminated, and TF_Buffer.length does not count
// the terminator.
%typemap(out) TF_Buffer (TF_GetOpList,TF_GetBuffer) {
$result = PyBytes_FromStringAndSize(
reinterpret_cast<const char*>($1.data), $1.length);
}
// Converts input Python list of wrapped TF_Outputs into a single array
%typemap(in) (const TF_Output* inputs, int num_inputs)
(std::vector<TF_Output> inputs) {
string error_msg;
if (!PyTensorListToVector($input, &inputs, &error_msg)) {
SWIG_exception_fail(SWIG_TypeError, ("$symname: " + error_msg).c_str());
}
$1 = inputs.data();
$2 = inputs.size();
}
// Typemaps for TF_ImportGraphDefResultsReturnOutputs
%typemap(in, numinputs=0) (int* num_outputs, TF_Output** outputs)
(int num_outputs, TF_Output* outputs) {
$1 = &num_outputs;
$2 = &outputs;
}
%typemap(argout) (int* num_outputs, TF_Output** outputs) {
$result = PyList_New(*$1);
if (!$result) {
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
}
int num_outputs = *$1;
TF_Output* outputs = *$2;
for (int i = 0; i < num_outputs; ++i) {
PyList_SET_ITEM($result, i, CreateWrappedTFOutput(outputs[i]));
}
}
// Typemaps for TF_ImportGraphDefResultsReturnOperations
%typemap(in, numinputs=0) (int* num_opers, TF_Operation*** opers)
(int num_opers, TF_Operation** opers) {
$1 = &num_opers;
$2 = &opers;
}
%typemap(argout) (int* num_opers, TF_Operation*** opers) {
$result = PyList_New(*$1);
if (!$result) {
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
}
int num_opers = *$1;
TF_Operation** opers = *$2;
for (int i = 0; i < num_opers; ++i) {
PyList_SET_ITEM($result, i, CreateWrappedTFOperation(opers[i]));
}
}
// Typemaps for TF_GraphNextOperation().
%typemap(in) size_t* pos (size_t pos) {
pos = PyLong_AsUnsignedLong($input);
$1 = &pos;
}
// Returns a (TF_Operation*, int pos) tuple.
%typemap(argout) size_t* pos {
PyObject* new_result = PyTuple_New(2);
if (!new_result) {
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create tuple");
}
// Steals $result reference
PyTuple_SET_ITEM(new_result, 0, $result);
PyTuple_SET_ITEM(new_result, 1, PyLong_FromSize_t(*$1));
$result = new_result;
}
%typemap(in, numinputs=0) int64_t* out_handle (int64_t out_handle) {
$1 = &out_handle;
}
%typemap(argout) int64_t* out_handle {
$result = PyLong_FromLongLong(*$1);
}
%typemap(in) int64_t handle {
if (!PyLong_Check($input)) {
SWIG_exception_fail(
SWIG_TypeError,
tensorflow::strings::Printf(
"Expected a python long for conversion to callable handle but got %s",
Py_TYPE($input)->tp_name).c_str());
}
$1 = PyLong_AsLongLong($input);
}
// Override default py3 behavior of attempting to encode into Unicode.
%typemap(out) std::string tensorflow::GetHandleShapeAndType {
$result = PyBytes_FromStringAndSize($1.data(), $1.size());
}
// TODO(skyewm): SWIG emits a warning for the const char* in TF_WhileParams,
// skip for now
%ignore TF_WhileParams;
%ignore TF_NewWhile;
%ignore TF_FinishWhile;
%ignore TF_AbortWhile;
// These are defined below, avoid duplicate definitions
%ignore TF_Run;
%ignore TF_PRun;
%ignore TF_PRunSetup;
// We use TF_SessionRun_wrapper instead of TF_SessionRun
%ignore TF_SessionRun;
%unignore TF_SessionRun_wrapper;
// The %exception block above releases the Python GIL for the length of each
// wrapped method. We disable this behavior for TF_SessionRun_wrapper because it
// uses Python method(s) that expect the GIL to be held (at least
// PyArray_Return, maybe others).
%noexception TF_SessionRun_wrapper;
// We use TF_SessionPRunSetup_wrapper instead of TF_SessionPRunSetup
%ignore TF_SessionPRunSetup;
%unignore TF_SessionPRunSetup_wrapper;
// See comment for "%noexception TF_SessionRun_wrapper;"
%noexception TF_SessionPRunSetup_wrapper;
// We use TF_SessionPRun_wrapper instead of TF_SessionPRun
%ignore TF_SessionPRun;
%unignore TF_SessionPRun_wrapper;
// See comment for "%noexception TF_SessionRun_wrapper;"
%noexception TF_SessionPRun_wrapper;
%unignore TF_DeprecatedSessionMakeCallable;
%unignore TF_SessionMakeCallable;
%unignore TF_DeprecatedSessionRunCallable;
%unignore TF_SessionRunCallable;
%unignore TF_DeprecatedSessionReleaseCallable;
%unignore TF_SessionReleaseCallable;
// See comment for "%noexception TF_SessionRun_wrapper;"
%noexception TF_DeprecatedSessionRunCallable;
%noexception TF_SessionRunCallable;
%rename("_TF_SetTarget") TF_SetTarget;
%rename("_TF_SetConfig") TF_SetConfig;
%rename("_TF_NewSessionOptions") TF_NewSessionOptions;
%include "tensorflow/c/c_api.h"
%include "tensorflow/c/tf_attrtype.h"
%include "tensorflow/c/python_api.h"
%ignoreall
%insert("python") %{
def TF_NewSessionOptions(target=None, config=None):
# NOTE: target and config are validated in the session constructor.
opts = _TF_NewSessionOptions()
if target is not None:
_TF_SetTarget(opts, target)
if config is not None:
from tensorflow.python.framework import errors
config_str = config.SerializeToString()
_TF_SetConfig(opts, config_str)
return opts
%}
// Include the wrapper for TF_Run from tf_session_helper.h.
// The %exception block above releases the Python GIL for the length
// of each wrapped method. We disable this behavior for TF_Run
// because it uses the Python allocator.
%noexception tensorflow::TF_Run_wrapper;
%rename(TF_Run) tensorflow::TF_Run_wrapper;
%unignore tensorflow;
%unignore TF_Run;
%unignore EqualGraphDefWrapper;
%unignore EqualAttrValueWrapper;
// Include the wrapper for TF_PRunSetup from tf_session_helper.h.
// The %exception block above releases the Python GIL for the length
// of each wrapped method. We disable this behavior for TF_PRunSetup
// because it uses the Python allocator.
%noexception tensorflow::TF_PRunSetup_wrapper;
%rename(TF_PRunSetup) tensorflow::TF_PRunSetup_wrapper;
%unignore tensorflow;
%unignore TF_PRunSetup;
// Include the wrapper for TF_PRun from tf_session_helper.h.
// The %exception block above releases the Python GIL for the length
// of each wrapped method. We disable this behavior for TF_PRun
// because it uses the Python allocator.
%noexception tensorflow::TF_PRun_wrapper;
%rename(TF_PRun) tensorflow::TF_PRun_wrapper;
%unignore tensorflow;
%unignore TF_PRun;
%unignore tensorflow::TF_Reset_wrapper;
%insert("python") %{
def TF_Reset(target, containers=None, config=None):
from tensorflow.python.framework import errors
opts = TF_NewSessionOptions(target=target, config=config)
try:
TF_Reset_wrapper(opts, containers)
finally:
TF_DeleteSessionOptions(opts)
%}
// We use TF_GraphToFunction_wrapper instead of TF_GraphToFunction
%ignore TF_GraphToFunction;
// TF_GraphToFunction_wrapper does not use any Python methods and
// does not require GIL to be held.
%unignore TF_GraphToFunction_wrapper;
// $input is a Python list of wrapped TF_Operations
%typemap(in) (const std::vector<TF_Operation*>* opers)
(std::vector<TF_Operation*> opers) {
if ($input != Py_None) {
if (!PyList_Check($input)) {
SWIG_exception_fail(SWIG_TypeError, "$symname: expected list");
}
size_t size = PyList_Size($input);
for (int i = 0; i < size; ++i) {
PyObject* item = PyList_GetItem($input, i);
TF_Operation* oper_ptr;
SWIG_ConvertPtr(item, reinterpret_cast<void**>(&oper_ptr),
$descriptor(TF_Operation*), 0);
opers.push_back(oper_ptr);
}
$1 = &opers;
} else {
$1 = nullptr;
}
}
// $input is a Python list of wrapped TF_Operations
%typemap(in) (const std::vector<TF_Operation*>* control_outputs)
(std::vector<TF_Operation*> control_outputs) {
if ($input != Py_None) {
if (!PyList_Check($input)) {
SWIG_exception_fail(SWIG_TypeError, "$symname: expected list");
}
size_t size = PyList_Size($input);
for (int i = 0; i < size; ++i) {
PyObject* item = PyList_GetItem($input, i);
TF_Operation* oper_ptr;
SWIG_ConvertPtr(item, reinterpret_cast<void**>(&oper_ptr),
$descriptor(TF_Operation*), 0);
control_outputs.push_back(oper_ptr);
}
$1 = &control_outputs;
} else {
$1 = nullptr;
}
}
// Typemaps for TF_GraphGetTensorShapeHelper.
// Convert from C++ integer vector to Python list of ints.
%typemap(out) tensorflow::gtl::InlinedVector<int64_t, 6>
tensorflow::TF_GraphGetTensorShapeHelper {
$result = PyList_New($1.size());
if (!$result) {
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
}
for (size_t i = 0; i < $1.size(); ++i) {
PyList_SET_ITEM($result, i, PyLong_FromLongLong($1[i]));
}
}
%typemap(in, numinputs=0) bool* unknown_shape (bool temp) {
$1=&temp;
}
// Returns a (list(int), bool) tuple.
%typemap(argout) bool* unknown_shape {
PyObject* new_result = PyTuple_New(2);
if (!new_result) {
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create tuple");
}
// Steals $result reference
PyTuple_SET_ITEM(new_result, 0, $result);
PyTuple_SET_ITEM(new_result, 1, PyBool_FromLong(*$1));
$result = new_result;
}
%unignore tensorflow;
%unignore TF_GraphGetTensorShapeHelper;
%ignore TF_GraphGetTensorShape;
// We use TF_GraphSetTensorShape_wrapper instead of
// TF_GraphSetTensorShape
%ignore TF_GraphSetTensorShape;
%unignore tensorflow;
%unignore TF_GraphSetTensorShape_wrapper;
// $input is a Python list of ints to a vector<int> for TF_GraphSetTensorShape_wrapper
%typemap(in) (const std::vector<int64_t>& dims)
(std::vector<int64_t> dims_local){
if ($input != Py_None) {
PyObject* py_int_seq = PySequence_Fast($input, tensorflow::strings::Printf(
"$symname: expected list but got %s ",
Py_TYPE($input)->tp_name).c_str());
if (py_int_seq == nullptr) {
SWIG_exception_fail(SWIG_RuntimeError, tensorflow::strings::Printf(
"$symname: PySequence_Fast returned NULL.").c_str());
}
PyInt64ListToVector(py_int_seq, &dims_local);
Py_DECREF(py_int_seq);
$1 = &dims_local;
} else {
$1 = nullptr;
}
}
// We use TF_GraphGetTensorShape_wrapper instead of
// TF_GraphGetTensorShape
%ignore TF_GraphGetTensorShape;
%unignore tensorflow;
%unignore TF_GraphGetTensorShape_wrapper;
// Build a Python list of ints and return it.
%typemap(out) std::vector<int64_t> tensorflow::TF_GraphGetTensorShape_wrapper {
$result = PyList_New($1.size());
if (!$result) {
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
}
for (size_t i = 0; i < $1.size(); ++i) {
PyList_SET_ITEM($result, i, PyLong_FromLongLong($1[i]));
}
}
// We use TF_GraphSetOutputHandleShapesAndTypes_wrapper instead of
// TF_GraphSetOutputHandleShapesAndTypes
%ignore TF_GraphSetOutputHandleShapesAndTypes;
%unignore tensorflow;
%unignore TF_GraphSetOutputHandleShapesAndTypes_wrapper;
// The space between the double angle brackets below looks extraneous, but
// our version of SWIG cannot parse ">>".
%typemap(in) (const std::vector<std::vector<int64_t> >& shapes)
(std::vector<std::vector<int64_t> > shapes_local){
PyObject* seq = PySequence_Fast($input, tensorflow::strings::Printf(
"$symname: expected list but got %s ",
Py_TYPE($input)->tp_name).c_str());
if (seq == nullptr) {
SWIG_exception_fail(SWIG_RuntimeError, tensorflow::strings::Printf(
"$symname: PySequence_Fast returned NULL.").c_str());
}
int size = PySequence_Fast_GET_SIZE(seq);
if (size == 0) {
SWIG_exception_fail(SWIG_ValueError, tensorflow::strings::Printf(
"$symname: shapes list must be non-empty").c_str());
}
for (int i = 0; i < size; ++i) {
PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
std::vector<int64_t> dims;
if (item != Py_None) {
PyObject* py_int_seq = PySequence_Fast(item, tensorflow::strings::Printf(
"$symname: expected list but got %s ",
Py_TYPE($input)->tp_name).c_str());
if (py_int_seq == nullptr) {
SWIG_exception_fail(SWIG_RuntimeError, tensorflow::strings::Printf(
"$symname: PySequence_Fast returned NULL.").c_str());
}
PyInt64ListToVector(py_int_seq, &dims);
Py_DECREF(py_int_seq);
}
shapes_local.push_back(dims);
}
Py_DECREF(seq);
$1 = &shapes_local;
}
%typemap(in) (const std::vector<int>& ranks)
(std::vector<int> ranks_local){
PyObject* seq = PySequence_Fast($input, tensorflow::strings::Printf(
"$symname: expected list but got %s ",
Py_TYPE($input)->tp_name).c_str());
if (seq == nullptr) {
SWIG_exception_fail(SWIG_RuntimeError, tensorflow::strings::Printf(
"$symname: PySequence_Fast returned NULL.").c_str());
}
int size = PySequence_Fast_GET_SIZE(seq);
if (size == 0) {
SWIG_exception_fail(SWIG_ValueError, tensorflow::strings::Printf(
"$symname: shapes list must be non-empty").c_str());
}
for (int i = 0; i < size; ++i) {
PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
ranks_local.push_back((int) PyInt_AsLong(item));
}
Py_DECREF(seq);
$1 = &ranks_local;
}
%typemap(in) (const std::vector<TF_DataType>& types)
(std::vector<TF_DataType> types_local){
PyObject* seq = PySequence_Fast($input, tensorflow::strings::Printf(
"$symname: expected list but got %s ",
Py_TYPE($input)->tp_name).c_str());
if (seq == nullptr) {
SWIG_exception_fail(SWIG_RuntimeError, tensorflow::strings::Printf(
"$symname: PySequence_Fast returned NULL.").c_str());
}
int size = PySequence_Fast_GET_SIZE(seq);
if (size == 0) {
SWIG_exception_fail(SWIG_ValueError, tensorflow::strings::Printf(
"$symname: shapes list must be non-empty").c_str());
}
for (int i = 0; i < size; ++i) {
PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
types_local.push_back((TF_DataType) PyInt_AsLong(item));
}
Py_DECREF(seq);
$1 = &types_local;
}
%unignore TF_CreatePlaceholders;
// See comment for "%noexception TF_SessionRun_wrapper;"
%noexception TF_CreatePlaceholders;
// Build a Python list of TF_Output and return it.
%typemap(out) std::vector<TF_Output> tensorflow::TF_CreatePlaceholders {
$result = PyList_New($1.size());
if (!$result) {
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
}
// Unwrap the generated SwigValueWrapper<std::vector<TF_Output>>
const std::vector<TF_Output>& tf_outputs = $1;
for (size_t i = 0; i < tf_outputs.size(); ++i) {
PyList_SET_ITEM($result, i, CreateWrappedTFOutput(tf_outputs[i]));
}
}
%unignore TF_NewSessionRef;
%unignore SetRequireShapeInferenceFns;
%unignore TF_TryEvaluateConstant_wrapper;
%noexception TF_TryEvaluateConstant_wrapper;
%unignore ExtendSession;
%unignore HandleShapeAndType;
%include "tensorflow/python/client/tf_session_helper.h"
%unignoreall

File diff suppressed because it is too large Load Diff

View File

@ -1,102 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// SWIG typemaps for TF_SessionRun_wrapper()
%include "tensorflow/python/platform/base.i"
%{
#include "tensorflow/python/client/tf_session_helper.h"
%}
// Required to use PyArray_* functions.
%init %{
tensorflow::ImportNumpy();
%}
// $input is a Python dict mapping wrapped TF_Outputs to ndarrays.
%typemap(in) (const std::vector<TF_Output>& inputs,
const std::vector<PyObject*>& input_ndarrays)
(std::vector<TF_Output> inputs, std::vector<PyObject*> input_ndarrays) {
if (!PyDict_Check($input)) {
SWIG_exception_fail(SWIG_TypeError, "$symname: expected dict");
}
PyObject* key;
PyObject* value;
Py_ssize_t pos = 0;
while (PyDict_Next($input, &pos, &key, &value)) {
TF_Output* input_ptr;
SWIG_ConvertPtr(key, reinterpret_cast<void**>(&input_ptr),
SWIGTYPE_p_TF_Output, 0);
inputs.push_back(*input_ptr);
if (!PyArray_Check(value)) {
SWIG_exception_fail(
SWIG_TypeError,
"$symname: expected all values in input dict to be ndarray");
}
input_ndarrays.push_back(value);
}
$1 = &inputs;
$2 = &input_ndarrays;
}
// $input is a Python list of wrapped TF_Operations
%typemap(in) (const std::vector<TF_Operation*>& targets)
(std::vector<TF_Operation*> targets) {
if (!PyList_Check($input)) {
SWIG_exception_fail(SWIG_TypeError, "$symname: expected list");
}
size_t size = PyList_Size($input);
for (int i = 0; i < size; ++i) {
PyObject* item = PyList_GetItem($input, i);
TF_Operation* oper_ptr;
SWIG_ConvertPtr(item, reinterpret_cast<void**>(&oper_ptr),
SWIGTYPE_p_TF_Operation, 0);
targets.push_back(oper_ptr);
}
$1 = &targets;
}
// $input is a Python list of wrapped TF_Outputs
%typemap(in) (const std::vector<TF_Output>& outputs)
(std::vector<TF_Output> outputs) {
string error_msg;
if (!PyTensorListToVector($input, &outputs, &error_msg)) {
SWIG_exception_fail(SWIG_TypeError, ("$symname: " + error_msg).c_str());
}
$1 = &outputs;
}
// Apply the typemap above to inputs as well
%typemap(in) (const std::vector<TF_Output>& inputs) =
(const std::vector<TF_Output>& outputs);
// Create temporary py_outputs_vec variable to store return value
%typemap(in, numinputs=0) (std::vector<PyObject*>* py_outputs)
(std::vector<PyObject*> py_outputs_vec) {
$1 = &py_outputs_vec;
}
// Convert py_outputs to returned Python list
%typemap(argout) (std::vector<PyObject*>* py_outputs) {
$result = PyList_New($1->size());
if (!$result) {
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
}
for (int i = 0; i < $1->size(); ++i) {
PyList_SET_ITEM($result, i, (*$1)[i]);
}
}

View File

@ -22,7 +22,7 @@ from absl.testing import parameterized
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python.data.experimental.ops import distribute
from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.kernel_tests import test_base
@ -220,7 +220,7 @@ class RemoteReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
def setUp(self):
super(RemoteReplicateTest, self).setUp()
# Start the local server.
local_port = pywrap_tensorflow.TF_PickUnusedPortOrDie()
local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
context.set_server_def(
server_def=_get_server_def(
JOB_NAME,

View File

@ -265,6 +265,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:platform",
"//tensorflow/python:pywrap_tf_session",
"//third_party/py/numpy",
"@six_archive//:six",
],
@ -1151,6 +1152,7 @@ py_test(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform",
"//tensorflow/python:platform_test",
"//tensorflow/python:pywrap_tf_session",
"//third_party/py/numpy",
],
)

View File

@ -26,7 +26,7 @@ import traceback
import numpy as np
import six
from tensorflow.python import pywrap_tensorflow_internal
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.platform import gfile
HELP_INDENT = " "
@ -142,7 +142,7 @@ def get_tensorflow_version_lines(include_dependency_versions=False):
Returns:
A formatted, multi-line `RichTextLines` object.
"""
lines = ["TensorFlow version: %s" % pywrap_tensorflow_internal.__version__]
lines = ["TensorFlow version: %s" % pywrap_tf_session.__version__]
lines.append("")
if include_dependency_versions:
lines.append("Dependency version(s):")

View File

@ -23,7 +23,7 @@ import tempfile
import numpy as np
from tensorflow.python import pywrap_tensorflow_internal
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.debug.cli import debugger_cli_common
from tensorflow.python.framework import test_util
from tensorflow.python.platform import gfile
@ -1160,15 +1160,13 @@ class GetTensorFlowVersionLinesTest(test_util.TensorFlowTestCase):
def testGetVersionWithoutDependencies(self):
out = debugger_cli_common.get_tensorflow_version_lines()
self.assertEqual(2, len(out.lines))
self.assertEqual(
"TensorFlow version: %s" % pywrap_tensorflow_internal.__version__,
out.lines[0])
self.assertEqual("TensorFlow version: %s" % pywrap_tf_session.__version__,
out.lines[0])
def testGetVersionWithDependencies(self):
out = debugger_cli_common.get_tensorflow_version_lines(True)
self.assertIn(
"TensorFlow version: %s" % pywrap_tensorflow_internal.__version__,
out.lines)
self.assertIn("TensorFlow version: %s" % pywrap_tf_session.__version__,
out.lines)
self.assertIn(" numpy: %s" % np.__version__, out.lines)

View File

@ -143,13 +143,14 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
":eager_util",
":executor",
":monitoring",
"//tensorflow/python:c_api_util",
"//tensorflow/python:device",
"//tensorflow/python:device_spec",
"//tensorflow/python:errors",
"//tensorflow/python:platform",
"//tensorflow/python:pywrap_tf_session",
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python:tf2",
"//tensorflow/python:util",
@ -177,7 +178,8 @@ py_library(
"//third_party/py/tf_agents:__subpackages__",
],
deps = [
":eager_util",
"//tensorflow/python:c_api_util",
"//tensorflow/python:pywrap_tf_session",
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python:util",
],
@ -200,7 +202,8 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
":context",
":eager_util",
"//tensorflow/python:c_api_util",
"//tensorflow/python:pywrap_tf_session",
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python:util",
],
@ -223,7 +226,8 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
":eager_util",
"//tensorflow/python:c_api_util",
"//tensorflow/python:pywrap_tf_session",
"//tensorflow/python:pywrap_tfe",
],
)
@ -488,6 +492,7 @@ py_library(
"//tensorflow/python:func_graph",
"//tensorflow/python:gradients_impl",
"//tensorflow/python:graph_to_function_def",
"//tensorflow/python:pywrap_tf_session",
"//tensorflow/python:util",
"//third_party/py/numpy",
"@six_archive//:six",
@ -554,17 +559,6 @@ py_library(
],
)
py_library(
name = "eager_util",
srcs = ["eager_util.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python:util",
],
)
cuda_py_test(
name = "benchmarks_test",
srcs = ["benchmarks_test.py"],

View File

@ -32,9 +32,10 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import pywrap_tfe
from tensorflow.python import tf2
from tensorflow.python.eager import eager_util as c_api_util
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.eager import executor
from tensorflow.python.eager import monitoring
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import device as pydev
from tensorflow.python.util import compat
from tensorflow.python.util import is_in_graph_mode
@ -789,7 +790,7 @@ class Context(object):
self.ensure_initialized()
with c_api_util.tf_buffer() as buffer_:
pywrap_tfe.TFE_HostAddressSpace(self._context_handle, buffer_)
address_space = pywrap_tfe.TF_GetBuffer(buffer_).decode("utf-8")
address_space = pywrap_tf_session.TF_GetBuffer(buffer_).decode("utf-8")
return address_space
# TODO(fishx): remove this property.
@ -1537,7 +1538,7 @@ class Context(object):
return None
with c_api_util.tf_buffer() as buffer_:
pywrap_tfe.TFE_ContextExportRunMetadata(self._context_handle, buffer_)
proto_data = pywrap_tfe.TF_GetBuffer(buffer_)
proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
run_metadata = config_pb2.RunMetadata()
run_metadata.ParseFromString(compat.as_bytes(proto_data))
return run_metadata

View File

@ -1,61 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for using the TensorFlow Eager using the C API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tfe as c_api
from tensorflow.python.util import compat
from tensorflow.python.util import tf_contextlib
# We temporarily need a duplicate tf_buffer function in eager_util. The
# c_api_util is still relying on SWIG and is thus incompatible until
# we migrate over. We can delete this once we migrate tf_session.i
@tf_contextlib.contextmanager
def tf_buffer(data=None):
"""Context manager that creates and deletes TF_Buffer.
Example usage:
with tf_buffer() as buf:
# get serialized graph def into buf
...
proto_data = c_api.TF_GetBuffer(buf)
graph_def.ParseFromString(compat.as_bytes(proto_data))
# buf has been deleted
with tf_buffer(some_string) as buf:
c_api.TF_SomeFunction(buf)
# buf has been deleted
Args:
data: An optional `bytes`, `str`, or `unicode` object. If not None, the
yielded buffer will contain this data.
Yields:
Created TF_Buffer
"""
if data:
buf = c_api.TF_NewBufferFromString(compat.as_bytes(data))
else:
buf = c_api.TF_NewBuffer()
try:
yield buf
finally:
c_api.TF_DeleteBuffer(buf)

View File

@ -33,8 +33,8 @@ from six.moves import map
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.python import _pywrap_utils
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import context
@ -482,7 +482,7 @@ class _EagerDefinedFunction(object):
output_names = []
else:
output_names = []
fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
fn = pywrap_tf_session.TF_GraphToFunction_wrapper(
graph._c_graph, # pylint: disable=protected-access
compat.as_str(name),
False,
@ -499,14 +499,14 @@ class _EagerDefinedFunction(object):
serialized = attr_value.SerializeToString()
# TODO(iga): this creates and deletes a new TF_Status for every attr.
# It might be worth creating a convenient way to re-use status.
pywrap_tensorflow.TF_FunctionSetAttrValueProto(
fn, compat.as_str(name), serialized)
pywrap_tf_session.TF_FunctionSetAttrValueProto(fn, compat.as_str(name),
serialized)
# TODO(apassos) avoid creating a FunctionDef (specially to grab the
# signature, but also in general it's nice not to depend on it.
with c_api_util.tf_buffer() as buffer_:
pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_)
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
pywrap_tf_session.TF_FunctionToFunctionDef(fn, buffer_)
proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
function_def = function_pb2.FunctionDef()
function_def.ParseFromString(compat.as_bytes(proto_data))
self._name = compat.as_bytes(function_def.signature.name)

View File

@ -22,7 +22,8 @@ import collections
from tensorflow.core.framework import summary_pb2
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import eager_util as c_api_util
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.framework import c_api_util
from tensorflow.python.util import compat
_MetricMethod = collections.namedtuple('MetricMethod', 'create delete get_cell')
@ -258,7 +259,7 @@ class StringGaugeCell(object):
"""Retrieves the current value."""
with c_api_util.tf_buffer() as buffer_:
pywrap_tfe.TFE_MonitoringStringGaugeCellValue(self._cell, buffer_)
value = pywrap_tfe.TF_GetBuffer(buffer_).decode('utf-8')
value = pywrap_tf_session.TF_GetBuffer(buffer_).decode('utf-8')
return value
@ -361,7 +362,7 @@ class SamplerCell(object):
"""
with c_api_util.tf_buffer() as buffer_:
pywrap_tfe.TFE_MonitoringSamplerCellValue(self._cell, buffer_)
proto_data = pywrap_tfe.TF_GetBuffer(buffer_)
proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
histogram_proto = summary_pb2.HistogramProto()
histogram_proto.ParseFromString(compat.as_bytes(proto_data))
return histogram_proto

View File

@ -40,8 +40,9 @@ import threading
from tensorflow.python import _pywrap_events_writer
from tensorflow.python import pywrap_tfe
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.eager import context
from tensorflow.python.eager import eager_util as c_api_util
from tensorflow.python.framework import c_api_util
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
@ -101,7 +102,7 @@ def stop():
context.context().executor.wait()
with c_api_util.tf_buffer() as buffer_:
pywrap_tfe.TFE_ProfilerSerializeToString(_profiler, buffer_)
result = pywrap_tfe.TF_GetBuffer(buffer_)
result = pywrap_tf_session.TF_GetBuffer(buffer_)
pywrap_tfe.TFE_DeleteProfiler(_profiler)
_profiler = None
_run_num += 1

View File

@ -19,7 +19,8 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import eager_util as c_api_util
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import errors
@ -74,4 +75,4 @@ def monitor(service_addr,
pywrap_tfe.TFE_ProfilerClientMonitor(service_addr, duration_ms,
monitoring_level, display_timestamp,
buffer_)
return pywrap_tfe.TF_GetBuffer(buffer_)
return pywrap_tf_session.TF_GetBuffer(buffer_)

View File

@ -21,7 +21,7 @@ from __future__ import print_function
from tensorflow.core.framework import api_def_pb2
from tensorflow.core.framework import op_def_pb2
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.client import pywrap_tf_session as c_api
from tensorflow.python.util import compat
from tensorflow.python.util import tf_contextlib

View File

@ -23,7 +23,7 @@ import warnings
from tensorflow.core.lib.core import error_codes_pb2
from tensorflow.python import _pywrap_py_exception_registry
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.client import pywrap_tf_session as c_api
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import error_interpolation
from tensorflow.python.util import compat

View File

@ -23,7 +23,7 @@ import pickle
import warnings
from tensorflow.core.lib.core import error_codes_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import _pywrap_file_io
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl
@ -112,7 +112,7 @@ class ErrorsTest(test.TestCase):
def testStatusDoesNotLeak(self):
try:
pywrap_tensorflow.DeleteFile(compat.as_bytes("/DOES_NOT_EXIST/"))
_pywrap_file_io.DeleteFile(compat.as_bytes("/DOES_NOT_EXIST/"))
except:
pass
gc.collect()

View File

@ -27,7 +27,7 @@ import hashlib
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.client import pywrap_tf_session as c_api
from tensorflow.python.eager import context
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import dtypes

View File

@ -20,8 +20,8 @@ from __future__ import print_function
import contextlib
from tensorflow.core.framework import graph_pb2
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python import tf2
from tensorflow.python.client import pywrap_tf_session as c_api
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import errors

View File

@ -19,7 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import kernel_def_pb2
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.client import pywrap_tf_session as c_api
from tensorflow.python.util import compat

View File

@ -26,7 +26,7 @@ import platform
import sys
from tensorflow.python import _pywrap_python_op_gen
from tensorflow.python import pywrap_tensorflow as py_tf
from tensorflow.python.client import pywrap_tf_session as py_tf
from tensorflow.python.lib.io import file_io
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export

View File

@ -32,7 +32,7 @@ from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import op_def_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import pywrap_tf_session as c_api
from tensorflow.python.eager import context
from tensorflow.python.framework import error_interpolation
from tensorflow.python.framework import graph_io
@ -446,9 +446,9 @@ def _is_default_attr_value(op_def, attr_name, attr_value):
if attr_def.name == attr_name:
if not attr_def.HasField("default_value"):
return False
# pywrap_tensorflow.EqualAttrValueWrapper returns an empty string
# c_api.EqualAttrValueWrapper returns an empty string
# if both arguments represent an equivalent AttrValue instance.
return not pywrap_tensorflow.EqualAttrValueWrapper(
return not c_api.EqualAttrValueWrapper(
attr_value.SerializeToString(),
attr_def.default_value.SerializeToString())
return False

View File

@ -38,11 +38,12 @@ from tensorflow.core.framework import versions_pb2
from tensorflow.core.protobuf import config_pb2
# pywrap_tensorflow must be imported first to avoid profobuf issues.
# (b/143110113)
# pylint: disable=invalid-import-order,g-bad-import-order
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python import pywrap_tfe as c_api_new
# pylint: enable=invalid-import-order,g-bad-import-order
# pylint: disable=invalid-import-order,g-bad-import-order,unused-import
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
# pylint: enable=invalid-import-order,g-bad-import-order,unused-import
from tensorflow.python import tf2
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.eager import monitoring
@ -254,7 +255,7 @@ def register_dense_tensor_like_type(tensor_type):
def uid():
"""A unique (within this program execution) integer."""
return c_api_new.TFE_Py_UID()
return pywrap_tfe.TFE_Py_UID()
def numpy_text(tensor, is_repr=False):
@ -502,13 +503,13 @@ class Tensor(_TensorLike):
def _c_api_shape(self):
"""Returns the TensorShape of this tensor according to the C API."""
c_graph = self._op._graph._c_graph # pylint: disable=protected-access
shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper(
shape_vec, unknown_shape = pywrap_tf_session.TF_GraphGetTensorShapeHelper(
c_graph, self._as_tf_output())
if unknown_shape:
return tensor_shape.unknown_shape()
else:
shape_vector = [None if d == -1 else d for d in shape_vector]
return tensor_shape.TensorShape(shape_vector)
shape_vec = [None if d == -1 else d for d in shape_vec]
return tensor_shape.TensorShape(shape_vec)
@property
def _shape(self):
@ -649,7 +650,7 @@ class Tensor(_TensorLike):
else:
dim_list.append(dim.value)
try:
c_api.TF_GraphSetTensorShape_wrapper(
pywrap_tf_session.TF_GraphSetTensorShape_wrapper(
self._op._graph._c_graph, # pylint: disable=protected-access
self._as_tf_output(),
dim_list,
@ -669,7 +670,7 @@ class Tensor(_TensorLike):
Returns:
A list of `Operation`s.
"""
consumer_names = c_api.TF_OperationOutputConsumers_wrapper(
consumer_names = pywrap_tf_session.TF_OperationOutputConsumers_wrapper(
self._as_tf_output())
# pylint: disable=protected-access
return [
@ -1160,7 +1161,7 @@ class _EagerTensorBase(Tensor):
# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and
# registers it with the current module.
EagerTensor = c_api_new.TFE_Py_InitEagerTensor(_EagerTensorBase)
EagerTensor = pywrap_tfe.TFE_Py_InitEagerTensor(_EagerTensorBase)
register_dense_tensor_like_type(Tensor)
@ -1633,20 +1634,22 @@ def _create_c_op(graph, node_def, inputs, control_inputs, op_def=None):
# Refactor so we don't have to do this here.
inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.attr)
# pylint: disable=protected-access
op_desc = c_api.TF_NewOperation(graph._c_graph, compat.as_str(node_def.op),
compat.as_str(node_def.name))
op_desc = pywrap_tf_session.TF_NewOperation(graph._c_graph,
compat.as_str(node_def.op),
compat.as_str(node_def.name))
if node_def.device:
c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device))
pywrap_tf_session.TF_SetDevice(op_desc, compat.as_str(node_def.device))
# Add inputs
for op_input in inputs:
if isinstance(op_input, (list, tuple)):
c_api.TF_AddInputList(op_desc, [t._as_tf_output() for t in op_input])
pywrap_tf_session.TF_AddInputList(op_desc,
[t._as_tf_output() for t in op_input])
else:
c_api.TF_AddInput(op_desc, op_input._as_tf_output())
pywrap_tf_session.TF_AddInput(op_desc, op_input._as_tf_output())
# Add control inputs
for control_input in control_inputs:
c_api.TF_AddControlInput(op_desc, control_input._c_op)
pywrap_tf_session.TF_AddControlInput(op_desc, control_input._c_op)
# pylint: enable=protected-access
# Add attrs
@ -1654,10 +1657,11 @@ def _create_c_op(graph, node_def, inputs, control_inputs, op_def=None):
serialized = attr_value.SerializeToString()
# TODO(skyewm): this creates and deletes a new TF_Status for every attr.
# It might be worth creating a convenient way to re-use the same status.
c_api.TF_SetAttrValueProto(op_desc, compat.as_str(name), serialized)
pywrap_tf_session.TF_SetAttrValueProto(op_desc, compat.as_str(name),
serialized)
try:
c_op = c_api.TF_FinishOperation(op_desc)
c_op = pywrap_tf_session.TF_FinishOperation(op_desc)
except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility.
raise ValueError(str(e))
@ -1744,7 +1748,7 @@ class Operation(object):
if not _VALID_OP_NAME_REGEX.match(node_def.name):
raise ValueError("'%s' is not a valid node name" % node_def.name)
c_op = None
elif type(node_def).__name__ == "SwigPyObject":
elif type(node_def).__name__ == "TF_Operation":
assert inputs is None
assert output_types is None
assert control_inputs is None
@ -1814,7 +1818,7 @@ class Operation(object):
# Initialize self._c_op.
if c_op:
self._c_op = c_op
op_def = g._get_op_def(c_api.TF_OperationOpType(c_op))
op_def = g._get_op_def(pywrap_tf_session.TF_OperationOpType(c_op))
name = self.name
else:
if op_def is None:
@ -1827,11 +1831,11 @@ class Operation(object):
self._is_stateful = op_def.is_stateful
# Initialize self._outputs.
num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
num_outputs = pywrap_tf_session.TF_OperationNumOutputs(self._c_op)
self._outputs = []
for i in range(num_outputs):
tf_output = c_api_util.tf_output(self._c_op, i)
output_type = c_api.TF_OperationOutputType(tf_output)
output_type = pywrap_tf_session.TF_OperationOutputType(tf_output)
tensor = Tensor._create_with_tf_output(self, i, output_type, tf_output) # pylint: disable=protected-access
self._outputs.append(tensor)
@ -1900,7 +1904,7 @@ class Operation(object):
@property
def name(self):
"""The full name of this operation."""
return c_api.TF_OperationName(self._c_op)
return pywrap_tf_session.TF_OperationName(self._c_op)
@property
def _id(self):
@ -1916,7 +1920,7 @@ class Operation(object):
assigned, or an empty string if it has not been assigned to a
device.
"""
return c_api.TF_OperationDevice(self._c_op)
return pywrap_tf_session.TF_OperationDevice(self._c_op)
@property
def _device_assignments(self):
@ -1991,33 +1995,28 @@ class Operation(object):
Returns:
List of the types of the Tensors computed by this operation.
Each element in the list is an integer whose value is one of
the TF_DataType enums defined in c_api.h
the TF_DataType enums defined in pywrap_tf_session.h
The length of this list indicates the number of output endpoints
of the operation.
"""
num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
num_outputs = pywrap_tf_session.TF_OperationNumOutputs(self._c_op)
output_types = [
c_api.TF_OperationOutputType(self._tf_output(i))
int(pywrap_tf_session.TF_OperationOutputType(self._tf_output(i)))
for i in xrange(num_outputs)
]
# In all the tests we have output_types that are passed into
# Operation.__init__ are a list of ints (which is illegal according
# to the docstring), but input_types are instances of DType.
# This extra assert is to catch if we ever use DType for output_types.
if output_types:
assert isinstance(output_types[0], int)
return output_types
def _tf_output(self, output_idx):
"""Create and return a new TF_Output for output_idx'th output of this op."""
tf_output = c_api.TF_Output()
tf_output = pywrap_tf_session.TF_Output()
tf_output.oper = self._c_op
tf_output.index = output_idx
return tf_output
def _tf_input(self, input_idx):
"""Create and return a new TF_Input for input_idx'th input of this op."""
tf_input = c_api.TF_Input()
tf_input = pywrap_tf_session.TF_Input()
tf_input.oper = self._c_op
tf_input.index = input_idx
return tf_input
@ -2040,7 +2039,7 @@ class Operation(object):
Args:
device_str: A string specifying where to place this op.
"""
c_api.SetRequestedDevice(
pywrap_tf_session.SetRequestedDevice(
self._graph._c_graph, # pylint: disable=protected-access
self._c_op, # pylint: disable=protected-access
device_str)
@ -2065,7 +2064,7 @@ class Operation(object):
# Reset cached inputs.
self._inputs_val = None
c_api.UpdateEdge(
pywrap_tf_session.UpdateEdge(
self._graph._c_graph, # pylint: disable=protected-access
tensor._as_tf_output(), # pylint: disable=protected-access
self._tf_input(index))
@ -2090,7 +2089,7 @@ class Operation(object):
# Reset cached inputs.
self._inputs_val = None
c_api.AddWhileInputHack(
pywrap_tf_session.AddWhileInputHack(
self._graph._c_graph, # pylint: disable=protected-access
tensor._as_tf_output(), # pylint: disable=protected-access
self._c_op)
@ -2108,7 +2107,10 @@ class Operation(object):
for op in ops:
if not isinstance(op, Operation):
raise TypeError("op must be an Operation: %s" % op)
c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access
pywrap_tf_session.AddControlInput(
self._graph._c_graph, # pylint: disable=protected-access
self._c_op, # pylint: disable=protected-access
op._c_op) # pylint: disable=protected-access
def _add_control_input(self, op):
"""Add a new control input to this operation.
@ -2122,11 +2124,14 @@ class Operation(object):
"""
if not isinstance(op, Operation):
raise TypeError("op must be an Operation: %s" % op)
c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access
pywrap_tf_session.AddControlInput(
self._graph._c_graph, # pylint: disable=protected-access
self._c_op, # pylint: disable=protected-access
op._c_op) # pylint: disable=protected-access
def _remove_all_control_inputs(self):
"""Removes any control inputs to this operation."""
c_api.RemoveAllControlInputs(self._graph._c_graph, self._c_op) # pylint: disable=protected-access
pywrap_tf_session.RemoveAllControlInputs(self._graph._c_graph, self._c_op) # pylint: disable=protected-access
def _add_outputs(self, types, shapes):
"""Adds new Tensors to self.outputs.
@ -2161,16 +2166,18 @@ class Operation(object):
"""The sequence of `Tensor` objects representing the data inputs of this op."""
if self._inputs_val is None:
# pylint: disable=protected-access
self._inputs_val = tuple(map(self.graph._get_tensor_by_tf_output,
c_api.GetOperationInputs(self._c_op)))
self._inputs_val = tuple(
map(self.graph._get_tensor_by_tf_output,
pywrap_tf_session.GetOperationInputs(self._c_op)))
# pylint: enable=protected-access
return self._inputs_val
@property
def _input_types(self):
num_inputs = c_api.TF_OperationNumInputs(self._c_op)
num_inputs = pywrap_tf_session.TF_OperationNumInputs(self._c_op)
input_types = [
dtypes.as_dtype(c_api.TF_OperationInputType(self._tf_input(i)))
dtypes.as_dtype(
pywrap_tf_session.TF_OperationInputType(self._tf_input(i)))
for i in xrange(num_inputs)
]
return input_types
@ -2189,11 +2196,12 @@ class Operation(object):
A list of `Operation` objects.
"""
control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op)
control_c_ops = pywrap_tf_session.TF_OperationGetControlInputs_wrapper(
self._c_op)
# pylint: disable=protected-access
return [
self.graph._get_operation_by_name_unsafe(c_api.TF_OperationName(c_op))
for c_op in control_c_ops
self.graph._get_operation_by_name_unsafe(
pywrap_tf_session.TF_OperationName(c_op)) for c_op in control_c_ops
]
# pylint: enable=protected-access
@ -2208,18 +2216,19 @@ class Operation(object):
A list of `Operation` objects.
"""
control_c_ops = c_api.TF_OperationGetControlOutputs_wrapper(self._c_op)
control_c_ops = pywrap_tf_session.TF_OperationGetControlOutputs_wrapper(
self._c_op)
# pylint: disable=protected-access
return [
self.graph._get_operation_by_name_unsafe(c_api.TF_OperationName(c_op))
for c_op in control_c_ops
self.graph._get_operation_by_name_unsafe(
pywrap_tf_session.TF_OperationName(c_op)) for c_op in control_c_ops
]
# pylint: enable=protected-access
@property
def type(self):
"""The type of the op (e.g. `"MatMul"`)."""
return c_api.TF_OperationOpType(self._c_op)
return pywrap_tf_session.TF_OperationOpType(self._c_op)
@property
def graph(self):
@ -2238,8 +2247,8 @@ class Operation(object):
"""
# pylint: enable=line-too-long
with c_api_util.tf_buffer() as buf:
c_api.TF_OperationToNodeDef(self._c_op, buf)
data = c_api.TF_GetBuffer(buf)
pywrap_tf_session.TF_OperationToNodeDef(self._c_op, buf)
data = pywrap_tf_session.TF_GetBuffer(buf)
node_def = node_def_pb2.NodeDef()
node_def.ParseFromString(compat.as_bytes(data))
return node_def
@ -2264,17 +2273,18 @@ class Operation(object):
def _set_attr(self, attr_name, attr_value):
"""Private method used to set an attribute in the node_def."""
buf = c_api.TF_NewBufferFromString(
buf = pywrap_tf_session.TF_NewBufferFromString(
compat.as_bytes(attr_value.SerializeToString()))
try:
self._set_attr_with_buf(attr_name, buf)
finally:
c_api.TF_DeleteBuffer(buf)
pywrap_tf_session.TF_DeleteBuffer(buf)
def _set_attr_with_buf(self, attr_name, attr_buf):
"""Set an attr in the node_def with a pre-allocated buffer."""
# pylint: disable=protected-access
c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, attr_buf)
pywrap_tf_session.SetAttr(self._graph._c_graph, self._c_op, attr_name,
attr_buf)
# pylint: enable=protected-access
def _set_func_attr(self, attr_name, func_name):
@ -2307,7 +2317,7 @@ class Operation(object):
def _clear_attr(self, attr_name):
"""Private method used to clear an attribute in the node_def."""
# pylint: disable=protected-access
c_api.ClearAttr(self._graph._c_graph, self._c_op, attr_name)
pywrap_tf_session.ClearAttr(self._graph._c_graph, self._c_op, attr_name)
# pylint: enable=protected-access
def get_attr(self, name):
@ -2325,8 +2335,8 @@ class Operation(object):
fields = ("s", "i", "f", "b", "type", "shape", "tensor", "func")
try:
with c_api_util.tf_buffer() as buf:
c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf)
data = c_api.TF_GetBuffer(buf)
pywrap_tf_session.TF_OperationGetAttrValueProto(self._c_op, name, buf)
data = pywrap_tf_session.TF_GetBuffer(buf)
except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility.
raise ValueError(str(e))
@ -2352,7 +2362,7 @@ class Operation(object):
def _get_attr_type(self, name):
"""Returns the `DType` value of the attr of this op with the given `name`."""
try:
dtype_enum = c_api.TF_OperationGetAttrType(self._c_op, name)
dtype_enum = pywrap_tf_session.TF_OperationGetAttrType(self._c_op, name)
return _DTYPES_INTERN_TABLE[dtype_enum]
except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility.
@ -2361,7 +2371,7 @@ class Operation(object):
def _get_attr_bool(self, name):
"""Returns the `bool` value of the attr of this op with the given `name`."""
try:
return c_api.TF_OperationGetAttrBool(self._c_op, name)
return pywrap_tf_session.TF_OperationGetAttrBool(self._c_op, name)
except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility.
raise ValueError(str(e))
@ -2369,7 +2379,7 @@ class Operation(object):
def _get_attr_int(self, name):
"""Returns the `int` value of the attr of this op with the given `name`."""
try:
return c_api.TF_OperationGetAttrInt(self._c_op, name)
return pywrap_tf_session.TF_OperationGetAttrInt(self._c_op, name)
except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility.
raise ValueError(str(e))
@ -2812,7 +2822,7 @@ class Graph(object):
# The C API requires all ops to have shape functions. Disable this
# requirement (many custom ops do not have shape functions, and we don't
# want to break these existing cases).
c_api.SetRequireShapeInferenceFns(self._c_graph, False)
pywrap_tf_session.SetRequireShapeInferenceFns(self._c_graph, False)
if tf2.enabled():
self.switch_to_thread_local()
@ -2945,8 +2955,8 @@ class Graph(object):
"""
# pylint: enable=line-too-long
with c_api_util.tf_buffer() as buf:
c_api.TF_GraphVersions(self._c_graph, buf)
data = c_api.TF_GetBuffer(buf)
pywrap_tf_session.TF_GraphVersions(self._c_graph, buf)
data = pywrap_tf_session.TF_GetBuffer(buf)
version_def = versions_pb2.VersionDef()
version_def.ParseFromString(compat.as_bytes(data))
return version_def
@ -3047,8 +3057,8 @@ class Graph(object):
# pylint: enable=line-too-long
with self._lock:
with c_api_util.tf_buffer() as buf:
c_api.TF_GraphToGraphDef(self._c_graph, buf)
data = c_api.TF_GetBuffer(buf)
pywrap_tf_session.TF_GraphToGraphDef(self._c_graph, buf)
data = pywrap_tf_session.TF_GetBuffer(buf)
graph = graph_pb2.GraphDef()
graph.ParseFromString(compat.as_bytes(data))
# Strip the experimental library field iff it's empty.
@ -3181,7 +3191,8 @@ class Graph(object):
# pylint: disable=protected-access
gradient = (
function._grad_func._c_func.func if function._grad_func else None)
c_api.TF_GraphCopyFunction(self._c_graph, function._c_func.func, gradient)
pywrap_tf_session.TF_GraphCopyFunction(self._c_graph, function._c_func.func,
gradient)
# pylint: enable=protected-access
self._functions[compat.as_str(name)] = function
@ -3658,7 +3669,7 @@ class Graph(object):
return self._nodes_by_name[name]
def _get_operation_by_tf_operation(self, tf_oper):
op_name = c_api.TF_OperationName(tf_oper)
op_name = pywrap_tf_session.TF_OperationName(tf_oper)
return self._get_operation_by_name_unsafe(op_name)
def get_tensor_by_name(self, name):
@ -3711,9 +3722,10 @@ class Graph(object):
except KeyError:
with c_api_util.tf_buffer() as buf:
# pylint: disable=protected-access
c_api.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), buf)
pywrap_tf_session.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type),
buf)
# pylint: enable=protected-access
data = c_api.TF_GetBuffer(buf)
data = pywrap_tf_session.TF_GetBuffer(buf)
op_def = op_def_pb2.OpDef()
op_def.ParseFromString(compat.as_bytes(data))
self._op_def_cache[type] = op_def

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.client import pywrap_tf_session as c_api
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops

View File

@ -44,9 +44,9 @@ from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import _pywrap_stacktrace_handler
from tensorflow.python import _pywrap_util_port
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import tf2
from tensorflow.python.client import device_lib
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.client import session
from tensorflow.python.compat.compat import forward_compatibility_horizon
from tensorflow.python.eager import context
@ -199,7 +199,7 @@ def assert_equal_graph_def(actual, expected, checkpoint_v2=False,
_strip_hash_table_shared_name(actual)
_strip_hash_table_shared_name(expected)
diff = pywrap_tensorflow.EqualGraphDefWrapper(actual.SerializeToString(),
diff = pywrap_tf_session.EqualGraphDefWrapper(actual.SerializeToString(),
expected.SerializeToString())
if diff:
raise AssertionError(compat.as_str(diff))
@ -1694,10 +1694,10 @@ def enable_tf_xla_constant_folding(description):
def decorator(f):
def decorated(self, *args, **kwargs):
original_var = pywrap_tensorflow.TF_GetXlaConstantFoldingDisabled()
pywrap_tensorflow.TF_SetXlaConstantFoldingDisabled(False)
original_var = pywrap_tf_session.TF_GetXlaConstantFoldingDisabled()
pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(False)
result = f(self, *args, **kwargs)
pywrap_tensorflow.TF_SetXlaConstantFoldingDisabled(original_var)
pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(original_var)
return result
return decorated
@ -1799,9 +1799,9 @@ def xla_allow_fallback(description): # pylint: disable=unused-argument
# Update the global XLABuildOpsPassFlags to enable lazy compilation,
# which allows the compiler to fall back to TF classic. Remember the
# old value so that we can reset it.
old_value = pywrap_tensorflow.TF_SetXlaEnableLazyCompilation(True)
old_value = pywrap_tf_session.TF_SetXlaEnableLazyCompilation(True)
result = func(self, *args, **kwargs)
pywrap_tensorflow.TF_SetXlaEnableLazyCompilation(old_value)
pywrap_tf_session.TF_SetXlaEnableLazyCompilation(old_value)
return result
else:
return func(self, *args, **kwargs)
@ -1835,13 +1835,13 @@ class TensorFlowTestCase(googletest.TestCase):
def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
super(TensorFlowTestCase, self).__init__(methodName)
if is_xla_enabled():
pywrap_tensorflow.TF_SetXlaAutoJitMode("2")
pywrap_tensorflow.TF_SetXlaMinClusterSize(1)
pywrap_tensorflow.TF_SetXlaEnableLazyCompilation(False)
pywrap_tensorflow.TF_SetTfXlaCpuGlobalJit(True)
pywrap_tf_session.TF_SetXlaAutoJitMode("2")
pywrap_tf_session.TF_SetXlaMinClusterSize(1)
pywrap_tf_session.TF_SetXlaEnableLazyCompilation(False)
pywrap_tf_session.TF_SetTfXlaCpuGlobalJit(True)
# Constant folding secretly runs code on TF:Classic CPU, so we also
# disable it here.
pywrap_tensorflow.TF_SetXlaConstantFoldingDisabled(True)
pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(True)
self._threads = []
self._tempdir = None

View File

@ -19,14 +19,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.util.tf_export import tf_export
__version__ = pywrap_tensorflow.__version__
__git_version__ = pywrap_tensorflow.__git_version__
__compiler_version__ = pywrap_tensorflow.__compiler_version__
__cxx11_abi_flag__ = pywrap_tensorflow.__cxx11_abi_flag__
__monolithic_build__ = pywrap_tensorflow.__monolithic_build__
__version__ = pywrap_tf_session.__version__
__git_version__ = pywrap_tf_session.__git_version__
__compiler_version__ = pywrap_tf_session.__compiler_version__
__cxx11_abi_flag__ = pywrap_tf_session.__cxx11_abi_flag__
__monolithic_build__ = pywrap_tf_session.__monolithic_build__
VERSION = __version__
tf_export(
@ -61,13 +61,13 @@ tf_export(
"sysconfig.MONOLITHIC_BUILD", "MONOLITHIC_BUILD", "__monolithic_build__"
]).export_constant(__name__, "MONOLITHIC_BUILD")
GRAPH_DEF_VERSION = pywrap_tensorflow.GRAPH_DEF_VERSION
GRAPH_DEF_VERSION = pywrap_tf_session.GRAPH_DEF_VERSION
tf_export(
"version.GRAPH_DEF_VERSION",
v1=["version.GRAPH_DEF_VERSION", "GRAPH_DEF_VERSION"]).export_constant(
__name__, "GRAPH_DEF_VERSION")
GRAPH_DEF_VERSION_MIN_CONSUMER = (
pywrap_tensorflow.GRAPH_DEF_VERSION_MIN_CONSUMER)
pywrap_tf_session.GRAPH_DEF_VERSION_MIN_CONSUMER)
tf_export(
"version.GRAPH_DEF_VERSION_MIN_CONSUMER",
v1=[
@ -75,7 +75,7 @@ tf_export(
"GRAPH_DEF_VERSION_MIN_CONSUMER"
]).export_constant(__name__, "GRAPH_DEF_VERSION_MIN_CONSUMER")
GRAPH_DEF_VERSION_MIN_PRODUCER = (
pywrap_tensorflow.GRAPH_DEF_VERSION_MIN_PRODUCER)
pywrap_tf_session.GRAPH_DEF_VERSION_MIN_PRODUCER)
tf_export(
"version.GRAPH_DEF_VERSION_MIN_PRODUCER",
v1=[

View File

@ -88,6 +88,20 @@ inline void MaybeRaiseRegisteredFromTFStatus(TF_Status* status) {
}
}
inline void MaybeRaiseRegisteredFromTFStatusWithGIL(TF_Status* status) {
TF_Code code = TF_GetCode(status);
if (code != TF_OK) {
// Acquire GIL for throwing exception.
pybind11::gil_scoped_acquire acquire;
PyErr_SetObject(PyExceptionRegistry::Lookup(code),
pybind11::make_tuple(pybind11::none(), pybind11::none(),
TF_Message(status))
.ptr());
throw pybind11::error_already_set();
}
}
} // namespace tensorflow
namespace pybind11 {

View File

@ -19,8 +19,8 @@ from __future__ import division
from __future__ import print_function
from tensorflow.compiler.tf2xla.ops import gen_xla_ops
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -1128,7 +1128,7 @@ def _BroadcastToGrad(op, grad):
input_value_shape = array_ops.shape(input_value)
if not context.executing_eagerly():
broadcast_shape_static = tensor_shape.TensorShape(
pywrap_tensorflow.TF_TryEvaluateConstant_wrapper(
pywrap_tf_session.TF_TryEvaluateConstant_wrapper(
broadcast_shape.graph._c_graph, broadcast_shape._as_tf_output())) # pylint: disable=protected-access
if broadcast_shape_static.is_fully_defined():
broadcast_shape = constant_op.constant(

View File

@ -17,7 +17,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import tape as tape_lib
@ -66,7 +66,7 @@ def copy_handle_data(source_t, target_t):
and handle_data.is_set
and handle_data.shape_and_type):
# pylint: disable=protected-access
pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph,
pywrap_tf_session.SetHandleShapeAndType(target_t.graph._c_graph,
target_t._as_tf_output(),
handle_data.SerializeToString())
# pylint: enable=protected-access
@ -76,10 +76,12 @@ def copy_handle_data(source_t, target_t):
ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
shapes = [[d.size for d in s.dim] # pylint: disable=g-complex-comprehension
if not s.unknown_rank else None for s in shapes]
pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
target_t._op._graph._c_graph, # pylint: disable=protected-access
target_t._as_tf_output(), # pylint: disable=protected-access
shapes, ranks, types)
shapes,
ranks,
types)
@tf_export("custom_gradient")

View File

@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.client import pywrap_tf_session as c_api
from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op

View File

@ -26,7 +26,7 @@ import weakref
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import variable_pb2
from tensorflow.python import _pywrap_utils
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op
@ -56,7 +56,7 @@ from tensorflow.python.util.deprecation import deprecated_args
def get_resource_handle_data(graph_op):
assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck
handle_data = pywrap_tensorflow.GetHandleShapeAndType(
handle_data = pywrap_tf_session.GetHandleShapeAndType(
graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access
return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
@ -91,10 +91,12 @@ def _set_handle_shapes_and_types(tensor, handle_data, graph_mode):
ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
shapes = [[d.size for d in s.dim] # pylint: disable=g-complex-comprehension
if not s.unknown_rank else None for s in shapes]
pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
tensor._op._graph._c_graph, # pylint: disable=protected-access
tensor._as_tf_output(), # pylint: disable=protected-access
shapes, ranks, types)
shapes,
ranks,
types)
def _combine_handle_data(handle, initial_value):

View File

@ -23,7 +23,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.core.framework import resource_handle_pb2
from tensorflow.python import pywrap_tensorflow_internal
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -71,8 +71,7 @@ class TensorHandle(object):
if not self._resource_handle:
self._resource_handle = resource_handle_pb2.ResourceHandleProto()
self._resource_handle.device = self._handle.split(";")[-1]
self._resource_handle.container = (
pywrap_tensorflow_internal.TENSOR_HANDLE_KEY)
self._resource_handle.container = (pywrap_tf_session.TENSOR_HANDLE_KEY)
self._resource_handle.name = self._handle
return self._resource_handle

View File

@ -24,7 +24,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.client import pywrap_tf_session as c_api
from tensorflow.python.eager import backprop_util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes

View File

@ -56,11 +56,6 @@ try:
sys.setdlopenflags(_default_dlopen_flags | ctypes.RTLD_LOCAL)
from tensorflow.python.pywrap_tensorflow_internal import *
from tensorflow.python.pywrap_tensorflow_internal import __version__
from tensorflow.python.pywrap_tensorflow_internal import __git_version__
from tensorflow.python.pywrap_tensorflow_internal import __compiler_version__
from tensorflow.python.pywrap_tensorflow_internal import __cxx11_abi_flag__
from tensorflow.python.pywrap_tensorflow_internal import __monolithic_build__
if _use_dlopen_global_flags:
pywrap_dlopen_global_flags.reset_dlopen_flags()

View File

@ -17,8 +17,6 @@ limitations under the License.
* The includes are intentionally not alphabetically sorted, as the order of
* includes follows dependency order */
%include "tensorflow/python/client/tf_session.i"
%include "tensorflow/python/grappler/cluster.i"
%include "tensorflow/python/grappler/item.i"
%include "tensorflow/python/grappler/tf_optimizer.i"

View File

@ -322,7 +322,6 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
py::class_<TF_DeviceList> TF_DeviceList_class(m, "TF_DeviceList");
py::class_<TF_Function> TF_Function_class(m, "TF_Function");
py::class_<TF_Buffer> TF_Buffer_class(m, "TF_Buffer");
m.def("TFE_Py_RegisterExceptionClass", [](const py::handle& e) {
return tensorflow::pyo_or_throw(TFE_Py_RegisterExceptionClass(e.ptr()));
@ -369,12 +368,10 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
m.def("TFE_HostAddressSpace", [](py::handle& o, TF_Buffer& buf) {
TFE_HostAddressSpace(tensorflow::InputTFE_Context(o), &buf);
});
m.def("TFE_ContextAddFunction", [](py::handle& ctx, py::handle& func) {
m.def("TFE_ContextAddFunction", [](py::handle& ctx, TF_Function* func) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
SwigPyObject* sstable_swig = reinterpret_cast<SwigPyObject*>(func.ptr());
auto function = reinterpret_cast<TF_Function*>(sstable_swig->ptr);
TFE_ContextAddFunction(tensorflow::InputTFE_Context(ctx), function,
TFE_ContextAddFunction(tensorflow::InputTFE_Context(ctx), func,
status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
@ -1066,13 +1063,6 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
// Util buffer helper functions
m.def("TF_NewBufferFromString", &TF_NewBufferFromString,
py::return_value_policy::reference);
m.def("TF_NewBuffer", &TF_NewBuffer, py::return_value_policy::reference);
m.def("TF_GetBuffer", [](TF_Buffer* buf) {
return tensorflow::pyo_or_throw(PyBytes_FromStringAndSize(
reinterpret_cast<const char*>(buf->data), buf->length));
});
m.def("TF_DeleteBuffer", &TF_DeleteBuffer,
py::return_value_policy::reference);
// C API Enum

View File

@ -25,7 +25,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.compiler.xla import xla
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribution_strategy_context
@ -253,11 +253,11 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
"""An internal class to help manage the TF_Buffer lifetime."""
def __init__(self, buf_string):
self._buffer = pywrap_tensorflow.TF_NewBufferFromString(
self._buffer = pywrap_tf_session.TF_NewBufferFromString(
compat.as_bytes(buf_string))
def __del__(self):
pywrap_tensorflow.TF_DeleteBuffer(self._buffer)
pywrap_tf_session.TF_DeleteBuffer(self._buffer)
def __init__(self, name, num_replicas, pivot):
"""Builds a new TPUReplicateContext.

View File

@ -21,7 +21,7 @@ from __future__ import print_function
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import device_filters_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.client import pywrap_tf_session as c_api
from tensorflow.python.framework import errors
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation

View File

@ -134,15 +134,25 @@ def get_symbols(path_to_lib, re_filter):
# Example symbol line:
# 954 00000000 SECT2BD notype () External | ?IsSequence@swig@tensorflow@@YA_NPEAU_object@@@Z (bool __cdecl tensorflow::swig::IsSequence(struct _object *))
# Anomaly symbol line:
# 00B 00000000 SECT4 notype External | _tensorflow_numpy_api.
sym_filtered = []
re_filter_comp = re.compile(r"{}".format(re_filter))
# Filter out symbol from the split line (`sym_split` in the for loop below).
sym_line_filter = r".*\s+\| (.*) \(.*"
sym_line_filter_anomaly = r".*\s+\| (.*)"
for sym_line in sym_split:
if re_filter_comp.search(sym_line):
sym = re.match(sym_line_filter, sym_line).groups()[0]
try:
sym = re.match(sym_line_filter, sym_line).groups()[0]
except AttributeError:
try:
sym = re.match(sym_line_filter_anomaly, sym_line).groups()[0]
except:
raise RuntimeError("Unable to find the following symbol:[%s]" % sym_line)
sym_filtered.append(sym)
return sym_filtered

View File

@ -75,12 +75,13 @@ tensorflow::Status::code
tensorflow::Status::error_message
tensorflow::Status::ok()
[core_cpu_impl] # device_lib tfe
[core_cpu_impl] # device_lib, tfe, tf_session
tensorflow::Device::attributes
tensorflow::DeviceFactory::AddDevices
tensorflow::SessionOptions::SessionOptions
tensorflow::DoQuantizeTrainingOnSerializedGraphDef
tensorflow::DeviceFactory::ListAllPhysicalDevices
tensorflow::SessionState::kTensorHandleResourceTypeName
[protos_all] # device_lib, dtypes
tensorflow::DataType_IsValid
@ -195,3 +196,50 @@ tensorflow::ExperimentalRunPassPipeline
tensorflow::ExperimentalConvertSavedModelV1ToMlir
tensorflow::ExperimentalConvertSavedModelToMlir
tensorflow::ImportGraphDef
[op_gen_lib] # tf_session
tensorflow::ApiDefMap::~ApiDefMap
[core_cpu_base_no_ops] # tf_session
tensorflow::ShapeRefiner::~ShapeRefiner
[python_api] # tf_session
tensorflow::AddControlInput
tensorflow::SetAttr
tensorflow::ClearAttr
tensorflow::SetRequestedDevice
tensorflow::UpdateEdge
tensorflow::RemoveAllControlInputs
tensorflow::SetRequireShapeInferenceFns
tensorflow::ExtendSession
tensorflow::GetHandleShapeAndType
tensorflow::SetHandleShapeAndType
tensorflow::AddWhileInputHack
[numpy_lib] # tf_session
tensorflow::ImportNumpy
_tensorflow_numpy_api
[tf_session_helper] # tf_session
tensorflow::TF_NewSessionRef
tensorflow::TF_SessionMakeCallable
tensorflow::TF_SessionRunCallable
tensorflow::TF_SessionReleaseCallable
tensorflow::TF_Reset_wrapper
tensorflow::EqualGraphDefWrapper
tensorflow::EqualAttrValueWrapper
tensorflow::TF_GraphGetTensorShapeHelper
tensorflow::TF_SessionRun_wrapper
tensorflow::TF_SessionPRunSetup_wrapper
tensorflow::TF_SessionPRun_wrapper
tensorflow::GetOperationInputs
tensorflow::TF_OperationGetControlInputs_wrapper
tensorflow::TF_OperationGetControlOutputs_wrapper
tensorflow::TF_OperationOutputConsumers_wrapper
tensorflow::TF_GraphToFunction_wrapper
tensorflow::TF_GraphSetOutputHandleShapesAndTypes_wrapper
tensorflow::TF_CreatePlaceholders
tensorflow::TF_GraphSetTensorShape_wrapper
tensorflow::TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper
tensorflow::TF_TryEvaluateConstant_wrapper