PiperOrigin-RevId: 306837009
Change-Id: I430ad66bada18eba9b12fb99d6bfc3229393f91b
This commit is contained in:
Benjamin Kramer 2020-04-16 06:53:42 -07:00 committed by TensorFlower Gardener
parent 5bf9037542
commit d61d643dcc
5 changed files with 29 additions and 9 deletions

View File

@ -269,7 +269,8 @@ class BroadcastInDimConverter
// The input is a scalar, i.e. this is a scalar broadcast op.
inputMap = AffineMap::get(nloops, /*symbolCount=*/0, b->getContext());
} else {
inputMap = AffineMap::get(nloops, /*symbolCount=*/0, dimExprs);
inputMap = AffineMap::get(nloops, /*symbolCount=*/0, dimExprs,
b->getContext());
}
}
return b->getAffineMapArrayAttr(
@ -295,7 +296,7 @@ class TransposeConverter
b->getAffineDimExpr(permutation.index());
}
return b->getAffineMapArrayAttr(
{AffineMap::get(nloops, /*symbolCount=*/0, inputExprs),
{AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)});
}
};
@ -367,7 +368,7 @@ class ReshapeAddRemoveDimConverter
return nullptr;
inputExprs.resize(operandShape.size(), b->getAffineConstantExpr(0));
return b->getAffineMapArrayAttr(
{AffineMap::get(nloops, /*symbolCount=*/0, inputExprs),
{AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)});
}
};

View File

@ -93,7 +93,8 @@ ShapeInfo GetShapeInfo(
}
shape_info.affine_map = mlir::AffineMap::get(
/*dimCount=*/2 + spatial_dims.size(), /*symbolCount=*/0, affine_exprs);
/*dimCount=*/2 + spatial_dims.size(), /*symbolCount=*/0, affine_exprs,
builder.getContext());
shape_info.element_type = [&] {
switch (shape.element_type()) {
@ -315,9 +316,9 @@ StatusOr<InitialMlirConvAnchors> CreateNaiveMlirConv(
builder.createOrFold<mlir::AffineLoadOp>(
location, input,
mlir::AffineMap(input_shape_info.affine_map)
.compose(
mlir::AffineMap::get(/*dimCount=*/2 + num_spatial_dims * 2,
/*symbolCount=*/0, input_indices)),
.compose(mlir::AffineMap::get(
/*dimCount=*/2 + num_spatial_dims * 2,
/*symbolCount=*/0, input_indices, builder.getContext())),
input_vars),
builder.getF32Type());
}();

View File

@ -658,8 +658,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
)
# Check out LLVM and MLIR from llvm-project.
LLVM_COMMIT = "129cf84e69537ae5c184550f94be18da738d9261"
LLVM_SHA256 = "4b56ff735e973ee982c4f24ab5b22d56ad0734f2d12d9c57fe1c21530a5840a0"
LLVM_COMMIT = "3ee1ec0b9dd6ee2350f39ae8a418bf3ce28d06cf"
LLVM_SHA256 = "473eae82b2c9f8bace1b86018b809a1565723d0f9affaeb88d3d95b685062746"
LLVM_URLS = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
"https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),

View File

@ -449,6 +449,23 @@ cc_library(
],
)
cc_library(
name = "AffineUtils",
srcs = glob(
[
"lib/Dialect/Affine/Utils/*.cpp",
"lib/Dialect/Affine/Utils/*.h",
],
),
hdrs = ["include/mlir/Dialect/Affine/Utils.h"],
includes = ["include"],
deps = [
":Affine",
":IR",
"@llvm-project//llvm:support",
],
)
gentbl(
name = "AffinePassIncGen",
strip_include_prefix = "include",

View File

@ -207,6 +207,7 @@ cc_library(
"@llvm-project//llvm:support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:AffineTransforms",
"@llvm-project//mlir:AffineUtils",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",