From 3d53fd687526c1ea982f597497123e1c39ad9cef Mon Sep 17 00:00:00 2001
From: Scott Zhu <scottzhu@google.com>
Date: Tue, 9 Jun 2020 20:55:37 -0700
Subject: [PATCH] Update feature_column to not rely on Keras initializer.

This is trying to remove the deps from Tensorflow to Keras.

PiperOrigin-RevId: 315619009
Change-Id: I0f39881eb91ab2003aa5a4f600fc95b53333c0bc
---
 tensorflow/python/feature_column/BUILD        |  1 -
 .../feature_column/feature_column_v2.py       | 19 ++++++++++++-------
 .../feature_column/feature_column_v2_test.py  | 16 +++++++++++-----
 3 files changed, 23 insertions(+), 13 deletions(-)

diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
index 52f1186c5d9..bd4152c6d42 100644
--- a/tensorflow/python/feature_column/BUILD
+++ b/tensorflow/python/feature_column/BUILD
@@ -90,7 +90,6 @@ py_library(
         "//tensorflow/python:variable_scope",
         "//tensorflow/python:variables",
         "//tensorflow/python/eager:context",
-        "//tensorflow/python/keras:initializers",
         "//tensorflow/python/keras/utils:generic_utils",
         "//tensorflow/python/training/tracking",
         "//tensorflow/python/training/tracking:data_structures",
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index a03e4da0fae..73d33c1e0e6 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -144,12 +144,12 @@ from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
 from tensorflow.python.framework import tensor_shape
 # TODO(b/118385027): Dependency on keras can be problematic if Keras moves out
 # of the main repo.
-from tensorflow.python.keras import initializers
 from tensorflow.python.keras.utils import generic_utils
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import lookup_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import parsing_ops
@@ -165,6 +165,7 @@ from tensorflow.python.training.tracking import data_structures
 from tensorflow.python.training.tracking import tracking
 from tensorflow.python.util import deprecation
 from tensorflow.python.util import nest
+from tensorflow.python.util import tf_inspect
 from tensorflow.python.util.compat import collections_abc
 from tensorflow.python.util.tf_export import tf_export
 
@@ -588,7 +589,7 @@ def embedding_column(categorical_column,
                      'Embedding of column_name: {}'.format(
                          categorical_column.name))
   if initializer is None:
-    initializer = initializers.truncated_normal(
+    initializer = init_ops.truncated_normal_initializer(
         mean=0.0, stddev=1 / math.sqrt(dimension))
 
   return EmbeddingColumn(
@@ -730,7 +731,7 @@ def shared_embedding_columns(categorical_columns,
   if (initializer is not None) and (not callable(initializer)):
     raise ValueError('initializer must be callable if specified.')
   if initializer is None:
-    initializer = initializers.truncated_normal(
+    initializer = init_ops.truncated_normal_initializer(
         mean=0.0, stddev=1. / math.sqrt(dimension))
 
   # Sort the columns so the default collection name is deterministic even if the
@@ -913,7 +914,7 @@ def shared_embedding_columns_v2(categorical_columns,
   if (initializer is not None) and (not callable(initializer)):
     raise ValueError('initializer must be callable if specified.')
   if initializer is None:
-    initializer = initializers.truncated_normal(
+    initializer = init_ops.truncated_normal_initializer(
         mean=0.0, stddev=1. / math.sqrt(dimension))
 
   # Sort the columns so the default collection name is deterministic even if the
@@ -3030,7 +3031,8 @@ class EmbeddingColumn(
     config = dict(zip(self._fields, self))
     config['categorical_column'] = serialize_feature_column(
         self.categorical_column)
-    config['initializer'] = initializers.serialize(self.initializer)
+    config['initializer'] = generic_utils.serialize_keras_object(
+        self.initializer)
     return config
 
   @classmethod
@@ -3043,8 +3045,11 @@ class EmbeddingColumn(
     kwargs = _standardize_and_copy_config(config)
     kwargs['categorical_column'] = deserialize_feature_column(
         config['categorical_column'], custom_objects, columns_by_name)
-    kwargs['initializer'] = initializers.deserialize(
-        config['initializer'], custom_objects=custom_objects)
+    all_initializers = dict(tf_inspect.getmembers(init_ops, tf_inspect.isclass))
+    kwargs['initializer'] = generic_utils.deserialize_keras_object(
+        config['initializer'],
+        module_objects=all_initializers,
+        custom_objects=custom_objects)
     return cls(**kwargs)
 
 
diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py
index 844478c879b..dda1af8a00e 100644
--- a/tensorflow/python/feature_column/feature_column_v2_test.py
+++ b/tensorflow/python/feature_column/feature_column_v2_test.py
@@ -40,7 +40,6 @@ from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import test_util
-from tensorflow.python.keras import initializers
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import lookup_ops
 from tensorflow.python.ops import parsing_ops
@@ -117,6 +116,7 @@ class LazyColumnTest(test.TestCase):
     class TransformCounter(BaseFeatureColumnForTests):
 
       def __init__(self):
+        super(TransformCounter, self).__init__()
         self.num_transform = 0
 
       @property
@@ -4285,6 +4285,7 @@ class TransformFeaturesTest(test.TestCase):
     class _LoggerColumn(BaseFeatureColumnForTests):
 
       def __init__(self, name):
+        super(_LoggerColumn, self).__init__()
         self._name = name
 
       @property
@@ -5362,9 +5363,6 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
     self.assertEqual([categorical_column], embedding_column.parents)
 
     config = embedding_column.get_config()
-    # initializer config contains `dtype` in v1.
-    initializer_config = initializers.serialize(initializers.truncated_normal(
-        mean=0.0, stddev=1 / np.sqrt(2)))
     self.assertEqual(
         {
             'categorical_column': {
@@ -5378,7 +5376,15 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase):
             'ckpt_to_load_from': None,
             'combiner': 'mean',
             'dimension': 2,
-            'initializer': initializer_config,
+            'initializer': {
+                'class_name': 'TruncatedNormal',
+                'config': {
+                    'dtype': 'float32',
+                    'stddev': 0.7071067811865475,
+                    'seed': None,
+                    'mean': 0.0
+                }
+            },
             'max_norm': None,
             'tensor_name_in_ckpt': None,
             'trainable': True,