Export the Eager classes and functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.

PiperOrigin-RevId: 286110711
Change-Id: I7bf6f6f4ce1d6bf3e8a3e40ef4a83f82333800f6
This commit is contained in:
Amit Patankar 2019-12-17 19:37:26 -08:00 committed by TensorFlower Gardener
parent 03341c4342
commit 7bd345bcbb
47 changed files with 1866 additions and 853 deletions

View File

@ -53,6 +53,20 @@ filegroup(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
filegroup(
name = "pywrap_eager_hdrs",
srcs = [
"c_api_internal.h",
"tf_status_helper.h",
"tf_status_internal.h",
"tf_tensor_internal.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
tf_cuda_library( tf_cuda_library(
name = "c_api_internal", name = "c_api_internal",
hdrs = [ hdrs = [

View File

@ -88,6 +88,18 @@ tf_cuda_library(
alwayslink = 1, alwayslink = 1,
) )
filegroup(
name = "pywrap_eager_hdrs",
srcs = [
"c_api_experimental.h",
"c_api_internal.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
tf_cuda_library( tf_cuda_library(
name = "c_api_internal", name = "c_api_internal",
srcs = ["c_api_experimental.h"], srcs = ["c_api_experimental.h"],

View File

@ -439,6 +439,23 @@ tf_cc_test(
], ],
) )
filegroup(
name = "pywrap_eager_hdrs",
srcs = [
"attr_builder.h",
"context.h",
"eager_executor.h",
"eager_operation.h",
"kernel_and_device.h",
"tensor_handle.h",
"tensor_handle_data.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
filegroup( filegroup(
name = "srcs", name = "srcs",
srcs = glob( srcs = glob(

View File

@ -783,3 +783,20 @@ tf_cc_test(
"//tensorflow/core:worker_proto_cc", "//tensorflow/core:worker_proto_cc",
], ],
) )
filegroup(
name = "pywrap_eager_hdrs",
srcs = [
"call_options.h",
"message_wrappers.h",
"rendezvous_mgr_interface.h",
"server_lib.h",
"worker_cache.h",
"worker_env.h",
"worker_interface.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)

View File

@ -216,3 +216,16 @@ cc_library(
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
], ],
) )
filegroup(
name = "pywrap_eager_hdrs",
srcs = [
"eager_client.h",
"remote_tensor_handle.h",
"remote_tensor_handle_data.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)

View File

@ -853,6 +853,18 @@ tf_cc_tests(
], ],
) )
filegroup(
name = "pywrap_eager_hdrs",
srcs = [
"op_gen_lib.h",
"rendezvous.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
# All framewrok protos are self-contained, i.e. they only import other # All framewrok protos are self-contained, i.e. they only import other
# protos from the same package, so we can build the protos here and then # protos from the same package, so we can build the protos here and then
# link them from core:protos_all without circular dependencies. # link them from core:protos_all without circular dependencies.

View File

@ -523,3 +523,14 @@ tf_cc_test(
"//tensorflow/core:testlib", "//tensorflow/core:testlib",
], ],
) )
filegroup(
name = "pywrap_eager_hdrs",
srcs = [
"profiler_interface.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)

View File

@ -43,6 +43,17 @@ tf_cuda_library(
alwayslink = True, alwayslink = True,
) )
filegroup(
name = "pywrap_eager_hdrs",
srcs = [
"profiler_session.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
cc_library( cc_library(
name = "traceme", name = "traceme",
hdrs = ["traceme.h"], hdrs = ["traceme.h"],

View File

@ -171,6 +171,7 @@ py_library(
":platform", ":platform",
":proto_ops", ":proto_ops",
":pywrap_tensorflow", ":pywrap_tensorflow",
":pywrap_tfe",
":rnn_ops_gen", ":rnn_ops_gen",
":saver_test_utils", ":saver_test_utils",
":script_ops", ":script_ops",
@ -251,6 +252,7 @@ py_library(
deps = [ deps = [
":_pywrap_util_port", ":_pywrap_util_port",
":lib", ":lib",
":pywrap_tfe",
":util", ":util",
"//tensorflow/core:protos_all_py", "//tensorflow/core:protos_all_py",
"@absl_py//absl:app", "@absl_py//absl:app",
@ -477,13 +479,13 @@ cc_library(
cc_library( cc_library(
name = "pybind11_status", name = "pybind11_status",
hdrs = [ hdrs = [
"lib/core/py_exception_registry.h",
"lib/core/pybind11_status.h", "lib/core/pybind11_status.h",
"//tensorflow/c:headers", "//tensorflow/c:headers",
], ],
features = ["-parse_headers"], features = ["-parse_headers"],
visibility = tf_external_workspace_visible(visibility), visibility = tf_external_workspace_visible(visibility),
deps = [ deps = [
":py_exception_registry",
"//tensorflow/c:tf_status_headers", "//tensorflow/c:tf_status_headers",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
@ -1110,6 +1112,7 @@ py_library(
":lib", ":lib",
":platform", ":platform",
":pywrap_tensorflow", ":pywrap_tensorflow",
":pywrap_tfe",
":random_seed", ":random_seed",
":sparse_tensor", ":sparse_tensor",
":tensor_spec", ":tensor_spec",
@ -5492,7 +5495,6 @@ tf_py_wrap_cc(
"lib/io/py_record_reader.i", "lib/io/py_record_reader.i",
"lib/io/py_record_writer.i", "lib/io/py_record_writer.i",
"platform/base.i", "platform/base.i",
"pywrap_tfe.i",
"//tensorflow/compiler/mlir/python:mlir.i", "//tensorflow/compiler/mlir/python:mlir.i",
], ],
# add win_def_file for pywrap_tensorflow # add win_def_file for pywrap_tensorflow
@ -5573,7 +5575,12 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [
":safe_ptr", # checkpoint_reader ":safe_ptr", # checkpoint_reader
":python_op_gen", # python_op_gen ":python_op_gen", # python_op_gen
":bfloat16_lib", # bfloat16 ":bfloat16_lib", # bfloat16
"//tensorflow/python/eager:pywrap_tfe_lib", # pywrap_tfe_lib
"//tensorflow/core/util/tensor_bundle", # checkpoint_reader "//tensorflow/core/util/tensor_bundle", # checkpoint_reader
"//tensorflow/core/common_runtime/eager:eager_executor", # tfe
"//tensorflow/core/common_runtime/eager:context", # tfe
"//tensorflow/core/profiler/lib:profiler_session", # tfe
"//tensorflow/c:tf_status_helper", # tfe
] ]
# Filter the DEF file to reduce the number of symbols to 64K or less. # Filter the DEF file to reduce the number of symbols to 64K or less.
@ -7555,6 +7562,67 @@ py_library(
], ],
) )
py_library(
name = "pywrap_tfe",
srcs = ["pywrap_tfe.py"],
visibility = ["//visibility:public"],
deps = [
":_pywrap_tfe",
":pywrap_tensorflow",
],
)
tf_python_pybind_extension(
name = "_pywrap_tfe",
srcs = ["tfe_wrapper.cc"],
hdrs = [
"lib/core/safe_ptr.h",
"util/util.h",
":py_exception_registry_hdr",
"//tensorflow/c:headers",
"//tensorflow/c:pywrap_eager_hdrs",
"//tensorflow/c/eager:headers",
"//tensorflow/c/eager:pywrap_eager_hdrs",
"//tensorflow/core/common_runtime/eager:pywrap_eager_hdrs",
"//tensorflow/core/distributed_runtime:pywrap_eager_hdrs",
"//tensorflow/core/distributed_runtime/eager:pywrap_eager_hdrs",
"//tensorflow/core/framework:pywrap_eager_hdrs",
"//tensorflow/core/profiler/internal:pywrap_eager_hdrs",
"//tensorflow/core/profiler/lib:pywrap_eager_hdrs",
"//tensorflow/python/eager:pywrap_eager_hdrs",
],
module_name = "_pywrap_tfe",
deps = [
":pybind11_lib",
":pybind11_status",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@pybind11",
"//third_party/python_runtime:headers",
"//tensorflow/core:core_cpu_headers_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:platform",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
] + if_static(
extra_deps = [
"//tensorflow/core:eager_service_proto_cc",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:worker_proto_cc",
],
otherwise = [
"//tensorflow/core:eager_service_proto_cc_headers_only",
"//tensorflow/core:master_proto_cc_headers_only",
"//tensorflow/core:worker_proto_cc_headers_only",
],
),
)
tf_python_pybind_extension( tf_python_pybind_extension(
name = "_pywrap_graph_analyzer", name = "_pywrap_graph_analyzer",
srcs = ["grappler/graph_analyzer_tool_wrapper.cc"], srcs = ["grappler/graph_analyzer_tool_wrapper.cc"],

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
%include "tensorflow/python/lib/core/strings.i"
%include "tensorflow/python/platform/base.i" %include "tensorflow/python/platform/base.i"
%{ %{
@ -23,6 +24,13 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h" #include "tensorflow/core/public/version.h"
#include "tensorflow/python/client/tf_session_helper.h" #include "tensorflow/python/client/tf_session_helper.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
#include "tensorflow/python/eager/pywrap_tfe.h"
// We were getting lucky on imports with safe_ptr.h being placed prior to
// tf_session which imported safe_ptr. We also need pywrap_tfe.h to cast
// one of the inputs to a graph function from a Python string to const char*.
// Helper function to convert a Python list of Tensors to a C++ vector of // Helper function to convert a Python list of Tensors to a C++ vector of
// TF_Outputs. // TF_Outputs.
@ -78,6 +86,9 @@ void PyInt64ListToVector(PyObject* py_int_seq, std::vector<int64_t>* vec) {
%} %}
%include "tensorflow/c/tf_datatype.h"
%include "tensorflow/c/tf_status.h"
%include "tensorflow/python/client/tf_sessionrun_wrapper.i" %include "tensorflow/python/client/tf_sessionrun_wrapper.i"
// Required to use PyArray_* functions. // Required to use PyArray_* functions.
@ -85,6 +96,14 @@ void PyInt64ListToVector(PyObject* py_int_seq, std::vector<int64_t>* vec) {
tensorflow::ImportNumpy(); tensorflow::ImportNumpy();
%} %}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* op_name {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
// TensorFlow version and GraphDef versions // TensorFlow version and GraphDef versions
%constant const char* __version__ = TF_VERSION_STRING; %constant const char* __version__ = TF_VERSION_STRING;
%constant int GRAPH_DEF_VERSION = TF_GRAPH_DEF_VERSION; %constant int GRAPH_DEF_VERSION = TF_GRAPH_DEF_VERSION;
@ -174,6 +193,12 @@ tensorflow::ImportNumpy();
// See comment for "%noexception TF_SessionRun_wrapper;" // See comment for "%noexception TF_SessionRun_wrapper;"
%noexception TF_OperationGetControlInputs_wrapper; %noexception TF_OperationGetControlInputs_wrapper;
// Migrate one function from pywrap_tfe.i
%include "tensorflow/c/c_api_experimental.h"
%unignore TF_ImportGraphDefOptionsSetValidateColocationConstraints;
%noexception TF_ImportGraphDefOptionsSetValidateColocationConstraints;
// Build a Python list of TF_Operation* and return it. // Build a Python list of TF_Operation* and return it.
%typemap(out) std::vector<TF_Operation*> tensorflow::TF_OperationGetControlInputs_wrapper { %typemap(out) std::vector<TF_Operation*> tensorflow::TF_OperationGetControlInputs_wrapper {
$result = PyList_New($1.size()); $result = PyList_New($1.size());

View File

@ -268,7 +268,7 @@ py_library(
"//tensorflow/python:device", "//tensorflow/python:device",
"//tensorflow/python:dtypes", "//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:pywrap_tfe",
"//tensorflow/python:summary_ops_v2", "//tensorflow/python:summary_ops_v2",
"//tensorflow/python:tensor_util", "//tensorflow/python:tensor_util",
"//tensorflow/python:training", "//tensorflow/python:training",

View File

@ -24,7 +24,7 @@ import functools
import threading import threading
import weakref import weakref
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
from tensorflow.python.autograph.core import ag_ctx from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.impl import api as autograph from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
@ -944,7 +944,7 @@ class _MirroredReplicaThread(threading.Thread):
self.record_thread_local_summary_state() self.record_thread_local_summary_state()
self.record_thread_local_eager_context_state() self.record_thread_local_eager_context_state()
self.context_device_policy = ( self.context_device_policy = (
pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy( pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(
ctx._context_handle)) # pylint: disable=protected-access ctx._context_handle)) # pylint: disable=protected-access
self.graph = ops.get_default_graph() self.graph = ops.get_default_graph()
with ops.init_scope(): with ops.init_scope():

View File

@ -56,6 +56,18 @@ cc_library(
], ],
) )
filegroup(
name = "pywrap_eager_hdrs",
srcs = [
"pywrap_tensor_conversion.h",
"pywrap_tfe.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
# Transitive dependencies of this target will be included in the pip package. # Transitive dependencies of this target will be included in the pip package.
py_library( py_library(
name = "eager_pip", name = "eager_pip",
@ -90,7 +102,7 @@ py_library(
deps = [ deps = [
":context", ":context",
"//tensorflow/python:errors", "//tensorflow/python:errors",
"//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:pywrap_tfe",
], ],
) )
@ -100,7 +112,7 @@ py_library(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"], visibility = ["//tensorflow:internal"],
deps = [ deps = [
"//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:pywrap_tfe",
], ],
) )
@ -121,7 +133,7 @@ py_library(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"], visibility = ["//tensorflow:internal"],
deps = [ deps = [
"//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:pywrap_tfe",
], ],
) )
@ -131,13 +143,14 @@ py_library(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"], visibility = ["//tensorflow:internal"],
deps = [ deps = [
":eager_util",
":executor", ":executor",
":monitoring", ":monitoring",
"//tensorflow/python:device", "//tensorflow/python:device",
"//tensorflow/python:device_spec", "//tensorflow/python:device_spec",
"//tensorflow/python:errors", "//tensorflow/python:errors",
"//tensorflow/python:platform", "//tensorflow/python:platform",
"//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:pywrap_tfe",
"//tensorflow/python:tf2", "//tensorflow/python:tf2",
"//tensorflow/python:util", "//tensorflow/python:util",
"//third_party/py/numpy", "//third_party/py/numpy",
@ -164,8 +177,8 @@ py_library(
"//third_party/py/tf_agents:__subpackages__", "//third_party/py/tf_agents:__subpackages__",
], ],
deps = [ deps = [
"//tensorflow/python:c_api_util", ":eager_util",
"//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:pywrap_tfe",
"//tensorflow/python:util", "//tensorflow/python:util",
], ],
) )
@ -187,7 +200,8 @@ py_library(
visibility = ["//tensorflow:internal"], visibility = ["//tensorflow:internal"],
deps = [ deps = [
":context", ":context",
"//tensorflow/python:pywrap_tensorflow", ":eager_util",
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python:util", "//tensorflow/python:util",
], ],
) )
@ -209,7 +223,8 @@ py_library(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"], visibility = ["//tensorflow:internal"],
deps = [ deps = [
"//tensorflow/python:pywrap_tensorflow", ":eager_util",
"//tensorflow/python:pywrap_tfe",
], ],
) )
@ -298,7 +313,7 @@ cuda_py_test(
"//tensorflow/python:errors", "//tensorflow/python:errors",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:pywrap_tfe",
"//third_party/py/numpy", "//third_party/py/numpy",
], ],
) )
@ -410,7 +425,7 @@ py_library(
"//tensorflow/core:protos_all_py", "//tensorflow/core:protos_all_py",
"//tensorflow/python:dtypes", "//tensorflow/python:dtypes",
"//tensorflow/python:lib", "//tensorflow/python:lib",
"//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:pywrap_tfe",
"//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_shape",
"//tensorflow/python:util", "//tensorflow/python:util",
"@six_archive//:six", "@six_archive//:six",
@ -496,7 +511,7 @@ py_library(
"//tensorflow/python:errors", "//tensorflow/python:errors",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:pywrap_tfe",
"//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_shape",
"//tensorflow/python:unconnected_gradients", "//tensorflow/python:unconnected_gradients",
"//tensorflow/python:util", "//tensorflow/python:util",
@ -524,7 +539,7 @@ py_library(
deps = [ deps = [
":forwardprop_util", ":forwardprop_util",
"//tensorflow/python:platform", "//tensorflow/python:platform",
"//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:pywrap_tfe",
"//tensorflow/python:util", "//tensorflow/python:util",
], ],
) )
@ -535,7 +550,18 @@ py_library(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"], visibility = ["//tensorflow:internal"],
deps = [ deps = [
"//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:pywrap_tfe",
],
)
py_library(
name = "eager_util",
srcs = ["eager_util.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python:util",
], ],
) )
@ -552,7 +578,7 @@ cuda_py_test(
":remote", ":remote",
":test", ":test",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:pywrap_tfe",
"//tensorflow/python:random_ops", "//tensorflow/python:random_ops",
"//tensorflow/python/keras", "//tensorflow/python/keras",
"//third_party/py/numpy", "//third_party/py/numpy",
@ -637,7 +663,7 @@ tf_py_test(
":test", ":test",
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:pywrap_tfe",
"//tensorflow/python:random_ops", "//tensorflow/python:random_ops",
"//tensorflow/python:test_ops", "//tensorflow/python:test_ops",
"//third_party/py/numpy", "//third_party/py/numpy",
@ -649,7 +675,7 @@ py_library(
srcs = ["imperative_grad.py"], srcs = ["imperative_grad.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
"//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:pywrap_tfe",
"//tensorflow/python:unconnected_gradients", "//tensorflow/python:unconnected_gradients",
"//tensorflow/python:util", "//tensorflow/python:util",
], ],

View File

@ -24,7 +24,7 @@ import sys
import six import six
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
from tensorflow.python import _pywrap_utils from tensorflow.python import _pywrap_utils
from tensorflow.python.eager import backprop_util from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -71,19 +71,25 @@ def op_attr_type(op_type, attr_name):
except KeyError: except KeyError:
context.ensure_initialized() context.ensure_initialized()
h = context.context()._handle # pylint: disable=protected-access h = context.context()._handle # pylint: disable=protected-access
attr_type = pywrap_tensorflow.TFE_OpNameGetAttrType(h, op_type, attr_name) attr_type = pywrap_tfe.TFE_OpNameGetAttrType(h, op_type, attr_name)
_op_attr_type_cache[(op_type, attr_name)] = attr_type _op_attr_type_cache[(op_type, attr_name)] = attr_type
return attr_type return attr_type
def make_attr(attr_type, value): def make_attr(attr_type, value):
if attr_type == pywrap_tensorflow.TF_ATTR_TYPE: # pybind11 enums do not return the raw value like SWIG enums do. They are
# useful when comparing amongst each other but not direct integers as we are
# doing in most tests.
# https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types
# TODO(amitpatankar): After all SWIG transitions, convert the enum comparisons
# from integer value to class.
if attr_type == int(pywrap_tfe.TF_ATTR_TYPE):
return dtypes.as_dtype(value) return dtypes.as_dtype(value)
elif attr_type == [pywrap_tensorflow.TF_ATTR_TYPE]: elif attr_type == [int(pywrap_tfe.TF_ATTR_TYPE)]:
return [dtypes.as_dtype(v) for v in value] return [dtypes.as_dtype(v) for v in value]
elif attr_type == pywrap_tensorflow.TF_ATTR_SHAPE: elif attr_type == int(pywrap_tfe.TF_ATTR_SHAPE):
return tensor_shape.as_shape(value).as_proto() return tensor_shape.as_shape(value).as_proto()
elif attr_type == [pywrap_tensorflow.TF_ATTR_SHAPE]: elif attr_type == [int(pywrap_tfe.TF_ATTR_SHAPE)]:
return [tensor_shape.as_shape(v).as_proto() for v in value] return [tensor_shape.as_shape(v).as_proto() for v in value]
elif isinstance(value, str): elif isinstance(value, str):
return value.encode() return value.encode()
@ -141,16 +147,15 @@ def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
return grad_fn(mock_op, *out_grads) return grad_fn(mock_op, *out_grads)
pywrap_tensorflow.TFE_Py_RegisterGradientFunction(_gradient_function) pywrap_tfe.TFE_Py_RegisterGradientFunction(_gradient_function)
def _must_record_gradient(): def _must_record_gradient():
return not pywrap_tensorflow.TFE_Py_TapeSetIsEmpty() return not pywrap_tfe.TFE_Py_TapeSetIsEmpty()
def _record_gradient(op_name, inputs, attrs, results): def _record_gradient(op_name, inputs, attrs, results):
return pywrap_tensorflow.TFE_Py_RecordGradient(op_name, inputs, attrs, return pywrap_tfe.TFE_Py_RecordGradient(op_name, inputs, attrs, results)
results)
execute.must_record_gradient = _must_record_gradient execute.must_record_gradient = _must_record_gradient
@ -688,7 +693,7 @@ _default_vspace = imperative_grad.VSpace(
zeros_like_fn=default_gradient.zeros_like, zeros_like_fn=default_gradient.zeros_like,
ones_like_fn=default_gradient.ones_like, ones_like_fn=default_gradient.ones_like,
graph_shape_fn=gen_array_ops.shape) graph_shape_fn=gen_array_ops.shape)
pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace) pywrap_tfe.TFE_Py_RegisterVSpace(_default_vspace)
def _handle_or_self(x): def _handle_or_self(x):

View File

@ -21,7 +21,7 @@ import functools
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
@ -1014,19 +1014,19 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
def testGetAttrType(self): def testGetAttrType(self):
typ = backprop.op_attr_type('Add', 'T') typ = backprop.op_attr_type('Add', 'T')
self.assertEqual(typ, pywrap_tensorflow.TF_ATTR_TYPE) self.assertEqual(typ, int(pywrap_tfe.TF_ATTR_TYPE))
def testGetAttrList(self): def testGetAttrList(self):
typ = backprop.op_attr_type('MaxPool', 'ksize') typ = backprop.op_attr_type('MaxPool', 'ksize')
self.assertEqual(typ, [pywrap_tensorflow.TF_ATTR_INT]) self.assertEqual(typ, [int(pywrap_tfe.TF_ATTR_INT)])
def testMakeAttrType(self): def testMakeAttrType(self):
self.assertEqual(dtypes.float32, self.assertEqual(dtypes.float32,
backprop.make_attr(pywrap_tensorflow.TF_ATTR_TYPE, 1)) backprop.make_attr(int(pywrap_tfe.TF_ATTR_TYPE), 1))
def testMakeAttrTypeList(self): def testMakeAttrTypeList(self):
self.assertEqual([dtypes.float32], self.assertEqual([dtypes.float32],
backprop.make_attr([pywrap_tensorflow.TF_ATTR_TYPE], [1])) backprop.make_attr([int(pywrap_tfe.TF_ATTR_TYPE)], [1]))
def testMulType(self): def testMulType(self):
@ -1040,7 +1040,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
def testMakeAttrShape(self): def testMakeAttrShape(self):
for s in ([], None, [1, 2, 3], [None, None], [1, None, 3]): for s in ([], None, [1, 2, 3], [None, None], [1, None, 3]):
expected = tensor_shape.TensorShape(s).as_proto() expected = tensor_shape.TensorShape(s).as_proto()
actual = backprop.make_attr(pywrap_tensorflow.TF_ATTR_SHAPE, s) actual = backprop.make_attr(int(pywrap_tfe.TF_ATTR_SHAPE), s)
self.assertEqual( self.assertEqual(
expected, expected,
actual, actual,
@ -1051,7 +1051,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
shape_list = [[], None, [1, 2, 3], [None, None], [1, None, 3]] shape_list = [[], None, [1, 2, 3], [None, None], [1, None, 3]]
self.assertEqual( self.assertEqual(
[tensor_shape.TensorShape(s).as_proto() for s in shape_list], [tensor_shape.TensorShape(s).as_proto() for s in shape_list],
backprop.make_attr([pywrap_tensorflow.TF_ATTR_SHAPE], shape_list)) backprop.make_attr([int(pywrap_tfe.TF_ATTR_SHAPE)], shape_list))
def testArgsGradientFunction(self): def testArgsGradientFunction(self):

View File

@ -39,7 +39,7 @@ import six
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import backprop # pylint: disable=unused-import from tensorflow.python.eager import backprop # pylint: disable=unused-import
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -76,10 +76,10 @@ def c_tfe_py_fastpath_execute(a,
assert ctx.executing_eagerly( assert ctx.executing_eagerly(
), "The prototype doesn't contain C code for graph construction" ), "The prototype doesn't contain C code for graph construction"
try: try:
return pywrap_tensorflow.TFE_Py_FastPathExecute( return pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
ctx._handle, ctx.device_name, "MatMul", name, "MatMul", name, ctx.op_callbacks,
ctx.op_callbacks, a, b, "transpose_a", transpose_a, a, b, "transpose_a", transpose_a,
"transpose_b", transpose_b) "transpose_b", transpose_b)
except core._NotOkStatusException as e: except core._NotOkStatusException as e:
if name is not None: if name is not None:
message = e.message + " name: " + name message = e.message + " name: " + name
@ -339,8 +339,7 @@ class MicroBenchmarks(test.Benchmark):
inputs = [m] inputs = [m]
def f(): def f():
pywrap_tensorflow.TFE_Py_Execute(ctx_handle, None, "Identity", inputs, pywrap_tfe.TFE_Py_Execute(ctx_handle, None, "Identity", inputs, attrs, 1)
attrs, 1)
self._run(f, 30000) self._run(f, 30000)
@ -406,8 +405,7 @@ class MicroBenchmarks(test.Benchmark):
m.dtype.as_datatype_enum) m.dtype.as_datatype_enum)
def func(): def func():
pywrap_tensorflow.TFE_Py_Execute(ctx_handle, device, "MatMul", inputs, pywrap_tfe.TFE_Py_Execute(ctx_handle, device, "MatMul", inputs, attrs, 1)
attrs, 1)
self._run(func, num_iters) self._run(func, num_iters)

View File

@ -18,27 +18,27 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
class CancellationManager(object): class CancellationManager(object):
"""A mechanism for cancelling blocking computation.""" """A mechanism for cancelling blocking computation."""
def __init__(self): def __init__(self):
self._impl = pywrap_tensorflow.TFE_NewCancellationManager() self._impl = pywrap_tfe.TFE_NewCancellationManager()
@property @property
def is_cancelled(self): def is_cancelled(self):
"""Returns `True` if `CancellationManager.start_cancel` has been called.""" """Returns `True` if `CancellationManager.start_cancel` has been called."""
return pywrap_tensorflow.TFE_CancellationManagerIsCancelled(self._impl) return pywrap_tfe.TFE_CancellationManagerIsCancelled(self._impl)
def start_cancel(self): def start_cancel(self):
"""Cancels blocking operations that have been registered with this object.""" """Cancels blocking operations that have been registered with this object."""
pywrap_tensorflow.TFE_CancellationManagerStartCancel(self._impl) pywrap_tfe.TFE_CancellationManagerStartCancel(self._impl)
def get_cancelable_function(self, concrete_function): def get_cancelable_function(self, concrete_function):
# pylint: disable=protected-access # pylint: disable=protected-access
return concrete_function._experimental_with_cancellation_manager(self) return concrete_function._experimental_with_cancellation_manager(self)
def __del__(self): def __del__(self):
pywrap_tensorflow.TFE_DeleteCancellationManager(self._impl) pywrap_tfe.TFE_DeleteCancellationManager(self._impl)

View File

@ -29,11 +29,11 @@ import six
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.eager import eager_util as c_api_util
from tensorflow.python.eager import executor from tensorflow.python.eager import executor
from tensorflow.python.eager import monitoring from tensorflow.python.eager import monitoring
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import device as pydev from tensorflow.python.framework import device as pydev
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util import is_in_graph_mode from tensorflow.python.util import is_in_graph_mode
@ -54,17 +54,17 @@ _starting_device_spec = pydev.DeviceSpec.from_string("")
_MAXINT32 = 2**31 - 1 _MAXINT32 = 2**31 - 1
DEVICE_PLACEMENT_EXPLICIT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_EXPLICIT DEVICE_PLACEMENT_EXPLICIT = pywrap_tfe.TFE_DEVICE_PLACEMENT_EXPLICIT
DEVICE_PLACEMENT_WARN = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_WARN DEVICE_PLACEMENT_WARN = pywrap_tfe.TFE_DEVICE_PLACEMENT_WARN
DEVICE_PLACEMENT_SILENT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT DEVICE_PLACEMENT_SILENT = pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT
DEVICE_PLACEMENT_SILENT_FOR_INT32 = ( DEVICE_PLACEMENT_SILENT_FOR_INT32 = (
pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32) pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
SYNC = 0 SYNC = 0
ASYNC = 1 ASYNC = 1
MIRRORING_NONE = pywrap_tensorflow.TFE_MIRRORING_NONE MIRRORING_NONE = pywrap_tfe.TFE_MIRRORING_NONE
MIRRORING_ALL = pywrap_tensorflow.TFE_MIRRORING_ALL MIRRORING_ALL = pywrap_tfe.TFE_MIRRORING_ALL
_KEEP_ALIVE_SECS = 600 _KEEP_ALIVE_SECS = 600
@ -444,7 +444,7 @@ class Context(object):
self._rng = random.Random(seed) self._rng = random.Random(seed)
# Also clear the kernel cache, to reset any existing seeds # Also clear the kernel cache, to reset any existing seeds
if self._context_handle is not None: if self._context_handle is not None:
pywrap_tensorflow.TFE_ContextClearCaches(self._context_handle) pywrap_tfe.TFE_ContextClearCaches(self._context_handle)
def _internal_operation_seed(self): def _internal_operation_seed(self):
"""Returns a fake operation seed. """Returns a fake operation seed.
@ -463,12 +463,11 @@ class Context(object):
# Store list of devices # Store list of devices
logical_devices = [] logical_devices = []
context_devices = [] context_devices = []
device_list = pywrap_tensorflow.TFE_ContextListDevices( device_list = pywrap_tfe.TFE_ContextListDevices(self._context_handle)
self._context_handle)
try: try:
self._num_gpus = 0 self._num_gpus = 0
for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)): for i in range(pywrap_tfe.TF_DeviceListCount(device_list)):
dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i) dev_name = pywrap_tfe.TF_DeviceListName(device_list, i)
context_devices.append(pydev.canonical_name(dev_name)) context_devices.append(pydev.canonical_name(dev_name))
spec = pydev.DeviceSpec.from_string(dev_name) spec = pydev.DeviceSpec.from_string(dev_name)
# If the job is localhost, we assume that the cluster has not yet been # If the job is localhost, we assume that the cluster has not yet been
@ -477,14 +476,14 @@ class Context(object):
spec = spec.replace(job=None, replica=None, task=None) spec = spec.replace(job=None, replica=None, task=None)
logical_devices.append( logical_devices.append(
LogicalDevice(name=spec.to_string(), device_type=spec.device_type)) LogicalDevice(name=spec.to_string(), device_type=spec.device_type))
dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i) dev_type = pywrap_tfe.TF_DeviceListType(device_list, i)
if dev_type == "GPU": if dev_type == "GPU":
self._num_gpus += 1 self._num_gpus += 1
finally: finally:
self._logical_devices = logical_devices self._logical_devices = logical_devices
self._context_devices = context_devices self._context_devices = context_devices
pywrap_tensorflow.TF_DeleteDeviceList(device_list) pywrap_tfe.TF_DeleteDeviceList(device_list)
def ensure_initialized(self): def ensure_initialized(self):
"""Initialize handle and devices if not already done so.""" """Initialize handle and devices if not already done so."""
@ -494,36 +493,34 @@ class Context(object):
if self._initialized: if self._initialized:
return return
assert self._context_devices is None assert self._context_devices is None
opts = pywrap_tensorflow.TFE_NewContextOptions() opts = pywrap_tfe.TFE_NewContextOptions()
try: try:
config_str = self.config.SerializeToString() config_str = self.config.SerializeToString()
pywrap_tensorflow.TFE_ContextOptionsSetConfig(opts, config_str) pywrap_tfe.TFE_ContextOptionsSetConfig(opts, config_str)
if self._device_policy is not None: if self._device_policy is not None:
pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy( pywrap_tfe.TFE_ContextOptionsSetDevicePlacementPolicy(
opts, self._device_policy) opts, self._device_policy)
if self._mirroring_policy is not None: if self._mirroring_policy is not None:
pywrap_tensorflow.TFE_ContextOptionsSetMirroringPolicy( pywrap_tfe.TFE_ContextOptionsSetMirroringPolicy(
opts, self._mirroring_policy) opts, self._mirroring_policy)
if self._default_is_async == ASYNC: if self._default_is_async == ASYNC:
pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True) pywrap_tfe.TFE_ContextOptionsSetAsync(opts, True)
if self._lazy_remote_inputs_copy is not None: if self._lazy_remote_inputs_copy is not None:
pywrap_tensorflow.TFE_ContextOptionsSetLazyRemoteInputsCopy( pywrap_tfe.TFE_ContextOptionsSetLazyRemoteInputsCopy(
opts, self._lazy_remote_inputs_copy) opts, self._lazy_remote_inputs_copy)
context_handle = pywrap_tensorflow.TFE_NewContext(opts) context_handle = pywrap_tfe.TFE_NewContext(opts)
finally: finally:
pywrap_tensorflow.TFE_DeleteContextOptions(opts) pywrap_tfe.TFE_DeleteContextOptions(opts)
assert not (self._server_def and self._collective_ops_server_def), ( assert not (self._server_def and self._collective_ops_server_def), (
"Cannot enable remote execution as well as collective ops at the " "Cannot enable remote execution as well as collective ops at the "
"moment. If this is important to you, please file an issue.") "moment. If this is important to you, please file an issue.")
if self._server_def is not None: if self._server_def is not None:
server_def_str = self._server_def.SerializeToString() server_def_str = self._server_def.SerializeToString()
pywrap_tensorflow.TFE_ContextSetServerDef(context_handle, pywrap_tfe.TFE_ContextSetServerDef(context_handle, _KEEP_ALIVE_SECS,
_KEEP_ALIVE_SECS, server_def_str)
server_def_str)
elif self._collective_ops_server_def is not None: elif self._collective_ops_server_def is not None:
server_def_str = self._collective_ops_server_def.SerializeToString() server_def_str = self._collective_ops_server_def.SerializeToString()
pywrap_tensorflow.TFE_EnableCollectiveOps(context_handle, pywrap_tfe.TFE_EnableCollectiveOps(context_handle, server_def_str)
server_def_str)
self._context_handle = context_handle self._context_handle = context_handle
self._initialize_logical_devices() self._initialize_logical_devices()
@ -532,7 +529,7 @@ class Context(object):
def _clear_caches(self): def _clear_caches(self):
self.ones_rank_cache().flush() self.ones_rank_cache().flush()
self.zeros_cache().flush() self.zeros_cache().flush()
pywrap_tensorflow.TFE_ClearScalarCache() pywrap_tfe.TFE_ClearScalarCache()
def get_server_def(self): def get_server_def(self):
return self._server_def return self._server_def
@ -563,8 +560,8 @@ class Context(object):
if self._context_handle: if self._context_handle:
server_def_str = server_def.SerializeToString() server_def_str = server_def.SerializeToString()
pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle, pywrap_tfe.TFE_ContextSetServerDef(self._context_handle, keep_alive_secs,
keep_alive_secs, server_def_str) server_def_str)
self._initialize_logical_devices() self._initialize_logical_devices()
# Clear all the caches in case there are remote tensors in them. # Clear all the caches in case there are remote tensors in them.
@ -592,9 +589,8 @@ class Context(object):
if self._context_handle: if self._context_handle:
server_def_str = server_def.SerializeToString() server_def_str = server_def.SerializeToString()
pywrap_tensorflow.TFE_ContextUpdateServerDef(self._context_handle, pywrap_tfe.TFE_ContextUpdateServerDef(self._context_handle,
keep_alive_secs, keep_alive_secs, server_def_str)
server_def_str)
self._initialize_logical_devices() self._initialize_logical_devices()
self._clear_caches() self._clear_caches()
@ -614,8 +610,7 @@ class Context(object):
""" """
# TODO(yuefengz): support checking multiple workers. # TODO(yuefengz): support checking multiple workers.
if self._context_handle: if self._context_handle:
return pywrap_tensorflow.TFE_ContextCheckAlive(self._context_handle, return pywrap_tfe.TFE_ContextCheckAlive(self._context_handle, worker_name)
worker_name)
else: else:
raise ValueError("Context is not initialized.") raise ValueError("Context is not initialized.")
@ -808,8 +803,8 @@ class Context(object):
self.executor.wait() self.executor.wait()
executor_new = executor.new_executor(enable_async) executor_new = executor.new_executor(enable_async)
self._thread_local_data.executor = executor_new self._thread_local_data.executor = executor_new
pywrap_tensorflow.TFE_ContextSetExecutorForThread( pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle,
self._context_handle, executor_new.handle()) executor_new.handle())
else: else:
self._default_is_async = enable_async self._default_is_async = enable_async
@ -823,13 +818,12 @@ class Context(object):
def executor(self): def executor(self):
ensure_initialized() ensure_initialized()
return executor.Executor( return executor.Executor(
pywrap_tensorflow.TFE_ContextGetExecutorForThread(self._context_handle)) pywrap_tfe.TFE_ContextGetExecutorForThread(self._context_handle))
@executor.setter @executor.setter
def executor(self, e): def executor(self, e):
ensure_initialized() ensure_initialized()
pywrap_tensorflow.TFE_ContextSetExecutorForThread(self._context_handle, pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, e.handle())
e.handle())
@property @property
def config(self): def config(self):
@ -1015,7 +1009,7 @@ class Context(object):
fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper). fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper).
""" """
self.ensure_initialized() self.ensure_initialized()
pywrap_tensorflow.TFE_ContextAddFunction(self._handle, fn) pywrap_tfe.TFE_ContextAddFunction(self._handle, fn)
def add_function_def(self, fdef): def add_function_def(self, fdef):
"""Add a function definition to the context. """Add a function definition to the context.
@ -1028,8 +1022,8 @@ class Context(object):
""" """
self.ensure_initialized() self.ensure_initialized()
fdef_string = fdef.SerializeToString() fdef_string = fdef.SerializeToString()
pywrap_tensorflow.TFE_ContextAddFunctionDef( pywrap_tfe.TFE_ContextAddFunctionDef(self._handle, fdef_string,
self._handle, fdef_string, len(fdef_string)) len(fdef_string))
def remove_function(self, name): def remove_function(self, name):
"""Remove a function from the context. """Remove a function from the context.
@ -1040,12 +1034,12 @@ class Context(object):
name: function signature name. name: function signature name.
""" """
self.ensure_initialized() self.ensure_initialized()
pywrap_tensorflow.TFE_ContextRemoveFunction(self._handle, name) pywrap_tfe.TFE_ContextRemoveFunction(self._handle, name)
def has_function(self, name): def has_function(self, name):
"""Check if a function `name` is registered.""" """Check if a function `name` is registered."""
self.ensure_initialized() self.ensure_initialized()
return bool(pywrap_tensorflow.TFE_ContextHasFunction(self._handle, name)) return bool(pywrap_tfe.TFE_ContextHasFunction(self._handle, name))
def add_op_callback(self, callback): def add_op_callback(self, callback):
"""Add a post-op callback to the context. """Add a post-op callback to the context.
@ -1101,7 +1095,7 @@ class Context(object):
if self._physical_devices is not None: if self._physical_devices is not None:
return return
devs = pywrap_tensorflow.TF_ListPhysicalDevices() devs = pywrap_tfe.TF_ListPhysicalDevices()
self._physical_devices = [ self._physical_devices = [
PhysicalDevice(name=d.decode(), PhysicalDevice(name=d.decode(),
device_type=d.decode().split(":")[1]) for d in devs] device_type=d.decode().split(":")[1]) for d in devs]
@ -1434,7 +1428,7 @@ class Context(object):
def device_policy(self): def device_policy(self):
# Only get the policy from the context if it has already been initialized # Only get the policy from the context if it has already been initialized
if self._context_handle is not None: if self._context_handle is not None:
return pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(self._handle) return pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(self._handle)
return self._device_policy return self._device_policy
@ -1448,14 +1442,14 @@ class Context(object):
# Only set the policy if the context has already been initialized # Only set the policy if the context has already been initialized
if self._context_handle is not None: if self._context_handle is not None:
pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy( pywrap_tfe.TFE_ContextSetThreadLocalDevicePlacementPolicy(
self._handle, self._device_policy) self._handle, self._device_policy)
@property @property
def mirroring_policy(self): def mirroring_policy(self):
# Only get the policy from the context if it has already been initialized # Only get the policy from the context if it has already been initialized
if self._context_handle is not None: if self._context_handle is not None:
return pywrap_tensorflow.TFE_ContextGetMirroringPolicy(self._handle) return pywrap_tfe.TFE_ContextGetMirroringPolicy(self._handle)
return self._mirroring_policy return self._mirroring_policy
@ -1469,7 +1463,7 @@ class Context(object):
# Only set the policy if the context has already been initialized # Only set the policy if the context has already been initialized
if self._context_handle is not None: if self._context_handle is not None:
pywrap_tensorflow.TFE_ContextSetThreadLocalMirroringPolicy( pywrap_tfe.TFE_ContextSetThreadLocalMirroringPolicy(
self._handle, self._mirroring_policy) self._handle, self._mirroring_policy)
@property @property
@ -1495,13 +1489,13 @@ class Context(object):
and to stop tracing call context.disable_run_metadata(). and to stop tracing call context.disable_run_metadata().
""" """
self.ensure_initialized() self.ensure_initialized()
pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._handle) pywrap_tfe.TFE_ContextEnableRunMetadata(self._handle)
def disable_run_metadata(self): def disable_run_metadata(self):
"""Disables tracing of op execution via RunMetadata.""" """Disables tracing of op execution via RunMetadata."""
if not self._context_handle: if not self._context_handle:
return return
pywrap_tensorflow.TFE_ContextDisableRunMetadata(self._context_handle) pywrap_tfe.TFE_ContextDisableRunMetadata(self._context_handle)
def enable_graph_collection(self): def enable_graph_collection(self):
"""Enables graph collection of executed functions. """Enables graph collection of executed functions.
@ -1510,13 +1504,13 @@ class Context(object):
and to stop collecting graphs call context.disable_graph_collection(). and to stop collecting graphs call context.disable_graph_collection().
""" """
self.ensure_initialized() self.ensure_initialized()
pywrap_tensorflow.TFE_ContextEnableGraphCollection(self._handle) pywrap_tfe.TFE_ContextEnableGraphCollection(self._handle)
def disable_graph_collection(self): def disable_graph_collection(self):
"""Disables graph collection of executed functions.""" """Disables graph collection of executed functions."""
if not self._context_handle: if not self._context_handle:
return return
pywrap_tensorflow.TFE_ContextDisableGraphCollection(self._context_handle) pywrap_tfe.TFE_ContextDisableGraphCollection(self._context_handle)
def export_run_metadata(self): def export_run_metadata(self):
"""Returns a RunMetadata proto with accumulated information. """Returns a RunMetadata proto with accumulated information.
@ -1530,9 +1524,8 @@ class Context(object):
if not self._context_handle: if not self._context_handle:
return None return None
with c_api_util.tf_buffer() as buffer_: with c_api_util.tf_buffer() as buffer_:
pywrap_tensorflow.TFE_ContextExportRunMetadata( pywrap_tfe.TFE_ContextExportRunMetadata(self._context_handle, buffer_)
self._context_handle, buffer_) proto_data = pywrap_tfe.TF_GetBuffer(buffer_)
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
run_metadata = config_pb2.RunMetadata() run_metadata = config_pb2.RunMetadata()
run_metadata.ParseFromString(compat.as_bytes(proto_data)) run_metadata.ParseFromString(compat.as_bytes(proto_data))
return run_metadata return run_metadata
@ -1543,10 +1536,10 @@ class Context(object):
return self._context_switches return self._context_switches
def start_step(self): def start_step(self):
pywrap_tensorflow.TFE_ContextStartStep(self._handle) pywrap_tfe.TFE_ContextStartStep(self._handle)
def end_step(self): def end_step(self):
pywrap_tensorflow.TFE_ContextEndStep(self._handle) pywrap_tfe.TFE_ContextEndStep(self._handle)
class _EagerDeviceContext(object): class _EagerDeviceContext(object):
@ -1608,7 +1601,7 @@ _context_lock = threading.Lock()
def _set_context_locked(ctx): def _set_context_locked(ctx):
global _context global _context
pywrap_tensorflow.TFE_Py_SetEagerContext(ctx) pywrap_tfe.TFE_Py_SetEagerContext(ctx)
_context = ctx _context = ctx

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
# Trace of execution and memory usage. # Trace of execution and memory usage.
@ -46,7 +46,7 @@ class _NotOkStatusException(Exception):
return "%s: %s" % (e.__class__.__name__, e) return "%s: %s" % (e.__class__.__name__, e)
pywrap_tensorflow.TFE_Py_RegisterExceptionClass(_NotOkStatusException) pywrap_tfe.TFE_Py_RegisterExceptionClass(_NotOkStatusException)
class _FallbackException(Exception): class _FallbackException(Exception):
@ -71,4 +71,4 @@ class _SymbolicException(Exception):
pass pass
pywrap_tensorflow.TFE_Py_RegisterFallbackExceptionClass(_FallbackException) pywrap_tfe.TFE_Py_RegisterFallbackExceptionClass(_FallbackException)

View File

@ -26,7 +26,7 @@ import threading
import numpy as np import numpy as np
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import core from tensorflow.python.eager import core
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
@ -602,8 +602,8 @@ class TFETest(test_util.TensorFlowTestCase):
def testRegisterExceptionClass(self): def testRegisterExceptionClass(self):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
pywrap_tensorflow.TFE_Py_RegisterExceptionClass(str) pywrap_tfe.TFE_Py_RegisterExceptionClass(str)
pywrap_tensorflow.TFE_Py_RegisterExceptionClass(core._NotOkStatusException) # pylint: disable=protected-access pywrap_tfe.TFE_Py_RegisterExceptionClass(core._NotOkStatusException) # pylint: disable=protected-access
# TODO(agarwal): add tests passing incorrect typed values to attrs. # TODO(agarwal): add tests passing incorrect typed values to attrs.
def testExecuteBasic(self): def testExecuteBasic(self):

View File

@ -0,0 +1,61 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for using the TensorFlow Eager using the C API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tfe as c_api
from tensorflow.python.util import compat
from tensorflow.python.util import tf_contextlib
# We temporarily need a duplicate tf_buffer function in eager_util. The
# c_api_util is still relying on SWIG and is thus incompatible until
# we migrate over. We can delete this once we migrate tf_session.i
@tf_contextlib.contextmanager
def tf_buffer(data=None):
"""Context manager that creates and deletes TF_Buffer.
Example usage:
with tf_buffer() as buf:
# get serialized graph def into buf
...
proto_data = c_api.TF_GetBuffer(buf)
graph_def.ParseFromString(compat.as_bytes(proto_data))
# buf has been deleted
with tf_buffer(some_string) as buf:
c_api.TF_SomeFunction(buf)
# buf has been deleted
Args:
data: An optional `bytes`, `str`, or `unicode` object. If not None, the
yielded buffer will contain this data.
Yields:
Created TF_Buffer
"""
if data:
buf = c_api.TF_NewBufferFromString(compat.as_bytes(data))
else:
buf = c_api.TF_NewBuffer()
try:
yield buf
finally:
c_api.TF_DeleteBuffer(buf)

View File

@ -22,7 +22,7 @@ import six
from google.protobuf import text_format from google.protobuf import text_format
from tensorflow.core.framework import tensor_pb2 from tensorflow.core.framework import tensor_pb2
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import core from tensorflow.python.eager import core
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -56,9 +56,8 @@ def quick_execute(op_name, num_outputs, inputs, attrs, ctx, name=None):
# pylint: disable=protected-access # pylint: disable=protected-access
try: try:
ctx.ensure_initialized() ctx.ensure_initialized()
tensors = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name, tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
op_name, inputs, attrs, inputs, attrs, num_outputs)
num_outputs)
except core._NotOkStatusException as e: except core._NotOkStatusException as e:
if name is not None: if name is not None:
message = e.message + " name: " + name message = e.message + " name: " + name
@ -111,9 +110,10 @@ def execute_with_cancellation(op_name,
# pylint: disable=protected-access # pylint: disable=protected-access
try: try:
ctx.ensure_initialized() ctx.ensure_initialized()
tensors = pywrap_tensorflow.TFE_Py_ExecuteCancelable( tensors = pywrap_tfe.TFE_Py_ExecuteCancelable(ctx._handle, device_name,
ctx._handle, device_name, op_name, inputs, attrs, op_name, inputs, attrs,
cancellation_manager._impl, num_outputs) cancellation_manager._impl,
num_outputs)
except core._NotOkStatusException as e: except core._NotOkStatusException as e:
if name is not None: if name is not None:
message = e.message + " name: " + name message = e.message + " name: " + name

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
class Executor(object): class Executor(object):
@ -45,8 +45,8 @@ class Executor(object):
def __del__(self): def __del__(self):
try: try:
# pywrap_tensorflow.TFE_ExecutorWaitForAllPendingNodes(self._handle) # pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)
pywrap_tensorflow.TFE_DeleteExecutor(self._handle) pywrap_tfe.TFE_DeleteExecutor(self._handle)
except TypeError: except TypeError:
# Suppress some exceptions, mainly for the case when we're running on # Suppress some exceptions, mainly for the case when we're running on
# module deletion. Things that can go wrong include the pywrap module # module deletion. Things that can go wrong include the pywrap module
@ -57,20 +57,20 @@ class Executor(object):
# partially unloaded. # partially unloaded.
def is_async(self): def is_async(self):
return pywrap_tensorflow.TFE_ExecutorIsAsync(self._handle) return pywrap_tfe.TFE_ExecutorIsAsync(self._handle)
def handle(self): def handle(self):
return self._handle return self._handle
def wait(self): def wait(self):
"""Waits for ops dispatched in this executor to finish.""" """Waits for ops dispatched in this executor to finish."""
pywrap_tensorflow.TFE_ExecutorWaitForAllPendingNodes(self._handle) pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)
def clear_error(self): def clear_error(self):
"""Clears errors raised in this executor during execution.""" """Clears errors raised in this executor during execution."""
pywrap_tensorflow.TFE_ExecutorClearError(self._handle) pywrap_tfe.TFE_ExecutorClearError(self._handle)
def new_executor(enable_async): def new_executor(enable_async):
handle = pywrap_tensorflow.TFE_NewExecutor(enable_async) handle = pywrap_tfe.TFE_NewExecutor(enable_async)
return Executor(handle) return Executor(handle)

View File

@ -20,7 +20,7 @@ from __future__ import print_function
import threading import threading
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import backprop_util from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
@ -166,7 +166,8 @@ def _jvp_dispatch(op_name, attr_tuple, inputs, outputs, tangents):
return _jvp_relaxed_shapes( return _jvp_relaxed_shapes(
op_name, attr_tuple, inputs, outputs, tangents) op_name, attr_tuple, inputs, outputs, tangents)
pywrap_tensorflow.TFE_Py_RegisterJVPFunction(_jvp_dispatch)
pywrap_tfe.TFE_Py_RegisterJVPFunction(_jvp_dispatch)
@tf_export("autodiff.ForwardAccumulator", v1=[]) @tf_export("autodiff.ForwardAccumulator", v1=[])
@ -300,7 +301,7 @@ class ForwardAccumulator(object):
ValueError: If the same tensor or variable is specified multiple times in ValueError: If the same tensor or variable is specified multiple times in
`primals`. `primals`.
""" """
self._accumulator = pywrap_tensorflow.TFE_Py_ForwardAccumulatorNew() self._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew()
self._recording = False self._recording = False
primal_ids = set() primal_ids = set()
for primal in nest.flatten(primals): for primal in nest.flatten(primals):
@ -323,13 +324,13 @@ class ForwardAccumulator(object):
def _push_accumulator(self): def _push_accumulator(self):
if self._recording: if self._recording:
raise ValueError("Accumulator is already recording.") raise ValueError("Accumulator is already recording.")
pywrap_tensorflow.TFE_Py_ForwardAccumulatorSetAdd(self._accumulator) pywrap_tfe.TFE_Py_ForwardAccumulatorSetAdd(self._accumulator)
self._recording = True self._recording = True
def _pop_accumulator(self): def _pop_accumulator(self):
if not self._recording: if not self._recording:
raise ValueError("Accumulator is not recording.") raise ValueError("Accumulator is not recording.")
pywrap_tensorflow.TFE_Py_ForwardAccumulatorSetRemove(self._accumulator) pywrap_tfe.TFE_Py_ForwardAccumulatorSetRemove(self._accumulator)
self._recording = False self._recording = False
def _watch(self, primals, tangents): def _watch(self, primals, tangents):
@ -358,7 +359,7 @@ class ForwardAccumulator(object):
# Run convert_to_tensor to get the captured handle from whichever # Run convert_to_tensor to get the captured handle from whichever
# function we're running if necessary. # function we're running if necessary.
t = ops.convert_to_tensor(t.handle) t = ops.convert_to_tensor(t.handle)
pywrap_tensorflow.TFE_Py_ForwardAccumulatorWatch(self._accumulator, t, g) pywrap_tfe.TFE_Py_ForwardAccumulatorWatch(self._accumulator, t, g)
def jvp(self, primals, unconnected_gradients=UnconnectedGradients.NONE): def jvp(self, primals, unconnected_gradients=UnconnectedGradients.NONE):
"""Fetches the Jacobian-vector product computed for `primals`. """Fetches the Jacobian-vector product computed for `primals`.
@ -384,8 +385,8 @@ class ForwardAccumulator(object):
def _fetch_jvp(tensor): def _fetch_jvp(tensor):
if hasattr(tensor, "handle"): if hasattr(tensor, "handle"):
tensor = ops.convert_to_tensor(tensor.handle) tensor = ops.convert_to_tensor(tensor.handle)
result = pywrap_tensorflow.TFE_Py_ForwardAccumulatorJVP( result = pywrap_tfe.TFE_Py_ForwardAccumulatorJVP(self._accumulator,
self._accumulator, tensor) tensor)
if result is None and unconnected_gradients == UnconnectedGradients.ZERO: if result is None and unconnected_gradients == UnconnectedGradients.ZERO:
return array_ops.zeros_like(tensor) return array_ops.zeros_like(tensor)
return result return result

View File

@ -24,7 +24,7 @@ import weakref
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
@ -236,13 +236,13 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
x = constant_op.constant(1.) x = constant_op.constant(1.)
with forwardprop.ForwardAccumulator(x, 2.) as acc: with forwardprop.ForwardAccumulator(x, 2.) as acc:
y = x + x y = x + x
pywrap_tensorflow.TFE_Py_RegisterJVPFunction( pywrap_tfe.TFE_Py_RegisterJVPFunction(
lambda *args, **kwargs: [constant_op.constant(-15.)]) lambda *args, **kwargs: [constant_op.constant(-15.)])
z = x + x z = x + x
self.assertAllClose(4., acc.jvp(y)) self.assertAllClose(4., acc.jvp(y))
self.assertAllClose(-15., acc.jvp(z)) self.assertAllClose(-15., acc.jvp(z))
finally: finally:
pywrap_tensorflow.TFE_Py_RegisterJVPFunction(previous_fn) pywrap_tfe.TFE_Py_RegisterJVPFunction(previous_fn)
@test_util.assert_no_new_pyobjects_executing_eagerly @test_util.assert_no_new_pyobjects_executing_eagerly
def testFunctionCacheLimited(self): def testFunctionCacheLimited(self):
@ -738,19 +738,19 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
with forwardprop.ForwardAccumulator(c, c_tangent) as acc: with forwardprop.ForwardAccumulator(c, c_tangent) as acc:
with backprop.GradientTape() as tape: with backprop.GradientTape() as tape:
self.assertFalse(tape_lib.should_record_backprop([c])) self.assertFalse(tape_lib.should_record_backprop([c]))
self.assertEqual( self.assertEqual(1,
1, pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes([c])) pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
tape.watch(c) tape.watch(c)
self.assertEqual( self.assertEqual(2,
2, pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes([c])) pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
self.assertTrue(tape_lib.should_record_backprop([c])) self.assertTrue(tape_lib.should_record_backprop([c]))
with tape_lib.stop_recording(): with tape_lib.stop_recording():
self.assertEqual( self.assertEqual(0,
0, pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes([c])) pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
self.assertFalse(tape_lib.should_record_backprop([c])) self.assertFalse(tape_lib.should_record_backprop([c]))
d = c * 2. d = c * 2.
self.assertEqual( self.assertEqual(2,
2, pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes([c])) pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
self.assertTrue(tape_lib.should_record_backprop([c])) self.assertTrue(tape_lib.should_record_backprop([c]))
self.assertFalse(tape_lib.should_record_backprop([d])) self.assertFalse(tape_lib.should_record_backprop([d]))
self.assertIsNone(acc.jvp(d)) self.assertIsNone(acc.jvp(d))

View File

@ -24,7 +24,7 @@ from __future__ import print_function
import collections import collections
import contextlib import contextlib
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
class TangentInfo( class TangentInfo(
@ -54,8 +54,7 @@ def pack_tangents(tensors):
tangents: A flat list of Tensors. Best interpreted as a sequence to be tangents: A flat list of Tensors. Best interpreted as a sequence to be
appended to `tensors`. appended to `tensors`.
""" """
return TangentInfo( return TangentInfo(*pywrap_tfe.TFE_Py_PackJVPs(tensors))
*pywrap_tensorflow.TFE_Py_PackJVPs(tensors))
@contextlib.contextmanager @contextlib.contextmanager
@ -73,7 +72,7 @@ def push_forwardprop_state():
None (used for its side effect). None (used for its side effect).
""" """
try: try:
pywrap_tensorflow.TFE_Py_ForwardAccumulatorPushState() pywrap_tfe.TFE_Py_ForwardAccumulatorPushState()
yield yield
finally: finally:
pywrap_tensorflow.TFE_Py_ForwardAccumulatorPopState() pywrap_tfe.TFE_Py_ForwardAccumulatorPopState()

View File

@ -32,8 +32,9 @@ from six.moves import map
from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2 from tensorflow.core.framework import function_pb2
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
from tensorflow.python import _pywrap_utils from tensorflow.python import _pywrap_utils
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import backprop_util from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -1098,7 +1099,7 @@ class _TapeGradientFunctions(object):
forward_function.signature.name, forward_function.signature.name,
forward_outputs, forward_inputs, py_backward, None) forward_outputs, forward_inputs, py_backward, None)
output_indices, output_tangents = ( output_indices, output_tangents = (
pywrap_tensorflow.TFE_Py_PackJVPs(forward_outputs)) pywrap_tfe.TFE_Py_PackJVPs(forward_outputs))
output_tangents = [forward_wrapper_graph.capture(t) output_tangents = [forward_wrapper_graph.capture(t)
for t in output_tangents] for t in output_tangents]
return _ForwardWrapper( return _ForwardWrapper(
@ -1732,7 +1733,7 @@ class ConcreteFunction(object):
"Tensor." % (self._func_graph.name, i, str(arg))) "Tensor." % (self._func_graph.name, i, str(arg)))
args = tensor_inputs + captured_inputs args = tensor_inputs + captured_inputs
possible_gradient_type = ( possible_gradient_type = (
pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes(args)) pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(args))
if (possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_NONE if (possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_NONE
and executing_eagerly): and executing_eagerly):
# No tape is watching; skip to running the function. # No tape is watching; skip to running the function.
@ -2552,8 +2553,8 @@ class Function(object):
"""Computes the cache key given inputs and execution context.""" """Computes the cache key given inputs and execution context."""
if self.input_signature is None: if self.input_signature is None:
inputs = (args, kwargs) if kwargs else args inputs = (args, kwargs) if kwargs else args
input_signature = pywrap_tensorflow.TFE_Py_EncodeArg( input_signature = pywrap_tfe.TFE_Py_EncodeArg(inputs,
inputs, include_tensor_ranks_only) include_tensor_ranks_only)
else: else:
del args, kwargs del args, kwargs
assert not include_tensor_ranks_only assert not include_tensor_ranks_only

View File

@ -20,7 +20,7 @@ from __future__ import print_function
import collections import collections
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
from tensorflow.python.util import compat from tensorflow.python.util import compat
@ -68,7 +68,7 @@ def imperative_grad(tape,
raise ValueError( raise ValueError(
"Unknown value for unconnected_gradients: %r" % unconnected_gradients) "Unknown value for unconnected_gradients: %r" % unconnected_gradients)
return pywrap_tensorflow.TFE_Py_TapeGradient( return pywrap_tfe.TFE_Py_TapeGradient(
tape._tape, # pylint: disable=protected-access tape._tape, # pylint: disable=protected-access
target, target,
sources, sources,

View File

@ -21,80 +21,80 @@ from __future__ import print_function
import collections import collections
from tensorflow.core.framework import summary_pb2 from tensorflow.core.framework import summary_pb2
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
from tensorflow.python.framework import c_api_util from tensorflow.python.eager import eager_util as c_api_util
from tensorflow.python.util import compat from tensorflow.python.util import compat
_MetricMethod = collections.namedtuple('MetricMethod', 'create delete get_cell') _MetricMethod = collections.namedtuple('MetricMethod', 'create delete get_cell')
_counter_methods = [ _counter_methods = [
_MetricMethod( _MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewCounter0, create=pywrap_tfe.TFE_MonitoringNewCounter0,
delete=pywrap_tensorflow.TFE_MonitoringDeleteCounter0, delete=pywrap_tfe.TFE_MonitoringDeleteCounter0,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellCounter0), get_cell=pywrap_tfe.TFE_MonitoringGetCellCounter0),
_MetricMethod( _MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewCounter1, create=pywrap_tfe.TFE_MonitoringNewCounter1,
delete=pywrap_tensorflow.TFE_MonitoringDeleteCounter1, delete=pywrap_tfe.TFE_MonitoringDeleteCounter1,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellCounter1), get_cell=pywrap_tfe.TFE_MonitoringGetCellCounter1),
_MetricMethod( _MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewCounter2, create=pywrap_tfe.TFE_MonitoringNewCounter2,
delete=pywrap_tensorflow.TFE_MonitoringDeleteCounter2, delete=pywrap_tfe.TFE_MonitoringDeleteCounter2,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellCounter2), get_cell=pywrap_tfe.TFE_MonitoringGetCellCounter2),
] ]
_int_gauge_methods = [ _int_gauge_methods = [
_MetricMethod( _MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewIntGauge0, create=pywrap_tfe.TFE_MonitoringNewIntGauge0,
delete=pywrap_tensorflow.TFE_MonitoringDeleteIntGauge0, delete=pywrap_tfe.TFE_MonitoringDeleteIntGauge0,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellIntGauge0), get_cell=pywrap_tfe.TFE_MonitoringGetCellIntGauge0),
_MetricMethod( _MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewIntGauge1, create=pywrap_tfe.TFE_MonitoringNewIntGauge1,
delete=pywrap_tensorflow.TFE_MonitoringDeleteIntGauge1, delete=pywrap_tfe.TFE_MonitoringDeleteIntGauge1,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellIntGauge1), get_cell=pywrap_tfe.TFE_MonitoringGetCellIntGauge1),
_MetricMethod( _MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewIntGauge2, create=pywrap_tfe.TFE_MonitoringNewIntGauge2,
delete=pywrap_tensorflow.TFE_MonitoringDeleteIntGauge2, delete=pywrap_tfe.TFE_MonitoringDeleteIntGauge2,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellIntGauge2), get_cell=pywrap_tfe.TFE_MonitoringGetCellIntGauge2),
] ]
_string_gauge_methods = [ _string_gauge_methods = [
_MetricMethod( _MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewStringGauge0, create=pywrap_tfe.TFE_MonitoringNewStringGauge0,
delete=pywrap_tensorflow.TFE_MonitoringDeleteStringGauge0, delete=pywrap_tfe.TFE_MonitoringDeleteStringGauge0,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellStringGauge0), get_cell=pywrap_tfe.TFE_MonitoringGetCellStringGauge0),
_MetricMethod( _MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewStringGauge1, create=pywrap_tfe.TFE_MonitoringNewStringGauge1,
delete=pywrap_tensorflow.TFE_MonitoringDeleteStringGauge1, delete=pywrap_tfe.TFE_MonitoringDeleteStringGauge1,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellStringGauge1), get_cell=pywrap_tfe.TFE_MonitoringGetCellStringGauge1),
_MetricMethod( _MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewStringGauge2, create=pywrap_tfe.TFE_MonitoringNewStringGauge2,
delete=pywrap_tensorflow.TFE_MonitoringDeleteStringGauge2, delete=pywrap_tfe.TFE_MonitoringDeleteStringGauge2,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellStringGauge2), get_cell=pywrap_tfe.TFE_MonitoringGetCellStringGauge2),
] ]
_bool_gauge_methods = [ _bool_gauge_methods = [
_MetricMethod( _MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewBoolGauge0, create=pywrap_tfe.TFE_MonitoringNewBoolGauge0,
delete=pywrap_tensorflow.TFE_MonitoringDeleteBoolGauge0, delete=pywrap_tfe.TFE_MonitoringDeleteBoolGauge0,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellBoolGauge0), get_cell=pywrap_tfe.TFE_MonitoringGetCellBoolGauge0),
_MetricMethod( _MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewBoolGauge1, create=pywrap_tfe.TFE_MonitoringNewBoolGauge1,
delete=pywrap_tensorflow.TFE_MonitoringDeleteBoolGauge1, delete=pywrap_tfe.TFE_MonitoringDeleteBoolGauge1,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellBoolGauge1), get_cell=pywrap_tfe.TFE_MonitoringGetCellBoolGauge1),
_MetricMethod( _MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewBoolGauge2, create=pywrap_tfe.TFE_MonitoringNewBoolGauge2,
delete=pywrap_tensorflow.TFE_MonitoringDeleteBoolGauge2, delete=pywrap_tfe.TFE_MonitoringDeleteBoolGauge2,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellBoolGauge2), get_cell=pywrap_tfe.TFE_MonitoringGetCellBoolGauge2),
] ]
_sampler_methods = [ _sampler_methods = [
_MetricMethod( _MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewSampler0, create=pywrap_tfe.TFE_MonitoringNewSampler0,
delete=pywrap_tensorflow.TFE_MonitoringDeleteSampler0, delete=pywrap_tfe.TFE_MonitoringDeleteSampler0,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellSampler0), get_cell=pywrap_tfe.TFE_MonitoringGetCellSampler0),
_MetricMethod( _MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewSampler1, create=pywrap_tfe.TFE_MonitoringNewSampler1,
delete=pywrap_tensorflow.TFE_MonitoringDeleteSampler1, delete=pywrap_tfe.TFE_MonitoringDeleteSampler1,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellSampler1), get_cell=pywrap_tfe.TFE_MonitoringGetCellSampler1),
_MetricMethod( _MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewSampler2, create=pywrap_tfe.TFE_MonitoringNewSampler2,
delete=pywrap_tensorflow.TFE_MonitoringDeleteSampler2, delete=pywrap_tfe.TFE_MonitoringDeleteSampler2,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellSampler2), get_cell=pywrap_tfe.TFE_MonitoringGetCellSampler2),
] ]
@ -156,11 +156,11 @@ class CounterCell(object):
Args: Args:
value: non-negative value. value: non-negative value.
""" """
pywrap_tensorflow.TFE_MonitoringCounterCellIncrementBy(self._cell, value) pywrap_tfe.TFE_MonitoringCounterCellIncrementBy(self._cell, value)
def value(self): def value(self):
"""Retrieves the current value.""" """Retrieves the current value."""
return pywrap_tensorflow.TFE_MonitoringCounterCellValue(self._cell) return pywrap_tfe.TFE_MonitoringCounterCellValue(self._cell)
class Counter(Metric): class Counter(Metric):
@ -204,11 +204,11 @@ class IntGaugeCell(object):
Args: Args:
value: integer value. value: integer value.
""" """
pywrap_tensorflow.TFE_MonitoringIntGaugeCellSet(self._cell, value) pywrap_tfe.TFE_MonitoringIntGaugeCellSet(self._cell, value)
def value(self): def value(self):
"""Retrieves the current value.""" """Retrieves the current value."""
return pywrap_tensorflow.TFE_MonitoringIntGaugeCellValue(self._cell) return pywrap_tfe.TFE_MonitoringIntGaugeCellValue(self._cell)
class IntGauge(Metric): class IntGauge(Metric):
@ -252,13 +252,13 @@ class StringGaugeCell(object):
Args: Args:
value: string value. value: string value.
""" """
pywrap_tensorflow.TFE_MonitoringStringGaugeCellSet(self._cell, value) pywrap_tfe.TFE_MonitoringStringGaugeCellSet(self._cell, value)
def value(self): def value(self):
"""Retrieves the current value.""" """Retrieves the current value."""
with c_api_util.tf_buffer() as buffer_: with c_api_util.tf_buffer() as buffer_:
pywrap_tensorflow.TFE_MonitoringStringGaugeCellValue(self._cell, buffer_) pywrap_tfe.TFE_MonitoringStringGaugeCellValue(self._cell, buffer_)
value = pywrap_tensorflow.TF_GetBuffer(buffer_).decode('utf-8') value = pywrap_tfe.TF_GetBuffer(buffer_).decode('utf-8')
return value return value
@ -303,11 +303,11 @@ class BoolGaugeCell(object):
Args: Args:
value: bool value. value: bool value.
""" """
pywrap_tensorflow.TFE_MonitoringBoolGaugeCellSet(self._cell, value) pywrap_tfe.TFE_MonitoringBoolGaugeCellSet(self._cell, value)
def value(self): def value(self):
"""Retrieves the current value.""" """Retrieves the current value."""
return pywrap_tensorflow.TFE_MonitoringBoolGaugeCellValue(self._cell) return pywrap_tfe.TFE_MonitoringBoolGaugeCellValue(self._cell)
class BoolGauge(Metric): class BoolGauge(Metric):
@ -351,7 +351,7 @@ class SamplerCell(object):
Args: Args:
value: float value. value: float value.
""" """
pywrap_tensorflow.TFE_MonitoringSamplerCellAdd(self._cell, value) pywrap_tfe.TFE_MonitoringSamplerCellAdd(self._cell, value)
def value(self): def value(self):
"""Retrieves the current distribution of samples. """Retrieves the current distribution of samples.
@ -360,8 +360,8 @@ class SamplerCell(object):
A HistogramProto describing the distribution of samples. A HistogramProto describing the distribution of samples.
""" """
with c_api_util.tf_buffer() as buffer_: with c_api_util.tf_buffer() as buffer_:
pywrap_tensorflow.TFE_MonitoringSamplerCellValue(self._cell, buffer_) pywrap_tfe.TFE_MonitoringSamplerCellValue(self._cell, buffer_)
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) proto_data = pywrap_tfe.TF_GetBuffer(buffer_)
histogram_proto = summary_pb2.HistogramProto() histogram_proto = summary_pb2.HistogramProto()
histogram_proto.ParseFromString(compat.as_bytes(proto_data)) histogram_proto.ParseFromString(compat.as_bytes(proto_data))
return histogram_proto return histogram_proto
@ -379,7 +379,7 @@ class Buckets(object):
self.buckets = buckets self.buckets = buckets
def __del__(self): def __del__(self):
pywrap_tensorflow.TFE_MonitoringDeleteBuckets(self.buckets) pywrap_tfe.TFE_MonitoringDeleteBuckets(self.buckets)
class ExponentialBuckets(Buckets): class ExponentialBuckets(Buckets):
@ -399,8 +399,8 @@ class ExponentialBuckets(Buckets):
bucket_count: integer bucket_count: integer
""" """
super(ExponentialBuckets, self).__init__( super(ExponentialBuckets, self).__init__(
pywrap_tensorflow.TFE_MonitoringNewExponentialBuckets( pywrap_tfe.TFE_MonitoringNewExponentialBuckets(scale, growth_factor,
scale, growth_factor, bucket_count)) bucket_count))
class Sampler(Metric): class Sampler(Metric):

View File

@ -39,9 +39,9 @@ import os
import threading import threading
from tensorflow.python import _pywrap_events_writer from tensorflow.python import _pywrap_events_writer
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import c_api_util from tensorflow.python.eager import eager_util as c_api_util
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat from tensorflow.python.util import compat
@ -74,8 +74,8 @@ def start():
raise ProfilerAlreadyRunningError('Another profiler is running.') raise ProfilerAlreadyRunningError('Another profiler is running.')
if context.default_execution_mode == context.EAGER_MODE: if context.default_execution_mode == context.EAGER_MODE:
context.ensure_initialized() context.ensure_initialized()
_profiler = pywrap_tensorflow.TFE_NewProfiler() _profiler = pywrap_tfe.TFE_NewProfiler()
if not pywrap_tensorflow.TFE_ProfilerIsOk(_profiler): if not pywrap_tfe.TFE_ProfilerIsOk(_profiler):
logging.warning('Another profiler session is running which is probably ' logging.warning('Another profiler session is running which is probably '
'created by profiler server. Please avoid using profiler ' 'created by profiler server. Please avoid using profiler '
'server and profiler APIs at the same time.') 'server and profiler APIs at the same time.')
@ -100,11 +100,9 @@ def stop():
if context.default_execution_mode == context.EAGER_MODE: if context.default_execution_mode == context.EAGER_MODE:
context.context().executor.wait() context.context().executor.wait()
with c_api_util.tf_buffer() as buffer_: with c_api_util.tf_buffer() as buffer_:
pywrap_tensorflow.TFE_ProfilerSerializeToString( pywrap_tfe.TFE_ProfilerSerializeToString(_profiler, buffer_)
_profiler, result = pywrap_tfe.TF_GetBuffer(buffer_)
buffer_) pywrap_tfe.TFE_DeleteProfiler(_profiler)
result = pywrap_tensorflow.TF_GetBuffer(buffer_)
pywrap_tensorflow.TFE_DeleteProfiler(_profiler)
_profiler = None _profiler = None
_run_num += 1 _run_num += 1
return result return result
@ -159,7 +157,7 @@ def start_profiler_server(port):
""" """
if context.default_execution_mode == context.EAGER_MODE: if context.default_execution_mode == context.EAGER_MODE:
context.ensure_initialized() context.ensure_initialized()
pywrap_tensorflow.TFE_StartProfilerServer(port) pywrap_tfe.TFE_StartProfilerServer(port)
class Profiler(object): class Profiler(object):

View File

@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
from tensorflow.python.framework import c_api_util from tensorflow.python.eager import eager_util as c_api_util
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
@ -46,7 +46,7 @@ def start_tracing(service_addr,
Raises: Raises:
UnavailableError: If no trace event is collected. UnavailableError: If no trace event is collected.
""" """
if not pywrap_tensorflow.TFE_ProfilerClientStartTracing( if not pywrap_tfe.TFE_ProfilerClientStartTracing(
service_addr, logdir, worker_list, include_dataset_ops, duration_ms, service_addr, logdir, worker_list, include_dataset_ops, duration_ms,
num_tracing_attempts): num_tracing_attempts):
raise errors.UnavailableError(None, None, 'No trace event is collected.') raise errors.UnavailableError(None, None, 'No trace event is collected.')
@ -71,7 +71,7 @@ def monitor(service_addr,
A string of monitoring output. A string of monitoring output.
""" """
with c_api_util.tf_buffer() as buffer_: with c_api_util.tf_buffer() as buffer_:
pywrap_tensorflow.TFE_ProfilerClientMonitor(service_addr, duration_ms, pywrap_tfe.TFE_ProfilerClientMonitor(service_addr, duration_ms,
monitoring_level, monitoring_level, display_timestamp,
display_timestamp, buffer_) buffer_)
return pywrap_tensorflow.TF_GetBuffer(buffer_) return pywrap_tfe.TF_GetBuffer(buffer_)

View File

@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import core from tensorflow.python.eager import core
@ -54,14 +54,16 @@ class Tests(test.TestCase):
self.assertAllClose( self.assertAllClose(
math_ops.matmul(a_2_by_2, b_2_by_2), math_ops.matmul(a_2_by_2, b_2_by_2),
pywrap_tensorflow.TFE_Py_FastPathExecute( pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2, "MatMul", None, None, a_2_by_2,
b_2_by_2, "transpose_a", False, "transpose_b", False)) b_2_by_2, "transpose_a", False,
"transpose_b", False))
self.assertAllClose( self.assertAllClose(
math_ops.matmul(a_100_by_784, b_100_by_784, transpose_b=True), math_ops.matmul(a_100_by_784, b_100_by_784, transpose_b=True),
pywrap_tensorflow.TFE_Py_FastPathExecute( pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
ctx._handle, ctx.device_name, "MatMul", None, None, a_100_by_784, "MatMul", None, None, a_100_by_784,
b_100_by_784, "transpose_a", False, "transpose_b", True)) b_100_by_784, "transpose_a", False,
"transpose_b", True))
@test_util.assert_no_new_tensors @test_util.assert_no_new_tensors
@test_util.assert_no_garbage_created @test_util.assert_no_garbage_created
@ -71,12 +73,14 @@ class Tests(test.TestCase):
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2]) a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
m = resource_variable_ops.ResourceVariable(a_2_by_2) m = resource_variable_ops.ResourceVariable(a_2_by_2)
x = pywrap_tensorflow.TFE_Py_FastPathExecute( x = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
ctx._handle, ctx.device_name, "MatMul", None, None, m, m, "transpose_a", "MatMul", None, None, m, m,
False, "transpose_b", False) "transpose_a", False, "transpose_b",
y = pywrap_tensorflow.TFE_Py_FastPathExecute( False)
ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2, a_2_by_2, y = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"transpose_a", False, "transpose_b", False) "MatMul", None, None, a_2_by_2,
a_2_by_2, "transpose_a", False,
"transpose_b", False)
self.assertAllEqual(x, y) self.assertAllEqual(x, y)
@ -89,9 +93,10 @@ class Tests(test.TestCase):
with backprop.GradientTape(persistent=True) as tape: with backprop.GradientTape(persistent=True) as tape:
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2]) a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
tape.watch(a_2_by_2) tape.watch(a_2_by_2)
z = pywrap_tensorflow.TFE_Py_FastPathExecute( z = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2, "MatMul", None, None, a_2_by_2,
a_2_by_2, "transpose_a", False, "transpose_b", False) a_2_by_2, "transpose_a", False,
"transpose_b", False)
dz_dy = tape.gradient(z, [a_2_by_2])[0] dz_dy = tape.gradient(z, [a_2_by_2])[0]
self.assertAllEqual(dz_dy.numpy(), self.assertAllEqual(dz_dy.numpy(),
constant_op.constant(4.0, shape=[2, 2]).numpy()) constant_op.constant(4.0, shape=[2, 2]).numpy())
@ -106,9 +111,10 @@ class Tests(test.TestCase):
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2]) a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
m = resource_variable_ops.ResourceVariable(a_2_by_2) m = resource_variable_ops.ResourceVariable(a_2_by_2)
tape.watch(m) tape.watch(m)
z = pywrap_tensorflow.TFE_Py_FastPathExecute( z = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
ctx._handle, ctx.device_name, "MatMul", None, None, m, m, "MatMul", None, None, m, m,
"transpose_a", False, "transpose_b", False) "transpose_a", False, "transpose_b",
False)
dz_dy = tape.gradient(z, [m])[0] dz_dy = tape.gradient(z, [m])[0]
self.assertAllEqual(dz_dy.numpy(), self.assertAllEqual(dz_dy.numpy(),
constant_op.constant(4.0, shape=[2, 2]).numpy()) constant_op.constant(4.0, shape=[2, 2]).numpy())
@ -125,9 +131,8 @@ class Tests(test.TestCase):
self.assertAllClose( self.assertAllClose(
math_ops.add_n([a_2_by_2, b_2_by_2]), math_ops.add_n([a_2_by_2, b_2_by_2]),
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, "AddN",
"AddN", None, None, None, None, [a_2_by_2, b_2_by_2]))
[a_2_by_2, b_2_by_2]))
# Tests homogeneous list op # Tests homogeneous list op
@test_util.assert_no_new_tensors @test_util.assert_no_new_tensors
@ -142,9 +147,9 @@ class Tests(test.TestCase):
with backprop.GradientTape(persistent=True) as tape: with backprop.GradientTape(persistent=True) as tape:
tape.watch(a_2_by_2) tape.watch(a_2_by_2)
tape.watch(b_2_by_2) tape.watch(b_2_by_2)
z1 = pywrap_tensorflow.TFE_Py_FastPathExecute( z1 = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
ctx._handle, ctx.device_name, "AddN", None, None, "AddN", None, None,
[a_2_by_2, b_2_by_2]) [a_2_by_2, b_2_by_2])
z2 = math_ops.add_n([a_2_by_2, b_2_by_2]) z2 = math_ops.add_n([a_2_by_2, b_2_by_2])
dz1_dy = tape.gradient(z1, [a_2_by_2])[0] dz1_dy = tape.gradient(z1, [a_2_by_2])[0]
dz2_dy = tape.gradient(z2, [a_2_by_2])[0] dz2_dy = tape.gradient(z2, [a_2_by_2])[0]
@ -162,9 +167,9 @@ class Tests(test.TestCase):
self.assertAllClose( self.assertAllClose(
array_ops.identity_n([a_2_by_2, b_2_by_2]), array_ops.identity_n([a_2_by_2, b_2_by_2]),
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"IdentityN", None, None, "IdentityN", None, None,
[a_2_by_2, b_2_by_2])) [a_2_by_2, b_2_by_2]))
# Tests heterogeneous list op # Tests heterogeneous list op
@test_util.assert_no_new_tensors @test_util.assert_no_new_tensors
@ -179,9 +184,9 @@ class Tests(test.TestCase):
with backprop.GradientTape(persistent=True) as tape: with backprop.GradientTape(persistent=True) as tape:
tape.watch(a_2_by_2) tape.watch(a_2_by_2)
tape.watch(b_2_by_2) tape.watch(b_2_by_2)
z1 = pywrap_tensorflow.TFE_Py_FastPathExecute( z1 = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
ctx._handle, ctx.device_name, "IdentityN", None, None, "IdentityN", None, None,
[a_2_by_2, b_2_by_2]) [a_2_by_2, b_2_by_2])
z2 = array_ops.identity_n([a_2_by_2, b_2_by_2]) z2 = array_ops.identity_n([a_2_by_2, b_2_by_2])
dz1_dy = tape.gradient(z1[0], [a_2_by_2])[0] dz1_dy = tape.gradient(z1[0], [a_2_by_2])[0]
dz2_dy = tape.gradient(z2[0], [a_2_by_2])[0] dz2_dy = tape.gradient(z2[0], [a_2_by_2])[0]
@ -201,19 +206,18 @@ class Tests(test.TestCase):
# Not enough base params # Not enough base params
with self.assertRaisesRegexp(ValueError, with self.assertRaisesRegexp(ValueError,
"at least 5 items in the input tuple"): "at least 5 items in the input tuple"):
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, "Identity")
"Identity")
# Not enough inputs # Not enough inputs
with self.assertRaisesRegexp(ValueError, with self.assertRaisesRegexp(ValueError,
"Expected to be at least 6, was 5"): "Expected to be at least 6, was 5"):
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx_handle, pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx_handle, "Identity",
"Identity", None, []) None, [])
# Bad type # Bad type
with self.assertRaisesRegexp(TypeError, "expected a string for op_name"): with self.assertRaisesRegexp(TypeError, "expected a string for op_name"):
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, ctx_handle,
ctx_handle, None, [], a_2_by_2) None, [], a_2_by_2)
@test_util.assert_no_new_tensors @test_util.assert_no_new_tensors
@test_util.assert_no_garbage_created @test_util.assert_no_garbage_created
@ -225,9 +229,9 @@ class Tests(test.TestCase):
ctx_handle = ctx._handle ctx_handle = ctx._handle
with self.assertRaises(core._FallbackException): with self.assertRaises(core._FallbackException):
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, "Split",
"Split", None, None, split_dim, None, None, split_dim, value,
value, "num_split", -1) "num_split", -1)
@test_util.assert_no_new_tensors @test_util.assert_no_new_tensors
@test_util.assert_no_garbage_created @test_util.assert_no_garbage_created
@ -266,10 +270,9 @@ class Tests(test.TestCase):
ctx = context.context() ctx = context.context()
ctx.ensure_initialized() ctx.ensure_initialized()
with self.assertRaises(core._FallbackException): with self.assertRaises(core._FallbackException):
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, "MatMul",
"MatMul", None, None, m, m, None, None, m, m, "transpose_a", False,
"transpose_a", False, "transpose_b", False)
"transpose_b", False)
def testOpDefDefaultType(self): def testOpDefDefaultType(self):
im = np.random.randint( im = np.random.randint(

View File

@ -22,7 +22,7 @@ import copy
from absl import logging from absl import logging
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
from tensorflow.python.distribute import device_util from tensorflow.python.distribute import device_util
from tensorflow.python.distribute.cluster_resolver import cluster_resolver from tensorflow.python.distribute.cluster_resolver import cluster_resolver
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -127,7 +127,7 @@ def connect_to_cluster(cluster_spec_or_resolver,
# Automatically add local job, if not part of the cluster spec. # Automatically add local job, if not part of the cluster spec.
if job_name not in cluster_spec.jobs: if job_name not in cluster_spec.jobs:
local_port = pywrap_tensorflow.TF_PickUnusedPortOrDie() local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
job_def = cluster_def.job.add() job_def = cluster_def.job.add()
job_def.name = job_name job_def.name = job_name
# TODO(fishx): Update this to make sure remote worker has valid ip address # TODO(fishx): Update this to make sure remote worker has valid ip address

View File

@ -20,7 +20,7 @@ from __future__ import print_function
import contextlib import contextlib
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
from tensorflow.python.util.lazy_loader import LazyLoader from tensorflow.python.util.lazy_loader import LazyLoader
# There is a circular dependency between this, ops.py, and # There is a circular dependency between this, ops.py, and
@ -39,24 +39,23 @@ class Tape(object):
self._tape = tape self._tape = tape
def watched_variables(self): def watched_variables(self):
return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape) return pywrap_tfe.TFE_Py_TapeWatchedVariables(self._tape)
def push_new_tape(persistent=False, watch_accessed_variables=True): def push_new_tape(persistent=False, watch_accessed_variables=True):
"""Pushes a new tape onto the tape stack.""" """Pushes a new tape onto the tape stack."""
tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent, tape = pywrap_tfe.TFE_Py_TapeSetNew(persistent, watch_accessed_variables)
watch_accessed_variables)
return Tape(tape) return Tape(tape)
def push_tape(tape): def push_tape(tape):
"""Pushes an existing tape onto the tape stack.""" """Pushes an existing tape onto the tape stack."""
pywrap_tensorflow.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access pywrap_tfe.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access
def watch(tape, tensor): def watch(tape, tensor):
"""Marks this tensor to be watched by the given tape.""" """Marks this tensor to be watched by the given tape."""
pywrap_tensorflow.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access pywrap_tfe.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access
def watch_variable(tape, variable): def watch_variable(tape, variable):
@ -68,7 +67,7 @@ def watch_variable(tape, variable):
else: else:
variables = strategy.experimental_local_results(variable) variables = strategy.experimental_local_results(variable)
for var in variables: for var in variables:
pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access pywrap_tfe.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access
def variable_accessed(variable): def variable_accessed(variable):
@ -84,7 +83,7 @@ def variable_accessed(variable):
else: else:
variables = strategy.experimental_local_results(variable) variables = strategy.experimental_local_results(variable)
for var in variables: for var in variables:
pywrap_tensorflow.TFE_Py_TapeVariableAccessed(var) pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
def variables_accessed(variables): def variables_accessed(variables):
@ -107,25 +106,25 @@ def variables_accessed(variables):
accessed.extend(strategy.experimental_local_results(variable)) accessed.extend(strategy.experimental_local_results(variable))
for var in accessed: for var in accessed:
pywrap_tensorflow.TFE_Py_TapeVariableAccessed(var) pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
def pop_tape(tape): def pop_tape(tape):
"""Pops the given tape in the stack.""" """Pops the given tape in the stack."""
pywrap_tensorflow.TFE_Py_TapeSetRemove(tape._tape) # pylint: disable=protected-access pywrap_tfe.TFE_Py_TapeSetRemove(tape._tape) # pylint: disable=protected-access
@contextlib.contextmanager @contextlib.contextmanager
def stop_recording(): def stop_recording():
"""Stop all gradient recording (backprop and forwardprop).""" """Stop all gradient recording (backprop and forwardprop)."""
is_stopped = pywrap_tensorflow.TFE_Py_TapeSetIsStopped() is_stopped = pywrap_tfe.TFE_Py_TapeSetIsStopped()
try: try:
if not is_stopped: if not is_stopped:
pywrap_tensorflow.TFE_Py_TapeSetStopOnThread() pywrap_tfe.TFE_Py_TapeSetStopOnThread()
yield yield
finally: finally:
if not is_stopped: if not is_stopped:
pywrap_tensorflow.TFE_Py_TapeSetRestartOnThread() pywrap_tfe.TFE_Py_TapeSetRestartOnThread()
def should_record_backprop(tensors): def should_record_backprop(tensors):
@ -139,22 +138,23 @@ def should_record_backprop(tensors):
Returns: Returns:
Boolean, whether any tape watches any of `tensors`. Boolean, whether any tape watches any of `tensors`.
""" """
return pywrap_tensorflow.TFE_Py_TapeSetShouldRecordBackprop(tensors) return pywrap_tfe.TFE_Py_TapeSetShouldRecordBackprop(tensors)
def record_operation(op_type, output_tensors, input_tensors, backward_function, def record_operation(op_type, output_tensors, input_tensors, backward_function,
forward_function=None): forward_function=None):
"""Records the operation on all tapes in the stack.""" """Records the operation on all tapes in the stack."""
pywrap_tensorflow.TFE_Py_TapeSetRecordOperation( pywrap_tfe.TFE_Py_TapeSetRecordOperation(op_type, output_tensors,
op_type, output_tensors, input_tensors, backward_function, input_tensors, backward_function,
forward_function) forward_function)
def record_operation_backprop_only(op_type, output_tensors, input_tensors, def record_operation_backprop_only(op_type, output_tensors, input_tensors,
backward_function): backward_function):
"""Records the operation on all backward tapes in the stack.""" """Records the operation on all backward tapes in the stack."""
pywrap_tensorflow.TFE_Py_TapeSetRecordOperationBackprop( pywrap_tfe.TFE_Py_TapeSetRecordOperationBackprop(op_type, output_tensors,
op_type, output_tensors, input_tensors, backward_function) input_tensors,
backward_function)
def record_operation_forwardprop_only(op_type, output_tensors, input_tensors, def record_operation_forwardprop_only(op_type, output_tensors, input_tensors,
@ -174,16 +174,16 @@ def record_operation_forwardprop_only(op_type, output_tensors, input_tensors,
Typically these will have come from TFE_Py_PackForwardGradients. May be Typically these will have come from TFE_Py_PackForwardGradients. May be
None or an empty sequence if there are no JVP outputs from the operation. None or an empty sequence if there are no JVP outputs from the operation.
""" """
pywrap_tensorflow.TFE_Py_TapeSetRecordOperationForwardprop( pywrap_tfe.TFE_Py_TapeSetRecordOperationForwardprop(
op_type, output_tensors, input_tensors, backward_function, op_type, output_tensors, input_tensors, backward_function,
forwardprop_output_indices) forwardprop_output_indices)
def delete_trace(tensor_id): def delete_trace(tensor_id):
"""Deletes traces for this Tensor from all tapes in the stack.""" """Deletes traces for this Tensor from all tapes in the stack."""
pywrap_tensorflow.TFE_Py_TapeSetDeleteTrace(tensor_id) pywrap_tfe.TFE_Py_TapeSetDeleteTrace(tensor_id)
def could_possibly_record(): def could_possibly_record():
"""Returns True if any tape is active.""" """Returns True if any tape is active."""
return not pywrap_tensorflow.TFE_Py_TapeSetIsEmpty() return not pywrap_tfe.TFE_Py_TapeSetIsEmpty()

View File

@ -26,7 +26,7 @@ import unittest
import numpy as np import numpy as np
import six import six
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import core from tensorflow.python.eager import core
from tensorflow.python.eager import test from tensorflow.python.eager import test
@ -435,14 +435,14 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
t2 = _create_tensor([[1, 2, 5], [3, 4, 5]], dtype=dtypes.int32) t2 = _create_tensor([[1, 2, 5], [3, 4, 5]], dtype=dtypes.int32)
t3 = _create_tensor([[1], [3], [5], [6]], dtype=dtypes.int32) t3 = _create_tensor([[1], [3], [5], [6]], dtype=dtypes.int32)
r = pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2, t3], 0) r = pywrap_tfe.TFE_Py_TensorShapeSlice([t1, t2, t3], 0)
self.assertAllEqual(np.array([3, 2, 4]), r.numpy()) self.assertAllEqual(np.array([3, 2, 4]), r.numpy())
r = pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2, t3], 1) r = pywrap_tfe.TFE_Py_TensorShapeSlice([t1, t2, t3], 1)
self.assertAllEqual(np.array([2, 3, 1]), r.numpy()) self.assertAllEqual(np.array([2, 3, 1]), r.numpy())
def testEmptyTensorList(self): def testEmptyTensorList(self):
a = pywrap_tensorflow.TFE_Py_TensorShapeSlice([], 0) a = pywrap_tfe.TFE_Py_TensorShapeSlice([], 0)
self.assertTrue(isinstance(a, ops.EagerTensor)) self.assertTrue(isinstance(a, ops.EagerTensor))
self.assertEqual(0, a.numpy().size) self.assertEqual(0, a.numpy().size)
@ -452,12 +452,12 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
TypeError, TypeError,
r"Expected a list of EagerTensors but element 1 has type \"str\""): r"Expected a list of EagerTensors but element 1 has type \"str\""):
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, "abc"], 0) pywrap_tfe.TFE_Py_TensorShapeSlice([t1, "abc"], 0)
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
TypeError, TypeError,
r"Expected a list of EagerTensors but element 0 has type \"int\""): r"Expected a list of EagerTensors but element 0 has type \"int\""):
pywrap_tensorflow.TFE_Py_TensorShapeSlice([2, t1], 0) pywrap_tfe.TFE_Py_TensorShapeSlice([2, t1], 0)
def testTensorListNotList(self): def testTensorListNotList(self):
t1 = _create_tensor([1, 2], dtype=dtypes.int32) t1 = _create_tensor([1, 2], dtype=dtypes.int32)
@ -465,7 +465,7 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
TypeError, TypeError,
r"tensors argument must be a list or a tuple. Got.*EagerTensor"): r"tensors argument must be a list or a tuple. Got.*EagerTensor"):
pywrap_tensorflow.TFE_Py_TensorShapeSlice(t1, -2) pywrap_tfe.TFE_Py_TensorShapeSlice(t1, -2)
def testNegativeSliceDim(self): def testNegativeSliceDim(self):
t1 = _create_tensor([1, 2], dtype=dtypes.int32) t1 = _create_tensor([1, 2], dtype=dtypes.int32)
@ -473,7 +473,7 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
ValueError, ValueError,
r"Slice dimension must be non-negative. Got -2"): r"Slice dimension must be non-negative. Got -2"):
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1], -2) pywrap_tfe.TFE_Py_TensorShapeSlice([t1], -2)
def testUnicode(self): def testUnicode(self):
self.assertEqual(constant_op.constant(u"asdf").numpy(), b"asdf") self.assertEqual(constant_op.constant(u"asdf").numpy(), b"asdf")
@ -493,31 +493,31 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
IndexError, IndexError,
r"Slice dimension \(2\) must be smaller than rank of all tensors, " r"Slice dimension \(2\) must be smaller than rank of all tensors, "
"but tensor at index 0 has rank 2"): "but tensor at index 0 has rank 2"):
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1], 2) pywrap_tfe.TFE_Py_TensorShapeSlice([t1], 2)
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
IndexError, IndexError,
r"Slice dimension \(1\) must be smaller than rank of all tensors, " r"Slice dimension \(1\) must be smaller than rank of all tensors, "
"but tensor at index 0 has rank 1"): "but tensor at index 0 has rank 1"):
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2], 1) pywrap_tfe.TFE_Py_TensorShapeSlice([t2], 1)
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
IndexError, IndexError,
r"Slice dimension \(1\) must be smaller than rank of all tensors, " r"Slice dimension \(1\) must be smaller than rank of all tensors, "
"but tensor at index 1 has rank 1"): "but tensor at index 1 has rank 1"):
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2], 1) pywrap_tfe.TFE_Py_TensorShapeSlice([t1, t2], 1)
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
IndexError, IndexError,
r"Slice dimension \(0\) must be smaller than rank of all tensors, " r"Slice dimension \(0\) must be smaller than rank of all tensors, "
"but tensor at index 0 has rank 0"): "but tensor at index 0 has rank 0"):
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t3], 0) pywrap_tfe.TFE_Py_TensorShapeSlice([t3], 0)
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
IndexError, IndexError,
r"Slice dimension \(0\) must be smaller than rank of all tensors, " r"Slice dimension \(0\) must be smaller than rank of all tensors, "
"but tensor at index 2 has rank 0"): "but tensor at index 2 has rank 0"):
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2, t1, t3], 0) pywrap_tfe.TFE_Py_TensorShapeSlice([t2, t1, t3], 0)
@test_util.assert_no_new_pyobjects_executing_eagerly @test_util.assert_no_new_pyobjects_executing_eagerly
def testTensorDir(self): def testTensorDir(self):

View File

@ -36,7 +36,12 @@ from tensorflow.core.framework import node_def_pb2
from tensorflow.core.framework import op_def_pb2 from tensorflow.core.framework import op_def_pb2
from tensorflow.core.framework import versions_pb2 from tensorflow.core.framework import versions_pb2
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
# pywrap_tensorflow must be imported first to avoid profobuf issues.
# (b/143110113)
# pylint: disable=invalid-import-order,g-bad-import-order
from tensorflow.python import pywrap_tensorflow as c_api from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python import pywrap_tfe as c_api_new
# pylint: enable=invalid-import-order,g-bad-import-order
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import core from tensorflow.python.eager import core
@ -249,7 +254,7 @@ def register_dense_tensor_like_type(tensor_type):
def uid(): def uid():
"""A unique (within this program execution) integer.""" """A unique (within this program execution) integer."""
return c_api.TFE_Py_UID() return c_api_new.TFE_Py_UID()
def numpy_text(tensor, is_repr=False): def numpy_text(tensor, is_repr=False):
@ -1135,7 +1140,7 @@ class _EagerTensorBase(Tensor):
# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and # This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and
# registers it with the current module. # registers it with the current module.
EagerTensor = c_api.TFE_Py_InitEagerTensor(_EagerTensorBase) EagerTensor = c_api_new.TFE_Py_InitEagerTensor(_EagerTensorBase)
register_dense_tensor_like_type(Tensor) register_dense_tensor_like_type(Tensor)

View File

@ -789,8 +789,7 @@ void GenEagerPythonOp::AddEagerFastPathExecute() {
strings::StrAppend(&result_, " try:\n"); strings::StrAppend(&result_, " try:\n");
strings::StrAppend( strings::StrAppend(
&result_, " ", &result_, " ", "_result = pywrap_tfe.TFE_Py_FastPathExecute(\n",
"_result = _pywrap_tensorflow.TFE_Py_FastPathExecute(\n",
WordWrap(strings::StrCat(" "), WordWrap(strings::StrCat(" "),
strings::StrCat(fastpath_execute_params, ")"), kRightMargin), strings::StrCat(fastpath_execute_params, ")"), kRightMargin),
"\n"); "\n");
@ -1000,7 +999,7 @@ This file is MACHINE GENERATED! Do not edit.
import collections import collections
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow from tensorflow.python import pywrap_tfe as pywrap_tfe
from tensorflow.python.eager import context as _context from tensorflow.python.eager import context as _context
from tensorflow.python.eager import core as _core from tensorflow.python.eager import core as _core
from tensorflow.python.eager import execute as _execute from tensorflow.python.eager import execute as _execute

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -115,8 +116,8 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
non_neg_concat_dim = ( non_neg_concat_dim = (
concat_dim._numpy().item(0) % input_values[0]._rank()) # pylint: disable=protected-access concat_dim._numpy().item(0) % input_values[0]._rank()) # pylint: disable=protected-access
# All inputs are guaranteed to be EagerTensors in eager mode # All inputs are guaranteed to be EagerTensors in eager mode
sizes = pywrap_tensorflow.TFE_Py_TensorShapeSlice(input_values, sizes = pywrap_tfe.TFE_Py_TensorShapeSlice(input_values,
non_neg_concat_dim) non_neg_concat_dim)
out_grads = array_ops.split(grad, sizes, non_neg_concat_dim) out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
else: else:
if constant_op.is_constant(concat_dim): if constant_op.is_constant(concat_dim):

View File

@ -26,7 +26,7 @@ import sys
from absl import logging from absl import logging
import six import six
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tfe
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
@ -45,7 +45,7 @@ from tensorflow.python.util.tf_export import tf_export
# Register printing to the cell output if we are in a Colab or Jupyter Notebook. # Register printing to the cell output if we are in a Colab or Jupyter Notebook.
try: try:
get_ipython() # Exists in an ipython env like Jupyter or Colab get_ipython() # Exists in an ipython env like Jupyter or Colab
pywrap_tensorflow.TFE_Py_EnableInteractivePythonLogging() pywrap_tfe.TFE_Py_EnableInteractivePythonLogging()
except NameError: except NameError:
pass pass

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/python/lib/core/py_exception_registry.h" #include "tensorflow/python/lib/core/py_exception_registry.h"
using tensorflow::uint64; using tensorflow::uint64;
@ -233,7 +234,50 @@ _COPY_TYPEMAPS(unsigned int, mode_t);
%define override %enddef %define override %enddef
#endif #endif
// This was originally included in pywrap_tfe.i, but is used by tf_session.i
%include "tensorflow/c/tf_status.h" %include "tensorflow/c/tf_status.h"
%include "tensorflow/c/tf_datatype.h"
%typemap(in) (const void* proto) {
char* c_string;
Py_ssize_t py_size;
// PyBytes_AsStringAndSize() does not copy but simply interprets the input
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
// Python has raised an error (likely TypeError or UnicodeEncodeError).
SWIG_fail;
}
$1 = static_cast<void*>(c_string);
}
%typemap(in) int64_t {
$1 = PyLong_AsLongLong($input);
}
%typemap(out) TF_DataType {
$result = PyInt_FromLong($1);
}
%typemap(out) int64_t {
$result = PyInt_FromLong($1);
}
%typemap(out) TF_AttrType {
$result = PyInt_FromLong($1);
}
%typemap(in, numinputs=0) unsigned char* is_list (unsigned char tmp) {
tmp = 0;
$1 = &tmp;
}
%typemap(argout) unsigned char* is_list {
if (*$1 == 1) {
PyObject* list = PyList_New(1);
PyList_SetItem(list, 0, $result);
$result = list;
}
}
// Typemaps to automatically raise a Python exception from bad output TF_Status. // Typemaps to automatically raise a Python exception from bad output TF_Status.
// TODO(b/77295559): expand this to all TF_Status* output params and deprecate // TODO(b/77295559): expand this to all TF_Status* output params and deprecate

View File

@ -1,515 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/lib/core/strings.i"
%include "tensorflow/python/platform/base.i"
%{
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
%}
%include "tensorflow/c/tf_datatype.h"
%include "tensorflow/c/tf_status.h"
%ignoreall;
%rename("%s") TF_SetXlaEnableLazyCompilation;
%rename("%s") TF_SetTfXlaCpuGlobalJit;
%rename("%s") TF_SetXlaAutoJitMode;
%rename("%s") TF_SetXlaConstantFoldingDisabled;
%rename("%s") TF_GetXlaConstantFoldingDisabled;
%rename("%s") TF_SetXlaMinClusterSize;
%rename("%s") TFE_NewContext;
%rename("%s") TFE_DeleteContext;
%rename("%s") TFE_ContextListDevices;
%rename("%s") TFE_ContextAddFunction;
%rename("%s") TFE_ContextAddFunctionDef;
%rename("%s") TFE_ContextRemoveFunction;
%rename("%s") TFE_ContextHasFunction;
%rename("%s") TFE_ContextEnableRunMetadata;
%rename("%s") TFE_ContextDisableRunMetadata;
%rename("%s") TFE_ContextEnableGraphCollection;
%rename("%s") TFE_ContextDisableGraphCollection;
%rename("%s") TFE_ContextExportRunMetadata;
%rename("%s") TFE_ContextClearCaches;
%rename("%s") TFE_ContextGetDevicePlacementPolicy;
%rename("%s") TFE_ContextGetMirroringPolicy;
%rename("%s") TFE_ContextSetThreadLocalDevicePlacementPolicy;
%rename("%s") TFE_ContextSetThreadLocalMirroringPolicy;
%rename("%s") TFE_ContextSetServerDef;
%rename("%s") TFE_ContextUpdateServerDef;
%rename("%s") TFE_ContextCheckAlive;
%rename("%s") TFE_NewExecutor;
%rename("%s") TFE_DeleteExecutor;
%rename("%s") TFE_ExecutorIsAsync;
%rename("%s") TFE_ExecutorWaitForAllPendingNodes;
%rename("%s") TFE_ExecutorClearError;
%rename("%s") TFE_ContextSetExecutorForThread;
%rename("%s") TFE_ContextGetExecutorForThread;
%rename("%s") TFE_NewProfiler;
%rename("%s") TFE_ProfilerIsOk;
%rename("%s") TFE_DeleteProfiler;
%rename("%s") TFE_ProfilerSerializeToString;
%rename("%s") TFE_StartProfilerServer;
%rename("%s") TFE_ProfilerClientStartTracing;
%rename("%s") TFE_ProfilerClientMonitor;
%rename("%s") TFE_OpNameGetAttrType;
%rename("%s") TFE_Py_InitEagerTensor;
%rename("%s") TFE_Py_SetEagerTensorProfiler;
%rename("%s") TFE_Py_RegisterExceptionClass;
%rename("%s") TFE_Py_RegisterJVPFunction;
%rename("%s") TFE_Py_RegisterGradientFunction;
%rename("%s") TFE_Py_RegisterFallbackExceptionClass;
%rename("%s") TFE_Py_Execute;
%rename("%s") TFE_Py_ExecuteCancelable;
%rename("%s") TFE_Py_FastPathExecute;
%rename("%s") TFE_Py_RecordGradient;
%rename("%s") TFE_Py_UID;
%rename("%s") TFE_Py_TapeSetNew;
%rename("%s") TFE_Py_TapeSetAdd;
%rename("%s") TFE_Py_TapeSetRemove;
%rename("%s") TFE_Py_TapeSetStopOnThread;
%rename("%s") TFE_Py_TapeSetRestartOnThread;
%rename("%s") TFE_Py_TapeSetIsStopped;
%rename("%s") TFE_Py_TapeSetIsEmpty;
%rename("%s") TFE_Py_TapeSetShouldRecordBackprop;
%rename("%s") TFE_Py_TapeSetPossibleGradientTypes;
%rename("%s") TFE_Py_TapeSetDeleteTrace;
%rename("%s") TFE_Py_TapeSetRecordOperation;
%rename("%s") TFE_Py_TapeSetRecordOperationBackprop;
%rename("%s") TFE_Py_TapeSetRecordOperationForwardprop;
%rename("%s") TFE_Py_TapeGradient;
%rename("%s") TFE_Py_TapeVariableAccessed;
%rename("%s") TFE_Py_TapeWatch;
%rename("%s") TFE_Py_TapeWatchVariable;
%rename("%s") TFE_Py_TapeWatchedVariables;
%rename("%s") TFE_Py_ForwardAccumulatorNew;
%rename("%s") TFE_Py_ForwardAccumulatorSetAdd;
%rename("%s") TFE_Py_ForwardAccumulatorSetRemove;
%rename("%s") TFE_Py_ForwardAccumulatorWatch;
%rename("%s") TFE_Py_ForwardAccumulatorJVP;
%rename("%s") TFE_Py_ForwardAccumulatorPushState;
%rename("%s") TFE_Py_ForwardAccumulatorPopState;
%rename("%s") TFE_Py_PackJVPs;
%rename("%s") TFE_NewContextOptions;
%rename("%s") TFE_ContextOptionsSetConfig;
%rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy;
%rename("%s") TFE_ContextOptionsSetMirroringPolicy;
%rename("%s") TFE_ContextOptionsSetAsync;
%rename("%s") TFE_ContextOptionsSetLazyRemoteInputsCopy;
%rename("%s") TFE_DeleteContextOptions;
%rename("%s") TFE_Py_TensorShapeSlice;
%rename("%s") TFE_Py_TensorShapeOnDevice;
%rename("%s") TFE_Py_EnableInteractivePythonLogging;
%rename("%s") TFE_Py_SetEagerContext;
%rename("%s") TFE_ContextStartStep;
%rename("%s") TFE_ContextEndStep;
%rename("%s") TFE_Py_RegisterVSpace;
%rename("%s") TFE_Py_EncodeArg;
%rename("%s") TFE_EnableCollectiveOps;
%rename("%s") TF_ListPhysicalDevices;
%rename("%s") TF_PickUnusedPortOrDie;
%rename("%s") TFE_MonitoringCounterCellIncrementBy;
%rename("%s") TFE_MonitoringCounterCellValue;
%rename("%s") TFE_MonitoringNewCounter0;
%rename("%s") TFE_MonitoringDeleteCounter0;
%rename("%s") TFE_MonitoringGetCellCounter0;
%rename("%s") TFE_MonitoringNewCounter1;
%rename("%s") TFE_MonitoringDeleteCounter1;
%rename("%s") TFE_MonitoringGetCellCounter1;
%rename("%s") TFE_MonitoringNewCounter2;
%rename("%s") TFE_MonitoringDeleteCounter2;
%rename("%s") TFE_MonitoringGetCellCounter2;
%rename("%s") TFE_MonitoringIntGaugeCellSet;
%rename("%s") TFE_MonitoringIntGaugeCellValue;
%rename("%s") TFE_MonitoringNewIntGauge0;
%rename("%s") TFE_MonitoringDeleteIntGauge0;
%rename("%s") TFE_MonitoringGetCellIntGauge0;
%rename("%s") TFE_MonitoringNewIntGauge1;
%rename("%s") TFE_MonitoringDeleteIntGauge1;
%rename("%s") TFE_MonitoringGetCellIntGauge1;
%rename("%s") TFE_MonitoringNewIntGauge2;
%rename("%s") TFE_MonitoringDeleteIntGauge2;
%rename("%s") TFE_MonitoringGetCellIntGauge2;
%rename("%s") TFE_MonitoringStringGaugeCellSet;
%rename("%s") TFE_MonitoringStringGaugeCellValue;
%rename("%s") TFE_MonitoringNewStringGauge0;
%rename("%s") TFE_MonitoringDeleteStringGauge0;
%rename("%s") TFE_MonitoringGetCellStringGauge0;
%rename("%s") TFE_MonitoringNewStringGauge1;
%rename("%s") TFE_MonitoringDeleteStringGauge1;
%rename("%s") TFE_MonitoringGetCellStringGauge1;
%rename("%s") TFE_MonitoringNewStringGauge2;
%rename("%s") TFE_MonitoringDeleteStringGauge2;
%rename("%s") TFE_MonitoringGetCellStringGauge2;
%rename("%s") TFE_MonitoringBoolGaugeCellSet;
%rename("%s") TFE_MonitoringBoolGaugeCellValue;
%rename("%s") TFE_MonitoringNewBoolGauge0;
%rename("%s") TFE_MonitoringDeleteBoolGauge0;
%rename("%s") TFE_MonitoringGetCellBoolGauge0;
%rename("%s") TFE_MonitoringNewBoolGauge1;
%rename("%s") TFE_MonitoringDeleteBoolGauge1;
%rename("%s") TFE_MonitoringGetCellBoolGauge1;
%rename("%s") TFE_MonitoringNewBoolGauge2;
%rename("%s") TFE_MonitoringDeleteBoolGauge2;
%rename("%s") TFE_MonitoringGetCellBoolGauge2;
%rename("%s") TFE_MonitoringSamplerCellAdd;
%rename("%s") TFE_MonitoringSamplerCellValue;
%rename("%s") TFE_MonitoringNewExponentialBuckets;
%rename("%s") TFE_MonitoringDeleteBuckets;
%rename("%s") TFE_MonitoringNewSampler0;
%rename("%s") TFE_MonitoringDeleteSampler0;
%rename("%s") TFE_MonitoringGetCellSampler0;
%rename("%s") TFE_MonitoringNewSampler1;
%rename("%s") TFE_MonitoringDeleteSampler1;
%rename("%s") TFE_MonitoringGetCellSampler1;
%rename("%s") TFE_MonitoringNewSampler2;
%rename("%s") TFE_MonitoringDeleteSampler2;
%rename("%s") TFE_MonitoringGetCellSampler2;
%rename("%s") TFE_NewCancellationManager;
%rename("%s") TFE_CancellationManagerIsCancelled;
%rename("%s") TFE_CancellationManagerStartCancel;
%rename("%s") TFE_DeleteCancellationManager;
%rename("%s") TF_ImportGraphDefOptionsSetValidateColocationConstraints;
%rename("%s") TFE_ClearScalarCache;
%{
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/python/util/util.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/common_runtime/device_factory.h"
static PyObject* TF_ListPhysicalDevices(TF_Status* status) {
std::vector<string> devices;
tensorflow::Status s = tensorflow::DeviceFactory::ListAllPhysicalDevices(&devices);
tensorflow::Set_TF_Status_from_Status(status, s);
if (!s.ok()) {
Py_RETURN_NONE;
};
PyObject* result = PyList_New(devices.size());
int i = 0;
for (auto& dev : devices) {
PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size());
PyList_SetItem(result, i, dev_obj);
++i;
}
return result;
}
%}
static PyObject* TF_ListPhysicalDevices(TF_Status* status);
%{
#include "tensorflow/python/eager/pywrap_tensor_conversion.h"
static PyObject* TFE_ClearScalarCache() {
tensorflow::TFE_TensorHandleCache::Get()->Clear();
Py_RETURN_NONE;
}
%}
static PyObject* TFE_ClearScalarCache();
%typemap(in) (const void* proto) {
char* c_string;
Py_ssize_t py_size;
// PyBytes_AsStringAndSize() does not copy but simply interprets the input
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
// Python has raised an error (likely TypeError or UnicodeEncodeError).
SWIG_fail;
}
$1 = static_cast<void*>(c_string);
}
%typemap(in) int64_t {
$1 = PyLong_AsLongLong($input);
}
%typemap(out) TF_DataType {
$result = PyInt_FromLong($1);
}
%typemap(out) int64_t {
$result = PyInt_FromLong($1);
}
%typemap(out) TF_AttrType {
$result = PyInt_FromLong($1);
}
%typemap(in, numinputs=0) unsigned char* is_list (unsigned char tmp) {
tmp = 0;
$1 = &tmp;
}
%typemap(argout) unsigned char* is_list {
if (*$1 == 1) {
PyObject* list = PyList_New(1);
PyList_SetItem(list, 0, $result);
$result = list;
}
}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* serialized_function_def {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* device_name {
if ($input == Py_None) {
$1 = nullptr;
} else {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* op_name {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* name {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* description {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* label {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* label1 {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* label2 {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* service_addr {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* logdir {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* worker_list {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
%typemap(in) (TFE_Context*) {
$1 = (TFE_Context*)PyCapsule_GetPointer($input, nullptr);
}
%typemap(out) (TFE_Context*) {
// When the TFE_Context* returned is a nullptr, we expect the status is not
// OK. This will raise an error (happens in another typemap).
if ($1 != nullptr) {
$result = PyCapsule_New($1, nullptr, TFE_DeleteContextCapsule);
}
}
%rename("%s") TFE_ContextDevicePlacementPolicy;
%rename("%s") TFE_DEVICE_PLACEMENT_EXPLICIT;
%rename("%s") TFE_DEVICE_PLACEMENT_WARN;
%rename("%s") TFE_DEVICE_PLACEMENT_SILENT;
%rename("%s") TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32;
%rename("%s") TFE_ContextMirroringPolicy;
%rename("%s") TFE_MIRRORING_NONE;
%rename("%s") TFE_MIRRORING_ALL;
%include "tensorflow/c/eager/c_api.h"
%typemap(in) TFE_InputTensorHandles* inputs (TFE_InputTensorHandles temp) {
$1 = &temp;
if ($input != Py_None) {
if (!PyList_Check($input)) {
SWIG_exception_fail(SWIG_TypeError,
"must provide a list of Tensors as inputs");
}
Py_ssize_t len = PyList_Size($input);
$1->resize(len);
for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* elem = PyList_GetItem($input, i);
if (!elem) {
SWIG_fail;
}
if (EagerTensor_CheckExact(elem)) {
(*$1)[i] = EagerTensor_Handle(elem);
} else if (tensorflow::swig::IsEagerTensorSlow(elem)) {
// Use equivalent of object.__getattribute__ to get the underlying
// tf wrapped EagerTensor (if there is one).
tensorflow::Safe_PyObjectPtr tf_should_use_attr(
#if PY_MAJOR_VERSION < 3
PyString_InternFromString("_tf_should_use_wrapped_value")
#else
PyUnicode_InternFromString("_tf_should_use_wrapped_value")
#endif
);
tensorflow::Safe_PyObjectPtr value_attr(
PyObject_GenericGetAttr(elem, tf_should_use_attr.get()));
if (value_attr) {
// This is an EagerTensor wrapped inside a TFShouldUse wrapped object.
(*$1)[i] = EagerTensor_Handle(value_attr.get());
} else {
// This is a subclass of EagerTensor that we don't support.
PyErr_Clear();
SWIG_exception_fail(
SWIG_TypeError,
tensorflow::strings::StrCat(
"Saw an object that is an instance of a strict subclass of "
"EagerTensor, which is not supported. Item ",
i, " is type: ", elem->ob_type->tp_name)
.c_str());
}
} else if (tensorflow::swig::IsTensor(elem)) {
// If it isnt an EagerTensor, but is still a Tensor, it must be a graph
// tensor.
tensorflow::Safe_PyObjectPtr name_attr(
PyObject_GetAttrString(elem, "name"));
SWIG_exception_fail(
SWIG_TypeError,
tensorflow::strings::StrCat(
"An op outside of the function building code is being passed\n"
"a \"Graph\" tensor. It is possible to have Graph tensors\n"
"leak out of the function building context by including a\n"
"tf.init_scope in your function building code.\n"
"For example, the following function will fail:\n",
" @tf.function\n",
" def has_init_scope():\n",
" my_constant = tf.constant(1.)\n",
" with tf.init_scope():\n",
" added = my_constant * 2\n",
"The graph tensor has name: ",
name_attr ? TFE_GetPythonString(name_attr.get()) : "<unknown>"
).c_str());
} else {
SWIG_exception_fail(
SWIG_TypeError,
tensorflow::strings::StrCat(
"provided list of inputs contains objects other "
"than 'EagerTensor'. Item ",
i, " is type: ", elem->ob_type->tp_name).c_str());
}
}
}
}
// Temporary for the argout
%typemap(in) TFE_OutputTensorHandles* outputs (TFE_OutputTensorHandles temp) {
if (!PyInt_Check($input)) {
SWIG_exception_fail(SWIG_TypeError,
"expected an integer value (size of the number of "
"outputs of the operation)");
}
$1 = &temp;
long sz = PyInt_AsLong($input);
if (sz > 0) {
$1->resize(PyInt_AsLong($input), nullptr);
}
}
// Create new Status object.
%typemap(in, numinputs=0) TF_Status *out_status {
$1 = GetStatus();
}
%typemap(freearg) (TF_Status* out_status) {
ReturnStatus($1);
}
%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status) {
if (MaybeRaiseExceptionFromTFStatus($2, nullptr)) {
SWIG_fail;
} else {
int num_outputs = $1->size();
Py_CLEAR($result);
$result = PyList_New(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
PyObject *output;
output = EagerTensorFromHandle($1->at(i));
PyList_SetItem($result, i, output);
}
}
}
// SWIG usually unwraps the tuple that the native Python/C interface generates.
// Since we wanted to have a function with a variable length of arguments, we
// used the native Python/C interface directly (which by default supports
// passing all arguments as a tuple).
%native(TFE_Py_FastPathExecute) TFE_Py_FastPathExecute_C;
%include "tensorflow/python/eager/pywrap_tfe.h"
%include "tensorflow/c/c_api_experimental.h"
%include "tensorflow/c/eager/c_api_experimental.h"
// Clear all typemaps.
%typemap(out) TF_DataType;
%typemap(in) int64_t;
%typemap(out) int64_t;
%typemap(out) TF_AttrType;
%typemap(in, numinputs=0) TF_Status *out_status;
%typemap(argout) unsigned char* is_list;
%typemap(in) const char* description;
%typemap(in) const char* label1;
%typemap(in) const char* label2;
%typemap(in) (TFE_Context*);
%typemap(out) (TFE_Context*);
%typemap(in) TFE_OutputTensorHandles* outputs (TFE_OutputTensorHandles temp);
%typemap(in, numinputs=0) TF_Status *out_status;
%typemap(freearg) (TF_Status* out_status);
%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status);
%typemap(in) (const void* proto);
%unignoreall

View File

@ -0,0 +1,29 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python module for TFE ops and functions exported by pybind11.
This module is created because we are splitting out eager bindings from
pywrap_tensorflow. This is causing some issues where Graphs are not properly
initialized when running eager code. Once the graph architecture has been
removed from pywrap_tensorflow as well, we can remove this file.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=invalid-import-order,g-bad-import-order, wildcard-import, unused-import
from tensorflow.python import pywrap_tensorflow
from tensorflow.python._pywrap_tfe import *

View File

@ -17,8 +17,6 @@ limitations under the License.
* The includes are intentionally not alphabetically sorted, as the order of * The includes are intentionally not alphabetically sorted, as the order of
* includes follows dependency order */ * includes follows dependency order */
%include "tensorflow/python/pywrap_tfe.i"
%include "tensorflow/python/client/tf_session.i" %include "tensorflow/python/client/tf_session.i"
%include "tensorflow/python/lib/io/file_io.i" %include "tensorflow/python/lib/io/file_io.i"

File diff suppressed because it is too large Load Diff

View File

@ -3,6 +3,7 @@
*perftools*gputools* *perftools*gputools*
*tf_* *tf_*
*TF_* *TF_*
*Eager*
*TFE_* *TFE_*
*nsync_* *nsync_*
*stream_executor* *stream_executor*

View File

@ -4,6 +4,7 @@ tensorflow {
*toco*; *toco*;
*perftools*gputools*; *perftools*gputools*;
*TF_*; *TF_*;
*Eager*;
*TFE_*; *TFE_*;
*nsync_*; *nsync_*;
*stream_executor*; *stream_executor*;

View File

@ -1,4 +1,4 @@
[cpp_python_util] # util [cpp_python_util] # util tfe
tensorflow::swig::IsSequence tensorflow::swig::IsSequence
tensorflow::swig::IsSequenceOrComposite tensorflow::swig::IsSequenceOrComposite
tensorflow::swig::IsCompositeTensor tensorflow::swig::IsCompositeTensor
@ -17,6 +17,7 @@ tensorflow::swig::IsSequenceForData
tensorflow::swig::FlattenForData tensorflow::swig::FlattenForData
tensorflow::swig::AssertSameStructureForData tensorflow::swig::AssertSameStructureForData
tensorflow::swig::RegisterType tensorflow::swig::RegisterType
tensorflow::swig::IsEagerTensorSlow
[util_port] # util_port [util_port] # util_port
tensorflow::IsGoogleCudaEnabled tensorflow::IsGoogleCudaEnabled
@ -74,11 +75,12 @@ tensorflow::Status::code
tensorflow::Status::error_message tensorflow::Status::error_message
tensorflow::Status::ok() tensorflow::Status::ok()
[core_cpu_impl] # device_lib [core_cpu_impl] # device_lib tfe
tensorflow::Device::attributes tensorflow::Device::attributes
tensorflow::DeviceFactory::AddDevices tensorflow::DeviceFactory::AddDevices
tensorflow::SessionOptions::SessionOptions tensorflow::SessionOptions::SessionOptions
tensorflow::DoQuantizeTrainingOnSerializedGraphDef tensorflow::DoQuantizeTrainingOnSerializedGraphDef
tensorflow::DeviceFactory::ListAllPhysicalDevices
[protos_all] # device_lib, dtypes [protos_all] # device_lib, dtypes
tensorflow::DataType_IsValid tensorflow::DataType_IsValid
@ -123,3 +125,67 @@ tensorflow::make_safe
[python_op_gen] # python_op_gen [python_op_gen] # python_op_gen
tensorflow::GetPythonWrappers tensorflow::GetPythonWrappers
[pywrap_tfe_lib] # tfe
tensorflow::TFE_TensorHandleCache
tensorflow::TFE_TensorHandleCache::Clear
EagerTensor_CheckExact
EagerTensorFromHandle
EagerTensor_Handle
TFE_Py_ExecuteCancelable
TFE_Py_RegisterExceptionClass
TFE_Py_RegisterVSpace
TFE_Py_RegisterFallbackExceptionClass
TFE_Py_RegisterGradientFunction
TFE_Py_RegisterJVPFunction
TFE_GetPythonString
TFE_Py_UID
TFE_DeleteContextCapsule
TFE_Py_InitEagerTensor
TFE_Py_SetEagerTensorProfiler
TFE_Py_TapeSetNew
TFE_Py_TapeSetRemove
TFE_Py_TapeSetAdd
TFE_Py_TapeSetIsEmpty
TFE_Py_TapeSetShouldRecordBackprop
TFE_Py_TapeSetPossibleGradientTypes
TFE_Py_TapeWatch
TFE_Py_TapeSetDeleteTrace
TFE_Py_TapeSetStopOnThread
TFE_Py_TapeSetRestartOnThread
TFE_Py_TapeSetIsStopped
TFE_Py_TapeSetRecordOperation
TFE_Py_TapeSetRecordOperationBackprop
TFE_Py_TapeSetRecordOperationForwardprop
TFE_Py_TapeVariableAccessed
TFE_Py_TapeWatchVariable
TFE_Py_TapeGradient
TFE_Py_FastPathExecute_C
TFE_Py_RecordGradient
TFE_Py_TapeWatchedVariables
TFE_Py_ForwardAccumulatorNew
TFE_Py_ForwardAccumulatorSetAdd
TFE_Py_ForwardAccumulatorSetRemove
TFE_Py_ForwardAccumulatorWatch
TFE_Py_ForwardAccumulatorJVP
TFE_Py_ForwardAccumulatorPushState
TFE_Py_ForwardAccumulatorPopState
TFE_Py_PackJVPs
TFE_Py_TensorShapeSlice
TFE_Py_TensorShapeOnDevice
TFE_Py_EncodeArg
TFE_Py_EnableInteractivePythonLogging
TFE_Py_SetEagerContext
[eager_executor] # tfe
tensorflow::EagerExecutor::~EagerExecutor
tensorflow::EagerContext::WaitForAndCloseRemoteContexts
[profiler_session] # tfe
tensorflow::ProfilerSession::~ProfilerSession
[tf_status_helper] # tfe
tensorflow::Set_TF_Status_from_Status
[context] # tfe
tensorflow::EagerContext::WaitForAndCloseRemoteContexts