From 3146bc0a2240ac829437009c6cdc614b3869cd0a Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Wed, 7 Nov 2018 14:59:59 -0800 Subject: [PATCH] In the short term, making Feature Column V2 produce RefVariables instead of ResourceVariables till we figure out the performance regression issues. PiperOrigin-RevId: 220535259 --- tensorflow/python/feature_column/feature_column_v2.py | 10 ++++++++-- .../python/feature_column/feature_column_v2_test.py | 8 ++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py index d97d41dd830..bd198ed53d3 100644 --- a/tensorflow/python/feature_column/feature_column_v2.py +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -184,6 +184,7 @@ class StateManager(object): shape, dtype=None, trainable=True, + use_resource=True, initializer=None): """Creates a new variable. @@ -193,12 +194,14 @@ class StateManager(object): shape: variable shape. dtype: The type of the variable. Defaults to `self.dtype` or `float32`. trainable: Whether this variable is trainable or not. + use_resource: If true, we use resource variables. Otherwise we use + RefVariable. initializer: initializer instance (callable). Returns: The created variable. """ - del feature_column, name, shape, dtype, trainable, initializer + del feature_column, name, shape, dtype, trainable, use_resource, initializer raise NotImplementedError('StateManager.create_variable') def add_variable(self, feature_column, var): @@ -270,6 +273,7 @@ class _StateManagerImpl(StateManager): shape, dtype=None, trainable=True, + use_resource=True, initializer=None): if name in self._cols_to_vars_map[feature_column]: raise ValueError('Variable already exists.') @@ -280,7 +284,7 @@ class _StateManagerImpl(StateManager): dtype=dtype, initializer=initializer, trainable=self._trainable and trainable, - use_resource=True, + use_resource=use_resource, # TODO(rohanj): Get rid of this hack once we have a mechanism for # specifying a default partitioner for an entire layer. In that case, # the default getter for Layers should work. @@ -2539,6 +2543,8 @@ class EmbeddingColumn( shape=embedding_shape, dtype=dtypes.float32, trainable=self.trainable, + # TODO(rohanj): Make this True when b/118500434 is fixed. + use_resource=False, initializer=self.initializer) def _get_dense_tensor_internal_helper(self, sparse_tensors, diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py index ab727752b49..45317a7d4a0 100644 --- a/tensorflow/python/feature_column/feature_column_v2_test.py +++ b/tensorflow/python/feature_column/feature_column_v2_test.py @@ -1816,6 +1816,8 @@ class LinearModelTest(test.TestCase): 'sparse_feature': [['a'], ['x']], } model(features) + for var in model.variables: + self.assertTrue(isinstance(var, variables_lib.RefVariable)) variable_names = [var.name for var in model.variables] self.assertItemsEqual([ 'linear_model/dense_feature_bucketized/weights:0', @@ -5592,6 +5594,7 @@ class _TestStateManager(fc.StateManager): shape, dtype=None, trainable=True, + use_resource=True, initializer=None): if feature_column not in self._all_variables: self._all_variables[feature_column] = {} @@ -5604,6 +5607,7 @@ class _TestStateManager(fc.StateManager): shape=shape, dtype=dtype, trainable=self._trainable and trainable, + use_resource=use_resource, initializer=initializer) var_dict[name] = var return var @@ -6182,6 +6186,8 @@ class EmbeddingColumnTest(test.TestCase): global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',), tuple([v.name for v in global_vars])) + for v in global_vars: + self.assertTrue(isinstance(v, variables_lib.RefVariable)) trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',), tuple([v.name for v in trainable_vars])) @@ -6964,6 +6970,8 @@ class SharedEmbeddingColumnTest(test.TestCase): self.assertItemsEqual( ['aaa_bbb_shared_embedding:0', 'ccc_ddd_shared_embedding:0'], tuple([v.name for v in global_vars])) + for v in global_vars: + self.assertTrue(isinstance(v, variables_lib.RefVariable)) trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) if trainable: self.assertItemsEqual(