Add BUILD rules for python/training and python/training/experimental

There were a couple issues around op generation and strict dep checking.
 - A genrule that needs to be in python/ was adding a file to python/training, apparently not OK across module boundaries. I've just stopped it from adding the file to python/training and added a Python redirect file for now.
 - I've added rules for files that were globbed together previously, but strict dep checking means we still need to include these as srcs in the rule that previously had them. They're listed explicitly rather than globbed.

Otherwise just moving rules, adding aliases, and running build_cleaner.

PiperOrigin-RevId: 329320168
Change-Id: I8494424e332c3bc21263ce1f8caaf5bd4d32d26c
This commit is contained in:
Allen Lavoie 2020-08-31 09:45:16 -07:00 committed by TensorFlower Gardener
parent 76e2e756e5
commit 5574be6465
6 changed files with 1686 additions and 995 deletions

File diff suppressed because it is too large Load Diff

View File

@ -180,13 +180,13 @@ tf_py_test(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:saver_test_utils",
"//tensorflow/python:session",
"//tensorflow/python:state_ops",
"//tensorflow/python:test_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python:variables",
"//tensorflow/python/training:saver_test_utils",
],
)
@ -449,10 +449,10 @@ py_strict_library(
"//tensorflow/python:platform",
"//tensorflow/python:saver",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:training_lib",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:lift_to_graph",
"//tensorflow/python/eager:wrap_function",
"//tensorflow/python/training:monitored_session",
"//tensorflow/python/training/tracking",
],
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,112 @@
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
package(
default_visibility = ["//tensorflow:internal"],
licenses = ["notice"], # Apache 2.0
)
py_library(
name = "loss_scale",
srcs = ["loss_scale.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:framework",
"@absl_py//absl/testing:parameterized",
],
)
py_library(
name = "loss_scale_optimizer",
srcs = ["loss_scale_optimizer.py"],
srcs_version = "PY2AND3",
deps = [
":loss_scale",
"//tensorflow/python/distribute:distribute_lib",
"@absl_py//absl/testing:parameterized",
],
)
# The test currently requires visibility only granted to tensorflow/python:__pkg__
exports_files(
["loss_scale_optimizer_test.py"],
visibility = ["//tensorflow/python:__pkg__"],
)
py_test(
name = "loss_scale_test",
size = "medium",
srcs = ["loss_scale_test.py"],
python_version = "PY3",
deps = [
":loss_scale",
"//tensorflow/python:client_testlib",
"//tensorflow/python/distribute:mirrored_strategy",
"//tensorflow/python/distribute:one_device_strategy",
"@absl_py//absl/testing:parameterized",
],
)
py_library(
name = "mixed_precision_global_state",
srcs = ["mixed_precision_global_state.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "mixed_precision",
srcs = ["mixed_precision.py"],
srcs_version = "PY2AND3",
deps = [
":loss_scale",
":loss_scale_optimizer",
":mixed_precision_global_state",
"//tensorflow/python:config",
"//tensorflow/python:util",
],
)
cuda_py_test(
name = "mixed_precision_test",
size = "small",
srcs = ["mixed_precision_test.py"],
python_version = "PY3",
tfrt_enabled = True,
deps = [
":mixed_precision",
"//tensorflow/python:client_testlib",
"@absl_py//absl/testing:parameterized",
],
)
py_library(
name = "loss_scaling_gradient_tape",
srcs = ["loss_scaling_gradient_tape.py"],
srcs_version = "PY2AND3",
deps = [
":loss_scale",
"//tensorflow/python:array_ops",
"//tensorflow/python:unconnected_gradients",
"//tensorflow/python:util",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/eager:backprop",
],
)
cuda_py_test(
name = "loss_scaling_gradient_tape_test",
size = "medium",
srcs = ["loss_scaling_gradient_tape_test.py"],
shard_count = 2,
deps = [
":loss_scale",
":loss_scaling_gradient_tape",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_test_combinations_lib",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/distribute:mirrored_strategy",
"//tensorflow/python/eager:def_function",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -0,0 +1,29 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python wrappers for training ops."""
# NOTE(allenl): The generated op wrappers for training ops were originally in
# training/gen_training_ops.py. They moved to ops/gen_training_ops.py when
# training/ became a module, and this is an alias to avoid breaking existing
# imports.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_training_ops import *
# pylint: enable=wildcard-import

View File

@ -19,8 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.training import gen_training_ops # pylint: disable=unused-import
from tensorflow.python.ops import gen_training_ops # pylint: disable=unused-import
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.training.gen_training_ops import *
from tensorflow.python.ops.gen_training_ops import *
# pylint: enable=wildcard-import