From 7bd345bcbb9afbd9e980efac0586b53a0ef0a9b0 Mon Sep 17 00:00:00 2001 From: Amit Patankar Date: Tue, 17 Dec 2019 19:37:26 -0800 Subject: [PATCH] Export the Eager classes and functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information. PiperOrigin-RevId: 286110711 Change-Id: I7bf6f6f4ce1d6bf3e8a3e40ef4a83f82333800f6 --- tensorflow/c/BUILD | 14 + tensorflow/c/eager/BUILD | 12 + tensorflow/core/common_runtime/eager/BUILD | 17 + tensorflow/core/distributed_runtime/BUILD | 17 + .../core/distributed_runtime/eager/BUILD | 13 + tensorflow/core/framework/BUILD | 12 + tensorflow/core/profiler/internal/BUILD | 11 + tensorflow/core/profiler/lib/BUILD | 11 + tensorflow/python/BUILD | 72 +- tensorflow/python/client/tf_session.i | 25 + tensorflow/python/distribute/BUILD | 2 +- .../python/distribute/mirrored_strategy.py | 4 +- tensorflow/python/eager/BUILD | 58 +- tensorflow/python/eager/backprop.py | 27 +- tensorflow/python/eager/backprop_test.py | 14 +- tensorflow/python/eager/benchmarks_test.py | 16 +- tensorflow/python/eager/cancellation.py | 10 +- tensorflow/python/eager/context.py | 115 +- tensorflow/python/eager/core.py | 6 +- tensorflow/python/eager/core_test.py | 6 +- tensorflow/python/eager/eager_util.py | 61 + tensorflow/python/eager/execute.py | 14 +- tensorflow/python/eager/executor.py | 14 +- tensorflow/python/eager/forwardprop.py | 17 +- tensorflow/python/eager/forwardprop_test.py | 22 +- tensorflow/python/eager/forwardprop_util.py | 9 +- tensorflow/python/eager/function.py | 11 +- tensorflow/python/eager/imperative_grad.py | 4 +- tensorflow/python/eager/monitoring.py | 124 +- tensorflow/python/eager/profiler.py | 18 +- tensorflow/python/eager/profiler_client.py | 14 +- tensorflow/python/eager/pywrap_tfe_test.py | 91 +- tensorflow/python/eager/remote.py | 4 +- tensorflow/python/eager/tape.py | 44 +- tensorflow/python/eager/tensor_test.py | 26 +- tensorflow/python/framework/ops.py | 9 +- tensorflow/python/framework/python_op_gen.cc | 5 +- tensorflow/python/ops/array_grad.py | 5 +- tensorflow/python/ops/logging_ops.py | 4 +- tensorflow/python/platform/base.i | 44 + tensorflow/python/pywrap_tfe.i | 515 -------- tensorflow/python/pywrap_tfe.py | 29 + tensorflow/python/tensorflow.i | 2 - tensorflow/python/tfe_wrapper.cc | 1099 +++++++++++++++++ tensorflow/tf_exported_symbols.lds | 1 + tensorflow/tf_version_script.lds | 1 + .../tools/def_file_filter/symbols_pybind.txt | 70 +- 47 files changed, 1866 insertions(+), 853 deletions(-) create mode 100644 tensorflow/python/eager/eager_util.py delete mode 100755 tensorflow/python/pywrap_tfe.i create mode 100644 tensorflow/python/pywrap_tfe.py create mode 100644 tensorflow/python/tfe_wrapper.cc diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index efe01f7e049..76a02090c3b 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -53,6 +53,20 @@ filegroup( visibility = ["//visibility:public"], ) +filegroup( + name = "pywrap_eager_hdrs", + srcs = [ + "c_api_internal.h", + "tf_status_helper.h", + "tf_status_internal.h", + "tf_tensor_internal.h", + ], + visibility = [ + "//tensorflow/core:__pkg__", + "//tensorflow/python:__pkg__", + ], +) + tf_cuda_library( name = "c_api_internal", hdrs = [ diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 130e9a0c3c7..92e994183a2 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -88,6 +88,18 @@ tf_cuda_library( alwayslink = 1, ) +filegroup( + name = "pywrap_eager_hdrs", + srcs = [ + "c_api_experimental.h", + "c_api_internal.h", + ], + visibility = [ + "//tensorflow/core:__pkg__", + "//tensorflow/python:__pkg__", + ], +) + tf_cuda_library( name = "c_api_internal", srcs = ["c_api_experimental.h"], diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index e6825cb2090..5119dcdf562 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -439,6 +439,23 @@ tf_cc_test( ], ) +filegroup( + name = "pywrap_eager_hdrs", + srcs = [ + "attr_builder.h", + "context.h", + "eager_executor.h", + "eager_operation.h", + "kernel_and_device.h", + "tensor_handle.h", + "tensor_handle_data.h", + ], + visibility = [ + "//tensorflow/core:__pkg__", + "//tensorflow/python:__pkg__", + ], +) + filegroup( name = "srcs", srcs = glob( diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index c2da0e778da..2156dcfc3d3 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -783,3 +783,20 @@ tf_cc_test( "//tensorflow/core:worker_proto_cc", ], ) + +filegroup( + name = "pywrap_eager_hdrs", + srcs = [ + "call_options.h", + "message_wrappers.h", + "rendezvous_mgr_interface.h", + "server_lib.h", + "worker_cache.h", + "worker_env.h", + "worker_interface.h", + ], + visibility = [ + "//tensorflow/core:__pkg__", + "//tensorflow/python:__pkg__", + ], +) diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD index 6cd525b317d..a4f7309e07a 100644 --- a/tensorflow/core/distributed_runtime/eager/BUILD +++ b/tensorflow/core/distributed_runtime/eager/BUILD @@ -216,3 +216,16 @@ cc_library( "@com_google_absl//absl/types:optional", ], ) + +filegroup( + name = "pywrap_eager_hdrs", + srcs = [ + "eager_client.h", + "remote_tensor_handle.h", + "remote_tensor_handle_data.h", + ], + visibility = [ + "//tensorflow/core:__pkg__", + "//tensorflow/python:__pkg__", + ], +) diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index 2c8ad2a4697..23b18a0759b 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -853,6 +853,18 @@ tf_cc_tests( ], ) +filegroup( + name = "pywrap_eager_hdrs", + srcs = [ + "op_gen_lib.h", + "rendezvous.h", + ], + visibility = [ + "//tensorflow/core:__pkg__", + "//tensorflow/python:__pkg__", + ], +) + # All framewrok protos are self-contained, i.e. they only import other # protos from the same package, so we can build the protos here and then # link them from core:protos_all without circular dependencies. diff --git a/tensorflow/core/profiler/internal/BUILD b/tensorflow/core/profiler/internal/BUILD index 3e9f80807cf..304e5253072 100644 --- a/tensorflow/core/profiler/internal/BUILD +++ b/tensorflow/core/profiler/internal/BUILD @@ -523,3 +523,14 @@ tf_cc_test( "//tensorflow/core:testlib", ], ) + +filegroup( + name = "pywrap_eager_hdrs", + srcs = [ + "profiler_interface.h", + ], + visibility = [ + "//tensorflow/core:__pkg__", + "//tensorflow/python:__pkg__", + ], +) diff --git a/tensorflow/core/profiler/lib/BUILD b/tensorflow/core/profiler/lib/BUILD index e64a8e1fcc6..215eb1559d5 100644 --- a/tensorflow/core/profiler/lib/BUILD +++ b/tensorflow/core/profiler/lib/BUILD @@ -43,6 +43,17 @@ tf_cuda_library( alwayslink = True, ) +filegroup( + name = "pywrap_eager_hdrs", + srcs = [ + "profiler_session.h", + ], + visibility = [ + "//tensorflow/core:__pkg__", + "//tensorflow/python:__pkg__", + ], +) + cc_library( name = "traceme", hdrs = ["traceme.h"], diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 253c69a7347..b3339f9bead 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -171,6 +171,7 @@ py_library( ":platform", ":proto_ops", ":pywrap_tensorflow", + ":pywrap_tfe", ":rnn_ops_gen", ":saver_test_utils", ":script_ops", @@ -251,6 +252,7 @@ py_library( deps = [ ":_pywrap_util_port", ":lib", + ":pywrap_tfe", ":util", "//tensorflow/core:protos_all_py", "@absl_py//absl:app", @@ -477,13 +479,13 @@ cc_library( cc_library( name = "pybind11_status", hdrs = [ + "lib/core/py_exception_registry.h", "lib/core/pybind11_status.h", "//tensorflow/c:headers", ], features = ["-parse_headers"], visibility = tf_external_workspace_visible(visibility), deps = [ - ":py_exception_registry", "//tensorflow/c:tf_status_headers", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -1110,6 +1112,7 @@ py_library( ":lib", ":platform", ":pywrap_tensorflow", + ":pywrap_tfe", ":random_seed", ":sparse_tensor", ":tensor_spec", @@ -5492,7 +5495,6 @@ tf_py_wrap_cc( "lib/io/py_record_reader.i", "lib/io/py_record_writer.i", "platform/base.i", - "pywrap_tfe.i", "//tensorflow/compiler/mlir/python:mlir.i", ], # add win_def_file for pywrap_tensorflow @@ -5573,7 +5575,12 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [ ":safe_ptr", # checkpoint_reader ":python_op_gen", # python_op_gen ":bfloat16_lib", # bfloat16 + "//tensorflow/python/eager:pywrap_tfe_lib", # pywrap_tfe_lib "//tensorflow/core/util/tensor_bundle", # checkpoint_reader + "//tensorflow/core/common_runtime/eager:eager_executor", # tfe + "//tensorflow/core/common_runtime/eager:context", # tfe + "//tensorflow/core/profiler/lib:profiler_session", # tfe + "//tensorflow/c:tf_status_helper", # tfe ] # Filter the DEF file to reduce the number of symbols to 64K or less. @@ -7555,6 +7562,67 @@ py_library( ], ) +py_library( + name = "pywrap_tfe", + srcs = ["pywrap_tfe.py"], + visibility = ["//visibility:public"], + deps = [ + ":_pywrap_tfe", + ":pywrap_tensorflow", + ], +) + +tf_python_pybind_extension( + name = "_pywrap_tfe", + srcs = ["tfe_wrapper.cc"], + hdrs = [ + "lib/core/safe_ptr.h", + "util/util.h", + ":py_exception_registry_hdr", + "//tensorflow/c:headers", + "//tensorflow/c:pywrap_eager_hdrs", + "//tensorflow/c/eager:headers", + "//tensorflow/c/eager:pywrap_eager_hdrs", + "//tensorflow/core/common_runtime/eager:pywrap_eager_hdrs", + "//tensorflow/core/distributed_runtime:pywrap_eager_hdrs", + "//tensorflow/core/distributed_runtime/eager:pywrap_eager_hdrs", + "//tensorflow/core/framework:pywrap_eager_hdrs", + "//tensorflow/core/profiler/internal:pywrap_eager_hdrs", + "//tensorflow/core/profiler/lib:pywrap_eager_hdrs", + "//tensorflow/python/eager:pywrap_eager_hdrs", + ], + module_name = "_pywrap_tfe", + deps = [ + ":pybind11_lib", + ":pybind11_status", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@pybind11", + "//third_party/python_runtime:headers", + "//tensorflow/core:core_cpu_headers_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:platform", + "//tensorflow/core/profiler/protobuf:xplane_proto_cc", + ] + if_static( + extra_deps = [ + "//tensorflow/core:eager_service_proto_cc", + "//tensorflow/core:master_proto_cc", + "//tensorflow/core:worker_proto_cc", + ], + otherwise = [ + "//tensorflow/core:eager_service_proto_cc_headers_only", + "//tensorflow/core:master_proto_cc_headers_only", + "//tensorflow/core:worker_proto_cc_headers_only", + ], + ), +) + tf_python_pybind_extension( name = "_pywrap_graph_analyzer", srcs = ["grappler/graph_analyzer_tool_wrapper.cc"], diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 4abef7b6ec5..bf8536e641f 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +%include "tensorflow/python/lib/core/strings.i" %include "tensorflow/python/platform/base.i" %{ @@ -23,6 +24,13 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" #include "tensorflow/python/client/tf_session_helper.h" +#include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/python/lib/core/safe_ptr.h" +#include "tensorflow/python/eager/pywrap_tfe.h" +// We were getting lucky on imports with safe_ptr.h being placed prior to +// tf_session which imported safe_ptr. We also need pywrap_tfe.h to cast +// one of the inputs to a graph function from a Python string to const char*. + // Helper function to convert a Python list of Tensors to a C++ vector of // TF_Outputs. @@ -78,6 +86,9 @@ void PyInt64ListToVector(PyObject* py_int_seq, std::vector* vec) { %} +%include "tensorflow/c/tf_datatype.h" +%include "tensorflow/c/tf_status.h" + %include "tensorflow/python/client/tf_sessionrun_wrapper.i" // Required to use PyArray_* functions. @@ -85,6 +96,14 @@ void PyInt64ListToVector(PyObject* py_int_seq, std::vector* vec) { tensorflow::ImportNumpy(); %} +// For const parameters in a function, SWIG pretty much ignores the const. +// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13 +// Hence the 'const_cast'. +%typemap(in) const char* op_name { + $1 = const_cast(TFE_GetPythonString($input)); +} + + // TensorFlow version and GraphDef versions %constant const char* __version__ = TF_VERSION_STRING; %constant int GRAPH_DEF_VERSION = TF_GRAPH_DEF_VERSION; @@ -174,6 +193,12 @@ tensorflow::ImportNumpy(); // See comment for "%noexception TF_SessionRun_wrapper;" %noexception TF_OperationGetControlInputs_wrapper; + +// Migrate one function from pywrap_tfe.i +%include "tensorflow/c/c_api_experimental.h" +%unignore TF_ImportGraphDefOptionsSetValidateColocationConstraints; +%noexception TF_ImportGraphDefOptionsSetValidateColocationConstraints; + // Build a Python list of TF_Operation* and return it. %typemap(out) std::vector tensorflow::TF_OperationGetControlInputs_wrapper { $result = PyList_New($1.size()); diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 16ed490dd8b..ff60fe6bf3a 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -268,7 +268,7 @@ py_library( "//tensorflow/python:device", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:pywrap_tfe", "//tensorflow/python:summary_ops_v2", "//tensorflow/python:tensor_util", "//tensorflow/python:training", diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index 50f35b04fc3..729bb341b6f 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -24,7 +24,7 @@ import functools import threading import weakref -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import pywrap_tfe from tensorflow.python.autograph.core import ag_ctx from tensorflow.python.autograph.impl import api as autograph from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib @@ -944,7 +944,7 @@ class _MirroredReplicaThread(threading.Thread): self.record_thread_local_summary_state() self.record_thread_local_eager_context_state() self.context_device_policy = ( - pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy( + pywrap_tfe.TFE_ContextGetDevicePlacementPolicy( ctx._context_handle)) # pylint: disable=protected-access self.graph = ops.get_default_graph() with ops.init_scope(): diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index d869a3b627e..ad792ab70ba 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -56,6 +56,18 @@ cc_library( ], ) +filegroup( + name = "pywrap_eager_hdrs", + srcs = [ + "pywrap_tensor_conversion.h", + "pywrap_tfe.h", + ], + visibility = [ + "//tensorflow/core:__pkg__", + "//tensorflow/python:__pkg__", + ], +) + # Transitive dependencies of this target will be included in the pip package. py_library( name = "eager_pip", @@ -90,7 +102,7 @@ py_library( deps = [ ":context", "//tensorflow/python:errors", - "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:pywrap_tfe", ], ) @@ -100,7 +112,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:pywrap_tfe", ], ) @@ -121,7 +133,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:pywrap_tfe", ], ) @@ -131,13 +143,14 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ + ":eager_util", ":executor", ":monitoring", "//tensorflow/python:device", "//tensorflow/python:device_spec", "//tensorflow/python:errors", "//tensorflow/python:platform", - "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:pywrap_tfe", "//tensorflow/python:tf2", "//tensorflow/python:util", "//third_party/py/numpy", @@ -164,8 +177,8 @@ py_library( "//third_party/py/tf_agents:__subpackages__", ], deps = [ - "//tensorflow/python:c_api_util", - "//tensorflow/python:pywrap_tensorflow", + ":eager_util", + "//tensorflow/python:pywrap_tfe", "//tensorflow/python:util", ], ) @@ -187,7 +200,8 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ ":context", - "//tensorflow/python:pywrap_tensorflow", + ":eager_util", + "//tensorflow/python:pywrap_tfe", "//tensorflow/python:util", ], ) @@ -209,7 +223,8 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/python:pywrap_tensorflow", + ":eager_util", + "//tensorflow/python:pywrap_tfe", ], ) @@ -298,7 +313,7 @@ cuda_py_test( "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", - "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:pywrap_tfe", "//third_party/py/numpy", ], ) @@ -410,7 +425,7 @@ py_library( "//tensorflow/core:protos_all_py", "//tensorflow/python:dtypes", "//tensorflow/python:lib", - "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:pywrap_tfe", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", "@six_archive//:six", @@ -496,7 +511,7 @@ py_library( "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", - "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:pywrap_tfe", "//tensorflow/python:tensor_shape", "//tensorflow/python:unconnected_gradients", "//tensorflow/python:util", @@ -524,7 +539,7 @@ py_library( deps = [ ":forwardprop_util", "//tensorflow/python:platform", - "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:pywrap_tfe", "//tensorflow/python:util", ], ) @@ -535,7 +550,18 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:pywrap_tfe", + ], +) + +py_library( + name = "eager_util", + srcs = ["eager_util.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:pywrap_tfe", + "//tensorflow/python:util", ], ) @@ -552,7 +578,7 @@ cuda_py_test( ":remote", ":test", "//tensorflow/python:math_ops", - "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:pywrap_tfe", "//tensorflow/python:random_ops", "//tensorflow/python/keras", "//third_party/py/numpy", @@ -637,7 +663,7 @@ tf_py_test( ":test", "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", - "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:pywrap_tfe", "//tensorflow/python:random_ops", "//tensorflow/python:test_ops", "//third_party/py/numpy", @@ -649,7 +675,7 @@ py_library( srcs = ["imperative_grad.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:pywrap_tfe", "//tensorflow/python:unconnected_gradients", "//tensorflow/python:util", ], diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index e2a4992996f..d51597f4cbe 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -24,7 +24,7 @@ import sys import six -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import pywrap_tfe from tensorflow.python import _pywrap_utils from tensorflow.python.eager import backprop_util from tensorflow.python.eager import context @@ -71,19 +71,25 @@ def op_attr_type(op_type, attr_name): except KeyError: context.ensure_initialized() h = context.context()._handle # pylint: disable=protected-access - attr_type = pywrap_tensorflow.TFE_OpNameGetAttrType(h, op_type, attr_name) + attr_type = pywrap_tfe.TFE_OpNameGetAttrType(h, op_type, attr_name) _op_attr_type_cache[(op_type, attr_name)] = attr_type return attr_type def make_attr(attr_type, value): - if attr_type == pywrap_tensorflow.TF_ATTR_TYPE: + # pybind11 enums do not return the raw value like SWIG enums do. They are + # useful when comparing amongst each other but not direct integers as we are + # doing in most tests. + # https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types + # TODO(amitpatankar): After all SWIG transitions, convert the enum comparisons + # from integer value to class. + if attr_type == int(pywrap_tfe.TF_ATTR_TYPE): return dtypes.as_dtype(value) - elif attr_type == [pywrap_tensorflow.TF_ATTR_TYPE]: + elif attr_type == [int(pywrap_tfe.TF_ATTR_TYPE)]: return [dtypes.as_dtype(v) for v in value] - elif attr_type == pywrap_tensorflow.TF_ATTR_SHAPE: + elif attr_type == int(pywrap_tfe.TF_ATTR_SHAPE): return tensor_shape.as_shape(value).as_proto() - elif attr_type == [pywrap_tensorflow.TF_ATTR_SHAPE]: + elif attr_type == [int(pywrap_tfe.TF_ATTR_SHAPE)]: return [tensor_shape.as_shape(v).as_proto() for v in value] elif isinstance(value, str): return value.encode() @@ -141,16 +147,15 @@ def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs, return grad_fn(mock_op, *out_grads) -pywrap_tensorflow.TFE_Py_RegisterGradientFunction(_gradient_function) +pywrap_tfe.TFE_Py_RegisterGradientFunction(_gradient_function) def _must_record_gradient(): - return not pywrap_tensorflow.TFE_Py_TapeSetIsEmpty() + return not pywrap_tfe.TFE_Py_TapeSetIsEmpty() def _record_gradient(op_name, inputs, attrs, results): - return pywrap_tensorflow.TFE_Py_RecordGradient(op_name, inputs, attrs, - results) + return pywrap_tfe.TFE_Py_RecordGradient(op_name, inputs, attrs, results) execute.must_record_gradient = _must_record_gradient @@ -688,7 +693,7 @@ _default_vspace = imperative_grad.VSpace( zeros_like_fn=default_gradient.zeros_like, ones_like_fn=default_gradient.ones_like, graph_shape_fn=gen_array_ops.shape) -pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace) +pywrap_tfe.TFE_Py_RegisterVSpace(_default_vspace) def _handle_or_self(x): diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 62a808f44d7..7ffaefeac98 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -21,7 +21,7 @@ import functools from absl.testing import parameterized import numpy as np -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import pywrap_tfe from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function @@ -1014,19 +1014,19 @@ class BackpropTest(test.TestCase, parameterized.TestCase): def testGetAttrType(self): typ = backprop.op_attr_type('Add', 'T') - self.assertEqual(typ, pywrap_tensorflow.TF_ATTR_TYPE) + self.assertEqual(typ, int(pywrap_tfe.TF_ATTR_TYPE)) def testGetAttrList(self): typ = backprop.op_attr_type('MaxPool', 'ksize') - self.assertEqual(typ, [pywrap_tensorflow.TF_ATTR_INT]) + self.assertEqual(typ, [int(pywrap_tfe.TF_ATTR_INT)]) def testMakeAttrType(self): self.assertEqual(dtypes.float32, - backprop.make_attr(pywrap_tensorflow.TF_ATTR_TYPE, 1)) + backprop.make_attr(int(pywrap_tfe.TF_ATTR_TYPE), 1)) def testMakeAttrTypeList(self): self.assertEqual([dtypes.float32], - backprop.make_attr([pywrap_tensorflow.TF_ATTR_TYPE], [1])) + backprop.make_attr([int(pywrap_tfe.TF_ATTR_TYPE)], [1])) def testMulType(self): @@ -1040,7 +1040,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase): def testMakeAttrShape(self): for s in ([], None, [1, 2, 3], [None, None], [1, None, 3]): expected = tensor_shape.TensorShape(s).as_proto() - actual = backprop.make_attr(pywrap_tensorflow.TF_ATTR_SHAPE, s) + actual = backprop.make_attr(int(pywrap_tfe.TF_ATTR_SHAPE), s) self.assertEqual( expected, actual, @@ -1051,7 +1051,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase): shape_list = [[], None, [1, 2, 3], [None, None], [1, None, 3]] self.assertEqual( [tensor_shape.TensorShape(s).as_proto() for s in shape_list], - backprop.make_attr([pywrap_tensorflow.TF_ATTR_SHAPE], shape_list)) + backprop.make_attr([int(pywrap_tfe.TF_ATTR_SHAPE)], shape_list)) def testArgsGradientFunction(self): diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index 79abb3bda5d..50b81303606 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -39,7 +39,7 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python import keras -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import pywrap_tfe from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import backprop # pylint: disable=unused-import from tensorflow.python.eager import context @@ -76,10 +76,10 @@ def c_tfe_py_fastpath_execute(a, assert ctx.executing_eagerly( ), "The prototype doesn't contain C code for graph construction" try: - return pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx._handle, ctx.device_name, "MatMul", name, - ctx.op_callbacks, a, b, "transpose_a", transpose_a, - "transpose_b", transpose_b) + return pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, + "MatMul", name, ctx.op_callbacks, + a, b, "transpose_a", transpose_a, + "transpose_b", transpose_b) except core._NotOkStatusException as e: if name is not None: message = e.message + " name: " + name @@ -339,8 +339,7 @@ class MicroBenchmarks(test.Benchmark): inputs = [m] def f(): - pywrap_tensorflow.TFE_Py_Execute(ctx_handle, None, "Identity", inputs, - attrs, 1) + pywrap_tfe.TFE_Py_Execute(ctx_handle, None, "Identity", inputs, attrs, 1) self._run(f, 30000) @@ -406,8 +405,7 @@ class MicroBenchmarks(test.Benchmark): m.dtype.as_datatype_enum) def func(): - pywrap_tensorflow.TFE_Py_Execute(ctx_handle, device, "MatMul", inputs, - attrs, 1) + pywrap_tfe.TFE_Py_Execute(ctx_handle, device, "MatMul", inputs, attrs, 1) self._run(func, num_iters) diff --git a/tensorflow/python/eager/cancellation.py b/tensorflow/python/eager/cancellation.py index 308289b5826..e01ce384b76 100644 --- a/tensorflow/python/eager/cancellation.py +++ b/tensorflow/python/eager/cancellation.py @@ -18,27 +18,27 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import pywrap_tfe class CancellationManager(object): """A mechanism for cancelling blocking computation.""" def __init__(self): - self._impl = pywrap_tensorflow.TFE_NewCancellationManager() + self._impl = pywrap_tfe.TFE_NewCancellationManager() @property def is_cancelled(self): """Returns `True` if `CancellationManager.start_cancel` has been called.""" - return pywrap_tensorflow.TFE_CancellationManagerIsCancelled(self._impl) + return pywrap_tfe.TFE_CancellationManagerIsCancelled(self._impl) def start_cancel(self): """Cancels blocking operations that have been registered with this object.""" - pywrap_tensorflow.TFE_CancellationManagerStartCancel(self._impl) + pywrap_tfe.TFE_CancellationManagerStartCancel(self._impl) def get_cancelable_function(self, concrete_function): # pylint: disable=protected-access return concrete_function._experimental_with_cancellation_manager(self) def __del__(self): - pywrap_tensorflow.TFE_DeleteCancellationManager(self._impl) + pywrap_tfe.TFE_DeleteCancellationManager(self._impl) diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index b18f3ebad37..c7c35511115 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -29,11 +29,11 @@ import six from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import pywrap_tfe from tensorflow.python import tf2 +from tensorflow.python.eager import eager_util as c_api_util from tensorflow.python.eager import executor from tensorflow.python.eager import monitoring -from tensorflow.python.framework import c_api_util from tensorflow.python.framework import device as pydev from tensorflow.python.util import compat from tensorflow.python.util import is_in_graph_mode @@ -54,17 +54,17 @@ _starting_device_spec = pydev.DeviceSpec.from_string("") _MAXINT32 = 2**31 - 1 -DEVICE_PLACEMENT_EXPLICIT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_EXPLICIT -DEVICE_PLACEMENT_WARN = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_WARN -DEVICE_PLACEMENT_SILENT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT +DEVICE_PLACEMENT_EXPLICIT = pywrap_tfe.TFE_DEVICE_PLACEMENT_EXPLICIT +DEVICE_PLACEMENT_WARN = pywrap_tfe.TFE_DEVICE_PLACEMENT_WARN +DEVICE_PLACEMENT_SILENT = pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT DEVICE_PLACEMENT_SILENT_FOR_INT32 = ( - pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32) + pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32) SYNC = 0 ASYNC = 1 -MIRRORING_NONE = pywrap_tensorflow.TFE_MIRRORING_NONE -MIRRORING_ALL = pywrap_tensorflow.TFE_MIRRORING_ALL +MIRRORING_NONE = pywrap_tfe.TFE_MIRRORING_NONE +MIRRORING_ALL = pywrap_tfe.TFE_MIRRORING_ALL _KEEP_ALIVE_SECS = 600 @@ -444,7 +444,7 @@ class Context(object): self._rng = random.Random(seed) # Also clear the kernel cache, to reset any existing seeds if self._context_handle is not None: - pywrap_tensorflow.TFE_ContextClearCaches(self._context_handle) + pywrap_tfe.TFE_ContextClearCaches(self._context_handle) def _internal_operation_seed(self): """Returns a fake operation seed. @@ -463,12 +463,11 @@ class Context(object): # Store list of devices logical_devices = [] context_devices = [] - device_list = pywrap_tensorflow.TFE_ContextListDevices( - self._context_handle) + device_list = pywrap_tfe.TFE_ContextListDevices(self._context_handle) try: self._num_gpus = 0 - for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)): - dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i) + for i in range(pywrap_tfe.TF_DeviceListCount(device_list)): + dev_name = pywrap_tfe.TF_DeviceListName(device_list, i) context_devices.append(pydev.canonical_name(dev_name)) spec = pydev.DeviceSpec.from_string(dev_name) # If the job is localhost, we assume that the cluster has not yet been @@ -477,14 +476,14 @@ class Context(object): spec = spec.replace(job=None, replica=None, task=None) logical_devices.append( LogicalDevice(name=spec.to_string(), device_type=spec.device_type)) - dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i) + dev_type = pywrap_tfe.TF_DeviceListType(device_list, i) if dev_type == "GPU": self._num_gpus += 1 finally: self._logical_devices = logical_devices self._context_devices = context_devices - pywrap_tensorflow.TF_DeleteDeviceList(device_list) + pywrap_tfe.TF_DeleteDeviceList(device_list) def ensure_initialized(self): """Initialize handle and devices if not already done so.""" @@ -494,36 +493,34 @@ class Context(object): if self._initialized: return assert self._context_devices is None - opts = pywrap_tensorflow.TFE_NewContextOptions() + opts = pywrap_tfe.TFE_NewContextOptions() try: config_str = self.config.SerializeToString() - pywrap_tensorflow.TFE_ContextOptionsSetConfig(opts, config_str) + pywrap_tfe.TFE_ContextOptionsSetConfig(opts, config_str) if self._device_policy is not None: - pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy( + pywrap_tfe.TFE_ContextOptionsSetDevicePlacementPolicy( opts, self._device_policy) if self._mirroring_policy is not None: - pywrap_tensorflow.TFE_ContextOptionsSetMirroringPolicy( + pywrap_tfe.TFE_ContextOptionsSetMirroringPolicy( opts, self._mirroring_policy) if self._default_is_async == ASYNC: - pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True) + pywrap_tfe.TFE_ContextOptionsSetAsync(opts, True) if self._lazy_remote_inputs_copy is not None: - pywrap_tensorflow.TFE_ContextOptionsSetLazyRemoteInputsCopy( + pywrap_tfe.TFE_ContextOptionsSetLazyRemoteInputsCopy( opts, self._lazy_remote_inputs_copy) - context_handle = pywrap_tensorflow.TFE_NewContext(opts) + context_handle = pywrap_tfe.TFE_NewContext(opts) finally: - pywrap_tensorflow.TFE_DeleteContextOptions(opts) + pywrap_tfe.TFE_DeleteContextOptions(opts) assert not (self._server_def and self._collective_ops_server_def), ( "Cannot enable remote execution as well as collective ops at the " "moment. If this is important to you, please file an issue.") if self._server_def is not None: server_def_str = self._server_def.SerializeToString() - pywrap_tensorflow.TFE_ContextSetServerDef(context_handle, - _KEEP_ALIVE_SECS, - server_def_str) + pywrap_tfe.TFE_ContextSetServerDef(context_handle, _KEEP_ALIVE_SECS, + server_def_str) elif self._collective_ops_server_def is not None: server_def_str = self._collective_ops_server_def.SerializeToString() - pywrap_tensorflow.TFE_EnableCollectiveOps(context_handle, - server_def_str) + pywrap_tfe.TFE_EnableCollectiveOps(context_handle, server_def_str) self._context_handle = context_handle self._initialize_logical_devices() @@ -532,7 +529,7 @@ class Context(object): def _clear_caches(self): self.ones_rank_cache().flush() self.zeros_cache().flush() - pywrap_tensorflow.TFE_ClearScalarCache() + pywrap_tfe.TFE_ClearScalarCache() def get_server_def(self): return self._server_def @@ -563,8 +560,8 @@ class Context(object): if self._context_handle: server_def_str = server_def.SerializeToString() - pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle, - keep_alive_secs, server_def_str) + pywrap_tfe.TFE_ContextSetServerDef(self._context_handle, keep_alive_secs, + server_def_str) self._initialize_logical_devices() # Clear all the caches in case there are remote tensors in them. @@ -592,9 +589,8 @@ class Context(object): if self._context_handle: server_def_str = server_def.SerializeToString() - pywrap_tensorflow.TFE_ContextUpdateServerDef(self._context_handle, - keep_alive_secs, - server_def_str) + pywrap_tfe.TFE_ContextUpdateServerDef(self._context_handle, + keep_alive_secs, server_def_str) self._initialize_logical_devices() self._clear_caches() @@ -614,8 +610,7 @@ class Context(object): """ # TODO(yuefengz): support checking multiple workers. if self._context_handle: - return pywrap_tensorflow.TFE_ContextCheckAlive(self._context_handle, - worker_name) + return pywrap_tfe.TFE_ContextCheckAlive(self._context_handle, worker_name) else: raise ValueError("Context is not initialized.") @@ -808,8 +803,8 @@ class Context(object): self.executor.wait() executor_new = executor.new_executor(enable_async) self._thread_local_data.executor = executor_new - pywrap_tensorflow.TFE_ContextSetExecutorForThread( - self._context_handle, executor_new.handle()) + pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, + executor_new.handle()) else: self._default_is_async = enable_async @@ -823,13 +818,12 @@ class Context(object): def executor(self): ensure_initialized() return executor.Executor( - pywrap_tensorflow.TFE_ContextGetExecutorForThread(self._context_handle)) + pywrap_tfe.TFE_ContextGetExecutorForThread(self._context_handle)) @executor.setter def executor(self, e): ensure_initialized() - pywrap_tensorflow.TFE_ContextSetExecutorForThread(self._context_handle, - e.handle()) + pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, e.handle()) @property def config(self): @@ -1015,7 +1009,7 @@ class Context(object): fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper). """ self.ensure_initialized() - pywrap_tensorflow.TFE_ContextAddFunction(self._handle, fn) + pywrap_tfe.TFE_ContextAddFunction(self._handle, fn) def add_function_def(self, fdef): """Add a function definition to the context. @@ -1028,8 +1022,8 @@ class Context(object): """ self.ensure_initialized() fdef_string = fdef.SerializeToString() - pywrap_tensorflow.TFE_ContextAddFunctionDef( - self._handle, fdef_string, len(fdef_string)) + pywrap_tfe.TFE_ContextAddFunctionDef(self._handle, fdef_string, + len(fdef_string)) def remove_function(self, name): """Remove a function from the context. @@ -1040,12 +1034,12 @@ class Context(object): name: function signature name. """ self.ensure_initialized() - pywrap_tensorflow.TFE_ContextRemoveFunction(self._handle, name) + pywrap_tfe.TFE_ContextRemoveFunction(self._handle, name) def has_function(self, name): """Check if a function `name` is registered.""" self.ensure_initialized() - return bool(pywrap_tensorflow.TFE_ContextHasFunction(self._handle, name)) + return bool(pywrap_tfe.TFE_ContextHasFunction(self._handle, name)) def add_op_callback(self, callback): """Add a post-op callback to the context. @@ -1101,7 +1095,7 @@ class Context(object): if self._physical_devices is not None: return - devs = pywrap_tensorflow.TF_ListPhysicalDevices() + devs = pywrap_tfe.TF_ListPhysicalDevices() self._physical_devices = [ PhysicalDevice(name=d.decode(), device_type=d.decode().split(":")[1]) for d in devs] @@ -1434,7 +1428,7 @@ class Context(object): def device_policy(self): # Only get the policy from the context if it has already been initialized if self._context_handle is not None: - return pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(self._handle) + return pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(self._handle) return self._device_policy @@ -1448,14 +1442,14 @@ class Context(object): # Only set the policy if the context has already been initialized if self._context_handle is not None: - pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy( + pywrap_tfe.TFE_ContextSetThreadLocalDevicePlacementPolicy( self._handle, self._device_policy) @property def mirroring_policy(self): # Only get the policy from the context if it has already been initialized if self._context_handle is not None: - return pywrap_tensorflow.TFE_ContextGetMirroringPolicy(self._handle) + return pywrap_tfe.TFE_ContextGetMirroringPolicy(self._handle) return self._mirroring_policy @@ -1469,7 +1463,7 @@ class Context(object): # Only set the policy if the context has already been initialized if self._context_handle is not None: - pywrap_tensorflow.TFE_ContextSetThreadLocalMirroringPolicy( + pywrap_tfe.TFE_ContextSetThreadLocalMirroringPolicy( self._handle, self._mirroring_policy) @property @@ -1495,13 +1489,13 @@ class Context(object): and to stop tracing call context.disable_run_metadata(). """ self.ensure_initialized() - pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._handle) + pywrap_tfe.TFE_ContextEnableRunMetadata(self._handle) def disable_run_metadata(self): """Disables tracing of op execution via RunMetadata.""" if not self._context_handle: return - pywrap_tensorflow.TFE_ContextDisableRunMetadata(self._context_handle) + pywrap_tfe.TFE_ContextDisableRunMetadata(self._context_handle) def enable_graph_collection(self): """Enables graph collection of executed functions. @@ -1510,13 +1504,13 @@ class Context(object): and to stop collecting graphs call context.disable_graph_collection(). """ self.ensure_initialized() - pywrap_tensorflow.TFE_ContextEnableGraphCollection(self._handle) + pywrap_tfe.TFE_ContextEnableGraphCollection(self._handle) def disable_graph_collection(self): """Disables graph collection of executed functions.""" if not self._context_handle: return - pywrap_tensorflow.TFE_ContextDisableGraphCollection(self._context_handle) + pywrap_tfe.TFE_ContextDisableGraphCollection(self._context_handle) def export_run_metadata(self): """Returns a RunMetadata proto with accumulated information. @@ -1530,9 +1524,8 @@ class Context(object): if not self._context_handle: return None with c_api_util.tf_buffer() as buffer_: - pywrap_tensorflow.TFE_ContextExportRunMetadata( - self._context_handle, buffer_) - proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) + pywrap_tfe.TFE_ContextExportRunMetadata(self._context_handle, buffer_) + proto_data = pywrap_tfe.TF_GetBuffer(buffer_) run_metadata = config_pb2.RunMetadata() run_metadata.ParseFromString(compat.as_bytes(proto_data)) return run_metadata @@ -1543,10 +1536,10 @@ class Context(object): return self._context_switches def start_step(self): - pywrap_tensorflow.TFE_ContextStartStep(self._handle) + pywrap_tfe.TFE_ContextStartStep(self._handle) def end_step(self): - pywrap_tensorflow.TFE_ContextEndStep(self._handle) + pywrap_tfe.TFE_ContextEndStep(self._handle) class _EagerDeviceContext(object): @@ -1608,7 +1601,7 @@ _context_lock = threading.Lock() def _set_context_locked(ctx): global _context - pywrap_tensorflow.TFE_Py_SetEagerContext(ctx) + pywrap_tfe.TFE_Py_SetEagerContext(ctx) _context = ctx diff --git a/tensorflow/python/eager/core.py b/tensorflow/python/eager/core.py index e168b4bd5ff..5216afd12e8 100644 --- a/tensorflow/python/eager/core.py +++ b/tensorflow/python/eager/core.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import pywrap_tfe from tensorflow.python.framework import errors # Trace of execution and memory usage. @@ -46,7 +46,7 @@ class _NotOkStatusException(Exception): return "%s: %s" % (e.__class__.__name__, e) -pywrap_tensorflow.TFE_Py_RegisterExceptionClass(_NotOkStatusException) +pywrap_tfe.TFE_Py_RegisterExceptionClass(_NotOkStatusException) class _FallbackException(Exception): @@ -71,4 +71,4 @@ class _SymbolicException(Exception): pass -pywrap_tensorflow.TFE_Py_RegisterFallbackExceptionClass(_FallbackException) +pywrap_tfe.TFE_Py_RegisterFallbackExceptionClass(_FallbackException) diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index ca224fc0f05..8993efd4085 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -26,7 +26,7 @@ import threading import numpy as np from tensorflow.core.protobuf import config_pb2 -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import pywrap_tfe from tensorflow.python.eager import context from tensorflow.python.eager import core from tensorflow.python.eager import def_function @@ -602,8 +602,8 @@ class TFETest(test_util.TensorFlowTestCase): def testRegisterExceptionClass(self): with self.assertRaises(TypeError): - pywrap_tensorflow.TFE_Py_RegisterExceptionClass(str) - pywrap_tensorflow.TFE_Py_RegisterExceptionClass(core._NotOkStatusException) # pylint: disable=protected-access + pywrap_tfe.TFE_Py_RegisterExceptionClass(str) + pywrap_tfe.TFE_Py_RegisterExceptionClass(core._NotOkStatusException) # pylint: disable=protected-access # TODO(agarwal): add tests passing incorrect typed values to attrs. def testExecuteBasic(self): diff --git a/tensorflow/python/eager/eager_util.py b/tensorflow/python/eager/eager_util.py new file mode 100644 index 00000000000..7d369c876d6 --- /dev/null +++ b/tensorflow/python/eager/eager_util.py @@ -0,0 +1,61 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for using the TensorFlow Eager using the C API.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python import pywrap_tfe as c_api +from tensorflow.python.util import compat +from tensorflow.python.util import tf_contextlib + + +# We temporarily need a duplicate tf_buffer function in eager_util. The +# c_api_util is still relying on SWIG and is thus incompatible until +# we migrate over. We can delete this once we migrate tf_session.i + + +@tf_contextlib.contextmanager +def tf_buffer(data=None): + """Context manager that creates and deletes TF_Buffer. + + Example usage: + with tf_buffer() as buf: + # get serialized graph def into buf + ... + proto_data = c_api.TF_GetBuffer(buf) + graph_def.ParseFromString(compat.as_bytes(proto_data)) + # buf has been deleted + + with tf_buffer(some_string) as buf: + c_api.TF_SomeFunction(buf) + # buf has been deleted + + Args: + data: An optional `bytes`, `str`, or `unicode` object. If not None, the + yielded buffer will contain this data. + + Yields: + Created TF_Buffer + """ + if data: + buf = c_api.TF_NewBufferFromString(compat.as_bytes(data)) + else: + buf = c_api.TF_NewBuffer() + try: + yield buf + finally: + c_api.TF_DeleteBuffer(buf) diff --git a/tensorflow/python/eager/execute.py b/tensorflow/python/eager/execute.py index 7a1de7aa305..e206262309e 100644 --- a/tensorflow/python/eager/execute.py +++ b/tensorflow/python/eager/execute.py @@ -22,7 +22,7 @@ import six from google.protobuf import text_format from tensorflow.core.framework import tensor_pb2 -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import pywrap_tfe from tensorflow.python.eager import core from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -56,9 +56,8 @@ def quick_execute(op_name, num_outputs, inputs, attrs, ctx, name=None): # pylint: disable=protected-access try: ctx.ensure_initialized() - tensors = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name, - op_name, inputs, attrs, - num_outputs) + tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name, + inputs, attrs, num_outputs) except core._NotOkStatusException as e: if name is not None: message = e.message + " name: " + name @@ -111,9 +110,10 @@ def execute_with_cancellation(op_name, # pylint: disable=protected-access try: ctx.ensure_initialized() - tensors = pywrap_tensorflow.TFE_Py_ExecuteCancelable( - ctx._handle, device_name, op_name, inputs, attrs, - cancellation_manager._impl, num_outputs) + tensors = pywrap_tfe.TFE_Py_ExecuteCancelable(ctx._handle, device_name, + op_name, inputs, attrs, + cancellation_manager._impl, + num_outputs) except core._NotOkStatusException as e: if name is not None: message = e.message + " name: " + name diff --git a/tensorflow/python/eager/executor.py b/tensorflow/python/eager/executor.py index be844015dd0..cd2bf0d0398 100644 --- a/tensorflow/python/eager/executor.py +++ b/tensorflow/python/eager/executor.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import pywrap_tfe class Executor(object): @@ -45,8 +45,8 @@ class Executor(object): def __del__(self): try: - # pywrap_tensorflow.TFE_ExecutorWaitForAllPendingNodes(self._handle) - pywrap_tensorflow.TFE_DeleteExecutor(self._handle) + # pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle) + pywrap_tfe.TFE_DeleteExecutor(self._handle) except TypeError: # Suppress some exceptions, mainly for the case when we're running on # module deletion. Things that can go wrong include the pywrap module @@ -57,20 +57,20 @@ class Executor(object): # partially unloaded. def is_async(self): - return pywrap_tensorflow.TFE_ExecutorIsAsync(self._handle) + return pywrap_tfe.TFE_ExecutorIsAsync(self._handle) def handle(self): return self._handle def wait(self): """Waits for ops dispatched in this executor to finish.""" - pywrap_tensorflow.TFE_ExecutorWaitForAllPendingNodes(self._handle) + pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle) def clear_error(self): """Clears errors raised in this executor during execution.""" - pywrap_tensorflow.TFE_ExecutorClearError(self._handle) + pywrap_tfe.TFE_ExecutorClearError(self._handle) def new_executor(enable_async): - handle = pywrap_tensorflow.TFE_NewExecutor(enable_async) + handle = pywrap_tfe.TFE_NewExecutor(enable_async) return Executor(handle) diff --git a/tensorflow/python/eager/forwardprop.py b/tensorflow/python/eager/forwardprop.py index 6ddaedc2fdb..973e130ef0f 100644 --- a/tensorflow/python/eager/forwardprop.py +++ b/tensorflow/python/eager/forwardprop.py @@ -20,7 +20,7 @@ from __future__ import print_function import threading -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import pywrap_tfe from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop_util from tensorflow.python.eager import def_function @@ -166,7 +166,8 @@ def _jvp_dispatch(op_name, attr_tuple, inputs, outputs, tangents): return _jvp_relaxed_shapes( op_name, attr_tuple, inputs, outputs, tangents) -pywrap_tensorflow.TFE_Py_RegisterJVPFunction(_jvp_dispatch) + +pywrap_tfe.TFE_Py_RegisterJVPFunction(_jvp_dispatch) @tf_export("autodiff.ForwardAccumulator", v1=[]) @@ -300,7 +301,7 @@ class ForwardAccumulator(object): ValueError: If the same tensor or variable is specified multiple times in `primals`. """ - self._accumulator = pywrap_tensorflow.TFE_Py_ForwardAccumulatorNew() + self._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew() self._recording = False primal_ids = set() for primal in nest.flatten(primals): @@ -323,13 +324,13 @@ class ForwardAccumulator(object): def _push_accumulator(self): if self._recording: raise ValueError("Accumulator is already recording.") - pywrap_tensorflow.TFE_Py_ForwardAccumulatorSetAdd(self._accumulator) + pywrap_tfe.TFE_Py_ForwardAccumulatorSetAdd(self._accumulator) self._recording = True def _pop_accumulator(self): if not self._recording: raise ValueError("Accumulator is not recording.") - pywrap_tensorflow.TFE_Py_ForwardAccumulatorSetRemove(self._accumulator) + pywrap_tfe.TFE_Py_ForwardAccumulatorSetRemove(self._accumulator) self._recording = False def _watch(self, primals, tangents): @@ -358,7 +359,7 @@ class ForwardAccumulator(object): # Run convert_to_tensor to get the captured handle from whichever # function we're running if necessary. t = ops.convert_to_tensor(t.handle) - pywrap_tensorflow.TFE_Py_ForwardAccumulatorWatch(self._accumulator, t, g) + pywrap_tfe.TFE_Py_ForwardAccumulatorWatch(self._accumulator, t, g) def jvp(self, primals, unconnected_gradients=UnconnectedGradients.NONE): """Fetches the Jacobian-vector product computed for `primals`. @@ -384,8 +385,8 @@ class ForwardAccumulator(object): def _fetch_jvp(tensor): if hasattr(tensor, "handle"): tensor = ops.convert_to_tensor(tensor.handle) - result = pywrap_tensorflow.TFE_Py_ForwardAccumulatorJVP( - self._accumulator, tensor) + result = pywrap_tfe.TFE_Py_ForwardAccumulatorJVP(self._accumulator, + tensor) if result is None and unconnected_gradients == UnconnectedGradients.ZERO: return array_ops.zeros_like(tensor) return result diff --git a/tensorflow/python/eager/forwardprop_test.py b/tensorflow/python/eager/forwardprop_test.py index dd1854c8797..0f88ee2d4a6 100644 --- a/tensorflow/python/eager/forwardprop_test.py +++ b/tensorflow/python/eager/forwardprop_test.py @@ -24,7 +24,7 @@ import weakref from absl.testing import parameterized import numpy as np -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import pywrap_tfe from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.eager import backprop from tensorflow.python.eager import def_function @@ -236,13 +236,13 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase): x = constant_op.constant(1.) with forwardprop.ForwardAccumulator(x, 2.) as acc: y = x + x - pywrap_tensorflow.TFE_Py_RegisterJVPFunction( + pywrap_tfe.TFE_Py_RegisterJVPFunction( lambda *args, **kwargs: [constant_op.constant(-15.)]) z = x + x self.assertAllClose(4., acc.jvp(y)) self.assertAllClose(-15., acc.jvp(z)) finally: - pywrap_tensorflow.TFE_Py_RegisterJVPFunction(previous_fn) + pywrap_tfe.TFE_Py_RegisterJVPFunction(previous_fn) @test_util.assert_no_new_pyobjects_executing_eagerly def testFunctionCacheLimited(self): @@ -738,19 +738,19 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase): with forwardprop.ForwardAccumulator(c, c_tangent) as acc: with backprop.GradientTape() as tape: self.assertFalse(tape_lib.should_record_backprop([c])) - self.assertEqual( - 1, pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes([c])) + self.assertEqual(1, + pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c])) tape.watch(c) - self.assertEqual( - 2, pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes([c])) + self.assertEqual(2, + pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c])) self.assertTrue(tape_lib.should_record_backprop([c])) with tape_lib.stop_recording(): - self.assertEqual( - 0, pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes([c])) + self.assertEqual(0, + pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c])) self.assertFalse(tape_lib.should_record_backprop([c])) d = c * 2. - self.assertEqual( - 2, pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes([c])) + self.assertEqual(2, + pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c])) self.assertTrue(tape_lib.should_record_backprop([c])) self.assertFalse(tape_lib.should_record_backprop([d])) self.assertIsNone(acc.jvp(d)) diff --git a/tensorflow/python/eager/forwardprop_util.py b/tensorflow/python/eager/forwardprop_util.py index 07aa9511bfe..f618525d01b 100644 --- a/tensorflow/python/eager/forwardprop_util.py +++ b/tensorflow/python/eager/forwardprop_util.py @@ -24,7 +24,7 @@ from __future__ import print_function import collections import contextlib -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import pywrap_tfe class TangentInfo( @@ -54,8 +54,7 @@ def pack_tangents(tensors): tangents: A flat list of Tensors. Best interpreted as a sequence to be appended to `tensors`. """ - return TangentInfo( - *pywrap_tensorflow.TFE_Py_PackJVPs(tensors)) + return TangentInfo(*pywrap_tfe.TFE_Py_PackJVPs(tensors)) @contextlib.contextmanager @@ -73,7 +72,7 @@ def push_forwardprop_state(): None (used for its side effect). """ try: - pywrap_tensorflow.TFE_Py_ForwardAccumulatorPushState() + pywrap_tfe.TFE_Py_ForwardAccumulatorPushState() yield finally: - pywrap_tensorflow.TFE_Py_ForwardAccumulatorPopState() + pywrap_tfe.TFE_Py_ForwardAccumulatorPopState() diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index c0c6ed6bb54..7b8c5f33e77 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -32,8 +32,9 @@ from six.moves import map from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import pywrap_tfe from tensorflow.python import _pywrap_utils +from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop_util from tensorflow.python.eager import context @@ -1098,7 +1099,7 @@ class _TapeGradientFunctions(object): forward_function.signature.name, forward_outputs, forward_inputs, py_backward, None) output_indices, output_tangents = ( - pywrap_tensorflow.TFE_Py_PackJVPs(forward_outputs)) + pywrap_tfe.TFE_Py_PackJVPs(forward_outputs)) output_tangents = [forward_wrapper_graph.capture(t) for t in output_tangents] return _ForwardWrapper( @@ -1732,7 +1733,7 @@ class ConcreteFunction(object): "Tensor." % (self._func_graph.name, i, str(arg))) args = tensor_inputs + captured_inputs possible_gradient_type = ( - pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes(args)) + pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(args)) if (possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_NONE and executing_eagerly): # No tape is watching; skip to running the function. @@ -2552,8 +2553,8 @@ class Function(object): """Computes the cache key given inputs and execution context.""" if self.input_signature is None: inputs = (args, kwargs) if kwargs else args - input_signature = pywrap_tensorflow.TFE_Py_EncodeArg( - inputs, include_tensor_ranks_only) + input_signature = pywrap_tfe.TFE_Py_EncodeArg(inputs, + include_tensor_ranks_only) else: del args, kwargs assert not include_tensor_ranks_only diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py index 3aacbe4ab61..3d6cbb50fbe 100644 --- a/tensorflow/python/eager/imperative_grad.py +++ b/tensorflow/python/eager/imperative_grad.py @@ -20,7 +20,7 @@ from __future__ import print_function import collections -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import pywrap_tfe from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients from tensorflow.python.util import compat @@ -68,7 +68,7 @@ def imperative_grad(tape, raise ValueError( "Unknown value for unconnected_gradients: %r" % unconnected_gradients) - return pywrap_tensorflow.TFE_Py_TapeGradient( + return pywrap_tfe.TFE_Py_TapeGradient( tape._tape, # pylint: disable=protected-access target, sources, diff --git a/tensorflow/python/eager/monitoring.py b/tensorflow/python/eager/monitoring.py index 09838f143e0..b0c6d23f35b 100644 --- a/tensorflow/python/eager/monitoring.py +++ b/tensorflow/python/eager/monitoring.py @@ -21,80 +21,80 @@ from __future__ import print_function import collections from tensorflow.core.framework import summary_pb2 -from tensorflow.python import pywrap_tensorflow -from tensorflow.python.framework import c_api_util +from tensorflow.python import pywrap_tfe +from tensorflow.python.eager import eager_util as c_api_util from tensorflow.python.util import compat _MetricMethod = collections.namedtuple('MetricMethod', 'create delete get_cell') _counter_methods = [ _MetricMethod( - create=pywrap_tensorflow.TFE_MonitoringNewCounter0, - delete=pywrap_tensorflow.TFE_MonitoringDeleteCounter0, - get_cell=pywrap_tensorflow.TFE_MonitoringGetCellCounter0), + create=pywrap_tfe.TFE_MonitoringNewCounter0, + delete=pywrap_tfe.TFE_MonitoringDeleteCounter0, + get_cell=pywrap_tfe.TFE_MonitoringGetCellCounter0), _MetricMethod( - create=pywrap_tensorflow.TFE_MonitoringNewCounter1, - delete=pywrap_tensorflow.TFE_MonitoringDeleteCounter1, - get_cell=pywrap_tensorflow.TFE_MonitoringGetCellCounter1), + create=pywrap_tfe.TFE_MonitoringNewCounter1, + delete=pywrap_tfe.TFE_MonitoringDeleteCounter1, + get_cell=pywrap_tfe.TFE_MonitoringGetCellCounter1), _MetricMethod( - create=pywrap_tensorflow.TFE_MonitoringNewCounter2, - delete=pywrap_tensorflow.TFE_MonitoringDeleteCounter2, - get_cell=pywrap_tensorflow.TFE_MonitoringGetCellCounter2), + create=pywrap_tfe.TFE_MonitoringNewCounter2, + delete=pywrap_tfe.TFE_MonitoringDeleteCounter2, + get_cell=pywrap_tfe.TFE_MonitoringGetCellCounter2), ] _int_gauge_methods = [ _MetricMethod( - create=pywrap_tensorflow.TFE_MonitoringNewIntGauge0, - delete=pywrap_tensorflow.TFE_MonitoringDeleteIntGauge0, - get_cell=pywrap_tensorflow.TFE_MonitoringGetCellIntGauge0), + create=pywrap_tfe.TFE_MonitoringNewIntGauge0, + delete=pywrap_tfe.TFE_MonitoringDeleteIntGauge0, + get_cell=pywrap_tfe.TFE_MonitoringGetCellIntGauge0), _MetricMethod( - create=pywrap_tensorflow.TFE_MonitoringNewIntGauge1, - delete=pywrap_tensorflow.TFE_MonitoringDeleteIntGauge1, - get_cell=pywrap_tensorflow.TFE_MonitoringGetCellIntGauge1), + create=pywrap_tfe.TFE_MonitoringNewIntGauge1, + delete=pywrap_tfe.TFE_MonitoringDeleteIntGauge1, + get_cell=pywrap_tfe.TFE_MonitoringGetCellIntGauge1), _MetricMethod( - create=pywrap_tensorflow.TFE_MonitoringNewIntGauge2, - delete=pywrap_tensorflow.TFE_MonitoringDeleteIntGauge2, - get_cell=pywrap_tensorflow.TFE_MonitoringGetCellIntGauge2), + create=pywrap_tfe.TFE_MonitoringNewIntGauge2, + delete=pywrap_tfe.TFE_MonitoringDeleteIntGauge2, + get_cell=pywrap_tfe.TFE_MonitoringGetCellIntGauge2), ] _string_gauge_methods = [ _MetricMethod( - create=pywrap_tensorflow.TFE_MonitoringNewStringGauge0, - delete=pywrap_tensorflow.TFE_MonitoringDeleteStringGauge0, - get_cell=pywrap_tensorflow.TFE_MonitoringGetCellStringGauge0), + create=pywrap_tfe.TFE_MonitoringNewStringGauge0, + delete=pywrap_tfe.TFE_MonitoringDeleteStringGauge0, + get_cell=pywrap_tfe.TFE_MonitoringGetCellStringGauge0), _MetricMethod( - create=pywrap_tensorflow.TFE_MonitoringNewStringGauge1, - delete=pywrap_tensorflow.TFE_MonitoringDeleteStringGauge1, - get_cell=pywrap_tensorflow.TFE_MonitoringGetCellStringGauge1), + create=pywrap_tfe.TFE_MonitoringNewStringGauge1, + delete=pywrap_tfe.TFE_MonitoringDeleteStringGauge1, + get_cell=pywrap_tfe.TFE_MonitoringGetCellStringGauge1), _MetricMethod( - create=pywrap_tensorflow.TFE_MonitoringNewStringGauge2, - delete=pywrap_tensorflow.TFE_MonitoringDeleteStringGauge2, - get_cell=pywrap_tensorflow.TFE_MonitoringGetCellStringGauge2), + create=pywrap_tfe.TFE_MonitoringNewStringGauge2, + delete=pywrap_tfe.TFE_MonitoringDeleteStringGauge2, + get_cell=pywrap_tfe.TFE_MonitoringGetCellStringGauge2), ] _bool_gauge_methods = [ _MetricMethod( - create=pywrap_tensorflow.TFE_MonitoringNewBoolGauge0, - delete=pywrap_tensorflow.TFE_MonitoringDeleteBoolGauge0, - get_cell=pywrap_tensorflow.TFE_MonitoringGetCellBoolGauge0), + create=pywrap_tfe.TFE_MonitoringNewBoolGauge0, + delete=pywrap_tfe.TFE_MonitoringDeleteBoolGauge0, + get_cell=pywrap_tfe.TFE_MonitoringGetCellBoolGauge0), _MetricMethod( - create=pywrap_tensorflow.TFE_MonitoringNewBoolGauge1, - delete=pywrap_tensorflow.TFE_MonitoringDeleteBoolGauge1, - get_cell=pywrap_tensorflow.TFE_MonitoringGetCellBoolGauge1), + create=pywrap_tfe.TFE_MonitoringNewBoolGauge1, + delete=pywrap_tfe.TFE_MonitoringDeleteBoolGauge1, + get_cell=pywrap_tfe.TFE_MonitoringGetCellBoolGauge1), _MetricMethod( - create=pywrap_tensorflow.TFE_MonitoringNewBoolGauge2, - delete=pywrap_tensorflow.TFE_MonitoringDeleteBoolGauge2, - get_cell=pywrap_tensorflow.TFE_MonitoringGetCellBoolGauge2), + create=pywrap_tfe.TFE_MonitoringNewBoolGauge2, + delete=pywrap_tfe.TFE_MonitoringDeleteBoolGauge2, + get_cell=pywrap_tfe.TFE_MonitoringGetCellBoolGauge2), ] _sampler_methods = [ _MetricMethod( - create=pywrap_tensorflow.TFE_MonitoringNewSampler0, - delete=pywrap_tensorflow.TFE_MonitoringDeleteSampler0, - get_cell=pywrap_tensorflow.TFE_MonitoringGetCellSampler0), + create=pywrap_tfe.TFE_MonitoringNewSampler0, + delete=pywrap_tfe.TFE_MonitoringDeleteSampler0, + get_cell=pywrap_tfe.TFE_MonitoringGetCellSampler0), _MetricMethod( - create=pywrap_tensorflow.TFE_MonitoringNewSampler1, - delete=pywrap_tensorflow.TFE_MonitoringDeleteSampler1, - get_cell=pywrap_tensorflow.TFE_MonitoringGetCellSampler1), + create=pywrap_tfe.TFE_MonitoringNewSampler1, + delete=pywrap_tfe.TFE_MonitoringDeleteSampler1, + get_cell=pywrap_tfe.TFE_MonitoringGetCellSampler1), _MetricMethod( - create=pywrap_tensorflow.TFE_MonitoringNewSampler2, - delete=pywrap_tensorflow.TFE_MonitoringDeleteSampler2, - get_cell=pywrap_tensorflow.TFE_MonitoringGetCellSampler2), + create=pywrap_tfe.TFE_MonitoringNewSampler2, + delete=pywrap_tfe.TFE_MonitoringDeleteSampler2, + get_cell=pywrap_tfe.TFE_MonitoringGetCellSampler2), ] @@ -156,11 +156,11 @@ class CounterCell(object): Args: value: non-negative value. """ - pywrap_tensorflow.TFE_MonitoringCounterCellIncrementBy(self._cell, value) + pywrap_tfe.TFE_MonitoringCounterCellIncrementBy(self._cell, value) def value(self): """Retrieves the current value.""" - return pywrap_tensorflow.TFE_MonitoringCounterCellValue(self._cell) + return pywrap_tfe.TFE_MonitoringCounterCellValue(self._cell) class Counter(Metric): @@ -204,11 +204,11 @@ class IntGaugeCell(object): Args: value: integer value. """ - pywrap_tensorflow.TFE_MonitoringIntGaugeCellSet(self._cell, value) + pywrap_tfe.TFE_MonitoringIntGaugeCellSet(self._cell, value) def value(self): """Retrieves the current value.""" - return pywrap_tensorflow.TFE_MonitoringIntGaugeCellValue(self._cell) + return pywrap_tfe.TFE_MonitoringIntGaugeCellValue(self._cell) class IntGauge(Metric): @@ -252,13 +252,13 @@ class StringGaugeCell(object): Args: value: string value. """ - pywrap_tensorflow.TFE_MonitoringStringGaugeCellSet(self._cell, value) + pywrap_tfe.TFE_MonitoringStringGaugeCellSet(self._cell, value) def value(self): """Retrieves the current value.""" with c_api_util.tf_buffer() as buffer_: - pywrap_tensorflow.TFE_MonitoringStringGaugeCellValue(self._cell, buffer_) - value = pywrap_tensorflow.TF_GetBuffer(buffer_).decode('utf-8') + pywrap_tfe.TFE_MonitoringStringGaugeCellValue(self._cell, buffer_) + value = pywrap_tfe.TF_GetBuffer(buffer_).decode('utf-8') return value @@ -303,11 +303,11 @@ class BoolGaugeCell(object): Args: value: bool value. """ - pywrap_tensorflow.TFE_MonitoringBoolGaugeCellSet(self._cell, value) + pywrap_tfe.TFE_MonitoringBoolGaugeCellSet(self._cell, value) def value(self): """Retrieves the current value.""" - return pywrap_tensorflow.TFE_MonitoringBoolGaugeCellValue(self._cell) + return pywrap_tfe.TFE_MonitoringBoolGaugeCellValue(self._cell) class BoolGauge(Metric): @@ -351,7 +351,7 @@ class SamplerCell(object): Args: value: float value. """ - pywrap_tensorflow.TFE_MonitoringSamplerCellAdd(self._cell, value) + pywrap_tfe.TFE_MonitoringSamplerCellAdd(self._cell, value) def value(self): """Retrieves the current distribution of samples. @@ -360,8 +360,8 @@ class SamplerCell(object): A HistogramProto describing the distribution of samples. """ with c_api_util.tf_buffer() as buffer_: - pywrap_tensorflow.TFE_MonitoringSamplerCellValue(self._cell, buffer_) - proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) + pywrap_tfe.TFE_MonitoringSamplerCellValue(self._cell, buffer_) + proto_data = pywrap_tfe.TF_GetBuffer(buffer_) histogram_proto = summary_pb2.HistogramProto() histogram_proto.ParseFromString(compat.as_bytes(proto_data)) return histogram_proto @@ -379,7 +379,7 @@ class Buckets(object): self.buckets = buckets def __del__(self): - pywrap_tensorflow.TFE_MonitoringDeleteBuckets(self.buckets) + pywrap_tfe.TFE_MonitoringDeleteBuckets(self.buckets) class ExponentialBuckets(Buckets): @@ -399,8 +399,8 @@ class ExponentialBuckets(Buckets): bucket_count: integer """ super(ExponentialBuckets, self).__init__( - pywrap_tensorflow.TFE_MonitoringNewExponentialBuckets( - scale, growth_factor, bucket_count)) + pywrap_tfe.TFE_MonitoringNewExponentialBuckets(scale, growth_factor, + bucket_count)) class Sampler(Metric): diff --git a/tensorflow/python/eager/profiler.py b/tensorflow/python/eager/profiler.py index e2ba5f4d593..e91700b86ac 100644 --- a/tensorflow/python/eager/profiler.py +++ b/tensorflow/python/eager/profiler.py @@ -39,9 +39,9 @@ import os import threading from tensorflow.python import _pywrap_events_writer -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import pywrap_tfe from tensorflow.python.eager import context -from tensorflow.python.framework import c_api_util +from tensorflow.python.eager import eager_util as c_api_util from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat @@ -74,8 +74,8 @@ def start(): raise ProfilerAlreadyRunningError('Another profiler is running.') if context.default_execution_mode == context.EAGER_MODE: context.ensure_initialized() - _profiler = pywrap_tensorflow.TFE_NewProfiler() - if not pywrap_tensorflow.TFE_ProfilerIsOk(_profiler): + _profiler = pywrap_tfe.TFE_NewProfiler() + if not pywrap_tfe.TFE_ProfilerIsOk(_profiler): logging.warning('Another profiler session is running which is probably ' 'created by profiler server. Please avoid using profiler ' 'server and profiler APIs at the same time.') @@ -100,11 +100,9 @@ def stop(): if context.default_execution_mode == context.EAGER_MODE: context.context().executor.wait() with c_api_util.tf_buffer() as buffer_: - pywrap_tensorflow.TFE_ProfilerSerializeToString( - _profiler, - buffer_) - result = pywrap_tensorflow.TF_GetBuffer(buffer_) - pywrap_tensorflow.TFE_DeleteProfiler(_profiler) + pywrap_tfe.TFE_ProfilerSerializeToString(_profiler, buffer_) + result = pywrap_tfe.TF_GetBuffer(buffer_) + pywrap_tfe.TFE_DeleteProfiler(_profiler) _profiler = None _run_num += 1 return result @@ -159,7 +157,7 @@ def start_profiler_server(port): """ if context.default_execution_mode == context.EAGER_MODE: context.ensure_initialized() - pywrap_tensorflow.TFE_StartProfilerServer(port) + pywrap_tfe.TFE_StartProfilerServer(port) class Profiler(object): diff --git a/tensorflow/python/eager/profiler_client.py b/tensorflow/python/eager/profiler_client.py index 5d6fcb47b71..c59f8eed216 100644 --- a/tensorflow/python/eager/profiler_client.py +++ b/tensorflow/python/eager/profiler_client.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python import pywrap_tensorflow -from tensorflow.python.framework import c_api_util +from tensorflow.python import pywrap_tfe +from tensorflow.python.eager import eager_util as c_api_util from tensorflow.python.framework import errors @@ -46,7 +46,7 @@ def start_tracing(service_addr, Raises: UnavailableError: If no trace event is collected. """ - if not pywrap_tensorflow.TFE_ProfilerClientStartTracing( + if not pywrap_tfe.TFE_ProfilerClientStartTracing( service_addr, logdir, worker_list, include_dataset_ops, duration_ms, num_tracing_attempts): raise errors.UnavailableError(None, None, 'No trace event is collected.') @@ -71,7 +71,7 @@ def monitor(service_addr, A string of monitoring output. """ with c_api_util.tf_buffer() as buffer_: - pywrap_tensorflow.TFE_ProfilerClientMonitor(service_addr, duration_ms, - monitoring_level, - display_timestamp, buffer_) - return pywrap_tensorflow.TF_GetBuffer(buffer_) + pywrap_tfe.TFE_ProfilerClientMonitor(service_addr, duration_ms, + monitoring_level, display_timestamp, + buffer_) + return pywrap_tfe.TF_GetBuffer(buffer_) diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py index e29d9b7321a..f8ede96738c 100644 --- a/tensorflow/python/eager/pywrap_tfe_test.py +++ b/tensorflow/python/eager/pywrap_tfe_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import pywrap_tfe from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import core @@ -54,14 +54,16 @@ class Tests(test.TestCase): self.assertAllClose( math_ops.matmul(a_2_by_2, b_2_by_2), - pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2, - b_2_by_2, "transpose_a", False, "transpose_b", False)) + pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, + "MatMul", None, None, a_2_by_2, + b_2_by_2, "transpose_a", False, + "transpose_b", False)) self.assertAllClose( math_ops.matmul(a_100_by_784, b_100_by_784, transpose_b=True), - pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx._handle, ctx.device_name, "MatMul", None, None, a_100_by_784, - b_100_by_784, "transpose_a", False, "transpose_b", True)) + pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, + "MatMul", None, None, a_100_by_784, + b_100_by_784, "transpose_a", False, + "transpose_b", True)) @test_util.assert_no_new_tensors @test_util.assert_no_garbage_created @@ -71,12 +73,14 @@ class Tests(test.TestCase): a_2_by_2 = constant_op.constant(1.0, shape=[2, 2]) m = resource_variable_ops.ResourceVariable(a_2_by_2) - x = pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx._handle, ctx.device_name, "MatMul", None, None, m, m, "transpose_a", - False, "transpose_b", False) - y = pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2, a_2_by_2, - "transpose_a", False, "transpose_b", False) + x = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, + "MatMul", None, None, m, m, + "transpose_a", False, "transpose_b", + False) + y = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, + "MatMul", None, None, a_2_by_2, + a_2_by_2, "transpose_a", False, + "transpose_b", False) self.assertAllEqual(x, y) @@ -89,9 +93,10 @@ class Tests(test.TestCase): with backprop.GradientTape(persistent=True) as tape: a_2_by_2 = constant_op.constant(1.0, shape=[2, 2]) tape.watch(a_2_by_2) - z = pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2, - a_2_by_2, "transpose_a", False, "transpose_b", False) + z = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, + "MatMul", None, None, a_2_by_2, + a_2_by_2, "transpose_a", False, + "transpose_b", False) dz_dy = tape.gradient(z, [a_2_by_2])[0] self.assertAllEqual(dz_dy.numpy(), constant_op.constant(4.0, shape=[2, 2]).numpy()) @@ -106,9 +111,10 @@ class Tests(test.TestCase): a_2_by_2 = constant_op.constant(1.0, shape=[2, 2]) m = resource_variable_ops.ResourceVariable(a_2_by_2) tape.watch(m) - z = pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx._handle, ctx.device_name, "MatMul", None, None, m, m, - "transpose_a", False, "transpose_b", False) + z = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, + "MatMul", None, None, m, m, + "transpose_a", False, "transpose_b", + False) dz_dy = tape.gradient(z, [m])[0] self.assertAllEqual(dz_dy.numpy(), constant_op.constant(4.0, shape=[2, 2]).numpy()) @@ -125,9 +131,8 @@ class Tests(test.TestCase): self.assertAllClose( math_ops.add_n([a_2_by_2, b_2_by_2]), - pywrap_tensorflow.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, - "AddN", None, None, - [a_2_by_2, b_2_by_2])) + pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, "AddN", + None, None, [a_2_by_2, b_2_by_2])) # Tests homogeneous list op @test_util.assert_no_new_tensors @@ -142,9 +147,9 @@ class Tests(test.TestCase): with backprop.GradientTape(persistent=True) as tape: tape.watch(a_2_by_2) tape.watch(b_2_by_2) - z1 = pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx._handle, ctx.device_name, "AddN", None, None, - [a_2_by_2, b_2_by_2]) + z1 = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, + "AddN", None, None, + [a_2_by_2, b_2_by_2]) z2 = math_ops.add_n([a_2_by_2, b_2_by_2]) dz1_dy = tape.gradient(z1, [a_2_by_2])[0] dz2_dy = tape.gradient(z2, [a_2_by_2])[0] @@ -162,9 +167,9 @@ class Tests(test.TestCase): self.assertAllClose( array_ops.identity_n([a_2_by_2, b_2_by_2]), - pywrap_tensorflow.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, - "IdentityN", None, None, - [a_2_by_2, b_2_by_2])) + pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, + "IdentityN", None, None, + [a_2_by_2, b_2_by_2])) # Tests heterogeneous list op @test_util.assert_no_new_tensors @@ -179,9 +184,9 @@ class Tests(test.TestCase): with backprop.GradientTape(persistent=True) as tape: tape.watch(a_2_by_2) tape.watch(b_2_by_2) - z1 = pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx._handle, ctx.device_name, "IdentityN", None, None, - [a_2_by_2, b_2_by_2]) + z1 = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, + "IdentityN", None, None, + [a_2_by_2, b_2_by_2]) z2 = array_ops.identity_n([a_2_by_2, b_2_by_2]) dz1_dy = tape.gradient(z1[0], [a_2_by_2])[0] dz2_dy = tape.gradient(z2[0], [a_2_by_2])[0] @@ -201,19 +206,18 @@ class Tests(test.TestCase): # Not enough base params with self.assertRaisesRegexp(ValueError, "at least 5 items in the input tuple"): - pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, - "Identity") + pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, "Identity") # Not enough inputs with self.assertRaisesRegexp(ValueError, "Expected to be at least 6, was 5"): - pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx_handle, - "Identity", None, []) + pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx_handle, "Identity", + None, []) # Bad type with self.assertRaisesRegexp(TypeError, "expected a string for op_name"): - pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, - ctx_handle, None, [], a_2_by_2) + pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, ctx_handle, + None, [], a_2_by_2) @test_util.assert_no_new_tensors @test_util.assert_no_garbage_created @@ -225,9 +229,9 @@ class Tests(test.TestCase): ctx_handle = ctx._handle with self.assertRaises(core._FallbackException): - pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, - "Split", None, None, split_dim, - value, "num_split", -1) + pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, "Split", + None, None, split_dim, value, + "num_split", -1) @test_util.assert_no_new_tensors @test_util.assert_no_garbage_created @@ -266,10 +270,9 @@ class Tests(test.TestCase): ctx = context.context() ctx.ensure_initialized() with self.assertRaises(core._FallbackException): - pywrap_tensorflow.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, - "MatMul", None, None, m, m, - "transpose_a", False, - "transpose_b", False) + pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, "MatMul", + None, None, m, m, "transpose_a", False, + "transpose_b", False) def testOpDefDefaultType(self): im = np.random.randint( diff --git a/tensorflow/python/eager/remote.py b/tensorflow/python/eager/remote.py index c7e579ac2a9..64309f9c9d7 100644 --- a/tensorflow/python/eager/remote.py +++ b/tensorflow/python/eager/remote.py @@ -22,7 +22,7 @@ import copy from absl import logging from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import pywrap_tfe from tensorflow.python.distribute import device_util from tensorflow.python.distribute.cluster_resolver import cluster_resolver from tensorflow.python.eager import context @@ -127,7 +127,7 @@ def connect_to_cluster(cluster_spec_or_resolver, # Automatically add local job, if not part of the cluster spec. if job_name not in cluster_spec.jobs: - local_port = pywrap_tensorflow.TF_PickUnusedPortOrDie() + local_port = pywrap_tfe.TF_PickUnusedPortOrDie() job_def = cluster_def.job.add() job_def.name = job_name # TODO(fishx): Update this to make sure remote worker has valid ip address diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index b73aaa449f7..70a48c8b0da 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -20,7 +20,7 @@ from __future__ import print_function import contextlib -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import pywrap_tfe from tensorflow.python.util.lazy_loader import LazyLoader # There is a circular dependency between this, ops.py, and @@ -39,24 +39,23 @@ class Tape(object): self._tape = tape def watched_variables(self): - return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape) + return pywrap_tfe.TFE_Py_TapeWatchedVariables(self._tape) def push_new_tape(persistent=False, watch_accessed_variables=True): """Pushes a new tape onto the tape stack.""" - tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent, - watch_accessed_variables) + tape = pywrap_tfe.TFE_Py_TapeSetNew(persistent, watch_accessed_variables) return Tape(tape) def push_tape(tape): """Pushes an existing tape onto the tape stack.""" - pywrap_tensorflow.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access + pywrap_tfe.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access def watch(tape, tensor): """Marks this tensor to be watched by the given tape.""" - pywrap_tensorflow.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access + pywrap_tfe.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access def watch_variable(tape, variable): @@ -68,7 +67,7 @@ def watch_variable(tape, variable): else: variables = strategy.experimental_local_results(variable) for var in variables: - pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access + pywrap_tfe.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access def variable_accessed(variable): @@ -84,7 +83,7 @@ def variable_accessed(variable): else: variables = strategy.experimental_local_results(variable) for var in variables: - pywrap_tensorflow.TFE_Py_TapeVariableAccessed(var) + pywrap_tfe.TFE_Py_TapeVariableAccessed(var) def variables_accessed(variables): @@ -107,25 +106,25 @@ def variables_accessed(variables): accessed.extend(strategy.experimental_local_results(variable)) for var in accessed: - pywrap_tensorflow.TFE_Py_TapeVariableAccessed(var) + pywrap_tfe.TFE_Py_TapeVariableAccessed(var) def pop_tape(tape): """Pops the given tape in the stack.""" - pywrap_tensorflow.TFE_Py_TapeSetRemove(tape._tape) # pylint: disable=protected-access + pywrap_tfe.TFE_Py_TapeSetRemove(tape._tape) # pylint: disable=protected-access @contextlib.contextmanager def stop_recording(): """Stop all gradient recording (backprop and forwardprop).""" - is_stopped = pywrap_tensorflow.TFE_Py_TapeSetIsStopped() + is_stopped = pywrap_tfe.TFE_Py_TapeSetIsStopped() try: if not is_stopped: - pywrap_tensorflow.TFE_Py_TapeSetStopOnThread() + pywrap_tfe.TFE_Py_TapeSetStopOnThread() yield finally: if not is_stopped: - pywrap_tensorflow.TFE_Py_TapeSetRestartOnThread() + pywrap_tfe.TFE_Py_TapeSetRestartOnThread() def should_record_backprop(tensors): @@ -139,22 +138,23 @@ def should_record_backprop(tensors): Returns: Boolean, whether any tape watches any of `tensors`. """ - return pywrap_tensorflow.TFE_Py_TapeSetShouldRecordBackprop(tensors) + return pywrap_tfe.TFE_Py_TapeSetShouldRecordBackprop(tensors) def record_operation(op_type, output_tensors, input_tensors, backward_function, forward_function=None): """Records the operation on all tapes in the stack.""" - pywrap_tensorflow.TFE_Py_TapeSetRecordOperation( - op_type, output_tensors, input_tensors, backward_function, - forward_function) + pywrap_tfe.TFE_Py_TapeSetRecordOperation(op_type, output_tensors, + input_tensors, backward_function, + forward_function) def record_operation_backprop_only(op_type, output_tensors, input_tensors, backward_function): """Records the operation on all backward tapes in the stack.""" - pywrap_tensorflow.TFE_Py_TapeSetRecordOperationBackprop( - op_type, output_tensors, input_tensors, backward_function) + pywrap_tfe.TFE_Py_TapeSetRecordOperationBackprop(op_type, output_tensors, + input_tensors, + backward_function) def record_operation_forwardprop_only(op_type, output_tensors, input_tensors, @@ -174,16 +174,16 @@ def record_operation_forwardprop_only(op_type, output_tensors, input_tensors, Typically these will have come from TFE_Py_PackForwardGradients. May be None or an empty sequence if there are no JVP outputs from the operation. """ - pywrap_tensorflow.TFE_Py_TapeSetRecordOperationForwardprop( + pywrap_tfe.TFE_Py_TapeSetRecordOperationForwardprop( op_type, output_tensors, input_tensors, backward_function, forwardprop_output_indices) def delete_trace(tensor_id): """Deletes traces for this Tensor from all tapes in the stack.""" - pywrap_tensorflow.TFE_Py_TapeSetDeleteTrace(tensor_id) + pywrap_tfe.TFE_Py_TapeSetDeleteTrace(tensor_id) def could_possibly_record(): """Returns True if any tape is active.""" - return not pywrap_tensorflow.TFE_Py_TapeSetIsEmpty() + return not pywrap_tfe.TFE_Py_TapeSetIsEmpty() diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index 5f4b75b8bbd..fd961671b52 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -26,7 +26,7 @@ import unittest import numpy as np import six -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import pywrap_tfe from tensorflow.python.eager import context from tensorflow.python.eager import core from tensorflow.python.eager import test @@ -435,14 +435,14 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase): t2 = _create_tensor([[1, 2, 5], [3, 4, 5]], dtype=dtypes.int32) t3 = _create_tensor([[1], [3], [5], [6]], dtype=dtypes.int32) - r = pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2, t3], 0) + r = pywrap_tfe.TFE_Py_TensorShapeSlice([t1, t2, t3], 0) self.assertAllEqual(np.array([3, 2, 4]), r.numpy()) - r = pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2, t3], 1) + r = pywrap_tfe.TFE_Py_TensorShapeSlice([t1, t2, t3], 1) self.assertAllEqual(np.array([2, 3, 1]), r.numpy()) def testEmptyTensorList(self): - a = pywrap_tensorflow.TFE_Py_TensorShapeSlice([], 0) + a = pywrap_tfe.TFE_Py_TensorShapeSlice([], 0) self.assertTrue(isinstance(a, ops.EagerTensor)) self.assertEqual(0, a.numpy().size) @@ -452,12 +452,12 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase): with self.assertRaisesRegexp( TypeError, r"Expected a list of EagerTensors but element 1 has type \"str\""): - pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, "abc"], 0) + pywrap_tfe.TFE_Py_TensorShapeSlice([t1, "abc"], 0) with self.assertRaisesRegexp( TypeError, r"Expected a list of EagerTensors but element 0 has type \"int\""): - pywrap_tensorflow.TFE_Py_TensorShapeSlice([2, t1], 0) + pywrap_tfe.TFE_Py_TensorShapeSlice([2, t1], 0) def testTensorListNotList(self): t1 = _create_tensor([1, 2], dtype=dtypes.int32) @@ -465,7 +465,7 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase): with self.assertRaisesRegexp( TypeError, r"tensors argument must be a list or a tuple. Got.*EagerTensor"): - pywrap_tensorflow.TFE_Py_TensorShapeSlice(t1, -2) + pywrap_tfe.TFE_Py_TensorShapeSlice(t1, -2) def testNegativeSliceDim(self): t1 = _create_tensor([1, 2], dtype=dtypes.int32) @@ -473,7 +473,7 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase): with self.assertRaisesRegexp( ValueError, r"Slice dimension must be non-negative. Got -2"): - pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1], -2) + pywrap_tfe.TFE_Py_TensorShapeSlice([t1], -2) def testUnicode(self): self.assertEqual(constant_op.constant(u"asdf").numpy(), b"asdf") @@ -493,31 +493,31 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase): IndexError, r"Slice dimension \(2\) must be smaller than rank of all tensors, " "but tensor at index 0 has rank 2"): - pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1], 2) + pywrap_tfe.TFE_Py_TensorShapeSlice([t1], 2) with self.assertRaisesRegexp( IndexError, r"Slice dimension \(1\) must be smaller than rank of all tensors, " "but tensor at index 0 has rank 1"): - pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2], 1) + pywrap_tfe.TFE_Py_TensorShapeSlice([t2], 1) with self.assertRaisesRegexp( IndexError, r"Slice dimension \(1\) must be smaller than rank of all tensors, " "but tensor at index 1 has rank 1"): - pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2], 1) + pywrap_tfe.TFE_Py_TensorShapeSlice([t1, t2], 1) with self.assertRaisesRegexp( IndexError, r"Slice dimension \(0\) must be smaller than rank of all tensors, " "but tensor at index 0 has rank 0"): - pywrap_tensorflow.TFE_Py_TensorShapeSlice([t3], 0) + pywrap_tfe.TFE_Py_TensorShapeSlice([t3], 0) with self.assertRaisesRegexp( IndexError, r"Slice dimension \(0\) must be smaller than rank of all tensors, " "but tensor at index 2 has rank 0"): - pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2, t1, t3], 0) + pywrap_tfe.TFE_Py_TensorShapeSlice([t2, t1, t3], 0) @test_util.assert_no_new_pyobjects_executing_eagerly def testTensorDir(self): diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 55af78c822c..f149a61dfc9 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -36,7 +36,12 @@ from tensorflow.core.framework import node_def_pb2 from tensorflow.core.framework import op_def_pb2 from tensorflow.core.framework import versions_pb2 from tensorflow.core.protobuf import config_pb2 +# pywrap_tensorflow must be imported first to avoid profobuf issues. +# (b/143110113) +# pylint: disable=invalid-import-order,g-bad-import-order from tensorflow.python import pywrap_tensorflow as c_api +from tensorflow.python import pywrap_tfe as c_api_new +# pylint: enable=invalid-import-order,g-bad-import-order from tensorflow.python import tf2 from tensorflow.python.eager import context from tensorflow.python.eager import core @@ -249,7 +254,7 @@ def register_dense_tensor_like_type(tensor_type): def uid(): """A unique (within this program execution) integer.""" - return c_api.TFE_Py_UID() + return c_api_new.TFE_Py_UID() def numpy_text(tensor, is_repr=False): @@ -1135,7 +1140,7 @@ class _EagerTensorBase(Tensor): # This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and # registers it with the current module. -EagerTensor = c_api.TFE_Py_InitEagerTensor(_EagerTensorBase) +EagerTensor = c_api_new.TFE_Py_InitEagerTensor(_EagerTensorBase) register_dense_tensor_like_type(Tensor) diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index 093afc112d7..5912c26a5a0 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -789,8 +789,7 @@ void GenEagerPythonOp::AddEagerFastPathExecute() { strings::StrAppend(&result_, " try:\n"); strings::StrAppend( - &result_, " ", - "_result = _pywrap_tensorflow.TFE_Py_FastPathExecute(\n", + &result_, " ", "_result = pywrap_tfe.TFE_Py_FastPathExecute(\n", WordWrap(strings::StrCat(" "), strings::StrCat(fastpath_execute_params, ")"), kRightMargin), "\n"); @@ -1000,7 +999,7 @@ This file is MACHINE GENERATED! Do not edit. import collections -from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow +from tensorflow.python import pywrap_tfe as pywrap_tfe from tensorflow.python.eager import context as _context from tensorflow.python.eager import core as _core from tensorflow.python.eager import execute as _execute diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index a08e81bbc49..2757495875f 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python import pywrap_tensorflow +from tensorflow.python import pywrap_tfe from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -115,8 +116,8 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index): non_neg_concat_dim = ( concat_dim._numpy().item(0) % input_values[0]._rank()) # pylint: disable=protected-access # All inputs are guaranteed to be EagerTensors in eager mode - sizes = pywrap_tensorflow.TFE_Py_TensorShapeSlice(input_values, - non_neg_concat_dim) + sizes = pywrap_tfe.TFE_Py_TensorShapeSlice(input_values, + non_neg_concat_dim) out_grads = array_ops.split(grad, sizes, non_neg_concat_dim) else: if constant_op.is_constant(concat_dim): diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py index 04a5ddd1503..3a01ffc4704 100644 --- a/tensorflow/python/ops/logging_ops.py +++ b/tensorflow/python/ops/logging_ops.py @@ -26,7 +26,7 @@ import sys from absl import logging import six -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import pywrap_tfe from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor @@ -45,7 +45,7 @@ from tensorflow.python.util.tf_export import tf_export # Register printing to the cell output if we are in a Colab or Jupyter Notebook. try: get_ipython() # Exists in an ipython env like Jupyter or Colab - pywrap_tensorflow.TFE_Py_EnableInteractivePythonLogging() + pywrap_tfe.TFE_Py_EnableInteractivePythonLogging() except NameError: pass diff --git a/tensorflow/python/platform/base.i b/tensorflow/python/platform/base.i index 25fffcfb2d2..65a56f91b93 100644 --- a/tensorflow/python/platform/base.i +++ b/tensorflow/python/platform/base.i @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/c/tf_status.h" #include "tensorflow/core/platform/types.h" + #include "tensorflow/c/tf_datatype.h" #include "tensorflow/python/lib/core/py_exception_registry.h" using tensorflow::uint64; @@ -233,7 +234,50 @@ _COPY_TYPEMAPS(unsigned int, mode_t); %define override %enddef #endif + +// This was originally included in pywrap_tfe.i, but is used by tf_session.i %include "tensorflow/c/tf_status.h" +%include "tensorflow/c/tf_datatype.h" + +%typemap(in) (const void* proto) { + char* c_string; + Py_ssize_t py_size; + // PyBytes_AsStringAndSize() does not copy but simply interprets the input + if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) { + // Python has raised an error (likely TypeError or UnicodeEncodeError). + SWIG_fail; + } + $1 = static_cast(c_string); +} + +%typemap(in) int64_t { + $1 = PyLong_AsLongLong($input); +} + +%typemap(out) TF_DataType { + $result = PyInt_FromLong($1); +} + +%typemap(out) int64_t { + $result = PyInt_FromLong($1); +} + +%typemap(out) TF_AttrType { + $result = PyInt_FromLong($1); +} + +%typemap(in, numinputs=0) unsigned char* is_list (unsigned char tmp) { + tmp = 0; + $1 = &tmp; +} + +%typemap(argout) unsigned char* is_list { + if (*$1 == 1) { + PyObject* list = PyList_New(1); + PyList_SetItem(list, 0, $result); + $result = list; + } +} // Typemaps to automatically raise a Python exception from bad output TF_Status. // TODO(b/77295559): expand this to all TF_Status* output params and deprecate diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i deleted file mode 100755 index d28a1d1f593..00000000000 --- a/tensorflow/python/pywrap_tfe.i +++ /dev/null @@ -1,515 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -%include "tensorflow/python/lib/core/strings.i" -%include "tensorflow/python/platform/base.i" - -%{ -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/python/lib/core/ndarray_tensor.h" -#include "tensorflow/python/lib/core/safe_ptr.h" -%} - - -%include "tensorflow/c/tf_datatype.h" -%include "tensorflow/c/tf_status.h" - -%ignoreall; - -%rename("%s") TF_SetXlaEnableLazyCompilation; -%rename("%s") TF_SetTfXlaCpuGlobalJit; -%rename("%s") TF_SetXlaAutoJitMode; -%rename("%s") TF_SetXlaConstantFoldingDisabled; -%rename("%s") TF_GetXlaConstantFoldingDisabled; -%rename("%s") TF_SetXlaMinClusterSize; -%rename("%s") TFE_NewContext; -%rename("%s") TFE_DeleteContext; -%rename("%s") TFE_ContextListDevices; -%rename("%s") TFE_ContextAddFunction; -%rename("%s") TFE_ContextAddFunctionDef; -%rename("%s") TFE_ContextRemoveFunction; -%rename("%s") TFE_ContextHasFunction; -%rename("%s") TFE_ContextEnableRunMetadata; -%rename("%s") TFE_ContextDisableRunMetadata; -%rename("%s") TFE_ContextEnableGraphCollection; -%rename("%s") TFE_ContextDisableGraphCollection; -%rename("%s") TFE_ContextExportRunMetadata; -%rename("%s") TFE_ContextClearCaches; -%rename("%s") TFE_ContextGetDevicePlacementPolicy; -%rename("%s") TFE_ContextGetMirroringPolicy; -%rename("%s") TFE_ContextSetThreadLocalDevicePlacementPolicy; -%rename("%s") TFE_ContextSetThreadLocalMirroringPolicy; -%rename("%s") TFE_ContextSetServerDef; -%rename("%s") TFE_ContextUpdateServerDef; -%rename("%s") TFE_ContextCheckAlive; -%rename("%s") TFE_NewExecutor; -%rename("%s") TFE_DeleteExecutor; -%rename("%s") TFE_ExecutorIsAsync; -%rename("%s") TFE_ExecutorWaitForAllPendingNodes; -%rename("%s") TFE_ExecutorClearError; -%rename("%s") TFE_ContextSetExecutorForThread; -%rename("%s") TFE_ContextGetExecutorForThread; -%rename("%s") TFE_NewProfiler; -%rename("%s") TFE_ProfilerIsOk; -%rename("%s") TFE_DeleteProfiler; -%rename("%s") TFE_ProfilerSerializeToString; -%rename("%s") TFE_StartProfilerServer; -%rename("%s") TFE_ProfilerClientStartTracing; -%rename("%s") TFE_ProfilerClientMonitor; -%rename("%s") TFE_OpNameGetAttrType; -%rename("%s") TFE_Py_InitEagerTensor; -%rename("%s") TFE_Py_SetEagerTensorProfiler; -%rename("%s") TFE_Py_RegisterExceptionClass; -%rename("%s") TFE_Py_RegisterJVPFunction; -%rename("%s") TFE_Py_RegisterGradientFunction; -%rename("%s") TFE_Py_RegisterFallbackExceptionClass; -%rename("%s") TFE_Py_Execute; -%rename("%s") TFE_Py_ExecuteCancelable; -%rename("%s") TFE_Py_FastPathExecute; -%rename("%s") TFE_Py_RecordGradient; -%rename("%s") TFE_Py_UID; -%rename("%s") TFE_Py_TapeSetNew; -%rename("%s") TFE_Py_TapeSetAdd; -%rename("%s") TFE_Py_TapeSetRemove; -%rename("%s") TFE_Py_TapeSetStopOnThread; -%rename("%s") TFE_Py_TapeSetRestartOnThread; -%rename("%s") TFE_Py_TapeSetIsStopped; -%rename("%s") TFE_Py_TapeSetIsEmpty; -%rename("%s") TFE_Py_TapeSetShouldRecordBackprop; -%rename("%s") TFE_Py_TapeSetPossibleGradientTypes; -%rename("%s") TFE_Py_TapeSetDeleteTrace; -%rename("%s") TFE_Py_TapeSetRecordOperation; -%rename("%s") TFE_Py_TapeSetRecordOperationBackprop; -%rename("%s") TFE_Py_TapeSetRecordOperationForwardprop; -%rename("%s") TFE_Py_TapeGradient; -%rename("%s") TFE_Py_TapeVariableAccessed; -%rename("%s") TFE_Py_TapeWatch; -%rename("%s") TFE_Py_TapeWatchVariable; -%rename("%s") TFE_Py_TapeWatchedVariables; -%rename("%s") TFE_Py_ForwardAccumulatorNew; -%rename("%s") TFE_Py_ForwardAccumulatorSetAdd; -%rename("%s") TFE_Py_ForwardAccumulatorSetRemove; -%rename("%s") TFE_Py_ForwardAccumulatorWatch; -%rename("%s") TFE_Py_ForwardAccumulatorJVP; -%rename("%s") TFE_Py_ForwardAccumulatorPushState; -%rename("%s") TFE_Py_ForwardAccumulatorPopState; -%rename("%s") TFE_Py_PackJVPs; -%rename("%s") TFE_NewContextOptions; -%rename("%s") TFE_ContextOptionsSetConfig; -%rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy; -%rename("%s") TFE_ContextOptionsSetMirroringPolicy; -%rename("%s") TFE_ContextOptionsSetAsync; -%rename("%s") TFE_ContextOptionsSetLazyRemoteInputsCopy; -%rename("%s") TFE_DeleteContextOptions; -%rename("%s") TFE_Py_TensorShapeSlice; -%rename("%s") TFE_Py_TensorShapeOnDevice; -%rename("%s") TFE_Py_EnableInteractivePythonLogging; -%rename("%s") TFE_Py_SetEagerContext; -%rename("%s") TFE_ContextStartStep; -%rename("%s") TFE_ContextEndStep; -%rename("%s") TFE_Py_RegisterVSpace; -%rename("%s") TFE_Py_EncodeArg; -%rename("%s") TFE_EnableCollectiveOps; -%rename("%s") TF_ListPhysicalDevices; -%rename("%s") TF_PickUnusedPortOrDie; -%rename("%s") TFE_MonitoringCounterCellIncrementBy; -%rename("%s") TFE_MonitoringCounterCellValue; -%rename("%s") TFE_MonitoringNewCounter0; -%rename("%s") TFE_MonitoringDeleteCounter0; -%rename("%s") TFE_MonitoringGetCellCounter0; -%rename("%s") TFE_MonitoringNewCounter1; -%rename("%s") TFE_MonitoringDeleteCounter1; -%rename("%s") TFE_MonitoringGetCellCounter1; -%rename("%s") TFE_MonitoringNewCounter2; -%rename("%s") TFE_MonitoringDeleteCounter2; -%rename("%s") TFE_MonitoringGetCellCounter2; -%rename("%s") TFE_MonitoringIntGaugeCellSet; -%rename("%s") TFE_MonitoringIntGaugeCellValue; -%rename("%s") TFE_MonitoringNewIntGauge0; -%rename("%s") TFE_MonitoringDeleteIntGauge0; -%rename("%s") TFE_MonitoringGetCellIntGauge0; -%rename("%s") TFE_MonitoringNewIntGauge1; -%rename("%s") TFE_MonitoringDeleteIntGauge1; -%rename("%s") TFE_MonitoringGetCellIntGauge1; -%rename("%s") TFE_MonitoringNewIntGauge2; -%rename("%s") TFE_MonitoringDeleteIntGauge2; -%rename("%s") TFE_MonitoringGetCellIntGauge2; -%rename("%s") TFE_MonitoringStringGaugeCellSet; -%rename("%s") TFE_MonitoringStringGaugeCellValue; -%rename("%s") TFE_MonitoringNewStringGauge0; -%rename("%s") TFE_MonitoringDeleteStringGauge0; -%rename("%s") TFE_MonitoringGetCellStringGauge0; -%rename("%s") TFE_MonitoringNewStringGauge1; -%rename("%s") TFE_MonitoringDeleteStringGauge1; -%rename("%s") TFE_MonitoringGetCellStringGauge1; -%rename("%s") TFE_MonitoringNewStringGauge2; -%rename("%s") TFE_MonitoringDeleteStringGauge2; -%rename("%s") TFE_MonitoringGetCellStringGauge2; -%rename("%s") TFE_MonitoringBoolGaugeCellSet; -%rename("%s") TFE_MonitoringBoolGaugeCellValue; -%rename("%s") TFE_MonitoringNewBoolGauge0; -%rename("%s") TFE_MonitoringDeleteBoolGauge0; -%rename("%s") TFE_MonitoringGetCellBoolGauge0; -%rename("%s") TFE_MonitoringNewBoolGauge1; -%rename("%s") TFE_MonitoringDeleteBoolGauge1; -%rename("%s") TFE_MonitoringGetCellBoolGauge1; -%rename("%s") TFE_MonitoringNewBoolGauge2; -%rename("%s") TFE_MonitoringDeleteBoolGauge2; -%rename("%s") TFE_MonitoringGetCellBoolGauge2; -%rename("%s") TFE_MonitoringSamplerCellAdd; -%rename("%s") TFE_MonitoringSamplerCellValue; -%rename("%s") TFE_MonitoringNewExponentialBuckets; -%rename("%s") TFE_MonitoringDeleteBuckets; -%rename("%s") TFE_MonitoringNewSampler0; -%rename("%s") TFE_MonitoringDeleteSampler0; -%rename("%s") TFE_MonitoringGetCellSampler0; -%rename("%s") TFE_MonitoringNewSampler1; -%rename("%s") TFE_MonitoringDeleteSampler1; -%rename("%s") TFE_MonitoringGetCellSampler1; -%rename("%s") TFE_MonitoringNewSampler2; -%rename("%s") TFE_MonitoringDeleteSampler2; -%rename("%s") TFE_MonitoringGetCellSampler2; -%rename("%s") TFE_NewCancellationManager; -%rename("%s") TFE_CancellationManagerIsCancelled; -%rename("%s") TFE_CancellationManagerStartCancel; -%rename("%s") TFE_DeleteCancellationManager; -%rename("%s") TF_ImportGraphDefOptionsSetValidateColocationConstraints; -%rename("%s") TFE_ClearScalarCache; - -%{ -#include "tensorflow/python/eager/pywrap_tfe.h" -#include "tensorflow/python/util/util.h" -#include "tensorflow/c/c_api_experimental.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/c/eager/c_api_experimental.h" -#include "tensorflow/core/common_runtime/device_factory.h" - -static PyObject* TF_ListPhysicalDevices(TF_Status* status) { - std::vector devices; - tensorflow::Status s = tensorflow::DeviceFactory::ListAllPhysicalDevices(&devices); - tensorflow::Set_TF_Status_from_Status(status, s); - if (!s.ok()) { - Py_RETURN_NONE; - }; - PyObject* result = PyList_New(devices.size()); - int i = 0; - for (auto& dev : devices) { - PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size()); - PyList_SetItem(result, i, dev_obj); - ++i; - } - return result; -} -%} -static PyObject* TF_ListPhysicalDevices(TF_Status* status); - -%{ -#include "tensorflow/python/eager/pywrap_tensor_conversion.h" - -static PyObject* TFE_ClearScalarCache() { - tensorflow::TFE_TensorHandleCache::Get()->Clear(); - Py_RETURN_NONE; -} -%} -static PyObject* TFE_ClearScalarCache(); - -%typemap(in) (const void* proto) { - char* c_string; - Py_ssize_t py_size; - // PyBytes_AsStringAndSize() does not copy but simply interprets the input - if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) { - // Python has raised an error (likely TypeError or UnicodeEncodeError). - SWIG_fail; - } - $1 = static_cast(c_string); -} - -%typemap(in) int64_t { - $1 = PyLong_AsLongLong($input); -} - -%typemap(out) TF_DataType { - $result = PyInt_FromLong($1); -} - -%typemap(out) int64_t { - $result = PyInt_FromLong($1); -} - -%typemap(out) TF_AttrType { - $result = PyInt_FromLong($1); -} - -%typemap(in, numinputs=0) unsigned char* is_list (unsigned char tmp) { - tmp = 0; - $1 = &tmp; -} - -%typemap(argout) unsigned char* is_list { - if (*$1 == 1) { - PyObject* list = PyList_New(1); - PyList_SetItem(list, 0, $result); - $result = list; - } -} - -// For const parameters in a function, SWIG pretty much ignores the const. -// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13 -// Hence the 'const_cast'. -%typemap(in) const char* serialized_function_def { - $1 = const_cast(TFE_GetPythonString($input)); -} - -// For const parameters in a function, SWIG pretty much ignores the const. -// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13 -// Hence the 'const_cast'. -%typemap(in) const char* device_name { - if ($input == Py_None) { - $1 = nullptr; - } else { - $1 = const_cast(TFE_GetPythonString($input)); - } -} - -// For const parameters in a function, SWIG pretty much ignores the const. -// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13 -// Hence the 'const_cast'. -%typemap(in) const char* op_name { - $1 = const_cast(TFE_GetPythonString($input)); -} - -// For const parameters in a function, SWIG pretty much ignores the const. -// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13 -// Hence the 'const_cast'. -%typemap(in) const char* name { - $1 = const_cast(TFE_GetPythonString($input)); -} - - -// For const parameters in a function, SWIG pretty much ignores the const. -// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13 -// Hence the 'const_cast'. -%typemap(in) const char* description { - $1 = const_cast(TFE_GetPythonString($input)); -} - -// For const parameters in a function, SWIG pretty much ignores the const. -// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13 -// Hence the 'const_cast'. -%typemap(in) const char* label { - $1 = const_cast(TFE_GetPythonString($input)); -} - -// For const parameters in a function, SWIG pretty much ignores the const. -// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13 -// Hence the 'const_cast'. -%typemap(in) const char* label1 { - $1 = const_cast(TFE_GetPythonString($input)); -} - -// For const parameters in a function, SWIG pretty much ignores the const. -// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13 -// Hence the 'const_cast'. -%typemap(in) const char* label2 { - $1 = const_cast(TFE_GetPythonString($input)); -} - -// For const parameters in a function, SWIG pretty much ignores the const. -// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13 -// Hence the 'const_cast'. -%typemap(in) const char* service_addr { - $1 = const_cast(TFE_GetPythonString($input)); -} - -// For const parameters in a function, SWIG pretty much ignores the const. -// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13 -// Hence the 'const_cast'. -%typemap(in) const char* logdir { - $1 = const_cast(TFE_GetPythonString($input)); -} - -// For const parameters in a function, SWIG pretty much ignores the const. -// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13 -// Hence the 'const_cast'. -%typemap(in) const char* worker_list { - $1 = const_cast(TFE_GetPythonString($input)); -} - -%typemap(in) (TFE_Context*) { - $1 = (TFE_Context*)PyCapsule_GetPointer($input, nullptr); - -} -%typemap(out) (TFE_Context*) { - // When the TFE_Context* returned is a nullptr, we expect the status is not - // OK. This will raise an error (happens in another typemap). - if ($1 != nullptr) { - $result = PyCapsule_New($1, nullptr, TFE_DeleteContextCapsule); - } -} - -%rename("%s") TFE_ContextDevicePlacementPolicy; -%rename("%s") TFE_DEVICE_PLACEMENT_EXPLICIT; -%rename("%s") TFE_DEVICE_PLACEMENT_WARN; -%rename("%s") TFE_DEVICE_PLACEMENT_SILENT; -%rename("%s") TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32; - -%rename("%s") TFE_ContextMirroringPolicy; -%rename("%s") TFE_MIRRORING_NONE; -%rename("%s") TFE_MIRRORING_ALL; - -%include "tensorflow/c/eager/c_api.h" - -%typemap(in) TFE_InputTensorHandles* inputs (TFE_InputTensorHandles temp) { - $1 = &temp; - if ($input != Py_None) { - if (!PyList_Check($input)) { - SWIG_exception_fail(SWIG_TypeError, - "must provide a list of Tensors as inputs"); - } - Py_ssize_t len = PyList_Size($input); - $1->resize(len); - for (Py_ssize_t i = 0; i < len; ++i) { - PyObject* elem = PyList_GetItem($input, i); - if (!elem) { - SWIG_fail; - } - if (EagerTensor_CheckExact(elem)) { - (*$1)[i] = EagerTensor_Handle(elem); - } else if (tensorflow::swig::IsEagerTensorSlow(elem)) { - // Use equivalent of object.__getattribute__ to get the underlying - // tf wrapped EagerTensor (if there is one). - tensorflow::Safe_PyObjectPtr tf_should_use_attr( -#if PY_MAJOR_VERSION < 3 - PyString_InternFromString("_tf_should_use_wrapped_value") -#else - PyUnicode_InternFromString("_tf_should_use_wrapped_value") -#endif - ); - tensorflow::Safe_PyObjectPtr value_attr( - PyObject_GenericGetAttr(elem, tf_should_use_attr.get())); - if (value_attr) { - // This is an EagerTensor wrapped inside a TFShouldUse wrapped object. - (*$1)[i] = EagerTensor_Handle(value_attr.get()); - } else { - // This is a subclass of EagerTensor that we don't support. - PyErr_Clear(); - SWIG_exception_fail( - SWIG_TypeError, - tensorflow::strings::StrCat( - "Saw an object that is an instance of a strict subclass of " - "EagerTensor, which is not supported. Item ", - i, " is type: ", elem->ob_type->tp_name) - .c_str()); - } - } else if (tensorflow::swig::IsTensor(elem)) { - // If it isnt an EagerTensor, but is still a Tensor, it must be a graph - // tensor. - tensorflow::Safe_PyObjectPtr name_attr( - PyObject_GetAttrString(elem, "name")); - SWIG_exception_fail( - SWIG_TypeError, - tensorflow::strings::StrCat( - "An op outside of the function building code is being passed\n" - "a \"Graph\" tensor. It is possible to have Graph tensors\n" - "leak out of the function building context by including a\n" - "tf.init_scope in your function building code.\n" - "For example, the following function will fail:\n", - " @tf.function\n", - " def has_init_scope():\n", - " my_constant = tf.constant(1.)\n", - " with tf.init_scope():\n", - " added = my_constant * 2\n", - "The graph tensor has name: ", - name_attr ? TFE_GetPythonString(name_attr.get()) : "" - ).c_str()); - } else { - SWIG_exception_fail( - SWIG_TypeError, - tensorflow::strings::StrCat( - "provided list of inputs contains objects other " - "than 'EagerTensor'. Item ", - i, " is type: ", elem->ob_type->tp_name).c_str()); - } - } - } -} - -// Temporary for the argout -%typemap(in) TFE_OutputTensorHandles* outputs (TFE_OutputTensorHandles temp) { - if (!PyInt_Check($input)) { - SWIG_exception_fail(SWIG_TypeError, - "expected an integer value (size of the number of " - "outputs of the operation)"); - } - $1 = &temp; - long sz = PyInt_AsLong($input); - if (sz > 0) { - $1->resize(PyInt_AsLong($input), nullptr); - } -} - -// Create new Status object. -%typemap(in, numinputs=0) TF_Status *out_status { - $1 = GetStatus(); -} - -%typemap(freearg) (TF_Status* out_status) { - ReturnStatus($1); -} - -%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status) { - if (MaybeRaiseExceptionFromTFStatus($2, nullptr)) { - SWIG_fail; - } else { - int num_outputs = $1->size(); - Py_CLEAR($result); - $result = PyList_New(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - PyObject *output; - output = EagerTensorFromHandle($1->at(i)); - PyList_SetItem($result, i, output); - } - } -} - -// SWIG usually unwraps the tuple that the native Python/C interface generates. -// Since we wanted to have a function with a variable length of arguments, we -// used the native Python/C interface directly (which by default supports -// passing all arguments as a tuple). -%native(TFE_Py_FastPathExecute) TFE_Py_FastPathExecute_C; - -%include "tensorflow/python/eager/pywrap_tfe.h" -%include "tensorflow/c/c_api_experimental.h" -%include "tensorflow/c/eager/c_api_experimental.h" - -// Clear all typemaps. -%typemap(out) TF_DataType; -%typemap(in) int64_t; -%typemap(out) int64_t; -%typemap(out) TF_AttrType; -%typemap(in, numinputs=0) TF_Status *out_status; -%typemap(argout) unsigned char* is_list; -%typemap(in) const char* description; -%typemap(in) const char* label1; -%typemap(in) const char* label2; -%typemap(in) (TFE_Context*); -%typemap(out) (TFE_Context*); -%typemap(in) TFE_OutputTensorHandles* outputs (TFE_OutputTensorHandles temp); -%typemap(in, numinputs=0) TF_Status *out_status; -%typemap(freearg) (TF_Status* out_status); -%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status); -%typemap(in) (const void* proto); - -%unignoreall diff --git a/tensorflow/python/pywrap_tfe.py b/tensorflow/python/pywrap_tfe.py new file mode 100644 index 00000000000..8c591e9bf45 --- /dev/null +++ b/tensorflow/python/pywrap_tfe.py @@ -0,0 +1,29 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python module for TFE ops and functions exported by pybind11. + +This module is created because we are splitting out eager bindings from +pywrap_tensorflow. This is causing some issues where Graphs are not properly +initialized when running eager code. Once the graph architecture has been +removed from pywrap_tensorflow as well, we can remove this file. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=invalid-import-order,g-bad-import-order, wildcard-import, unused-import +from tensorflow.python import pywrap_tensorflow +from tensorflow.python._pywrap_tfe import * diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index 761e6f376f8..e7077185343 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -17,8 +17,6 @@ limitations under the License. * The includes are intentionally not alphabetically sorted, as the order of * includes follows dependency order */ -%include "tensorflow/python/pywrap_tfe.i" - %include "tensorflow/python/client/tf_session.i" %include "tensorflow/python/lib/io/file_io.i" diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc new file mode 100644 index 00000000000..284159762a8 --- /dev/null +++ b/tensorflow/python/tfe_wrapper.cc @@ -0,0 +1,1099 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License");; +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "Python.h" +#include "include/pybind11/chrono.h" +#include "include/pybind11/complex.h" +#include "include/pybind11/functional.h" +#include "include/pybind11/pybind11.h" +#include "include/pybind11/stl.h" +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/python/eager/pywrap_tensor_conversion.h" +#include "tensorflow/python/eager/pywrap_tfe.h" +#include "tensorflow/python/lib/core/py_exception_registry.h" +#include "tensorflow/python/lib/core/pybind11_lib.h" +#include "tensorflow/python/lib/core/pybind11_status.h" +#include "tensorflow/python/lib/core/safe_ptr.h" +#include "tensorflow/python/util/util.h" + +namespace py = pybind11; + +PYBIND11_MAKE_OPAQUE(TFE_Executor); +PYBIND11_MAKE_OPAQUE(TFE_ContextOptions); +PYBIND11_MAKE_OPAQUE(TFE_CancellationManager); +PYBIND11_MAKE_OPAQUE(TFE_Profiler); + +PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter0); +PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter1); +PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter2); +PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge0); +PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge1); +PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge2); +PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge0); +PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge1); +PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge2); +PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge0); +PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge1); +PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge2); +PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler0); +PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler1); +PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler2); +PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounterCell); +PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGaugeCell); +PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGaugeCell); +PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGaugeCell); +PYBIND11_MAKE_OPAQUE(TFE_MonitoringSamplerCell); + +PYBIND11_MAKE_OPAQUE(TF_DeviceList); +PYBIND11_MAKE_OPAQUE(TF_Function); +PYBIND11_MAKE_OPAQUE(TF_Buffer); + +// Eager helper functions migrated from pywrap_tfe.i. + +namespace tensorflow { + +// We cannot use Context as an opaque type. SWIG also had +// difficult directly passing the pointer around. These +// typemaps are migrated over from pywrap_tfe.i. I tried +// using a custom type caster, but we get segfaults periodically. + +// TODO(amitpatankar): Move input and output logic of Context into a +// pybind11 custom type caster. + +TFE_Context* InputTFE_Context(const py::handle& ctx) { + return static_cast(PyCapsule_GetPointer(ctx.ptr(), nullptr)); +} + +PyObject* OutputTFE_Context(TFE_Context* context) { + return PyCapsule_New(context, nullptr, TFE_DeleteContextCapsule); +} + +TF_Buffer* ProtoStringToTFBuffer(PyObject* input) { + // Convert a Python string object to TF_Buffer. + char* c_string; + Py_ssize_t py_size; + // PyBytes_AsStringAndSize() does not copy but simply interprets the input + if (PyBytes_AsStringAndSize(input, &c_string, &py_size) == -1) { + // Python has raised an error (likely TypeError or UnicodeEncodeError). + throw py::error_already_set(); + } + return TF_NewBufferFromString(static_cast(c_string), + static_cast(py_size)); +} + +// These functions are typemaps from the Python side. I did not use +// a custom type caster since the logic is slightly harder to follow. This +// converter is also only used once in `TFE_Py_ExecuteCancelable_wrapper`. +TFE_InputTensorHandles InputTFE_InputTensorHandles( + const py::handle& input_tensors) { + TFE_InputTensorHandles input_tensor_handles; + if (input_tensors.ptr() != Py_None) { + if (!PyList_Check(input_tensors.ptr())) { + tensorflow::throwTypeError("must provide a list of Tensors as inputs"); + } + Py_ssize_t len = PyList_Size(input_tensors.ptr()); + input_tensor_handles.resize(len); + for (Py_ssize_t i = 0; i < len; ++i) { + PyObject* elem = PyList_GetItem(input_tensors.ptr(), i); + if (!elem) { + tensorflow::throwTypeError("Input Tensor does not exist."); + } + if (EagerTensor_CheckExact(elem)) { + (input_tensor_handles)[i] = EagerTensor_Handle(elem); + } else if (tensorflow::swig::IsEagerTensorSlow(elem)) { + // Use equivalent of object.__getattribute__ to get the underlying + // tf wrapped EagerTensor (if there is one). + tensorflow::Safe_PyObjectPtr tf_should_use_attr( +#if PY_MAJOR_VERSION < 3 + PyString_InternFromString("_tf_should_use_wrapped_value") +#else + PyUnicode_InternFromString("_tf_should_use_wrapped_value") +#endif + ); + tensorflow::Safe_PyObjectPtr value_attr( + PyObject_GenericGetAttr(elem, tf_should_use_attr.get())); + if (value_attr) { + // This is an EagerTensor wrapped inside a TFShouldUse wrapped object. + (input_tensor_handles)[i] = EagerTensor_Handle(value_attr.get()); + } else { + // This is a subclass of EagerTensor that we don't support. + PyErr_Clear(); + tensorflow::throwTypeError( + tensorflow::strings::StrCat( + "Saw an object that is an instance of a strict subclass of " + "EagerTensor, which is not supported. Item ", + i, " is type: ", elem->ob_type->tp_name) + .c_str()); + } + } else if (tensorflow::swig::IsTensor(elem)) { + // If it isnt an EagerTensor, but is still a Tensor, it must be a graph + // tensor. + tensorflow::Safe_PyObjectPtr name_attr( + PyObject_GetAttrString(elem, "name")); + tensorflow::throwTypeError( + tensorflow::strings::StrCat( + "An op outside of the function building code is being passed\n" + "a \"Graph\" tensor. It is possible to have Graph tensors\n" + "leak out of the function building context by including a\n" + "tf.init_scope in your function building code.\n" + "For example, the following function will fail:\n", + " @tf.function\n", " def has_init_scope():\n", + " my_constant = tf.constant(1.)\n", + " with tf.init_scope():\n", + " added = my_constant * 2\n", + "The graph tensor has name: ", + name_attr ? TFE_GetPythonString(name_attr.get()) : "") + .c_str()); + } else { + tensorflow::throwTypeError( + tensorflow::strings::StrCat( + "provided list of inputs contains objects other " + "than 'EagerTensor'. Item ", + i, " is type: ", elem->ob_type->tp_name) + .c_str()); + } + } + } + return input_tensor_handles; +} + +// These functions are typemaps from the Python side. I did not use +// a custom type caster since the logic is slightly harder to follow. This +// converter is also only used once in `TFE_Py_ExecuteCancelable_wrapper`. +// This function actually takes a number rather than an output Tensor holder. +TFE_OutputTensorHandles InputTFE_OutputTensorHandles( + const py::handle& num_outputs) { + TFE_OutputTensorHandles output_tensor_handles; +#if PY_MAJOR_VERSION < 3 + if (!PyInt_Check(num_outputs.ptr())) { +#else + if (!PyLong_Check(num_outputs.ptr())) { +#endif + PyErr_SetString(PyExc_TypeError, + "expected an integer value (size of the number of " + "outputs of the operation)"); + throw py::error_already_set(); + } +#if PY_MAJOR_VERSION < 3 + long sz = PyInt_AsLong(num_outputs.ptr()); // NOLINT +#else + long sz = PyLong_AsLong(num_outputs.ptr()); // NOLINT +#endif + if (sz > 0) { +#if PY_MAJOR_VERSION < 3 + output_tensor_handles.resize(PyInt_AsLong(num_outputs.ptr()), nullptr); +#else + output_tensor_handles.resize(PyLong_AsLong(num_outputs.ptr()), nullptr); +#endif + } + return output_tensor_handles; +} + +// This function was created from fusing the typemap logic in platform/base.i. +py::object TFE_Py_ExecuteCancelable_wrapper( + const py::handle& context, const char* device_name, const char* op_name, + const py::handle& inputs, const py::handle& attrs, + TFE_CancellationManager* cancellation_manager, + const py::handle& num_outputs) { + TFE_Context* ctx = tensorflow::InputTFE_Context(context); + TFE_InputTensorHandles input_tensor_handles = + InputTFE_InputTensorHandles(inputs); + TFE_OutputTensorHandles output_tensor_handles = + InputTFE_OutputTensorHandles(num_outputs); + tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); + TFE_Py_ExecuteCancelable(ctx, device_name, op_name, &input_tensor_handles, + attrs.ptr(), cancellation_manager, + &output_tensor_handles, status.get()); + + int output_len = output_tensor_handles.size(); + PyObject* output_list = PyList_New(output_len); + for (int i = 0; i < output_len; ++i) { + PyObject* output; + output = EagerTensorFromHandle(output_tensor_handles.at(i)); + PyList_SetItem(output_list, i, output); + } + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return tensorflow::pyo_or_throw(output_list); +} + +static py::object TF_ListPhysicalDevices() { + tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); + std::vector devices; + tensorflow::Status s = + tensorflow::DeviceFactory::ListAllPhysicalDevices(&devices); + tensorflow::Set_TF_Status_from_Status(status.get(), s); + if (!s.ok()) { + return py::none(); + } + PyObject* result = PyList_New(devices.size()); + int i = 0; + for (auto& dev : devices) { + PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size()); + PyList_SetItem(result, i, dev_obj); + ++i; + } + return tensorflow::pyo_or_throw(result); +} + +static py::object TFE_ClearScalarCache() { + tensorflow::TFE_TensorHandleCache::Get()->Clear(); + return py::none(); +} + +} // namespace tensorflow + +// py::return_value_policy::reference is defined as specified by the +// pybind11 documents listed here. +// https://pybind11.readthedocs.io/en/stable/advanced/functions.html#return-value-policies +// This means that C++ maintains ownership of the object. We +// are only assigning this to functions that return opaque types. + +PYBIND11_MODULE(_pywrap_tfe, m) { + py::class_ TFE_Context_class(m, "TFE_Context"); + py::class_ TFE_Executor_class(m, "TFE_Executor"); + py::class_ TFE_ContextOptions_class(m, + "TFE_ContextOptions"); + py::class_ TFE_MonitoringCounter0_class( + m, "TFE_MonitoringCounter0"); + py::class_ TFE_MonitoringCounter1_class( + m, "TFE_MonitoringCounter1"); + py::class_ TFE_MonitoringCounter2_class( + m, "TFE_MonitoringCounter2"); + py::class_ TFE_MonitoringStringGauge0_class( + m, "TFE_MonitoringStringGauge0"); + py::class_ TFE_MonitoringStringGauge1_class( + m, "TFE_MonitoringStringGauge1"); + py::class_ TFE_MonitoringStringGauge2_class( + m, "TFE_MonitoringStringGauge2"); + py::class_ TFE_MonitoringIntGauge0_class( + m, "TFE_MonitoringIntGauge0"); + py::class_ TFE_MonitoringIntGauge1_class( + m, "TFE_MonitoringIntGauge1"); + py::class_ TFE_MonitoringIntGauge2_class( + m, "TFE_MonitoringIntGauge2"); + py::class_ TFE_MonitoringBoolGauge0_class( + m, "TFE_MonitoringBoolGauge0"); + py::class_ TFE_MonitoringBoolGauge1_class( + m, "TFE_MonitoringBoolGauge1"); + py::class_ TFE_MonitoringBoolGauge2_class( + m, "TFE_MonitoringBoolGauge2"); + py::class_ TFE_MonitoringCounterCell_class( + m, "TFE_MonitoringCounterCell"); + py::class_ TFE_MonitoringIntGaugeCell_class( + m, "TFE_MonitoringIntGaugeCell"); + py::class_ TFE_MonitoringStringGaugeCell_class( + m, "TFE_MonitoringStringGaugeCell"); + py::class_ TFE_MonitoringBoolGaugeCell_class( + m, "TFE_MonitoringBoolGaugeCell"); + py::class_ TFE_MonitoringSamplerCell_class( + m, "TFE_MonitoringSamplerCell"); + py::class_ TFE_MonitoringBuckets_class( + m, "TFE_MonitoringBuckets"); + py::class_ TFE_MonitoringSampler0_class( + m, "TFE_MonitoringSampler0"); + py::class_ TFE_MonitoringSampler1_class( + m, "TFE_MonitoringSampler1"); + py::class_ TFE_MonitoringSampler2_class( + m, "TFE_MonitoringSampler2"); + py::class_ TFE_CancellationManager_class( + m, "TFE_CancellationManager"); + py::class_ TFE_Profiler_class(m, "TFE_Profiler"); + + py::class_ TF_DeviceList_class(m, "TF_DeviceList"); + py::class_ TF_Function_class(m, "TF_Function"); + py::class_ TF_Buffer_class(m, "TF_Buffer"); + + m.def("TFE_Py_RegisterExceptionClass", [](const py::handle& e) { + return tensorflow::pyo_or_throw(TFE_Py_RegisterExceptionClass(e.ptr())); + }); + m.def("TFE_Py_RegisterFallbackExceptionClass", [](const py::handle& e) { + return tensorflow::pyo_or_throw( + TFE_Py_RegisterFallbackExceptionClass(e.ptr())); + }); + + // XLA Eager Logic + m.def("TF_SetXlaEnableLazyCompilation", &TF_SetXlaEnableLazyCompilation); + m.def("TF_SetTfXlaCpuGlobalJit", &TF_SetTfXlaCpuGlobalJit); + m.def("TF_SetXlaAutoJitMode", &TF_SetXlaAutoJitMode); + m.def("TF_SetXlaConstantFoldingDisabled", &TF_SetXlaConstantFoldingDisabled); + m.def("TF_GetXlaConstantFoldingDisabled", &TF_GetXlaConstantFoldingDisabled); + m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize); + + // // TFE_Context Logic + m.def( + "TFE_NewContext", + [](const TFE_ContextOptions* opts) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + TFE_Context* context = TFE_NewContext(opts, status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return tensorflow::pyo_or_throw(tensorflow::OutputTFE_Context(context)); + }, + py::return_value_policy::reference); + m.def("TFE_DeleteContext", [](py::handle& o) { + TFE_DeleteContext(tensorflow::InputTFE_Context(o)); + }); + m.def( + "TFE_ContextListDevices", + [](py::handle& o) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + auto output = TFE_ContextListDevices(tensorflow::InputTFE_Context(o), + status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }, + py::return_value_policy::reference); + m.def("TFE_ContextAddFunction", [](py::handle& ctx, py::handle& func) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + SwigPyObject* sstable_swig = reinterpret_cast(func.ptr()); + auto function = reinterpret_cast(sstable_swig->ptr); + TFE_ContextAddFunction(tensorflow::InputTFE_Context(ctx), function, + status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + }); + m.def("TFE_ContextAddFunctionDef", + [](py::handle& ctx, const char* serialized_function_def, size_t size) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + TFE_ContextAddFunctionDef(tensorflow::InputTFE_Context(ctx), + serialized_function_def, size, + status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + }); + m.def("TFE_ContextRemoveFunction", [](py::handle& ctx, const char* name) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + TFE_ContextRemoveFunction(tensorflow::InputTFE_Context(ctx), name, + status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + }); + m.def("TFE_ContextHasFunction", [](py::handle& ctx, const char* name) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + auto output = + TFE_ContextHasFunction(tensorflow::InputTFE_Context(ctx), name); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }); + m.def("TFE_ContextEnableRunMetadata", [](py::handle& ctx) { + TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx)); + }); + m.def("TFE_ContextDisableRunMetadata", [](py::handle& ctx) { + TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx)); + }); + m.def("TFE_ContextEnableGraphCollection", [](py::handle& ctx) { + TFE_ContextEnableGraphCollection(tensorflow::InputTFE_Context(ctx)); + }); + m.def("TFE_ContextDisableGraphCollection", [](py::handle& ctx) { + TFE_ContextDisableGraphCollection(tensorflow::InputTFE_Context(ctx)); + }); + m.def("TFE_ContextExportRunMetadata", [](py::handle& ctx, TF_Buffer& buf) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + TFE_ContextExportRunMetadata(tensorflow::InputTFE_Context(ctx), &buf, + status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + }); + m.def("TFE_ContextClearCaches", [](py::handle& o) { + TFE_ContextClearCaches(tensorflow::InputTFE_Context(o)); + }); + m.def("TFE_ContextGetDevicePlacementPolicy", [](py::handle& ctx) { + return TFE_ContextGetDevicePlacementPolicy( + tensorflow::InputTFE_Context(ctx)); + }); + m.def("TFE_ContextGetMirroringPolicy", [](py::handle& ctx) { + return TFE_ContextGetMirroringPolicy(tensorflow::InputTFE_Context(ctx)); + }); + m.def("TFE_ContextSetThreadLocalDevicePlacementPolicy", + [](py::handle& ctx, TFE_ContextDevicePlacementPolicy policy) { + TFE_ContextSetThreadLocalDevicePlacementPolicy( + tensorflow::InputTFE_Context(ctx), policy); + }); + m.def("TFE_ContextSetThreadLocalMirroringPolicy", + [](py::handle& ctx, TFE_ContextMirroringPolicy policy) { + TFE_ContextSetThreadLocalMirroringPolicy( + tensorflow::InputTFE_Context(ctx), policy); + }); + m.def("TFE_ContextSetServerDef", [](py::handle& ctx, int keep_alive_secs, + py::str proto) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + tensorflow::Safe_TF_BufferPtr buf = + tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr())); + TFE_ContextSetServerDef(tensorflow::InputTFE_Context(ctx), keep_alive_secs, + buf.get()->data, buf.get()->length, status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + }); + m.def("TFE_ContextUpdateServerDef", [](py::handle& ctx, int keep_alive_secs, + py::str proto) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + tensorflow::Safe_TF_BufferPtr buf = + tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr())); + TFE_ContextUpdateServerDef(tensorflow::InputTFE_Context(ctx), + keep_alive_secs, buf.get()->data, + buf.get()->length, status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + }); + m.def("TFE_ContextCheckAlive", [](py::handle& ctx, const char* worker_name) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + bool output = TFE_ContextCheckAlive(tensorflow::InputTFE_Context(ctx), + worker_name, status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }); + + // TFE_Executor logic + m.def( + "TFE_NewExecutor", + [](const bool is_async) { + TFE_Executor* exc = TFE_NewExecutor(is_async); + return exc; + }, + py::return_value_policy::reference); + m.def("TFE_DeleteExecutor", &TFE_DeleteExecutor); + m.def("TFE_ExecutorIsAsync", &TFE_ExecutorIsAsync); + m.def("TFE_ExecutorWaitForAllPendingNodes", [](TFE_Executor& exc) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + TFE_ExecutorWaitForAllPendingNodes(&exc, status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + }); + m.def("TFE_ExecutorClearError", &TFE_ExecutorClearError); + m.def("TFE_ContextSetExecutorForThread", [](py::handle& ctx, + TFE_Executor& exc) { + TFE_ContextSetExecutorForThread(tensorflow::InputTFE_Context(ctx), &exc); + }); + m.def( + "TFE_ContextGetExecutorForThread", + [](py::handle& o) { + return TFE_ContextGetExecutorForThread(tensorflow::InputTFE_Context(o)); + }, + py::return_value_policy::reference); + + // Profiler Logic + m.def("TFE_NewProfiler", &TFE_NewProfiler, + py::return_value_policy::reference); + m.def("TFE_ProfilerIsOk", &TFE_ProfilerIsOk); + m.def("TFE_DeleteProfiler", &TFE_DeleteProfiler); + m.def("TFE_ProfilerSerializeToString", + [](TFE_Profiler& profiler, TF_Buffer& buf) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + TFE_ProfilerSerializeToString(&profiler, &buf, status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + }); + m.def("TFE_StartProfilerServer", &TFE_StartProfilerServer); + m.def( + "TFE_ProfilerClientStartTracing", + [](const char* service_addr, const char* logdir, const char* worker_list, + bool include_dataset_ops, int duration_ms, int num_tracing_attempts) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + bool output = TFE_ProfilerClientStartTracing( + service_addr, logdir, worker_list, include_dataset_ops, duration_ms, + num_tracing_attempts, status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }); + m.def("TFE_ProfilerClientMonitor", + [](const char* service_addr, int duration_ms, int monitoring_level, + bool display_timestamp, TF_Buffer& result) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + TFE_ProfilerClientMonitor(service_addr, duration_ms, monitoring_level, + display_timestamp, &result, status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + }); + + m.def("TFE_OpNameGetAttrType", + [](py::handle& ctx, const char* op_or_function_name, + const char* attr_name) { + int temp = 0; + unsigned char* is_list = reinterpret_cast(&temp); + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + auto output = TFE_OpNameGetAttrType(tensorflow::InputTFE_Context(ctx), + op_or_function_name, attr_name, + is_list, status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); +#if PY_MAJOR_VERSION < 3 + PyObject* output_pyo = PyInt_FromLong(output); +#else + PyObject* output_pyo = PyLong_FromLong(output); +#endif + if (*is_list == 1) { + PyObject* list = PyList_New(1); + PyList_SetItem(list, 0, output_pyo); + return tensorflow::pyo_or_throw(list); + } + return tensorflow::pyo_or_throw(output_pyo); + }); + m.def("TFE_Py_InitEagerTensor", [](const py::handle& o) { + return tensorflow::pyo_or_throw(TFE_Py_InitEagerTensor(o.ptr())); + }); + m.def("TFE_Py_SetEagerTensorProfiler", &TFE_Py_SetEagerTensorProfiler); + m.def("TFE_Py_RegisterJVPFunction", [](const py::handle& o) { + return tensorflow::pyo_or_throw(TFE_Py_RegisterJVPFunction(o.ptr())); + }); + m.def("TFE_Py_RegisterGradientFunction", [](const py::handle& o) { + return tensorflow::pyo_or_throw(TFE_Py_RegisterGradientFunction(o.ptr())); + }); + m.def("TFE_Py_Execute", + [](const py::handle& context, const char* device_name, + const char* op_name, const py::handle& inputs, + const py::handle& attrs, const py::handle& num_outputs) { + return tensorflow::TFE_Py_ExecuteCancelable_wrapper( + context, device_name, op_name, inputs, attrs.ptr(), nullptr, + num_outputs); + }); + m.def( + "TFE_Py_ExecuteCancelable", + [](const py::handle& context, const char* device_name, + const char* op_name, const py::handle& inputs, const py::handle& attrs, + TFE_CancellationManager& cancellation_manager, + const py::handle& num_outputs) { + return tensorflow::TFE_Py_ExecuteCancelable_wrapper( + context, device_name, op_name, inputs, attrs.ptr(), + &cancellation_manager, num_outputs); + }); + m.def("TFE_Py_FastPathExecute", [](const py::args args) { + // First argument is a PyObject which is unused. + // https://docs.python.org/3/c-api/structures.html#METH_VARARGS + // TFE_Py_FastPathExecute requires error checking prior to returning. + return tensorflow::pyo_or_throw( + TFE_Py_FastPathExecute_C(nullptr, args.ptr())); + }); + m.def("TFE_Py_RecordGradient", + [](const py::handle& op_name, const py::handle& inputs, + const py::handle& attrs, const py::handle& results) { + return tensorflow::pyo_or_throw(TFE_Py_RecordGradient( + op_name.ptr(), inputs.ptr(), attrs.ptr(), results.ptr())); + }); + m.def("TFE_Py_UID", []() { return tensorflow::pyo_or_throw(TFE_Py_UID()); }); + + // TFE_Py_Tape Logic + m.def("TFE_Py_TapeSetNew", [](const py::handle& persistent, + const py::handle& watch_accessed_variables) { + return tensorflow::pyo_or_throw( + TFE_Py_TapeSetNew(persistent.ptr(), watch_accessed_variables.ptr())); + }); + m.def("TFE_Py_TapeSetAdd", + [](const py::handle& tape) { TFE_Py_TapeSetAdd(tape.ptr()); }); + m.def("TFE_Py_TapeSetRemove", + [](const py::handle& tape) { TFE_Py_TapeSetRemove(tape.ptr()); }); + m.def("TFE_Py_TapeSetStopOnThread", &TFE_Py_TapeSetStopOnThread); + m.def("TFE_Py_TapeSetRestartOnThread", &TFE_Py_TapeSetRestartOnThread); + m.def("TFE_Py_TapeSetIsStopped", + []() { return tensorflow::pyo_or_throw(TFE_Py_TapeSetIsStopped()); }); + m.def("TFE_Py_TapeSetIsEmpty", + []() { return tensorflow::pyo_or_throw(TFE_Py_TapeSetIsEmpty()); }); + m.def("TFE_Py_TapeSetShouldRecordBackprop", [](const py::handle& tensors) { + return tensorflow::pyo_or_throw( + TFE_Py_TapeSetShouldRecordBackprop(tensors.ptr())); + }); + m.def("TFE_Py_TapeSetPossibleGradientTypes", [](const py::handle& tensors) { + return tensorflow::pyo_or_throw( + TFE_Py_TapeSetPossibleGradientTypes(tensors.ptr())); + }); + m.def("TFE_Py_TapeSetDeleteTrace", &TFE_Py_TapeSetDeleteTrace); + m.def("TFE_Py_TapeSetRecordOperation", + [](const py::handle& op_type, const py::handle& output_tensors, + const py::handle& input_tensors, const py::handle& backward_function, + const py::handle& forward_function) { + return tensorflow::pyo_or_throw(TFE_Py_TapeSetRecordOperation( + op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(), + backward_function.ptr(), forward_function.ptr())); + }); + m.def( + "TFE_Py_TapeSetRecordOperationBackprop", + [](const py::handle& op_type, const py::handle& output_tensors, + const py::handle& input_tensors, const py::handle& backward_function) { + return tensorflow::pyo_or_throw(TFE_Py_TapeSetRecordOperationBackprop( + op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(), + backward_function.ptr())); + }); + m.def("TFE_Py_TapeSetRecordOperationForwardprop", + [](const py::handle& op_type, const py::handle& output_tensors, + const py::handle& input_tensors, const py::handle& backward_function, + const py::handle& forwardprop_output_indices) { + return tensorflow::pyo_or_throw( + TFE_Py_TapeSetRecordOperationForwardprop( + op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(), + backward_function.ptr(), forwardprop_output_indices.ptr())); + }); + m.def("TFE_Py_TapeGradient", + [](const py::handle& tape, const py::handle& target, + const py::handle& sources, const py::handle& output_gradients, + const py::handle& sources_raw, + const py::handle& unconnected_gradients) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + PyObject* output = TFE_Py_TapeGradient( + tape.ptr(), target.ptr(), sources.ptr(), output_gradients.ptr(), + sources_raw.ptr(), unconnected_gradients.ptr(), status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return tensorflow::pyo_or_throw(output); + }); + + m.def("TFE_Py_TapeVariableAccessed", [](const py::handle& variable) { + TFE_Py_TapeVariableAccessed(variable.ptr()); + }); + m.def("TFE_Py_TapeWatch", + [](const py::handle& tape, const py::handle& tensor) { + TFE_Py_TapeWatch(tape.ptr(), tensor.ptr()); + }); + m.def("TFE_Py_TapeWatchVariable", + [](const py::handle& tape, const py::handle& variable) { + TFE_Py_TapeWatchVariable(tape.ptr(), variable.ptr()); + }); + m.def("TFE_Py_TapeWatchedVariables", [](const py::handle& tape) { + return tensorflow::pyo_or_throw(TFE_Py_TapeWatchedVariables(tape.ptr())); + }); + + m.def("TFE_Py_ForwardAccumulatorNew", []() { + return tensorflow::pyo_or_throw(TFE_Py_ForwardAccumulatorNew()); + }); + + m.def("TFE_Py_ForwardAccumulatorSetAdd", [](const py::handle& accumulator) { + return tensorflow::pyo_or_throw( + TFE_Py_ForwardAccumulatorSetAdd(accumulator.ptr())); + }); + m.def("TFE_Py_ForwardAccumulatorSetRemove", + [](const py::handle& accumulator) { + TFE_Py_ForwardAccumulatorSetRemove(accumulator.ptr()); + }); + + m.def("TFE_Py_ForwardAccumulatorWatch", + [](const py::handle& accumulator, const py::handle& tensor, + const py::handle& tangent) { + TFE_Py_ForwardAccumulatorWatch(accumulator.ptr(), tensor.ptr(), + tangent.ptr()); + }); + m.def("TFE_Py_ForwardAccumulatorJVP", + [](const py::handle& accumulator, const py::handle& tensor) { + return tensorflow::pyo_or_throw( + TFE_Py_ForwardAccumulatorJVP(accumulator.ptr(), tensor.ptr())); + }); + m.def("TFE_Py_ForwardAccumulatorPushState", []() { + return tensorflow::pyo_or_throw(TFE_Py_ForwardAccumulatorPushState()); + }); + m.def("TFE_Py_ForwardAccumulatorPopState", []() { + return tensorflow::pyo_or_throw(TFE_Py_ForwardAccumulatorPopState()); + }); + m.def("TFE_Py_PackJVPs", [](const py::handle& tensors) { + return tensorflow::pyo_or_throw(TFE_Py_PackJVPs(tensors.ptr())); + }); + + // TFE_ContextOptions Logic + m.def("TFE_NewContextOptions", &TFE_NewContextOptions, + py::return_value_policy::reference); + m.def("TFE_ContextOptionsSetConfig", [](TFE_ContextOptions* options, + py::str proto) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + tensorflow::Safe_TF_BufferPtr buf = + tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr())); + TFE_ContextOptionsSetConfig(options, buf.get()->data, buf.get()->length, + status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + }); + m.def("TFE_ContextOptionsSetDevicePlacementPolicy", + &TFE_ContextOptionsSetDevicePlacementPolicy); + m.def("TFE_ContextOptionsSetLazyRemoteInputsCopy", + &TFE_ContextOptionsSetLazyRemoteInputsCopy); + m.def("TFE_ContextOptionsSetMirroringPolicy", + &TFE_ContextOptionsSetMirroringPolicy); + m.def("TFE_ContextOptionsSetAsync", &TFE_ContextOptionsSetAsync); + m.def("TFE_DeleteContextOptions", &TFE_DeleteContextOptions, + py::return_value_policy::reference); + + // TFE_Py_TensorShape Logic + m.def("TFE_Py_TensorShapeSlice", + [](const py::handle& tensors, int slice_dim) { + return tensorflow::pyo_or_throw( + TFE_Py_TensorShapeSlice(tensors.ptr(), slice_dim)); + }); + m.def("TFE_Py_TensorShapeOnDevice", [](const py::handle& tensors, + int slice_dim) { + return tensorflow::pyo_or_throw(TFE_Py_TensorShapeOnDevice(tensors.ptr())); + }); + m.def("TFE_Py_EnableInteractivePythonLogging", + &TFE_Py_EnableInteractivePythonLogging); + + // Additional Context Logic + m.def("TFE_Py_SetEagerContext", [](const py::handle& o) { + return tensorflow::pyo_or_throw(TFE_Py_SetEagerContext(o.ptr())); + }); + m.def("TFE_ContextStartStep", [](py::handle& o) { + TFE_ContextStartStep(tensorflow::InputTFE_Context(o.ptr())); + }); + m.def("TFE_ContextEndStep", &TFE_ContextEndStep); + m.def("TFE_Py_RegisterVSpace", [](const py::handle& o) { + return tensorflow::pyo_or_throw(TFE_Py_RegisterVSpace(o.ptr())); + }); + m.def("TFE_Py_EncodeArg", + [](const py::handle& o, bool include_tensor_ranks_only) { + return tensorflow::pyo_or_throw( + TFE_Py_EncodeArg(o.ptr(), include_tensor_ranks_only)); + }); + m.def("TFE_EnableCollectiveOps", [](const py::handle& ctx, py::str proto) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + tensorflow::Safe_TF_BufferPtr buf = + tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr())); + TFE_EnableCollectiveOps(tensorflow::InputTFE_Context(ctx), buf.get()->data, + buf.get()->length, status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + }); + m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices); + m.def("TF_DeleteDeviceList", &TF_DeleteDeviceList, + py::return_value_policy::reference); + m.def("TF_DeviceListCount", &TF_DeviceListCount); + m.def("TF_DeviceListName", [](const TF_DeviceList* list, int index) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + auto output = TF_DeviceListName(list, index, status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }); + m.def("TF_DeviceListType", [](const TF_DeviceList* list, int index) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + auto output = TF_DeviceListType(list, index, status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }); + + m.def("TF_PickUnusedPortOrDie", &TF_PickUnusedPortOrDie); + + // TFE_MonitoringCounter Logic + m.def("TFE_MonitoringCounterCellIncrementBy", + &TFE_MonitoringCounterCellIncrementBy); + m.def("TFE_MonitoringCounterCellValue", &TFE_MonitoringCounterCellValue); + m.def( + "TFE_MonitoringNewCounter0", + [](const char* name, const char* description) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + auto output = + TFE_MonitoringNewCounter0(name, status.get(), description); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }, + py::return_value_policy::reference); + m.def("TFE_MonitoringDeleteCounter0", &TFE_MonitoringDeleteCounter0, + py::return_value_policy::reference); + m.def("TFE_MonitoringGetCellCounter0", &TFE_MonitoringGetCellCounter0, + py::return_value_policy::reference); + m.def( + "TFE_MonitoringNewCounter1", + [](const char* name, const char* description, const char* label1) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + auto output = + TFE_MonitoringNewCounter1(name, status.get(), description, label1); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }, + py::return_value_policy::reference); + m.def("TFE_MonitoringDeleteCounter1", &TFE_MonitoringDeleteCounter1, + py::return_value_policy::reference); + m.def("TFE_MonitoringGetCellCounter1", &TFE_MonitoringGetCellCounter1, + py::return_value_policy::reference); + m.def( + "TFE_MonitoringNewCounter2", + [](const char* name, const char* description, const char* label1, + const char* label2) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + auto output = TFE_MonitoringNewCounter2(name, status.get(), description, + label1, label2); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }, + py::return_value_policy::reference); + m.def("TFE_MonitoringDeleteCounter2", &TFE_MonitoringDeleteCounter2, + py::return_value_policy::reference); + m.def("TFE_MonitoringGetCellCounter2", &TFE_MonitoringGetCellCounter2, + py::return_value_policy::reference); + + // TFE_MonitoringIntGauge Logic + m.def("TFE_MonitoringIntGaugeCellSet", &TFE_MonitoringIntGaugeCellSet); + m.def("TFE_MonitoringIntGaugeCellValue", &TFE_MonitoringIntGaugeCellValue); + m.def( + "TFE_MonitoringNewIntGauge0", + [](const char* name, const char* description) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + auto output = + TFE_MonitoringNewIntGauge0(name, status.get(), description); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }, + py::return_value_policy::reference); + m.def("TFE_MonitoringDeleteIntGauge0", &TFE_MonitoringDeleteIntGauge0, + py::return_value_policy::reference); + m.def("TFE_MonitoringGetCellIntGauge0", &TFE_MonitoringGetCellIntGauge0, + py::return_value_policy::reference); + m.def( + "TFE_MonitoringNewIntGauge1", + [](const char* name, const char* description, const char* label1) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + auto output = + TFE_MonitoringNewIntGauge1(name, status.get(), description, label1); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }, + py::return_value_policy::reference); + m.def("TFE_MonitoringDeleteIntGauge1", &TFE_MonitoringDeleteIntGauge1, + py::return_value_policy::reference); + m.def("TFE_MonitoringGetCellIntGauge1", &TFE_MonitoringGetCellIntGauge1, + py::return_value_policy::reference); + m.def( + "TFE_MonitoringNewIntGauge2", + [](const char* name, const char* description, const char* label1, + const char* label2) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + auto output = TFE_MonitoringNewIntGauge2(name, status.get(), + description, label1, label2); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }, + py::return_value_policy::reference); + m.def("TFE_MonitoringDeleteIntGauge2", &TFE_MonitoringDeleteIntGauge2, + py::return_value_policy::reference); + m.def("TFE_MonitoringGetCellIntGauge2", &TFE_MonitoringGetCellIntGauge2, + py::return_value_policy::reference); + m.def("TFE_MonitoringStringGaugeCellSet", &TFE_MonitoringStringGaugeCellSet); + m.def("TFE_MonitoringStringGaugeCellValue", + &TFE_MonitoringStringGaugeCellValue); + m.def( + "TFE_MonitoringNewStringGauge0", + [](const char* name, const char* description) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + auto output = + TFE_MonitoringNewStringGauge0(name, status.get(), description); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }, + py::return_value_policy::reference); + + // TFE_MonitoringStringGauge Logic + m.def("TFE_MonitoringDeleteStringGauge0", &TFE_MonitoringDeleteStringGauge0); + m.def("TFE_MonitoringGetCellStringGauge0", &TFE_MonitoringGetCellStringGauge0, + py::return_value_policy::reference); + m.def( + "TFE_MonitoringNewStringGauge1", + [](const char* name, const char* description, const char* label1) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + auto output = TFE_MonitoringNewStringGauge1(name, status.get(), + description, label1); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }, + py::return_value_policy::reference); + m.def("TFE_MonitoringDeleteStringGauge1", &TFE_MonitoringDeleteStringGauge1); + m.def("TFE_MonitoringGetCellStringGauge1", &TFE_MonitoringGetCellStringGauge1, + py::return_value_policy::reference); + m.def( + "TFE_MonitoringNewStringGauge2", + [](const char* name, const char* description, const char* label1, + const char* label2) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + auto output = TFE_MonitoringNewStringGauge2( + name, status.get(), description, label1, label2); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }, + py::return_value_policy::reference); + m.def("TFE_MonitoringDeleteStringGauge2", &TFE_MonitoringDeleteStringGauge2); + m.def("TFE_MonitoringGetCellStringGauge2", &TFE_MonitoringGetCellStringGauge2, + py::return_value_policy::reference); + + // TFE_MonitoringBoolGauge Logic + m.def("TFE_MonitoringBoolGaugeCellSet", &TFE_MonitoringBoolGaugeCellSet); + m.def("TFE_MonitoringBoolGaugeCellValue", &TFE_MonitoringBoolGaugeCellValue); + m.def( + "TFE_MonitoringNewBoolGauge0", + [](const char* name, const char* description) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + auto output = + TFE_MonitoringNewBoolGauge0(name, status.get(), description); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }, + py::return_value_policy::reference); + m.def("TFE_MonitoringDeleteBoolGauge0", &TFE_MonitoringDeleteBoolGauge0, + py::return_value_policy::reference); + m.def("TFE_MonitoringGetCellBoolGauge0", &TFE_MonitoringGetCellBoolGauge0, + py::return_value_policy::reference); + m.def( + "TFE_MonitoringNewBoolGauge1", + [](const char* name, const char* description, const char* label1) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + auto output = TFE_MonitoringNewBoolGauge1(name, status.get(), + description, label1); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }, + py::return_value_policy::reference); + m.def("TFE_MonitoringDeleteBoolGauge1", &TFE_MonitoringDeleteBoolGauge1, + py::return_value_policy::reference); + m.def("TFE_MonitoringGetCellBoolGauge1", &TFE_MonitoringGetCellBoolGauge1, + py::return_value_policy::reference); + m.def( + "TFE_MonitoringNewBoolGauge2", + [](const char* name, const char* description, const char* label1, + const char* label2) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + auto output = TFE_MonitoringNewBoolGauge2(name, status.get(), + description, label1, label2); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }, + py::return_value_policy::reference); + m.def("TFE_MonitoringDeleteBoolGauge2", &TFE_MonitoringDeleteBoolGauge2, + py::return_value_policy::reference); + m.def("TFE_MonitoringGetCellBoolGauge2", &TFE_MonitoringGetCellBoolGauge2, + py::return_value_policy::reference); + + // TFE_MonitoringSampler Logic + m.def("TFE_MonitoringSamplerCellAdd", &TFE_MonitoringSamplerCellAdd); + m.def("TFE_MonitoringSamplerCellValue", &TFE_MonitoringSamplerCellValue); + m.def("TFE_MonitoringNewExponentialBuckets", + &TFE_MonitoringNewExponentialBuckets, + py::return_value_policy::reference); + m.def("TFE_MonitoringDeleteBuckets", &TFE_MonitoringDeleteBuckets, + py::return_value_policy::reference); + m.def( + "TFE_MonitoringNewSampler0", + [](const char* name, TFE_MonitoringBuckets* buckets, + const char* description) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + auto output = + TFE_MonitoringNewSampler0(name, buckets, status.get(), description); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }, + py::return_value_policy::reference); + m.def("TFE_MonitoringDeleteSampler0", &TFE_MonitoringDeleteSampler0, + py::return_value_policy::reference); + m.def("TFE_MonitoringGetCellSampler0", &TFE_MonitoringGetCellSampler0, + py::return_value_policy::reference); + m.def( + "TFE_MonitoringNewSampler1", + [](const char* name, TFE_MonitoringBuckets* buckets, + const char* description, const char* label1) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + auto output = TFE_MonitoringNewSampler1(name, buckets, status.get(), + description, label1); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }, + py::return_value_policy::reference); + m.def("TFE_MonitoringDeleteSampler1", &TFE_MonitoringDeleteSampler1, + py::return_value_policy::reference); + m.def("TFE_MonitoringGetCellSampler1", &TFE_MonitoringGetCellSampler1, + py::return_value_policy::reference); + m.def( + "TFE_MonitoringNewSampler2", + [](const char* name, TFE_MonitoringBuckets* buckets, + const char* description, const char* label1, const char* label2) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + auto output = TFE_MonitoringNewSampler2(name, buckets, status.get(), + description, label1, label2); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }, + py::return_value_policy::reference); + m.def("TFE_MonitoringDeleteSampler2", &TFE_MonitoringDeleteSampler2, + py::return_value_policy::reference); + m.def("TFE_MonitoringGetCellSampler2", &TFE_MonitoringGetCellSampler2, + py::return_value_policy::reference); + + // TFE_CancellationManager Logic + m.def("TFE_NewCancellationManager", &TFE_NewCancellationManager, + py::return_value_policy::reference); + m.def("TFE_CancellationManagerIsCancelled", + &TFE_CancellationManagerIsCancelled); + m.def("TFE_CancellationManagerStartCancel", + &TFE_CancellationManagerStartCancel); + m.def("TFE_DeleteCancellationManager", &TFE_DeleteCancellationManager, + py::return_value_policy::reference); + + m.def("TFE_ClearScalarCache", &tensorflow::TFE_ClearScalarCache); + + // Util buffer helper functions + m.def("TF_NewBufferFromString", &TF_NewBufferFromString, + py::return_value_policy::reference); + m.def("TF_NewBuffer", &TF_NewBuffer, py::return_value_policy::reference); + m.def("TF_GetBuffer", [](TF_Buffer* buf) { + return tensorflow::pyo_or_throw(PyBytes_FromStringAndSize( + reinterpret_cast(buf->data), buf->length)); + }); + m.def("TF_DeleteBuffer", &TF_DeleteBuffer, + py::return_value_policy::reference); + + // C API Enum + + py::enum_( + m, "TFE_ContextDevicePlacementPolicy") + .value("TFE_DEVICE_PLACEMENT_EXPLICIT", TFE_DEVICE_PLACEMENT_EXPLICIT) + .value("TFE_DEVICE_PLACEMENT_WARN", TFE_DEVICE_PLACEMENT_WARN) + .value("TFE_DEVICE_PLACEMENT_SILENT", TFE_DEVICE_PLACEMENT_SILENT) + .value("TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32", + TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32) + .export_values(); + + py::enum_(m, "TF_AttrType") + .value("TF_ATTR_STRING", TF_ATTR_STRING) + .value("TF_ATTR_INT", TF_ATTR_INT) + .value("TF_ATTR_FLOAT", TF_ATTR_FLOAT) + .value("TF_ATTR_BOOL", TF_ATTR_BOOL) + .value("TF_ATTR_TYPE", TF_ATTR_TYPE) + .value("TF_ATTR_SHAPE", TF_ATTR_SHAPE) + .value("TF_ATTR_TENSOR", TF_ATTR_TENSOR) + .value("TF_ATTR_PLACEHOLDER", TF_ATTR_PLACEHOLDER) + .value("TF_ATTR_FUNC", TF_ATTR_FUNC) + .export_values(); + + py::enum_(m, "TFE_ContextMirroringPolicy") + .value("TFE_MIRRORING_NONE", TFE_MIRRORING_NONE) + .value("TFE_MIRRORING_ALL", TFE_MIRRORING_ALL) + .export_values(); +}; diff --git a/tensorflow/tf_exported_symbols.lds b/tensorflow/tf_exported_symbols.lds index bed2ab4aae4..7e5b06432e0 100644 --- a/tensorflow/tf_exported_symbols.lds +++ b/tensorflow/tf_exported_symbols.lds @@ -3,6 +3,7 @@ *perftools*gputools* *tf_* *TF_* +*Eager* *TFE_* *nsync_* *stream_executor* diff --git a/tensorflow/tf_version_script.lds b/tensorflow/tf_version_script.lds index f74644b7a14..ed2395cf913 100644 --- a/tensorflow/tf_version_script.lds +++ b/tensorflow/tf_version_script.lds @@ -4,6 +4,7 @@ tensorflow { *toco*; *perftools*gputools*; *TF_*; + *Eager*; *TFE_*; *nsync_*; *stream_executor*; diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index 052e757441f..e657edc4fbf 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -1,4 +1,4 @@ -[cpp_python_util] # util +[cpp_python_util] # util tfe tensorflow::swig::IsSequence tensorflow::swig::IsSequenceOrComposite tensorflow::swig::IsCompositeTensor @@ -17,6 +17,7 @@ tensorflow::swig::IsSequenceForData tensorflow::swig::FlattenForData tensorflow::swig::AssertSameStructureForData tensorflow::swig::RegisterType +tensorflow::swig::IsEagerTensorSlow [util_port] # util_port tensorflow::IsGoogleCudaEnabled @@ -74,11 +75,12 @@ tensorflow::Status::code tensorflow::Status::error_message tensorflow::Status::ok() -[core_cpu_impl] # device_lib +[core_cpu_impl] # device_lib tfe tensorflow::Device::attributes tensorflow::DeviceFactory::AddDevices tensorflow::SessionOptions::SessionOptions tensorflow::DoQuantizeTrainingOnSerializedGraphDef +tensorflow::DeviceFactory::ListAllPhysicalDevices [protos_all] # device_lib, dtypes tensorflow::DataType_IsValid @@ -123,3 +125,67 @@ tensorflow::make_safe [python_op_gen] # python_op_gen tensorflow::GetPythonWrappers + +[pywrap_tfe_lib] # tfe +tensorflow::TFE_TensorHandleCache +tensorflow::TFE_TensorHandleCache::Clear +EagerTensor_CheckExact +EagerTensorFromHandle +EagerTensor_Handle +TFE_Py_ExecuteCancelable +TFE_Py_RegisterExceptionClass +TFE_Py_RegisterVSpace +TFE_Py_RegisterFallbackExceptionClass +TFE_Py_RegisterGradientFunction +TFE_Py_RegisterJVPFunction +TFE_GetPythonString +TFE_Py_UID +TFE_DeleteContextCapsule +TFE_Py_InitEagerTensor +TFE_Py_SetEagerTensorProfiler +TFE_Py_TapeSetNew +TFE_Py_TapeSetRemove +TFE_Py_TapeSetAdd +TFE_Py_TapeSetIsEmpty +TFE_Py_TapeSetShouldRecordBackprop +TFE_Py_TapeSetPossibleGradientTypes +TFE_Py_TapeWatch +TFE_Py_TapeSetDeleteTrace +TFE_Py_TapeSetStopOnThread +TFE_Py_TapeSetRestartOnThread +TFE_Py_TapeSetIsStopped +TFE_Py_TapeSetRecordOperation +TFE_Py_TapeSetRecordOperationBackprop +TFE_Py_TapeSetRecordOperationForwardprop +TFE_Py_TapeVariableAccessed +TFE_Py_TapeWatchVariable +TFE_Py_TapeGradient +TFE_Py_FastPathExecute_C +TFE_Py_RecordGradient +TFE_Py_TapeWatchedVariables +TFE_Py_ForwardAccumulatorNew +TFE_Py_ForwardAccumulatorSetAdd +TFE_Py_ForwardAccumulatorSetRemove +TFE_Py_ForwardAccumulatorWatch +TFE_Py_ForwardAccumulatorJVP +TFE_Py_ForwardAccumulatorPushState +TFE_Py_ForwardAccumulatorPopState +TFE_Py_PackJVPs +TFE_Py_TensorShapeSlice +TFE_Py_TensorShapeOnDevice +TFE_Py_EncodeArg +TFE_Py_EnableInteractivePythonLogging +TFE_Py_SetEagerContext + +[eager_executor] # tfe +tensorflow::EagerExecutor::~EagerExecutor +tensorflow::EagerContext::WaitForAndCloseRemoteContexts + +[profiler_session] # tfe +tensorflow::ProfilerSession::~ProfilerSession + +[tf_status_helper] # tfe +tensorflow::Set_TF_Status_from_Status + +[context] # tfe +tensorflow::EagerContext::WaitForAndCloseRemoteContexts