Export the quantize_training functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.
PiperOrigin-RevId: 275945771 Change-Id: I77ed4389903b24d3e0fe12ed83aa441099b2b9bf
This commit is contained in:
parent
b0a7160413
commit
7bf81e0bce
@ -351,6 +351,17 @@ filegroup(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "quantize_training_hdrs",
|
||||||
|
srcs = [
|
||||||
|
"graph/quantize_training.h",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/core:__pkg__",
|
||||||
|
"//tensorflow/python:__pkg__",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "util_port",
|
name = "util_port",
|
||||||
srcs = ["util/port.cc"],
|
srcs = ["util/port.cc"],
|
||||||
|
@ -101,6 +101,7 @@ py_library(
|
|||||||
":_pywrap_events_writer",
|
":_pywrap_events_writer",
|
||||||
":_pywrap_kernel_registry",
|
":_pywrap_kernel_registry",
|
||||||
":_pywrap_py_exception_registry",
|
":_pywrap_py_exception_registry",
|
||||||
|
":_pywrap_quantize_training",
|
||||||
":_pywrap_stat_summarizer",
|
":_pywrap_stat_summarizer",
|
||||||
":_pywrap_tfprof",
|
":_pywrap_tfprof",
|
||||||
":_pywrap_util_port",
|
":_pywrap_util_port",
|
||||||
@ -474,6 +475,27 @@ tf_python_pybind_extension(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_python_pybind_extension(
|
||||||
|
name = "_pywrap_quantize_training",
|
||||||
|
srcs = [
|
||||||
|
"training/quantize_training_wrapper.cc",
|
||||||
|
],
|
||||||
|
hdrs = ["//tensorflow/core:quantize_training_hdrs"],
|
||||||
|
module_name = "_pywrap_quantize_training",
|
||||||
|
deps = [
|
||||||
|
":pybind11_lib",
|
||||||
|
":pybind11_proto",
|
||||||
|
":pybind11_status",
|
||||||
|
"//tensorflow/core:core_cpu_headers_lib",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//third_party/python_runtime:headers",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_python_pybind_extension(
|
tf_python_pybind_extension(
|
||||||
name = "_pywrap_stat_summarizer",
|
name = "_pywrap_stat_summarizer",
|
||||||
srcs = ["util/stat_summarizer_wrapper.cc"],
|
srcs = ["util/stat_summarizer_wrapper.cc"],
|
||||||
@ -890,6 +912,7 @@ py_library(
|
|||||||
":_pywrap_kernel_registry",
|
":_pywrap_kernel_registry",
|
||||||
":_pywrap_py_exception_registry",
|
":_pywrap_py_exception_registry",
|
||||||
":_pywrap_py_func", # TODO(b/142001480): remove once the bug is fixed.
|
":_pywrap_py_func", # TODO(b/142001480): remove once the bug is fixed.
|
||||||
|
":_pywrap_quantize_training",
|
||||||
":_pywrap_stat_summarizer",
|
":_pywrap_stat_summarizer",
|
||||||
":_pywrap_tfprof",
|
":_pywrap_tfprof",
|
||||||
":_pywrap_util_port",
|
":_pywrap_util_port",
|
||||||
@ -5157,7 +5180,6 @@ tf_py_wrap_cc(
|
|||||||
"platform/base.i",
|
"platform/base.i",
|
||||||
"platform/stacktrace_handler.i",
|
"platform/stacktrace_handler.i",
|
||||||
"pywrap_tfe.i",
|
"pywrap_tfe.i",
|
||||||
"training/quantize_training.i",
|
|
||||||
"util/py_checkpoint_reader.i",
|
"util/py_checkpoint_reader.i",
|
||||||
"util/scoped_annotation.i",
|
"util/scoped_annotation.i",
|
||||||
"util/traceme.i",
|
"util/traceme.i",
|
||||||
|
@ -54,6 +54,7 @@ from tensorflow.python import _pywrap_util_port
|
|||||||
from tensorflow.python import _pywrap_stat_summarizer
|
from tensorflow.python import _pywrap_stat_summarizer
|
||||||
from tensorflow.python import _pywrap_py_exception_registry
|
from tensorflow.python import _pywrap_py_exception_registry
|
||||||
from tensorflow.python import _pywrap_kernel_registry
|
from tensorflow.python import _pywrap_kernel_registry
|
||||||
|
from tensorflow.python import _pywrap_quantize_training
|
||||||
|
|
||||||
# Protocol buffers
|
# Protocol buffers
|
||||||
from tensorflow.core.framework.graph_pb2 import *
|
from tensorflow.core.framework.graph_pb2 import *
|
||||||
@ -120,6 +121,7 @@ from tensorflow.python.ops import gen_sendrecv_ops
|
|||||||
|
|
||||||
# Import the names from python/training.py as train.Name.
|
# Import the names from python/training.py as train.Name.
|
||||||
from tensorflow.python.training import training as train
|
from tensorflow.python.training import training as train
|
||||||
|
from tensorflow.python.training import quantize_training as _quantize_training
|
||||||
|
|
||||||
# Sub-package for performing i/o directly instead of via ops in a graph.
|
# Sub-package for performing i/o directly instead of via ops in a graph.
|
||||||
from tensorflow.python.lib.io import python_io
|
from tensorflow.python.lib.io import python_io
|
||||||
|
@ -32,7 +32,6 @@ limitations under the License.
|
|||||||
%include "tensorflow/lite/toco/python/toco.i"
|
%include "tensorflow/lite/toco/python/toco.i"
|
||||||
|
|
||||||
%include "tensorflow/python/lib/io/file_io.i"
|
%include "tensorflow/python/lib/io/file_io.i"
|
||||||
%include "tensorflow/python/training/quantize_training.i"
|
|
||||||
|
|
||||||
%include "tensorflow/python/framework/python_op_gen.i"
|
%include "tensorflow/python/framework/python_op_gen.i"
|
||||||
|
|
||||||
|
@ -1,87 +0,0 @@
|
|||||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
%include "tensorflow/python/platform/base.i"
|
|
||||||
|
|
||||||
%{
|
|
||||||
#include "tensorflow/core/graph/quantize_training.h"
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
|
||||||
|
|
||||||
static PyObject* DoQuantizeTrainingOnGraphDefHelper(
|
|
||||||
const string& input_graph,
|
|
||||||
int num_bits,
|
|
||||||
TF_Status* status) {
|
|
||||||
string result;
|
|
||||||
// TODO(suharshs): Make the QuantizeAndDequantizeV2 configurable.
|
|
||||||
tensorflow::Status s =
|
|
||||||
tensorflow::DoQuantizeTrainingOnSerializedGraphDef(input_graph, num_bits,
|
|
||||||
"QuantizeAndDequantizeV2", &result);
|
|
||||||
if (!s.ok()) {
|
|
||||||
Set_TF_Status_from_Status(status, s);
|
|
||||||
Py_RETURN_NONE;
|
|
||||||
}
|
|
||||||
PyObject* py_str = PyBytes_FromStringAndSize(result.data(), result.size());
|
|
||||||
if (!py_str) {
|
|
||||||
Set_TF_Status_from_Status(status,
|
|
||||||
tensorflow::Status(tensorflow::error::INTERNAL,
|
|
||||||
"Failed to generate serialized string of the rewritten graph."));
|
|
||||||
Py_RETURN_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
return py_str;
|
|
||||||
}
|
|
||||||
%}
|
|
||||||
|
|
||||||
%ignoreall
|
|
||||||
%unignore DoQuantizeTrainingOnGraphDefHelper;
|
|
||||||
|
|
||||||
// Wrap this function
|
|
||||||
PyObject* DoQuantizeTrainingOnGraphDefHelper(
|
|
||||||
const string& input_graph,
|
|
||||||
int num_bits,
|
|
||||||
TF_Status* status);
|
|
||||||
|
|
||||||
|
|
||||||
%insert("python") %{
|
|
||||||
from tensorflow.python.util import deprecation
|
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
|
||||||
|
|
||||||
@deprecation.deprecated(
|
|
||||||
None,
|
|
||||||
"GraphDef quantized training rewriter is deprecated in the long term")
|
|
||||||
@tf_export(v1=["train.do_quantize_training_on_graphdef"])
|
|
||||||
def do_quantize_training_on_graphdef(input_graph, num_bits):
|
|
||||||
"""A general quantization scheme is being developed in `tf.contrib.quantize`.
|
|
||||||
|
|
||||||
Consider using that instead, though since it is in the tf.contrib namespace,
|
|
||||||
it is not subject to backward compatibility guarantees.
|
|
||||||
"""
|
|
||||||
from tensorflow.core.framework.graph_pb2 import GraphDef
|
|
||||||
from tensorflow.python.framework import errors
|
|
||||||
|
|
||||||
graph = GraphDef()
|
|
||||||
result_graph_string = DoQuantizeTrainingOnGraphDefHelper(
|
|
||||||
input_graph.SerializeToString(), num_bits)
|
|
||||||
|
|
||||||
graph.ParseFromString(result_graph_string)
|
|
||||||
return graph
|
|
||||||
|
|
||||||
do_quantize_training_on_graphdef._tf_api_names = [
|
|
||||||
'train.do_quantize_training_on_graphdef']
|
|
||||||
do_quantize_training_on_graphdef._tf_api_names_v1 = [
|
|
||||||
'train.do_quantize_training_on_graphdef']
|
|
||||||
%}
|
|
||||||
|
|
||||||
%unignoreall
|
|
50
tensorflow/python/training/quantize_training.py
Normal file
50
tensorflow/python/training/quantize_training.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Quantize training for TensorFlow."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.core.framework import graph_pb2
|
||||||
|
from tensorflow.python._pywrap_quantize_training import DoQuantizeTrainingOnGraphDefHelper
|
||||||
|
from tensorflow.python.util import deprecation
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
|
# Migrated this python code from deprecated quantize_training.i
|
||||||
|
@deprecation.deprecated(
|
||||||
|
None,
|
||||||
|
"GraphDef quantized training rewriter is deprecated in the long term.")
|
||||||
|
@tf_export(v1=["train.do_quantize_training_on_graphdef"])
|
||||||
|
def do_quantize_training_on_graphdef(input_graph, num_bits):
|
||||||
|
"""A general quantization scheme is being developed in `tf.contrib.quantize`.
|
||||||
|
|
||||||
|
Consider using that instead, though since it is in the tf.contrib namespace,
|
||||||
|
it is not subject to backward compatibility guarantees.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_graph: A `GraphDef`.
|
||||||
|
num_bits: The number of bits for quantize training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The graph with quantize training done.
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph = graph_pb2.GraphDef()
|
||||||
|
result_graph_string = DoQuantizeTrainingOnGraphDefHelper(
|
||||||
|
input_graph.SerializeToString(), num_bits)
|
||||||
|
|
||||||
|
graph.ParseFromString(result_graph_string)
|
||||||
|
return graph
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import importer
|
from tensorflow.python.framework import importer
|
||||||
@ -29,6 +28,7 @@ from tensorflow.python.framework import test_util
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.training import quantize_training
|
||||||
from tensorflow.python.training import saver as saver_module
|
from tensorflow.python.training import saver as saver_module
|
||||||
|
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ class PywrapQuantizeTrainingTest(test.TestCase):
|
|||||||
self.assertEquals(c.eval(), 42.0)
|
self.assertEquals(c.eval(), 42.0)
|
||||||
self.assertEquals(len(sess.graph_def.node), 3)
|
self.assertEquals(len(sess.graph_def.node), 3)
|
||||||
|
|
||||||
result = pywrap_tensorflow.do_quantize_training_on_graphdef(
|
result = quantize_training.do_quantize_training_on_graphdef(
|
||||||
sess.graph_def, 8)
|
sess.graph_def, 8)
|
||||||
|
|
||||||
# We just want to guarantee that some rewrite happened.
|
# We just want to guarantee that some rewrite happened.
|
||||||
@ -68,7 +68,7 @@ class PywrapQuantizeTrainingTest(test.TestCase):
|
|||||||
|
|
||||||
saver = saver_module.Saver({'b': b})
|
saver = saver_module.Saver({'b': b})
|
||||||
|
|
||||||
result = pywrap_tensorflow.do_quantize_training_on_graphdef(
|
result = quantize_training.do_quantize_training_on_graphdef(
|
||||||
sess.graph_def, 8)
|
sess.graph_def, 8)
|
||||||
|
|
||||||
with ops.Graph().as_default() as g, session.Session(graph=g) as sess:
|
with ops.Graph().as_default() as g, session.Session(graph=g) as sess:
|
||||||
|
48
tensorflow/python/training/quantize_training_wrapper.cc
Normal file
48
tensorflow/python/training/quantize_training_wrapper.cc
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "include/pybind11/pybind11.h"
|
||||||
|
#include "tensorflow/core/graph/quantize_training.h"
|
||||||
|
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||||
|
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
static PyObject* DoQuantizeTrainingOnGraphDefHelper(const string& input_graph,
|
||||||
|
int num_bits) {
|
||||||
|
string result;
|
||||||
|
// TODO(suharshs): Make the QuantizeAndDequantizeV2 configurable.
|
||||||
|
tensorflow::MaybeRaiseFromStatus(
|
||||||
|
tensorflow::DoQuantizeTrainingOnSerializedGraphDef(
|
||||||
|
input_graph, num_bits, "QuantizeAndDequantizeV2", &result));
|
||||||
|
|
||||||
|
PyObject* py_str = PyBytes_FromStringAndSize(result.data(), result.size());
|
||||||
|
if (!py_str) {
|
||||||
|
tensorflow::MaybeRaiseFromStatus(tensorflow::errors::Internal(
|
||||||
|
"Failed to generate serialized string of the rewritten graph."));
|
||||||
|
}
|
||||||
|
return py_str;
|
||||||
|
}
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
PYBIND11_MODULE(_pywrap_quantize_training, m) {
|
||||||
|
m.def("DoQuantizeTrainingOnGraphDefHelper",
|
||||||
|
[](const py::object input_graph, int num_bits) {
|
||||||
|
return tensorflow::pyo_or_throw(
|
||||||
|
tensorflow::DoQuantizeTrainingOnGraphDefHelper(
|
||||||
|
input_graph.cast<std::string>(), num_bits));
|
||||||
|
});
|
||||||
|
};
|
@ -110,7 +110,6 @@ from tensorflow.python.training.training_util import create_global_step
|
|||||||
from tensorflow.python.training.training_util import get_or_create_global_step
|
from tensorflow.python.training.training_util import get_or_create_global_step
|
||||||
from tensorflow.python.training.warm_starting_util import VocabInfo
|
from tensorflow.python.training.warm_starting_util import VocabInfo
|
||||||
from tensorflow.python.training.warm_starting_util import warm_start
|
from tensorflow.python.training.warm_starting_util import warm_start
|
||||||
from tensorflow.python.pywrap_tensorflow import do_quantize_training_on_graphdef
|
|
||||||
from tensorflow.python.pywrap_tensorflow import NewCheckpointReader
|
from tensorflow.python.pywrap_tensorflow import NewCheckpointReader
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
@ -146,3 +145,4 @@ tf_export(v1=["train.SaverDef"])(SaverDef)
|
|||||||
tf_export("train.SequenceExample")(SequenceExample)
|
tf_export("train.SequenceExample")(SequenceExample)
|
||||||
tf_export("train.ServerDef")(ServerDef)
|
tf_export("train.ServerDef")(ServerDef)
|
||||||
# pylint: enable=undefined-variable
|
# pylint: enable=undefined-variable
|
||||||
|
|
||||||
|
@ -69,6 +69,7 @@ tensorflow::Status::ok()
|
|||||||
tensorflow::Device::attributes
|
tensorflow::Device::attributes
|
||||||
tensorflow::DeviceFactory::AddDevices
|
tensorflow::DeviceFactory::AddDevices
|
||||||
tensorflow::SessionOptions::SessionOptions
|
tensorflow::SessionOptions::SessionOptions
|
||||||
|
tensorflow::DoQuantizeTrainingOnSerializedGraphDef
|
||||||
|
|
||||||
[protos_all] # device_lib
|
[protos_all] # device_lib
|
||||||
tensorflow::ConfigProto::ConfigProto
|
tensorflow::ConfigProto::ConfigProto
|
||||||
@ -81,3 +82,4 @@ tensorflow::PyExceptionRegistry::Lookup
|
|||||||
|
|
||||||
[kernel_registry] # kernel_registry
|
[kernel_registry] # kernel_registry
|
||||||
tensorflow::swig::TryFindKernelClass
|
tensorflow::swig::TryFindKernelClass
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user