Standardize name scopes used during model construction in v2.

PiperOrigin-RevId: 301852763
Change-Id: I3b4281f64ec4f3fe8e5f25901a1581ebc63de057
This commit is contained in:
Francois Chollet 2020-03-19 11:05:11 -07:00 committed by TensorFlower Gardener
parent 6a15e29467
commit afb11783a5
3 changed files with 38 additions and 1 deletions

View File

@ -111,6 +111,7 @@ py_library(
":base_layer_utils",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:constant_op",
"//tensorflow/python:tf2",
"//tensorflow/python/data",
"//tensorflow/python/distribute:distribute_coordinator",
"//tensorflow/python/distribute:distribute_lib",

View File

@ -30,6 +30,7 @@ from six.moves import zip # pylint: disable=redefined-builtin
from google.protobuf import json_format
from tensorflow.core.framework import node_def_pb2
from tensorflow.python import tf2
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.distribute import distribution_strategy_context as ds_context
@ -2083,7 +2084,18 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
self._dtype_policy = policy.Policy(value)
def _name_scope(self):
return self.name
if not tf2.enabled():
return self.name
name_scope = self.name
current_name_scope = ops.get_name_scope()
if current_name_scope:
name_scope = current_name_scope + '/' + name_scope
if name_scope:
# Note that the trailing `/` prevents autogenerated
# numerical suffixes to get appended. It will also fully reset
# nested name scope (i.e. the outer name scope has no effect).
name_scope += '/'
return name_scope
def _init_set_name(self, name, zero_based=True):
if not name:

View File

@ -936,6 +936,30 @@ class NameScopingTest(keras_parameterized.TestCase):
self.assertEqual(layer.bias.name, 'MyName/bias:0')
self.assertEqual(layer.kernel.name, 'MyName/kernel:0')
def test_name_scope_functional_api(self):
inputs = input_layer.Input((3,))
layer = layers.Dense(10, name='MyName')
_ = layer(inputs)
self.assertEqual(layer.bias.name, 'MyName/bias:0')
self.assertEqual(layer.kernel.name, 'MyName/kernel:0')
def test_name_scope_functional_api_nested(self):
class NestedLayer(base_layer.Layer):
def __init__(self, name='OuterName'):
super(NestedLayer, self).__init__(name=name)
self.dense = layers.Dense(10, name='InnerName')
def call(self, inputs):
return self.dense(inputs)
inputs = input_layer.Input((3,))
layer = NestedLayer()
_ = layer(inputs)
self.assertEqual(layer.dense.bias.name, 'OuterName/InnerName/bias:0')
self.assertEqual(layer.dense.kernel.name, 'OuterName/InnerName/kernel:0')
def test_name_scope_sublayer(self):
class NameScopeTracker(base_layer.Layer):