Remove __init__ entry for keras/wrappers.

This is the first of many cls that remove __init__ file content from keras to
prevent hourglass imports.

PiperOrigin-RevId: 271414127
This commit is contained in:
Scott Zhu 2019-09-26 13:06:55 -07:00 committed by TensorFlower Gardener
parent 52a9ff02c5
commit 49151a0661
4 changed files with 16 additions and 44 deletions

View File

@ -42,7 +42,6 @@ from tensorflow.python.keras import premade
from tensorflow.python.keras import preprocessing
from tensorflow.python.keras import regularizers
from tensorflow.python.keras import utils
from tensorflow.python.keras import wrappers
from tensorflow.python.keras.layers import Input
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.models import Sequential

View File

@ -10,6 +10,12 @@ package(
licenses = ["notice"], # Apache 2.0 License
)
keras_packages = [
"tensorflow.python",
"tensorflow.python.keras",
"tensorflow.python.keras.wrappers.scikit_learn",
]
gen_api_init_files(
name = "keras_python_api_gen",
api_name = "keras",
@ -20,10 +26,7 @@ gen_api_init_files(
"//tensorflow/python/keras",
"//tensorflow/python:no_contrib",
],
packages = [
"tensorflow.python",
"tensorflow.python.keras",
],
packages = keras_packages,
)
gen_api_init_files(
@ -37,10 +40,7 @@ gen_api_init_files(
"//tensorflow/python/keras",
"//tensorflow/python:no_contrib",
],
packages = [
"tensorflow.python",
"tensorflow.python.keras",
],
packages = keras_packages,
)
gen_api_init_files(
@ -54,8 +54,5 @@ gen_api_init_files(
"//tensorflow/python/keras",
"//tensorflow/python:no_contrib",
],
packages = [
"tensorflow.python",
"tensorflow.python.keras",
],
packages = keras_packages,
)

View File

@ -1,25 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrappers for Keras models, providing compatibility with other frameworks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.keras.wrappers import scikit_learn
del absolute_import
del division
del print_function

View File

@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.wrappers import scikit_learn
from tensorflow.python.platform import test
INPUT_DIM = 5
@ -103,7 +104,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase):
def test_classify_build_fn(self):
with self.cached_session():
clf = keras.wrappers.scikit_learn.KerasClassifier(
clf = scikit_learn.KerasClassifier(
build_fn=build_fn_clf,
hidden_dim=HIDDEN_DIM,
batch_size=BATCH_SIZE,
@ -119,7 +120,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase):
return build_fn_clf(hidden_dim)
with self.cached_session():
clf = keras.wrappers.scikit_learn.KerasClassifier(
clf = scikit_learn.KerasClassifier(
build_fn=ClassBuildFnClf(),
hidden_dim=HIDDEN_DIM,
batch_size=BATCH_SIZE,
@ -129,7 +130,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase):
def test_classify_inherit_class_build_fn(self):
class InheritClassBuildFnClf(keras.wrappers.scikit_learn.KerasClassifier):
class InheritClassBuildFnClf(scikit_learn.KerasClassifier):
def __call__(self, hidden_dim):
return build_fn_clf(hidden_dim)
@ -145,7 +146,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase):
def test_regression_build_fn(self):
with self.cached_session():
reg = keras.wrappers.scikit_learn.KerasRegressor(
reg = scikit_learn.KerasRegressor(
build_fn=build_fn_reg,
hidden_dim=HIDDEN_DIM,
batch_size=BATCH_SIZE,
@ -161,7 +162,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase):
return build_fn_reg(hidden_dim)
with self.cached_session():
reg = keras.wrappers.scikit_learn.KerasRegressor(
reg = scikit_learn.KerasRegressor(
build_fn=ClassBuildFnReg(),
hidden_dim=HIDDEN_DIM,
batch_size=BATCH_SIZE,
@ -171,7 +172,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase):
def test_regression_inherit_class_build_fn(self):
class InheritClassBuildFnReg(keras.wrappers.scikit_learn.KerasRegressor):
class InheritClassBuildFnReg(scikit_learn.KerasRegressor):
def __call__(self, hidden_dim):
return build_fn_reg(hidden_dim)