diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc index ae7e203ff1d..e896e0f98ee 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc @@ -943,7 +943,7 @@ class AutoMixedPrecisionImpl { bool IsOnSuitableGPUArch(const NodeDef& node) const; bool ShouldProcess(const NodeDef& node) const; bool NodeHasFP16KernelForTypeAttr(const NodeDef& node, TypeAttrId taid) const; - bool IsIdentityAfterVariable(const NodeDef& node) const; + bool NodeImplicitlyReadsNonResourceVariable(const NodeDef& node) const; void ConvertBatchNormOpsToV2(); bool SupportsFloat16(const NodeTypeId& node_type) const; const NodeDef* GetTailOfChain(const NodeDef& node, const string& op) const; @@ -1488,10 +1488,11 @@ void AutoMixedPrecisionImpl::PropagateWhiteThroughClear( ShouldProcess(*item.node) && IsFloat32(item) && SupportsFloat16(item) && (fp16_clearlist_.count(item.node->op())) && - // We don't propagate (backwards) through Identity nodes when - // they immediately follow Variable nodes because otherwise it - // breaks TensorBoard visualization. - !IsIdentityAfterVariable(*item.node)); + // We don't propagate (backwards) through nodes that read + // Variables because it can break the behavior of TensorBoard + // visualization and/or (in the case of Enter nodes) the model + // itself. This is only a problem for non-resource variables. + !NodeImplicitlyReadsNonResourceVariable(*item.node)); }), DfsTypeCallbacks::PreOrder([&](int idx) { clear_prop_set.insert(idx); @@ -1641,13 +1642,17 @@ void AutoMixedPrecisionImpl::ForceColorMatchBetweenDataStructureOps( } } -bool AutoMixedPrecisionImpl::IsIdentityAfterVariable( +bool AutoMixedPrecisionImpl::NodeImplicitlyReadsNonResourceVariable( const NodeDef& node) const { - if (node.op() == "Identity") { + if (node.op() == "Identity" || node.op() == "Enter") { GraphView::InputPort node_input(&node, 0); MutableGraphView::OutputPort prev_output = graph_view_.GetRegularFanin(node_input); - if (prev_output.node && IsVariable(*prev_output.node)) { + const NodeDef* input = prev_output.node; + if (input && ((node.op() == "Identity" && (input->op() == "Variable" || + input->op() == "VariableV2")) || + (node.op() == "Enter" && + NodeImplicitlyReadsNonResourceVariable(*input)))) { return true; } } @@ -1753,7 +1758,7 @@ Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts( added_cast_node = graph_view_.AddNode( BuildCastNode(src, to_fp16, src.node->device())); if (to_fp16 && !IsConstant(*node) && !IsVariable(*node) && - !IsIdentityAfterVariable(*node)) { + !NodeImplicitlyReadsNonResourceVariable(*node)) { ++num_nonvar_casts_to_fp16; } } diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index 6ce7d96b064..75b8ed97226 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -16,13 +16,13 @@ tensorflow/third_party/android/BUILD tensorflow/third_party/android/android.bzl.tpl tensorflow/third_party/android/android_configure.BUILD.tpl tensorflow/third_party/android/android_configure.bzl -tensorflow/third_party/arm_neon_2_x86_sse.BUILD tensorflow/third_party/astor.BUILD +tensorflow/third_party/arm_neon_2_x86_sse.BUILD tensorflow/third_party/backports_weakref.BUILD tensorflow/third_party/boringssl/BUILD -tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl tensorflow/third_party/clang_toolchain/BUILD tensorflow/third_party/clang_toolchain/download_clang.bzl +tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl tensorflow/third_party/codegen.BUILD tensorflow/third_party/com_google_absl.BUILD tensorflow/third_party/common.bzl @@ -31,26 +31,26 @@ tensorflow/third_party/curl.BUILD tensorflow/third_party/cython.BUILD tensorflow/third_party/double_conversion.BUILD tensorflow/third_party/eigen.BUILD -tensorflow/third_party/eigen3/Eigen/Core +tensorflow/third_party/eigen3/BUILD tensorflow/third_party/eigen3/Eigen/Cholesky +tensorflow/third_party/eigen3/Eigen/Core tensorflow/third_party/eigen3/Eigen/Eigenvalues tensorflow/third_party/eigen3/Eigen/LU tensorflow/third_party/eigen3/Eigen/QR tensorflow/third_party/eigen3/Eigen/SVD -tensorflow/third_party/eigen3/BUILD tensorflow/third_party/eigen3/LICENSE tensorflow/third_party/eigen3/gpu_packet_math.patch tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions @@ -61,22 +61,22 @@ tensorflow/third_party/fft2d/LICENSE tensorflow/third_party/fft2d/fft.h tensorflow/third_party/fft2d/fft2d.BUILD tensorflow/third_party/fft2d/fft2d.h -tensorflow/third_party/gast.BUILD tensorflow/third_party/functools32.BUILD +tensorflow/third_party/gast.BUILD tensorflow/third_party/gif.BUILD -tensorflow/third_party/git/BUILD tensorflow/third_party/git/BUILD.tpl +tensorflow/third_party/git/BUILD tensorflow/third_party/git/git_configure.bzl tensorflow/third_party/googleapis.BUILD tensorflow/third_party/gpus/BUILD tensorflow/third_party/gpus/crosstool/BUILD tensorflow/third_party/gpus/crosstool/BUILD.tpl tensorflow/third_party/gpus/crosstool/LICENSE -tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl tensorflow/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl -tensorflow/third_party/gpus/cuda/BUILD tensorflow/third_party/gpus/cuda/BUILD.tpl +tensorflow/third_party/gpus/cuda/BUILD tensorflow/third_party/gpus/cuda/BUILD.windows.tpl tensorflow/third_party/gpus/cuda/LICENSE tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl @@ -85,8 +85,8 @@ tensorflow/third_party/gpus/cuda_configure.bzl tensorflow/third_party/gpus/find_cuda_config.py tensorflow/third_party/gpus/rocm/BUILD tensorflow/third_party/gpus/rocm/BUILD.tpl -tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl tensorflow/third_party/gpus/rocm/rocm_config.h.tpl +tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl tensorflow/third_party/gpus/rocm_configure.bzl tensorflow/third_party/grpc/BUILD tensorflow/third_party/icu/udata.patch @@ -103,38 +103,38 @@ tensorflow/third_party/lmdb.BUILD tensorflow/third_party/mkl/BUILD tensorflow/third_party/mkl/LICENSE tensorflow/third_party/mkl/MKL_LICENSE -tensorflow/third_party/mkl/mkl.BUILD tensorflow/third_party/mkl/build_defs.bzl +tensorflow/third_party/mkl/mkl.BUILD tensorflow/third_party/mkl_dnn/LICENSE tensorflow/third_party/mkl_dnn/mkldnn.BUILD tensorflow/third_party/mpi/.gitignore tensorflow/third_party/mpi/BUILD tensorflow/third_party/mpi_collectives/BUILD tensorflow/third_party/nanopb.BUILD -tensorflow/third_party/nccl/LICENSE tensorflow/third_party/nccl/BUILD tensorflow/third_party/nccl/archive.BUILD -tensorflow/third_party/nccl/build_defs.bzl.tpl +tensorflow/third_party/nccl/LICENSE tensorflow/third_party/nccl/archive.patch +tensorflow/third_party/nccl/build_defs.bzl.tpl tensorflow/third_party/nccl/nccl_configure.bzl tensorflow/third_party/nccl/system.BUILD.tpl -tensorflow/third_party/ngraph/LICENSE tensorflow/third_party/ngraph/BUILD +tensorflow/third_party/ngraph/LICENSE tensorflow/third_party/ngraph/NGRAPH_LICENSE -tensorflow/third_party/ngraph/build_defs.bzl tensorflow/third_party/ngraph/ngraph.BUILD +tensorflow/third_party/ngraph/build_defs.bzl tensorflow/third_party/ngraph/ngraph_tf.BUILD tensorflow/third_party/ngraph/nlohmann_json.BUILD tensorflow/third_party/ngraph/tbb.BUILD tensorflow/third_party/opt_einsum.BUILD -tensorflow/third_party/png.BUILD tensorflow/third_party/pcre.BUILD +tensorflow/third_party/png.BUILD tensorflow/third_party/png_fix_rpi.patch tensorflow/third_party/pprof.BUILD tensorflow/third_party/protobuf/BUILD +tensorflow/third_party/py/BUILD.tpl tensorflow/third_party/py/BUILD tensorflow/third_party/py/numpy/BUILD -tensorflow/third_party/py/BUILD.tpl tensorflow/third_party/py/python_configure.bzl tensorflow/third_party/pybind11.BUILD tensorflow/third_party/python_runtime/BUILD @@ -144,11 +144,11 @@ tensorflow/third_party/snappy.BUILD tensorflow/third_party/sqlite.BUILD tensorflow/third_party/swig.BUILD tensorflow/third_party/sycl/crosstool/BUILD -tensorflow/third_party/systemlibs/BUILD tensorflow/third_party/systemlibs/BUILD.tpl +tensorflow/third_party/systemlibs/BUILD tensorflow/third_party/systemlibs/absl_py.BUILD -tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD +tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD tensorflow/third_party/systemlibs/astor.BUILD tensorflow/third_party/systemlibs/boringssl.BUILD tensorflow/third_party/systemlibs/build_defs.bzl.tpl @@ -158,19 +158,19 @@ tensorflow/third_party/systemlibs/double_conversion.BUILD tensorflow/third_party/systemlibs/gast.BUILD tensorflow/third_party/systemlibs/gif.BUILD tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD -tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD tensorflow/third_party/systemlibs/googleapis.BUILD +tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD tensorflow/third_party/systemlibs/grpc.BUILD tensorflow/third_party/systemlibs/jsoncpp.BUILD tensorflow/third_party/systemlibs/lmdb.BUILD tensorflow/third_party/systemlibs/nsync.BUILD tensorflow/third_party/systemlibs/opt_einsum.BUILD -tensorflow/third_party/systemlibs/png.BUILD tensorflow/third_party/systemlibs/pcre.BUILD +tensorflow/third_party/systemlibs/png.BUILD tensorflow/third_party/systemlibs/protobuf.BUILD tensorflow/third_party/systemlibs/protobuf.bzl -tensorflow/third_party/systemlibs/re2.BUILD tensorflow/third_party/systemlibs/six.BUILD +tensorflow/third_party/systemlibs/re2.BUILD tensorflow/third_party/systemlibs/snappy.BUILD tensorflow/third_party/systemlibs/sqlite.BUILD tensorflow/third_party/systemlibs/swig.BUILD @@ -196,8 +196,8 @@ tensorflow/third_party/toolchains/clang6/README.md tensorflow/third_party/toolchains/clang6/clang.BUILD tensorflow/third_party/toolchains/clang6/repo.bzl tensorflow/third_party/toolchains/cpus/arm/BUILD -tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl tensorflow/third_party/toolchains/cpus/arm/cc_config.bzl.tpl +tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl tensorflow/third_party/toolchains/cpus/py/BUILD tensorflow/third_party/toolchains/cpus/py3/BUILD tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/BUILD @@ -209,26 +209,26 @@ tensorflow/third_party/toolchains/preconfig/centos6/gcc7/cc_toolchain_config.bzl tensorflow/third_party/toolchains/preconfig/centos6/gcc7/dummy_toolchain.bzl tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.0/BUILD tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.0/cc_toolchain_config.bzl -tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.1/cc_toolchain_config.bzl tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.1/BUILD +tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.1/cc_toolchain_config.bzl tensorflow/third_party/toolchains/preconfig/centos6/py/BUILD tensorflow/third_party/toolchains/preconfig/centos6/py3/BUILD tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/BUILD tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/build_defs.bzl -tensorflow/third_party/toolchains/preconfig/generate/archives.bzl tensorflow/third_party/toolchains/preconfig/generate/BUILD -tensorflow/third_party/toolchains/preconfig/generate/generate.bzl +tensorflow/third_party/toolchains/preconfig/generate/archives.bzl tensorflow/third_party/toolchains/preconfig/generate/containers.bzl +tensorflow/third_party/toolchains/preconfig/generate/generate.bzl tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl -tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/cc_toolchain_config.bzl +tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl -tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl +tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/build_defs.bzl @@ -247,8 +247,8 @@ tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/build_defs.bzl tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5.1/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5.1/build_defs.bzl -tensorflow/third_party/toolchains/preconfig/win_1803/BUILD tensorflow/third_party/toolchains/preconfig/win_1803/bazel_025/BUILD +tensorflow/third_party/toolchains/preconfig/win_1803/BUILD tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD tensorflow/third_party/toolchains/remote/BUILD tensorflow/third_party/toolchains/remote/BUILD.tpl @@ -264,8 +264,8 @@ tensorflow/tools/def_file_filter/def_file_filter_configure.bzl tensorflow/tools/lib_package/BUILD tensorflow/tools/lib_package/LibTensorFlowTest.java tensorflow/tools/lib_package/README.md -tensorflow/tools/lib_package/concat_licenses.sh tensorflow/tools/lib_package/libtensorflow_java_test.sh +tensorflow/tools/lib_package/concat_licenses.sh tensorflow/tools/lib_package/libtensorflow_test.c tensorflow/tools/lib_package/libtensorflow_test.sh tensorflow/tools/pip_package/BUILD @@ -273,8 +273,8 @@ tensorflow/tools/pip_package/MANIFEST.in tensorflow/tools/pip_package/README tensorflow/tools/pip_package/build_pip_package.sh tensorflow/tools/pip_package/check_load_py_test.py -tensorflow/tools/pip_package/pip_smoke_test.py tensorflow/tools/pip_package/setup.py +tensorflow/tools/pip_package/pip_smoke_test.py tensorflow/tools/pip_package/simple_console.py tensorflow/tools/pip_package/simple_console_for_windows.py tensorflow/virtual_root_template_v1.__init__.py diff --git a/tensorflow/python/grappler/auto_mixed_precision_test.py b/tensorflow/python/grappler/auto_mixed_precision_test.py index dc020d76b5e..84ebcf2b882 100644 --- a/tensorflow/python/grappler/auto_mixed_precision_test.py +++ b/tensorflow/python/grappler/auto_mixed_precision_test.py @@ -19,18 +19,22 @@ from __future__ import division from __future__ import print_function import os +import numpy as np from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 - +from tensorflow.python import tf2 from tensorflow.python.client import session from tensorflow.python.compat import compat +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function +from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import test_util +from tensorflow.python.layers import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops @@ -40,7 +44,9 @@ from tensorflow.python.ops import nn_impl from tensorflow.python.ops import random_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test +from tensorflow.python.training import adam from tensorflow.python.training import gradient_descent @@ -616,6 +622,52 @@ class AutoMixedPrecisionTest(test.TestCase): self._assert_output_fp16(node_map, 'MatMul') self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) + @test_util.run_deprecated_v1 + def test_ingraph_train_loop(self): + """Tests a graph containing a while loop around a training update. + + This requires the grappler pass to take special care with its handling of + Enter ops that appear in front of reads from non-resource variables. See + the use of NodeImplicitlyReadsVariable in auto_mixed_precision.cc. + """ + if tf2.enabled(): + # This test tests non-resource variables, which are only used in TF1. + self.skipTest('TensorFlow 1 required') + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(1234) + np.random.seed(1234) + num_iter, bs, nchan, nclass = 100, 64, 32, 100 + + data = np.random.normal(size=(bs * num_iter, nchan)).astype(np.float32) + labels = np.random.randint(nclass, size=(bs * num_iter,)) + ds = dataset_ops.Dataset.from_tensor_slices((data, labels)) + ds = ds.batch(bs).prefetch(3) + it = ds.make_one_shot_iterator() + + def body(_, i): + i += 1 + x, yt = it.get_next() + dense = layers.Dense(nclass) + y = dense(x) + loss = losses.sparse_softmax_cross_entropy(yt, y) + opt = adam.AdamOptimizer() + train_op = opt.minimize(loss, var_list=dense.trainable_weights) + with ops.control_dependencies([train_op]): + loss = array_ops.identity(loss) + return loss, i + + begin, end = constant_op.constant(0), constant_op.constant(num_iter) + loss, _ = control_flow_ops.while_loop( + lambda loss, i: math_ops.less(i, end), body, [0.0, begin]) + + output_val_ref, output_val, cost_graph = self._run(loss) + node_map = _build_node_map(cost_graph.node) + + self._assert_output_fp16(node_map, 'while/dense/MatMul') + self._assert_output_fp16( + node_map, 'while/gradients/while/dense/MatMul_grad/MatMul_1') + self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) + if __name__ == '__main__': test.main()