Export the quantize_training functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.

PiperOrigin-RevId: 275945771
Change-Id: I77ed4389903b24d3e0fe12ed83aa441099b2b9bf
This commit is contained in:
Amit Patankar 2019-10-21 16:08:20 -07:00 committed by TensorFlower Gardener
parent b0a7160413
commit 7bf81e0bce
10 changed files with 140 additions and 93 deletions

View File

@ -351,6 +351,17 @@ filegroup(
],
)
filegroup(
name = "quantize_training_hdrs",
srcs = [
"graph/quantize_training.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "util_port",
srcs = ["util/port.cc"],

View File

@ -101,6 +101,7 @@ py_library(
":_pywrap_events_writer",
":_pywrap_kernel_registry",
":_pywrap_py_exception_registry",
":_pywrap_quantize_training",
":_pywrap_stat_summarizer",
":_pywrap_tfprof",
":_pywrap_util_port",
@ -474,6 +475,27 @@ tf_python_pybind_extension(
],
)
tf_python_pybind_extension(
name = "_pywrap_quantize_training",
srcs = [
"training/quantize_training_wrapper.cc",
],
hdrs = ["//tensorflow/core:quantize_training_hdrs"],
module_name = "_pywrap_quantize_training",
deps = [
":pybind11_lib",
":pybind11_proto",
":pybind11_status",
"//tensorflow/core:core_cpu_headers_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//third_party/python_runtime:headers",
"@com_google_absl//absl/strings",
"@pybind11",
],
)
tf_python_pybind_extension(
name = "_pywrap_stat_summarizer",
srcs = ["util/stat_summarizer_wrapper.cc"],
@ -890,6 +912,7 @@ py_library(
":_pywrap_kernel_registry",
":_pywrap_py_exception_registry",
":_pywrap_py_func", # TODO(b/142001480): remove once the bug is fixed.
":_pywrap_quantize_training",
":_pywrap_stat_summarizer",
":_pywrap_tfprof",
":_pywrap_util_port",
@ -5157,7 +5180,6 @@ tf_py_wrap_cc(
"platform/base.i",
"platform/stacktrace_handler.i",
"pywrap_tfe.i",
"training/quantize_training.i",
"util/py_checkpoint_reader.i",
"util/scoped_annotation.i",
"util/traceme.i",

View File

@ -54,6 +54,7 @@ from tensorflow.python import _pywrap_util_port
from tensorflow.python import _pywrap_stat_summarizer
from tensorflow.python import _pywrap_py_exception_registry
from tensorflow.python import _pywrap_kernel_registry
from tensorflow.python import _pywrap_quantize_training
# Protocol buffers
from tensorflow.core.framework.graph_pb2 import *
@ -120,6 +121,7 @@ from tensorflow.python.ops import gen_sendrecv_ops
# Import the names from python/training.py as train.Name.
from tensorflow.python.training import training as train
from tensorflow.python.training import quantize_training as _quantize_training
# Sub-package for performing i/o directly instead of via ops in a graph.
from tensorflow.python.lib.io import python_io

View File

@ -32,7 +32,6 @@ limitations under the License.
%include "tensorflow/lite/toco/python/toco.i"
%include "tensorflow/python/lib/io/file_io.i"
%include "tensorflow/python/training/quantize_training.i"
%include "tensorflow/python/framework/python_op_gen.i"

View File

@ -1,87 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/platform/base.i"
%{
#include "tensorflow/core/graph/quantize_training.h"
#include "tensorflow/core/lib/core/status.h"
static PyObject* DoQuantizeTrainingOnGraphDefHelper(
const string& input_graph,
int num_bits,
TF_Status* status) {
string result;
// TODO(suharshs): Make the QuantizeAndDequantizeV2 configurable.
tensorflow::Status s =
tensorflow::DoQuantizeTrainingOnSerializedGraphDef(input_graph, num_bits,
"QuantizeAndDequantizeV2", &result);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
Py_RETURN_NONE;
}
PyObject* py_str = PyBytes_FromStringAndSize(result.data(), result.size());
if (!py_str) {
Set_TF_Status_from_Status(status,
tensorflow::Status(tensorflow::error::INTERNAL,
"Failed to generate serialized string of the rewritten graph."));
Py_RETURN_NONE;
}
return py_str;
}
%}
%ignoreall
%unignore DoQuantizeTrainingOnGraphDefHelper;
// Wrap this function
PyObject* DoQuantizeTrainingOnGraphDefHelper(
const string& input_graph,
int num_bits,
TF_Status* status);
%insert("python") %{
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@deprecation.deprecated(
None,
"GraphDef quantized training rewriter is deprecated in the long term")
@tf_export(v1=["train.do_quantize_training_on_graphdef"])
def do_quantize_training_on_graphdef(input_graph, num_bits):
"""A general quantization scheme is being developed in `tf.contrib.quantize`.
Consider using that instead, though since it is in the tf.contrib namespace,
it is not subject to backward compatibility guarantees.
"""
from tensorflow.core.framework.graph_pb2 import GraphDef
from tensorflow.python.framework import errors
graph = GraphDef()
result_graph_string = DoQuantizeTrainingOnGraphDefHelper(
input_graph.SerializeToString(), num_bits)
graph.ParseFromString(result_graph_string)
return graph
do_quantize_training_on_graphdef._tf_api_names = [
'train.do_quantize_training_on_graphdef']
do_quantize_training_on_graphdef._tf_api_names_v1 = [
'train.do_quantize_training_on_graphdef']
%}
%unignoreall

View File

@ -0,0 +1,50 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Quantize training for TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import graph_pb2
from tensorflow.python._pywrap_quantize_training import DoQuantizeTrainingOnGraphDefHelper
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# Migrated this python code from deprecated quantize_training.i
@deprecation.deprecated(
None,
"GraphDef quantized training rewriter is deprecated in the long term.")
@tf_export(v1=["train.do_quantize_training_on_graphdef"])
def do_quantize_training_on_graphdef(input_graph, num_bits):
"""A general quantization scheme is being developed in `tf.contrib.quantize`.
Consider using that instead, though since it is in the tf.contrib namespace,
it is not subject to backward compatibility guarantees.
Args:
input_graph: A `GraphDef`.
num_bits: The number of bits for quantize training.
Returns:
The graph with quantize training done.
"""
graph = graph_pb2.GraphDef()
result_graph_string = DoQuantizeTrainingOnGraphDefHelper(
input_graph.SerializeToString(), num_bits)
graph.ParseFromString(result_graph_string)
return graph

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import os
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import importer
@ -29,6 +28,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import quantize_training
from tensorflow.python.training import saver as saver_module
@ -45,7 +45,7 @@ class PywrapQuantizeTrainingTest(test.TestCase):
self.assertEquals(c.eval(), 42.0)
self.assertEquals(len(sess.graph_def.node), 3)
result = pywrap_tensorflow.do_quantize_training_on_graphdef(
result = quantize_training.do_quantize_training_on_graphdef(
sess.graph_def, 8)
# We just want to guarantee that some rewrite happened.
@ -68,7 +68,7 @@ class PywrapQuantizeTrainingTest(test.TestCase):
saver = saver_module.Saver({'b': b})
result = pywrap_tensorflow.do_quantize_training_on_graphdef(
result = quantize_training.do_quantize_training_on_graphdef(
sess.graph_def, 8)
with ops.Graph().as_default() as g, session.Session(graph=g) as sess:

View File

@ -0,0 +1,48 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "include/pybind11/pybind11.h"
#include "tensorflow/core/graph/quantize_training.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
namespace py = pybind11;
namespace tensorflow {
static PyObject* DoQuantizeTrainingOnGraphDefHelper(const string& input_graph,
int num_bits) {
string result;
// TODO(suharshs): Make the QuantizeAndDequantizeV2 configurable.
tensorflow::MaybeRaiseFromStatus(
tensorflow::DoQuantizeTrainingOnSerializedGraphDef(
input_graph, num_bits, "QuantizeAndDequantizeV2", &result));
PyObject* py_str = PyBytes_FromStringAndSize(result.data(), result.size());
if (!py_str) {
tensorflow::MaybeRaiseFromStatus(tensorflow::errors::Internal(
"Failed to generate serialized string of the rewritten graph."));
}
return py_str;
}
} // namespace tensorflow
PYBIND11_MODULE(_pywrap_quantize_training, m) {
m.def("DoQuantizeTrainingOnGraphDefHelper",
[](const py::object input_graph, int num_bits) {
return tensorflow::pyo_or_throw(
tensorflow::DoQuantizeTrainingOnGraphDefHelper(
input_graph.cast<std::string>(), num_bits));
});
};

View File

@ -110,7 +110,6 @@ from tensorflow.python.training.training_util import create_global_step
from tensorflow.python.training.training_util import get_or_create_global_step
from tensorflow.python.training.warm_starting_util import VocabInfo
from tensorflow.python.training.warm_starting_util import warm_start
from tensorflow.python.pywrap_tensorflow import do_quantize_training_on_graphdef
from tensorflow.python.pywrap_tensorflow import NewCheckpointReader
from tensorflow.python.util.tf_export import tf_export
@ -146,3 +145,4 @@ tf_export(v1=["train.SaverDef"])(SaverDef)
tf_export("train.SequenceExample")(SequenceExample)
tf_export("train.ServerDef")(ServerDef)
# pylint: enable=undefined-variable

View File

@ -69,6 +69,7 @@ tensorflow::Status::ok()
tensorflow::Device::attributes
tensorflow::DeviceFactory::AddDevices
tensorflow::SessionOptions::SessionOptions
tensorflow::DoQuantizeTrainingOnSerializedGraphDef
[protos_all] # device_lib
tensorflow::ConfigProto::ConfigProto
@ -81,3 +82,4 @@ tensorflow::PyExceptionRegistry::Lookup
[kernel_registry] # kernel_registry
tensorflow::swig::TryFindKernelClass