From 4059344a81a6242ac437e8dca4b34aeb7fe8d6bd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 28 Feb 2019 12:13:54 -0800 Subject: [PATCH] Expose clear_attr in the python graph API. Use it to clear hanging nodes with no input/output from being TPU compiled. PiperOrigin-RevId: 236171147 --- tensorflow/c/python_api.cc | 9 + tensorflow/c/python_api.h | 5 + tensorflow/opensource_only.files | 446 ++++++++++++------------- tensorflow/python/framework/ops.py | 6 + tensorflow/python/tpu/tpu.py | 32 ++ tensorflow/python/tpu/tpu_estimator.py | 1 + tensorflow/python/tpu/tpu_test.py | 63 ++++ 7 files changed, 339 insertions(+), 223 deletions(-) diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 98d83933322..6449e7f44f7 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -41,6 +41,15 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, RecordMutation(graph, *op, "setting attribute"); } +void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, + TF_Status* status) { + AttrValue attr_val; + + mutex_lock l(graph->mu); + op->node.ClearAttr(attr_name); + RecordMutation(graph, *op, "clearing attribute"); +} + void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { mutex_lock l(graph->mu); op->node.set_requested_device(device); diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index 44779ca6561..f26c0cb2ae2 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -32,6 +32,11 @@ void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input); void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, TF_Buffer* attr_value_proto, TF_Status* status); +// Clears the attr in the node_def Protocol Buffer and sets a status upon +// completion. +void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, + TF_Status* status); + void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device); // Updates 'dst' to consume 'new_src'. diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index 3be09b70f17..41f02fedadd 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -1,245 +1,245 @@ -tensorflow/contrib/tpu/profiler/pip_package/BUILD -tensorflow/contrib/tpu/profiler/pip_package/setup.py -tensorflow/contrib/tpu/profiler/pip_package/README -tensorflow/contrib/tpu/profiler/pip_package/build_pip_package.sh -tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py -tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/__init__.py -tensorflow/contrib/mpi/BUILD -tensorflow/tools/ci_build/remote/BUILD -tensorflow/tools/pip_package/README -tensorflow/tools/pip_package/MANIFEST.in -tensorflow/tools/pip_package/simple_console.py -tensorflow/tools/pip_package/build_pip_package.sh -tensorflow/tools/pip_package/check_load_py_test.py -tensorflow/tools/pip_package/pip_smoke_test.py -tensorflow/tools/pip_package/simple_console_for_windows.py -tensorflow/tools/pip_package/setup.py -tensorflow/tools/pip_package/BUILD -tensorflow/tools/lib_package/concat_licenses.sh -tensorflow/tools/lib_package/libtensorflow_test.c -tensorflow/tools/lib_package/LibTensorFlowTest.java -tensorflow/tools/lib_package/BUILD -tensorflow/tools/lib_package/libtensorflow_test.sh -tensorflow/tools/lib_package/README.md -tensorflow/tools/lib_package/libtensorflow_java_test.sh -tensorflow/tools/def_file_filter/def_file_filter_configure.bzl -tensorflow/tools/def_file_filter/BUILD -tensorflow/tools/def_file_filter/BUILD.tpl -tensorflow/tools/def_file_filter/def_file_filter.py.tpl -tensorflow/third_party/mkl/MKL_LICENSE -tensorflow/third_party/mkl/LICENSE -tensorflow/third_party/mkl/BUILD -tensorflow/third_party/mkl/mkl.BUILD -tensorflow/third_party/mkl/build_defs.bzl -tensorflow/third_party/backports_weakref.BUILD -tensorflow/third_party/toolchains/clang6/BUILD -tensorflow/third_party/toolchains/clang6/README.md -tensorflow/third_party/toolchains/clang6/repo.bzl -tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl -tensorflow/third_party/toolchains/clang6/clang.BUILD -tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc7-nvcc-cuda10.0/BUILD -tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD -tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl -tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD -tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda9.0/BUILD -tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD -tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl -tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD -tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD -tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/build_defs.bzl -tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD -tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl -tensorflow/third_party/toolchains/preconfig/generate/containers.bzl -tensorflow/third_party/toolchains/preconfig/generate/generate.bzl -tensorflow/third_party/toolchains/preconfig/generate/archives.bzl -tensorflow/third_party/toolchains/preconfig/generate/BUILD -tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD -tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl -tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD -tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/BUILD -tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/dummy_toolchain.bzl -tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD -tensorflow/third_party/toolchains/preconfig/win_1803/BUILD +tensorflow/third_party/systemlibs/nsync.BUILD +tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD +tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD +tensorflow/third_party/systemlibs/build_defs.bzl.tpl +tensorflow/third_party/systemlibs/curl.BUILD +tensorflow/third_party/systemlibs/cython.BUILD +tensorflow/third_party/systemlibs/astor.BUILD +tensorflow/third_party/systemlibs/jsoncpp.BUILD +tensorflow/third_party/systemlibs/png.BUILD +tensorflow/third_party/systemlibs/pcre.BUILD +tensorflow/third_party/systemlibs/grpc.BUILD +tensorflow/third_party/systemlibs/protobuf.BUILD +tensorflow/third_party/systemlibs/double_conversion.BUILD +tensorflow/third_party/systemlibs/six.BUILD +tensorflow/third_party/systemlibs/zlib.BUILD +tensorflow/third_party/systemlibs/lmdb.BUILD +tensorflow/third_party/systemlibs/sqlite.BUILD +tensorflow/third_party/systemlibs/gast.BUILD +tensorflow/third_party/systemlibs/absl_py.BUILD +tensorflow/third_party/systemlibs/boringssl.BUILD +tensorflow/third_party/systemlibs/BUILD.tpl +tensorflow/third_party/systemlibs/BUILD +tensorflow/third_party/systemlibs/termcolor.BUILD +tensorflow/third_party/systemlibs/gif.BUILD +tensorflow/third_party/systemlibs/protobuf.bzl +tensorflow/third_party/systemlibs/snappy.BUILD +tensorflow/third_party/systemlibs/googleapis.BUILD +tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD +tensorflow/third_party/systemlibs/re2.BUILD +tensorflow/third_party/systemlibs/swig.BUILD +tensorflow/third_party/systemlibs/syslibs_configure.bzl +tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD +tensorflow/third_party/pprof.BUILD +tensorflow/third_party/toolchains/remote/execution.bzl.tpl +tensorflow/third_party/toolchains/remote/BUILD.tpl +tensorflow/third_party/toolchains/remote/BUILD +tensorflow/third_party/toolchains/remote/configure.bzl +tensorflow/third_party/toolchains/cpus/py3/BUILD +tensorflow/third_party/toolchains/cpus/py/BUILD tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl tensorflow/third_party/toolchains/cpus/arm/CROSSTOOL.tpl tensorflow/third_party/toolchains/cpus/arm/BUILD -tensorflow/third_party/toolchains/cpus/py3/BUILD -tensorflow/third_party/toolchains/cpus/py/BUILD -tensorflow/third_party/toolchains/remote/configure.bzl -tensorflow/third_party/toolchains/remote/BUILD.tpl -tensorflow/third_party/toolchains/remote/BUILD -tensorflow/third_party/toolchains/remote/execution.bzl.tpl tensorflow/third_party/toolchains/BUILD -tensorflow/third_party/gpus/BUILD -tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl -tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl -tensorflow/third_party/gpus/crosstool/CROSSTOOL.tpl -tensorflow/third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl +tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD +tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD +tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl +tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/BUILD +tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/dummy_toolchain.bzl +tensorflow/third_party/toolchains/preconfig/win_1803/BUILD +tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD +tensorflow/third_party/toolchains/preconfig/generate/containers.bzl +tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl +tensorflow/third_party/toolchains/preconfig/generate/archives.bzl +tensorflow/third_party/toolchains/preconfig/generate/generate.bzl +tensorflow/third_party/toolchains/preconfig/generate/BUILD +tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD +tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda9.0/BUILD +tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc7-nvcc-cuda10.0/BUILD +tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD +tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD +tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/build_defs.bzl +tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD +tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl +tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD +tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl +tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD +tensorflow/third_party/toolchains/clang6/repo.bzl +tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl +tensorflow/third_party/toolchains/clang6/BUILD +tensorflow/third_party/toolchains/clang6/clang.BUILD +tensorflow/third_party/toolchains/clang6/README.md +tensorflow/third_party/farmhash.BUILD +tensorflow/third_party/git/BUILD.tpl +tensorflow/third_party/git/git_configure.bzl +tensorflow/third_party/git/BUILD +tensorflow/third_party/cub.BUILD +tensorflow/third_party/gpus/cuda_configure.bzl +tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl +tensorflow/third_party/gpus/rocm/BUILD.tpl +tensorflow/third_party/gpus/rocm/BUILD +tensorflow/third_party/gpus/rocm/rocm_config.h.tpl +tensorflow/third_party/gpus/rocm_configure.bzl tensorflow/third_party/gpus/crosstool/LICENSE tensorflow/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl +tensorflow/third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl +tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl +tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +tensorflow/third_party/gpus/crosstool/CROSSTOOL.tpl tensorflow/third_party/gpus/crosstool/BUILD.tpl tensorflow/third_party/gpus/crosstool/BUILD +tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl tensorflow/third_party/gpus/cuda/LICENSE -tensorflow/third_party/gpus/cuda/BUILD.tpl tensorflow/third_party/gpus/cuda/BUILD.windows.tpl tensorflow/third_party/gpus/cuda/cuda_config.h.tpl +tensorflow/third_party/gpus/cuda/BUILD.tpl tensorflow/third_party/gpus/cuda/BUILD -tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl -tensorflow/third_party/gpus/rocm/rocm_config.h.tpl -tensorflow/third_party/gpus/rocm/BUILD -tensorflow/third_party/gpus/rocm/BUILD.tpl -tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl -tensorflow/third_party/gpus/cuda_configure.bzl -tensorflow/third_party/gpus/rocm_configure.bzl -tensorflow/third_party/snappy.BUILD -tensorflow/third_party/cython.BUILD -tensorflow/third_party/farmhash.BUILD -tensorflow/third_party/eigen3/Eigen/Cholesky -tensorflow/third_party/eigen3/Eigen/QR -tensorflow/third_party/eigen3/Eigen/LU -tensorflow/third_party/eigen3/Eigen/Core -tensorflow/third_party/eigen3/Eigen/SVD -tensorflow/third_party/eigen3/Eigen/Eigenvalues -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool -tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions -tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions -tensorflow/third_party/eigen3/gpu_packet_math.patch -tensorflow/third_party/eigen3/LICENSE -tensorflow/third_party/eigen3/BUILD -tensorflow/third_party/systemlibs/build_defs.bzl.tpl -tensorflow/third_party/systemlibs/absl_py.BUILD -tensorflow/third_party/systemlibs/curl.BUILD -tensorflow/third_party/systemlibs/termcolor.BUILD -tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD -tensorflow/third_party/systemlibs/grpc.BUILD -tensorflow/third_party/systemlibs/swig.BUILD -tensorflow/third_party/systemlibs/protobuf.bzl -tensorflow/third_party/systemlibs/protobuf.BUILD -tensorflow/third_party/systemlibs/BUILD -tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD -tensorflow/third_party/systemlibs/astor.BUILD -tensorflow/third_party/systemlibs/six.BUILD -tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD -tensorflow/third_party/systemlibs/boringssl.BUILD -tensorflow/third_party/systemlibs/nsync.BUILD -tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD -tensorflow/third_party/systemlibs/gif.BUILD -tensorflow/third_party/systemlibs/pcre.BUILD -tensorflow/third_party/systemlibs/BUILD.tpl -tensorflow/third_party/systemlibs/snappy.BUILD -tensorflow/third_party/systemlibs/gast.BUILD -tensorflow/third_party/systemlibs/cython.BUILD -tensorflow/third_party/systemlibs/double_conversion.BUILD -tensorflow/third_party/systemlibs/zlib.BUILD -tensorflow/third_party/systemlibs/jsoncpp.BUILD -tensorflow/third_party/systemlibs/re2.BUILD -tensorflow/third_party/systemlibs/lmdb.BUILD -tensorflow/third_party/systemlibs/googleapis.BUILD -tensorflow/third_party/systemlibs/png.BUILD -tensorflow/third_party/systemlibs/syslibs_configure.bzl -tensorflow/third_party/systemlibs/sqlite.BUILD -tensorflow/third_party/python_runtime/BUILD -tensorflow/third_party/sycl/crosstool/BUILD -tensorflow/third_party/ngraph/LICENSE -tensorflow/third_party/ngraph/tbb.BUILD -tensorflow/third_party/ngraph/BUILD -tensorflow/third_party/ngraph/ngraph.BUILD -tensorflow/third_party/ngraph/build_defs.bzl -tensorflow/third_party/ngraph/NGRAPH_LICENSE -tensorflow/third_party/ngraph/ngraph_tf.BUILD -tensorflow/third_party/ngraph/nlohmann_json.BUILD -tensorflow/third_party/clang_toolchain/download_clang.bzl -tensorflow/third_party/clang_toolchain/BUILD -tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl -tensorflow/third_party/gast.BUILD -tensorflow/third_party/llvm/BUILD -tensorflow/third_party/llvm/expand_cmake_vars.py -tensorflow/third_party/llvm/llvm.autogenerated.BUILD -tensorflow/third_party/llvm/llvm.bzl -tensorflow/third_party/icu/udata.patch -tensorflow/third_party/nccl/archive.BUILD -tensorflow/third_party/nccl/LICENSE -tensorflow/third_party/nccl/system.BUILD.tpl -tensorflow/third_party/nccl/nccl_configure.bzl -tensorflow/third_party/nccl/build_defs.bzl.tpl -tensorflow/third_party/nccl/BUILD -tensorflow/third_party/fft2d/BUILD -tensorflow/third_party/fft2d/fft.h -tensorflow/third_party/fft2d/LICENSE -tensorflow/third_party/fft2d/fft2d.BUILD -tensorflow/third_party/boringssl/BUILD -tensorflow/third_party/mpi/.gitignore -tensorflow/third_party/mpi/BUILD -tensorflow/third_party/tensorrt/LICENSE -tensorflow/third_party/tensorrt/BUILD -tensorflow/third_party/tensorrt/build_defs.bzl.tpl -tensorflow/third_party/tensorrt/BUILD.tpl -tensorflow/third_party/tensorrt/tensorrt_configure.bzl -tensorflow/third_party/kafka/config.patch -tensorflow/third_party/kafka/BUILD -tensorflow/third_party/android/BUILD -tensorflow/third_party/android/android.bzl.tpl -tensorflow/third_party/android/android_configure.bzl -tensorflow/third_party/android/android_configure.BUILD.tpl -tensorflow/third_party/tflite_smartreply.BUILD +tensorflow/third_party/gpus/BUILD +tensorflow/third_party/common.bzl +tensorflow/third_party/tflite_mobilenet_quant.BUILD +tensorflow/third_party/linenoise.BUILD +tensorflow/third_party/curl.BUILD tensorflow/third_party/mkl_dnn/LICENSE tensorflow/third_party/mkl_dnn/mkldnn.BUILD -tensorflow/third_party/pcre.BUILD -tensorflow/third_party/linenoise.BUILD -tensorflow/third_party/sqlite.BUILD -tensorflow/third_party/common.bzl -tensorflow/third_party/com_google_absl.BUILD -tensorflow/third_party/pprof.BUILD -tensorflow/third_party/BUILD -tensorflow/third_party/tflite_mobilenet_quant.BUILD -tensorflow/third_party/lmdb.BUILD -tensorflow/third_party/git/BUILD.tpl -tensorflow/third_party/git/BUILD -tensorflow/third_party/git/git_configure.bzl -tensorflow/third_party/protobuf/BUILD -tensorflow/third_party/enum34.BUILD -tensorflow/third_party/tflite_mobilenet.BUILD -tensorflow/third_party/py/BUILD -tensorflow/third_party/py/BUILD.tpl -tensorflow/third_party/py/numpy/BUILD -tensorflow/third_party/py/python_configure.bzl -tensorflow/third_party/termcolor.BUILD -tensorflow/third_party/png_fix_rpi.patch -tensorflow/third_party/swig.BUILD -tensorflow/third_party/astor.BUILD +tensorflow/third_party/fft2d/LICENSE +tensorflow/third_party/fft2d/fft2d.BUILD +tensorflow/third_party/fft2d/fft.h +tensorflow/third_party/fft2d/BUILD +tensorflow/third_party/ngraph/LICENSE +tensorflow/third_party/ngraph/build_defs.bzl +tensorflow/third_party/ngraph/tbb.BUILD +tensorflow/third_party/ngraph/ngraph.BUILD +tensorflow/third_party/ngraph/nlohmann_json.BUILD +tensorflow/third_party/ngraph/BUILD +tensorflow/third_party/ngraph/ngraph_tf.BUILD +tensorflow/third_party/ngraph/NGRAPH_LICENSE tensorflow/third_party/grpc/BUILD -tensorflow/third_party/curl.BUILD -tensorflow/third_party/arm_neon_2_x86_sse.BUILD +tensorflow/third_party/cython.BUILD +tensorflow/third_party/icu/udata.patch +tensorflow/third_party/astor.BUILD +tensorflow/third_party/jsoncpp.BUILD +tensorflow/third_party/sycl/crosstool/BUILD +tensorflow/third_party/llvm/llvm.autogenerated.BUILD +tensorflow/third_party/llvm/expand_cmake_vars.py +tensorflow/third_party/llvm/llvm.bzl +tensorflow/third_party/llvm/BUILD tensorflow/third_party/png.BUILD -tensorflow/third_party/googleapis.BUILD -tensorflow/third_party/mpi_collectives/BUILD -tensorflow/third_party/nanopb.BUILD -tensorflow/third_party/gif.BUILD +tensorflow/third_party/arm_neon_2_x86_sse.BUILD +tensorflow/third_party/codegen.BUILD +tensorflow/third_party/enum34.BUILD +tensorflow/third_party/kafka/config.patch +tensorflow/third_party/kafka/BUILD +tensorflow/third_party/pcre.BUILD +tensorflow/third_party/mpi/BUILD +tensorflow/third_party/mpi/.gitignore +tensorflow/third_party/clang_toolchain/BUILD +tensorflow/third_party/clang_toolchain/download_clang.bzl +tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl +tensorflow/third_party/tflite_ovic_testdata.BUILD +tensorflow/third_party/repo.bzl +tensorflow/third_party/png_fix_rpi.patch +tensorflow/third_party/py/python_configure.bzl +tensorflow/third_party/py/BUILD.tpl +tensorflow/third_party/py/BUILD +tensorflow/third_party/py/numpy/BUILD tensorflow/third_party/double_conversion.BUILD tensorflow/third_party/six.BUILD -tensorflow/third_party/tflite_mobilenet_float.BUILD -tensorflow/third_party/repo.bzl -tensorflow/third_party/codegen.BUILD -tensorflow/third_party/cub.BUILD -tensorflow/third_party/jsoncpp.BUILD -tensorflow/third_party/tflite_ovic_testdata.BUILD -tensorflow/third_party/libxsmm.BUILD tensorflow/third_party/zlib.BUILD +tensorflow/third_party/lmdb.BUILD +tensorflow/third_party/nanopb.BUILD +tensorflow/third_party/android/android.bzl.tpl +tensorflow/third_party/android/BUILD +tensorflow/third_party/android/android_configure.BUILD.tpl +tensorflow/third_party/android/android_configure.bzl +tensorflow/third_party/tflite_mobilenet_float.BUILD +tensorflow/third_party/sqlite.BUILD +tensorflow/third_party/tensorrt/build_defs.bzl.tpl +tensorflow/third_party/tensorrt/LICENSE +tensorflow/third_party/tensorrt/tensorrt_configure.bzl +tensorflow/third_party/tensorrt/BUILD.tpl +tensorflow/third_party/tensorrt/BUILD +tensorflow/third_party/gast.BUILD +tensorflow/third_party/mpi_collectives/BUILD +tensorflow/third_party/libxsmm.BUILD tensorflow/third_party/eigen.BUILD +tensorflow/third_party/com_google_absl.BUILD +tensorflow/third_party/eigen3/LICENSE +tensorflow/third_party/eigen3/gpu_packet_math.patch +tensorflow/third_party/eigen3/BUILD +tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions +tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool +tensorflow/third_party/eigen3/Eigen/QR +tensorflow/third_party/eigen3/Eigen/SVD +tensorflow/third_party/eigen3/Eigen/LU +tensorflow/third_party/eigen3/Eigen/Cholesky +tensorflow/third_party/eigen3/Eigen/Eigenvalues +tensorflow/third_party/eigen3/Eigen/Core +tensorflow/third_party/BUILD +tensorflow/third_party/termcolor.BUILD +tensorflow/third_party/gif.BUILD +tensorflow/third_party/tflite_mobilenet.BUILD +tensorflow/third_party/mkl/LICENSE +tensorflow/third_party/mkl/build_defs.bzl +tensorflow/third_party/mkl/mkl.BUILD +tensorflow/third_party/mkl/MKL_LICENSE +tensorflow/third_party/mkl/BUILD +tensorflow/third_party/nccl/build_defs.bzl.tpl +tensorflow/third_party/nccl/LICENSE +tensorflow/third_party/nccl/nccl_configure.bzl +tensorflow/third_party/nccl/archive.BUILD +tensorflow/third_party/nccl/BUILD +tensorflow/third_party/nccl/system.BUILD.tpl +tensorflow/third_party/snappy.BUILD +tensorflow/third_party/python_runtime/BUILD +tensorflow/third_party/googleapis.BUILD +tensorflow/third_party/boringssl/BUILD +tensorflow/third_party/protobuf/BUILD +tensorflow/third_party/backports_weakref.BUILD +tensorflow/third_party/tflite_smartreply.BUILD +tensorflow/third_party/swig.BUILD +tensorflow/compat_template.__init__.py +tensorflow/tools/lib_package/libtensorflow_test.sh +tensorflow/tools/lib_package/libtensorflow_java_test.sh +tensorflow/tools/lib_package/libtensorflow_test.c +tensorflow/tools/lib_package/concat_licenses.sh +tensorflow/tools/lib_package/LibTensorFlowTest.java +tensorflow/tools/lib_package/BUILD +tensorflow/tools/lib_package/README.md +tensorflow/tools/pip_package/check_load_py_test.py +tensorflow/tools/pip_package/simple_console.py +tensorflow/tools/pip_package/pip_smoke_test.py +tensorflow/tools/pip_package/BUILD +tensorflow/tools/pip_package/simple_console_for_windows.py +tensorflow/tools/pip_package/build_pip_package.sh +tensorflow/tools/pip_package/README +tensorflow/tools/pip_package/setup.py +tensorflow/tools/pip_package/MANIFEST.in +tensorflow/tools/ci_build/remote/BUILD +tensorflow/tools/def_file_filter/def_file_filter.py.tpl +tensorflow/tools/def_file_filter/BUILD.tpl +tensorflow/tools/def_file_filter/BUILD +tensorflow/tools/def_file_filter/def_file_filter_configure.bzl +tensorflow/api_template.__init__.py +tensorflow/contrib/tpu/profiler/pip_package/BUILD +tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/__init__.py +tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py +tensorflow/contrib/tpu/profiler/pip_package/build_pip_package.sh +tensorflow/contrib/tpu/profiler/pip_package/README +tensorflow/contrib/tpu/profiler/pip_package/setup.py +tensorflow/contrib/mpi/BUILD +tensorflow/__init__.py tensorflow/stream_executor/build_defs.bzl tensorflow/api_template_v1.__init__.py -tensorflow/compat_template_v1.__init__.py -tensorflow/compat_template.__init__.py -tensorflow/api_template.__init__.py -tensorflow/__init__.py \ No newline at end of file +tensorflow/compat_template_v1.__init__.py \ No newline at end of file diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index f66d851dedf..6106f3598e1 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -2551,6 +2551,12 @@ class Operation(object): shapes_list = attr_value_pb2.AttrValue.ListValue(shape=shapes) self._set_attr(attr_name, attr_value_pb2.AttrValue(list=shapes_list)) + def _clear_attr(self, attr_name): + """Private method used to clear an attribute in the node_def.""" + # pylint: disable=protected-access + c_api.ClearAttr(self._graph._c_graph, self._c_op, attr_name) + # pylint: enable=protected-access + def get_attr(self, name): """Returns the value of the attr of this op with the given `name`. diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py index 02489a9b10e..55273a5203e 100644 --- a/tensorflow/python/tpu/tpu.py +++ b/tensorflow/python/tpu/tpu.py @@ -62,6 +62,10 @@ _UNSUPPORTED_OPS = set([ "TensorSummaryV2", ]) +# Ops which can be safely pruned from XLA compile if they have no consumers. +# These ops should also have no inputs. +_UNCONNECTED_OPS_TO_PRUNE = set(["Placeholder", "VarHandleOp"]) + _MAX_WARNING_LINES = 5 _TPU_REPLICATE_ATTR = "_tpu_replicate" @@ -1574,3 +1578,31 @@ def rewrite_for_inference(computation, device_assignment=device_assignment, name=name) # pylint: enable=undefined-variable + + +def prune_unconnected_ops_from_xla(prune_graph): + """Prunes unconnected ops as listed in _UNCONNECTED_OPS_TO_PRUNE. + + Args: + prune_graph: A tensorflow graph from which we wish to prune unconnected ops + as listed in _UNCONNECTED_OPS_TO_PRUNE. In general, these ops should have + no inputs and no consumers. These can often be left behind due to graph + construction rewiring (for instance TF-Hub). While they never execute, + they will cause XLA compile to fail so we strip them from XLA compile by + removing the tpu_replicate attribute. + """ + # Scan over the top level graph and all function graphs. + for graph in [prune_graph] + list(prune_graph._functions.values()): # pylint: disable=protected-access + for op in graph.get_operations(): + if op.type not in _UNCONNECTED_OPS_TO_PRUNE: + continue + outputs_consumed = False + for output in op.outputs: + if output.consumers(): + outputs_consumed = True + break + if not outputs_consumed: + logging.info( + "Pruning OP %s of type %s from XLA Compile due to " + "it being disconnected.", op.name, op.type) + op._clear_attr(_TPU_REPLICATE_ATTR) # pylint: disable=protected-access diff --git a/tensorflow/python/tpu/tpu_estimator.py b/tensorflow/python/tpu/tpu_estimator.py index 83e86407662..9c47727772e 100644 --- a/tensorflow/python/tpu/tpu_estimator.py +++ b/tensorflow/python/tpu/tpu_estimator.py @@ -2477,6 +2477,7 @@ class TPUEstimator(estimator_lib.Estimator): if self._experimental_exported_model_uses_all_cores: tensors_on_cpu = tpu.rewrite( tpu_computation, device_assignment=device_assignment) + tpu.prune_unconnected_ops_from_xla(ops.get_default_graph()) else: tensors_on_cpu = tpu.rewrite_for_inference( tpu_computation, device_assignment=device_assignment) diff --git a/tensorflow/python/tpu/tpu_test.py b/tensorflow/python/tpu/tpu_test.py index 69b03811daa..914f7322707 100644 --- a/tensorflow/python/tpu/tpu_test.py +++ b/tensorflow/python/tpu/tpu_test.py @@ -19,12 +19,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops from tensorflow.python.layers import convolutional from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util +from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.tpu import tpu from tensorflow.python.tpu import tpu_feed @@ -76,5 +81,63 @@ class TPULayerRewriteTest(test.TestCase): tpu.rewrite(loop) +class TPUGraphPruneTest(test.TestCase): + + def test_prune_unconnected_ops(self): + with ops.Graph().as_default(): + a = array_ops.placeholder(dtype=dtypes.float32, name="a") + b = array_ops.placeholder(dtype=dtypes.float32, name="b") + constant_op.constant(1.0, name="constant") + x = variable_scope.get_variable( + name="x", + dtype=dtypes.float32, + shape=[], + use_resource=True, + initializer=init_ops.constant_initializer(2.0)) + y = variable_scope.get_variable( + name="y", + dtype=dtypes.float32, + shape=[], + use_resource=True, + initializer=init_ops.constant_initializer(3.0)) + math_ops.add(a, b) + math_ops.add(x, y) + graph_def = ops.get_default_graph().as_graph_def() + + for node in graph_def.node: + # Attach a TPU_REPLICATE_ATTR to each node. + node.attr[tpu._TPU_REPLICATE_ATTR].s = b"0" + # Rewire placeholder "a" and variable "y" leaving them unconnected. + for (input_index, node_input) in enumerate(node.input): + if node_input == "b": + node.input[input_index] = "constant" + if node_input == "y": + node.input[input_index] = "x" + + with ops.Graph().as_default() as graph: + # Reimport the graph and prune unconnected ops. + importer.import_graph_def(graph_def) + tpu.prune_unconnected_ops_from_xla(ops.get_default_graph()) + + # Verify that ops "a" and "x" still have TPU_REPLICATE_ATTR. + a = graph.get_operation_by_name("import/a").get_attr( + tpu._TPU_REPLICATE_ATTR) + self.assertEqual(b"0", a) + x = graph.get_operation_by_name("import/x").get_attr( + tpu._TPU_REPLICATE_ATTR) + self.assertEqual(b"0", x) + # Verify that ops "b" and "y" have TPU_REPLICATE_ATTR removed. + with self.assertRaisesRegexp( + ValueError, + "Operation \'import/b\' has no attr named \'_tpu_replicate\'"): + graph.get_operation_by_name("import/b").get_attr( + tpu._TPU_REPLICATE_ATTR) + with self.assertRaisesRegexp( + ValueError, + "Operation \'import/y\' has no attr named \'_tpu_replicate\'"): + graph.get_operation_by_name("import/y").get_attr( + tpu._TPU_REPLICATE_ATTR) + + if __name__ == "__main__": test.main()