From adcbdc2bcca2ce1e5843c2fc04cf439a0debfb05 Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Sat, 1 Feb 2020 15:56:50 -0800 Subject: [PATCH] Add support for bfloat16 in tf.print PiperOrigin-RevId: 292736205 Change-Id: I3d09e6dbc1ea4e425719312dfcec0691e4f98676 --- tensorflow/core/framework/tensor.cc | 7 +++++++ tensorflow/python/kernel_tests/logging_ops_test.py | 13 +++++++------ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index dbca4879cf9..3a47cd35cbf 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -1015,6 +1015,10 @@ inline float PrintOneElement(const Eigen::half& h, bool print_v2) { return static_cast(h); } +inline float PrintOneElement(bfloat16 f, bool print_v2) { + return static_cast(f); +} + // Print from left dim to right dim recursively. template void PrintOneDim(int dim_index, const gtl::InlinedVector& shape, @@ -1156,6 +1160,9 @@ string Tensor::SummarizeValue(int64 max_entries, bool print_v2) const { } const char* data = limit > 0 ? tensor_data().data() : nullptr; switch (dtype()) { + case DT_BFLOAT16: + return SummarizeArray(limit, num_elts, shape_, data, print_v2); + break; case DT_HALF: return SummarizeArray(limit, num_elts, shape_, data, print_v2); diff --git a/tensorflow/python/kernel_tests/logging_ops_test.py b/tensorflow/python/kernel_tests/logging_ops_test.py index bfd57e5d6ab..5beb785ac2b 100644 --- a/tensorflow/python/kernel_tests/logging_ops_test.py +++ b/tensorflow/python/kernel_tests/logging_ops_test.py @@ -178,12 +178,13 @@ class PrintV2Test(test.TestCase): self.assertIn((expected + "\n"), printed.contents()) def testPrintFloatScalar(self): - tensor = ops.convert_to_tensor(434.43) - with self.captureWritesToStream(sys.stderr) as printed: - print_op = logging_ops.print_v2(tensor) - self.evaluate(print_op) - expected = "434.43" - self.assertIn((expected + "\n"), printed.contents()) + for dtype in [dtypes.bfloat16, dtypes.half, dtypes.float32, dtypes.float64]: + tensor = ops.convert_to_tensor(43.5, dtype=dtype) + with self.captureWritesToStream(sys.stderr) as printed: + print_op = logging_ops.print_v2(tensor) + self.evaluate(print_op) + expected = "43.5" + self.assertIn((expected + "\n"), printed.contents()) def testPrintStringScalar(self): tensor = ops.convert_to_tensor("scalar")