Boosted Trees: delay switching to version 2 cond (even when running in v1).
This is breaking some tests. Delaying until we have a fix. PiperOrigin-RevId: 257199240
This commit is contained in:
parent
f425b028ad
commit
8305f9319f
@ -1217,7 +1217,7 @@ class ConditionalAccumulatorBase(object):
|
||||
if name is None:
|
||||
name = "%s_NumAccumulated" % self._name
|
||||
|
||||
if compat.forward_compatible(2019, 7, 8):
|
||||
if compat.forward_compatible(2019, 8, 8):
|
||||
return gen_data_flow_ops.resource_accumulator_num_accumulated(
|
||||
self._accumulator_ref, name=name)
|
||||
|
||||
@ -1237,7 +1237,7 @@ class ConditionalAccumulatorBase(object):
|
||||
Returns:
|
||||
Operation that sets the accumulator's time step.
|
||||
"""
|
||||
if compat.forward_compatible(2019, 7, 8):
|
||||
if compat.forward_compatible(2019, 8, 8):
|
||||
return gen_data_flow_ops.resource_accumulator_set_global_step(
|
||||
self._accumulator_ref,
|
||||
math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64),
|
||||
@ -1276,7 +1276,7 @@ class ConditionalAccumulator(ConditionalAccumulatorBase):
|
||||
name: Optional name for the accumulator.
|
||||
reduction_type: Reduction type to use when taking the gradient.
|
||||
"""
|
||||
if compat.forward_compatible(2019, 7, 8):
|
||||
if compat.forward_compatible(2019, 8, 8):
|
||||
accumulator_ref = gen_data_flow_ops.resource_conditional_accumulator(
|
||||
dtype=dtype,
|
||||
shape=shape,
|
||||
@ -1316,7 +1316,7 @@ class ConditionalAccumulator(ConditionalAccumulatorBase):
|
||||
grad.get_shape().assert_is_compatible_with(self._shape)
|
||||
local_step = math_ops.cast(ops.convert_to_tensor(local_step), _dtypes.int64)
|
||||
|
||||
if compat.forward_compatible(2019, 7, 8):
|
||||
if compat.forward_compatible(2019, 8, 8):
|
||||
return gen_data_flow_ops.resource_accumulator_apply_gradient(
|
||||
self._accumulator_ref,
|
||||
local_step=local_step,
|
||||
@ -1347,7 +1347,7 @@ class ConditionalAccumulator(ConditionalAccumulatorBase):
|
||||
Raises:
|
||||
InvalidArgumentError: If num_required < 1
|
||||
"""
|
||||
if compat.forward_compatible(2019, 7, 8):
|
||||
if compat.forward_compatible(2019, 8, 8):
|
||||
out = gen_data_flow_ops.resource_accumulator_take_gradient(
|
||||
self._accumulator_ref, num_required, dtype=self._dtype, name=name)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user