From bf9c196f37b9cbb3109b2891aaf9da85bf5f712a Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Fri, 1 Nov 2019 10:27:42 -0700 Subject: [PATCH] Add complex support to optimizers We do not support complex with certain optimizers such as Ftrl, FtrlV2, AdamWithAmsgrad, AdaMax, AddSign & PowerSign since they may use missing operations on complex values such as sqrt. Fixes #32774 PiperOrigin-RevId: 277953548 Change-Id: Ia075aa5c3f944de932d71b9741d626f7ebe5416f --- tensorflow/compiler/tests/adadelta_test.py | 18 +- tensorflow/compiler/tests/adagrad_test.py | 6 +- tensorflow/compiler/tests/adam_test.py | 6 +- tensorflow/compiler/tests/rmsprop_test.py | 10 +- .../compiler/tf2xla/kernels/training_ops.cc | 51 +- tensorflow/compiler/tf2xla/xla_op_registry.h | 2 + tensorflow/core/kernels/training_ops.cc | 100 +++ .../core/kernels/training_ops_gpu.cu.cc | 38 + .../keras/optimizer_v2/adadelta_test.py | 37 +- .../python/keras/optimizer_v2/adagrad_test.py | 734 +++++++++--------- .../python/keras/optimizer_v2/optimizer_v2.py | 6 +- .../keras/optimizer_v2/optimizer_v2_test.py | 69 +- .../python/keras/optimizer_v2/rmsprop_test.py | 111 ++- 13 files changed, 672 insertions(+), 516 deletions(-) diff --git a/tensorflow/compiler/tests/adadelta_test.py b/tensorflow/compiler/tests/adadelta_test.py index 548dbe53f2a..9afabcc7467 100644 --- a/tensorflow/compiler/tests/adadelta_test.py +++ b/tensorflow/compiler/tests/adadelta_test.py @@ -40,7 +40,7 @@ class AdadeltaOptimizerTest(xla_test.XLATestCase): all_grad = [0.2, 0.1, 0.01] all_lr = [1.0, 0.5, 0.1] - for dtype in self.float_types: + for dtype in self.float_types | self.complex_types: with self.session(), self.test_scope(): for grad in all_grad: for lr in all_lr: @@ -76,20 +76,20 @@ class AdadeltaOptimizerTest(xla_test.XLATestCase): self.assertEqual(["accum", "accum_update"], adadelta_opt.get_slot_names()) slot[0] = adadelta_opt.get_slot(var0, "accum") - self.assertEquals(slot[0].get_shape(), var0.get_shape()) - self.assertFalse(slot[0] in variables.trainable_variables()) + self.assertEqual(slot[0].get_shape(), var0.get_shape()) + self.assertNotIn(slot[0], variables.trainable_variables()) slot_update[0] = adadelta_opt.get_slot(var0, "accum_update") - self.assertEquals(slot_update[0].get_shape(), var0.get_shape()) - self.assertFalse(slot_update[0] in variables.trainable_variables()) + self.assertEqual(slot_update[0].get_shape(), var0.get_shape()) + self.assertNotIn(slot_update[0], variables.trainable_variables()) slot[1] = adadelta_opt.get_slot(var1, "accum") - self.assertEquals(slot[1].get_shape(), var1.get_shape()) - self.assertFalse(slot[1] in variables.trainable_variables()) + self.assertEqual(slot[1].get_shape(), var1.get_shape()) + self.assertNotIn(slot[1], variables.trainable_variables()) slot_update[1] = adadelta_opt.get_slot(var1, "accum_update") - self.assertEquals(slot_update[1].get_shape(), var1.get_shape()) - self.assertFalse(slot_update[1] in variables.trainable_variables()) + self.assertEqual(slot_update[1].get_shape(), var1.get_shape()) + self.assertNotIn(slot_update[1], variables.trainable_variables()) # Fetch params to validate initial values self.assertAllClose(var0_init, self.evaluate(var0)) diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py index 844e5dfd831..9f7a940019e 100644 --- a/tensorflow/compiler/tests/adagrad_test.py +++ b/tensorflow/compiler/tests/adagrad_test.py @@ -31,7 +31,7 @@ from tensorflow.python.training import adagrad class AdagradOptimizerTest(xla_test.XLATestCase): def testBasic(self): - for dtype in self.float_types: + for dtype in self.float_types | self.complex_types: with self.session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) @@ -101,9 +101,9 @@ class AdagradOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1])) self.assertEqual(["accumulator"], ada_opt.get_slot_names()) slot0 = ada_opt.get_slot(var0, "accumulator") - self.assertEquals(slot0.get_shape(), var0.get_shape()) + self.assertEqual(slot0.get_shape(), var0.get_shape()) slot1 = ada_opt.get_slot(var1, "accumulator") - self.assertEquals(slot1.get_shape(), var1.get_shape()) + self.assertEqual(slot1.get_shape(), var1.get_shape()) variables.global_variables_initializer().run() # Fetch params to validate initial values. diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py index bf22b756074..2a5b809e288 100644 --- a/tensorflow/compiler/tests/adam_test.py +++ b/tensorflow/compiler/tests/adam_test.py @@ -52,7 +52,7 @@ def adam_update_numpy(param, class AdamOptimizerTest(xla_test.XLATestCase): def testBasic(self): - for dtype in self.float_types: + for dtype in self.float_types | self.complex_types: # TODO: test fails for float16 due to excessive precision requirements. if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue @@ -95,7 +95,7 @@ class AdamOptimizerTest(xla_test.XLATestCase): self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testTensorLearningRate(self): - for dtype in self.float_types: + for dtype in self.float_types | self.complex_types: # TODO: test fails for float16 due to excessive precision requirements. if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue @@ -138,7 +138,7 @@ class AdamOptimizerTest(xla_test.XLATestCase): self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testSharing(self): - for dtype in self.float_types: + for dtype in self.float_types | self.complex_types: # TODO: test fails for float16 due to excessive precision requirements. if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py index 961103e83f2..a8c449175be 100644 --- a/tensorflow/compiler/tests/rmsprop_test.py +++ b/tensorflow/compiler/tests/rmsprop_test.py @@ -53,7 +53,7 @@ class RmspropTest(xla_test.XLATestCase): return var_t, mg_t, rms_t, mom_t def testBasic(self): - for dtype in self.float_types: + for dtype in self.float_types | self.complex_types: for centered in [False, True]: with self.session(), self.test_scope(): # Initialize variables for numpy implementation. @@ -83,13 +83,13 @@ class RmspropTest(xla_test.XLATestCase): mg1 = rms_opt.get_slot(var1, "mg") self.assertEqual(mg1 is not None, centered) rms0 = rms_opt.get_slot(var0, "rms") - self.assertTrue(rms0 is not None) + self.assertIsNotNone(rms0) rms1 = rms_opt.get_slot(var1, "rms") - self.assertTrue(rms1 is not None) + self.assertIsNotNone(rms1) mom0 = rms_opt.get_slot(var0, "momentum") - self.assertTrue(mom0 is not None) + self.assertIsNotNone(mom0) mom1 = rms_opt.get_slot(var1, "momentum") - self.assertTrue(mom1 is not None) + self.assertIsNotNone(mom1) # Fetch params to validate initial values self.assertAllClose([1.0, 2.0], self.evaluate(var0)) diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index ff4833ab802..3816cabc282 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -52,9 +52,9 @@ class ResourceApplyGradientDescent : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; -REGISTER_XLA_OP( - Name("ResourceApplyGradientDescent").TypeConstraint("T", kFloatTypes), - ResourceApplyGradientDescent); +REGISTER_XLA_OP(Name("ResourceApplyGradientDescent") + .TypeConstraint("T", kFloatAndComplexTypes), + ResourceApplyGradientDescent); xla::XlaOp ProximalGradientDescentUpdate(xla::XlaOp var, xla::XlaOp lr, xla::XlaOp l1, xla::XlaOp l2, @@ -111,7 +111,7 @@ class ResourceApplyProximalGradientDescent : public XlaOpKernel { DataType dtype_; }; REGISTER_XLA_OP(Name("ResourceApplyProximalGradientDescent") - .TypeConstraint("T", kFloatTypes), + .TypeConstraint("T", kFloatAndComplexTypes), ResourceApplyProximalGradientDescent); class ResourceApplyMomentum : public XlaOpKernel { @@ -226,9 +226,9 @@ class ResourceApplyKerasMomentum : public XlaOpKernel { private: bool use_nesterov_; }; -REGISTER_XLA_OP( - Name("ResourceApplyKerasMomentum").TypeConstraint("T", kFloatTypes), - ResourceApplyKerasMomentum); +REGISTER_XLA_OP(Name("ResourceApplyKerasMomentum") + .TypeConstraint("T", kFloatAndComplexTypes), + ResourceApplyKerasMomentum); class ResourceApplyAdagrad : public XlaOpKernel { public: @@ -274,8 +274,9 @@ class ResourceApplyAdagrad : public XlaOpKernel { private: bool update_slots_; }; -REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes), - ResourceApplyAdagrad); +REGISTER_XLA_OP( + Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatAndComplexTypes), + ResourceApplyAdagrad); class ResourceApplyAdagradV2 : public XlaOpKernel { public: @@ -328,8 +329,9 @@ class ResourceApplyAdagradV2 : public XlaOpKernel { private: bool update_slots_; }; -REGISTER_XLA_OP(Name("ResourceApplyAdagradV2").TypeConstraint("T", kFloatTypes), - ResourceApplyAdagradV2); +REGISTER_XLA_OP( + Name("ResourceApplyAdagradV2").TypeConstraint("T", kFloatAndComplexTypes), + ResourceApplyAdagradV2); class ResourceApplyProximalAdagrad : public XlaOpKernel { public: @@ -383,9 +385,9 @@ class ResourceApplyProximalAdagrad : public XlaOpKernel { private: DataType dtype_; }; -REGISTER_XLA_OP( - Name("ResourceApplyProximalAdagrad").TypeConstraint("T", kFloatTypes), - ResourceApplyProximalAdagrad); +REGISTER_XLA_OP(Name("ResourceApplyProximalAdagrad") + .TypeConstraint("T", kFloatAndComplexTypes), + ResourceApplyProximalAdagrad); class ResourceApplyAdagradDA : public XlaOpKernel { public: @@ -556,8 +558,9 @@ class ResourceApplyAdam : public XlaOpKernel { DataType dtype_; bool use_nesterov_; }; -REGISTER_XLA_OP(Name("ResourceApplyAdam").TypeConstraint("T", kFloatTypes), - ResourceApplyAdam); +REGISTER_XLA_OP( + Name("ResourceApplyAdam").TypeConstraint("T", kFloatAndComplexTypes), + ResourceApplyAdam); class ResourceApplyAdaMax : public XlaOpKernel { public: @@ -729,8 +732,9 @@ class ResourceApplyRMSProp : public XlaOpKernel { private: DataType dtype_; }; -REGISTER_XLA_OP(Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatTypes), - ResourceApplyRMSProp); +REGISTER_XLA_OP( + Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatAndComplexTypes), + ResourceApplyRMSProp); class ResourceApplyCenteredRMSProp : public ResourceApplyRMSProp { public: @@ -739,9 +743,9 @@ class ResourceApplyCenteredRMSProp : public ResourceApplyRMSProp { centered_ = true; } }; -REGISTER_XLA_OP( - Name("ResourceApplyCenteredRMSProp").TypeConstraint("T", kFloatTypes), - ResourceApplyCenteredRMSProp); +REGISTER_XLA_OP(Name("ResourceApplyCenteredRMSProp") + .TypeConstraint("T", kFloatAndComplexTypes), + ResourceApplyCenteredRMSProp); void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, bool has_l2_shrinkage) { @@ -942,8 +946,9 @@ class ResourceApplyAdadelta : public XlaOpKernel { private: DataType dtype_; }; -REGISTER_XLA_OP(Name("ResourceApplyAdadelta").TypeConstraint("T", kFloatTypes), - ResourceApplyAdadelta); +REGISTER_XLA_OP( + Name("ResourceApplyAdadelta").TypeConstraint("T", kFloatAndComplexTypes), + ResourceApplyAdadelta); class ResourceApplySignBase : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index fa51753aa45..af08790e02e 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -47,6 +47,8 @@ extern const char* const DEVICE_XLA_GPU; constexpr std::array kFloatTypes = { {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}}; +constexpr std::array kFloatAndComplexTypes = { + {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16, DT_COMPLEX64, DT_COMPLEX128}}; constexpr std::array kNumericTypes = { {DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 7aa71d635c8..cba3d4fbdc2 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -652,6 +652,8 @@ TF_CALL_half(REGISTER_CPU_KERNELS); TF_CALL_bfloat16(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); +TF_CALL_complex64(REGISTER_CPU_KERNELS); +TF_CALL_complex128(REGISTER_CPU_KERNELS); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. @@ -666,12 +668,16 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); +DECLARE_GPU_SPEC(complex64); +DECLARE_GPU_SPEC(complex128); #undef DECLARE_GPU_SPEC } // namespace functor REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); +REGISTER_KERNELS(GPU, complex64); +REGISTER_KERNELS(GPU, complex128); #endif #ifdef TENSORFLOW_USE_SYCL @@ -813,6 +819,8 @@ TF_CALL_half(REGISTER_CPU_KERNELS); TF_CALL_bfloat16(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); +TF_CALL_complex64(REGISTER_CPU_KERNELS); +TF_CALL_complex128(REGISTER_CPU_KERNELS); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. @@ -829,12 +837,20 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +DECLARE_GPU_SPEC(complex64); +DECLARE_GPU_SPEC(complex128); +#endif #undef DECLARE_GPU_SPEC } // namespace functor REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +REGISTER_KERNELS(GPU, complex64); +REGISTER_KERNELS(GPU, complex128); +#endif #endif #undef REGISTER_CPU_KERNELS #undef REGISTER_KERNELS @@ -994,6 +1010,8 @@ TF_CALL_half(REGISTER_CPU_KERNELS); TF_CALL_bfloat16(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); +TF_CALL_complex64(REGISTER_CPU_KERNELS); +TF_CALL_complex128(REGISTER_CPU_KERNELS); #undef REGISTER_CPU_KERNELS #undef REGISTER_KERNELS @@ -1286,6 +1304,8 @@ TF_CALL_half(REGISTER_CPU_KERNELS); TF_CALL_bfloat16(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); +TF_CALL_complex64(REGISTER_CPU_KERNELS); +TF_CALL_complex128(REGISTER_CPU_KERNELS); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. @@ -1300,12 +1320,20 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +DECLARE_GPU_SPEC(complex64); +DECLARE_GPU_SPEC(complex128); +#endif #undef DECLARE_GPU_SPEC } // namespace functor REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +REGISTER_KERNELS(GPU, complex64); +REGISTER_KERNELS(GPU, complex128); +#endif #endif #undef REGISTER_CPU_KERNELS #undef REGISTER_KERNELS @@ -1385,6 +1413,8 @@ TF_CALL_half(REGISTER_CPU_KERNELS); TF_CALL_bfloat16(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); +TF_CALL_complex64(REGISTER_CPU_KERNELS); +TF_CALL_complex128(REGISTER_CPU_KERNELS); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. @@ -1400,12 +1430,20 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +DECLARE_GPU_SPEC(complex64); +DECLARE_GPU_SPEC(complex128); +#endif #undef DECLARE_GPU_SPEC } // namespace functor REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +REGISTER_KERNELS(GPU, complex64); +REGISTER_KERNELS(GPU, complex128); +#endif #endif #undef REGISTER_CPU_KERNELS #undef REGISTER_KERNELS @@ -1672,6 +1710,8 @@ TF_CALL_half(REGISTER_CPU_KERNELS); TF_CALL_bfloat16(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); +TF_CALL_complex64(REGISTER_CPU_KERNELS); +TF_CALL_complex128(REGISTER_CPU_KERNELS); #undef REGISTER_CPU_KERNELS #undef REGISTER_KERNELS @@ -1845,6 +1885,8 @@ TF_CALL_half(REGISTER_CPU_KERNELS); TF_CALL_bfloat16(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); +TF_CALL_complex64(REGISTER_CPU_KERNELS); +TF_CALL_complex128(REGISTER_CPU_KERNELS); #undef REGISTER_CPU_KERNELS #undef REGISTER_KERNELS @@ -2862,6 +2904,8 @@ TF_CALL_half(REGISTER_CPU_KERNELS); TF_CALL_bfloat16(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); +TF_CALL_complex64(REGISTER_CPU_KERNELS); +TF_CALL_complex128(REGISTER_CPU_KERNELS); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. @@ -2877,12 +2921,20 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +DECLARE_GPU_SPEC(complex64); +DECLARE_GPU_SPEC(complex128); +#endif #undef DECLARE_GPU_SPEC } // namespace functor REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +REGISTER_KERNELS(GPU, complex64); +REGISTER_KERNELS(GPU, complex128); +#endif #endif #undef REGISTER_CPU_KERNELS #undef REGISTER_KERNELS @@ -3003,6 +3055,8 @@ TF_CALL_half(REGISTER_CPU_KERNELS); TF_CALL_bfloat16(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); +TF_CALL_complex64(REGISTER_CPU_KERNELS); +TF_CALL_complex128(REGISTER_CPU_KERNELS); #undef REGISTER_CPU_KERNELS #undef REGISTER_KERNELS @@ -3080,6 +3134,8 @@ TF_CALL_half(REGISTER_CPU_KERNELS); TF_CALL_bfloat16(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); +TF_CALL_complex64(REGISTER_CPU_KERNELS); +TF_CALL_complex128(REGISTER_CPU_KERNELS); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. @@ -3095,12 +3151,20 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +DECLARE_GPU_SPEC(complex64); +DECLARE_GPU_SPEC(complex128); +#endif #undef DECLARE_GPU_SPEC } // namespace functor REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +REGISTER_KERNELS(GPU, complex64); +REGISTER_KERNELS(GPU, complex128); +#endif #endif #undef REGISTER_CPU_KERNELS #undef REGISTER_KERNELS @@ -3201,6 +3265,8 @@ TF_CALL_half(REGISTER_CPU_KERNELS); TF_CALL_bfloat16(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); +TF_CALL_complex64(REGISTER_CPU_KERNELS); +TF_CALL_complex128(REGISTER_CPU_KERNELS); #undef REGISTER_CPU_KERNELS #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -3221,6 +3287,12 @@ DECLARE_GPU_SPEC(float, int32); DECLARE_GPU_SPEC(float, int64); DECLARE_GPU_SPEC(double, int32); DECLARE_GPU_SPEC(double, int64); +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +DECLARE_GPU_SPEC(complex64, int32); +DECLARE_GPU_SPEC(complex64, int64); +DECLARE_GPU_SPEC(complex128, int32); +DECLARE_GPU_SPEC(complex128, int64); +#endif #undef DECLARE_GPU_SPEC } // namespace functor @@ -3231,6 +3303,10 @@ DECLARE_GPU_SPEC(double, int64); REGISTER_GPU_KERNELS(Eigen::half); REGISTER_GPU_KERNELS(float); REGISTER_GPU_KERNELS(double); +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +REGISTER_GPU_KERNELS(complex64); +REGISTER_GPU_KERNELS(complex128); +#endif #undef REGISTER_GPU_KERNELS #endif #undef REGISTER_KERNELS @@ -3461,6 +3537,8 @@ TF_CALL_half(REGISTER_CPU_KERNELS); TF_CALL_bfloat16(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); +TF_CALL_complex64(REGISTER_CPU_KERNELS); +TF_CALL_complex128(REGISTER_CPU_KERNELS); #ifdef TENSORFLOW_USE_SYCL #define REGISTER_SYCL_KERNELS(T) REGISTER_KERNELS(SYCL, T); @@ -3488,12 +3566,20 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +DECLARE_GPU_SPEC(complex64); +DECLARE_GPU_SPEC(complex128); +#endif #undef DECLARE_GPU_SPEC } // namespace functor REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +REGISTER_KERNELS(GPU, complex64); +REGISTER_KERNELS(GPU, complex128); +#endif #endif #undef REGISTER_CPU_KERNELS #undef REGISTER_KERNELS @@ -3974,6 +4060,8 @@ TF_CALL_half(REGISTER_CPU_KERNELS); TF_CALL_bfloat16(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); +TF_CALL_complex64(REGISTER_CPU_KERNELS); +TF_CALL_complex128(REGISTER_CPU_KERNELS); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. @@ -4001,12 +4089,20 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +DECLARE_GPU_SPEC(complex64); +DECLARE_GPU_SPEC(complex128); +#endif #undef DECLARE_GPU_SPEC } // namespace functor REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +REGISTER_KERNELS(GPU, complex64); +REGISTER_KERNELS(GPU, complex128); +#endif #endif #undef REGISTER_CPU_KERNELS #undef REGISTER_KERNELS @@ -4310,6 +4406,10 @@ REGISTER_KERNELS(float, int32); REGISTER_KERNELS(float, int64); REGISTER_KERNELS(double, int32); REGISTER_KERNELS(double, int64); +REGISTER_KERNELS(complex64, int32); +REGISTER_KERNELS(complex64, int64); +REGISTER_KERNELS(complex128, int32); +REGISTER_KERNELS(complex128, int64); #undef REGISTER_KERNELS diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc index 17ab7c59d12..9b28dac9316 100644 --- a/tensorflow/core/kernels/training_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc @@ -524,18 +524,30 @@ struct ApplyPowerSign { template struct functor::ApplyGradientDescent; template struct functor::ApplyGradientDescent; template struct functor::ApplyGradientDescent; +template struct functor::ApplyGradientDescent; +template struct functor::ApplyGradientDescent; template struct functor::ApplyAdagrad; template struct functor::ApplyAdagrad; template struct functor::ApplyAdagrad; +template struct functor::ApplyAdagrad; +template struct functor::ApplyAdagrad; template struct functor::ApplyAdagradV2; template struct functor::ApplyAdagradV2; template struct functor::ApplyAdagradV2; +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +template struct functor::ApplyAdagradV2; +template struct functor::ApplyAdagradV2; +#endif template struct functor::ApplyAdadelta; template struct functor::ApplyAdadelta; template struct functor::ApplyAdadelta; +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +template struct functor::ApplyAdadelta; +template struct functor::ApplyAdadelta; +#endif template struct functor::ApplyFtrl; template struct functor::ApplyFtrl; @@ -548,10 +560,18 @@ template struct functor::ApplyFtrlV2; template struct functor::ApplyMomentum; template struct functor::ApplyMomentum; template struct functor::ApplyMomentum; +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +template struct functor::ApplyMomentum; +template struct functor::ApplyMomentum; +#endif template struct functor::ApplyKerasMomentum; template struct functor::ApplyKerasMomentum; template struct functor::ApplyKerasMomentum; +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +template struct functor::ApplyKerasMomentum; +template struct functor::ApplyKerasMomentum; +#endif template struct functor::SparseApplyKerasMomentum; @@ -561,10 +581,20 @@ template struct functor::SparseApplyKerasMomentum; template struct functor::SparseApplyKerasMomentum; template struct functor::SparseApplyKerasMomentum; template struct functor::SparseApplyKerasMomentum; +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +template struct functor::SparseApplyKerasMomentum; +template struct functor::SparseApplyKerasMomentum; +template struct functor::SparseApplyKerasMomentum; +template struct functor::SparseApplyKerasMomentum; +#endif template struct functor::ApplyAdam; template struct functor::ApplyAdam; template struct functor::ApplyAdam; +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +template struct functor::ApplyAdam; +template struct functor::ApplyAdam; +#endif template struct functor::ApplyAdamWithAmsgrad; template struct functor::ApplyAdamWithAmsgrad; @@ -577,10 +607,18 @@ template struct functor::ApplyAdaMax; template struct functor::ApplyRMSProp; template struct functor::ApplyRMSProp; template struct functor::ApplyRMSProp; +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +template struct functor::ApplyRMSProp; +template struct functor::ApplyRMSProp; +#endif template struct functor::ApplyCenteredRMSProp; template struct functor::ApplyCenteredRMSProp; template struct functor::ApplyCenteredRMSProp; +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt +template struct functor::ApplyCenteredRMSProp; +template struct functor::ApplyCenteredRMSProp; +#endif template struct functor::ApplyAddSign; template struct functor::ApplyAddSign; diff --git a/tensorflow/python/keras/optimizer_v2/adadelta_test.py b/tensorflow/python/keras/optimizer_v2/adadelta_test.py index fd4b7ae4544..606917cc542 100644 --- a/tensorflow/python/keras/optimizer_v2/adadelta_test.py +++ b/tensorflow/python/keras/optimizer_v2/adadelta_test.py @@ -31,12 +31,17 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test +_DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64] +# TODO(b/143684500): Eigen to support complex sqrt +if not test_util.IsBuiltWithNvcc(): + _DATA_TYPES += [dtypes.complex64, dtypes.complex128] + class AdadeltaOptimizerTest(test.TestCase): def doTestBasic(self, use_resource=False, use_callable_params=False): num_updates = 4 # number of ADADELTA steps to perform - for dtype in [dtypes.half, dtypes.float32]: + for dtype in _DATA_TYPES: for grad in [0.2, 0.1, 0.01]: for lr in [1.0, 0.5, 0.1]: var0_init = [1.0, 2.0] @@ -149,24 +154,22 @@ class AdadeltaOptimizerTest(test.TestCase): @test_util.run_deprecated_v1 def testMinimizeSparseResourceVariable(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.cached_session(): - var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) - x = constant_op.constant([[4.0], [5.0]], dtype=dtype) + for dtype in _DATA_TYPES: + var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) + x = constant_op.constant([[4.0], [5.0]], dtype=dtype) - def loss(): - pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) # pylint: disable=cell-var-from-loop - return pred * pred + def loss(): + pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) # pylint: disable=cell-var-from-loop + return pred * pred - sgd_op = adadelta.Adadelta(1.0, 1.0, 1.0).minimize( - loss, var_list=[var0]) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([[1.0, 2.0]], self.evaluate(var0)) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - self.assertAllCloseAccordingToType([[-111, -138]], self.evaluate(var0)) + sgd_op = adadelta.Adadelta(1.0, 1.0, 1.0).minimize(loss, var_list=[var0]) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([[1.0, 2.0]], self.evaluate(var0)) + # Run 1 step of sgd + self.evaluate(sgd_op) + # Validate updated params + self.assertAllCloseAccordingToType([[-111, -138]], self.evaluate(var0)) def testConstructAdadeltaWithLR(self): opt = adadelta.Adadelta(lr=1.0, rho=0.9, epsilon=1.) diff --git a/tensorflow/python/keras/optimizer_v2/adagrad_test.py b/tensorflow/python/keras/optimizer_v2/adagrad_test.py index d3a2ac8b5ab..03a85780f25 100644 --- a/tensorflow/python/keras/optimizer_v2/adagrad_test.py +++ b/tensorflow/python/keras/optimizer_v2/adagrad_test.py @@ -35,6 +35,11 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test +_DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64] +# TODO(b/143684500): Eigen to support complex sqrt +if not test_util.IsBuiltWithNvcc(): + _DATA_TYPES += [dtypes.complex64, dtypes.complex128] + def adagrad_update_numpy(param, accum, g_t, lr=0.001, epsilon=1e-7): accum_t = accum + g_t * g_t @@ -66,118 +71,24 @@ def sparse_adagrad_update_numpy(param, class AdagradOptimizerTest(test.TestCase): def doTestBasic(self, use_callable_params=False): - for dtype in [dtypes.float32, dtypes.float64]: - with self.cached_session(): - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) - var0 = resource_variable_ops.ResourceVariable(var0_np) - var1 = resource_variable_ops.ResourceVariable(var1_np) - grads0 = constant_op.constant(grads0_np) - grads1 = constant_op.constant(grads1_np) - - learning_rate = lambda: 3.0 - if not use_callable_params: - learning_rate = learning_rate() - - ada_opt = adagrad.Adagrad(learning_rate) - - accum0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - accum1_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - - if not context.executing_eagerly(): - ada_update = ada_opt.apply_gradients( - zip([grads0, grads1], [var0, var1])) - self.evaluate(variables.global_variables_initializer()) - - # Fetch params to validate initial values - v0_val, v1_val = self.evaluate([var0, var1]) - self.assertAllClose([1.0, 2.0], v0_val) - self.assertAllClose([3.0, 4.0], v1_val) - - # Run 3 steps of adagrad - for _ in range(3): - if not context.executing_eagerly(): - self.evaluate(ada_update) - else: - ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - var0_np, accum0_np = adagrad_update_numpy(var0_np, accum0_np, - grads0_np, 3.0) - var1_np, accum1_np = adagrad_update_numpy(var1_np, accum1_np, - grads1_np, 3.0) - self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) - self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) - - @test_util.run_in_graph_and_eager_modes(reset_test=True) - def testBasic(self): - self.doTestBasic() - - def testBasicCallableParams(self): - with context.eager_mode(): - self.doTestBasic(use_callable_params=True) - - def testBasicWithLearningRateDecay(self): - for dtype in [dtypes.float32, dtypes.float64]: - with self.cached_session(): - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) - var0 = resource_variable_ops.ResourceVariable(var0_np) - var1 = resource_variable_ops.ResourceVariable(var1_np) - grads0 = constant_op.constant(grads0_np) - grads1 = constant_op.constant(grads1_np) - - learning_rate = 3.0 - decay = 0.5 - - ada_opt = adagrad.Adagrad(learning_rate, decay=decay) - - accum0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - accum1_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - - if not context.executing_eagerly(): - ada_update = ada_opt.apply_gradients( - zip([grads0, grads1], [var0, var1])) - self.evaluate(variables.global_variables_initializer()) - - # Fetch params to validate initial values - v0_val, v1_val = self.evaluate([var0, var1]) - self.assertAllClose([1.0, 2.0], v0_val) - self.assertAllClose([3.0, 4.0], v1_val) - - # Run 3 steps of adagrad - for t in range(3): - if not context.executing_eagerly(): - self.evaluate(ada_update) - else: - ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - lr_np = learning_rate / (1 + decay * t) - var0_np, accum0_np = adagrad_update_numpy(var0_np, accum0_np, - grads0_np, lr_np) - var1_np, accum1_np = adagrad_update_numpy(var1_np, accum1_np, - grads1_np, lr_np) - self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) - self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) - - def testBasicWithLargeEpsilon(self): - with self.cached_session(): - var0_np = np.array([1.0, 2.0]) - var1_np = np.array([3.0, 4.0]) - grads0_np = np.array([0.1, 0.1]) - grads1_np = np.array([0.01, 0.01]) + for dtype in _DATA_TYPES: + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) var0 = resource_variable_ops.ResourceVariable(var0_np) var1 = resource_variable_ops.ResourceVariable(var1_np) grads0 = constant_op.constant(grads0_np) grads1 = constant_op.constant(grads1_np) - learning_rate = 3.0 + learning_rate = lambda: 3.0 + if not use_callable_params: + learning_rate = learning_rate() - ada_opt = adagrad.Adagrad(learning_rate, epsilon=1.0) + ada_opt = adagrad.Adagrad(learning_rate) - accum0_np = np.array([0.1, 0.1]) - accum1_np = np.array([0.1, 0.1]) + accum0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + accum1_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) if not context.executing_eagerly(): ada_update = ada_opt.apply_gradients( @@ -196,330 +107,407 @@ class AdagradOptimizerTest(test.TestCase): else: ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1])) var0_np, accum0_np = adagrad_update_numpy(var0_np, accum0_np, grads0_np, - 3.0, 1.0) + 3.0) var1_np, accum1_np = adagrad_update_numpy(var1_np, accum1_np, grads1_np, - 3.0, 1.0) + 3.0) self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) - def testBasicWithLearningRateInverseTimeDecay(self): - for dtype in [dtypes.float32, dtypes.float64]: - with self.cached_session(): - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) - var0 = resource_variable_ops.ResourceVariable(var0_np) - var1 = resource_variable_ops.ResourceVariable(var1_np) - grads0 = constant_op.constant(grads0_np) - grads1 = constant_op.constant(grads1_np) + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testBasic(self): + self.doTestBasic() - learning_rate = 3.0 - decay = 0.5 - lr_schedule = learning_rate_schedule.InverseTimeDecay( - learning_rate, decay_steps=1.0, decay_rate=decay) + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic(use_callable_params=True) - ada_opt = adagrad.Adagrad(lr_schedule) + def testBasicWithLearningRateDecay(self): + for dtype in _DATA_TYPES: + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) - accum0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - accum1_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + learning_rate = 3.0 + decay = 0.5 + ada_opt = adagrad.Adagrad(learning_rate, decay=decay) + + accum0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + accum1_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + + if not context.executing_eagerly(): + ada_update = ada_opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) + + # Fetch params to validate initial values + v0_val, v1_val = self.evaluate([var0, var1]) + self.assertAllClose([1.0, 2.0], v0_val) + self.assertAllClose([3.0, 4.0], v1_val) + + # Run 3 steps of adagrad + for t in range(3): if not context.executing_eagerly(): - ada_update = ada_opt.apply_gradients( - zip([grads0, grads1], [var0, var1])) - self.evaluate(variables.global_variables_initializer()) + self.evaluate(ada_update) + else: + ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + lr_np = learning_rate / (1 + decay * t) + var0_np, accum0_np = adagrad_update_numpy(var0_np, accum0_np, grads0_np, + lr_np) + var1_np, accum1_np = adagrad_update_numpy(var1_np, accum1_np, grads1_np, + lr_np) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) - # Fetch params to validate initial values - v0_val, v1_val = self.evaluate([var0, var1]) - self.assertAllClose([1.0, 2.0], v0_val) - self.assertAllClose([3.0, 4.0], v1_val) + def testBasicWithLargeEpsilon(self): + var0_np = np.array([1.0, 2.0]) + var1_np = np.array([3.0, 4.0]) + grads0_np = np.array([0.1, 0.1]) + grads1_np = np.array([0.01, 0.01]) + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) - # Run 3 steps of adagrad - for t in range(3): - if not context.executing_eagerly(): - self.evaluate(ada_update) - else: - ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - lr_np = learning_rate / (1 + decay * t) - var0_np, accum0_np = adagrad_update_numpy(var0_np, accum0_np, - grads0_np, lr_np) - var1_np, accum1_np = adagrad_update_numpy(var1_np, accum1_np, - grads1_np, lr_np) - self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) - self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + learning_rate = 3.0 + + ada_opt = adagrad.Adagrad(learning_rate, epsilon=1.0) + + accum0_np = np.array([0.1, 0.1]) + accum1_np = np.array([0.1, 0.1]) + + if not context.executing_eagerly(): + ada_update = ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) + + # Fetch params to validate initial values + v0_val, v1_val = self.evaluate([var0, var1]) + self.assertAllClose([1.0, 2.0], v0_val) + self.assertAllClose([3.0, 4.0], v1_val) + + # Run 3 steps of adagrad + for _ in range(3): + if not context.executing_eagerly(): + self.evaluate(ada_update) + else: + ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + var0_np, accum0_np = adagrad_update_numpy(var0_np, accum0_np, grads0_np, + 3.0, 1.0) + var1_np, accum1_np = adagrad_update_numpy(var1_np, accum1_np, grads1_np, + 3.0, 1.0) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testBasicWithLearningRateInverseTimeDecay(self): + for dtype in _DATA_TYPES: + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = 3.0 + decay = 0.5 + lr_schedule = learning_rate_schedule.InverseTimeDecay( + learning_rate, decay_steps=1.0, decay_rate=decay) + + ada_opt = adagrad.Adagrad(lr_schedule) + + accum0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + accum1_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + + if not context.executing_eagerly(): + ada_update = ada_opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) + + # Fetch params to validate initial values + v0_val, v1_val = self.evaluate([var0, var1]) + self.assertAllClose([1.0, 2.0], v0_val) + self.assertAllClose([3.0, 4.0], v1_val) + + # Run 3 steps of adagrad + for t in range(3): + if not context.executing_eagerly(): + self.evaluate(ada_update) + else: + ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + lr_np = learning_rate / (1 + decay * t) + var0_np, accum0_np = adagrad_update_numpy(var0_np, accum0_np, grads0_np, + lr_np) + var1_np, accum1_np = adagrad_update_numpy(var1_np, accum1_np, grads1_np, + lr_np) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) @test_util.run_deprecated_v1 def testMinimizeSparseResourceVariable(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.cached_session(): - var0 = resource_variable_ops.ResourceVariable( - [[1.0, 2.0], [3.0, 4.0]], dtype=dtype) - x = constant_op.constant([[4.0], [5.0]], dtype=dtype) + for dtype in _DATA_TYPES: + var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0], [3.0, 4.0]], + dtype=dtype) + x = constant_op.constant([[4.0], [5.0]], dtype=dtype) - def loss(): - pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) # pylint: disable=cell-var-from-loop - return pred * pred + def loss(): + pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) # pylint: disable=cell-var-from-loop + return pred * pred - sgd_op = adagrad.Adagrad(1.0).minimize(loss, var_list=[var0]) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType( - [[1.0, 2.0], [3.0, 4.0]], var0.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - self.assertAllCloseAccordingToType( - [[0, 1], [3, 4]], var0.eval(), atol=0.01) + sgd_op = adagrad.Adagrad(1.0).minimize(loss, var_list=[var0]) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([[1.0, 2.0], [3.0, 4.0]], + self.evaluate(var0)) + # Run 1 step of sgd + self.evaluate(sgd_op) + # Validate updated params + self.assertAllCloseAccordingToType([[0, 1], [3, 4]], + self.evaluate(var0), + atol=0.01) @test_util.run_deprecated_v1 def testTensorLearningRate(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.cached_session(): - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) - var0 = resource_variable_ops.ResourceVariable(var0_np) - var1 = resource_variable_ops.ResourceVariable(var1_np) - grads0 = constant_op.constant(grads0_np) - grads1 = constant_op.constant(grads1_np) + for dtype in _DATA_TYPES: + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) - learning_rate = constant_op.constant(3.0) - ada_opt = adagrad.Adagrad(learning_rate) - ada_update = ada_opt.apply_gradients( - zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - accum0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - accum1_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - # Run 3 steps of adagrad - for _ in range(3): - ada_update.run() - var0_np, accum0_np = adagrad_update_numpy(var0_np, accum0_np, - grads0_np, learning_rate) - var1_np, accum1_np = adagrad_update_numpy(var1_np, accum1_np, - grads1_np, learning_rate) - self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) - self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + learning_rate = constant_op.constant(3.0) + ada_opt = adagrad.Adagrad(learning_rate) + ada_update = ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + accum0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + accum1_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + # Run 3 steps of adagrad + for _ in range(3): + self.evaluate(ada_update) + var0_np, accum0_np = adagrad_update_numpy(var0_np, accum0_np, grads0_np, + learning_rate) + var1_np, accum1_np = adagrad_update_numpy(var1_np, accum1_np, grads1_np, + learning_rate) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) @test_util.run_deprecated_v1 def testSparseBasic(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.cached_session(): - var0_np = np.array([1.0, 1.0, 2.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1, 0, 0.1], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01, 0, 0.01], dtype=dtype.as_numpy_dtype) + for dtype in _DATA_TYPES: + var0_np = np.array([1.0, 1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0, 0.01], dtype=dtype.as_numpy_dtype) - var0 = resource_variable_ops.ResourceVariable(var0_np) - var1 = resource_variable_ops.ResourceVariable(var1_np) - grads0_np_indices = np.array([0, 2], dtype=np.int32) - grads0 = ops.IndexedSlices( - constant_op.constant(grads0_np[grads0_np_indices]), - constant_op.constant(grads0_np_indices), constant_op.constant([3])) - grads1_np_indices = np.array([0, 2], dtype=np.int32) - grads1 = ops.IndexedSlices( - constant_op.constant(grads1_np[grads1_np_indices]), - constant_op.constant(grads1_np_indices), constant_op.constant([3])) - learning_rate = 3.0 - ada_opt = adagrad.Adagrad(learning_rate) - ada_update = ada_opt.apply_gradients( - zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0_np_indices = np.array([0, 2], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np[grads0_np_indices]), + constant_op.constant(grads0_np_indices), constant_op.constant([3])) + grads1_np_indices = np.array([0, 2], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np[grads1_np_indices]), + constant_op.constant(grads1_np_indices), constant_op.constant([3])) + learning_rate = 3.0 + ada_opt = adagrad.Adagrad(learning_rate) + ada_update = ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) - # Fetch params to validate initial values - self.assertAllClose([1.0, 1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 3.0, 4.0], var1.eval()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 3.0, 4.0], self.evaluate(var1)) - accum0_np = np.array([0.1, 0.1, 0.1], dtype=dtype.as_numpy_dtype) - accum1_np = np.array([0.1, 0.1, 0.1], dtype=dtype.as_numpy_dtype) + accum0_np = np.array([0.1, 0.1, 0.1], dtype=dtype.as_numpy_dtype) + accum1_np = np.array([0.1, 0.1, 0.1], dtype=dtype.as_numpy_dtype) - # Run 3 step of sgd - for _ in range(3): - ada_update.run() + # Run 3 step of sgd + for _ in range(3): + self.evaluate(ada_update) - var0_np, accum0_np = sparse_adagrad_update_numpy( - var0_np, accum0_np, grads0_np_indices, - grads0_np[grads0_np_indices], learning_rate) - var1_np, accum1_np = sparse_adagrad_update_numpy( - var1_np, accum1_np, grads1_np_indices, - grads1_np[grads1_np_indices], learning_rate) - self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) - self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + var0_np, accum0_np = sparse_adagrad_update_numpy( + var0_np, accum0_np, grads0_np_indices, grads0_np[grads0_np_indices], + learning_rate) + var1_np, accum1_np = sparse_adagrad_update_numpy( + var1_np, accum1_np, grads1_np_indices, grads1_np[grads1_np_indices], + learning_rate) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) @test_util.run_deprecated_v1 def testSparseSingleVarDim(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.cached_session(): - var0_np = np.array([1.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1], dtype=dtype.as_numpy_dtype) + for dtype in _DATA_TYPES: + var0_np = np.array([1.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1], dtype=dtype.as_numpy_dtype) - var0 = resource_variable_ops.ResourceVariable(var0_np) - grads0_np_indices = np.array([0], dtype=np.int32) - grads0 = ops.IndexedSlices( - constant_op.constant(grads0_np[grads0_np_indices]), - constant_op.constant(grads0_np_indices), constant_op.constant([3])) - learning_rate = 3.0 - ada_opt = adagrad.Adagrad(learning_rate, epsilon=1.) - ada_update = ada_opt.apply_gradients(zip([grads0], [var0])) - variables.global_variables_initializer().run() + var0 = resource_variable_ops.ResourceVariable(var0_np) + grads0_np_indices = np.array([0], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np[grads0_np_indices]), + constant_op.constant(grads0_np_indices), constant_op.constant([3])) + learning_rate = 3.0 + ada_opt = adagrad.Adagrad(learning_rate, epsilon=1.) + ada_update = ada_opt.apply_gradients(zip([grads0], [var0])) + self.evaluate(variables.global_variables_initializer()) - # Fetch params to validate initial values - self.assertAllClose([1.0], var0.eval()) + # Fetch params to validate initial values + self.assertAllClose([1.0], self.evaluate(var0)) - accum0_np = np.array([0.1], dtype=dtype.as_numpy_dtype) + accum0_np = np.array([0.1], dtype=dtype.as_numpy_dtype) - # Run 3 step of sgd - for _ in range(3): - ada_update.run() + # Run 3 step of sgd + for _ in range(3): + self.evaluate(ada_update) - var0_np, accum0_np = sparse_adagrad_update_numpy( - var0_np, - accum0_np, - grads0_np_indices, - grads0_np[grads0_np_indices], - learning_rate, - epsilon=1.) - self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + var0_np, accum0_np = sparse_adagrad_update_numpy( + var0_np, + accum0_np, + grads0_np_indices, + grads0_np[grads0_np_indices], + learning_rate, + epsilon=1.) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) @test_util.run_deprecated_v1 def testSparseRepeatedIndices(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.cached_session(): - var_np = np.array([[1.0], [2.0]], dtype=dtype.as_numpy_dtype) + for dtype in _DATA_TYPES: + var_np = np.array([[1.0], [2.0]], dtype=dtype.as_numpy_dtype) - repeated_index_update_var = resource_variable_ops.ResourceVariable( - var_np, dtype=dtype) - aggregated_update_var = resource_variable_ops.ResourceVariable( - var_np, dtype=dtype) - grad_repeated_index = ops.IndexedSlices( - constant_op.constant( - [0.1, 0.1], shape=[2, 1], dtype=dtype), - constant_op.constant([1, 1]), - constant_op.constant([2, 1])) - grad_aggregated = ops.IndexedSlices( - constant_op.constant( - [0.2], shape=[1, 1], dtype=dtype), - constant_op.constant([1]), - constant_op.constant([2, 1])) - repeated_update = adagrad.Adagrad(3.0).apply_gradients( - [(grad_repeated_index, repeated_index_update_var)]) - aggregated_update = adagrad.Adagrad(3.0).apply_gradients( - [(grad_aggregated, aggregated_update_var)]) - variables.global_variables_initializer().run() - self.assertAllClose(aggregated_update_var.eval(), - repeated_index_update_var.eval()) - for _ in range(3): - repeated_update.run() - aggregated_update.run() - self.assertAllClose(aggregated_update_var.eval(), - repeated_index_update_var.eval()) + repeated_index_update_var = resource_variable_ops.ResourceVariable( + var_np, dtype=dtype) + aggregated_update_var = resource_variable_ops.ResourceVariable( + var_np, dtype=dtype) + grad_repeated_index = ops.IndexedSlices( + constant_op.constant([0.1, 0.1], shape=[2, 1], dtype=dtype), + constant_op.constant([1, 1]), constant_op.constant([2, 1])) + grad_aggregated = ops.IndexedSlices( + constant_op.constant([0.2], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), constant_op.constant([2, 1])) + repeated_update = adagrad.Adagrad(3.0).apply_gradients([ + (grad_repeated_index, repeated_index_update_var) + ]) + aggregated_update = adagrad.Adagrad(3.0).apply_gradients([ + (grad_aggregated, aggregated_update_var) + ]) + self.evaluate(variables.global_variables_initializer()) + self.assertAllClose( + self.evaluate(aggregated_update_var), + self.evaluate(repeated_index_update_var)) + for _ in range(3): + self.evaluate(repeated_update) + self.evaluate(aggregated_update) + self.assertAllClose( + self.evaluate(aggregated_update_var), + self.evaluate(repeated_index_update_var)) @test_util.run_deprecated_v1 def testSparseRepeatedIndicesByEmbeddingLookUp(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.cached_session(): - var_repeated = resource_variable_ops.ResourceVariable( - [1.0, 2.0], dtype=dtype) - loss_repeated = lambda: math_ops.reduce_sum( # pylint: disable=g-long-lambda - embedding_ops.embedding_lookup(var_repeated, [0, 0])) # pylint: disable=cell-var-from-loop - var_aggregated = resource_variable_ops.ResourceVariable( - [1.0, 2.0], dtype=dtype) - loss_aggregated = lambda: 2 * math_ops.reduce_sum( # pylint: disable=g-long-lambda - embedding_ops.embedding_lookup(var_aggregated, [0])) # pylint: disable=cell-var-from-loop - update_op_repeated = adagrad.Adagrad(2.0).minimize( - loss_repeated, var_list=[var_repeated]) - update_op_aggregated = adagrad.Adagrad(2.0).minimize( - loss_aggregated, var_list=[var_aggregated]) - variables.global_variables_initializer().run() + for dtype in _DATA_TYPES: + var_repeated = resource_variable_ops.ResourceVariable([1.0, 2.0], + dtype=dtype) + loss_repeated = lambda: math_ops.reduce_sum( # pylint: disable=g-long-lambda + embedding_ops.embedding_lookup(var_repeated, [0, 0])) # pylint: disable=cell-var-from-loop + var_aggregated = resource_variable_ops.ResourceVariable([1.0, 2.0], + dtype=dtype) + loss_aggregated = lambda: 2 * math_ops.reduce_sum( # pylint: disable=g-long-lambda + embedding_ops.embedding_lookup(var_aggregated, [0])) # pylint: disable=cell-var-from-loop + update_op_repeated = adagrad.Adagrad(2.0).minimize( + loss_repeated, var_list=[var_repeated]) + update_op_aggregated = adagrad.Adagrad(2.0).minimize( + loss_aggregated, var_list=[var_aggregated]) + self.evaluate(variables.global_variables_initializer()) + self.assertAllCloseAccordingToType( + self.evaluate(var_repeated), self.evaluate(var_aggregated)) + for _ in range(3): + self.evaluate(update_op_repeated) + self.evaluate(update_op_aggregated) self.assertAllCloseAccordingToType( - var_repeated.eval(), var_aggregated.eval()) - for _ in range(3): - update_op_repeated.run() - update_op_aggregated.run() - self.assertAllCloseAccordingToType( - var_repeated.eval(), var_aggregated.eval()) + self.evaluate(var_repeated), self.evaluate(var_aggregated)) @test_util.run_deprecated_v1 def testSparseStability(self): for dtype in [dtypes.half]: - with self.cached_session(): - shape = [1, 6] - var0_np = np.array([[ - 0.00872496, -0.106952, 0.110467, 0.226505, -0.0147257, -0.0105945 - ]], + shape = [1, 6] + var0_np = np.array( + [[0.00872496, -0.106952, 0.110467, 0.226505, -0.0147257, -0.0105945]], + dtype=dtype.as_numpy_dtype) + var0 = resource_variable_ops.ResourceVariable(var0_np) + grads0_np = np.array([[ + -5.91278e-05, 5.31673e-05, -2.5779e-06, 4.29153e-05, -8.4877e-05, + -9.48906e-05 + ]], dtype=dtype.as_numpy_dtype) - var0 = resource_variable_ops.ResourceVariable(var0_np) - grads0_np = np.array([[ - -5.91278e-05, 5.31673e-05, -2.5779e-06, 4.29153e-05, -8.4877e-05, - -9.48906e-05 - ]], - dtype=dtype.as_numpy_dtype) - grads0 = ops.IndexedSlices( - constant_op.constant(grads0_np), constant_op.constant([0]), - constant_op.constant(shape)) - ada_opt = adagrad.Adagrad(1.0) - ada_update = ada_opt.apply_gradients(zip([grads0], [var0])) - slot0 = ada_opt.get_slot(var0, "accumulator") - init = variables.global_variables_initializer() - for _ in range(100): - init.run() - ada_update.run() - self.assertAllCloseAccordingToType( - np.array([[0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]), slot0.eval()) - self.assertAllCloseAccordingToType( - np.array([[ - 0.00891194, -0.10712013, 0.11047515, 0.22636929, -0.0144573, - -0.01029443 - ]]), var0.eval()) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np), constant_op.constant([0]), + constant_op.constant(shape)) + ada_opt = adagrad.Adagrad(1.0) + ada_update = ada_opt.apply_gradients(zip([grads0], [var0])) + slot0 = ada_opt.get_slot(var0, "accumulator") + init = variables.global_variables_initializer() + for _ in range(100): + self.evaluate(init) + self.evaluate(ada_update) + self.assertAllCloseAccordingToType( + np.array([[0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]), self.evaluate(slot0)) + self.assertAllCloseAccordingToType( + np.array([[ + 0.00891194, -0.10712013, 0.11047515, 0.22636929, -0.0144573, + -0.01029443 + ]]), self.evaluate(var0)) @test_util.run_deprecated_v1 def testSharing(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.cached_session(): - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + for dtype in _DATA_TYPES: + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) - var0 = resource_variable_ops.ResourceVariable(var0_np) - var1 = resource_variable_ops.ResourceVariable(var1_np) - grads0 = constant_op.constant(grads0_np) - grads1 = constant_op.constant(grads1_np) + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) - learning_rate = 3.0 - ada_opt = adagrad.Adagrad(learning_rate) - # Apply the optimizer twice. Both applications will use - # the same accums. - ada_update1 = ada_opt.apply_gradients( - zip([grads0, grads1], [var0, var1])) - ada_update2 = ada_opt.apply_gradients( - zip([grads0, grads1], [var0, var1])) - slot0 = ada_opt.get_slot(var0, "accumulator") - self.assertEqual(slot0.shape, var0.shape) - slot1 = ada_opt.get_slot(var1, "accumulator") - self.assertEqual(slot1.shape, var1.shape) - variables.global_variables_initializer().run() + learning_rate = 3.0 + ada_opt = adagrad.Adagrad(learning_rate) + # Apply the optimizer twice. Both applications will use + # the same accums. + ada_update1 = ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + ada_update2 = ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + slot0 = ada_opt.get_slot(var0, "accumulator") + self.assertEqual(slot0.shape, var0.shape) + slot1 = ada_opt.get_slot(var1, "accumulator") + self.assertEqual(slot1.shape, var1.shape) + self.evaluate(variables.global_variables_initializer()) - # Fetch params to validate initial values. - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - # Mix the first and the second adagrad for 3 steps. - ada_update1.run() - ada_update2.run() - ada_update1.run() + # Fetch params to validate initial values. + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + # Mix the first and the second adagrad for 3 steps. + self.evaluate(ada_update1) + self.evaluate(ada_update2) + self.evaluate(ada_update1) - accum0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - accum1_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - for _ in range(3): - var0_np, accum0_np = adagrad_update_numpy(var0_np, accum0_np, - grads0_np, learning_rate) - var1_np, accum1_np = adagrad_update_numpy(var1_np, accum1_np, - grads1_np, learning_rate) - self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) - self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + accum0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + accum1_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + for _ in range(3): + var0_np, accum0_np = adagrad_update_numpy(var0_np, accum0_np, grads0_np, + learning_rate) + var1_np, accum1_np = adagrad_update_numpy(var1_np, accum1_np, grads1_np, + learning_rate) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testConstructAdagradWithLR(self): opt = adagrad.Adagrad(lr=1.0) diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index 6547aa8244b..1e97ae469bb 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -851,8 +851,10 @@ class OptimizerV2(trackable.Trackable): Returns: Valid types for loss, variables and gradients. """ - return set( - [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]) + return set([ + dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64, + dtypes.complex64, dtypes.complex128 + ]) def _call_if_callable(self, param): """Call the function if param is callable.""" diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py index 1caa4dc85de..501d05e453f 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py @@ -65,8 +65,11 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testBasic(self): - for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - with self.cached_session(use_gpu=True): + for _, dtype in enumerate([ + dtypes.half, dtypes.float32, dtypes.float64, dtypes.complex64, + dtypes.complex128 + ]): + with test_util.use_gpu(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) loss = lambda: 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop @@ -86,7 +89,10 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testAdaptiveLearningRate(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + for dtype in [ + dtypes.half, dtypes.float32, dtypes.float64, dtypes.complex64, + dtypes.complex128 + ]: var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) @@ -129,8 +135,11 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testPrecomputedGradient(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.cached_session(use_gpu=True): + for dtype in [ + dtypes.half, dtypes.float32, dtypes.float64, dtypes.complex64, + dtypes.complex128 + ]: + with test_util.use_gpu(): var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([3.0, 4.0], dtype=dtype) loss = lambda: 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop @@ -153,8 +162,11 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testNoGradients(self): - for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - with self.cached_session(use_gpu=True): + for _, dtype in enumerate([ + dtypes.half, dtypes.float32, dtypes.float64, dtypes.complex64, + dtypes.complex128 + ]): + with test_util.use_gpu(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) loss = lambda: 5 * var0 # pylint: disable=cell-var-from-loop @@ -165,8 +177,11 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testNoGradientsForAnyVariables_Minimize(self): - for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - with self.cached_session(use_gpu=True): + for _, dtype in enumerate([ + dtypes.half, dtypes.float32, dtypes.float64, dtypes.complex64, + dtypes.complex128 + ]): + with test_util.use_gpu(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) loss = lambda: constant_op.constant(5.0) @@ -178,8 +193,11 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testNoGradientsForAnyVariables_ApplyGradients(self): - for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - with self.cached_session(use_gpu=True): + for _, dtype in enumerate([ + dtypes.half, dtypes.float32, dtypes.float64, dtypes.complex64, + dtypes.complex128 + ]): + with test_util.use_gpu(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) sgd_op = gradient_descent.SGD(3.0) @@ -189,8 +207,11 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testGradientsAsVariables(self): - for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - with self.cached_session(use_gpu=True): + for i, dtype in enumerate([ + dtypes.half, dtypes.float32, dtypes.float64, dtypes.complex64, + dtypes.complex128 + ]): + with test_util.use_gpu(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) loss = lambda: 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop @@ -228,7 +249,7 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testComputeGradientsWithTensors(self): - with self.cached_session(use_gpu=True): + with test_util.use_gpu(): x = ops.convert_to_tensor(1.0) def f(): @@ -248,7 +269,7 @@ class OptimizerTest(test.TestCase): def testConstraint(self): constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.) constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.) - with self.cached_session(use_gpu=True): + with test_util.use_gpu(): var0 = variables.Variable([1.0, 2.0], constraint=constraint_01) var1 = variables.Variable([3.0, 4.0], @@ -270,14 +291,14 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testIterationWithoutMinimize(self): - with self.cached_session(use_gpu=True): + with test_util.use_gpu(): sgd = gradient_descent.SGD(3.0) self.evaluate(sgd.iterations.initializer) self.assertEqual(0, self.evaluate(sgd.iterations)) @test_util.run_in_graph_and_eager_modes def testConfig(self): - with self.cached_session(use_gpu=True): + with test_util.use_gpu(): opt = gradient_descent.SGD(learning_rate=1.0) config = opt.get_config() opt2 = gradient_descent.SGD.from_config(config) @@ -297,7 +318,7 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testConfigWithLearningRateDecay(self): - with self.cached_session(use_gpu=True): + with test_util.use_gpu(): var0 = variables.Variable([[1.0], [2.0]], dtype=dtypes.float32) for decay_schedule in [ learning_rate_schedule.InverseTimeDecay( @@ -328,7 +349,7 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testGradClipValue(self): - with self.cached_session(use_gpu=True): + with test_util.use_gpu(): var = resource_variable_ops.ResourceVariable([1.0, 2.0]) loss = lambda: 3 * var opt = gradient_descent.SGD(learning_rate=1.0, clipvalue=1.0) @@ -339,7 +360,7 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testGradClipNorm(self): - with self.cached_session(use_gpu=True): + with test_util.use_gpu(): var = resource_variable_ops.ResourceVariable([1.0]) loss = lambda: 3 * var opt = gradient_descent.SGD(learning_rate=1.0, clipnorm=1.0) @@ -360,7 +381,7 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testWeights(self): - with self.cached_session(use_gpu=True): + with test_util.use_gpu(): opt1 = adam.Adam(learning_rate=1.0) var1 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtypes.float32) @@ -627,7 +648,7 @@ class OptimizersCompatibilityTest(keras_parameterized.TestCase): 'v1 optimizer does not run in experimental_run_tf_function mode or ' 'eager mode') np.random.seed(1331) - with self.cached_session(use_gpu=True): + with test_util.use_gpu(): train_samples = 20 input_dim = 3 num_classes = 2 @@ -715,7 +736,7 @@ class OptimizersCompatibilityTest(keras_parameterized.TestCase): 'v1 optimizer does not run in experimental_run_tf_function mode or ' 'eager mode') np.random.seed(1331) - with self.cached_session(use_gpu=True): + with test_util.use_gpu(): train_samples = 20 input_dim = 3 num_classes = 2 @@ -776,7 +797,7 @@ class OptimizersCompatibilityTest(keras_parameterized.TestCase): 'v1 optimizer does not run in experimental_run_tf_function mode or ' 'eager mode') np.random.seed(1331) - with self.cached_session(use_gpu=True): + with test_util.use_gpu(): train_samples = 20 input_dim = 3 num_classes = 2 diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop_test.py b/tensorflow/python/keras/optimizer_v2/rmsprop_test.py index 87c1e56bd7c..d4de0b5b7e9 100644 --- a/tensorflow/python/keras/optimizer_v2/rmsprop_test.py +++ b/tensorflow/python/keras/optimizer_v2/rmsprop_test.py @@ -38,7 +38,10 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test -_DATA_TYPES = [dtypes.half, dtypes.float32] +_DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64] +# TODO(b/143684500): Eigen to support complex sqrt +if not test_util.IsBuiltWithNvcc(): + _DATA_TYPES += [dtypes.complex64, dtypes.complex128] _TEST_PARAM_VALUES = [ # learning_rate, rho, momentum, epsilon, centered @@ -137,9 +140,9 @@ class RMSpropOptimizerTest(test.TestCase): mom1 = None rms0 = opt.get_slot(var0, "rms") - self.assertTrue(rms0 is not None) + self.assertIsNotNone(rms0) rms1 = opt.get_slot(var1, "rms") - self.assertTrue(rms1 is not None) + self.assertIsNotNone(rms1) mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) @@ -204,9 +207,9 @@ class RMSpropOptimizerTest(test.TestCase): self.evaluate(variables.global_variables_initializer()) rms0 = opt.get_slot(var0, "rms") - self.assertTrue(rms0 is not None) + self.assertIsNotNone(rms0) rms1 = opt.get_slot(var1, "rms") - self.assertTrue(rms1 is not None) + self.assertIsNotNone(rms1) if momentum > 0.: mom0 = opt.get_slot(var0, "momentum") mom1 = opt.get_slot(var1, "momentum") @@ -276,9 +279,9 @@ class RMSpropOptimizerTest(test.TestCase): self.evaluate(variables.global_variables_initializer()) rms0 = opt.get_slot(var0, "rms") - self.assertTrue(rms0 is not None) + self.assertIsNotNone(rms0) rms1 = opt.get_slot(var1, "rms") - self.assertTrue(rms1 is not None) + self.assertIsNotNone(rms1) if momentum > 0.: mom0 = opt.get_slot(var0, "momentum") mom1 = opt.get_slot(var1, "momentum") @@ -320,60 +323,54 @@ class RMSpropOptimizerTest(test.TestCase): @test_util.run_deprecated_v1 def testMinimizeSparseResourceVariable(self): - for dtype in [dtypes.float32, dtypes.float64]: - with self.cached_session(): - var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) - x = constant_op.constant([[4.0], [5.0]], dtype=dtype) + for dtype in _DATA_TYPES: + var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) + x = constant_op.constant([[4.0], [5.0]], dtype=dtype) - def loss(): - pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) # pylint: disable=cell-var-from-loop - return pred * pred + def loss(): + pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) # pylint: disable=cell-var-from-loop + return pred * pred - sgd_op = rmsprop.RMSprop( - learning_rate=1.0, - rho=0.0, - momentum=0.0, - epsilon=0.0, - centered=False).minimize( - loss, var_list=[var0]) - self.evaluate(variables.global_variables_initializer()) - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([[1.0, 2.0]], self.evaluate(var0)) - # Run 1 step of sgd - self.evaluate(sgd_op) - # Validate updated params - self.assertAllCloseAccordingToType([[0., 1.]], - self.evaluate(var0), - atol=0.01) + sgd_op = rmsprop.RMSprop( + learning_rate=1.0, rho=0.0, momentum=0.0, epsilon=0.0, + centered=False).minimize( + loss, var_list=[var0]) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([[1.0, 2.0]], self.evaluate(var0)) + # Run 1 step of sgd + self.evaluate(sgd_op) + # Validate updated params + self.assertAllCloseAccordingToType([[0., 1.]], + self.evaluate(var0), + atol=0.01) @test_util.run_deprecated_v1 def testMinimizeSparseResourceVariableCentered(self): - for dtype in [dtypes.float32, dtypes.float64]: - with self.cached_session(): - var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) - x = constant_op.constant([[4.0], [5.0]], dtype=dtype) + for dtype in _DATA_TYPES: + if test_util.is_xla_enabled() and dtype.is_complex: + self.skipTest("b/143578550") + var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) + x = constant_op.constant([[4.0], [5.0]], dtype=dtype) - def loss(): - pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) # pylint: disable=cell-var-from-loop - return pred * pred + def loss(): + pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) # pylint: disable=cell-var-from-loop + return pred * pred - # loss = lambda: pred * pred # pylint: disable=cell-var-from-loop - sgd_op = rmsprop.RMSprop( - learning_rate=1.0, - rho=0.0, - momentum=0.0, - epsilon=1.0, - centered=True).minimize( - loss, var_list=[var0]) - self.evaluate(variables.global_variables_initializer()) - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([[1.0, 2.0]], self.evaluate(var0)) - # Run 1 step of sgd - self.evaluate(sgd_op) - # Validate updated params - self.assertAllCloseAccordingToType([[-111, -138]], - self.evaluate(var0), - atol=0.01) + # loss = lambda: pred * pred # pylint: disable=cell-var-from-loop + sgd_op = rmsprop.RMSprop( + learning_rate=1.0, rho=0.0, momentum=0.0, epsilon=1.0, + centered=True).minimize( + loss, var_list=[var0]) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([[1.0, 2.0]], self.evaluate(var0)) + # Run 1 step of sgd + self.evaluate(sgd_op) + # Validate updated params + self.assertAllCloseAccordingToType([[-111, -138]], + self.evaluate(var0), + atol=0.01) @test_util.run_deprecated_v1 def testSparse(self): @@ -413,9 +410,9 @@ class RMSpropOptimizerTest(test.TestCase): mg0 = None mg1 = None rms0 = opt.get_slot(var0, "rms") - self.assertTrue(rms0 is not None) + self.assertIsNotNone(rms0) rms1 = opt.get_slot(var1, "rms") - self.assertTrue(rms1 is not None) + self.assertIsNotNone(rms1) if momentum > 0.: mom0 = opt.get_slot(var0, "momentum") mom1 = opt.get_slot(var1, "momentum") @@ -459,7 +456,7 @@ class RMSpropOptimizerTest(test.TestCase): def testCallableParams(self): with context.eager_mode(): - for dtype in [dtypes.half, dtypes.float32]: + for dtype in _DATA_TYPES: var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)