From 3b2109f7de7689a33c6f94251fe2bd74a1055046 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Mon, 8 Jun 2020 08:46:51 -0700 Subject: [PATCH] Update the base class for _LinearModel in Feature column v1. PiperOrigin-RevId: 315282469 Change-Id: I86a793c89b8098723750c631d858bc360746942d --- tensorflow/python/feature_column/BUILD | 1 - tensorflow/python/feature_column/feature_column.py | 11 ++++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD index 8f62fc2d1be..52f1186c5d9 100644 --- a/tensorflow/python/feature_column/BUILD +++ b/tensorflow/python/feature_column/BUILD @@ -48,7 +48,6 @@ py_library( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/eager:context", - "//tensorflow/python/keras/engine", # TODO(scottzhu): Remove metrics after we cleanup the keras internal cyclar dependency. # //third_party/tensorflow/python/feature_column:feature_column # //third_party/tensorflow/python/keras/engine:engine diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 07df4e914c9..3207fd550b4 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -144,7 +144,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras.engine import training from tensorflow.python.layers import base from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -165,8 +164,8 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_utils from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.compat import collections_abc +from tensorflow.python.util.tf_export import tf_export def _internal_input_layer(features, @@ -616,7 +615,7 @@ def _strip_leading_slashes(name): return name.rsplit('/', 1)[-1] -class _LinearModel(training.Model): +class _LinearModel(base.Layer): """Creates a linear model using feature columns. See `linear_model` for details. @@ -631,6 +630,12 @@ class _LinearModel(training.Model): name=None, **kwargs): super(_LinearModel, self).__init__(name=name, **kwargs) + # We force the keras_style to be True here, as a workaround to not being + # able to inherit keras.layers.Layer as base class. Setting this will let + # us skip all the legacy behavior for base.Layer. + # Also note that we use Layer as base class, instead of Model, since there + # isn't any Model specific behavior gets used, eg compile/fit. + self._keras_style = True self._feature_columns = _normalize_feature_columns( feature_columns) self._weight_collections = list(weight_collections or [])