Merge pull request from Intel-tensorflow:cuixiaom/bf16_namescope_pr

PiperOrigin-RevId: 314593797
Change-Id: I8505d3d5195ad64f91776e3d18a583415cf353da
This commit is contained in:
TensorFlower Gardener 2020-06-03 13:35:44 -07:00
commit 463d464069
3 changed files with 41 additions and 4 deletions
tensorflow

View File

@ -70,11 +70,13 @@ def _get_custom_getter():
@tf_export(v1=['tpu.bfloat16_scope'])
@tf_contextlib.contextmanager
def bfloat16_scope():
def bfloat16_scope(name=None):
"""Scope class for bfloat16 variables so that the model uses custom getter.
This enables variables to be read as bfloat16 type when using get_variable.
"""
if name is None:
name = ''
with variable_scope.variable_scope(
'', custom_getter=_get_custom_getter()) as varscope:
name, custom_getter=_get_custom_getter()) as varscope:
yield varscope

View File

@ -24,15 +24,50 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.tpu import bfloat16
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
class BFloat16ScopeTest(test.TestCase):
def testScopeName(self):
def testDefaultScopeName(self):
"""Test if name for the variable scope is propagated correctly."""
with bfloat16.bfloat16_scope() as bf:
self.assertEqual(bf.name, "")
def testCustomScopeName(self):
"""Test if custom name for the variable scope is propagated correctly."""
name = 'bfloat16'
with bfloat16.bfloat16_scope('bfloat16') as bf:
self.assertEqual(bf.name, name)
def testVariableName(self):
"""Test if custom name for the variable scope is propagated correctly."""
g = ops.Graph()
with g.as_default():
a = variables.Variable(2.2, name='var_a')
b = variables.Variable(3.3, name='var_b')
d = variables.Variable(4.4, name='var_b')
with g.name_scope('scope1'):
with bfloat16.bfloat16_scope('bf16'):
a = math_ops.cast(a, dtypes.bfloat16)
b = math_ops.cast(b, dtypes.bfloat16)
c = math_ops.add(a, b, name='addition')
with bfloat16.bfloat16_scope():
d = math_ops.cast(d, dtypes.bfloat16)
math_ops.add(c, d, name='addition')
g_ops = g.get_operations()
ops_name = []
for op in g_ops:
ops_name.append(str(op.name))
self.assertIn('scope1/bf16/addition', ops_name)
self.assertIn('scope1/bf16/Cast', ops_name)
self.assertIn('scope1/addition', ops_name)
self.assertIn('scope1/Cast', ops_name)
@test_util.run_deprecated_v1
def testRequestedDType(self):
"""Test if requested dtype is honored in the getter.

View File

@ -18,7 +18,7 @@ tf_module {
}
member_method {
name: "bfloat16_scope"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "core"