:
Fixing compat due to conditional accumulator changes.\ PiperOrigin-RevId: 260979763
This commit is contained in:
parent
012a1167d2
commit
1d184409ae
@ -5116,7 +5116,6 @@ tf_py_test(
|
||||
grpc_enabled = True,
|
||||
tags = [
|
||||
"no_oss", # Test flaky due to port collisions.
|
||||
"nofwdcompat", # b/137641346
|
||||
"notsan", # data race due to b/62910646
|
||||
"oss_serial",
|
||||
],
|
||||
|
@ -271,7 +271,6 @@ tf_py_test(
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
tags = ["nofwdcompat"], # b/137641346
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
@ -984,7 +983,6 @@ tf_py_test(
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
],
|
||||
tags = ["nofwdcompat"], # b/137641346
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
|
@ -39,47 +39,12 @@ from tensorflow.python.platform import test
|
||||
|
||||
class ConditionalAccumulatorTest(test.TestCase):
|
||||
|
||||
def testConstructor(self):
|
||||
with ops.Graph().as_default():
|
||||
q = data_flow_ops.ConditionalAccumulator(dtypes_lib.float32, name="Q")
|
||||
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
|
||||
self.assertProtoEquals(
|
||||
"""
|
||||
name:'Q' op:'ConditionalAccumulator'
|
||||
attr { key: 'dtype' value { type: DT_FLOAT } }
|
||||
attr { key: 'shape' value { shape { unknown_rank: true} } }
|
||||
attr { key: 'container' value { s: '' } }
|
||||
attr { key: 'shared_name' value { s: '' } }
|
||||
attr { key: 'reduction_type' value {s: 'MEAN'} }
|
||||
""", q.accumulator_ref.op.node_def)
|
||||
|
||||
def testConstructorWithInvalidArg(self):
|
||||
with ops.Graph().as_default():
|
||||
with self.assertRaises(ValueError):
|
||||
data_flow_ops.ConditionalAccumulator(
|
||||
dtypes_lib.float32, name="Q", reduction_type="Invalid")
|
||||
|
||||
def testConstructorWithShape(self):
|
||||
with ops.Graph().as_default():
|
||||
q = data_flow_ops.ConditionalAccumulator(
|
||||
dtypes_lib.float32,
|
||||
name="Q",
|
||||
shape=tensor_shape.TensorShape([1, 5, 2, 8]))
|
||||
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
|
||||
self.assertProtoEquals(
|
||||
"""
|
||||
name:'Q' op:'ConditionalAccumulator'
|
||||
attr { key: 'dtype' value { type: DT_FLOAT } }
|
||||
attr { key: 'shape' value { shape { dim {size: 1 }
|
||||
dim {size: 5 }
|
||||
dim {size: 2 }
|
||||
dim {size: 8 }
|
||||
} } }
|
||||
attr { key: 'container' value { s: '' } }
|
||||
attr { key: 'shared_name' value { s: '' } }
|
||||
attr { key: 'reduction_type' value {s: 'MEAN'} }
|
||||
""", q.accumulator_ref.op.node_def)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAccumulatorSizeEmpty(self):
|
||||
with self.cached_session():
|
||||
|
@ -1518,6 +1518,40 @@ class SparseConditionalAccumulator(ConditionalAccumulatorBase):
|
||||
values=return_val.values,
|
||||
dense_shape=return_val.shape)
|
||||
|
||||
# SparseConditionalAccumulator is not switched to resource. Use old kernels.
|
||||
def num_accumulated(self, name=None):
|
||||
"""Number of gradients that have currently been aggregated in accumulator.
|
||||
|
||||
Args:
|
||||
name: Optional name for the operation.
|
||||
|
||||
Returns:
|
||||
Number of accumulated gradients currently in accumulator.
|
||||
"""
|
||||
if name is None:
|
||||
name = "%s_NumAccumulated" % self._name
|
||||
|
||||
return gen_data_flow_ops.accumulator_num_accumulated(
|
||||
self._accumulator_ref, name=name)
|
||||
|
||||
def set_global_step(self, new_global_step, name=None):
|
||||
"""Sets the global time step of the accumulator.
|
||||
|
||||
The operation logs a warning if we attempt to set to a time step that is
|
||||
lower than the accumulator's own time step.
|
||||
|
||||
Args:
|
||||
new_global_step: Value of new time step. Can be a variable or a constant
|
||||
name: Optional name for the operation.
|
||||
|
||||
Returns:
|
||||
Operation that sets the accumulator's time step.
|
||||
"""
|
||||
return gen_data_flow_ops.accumulator_set_global_step(
|
||||
self._accumulator_ref,
|
||||
math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64),
|
||||
name=name)
|
||||
|
||||
|
||||
class BaseStagingArea(object):
|
||||
"""Base class for Staging Areas."""
|
||||
|
Loading…
Reference in New Issue
Block a user