From 471b73c238709fb796929eb412f1dab763b3f8cc Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 19 Aug 2019 14:55:33 -0700 Subject: [PATCH] [Grappler] Skip XlaLaunch functions when optimizing function library Probably a fix for #30580 PiperOrigin-RevId: 264248970 --- tensorflow/core/grappler/op_types.cc | 3 + tensorflow/core/grappler/op_types.h | 1 + .../grappler/optimizers/meta_optimizer.cc | 31 +++++++-- tensorflow/opensource_only.files | 66 +++++++++---------- 4 files changed, 64 insertions(+), 37 deletions(-) diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 0c94801580f..b3d53360802 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/op_types.h" + #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/types.h" @@ -959,5 +960,7 @@ bool NeverForwardsInputs(const NodeDef& node) { absl::StartsWith(op_name, "Quantize"); } +bool IsXlaLaunch(const NodeDef& node) { return node.op() == "XlaLaunch"; } + } // namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 4dc8b31a0fc..eee368e0700 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -190,6 +190,7 @@ bool IsUnpack(const NodeDef& node); bool IsVariable(const NodeDef& node); bool IsWhile(const NodeDef& node); bool IsXdivy(const NodeDef& node); +bool IsXlaLaunch(const NodeDef& node); bool IsZerosLike(const NodeDef& node); bool IsZeta(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index f9af4d8ef5d..84b9a1b07a2 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -591,13 +591,13 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // 2. Optimize functions reachable from the optimized graph. FunctionLibraryDefinition flib = minimized_flib(*optimized_graph); + using NodeDefs = protobuf::RepeatedPtrField; // Find functions for which we might need to compute a gradient at runtime. absl::flat_hash_set differentiable_functions; - using NodeDefs = protobuf::RepeatedPtrField; const auto find_differentiable_functions = - [&differentiable_functions](const NodeDefs& nodes) -> void { + [&](const NodeDefs& nodes) -> void { for (const NodeDef& node : nodes) { if (IsSymbolicGradient(node)) { const auto* f_attr = gtl::FindOrNull(node.attr(), "f"); @@ -613,6 +613,28 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, find_differentiable_functions(function.node_def()); } + // Find functions that are formed by XLA and will be compiled later. We do it + // by looking for a function attribute in XlaLaunch ops. Grappler rewrites + // potentially can add nodes that are not supported by XLA, so we choose to + // skip such functions when we optimize function library. + absl::flat_hash_set xla_compiled_functions; + + const auto find_xla_compiled_functions = [&](const NodeDefs& nodes) -> void { + NameAttrList function; + for (const NodeDef& node : nodes) { + if (!IsXlaLaunch(node)) continue; + if (!GetNodeAttr(node, "function", &function).ok()) continue; + xla_compiled_functions.insert(function.name()); + } + }; + + // XlaLaunch ops inside the main graph ... + find_xla_compiled_functions(optimized_graph->node()); + // ... and inside the function library. + for (const FunctionDef& function : optimized_graph->library().function()) { + find_xla_compiled_functions(function.node_def()); + } + // Optimize each function only once. absl::flat_hash_set optimized_funcs; bool optimize_function_library = @@ -629,9 +651,10 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // Skip functions that are not reachable from the optimized graph. if (!flib.Contains(func_name)) continue; - // Skip already optimized functions. - if (optimized_funcs.find(func_name) != optimized_funcs.end()) continue; + if (optimized_funcs.contains(func_name)) continue; + // Skip functions that will be compiled by XLA. + if (xla_compiled_functions.contains(func_name)) continue; // Skip parametrized functions (function type or body is defined only at // function call time by caller node attributes). diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index 75b8ed97226..9185db8571d 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -12,17 +12,17 @@ tensorflow/python/tpu/profiler/pip_package/setup.py tensorflow/stream_executor/build_defs.bzl tensorflow/third_party/BUILD tensorflow/third_party/__init__.py -tensorflow/third_party/android/BUILD tensorflow/third_party/android/android.bzl.tpl +tensorflow/third_party/android/BUILD tensorflow/third_party/android/android_configure.BUILD.tpl tensorflow/third_party/android/android_configure.bzl -tensorflow/third_party/astor.BUILD tensorflow/third_party/arm_neon_2_x86_sse.BUILD +tensorflow/third_party/astor.BUILD tensorflow/third_party/backports_weakref.BUILD tensorflow/third_party/boringssl/BUILD tensorflow/third_party/clang_toolchain/BUILD -tensorflow/third_party/clang_toolchain/download_clang.bzl tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl +tensorflow/third_party/clang_toolchain/download_clang.bzl tensorflow/third_party/codegen.BUILD tensorflow/third_party/com_google_absl.BUILD tensorflow/third_party/common.bzl @@ -30,30 +30,30 @@ tensorflow/third_party/cub.BUILD tensorflow/third_party/curl.BUILD tensorflow/third_party/cython.BUILD tensorflow/third_party/double_conversion.BUILD -tensorflow/third_party/eigen.BUILD tensorflow/third_party/eigen3/BUILD tensorflow/third_party/eigen3/Eigen/Cholesky tensorflow/third_party/eigen3/Eigen/Core tensorflow/third_party/eigen3/Eigen/Eigenvalues tensorflow/third_party/eigen3/Eigen/LU -tensorflow/third_party/eigen3/Eigen/QR tensorflow/third_party/eigen3/Eigen/SVD +tensorflow/third_party/eigen3/Eigen/QR tensorflow/third_party/eigen3/LICENSE tensorflow/third_party/eigen3/gpu_packet_math.patch tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions +tensorflow/third_party/eigen.BUILD tensorflow/third_party/enum34.BUILD tensorflow/third_party/farmhash.BUILD tensorflow/third_party/fft2d/BUILD @@ -64,10 +64,9 @@ tensorflow/third_party/fft2d/fft2d.h tensorflow/third_party/functools32.BUILD tensorflow/third_party/gast.BUILD tensorflow/third_party/gif.BUILD -tensorflow/third_party/git/BUILD.tpl tensorflow/third_party/git/BUILD +tensorflow/third_party/git/BUILD.tpl tensorflow/third_party/git/git_configure.bzl -tensorflow/third_party/googleapis.BUILD tensorflow/third_party/gpus/BUILD tensorflow/third_party/gpus/crosstool/BUILD tensorflow/third_party/gpus/crosstool/BUILD.tpl @@ -82,12 +81,13 @@ tensorflow/third_party/gpus/cuda/LICENSE tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl tensorflow/third_party/gpus/cuda/cuda_config.h.tpl tensorflow/third_party/gpus/cuda_configure.bzl -tensorflow/third_party/gpus/find_cuda_config.py tensorflow/third_party/gpus/rocm/BUILD tensorflow/third_party/gpus/rocm/BUILD.tpl -tensorflow/third_party/gpus/rocm/rocm_config.h.tpl tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl +tensorflow/third_party/gpus/rocm/rocm_config.h.tpl +tensorflow/third_party/gpus/find_cuda_config.py tensorflow/third_party/gpus/rocm_configure.bzl +tensorflow/third_party/googleapis.BUILD tensorflow/third_party/grpc/BUILD tensorflow/third_party/icu/udata.patch tensorflow/third_party/jsoncpp.BUILD @@ -96,8 +96,8 @@ tensorflow/third_party/kafka/config.patch tensorflow/third_party/libxsmm.BUILD tensorflow/third_party/linenoise.BUILD tensorflow/third_party/llvm/BUILD -tensorflow/third_party/llvm/expand_cmake_vars.py tensorflow/third_party/llvm/llvm.autogenerated.BUILD +tensorflow/third_party/llvm/expand_cmake_vars.py tensorflow/third_party/llvm/llvm.bzl tensorflow/third_party/lmdb.BUILD tensorflow/third_party/mkl/BUILD @@ -112,17 +112,17 @@ tensorflow/third_party/mpi/BUILD tensorflow/third_party/mpi_collectives/BUILD tensorflow/third_party/nanopb.BUILD tensorflow/third_party/nccl/BUILD -tensorflow/third_party/nccl/archive.BUILD tensorflow/third_party/nccl/LICENSE -tensorflow/third_party/nccl/archive.patch +tensorflow/third_party/nccl/archive.BUILD tensorflow/third_party/nccl/build_defs.bzl.tpl +tensorflow/third_party/nccl/archive.patch tensorflow/third_party/nccl/nccl_configure.bzl tensorflow/third_party/nccl/system.BUILD.tpl tensorflow/third_party/ngraph/BUILD tensorflow/third_party/ngraph/LICENSE tensorflow/third_party/ngraph/NGRAPH_LICENSE -tensorflow/third_party/ngraph/ngraph.BUILD tensorflow/third_party/ngraph/build_defs.bzl +tensorflow/third_party/ngraph/ngraph.BUILD tensorflow/third_party/ngraph/ngraph_tf.BUILD tensorflow/third_party/ngraph/nlohmann_json.BUILD tensorflow/third_party/ngraph/tbb.BUILD @@ -132,8 +132,8 @@ tensorflow/third_party/png.BUILD tensorflow/third_party/png_fix_rpi.patch tensorflow/third_party/pprof.BUILD tensorflow/third_party/protobuf/BUILD -tensorflow/third_party/py/BUILD.tpl tensorflow/third_party/py/BUILD +tensorflow/third_party/py/BUILD.tpl tensorflow/third_party/py/numpy/BUILD tensorflow/third_party/py/python_configure.bzl tensorflow/third_party/pybind11.BUILD @@ -150,29 +150,29 @@ tensorflow/third_party/systemlibs/absl_py.BUILD tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD tensorflow/third_party/systemlibs/astor.BUILD -tensorflow/third_party/systemlibs/boringssl.BUILD tensorflow/third_party/systemlibs/build_defs.bzl.tpl +tensorflow/third_party/systemlibs/boringssl.BUILD tensorflow/third_party/systemlibs/curl.BUILD tensorflow/third_party/systemlibs/cython.BUILD tensorflow/third_party/systemlibs/double_conversion.BUILD tensorflow/third_party/systemlibs/gast.BUILD tensorflow/third_party/systemlibs/gif.BUILD tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD -tensorflow/third_party/systemlibs/googleapis.BUILD tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD +tensorflow/third_party/systemlibs/googleapis.BUILD tensorflow/third_party/systemlibs/grpc.BUILD -tensorflow/third_party/systemlibs/jsoncpp.BUILD tensorflow/third_party/systemlibs/lmdb.BUILD +tensorflow/third_party/systemlibs/jsoncpp.BUILD tensorflow/third_party/systemlibs/nsync.BUILD tensorflow/third_party/systemlibs/opt_einsum.BUILD tensorflow/third_party/systemlibs/pcre.BUILD tensorflow/third_party/systemlibs/png.BUILD tensorflow/third_party/systemlibs/protobuf.BUILD tensorflow/third_party/systemlibs/protobuf.bzl -tensorflow/third_party/systemlibs/six.BUILD tensorflow/third_party/systemlibs/re2.BUILD -tensorflow/third_party/systemlibs/snappy.BUILD +tensorflow/third_party/systemlibs/six.BUILD tensorflow/third_party/systemlibs/sqlite.BUILD +tensorflow/third_party/systemlibs/snappy.BUILD tensorflow/third_party/systemlibs/swig.BUILD tensorflow/third_party/systemlibs/syslibs_configure.bzl tensorflow/third_party/systemlibs/termcolor.BUILD @@ -189,19 +189,19 @@ tensorflow/third_party/tflite_mobilenet_float.BUILD tensorflow/third_party/tflite_mobilenet_quant.BUILD tensorflow/third_party/tflite_ovic_testdata.BUILD tensorflow/third_party/tflite_smartreply.BUILD -tensorflow/third_party/toolchains/BUILD -tensorflow/third_party/toolchains/clang6/BUILD tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl +tensorflow/third_party/toolchains/clang6/BUILD tensorflow/third_party/toolchains/clang6/README.md tensorflow/third_party/toolchains/clang6/clang.BUILD tensorflow/third_party/toolchains/clang6/repo.bzl +tensorflow/third_party/toolchains/BUILD tensorflow/third_party/toolchains/cpus/arm/BUILD -tensorflow/third_party/toolchains/cpus/arm/cc_config.bzl.tpl tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl +tensorflow/third_party/toolchains/cpus/arm/cc_config.bzl.tpl tensorflow/third_party/toolchains/cpus/py/BUILD tensorflow/third_party/toolchains/cpus/py3/BUILD -tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/BUILD tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/build_defs.bzl +tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/BUILD tensorflow/third_party/toolchains/preconfig/centos6/cuda10.1-cudnn7/cuda/BUILD tensorflow/third_party/toolchains/preconfig/centos6/cuda10.1-cudnn7/cuda/build_defs.bzl tensorflow/third_party/toolchains/preconfig/centos6/gcc7/BUILD @@ -227,13 +227,13 @@ tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl -tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD +tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/build_defs.bzl -tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc5-rocm/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc5-rocm/cc_toolchain_config.bzl +tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc5-rocm/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/cc_toolchain_config.bzl tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/dummy_toolchain.bzl @@ -264,19 +264,19 @@ tensorflow/tools/def_file_filter/def_file_filter_configure.bzl tensorflow/tools/lib_package/BUILD tensorflow/tools/lib_package/LibTensorFlowTest.java tensorflow/tools/lib_package/README.md -tensorflow/tools/lib_package/libtensorflow_java_test.sh tensorflow/tools/lib_package/concat_licenses.sh +tensorflow/tools/lib_package/libtensorflow_java_test.sh tensorflow/tools/lib_package/libtensorflow_test.c tensorflow/tools/lib_package/libtensorflow_test.sh -tensorflow/tools/pip_package/BUILD tensorflow/tools/pip_package/MANIFEST.in +tensorflow/tools/pip_package/BUILD tensorflow/tools/pip_package/README tensorflow/tools/pip_package/build_pip_package.sh tensorflow/tools/pip_package/check_load_py_test.py -tensorflow/tools/pip_package/setup.py tensorflow/tools/pip_package/pip_smoke_test.py +tensorflow/tools/pip_package/setup.py tensorflow/tools/pip_package/simple_console.py tensorflow/tools/pip_package/simple_console_for_windows.py -tensorflow/virtual_root_template_v1.__init__.py tensorflow/virtual_root_template_v2.__init__.py +tensorflow/virtual_root_template_v1.__init__.py llvm/llvm/projects/google_mlir/WORKSPACE \ No newline at end of file