Update feature_column to not rely on Keras initializer.

This is trying to remove the deps from Tensorflow to Keras.

PiperOrigin-RevId: 315619009
Change-Id: I0f39881eb91ab2003aa5a4f600fc95b53333c0bc
This commit is contained in:
Scott Zhu 2020-06-09 20:55:37 -07:00 committed by TensorFlower Gardener
parent 012401fb38
commit 3d53fd6875
3 changed files with 23 additions and 13 deletions

View File

@ -90,7 +90,6 @@ py_library(
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
"//tensorflow/python/keras:initializers",
"//tensorflow/python/keras/utils:generic_utils",
"//tensorflow/python/training/tracking",
"//tensorflow/python/training/tracking:data_structures",

View File

@ -144,12 +144,12 @@ from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
# TODO(b/118385027): Dependency on keras can be problematic if Keras moves out
# of the main repo.
from tensorflow.python.keras import initializers
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
@ -165,6 +165,7 @@ from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import tracking
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.tf_export import tf_export
@ -588,7 +589,7 @@ def embedding_column(categorical_column,
'Embedding of column_name: {}'.format(
categorical_column.name))
if initializer is None:
initializer = initializers.truncated_normal(
initializer = init_ops.truncated_normal_initializer(
mean=0.0, stddev=1 / math.sqrt(dimension))
return EmbeddingColumn(
@ -730,7 +731,7 @@ def shared_embedding_columns(categorical_columns,
if (initializer is not None) and (not callable(initializer)):
raise ValueError('initializer must be callable if specified.')
if initializer is None:
initializer = initializers.truncated_normal(
initializer = init_ops.truncated_normal_initializer(
mean=0.0, stddev=1. / math.sqrt(dimension))
# Sort the columns so the default collection name is deterministic even if the
@ -913,7 +914,7 @@ def shared_embedding_columns_v2(categorical_columns,
if (initializer is not None) and (not callable(initializer)):
raise ValueError('initializer must be callable if specified.')
if initializer is None:
initializer = initializers.truncated_normal(
initializer = init_ops.truncated_normal_initializer(
mean=0.0, stddev=1. / math.sqrt(dimension))
# Sort the columns so the default collection name is deterministic even if the
@ -3030,7 +3031,8 @@ class EmbeddingColumn(
config = dict(zip(self._fields, self))
config['categorical_column'] = serialize_feature_column(
self.categorical_column)
config['initializer'] = initializers.serialize(self.initializer)
config['initializer'] = generic_utils.serialize_keras_object(
self.initializer)
return config
@classmethod
@ -3043,8 +3045,11 @@ class EmbeddingColumn(
kwargs = _standardize_and_copy_config(config)
kwargs['categorical_column'] = deserialize_feature_column(
config['categorical_column'], custom_objects, columns_by_name)
kwargs['initializer'] = initializers.deserialize(
config['initializer'], custom_objects=custom_objects)
all_initializers = dict(tf_inspect.getmembers(init_ops, tf_inspect.isclass))
kwargs['initializer'] = generic_utils.deserialize_keras_object(
config['initializer'],
module_objects=all_initializers,
custom_objects=custom_objects)
return cls(**kwargs)

View File

@ -40,7 +40,6 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.keras import initializers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import parsing_ops
@ -117,6 +116,7 @@ class LazyColumnTest(test.TestCase):
class TransformCounter(BaseFeatureColumnForTests):
def __init__(self):
super(TransformCounter, self).__init__()
self.num_transform = 0
@property
@ -4285,6 +4285,7 @@ class TransformFeaturesTest(test.TestCase):
class _LoggerColumn(BaseFeatureColumnForTests):
def __init__(self, name):
super(_LoggerColumn, self).__init__()
self._name = name
@property
@ -5362,9 +5363,6 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
self.assertEqual([categorical_column], embedding_column.parents)
config = embedding_column.get_config()
# initializer config contains `dtype` in v1.
initializer_config = initializers.serialize(initializers.truncated_normal(
mean=0.0, stddev=1 / np.sqrt(2)))
self.assertEqual(
{
'categorical_column': {
@ -5378,7 +5376,15 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
'ckpt_to_load_from': None,
'combiner': 'mean',
'dimension': 2,
'initializer': initializer_config,
'initializer': {
'class_name': 'TruncatedNormal',
'config': {
'dtype': 'float32',
'stddev': 0.7071067811865475,
'seed': None,
'mean': 0.0
}
},
'max_norm': None,
'tensor_name_in_ckpt': None,
'trainable': True,