From 97528c31757797f97a8b57b1d0e024a4ffd42422 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 21 May 2020 12:53:37 -0700 Subject: [PATCH] [NFC] Fix typos and adopt Google style variable names PiperOrigin-RevId: 312723375 Change-Id: I4eb23a8b34de55fb35960af7fcca8350cfb8e1c7 --- .../compiler/mlir/tensorflow/ir/tf_ops.cc | 119 +++++++++--------- .../transforms/resource_device_inference.cc | 2 +- .../transforms/resource_op_lifting.cc | 4 +- .../tensorflow/transforms/shape_inference.cc | 4 +- 4 files changed, 65 insertions(+), 64 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index cbbb9fd5db3..389be0d3b2b 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -1821,71 +1821,71 @@ static LogicalResult Verify(GatherV2Op op) { static LogicalResult Verify(IfOp op) { auto module = op.getParentOfType(); - auto thenFn = module.lookupSymbol(op.then_branch()); - if (!thenFn) + auto then_fn = module.lookupSymbol(op.then_branch()); + if (!then_fn) return op.emitOpError("then_branch refers to an undefined function : ") << op.then_branch(); - auto elseFn = module.lookupSymbol(op.else_branch()); - if (!elseFn) + auto else_fn = module.lookupSymbol(op.else_branch()); + if (!else_fn) return op.emitOpError("else_branch refers to an undefined function : ") << op.else_branch(); - auto thenFuncType = thenFn.getType(); - auto elseFuncType = elseFn.getType(); + auto then_fn_type = then_fn.getType(); + auto else_fn_type = else_fn.getType(); // Non-conditional operands starting with the second operand are passed to // branches and should be pair-wise compatible with branches' inputs. - unsigned expectedNumInputs = op.getNumOperands() - 1; - if (thenFuncType.getNumInputs() != expectedNumInputs || - elseFuncType.getNumInputs() != expectedNumInputs) - return op.emitError("branches should have " + Twine(expectedNumInputs) + + unsigned expected_num_inputs = op.getNumOperands() - 1; + if (then_fn_type.getNumInputs() != expected_num_inputs || + else_fn_type.getNumInputs() != expected_num_inputs) + return op.emitError("branches should have " + Twine(expected_num_inputs) + " inputs"); - for (unsigned i = 0; i < expectedNumInputs; ++i) { - auto operandType = op.getOperand(i + 1).getType().cast(); - auto thenInputType = thenFuncType.getInput(i).cast(); - if (!AreCastCompatible({operandType, thenInputType})) + for (unsigned i = 0; i < expected_num_inputs; ++i) { + auto operand_type = op.getOperand(i + 1).getType().cast(); + auto then_input_type = then_fn_type.getInput(i).cast(); + if (!AreCastCompatible({operand_type, then_input_type})) return op.emitError( llvm::formatv("then branch input type {0} is incompatible with " "operand type {1} at index {2}", - thenInputType, operandType, i)); + then_input_type, operand_type, i)); - auto elseInputType = elseFuncType.getInput(i).cast(); - if (!AreCastCompatible({operandType, elseInputType})) + auto else_input_type = else_fn_type.getInput(i).cast(); + if (!AreCastCompatible({operand_type, else_input_type})) return op.emitError( llvm::formatv("else branch input type {0} is incompatible with " "operand type {1} at index {2}", - elseInputType, operandType, i)); + else_input_type, operand_type, i)); // If branches have incompatible input types that means that no tensor can // serve as input to both the functions. Hence, the op is invalid. - if (!AreCastCompatible({thenInputType, elseInputType})) + if (!AreCastCompatible({then_input_type, else_input_type})) return op.emitError(llvm::formatv( "branches inputs have incompatible types {0} and {1} at index {2}", - thenInputType, elseInputType, i)); + then_input_type, else_input_type, i)); } // Branches' results should be pair-wise compatible with the op results. - unsigned expectedNumResults = op.getNumResults(); - if (thenFuncType.getNumResults() != expectedNumResults || - elseFuncType.getNumResults() != expectedNumResults) - return op.emitError("branches should have " + Twine(expectedNumResults) + + unsigned expected_num_results = op.getNumResults(); + if (then_fn_type.getNumResults() != expected_num_results || + else_fn_type.getNumResults() != expected_num_results) + return op.emitError("branches should have " + Twine(expected_num_results) + " results"); - for (unsigned i = 0; i < expectedNumResults; ++i) { - auto resultType = op.getResult(i).getType().cast(); - auto thenResultType = thenFuncType.getResult(i).cast(); - if (!AreCastCompatible({thenResultType, resultType})) + for (unsigned i = 0; i < expected_num_results; ++i) { + auto result_type = op.getResult(i).getType().cast(); + auto then_result_type = then_fn_type.getResult(i).cast(); + if (!AreCastCompatible({then_result_type, result_type})) return op.emitError( llvm::formatv("then branch result type {0} is incompatible with op " "result type {1} at index {2}", - thenResultType, resultType, i)); + then_result_type, result_type, i)); - auto elseResultType = elseFuncType.getResult(i).cast(); - if (!AreCastCompatible({elseResultType, resultType})) + auto else_result_type = else_fn_type.getResult(i).cast(); + if (!AreCastCompatible({else_result_type, result_type})) return op.emitError( llvm::formatv("else branch result type {0} is incompatible with op " "result type {1} at index {2}", - elseResultType, resultType, i)); + else_result_type, result_type, i)); } return success(); } @@ -3887,36 +3887,37 @@ OpFoldResult VariableShapeOp::fold(ArrayRef operands) { static LogicalResult Verify(WhileOp op) { auto module = op.getParentOfType(); - auto condFn = module.lookupSymbol(op.cond()); - auto bodyFn = module.lookupSymbol(op.body()); - if (!condFn) { + auto cond_fn = module.lookupSymbol(op.cond()); + auto body_fn = module.lookupSymbol(op.body()); + if (!cond_fn) { return op.emitOpError("cond refers to an undefined function : ") << op.cond(); } - if (!bodyFn) { + if (!body_fn) { return op.emitOpError("body refers to an undefined function : ") << op.body(); } - auto condFuncType = condFn.getType(); - auto bodyFuncType = bodyFn.getType(); + auto cond_fn_type = cond_fn.getType(); + auto body_fn_type = body_fn.getType(); // Verify that the cond function has exactly one result. - if (condFuncType.getNumResults() != 1) + if (cond_fn_type.getNumResults() != 1) return op.emitOpError("requires cond function to have exactly one result"); SmallVector operands(op.getOperandTypes()); // Collect all the type lists for the op so that different pairs of type lists // can be compared for the compatibility. - int numTypeLists = 5; - std::pair> typeLists[] = { - {"operand", operands}, - {"body function result", bodyFuncType.getResults()}, - {"result", op.getResultTypes()}, - {"cond function input", condFuncType.getInputs()}, - {"body function input", bodyFuncType.getInputs()}, - }; + constexpr int kNumTypeLists = 5; + const std::array>, kNumTypeLists> + type_lists = {{ + {"operand", operands}, + {"body function result", body_fn_type.getResults()}, + {"result", op.getResultTypes()}, + {"cond function input", cond_fn_type.getInputs()}, + {"body function input", body_fn_type.getInputs()}, + }}; // A pair of type lists should be cast compatible with each other if one is // converted to the another for a function call or assignment or there is a @@ -3940,28 +3941,28 @@ static LogicalResult Verify(WhileOp op) { // never converted from one to the another nor there is a common source // tensors. Compatibility requirement is not transitive. - for (int i = 0; i < numTypeLists; ++i) { + for (int i = 0; i < kNumTypeLists; ++i) { // Skip the first pair as the While op operands and body function results // does not need to be compatible with each other. - for (int j = std::max(2, i + 1); j < numTypeLists; ++j) { - auto &a = typeLists[i]; - auto &b = typeLists[j]; + for (int j = std::max(2, i + 1); j < kNumTypeLists; ++j) { + auto &a = type_lists[i]; + auto &b = type_lists[j]; - int aSize = a.second.size(); - if (aSize != b.second.size()) + int a_size = a.second.size(); + if (a_size != b.second.size()) return op.emitOpError( llvm::formatv("requires the number of {0}s to be equal to the " "number of {1}s. Found {2} and {3}, respectively", - a.first, b.first, aSize, b.second.size())); + a.first, b.first, a_size, b.second.size())); - for (int idx = 0; idx < aSize; ++idx) { - auto aType = a.second[idx]; - auto bType = b.second[idx]; + for (int idx = 0; idx < a_size; ++idx) { + auto a_type = a.second[idx]; + auto b_type = b.second[idx]; - if (!AreCastCompatible({aType, bType})) + if (!AreCastCompatible({a_type, b_type})) return op.emitError(llvm::formatv( "{0} type {1} is incompatible with {2} type {3} at index {4}", - a.first, aType, b.first, bType, idx)); + a.first, a_type, b.first, b_type, idx)); } } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc index d37dfd14590..21d74d81b20 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc @@ -149,7 +149,7 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op, } auto walk_res = func_op.walk([&](Operation* op) { if (auto var_handle = llvm::dyn_cast(op)) { - // Record VarHanldeOp's device attribute. + // Record VarHandleOp's device attribute. auto device_attr = var_handle.getAttrOfType(kDeviceAttr); if (!device_attr || device_attr.getValue().empty()) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 611c4d2725a..82bc612b1f8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -571,7 +571,7 @@ void AddLoadsStoresOutsideControlFlowOp( } // Lifts loads/stores from while loop's body and cond functions. -LogicalResult HanldeWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) { +LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) { // Remove identity nodes to avoid aliasing. RemoveIdentity(&body.front()); RemoveIdentity(&cond.front()); @@ -985,7 +985,7 @@ LogicalResult HoistForFunctionalControlFlow( lifted_partitioned_call_callees); HoistForFunctionalControlFlow(&cond.front(), module, lifted_partitioned_call_callees); - if (failed(HanldeWhileLoop(while_op, body, cond))) return failure(); + if (failed(HandleWhileLoop(while_op, body, cond))) return failure(); } else if (auto if_op = llvm::dyn_cast(&op)) { auto then_branch = llvm::cast(module.lookupSymbol(if_op.then_branch())); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 5fa810eea33..1e9be76aa66 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -562,7 +562,7 @@ class ShapeInference { private: // Mapping between ValuePort (which corresponds to an OpResult or smaller, - // e.g., first element of OpResult produded) to an Attribute if the ValuePort + // e.g., first element of OpResult produced) to an Attribute if the ValuePort // corresponds to a constant value. ValuePortResultMap results_; int64_t graph_version_; @@ -1144,7 +1144,7 @@ LogicalResult InferShapeForFunction(FuncOp func, ArrayRef shape = arg_shapes[i]; Type element_type; if (auto input_ty = func_type.getInput(i).dyn_cast()) { - if (!input_ty || input_ty.getShape().size() != shape.size()) { + if (input_ty.getRank() != shape.size()) { return failure(); } element_type = input_ty.getElementType();