From 3d53fd687526c1ea982f597497123e1c39ad9cef Mon Sep 17 00:00:00 2001 From: Scott Zhu <scottzhu@google.com> Date: Tue, 9 Jun 2020 20:55:37 -0700 Subject: [PATCH] Update feature_column to not rely on Keras initializer. This is trying to remove the deps from Tensorflow to Keras. PiperOrigin-RevId: 315619009 Change-Id: I0f39881eb91ab2003aa5a4f600fc95b53333c0bc --- tensorflow/python/feature_column/BUILD | 1 - .../feature_column/feature_column_v2.py | 19 ++++++++++++------- .../feature_column/feature_column_v2_test.py | 16 +++++++++++----- 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD index 52f1186c5d9..bd4152c6d42 100644 --- a/tensorflow/python/feature_column/BUILD +++ b/tensorflow/python/feature_column/BUILD @@ -90,7 +90,6 @@ py_library( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/eager:context", - "//tensorflow/python/keras:initializers", "//tensorflow/python/keras/utils:generic_utils", "//tensorflow/python/training/tracking", "//tensorflow/python/training/tracking:data_structures", diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py index a03e4da0fae..73d33c1e0e6 100644 --- a/tensorflow/python/feature_column/feature_column_v2.py +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -144,12 +144,12 @@ from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape # TODO(b/118385027): Dependency on keras can be problematic if Keras moves out # of the main repo. -from tensorflow.python.keras import initializers from tensorflow.python.keras.utils import generic_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops @@ -165,6 +165,7 @@ from tensorflow.python.training.tracking import data_structures from tensorflow.python.training.tracking import tracking from tensorflow.python.util import deprecation from tensorflow.python.util import nest +from tensorflow.python.util import tf_inspect from tensorflow.python.util.compat import collections_abc from tensorflow.python.util.tf_export import tf_export @@ -588,7 +589,7 @@ def embedding_column(categorical_column, 'Embedding of column_name: {}'.format( categorical_column.name)) if initializer is None: - initializer = initializers.truncated_normal( + initializer = init_ops.truncated_normal_initializer( mean=0.0, stddev=1 / math.sqrt(dimension)) return EmbeddingColumn( @@ -730,7 +731,7 @@ def shared_embedding_columns(categorical_columns, if (initializer is not None) and (not callable(initializer)): raise ValueError('initializer must be callable if specified.') if initializer is None: - initializer = initializers.truncated_normal( + initializer = init_ops.truncated_normal_initializer( mean=0.0, stddev=1. / math.sqrt(dimension)) # Sort the columns so the default collection name is deterministic even if the @@ -913,7 +914,7 @@ def shared_embedding_columns_v2(categorical_columns, if (initializer is not None) and (not callable(initializer)): raise ValueError('initializer must be callable if specified.') if initializer is None: - initializer = initializers.truncated_normal( + initializer = init_ops.truncated_normal_initializer( mean=0.0, stddev=1. / math.sqrt(dimension)) # Sort the columns so the default collection name is deterministic even if the @@ -3030,7 +3031,8 @@ class EmbeddingColumn( config = dict(zip(self._fields, self)) config['categorical_column'] = serialize_feature_column( self.categorical_column) - config['initializer'] = initializers.serialize(self.initializer) + config['initializer'] = generic_utils.serialize_keras_object( + self.initializer) return config @classmethod @@ -3043,8 +3045,11 @@ class EmbeddingColumn( kwargs = _standardize_and_copy_config(config) kwargs['categorical_column'] = deserialize_feature_column( config['categorical_column'], custom_objects, columns_by_name) - kwargs['initializer'] = initializers.deserialize( - config['initializer'], custom_objects=custom_objects) + all_initializers = dict(tf_inspect.getmembers(init_ops, tf_inspect.isclass)) + kwargs['initializer'] = generic_utils.deserialize_keras_object( + config['initializer'], + module_objects=all_initializers, + custom_objects=custom_objects) return cls(**kwargs) diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py index 844478c879b..dda1af8a00e 100644 --- a/tensorflow/python/feature_column/feature_column_v2_test.py +++ b/tensorflow/python/feature_column/feature_column_v2_test.py @@ -40,7 +40,6 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util -from tensorflow.python.keras import initializers from tensorflow.python.ops import array_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import parsing_ops @@ -117,6 +116,7 @@ class LazyColumnTest(test.TestCase): class TransformCounter(BaseFeatureColumnForTests): def __init__(self): + super(TransformCounter, self).__init__() self.num_transform = 0 @property @@ -4285,6 +4285,7 @@ class TransformFeaturesTest(test.TestCase): class _LoggerColumn(BaseFeatureColumnForTests): def __init__(self, name): + super(_LoggerColumn, self).__init__() self._name = name @property @@ -5362,9 +5363,6 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase): self.assertEqual([categorical_column], embedding_column.parents) config = embedding_column.get_config() - # initializer config contains `dtype` in v1. - initializer_config = initializers.serialize(initializers.truncated_normal( - mean=0.0, stddev=1 / np.sqrt(2))) self.assertEqual( { 'categorical_column': { @@ -5378,7 +5376,15 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase): 'ckpt_to_load_from': None, 'combiner': 'mean', 'dimension': 2, - 'initializer': initializer_config, + 'initializer': { + 'class_name': 'TruncatedNormal', + 'config': { + 'dtype': 'float32', + 'stddev': 0.7071067811865475, + 'seed': None, + 'mean': 0.0 + } + }, 'max_norm': None, 'tensor_name_in_ckpt': None, 'trainable': True,