diff --git a/tensorflow/python/keras/mixed_precision/autocast_variable.py b/tensorflow/python/keras/mixed_precision/autocast_variable.py index 3cacee0cb82..6882a055a68 100644 --- a/tensorflow/python/keras/mixed_precision/autocast_variable.py +++ b/tensorflow/python/keras/mixed_precision/autocast_variable.py @@ -34,6 +34,19 @@ from tensorflow.python.types import core _autocast_dtype = threading.local() +def numpy_text(tensor, is_repr=False): + """Human readable representation of a tensor's numpy value.""" + if tensor.dtype.is_numpy_compatible: + # pylint: disable=protected-access + text = repr(tensor._numpy()) if is_repr else str(tensor._numpy()) + # pylint: enable=protected-access + else: + text = '' + if '\n' in text: + text = '\n' + text + return text + + class AutoCastVariable(variables.Variable, core.Tensor): """Variable that will cast itself to a different dtype in applicable contexts. @@ -144,7 +157,7 @@ class AutoCastVariable(variables.Variable, core.Tensor): 'dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}, ' 'numpy={np_repr}>') return repr_str.format( - v=self, np_repr=ops.numpy_text(self.read_value(), is_repr=True)) + v=self, np_repr=numpy_text(self.read_value(), is_repr=True)) else: repr_str = ("') @@ -534,5 +547,3 @@ class enable_auto_cast_variables(object): # pylint:disable=invalid-name def __exit__(self, type_arg, value_arg, traceback_arg): _autocast_dtype.dtype = self._prev_dtype - -