Export the checkpoint reader classes and functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.

PiperOrigin-RevId: 279101529
Change-Id: I25502ed3d3718499abca41f5614681f41e4c7199
This commit is contained in:
Amit Patankar 2019-11-07 09:20:14 -08:00 committed by TensorFlower Gardener
parent 8977382c53
commit b66e4e833c
20 changed files with 364 additions and 205 deletions

View File

@ -293,6 +293,15 @@ exports_files(
visibility = ["//visibility:public"],
)
filegroup(
name = "checkpoint_reader_hdrs",
srcs = [
"checkpoint_reader.h",
"tf_status_helper.h",
],
visibility = ["//tensorflow:__subpackages__"],
)
tf_cuda_library(
name = "tf_status_helper",
srcs = ["tf_status_helper.cc"],

View File

@ -98,6 +98,7 @@ py_library(
"//third_party/py/tensorflow_core:__subpackages__",
],
deps = [
":_pywrap_checkpoint_reader",
":_pywrap_events_writer",
":_pywrap_kernel_registry",
":_pywrap_py_exception_registry",
@ -629,6 +630,32 @@ tf_python_pybind_extension(
],
)
tf_python_pybind_extension(
name = "_pywrap_checkpoint_reader",
srcs = ["util/py_checkpoint_reader_wrapper.cc"],
hdrs = [
"lib/core/ndarray_tensor.h",
"lib/core/safe_ptr.h",
":py_exception_registry_hdr",
"//tensorflow/c:checkpoint_reader_hdrs",
"//tensorflow/c:headers",
"//tensorflow/c/eager:headers",
],
module_name = "_pywrap_checkpoint_reader",
deps = [
":pybind11_lib",
":pybind11_status",
"//tensorflow/core:lib",
"//tensorflow/core:op_gen_lib",
"//tensorflow/core:protos_all",
"//tensorflow/core/util/tensor_bundle:tensor_bundle_headers_lib",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
"@com_google_absl//absl/strings",
"@pybind11",
],
)
filegroup(
name = "py_exception_registry_hdr",
srcs = [
@ -996,6 +1023,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
":_pywrap_checkpoint_reader",
":_pywrap_debug_events_writer",
":_pywrap_events_writer",
":_pywrap_kernel_registry",
@ -4730,6 +4758,7 @@ py_library(
":math_ops",
":mixed_precision",
":platform",
":py_checkpoint_reader",
":pywrap_tensorflow",
":random_ops",
":resource_variable_ops",
@ -4785,6 +4814,17 @@ py_library(
],
)
py_library(
name = "py_checkpoint_reader",
srcs = ["training/py_checkpoint_reader.py"],
deps = [
":_pywrap_checkpoint_reader",
":dtypes",
":errors",
":util",
],
)
py_library(
name = "checkpoint_management",
srcs = ["training/checkpoint_management.py"],
@ -5284,7 +5324,6 @@ tf_py_wrap_cc(
"lib/io/py_record_writer.i",
"platform/base.i",
"pywrap_tfe.i",
"util/py_checkpoint_reader.i",
"//tensorflow/compiler/mlir/python:mlir.i",
],
# add win_def_file for pywrap_tensorflow
@ -5336,6 +5375,7 @@ tf_py_wrap_cc(
"//tensorflow/tools/graph_transforms:transform_graph_lib",
"//tensorflow/lite/toco/python:toco_python_api",
"//tensorflow/python/eager:pywrap_tfe_lib",
"//tensorflow/core/util/tensor_bundle:tensor_bundle",
] + (tf_additional_lib_deps() +
tf_additional_plugin_deps()) + if_ngraph([
"@ngraph_tf//:ngraph_tf",
@ -5391,9 +5431,14 @@ genrule(
"//tensorflow/core/profiler/internal:python_traceme", # traceme
"//tensorflow/core/profiler/internal:traceme_recorder", # traceme
":py_exception_registry", # py_exception_registry
":kernel_registry",
":kernel_registry", # kernel_registry
"//tensorflow/lite/toco/python:toco_python_api", # toco
"//tensorflow/tools/graph_transforms:transform_graph_lib", # transform_graph
"//tensorflow/c:checkpoint_reader", # checkpoint_reader
":ndarray_tensor", # checkpoint_reader
":numpy_lib", # checkpoint_reader
":safe_ptr", # checkpoint_reader
"//tensorflow/core/util/tensor_bundle", # checkpoint_reader
],
outs = ["pybind_symbol_target_libs_file.txt"],
cmd = select({
@ -6211,7 +6256,6 @@ tf_py_test(
":io_ops",
":partitioned_variables",
":platform",
":pywrap_tensorflow",
":resource_variable_ops",
":state_ops",
":training",

View File

@ -30,7 +30,6 @@ import numpy as np
import six
from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
@ -54,6 +53,7 @@ from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import py_checkpoint_reader
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
@ -1176,7 +1176,7 @@ class Network(base_layer.Layer):
save_format = 'h5'
else:
try:
pywrap_tensorflow.NewCheckpointReader(filepath)
py_checkpoint_reader.NewCheckpointReader(filepath)
save_format = 'tf'
except errors_impl.DataLossError:
# The checkpoint is not readable in TensorFlow format. Try HDF5.

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/python/lib/core/bfloat16.h"
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
#include "tensorflow/python/lib/core/numpy.h"
namespace tensorflow {
namespace {
@ -527,12 +528,12 @@ Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor) {
size_t size = 0;
void* encoded = nullptr;
TF_RETURN_IF_ERROR(EncodePyBytesArray(array, nelems, &size, &encoded));
*out_tensor =
make_safe(TF_NewTensor(dtype, dims.data(), dims.size(), encoded, size,
[](void* data, size_t len, void* arg) {
delete[] reinterpret_cast<char*>(data);
},
nullptr));
*out_tensor = make_safe(TF_NewTensor(
dtype, dims.data(), dims.size(), encoded, size,
[](void* data, size_t len, void* arg) {
delete[] reinterpret_cast<char*>(data);
},
nullptr));
}
return Status::OK();
}

View File

@ -16,9 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_
#define TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_
// Must be included first.
#include "tensorflow/python/lib/core/numpy.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/tensor.h"

View File

@ -13,7 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/lib/core/strings.i"
%include "tensorflow/python/platform/base.i"
%{
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
%}
%include "tensorflow/c/tf_datatype.h"
%include "tensorflow/c/tf_status.h"

View File

@ -17,8 +17,6 @@ limitations under the License.
* The includes are intentionally not alphabetically sorted, as the order of
* includes follows dependency order */
%include "tensorflow/python/util/py_checkpoint_reader.i"
%include "tensorflow/python/pywrap_tfe.i"

View File

@ -46,7 +46,6 @@ from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
@ -56,6 +55,7 @@ from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import py_checkpoint_reader
from tensorflow.python.training import saver as saver_lib
@ -161,7 +161,7 @@ def freeze_graph_with_def_protos(input_graph_def,
loader.load(sess, saved_model_tags, input_saved_model_dir)
else:
var_list = {}
reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
reader = py_checkpoint_reader.NewCheckpointReader(input_checkpoint)
var_to_shape_map = reader.get_variable_to_shape_map()
# List of all partition variables. Because the condition is heuristic

View File

@ -23,9 +23,9 @@ import sys
import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.platform import app
from tensorflow.python.platform import flags
from tensorflow.python.training import py_checkpoint_reader
FLAGS = None
@ -72,7 +72,7 @@ def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors,
count_exclude_pattern: Regex string, pattern to exclude tensors when count.
"""
try:
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
reader = py_checkpoint_reader.NewCheckpointReader(file_name)
if all_tensors or all_tensor_names:
var_to_shape_map = reader.get_variable_to_shape_map()
var_to_dtype_map = reader.get_variable_to_dtype_map()

View File

@ -21,7 +21,6 @@ from __future__ import print_function
import time
import six
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import ops
from tensorflow.python.ops import io_ops
@ -31,6 +30,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import py_checkpoint_reader
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.util.tf_export import tf_export
@ -63,7 +63,7 @@ def load_checkpoint(ckpt_dir_or_file):
if filename is None:
raise ValueError("Couldn't find 'checkpoint' file or checkpoints in "
"given directory %s" % ckpt_dir_or_file)
return pywrap_tensorflow.NewCheckpointReader(filename)
return py_checkpoint_reader.NewCheckpointReader(filename)
@tf_export("train.load_variable")

View File

@ -0,0 +1,99 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Extending CheckpointReader for TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python._pywrap_checkpoint_reader import CheckpointReader
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
def error_translator(e):
"""Translate the tensor_slice_reader.cc errors."""
# TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the
# issue with throwing python exceptions from C++.
error_message = str(e)
if 'not found in checkpoint' in error_message or (
'Failed to find any '
'matching files for') in error_message:
raise errors_impl.NotFoundError(None, None, error_message)
elif 'Sliced checkpoints are not supported' in error_message or (
'Data type '
'not '
'supported') in error_message:
raise errors_impl.UnimplementedError(None, None, error_message)
elif 'Failed to get matching files on' in error_message:
raise errors_impl.InvalidArgumentError(None, None, error_message)
elif 'Unable to open table file' in error_message:
raise errors_impl.DataLossError(None, None, error_message)
elif 'Failed to find the saved tensor slices' in error_message:
raise errors_impl.InternalError(None, None, error_message)
else:
raise errors_impl.OpError(None, None, error_message, errors_impl.UNKNOWN)
def get_variable_to_dtype_map(self):
return {
name: dtypes.DType(type_enum)
for name, type_enum in self._GetVariableToDataTypeMap().items() # pylint: disable=protected-access
}
CheckpointReader.get_variable_to_dtype_map = get_variable_to_dtype_map
def has_tensor(self, tensor_str):
return self._HasTensor(compat.as_bytes(tensor_str)) # pylint: disable=protected-access
CheckpointReader.has_tensor = has_tensor
def get_tensor(self, tensor_str):
"""Get the tensor from the Checkpoint object."""
try:
return CheckpointReader.CheckpointReader_GetTensor(
self, compat.as_bytes(tensor_str))
# TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the
# issue with throwing python exceptions from C++.
except RuntimeError as e:
error_translator(e)
CheckpointReader.get_tensor = get_tensor
# Disable invalid name to keep backwards compatibility with that function.
# It was previously exported from py_checkpoint_reader.i which did not conform
# to pylint checks.
# pylint: disable=invalid-name
@tf_export(v1=['train.NewCheckpointReader'])
def NewCheckpointReader(filepattern):
"""A function that returns a CheckPointReader.
Args:
filepattern: The filename.
Returns:
A CheckpointReader object.
"""
try:
return CheckpointReader(compat.as_bytes(filepattern))
# TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the
# issue with throwing python exceptions from C++.
except RuntimeError as e:
error_translator(e)

View File

@ -32,7 +32,6 @@ import numpy as np
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.core.protobuf import trackable_object_graph_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@ -49,6 +48,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import py_checkpoint_reader
from tensorflow.python.training import training_util
from tensorflow.python.training.saving import saveable_object
from tensorflow.python.training.saving import saveable_object_util
@ -1614,7 +1614,7 @@ def object_graph_key_mapping(checkpoint_path):
Returns:
Dictionary mapping tensor names to checkpoint keys.
"""
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path)
object_graph_string = reader.get_tensor(trackable.OBJECT_GRAPH_PROTO_KEY)
object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
object_graph_proto.ParseFromString(object_graph_string)

View File

@ -34,7 +34,6 @@ from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import queue_runner_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
@ -69,6 +68,7 @@ from tensorflow.python.summary import summary
from tensorflow.python.training import adam
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import py_checkpoint_reader
from tensorflow.python.training import queue_runner_impl
from tensorflow.python.training import saver as saver_module
from tensorflow.python.training import saver_test_utils
@ -2468,7 +2468,7 @@ class CheckpointReaderTest(test.TestCase):
save.save(sess, save_path)
# Creates a reader.
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
reader = py_checkpoint_reader.NewCheckpointReader(save_path)
# Verifies that the tensors exist.
self.assertTrue(reader.has_tensor("v0"))
self.assertTrue(reader.has_tensor("v1"))
@ -2493,7 +2493,7 @@ class CheckpointReaderTest(test.TestCase):
def testNonexistentPath(self):
with self.assertRaisesRegexp(errors.NotFoundError,
"Unsuccessful TensorSliceReader"):
pywrap_tensorflow.NewCheckpointReader("non-existent")
py_checkpoint_reader.NewCheckpointReader("non-existent")
class CheckpointReaderForV2Test(CheckpointReaderTest):

View File

@ -21,7 +21,6 @@ from __future__ import print_function
import os
import time
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.module import module
@ -29,6 +28,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.platform import test
from tensorflow.python.training import py_checkpoint_reader
from tensorflow.python.training.saving import saveable_object
from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import util
@ -116,7 +116,7 @@ class SavingBenchmarks(test.Benchmark):
def benchmark_raw_restore(self):
checkpoint_path = _save_checkpoint()
all_names, all_dtypes = zip(*pywrap_tensorflow.NewCheckpointReader(
all_names, all_dtypes = zip(*py_checkpoint_reader.NewCheckpointReader(
checkpoint_path).get_variable_to_dtype_map().items())
def _call_restore_v2():

View File

@ -25,7 +25,6 @@ import weakref
import six
from tensorflow.core.protobuf import trackable_object_graph_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session as session_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
@ -43,6 +42,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import py_checkpoint_reader
from tensorflow.python.training import saver as v1_saver_lib
from tensorflow.python.training.saving import functional_saver
from tensorflow.python.training.saving import saveable_object_util
@ -200,7 +200,7 @@ class _CheckpointRestoreCoordinator(object):
self.all_python_objects = object_identity.ObjectIdentityWeakSet()
self.save_path_tensor = save_path_tensor
self.save_path_string = save_path
self.dtype_map = pywrap_tensorflow.NewCheckpointReader(
self.dtype_map = py_checkpoint_reader.NewCheckpointReader(
save_path).get_variable_to_dtype_map()
# A NewCheckpointReader for the most recent checkpoint, for streaming Python
# state restoration.
@ -272,7 +272,7 @@ class _CheckpointRestoreCoordinator(object):
if reader is None:
# Lazily create the NewCheckpointReader, since this requires file access
# and we may not have any Python saveables.
reader = pywrap_tensorflow.NewCheckpointReader(self.save_path_string)
reader = py_checkpoint_reader.NewCheckpointReader(self.save_path_string)
spec_names = [spec.name for spec in saveable.specs]
saveable.python_restore([reader.get_tensor(name) for name in spec_names])
@ -462,7 +462,7 @@ def object_metadata(save_path):
Raises:
ValueError: If an object graph was not found in the checkpoint.
"""
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
reader = py_checkpoint_reader.NewCheckpointReader(save_path)
try:
object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
except errors_impl.NotFoundError:
@ -1238,7 +1238,7 @@ class TrackableSaver(object):
"""
if save_path is None:
return InitializationOnlyStatus(self._graph_view, ops.uid())
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
reader = py_checkpoint_reader.NewCheckpointReader(save_path)
graph_building = not context.executing_eagerly()
if graph_building:
dtype_map = None

View File

@ -110,7 +110,7 @@ from tensorflow.python.training.training_util import create_global_step
from tensorflow.python.training.training_util import get_or_create_global_step
from tensorflow.python.training.warm_starting_util import VocabInfo
from tensorflow.python.training.warm_starting_util import warm_start
from tensorflow.python.pywrap_tensorflow import NewCheckpointReader
from tensorflow.python.training.py_checkpoint_reader import NewCheckpointReader
from tensorflow.python.util.tf_export import tf_export
# pylint: disable=wildcard-import
@ -145,4 +145,3 @@ tf_export(v1=["train.SaverDef"])(SaverDef)
tf_export("train.SequenceExample")(SequenceExample)
tf_export("train.ServerDef")(ServerDef)
# pylint: enable=undefined-variable

View File

@ -335,14 +335,16 @@ def _get_grouped_variables(vars_to_warm_start):
ValueError: If vars_to_warm_start is not a string, `None`, a list of
`Variables`, or a list of strings.
"""
if isinstance(vars_to_warm_start, str) or vars_to_warm_start is None:
# TODO(b/143899805): Remove unicode checks when deprecating Python2.
if isinstance(vars_to_warm_start,
six.string_types) or vars_to_warm_start is None:
# Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match
# everything (in TRAINABLE_VARIABLES) here.
logging.info("Warm-starting variables only in TRAINABLE_VARIABLES.")
list_of_vars = ops.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES, scope=vars_to_warm_start)
elif isinstance(vars_to_warm_start, list):
if all(isinstance(v, str) for v in vars_to_warm_start):
if all(isinstance(v, six.string_types) for v in vars_to_warm_start):
list_of_vars = []
for v in vars_to_warm_start:
list_of_vars += ops.get_collection(

View File

@ -1,167 +0,0 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/lib/core/strings.i"
%include "tensorflow/python/platform/base.i"
%{
#include "tensorflow/c/checkpoint_reader.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
%}
%typemap(out) const tensorflow::checkpoint::TensorSliceReader::VarToShapeMap& {
tensorflow::Safe_PyObjectPtr output_map(tensorflow::make_safe(PyDict_New()));
for (auto v : *$1) {
%#if PY_MAJOR_VERSION >= 3
tensorflow::Safe_PyObjectPtr key(
tensorflow::make_safe(PyUnicode_FromStringAndSize(v.first.c_str(),
v.first.size())));
%#else
tensorflow::Safe_PyObjectPtr key(
tensorflow::make_safe(PyString_FromStringAndSize(v.first.c_str(),
v.first.size())));
%#endif
if (!key) {
SWIG_fail;
}
size_t dims = v.second.dims();
tensorflow::Safe_PyObjectPtr value(tensorflow::make_safe(PyList_New(dims)));
if (!value) {
SWIG_fail;
}
for (size_t i = 0; i < dims; ++i) {
%#if PY_MAJOR_VERSION >= 3
tensorflow::Safe_PyObjectPtr dim_value(
tensorflow::make_safe(PyLong_FromLong(v.second.dim_size(i))));
%#else
tensorflow::Safe_PyObjectPtr dim_value(
tensorflow::make_safe(PyInt_FromLong(v.second.dim_size(i))));
%#endif
if (!dim_value) {
SWIG_fail;
}
PyList_SET_ITEM(value.get(), i, dim_value.release());
}
if (PyDict_SetItem(output_map.get(), key.get(), value.get()) == -1) {
SWIG_fail;
} else {
key.release();
value.release();
}
}
$result = output_map.release();
}
%typemap(out) const tensorflow::checkpoint::TensorSliceReader::VarToDataTypeMap& {
tensorflow::Safe_PyObjectPtr output_map(tensorflow::make_safe(PyDict_New()));
for (auto v : *$1) {
%#if PY_MAJOR_VERSION >= 3
tensorflow::Safe_PyObjectPtr key(
tensorflow::make_safe(PyUnicode_FromStringAndSize(v.first.c_str(), v.first.size())));
%#else
tensorflow::Safe_PyObjectPtr key(
tensorflow::make_safe(PyString_FromStringAndSize(v.first.c_str(), v.first.size())));
%#endif
if (!key) {
SWIG_fail;
}
%#if PY_MAJOR_VERSION >= 3
tensorflow::Safe_PyObjectPtr value(tensorflow::make_safe(PyLong_FromLong(v.second)));
%#else
tensorflow::Safe_PyObjectPtr value(tensorflow::make_safe(PyInt_FromLong(v.second)));
%#endif
if (!value) {
SWIG_fail;
}
if (PyDict_SetItem(output_map.get(), key.get(), value.get()) == -1) {
SWIG_fail;
} else {
key.release();
value.release();
}
}
$result = output_map.release();
}
%{
static PyObject* CheckpointReader_GetTensor(
tensorflow::checkpoint::CheckpointReader* reader,
const string& name,
TF_Status* status) {
PyObject* py_obj = Py_None;
std::unique_ptr<tensorflow::Tensor> tensor;
reader->GetTensor(name, &tensor, status);
if (TF_GetCode(status) == TF_OK) {
tensorflow::Status s =
tensorflow::TensorToNdarray(*tensor.get(), &py_obj);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
}
}
return PyArray_Return(reinterpret_cast<PyArrayObject*>(py_obj));
}
%}
// Wrap this function.
PyObject* CheckpointReader_GetTensor(
tensorflow::checkpoint::CheckpointReader* reader,
const string& name,
TF_Status* status);
%ignoreall
%unignore tensorflow;
%unignore tensorflow::checkpoint;
%unignore tensorflow::checkpoint::CheckpointReader;
%unignore tensorflow::checkpoint::CheckpointReader::CheckpointReader;
%unignore tensorflow::checkpoint::CheckpointReader::~CheckpointReader;
%rename("debug_string") tensorflow::checkpoint::CheckpointReader::DebugString;
%rename("get_variable_to_shape_map") tensorflow::checkpoint::CheckpointReader::GetVariableToShapeMap;
%rename("_GetVariableToDataTypeMap") tensorflow::checkpoint::CheckpointReader::GetVariableToDataTypeMap;
%rename("_HasTensor") tensorflow::checkpoint::CheckpointReader::HasTensor;
%unignore CheckpointReader_GetTensor;
%extend tensorflow::checkpoint::CheckpointReader {
%insert("python") %{
def get_variable_to_dtype_map(self):
from tensorflow.python.framework import dtypes
return {name: dtypes.DType(type_enum)
for name, type_enum in self._GetVariableToDataTypeMap().items()}
def has_tensor(self, tensor_str):
from tensorflow.python.util import compat
return self._HasTensor(compat.as_bytes(tensor_str))
def get_tensor(self, tensor_str):
from tensorflow.python.util import compat
return CheckpointReader_GetTensor(self, compat.as_bytes(tensor_str))
%}
}
%insert("python") %{
def NewCheckpointReader(filepattern):
from tensorflow.python.util import compat
return CheckpointReader(compat.as_bytes(filepattern))
NewCheckpointReader._tf_api_names_v1 = ['train.NewCheckpointReader']
%}
%include "tensorflow/c/checkpoint_reader.h"
%unignoreall

View File

@ -0,0 +1,150 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Disallow Numpy 1.7 deprecated symbols.
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include "numpy/arrayobject.h"
#include "numpy/ufuncobject.h"
#include "include/pybind11/chrono.h"
#include "include/pybind11/complex.h"
#include "include/pybind11/functional.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "tensorflow/c/checkpoint_reader.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h"
#include "tensorflow/python/lib/core/py_exception_registry.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
namespace py = pybind11;
// TODO(amitpatankar): Move the custom type casters to separate common header
// only libraries.
namespace pybind11 {
namespace detail {
/* This is a custom type caster for the TensorShape object. For more
* documentation please refer to this link:
* https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html#custom-type-casters
* The PyCheckpointReader methods sometimes return the `TensorShape` object
* and the `DataType` object as outputs. This custom type caster helps Python
* handle it's conversion from C++ to Python. Since we do not accept these
* classes as arguments from Python, it is not necessary to define the `load`
* function to cast the object from Python to a C++ object.
*/
template <>
struct type_caster<tensorflow::TensorShape> {
public:
PYBIND11_TYPE_CASTER(tensorflow::TensorShape, _("tensorflow::TensorShape"));
static handle cast(const tensorflow::TensorShape& src,
return_value_policy unused_policy, handle unused_handle) {
// TODO(amitpatankar): Simplify handling TensorShape as output later.
size_t dims = src.dims();
tensorflow::Safe_PyObjectPtr value(PyList_New(dims));
for (size_t i = 0; i < dims; ++i) {
#if PY_MAJOR_VERSION >= 3
tensorflow::Safe_PyObjectPtr dim_value(
tensorflow::make_safe(PyLong_FromLong(src.dim_size(i))));
#else
tensorflow::Safe_PyObjectPtr dim_value(
tensorflow::make_safe(PyInt_FromLong(src.dim_size(i))));
#endif
PyList_SET_ITEM(value.get(), i, dim_value.release());
}
return value.release();
}
};
template <>
struct type_caster<tensorflow::DataType> {
public:
PYBIND11_TYPE_CASTER(tensorflow::DataType, _("tensorflow::DataType"));
static handle cast(const tensorflow::DataType& src,
return_value_policy unused_policy, handle unused_handle) {
#if PY_MAJOR_VERSION >= 3
tensorflow::Safe_PyObjectPtr value(
tensorflow::make_safe(PyLong_FromLong(src)));
#else
tensorflow::Safe_PyObjectPtr value(
tensorflow::make_safe(PyInt_FromLong(src)));
#endif
return value.release();
}
};
} // namespace detail
} // namespace pybind11
namespace tensorflow {
static py::object CheckpointReader_GetTensor(
tensorflow::checkpoint::CheckpointReader* reader, const string& name) {
Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
PyObject* py_obj = Py_None;
std::unique_ptr<tensorflow::Tensor> tensor;
reader->GetTensor(name, &tensor, status.get());
// Error handling if unable to get Tensor.
tensorflow::MaybeRaiseFromTFStatus(status.get());
tensorflow::MaybeRaiseFromStatus(
tensorflow::TensorToNdarray(*tensor, &py_obj));
return tensorflow::pyo_or_throw(
PyArray_Return(reinterpret_cast<PyArrayObject*>(py_obj)));
}
} // namespace tensorflow
PYBIND11_MODULE(_pywrap_checkpoint_reader, m) {
// Initialization code to use numpy types in the type casters.
import_array1();
py::class_<tensorflow::checkpoint::CheckpointReader> checkpoint_reader_class(
m, "CheckpointReader");
checkpoint_reader_class
.def(py::init([](const std::string& filename) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
// pybind11 support smart pointers and will own freeing the memory when
// complete.
// https://pybind11.readthedocs.io/en/master/advanced/smart_ptrs.html#std-unique-ptr
auto checkpoint =
std::make_unique<tensorflow::checkpoint::CheckpointReader>(
filename, status.get());
tensorflow::MaybeRaiseFromTFStatus(status.get());
return checkpoint;
}))
.def("debug_string",
[](tensorflow::checkpoint::CheckpointReader& self) {
return py::bytes(self.DebugString());
})
.def("get_variable_to_shape_map",
&tensorflow::checkpoint::CheckpointReader::GetVariableToShapeMap)
.def("_GetVariableToDataTypeMap",
&tensorflow::checkpoint::CheckpointReader::GetVariableToDataTypeMap)
.def("_HasTensor", &tensorflow::checkpoint::CheckpointReader::HasTensor)
.def_static("CheckpointReader_GetTensor",
&tensorflow::CheckpointReader_GetTensor);
};

View File

@ -107,3 +107,21 @@ tensorflow::profiler::PythonTraceMe::IsEnabled
[traceme_recorder] # traceme
tensorflow::profiler::internal::g_trace_level
[checkpoint_reader] # py_checkpoint_reader
tensorflow::checkpoint::CheckpointReader
tensorflow::checkpoint::CheckpointReader::Init
tensorflow::checkpoint::CheckpointReader::DebugString
tensorflow::checkpoint::CheckpointReader::GetVariableToShapeMap
tensorflow::checkpoint::CheckpointReader::GetVariableToDataTypeMap
tensorflow::checkpoint::CheckpointReader::GetTensor
tensorflow::checkpoint::CheckpointReader::HasTensor
[tensor_bundle] # py_checkpoint_reader
tensorflow::BundleReader::BundleReader
tensorflow::BundleReader::~BundleReader
[ndarray_tensor] # py_checkpoint_reader
tensorflow::TensorToNdarray
[safe_ptr] # py_checkpoint_reader
tensorflow::detail::PyDecrefDeleter