Expose clear_attr in the python graph API. Use it to clear hanging nodes with no input/output from being TPU compiled.

PiperOrigin-RevId: 236171147
This commit is contained in:
A. Unique TensorFlower 2019-02-28 12:13:54 -08:00 committed by TensorFlower Gardener
parent e59d6b67ad
commit 4059344a81
7 changed files with 339 additions and 223 deletions

View File

@ -41,6 +41,15 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
RecordMutation(graph, *op, "setting attribute");
}
void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
TF_Status* status) {
AttrValue attr_val;
mutex_lock l(graph->mu);
op->node.ClearAttr(attr_name);
RecordMutation(graph, *op, "clearing attribute");
}
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
mutex_lock l(graph->mu);
op->node.set_requested_device(device);

View File

@ -32,6 +32,11 @@ void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input);
void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
TF_Buffer* attr_value_proto, TF_Status* status);
// Clears the attr in the node_def Protocol Buffer and sets a status upon
// completion.
void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
TF_Status* status);
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device);
// Updates 'dst' to consume 'new_src'.

View File

@ -1,245 +1,245 @@
tensorflow/contrib/tpu/profiler/pip_package/BUILD
tensorflow/contrib/tpu/profiler/pip_package/setup.py
tensorflow/contrib/tpu/profiler/pip_package/README
tensorflow/contrib/tpu/profiler/pip_package/build_pip_package.sh
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/__init__.py
tensorflow/contrib/mpi/BUILD
tensorflow/tools/ci_build/remote/BUILD
tensorflow/tools/pip_package/README
tensorflow/tools/pip_package/MANIFEST.in
tensorflow/tools/pip_package/simple_console.py
tensorflow/tools/pip_package/build_pip_package.sh
tensorflow/tools/pip_package/check_load_py_test.py
tensorflow/tools/pip_package/pip_smoke_test.py
tensorflow/tools/pip_package/simple_console_for_windows.py
tensorflow/tools/pip_package/setup.py
tensorflow/tools/pip_package/BUILD
tensorflow/tools/lib_package/concat_licenses.sh
tensorflow/tools/lib_package/libtensorflow_test.c
tensorflow/tools/lib_package/LibTensorFlowTest.java
tensorflow/tools/lib_package/BUILD
tensorflow/tools/lib_package/libtensorflow_test.sh
tensorflow/tools/lib_package/README.md
tensorflow/tools/lib_package/libtensorflow_java_test.sh
tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
tensorflow/tools/def_file_filter/BUILD
tensorflow/tools/def_file_filter/BUILD.tpl
tensorflow/tools/def_file_filter/def_file_filter.py.tpl
tensorflow/third_party/mkl/MKL_LICENSE
tensorflow/third_party/mkl/LICENSE
tensorflow/third_party/mkl/BUILD
tensorflow/third_party/mkl/mkl.BUILD
tensorflow/third_party/mkl/build_defs.bzl
tensorflow/third_party/backports_weakref.BUILD
tensorflow/third_party/toolchains/clang6/BUILD
tensorflow/third_party/toolchains/clang6/README.md
tensorflow/third_party/toolchains/clang6/repo.bzl
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
tensorflow/third_party/toolchains/clang6/clang.BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc7-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda9.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
tensorflow/third_party/toolchains/preconfig/generate/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
tensorflow/third_party/systemlibs/nsync.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
tensorflow/third_party/systemlibs/curl.BUILD
tensorflow/third_party/systemlibs/cython.BUILD
tensorflow/third_party/systemlibs/astor.BUILD
tensorflow/third_party/systemlibs/jsoncpp.BUILD
tensorflow/third_party/systemlibs/png.BUILD
tensorflow/third_party/systemlibs/pcre.BUILD
tensorflow/third_party/systemlibs/grpc.BUILD
tensorflow/third_party/systemlibs/protobuf.BUILD
tensorflow/third_party/systemlibs/double_conversion.BUILD
tensorflow/third_party/systemlibs/six.BUILD
tensorflow/third_party/systemlibs/zlib.BUILD
tensorflow/third_party/systemlibs/lmdb.BUILD
tensorflow/third_party/systemlibs/sqlite.BUILD
tensorflow/third_party/systemlibs/gast.BUILD
tensorflow/third_party/systemlibs/absl_py.BUILD
tensorflow/third_party/systemlibs/boringssl.BUILD
tensorflow/third_party/systemlibs/BUILD.tpl
tensorflow/third_party/systemlibs/BUILD
tensorflow/third_party/systemlibs/termcolor.BUILD
tensorflow/third_party/systemlibs/gif.BUILD
tensorflow/third_party/systemlibs/protobuf.bzl
tensorflow/third_party/systemlibs/snappy.BUILD
tensorflow/third_party/systemlibs/googleapis.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
tensorflow/third_party/systemlibs/re2.BUILD
tensorflow/third_party/systemlibs/swig.BUILD
tensorflow/third_party/systemlibs/syslibs_configure.bzl
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
tensorflow/third_party/pprof.BUILD
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
tensorflow/third_party/toolchains/remote/BUILD.tpl
tensorflow/third_party/toolchains/remote/BUILD
tensorflow/third_party/toolchains/remote/configure.bzl
tensorflow/third_party/toolchains/cpus/py3/BUILD
tensorflow/third_party/toolchains/cpus/py/BUILD
tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl
tensorflow/third_party/toolchains/cpus/arm/CROSSTOOL.tpl
tensorflow/third_party/toolchains/cpus/arm/BUILD
tensorflow/third_party/toolchains/cpus/py3/BUILD
tensorflow/third_party/toolchains/cpus/py/BUILD
tensorflow/third_party/toolchains/remote/configure.bzl
tensorflow/third_party/toolchains/remote/BUILD.tpl
tensorflow/third_party/toolchains/remote/BUILD
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
tensorflow/third_party/toolchains/BUILD
tensorflow/third_party/gpus/BUILD
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
tensorflow/third_party/gpus/crosstool/CROSSTOOL.tpl
tensorflow/third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
tensorflow/third_party/toolchains/preconfig/generate/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda9.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc7-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/clang6/repo.bzl
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
tensorflow/third_party/toolchains/clang6/BUILD
tensorflow/third_party/toolchains/clang6/clang.BUILD
tensorflow/third_party/toolchains/clang6/README.md
tensorflow/third_party/farmhash.BUILD
tensorflow/third_party/git/BUILD.tpl
tensorflow/third_party/git/git_configure.bzl
tensorflow/third_party/git/BUILD
tensorflow/third_party/cub.BUILD
tensorflow/third_party/gpus/cuda_configure.bzl
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
tensorflow/third_party/gpus/rocm/BUILD.tpl
tensorflow/third_party/gpus/rocm/BUILD
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
tensorflow/third_party/gpus/rocm_configure.bzl
tensorflow/third_party/gpus/crosstool/LICENSE
tensorflow/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
tensorflow/third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
tensorflow/third_party/gpus/crosstool/CROSSTOOL.tpl
tensorflow/third_party/gpus/crosstool/BUILD.tpl
tensorflow/third_party/gpus/crosstool/BUILD
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
tensorflow/third_party/gpus/cuda/LICENSE
tensorflow/third_party/gpus/cuda/BUILD.tpl
tensorflow/third_party/gpus/cuda/BUILD.windows.tpl
tensorflow/third_party/gpus/cuda/cuda_config.h.tpl
tensorflow/third_party/gpus/cuda/BUILD.tpl
tensorflow/third_party/gpus/cuda/BUILD
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
tensorflow/third_party/gpus/rocm/BUILD
tensorflow/third_party/gpus/rocm/BUILD.tpl
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
tensorflow/third_party/gpus/cuda_configure.bzl
tensorflow/third_party/gpus/rocm_configure.bzl
tensorflow/third_party/snappy.BUILD
tensorflow/third_party/cython.BUILD
tensorflow/third_party/farmhash.BUILD
tensorflow/third_party/eigen3/Eigen/Cholesky
tensorflow/third_party/eigen3/Eigen/QR
tensorflow/third_party/eigen3/Eigen/LU
tensorflow/third_party/eigen3/Eigen/Core
tensorflow/third_party/eigen3/Eigen/SVD
tensorflow/third_party/eigen3/Eigen/Eigenvalues
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
tensorflow/third_party/eigen3/gpu_packet_math.patch
tensorflow/third_party/eigen3/LICENSE
tensorflow/third_party/eigen3/BUILD
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
tensorflow/third_party/systemlibs/absl_py.BUILD
tensorflow/third_party/systemlibs/curl.BUILD
tensorflow/third_party/systemlibs/termcolor.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
tensorflow/third_party/systemlibs/grpc.BUILD
tensorflow/third_party/systemlibs/swig.BUILD
tensorflow/third_party/systemlibs/protobuf.bzl
tensorflow/third_party/systemlibs/protobuf.BUILD
tensorflow/third_party/systemlibs/BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
tensorflow/third_party/systemlibs/astor.BUILD
tensorflow/third_party/systemlibs/six.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
tensorflow/third_party/systemlibs/boringssl.BUILD
tensorflow/third_party/systemlibs/nsync.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
tensorflow/third_party/systemlibs/gif.BUILD
tensorflow/third_party/systemlibs/pcre.BUILD
tensorflow/third_party/systemlibs/BUILD.tpl
tensorflow/third_party/systemlibs/snappy.BUILD
tensorflow/third_party/systemlibs/gast.BUILD
tensorflow/third_party/systemlibs/cython.BUILD
tensorflow/third_party/systemlibs/double_conversion.BUILD
tensorflow/third_party/systemlibs/zlib.BUILD
tensorflow/third_party/systemlibs/jsoncpp.BUILD
tensorflow/third_party/systemlibs/re2.BUILD
tensorflow/third_party/systemlibs/lmdb.BUILD
tensorflow/third_party/systemlibs/googleapis.BUILD
tensorflow/third_party/systemlibs/png.BUILD
tensorflow/third_party/systemlibs/syslibs_configure.bzl
tensorflow/third_party/systemlibs/sqlite.BUILD
tensorflow/third_party/python_runtime/BUILD
tensorflow/third_party/sycl/crosstool/BUILD
tensorflow/third_party/ngraph/LICENSE
tensorflow/third_party/ngraph/tbb.BUILD
tensorflow/third_party/ngraph/BUILD
tensorflow/third_party/ngraph/ngraph.BUILD
tensorflow/third_party/ngraph/build_defs.bzl
tensorflow/third_party/ngraph/NGRAPH_LICENSE
tensorflow/third_party/ngraph/ngraph_tf.BUILD
tensorflow/third_party/ngraph/nlohmann_json.BUILD
tensorflow/third_party/clang_toolchain/download_clang.bzl
tensorflow/third_party/clang_toolchain/BUILD
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
tensorflow/third_party/gast.BUILD
tensorflow/third_party/llvm/BUILD
tensorflow/third_party/llvm/expand_cmake_vars.py
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
tensorflow/third_party/llvm/llvm.bzl
tensorflow/third_party/icu/udata.patch
tensorflow/third_party/nccl/archive.BUILD
tensorflow/third_party/nccl/LICENSE
tensorflow/third_party/nccl/system.BUILD.tpl
tensorflow/third_party/nccl/nccl_configure.bzl
tensorflow/third_party/nccl/build_defs.bzl.tpl
tensorflow/third_party/nccl/BUILD
tensorflow/third_party/fft2d/BUILD
tensorflow/third_party/fft2d/fft.h
tensorflow/third_party/fft2d/LICENSE
tensorflow/third_party/fft2d/fft2d.BUILD
tensorflow/third_party/boringssl/BUILD
tensorflow/third_party/mpi/.gitignore
tensorflow/third_party/mpi/BUILD
tensorflow/third_party/tensorrt/LICENSE
tensorflow/third_party/tensorrt/BUILD
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
tensorflow/third_party/tensorrt/BUILD.tpl
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
tensorflow/third_party/kafka/config.patch
tensorflow/third_party/kafka/BUILD
tensorflow/third_party/android/BUILD
tensorflow/third_party/android/android.bzl.tpl
tensorflow/third_party/android/android_configure.bzl
tensorflow/third_party/android/android_configure.BUILD.tpl
tensorflow/third_party/tflite_smartreply.BUILD
tensorflow/third_party/gpus/BUILD
tensorflow/third_party/common.bzl
tensorflow/third_party/tflite_mobilenet_quant.BUILD
tensorflow/third_party/linenoise.BUILD
tensorflow/third_party/curl.BUILD
tensorflow/third_party/mkl_dnn/LICENSE
tensorflow/third_party/mkl_dnn/mkldnn.BUILD
tensorflow/third_party/pcre.BUILD
tensorflow/third_party/linenoise.BUILD
tensorflow/third_party/sqlite.BUILD
tensorflow/third_party/common.bzl
tensorflow/third_party/com_google_absl.BUILD
tensorflow/third_party/pprof.BUILD
tensorflow/third_party/BUILD
tensorflow/third_party/tflite_mobilenet_quant.BUILD
tensorflow/third_party/lmdb.BUILD
tensorflow/third_party/git/BUILD.tpl
tensorflow/third_party/git/BUILD
tensorflow/third_party/git/git_configure.bzl
tensorflow/third_party/protobuf/BUILD
tensorflow/third_party/enum34.BUILD
tensorflow/third_party/tflite_mobilenet.BUILD
tensorflow/third_party/py/BUILD
tensorflow/third_party/py/BUILD.tpl
tensorflow/third_party/py/numpy/BUILD
tensorflow/third_party/py/python_configure.bzl
tensorflow/third_party/termcolor.BUILD
tensorflow/third_party/png_fix_rpi.patch
tensorflow/third_party/swig.BUILD
tensorflow/third_party/astor.BUILD
tensorflow/third_party/fft2d/LICENSE
tensorflow/third_party/fft2d/fft2d.BUILD
tensorflow/third_party/fft2d/fft.h
tensorflow/third_party/fft2d/BUILD
tensorflow/third_party/ngraph/LICENSE
tensorflow/third_party/ngraph/build_defs.bzl
tensorflow/third_party/ngraph/tbb.BUILD
tensorflow/third_party/ngraph/ngraph.BUILD
tensorflow/third_party/ngraph/nlohmann_json.BUILD
tensorflow/third_party/ngraph/BUILD
tensorflow/third_party/ngraph/ngraph_tf.BUILD
tensorflow/third_party/ngraph/NGRAPH_LICENSE
tensorflow/third_party/grpc/BUILD
tensorflow/third_party/curl.BUILD
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
tensorflow/third_party/cython.BUILD
tensorflow/third_party/icu/udata.patch
tensorflow/third_party/astor.BUILD
tensorflow/third_party/jsoncpp.BUILD
tensorflow/third_party/sycl/crosstool/BUILD
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
tensorflow/third_party/llvm/expand_cmake_vars.py
tensorflow/third_party/llvm/llvm.bzl
tensorflow/third_party/llvm/BUILD
tensorflow/third_party/png.BUILD
tensorflow/third_party/googleapis.BUILD
tensorflow/third_party/mpi_collectives/BUILD
tensorflow/third_party/nanopb.BUILD
tensorflow/third_party/gif.BUILD
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
tensorflow/third_party/codegen.BUILD
tensorflow/third_party/enum34.BUILD
tensorflow/third_party/kafka/config.patch
tensorflow/third_party/kafka/BUILD
tensorflow/third_party/pcre.BUILD
tensorflow/third_party/mpi/BUILD
tensorflow/third_party/mpi/.gitignore
tensorflow/third_party/clang_toolchain/BUILD
tensorflow/third_party/clang_toolchain/download_clang.bzl
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
tensorflow/third_party/tflite_ovic_testdata.BUILD
tensorflow/third_party/repo.bzl
tensorflow/third_party/png_fix_rpi.patch
tensorflow/third_party/py/python_configure.bzl
tensorflow/third_party/py/BUILD.tpl
tensorflow/third_party/py/BUILD
tensorflow/third_party/py/numpy/BUILD
tensorflow/third_party/double_conversion.BUILD
tensorflow/third_party/six.BUILD
tensorflow/third_party/tflite_mobilenet_float.BUILD
tensorflow/third_party/repo.bzl
tensorflow/third_party/codegen.BUILD
tensorflow/third_party/cub.BUILD
tensorflow/third_party/jsoncpp.BUILD
tensorflow/third_party/tflite_ovic_testdata.BUILD
tensorflow/third_party/libxsmm.BUILD
tensorflow/third_party/zlib.BUILD
tensorflow/third_party/lmdb.BUILD
tensorflow/third_party/nanopb.BUILD
tensorflow/third_party/android/android.bzl.tpl
tensorflow/third_party/android/BUILD
tensorflow/third_party/android/android_configure.BUILD.tpl
tensorflow/third_party/android/android_configure.bzl
tensorflow/third_party/tflite_mobilenet_float.BUILD
tensorflow/third_party/sqlite.BUILD
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
tensorflow/third_party/tensorrt/LICENSE
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
tensorflow/third_party/tensorrt/BUILD.tpl
tensorflow/third_party/tensorrt/BUILD
tensorflow/third_party/gast.BUILD
tensorflow/third_party/mpi_collectives/BUILD
tensorflow/third_party/libxsmm.BUILD
tensorflow/third_party/eigen.BUILD
tensorflow/third_party/com_google_absl.BUILD
tensorflow/third_party/eigen3/LICENSE
tensorflow/third_party/eigen3/gpu_packet_math.patch
tensorflow/third_party/eigen3/BUILD
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
tensorflow/third_party/eigen3/Eigen/QR
tensorflow/third_party/eigen3/Eigen/SVD
tensorflow/third_party/eigen3/Eigen/LU
tensorflow/third_party/eigen3/Eigen/Cholesky
tensorflow/third_party/eigen3/Eigen/Eigenvalues
tensorflow/third_party/eigen3/Eigen/Core
tensorflow/third_party/BUILD
tensorflow/third_party/termcolor.BUILD
tensorflow/third_party/gif.BUILD
tensorflow/third_party/tflite_mobilenet.BUILD
tensorflow/third_party/mkl/LICENSE
tensorflow/third_party/mkl/build_defs.bzl
tensorflow/third_party/mkl/mkl.BUILD
tensorflow/third_party/mkl/MKL_LICENSE
tensorflow/third_party/mkl/BUILD
tensorflow/third_party/nccl/build_defs.bzl.tpl
tensorflow/third_party/nccl/LICENSE
tensorflow/third_party/nccl/nccl_configure.bzl
tensorflow/third_party/nccl/archive.BUILD
tensorflow/third_party/nccl/BUILD
tensorflow/third_party/nccl/system.BUILD.tpl
tensorflow/third_party/snappy.BUILD
tensorflow/third_party/python_runtime/BUILD
tensorflow/third_party/googleapis.BUILD
tensorflow/third_party/boringssl/BUILD
tensorflow/third_party/protobuf/BUILD
tensorflow/third_party/backports_weakref.BUILD
tensorflow/third_party/tflite_smartreply.BUILD
tensorflow/third_party/swig.BUILD
tensorflow/compat_template.__init__.py
tensorflow/tools/lib_package/libtensorflow_test.sh
tensorflow/tools/lib_package/libtensorflow_java_test.sh
tensorflow/tools/lib_package/libtensorflow_test.c
tensorflow/tools/lib_package/concat_licenses.sh
tensorflow/tools/lib_package/LibTensorFlowTest.java
tensorflow/tools/lib_package/BUILD
tensorflow/tools/lib_package/README.md
tensorflow/tools/pip_package/check_load_py_test.py
tensorflow/tools/pip_package/simple_console.py
tensorflow/tools/pip_package/pip_smoke_test.py
tensorflow/tools/pip_package/BUILD
tensorflow/tools/pip_package/simple_console_for_windows.py
tensorflow/tools/pip_package/build_pip_package.sh
tensorflow/tools/pip_package/README
tensorflow/tools/pip_package/setup.py
tensorflow/tools/pip_package/MANIFEST.in
tensorflow/tools/ci_build/remote/BUILD
tensorflow/tools/def_file_filter/def_file_filter.py.tpl
tensorflow/tools/def_file_filter/BUILD.tpl
tensorflow/tools/def_file_filter/BUILD
tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
tensorflow/api_template.__init__.py
tensorflow/contrib/tpu/profiler/pip_package/BUILD
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/__init__.py
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
tensorflow/contrib/tpu/profiler/pip_package/build_pip_package.sh
tensorflow/contrib/tpu/profiler/pip_package/README
tensorflow/contrib/tpu/profiler/pip_package/setup.py
tensorflow/contrib/mpi/BUILD
tensorflow/__init__.py
tensorflow/stream_executor/build_defs.bzl
tensorflow/api_template_v1.__init__.py
tensorflow/compat_template_v1.__init__.py
tensorflow/compat_template.__init__.py
tensorflow/api_template.__init__.py
tensorflow/__init__.py

View File

@ -2551,6 +2551,12 @@ class Operation(object):
shapes_list = attr_value_pb2.AttrValue.ListValue(shape=shapes)
self._set_attr(attr_name, attr_value_pb2.AttrValue(list=shapes_list))
def _clear_attr(self, attr_name):
"""Private method used to clear an attribute in the node_def."""
# pylint: disable=protected-access
c_api.ClearAttr(self._graph._c_graph, self._c_op, attr_name)
# pylint: enable=protected-access
def get_attr(self, name):
"""Returns the value of the attr of this op with the given `name`.

View File

@ -62,6 +62,10 @@ _UNSUPPORTED_OPS = set([
"TensorSummaryV2",
])
# Ops which can be safely pruned from XLA compile if they have no consumers.
# These ops should also have no inputs.
_UNCONNECTED_OPS_TO_PRUNE = set(["Placeholder", "VarHandleOp"])
_MAX_WARNING_LINES = 5
_TPU_REPLICATE_ATTR = "_tpu_replicate"
@ -1574,3 +1578,31 @@ def rewrite_for_inference(computation,
device_assignment=device_assignment,
name=name)
# pylint: enable=undefined-variable
def prune_unconnected_ops_from_xla(prune_graph):
"""Prunes unconnected ops as listed in _UNCONNECTED_OPS_TO_PRUNE.
Args:
prune_graph: A tensorflow graph from which we wish to prune unconnected ops
as listed in _UNCONNECTED_OPS_TO_PRUNE. In general, these ops should have
no inputs and no consumers. These can often be left behind due to graph
construction rewiring (for instance TF-Hub). While they never execute,
they will cause XLA compile to fail so we strip them from XLA compile by
removing the tpu_replicate attribute.
"""
# Scan over the top level graph and all function graphs.
for graph in [prune_graph] + list(prune_graph._functions.values()): # pylint: disable=protected-access
for op in graph.get_operations():
if op.type not in _UNCONNECTED_OPS_TO_PRUNE:
continue
outputs_consumed = False
for output in op.outputs:
if output.consumers():
outputs_consumed = True
break
if not outputs_consumed:
logging.info(
"Pruning OP %s of type %s from XLA Compile due to "
"it being disconnected.", op.name, op.type)
op._clear_attr(_TPU_REPLICATE_ATTR) # pylint: disable=protected-access

View File

@ -2477,6 +2477,7 @@ class TPUEstimator(estimator_lib.Estimator):
if self._experimental_exported_model_uses_all_cores:
tensors_on_cpu = tpu.rewrite(
tpu_computation, device_assignment=device_assignment)
tpu.prune_unconnected_ops_from_xla(ops.get_default_graph())
else:
tensors_on_cpu = tpu.rewrite_for_inference(
tpu_computation, device_assignment=device_assignment)

View File

@ -19,12 +19,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.layers import convolutional
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.tpu import tpu
from tensorflow.python.tpu import tpu_feed
@ -76,5 +81,63 @@ class TPULayerRewriteTest(test.TestCase):
tpu.rewrite(loop)
class TPUGraphPruneTest(test.TestCase):
def test_prune_unconnected_ops(self):
with ops.Graph().as_default():
a = array_ops.placeholder(dtype=dtypes.float32, name="a")
b = array_ops.placeholder(dtype=dtypes.float32, name="b")
constant_op.constant(1.0, name="constant")
x = variable_scope.get_variable(
name="x",
dtype=dtypes.float32,
shape=[],
use_resource=True,
initializer=init_ops.constant_initializer(2.0))
y = variable_scope.get_variable(
name="y",
dtype=dtypes.float32,
shape=[],
use_resource=True,
initializer=init_ops.constant_initializer(3.0))
math_ops.add(a, b)
math_ops.add(x, y)
graph_def = ops.get_default_graph().as_graph_def()
for node in graph_def.node:
# Attach a TPU_REPLICATE_ATTR to each node.
node.attr[tpu._TPU_REPLICATE_ATTR].s = b"0"
# Rewire placeholder "a" and variable "y" leaving them unconnected.
for (input_index, node_input) in enumerate(node.input):
if node_input == "b":
node.input[input_index] = "constant"
if node_input == "y":
node.input[input_index] = "x"
with ops.Graph().as_default() as graph:
# Reimport the graph and prune unconnected ops.
importer.import_graph_def(graph_def)
tpu.prune_unconnected_ops_from_xla(ops.get_default_graph())
# Verify that ops "a" and "x" still have TPU_REPLICATE_ATTR.
a = graph.get_operation_by_name("import/a").get_attr(
tpu._TPU_REPLICATE_ATTR)
self.assertEqual(b"0", a)
x = graph.get_operation_by_name("import/x").get_attr(
tpu._TPU_REPLICATE_ATTR)
self.assertEqual(b"0", x)
# Verify that ops "b" and "y" have TPU_REPLICATE_ATTR removed.
with self.assertRaisesRegexp(
ValueError,
"Operation \'import/b\' has no attr named \'_tpu_replicate\'"):
graph.get_operation_by_name("import/b").get_attr(
tpu._TPU_REPLICATE_ATTR)
with self.assertRaisesRegexp(
ValueError,
"Operation \'import/y\' has no attr named \'_tpu_replicate\'"):
graph.get_operation_by_name("import/y").get_attr(
tpu._TPU_REPLICATE_ATTR)
if __name__ == "__main__":
test.main()