Ensure input shape validation always runs in ScopedAllocatorOptimizer.
This change fixes a bug in the optimizer: the check for valid shape should always happen, irrespective of other checks. Also enable Adagrad optimizer in CTL correctness check which triggers this bug. PiperOrigin-RevId: 323648534 Change-Id: I2f785f6ca9e38ce63de12b6ceabfb7c405b341f7
This commit is contained in:
parent
69ca56e7f4
commit
9e475aa305
@ -118,18 +118,18 @@ Status CheckTypesAndGetShapes(const GraphProperties& graph_properties,
|
|||||||
*type = props.dtype();
|
*type = props.dtype();
|
||||||
} else if (*type != props.dtype()) {
|
} else if (*type != props.dtype()) {
|
||||||
return errors::Internal("Group ops don't all have same type");
|
return errors::Internal("Group ops don't all have same type");
|
||||||
} else if (!TensorShape::IsValid(props.shape()) ||
|
|
||||||
props.shape().unknown_rank()) {
|
|
||||||
// TensorShape::IsValid may return true if unknown_rank is True, i.e.
|
|
||||||
// number of dimensions is unknown. But for ScopedAllocatorOptimizer we
|
|
||||||
// need to know the shape fully.
|
|
||||||
return errors::Internal("Complete shape not known for ", n->name());
|
|
||||||
}
|
}
|
||||||
if (*type != dtype) {
|
if (*type != dtype) {
|
||||||
return errors::Internal(
|
return errors::Internal(
|
||||||
"Type mismatch: type in op attr = ", DataTypeString(dtype),
|
"Type mismatch: type in op attr = ", DataTypeString(dtype),
|
||||||
", type in output props = ", DataTypeString(*type));
|
", type in output props = ", DataTypeString(*type));
|
||||||
}
|
}
|
||||||
|
if (!TensorShape::IsValid(props.shape()) || props.shape().unknown_rank()) {
|
||||||
|
// TensorShape::IsValid may return true if unknown_rank is True, i.e.
|
||||||
|
// number of dimensions is unknown. But for ScopedAllocatorOptimizer we
|
||||||
|
// need to know the shape fully.
|
||||||
|
return errors::Internal("Complete shape not known for ", n->name());
|
||||||
|
}
|
||||||
VLOG(2) << "Adding shape " << props.shape().DebugString();
|
VLOG(2) << "Adding shape " << props.shape().DebugString();
|
||||||
shapes->push_back(TensorShape(props.shape()));
|
shapes->push_back(TensorShape(props.shape()));
|
||||||
}
|
}
|
||||||
|
@ -234,8 +234,10 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase,
|
|||||||
sync_batchnorm=[True, False]) +
|
sync_batchnorm=[True, False]) +
|
||||||
combinations.combine(
|
combinations.combine(
|
||||||
distribution=strategy_combinations.multiworker_strategies,
|
distribution=strategy_combinations.multiworker_strategies,
|
||||||
optimizer_fn=
|
optimizer_fn=[
|
||||||
optimizer_combinations.gradient_descent_optimizer_keras_v2_fn,
|
optimizer_combinations.gradient_descent_optimizer_keras_v2_fn,
|
||||||
|
optimizer_combinations.adagrad_optimizer_keras_v2_fn
|
||||||
|
],
|
||||||
mode=['eager'],
|
mode=['eager'],
|
||||||
iteration_type=['iterator', 'dataset'],
|
iteration_type=['iterator', 'dataset'],
|
||||||
inside_func=[False, True],
|
inside_func=[False, True],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user