[NFC] Fix typos and adopt Google style variable names
PiperOrigin-RevId: 312723375 Change-Id: I4eb23a8b34de55fb35960af7fcca8350cfb8e1c7
This commit is contained in:
parent
49fd845a78
commit
97528c3175
@ -1821,71 +1821,71 @@ static LogicalResult Verify(GatherV2Op op) {
|
||||
|
||||
static LogicalResult Verify(IfOp op) {
|
||||
auto module = op.getParentOfType<ModuleOp>();
|
||||
auto thenFn = module.lookupSymbol<FuncOp>(op.then_branch());
|
||||
if (!thenFn)
|
||||
auto then_fn = module.lookupSymbol<FuncOp>(op.then_branch());
|
||||
if (!then_fn)
|
||||
return op.emitOpError("then_branch refers to an undefined function : ")
|
||||
<< op.then_branch();
|
||||
auto elseFn = module.lookupSymbol<FuncOp>(op.else_branch());
|
||||
if (!elseFn)
|
||||
auto else_fn = module.lookupSymbol<FuncOp>(op.else_branch());
|
||||
if (!else_fn)
|
||||
return op.emitOpError("else_branch refers to an undefined function : ")
|
||||
<< op.else_branch();
|
||||
auto thenFuncType = thenFn.getType();
|
||||
auto elseFuncType = elseFn.getType();
|
||||
auto then_fn_type = then_fn.getType();
|
||||
auto else_fn_type = else_fn.getType();
|
||||
|
||||
// Non-conditional operands starting with the second operand are passed to
|
||||
// branches and should be pair-wise compatible with branches' inputs.
|
||||
unsigned expectedNumInputs = op.getNumOperands() - 1;
|
||||
if (thenFuncType.getNumInputs() != expectedNumInputs ||
|
||||
elseFuncType.getNumInputs() != expectedNumInputs)
|
||||
return op.emitError("branches should have " + Twine(expectedNumInputs) +
|
||||
unsigned expected_num_inputs = op.getNumOperands() - 1;
|
||||
if (then_fn_type.getNumInputs() != expected_num_inputs ||
|
||||
else_fn_type.getNumInputs() != expected_num_inputs)
|
||||
return op.emitError("branches should have " + Twine(expected_num_inputs) +
|
||||
" inputs");
|
||||
|
||||
for (unsigned i = 0; i < expectedNumInputs; ++i) {
|
||||
auto operandType = op.getOperand(i + 1).getType().cast<TensorType>();
|
||||
auto thenInputType = thenFuncType.getInput(i).cast<TensorType>();
|
||||
if (!AreCastCompatible({operandType, thenInputType}))
|
||||
for (unsigned i = 0; i < expected_num_inputs; ++i) {
|
||||
auto operand_type = op.getOperand(i + 1).getType().cast<TensorType>();
|
||||
auto then_input_type = then_fn_type.getInput(i).cast<TensorType>();
|
||||
if (!AreCastCompatible({operand_type, then_input_type}))
|
||||
return op.emitError(
|
||||
llvm::formatv("then branch input type {0} is incompatible with "
|
||||
"operand type {1} at index {2}",
|
||||
thenInputType, operandType, i));
|
||||
then_input_type, operand_type, i));
|
||||
|
||||
auto elseInputType = elseFuncType.getInput(i).cast<TensorType>();
|
||||
if (!AreCastCompatible({operandType, elseInputType}))
|
||||
auto else_input_type = else_fn_type.getInput(i).cast<TensorType>();
|
||||
if (!AreCastCompatible({operand_type, else_input_type}))
|
||||
return op.emitError(
|
||||
llvm::formatv("else branch input type {0} is incompatible with "
|
||||
"operand type {1} at index {2}",
|
||||
elseInputType, operandType, i));
|
||||
else_input_type, operand_type, i));
|
||||
|
||||
// If branches have incompatible input types that means that no tensor can
|
||||
// serve as input to both the functions. Hence, the op is invalid.
|
||||
if (!AreCastCompatible({thenInputType, elseInputType}))
|
||||
if (!AreCastCompatible({then_input_type, else_input_type}))
|
||||
return op.emitError(llvm::formatv(
|
||||
"branches inputs have incompatible types {0} and {1} at index {2}",
|
||||
thenInputType, elseInputType, i));
|
||||
then_input_type, else_input_type, i));
|
||||
}
|
||||
|
||||
// Branches' results should be pair-wise compatible with the op results.
|
||||
unsigned expectedNumResults = op.getNumResults();
|
||||
if (thenFuncType.getNumResults() != expectedNumResults ||
|
||||
elseFuncType.getNumResults() != expectedNumResults)
|
||||
return op.emitError("branches should have " + Twine(expectedNumResults) +
|
||||
unsigned expected_num_results = op.getNumResults();
|
||||
if (then_fn_type.getNumResults() != expected_num_results ||
|
||||
else_fn_type.getNumResults() != expected_num_results)
|
||||
return op.emitError("branches should have " + Twine(expected_num_results) +
|
||||
" results");
|
||||
|
||||
for (unsigned i = 0; i < expectedNumResults; ++i) {
|
||||
auto resultType = op.getResult(i).getType().cast<TensorType>();
|
||||
auto thenResultType = thenFuncType.getResult(i).cast<TensorType>();
|
||||
if (!AreCastCompatible({thenResultType, resultType}))
|
||||
for (unsigned i = 0; i < expected_num_results; ++i) {
|
||||
auto result_type = op.getResult(i).getType().cast<TensorType>();
|
||||
auto then_result_type = then_fn_type.getResult(i).cast<TensorType>();
|
||||
if (!AreCastCompatible({then_result_type, result_type}))
|
||||
return op.emitError(
|
||||
llvm::formatv("then branch result type {0} is incompatible with op "
|
||||
"result type {1} at index {2}",
|
||||
thenResultType, resultType, i));
|
||||
then_result_type, result_type, i));
|
||||
|
||||
auto elseResultType = elseFuncType.getResult(i).cast<TensorType>();
|
||||
if (!AreCastCompatible({elseResultType, resultType}))
|
||||
auto else_result_type = else_fn_type.getResult(i).cast<TensorType>();
|
||||
if (!AreCastCompatible({else_result_type, result_type}))
|
||||
return op.emitError(
|
||||
llvm::formatv("else branch result type {0} is incompatible with op "
|
||||
"result type {1} at index {2}",
|
||||
elseResultType, resultType, i));
|
||||
else_result_type, result_type, i));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
@ -3887,36 +3887,37 @@ OpFoldResult VariableShapeOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
||||
static LogicalResult Verify(WhileOp op) {
|
||||
auto module = op.getParentOfType<ModuleOp>();
|
||||
auto condFn = module.lookupSymbol<FuncOp>(op.cond());
|
||||
auto bodyFn = module.lookupSymbol<FuncOp>(op.body());
|
||||
if (!condFn) {
|
||||
auto cond_fn = module.lookupSymbol<FuncOp>(op.cond());
|
||||
auto body_fn = module.lookupSymbol<FuncOp>(op.body());
|
||||
if (!cond_fn) {
|
||||
return op.emitOpError("cond refers to an undefined function : ")
|
||||
<< op.cond();
|
||||
}
|
||||
if (!bodyFn) {
|
||||
if (!body_fn) {
|
||||
return op.emitOpError("body refers to an undefined function : ")
|
||||
<< op.body();
|
||||
}
|
||||
|
||||
auto condFuncType = condFn.getType();
|
||||
auto bodyFuncType = bodyFn.getType();
|
||||
auto cond_fn_type = cond_fn.getType();
|
||||
auto body_fn_type = body_fn.getType();
|
||||
|
||||
// Verify that the cond function has exactly one result.
|
||||
if (condFuncType.getNumResults() != 1)
|
||||
if (cond_fn_type.getNumResults() != 1)
|
||||
return op.emitOpError("requires cond function to have exactly one result");
|
||||
|
||||
SmallVector<Type, 4> operands(op.getOperandTypes());
|
||||
|
||||
// Collect all the type lists for the op so that different pairs of type lists
|
||||
// can be compared for the compatibility.
|
||||
int numTypeLists = 5;
|
||||
std::pair<std::string, ArrayRef<Type>> typeLists[] = {
|
||||
{"operand", operands},
|
||||
{"body function result", bodyFuncType.getResults()},
|
||||
{"result", op.getResultTypes()},
|
||||
{"cond function input", condFuncType.getInputs()},
|
||||
{"body function input", bodyFuncType.getInputs()},
|
||||
};
|
||||
constexpr int kNumTypeLists = 5;
|
||||
const std::array<std::pair<std::string, ArrayRef<Type>>, kNumTypeLists>
|
||||
type_lists = {{
|
||||
{"operand", operands},
|
||||
{"body function result", body_fn_type.getResults()},
|
||||
{"result", op.getResultTypes()},
|
||||
{"cond function input", cond_fn_type.getInputs()},
|
||||
{"body function input", body_fn_type.getInputs()},
|
||||
}};
|
||||
|
||||
// A pair of type lists should be cast compatible with each other if one is
|
||||
// converted to the another for a function call or assignment or there is a
|
||||
@ -3940,28 +3941,28 @@ static LogicalResult Verify(WhileOp op) {
|
||||
// never converted from one to the another nor there is a common source
|
||||
// tensors. Compatibility requirement is not transitive.
|
||||
|
||||
for (int i = 0; i < numTypeLists; ++i) {
|
||||
for (int i = 0; i < kNumTypeLists; ++i) {
|
||||
// Skip the first pair as the While op operands and body function results
|
||||
// does not need to be compatible with each other.
|
||||
for (int j = std::max(2, i + 1); j < numTypeLists; ++j) {
|
||||
auto &a = typeLists[i];
|
||||
auto &b = typeLists[j];
|
||||
for (int j = std::max(2, i + 1); j < kNumTypeLists; ++j) {
|
||||
auto &a = type_lists[i];
|
||||
auto &b = type_lists[j];
|
||||
|
||||
int aSize = a.second.size();
|
||||
if (aSize != b.second.size())
|
||||
int a_size = a.second.size();
|
||||
if (a_size != b.second.size())
|
||||
return op.emitOpError(
|
||||
llvm::formatv("requires the number of {0}s to be equal to the "
|
||||
"number of {1}s. Found {2} and {3}, respectively",
|
||||
a.first, b.first, aSize, b.second.size()));
|
||||
a.first, b.first, a_size, b.second.size()));
|
||||
|
||||
for (int idx = 0; idx < aSize; ++idx) {
|
||||
auto aType = a.second[idx];
|
||||
auto bType = b.second[idx];
|
||||
for (int idx = 0; idx < a_size; ++idx) {
|
||||
auto a_type = a.second[idx];
|
||||
auto b_type = b.second[idx];
|
||||
|
||||
if (!AreCastCompatible({aType, bType}))
|
||||
if (!AreCastCompatible({a_type, b_type}))
|
||||
return op.emitError(llvm::formatv(
|
||||
"{0} type {1} is incompatible with {2} type {3} at index {4}",
|
||||
a.first, aType, b.first, bType, idx));
|
||||
a.first, a_type, b.first, b_type, idx));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -149,7 +149,7 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
|
||||
}
|
||||
auto walk_res = func_op.walk([&](Operation* op) {
|
||||
if (auto var_handle = llvm::dyn_cast<TF::VarHandleOp>(op)) {
|
||||
// Record VarHanldeOp's device attribute.
|
||||
// Record VarHandleOp's device attribute.
|
||||
auto device_attr =
|
||||
var_handle.getAttrOfType<mlir::StringAttr>(kDeviceAttr);
|
||||
if (!device_attr || device_attr.getValue().empty()) {
|
||||
|
@ -571,7 +571,7 @@ void AddLoadsStoresOutsideControlFlowOp(
|
||||
}
|
||||
|
||||
// Lifts loads/stores from while loop's body and cond functions.
|
||||
LogicalResult HanldeWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
|
||||
LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
|
||||
// Remove identity nodes to avoid aliasing.
|
||||
RemoveIdentity(&body.front());
|
||||
RemoveIdentity(&cond.front());
|
||||
@ -985,7 +985,7 @@ LogicalResult HoistForFunctionalControlFlow(
|
||||
lifted_partitioned_call_callees);
|
||||
HoistForFunctionalControlFlow(&cond.front(), module,
|
||||
lifted_partitioned_call_callees);
|
||||
if (failed(HanldeWhileLoop(while_op, body, cond))) return failure();
|
||||
if (failed(HandleWhileLoop(while_op, body, cond))) return failure();
|
||||
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
|
||||
auto then_branch =
|
||||
llvm::cast<FuncOp>(module.lookupSymbol(if_op.then_branch()));
|
||||
|
@ -562,7 +562,7 @@ class ShapeInference {
|
||||
|
||||
private:
|
||||
// Mapping between ValuePort (which corresponds to an OpResult or smaller,
|
||||
// e.g., first element of OpResult produded) to an Attribute if the ValuePort
|
||||
// e.g., first element of OpResult produced) to an Attribute if the ValuePort
|
||||
// corresponds to a constant value.
|
||||
ValuePortResultMap results_;
|
||||
int64_t graph_version_;
|
||||
@ -1144,7 +1144,7 @@ LogicalResult InferShapeForFunction(FuncOp func,
|
||||
ArrayRef<int64_t> shape = arg_shapes[i];
|
||||
Type element_type;
|
||||
if (auto input_ty = func_type.getInput(i).dyn_cast<RankedTensorType>()) {
|
||||
if (!input_ty || input_ty.getShape().size() != shape.size()) {
|
||||
if (input_ty.getRank() != shape.size()) {
|
||||
return failure();
|
||||
}
|
||||
element_type = input_ty.getElementType();
|
||||
|
Loading…
Reference in New Issue
Block a user