Export the bfloat16 classes and functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.
PiperOrigin-RevId: 283439638 Change-Id: I8ca8e5e4835995f78b8b1d78036a98de444508d3
This commit is contained in:
parent
71681bd691
commit
a46fa0b405
@ -400,6 +400,17 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_python_pybind_extension(
|
||||||
|
name = "_pywrap_bfloat16",
|
||||||
|
srcs = ["lib/core/bfloat16_wrapper.cc"],
|
||||||
|
hdrs = ["lib/core/bfloat16.h"],
|
||||||
|
module_name = "_pywrap_bfloat16",
|
||||||
|
deps = [
|
||||||
|
"//third_party/python_runtime:headers",
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "ndarray_tensor_bridge",
|
name = "ndarray_tensor_bridge",
|
||||||
srcs = ["lib/core/ndarray_tensor_bridge.cc"],
|
srcs = ["lib/core/ndarray_tensor_bridge.cc"],
|
||||||
@ -1158,6 +1169,7 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":_dtypes",
|
":_dtypes",
|
||||||
|
":_pywrap_bfloat16",
|
||||||
":pywrap_tensorflow",
|
":pywrap_tensorflow",
|
||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
],
|
],
|
||||||
@ -5442,7 +5454,6 @@ tf_py_wrap_cc(
|
|||||||
"grappler/cost_analyzer.i",
|
"grappler/cost_analyzer.i",
|
||||||
"grappler/item.i",
|
"grappler/item.i",
|
||||||
"grappler/tf_optimizer.i",
|
"grappler/tf_optimizer.i",
|
||||||
"lib/core/bfloat16.i",
|
|
||||||
"lib/core/strings.i",
|
"lib/core/strings.i",
|
||||||
"lib/io/file_io.i",
|
"lib/io/file_io.i",
|
||||||
"lib/io/py_record_reader.i",
|
"lib/io/py_record_reader.i",
|
||||||
@ -5528,6 +5539,7 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [
|
|||||||
":numpy_lib", # checkpoint_reader
|
":numpy_lib", # checkpoint_reader
|
||||||
":safe_ptr", # checkpoint_reader
|
":safe_ptr", # checkpoint_reader
|
||||||
":python_op_gen", # python_op_gen
|
":python_op_gen", # python_op_gen
|
||||||
|
":bfloat16_lib", # bfloat16
|
||||||
"//tensorflow/core/util/tensor_bundle", # checkpoint_reader
|
"//tensorflow/core/util/tensor_bundle", # checkpoint_reader
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -21,14 +21,15 @@ import numpy as np
|
|||||||
from six.moves import builtins
|
from six.moves import builtins
|
||||||
|
|
||||||
from tensorflow.core.framework import types_pb2
|
from tensorflow.core.framework import types_pb2
|
||||||
# pywrap_tensorflow must be imported prior to _dtypes for the MacOS linker
|
# We need to import pywrap_tensorflow prior to the bfloat wrapper to avoid
|
||||||
# to resolve the protobufs properly.
|
# protobuf errors where a file is defined twice on MacOS.
|
||||||
# pylint: disable=unused-import,g-bad-import-order
|
# pylint: disable=invalid-import-order,g-bad-import-order
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||||
|
from tensorflow.python import _pywrap_bfloat16
|
||||||
from tensorflow.python import _dtypes
|
from tensorflow.python import _dtypes
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
_np_bfloat16 = pywrap_tensorflow.TF_bfloat16_type()
|
_np_bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=slots-on-old-class
|
# pylint: disable=slots-on-old-class
|
||||||
|
|||||||
@ -532,7 +532,9 @@ struct Bfloat16GeFunctor {
|
|||||||
|
|
||||||
// Initializes the module.
|
// Initializes the module.
|
||||||
bool Initialize() {
|
bool Initialize() {
|
||||||
// It's critical to import umath to avoid crash in open source build.
|
// It's critical to ImportNumpy and import umath
|
||||||
|
// to avoid crash in open source build.
|
||||||
|
ImportNumpy();
|
||||||
import_umath1(false);
|
import_umath1(false);
|
||||||
|
|
||||||
Safe_PyObjectPtr numpy_str = make_safe(MakePyString("numpy"));
|
Safe_PyObjectPtr numpy_str = make_safe(MakePyString("numpy"));
|
||||||
|
|||||||
@ -24,12 +24,12 @@ import math
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# pylint: disable=unused-import,g-bad-import-order
|
# pylint: disable=unused-import,g-bad-import-order
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import _pywrap_bfloat16
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
bfloat16 = pywrap_tensorflow.TF_bfloat16_type()
|
bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
|
||||||
|
|
||||||
|
|
||||||
class Bfloat16Test(test.TestCase):
|
class Bfloat16Test(test.TestCase):
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -13,18 +13,12 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
%{
|
#include "include/pybind11/pybind11.h"
|
||||||
#include "tensorflow/python/lib/core/bfloat16.h"
|
#include "tensorflow/python/lib/core/bfloat16.h"
|
||||||
%}
|
|
||||||
|
|
||||||
%init %{
|
PYBIND11_MODULE(_pywrap_bfloat16, m) {
|
||||||
tensorflow::RegisterNumpyBfloat16();
|
tensorflow::RegisterNumpyBfloat16();
|
||||||
%}
|
|
||||||
|
|
||||||
%{
|
m.def("TF_bfloat16_type",
|
||||||
PyObject* TF_bfloat16_type() {
|
[] { return pybind11::handle(tensorflow::Bfloat16PyType()); });
|
||||||
return tensorflow::Bfloat16PyType();
|
|
||||||
}
|
}
|
||||||
%}
|
|
||||||
|
|
||||||
PyObject* TF_bfloat16_type();
|
|
||||||
@ -21,8 +21,6 @@ limitations under the License.
|
|||||||
|
|
||||||
%include "tensorflow/python/client/tf_session.i"
|
%include "tensorflow/python/client/tf_session.i"
|
||||||
|
|
||||||
%include "tensorflow/python/lib/core/bfloat16.i"
|
|
||||||
|
|
||||||
%include "tensorflow/python/lib/io/file_io.i"
|
%include "tensorflow/python/lib/io/file_io.i"
|
||||||
|
|
||||||
%include "tensorflow/python/lib/io/py_record_reader.i"
|
%include "tensorflow/python/lib/io/py_record_reader.i"
|
||||||
|
|||||||
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -131,7 +130,6 @@ class MovingAveragesTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.deprecated_graph_mode_only
|
@test_util.deprecated_graph_mode_only
|
||||||
def testWeightedMovingAverageBfloat16(self):
|
def testWeightedMovingAverageBfloat16(self):
|
||||||
bfloat16 = pywrap_tensorflow.TF_bfloat16_type()
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
decay = 0.5
|
decay = 0.5
|
||||||
weight = array_ops.placeholder(dtypes.bfloat16, [])
|
weight = array_ops.placeholder(dtypes.bfloat16, [])
|
||||||
@ -154,7 +152,8 @@ class MovingAveragesTest(test.TestCase):
|
|||||||
wma_array = sess.run(wma, feed_dict={val: val_2, weight: weight_2})
|
wma_array = sess.run(wma, feed_dict={val: val_2, weight: weight_2})
|
||||||
numerator_2 = numerator_1 * decay + val_2 * weight_2 * (1.0 - decay)
|
numerator_2 = numerator_1 * decay + val_2 * weight_2 * (1.0 - decay)
|
||||||
denominator_2 = denominator_1 * decay + weight_2 * (1.0 - decay)
|
denominator_2 = denominator_1 * decay + weight_2 * (1.0 - decay)
|
||||||
self.assertAllClose(bfloat16(numerator_2 / denominator_2), wma_array)
|
self.assertAllClose(
|
||||||
|
dtypes._np_bfloat16(numerator_2 / denominator_2), wma_array)
|
||||||
|
|
||||||
|
|
||||||
def _Repeat(value, dim):
|
def _Repeat(value, dim):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user