Add support for bfloat16 in tf.print

PiperOrigin-RevId: 292736205
Change-Id: I3d09e6dbc1ea4e425719312dfcec0691e4f98676
This commit is contained in:
Gaurav Jain 2020-02-01 15:56:50 -08:00 committed by TensorFlower Gardener
parent f3ac48296b
commit adcbdc2bcc
2 changed files with 14 additions and 6 deletions

View File

@ -1015,6 +1015,10 @@ inline float PrintOneElement(const Eigen::half& h, bool print_v2) {
return static_cast<float>(h);
}
inline float PrintOneElement(bfloat16 f, bool print_v2) {
return static_cast<float>(f);
}
// Print from left dim to right dim recursively.
template <typename T>
void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
@ -1156,6 +1160,9 @@ string Tensor::SummarizeValue(int64 max_entries, bool print_v2) const {
}
const char* data = limit > 0 ? tensor_data().data() : nullptr;
switch (dtype()) {
case DT_BFLOAT16:
return SummarizeArray<bfloat16>(limit, num_elts, shape_, data, print_v2);
break;
case DT_HALF:
return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data,
print_v2);

View File

@ -178,12 +178,13 @@ class PrintV2Test(test.TestCase):
self.assertIn((expected + "\n"), printed.contents())
def testPrintFloatScalar(self):
tensor = ops.convert_to_tensor(434.43)
with self.captureWritesToStream(sys.stderr) as printed:
print_op = logging_ops.print_v2(tensor)
self.evaluate(print_op)
expected = "434.43"
self.assertIn((expected + "\n"), printed.contents())
for dtype in [dtypes.bfloat16, dtypes.half, dtypes.float32, dtypes.float64]:
tensor = ops.convert_to_tensor(43.5, dtype=dtype)
with self.captureWritesToStream(sys.stderr) as printed:
print_op = logging_ops.print_v2(tensor)
self.evaluate(print_op)
expected = "43.5"
self.assertIn((expected + "\n"), printed.contents())
def testPrintStringScalar(self):
tensor = ops.convert_to_tensor("scalar")