From 49151a0661405a6ef1c957743228c6ed69edbbb0 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Thu, 26 Sep 2019 13:06:55 -0700 Subject: [PATCH] Remove __init__ entry for keras/wrappers. This is the first of many cls that remove __init__ file content from keras to prevent hourglass imports. PiperOrigin-RevId: 271414127 --- tensorflow/python/keras/__init__.py | 1 - tensorflow/python/keras/api/BUILD | 21 +++++++--------- tensorflow/python/keras/wrappers/__init__.py | 25 ------------------- .../keras/wrappers/scikit_learn_test.py | 13 +++++----- 4 files changed, 16 insertions(+), 44 deletions(-) diff --git a/tensorflow/python/keras/__init__.py b/tensorflow/python/keras/__init__.py index 96552549a27..7321876b2ad 100644 --- a/tensorflow/python/keras/__init__.py +++ b/tensorflow/python/keras/__init__.py @@ -42,7 +42,6 @@ from tensorflow.python.keras import premade from tensorflow.python.keras import preprocessing from tensorflow.python.keras import regularizers from tensorflow.python.keras import utils -from tensorflow.python.keras import wrappers from tensorflow.python.keras.layers import Input from tensorflow.python.keras.models import Model from tensorflow.python.keras.models import Sequential diff --git a/tensorflow/python/keras/api/BUILD b/tensorflow/python/keras/api/BUILD index 0bf58b35f15..879aeabdc24 100644 --- a/tensorflow/python/keras/api/BUILD +++ b/tensorflow/python/keras/api/BUILD @@ -10,6 +10,12 @@ package( licenses = ["notice"], # Apache 2.0 License ) +keras_packages = [ + "tensorflow.python", + "tensorflow.python.keras", + "tensorflow.python.keras.wrappers.scikit_learn", +] + gen_api_init_files( name = "keras_python_api_gen", api_name = "keras", @@ -20,10 +26,7 @@ gen_api_init_files( "//tensorflow/python/keras", "//tensorflow/python:no_contrib", ], - packages = [ - "tensorflow.python", - "tensorflow.python.keras", - ], + packages = keras_packages, ) gen_api_init_files( @@ -37,10 +40,7 @@ gen_api_init_files( "//tensorflow/python/keras", "//tensorflow/python:no_contrib", ], - packages = [ - "tensorflow.python", - "tensorflow.python.keras", - ], + packages = keras_packages, ) gen_api_init_files( @@ -54,8 +54,5 @@ gen_api_init_files( "//tensorflow/python/keras", "//tensorflow/python:no_contrib", ], - packages = [ - "tensorflow.python", - "tensorflow.python.keras", - ], + packages = keras_packages, ) diff --git a/tensorflow/python/keras/wrappers/__init__.py b/tensorflow/python/keras/wrappers/__init__.py index da579a7ab58..e69de29bb2d 100644 --- a/tensorflow/python/keras/wrappers/__init__.py +++ b/tensorflow/python/keras/wrappers/__init__.py @@ -1,25 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Wrappers for Keras models, providing compatibility with other frameworks.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.keras.wrappers import scikit_learn - -del absolute_import -del division -del print_function diff --git a/tensorflow/python/keras/wrappers/scikit_learn_test.py b/tensorflow/python/keras/wrappers/scikit_learn_test.py index f9042908033..30b8d75a561 100644 --- a/tensorflow/python/keras/wrappers/scikit_learn_test.py +++ b/tensorflow/python/keras/wrappers/scikit_learn_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.python import keras from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.wrappers import scikit_learn from tensorflow.python.platform import test INPUT_DIM = 5 @@ -103,7 +104,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase): def test_classify_build_fn(self): with self.cached_session(): - clf = keras.wrappers.scikit_learn.KerasClassifier( + clf = scikit_learn.KerasClassifier( build_fn=build_fn_clf, hidden_dim=HIDDEN_DIM, batch_size=BATCH_SIZE, @@ -119,7 +120,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase): return build_fn_clf(hidden_dim) with self.cached_session(): - clf = keras.wrappers.scikit_learn.KerasClassifier( + clf = scikit_learn.KerasClassifier( build_fn=ClassBuildFnClf(), hidden_dim=HIDDEN_DIM, batch_size=BATCH_SIZE, @@ -129,7 +130,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase): def test_classify_inherit_class_build_fn(self): - class InheritClassBuildFnClf(keras.wrappers.scikit_learn.KerasClassifier): + class InheritClassBuildFnClf(scikit_learn.KerasClassifier): def __call__(self, hidden_dim): return build_fn_clf(hidden_dim) @@ -145,7 +146,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase): def test_regression_build_fn(self): with self.cached_session(): - reg = keras.wrappers.scikit_learn.KerasRegressor( + reg = scikit_learn.KerasRegressor( build_fn=build_fn_reg, hidden_dim=HIDDEN_DIM, batch_size=BATCH_SIZE, @@ -161,7 +162,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase): return build_fn_reg(hidden_dim) with self.cached_session(): - reg = keras.wrappers.scikit_learn.KerasRegressor( + reg = scikit_learn.KerasRegressor( build_fn=ClassBuildFnReg(), hidden_dim=HIDDEN_DIM, batch_size=BATCH_SIZE, @@ -171,7 +172,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase): def test_regression_inherit_class_build_fn(self): - class InheritClassBuildFnReg(keras.wrappers.scikit_learn.KerasRegressor): + class InheritClassBuildFnReg(scikit_learn.KerasRegressor): def __call__(self, hidden_dim): return build_fn_reg(hidden_dim)