Update feature_column to not rely on Keras initializer.
This is trying to remove the deps from Tensorflow to Keras. PiperOrigin-RevId: 315619009 Change-Id: I0f39881eb91ab2003aa5a4f600fc95b53333c0bc
This commit is contained in:
parent
012401fb38
commit
3d53fd6875
@ -90,7 +90,6 @@ py_library(
|
|||||||
"//tensorflow/python:variable_scope",
|
"//tensorflow/python:variable_scope",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/keras:initializers",
|
|
||||||
"//tensorflow/python/keras/utils:generic_utils",
|
"//tensorflow/python/keras/utils:generic_utils",
|
||||||
"//tensorflow/python/training/tracking",
|
"//tensorflow/python/training/tracking",
|
||||||
"//tensorflow/python/training/tracking:data_structures",
|
"//tensorflow/python/training/tracking:data_structures",
|
||||||
|
@ -144,12 +144,12 @@ from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
|||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
# TODO(b/118385027): Dependency on keras can be problematic if Keras moves out
|
# TODO(b/118385027): Dependency on keras can be problematic if Keras moves out
|
||||||
# of the main repo.
|
# of the main repo.
|
||||||
from tensorflow.python.keras import initializers
|
|
||||||
from tensorflow.python.keras.utils import generic_utils
|
from tensorflow.python.keras.utils import generic_utils
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import embedding_ops
|
from tensorflow.python.ops import embedding_ops
|
||||||
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import lookup_ops
|
from tensorflow.python.ops import lookup_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import parsing_ops
|
from tensorflow.python.ops import parsing_ops
|
||||||
@ -165,6 +165,7 @@ from tensorflow.python.training.tracking import data_structures
|
|||||||
from tensorflow.python.training.tracking import tracking
|
from tensorflow.python.training.tracking import tracking
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
from tensorflow.python.util import tf_inspect
|
||||||
from tensorflow.python.util.compat import collections_abc
|
from tensorflow.python.util.compat import collections_abc
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
@ -588,7 +589,7 @@ def embedding_column(categorical_column,
|
|||||||
'Embedding of column_name: {}'.format(
|
'Embedding of column_name: {}'.format(
|
||||||
categorical_column.name))
|
categorical_column.name))
|
||||||
if initializer is None:
|
if initializer is None:
|
||||||
initializer = initializers.truncated_normal(
|
initializer = init_ops.truncated_normal_initializer(
|
||||||
mean=0.0, stddev=1 / math.sqrt(dimension))
|
mean=0.0, stddev=1 / math.sqrt(dimension))
|
||||||
|
|
||||||
return EmbeddingColumn(
|
return EmbeddingColumn(
|
||||||
@ -730,7 +731,7 @@ def shared_embedding_columns(categorical_columns,
|
|||||||
if (initializer is not None) and (not callable(initializer)):
|
if (initializer is not None) and (not callable(initializer)):
|
||||||
raise ValueError('initializer must be callable if specified.')
|
raise ValueError('initializer must be callable if specified.')
|
||||||
if initializer is None:
|
if initializer is None:
|
||||||
initializer = initializers.truncated_normal(
|
initializer = init_ops.truncated_normal_initializer(
|
||||||
mean=0.0, stddev=1. / math.sqrt(dimension))
|
mean=0.0, stddev=1. / math.sqrt(dimension))
|
||||||
|
|
||||||
# Sort the columns so the default collection name is deterministic even if the
|
# Sort the columns so the default collection name is deterministic even if the
|
||||||
@ -913,7 +914,7 @@ def shared_embedding_columns_v2(categorical_columns,
|
|||||||
if (initializer is not None) and (not callable(initializer)):
|
if (initializer is not None) and (not callable(initializer)):
|
||||||
raise ValueError('initializer must be callable if specified.')
|
raise ValueError('initializer must be callable if specified.')
|
||||||
if initializer is None:
|
if initializer is None:
|
||||||
initializer = initializers.truncated_normal(
|
initializer = init_ops.truncated_normal_initializer(
|
||||||
mean=0.0, stddev=1. / math.sqrt(dimension))
|
mean=0.0, stddev=1. / math.sqrt(dimension))
|
||||||
|
|
||||||
# Sort the columns so the default collection name is deterministic even if the
|
# Sort the columns so the default collection name is deterministic even if the
|
||||||
@ -3030,7 +3031,8 @@ class EmbeddingColumn(
|
|||||||
config = dict(zip(self._fields, self))
|
config = dict(zip(self._fields, self))
|
||||||
config['categorical_column'] = serialize_feature_column(
|
config['categorical_column'] = serialize_feature_column(
|
||||||
self.categorical_column)
|
self.categorical_column)
|
||||||
config['initializer'] = initializers.serialize(self.initializer)
|
config['initializer'] = generic_utils.serialize_keras_object(
|
||||||
|
self.initializer)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -3043,8 +3045,11 @@ class EmbeddingColumn(
|
|||||||
kwargs = _standardize_and_copy_config(config)
|
kwargs = _standardize_and_copy_config(config)
|
||||||
kwargs['categorical_column'] = deserialize_feature_column(
|
kwargs['categorical_column'] = deserialize_feature_column(
|
||||||
config['categorical_column'], custom_objects, columns_by_name)
|
config['categorical_column'], custom_objects, columns_by_name)
|
||||||
kwargs['initializer'] = initializers.deserialize(
|
all_initializers = dict(tf_inspect.getmembers(init_ops, tf_inspect.isclass))
|
||||||
config['initializer'], custom_objects=custom_objects)
|
kwargs['initializer'] = generic_utils.deserialize_keras_object(
|
||||||
|
config['initializer'],
|
||||||
|
module_objects=all_initializers,
|
||||||
|
custom_objects=custom_objects)
|
||||||
return cls(**kwargs)
|
return cls(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,7 +40,6 @@ from tensorflow.python.framework import errors
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.keras import initializers
|
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import lookup_ops
|
from tensorflow.python.ops import lookup_ops
|
||||||
from tensorflow.python.ops import parsing_ops
|
from tensorflow.python.ops import parsing_ops
|
||||||
@ -117,6 +116,7 @@ class LazyColumnTest(test.TestCase):
|
|||||||
class TransformCounter(BaseFeatureColumnForTests):
|
class TransformCounter(BaseFeatureColumnForTests):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
super(TransformCounter, self).__init__()
|
||||||
self.num_transform = 0
|
self.num_transform = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -4285,6 +4285,7 @@ class TransformFeaturesTest(test.TestCase):
|
|||||||
class _LoggerColumn(BaseFeatureColumnForTests):
|
class _LoggerColumn(BaseFeatureColumnForTests):
|
||||||
|
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
|
super(_LoggerColumn, self).__init__()
|
||||||
self._name = name
|
self._name = name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -5362,9 +5363,6 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertEqual([categorical_column], embedding_column.parents)
|
self.assertEqual([categorical_column], embedding_column.parents)
|
||||||
|
|
||||||
config = embedding_column.get_config()
|
config = embedding_column.get_config()
|
||||||
# initializer config contains `dtype` in v1.
|
|
||||||
initializer_config = initializers.serialize(initializers.truncated_normal(
|
|
||||||
mean=0.0, stddev=1 / np.sqrt(2)))
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
{
|
{
|
||||||
'categorical_column': {
|
'categorical_column': {
|
||||||
@ -5378,7 +5376,15 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
|||||||
'ckpt_to_load_from': None,
|
'ckpt_to_load_from': None,
|
||||||
'combiner': 'mean',
|
'combiner': 'mean',
|
||||||
'dimension': 2,
|
'dimension': 2,
|
||||||
'initializer': initializer_config,
|
'initializer': {
|
||||||
|
'class_name': 'TruncatedNormal',
|
||||||
|
'config': {
|
||||||
|
'dtype': 'float32',
|
||||||
|
'stddev': 0.7071067811865475,
|
||||||
|
'seed': None,
|
||||||
|
'mean': 0.0
|
||||||
|
}
|
||||||
|
},
|
||||||
'max_norm': None,
|
'max_norm': None,
|
||||||
'tensor_name_in_ckpt': None,
|
'tensor_name_in_ckpt': None,
|
||||||
'trainable': True,
|
'trainable': True,
|
||||||
|
Loading…
Reference in New Issue
Block a user