Merged commit includes the following changes:
253889825 by A. Unique TensorFlower<gardener@tensorflow.org>: Automated rollback of changelist 253879973. 253889366 by A. Unique TensorFlower<gardener@tensorflow.org>: Internal change. -- 253886384 by A. Unique TensorFlower<gardener@tensorflow.org>: Automated Temporary rollback to fix breakages END_PUBLIC Automated g4 rollback of changelist 253834965. 253883891 by A. Unique TensorFlower<gardener@tensorflow.org>: Add commonly used multi-worker training utilities in Keras to multi_worker_util.py. -- 253882899 by A. Unique TensorFlower<gardener@tensorflow.org>: When lowering While V2 to While V1, add control dependency from body function call to NextIteration node. -- 253880779 by A. Unique TensorFlower<gardener@tensorflow.org>: Fixed incorrect removing of Add with scalar broadcast. -- 253879973 by A. Unique TensorFlower<gardener@tensorflow.org>: Internal changes -- PiperOrigin-RevId: 253889825
This commit is contained in:
parent
e985e958b3
commit
56e1e2ad5b
@ -375,6 +375,7 @@ Status LowerWhileHelper::CreateNextIterationNodes() {
|
||||
TF_RETURN_IF_ERROR(NodeBuilder(NewName("next_iteration"), "NextIteration",
|
||||
graph_->op_registry(), &debug_info_)
|
||||
.Input(NodeOut(body_call_node_, i))
|
||||
.ControlInput(body_call_node_)
|
||||
.Device(while_op_->requested_device())
|
||||
.Finalize(graph_, &next_iteration));
|
||||
next_iterations_nodes_.emplace_back(next_iteration);
|
||||
|
@ -78,7 +78,8 @@ std::unique_ptr<SequenceTransformation> NewRemoveSingleInputAdd() {
|
||||
auto& attr =
|
||||
absl::any_cast<const AddAttributes&>(node->operation.attributes);
|
||||
return absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param) ==
|
||||
nullptr;
|
||||
nullptr &&
|
||||
absl::get_if<float>(&attr.param) == nullptr;
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -1,264 +1,264 @@
|
||||
tensorflow/contrib/tpu/profiler/pip_package/BUILD
|
||||
tensorflow/contrib/tpu/profiler/pip_package/setup.py
|
||||
tensorflow/contrib/tpu/profiler/pip_package/README
|
||||
tensorflow/contrib/tpu/profiler/pip_package/build_pip_package.sh
|
||||
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
|
||||
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/__init__.py
|
||||
tensorflow/contrib/mpi/BUILD
|
||||
tensorflow/stream_executor/build_defs.bzl
|
||||
tensorflow/python/autograph/core/config.py
|
||||
tensorflow/tools/ci_build/remote/BUILD
|
||||
tensorflow/tools/pip_package/README
|
||||
tensorflow/tools/pip_package/MANIFEST.in
|
||||
tensorflow/tools/pip_package/simple_console.py
|
||||
tensorflow/tools/pip_package/build_pip_package.sh
|
||||
tensorflow/tools/pip_package/check_load_py_test.py
|
||||
tensorflow/tools/pip_package/pip_smoke_test.py
|
||||
tensorflow/tools/pip_package/simple_console_for_windows.py
|
||||
tensorflow/tools/pip_package/setup.py
|
||||
tensorflow/tools/pip_package/BUILD
|
||||
tensorflow/tools/lib_package/concat_licenses.sh
|
||||
tensorflow/tools/lib_package/libtensorflow_test.c
|
||||
tensorflow/tools/lib_package/LibTensorFlowTest.java
|
||||
tensorflow/tools/lib_package/BUILD
|
||||
tensorflow/tools/lib_package/libtensorflow_test.sh
|
||||
tensorflow/tools/lib_package/README.md
|
||||
tensorflow/tools/lib_package/libtensorflow_java_test.sh
|
||||
tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
|
||||
tensorflow/tools/def_file_filter/BUILD
|
||||
tensorflow/tools/def_file_filter/BUILD.tpl
|
||||
tensorflow/tools/def_file_filter/def_file_filter.py.tpl
|
||||
tensorflow/third_party/mkl/MKL_LICENSE
|
||||
tensorflow/third_party/mkl/LICENSE
|
||||
tensorflow/third_party/mkl/BUILD
|
||||
tensorflow/third_party/mkl/mkl.BUILD
|
||||
tensorflow/third_party/mkl/build_defs.bzl
|
||||
tensorflow/third_party/backports_weakref.BUILD
|
||||
tensorflow/third_party/toolchains/clang6/BUILD
|
||||
tensorflow/third_party/toolchains/clang6/README.md
|
||||
tensorflow/third_party/toolchains/clang6/repo.bzl
|
||||
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
|
||||
tensorflow/third_party/toolchains/clang6/clang.BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/py/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.0/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/dummy_toolchain.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_025/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
|
||||
tensorflow/third_party/systemlibs/nsync.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
|
||||
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
|
||||
tensorflow/third_party/systemlibs/curl.BUILD
|
||||
tensorflow/third_party/systemlibs/cython.BUILD
|
||||
tensorflow/third_party/systemlibs/astor.BUILD
|
||||
tensorflow/third_party/systemlibs/jsoncpp.BUILD
|
||||
tensorflow/third_party/systemlibs/png.BUILD
|
||||
tensorflow/third_party/systemlibs/pcre.BUILD
|
||||
tensorflow/third_party/systemlibs/grpc.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
tensorflow/third_party/systemlibs/double_conversion.BUILD
|
||||
tensorflow/third_party/systemlibs/six.BUILD
|
||||
tensorflow/third_party/systemlibs/zlib.BUILD
|
||||
tensorflow/third_party/systemlibs/lmdb.BUILD
|
||||
tensorflow/third_party/systemlibs/sqlite.BUILD
|
||||
tensorflow/third_party/systemlibs/gast.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.BUILD
|
||||
tensorflow/third_party/systemlibs/boringssl.BUILD
|
||||
tensorflow/third_party/systemlibs/BUILD.tpl
|
||||
tensorflow/third_party/systemlibs/BUILD
|
||||
tensorflow/third_party/systemlibs/termcolor.BUILD
|
||||
tensorflow/third_party/systemlibs/gif.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.bzl
|
||||
tensorflow/third_party/systemlibs/snappy.BUILD
|
||||
tensorflow/third_party/systemlibs/googleapis.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
|
||||
tensorflow/third_party/systemlibs/re2.BUILD
|
||||
tensorflow/third_party/systemlibs/swig.BUILD
|
||||
tensorflow/third_party/systemlibs/syslibs_configure.bzl
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
|
||||
tensorflow/third_party/pprof.BUILD
|
||||
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
|
||||
tensorflow/third_party/toolchains/remote/BUILD.tpl
|
||||
tensorflow/third_party/toolchains/remote/BUILD
|
||||
tensorflow/third_party/toolchains/remote/configure.bzl
|
||||
tensorflow/third_party/toolchains/cpus/py3/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/py/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl
|
||||
tensorflow/third_party/toolchains/cpus/arm/CROSSTOOL.tpl
|
||||
tensorflow/third_party/toolchains/cpus/arm/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/py3/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/py/BUILD
|
||||
tensorflow/third_party/toolchains/remote/configure.bzl
|
||||
tensorflow/third_party/toolchains/remote/BUILD.tpl
|
||||
tensorflow/third_party/toolchains/remote/BUILD
|
||||
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
|
||||
tensorflow/third_party/toolchains/BUILD
|
||||
tensorflow/third_party/gpus/BUILD
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_025/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/py/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/dummy_toolchain.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.0/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/clang6/repo.bzl
|
||||
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
|
||||
tensorflow/third_party/toolchains/clang6/BUILD
|
||||
tensorflow/third_party/toolchains/clang6/clang.BUILD
|
||||
tensorflow/third_party/toolchains/clang6/README.md
|
||||
tensorflow/third_party/farmhash.BUILD
|
||||
tensorflow/third_party/git/BUILD.tpl
|
||||
tensorflow/third_party/git/git_configure.bzl
|
||||
tensorflow/third_party/git/BUILD
|
||||
tensorflow/third_party/cub.BUILD
|
||||
tensorflow/third_party/gpus/cuda_configure.bzl
|
||||
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/rocm/BUILD.tpl
|
||||
tensorflow/third_party/gpus/rocm/BUILD
|
||||
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
|
||||
tensorflow/third_party/gpus/rocm_configure.bzl
|
||||
tensorflow/third_party/gpus/find_cuda_config.py
|
||||
tensorflow/third_party/gpus/crosstool/LICENSE
|
||||
tensorflow/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
|
||||
tensorflow/third_party/gpus/crosstool/BUILD.tpl
|
||||
tensorflow/third_party/gpus/crosstool/BUILD
|
||||
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/cuda/LICENSE
|
||||
tensorflow/third_party/gpus/cuda/BUILD.tpl
|
||||
tensorflow/third_party/gpus/cuda/BUILD.windows.tpl
|
||||
tensorflow/third_party/gpus/cuda/cuda_config.h.tpl
|
||||
tensorflow/third_party/gpus/cuda/BUILD.tpl
|
||||
tensorflow/third_party/gpus/cuda/BUILD
|
||||
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
|
||||
tensorflow/third_party/gpus/rocm/BUILD
|
||||
tensorflow/third_party/gpus/rocm/BUILD.tpl
|
||||
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/cuda_configure.bzl
|
||||
tensorflow/third_party/gpus/find_cuda_config.py
|
||||
tensorflow/third_party/gpus/rocm_configure.bzl
|
||||
tensorflow/third_party/snappy.BUILD
|
||||
tensorflow/third_party/cython.BUILD
|
||||
tensorflow/third_party/farmhash.BUILD
|
||||
tensorflow/third_party/eigen3/Eigen/Cholesky
|
||||
tensorflow/third_party/eigen3/Eigen/QR
|
||||
tensorflow/third_party/eigen3/Eigen/LU
|
||||
tensorflow/third_party/eigen3/Eigen/Core
|
||||
tensorflow/third_party/eigen3/Eigen/SVD
|
||||
tensorflow/third_party/eigen3/Eigen/Eigenvalues
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
|
||||
tensorflow/third_party/eigen3/gpu_packet_math.patch
|
||||
tensorflow/third_party/eigen3/LICENSE
|
||||
tensorflow/third_party/eigen3/BUILD
|
||||
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
|
||||
tensorflow/third_party/systemlibs/absl_py.BUILD
|
||||
tensorflow/third_party/systemlibs/curl.BUILD
|
||||
tensorflow/third_party/systemlibs/termcolor.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
|
||||
tensorflow/third_party/systemlibs/grpc.BUILD
|
||||
tensorflow/third_party/systemlibs/swig.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.bzl
|
||||
tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
tensorflow/third_party/systemlibs/BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
|
||||
tensorflow/third_party/systemlibs/astor.BUILD
|
||||
tensorflow/third_party/systemlibs/six.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
|
||||
tensorflow/third_party/systemlibs/boringssl.BUILD
|
||||
tensorflow/third_party/systemlibs/nsync.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
|
||||
tensorflow/third_party/systemlibs/gif.BUILD
|
||||
tensorflow/third_party/systemlibs/pcre.BUILD
|
||||
tensorflow/third_party/systemlibs/BUILD.tpl
|
||||
tensorflow/third_party/systemlibs/snappy.BUILD
|
||||
tensorflow/third_party/systemlibs/gast.BUILD
|
||||
tensorflow/third_party/systemlibs/cython.BUILD
|
||||
tensorflow/third_party/systemlibs/double_conversion.BUILD
|
||||
tensorflow/third_party/systemlibs/zlib.BUILD
|
||||
tensorflow/third_party/systemlibs/jsoncpp.BUILD
|
||||
tensorflow/third_party/systemlibs/re2.BUILD
|
||||
tensorflow/third_party/systemlibs/lmdb.BUILD
|
||||
tensorflow/third_party/systemlibs/googleapis.BUILD
|
||||
tensorflow/third_party/systemlibs/png.BUILD
|
||||
tensorflow/third_party/systemlibs/syslibs_configure.bzl
|
||||
tensorflow/third_party/systemlibs/sqlite.BUILD
|
||||
tensorflow/third_party/python_runtime/BUILD
|
||||
tensorflow/third_party/sycl/crosstool/BUILD
|
||||
tensorflow/third_party/ngraph/LICENSE
|
||||
tensorflow/third_party/ngraph/tbb.BUILD
|
||||
tensorflow/third_party/ngraph/BUILD
|
||||
tensorflow/third_party/ngraph/ngraph.BUILD
|
||||
tensorflow/third_party/ngraph/build_defs.bzl
|
||||
tensorflow/third_party/ngraph/NGRAPH_LICENSE
|
||||
tensorflow/third_party/ngraph/ngraph_tf.BUILD
|
||||
tensorflow/third_party/ngraph/nlohmann_json.BUILD
|
||||
tensorflow/third_party/clang_toolchain/download_clang.bzl
|
||||
tensorflow/third_party/clang_toolchain/BUILD
|
||||
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
|
||||
tensorflow/third_party/gast.BUILD
|
||||
tensorflow/third_party/llvm/BUILD
|
||||
tensorflow/third_party/llvm/expand_cmake_vars.py
|
||||
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
|
||||
tensorflow/third_party/llvm/llvm.bzl
|
||||
tensorflow/third_party/icu/udata.patch
|
||||
tensorflow/third_party/fft2d/fft2d.h
|
||||
tensorflow/third_party/fft2d/BUILD
|
||||
tensorflow/third_party/fft2d/fft.h
|
||||
tensorflow/third_party/fft2d/LICENSE
|
||||
tensorflow/third_party/fft2d/fft2d.BUILD
|
||||
tensorflow/third_party/nccl/archive.BUILD
|
||||
tensorflow/third_party/nccl/LICENSE
|
||||
tensorflow/third_party/nccl/system.BUILD.tpl
|
||||
tensorflow/third_party/nccl/nccl_configure.bzl
|
||||
tensorflow/third_party/nccl/build_defs.bzl.tpl
|
||||
tensorflow/third_party/nccl/BUILD
|
||||
tensorflow/third_party/boringssl/BUILD
|
||||
tensorflow/third_party/mpi/.gitignore
|
||||
tensorflow/third_party/mpi/BUILD
|
||||
tensorflow/third_party/tensorrt/LICENSE
|
||||
tensorflow/third_party/tensorrt/BUILD
|
||||
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
|
||||
tensorflow/third_party/tensorrt/BUILD.tpl
|
||||
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
|
||||
tensorflow/third_party/kafka/config.patch
|
||||
tensorflow/third_party/kafka/BUILD
|
||||
tensorflow/third_party/android/BUILD
|
||||
tensorflow/third_party/android/android.bzl.tpl
|
||||
tensorflow/third_party/android/android_configure.bzl
|
||||
tensorflow/third_party/android/android_configure.BUILD.tpl
|
||||
tensorflow/third_party/tflite_smartreply.BUILD
|
||||
tensorflow/third_party/gpus/BUILD
|
||||
tensorflow/third_party/common.bzl
|
||||
tensorflow/third_party/tflite_mobilenet_quant.BUILD
|
||||
tensorflow/third_party/linenoise.BUILD
|
||||
tensorflow/third_party/curl.BUILD
|
||||
tensorflow/third_party/mkl_dnn/LICENSE
|
||||
tensorflow/third_party/mkl_dnn/mkldnn.BUILD
|
||||
tensorflow/third_party/pcre.BUILD
|
||||
tensorflow/third_party/pybind11.BUILD
|
||||
tensorflow/third_party/linenoise.BUILD
|
||||
tensorflow/third_party/sqlite.BUILD
|
||||
tensorflow/third_party/common.bzl
|
||||
tensorflow/third_party/com_google_absl.BUILD
|
||||
tensorflow/third_party/pprof.BUILD
|
||||
tensorflow/third_party/BUILD
|
||||
tensorflow/third_party/tflite_mobilenet_quant.BUILD
|
||||
tensorflow/third_party/wrapt.BUILD
|
||||
tensorflow/third_party/lmdb.BUILD
|
||||
tensorflow/third_party/git/BUILD.tpl
|
||||
tensorflow/third_party/git/BUILD
|
||||
tensorflow/third_party/git/git_configure.bzl
|
||||
tensorflow/third_party/protobuf/BUILD
|
||||
tensorflow/third_party/enum34.BUILD
|
||||
tensorflow/third_party/tflite_mobilenet.BUILD
|
||||
tensorflow/third_party/py/BUILD
|
||||
tensorflow/third_party/py/BUILD.tpl
|
||||
tensorflow/third_party/py/numpy/BUILD
|
||||
tensorflow/third_party/py/python_configure.bzl
|
||||
tensorflow/third_party/termcolor.BUILD
|
||||
tensorflow/third_party/png_fix_rpi.patch
|
||||
tensorflow/third_party/swig.BUILD
|
||||
tensorflow/third_party/astor.BUILD
|
||||
tensorflow/third_party/fft2d/LICENSE
|
||||
tensorflow/third_party/fft2d/fft2d.BUILD
|
||||
tensorflow/third_party/fft2d/fft2d.h
|
||||
tensorflow/third_party/fft2d/fft.h
|
||||
tensorflow/third_party/fft2d/BUILD
|
||||
tensorflow/third_party/ngraph/LICENSE
|
||||
tensorflow/third_party/ngraph/build_defs.bzl
|
||||
tensorflow/third_party/ngraph/tbb.BUILD
|
||||
tensorflow/third_party/ngraph/ngraph.BUILD
|
||||
tensorflow/third_party/ngraph/nlohmann_json.BUILD
|
||||
tensorflow/third_party/ngraph/BUILD
|
||||
tensorflow/third_party/ngraph/ngraph_tf.BUILD
|
||||
tensorflow/third_party/ngraph/NGRAPH_LICENSE
|
||||
tensorflow/third_party/grpc/BUILD
|
||||
tensorflow/third_party/curl.BUILD
|
||||
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
|
||||
tensorflow/third_party/cython.BUILD
|
||||
tensorflow/third_party/icu/udata.patch
|
||||
tensorflow/third_party/astor.BUILD
|
||||
tensorflow/third_party/jsoncpp.BUILD
|
||||
tensorflow/third_party/sycl/crosstool/BUILD
|
||||
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
|
||||
tensorflow/third_party/llvm/expand_cmake_vars.py
|
||||
tensorflow/third_party/llvm/llvm.bzl
|
||||
tensorflow/third_party/llvm/BUILD
|
||||
tensorflow/third_party/png.BUILD
|
||||
tensorflow/third_party/googleapis.BUILD
|
||||
tensorflow/third_party/mpi_collectives/BUILD
|
||||
tensorflow/third_party/nanopb.BUILD
|
||||
tensorflow/third_party/gif.BUILD
|
||||
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
|
||||
tensorflow/third_party/codegen.BUILD
|
||||
tensorflow/third_party/enum34.BUILD
|
||||
tensorflow/third_party/kafka/config.patch
|
||||
tensorflow/third_party/kafka/BUILD
|
||||
tensorflow/third_party/pcre.BUILD
|
||||
tensorflow/third_party/mpi/BUILD
|
||||
tensorflow/third_party/mpi/.gitignore
|
||||
tensorflow/third_party/clang_toolchain/BUILD
|
||||
tensorflow/third_party/clang_toolchain/download_clang.bzl
|
||||
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
|
||||
tensorflow/third_party/tflite_ovic_testdata.BUILD
|
||||
tensorflow/third_party/repo.bzl
|
||||
tensorflow/third_party/png_fix_rpi.patch
|
||||
tensorflow/third_party/py/python_configure.bzl
|
||||
tensorflow/third_party/py/BUILD.tpl
|
||||
tensorflow/third_party/py/BUILD
|
||||
tensorflow/third_party/py/numpy/BUILD
|
||||
tensorflow/third_party/double_conversion.BUILD
|
||||
tensorflow/third_party/six.BUILD
|
||||
tensorflow/third_party/tflite_mobilenet_float.BUILD
|
||||
tensorflow/third_party/repo.bzl
|
||||
tensorflow/third_party/codegen.BUILD
|
||||
tensorflow/third_party/cub.BUILD
|
||||
tensorflow/third_party/jsoncpp.BUILD
|
||||
tensorflow/third_party/tflite_ovic_testdata.BUILD
|
||||
tensorflow/third_party/__init__.py
|
||||
tensorflow/third_party/libxsmm.BUILD
|
||||
tensorflow/third_party/zlib.BUILD
|
||||
tensorflow/third_party/lmdb.BUILD
|
||||
tensorflow/third_party/nanopb.BUILD
|
||||
tensorflow/third_party/pybind11.BUILD
|
||||
tensorflow/third_party/android/android.bzl.tpl
|
||||
tensorflow/third_party/android/BUILD
|
||||
tensorflow/third_party/android/android_configure.BUILD.tpl
|
||||
tensorflow/third_party/android/android_configure.bzl
|
||||
tensorflow/third_party/tflite_mobilenet_float.BUILD
|
||||
tensorflow/third_party/sqlite.BUILD
|
||||
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
|
||||
tensorflow/third_party/tensorrt/LICENSE
|
||||
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
|
||||
tensorflow/third_party/tensorrt/BUILD.tpl
|
||||
tensorflow/third_party/tensorrt/BUILD
|
||||
tensorflow/third_party/gast.BUILD
|
||||
tensorflow/third_party/mpi_collectives/BUILD
|
||||
tensorflow/third_party/libxsmm.BUILD
|
||||
tensorflow/third_party/eigen.BUILD
|
||||
tensorflow/third_party/com_google_absl.BUILD
|
||||
tensorflow/third_party/eigen3/LICENSE
|
||||
tensorflow/third_party/eigen3/gpu_packet_math.patch
|
||||
tensorflow/third_party/eigen3/BUILD
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
|
||||
tensorflow/third_party/eigen3/Eigen/QR
|
||||
tensorflow/third_party/eigen3/Eigen/SVD
|
||||
tensorflow/third_party/eigen3/Eigen/LU
|
||||
tensorflow/third_party/eigen3/Eigen/Cholesky
|
||||
tensorflow/third_party/eigen3/Eigen/Eigenvalues
|
||||
tensorflow/third_party/eigen3/Eigen/Core
|
||||
tensorflow/third_party/BUILD
|
||||
tensorflow/third_party/termcolor.BUILD
|
||||
tensorflow/third_party/gif.BUILD
|
||||
tensorflow/third_party/tflite_mobilenet.BUILD
|
||||
tensorflow/third_party/__init__.py
|
||||
tensorflow/third_party/mkl/LICENSE
|
||||
tensorflow/third_party/mkl/build_defs.bzl
|
||||
tensorflow/third_party/mkl/mkl.BUILD
|
||||
tensorflow/third_party/mkl/MKL_LICENSE
|
||||
tensorflow/third_party/mkl/BUILD
|
||||
tensorflow/third_party/nccl/build_defs.bzl.tpl
|
||||
tensorflow/third_party/nccl/LICENSE
|
||||
tensorflow/third_party/nccl/nccl_configure.bzl
|
||||
tensorflow/third_party/nccl/archive.BUILD
|
||||
tensorflow/third_party/nccl/BUILD
|
||||
tensorflow/third_party/nccl/system.BUILD.tpl
|
||||
tensorflow/third_party/snappy.BUILD
|
||||
tensorflow/third_party/python_runtime/BUILD
|
||||
tensorflow/third_party/googleapis.BUILD
|
||||
tensorflow/third_party/wrapt.BUILD
|
||||
tensorflow/third_party/boringssl/BUILD
|
||||
tensorflow/third_party/protobuf/BUILD
|
||||
tensorflow/third_party/backports_weakref.BUILD
|
||||
tensorflow/third_party/tflite_smartreply.BUILD
|
||||
tensorflow/third_party/swig.BUILD
|
||||
tensorflow/compat_template.__init__.py
|
||||
tensorflow/tools/lib_package/libtensorflow_test.sh
|
||||
tensorflow/tools/lib_package/libtensorflow_java_test.sh
|
||||
tensorflow/tools/lib_package/libtensorflow_test.c
|
||||
tensorflow/tools/lib_package/concat_licenses.sh
|
||||
tensorflow/tools/lib_package/LibTensorFlowTest.java
|
||||
tensorflow/tools/lib_package/BUILD
|
||||
tensorflow/tools/lib_package/README.md
|
||||
tensorflow/tools/pip_package/check_load_py_test.py
|
||||
tensorflow/tools/pip_package/simple_console.py
|
||||
tensorflow/tools/pip_package/pip_smoke_test.py
|
||||
tensorflow/tools/pip_package/BUILD
|
||||
tensorflow/tools/pip_package/simple_console_for_windows.py
|
||||
tensorflow/tools/pip_package/build_pip_package.sh
|
||||
tensorflow/tools/pip_package/README
|
||||
tensorflow/tools/pip_package/setup.py
|
||||
tensorflow/tools/pip_package/MANIFEST.in
|
||||
tensorflow/tools/ci_build/remote/BUILD
|
||||
tensorflow/tools/def_file_filter/def_file_filter.py.tpl
|
||||
tensorflow/tools/def_file_filter/BUILD.tpl
|
||||
tensorflow/tools/def_file_filter/BUILD
|
||||
tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
|
||||
tensorflow/api_template.__init__.py
|
||||
tensorflow/contrib/tpu/profiler/pip_package/BUILD
|
||||
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/__init__.py
|
||||
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
|
||||
tensorflow/contrib/tpu/profiler/pip_package/build_pip_package.sh
|
||||
tensorflow/contrib/tpu/profiler/pip_package/README
|
||||
tensorflow/contrib/tpu/profiler/pip_package/setup.py
|
||||
tensorflow/contrib/mpi/BUILD
|
||||
tensorflow/python/autograph/core/config.py
|
||||
tensorflow/virtual_root_template_v2.__init__.py
|
||||
tensorflow/__init__.py
|
||||
tensorflow/stream_executor/build_defs.bzl
|
||||
tensorflow/api_template_v1.__init__.py
|
||||
tensorflow/compat_template_v1.__init__.py
|
||||
tensorflow/compat_template.__init__.py
|
||||
tensorflow/api_template.__init__.py
|
||||
tensorflow/__init__.py
|
||||
tensorflow/virtual_root_template_v2.__init__.py
|
||||
tensorflow/virtual_root_template_v1.__init__.py
|
@ -18,7 +18,11 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from tensorflow.core.protobuf import cluster_pb2
|
||||
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
|
||||
@ -50,6 +54,7 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id):
|
||||
"""Validates `cluster_spec`.
|
||||
|
||||
It checks:
|
||||
0) None of `cluster_spec`, `task_type`, and `task_id` is `None`.
|
||||
1) task type is one of "chief", "worker" or "evaluator".
|
||||
2) whether there is such a task type as `task_type` in the `cluster_spec`.
|
||||
3) whether there is at most one "chief" job.
|
||||
@ -64,6 +69,10 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id):
|
||||
Throws:
|
||||
ValueError: if `cluster_spec` fails any check.
|
||||
"""
|
||||
if cluster_spec is None or task_type is None or task_id is None:
|
||||
raise ValueError(
|
||||
"None of `cluster_spec`, `task_type`, and `task_id` should be `None`.")
|
||||
|
||||
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
|
||||
if task_type not in ("chief", "worker", "evaluator", "ps"):
|
||||
raise ValueError(
|
||||
@ -84,13 +93,16 @@ def _validate_cluster_spec(cluster_spec, task_type, task_id):
|
||||
"The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type))
|
||||
|
||||
|
||||
def is_chief(cluster_spec, task_type, task_id):
|
||||
def is_chief(cluster_spec=None, task_type=None, task_id=None):
|
||||
"""Returns whether the given task is chief in the cluster.
|
||||
|
||||
Since there is at most one evaluator and the evaluator itself should be
|
||||
independent of the training cluster, the evaluator job is also a chief job on
|
||||
its own.
|
||||
|
||||
If this is currently running under a `_WorkerContext` of distribute
|
||||
coordinator, the arguments can be omitted as the result is already available.
|
||||
|
||||
Args:
|
||||
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
|
||||
cluster configurations.
|
||||
@ -104,6 +116,10 @@ def is_chief(cluster_spec, task_type, task_id):
|
||||
ValueError: if `task_type` is not in the `cluster_spec` or `task_id` exceeds
|
||||
the maximum id of the `task_type`.
|
||||
"""
|
||||
if has_worker_context():
|
||||
# If a worker context exists, use the value provided by it.
|
||||
return dc_context.get_current_worker_context().is_chief
|
||||
|
||||
_validate_cluster_spec(cluster_spec, task_type, task_id)
|
||||
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
|
||||
|
||||
@ -208,3 +224,48 @@ def id_in_cluster(cluster_spec, task_type, task_id):
|
||||
|
||||
# We currently don't assign ids to other tasks.
|
||||
raise ValueError("There is no id for task_type %r" % task_type)
|
||||
|
||||
|
||||
def in_multi_worker_mode():
|
||||
"""Whether the program is operating in Multi-Worker setting."""
|
||||
# TODO(rchao): Consider a warning if user uses multiple `model` method
|
||||
# calls in multi-worker setting.
|
||||
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
|
||||
cluster_spec = server_lib.ClusterSpec(tf_config.get("cluster", {}))
|
||||
return tf_config and "master" not in cluster_spec.jobs
|
||||
|
||||
|
||||
def should_save_checkpoint():
|
||||
"""Returns whether the current worker should save checkpoints.
|
||||
|
||||
In multi-worker training, if saving checkpoint is requested by user, or needed
|
||||
for fault-tolerance, the cluster should save checkpoint but not necessarily
|
||||
every worker in the cluster should.
|
||||
|
||||
Returns:
|
||||
Whether this particular worker in the cluster should save checkpoints.
|
||||
"""
|
||||
return dc_context.get_current_worker_context().should_checkpoint
|
||||
|
||||
|
||||
def should_load_checkpoint():
|
||||
"""Returns whether the current worker should load checkpoints.
|
||||
|
||||
In multi-worker training, if loading checkpoint is requested by user, or
|
||||
needed for fault-tolerance, the cluster should load checkpoint but not
|
||||
necessarily every worker in the cluster should.
|
||||
|
||||
Returns:
|
||||
Whether this particular worker in the cluster should load checkpoints.
|
||||
"""
|
||||
return dc_context.get_current_worker_context().experimental_should_init
|
||||
|
||||
|
||||
def wait_for_other_workers():
|
||||
"""Waits for other workers to reach the same call to this method."""
|
||||
return dc_context.get_current_worker_context().wait_for_other_workers()
|
||||
|
||||
|
||||
def has_worker_context():
|
||||
"""Returns whether a worker context has been entered."""
|
||||
return dc_context.get_current_worker_context() is not None
|
||||
|
@ -681,6 +681,9 @@ tf_py_test(
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
],
|
||||
tags = [
|
||||
"no_windows", # TODO(b/135556470) reenable.
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
@ -125,6 +125,7 @@ py_library(
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute:distribute_coordinator",
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
"//tensorflow/python/distribute:multi_worker_util",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -36,6 +36,7 @@ from tensorflow.python.client import session as session_module
|
||||
from tensorflow.python.distribute import distribute_coordinator as dc
|
||||
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import composite_tensor_utils
|
||||
from tensorflow.python.eager import function as eager_function
|
||||
@ -70,7 +71,6 @@ from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-im
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variables as variables_module
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util import tf_inspect
|
||||
@ -5649,15 +5649,6 @@ if not os.path.exists(_config_path):
|
||||
pass
|
||||
|
||||
|
||||
def in_multi_worker_mode():
|
||||
"""Whether we are operating in a Multi-Worker setting."""
|
||||
# TODO(rchao): Consider a warning if user uses multiple `model` method
|
||||
# calls in multi-worker setting.
|
||||
tf_config = json.loads(os.environ.get('TF_CONFIG', '{}'))
|
||||
cluster_spec = server_lib.ClusterSpec(tf_config.get('cluster', {}))
|
||||
return tf_config and 'master' not in cluster_spec.jobs
|
||||
|
||||
|
||||
def configure_and_create_distributed_session(distribution_strategy):
|
||||
"""Configure session config and create a session with it."""
|
||||
|
||||
@ -5694,7 +5685,7 @@ def configure_and_create_distributed_session(distribution_strategy):
|
||||
|
||||
set_session(session)
|
||||
|
||||
if in_multi_worker_mode():
|
||||
if multi_worker_util.in_multi_worker_mode():
|
||||
dc.run_distribute_coordinator(
|
||||
_create_session,
|
||||
distribution_strategy,
|
||||
|
@ -33,7 +33,7 @@ import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend as K
|
||||
@ -898,7 +898,7 @@ class ModelCheckpoint(Callback):
|
||||
self.save_weights_only = True
|
||||
|
||||
def on_train_begin(self, logs=None):
|
||||
if K.in_multi_worker_mode():
|
||||
if multi_worker_util.in_multi_worker_mode():
|
||||
# pylint: disable=protected-access
|
||||
# MultiWorkerTrainingState is used to manage the training state needed
|
||||
# for preemption-recovery of a worker in multi-worker training.
|
||||
@ -914,11 +914,8 @@ class ModelCheckpoint(Callback):
|
||||
# If this is not multi worker training, restoring is not needed, or
|
||||
# restoring failed, check if it should load weights on restart.
|
||||
if self.load_weights_on_restart:
|
||||
# In multi worker training, it only should load weights on restart if
|
||||
# `experimental_should_init` is True.
|
||||
# TODO(rchao): Reference `experimental_should_init` api from a util file.
|
||||
if (not K.in_multi_worker_mode() or
|
||||
dc_context.get_current_worker_context().experimental_should_init):
|
||||
if (not multi_worker_util.in_multi_worker_mode()
|
||||
or multi_worker_util.should_load_checkpoint()):
|
||||
filepath_to_load = (
|
||||
self._get_most_recently_modified_file_matching_pattern(
|
||||
self.filepath))
|
||||
@ -934,7 +931,7 @@ class ModelCheckpoint(Callback):
|
||||
filepath_to_load, e))
|
||||
|
||||
def on_train_end(self, logs=None):
|
||||
if K.in_multi_worker_mode():
|
||||
if multi_worker_util.in_multi_worker_mode():
|
||||
# In multi-worker training, on successful exit of training, delete the
|
||||
# training state backup file that was saved for the purpose of worker
|
||||
# recovery.
|
||||
@ -958,13 +955,13 @@ class ModelCheckpoint(Callback):
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
self.epochs_since_last_save += 1
|
||||
if self.save_freq == 'epoch':
|
||||
if K.in_multi_worker_mode():
|
||||
if multi_worker_util.in_multi_worker_mode():
|
||||
# Exclude training state variables in user-requested checkpoint file.
|
||||
with self._training_state.untrack_vars():
|
||||
self._save_model(epoch=epoch, logs=logs)
|
||||
else:
|
||||
self._save_model(epoch=epoch, logs=logs)
|
||||
if K.in_multi_worker_mode():
|
||||
if multi_worker_util.in_multi_worker_mode():
|
||||
# For multi-worker training, back up the weights and current training
|
||||
# state for possible future recovery.
|
||||
# TODO(rchao): Call `back_up` at finer period such as N steps.
|
||||
@ -1016,11 +1013,8 @@ class ModelCheckpoint(Callback):
|
||||
|
||||
def _get_file_path(self, epoch, logs):
|
||||
"""Returns the file path for checkpoint."""
|
||||
# TODO(rchao): Replace dc_context reference with
|
||||
# distributed_training_utils.should_current_worker_checkpoint() once
|
||||
# distributed_training_utils.py no longer depends on callbacks.py.
|
||||
if not K.in_multi_worker_mode() or dc_context.get_current_worker_context(
|
||||
).should_checkpoint:
|
||||
if not multi_worker_util.in_multi_worker_mode(
|
||||
) or multi_worker_util.should_save_checkpoint():
|
||||
return self.filepath.format(epoch=epoch + 1, **logs)
|
||||
else:
|
||||
# If this is multi-worker training, and this worker should not
|
||||
@ -1037,8 +1031,8 @@ class ModelCheckpoint(Callback):
|
||||
# Remove the checkpoint directory in multi-worker training where this worker
|
||||
# should not checkpoint. It is a dummy directory previously saved for sync
|
||||
# distributed training.
|
||||
if K.in_multi_worker_mode(
|
||||
) and not dc_context.get_current_worker_context().should_checkpoint:
|
||||
if multi_worker_util.in_multi_worker_mode(
|
||||
) and not multi_worker_util.should_save_checkpoint():
|
||||
file_io.delete_recursively(self._temp_file_dir)
|
||||
del self._temp_file_dir
|
||||
|
||||
|
@ -55,6 +55,9 @@ py_library(
|
||||
"multi_worker_training_state.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python/distribute:multi_worker_util",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
|
@ -26,6 +26,7 @@ from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -373,8 +374,8 @@ def _wait_for_variable_initialization(session):
|
||||
def init_restore_or_wait_for_variables():
|
||||
"""Initialize or restore variables or wait for variables to be initialized."""
|
||||
session = K._get_session() # pylint: disable=protected-access
|
||||
worker_context = dc_context.get_current_worker_context()
|
||||
if not worker_context or worker_context.experimental_should_init:
|
||||
if not multi_worker_util.has_worker_context(
|
||||
) or multi_worker_util.should_load_checkpoint():
|
||||
# TODO(yuefengz): if checkpoints exist, restore from checkpoint.
|
||||
K._initialize_variables(session) # pylint: disable=protected-access
|
||||
else:
|
||||
@ -1104,7 +1105,7 @@ def filter_distributed_callbacks(callbacks_list):
|
||||
The list of `Callback` instances that should be run on this worker.
|
||||
"""
|
||||
|
||||
if not K.in_multi_worker_mode():
|
||||
if not multi_worker_util.in_multi_worker_mode():
|
||||
raise ValueError(
|
||||
'filter_distributed_callbacks() should only be called when Keras '
|
||||
'is in multi worker mode.')
|
||||
|
@ -29,9 +29,9 @@ from tensorflow.python import keras
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distribute_coordinator as dc
|
||||
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.distribute import multi_worker_test_base as test_base
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import callbacks
|
||||
from tensorflow.python.keras import testing_utils
|
||||
@ -130,7 +130,7 @@ class KerasMultiWorkerCallbackTest(test_base.IndependentWorkerTestBase,
|
||||
self.filtered_correctly = True
|
||||
|
||||
def on_train_begin(self, logs):
|
||||
if not dc_context.get_current_worker_context().is_chief:
|
||||
if not multi_worker_util.is_chief():
|
||||
# Non-chief workers shouldn't run this callback.
|
||||
self.filtered_correctly = False
|
||||
|
||||
|
@ -32,9 +32,9 @@ from tensorflow.python import keras
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distribute_coordinator as dc
|
||||
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import multi_worker_test_base as test_base
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.distribute import parameter_server_strategy
|
||||
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
|
||||
from tensorflow.python.framework import ops
|
||||
@ -222,11 +222,11 @@ def _run_standalone_client(test_obj, strategy, cluster_spec):
|
||||
|
||||
# Workaround for the metrics issue (b/122928955) in async training. This
|
||||
# can only be used in standalone client mode.
|
||||
dc_context.get_current_worker_context().wait_for_other_workers()
|
||||
multi_worker_util.wait_for_other_workers()
|
||||
|
||||
model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)
|
||||
|
||||
dc_context.get_current_worker_context().wait_for_other_workers()
|
||||
multi_worker_util.wait_for_other_workers()
|
||||
|
||||
trained_loss, trained_acc = model.evaluate(train_ds, steps=steps)
|
||||
|
||||
|
@ -20,8 +20,7 @@ from __future__ import print_function
|
||||
import contextlib
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.keras import backend as K
|
||||
@ -76,7 +75,7 @@ class MultiWorkerTrainingState(object):
|
||||
# For those who should not checkpoint (e.g. non-chief worker in sync
|
||||
# training), create a temporary directory to write to (that will be
|
||||
# removed later).
|
||||
if not dc_context.get_current_worker_context().should_checkpoint:
|
||||
if not multi_worker_util.should_save_checkpoint():
|
||||
self._temp_dir, self._temp_filepath = self._get_temp_filepath()
|
||||
|
||||
# The epoch at which the checkpoint is saved. Used for fault-tolerance.
|
||||
@ -114,7 +113,7 @@ class MultiWorkerTrainingState(object):
|
||||
# call. This is because the SyncOnReadVariable needs to be synced across
|
||||
# all the workers in order to be read, and all workers need to initiate
|
||||
# that.
|
||||
if dc_context.get_current_worker_context().should_checkpoint:
|
||||
if multi_worker_util.should_save_checkpoint():
|
||||
save_filepath = self._backup_filepath
|
||||
else:
|
||||
save_filepath = self._temp_filepath
|
||||
@ -122,7 +121,7 @@ class MultiWorkerTrainingState(object):
|
||||
# Save the weights plus CKPT_SAVED_EPOCH variable.
|
||||
self._model.save_weights(save_filepath, overwrite=True)
|
||||
|
||||
if not dc_context.get_current_worker_context().should_checkpoint:
|
||||
if not multi_worker_util.should_save_checkpoint():
|
||||
# Remove the file in multi-worker training where this worker should
|
||||
# not checkpoint. It is a dummy file previously saved for sync distributed
|
||||
# training.
|
||||
@ -136,7 +135,7 @@ class MultiWorkerTrainingState(object):
|
||||
state doesn't need to be restored, or error occurred so it can't.
|
||||
"""
|
||||
self._assert_in_multi_worker_mode()
|
||||
if not dc_context.get_current_worker_context().experimental_should_init:
|
||||
if not multi_worker_util.should_load_checkpoint():
|
||||
# For multi-worker training, it should not restore a model in certain
|
||||
# worker setting (e.g. non-chief worker in ParameterServerStrategy).
|
||||
return False
|
||||
@ -159,7 +158,7 @@ class MultiWorkerTrainingState(object):
|
||||
"""
|
||||
self._assert_in_multi_worker_mode()
|
||||
tracking.AutoTrackable.__delattr__(self._model, CKPT_SAVED_EPOCH)
|
||||
if dc_context.get_current_worker_context().should_checkpoint:
|
||||
if multi_worker_util.should_save_checkpoint():
|
||||
_remove_dir(self._backup_dir)
|
||||
else:
|
||||
assert not file_io.file_exists(self._temp_dir)
|
||||
@ -218,7 +217,7 @@ class MultiWorkerTrainingState(object):
|
||||
return temp_dir, os.path.join(temp_dir, 'temp_training_state')
|
||||
|
||||
def _assert_in_multi_worker_mode(self):
|
||||
if not K.in_multi_worker_mode():
|
||||
if not multi_worker_util.in_multi_worker_mode():
|
||||
raise ValueError('MultiWorkerTrainingState is only supposed to be used '
|
||||
'in multi-worker training. This indicates some error '
|
||||
'that needs to be fixed. Please submit a bug issue to '
|
||||
|
@ -26,6 +26,7 @@ from tensorflow.python import tf2
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import monitoring
|
||||
@ -435,7 +436,7 @@ class Model(network.Network):
|
||||
"""Select training loop for fit/eval/predict based on the inputs."""
|
||||
# Case 1: distribution strategy.
|
||||
if self._distribution_strategy:
|
||||
if K.in_multi_worker_mode():
|
||||
if multi_worker_util.in_multi_worker_mode():
|
||||
return training_distributed.DistributionMultiWorkerTrainingLoop()
|
||||
else:
|
||||
return training_distributed.DistributionSingleWorkerTrainingLoop()
|
||||
|
@ -406,28 +406,10 @@ class BatchNormalizationBase(Layer):
|
||||
experimental_autocast=False)
|
||||
|
||||
if self.renorm:
|
||||
# In batch renormalization we track the inference moving stddev instead
|
||||
# of the moving variance to more closely align with the paper.
|
||||
def moving_stddev_initializer(*args, **kwargs):
|
||||
return math_ops.sqrt(
|
||||
self.moving_variance_initializer(*args, **kwargs))
|
||||
|
||||
with distribution_strategy_context.get_strategy(
|
||||
).extended.colocate_vars_with(self.moving_variance):
|
||||
self.moving_stddev = self.add_weight(
|
||||
name='moving_stddev',
|
||||
shape=param_shape,
|
||||
dtype=self._param_dtype,
|
||||
initializer=moving_stddev_initializer,
|
||||
synchronization=tf_variables.VariableSynchronization.ON_READ,
|
||||
trainable=False,
|
||||
aggregation=tf_variables.VariableAggregation.MEAN,
|
||||
experimental_autocast=False)
|
||||
|
||||
# Create variables to maintain the moving mean and standard deviation.
|
||||
# These are used in training and thus are different from the moving
|
||||
# averages above. The renorm variables are colocated with moving_mean
|
||||
# and moving_stddev.
|
||||
# and moving_variance.
|
||||
# NOTE: below, the outer `with device` block causes the current device
|
||||
# stack to be cleared. The nested ones use a `lambda` to set the desired
|
||||
# device and ignore any devices that may be set by the custom getter.
|
||||
@ -450,10 +432,17 @@ class BatchNormalizationBase(Layer):
|
||||
).extended.colocate_vars_with(self.moving_mean):
|
||||
self.renorm_mean = _renorm_variable('renorm_mean', param_shape,
|
||||
self.moving_mean_initializer)
|
||||
self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ())
|
||||
# We initialize renorm_stddev to 0, and maintain the (0-initialized)
|
||||
# renorm_stddev_weight. This allows us to (1) mix the average
|
||||
# stddev with the minibatch stddev early in training, and (2) compute
|
||||
# the unbiased average stddev by dividing renorm_stddev by the weight.
|
||||
with distribution_strategy_context.get_strategy(
|
||||
).extended.colocate_vars_with(self.moving_stddev):
|
||||
self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape,
|
||||
moving_stddev_initializer)
|
||||
).extended.colocate_vars_with(self.moving_variance):
|
||||
self.renorm_stddev = _renorm_variable(
|
||||
'renorm_stddev', param_shape, self.moving_variance_initializer)
|
||||
self.renorm_stddev_weight = _renorm_variable('renorm_stddev_weight',
|
||||
())
|
||||
finally:
|
||||
if partitioner:
|
||||
self._scope.set_partitioner(partitioner)
|
||||
@ -472,11 +461,6 @@ class BatchNormalizationBase(Layer):
|
||||
K.zeros_like(update_delta))
|
||||
return state_ops.assign_sub(variable, update_delta, name=scope)
|
||||
|
||||
def _assign_new_value(self, variable, value):
|
||||
with K.name_scope('AssignNewValue') as scope:
|
||||
with ops.colocate_with(variable):
|
||||
return state_ops.assign(variable, value, name=scope)
|
||||
|
||||
def _fused_batch_norm(self, inputs, training):
|
||||
"""Returns the output of fused batch norm."""
|
||||
beta = self.beta if self.center else self._beta_const
|
||||
@ -534,21 +518,8 @@ class BatchNormalizationBase(Layer):
|
||||
inputs_size)
|
||||
|
||||
def variance_update():
|
||||
"""Update self.moving_variance with the most recent data point."""
|
||||
if self.renorm:
|
||||
# We apply epsilon as part of the moving_stddev to mirror the training
|
||||
# code path.
|
||||
moving_stddev = self._assign_moving_average(
|
||||
self.moving_stddev, math_ops.sqrt(variance + self.epsilon),
|
||||
momentum, inputs_size)
|
||||
return self._assign_new_value(
|
||||
self.moving_variance,
|
||||
# Apply relu in case floating point rounding causes it to go
|
||||
# negative.
|
||||
K.relu(moving_stddev * moving_stddev - self.epsilon))
|
||||
else:
|
||||
return self._assign_moving_average(self.moving_variance, variance,
|
||||
momentum, inputs_size)
|
||||
return self._assign_moving_average(self.moving_variance, variance,
|
||||
momentum, inputs_size)
|
||||
|
||||
self.add_update(mean_update)
|
||||
self.add_update(variance_update)
|
||||
@ -563,8 +534,7 @@ class BatchNormalizationBase(Layer):
|
||||
# initialized with this batch's moments.
|
||||
renorm_mean = self.renorm_mean
|
||||
# Avoid divide by zero early on in training.
|
||||
renorm_stddev = math_ops.maximum(self.renorm_stddev,
|
||||
math_ops.sqrt(self.epsilon))
|
||||
renorm_stddev = math_ops.maximum(self.renorm_stddev, self.epsilon)
|
||||
# Compute the corrections for batch renorm.
|
||||
r = stddev / renorm_stddev
|
||||
d = (mean - renorm_mean) / renorm_stddev
|
||||
@ -587,24 +557,40 @@ class BatchNormalizationBase(Layer):
|
||||
lambda: d,
|
||||
lambda: array_ops.zeros_like(d))
|
||||
|
||||
def _update_renorm_variable(var, value, inputs_size):
|
||||
def _update_renorm_variable(var, weight, value, inputs_size):
|
||||
"""Updates a moving average and weight, returns the unbiased value."""
|
||||
value = array_ops.identity(value)
|
||||
def _do_update():
|
||||
"""Updates the var, returns the updated value."""
|
||||
"""Updates the var and weight, returns their updated ratio."""
|
||||
# Update the variables without zero debiasing. The debiasing will be
|
||||
# accomplished by dividing the exponential moving average by the weight.
|
||||
# For example, after a single update, the moving average would be
|
||||
# (1-decay) * value. and the weight will be 1-decay, with their ratio
|
||||
# giving the value.
|
||||
# Make sure the weight is not updated until before r and d computation.
|
||||
with ops.control_dependencies([value]):
|
||||
weight_value = array_ops.constant(1., dtype=weight.dtype)
|
||||
new_var = self._assign_moving_average(var, value, self.renorm_momentum,
|
||||
inputs_size)
|
||||
return new_var
|
||||
new_weight = self._assign_moving_average(weight, weight_value,
|
||||
self.renorm_momentum,
|
||||
inputs_size)
|
||||
# TODO(yuefengz): the updates to var and weighted can not be batched
|
||||
# together if we fetch their updated values here. Consider calculating
|
||||
# new values and delaying the updates.
|
||||
return new_var / new_weight
|
||||
|
||||
def _fake_update():
|
||||
return array_ops.identity(var)
|
||||
return tf_utils.smart_cond(training, _do_update, _fake_update)
|
||||
|
||||
# TODO(yuefengz): colocate the operations
|
||||
update_new_mean = _update_renorm_variable(self.renorm_mean, mean,
|
||||
update_new_mean = _update_renorm_variable(self.renorm_mean,
|
||||
self.renorm_mean_weight, mean,
|
||||
inputs_size)
|
||||
update_new_stddev = _update_renorm_variable(self.renorm_stddev, stddev,
|
||||
inputs_size)
|
||||
update_new_stddev = _update_renorm_variable(self.renorm_stddev,
|
||||
self.renorm_stddev_weight,
|
||||
stddev, inputs_size)
|
||||
|
||||
# Update the inference mode moving averages with the batch value.
|
||||
with ops.control_dependencies([update_new_mean, update_new_stddev]):
|
||||
@ -761,24 +747,7 @@ class BatchNormalizationBase(Layer):
|
||||
return tf_utils.smart_cond(training, true_branch, false_branch)
|
||||
|
||||
def variance_update():
|
||||
"""Update the moving variance."""
|
||||
|
||||
def true_branch_renorm():
|
||||
# We apply epsilon as part of the moving_stddev to mirror the training
|
||||
# code path.
|
||||
moving_stddev = _do_update(self.moving_stddev,
|
||||
math_ops.sqrt(new_variance + self.epsilon))
|
||||
return self._assign_new_value(
|
||||
self.moving_variance,
|
||||
# Apply relu in case floating point rounding causes it to go
|
||||
# negative.
|
||||
K.relu(moving_stddev * moving_stddev - self.epsilon))
|
||||
|
||||
if self.renorm:
|
||||
true_branch = true_branch_renorm
|
||||
else:
|
||||
true_branch = lambda: _do_update(self.moving_variance, new_variance)
|
||||
|
||||
true_branch = lambda: _do_update(self.moving_variance, new_variance)
|
||||
false_branch = lambda: self.moving_variance
|
||||
return tf_utils.smart_cond(training, true_branch, false_branch)
|
||||
|
||||
|
@ -894,9 +894,10 @@ class BNTest(test.TestCase):
|
||||
yt = bn.apply(xt, training=training)
|
||||
|
||||
moving_mean = 0.
|
||||
moving_stddev = 1.
|
||||
moving_variance = 1.
|
||||
renorm_mean = 0.
|
||||
renorm_stddev = 1.
|
||||
renorm_weight = 0.
|
||||
with self.session(use_gpu=True) as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
for _ in range(5):
|
||||
@ -911,10 +912,10 @@ class BNTest(test.TestCase):
|
||||
renorm_mean += (mean - renorm_mean) * (1. - renorm_momentum)
|
||||
renorm_stddev += (stddev - renorm_stddev) * (1. - renorm_momentum)
|
||||
moving_mean += (mean - moving_mean) * (1. - momentum)
|
||||
moving_stddev += (stddev - moving_stddev) * (1. - momentum)
|
||||
moving_variance += (variance - moving_variance) * (1. - momentum)
|
||||
|
||||
y_test = ((x - moving_mean) /
|
||||
(moving_stddev * moving_stddev)**0.5 * gamma) + beta
|
||||
y_test = ((x - moving_mean) / (moving_variance + epsilon) ** 0.5 *
|
||||
gamma) + beta
|
||||
|
||||
yt_val_train, _, _ = sess.run([yt] + bn.updates,
|
||||
feed_dict={xt: x, training: True})
|
||||
@ -924,62 +925,6 @@ class BNTest(test.TestCase):
|
||||
self.assertAllClose(y_train, yt_val_train, atol=1e-5)
|
||||
self.assertAllClose(y_test, yt_val_test, atol=1e-5)
|
||||
|
||||
def testRenormNoClippingSameMomentumGivesSameTestTrain(self):
|
||||
shape = (4, 3)
|
||||
xt = array_ops.placeholder(dtypes.float32, shape)
|
||||
momentum = 0.9
|
||||
renorm_momentum = 0.9
|
||||
gamma = 2.
|
||||
beta = 3.
|
||||
epsilon = 0.001
|
||||
bn = normalization_layers.BatchNormalization(
|
||||
axis=1,
|
||||
gamma_initializer=init_ops.constant_initializer(gamma),
|
||||
beta_initializer=init_ops.constant_initializer(beta),
|
||||
epsilon=epsilon,
|
||||
momentum=momentum,
|
||||
renorm=True,
|
||||
renorm_clipping=None,
|
||||
renorm_momentum=momentum)
|
||||
training = array_ops.placeholder(dtypes.bool)
|
||||
yt = bn.apply(xt, training=training)
|
||||
moving_mean = 0.
|
||||
moving_stddev = 1.
|
||||
renorm_mean = 0.
|
||||
renorm_stddev = 1.
|
||||
with self.session(use_gpu=True) as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
for step in range(6):
|
||||
x = np.random.random(shape)
|
||||
|
||||
mean = x.mean(0)
|
||||
variance = x.var(0)
|
||||
stddev = np.sqrt(variance + epsilon)
|
||||
r = (stddev / renorm_stddev)
|
||||
d = ((mean - renorm_mean) / renorm_stddev)
|
||||
y_test = ((x - moving_mean) /
|
||||
(moving_stddev * moving_stddev)**0.5 * gamma) + beta
|
||||
y_train = ((x - mean) / stddev * r + d) * gamma + beta
|
||||
renorm_mean += (mean - renorm_mean) * (1. - renorm_momentum)
|
||||
renorm_stddev += (stddev - renorm_stddev) * (1. - renorm_momentum)
|
||||
moving_mean += (mean - moving_mean) * (1. - momentum)
|
||||
moving_stddev += (stddev - moving_stddev) * (1. - momentum)
|
||||
|
||||
# Compute test values first, before the train mode updates the moving
|
||||
# averages.
|
||||
yt_val_test, _, _ = sess.run([yt] + bn.updates,
|
||||
feed_dict={xt: x, training: False})
|
||||
yt_val_train, _, _ = sess.run([yt] + bn.updates,
|
||||
feed_dict={xt: x, training: True})
|
||||
|
||||
# Due to initialization inconsistencies, values may not be identical
|
||||
# on the first iteration (but shouldn't be different by much more than
|
||||
# epsilon). After the first iteration they should be identical.
|
||||
atol = epsilon * 1.5 if step == 0 else 1e-5
|
||||
self.assertAllClose(y_train, yt_val_train, atol=atol)
|
||||
self.assertAllClose(y_test, yt_val_test, atol=atol)
|
||||
self.assertAllClose(yt_val_train, yt_val_test, atol=atol)
|
||||
|
||||
def testAdjustment(self):
|
||||
shape = (4, 3)
|
||||
xt = array_ops.placeholder(dtypes.float32, shape)
|
||||
@ -1051,7 +996,7 @@ class BNTest(test.TestCase):
|
||||
yt = bn.apply(xt, training=training)
|
||||
|
||||
moving_mean = 0.
|
||||
moving_stddev = 1.
|
||||
moving_variance = 1.
|
||||
renorm_mean = 0.
|
||||
renorm_stddev = 1.
|
||||
with self.session(use_gpu=True) as sess:
|
||||
@ -1074,10 +1019,10 @@ class BNTest(test.TestCase):
|
||||
renorm_mean += (mean - renorm_mean) * (1. - renorm_momentum)
|
||||
renorm_stddev += (stddev - renorm_stddev) * (1. - renorm_momentum)
|
||||
moving_mean += (mean - moving_mean) * (1. - momentum)
|
||||
moving_stddev += (stddev - moving_stddev) * (1. - momentum)
|
||||
moving_variance += (variance - moving_variance) * (1. - momentum)
|
||||
|
||||
y_test = ((x - moving_mean) /
|
||||
(moving_stddev * moving_stddev)**0.5 * gamma) + beta
|
||||
y_test = ((x - moving_mean) / (moving_variance + epsilon) ** 0.5 *
|
||||
gamma) + beta
|
||||
|
||||
self.assertAllClose(y_train, yt_val_train, atol=1e-5)
|
||||
self.assertAllClose(y_test, yt_val_test, atol=1e-5)
|
||||
|
Loading…
Reference in New Issue
Block a user