diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 9c958588d9d..3c22176dce9 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -28,8 +28,6 @@ py_library( "utils/multi_gpu_utils.py", "utils/np_utils.py", "utils/vis_utils.py", - "wrappers/__init__.py", - "wrappers/scikit_learn.py", ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], @@ -45,6 +43,7 @@ py_library( "//tensorflow/python/keras/mixed_precision/experimental:mixed_precision_experimental", "//tensorflow/python/keras/optimizer_v2", "//tensorflow/python/keras/premade", + "//tensorflow/python/keras/wrappers", "//tensorflow/python/saved_model", ], ) @@ -1275,20 +1274,6 @@ tf_py_test( ], ) -tf_py_test( - name = "scikit_learn_test", - size = "small", - srcs = ["wrappers/scikit_learn_test.py"], - python_version = "PY3", - tags = ["notsan"], - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], -) - tf_py_test( name = "data_utils_test", size = "medium", diff --git a/tensorflow/python/keras/wrappers/BUILD b/tensorflow/python/keras/wrappers/BUILD new file mode 100644 index 00000000000..9020140d9ec --- /dev/null +++ b/tensorflow/python/keras/wrappers/BUILD @@ -0,0 +1,41 @@ +# Description: +# Contains the Keras wrapper API (internal TensorFlow version). + +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(["LICENSE"]) + +py_library( + name = "wrappers", + srcs = [ + "__init__.py", + "scikit_learn.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:util", + "//tensorflow/python/keras:engine", + "//tensorflow/python/keras:generic_utils", + "//tensorflow/python/keras:losses", + "//third_party/py/numpy", + ], +) + +tf_py_test( + name = "scikit_learn_test", + size = "small", + srcs = ["scikit_learn_test.py"], + python_version = "PY3", + tags = ["notsan"], + deps = [ + ":wrappers", + "//tensorflow/python:client_testlib", + "//tensorflow/python:extra_py_tests_deps", + "//third_party/py/numpy", + ], +)