Enable model_to_estimator for wide and deep model.

PiperOrigin-RevId: 264260434
This commit is contained in:
Zhenyu Tan 2019-08-19 15:47:21 -07:00 committed by TensorFlower Gardener
parent 8e2ca26f57
commit 372a2dbe96
2 changed files with 85 additions and 60 deletions

View File

@ -1,6 +1,6 @@
tensorflow/__init__.py tensorflow/__init__.py
tensorflow/api_template.__init__.py
tensorflow/api_template_v1.__init__.py tensorflow/api_template_v1.__init__.py
tensorflow/api_template.__init__.py
tensorflow/compat_template.__init__.py tensorflow/compat_template.__init__.py
tensorflow/compat_template_v1.__init__.py tensorflow/compat_template_v1.__init__.py
tensorflow/contrib/mpi/BUILD tensorflow/contrib/mpi/BUILD
@ -28,23 +28,23 @@ tensorflow/third_party/com_google_absl.BUILD
tensorflow/third_party/common.bzl tensorflow/third_party/common.bzl
tensorflow/third_party/cub.BUILD tensorflow/third_party/cub.BUILD
tensorflow/third_party/curl.BUILD tensorflow/third_party/curl.BUILD
tensorflow/third_party/double_conversion.BUILD tensorflow/third_party/cython.BUILD
tensorflow/third_party/eigen3/Eigen/Core
tensorflow/third_party/eigen3/Eigen/Cholesky tensorflow/third_party/eigen3/Eigen/Cholesky
tensorflow/third_party/eigen3/Eigen/Eigenvalues tensorflow/third_party/eigen3/Eigen/Eigenvalues
tensorflow/third_party/eigen3/Eigen/Core
tensorflow/third_party/eigen3/Eigen/LU tensorflow/third_party/eigen3/Eigen/LU
tensorflow/third_party/eigen3/Eigen/SVD
tensorflow/third_party/eigen3/Eigen/QR tensorflow/third_party/eigen3/Eigen/QR
tensorflow/third_party/eigen3/Eigen/SVD
tensorflow/third_party/eigen3/BUILD tensorflow/third_party/eigen3/BUILD
tensorflow/third_party/eigen3/LICENSE tensorflow/third_party/eigen3/LICENSE
tensorflow/third_party/eigen3/gpu_packet_math.patch tensorflow/third_party/eigen3/gpu_packet_math.patch
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
@ -52,7 +52,7 @@ tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCasting
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
tensorflow/third_party/cython.BUILD tensorflow/third_party/double_conversion.BUILD
tensorflow/third_party/eigen.BUILD tensorflow/third_party/eigen.BUILD
tensorflow/third_party/fft2d/BUILD tensorflow/third_party/fft2d/BUILD
tensorflow/third_party/fft2d/LICENSE tensorflow/third_party/fft2d/LICENSE
@ -64,28 +64,28 @@ tensorflow/third_party/farmhash.BUILD
tensorflow/third_party/functools32.BUILD tensorflow/third_party/functools32.BUILD
tensorflow/third_party/gast.BUILD tensorflow/third_party/gast.BUILD
tensorflow/third_party/git/BUILD tensorflow/third_party/git/BUILD
tensorflow/third_party/git/git_configure.bzl
tensorflow/third_party/git/BUILD.tpl tensorflow/third_party/git/BUILD.tpl
tensorflow/third_party/git/git_configure.bzl
tensorflow/third_party/gif.BUILD tensorflow/third_party/gif.BUILD
tensorflow/third_party/gpus/crosstool/BUILD tensorflow/third_party/gpus/crosstool/BUILD
tensorflow/third_party/gpus/crosstool/BUILD.tpl tensorflow/third_party/gpus/crosstool/BUILD.tpl
tensorflow/third_party/gpus/crosstool/LICENSE tensorflow/third_party/gpus/crosstool/LICENSE
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
tensorflow/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl tensorflow/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
tensorflow/third_party/gpus/BUILD tensorflow/third_party/gpus/BUILD
tensorflow/third_party/gpus/cuda/BUILD.tpl
tensorflow/third_party/gpus/cuda/BUILD tensorflow/third_party/gpus/cuda/BUILD
tensorflow/third_party/gpus/cuda/BUILD.tpl
tensorflow/third_party/gpus/cuda/BUILD.windows.tpl tensorflow/third_party/gpus/cuda/BUILD.windows.tpl
tensorflow/third_party/gpus/cuda/LICENSE tensorflow/third_party/gpus/cuda/LICENSE
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
tensorflow/third_party/gpus/cuda/cuda_config.h.tpl tensorflow/third_party/gpus/cuda/cuda_config.h.tpl
tensorflow/third_party/gpus/cuda_configure.bzl
tensorflow/third_party/gpus/rocm/BUILD tensorflow/third_party/gpus/rocm/BUILD
tensorflow/third_party/gpus/rocm/BUILD.tpl tensorflow/third_party/gpus/rocm/BUILD.tpl
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
tensorflow/third_party/gpus/find_cuda_config.py tensorflow/third_party/gpus/find_cuda_config.py
tensorflow/third_party/gpus/cuda_configure.bzl
tensorflow/third_party/gpus/rocm_configure.bzl tensorflow/third_party/gpus/rocm_configure.bzl
tensorflow/third_party/googleapis.BUILD tensorflow/third_party/googleapis.BUILD
tensorflow/third_party/grpc/BUILD tensorflow/third_party/grpc/BUILD
@ -99,12 +99,12 @@ tensorflow/third_party/llvm/expand_cmake_vars.py
tensorflow/third_party/llvm/llvm.autogenerated.BUILD tensorflow/third_party/llvm/llvm.autogenerated.BUILD
tensorflow/third_party/llvm/llvm.bzl tensorflow/third_party/llvm/llvm.bzl
tensorflow/third_party/linenoise.BUILD tensorflow/third_party/linenoise.BUILD
tensorflow/third_party/lmdb.BUILD
tensorflow/third_party/mkl/BUILD tensorflow/third_party/mkl/BUILD
tensorflow/third_party/mkl/LICENSE tensorflow/third_party/mkl/LICENSE
tensorflow/third_party/mkl/MKL_LICENSE tensorflow/third_party/mkl/MKL_LICENSE
tensorflow/third_party/mkl/mkl.BUILD tensorflow/third_party/mkl/mkl.BUILD
tensorflow/third_party/mkl/build_defs.bzl tensorflow/third_party/mkl/build_defs.bzl
tensorflow/third_party/lmdb.BUILD
tensorflow/third_party/mkl_dnn/LICENSE tensorflow/third_party/mkl_dnn/LICENSE
tensorflow/third_party/mkl_dnn/mkldnn.BUILD tensorflow/third_party/mkl_dnn/mkldnn.BUILD
tensorflow/third_party/mpi/.gitignore tensorflow/third_party/mpi/.gitignore
@ -112,19 +112,19 @@ tensorflow/third_party/mpi/BUILD
tensorflow/third_party/mpi_collectives/BUILD tensorflow/third_party/mpi_collectives/BUILD
tensorflow/third_party/nanopb.BUILD tensorflow/third_party/nanopb.BUILD
tensorflow/third_party/nccl/BUILD tensorflow/third_party/nccl/BUILD
tensorflow/third_party/nccl/archive.BUILD
tensorflow/third_party/nccl/LICENSE tensorflow/third_party/nccl/LICENSE
tensorflow/third_party/nccl/build_defs.bzl.tpl tensorflow/third_party/nccl/archive.BUILD
tensorflow/third_party/nccl/archive.patch tensorflow/third_party/nccl/archive.patch
tensorflow/third_party/nccl/build_defs.bzl.tpl
tensorflow/third_party/nccl/nccl_configure.bzl tensorflow/third_party/nccl/nccl_configure.bzl
tensorflow/third_party/nccl/system.BUILD.tpl tensorflow/third_party/nccl/system.BUILD.tpl
tensorflow/third_party/ngraph/BUILD tensorflow/third_party/ngraph/BUILD
tensorflow/third_party/ngraph/LICENSE
tensorflow/third_party/ngraph/NGRAPH_LICENSE tensorflow/third_party/ngraph/NGRAPH_LICENSE
tensorflow/third_party/ngraph/build_defs.bzl tensorflow/third_party/ngraph/LICENSE
tensorflow/third_party/ngraph/ngraph.BUILD tensorflow/third_party/ngraph/ngraph.BUILD
tensorflow/third_party/ngraph/nlohmann_json.BUILD tensorflow/third_party/ngraph/build_defs.bzl
tensorflow/third_party/ngraph/ngraph_tf.BUILD tensorflow/third_party/ngraph/ngraph_tf.BUILD
tensorflow/third_party/ngraph/nlohmann_json.BUILD
tensorflow/third_party/ngraph/tbb.BUILD tensorflow/third_party/ngraph/tbb.BUILD
tensorflow/third_party/opt_einsum.BUILD tensorflow/third_party/opt_einsum.BUILD
tensorflow/third_party/pcre.BUILD tensorflow/third_party/pcre.BUILD
@ -133,46 +133,46 @@ tensorflow/third_party/protobuf/BUILD
tensorflow/third_party/png_fix_rpi.patch tensorflow/third_party/png_fix_rpi.patch
tensorflow/third_party/pprof.BUILD tensorflow/third_party/pprof.BUILD
tensorflow/third_party/py/numpy/BUILD tensorflow/third_party/py/numpy/BUILD
tensorflow/third_party/py/BUILD
tensorflow/third_party/py/BUILD.tpl tensorflow/third_party/py/BUILD.tpl
tensorflow/third_party/py/BUILD
tensorflow/third_party/py/python_configure.bzl tensorflow/third_party/py/python_configure.bzl
tensorflow/third_party/python_runtime/BUILD tensorflow/third_party/python_runtime/BUILD
tensorflow/third_party/pybind11.BUILD tensorflow/third_party/pybind11.BUILD
tensorflow/third_party/repo.bzl tensorflow/third_party/repo.bzl
tensorflow/third_party/six.BUILD tensorflow/third_party/sycl/crosstool/BUILD
tensorflow/third_party/snappy.BUILD tensorflow/third_party/snappy.BUILD
tensorflow/third_party/sqlite.BUILD tensorflow/third_party/sqlite.BUILD
tensorflow/third_party/sycl/crosstool/BUILD tensorflow/third_party/six.BUILD
tensorflow/third_party/swig.BUILD tensorflow/third_party/swig.BUILD
tensorflow/third_party/systemlibs/BUILD.tpl
tensorflow/third_party/systemlibs/BUILD tensorflow/third_party/systemlibs/BUILD
tensorflow/third_party/systemlibs/BUILD.tpl
tensorflow/third_party/systemlibs/absl_py.BUILD tensorflow/third_party/systemlibs/absl_py.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
tensorflow/third_party/systemlibs/astor.BUILD tensorflow/third_party/systemlibs/astor.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
tensorflow/third_party/systemlibs/boringssl.BUILD tensorflow/third_party/systemlibs/boringssl.BUILD
tensorflow/third_party/systemlibs/curl.BUILD
tensorflow/third_party/systemlibs/build_defs.bzl.tpl tensorflow/third_party/systemlibs/build_defs.bzl.tpl
tensorflow/third_party/systemlibs/curl.BUILD
tensorflow/third_party/systemlibs/cython.BUILD tensorflow/third_party/systemlibs/cython.BUILD
tensorflow/third_party/systemlibs/double_conversion.BUILD tensorflow/third_party/systemlibs/double_conversion.BUILD
tensorflow/third_party/systemlibs/gast.BUILD tensorflow/third_party/systemlibs/gast.BUILD
tensorflow/third_party/systemlibs/gif.BUILD tensorflow/third_party/systemlibs/gif.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
tensorflow/third_party/systemlibs/googleapis.BUILD
tensorflow/third_party/systemlibs/grpc.BUILD tensorflow/third_party/systemlibs/grpc.BUILD
tensorflow/third_party/systemlibs/googleapis.BUILD
tensorflow/third_party/systemlibs/lmdb.BUILD tensorflow/third_party/systemlibs/lmdb.BUILD
tensorflow/third_party/systemlibs/jsoncpp.BUILD tensorflow/third_party/systemlibs/jsoncpp.BUILD
tensorflow/third_party/systemlibs/nsync.BUILD tensorflow/third_party/systemlibs/nsync.BUILD
tensorflow/third_party/systemlibs/opt_einsum.BUILD tensorflow/third_party/systemlibs/opt_einsum.BUILD
tensorflow/third_party/systemlibs/pcre.BUILD tensorflow/third_party/systemlibs/pcre.BUILD
tensorflow/third_party/systemlibs/protobuf.BUILD
tensorflow/third_party/systemlibs/png.BUILD tensorflow/third_party/systemlibs/png.BUILD
tensorflow/third_party/systemlibs/protobuf.BUILD
tensorflow/third_party/systemlibs/protobuf.bzl tensorflow/third_party/systemlibs/protobuf.bzl
tensorflow/third_party/systemlibs/re2.BUILD tensorflow/third_party/systemlibs/re2.BUILD
tensorflow/third_party/systemlibs/snappy.BUILD
tensorflow/third_party/systemlibs/six.BUILD tensorflow/third_party/systemlibs/six.BUILD
tensorflow/third_party/systemlibs/sqlite.BUILD tensorflow/third_party/systemlibs/sqlite.BUILD
tensorflow/third_party/systemlibs/snappy.BUILD
tensorflow/third_party/systemlibs/swig.BUILD tensorflow/third_party/systemlibs/swig.BUILD
tensorflow/third_party/systemlibs/syslibs_configure.bzl tensorflow/third_party/systemlibs/syslibs_configure.bzl
tensorflow/third_party/systemlibs/termcolor.BUILD tensorflow/third_party/systemlibs/termcolor.BUILD
@ -188,12 +188,12 @@ tensorflow/third_party/tflite_mobilenet.BUILD
tensorflow/third_party/tflite_mobilenet_float.BUILD tensorflow/third_party/tflite_mobilenet_float.BUILD
tensorflow/third_party/toolchains/clang6/BUILD tensorflow/third_party/toolchains/clang6/BUILD
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
tensorflow/third_party/toolchains/clang6/clang.BUILD
tensorflow/third_party/toolchains/clang6/README.md tensorflow/third_party/toolchains/clang6/README.md
tensorflow/third_party/toolchains/clang6/clang.BUILD
tensorflow/third_party/toolchains/clang6/repo.bzl tensorflow/third_party/toolchains/clang6/repo.bzl
tensorflow/third_party/toolchains/BUILD tensorflow/third_party/toolchains/BUILD
tensorflow/third_party/toolchains/cpus/arm/BUILD
tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl
tensorflow/third_party/toolchains/cpus/arm/BUILD
tensorflow/third_party/toolchains/cpus/arm/cc_config.bzl.tpl tensorflow/third_party/toolchains/cpus/arm/cc_config.bzl.tpl
tensorflow/third_party/toolchains/cpus/py/BUILD tensorflow/third_party/toolchains/cpus/py/BUILD
tensorflow/third_party/toolchains/cpus/py3/BUILD tensorflow/third_party/toolchains/cpus/py3/BUILD
@ -214,13 +214,13 @@ tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/build_defs.bzl tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/generate/BUILD tensorflow/third_party/toolchains/preconfig/generate/BUILD
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/cc_toolchain_config.bzl tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
@ -239,44 +239,44 @@ tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3_opt/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3_opt/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5.1/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5.1/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5.1/build_defs.bzl tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5.1/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_025/BUILD tensorflow/third_party/toolchains/preconfig/win_1803/bazel_025/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
tensorflow/third_party/toolchains/remote/BUILD
tensorflow/third_party/toolchains/remote/BUILD.tpl tensorflow/third_party/toolchains/remote/BUILD.tpl
tensorflow/third_party/toolchains/remote/configure.bzl tensorflow/third_party/toolchains/remote/configure.bzl
tensorflow/third_party/toolchains/remote/BUILD
tensorflow/third_party/toolchains/remote/execution.bzl.tpl tensorflow/third_party/toolchains/remote/execution.bzl.tpl
tensorflow/third_party/tflite_ovic_testdata.BUILD
tensorflow/third_party/tflite_mobilenet_quant.BUILD tensorflow/third_party/tflite_mobilenet_quant.BUILD
tensorflow/third_party/tflite_ovic_testdata.BUILD
tensorflow/third_party/tflite_smartreply.BUILD tensorflow/third_party/tflite_smartreply.BUILD
tensorflow/third_party/wrapt.BUILD tensorflow/third_party/wrapt.BUILD
tensorflow/third_party/zlib.BUILD tensorflow/third_party/zlib.BUILD
tensorflow/tools/ci_build/remote/BUILD tensorflow/tools/ci_build/remote/BUILD
tensorflow/tools/def_file_filter/BUILD tensorflow/tools/def_file_filter/BUILD
tensorflow/tools/def_file_filter/BUILD.tpl
tensorflow/tools/def_file_filter/def_file_filter.py.tpl tensorflow/tools/def_file_filter/def_file_filter.py.tpl
tensorflow/tools/def_file_filter/BUILD.tpl
tensorflow/tools/def_file_filter/def_file_filter_configure.bzl tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
tensorflow/tools/lib_package/BUILD tensorflow/tools/lib_package/BUILD
tensorflow/tools/lib_package/LibTensorFlowTest.java tensorflow/tools/lib_package/LibTensorFlowTest.java
tensorflow/tools/lib_package/README.md tensorflow/tools/lib_package/README.md
tensorflow/tools/lib_package/concat_licenses.sh tensorflow/tools/lib_package/concat_licenses.sh
tensorflow/tools/lib_package/libtensorflow_test.c
tensorflow/tools/lib_package/libtensorflow_java_test.sh tensorflow/tools/lib_package/libtensorflow_java_test.sh
tensorflow/tools/lib_package/libtensorflow_test.c
tensorflow/tools/lib_package/libtensorflow_test.sh tensorflow/tools/lib_package/libtensorflow_test.sh
tensorflow/tools/pip_package/MANIFEST.in
tensorflow/tools/pip_package/BUILD tensorflow/tools/pip_package/BUILD
tensorflow/tools/pip_package/MANIFEST.in
tensorflow/tools/pip_package/README tensorflow/tools/pip_package/README
tensorflow/tools/pip_package/build_pip_package.sh
tensorflow/tools/pip_package/pip_smoke_test.py tensorflow/tools/pip_package/pip_smoke_test.py
tensorflow/tools/pip_package/setup.py tensorflow/tools/pip_package/build_pip_package.sh
tensorflow/tools/pip_package/simple_console.py
tensorflow/tools/pip_package/simple_console_for_windows.py
tensorflow/tools/pip_package/check_load_py_test.py tensorflow/tools/pip_package/check_load_py_test.py
tensorflow/tools/pip_package/simple_console.py
tensorflow/tools/pip_package/setup.py
tensorflow/tools/pip_package/simple_console_for_windows.py
tensorflow/virtual_root_template_v2.__init__.py tensorflow/virtual_root_template_v2.__init__.py
tensorflow/virtual_root_template_v1.__init__.py tensorflow/virtual_root_template_v1.__init__.py
llvm/llvm/projects/google_mlir/WORKSPACE llvm/llvm/projects/google_mlir/WORKSPACE

View File

@ -32,6 +32,7 @@ from tensorflow.python.keras.engine.input_layer import InputLayer
from tensorflow.python.keras.engine.network import Network from tensorflow.python.keras.engine.network import Network
from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import object_identity from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -575,8 +576,8 @@ def clone_and_build_model(
Args: Args:
model: `tf.keras.Model` object. Can be Functional, Sequential, or model: `tf.keras.Model` object. Can be Functional, Sequential, or
sub-classed. sub-classed.
input_tensors: Optional list of input tensors to build the model upon. If input_tensors: Optional list or dictionary of input tensors to build the
not provided, placeholders will be created. model upon. If not provided, placeholders will be created.
target_tensors: Optional list of target tensors for compiling the model. If target_tensors: Optional list of target tensors for compiling the model. If
not provided, placeholders will be created. not provided, placeholders will be created.
custom_objects: Optional dictionary mapping string names to custom classes custom_objects: Optional dictionary mapping string names to custom classes
@ -591,10 +592,10 @@ def clone_and_build_model(
optimizer if the clone is compiled. This argument is used when a Keras optimizer if the clone is compiled. This argument is used when a Keras
model is cloned into an Estimator model function, because Estimators model is cloned into an Estimator model function, because Estimators
create their own global step variable. create their own global step variable.
optimizer_config: Optimizer config dictionary returned from `get_config()`. optimizer_config: Optimizer config dictionary or list of dictionary
This argument should be defined if `clone_and_build_model` is called in returned from `get_config()`. This argument should be defined if
a different graph or session from the original model, and the optimizer is `clone_and_build_model` is called in a different graph or session from
an instance of `OptimizerV2`. the original model, and the optimizer is an instance of `OptimizerV2`.
Returns: Returns:
Clone of the model. Clone of the model.
@ -628,6 +629,14 @@ def clone_and_build_model(
clone._set_inputs( clone._set_inputs(
K.placeholder(model._build_input_shape, dtype=model.inputs[0].dtype)) K.placeholder(model._build_input_shape, dtype=model.inputs[0].dtype))
else: else:
try:
# Prefer clonining the model if serial/deserial logic is implemented for
# subclassed model.
clone = model.__class__.from_config(model.get_config())
except NotImplementedError:
logging.warning('This model is a subclassed model. Please implement '
'`get_config` and `from_config` to better support '
'cloning the model.')
if not in_place_reset: if not in_place_reset:
raise ValueError( raise ValueError(
'This model is a subclassed model. ' 'This model is a subclassed model. '
@ -649,11 +658,27 @@ def clone_and_build_model(
orig_optimizer.optimizer, optimizer_iterations) orig_optimizer.optimizer, optimizer_iterations)
K.track_tf_optimizer(optimizer) K.track_tf_optimizer(optimizer)
else: else:
optimizer_config = optimizer_config or orig_optimizer.get_config() if not isinstance(orig_optimizer, (tuple, list)):
optimizer = orig_optimizer.__class__.from_config(optimizer_config) orig_optimizer = [orig_optimizer]
if optimizer_config is None:
optimizer = [
opt.__class__.from_config(opt.get_config())
for opt in orig_optimizer
]
elif isinstance(optimizer_config, dict):
optimizer = [orig_optimizer[0].__class__.from_config(optimizer_config)]
else:
# optimizer config is list of dict, same order as orig_optimizer.
optimizer = [
opt.__class__.from_config(opt_config)
for (opt, opt_config) in zip(orig_optimizer, optimizer_config)
]
if optimizer_iterations is not None: if optimizer_iterations is not None:
optimizer.iterations = optimizer_iterations for opt in optimizer:
opt.iterations = optimizer_iterations
if len(optimizer) == 1:
optimizer = optimizer[0]
clone.compile( clone.compile(
optimizer, optimizer,
model.loss, model.loss,