Remove the private API usage of vriables.dtypes.
PiperOrigin-RevId: 321645728 Change-Id: I6a91363a43bceb4c0954f48fdf287cdceffebcb7
This commit is contained in:
parent
44ce8c4417
commit
77039288bb
@ -25,6 +25,7 @@ from absl.testing import parameterized
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import combinations
|
||||
@ -143,7 +144,7 @@ class LRDecayTestV2(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
def testPiecewiseConstantEdgeCases(self, serialize):
|
||||
# Test casting boundaries from int32 to int64.
|
||||
x_int64 = variables.Variable(0, dtype=variables.dtypes.int64)
|
||||
x_int64 = variables.Variable(0, dtype=dtypes.int64)
|
||||
boundaries, values = [1, 2, 3], [0.4, 0.5, 0.6, 0.7]
|
||||
decayed_lr = learning_rate_schedule.PiecewiseConstantDecay(
|
||||
boundaries, values)
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import math
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras.optimizer_v2 import legacy_learning_rate_decay as learning_rate_decay
|
||||
@ -101,7 +102,7 @@ class LRDecayTest(keras_parameterized.TestCase):
|
||||
self.assertAllClose(self.evaluate(decayed_lr), 0.001, 1e-6)
|
||||
|
||||
def testPiecewiseConstantEdgeCases(self):
|
||||
x_int = variables.Variable(0, dtype=variables.dtypes.int32)
|
||||
x_int = variables.Variable(0, dtype=dtypes.int32)
|
||||
boundaries, values = [-1.0, 1.0], [1, 2, 3]
|
||||
with self.assertRaises(ValueError):
|
||||
decayed_lr = learning_rate_decay.piecewise_constant(
|
||||
@ -125,7 +126,7 @@ class LRDecayTest(keras_parameterized.TestCase):
|
||||
learning_rate_decay.piecewise_constant(x_ref, boundaries, values)
|
||||
|
||||
# Test casting boundaries from int32 to int64.
|
||||
x_int64 = variables.Variable(0, dtype=variables.dtypes.int64)
|
||||
x_int64 = variables.Variable(0, dtype=dtypes.int64)
|
||||
boundaries, values = [1, 2, 3], [0.4, 0.5, 0.6, 0.7]
|
||||
decayed_lr = learning_rate_decay.piecewise_constant(
|
||||
x_int64, boundaries, values)
|
||||
|
Loading…
Reference in New Issue
Block a user