[Grappler] Skip XlaLaunch functions when optimizing function library
Probably a fix for #30580 PiperOrigin-RevId: 264248970
This commit is contained in:
parent
370ea8e1f8
commit
471b73c238
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
@ -959,5 +960,7 @@ bool NeverForwardsInputs(const NodeDef& node) {
|
||||
absl::StartsWith(op_name, "Quantize");
|
||||
}
|
||||
|
||||
bool IsXlaLaunch(const NodeDef& node) { return node.op() == "XlaLaunch"; }
|
||||
|
||||
} // namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
@ -190,6 +190,7 @@ bool IsUnpack(const NodeDef& node);
|
||||
bool IsVariable(const NodeDef& node);
|
||||
bool IsWhile(const NodeDef& node);
|
||||
bool IsXdivy(const NodeDef& node);
|
||||
bool IsXlaLaunch(const NodeDef& node);
|
||||
bool IsZerosLike(const NodeDef& node);
|
||||
bool IsZeta(const NodeDef& node);
|
||||
|
||||
|
@ -591,13 +591,13 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
|
||||
// 2. Optimize functions reachable from the optimized graph.
|
||||
FunctionLibraryDefinition flib = minimized_flib(*optimized_graph);
|
||||
using NodeDefs = protobuf::RepeatedPtrField<NodeDef>;
|
||||
|
||||
// Find functions for which we might need to compute a gradient at runtime.
|
||||
absl::flat_hash_set<string> differentiable_functions;
|
||||
|
||||
using NodeDefs = protobuf::RepeatedPtrField<NodeDef>;
|
||||
const auto find_differentiable_functions =
|
||||
[&differentiable_functions](const NodeDefs& nodes) -> void {
|
||||
[&](const NodeDefs& nodes) -> void {
|
||||
for (const NodeDef& node : nodes) {
|
||||
if (IsSymbolicGradient(node)) {
|
||||
const auto* f_attr = gtl::FindOrNull(node.attr(), "f");
|
||||
@ -613,6 +613,28 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
find_differentiable_functions(function.node_def());
|
||||
}
|
||||
|
||||
// Find functions that are formed by XLA and will be compiled later. We do it
|
||||
// by looking for a function attribute in XlaLaunch ops. Grappler rewrites
|
||||
// potentially can add nodes that are not supported by XLA, so we choose to
|
||||
// skip such functions when we optimize function library.
|
||||
absl::flat_hash_set<string> xla_compiled_functions;
|
||||
|
||||
const auto find_xla_compiled_functions = [&](const NodeDefs& nodes) -> void {
|
||||
NameAttrList function;
|
||||
for (const NodeDef& node : nodes) {
|
||||
if (!IsXlaLaunch(node)) continue;
|
||||
if (!GetNodeAttr(node, "function", &function).ok()) continue;
|
||||
xla_compiled_functions.insert(function.name());
|
||||
}
|
||||
};
|
||||
|
||||
// XlaLaunch ops inside the main graph ...
|
||||
find_xla_compiled_functions(optimized_graph->node());
|
||||
// ... and inside the function library.
|
||||
for (const FunctionDef& function : optimized_graph->library().function()) {
|
||||
find_xla_compiled_functions(function.node_def());
|
||||
}
|
||||
|
||||
// Optimize each function only once.
|
||||
absl::flat_hash_set<string> optimized_funcs;
|
||||
bool optimize_function_library =
|
||||
@ -629,9 +651,10 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
|
||||
// Skip functions that are not reachable from the optimized graph.
|
||||
if (!flib.Contains(func_name)) continue;
|
||||
|
||||
// Skip already optimized functions.
|
||||
if (optimized_funcs.find(func_name) != optimized_funcs.end()) continue;
|
||||
if (optimized_funcs.contains(func_name)) continue;
|
||||
// Skip functions that will be compiled by XLA.
|
||||
if (xla_compiled_functions.contains(func_name)) continue;
|
||||
|
||||
// Skip parametrized functions (function type or body is defined only at
|
||||
// function call time by caller node attributes).
|
||||
|
@ -12,17 +12,17 @@ tensorflow/python/tpu/profiler/pip_package/setup.py
|
||||
tensorflow/stream_executor/build_defs.bzl
|
||||
tensorflow/third_party/BUILD
|
||||
tensorflow/third_party/__init__.py
|
||||
tensorflow/third_party/android/BUILD
|
||||
tensorflow/third_party/android/android.bzl.tpl
|
||||
tensorflow/third_party/android/BUILD
|
||||
tensorflow/third_party/android/android_configure.BUILD.tpl
|
||||
tensorflow/third_party/android/android_configure.bzl
|
||||
tensorflow/third_party/astor.BUILD
|
||||
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
|
||||
tensorflow/third_party/astor.BUILD
|
||||
tensorflow/third_party/backports_weakref.BUILD
|
||||
tensorflow/third_party/boringssl/BUILD
|
||||
tensorflow/third_party/clang_toolchain/BUILD
|
||||
tensorflow/third_party/clang_toolchain/download_clang.bzl
|
||||
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
|
||||
tensorflow/third_party/clang_toolchain/download_clang.bzl
|
||||
tensorflow/third_party/codegen.BUILD
|
||||
tensorflow/third_party/com_google_absl.BUILD
|
||||
tensorflow/third_party/common.bzl
|
||||
@ -30,30 +30,30 @@ tensorflow/third_party/cub.BUILD
|
||||
tensorflow/third_party/curl.BUILD
|
||||
tensorflow/third_party/cython.BUILD
|
||||
tensorflow/third_party/double_conversion.BUILD
|
||||
tensorflow/third_party/eigen.BUILD
|
||||
tensorflow/third_party/eigen3/BUILD
|
||||
tensorflow/third_party/eigen3/Eigen/Cholesky
|
||||
tensorflow/third_party/eigen3/Eigen/Core
|
||||
tensorflow/third_party/eigen3/Eigen/Eigenvalues
|
||||
tensorflow/third_party/eigen3/Eigen/LU
|
||||
tensorflow/third_party/eigen3/Eigen/QR
|
||||
tensorflow/third_party/eigen3/Eigen/SVD
|
||||
tensorflow/third_party/eigen3/Eigen/QR
|
||||
tensorflow/third_party/eigen3/LICENSE
|
||||
tensorflow/third_party/eigen3/gpu_packet_math.patch
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
|
||||
tensorflow/third_party/eigen.BUILD
|
||||
tensorflow/third_party/enum34.BUILD
|
||||
tensorflow/third_party/farmhash.BUILD
|
||||
tensorflow/third_party/fft2d/BUILD
|
||||
@ -64,10 +64,9 @@ tensorflow/third_party/fft2d/fft2d.h
|
||||
tensorflow/third_party/functools32.BUILD
|
||||
tensorflow/third_party/gast.BUILD
|
||||
tensorflow/third_party/gif.BUILD
|
||||
tensorflow/third_party/git/BUILD.tpl
|
||||
tensorflow/third_party/git/BUILD
|
||||
tensorflow/third_party/git/BUILD.tpl
|
||||
tensorflow/third_party/git/git_configure.bzl
|
||||
tensorflow/third_party/googleapis.BUILD
|
||||
tensorflow/third_party/gpus/BUILD
|
||||
tensorflow/third_party/gpus/crosstool/BUILD
|
||||
tensorflow/third_party/gpus/crosstool/BUILD.tpl
|
||||
@ -82,12 +81,13 @@ tensorflow/third_party/gpus/cuda/LICENSE
|
||||
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/cuda/cuda_config.h.tpl
|
||||
tensorflow/third_party/gpus/cuda_configure.bzl
|
||||
tensorflow/third_party/gpus/find_cuda_config.py
|
||||
tensorflow/third_party/gpus/rocm/BUILD
|
||||
tensorflow/third_party/gpus/rocm/BUILD.tpl
|
||||
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
|
||||
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
|
||||
tensorflow/third_party/gpus/find_cuda_config.py
|
||||
tensorflow/third_party/gpus/rocm_configure.bzl
|
||||
tensorflow/third_party/googleapis.BUILD
|
||||
tensorflow/third_party/grpc/BUILD
|
||||
tensorflow/third_party/icu/udata.patch
|
||||
tensorflow/third_party/jsoncpp.BUILD
|
||||
@ -96,8 +96,8 @@ tensorflow/third_party/kafka/config.patch
|
||||
tensorflow/third_party/libxsmm.BUILD
|
||||
tensorflow/third_party/linenoise.BUILD
|
||||
tensorflow/third_party/llvm/BUILD
|
||||
tensorflow/third_party/llvm/expand_cmake_vars.py
|
||||
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
|
||||
tensorflow/third_party/llvm/expand_cmake_vars.py
|
||||
tensorflow/third_party/llvm/llvm.bzl
|
||||
tensorflow/third_party/lmdb.BUILD
|
||||
tensorflow/third_party/mkl/BUILD
|
||||
@ -112,17 +112,17 @@ tensorflow/third_party/mpi/BUILD
|
||||
tensorflow/third_party/mpi_collectives/BUILD
|
||||
tensorflow/third_party/nanopb.BUILD
|
||||
tensorflow/third_party/nccl/BUILD
|
||||
tensorflow/third_party/nccl/archive.BUILD
|
||||
tensorflow/third_party/nccl/LICENSE
|
||||
tensorflow/third_party/nccl/archive.patch
|
||||
tensorflow/third_party/nccl/archive.BUILD
|
||||
tensorflow/third_party/nccl/build_defs.bzl.tpl
|
||||
tensorflow/third_party/nccl/archive.patch
|
||||
tensorflow/third_party/nccl/nccl_configure.bzl
|
||||
tensorflow/third_party/nccl/system.BUILD.tpl
|
||||
tensorflow/third_party/ngraph/BUILD
|
||||
tensorflow/third_party/ngraph/LICENSE
|
||||
tensorflow/third_party/ngraph/NGRAPH_LICENSE
|
||||
tensorflow/third_party/ngraph/ngraph.BUILD
|
||||
tensorflow/third_party/ngraph/build_defs.bzl
|
||||
tensorflow/third_party/ngraph/ngraph.BUILD
|
||||
tensorflow/third_party/ngraph/ngraph_tf.BUILD
|
||||
tensorflow/third_party/ngraph/nlohmann_json.BUILD
|
||||
tensorflow/third_party/ngraph/tbb.BUILD
|
||||
@ -132,8 +132,8 @@ tensorflow/third_party/png.BUILD
|
||||
tensorflow/third_party/png_fix_rpi.patch
|
||||
tensorflow/third_party/pprof.BUILD
|
||||
tensorflow/third_party/protobuf/BUILD
|
||||
tensorflow/third_party/py/BUILD.tpl
|
||||
tensorflow/third_party/py/BUILD
|
||||
tensorflow/third_party/py/BUILD.tpl
|
||||
tensorflow/third_party/py/numpy/BUILD
|
||||
tensorflow/third_party/py/python_configure.bzl
|
||||
tensorflow/third_party/pybind11.BUILD
|
||||
@ -150,29 +150,29 @@ tensorflow/third_party/systemlibs/absl_py.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
|
||||
tensorflow/third_party/systemlibs/astor.BUILD
|
||||
tensorflow/third_party/systemlibs/boringssl.BUILD
|
||||
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
|
||||
tensorflow/third_party/systemlibs/boringssl.BUILD
|
||||
tensorflow/third_party/systemlibs/curl.BUILD
|
||||
tensorflow/third_party/systemlibs/cython.BUILD
|
||||
tensorflow/third_party/systemlibs/double_conversion.BUILD
|
||||
tensorflow/third_party/systemlibs/gast.BUILD
|
||||
tensorflow/third_party/systemlibs/gif.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
|
||||
tensorflow/third_party/systemlibs/googleapis.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
|
||||
tensorflow/third_party/systemlibs/googleapis.BUILD
|
||||
tensorflow/third_party/systemlibs/grpc.BUILD
|
||||
tensorflow/third_party/systemlibs/jsoncpp.BUILD
|
||||
tensorflow/third_party/systemlibs/lmdb.BUILD
|
||||
tensorflow/third_party/systemlibs/jsoncpp.BUILD
|
||||
tensorflow/third_party/systemlibs/nsync.BUILD
|
||||
tensorflow/third_party/systemlibs/opt_einsum.BUILD
|
||||
tensorflow/third_party/systemlibs/pcre.BUILD
|
||||
tensorflow/third_party/systemlibs/png.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.bzl
|
||||
tensorflow/third_party/systemlibs/six.BUILD
|
||||
tensorflow/third_party/systemlibs/re2.BUILD
|
||||
tensorflow/third_party/systemlibs/snappy.BUILD
|
||||
tensorflow/third_party/systemlibs/six.BUILD
|
||||
tensorflow/third_party/systemlibs/sqlite.BUILD
|
||||
tensorflow/third_party/systemlibs/snappy.BUILD
|
||||
tensorflow/third_party/systemlibs/swig.BUILD
|
||||
tensorflow/third_party/systemlibs/syslibs_configure.bzl
|
||||
tensorflow/third_party/systemlibs/termcolor.BUILD
|
||||
@ -189,19 +189,19 @@ tensorflow/third_party/tflite_mobilenet_float.BUILD
|
||||
tensorflow/third_party/tflite_mobilenet_quant.BUILD
|
||||
tensorflow/third_party/tflite_ovic_testdata.BUILD
|
||||
tensorflow/third_party/tflite_smartreply.BUILD
|
||||
tensorflow/third_party/toolchains/BUILD
|
||||
tensorflow/third_party/toolchains/clang6/BUILD
|
||||
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
|
||||
tensorflow/third_party/toolchains/clang6/BUILD
|
||||
tensorflow/third_party/toolchains/clang6/README.md
|
||||
tensorflow/third_party/toolchains/clang6/clang.BUILD
|
||||
tensorflow/third_party/toolchains/clang6/repo.bzl
|
||||
tensorflow/third_party/toolchains/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/arm/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/arm/cc_config.bzl.tpl
|
||||
tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl
|
||||
tensorflow/third_party/toolchains/cpus/arm/cc_config.bzl.tpl
|
||||
tensorflow/third_party/toolchains/cpus/py/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.1-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.1-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/BUILD
|
||||
@ -227,13 +227,13 @@ tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc5-rocm/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc5-rocm/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc5-rocm/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/dummy_toolchain.bzl
|
||||
@ -264,19 +264,19 @@ tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
|
||||
tensorflow/tools/lib_package/BUILD
|
||||
tensorflow/tools/lib_package/LibTensorFlowTest.java
|
||||
tensorflow/tools/lib_package/README.md
|
||||
tensorflow/tools/lib_package/libtensorflow_java_test.sh
|
||||
tensorflow/tools/lib_package/concat_licenses.sh
|
||||
tensorflow/tools/lib_package/libtensorflow_java_test.sh
|
||||
tensorflow/tools/lib_package/libtensorflow_test.c
|
||||
tensorflow/tools/lib_package/libtensorflow_test.sh
|
||||
tensorflow/tools/pip_package/BUILD
|
||||
tensorflow/tools/pip_package/MANIFEST.in
|
||||
tensorflow/tools/pip_package/BUILD
|
||||
tensorflow/tools/pip_package/README
|
||||
tensorflow/tools/pip_package/build_pip_package.sh
|
||||
tensorflow/tools/pip_package/check_load_py_test.py
|
||||
tensorflow/tools/pip_package/setup.py
|
||||
tensorflow/tools/pip_package/pip_smoke_test.py
|
||||
tensorflow/tools/pip_package/setup.py
|
||||
tensorflow/tools/pip_package/simple_console.py
|
||||
tensorflow/tools/pip_package/simple_console_for_windows.py
|
||||
tensorflow/virtual_root_template_v1.__init__.py
|
||||
tensorflow/virtual_root_template_v2.__init__.py
|
||||
tensorflow/virtual_root_template_v1.__init__.py
|
||||
llvm/llvm/projects/google_mlir/WORKSPACE
|
Loading…
Reference in New Issue
Block a user