[NFC] Fix typos and adopt Google style variable names

PiperOrigin-RevId: 312723375
Change-Id: I4eb23a8b34de55fb35960af7fcca8350cfb8e1c7
This commit is contained in:
A. Unique TensorFlower 2020-05-21 12:53:37 -07:00 committed by TensorFlower Gardener
parent 49fd845a78
commit 97528c3175
4 changed files with 65 additions and 64 deletions

View File

@ -1821,71 +1821,71 @@ static LogicalResult Verify(GatherV2Op op) {
static LogicalResult Verify(IfOp op) { static LogicalResult Verify(IfOp op) {
auto module = op.getParentOfType<ModuleOp>(); auto module = op.getParentOfType<ModuleOp>();
auto thenFn = module.lookupSymbol<FuncOp>(op.then_branch()); auto then_fn = module.lookupSymbol<FuncOp>(op.then_branch());
if (!thenFn) if (!then_fn)
return op.emitOpError("then_branch refers to an undefined function : ") return op.emitOpError("then_branch refers to an undefined function : ")
<< op.then_branch(); << op.then_branch();
auto elseFn = module.lookupSymbol<FuncOp>(op.else_branch()); auto else_fn = module.lookupSymbol<FuncOp>(op.else_branch());
if (!elseFn) if (!else_fn)
return op.emitOpError("else_branch refers to an undefined function : ") return op.emitOpError("else_branch refers to an undefined function : ")
<< op.else_branch(); << op.else_branch();
auto thenFuncType = thenFn.getType(); auto then_fn_type = then_fn.getType();
auto elseFuncType = elseFn.getType(); auto else_fn_type = else_fn.getType();
// Non-conditional operands starting with the second operand are passed to // Non-conditional operands starting with the second operand are passed to
// branches and should be pair-wise compatible with branches' inputs. // branches and should be pair-wise compatible with branches' inputs.
unsigned expectedNumInputs = op.getNumOperands() - 1; unsigned expected_num_inputs = op.getNumOperands() - 1;
if (thenFuncType.getNumInputs() != expectedNumInputs || if (then_fn_type.getNumInputs() != expected_num_inputs ||
elseFuncType.getNumInputs() != expectedNumInputs) else_fn_type.getNumInputs() != expected_num_inputs)
return op.emitError("branches should have " + Twine(expectedNumInputs) + return op.emitError("branches should have " + Twine(expected_num_inputs) +
" inputs"); " inputs");
for (unsigned i = 0; i < expectedNumInputs; ++i) { for (unsigned i = 0; i < expected_num_inputs; ++i) {
auto operandType = op.getOperand(i + 1).getType().cast<TensorType>(); auto operand_type = op.getOperand(i + 1).getType().cast<TensorType>();
auto thenInputType = thenFuncType.getInput(i).cast<TensorType>(); auto then_input_type = then_fn_type.getInput(i).cast<TensorType>();
if (!AreCastCompatible({operandType, thenInputType})) if (!AreCastCompatible({operand_type, then_input_type}))
return op.emitError( return op.emitError(
llvm::formatv("then branch input type {0} is incompatible with " llvm::formatv("then branch input type {0} is incompatible with "
"operand type {1} at index {2}", "operand type {1} at index {2}",
thenInputType, operandType, i)); then_input_type, operand_type, i));
auto elseInputType = elseFuncType.getInput(i).cast<TensorType>(); auto else_input_type = else_fn_type.getInput(i).cast<TensorType>();
if (!AreCastCompatible({operandType, elseInputType})) if (!AreCastCompatible({operand_type, else_input_type}))
return op.emitError( return op.emitError(
llvm::formatv("else branch input type {0} is incompatible with " llvm::formatv("else branch input type {0} is incompatible with "
"operand type {1} at index {2}", "operand type {1} at index {2}",
elseInputType, operandType, i)); else_input_type, operand_type, i));
// If branches have incompatible input types that means that no tensor can // If branches have incompatible input types that means that no tensor can
// serve as input to both the functions. Hence, the op is invalid. // serve as input to both the functions. Hence, the op is invalid.
if (!AreCastCompatible({thenInputType, elseInputType})) if (!AreCastCompatible({then_input_type, else_input_type}))
return op.emitError(llvm::formatv( return op.emitError(llvm::formatv(
"branches inputs have incompatible types {0} and {1} at index {2}", "branches inputs have incompatible types {0} and {1} at index {2}",
thenInputType, elseInputType, i)); then_input_type, else_input_type, i));
} }
// Branches' results should be pair-wise compatible with the op results. // Branches' results should be pair-wise compatible with the op results.
unsigned expectedNumResults = op.getNumResults(); unsigned expected_num_results = op.getNumResults();
if (thenFuncType.getNumResults() != expectedNumResults || if (then_fn_type.getNumResults() != expected_num_results ||
elseFuncType.getNumResults() != expectedNumResults) else_fn_type.getNumResults() != expected_num_results)
return op.emitError("branches should have " + Twine(expectedNumResults) + return op.emitError("branches should have " + Twine(expected_num_results) +
" results"); " results");
for (unsigned i = 0; i < expectedNumResults; ++i) { for (unsigned i = 0; i < expected_num_results; ++i) {
auto resultType = op.getResult(i).getType().cast<TensorType>(); auto result_type = op.getResult(i).getType().cast<TensorType>();
auto thenResultType = thenFuncType.getResult(i).cast<TensorType>(); auto then_result_type = then_fn_type.getResult(i).cast<TensorType>();
if (!AreCastCompatible({thenResultType, resultType})) if (!AreCastCompatible({then_result_type, result_type}))
return op.emitError( return op.emitError(
llvm::formatv("then branch result type {0} is incompatible with op " llvm::formatv("then branch result type {0} is incompatible with op "
"result type {1} at index {2}", "result type {1} at index {2}",
thenResultType, resultType, i)); then_result_type, result_type, i));
auto elseResultType = elseFuncType.getResult(i).cast<TensorType>(); auto else_result_type = else_fn_type.getResult(i).cast<TensorType>();
if (!AreCastCompatible({elseResultType, resultType})) if (!AreCastCompatible({else_result_type, result_type}))
return op.emitError( return op.emitError(
llvm::formatv("else branch result type {0} is incompatible with op " llvm::formatv("else branch result type {0} is incompatible with op "
"result type {1} at index {2}", "result type {1} at index {2}",
elseResultType, resultType, i)); else_result_type, result_type, i));
} }
return success(); return success();
} }
@ -3887,36 +3887,37 @@ OpFoldResult VariableShapeOp::fold(ArrayRef<Attribute> operands) {
static LogicalResult Verify(WhileOp op) { static LogicalResult Verify(WhileOp op) {
auto module = op.getParentOfType<ModuleOp>(); auto module = op.getParentOfType<ModuleOp>();
auto condFn = module.lookupSymbol<FuncOp>(op.cond()); auto cond_fn = module.lookupSymbol<FuncOp>(op.cond());
auto bodyFn = module.lookupSymbol<FuncOp>(op.body()); auto body_fn = module.lookupSymbol<FuncOp>(op.body());
if (!condFn) { if (!cond_fn) {
return op.emitOpError("cond refers to an undefined function : ") return op.emitOpError("cond refers to an undefined function : ")
<< op.cond(); << op.cond();
} }
if (!bodyFn) { if (!body_fn) {
return op.emitOpError("body refers to an undefined function : ") return op.emitOpError("body refers to an undefined function : ")
<< op.body(); << op.body();
} }
auto condFuncType = condFn.getType(); auto cond_fn_type = cond_fn.getType();
auto bodyFuncType = bodyFn.getType(); auto body_fn_type = body_fn.getType();
// Verify that the cond function has exactly one result. // Verify that the cond function has exactly one result.
if (condFuncType.getNumResults() != 1) if (cond_fn_type.getNumResults() != 1)
return op.emitOpError("requires cond function to have exactly one result"); return op.emitOpError("requires cond function to have exactly one result");
SmallVector<Type, 4> operands(op.getOperandTypes()); SmallVector<Type, 4> operands(op.getOperandTypes());
// Collect all the type lists for the op so that different pairs of type lists // Collect all the type lists for the op so that different pairs of type lists
// can be compared for the compatibility. // can be compared for the compatibility.
int numTypeLists = 5; constexpr int kNumTypeLists = 5;
std::pair<std::string, ArrayRef<Type>> typeLists[] = { const std::array<std::pair<std::string, ArrayRef<Type>>, kNumTypeLists>
type_lists = {{
{"operand", operands}, {"operand", operands},
{"body function result", bodyFuncType.getResults()}, {"body function result", body_fn_type.getResults()},
{"result", op.getResultTypes()}, {"result", op.getResultTypes()},
{"cond function input", condFuncType.getInputs()}, {"cond function input", cond_fn_type.getInputs()},
{"body function input", bodyFuncType.getInputs()}, {"body function input", body_fn_type.getInputs()},
}; }};
// A pair of type lists should be cast compatible with each other if one is // A pair of type lists should be cast compatible with each other if one is
// converted to the another for a function call or assignment or there is a // converted to the another for a function call or assignment or there is a
@ -3940,28 +3941,28 @@ static LogicalResult Verify(WhileOp op) {
// never converted from one to the another nor there is a common source // never converted from one to the another nor there is a common source
// tensors. Compatibility requirement is not transitive. // tensors. Compatibility requirement is not transitive.
for (int i = 0; i < numTypeLists; ++i) { for (int i = 0; i < kNumTypeLists; ++i) {
// Skip the first pair as the While op operands and body function results // Skip the first pair as the While op operands and body function results
// does not need to be compatible with each other. // does not need to be compatible with each other.
for (int j = std::max(2, i + 1); j < numTypeLists; ++j) { for (int j = std::max(2, i + 1); j < kNumTypeLists; ++j) {
auto &a = typeLists[i]; auto &a = type_lists[i];
auto &b = typeLists[j]; auto &b = type_lists[j];
int aSize = a.second.size(); int a_size = a.second.size();
if (aSize != b.second.size()) if (a_size != b.second.size())
return op.emitOpError( return op.emitOpError(
llvm::formatv("requires the number of {0}s to be equal to the " llvm::formatv("requires the number of {0}s to be equal to the "
"number of {1}s. Found {2} and {3}, respectively", "number of {1}s. Found {2} and {3}, respectively",
a.first, b.first, aSize, b.second.size())); a.first, b.first, a_size, b.second.size()));
for (int idx = 0; idx < aSize; ++idx) { for (int idx = 0; idx < a_size; ++idx) {
auto aType = a.second[idx]; auto a_type = a.second[idx];
auto bType = b.second[idx]; auto b_type = b.second[idx];
if (!AreCastCompatible({aType, bType})) if (!AreCastCompatible({a_type, b_type}))
return op.emitError(llvm::formatv( return op.emitError(llvm::formatv(
"{0} type {1} is incompatible with {2} type {3} at index {4}", "{0} type {1} is incompatible with {2} type {3} at index {4}",
a.first, aType, b.first, bType, idx)); a.first, a_type, b.first, b_type, idx));
} }
} }
} }

View File

@ -149,7 +149,7 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
} }
auto walk_res = func_op.walk([&](Operation* op) { auto walk_res = func_op.walk([&](Operation* op) {
if (auto var_handle = llvm::dyn_cast<TF::VarHandleOp>(op)) { if (auto var_handle = llvm::dyn_cast<TF::VarHandleOp>(op)) {
// Record VarHanldeOp's device attribute. // Record VarHandleOp's device attribute.
auto device_attr = auto device_attr =
var_handle.getAttrOfType<mlir::StringAttr>(kDeviceAttr); var_handle.getAttrOfType<mlir::StringAttr>(kDeviceAttr);
if (!device_attr || device_attr.getValue().empty()) { if (!device_attr || device_attr.getValue().empty()) {

View File

@ -571,7 +571,7 @@ void AddLoadsStoresOutsideControlFlowOp(
} }
// Lifts loads/stores from while loop's body and cond functions. // Lifts loads/stores from while loop's body and cond functions.
LogicalResult HanldeWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) { LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
// Remove identity nodes to avoid aliasing. // Remove identity nodes to avoid aliasing.
RemoveIdentity(&body.front()); RemoveIdentity(&body.front());
RemoveIdentity(&cond.front()); RemoveIdentity(&cond.front());
@ -985,7 +985,7 @@ LogicalResult HoistForFunctionalControlFlow(
lifted_partitioned_call_callees); lifted_partitioned_call_callees);
HoistForFunctionalControlFlow(&cond.front(), module, HoistForFunctionalControlFlow(&cond.front(), module,
lifted_partitioned_call_callees); lifted_partitioned_call_callees);
if (failed(HanldeWhileLoop(while_op, body, cond))) return failure(); if (failed(HandleWhileLoop(while_op, body, cond))) return failure();
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) { } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
auto then_branch = auto then_branch =
llvm::cast<FuncOp>(module.lookupSymbol(if_op.then_branch())); llvm::cast<FuncOp>(module.lookupSymbol(if_op.then_branch()));

View File

@ -562,7 +562,7 @@ class ShapeInference {
private: private:
// Mapping between ValuePort (which corresponds to an OpResult or smaller, // Mapping between ValuePort (which corresponds to an OpResult or smaller,
// e.g., first element of OpResult produded) to an Attribute if the ValuePort // e.g., first element of OpResult produced) to an Attribute if the ValuePort
// corresponds to a constant value. // corresponds to a constant value.
ValuePortResultMap results_; ValuePortResultMap results_;
int64_t graph_version_; int64_t graph_version_;
@ -1144,7 +1144,7 @@ LogicalResult InferShapeForFunction(FuncOp func,
ArrayRef<int64_t> shape = arg_shapes[i]; ArrayRef<int64_t> shape = arg_shapes[i];
Type element_type; Type element_type;
if (auto input_ty = func_type.getInput(i).dyn_cast<RankedTensorType>()) { if (auto input_ty = func_type.getInput(i).dyn_cast<RankedTensorType>()) {
if (!input_ty || input_ty.getShape().size() != shape.size()) { if (input_ty.getRank() != shape.size()) {
return failure(); return failure();
} }
element_type = input_ty.getElementType(); element_type = input_ty.getElementType();