Add resources to a dict instead of a list.
This will make sure the variables are named properly and not by position in the array. It should prevent accidental loading of variables from different feature columns. Since some dependent libraries use add_resource to attach non-trackable objects, we make sure that only the ones inheriting from trackable.Trackable get put onto the layer. PiperOrigin-RevId: 297587680 Change-Id: I09bee8029851a2166985265550a587be098eee96
This commit is contained in:
parent
25e67f8720
commit
1f9c010786
@ -164,6 +164,7 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_utils
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import deprecation
|
||||
@ -236,7 +237,7 @@ class StateManager(object):
|
||||
def add_resource(self, feature_column, name, resource):
|
||||
"""Creates a new resource.
|
||||
|
||||
Resources can be things such as tables etc.
|
||||
Resources can be things such as tables, variables, trackables, etc.
|
||||
|
||||
Args:
|
||||
feature_column: A `FeatureColumn` object this resource corresponds to.
|
||||
@ -249,10 +250,22 @@ class StateManager(object):
|
||||
del feature_column, name, resource
|
||||
raise NotImplementedError('StateManager.add_resource')
|
||||
|
||||
def has_resource(self, feature_column, name):
|
||||
"""Returns true iff a resource with same name exists.
|
||||
|
||||
Resources can be things such as tables, variables, trackables, etc.
|
||||
|
||||
Args:
|
||||
feature_column: A `FeatureColumn` object this variable corresponds to.
|
||||
name: Name of the resource.
|
||||
"""
|
||||
del feature_column, name
|
||||
raise NotImplementedError('StateManager.has_resource')
|
||||
|
||||
def get_resource(self, feature_column, name):
|
||||
"""Returns an already created resource.
|
||||
|
||||
Resources can be things such as tables etc.
|
||||
Resources can be things such as tables, variables, trackables, etc.
|
||||
|
||||
Args:
|
||||
feature_column: A `FeatureColumn` object this variable corresponds to.
|
||||
@ -275,11 +288,8 @@ class _StateManagerImpl(StateManager):
|
||||
self._trainable = trainable
|
||||
self._layer = layer
|
||||
if self._layer is not None and not hasattr(self._layer, '_resources'):
|
||||
self._layer._resources = [] # pylint: disable=protected-access
|
||||
self._layer._resources = data_structures.Mapping() # pylint: disable=protected-access
|
||||
self._cols_to_vars_map = collections.defaultdict(lambda: {})
|
||||
# TODO(vbardiovsky): Make sure the resources are tracked by moving them to
|
||||
# the layer (inheriting from AutoTrackable), e.g.:
|
||||
# self._layer._resources_map = data_structures.Mapping()
|
||||
self._cols_to_resources_map = collections.defaultdict(lambda: {})
|
||||
|
||||
def create_variable(self,
|
||||
@ -323,15 +333,25 @@ class _StateManagerImpl(StateManager):
|
||||
return self._cols_to_vars_map[feature_column][name]
|
||||
raise ValueError('Variable does not exist.')
|
||||
|
||||
def add_resource(self, feature_column, name, resource):
|
||||
self._cols_to_resources_map[feature_column][name] = resource
|
||||
if self._layer is not None:
|
||||
self._layer._resources.append(resource) # pylint: disable=protected-access
|
||||
def add_resource(self, feature_column, resource_name, resource):
|
||||
self._cols_to_resources_map[feature_column][resource_name] = resource
|
||||
# pylint: disable=protected-access
|
||||
if self._layer is not None and isinstance(resource, trackable.Trackable):
|
||||
# Add trackable resources to the layer for serialization.
|
||||
if feature_column.name not in self._layer._resources:
|
||||
self._layer._resources[feature_column.name] = data_structures.Mapping()
|
||||
if resource_name not in self._layer._resources[feature_column.name]:
|
||||
self._layer._resources[feature_column.name][resource_name] = resource
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def get_resource(self, feature_column, name):
|
||||
if name in self._cols_to_resources_map[feature_column]:
|
||||
return self._cols_to_resources_map[feature_column][name]
|
||||
raise ValueError('Resource does not exist.')
|
||||
def has_resource(self, feature_column, resource_name):
|
||||
return resource_name in self._cols_to_resources_map[feature_column]
|
||||
|
||||
def get_resource(self, feature_column, resource_name):
|
||||
if (feature_column not in self._cols_to_resources_map or
|
||||
resource_name not in self._cols_to_resources_map[feature_column]):
|
||||
raise ValueError('Resource does not exist.')
|
||||
return self._cols_to_resources_map[feature_column][resource_name]
|
||||
|
||||
|
||||
class _StateManagerImplV2(_StateManagerImpl):
|
||||
@ -3736,15 +3756,20 @@ class VocabularyFileCategoricalColumn(
|
||||
input_tensor = math_ops.cast(input_tensor, dtypes.int64)
|
||||
|
||||
name = '{}_lookup'.format(self.key)
|
||||
table = lookup_ops.index_table_from_file(
|
||||
vocabulary_file=self.vocabulary_file,
|
||||
num_oov_buckets=self.num_oov_buckets,
|
||||
vocab_size=self.vocabulary_size,
|
||||
default_value=self.default_value,
|
||||
key_dtype=key_dtype,
|
||||
name=name)
|
||||
if state_manager is not None:
|
||||
state_manager.add_resource(self, name, table)
|
||||
if state_manager is None or not state_manager.has_resource(self, name):
|
||||
with ops.init_scope():
|
||||
table = lookup_ops.index_table_from_file(
|
||||
vocabulary_file=self.vocabulary_file,
|
||||
num_oov_buckets=self.num_oov_buckets,
|
||||
vocab_size=self.vocabulary_size,
|
||||
default_value=self.default_value,
|
||||
key_dtype=key_dtype,
|
||||
name=name)
|
||||
if state_manager is not None:
|
||||
state_manager.add_resource(self, name, table)
|
||||
else:
|
||||
# Reuse the table from the previous run.
|
||||
table = state_manager.get_resource(self, name)
|
||||
return table.lookup(input_tensor)
|
||||
|
||||
def transform_feature(self, transformation_cache, state_manager):
|
||||
@ -3851,14 +3876,19 @@ class VocabularyListCategoricalColumn(
|
||||
input_tensor = math_ops.cast(input_tensor, dtypes.int64)
|
||||
|
||||
name = '{}_lookup'.format(self.key)
|
||||
table = lookup_ops.index_table_from_tensor(
|
||||
vocabulary_list=tuple(self.vocabulary_list),
|
||||
default_value=self.default_value,
|
||||
num_oov_buckets=self.num_oov_buckets,
|
||||
dtype=key_dtype,
|
||||
name=name)
|
||||
if state_manager is not None:
|
||||
state_manager.add_resource(self, name, table)
|
||||
if state_manager is None or not state_manager.has_resource(self, name):
|
||||
with ops.init_scope():
|
||||
table = lookup_ops.index_table_from_tensor(
|
||||
vocabulary_list=tuple(self.vocabulary_list),
|
||||
default_value=self.default_value,
|
||||
num_oov_buckets=self.num_oov_buckets,
|
||||
dtype=key_dtype,
|
||||
name=name)
|
||||
if state_manager is not None:
|
||||
state_manager.add_resource(self, name, table)
|
||||
else:
|
||||
# Reuse the table from the previous run.
|
||||
table = state_manager.get_resource(self, name)
|
||||
return table.lookup(input_tensor)
|
||||
|
||||
def transform_feature(self, transformation_cache, state_manager):
|
||||
|
Loading…
x
Reference in New Issue
Block a user