Merge pull request #31178 from benbarsdell:fix-amp-while-loop-var-read
PiperOrigin-RevId: 264247407
This commit is contained in:
commit
370ea8e1f8
@ -943,7 +943,7 @@ class AutoMixedPrecisionImpl {
|
||||
bool IsOnSuitableGPUArch(const NodeDef& node) const;
|
||||
bool ShouldProcess(const NodeDef& node) const;
|
||||
bool NodeHasFP16KernelForTypeAttr(const NodeDef& node, TypeAttrId taid) const;
|
||||
bool IsIdentityAfterVariable(const NodeDef& node) const;
|
||||
bool NodeImplicitlyReadsNonResourceVariable(const NodeDef& node) const;
|
||||
void ConvertBatchNormOpsToV2();
|
||||
bool SupportsFloat16(const NodeTypeId& node_type) const;
|
||||
const NodeDef* GetTailOfChain(const NodeDef& node, const string& op) const;
|
||||
@ -1488,10 +1488,11 @@ void AutoMixedPrecisionImpl::PropagateWhiteThroughClear(
|
||||
ShouldProcess(*item.node) && IsFloat32(item) &&
|
||||
SupportsFloat16(item) &&
|
||||
(fp16_clearlist_.count(item.node->op())) &&
|
||||
// We don't propagate (backwards) through Identity nodes when
|
||||
// they immediately follow Variable nodes because otherwise it
|
||||
// breaks TensorBoard visualization.
|
||||
!IsIdentityAfterVariable(*item.node));
|
||||
// We don't propagate (backwards) through nodes that read
|
||||
// Variables because it can break the behavior of TensorBoard
|
||||
// visualization and/or (in the case of Enter nodes) the model
|
||||
// itself. This is only a problem for non-resource variables.
|
||||
!NodeImplicitlyReadsNonResourceVariable(*item.node));
|
||||
}),
|
||||
DfsTypeCallbacks::PreOrder([&](int idx) {
|
||||
clear_prop_set.insert(idx);
|
||||
@ -1641,13 +1642,17 @@ void AutoMixedPrecisionImpl::ForceColorMatchBetweenDataStructureOps(
|
||||
}
|
||||
}
|
||||
|
||||
bool AutoMixedPrecisionImpl::IsIdentityAfterVariable(
|
||||
bool AutoMixedPrecisionImpl::NodeImplicitlyReadsNonResourceVariable(
|
||||
const NodeDef& node) const {
|
||||
if (node.op() == "Identity") {
|
||||
if (node.op() == "Identity" || node.op() == "Enter") {
|
||||
GraphView::InputPort node_input(&node, 0);
|
||||
MutableGraphView::OutputPort prev_output =
|
||||
graph_view_.GetRegularFanin(node_input);
|
||||
if (prev_output.node && IsVariable(*prev_output.node)) {
|
||||
const NodeDef* input = prev_output.node;
|
||||
if (input && ((node.op() == "Identity" && (input->op() == "Variable" ||
|
||||
input->op() == "VariableV2")) ||
|
||||
(node.op() == "Enter" &&
|
||||
NodeImplicitlyReadsNonResourceVariable(*input)))) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@ -1753,7 +1758,7 @@ Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
|
||||
added_cast_node = graph_view_.AddNode(
|
||||
BuildCastNode(src, to_fp16, src.node->device()));
|
||||
if (to_fp16 && !IsConstant(*node) && !IsVariable(*node) &&
|
||||
!IsIdentityAfterVariable(*node)) {
|
||||
!NodeImplicitlyReadsNonResourceVariable(*node)) {
|
||||
++num_nonvar_casts_to_fp16;
|
||||
}
|
||||
}
|
||||
|
@ -16,13 +16,13 @@ tensorflow/third_party/android/BUILD
|
||||
tensorflow/third_party/android/android.bzl.tpl
|
||||
tensorflow/third_party/android/android_configure.BUILD.tpl
|
||||
tensorflow/third_party/android/android_configure.bzl
|
||||
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
|
||||
tensorflow/third_party/astor.BUILD
|
||||
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
|
||||
tensorflow/third_party/backports_weakref.BUILD
|
||||
tensorflow/third_party/boringssl/BUILD
|
||||
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
|
||||
tensorflow/third_party/clang_toolchain/BUILD
|
||||
tensorflow/third_party/clang_toolchain/download_clang.bzl
|
||||
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
|
||||
tensorflow/third_party/codegen.BUILD
|
||||
tensorflow/third_party/com_google_absl.BUILD
|
||||
tensorflow/third_party/common.bzl
|
||||
@ -31,26 +31,26 @@ tensorflow/third_party/curl.BUILD
|
||||
tensorflow/third_party/cython.BUILD
|
||||
tensorflow/third_party/double_conversion.BUILD
|
||||
tensorflow/third_party/eigen.BUILD
|
||||
tensorflow/third_party/eigen3/Eigen/Core
|
||||
tensorflow/third_party/eigen3/BUILD
|
||||
tensorflow/third_party/eigen3/Eigen/Cholesky
|
||||
tensorflow/third_party/eigen3/Eigen/Core
|
||||
tensorflow/third_party/eigen3/Eigen/Eigenvalues
|
||||
tensorflow/third_party/eigen3/Eigen/LU
|
||||
tensorflow/third_party/eigen3/Eigen/QR
|
||||
tensorflow/third_party/eigen3/Eigen/SVD
|
||||
tensorflow/third_party/eigen3/BUILD
|
||||
tensorflow/third_party/eigen3/LICENSE
|
||||
tensorflow/third_party/eigen3/gpu_packet_math.patch
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
|
||||
@ -61,22 +61,22 @@ tensorflow/third_party/fft2d/LICENSE
|
||||
tensorflow/third_party/fft2d/fft.h
|
||||
tensorflow/third_party/fft2d/fft2d.BUILD
|
||||
tensorflow/third_party/fft2d/fft2d.h
|
||||
tensorflow/third_party/gast.BUILD
|
||||
tensorflow/third_party/functools32.BUILD
|
||||
tensorflow/third_party/gast.BUILD
|
||||
tensorflow/third_party/gif.BUILD
|
||||
tensorflow/third_party/git/BUILD
|
||||
tensorflow/third_party/git/BUILD.tpl
|
||||
tensorflow/third_party/git/BUILD
|
||||
tensorflow/third_party/git/git_configure.bzl
|
||||
tensorflow/third_party/googleapis.BUILD
|
||||
tensorflow/third_party/gpus/BUILD
|
||||
tensorflow/third_party/gpus/crosstool/BUILD
|
||||
tensorflow/third_party/gpus/crosstool/BUILD.tpl
|
||||
tensorflow/third_party/gpus/crosstool/LICENSE
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
|
||||
tensorflow/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
|
||||
tensorflow/third_party/gpus/cuda/BUILD
|
||||
tensorflow/third_party/gpus/cuda/BUILD.tpl
|
||||
tensorflow/third_party/gpus/cuda/BUILD
|
||||
tensorflow/third_party/gpus/cuda/BUILD.windows.tpl
|
||||
tensorflow/third_party/gpus/cuda/LICENSE
|
||||
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
|
||||
@ -85,8 +85,8 @@ tensorflow/third_party/gpus/cuda_configure.bzl
|
||||
tensorflow/third_party/gpus/find_cuda_config.py
|
||||
tensorflow/third_party/gpus/rocm/BUILD
|
||||
tensorflow/third_party/gpus/rocm/BUILD.tpl
|
||||
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
|
||||
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/rocm_configure.bzl
|
||||
tensorflow/third_party/grpc/BUILD
|
||||
tensorflow/third_party/icu/udata.patch
|
||||
@ -103,38 +103,38 @@ tensorflow/third_party/lmdb.BUILD
|
||||
tensorflow/third_party/mkl/BUILD
|
||||
tensorflow/third_party/mkl/LICENSE
|
||||
tensorflow/third_party/mkl/MKL_LICENSE
|
||||
tensorflow/third_party/mkl/mkl.BUILD
|
||||
tensorflow/third_party/mkl/build_defs.bzl
|
||||
tensorflow/third_party/mkl/mkl.BUILD
|
||||
tensorflow/third_party/mkl_dnn/LICENSE
|
||||
tensorflow/third_party/mkl_dnn/mkldnn.BUILD
|
||||
tensorflow/third_party/mpi/.gitignore
|
||||
tensorflow/third_party/mpi/BUILD
|
||||
tensorflow/third_party/mpi_collectives/BUILD
|
||||
tensorflow/third_party/nanopb.BUILD
|
||||
tensorflow/third_party/nccl/LICENSE
|
||||
tensorflow/third_party/nccl/BUILD
|
||||
tensorflow/third_party/nccl/archive.BUILD
|
||||
tensorflow/third_party/nccl/build_defs.bzl.tpl
|
||||
tensorflow/third_party/nccl/LICENSE
|
||||
tensorflow/third_party/nccl/archive.patch
|
||||
tensorflow/third_party/nccl/build_defs.bzl.tpl
|
||||
tensorflow/third_party/nccl/nccl_configure.bzl
|
||||
tensorflow/third_party/nccl/system.BUILD.tpl
|
||||
tensorflow/third_party/ngraph/LICENSE
|
||||
tensorflow/third_party/ngraph/BUILD
|
||||
tensorflow/third_party/ngraph/LICENSE
|
||||
tensorflow/third_party/ngraph/NGRAPH_LICENSE
|
||||
tensorflow/third_party/ngraph/build_defs.bzl
|
||||
tensorflow/third_party/ngraph/ngraph.BUILD
|
||||
tensorflow/third_party/ngraph/build_defs.bzl
|
||||
tensorflow/third_party/ngraph/ngraph_tf.BUILD
|
||||
tensorflow/third_party/ngraph/nlohmann_json.BUILD
|
||||
tensorflow/third_party/ngraph/tbb.BUILD
|
||||
tensorflow/third_party/opt_einsum.BUILD
|
||||
tensorflow/third_party/png.BUILD
|
||||
tensorflow/third_party/pcre.BUILD
|
||||
tensorflow/third_party/png.BUILD
|
||||
tensorflow/third_party/png_fix_rpi.patch
|
||||
tensorflow/third_party/pprof.BUILD
|
||||
tensorflow/third_party/protobuf/BUILD
|
||||
tensorflow/third_party/py/BUILD.tpl
|
||||
tensorflow/third_party/py/BUILD
|
||||
tensorflow/third_party/py/numpy/BUILD
|
||||
tensorflow/third_party/py/BUILD.tpl
|
||||
tensorflow/third_party/py/python_configure.bzl
|
||||
tensorflow/third_party/pybind11.BUILD
|
||||
tensorflow/third_party/python_runtime/BUILD
|
||||
@ -144,11 +144,11 @@ tensorflow/third_party/snappy.BUILD
|
||||
tensorflow/third_party/sqlite.BUILD
|
||||
tensorflow/third_party/swig.BUILD
|
||||
tensorflow/third_party/sycl/crosstool/BUILD
|
||||
tensorflow/third_party/systemlibs/BUILD
|
||||
tensorflow/third_party/systemlibs/BUILD.tpl
|
||||
tensorflow/third_party/systemlibs/BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
|
||||
tensorflow/third_party/systemlibs/astor.BUILD
|
||||
tensorflow/third_party/systemlibs/boringssl.BUILD
|
||||
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
|
||||
@ -158,19 +158,19 @@ tensorflow/third_party/systemlibs/double_conversion.BUILD
|
||||
tensorflow/third_party/systemlibs/gast.BUILD
|
||||
tensorflow/third_party/systemlibs/gif.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
|
||||
tensorflow/third_party/systemlibs/googleapis.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
|
||||
tensorflow/third_party/systemlibs/grpc.BUILD
|
||||
tensorflow/third_party/systemlibs/jsoncpp.BUILD
|
||||
tensorflow/third_party/systemlibs/lmdb.BUILD
|
||||
tensorflow/third_party/systemlibs/nsync.BUILD
|
||||
tensorflow/third_party/systemlibs/opt_einsum.BUILD
|
||||
tensorflow/third_party/systemlibs/png.BUILD
|
||||
tensorflow/third_party/systemlibs/pcre.BUILD
|
||||
tensorflow/third_party/systemlibs/png.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.bzl
|
||||
tensorflow/third_party/systemlibs/re2.BUILD
|
||||
tensorflow/third_party/systemlibs/six.BUILD
|
||||
tensorflow/third_party/systemlibs/re2.BUILD
|
||||
tensorflow/third_party/systemlibs/snappy.BUILD
|
||||
tensorflow/third_party/systemlibs/sqlite.BUILD
|
||||
tensorflow/third_party/systemlibs/swig.BUILD
|
||||
@ -196,8 +196,8 @@ tensorflow/third_party/toolchains/clang6/README.md
|
||||
tensorflow/third_party/toolchains/clang6/clang.BUILD
|
||||
tensorflow/third_party/toolchains/clang6/repo.bzl
|
||||
tensorflow/third_party/toolchains/cpus/arm/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl
|
||||
tensorflow/third_party/toolchains/cpus/arm/cc_config.bzl.tpl
|
||||
tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl
|
||||
tensorflow/third_party/toolchains/cpus/py/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/BUILD
|
||||
@ -209,26 +209,26 @@ tensorflow/third_party/toolchains/preconfig/centos6/gcc7/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/dummy_toolchain.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.0/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.1/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.1/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.1/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/py/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
@ -247,8 +247,8 @@ tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5.1/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5.1/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_025/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
|
||||
tensorflow/third_party/toolchains/remote/BUILD
|
||||
tensorflow/third_party/toolchains/remote/BUILD.tpl
|
||||
@ -264,8 +264,8 @@ tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
|
||||
tensorflow/tools/lib_package/BUILD
|
||||
tensorflow/tools/lib_package/LibTensorFlowTest.java
|
||||
tensorflow/tools/lib_package/README.md
|
||||
tensorflow/tools/lib_package/concat_licenses.sh
|
||||
tensorflow/tools/lib_package/libtensorflow_java_test.sh
|
||||
tensorflow/tools/lib_package/concat_licenses.sh
|
||||
tensorflow/tools/lib_package/libtensorflow_test.c
|
||||
tensorflow/tools/lib_package/libtensorflow_test.sh
|
||||
tensorflow/tools/pip_package/BUILD
|
||||
@ -273,8 +273,8 @@ tensorflow/tools/pip_package/MANIFEST.in
|
||||
tensorflow/tools/pip_package/README
|
||||
tensorflow/tools/pip_package/build_pip_package.sh
|
||||
tensorflow/tools/pip_package/check_load_py_test.py
|
||||
tensorflow/tools/pip_package/pip_smoke_test.py
|
||||
tensorflow/tools/pip_package/setup.py
|
||||
tensorflow/tools/pip_package/pip_smoke_test.py
|
||||
tensorflow/tools/pip_package/simple_console.py
|
||||
tensorflow/tools/pip_package/simple_console_for_windows.py
|
||||
tensorflow/virtual_root_template_v1.__init__.py
|
||||
|
@ -19,18 +19,22 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.layers import layers
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
@ -40,7 +44,9 @@ from tensorflow.python.ops import nn_impl
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.losses import losses
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import adam
|
||||
from tensorflow.python.training import gradient_descent
|
||||
|
||||
|
||||
@ -616,6 +622,52 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
self._assert_output_fp16(node_map, 'MatMul')
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_ingraph_train_loop(self):
|
||||
"""Tests a graph containing a while loop around a training update.
|
||||
|
||||
This requires the grappler pass to take special care with its handling of
|
||||
Enter ops that appear in front of reads from non-resource variables. See
|
||||
the use of NodeImplicitlyReadsVariable in auto_mixed_precision.cc.
|
||||
"""
|
||||
if tf2.enabled():
|
||||
# This test tests non-resource variables, which are only used in TF1.
|
||||
self.skipTest('TensorFlow 1 required')
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(1234)
|
||||
np.random.seed(1234)
|
||||
num_iter, bs, nchan, nclass = 100, 64, 32, 100
|
||||
|
||||
data = np.random.normal(size=(bs * num_iter, nchan)).astype(np.float32)
|
||||
labels = np.random.randint(nclass, size=(bs * num_iter,))
|
||||
ds = dataset_ops.Dataset.from_tensor_slices((data, labels))
|
||||
ds = ds.batch(bs).prefetch(3)
|
||||
it = ds.make_one_shot_iterator()
|
||||
|
||||
def body(_, i):
|
||||
i += 1
|
||||
x, yt = it.get_next()
|
||||
dense = layers.Dense(nclass)
|
||||
y = dense(x)
|
||||
loss = losses.sparse_softmax_cross_entropy(yt, y)
|
||||
opt = adam.AdamOptimizer()
|
||||
train_op = opt.minimize(loss, var_list=dense.trainable_weights)
|
||||
with ops.control_dependencies([train_op]):
|
||||
loss = array_ops.identity(loss)
|
||||
return loss, i
|
||||
|
||||
begin, end = constant_op.constant(0), constant_op.constant(num_iter)
|
||||
loss, _ = control_flow_ops.while_loop(
|
||||
lambda loss, i: math_ops.less(i, end), body, [0.0, begin])
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(loss)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
|
||||
self._assert_output_fp16(node_map, 'while/dense/MatMul')
|
||||
self._assert_output_fp16(
|
||||
node_map, 'while/gradients/while/dense/MatMul_grad/MatMul_1')
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user