Remove __init__ entry for keras/wrappers.
This is the first of many cls that remove __init__ file content from keras to prevent hourglass imports. PiperOrigin-RevId: 271414127
This commit is contained in:
parent
52a9ff02c5
commit
49151a0661
@ -42,7 +42,6 @@ from tensorflow.python.keras import premade
|
||||
from tensorflow.python.keras import preprocessing
|
||||
from tensorflow.python.keras import regularizers
|
||||
from tensorflow.python.keras import utils
|
||||
from tensorflow.python.keras import wrappers
|
||||
from tensorflow.python.keras.layers import Input
|
||||
from tensorflow.python.keras.models import Model
|
||||
from tensorflow.python.keras.models import Sequential
|
||||
|
@ -10,6 +10,12 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0 License
|
||||
)
|
||||
|
||||
keras_packages = [
|
||||
"tensorflow.python",
|
||||
"tensorflow.python.keras",
|
||||
"tensorflow.python.keras.wrappers.scikit_learn",
|
||||
]
|
||||
|
||||
gen_api_init_files(
|
||||
name = "keras_python_api_gen",
|
||||
api_name = "keras",
|
||||
@ -20,10 +26,7 @@ gen_api_init_files(
|
||||
"//tensorflow/python/keras",
|
||||
"//tensorflow/python:no_contrib",
|
||||
],
|
||||
packages = [
|
||||
"tensorflow.python",
|
||||
"tensorflow.python.keras",
|
||||
],
|
||||
packages = keras_packages,
|
||||
)
|
||||
|
||||
gen_api_init_files(
|
||||
@ -37,10 +40,7 @@ gen_api_init_files(
|
||||
"//tensorflow/python/keras",
|
||||
"//tensorflow/python:no_contrib",
|
||||
],
|
||||
packages = [
|
||||
"tensorflow.python",
|
||||
"tensorflow.python.keras",
|
||||
],
|
||||
packages = keras_packages,
|
||||
)
|
||||
|
||||
gen_api_init_files(
|
||||
@ -54,8 +54,5 @@ gen_api_init_files(
|
||||
"//tensorflow/python/keras",
|
||||
"//tensorflow/python:no_contrib",
|
||||
],
|
||||
packages = [
|
||||
"tensorflow.python",
|
||||
"tensorflow.python.keras",
|
||||
],
|
||||
packages = keras_packages,
|
||||
)
|
||||
|
@ -1,25 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Wrappers for Keras models, providing compatibility with other frameworks."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.keras.wrappers import scikit_learn
|
||||
|
||||
del absolute_import
|
||||
del division
|
||||
del print_function
|
@ -22,6 +22,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.wrappers import scikit_learn
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
INPUT_DIM = 5
|
||||
@ -103,7 +104,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase):
|
||||
|
||||
def test_classify_build_fn(self):
|
||||
with self.cached_session():
|
||||
clf = keras.wrappers.scikit_learn.KerasClassifier(
|
||||
clf = scikit_learn.KerasClassifier(
|
||||
build_fn=build_fn_clf,
|
||||
hidden_dim=HIDDEN_DIM,
|
||||
batch_size=BATCH_SIZE,
|
||||
@ -119,7 +120,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase):
|
||||
return build_fn_clf(hidden_dim)
|
||||
|
||||
with self.cached_session():
|
||||
clf = keras.wrappers.scikit_learn.KerasClassifier(
|
||||
clf = scikit_learn.KerasClassifier(
|
||||
build_fn=ClassBuildFnClf(),
|
||||
hidden_dim=HIDDEN_DIM,
|
||||
batch_size=BATCH_SIZE,
|
||||
@ -129,7 +130,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase):
|
||||
|
||||
def test_classify_inherit_class_build_fn(self):
|
||||
|
||||
class InheritClassBuildFnClf(keras.wrappers.scikit_learn.KerasClassifier):
|
||||
class InheritClassBuildFnClf(scikit_learn.KerasClassifier):
|
||||
|
||||
def __call__(self, hidden_dim):
|
||||
return build_fn_clf(hidden_dim)
|
||||
@ -145,7 +146,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase):
|
||||
|
||||
def test_regression_build_fn(self):
|
||||
with self.cached_session():
|
||||
reg = keras.wrappers.scikit_learn.KerasRegressor(
|
||||
reg = scikit_learn.KerasRegressor(
|
||||
build_fn=build_fn_reg,
|
||||
hidden_dim=HIDDEN_DIM,
|
||||
batch_size=BATCH_SIZE,
|
||||
@ -161,7 +162,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase):
|
||||
return build_fn_reg(hidden_dim)
|
||||
|
||||
with self.cached_session():
|
||||
reg = keras.wrappers.scikit_learn.KerasRegressor(
|
||||
reg = scikit_learn.KerasRegressor(
|
||||
build_fn=ClassBuildFnReg(),
|
||||
hidden_dim=HIDDEN_DIM,
|
||||
batch_size=BATCH_SIZE,
|
||||
@ -171,7 +172,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase):
|
||||
|
||||
def test_regression_inherit_class_build_fn(self):
|
||||
|
||||
class InheritClassBuildFnReg(keras.wrappers.scikit_learn.KerasRegressor):
|
||||
class InheritClassBuildFnReg(scikit_learn.KerasRegressor):
|
||||
|
||||
def __call__(self, hidden_dim):
|
||||
return build_fn_reg(hidden_dim)
|
||||
|
Loading…
Reference in New Issue
Block a user