[NFC] Fix typos and adopt Google style variable names

PiperOrigin-RevId: 312723375
Change-Id: I4eb23a8b34de55fb35960af7fcca8350cfb8e1c7
This commit is contained in:
A. Unique TensorFlower 2020-05-21 12:53:37 -07:00 committed by TensorFlower Gardener
parent 49fd845a78
commit 97528c3175
4 changed files with 65 additions and 64 deletions

View File

@ -1821,71 +1821,71 @@ static LogicalResult Verify(GatherV2Op op) {
static LogicalResult Verify(IfOp op) {
auto module = op.getParentOfType<ModuleOp>();
auto thenFn = module.lookupSymbol<FuncOp>(op.then_branch());
if (!thenFn)
auto then_fn = module.lookupSymbol<FuncOp>(op.then_branch());
if (!then_fn)
return op.emitOpError("then_branch refers to an undefined function : ")
<< op.then_branch();
auto elseFn = module.lookupSymbol<FuncOp>(op.else_branch());
if (!elseFn)
auto else_fn = module.lookupSymbol<FuncOp>(op.else_branch());
if (!else_fn)
return op.emitOpError("else_branch refers to an undefined function : ")
<< op.else_branch();
auto thenFuncType = thenFn.getType();
auto elseFuncType = elseFn.getType();
auto then_fn_type = then_fn.getType();
auto else_fn_type = else_fn.getType();
// Non-conditional operands starting with the second operand are passed to
// branches and should be pair-wise compatible with branches' inputs.
unsigned expectedNumInputs = op.getNumOperands() - 1;
if (thenFuncType.getNumInputs() != expectedNumInputs ||
elseFuncType.getNumInputs() != expectedNumInputs)
return op.emitError("branches should have " + Twine(expectedNumInputs) +
unsigned expected_num_inputs = op.getNumOperands() - 1;
if (then_fn_type.getNumInputs() != expected_num_inputs ||
else_fn_type.getNumInputs() != expected_num_inputs)
return op.emitError("branches should have " + Twine(expected_num_inputs) +
" inputs");
for (unsigned i = 0; i < expectedNumInputs; ++i) {
auto operandType = op.getOperand(i + 1).getType().cast<TensorType>();
auto thenInputType = thenFuncType.getInput(i).cast<TensorType>();
if (!AreCastCompatible({operandType, thenInputType}))
for (unsigned i = 0; i < expected_num_inputs; ++i) {
auto operand_type = op.getOperand(i + 1).getType().cast<TensorType>();
auto then_input_type = then_fn_type.getInput(i).cast<TensorType>();
if (!AreCastCompatible({operand_type, then_input_type}))
return op.emitError(
llvm::formatv("then branch input type {0} is incompatible with "
"operand type {1} at index {2}",
thenInputType, operandType, i));
then_input_type, operand_type, i));
auto elseInputType = elseFuncType.getInput(i).cast<TensorType>();
if (!AreCastCompatible({operandType, elseInputType}))
auto else_input_type = else_fn_type.getInput(i).cast<TensorType>();
if (!AreCastCompatible({operand_type, else_input_type}))
return op.emitError(
llvm::formatv("else branch input type {0} is incompatible with "
"operand type {1} at index {2}",
elseInputType, operandType, i));
else_input_type, operand_type, i));
// If branches have incompatible input types that means that no tensor can
// serve as input to both the functions. Hence, the op is invalid.
if (!AreCastCompatible({thenInputType, elseInputType}))
if (!AreCastCompatible({then_input_type, else_input_type}))
return op.emitError(llvm::formatv(
"branches inputs have incompatible types {0} and {1} at index {2}",
thenInputType, elseInputType, i));
then_input_type, else_input_type, i));
}
// Branches' results should be pair-wise compatible with the op results.
unsigned expectedNumResults = op.getNumResults();
if (thenFuncType.getNumResults() != expectedNumResults ||
elseFuncType.getNumResults() != expectedNumResults)
return op.emitError("branches should have " + Twine(expectedNumResults) +
unsigned expected_num_results = op.getNumResults();
if (then_fn_type.getNumResults() != expected_num_results ||
else_fn_type.getNumResults() != expected_num_results)
return op.emitError("branches should have " + Twine(expected_num_results) +
" results");
for (unsigned i = 0; i < expectedNumResults; ++i) {
auto resultType = op.getResult(i).getType().cast<TensorType>();
auto thenResultType = thenFuncType.getResult(i).cast<TensorType>();
if (!AreCastCompatible({thenResultType, resultType}))
for (unsigned i = 0; i < expected_num_results; ++i) {
auto result_type = op.getResult(i).getType().cast<TensorType>();
auto then_result_type = then_fn_type.getResult(i).cast<TensorType>();
if (!AreCastCompatible({then_result_type, result_type}))
return op.emitError(
llvm::formatv("then branch result type {0} is incompatible with op "
"result type {1} at index {2}",
thenResultType, resultType, i));
then_result_type, result_type, i));
auto elseResultType = elseFuncType.getResult(i).cast<TensorType>();
if (!AreCastCompatible({elseResultType, resultType}))
auto else_result_type = else_fn_type.getResult(i).cast<TensorType>();
if (!AreCastCompatible({else_result_type, result_type}))
return op.emitError(
llvm::formatv("else branch result type {0} is incompatible with op "
"result type {1} at index {2}",
elseResultType, resultType, i));
else_result_type, result_type, i));
}
return success();
}
@ -3887,36 +3887,37 @@ OpFoldResult VariableShapeOp::fold(ArrayRef<Attribute> operands) {
static LogicalResult Verify(WhileOp op) {
auto module = op.getParentOfType<ModuleOp>();
auto condFn = module.lookupSymbol<FuncOp>(op.cond());
auto bodyFn = module.lookupSymbol<FuncOp>(op.body());
if (!condFn) {
auto cond_fn = module.lookupSymbol<FuncOp>(op.cond());
auto body_fn = module.lookupSymbol<FuncOp>(op.body());
if (!cond_fn) {
return op.emitOpError("cond refers to an undefined function : ")
<< op.cond();
}
if (!bodyFn) {
if (!body_fn) {
return op.emitOpError("body refers to an undefined function : ")
<< op.body();
}
auto condFuncType = condFn.getType();
auto bodyFuncType = bodyFn.getType();
auto cond_fn_type = cond_fn.getType();
auto body_fn_type = body_fn.getType();
// Verify that the cond function has exactly one result.
if (condFuncType.getNumResults() != 1)
if (cond_fn_type.getNumResults() != 1)
return op.emitOpError("requires cond function to have exactly one result");
SmallVector<Type, 4> operands(op.getOperandTypes());
// Collect all the type lists for the op so that different pairs of type lists
// can be compared for the compatibility.
int numTypeLists = 5;
std::pair<std::string, ArrayRef<Type>> typeLists[] = {
{"operand", operands},
{"body function result", bodyFuncType.getResults()},
{"result", op.getResultTypes()},
{"cond function input", condFuncType.getInputs()},
{"body function input", bodyFuncType.getInputs()},
};
constexpr int kNumTypeLists = 5;
const std::array<std::pair<std::string, ArrayRef<Type>>, kNumTypeLists>
type_lists = {{
{"operand", operands},
{"body function result", body_fn_type.getResults()},
{"result", op.getResultTypes()},
{"cond function input", cond_fn_type.getInputs()},
{"body function input", body_fn_type.getInputs()},
}};
// A pair of type lists should be cast compatible with each other if one is
// converted to the another for a function call or assignment or there is a
@ -3940,28 +3941,28 @@ static LogicalResult Verify(WhileOp op) {
// never converted from one to the another nor there is a common source
// tensors. Compatibility requirement is not transitive.
for (int i = 0; i < numTypeLists; ++i) {
for (int i = 0; i < kNumTypeLists; ++i) {
// Skip the first pair as the While op operands and body function results
// does not need to be compatible with each other.
for (int j = std::max(2, i + 1); j < numTypeLists; ++j) {
auto &a = typeLists[i];
auto &b = typeLists[j];
for (int j = std::max(2, i + 1); j < kNumTypeLists; ++j) {
auto &a = type_lists[i];
auto &b = type_lists[j];
int aSize = a.second.size();
if (aSize != b.second.size())
int a_size = a.second.size();
if (a_size != b.second.size())
return op.emitOpError(
llvm::formatv("requires the number of {0}s to be equal to the "
"number of {1}s. Found {2} and {3}, respectively",
a.first, b.first, aSize, b.second.size()));
a.first, b.first, a_size, b.second.size()));
for (int idx = 0; idx < aSize; ++idx) {
auto aType = a.second[idx];
auto bType = b.second[idx];
for (int idx = 0; idx < a_size; ++idx) {
auto a_type = a.second[idx];
auto b_type = b.second[idx];
if (!AreCastCompatible({aType, bType}))
if (!AreCastCompatible({a_type, b_type}))
return op.emitError(llvm::formatv(
"{0} type {1} is incompatible with {2} type {3} at index {4}",
a.first, aType, b.first, bType, idx));
a.first, a_type, b.first, b_type, idx));
}
}
}

View File

@ -149,7 +149,7 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
}
auto walk_res = func_op.walk([&](Operation* op) {
if (auto var_handle = llvm::dyn_cast<TF::VarHandleOp>(op)) {
// Record VarHanldeOp's device attribute.
// Record VarHandleOp's device attribute.
auto device_attr =
var_handle.getAttrOfType<mlir::StringAttr>(kDeviceAttr);
if (!device_attr || device_attr.getValue().empty()) {

View File

@ -571,7 +571,7 @@ void AddLoadsStoresOutsideControlFlowOp(
}
// Lifts loads/stores from while loop's body and cond functions.
LogicalResult HanldeWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
// Remove identity nodes to avoid aliasing.
RemoveIdentity(&body.front());
RemoveIdentity(&cond.front());
@ -985,7 +985,7 @@ LogicalResult HoistForFunctionalControlFlow(
lifted_partitioned_call_callees);
HoistForFunctionalControlFlow(&cond.front(), module,
lifted_partitioned_call_callees);
if (failed(HanldeWhileLoop(while_op, body, cond))) return failure();
if (failed(HandleWhileLoop(while_op, body, cond))) return failure();
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
auto then_branch =
llvm::cast<FuncOp>(module.lookupSymbol(if_op.then_branch()));

View File

@ -562,7 +562,7 @@ class ShapeInference {
private:
// Mapping between ValuePort (which corresponds to an OpResult or smaller,
// e.g., first element of OpResult produded) to an Attribute if the ValuePort
// e.g., first element of OpResult produced) to an Attribute if the ValuePort
// corresponds to a constant value.
ValuePortResultMap results_;
int64_t graph_version_;
@ -1144,7 +1144,7 @@ LogicalResult InferShapeForFunction(FuncOp func,
ArrayRef<int64_t> shape = arg_shapes[i];
Type element_type;
if (auto input_ty = func_type.getInput(i).dyn_cast<RankedTensorType>()) {
if (!input_ty || input_ty.getShape().size() != shape.size()) {
if (input_ty.getRank() != shape.size()) {
return failure();
}
element_type = input_ty.getElementType();