diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 9f889859522..3ec93e26910 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -351,6 +351,17 @@ filegroup( ], ) +filegroup( + name = "quantize_training_hdrs", + srcs = [ + "graph/quantize_training.h", + ], + visibility = [ + "//tensorflow/core:__pkg__", + "//tensorflow/python:__pkg__", + ], +) + cc_library( name = "util_port", srcs = ["util/port.cc"], diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index cca367fef4c..4b2738725d5 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -101,6 +101,7 @@ py_library( ":_pywrap_events_writer", ":_pywrap_kernel_registry", ":_pywrap_py_exception_registry", + ":_pywrap_quantize_training", ":_pywrap_stat_summarizer", ":_pywrap_tfprof", ":_pywrap_util_port", @@ -474,6 +475,27 @@ tf_python_pybind_extension( ], ) +tf_python_pybind_extension( + name = "_pywrap_quantize_training", + srcs = [ + "training/quantize_training_wrapper.cc", + ], + hdrs = ["//tensorflow/core:quantize_training_hdrs"], + module_name = "_pywrap_quantize_training", + deps = [ + ":pybind11_lib", + ":pybind11_proto", + ":pybind11_status", + "//tensorflow/core:core_cpu_headers_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//third_party/python_runtime:headers", + "@com_google_absl//absl/strings", + "@pybind11", + ], +) + tf_python_pybind_extension( name = "_pywrap_stat_summarizer", srcs = ["util/stat_summarizer_wrapper.cc"], @@ -890,6 +912,7 @@ py_library( ":_pywrap_kernel_registry", ":_pywrap_py_exception_registry", ":_pywrap_py_func", # TODO(b/142001480): remove once the bug is fixed. + ":_pywrap_quantize_training", ":_pywrap_stat_summarizer", ":_pywrap_tfprof", ":_pywrap_util_port", @@ -5157,7 +5180,6 @@ tf_py_wrap_cc( "platform/base.i", "platform/stacktrace_handler.i", "pywrap_tfe.i", - "training/quantize_training.i", "util/py_checkpoint_reader.i", "util/scoped_annotation.i", "util/traceme.i", diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 88c0ce90e3d..34127bb05c1 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -54,6 +54,7 @@ from tensorflow.python import _pywrap_util_port from tensorflow.python import _pywrap_stat_summarizer from tensorflow.python import _pywrap_py_exception_registry from tensorflow.python import _pywrap_kernel_registry +from tensorflow.python import _pywrap_quantize_training # Protocol buffers from tensorflow.core.framework.graph_pb2 import * @@ -120,6 +121,7 @@ from tensorflow.python.ops import gen_sendrecv_ops # Import the names from python/training.py as train.Name. from tensorflow.python.training import training as train +from tensorflow.python.training import quantize_training as _quantize_training # Sub-package for performing i/o directly instead of via ops in a graph. from tensorflow.python.lib.io import python_io diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index 56e859c2e2f..75629bc2abe 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -32,7 +32,6 @@ limitations under the License. %include "tensorflow/lite/toco/python/toco.i" %include "tensorflow/python/lib/io/file_io.i" -%include "tensorflow/python/training/quantize_training.i" %include "tensorflow/python/framework/python_op_gen.i" diff --git a/tensorflow/python/training/quantize_training.i b/tensorflow/python/training/quantize_training.i deleted file mode 100644 index dd4eb10d3c2..00000000000 --- a/tensorflow/python/training/quantize_training.i +++ /dev/null @@ -1,87 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -%include "tensorflow/python/platform/base.i" - -%{ -#include "tensorflow/core/graph/quantize_training.h" -#include "tensorflow/core/lib/core/status.h" - -static PyObject* DoQuantizeTrainingOnGraphDefHelper( - const string& input_graph, - int num_bits, - TF_Status* status) { - string result; - // TODO(suharshs): Make the QuantizeAndDequantizeV2 configurable. - tensorflow::Status s = - tensorflow::DoQuantizeTrainingOnSerializedGraphDef(input_graph, num_bits, - "QuantizeAndDequantizeV2", &result); - if (!s.ok()) { - Set_TF_Status_from_Status(status, s); - Py_RETURN_NONE; - } - PyObject* py_str = PyBytes_FromStringAndSize(result.data(), result.size()); - if (!py_str) { - Set_TF_Status_from_Status(status, - tensorflow::Status(tensorflow::error::INTERNAL, - "Failed to generate serialized string of the rewritten graph.")); - Py_RETURN_NONE; - } - - return py_str; -} -%} - -%ignoreall -%unignore DoQuantizeTrainingOnGraphDefHelper; - -// Wrap this function -PyObject* DoQuantizeTrainingOnGraphDefHelper( - const string& input_graph, - int num_bits, - TF_Status* status); - - -%insert("python") %{ -from tensorflow.python.util import deprecation -from tensorflow.python.util.tf_export import tf_export - -@deprecation.deprecated( - None, - "GraphDef quantized training rewriter is deprecated in the long term") -@tf_export(v1=["train.do_quantize_training_on_graphdef"]) -def do_quantize_training_on_graphdef(input_graph, num_bits): - """A general quantization scheme is being developed in `tf.contrib.quantize`. - - Consider using that instead, though since it is in the tf.contrib namespace, - it is not subject to backward compatibility guarantees. - """ - from tensorflow.core.framework.graph_pb2 import GraphDef - from tensorflow.python.framework import errors - - graph = GraphDef() - result_graph_string = DoQuantizeTrainingOnGraphDefHelper( - input_graph.SerializeToString(), num_bits) - - graph.ParseFromString(result_graph_string) - return graph - -do_quantize_training_on_graphdef._tf_api_names = [ - 'train.do_quantize_training_on_graphdef'] -do_quantize_training_on_graphdef._tf_api_names_v1 = [ - 'train.do_quantize_training_on_graphdef'] -%} - -%unignoreall diff --git a/tensorflow/python/training/quantize_training.py b/tensorflow/python/training/quantize_training.py new file mode 100644 index 00000000000..f6f0456d90d --- /dev/null +++ b/tensorflow/python/training/quantize_training.py @@ -0,0 +1,50 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Quantize training for TensorFlow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.framework import graph_pb2 +from tensorflow.python._pywrap_quantize_training import DoQuantizeTrainingOnGraphDefHelper +from tensorflow.python.util import deprecation +from tensorflow.python.util.tf_export import tf_export + + +# Migrated this python code from deprecated quantize_training.i +@deprecation.deprecated( + None, + "GraphDef quantized training rewriter is deprecated in the long term.") +@tf_export(v1=["train.do_quantize_training_on_graphdef"]) +def do_quantize_training_on_graphdef(input_graph, num_bits): + """A general quantization scheme is being developed in `tf.contrib.quantize`. + + Consider using that instead, though since it is in the tf.contrib namespace, + it is not subject to backward compatibility guarantees. + + Args: + input_graph: A `GraphDef`. + num_bits: The number of bits for quantize training. + + Returns: + The graph with quantize training done. + """ + + graph = graph_pb2.GraphDef() + result_graph_string = DoQuantizeTrainingOnGraphDefHelper( + input_graph.SerializeToString(), num_bits) + + graph.ParseFromString(result_graph_string) + return graph diff --git a/tensorflow/python/training/quantize_training_test.py b/tensorflow/python/training/quantize_training_test.py index 2352af7e99b..813ea2416bf 100644 --- a/tensorflow/python/training/quantize_training_test.py +++ b/tensorflow/python/training/quantize_training_test.py @@ -20,7 +20,6 @@ from __future__ import print_function import os -from tensorflow.python import pywrap_tensorflow from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import importer @@ -29,6 +28,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training import quantize_training from tensorflow.python.training import saver as saver_module @@ -45,7 +45,7 @@ class PywrapQuantizeTrainingTest(test.TestCase): self.assertEquals(c.eval(), 42.0) self.assertEquals(len(sess.graph_def.node), 3) - result = pywrap_tensorflow.do_quantize_training_on_graphdef( + result = quantize_training.do_quantize_training_on_graphdef( sess.graph_def, 8) # We just want to guarantee that some rewrite happened. @@ -68,7 +68,7 @@ class PywrapQuantizeTrainingTest(test.TestCase): saver = saver_module.Saver({'b': b}) - result = pywrap_tensorflow.do_quantize_training_on_graphdef( + result = quantize_training.do_quantize_training_on_graphdef( sess.graph_def, 8) with ops.Graph().as_default() as g, session.Session(graph=g) as sess: diff --git a/tensorflow/python/training/quantize_training_wrapper.cc b/tensorflow/python/training/quantize_training_wrapper.cc new file mode 100644 index 00000000000..f4173553ed6 --- /dev/null +++ b/tensorflow/python/training/quantize_training_wrapper.cc @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "include/pybind11/pybind11.h" +#include "tensorflow/core/graph/quantize_training.h" +#include "tensorflow/python/lib/core/pybind11_lib.h" +#include "tensorflow/python/lib/core/pybind11_status.h" + +namespace py = pybind11; + +namespace tensorflow { +static PyObject* DoQuantizeTrainingOnGraphDefHelper(const string& input_graph, + int num_bits) { + string result; + // TODO(suharshs): Make the QuantizeAndDequantizeV2 configurable. + tensorflow::MaybeRaiseFromStatus( + tensorflow::DoQuantizeTrainingOnSerializedGraphDef( + input_graph, num_bits, "QuantizeAndDequantizeV2", &result)); + + PyObject* py_str = PyBytes_FromStringAndSize(result.data(), result.size()); + if (!py_str) { + tensorflow::MaybeRaiseFromStatus(tensorflow::errors::Internal( + "Failed to generate serialized string of the rewritten graph.")); + } + return py_str; +} +} // namespace tensorflow + +PYBIND11_MODULE(_pywrap_quantize_training, m) { + m.def("DoQuantizeTrainingOnGraphDefHelper", + [](const py::object input_graph, int num_bits) { + return tensorflow::pyo_or_throw( + tensorflow::DoQuantizeTrainingOnGraphDefHelper( + input_graph.cast(), num_bits)); + }); +}; diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py index 50acc0882b7..287b36bb615 100644 --- a/tensorflow/python/training/training.py +++ b/tensorflow/python/training/training.py @@ -110,7 +110,6 @@ from tensorflow.python.training.training_util import create_global_step from tensorflow.python.training.training_util import get_or_create_global_step from tensorflow.python.training.warm_starting_util import VocabInfo from tensorflow.python.training.warm_starting_util import warm_start -from tensorflow.python.pywrap_tensorflow import do_quantize_training_on_graphdef from tensorflow.python.pywrap_tensorflow import NewCheckpointReader from tensorflow.python.util.tf_export import tf_export @@ -146,3 +145,4 @@ tf_export(v1=["train.SaverDef"])(SaverDef) tf_export("train.SequenceExample")(SequenceExample) tf_export("train.ServerDef")(ServerDef) # pylint: enable=undefined-variable + diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index 86cbcb0667a..072c449525c 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -69,6 +69,7 @@ tensorflow::Status::ok() tensorflow::Device::attributes tensorflow::DeviceFactory::AddDevices tensorflow::SessionOptions::SessionOptions +tensorflow::DoQuantizeTrainingOnSerializedGraphDef [protos_all] # device_lib tensorflow::ConfigProto::ConfigProto @@ -81,3 +82,4 @@ tensorflow::PyExceptionRegistry::Lookup [kernel_registry] # kernel_registry tensorflow::swig::TryFindKernelClass +