Expose clear_attr in the python graph API. Use it to clear hanging nodes with no input/output from being TPU compiled.
PiperOrigin-RevId: 236171147
This commit is contained in:
parent
e59d6b67ad
commit
4059344a81
@ -41,6 +41,15 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
|
||||
RecordMutation(graph, *op, "setting attribute");
|
||||
}
|
||||
|
||||
void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
|
||||
TF_Status* status) {
|
||||
AttrValue attr_val;
|
||||
|
||||
mutex_lock l(graph->mu);
|
||||
op->node.ClearAttr(attr_name);
|
||||
RecordMutation(graph, *op, "clearing attribute");
|
||||
}
|
||||
|
||||
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
|
||||
mutex_lock l(graph->mu);
|
||||
op->node.set_requested_device(device);
|
||||
|
@ -32,6 +32,11 @@ void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input);
|
||||
void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
|
||||
TF_Buffer* attr_value_proto, TF_Status* status);
|
||||
|
||||
// Clears the attr in the node_def Protocol Buffer and sets a status upon
|
||||
// completion.
|
||||
void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
|
||||
TF_Status* status);
|
||||
|
||||
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device);
|
||||
|
||||
// Updates 'dst' to consume 'new_src'.
|
||||
|
@ -1,245 +1,245 @@
|
||||
tensorflow/contrib/tpu/profiler/pip_package/BUILD
|
||||
tensorflow/contrib/tpu/profiler/pip_package/setup.py
|
||||
tensorflow/contrib/tpu/profiler/pip_package/README
|
||||
tensorflow/contrib/tpu/profiler/pip_package/build_pip_package.sh
|
||||
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
|
||||
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/__init__.py
|
||||
tensorflow/contrib/mpi/BUILD
|
||||
tensorflow/tools/ci_build/remote/BUILD
|
||||
tensorflow/tools/pip_package/README
|
||||
tensorflow/tools/pip_package/MANIFEST.in
|
||||
tensorflow/tools/pip_package/simple_console.py
|
||||
tensorflow/tools/pip_package/build_pip_package.sh
|
||||
tensorflow/tools/pip_package/check_load_py_test.py
|
||||
tensorflow/tools/pip_package/pip_smoke_test.py
|
||||
tensorflow/tools/pip_package/simple_console_for_windows.py
|
||||
tensorflow/tools/pip_package/setup.py
|
||||
tensorflow/tools/pip_package/BUILD
|
||||
tensorflow/tools/lib_package/concat_licenses.sh
|
||||
tensorflow/tools/lib_package/libtensorflow_test.c
|
||||
tensorflow/tools/lib_package/LibTensorFlowTest.java
|
||||
tensorflow/tools/lib_package/BUILD
|
||||
tensorflow/tools/lib_package/libtensorflow_test.sh
|
||||
tensorflow/tools/lib_package/README.md
|
||||
tensorflow/tools/lib_package/libtensorflow_java_test.sh
|
||||
tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
|
||||
tensorflow/tools/def_file_filter/BUILD
|
||||
tensorflow/tools/def_file_filter/BUILD.tpl
|
||||
tensorflow/tools/def_file_filter/def_file_filter.py.tpl
|
||||
tensorflow/third_party/mkl/MKL_LICENSE
|
||||
tensorflow/third_party/mkl/LICENSE
|
||||
tensorflow/third_party/mkl/BUILD
|
||||
tensorflow/third_party/mkl/mkl.BUILD
|
||||
tensorflow/third_party/mkl/build_defs.bzl
|
||||
tensorflow/third_party/backports_weakref.BUILD
|
||||
tensorflow/third_party/toolchains/clang6/BUILD
|
||||
tensorflow/third_party/toolchains/clang6/README.md
|
||||
tensorflow/third_party/toolchains/clang6/repo.bzl
|
||||
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
|
||||
tensorflow/third_party/toolchains/clang6/clang.BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc7-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda9.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/dummy_toolchain.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
|
||||
tensorflow/third_party/systemlibs/nsync.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
|
||||
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
|
||||
tensorflow/third_party/systemlibs/curl.BUILD
|
||||
tensorflow/third_party/systemlibs/cython.BUILD
|
||||
tensorflow/third_party/systemlibs/astor.BUILD
|
||||
tensorflow/third_party/systemlibs/jsoncpp.BUILD
|
||||
tensorflow/third_party/systemlibs/png.BUILD
|
||||
tensorflow/third_party/systemlibs/pcre.BUILD
|
||||
tensorflow/third_party/systemlibs/grpc.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
tensorflow/third_party/systemlibs/double_conversion.BUILD
|
||||
tensorflow/third_party/systemlibs/six.BUILD
|
||||
tensorflow/third_party/systemlibs/zlib.BUILD
|
||||
tensorflow/third_party/systemlibs/lmdb.BUILD
|
||||
tensorflow/third_party/systemlibs/sqlite.BUILD
|
||||
tensorflow/third_party/systemlibs/gast.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.BUILD
|
||||
tensorflow/third_party/systemlibs/boringssl.BUILD
|
||||
tensorflow/third_party/systemlibs/BUILD.tpl
|
||||
tensorflow/third_party/systemlibs/BUILD
|
||||
tensorflow/third_party/systemlibs/termcolor.BUILD
|
||||
tensorflow/third_party/systemlibs/gif.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.bzl
|
||||
tensorflow/third_party/systemlibs/snappy.BUILD
|
||||
tensorflow/third_party/systemlibs/googleapis.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
|
||||
tensorflow/third_party/systemlibs/re2.BUILD
|
||||
tensorflow/third_party/systemlibs/swig.BUILD
|
||||
tensorflow/third_party/systemlibs/syslibs_configure.bzl
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
|
||||
tensorflow/third_party/pprof.BUILD
|
||||
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
|
||||
tensorflow/third_party/toolchains/remote/BUILD.tpl
|
||||
tensorflow/third_party/toolchains/remote/BUILD
|
||||
tensorflow/third_party/toolchains/remote/configure.bzl
|
||||
tensorflow/third_party/toolchains/cpus/py3/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/py/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl
|
||||
tensorflow/third_party/toolchains/cpus/arm/CROSSTOOL.tpl
|
||||
tensorflow/third_party/toolchains/cpus/arm/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/py3/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/py/BUILD
|
||||
tensorflow/third_party/toolchains/remote/configure.bzl
|
||||
tensorflow/third_party/toolchains/remote/BUILD.tpl
|
||||
tensorflow/third_party/toolchains/remote/BUILD
|
||||
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
|
||||
tensorflow/third_party/toolchains/BUILD
|
||||
tensorflow/third_party/gpus/BUILD
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
|
||||
tensorflow/third_party/gpus/crosstool/CROSSTOOL.tpl
|
||||
tensorflow/third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/dummy_toolchain.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda9.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc7-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/clang6/repo.bzl
|
||||
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
|
||||
tensorflow/third_party/toolchains/clang6/BUILD
|
||||
tensorflow/third_party/toolchains/clang6/clang.BUILD
|
||||
tensorflow/third_party/toolchains/clang6/README.md
|
||||
tensorflow/third_party/farmhash.BUILD
|
||||
tensorflow/third_party/git/BUILD.tpl
|
||||
tensorflow/third_party/git/git_configure.bzl
|
||||
tensorflow/third_party/git/BUILD
|
||||
tensorflow/third_party/cub.BUILD
|
||||
tensorflow/third_party/gpus/cuda_configure.bzl
|
||||
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/rocm/BUILD.tpl
|
||||
tensorflow/third_party/gpus/rocm/BUILD
|
||||
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
|
||||
tensorflow/third_party/gpus/rocm_configure.bzl
|
||||
tensorflow/third_party/gpus/crosstool/LICENSE
|
||||
tensorflow/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
|
||||
tensorflow/third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
|
||||
tensorflow/third_party/gpus/crosstool/CROSSTOOL.tpl
|
||||
tensorflow/third_party/gpus/crosstool/BUILD.tpl
|
||||
tensorflow/third_party/gpus/crosstool/BUILD
|
||||
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/cuda/LICENSE
|
||||
tensorflow/third_party/gpus/cuda/BUILD.tpl
|
||||
tensorflow/third_party/gpus/cuda/BUILD.windows.tpl
|
||||
tensorflow/third_party/gpus/cuda/cuda_config.h.tpl
|
||||
tensorflow/third_party/gpus/cuda/BUILD.tpl
|
||||
tensorflow/third_party/gpus/cuda/BUILD
|
||||
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
|
||||
tensorflow/third_party/gpus/rocm/BUILD
|
||||
tensorflow/third_party/gpus/rocm/BUILD.tpl
|
||||
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/cuda_configure.bzl
|
||||
tensorflow/third_party/gpus/rocm_configure.bzl
|
||||
tensorflow/third_party/snappy.BUILD
|
||||
tensorflow/third_party/cython.BUILD
|
||||
tensorflow/third_party/farmhash.BUILD
|
||||
tensorflow/third_party/eigen3/Eigen/Cholesky
|
||||
tensorflow/third_party/eigen3/Eigen/QR
|
||||
tensorflow/third_party/eigen3/Eigen/LU
|
||||
tensorflow/third_party/eigen3/Eigen/Core
|
||||
tensorflow/third_party/eigen3/Eigen/SVD
|
||||
tensorflow/third_party/eigen3/Eigen/Eigenvalues
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
|
||||
tensorflow/third_party/eigen3/gpu_packet_math.patch
|
||||
tensorflow/third_party/eigen3/LICENSE
|
||||
tensorflow/third_party/eigen3/BUILD
|
||||
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
|
||||
tensorflow/third_party/systemlibs/absl_py.BUILD
|
||||
tensorflow/third_party/systemlibs/curl.BUILD
|
||||
tensorflow/third_party/systemlibs/termcolor.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
|
||||
tensorflow/third_party/systemlibs/grpc.BUILD
|
||||
tensorflow/third_party/systemlibs/swig.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.bzl
|
||||
tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
tensorflow/third_party/systemlibs/BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
|
||||
tensorflow/third_party/systemlibs/astor.BUILD
|
||||
tensorflow/third_party/systemlibs/six.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
|
||||
tensorflow/third_party/systemlibs/boringssl.BUILD
|
||||
tensorflow/third_party/systemlibs/nsync.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
|
||||
tensorflow/third_party/systemlibs/gif.BUILD
|
||||
tensorflow/third_party/systemlibs/pcre.BUILD
|
||||
tensorflow/third_party/systemlibs/BUILD.tpl
|
||||
tensorflow/third_party/systemlibs/snappy.BUILD
|
||||
tensorflow/third_party/systemlibs/gast.BUILD
|
||||
tensorflow/third_party/systemlibs/cython.BUILD
|
||||
tensorflow/third_party/systemlibs/double_conversion.BUILD
|
||||
tensorflow/third_party/systemlibs/zlib.BUILD
|
||||
tensorflow/third_party/systemlibs/jsoncpp.BUILD
|
||||
tensorflow/third_party/systemlibs/re2.BUILD
|
||||
tensorflow/third_party/systemlibs/lmdb.BUILD
|
||||
tensorflow/third_party/systemlibs/googleapis.BUILD
|
||||
tensorflow/third_party/systemlibs/png.BUILD
|
||||
tensorflow/third_party/systemlibs/syslibs_configure.bzl
|
||||
tensorflow/third_party/systemlibs/sqlite.BUILD
|
||||
tensorflow/third_party/python_runtime/BUILD
|
||||
tensorflow/third_party/sycl/crosstool/BUILD
|
||||
tensorflow/third_party/ngraph/LICENSE
|
||||
tensorflow/third_party/ngraph/tbb.BUILD
|
||||
tensorflow/third_party/ngraph/BUILD
|
||||
tensorflow/third_party/ngraph/ngraph.BUILD
|
||||
tensorflow/third_party/ngraph/build_defs.bzl
|
||||
tensorflow/third_party/ngraph/NGRAPH_LICENSE
|
||||
tensorflow/third_party/ngraph/ngraph_tf.BUILD
|
||||
tensorflow/third_party/ngraph/nlohmann_json.BUILD
|
||||
tensorflow/third_party/clang_toolchain/download_clang.bzl
|
||||
tensorflow/third_party/clang_toolchain/BUILD
|
||||
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
|
||||
tensorflow/third_party/gast.BUILD
|
||||
tensorflow/third_party/llvm/BUILD
|
||||
tensorflow/third_party/llvm/expand_cmake_vars.py
|
||||
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
|
||||
tensorflow/third_party/llvm/llvm.bzl
|
||||
tensorflow/third_party/icu/udata.patch
|
||||
tensorflow/third_party/nccl/archive.BUILD
|
||||
tensorflow/third_party/nccl/LICENSE
|
||||
tensorflow/third_party/nccl/system.BUILD.tpl
|
||||
tensorflow/third_party/nccl/nccl_configure.bzl
|
||||
tensorflow/third_party/nccl/build_defs.bzl.tpl
|
||||
tensorflow/third_party/nccl/BUILD
|
||||
tensorflow/third_party/fft2d/BUILD
|
||||
tensorflow/third_party/fft2d/fft.h
|
||||
tensorflow/third_party/fft2d/LICENSE
|
||||
tensorflow/third_party/fft2d/fft2d.BUILD
|
||||
tensorflow/third_party/boringssl/BUILD
|
||||
tensorflow/third_party/mpi/.gitignore
|
||||
tensorflow/third_party/mpi/BUILD
|
||||
tensorflow/third_party/tensorrt/LICENSE
|
||||
tensorflow/third_party/tensorrt/BUILD
|
||||
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
|
||||
tensorflow/third_party/tensorrt/BUILD.tpl
|
||||
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
|
||||
tensorflow/third_party/kafka/config.patch
|
||||
tensorflow/third_party/kafka/BUILD
|
||||
tensorflow/third_party/android/BUILD
|
||||
tensorflow/third_party/android/android.bzl.tpl
|
||||
tensorflow/third_party/android/android_configure.bzl
|
||||
tensorflow/third_party/android/android_configure.BUILD.tpl
|
||||
tensorflow/third_party/tflite_smartreply.BUILD
|
||||
tensorflow/third_party/gpus/BUILD
|
||||
tensorflow/third_party/common.bzl
|
||||
tensorflow/third_party/tflite_mobilenet_quant.BUILD
|
||||
tensorflow/third_party/linenoise.BUILD
|
||||
tensorflow/third_party/curl.BUILD
|
||||
tensorflow/third_party/mkl_dnn/LICENSE
|
||||
tensorflow/third_party/mkl_dnn/mkldnn.BUILD
|
||||
tensorflow/third_party/pcre.BUILD
|
||||
tensorflow/third_party/linenoise.BUILD
|
||||
tensorflow/third_party/sqlite.BUILD
|
||||
tensorflow/third_party/common.bzl
|
||||
tensorflow/third_party/com_google_absl.BUILD
|
||||
tensorflow/third_party/pprof.BUILD
|
||||
tensorflow/third_party/BUILD
|
||||
tensorflow/third_party/tflite_mobilenet_quant.BUILD
|
||||
tensorflow/third_party/lmdb.BUILD
|
||||
tensorflow/third_party/git/BUILD.tpl
|
||||
tensorflow/third_party/git/BUILD
|
||||
tensorflow/third_party/git/git_configure.bzl
|
||||
tensorflow/third_party/protobuf/BUILD
|
||||
tensorflow/third_party/enum34.BUILD
|
||||
tensorflow/third_party/tflite_mobilenet.BUILD
|
||||
tensorflow/third_party/py/BUILD
|
||||
tensorflow/third_party/py/BUILD.tpl
|
||||
tensorflow/third_party/py/numpy/BUILD
|
||||
tensorflow/third_party/py/python_configure.bzl
|
||||
tensorflow/third_party/termcolor.BUILD
|
||||
tensorflow/third_party/png_fix_rpi.patch
|
||||
tensorflow/third_party/swig.BUILD
|
||||
tensorflow/third_party/astor.BUILD
|
||||
tensorflow/third_party/fft2d/LICENSE
|
||||
tensorflow/third_party/fft2d/fft2d.BUILD
|
||||
tensorflow/third_party/fft2d/fft.h
|
||||
tensorflow/third_party/fft2d/BUILD
|
||||
tensorflow/third_party/ngraph/LICENSE
|
||||
tensorflow/third_party/ngraph/build_defs.bzl
|
||||
tensorflow/third_party/ngraph/tbb.BUILD
|
||||
tensorflow/third_party/ngraph/ngraph.BUILD
|
||||
tensorflow/third_party/ngraph/nlohmann_json.BUILD
|
||||
tensorflow/third_party/ngraph/BUILD
|
||||
tensorflow/third_party/ngraph/ngraph_tf.BUILD
|
||||
tensorflow/third_party/ngraph/NGRAPH_LICENSE
|
||||
tensorflow/third_party/grpc/BUILD
|
||||
tensorflow/third_party/curl.BUILD
|
||||
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
|
||||
tensorflow/third_party/cython.BUILD
|
||||
tensorflow/third_party/icu/udata.patch
|
||||
tensorflow/third_party/astor.BUILD
|
||||
tensorflow/third_party/jsoncpp.BUILD
|
||||
tensorflow/third_party/sycl/crosstool/BUILD
|
||||
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
|
||||
tensorflow/third_party/llvm/expand_cmake_vars.py
|
||||
tensorflow/third_party/llvm/llvm.bzl
|
||||
tensorflow/third_party/llvm/BUILD
|
||||
tensorflow/third_party/png.BUILD
|
||||
tensorflow/third_party/googleapis.BUILD
|
||||
tensorflow/third_party/mpi_collectives/BUILD
|
||||
tensorflow/third_party/nanopb.BUILD
|
||||
tensorflow/third_party/gif.BUILD
|
||||
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
|
||||
tensorflow/third_party/codegen.BUILD
|
||||
tensorflow/third_party/enum34.BUILD
|
||||
tensorflow/third_party/kafka/config.patch
|
||||
tensorflow/third_party/kafka/BUILD
|
||||
tensorflow/third_party/pcre.BUILD
|
||||
tensorflow/third_party/mpi/BUILD
|
||||
tensorflow/third_party/mpi/.gitignore
|
||||
tensorflow/third_party/clang_toolchain/BUILD
|
||||
tensorflow/third_party/clang_toolchain/download_clang.bzl
|
||||
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
|
||||
tensorflow/third_party/tflite_ovic_testdata.BUILD
|
||||
tensorflow/third_party/repo.bzl
|
||||
tensorflow/third_party/png_fix_rpi.patch
|
||||
tensorflow/third_party/py/python_configure.bzl
|
||||
tensorflow/third_party/py/BUILD.tpl
|
||||
tensorflow/third_party/py/BUILD
|
||||
tensorflow/third_party/py/numpy/BUILD
|
||||
tensorflow/third_party/double_conversion.BUILD
|
||||
tensorflow/third_party/six.BUILD
|
||||
tensorflow/third_party/tflite_mobilenet_float.BUILD
|
||||
tensorflow/third_party/repo.bzl
|
||||
tensorflow/third_party/codegen.BUILD
|
||||
tensorflow/third_party/cub.BUILD
|
||||
tensorflow/third_party/jsoncpp.BUILD
|
||||
tensorflow/third_party/tflite_ovic_testdata.BUILD
|
||||
tensorflow/third_party/libxsmm.BUILD
|
||||
tensorflow/third_party/zlib.BUILD
|
||||
tensorflow/third_party/lmdb.BUILD
|
||||
tensorflow/third_party/nanopb.BUILD
|
||||
tensorflow/third_party/android/android.bzl.tpl
|
||||
tensorflow/third_party/android/BUILD
|
||||
tensorflow/third_party/android/android_configure.BUILD.tpl
|
||||
tensorflow/third_party/android/android_configure.bzl
|
||||
tensorflow/third_party/tflite_mobilenet_float.BUILD
|
||||
tensorflow/third_party/sqlite.BUILD
|
||||
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
|
||||
tensorflow/third_party/tensorrt/LICENSE
|
||||
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
|
||||
tensorflow/third_party/tensorrt/BUILD.tpl
|
||||
tensorflow/third_party/tensorrt/BUILD
|
||||
tensorflow/third_party/gast.BUILD
|
||||
tensorflow/third_party/mpi_collectives/BUILD
|
||||
tensorflow/third_party/libxsmm.BUILD
|
||||
tensorflow/third_party/eigen.BUILD
|
||||
tensorflow/third_party/com_google_absl.BUILD
|
||||
tensorflow/third_party/eigen3/LICENSE
|
||||
tensorflow/third_party/eigen3/gpu_packet_math.patch
|
||||
tensorflow/third_party/eigen3/BUILD
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
|
||||
tensorflow/third_party/eigen3/Eigen/QR
|
||||
tensorflow/third_party/eigen3/Eigen/SVD
|
||||
tensorflow/third_party/eigen3/Eigen/LU
|
||||
tensorflow/third_party/eigen3/Eigen/Cholesky
|
||||
tensorflow/third_party/eigen3/Eigen/Eigenvalues
|
||||
tensorflow/third_party/eigen3/Eigen/Core
|
||||
tensorflow/third_party/BUILD
|
||||
tensorflow/third_party/termcolor.BUILD
|
||||
tensorflow/third_party/gif.BUILD
|
||||
tensorflow/third_party/tflite_mobilenet.BUILD
|
||||
tensorflow/third_party/mkl/LICENSE
|
||||
tensorflow/third_party/mkl/build_defs.bzl
|
||||
tensorflow/third_party/mkl/mkl.BUILD
|
||||
tensorflow/third_party/mkl/MKL_LICENSE
|
||||
tensorflow/third_party/mkl/BUILD
|
||||
tensorflow/third_party/nccl/build_defs.bzl.tpl
|
||||
tensorflow/third_party/nccl/LICENSE
|
||||
tensorflow/third_party/nccl/nccl_configure.bzl
|
||||
tensorflow/third_party/nccl/archive.BUILD
|
||||
tensorflow/third_party/nccl/BUILD
|
||||
tensorflow/third_party/nccl/system.BUILD.tpl
|
||||
tensorflow/third_party/snappy.BUILD
|
||||
tensorflow/third_party/python_runtime/BUILD
|
||||
tensorflow/third_party/googleapis.BUILD
|
||||
tensorflow/third_party/boringssl/BUILD
|
||||
tensorflow/third_party/protobuf/BUILD
|
||||
tensorflow/third_party/backports_weakref.BUILD
|
||||
tensorflow/third_party/tflite_smartreply.BUILD
|
||||
tensorflow/third_party/swig.BUILD
|
||||
tensorflow/compat_template.__init__.py
|
||||
tensorflow/tools/lib_package/libtensorflow_test.sh
|
||||
tensorflow/tools/lib_package/libtensorflow_java_test.sh
|
||||
tensorflow/tools/lib_package/libtensorflow_test.c
|
||||
tensorflow/tools/lib_package/concat_licenses.sh
|
||||
tensorflow/tools/lib_package/LibTensorFlowTest.java
|
||||
tensorflow/tools/lib_package/BUILD
|
||||
tensorflow/tools/lib_package/README.md
|
||||
tensorflow/tools/pip_package/check_load_py_test.py
|
||||
tensorflow/tools/pip_package/simple_console.py
|
||||
tensorflow/tools/pip_package/pip_smoke_test.py
|
||||
tensorflow/tools/pip_package/BUILD
|
||||
tensorflow/tools/pip_package/simple_console_for_windows.py
|
||||
tensorflow/tools/pip_package/build_pip_package.sh
|
||||
tensorflow/tools/pip_package/README
|
||||
tensorflow/tools/pip_package/setup.py
|
||||
tensorflow/tools/pip_package/MANIFEST.in
|
||||
tensorflow/tools/ci_build/remote/BUILD
|
||||
tensorflow/tools/def_file_filter/def_file_filter.py.tpl
|
||||
tensorflow/tools/def_file_filter/BUILD.tpl
|
||||
tensorflow/tools/def_file_filter/BUILD
|
||||
tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
|
||||
tensorflow/api_template.__init__.py
|
||||
tensorflow/contrib/tpu/profiler/pip_package/BUILD
|
||||
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/__init__.py
|
||||
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
|
||||
tensorflow/contrib/tpu/profiler/pip_package/build_pip_package.sh
|
||||
tensorflow/contrib/tpu/profiler/pip_package/README
|
||||
tensorflow/contrib/tpu/profiler/pip_package/setup.py
|
||||
tensorflow/contrib/mpi/BUILD
|
||||
tensorflow/__init__.py
|
||||
tensorflow/stream_executor/build_defs.bzl
|
||||
tensorflow/api_template_v1.__init__.py
|
||||
tensorflow/compat_template_v1.__init__.py
|
||||
tensorflow/compat_template.__init__.py
|
||||
tensorflow/api_template.__init__.py
|
||||
tensorflow/__init__.py
|
@ -2551,6 +2551,12 @@ class Operation(object):
|
||||
shapes_list = attr_value_pb2.AttrValue.ListValue(shape=shapes)
|
||||
self._set_attr(attr_name, attr_value_pb2.AttrValue(list=shapes_list))
|
||||
|
||||
def _clear_attr(self, attr_name):
|
||||
"""Private method used to clear an attribute in the node_def."""
|
||||
# pylint: disable=protected-access
|
||||
c_api.ClearAttr(self._graph._c_graph, self._c_op, attr_name)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def get_attr(self, name):
|
||||
"""Returns the value of the attr of this op with the given `name`.
|
||||
|
||||
|
@ -62,6 +62,10 @@ _UNSUPPORTED_OPS = set([
|
||||
"TensorSummaryV2",
|
||||
])
|
||||
|
||||
# Ops which can be safely pruned from XLA compile if they have no consumers.
|
||||
# These ops should also have no inputs.
|
||||
_UNCONNECTED_OPS_TO_PRUNE = set(["Placeholder", "VarHandleOp"])
|
||||
|
||||
_MAX_WARNING_LINES = 5
|
||||
|
||||
_TPU_REPLICATE_ATTR = "_tpu_replicate"
|
||||
@ -1574,3 +1578,31 @@ def rewrite_for_inference(computation,
|
||||
device_assignment=device_assignment,
|
||||
name=name)
|
||||
# pylint: enable=undefined-variable
|
||||
|
||||
|
||||
def prune_unconnected_ops_from_xla(prune_graph):
|
||||
"""Prunes unconnected ops as listed in _UNCONNECTED_OPS_TO_PRUNE.
|
||||
|
||||
Args:
|
||||
prune_graph: A tensorflow graph from which we wish to prune unconnected ops
|
||||
as listed in _UNCONNECTED_OPS_TO_PRUNE. In general, these ops should have
|
||||
no inputs and no consumers. These can often be left behind due to graph
|
||||
construction rewiring (for instance TF-Hub). While they never execute,
|
||||
they will cause XLA compile to fail so we strip them from XLA compile by
|
||||
removing the tpu_replicate attribute.
|
||||
"""
|
||||
# Scan over the top level graph and all function graphs.
|
||||
for graph in [prune_graph] + list(prune_graph._functions.values()): # pylint: disable=protected-access
|
||||
for op in graph.get_operations():
|
||||
if op.type not in _UNCONNECTED_OPS_TO_PRUNE:
|
||||
continue
|
||||
outputs_consumed = False
|
||||
for output in op.outputs:
|
||||
if output.consumers():
|
||||
outputs_consumed = True
|
||||
break
|
||||
if not outputs_consumed:
|
||||
logging.info(
|
||||
"Pruning OP %s of type %s from XLA Compile due to "
|
||||
"it being disconnected.", op.name, op.type)
|
||||
op._clear_attr(_TPU_REPLICATE_ATTR) # pylint: disable=protected-access
|
||||
|
@ -2477,6 +2477,7 @@ class TPUEstimator(estimator_lib.Estimator):
|
||||
if self._experimental_exported_model_uses_all_cores:
|
||||
tensors_on_cpu = tpu.rewrite(
|
||||
tpu_computation, device_assignment=device_assignment)
|
||||
tpu.prune_unconnected_ops_from_xla(ops.get_default_graph())
|
||||
else:
|
||||
tensors_on_cpu = tpu.rewrite_for_inference(
|
||||
tpu_computation, device_assignment=device_assignment)
|
||||
|
@ -19,12 +19,17 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import importer
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.layers import convolutional
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.tpu import tpu
|
||||
from tensorflow.python.tpu import tpu_feed
|
||||
@ -76,5 +81,63 @@ class TPULayerRewriteTest(test.TestCase):
|
||||
tpu.rewrite(loop)
|
||||
|
||||
|
||||
class TPUGraphPruneTest(test.TestCase):
|
||||
|
||||
def test_prune_unconnected_ops(self):
|
||||
with ops.Graph().as_default():
|
||||
a = array_ops.placeholder(dtype=dtypes.float32, name="a")
|
||||
b = array_ops.placeholder(dtype=dtypes.float32, name="b")
|
||||
constant_op.constant(1.0, name="constant")
|
||||
x = variable_scope.get_variable(
|
||||
name="x",
|
||||
dtype=dtypes.float32,
|
||||
shape=[],
|
||||
use_resource=True,
|
||||
initializer=init_ops.constant_initializer(2.0))
|
||||
y = variable_scope.get_variable(
|
||||
name="y",
|
||||
dtype=dtypes.float32,
|
||||
shape=[],
|
||||
use_resource=True,
|
||||
initializer=init_ops.constant_initializer(3.0))
|
||||
math_ops.add(a, b)
|
||||
math_ops.add(x, y)
|
||||
graph_def = ops.get_default_graph().as_graph_def()
|
||||
|
||||
for node in graph_def.node:
|
||||
# Attach a TPU_REPLICATE_ATTR to each node.
|
||||
node.attr[tpu._TPU_REPLICATE_ATTR].s = b"0"
|
||||
# Rewire placeholder "a" and variable "y" leaving them unconnected.
|
||||
for (input_index, node_input) in enumerate(node.input):
|
||||
if node_input == "b":
|
||||
node.input[input_index] = "constant"
|
||||
if node_input == "y":
|
||||
node.input[input_index] = "x"
|
||||
|
||||
with ops.Graph().as_default() as graph:
|
||||
# Reimport the graph and prune unconnected ops.
|
||||
importer.import_graph_def(graph_def)
|
||||
tpu.prune_unconnected_ops_from_xla(ops.get_default_graph())
|
||||
|
||||
# Verify that ops "a" and "x" still have TPU_REPLICATE_ATTR.
|
||||
a = graph.get_operation_by_name("import/a").get_attr(
|
||||
tpu._TPU_REPLICATE_ATTR)
|
||||
self.assertEqual(b"0", a)
|
||||
x = graph.get_operation_by_name("import/x").get_attr(
|
||||
tpu._TPU_REPLICATE_ATTR)
|
||||
self.assertEqual(b"0", x)
|
||||
# Verify that ops "b" and "y" have TPU_REPLICATE_ATTR removed.
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
"Operation \'import/b\' has no attr named \'_tpu_replicate\'"):
|
||||
graph.get_operation_by_name("import/b").get_attr(
|
||||
tpu._TPU_REPLICATE_ATTR)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
"Operation \'import/y\' has no attr named \'_tpu_replicate\'"):
|
||||
graph.get_operation_by_name("import/y").get_attr(
|
||||
tpu._TPU_REPLICATE_ATTR)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user