In the short term, making Feature Column V2 produce RefVariables instead of ResourceVariables till we figure out the performance regression issues.

PiperOrigin-RevId: 220535259
This commit is contained in:
Rohan Jain 2018-11-07 14:59:59 -08:00 committed by TensorFlower Gardener
parent 7003be098c
commit 3146bc0a22
2 changed files with 16 additions and 2 deletions

View File

@ -184,6 +184,7 @@ class StateManager(object):
shape,
dtype=None,
trainable=True,
use_resource=True,
initializer=None):
"""Creates a new variable.
@ -193,12 +194,14 @@ class StateManager(object):
shape: variable shape.
dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
trainable: Whether this variable is trainable or not.
use_resource: If true, we use resource variables. Otherwise we use
RefVariable.
initializer: initializer instance (callable).
Returns:
The created variable.
"""
del feature_column, name, shape, dtype, trainable, initializer
del feature_column, name, shape, dtype, trainable, use_resource, initializer
raise NotImplementedError('StateManager.create_variable')
def add_variable(self, feature_column, var):
@ -270,6 +273,7 @@ class _StateManagerImpl(StateManager):
shape,
dtype=None,
trainable=True,
use_resource=True,
initializer=None):
if name in self._cols_to_vars_map[feature_column]:
raise ValueError('Variable already exists.')
@ -280,7 +284,7 @@ class _StateManagerImpl(StateManager):
dtype=dtype,
initializer=initializer,
trainable=self._trainable and trainable,
use_resource=True,
use_resource=use_resource,
# TODO(rohanj): Get rid of this hack once we have a mechanism for
# specifying a default partitioner for an entire layer. In that case,
# the default getter for Layers should work.
@ -2539,6 +2543,8 @@ class EmbeddingColumn(
shape=embedding_shape,
dtype=dtypes.float32,
trainable=self.trainable,
# TODO(rohanj): Make this True when b/118500434 is fixed.
use_resource=False,
initializer=self.initializer)
def _get_dense_tensor_internal_helper(self, sparse_tensors,

View File

@ -1816,6 +1816,8 @@ class LinearModelTest(test.TestCase):
'sparse_feature': [['a'], ['x']],
}
model(features)
for var in model.variables:
self.assertTrue(isinstance(var, variables_lib.RefVariable))
variable_names = [var.name for var in model.variables]
self.assertItemsEqual([
'linear_model/dense_feature_bucketized/weights:0',
@ -5592,6 +5594,7 @@ class _TestStateManager(fc.StateManager):
shape,
dtype=None,
trainable=True,
use_resource=True,
initializer=None):
if feature_column not in self._all_variables:
self._all_variables[feature_column] = {}
@ -5604,6 +5607,7 @@ class _TestStateManager(fc.StateManager):
shape=shape,
dtype=dtype,
trainable=self._trainable and trainable,
use_resource=use_resource,
initializer=initializer)
var_dict[name] = var
return var
@ -6182,6 +6186,8 @@ class EmbeddingColumnTest(test.TestCase):
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',),
tuple([v.name for v in global_vars]))
for v in global_vars:
self.assertTrue(isinstance(v, variables_lib.RefVariable))
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',),
tuple([v.name for v in trainable_vars]))
@ -6964,6 +6970,8 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertItemsEqual(
['aaa_bbb_shared_embedding:0', 'ccc_ddd_shared_embedding:0'],
tuple([v.name for v in global_vars]))
for v in global_vars:
self.assertTrue(isinstance(v, variables_lib.RefVariable))
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
if trainable:
self.assertItemsEqual(