Enable model_to_estimator for wide and deep model.
PiperOrigin-RevId: 264260434
This commit is contained in:
parent
8e2ca26f57
commit
372a2dbe96
@ -1,6 +1,6 @@
|
||||
tensorflow/__init__.py
|
||||
tensorflow/api_template.__init__.py
|
||||
tensorflow/api_template_v1.__init__.py
|
||||
tensorflow/api_template.__init__.py
|
||||
tensorflow/compat_template.__init__.py
|
||||
tensorflow/compat_template_v1.__init__.py
|
||||
tensorflow/contrib/mpi/BUILD
|
||||
@ -28,23 +28,23 @@ tensorflow/third_party/com_google_absl.BUILD
|
||||
tensorflow/third_party/common.bzl
|
||||
tensorflow/third_party/cub.BUILD
|
||||
tensorflow/third_party/curl.BUILD
|
||||
tensorflow/third_party/double_conversion.BUILD
|
||||
tensorflow/third_party/cython.BUILD
|
||||
tensorflow/third_party/eigen3/Eigen/Core
|
||||
tensorflow/third_party/eigen3/Eigen/Cholesky
|
||||
tensorflow/third_party/eigen3/Eigen/Eigenvalues
|
||||
tensorflow/third_party/eigen3/Eigen/Core
|
||||
tensorflow/third_party/eigen3/Eigen/LU
|
||||
tensorflow/third_party/eigen3/Eigen/SVD
|
||||
tensorflow/third_party/eigen3/Eigen/QR
|
||||
tensorflow/third_party/eigen3/Eigen/SVD
|
||||
tensorflow/third_party/eigen3/BUILD
|
||||
tensorflow/third_party/eigen3/LICENSE
|
||||
tensorflow/third_party/eigen3/gpu_packet_math.patch
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
|
||||
@ -52,7 +52,7 @@ tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCasting
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
|
||||
tensorflow/third_party/cython.BUILD
|
||||
tensorflow/third_party/double_conversion.BUILD
|
||||
tensorflow/third_party/eigen.BUILD
|
||||
tensorflow/third_party/fft2d/BUILD
|
||||
tensorflow/third_party/fft2d/LICENSE
|
||||
@ -64,28 +64,28 @@ tensorflow/third_party/farmhash.BUILD
|
||||
tensorflow/third_party/functools32.BUILD
|
||||
tensorflow/third_party/gast.BUILD
|
||||
tensorflow/third_party/git/BUILD
|
||||
tensorflow/third_party/git/git_configure.bzl
|
||||
tensorflow/third_party/git/BUILD.tpl
|
||||
tensorflow/third_party/git/git_configure.bzl
|
||||
tensorflow/third_party/gif.BUILD
|
||||
tensorflow/third_party/gpus/crosstool/BUILD
|
||||
tensorflow/third_party/gpus/crosstool/BUILD.tpl
|
||||
tensorflow/third_party/gpus/crosstool/LICENSE
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
|
||||
tensorflow/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
|
||||
tensorflow/third_party/gpus/BUILD
|
||||
tensorflow/third_party/gpus/cuda/BUILD.tpl
|
||||
tensorflow/third_party/gpus/cuda/BUILD
|
||||
tensorflow/third_party/gpus/cuda/BUILD.tpl
|
||||
tensorflow/third_party/gpus/cuda/BUILD.windows.tpl
|
||||
tensorflow/third_party/gpus/cuda/LICENSE
|
||||
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/cuda/cuda_config.h.tpl
|
||||
tensorflow/third_party/gpus/cuda_configure.bzl
|
||||
tensorflow/third_party/gpus/rocm/BUILD
|
||||
tensorflow/third_party/gpus/rocm/BUILD.tpl
|
||||
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
|
||||
tensorflow/third_party/gpus/find_cuda_config.py
|
||||
tensorflow/third_party/gpus/cuda_configure.bzl
|
||||
tensorflow/third_party/gpus/rocm_configure.bzl
|
||||
tensorflow/third_party/googleapis.BUILD
|
||||
tensorflow/third_party/grpc/BUILD
|
||||
@ -99,12 +99,12 @@ tensorflow/third_party/llvm/expand_cmake_vars.py
|
||||
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
|
||||
tensorflow/third_party/llvm/llvm.bzl
|
||||
tensorflow/third_party/linenoise.BUILD
|
||||
tensorflow/third_party/lmdb.BUILD
|
||||
tensorflow/third_party/mkl/BUILD
|
||||
tensorflow/third_party/mkl/LICENSE
|
||||
tensorflow/third_party/mkl/MKL_LICENSE
|
||||
tensorflow/third_party/mkl/mkl.BUILD
|
||||
tensorflow/third_party/mkl/build_defs.bzl
|
||||
tensorflow/third_party/lmdb.BUILD
|
||||
tensorflow/third_party/mkl_dnn/LICENSE
|
||||
tensorflow/third_party/mkl_dnn/mkldnn.BUILD
|
||||
tensorflow/third_party/mpi/.gitignore
|
||||
@ -112,19 +112,19 @@ tensorflow/third_party/mpi/BUILD
|
||||
tensorflow/third_party/mpi_collectives/BUILD
|
||||
tensorflow/third_party/nanopb.BUILD
|
||||
tensorflow/third_party/nccl/BUILD
|
||||
tensorflow/third_party/nccl/archive.BUILD
|
||||
tensorflow/third_party/nccl/LICENSE
|
||||
tensorflow/third_party/nccl/build_defs.bzl.tpl
|
||||
tensorflow/third_party/nccl/archive.BUILD
|
||||
tensorflow/third_party/nccl/archive.patch
|
||||
tensorflow/third_party/nccl/build_defs.bzl.tpl
|
||||
tensorflow/third_party/nccl/nccl_configure.bzl
|
||||
tensorflow/third_party/nccl/system.BUILD.tpl
|
||||
tensorflow/third_party/ngraph/BUILD
|
||||
tensorflow/third_party/ngraph/LICENSE
|
||||
tensorflow/third_party/ngraph/NGRAPH_LICENSE
|
||||
tensorflow/third_party/ngraph/build_defs.bzl
|
||||
tensorflow/third_party/ngraph/LICENSE
|
||||
tensorflow/third_party/ngraph/ngraph.BUILD
|
||||
tensorflow/third_party/ngraph/nlohmann_json.BUILD
|
||||
tensorflow/third_party/ngraph/build_defs.bzl
|
||||
tensorflow/third_party/ngraph/ngraph_tf.BUILD
|
||||
tensorflow/third_party/ngraph/nlohmann_json.BUILD
|
||||
tensorflow/third_party/ngraph/tbb.BUILD
|
||||
tensorflow/third_party/opt_einsum.BUILD
|
||||
tensorflow/third_party/pcre.BUILD
|
||||
@ -133,46 +133,46 @@ tensorflow/third_party/protobuf/BUILD
|
||||
tensorflow/third_party/png_fix_rpi.patch
|
||||
tensorflow/third_party/pprof.BUILD
|
||||
tensorflow/third_party/py/numpy/BUILD
|
||||
tensorflow/third_party/py/BUILD
|
||||
tensorflow/third_party/py/BUILD.tpl
|
||||
tensorflow/third_party/py/BUILD
|
||||
tensorflow/third_party/py/python_configure.bzl
|
||||
tensorflow/third_party/python_runtime/BUILD
|
||||
tensorflow/third_party/pybind11.BUILD
|
||||
tensorflow/third_party/repo.bzl
|
||||
tensorflow/third_party/six.BUILD
|
||||
tensorflow/third_party/sycl/crosstool/BUILD
|
||||
tensorflow/third_party/snappy.BUILD
|
||||
tensorflow/third_party/sqlite.BUILD
|
||||
tensorflow/third_party/sycl/crosstool/BUILD
|
||||
tensorflow/third_party/six.BUILD
|
||||
tensorflow/third_party/swig.BUILD
|
||||
tensorflow/third_party/systemlibs/BUILD.tpl
|
||||
tensorflow/third_party/systemlibs/BUILD
|
||||
tensorflow/third_party/systemlibs/BUILD.tpl
|
||||
tensorflow/third_party/systemlibs/absl_py.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
|
||||
tensorflow/third_party/systemlibs/astor.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
|
||||
tensorflow/third_party/systemlibs/boringssl.BUILD
|
||||
tensorflow/third_party/systemlibs/curl.BUILD
|
||||
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
|
||||
tensorflow/third_party/systemlibs/curl.BUILD
|
||||
tensorflow/third_party/systemlibs/cython.BUILD
|
||||
tensorflow/third_party/systemlibs/double_conversion.BUILD
|
||||
tensorflow/third_party/systemlibs/gast.BUILD
|
||||
tensorflow/third_party/systemlibs/gif.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
|
||||
tensorflow/third_party/systemlibs/googleapis.BUILD
|
||||
tensorflow/third_party/systemlibs/grpc.BUILD
|
||||
tensorflow/third_party/systemlibs/googleapis.BUILD
|
||||
tensorflow/third_party/systemlibs/lmdb.BUILD
|
||||
tensorflow/third_party/systemlibs/jsoncpp.BUILD
|
||||
tensorflow/third_party/systemlibs/nsync.BUILD
|
||||
tensorflow/third_party/systemlibs/opt_einsum.BUILD
|
||||
tensorflow/third_party/systemlibs/pcre.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
tensorflow/third_party/systemlibs/png.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.bzl
|
||||
tensorflow/third_party/systemlibs/re2.BUILD
|
||||
tensorflow/third_party/systemlibs/snappy.BUILD
|
||||
tensorflow/third_party/systemlibs/six.BUILD
|
||||
tensorflow/third_party/systemlibs/sqlite.BUILD
|
||||
tensorflow/third_party/systemlibs/snappy.BUILD
|
||||
tensorflow/third_party/systemlibs/swig.BUILD
|
||||
tensorflow/third_party/systemlibs/syslibs_configure.bzl
|
||||
tensorflow/third_party/systemlibs/termcolor.BUILD
|
||||
@ -188,12 +188,12 @@ tensorflow/third_party/tflite_mobilenet.BUILD
|
||||
tensorflow/third_party/tflite_mobilenet_float.BUILD
|
||||
tensorflow/third_party/toolchains/clang6/BUILD
|
||||
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
|
||||
tensorflow/third_party/toolchains/clang6/clang.BUILD
|
||||
tensorflow/third_party/toolchains/clang6/README.md
|
||||
tensorflow/third_party/toolchains/clang6/clang.BUILD
|
||||
tensorflow/third_party/toolchains/clang6/repo.bzl
|
||||
tensorflow/third_party/toolchains/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/arm/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl
|
||||
tensorflow/third_party/toolchains/cpus/arm/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/arm/cc_config.bzl.tpl
|
||||
tensorflow/third_party/toolchains/cpus/py/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/py3/BUILD
|
||||
@ -214,13 +214,13 @@ tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
|
||||
@ -239,44 +239,44 @@ tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3_opt/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5.1/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5.1/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_025/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
|
||||
tensorflow/third_party/toolchains/remote/BUILD
|
||||
tensorflow/third_party/toolchains/remote/BUILD.tpl
|
||||
tensorflow/third_party/toolchains/remote/configure.bzl
|
||||
tensorflow/third_party/toolchains/remote/BUILD
|
||||
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
|
||||
tensorflow/third_party/tflite_ovic_testdata.BUILD
|
||||
tensorflow/third_party/tflite_mobilenet_quant.BUILD
|
||||
tensorflow/third_party/tflite_ovic_testdata.BUILD
|
||||
tensorflow/third_party/tflite_smartreply.BUILD
|
||||
tensorflow/third_party/wrapt.BUILD
|
||||
tensorflow/third_party/zlib.BUILD
|
||||
tensorflow/tools/ci_build/remote/BUILD
|
||||
tensorflow/tools/def_file_filter/BUILD
|
||||
tensorflow/tools/def_file_filter/BUILD.tpl
|
||||
tensorflow/tools/def_file_filter/def_file_filter.py.tpl
|
||||
tensorflow/tools/def_file_filter/BUILD.tpl
|
||||
tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
|
||||
tensorflow/tools/lib_package/BUILD
|
||||
tensorflow/tools/lib_package/LibTensorFlowTest.java
|
||||
tensorflow/tools/lib_package/README.md
|
||||
tensorflow/tools/lib_package/concat_licenses.sh
|
||||
tensorflow/tools/lib_package/libtensorflow_test.c
|
||||
tensorflow/tools/lib_package/libtensorflow_java_test.sh
|
||||
tensorflow/tools/lib_package/libtensorflow_test.c
|
||||
tensorflow/tools/lib_package/libtensorflow_test.sh
|
||||
tensorflow/tools/pip_package/MANIFEST.in
|
||||
tensorflow/tools/pip_package/BUILD
|
||||
tensorflow/tools/pip_package/MANIFEST.in
|
||||
tensorflow/tools/pip_package/README
|
||||
tensorflow/tools/pip_package/build_pip_package.sh
|
||||
tensorflow/tools/pip_package/pip_smoke_test.py
|
||||
tensorflow/tools/pip_package/setup.py
|
||||
tensorflow/tools/pip_package/simple_console.py
|
||||
tensorflow/tools/pip_package/simple_console_for_windows.py
|
||||
tensorflow/tools/pip_package/build_pip_package.sh
|
||||
tensorflow/tools/pip_package/check_load_py_test.py
|
||||
tensorflow/tools/pip_package/simple_console.py
|
||||
tensorflow/tools/pip_package/setup.py
|
||||
tensorflow/tools/pip_package/simple_console_for_windows.py
|
||||
tensorflow/virtual_root_template_v2.__init__.py
|
||||
tensorflow/virtual_root_template_v1.__init__.py
|
||||
llvm/llvm/projects/google_mlir/WORKSPACE
|
@ -32,6 +32,7 @@ from tensorflow.python.keras.engine.input_layer import InputLayer
|
||||
from tensorflow.python.keras.engine.network import Network
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import object_identity
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
@ -575,8 +576,8 @@ def clone_and_build_model(
|
||||
Args:
|
||||
model: `tf.keras.Model` object. Can be Functional, Sequential, or
|
||||
sub-classed.
|
||||
input_tensors: Optional list of input tensors to build the model upon. If
|
||||
not provided, placeholders will be created.
|
||||
input_tensors: Optional list or dictionary of input tensors to build the
|
||||
model upon. If not provided, placeholders will be created.
|
||||
target_tensors: Optional list of target tensors for compiling the model. If
|
||||
not provided, placeholders will be created.
|
||||
custom_objects: Optional dictionary mapping string names to custom classes
|
||||
@ -591,10 +592,10 @@ def clone_and_build_model(
|
||||
optimizer if the clone is compiled. This argument is used when a Keras
|
||||
model is cloned into an Estimator model function, because Estimators
|
||||
create their own global step variable.
|
||||
optimizer_config: Optimizer config dictionary returned from `get_config()`.
|
||||
This argument should be defined if `clone_and_build_model` is called in
|
||||
a different graph or session from the original model, and the optimizer is
|
||||
an instance of `OptimizerV2`.
|
||||
optimizer_config: Optimizer config dictionary or list of dictionary
|
||||
returned from `get_config()`. This argument should be defined if
|
||||
`clone_and_build_model` is called in a different graph or session from
|
||||
the original model, and the optimizer is an instance of `OptimizerV2`.
|
||||
|
||||
Returns:
|
||||
Clone of the model.
|
||||
@ -628,16 +629,24 @@ def clone_and_build_model(
|
||||
clone._set_inputs(
|
||||
K.placeholder(model._build_input_shape, dtype=model.inputs[0].dtype))
|
||||
else:
|
||||
if not in_place_reset:
|
||||
raise ValueError(
|
||||
'This model is a subclassed model. '
|
||||
'Such a model cannot be cloned, but there is a workaround where '
|
||||
'the model is reset in-place. To use this, please set the argument '
|
||||
'`in_place_reset` to `True`. This will reset the attributes in the '
|
||||
'original model. To restore the attributes, call '
|
||||
'`in_place_subclassed_model_state_restoration(model)`.')
|
||||
clone = model
|
||||
_in_place_subclassed_model_reset(clone)
|
||||
try:
|
||||
# Prefer clonining the model if serial/deserial logic is implemented for
|
||||
# subclassed model.
|
||||
clone = model.__class__.from_config(model.get_config())
|
||||
except NotImplementedError:
|
||||
logging.warning('This model is a subclassed model. Please implement '
|
||||
'`get_config` and `from_config` to better support '
|
||||
'cloning the model.')
|
||||
if not in_place_reset:
|
||||
raise ValueError(
|
||||
'This model is a subclassed model. '
|
||||
'Such a model cannot be cloned, but there is a workaround where '
|
||||
'the model is reset in-place. To use this, please set the argument '
|
||||
'`in_place_reset` to `True`. This will reset the attributes in the '
|
||||
'original model. To restore the attributes, call '
|
||||
'`in_place_subclassed_model_state_restoration(model)`.')
|
||||
clone = model
|
||||
_in_place_subclassed_model_reset(clone)
|
||||
if input_tensors is not None:
|
||||
if isinstance(input_tensors, (list, tuple)) and len(input_tensors) == 1:
|
||||
input_tensors = input_tensors[0]
|
||||
@ -649,11 +658,27 @@ def clone_and_build_model(
|
||||
orig_optimizer.optimizer, optimizer_iterations)
|
||||
K.track_tf_optimizer(optimizer)
|
||||
else:
|
||||
optimizer_config = optimizer_config or orig_optimizer.get_config()
|
||||
optimizer = orig_optimizer.__class__.from_config(optimizer_config)
|
||||
if not isinstance(orig_optimizer, (tuple, list)):
|
||||
orig_optimizer = [orig_optimizer]
|
||||
if optimizer_config is None:
|
||||
optimizer = [
|
||||
opt.__class__.from_config(opt.get_config())
|
||||
for opt in orig_optimizer
|
||||
]
|
||||
elif isinstance(optimizer_config, dict):
|
||||
optimizer = [orig_optimizer[0].__class__.from_config(optimizer_config)]
|
||||
else:
|
||||
# optimizer config is list of dict, same order as orig_optimizer.
|
||||
optimizer = [
|
||||
opt.__class__.from_config(opt_config)
|
||||
for (opt, opt_config) in zip(orig_optimizer, optimizer_config)
|
||||
]
|
||||
if optimizer_iterations is not None:
|
||||
optimizer.iterations = optimizer_iterations
|
||||
for opt in optimizer:
|
||||
opt.iterations = optimizer_iterations
|
||||
|
||||
if len(optimizer) == 1:
|
||||
optimizer = optimizer[0]
|
||||
clone.compile(
|
||||
optimizer,
|
||||
model.loss,
|
||||
|
Loading…
Reference in New Issue
Block a user