From 372a2dbe96dd62ab966c81a82c2aa72222cc2265 Mon Sep 17 00:00:00 2001 From: Zhenyu Tan Date: Mon, 19 Aug 2019 15:47:21 -0700 Subject: [PATCH] Enable model_to_estimator for wide and deep model. PiperOrigin-RevId: 264260434 --- tensorflow/opensource_only.files | 82 +++++++++++++++---------------- tensorflow/python/keras/models.py | 63 +++++++++++++++++------- 2 files changed, 85 insertions(+), 60 deletions(-) diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index d328d34c758..9e6ebf68d93 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -1,6 +1,6 @@ tensorflow/__init__.py -tensorflow/api_template.__init__.py tensorflow/api_template_v1.__init__.py +tensorflow/api_template.__init__.py tensorflow/compat_template.__init__.py tensorflow/compat_template_v1.__init__.py tensorflow/contrib/mpi/BUILD @@ -28,23 +28,23 @@ tensorflow/third_party/com_google_absl.BUILD tensorflow/third_party/common.bzl tensorflow/third_party/cub.BUILD tensorflow/third_party/curl.BUILD -tensorflow/third_party/double_conversion.BUILD +tensorflow/third_party/cython.BUILD +tensorflow/third_party/eigen3/Eigen/Core tensorflow/third_party/eigen3/Eigen/Cholesky tensorflow/third_party/eigen3/Eigen/Eigenvalues -tensorflow/third_party/eigen3/Eigen/Core tensorflow/third_party/eigen3/Eigen/LU -tensorflow/third_party/eigen3/Eigen/SVD tensorflow/third_party/eigen3/Eigen/QR +tensorflow/third_party/eigen3/Eigen/SVD tensorflow/third_party/eigen3/BUILD tensorflow/third_party/eigen3/LICENSE tensorflow/third_party/eigen3/gpu_packet_math.patch -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint +tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h -tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h @@ -52,7 +52,7 @@ tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCasting tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions -tensorflow/third_party/cython.BUILD +tensorflow/third_party/double_conversion.BUILD tensorflow/third_party/eigen.BUILD tensorflow/third_party/fft2d/BUILD tensorflow/third_party/fft2d/LICENSE @@ -64,28 +64,28 @@ tensorflow/third_party/farmhash.BUILD tensorflow/third_party/functools32.BUILD tensorflow/third_party/gast.BUILD tensorflow/third_party/git/BUILD -tensorflow/third_party/git/git_configure.bzl tensorflow/third_party/git/BUILD.tpl +tensorflow/third_party/git/git_configure.bzl tensorflow/third_party/gif.BUILD tensorflow/third_party/gpus/crosstool/BUILD tensorflow/third_party/gpus/crosstool/BUILD.tpl tensorflow/third_party/gpus/crosstool/LICENSE -tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl tensorflow/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl tensorflow/third_party/gpus/BUILD -tensorflow/third_party/gpus/cuda/BUILD.tpl tensorflow/third_party/gpus/cuda/BUILD +tensorflow/third_party/gpus/cuda/BUILD.tpl tensorflow/third_party/gpus/cuda/BUILD.windows.tpl tensorflow/third_party/gpus/cuda/LICENSE tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl tensorflow/third_party/gpus/cuda/cuda_config.h.tpl +tensorflow/third_party/gpus/cuda_configure.bzl tensorflow/third_party/gpus/rocm/BUILD tensorflow/third_party/gpus/rocm/BUILD.tpl tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl tensorflow/third_party/gpus/rocm/rocm_config.h.tpl tensorflow/third_party/gpus/find_cuda_config.py -tensorflow/third_party/gpus/cuda_configure.bzl tensorflow/third_party/gpus/rocm_configure.bzl tensorflow/third_party/googleapis.BUILD tensorflow/third_party/grpc/BUILD @@ -99,12 +99,12 @@ tensorflow/third_party/llvm/expand_cmake_vars.py tensorflow/third_party/llvm/llvm.autogenerated.BUILD tensorflow/third_party/llvm/llvm.bzl tensorflow/third_party/linenoise.BUILD +tensorflow/third_party/lmdb.BUILD tensorflow/third_party/mkl/BUILD tensorflow/third_party/mkl/LICENSE tensorflow/third_party/mkl/MKL_LICENSE tensorflow/third_party/mkl/mkl.BUILD tensorflow/third_party/mkl/build_defs.bzl -tensorflow/third_party/lmdb.BUILD tensorflow/third_party/mkl_dnn/LICENSE tensorflow/third_party/mkl_dnn/mkldnn.BUILD tensorflow/third_party/mpi/.gitignore @@ -112,19 +112,19 @@ tensorflow/third_party/mpi/BUILD tensorflow/third_party/mpi_collectives/BUILD tensorflow/third_party/nanopb.BUILD tensorflow/third_party/nccl/BUILD -tensorflow/third_party/nccl/archive.BUILD tensorflow/third_party/nccl/LICENSE -tensorflow/third_party/nccl/build_defs.bzl.tpl +tensorflow/third_party/nccl/archive.BUILD tensorflow/third_party/nccl/archive.patch +tensorflow/third_party/nccl/build_defs.bzl.tpl tensorflow/third_party/nccl/nccl_configure.bzl tensorflow/third_party/nccl/system.BUILD.tpl tensorflow/third_party/ngraph/BUILD -tensorflow/third_party/ngraph/LICENSE tensorflow/third_party/ngraph/NGRAPH_LICENSE -tensorflow/third_party/ngraph/build_defs.bzl +tensorflow/third_party/ngraph/LICENSE tensorflow/third_party/ngraph/ngraph.BUILD -tensorflow/third_party/ngraph/nlohmann_json.BUILD +tensorflow/third_party/ngraph/build_defs.bzl tensorflow/third_party/ngraph/ngraph_tf.BUILD +tensorflow/third_party/ngraph/nlohmann_json.BUILD tensorflow/third_party/ngraph/tbb.BUILD tensorflow/third_party/opt_einsum.BUILD tensorflow/third_party/pcre.BUILD @@ -133,46 +133,46 @@ tensorflow/third_party/protobuf/BUILD tensorflow/third_party/png_fix_rpi.patch tensorflow/third_party/pprof.BUILD tensorflow/third_party/py/numpy/BUILD -tensorflow/third_party/py/BUILD tensorflow/third_party/py/BUILD.tpl +tensorflow/third_party/py/BUILD tensorflow/third_party/py/python_configure.bzl tensorflow/third_party/python_runtime/BUILD tensorflow/third_party/pybind11.BUILD tensorflow/third_party/repo.bzl -tensorflow/third_party/six.BUILD +tensorflow/third_party/sycl/crosstool/BUILD tensorflow/third_party/snappy.BUILD tensorflow/third_party/sqlite.BUILD -tensorflow/third_party/sycl/crosstool/BUILD +tensorflow/third_party/six.BUILD tensorflow/third_party/swig.BUILD -tensorflow/third_party/systemlibs/BUILD.tpl tensorflow/third_party/systemlibs/BUILD +tensorflow/third_party/systemlibs/BUILD.tpl tensorflow/third_party/systemlibs/absl_py.BUILD tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD tensorflow/third_party/systemlibs/astor.BUILD tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD tensorflow/third_party/systemlibs/boringssl.BUILD -tensorflow/third_party/systemlibs/curl.BUILD tensorflow/third_party/systemlibs/build_defs.bzl.tpl +tensorflow/third_party/systemlibs/curl.BUILD tensorflow/third_party/systemlibs/cython.BUILD tensorflow/third_party/systemlibs/double_conversion.BUILD tensorflow/third_party/systemlibs/gast.BUILD tensorflow/third_party/systemlibs/gif.BUILD tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD -tensorflow/third_party/systemlibs/googleapis.BUILD tensorflow/third_party/systemlibs/grpc.BUILD +tensorflow/third_party/systemlibs/googleapis.BUILD tensorflow/third_party/systemlibs/lmdb.BUILD tensorflow/third_party/systemlibs/jsoncpp.BUILD tensorflow/third_party/systemlibs/nsync.BUILD tensorflow/third_party/systemlibs/opt_einsum.BUILD tensorflow/third_party/systemlibs/pcre.BUILD -tensorflow/third_party/systemlibs/protobuf.BUILD tensorflow/third_party/systemlibs/png.BUILD +tensorflow/third_party/systemlibs/protobuf.BUILD tensorflow/third_party/systemlibs/protobuf.bzl tensorflow/third_party/systemlibs/re2.BUILD -tensorflow/third_party/systemlibs/snappy.BUILD tensorflow/third_party/systemlibs/six.BUILD tensorflow/third_party/systemlibs/sqlite.BUILD +tensorflow/third_party/systemlibs/snappy.BUILD tensorflow/third_party/systemlibs/swig.BUILD tensorflow/third_party/systemlibs/syslibs_configure.bzl tensorflow/third_party/systemlibs/termcolor.BUILD @@ -188,12 +188,12 @@ tensorflow/third_party/tflite_mobilenet.BUILD tensorflow/third_party/tflite_mobilenet_float.BUILD tensorflow/third_party/toolchains/clang6/BUILD tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl -tensorflow/third_party/toolchains/clang6/clang.BUILD tensorflow/third_party/toolchains/clang6/README.md +tensorflow/third_party/toolchains/clang6/clang.BUILD tensorflow/third_party/toolchains/clang6/repo.bzl tensorflow/third_party/toolchains/BUILD -tensorflow/third_party/toolchains/cpus/arm/BUILD tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl +tensorflow/third_party/toolchains/cpus/arm/BUILD tensorflow/third_party/toolchains/cpus/arm/cc_config.bzl.tpl tensorflow/third_party/toolchains/cpus/py/BUILD tensorflow/third_party/toolchains/cpus/py3/BUILD @@ -214,13 +214,13 @@ tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/BUILD tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/build_defs.bzl tensorflow/third_party/toolchains/preconfig/generate/BUILD tensorflow/third_party/toolchains/preconfig/generate/archives.bzl -tensorflow/third_party/toolchains/preconfig/generate/generate.bzl tensorflow/third_party/toolchains/preconfig/generate/containers.bzl +tensorflow/third_party/toolchains/preconfig/generate/generate.bzl tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl -tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl -tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD +tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/cc_toolchain_config.bzl +tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl @@ -239,44 +239,44 @@ tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc- tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3_opt/BUILD -tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/build_defs.bzl tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/BUILD +tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/build_defs.bzl tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5.1/BUILD tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5.1/build_defs.bzl tensorflow/third_party/toolchains/preconfig/win_1803/bazel_025/BUILD tensorflow/third_party/toolchains/preconfig/win_1803/BUILD tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD +tensorflow/third_party/toolchains/remote/BUILD tensorflow/third_party/toolchains/remote/BUILD.tpl tensorflow/third_party/toolchains/remote/configure.bzl -tensorflow/third_party/toolchains/remote/BUILD tensorflow/third_party/toolchains/remote/execution.bzl.tpl -tensorflow/third_party/tflite_ovic_testdata.BUILD tensorflow/third_party/tflite_mobilenet_quant.BUILD +tensorflow/third_party/tflite_ovic_testdata.BUILD tensorflow/third_party/tflite_smartreply.BUILD tensorflow/third_party/wrapt.BUILD tensorflow/third_party/zlib.BUILD tensorflow/tools/ci_build/remote/BUILD tensorflow/tools/def_file_filter/BUILD -tensorflow/tools/def_file_filter/BUILD.tpl tensorflow/tools/def_file_filter/def_file_filter.py.tpl +tensorflow/tools/def_file_filter/BUILD.tpl tensorflow/tools/def_file_filter/def_file_filter_configure.bzl tensorflow/tools/lib_package/BUILD tensorflow/tools/lib_package/LibTensorFlowTest.java tensorflow/tools/lib_package/README.md tensorflow/tools/lib_package/concat_licenses.sh -tensorflow/tools/lib_package/libtensorflow_test.c tensorflow/tools/lib_package/libtensorflow_java_test.sh +tensorflow/tools/lib_package/libtensorflow_test.c tensorflow/tools/lib_package/libtensorflow_test.sh -tensorflow/tools/pip_package/MANIFEST.in tensorflow/tools/pip_package/BUILD +tensorflow/tools/pip_package/MANIFEST.in tensorflow/tools/pip_package/README -tensorflow/tools/pip_package/build_pip_package.sh tensorflow/tools/pip_package/pip_smoke_test.py -tensorflow/tools/pip_package/setup.py -tensorflow/tools/pip_package/simple_console.py -tensorflow/tools/pip_package/simple_console_for_windows.py +tensorflow/tools/pip_package/build_pip_package.sh tensorflow/tools/pip_package/check_load_py_test.py +tensorflow/tools/pip_package/simple_console.py +tensorflow/tools/pip_package/setup.py +tensorflow/tools/pip_package/simple_console_for_windows.py tensorflow/virtual_root_template_v2.__init__.py tensorflow/virtual_root_template_v1.__init__.py llvm/llvm/projects/google_mlir/WORKSPACE \ No newline at end of file diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py index 888b306f6fe..bf157da6878 100644 --- a/tensorflow/python/keras/models.py +++ b/tensorflow/python/keras/models.py @@ -32,6 +32,7 @@ from tensorflow.python.keras.engine.input_layer import InputLayer from tensorflow.python.keras.engine.network import Network from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils.generic_utils import CustomObjectScope +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest from tensorflow.python.util import object_identity from tensorflow.python.util.tf_export import keras_export @@ -575,8 +576,8 @@ def clone_and_build_model( Args: model: `tf.keras.Model` object. Can be Functional, Sequential, or sub-classed. - input_tensors: Optional list of input tensors to build the model upon. If - not provided, placeholders will be created. + input_tensors: Optional list or dictionary of input tensors to build the + model upon. If not provided, placeholders will be created. target_tensors: Optional list of target tensors for compiling the model. If not provided, placeholders will be created. custom_objects: Optional dictionary mapping string names to custom classes @@ -591,10 +592,10 @@ def clone_and_build_model( optimizer if the clone is compiled. This argument is used when a Keras model is cloned into an Estimator model function, because Estimators create their own global step variable. - optimizer_config: Optimizer config dictionary returned from `get_config()`. - This argument should be defined if `clone_and_build_model` is called in - a different graph or session from the original model, and the optimizer is - an instance of `OptimizerV2`. + optimizer_config: Optimizer config dictionary or list of dictionary + returned from `get_config()`. This argument should be defined if + `clone_and_build_model` is called in a different graph or session from + the original model, and the optimizer is an instance of `OptimizerV2`. Returns: Clone of the model. @@ -628,16 +629,24 @@ def clone_and_build_model( clone._set_inputs( K.placeholder(model._build_input_shape, dtype=model.inputs[0].dtype)) else: - if not in_place_reset: - raise ValueError( - 'This model is a subclassed model. ' - 'Such a model cannot be cloned, but there is a workaround where ' - 'the model is reset in-place. To use this, please set the argument ' - '`in_place_reset` to `True`. This will reset the attributes in the ' - 'original model. To restore the attributes, call ' - '`in_place_subclassed_model_state_restoration(model)`.') - clone = model - _in_place_subclassed_model_reset(clone) + try: + # Prefer clonining the model if serial/deserial logic is implemented for + # subclassed model. + clone = model.__class__.from_config(model.get_config()) + except NotImplementedError: + logging.warning('This model is a subclassed model. Please implement ' + '`get_config` and `from_config` to better support ' + 'cloning the model.') + if not in_place_reset: + raise ValueError( + 'This model is a subclassed model. ' + 'Such a model cannot be cloned, but there is a workaround where ' + 'the model is reset in-place. To use this, please set the argument ' + '`in_place_reset` to `True`. This will reset the attributes in the ' + 'original model. To restore the attributes, call ' + '`in_place_subclassed_model_state_restoration(model)`.') + clone = model + _in_place_subclassed_model_reset(clone) if input_tensors is not None: if isinstance(input_tensors, (list, tuple)) and len(input_tensors) == 1: input_tensors = input_tensors[0] @@ -649,11 +658,27 @@ def clone_and_build_model( orig_optimizer.optimizer, optimizer_iterations) K.track_tf_optimizer(optimizer) else: - optimizer_config = optimizer_config or orig_optimizer.get_config() - optimizer = orig_optimizer.__class__.from_config(optimizer_config) + if not isinstance(orig_optimizer, (tuple, list)): + orig_optimizer = [orig_optimizer] + if optimizer_config is None: + optimizer = [ + opt.__class__.from_config(opt.get_config()) + for opt in orig_optimizer + ] + elif isinstance(optimizer_config, dict): + optimizer = [orig_optimizer[0].__class__.from_config(optimizer_config)] + else: + # optimizer config is list of dict, same order as orig_optimizer. + optimizer = [ + opt.__class__.from_config(opt_config) + for (opt, opt_config) in zip(orig_optimizer, optimizer_config) + ] if optimizer_iterations is not None: - optimizer.iterations = optimizer_iterations + for opt in optimizer: + opt.iterations = optimizer_iterations + if len(optimizer) == 1: + optimizer = optimizer[0] clone.compile( optimizer, model.loss,