From 5574be6465472d32f5849e353b150d45eb7738f2 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Mon, 31 Aug 2020 09:45:16 -0700 Subject: [PATCH] Add BUILD rules for python/training and python/training/experimental There were a couple issues around op generation and strict dep checking. - A genrule that needs to be in python/ was adding a file to python/training, apparently not OK across module boundaries. I've just stopped it from adding the file to python/training and added a Python redirect file for now. - I've added rules for files that were globbed together previously, but strict dep checking means we still need to include these as srcs in the rule that previously had them. They're listed explicitly rather than globbed. Otherwise just moving rules, adding aliases, and running build_cleaner. PiperOrigin-RevId: 329320168 Change-Id: I8494424e332c3bc21263ce1f8caaf5bd4d32d26c --- tensorflow/python/BUILD | 1108 ++----------- tensorflow/python/saved_model/BUILD | 4 +- tensorflow/python/training/BUILD | 1424 +++++++++++++++++ tensorflow/python/training/experimental/BUILD | 112 ++ .../python/training/gen_training_ops.py | 29 + tensorflow/python/training/training_ops.py | 4 +- 6 files changed, 1686 insertions(+), 995 deletions(-) create mode 100644 tensorflow/python/training/BUILD create mode 100644 tensorflow/python/training/experimental/BUILD create mode 100644 tensorflow/python/training/gen_training_ops.py diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index f39797f8158..e28baa53908 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -185,7 +185,6 @@ py_library( ":pywrap_tf_session", ":pywrap_tfe", ":rnn_ops_gen", - ":saver_test_utils", ":script_ops", ":sendrecv_ops_gen", ":session_ops", @@ -239,6 +238,7 @@ py_library( "//tensorflow/python/tools:module_util", "//tensorflow/python/tools/api/generator:create_python_api", "//tensorflow/python/tpu:tpu_noestimator", + "//tensorflow/python/training:saver_test_utils", "//tensorflow/python/types", "//third_party/py/numpy", ], @@ -761,7 +761,7 @@ tf_python_pybind_extension( tf_python_pybind_extension( name = "_pywrap_quantize_training", srcs = [ - "training/quantize_training_wrapper.cc", + "//tensorflow/python/training:quantize_training_wrapper.cc", ], hdrs = ["//tensorflow/core/common_runtime:quantize_training_hdrs"], module_name = "_pywrap_quantize_training", @@ -2903,6 +2903,7 @@ tf_gen_op_wrapper_private_py( "//tensorflow/compiler/tests:__pkg__", "//tensorflow/python/kernel_tests:__pkg__", "//tensorflow/python/kernel_tests/v1_compat_tests:__pkg__", + "//tensorflow/python/training:__pkg__", ], deps = [ "//tensorflow/c/kernels:bitcast_op_lib", @@ -2961,7 +2962,10 @@ tf_gen_op_wrapper_private_py( tf_gen_op_wrapper_private_py( name = "checkpoint_ops_gen", - visibility = ["//tensorflow/python/kernel_tests:__pkg__"], + visibility = [ + "//tensorflow/python/kernel_tests:__pkg__", + "//tensorflow/python/training:__pkg__", + ], ) tf_gen_op_wrapper_private_py( @@ -3001,6 +3005,7 @@ tf_gen_op_wrapper_private_py( visibility = [ "//learning/brain/python/ops:__pkg__", "//tensorflow/python/kernel_tests:__pkg__", + "//tensorflow/python/training:__pkg__", ], ) @@ -3032,6 +3037,7 @@ tf_gen_op_wrapper_private_py( visibility = [ "//learning/brain/python/ops:__pkg__", "//tensorflow/python/kernel_tests:__pkg__", + "//tensorflow/python/training:__pkg__", "//tensorflow/python/training/tracking:__pkg__", ], ) @@ -3060,6 +3066,7 @@ tf_gen_op_wrapper_private_py( visibility = [ "//learning/brain/python/ops:__pkg__", "//tensorflow/python/kernel_tests:__pkg__", + "//tensorflow/python/training:__pkg__", ], ) @@ -3197,6 +3204,7 @@ tf_gen_op_wrapper_private_py( visibility = [ "//learning/brain/python/ops:__pkg__", "//tensorflow/python/kernel_tests:__pkg__", + "//tensorflow/python/training:__pkg__", ], ) @@ -3219,7 +3227,9 @@ tf_gen_op_wrapper_private_py( tf_gen_op_wrapper_private_py( name = "training_ops_gen", - out = "training/gen_training_ops.py", + visibility = [ + "//tensorflow/python/training:__pkg__", + ], ) tf_gen_op_wrapper_private_py( @@ -4157,121 +4167,6 @@ py_library( ], ) -py_library( - name = "loss_scale", - srcs = ["training/experimental/loss_scale.py"], - srcs_version = "PY2AND3", - deps = [ - ":framework", - "@absl_py//absl/testing:parameterized", - ], -) - -py_library( - name = "loss_scale_optimizer", - srcs = ["training/experimental/loss_scale_optimizer.py"], - srcs_version = "PY2AND3", - deps = [ - ":loss_scale", - "//tensorflow/python/distribute:distribute_lib", - "@absl_py//absl/testing:parameterized", - ], -) - -py_test( - name = "loss_scale_optimizer_test", - size = "small", - srcs = ["training/experimental/loss_scale_optimizer_test.py"], - python_version = "PY3", - deps = [ - ":client_testlib", - ":loss_scale_optimizer", - "//tensorflow/python/distribute:mirrored_strategy", - "//tensorflow/python/distribute:one_device_strategy", - "//tensorflow/python/keras/mixed_precision/experimental:test_util", - "@absl_py//absl/testing:parameterized", - ], -) - -py_test( - name = "loss_scale_test", - size = "medium", - srcs = ["training/experimental/loss_scale_test.py"], - python_version = "PY3", - deps = [ - ":client_testlib", - ":loss_scale", - "//tensorflow/python/distribute:mirrored_strategy", - "//tensorflow/python/distribute:one_device_strategy", - "@absl_py//absl/testing:parameterized", - ], -) - -py_library( - name = "mixed_precision_global_state", - srcs = ["training/experimental/mixed_precision_global_state.py"], - srcs_version = "PY2AND3", -) - -py_library( - name = "mixed_precision", - srcs = ["training/experimental/mixed_precision.py"], - srcs_version = "PY2AND3", - deps = [ - ":config", - ":loss_scale", - ":loss_scale_optimizer", - ":mixed_precision_global_state", - ":util", - ], -) - -cuda_py_test( - name = "mixed_precision_test", - size = "small", - srcs = ["training/experimental/mixed_precision_test.py"], - python_version = "PY3", - tfrt_enabled = True, - deps = [ - ":client_testlib", - ":mixed_precision", - "@absl_py//absl/testing:parameterized", - ], -) - -py_library( - name = "loss_scaling_gradient_tape", - srcs = ["training/experimental/loss_scaling_gradient_tape.py"], - srcs_version = "PY2AND3", - deps = [ - ":array_ops", - ":loss_scale", - ":unconnected_gradients", - ":util", - "//tensorflow/python/distribute:distribute_lib", - "//tensorflow/python/eager:backprop", - ], -) - -cuda_py_test( - name = "loss_scaling_gradient_tape_test", - size = "medium", - srcs = ["training/experimental/loss_scaling_gradient_tape_test.py"], - shard_count = 2, - deps = [ - ":client_testlib", - ":constant_op", - ":framework_test_combinations_lib", - ":loss_scale", - ":loss_scaling_gradient_tape", - "//tensorflow/python/compat:v2_compat", - "//tensorflow/python/distribute:mirrored_strategy", - "//tensorflow/python/eager:def_function", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], -) - py_library( name = "math_grad", srcs = ["ops/math_grad.py"], @@ -4876,7 +4771,6 @@ py_library( ":linalg_ops", ":logging_ops", ":lookup_ops", - ":loss_scaling_gradient_tape", ":manip_grad", ":manip_ops", ":math_grad", @@ -4911,6 +4805,7 @@ py_library( "//tensorflow/python/ops/linalg/sparse", "//tensorflow/python/ops/ragged", "//tensorflow/python/ops/structured", + "//tensorflow/python/training/experimental:loss_scaling_gradient_tape", ], ) @@ -5514,231 +5409,6 @@ tf_py_test( ], ) -py_library( - name = "training_lib", - srcs = glob( - ["training/**/*.py"], - exclude = [ - "**/*test*", - "training/tracking/**/*.py", - "training/saving/**/*.py", - # The following targets have their own build rules (same name as the - # file): - "training/basic_session_run_hooks.py", - "training/checkpoint_management.py", - "training/distribute.py", - "training/distribution_strategy_context.py", - "training/saver.py", - "training/session_run_hook.py", - "training/training_util.py", - ], - ), - srcs_version = "PY2AND3", - deps = [ - ":array_ops", - ":array_ops_gen", - ":basic_session_run_hooks", - ":checkpoint_management", - ":checkpoint_ops_gen", - ":client", - ":control_flow_ops", - ":data_flow_ops", - ":device", - ":device_spec", - ":distribute", - ":errors", - ":framework", - ":framework_for_generated_wrappers", - ":framework_ops", - ":gradients", - ":init_ops", - ":io_ops", - ":layers_util", - ":lookup_ops", - ":loss_scale", - ":loss_scale_optimizer", - ":math_ops", - ":mixed_precision", - ":platform", - ":py_checkpoint_reader", - ":pywrap_tensorflow", - ":random_ops", - ":resource_variable_ops", - ":resources", - ":saver", - ":sdca_ops", - ":session", - ":session_run_hook", - ":sparse_ops", - ":sparse_tensor", - ":state_ops", - ":summary", - ":training_ops_gen", - ":training_util", - ":util", - ":variable_scope", - ":variables", - "//tensorflow/core:protos_all_py", - "//tensorflow/python/data/experimental/service:server_lib", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/distribute:distribute_coordinator_context", - "//tensorflow/python/distribute:distribute_lib", - "//tensorflow/python/distribute:reduce_util", - "//tensorflow/python/eager:backprop", - "//tensorflow/python/eager:context", - "//tensorflow/python/keras/optimizer_v2:legacy_learning_rate_decay", - "//tensorflow/python/ops/losses", - "//third_party/py/numpy", - "@six_archive//:six", - ], -) - -py_library( - name = "training", - srcs_version = "PY2AND3", - deps = [ - ":training_lib", - "//tensorflow/python/training/tracking:base", - "//tensorflow/python/training/tracking:python_state", - "//tensorflow/python/training/tracking:util", - ], -) - -# Dependency added and used by ClusterResolvers to avoid circular dependency between keras, distribute, and training. -py_library( - name = "training_server_lib", - srcs = ["training/server_lib.py"], - srcs_version = "PY2AND3", - deps = [ - ":framework", - ":pywrap_tf_session", - ":util", - "//tensorflow/core:protos_all_py", - ], -) - -py_library( - name = "py_checkpoint_reader", - srcs = ["training/py_checkpoint_reader.py"], - deps = [ - ":_pywrap_checkpoint_reader", - ":dtypes", - ":errors", - ":util", - ], -) - -py_library( - name = "checkpoint_management", - srcs = ["training/checkpoint_management.py"], - deps = [ - ":errors", - ":lib", - ":platform", - ":protos_all_py", - ":util", - "//tensorflow/core:protos_all_py", - ], -) - -py_library( - name = "session_run_hook", - srcs = ["training/session_run_hook.py"], - srcs_version = "PY2AND3", - deps = [":util"], -) - -py_library( - name = "basic_session_run_hooks", - srcs = ["training/basic_session_run_hooks.py"], - srcs_version = "PY2AND3", - deps = [ - ":client", - ":framework", - ":platform", - ":protos_all_py", - ":session_run_hook", - ":training_util", - ":util", - ], -) - -py_library( - name = "saver", - srcs = ["training/saver.py"], - srcs_version = "PY2AND3", - deps = [ - ":array_ops", - ":checkpoint_management", - ":constant_op", - ":control_flow_ops", - ":device", - ":errors", - ":framework", - ":framework_ops", - ":io_ops", - ":io_ops_gen", - ":platform", - ":py_checkpoint_reader", - ":resource_variable_ops", - ":session", - ":state_ops", - ":string_ops", - ":training_util", - ":util", - ":variables", - "//tensorflow/core:protos_all_py", - "//tensorflow/python/eager:context", - "//tensorflow/python/training/saving:saveable_object", - "//tensorflow/python/training/saving:saveable_object_util", - "//tensorflow/python/training/tracking:base", - "//third_party/py/numpy", - "@six_archive//:six", - ], -) - -py_library( - name = "distribute", - srcs = [ - "training/distribute.py", - "training/distribution_strategy_context.py", - ], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python/distribute:distribute_lib", - ], -) - -tf_py_test( - name = "evaluation_test", - size = "small", - srcs = ["training/evaluation_test.py"], - python_version = "PY3", - shard_count = 3, - tags = [ - "manual", - "notap", # Disabling until b/33000128 and b/33040312 are fixed. - ], - deps = [ - ":array_ops", - ":client", - ":client_testlib", - ":framework", - ":framework_for_generated_wrappers", - ":framework_test_lib", - ":math_ops", - ":metrics", - ":platform", - ":state_ops", - ":summary", - ":training", - ":variables", - "//tensorflow/core:protos_all_py", - "//tensorflow/python/ops/losses", - "//third_party/py/numpy", - ], -) - py_library( name = "client", srcs = [ @@ -6070,6 +5740,7 @@ tf_proto_library( "framework/cpp_shape_inference.proto", ], ), + protodeps = ["//tensorflow/python/training:checkpoint_state"], visibility = visibility, ) @@ -6539,189 +6210,6 @@ py_library( ], ) -tf_py_test( - name = "server_lib_test", - size = "small", - srcs = ["training/server_lib_test.py"], - grpc_enabled = True, - python_version = "PY3", - tags = [ - "noasan", # TODO(b/161236904): flaky timeout in trying to start gRPC server - ], - tfrt_enabled = True, - deps = [ - ":array_ops", - ":client", - ":client_testlib", - ":data_flow_ops", - ":errors", - ":framework_for_generated_wrappers", - ":math_ops", - ":training", - ":variables", - "//tensorflow/core:protos_all_py", - "//third_party/py/numpy", - ], -) - -tf_py_test( - name = "server_lib_multiple_containers_test", - size = "small", - srcs = ["training/server_lib_multiple_containers_test.py"], - grpc_enabled = True, - python_version = "PY3", - tfrt_enabled = True, - deps = [ - ":array_ops", - ":client", - ":client_testlib", - ":data_flow_ops", - ":errors", - ":framework_for_generated_wrappers", - ":math_ops", - ":training", - ":variables", - "//tensorflow/core:protos_all_py", - "//third_party/py/numpy", - ], -) - -tf_py_test( - name = "server_lib_same_variables_clear_container_test", - size = "small", - srcs = ["training/server_lib_same_variables_clear_container_test.py"], - grpc_enabled = True, - python_version = "PY3", - tfrt_enabled = True, - deps = [ - ":array_ops", - ":client", - ":client_testlib", - ":data_flow_ops", - ":errors", - ":framework_for_generated_wrappers", - ":math_ops", - ":training", - ":variables", - "//tensorflow/core:protos_all_py", - "//third_party/py/numpy", - ], -) - -tf_py_test( - name = "server_lib_same_variables_clear_test", - size = "small", - srcs = ["training/server_lib_same_variables_clear_test.py"], - grpc_enabled = True, - python_version = "PY3", - tfrt_enabled = True, - deps = [ - ":array_ops", - ":client", - ":client_testlib", - ":data_flow_ops", - ":errors", - ":framework_for_generated_wrappers", - ":math_ops", - ":training", - ":variables", - "//tensorflow/core:protos_all_py", - "//third_party/py/numpy", - ], -) - -tf_py_test( - name = "server_lib_same_variables_no_clear_test", - size = "small", - srcs = ["training/server_lib_same_variables_no_clear_test.py"], - grpc_enabled = True, - python_version = "PY3", - tfrt_enabled = True, - deps = [ - ":array_ops", - ":client", - ":client_testlib", - ":data_flow_ops", - ":errors", - ":framework_for_generated_wrappers", - ":math_ops", - ":training", - ":variables", - "//tensorflow/core:protos_all_py", - "//third_party/py/numpy", - ], -) - -tf_py_test( - name = "server_lib_sparse_job_test", - size = "small", - srcs = ["training/server_lib_sparse_job_test.py"], - grpc_enabled = True, - python_version = "PY3", - tfrt_enabled = True, - deps = [ - ":array_ops", - ":client", - ":client_testlib", - ":data_flow_ops", - ":errors", - ":framework_for_generated_wrappers", - ":math_ops", - ":training", - ":variables", - "//tensorflow/core:protos_all_py", - "//third_party/py/numpy", - ], -) - -cuda_py_test( - name = "localhost_cluster_performance_test", - size = "medium", - srcs = [ - "training/localhost_cluster_performance_test.py", - ], - grpc_enabled = True, - python_version = "PY3", - tags = [ - "no_oss", # Test flaky due to port collisions. - "oss_serial", - ], - tfrt_enabled = True, - deps = [ - ":client", - ":client_testlib", - ":distributed_framework_test_lib", - ":framework_for_generated_wrappers", - ":partitioned_variables", - ":training", - ":variable_scope", - ":variables", - "//third_party/py/numpy", - ], -) - -tf_py_test( - name = "sync_replicas_optimizer_test", - size = "medium", - srcs = [ - "training/sync_replicas_optimizer_test.py", - ], - grpc_enabled = True, - python_version = "PY3", - tags = [ - "no_oss", # Test flaky due to port collisions. - "notsan", # data race due to b/62910646 - "oss_serial", - ], - tfrt_enabled = True, - deps = [ - ":client_testlib", - ":framework_for_generated_wrappers", - ":training", - ":variables", - ], -) - py_library( name = "timeline", srcs = ["client/timeline.py"], @@ -6987,467 +6475,6 @@ tf_py_test( ], ) -cuda_py_test( - name = "adam_test", - size = "medium", - srcs = ["training/adam_test.py"], - python_version = "PY3", - tfrt_enabled = True, - deps = [ - ":array_ops", - ":client_testlib", - ":framework", - ":math_ops", - ":platform", - ":platform_test", - ":training", - "//third_party/py/numpy", - ], -) - -cuda_py_test( - name = "moving_averages_test", - size = "small", - srcs = [ - "training/moving_averages_test.py", - ], - python_version = "PY3", - tags = [ - "no_windows", # b/139083295: bfloat16 tests fail on Windows - "notsan", - ], - tfrt_enabled = True, - deps = [ - ":array_ops", - ":client_testlib", - ":constant_op", - ":dtypes", - ":framework_for_generated_wrappers", - ":framework_ops", - ":training", - ":variable_scope", - ":variables", - ], -) - -cuda_py_tests( - name = "training_tests", - size = "medium", - srcs = [ - "training/adadelta_test.py", - "training/adagrad_da_test.py", - "training/adagrad_test.py", - "training/basic_loops_test.py", - "training/coordinator_test.py", - "training/device_setter_test.py", - "training/ftrl_test.py", - "training/gradient_descent_test.py", - "training/momentum_test.py", - "training/optimizer_test.py", - "training/proximal_adagrad_test.py", - "training/proximal_gradient_descent_test.py", - "training/quantize_training_test.py", - "training/queue_runner_test.py", - "training/rmsprop_test.py", - "training/slot_creator_test.py", - "training/tensorboard_logging_test.py", - "training/training_ops_test.py", - ], - python_version = "PY3", - deps = [ - ":array_ops", - ":client", - ":client_testlib", - ":control_flow_ops", - ":data_flow_ops", - ":data_flow_ops_gen", - ":embedding_ops", - ":errors", - ":framework", - ":framework_for_generated_wrappers", - ":framework_test_lib", - ":gradients", - ":lookup_ops", - ":math_ops", - ":nn_grad", - ":nn_ops", - ":partitioned_variables", - ":platform", - ":platform_test", - ":pywrap_tensorflow", - ":random_ops", - ":resource_variable_ops", - ":resources", - ":sparse_ops", - ":state_ops", - ":state_ops_gen", - ":summary", - ":training", - ":util", - ":variable_scope", - ":variables", - "//tensorflow/core:protos_all_py", - "//third_party/py/numpy", - "@six_archive//:six", - ], -) - -py_library( - name = "saver_test_utils", - srcs = ["training/saver_test_utils.py"], - srcs_version = "PY2AND3", - deps = [ - ":lookup_ops_gen", - ":training", - ], -) - -cuda_py_test( - name = "saver_test", - size = "medium", - srcs = [ - "training/saver_test.py", - ], - python_version = "PY3", - tags = ["multi_gpu"], - deps = [ - ":array_ops", - ":client_testlib", - ":control_flow_ops", - ":data_flow_ops", - ":errors", - ":gradients", - ":math_ops", - ":nn_grad", - ":nn_ops", - ":partitioned_variables", - ":platform", - ":platform_test", - ":py_checkpoint_reader", - ":random_ops", - ":resource_variable_ops", - ":saver_test_utils", - ":sparse_ops", - ":summary", - ":training", - ":util", - ":variable_scope", - ":variables", - "//tensorflow/core:protos_all_py", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - "@six_archive//:six", - ], -) - -cuda_py_test( - name = "checkpoint_management_test", - size = "small", - srcs = [ - "training/checkpoint_management_test.py", - ], - python_version = "PY3", - deps = [ - ":array_ops", - ":client_testlib", - ":control_flow_ops", - ":data_flow_ops", - ":errors", - ":gradients", - ":math_ops", - ":nn_grad", - ":nn_ops", - ":partitioned_variables", - ":platform", - ":platform_test", - ":pywrap_tensorflow", - ":random_ops", - ":resource_variable_ops", - ":saver_test_utils", - ":sparse_ops", - ":summary", - ":training", - ":util", - ":variable_scope", - ":variables", - "//tensorflow/core:protos_all_py", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - "@six_archive//:six", - ], -) - -tf_py_test( - name = "saver_large_variable_test", - size = "medium", - srcs = ["training/saver_large_variable_test.py"], - python_version = "PY3", - tags = [ - "manual", - "noasan", # http://b/30379628 - "notsan", # http://b/30379628 - ], - tfrt_enabled = True, - deps = [ - ":client", - ":client_testlib", - ":errors", - ":framework_for_generated_wrappers", - ":training", - ":variables", - "//tensorflow/core:protos_all_py", - ], -) - -tf_py_test( - name = "saver_large_partitioned_variable_test", - size = "medium", - srcs = ["training/saver_large_partitioned_variable_test.py"], - python_version = "PY3", - tags = [ - "noasan", # http://b/30782289 - "notsan", # http://b/30782289 - ], - tfrt_enabled = True, - deps = [ - ":client", - ":client_testlib", - ":framework_for_generated_wrappers", - ":partitioned_variables", - ":training", - ":variables", - ], -) - -cuda_py_test( - name = "session_manager_test", - size = "medium", # TODO(irving): Can this be made small? - srcs = ["training/session_manager_test.py"], - grpc_enabled = True, - main = "training/session_manager_test.py", - python_version = "PY3", - tfrt_enabled = True, - deps = [ - ":array_ops", - ":client", - ":client_testlib", - ":control_flow_ops", - ":errors", - ":framework_for_generated_wrappers", - ":platform", - ":training", - ":variables", - ], -) - -tf_py_test( - name = "supervisor_test", - size = "small", - srcs = ["training/supervisor_test.py"], - grpc_enabled = True, - python_version = "PY3", - tags = ["no_windows"], - tfrt_enabled = True, - deps = [ - ":array_ops", - ":checkpoint_management", - ":client_testlib", - ":errors", - ":framework", - ":framework_for_generated_wrappers", - ":io_ops", - ":parsing_ops", - ":platform", - ":saver", - ":summary", - ":training", - ":variables", - "//tensorflow/core:protos_all_py", - ], -) - -tf_py_test( - name = "basic_session_run_hooks_test", - size = "medium", - srcs = ["training/basic_session_run_hooks_test.py"], - python_version = "PY3", - tags = [ - "no_pip", # Relies on contrib - "no_windows", - "notsan", # intermittent races on a few percent of runs - ], - tfrt_enabled = True, - deps = [ - ":client", - ":client_testlib", - ":control_flow_ops", - ":fake_summary_writer", - ":framework", - ":framework_for_generated_wrappers", - ":nn_grad", - ":platform", - ":state_ops", - ":summary", - ":training", - ":variable_scope", - ":variables", - "//tensorflow/core:protos_all_py", - ], -) - -tf_py_test( - name = "checkpoint_utils_test", - size = "small", - srcs = ["training/checkpoint_utils_test.py"], - python_version = "PY3", - tags = [ - "manual", - "no_cuda_on_cpu_tap", - "no_oss", - "no_windows", - "notap", - ], - deps = [ - ":client", - ":client_testlib", - ":framework_for_generated_wrappers", - ":io_ops", - ":partitioned_variables", - ":platform", - ":resource_variable_ops", - ":state_ops", - ":training", - ":variable_scope", - ":variables", - ], -) - -tf_py_test( - name = "checkpoint_ops_test", - size = "small", - srcs = ["training/checkpoint_ops_test.py"], - python_version = "PY3", - tfrt_enabled = True, - deps = [ - ":checkpoint_ops_gen", - ":client", - ":client_testlib", - ":framework_for_generated_wrappers", - ":io_ops", - ":partitioned_variables", - ":platform", - ":pywrap_tensorflow", - ":state_ops", - ":training", - ":variable_scope", - ":variables", - ], -) - -tf_py_test( - name = "warm_starting_util_test", - size = "medium", - srcs = ["training/warm_starting_util_test.py"], - python_version = "PY3", - tfrt_enabled = True, - deps = [ - ":array_ops", - ":client_testlib", - ":dtypes", - ":framework_ops", - ":init_ops", - ":training", - ":variable_scope", - ":variables", - "//third_party/py/numpy", - ], -) - -tf_py_test( - name = "monitored_session_test", - size = "medium", - srcs = ["training/monitored_session_test.py"], - tags = [ - "no_pip", - "notsan", # b/67945581 - ], - tfrt_enabled = True, - deps = [ - ":array_ops", - ":checkpoint_management", - ":client_testlib", - ":control_flow_ops", - ":errors", - ":framework_for_generated_wrappers", - ":resource_variable_ops", - ":saver", - ":session", - ":state_ops", - ":summary", - ":training", - ":variables", - "//tensorflow/core:protos_all_py", - "//tensorflow/python/distribute:collective_all_reduce_strategy", - "//tensorflow/python/distribute:distribute_coordinator", - ], -) - -py_library( - name = "training_util", - srcs = ["training/training_util.py"], - srcs_version = "PY2AND3", - deps = [ - ":dtypes", - ":framework", - ":framework_ops", - ":init_ops", - ":platform", - ":resource_variable_ops", - ":state_ops", - ":util", - ":variable_scope", - ":variables", - "//tensorflow/python/eager:context", - ], -) - -tf_py_test( - name = "training_util_test", - size = "small", - srcs = ["training/training_util_test.py"], - python_version = "PY3", - tfrt_enabled = True, - deps = [ - ":client_testlib", - ":framework", - ":platform", - ":training", - ":training_util", - ":variables", - ], -) - -tf_py_test( - name = "input_test", - size = "medium", - srcs = ["training/input_test.py"], - python_version = "PY3", - tfrt_enabled = True, - deps = [ - ":array_ops", - ":client_testlib", - ":errors", - ":framework", - ":framework_for_generated_wrappers", - ":math_ops", - ":platform", - ":training", - ":util", - ":variables", - "//third_party/py/numpy", - ], -) - py_library( name = "summary_op_util", srcs = ["ops/summary_op_util.py"], @@ -8499,11 +7526,11 @@ tf_python_pybind_extension( module_name = "_pywrap_parallel_device", visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"], deps = [ + ":pybind11_lib", + ":pybind11_status", "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib_headers_for_pybind", "//tensorflow/core:protos_all_cc", - "//tensorflow/python:pybind11_lib", - "//tensorflow/python:pybind11_status", "//third_party/python_runtime:headers", "@pybind11", ], @@ -8543,3 +7570,102 @@ cuda_py_test( ":client_testlib", ], ) + +alias( + name = "basic_session_run_hooks", + actual = "//tensorflow/python/training:basic_session_run_hooks", +) + +alias( + name = "checkpoint_management", + actual = "//tensorflow/python/training:checkpoint_management", +) + +alias( + name = "distribute", + actual = "//tensorflow/python/training:distribute", +) + +alias( + name = "py_checkpoint_reader", + actual = "//tensorflow/python/training:py_checkpoint_reader", +) + +alias( + name = "saver", + actual = "//tensorflow/python/training:saver", +) + +alias( + name = "session_run_hook", + actual = "//tensorflow/python/training:session_run_hook", +) + +alias( + name = "training", + actual = "//tensorflow/python/training:training", +) + +alias( + name = "training_lib", + actual = "//tensorflow/python/training:training_lib", +) + +alias( + name = "training_server_lib", + actual = "//tensorflow/python/training:server_lib", +) + +alias( + name = "training_util", + actual = "//tensorflow/python/training:training_util", +) + +alias( + name = "loss_scale", + actual = "//tensorflow/python/training/experimental:loss_scale", +) + +alias( + name = "loss_scale_optimizer", + actual = "//tensorflow/python/training/experimental:loss_scale_optimizer", +) + +alias( + name = "mixed_precision", + actual = "//tensorflow/python/training/experimental:mixed_precision", +) + +alias( + name = "mixed_precision_global_state", + actual = "//tensorflow/python/training/experimental:mixed_precision_global_state", +) + +alias( + name = "loss_scaling_gradient_tape", + actual = "//tensorflow/python/training/experimental:loss_scaling_gradient_tape", +) + +py_library( + name = "learning_rate_decay", + # This rule depends on a target that only python:__pkg__ has visibility for. + srcs = ["//tensorflow/python/training:learning_rate_decay.py"], + srcs_version = "PY2AND3", + deps = ["//tensorflow/python/keras/optimizer_v2:legacy_learning_rate_decay"], +) + +py_test( + name = "loss_scale_optimizer_test", + size = "small", + # This test currently depends on rules only python:__pkg__ has visibility for. + srcs = ["//tensorflow/python/training/experimental:loss_scale_optimizer_test.py"], + python_version = "PY3", + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python/distribute:mirrored_strategy", + "//tensorflow/python/distribute:one_device_strategy", + "//tensorflow/python/keras/mixed_precision/experimental:test_util", + "//tensorflow/python/training/experimental:loss_scale_optimizer", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 4507118c17c..10cf520f4e5 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -180,13 +180,13 @@ tf_py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:lib", "//tensorflow/python:math_ops", - "//tensorflow/python:saver_test_utils", "//tensorflow/python:session", "//tensorflow/python:state_ops", "//tensorflow/python:test_ops", "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variables", + "//tensorflow/python/training:saver_test_utils", ], ) @@ -449,10 +449,10 @@ py_strict_library( "//tensorflow/python:platform", "//tensorflow/python:saver", "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training_lib", "//tensorflow/python/eager:context", "//tensorflow/python/eager:lift_to_graph", "//tensorflow/python/eager:wrap_function", + "//tensorflow/python/training:monitored_session", "//tensorflow/python/training/tracking", ], ) diff --git a/tensorflow/python/training/BUILD b/tensorflow/python/training/BUILD new file mode 100644 index 00000000000..0e864b176d6 --- /dev/null +++ b/tensorflow/python/training/BUILD @@ -0,0 +1,1424 @@ +load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library") +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "cuda_py_tests") + +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files( + # Used in a pybind extension whose rule must be in tensorflow/python + ["quantize_training_wrapper.cc"], + visibility = ["//tensorflow/python:__pkg__"], +) + +exports_files( + # Used in a rule which visibility limits to tensorflow/python + ["learning_rate_decay.py"], + visibility = ["//tensorflow/python:__pkg__"], +) + +# Files which have their own BUILD rules, but which for compatibility with +# strict dep checking need to be direct dependencies of training_lib. Do not add +# any new files to this list. +filegroup( + name = "deprecated_inclusions_in_training_lib", + srcs = [ + "adadelta.py", + "adagrad.py", + "adagrad_da.py", + "adam.py", + "basic_loops.py", + "checkpoint_ops.py", + "checkpoint_utils.py", + "coordinator.py", + "device_setter.py", + "evaluation.py", + "ftrl.py", + "gradient_descent.py", + "input.py", + "learning_rate_decay.py", + "momentum.py", + "monitored_session.py", + "moving_averages.py", + "optimizer.py", + "proximal_adagrad.py", + "proximal_gradient_descent.py", + "py_checkpoint_reader.py", + "quantize_training.py", + "queue_runner.py", + "queue_runner_impl.py", + "rmsprop.py", + "server_lib.py", + "session_manager.py", + "slot_creator.py", + "summary_io.py", + "supervisor.py", + "sync_replicas_optimizer.py", + "tensorboard_logging.py", + "training.py", + "training_ops.py", + "warm_starting_util.py", + ], + visibility = ["//tensorflow/python/training:__pkg__"], +) + +py_library( + name = "training_lib", + srcs = [ + "__init__.py", + "training.py", + ":deprecated_inclusions_in_training_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":adadelta", + ":adagrad", + ":adagrad_da", + ":adam", + ":basic_loops", + ":basic_session_run_hooks", + ":checkpoint_management", + ":checkpoint_utils", + ":coordinator", + ":device_setter", + ":ftrl", + ":gradient_descent", + ":input", + ":momentum", + ":monitored_session", + ":moving_averages", + ":optimizer", + ":proximal_adagrad", + ":proximal_gradient_descent", + ":py_checkpoint_reader", + ":quantize_training", + ":queue_runner", + ":rmsprop", + ":saver", + ":server_lib", + ":session_manager", + ":session_run_hook", + ":summary_io", + ":supervisor", + ":sync_replicas_optimizer", + ":tensorboard_logging", + ":training_util", + ":warm_starting_util", + "//tensorflow/python:learning_rate_decay", + "//tensorflow/python:sdca_ops", + "//tensorflow/python:tf_export", + "//tensorflow/python/training/experimental:loss_scale_optimizer", + "//tensorflow/python/training/experimental:mixed_precision", + ], +) + +py_library( + name = "training", + srcs_version = "PY2AND3", + deps = [ + ":training_lib", + "//tensorflow/python/training/tracking:base", + "//tensorflow/python/training/tracking:python_state", + "//tensorflow/python/training/tracking:util", + ], +) + +py_library( + name = "adadelta", + srcs = ["adadelta.py"], + srcs_version = "PY2AND3", + deps = [ + ":optimizer", + ":training_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:tf_export", + ], +) + +py_library( + name = "adagrad_da", + srcs = ["adagrad_da.py"], + srcs_version = "PY2AND3", + deps = [ + ":optimizer", + ":training_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:tf_export", + ], +) + +py_library( + name = "adagrad", + srcs = ["adagrad.py"], + srcs_version = "PY2AND3", + deps = [ + ":optimizer", + ":training_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:array_ops_gen", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:tf_export", + ], +) + +py_library( + name = "adam", + srcs = ["adam.py"], + srcs_version = "PY2AND3", + deps = [ + ":optimizer", + ":training_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:tf_export", + "//tensorflow/python/eager:context", + ], +) + +py_library( + name = "basic_loops", + srcs = ["basic_loops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:errors", + "//tensorflow/python:tf_export", + ], +) + +py_library( + name = "checkpoint_ops", + srcs = ["checkpoint_ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:checkpoint_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + ], +) + +py_library( + name = "checkpoint_utils", + srcs = ["checkpoint_utils.py"], + srcs_version = "PY2AND3", + deps = [ + ":checkpoint_management", + ":py_checkpoint_reader", + "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:platform", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:tf_export", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/training/saving:saveable_object_util", + "@six_archive//:six", + ], +) + +py_library( + name = "coordinator", + srcs = ["coordinator.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:errors", + "//tensorflow/python:platform", + "//tensorflow/python:tf_export", + "//tensorflow/python:util", + "@six_archive//:six", + ], +) + +py_library( + name = "device_setter", + srcs = ["device_setter.py"], + srcs_version = "PY2AND3", + deps = [ + ":server_lib", + "//tensorflow/python:device", + "//tensorflow/python:platform", + "//tensorflow/python:tf_export", + "@six_archive//:six", + ], +) + +py_library( + name = "distribution_strategy_context", + srcs = ["distribution_strategy_context.py"], + srcs_version = "PY2AND3", + deps = ["//tensorflow/python/distribute:distribute_lib"], +) + +py_library( + name = "evaluation", + srcs = ["evaluation.py"], + srcs_version = "PY2AND3", + deps = [ + ":basic_session_run_hooks", + ":monitored_session", + ":session_run_hook", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:variable_scope", + ], +) + +py_library( + name = "ftrl", + srcs = ["ftrl.py"], + srcs_version = "PY2AND3", + deps = [ + ":optimizer", + ":training_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:tf_export", + ], +) + +py_library( + name = "gradient_descent", + srcs = ["gradient_descent.py"], + srcs_version = "PY2AND3", + deps = [ + ":optimizer", + ":training_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:tf_export", + ], +) + +py_library( + name = "input", + srcs = ["input.py"], + srcs_version = "PY2AND3", + deps = [ + ":queue_runner", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:layers_util", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:summary", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tf_export", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + "@six_archive//:six", + ], +) + +py_library( + name = "momentum", + srcs = ["momentum.py"], + srcs_version = "PY2AND3", + deps = [ + ":optimizer", + ":training_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:tf_export", + ], +) + +py_library( + name = "moving_averages", + srcs = ["moving_averages.py"], + srcs_version = "PY2AND3", + deps = [ + ":slot_creator", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:tf_export", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:reduce_util", + ], +) + +py_library( + name = "optimizer", + srcs = ["optimizer.py"], + srcs_version = "PY2AND3", + deps = [ + ":slot_creator", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:tf_export", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:reduce_util", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", + "//tensorflow/python/training/tracking:base", + "@six_archive//:six", + ], +) + +py_library( + name = "proximal_adagrad", + srcs = ["proximal_adagrad.py"], + srcs_version = "PY2AND3", + deps = [ + ":optimizer", + ":training_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:tf_export", + ], +) + +py_library( + name = "proximal_gradient_descent", + srcs = ["proximal_gradient_descent.py"], + srcs_version = "PY2AND3", + deps = [ + ":optimizer", + ":training_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:tf_export", + ], +) + +py_library( + name = "quantize_training", + srcs = ["quantize_training.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:_pywrap_quantize_training", + "//tensorflow/python:tf_export", + "//tensorflow/python:util", + ], +) + +py_library( + name = "queue_runner_impl", + srcs = ["queue_runner_impl.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:session", + "//tensorflow/python:tf_export", + "//tensorflow/python:util", + "//tensorflow/python/eager:context", + ], +) + +py_library( + name = "queue_runner", + srcs = ["queue_runner.py"], + srcs_version = "PY2AND3", + deps = [":queue_runner_impl"], +) + +py_library( + name = "rmsprop", + srcs = ["rmsprop.py"], + srcs_version = "PY2AND3", + deps = [ + ":optimizer", + ":training_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:tf_export", + ], +) + +py_library( + name = "session_manager", + srcs = ["session_manager.py"], + srcs_version = "PY2AND3", + deps = [ + ":checkpoint_management", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:session", + "//tensorflow/python:tf_export", + "//tensorflow/python/distribute:distribute_lib", + "//third_party/py/numpy", + ], +) + +py_library( + name = "slot_creator", + srcs = ["slot_creator.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/distribute:distribute_lib", + ], +) + +py_library( + name = "summary_io", + srcs = ["summary_io.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:util", + ], +) + +py_library( + name = "sync_replicas_optimizer", + srcs = ["sync_replicas_optimizer.py"], + srcs_version = "PY2AND3", + deps = [ + ":optimizer", + ":queue_runner", + ":session_manager", + ":session_run_hook", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:tf_export", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/distribute:distribute_lib", + ], +) + +py_library( + name = "tensorboard_logging", + srcs = ["tensorboard_logging.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:platform", + ], +) + +py_library( + name = "training_ops", + srcs = [ + "gen_training_ops.py", + "training_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:training_ops_gen", + ], +) + +py_library( + name = "warm_starting_util", + srcs = ["warm_starting_util.py"], + srcs_version = "PY2AND3", + deps = [ + ":checkpoint_ops", + ":checkpoint_utils", + ":saver", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:tf_export", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/training/saving:saveable_object_util", + "@six_archive//:six", + ], +) + +py_library( + name = "distribute", + srcs = [ + "distribute.py", + "distribution_strategy_context.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/distribute:distribute_lib", + ], +) + +tf_py_test( + name = "server_lib_test", + size = "small", + srcs = ["server_lib_test.py"], + grpc_enabled = True, + python_version = "PY3", + tags = [ + "noasan", # TODO(b/161236904): flaky timeout in trying to start gRPC server + ], + tfrt_enabled = True, + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +tf_py_test( + name = "server_lib_multiple_containers_test", + size = "small", + srcs = ["server_lib_multiple_containers_test.py"], + grpc_enabled = True, + python_version = "PY3", + tfrt_enabled = True, + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +tf_py_test( + name = "server_lib_same_variables_clear_container_test", + size = "small", + srcs = ["server_lib_same_variables_clear_container_test.py"], + grpc_enabled = True, + python_version = "PY3", + tfrt_enabled = True, + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +tf_py_test( + name = "server_lib_same_variables_clear_test", + size = "small", + srcs = ["server_lib_same_variables_clear_test.py"], + grpc_enabled = True, + python_version = "PY3", + tfrt_enabled = True, + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +tf_py_test( + name = "server_lib_same_variables_no_clear_test", + size = "small", + srcs = ["server_lib_same_variables_no_clear_test.py"], + grpc_enabled = True, + python_version = "PY3", + tfrt_enabled = True, + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +tf_py_test( + name = "server_lib_sparse_job_test", + size = "small", + srcs = ["server_lib_sparse_job_test.py"], + grpc_enabled = True, + python_version = "PY3", + tfrt_enabled = True, + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +cuda_py_test( + name = "localhost_cluster_performance_test", + size = "medium", + srcs = [ + "localhost_cluster_performance_test.py", + ], + grpc_enabled = True, + python_version = "PY3", + tags = [ + "no_oss", # Test flaky due to port collisions. + "oss_serial", + ], + tfrt_enabled = True, + deps = [ + ":device_setter", + "//tensorflow/python:client_testlib", + "//tensorflow/python:distributed_framework_test_lib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:partitioned_variables", + "//tensorflow/python:session", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +tf_py_test( + name = "sync_replicas_optimizer_test", + size = "medium", + srcs = [ + "sync_replicas_optimizer_test.py", + ], + grpc_enabled = True, + python_version = "PY3", + tags = [ + "no_oss", # Test flaky due to port collisions. + "notsan", # data race due to b/62910646 + "oss_serial", + ], + tfrt_enabled = True, + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:variables", + ], +) + +tf_py_test( + name = "evaluation_test", + size = "small", + srcs = ["evaluation_test.py"], + python_version = "PY3", + shard_count = 3, + tags = [ + "manual", + "notap", # Disabling until b/33000128 and b/33040312 are fixed. + ], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:summary", + "//tensorflow/python:variables", + "//tensorflow/python/ops/losses", + "//third_party/py/numpy", + ], +) + +py_library( + name = "py_checkpoint_reader", + srcs = ["py_checkpoint_reader.py"], + deps = [ + "//tensorflow/python:_pywrap_checkpoint_reader", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:tf_export", + "//tensorflow/python:util", + ], +) + +tf_proto_library( + name = "checkpoint_state", + srcs = ["checkpoint_state.proto"], + cc_api_version = 2, +) + +py_library( + name = "checkpoint_management", + srcs = ["checkpoint_management.py"], + deps = [ + ":training_util", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:lib", + "//tensorflow/python:platform", + "//tensorflow/python:tf_export", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + ], +) + +cuda_py_test( + name = "checkpoint_management_test", + size = "small", + srcs = [ + "checkpoint_management_test.py", + ], + python_version = "PY3", + deps = [ + ":checkpoint_management", + ":saver", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:lib", + "//tensorflow/python:platform", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/training/tracking:util", + ], +) + +py_library( + name = "saver", + srcs = ["saver.py"], + srcs_version = "PY2AND3", + deps = [ + ":checkpoint_management", + ":py_checkpoint_reader", + ":training_util", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:device", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:io_ops_gen", + "//tensorflow/python:platform", + "//tensorflow/python:session", + "//tensorflow/python:string_ops", + "//tensorflow/python:tf_export", + "//tensorflow/python:util", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/training/saving:saveable_object", + "//tensorflow/python/training/saving:saveable_object_util", + "//tensorflow/python/training/tracking:base", + "//third_party/py/numpy", + ], +) + +py_library( + name = "saver_test_utils", + srcs = ["saver_test_utils.py"], + srcs_version = "PY2AND3", + deps = [ + ":saver", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:lookup_ops_gen", + "//tensorflow/python/eager:context", + ], +) + +cuda_py_test( + name = "saver_test", + size = "medium", + srcs = [ + "saver_test.py", + ], + python_version = "PY3", + tags = ["multi_gpu"], + deps = [ + ":adam", + ":checkpoint_management", + ":gradient_descent", + ":py_checkpoint_reader", + ":queue_runner_impl", + ":saver", + ":saver_test_utils", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:function", + "//tensorflow/python:gradients_impl", + "//tensorflow/python:lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_grad", + "//tensorflow/python:nn_ops", + "//tensorflow/python:partitioned_variables", + "//tensorflow/python:platform", + "//tensorflow/python:random_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:session", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:summary", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/training/tracking:base", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +tf_py_test( + name = "saver_large_variable_test", + size = "medium", + srcs = ["saver_large_variable_test.py"], + python_version = "PY3", + tags = [ + "manual", + "noasan", # http://b/30379628 + "notsan", # http://b/30379628 + ], + tfrt_enabled = True, + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:variables", + ], +) + +tf_py_test( + name = "saver_large_partitioned_variable_test", + size = "medium", + srcs = ["saver_large_partitioned_variable_test.py"], + python_version = "PY3", + tags = [ + "noasan", # http://b/30782289 + "notsan", # http://b/30782289 + ], + tfrt_enabled = True, + deps = [ + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:partitioned_variables", + "//tensorflow/python:variables", + ], +) + +py_library( + name = "basic_session_run_hooks", + srcs = ["basic_session_run_hooks.py"], + srcs_version = "PY2AND3", + deps = [ + ":session_run_hook", + ":summary_io", + ":training_util", + "//tensorflow/python:client", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:platform", + "//tensorflow/python:tf_export", + "//tensorflow/python:variable_scope", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_library( + name = "session_run_hook", + srcs = ["session_run_hook.py"], + srcs_version = "PY2AND3", + deps = ["//tensorflow/python:tf_export"], +) + +py_library( + name = "supervisor", + srcs = ["supervisor.py"], + deps = [ + ":coordinator", + ":saver", + ":session_manager", + ":training_util", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework", + "//tensorflow/python:framework_ops", + "//tensorflow/python:lookup_ops", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + "//tensorflow/python:tf_export", + "//tensorflow/python:util", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + ], +) + +tf_py_test( + name = "supervisor_test", + size = "small", + srcs = ["supervisor_test.py"], + grpc_enabled = True, + python_version = "PY3", + tags = ["no_windows"], + tfrt_enabled = True, + deps = [ + ":checkpoint_management", + ":saver", + ":supervisor", + ":training", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:io_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + "//tensorflow/python:variables", + ], +) + +py_library( + name = "server_lib", + srcs = ["server_lib.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:errors", + "//tensorflow/python:pywrap_tf_session", + "//tensorflow/python:tf_export", + "//tensorflow/python:util", + ], +) + +py_library( + name = "training_util", + srcs = ["training_util.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dtypes", + "//tensorflow/python:framework", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:platform", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:tf_export", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + ], +) + +tf_py_test( + name = "training_util_test", + size = "small", + srcs = ["training_util_test.py"], + python_version = "PY3", + tfrt_enabled = True, + deps = [ + ":training_util", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:platform", + "//tensorflow/python:variables", + ], +) + +cuda_py_test( + name = "adam_test", + size = "medium", + srcs = ["adam_test.py"], + python_version = "PY3", + tfrt_enabled = True, + deps = [ + ":adam", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:session", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//third_party/py/numpy", + ], +) + +cuda_py_test( + name = "moving_averages_test", + size = "small", + srcs = [ + "moving_averages_test.py", + ], + python_version = "PY3", + tags = [ + "no_windows", # b/139083295: bfloat16 tests fail on Windows + "notsan", + ], + tfrt_enabled = True, + deps = [ + ":moving_averages", + ":saver", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:state_ops_gen", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + ], +) + +cuda_py_tests( + name = "training_tests", + size = "medium", + srcs = [ + "adadelta_test.py", + "adagrad_da_test.py", + "adagrad_test.py", + "basic_loops_test.py", + "coordinator_test.py", + "device_setter_test.py", + "ftrl_test.py", + "gradient_descent_test.py", + "momentum_test.py", + "optimizer_test.py", + "proximal_adagrad_test.py", + "proximal_gradient_descent_test.py", + "quantize_training_test.py", + "queue_runner_test.py", + "rmsprop_test.py", + "slot_creator_test.py", + "tensorboard_logging_test.py", + "training_ops_test.py", + ], + python_version = "PY3", + deps = [ + ":training", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:data_flow_ops_gen", + "//tensorflow/python:embedding_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:lookup_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_grad", + "//tensorflow/python:nn_ops", + "//tensorflow/python:partitioned_variables", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:random_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:resources", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:state_ops_gen", + "//tensorflow/python:summary", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +cuda_py_test( + name = "session_manager_test", + size = "medium", # TODO(irving): Can this be made small? + srcs = ["session_manager_test.py"], + grpc_enabled = True, + main = "session_manager_test.py", + python_version = "PY3", + tfrt_enabled = True, + deps = [ + ":checkpoint_management", + ":saver", + ":server_lib", + ":session_manager", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform", + "//tensorflow/python:session", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + +tf_py_test( + name = "basic_session_run_hooks_test", + size = "medium", + srcs = ["basic_session_run_hooks_test.py"], + python_version = "PY3", + tags = [ + "no_pip", # Relies on contrib + "no_windows", + "notsan", # intermittent races on a few percent of runs + ], + tfrt_enabled = True, + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:fake_summary_writer", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:nn_grad", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:summary", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + +tf_py_test( + name = "checkpoint_utils_test", + size = "small", + srcs = ["checkpoint_utils_test.py"], + python_version = "PY3", + tags = [ + "manual", + "no_cuda_on_cpu_tap", + "no_oss", + "no_windows", + "notap", + ], + deps = [ + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:io_ops", + "//tensorflow/python:partitioned_variables", + "//tensorflow/python:platform", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + +tf_py_test( + name = "checkpoint_ops_test", + size = "small", + srcs = ["checkpoint_ops_test.py"], + python_version = "PY3", + tfrt_enabled = True, + deps = [ + "//tensorflow/python:checkpoint_ops_gen", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:io_ops", + "//tensorflow/python:partitioned_variables", + "//tensorflow/python:platform", + "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:state_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + +tf_py_test( + name = "warm_starting_util_test", + size = "medium", + srcs = ["warm_starting_util_test.py"], + python_version = "PY3", + tfrt_enabled = True, + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +py_library( + name = "monitored_session", + srcs = ["monitored_session.py"], + srcs_version = "PY2AND3", + deps = [ + ":basic_session_run_hooks", + ":coordinator", + ":queue_runner", + ":saver", + ":session_manager", + ":session_run_hook", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:lookup_ops", + "//tensorflow/python:platform", + "//tensorflow/python:resources", + "//tensorflow/python:summary", + "//tensorflow/python:tf_export", + "//tensorflow/python:util", + "//tensorflow/python:variables", + "//tensorflow/python/distribute:distribute_coordinator_context", + "@six_archive//:six", + ], +) + +tf_py_test( + name = "monitored_session_test", + size = "medium", + srcs = ["monitored_session_test.py"], + tags = [ + "no_pip", + "notsan", # b/67945581 + ], + tfrt_enabled = True, + deps = [ + ":checkpoint_management", + ":monitored_session", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:saver", + "//tensorflow/python:session", + "//tensorflow/python:state_ops", + "//tensorflow/python:summary", + "//tensorflow/python:variables", + "//tensorflow/python/distribute:collective_all_reduce_strategy", + "//tensorflow/python/distribute:distribute_coordinator", + ], +) + +tf_py_test( + name = "input_test", + size = "medium", + srcs = ["input_test.py"], + python_version = "PY3", + tfrt_enabled = True, + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:util", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/python/training/experimental/BUILD b/tensorflow/python/training/experimental/BUILD new file mode 100644 index 00000000000..4f897881ec6 --- /dev/null +++ b/tensorflow/python/training/experimental/BUILD @@ -0,0 +1,112 @@ +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) + +py_library( + name = "loss_scale", + srcs = ["loss_scale.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework", + "@absl_py//absl/testing:parameterized", + ], +) + +py_library( + name = "loss_scale_optimizer", + srcs = ["loss_scale_optimizer.py"], + srcs_version = "PY2AND3", + deps = [ + ":loss_scale", + "//tensorflow/python/distribute:distribute_lib", + "@absl_py//absl/testing:parameterized", + ], +) + +# The test currently requires visibility only granted to tensorflow/python:__pkg__ +exports_files( + ["loss_scale_optimizer_test.py"], + visibility = ["//tensorflow/python:__pkg__"], +) + +py_test( + name = "loss_scale_test", + size = "medium", + srcs = ["loss_scale_test.py"], + python_version = "PY3", + deps = [ + ":loss_scale", + "//tensorflow/python:client_testlib", + "//tensorflow/python/distribute:mirrored_strategy", + "//tensorflow/python/distribute:one_device_strategy", + "@absl_py//absl/testing:parameterized", + ], +) + +py_library( + name = "mixed_precision_global_state", + srcs = ["mixed_precision_global_state.py"], + srcs_version = "PY2AND3", +) + +py_library( + name = "mixed_precision", + srcs = ["mixed_precision.py"], + srcs_version = "PY2AND3", + deps = [ + ":loss_scale", + ":loss_scale_optimizer", + ":mixed_precision_global_state", + "//tensorflow/python:config", + "//tensorflow/python:util", + ], +) + +cuda_py_test( + name = "mixed_precision_test", + size = "small", + srcs = ["mixed_precision_test.py"], + python_version = "PY3", + tfrt_enabled = True, + deps = [ + ":mixed_precision", + "//tensorflow/python:client_testlib", + "@absl_py//absl/testing:parameterized", + ], +) + +py_library( + name = "loss_scaling_gradient_tape", + srcs = ["loss_scaling_gradient_tape.py"], + srcs_version = "PY2AND3", + deps = [ + ":loss_scale", + "//tensorflow/python:array_ops", + "//tensorflow/python:unconnected_gradients", + "//tensorflow/python:util", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/eager:backprop", + ], +) + +cuda_py_test( + name = "loss_scaling_gradient_tape_test", + size = "medium", + srcs = ["loss_scaling_gradient_tape_test.py"], + shard_count = 2, + deps = [ + ":loss_scale", + ":loss_scaling_gradient_tape", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_test_combinations_lib", + "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/distribute:mirrored_strategy", + "//tensorflow/python/eager:def_function", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/python/training/gen_training_ops.py b/tensorflow/python/training/gen_training_ops.py new file mode 100644 index 00000000000..5590b5056f8 --- /dev/null +++ b/tensorflow/python/training/gen_training_ops.py @@ -0,0 +1,29 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Python wrappers for training ops.""" +# NOTE(allenl): The generated op wrappers for training ops were originally in +# training/gen_training_ops.py. They moved to ops/gen_training_ops.py when +# training/ became a module, and this is an alias to avoid breaking existing +# imports. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.python.ops.gen_training_ops import * +# pylint: enable=wildcard-import diff --git a/tensorflow/python/training/training_ops.py b/tensorflow/python/training/training_ops.py index d7133cfb500..ba53657d6e6 100644 --- a/tensorflow/python/training/training_ops.py +++ b/tensorflow/python/training/training_ops.py @@ -19,8 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.training import gen_training_ops # pylint: disable=unused-import +from tensorflow.python.ops import gen_training_ops # pylint: disable=unused-import # go/tf-wildcard-import # pylint: disable=wildcard-import -from tensorflow.python.training.gen_training_ops import * +from tensorflow.python.ops.gen_training_ops import * # pylint: enable=wildcard-import