Ensure input shape validation always runs in ScopedAllocatorOptimizer.

This change fixes a bug in the optimizer: the check for valid shape should
always happen, irrespective of other checks.

Also enable Adagrad optimizer in CTL correctness check which triggers this bug.

PiperOrigin-RevId: 323648534
Change-Id: I2f785f6ca9e38ce63de12b6ceabfb7c405b341f7
This commit is contained in:
Ayush Dubey 2020-07-28 14:02:59 -07:00 committed by TensorFlower Gardener
parent 69ca56e7f4
commit 9e475aa305
2 changed files with 10 additions and 8 deletions

View File

@ -118,18 +118,18 @@ Status CheckTypesAndGetShapes(const GraphProperties& graph_properties,
*type = props.dtype(); *type = props.dtype();
} else if (*type != props.dtype()) { } else if (*type != props.dtype()) {
return errors::Internal("Group ops don't all have same type"); return errors::Internal("Group ops don't all have same type");
} else if (!TensorShape::IsValid(props.shape()) ||
props.shape().unknown_rank()) {
// TensorShape::IsValid may return true if unknown_rank is True, i.e.
// number of dimensions is unknown. But for ScopedAllocatorOptimizer we
// need to know the shape fully.
return errors::Internal("Complete shape not known for ", n->name());
} }
if (*type != dtype) { if (*type != dtype) {
return errors::Internal( return errors::Internal(
"Type mismatch: type in op attr = ", DataTypeString(dtype), "Type mismatch: type in op attr = ", DataTypeString(dtype),
", type in output props = ", DataTypeString(*type)); ", type in output props = ", DataTypeString(*type));
} }
if (!TensorShape::IsValid(props.shape()) || props.shape().unknown_rank()) {
// TensorShape::IsValid may return true if unknown_rank is True, i.e.
// number of dimensions is unknown. But for ScopedAllocatorOptimizer we
// need to know the shape fully.
return errors::Internal("Complete shape not known for ", n->name());
}
VLOG(2) << "Adding shape " << props.shape().DebugString(); VLOG(2) << "Adding shape " << props.shape().DebugString();
shapes->push_back(TensorShape(props.shape())); shapes->push_back(TensorShape(props.shape()));
} }

View File

@ -234,8 +234,10 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase,
sync_batchnorm=[True, False]) + sync_batchnorm=[True, False]) +
combinations.combine( combinations.combine(
distribution=strategy_combinations.multiworker_strategies, distribution=strategy_combinations.multiworker_strategies,
optimizer_fn= optimizer_fn=[
optimizer_combinations.gradient_descent_optimizer_keras_v2_fn, optimizer_combinations.gradient_descent_optimizer_keras_v2_fn,
optimizer_combinations.adagrad_optimizer_keras_v2_fn
],
mode=['eager'], mode=['eager'],
iteration_type=['iterator', 'dataset'], iteration_type=['iterator', 'dataset'],
inside_func=[False, True], inside_func=[False, True],