Export the bfloat16 classes and functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.
PiperOrigin-RevId: 283439638 Change-Id: I8ca8e5e4835995f78b8b1d78036a98de444508d3
This commit is contained in:
		
							parent
							
								
									71681bd691
								
							
						
					
					
						commit
						a46fa0b405
					
				| @ -400,6 +400,17 @@ cc_library( | |||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | tf_python_pybind_extension( | ||||||
|  |     name = "_pywrap_bfloat16", | ||||||
|  |     srcs = ["lib/core/bfloat16_wrapper.cc"], | ||||||
|  |     hdrs = ["lib/core/bfloat16.h"], | ||||||
|  |     module_name = "_pywrap_bfloat16", | ||||||
|  |     deps = [ | ||||||
|  |         "//third_party/python_runtime:headers", | ||||||
|  |         "@pybind11", | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| cc_library( | cc_library( | ||||||
|     name = "ndarray_tensor_bridge", |     name = "ndarray_tensor_bridge", | ||||||
|     srcs = ["lib/core/ndarray_tensor_bridge.cc"], |     srcs = ["lib/core/ndarray_tensor_bridge.cc"], | ||||||
| @ -1158,6 +1169,7 @@ py_library( | |||||||
|     srcs_version = "PY2AND3", |     srcs_version = "PY2AND3", | ||||||
|     deps = [ |     deps = [ | ||||||
|         ":_dtypes", |         ":_dtypes", | ||||||
|  |         ":_pywrap_bfloat16", | ||||||
|         ":pywrap_tensorflow", |         ":pywrap_tensorflow", | ||||||
|         "//tensorflow/core:protos_all_py", |         "//tensorflow/core:protos_all_py", | ||||||
|     ], |     ], | ||||||
| @ -5442,7 +5454,6 @@ tf_py_wrap_cc( | |||||||
|         "grappler/cost_analyzer.i", |         "grappler/cost_analyzer.i", | ||||||
|         "grappler/item.i", |         "grappler/item.i", | ||||||
|         "grappler/tf_optimizer.i", |         "grappler/tf_optimizer.i", | ||||||
|         "lib/core/bfloat16.i", |  | ||||||
|         "lib/core/strings.i", |         "lib/core/strings.i", | ||||||
|         "lib/io/file_io.i", |         "lib/io/file_io.i", | ||||||
|         "lib/io/py_record_reader.i", |         "lib/io/py_record_reader.i", | ||||||
| @ -5528,6 +5539,7 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [ | |||||||
|     ":numpy_lib",  # checkpoint_reader |     ":numpy_lib",  # checkpoint_reader | ||||||
|     ":safe_ptr",  # checkpoint_reader |     ":safe_ptr",  # checkpoint_reader | ||||||
|     ":python_op_gen",  # python_op_gen |     ":python_op_gen",  # python_op_gen | ||||||
|  |     ":bfloat16_lib",  # bfloat16 | ||||||
|     "//tensorflow/core/util/tensor_bundle",  # checkpoint_reader |     "//tensorflow/core/util/tensor_bundle",  # checkpoint_reader | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -21,14 +21,15 @@ import numpy as np | |||||||
| from six.moves import builtins | from six.moves import builtins | ||||||
| 
 | 
 | ||||||
| from tensorflow.core.framework import types_pb2 | from tensorflow.core.framework import types_pb2 | ||||||
| # pywrap_tensorflow must be imported prior to _dtypes for the MacOS linker | # We need to import pywrap_tensorflow prior to the bfloat wrapper to avoid | ||||||
| # to resolve the protobufs properly. | # protobuf errors where a file is defined twice on MacOS. | ||||||
| # pylint: disable=unused-import,g-bad-import-order | # pylint: disable=invalid-import-order,g-bad-import-order | ||||||
| from tensorflow.python import pywrap_tensorflow | from tensorflow.python import pywrap_tensorflow  # pylint: disable=unused-import | ||||||
|  | from tensorflow.python import _pywrap_bfloat16 | ||||||
| from tensorflow.python import _dtypes | from tensorflow.python import _dtypes | ||||||
| from tensorflow.python.util.tf_export import tf_export | from tensorflow.python.util.tf_export import tf_export | ||||||
| 
 | 
 | ||||||
| _np_bfloat16 = pywrap_tensorflow.TF_bfloat16_type() | _np_bfloat16 = _pywrap_bfloat16.TF_bfloat16_type() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # pylint: disable=slots-on-old-class | # pylint: disable=slots-on-old-class | ||||||
|  | |||||||
| @ -532,7 +532,9 @@ struct Bfloat16GeFunctor { | |||||||
| 
 | 
 | ||||||
| // Initializes the module.
 | // Initializes the module.
 | ||||||
| bool Initialize() { | bool Initialize() { | ||||||
|   // It's critical to import umath to avoid crash in open source build.
 |   // It's critical to ImportNumpy and import umath
 | ||||||
|  |   // to avoid crash in open source build.
 | ||||||
|  |   ImportNumpy(); | ||||||
|   import_umath1(false); |   import_umath1(false); | ||||||
| 
 | 
 | ||||||
|   Safe_PyObjectPtr numpy_str = make_safe(MakePyString("numpy")); |   Safe_PyObjectPtr numpy_str = make_safe(MakePyString("numpy")); | ||||||
|  | |||||||
| @ -24,12 +24,12 @@ import math | |||||||
| import numpy as np | import numpy as np | ||||||
| 
 | 
 | ||||||
| # pylint: disable=unused-import,g-bad-import-order | # pylint: disable=unused-import,g-bad-import-order | ||||||
| from tensorflow.python import pywrap_tensorflow | from tensorflow.python import _pywrap_bfloat16 | ||||||
| from tensorflow.python.framework import dtypes | from tensorflow.python.framework import dtypes | ||||||
| from tensorflow.python.platform import test | from tensorflow.python.platform import test | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| bfloat16 = pywrap_tensorflow.TF_bfloat16_type() | bfloat16 = _pywrap_bfloat16.TF_bfloat16_type() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Bfloat16Test(test.TestCase): | class Bfloat16Test(test.TestCase): | ||||||
|  | |||||||
| @ -1,4 +1,4 @@ | |||||||
| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | ||||||
| 
 | 
 | ||||||
| Licensed under the Apache License, Version 2.0 (the "License"); | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
| you may not use this file except in compliance with the License. | you may not use this file except in compliance with the License. | ||||||
| @ -13,18 +13,12 @@ See the License for the specific language governing permissions and | |||||||
| limitations under the License. | limitations under the License. | ||||||
| ==============================================================================*/ | ==============================================================================*/ | ||||||
| 
 | 
 | ||||||
| %{ | #include "include/pybind11/pybind11.h" | ||||||
| #include "tensorflow/python/lib/core/bfloat16.h" | #include "tensorflow/python/lib/core/bfloat16.h" | ||||||
| %} |  | ||||||
| 
 | 
 | ||||||
| %init %{ | PYBIND11_MODULE(_pywrap_bfloat16, m) { | ||||||
| tensorflow::RegisterNumpyBfloat16(); |   tensorflow::RegisterNumpyBfloat16(); | ||||||
| %} |  | ||||||
| 
 | 
 | ||||||
| %{ |   m.def("TF_bfloat16_type", | ||||||
| PyObject* TF_bfloat16_type() { |         [] { return pybind11::handle(tensorflow::Bfloat16PyType()); }); | ||||||
|   return tensorflow::Bfloat16PyType(); |  | ||||||
| } | } | ||||||
| %} |  | ||||||
| 
 |  | ||||||
| PyObject* TF_bfloat16_type(); |  | ||||||
| @ -21,8 +21,6 @@ limitations under the License. | |||||||
| 
 | 
 | ||||||
| %include "tensorflow/python/client/tf_session.i" | %include "tensorflow/python/client/tf_session.i" | ||||||
| 
 | 
 | ||||||
| %include "tensorflow/python/lib/core/bfloat16.i" |  | ||||||
| 
 |  | ||||||
| %include "tensorflow/python/lib/io/file_io.i" | %include "tensorflow/python/lib/io/file_io.i" | ||||||
| 
 | 
 | ||||||
| %include "tensorflow/python/lib/io/py_record_reader.i" | %include "tensorflow/python/lib/io/py_record_reader.i" | ||||||
|  | |||||||
| @ -18,7 +18,6 @@ from __future__ import absolute_import | |||||||
| from __future__ import division | from __future__ import division | ||||||
| from __future__ import print_function | from __future__ import print_function | ||||||
| 
 | 
 | ||||||
| from tensorflow.python import pywrap_tensorflow |  | ||||||
| from tensorflow.python.eager import context | from tensorflow.python.eager import context | ||||||
| from tensorflow.python.framework import constant_op | from tensorflow.python.framework import constant_op | ||||||
| from tensorflow.python.framework import dtypes | from tensorflow.python.framework import dtypes | ||||||
| @ -131,7 +130,6 @@ class MovingAveragesTest(test.TestCase): | |||||||
| 
 | 
 | ||||||
|   @test_util.deprecated_graph_mode_only |   @test_util.deprecated_graph_mode_only | ||||||
|   def testWeightedMovingAverageBfloat16(self): |   def testWeightedMovingAverageBfloat16(self): | ||||||
|     bfloat16 = pywrap_tensorflow.TF_bfloat16_type() |  | ||||||
|     with self.cached_session() as sess: |     with self.cached_session() as sess: | ||||||
|       decay = 0.5 |       decay = 0.5 | ||||||
|       weight = array_ops.placeholder(dtypes.bfloat16, []) |       weight = array_ops.placeholder(dtypes.bfloat16, []) | ||||||
| @ -154,7 +152,8 @@ class MovingAveragesTest(test.TestCase): | |||||||
|       wma_array = sess.run(wma, feed_dict={val: val_2, weight: weight_2}) |       wma_array = sess.run(wma, feed_dict={val: val_2, weight: weight_2}) | ||||||
|       numerator_2 = numerator_1 * decay + val_2 * weight_2 * (1.0 - decay) |       numerator_2 = numerator_1 * decay + val_2 * weight_2 * (1.0 - decay) | ||||||
|       denominator_2 = denominator_1 * decay + weight_2 * (1.0 - decay) |       denominator_2 = denominator_1 * decay + weight_2 * (1.0 - decay) | ||||||
|       self.assertAllClose(bfloat16(numerator_2 / denominator_2), wma_array) |       self.assertAllClose( | ||||||
|  |           dtypes._np_bfloat16(numerator_2 / denominator_2), wma_array) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _Repeat(value, dim): | def _Repeat(value, dim): | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user