Merge pull request #39597 from Intel-tensorflow:cuixiaom/bf16_namescope_pr
PiperOrigin-RevId: 314593797 Change-Id: I8505d3d5195ad64f91776e3d18a583415cf353da
This commit is contained in:
commit
463d464069
tensorflow
@ -70,11 +70,13 @@ def _get_custom_getter():
|
||||
|
||||
@tf_export(v1=['tpu.bfloat16_scope'])
|
||||
@tf_contextlib.contextmanager
|
||||
def bfloat16_scope():
|
||||
def bfloat16_scope(name=None):
|
||||
"""Scope class for bfloat16 variables so that the model uses custom getter.
|
||||
|
||||
This enables variables to be read as bfloat16 type when using get_variable.
|
||||
"""
|
||||
if name is None:
|
||||
name = ''
|
||||
with variable_scope.variable_scope(
|
||||
'', custom_getter=_get_custom_getter()) as varscope:
|
||||
name, custom_getter=_get_custom_getter()) as varscope:
|
||||
yield varscope
|
||||
|
@ -24,15 +24,50 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.tpu import bfloat16
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
|
||||
class BFloat16ScopeTest(test.TestCase):
|
||||
|
||||
def testScopeName(self):
|
||||
def testDefaultScopeName(self):
|
||||
"""Test if name for the variable scope is propagated correctly."""
|
||||
with bfloat16.bfloat16_scope() as bf:
|
||||
self.assertEqual(bf.name, "")
|
||||
|
||||
def testCustomScopeName(self):
|
||||
"""Test if custom name for the variable scope is propagated correctly."""
|
||||
name = 'bfloat16'
|
||||
with bfloat16.bfloat16_scope('bfloat16') as bf:
|
||||
self.assertEqual(bf.name, name)
|
||||
|
||||
def testVariableName(self):
|
||||
"""Test if custom name for the variable scope is propagated correctly."""
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
a = variables.Variable(2.2, name='var_a')
|
||||
b = variables.Variable(3.3, name='var_b')
|
||||
d = variables.Variable(4.4, name='var_b')
|
||||
with g.name_scope('scope1'):
|
||||
with bfloat16.bfloat16_scope('bf16'):
|
||||
a = math_ops.cast(a, dtypes.bfloat16)
|
||||
b = math_ops.cast(b, dtypes.bfloat16)
|
||||
c = math_ops.add(a, b, name='addition')
|
||||
with bfloat16.bfloat16_scope():
|
||||
d = math_ops.cast(d, dtypes.bfloat16)
|
||||
math_ops.add(c, d, name='addition')
|
||||
|
||||
g_ops = g.get_operations()
|
||||
ops_name = []
|
||||
for op in g_ops:
|
||||
ops_name.append(str(op.name))
|
||||
|
||||
self.assertIn('scope1/bf16/addition', ops_name)
|
||||
self.assertIn('scope1/bf16/Cast', ops_name)
|
||||
self.assertIn('scope1/addition', ops_name)
|
||||
self.assertIn('scope1/Cast', ops_name)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testRequestedDType(self):
|
||||
"""Test if requested dtype is honored in the getter.
|
||||
|
@ -18,7 +18,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "bfloat16_scope"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "core"
|
||||
|
Loading…
Reference in New Issue
Block a user