Check for Any/All reduction explicitly
This commit is contained in:
parent
bb196f949f
commit
a3b0b87aa2
@ -2474,9 +2474,10 @@ bool ConstantFolding::ReplaceReductionWithIdentity(NodeDef* node) const {
|
|||||||
DataType output_type;
|
DataType output_type;
|
||||||
if (node->attr().count("T") != 0) {
|
if (node->attr().count("T") != 0) {
|
||||||
output_type = node->attr().at("T").type();
|
output_type = node->attr().at("T").type();
|
||||||
} else {
|
} else if (IsAny(*node) || IsAll(*node)) {
|
||||||
// This is an 'any' or 'all' reduction. The output is always boolean.
|
|
||||||
output_type = DT_BOOL;
|
output_type = DT_BOOL;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
node->set_op("Identity");
|
node->set_op("Identity");
|
||||||
node->clear_attr();
|
node->clear_attr();
|
||||||
|
Loading…
x
Reference in New Issue
Block a user