Use initializers instead of init_ops in feature column serialization.

PiperOrigin-RevId: 304910252
Change-Id: I61246eaf716c926ee07ea3c6edafd1eb72b74595
This commit is contained in:
Zhenyu Tan 2020-04-05 12:18:48 -07:00 committed by TensorFlower Gardener
parent 626e41ed1e
commit 248675bf77
2 changed files with 12 additions and 19 deletions

View File

@ -152,7 +152,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
@ -543,7 +542,7 @@ class _LinearModelLayer(Layer):
name='weights', name='weights',
dtype=dtypes.float32, dtype=dtypes.float32,
shape=(first_dim, self._units), shape=(first_dim, self._units),
initializer=init_ops.zeros_initializer(), initializer=initializers.zeros(),
trainable=self.trainable) trainable=self.trainable)
# Create a bias variable. # Create a bias variable.
@ -551,7 +550,7 @@ class _LinearModelLayer(Layer):
name='bias_weights', name='bias_weights',
dtype=dtypes.float32, dtype=dtypes.float32,
shape=[self._units], shape=[self._units],
initializer=init_ops.zeros_initializer(), initializer=initializers.zeros(),
trainable=self.trainable, trainable=self.trainable,
use_resource=True, use_resource=True,
# TODO(rohanj): Get rid of this hack once we have a mechanism for # TODO(rohanj): Get rid of this hack once we have a mechanism for
@ -962,7 +961,7 @@ def embedding_column(categorical_column,
'Embedding of column_name: {}'.format( 'Embedding of column_name: {}'.format(
categorical_column.name)) categorical_column.name))
if initializer is None: if initializer is None:
initializer = init_ops.truncated_normal_initializer( initializer = initializers.truncated_normal(
mean=0.0, stddev=1 / math.sqrt(dimension)) mean=0.0, stddev=1 / math.sqrt(dimension))
return EmbeddingColumn( return EmbeddingColumn(
@ -1104,7 +1103,7 @@ def shared_embedding_columns(categorical_columns,
if (initializer is not None) and (not callable(initializer)): if (initializer is not None) and (not callable(initializer)):
raise ValueError('initializer must be callable if specified.') raise ValueError('initializer must be callable if specified.')
if initializer is None: if initializer is None:
initializer = init_ops.truncated_normal_initializer( initializer = initializers.truncated_normal(
mean=0.0, stddev=1. / math.sqrt(dimension)) mean=0.0, stddev=1. / math.sqrt(dimension))
# Sort the columns so the default collection name is deterministic even if the # Sort the columns so the default collection name is deterministic even if the
@ -1287,7 +1286,7 @@ def shared_embedding_columns_v2(categorical_columns,
if (initializer is not None) and (not callable(initializer)): if (initializer is not None) and (not callable(initializer)):
raise ValueError('initializer must be callable if specified.') raise ValueError('initializer must be callable if specified.')
if initializer is None: if initializer is None:
initializer = init_ops.truncated_normal_initializer( initializer = initializers.truncated_normal(
mean=0.0, stddev=1. / math.sqrt(dimension)) mean=0.0, stddev=1. / math.sqrt(dimension))
# Sort the columns so the default collection name is deterministic even if the # Sort the columns so the default collection name is deterministic even if the

View File

@ -41,8 +41,8 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.keras import initializers
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import partitioned_variables
@ -6662,6 +6662,9 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
self.assertEqual([categorical_column], embedding_column.parents) self.assertEqual([categorical_column], embedding_column.parents)
config = embedding_column.get_config() config = embedding_column.get_config()
# initializer config contains `dtype` in v1.
initializer_config = initializers.serialize(initializers.truncated_normal(
mean=0.0, stddev=1 / np.sqrt(2)))
self.assertEqual( self.assertEqual(
{ {
'categorical_column': { 'categorical_column': {
@ -6675,24 +6678,15 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
'ckpt_to_load_from': None, 'ckpt_to_load_from': None,
'combiner': 'mean', 'combiner': 'mean',
'dimension': 2, 'dimension': 2,
'initializer': { 'initializer': initializer_config,
'class_name': 'TruncatedNormal',
'config': {
'dtype': 'float32',
'stddev': 0.7071067811865475,
'seed': None,
'mean': 0.0
}
},
'max_norm': None, 'max_norm': None,
'tensor_name_in_ckpt': None, 'tensor_name_in_ckpt': None,
'trainable': True, 'trainable': True,
'use_safe_embedding_lookup': True 'use_safe_embedding_lookup': True
}, config) }, config)
custom_objects = {'TruncatedNormal': init_ops.TruncatedNormal}
new_embedding_column = fc.EmbeddingColumn.from_config( new_embedding_column = fc.EmbeddingColumn.from_config(
config, custom_objects=custom_objects) config, custom_objects=None)
self.assertEqual(embedding_column.get_config(), self.assertEqual(embedding_column.get_config(),
new_embedding_column.get_config()) new_embedding_column.get_config())
self.assertIsNot(categorical_column, self.assertIsNot(categorical_column,
@ -6700,7 +6694,7 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
new_embedding_column = fc.EmbeddingColumn.from_config( new_embedding_column = fc.EmbeddingColumn.from_config(
config, config,
custom_objects=custom_objects, custom_objects=None,
columns_by_name={ columns_by_name={
serialization._column_name_with_class_name(categorical_column): serialization._column_name_with_class_name(categorical_column):
categorical_column categorical_column