Export the quantize_training functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.
PiperOrigin-RevId: 275945771 Change-Id: I77ed4389903b24d3e0fe12ed83aa441099b2b9bf
This commit is contained in:
parent
b0a7160413
commit
7bf81e0bce
@ -351,6 +351,17 @@ filegroup(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "quantize_training_hdrs",
|
||||
srcs = [
|
||||
"graph/quantize_training.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "util_port",
|
||||
srcs = ["util/port.cc"],
|
||||
|
@ -101,6 +101,7 @@ py_library(
|
||||
":_pywrap_events_writer",
|
||||
":_pywrap_kernel_registry",
|
||||
":_pywrap_py_exception_registry",
|
||||
":_pywrap_quantize_training",
|
||||
":_pywrap_stat_summarizer",
|
||||
":_pywrap_tfprof",
|
||||
":_pywrap_util_port",
|
||||
@ -474,6 +475,27 @@ tf_python_pybind_extension(
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_quantize_training",
|
||||
srcs = [
|
||||
"training/quantize_training_wrapper.cc",
|
||||
],
|
||||
hdrs = ["//tensorflow/core:quantize_training_hdrs"],
|
||||
module_name = "_pywrap_quantize_training",
|
||||
deps = [
|
||||
":pybind11_lib",
|
||||
":pybind11_proto",
|
||||
":pybind11_status",
|
||||
"//tensorflow/core:core_cpu_headers_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_stat_summarizer",
|
||||
srcs = ["util/stat_summarizer_wrapper.cc"],
|
||||
@ -890,6 +912,7 @@ py_library(
|
||||
":_pywrap_kernel_registry",
|
||||
":_pywrap_py_exception_registry",
|
||||
":_pywrap_py_func", # TODO(b/142001480): remove once the bug is fixed.
|
||||
":_pywrap_quantize_training",
|
||||
":_pywrap_stat_summarizer",
|
||||
":_pywrap_tfprof",
|
||||
":_pywrap_util_port",
|
||||
@ -5157,7 +5180,6 @@ tf_py_wrap_cc(
|
||||
"platform/base.i",
|
||||
"platform/stacktrace_handler.i",
|
||||
"pywrap_tfe.i",
|
||||
"training/quantize_training.i",
|
||||
"util/py_checkpoint_reader.i",
|
||||
"util/scoped_annotation.i",
|
||||
"util/traceme.i",
|
||||
|
@ -54,6 +54,7 @@ from tensorflow.python import _pywrap_util_port
|
||||
from tensorflow.python import _pywrap_stat_summarizer
|
||||
from tensorflow.python import _pywrap_py_exception_registry
|
||||
from tensorflow.python import _pywrap_kernel_registry
|
||||
from tensorflow.python import _pywrap_quantize_training
|
||||
|
||||
# Protocol buffers
|
||||
from tensorflow.core.framework.graph_pb2 import *
|
||||
@ -120,6 +121,7 @@ from tensorflow.python.ops import gen_sendrecv_ops
|
||||
|
||||
# Import the names from python/training.py as train.Name.
|
||||
from tensorflow.python.training import training as train
|
||||
from tensorflow.python.training import quantize_training as _quantize_training
|
||||
|
||||
# Sub-package for performing i/o directly instead of via ops in a graph.
|
||||
from tensorflow.python.lib.io import python_io
|
||||
|
@ -32,7 +32,6 @@ limitations under the License.
|
||||
%include "tensorflow/lite/toco/python/toco.i"
|
||||
|
||||
%include "tensorflow/python/lib/io/file_io.i"
|
||||
%include "tensorflow/python/training/quantize_training.i"
|
||||
|
||||
%include "tensorflow/python/framework/python_op_gen.i"
|
||||
|
||||
|
@ -1,87 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%{
|
||||
#include "tensorflow/core/graph/quantize_training.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
static PyObject* DoQuantizeTrainingOnGraphDefHelper(
|
||||
const string& input_graph,
|
||||
int num_bits,
|
||||
TF_Status* status) {
|
||||
string result;
|
||||
// TODO(suharshs): Make the QuantizeAndDequantizeV2 configurable.
|
||||
tensorflow::Status s =
|
||||
tensorflow::DoQuantizeTrainingOnSerializedGraphDef(input_graph, num_bits,
|
||||
"QuantizeAndDequantizeV2", &result);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
PyObject* py_str = PyBytes_FromStringAndSize(result.data(), result.size());
|
||||
if (!py_str) {
|
||||
Set_TF_Status_from_Status(status,
|
||||
tensorflow::Status(tensorflow::error::INTERNAL,
|
||||
"Failed to generate serialized string of the rewritten graph."));
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
return py_str;
|
||||
}
|
||||
%}
|
||||
|
||||
%ignoreall
|
||||
%unignore DoQuantizeTrainingOnGraphDefHelper;
|
||||
|
||||
// Wrap this function
|
||||
PyObject* DoQuantizeTrainingOnGraphDefHelper(
|
||||
const string& input_graph,
|
||||
int num_bits,
|
||||
TF_Status* status);
|
||||
|
||||
|
||||
%insert("python") %{
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@deprecation.deprecated(
|
||||
None,
|
||||
"GraphDef quantized training rewriter is deprecated in the long term")
|
||||
@tf_export(v1=["train.do_quantize_training_on_graphdef"])
|
||||
def do_quantize_training_on_graphdef(input_graph, num_bits):
|
||||
"""A general quantization scheme is being developed in `tf.contrib.quantize`.
|
||||
|
||||
Consider using that instead, though since it is in the tf.contrib namespace,
|
||||
it is not subject to backward compatibility guarantees.
|
||||
"""
|
||||
from tensorflow.core.framework.graph_pb2 import GraphDef
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
graph = GraphDef()
|
||||
result_graph_string = DoQuantizeTrainingOnGraphDefHelper(
|
||||
input_graph.SerializeToString(), num_bits)
|
||||
|
||||
graph.ParseFromString(result_graph_string)
|
||||
return graph
|
||||
|
||||
do_quantize_training_on_graphdef._tf_api_names = [
|
||||
'train.do_quantize_training_on_graphdef']
|
||||
do_quantize_training_on_graphdef._tf_api_names_v1 = [
|
||||
'train.do_quantize_training_on_graphdef']
|
||||
%}
|
||||
|
||||
%unignoreall
|
50
tensorflow/python/training/quantize_training.py
Normal file
50
tensorflow/python/training/quantize_training.py
Normal file
@ -0,0 +1,50 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Quantize training for TensorFlow."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python._pywrap_quantize_training import DoQuantizeTrainingOnGraphDefHelper
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
# Migrated this python code from deprecated quantize_training.i
|
||||
@deprecation.deprecated(
|
||||
None,
|
||||
"GraphDef quantized training rewriter is deprecated in the long term.")
|
||||
@tf_export(v1=["train.do_quantize_training_on_graphdef"])
|
||||
def do_quantize_training_on_graphdef(input_graph, num_bits):
|
||||
"""A general quantization scheme is being developed in `tf.contrib.quantize`.
|
||||
|
||||
Consider using that instead, though since it is in the tf.contrib namespace,
|
||||
it is not subject to backward compatibility guarantees.
|
||||
|
||||
Args:
|
||||
input_graph: A `GraphDef`.
|
||||
num_bits: The number of bits for quantize training.
|
||||
|
||||
Returns:
|
||||
The graph with quantize training done.
|
||||
"""
|
||||
|
||||
graph = graph_pb2.GraphDef()
|
||||
result_graph_string = DoQuantizeTrainingOnGraphDefHelper(
|
||||
input_graph.SerializeToString(), num_bits)
|
||||
|
||||
graph.ParseFromString(result_graph_string)
|
||||
return graph
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import importer
|
||||
@ -29,6 +28,7 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import quantize_training
|
||||
from tensorflow.python.training import saver as saver_module
|
||||
|
||||
|
||||
@ -45,7 +45,7 @@ class PywrapQuantizeTrainingTest(test.TestCase):
|
||||
self.assertEquals(c.eval(), 42.0)
|
||||
self.assertEquals(len(sess.graph_def.node), 3)
|
||||
|
||||
result = pywrap_tensorflow.do_quantize_training_on_graphdef(
|
||||
result = quantize_training.do_quantize_training_on_graphdef(
|
||||
sess.graph_def, 8)
|
||||
|
||||
# We just want to guarantee that some rewrite happened.
|
||||
@ -68,7 +68,7 @@ class PywrapQuantizeTrainingTest(test.TestCase):
|
||||
|
||||
saver = saver_module.Saver({'b': b})
|
||||
|
||||
result = pywrap_tensorflow.do_quantize_training_on_graphdef(
|
||||
result = quantize_training.do_quantize_training_on_graphdef(
|
||||
sess.graph_def, 8)
|
||||
|
||||
with ops.Graph().as_default() as g, session.Session(graph=g) as sess:
|
||||
|
48
tensorflow/python/training/quantize_training_wrapper.cc
Normal file
48
tensorflow/python/training/quantize_training_wrapper.cc
Normal file
@ -0,0 +1,48 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "tensorflow/core/graph/quantize_training.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace tensorflow {
|
||||
static PyObject* DoQuantizeTrainingOnGraphDefHelper(const string& input_graph,
|
||||
int num_bits) {
|
||||
string result;
|
||||
// TODO(suharshs): Make the QuantizeAndDequantizeV2 configurable.
|
||||
tensorflow::MaybeRaiseFromStatus(
|
||||
tensorflow::DoQuantizeTrainingOnSerializedGraphDef(
|
||||
input_graph, num_bits, "QuantizeAndDequantizeV2", &result));
|
||||
|
||||
PyObject* py_str = PyBytes_FromStringAndSize(result.data(), result.size());
|
||||
if (!py_str) {
|
||||
tensorflow::MaybeRaiseFromStatus(tensorflow::errors::Internal(
|
||||
"Failed to generate serialized string of the rewritten graph."));
|
||||
}
|
||||
return py_str;
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
||||
PYBIND11_MODULE(_pywrap_quantize_training, m) {
|
||||
m.def("DoQuantizeTrainingOnGraphDefHelper",
|
||||
[](const py::object input_graph, int num_bits) {
|
||||
return tensorflow::pyo_or_throw(
|
||||
tensorflow::DoQuantizeTrainingOnGraphDefHelper(
|
||||
input_graph.cast<std::string>(), num_bits));
|
||||
});
|
||||
};
|
@ -110,7 +110,6 @@ from tensorflow.python.training.training_util import create_global_step
|
||||
from tensorflow.python.training.training_util import get_or_create_global_step
|
||||
from tensorflow.python.training.warm_starting_util import VocabInfo
|
||||
from tensorflow.python.training.warm_starting_util import warm_start
|
||||
from tensorflow.python.pywrap_tensorflow import do_quantize_training_on_graphdef
|
||||
from tensorflow.python.pywrap_tensorflow import NewCheckpointReader
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -146,3 +145,4 @@ tf_export(v1=["train.SaverDef"])(SaverDef)
|
||||
tf_export("train.SequenceExample")(SequenceExample)
|
||||
tf_export("train.ServerDef")(ServerDef)
|
||||
# pylint: enable=undefined-variable
|
||||
|
||||
|
@ -69,6 +69,7 @@ tensorflow::Status::ok()
|
||||
tensorflow::Device::attributes
|
||||
tensorflow::DeviceFactory::AddDevices
|
||||
tensorflow::SessionOptions::SessionOptions
|
||||
tensorflow::DoQuantizeTrainingOnSerializedGraphDef
|
||||
|
||||
[protos_all] # device_lib
|
||||
tensorflow::ConfigProto::ConfigProto
|
||||
@ -81,3 +82,4 @@ tensorflow::PyExceptionRegistry::Lookup
|
||||
|
||||
[kernel_registry] # kernel_registry
|
||||
tensorflow::swig::TryFindKernelClass
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user