diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py index 3f334ce1bf1..f981909aef1 100644 --- a/tensorflow/python/feature_column/feature_column_v2.py +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -152,7 +152,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops -from tensorflow.python.ops import init_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -543,7 +542,7 @@ class _LinearModelLayer(Layer): name='weights', dtype=dtypes.float32, shape=(first_dim, self._units), - initializer=init_ops.zeros_initializer(), + initializer=initializers.zeros(), trainable=self.trainable) # Create a bias variable. @@ -551,7 +550,7 @@ class _LinearModelLayer(Layer): name='bias_weights', dtype=dtypes.float32, shape=[self._units], - initializer=init_ops.zeros_initializer(), + initializer=initializers.zeros(), trainable=self.trainable, use_resource=True, # TODO(rohanj): Get rid of this hack once we have a mechanism for @@ -962,7 +961,7 @@ def embedding_column(categorical_column, 'Embedding of column_name: {}'.format( categorical_column.name)) if initializer is None: - initializer = init_ops.truncated_normal_initializer( + initializer = initializers.truncated_normal( mean=0.0, stddev=1 / math.sqrt(dimension)) return EmbeddingColumn( @@ -1104,7 +1103,7 @@ def shared_embedding_columns(categorical_columns, if (initializer is not None) and (not callable(initializer)): raise ValueError('initializer must be callable if specified.') if initializer is None: - initializer = init_ops.truncated_normal_initializer( + initializer = initializers.truncated_normal( mean=0.0, stddev=1. / math.sqrt(dimension)) # Sort the columns so the default collection name is deterministic even if the @@ -1287,7 +1286,7 @@ def shared_embedding_columns_v2(categorical_columns, if (initializer is not None) and (not callable(initializer)): raise ValueError('initializer must be callable if specified.') if initializer is None: - initializer = init_ops.truncated_normal_initializer( + initializer = initializers.truncated_normal( mean=0.0, stddev=1. / math.sqrt(dimension)) # Sort the columns so the default collection name is deterministic even if the diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py index 8c9aee722a6..fe769850fb0 100644 --- a/tensorflow/python/feature_column/feature_column_v2_test.py +++ b/tensorflow/python/feature_column/feature_column_v2_test.py @@ -41,8 +41,8 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util +from tensorflow.python.keras import initializers from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import partitioned_variables @@ -6662,6 +6662,9 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase): self.assertEqual([categorical_column], embedding_column.parents) config = embedding_column.get_config() + # initializer config contains `dtype` in v1. + initializer_config = initializers.serialize(initializers.truncated_normal( + mean=0.0, stddev=1 / np.sqrt(2))) self.assertEqual( { 'categorical_column': { @@ -6675,24 +6678,15 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase): 'ckpt_to_load_from': None, 'combiner': 'mean', 'dimension': 2, - 'initializer': { - 'class_name': 'TruncatedNormal', - 'config': { - 'dtype': 'float32', - 'stddev': 0.7071067811865475, - 'seed': None, - 'mean': 0.0 - } - }, + 'initializer': initializer_config, 'max_norm': None, 'tensor_name_in_ckpt': None, 'trainable': True, 'use_safe_embedding_lookup': True }, config) - custom_objects = {'TruncatedNormal': init_ops.TruncatedNormal} new_embedding_column = fc.EmbeddingColumn.from_config( - config, custom_objects=custom_objects) + config, custom_objects=None) self.assertEqual(embedding_column.get_config(), new_embedding_column.get_config()) self.assertIsNot(categorical_column, @@ -6700,7 +6694,7 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase): new_embedding_column = fc.EmbeddingColumn.from_config( config, - custom_objects=custom_objects, + custom_objects=None, columns_by_name={ serialization._column_name_with_class_name(categorical_column): categorical_column