[Grappler] Skip XlaLaunch functions when optimizing function library

Probably a fix for #30580

PiperOrigin-RevId: 264248970
This commit is contained in:
Eugene Zhulenev 2019-08-19 14:55:33 -07:00 committed by TensorFlower Gardener
parent 370ea8e1f8
commit 471b73c238
4 changed files with 64 additions and 37 deletions

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/types.h"
@ -959,5 +960,7 @@ bool NeverForwardsInputs(const NodeDef& node) {
absl::StartsWith(op_name, "Quantize");
}
bool IsXlaLaunch(const NodeDef& node) { return node.op() == "XlaLaunch"; }
} // namespace grappler
} // end namespace tensorflow

View File

@ -190,6 +190,7 @@ bool IsUnpack(const NodeDef& node);
bool IsVariable(const NodeDef& node);
bool IsWhile(const NodeDef& node);
bool IsXdivy(const NodeDef& node);
bool IsXlaLaunch(const NodeDef& node);
bool IsZerosLike(const NodeDef& node);
bool IsZeta(const NodeDef& node);

View File

@ -591,13 +591,13 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// 2. Optimize functions reachable from the optimized graph.
FunctionLibraryDefinition flib = minimized_flib(*optimized_graph);
using NodeDefs = protobuf::RepeatedPtrField<NodeDef>;
// Find functions for which we might need to compute a gradient at runtime.
absl::flat_hash_set<string> differentiable_functions;
using NodeDefs = protobuf::RepeatedPtrField<NodeDef>;
const auto find_differentiable_functions =
[&differentiable_functions](const NodeDefs& nodes) -> void {
[&](const NodeDefs& nodes) -> void {
for (const NodeDef& node : nodes) {
if (IsSymbolicGradient(node)) {
const auto* f_attr = gtl::FindOrNull(node.attr(), "f");
@ -613,6 +613,28 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
find_differentiable_functions(function.node_def());
}
// Find functions that are formed by XLA and will be compiled later. We do it
// by looking for a function attribute in XlaLaunch ops. Grappler rewrites
// potentially can add nodes that are not supported by XLA, so we choose to
// skip such functions when we optimize function library.
absl::flat_hash_set<string> xla_compiled_functions;
const auto find_xla_compiled_functions = [&](const NodeDefs& nodes) -> void {
NameAttrList function;
for (const NodeDef& node : nodes) {
if (!IsXlaLaunch(node)) continue;
if (!GetNodeAttr(node, "function", &function).ok()) continue;
xla_compiled_functions.insert(function.name());
}
};
// XlaLaunch ops inside the main graph ...
find_xla_compiled_functions(optimized_graph->node());
// ... and inside the function library.
for (const FunctionDef& function : optimized_graph->library().function()) {
find_xla_compiled_functions(function.node_def());
}
// Optimize each function only once.
absl::flat_hash_set<string> optimized_funcs;
bool optimize_function_library =
@ -629,9 +651,10 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// Skip functions that are not reachable from the optimized graph.
if (!flib.Contains(func_name)) continue;
// Skip already optimized functions.
if (optimized_funcs.find(func_name) != optimized_funcs.end()) continue;
if (optimized_funcs.contains(func_name)) continue;
// Skip functions that will be compiled by XLA.
if (xla_compiled_functions.contains(func_name)) continue;
// Skip parametrized functions (function type or body is defined only at
// function call time by caller node attributes).

View File

@ -12,17 +12,17 @@ tensorflow/python/tpu/profiler/pip_package/setup.py
tensorflow/stream_executor/build_defs.bzl
tensorflow/third_party/BUILD
tensorflow/third_party/__init__.py
tensorflow/third_party/android/BUILD
tensorflow/third_party/android/android.bzl.tpl
tensorflow/third_party/android/BUILD
tensorflow/third_party/android/android_configure.BUILD.tpl
tensorflow/third_party/android/android_configure.bzl
tensorflow/third_party/astor.BUILD
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
tensorflow/third_party/astor.BUILD
tensorflow/third_party/backports_weakref.BUILD
tensorflow/third_party/boringssl/BUILD
tensorflow/third_party/clang_toolchain/BUILD
tensorflow/third_party/clang_toolchain/download_clang.bzl
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
tensorflow/third_party/clang_toolchain/download_clang.bzl
tensorflow/third_party/codegen.BUILD
tensorflow/third_party/com_google_absl.BUILD
tensorflow/third_party/common.bzl
@ -30,30 +30,30 @@ tensorflow/third_party/cub.BUILD
tensorflow/third_party/curl.BUILD
tensorflow/third_party/cython.BUILD
tensorflow/third_party/double_conversion.BUILD
tensorflow/third_party/eigen.BUILD
tensorflow/third_party/eigen3/BUILD
tensorflow/third_party/eigen3/Eigen/Cholesky
tensorflow/third_party/eigen3/Eigen/Core
tensorflow/third_party/eigen3/Eigen/Eigenvalues
tensorflow/third_party/eigen3/Eigen/LU
tensorflow/third_party/eigen3/Eigen/QR
tensorflow/third_party/eigen3/Eigen/SVD
tensorflow/third_party/eigen3/Eigen/QR
tensorflow/third_party/eigen3/LICENSE
tensorflow/third_party/eigen3/gpu_packet_math.patch
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
tensorflow/third_party/eigen.BUILD
tensorflow/third_party/enum34.BUILD
tensorflow/third_party/farmhash.BUILD
tensorflow/third_party/fft2d/BUILD
@ -64,10 +64,9 @@ tensorflow/third_party/fft2d/fft2d.h
tensorflow/third_party/functools32.BUILD
tensorflow/third_party/gast.BUILD
tensorflow/third_party/gif.BUILD
tensorflow/third_party/git/BUILD.tpl
tensorflow/third_party/git/BUILD
tensorflow/third_party/git/BUILD.tpl
tensorflow/third_party/git/git_configure.bzl
tensorflow/third_party/googleapis.BUILD
tensorflow/third_party/gpus/BUILD
tensorflow/third_party/gpus/crosstool/BUILD
tensorflow/third_party/gpus/crosstool/BUILD.tpl
@ -82,12 +81,13 @@ tensorflow/third_party/gpus/cuda/LICENSE
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
tensorflow/third_party/gpus/cuda/cuda_config.h.tpl
tensorflow/third_party/gpus/cuda_configure.bzl
tensorflow/third_party/gpus/find_cuda_config.py
tensorflow/third_party/gpus/rocm/BUILD
tensorflow/third_party/gpus/rocm/BUILD.tpl
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
tensorflow/third_party/gpus/find_cuda_config.py
tensorflow/third_party/gpus/rocm_configure.bzl
tensorflow/third_party/googleapis.BUILD
tensorflow/third_party/grpc/BUILD
tensorflow/third_party/icu/udata.patch
tensorflow/third_party/jsoncpp.BUILD
@ -96,8 +96,8 @@ tensorflow/third_party/kafka/config.patch
tensorflow/third_party/libxsmm.BUILD
tensorflow/third_party/linenoise.BUILD
tensorflow/third_party/llvm/BUILD
tensorflow/third_party/llvm/expand_cmake_vars.py
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
tensorflow/third_party/llvm/expand_cmake_vars.py
tensorflow/third_party/llvm/llvm.bzl
tensorflow/third_party/lmdb.BUILD
tensorflow/third_party/mkl/BUILD
@ -112,17 +112,17 @@ tensorflow/third_party/mpi/BUILD
tensorflow/third_party/mpi_collectives/BUILD
tensorflow/third_party/nanopb.BUILD
tensorflow/third_party/nccl/BUILD
tensorflow/third_party/nccl/archive.BUILD
tensorflow/third_party/nccl/LICENSE
tensorflow/third_party/nccl/archive.patch
tensorflow/third_party/nccl/archive.BUILD
tensorflow/third_party/nccl/build_defs.bzl.tpl
tensorflow/third_party/nccl/archive.patch
tensorflow/third_party/nccl/nccl_configure.bzl
tensorflow/third_party/nccl/system.BUILD.tpl
tensorflow/third_party/ngraph/BUILD
tensorflow/third_party/ngraph/LICENSE
tensorflow/third_party/ngraph/NGRAPH_LICENSE
tensorflow/third_party/ngraph/ngraph.BUILD
tensorflow/third_party/ngraph/build_defs.bzl
tensorflow/third_party/ngraph/ngraph.BUILD
tensorflow/third_party/ngraph/ngraph_tf.BUILD
tensorflow/third_party/ngraph/nlohmann_json.BUILD
tensorflow/third_party/ngraph/tbb.BUILD
@ -132,8 +132,8 @@ tensorflow/third_party/png.BUILD
tensorflow/third_party/png_fix_rpi.patch
tensorflow/third_party/pprof.BUILD
tensorflow/third_party/protobuf/BUILD
tensorflow/third_party/py/BUILD.tpl
tensorflow/third_party/py/BUILD
tensorflow/third_party/py/BUILD.tpl
tensorflow/third_party/py/numpy/BUILD
tensorflow/third_party/py/python_configure.bzl
tensorflow/third_party/pybind11.BUILD
@ -150,29 +150,29 @@ tensorflow/third_party/systemlibs/absl_py.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
tensorflow/third_party/systemlibs/astor.BUILD
tensorflow/third_party/systemlibs/boringssl.BUILD
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
tensorflow/third_party/systemlibs/boringssl.BUILD
tensorflow/third_party/systemlibs/curl.BUILD
tensorflow/third_party/systemlibs/cython.BUILD
tensorflow/third_party/systemlibs/double_conversion.BUILD
tensorflow/third_party/systemlibs/gast.BUILD
tensorflow/third_party/systemlibs/gif.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
tensorflow/third_party/systemlibs/googleapis.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
tensorflow/third_party/systemlibs/googleapis.BUILD
tensorflow/third_party/systemlibs/grpc.BUILD
tensorflow/third_party/systemlibs/jsoncpp.BUILD
tensorflow/third_party/systemlibs/lmdb.BUILD
tensorflow/third_party/systemlibs/jsoncpp.BUILD
tensorflow/third_party/systemlibs/nsync.BUILD
tensorflow/third_party/systemlibs/opt_einsum.BUILD
tensorflow/third_party/systemlibs/pcre.BUILD
tensorflow/third_party/systemlibs/png.BUILD
tensorflow/third_party/systemlibs/protobuf.BUILD
tensorflow/third_party/systemlibs/protobuf.bzl
tensorflow/third_party/systemlibs/six.BUILD
tensorflow/third_party/systemlibs/re2.BUILD
tensorflow/third_party/systemlibs/snappy.BUILD
tensorflow/third_party/systemlibs/six.BUILD
tensorflow/third_party/systemlibs/sqlite.BUILD
tensorflow/third_party/systemlibs/snappy.BUILD
tensorflow/third_party/systemlibs/swig.BUILD
tensorflow/third_party/systemlibs/syslibs_configure.bzl
tensorflow/third_party/systemlibs/termcolor.BUILD
@ -189,19 +189,19 @@ tensorflow/third_party/tflite_mobilenet_float.BUILD
tensorflow/third_party/tflite_mobilenet_quant.BUILD
tensorflow/third_party/tflite_ovic_testdata.BUILD
tensorflow/third_party/tflite_smartreply.BUILD
tensorflow/third_party/toolchains/BUILD
tensorflow/third_party/toolchains/clang6/BUILD
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
tensorflow/third_party/toolchains/clang6/BUILD
tensorflow/third_party/toolchains/clang6/README.md
tensorflow/third_party/toolchains/clang6/clang.BUILD
tensorflow/third_party/toolchains/clang6/repo.bzl
tensorflow/third_party/toolchains/BUILD
tensorflow/third_party/toolchains/cpus/arm/BUILD
tensorflow/third_party/toolchains/cpus/arm/cc_config.bzl.tpl
tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl
tensorflow/third_party/toolchains/cpus/arm/cc_config.bzl.tpl
tensorflow/third_party/toolchains/cpus/py/BUILD
tensorflow/third_party/toolchains/cpus/py3/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.1-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.1-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/BUILD
@ -227,13 +227,13 @@ tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc5-rocm/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc5-rocm/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc5-rocm/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/dummy_toolchain.bzl
@ -264,19 +264,19 @@ tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
tensorflow/tools/lib_package/BUILD
tensorflow/tools/lib_package/LibTensorFlowTest.java
tensorflow/tools/lib_package/README.md
tensorflow/tools/lib_package/libtensorflow_java_test.sh
tensorflow/tools/lib_package/concat_licenses.sh
tensorflow/tools/lib_package/libtensorflow_java_test.sh
tensorflow/tools/lib_package/libtensorflow_test.c
tensorflow/tools/lib_package/libtensorflow_test.sh
tensorflow/tools/pip_package/BUILD
tensorflow/tools/pip_package/MANIFEST.in
tensorflow/tools/pip_package/BUILD
tensorflow/tools/pip_package/README
tensorflow/tools/pip_package/build_pip_package.sh
tensorflow/tools/pip_package/check_load_py_test.py
tensorflow/tools/pip_package/setup.py
tensorflow/tools/pip_package/pip_smoke_test.py
tensorflow/tools/pip_package/setup.py
tensorflow/tools/pip_package/simple_console.py
tensorflow/tools/pip_package/simple_console_for_windows.py
tensorflow/virtual_root_template_v1.__init__.py
tensorflow/virtual_root_template_v2.__init__.py
tensorflow/virtual_root_template_v1.__init__.py
llvm/llvm/projects/google_mlir/WORKSPACE