Standardize name scopes used during model construction.
PiperOrigin-RevId: 301718829 Change-Id: I09d0ffe16b08c369b864290c9e33ebc3b0d85edb
This commit is contained in:
parent
e377b6dbcf
commit
7fe72602b4
@ -2083,16 +2083,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
self._dtype_policy = policy.Policy(value)
|
||||
|
||||
def _name_scope(self):
|
||||
name_scope = self.name
|
||||
current_name_scope = ops.get_name_scope()
|
||||
if current_name_scope:
|
||||
name_scope = current_name_scope + '/' + name_scope
|
||||
if name_scope:
|
||||
# Note that the trailing `/` prevents autogenerated
|
||||
# numerical suffixes to get appended. It will also fully reset
|
||||
# nested name scope (i.e. the outer name scope has no effect).
|
||||
name_scope += '/'
|
||||
return name_scope
|
||||
return self.name
|
||||
|
||||
def _init_set_name(self, name, zero_based=True):
|
||||
if not name:
|
||||
|
@ -936,30 +936,6 @@ class NameScopingTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(layer.bias.name, 'MyName/bias:0')
|
||||
self.assertEqual(layer.kernel.name, 'MyName/kernel:0')
|
||||
|
||||
def test_name_scope_functional_api(self):
|
||||
inputs = input_layer.Input((3,))
|
||||
layer = layers.Dense(10, name='MyName')
|
||||
_ = layer(inputs)
|
||||
self.assertEqual(layer.bias.name, 'MyName/bias:0')
|
||||
self.assertEqual(layer.kernel.name, 'MyName/kernel:0')
|
||||
|
||||
def test_name_scope_functional_api_nested(self):
|
||||
|
||||
class NestedLayer(base_layer.Layer):
|
||||
|
||||
def __init__(self, name='OuterName'):
|
||||
super(NestedLayer, self).__init__(name=name)
|
||||
self.dense = layers.Dense(10, name='InnerName')
|
||||
|
||||
def call(self, inputs):
|
||||
return self.dense(inputs)
|
||||
|
||||
inputs = input_layer.Input((3,))
|
||||
layer = NestedLayer()
|
||||
_ = layer(inputs)
|
||||
self.assertEqual(layer.dense.bias.name, 'OuterName/InnerName/bias:0')
|
||||
self.assertEqual(layer.dense.kernel.name, 'OuterName/InnerName/kernel:0')
|
||||
|
||||
def test_name_scope_sublayer(self):
|
||||
|
||||
class NameScopeTracker(base_layer.Layer):
|
||||
|
Loading…
x
Reference in New Issue
Block a user