Use initializers instead of init_ops in feature column serialization.
PiperOrigin-RevId: 304910252 Change-Id: I61246eaf716c926ee07ea3c6edafd1eb72b74595
This commit is contained in:
parent
626e41ed1e
commit
248675bf77
@ -152,7 +152,6 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import embedding_ops
|
from tensorflow.python.ops import embedding_ops
|
||||||
from tensorflow.python.ops import init_ops
|
|
||||||
from tensorflow.python.ops import lookup_ops
|
from tensorflow.python.ops import lookup_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
@ -543,7 +542,7 @@ class _LinearModelLayer(Layer):
|
|||||||
name='weights',
|
name='weights',
|
||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
shape=(first_dim, self._units),
|
shape=(first_dim, self._units),
|
||||||
initializer=init_ops.zeros_initializer(),
|
initializer=initializers.zeros(),
|
||||||
trainable=self.trainable)
|
trainable=self.trainable)
|
||||||
|
|
||||||
# Create a bias variable.
|
# Create a bias variable.
|
||||||
@ -551,7 +550,7 @@ class _LinearModelLayer(Layer):
|
|||||||
name='bias_weights',
|
name='bias_weights',
|
||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
shape=[self._units],
|
shape=[self._units],
|
||||||
initializer=init_ops.zeros_initializer(),
|
initializer=initializers.zeros(),
|
||||||
trainable=self.trainable,
|
trainable=self.trainable,
|
||||||
use_resource=True,
|
use_resource=True,
|
||||||
# TODO(rohanj): Get rid of this hack once we have a mechanism for
|
# TODO(rohanj): Get rid of this hack once we have a mechanism for
|
||||||
@ -962,7 +961,7 @@ def embedding_column(categorical_column,
|
|||||||
'Embedding of column_name: {}'.format(
|
'Embedding of column_name: {}'.format(
|
||||||
categorical_column.name))
|
categorical_column.name))
|
||||||
if initializer is None:
|
if initializer is None:
|
||||||
initializer = init_ops.truncated_normal_initializer(
|
initializer = initializers.truncated_normal(
|
||||||
mean=0.0, stddev=1 / math.sqrt(dimension))
|
mean=0.0, stddev=1 / math.sqrt(dimension))
|
||||||
|
|
||||||
return EmbeddingColumn(
|
return EmbeddingColumn(
|
||||||
@ -1104,7 +1103,7 @@ def shared_embedding_columns(categorical_columns,
|
|||||||
if (initializer is not None) and (not callable(initializer)):
|
if (initializer is not None) and (not callable(initializer)):
|
||||||
raise ValueError('initializer must be callable if specified.')
|
raise ValueError('initializer must be callable if specified.')
|
||||||
if initializer is None:
|
if initializer is None:
|
||||||
initializer = init_ops.truncated_normal_initializer(
|
initializer = initializers.truncated_normal(
|
||||||
mean=0.0, stddev=1. / math.sqrt(dimension))
|
mean=0.0, stddev=1. / math.sqrt(dimension))
|
||||||
|
|
||||||
# Sort the columns so the default collection name is deterministic even if the
|
# Sort the columns so the default collection name is deterministic even if the
|
||||||
@ -1287,7 +1286,7 @@ def shared_embedding_columns_v2(categorical_columns,
|
|||||||
if (initializer is not None) and (not callable(initializer)):
|
if (initializer is not None) and (not callable(initializer)):
|
||||||
raise ValueError('initializer must be callable if specified.')
|
raise ValueError('initializer must be callable if specified.')
|
||||||
if initializer is None:
|
if initializer is None:
|
||||||
initializer = init_ops.truncated_normal_initializer(
|
initializer = initializers.truncated_normal(
|
||||||
mean=0.0, stddev=1. / math.sqrt(dimension))
|
mean=0.0, stddev=1. / math.sqrt(dimension))
|
||||||
|
|
||||||
# Sort the columns so the default collection name is deterministic even if the
|
# Sort the columns so the default collection name is deterministic even if the
|
||||||
|
@ -41,8 +41,8 @@ from tensorflow.python.framework import errors
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.keras import initializers
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import init_ops
|
|
||||||
from tensorflow.python.ops import lookup_ops
|
from tensorflow.python.ops import lookup_ops
|
||||||
from tensorflow.python.ops import parsing_ops
|
from tensorflow.python.ops import parsing_ops
|
||||||
from tensorflow.python.ops import partitioned_variables
|
from tensorflow.python.ops import partitioned_variables
|
||||||
@ -6662,6 +6662,9 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertEqual([categorical_column], embedding_column.parents)
|
self.assertEqual([categorical_column], embedding_column.parents)
|
||||||
|
|
||||||
config = embedding_column.get_config()
|
config = embedding_column.get_config()
|
||||||
|
# initializer config contains `dtype` in v1.
|
||||||
|
initializer_config = initializers.serialize(initializers.truncated_normal(
|
||||||
|
mean=0.0, stddev=1 / np.sqrt(2)))
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
{
|
{
|
||||||
'categorical_column': {
|
'categorical_column': {
|
||||||
@ -6675,24 +6678,15 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
|||||||
'ckpt_to_load_from': None,
|
'ckpt_to_load_from': None,
|
||||||
'combiner': 'mean',
|
'combiner': 'mean',
|
||||||
'dimension': 2,
|
'dimension': 2,
|
||||||
'initializer': {
|
'initializer': initializer_config,
|
||||||
'class_name': 'TruncatedNormal',
|
|
||||||
'config': {
|
|
||||||
'dtype': 'float32',
|
|
||||||
'stddev': 0.7071067811865475,
|
|
||||||
'seed': None,
|
|
||||||
'mean': 0.0
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'max_norm': None,
|
'max_norm': None,
|
||||||
'tensor_name_in_ckpt': None,
|
'tensor_name_in_ckpt': None,
|
||||||
'trainable': True,
|
'trainable': True,
|
||||||
'use_safe_embedding_lookup': True
|
'use_safe_embedding_lookup': True
|
||||||
}, config)
|
}, config)
|
||||||
|
|
||||||
custom_objects = {'TruncatedNormal': init_ops.TruncatedNormal}
|
|
||||||
new_embedding_column = fc.EmbeddingColumn.from_config(
|
new_embedding_column = fc.EmbeddingColumn.from_config(
|
||||||
config, custom_objects=custom_objects)
|
config, custom_objects=None)
|
||||||
self.assertEqual(embedding_column.get_config(),
|
self.assertEqual(embedding_column.get_config(),
|
||||||
new_embedding_column.get_config())
|
new_embedding_column.get_config())
|
||||||
self.assertIsNot(categorical_column,
|
self.assertIsNot(categorical_column,
|
||||||
@ -6700,7 +6694,7 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
new_embedding_column = fc.EmbeddingColumn.from_config(
|
new_embedding_column = fc.EmbeddingColumn.from_config(
|
||||||
config,
|
config,
|
||||||
custom_objects=custom_objects,
|
custom_objects=None,
|
||||||
columns_by_name={
|
columns_by_name={
|
||||||
serialization._column_name_with_class_name(categorical_column):
|
serialization._column_name_with_class_name(categorical_column):
|
||||||
categorical_column
|
categorical_column
|
||||||
|
Loading…
x
Reference in New Issue
Block a user