Export the TF Session classes and functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.
PiperOrigin-RevId: 292259851 Change-Id: If5abe93f9cf25018d185e220d4dfbc216b5f3b32
This commit is contained in:
parent
4ca8bf54d7
commit
a02fe6c24a
tensorflow
c
core/common_runtime
python
BUILD
client
pywrap_tf_session.pysession.pysession_list_devices_test.pytf_session.itf_session_wrapper.cctf_sessionrun_wrapper.i
data/experimental/kernel_tests
debug
eager
framework
c_api_util.pyerrors_impl.pyerrors_test.pyfunction.pyimporter.pykernels.pyload_library.pymeta_graph.pyops.pysmart_cond.pytest_util.pyversions.py
lib/core
ops
pywrap_tensorflow.pytensorflow.itfe_wrapper.cctpu
training
tools/def_file_filter
@ -57,6 +57,7 @@ filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"c_api_internal.h",
|
||||
"python_api.h",
|
||||
"tf_status_helper.h",
|
||||
"tf_status_internal.h",
|
||||
"tf_tensor_internal.h",
|
||||
@ -98,6 +99,17 @@ tf_cuda_library(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_tf_session_hdrs",
|
||||
srcs = [
|
||||
"python_api.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_attrtype",
|
||||
hdrs = ["tf_attrtype.h"],
|
||||
|
@ -18,6 +18,9 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Adjust value in third_party/tensorflow/python/client/tf_session_wrapper.cc
|
||||
// in the get_tensor_handle_key function if adjusting the value for
|
||||
// kTensorHandleResourceTypeName.
|
||||
const char* SessionState::kTensorHandleResourceTypeName = "TensorHandle";
|
||||
|
||||
Status SessionState::GetTensor(const string& handle, Tensor* tensor) {
|
||||
|
@ -173,6 +173,7 @@ py_library(
|
||||
":platform",
|
||||
":proto_ops",
|
||||
":pywrap_tensorflow",
|
||||
":pywrap_tf_session",
|
||||
":pywrap_tfe",
|
||||
":rnn_ops_gen",
|
||||
":saver_test_utils",
|
||||
@ -559,6 +560,58 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "pywrap_tf_session",
|
||||
srcs = ["client/pywrap_tf_session.py"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":_pywrap_tf_session",
|
||||
":pywrap_tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_tf_session",
|
||||
srcs = ["client/tf_session_wrapper.cc"],
|
||||
hdrs = [
|
||||
"client/tf_session_helper.h",
|
||||
"lib/core/numpy.h",
|
||||
"lib/core/safe_ptr.h",
|
||||
"//tensorflow/c:headers",
|
||||
"//tensorflow/c:pywrap_required_hdrs",
|
||||
"//tensorflow/c/eager:headers",
|
||||
"//tensorflow/core/common_runtime/eager:pywrap_required_hdrs",
|
||||
"//tensorflow/core/distributed_runtime:pywrap_required_hdrs",
|
||||
"//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs",
|
||||
"//tensorflow/core/framework:pywrap_required_hdrs",
|
||||
],
|
||||
module_name = "_pywrap_tf_session",
|
||||
deps = [
|
||||
":pybind11_lib",
|
||||
":pybind11_status",
|
||||
"//third_party/py/numpy:headers",
|
||||
"@pybind11",
|
||||
"//third_party/python_runtime:headers",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:core_cpu_headers_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
] + if_static(
|
||||
extra_deps = [
|
||||
"//tensorflow/core:eager_service_proto_cc",
|
||||
"//tensorflow/core:master_proto_cc",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
],
|
||||
otherwise = [
|
||||
"//tensorflow/core:eager_service_proto_cc_headers_only",
|
||||
"//tensorflow/core:master_proto_cc_headers_only",
|
||||
"//tensorflow/core:worker_proto_cc_headers_only",
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_tfprof",
|
||||
srcs = ["util/tfprof_wrapper.cc"],
|
||||
@ -1173,6 +1226,7 @@ py_library(
|
||||
":lib",
|
||||
":platform",
|
||||
":pywrap_tensorflow",
|
||||
":pywrap_tf_session",
|
||||
":pywrap_tfe",
|
||||
":pywrap_mlir",
|
||||
":random_seed",
|
||||
@ -1194,7 +1248,7 @@ py_library(
|
||||
srcs = ["framework/c_api_util.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":pywrap_tensorflow",
|
||||
":pywrap_tf_session",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
)
|
||||
@ -1271,6 +1325,7 @@ py_library(
|
||||
":_pywrap_py_exception_registry",
|
||||
":c_api_util",
|
||||
":error_interpolation",
|
||||
":pywrap_tf_session",
|
||||
":util",
|
||||
],
|
||||
)
|
||||
@ -1296,6 +1351,7 @@ py_library(
|
||||
":framework_ops",
|
||||
":graph_to_function_def",
|
||||
":op_def_registry",
|
||||
":pywrap_tf_session",
|
||||
":util",
|
||||
":variable_scope",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
@ -1387,7 +1443,7 @@ py_library(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":pywrap_tensorflow",
|
||||
":pywrap_tf_session",
|
||||
":util",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
@ -1629,6 +1685,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":control_flow_ops",
|
||||
":pywrap_tf_session",
|
||||
":tensor_util",
|
||||
],
|
||||
)
|
||||
@ -1801,7 +1858,7 @@ py_library(
|
||||
srcs = ["framework/versions.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":pywrap_tensorflow",
|
||||
":pywrap_tf_session",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1837,7 +1894,7 @@ py_library(
|
||||
":gpu_util",
|
||||
":platform",
|
||||
":platform_test",
|
||||
":pywrap_tensorflow",
|
||||
":pywrap_tf_session",
|
||||
":random_seed",
|
||||
":resource_variable_ops",
|
||||
":session",
|
||||
@ -3244,6 +3301,7 @@ py_library(
|
||||
":functional_ops_gen",
|
||||
":gradients_util",
|
||||
":list_ops",
|
||||
":pywrap_tf_session",
|
||||
":tensor_array_ops",
|
||||
":tensor_shape",
|
||||
":tensor_util",
|
||||
@ -3334,6 +3392,7 @@ py_library(
|
||||
deps = [
|
||||
":gradients_impl",
|
||||
":gradients_util",
|
||||
":pywrap_tf_session",
|
||||
":unconnected_gradients",
|
||||
"//tensorflow/python/eager:forwardprop",
|
||||
"//tensorflow/python/eager:function",
|
||||
@ -3776,6 +3835,7 @@ py_library(
|
||||
":framework_for_generated_wrappers",
|
||||
":math_ops",
|
||||
":math_ops_gen",
|
||||
":pywrap_tf_session",
|
||||
":tensor_util",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//third_party/py/numpy",
|
||||
@ -3838,6 +3898,7 @@ py_library(
|
||||
":array_ops_gen",
|
||||
":dtypes",
|
||||
":framework_ops",
|
||||
":pywrap_tf_session",
|
||||
":resource_variable_ops_gen",
|
||||
":tensor_shape",
|
||||
":util",
|
||||
@ -5070,7 +5131,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework",
|
||||
":pywrap_tensorflow",
|
||||
":pywrap_tf_session",
|
||||
":util",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
@ -5588,8 +5649,6 @@ tf_py_wrap_cc(
|
||||
name = "pywrap_tensorflow_internal",
|
||||
srcs = ["tensorflow.i"],
|
||||
swig_includes = [
|
||||
"client/tf_session.i",
|
||||
"client/tf_sessionrun_wrapper.i",
|
||||
"grappler/cluster.i",
|
||||
"grappler/cost_analyzer.i",
|
||||
"grappler/item.i",
|
||||
@ -5683,6 +5742,10 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [
|
||||
"//tensorflow/core/profiler/lib:profiler_session", # tfe
|
||||
"//tensorflow/c:tf_status_helper", # tfe
|
||||
"//tensorflow/compiler/mlir/python:mlir", # mlir
|
||||
"//tensorflow/core:op_gen_lib", # tf_session
|
||||
"//tensorflow/core:core_cpu_base_no_ops", # tf_session
|
||||
"//tensorflow/c:python_api", # tf_session
|
||||
"//tensorflow/python:tf_session_helper", # tf_session
|
||||
]
|
||||
|
||||
# Filter the DEF file to reduce the number of symbols to 64K or less.
|
||||
|
70
tensorflow/python/client/pywrap_tf_session.py
Normal file
70
tensorflow/python/client/pywrap_tf_session.py
Normal file
@ -0,0 +1,70 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Python module for Session ops, vars, and functions exported by pybind11."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=invalid-import-order,g-bad-import-order, wildcard-import, unused-import
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python._pywrap_tf_session import *
|
||||
from tensorflow.python._pywrap_tf_session import _TF_SetTarget
|
||||
from tensorflow.python._pywrap_tf_session import _TF_SetConfig
|
||||
from tensorflow.python._pywrap_tf_session import _TF_NewSessionOptions
|
||||
|
||||
# Convert versions to strings for Python2 and keep api_compatibility_test green.
|
||||
# We can remove this hack once we remove Python2 presubmits. pybind11 can only
|
||||
# return unicode for Python2 even with py::str.
|
||||
# https://pybind11.readthedocs.io/en/stable/advanced/cast/strings.html#returning-c-strings-to-python
|
||||
# pylint: disable=undefined-variable
|
||||
__version__ = str(get_version())
|
||||
__git_version__ = str(get_git_version())
|
||||
__compiler_version__ = str(get_compiler_version())
|
||||
__cxx11_abi_flag__ = get_cxx11_abi_flag()
|
||||
__monolithic_build__ = get_monolithic_build()
|
||||
|
||||
# User getters to hold attributes rather than pybind11's m.attr due to
|
||||
# b/145559202.
|
||||
GRAPH_DEF_VERSION = get_graph_def_version()
|
||||
GRAPH_DEF_VERSION_MIN_CONSUMER = get_graph_def_version_min_consumer()
|
||||
GRAPH_DEF_VERSION_MIN_PRODUCER = get_graph_def_version_min_producer()
|
||||
TENSOR_HANDLE_KEY = get_tensor_handle_key()
|
||||
|
||||
# pylint: enable=undefined-variable
|
||||
|
||||
|
||||
# Disable pylint invalid name warnings for legacy functions.
|
||||
# pylint: disable=invalid-name
|
||||
def TF_NewSessionOptions(target=None, config=None):
|
||||
# NOTE: target and config are validated in the session constructor.
|
||||
opts = _TF_NewSessionOptions()
|
||||
if target is not None:
|
||||
_TF_SetTarget(opts, target)
|
||||
if config is not None:
|
||||
config_str = config.SerializeToString()
|
||||
_TF_SetConfig(opts, config_str)
|
||||
return opts
|
||||
|
||||
|
||||
# Disable pylind undefined-variable as the variable is exported in the shared
|
||||
# object via pybind11.
|
||||
# pylint: disable=undefined-variable
|
||||
def TF_Reset(target, containers=None, config=None):
|
||||
opts = TF_NewSessionOptions(target=target, config=config)
|
||||
try:
|
||||
TF_Reset_wrapper(opts, containers)
|
||||
finally:
|
||||
TF_DeleteSessionOptions(opts)
|
@ -28,7 +28,7 @@ import wrapt
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python import pywrap_tensorflow as tf_session
|
||||
from tensorflow.python.client import pywrap_tf_session as tf_session
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import monitoring
|
||||
from tensorflow.python.framework import device
|
||||
|
@ -21,7 +21,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.core.protobuf import cluster_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import pywrap_tensorflow as tf_session
|
||||
from tensorflow.python.client import pywrap_tf_session as tf_session
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
|
@ -1,877 +0,0 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%include "tensorflow/python/lib/core/strings.i"
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%{
|
||||
|
||||
#include "tensorflow/c/python_api.h"
|
||||
#include "tensorflow/core/framework/session_state.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#include "tensorflow/python/client/tf_session_helper.h"
|
||||
#include "tensorflow/c/c_api_experimental.h"
|
||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||
// We were getting lucky on imports with safe_ptr.h being placed prior to
|
||||
// tf_session which imported safe_ptr. We also need pywrap_tfe.h to cast
|
||||
// one of the inputs to a graph function from a Python string to const char*.
|
||||
|
||||
|
||||
// Helper function to convert a Python list of Tensors to a C++ vector of
|
||||
// TF_Outputs.
|
||||
//
|
||||
// Returns true if successful. Otherwise, returns false and sets error_msg.
|
||||
bool PyTensorListToVector(PyObject* py_tensor_list,
|
||||
std::vector<TF_Output>* vec,
|
||||
string* error_msg) {
|
||||
if (!PyList_Check(py_tensor_list)) {
|
||||
*error_msg = "expected Python list.";
|
||||
return false;
|
||||
}
|
||||
size_t size = PyList_Size(py_tensor_list);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
PyObject* item = PyList_GetItem(py_tensor_list, i);
|
||||
TF_Output* input_ptr;
|
||||
if (!SWIG_IsOK(SWIG_ConvertPtr(item, reinterpret_cast<void**>(&input_ptr),
|
||||
SWIGTYPE_p_TF_Output, 0))) {
|
||||
*error_msg = "expected Python list of wrapped TF_Output objects. "
|
||||
"Found python list of something else.";
|
||||
return false;
|
||||
}
|
||||
vec->push_back(*input_ptr);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Helper function to convert a TF_Output to a wrapped TF_Output Python object.
|
||||
PyObject* CreateWrappedTFOutput(TF_Output tf_output) {
|
||||
// We used heap-allocated pointers in the Python runtime (this is what SWIG
|
||||
// generates by default for functions returning TF_Output).
|
||||
TF_Output* tf_output_ptr = new TF_Output(tf_output);
|
||||
// Use SWIG_POINTER_OWN so the TF_Output* is deleted by Python.
|
||||
return SWIG_NewPointerObj(tf_output_ptr, SWIGTYPE_p_TF_Output,
|
||||
SWIG_POINTER_OWN);
|
||||
}
|
||||
|
||||
// Helper function to convert a TF_Operation to a wrapped TF_Operation Python
|
||||
// object.
|
||||
PyObject* CreateWrappedTFOperation(TF_Operation* tf_operation) {
|
||||
// No flags since operation is owned by TF_Graph.
|
||||
return SWIG_NewPointerObj(tf_operation, SWIGTYPE_p_TF_Operation, 0);
|
||||
}
|
||||
|
||||
// Helper function to convert a Python list of ints to a C++ vector of int64s
|
||||
void PyInt64ListToVector(PyObject* py_int_seq, std::vector<int64_t>* vec) {
|
||||
int size = PySequence_Fast_GET_SIZE(py_int_seq);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
PyObject* item = PySequence_Fast_GET_ITEM(py_int_seq, i);
|
||||
vec->push_back(PyLong_AsLongLong(item));
|
||||
}
|
||||
}
|
||||
|
||||
%}
|
||||
|
||||
%include "tensorflow/c/tf_datatype.h"
|
||||
%include "tensorflow/c/tf_status.h"
|
||||
|
||||
%include "tensorflow/python/client/tf_sessionrun_wrapper.i"
|
||||
|
||||
// Required to use PyArray_* functions.
|
||||
%init %{
|
||||
tensorflow::ImportNumpy();
|
||||
%}
|
||||
|
||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
||||
// Hence the 'const_cast'.
|
||||
%typemap(in) const char* op_name {
|
||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
||||
}
|
||||
|
||||
|
||||
// TensorFlow version and GraphDef versions
|
||||
%constant const char* __version__ = TF_VERSION_STRING;
|
||||
%constant int GRAPH_DEF_VERSION = TF_GRAPH_DEF_VERSION;
|
||||
%constant int GRAPH_DEF_VERSION_MIN_CONSUMER = TF_GRAPH_DEF_VERSION_MIN_CONSUMER;
|
||||
%constant int GRAPH_DEF_VERSION_MIN_PRODUCER = TF_GRAPH_DEF_VERSION_MIN_PRODUCER;
|
||||
|
||||
// Git version information
|
||||
%constant const char* __git_version__ = tf_git_version();
|
||||
|
||||
// Compiler
|
||||
%constant const char* __compiler_version__ = tf_compiler_version();
|
||||
|
||||
// _GLIBCXX_USE_CXX11_ABI flag value
|
||||
%constant const int __cxx11_abi_flag__ = tf_cxx11_abi_flag();
|
||||
|
||||
// Flag indicating whether the build is monolithic
|
||||
%constant const int __monolithic_build__ = tf_monolithic_build();
|
||||
|
||||
// Release the Python GIL for the duration of most methods.
|
||||
%exception {
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
$action
|
||||
Py_END_ALLOW_THREADS;
|
||||
}
|
||||
|
||||
// The target input to TF_SetTarget() is passed as a null-terminated
|
||||
// const char*.
|
||||
%typemap(in) (const char* target) {
|
||||
$1 = PyBytes_AsString($input);
|
||||
if (!$1) {
|
||||
// Python has raised an error.
|
||||
SWIG_fail;
|
||||
}
|
||||
}
|
||||
|
||||
// Constants used by TensorHandle (get_session_handle).
|
||||
%constant const char* TENSOR_HANDLE_KEY = tensorflow::SessionState::kTensorHandleResourceTypeName;
|
||||
|
||||
// Convert TF_OperationName output to unicode python string
|
||||
%typemap(out) const char* TF_OperationName {
|
||||
$result = PyUnicode_FromString($1);
|
||||
}
|
||||
|
||||
// Convert TF_OperationOpType output to unicode python string
|
||||
%typemap(out) const char* TF_OperationOpType {
|
||||
$result = PyUnicode_FromString($1);
|
||||
}
|
||||
|
||||
// Convert TF_DeviceListMemoryBytes and TF_Dim int64_t output to Python integers
|
||||
%typemap(out) int64_t {
|
||||
$result = PyLong_FromLongLong($1);
|
||||
}
|
||||
|
||||
// Convert TF_DeviceListIncarnation uint64_t output to Python integer
|
||||
%typemap(out) uint64_t {
|
||||
$result = PyLong_FromUnsignedLongLong($1);
|
||||
}
|
||||
|
||||
// Convert TF_OperationGetAttrType TF_DataType* out-argument to Python integer.
|
||||
%typemap(in, numinputs=0) TF_DataType *value (TF_DataType temp) {
|
||||
$1 = &temp;
|
||||
}
|
||||
%typemap(argout) TF_DataType *value {
|
||||
$result = PyInt_FromLong(*$1);
|
||||
}
|
||||
|
||||
// Convert TF_OperationGetAttrBool unsigned char* out-argument to Python bool.
|
||||
%typemap(in, numinputs=0) unsigned char *value (unsigned char temp) {
|
||||
$1 = &temp;
|
||||
}
|
||||
%typemap(argout) unsigned char *value {
|
||||
$result = PyBool_FromLong(*$1);
|
||||
}
|
||||
|
||||
// Convert TF_OperationGetAttrInt int64_t* out-argument to Python bool.
|
||||
%typemap(in, numinputs=0) int64_t *value (int64_t temp) {
|
||||
$1 = &temp;
|
||||
}
|
||||
%typemap(argout) int64_t *value {
|
||||
$result = PyLong_FromLongLong(*$1);
|
||||
}
|
||||
|
||||
// We use TF_OperationGetControlInputs_wrapper instead of
|
||||
// TF_OperationGetControlInputs
|
||||
%ignore TF_OperationGetControlInputs;
|
||||
%unignore TF_OperationGetControlInputs_wrapper;
|
||||
// See comment for "%noexception TF_SessionRun_wrapper;"
|
||||
%noexception TF_OperationGetControlInputs_wrapper;
|
||||
|
||||
|
||||
// Migrate one function from pywrap_tfe.i
|
||||
%include "tensorflow/c/c_api_experimental.h"
|
||||
%unignore TF_ImportGraphDefOptionsSetValidateColocationConstraints;
|
||||
%noexception TF_ImportGraphDefOptionsSetValidateColocationConstraints;
|
||||
|
||||
// Build a Python list of TF_Operation* and return it.
|
||||
%typemap(out) std::vector<TF_Operation*> tensorflow::TF_OperationGetControlInputs_wrapper {
|
||||
$result = PyList_New($1.size());
|
||||
if (!$result) {
|
||||
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < $1.size(); ++i) {
|
||||
PyList_SET_ITEM($result, i, CreateWrappedTFOperation($1[i]));
|
||||
}
|
||||
}
|
||||
|
||||
// We use TF_OperationGetControlOutputs_wrapper instead of
|
||||
// TF_OperationGetControlOutputs
|
||||
%ignore TF_OperationGetControlOutputs;
|
||||
%unignore TF_OperationGetControlOutputs_wrapper;
|
||||
// See comment for "%noexception TF_SessionRun_wrapper;"
|
||||
%noexception TF_OperationGetControlOutputs_wrapper;
|
||||
|
||||
// Build a Python list of TF_Operation* and return it.
|
||||
%typemap(out) std::vector<TF_Operation*> tensorflow::TF_OperationGetControlOutputs_wrapper {
|
||||
$result = PyList_New($1.size());
|
||||
if (!$result) {
|
||||
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < $1.size(); ++i) {
|
||||
PyList_SET_ITEM($result, i, CreateWrappedTFOperation($1[i]));
|
||||
}
|
||||
}
|
||||
|
||||
%ignore TF_OperationOutputConsumers;
|
||||
%unignore TF_OperationOutputConsumers_wrapper;
|
||||
// See comment for "%noexception TF_SessionRun_wrapper;"
|
||||
%noexception TF_OperationGetOutputConsumers_wrapper;
|
||||
|
||||
// Build a Python list of unicode strings and return it. (Operation names are
|
||||
// always represented as unicode.)
|
||||
%typemap(out) std::vector<const char*>
|
||||
tensorflow::TF_OperationOutputConsumers_wrapper {
|
||||
$result = PyList_New($1.size());
|
||||
if (!$result) {
|
||||
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < $1.size(); ++i) {
|
||||
PyList_SET_ITEM($result, i, PyUnicode_FromString($1[i]));
|
||||
}
|
||||
}
|
||||
|
||||
%unignore GetOperationInputs;
|
||||
// See comment for "%noexception TF_SessionRun_wrapper;"
|
||||
%noexception GetOperationInputs;
|
||||
|
||||
// Build a Python list of TF_Outputs and return it.
|
||||
// TODO(skyewm): is there some way to generalize this pattern? Maybe a macro?
|
||||
%typemap(out) std::vector<TF_Output> tensorflow::GetOperationInputs {
|
||||
$result = PyList_New($1.size());
|
||||
if (!$result) {
|
||||
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
|
||||
}
|
||||
|
||||
// Unwrap the generated SwigValueWrapper<std::vector<TF_Output>>
|
||||
const std::vector<TF_Output>& tf_outputs = $1;
|
||||
for (size_t i = 0; i < tf_outputs.size(); ++i) {
|
||||
PyList_SET_ITEM($result, i, CreateWrappedTFOutput(tf_outputs[i]));
|
||||
}
|
||||
}
|
||||
|
||||
%ignore TF_ImportGraphDefResultsMissingUnusedInputMappings;
|
||||
%unignore TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper;
|
||||
// See comment for "%noexception TF_SessionRun_wrapper;"
|
||||
%noexception TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper;
|
||||
|
||||
%typemap(out) std::vector<string>
|
||||
TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{
|
||||
$result = PyList_New($1.size());
|
||||
if (!$result) {
|
||||
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
|
||||
}
|
||||
for (size_t i = 0; i < $1.size(); ++i) {
|
||||
const string& input_str = $1[i];
|
||||
PyList_SET_ITEM($result, i, PyBytes_FromStringAndSize(input_str.data(),
|
||||
input_str.size()));
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// BEGIN TYPEMAPS FOR tensorflow::TF_Run_wrapper()
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Converts a python list of strings to NameVector.
|
||||
// Has multiple users including feeds/fetches names and function output names
|
||||
%typemap(in) const tensorflow::NameVector& (
|
||||
tensorflow::NameVector temp,
|
||||
tensorflow::Safe_PyObjectPtr temp_string_list(
|
||||
tensorflow::make_safe(static_cast<PyObject*>(nullptr)))) {
|
||||
if (!PyList_Check($input)) {
|
||||
SWIG_exception_fail(
|
||||
SWIG_TypeError,
|
||||
tensorflow::strings::Printf(
|
||||
"Expected a python list for conversion "
|
||||
"to tensorflow::NameVector but got %s",
|
||||
Py_TYPE($input)->tp_name).c_str());
|
||||
}
|
||||
|
||||
Py_ssize_t len = PyList_Size($input);
|
||||
|
||||
temp_string_list = tensorflow::make_safe(PyList_New(len));
|
||||
if (!temp_string_list) {
|
||||
SWIG_exception_fail(
|
||||
SWIG_MemoryError,
|
||||
tensorflow::strings::Printf("Failed to create a list of size %zd",
|
||||
len).c_str());
|
||||
}
|
||||
|
||||
for (Py_ssize_t i = 0; i < len; ++i) {
|
||||
PyObject* elem = PyList_GetItem($input, i);
|
||||
if (!elem) {
|
||||
SWIG_fail;
|
||||
}
|
||||
|
||||
// Keep a reference to the string in case the incoming list is modified.
|
||||
PyList_SET_ITEM(temp_string_list.get(), i, elem);
|
||||
Py_INCREF(elem);
|
||||
|
||||
char* string_elem = PyBytes_AsString(elem);
|
||||
if (!string_elem) {
|
||||
SWIG_exception_fail(
|
||||
SWIG_TypeError,
|
||||
tensorflow::strings::Printf(
|
||||
"Element %zd was of type %s instead of a string",
|
||||
i, Py_TYPE(elem)->tp_name).c_str());
|
||||
}
|
||||
|
||||
// TODO(mrry): Avoid copying the fetch name in, if this impacts performance.
|
||||
temp.push_back(string_elem);
|
||||
}
|
||||
$1 = &temp;
|
||||
}
|
||||
|
||||
// Define temporaries for the argout outputs.
|
||||
%typemap(in, numinputs=0) tensorflow::PyObjectVector* out_values (
|
||||
tensorflow::PyObjectVector temp) {
|
||||
$1 = &temp;
|
||||
}
|
||||
// TODO(iga): move this and the corresponding typemap(argout) to
|
||||
// tf_sessionrun_wrapper.i once we get rid of this code for DeprecatedSession.
|
||||
%typemap(in, numinputs=0) char** out_handle (
|
||||
char* temp) {
|
||||
$1 = &temp;
|
||||
}
|
||||
|
||||
// Build a Python list of outputs and return it.
|
||||
%typemap(argout) tensorflow::PyObjectVector* out_values {
|
||||
std::vector<tensorflow::Safe_PyObjectPtr> out_values_safe;
|
||||
for (size_t i = 0; i < $1->size(); ++i) {
|
||||
out_values_safe.emplace_back(tensorflow::make_safe($1->at(i)));
|
||||
}
|
||||
|
||||
$result = PyList_New($1->size());
|
||||
if (!$result) {
|
||||
SWIG_exception_fail(
|
||||
SWIG_MemoryError,
|
||||
tensorflow::strings::Printf("Failed to create a list of size %zd",
|
||||
$1->size()).c_str());
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < $1->size(); ++i) {
|
||||
PyList_SET_ITEM($result, i, $1->at(i));
|
||||
out_values_safe[i].release();
|
||||
}
|
||||
}
|
||||
|
||||
// Return the handle as a python string object.
|
||||
%typemap(argout) char** out_handle {
|
||||
%#if PY_MAJOR_VERSION < 3
|
||||
$result = PyString_FromStringAndSize(
|
||||
%#else
|
||||
$result = PyUnicode_FromStringAndSize(
|
||||
%#endif
|
||||
*$1, *$1 == nullptr ? 0 : strlen(*$1));
|
||||
delete[] *$1;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// END TYPEMAPS FOR tensorflow::TF_Run_wrapper()
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Typemap for TF_Status* inputs that automatically unwraps a ScopedTFStatus.
|
||||
// This can also handle a wrapped TF_Status* input.
|
||||
%typemap(in) (TF_Status*) {
|
||||
PyObject* wrapped_tf_status;
|
||||
if (strcmp(Py_TYPE($input)->tp_name, "ScopedTFStatus") == 0) {
|
||||
DCHECK(PyObject_HasAttrString($input, "status"))
|
||||
<< "ScopedTFStatus.status not found! Do you need to modify "
|
||||
"tf_session.i?";
|
||||
wrapped_tf_status = PyObject_GetAttrString($input, "status");
|
||||
} else {
|
||||
// Assume wrapped TF_Status*
|
||||
wrapped_tf_status = $input;
|
||||
}
|
||||
DCHECK_EQ(strcmp(Py_TYPE(wrapped_tf_status)->tp_name, "SwigPyObject"), 0)
|
||||
<< Py_TYPE(wrapped_tf_status)->tp_name;
|
||||
|
||||
// The following is the default SWIG code generated for TF_Status*
|
||||
void* tf_status = nullptr;
|
||||
int r = SWIG_ConvertPtr(wrapped_tf_status, &tf_status,
|
||||
$descriptor(TF_Status*), 0 | 0);
|
||||
if (!SWIG_IsOK(r)) {
|
||||
SWIG_exception_fail(
|
||||
SWIG_ArgError(r),
|
||||
"in method '_TF_DeleteStatus', argument 1 of type 'TF_Status *'");
|
||||
}
|
||||
$1 = reinterpret_cast<TF_Status*>(tf_status);
|
||||
}
|
||||
|
||||
// Typemap for functions that return a TF_Buffer struct. This typemap creates a
|
||||
// Python string from the TF_Buffer and returns it. The TF_Buffer.data string
|
||||
// is not expected to be NULL-terminated, and TF_Buffer.length does not count
|
||||
// the terminator.
|
||||
%typemap(out) TF_Buffer (TF_GetOpList,TF_GetBuffer) {
|
||||
$result = PyBytes_FromStringAndSize(
|
||||
reinterpret_cast<const char*>($1.data), $1.length);
|
||||
}
|
||||
|
||||
// Converts input Python list of wrapped TF_Outputs into a single array
|
||||
%typemap(in) (const TF_Output* inputs, int num_inputs)
|
||||
(std::vector<TF_Output> inputs) {
|
||||
string error_msg;
|
||||
if (!PyTensorListToVector($input, &inputs, &error_msg)) {
|
||||
SWIG_exception_fail(SWIG_TypeError, ("$symname: " + error_msg).c_str());
|
||||
}
|
||||
$1 = inputs.data();
|
||||
$2 = inputs.size();
|
||||
}
|
||||
|
||||
// Typemaps for TF_ImportGraphDefResultsReturnOutputs
|
||||
%typemap(in, numinputs=0) (int* num_outputs, TF_Output** outputs)
|
||||
(int num_outputs, TF_Output* outputs) {
|
||||
$1 = &num_outputs;
|
||||
$2 = &outputs;
|
||||
}
|
||||
|
||||
%typemap(argout) (int* num_outputs, TF_Output** outputs) {
|
||||
$result = PyList_New(*$1);
|
||||
if (!$result) {
|
||||
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
|
||||
}
|
||||
int num_outputs = *$1;
|
||||
TF_Output* outputs = *$2;
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
PyList_SET_ITEM($result, i, CreateWrappedTFOutput(outputs[i]));
|
||||
}
|
||||
}
|
||||
|
||||
// Typemaps for TF_ImportGraphDefResultsReturnOperations
|
||||
%typemap(in, numinputs=0) (int* num_opers, TF_Operation*** opers)
|
||||
(int num_opers, TF_Operation** opers) {
|
||||
$1 = &num_opers;
|
||||
$2 = &opers;
|
||||
}
|
||||
|
||||
%typemap(argout) (int* num_opers, TF_Operation*** opers) {
|
||||
$result = PyList_New(*$1);
|
||||
if (!$result) {
|
||||
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
|
||||
}
|
||||
int num_opers = *$1;
|
||||
TF_Operation** opers = *$2;
|
||||
for (int i = 0; i < num_opers; ++i) {
|
||||
PyList_SET_ITEM($result, i, CreateWrappedTFOperation(opers[i]));
|
||||
}
|
||||
}
|
||||
|
||||
// Typemaps for TF_GraphNextOperation().
|
||||
%typemap(in) size_t* pos (size_t pos) {
|
||||
pos = PyLong_AsUnsignedLong($input);
|
||||
$1 = &pos;
|
||||
}
|
||||
|
||||
// Returns a (TF_Operation*, int pos) tuple.
|
||||
%typemap(argout) size_t* pos {
|
||||
PyObject* new_result = PyTuple_New(2);
|
||||
if (!new_result) {
|
||||
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create tuple");
|
||||
}
|
||||
// Steals $result reference
|
||||
PyTuple_SET_ITEM(new_result, 0, $result);
|
||||
PyTuple_SET_ITEM(new_result, 1, PyLong_FromSize_t(*$1));
|
||||
$result = new_result;
|
||||
}
|
||||
|
||||
%typemap(in, numinputs=0) int64_t* out_handle (int64_t out_handle) {
|
||||
$1 = &out_handle;
|
||||
}
|
||||
|
||||
%typemap(argout) int64_t* out_handle {
|
||||
$result = PyLong_FromLongLong(*$1);
|
||||
}
|
||||
|
||||
%typemap(in) int64_t handle {
|
||||
if (!PyLong_Check($input)) {
|
||||
SWIG_exception_fail(
|
||||
SWIG_TypeError,
|
||||
tensorflow::strings::Printf(
|
||||
"Expected a python long for conversion to callable handle but got %s",
|
||||
Py_TYPE($input)->tp_name).c_str());
|
||||
}
|
||||
$1 = PyLong_AsLongLong($input);
|
||||
}
|
||||
|
||||
// Override default py3 behavior of attempting to encode into Unicode.
|
||||
%typemap(out) std::string tensorflow::GetHandleShapeAndType {
|
||||
$result = PyBytes_FromStringAndSize($1.data(), $1.size());
|
||||
}
|
||||
|
||||
// TODO(skyewm): SWIG emits a warning for the const char* in TF_WhileParams,
|
||||
// skip for now
|
||||
%ignore TF_WhileParams;
|
||||
%ignore TF_NewWhile;
|
||||
%ignore TF_FinishWhile;
|
||||
%ignore TF_AbortWhile;
|
||||
|
||||
// These are defined below, avoid duplicate definitions
|
||||
%ignore TF_Run;
|
||||
%ignore TF_PRun;
|
||||
%ignore TF_PRunSetup;
|
||||
|
||||
// We use TF_SessionRun_wrapper instead of TF_SessionRun
|
||||
%ignore TF_SessionRun;
|
||||
%unignore TF_SessionRun_wrapper;
|
||||
// The %exception block above releases the Python GIL for the length of each
|
||||
// wrapped method. We disable this behavior for TF_SessionRun_wrapper because it
|
||||
// uses Python method(s) that expect the GIL to be held (at least
|
||||
// PyArray_Return, maybe others).
|
||||
%noexception TF_SessionRun_wrapper;
|
||||
|
||||
// We use TF_SessionPRunSetup_wrapper instead of TF_SessionPRunSetup
|
||||
%ignore TF_SessionPRunSetup;
|
||||
%unignore TF_SessionPRunSetup_wrapper;
|
||||
// See comment for "%noexception TF_SessionRun_wrapper;"
|
||||
%noexception TF_SessionPRunSetup_wrapper;
|
||||
|
||||
// We use TF_SessionPRun_wrapper instead of TF_SessionPRun
|
||||
%ignore TF_SessionPRun;
|
||||
%unignore TF_SessionPRun_wrapper;
|
||||
// See comment for "%noexception TF_SessionRun_wrapper;"
|
||||
%noexception TF_SessionPRun_wrapper;
|
||||
|
||||
%unignore TF_DeprecatedSessionMakeCallable;
|
||||
%unignore TF_SessionMakeCallable;
|
||||
%unignore TF_DeprecatedSessionRunCallable;
|
||||
%unignore TF_SessionRunCallable;
|
||||
%unignore TF_DeprecatedSessionReleaseCallable;
|
||||
%unignore TF_SessionReleaseCallable;
|
||||
|
||||
// See comment for "%noexception TF_SessionRun_wrapper;"
|
||||
%noexception TF_DeprecatedSessionRunCallable;
|
||||
%noexception TF_SessionRunCallable;
|
||||
|
||||
%rename("_TF_SetTarget") TF_SetTarget;
|
||||
%rename("_TF_SetConfig") TF_SetConfig;
|
||||
%rename("_TF_NewSessionOptions") TF_NewSessionOptions;
|
||||
|
||||
%include "tensorflow/c/c_api.h"
|
||||
%include "tensorflow/c/tf_attrtype.h"
|
||||
%include "tensorflow/c/python_api.h"
|
||||
|
||||
|
||||
%ignoreall
|
||||
%insert("python") %{
|
||||
def TF_NewSessionOptions(target=None, config=None):
|
||||
# NOTE: target and config are validated in the session constructor.
|
||||
opts = _TF_NewSessionOptions()
|
||||
if target is not None:
|
||||
_TF_SetTarget(opts, target)
|
||||
if config is not None:
|
||||
from tensorflow.python.framework import errors
|
||||
config_str = config.SerializeToString()
|
||||
_TF_SetConfig(opts, config_str)
|
||||
return opts
|
||||
%}
|
||||
|
||||
// Include the wrapper for TF_Run from tf_session_helper.h.
|
||||
|
||||
// The %exception block above releases the Python GIL for the length
|
||||
// of each wrapped method. We disable this behavior for TF_Run
|
||||
// because it uses the Python allocator.
|
||||
%noexception tensorflow::TF_Run_wrapper;
|
||||
%rename(TF_Run) tensorflow::TF_Run_wrapper;
|
||||
%unignore tensorflow;
|
||||
%unignore TF_Run;
|
||||
%unignore EqualGraphDefWrapper;
|
||||
%unignore EqualAttrValueWrapper;
|
||||
|
||||
// Include the wrapper for TF_PRunSetup from tf_session_helper.h.
|
||||
|
||||
// The %exception block above releases the Python GIL for the length
|
||||
// of each wrapped method. We disable this behavior for TF_PRunSetup
|
||||
// because it uses the Python allocator.
|
||||
%noexception tensorflow::TF_PRunSetup_wrapper;
|
||||
%rename(TF_PRunSetup) tensorflow::TF_PRunSetup_wrapper;
|
||||
%unignore tensorflow;
|
||||
%unignore TF_PRunSetup;
|
||||
|
||||
// Include the wrapper for TF_PRun from tf_session_helper.h.
|
||||
|
||||
// The %exception block above releases the Python GIL for the length
|
||||
// of each wrapped method. We disable this behavior for TF_PRun
|
||||
// because it uses the Python allocator.
|
||||
%noexception tensorflow::TF_PRun_wrapper;
|
||||
%rename(TF_PRun) tensorflow::TF_PRun_wrapper;
|
||||
%unignore tensorflow;
|
||||
%unignore TF_PRun;
|
||||
|
||||
%unignore tensorflow::TF_Reset_wrapper;
|
||||
%insert("python") %{
|
||||
def TF_Reset(target, containers=None, config=None):
|
||||
from tensorflow.python.framework import errors
|
||||
opts = TF_NewSessionOptions(target=target, config=config)
|
||||
try:
|
||||
TF_Reset_wrapper(opts, containers)
|
||||
finally:
|
||||
TF_DeleteSessionOptions(opts)
|
||||
%}
|
||||
|
||||
// We use TF_GraphToFunction_wrapper instead of TF_GraphToFunction
|
||||
%ignore TF_GraphToFunction;
|
||||
// TF_GraphToFunction_wrapper does not use any Python methods and
|
||||
// does not require GIL to be held.
|
||||
%unignore TF_GraphToFunction_wrapper;
|
||||
|
||||
// $input is a Python list of wrapped TF_Operations
|
||||
%typemap(in) (const std::vector<TF_Operation*>* opers)
|
||||
(std::vector<TF_Operation*> opers) {
|
||||
if ($input != Py_None) {
|
||||
if (!PyList_Check($input)) {
|
||||
SWIG_exception_fail(SWIG_TypeError, "$symname: expected list");
|
||||
}
|
||||
size_t size = PyList_Size($input);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
PyObject* item = PyList_GetItem($input, i);
|
||||
TF_Operation* oper_ptr;
|
||||
SWIG_ConvertPtr(item, reinterpret_cast<void**>(&oper_ptr),
|
||||
$descriptor(TF_Operation*), 0);
|
||||
opers.push_back(oper_ptr);
|
||||
}
|
||||
$1 = &opers;
|
||||
} else {
|
||||
$1 = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// $input is a Python list of wrapped TF_Operations
|
||||
%typemap(in) (const std::vector<TF_Operation*>* control_outputs)
|
||||
(std::vector<TF_Operation*> control_outputs) {
|
||||
if ($input != Py_None) {
|
||||
if (!PyList_Check($input)) {
|
||||
SWIG_exception_fail(SWIG_TypeError, "$symname: expected list");
|
||||
}
|
||||
size_t size = PyList_Size($input);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
PyObject* item = PyList_GetItem($input, i);
|
||||
TF_Operation* oper_ptr;
|
||||
SWIG_ConvertPtr(item, reinterpret_cast<void**>(&oper_ptr),
|
||||
$descriptor(TF_Operation*), 0);
|
||||
control_outputs.push_back(oper_ptr);
|
||||
}
|
||||
$1 = &control_outputs;
|
||||
} else {
|
||||
$1 = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// Typemaps for TF_GraphGetTensorShapeHelper.
|
||||
|
||||
// Convert from C++ integer vector to Python list of ints.
|
||||
%typemap(out) tensorflow::gtl::InlinedVector<int64_t, 6>
|
||||
tensorflow::TF_GraphGetTensorShapeHelper {
|
||||
$result = PyList_New($1.size());
|
||||
if (!$result) {
|
||||
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < $1.size(); ++i) {
|
||||
PyList_SET_ITEM($result, i, PyLong_FromLongLong($1[i]));
|
||||
}
|
||||
}
|
||||
|
||||
%typemap(in, numinputs=0) bool* unknown_shape (bool temp) {
|
||||
$1=&temp;
|
||||
}
|
||||
|
||||
// Returns a (list(int), bool) tuple.
|
||||
%typemap(argout) bool* unknown_shape {
|
||||
PyObject* new_result = PyTuple_New(2);
|
||||
if (!new_result) {
|
||||
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create tuple");
|
||||
}
|
||||
// Steals $result reference
|
||||
PyTuple_SET_ITEM(new_result, 0, $result);
|
||||
PyTuple_SET_ITEM(new_result, 1, PyBool_FromLong(*$1));
|
||||
$result = new_result;
|
||||
}
|
||||
|
||||
%unignore tensorflow;
|
||||
%unignore TF_GraphGetTensorShapeHelper;
|
||||
%ignore TF_GraphGetTensorShape;
|
||||
|
||||
// We use TF_GraphSetTensorShape_wrapper instead of
|
||||
// TF_GraphSetTensorShape
|
||||
%ignore TF_GraphSetTensorShape;
|
||||
%unignore tensorflow;
|
||||
%unignore TF_GraphSetTensorShape_wrapper;
|
||||
|
||||
// $input is a Python list of ints to a vector<int> for TF_GraphSetTensorShape_wrapper
|
||||
%typemap(in) (const std::vector<int64_t>& dims)
|
||||
(std::vector<int64_t> dims_local){
|
||||
if ($input != Py_None) {
|
||||
PyObject* py_int_seq = PySequence_Fast($input, tensorflow::strings::Printf(
|
||||
"$symname: expected list but got %s ",
|
||||
Py_TYPE($input)->tp_name).c_str());
|
||||
if (py_int_seq == nullptr) {
|
||||
SWIG_exception_fail(SWIG_RuntimeError, tensorflow::strings::Printf(
|
||||
"$symname: PySequence_Fast returned NULL.").c_str());
|
||||
}
|
||||
PyInt64ListToVector(py_int_seq, &dims_local);
|
||||
Py_DECREF(py_int_seq);
|
||||
$1 = &dims_local;
|
||||
} else {
|
||||
$1 = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// We use TF_GraphGetTensorShape_wrapper instead of
|
||||
// TF_GraphGetTensorShape
|
||||
%ignore TF_GraphGetTensorShape;
|
||||
%unignore tensorflow;
|
||||
%unignore TF_GraphGetTensorShape_wrapper;
|
||||
|
||||
// Build a Python list of ints and return it.
|
||||
%typemap(out) std::vector<int64_t> tensorflow::TF_GraphGetTensorShape_wrapper {
|
||||
$result = PyList_New($1.size());
|
||||
if (!$result) {
|
||||
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < $1.size(); ++i) {
|
||||
PyList_SET_ITEM($result, i, PyLong_FromLongLong($1[i]));
|
||||
}
|
||||
}
|
||||
|
||||
// We use TF_GraphSetOutputHandleShapesAndTypes_wrapper instead of
|
||||
// TF_GraphSetOutputHandleShapesAndTypes
|
||||
%ignore TF_GraphSetOutputHandleShapesAndTypes;
|
||||
%unignore tensorflow;
|
||||
%unignore TF_GraphSetOutputHandleShapesAndTypes_wrapper;
|
||||
|
||||
// The space between the double angle brackets below looks extraneous, but
|
||||
// our version of SWIG cannot parse ">>".
|
||||
%typemap(in) (const std::vector<std::vector<int64_t> >& shapes)
|
||||
(std::vector<std::vector<int64_t> > shapes_local){
|
||||
PyObject* seq = PySequence_Fast($input, tensorflow::strings::Printf(
|
||||
"$symname: expected list but got %s ",
|
||||
Py_TYPE($input)->tp_name).c_str());
|
||||
if (seq == nullptr) {
|
||||
SWIG_exception_fail(SWIG_RuntimeError, tensorflow::strings::Printf(
|
||||
"$symname: PySequence_Fast returned NULL.").c_str());
|
||||
}
|
||||
|
||||
int size = PySequence_Fast_GET_SIZE(seq);
|
||||
if (size == 0) {
|
||||
SWIG_exception_fail(SWIG_ValueError, tensorflow::strings::Printf(
|
||||
"$symname: shapes list must be non-empty").c_str());
|
||||
}
|
||||
|
||||
for (int i = 0; i < size; ++i) {
|
||||
PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
|
||||
std::vector<int64_t> dims;
|
||||
if (item != Py_None) {
|
||||
PyObject* py_int_seq = PySequence_Fast(item, tensorflow::strings::Printf(
|
||||
"$symname: expected list but got %s ",
|
||||
Py_TYPE($input)->tp_name).c_str());
|
||||
if (py_int_seq == nullptr) {
|
||||
SWIG_exception_fail(SWIG_RuntimeError, tensorflow::strings::Printf(
|
||||
"$symname: PySequence_Fast returned NULL.").c_str());
|
||||
}
|
||||
PyInt64ListToVector(py_int_seq, &dims);
|
||||
Py_DECREF(py_int_seq);
|
||||
}
|
||||
shapes_local.push_back(dims);
|
||||
}
|
||||
|
||||
Py_DECREF(seq);
|
||||
$1 = &shapes_local;
|
||||
}
|
||||
|
||||
%typemap(in) (const std::vector<int>& ranks)
|
||||
(std::vector<int> ranks_local){
|
||||
PyObject* seq = PySequence_Fast($input, tensorflow::strings::Printf(
|
||||
"$symname: expected list but got %s ",
|
||||
Py_TYPE($input)->tp_name).c_str());
|
||||
if (seq == nullptr) {
|
||||
SWIG_exception_fail(SWIG_RuntimeError, tensorflow::strings::Printf(
|
||||
"$symname: PySequence_Fast returned NULL.").c_str());
|
||||
}
|
||||
|
||||
int size = PySequence_Fast_GET_SIZE(seq);
|
||||
if (size == 0) {
|
||||
SWIG_exception_fail(SWIG_ValueError, tensorflow::strings::Printf(
|
||||
"$symname: shapes list must be non-empty").c_str());
|
||||
}
|
||||
|
||||
for (int i = 0; i < size; ++i) {
|
||||
PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
|
||||
ranks_local.push_back((int) PyInt_AsLong(item));
|
||||
}
|
||||
|
||||
Py_DECREF(seq);
|
||||
$1 = &ranks_local;
|
||||
}
|
||||
|
||||
%typemap(in) (const std::vector<TF_DataType>& types)
|
||||
(std::vector<TF_DataType> types_local){
|
||||
PyObject* seq = PySequence_Fast($input, tensorflow::strings::Printf(
|
||||
"$symname: expected list but got %s ",
|
||||
Py_TYPE($input)->tp_name).c_str());
|
||||
if (seq == nullptr) {
|
||||
SWIG_exception_fail(SWIG_RuntimeError, tensorflow::strings::Printf(
|
||||
"$symname: PySequence_Fast returned NULL.").c_str());
|
||||
}
|
||||
|
||||
int size = PySequence_Fast_GET_SIZE(seq);
|
||||
if (size == 0) {
|
||||
SWIG_exception_fail(SWIG_ValueError, tensorflow::strings::Printf(
|
||||
"$symname: shapes list must be non-empty").c_str());
|
||||
}
|
||||
|
||||
for (int i = 0; i < size; ++i) {
|
||||
PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
|
||||
types_local.push_back((TF_DataType) PyInt_AsLong(item));
|
||||
}
|
||||
|
||||
Py_DECREF(seq);
|
||||
$1 = &types_local;
|
||||
}
|
||||
|
||||
%unignore TF_CreatePlaceholders;
|
||||
// See comment for "%noexception TF_SessionRun_wrapper;"
|
||||
%noexception TF_CreatePlaceholders;
|
||||
|
||||
// Build a Python list of TF_Output and return it.
|
||||
%typemap(out) std::vector<TF_Output> tensorflow::TF_CreatePlaceholders {
|
||||
$result = PyList_New($1.size());
|
||||
if (!$result) {
|
||||
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
|
||||
}
|
||||
|
||||
// Unwrap the generated SwigValueWrapper<std::vector<TF_Output>>
|
||||
const std::vector<TF_Output>& tf_outputs = $1;
|
||||
for (size_t i = 0; i < tf_outputs.size(); ++i) {
|
||||
PyList_SET_ITEM($result, i, CreateWrappedTFOutput(tf_outputs[i]));
|
||||
}
|
||||
}
|
||||
|
||||
%unignore TF_NewSessionRef;
|
||||
%unignore SetRequireShapeInferenceFns;
|
||||
%unignore TF_TryEvaluateConstant_wrapper;
|
||||
%noexception TF_TryEvaluateConstant_wrapper;
|
||||
%unignore ExtendSession;
|
||||
%unignore HandleShapeAndType;
|
||||
|
||||
%include "tensorflow/python/client/tf_session_helper.h"
|
||||
|
||||
%unignoreall
|
1202
tensorflow/python/client/tf_session_wrapper.cc
Normal file
1202
tensorflow/python/client/tf_session_wrapper.cc
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,102 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// SWIG typemaps for TF_SessionRun_wrapper()
|
||||
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%{
|
||||
#include "tensorflow/python/client/tf_session_helper.h"
|
||||
%}
|
||||
|
||||
// Required to use PyArray_* functions.
|
||||
%init %{
|
||||
tensorflow::ImportNumpy();
|
||||
%}
|
||||
|
||||
// $input is a Python dict mapping wrapped TF_Outputs to ndarrays.
|
||||
%typemap(in) (const std::vector<TF_Output>& inputs,
|
||||
const std::vector<PyObject*>& input_ndarrays)
|
||||
(std::vector<TF_Output> inputs, std::vector<PyObject*> input_ndarrays) {
|
||||
if (!PyDict_Check($input)) {
|
||||
SWIG_exception_fail(SWIG_TypeError, "$symname: expected dict");
|
||||
}
|
||||
PyObject* key;
|
||||
PyObject* value;
|
||||
Py_ssize_t pos = 0;
|
||||
while (PyDict_Next($input, &pos, &key, &value)) {
|
||||
TF_Output* input_ptr;
|
||||
SWIG_ConvertPtr(key, reinterpret_cast<void**>(&input_ptr),
|
||||
SWIGTYPE_p_TF_Output, 0);
|
||||
inputs.push_back(*input_ptr);
|
||||
|
||||
if (!PyArray_Check(value)) {
|
||||
SWIG_exception_fail(
|
||||
SWIG_TypeError,
|
||||
"$symname: expected all values in input dict to be ndarray");
|
||||
}
|
||||
input_ndarrays.push_back(value);
|
||||
}
|
||||
$1 = &inputs;
|
||||
$2 = &input_ndarrays;
|
||||
}
|
||||
|
||||
// $input is a Python list of wrapped TF_Operations
|
||||
%typemap(in) (const std::vector<TF_Operation*>& targets)
|
||||
(std::vector<TF_Operation*> targets) {
|
||||
if (!PyList_Check($input)) {
|
||||
SWIG_exception_fail(SWIG_TypeError, "$symname: expected list");
|
||||
}
|
||||
size_t size = PyList_Size($input);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
PyObject* item = PyList_GetItem($input, i);
|
||||
TF_Operation* oper_ptr;
|
||||
SWIG_ConvertPtr(item, reinterpret_cast<void**>(&oper_ptr),
|
||||
SWIGTYPE_p_TF_Operation, 0);
|
||||
targets.push_back(oper_ptr);
|
||||
}
|
||||
$1 = &targets;
|
||||
}
|
||||
|
||||
// $input is a Python list of wrapped TF_Outputs
|
||||
%typemap(in) (const std::vector<TF_Output>& outputs)
|
||||
(std::vector<TF_Output> outputs) {
|
||||
string error_msg;
|
||||
if (!PyTensorListToVector($input, &outputs, &error_msg)) {
|
||||
SWIG_exception_fail(SWIG_TypeError, ("$symname: " + error_msg).c_str());
|
||||
}
|
||||
$1 = &outputs;
|
||||
}
|
||||
|
||||
// Apply the typemap above to inputs as well
|
||||
%typemap(in) (const std::vector<TF_Output>& inputs) =
|
||||
(const std::vector<TF_Output>& outputs);
|
||||
|
||||
// Create temporary py_outputs_vec variable to store return value
|
||||
%typemap(in, numinputs=0) (std::vector<PyObject*>* py_outputs)
|
||||
(std::vector<PyObject*> py_outputs_vec) {
|
||||
$1 = &py_outputs_vec;
|
||||
}
|
||||
|
||||
// Convert py_outputs to returned Python list
|
||||
%typemap(argout) (std::vector<PyObject*>* py_outputs) {
|
||||
$result = PyList_New($1->size());
|
||||
if (!$result) {
|
||||
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
|
||||
}
|
||||
for (int i = 0; i < $1->size(); ++i) {
|
||||
PyList_SET_ITEM($result, i, (*$1)[i]);
|
||||
}
|
||||
}
|
@ -22,7 +22,7 @@ from absl.testing import parameterized
|
||||
from tensorflow.core.protobuf import cluster_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import tensorflow_server_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.data.experimental.ops import distribute
|
||||
from tensorflow.python.data.experimental.ops import distribute_options
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
@ -220,7 +220,7 @@ class RemoteReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
def setUp(self):
|
||||
super(RemoteReplicateTest, self).setUp()
|
||||
# Start the local server.
|
||||
local_port = pywrap_tensorflow.TF_PickUnusedPortOrDie()
|
||||
local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
|
||||
context.set_server_def(
|
||||
server_def=_get_server_def(
|
||||
JOB_NAME,
|
||||
|
@ -265,6 +265,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:pywrap_tf_session",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
@ -1151,6 +1152,7 @@ py_test(
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:pywrap_tf_session",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
@ -26,7 +26,7 @@ import traceback
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow_internal
|
||||
from tensorflow.python.client import pywrap_tf_session
|
||||
from tensorflow.python.platform import gfile
|
||||
|
||||
HELP_INDENT = " "
|
||||
@ -142,7 +142,7 @@ def get_tensorflow_version_lines(include_dependency_versions=False):
|
||||
Returns:
|
||||
A formatted, multi-line `RichTextLines` object.
|
||||
"""
|
||||
lines = ["TensorFlow version: %s" % pywrap_tensorflow_internal.__version__]
|
||||
lines = ["TensorFlow version: %s" % pywrap_tf_session.__version__]
|
||||
lines.append("")
|
||||
if include_dependency_versions:
|
||||
lines.append("Dependency version(s):")
|
||||
|
@ -23,7 +23,7 @@ import tempfile
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow_internal
|
||||
from tensorflow.python.client import pywrap_tf_session
|
||||
from tensorflow.python.debug.cli import debugger_cli_common
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import gfile
|
||||
@ -1160,15 +1160,13 @@ class GetTensorFlowVersionLinesTest(test_util.TensorFlowTestCase):
|
||||
def testGetVersionWithoutDependencies(self):
|
||||
out = debugger_cli_common.get_tensorflow_version_lines()
|
||||
self.assertEqual(2, len(out.lines))
|
||||
self.assertEqual(
|
||||
"TensorFlow version: %s" % pywrap_tensorflow_internal.__version__,
|
||||
out.lines[0])
|
||||
self.assertEqual("TensorFlow version: %s" % pywrap_tf_session.__version__,
|
||||
out.lines[0])
|
||||
|
||||
def testGetVersionWithDependencies(self):
|
||||
out = debugger_cli_common.get_tensorflow_version_lines(True)
|
||||
self.assertIn(
|
||||
"TensorFlow version: %s" % pywrap_tensorflow_internal.__version__,
|
||||
out.lines)
|
||||
self.assertIn("TensorFlow version: %s" % pywrap_tf_session.__version__,
|
||||
out.lines)
|
||||
self.assertIn(" numpy: %s" % np.__version__, out.lines)
|
||||
|
||||
|
||||
|
@ -143,13 +143,14 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":eager_util",
|
||||
":executor",
|
||||
":monitoring",
|
||||
"//tensorflow/python:c_api_util",
|
||||
"//tensorflow/python:device",
|
||||
"//tensorflow/python:device_spec",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:pywrap_tf_session",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python:util",
|
||||
@ -177,7 +178,8 @@ py_library(
|
||||
"//third_party/py/tf_agents:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
":eager_util",
|
||||
"//tensorflow/python:c_api_util",
|
||||
"//tensorflow/python:pywrap_tf_session",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
@ -200,7 +202,8 @@ py_library(
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":context",
|
||||
":eager_util",
|
||||
"//tensorflow/python:c_api_util",
|
||||
"//tensorflow/python:pywrap_tf_session",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
@ -223,7 +226,8 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":eager_util",
|
||||
"//tensorflow/python:c_api_util",
|
||||
"//tensorflow/python:pywrap_tf_session",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
],
|
||||
)
|
||||
@ -488,6 +492,7 @@ py_library(
|
||||
"//tensorflow/python:func_graph",
|
||||
"//tensorflow/python:gradients_impl",
|
||||
"//tensorflow/python:graph_to_function_def",
|
||||
"//tensorflow/python:pywrap_tf_session",
|
||||
"//tensorflow/python:util",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
@ -554,17 +559,6 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "eager_util",
|
||||
srcs = ["eager_util.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "benchmarks_test",
|
||||
srcs = ["benchmarks_test.py"],
|
||||
|
@ -32,9 +32,10 @@ from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.eager import eager_util as c_api_util
|
||||
from tensorflow.python.client import pywrap_tf_session
|
||||
from tensorflow.python.eager import executor
|
||||
from tensorflow.python.eager import monitoring
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import is_in_graph_mode
|
||||
@ -789,7 +790,7 @@ class Context(object):
|
||||
self.ensure_initialized()
|
||||
with c_api_util.tf_buffer() as buffer_:
|
||||
pywrap_tfe.TFE_HostAddressSpace(self._context_handle, buffer_)
|
||||
address_space = pywrap_tfe.TF_GetBuffer(buffer_).decode("utf-8")
|
||||
address_space = pywrap_tf_session.TF_GetBuffer(buffer_).decode("utf-8")
|
||||
return address_space
|
||||
|
||||
# TODO(fishx): remove this property.
|
||||
@ -1537,7 +1538,7 @@ class Context(object):
|
||||
return None
|
||||
with c_api_util.tf_buffer() as buffer_:
|
||||
pywrap_tfe.TFE_ContextExportRunMetadata(self._context_handle, buffer_)
|
||||
proto_data = pywrap_tfe.TF_GetBuffer(buffer_)
|
||||
proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
run_metadata.ParseFromString(compat.as_bytes(proto_data))
|
||||
return run_metadata
|
||||
|
@ -1,61 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities for using the TensorFlow Eager using the C API."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tfe as c_api
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
|
||||
# We temporarily need a duplicate tf_buffer function in eager_util. The
|
||||
# c_api_util is still relying on SWIG and is thus incompatible until
|
||||
# we migrate over. We can delete this once we migrate tf_session.i
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def tf_buffer(data=None):
|
||||
"""Context manager that creates and deletes TF_Buffer.
|
||||
|
||||
Example usage:
|
||||
with tf_buffer() as buf:
|
||||
# get serialized graph def into buf
|
||||
...
|
||||
proto_data = c_api.TF_GetBuffer(buf)
|
||||
graph_def.ParseFromString(compat.as_bytes(proto_data))
|
||||
# buf has been deleted
|
||||
|
||||
with tf_buffer(some_string) as buf:
|
||||
c_api.TF_SomeFunction(buf)
|
||||
# buf has been deleted
|
||||
|
||||
Args:
|
||||
data: An optional `bytes`, `str`, or `unicode` object. If not None, the
|
||||
yielded buffer will contain this data.
|
||||
|
||||
Yields:
|
||||
Created TF_Buffer
|
||||
"""
|
||||
if data:
|
||||
buf = c_api.TF_NewBufferFromString(compat.as_bytes(data))
|
||||
else:
|
||||
buf = c_api.TF_NewBuffer()
|
||||
try:
|
||||
yield buf
|
||||
finally:
|
||||
c_api.TF_DeleteBuffer(buf)
|
@ -33,8 +33,8 @@ from six.moves import map
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import function_pb2
|
||||
from tensorflow.python import _pywrap_utils
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.client import pywrap_tf_session
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import backprop_util
|
||||
from tensorflow.python.eager import context
|
||||
@ -482,7 +482,7 @@ class _EagerDefinedFunction(object):
|
||||
output_names = []
|
||||
else:
|
||||
output_names = []
|
||||
fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
|
||||
fn = pywrap_tf_session.TF_GraphToFunction_wrapper(
|
||||
graph._c_graph, # pylint: disable=protected-access
|
||||
compat.as_str(name),
|
||||
False,
|
||||
@ -499,14 +499,14 @@ class _EagerDefinedFunction(object):
|
||||
serialized = attr_value.SerializeToString()
|
||||
# TODO(iga): this creates and deletes a new TF_Status for every attr.
|
||||
# It might be worth creating a convenient way to re-use status.
|
||||
pywrap_tensorflow.TF_FunctionSetAttrValueProto(
|
||||
fn, compat.as_str(name), serialized)
|
||||
pywrap_tf_session.TF_FunctionSetAttrValueProto(fn, compat.as_str(name),
|
||||
serialized)
|
||||
|
||||
# TODO(apassos) avoid creating a FunctionDef (specially to grab the
|
||||
# signature, but also in general it's nice not to depend on it.
|
||||
with c_api_util.tf_buffer() as buffer_:
|
||||
pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_)
|
||||
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
|
||||
pywrap_tf_session.TF_FunctionToFunctionDef(fn, buffer_)
|
||||
proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
|
||||
function_def = function_pb2.FunctionDef()
|
||||
function_def.ParseFromString(compat.as_bytes(proto_data))
|
||||
self._name = compat.as_bytes(function_def.signature.name)
|
||||
|
@ -22,7 +22,8 @@ import collections
|
||||
|
||||
from tensorflow.core.framework import summary_pb2
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.eager import eager_util as c_api_util
|
||||
from tensorflow.python.client import pywrap_tf_session
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
_MetricMethod = collections.namedtuple('MetricMethod', 'create delete get_cell')
|
||||
@ -258,7 +259,7 @@ class StringGaugeCell(object):
|
||||
"""Retrieves the current value."""
|
||||
with c_api_util.tf_buffer() as buffer_:
|
||||
pywrap_tfe.TFE_MonitoringStringGaugeCellValue(self._cell, buffer_)
|
||||
value = pywrap_tfe.TF_GetBuffer(buffer_).decode('utf-8')
|
||||
value = pywrap_tf_session.TF_GetBuffer(buffer_).decode('utf-8')
|
||||
return value
|
||||
|
||||
|
||||
@ -361,7 +362,7 @@ class SamplerCell(object):
|
||||
"""
|
||||
with c_api_util.tf_buffer() as buffer_:
|
||||
pywrap_tfe.TFE_MonitoringSamplerCellValue(self._cell, buffer_)
|
||||
proto_data = pywrap_tfe.TF_GetBuffer(buffer_)
|
||||
proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
|
||||
histogram_proto = summary_pb2.HistogramProto()
|
||||
histogram_proto.ParseFromString(compat.as_bytes(proto_data))
|
||||
return histogram_proto
|
||||
|
@ -40,8 +40,9 @@ import threading
|
||||
|
||||
from tensorflow.python import _pywrap_events_writer
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.client import pywrap_tf_session
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import eager_util as c_api_util
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import compat
|
||||
@ -101,7 +102,7 @@ def stop():
|
||||
context.context().executor.wait()
|
||||
with c_api_util.tf_buffer() as buffer_:
|
||||
pywrap_tfe.TFE_ProfilerSerializeToString(_profiler, buffer_)
|
||||
result = pywrap_tfe.TF_GetBuffer(buffer_)
|
||||
result = pywrap_tf_session.TF_GetBuffer(buffer_)
|
||||
pywrap_tfe.TFE_DeleteProfiler(_profiler)
|
||||
_profiler = None
|
||||
_run_num += 1
|
||||
|
@ -19,7 +19,8 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.eager import eager_util as c_api_util
|
||||
from tensorflow.python.client import pywrap_tf_session
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
|
||||
@ -74,4 +75,4 @@ def monitor(service_addr,
|
||||
pywrap_tfe.TFE_ProfilerClientMonitor(service_addr, duration_ms,
|
||||
monitoring_level, display_timestamp,
|
||||
buffer_)
|
||||
return pywrap_tfe.TF_GetBuffer(buffer_)
|
||||
return pywrap_tf_session.TF_GetBuffer(buffer_)
|
||||
|
@ -21,7 +21,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import api_def_pb2
|
||||
from tensorflow.core.framework import op_def_pb2
|
||||
from tensorflow.python import pywrap_tensorflow as c_api
|
||||
from tensorflow.python.client import pywrap_tf_session as c_api
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
|
@ -23,7 +23,7 @@ import warnings
|
||||
|
||||
from tensorflow.core.lib.core import error_codes_pb2
|
||||
from tensorflow.python import _pywrap_py_exception_registry
|
||||
from tensorflow.python import pywrap_tensorflow as c_api
|
||||
from tensorflow.python.client import pywrap_tf_session as c_api
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.framework import error_interpolation
|
||||
from tensorflow.python.util import compat
|
||||
|
@ -23,7 +23,7 @@ import pickle
|
||||
import warnings
|
||||
|
||||
from tensorflow.core.lib.core import error_codes_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import _pywrap_file_io
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import errors_impl
|
||||
@ -112,7 +112,7 @@ class ErrorsTest(test.TestCase):
|
||||
|
||||
def testStatusDoesNotLeak(self):
|
||||
try:
|
||||
pywrap_tensorflow.DeleteFile(compat.as_bytes("/DOES_NOT_EXIST/"))
|
||||
_pywrap_file_io.DeleteFile(compat.as_bytes("/DOES_NOT_EXIST/"))
|
||||
except:
|
||||
pass
|
||||
gc.collect()
|
||||
|
@ -27,7 +27,7 @@ import hashlib
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import function_pb2
|
||||
from tensorflow.python import pywrap_tensorflow as c_api
|
||||
from tensorflow.python.client import pywrap_tf_session as c_api
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
@ -20,8 +20,8 @@ from __future__ import print_function
|
||||
import contextlib
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python import pywrap_tensorflow as c_api
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.client import pywrap_tf_session as c_api
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.framework import errors
|
||||
|
@ -19,7 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import kernel_def_pb2
|
||||
from tensorflow.python import pywrap_tensorflow as c_api
|
||||
from tensorflow.python.client import pywrap_tf_session as c_api
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
|
@ -26,7 +26,7 @@ import platform
|
||||
import sys
|
||||
|
||||
from tensorflow.python import _pywrap_python_op_gen
|
||||
from tensorflow.python import pywrap_tensorflow as py_tf
|
||||
from tensorflow.python.client import pywrap_tf_session as py_tf
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
@ -32,7 +32,7 @@ from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.framework import op_def_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.core.protobuf import saver_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.client import pywrap_tf_session as c_api
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import error_interpolation
|
||||
from tensorflow.python.framework import graph_io
|
||||
@ -446,9 +446,9 @@ def _is_default_attr_value(op_def, attr_name, attr_value):
|
||||
if attr_def.name == attr_name:
|
||||
if not attr_def.HasField("default_value"):
|
||||
return False
|
||||
# pywrap_tensorflow.EqualAttrValueWrapper returns an empty string
|
||||
# c_api.EqualAttrValueWrapper returns an empty string
|
||||
# if both arguments represent an equivalent AttrValue instance.
|
||||
return not pywrap_tensorflow.EqualAttrValueWrapper(
|
||||
return not c_api.EqualAttrValueWrapper(
|
||||
attr_value.SerializeToString(),
|
||||
attr_def.default_value.SerializeToString())
|
||||
return False
|
||||
|
@ -38,11 +38,12 @@ from tensorflow.core.framework import versions_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
# pywrap_tensorflow must be imported first to avoid profobuf issues.
|
||||
# (b/143110113)
|
||||
# pylint: disable=invalid-import-order,g-bad-import-order
|
||||
from tensorflow.python import pywrap_tensorflow as c_api
|
||||
from tensorflow.python import pywrap_tfe as c_api_new
|
||||
# pylint: enable=invalid-import-order,g-bad-import-order
|
||||
# pylint: disable=invalid-import-order,g-bad-import-order,unused-import
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
# pylint: enable=invalid-import-order,g-bad-import-order,unused-import
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.client import pywrap_tf_session
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import core
|
||||
from tensorflow.python.eager import monitoring
|
||||
@ -254,7 +255,7 @@ def register_dense_tensor_like_type(tensor_type):
|
||||
|
||||
def uid():
|
||||
"""A unique (within this program execution) integer."""
|
||||
return c_api_new.TFE_Py_UID()
|
||||
return pywrap_tfe.TFE_Py_UID()
|
||||
|
||||
|
||||
def numpy_text(tensor, is_repr=False):
|
||||
@ -502,13 +503,13 @@ class Tensor(_TensorLike):
|
||||
def _c_api_shape(self):
|
||||
"""Returns the TensorShape of this tensor according to the C API."""
|
||||
c_graph = self._op._graph._c_graph # pylint: disable=protected-access
|
||||
shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper(
|
||||
shape_vec, unknown_shape = pywrap_tf_session.TF_GraphGetTensorShapeHelper(
|
||||
c_graph, self._as_tf_output())
|
||||
if unknown_shape:
|
||||
return tensor_shape.unknown_shape()
|
||||
else:
|
||||
shape_vector = [None if d == -1 else d for d in shape_vector]
|
||||
return tensor_shape.TensorShape(shape_vector)
|
||||
shape_vec = [None if d == -1 else d for d in shape_vec]
|
||||
return tensor_shape.TensorShape(shape_vec)
|
||||
|
||||
@property
|
||||
def _shape(self):
|
||||
@ -649,7 +650,7 @@ class Tensor(_TensorLike):
|
||||
else:
|
||||
dim_list.append(dim.value)
|
||||
try:
|
||||
c_api.TF_GraphSetTensorShape_wrapper(
|
||||
pywrap_tf_session.TF_GraphSetTensorShape_wrapper(
|
||||
self._op._graph._c_graph, # pylint: disable=protected-access
|
||||
self._as_tf_output(),
|
||||
dim_list,
|
||||
@ -669,7 +670,7 @@ class Tensor(_TensorLike):
|
||||
Returns:
|
||||
A list of `Operation`s.
|
||||
"""
|
||||
consumer_names = c_api.TF_OperationOutputConsumers_wrapper(
|
||||
consumer_names = pywrap_tf_session.TF_OperationOutputConsumers_wrapper(
|
||||
self._as_tf_output())
|
||||
# pylint: disable=protected-access
|
||||
return [
|
||||
@ -1160,7 +1161,7 @@ class _EagerTensorBase(Tensor):
|
||||
|
||||
# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and
|
||||
# registers it with the current module.
|
||||
EagerTensor = c_api_new.TFE_Py_InitEagerTensor(_EagerTensorBase)
|
||||
EagerTensor = pywrap_tfe.TFE_Py_InitEagerTensor(_EagerTensorBase)
|
||||
|
||||
|
||||
register_dense_tensor_like_type(Tensor)
|
||||
@ -1633,20 +1634,22 @@ def _create_c_op(graph, node_def, inputs, control_inputs, op_def=None):
|
||||
# Refactor so we don't have to do this here.
|
||||
inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.attr)
|
||||
# pylint: disable=protected-access
|
||||
op_desc = c_api.TF_NewOperation(graph._c_graph, compat.as_str(node_def.op),
|
||||
compat.as_str(node_def.name))
|
||||
op_desc = pywrap_tf_session.TF_NewOperation(graph._c_graph,
|
||||
compat.as_str(node_def.op),
|
||||
compat.as_str(node_def.name))
|
||||
if node_def.device:
|
||||
c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device))
|
||||
pywrap_tf_session.TF_SetDevice(op_desc, compat.as_str(node_def.device))
|
||||
# Add inputs
|
||||
for op_input in inputs:
|
||||
if isinstance(op_input, (list, tuple)):
|
||||
c_api.TF_AddInputList(op_desc, [t._as_tf_output() for t in op_input])
|
||||
pywrap_tf_session.TF_AddInputList(op_desc,
|
||||
[t._as_tf_output() for t in op_input])
|
||||
else:
|
||||
c_api.TF_AddInput(op_desc, op_input._as_tf_output())
|
||||
pywrap_tf_session.TF_AddInput(op_desc, op_input._as_tf_output())
|
||||
|
||||
# Add control inputs
|
||||
for control_input in control_inputs:
|
||||
c_api.TF_AddControlInput(op_desc, control_input._c_op)
|
||||
pywrap_tf_session.TF_AddControlInput(op_desc, control_input._c_op)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# Add attrs
|
||||
@ -1654,10 +1657,11 @@ def _create_c_op(graph, node_def, inputs, control_inputs, op_def=None):
|
||||
serialized = attr_value.SerializeToString()
|
||||
# TODO(skyewm): this creates and deletes a new TF_Status for every attr.
|
||||
# It might be worth creating a convenient way to re-use the same status.
|
||||
c_api.TF_SetAttrValueProto(op_desc, compat.as_str(name), serialized)
|
||||
pywrap_tf_session.TF_SetAttrValueProto(op_desc, compat.as_str(name),
|
||||
serialized)
|
||||
|
||||
try:
|
||||
c_op = c_api.TF_FinishOperation(op_desc)
|
||||
c_op = pywrap_tf_session.TF_FinishOperation(op_desc)
|
||||
except errors.InvalidArgumentError as e:
|
||||
# Convert to ValueError for backwards compatibility.
|
||||
raise ValueError(str(e))
|
||||
@ -1744,7 +1748,7 @@ class Operation(object):
|
||||
if not _VALID_OP_NAME_REGEX.match(node_def.name):
|
||||
raise ValueError("'%s' is not a valid node name" % node_def.name)
|
||||
c_op = None
|
||||
elif type(node_def).__name__ == "SwigPyObject":
|
||||
elif type(node_def).__name__ == "TF_Operation":
|
||||
assert inputs is None
|
||||
assert output_types is None
|
||||
assert control_inputs is None
|
||||
@ -1814,7 +1818,7 @@ class Operation(object):
|
||||
# Initialize self._c_op.
|
||||
if c_op:
|
||||
self._c_op = c_op
|
||||
op_def = g._get_op_def(c_api.TF_OperationOpType(c_op))
|
||||
op_def = g._get_op_def(pywrap_tf_session.TF_OperationOpType(c_op))
|
||||
name = self.name
|
||||
else:
|
||||
if op_def is None:
|
||||
@ -1827,11 +1831,11 @@ class Operation(object):
|
||||
self._is_stateful = op_def.is_stateful
|
||||
|
||||
# Initialize self._outputs.
|
||||
num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
|
||||
num_outputs = pywrap_tf_session.TF_OperationNumOutputs(self._c_op)
|
||||
self._outputs = []
|
||||
for i in range(num_outputs):
|
||||
tf_output = c_api_util.tf_output(self._c_op, i)
|
||||
output_type = c_api.TF_OperationOutputType(tf_output)
|
||||
output_type = pywrap_tf_session.TF_OperationOutputType(tf_output)
|
||||
tensor = Tensor._create_with_tf_output(self, i, output_type, tf_output) # pylint: disable=protected-access
|
||||
self._outputs.append(tensor)
|
||||
|
||||
@ -1900,7 +1904,7 @@ class Operation(object):
|
||||
@property
|
||||
def name(self):
|
||||
"""The full name of this operation."""
|
||||
return c_api.TF_OperationName(self._c_op)
|
||||
return pywrap_tf_session.TF_OperationName(self._c_op)
|
||||
|
||||
@property
|
||||
def _id(self):
|
||||
@ -1916,7 +1920,7 @@ class Operation(object):
|
||||
assigned, or an empty string if it has not been assigned to a
|
||||
device.
|
||||
"""
|
||||
return c_api.TF_OperationDevice(self._c_op)
|
||||
return pywrap_tf_session.TF_OperationDevice(self._c_op)
|
||||
|
||||
@property
|
||||
def _device_assignments(self):
|
||||
@ -1991,33 +1995,28 @@ class Operation(object):
|
||||
Returns:
|
||||
List of the types of the Tensors computed by this operation.
|
||||
Each element in the list is an integer whose value is one of
|
||||
the TF_DataType enums defined in c_api.h
|
||||
the TF_DataType enums defined in pywrap_tf_session.h
|
||||
The length of this list indicates the number of output endpoints
|
||||
of the operation.
|
||||
"""
|
||||
num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
|
||||
num_outputs = pywrap_tf_session.TF_OperationNumOutputs(self._c_op)
|
||||
output_types = [
|
||||
c_api.TF_OperationOutputType(self._tf_output(i))
|
||||
int(pywrap_tf_session.TF_OperationOutputType(self._tf_output(i)))
|
||||
for i in xrange(num_outputs)
|
||||
]
|
||||
# In all the tests we have output_types that are passed into
|
||||
# Operation.__init__ are a list of ints (which is illegal according
|
||||
# to the docstring), but input_types are instances of DType.
|
||||
# This extra assert is to catch if we ever use DType for output_types.
|
||||
if output_types:
|
||||
assert isinstance(output_types[0], int)
|
||||
|
||||
return output_types
|
||||
|
||||
def _tf_output(self, output_idx):
|
||||
"""Create and return a new TF_Output for output_idx'th output of this op."""
|
||||
tf_output = c_api.TF_Output()
|
||||
tf_output = pywrap_tf_session.TF_Output()
|
||||
tf_output.oper = self._c_op
|
||||
tf_output.index = output_idx
|
||||
return tf_output
|
||||
|
||||
def _tf_input(self, input_idx):
|
||||
"""Create and return a new TF_Input for input_idx'th input of this op."""
|
||||
tf_input = c_api.TF_Input()
|
||||
tf_input = pywrap_tf_session.TF_Input()
|
||||
tf_input.oper = self._c_op
|
||||
tf_input.index = input_idx
|
||||
return tf_input
|
||||
@ -2040,7 +2039,7 @@ class Operation(object):
|
||||
Args:
|
||||
device_str: A string specifying where to place this op.
|
||||
"""
|
||||
c_api.SetRequestedDevice(
|
||||
pywrap_tf_session.SetRequestedDevice(
|
||||
self._graph._c_graph, # pylint: disable=protected-access
|
||||
self._c_op, # pylint: disable=protected-access
|
||||
device_str)
|
||||
@ -2065,7 +2064,7 @@ class Operation(object):
|
||||
|
||||
# Reset cached inputs.
|
||||
self._inputs_val = None
|
||||
c_api.UpdateEdge(
|
||||
pywrap_tf_session.UpdateEdge(
|
||||
self._graph._c_graph, # pylint: disable=protected-access
|
||||
tensor._as_tf_output(), # pylint: disable=protected-access
|
||||
self._tf_input(index))
|
||||
@ -2090,7 +2089,7 @@ class Operation(object):
|
||||
|
||||
# Reset cached inputs.
|
||||
self._inputs_val = None
|
||||
c_api.AddWhileInputHack(
|
||||
pywrap_tf_session.AddWhileInputHack(
|
||||
self._graph._c_graph, # pylint: disable=protected-access
|
||||
tensor._as_tf_output(), # pylint: disable=protected-access
|
||||
self._c_op)
|
||||
@ -2108,7 +2107,10 @@ class Operation(object):
|
||||
for op in ops:
|
||||
if not isinstance(op, Operation):
|
||||
raise TypeError("op must be an Operation: %s" % op)
|
||||
c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access
|
||||
pywrap_tf_session.AddControlInput(
|
||||
self._graph._c_graph, # pylint: disable=protected-access
|
||||
self._c_op, # pylint: disable=protected-access
|
||||
op._c_op) # pylint: disable=protected-access
|
||||
|
||||
def _add_control_input(self, op):
|
||||
"""Add a new control input to this operation.
|
||||
@ -2122,11 +2124,14 @@ class Operation(object):
|
||||
"""
|
||||
if not isinstance(op, Operation):
|
||||
raise TypeError("op must be an Operation: %s" % op)
|
||||
c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access
|
||||
pywrap_tf_session.AddControlInput(
|
||||
self._graph._c_graph, # pylint: disable=protected-access
|
||||
self._c_op, # pylint: disable=protected-access
|
||||
op._c_op) # pylint: disable=protected-access
|
||||
|
||||
def _remove_all_control_inputs(self):
|
||||
"""Removes any control inputs to this operation."""
|
||||
c_api.RemoveAllControlInputs(self._graph._c_graph, self._c_op) # pylint: disable=protected-access
|
||||
pywrap_tf_session.RemoveAllControlInputs(self._graph._c_graph, self._c_op) # pylint: disable=protected-access
|
||||
|
||||
def _add_outputs(self, types, shapes):
|
||||
"""Adds new Tensors to self.outputs.
|
||||
@ -2161,16 +2166,18 @@ class Operation(object):
|
||||
"""The sequence of `Tensor` objects representing the data inputs of this op."""
|
||||
if self._inputs_val is None:
|
||||
# pylint: disable=protected-access
|
||||
self._inputs_val = tuple(map(self.graph._get_tensor_by_tf_output,
|
||||
c_api.GetOperationInputs(self._c_op)))
|
||||
self._inputs_val = tuple(
|
||||
map(self.graph._get_tensor_by_tf_output,
|
||||
pywrap_tf_session.GetOperationInputs(self._c_op)))
|
||||
# pylint: enable=protected-access
|
||||
return self._inputs_val
|
||||
|
||||
@property
|
||||
def _input_types(self):
|
||||
num_inputs = c_api.TF_OperationNumInputs(self._c_op)
|
||||
num_inputs = pywrap_tf_session.TF_OperationNumInputs(self._c_op)
|
||||
input_types = [
|
||||
dtypes.as_dtype(c_api.TF_OperationInputType(self._tf_input(i)))
|
||||
dtypes.as_dtype(
|
||||
pywrap_tf_session.TF_OperationInputType(self._tf_input(i)))
|
||||
for i in xrange(num_inputs)
|
||||
]
|
||||
return input_types
|
||||
@ -2189,11 +2196,12 @@ class Operation(object):
|
||||
A list of `Operation` objects.
|
||||
|
||||
"""
|
||||
control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op)
|
||||
control_c_ops = pywrap_tf_session.TF_OperationGetControlInputs_wrapper(
|
||||
self._c_op)
|
||||
# pylint: disable=protected-access
|
||||
return [
|
||||
self.graph._get_operation_by_name_unsafe(c_api.TF_OperationName(c_op))
|
||||
for c_op in control_c_ops
|
||||
self.graph._get_operation_by_name_unsafe(
|
||||
pywrap_tf_session.TF_OperationName(c_op)) for c_op in control_c_ops
|
||||
]
|
||||
# pylint: enable=protected-access
|
||||
|
||||
@ -2208,18 +2216,19 @@ class Operation(object):
|
||||
A list of `Operation` objects.
|
||||
|
||||
"""
|
||||
control_c_ops = c_api.TF_OperationGetControlOutputs_wrapper(self._c_op)
|
||||
control_c_ops = pywrap_tf_session.TF_OperationGetControlOutputs_wrapper(
|
||||
self._c_op)
|
||||
# pylint: disable=protected-access
|
||||
return [
|
||||
self.graph._get_operation_by_name_unsafe(c_api.TF_OperationName(c_op))
|
||||
for c_op in control_c_ops
|
||||
self.graph._get_operation_by_name_unsafe(
|
||||
pywrap_tf_session.TF_OperationName(c_op)) for c_op in control_c_ops
|
||||
]
|
||||
# pylint: enable=protected-access
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
"""The type of the op (e.g. `"MatMul"`)."""
|
||||
return c_api.TF_OperationOpType(self._c_op)
|
||||
return pywrap_tf_session.TF_OperationOpType(self._c_op)
|
||||
|
||||
@property
|
||||
def graph(self):
|
||||
@ -2238,8 +2247,8 @@ class Operation(object):
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
with c_api_util.tf_buffer() as buf:
|
||||
c_api.TF_OperationToNodeDef(self._c_op, buf)
|
||||
data = c_api.TF_GetBuffer(buf)
|
||||
pywrap_tf_session.TF_OperationToNodeDef(self._c_op, buf)
|
||||
data = pywrap_tf_session.TF_GetBuffer(buf)
|
||||
node_def = node_def_pb2.NodeDef()
|
||||
node_def.ParseFromString(compat.as_bytes(data))
|
||||
return node_def
|
||||
@ -2264,17 +2273,18 @@ class Operation(object):
|
||||
|
||||
def _set_attr(self, attr_name, attr_value):
|
||||
"""Private method used to set an attribute in the node_def."""
|
||||
buf = c_api.TF_NewBufferFromString(
|
||||
buf = pywrap_tf_session.TF_NewBufferFromString(
|
||||
compat.as_bytes(attr_value.SerializeToString()))
|
||||
try:
|
||||
self._set_attr_with_buf(attr_name, buf)
|
||||
finally:
|
||||
c_api.TF_DeleteBuffer(buf)
|
||||
pywrap_tf_session.TF_DeleteBuffer(buf)
|
||||
|
||||
def _set_attr_with_buf(self, attr_name, attr_buf):
|
||||
"""Set an attr in the node_def with a pre-allocated buffer."""
|
||||
# pylint: disable=protected-access
|
||||
c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, attr_buf)
|
||||
pywrap_tf_session.SetAttr(self._graph._c_graph, self._c_op, attr_name,
|
||||
attr_buf)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def _set_func_attr(self, attr_name, func_name):
|
||||
@ -2307,7 +2317,7 @@ class Operation(object):
|
||||
def _clear_attr(self, attr_name):
|
||||
"""Private method used to clear an attribute in the node_def."""
|
||||
# pylint: disable=protected-access
|
||||
c_api.ClearAttr(self._graph._c_graph, self._c_op, attr_name)
|
||||
pywrap_tf_session.ClearAttr(self._graph._c_graph, self._c_op, attr_name)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def get_attr(self, name):
|
||||
@ -2325,8 +2335,8 @@ class Operation(object):
|
||||
fields = ("s", "i", "f", "b", "type", "shape", "tensor", "func")
|
||||
try:
|
||||
with c_api_util.tf_buffer() as buf:
|
||||
c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf)
|
||||
data = c_api.TF_GetBuffer(buf)
|
||||
pywrap_tf_session.TF_OperationGetAttrValueProto(self._c_op, name, buf)
|
||||
data = pywrap_tf_session.TF_GetBuffer(buf)
|
||||
except errors.InvalidArgumentError as e:
|
||||
# Convert to ValueError for backwards compatibility.
|
||||
raise ValueError(str(e))
|
||||
@ -2352,7 +2362,7 @@ class Operation(object):
|
||||
def _get_attr_type(self, name):
|
||||
"""Returns the `DType` value of the attr of this op with the given `name`."""
|
||||
try:
|
||||
dtype_enum = c_api.TF_OperationGetAttrType(self._c_op, name)
|
||||
dtype_enum = pywrap_tf_session.TF_OperationGetAttrType(self._c_op, name)
|
||||
return _DTYPES_INTERN_TABLE[dtype_enum]
|
||||
except errors.InvalidArgumentError as e:
|
||||
# Convert to ValueError for backwards compatibility.
|
||||
@ -2361,7 +2371,7 @@ class Operation(object):
|
||||
def _get_attr_bool(self, name):
|
||||
"""Returns the `bool` value of the attr of this op with the given `name`."""
|
||||
try:
|
||||
return c_api.TF_OperationGetAttrBool(self._c_op, name)
|
||||
return pywrap_tf_session.TF_OperationGetAttrBool(self._c_op, name)
|
||||
except errors.InvalidArgumentError as e:
|
||||
# Convert to ValueError for backwards compatibility.
|
||||
raise ValueError(str(e))
|
||||
@ -2369,7 +2379,7 @@ class Operation(object):
|
||||
def _get_attr_int(self, name):
|
||||
"""Returns the `int` value of the attr of this op with the given `name`."""
|
||||
try:
|
||||
return c_api.TF_OperationGetAttrInt(self._c_op, name)
|
||||
return pywrap_tf_session.TF_OperationGetAttrInt(self._c_op, name)
|
||||
except errors.InvalidArgumentError as e:
|
||||
# Convert to ValueError for backwards compatibility.
|
||||
raise ValueError(str(e))
|
||||
@ -2812,7 +2822,7 @@ class Graph(object):
|
||||
# The C API requires all ops to have shape functions. Disable this
|
||||
# requirement (many custom ops do not have shape functions, and we don't
|
||||
# want to break these existing cases).
|
||||
c_api.SetRequireShapeInferenceFns(self._c_graph, False)
|
||||
pywrap_tf_session.SetRequireShapeInferenceFns(self._c_graph, False)
|
||||
if tf2.enabled():
|
||||
self.switch_to_thread_local()
|
||||
|
||||
@ -2945,8 +2955,8 @@ class Graph(object):
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
with c_api_util.tf_buffer() as buf:
|
||||
c_api.TF_GraphVersions(self._c_graph, buf)
|
||||
data = c_api.TF_GetBuffer(buf)
|
||||
pywrap_tf_session.TF_GraphVersions(self._c_graph, buf)
|
||||
data = pywrap_tf_session.TF_GetBuffer(buf)
|
||||
version_def = versions_pb2.VersionDef()
|
||||
version_def.ParseFromString(compat.as_bytes(data))
|
||||
return version_def
|
||||
@ -3047,8 +3057,8 @@ class Graph(object):
|
||||
# pylint: enable=line-too-long
|
||||
with self._lock:
|
||||
with c_api_util.tf_buffer() as buf:
|
||||
c_api.TF_GraphToGraphDef(self._c_graph, buf)
|
||||
data = c_api.TF_GetBuffer(buf)
|
||||
pywrap_tf_session.TF_GraphToGraphDef(self._c_graph, buf)
|
||||
data = pywrap_tf_session.TF_GetBuffer(buf)
|
||||
graph = graph_pb2.GraphDef()
|
||||
graph.ParseFromString(compat.as_bytes(data))
|
||||
# Strip the experimental library field iff it's empty.
|
||||
@ -3181,7 +3191,8 @@ class Graph(object):
|
||||
# pylint: disable=protected-access
|
||||
gradient = (
|
||||
function._grad_func._c_func.func if function._grad_func else None)
|
||||
c_api.TF_GraphCopyFunction(self._c_graph, function._c_func.func, gradient)
|
||||
pywrap_tf_session.TF_GraphCopyFunction(self._c_graph, function._c_func.func,
|
||||
gradient)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
self._functions[compat.as_str(name)] = function
|
||||
@ -3658,7 +3669,7 @@ class Graph(object):
|
||||
return self._nodes_by_name[name]
|
||||
|
||||
def _get_operation_by_tf_operation(self, tf_oper):
|
||||
op_name = c_api.TF_OperationName(tf_oper)
|
||||
op_name = pywrap_tf_session.TF_OperationName(tf_oper)
|
||||
return self._get_operation_by_name_unsafe(op_name)
|
||||
|
||||
def get_tensor_by_name(self, name):
|
||||
@ -3711,9 +3722,10 @@ class Graph(object):
|
||||
except KeyError:
|
||||
with c_api_util.tf_buffer() as buf:
|
||||
# pylint: disable=protected-access
|
||||
c_api.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), buf)
|
||||
pywrap_tf_session.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type),
|
||||
buf)
|
||||
# pylint: enable=protected-access
|
||||
data = c_api.TF_GetBuffer(buf)
|
||||
data = pywrap_tf_session.TF_GetBuffer(buf)
|
||||
op_def = op_def_pb2.OpDef()
|
||||
op_def.ParseFromString(compat.as_bytes(data))
|
||||
self._op_def_cache[type] = op_def
|
||||
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow as c_api
|
||||
from tensorflow.python.client import pywrap_tf_session as c_api
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
|
@ -44,9 +44,9 @@ from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python import _pywrap_stacktrace_handler
|
||||
from tensorflow.python import _pywrap_util_port
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.client import device_lib
|
||||
from tensorflow.python.client import pywrap_tf_session
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.compat.compat import forward_compatibility_horizon
|
||||
from tensorflow.python.eager import context
|
||||
@ -199,7 +199,7 @@ def assert_equal_graph_def(actual, expected, checkpoint_v2=False,
|
||||
_strip_hash_table_shared_name(actual)
|
||||
_strip_hash_table_shared_name(expected)
|
||||
|
||||
diff = pywrap_tensorflow.EqualGraphDefWrapper(actual.SerializeToString(),
|
||||
diff = pywrap_tf_session.EqualGraphDefWrapper(actual.SerializeToString(),
|
||||
expected.SerializeToString())
|
||||
if diff:
|
||||
raise AssertionError(compat.as_str(diff))
|
||||
@ -1694,10 +1694,10 @@ def enable_tf_xla_constant_folding(description):
|
||||
def decorator(f):
|
||||
|
||||
def decorated(self, *args, **kwargs):
|
||||
original_var = pywrap_tensorflow.TF_GetXlaConstantFoldingDisabled()
|
||||
pywrap_tensorflow.TF_SetXlaConstantFoldingDisabled(False)
|
||||
original_var = pywrap_tf_session.TF_GetXlaConstantFoldingDisabled()
|
||||
pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(False)
|
||||
result = f(self, *args, **kwargs)
|
||||
pywrap_tensorflow.TF_SetXlaConstantFoldingDisabled(original_var)
|
||||
pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(original_var)
|
||||
return result
|
||||
|
||||
return decorated
|
||||
@ -1799,9 +1799,9 @@ def xla_allow_fallback(description): # pylint: disable=unused-argument
|
||||
# Update the global XLABuildOpsPassFlags to enable lazy compilation,
|
||||
# which allows the compiler to fall back to TF classic. Remember the
|
||||
# old value so that we can reset it.
|
||||
old_value = pywrap_tensorflow.TF_SetXlaEnableLazyCompilation(True)
|
||||
old_value = pywrap_tf_session.TF_SetXlaEnableLazyCompilation(True)
|
||||
result = func(self, *args, **kwargs)
|
||||
pywrap_tensorflow.TF_SetXlaEnableLazyCompilation(old_value)
|
||||
pywrap_tf_session.TF_SetXlaEnableLazyCompilation(old_value)
|
||||
return result
|
||||
else:
|
||||
return func(self, *args, **kwargs)
|
||||
@ -1835,13 +1835,13 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
|
||||
super(TensorFlowTestCase, self).__init__(methodName)
|
||||
if is_xla_enabled():
|
||||
pywrap_tensorflow.TF_SetXlaAutoJitMode("2")
|
||||
pywrap_tensorflow.TF_SetXlaMinClusterSize(1)
|
||||
pywrap_tensorflow.TF_SetXlaEnableLazyCompilation(False)
|
||||
pywrap_tensorflow.TF_SetTfXlaCpuGlobalJit(True)
|
||||
pywrap_tf_session.TF_SetXlaAutoJitMode("2")
|
||||
pywrap_tf_session.TF_SetXlaMinClusterSize(1)
|
||||
pywrap_tf_session.TF_SetXlaEnableLazyCompilation(False)
|
||||
pywrap_tf_session.TF_SetTfXlaCpuGlobalJit(True)
|
||||
# Constant folding secretly runs code on TF:Classic CPU, so we also
|
||||
# disable it here.
|
||||
pywrap_tensorflow.TF_SetXlaConstantFoldingDisabled(True)
|
||||
pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(True)
|
||||
|
||||
self._threads = []
|
||||
self._tempdir = None
|
||||
|
@ -19,14 +19,14 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.client import pywrap_tf_session
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
__version__ = pywrap_tensorflow.__version__
|
||||
__git_version__ = pywrap_tensorflow.__git_version__
|
||||
__compiler_version__ = pywrap_tensorflow.__compiler_version__
|
||||
__cxx11_abi_flag__ = pywrap_tensorflow.__cxx11_abi_flag__
|
||||
__monolithic_build__ = pywrap_tensorflow.__monolithic_build__
|
||||
__version__ = pywrap_tf_session.__version__
|
||||
__git_version__ = pywrap_tf_session.__git_version__
|
||||
__compiler_version__ = pywrap_tf_session.__compiler_version__
|
||||
__cxx11_abi_flag__ = pywrap_tf_session.__cxx11_abi_flag__
|
||||
__monolithic_build__ = pywrap_tf_session.__monolithic_build__
|
||||
|
||||
VERSION = __version__
|
||||
tf_export(
|
||||
@ -61,13 +61,13 @@ tf_export(
|
||||
"sysconfig.MONOLITHIC_BUILD", "MONOLITHIC_BUILD", "__monolithic_build__"
|
||||
]).export_constant(__name__, "MONOLITHIC_BUILD")
|
||||
|
||||
GRAPH_DEF_VERSION = pywrap_tensorflow.GRAPH_DEF_VERSION
|
||||
GRAPH_DEF_VERSION = pywrap_tf_session.GRAPH_DEF_VERSION
|
||||
tf_export(
|
||||
"version.GRAPH_DEF_VERSION",
|
||||
v1=["version.GRAPH_DEF_VERSION", "GRAPH_DEF_VERSION"]).export_constant(
|
||||
__name__, "GRAPH_DEF_VERSION")
|
||||
GRAPH_DEF_VERSION_MIN_CONSUMER = (
|
||||
pywrap_tensorflow.GRAPH_DEF_VERSION_MIN_CONSUMER)
|
||||
pywrap_tf_session.GRAPH_DEF_VERSION_MIN_CONSUMER)
|
||||
tf_export(
|
||||
"version.GRAPH_DEF_VERSION_MIN_CONSUMER",
|
||||
v1=[
|
||||
@ -75,7 +75,7 @@ tf_export(
|
||||
"GRAPH_DEF_VERSION_MIN_CONSUMER"
|
||||
]).export_constant(__name__, "GRAPH_DEF_VERSION_MIN_CONSUMER")
|
||||
GRAPH_DEF_VERSION_MIN_PRODUCER = (
|
||||
pywrap_tensorflow.GRAPH_DEF_VERSION_MIN_PRODUCER)
|
||||
pywrap_tf_session.GRAPH_DEF_VERSION_MIN_PRODUCER)
|
||||
tf_export(
|
||||
"version.GRAPH_DEF_VERSION_MIN_PRODUCER",
|
||||
v1=[
|
||||
|
@ -88,6 +88,20 @@ inline void MaybeRaiseRegisteredFromTFStatus(TF_Status* status) {
|
||||
}
|
||||
}
|
||||
|
||||
inline void MaybeRaiseRegisteredFromTFStatusWithGIL(TF_Status* status) {
|
||||
TF_Code code = TF_GetCode(status);
|
||||
if (code != TF_OK) {
|
||||
// Acquire GIL for throwing exception.
|
||||
pybind11::gil_scoped_acquire acquire;
|
||||
|
||||
PyErr_SetObject(PyExceptionRegistry::Lookup(code),
|
||||
pybind11::make_tuple(pybind11::none(), pybind11::none(),
|
||||
TF_Message(status))
|
||||
.ptr());
|
||||
throw pybind11::error_already_set();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace pybind11 {
|
||||
|
@ -19,8 +19,8 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.compiler.tf2xla.ops import gen_xla_ops
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.client import pywrap_tf_session
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -1128,7 +1128,7 @@ def _BroadcastToGrad(op, grad):
|
||||
input_value_shape = array_ops.shape(input_value)
|
||||
if not context.executing_eagerly():
|
||||
broadcast_shape_static = tensor_shape.TensorShape(
|
||||
pywrap_tensorflow.TF_TryEvaluateConstant_wrapper(
|
||||
pywrap_tf_session.TF_TryEvaluateConstant_wrapper(
|
||||
broadcast_shape.graph._c_graph, broadcast_shape._as_tf_output())) # pylint: disable=protected-access
|
||||
if broadcast_shape_static.is_fully_defined():
|
||||
broadcast_shape = constant_op.constant(
|
||||
|
@ -17,7 +17,7 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.client import pywrap_tf_session
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import tape as tape_lib
|
||||
@ -66,7 +66,7 @@ def copy_handle_data(source_t, target_t):
|
||||
and handle_data.is_set
|
||||
and handle_data.shape_and_type):
|
||||
# pylint: disable=protected-access
|
||||
pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph,
|
||||
pywrap_tf_session.SetHandleShapeAndType(target_t.graph._c_graph,
|
||||
target_t._as_tf_output(),
|
||||
handle_data.SerializeToString())
|
||||
# pylint: enable=protected-access
|
||||
@ -76,10 +76,12 @@ def copy_handle_data(source_t, target_t):
|
||||
ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
|
||||
shapes = [[d.size for d in s.dim] # pylint: disable=g-complex-comprehension
|
||||
if not s.unknown_rank else None for s in shapes]
|
||||
pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
|
||||
pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
|
||||
target_t._op._graph._c_graph, # pylint: disable=protected-access
|
||||
target_t._as_tf_output(), # pylint: disable=protected-access
|
||||
shapes, ranks, types)
|
||||
shapes,
|
||||
ranks,
|
||||
types)
|
||||
|
||||
|
||||
@tf_export("custom_gradient")
|
||||
|
@ -19,7 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow as c_api
|
||||
from tensorflow.python.client import pywrap_tf_session as c_api
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
|
@ -26,7 +26,7 @@ import weakref
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import variable_pb2
|
||||
from tensorflow.python import _pywrap_utils
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.client import pywrap_tf_session
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -56,7 +56,7 @@ from tensorflow.python.util.deprecation import deprecated_args
|
||||
def get_resource_handle_data(graph_op):
|
||||
assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck
|
||||
|
||||
handle_data = pywrap_tensorflow.GetHandleShapeAndType(
|
||||
handle_data = pywrap_tf_session.GetHandleShapeAndType(
|
||||
graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access
|
||||
|
||||
return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
|
||||
@ -91,10 +91,12 @@ def _set_handle_shapes_and_types(tensor, handle_data, graph_mode):
|
||||
ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
|
||||
shapes = [[d.size for d in s.dim] # pylint: disable=g-complex-comprehension
|
||||
if not s.unknown_rank else None for s in shapes]
|
||||
pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
|
||||
pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
|
||||
tensor._op._graph._c_graph, # pylint: disable=protected-access
|
||||
tensor._as_tf_output(), # pylint: disable=protected-access
|
||||
shapes, ranks, types)
|
||||
shapes,
|
||||
ranks,
|
||||
types)
|
||||
|
||||
|
||||
def _combine_handle_data(handle, initial_value):
|
||||
|
@ -23,7 +23,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.framework import resource_handle_pb2
|
||||
from tensorflow.python import pywrap_tensorflow_internal
|
||||
from tensorflow.python.client import pywrap_tf_session
|
||||
from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -71,8 +71,7 @@ class TensorHandle(object):
|
||||
if not self._resource_handle:
|
||||
self._resource_handle = resource_handle_pb2.ResourceHandleProto()
|
||||
self._resource_handle.device = self._handle.split(";")[-1]
|
||||
self._resource_handle.container = (
|
||||
pywrap_tensorflow_internal.TENSOR_HANDLE_KEY)
|
||||
self._resource_handle.container = (pywrap_tf_session.TENSOR_HANDLE_KEY)
|
||||
self._resource_handle.name = self._handle
|
||||
return self._resource_handle
|
||||
|
||||
|
@ -24,7 +24,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python import pywrap_tensorflow as c_api
|
||||
from tensorflow.python.client import pywrap_tf_session as c_api
|
||||
from tensorflow.python.eager import backprop_util
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
@ -56,11 +56,6 @@ try:
|
||||
sys.setdlopenflags(_default_dlopen_flags | ctypes.RTLD_LOCAL)
|
||||
|
||||
from tensorflow.python.pywrap_tensorflow_internal import *
|
||||
from tensorflow.python.pywrap_tensorflow_internal import __version__
|
||||
from tensorflow.python.pywrap_tensorflow_internal import __git_version__
|
||||
from tensorflow.python.pywrap_tensorflow_internal import __compiler_version__
|
||||
from tensorflow.python.pywrap_tensorflow_internal import __cxx11_abi_flag__
|
||||
from tensorflow.python.pywrap_tensorflow_internal import __monolithic_build__
|
||||
|
||||
if _use_dlopen_global_flags:
|
||||
pywrap_dlopen_global_flags.reset_dlopen_flags()
|
||||
|
@ -17,8 +17,6 @@ limitations under the License.
|
||||
* The includes are intentionally not alphabetically sorted, as the order of
|
||||
* includes follows dependency order */
|
||||
|
||||
%include "tensorflow/python/client/tf_session.i"
|
||||
|
||||
%include "tensorflow/python/grappler/cluster.i"
|
||||
%include "tensorflow/python/grappler/item.i"
|
||||
%include "tensorflow/python/grappler/tf_optimizer.i"
|
||||
|
@ -322,7 +322,6 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
|
||||
py::class_<TF_DeviceList> TF_DeviceList_class(m, "TF_DeviceList");
|
||||
py::class_<TF_Function> TF_Function_class(m, "TF_Function");
|
||||
py::class_<TF_Buffer> TF_Buffer_class(m, "TF_Buffer");
|
||||
|
||||
m.def("TFE_Py_RegisterExceptionClass", [](const py::handle& e) {
|
||||
return tensorflow::pyo_or_throw(TFE_Py_RegisterExceptionClass(e.ptr()));
|
||||
@ -369,12 +368,10 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
m.def("TFE_HostAddressSpace", [](py::handle& o, TF_Buffer& buf) {
|
||||
TFE_HostAddressSpace(tensorflow::InputTFE_Context(o), &buf);
|
||||
});
|
||||
m.def("TFE_ContextAddFunction", [](py::handle& ctx, py::handle& func) {
|
||||
m.def("TFE_ContextAddFunction", [](py::handle& ctx, TF_Function* func) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
SwigPyObject* sstable_swig = reinterpret_cast<SwigPyObject*>(func.ptr());
|
||||
auto function = reinterpret_cast<TF_Function*>(sstable_swig->ptr);
|
||||
TFE_ContextAddFunction(tensorflow::InputTFE_Context(ctx), function,
|
||||
TFE_ContextAddFunction(tensorflow::InputTFE_Context(ctx), func,
|
||||
status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
});
|
||||
@ -1066,13 +1063,6 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
// Util buffer helper functions
|
||||
m.def("TF_NewBufferFromString", &TF_NewBufferFromString,
|
||||
py::return_value_policy::reference);
|
||||
m.def("TF_NewBuffer", &TF_NewBuffer, py::return_value_policy::reference);
|
||||
m.def("TF_GetBuffer", [](TF_Buffer* buf) {
|
||||
return tensorflow::pyo_or_throw(PyBytes_FromStringAndSize(
|
||||
reinterpret_cast<const char*>(buf->data), buf->length));
|
||||
});
|
||||
m.def("TF_DeleteBuffer", &TF_DeleteBuffer,
|
||||
py::return_value_policy::reference);
|
||||
|
||||
// C API Enum
|
||||
|
||||
|
@ -25,7 +25,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.client import pywrap_tf_session
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
@ -253,11 +253,11 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
"""An internal class to help manage the TF_Buffer lifetime."""
|
||||
|
||||
def __init__(self, buf_string):
|
||||
self._buffer = pywrap_tensorflow.TF_NewBufferFromString(
|
||||
self._buffer = pywrap_tf_session.TF_NewBufferFromString(
|
||||
compat.as_bytes(buf_string))
|
||||
|
||||
def __del__(self):
|
||||
pywrap_tensorflow.TF_DeleteBuffer(self._buffer)
|
||||
pywrap_tf_session.TF_DeleteBuffer(self._buffer)
|
||||
|
||||
def __init__(self, name, num_replicas, pivot):
|
||||
"""Builds a new TPUReplicateContext.
|
||||
|
@ -21,7 +21,7 @@ from __future__ import print_function
|
||||
from tensorflow.core.protobuf import cluster_pb2
|
||||
from tensorflow.core.protobuf import device_filters_pb2
|
||||
from tensorflow.core.protobuf import tensorflow_server_pb2
|
||||
from tensorflow.python import pywrap_tensorflow as c_api
|
||||
from tensorflow.python.client import pywrap_tf_session as c_api
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import deprecation
|
||||
|
@ -134,15 +134,25 @@ def get_symbols(path_to_lib, re_filter):
|
||||
|
||||
# Example symbol line:
|
||||
# 954 00000000 SECT2BD notype () External | ?IsSequence@swig@tensorflow@@YA_NPEAU_object@@@Z (bool __cdecl tensorflow::swig::IsSequence(struct _object *))
|
||||
# Anomaly symbol line:
|
||||
# 00B 00000000 SECT4 notype External | _tensorflow_numpy_api.
|
||||
sym_filtered = []
|
||||
re_filter_comp = re.compile(r"{}".format(re_filter))
|
||||
|
||||
# Filter out symbol from the split line (`sym_split` in the for loop below).
|
||||
sym_line_filter = r".*\s+\| (.*) \(.*"
|
||||
sym_line_filter_anomaly = r".*\s+\| (.*)"
|
||||
|
||||
for sym_line in sym_split:
|
||||
if re_filter_comp.search(sym_line):
|
||||
sym = re.match(sym_line_filter, sym_line).groups()[0]
|
||||
try:
|
||||
sym = re.match(sym_line_filter, sym_line).groups()[0]
|
||||
except AttributeError:
|
||||
try:
|
||||
sym = re.match(sym_line_filter_anomaly, sym_line).groups()[0]
|
||||
except:
|
||||
raise RuntimeError("Unable to find the following symbol:[%s]" % sym_line)
|
||||
|
||||
sym_filtered.append(sym)
|
||||
|
||||
return sym_filtered
|
||||
|
@ -75,12 +75,13 @@ tensorflow::Status::code
|
||||
tensorflow::Status::error_message
|
||||
tensorflow::Status::ok()
|
||||
|
||||
[core_cpu_impl] # device_lib tfe
|
||||
[core_cpu_impl] # device_lib, tfe, tf_session
|
||||
tensorflow::Device::attributes
|
||||
tensorflow::DeviceFactory::AddDevices
|
||||
tensorflow::SessionOptions::SessionOptions
|
||||
tensorflow::DoQuantizeTrainingOnSerializedGraphDef
|
||||
tensorflow::DeviceFactory::ListAllPhysicalDevices
|
||||
tensorflow::SessionState::kTensorHandleResourceTypeName
|
||||
|
||||
[protos_all] # device_lib, dtypes
|
||||
tensorflow::DataType_IsValid
|
||||
@ -195,3 +196,50 @@ tensorflow::ExperimentalRunPassPipeline
|
||||
tensorflow::ExperimentalConvertSavedModelV1ToMlir
|
||||
tensorflow::ExperimentalConvertSavedModelToMlir
|
||||
tensorflow::ImportGraphDef
|
||||
|
||||
[op_gen_lib] # tf_session
|
||||
tensorflow::ApiDefMap::~ApiDefMap
|
||||
|
||||
[core_cpu_base_no_ops] # tf_session
|
||||
tensorflow::ShapeRefiner::~ShapeRefiner
|
||||
|
||||
[python_api] # tf_session
|
||||
tensorflow::AddControlInput
|
||||
tensorflow::SetAttr
|
||||
tensorflow::ClearAttr
|
||||
tensorflow::SetRequestedDevice
|
||||
tensorflow::UpdateEdge
|
||||
tensorflow::RemoveAllControlInputs
|
||||
tensorflow::SetRequireShapeInferenceFns
|
||||
tensorflow::ExtendSession
|
||||
tensorflow::GetHandleShapeAndType
|
||||
tensorflow::SetHandleShapeAndType
|
||||
tensorflow::AddWhileInputHack
|
||||
|
||||
[numpy_lib] # tf_session
|
||||
tensorflow::ImportNumpy
|
||||
_tensorflow_numpy_api
|
||||
|
||||
[tf_session_helper] # tf_session
|
||||
tensorflow::TF_NewSessionRef
|
||||
tensorflow::TF_SessionMakeCallable
|
||||
tensorflow::TF_SessionRunCallable
|
||||
tensorflow::TF_SessionReleaseCallable
|
||||
tensorflow::TF_Reset_wrapper
|
||||
tensorflow::EqualGraphDefWrapper
|
||||
tensorflow::EqualAttrValueWrapper
|
||||
tensorflow::TF_GraphGetTensorShapeHelper
|
||||
tensorflow::TF_SessionRun_wrapper
|
||||
tensorflow::TF_SessionPRunSetup_wrapper
|
||||
tensorflow::TF_SessionPRun_wrapper
|
||||
tensorflow::GetOperationInputs
|
||||
tensorflow::TF_OperationGetControlInputs_wrapper
|
||||
tensorflow::TF_OperationGetControlOutputs_wrapper
|
||||
tensorflow::TF_OperationOutputConsumers_wrapper
|
||||
tensorflow::TF_GraphToFunction_wrapper
|
||||
tensorflow::TF_GraphSetOutputHandleShapesAndTypes_wrapper
|
||||
tensorflow::TF_CreatePlaceholders
|
||||
tensorflow::TF_GraphSetTensorShape_wrapper
|
||||
tensorflow::TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper
|
||||
tensorflow::TF_TryEvaluateConstant_wrapper
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user