Add support for bfloat16 in tf.print
PiperOrigin-RevId: 292736205 Change-Id: I3d09e6dbc1ea4e425719312dfcec0691e4f98676
This commit is contained in:
parent
f3ac48296b
commit
adcbdc2bcc
tensorflow
@ -1015,6 +1015,10 @@ inline float PrintOneElement(const Eigen::half& h, bool print_v2) {
|
|||||||
return static_cast<float>(h);
|
return static_cast<float>(h);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline float PrintOneElement(bfloat16 f, bool print_v2) {
|
||||||
|
return static_cast<float>(f);
|
||||||
|
}
|
||||||
|
|
||||||
// Print from left dim to right dim recursively.
|
// Print from left dim to right dim recursively.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
|
void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
|
||||||
@ -1156,6 +1160,9 @@ string Tensor::SummarizeValue(int64 max_entries, bool print_v2) const {
|
|||||||
}
|
}
|
||||||
const char* data = limit > 0 ? tensor_data().data() : nullptr;
|
const char* data = limit > 0 ? tensor_data().data() : nullptr;
|
||||||
switch (dtype()) {
|
switch (dtype()) {
|
||||||
|
case DT_BFLOAT16:
|
||||||
|
return SummarizeArray<bfloat16>(limit, num_elts, shape_, data, print_v2);
|
||||||
|
break;
|
||||||
case DT_HALF:
|
case DT_HALF:
|
||||||
return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data,
|
return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data,
|
||||||
print_v2);
|
print_v2);
|
||||||
|
@ -178,11 +178,12 @@ class PrintV2Test(test.TestCase):
|
|||||||
self.assertIn((expected + "\n"), printed.contents())
|
self.assertIn((expected + "\n"), printed.contents())
|
||||||
|
|
||||||
def testPrintFloatScalar(self):
|
def testPrintFloatScalar(self):
|
||||||
tensor = ops.convert_to_tensor(434.43)
|
for dtype in [dtypes.bfloat16, dtypes.half, dtypes.float32, dtypes.float64]:
|
||||||
|
tensor = ops.convert_to_tensor(43.5, dtype=dtype)
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(tensor)
|
print_op = logging_ops.print_v2(tensor)
|
||||||
self.evaluate(print_op)
|
self.evaluate(print_op)
|
||||||
expected = "434.43"
|
expected = "43.5"
|
||||||
self.assertIn((expected + "\n"), printed.contents())
|
self.assertIn((expected + "\n"), printed.contents())
|
||||||
|
|
||||||
def testPrintStringScalar(self):
|
def testPrintStringScalar(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user