Export the checkpoint reader classes and functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.
PiperOrigin-RevId: 279101529 Change-Id: I25502ed3d3718499abca41f5614681f41e4c7199
This commit is contained in:
parent
8977382c53
commit
b66e4e833c
@ -293,6 +293,15 @@ exports_files(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "checkpoint_reader_hdrs",
|
||||||
|
srcs = [
|
||||||
|
"checkpoint_reader.h",
|
||||||
|
"tf_status_helper.h",
|
||||||
|
],
|
||||||
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cuda_library(
|
tf_cuda_library(
|
||||||
name = "tf_status_helper",
|
name = "tf_status_helper",
|
||||||
srcs = ["tf_status_helper.cc"],
|
srcs = ["tf_status_helper.cc"],
|
||||||
|
@ -98,6 +98,7 @@ py_library(
|
|||||||
"//third_party/py/tensorflow_core:__subpackages__",
|
"//third_party/py/tensorflow_core:__subpackages__",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":_pywrap_checkpoint_reader",
|
||||||
":_pywrap_events_writer",
|
":_pywrap_events_writer",
|
||||||
":_pywrap_kernel_registry",
|
":_pywrap_kernel_registry",
|
||||||
":_pywrap_py_exception_registry",
|
":_pywrap_py_exception_registry",
|
||||||
@ -629,6 +630,32 @@ tf_python_pybind_extension(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_python_pybind_extension(
|
||||||
|
name = "_pywrap_checkpoint_reader",
|
||||||
|
srcs = ["util/py_checkpoint_reader_wrapper.cc"],
|
||||||
|
hdrs = [
|
||||||
|
"lib/core/ndarray_tensor.h",
|
||||||
|
"lib/core/safe_ptr.h",
|
||||||
|
":py_exception_registry_hdr",
|
||||||
|
"//tensorflow/c:checkpoint_reader_hdrs",
|
||||||
|
"//tensorflow/c:headers",
|
||||||
|
"//tensorflow/c/eager:headers",
|
||||||
|
],
|
||||||
|
module_name = "_pywrap_checkpoint_reader",
|
||||||
|
deps = [
|
||||||
|
":pybind11_lib",
|
||||||
|
":pybind11_status",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:op_gen_lib",
|
||||||
|
"//tensorflow/core:protos_all",
|
||||||
|
"//tensorflow/core/util/tensor_bundle:tensor_bundle_headers_lib",
|
||||||
|
"//third_party/py/numpy:headers",
|
||||||
|
"//third_party/python_runtime:headers",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "py_exception_registry_hdr",
|
name = "py_exception_registry_hdr",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -996,6 +1023,7 @@ py_library(
|
|||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":_pywrap_checkpoint_reader",
|
||||||
":_pywrap_debug_events_writer",
|
":_pywrap_debug_events_writer",
|
||||||
":_pywrap_events_writer",
|
":_pywrap_events_writer",
|
||||||
":_pywrap_kernel_registry",
|
":_pywrap_kernel_registry",
|
||||||
@ -4730,6 +4758,7 @@ py_library(
|
|||||||
":math_ops",
|
":math_ops",
|
||||||
":mixed_precision",
|
":mixed_precision",
|
||||||
":platform",
|
":platform",
|
||||||
|
":py_checkpoint_reader",
|
||||||
":pywrap_tensorflow",
|
":pywrap_tensorflow",
|
||||||
":random_ops",
|
":random_ops",
|
||||||
":resource_variable_ops",
|
":resource_variable_ops",
|
||||||
@ -4785,6 +4814,17 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "py_checkpoint_reader",
|
||||||
|
srcs = ["training/py_checkpoint_reader.py"],
|
||||||
|
deps = [
|
||||||
|
":_pywrap_checkpoint_reader",
|
||||||
|
":dtypes",
|
||||||
|
":errors",
|
||||||
|
":util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "checkpoint_management",
|
name = "checkpoint_management",
|
||||||
srcs = ["training/checkpoint_management.py"],
|
srcs = ["training/checkpoint_management.py"],
|
||||||
@ -5284,7 +5324,6 @@ tf_py_wrap_cc(
|
|||||||
"lib/io/py_record_writer.i",
|
"lib/io/py_record_writer.i",
|
||||||
"platform/base.i",
|
"platform/base.i",
|
||||||
"pywrap_tfe.i",
|
"pywrap_tfe.i",
|
||||||
"util/py_checkpoint_reader.i",
|
|
||||||
"//tensorflow/compiler/mlir/python:mlir.i",
|
"//tensorflow/compiler/mlir/python:mlir.i",
|
||||||
],
|
],
|
||||||
# add win_def_file for pywrap_tensorflow
|
# add win_def_file for pywrap_tensorflow
|
||||||
@ -5336,6 +5375,7 @@ tf_py_wrap_cc(
|
|||||||
"//tensorflow/tools/graph_transforms:transform_graph_lib",
|
"//tensorflow/tools/graph_transforms:transform_graph_lib",
|
||||||
"//tensorflow/lite/toco/python:toco_python_api",
|
"//tensorflow/lite/toco/python:toco_python_api",
|
||||||
"//tensorflow/python/eager:pywrap_tfe_lib",
|
"//tensorflow/python/eager:pywrap_tfe_lib",
|
||||||
|
"//tensorflow/core/util/tensor_bundle:tensor_bundle",
|
||||||
] + (tf_additional_lib_deps() +
|
] + (tf_additional_lib_deps() +
|
||||||
tf_additional_plugin_deps()) + if_ngraph([
|
tf_additional_plugin_deps()) + if_ngraph([
|
||||||
"@ngraph_tf//:ngraph_tf",
|
"@ngraph_tf//:ngraph_tf",
|
||||||
@ -5391,9 +5431,14 @@ genrule(
|
|||||||
"//tensorflow/core/profiler/internal:python_traceme", # traceme
|
"//tensorflow/core/profiler/internal:python_traceme", # traceme
|
||||||
"//tensorflow/core/profiler/internal:traceme_recorder", # traceme
|
"//tensorflow/core/profiler/internal:traceme_recorder", # traceme
|
||||||
":py_exception_registry", # py_exception_registry
|
":py_exception_registry", # py_exception_registry
|
||||||
":kernel_registry",
|
":kernel_registry", # kernel_registry
|
||||||
"//tensorflow/lite/toco/python:toco_python_api", # toco
|
"//tensorflow/lite/toco/python:toco_python_api", # toco
|
||||||
"//tensorflow/tools/graph_transforms:transform_graph_lib", # transform_graph
|
"//tensorflow/tools/graph_transforms:transform_graph_lib", # transform_graph
|
||||||
|
"//tensorflow/c:checkpoint_reader", # checkpoint_reader
|
||||||
|
":ndarray_tensor", # checkpoint_reader
|
||||||
|
":numpy_lib", # checkpoint_reader
|
||||||
|
":safe_ptr", # checkpoint_reader
|
||||||
|
"//tensorflow/core/util/tensor_bundle", # checkpoint_reader
|
||||||
],
|
],
|
||||||
outs = ["pybind_symbol_target_libs_file.txt"],
|
outs = ["pybind_symbol_target_libs_file.txt"],
|
||||||
cmd = select({
|
cmd = select({
|
||||||
@ -6211,7 +6256,6 @@ tf_py_test(
|
|||||||
":io_ops",
|
":io_ops",
|
||||||
":partitioned_variables",
|
":partitioned_variables",
|
||||||
":platform",
|
":platform",
|
||||||
":pywrap_tensorflow",
|
|
||||||
":resource_variable_ops",
|
":resource_variable_ops",
|
||||||
":state_ops",
|
":state_ops",
|
||||||
":training",
|
":training",
|
||||||
|
@ -30,7 +30,6 @@ import numpy as np
|
|||||||
import six
|
import six
|
||||||
from six.moves import zip # pylint: disable=redefined-builtin
|
from six.moves import zip # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
@ -54,6 +53,7 @@ from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
|||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import checkpoint_management
|
from tensorflow.python.training import checkpoint_management
|
||||||
|
from tensorflow.python.training import py_checkpoint_reader
|
||||||
from tensorflow.python.training.tracking import base as trackable
|
from tensorflow.python.training.tracking import base as trackable
|
||||||
from tensorflow.python.training.tracking import data_structures
|
from tensorflow.python.training.tracking import data_structures
|
||||||
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
|
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
|
||||||
@ -1176,7 +1176,7 @@ class Network(base_layer.Layer):
|
|||||||
save_format = 'h5'
|
save_format = 'h5'
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
pywrap_tensorflow.NewCheckpointReader(filepath)
|
py_checkpoint_reader.NewCheckpointReader(filepath)
|
||||||
save_format = 'tf'
|
save_format = 'tf'
|
||||||
except errors_impl.DataLossError:
|
except errors_impl.DataLossError:
|
||||||
# The checkpoint is not readable in TensorFlow format. Try HDF5.
|
# The checkpoint is not readable in TensorFlow format. Try HDF5.
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/python/lib/core/bfloat16.h"
|
#include "tensorflow/python/lib/core/bfloat16.h"
|
||||||
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
|
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
|
||||||
|
#include "tensorflow/python/lib/core/numpy.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
@ -527,8 +528,8 @@ Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor) {
|
|||||||
size_t size = 0;
|
size_t size = 0;
|
||||||
void* encoded = nullptr;
|
void* encoded = nullptr;
|
||||||
TF_RETURN_IF_ERROR(EncodePyBytesArray(array, nelems, &size, &encoded));
|
TF_RETURN_IF_ERROR(EncodePyBytesArray(array, nelems, &size, &encoded));
|
||||||
*out_tensor =
|
*out_tensor = make_safe(TF_NewTensor(
|
||||||
make_safe(TF_NewTensor(dtype, dims.data(), dims.size(), encoded, size,
|
dtype, dims.data(), dims.size(), encoded, size,
|
||||||
[](void* data, size_t len, void* arg) {
|
[](void* data, size_t len, void* arg) {
|
||||||
delete[] reinterpret_cast<char*>(data);
|
delete[] reinterpret_cast<char*>(data);
|
||||||
},
|
},
|
||||||
|
@ -16,9 +16,6 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_
|
#ifndef TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_
|
||||||
#define TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_
|
#define TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_
|
||||||
|
|
||||||
// Must be included first.
|
|
||||||
#include "tensorflow/python/lib/core/numpy.h"
|
|
||||||
|
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
@ -13,7 +13,16 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
%include "tensorflow/python/lib/core/strings.i"
|
||||||
%include "tensorflow/python/platform/base.i"
|
%include "tensorflow/python/platform/base.i"
|
||||||
|
|
||||||
|
%{
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
||||||
|
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||||
|
%}
|
||||||
|
|
||||||
|
|
||||||
%include "tensorflow/c/tf_datatype.h"
|
%include "tensorflow/c/tf_datatype.h"
|
||||||
%include "tensorflow/c/tf_status.h"
|
%include "tensorflow/c/tf_status.h"
|
||||||
|
|
||||||
|
@ -17,8 +17,6 @@ limitations under the License.
|
|||||||
* The includes are intentionally not alphabetically sorted, as the order of
|
* The includes are intentionally not alphabetically sorted, as the order of
|
||||||
* includes follows dependency order */
|
* includes follows dependency order */
|
||||||
|
|
||||||
%include "tensorflow/python/util/py_checkpoint_reader.i"
|
|
||||||
|
|
||||||
%include "tensorflow/python/pywrap_tfe.i"
|
%include "tensorflow/python/pywrap_tfe.i"
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,7 +46,6 @@ from google.protobuf import text_format
|
|||||||
from tensorflow.core.framework import graph_pb2
|
from tensorflow.core.framework import graph_pb2
|
||||||
from tensorflow.core.protobuf import saver_pb2
|
from tensorflow.core.protobuf import saver_pb2
|
||||||
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
|
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
|
||||||
from tensorflow.python import pywrap_tensorflow
|
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.framework import graph_util
|
from tensorflow.python.framework import graph_util
|
||||||
from tensorflow.python.framework import importer
|
from tensorflow.python.framework import importer
|
||||||
@ -56,6 +55,7 @@ from tensorflow.python.saved_model import loader
|
|||||||
from tensorflow.python.saved_model import tag_constants
|
from tensorflow.python.saved_model import tag_constants
|
||||||
from tensorflow.python.tools import saved_model_utils
|
from tensorflow.python.tools import saved_model_utils
|
||||||
from tensorflow.python.training import checkpoint_management
|
from tensorflow.python.training import checkpoint_management
|
||||||
|
from tensorflow.python.training import py_checkpoint_reader
|
||||||
from tensorflow.python.training import saver as saver_lib
|
from tensorflow.python.training import saver as saver_lib
|
||||||
|
|
||||||
|
|
||||||
@ -161,7 +161,7 @@ def freeze_graph_with_def_protos(input_graph_def,
|
|||||||
loader.load(sess, saved_model_tags, input_saved_model_dir)
|
loader.load(sess, saved_model_tags, input_saved_model_dir)
|
||||||
else:
|
else:
|
||||||
var_list = {}
|
var_list = {}
|
||||||
reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
|
reader = py_checkpoint_reader.NewCheckpointReader(input_checkpoint)
|
||||||
var_to_shape_map = reader.get_variable_to_shape_map()
|
var_to_shape_map = reader.get_variable_to_shape_map()
|
||||||
|
|
||||||
# List of all partition variables. Because the condition is heuristic
|
# List of all partition variables. Because the condition is heuristic
|
||||||
|
@ -23,9 +23,9 @@ import sys
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
|
||||||
from tensorflow.python.platform import app
|
from tensorflow.python.platform import app
|
||||||
from tensorflow.python.platform import flags
|
from tensorflow.python.platform import flags
|
||||||
|
from tensorflow.python.training import py_checkpoint_reader
|
||||||
|
|
||||||
FLAGS = None
|
FLAGS = None
|
||||||
|
|
||||||
@ -72,7 +72,7 @@ def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors,
|
|||||||
count_exclude_pattern: Regex string, pattern to exclude tensors when count.
|
count_exclude_pattern: Regex string, pattern to exclude tensors when count.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
|
reader = py_checkpoint_reader.NewCheckpointReader(file_name)
|
||||||
if all_tensors or all_tensor_names:
|
if all_tensors or all_tensor_names:
|
||||||
var_to_shape_map = reader.get_variable_to_shape_map()
|
var_to_shape_map = reader.get_variable_to_shape_map()
|
||||||
var_to_dtype_map = reader.get_variable_to_dtype_map()
|
var_to_dtype_map = reader.get_variable_to_dtype_map()
|
||||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
|||||||
import time
|
import time
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
|
||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import io_ops
|
from tensorflow.python.ops import io_ops
|
||||||
@ -31,6 +30,7 @@ from tensorflow.python.ops import variables
|
|||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import checkpoint_management
|
from tensorflow.python.training import checkpoint_management
|
||||||
|
from tensorflow.python.training import py_checkpoint_reader
|
||||||
from tensorflow.python.training.saving import saveable_object_util
|
from tensorflow.python.training.saving import saveable_object_util
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
@ -63,7 +63,7 @@ def load_checkpoint(ckpt_dir_or_file):
|
|||||||
if filename is None:
|
if filename is None:
|
||||||
raise ValueError("Couldn't find 'checkpoint' file or checkpoints in "
|
raise ValueError("Couldn't find 'checkpoint' file or checkpoints in "
|
||||||
"given directory %s" % ckpt_dir_or_file)
|
"given directory %s" % ckpt_dir_or_file)
|
||||||
return pywrap_tensorflow.NewCheckpointReader(filename)
|
return py_checkpoint_reader.NewCheckpointReader(filename)
|
||||||
|
|
||||||
|
|
||||||
@tf_export("train.load_variable")
|
@tf_export("train.load_variable")
|
||||||
|
99
tensorflow/python/training/py_checkpoint_reader.py
Normal file
99
tensorflow/python/training/py_checkpoint_reader.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Extending CheckpointReader for TensorFlow."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python._pywrap_checkpoint_reader import CheckpointReader
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors_impl
|
||||||
|
from tensorflow.python.util import compat
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
|
def error_translator(e):
|
||||||
|
"""Translate the tensor_slice_reader.cc errors."""
|
||||||
|
# TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the
|
||||||
|
# issue with throwing python exceptions from C++.
|
||||||
|
error_message = str(e)
|
||||||
|
if 'not found in checkpoint' in error_message or (
|
||||||
|
'Failed to find any '
|
||||||
|
'matching files for') in error_message:
|
||||||
|
raise errors_impl.NotFoundError(None, None, error_message)
|
||||||
|
elif 'Sliced checkpoints are not supported' in error_message or (
|
||||||
|
'Data type '
|
||||||
|
'not '
|
||||||
|
'supported') in error_message:
|
||||||
|
raise errors_impl.UnimplementedError(None, None, error_message)
|
||||||
|
elif 'Failed to get matching files on' in error_message:
|
||||||
|
raise errors_impl.InvalidArgumentError(None, None, error_message)
|
||||||
|
elif 'Unable to open table file' in error_message:
|
||||||
|
raise errors_impl.DataLossError(None, None, error_message)
|
||||||
|
elif 'Failed to find the saved tensor slices' in error_message:
|
||||||
|
raise errors_impl.InternalError(None, None, error_message)
|
||||||
|
else:
|
||||||
|
raise errors_impl.OpError(None, None, error_message, errors_impl.UNKNOWN)
|
||||||
|
|
||||||
|
|
||||||
|
def get_variable_to_dtype_map(self):
|
||||||
|
return {
|
||||||
|
name: dtypes.DType(type_enum)
|
||||||
|
for name, type_enum in self._GetVariableToDataTypeMap().items() # pylint: disable=protected-access
|
||||||
|
}
|
||||||
|
|
||||||
|
CheckpointReader.get_variable_to_dtype_map = get_variable_to_dtype_map
|
||||||
|
|
||||||
|
|
||||||
|
def has_tensor(self, tensor_str):
|
||||||
|
return self._HasTensor(compat.as_bytes(tensor_str)) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
CheckpointReader.has_tensor = has_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def get_tensor(self, tensor_str):
|
||||||
|
"""Get the tensor from the Checkpoint object."""
|
||||||
|
try:
|
||||||
|
return CheckpointReader.CheckpointReader_GetTensor(
|
||||||
|
self, compat.as_bytes(tensor_str))
|
||||||
|
# TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the
|
||||||
|
# issue with throwing python exceptions from C++.
|
||||||
|
except RuntimeError as e:
|
||||||
|
error_translator(e)
|
||||||
|
|
||||||
|
|
||||||
|
CheckpointReader.get_tensor = get_tensor
|
||||||
|
|
||||||
|
|
||||||
|
# Disable invalid name to keep backwards compatibility with that function.
|
||||||
|
# It was previously exported from py_checkpoint_reader.i which did not conform
|
||||||
|
# to pylint checks.
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
@tf_export(v1=['train.NewCheckpointReader'])
|
||||||
|
def NewCheckpointReader(filepattern):
|
||||||
|
"""A function that returns a CheckPointReader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filepattern: The filename.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A CheckpointReader object.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return CheckpointReader(compat.as_bytes(filepattern))
|
||||||
|
# TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the
|
||||||
|
# issue with throwing python exceptions from C++.
|
||||||
|
except RuntimeError as e:
|
||||||
|
error_translator(e)
|
@ -32,7 +32,6 @@ import numpy as np
|
|||||||
from tensorflow.core.protobuf import meta_graph_pb2
|
from tensorflow.core.protobuf import meta_graph_pb2
|
||||||
from tensorflow.core.protobuf import saver_pb2
|
from tensorflow.core.protobuf import saver_pb2
|
||||||
from tensorflow.core.protobuf import trackable_object_graph_pb2
|
from tensorflow.core.protobuf import trackable_object_graph_pb2
|
||||||
from tensorflow.python import pywrap_tensorflow
|
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
@ -49,6 +48,7 @@ from tensorflow.python.ops import variables
|
|||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import checkpoint_management
|
from tensorflow.python.training import checkpoint_management
|
||||||
|
from tensorflow.python.training import py_checkpoint_reader
|
||||||
from tensorflow.python.training import training_util
|
from tensorflow.python.training import training_util
|
||||||
from tensorflow.python.training.saving import saveable_object
|
from tensorflow.python.training.saving import saveable_object
|
||||||
from tensorflow.python.training.saving import saveable_object_util
|
from tensorflow.python.training.saving import saveable_object_util
|
||||||
@ -1614,7 +1614,7 @@ def object_graph_key_mapping(checkpoint_path):
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary mapping tensor names to checkpoint keys.
|
Dictionary mapping tensor names to checkpoint keys.
|
||||||
"""
|
"""
|
||||||
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
|
reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path)
|
||||||
object_graph_string = reader.get_tensor(trackable.OBJECT_GRAPH_PROTO_KEY)
|
object_graph_string = reader.get_tensor(trackable.OBJECT_GRAPH_PROTO_KEY)
|
||||||
object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
|
object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
|
||||||
object_graph_proto.ParseFromString(object_graph_string)
|
object_graph_proto.ParseFromString(object_graph_string)
|
||||||
|
@ -34,7 +34,6 @@ from tensorflow.core.protobuf import meta_graph_pb2
|
|||||||
from tensorflow.core.protobuf import queue_runner_pb2
|
from tensorflow.core.protobuf import queue_runner_pb2
|
||||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||||
from tensorflow.core.protobuf import saver_pb2
|
from tensorflow.core.protobuf import saver_pb2
|
||||||
from tensorflow.python import pywrap_tensorflow
|
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -69,6 +68,7 @@ from tensorflow.python.summary import summary
|
|||||||
from tensorflow.python.training import adam
|
from tensorflow.python.training import adam
|
||||||
from tensorflow.python.training import checkpoint_management
|
from tensorflow.python.training import checkpoint_management
|
||||||
from tensorflow.python.training import gradient_descent
|
from tensorflow.python.training import gradient_descent
|
||||||
|
from tensorflow.python.training import py_checkpoint_reader
|
||||||
from tensorflow.python.training import queue_runner_impl
|
from tensorflow.python.training import queue_runner_impl
|
||||||
from tensorflow.python.training import saver as saver_module
|
from tensorflow.python.training import saver as saver_module
|
||||||
from tensorflow.python.training import saver_test_utils
|
from tensorflow.python.training import saver_test_utils
|
||||||
@ -2468,7 +2468,7 @@ class CheckpointReaderTest(test.TestCase):
|
|||||||
save.save(sess, save_path)
|
save.save(sess, save_path)
|
||||||
|
|
||||||
# Creates a reader.
|
# Creates a reader.
|
||||||
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
|
reader = py_checkpoint_reader.NewCheckpointReader(save_path)
|
||||||
# Verifies that the tensors exist.
|
# Verifies that the tensors exist.
|
||||||
self.assertTrue(reader.has_tensor("v0"))
|
self.assertTrue(reader.has_tensor("v0"))
|
||||||
self.assertTrue(reader.has_tensor("v1"))
|
self.assertTrue(reader.has_tensor("v1"))
|
||||||
@ -2493,7 +2493,7 @@ class CheckpointReaderTest(test.TestCase):
|
|||||||
def testNonexistentPath(self):
|
def testNonexistentPath(self):
|
||||||
with self.assertRaisesRegexp(errors.NotFoundError,
|
with self.assertRaisesRegexp(errors.NotFoundError,
|
||||||
"Unsuccessful TensorSliceReader"):
|
"Unsuccessful TensorSliceReader"):
|
||||||
pywrap_tensorflow.NewCheckpointReader("non-existent")
|
py_checkpoint_reader.NewCheckpointReader("non-existent")
|
||||||
|
|
||||||
|
|
||||||
class CheckpointReaderForV2Test(CheckpointReaderTest):
|
class CheckpointReaderForV2Test(CheckpointReaderTest):
|
||||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.module import module
|
from tensorflow.python.module import module
|
||||||
@ -29,6 +28,7 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import gen_io_ops
|
from tensorflow.python.ops import gen_io_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.training import py_checkpoint_reader
|
||||||
from tensorflow.python.training.saving import saveable_object
|
from tensorflow.python.training.saving import saveable_object
|
||||||
from tensorflow.python.training.tracking import base
|
from tensorflow.python.training.tracking import base
|
||||||
from tensorflow.python.training.tracking import util
|
from tensorflow.python.training.tracking import util
|
||||||
@ -116,7 +116,7 @@ class SavingBenchmarks(test.Benchmark):
|
|||||||
|
|
||||||
def benchmark_raw_restore(self):
|
def benchmark_raw_restore(self):
|
||||||
checkpoint_path = _save_checkpoint()
|
checkpoint_path = _save_checkpoint()
|
||||||
all_names, all_dtypes = zip(*pywrap_tensorflow.NewCheckpointReader(
|
all_names, all_dtypes = zip(*py_checkpoint_reader.NewCheckpointReader(
|
||||||
checkpoint_path).get_variable_to_dtype_map().items())
|
checkpoint_path).get_variable_to_dtype_map().items())
|
||||||
|
|
||||||
def _call_restore_v2():
|
def _call_restore_v2():
|
||||||
|
@ -25,7 +25,6 @@ import weakref
|
|||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.core.protobuf import trackable_object_graph_pb2
|
from tensorflow.core.protobuf import trackable_object_graph_pb2
|
||||||
from tensorflow.python import pywrap_tensorflow
|
|
||||||
from tensorflow.python.client import session as session_lib
|
from tensorflow.python.client import session as session_lib
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
@ -43,6 +42,7 @@ from tensorflow.python.ops import variable_scope
|
|||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import checkpoint_management
|
from tensorflow.python.training import checkpoint_management
|
||||||
|
from tensorflow.python.training import py_checkpoint_reader
|
||||||
from tensorflow.python.training import saver as v1_saver_lib
|
from tensorflow.python.training import saver as v1_saver_lib
|
||||||
from tensorflow.python.training.saving import functional_saver
|
from tensorflow.python.training.saving import functional_saver
|
||||||
from tensorflow.python.training.saving import saveable_object_util
|
from tensorflow.python.training.saving import saveable_object_util
|
||||||
@ -200,7 +200,7 @@ class _CheckpointRestoreCoordinator(object):
|
|||||||
self.all_python_objects = object_identity.ObjectIdentityWeakSet()
|
self.all_python_objects = object_identity.ObjectIdentityWeakSet()
|
||||||
self.save_path_tensor = save_path_tensor
|
self.save_path_tensor = save_path_tensor
|
||||||
self.save_path_string = save_path
|
self.save_path_string = save_path
|
||||||
self.dtype_map = pywrap_tensorflow.NewCheckpointReader(
|
self.dtype_map = py_checkpoint_reader.NewCheckpointReader(
|
||||||
save_path).get_variable_to_dtype_map()
|
save_path).get_variable_to_dtype_map()
|
||||||
# A NewCheckpointReader for the most recent checkpoint, for streaming Python
|
# A NewCheckpointReader for the most recent checkpoint, for streaming Python
|
||||||
# state restoration.
|
# state restoration.
|
||||||
@ -272,7 +272,7 @@ class _CheckpointRestoreCoordinator(object):
|
|||||||
if reader is None:
|
if reader is None:
|
||||||
# Lazily create the NewCheckpointReader, since this requires file access
|
# Lazily create the NewCheckpointReader, since this requires file access
|
||||||
# and we may not have any Python saveables.
|
# and we may not have any Python saveables.
|
||||||
reader = pywrap_tensorflow.NewCheckpointReader(self.save_path_string)
|
reader = py_checkpoint_reader.NewCheckpointReader(self.save_path_string)
|
||||||
spec_names = [spec.name for spec in saveable.specs]
|
spec_names = [spec.name for spec in saveable.specs]
|
||||||
saveable.python_restore([reader.get_tensor(name) for name in spec_names])
|
saveable.python_restore([reader.get_tensor(name) for name in spec_names])
|
||||||
|
|
||||||
@ -462,7 +462,7 @@ def object_metadata(save_path):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If an object graph was not found in the checkpoint.
|
ValueError: If an object graph was not found in the checkpoint.
|
||||||
"""
|
"""
|
||||||
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
|
reader = py_checkpoint_reader.NewCheckpointReader(save_path)
|
||||||
try:
|
try:
|
||||||
object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
|
object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
|
||||||
except errors_impl.NotFoundError:
|
except errors_impl.NotFoundError:
|
||||||
@ -1238,7 +1238,7 @@ class TrackableSaver(object):
|
|||||||
"""
|
"""
|
||||||
if save_path is None:
|
if save_path is None:
|
||||||
return InitializationOnlyStatus(self._graph_view, ops.uid())
|
return InitializationOnlyStatus(self._graph_view, ops.uid())
|
||||||
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
|
reader = py_checkpoint_reader.NewCheckpointReader(save_path)
|
||||||
graph_building = not context.executing_eagerly()
|
graph_building = not context.executing_eagerly()
|
||||||
if graph_building:
|
if graph_building:
|
||||||
dtype_map = None
|
dtype_map = None
|
||||||
|
@ -110,7 +110,7 @@ from tensorflow.python.training.training_util import create_global_step
|
|||||||
from tensorflow.python.training.training_util import get_or_create_global_step
|
from tensorflow.python.training.training_util import get_or_create_global_step
|
||||||
from tensorflow.python.training.warm_starting_util import VocabInfo
|
from tensorflow.python.training.warm_starting_util import VocabInfo
|
||||||
from tensorflow.python.training.warm_starting_util import warm_start
|
from tensorflow.python.training.warm_starting_util import warm_start
|
||||||
from tensorflow.python.pywrap_tensorflow import NewCheckpointReader
|
from tensorflow.python.training.py_checkpoint_reader import NewCheckpointReader
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
# pylint: disable=wildcard-import
|
# pylint: disable=wildcard-import
|
||||||
@ -145,4 +145,3 @@ tf_export(v1=["train.SaverDef"])(SaverDef)
|
|||||||
tf_export("train.SequenceExample")(SequenceExample)
|
tf_export("train.SequenceExample")(SequenceExample)
|
||||||
tf_export("train.ServerDef")(ServerDef)
|
tf_export("train.ServerDef")(ServerDef)
|
||||||
# pylint: enable=undefined-variable
|
# pylint: enable=undefined-variable
|
||||||
|
|
||||||
|
@ -335,14 +335,16 @@ def _get_grouped_variables(vars_to_warm_start):
|
|||||||
ValueError: If vars_to_warm_start is not a string, `None`, a list of
|
ValueError: If vars_to_warm_start is not a string, `None`, a list of
|
||||||
`Variables`, or a list of strings.
|
`Variables`, or a list of strings.
|
||||||
"""
|
"""
|
||||||
if isinstance(vars_to_warm_start, str) or vars_to_warm_start is None:
|
# TODO(b/143899805): Remove unicode checks when deprecating Python2.
|
||||||
|
if isinstance(vars_to_warm_start,
|
||||||
|
six.string_types) or vars_to_warm_start is None:
|
||||||
# Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match
|
# Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match
|
||||||
# everything (in TRAINABLE_VARIABLES) here.
|
# everything (in TRAINABLE_VARIABLES) here.
|
||||||
logging.info("Warm-starting variables only in TRAINABLE_VARIABLES.")
|
logging.info("Warm-starting variables only in TRAINABLE_VARIABLES.")
|
||||||
list_of_vars = ops.get_collection(
|
list_of_vars = ops.get_collection(
|
||||||
ops.GraphKeys.TRAINABLE_VARIABLES, scope=vars_to_warm_start)
|
ops.GraphKeys.TRAINABLE_VARIABLES, scope=vars_to_warm_start)
|
||||||
elif isinstance(vars_to_warm_start, list):
|
elif isinstance(vars_to_warm_start, list):
|
||||||
if all(isinstance(v, str) for v in vars_to_warm_start):
|
if all(isinstance(v, six.string_types) for v in vars_to_warm_start):
|
||||||
list_of_vars = []
|
list_of_vars = []
|
||||||
for v in vars_to_warm_start:
|
for v in vars_to_warm_start:
|
||||||
list_of_vars += ops.get_collection(
|
list_of_vars += ops.get_collection(
|
||||||
|
@ -1,167 +0,0 @@
|
|||||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
%include "tensorflow/python/lib/core/strings.i"
|
|
||||||
%include "tensorflow/python/platform/base.i"
|
|
||||||
|
|
||||||
%{
|
|
||||||
#include "tensorflow/c/checkpoint_reader.h"
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
|
||||||
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
|
||||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
|
||||||
%}
|
|
||||||
|
|
||||||
%typemap(out) const tensorflow::checkpoint::TensorSliceReader::VarToShapeMap& {
|
|
||||||
tensorflow::Safe_PyObjectPtr output_map(tensorflow::make_safe(PyDict_New()));
|
|
||||||
for (auto v : *$1) {
|
|
||||||
%#if PY_MAJOR_VERSION >= 3
|
|
||||||
tensorflow::Safe_PyObjectPtr key(
|
|
||||||
tensorflow::make_safe(PyUnicode_FromStringAndSize(v.first.c_str(),
|
|
||||||
v.first.size())));
|
|
||||||
%#else
|
|
||||||
tensorflow::Safe_PyObjectPtr key(
|
|
||||||
tensorflow::make_safe(PyString_FromStringAndSize(v.first.c_str(),
|
|
||||||
v.first.size())));
|
|
||||||
%#endif
|
|
||||||
if (!key) {
|
|
||||||
SWIG_fail;
|
|
||||||
}
|
|
||||||
size_t dims = v.second.dims();
|
|
||||||
tensorflow::Safe_PyObjectPtr value(tensorflow::make_safe(PyList_New(dims)));
|
|
||||||
if (!value) {
|
|
||||||
SWIG_fail;
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < dims; ++i) {
|
|
||||||
%#if PY_MAJOR_VERSION >= 3
|
|
||||||
tensorflow::Safe_PyObjectPtr dim_value(
|
|
||||||
tensorflow::make_safe(PyLong_FromLong(v.second.dim_size(i))));
|
|
||||||
%#else
|
|
||||||
tensorflow::Safe_PyObjectPtr dim_value(
|
|
||||||
tensorflow::make_safe(PyInt_FromLong(v.second.dim_size(i))));
|
|
||||||
%#endif
|
|
||||||
if (!dim_value) {
|
|
||||||
SWIG_fail;
|
|
||||||
}
|
|
||||||
PyList_SET_ITEM(value.get(), i, dim_value.release());
|
|
||||||
}
|
|
||||||
if (PyDict_SetItem(output_map.get(), key.get(), value.get()) == -1) {
|
|
||||||
SWIG_fail;
|
|
||||||
} else {
|
|
||||||
key.release();
|
|
||||||
value.release();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
$result = output_map.release();
|
|
||||||
}
|
|
||||||
|
|
||||||
%typemap(out) const tensorflow::checkpoint::TensorSliceReader::VarToDataTypeMap& {
|
|
||||||
tensorflow::Safe_PyObjectPtr output_map(tensorflow::make_safe(PyDict_New()));
|
|
||||||
for (auto v : *$1) {
|
|
||||||
%#if PY_MAJOR_VERSION >= 3
|
|
||||||
tensorflow::Safe_PyObjectPtr key(
|
|
||||||
tensorflow::make_safe(PyUnicode_FromStringAndSize(v.first.c_str(), v.first.size())));
|
|
||||||
%#else
|
|
||||||
tensorflow::Safe_PyObjectPtr key(
|
|
||||||
tensorflow::make_safe(PyString_FromStringAndSize(v.first.c_str(), v.first.size())));
|
|
||||||
%#endif
|
|
||||||
if (!key) {
|
|
||||||
SWIG_fail;
|
|
||||||
}
|
|
||||||
%#if PY_MAJOR_VERSION >= 3
|
|
||||||
tensorflow::Safe_PyObjectPtr value(tensorflow::make_safe(PyLong_FromLong(v.second)));
|
|
||||||
%#else
|
|
||||||
tensorflow::Safe_PyObjectPtr value(tensorflow::make_safe(PyInt_FromLong(v.second)));
|
|
||||||
%#endif
|
|
||||||
if (!value) {
|
|
||||||
SWIG_fail;
|
|
||||||
}
|
|
||||||
if (PyDict_SetItem(output_map.get(), key.get(), value.get()) == -1) {
|
|
||||||
SWIG_fail;
|
|
||||||
} else {
|
|
||||||
key.release();
|
|
||||||
value.release();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
$result = output_map.release();
|
|
||||||
}
|
|
||||||
|
|
||||||
%{
|
|
||||||
static PyObject* CheckpointReader_GetTensor(
|
|
||||||
tensorflow::checkpoint::CheckpointReader* reader,
|
|
||||||
const string& name,
|
|
||||||
TF_Status* status) {
|
|
||||||
PyObject* py_obj = Py_None;
|
|
||||||
std::unique_ptr<tensorflow::Tensor> tensor;
|
|
||||||
reader->GetTensor(name, &tensor, status);
|
|
||||||
if (TF_GetCode(status) == TF_OK) {
|
|
||||||
tensorflow::Status s =
|
|
||||||
tensorflow::TensorToNdarray(*tensor.get(), &py_obj);
|
|
||||||
if (!s.ok()) {
|
|
||||||
Set_TF_Status_from_Status(status, s);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return PyArray_Return(reinterpret_cast<PyArrayObject*>(py_obj));
|
|
||||||
}
|
|
||||||
%}
|
|
||||||
|
|
||||||
// Wrap this function.
|
|
||||||
PyObject* CheckpointReader_GetTensor(
|
|
||||||
tensorflow::checkpoint::CheckpointReader* reader,
|
|
||||||
const string& name,
|
|
||||||
TF_Status* status);
|
|
||||||
|
|
||||||
%ignoreall
|
|
||||||
|
|
||||||
%unignore tensorflow;
|
|
||||||
%unignore tensorflow::checkpoint;
|
|
||||||
%unignore tensorflow::checkpoint::CheckpointReader;
|
|
||||||
%unignore tensorflow::checkpoint::CheckpointReader::CheckpointReader;
|
|
||||||
%unignore tensorflow::checkpoint::CheckpointReader::~CheckpointReader;
|
|
||||||
%rename("debug_string") tensorflow::checkpoint::CheckpointReader::DebugString;
|
|
||||||
%rename("get_variable_to_shape_map") tensorflow::checkpoint::CheckpointReader::GetVariableToShapeMap;
|
|
||||||
%rename("_GetVariableToDataTypeMap") tensorflow::checkpoint::CheckpointReader::GetVariableToDataTypeMap;
|
|
||||||
%rename("_HasTensor") tensorflow::checkpoint::CheckpointReader::HasTensor;
|
|
||||||
%unignore CheckpointReader_GetTensor;
|
|
||||||
|
|
||||||
%extend tensorflow::checkpoint::CheckpointReader {
|
|
||||||
%insert("python") %{
|
|
||||||
def get_variable_to_dtype_map(self):
|
|
||||||
from tensorflow.python.framework import dtypes
|
|
||||||
return {name: dtypes.DType(type_enum)
|
|
||||||
for name, type_enum in self._GetVariableToDataTypeMap().items()}
|
|
||||||
|
|
||||||
def has_tensor(self, tensor_str):
|
|
||||||
from tensorflow.python.util import compat
|
|
||||||
return self._HasTensor(compat.as_bytes(tensor_str))
|
|
||||||
|
|
||||||
def get_tensor(self, tensor_str):
|
|
||||||
from tensorflow.python.util import compat
|
|
||||||
|
|
||||||
return CheckpointReader_GetTensor(self, compat.as_bytes(tensor_str))
|
|
||||||
%}
|
|
||||||
}
|
|
||||||
|
|
||||||
%insert("python") %{
|
|
||||||
def NewCheckpointReader(filepattern):
|
|
||||||
from tensorflow.python.util import compat
|
|
||||||
return CheckpointReader(compat.as_bytes(filepattern))
|
|
||||||
|
|
||||||
NewCheckpointReader._tf_api_names_v1 = ['train.NewCheckpointReader']
|
|
||||||
%}
|
|
||||||
|
|
||||||
%include "tensorflow/c/checkpoint_reader.h"
|
|
||||||
%unignoreall
|
|
150
tensorflow/python/util/py_checkpoint_reader_wrapper.cc
Normal file
150
tensorflow/python/util/py_checkpoint_reader_wrapper.cc
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// Disallow Numpy 1.7 deprecated symbols.
|
||||||
|
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
|
||||||
|
|
||||||
|
#include "numpy/arrayobject.h"
|
||||||
|
#include "numpy/ufuncobject.h"
|
||||||
|
#include "include/pybind11/chrono.h"
|
||||||
|
#include "include/pybind11/complex.h"
|
||||||
|
#include "include/pybind11/functional.h"
|
||||||
|
#include "include/pybind11/pybind11.h"
|
||||||
|
#include "include/pybind11/stl.h"
|
||||||
|
#include "tensorflow/c/checkpoint_reader.h"
|
||||||
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
||||||
|
#include "tensorflow/python/lib/core/py_exception_registry.h"
|
||||||
|
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||||
|
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||||
|
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
// TODO(amitpatankar): Move the custom type casters to separate common header
|
||||||
|
// only libraries.
|
||||||
|
|
||||||
|
namespace pybind11 {
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
/* This is a custom type caster for the TensorShape object. For more
|
||||||
|
* documentation please refer to this link:
|
||||||
|
* https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html#custom-type-casters
|
||||||
|
* The PyCheckpointReader methods sometimes return the `TensorShape` object
|
||||||
|
* and the `DataType` object as outputs. This custom type caster helps Python
|
||||||
|
* handle it's conversion from C++ to Python. Since we do not accept these
|
||||||
|
* classes as arguments from Python, it is not necessary to define the `load`
|
||||||
|
* function to cast the object from Python to a C++ object.
|
||||||
|
*/
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct type_caster<tensorflow::TensorShape> {
|
||||||
|
public:
|
||||||
|
PYBIND11_TYPE_CASTER(tensorflow::TensorShape, _("tensorflow::TensorShape"));
|
||||||
|
|
||||||
|
static handle cast(const tensorflow::TensorShape& src,
|
||||||
|
return_value_policy unused_policy, handle unused_handle) {
|
||||||
|
// TODO(amitpatankar): Simplify handling TensorShape as output later.
|
||||||
|
size_t dims = src.dims();
|
||||||
|
tensorflow::Safe_PyObjectPtr value(PyList_New(dims));
|
||||||
|
for (size_t i = 0; i < dims; ++i) {
|
||||||
|
#if PY_MAJOR_VERSION >= 3
|
||||||
|
tensorflow::Safe_PyObjectPtr dim_value(
|
||||||
|
tensorflow::make_safe(PyLong_FromLong(src.dim_size(i))));
|
||||||
|
#else
|
||||||
|
tensorflow::Safe_PyObjectPtr dim_value(
|
||||||
|
tensorflow::make_safe(PyInt_FromLong(src.dim_size(i))));
|
||||||
|
#endif
|
||||||
|
PyList_SET_ITEM(value.get(), i, dim_value.release());
|
||||||
|
}
|
||||||
|
|
||||||
|
return value.release();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct type_caster<tensorflow::DataType> {
|
||||||
|
public:
|
||||||
|
PYBIND11_TYPE_CASTER(tensorflow::DataType, _("tensorflow::DataType"));
|
||||||
|
|
||||||
|
static handle cast(const tensorflow::DataType& src,
|
||||||
|
return_value_policy unused_policy, handle unused_handle) {
|
||||||
|
#if PY_MAJOR_VERSION >= 3
|
||||||
|
tensorflow::Safe_PyObjectPtr value(
|
||||||
|
tensorflow::make_safe(PyLong_FromLong(src)));
|
||||||
|
#else
|
||||||
|
tensorflow::Safe_PyObjectPtr value(
|
||||||
|
tensorflow::make_safe(PyInt_FromLong(src)));
|
||||||
|
#endif
|
||||||
|
return value.release();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
} // namespace pybind11
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
static py::object CheckpointReader_GetTensor(
|
||||||
|
tensorflow::checkpoint::CheckpointReader* reader, const string& name) {
|
||||||
|
Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
|
||||||
|
PyObject* py_obj = Py_None;
|
||||||
|
std::unique_ptr<tensorflow::Tensor> tensor;
|
||||||
|
reader->GetTensor(name, &tensor, status.get());
|
||||||
|
|
||||||
|
// Error handling if unable to get Tensor.
|
||||||
|
tensorflow::MaybeRaiseFromTFStatus(status.get());
|
||||||
|
|
||||||
|
tensorflow::MaybeRaiseFromStatus(
|
||||||
|
tensorflow::TensorToNdarray(*tensor, &py_obj));
|
||||||
|
|
||||||
|
return tensorflow::pyo_or_throw(
|
||||||
|
PyArray_Return(reinterpret_cast<PyArrayObject*>(py_obj)));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
PYBIND11_MODULE(_pywrap_checkpoint_reader, m) {
|
||||||
|
// Initialization code to use numpy types in the type casters.
|
||||||
|
import_array1();
|
||||||
|
py::class_<tensorflow::checkpoint::CheckpointReader> checkpoint_reader_class(
|
||||||
|
m, "CheckpointReader");
|
||||||
|
checkpoint_reader_class
|
||||||
|
.def(py::init([](const std::string& filename) {
|
||||||
|
tensorflow::Safe_TF_StatusPtr status =
|
||||||
|
tensorflow::make_safe(TF_NewStatus());
|
||||||
|
// pybind11 support smart pointers and will own freeing the memory when
|
||||||
|
// complete.
|
||||||
|
// https://pybind11.readthedocs.io/en/master/advanced/smart_ptrs.html#std-unique-ptr
|
||||||
|
auto checkpoint =
|
||||||
|
std::make_unique<tensorflow::checkpoint::CheckpointReader>(
|
||||||
|
filename, status.get());
|
||||||
|
tensorflow::MaybeRaiseFromTFStatus(status.get());
|
||||||
|
return checkpoint;
|
||||||
|
}))
|
||||||
|
.def("debug_string",
|
||||||
|
[](tensorflow::checkpoint::CheckpointReader& self) {
|
||||||
|
return py::bytes(self.DebugString());
|
||||||
|
})
|
||||||
|
.def("get_variable_to_shape_map",
|
||||||
|
&tensorflow::checkpoint::CheckpointReader::GetVariableToShapeMap)
|
||||||
|
.def("_GetVariableToDataTypeMap",
|
||||||
|
&tensorflow::checkpoint::CheckpointReader::GetVariableToDataTypeMap)
|
||||||
|
.def("_HasTensor", &tensorflow::checkpoint::CheckpointReader::HasTensor)
|
||||||
|
.def_static("CheckpointReader_GetTensor",
|
||||||
|
&tensorflow::CheckpointReader_GetTensor);
|
||||||
|
};
|
@ -107,3 +107,21 @@ tensorflow::profiler::PythonTraceMe::IsEnabled
|
|||||||
[traceme_recorder] # traceme
|
[traceme_recorder] # traceme
|
||||||
tensorflow::profiler::internal::g_trace_level
|
tensorflow::profiler::internal::g_trace_level
|
||||||
|
|
||||||
|
[checkpoint_reader] # py_checkpoint_reader
|
||||||
|
tensorflow::checkpoint::CheckpointReader
|
||||||
|
tensorflow::checkpoint::CheckpointReader::Init
|
||||||
|
tensorflow::checkpoint::CheckpointReader::DebugString
|
||||||
|
tensorflow::checkpoint::CheckpointReader::GetVariableToShapeMap
|
||||||
|
tensorflow::checkpoint::CheckpointReader::GetVariableToDataTypeMap
|
||||||
|
tensorflow::checkpoint::CheckpointReader::GetTensor
|
||||||
|
tensorflow::checkpoint::CheckpointReader::HasTensor
|
||||||
|
|
||||||
|
[tensor_bundle] # py_checkpoint_reader
|
||||||
|
tensorflow::BundleReader::BundleReader
|
||||||
|
tensorflow::BundleReader::~BundleReader
|
||||||
|
|
||||||
|
[ndarray_tensor] # py_checkpoint_reader
|
||||||
|
tensorflow::TensorToNdarray
|
||||||
|
|
||||||
|
[safe_ptr] # py_checkpoint_reader
|
||||||
|
tensorflow::detail::PyDecrefDeleter
|
||||||
|
Loading…
Reference in New Issue
Block a user