Merge pull request #31178 from benbarsdell:fix-amp-while-loop-var-read

PiperOrigin-RevId: 264247407
This commit is contained in:
TensorFlower Gardener 2019-08-19 15:32:23 -07:00
commit 370ea8e1f8
3 changed files with 101 additions and 44 deletions

View File

@ -943,7 +943,7 @@ class AutoMixedPrecisionImpl {
bool IsOnSuitableGPUArch(const NodeDef& node) const;
bool ShouldProcess(const NodeDef& node) const;
bool NodeHasFP16KernelForTypeAttr(const NodeDef& node, TypeAttrId taid) const;
bool IsIdentityAfterVariable(const NodeDef& node) const;
bool NodeImplicitlyReadsNonResourceVariable(const NodeDef& node) const;
void ConvertBatchNormOpsToV2();
bool SupportsFloat16(const NodeTypeId& node_type) const;
const NodeDef* GetTailOfChain(const NodeDef& node, const string& op) const;
@ -1488,10 +1488,11 @@ void AutoMixedPrecisionImpl::PropagateWhiteThroughClear(
ShouldProcess(*item.node) && IsFloat32(item) &&
SupportsFloat16(item) &&
(fp16_clearlist_.count(item.node->op())) &&
// We don't propagate (backwards) through Identity nodes when
// they immediately follow Variable nodes because otherwise it
// breaks TensorBoard visualization.
!IsIdentityAfterVariable(*item.node));
// We don't propagate (backwards) through nodes that read
// Variables because it can break the behavior of TensorBoard
// visualization and/or (in the case of Enter nodes) the model
// itself. This is only a problem for non-resource variables.
!NodeImplicitlyReadsNonResourceVariable(*item.node));
}),
DfsTypeCallbacks::PreOrder([&](int idx) {
clear_prop_set.insert(idx);
@ -1641,13 +1642,17 @@ void AutoMixedPrecisionImpl::ForceColorMatchBetweenDataStructureOps(
}
}
bool AutoMixedPrecisionImpl::IsIdentityAfterVariable(
bool AutoMixedPrecisionImpl::NodeImplicitlyReadsNonResourceVariable(
const NodeDef& node) const {
if (node.op() == "Identity") {
if (node.op() == "Identity" || node.op() == "Enter") {
GraphView::InputPort node_input(&node, 0);
MutableGraphView::OutputPort prev_output =
graph_view_.GetRegularFanin(node_input);
if (prev_output.node && IsVariable(*prev_output.node)) {
const NodeDef* input = prev_output.node;
if (input && ((node.op() == "Identity" && (input->op() == "Variable" ||
input->op() == "VariableV2")) ||
(node.op() == "Enter" &&
NodeImplicitlyReadsNonResourceVariable(*input)))) {
return true;
}
}
@ -1753,7 +1758,7 @@ Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
added_cast_node = graph_view_.AddNode(
BuildCastNode(src, to_fp16, src.node->device()));
if (to_fp16 && !IsConstant(*node) && !IsVariable(*node) &&
!IsIdentityAfterVariable(*node)) {
!NodeImplicitlyReadsNonResourceVariable(*node)) {
++num_nonvar_casts_to_fp16;
}
}

View File

@ -16,13 +16,13 @@ tensorflow/third_party/android/BUILD
tensorflow/third_party/android/android.bzl.tpl
tensorflow/third_party/android/android_configure.BUILD.tpl
tensorflow/third_party/android/android_configure.bzl
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
tensorflow/third_party/astor.BUILD
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
tensorflow/third_party/backports_weakref.BUILD
tensorflow/third_party/boringssl/BUILD
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
tensorflow/third_party/clang_toolchain/BUILD
tensorflow/third_party/clang_toolchain/download_clang.bzl
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
tensorflow/third_party/codegen.BUILD
tensorflow/third_party/com_google_absl.BUILD
tensorflow/third_party/common.bzl
@ -31,26 +31,26 @@ tensorflow/third_party/curl.BUILD
tensorflow/third_party/cython.BUILD
tensorflow/third_party/double_conversion.BUILD
tensorflow/third_party/eigen.BUILD
tensorflow/third_party/eigen3/Eigen/Core
tensorflow/third_party/eigen3/BUILD
tensorflow/third_party/eigen3/Eigen/Cholesky
tensorflow/third_party/eigen3/Eigen/Core
tensorflow/third_party/eigen3/Eigen/Eigenvalues
tensorflow/third_party/eigen3/Eigen/LU
tensorflow/third_party/eigen3/Eigen/QR
tensorflow/third_party/eigen3/Eigen/SVD
tensorflow/third_party/eigen3/BUILD
tensorflow/third_party/eigen3/LICENSE
tensorflow/third_party/eigen3/gpu_packet_math.patch
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
@ -61,22 +61,22 @@ tensorflow/third_party/fft2d/LICENSE
tensorflow/third_party/fft2d/fft.h
tensorflow/third_party/fft2d/fft2d.BUILD
tensorflow/third_party/fft2d/fft2d.h
tensorflow/third_party/gast.BUILD
tensorflow/third_party/functools32.BUILD
tensorflow/third_party/gast.BUILD
tensorflow/third_party/gif.BUILD
tensorflow/third_party/git/BUILD
tensorflow/third_party/git/BUILD.tpl
tensorflow/third_party/git/BUILD
tensorflow/third_party/git/git_configure.bzl
tensorflow/third_party/googleapis.BUILD
tensorflow/third_party/gpus/BUILD
tensorflow/third_party/gpus/crosstool/BUILD
tensorflow/third_party/gpus/crosstool/BUILD.tpl
tensorflow/third_party/gpus/crosstool/LICENSE
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
tensorflow/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
tensorflow/third_party/gpus/cuda/BUILD
tensorflow/third_party/gpus/cuda/BUILD.tpl
tensorflow/third_party/gpus/cuda/BUILD
tensorflow/third_party/gpus/cuda/BUILD.windows.tpl
tensorflow/third_party/gpus/cuda/LICENSE
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
@ -85,8 +85,8 @@ tensorflow/third_party/gpus/cuda_configure.bzl
tensorflow/third_party/gpus/find_cuda_config.py
tensorflow/third_party/gpus/rocm/BUILD
tensorflow/third_party/gpus/rocm/BUILD.tpl
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
tensorflow/third_party/gpus/rocm_configure.bzl
tensorflow/third_party/grpc/BUILD
tensorflow/third_party/icu/udata.patch
@ -103,38 +103,38 @@ tensorflow/third_party/lmdb.BUILD
tensorflow/third_party/mkl/BUILD
tensorflow/third_party/mkl/LICENSE
tensorflow/third_party/mkl/MKL_LICENSE
tensorflow/third_party/mkl/mkl.BUILD
tensorflow/third_party/mkl/build_defs.bzl
tensorflow/third_party/mkl/mkl.BUILD
tensorflow/third_party/mkl_dnn/LICENSE
tensorflow/third_party/mkl_dnn/mkldnn.BUILD
tensorflow/third_party/mpi/.gitignore
tensorflow/third_party/mpi/BUILD
tensorflow/third_party/mpi_collectives/BUILD
tensorflow/third_party/nanopb.BUILD
tensorflow/third_party/nccl/LICENSE
tensorflow/third_party/nccl/BUILD
tensorflow/third_party/nccl/archive.BUILD
tensorflow/third_party/nccl/build_defs.bzl.tpl
tensorflow/third_party/nccl/LICENSE
tensorflow/third_party/nccl/archive.patch
tensorflow/third_party/nccl/build_defs.bzl.tpl
tensorflow/third_party/nccl/nccl_configure.bzl
tensorflow/third_party/nccl/system.BUILD.tpl
tensorflow/third_party/ngraph/LICENSE
tensorflow/third_party/ngraph/BUILD
tensorflow/third_party/ngraph/LICENSE
tensorflow/third_party/ngraph/NGRAPH_LICENSE
tensorflow/third_party/ngraph/build_defs.bzl
tensorflow/third_party/ngraph/ngraph.BUILD
tensorflow/third_party/ngraph/build_defs.bzl
tensorflow/third_party/ngraph/ngraph_tf.BUILD
tensorflow/third_party/ngraph/nlohmann_json.BUILD
tensorflow/third_party/ngraph/tbb.BUILD
tensorflow/third_party/opt_einsum.BUILD
tensorflow/third_party/png.BUILD
tensorflow/third_party/pcre.BUILD
tensorflow/third_party/png.BUILD
tensorflow/third_party/png_fix_rpi.patch
tensorflow/third_party/pprof.BUILD
tensorflow/third_party/protobuf/BUILD
tensorflow/third_party/py/BUILD.tpl
tensorflow/third_party/py/BUILD
tensorflow/third_party/py/numpy/BUILD
tensorflow/third_party/py/BUILD.tpl
tensorflow/third_party/py/python_configure.bzl
tensorflow/third_party/pybind11.BUILD
tensorflow/third_party/python_runtime/BUILD
@ -144,11 +144,11 @@ tensorflow/third_party/snappy.BUILD
tensorflow/third_party/sqlite.BUILD
tensorflow/third_party/swig.BUILD
tensorflow/third_party/sycl/crosstool/BUILD
tensorflow/third_party/systemlibs/BUILD
tensorflow/third_party/systemlibs/BUILD.tpl
tensorflow/third_party/systemlibs/BUILD
tensorflow/third_party/systemlibs/absl_py.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
tensorflow/third_party/systemlibs/astor.BUILD
tensorflow/third_party/systemlibs/boringssl.BUILD
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
@ -158,19 +158,19 @@ tensorflow/third_party/systemlibs/double_conversion.BUILD
tensorflow/third_party/systemlibs/gast.BUILD
tensorflow/third_party/systemlibs/gif.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
tensorflow/third_party/systemlibs/googleapis.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
tensorflow/third_party/systemlibs/grpc.BUILD
tensorflow/third_party/systemlibs/jsoncpp.BUILD
tensorflow/third_party/systemlibs/lmdb.BUILD
tensorflow/third_party/systemlibs/nsync.BUILD
tensorflow/third_party/systemlibs/opt_einsum.BUILD
tensorflow/third_party/systemlibs/png.BUILD
tensorflow/third_party/systemlibs/pcre.BUILD
tensorflow/third_party/systemlibs/png.BUILD
tensorflow/third_party/systemlibs/protobuf.BUILD
tensorflow/third_party/systemlibs/protobuf.bzl
tensorflow/third_party/systemlibs/re2.BUILD
tensorflow/third_party/systemlibs/six.BUILD
tensorflow/third_party/systemlibs/re2.BUILD
tensorflow/third_party/systemlibs/snappy.BUILD
tensorflow/third_party/systemlibs/sqlite.BUILD
tensorflow/third_party/systemlibs/swig.BUILD
@ -196,8 +196,8 @@ tensorflow/third_party/toolchains/clang6/README.md
tensorflow/third_party/toolchains/clang6/clang.BUILD
tensorflow/third_party/toolchains/clang6/repo.bzl
tensorflow/third_party/toolchains/cpus/arm/BUILD
tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl
tensorflow/third_party/toolchains/cpus/arm/cc_config.bzl.tpl
tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl
tensorflow/third_party/toolchains/cpus/py/BUILD
tensorflow/third_party/toolchains/cpus/py3/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/BUILD
@ -209,26 +209,26 @@ tensorflow/third_party/toolchains/preconfig/centos6/gcc7/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.0/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.1/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.1/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.1/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/centos6/py/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/py3/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
tensorflow/third_party/toolchains/preconfig/generate/BUILD
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/build_defs.bzl
@ -247,8 +247,8 @@ tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5.1/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5.1/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_025/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
tensorflow/third_party/toolchains/remote/BUILD
tensorflow/third_party/toolchains/remote/BUILD.tpl
@ -264,8 +264,8 @@ tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
tensorflow/tools/lib_package/BUILD
tensorflow/tools/lib_package/LibTensorFlowTest.java
tensorflow/tools/lib_package/README.md
tensorflow/tools/lib_package/concat_licenses.sh
tensorflow/tools/lib_package/libtensorflow_java_test.sh
tensorflow/tools/lib_package/concat_licenses.sh
tensorflow/tools/lib_package/libtensorflow_test.c
tensorflow/tools/lib_package/libtensorflow_test.sh
tensorflow/tools/pip_package/BUILD
@ -273,8 +273,8 @@ tensorflow/tools/pip_package/MANIFEST.in
tensorflow/tools/pip_package/README
tensorflow/tools/pip_package/build_pip_package.sh
tensorflow/tools/pip_package/check_load_py_test.py
tensorflow/tools/pip_package/pip_smoke_test.py
tensorflow/tools/pip_package/setup.py
tensorflow/tools/pip_package/pip_smoke_test.py
tensorflow/tools/pip_package/simple_console.py
tensorflow/tools/pip_package/simple_console_for_windows.py
tensorflow/virtual_root_template_v1.__init__.py

View File

@ -19,18 +19,22 @@ from __future__ import division
from __future__ import print_function
import os
import numpy as np
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import tf2
from tensorflow.python.client import session
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_util
from tensorflow.python.layers import layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
@ -40,7 +44,9 @@ from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import test
from tensorflow.python.training import adam
from tensorflow.python.training import gradient_descent
@ -616,6 +622,52 @@ class AutoMixedPrecisionTest(test.TestCase):
self._assert_output_fp16(node_map, 'MatMul')
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
@test_util.run_deprecated_v1
def test_ingraph_train_loop(self):
"""Tests a graph containing a while loop around a training update.
This requires the grappler pass to take special care with its handling of
Enter ops that appear in front of reads from non-resource variables. See
the use of NodeImplicitlyReadsVariable in auto_mixed_precision.cc.
"""
if tf2.enabled():
# This test tests non-resource variables, which are only used in TF1.
self.skipTest('TensorFlow 1 required')
if test.is_gpu_available(cuda_only=True):
random_seed.set_random_seed(1234)
np.random.seed(1234)
num_iter, bs, nchan, nclass = 100, 64, 32, 100
data = np.random.normal(size=(bs * num_iter, nchan)).astype(np.float32)
labels = np.random.randint(nclass, size=(bs * num_iter,))
ds = dataset_ops.Dataset.from_tensor_slices((data, labels))
ds = ds.batch(bs).prefetch(3)
it = ds.make_one_shot_iterator()
def body(_, i):
i += 1
x, yt = it.get_next()
dense = layers.Dense(nclass)
y = dense(x)
loss = losses.sparse_softmax_cross_entropy(yt, y)
opt = adam.AdamOptimizer()
train_op = opt.minimize(loss, var_list=dense.trainable_weights)
with ops.control_dependencies([train_op]):
loss = array_ops.identity(loss)
return loss, i
begin, end = constant_op.constant(0), constant_op.constant(num_iter)
loss, _ = control_flow_ops.while_loop(
lambda loss, i: math_ops.less(i, end), body, [0.0, begin])
output_val_ref, output_val, cost_graph = self._run(loss)
node_map = _build_node_map(cost_graph.node)
self._assert_output_fp16(node_map, 'while/dense/MatMul')
self._assert_output_fp16(
node_map, 'while/gradients/while/dense/MatMul_grad/MatMul_1')
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
if __name__ == '__main__':
test.main()