Standardize name scopes used during model construction in v2.
PiperOrigin-RevId: 301852763 Change-Id: I3b4281f64ec4f3fe8e5f25901a1581ebc63de057
This commit is contained in:
parent
6a15e29467
commit
afb11783a5
@ -111,6 +111,7 @@ py_library(
|
||||
":base_layer_utils",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python/data",
|
||||
"//tensorflow/python/distribute:distribute_coordinator",
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
|
||||
@ -30,6 +30,7 @@ from six.moves import zip # pylint: disable=redefined-builtin
|
||||
|
||||
from google.protobuf import json_format
|
||||
from tensorflow.core.framework import node_def_pb2
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.autograph.core import ag_ctx
|
||||
from tensorflow.python.autograph.impl import api as autograph
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
@ -2083,7 +2084,18 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
self._dtype_policy = policy.Policy(value)
|
||||
|
||||
def _name_scope(self):
|
||||
return self.name
|
||||
if not tf2.enabled():
|
||||
return self.name
|
||||
name_scope = self.name
|
||||
current_name_scope = ops.get_name_scope()
|
||||
if current_name_scope:
|
||||
name_scope = current_name_scope + '/' + name_scope
|
||||
if name_scope:
|
||||
# Note that the trailing `/` prevents autogenerated
|
||||
# numerical suffixes to get appended. It will also fully reset
|
||||
# nested name scope (i.e. the outer name scope has no effect).
|
||||
name_scope += '/'
|
||||
return name_scope
|
||||
|
||||
def _init_set_name(self, name, zero_based=True):
|
||||
if not name:
|
||||
|
||||
@ -936,6 +936,30 @@ class NameScopingTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(layer.bias.name, 'MyName/bias:0')
|
||||
self.assertEqual(layer.kernel.name, 'MyName/kernel:0')
|
||||
|
||||
def test_name_scope_functional_api(self):
|
||||
inputs = input_layer.Input((3,))
|
||||
layer = layers.Dense(10, name='MyName')
|
||||
_ = layer(inputs)
|
||||
self.assertEqual(layer.bias.name, 'MyName/bias:0')
|
||||
self.assertEqual(layer.kernel.name, 'MyName/kernel:0')
|
||||
|
||||
def test_name_scope_functional_api_nested(self):
|
||||
|
||||
class NestedLayer(base_layer.Layer):
|
||||
|
||||
def __init__(self, name='OuterName'):
|
||||
super(NestedLayer, self).__init__(name=name)
|
||||
self.dense = layers.Dense(10, name='InnerName')
|
||||
|
||||
def call(self, inputs):
|
||||
return self.dense(inputs)
|
||||
|
||||
inputs = input_layer.Input((3,))
|
||||
layer = NestedLayer()
|
||||
_ = layer(inputs)
|
||||
self.assertEqual(layer.dense.bias.name, 'OuterName/InnerName/bias:0')
|
||||
self.assertEqual(layer.dense.kernel.name, 'OuterName/InnerName/kernel:0')
|
||||
|
||||
def test_name_scope_sublayer(self):
|
||||
|
||||
class NameScopeTracker(base_layer.Layer):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user