Export the checkpoint reader classes and functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.
PiperOrigin-RevId: 279101529 Change-Id: I25502ed3d3718499abca41f5614681f41e4c7199
This commit is contained in:
parent
8977382c53
commit
b66e4e833c
@ -293,6 +293,15 @@ exports_files(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "checkpoint_reader_hdrs",
|
||||
srcs = [
|
||||
"checkpoint_reader.h",
|
||||
"tf_status_helper.h",
|
||||
],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "tf_status_helper",
|
||||
srcs = ["tf_status_helper.cc"],
|
||||
|
@ -98,6 +98,7 @@ py_library(
|
||||
"//third_party/py/tensorflow_core:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
":_pywrap_checkpoint_reader",
|
||||
":_pywrap_events_writer",
|
||||
":_pywrap_kernel_registry",
|
||||
":_pywrap_py_exception_registry",
|
||||
@ -629,6 +630,32 @@ tf_python_pybind_extension(
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_checkpoint_reader",
|
||||
srcs = ["util/py_checkpoint_reader_wrapper.cc"],
|
||||
hdrs = [
|
||||
"lib/core/ndarray_tensor.h",
|
||||
"lib/core/safe_ptr.h",
|
||||
":py_exception_registry_hdr",
|
||||
"//tensorflow/c:checkpoint_reader_hdrs",
|
||||
"//tensorflow/c:headers",
|
||||
"//tensorflow/c/eager:headers",
|
||||
],
|
||||
module_name = "_pywrap_checkpoint_reader",
|
||||
deps = [
|
||||
":pybind11_lib",
|
||||
":pybind11_status",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
"//tensorflow/core:protos_all",
|
||||
"//tensorflow/core/util/tensor_bundle:tensor_bundle_headers_lib",
|
||||
"//third_party/py/numpy:headers",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "py_exception_registry_hdr",
|
||||
srcs = [
|
||||
@ -996,6 +1023,7 @@ py_library(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":_pywrap_checkpoint_reader",
|
||||
":_pywrap_debug_events_writer",
|
||||
":_pywrap_events_writer",
|
||||
":_pywrap_kernel_registry",
|
||||
@ -4730,6 +4758,7 @@ py_library(
|
||||
":math_ops",
|
||||
":mixed_precision",
|
||||
":platform",
|
||||
":py_checkpoint_reader",
|
||||
":pywrap_tensorflow",
|
||||
":random_ops",
|
||||
":resource_variable_ops",
|
||||
@ -4785,6 +4814,17 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "py_checkpoint_reader",
|
||||
srcs = ["training/py_checkpoint_reader.py"],
|
||||
deps = [
|
||||
":_pywrap_checkpoint_reader",
|
||||
":dtypes",
|
||||
":errors",
|
||||
":util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "checkpoint_management",
|
||||
srcs = ["training/checkpoint_management.py"],
|
||||
@ -5284,7 +5324,6 @@ tf_py_wrap_cc(
|
||||
"lib/io/py_record_writer.i",
|
||||
"platform/base.i",
|
||||
"pywrap_tfe.i",
|
||||
"util/py_checkpoint_reader.i",
|
||||
"//tensorflow/compiler/mlir/python:mlir.i",
|
||||
],
|
||||
# add win_def_file for pywrap_tensorflow
|
||||
@ -5336,6 +5375,7 @@ tf_py_wrap_cc(
|
||||
"//tensorflow/tools/graph_transforms:transform_graph_lib",
|
||||
"//tensorflow/lite/toco/python:toco_python_api",
|
||||
"//tensorflow/python/eager:pywrap_tfe_lib",
|
||||
"//tensorflow/core/util/tensor_bundle:tensor_bundle",
|
||||
] + (tf_additional_lib_deps() +
|
||||
tf_additional_plugin_deps()) + if_ngraph([
|
||||
"@ngraph_tf//:ngraph_tf",
|
||||
@ -5391,9 +5431,14 @@ genrule(
|
||||
"//tensorflow/core/profiler/internal:python_traceme", # traceme
|
||||
"//tensorflow/core/profiler/internal:traceme_recorder", # traceme
|
||||
":py_exception_registry", # py_exception_registry
|
||||
":kernel_registry",
|
||||
":kernel_registry", # kernel_registry
|
||||
"//tensorflow/lite/toco/python:toco_python_api", # toco
|
||||
"//tensorflow/tools/graph_transforms:transform_graph_lib", # transform_graph
|
||||
"//tensorflow/c:checkpoint_reader", # checkpoint_reader
|
||||
":ndarray_tensor", # checkpoint_reader
|
||||
":numpy_lib", # checkpoint_reader
|
||||
":safe_ptr", # checkpoint_reader
|
||||
"//tensorflow/core/util/tensor_bundle", # checkpoint_reader
|
||||
],
|
||||
outs = ["pybind_symbol_target_libs_file.txt"],
|
||||
cmd = select({
|
||||
@ -6211,7 +6256,6 @@ tf_py_test(
|
||||
":io_ops",
|
||||
":partitioned_variables",
|
||||
":platform",
|
||||
":pywrap_tensorflow",
|
||||
":resource_variable_ops",
|
||||
":state_ops",
|
||||
":training",
|
||||
|
@ -30,7 +30,6 @@ import numpy as np
|
||||
import six
|
||||
from six.moves import zip # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors
|
||||
@ -54,6 +53,7 @@ from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import py_checkpoint_reader
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
|
||||
@ -1176,7 +1176,7 @@ class Network(base_layer.Layer):
|
||||
save_format = 'h5'
|
||||
else:
|
||||
try:
|
||||
pywrap_tensorflow.NewCheckpointReader(filepath)
|
||||
py_checkpoint_reader.NewCheckpointReader(filepath)
|
||||
save_format = 'tf'
|
||||
except errors_impl.DataLossError:
|
||||
# The checkpoint is not readable in TensorFlow format. Try HDF5.
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/python/lib/core/bfloat16.h"
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
|
||||
#include "tensorflow/python/lib/core/numpy.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -527,12 +528,12 @@ Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor) {
|
||||
size_t size = 0;
|
||||
void* encoded = nullptr;
|
||||
TF_RETURN_IF_ERROR(EncodePyBytesArray(array, nelems, &size, &encoded));
|
||||
*out_tensor =
|
||||
make_safe(TF_NewTensor(dtype, dims.data(), dims.size(), encoded, size,
|
||||
[](void* data, size_t len, void* arg) {
|
||||
delete[] reinterpret_cast<char*>(data);
|
||||
},
|
||||
nullptr));
|
||||
*out_tensor = make_safe(TF_NewTensor(
|
||||
dtype, dims.data(), dims.size(), encoded, size,
|
||||
[](void* data, size_t len, void* arg) {
|
||||
delete[] reinterpret_cast<char*>(data);
|
||||
},
|
||||
nullptr));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -16,9 +16,6 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_
|
||||
#define TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_
|
||||
|
||||
// Must be included first.
|
||||
#include "tensorflow/python/lib/core/numpy.h"
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
@ -13,7 +13,16 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%include "tensorflow/python/lib/core/strings.i"
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%{
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||
%}
|
||||
|
||||
|
||||
%include "tensorflow/c/tf_datatype.h"
|
||||
%include "tensorflow/c/tf_status.h"
|
||||
|
||||
|
@ -17,8 +17,6 @@ limitations under the License.
|
||||
* The includes are intentionally not alphabetically sorted, as the order of
|
||||
* includes follows dependency order */
|
||||
|
||||
%include "tensorflow/python/util/py_checkpoint_reader.i"
|
||||
|
||||
%include "tensorflow/python/pywrap_tfe.i"
|
||||
|
||||
|
||||
|
@ -46,7 +46,6 @@ from google.protobuf import text_format
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.protobuf import saver_pb2
|
||||
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import graph_util
|
||||
from tensorflow.python.framework import importer
|
||||
@ -56,6 +55,7 @@ from tensorflow.python.saved_model import loader
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.tools import saved_model_utils
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import py_checkpoint_reader
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
|
||||
|
||||
@ -161,7 +161,7 @@ def freeze_graph_with_def_protos(input_graph_def,
|
||||
loader.load(sess, saved_model_tags, input_saved_model_dir)
|
||||
else:
|
||||
var_list = {}
|
||||
reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
|
||||
reader = py_checkpoint_reader.NewCheckpointReader(input_checkpoint)
|
||||
var_to_shape_map = reader.get_variable_to_shape_map()
|
||||
|
||||
# List of all partition variables. Because the condition is heuristic
|
||||
|
@ -23,9 +23,9 @@ import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.platform import flags
|
||||
from tensorflow.python.training import py_checkpoint_reader
|
||||
|
||||
FLAGS = None
|
||||
|
||||
@ -72,7 +72,7 @@ def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors,
|
||||
count_exclude_pattern: Regex string, pattern to exclude tensors when count.
|
||||
"""
|
||||
try:
|
||||
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
|
||||
reader = py_checkpoint_reader.NewCheckpointReader(file_name)
|
||||
if all_tensors or all_tensor_names:
|
||||
var_to_shape_map = reader.get_variable_to_shape_map()
|
||||
var_to_dtype_map = reader.get_variable_to_dtype_map()
|
||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
||||
import time
|
||||
import six
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import io_ops
|
||||
@ -31,6 +30,7 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import py_checkpoint_reader
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -63,7 +63,7 @@ def load_checkpoint(ckpt_dir_or_file):
|
||||
if filename is None:
|
||||
raise ValueError("Couldn't find 'checkpoint' file or checkpoints in "
|
||||
"given directory %s" % ckpt_dir_or_file)
|
||||
return pywrap_tensorflow.NewCheckpointReader(filename)
|
||||
return py_checkpoint_reader.NewCheckpointReader(filename)
|
||||
|
||||
|
||||
@tf_export("train.load_variable")
|
||||
|
99
tensorflow/python/training/py_checkpoint_reader.py
Normal file
99
tensorflow/python/training/py_checkpoint_reader.py
Normal file
@ -0,0 +1,99 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Extending CheckpointReader for TensorFlow."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python._pywrap_checkpoint_reader import CheckpointReader
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
def error_translator(e):
|
||||
"""Translate the tensor_slice_reader.cc errors."""
|
||||
# TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the
|
||||
# issue with throwing python exceptions from C++.
|
||||
error_message = str(e)
|
||||
if 'not found in checkpoint' in error_message or (
|
||||
'Failed to find any '
|
||||
'matching files for') in error_message:
|
||||
raise errors_impl.NotFoundError(None, None, error_message)
|
||||
elif 'Sliced checkpoints are not supported' in error_message or (
|
||||
'Data type '
|
||||
'not '
|
||||
'supported') in error_message:
|
||||
raise errors_impl.UnimplementedError(None, None, error_message)
|
||||
elif 'Failed to get matching files on' in error_message:
|
||||
raise errors_impl.InvalidArgumentError(None, None, error_message)
|
||||
elif 'Unable to open table file' in error_message:
|
||||
raise errors_impl.DataLossError(None, None, error_message)
|
||||
elif 'Failed to find the saved tensor slices' in error_message:
|
||||
raise errors_impl.InternalError(None, None, error_message)
|
||||
else:
|
||||
raise errors_impl.OpError(None, None, error_message, errors_impl.UNKNOWN)
|
||||
|
||||
|
||||
def get_variable_to_dtype_map(self):
|
||||
return {
|
||||
name: dtypes.DType(type_enum)
|
||||
for name, type_enum in self._GetVariableToDataTypeMap().items() # pylint: disable=protected-access
|
||||
}
|
||||
|
||||
CheckpointReader.get_variable_to_dtype_map = get_variable_to_dtype_map
|
||||
|
||||
|
||||
def has_tensor(self, tensor_str):
|
||||
return self._HasTensor(compat.as_bytes(tensor_str)) # pylint: disable=protected-access
|
||||
|
||||
CheckpointReader.has_tensor = has_tensor
|
||||
|
||||
|
||||
def get_tensor(self, tensor_str):
|
||||
"""Get the tensor from the Checkpoint object."""
|
||||
try:
|
||||
return CheckpointReader.CheckpointReader_GetTensor(
|
||||
self, compat.as_bytes(tensor_str))
|
||||
# TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the
|
||||
# issue with throwing python exceptions from C++.
|
||||
except RuntimeError as e:
|
||||
error_translator(e)
|
||||
|
||||
|
||||
CheckpointReader.get_tensor = get_tensor
|
||||
|
||||
|
||||
# Disable invalid name to keep backwards compatibility with that function.
|
||||
# It was previously exported from py_checkpoint_reader.i which did not conform
|
||||
# to pylint checks.
|
||||
# pylint: disable=invalid-name
|
||||
@tf_export(v1=['train.NewCheckpointReader'])
|
||||
def NewCheckpointReader(filepattern):
|
||||
"""A function that returns a CheckPointReader.
|
||||
|
||||
Args:
|
||||
filepattern: The filename.
|
||||
|
||||
Returns:
|
||||
A CheckpointReader object.
|
||||
"""
|
||||
try:
|
||||
return CheckpointReader(compat.as_bytes(filepattern))
|
||||
# TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the
|
||||
# issue with throwing python exceptions from C++.
|
||||
except RuntimeError as e:
|
||||
error_translator(e)
|
@ -32,7 +32,6 @@ import numpy as np
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.core.protobuf import saver_pb2
|
||||
from tensorflow.core.protobuf import trackable_object_graph_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -49,6 +48,7 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import py_checkpoint_reader
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.saving import saveable_object
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
@ -1614,7 +1614,7 @@ def object_graph_key_mapping(checkpoint_path):
|
||||
Returns:
|
||||
Dictionary mapping tensor names to checkpoint keys.
|
||||
"""
|
||||
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
|
||||
reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path)
|
||||
object_graph_string = reader.get_tensor(trackable.OBJECT_GRAPH_PROTO_KEY)
|
||||
object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
|
||||
object_graph_proto.ParseFromString(object_graph_string)
|
||||
|
@ -34,7 +34,6 @@ from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.core.protobuf import queue_runner_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.core.protobuf import saver_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
@ -69,6 +68,7 @@ from tensorflow.python.summary import summary
|
||||
from tensorflow.python.training import adam
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import gradient_descent
|
||||
from tensorflow.python.training import py_checkpoint_reader
|
||||
from tensorflow.python.training import queue_runner_impl
|
||||
from tensorflow.python.training import saver as saver_module
|
||||
from tensorflow.python.training import saver_test_utils
|
||||
@ -2468,7 +2468,7 @@ class CheckpointReaderTest(test.TestCase):
|
||||
save.save(sess, save_path)
|
||||
|
||||
# Creates a reader.
|
||||
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
|
||||
reader = py_checkpoint_reader.NewCheckpointReader(save_path)
|
||||
# Verifies that the tensors exist.
|
||||
self.assertTrue(reader.has_tensor("v0"))
|
||||
self.assertTrue(reader.has_tensor("v1"))
|
||||
@ -2493,7 +2493,7 @@ class CheckpointReaderTest(test.TestCase):
|
||||
def testNonexistentPath(self):
|
||||
with self.assertRaisesRegexp(errors.NotFoundError,
|
||||
"Unsuccessful TensorSliceReader"):
|
||||
pywrap_tensorflow.NewCheckpointReader("non-existent")
|
||||
py_checkpoint_reader.NewCheckpointReader("non-existent")
|
||||
|
||||
|
||||
class CheckpointReaderForV2Test(CheckpointReaderTest):
|
||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
||||
import os
|
||||
import time
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.module import module
|
||||
@ -29,6 +28,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_io_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import py_checkpoint_reader
|
||||
from tensorflow.python.training.saving import saveable_object
|
||||
from tensorflow.python.training.tracking import base
|
||||
from tensorflow.python.training.tracking import util
|
||||
@ -116,7 +116,7 @@ class SavingBenchmarks(test.Benchmark):
|
||||
|
||||
def benchmark_raw_restore(self):
|
||||
checkpoint_path = _save_checkpoint()
|
||||
all_names, all_dtypes = zip(*pywrap_tensorflow.NewCheckpointReader(
|
||||
all_names, all_dtypes = zip(*py_checkpoint_reader.NewCheckpointReader(
|
||||
checkpoint_path).get_variable_to_dtype_map().items())
|
||||
|
||||
def _call_restore_v2():
|
||||
|
@ -25,7 +25,6 @@ import weakref
|
||||
import six
|
||||
|
||||
from tensorflow.core.protobuf import trackable_object_graph_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -43,6 +42,7 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import py_checkpoint_reader
|
||||
from tensorflow.python.training import saver as v1_saver_lib
|
||||
from tensorflow.python.training.saving import functional_saver
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
@ -200,7 +200,7 @@ class _CheckpointRestoreCoordinator(object):
|
||||
self.all_python_objects = object_identity.ObjectIdentityWeakSet()
|
||||
self.save_path_tensor = save_path_tensor
|
||||
self.save_path_string = save_path
|
||||
self.dtype_map = pywrap_tensorflow.NewCheckpointReader(
|
||||
self.dtype_map = py_checkpoint_reader.NewCheckpointReader(
|
||||
save_path).get_variable_to_dtype_map()
|
||||
# A NewCheckpointReader for the most recent checkpoint, for streaming Python
|
||||
# state restoration.
|
||||
@ -272,7 +272,7 @@ class _CheckpointRestoreCoordinator(object):
|
||||
if reader is None:
|
||||
# Lazily create the NewCheckpointReader, since this requires file access
|
||||
# and we may not have any Python saveables.
|
||||
reader = pywrap_tensorflow.NewCheckpointReader(self.save_path_string)
|
||||
reader = py_checkpoint_reader.NewCheckpointReader(self.save_path_string)
|
||||
spec_names = [spec.name for spec in saveable.specs]
|
||||
saveable.python_restore([reader.get_tensor(name) for name in spec_names])
|
||||
|
||||
@ -462,7 +462,7 @@ def object_metadata(save_path):
|
||||
Raises:
|
||||
ValueError: If an object graph was not found in the checkpoint.
|
||||
"""
|
||||
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
|
||||
reader = py_checkpoint_reader.NewCheckpointReader(save_path)
|
||||
try:
|
||||
object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
|
||||
except errors_impl.NotFoundError:
|
||||
@ -1238,7 +1238,7 @@ class TrackableSaver(object):
|
||||
"""
|
||||
if save_path is None:
|
||||
return InitializationOnlyStatus(self._graph_view, ops.uid())
|
||||
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
|
||||
reader = py_checkpoint_reader.NewCheckpointReader(save_path)
|
||||
graph_building = not context.executing_eagerly()
|
||||
if graph_building:
|
||||
dtype_map = None
|
||||
|
@ -110,7 +110,7 @@ from tensorflow.python.training.training_util import create_global_step
|
||||
from tensorflow.python.training.training_util import get_or_create_global_step
|
||||
from tensorflow.python.training.warm_starting_util import VocabInfo
|
||||
from tensorflow.python.training.warm_starting_util import warm_start
|
||||
from tensorflow.python.pywrap_tensorflow import NewCheckpointReader
|
||||
from tensorflow.python.training.py_checkpoint_reader import NewCheckpointReader
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
# pylint: disable=wildcard-import
|
||||
@ -145,4 +145,3 @@ tf_export(v1=["train.SaverDef"])(SaverDef)
|
||||
tf_export("train.SequenceExample")(SequenceExample)
|
||||
tf_export("train.ServerDef")(ServerDef)
|
||||
# pylint: enable=undefined-variable
|
||||
|
||||
|
@ -335,14 +335,16 @@ def _get_grouped_variables(vars_to_warm_start):
|
||||
ValueError: If vars_to_warm_start is not a string, `None`, a list of
|
||||
`Variables`, or a list of strings.
|
||||
"""
|
||||
if isinstance(vars_to_warm_start, str) or vars_to_warm_start is None:
|
||||
# TODO(b/143899805): Remove unicode checks when deprecating Python2.
|
||||
if isinstance(vars_to_warm_start,
|
||||
six.string_types) or vars_to_warm_start is None:
|
||||
# Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match
|
||||
# everything (in TRAINABLE_VARIABLES) here.
|
||||
logging.info("Warm-starting variables only in TRAINABLE_VARIABLES.")
|
||||
list_of_vars = ops.get_collection(
|
||||
ops.GraphKeys.TRAINABLE_VARIABLES, scope=vars_to_warm_start)
|
||||
elif isinstance(vars_to_warm_start, list):
|
||||
if all(isinstance(v, str) for v in vars_to_warm_start):
|
||||
if all(isinstance(v, six.string_types) for v in vars_to_warm_start):
|
||||
list_of_vars = []
|
||||
for v in vars_to_warm_start:
|
||||
list_of_vars += ops.get_collection(
|
||||
|
@ -1,167 +0,0 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%include "tensorflow/python/lib/core/strings.i"
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%{
|
||||
#include "tensorflow/c/checkpoint_reader.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||
%}
|
||||
|
||||
%typemap(out) const tensorflow::checkpoint::TensorSliceReader::VarToShapeMap& {
|
||||
tensorflow::Safe_PyObjectPtr output_map(tensorflow::make_safe(PyDict_New()));
|
||||
for (auto v : *$1) {
|
||||
%#if PY_MAJOR_VERSION >= 3
|
||||
tensorflow::Safe_PyObjectPtr key(
|
||||
tensorflow::make_safe(PyUnicode_FromStringAndSize(v.first.c_str(),
|
||||
v.first.size())));
|
||||
%#else
|
||||
tensorflow::Safe_PyObjectPtr key(
|
||||
tensorflow::make_safe(PyString_FromStringAndSize(v.first.c_str(),
|
||||
v.first.size())));
|
||||
%#endif
|
||||
if (!key) {
|
||||
SWIG_fail;
|
||||
}
|
||||
size_t dims = v.second.dims();
|
||||
tensorflow::Safe_PyObjectPtr value(tensorflow::make_safe(PyList_New(dims)));
|
||||
if (!value) {
|
||||
SWIG_fail;
|
||||
}
|
||||
for (size_t i = 0; i < dims; ++i) {
|
||||
%#if PY_MAJOR_VERSION >= 3
|
||||
tensorflow::Safe_PyObjectPtr dim_value(
|
||||
tensorflow::make_safe(PyLong_FromLong(v.second.dim_size(i))));
|
||||
%#else
|
||||
tensorflow::Safe_PyObjectPtr dim_value(
|
||||
tensorflow::make_safe(PyInt_FromLong(v.second.dim_size(i))));
|
||||
%#endif
|
||||
if (!dim_value) {
|
||||
SWIG_fail;
|
||||
}
|
||||
PyList_SET_ITEM(value.get(), i, dim_value.release());
|
||||
}
|
||||
if (PyDict_SetItem(output_map.get(), key.get(), value.get()) == -1) {
|
||||
SWIG_fail;
|
||||
} else {
|
||||
key.release();
|
||||
value.release();
|
||||
}
|
||||
}
|
||||
|
||||
$result = output_map.release();
|
||||
}
|
||||
|
||||
%typemap(out) const tensorflow::checkpoint::TensorSliceReader::VarToDataTypeMap& {
|
||||
tensorflow::Safe_PyObjectPtr output_map(tensorflow::make_safe(PyDict_New()));
|
||||
for (auto v : *$1) {
|
||||
%#if PY_MAJOR_VERSION >= 3
|
||||
tensorflow::Safe_PyObjectPtr key(
|
||||
tensorflow::make_safe(PyUnicode_FromStringAndSize(v.first.c_str(), v.first.size())));
|
||||
%#else
|
||||
tensorflow::Safe_PyObjectPtr key(
|
||||
tensorflow::make_safe(PyString_FromStringAndSize(v.first.c_str(), v.first.size())));
|
||||
%#endif
|
||||
if (!key) {
|
||||
SWIG_fail;
|
||||
}
|
||||
%#if PY_MAJOR_VERSION >= 3
|
||||
tensorflow::Safe_PyObjectPtr value(tensorflow::make_safe(PyLong_FromLong(v.second)));
|
||||
%#else
|
||||
tensorflow::Safe_PyObjectPtr value(tensorflow::make_safe(PyInt_FromLong(v.second)));
|
||||
%#endif
|
||||
if (!value) {
|
||||
SWIG_fail;
|
||||
}
|
||||
if (PyDict_SetItem(output_map.get(), key.get(), value.get()) == -1) {
|
||||
SWIG_fail;
|
||||
} else {
|
||||
key.release();
|
||||
value.release();
|
||||
}
|
||||
}
|
||||
|
||||
$result = output_map.release();
|
||||
}
|
||||
|
||||
%{
|
||||
static PyObject* CheckpointReader_GetTensor(
|
||||
tensorflow::checkpoint::CheckpointReader* reader,
|
||||
const string& name,
|
||||
TF_Status* status) {
|
||||
PyObject* py_obj = Py_None;
|
||||
std::unique_ptr<tensorflow::Tensor> tensor;
|
||||
reader->GetTensor(name, &tensor, status);
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
tensorflow::Status s =
|
||||
tensorflow::TensorToNdarray(*tensor.get(), &py_obj);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
}
|
||||
return PyArray_Return(reinterpret_cast<PyArrayObject*>(py_obj));
|
||||
}
|
||||
%}
|
||||
|
||||
// Wrap this function.
|
||||
PyObject* CheckpointReader_GetTensor(
|
||||
tensorflow::checkpoint::CheckpointReader* reader,
|
||||
const string& name,
|
||||
TF_Status* status);
|
||||
|
||||
%ignoreall
|
||||
|
||||
%unignore tensorflow;
|
||||
%unignore tensorflow::checkpoint;
|
||||
%unignore tensorflow::checkpoint::CheckpointReader;
|
||||
%unignore tensorflow::checkpoint::CheckpointReader::CheckpointReader;
|
||||
%unignore tensorflow::checkpoint::CheckpointReader::~CheckpointReader;
|
||||
%rename("debug_string") tensorflow::checkpoint::CheckpointReader::DebugString;
|
||||
%rename("get_variable_to_shape_map") tensorflow::checkpoint::CheckpointReader::GetVariableToShapeMap;
|
||||
%rename("_GetVariableToDataTypeMap") tensorflow::checkpoint::CheckpointReader::GetVariableToDataTypeMap;
|
||||
%rename("_HasTensor") tensorflow::checkpoint::CheckpointReader::HasTensor;
|
||||
%unignore CheckpointReader_GetTensor;
|
||||
|
||||
%extend tensorflow::checkpoint::CheckpointReader {
|
||||
%insert("python") %{
|
||||
def get_variable_to_dtype_map(self):
|
||||
from tensorflow.python.framework import dtypes
|
||||
return {name: dtypes.DType(type_enum)
|
||||
for name, type_enum in self._GetVariableToDataTypeMap().items()}
|
||||
|
||||
def has_tensor(self, tensor_str):
|
||||
from tensorflow.python.util import compat
|
||||
return self._HasTensor(compat.as_bytes(tensor_str))
|
||||
|
||||
def get_tensor(self, tensor_str):
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
return CheckpointReader_GetTensor(self, compat.as_bytes(tensor_str))
|
||||
%}
|
||||
}
|
||||
|
||||
%insert("python") %{
|
||||
def NewCheckpointReader(filepattern):
|
||||
from tensorflow.python.util import compat
|
||||
return CheckpointReader(compat.as_bytes(filepattern))
|
||||
|
||||
NewCheckpointReader._tf_api_names_v1 = ['train.NewCheckpointReader']
|
||||
%}
|
||||
|
||||
%include "tensorflow/c/checkpoint_reader.h"
|
||||
%unignoreall
|
150
tensorflow/python/util/py_checkpoint_reader_wrapper.cc
Normal file
150
tensorflow/python/util/py_checkpoint_reader_wrapper.cc
Normal file
@ -0,0 +1,150 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Disallow Numpy 1.7 deprecated symbols.
|
||||
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
|
||||
|
||||
#include "numpy/arrayobject.h"
|
||||
#include "numpy/ufuncobject.h"
|
||||
#include "include/pybind11/chrono.h"
|
||||
#include "include/pybind11/complex.h"
|
||||
#include "include/pybind11/functional.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
#include "tensorflow/c/checkpoint_reader.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
||||
#include "tensorflow/python/lib/core/py_exception_registry.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
// TODO(amitpatankar): Move the custom type casters to separate common header
|
||||
// only libraries.
|
||||
|
||||
namespace pybind11 {
|
||||
namespace detail {
|
||||
|
||||
/* This is a custom type caster for the TensorShape object. For more
|
||||
* documentation please refer to this link:
|
||||
* https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html#custom-type-casters
|
||||
* The PyCheckpointReader methods sometimes return the `TensorShape` object
|
||||
* and the `DataType` object as outputs. This custom type caster helps Python
|
||||
* handle it's conversion from C++ to Python. Since we do not accept these
|
||||
* classes as arguments from Python, it is not necessary to define the `load`
|
||||
* function to cast the object from Python to a C++ object.
|
||||
*/
|
||||
|
||||
template <>
|
||||
struct type_caster<tensorflow::TensorShape> {
|
||||
public:
|
||||
PYBIND11_TYPE_CASTER(tensorflow::TensorShape, _("tensorflow::TensorShape"));
|
||||
|
||||
static handle cast(const tensorflow::TensorShape& src,
|
||||
return_value_policy unused_policy, handle unused_handle) {
|
||||
// TODO(amitpatankar): Simplify handling TensorShape as output later.
|
||||
size_t dims = src.dims();
|
||||
tensorflow::Safe_PyObjectPtr value(PyList_New(dims));
|
||||
for (size_t i = 0; i < dims; ++i) {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
tensorflow::Safe_PyObjectPtr dim_value(
|
||||
tensorflow::make_safe(PyLong_FromLong(src.dim_size(i))));
|
||||
#else
|
||||
tensorflow::Safe_PyObjectPtr dim_value(
|
||||
tensorflow::make_safe(PyInt_FromLong(src.dim_size(i))));
|
||||
#endif
|
||||
PyList_SET_ITEM(value.get(), i, dim_value.release());
|
||||
}
|
||||
|
||||
return value.release();
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_caster<tensorflow::DataType> {
|
||||
public:
|
||||
PYBIND11_TYPE_CASTER(tensorflow::DataType, _("tensorflow::DataType"));
|
||||
|
||||
static handle cast(const tensorflow::DataType& src,
|
||||
return_value_policy unused_policy, handle unused_handle) {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
tensorflow::Safe_PyObjectPtr value(
|
||||
tensorflow::make_safe(PyLong_FromLong(src)));
|
||||
#else
|
||||
tensorflow::Safe_PyObjectPtr value(
|
||||
tensorflow::make_safe(PyInt_FromLong(src)));
|
||||
#endif
|
||||
return value.release();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace pybind11
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
static py::object CheckpointReader_GetTensor(
|
||||
tensorflow::checkpoint::CheckpointReader* reader, const string& name) {
|
||||
Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
|
||||
PyObject* py_obj = Py_None;
|
||||
std::unique_ptr<tensorflow::Tensor> tensor;
|
||||
reader->GetTensor(name, &tensor, status.get());
|
||||
|
||||
// Error handling if unable to get Tensor.
|
||||
tensorflow::MaybeRaiseFromTFStatus(status.get());
|
||||
|
||||
tensorflow::MaybeRaiseFromStatus(
|
||||
tensorflow::TensorToNdarray(*tensor, &py_obj));
|
||||
|
||||
return tensorflow::pyo_or_throw(
|
||||
PyArray_Return(reinterpret_cast<PyArrayObject*>(py_obj)));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
PYBIND11_MODULE(_pywrap_checkpoint_reader, m) {
|
||||
// Initialization code to use numpy types in the type casters.
|
||||
import_array1();
|
||||
py::class_<tensorflow::checkpoint::CheckpointReader> checkpoint_reader_class(
|
||||
m, "CheckpointReader");
|
||||
checkpoint_reader_class
|
||||
.def(py::init([](const std::string& filename) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
// pybind11 support smart pointers and will own freeing the memory when
|
||||
// complete.
|
||||
// https://pybind11.readthedocs.io/en/master/advanced/smart_ptrs.html#std-unique-ptr
|
||||
auto checkpoint =
|
||||
std::make_unique<tensorflow::checkpoint::CheckpointReader>(
|
||||
filename, status.get());
|
||||
tensorflow::MaybeRaiseFromTFStatus(status.get());
|
||||
return checkpoint;
|
||||
}))
|
||||
.def("debug_string",
|
||||
[](tensorflow::checkpoint::CheckpointReader& self) {
|
||||
return py::bytes(self.DebugString());
|
||||
})
|
||||
.def("get_variable_to_shape_map",
|
||||
&tensorflow::checkpoint::CheckpointReader::GetVariableToShapeMap)
|
||||
.def("_GetVariableToDataTypeMap",
|
||||
&tensorflow::checkpoint::CheckpointReader::GetVariableToDataTypeMap)
|
||||
.def("_HasTensor", &tensorflow::checkpoint::CheckpointReader::HasTensor)
|
||||
.def_static("CheckpointReader_GetTensor",
|
||||
&tensorflow::CheckpointReader_GetTensor);
|
||||
};
|
@ -107,3 +107,21 @@ tensorflow::profiler::PythonTraceMe::IsEnabled
|
||||
[traceme_recorder] # traceme
|
||||
tensorflow::profiler::internal::g_trace_level
|
||||
|
||||
[checkpoint_reader] # py_checkpoint_reader
|
||||
tensorflow::checkpoint::CheckpointReader
|
||||
tensorflow::checkpoint::CheckpointReader::Init
|
||||
tensorflow::checkpoint::CheckpointReader::DebugString
|
||||
tensorflow::checkpoint::CheckpointReader::GetVariableToShapeMap
|
||||
tensorflow::checkpoint::CheckpointReader::GetVariableToDataTypeMap
|
||||
tensorflow::checkpoint::CheckpointReader::GetTensor
|
||||
tensorflow::checkpoint::CheckpointReader::HasTensor
|
||||
|
||||
[tensor_bundle] # py_checkpoint_reader
|
||||
tensorflow::BundleReader::BundleReader
|
||||
tensorflow::BundleReader::~BundleReader
|
||||
|
||||
[ndarray_tensor] # py_checkpoint_reader
|
||||
tensorflow::TensorToNdarray
|
||||
|
||||
[safe_ptr] # py_checkpoint_reader
|
||||
tensorflow::detail::PyDecrefDeleter
|
||||
|
Loading…
Reference in New Issue
Block a user