Use initializers instead of init_ops in feature column serialization.

PiperOrigin-RevId: 304910252
Change-Id: I61246eaf716c926ee07ea3c6edafd1eb72b74595
This commit is contained in:
Zhenyu Tan 2020-04-05 12:18:48 -07:00 committed by TensorFlower Gardener
parent 626e41ed1e
commit 248675bf77
2 changed files with 12 additions and 19 deletions

View File

@ -152,7 +152,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
@ -543,7 +542,7 @@ class _LinearModelLayer(Layer):
name='weights',
dtype=dtypes.float32,
shape=(first_dim, self._units),
initializer=init_ops.zeros_initializer(),
initializer=initializers.zeros(),
trainable=self.trainable)
# Create a bias variable.
@ -551,7 +550,7 @@ class _LinearModelLayer(Layer):
name='bias_weights',
dtype=dtypes.float32,
shape=[self._units],
initializer=init_ops.zeros_initializer(),
initializer=initializers.zeros(),
trainable=self.trainable,
use_resource=True,
# TODO(rohanj): Get rid of this hack once we have a mechanism for
@ -962,7 +961,7 @@ def embedding_column(categorical_column,
'Embedding of column_name: {}'.format(
categorical_column.name))
if initializer is None:
initializer = init_ops.truncated_normal_initializer(
initializer = initializers.truncated_normal(
mean=0.0, stddev=1 / math.sqrt(dimension))
return EmbeddingColumn(
@ -1104,7 +1103,7 @@ def shared_embedding_columns(categorical_columns,
if (initializer is not None) and (not callable(initializer)):
raise ValueError('initializer must be callable if specified.')
if initializer is None:
initializer = init_ops.truncated_normal_initializer(
initializer = initializers.truncated_normal(
mean=0.0, stddev=1. / math.sqrt(dimension))
# Sort the columns so the default collection name is deterministic even if the
@ -1287,7 +1286,7 @@ def shared_embedding_columns_v2(categorical_columns,
if (initializer is not None) and (not callable(initializer)):
raise ValueError('initializer must be callable if specified.')
if initializer is None:
initializer = init_ops.truncated_normal_initializer(
initializer = initializers.truncated_normal(
mean=0.0, stddev=1. / math.sqrt(dimension))
# Sort the columns so the default collection name is deterministic even if the

View File

@ -41,8 +41,8 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.keras import initializers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import partitioned_variables
@ -6662,6 +6662,9 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
self.assertEqual([categorical_column], embedding_column.parents)
config = embedding_column.get_config()
# initializer config contains `dtype` in v1.
initializer_config = initializers.serialize(initializers.truncated_normal(
mean=0.0, stddev=1 / np.sqrt(2)))
self.assertEqual(
{
'categorical_column': {
@ -6675,24 +6678,15 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
'ckpt_to_load_from': None,
'combiner': 'mean',
'dimension': 2,
'initializer': {
'class_name': 'TruncatedNormal',
'config': {
'dtype': 'float32',
'stddev': 0.7071067811865475,
'seed': None,
'mean': 0.0
}
},
'initializer': initializer_config,
'max_norm': None,
'tensor_name_in_ckpt': None,
'trainable': True,
'use_safe_embedding_lookup': True
}, config)
custom_objects = {'TruncatedNormal': init_ops.TruncatedNormal}
new_embedding_column = fc.EmbeddingColumn.from_config(
config, custom_objects=custom_objects)
config, custom_objects=None)
self.assertEqual(embedding_column.get_config(),
new_embedding_column.get_config())
self.assertIsNot(categorical_column,
@ -6700,7 +6694,7 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
new_embedding_column = fc.EmbeddingColumn.from_config(
config,
custom_objects=custom_objects,
custom_objects=None,
columns_by_name={
serialization._column_name_with_class_name(categorical_column):
categorical_column