Add support for bfloat16 in tf.print
PiperOrigin-RevId: 292736205 Change-Id: I3d09e6dbc1ea4e425719312dfcec0691e4f98676
This commit is contained in:
parent
f3ac48296b
commit
adcbdc2bcc
@ -1015,6 +1015,10 @@ inline float PrintOneElement(const Eigen::half& h, bool print_v2) {
|
||||
return static_cast<float>(h);
|
||||
}
|
||||
|
||||
inline float PrintOneElement(bfloat16 f, bool print_v2) {
|
||||
return static_cast<float>(f);
|
||||
}
|
||||
|
||||
// Print from left dim to right dim recursively.
|
||||
template <typename T>
|
||||
void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
|
||||
@ -1156,6 +1160,9 @@ string Tensor::SummarizeValue(int64 max_entries, bool print_v2) const {
|
||||
}
|
||||
const char* data = limit > 0 ? tensor_data().data() : nullptr;
|
||||
switch (dtype()) {
|
||||
case DT_BFLOAT16:
|
||||
return SummarizeArray<bfloat16>(limit, num_elts, shape_, data, print_v2);
|
||||
break;
|
||||
case DT_HALF:
|
||||
return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data,
|
||||
print_v2);
|
||||
|
@ -178,12 +178,13 @@ class PrintV2Test(test.TestCase):
|
||||
self.assertIn((expected + "\n"), printed.contents())
|
||||
|
||||
def testPrintFloatScalar(self):
|
||||
tensor = ops.convert_to_tensor(434.43)
|
||||
with self.captureWritesToStream(sys.stderr) as printed:
|
||||
print_op = logging_ops.print_v2(tensor)
|
||||
self.evaluate(print_op)
|
||||
expected = "434.43"
|
||||
self.assertIn((expected + "\n"), printed.contents())
|
||||
for dtype in [dtypes.bfloat16, dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
tensor = ops.convert_to_tensor(43.5, dtype=dtype)
|
||||
with self.captureWritesToStream(sys.stderr) as printed:
|
||||
print_op = logging_ops.print_v2(tensor)
|
||||
self.evaluate(print_op)
|
||||
expected = "43.5"
|
||||
self.assertIn((expected + "\n"), printed.contents())
|
||||
|
||||
def testPrintStringScalar(self):
|
||||
tensor = ops.convert_to_tensor("scalar")
|
||||
|
Loading…
Reference in New Issue
Block a user