MOve backend config related functions to a dedicated file.
Use backend epsilon if epsilon argument is None. PiperOrigin-RevId: 229756902
This commit is contained in:
parent
a782c9df3e
commit
c9bd0744ea
@ -82,6 +82,7 @@ py_library(
|
||||
srcs = ["backend.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":backend_config",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
@ -117,6 +118,12 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "backend_config",
|
||||
srcs = ["backend_config.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "engine",
|
||||
srcs = [
|
||||
@ -1108,6 +1115,18 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "backend_config_test",
|
||||
size = "medium",
|
||||
srcs = ["backend_config_test.py"],
|
||||
additional_deps = [
|
||||
":keras",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "keras_parameterized_test",
|
||||
size = "small",
|
||||
|
@ -40,6 +40,7 @@ from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras import backend_config
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -97,15 +98,6 @@ _DUMMY_EAGER_GRAPH = _DummyEagerGraph()
|
||||
# Change its value via `manual_variable_initialization(value)`.
|
||||
_MANUAL_VAR_INIT = False
|
||||
|
||||
# The type of float to use throughout a session.
|
||||
_FLOATX = 'float32'
|
||||
|
||||
# Epsilon fuzz factor used throughout the codebase.
|
||||
_EPSILON = 1e-7
|
||||
|
||||
# Default image data format, one of "channels_last", "channels_first".
|
||||
_IMAGE_DATA_FORMAT = 'channels_last'
|
||||
|
||||
# This list holds the available devices.
|
||||
# It is populated when `_get_available_gpus()` is called for the first time.
|
||||
# We assume our devices don't change henceforth.
|
||||
@ -119,6 +111,14 @@ _GRAPH_VARIABLES = weakref.WeakKeyDictionary()
|
||||
# the graph.
|
||||
_GRAPH_TF_OPTIMIZERS = weakref.WeakKeyDictionary()
|
||||
|
||||
# The below functions are kept accessible from backend for compatibility.
|
||||
epsilon = backend_config.epsilon
|
||||
floatx = backend_config.floatx
|
||||
image_data_format = backend_config.image_data_format
|
||||
set_epsilon = backend_config.set_epsilon
|
||||
set_floatx = backend_config.set_floatx
|
||||
set_image_data_format = backend_config.set_image_data_format
|
||||
|
||||
|
||||
@keras_export('keras.backend.backend')
|
||||
def backend():
|
||||
@ -132,87 +132,6 @@ def backend():
|
||||
return 'tensorflow'
|
||||
|
||||
|
||||
@keras_export('keras.backend.epsilon')
|
||||
def epsilon():
|
||||
"""Returns the value of the fuzz factor used in numeric expressions.
|
||||
|
||||
Returns:
|
||||
A float.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> keras.backend.epsilon()
|
||||
1e-07
|
||||
```
|
||||
"""
|
||||
return _EPSILON
|
||||
|
||||
|
||||
@keras_export('keras.backend.set_epsilon')
|
||||
def set_epsilon(value):
|
||||
"""Sets the value of the fuzz factor used in numeric expressions.
|
||||
|
||||
Arguments:
|
||||
value: float. New value of epsilon.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from keras import backend as K
|
||||
>>> K.epsilon()
|
||||
1e-07
|
||||
>>> K.set_epsilon(1e-05)
|
||||
>>> K.epsilon()
|
||||
1e-05
|
||||
```
|
||||
"""
|
||||
global _EPSILON
|
||||
_EPSILON = value
|
||||
|
||||
|
||||
@keras_export('keras.backend.floatx')
|
||||
def floatx():
|
||||
"""Returns the default float type, as a string.
|
||||
|
||||
E.g. 'float16', 'float32', 'float64'.
|
||||
|
||||
Returns:
|
||||
String, the current default float type.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> keras.backend.floatx()
|
||||
'float32'
|
||||
```
|
||||
"""
|
||||
return _FLOATX
|
||||
|
||||
|
||||
@keras_export('keras.backend.set_floatx')
|
||||
def set_floatx(value):
|
||||
"""Sets the default float type.
|
||||
|
||||
Arguments:
|
||||
value: String; 'float16', 'float32', or 'float64'.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from keras import backend as K
|
||||
>>> K.floatx()
|
||||
'float32'
|
||||
>>> K.set_floatx('float16')
|
||||
>>> K.floatx()
|
||||
'float16'
|
||||
```
|
||||
|
||||
Raises:
|
||||
ValueError: In case of invalid value.
|
||||
"""
|
||||
global _FLOATX
|
||||
if value not in {'float16', 'float32', 'float64'}:
|
||||
raise ValueError('Unknown floatx type: ' + str(value))
|
||||
_FLOATX = str(value)
|
||||
|
||||
|
||||
@keras_export('keras.backend.cast_to_floatx')
|
||||
def cast_to_floatx(x):
|
||||
"""Cast a Numpy array to the default Keras float type.
|
||||
@ -238,49 +157,7 @@ def cast_to_floatx(x):
|
||||
dtype('float32')
|
||||
```
|
||||
"""
|
||||
return np.asarray(x, dtype=_FLOATX)
|
||||
|
||||
|
||||
@keras_export('keras.backend.image_data_format')
|
||||
def image_data_format():
|
||||
"""Returns the default image data format convention.
|
||||
|
||||
Returns:
|
||||
A string, either `'channels_first'` or `'channels_last'`
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> keras.backend.image_data_format()
|
||||
'channels_first'
|
||||
```
|
||||
"""
|
||||
return _IMAGE_DATA_FORMAT
|
||||
|
||||
|
||||
@keras_export('keras.backend.set_image_data_format')
|
||||
def set_image_data_format(data_format):
|
||||
"""Sets the value of the image data format convention.
|
||||
|
||||
Arguments:
|
||||
data_format: string. `'channels_first'` or `'channels_last'`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from keras import backend as K
|
||||
>>> K.image_data_format()
|
||||
'channels_first'
|
||||
>>> K.set_image_data_format('channels_last')
|
||||
>>> K.image_data_format()
|
||||
'channels_last'
|
||||
```
|
||||
|
||||
Raises:
|
||||
ValueError: In case of invalid `data_format` value.
|
||||
"""
|
||||
global _IMAGE_DATA_FORMAT
|
||||
if data_format not in {'channels_last', 'channels_first'}:
|
||||
raise ValueError('Unknown data_format: ' + str(data_format))
|
||||
_IMAGE_DATA_FORMAT = str(data_format)
|
||||
return np.asarray(x, dtype=floatx())
|
||||
|
||||
|
||||
# A global dictionary mapping graph objects to an index of counters used
|
||||
|
126
tensorflow/python/keras/backend_config.py
Normal file
126
tensorflow/python/keras/backend_config.py
Normal file
@ -0,0 +1,126 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Keras backend config API."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
# The type of float to use throughout a session.
|
||||
_FLOATX = 'float32'
|
||||
|
||||
# Epsilon fuzz factor used throughout the codebase.
|
||||
_EPSILON = 1e-7
|
||||
|
||||
# Default image data format, one of "channels_last", "channels_first".
|
||||
_IMAGE_DATA_FORMAT = 'channels_last'
|
||||
|
||||
|
||||
@keras_export('keras.backend.epsilon')
|
||||
def epsilon():
|
||||
"""Returns the value of the fuzz factor used in numeric expressions.
|
||||
|
||||
Returns:
|
||||
A float.
|
||||
|
||||
Example:
|
||||
```python
|
||||
keras.backend.epsilon() >>>1e-07
|
||||
```
|
||||
"""
|
||||
return _EPSILON
|
||||
|
||||
|
||||
@keras_export('keras.backend.set_epsilon')
|
||||
def set_epsilon(value):
|
||||
"""Sets the value of the fuzz factor used in numeric expressions.
|
||||
|
||||
Arguments:
|
||||
value: float. New value of epsilon.
|
||||
Example: ```python from keras import backend as K K.epsilon() >>> 1e-07
|
||||
K.set_epsilon(1e-05) K.epsilon() >>> 1e-05 ```
|
||||
"""
|
||||
global _EPSILON
|
||||
_EPSILON = value
|
||||
|
||||
|
||||
@keras_export('keras.backend.floatx')
|
||||
def floatx():
|
||||
"""Returns the default float type, as a string.
|
||||
|
||||
E.g. 'float16', 'float32', 'float64'.
|
||||
|
||||
Returns:
|
||||
String, the current default float type.
|
||||
|
||||
Example:
|
||||
```python
|
||||
keras.backend.floatx() >>> 'float32'
|
||||
```
|
||||
"""
|
||||
return _FLOATX
|
||||
|
||||
|
||||
@keras_export('keras.backend.set_floatx')
|
||||
def set_floatx(value):
|
||||
"""Sets the default float type.
|
||||
|
||||
Arguments:
|
||||
value: String; 'float16', 'float32', or 'float64'.
|
||||
Example: ```python from keras import backend as K K.floatx() >>> 'float32'
|
||||
K.set_floatx('float16') K.floatx() >>> 'float16' ```
|
||||
|
||||
Raises:
|
||||
ValueError: In case of invalid value.
|
||||
"""
|
||||
global _FLOATX
|
||||
if value not in {'float16', 'float32', 'float64'}:
|
||||
raise ValueError('Unknown floatx type: ' + str(value))
|
||||
_FLOATX = str(value)
|
||||
|
||||
|
||||
@keras_export('keras.backend.image_data_format')
|
||||
def image_data_format():
|
||||
"""Returns the default image data format convention.
|
||||
|
||||
Returns:
|
||||
A string, either `'channels_first'` or `'channels_last'`
|
||||
|
||||
Example:
|
||||
```python
|
||||
keras.backend.image_data_format() >>> 'channels_first'
|
||||
```
|
||||
"""
|
||||
return _IMAGE_DATA_FORMAT
|
||||
|
||||
|
||||
@keras_export('keras.backend.set_image_data_format')
|
||||
def set_image_data_format(data_format):
|
||||
"""Sets the value of the image data format convention.
|
||||
|
||||
Arguments:
|
||||
data_format: string. `'channels_first'` or `'channels_last'`.
|
||||
Example: ```python from keras import backend as K K.image_data_format() >>>
|
||||
'channels_first' K.set_image_data_format('channels_last')
|
||||
K.image_data_format() >>> 'channels_last' ```
|
||||
|
||||
Raises:
|
||||
ValueError: In case of invalid `data_format` value.
|
||||
"""
|
||||
global _IMAGE_DATA_FORMAT
|
||||
if data_format not in {'channels_last', 'channels_first'}:
|
||||
raise ValueError('Unknown data_format: ' + str(data_format))
|
||||
_IMAGE_DATA_FORMAT = str(data_format)
|
55
tensorflow/python/keras/backend_config_test.py
Normal file
55
tensorflow/python/keras/backend_config_test.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for backend_config."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class BackendConfigTest(test.TestCase):
|
||||
|
||||
def test_backend(self):
|
||||
self.assertEqual(keras.backend.backend(), 'tensorflow')
|
||||
|
||||
def test_espilon(self):
|
||||
epsilon = 1e-2
|
||||
keras.backend_config.set_epsilon(epsilon)
|
||||
self.assertEqual(keras.backend_config.epsilon(), epsilon)
|
||||
keras.backend_config.set_epsilon(1e-7)
|
||||
self.assertEqual(keras.backend_config.epsilon(), 1e-7)
|
||||
|
||||
def test_floatx(self):
|
||||
floatx = 'float64'
|
||||
keras.backend_config.set_floatx(floatx)
|
||||
self.assertEqual(keras.backend_config.floatx(), floatx)
|
||||
keras.backend_config.set_floatx('float32')
|
||||
self.assertEqual(keras.backend_config.floatx(), 'float32')
|
||||
|
||||
def test_image_data_format(self):
|
||||
image_data_format = 'channels_first'
|
||||
keras.backend_config.set_image_data_format(image_data_format)
|
||||
self.assertEqual(keras.backend_config.image_data_format(),
|
||||
image_data_format)
|
||||
keras.backend_config.set_image_data_format('channels_last')
|
||||
self.assertEqual(keras.backend_config.image_data_format(), 'channels_last')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -99,24 +99,6 @@ class BackendUtilsTest(test.TestCase):
|
||||
def test_backend(self):
|
||||
self.assertEqual(keras.backend.backend(), 'tensorflow')
|
||||
|
||||
def test_espilon(self):
|
||||
epsilon = 1e-2
|
||||
keras.backend.set_epsilon(epsilon)
|
||||
self.assertEqual(keras.backend.epsilon(), epsilon)
|
||||
keras.backend.set_epsilon(1e-7)
|
||||
|
||||
def test_floatx(self):
|
||||
floatx = 'float64'
|
||||
keras.backend.set_floatx(floatx)
|
||||
self.assertEqual(keras.backend.floatx(), floatx)
|
||||
keras.backend.set_floatx('float32')
|
||||
|
||||
def test_image_data_format(self):
|
||||
image_data_format = 'channels_first'
|
||||
keras.backend.set_image_data_format(image_data_format)
|
||||
self.assertEqual(keras.backend.image_data_format(), image_data_format)
|
||||
keras.backend.set_image_data_format('channels_last')
|
||||
|
||||
def test_get_reset_uids(self):
|
||||
self.assertEqual(keras.backend.get_uid('foo'), 1)
|
||||
self.assertEqual(keras.backend.get_uid('foo'), 2)
|
||||
|
@ -35,6 +35,7 @@ py_library(
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute:reduce_util",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/keras:backend_config",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.keras import backend_config
|
||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||
from tensorflow.python.training import training_ops
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
@ -90,6 +91,8 @@ class Adadelta(optimizer_v2.OptimizerV2):
|
||||
invocations of optimizer functions.
|
||||
@end_compatibility
|
||||
"""
|
||||
if epsilon is None:
|
||||
epsilon = backend_config.epsilon()
|
||||
super(Adadelta, self).__init__(name, **kwargs)
|
||||
self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
|
||||
self._set_hyper('decay', self._initial_decay)
|
||||
|
@ -181,6 +181,15 @@ class AdadeltaOptimizerTest(test.TestCase):
|
||||
self.assertAllClose(self.evaluate(opt_2.lr), (1.0))
|
||||
self.assertAllClose(self.evaluate(opt_3.lr), (0.1))
|
||||
|
||||
def testConstructAdadeltaWithEpsilonValues(self):
|
||||
opt = adadelta.Adadelta(epsilon=None)
|
||||
config = opt.get_config()
|
||||
self.assertEqual(config["epsilon"], 1e-7)
|
||||
|
||||
opt = adadelta.Adadelta(epsilon=1e-8)
|
||||
config = opt.get_config()
|
||||
self.assertEqual(config["epsilon"], 1e-8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend_config
|
||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
@ -89,6 +90,8 @@ class Adagrad(optimizer_v2.OptimizerV2):
|
||||
if initial_accumulator_value < 0.0:
|
||||
raise ValueError('initial_accumulator_value must be non-negative: %s' %
|
||||
initial_accumulator_value)
|
||||
if epsilon is None:
|
||||
epsilon = backend_config.epsilon()
|
||||
if epsilon < 1e-7:
|
||||
raise ValueError('epsilon must be larger than 1e-7: %s' % epsilon)
|
||||
super(Adagrad, self).__init__(name, **kwargs)
|
||||
|
@ -411,6 +411,19 @@ class AdagradOptimizerTest(test.TestCase):
|
||||
self.assertAllClose(self.evaluate(opt_2.lr), (1.0))
|
||||
self.assertAllClose(self.evaluate(opt_3.lr), (0.1))
|
||||
|
||||
def testConstructAdagradWithEpsilonValues(self):
|
||||
opt = adagrad.Adagrad(epsilon=None)
|
||||
config = opt.get_config()
|
||||
self.assertEqual(config["epsilon"], 1e-7)
|
||||
|
||||
opt = adagrad.Adagrad(epsilon=1e-6)
|
||||
config = opt.get_config()
|
||||
self.assertEqual(config["epsilon"], 1e-6)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"epsilon must be larger than 1e-7"):
|
||||
opt = adagrad.Adagrad(epsilon=1e-8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -18,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend_config
|
||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -131,6 +132,8 @@ class Adam(optimizer_v2.OptimizerV2):
|
||||
compatibility, recommended to use `learning_rate` instead.
|
||||
"""
|
||||
|
||||
if epsilon is None:
|
||||
epsilon = backend_config.epsilon()
|
||||
super(Adam, self).__init__(name, **kwargs)
|
||||
self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
|
||||
self._set_hyper('decay', self._initial_decay)
|
||||
|
@ -516,6 +516,15 @@ class AdamOptimizerTest(test.TestCase):
|
||||
self.assertAllClose(self.evaluate(opt_2.lr), (1.0))
|
||||
self.assertAllClose(self.evaluate(opt_3.lr), (0.1))
|
||||
|
||||
def testConstructAdamWithEpsilonValues(self):
|
||||
opt = adam.Adam(epsilon=None)
|
||||
config = opt.get_config()
|
||||
self.assertEqual(config["epsilon"], 1e-7)
|
||||
|
||||
opt = adam.Adam(epsilon=1e-8)
|
||||
config = opt.get_config()
|
||||
self.assertEqual(config["epsilon"], 1e-8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend_config
|
||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -95,6 +96,8 @@ class Adamax(optimizer_v2.OptimizerV2):
|
||||
allow time inverse decay of learning rate. `lr` is included for backward
|
||||
compatibility, recommended to use `learning_rate` instead.
|
||||
"""
|
||||
if epsilon is None:
|
||||
epsilon = backend_config.epsilon()
|
||||
super(Adamax, self).__init__(name, **kwargs)
|
||||
self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
|
||||
self._set_hyper('decay', self._initial_decay)
|
||||
|
@ -375,6 +375,15 @@ class AdamaxOptimizerTest(test.TestCase):
|
||||
self.assertAllClose(self.evaluate(opt_2.lr), (1.0))
|
||||
self.assertAllClose(self.evaluate(opt_3.lr), (0.1))
|
||||
|
||||
def testConstructAdamaxWithEpsilonValues(self):
|
||||
opt = adamax.Adamax(epsilon=None)
|
||||
config = opt.get_config()
|
||||
self.assertEqual(config["epsilon"], 1e-7)
|
||||
|
||||
opt = adamax.Adamax(epsilon=1e-8)
|
||||
config = opt.get_config()
|
||||
self.assertEqual(config["epsilon"], 1e-8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -18,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend_config
|
||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -85,6 +86,8 @@ class Nadam(optimizer_v2.OptimizerV2):
|
||||
|
||||
# Backwards compatiblity with keras NAdam optimizer.
|
||||
kwargs['decay'] = kwargs.pop('schedule_decay', 0.004)
|
||||
if epsilon is None:
|
||||
epsilon = backend_config.epsilon()
|
||||
super(Nadam, self).__init__(name, **kwargs)
|
||||
self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
|
||||
self._set_hyper('decay', self._initial_decay)
|
||||
|
@ -178,6 +178,15 @@ class NadamOptimizerTest(test.TestCase):
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllClose(self.evaluate(opt.decay), (0.2))
|
||||
|
||||
def testConstructNAdamWithEpsilonValues(self):
|
||||
opt = nadam.Nadam(epsilon=None)
|
||||
config = opt.get_config()
|
||||
self.assertEqual(config["epsilon"], 1e-7)
|
||||
|
||||
opt = nadam.Nadam(epsilon=1e-8)
|
||||
config = opt.get_config()
|
||||
self.assertEqual(config["epsilon"], 1e-8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend_config
|
||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -102,6 +103,8 @@ class RMSprop(optimizer_v2.OptimizerV2):
|
||||
allow time inverse decay of learning rate. `lr` is included for backward
|
||||
compatibility, recommended to use `learning_rate` instead.
|
||||
"""
|
||||
if epsilon is None:
|
||||
epsilon = backend_config.epsilon()
|
||||
super(RMSprop, self).__init__(name, **kwargs)
|
||||
self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
|
||||
self._set_hyper("decay", self._initial_decay)
|
||||
|
@ -471,6 +471,15 @@ class RMSpropOptimizerTest(test.TestCase):
|
||||
self.assertEqual(
|
||||
self.evaluate(opt.variables()[0]), self.evaluate(opt.iterations))
|
||||
|
||||
def testConstructRMSpropWithEpsilonValues(self):
|
||||
opt = rmsprop.RMSprop(epsilon=None)
|
||||
config = opt.get_config()
|
||||
self.assertEqual(config["epsilon"], 1e-7)
|
||||
|
||||
opt = rmsprop.RMSprop(epsilon=1e-8)
|
||||
config = opt.get_config()
|
||||
self.assertEqual(config["epsilon"], 1e-8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user