Add resources to a dict instead of a list.

This will make sure the variables are named properly and not by position in the array. It should prevent accidental loading of variables from different feature columns.

Since some dependent libraries use add_resource to attach non-trackable objects, we make sure that only the ones inheriting from trackable.Trackable get put onto the layer.

PiperOrigin-RevId: 297587680
Change-Id: I09bee8029851a2166985265550a587be098eee96
This commit is contained in:
Vojtech Bardiovsky 2020-02-27 07:06:33 -08:00 committed by TensorFlower Gardener
parent 25e67f8720
commit 1f9c010786

View File

@ -164,6 +164,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_utils
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import deprecation
@ -236,7 +237,7 @@ class StateManager(object):
def add_resource(self, feature_column, name, resource):
"""Creates a new resource.
Resources can be things such as tables etc.
Resources can be things such as tables, variables, trackables, etc.
Args:
feature_column: A `FeatureColumn` object this resource corresponds to.
@ -249,10 +250,22 @@ class StateManager(object):
del feature_column, name, resource
raise NotImplementedError('StateManager.add_resource')
def has_resource(self, feature_column, name):
"""Returns true iff a resource with same name exists.
Resources can be things such as tables, variables, trackables, etc.
Args:
feature_column: A `FeatureColumn` object this variable corresponds to.
name: Name of the resource.
"""
del feature_column, name
raise NotImplementedError('StateManager.has_resource')
def get_resource(self, feature_column, name):
"""Returns an already created resource.
Resources can be things such as tables etc.
Resources can be things such as tables, variables, trackables, etc.
Args:
feature_column: A `FeatureColumn` object this variable corresponds to.
@ -275,11 +288,8 @@ class _StateManagerImpl(StateManager):
self._trainable = trainable
self._layer = layer
if self._layer is not None and not hasattr(self._layer, '_resources'):
self._layer._resources = [] # pylint: disable=protected-access
self._layer._resources = data_structures.Mapping() # pylint: disable=protected-access
self._cols_to_vars_map = collections.defaultdict(lambda: {})
# TODO(vbardiovsky): Make sure the resources are tracked by moving them to
# the layer (inheriting from AutoTrackable), e.g.:
# self._layer._resources_map = data_structures.Mapping()
self._cols_to_resources_map = collections.defaultdict(lambda: {})
def create_variable(self,
@ -323,15 +333,25 @@ class _StateManagerImpl(StateManager):
return self._cols_to_vars_map[feature_column][name]
raise ValueError('Variable does not exist.')
def add_resource(self, feature_column, name, resource):
self._cols_to_resources_map[feature_column][name] = resource
if self._layer is not None:
self._layer._resources.append(resource) # pylint: disable=protected-access
def add_resource(self, feature_column, resource_name, resource):
self._cols_to_resources_map[feature_column][resource_name] = resource
# pylint: disable=protected-access
if self._layer is not None and isinstance(resource, trackable.Trackable):
# Add trackable resources to the layer for serialization.
if feature_column.name not in self._layer._resources:
self._layer._resources[feature_column.name] = data_structures.Mapping()
if resource_name not in self._layer._resources[feature_column.name]:
self._layer._resources[feature_column.name][resource_name] = resource
# pylint: enable=protected-access
def get_resource(self, feature_column, name):
if name in self._cols_to_resources_map[feature_column]:
return self._cols_to_resources_map[feature_column][name]
raise ValueError('Resource does not exist.')
def has_resource(self, feature_column, resource_name):
return resource_name in self._cols_to_resources_map[feature_column]
def get_resource(self, feature_column, resource_name):
if (feature_column not in self._cols_to_resources_map or
resource_name not in self._cols_to_resources_map[feature_column]):
raise ValueError('Resource does not exist.')
return self._cols_to_resources_map[feature_column][resource_name]
class _StateManagerImplV2(_StateManagerImpl):
@ -3736,15 +3756,20 @@ class VocabularyFileCategoricalColumn(
input_tensor = math_ops.cast(input_tensor, dtypes.int64)
name = '{}_lookup'.format(self.key)
table = lookup_ops.index_table_from_file(
vocabulary_file=self.vocabulary_file,
num_oov_buckets=self.num_oov_buckets,
vocab_size=self.vocabulary_size,
default_value=self.default_value,
key_dtype=key_dtype,
name=name)
if state_manager is not None:
state_manager.add_resource(self, name, table)
if state_manager is None or not state_manager.has_resource(self, name):
with ops.init_scope():
table = lookup_ops.index_table_from_file(
vocabulary_file=self.vocabulary_file,
num_oov_buckets=self.num_oov_buckets,
vocab_size=self.vocabulary_size,
default_value=self.default_value,
key_dtype=key_dtype,
name=name)
if state_manager is not None:
state_manager.add_resource(self, name, table)
else:
# Reuse the table from the previous run.
table = state_manager.get_resource(self, name)
return table.lookup(input_tensor)
def transform_feature(self, transformation_cache, state_manager):
@ -3851,14 +3876,19 @@ class VocabularyListCategoricalColumn(
input_tensor = math_ops.cast(input_tensor, dtypes.int64)
name = '{}_lookup'.format(self.key)
table = lookup_ops.index_table_from_tensor(
vocabulary_list=tuple(self.vocabulary_list),
default_value=self.default_value,
num_oov_buckets=self.num_oov_buckets,
dtype=key_dtype,
name=name)
if state_manager is not None:
state_manager.add_resource(self, name, table)
if state_manager is None or not state_manager.has_resource(self, name):
with ops.init_scope():
table = lookup_ops.index_table_from_tensor(
vocabulary_list=tuple(self.vocabulary_list),
default_value=self.default_value,
num_oov_buckets=self.num_oov_buckets,
dtype=key_dtype,
name=name)
if state_manager is not None:
state_manager.add_resource(self, name, table)
else:
# Reuse the table from the previous run.
table = state_manager.get_resource(self, name)
return table.lookup(input_tensor)
def transform_feature(self, transformation_cache, state_manager):