From b66e4e833c5aacc31d0feaa629f2d064766a7a0b Mon Sep 17 00:00:00 2001 From: Amit Patankar Date: Thu, 7 Nov 2019 09:20:14 -0800 Subject: [PATCH] Export the checkpoint reader classes and functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information. PiperOrigin-RevId: 279101529 Change-Id: I25502ed3d3718499abca41f5614681f41e4c7199 --- tensorflow/c/BUILD | 9 + tensorflow/python/BUILD | 50 +++++- tensorflow/python/keras/engine/network.py | 4 +- tensorflow/python/lib/core/ndarray_tensor.cc | 13 +- tensorflow/python/lib/core/ndarray_tensor.h | 3 - tensorflow/python/pywrap_tfe.i | 9 + tensorflow/python/tensorflow.i | 2 - tensorflow/python/tools/freeze_graph.py | 4 +- tensorflow/python/tools/inspect_checkpoint.py | 4 +- .../python/training/checkpoint_utils.py | 4 +- .../python/training/py_checkpoint_reader.py | 99 +++++++++++ tensorflow/python/training/saver.py | 4 +- tensorflow/python/training/saver_test.py | 6 +- .../training/tracking/benchmarks_test.py | 4 +- tensorflow/python/training/tracking/util.py | 10 +- tensorflow/python/training/training.py | 3 +- .../python/training/warm_starting_util.py | 6 +- tensorflow/python/util/py_checkpoint_reader.i | 167 ------------------ .../util/py_checkpoint_reader_wrapper.cc | 150 ++++++++++++++++ .../tools/def_file_filter/symbols_pybind.txt | 18 ++ 20 files changed, 364 insertions(+), 205 deletions(-) create mode 100644 tensorflow/python/training/py_checkpoint_reader.py delete mode 100644 tensorflow/python/util/py_checkpoint_reader.i create mode 100644 tensorflow/python/util/py_checkpoint_reader_wrapper.cc diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index d15713f7875..a533fddb12f 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -293,6 +293,15 @@ exports_files( visibility = ["//visibility:public"], ) +filegroup( + name = "checkpoint_reader_hdrs", + srcs = [ + "checkpoint_reader.h", + "tf_status_helper.h", + ], + visibility = ["//tensorflow:__subpackages__"], +) + tf_cuda_library( name = "tf_status_helper", srcs = ["tf_status_helper.cc"], diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index b1cd03f35b9..6452c4293ba 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -98,6 +98,7 @@ py_library( "//third_party/py/tensorflow_core:__subpackages__", ], deps = [ + ":_pywrap_checkpoint_reader", ":_pywrap_events_writer", ":_pywrap_kernel_registry", ":_pywrap_py_exception_registry", @@ -629,6 +630,32 @@ tf_python_pybind_extension( ], ) +tf_python_pybind_extension( + name = "_pywrap_checkpoint_reader", + srcs = ["util/py_checkpoint_reader_wrapper.cc"], + hdrs = [ + "lib/core/ndarray_tensor.h", + "lib/core/safe_ptr.h", + ":py_exception_registry_hdr", + "//tensorflow/c:checkpoint_reader_hdrs", + "//tensorflow/c:headers", + "//tensorflow/c/eager:headers", + ], + module_name = "_pywrap_checkpoint_reader", + deps = [ + ":pybind11_lib", + ":pybind11_status", + "//tensorflow/core:lib", + "//tensorflow/core:op_gen_lib", + "//tensorflow/core:protos_all", + "//tensorflow/core/util/tensor_bundle:tensor_bundle_headers_lib", + "//third_party/py/numpy:headers", + "//third_party/python_runtime:headers", + "@com_google_absl//absl/strings", + "@pybind11", + ], +) + filegroup( name = "py_exception_registry_hdr", srcs = [ @@ -996,6 +1023,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":_pywrap_checkpoint_reader", ":_pywrap_debug_events_writer", ":_pywrap_events_writer", ":_pywrap_kernel_registry", @@ -4730,6 +4758,7 @@ py_library( ":math_ops", ":mixed_precision", ":platform", + ":py_checkpoint_reader", ":pywrap_tensorflow", ":random_ops", ":resource_variable_ops", @@ -4785,6 +4814,17 @@ py_library( ], ) +py_library( + name = "py_checkpoint_reader", + srcs = ["training/py_checkpoint_reader.py"], + deps = [ + ":_pywrap_checkpoint_reader", + ":dtypes", + ":errors", + ":util", + ], +) + py_library( name = "checkpoint_management", srcs = ["training/checkpoint_management.py"], @@ -5284,7 +5324,6 @@ tf_py_wrap_cc( "lib/io/py_record_writer.i", "platform/base.i", "pywrap_tfe.i", - "util/py_checkpoint_reader.i", "//tensorflow/compiler/mlir/python:mlir.i", ], # add win_def_file for pywrap_tensorflow @@ -5336,6 +5375,7 @@ tf_py_wrap_cc( "//tensorflow/tools/graph_transforms:transform_graph_lib", "//tensorflow/lite/toco/python:toco_python_api", "//tensorflow/python/eager:pywrap_tfe_lib", + "//tensorflow/core/util/tensor_bundle:tensor_bundle", ] + (tf_additional_lib_deps() + tf_additional_plugin_deps()) + if_ngraph([ "@ngraph_tf//:ngraph_tf", @@ -5391,9 +5431,14 @@ genrule( "//tensorflow/core/profiler/internal:python_traceme", # traceme "//tensorflow/core/profiler/internal:traceme_recorder", # traceme ":py_exception_registry", # py_exception_registry - ":kernel_registry", + ":kernel_registry", # kernel_registry "//tensorflow/lite/toco/python:toco_python_api", # toco "//tensorflow/tools/graph_transforms:transform_graph_lib", # transform_graph + "//tensorflow/c:checkpoint_reader", # checkpoint_reader + ":ndarray_tensor", # checkpoint_reader + ":numpy_lib", # checkpoint_reader + ":safe_ptr", # checkpoint_reader + "//tensorflow/core/util/tensor_bundle", # checkpoint_reader ], outs = ["pybind_symbol_target_libs_file.txt"], cmd = select({ @@ -6211,7 +6256,6 @@ tf_py_test( ":io_ops", ":partitioned_variables", ":platform", - ":pywrap_tensorflow", ":resource_variable_ops", ":state_ops", ":training", diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index a0769264e05..68fbc799660 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -30,7 +30,6 @@ import numpy as np import six from six.moves import zip # pylint: disable=redefined-builtin -from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors @@ -54,6 +53,7 @@ from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_management +from tensorflow.python.training import py_checkpoint_reader from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.tracking import data_structures from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils @@ -1176,7 +1176,7 @@ class Network(base_layer.Layer): save_format = 'h5' else: try: - pywrap_tensorflow.NewCheckpointReader(filepath) + py_checkpoint_reader.NewCheckpointReader(filepath) save_format = 'tf' except errors_impl.DataLossError: # The checkpoint is not readable in TensorFlow format. Try HDF5. diff --git a/tensorflow/python/lib/core/ndarray_tensor.cc b/tensorflow/python/lib/core/ndarray_tensor.cc index 50e90e32fe0..8c8362972be 100644 --- a/tensorflow/python/lib/core/ndarray_tensor.cc +++ b/tensorflow/python/lib/core/ndarray_tensor.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/python/lib/core/bfloat16.h" #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h" +#include "tensorflow/python/lib/core/numpy.h" namespace tensorflow { namespace { @@ -527,12 +528,12 @@ Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor) { size_t size = 0; void* encoded = nullptr; TF_RETURN_IF_ERROR(EncodePyBytesArray(array, nelems, &size, &encoded)); - *out_tensor = - make_safe(TF_NewTensor(dtype, dims.data(), dims.size(), encoded, size, - [](void* data, size_t len, void* arg) { - delete[] reinterpret_cast(data); - }, - nullptr)); + *out_tensor = make_safe(TF_NewTensor( + dtype, dims.data(), dims.size(), encoded, size, + [](void* data, size_t len, void* arg) { + delete[] reinterpret_cast(data); + }, + nullptr)); } return Status::OK(); } diff --git a/tensorflow/python/lib/core/ndarray_tensor.h b/tensorflow/python/lib/core/ndarray_tensor.h index 3f994524893..c5cd24cff2d 100644 --- a/tensorflow/python/lib/core/ndarray_tensor.h +++ b/tensorflow/python/lib/core/ndarray_tensor.h @@ -16,9 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_ #define TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_ -// Must be included first. -#include "tensorflow/python/lib/core/numpy.h" - #include "tensorflow/c/c_api.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 0473c342c26..25106769c15 100755 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -13,7 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +%include "tensorflow/python/lib/core/strings.i" %include "tensorflow/python/platform/base.i" + +%{ +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/python/lib/core/ndarray_tensor.h" +#include "tensorflow/python/lib/core/safe_ptr.h" +%} + + %include "tensorflow/c/tf_datatype.h" %include "tensorflow/c/tf_status.h" diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index c922071712f..b36024d513f 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -17,8 +17,6 @@ limitations under the License. * The includes are intentionally not alphabetically sorted, as the order of * includes follows dependency order */ -%include "tensorflow/python/util/py_checkpoint_reader.i" - %include "tensorflow/python/pywrap_tfe.i" diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py index e955e7bb06e..9ffc98f1743 100644 --- a/tensorflow/python/tools/freeze_graph.py +++ b/tensorflow/python/tools/freeze_graph.py @@ -46,7 +46,6 @@ from google.protobuf import text_format from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import saver_pb2 from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef -from tensorflow.python import pywrap_tensorflow from tensorflow.python.client import session from tensorflow.python.framework import graph_util from tensorflow.python.framework import importer @@ -56,6 +55,7 @@ from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import tag_constants from tensorflow.python.tools import saved_model_utils from tensorflow.python.training import checkpoint_management +from tensorflow.python.training import py_checkpoint_reader from tensorflow.python.training import saver as saver_lib @@ -161,7 +161,7 @@ def freeze_graph_with_def_protos(input_graph_def, loader.load(sess, saved_model_tags, input_saved_model_dir) else: var_list = {} - reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) + reader = py_checkpoint_reader.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() # List of all partition variables. Because the condition is heuristic diff --git a/tensorflow/python/tools/inspect_checkpoint.py b/tensorflow/python/tools/inspect_checkpoint.py index c59b9548add..01e98c2c7f4 100644 --- a/tensorflow/python/tools/inspect_checkpoint.py +++ b/tensorflow/python/tools/inspect_checkpoint.py @@ -23,9 +23,9 @@ import sys import numpy as np -from tensorflow.python import pywrap_tensorflow from tensorflow.python.platform import app from tensorflow.python.platform import flags +from tensorflow.python.training import py_checkpoint_reader FLAGS = None @@ -72,7 +72,7 @@ def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors, count_exclude_pattern: Regex string, pattern to exclude tensors when count. """ try: - reader = pywrap_tensorflow.NewCheckpointReader(file_name) + reader = py_checkpoint_reader.NewCheckpointReader(file_name) if all_tensors or all_tensor_names: var_to_shape_map = reader.get_variable_to_shape_map() var_to_dtype_map = reader.get_variable_to_dtype_map() diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py index d470021ad68..0a7f4d6e5f7 100644 --- a/tensorflow/python/training/checkpoint_utils.py +++ b/tensorflow/python/training/checkpoint_utils.py @@ -21,7 +21,6 @@ from __future__ import print_function import time import six -from tensorflow.python import pywrap_tensorflow from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.framework import ops from tensorflow.python.ops import io_ops @@ -31,6 +30,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_management +from tensorflow.python.training import py_checkpoint_reader from tensorflow.python.training.saving import saveable_object_util from tensorflow.python.util.tf_export import tf_export @@ -63,7 +63,7 @@ def load_checkpoint(ckpt_dir_or_file): if filename is None: raise ValueError("Couldn't find 'checkpoint' file or checkpoints in " "given directory %s" % ckpt_dir_or_file) - return pywrap_tensorflow.NewCheckpointReader(filename) + return py_checkpoint_reader.NewCheckpointReader(filename) @tf_export("train.load_variable") diff --git a/tensorflow/python/training/py_checkpoint_reader.py b/tensorflow/python/training/py_checkpoint_reader.py new file mode 100644 index 00000000000..83ab6e21304 --- /dev/null +++ b/tensorflow/python/training/py_checkpoint_reader.py @@ -0,0 +1,99 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Extending CheckpointReader for TensorFlow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python._pywrap_checkpoint_reader import CheckpointReader +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.util import compat +from tensorflow.python.util.tf_export import tf_export + + +def error_translator(e): + """Translate the tensor_slice_reader.cc errors.""" + # TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the + # issue with throwing python exceptions from C++. + error_message = str(e) + if 'not found in checkpoint' in error_message or ( + 'Failed to find any ' + 'matching files for') in error_message: + raise errors_impl.NotFoundError(None, None, error_message) + elif 'Sliced checkpoints are not supported' in error_message or ( + 'Data type ' + 'not ' + 'supported') in error_message: + raise errors_impl.UnimplementedError(None, None, error_message) + elif 'Failed to get matching files on' in error_message: + raise errors_impl.InvalidArgumentError(None, None, error_message) + elif 'Unable to open table file' in error_message: + raise errors_impl.DataLossError(None, None, error_message) + elif 'Failed to find the saved tensor slices' in error_message: + raise errors_impl.InternalError(None, None, error_message) + else: + raise errors_impl.OpError(None, None, error_message, errors_impl.UNKNOWN) + + +def get_variable_to_dtype_map(self): + return { + name: dtypes.DType(type_enum) + for name, type_enum in self._GetVariableToDataTypeMap().items() # pylint: disable=protected-access + } + +CheckpointReader.get_variable_to_dtype_map = get_variable_to_dtype_map + + +def has_tensor(self, tensor_str): + return self._HasTensor(compat.as_bytes(tensor_str)) # pylint: disable=protected-access + +CheckpointReader.has_tensor = has_tensor + + +def get_tensor(self, tensor_str): + """Get the tensor from the Checkpoint object.""" + try: + return CheckpointReader.CheckpointReader_GetTensor( + self, compat.as_bytes(tensor_str)) + # TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the + # issue with throwing python exceptions from C++. + except RuntimeError as e: + error_translator(e) + + +CheckpointReader.get_tensor = get_tensor + + +# Disable invalid name to keep backwards compatibility with that function. +# It was previously exported from py_checkpoint_reader.i which did not conform +# to pylint checks. +# pylint: disable=invalid-name +@tf_export(v1=['train.NewCheckpointReader']) +def NewCheckpointReader(filepattern): + """A function that returns a CheckPointReader. + + Args: + filepattern: The filename. + + Returns: + A CheckpointReader object. + """ + try: + return CheckpointReader(compat.as_bytes(filepattern)) + # TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the + # issue with throwing python exceptions from C++. + except RuntimeError as e: + error_translator(e) diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index d65297fb30d..6d2b339e3a8 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -32,7 +32,6 @@ import numpy as np from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import saver_pb2 from tensorflow.core.protobuf import trackable_object_graph_pb2 -from tensorflow.python import pywrap_tensorflow from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.framework import constant_op @@ -49,6 +48,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_management +from tensorflow.python.training import py_checkpoint_reader from tensorflow.python.training import training_util from tensorflow.python.training.saving import saveable_object from tensorflow.python.training.saving import saveable_object_util @@ -1614,7 +1614,7 @@ def object_graph_key_mapping(checkpoint_path): Returns: Dictionary mapping tensor names to checkpoint keys. """ - reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) + reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path) object_graph_string = reader.get_tensor(trackable.OBJECT_GRAPH_PROTO_KEY) object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index 99492bc5890..40cd26fd2d0 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -34,7 +34,6 @@ from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import queue_runner_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.core.protobuf import saver_pb2 -from tensorflow.python import pywrap_tensorflow from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context @@ -69,6 +68,7 @@ from tensorflow.python.summary import summary from tensorflow.python.training import adam from tensorflow.python.training import checkpoint_management from tensorflow.python.training import gradient_descent +from tensorflow.python.training import py_checkpoint_reader from tensorflow.python.training import queue_runner_impl from tensorflow.python.training import saver as saver_module from tensorflow.python.training import saver_test_utils @@ -2468,7 +2468,7 @@ class CheckpointReaderTest(test.TestCase): save.save(sess, save_path) # Creates a reader. - reader = pywrap_tensorflow.NewCheckpointReader(save_path) + reader = py_checkpoint_reader.NewCheckpointReader(save_path) # Verifies that the tensors exist. self.assertTrue(reader.has_tensor("v0")) self.assertTrue(reader.has_tensor("v1")) @@ -2493,7 +2493,7 @@ class CheckpointReaderTest(test.TestCase): def testNonexistentPath(self): with self.assertRaisesRegexp(errors.NotFoundError, "Unsuccessful TensorSliceReader"): - pywrap_tensorflow.NewCheckpointReader("non-existent") + py_checkpoint_reader.NewCheckpointReader("non-existent") class CheckpointReaderForV2Test(CheckpointReaderTest): diff --git a/tensorflow/python/training/tracking/benchmarks_test.py b/tensorflow/python/training/tracking/benchmarks_test.py index 7514d9f54cf..666adf78c58 100644 --- a/tensorflow/python/training/tracking/benchmarks_test.py +++ b/tensorflow/python/training/tracking/benchmarks_test.py @@ -21,7 +21,6 @@ from __future__ import print_function import os import time -from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.module import module @@ -29,6 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_io_ops from tensorflow.python.platform import test +from tensorflow.python.training import py_checkpoint_reader from tensorflow.python.training.saving import saveable_object from tensorflow.python.training.tracking import base from tensorflow.python.training.tracking import util @@ -116,7 +116,7 @@ class SavingBenchmarks(test.Benchmark): def benchmark_raw_restore(self): checkpoint_path = _save_checkpoint() - all_names, all_dtypes = zip(*pywrap_tensorflow.NewCheckpointReader( + all_names, all_dtypes = zip(*py_checkpoint_reader.NewCheckpointReader( checkpoint_path).get_variable_to_dtype_map().items()) def _call_restore_v2(): diff --git a/tensorflow/python/training/tracking/util.py b/tensorflow/python/training/tracking/util.py index ed4319c0617..d40e00f9a12 100644 --- a/tensorflow/python/training/tracking/util.py +++ b/tensorflow/python/training/tracking/util.py @@ -25,7 +25,6 @@ import weakref import six from tensorflow.core.protobuf import trackable_object_graph_pb2 -from tensorflow.python import pywrap_tensorflow from tensorflow.python.client import session as session_lib from tensorflow.python.eager import context from tensorflow.python.eager import def_function @@ -43,6 +42,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_management +from tensorflow.python.training import py_checkpoint_reader from tensorflow.python.training import saver as v1_saver_lib from tensorflow.python.training.saving import functional_saver from tensorflow.python.training.saving import saveable_object_util @@ -200,7 +200,7 @@ class _CheckpointRestoreCoordinator(object): self.all_python_objects = object_identity.ObjectIdentityWeakSet() self.save_path_tensor = save_path_tensor self.save_path_string = save_path - self.dtype_map = pywrap_tensorflow.NewCheckpointReader( + self.dtype_map = py_checkpoint_reader.NewCheckpointReader( save_path).get_variable_to_dtype_map() # A NewCheckpointReader for the most recent checkpoint, for streaming Python # state restoration. @@ -272,7 +272,7 @@ class _CheckpointRestoreCoordinator(object): if reader is None: # Lazily create the NewCheckpointReader, since this requires file access # and we may not have any Python saveables. - reader = pywrap_tensorflow.NewCheckpointReader(self.save_path_string) + reader = py_checkpoint_reader.NewCheckpointReader(self.save_path_string) spec_names = [spec.name for spec in saveable.specs] saveable.python_restore([reader.get_tensor(name) for name in spec_names]) @@ -462,7 +462,7 @@ def object_metadata(save_path): Raises: ValueError: If an object graph was not found in the checkpoint. """ - reader = pywrap_tensorflow.NewCheckpointReader(save_path) + reader = py_checkpoint_reader.NewCheckpointReader(save_path) try: object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY) except errors_impl.NotFoundError: @@ -1238,7 +1238,7 @@ class TrackableSaver(object): """ if save_path is None: return InitializationOnlyStatus(self._graph_view, ops.uid()) - reader = pywrap_tensorflow.NewCheckpointReader(save_path) + reader = py_checkpoint_reader.NewCheckpointReader(save_path) graph_building = not context.executing_eagerly() if graph_building: dtype_map = None diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py index 287b36bb615..91a67f18f7c 100644 --- a/tensorflow/python/training/training.py +++ b/tensorflow/python/training/training.py @@ -110,7 +110,7 @@ from tensorflow.python.training.training_util import create_global_step from tensorflow.python.training.training_util import get_or_create_global_step from tensorflow.python.training.warm_starting_util import VocabInfo from tensorflow.python.training.warm_starting_util import warm_start -from tensorflow.python.pywrap_tensorflow import NewCheckpointReader +from tensorflow.python.training.py_checkpoint_reader import NewCheckpointReader from tensorflow.python.util.tf_export import tf_export # pylint: disable=wildcard-import @@ -145,4 +145,3 @@ tf_export(v1=["train.SaverDef"])(SaverDef) tf_export("train.SequenceExample")(SequenceExample) tf_export("train.ServerDef")(ServerDef) # pylint: enable=undefined-variable - diff --git a/tensorflow/python/training/warm_starting_util.py b/tensorflow/python/training/warm_starting_util.py index 8fe059c2b47..fa334a41875 100644 --- a/tensorflow/python/training/warm_starting_util.py +++ b/tensorflow/python/training/warm_starting_util.py @@ -335,14 +335,16 @@ def _get_grouped_variables(vars_to_warm_start): ValueError: If vars_to_warm_start is not a string, `None`, a list of `Variables`, or a list of strings. """ - if isinstance(vars_to_warm_start, str) or vars_to_warm_start is None: + # TODO(b/143899805): Remove unicode checks when deprecating Python2. + if isinstance(vars_to_warm_start, + six.string_types) or vars_to_warm_start is None: # Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match # everything (in TRAINABLE_VARIABLES) here. logging.info("Warm-starting variables only in TRAINABLE_VARIABLES.") list_of_vars = ops.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES, scope=vars_to_warm_start) elif isinstance(vars_to_warm_start, list): - if all(isinstance(v, str) for v in vars_to_warm_start): + if all(isinstance(v, six.string_types) for v in vars_to_warm_start): list_of_vars = [] for v in vars_to_warm_start: list_of_vars += ops.get_collection( diff --git a/tensorflow/python/util/py_checkpoint_reader.i b/tensorflow/python/util/py_checkpoint_reader.i deleted file mode 100644 index 7a2ff22be99..00000000000 --- a/tensorflow/python/util/py_checkpoint_reader.i +++ /dev/null @@ -1,167 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -%include "tensorflow/python/lib/core/strings.i" -%include "tensorflow/python/platform/base.i" - -%{ -#include "tensorflow/c/checkpoint_reader.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/python/lib/core/ndarray_tensor.h" -#include "tensorflow/python/lib/core/safe_ptr.h" -%} - -%typemap(out) const tensorflow::checkpoint::TensorSliceReader::VarToShapeMap& { - tensorflow::Safe_PyObjectPtr output_map(tensorflow::make_safe(PyDict_New())); - for (auto v : *$1) { -%#if PY_MAJOR_VERSION >= 3 - tensorflow::Safe_PyObjectPtr key( - tensorflow::make_safe(PyUnicode_FromStringAndSize(v.first.c_str(), - v.first.size()))); -%#else - tensorflow::Safe_PyObjectPtr key( - tensorflow::make_safe(PyString_FromStringAndSize(v.first.c_str(), - v.first.size()))); -%#endif - if (!key) { - SWIG_fail; - } - size_t dims = v.second.dims(); - tensorflow::Safe_PyObjectPtr value(tensorflow::make_safe(PyList_New(dims))); - if (!value) { - SWIG_fail; - } - for (size_t i = 0; i < dims; ++i) { -%#if PY_MAJOR_VERSION >= 3 - tensorflow::Safe_PyObjectPtr dim_value( - tensorflow::make_safe(PyLong_FromLong(v.second.dim_size(i)))); -%#else - tensorflow::Safe_PyObjectPtr dim_value( - tensorflow::make_safe(PyInt_FromLong(v.second.dim_size(i)))); -%#endif - if (!dim_value) { - SWIG_fail; - } - PyList_SET_ITEM(value.get(), i, dim_value.release()); - } - if (PyDict_SetItem(output_map.get(), key.get(), value.get()) == -1) { - SWIG_fail; - } else { - key.release(); - value.release(); - } - } - - $result = output_map.release(); -} - -%typemap(out) const tensorflow::checkpoint::TensorSliceReader::VarToDataTypeMap& { - tensorflow::Safe_PyObjectPtr output_map(tensorflow::make_safe(PyDict_New())); - for (auto v : *$1) { -%#if PY_MAJOR_VERSION >= 3 - tensorflow::Safe_PyObjectPtr key( - tensorflow::make_safe(PyUnicode_FromStringAndSize(v.first.c_str(), v.first.size()))); -%#else - tensorflow::Safe_PyObjectPtr key( - tensorflow::make_safe(PyString_FromStringAndSize(v.first.c_str(), v.first.size()))); -%#endif - if (!key) { - SWIG_fail; - } -%#if PY_MAJOR_VERSION >= 3 - tensorflow::Safe_PyObjectPtr value(tensorflow::make_safe(PyLong_FromLong(v.second))); -%#else - tensorflow::Safe_PyObjectPtr value(tensorflow::make_safe(PyInt_FromLong(v.second))); -%#endif - if (!value) { - SWIG_fail; - } - if (PyDict_SetItem(output_map.get(), key.get(), value.get()) == -1) { - SWIG_fail; - } else { - key.release(); - value.release(); - } - } - - $result = output_map.release(); -} - -%{ -static PyObject* CheckpointReader_GetTensor( - tensorflow::checkpoint::CheckpointReader* reader, - const string& name, - TF_Status* status) { - PyObject* py_obj = Py_None; - std::unique_ptr tensor; - reader->GetTensor(name, &tensor, status); - if (TF_GetCode(status) == TF_OK) { - tensorflow::Status s = - tensorflow::TensorToNdarray(*tensor.get(), &py_obj); - if (!s.ok()) { - Set_TF_Status_from_Status(status, s); - } - } - return PyArray_Return(reinterpret_cast(py_obj)); -} -%} - -// Wrap this function. -PyObject* CheckpointReader_GetTensor( - tensorflow::checkpoint::CheckpointReader* reader, - const string& name, - TF_Status* status); - -%ignoreall - -%unignore tensorflow; -%unignore tensorflow::checkpoint; -%unignore tensorflow::checkpoint::CheckpointReader; -%unignore tensorflow::checkpoint::CheckpointReader::CheckpointReader; -%unignore tensorflow::checkpoint::CheckpointReader::~CheckpointReader; -%rename("debug_string") tensorflow::checkpoint::CheckpointReader::DebugString; -%rename("get_variable_to_shape_map") tensorflow::checkpoint::CheckpointReader::GetVariableToShapeMap; -%rename("_GetVariableToDataTypeMap") tensorflow::checkpoint::CheckpointReader::GetVariableToDataTypeMap; -%rename("_HasTensor") tensorflow::checkpoint::CheckpointReader::HasTensor; -%unignore CheckpointReader_GetTensor; - -%extend tensorflow::checkpoint::CheckpointReader { -%insert("python") %{ - def get_variable_to_dtype_map(self): - from tensorflow.python.framework import dtypes - return {name: dtypes.DType(type_enum) - for name, type_enum in self._GetVariableToDataTypeMap().items()} - - def has_tensor(self, tensor_str): - from tensorflow.python.util import compat - return self._HasTensor(compat.as_bytes(tensor_str)) - - def get_tensor(self, tensor_str): - from tensorflow.python.util import compat - - return CheckpointReader_GetTensor(self, compat.as_bytes(tensor_str)) -%} -} - -%insert("python") %{ -def NewCheckpointReader(filepattern): - from tensorflow.python.util import compat - return CheckpointReader(compat.as_bytes(filepattern)) - -NewCheckpointReader._tf_api_names_v1 = ['train.NewCheckpointReader'] -%} - -%include "tensorflow/c/checkpoint_reader.h" -%unignoreall diff --git a/tensorflow/python/util/py_checkpoint_reader_wrapper.cc b/tensorflow/python/util/py_checkpoint_reader_wrapper.cc new file mode 100644 index 00000000000..a7076f6ee29 --- /dev/null +++ b/tensorflow/python/util/py_checkpoint_reader_wrapper.cc @@ -0,0 +1,150 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Disallow Numpy 1.7 deprecated symbols. +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +#include "numpy/arrayobject.h" +#include "numpy/ufuncobject.h" +#include "include/pybind11/chrono.h" +#include "include/pybind11/complex.h" +#include "include/pybind11/functional.h" +#include "include/pybind11/pybind11.h" +#include "include/pybind11/stl.h" +#include "tensorflow/c/checkpoint_reader.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/python/lib/core/ndarray_tensor.h" +#include "tensorflow/python/lib/core/py_exception_registry.h" +#include "tensorflow/python/lib/core/pybind11_lib.h" +#include "tensorflow/python/lib/core/pybind11_status.h" +#include "tensorflow/python/lib/core/safe_ptr.h" + +namespace py = pybind11; + +// TODO(amitpatankar): Move the custom type casters to separate common header +// only libraries. + +namespace pybind11 { +namespace detail { + +/* This is a custom type caster for the TensorShape object. For more + * documentation please refer to this link: + * https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html#custom-type-casters + * The PyCheckpointReader methods sometimes return the `TensorShape` object + * and the `DataType` object as outputs. This custom type caster helps Python + * handle it's conversion from C++ to Python. Since we do not accept these + * classes as arguments from Python, it is not necessary to define the `load` + * function to cast the object from Python to a C++ object. + */ + +template <> +struct type_caster { + public: + PYBIND11_TYPE_CASTER(tensorflow::TensorShape, _("tensorflow::TensorShape")); + + static handle cast(const tensorflow::TensorShape& src, + return_value_policy unused_policy, handle unused_handle) { + // TODO(amitpatankar): Simplify handling TensorShape as output later. + size_t dims = src.dims(); + tensorflow::Safe_PyObjectPtr value(PyList_New(dims)); + for (size_t i = 0; i < dims; ++i) { +#if PY_MAJOR_VERSION >= 3 + tensorflow::Safe_PyObjectPtr dim_value( + tensorflow::make_safe(PyLong_FromLong(src.dim_size(i)))); +#else + tensorflow::Safe_PyObjectPtr dim_value( + tensorflow::make_safe(PyInt_FromLong(src.dim_size(i)))); +#endif + PyList_SET_ITEM(value.get(), i, dim_value.release()); + } + + return value.release(); + } +}; + +template <> +struct type_caster { + public: + PYBIND11_TYPE_CASTER(tensorflow::DataType, _("tensorflow::DataType")); + + static handle cast(const tensorflow::DataType& src, + return_value_policy unused_policy, handle unused_handle) { +#if PY_MAJOR_VERSION >= 3 + tensorflow::Safe_PyObjectPtr value( + tensorflow::make_safe(PyLong_FromLong(src))); +#else + tensorflow::Safe_PyObjectPtr value( + tensorflow::make_safe(PyInt_FromLong(src))); +#endif + return value.release(); + } +}; + +} // namespace detail +} // namespace pybind11 + +namespace tensorflow { + +static py::object CheckpointReader_GetTensor( + tensorflow::checkpoint::CheckpointReader* reader, const string& name) { + Safe_TF_StatusPtr status = make_safe(TF_NewStatus()); + PyObject* py_obj = Py_None; + std::unique_ptr tensor; + reader->GetTensor(name, &tensor, status.get()); + + // Error handling if unable to get Tensor. + tensorflow::MaybeRaiseFromTFStatus(status.get()); + + tensorflow::MaybeRaiseFromStatus( + tensorflow::TensorToNdarray(*tensor, &py_obj)); + + return tensorflow::pyo_or_throw( + PyArray_Return(reinterpret_cast(py_obj))); +} + +} // namespace tensorflow + +PYBIND11_MODULE(_pywrap_checkpoint_reader, m) { + // Initialization code to use numpy types in the type casters. + import_array1(); + py::class_ checkpoint_reader_class( + m, "CheckpointReader"); + checkpoint_reader_class + .def(py::init([](const std::string& filename) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + // pybind11 support smart pointers and will own freeing the memory when + // complete. + // https://pybind11.readthedocs.io/en/master/advanced/smart_ptrs.html#std-unique-ptr + auto checkpoint = + std::make_unique( + filename, status.get()); + tensorflow::MaybeRaiseFromTFStatus(status.get()); + return checkpoint; + })) + .def("debug_string", + [](tensorflow::checkpoint::CheckpointReader& self) { + return py::bytes(self.DebugString()); + }) + .def("get_variable_to_shape_map", + &tensorflow::checkpoint::CheckpointReader::GetVariableToShapeMap) + .def("_GetVariableToDataTypeMap", + &tensorflow::checkpoint::CheckpointReader::GetVariableToDataTypeMap) + .def("_HasTensor", &tensorflow::checkpoint::CheckpointReader::HasTensor) + .def_static("CheckpointReader_GetTensor", + &tensorflow::CheckpointReader_GetTensor); +}; diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index c4d92c8bbba..e773c77502c 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -107,3 +107,21 @@ tensorflow::profiler::PythonTraceMe::IsEnabled [traceme_recorder] # traceme tensorflow::profiler::internal::g_trace_level +[checkpoint_reader] # py_checkpoint_reader +tensorflow::checkpoint::CheckpointReader +tensorflow::checkpoint::CheckpointReader::Init +tensorflow::checkpoint::CheckpointReader::DebugString +tensorflow::checkpoint::CheckpointReader::GetVariableToShapeMap +tensorflow::checkpoint::CheckpointReader::GetVariableToDataTypeMap +tensorflow::checkpoint::CheckpointReader::GetTensor +tensorflow::checkpoint::CheckpointReader::HasTensor + +[tensor_bundle] # py_checkpoint_reader +tensorflow::BundleReader::BundleReader +tensorflow::BundleReader::~BundleReader + +[ndarray_tensor] # py_checkpoint_reader +tensorflow::TensorToNdarray + +[safe_ptr] # py_checkpoint_reader +tensorflow::detail::PyDecrefDeleter