Fixing compat due to conditional accumulator changes.\

PiperOrigin-RevId: 260979763
This commit is contained in:
A. Unique TensorFlower 2019-07-31 12:45:30 -07:00 committed by TensorFlower Gardener
parent 012a1167d2
commit 1d184409ae
4 changed files with 34 additions and 38 deletions

View File

@ -5116,7 +5116,6 @@ tf_py_test(
grpc_enabled = True,
tags = [
"no_oss", # Test flaky due to port collisions.
"nofwdcompat", # b/137641346
"notsan", # data race due to b/62910646
"oss_serial",
],

View File

@ -271,7 +271,6 @@ tf_py_test(
"//tensorflow/python:state_ops",
"//tensorflow/python:variables",
],
tags = ["nofwdcompat"], # b/137641346
)
tf_py_test(
@ -984,7 +983,6 @@ tf_py_test(
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
],
tags = ["nofwdcompat"], # b/137641346
)
tf_py_test(

View File

@ -39,47 +39,12 @@ from tensorflow.python.platform import test
class ConditionalAccumulatorTest(test.TestCase):
def testConstructor(self):
with ops.Graph().as_default():
q = data_flow_ops.ConditionalAccumulator(dtypes_lib.float32, name="Q")
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
self.assertProtoEquals(
"""
name:'Q' op:'ConditionalAccumulator'
attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'shape' value { shape { unknown_rank: true} } }
attr { key: 'container' value { s: '' } }
attr { key: 'shared_name' value { s: '' } }
attr { key: 'reduction_type' value {s: 'MEAN'} }
""", q.accumulator_ref.op.node_def)
def testConstructorWithInvalidArg(self):
with ops.Graph().as_default():
with self.assertRaises(ValueError):
data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", reduction_type="Invalid")
def testConstructorWithShape(self):
with ops.Graph().as_default():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32,
name="Q",
shape=tensor_shape.TensorShape([1, 5, 2, 8]))
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
self.assertProtoEquals(
"""
name:'Q' op:'ConditionalAccumulator'
attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'shape' value { shape { dim {size: 1 }
dim {size: 5 }
dim {size: 2 }
dim {size: 8 }
} } }
attr { key: 'container' value { s: '' } }
attr { key: 'shared_name' value { s: '' } }
attr { key: 'reduction_type' value {s: 'MEAN'} }
""", q.accumulator_ref.op.node_def)
@test_util.run_deprecated_v1
def testAccumulatorSizeEmpty(self):
with self.cached_session():

View File

@ -1518,6 +1518,40 @@ class SparseConditionalAccumulator(ConditionalAccumulatorBase):
values=return_val.values,
dense_shape=return_val.shape)
# SparseConditionalAccumulator is not switched to resource. Use old kernels.
def num_accumulated(self, name=None):
"""Number of gradients that have currently been aggregated in accumulator.
Args:
name: Optional name for the operation.
Returns:
Number of accumulated gradients currently in accumulator.
"""
if name is None:
name = "%s_NumAccumulated" % self._name
return gen_data_flow_ops.accumulator_num_accumulated(
self._accumulator_ref, name=name)
def set_global_step(self, new_global_step, name=None):
"""Sets the global time step of the accumulator.
The operation logs a warning if we attempt to set to a time step that is
lower than the accumulator's own time step.
Args:
new_global_step: Value of new time step. Can be a variable or a constant
name: Optional name for the operation.
Returns:
Operation that sets the accumulator's time step.
"""
return gen_data_flow_ops.accumulator_set_global_step(
self._accumulator_ref,
math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64),
name=name)
class BaseStagingArea(object):
"""Base class for Staging Areas."""