[NFC] Fix typos and adopt Google style variable names
PiperOrigin-RevId: 312723375 Change-Id: I4eb23a8b34de55fb35960af7fcca8350cfb8e1c7
This commit is contained in:
parent
49fd845a78
commit
97528c3175
@ -1821,71 +1821,71 @@ static LogicalResult Verify(GatherV2Op op) {
|
|||||||
|
|
||||||
static LogicalResult Verify(IfOp op) {
|
static LogicalResult Verify(IfOp op) {
|
||||||
auto module = op.getParentOfType<ModuleOp>();
|
auto module = op.getParentOfType<ModuleOp>();
|
||||||
auto thenFn = module.lookupSymbol<FuncOp>(op.then_branch());
|
auto then_fn = module.lookupSymbol<FuncOp>(op.then_branch());
|
||||||
if (!thenFn)
|
if (!then_fn)
|
||||||
return op.emitOpError("then_branch refers to an undefined function : ")
|
return op.emitOpError("then_branch refers to an undefined function : ")
|
||||||
<< op.then_branch();
|
<< op.then_branch();
|
||||||
auto elseFn = module.lookupSymbol<FuncOp>(op.else_branch());
|
auto else_fn = module.lookupSymbol<FuncOp>(op.else_branch());
|
||||||
if (!elseFn)
|
if (!else_fn)
|
||||||
return op.emitOpError("else_branch refers to an undefined function : ")
|
return op.emitOpError("else_branch refers to an undefined function : ")
|
||||||
<< op.else_branch();
|
<< op.else_branch();
|
||||||
auto thenFuncType = thenFn.getType();
|
auto then_fn_type = then_fn.getType();
|
||||||
auto elseFuncType = elseFn.getType();
|
auto else_fn_type = else_fn.getType();
|
||||||
|
|
||||||
// Non-conditional operands starting with the second operand are passed to
|
// Non-conditional operands starting with the second operand are passed to
|
||||||
// branches and should be pair-wise compatible with branches' inputs.
|
// branches and should be pair-wise compatible with branches' inputs.
|
||||||
unsigned expectedNumInputs = op.getNumOperands() - 1;
|
unsigned expected_num_inputs = op.getNumOperands() - 1;
|
||||||
if (thenFuncType.getNumInputs() != expectedNumInputs ||
|
if (then_fn_type.getNumInputs() != expected_num_inputs ||
|
||||||
elseFuncType.getNumInputs() != expectedNumInputs)
|
else_fn_type.getNumInputs() != expected_num_inputs)
|
||||||
return op.emitError("branches should have " + Twine(expectedNumInputs) +
|
return op.emitError("branches should have " + Twine(expected_num_inputs) +
|
||||||
" inputs");
|
" inputs");
|
||||||
|
|
||||||
for (unsigned i = 0; i < expectedNumInputs; ++i) {
|
for (unsigned i = 0; i < expected_num_inputs; ++i) {
|
||||||
auto operandType = op.getOperand(i + 1).getType().cast<TensorType>();
|
auto operand_type = op.getOperand(i + 1).getType().cast<TensorType>();
|
||||||
auto thenInputType = thenFuncType.getInput(i).cast<TensorType>();
|
auto then_input_type = then_fn_type.getInput(i).cast<TensorType>();
|
||||||
if (!AreCastCompatible({operandType, thenInputType}))
|
if (!AreCastCompatible({operand_type, then_input_type}))
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
llvm::formatv("then branch input type {0} is incompatible with "
|
llvm::formatv("then branch input type {0} is incompatible with "
|
||||||
"operand type {1} at index {2}",
|
"operand type {1} at index {2}",
|
||||||
thenInputType, operandType, i));
|
then_input_type, operand_type, i));
|
||||||
|
|
||||||
auto elseInputType = elseFuncType.getInput(i).cast<TensorType>();
|
auto else_input_type = else_fn_type.getInput(i).cast<TensorType>();
|
||||||
if (!AreCastCompatible({operandType, elseInputType}))
|
if (!AreCastCompatible({operand_type, else_input_type}))
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
llvm::formatv("else branch input type {0} is incompatible with "
|
llvm::formatv("else branch input type {0} is incompatible with "
|
||||||
"operand type {1} at index {2}",
|
"operand type {1} at index {2}",
|
||||||
elseInputType, operandType, i));
|
else_input_type, operand_type, i));
|
||||||
|
|
||||||
// If branches have incompatible input types that means that no tensor can
|
// If branches have incompatible input types that means that no tensor can
|
||||||
// serve as input to both the functions. Hence, the op is invalid.
|
// serve as input to both the functions. Hence, the op is invalid.
|
||||||
if (!AreCastCompatible({thenInputType, elseInputType}))
|
if (!AreCastCompatible({then_input_type, else_input_type}))
|
||||||
return op.emitError(llvm::formatv(
|
return op.emitError(llvm::formatv(
|
||||||
"branches inputs have incompatible types {0} and {1} at index {2}",
|
"branches inputs have incompatible types {0} and {1} at index {2}",
|
||||||
thenInputType, elseInputType, i));
|
then_input_type, else_input_type, i));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Branches' results should be pair-wise compatible with the op results.
|
// Branches' results should be pair-wise compatible with the op results.
|
||||||
unsigned expectedNumResults = op.getNumResults();
|
unsigned expected_num_results = op.getNumResults();
|
||||||
if (thenFuncType.getNumResults() != expectedNumResults ||
|
if (then_fn_type.getNumResults() != expected_num_results ||
|
||||||
elseFuncType.getNumResults() != expectedNumResults)
|
else_fn_type.getNumResults() != expected_num_results)
|
||||||
return op.emitError("branches should have " + Twine(expectedNumResults) +
|
return op.emitError("branches should have " + Twine(expected_num_results) +
|
||||||
" results");
|
" results");
|
||||||
|
|
||||||
for (unsigned i = 0; i < expectedNumResults; ++i) {
|
for (unsigned i = 0; i < expected_num_results; ++i) {
|
||||||
auto resultType = op.getResult(i).getType().cast<TensorType>();
|
auto result_type = op.getResult(i).getType().cast<TensorType>();
|
||||||
auto thenResultType = thenFuncType.getResult(i).cast<TensorType>();
|
auto then_result_type = then_fn_type.getResult(i).cast<TensorType>();
|
||||||
if (!AreCastCompatible({thenResultType, resultType}))
|
if (!AreCastCompatible({then_result_type, result_type}))
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
llvm::formatv("then branch result type {0} is incompatible with op "
|
llvm::formatv("then branch result type {0} is incompatible with op "
|
||||||
"result type {1} at index {2}",
|
"result type {1} at index {2}",
|
||||||
thenResultType, resultType, i));
|
then_result_type, result_type, i));
|
||||||
|
|
||||||
auto elseResultType = elseFuncType.getResult(i).cast<TensorType>();
|
auto else_result_type = else_fn_type.getResult(i).cast<TensorType>();
|
||||||
if (!AreCastCompatible({elseResultType, resultType}))
|
if (!AreCastCompatible({else_result_type, result_type}))
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
llvm::formatv("else branch result type {0} is incompatible with op "
|
llvm::formatv("else branch result type {0} is incompatible with op "
|
||||||
"result type {1} at index {2}",
|
"result type {1} at index {2}",
|
||||||
elseResultType, resultType, i));
|
else_result_type, result_type, i));
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@ -3887,36 +3887,37 @@ OpFoldResult VariableShapeOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
|
|
||||||
static LogicalResult Verify(WhileOp op) {
|
static LogicalResult Verify(WhileOp op) {
|
||||||
auto module = op.getParentOfType<ModuleOp>();
|
auto module = op.getParentOfType<ModuleOp>();
|
||||||
auto condFn = module.lookupSymbol<FuncOp>(op.cond());
|
auto cond_fn = module.lookupSymbol<FuncOp>(op.cond());
|
||||||
auto bodyFn = module.lookupSymbol<FuncOp>(op.body());
|
auto body_fn = module.lookupSymbol<FuncOp>(op.body());
|
||||||
if (!condFn) {
|
if (!cond_fn) {
|
||||||
return op.emitOpError("cond refers to an undefined function : ")
|
return op.emitOpError("cond refers to an undefined function : ")
|
||||||
<< op.cond();
|
<< op.cond();
|
||||||
}
|
}
|
||||||
if (!bodyFn) {
|
if (!body_fn) {
|
||||||
return op.emitOpError("body refers to an undefined function : ")
|
return op.emitOpError("body refers to an undefined function : ")
|
||||||
<< op.body();
|
<< op.body();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto condFuncType = condFn.getType();
|
auto cond_fn_type = cond_fn.getType();
|
||||||
auto bodyFuncType = bodyFn.getType();
|
auto body_fn_type = body_fn.getType();
|
||||||
|
|
||||||
// Verify that the cond function has exactly one result.
|
// Verify that the cond function has exactly one result.
|
||||||
if (condFuncType.getNumResults() != 1)
|
if (cond_fn_type.getNumResults() != 1)
|
||||||
return op.emitOpError("requires cond function to have exactly one result");
|
return op.emitOpError("requires cond function to have exactly one result");
|
||||||
|
|
||||||
SmallVector<Type, 4> operands(op.getOperandTypes());
|
SmallVector<Type, 4> operands(op.getOperandTypes());
|
||||||
|
|
||||||
// Collect all the type lists for the op so that different pairs of type lists
|
// Collect all the type lists for the op so that different pairs of type lists
|
||||||
// can be compared for the compatibility.
|
// can be compared for the compatibility.
|
||||||
int numTypeLists = 5;
|
constexpr int kNumTypeLists = 5;
|
||||||
std::pair<std::string, ArrayRef<Type>> typeLists[] = {
|
const std::array<std::pair<std::string, ArrayRef<Type>>, kNumTypeLists>
|
||||||
|
type_lists = {{
|
||||||
{"operand", operands},
|
{"operand", operands},
|
||||||
{"body function result", bodyFuncType.getResults()},
|
{"body function result", body_fn_type.getResults()},
|
||||||
{"result", op.getResultTypes()},
|
{"result", op.getResultTypes()},
|
||||||
{"cond function input", condFuncType.getInputs()},
|
{"cond function input", cond_fn_type.getInputs()},
|
||||||
{"body function input", bodyFuncType.getInputs()},
|
{"body function input", body_fn_type.getInputs()},
|
||||||
};
|
}};
|
||||||
|
|
||||||
// A pair of type lists should be cast compatible with each other if one is
|
// A pair of type lists should be cast compatible with each other if one is
|
||||||
// converted to the another for a function call or assignment or there is a
|
// converted to the another for a function call or assignment or there is a
|
||||||
@ -3940,28 +3941,28 @@ static LogicalResult Verify(WhileOp op) {
|
|||||||
// never converted from one to the another nor there is a common source
|
// never converted from one to the another nor there is a common source
|
||||||
// tensors. Compatibility requirement is not transitive.
|
// tensors. Compatibility requirement is not transitive.
|
||||||
|
|
||||||
for (int i = 0; i < numTypeLists; ++i) {
|
for (int i = 0; i < kNumTypeLists; ++i) {
|
||||||
// Skip the first pair as the While op operands and body function results
|
// Skip the first pair as the While op operands and body function results
|
||||||
// does not need to be compatible with each other.
|
// does not need to be compatible with each other.
|
||||||
for (int j = std::max(2, i + 1); j < numTypeLists; ++j) {
|
for (int j = std::max(2, i + 1); j < kNumTypeLists; ++j) {
|
||||||
auto &a = typeLists[i];
|
auto &a = type_lists[i];
|
||||||
auto &b = typeLists[j];
|
auto &b = type_lists[j];
|
||||||
|
|
||||||
int aSize = a.second.size();
|
int a_size = a.second.size();
|
||||||
if (aSize != b.second.size())
|
if (a_size != b.second.size())
|
||||||
return op.emitOpError(
|
return op.emitOpError(
|
||||||
llvm::formatv("requires the number of {0}s to be equal to the "
|
llvm::formatv("requires the number of {0}s to be equal to the "
|
||||||
"number of {1}s. Found {2} and {3}, respectively",
|
"number of {1}s. Found {2} and {3}, respectively",
|
||||||
a.first, b.first, aSize, b.second.size()));
|
a.first, b.first, a_size, b.second.size()));
|
||||||
|
|
||||||
for (int idx = 0; idx < aSize; ++idx) {
|
for (int idx = 0; idx < a_size; ++idx) {
|
||||||
auto aType = a.second[idx];
|
auto a_type = a.second[idx];
|
||||||
auto bType = b.second[idx];
|
auto b_type = b.second[idx];
|
||||||
|
|
||||||
if (!AreCastCompatible({aType, bType}))
|
if (!AreCastCompatible({a_type, b_type}))
|
||||||
return op.emitError(llvm::formatv(
|
return op.emitError(llvm::formatv(
|
||||||
"{0} type {1} is incompatible with {2} type {3} at index {4}",
|
"{0} type {1} is incompatible with {2} type {3} at index {4}",
|
||||||
a.first, aType, b.first, bType, idx));
|
a.first, a_type, b.first, b_type, idx));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -149,7 +149,7 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
|
|||||||
}
|
}
|
||||||
auto walk_res = func_op.walk([&](Operation* op) {
|
auto walk_res = func_op.walk([&](Operation* op) {
|
||||||
if (auto var_handle = llvm::dyn_cast<TF::VarHandleOp>(op)) {
|
if (auto var_handle = llvm::dyn_cast<TF::VarHandleOp>(op)) {
|
||||||
// Record VarHanldeOp's device attribute.
|
// Record VarHandleOp's device attribute.
|
||||||
auto device_attr =
|
auto device_attr =
|
||||||
var_handle.getAttrOfType<mlir::StringAttr>(kDeviceAttr);
|
var_handle.getAttrOfType<mlir::StringAttr>(kDeviceAttr);
|
||||||
if (!device_attr || device_attr.getValue().empty()) {
|
if (!device_attr || device_attr.getValue().empty()) {
|
||||||
|
@ -571,7 +571,7 @@ void AddLoadsStoresOutsideControlFlowOp(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Lifts loads/stores from while loop's body and cond functions.
|
// Lifts loads/stores from while loop's body and cond functions.
|
||||||
LogicalResult HanldeWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
|
LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
|
||||||
// Remove identity nodes to avoid aliasing.
|
// Remove identity nodes to avoid aliasing.
|
||||||
RemoveIdentity(&body.front());
|
RemoveIdentity(&body.front());
|
||||||
RemoveIdentity(&cond.front());
|
RemoveIdentity(&cond.front());
|
||||||
@ -985,7 +985,7 @@ LogicalResult HoistForFunctionalControlFlow(
|
|||||||
lifted_partitioned_call_callees);
|
lifted_partitioned_call_callees);
|
||||||
HoistForFunctionalControlFlow(&cond.front(), module,
|
HoistForFunctionalControlFlow(&cond.front(), module,
|
||||||
lifted_partitioned_call_callees);
|
lifted_partitioned_call_callees);
|
||||||
if (failed(HanldeWhileLoop(while_op, body, cond))) return failure();
|
if (failed(HandleWhileLoop(while_op, body, cond))) return failure();
|
||||||
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
|
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
|
||||||
auto then_branch =
|
auto then_branch =
|
||||||
llvm::cast<FuncOp>(module.lookupSymbol(if_op.then_branch()));
|
llvm::cast<FuncOp>(module.lookupSymbol(if_op.then_branch()));
|
||||||
|
@ -562,7 +562,7 @@ class ShapeInference {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
// Mapping between ValuePort (which corresponds to an OpResult or smaller,
|
// Mapping between ValuePort (which corresponds to an OpResult or smaller,
|
||||||
// e.g., first element of OpResult produded) to an Attribute if the ValuePort
|
// e.g., first element of OpResult produced) to an Attribute if the ValuePort
|
||||||
// corresponds to a constant value.
|
// corresponds to a constant value.
|
||||||
ValuePortResultMap results_;
|
ValuePortResultMap results_;
|
||||||
int64_t graph_version_;
|
int64_t graph_version_;
|
||||||
@ -1144,7 +1144,7 @@ LogicalResult InferShapeForFunction(FuncOp func,
|
|||||||
ArrayRef<int64_t> shape = arg_shapes[i];
|
ArrayRef<int64_t> shape = arg_shapes[i];
|
||||||
Type element_type;
|
Type element_type;
|
||||||
if (auto input_ty = func_type.getInput(i).dyn_cast<RankedTensorType>()) {
|
if (auto input_ty = func_type.getInput(i).dyn_cast<RankedTensorType>()) {
|
||||||
if (!input_ty || input_ty.getShape().size() != shape.size()) {
|
if (input_ty.getRank() != shape.size()) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
element_type = input_ty.getElementType();
|
element_type = input_ty.getElementType();
|
||||||
|
Loading…
Reference in New Issue
Block a user