In the short term, making Feature Column V2 produce RefVariables instead of ResourceVariables till we figure out the performance regression issues.
PiperOrigin-RevId: 220535259
This commit is contained in:
parent
7003be098c
commit
3146bc0a22
@ -184,6 +184,7 @@ class StateManager(object):
|
||||
shape,
|
||||
dtype=None,
|
||||
trainable=True,
|
||||
use_resource=True,
|
||||
initializer=None):
|
||||
"""Creates a new variable.
|
||||
|
||||
@ -193,12 +194,14 @@ class StateManager(object):
|
||||
shape: variable shape.
|
||||
dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
|
||||
trainable: Whether this variable is trainable or not.
|
||||
use_resource: If true, we use resource variables. Otherwise we use
|
||||
RefVariable.
|
||||
initializer: initializer instance (callable).
|
||||
|
||||
Returns:
|
||||
The created variable.
|
||||
"""
|
||||
del feature_column, name, shape, dtype, trainable, initializer
|
||||
del feature_column, name, shape, dtype, trainable, use_resource, initializer
|
||||
raise NotImplementedError('StateManager.create_variable')
|
||||
|
||||
def add_variable(self, feature_column, var):
|
||||
@ -270,6 +273,7 @@ class _StateManagerImpl(StateManager):
|
||||
shape,
|
||||
dtype=None,
|
||||
trainable=True,
|
||||
use_resource=True,
|
||||
initializer=None):
|
||||
if name in self._cols_to_vars_map[feature_column]:
|
||||
raise ValueError('Variable already exists.')
|
||||
@ -280,7 +284,7 @@ class _StateManagerImpl(StateManager):
|
||||
dtype=dtype,
|
||||
initializer=initializer,
|
||||
trainable=self._trainable and trainable,
|
||||
use_resource=True,
|
||||
use_resource=use_resource,
|
||||
# TODO(rohanj): Get rid of this hack once we have a mechanism for
|
||||
# specifying a default partitioner for an entire layer. In that case,
|
||||
# the default getter for Layers should work.
|
||||
@ -2539,6 +2543,8 @@ class EmbeddingColumn(
|
||||
shape=embedding_shape,
|
||||
dtype=dtypes.float32,
|
||||
trainable=self.trainable,
|
||||
# TODO(rohanj): Make this True when b/118500434 is fixed.
|
||||
use_resource=False,
|
||||
initializer=self.initializer)
|
||||
|
||||
def _get_dense_tensor_internal_helper(self, sparse_tensors,
|
||||
|
@ -1816,6 +1816,8 @@ class LinearModelTest(test.TestCase):
|
||||
'sparse_feature': [['a'], ['x']],
|
||||
}
|
||||
model(features)
|
||||
for var in model.variables:
|
||||
self.assertTrue(isinstance(var, variables_lib.RefVariable))
|
||||
variable_names = [var.name for var in model.variables]
|
||||
self.assertItemsEqual([
|
||||
'linear_model/dense_feature_bucketized/weights:0',
|
||||
@ -5592,6 +5594,7 @@ class _TestStateManager(fc.StateManager):
|
||||
shape,
|
||||
dtype=None,
|
||||
trainable=True,
|
||||
use_resource=True,
|
||||
initializer=None):
|
||||
if feature_column not in self._all_variables:
|
||||
self._all_variables[feature_column] = {}
|
||||
@ -5604,6 +5607,7 @@ class _TestStateManager(fc.StateManager):
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
trainable=self._trainable and trainable,
|
||||
use_resource=use_resource,
|
||||
initializer=initializer)
|
||||
var_dict[name] = var
|
||||
return var
|
||||
@ -6182,6 +6186,8 @@ class EmbeddingColumnTest(test.TestCase):
|
||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',),
|
||||
tuple([v.name for v in global_vars]))
|
||||
for v in global_vars:
|
||||
self.assertTrue(isinstance(v, variables_lib.RefVariable))
|
||||
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||
self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',),
|
||||
tuple([v.name for v in trainable_vars]))
|
||||
@ -6964,6 +6970,8 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
||||
self.assertItemsEqual(
|
||||
['aaa_bbb_shared_embedding:0', 'ccc_ddd_shared_embedding:0'],
|
||||
tuple([v.name for v in global_vars]))
|
||||
for v in global_vars:
|
||||
self.assertTrue(isinstance(v, variables_lib.RefVariable))
|
||||
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||
if trainable:
|
||||
self.assertItemsEqual(
|
||||
|
Loading…
Reference in New Issue
Block a user