Update feature_column to not rely on Keras initializer.
This is trying to remove the deps from Tensorflow to Keras. PiperOrigin-RevId: 315619009 Change-Id: I0f39881eb91ab2003aa5a4f600fc95b53333c0bc
This commit is contained in:
parent
012401fb38
commit
3d53fd6875
tensorflow/python/feature_column
@ -90,7 +90,6 @@ py_library(
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/keras:initializers",
|
||||
"//tensorflow/python/keras/utils:generic_utils",
|
||||
"//tensorflow/python/training/tracking",
|
||||
"//tensorflow/python/training/tracking:data_structures",
|
||||
|
@ -144,12 +144,12 @@ from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
# TODO(b/118385027): Dependency on keras can be problematic if Keras moves out
|
||||
# of the main repo.
|
||||
from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
@ -165,6 +165,7 @@ from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util.compat import collections_abc
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -588,7 +589,7 @@ def embedding_column(categorical_column,
|
||||
'Embedding of column_name: {}'.format(
|
||||
categorical_column.name))
|
||||
if initializer is None:
|
||||
initializer = initializers.truncated_normal(
|
||||
initializer = init_ops.truncated_normal_initializer(
|
||||
mean=0.0, stddev=1 / math.sqrt(dimension))
|
||||
|
||||
return EmbeddingColumn(
|
||||
@ -730,7 +731,7 @@ def shared_embedding_columns(categorical_columns,
|
||||
if (initializer is not None) and (not callable(initializer)):
|
||||
raise ValueError('initializer must be callable if specified.')
|
||||
if initializer is None:
|
||||
initializer = initializers.truncated_normal(
|
||||
initializer = init_ops.truncated_normal_initializer(
|
||||
mean=0.0, stddev=1. / math.sqrt(dimension))
|
||||
|
||||
# Sort the columns so the default collection name is deterministic even if the
|
||||
@ -913,7 +914,7 @@ def shared_embedding_columns_v2(categorical_columns,
|
||||
if (initializer is not None) and (not callable(initializer)):
|
||||
raise ValueError('initializer must be callable if specified.')
|
||||
if initializer is None:
|
||||
initializer = initializers.truncated_normal(
|
||||
initializer = init_ops.truncated_normal_initializer(
|
||||
mean=0.0, stddev=1. / math.sqrt(dimension))
|
||||
|
||||
# Sort the columns so the default collection name is deterministic even if the
|
||||
@ -3030,7 +3031,8 @@ class EmbeddingColumn(
|
||||
config = dict(zip(self._fields, self))
|
||||
config['categorical_column'] = serialize_feature_column(
|
||||
self.categorical_column)
|
||||
config['initializer'] = initializers.serialize(self.initializer)
|
||||
config['initializer'] = generic_utils.serialize_keras_object(
|
||||
self.initializer)
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
@ -3043,8 +3045,11 @@ class EmbeddingColumn(
|
||||
kwargs = _standardize_and_copy_config(config)
|
||||
kwargs['categorical_column'] = deserialize_feature_column(
|
||||
config['categorical_column'], custom_objects, columns_by_name)
|
||||
kwargs['initializer'] = initializers.deserialize(
|
||||
config['initializer'], custom_objects=custom_objects)
|
||||
all_initializers = dict(tf_inspect.getmembers(init_ops, tf_inspect.isclass))
|
||||
kwargs['initializer'] = generic_utils.deserialize_keras_object(
|
||||
config['initializer'],
|
||||
module_objects=all_initializers,
|
||||
custom_objects=custom_objects)
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
|
@ -40,7 +40,6 @@ from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
@ -117,6 +116,7 @@ class LazyColumnTest(test.TestCase):
|
||||
class TransformCounter(BaseFeatureColumnForTests):
|
||||
|
||||
def __init__(self):
|
||||
super(TransformCounter, self).__init__()
|
||||
self.num_transform = 0
|
||||
|
||||
@property
|
||||
@ -4285,6 +4285,7 @@ class TransformFeaturesTest(test.TestCase):
|
||||
class _LoggerColumn(BaseFeatureColumnForTests):
|
||||
|
||||
def __init__(self, name):
|
||||
super(_LoggerColumn, self).__init__()
|
||||
self._name = name
|
||||
|
||||
@property
|
||||
@ -5362,9 +5363,6 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual([categorical_column], embedding_column.parents)
|
||||
|
||||
config = embedding_column.get_config()
|
||||
# initializer config contains `dtype` in v1.
|
||||
initializer_config = initializers.serialize(initializers.truncated_normal(
|
||||
mean=0.0, stddev=1 / np.sqrt(2)))
|
||||
self.assertEqual(
|
||||
{
|
||||
'categorical_column': {
|
||||
@ -5378,7 +5376,15 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||
'ckpt_to_load_from': None,
|
||||
'combiner': 'mean',
|
||||
'dimension': 2,
|
||||
'initializer': initializer_config,
|
||||
'initializer': {
|
||||
'class_name': 'TruncatedNormal',
|
||||
'config': {
|
||||
'dtype': 'float32',
|
||||
'stddev': 0.7071067811865475,
|
||||
'seed': None,
|
||||
'mean': 0.0
|
||||
}
|
||||
},
|
||||
'max_norm': None,
|
||||
'tensor_name_in_ckpt': None,
|
||||
'trainable': True,
|
||||
|
Loading…
Reference in New Issue
Block a user