Merged commit includes the following changes:

253889825  by A. Unique TensorFlower<gardener@tensorflow.org>:
    Automated rollback of changelist 253879973.

253889366  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Internal change.

--
253886384  by A. Unique TensorFlower<gardener@tensorflow.org>:
    Automated
Temporary rollback to fix breakages
END_PUBLIC
Automated g4 rollback of changelist 253834965.

253883891  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Add commonly used multi-worker training utilities in Keras to multi_worker_util.py.

--
253882899  by A. Unique TensorFlower<gardener@tensorflow.org>:

    When lowering While V2 to While V1, add control dependency from body function call to NextIteration node.

--
253880779  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Fixed incorrect removing of Add with scalar broadcast.

--
253879973  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Internal changes

--

PiperOrigin-RevId: 253889825
This commit is contained in:
A. Unique TensorFlower 2019-06-18 16:16:25 -07:00 committed by TensorFlower Gardener
parent e985e958b3
commit 56e1e2ad5b
16 changed files with 390 additions and 420 deletions

View File

@ -375,6 +375,7 @@ Status LowerWhileHelper::CreateNextIterationNodes() {
TF_RETURN_IF_ERROR(NodeBuilder(NewName("next_iteration"), "NextIteration",
graph_->op_registry(), &debug_info_)
.Input(NodeOut(body_call_node_, i))
.ControlInput(body_call_node_)
.Device(while_op_->requested_device())
.Finalize(graph_, &next_iteration));
next_iterations_nodes_.emplace_back(next_iteration);

View File

@ -78,7 +78,8 @@ std::unique_ptr<SequenceTransformation> NewRemoveSingleInputAdd() {
auto& attr =
absl::any_cast<const AddAttributes&>(node->operation.attributes);
return absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param) ==
nullptr;
nullptr &&
absl::get_if<float>(&attr.param) == nullptr;
});
}

View File

@ -1,264 +1,264 @@
tensorflow/contrib/tpu/profiler/pip_package/BUILD
tensorflow/contrib/tpu/profiler/pip_package/setup.py
tensorflow/contrib/tpu/profiler/pip_package/README
tensorflow/contrib/tpu/profiler/pip_package/build_pip_package.sh
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/__init__.py
tensorflow/contrib/mpi/BUILD
tensorflow/stream_executor/build_defs.bzl
tensorflow/python/autograph/core/config.py
tensorflow/tools/ci_build/remote/BUILD
tensorflow/tools/pip_package/README
tensorflow/tools/pip_package/MANIFEST.in
tensorflow/tools/pip_package/simple_console.py
tensorflow/tools/pip_package/build_pip_package.sh
tensorflow/tools/pip_package/check_load_py_test.py
tensorflow/tools/pip_package/pip_smoke_test.py
tensorflow/tools/pip_package/simple_console_for_windows.py
tensorflow/tools/pip_package/setup.py
tensorflow/tools/pip_package/BUILD
tensorflow/tools/lib_package/concat_licenses.sh
tensorflow/tools/lib_package/libtensorflow_test.c
tensorflow/tools/lib_package/LibTensorFlowTest.java
tensorflow/tools/lib_package/BUILD
tensorflow/tools/lib_package/libtensorflow_test.sh
tensorflow/tools/lib_package/README.md
tensorflow/tools/lib_package/libtensorflow_java_test.sh
tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
tensorflow/tools/def_file_filter/BUILD
tensorflow/tools/def_file_filter/BUILD.tpl
tensorflow/tools/def_file_filter/def_file_filter.py.tpl
tensorflow/third_party/mkl/MKL_LICENSE
tensorflow/third_party/mkl/LICENSE
tensorflow/third_party/mkl/BUILD
tensorflow/third_party/mkl/mkl.BUILD
tensorflow/third_party/mkl/build_defs.bzl
tensorflow/third_party/backports_weakref.BUILD
tensorflow/third_party/toolchains/clang6/BUILD
tensorflow/third_party/toolchains/clang6/README.md
tensorflow/third_party/toolchains/clang6/repo.bzl
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
tensorflow/third_party/toolchains/clang6/clang.BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
tensorflow/third_party/toolchains/preconfig/generate/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/py/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.0/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/py3/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_025/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
tensorflow/third_party/systemlibs/nsync.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
tensorflow/third_party/systemlibs/curl.BUILD
tensorflow/third_party/systemlibs/cython.BUILD
tensorflow/third_party/systemlibs/astor.BUILD
tensorflow/third_party/systemlibs/jsoncpp.BUILD
tensorflow/third_party/systemlibs/png.BUILD
tensorflow/third_party/systemlibs/pcre.BUILD
tensorflow/third_party/systemlibs/grpc.BUILD
tensorflow/third_party/systemlibs/protobuf.BUILD
tensorflow/third_party/systemlibs/double_conversion.BUILD
tensorflow/third_party/systemlibs/six.BUILD
tensorflow/third_party/systemlibs/zlib.BUILD
tensorflow/third_party/systemlibs/lmdb.BUILD
tensorflow/third_party/systemlibs/sqlite.BUILD
tensorflow/third_party/systemlibs/gast.BUILD
tensorflow/third_party/systemlibs/absl_py.BUILD
tensorflow/third_party/systemlibs/boringssl.BUILD
tensorflow/third_party/systemlibs/BUILD.tpl
tensorflow/third_party/systemlibs/BUILD
tensorflow/third_party/systemlibs/termcolor.BUILD
tensorflow/third_party/systemlibs/gif.BUILD
tensorflow/third_party/systemlibs/protobuf.bzl
tensorflow/third_party/systemlibs/snappy.BUILD
tensorflow/third_party/systemlibs/googleapis.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
tensorflow/third_party/systemlibs/re2.BUILD
tensorflow/third_party/systemlibs/swig.BUILD
tensorflow/third_party/systemlibs/syslibs_configure.bzl
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
tensorflow/third_party/pprof.BUILD
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
tensorflow/third_party/toolchains/remote/BUILD.tpl
tensorflow/third_party/toolchains/remote/BUILD
tensorflow/third_party/toolchains/remote/configure.bzl
tensorflow/third_party/toolchains/cpus/py3/BUILD
tensorflow/third_party/toolchains/cpus/py/BUILD
tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl
tensorflow/third_party/toolchains/cpus/arm/CROSSTOOL.tpl
tensorflow/third_party/toolchains/cpus/arm/BUILD
tensorflow/third_party/toolchains/cpus/py3/BUILD
tensorflow/third_party/toolchains/cpus/py/BUILD
tensorflow/third_party/toolchains/remote/configure.bzl
tensorflow/third_party/toolchains/remote/BUILD.tpl
tensorflow/third_party/toolchains/remote/BUILD
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
tensorflow/third_party/toolchains/BUILD
tensorflow/third_party/gpus/BUILD
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_025/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
tensorflow/third_party/toolchains/preconfig/generate/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/py3/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/py/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.0/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/clang6/repo.bzl
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
tensorflow/third_party/toolchains/clang6/BUILD
tensorflow/third_party/toolchains/clang6/clang.BUILD
tensorflow/third_party/toolchains/clang6/README.md
tensorflow/third_party/farmhash.BUILD
tensorflow/third_party/git/BUILD.tpl
tensorflow/third_party/git/git_configure.bzl
tensorflow/third_party/git/BUILD
tensorflow/third_party/cub.BUILD
tensorflow/third_party/gpus/cuda_configure.bzl
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
tensorflow/third_party/gpus/rocm/BUILD.tpl
tensorflow/third_party/gpus/rocm/BUILD
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
tensorflow/third_party/gpus/rocm_configure.bzl
tensorflow/third_party/gpus/find_cuda_config.py
tensorflow/third_party/gpus/crosstool/LICENSE
tensorflow/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
tensorflow/third_party/gpus/crosstool/BUILD.tpl
tensorflow/third_party/gpus/crosstool/BUILD
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
tensorflow/third_party/gpus/cuda/LICENSE
tensorflow/third_party/gpus/cuda/BUILD.tpl
tensorflow/third_party/gpus/cuda/BUILD.windows.tpl
tensorflow/third_party/gpus/cuda/cuda_config.h.tpl
tensorflow/third_party/gpus/cuda/BUILD.tpl
tensorflow/third_party/gpus/cuda/BUILD
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
tensorflow/third_party/gpus/rocm/BUILD
tensorflow/third_party/gpus/rocm/BUILD.tpl
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
tensorflow/third_party/gpus/cuda_configure.bzl
tensorflow/third_party/gpus/find_cuda_config.py
tensorflow/third_party/gpus/rocm_configure.bzl
tensorflow/third_party/snappy.BUILD
tensorflow/third_party/cython.BUILD
tensorflow/third_party/farmhash.BUILD
tensorflow/third_party/eigen3/Eigen/Cholesky
tensorflow/third_party/eigen3/Eigen/QR
tensorflow/third_party/eigen3/Eigen/LU
tensorflow/third_party/eigen3/Eigen/Core
tensorflow/third_party/eigen3/Eigen/SVD
tensorflow/third_party/eigen3/Eigen/Eigenvalues
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
tensorflow/third_party/eigen3/gpu_packet_math.patch
tensorflow/third_party/eigen3/LICENSE
tensorflow/third_party/eigen3/BUILD
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
tensorflow/third_party/systemlibs/absl_py.BUILD
tensorflow/third_party/systemlibs/curl.BUILD
tensorflow/third_party/systemlibs/termcolor.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
tensorflow/third_party/systemlibs/grpc.BUILD
tensorflow/third_party/systemlibs/swig.BUILD
tensorflow/third_party/systemlibs/protobuf.bzl
tensorflow/third_party/systemlibs/protobuf.BUILD
tensorflow/third_party/systemlibs/BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
tensorflow/third_party/systemlibs/astor.BUILD
tensorflow/third_party/systemlibs/six.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
tensorflow/third_party/systemlibs/boringssl.BUILD
tensorflow/third_party/systemlibs/nsync.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
tensorflow/third_party/systemlibs/gif.BUILD
tensorflow/third_party/systemlibs/pcre.BUILD
tensorflow/third_party/systemlibs/BUILD.tpl
tensorflow/third_party/systemlibs/snappy.BUILD
tensorflow/third_party/systemlibs/gast.BUILD
tensorflow/third_party/systemlibs/cython.BUILD
tensorflow/third_party/systemlibs/double_conversion.BUILD
tensorflow/third_party/systemlibs/zlib.BUILD
tensorflow/third_party/systemlibs/jsoncpp.BUILD
tensorflow/third_party/systemlibs/re2.BUILD
tensorflow/third_party/systemlibs/lmdb.BUILD
tensorflow/third_party/systemlibs/googleapis.BUILD
tensorflow/third_party/systemlibs/png.BUILD
tensorflow/third_party/systemlibs/syslibs_configure.bzl
tensorflow/third_party/systemlibs/sqlite.BUILD
tensorflow/third_party/python_runtime/BUILD
tensorflow/third_party/sycl/crosstool/BUILD
tensorflow/third_party/ngraph/LICENSE
tensorflow/third_party/ngraph/tbb.BUILD
tensorflow/third_party/ngraph/BUILD
tensorflow/third_party/ngraph/ngraph.BUILD
tensorflow/third_party/ngraph/build_defs.bzl
tensorflow/third_party/ngraph/NGRAPH_LICENSE
tensorflow/third_party/ngraph/ngraph_tf.BUILD
tensorflow/third_party/ngraph/nlohmann_json.BUILD
tensorflow/third_party/clang_toolchain/download_clang.bzl
tensorflow/third_party/clang_toolchain/BUILD
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
tensorflow/third_party/gast.BUILD
tensorflow/third_party/llvm/BUILD
tensorflow/third_party/llvm/expand_cmake_vars.py
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
tensorflow/third_party/llvm/llvm.bzl
tensorflow/third_party/icu/udata.patch
tensorflow/third_party/fft2d/fft2d.h
tensorflow/third_party/fft2d/BUILD
tensorflow/third_party/fft2d/fft.h
tensorflow/third_party/fft2d/LICENSE
tensorflow/third_party/fft2d/fft2d.BUILD
tensorflow/third_party/nccl/archive.BUILD
tensorflow/third_party/nccl/LICENSE
tensorflow/third_party/nccl/system.BUILD.tpl
tensorflow/third_party/nccl/nccl_configure.bzl
tensorflow/third_party/nccl/build_defs.bzl.tpl
tensorflow/third_party/nccl/BUILD
tensorflow/third_party/boringssl/BUILD
tensorflow/third_party/mpi/.gitignore
tensorflow/third_party/mpi/BUILD
tensorflow/third_party/tensorrt/LICENSE
tensorflow/third_party/tensorrt/BUILD
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
tensorflow/third_party/tensorrt/BUILD.tpl
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
tensorflow/third_party/kafka/config.patch
tensorflow/third_party/kafka/BUILD
tensorflow/third_party/android/BUILD
tensorflow/third_party/android/android.bzl.tpl
tensorflow/third_party/android/android_configure.bzl
tensorflow/third_party/android/android_configure.BUILD.tpl
tensorflow/third_party/tflite_smartreply.BUILD
tensorflow/third_party/gpus/BUILD
tensorflow/third_party/common.bzl
tensorflow/third_party/tflite_mobilenet_quant.BUILD
tensorflow/third_party/linenoise.BUILD
tensorflow/third_party/curl.BUILD
tensorflow/third_party/mkl_dnn/LICENSE
tensorflow/third_party/mkl_dnn/mkldnn.BUILD
tensorflow/third_party/pcre.BUILD
tensorflow/third_party/pybind11.BUILD
tensorflow/third_party/linenoise.BUILD
tensorflow/third_party/sqlite.BUILD
tensorflow/third_party/common.bzl
tensorflow/third_party/com_google_absl.BUILD
tensorflow/third_party/pprof.BUILD
tensorflow/third_party/BUILD
tensorflow/third_party/tflite_mobilenet_quant.BUILD
tensorflow/third_party/wrapt.BUILD
tensorflow/third_party/lmdb.BUILD
tensorflow/third_party/git/BUILD.tpl
tensorflow/third_party/git/BUILD
tensorflow/third_party/git/git_configure.bzl
tensorflow/third_party/protobuf/BUILD
tensorflow/third_party/enum34.BUILD
tensorflow/third_party/tflite_mobilenet.BUILD
tensorflow/third_party/py/BUILD
tensorflow/third_party/py/BUILD.tpl
tensorflow/third_party/py/numpy/BUILD
tensorflow/third_party/py/python_configure.bzl
tensorflow/third_party/termcolor.BUILD
tensorflow/third_party/png_fix_rpi.patch
tensorflow/third_party/swig.BUILD
tensorflow/third_party/astor.BUILD
tensorflow/third_party/fft2d/LICENSE
tensorflow/third_party/fft2d/fft2d.BUILD
tensorflow/third_party/fft2d/fft2d.h
tensorflow/third_party/fft2d/fft.h
tensorflow/third_party/fft2d/BUILD
tensorflow/third_party/ngraph/LICENSE
tensorflow/third_party/ngraph/build_defs.bzl
tensorflow/third_party/ngraph/tbb.BUILD
tensorflow/third_party/ngraph/ngraph.BUILD
tensorflow/third_party/ngraph/nlohmann_json.BUILD
tensorflow/third_party/ngraph/BUILD
tensorflow/third_party/ngraph/ngraph_tf.BUILD
tensorflow/third_party/ngraph/NGRAPH_LICENSE
tensorflow/third_party/grpc/BUILD
tensorflow/third_party/curl.BUILD
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
tensorflow/third_party/cython.BUILD
tensorflow/third_party/icu/udata.patch
tensorflow/third_party/astor.BUILD
tensorflow/third_party/jsoncpp.BUILD
tensorflow/third_party/sycl/crosstool/BUILD
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
tensorflow/third_party/llvm/expand_cmake_vars.py
tensorflow/third_party/llvm/llvm.bzl
tensorflow/third_party/llvm/BUILD
tensorflow/third_party/png.BUILD
tensorflow/third_party/googleapis.BUILD
tensorflow/third_party/mpi_collectives/BUILD
tensorflow/third_party/nanopb.BUILD
tensorflow/third_party/gif.BUILD
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
tensorflow/third_party/codegen.BUILD
tensorflow/third_party/enum34.BUILD
tensorflow/third_party/kafka/config.patch
tensorflow/third_party/kafka/BUILD
tensorflow/third_party/pcre.BUILD
tensorflow/third_party/mpi/BUILD
tensorflow/third_party/mpi/.gitignore
tensorflow/third_party/clang_toolchain/BUILD
tensorflow/third_party/clang_toolchain/download_clang.bzl
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
tensorflow/third_party/tflite_ovic_testdata.BUILD
tensorflow/third_party/repo.bzl
tensorflow/third_party/png_fix_rpi.patch
tensorflow/third_party/py/python_configure.bzl
tensorflow/third_party/py/BUILD.tpl
tensorflow/third_party/py/BUILD
tensorflow/third_party/py/numpy/BUILD
tensorflow/third_party/double_conversion.BUILD
tensorflow/third_party/six.BUILD
tensorflow/third_party/tflite_mobilenet_float.BUILD
tensorflow/third_party/repo.bzl
tensorflow/third_party/codegen.BUILD
tensorflow/third_party/cub.BUILD
tensorflow/third_party/jsoncpp.BUILD
tensorflow/third_party/tflite_ovic_testdata.BUILD
tensorflow/third_party/__init__.py
tensorflow/third_party/libxsmm.BUILD
tensorflow/third_party/zlib.BUILD
tensorflow/third_party/lmdb.BUILD
tensorflow/third_party/nanopb.BUILD
tensorflow/third_party/pybind11.BUILD
tensorflow/third_party/android/android.bzl.tpl
tensorflow/third_party/android/BUILD
tensorflow/third_party/android/android_configure.BUILD.tpl
tensorflow/third_party/android/android_configure.bzl
tensorflow/third_party/tflite_mobilenet_float.BUILD
tensorflow/third_party/sqlite.BUILD
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
tensorflow/third_party/tensorrt/LICENSE
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
tensorflow/third_party/tensorrt/BUILD.tpl
tensorflow/third_party/tensorrt/BUILD
tensorflow/third_party/gast.BUILD
tensorflow/third_party/mpi_collectives/BUILD
tensorflow/third_party/libxsmm.BUILD
tensorflow/third_party/eigen.BUILD
tensorflow/third_party/com_google_absl.BUILD
tensorflow/third_party/eigen3/LICENSE
tensorflow/third_party/eigen3/gpu_packet_math.patch
tensorflow/third_party/eigen3/BUILD
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
tensorflow/third_party/eigen3/Eigen/QR
tensorflow/third_party/eigen3/Eigen/SVD
tensorflow/third_party/eigen3/Eigen/LU
tensorflow/third_party/eigen3/Eigen/Cholesky
tensorflow/third_party/eigen3/Eigen/Eigenvalues
tensorflow/third_party/eigen3/Eigen/Core
tensorflow/third_party/BUILD
tensorflow/third_party/termcolor.BUILD
tensorflow/third_party/gif.BUILD
tensorflow/third_party/tflite_mobilenet.BUILD
tensorflow/third_party/__init__.py
tensorflow/third_party/mkl/LICENSE
tensorflow/third_party/mkl/build_defs.bzl
tensorflow/third_party/mkl/mkl.BUILD
tensorflow/third_party/mkl/MKL_LICENSE
tensorflow/third_party/mkl/BUILD
tensorflow/third_party/nccl/build_defs.bzl.tpl
tensorflow/third_party/nccl/LICENSE
tensorflow/third_party/nccl/nccl_configure.bzl
tensorflow/third_party/nccl/archive.BUILD
tensorflow/third_party/nccl/BUILD
tensorflow/third_party/nccl/system.BUILD.tpl
tensorflow/third_party/snappy.BUILD
tensorflow/third_party/python_runtime/BUILD
tensorflow/third_party/googleapis.BUILD
tensorflow/third_party/wrapt.BUILD
tensorflow/third_party/boringssl/BUILD
tensorflow/third_party/protobuf/BUILD
tensorflow/third_party/backports_weakref.BUILD
tensorflow/third_party/tflite_smartreply.BUILD
tensorflow/third_party/swig.BUILD
tensorflow/compat_template.__init__.py
tensorflow/tools/lib_package/libtensorflow_test.sh
tensorflow/tools/lib_package/libtensorflow_java_test.sh
tensorflow/tools/lib_package/libtensorflow_test.c
tensorflow/tools/lib_package/concat_licenses.sh
tensorflow/tools/lib_package/LibTensorFlowTest.java
tensorflow/tools/lib_package/BUILD
tensorflow/tools/lib_package/README.md
tensorflow/tools/pip_package/check_load_py_test.py
tensorflow/tools/pip_package/simple_console.py
tensorflow/tools/pip_package/pip_smoke_test.py
tensorflow/tools/pip_package/BUILD
tensorflow/tools/pip_package/simple_console_for_windows.py
tensorflow/tools/pip_package/build_pip_package.sh
tensorflow/tools/pip_package/README
tensorflow/tools/pip_package/setup.py
tensorflow/tools/pip_package/MANIFEST.in
tensorflow/tools/ci_build/remote/BUILD
tensorflow/tools/def_file_filter/def_file_filter.py.tpl
tensorflow/tools/def_file_filter/BUILD.tpl
tensorflow/tools/def_file_filter/BUILD
tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
tensorflow/api_template.__init__.py
tensorflow/contrib/tpu/profiler/pip_package/BUILD
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/__init__.py
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
tensorflow/contrib/tpu/profiler/pip_package/build_pip_package.sh
tensorflow/contrib/tpu/profiler/pip_package/README
tensorflow/contrib/tpu/profiler/pip_package/setup.py
tensorflow/contrib/mpi/BUILD
tensorflow/python/autograph/core/config.py
tensorflow/virtual_root_template_v2.__init__.py
tensorflow/__init__.py
tensorflow/stream_executor/build_defs.bzl
tensorflow/api_template_v1.__init__.py
tensorflow/compat_template_v1.__init__.py
tensorflow/compat_template.__init__.py
tensorflow/api_template.__init__.py
tensorflow/__init__.py
tensorflow/virtual_root_template_v2.__init__.py
tensorflow/virtual_root_template_v1.__init__.py

View File

@ -18,7 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
from tensorflow.python.training import server_lib
@ -50,6 +54,7 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id):
"""Validates `cluster_spec`.
It checks:
0) None of `cluster_spec`, `task_type`, and `task_id` is `None`.
1) task type is one of "chief", "worker" or "evaluator".
2) whether there is such a task type as `task_type` in the `cluster_spec`.
3) whether there is at most one "chief" job.
@ -64,6 +69,10 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id):
Throws:
ValueError: if `cluster_spec` fails any check.
"""
if cluster_spec is None or task_type is None or task_id is None:
raise ValueError(
"None of `cluster_spec`, `task_type`, and `task_id` should be `None`.")
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
if task_type not in ("chief", "worker", "evaluator", "ps"):
raise ValueError(
@ -84,13 +93,16 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id):
"The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type))
def is_chief(cluster_spec, task_type, task_id):
def is_chief(cluster_spec=None, task_type=None, task_id=None):
"""Returns whether the given task is chief in the cluster.
Since there is at most one evaluator and the evaluator itself should be
independent of the training cluster, the evaluator job is also a chief job on
its own.
If this is currently running under a `_WorkerContext` of distribute
coordinator, the arguments can be omitted as the result is already available.
Args:
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
cluster configurations.
@ -104,6 +116,10 @@ def is_chief(cluster_spec, task_type, task_id):
ValueError: if `task_type` is not in the `cluster_spec` or `task_id` exceeds
the maximum id of the `task_type`.
"""
if has_worker_context():
# If a worker context exists, use the value provided by it.
return dc_context.get_current_worker_context().is_chief
_validate_cluster_spec(cluster_spec, task_type, task_id)
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
@ -208,3 +224,48 @@ def id_in_cluster(cluster_spec, task_type, task_id):
# We currently don't assign ids to other tasks.
raise ValueError("There is no id for task_type %r" % task_type)
def in_multi_worker_mode():
"""Whether the program is operating in Multi-Worker setting."""
# TODO(rchao): Consider a warning if user uses multiple `model` method
# calls in multi-worker setting.
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
cluster_spec = server_lib.ClusterSpec(tf_config.get("cluster", {}))
return tf_config and "master" not in cluster_spec.jobs
def should_save_checkpoint():
"""Returns whether the current worker should save checkpoints.
In multi-worker training, if saving checkpoint is requested by user, or needed
for fault-tolerance, the cluster should save checkpoint but not necessarily
every worker in the cluster should.
Returns:
Whether this particular worker in the cluster should save checkpoints.
"""
return dc_context.get_current_worker_context().should_checkpoint
def should_load_checkpoint():
"""Returns whether the current worker should load checkpoints.
In multi-worker training, if loading checkpoint is requested by user, or
needed for fault-tolerance, the cluster should load checkpoint but not
necessarily every worker in the cluster should.
Returns:
Whether this particular worker in the cluster should load checkpoints.
"""
return dc_context.get_current_worker_context().experimental_should_init
def wait_for_other_workers():
"""Waits for other workers to reach the same call to this method."""
return dc_context.get_current_worker_context().wait_for_other_workers()
def has_worker_context():
"""Returns whether a worker context has been entered."""
return dc_context.get_current_worker_context() is not None

View File

@ -681,6 +681,9 @@ tf_py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
],
tags = [
"no_windows", # TODO(b/135556470) reenable.
],
)
py_library(

View File

@ -125,6 +125,7 @@ py_library(
"//tensorflow/python:variables",
"//tensorflow/python/distribute:distribute_coordinator",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:multi_worker_util",
],
)

View File

@ -36,6 +36,7 @@ from tensorflow.python.client import session as session_module
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.eager import context
from tensorflow.python.framework import composite_tensor_utils
from tensorflow.python.eager import function as eager_function
@ -70,7 +71,6 @@ from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-im
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_inspect
@ -5649,15 +5649,6 @@ if not os.path.exists(_config_path):
pass
def in_multi_worker_mode():
"""Whether we are operating in a Multi-Worker setting."""
# TODO(rchao): Consider a warning if user uses multiple `model` method
# calls in multi-worker setting.
tf_config = json.loads(os.environ.get('TF_CONFIG', '{}'))
cluster_spec = server_lib.ClusterSpec(tf_config.get('cluster', {}))
return tf_config and 'master' not in cluster_spec.jobs
def configure_and_create_distributed_session(distribution_strategy):
"""Configure session config and create a session with it."""
@ -5694,7 +5685,7 @@ def configure_and_create_distributed_session(distribution_strategy):
set_session(session)
if in_multi_worker_mode():
if multi_worker_util.in_multi_worker_mode():
dc.run_distribute_coordinator(
_create_session,
distribution_strategy,

View File

@ -33,7 +33,7 @@ import numpy as np
import six
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
@ -898,7 +898,7 @@ class ModelCheckpoint(Callback):
self.save_weights_only = True
def on_train_begin(self, logs=None):
if K.in_multi_worker_mode():
if multi_worker_util.in_multi_worker_mode():
# pylint: disable=protected-access
# MultiWorkerTrainingState is used to manage the training state needed
# for preemption-recovery of a worker in multi-worker training.
@ -914,11 +914,8 @@ class ModelCheckpoint(Callback):
# If this is not multi worker training, restoring is not needed, or
# restoring failed, check if it should load weights on restart.
if self.load_weights_on_restart:
# In multi worker training, it only should load weights on restart if
# `experimental_should_init` is True.
# TODO(rchao): Reference `experimental_should_init` api from a util file.
if (not K.in_multi_worker_mode() or
dc_context.get_current_worker_context().experimental_should_init):
if (not multi_worker_util.in_multi_worker_mode()
or multi_worker_util.should_load_checkpoint()):
filepath_to_load = (
self._get_most_recently_modified_file_matching_pattern(
self.filepath))
@ -934,7 +931,7 @@ class ModelCheckpoint(Callback):
filepath_to_load, e))
def on_train_end(self, logs=None):
if K.in_multi_worker_mode():
if multi_worker_util.in_multi_worker_mode():
# In multi-worker training, on successful exit of training, delete the
# training state backup file that was saved for the purpose of worker
# recovery.
@ -958,13 +955,13 @@ class ModelCheckpoint(Callback):
def on_epoch_end(self, epoch, logs=None):
self.epochs_since_last_save += 1
if self.save_freq == 'epoch':
if K.in_multi_worker_mode():
if multi_worker_util.in_multi_worker_mode():
# Exclude training state variables in user-requested checkpoint file.
with self._training_state.untrack_vars():
self._save_model(epoch=epoch, logs=logs)
else:
self._save_model(epoch=epoch, logs=logs)
if K.in_multi_worker_mode():
if multi_worker_util.in_multi_worker_mode():
# For multi-worker training, back up the weights and current training
# state for possible future recovery.
# TODO(rchao): Call `back_up` at finer period such as N steps.
@ -1016,11 +1013,8 @@ class ModelCheckpoint(Callback):
def _get_file_path(self, epoch, logs):
"""Returns the file path for checkpoint."""
# TODO(rchao): Replace dc_context reference with
# distributed_training_utils.should_current_worker_checkpoint() once
# distributed_training_utils.py no longer depends on callbacks.py.
if not K.in_multi_worker_mode() or dc_context.get_current_worker_context(
).should_checkpoint:
if not multi_worker_util.in_multi_worker_mode(
) or multi_worker_util.should_save_checkpoint():
return self.filepath.format(epoch=epoch + 1, **logs)
else:
# If this is multi-worker training, and this worker should not
@ -1037,8 +1031,8 @@ class ModelCheckpoint(Callback):
# Remove the checkpoint directory in multi-worker training where this worker
# should not checkpoint. It is a dummy directory previously saved for sync
# distributed training.
if K.in_multi_worker_mode(
) and not dc_context.get_current_worker_context().should_checkpoint:
if multi_worker_util.in_multi_worker_mode(
) and not multi_worker_util.should_save_checkpoint():
file_io.delete_recursively(self._temp_file_dir)
del self._temp_file_dir

View File

@ -55,6 +55,9 @@ py_library(
"multi_worker_training_state.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python/distribute:multi_worker_util",
],
)
cuda_py_test(

View File

@ -26,6 +26,7 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import reduce_util
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
@ -373,8 +374,8 @@ def _wait_for_variable_initialization(session):
def init_restore_or_wait_for_variables():
"""Initialize or restore variables or wait for variables to be initialized."""
session = K._get_session() # pylint: disable=protected-access
worker_context = dc_context.get_current_worker_context()
if not worker_context or worker_context.experimental_should_init:
if not multi_worker_util.has_worker_context(
) or multi_worker_util.should_load_checkpoint():
# TODO(yuefengz): if checkpoints exist, restore from checkpoint.
K._initialize_variables(session) # pylint: disable=protected-access
else:
@ -1104,7 +1105,7 @@ def filter_distributed_callbacks(callbacks_list):
The list of `Callback` instances that should be run on this worker.
"""
if not K.in_multi_worker_mode():
if not multi_worker_util.in_multi_worker_mode():
raise ValueError(
'filter_distributed_callbacks() should only be called when Keras '
'is in multi worker mode.')

View File

@ -29,9 +29,9 @@ from tensorflow.python import keras
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import multi_worker_test_base as test_base
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks
from tensorflow.python.keras import testing_utils
@ -130,7 +130,7 @@ class KerasMultiWorkerCallbackTest(test_base.IndependentWorkerTestBase,
self.filtered_correctly = True
def on_train_begin(self, logs):
if not dc_context.get_current_worker_context().is_chief:
if not multi_worker_util.is_chief():
# Non-chief workers shouldn't run this callback.
self.filtered_correctly = False

View File

@ -32,9 +32,9 @@ from tensorflow.python import keras
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import multi_worker_test_base as test_base
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
from tensorflow.python.framework import ops
@ -222,11 +222,11 @@ def _run_standalone_client(test_obj, strategy, cluster_spec):
# Workaround for the metrics issue (b/122928955) in async training. This
# can only be used in standalone client mode.
dc_context.get_current_worker_context().wait_for_other_workers()
multi_worker_util.wait_for_other_workers()
model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)
dc_context.get_current_worker_context().wait_for_other_workers()
multi_worker_util.wait_for_other_workers()
trained_loss, trained_acc = model.evaluate(train_ds, steps=steps)

View File

@ -20,8 +20,7 @@ from __future__ import print_function
import contextlib
import os
import tempfile
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.keras import backend as K
@ -76,7 +75,7 @@ class MultiWorkerTrainingState(object):
# For those who should not checkpoint (e.g. non-chief worker in sync
# training), create a temporary directory to write to (that will be
# removed later).
if not dc_context.get_current_worker_context().should_checkpoint:
if not multi_worker_util.should_save_checkpoint():
self._temp_dir, self._temp_filepath = self._get_temp_filepath()
# The epoch at which the checkpoint is saved. Used for fault-tolerance.
@ -114,7 +113,7 @@ class MultiWorkerTrainingState(object):
# call. This is because the SyncOnReadVariable needs to be synced across
# all the workers in order to be read, and all workers need to initiate
# that.
if dc_context.get_current_worker_context().should_checkpoint:
if multi_worker_util.should_save_checkpoint():
save_filepath = self._backup_filepath
else:
save_filepath = self._temp_filepath
@ -122,7 +121,7 @@ class MultiWorkerTrainingState(object):
# Save the weights plus CKPT_SAVED_EPOCH variable.
self._model.save_weights(save_filepath, overwrite=True)
if not dc_context.get_current_worker_context().should_checkpoint:
if not multi_worker_util.should_save_checkpoint():
# Remove the file in multi-worker training where this worker should
# not checkpoint. It is a dummy file previously saved for sync distributed
# training.
@ -136,7 +135,7 @@ class MultiWorkerTrainingState(object):
state doesn't need to be restored, or error occurred so it can't.
"""
self._assert_in_multi_worker_mode()
if not dc_context.get_current_worker_context().experimental_should_init:
if not multi_worker_util.should_load_checkpoint():
# For multi-worker training, it should not restore a model in certain
# worker setting (e.g. non-chief worker in ParameterServerStrategy).
return False
@ -159,7 +158,7 @@ class MultiWorkerTrainingState(object):
"""
self._assert_in_multi_worker_mode()
tracking.AutoTrackable.__delattr__(self._model, CKPT_SAVED_EPOCH)
if dc_context.get_current_worker_context().should_checkpoint:
if multi_worker_util.should_save_checkpoint():
_remove_dir(self._backup_dir)
else:
assert not file_io.file_exists(self._temp_dir)
@ -218,7 +217,7 @@ class MultiWorkerTrainingState(object):
return temp_dir, os.path.join(temp_dir, 'temp_training_state')
def _assert_in_multi_worker_mode(self):
if not K.in_multi_worker_mode():
if not multi_worker_util.in_multi_worker_mode():
raise ValueError('MultiWorkerTrainingState is only supposed to be used '
'in multi-worker training. This indicates some error '
'that needs to be fixed. Please submit a bug issue to '

View File

@ -26,6 +26,7 @@ from tensorflow.python import tf2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import monitoring
@ -435,7 +436,7 @@ class Model(network.Network):
"""Select training loop for fit/eval/predict based on the inputs."""
# Case 1: distribution strategy.
if self._distribution_strategy:
if K.in_multi_worker_mode():
if multi_worker_util.in_multi_worker_mode():
return training_distributed.DistributionMultiWorkerTrainingLoop()
else:
return training_distributed.DistributionSingleWorkerTrainingLoop()

View File

@ -406,28 +406,10 @@ class BatchNormalizationBase(Layer):
experimental_autocast=False)
if self.renorm:
# In batch renormalization we track the inference moving stddev instead
# of the moving variance to more closely align with the paper.
def moving_stddev_initializer(*args, **kwargs):
return math_ops.sqrt(
self.moving_variance_initializer(*args, **kwargs))
with distribution_strategy_context.get_strategy(
).extended.colocate_vars_with(self.moving_variance):
self.moving_stddev = self.add_weight(
name='moving_stddev',
shape=param_shape,
dtype=self._param_dtype,
initializer=moving_stddev_initializer,
synchronization=tf_variables.VariableSynchronization.ON_READ,
trainable=False,
aggregation=tf_variables.VariableAggregation.MEAN,
experimental_autocast=False)
# Create variables to maintain the moving mean and standard deviation.
# These are used in training and thus are different from the moving
# averages above. The renorm variables are colocated with moving_mean
# and moving_stddev.
# and moving_variance.
# NOTE: below, the outer `with device` block causes the current device
# stack to be cleared. The nested ones use a `lambda` to set the desired
# device and ignore any devices that may be set by the custom getter.
@ -450,10 +432,17 @@ class BatchNormalizationBase(Layer):
).extended.colocate_vars_with(self.moving_mean):
self.renorm_mean = _renorm_variable('renorm_mean', param_shape,
self.moving_mean_initializer)
self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ())
# We initialize renorm_stddev to 0, and maintain the (0-initialized)
# renorm_stddev_weight. This allows us to (1) mix the average
# stddev with the minibatch stddev early in training, and (2) compute
# the unbiased average stddev by dividing renorm_stddev by the weight.
with distribution_strategy_context.get_strategy(
).extended.colocate_vars_with(self.moving_stddev):
self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape,
moving_stddev_initializer)
).extended.colocate_vars_with(self.moving_variance):
self.renorm_stddev = _renorm_variable(
'renorm_stddev', param_shape, self.moving_variance_initializer)
self.renorm_stddev_weight = _renorm_variable('renorm_stddev_weight',
())
finally:
if partitioner:
self._scope.set_partitioner(partitioner)
@ -472,11 +461,6 @@ class BatchNormalizationBase(Layer):
K.zeros_like(update_delta))
return state_ops.assign_sub(variable, update_delta, name=scope)
def _assign_new_value(self, variable, value):
with K.name_scope('AssignNewValue') as scope:
with ops.colocate_with(variable):
return state_ops.assign(variable, value, name=scope)
def _fused_batch_norm(self, inputs, training):
"""Returns the output of fused batch norm."""
beta = self.beta if self.center else self._beta_const
@ -534,21 +518,8 @@ class BatchNormalizationBase(Layer):
inputs_size)
def variance_update():
"""Update self.moving_variance with the most recent data point."""
if self.renorm:
# We apply epsilon as part of the moving_stddev to mirror the training
# code path.
moving_stddev = self._assign_moving_average(
self.moving_stddev, math_ops.sqrt(variance + self.epsilon),
momentum, inputs_size)
return self._assign_new_value(
self.moving_variance,
# Apply relu in case floating point rounding causes it to go
# negative.
K.relu(moving_stddev * moving_stddev - self.epsilon))
else:
return self._assign_moving_average(self.moving_variance, variance,
momentum, inputs_size)
return self._assign_moving_average(self.moving_variance, variance,
momentum, inputs_size)
self.add_update(mean_update)
self.add_update(variance_update)
@ -563,8 +534,7 @@ class BatchNormalizationBase(Layer):
# initialized with this batch's moments.
renorm_mean = self.renorm_mean
# Avoid divide by zero early on in training.
renorm_stddev = math_ops.maximum(self.renorm_stddev,
math_ops.sqrt(self.epsilon))
renorm_stddev = math_ops.maximum(self.renorm_stddev, self.epsilon)
# Compute the corrections for batch renorm.
r = stddev / renorm_stddev
d = (mean - renorm_mean) / renorm_stddev
@ -587,24 +557,40 @@ class BatchNormalizationBase(Layer):
lambda: d,
lambda: array_ops.zeros_like(d))
def _update_renorm_variable(var, value, inputs_size):
def _update_renorm_variable(var, weight, value, inputs_size):
"""Updates a moving average and weight, returns the unbiased value."""
value = array_ops.identity(value)
def _do_update():
"""Updates the var, returns the updated value."""
"""Updates the var and weight, returns their updated ratio."""
# Update the variables without zero debiasing. The debiasing will be
# accomplished by dividing the exponential moving average by the weight.
# For example, after a single update, the moving average would be
# (1-decay) * value. and the weight will be 1-decay, with their ratio
# giving the value.
# Make sure the weight is not updated until before r and d computation.
with ops.control_dependencies([value]):
weight_value = array_ops.constant(1., dtype=weight.dtype)
new_var = self._assign_moving_average(var, value, self.renorm_momentum,
inputs_size)
return new_var
new_weight = self._assign_moving_average(weight, weight_value,
self.renorm_momentum,
inputs_size)
# TODO(yuefengz): the updates to var and weighted can not be batched
# together if we fetch their updated values here. Consider calculating
# new values and delaying the updates.
return new_var / new_weight
def _fake_update():
return array_ops.identity(var)
return tf_utils.smart_cond(training, _do_update, _fake_update)
# TODO(yuefengz): colocate the operations
update_new_mean = _update_renorm_variable(self.renorm_mean, mean,
update_new_mean = _update_renorm_variable(self.renorm_mean,
self.renorm_mean_weight, mean,
inputs_size)
update_new_stddev = _update_renorm_variable(self.renorm_stddev, stddev,
inputs_size)
update_new_stddev = _update_renorm_variable(self.renorm_stddev,
self.renorm_stddev_weight,
stddev, inputs_size)
# Update the inference mode moving averages with the batch value.
with ops.control_dependencies([update_new_mean, update_new_stddev]):
@ -761,24 +747,7 @@ class BatchNormalizationBase(Layer):
return tf_utils.smart_cond(training, true_branch, false_branch)
def variance_update():
"""Update the moving variance."""
def true_branch_renorm():
# We apply epsilon as part of the moving_stddev to mirror the training
# code path.
moving_stddev = _do_update(self.moving_stddev,
math_ops.sqrt(new_variance + self.epsilon))
return self._assign_new_value(
self.moving_variance,
# Apply relu in case floating point rounding causes it to go
# negative.
K.relu(moving_stddev * moving_stddev - self.epsilon))
if self.renorm:
true_branch = true_branch_renorm
else:
true_branch = lambda: _do_update(self.moving_variance, new_variance)
true_branch = lambda: _do_update(self.moving_variance, new_variance)
false_branch = lambda: self.moving_variance
return tf_utils.smart_cond(training, true_branch, false_branch)

View File

@ -894,9 +894,10 @@ class BNTest(test.TestCase):
yt = bn.apply(xt, training=training)
moving_mean = 0.
moving_stddev = 1.
moving_variance = 1.
renorm_mean = 0.
renorm_stddev = 1.
renorm_weight = 0.
with self.session(use_gpu=True) as sess:
self.evaluate(variables.global_variables_initializer())
for _ in range(5):
@ -911,10 +912,10 @@ class BNTest(test.TestCase):
renorm_mean += (mean - renorm_mean) * (1. - renorm_momentum)
renorm_stddev += (stddev - renorm_stddev) * (1. - renorm_momentum)
moving_mean += (mean - moving_mean) * (1. - momentum)
moving_stddev += (stddev - moving_stddev) * (1. - momentum)
moving_variance += (variance - moving_variance) * (1. - momentum)
y_test = ((x - moving_mean) /
(moving_stddev * moving_stddev)**0.5 * gamma) + beta
y_test = ((x - moving_mean) / (moving_variance + epsilon) ** 0.5 *
gamma) + beta
yt_val_train, _, _ = sess.run([yt] + bn.updates,
feed_dict={xt: x, training: True})
@ -924,62 +925,6 @@ class BNTest(test.TestCase):
self.assertAllClose(y_train, yt_val_train, atol=1e-5)
self.assertAllClose(y_test, yt_val_test, atol=1e-5)
def testRenormNoClippingSameMomentumGivesSameTestTrain(self):
shape = (4, 3)
xt = array_ops.placeholder(dtypes.float32, shape)
momentum = 0.9
renorm_momentum = 0.9
gamma = 2.
beta = 3.
epsilon = 0.001
bn = normalization_layers.BatchNormalization(
axis=1,
gamma_initializer=init_ops.constant_initializer(gamma),
beta_initializer=init_ops.constant_initializer(beta),
epsilon=epsilon,
momentum=momentum,
renorm=True,
renorm_clipping=None,
renorm_momentum=momentum)
training = array_ops.placeholder(dtypes.bool)
yt = bn.apply(xt, training=training)
moving_mean = 0.
moving_stddev = 1.
renorm_mean = 0.
renorm_stddev = 1.
with self.session(use_gpu=True) as sess:
self.evaluate(variables.global_variables_initializer())
for step in range(6):
x = np.random.random(shape)
mean = x.mean(0)
variance = x.var(0)
stddev = np.sqrt(variance + epsilon)
r = (stddev / renorm_stddev)
d = ((mean - renorm_mean) / renorm_stddev)
y_test = ((x - moving_mean) /
(moving_stddev * moving_stddev)**0.5 * gamma) + beta
y_train = ((x - mean) / stddev * r + d) * gamma + beta
renorm_mean += (mean - renorm_mean) * (1. - renorm_momentum)
renorm_stddev += (stddev - renorm_stddev) * (1. - renorm_momentum)
moving_mean += (mean - moving_mean) * (1. - momentum)
moving_stddev += (stddev - moving_stddev) * (1. - momentum)
# Compute test values first, before the train mode updates the moving
# averages.
yt_val_test, _, _ = sess.run([yt] + bn.updates,
feed_dict={xt: x, training: False})
yt_val_train, _, _ = sess.run([yt] + bn.updates,
feed_dict={xt: x, training: True})
# Due to initialization inconsistencies, values may not be identical
# on the first iteration (but shouldn't be different by much more than
# epsilon). After the first iteration they should be identical.
atol = epsilon * 1.5 if step == 0 else 1e-5
self.assertAllClose(y_train, yt_val_train, atol=atol)
self.assertAllClose(y_test, yt_val_test, atol=atol)
self.assertAllClose(yt_val_train, yt_val_test, atol=atol)
def testAdjustment(self):
shape = (4, 3)
xt = array_ops.placeholder(dtypes.float32, shape)
@ -1051,7 +996,7 @@ class BNTest(test.TestCase):
yt = bn.apply(xt, training=training)
moving_mean = 0.
moving_stddev = 1.
moving_variance = 1.
renorm_mean = 0.
renorm_stddev = 1.
with self.session(use_gpu=True) as sess:
@ -1074,10 +1019,10 @@ class BNTest(test.TestCase):
renorm_mean += (mean - renorm_mean) * (1. - renorm_momentum)
renorm_stddev += (stddev - renorm_stddev) * (1. - renorm_momentum)
moving_mean += (mean - moving_mean) * (1. - momentum)
moving_stddev += (stddev - moving_stddev) * (1. - momentum)
moving_variance += (variance - moving_variance) * (1. - momentum)
y_test = ((x - moving_mean) /
(moving_stddev * moving_stddev)**0.5 * gamma) + beta
y_test = ((x - moving_mean) / (moving_variance + epsilon) ** 0.5 *
gamma) + beta
self.assertAllClose(y_train, yt_val_train, atol=1e-5)
self.assertAllClose(y_test, yt_val_test, atol=1e-5)