Remove the usage of TF private API ops.convert_n_to_tensor from Keras

PiperOrigin-RevId: 320426550
Change-Id: Ia2d82b85193e191e4452a39d091a9789729fc639
This commit is contained in:
Pavithra Vijay 2020-07-09 10:39:11 -07:00 committed by TensorFlower Gardener
parent 1cb7ec92a2
commit 73758fc98e
3 changed files with 7 additions and 5 deletions

View File

@ -71,7 +71,7 @@ class KerasSumTest(test.TestCase, parameterized.TestCase):
self.assertEqual(self.evaluate(m.total), 100)
# check update_state() and result() + state accumulation + tensor input
update_op = m.update_state(ops.convert_n_to_tensor([1, 5]))
update_op = m.update_state(ops.convert_to_tensor_v2([1, 5]))
self.evaluate(update_op)
self.assertAlmostEqual(self.evaluate(m.result()), 106)
self.assertEqual(self.evaluate(m.total), 106) # 100 + 1 + 5

View File

@ -26,6 +26,7 @@ from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export
@ -236,8 +237,8 @@ class PiecewiseConstantDecay(LearningRateSchedule):
def __call__(self, step):
with ops.name_scope_v2(self.name or "PiecewiseConstant"):
boundaries = ops.convert_n_to_tensor(self.boundaries)
values = ops.convert_n_to_tensor(self.values)
boundaries = nest.map_structure(ops.convert_to_tensor_v2, self.boundaries)
values = nest.map_structure(ops.convert_to_tensor_v2, self.values)
x_recomp = ops.convert_to_tensor_v2(step)
for i, b in enumerate(boundaries):
if b.dtype.base_dtype != x_recomp.dtype.base_dtype:

View File

@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
from tensorflow.python.ops import math_ops
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@ -147,8 +148,8 @@ def piecewise_constant(x, boundaries, values, name=None):
the learning rate value across different invocations of optimizer functions.
@end_compatibility
"""
boundaries = ops.convert_n_to_tensor(boundaries)
values = ops.convert_n_to_tensor(values)
boundaries = nest.map_structure(ops.convert_to_tensor_v2, boundaries)
values = nest.map_structure(ops.convert_to_tensor_v2, values)
x_recomp = ops.convert_to_tensor(x)
# Avoid explicit conversion to x's dtype. This could result in faulty
# comparisons, for example if floats are converted to integers.