Merge pull request #32433 from tanzhenyu/cherrypicks_WIQA6
Fix major Adamax gpu bug.
This commit is contained in:
commit
f03fe1bf79
@ -306,15 +306,16 @@ struct ApplyAdaMax<GPUDevice, T> {
|
||||
bcast[0] = grad.dimension(0);
|
||||
Eigen::Sizes<1> single;
|
||||
const auto one = static_cast<T>(1.0);
|
||||
m.device(d) =
|
||||
m + (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) *
|
||||
(grad - m);
|
||||
m.device(d) +=
|
||||
(beta1.constant(one) - beta1).reshape(single).broadcast(bcast) *
|
||||
(grad - m);
|
||||
v.device(d) =
|
||||
(beta2.reshape(single).broadcast(bcast) * v).cwiseMax(grad.abs());
|
||||
var.device(d) -=
|
||||
lr / (beta1_power.constant(one) -
|
||||
beta1_power).reshape(single).broadcast(bcast) *
|
||||
(m / (v + epsilon));
|
||||
var.device(d) -= lr.reshape(single).broadcast(bcast) /
|
||||
(beta1_power.constant(one) - beta1_power)
|
||||
.reshape(single)
|
||||
.broadcast(bcast) *
|
||||
(m / (v + epsilon.reshape(single).broadcast(bcast)));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -201,20 +201,13 @@ cuda_py_test(
|
||||
xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
||||
py_test(
|
||||
cuda_py_test(
|
||||
name = "optimizer_v2_test",
|
||||
size = "medium",
|
||||
srcs = ["optimizer_v2_test.py"],
|
||||
python_version = "PY2",
|
||||
shard_count = 8,
|
||||
tags = [
|
||||
"no_gpu", # b/127001953
|
||||
"no_windows",
|
||||
# TODO(b/127092862): Re-enable this test in Kokoro.
|
||||
"no_oss",
|
||||
],
|
||||
deps = [
|
||||
additional_deps = [
|
||||
":optimizer_v2",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:clip_ops",
|
||||
@ -226,8 +219,12 @@ py_test(
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/keras",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
shard_count = 8,
|
||||
tags = [
|
||||
"no_windows",
|
||||
],
|
||||
xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
|
@ -80,7 +80,7 @@ class AdamaxOptimizerTest(test.TestCase):
|
||||
|
||||
def doTestSparse(self, use_resource=False):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
# Initialize variables for numpy implementation.
|
||||
zero_slots = lambda: np.zeros((3), dtype=dtype.as_numpy_dtype) # pylint: disable=cell-var-from-loop
|
||||
m0, v0, m1, v1 = zero_slots(), zero_slots(), zero_slots(), zero_slots()
|
||||
@ -176,9 +176,12 @@ class AdamaxOptimizerTest(test.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes(reset_test=True)
|
||||
def testBasic(self):
|
||||
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
||||
with self.session(graph=ops.Graph()):
|
||||
with self.session(graph=ops.Graph(), use_gpu=True):
|
||||
# Initialize variables for numpy implementation.
|
||||
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
|
||||
m0 = np.array([0.0, 0.0])
|
||||
v0 = np.array([0.0, 0.0])
|
||||
m1 = np.array([0.0, 0.0])
|
||||
v1 = np.array([0.0, 0.0])
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||
@ -224,7 +227,7 @@ class AdamaxOptimizerTest(test.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes(reset_test=True)
|
||||
def testBasicWithLearningRateDecay(self):
|
||||
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
||||
with self.session(graph=ops.Graph()):
|
||||
with self.session(graph=ops.Graph(), use_gpu=True):
|
||||
# Initialize variables for numpy implementation.
|
||||
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||
@ -278,7 +281,7 @@ class AdamaxOptimizerTest(test.TestCase):
|
||||
@test_util.run_deprecated_v1
|
||||
def testTensorLearningRate(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
# Initialize variables for numpy implementation.
|
||||
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||
@ -315,7 +318,7 @@ class AdamaxOptimizerTest(test.TestCase):
|
||||
@test_util.run_deprecated_v1
|
||||
def testSharing(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
# Initialize variables for numpy implementation.
|
||||
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||
|
@ -65,7 +65,7 @@ class OptimizerTest(test.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testBasic(self):
|
||||
for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
||||
loss = lambda: 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
|
||||
@ -129,7 +129,7 @@ class OptimizerTest(test.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testPrecomputedGradient(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
||||
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
|
||||
loss = lambda: 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
|
||||
@ -153,7 +153,7 @@ class OptimizerTest(test.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testNoGradients(self):
|
||||
for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
||||
loss = lambda: 5 * var0 # pylint: disable=cell-var-from-loop
|
||||
@ -165,7 +165,7 @@ class OptimizerTest(test.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testNoGradientsForAnyVariables_Minimize(self):
|
||||
for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
||||
loss = lambda: constant_op.constant(5.0)
|
||||
@ -178,7 +178,7 @@ class OptimizerTest(test.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testNoGradientsForAnyVariables_ApplyGradients(self):
|
||||
for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
||||
sgd_op = gradient_descent.SGD(3.0)
|
||||
@ -189,7 +189,7 @@ class OptimizerTest(test.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testGradientsAsVariables(self):
|
||||
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
||||
loss = lambda: 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
|
||||
@ -227,7 +227,7 @@ class OptimizerTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testComputeGradientsWithTensors(self):
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
x = ops.convert_to_tensor(1.0)
|
||||
|
||||
def f():
|
||||
@ -247,7 +247,7 @@ class OptimizerTest(test.TestCase):
|
||||
def testConstraint(self):
|
||||
constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.)
|
||||
constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.)
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
var0 = variables.Variable([1.0, 2.0],
|
||||
constraint=constraint_01)
|
||||
var1 = variables.Variable([3.0, 4.0],
|
||||
@ -269,14 +269,14 @@ class OptimizerTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testIterationWithoutMinimize(self):
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
sgd = gradient_descent.SGD(3.0)
|
||||
self.evaluate(sgd.iterations.initializer)
|
||||
self.assertEqual(0, self.evaluate(sgd.iterations))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConfig(self):
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
opt = gradient_descent.SGD(learning_rate=1.0)
|
||||
config = opt.get_config()
|
||||
opt2 = gradient_descent.SGD.from_config(config)
|
||||
@ -296,7 +296,7 @@ class OptimizerTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConfigWithLearningRateDecay(self):
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
var0 = variables.Variable([[1.0], [2.0]], dtype=dtypes.float32)
|
||||
for decay_schedule in [
|
||||
learning_rate_schedule.InverseTimeDecay(
|
||||
@ -327,7 +327,7 @@ class OptimizerTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testGradClipValue(self):
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
var = resource_variable_ops.ResourceVariable([1.0, 2.0])
|
||||
loss = lambda: 3 * var
|
||||
opt = gradient_descent.SGD(learning_rate=1.0, clipvalue=1.0)
|
||||
@ -338,7 +338,7 @@ class OptimizerTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testGradClipNorm(self):
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
var = resource_variable_ops.ResourceVariable([1.0])
|
||||
loss = lambda: 3 * var
|
||||
opt = gradient_descent.SGD(learning_rate=1.0, clipnorm=1.0)
|
||||
@ -359,7 +359,7 @@ class OptimizerTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testWeights(self):
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
opt1 = adam.Adam(learning_rate=1.0)
|
||||
var1 = resource_variable_ops.ResourceVariable([1.0, 2.0],
|
||||
dtype=dtypes.float32)
|
||||
@ -620,7 +620,7 @@ class OptimizersCompatibilityTest(keras_parameterized.TestCase):
|
||||
'v1 optimizer does not run in experimental_run_tf_function mode or '
|
||||
'eager mode')
|
||||
np.random.seed(1331)
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
train_samples = 20
|
||||
input_dim = 3
|
||||
num_classes = 2
|
||||
@ -708,7 +708,7 @@ class OptimizersCompatibilityTest(keras_parameterized.TestCase):
|
||||
'v1 optimizer does not run in experimental_run_tf_function mode or '
|
||||
'eager mode')
|
||||
np.random.seed(1331)
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
train_samples = 20
|
||||
input_dim = 3
|
||||
num_classes = 2
|
||||
@ -769,7 +769,7 @@ class OptimizersCompatibilityTest(keras_parameterized.TestCase):
|
||||
'v1 optimizer does not run in experimental_run_tf_function mode or '
|
||||
'eager mode')
|
||||
np.random.seed(1331)
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
train_samples = 20
|
||||
input_dim = 3
|
||||
num_classes = 2
|
||||
|
Loading…
Reference in New Issue
Block a user