diff --git a/tensorflow/python/tpu/bfloat16.py b/tensorflow/python/tpu/bfloat16.py index 0e4a1441fcf..9761d7f7a0e 100644 --- a/tensorflow/python/tpu/bfloat16.py +++ b/tensorflow/python/tpu/bfloat16.py @@ -70,11 +70,13 @@ def _get_custom_getter(): @tf_export(v1=['tpu.bfloat16_scope']) @tf_contextlib.contextmanager -def bfloat16_scope(name=''): +def bfloat16_scope(name=None): """Scope class for bfloat16 variables so that the model uses custom getter. This enables variables to be read as bfloat16 type when using get_variable. """ + if name is None: + name = '' with variable_scope.variable_scope( name, custom_getter=_get_custom_getter()) as varscope: yield varscope