In the short term, making Feature Column V2 produce RefVariables instead of ResourceVariables till we figure out the performance regression issues.
PiperOrigin-RevId: 220535259
This commit is contained in:
parent
7003be098c
commit
3146bc0a22
@ -184,6 +184,7 @@ class StateManager(object):
|
|||||||
shape,
|
shape,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
trainable=True,
|
trainable=True,
|
||||||
|
use_resource=True,
|
||||||
initializer=None):
|
initializer=None):
|
||||||
"""Creates a new variable.
|
"""Creates a new variable.
|
||||||
|
|
||||||
@ -193,12 +194,14 @@ class StateManager(object):
|
|||||||
shape: variable shape.
|
shape: variable shape.
|
||||||
dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
|
dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
|
||||||
trainable: Whether this variable is trainable or not.
|
trainable: Whether this variable is trainable or not.
|
||||||
|
use_resource: If true, we use resource variables. Otherwise we use
|
||||||
|
RefVariable.
|
||||||
initializer: initializer instance (callable).
|
initializer: initializer instance (callable).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The created variable.
|
The created variable.
|
||||||
"""
|
"""
|
||||||
del feature_column, name, shape, dtype, trainable, initializer
|
del feature_column, name, shape, dtype, trainable, use_resource, initializer
|
||||||
raise NotImplementedError('StateManager.create_variable')
|
raise NotImplementedError('StateManager.create_variable')
|
||||||
|
|
||||||
def add_variable(self, feature_column, var):
|
def add_variable(self, feature_column, var):
|
||||||
@ -270,6 +273,7 @@ class _StateManagerImpl(StateManager):
|
|||||||
shape,
|
shape,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
trainable=True,
|
trainable=True,
|
||||||
|
use_resource=True,
|
||||||
initializer=None):
|
initializer=None):
|
||||||
if name in self._cols_to_vars_map[feature_column]:
|
if name in self._cols_to_vars_map[feature_column]:
|
||||||
raise ValueError('Variable already exists.')
|
raise ValueError('Variable already exists.')
|
||||||
@ -280,7 +284,7 @@ class _StateManagerImpl(StateManager):
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
trainable=self._trainable and trainable,
|
trainable=self._trainable and trainable,
|
||||||
use_resource=True,
|
use_resource=use_resource,
|
||||||
# TODO(rohanj): Get rid of this hack once we have a mechanism for
|
# TODO(rohanj): Get rid of this hack once we have a mechanism for
|
||||||
# specifying a default partitioner for an entire layer. In that case,
|
# specifying a default partitioner for an entire layer. In that case,
|
||||||
# the default getter for Layers should work.
|
# the default getter for Layers should work.
|
||||||
@ -2539,6 +2543,8 @@ class EmbeddingColumn(
|
|||||||
shape=embedding_shape,
|
shape=embedding_shape,
|
||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
trainable=self.trainable,
|
trainable=self.trainable,
|
||||||
|
# TODO(rohanj): Make this True when b/118500434 is fixed.
|
||||||
|
use_resource=False,
|
||||||
initializer=self.initializer)
|
initializer=self.initializer)
|
||||||
|
|
||||||
def _get_dense_tensor_internal_helper(self, sparse_tensors,
|
def _get_dense_tensor_internal_helper(self, sparse_tensors,
|
||||||
|
@ -1816,6 +1816,8 @@ class LinearModelTest(test.TestCase):
|
|||||||
'sparse_feature': [['a'], ['x']],
|
'sparse_feature': [['a'], ['x']],
|
||||||
}
|
}
|
||||||
model(features)
|
model(features)
|
||||||
|
for var in model.variables:
|
||||||
|
self.assertTrue(isinstance(var, variables_lib.RefVariable))
|
||||||
variable_names = [var.name for var in model.variables]
|
variable_names = [var.name for var in model.variables]
|
||||||
self.assertItemsEqual([
|
self.assertItemsEqual([
|
||||||
'linear_model/dense_feature_bucketized/weights:0',
|
'linear_model/dense_feature_bucketized/weights:0',
|
||||||
@ -5592,6 +5594,7 @@ class _TestStateManager(fc.StateManager):
|
|||||||
shape,
|
shape,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
trainable=True,
|
trainable=True,
|
||||||
|
use_resource=True,
|
||||||
initializer=None):
|
initializer=None):
|
||||||
if feature_column not in self._all_variables:
|
if feature_column not in self._all_variables:
|
||||||
self._all_variables[feature_column] = {}
|
self._all_variables[feature_column] = {}
|
||||||
@ -5604,6 +5607,7 @@ class _TestStateManager(fc.StateManager):
|
|||||||
shape=shape,
|
shape=shape,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trainable=self._trainable and trainable,
|
trainable=self._trainable and trainable,
|
||||||
|
use_resource=use_resource,
|
||||||
initializer=initializer)
|
initializer=initializer)
|
||||||
var_dict[name] = var
|
var_dict[name] = var
|
||||||
return var
|
return var
|
||||||
@ -6182,6 +6186,8 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',),
|
self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',),
|
||||||
tuple([v.name for v in global_vars]))
|
tuple([v.name for v in global_vars]))
|
||||||
|
for v in global_vars:
|
||||||
|
self.assertTrue(isinstance(v, variables_lib.RefVariable))
|
||||||
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||||
self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',),
|
self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',),
|
||||||
tuple([v.name for v in trainable_vars]))
|
tuple([v.name for v in trainable_vars]))
|
||||||
@ -6964,6 +6970,8 @@ class SharedEmbeddingColumnTest(test.TestCase):
|
|||||||
self.assertItemsEqual(
|
self.assertItemsEqual(
|
||||||
['aaa_bbb_shared_embedding:0', 'ccc_ddd_shared_embedding:0'],
|
['aaa_bbb_shared_embedding:0', 'ccc_ddd_shared_embedding:0'],
|
||||||
tuple([v.name for v in global_vars]))
|
tuple([v.name for v in global_vars]))
|
||||||
|
for v in global_vars:
|
||||||
|
self.assertTrue(isinstance(v, variables_lib.RefVariable))
|
||||||
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||||
if trainable:
|
if trainable:
|
||||||
self.assertItemsEqual(
|
self.assertItemsEqual(
|
||||||
|
Loading…
Reference in New Issue
Block a user