In the short term, making Feature Column V2 produce RefVariables instead of ResourceVariables till we figure out the performance regression issues.

PiperOrigin-RevId: 220535259
This commit is contained in:
Rohan Jain 2018-11-07 14:59:59 -08:00 committed by TensorFlower Gardener
parent 7003be098c
commit 3146bc0a22
2 changed files with 16 additions and 2 deletions

View File

@ -184,6 +184,7 @@ class StateManager(object):
shape, shape,
dtype=None, dtype=None,
trainable=True, trainable=True,
use_resource=True,
initializer=None): initializer=None):
"""Creates a new variable. """Creates a new variable.
@ -193,12 +194,14 @@ class StateManager(object):
shape: variable shape. shape: variable shape.
dtype: The type of the variable. Defaults to `self.dtype` or `float32`. dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
trainable: Whether this variable is trainable or not. trainable: Whether this variable is trainable or not.
use_resource: If true, we use resource variables. Otherwise we use
RefVariable.
initializer: initializer instance (callable). initializer: initializer instance (callable).
Returns: Returns:
The created variable. The created variable.
""" """
del feature_column, name, shape, dtype, trainable, initializer del feature_column, name, shape, dtype, trainable, use_resource, initializer
raise NotImplementedError('StateManager.create_variable') raise NotImplementedError('StateManager.create_variable')
def add_variable(self, feature_column, var): def add_variable(self, feature_column, var):
@ -270,6 +273,7 @@ class _StateManagerImpl(StateManager):
shape, shape,
dtype=None, dtype=None,
trainable=True, trainable=True,
use_resource=True,
initializer=None): initializer=None):
if name in self._cols_to_vars_map[feature_column]: if name in self._cols_to_vars_map[feature_column]:
raise ValueError('Variable already exists.') raise ValueError('Variable already exists.')
@ -280,7 +284,7 @@ class _StateManagerImpl(StateManager):
dtype=dtype, dtype=dtype,
initializer=initializer, initializer=initializer,
trainable=self._trainable and trainable, trainable=self._trainable and trainable,
use_resource=True, use_resource=use_resource,
# TODO(rohanj): Get rid of this hack once we have a mechanism for # TODO(rohanj): Get rid of this hack once we have a mechanism for
# specifying a default partitioner for an entire layer. In that case, # specifying a default partitioner for an entire layer. In that case,
# the default getter for Layers should work. # the default getter for Layers should work.
@ -2539,6 +2543,8 @@ class EmbeddingColumn(
shape=embedding_shape, shape=embedding_shape,
dtype=dtypes.float32, dtype=dtypes.float32,
trainable=self.trainable, trainable=self.trainable,
# TODO(rohanj): Make this True when b/118500434 is fixed.
use_resource=False,
initializer=self.initializer) initializer=self.initializer)
def _get_dense_tensor_internal_helper(self, sparse_tensors, def _get_dense_tensor_internal_helper(self, sparse_tensors,

View File

@ -1816,6 +1816,8 @@ class LinearModelTest(test.TestCase):
'sparse_feature': [['a'], ['x']], 'sparse_feature': [['a'], ['x']],
} }
model(features) model(features)
for var in model.variables:
self.assertTrue(isinstance(var, variables_lib.RefVariable))
variable_names = [var.name for var in model.variables] variable_names = [var.name for var in model.variables]
self.assertItemsEqual([ self.assertItemsEqual([
'linear_model/dense_feature_bucketized/weights:0', 'linear_model/dense_feature_bucketized/weights:0',
@ -5592,6 +5594,7 @@ class _TestStateManager(fc.StateManager):
shape, shape,
dtype=None, dtype=None,
trainable=True, trainable=True,
use_resource=True,
initializer=None): initializer=None):
if feature_column not in self._all_variables: if feature_column not in self._all_variables:
self._all_variables[feature_column] = {} self._all_variables[feature_column] = {}
@ -5604,6 +5607,7 @@ class _TestStateManager(fc.StateManager):
shape=shape, shape=shape,
dtype=dtype, dtype=dtype,
trainable=self._trainable and trainable, trainable=self._trainable and trainable,
use_resource=use_resource,
initializer=initializer) initializer=initializer)
var_dict[name] = var var_dict[name] = var
return var return var
@ -6182,6 +6186,8 @@ class EmbeddingColumnTest(test.TestCase):
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',), self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',),
tuple([v.name for v in global_vars])) tuple([v.name for v in global_vars]))
for v in global_vars:
self.assertTrue(isinstance(v, variables_lib.RefVariable))
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',), self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',),
tuple([v.name for v in trainable_vars])) tuple([v.name for v in trainable_vars]))
@ -6964,6 +6970,8 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertItemsEqual( self.assertItemsEqual(
['aaa_bbb_shared_embedding:0', 'ccc_ddd_shared_embedding:0'], ['aaa_bbb_shared_embedding:0', 'ccc_ddd_shared_embedding:0'],
tuple([v.name for v in global_vars])) tuple([v.name for v in global_vars]))
for v in global_vars:
self.assertTrue(isinstance(v, variables_lib.RefVariable))
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
if trainable: if trainable:
self.assertItemsEqual( self.assertItemsEqual(