Use initializers instead of init_ops in feature column serialization.
PiperOrigin-RevId: 304910252 Change-Id: I61246eaf716c926ee07ea3c6edafd1eb72b74595
This commit is contained in:
parent
626e41ed1e
commit
248675bf77
@ -152,7 +152,6 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
@ -543,7 +542,7 @@ class _LinearModelLayer(Layer):
|
||||
name='weights',
|
||||
dtype=dtypes.float32,
|
||||
shape=(first_dim, self._units),
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
initializer=initializers.zeros(),
|
||||
trainable=self.trainable)
|
||||
|
||||
# Create a bias variable.
|
||||
@ -551,7 +550,7 @@ class _LinearModelLayer(Layer):
|
||||
name='bias_weights',
|
||||
dtype=dtypes.float32,
|
||||
shape=[self._units],
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
initializer=initializers.zeros(),
|
||||
trainable=self.trainable,
|
||||
use_resource=True,
|
||||
# TODO(rohanj): Get rid of this hack once we have a mechanism for
|
||||
@ -962,7 +961,7 @@ def embedding_column(categorical_column,
|
||||
'Embedding of column_name: {}'.format(
|
||||
categorical_column.name))
|
||||
if initializer is None:
|
||||
initializer = init_ops.truncated_normal_initializer(
|
||||
initializer = initializers.truncated_normal(
|
||||
mean=0.0, stddev=1 / math.sqrt(dimension))
|
||||
|
||||
return EmbeddingColumn(
|
||||
@ -1104,7 +1103,7 @@ def shared_embedding_columns(categorical_columns,
|
||||
if (initializer is not None) and (not callable(initializer)):
|
||||
raise ValueError('initializer must be callable if specified.')
|
||||
if initializer is None:
|
||||
initializer = init_ops.truncated_normal_initializer(
|
||||
initializer = initializers.truncated_normal(
|
||||
mean=0.0, stddev=1. / math.sqrt(dimension))
|
||||
|
||||
# Sort the columns so the default collection name is deterministic even if the
|
||||
@ -1287,7 +1286,7 @@ def shared_embedding_columns_v2(categorical_columns,
|
||||
if (initializer is not None) and (not callable(initializer)):
|
||||
raise ValueError('initializer must be callable if specified.')
|
||||
if initializer is None:
|
||||
initializer = init_ops.truncated_normal_initializer(
|
||||
initializer = initializers.truncated_normal(
|
||||
mean=0.0, stddev=1. / math.sqrt(dimension))
|
||||
|
||||
# Sort the columns so the default collection name is deterministic even if the
|
||||
|
@ -41,8 +41,8 @@ from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import partitioned_variables
|
||||
@ -6662,6 +6662,9 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual([categorical_column], embedding_column.parents)
|
||||
|
||||
config = embedding_column.get_config()
|
||||
# initializer config contains `dtype` in v1.
|
||||
initializer_config = initializers.serialize(initializers.truncated_normal(
|
||||
mean=0.0, stddev=1 / np.sqrt(2)))
|
||||
self.assertEqual(
|
||||
{
|
||||
'categorical_column': {
|
||||
@ -6675,24 +6678,15 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||
'ckpt_to_load_from': None,
|
||||
'combiner': 'mean',
|
||||
'dimension': 2,
|
||||
'initializer': {
|
||||
'class_name': 'TruncatedNormal',
|
||||
'config': {
|
||||
'dtype': 'float32',
|
||||
'stddev': 0.7071067811865475,
|
||||
'seed': None,
|
||||
'mean': 0.0
|
||||
}
|
||||
},
|
||||
'initializer': initializer_config,
|
||||
'max_norm': None,
|
||||
'tensor_name_in_ckpt': None,
|
||||
'trainable': True,
|
||||
'use_safe_embedding_lookup': True
|
||||
}, config)
|
||||
|
||||
custom_objects = {'TruncatedNormal': init_ops.TruncatedNormal}
|
||||
new_embedding_column = fc.EmbeddingColumn.from_config(
|
||||
config, custom_objects=custom_objects)
|
||||
config, custom_objects=None)
|
||||
self.assertEqual(embedding_column.get_config(),
|
||||
new_embedding_column.get_config())
|
||||
self.assertIsNot(categorical_column,
|
||||
@ -6700,7 +6694,7 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
new_embedding_column = fc.EmbeddingColumn.from_config(
|
||||
config,
|
||||
custom_objects=custom_objects,
|
||||
custom_objects=None,
|
||||
columns_by_name={
|
||||
serialization._column_name_with_class_name(categorical_column):
|
||||
categorical_column
|
||||
|
Loading…
x
Reference in New Issue
Block a user