Merge branch 'master' into toupstream/16x8_resize_bilinear

This commit is contained in:
Thibaut Goetghebuer-Planchon 2020-11-05 08:15:03 +00:00
commit 5bc12f3435
380 changed files with 7187 additions and 4121 deletions

View File

@ -36,6 +36,9 @@
* Exposing `tf.data.experimental.ExternalStatePolicy`, which can be used
to control how external state should be handled during dataset
serialization or iterator checkpointing.
* XLA compilation:
* `tf.function(experimental_compile=True)` has become a stable API,
renamed `tf.function(jit_compile=True)`.
* `tf.lite`:
* NNAPI

View File

@ -1205,7 +1205,7 @@ typedef struct TF_Session TF_Session;
// Return a new execution session with the associated graph, or NULL on
// error. Does not take ownership of any input parameters.
//
// *`graph` must be a valid graph (not deleted or nullptr). `graph` will be be
// *`graph` must be a valid graph (not deleted or nullptr). `graph` will be
// kept alive for the lifetime of the returned TF_Session. New nodes can still
// be added to `graph` after this call.
TF_CAPI_EXPORT extern TF_Session* TF_NewSession(TF_Graph* graph,

View File

@ -86,7 +86,7 @@ TF_CAPI_EXPORT void TF_SetXlaConstantFoldingDisabled(
// Create a serialized tensorflow.ConfigProto proto, where:
//
// a) ConfigProto.optimizer_options.global_jit_level is set to to ON_1 if
// a) ConfigProto.optimizer_options.global_jit_level is set to ON_1 if
// `enable_xla_compilation` is non-zero, and OFF otherwise.
// b) ConfigProto.gpu_options.allow_growth is set to `gpu_memory_allow_growth`.
// c) ConfigProto.device_count is set to `num_cpu_devices`.

View File

@ -481,7 +481,7 @@ typedef struct TFE_CustomDevice {
// "/job:localhost/replica:0/task:0/device:CUSTOM:0".
//
// The custom device defines copy operations for moving TensorHandles on and
// off, and an an execution operation for named operations. Often execution will
// off, and an execution operation for named operations. Often execution will
// simply wrap op execution on one or more physical devices.
//
// device_info is an opaque caller-defined type stored with the custom device

View File

@ -16,9 +16,9 @@ limitations under the License.
#define TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_
// Run a function containing a MatMul op and check its output.
// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one
// which creates a remote remote input, to simulate a scenario that the remote
// input is not ready when we start running an op or a function.
// If heavy_load_on_streaming_rpc is true, send some rpc requests before the one
// which creates a remote input, to simulate a scenario that the remote input
// is not ready when we start running an op or a function.
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
bool heavy_load_on_streaming_rpc,
bool remote_func_outputs = false);

View File

@ -696,7 +696,7 @@ TEST(CAPI, ExecuteAddForwardAsync) {
/*tfrt*/ false);
}
#ifdef PLATFORM_GOOGLE
// TODO(b/153349425): Add add forwarding tests for TFRT
// TODO(b/153349425): Add forwarding tests for TFRT
TEST(CAPI, ExecuteAddTfrt) {
ExecuteAdd(
/*async=*/false,

View File

@ -46,8 +46,6 @@ class SavedModelAPI {
virtual Status GetSignatureDefFunction(const std::string& signature_def_key,
SignatureDefFunction** function) = 0;
virtual std::vector<ConcreteFunction*> ListFunctions() = 0;
virtual ~SavedModelAPI() = default;
};

View File

@ -211,15 +211,6 @@ Status TFSavedModelAPI::GetSignatureDefFunction(
return Status();
}
std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {
std::vector<ConcreteFunction*> result;
result.reserve(revived_objects_.concrete_functions.size());
for (auto& index_and_function : revived_objects_.concrete_functions) {
result.push_back(index_and_function.second.get());
}
return result;
}
Status TFSavedModelAPI::GetVariable(const std::string& variable_path,
Variable** variable) {
absl::optional<int> node =

View File

@ -66,8 +66,6 @@ class TFSavedModelAPI : public SavedModelAPI {
ImmediateExecutionContext* context,
std::unique_ptr<TFSavedModelAPI>* out);
std::vector<ConcreteFunction*> ListFunctions() override;
~TFSavedModelAPI() override = default;
Status GetVariable(const std::string& variable_path, Variable** variable);

View File

@ -122,9 +122,4 @@ TF_GetSavedModelSignatureDefFunction(TF_SavedModel* model,
return tensorflow::wrap(result);
}
TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) {
return new TF_ConcreteFunctionList{
tensorflow::unwrap(model)->ListFunctions()};
}
} // end extern "C"

View File

@ -100,11 +100,6 @@ TF_GetSavedModelSignatureDefFunction(TF_SavedModel* model,
const char* signature_def_key,
TF_Status* status);
// Returns a list of all ConcreteFunctions stored in this SavedModel.
// The lifetime of the returned list is bound to `model`.
TF_CAPI_EXPORT extern TF_ConcreteFunctionList* TF_ListSavedModelFunctions(
TF_SavedModel* model);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus

View File

@ -44,7 +44,7 @@ class DummyDevice : public DeviceBase {
}
};
// Helper for comparing ouput and expected output
// Helper for comparing output and expected output
void ExpectSummaryMatches(const Summary& actual, const string& expected_str) {
Summary expected;
ASSERT_TRUE(protobuf::TextFormat::ParseFromString(expected_str, &expected));

View File

@ -76,7 +76,7 @@ class Tensor {
// unknown rank.
int dims() const;
// Returns the number of elements in in demension `d`.
// Returns the number of elements in dimension `d`.
// REQUIRES: `0 <= d < dims()`
int64_t dim_size(int d) const;
@ -154,7 +154,7 @@ inline Tensor Tensor::FromBuffer(TF_DataType dtype,
// 1. Only a function pointer is sent across the C API (&DeleterFunction)
// 2. DeleterFunction is defined in the same build artifact that constructed
// the std::function (so there isn't confusion about std::function ABI).
// Note that 2. is satisifed by the fact that this is a header-only API, where
// Note that 2. is satisfied by the fact that this is a header-only API, where
// the function implementations are inline.
DeleterStruct* deleter_struct = new DeleterStruct{deleter};

View File

@ -67,7 +67,7 @@ bool IsZero(const Scope& scope, const Output& grad) {
// mat: A 2-D tensor of dimension [D0, D1]
//
// Returns:
// A tensor of dimension [D0, D1], the result fo vec * mat.
// A tensor of dimension [D0, D1], the result for vec * mat.
Output BroadcastMul(const Scope& scope, const Output& vec, const Output& mat) {
auto reshaped = ExpandDims(scope, vec, -1);
return Multiply(scope, reshaped, mat);

View File

@ -84,9 +84,6 @@ class SavedModelAPI {
SignatureDefFunction* GetSignatureDefFunction(
const std::string& function_path, Status* status);
// Lists all Conrete Functions available from the SavedModel.
std::vector<ConcreteFunction*> ListFunctions();
// SavedModelAPI is movable, but not copyable.
SavedModelAPI(SavedModelAPI&&) = default;
SavedModelAPI& operator=(SavedModelAPI&&) = default;
@ -151,11 +148,6 @@ inline SignatureDefFunction* SavedModelAPI::GetSignatureDefFunction(
return SignatureDefFunction::wrap(function);
}
inline std::vector<ConcreteFunction*> SavedModelAPI::ListFunctions() {
ConcreteFunctionList list(TF_ListSavedModelFunctions(saved_model_.get()));
return list.ToVector();
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow

View File

@ -138,7 +138,7 @@ class FreezeTest : public ::testing::Test {
}
TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
// "c" isnt dependent on the variable, so nothing should be frozen.
// "c" isn't dependent on the variable, so nothing should be frozen.
TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
graph_def, {"c:0"}, "assign", &saved_model_bundle));
@ -183,7 +183,7 @@ class FreezeTest : public ::testing::Test {
}
Output c = ops::Mul(scope.WithOpName("c"), a, read_var);
TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
// "c" isnt dependent on the variable, so nothing should be frozen.
// "c" isn't dependent on the variable, so nothing should be frozen.
TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
graph_def, {"c:0"}, "assign", &saved_model_bundle));
@ -244,7 +244,7 @@ class FreezeTest : public ::testing::Test {
Output c = ops::Mul(scope.WithOpName("c"), a, read_var);
TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
// "c" isnt dependent on the variable, so nothing should be frozen.
// "c" isn't dependent on the variable, so nothing should be frozen.
TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
graph_def, {"c:0"}, "assign", &saved_model_bundle));

View File

@ -173,7 +173,8 @@ Status XlaCompilationCache::BuildExecutable(
build_options.set_result_layout(result.xla_output_shape);
build_options.set_device_allocator(options.device_allocator);
build_options.set_alias_passthrough_params(options.alias_passthrough_params);
build_options.mutable_debug_options()->set_xla_detailed_logging(
options.detailed_logging);
TF_ASSIGN_OR_RETURN(
auto executables,
client_->Compile(*result.computation, argument_layouts, build_options));

View File

@ -132,7 +132,8 @@ Status XlaCompileOnDemandOp::Compile(
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
platform_info_,
/*has_ref_vars=*/true, &tf_allocator_adapter);
// No detailed logging from on demand op.
options.detailed_logging = false;
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true;
// Optimization: where possible, have the computation return a naked array

View File

@ -398,7 +398,7 @@ static void ShowXlaDeviceDeprecationWarning(
absl::call_once(once, [] {
LOG(INFO) << "XLA_GPU and XLA_CPU devices are deprecated and will be "
"removed in subsequent releases. Instead, use either "
"@tf.function(experimental_compile=True) for must-compile "
"@tf.function(jit_compile=True) for must-compile "
"semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
"for auto-clustering best-effort compilation.";
});

View File

@ -568,22 +568,6 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "lhlo_legalize_to_llvm",
srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc"],
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"],
deps = [
":lhlo",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:LLVMTransforms",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
)
cc_library(
name = "legalize_to_linalg",
srcs = ["lib/Dialect/mhlo/transforms/legalize_to_linalg.cc"],
@ -952,7 +936,6 @@ cc_library(
srcs = [
"include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h",
"lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc",
"lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc",
"lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc",
"lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc",
"lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc",
@ -962,7 +945,6 @@ cc_library(
":chlo_legalize_to_hlo", # build-cleaner: keep
":hlo",
":lhlo",
":lhlo_legalize_to_llvm", # build-cleaner: keep
":materialize_broadcasts", # build-cleaner: keep
":pass_details",
":unfuse_batch_norm", # build-cleaner: keep

View File

@ -314,169 +314,6 @@ def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
);
}
//===----------------------------------------------------------------------===//
// StaticMemRefCastOp
//===----------------------------------------------------------------------===//
def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
[NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
let summary = [{
modifies the offset, sizes and strides of a statically shaped memref
}];
let description = [{
Casts the statically shaped memref operand to a memref with optionally
modified offsets, sizes and strides.
Example:
```mlir
%buf_transformed =
lmhlo.static_memref_cast %buf
: memref<1x5xf32> -> memref<5xf32, offset: 2, strides: [1]>
// The result of the op is a rank-1 memref with `[5]` shape, stride 1 and
// offset 2.
```
}];
let arguments = (ins Arg<LHLO_Buffer, "", []>:$operand);
let results = (outs Res<LHLO_Buffer, "", []>:$result);
let builders = [
OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$operand),
[{
$_state.addOperands(operand);
$_state.types.push_back(resultType);
}]>];
let extraClassDeclaration = [{
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
}];
let verifier = [{ return Verify(*this); }];
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($result)
}];
}
//===----------------------------------------------------------------------===//
// DynamicMemRefCastOp
//===----------------------------------------------------------------------===//
def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
[SameVariadicOperandSize, NoSideEffect,
DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
let summary = "dynamic memref cast operation";
let description = [{
Change sizes and strides of a memref using the values computed in runtime.
Example:
```mlir
%buf_transformed =
lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y]
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
// The result of the op is a type-erased memref with `[%size_X, %size_Y]`
// shape and `[%step_X, %step_Y]` strides. The offset will be inherited
// from the input.
```
}];
let arguments = (ins
Arg<LHLO_Buffer, "", []>:$operand,
Variadic<Index>:$sizes,
Variadic<Index>:$strides
);
let results = (outs Res<LHLO_Buffer, "", []>:$result);
let builders = [
OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$operand,
"ValueRange":$sizes, "ValueRange":$strides),
[{
$_state.addOperands(operand);
$_state.addOperands(sizes);
$_state.addOperands(strides);
$_state.types.push_back(resultType);
}]>];
let extraClassDeclaration = [{
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
}];
let verifier = [{ return Verify(*this); }];
let assemblyFormat = [{
$operand `(` $sizes `)` `[` $strides `]` attr-dict `:` type($operand) `->`
type($result)
}];
}
//===----------------------------------------------------------------------===//
// ReshapeMemRefCastOp
//===----------------------------------------------------------------------===//
def ReshapeMemRefCastOp: Op<LHLO_Dialect, "reshape_memref_cast", [
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
NoSideEffect]> {
let summary = "reshape memref cast operation";
let description = [{
The `reshape_memref_cast` operation converts a memref from one type to an
equivalent type with a provided shape. The data is never copied or moved.
The source and destination types are compatible if both have the same
element type, address space and identity layout map. The following
combinations are possible:
a. Both are ranked memref types.
```mlir
// Reshape statically-shaped memref.
%dst = reshape_memref_cast %src(%shape)
: (memref<4x1xf32>, memref<1xi32>) to memref<4xf32>
%dst0 = reshape_memref_cast %src(%shape0)
: (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32>
```
b. Source type is ranked, destination type is unranked.
```mlir
// Reshape dynamically-shaped 1D memref.
%dst = reshape_memref_cast %src(%shape)
: (memref<?xf32>, memref<?xi32>) to memref<*xf32>
```
c. Source type is unranked, destination type is ranked.
```mlir
// Flatten unranked memref.
%dst = reshape_memref_cast %src(%shape)
: (memref<*xf32>, memref<1xi32>) to memref<?xf32>
```
d. Both are unranked memref types.
```mlir
// Reshape unranked memref.
%dst = reshape_memref_cast %src(%shape)
: (memref<*xf32>, memref<?xi32>) to memref<*xf32>
```
}];
let arguments = (ins
AnyRankedOrUnrankedMemRef:$operand,
LHLO_ExtentBuffer:$shape
);
let results = (outs AnyRankedOrUnrankedMemRef:$result);
let extraClassDeclaration = [{
BaseMemRefType getType() {
return getResult().getType().cast<BaseMemRefType>(); }
}];
let verifier = [{ return Verify(*this); }];
let assemblyFormat = [{
$operand `(` $shape `)` attr-dict `:` `(` type($operand) `,` type($shape)
`)` `->` type($result)
}];
}
//===----------------------------------------------------------------------===//
// LMHLO Other op definitions.
//===----------------------------------------------------------------------===//

View File

@ -46,12 +46,6 @@ def LhloLegalizeToGpuPass : Pass<"lhlo-legalize-to-gpu", "FuncOp"> {
}
def TestLhloToLLVMPass : Pass<"test-lhlo-legalize-to-llvm", "FuncOp"> {
let summary = "Legalize from LHLO dialect to LLVM.";
let constructor = "createTestLhloToLLVMPass()";
}
def LhloLegalizeToParallelLoopsPass : Pass<"lhlo-legalize-to-parallel-loops", "FuncOp"> {
let summary = "Legalize from LHLO dialect to parallel loops.";
let constructor = "createLegalizeLhloToParallelLoopsPass()";

View File

@ -48,12 +48,8 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass();
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
/// buffers if necessary. If `results_escape_functions` is set to true,
/// allocated buffers for function results will be returned and escape the
/// function. Otherwise, the signature is rewritten with extra arguments for the
/// buffers that are to be used for results.
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
bool results_escape_functions = false);
/// buffers if necessary.
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass();
// Lowers from HLO dialect to Linalg dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass();

View File

@ -35,8 +35,6 @@ inline void registerAllMhloPasses() { registerMHLOPasses(); }
namespace lmhlo {
std::unique_ptr<Pass> createTestLhloToLLVMPass();
#define GEN_PASS_REGISTRATION
#include "mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc"

View File

@ -24,8 +24,6 @@ limitations under the License.
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
class LLVMTypeConverter;
class LowerToLLVMOptions;
class OwningRewritePatternList;
// Populates a collection of rewrite patterns to realize element-wise operations
@ -94,14 +92,6 @@ void PopulateTrigonometricToApproximationPatterns(
} // namespace mhlo
namespace lmhlo {
/// Collect a set of patterns to convert from the LHLO dialect to LLVM.
void PopulateLhloToLLVMConversionPatterns(LLVMTypeConverter *converter,
OwningRewritePatternList *patterns);
} // namespace lmhlo
namespace chlo {
// Populates a collection of conversion patterns for legalizing client-HLO to

View File

@ -2173,10 +2173,21 @@ LogicalResult SliceOp::inferReturnTypes(
return success();
}
int64_t rank = ranked_ty.getRank();
ShapedType attr_ty = slice.start_indices().getType();
if (attr_ty.getRank() != 1 || attr_ty.getNumElements() != rank ||
!attr_ty.getElementType().isSignlessInteger(64) ||
if (attr_ty.getRank() != 1) {
return emitOptionalError(location, "start_indices has rank ",
attr_ty.getRank(), " instead of required rank 1");
}
int64_t rank = ranked_ty.getRank();
if (attr_ty.getNumElements() != rank) {
return emitOptionalError(
location, "the number of elements in start_indices (",
attr_ty.getNumElements(), ") does not match the rank of the operand (",
rank, ")");
}
if (!attr_ty.getElementType().isSignlessInteger(64) ||
slice.limit_indices().getType() != attr_ty ||
slice.strides().getType() != attr_ty) {
// Unfortunately we can't rely on the AllTypesMatch trait for the SliceOp

View File

@ -88,76 +88,6 @@ void ConstOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
results.insert<EraseConstOp>(context);
}
//===----------------------------------------------------------------------===//
// StaticMemRefCastOp
//===----------------------------------------------------------------------===//
Value StaticMemRefCastOp::getViewSource() { return *getODSOperands(0).begin(); }
static LogicalResult Verify(StaticMemRefCastOp op) {
if (!op.operand().getType().cast<ShapedType>().hasStaticShape())
return op.emitOpError("operand must have static shape");
if (!op.getType().hasStaticShape())
return op.emitOpError("result must have static shape");
return success();
}
//===----------------------------------------------------------------------===//
// DynamicMemRefCastOp
//===----------------------------------------------------------------------===//
Value DynamicMemRefCastOp::getViewSource() {
return *getODSOperands(0).begin();
}
static LogicalResult Verify(DynamicMemRefCastOp op) {
// Check if `sizes` and `strides` args are compatible with the result type.
if (op.sizes().size() != op.getType().getRank())
return op.emitOpError(
"`sizes` args count must be equal to the rank of the output memref");
return success();
}
//===----------------------------------------------------------------------===//
// ReshapeMemrefCastOp
//===----------------------------------------------------------------------===//
Value ReshapeMemRefCastOp::getViewSource() { return operand(); }
static LogicalResult Verify(ReshapeMemRefCastOp op) {
Type operandType = op.operand().getType();
Type resultType = op.result().getType();
Type operandElementType = operandType.cast<ShapedType>().getElementType();
Type resultElementType = resultType.cast<ShapedType>().getElementType();
if (operandElementType != resultElementType)
return op.emitOpError(
"element types of source and destination memref "
"types should be the same");
if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
if (!operandMemRefType.getAffineMaps().empty())
return op.emitOpError(
"operand memref type should have identity affine map");
int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
auto resultMemRefType = resultType.dyn_cast<MemRefType>();
if (resultMemRefType) {
if (shapeSize == ShapedType::kDynamicSize)
return op.emitOpError(
"cannot use shape operand with dynamic length to "
"cast statically-ranked memref type");
if (shapeSize != resultMemRefType.getRank())
return op.emitOpError(
"length of shape operand differs from the result's memref rank");
if (!resultMemRefType.getAffineMaps().empty())
return op.emitOpError(
"result memref type should have identity affine map");
}
return success();
}
} // namespace lmhlo
} // namespace mlir

View File

@ -67,6 +67,7 @@ add_mlir_library(MhloPasses
DEPENDS
MLIRhlo_opsIncGen
MLIRMhloLowerComplexIncGen
MLIRMhloPassIncGen
LINK_COMPONENTS
Core
@ -133,8 +134,6 @@ add_mlir_library(LmhloPasses
lhlo_fuse_linalg.cc
lhlo_legalize_to_affine.cc
lhlo_legalize_to_gpu.cc
lhlo_legalize_to_llvm.cc
lhlo_legalize_to_llvm_pass.cc
lhlo_legalize_to_parallel_loops.cc
DEPENDS

View File

@ -206,7 +206,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
// Inserts dynamic memref to change the layout of the memref to put 0-stride
// and size of the target dimension if size-1 dimension expansion is
// necessary.
lmhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
MemRefReinterpretCastOp InsertDynamicMemrefCastOp(
mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
auto loc = op.getLoc();
auto operand_type = operand.getType().cast<MemRefType>();
@ -259,8 +259,13 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
makeStridedLinearLayoutMap(dynamic_layout,
/*offset=*/0, b->getContext()));
auto transformed_operand = b->create<lmhlo::DynamicMemRefCastOp>(
loc, type_erased_memref_type, operand, sizes, strides);
SmallVector<int64_t, 2> static_sizes(sizes.size(),
ShapedType::kDynamicSize);
SmallVector<int64_t, 2> static_strides(strides.size(),
ShapedType::kDynamicStrideOrOffset);
auto transformed_operand = b->create<MemRefReinterpretCastOp>(
loc, type_erased_memref_type, operand, /*offset=*/0, static_sizes,
static_strides, llvm::None, sizes, strides);
return transformed_operand;
}
};
@ -284,7 +289,7 @@ struct HloToLhloDynamicReshapeConverter
return failure();
}
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<lmhlo::ReshapeMemRefCastOp>(
rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
op, result_type, adaptor.operand(), adaptor.output_shape());
return success();
}
@ -504,12 +509,7 @@ struct HloLegalizeToLhlo
public:
HloLegalizeToLhlo() = default;
HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {
this->results_escape_function = o.results_escape_function.getValue();
}
explicit HloLegalizeToLhlo(bool results_escape_function) {
this->results_escape_function.setValue(results_escape_function);
}
HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {}
void runOnOperation() override {
OwningRewritePatternList patterns;
@ -542,13 +542,6 @@ struct HloLegalizeToLhlo
isMemRefType);
});
auto kind = results_escape_function
? BufferizeTypeConverter::KeepAsFunctionResult
: BufferizeTypeConverter::AppendToArgumentsList;
converter.setResultConversionKind<UnrankedTensorType, UnrankedMemRefType>(
kind);
converter.setResultConversionKind<RankedTensorType, MemRefType>(kind);
populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
populateWithBufferizeOpConversionPatterns<mlir::ReturnOp, mlir::ReturnOp,
lmhlo::CopyOp>(
@ -559,13 +552,6 @@ struct HloLegalizeToLhlo
std::move(patterns))))
signalPassFailure();
}
private:
Option<bool> results_escape_function{
*this, "results-escape-function",
llvm::cl::desc(
"Allocate the results of functions within the functions body"),
llvm::cl::init(false)};
};
} // namespace
@ -625,9 +611,8 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
// clang-format on
}
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
bool results_escape_function) {
return std::make_unique<HloLegalizeToLhlo>(results_escape_function);
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass() {
return std::make_unique<HloLegalizeToLhlo>();
}
} // namespace mhlo

View File

@ -1,370 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace lmhlo {
namespace {
struct StaticMemRefCastOpConverter
: public ConvertOpToLLVMPattern<StaticMemRefCastOp> {
using ConvertOpToLLVMPattern<StaticMemRefCastOp>::ConvertOpToLLVMPattern;
LogicalResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto cast_op = cast<StaticMemRefCastOp>(op);
StaticMemRefCastOp::Adaptor operands_adaptor(operands);
MemRefDescriptor sourceMemRef(operands_adaptor.operand());
MemRefType targetMemRefType =
cast_op.getResult().getType().cast<MemRefType>();
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return failure();
// Create descriptor.
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
Type llvmTargetElementTy = desc.getElementPtrType();
// Set allocated ptr.
Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
allocated =
rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
desc.setAllocatedPtr(rewriter, loc, allocated);
// Set aligned ptr.
Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
desc.setAlignedPtr(rewriter, loc, ptr);
// Fill size and stride descriptors in memref.
auto target_sizes = targetMemRefType.getShape();
int64_t target_offset;
llvm::SmallVector<int64_t, 4> target_strides;
if (failed((getStridesAndOffset(targetMemRefType, target_strides,
target_offset))))
return failure();
// Copy offset of `targetMemRef`.
desc.setConstantOffset(rewriter, loc, target_offset);
for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
desc.setConstantSize(rewriter, loc, i, target_sizes[i]);
desc.setConstantStride(rewriter, loc, i, target_strides[i]);
}
rewriter.replaceOp(op, {desc});
return success();
}
};
struct DynamicMemRefCastOpConverter
: public ConvertOpToLLVMPattern<DynamicMemRefCastOp> {
using ConvertOpToLLVMPattern<DynamicMemRefCastOp>::ConvertOpToLLVMPattern;
LogicalResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto cast_op = cast<DynamicMemRefCastOp>(op);
DynamicMemRefCastOp::Adaptor operands_adaptor(operands);
MemRefDescriptor sourceMemRef(operands_adaptor.operand());
MemRefType targetMemRefType =
cast_op.getResult().getType().cast<MemRefType>();
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return failure();
// Create descriptor.
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
Type llvmTargetElementTy = desc.getElementPtrType();
// Set allocated ptr.
Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
allocated =
rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
desc.setAllocatedPtr(rewriter, loc, allocated);
// Set aligned ptr.
Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
desc.setAlignedPtr(rewriter, loc, ptr);
// Copy offset of `sourceMemRef`.
desc.setOffset(rewriter, loc, sourceMemRef.offset(rewriter, loc));
// Fill size and stride descriptors in memref.
if (!cast_op.sizes().empty()) {
auto sizes = operands_adaptor.sizes();
auto strides = operands_adaptor.strides();
for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
desc.setSize(rewriter, loc, i, sizes[i]);
desc.setStride(rewriter, loc, i, strides[i]);
}
}
rewriter.replaceOp(op, {desc});
return success();
}
};
struct ReshapeMemRefCastOpConverter
: public ConvertOpToLLVMPattern<ReshapeMemRefCastOp> {
using ConvertOpToLLVMPattern<ReshapeMemRefCastOp>::ConvertOpToLLVMPattern;
LogicalResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto reshape_op = cast<ReshapeMemRefCastOp>(op);
auto dst_type = reshape_op.getResult().getType().cast<BaseMemRefType>();
auto element_type = dst_type.getElementType();
auto shape = reshape_op.shape();
ReshapeMemRefCastOp::Adaptor operands_adaptor(operands);
PtrsAndOffset ptrs_n_offset = ExtractMemRefPtrsAndOffset(
loc, reshape_op.operand(), operands_adaptor.operand(), &rewriter);
MemRefDescriptor shape_desc(operands_adaptor.shape());
auto shape_memref_type = shape.getType().cast<MemRefType>();
if (shape_memref_type.hasStaticShape()) {
auto shape_length = shape_memref_type.getDimSize(0);
MemRefType targetMemRefType = MemRefType::get(
SmallVector<int64_t, 1>(shape_length, 1), element_type);
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return failure();
// Create descriptor.
auto desc =
MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
desc.setAllocatedPtr(rewriter, loc, ptrs_n_offset.allocated_ptr);
desc.setAlignedPtr(rewriter, loc, ptrs_n_offset.aligned_ptr);
desc.setOffset(rewriter, loc, ptrs_n_offset.offset);
auto llvm_index_type = typeConverter.getIndexType();
auto llvm_index_ptr_type = llvm_index_type.getPointerTo();
Value stride_carried = rewriter.create<LLVM::ConstantOp>(
loc, llvm_index_type,
rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
for (int i = shape_length - 1; i >= 0; --i) {
Value pos = rewriter.create<LLVM::ConstantOp>(
loc, llvm_index_type,
rewriter.getIntegerAttr(rewriter.getIndexType(), i));
Value ptr = rewriter.create<LLVM::GEPOp>(
loc, llvm_index_ptr_type, shape_desc.alignedPtr(rewriter, loc),
ValueRange{pos});
Value extracted_size = rewriter.create<LLVM::LoadOp>(loc, ptr);
desc.setSize(rewriter, loc, i, extracted_size);
desc.setStride(rewriter, loc, i, stride_carried);
// Update stride
if (i > 0) {
stride_carried =
rewriter.create<LLVM::MulOp>(loc, stride_carried, extracted_size);
}
}
if (dst_type.isa<MemRefType>()) {
rewriter.replaceOp(op, {desc});
} else {
Value rank = rewriter.create<LLVM::ConstantOp>(
loc, llvm_index_type,
rewriter.getIntegerAttr(rewriter.getIndexType(), shape_length));
Value alloca =
typeConverter.promoteOneMemRefDescriptor(loc, desc, rewriter);
Value void_ptr =
rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), alloca);
auto unranked_desc = UnrankedMemRefDescriptor::pack(
rewriter, loc, typeConverter, dst_type.cast<UnrankedMemRefType>(),
{rank, void_ptr});
rewriter.replaceOp(op, {unranked_desc});
}
return success();
}
// The shape is a rank-1 tensor with unknown length.
Value result_rank = shape_desc.size(rewriter, loc, 0);
// TODO(herhut): Propely handle address spaces.
unsigned address_space = 0;
auto target_type =
typeConverter
.convertType(UnrankedMemRefType::get(element_type, address_space))
.cast<LLVM::LLVMType>();
// Create the unranked memref descriptor that holds the ranked one. The
// inner descriptor is allocated on stack.
UnrankedMemRefDescriptor target_desc =
UnrankedMemRefDescriptor::undef(rewriter, loc, target_type);
target_desc.setRank(rewriter, loc, result_rank);
SmallVector<Value, 1> sizes;
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter,
{target_desc}, sizes);
auto void_ptr_type = LLVM::LLVMType::getInt8PtrTy(rewriter.getContext());
Value ranked_desc_mem = rewriter.create<LLVM::AllocaOp>(
loc, void_ptr_type, sizes.front(), llvm::None);
target_desc.setMemRefDescPtr(rewriter, loc, ranked_desc_mem);
// Fill the fixed parts. For this, we cast to a 0-D memref.
auto zero_d_memref_type = MemRefType::get({}, element_type);
Value as_zero_d = rewriter.create<LLVM::BitcastOp>(
loc,
typeConverter.convertType(zero_d_memref_type)
.cast<LLVM::LLVMType>()
.getPointerTo(address_space),
ranked_desc_mem);
// Some common constants. Use 32 bit where required by gep struct indexes.
auto int32_type = typeConverter.convertType(rewriter.getI32Type());
Value zero_index = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter.getIndexType(), rewriter.getIndexAttr(0));
Value zero = rewriter.create<LLVM::ConstantOp>(
loc, int32_type, rewriter.getI32IntegerAttr(0));
Value one = rewriter.create<LLVM::ConstantOp>(
loc, int32_type, rewriter.getI32IntegerAttr(1));
Value two = rewriter.create<LLVM::ConstantOp>(
loc, int32_type, rewriter.getI32IntegerAttr(2));
// Set base_pointer and aligned pointer.
auto element_ptr_ptr_type = typeConverter.convertType(element_type)
.cast<LLVM::LLVMType>()
.getPointerTo(address_space)
.getPointerTo(address_space);
auto base_gep = rewriter.create<LLVM::GEPOp>(
loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, zero}));
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.allocated_ptr, base_gep);
auto aligned_gep = rewriter.create<LLVM::GEPOp>(
loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, one}));
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.aligned_ptr, aligned_gep);
// Set offset.
auto index_ptr_type =
typeConverter.getIndexType().getPointerTo(address_space);
auto offset_gep = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, as_zero_d, ValueRange({zero_index, two}));
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.offset, offset_gep);
// Use the offset pointer as base for further addressing. Copy over the
// new shape and compute strides. For this, we need to create a loop from
// rank - 1 to 0.
Value one_index = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter.getIndexType(), rewriter.getIndexAttr(1));
auto target_shape_base = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, offset_gep, ValueRange({one}));
auto target_strides_base = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, target_shape_base, ValueRange({result_rank}));
auto shape_ptr = shape_desc.alignedPtr(rewriter, loc);
auto result_rank_minus_one =
rewriter.create<LLVM::SubOp>(loc, result_rank, one_index);
Block *init_block = rewriter.getInsertionBlock();
Block *cond_block =
rewriter.splitBlock(init_block, rewriter.getInsertionPoint());
rewriter.setInsertionPointToEnd(init_block);
rewriter.create<LLVM::BrOp>(
loc, ValueRange({result_rank_minus_one, one_index}), cond_block);
rewriter.setInsertionPointToStart(cond_block);
auto index_arg = cond_block->addArgument(typeConverter.getIndexType());
auto stride_arg = cond_block->addArgument(typeConverter.getIndexType());
auto pred = rewriter.create<LLVM::ICmpOp>(
loc, LLVM::LLVMType::getInt1Ty(rewriter.getContext()),
LLVM::ICmpPredicate::sge, index_arg, zero_index);
Block *body_block =
rewriter.splitBlock(cond_block, rewriter.getInsertionPoint());
rewriter.setInsertionPointToStart(body_block);
// Copy size from shape to descriptor.
auto size_load_gep = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, shape_ptr, ValueRange{index_arg});
auto extracted_size = rewriter.create<LLVM::LoadOp>(loc, size_load_gep);
auto size_store_gep = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, target_shape_base, ValueRange({index_arg}));
rewriter.create<LLVM::StoreOp>(loc, extracted_size, size_store_gep);
// Write stride value and compute next one.
auto stride_store_gep = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, target_strides_base, ValueRange({index_arg}));
rewriter.create<LLVM::StoreOp>(loc, stride_arg, stride_store_gep);
auto next_stride =
rewriter.create<LLVM::MulOp>(loc, stride_arg, extracted_size);
// Decrement loop counter and branch back.
auto decrement = rewriter.create<LLVM::SubOp>(loc, index_arg, one_index);
rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, next_stride}),
cond_block);
Block *remainder =
rewriter.splitBlock(body_block, rewriter.getInsertionPoint());
// Hook up the cond exit to the remainder.
rewriter.setInsertionPointToEnd(cond_block);
rewriter.create<LLVM::CondBrOp>(loc, pred, body_block, ValueRange(),
remainder, ValueRange());
// Reset position to beginning of new remainder block.
rewriter.setInsertionPointToStart(remainder);
rewriter.replaceOp(op, {target_desc});
return success();
}
private:
struct PtrsAndOffset {
Value allocated_ptr;
Value aligned_ptr;
Value offset;
};
PtrsAndOffset ExtractMemRefPtrsAndOffset(
Location loc, Value originalOperand, Value convertedOperand,
ConversionPatternRewriter *rewriter) const {
Type operandType = originalOperand.getType();
Value descriptor_ptr;
if (operandType.isa<MemRefType>()) {
descriptor_ptr = convertedOperand;
} else {
UnrankedMemRefDescriptor unranked_descriptor(convertedOperand);
Value underlying_desc_ptr =
unranked_descriptor.memRefDescPtr(*rewriter, loc);
Type element_type =
operandType.cast<UnrankedMemRefType>().getElementType();
LLVM::LLVMType memref_type_0d =
typeConverter.convertType(MemRefType::get(/*shape=*/{}, element_type))
.cast<LLVM::LLVMType>();
descriptor_ptr = rewriter->create<LLVM::BitcastOp>(
loc, memref_type_0d.getPointerTo(), underlying_desc_ptr);
descriptor_ptr = rewriter->create<LLVM::LoadOp>(loc, descriptor_ptr);
}
MemRefDescriptor descriptor(descriptor_ptr);
PtrsAndOffset result;
result.allocated_ptr = descriptor.allocatedPtr(*rewriter, loc);
result.aligned_ptr = descriptor.alignedPtr(*rewriter, loc);
result.offset = descriptor.offset(*rewriter, loc);
return result;
}
};
} // namespace
void PopulateLhloToLLVMConversionPatterns(LLVMTypeConverter *converter,
OwningRewritePatternList *patterns) {
patterns->insert<DynamicMemRefCastOpConverter, ReshapeMemRefCastOpConverter,
StaticMemRefCastOpConverter>(*converter);
}
} // namespace lmhlo
} // namespace mlir

View File

@ -1,63 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace lmhlo {
namespace {
class TestLhloToLLVMPass
: public ::mlir::PassWrapper<TestLhloToLLVMPass,
::mlir::OperationPass<::mlir::ModuleOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect>();
}
public:
void runOnOperation() override {
ModuleOp m = getOperation();
OwningRewritePatternList patterns;
LLVMTypeConverter converter(&getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
PopulateLhloToLLVMConversionPatterns(&converter, &patterns);
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
target.addIllegalDialect<LmhloDialect>();
if (failed(applyFullConversion(m, target, std::move(patterns)))) {
signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<Pass> createTestLhloToLLVMPass() {
return std::make_unique<TestLhloToLLVMPass>();
}
} // namespace lmhlo
} // namespace mlir

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-hoisting -buffer-deallocation %s -o - | FileCheck %s
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation %s -o - | FileCheck %s
// CHECK-LABEL: func @func_op_unranked_arg_result
func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> {
@ -17,7 +17,7 @@ func @dynamic_reshape_from_unranked(
return %reshaped : tensor<?xf32>
}
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>)
// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]])
// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]])
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
// -----
@ -30,5 +30,5 @@ func @dynamic_reshape_to_unranked(
return %reshaped : tensor<*xf32>
}
// CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>)
// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]])
// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]])
// CHECK-SAME: : (memref<?xf32>, memref<?xi32>) -> memref<*xf32>

View File

@ -1,13 +1,12 @@
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=PRE,BOTH %s
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=ESC,BOTH %s
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck %s
// BOTH-LABEL: func @attrs
// CHECK-LABEL: func @attrs
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.exponential"(%tensor_operand)
{some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
// CHECK: "lmhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -17,16 +16,13 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
return %arg0 : tensor<4xf32>
}
// PRE: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
// PRE-NEXT: "lmhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> ()
// PRE-NEXT: return
// ESC: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// ESC-NOT: "lmhlo.copy"
// ESC-NEXT: return %[[ARG0]]
// CHECK: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// CHECK-NOT: "lmhlo.copy"
// CHECK-NEXT: return %[[ARG0]]
// -----
// BOTH-LABEL: func @func_op_long
// CHECK-LABEL: func @func_op_long
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%1 = mhlo.maximum %arg0, %arg1 : tensor<4xf32>
%2 = mhlo.add %arg0, %1 : tensor<4xf32>
@ -35,91 +31,87 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
%5 = mhlo.multiply %2, %4 : tensor<4xf32>
return %5 : tensor<4xf32>
}
// PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
// ESC: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
// BOTH-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
// BOTH-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
// BOTH-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
// BOTH-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
//  BOTH-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
// BOTH-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
// BOTH-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
// PRE-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
// PRE-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
// PRE-NEXT: return
// ESC-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
//  CHECK-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
// CHECK-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
// -----
// BOTH-LABEL: func @fusion
// CHECK-LABEL: func @fusion
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}})
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
// CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}})
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32>
%tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
%sum = "mhlo.add"(%tensor_summand_1, %tensor_summand_2)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
// CHECK-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
%tensor_result = "mhlo.multiply"(%sum, %tensor_multiplier)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
// BOTH-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
// CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
// CHECK-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
tensor_store %tensor_result, %result : memref<2x2xf32>
// BOTH-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32>
// BOTH-NEXT: return
// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32>
// CHECK-NEXT: return
return
}
// -----
// BOTH-LABEL: func @copy
// CHECK-LABEL: func @copy
func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.copy"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @exp
// CHECK-LABEL: func @exp
func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.exponential"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.exponential"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @log
// CHECK-LABEL: func @log
func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.log"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.log"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.log"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @select
// CHECK-LABEL: func @select
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_pred = tensor_load %pred : memref<2x2xi1>
@ -127,34 +119,34 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "mhlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
// CHECK: "lmhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @compare
// CHECK-LABEL: func @compare
func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "mhlo.compare"(%tensor_lhs, %tensor_rhs)
{comparison_direction = "EQ"}
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
// BOTH: "lmhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
// CHECK: "lmhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
tensor_store %tensor_result, %result : memref<2x2xi1>
return
}
// -----
// BOTH-LABEL: func @broadcast
// CHECK-LABEL: func @broadcast
func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
%tensor_operand = tensor_load %operand : memref<5xf32>
%tensor_result = "mhlo.broadcast_in_dim"(%tensor_operand)
{broadcast_dimensions = dense<1> : tensor<1xi64>}
: (tensor<5xf32>) -> tensor<10x5xf32>
// BOTH: "lmhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK: "lmhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
tensor_store %tensor_result, %result : memref<10x5xf32>
return
}
@ -163,56 +155,57 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
func @external_func() -> tensor<3xi64>
// BOTH: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
// BOTH-LABEL: func @dyn_broadcast
// CHECK-LABEL: func @dyn_broadcast
func @dyn_broadcast(%operand: memref<?x?xf32>) {
// BOTH-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
// CHECK-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
%tensor_operand = tensor_load %operand : memref<?x?xf32>
%c1 = constant 1 : i64
%shape = tensor_from_elements %c1, %c1, %c1 : tensor<3xi64>
%tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
// BOTH: %[[SHAPE:.*]] = tensor_from_elements
// BOTH: %[[C0:.*]] = constant 0 : index
// BOTH: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
// BOTH: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
// BOTH: %[[C1:.*]] = constant 1 : index
// BOTH: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64>
// BOTH: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index
// BOTH: %[[C2:.*]] = constant 2 : index
// BOTH: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64>
// BOTH: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index
// BOTH: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]])
// CHECK: %[[SHAPE:.*]] = tensor_from_elements
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
// CHECK: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64>
// CHECK: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64>
// CHECK: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]])
// BOTH: %[[C0_:.*]] = constant 0 : index
// BOTH: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[C1_:.*]] = constant 1 : index
// BOTH: %[[C1__:.*]] = constant 1 : index
// BOTH: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
// BOTH: %[[C0___:.*]] = constant 0 : index
// BOTH: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref<?x?xf32>
// BOTH: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
// BOTH: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
// BOTH: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
// CHECK: %[[C1__:.*]] = constant 1 : index
// CHECK: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
// CHECK: %[[C0___:.*]] = constant 0 : index
// CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref<?x?xf32>
// CHECK: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
// CHECK: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
// CHECK: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
// BOTH: %[[C2_:.*]] = constant 2 : index
// BOTH: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
// BOTH: %[[C1___:.*]] = constant 1 : index
// BOTH: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref<?x?xf32>
// BOTH: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
// BOTH: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
// BOTH: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
// CHECK: %[[C2_:.*]] = constant 2 : index
// CHECK: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
// CHECK: %[[C1___:.*]] = constant 1 : index
// CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref<?x?xf32>
// CHECK: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
// BOTH: %[[TRANSFORMED_MEMREF:.*]] = lmhlo.dynamic_memref_cast
// BOTH-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]])
// BOTH-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
// BOTH-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map>
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to
// CHECK-SAME: offset: [0],
// CHECK-SAME: sizes: {{\[}}%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]]
// CHECK-SAME: strides: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
// CHECK-SAME: : memref<?x?xf32> to memref<?x?xf32, #map>
// BOTH: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
// BOTH-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
// BOTH-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
// CHECK: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
// CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
// CHECK-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
// Do not store the value back to avoid the tensor-store being rewritten to
// a copy into the pre-allocated argument.
@ -221,7 +214,7 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
// -----
// BOTH-LABEL: func @complex
// CHECK-LABEL: func @complex
func @complex(%real: memref<2x2xf32>,
%imag: memref<2x2xf32>,
%result: memref<2x2xcomplex<f32>>) {
@ -229,14 +222,14 @@ func @complex(%real: memref<2x2xf32>,
%tensor_imag = tensor_load %imag : memref<2x2xf32>
%tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>>
// BOTH: "lmhlo.complex"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.complex"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xcomplex<f32>>
return
}
// -----
// BOTH-LABEL: func @complex_dyn
// CHECK-LABEL: func @complex_dyn
func @complex_dyn(%real: memref<?xf32>,
%imag: memref<?xf32>,
%result: memref<?xcomplex<f32>>) {
@ -244,50 +237,50 @@ func @complex_dyn(%real: memref<?xf32>,
%tensor_imag = tensor_load %imag : memref<?xf32>
%tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag)
: (tensor<?xf32>, tensor<?xf32>) -> tensor<?xcomplex<f32>>
// BOTH: "lmhlo.complex"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.complex"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<?xcomplex<f32>>
return
}
// -----
// BOTH-LABEL: func @real
// CHECK-LABEL: func @real
func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
%tensor_result = "mhlo.real"(%tensor_operand)
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
// BOTH: "lmhlo.real"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.real"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @real_dyn
// CHECK-LABEL: func @real_dyn
func @real_dyn(%operand: memref<?xcomplex<f32>>, %result: memref<?xf32>) {
%tensor_operand = tensor_load %operand : memref<?xcomplex<f32>>
%tensor_result = "mhlo.real"(%tensor_operand)
: (tensor<?xcomplex<f32>>) -> tensor<?xf32>
// BOTH: "lmhlo.real"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.real"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<?xf32>
return
}
// -----
// BOTH-LABEL: func @imag
// CHECK-LABEL: func @imag
func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
%tensor_result = "mhlo.imag"(%tensor_operand)
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
// BOTH: "lmhlo.imag"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.imag"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @gather
// CHECK-LABEL: func @gather
func @gather(%operand: memref<13x7xf32>, %idxs: memref<5xi32>, %result: memref<5x7xf32>) {
%tensor_operand = tensor_load %operand : memref<13x7xf32>
%tensor_idxs = tensor_load %idxs : memref<5xi32>
@ -302,176 +295,176 @@ func @gather(%operand: memref<13x7xf32>, %idxs: memref<5xi32>, %result: memref<5
, name = "gather.71"
, slice_sizes = dense<[1, 7]> : tensor<2xi64> }
: (tensor<13x7xf32>, tensor<5xi32>) -> tensor<5x7xf32>
// BOTH: "lmhlo.gather"(%{{.*}}, %{{.*}}, %{{.*}})
// CHECK: "lmhlo.gather"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<5x7xf32>
return
}
// -----
// BOTH-LABEL: func @imag_dyn
// CHECK-LABEL: func @imag_dyn
func @imag_dyn(%operand: memref<?xcomplex<f32>>, %result: memref<?xf32>) {
%tensor_operand = tensor_load %operand : memref<?xcomplex<f32>>
%tensor_result = "mhlo.imag"(%tensor_operand)
: (tensor<?xcomplex<f32>>) -> tensor<?xf32>
// BOTH: "lmhlo.imag"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.imag"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<?xf32>
return
}
// -----
// BOTH-LABEL: func @iota
// CHECK-LABEL: func @iota
func @iota(%result: memref<10xi32>) {
%tensor_result = "mhlo.iota"()
{iota_dimension = 0 : i64} : () -> tensor<10xi32>
// BOTH: "lmhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
// CHECK: "lmhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
tensor_store %tensor_result, %result : memref<10xi32>
return
}
// -----
// BOTH-LABEL: func @abs
// CHECK-LABEL: func @abs
func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.abs"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.abs"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.abs"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @ceil
// CHECK-LABEL: func @ceil
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.ceil"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.ceil"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.ceil"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @convert
// CHECK-LABEL: func @convert
func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.convert"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}})
// BOTH-NOT: tensor_store
// CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}})
// CHECK-NOT: tensor_store
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @cos
// CHECK-LABEL: func @cos
func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.cosine"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.cosine"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.cosine"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @floor
// CHECK-LABEL: func @floor
func @floor(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.floor"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.floor"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.floor"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @neg
// CHECK-LABEL: func @neg
func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.negate"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.negate"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.negate"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @not
// CHECK-LABEL: func @not
func @not(%operand: memref<2x2xi32>, %result: memref<2x2xi32>) {
%tensor_operand = tensor_load %operand : memref<2x2xi32>
%tensor_result = "mhlo.not"(%tensor_operand)
: (tensor<2x2xi32>) -> tensor<2x2xi32>
// BOTH: "lmhlo.not"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.not"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xi32>
return
}
// -----
// BOTH-LABEL: func @rsqrt
// CHECK-LABEL: func @rsqrt
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.rsqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.rsqrt"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.rsqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @sign
// CHECK-LABEL: func @sign
func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.sign"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.sign"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.sign"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @sqrt
// CHECK-LABEL: func @sqrt
func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.sqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.sqrt"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.sqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @tanh
// CHECK-LABEL: func @tanh
func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.tanh"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.tanh"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.tanh"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @remainder
// CHECK-LABEL: func @remainder
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
// CHECK: "lmhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -479,61 +472,60 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
// -----
// Dynamic shape binary element-wise operation.
// BOTH-LABEL: func @add_dyn
// CHECK-LABEL: func @add_dyn
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
%result = "mhlo.add"(%lhs, %rhs)
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// BOTH: %[[C0:.*]] = constant 0 : index
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// BOTH: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// BOTH: %[[C1:.*]] = constant 1 : index
// BOTH: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// BOTH: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// BOTH: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
// BOTH: %[[C0_:.*]] = constant 0 : index
// BOTH: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// BOTH: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// BOTH: %[[C1_:.*]] = constant 1 : index
// BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// BOTH: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
// CHECK: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
return
}
// -----
// Dynamic shape unary element-wise operation.
// BOTH-LABEL: func @tanh_dyn
// CHECK-LABEL: func @tanh_dyn
func @tanh_dyn(%arg0: tensor<?x?xf32>) {
%result = "mhlo.tanh"(%arg0)
: (tensor<?x?xf32>) -> tensor<?x?xf32>
// BOTH: %[[C0:.*]] = constant 0 : index
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// BOTH: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// BOTH: %[[C1:.*]] = constant 1 : index
// BOTH: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// BOTH: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// BOTH: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
// BOTH: %[[C0_:.*]] = constant 0 : index
// BOTH: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// BOTH: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// BOTH: %[[C1_:.*]] = constant 1 : index
// BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// BOTH: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
// CHECK: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
return
}
// -----
// BOTH-LABEL: func @dot
// CHECK-LABEL: func @dot
func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
// PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// BOTH-NEXT: %[[ALLOC:.*]] = alloc
// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) {
// CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// CHECK-NEXT: %[[ALLOC:.*]] = alloc
// CHECK: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) {
// dot_dimension_numbers = {
// lhs_batching_dimensions = dense<> : tensor<0xi64>,
// lhs_contracting_dimensions = dense<1> : tensor<1xi64>,
@ -542,22 +534,21 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
// : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
%dot = "mhlo.dot"(%arg0, %arg0)
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
// PRE: "lmhlo.copy"(%[[ALLOC]], %[[RESULT]])
// ESC: return %[[ALLOC]]
// CHECK: return %[[ALLOC]]
return %dot : tensor<1024x1024xf32>
}
// -----
// BOTH-LABEL: func @conv
// CHECK-LABEL: func @conv
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> {
%c0 = constant 0 : index
// BOTH: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
// BOTH: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
// BOTH-SAME: padding = dense<[
// BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
// BOTH-SAME: rhs_dilation = dense<[1, 2]>
// BOTH-SAME: window_strides = dense<[2, 1]>
// CHECK: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
// CHECK: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
// CHECK-SAME: padding = dense<[
// CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
// CHECK-SAME: rhs_dilation = dense<[1, 2]>
// CHECK-SAME: window_strides = dense<[2, 1]>
%out = "mhlo.convolution"(%filter, %input) {
batch_group_count = 1 : i64,
dimension_numbers = {
@ -581,18 +572,18 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor
// -----
// BOTH-LABEL: func @reduce
// CHECK-LABEL: func @reduce
func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
// BOTH: %[[OUT:.*]] = alloc() : memref<1xf32>
// BOTH: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( {
// BOTH: ^bb0(%[[ARG1:.*]]: memref<f32>, %[[ARG2:.*]]: memref<f32>,
// BOTH-SAME: %[[ARG3:.*]]: memref<f32>):
// BOTH: %[[TMP:.*]] = alloc() : memref<f32>
// BOTH: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]])
// BOTH: "lmhlo.copy"(%[[TMP]], %[[ARG3]])
// BOTH: "lmhlo.terminator"() : () -> ()
// BOTH: }) {dimensions = dense<1> : tensor<1xi64>}
// BOTH-SAME: : (memref<1x8xf32>, memref<f32>, memref<1xf32>) -> ()
// CHECK: %[[OUT:.*]] = alloc() : memref<1xf32>
// CHECK: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( {
// CHECK: ^bb0(%[[ARG1:.*]]: memref<f32>, %[[ARG2:.*]]: memref<f32>,
// CHECK-SAME: %[[ARG3:.*]]: memref<f32>):
// CHECK: %[[TMP:.*]] = alloc() : memref<f32>
// CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]])
// CHECK: "lmhlo.copy"(%[[TMP]], %[[ARG3]])
// CHECK: "lmhlo.terminator"() : () -> ()
// CHECK: }) {dimensions = dense<1> : tensor<1xi64>}
// CHECK-SAME: : (memref<1x8xf32>, memref<f32>, memref<1xf32>) -> ()
%0 = "mhlo.reduce"(%arg0, %arg1) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
@ -604,25 +595,25 @@ func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
// -----
// BOTH-LABEL: func @transpose
// CHECK-LABEL: func @transpose
func @transpose(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.transpose"(%tensor_operand) {permutation = dense<[1, 0]> : tensor<2xi64>}
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.transpose"(%{{.*}}, %{{.*}}) {permutation = dense<[1, 0]> : tensor<2xi64>}
// BOTH-NOT: tensor_store
// CHECK: "lmhlo.transpose"(%{{.*}}, %{{.*}}) {permutation = dense<[1, 0]> : tensor<2xi64>}
// CHECK-NOT: tensor_store
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @custom_call
// BOTH-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>, [[RESULT:%.*]]: memref<4x4xf16>)
// CHECK-LABEL: func @custom_call
// CHECK-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>, [[RESULT:%.*]]: memref<4x4xf16>)
func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memref<4x4xf16>) {
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
%arg1_tensor = tensor_load %arg1 : memref<2x3xf32>
// BOTH: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false}
// CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false}
%result_tensor = "mhlo.custom_call"(%arg0_tensor, %arg1_tensor)
{backend_config = "", call_target_name = "foo", has_side_effect = false}
: (tensor<2x2xf32>, tensor<2x3xf32>) -> tensor<4x4xf16>
@ -632,10 +623,10 @@ func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memre
// ----
// BOTH-LABEL: func @isfinite
// CHECK-LABEL: func @isfinite
func @isfinite(%arg0: memref<2x2xf32>, %result: memref<2x2xi1>) {
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
// BOTH: "lmhlo.is_finite"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.is_finite"(%{{.*}}, %{{.*}})
%result_tensor = "mhlo.is_finite"(%arg0_tensor) : (tensor<2x2xf32>) -> tensor<2x2xi1>
tensor_store %result_tensor, %result: memref<2x2xi1>
return
@ -644,19 +635,19 @@ func @isfinite(%arg0: memref<2x2xf32>, %result: memref<2x2xi1>) {
// -----
// Test that assuming ops propagate memref types.
// BOTH-LABEL: func @shape_assuming_memref
// CHECK-LABEL: func @shape_assuming_memref
func @shape_assuming_memref(%arg0: tensor<?xf16>) -> tensor<?xf16> {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f16>
%1 = shape.const_witness true
// BOTH: shape.assuming %{{.*}} -> (memref<?xf16>)
// CHECK: shape.assuming %{{.*}} -> (memref<?xf16>)
%2 = shape.assuming %1 -> (tensor<?xf16>) {
%3 = shape.shape_of %arg0 : tensor<?xf16> -> tensor<?xindex>
%4 = tensor_cast %3 : tensor<?xindex> to tensor<1xindex>
%5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f16>, tensor<1xindex>) -> tensor<?xf16>
%6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16>
// BOTH: "lmhlo.maximum"(%6, %9, %20) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
// CHECK: "lmhlo.maximum"(%6, %9, %20) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
%7 = mhlo.maximum %5, %6 : tensor<?xf16>
// BOTH: shape.assuming_yield %{{.*}} : memref<?xf16>
// CHECK: shape.assuming_yield %{{.*}} : memref<?xf16>
shape.assuming_yield %7 : tensor<?xf16>
}
return %2 : tensor<?xf16>

View File

@ -267,7 +267,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
%13 = absf %arg3 : f32
linalg.yield %13 : f32
}
%2 = lmhlo.reshape_memref_cast %1(%arg1)
%2 = memref_reshape %1(%arg1)
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
return %2 : memref<*xf32>
}
@ -279,7 +279,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: absf
// CHECK: reshape_memref_cast
// CHECK: memref_reshape
// TILED-LABEL: func @view_result
// TILED-DAG: %[[C2:.*]] = constant 2
@ -288,7 +288,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: absf
// TILED: reshape_memref_cast
// TILED: memref_reshape
// PLOOP-LABEL: func @view_result
@ -297,5 +297,5 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: absf
// PLOOP: reshape_memref_cast
// PLOOP: memref_reshape

View File

@ -1,65 +0,0 @@
// RUN: mlir-hlo-opt %s -lower-affine -convert-scf-to-std -test-lhlo-legalize-to-llvm -split-input-file | FileCheck %s
// CHECK-LABEL: func @static_memref_cast
func @static_memref_cast(%buf : memref<10x1x5xf32>) {
%0 = lmhlo.static_memref_cast %buf
: memref<10x1x5xf32> -> memref<10x5xf32, offset: 2, strides: [5, 1]>
return
}
// CHECK: %[[INPUT_MEMREF_BLDR:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE_3D:!.*]]
// CHECK: llvm.insertvalue
// CHECK: %[[MEMREF_BLDR_0:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE_2D:!.*]]
// CHECK: %[[IN_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF:.*]][0] : [[DESCRIPTOR_TYPE_3D]]
// CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm.ptr<float> to !llvm.ptr<float>
// CHECK: %[[MEMREF_BLDR_1:.*]] = llvm.insertvalue %[[PTR]], %[[MEMREF_BLDR_0]][0] : [[DESCRIPTOR_TYPE_2D]]
// CHECK: %[[IN_ALIGNED_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][1] : [[DESCRIPTOR_TYPE_3D]]
// CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm.ptr<float> to !llvm.ptr<float>
// CHECK: %[[MEMREF_BLDR_2:.*]] = llvm.insertvalue %[[ALIGNED_PTR]], %[[MEMREF_BLDR_1]][1] : [[DESCRIPTOR_TYPE_2D]]
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
// CHECK: %[[MEMREF_BLDR_3:.*]] = llvm.insertvalue %[[C2]], %[[MEMREF_BLDR_2]][2] : [[DESCRIPTOR_TYPE_2D]]
// CHECK: %[[C10:.*]] = llvm.mlir.constant(10 : index) : !llvm.i64
// CHECK: %[[MEMREF_BLDR_4:.*]] = llvm.insertvalue %[[C10]], %[[MEMREF_BLDR_3]][3, 0] : [[DESCRIPTOR_TYPE_2D]]
// CHECK: %[[C5:.*]] = llvm.mlir.constant(5 : index) : !llvm.i64
// CHECK: %[[MEMREF_BLDR_5:.*]] = llvm.insertvalue %[[C5]], %[[MEMREF_BLDR_4]][4, 0] : [[DESCRIPTOR_TYPE_2D]]
// CHECK: %[[C5_:.*]] = llvm.mlir.constant(5 : index) : !llvm.i64
// CHECK: %[[MEMREF_BLDR_6:.*]] = llvm.insertvalue %[[C5_]], %[[MEMREF_BLDR_5]][3, 1] : [[DESCRIPTOR_TYPE_2D]]
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: %[[MEMREF_BLDR_7:.*]] = llvm.insertvalue %[[C1]], %[[MEMREF_BLDR_6]][4, 1] : [[DESCRIPTOR_TYPE_2D]]
// -----
// CHECK-LABEL: func @dynamic_memref_cast
func @dynamic_memref_cast(%buf : memref<?x?xf32>) {
%size_X = constant 10 : index
%size_Y = constant 50 : index
%stride_X = constant 1 : index
%stride_Y = constant 0 : index
%0 = lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y]
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
return
}
// CHECK: %[[C10:.*]] = llvm.mlir.constant(10 : index) : !llvm.i64
// CHECK: %[[C50:.*]] = llvm.mlir.constant(50 : index) : !llvm.i64
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: %[[MEMREF_BLDR_0:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE:!.*]]
// CHECK: %[[IN_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF:.*]][0] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm.ptr<float> to !llvm.ptr<float>
// CHECK: %[[MEMREF_BLDR_1:.*]] = llvm.insertvalue %[[PTR]], %[[MEMREF_BLDR_0]][0] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[IN_ALIGNED_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][1] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm.ptr<float> to !llvm.ptr<float>
// CHECK: %[[MEMREF_BLDR_2:.*]] = llvm.insertvalue %[[ALIGNED_PTR]], %[[MEMREF_BLDR_1]][1] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[SRC_OFFSET:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][2] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[MEMREF_BLDR_3:.*]] = llvm.insertvalue %[[SRC_OFFSET]], %[[MEMREF_BLDR_2]][2] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[MEMREF_BLDR_4:.*]] = llvm.insertvalue %[[C10]], %[[MEMREF_BLDR_3]][3, 0] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[MEMREF_BLDR_5:.*]] = llvm.insertvalue %[[C1]], %[[MEMREF_BLDR_4]][4, 0] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[MEMREF_BLDR_6:.*]] = llvm.insertvalue %[[C50]], %[[MEMREF_BLDR_5]][3, 1] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[MEMREF_BLDR_7:.*]] = llvm.insertvalue %[[C0]], %[[MEMREF_BLDR_6]][4, 1] : [[DESCRIPTOR_TYPE]]

View File

@ -429,120 +429,6 @@ func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memr
// -----
func @static_memref_cast(%in: memref<10x1xf32>) {
%out = lmhlo.static_memref_cast %in
: memref<10x1xf32> -> memref<10xf32, offset: 0, strides: [1]>
return
}
// CHECK-LABEL: func @static_memref_cast
// -----
func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) {
// expected-error @+1 {{operand must have static shape}}
%out = lmhlo.static_memref_cast %in
: memref<10x?xf32> -> memref<10x1xf32, offset: 0, strides: [10, 1]>
return
}
// -----
func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) {
// expected-error @+1 {{result must have static shape}}
%out = lmhlo.static_memref_cast %in
: memref<10x1xf32> -> memref<10x?xf32, offset: 0, strides: [?, ?]>
return
}
// -----
func @dynamic_memref_cast(%in: memref<?xf32>) {
%size = constant 10 : index
%step = constant 1 : index
%out = lmhlo.dynamic_memref_cast %in(%size)[%step]
: memref<?xf32> -> memref<?xf32, offset: 0, strides: [?]>
return
}
// CHECK-LABEL: func @dynamic_memref_cast
// -----
func @dynamic_memref_cast_incompatible_result_type(%in: memref<?xf32>) {
// expected-error @+3 {{`sizes` args count must be equal to the rank of the output memref}}
%size = constant 10 : index
%step = constant 1 : index
%out = lmhlo.dynamic_memref_cast %in(%size)[%step]
: memref<?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
return
}
// -----
// CHECK-LABEL: func @reshape_memref_cast(
func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
%shape2: memref<2xi32>, %shape3: memref<?xi32>) {
// CHECK-SAME: [[UNRANKED:%.*]]: memref<*xf32>, [[SHAPE_1:%.*]]: memref<1xi32>,
// CHECK-SAME: [[SHAPE_2:%.*]]: memref<2xi32>, [[SHAPE_3:%.*]]: memref<?xi32>
// CHECK-NEXT: [[DYN_VEC:%.*]] = lmhlo.reshape_memref_cast [[UNRANKED]]
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
%dyn_vec = lmhlo.reshape_memref_cast %unranked(%shape1)
: (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
// CHECK-NEXT: [[DYN_MAT:%.*]] = lmhlo.reshape_memref_cast [[DYN_VEC]]
// CHECK-SAME: : (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32>
%dyn_mat = lmhlo.reshape_memref_cast %dyn_vec(%shape2)
: (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32>
// CHECK-NEXT: {{%.*}} = lmhlo.reshape_memref_cast [[DYN_MAT]]
// CHECK-SAME: : (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
%new_unranked = lmhlo.reshape_memref_cast %dyn_mat(%shape3)
: (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
return
}
// -----
func @reshape_memref_cast_element_type_mismatch(
%buf: memref<*xf32>, %shape: memref<1xi32>) {
// expected-error @+1 {{element types of source and destination memref types should be the same}}
lmhlo.reshape_memref_cast %buf(%shape)
: (memref<*xf32>, memref<1xi32>) -> memref<?xi32>
}
// -----
func @reshape_memref_cast_dst_ranked_shape_unranked(
%buf: memref<*xf32>, %shape: memref<?xi32>) {
// expected-error @+1 {{cannot use shape operand with dynamic length to cast statically-ranked memref type}}
lmhlo.reshape_memref_cast %buf(%shape)
: (memref<*xf32>, memref<?xi32>) -> memref<?xf32>
return
}
// -----
func @reshape_memref_cast_dst_shape_rank_mismatch(
%buf: memref<*xf32>, %shape: memref<1xi32>) {
// expected-error @+1 {{length of shape operand differs from the result's memref rank}}
lmhlo.reshape_memref_cast %buf(%shape)
: (memref<*xf32>, memref<1xi32>) -> memref<?x?xf32>
return
}
// -----
func @reshape_memref_cast_affine_map_is_not_identity(
%buf: memref<4x4xf32, offset: 0, strides: [3, 2]>,
%shape: memref<1xi32>) {
// expected-error @+1 {{operand memref type should have identity affine map}}
lmhlo.reshape_memref_cast %buf(%shape)
: (memref<4x4xf32, offset: 0, strides: [3, 2]>, memref<1xi32>)
-> memref<8xf32>
return
}
// -----
// CHECK-LABEL: func @atan2_memrefs
func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()

View File

@ -700,7 +700,7 @@ func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> {
func @slice_indices_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> {
// expected-error@+1 {{failed to verify that all of {start_indices, limit_indices, strides} have same type}}
%0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 2, 3]> : tensor<3xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32>
%0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 2]> : tensor<2xi64>, limit_indices = dense<[2, 4, 1]> : tensor<3xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32>
return %0 : tensor<1x4xi32>
}
@ -714,6 +714,30 @@ func @slice_operand_result_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xf32> {
// -----
func @slice_indices_not_rank_1(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> {
// expected-error@+1 {{start_indices has rank 2 instead of required rank 1}}
%0 = "mhlo.slice"(%arg0) {
start_indices = dense<[[1, 0]]> : tensor<1x2xi64>,
limit_indices = dense<[[2, 4]]> : tensor<1x2xi64>,
strides = dense<[[1, 2]]> : tensor<1x2xi64>
} : (tensor<3x4xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
// -----
func @slice_indices_wrong_size(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> {
// expected-error@+1 {{the number of elements in start_indices (3) does not match the rank of the operand (2)}}
%0 = "mhlo.slice"(%arg0) {
start_indices = dense<[1, 0, 0]> : tensor<3xi64>,
limit_indices = dense<[2, 4, 0]> : tensor<3xi64>,
strides = dense<[1, 2, 0]> : tensor<3xi64>
} : (tensor<3x4xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
// -----
// CHECK-LABEL: func @dynamic_slice
func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
%0 = "mhlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>

View File

@ -265,6 +265,7 @@ cc_library(
deps = [
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:status",
"//tensorflow/stream_executor/lib",

View File

@ -148,18 +148,10 @@ bool IsI64Type(Type element_type) {
bool VerifyAddOpShapeConstraints(AddOp op) {
auto element_type = getElementTypeOrSelf(op.output().getType());
// Allows F32, QI8, and QUI8 outputs when the operands have valid shapes,
// Allows F32, QI8, QUI8 and I32 outputs when the operands have valid shapes,
// which are broadcastable shapes up to five dimension or have same shapes.
if (element_type.isF32() || IsQI8Type(element_type) ||
IsQUI8Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
}
// Allows I32 output when the operands have valid shapes, which are
// broadcastable shapes up to four dimension or have same shapes.
if (IsI32Type(element_type)) {
IsQUI8Type(element_type) || IsI32Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/4);
@ -211,20 +203,13 @@ bool VerifyMulOpShapeConstraints(MulOp op) {
}
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
/*max_bcast_rank=*/4);
}
// Allows F32 output when the operands have valid shapes, which are
// broadcastable shapes up to five dimension or have same shapes.
if (element_type.isF32()) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
}
// Allows I32 and QI16 outputs when the operands have valid shapes, which are
// broadcastable shapes up to four dimension or have same shapes.
if (IsI32Type(element_type) || IsQI16Type(element_type)) {
// Allows I32, QI16 and F32 outputs when the operands have valid shapes, which
// are broadcastable shapes up to four dimension or have same shapes.
if (IsI32Type(element_type) || IsQI16Type(element_type) ||
element_type.isF32()) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/4);

View File

@ -4407,4 +4407,55 @@ def TFL_CustomTfOp : Op<TFL_Dialect, "custom_tf", [
let regions = (region SizedRegion<1>:$body);
}
def TFL_BroadcastToOp : TFL_Op<"broadcast_to", [
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
TFL_OperandHasRankAtMost<0, 8>,
TFL_OperandHasRank<1, 1>,
PredOpTrait<"output dimension count must be at most 8",
Or<[TFL_OperandIsUnrankedPred<1>,
TFL_OperandDimIsAtMost<1, 0, 8>]>>,
NoSideEffect]> {
let summary = "Broadcast an array for a compatible shape.";
let description = [{
Broadcasting is the process of making arrays to have compatible shapes
for arithmetic operations. Two shapes are compatible if for each
dimension pair they are either equal or one of them is one. When trying
to broadcast a Tensor to a shape, it starts with the trailing dimensions,
and works its way forward.
For example,
>>> x = tf.constant([1, 2, 3])
>>> y = tf.broadcast_to(x, [3, 3])
>>> print(y)
tf.Tensor(
[[1 2 3]
[1 2 3]
[1 2 3]], shape=(3, 3), dtype=int32)
In the above example, the input Tensor with the shape of `[1, 3]`
is broadcasted to output Tensor with shape of `[3, 3]`.
When doing broadcasted operations such as multiplying a tensor
by a scalar, broadcasting (usually) confers some time or space
benefit, as the broadcasted tensor is never materialized.
However, `broadcast_to` does not carry with it any such benefits.
The newly-created tensor takes the full memory of the broadcasted
shape. (In a graph context, `broadcast_to` might be fused to
subsequent operation and then be optimized away, however.)
}];
let arguments = (ins
TFL_TensorOf<[F32, I32, I1, I8, QI8, UI8, QUI8, I16, QI16, I64, Complex<F<32>>]>:$input,
TFL_I32OrI64Tensor:$shape
);
let results = (outs
TFL_TensorOf<[F32, I32, I1, I8, QI8, UI8, QUI8, I16, QI16, I64, Complex<F<32>>]>:$output
);
}
#endif // TFL_OPS

View File

@ -289,7 +289,8 @@ Status ConvertMLIRToTFLiteFlatBuffer(
absl::StrCat(toco_flags.dump_graphviz_dir(), "/toco_AT_IMPORT.dot")));
}
mlir::PassManager pm(module->getContext());
mlir::PassManager pm(module->getContext(),
mlir::OpPassManager::Nesting::Implicit);
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm, session);
// Convert back to outlined while format for export back to flatbuffer.

View File

@ -75,7 +75,7 @@ TfLiteStatus QuantizeModel(
}
// Apply quantization passes
PassManager pm(module->getContext());
PassManager pm(module->getContext(), OpPassManager::Nesting::Implicit);
TFL::QuantizationSpecs quant_specs;
quant_specs.inference_type = tflite::TflTypeToTfType(inference_type);
quant_specs.post_training_quantization = true;

View File

@ -57,7 +57,7 @@ TfLiteStatus SparsifyModel(const tflite::ModelT& input_model,
return kTfLiteError;
}
PassManager pm(module->getContext());
PassManager pm(module->getContext(), OpPassManager::Nesting::Implicit);
pm.addPass(TFL::CreateDenseToSparsePass());
if (failed(pm.run(module.get()))) {

View File

@ -25,13 +25,6 @@ func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: te
return %0 : tensor<1x2x3x4x5x6x7x8xi32>
}
// CHECK-LABEL: testAddTooHighBroadcastableDims
func @testAddTooHighBroadcastableDims(%arg0: tensor<1x2x3x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// expected-error @+1 {{'tfl.add' op failed to verify that operand #0 and operand #1 have the same shape or broadcastable shapes within the rank 4}}
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
%2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32>
return %2: tensor<1xf32>
@ -1520,6 +1513,24 @@ func @UnidirectionalRnn(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x28xf32>) {
// CHECK: return [[VAL_4]] : tensor<28x1x28xf32>
// CHECK: }
func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
return %0: tensor<3x3xf32>
// CHECK-LABEL: broadcast_to_f32
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
// CHECK: return [[BCT]] : tensor<3x3xf32>
}
func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
return %0: tensor<3x3xi32>
// CHECK-LABEL: broadcast_to_i32
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
// CHECK: return [[BCT]] : tensor<3x3xi32>
}
func @matmul_batch(%arg0: tensor<10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<10x17xf32> {
%0 = "tf.BatchMatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} :
(tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32>
@ -1550,7 +1561,11 @@ func @select_v2_with_6d_broadcasting(%arg0: tensor<1x1x1x1x3x1xi1>, %arg1 : tens
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2): (tensor<1x1x1x1x3x1xi1>, tensor<1x1x1x1x1x4xf32>, tensor<1x1x1x2x1x1xf32>) -> tensor<1x1x1x2x3x4xf32>
return %0 : tensor<1x1x1x2x3x4xf32>
// CHECK-LABEL: select_v2_with_6d_broadcasting
// CHECK: "tf.SelectV2"(%arg0, %arg1, %arg2)
// CHECK: [[CST:%.*]] = constant dense<[1, 1, 1, 2, 3, 4]> : tensor<6xi64>
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCT_0:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: [[BCT_1:%.*]] = "tfl.broadcast_to"(%arg2, [[CST]])
// CHECK: "tfl.select"([[BCT]], [[BCT_0]], [[BCT_1]])
}
// -----
@ -1560,7 +1575,9 @@ func @maximum_with_6d_broadcasting(%arg0: tensor<1x1x1x1x8x16xf32>, %arg1: tenso
return %0 : tensor<1x1x1x1x8x16xf32>
// CHECK-LABEL: maximum_with_6d_broadcasting
// CHECK: "tf.Maximum"(%arg0, %arg1)
// CHECK: [[CST:%.*]] = constant dense<[1, 1, 1, 1, 8, 16]> : tensor<6xi64>
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.maximum"(%arg0, [[BCT]])
}
// -----
@ -1569,7 +1586,171 @@ func @add_with_int32_5d_inputs(%arg0: tensor<1x1x1x3x1xi32>, %arg1 : tensor<1x1x
%0 = "tf.Add"(%arg0, %arg1): (tensor<1x1x1x3x1xi32>, tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32>
return %0 : tensor<1x1x1x3x4xi32>
// CHECK-LABEL: add_with_int32_5d_inputs
// CHECK: "tf.Add"(%arg0, %arg1)
// CHECK: [[CST:%.*]] = constant dense<[1, 1, 1, 3, 4]> : tensor<5xi64>
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCT_0:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: tfl.add [[BCT]], [[BCT_0]]
}
// CHECK-LABEL: testAddWithBroadcastToOps
func @testAddWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: tfl.add [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testSubWithBroadcastToOps
func @testSubWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: tfl.sub [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testMulWithBroadcastToOps
func @testMulWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: tfl.mul [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
%0 = "tf.Mul"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testDivWithBroadcastToOps
func @testDivWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: tfl.div [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
%0 = "tf.Div"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testFloorDivWithBroadcastToOps
func @testFloorDivWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: tfl.floor_div [[BCAST]], [[BCAST_1]] : tensor<1x2x3x4x5x6xi32>
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testFloorModWithBroadcastToOps
func @testFloorModWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.floor_mod"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi32>
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testPowWithBroadcastToOps
func @testPowWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: tfl.pow [[BCAST]], [[BCAST_1]] : tensor<1x2x3x4x5x6xi32>
%0 = "tf.Pow"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testMaximumWithBroadcastToOps
func @testMaximumWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.maximum"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi32>
%0 = "tf.Maximum"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testMinimumWithBroadcastToOps
func @testMinimumWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.minimum"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi32>
%0 = "tf.Minimum"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testSelectV2WithBroadcastToOps
func @testSelectV2WithBroadcastToOps(%arg0: tensor<1x2x1x4x1x6xi1>, %arg1: tensor<1x2x3x4x1x1xi32>, %arg2: tensor<1x2x1x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: [[BCAST_2:%.*]] = "tfl.broadcast_to"(%arg2, [[CST]])
// CHECK: "tfl.select"([[BCAST]], [[BCAST_1]], [[BCAST_2]])
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1x2x1x4x1x6xi1>, tensor<1x2x3x4x1x1xi32>, tensor<1x2x1x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testLessEqualWithBroadcastToOps
func @testLessEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.less_equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
%0 = "tf.LessEqual"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
return %0 : tensor<1x2x3x4x5x6xi1>
}
// CHECK-LABEL: testGreaterEqualWithBroadcastToOps
func @testGreaterEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.greater_equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
%0 = "tf.GreaterEqual"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
return %0 : tensor<1x2x3x4x5x6xi1>
}
// CHECK-LABEL: testEqualWithBroadcastToOps
func @testEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
return %0 : tensor<1x2x3x4x5x6xi1>
}
// CHECK-LABEL: testNotEqualWithBroadcastToOps
func @testNotEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.not_equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
%0 = "tf.NotEqual"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
return %0 : tensor<1x2x3x4x5x6xi1>
}
// CHECK-LABEL: testLessWithBroadcastToOps
func @testLessWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.less"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
%0 = "tf.Less"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
return %0 : tensor<1x2x3x4x5x6xi1>
}
// CHECK-LABEL: testGreaterWithBroadcastToOps
func @testGreaterWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.greater"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
return %0 : tensor<1x2x3x4x5x6xi1>
}
func @tranpose_int32_perm(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {

View File

@ -2419,3 +2419,21 @@ func @invalid_two_dynamic_dims_on_reshape(%arg0: tensor<3x4xi32>, %arg1: tensor<
%0 = "tfl.reshape"(%arg0, %arg1) : (tensor<3x4xi32>, tensor<?x?x4xi32>) -> tensor<1x3x4xi32>
return %0 : tensor<1x3x4xi32>
}
// -----
// CHECK-LABEL: testBroadcastToWithI32ShapeTensor
func @testBroadcastToWithI32ShapeTensor(tensor<?x?x?x?x?x?xf32>, tensor<8xi32>) -> tensor<?x?x?x?x?x?x?x?xf32> {
^bb0(%arg0: tensor<?x?x?x?x?x?xf32>, %arg1: tensor<8xi32>):
// CHECK: "tfl.broadcast_to"(%arg0, %arg1)
%0 = "tfl.broadcast_to"(%arg0, %arg1): (tensor<?x?x?x?x?x?xf32>, tensor<8xi32>) -> tensor<?x?x?x?x?x?x?x?xf32>
return %0 : tensor<?x?x?x?x?x?x?x?xf32>
}
// CHECK-LABEL: testBroadcastToWithI64ShapeTensor
func @testBroadcastToWithI64ShapeTensor(tensor<?x?x?x?x?x?xf32>, tensor<8xi64>) -> tensor<?x?x?x?x?x?x?x?xf32> {
^bb0(%arg0: tensor<?x?x?x?x?x?xf32>, %arg1: tensor<8xi64>):
// CHECK: "tfl.broadcast_to"(%arg0, %arg1)
%0 = "tfl.broadcast_to"(%arg0, %arg1): (tensor<?x?x?x?x?x?xf32>, tensor<8xi64>) -> tensor<?x?x?x?x?x?x?x?xf32>
return %0 : tensor<?x?x?x?x?x?x?x?xf32>
}

View File

@ -1382,3 +1382,18 @@ func @fuseScalarAddIntoConv2dHalf(%arg0: tensor<256x32x32x3xf16>, %arg1: tensor<
// CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xf16>
// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst)
}
// CHECK-LABEL: fuseExpanded1DMulIntoConv2d
func @fuseExpanded1DMulIntoConv2d(%arg0: tensor<1x8x8x207xf32>) -> tensor<1x8x8x256xf32> {
%cst_0 = constant dense<1.4> : tensor<256x3x3x207xf32>
%cst_1 = constant dense<1.5> : tensor<256xf32>
%cst_2 = constant dense<2.0> : tensor<1x1x1x256xf32>
%0 = "tfl.conv_2d"(%arg0, %cst_0, %cst_1) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x8x8x207xf32>, tensor<256x3x3x207xf32>, tensor<256xf32>) -> tensor<1x8x8x256xf32>
%1 = "tfl.mul"(%0, %cst_2) {fused_activation_function = "NONE"} : (tensor<1x8x8x256xf32>, tensor<1x1x1x256xf32>) -> tensor<1x8x8x256xf32>
return %1 : tensor<1x8x8x256xf32>
// CHECK: %[[CST_0:.*]] = constant dense<2.800000e+00> : tensor<256x3x3x207xf32>
// CHECK: %[[CST_1:.*]] = constant dense<3.000000e+00> : tensor<1x1x1x256xf32>
// CHECK: "tfl.conv_2d"(%arg0, %[[CST_0]], %[[CST_1]])
}

View File

@ -0,0 +1,26 @@
// RUN: tf-opt -tfl-prepare-tf=tfl-allow-bf16-type-legalization=true %s | FileCheck %s
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
// CHECK-LABEL: conv_2d_bf16
func @conv_2d_bf16(%arg0 : tensor<256x32x32x3xbf16>, %arg1 : tensor<3x3x3x16xbf16>) -> tensor<256x30x30x16xbf16> {
%0 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xbf16>, tensor<3x3x3x16xbf16>) -> tensor<256x30x30x16xbf16>
return %0 : tensor<256x30x30x16xbf16>
// CHECK: "tfl.conv_2d"
}
// CHECK-LABEL: fused_batch_norm_v3_bf16
func @fused_batch_norm_v3_bf16(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> tensor<8x8x8x8xbf16> {
%0, %1, %2 ,%3, %4, %5 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_BFLOAT16", U = "tfdtype$DT_BFLOAT16", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
return %0 : tensor<8x8x8x8xbf16>
// CHECK: "tf.FusedBatchNormV3"
}
// CHECK-LABEL: depthwise_conv_2d_bf16
func @depthwise_conv_2d_bf16(%arg0 : tensor<256x32x32x3xbf16>, %arg1 : tensor<3x3x3x4xf32>, %arg2 : tensor<256x3x32x32xf32>) -> tensor<256x30x30x12xbf16> {
%0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xbf16>, tensor<3x3x3x4xf32>) -> tensor<256x30x30x12xbf16>
return %0 : tensor<256x30x30x12xbf16>
// CHECK: "tfl.depthwise_conv_2d"
}
}

View File

@ -571,6 +571,73 @@ func @MatrixSetDiagV3Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) ->
// CHECK: return %[[RES]]
}
func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
return %0: tensor<3x3xf32>
// CHECK-LABEL: broadcast_to_f32_low_dim
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32>
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// CHECK: return [[MUL]] : tensor<3x3xf32>
}
func @broadcast_to_i32_low_dim(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
return %0: tensor<3x3xi32>
// CHECK-LABEL: broadcast_to_i32_low_dim
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<3x3xi32>
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
// CHECK: return [[MUL]] : tensor<3x3xi32>
}
func @broadcast_to_low_dim_with_unknown_shape(%arg0: tensor<3xf32>, %arg1: tensor<*xi32>) -> tensor<3x3xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<*xi32>) -> tensor<3x3xf32>
return %0: tensor<3x3xf32>
// CHECK-LABEL: broadcast_to_low_dim_with_unknown_shape
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32>
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// CHECK: return [[MUL]] : tensor<3x3xf32>
}
func @broadcast_to_i32_low_dim_with_unknown_output(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<*xi32> {
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<*xi32>
return %0: tensor<*xi32>
// CHECK-LABEL: broadcast_to_i32_low_dim_with_unknown_output
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<i32>
// CHECK: [[FILL:%.*]] = "tf.Fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<i32>) -> tensor<*xi32>
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[FILL]]) : (tensor<3xi32>, tensor<*xi32>) -> tensor<*xi32>
// CHECK: return [[MUL]] : tensor<*xi32>
}
func @broadcast_to_high_dim_with_unknown_shape(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32>
return %0: tensor<7x8x1x2x3x4x5x6xf32>
// CHECK-LABEL: broadcast_to_high_dim_with_unknown_shape
// CHECK: [[BCT:%.*]] = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32>
// CHECK: return [[BCT]] : tensor<7x8x1x2x3x4x5x6xf32>
}
func @broadcast_to_high_dim_with_unknown_output(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<8xi32>) -> tensor<*xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<8xi32>) -> tensor<*xf32>
return %0: tensor<*xf32>
// CHECK-LABEL: broadcast_to_high_dim_with_unknown_output
// CHECK: [[BCT:%.*]] = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<8xi32>) -> tensor<*xf32>
// CHECK: return [[BCT]] : tensor<*xf32>
}
func @broadcast_to_with_unknown_shape_and_output(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<*xf32>
return %0: tensor<*xf32>
// CHECK-LABEL: broadcast_to_with_unknown_shape_and_output
// CHECK: "tf.BroadcastTo"(%arg0, %arg1)
}
// CHECK-LABEL: xla_conv
func @xla_conv(%arg0: tensor<4x8x8x16xf32>) -> tensor<4x8x8x16xf32> {
%0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<3x3x16x16xf32>} : () -> tensor<3x3x16x16xf32> loc("Const_1")
@ -646,4 +713,46 @@ func @DoNotConvertConv2DWhenFilterTypeDimIsNotDecided(%arg0 : tensor<?x?x?x96xf3
// CHECK: tf.Conv2D
}
// CHECK-LABEL: conv2d_f16
func @conv2d_f16(%arg0 : tensor<?x224x224x3xf16>, %arg1 : tensor<3x3x3x16xf16>) -> tensor<?x112x112x16xf16> {
%0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<?x224x224x3xf16>, tensor<3x3x3x16xf16>) -> tensor<?x112x112x16xf16>
return %0 : tensor<?x112x112x16xf16>
// CHECK: "tf.Conv2D"
}
// CHECK-LABEL: fused_batch_norm_v3_f16
func @fused_batch_norm_v3_f16(%arg0 : tensor<?x112x112x16xf16>, %arg1 : tensor<16xf32>, %arg2 : tensor<16xf32>, %arg3 : tensor<16xf32>, %arg4 : tensor<16xf32>) -> tensor<?x112x112x16xf16> {
%0, %1, %2, %3, %4, %5 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {data_format = "NHWC", device = "", epsilon = 1.000000e-03 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = false} : (tensor<?x112x112x16xf16>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) -> (tensor<?x112x112x16xf16>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<*xf32>)
return %0 : tensor<?x112x112x16xf16>
// CHECK: "tf.FusedBatchNormV3"
}
// CHECK-LABEL: depthwise_conv2d_native_f16
func @depthwise_conv2d_native_f16(%arg0 : tensor<?x112x112x16xf16>, %arg1 : tensor<3x3x16x1xf16>) -> tensor<?x112x112x16xf16> {
%0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<?x112x112x16xf16>, tensor<3x3x16x1xf16>) -> tensor<?x112x112x16xf16>
return %0 : tensor<?x112x112x16xf16>
// CHECK: "tf.DepthwiseConv2dNative"
}
// CHECK-LABEL: conv_2d_bf16
func @conv_2d_bf16(%arg0 : tensor<256x32x32x3xbf16>, %arg1 : tensor<3x3x3x16xbf16>) -> tensor<256x30x30x16xbf16> {
%0 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xbf16>, tensor<3x3x3x16xbf16>) -> tensor<256x30x30x16xbf16>
return %0 : tensor<256x30x30x16xbf16>
// CHECK: "tf.Conv2D"
}
// CHECK-LABEL: fused_batch_norm_v3_bf16
func @fused_batch_norm_v3_bf16(%arg0 : tensor<?x112x112x16xbf16>, %arg1 : tensor<16xf32>, %arg2 : tensor<16xf32>, %arg3 : tensor<16xf32>, %arg4 : tensor<16xf32>) -> tensor<?x112x112x16xbf16> {
%0, %1, %2, %3, %4, %5 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {data_format = "NHWC", device = "", epsilon = 1.000000e-03 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = false} : (tensor<?x112x112x16xbf16>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) -> (tensor<?x112x112x16xbf16>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<*xf32>)
return %0 : tensor<?x112x112x16xbf16>
// CHECK: "tf.FusedBatchNormV3"
}
// CHECK-LABEL: depthwise_conv_2d_bf16
func @depthwise_conv_2d_bf16(%arg0 : tensor<256x32x32x3xbf16>, %arg1 : tensor<3x3x3x4xf32>, %arg2 : tensor<256x3x32x32xf32>) -> tensor<256x30x30x12xbf16> {
%0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xbf16>, tensor<3x3x3x4xf32>) -> tensor<256x30x30x12xbf16>
return %0 : tensor<256x30x30x12xbf16>
// CHECK: "tf.DepthwiseConv2dNative"
}
}

View File

@ -45,18 +45,20 @@ const char kTFLiteDataLayout[] = "NHWC";
void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
mlir::OpPassManager* pass_manager) {
pass_manager->addPass(mlir::TFL::CreatePrepareQuantizePass(quant_specs));
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreatePrepareQuantizePass(quant_specs));
if (quant_specs.default_ranges.first.hasValue() ||
quant_specs.default_ranges.second.hasValue()) {
pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass(
quant_specs.default_ranges.first.getValueOr(0.0),
quant_specs.default_ranges.second.getValueOr(0.0),
quant_specs.IsSignedInferenceType()));
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateDefaultQuantParamsPass(
quant_specs.default_ranges.first.getValueOr(0.0),
quant_specs.default_ranges.second.getValueOr(0.0),
quant_specs.IsSignedInferenceType()));
}
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateQuantizePass());
bool emit_quant_adaptor_ops =
quant_specs.inference_type != quant_specs.inference_input_type;
pass_manager->addPass(
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
}
@ -67,7 +69,8 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
standard_pipeline_options.enable_inliner = false;
standard_pipeline_options.form_clusters = pass_config.form_clusters;
mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
pass_manager->addPass(mlir::TF::CreateDeviceIndexSelectorPass());
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TF::CreateDeviceIndexSelectorPass());
// Add canonicalize pass to remove no-op session initializer pass.
pass_manager->addPass(mlir::createCanonicalizerPass());
@ -155,7 +158,8 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_manager->addPass(mlir::createInlinerPass());
// TODO(jpienaar): Revise post dialect constants.
pass_manager->addPass(mlir::TF::CreateDecodeConstantPass());
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TF::CreateDecodeConstantPass());
// Canonicalization includes const folding, which is utilized here to optimize
// away ops that can't get constant folded after PrepareTF pass. For example,
// tf.Conv2D is split into tf.Transpose and tfl.Conv2D.
@ -178,12 +182,13 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
// to match 'kTFLiteDataLayout'
mlir::TF::LayoutOptimizationPipelineOptions layout_optimization_options;
layout_optimization_options.force_data_format = kTFLiteDataLayout;
mlir::TF::CreateLayoutOptimizationPipeline(*pass_manager,
layout_optimization_options);
mlir::TF::CreateLayoutOptimizationPipeline(
pass_manager->nest<mlir::FuncOp>(), layout_optimization_options);
// Prepare for TFLite dialect, rerun canonicalization, and then legalize to
// the TFLite dialect.
pass_manager->addPass(
mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul));
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreatePrepareTFPass(
pass_config.unfold_batch_matmul,
/*allow_bf16_type_legalization=*/!pass_config.runtime_verification));
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
if (pass_config.shape_inference) {
// Add a shape inference pass to optimize away the unnecessary casts.
@ -198,16 +203,18 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_manager->addPass(mlir::createInlinerPass());
// This pass removes the asset file dependencies in hash table use cases.
pass_manager->addPass(mlir::TF::CreateInitTextFileToImportPass());
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TF::CreateInitTextFileToImportPass());
pass_manager->addPass(
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateOptimizePass());
// This pass operates on TensorFlow ops but is triggered after legalization
// so that it can target constants introduced once TensorFlow Identity ops
// are removed during legalization.
pass_manager->addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
pass_manager->addPass(mlir::TFL::CreateRaiseCustomOpsPass());
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateRaiseCustomOpsPass());
pass_manager->addPass(mlir::createSymbolDCEPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
@ -225,7 +232,8 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
// it's not desired by TFL. This pass serves as a "fix" pass to split the
// merged inputs until we have 1st class variable support or reuse
// tf.variable to model this.
pass_manager->addPass(mlir::TFL::CreateSplitMergedOperandsPass());
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateSplitMergedOperandsPass());
}
}
@ -276,7 +284,8 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
pm.addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass());
// TFLite dialect passes.
pm.addPass(mlir::TFL::CreatePrepareTFPass(true));
pm.addPass(mlir::TFL::CreatePrepareTFPass(
/*unfold_batch_matmul=*/true, /*allow_bf16_type_legalization=*/false));
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pm.addPass(
mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true));
@ -295,7 +304,7 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
pm.addNestedPass<mlir::FuncOp>(mlir::TFL::CreateRuntimeVerifyPass());
}
// Registers a pass pipeline for the standard TFL passes.

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/compiler/mlir/init_mlir.h"
@ -187,7 +188,7 @@ int main(int argc, char **argv) {
// message. So we can just return here.
if (!module.ok()) return kTrFailure;
mlir::PassManager pm(&context);
mlir::PassManager pm(&context, mlir::OpPassManager::Nesting::Implicit);
mlir::applyPassManagerCLOptions(pm);
// Set the quantization specifications from the command line flags.

View File

@ -92,7 +92,7 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
return failure();
}
if (!TFTypeIsFloatTensor(op.input()) || !TFDataFormatIsNHWC(op))
if (!TFTypeIsFloat32Tensor(op.input()) || !TFDataFormatIsNHWC(op))
return failure();
// Allow dynamic width and height dimensions only.

View File

@ -116,6 +116,9 @@ def LegalizeArgMax : Pat<(TF_ArgMaxOp $input, $dim),
def LegalizeArgMin : Pat<(TF_ArgMinOp $input, $dim),
(TFL_ArgMinOp $input, $dim)>;
def LegalizeBroadcastTo : Pat<(TF_BroadcastToOp $input, $dim),
(TFL_BroadcastToOp $input, $dim)>;
def LegalizeCeil : Pat<(TF_CeilOp $arg), (TFL_CeilOp $arg)>;
def LegalizeCos : Pat<(TF_CosOp $arg), (TFL_CosOp $arg)>;
@ -264,7 +267,7 @@ def LegalizeAddv2 : Pat<(TF_AddV2Op $lhs, $rhs),
(TFL_AddOp $lhs, $rhs, TFL_AF_None)>;
def LegalizeBiasAdd : Pat<
(TF_BiasAddOp F32Tensor:$l, F32Tensor:$r, IsDataFormatNHWC:$data_format),
(TFL_AddOp $l, $r, TFL_AF_None)>;
(TF_AddV2Op $l, $r)>;
def LegalizeSub : Pat<(TF_SubOp $lhs, $rhs),
(TFL_SubOp $lhs, $rhs, TFL_AF_None)>;
def LegalizeMul : Pat<(TF_MulOp $lhs, $rhs),

View File

@ -636,11 +636,155 @@ struct LegalizeUnidirectionalSequenceRnn : public RewritePattern {
}
};
void LegalizeTF::runOnFunction() {
OwningRewritePatternList patterns;
auto* context = &getContext();
auto func = getFunction();
// Put two TFL BroadcastTo ops in front of the given TF binary broadcast op to
// to make binary broadcast-able op conversion always successful and does not
// require flex delegate.
template <typename SourceOp>
class ApplyExplicitBroadcasting : public OpRewritePattern<SourceOp> {
public:
using OpRewritePattern<SourceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SourceOp src_op,
PatternRewriter& rewriter) const override {
Operation* op = static_cast<Operation*>(src_op);
auto lhs = op->getOperand(0);
auto rhs = op->getOperand(1);
// Should have static shapes to calculate the broadcasted shape.
if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
!rhs.getType().cast<ShapedType>().hasStaticShape()) {
return failure();
}
auto lhs_shape = lhs.getType().cast<ShapedType>().getShape();
auto rhs_shape = rhs.getType().cast<ShapedType>().getShape();
if (lhs_shape == rhs_shape) {
return failure();
}
// Calculate the broadcasted shape.
SmallVector<int64_t, 4> result_shape;
if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape,
result_shape)) {
return failure();
}
RankedTensorType result_type = RankedTensorType::get(
result_shape, getElementTypeOrSelf(op->getResult(0).getType()));
// Create a const op, that stores the above broadcasted shape.
auto new_shape_attr = mlir::DenseIntElementsAttr::get(
RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64)),
result_shape);
auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
// Apply BroadcastTo ops to each input.
auto broadcast_type = RankedTensorType::get(
result_shape, getElementTypeOrSelf(lhs.getType()));
if (result_type.getShape() != lhs_shape) {
lhs = rewriter
.create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, lhs,
new_shape)
.output();
}
if (result_type.getShape() != rhs_shape) {
rhs = rewriter
.create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, rhs,
new_shape)
.output();
}
// Recreate an op with the above Broadcast op results.
rewriter.replaceOpWithNewOp<SourceOp>(op, result_type, lhs, rhs);
return success();
}
};
// This specialization is for TF SelectV2 op. SelectV2 op have three inputs and
// they should have broadcastable shapes.
template <>
class ApplyExplicitBroadcasting<TF::SelectV2Op>
: public OpRewritePattern<TF::SelectV2Op> {
public:
using OpRewritePattern<TF::SelectV2Op>::OpRewritePattern;
LogicalResult matchAndRewrite(TF::SelectV2Op src_op,
PatternRewriter& rewriter) const override {
Operation* op = static_cast<Operation*>(src_op);
auto cond = op->getOperand(0);
auto lhs = op->getOperand(1);
auto rhs = op->getOperand(2);
// Should have static shapes to calculate the broadcasted shape.
if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
!rhs.getType().cast<ShapedType>().hasStaticShape() ||
!cond.getType().cast<ShapedType>().hasStaticShape()) {
return failure();
}
auto lhs_shape = lhs.getType().cast<ShapedType>().getShape();
auto rhs_shape = rhs.getType().cast<ShapedType>().getShape();
auto cond_shape = cond.getType().cast<ShapedType>().getShape();
if (lhs_shape == rhs_shape && cond_shape == lhs_shape) {
return failure();
}
// Calculate the broadcasted shape.
SmallVector<int64_t, 4> broadcasted_shape;
if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape,
broadcasted_shape)) {
return failure();
}
SmallVector<int64_t, 4> result_shape;
if (!OpTrait::util::getBroadcastedShape(broadcasted_shape, cond_shape,
result_shape)) {
return failure();
}
// Create a const op, that stores the above broadcasted shape.
auto shape_type =
RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64));
auto new_shape_attr =
mlir::DenseIntElementsAttr::get(shape_type, result_shape);
auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
// Apply BroadcastTo ops to each input.
auto cond_result_type =
RankedTensorType::get(result_shape, rewriter.getIntegerType(1));
auto result_type = RankedTensorType::get(
result_shape, getElementTypeOrSelf(lhs.getType()));
if (result_shape != cond_shape) {
cond = rewriter
.create<TF::BroadcastToOp>(op->getLoc(), cond_result_type,
cond, new_shape)
.output();
}
if (result_shape != lhs_shape) {
lhs = rewriter
.create<TF::BroadcastToOp>(op->getLoc(), result_type, lhs,
new_shape)
.output();
}
if (result_shape != rhs_shape) {
rhs = rewriter
.create<TF::BroadcastToOp>(op->getLoc(), result_type, rhs,
new_shape)
.output();
}
// Recreate an op with the above Broadcast op results.
rewriter.replaceOpWithNewOp<TF::SelectV2Op>(op, result_type, cond, lhs,
rhs);
return success();
}
};
void addPatterns(MLIRContext* context, OwningRewritePatternList& patterns) {
// Add TF->TF lowering patterns.
TF::PopulateLoweringTFPatterns(context, &patterns);
@ -656,7 +800,25 @@ void LegalizeTF::runOnFunction() {
// Ophint python converter converted tf node pattern.
patterns.insert<LegalizeUnidirectionalSequenceLstm,
LegalizeUnidirectionalSequenceRnn>(context);
FrozenRewritePatternList frozenPatterns(std::move(patterns));
}
void applyPatterns(FuncOp func, ConversionTarget& target,
FrozenRewritePatternList& frozenPatterns) {
// Keep trying to convert.
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
// Look if there is a function that tries until it converge.
// Currently unit-test doesn't do multiple tries, so we need this.
const int max_iterations = 15;
for (int i = 0; i < max_iterations; ++i) {
if (failed(applyPartialConversion(func, target, frozenPatterns))) {
return;
}
}
}
void LegalizeTF::runOnFunction() {
auto* context = &getContext();
auto func = getFunction();
ConversionTarget target(*context);
// It is legal to have TF ops in the graph still which can be
@ -690,16 +852,42 @@ void LegalizeTF::runOnFunction() {
return success(current_thread_id == llvm::get_threadid());
});
// Keep trying to convert.
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
// Look if there is a function that tries until it converge.
// Currently unit-test doesn't do multiple tries, so we need this.
const int max_iterations = 15;
for (int i = 0; i < max_iterations; ++i) {
if (failed(applyPartialConversion(func, target, frozenPatterns))) {
return;
}
}
OwningRewritePatternList stage1Patterns;
addPatterns(context, stage1Patterns);
FrozenRewritePatternList stage1FrozenPatterns(std::move(stage1Patterns));
applyPatterns(func, target, stage1FrozenPatterns);
// Explict BroadcastTo addition for left-over broadcast-able ops.
// The following pattern matchings should be done after the other legalization
// rules in order not to add unnecessary BroadcastTo ops.
OwningRewritePatternList stage2Patterns;
addPatterns(context, stage2Patterns);
stage2Patterns.insert<ApplyExplicitBroadcasting<TF::LessEqualOp>,
ApplyExplicitBroadcasting<TF::GreaterEqualOp>,
ApplyExplicitBroadcasting<TF::NotEqualOp>,
ApplyExplicitBroadcasting<TF::GreaterOp>,
ApplyExplicitBroadcasting<TF::LessOp>,
ApplyExplicitBroadcasting<TF::EqualOp>,
ApplyExplicitBroadcasting<TF::AddOp>,
ApplyExplicitBroadcasting<TF::AddV2Op>,
ApplyExplicitBroadcasting<TF::MulOp>,
ApplyExplicitBroadcasting<TF::DivOp>,
ApplyExplicitBroadcasting<TF::RealDivOp>,
ApplyExplicitBroadcasting<TF::SubOp>,
ApplyExplicitBroadcasting<TF::FloorDivOp>,
ApplyExplicitBroadcasting<TF::FloorModOp>,
ApplyExplicitBroadcasting<TF::PowOp>,
ApplyExplicitBroadcasting<TF::MaximumOp>,
ApplyExplicitBroadcasting<TF::MinimumOp>,
ApplyExplicitBroadcasting<TF::SquaredDifferenceOp>,
ApplyExplicitBroadcasting<TF::SelectV2Op>>(context);
FrozenRewritePatternList stage2FrozenPatterns(std::move(stage2Patterns));
applyPatterns(func, target, stage2FrozenPatterns);
}
} // namespace

View File

@ -200,16 +200,16 @@ bool CanOptimizeIdentityGatherNdOrScatterNdOp(Value params,
ElementsAttr ExpandTo4DForConvImpl(Attribute a, bool is_depthwise) {
auto elements = a.dyn_cast<DenseElementsAttr>();
auto shape = elements.getType().getShape();
if (shape.size() == 4) {
return elements;
if (!shape.empty()) {
// Checks that elements are essentially 1d.
assert(elements.getNumElements() == shape.back());
}
std::vector<int64_t> shape_data = {1, 1, 1, 1};
if (shape.size() == 1 || shape.empty()) {
if (is_depthwise)
shape_data[3] = shape.empty() ? 1 : shape[0];
else
shape_data[0] = shape.empty() ? 1 : shape[0];
}
const int vector_length = elements.getNumElements();
if (is_depthwise)
shape_data[3] = vector_length;
else
shape_data[0] = vector_length;
auto new_shape =
RankedTensorType::get(shape_data, elements.getType().getElementType());
return elements.reshape(new_shape);

View File

@ -41,7 +41,7 @@ std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass();
// Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
bool unfold_batch_matmul);
bool unfold_batch_matmul, bool allow_bf16_type_legalization);
// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
// pass.

View File

@ -257,9 +257,9 @@ void PrepareQuantizePass::SanityCheckAndAdjustment(FuncOp func) {
// (Quant $in, $qA) would introduce less quantization noise) the likley cause
// is an minor error in constructing the original network model that
// introduced back-to-back Fake Quantization operations. Hence: emit a
// warning. N.b. at this point weŕe (teporarility) in the quantization dialect
// (presuambly enalbe re-use in xla etc) quant::*QuantizeCastOp weŕe matching
// here.
// warning. N.b. at this point we're (teporarility) in the quantization
// dialect (presuambly enable re-use in xla etc) quant::*QuantizeCastOp
// we're matching here.
//
func.walk([&](quant::QuantizeCastOp q_op) {
// If up with end up with

View File

@ -82,8 +82,10 @@ class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
public:
PrepareTFPass() = default;
PrepareTFPass(const PrepareTFPass &) {}
explicit PrepareTFPass(bool unfold_batch_matmul) {
explicit PrepareTFPass(bool unfold_batch_matmul,
bool allow_bf16_type_legalization) {
unfold_batch_matmul_ = unfold_batch_matmul;
allow_bf16_type_legalization_ = allow_bf16_type_legalization;
}
void runOnFunction() override;
@ -97,6 +99,10 @@ class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
*this, "tfl-unfold-batch-matmul",
llvm::cl::desc("Unfold BatchMatMul into individual MatMul ops."),
llvm::cl::init(true)};
Option<bool> allow_bf16_type_legalization_{
*this, "tfl-allow-bf16-type-legalization",
llvm::cl::desc("Allow bf16 type legalization."), llvm::cl::init(false)};
};
template <class TFFakeQuantOp>
@ -258,6 +264,15 @@ using PreparePerTensorFakeQuantWithMinMaxArgs =
TF::FakeQuantWithMinMaxArgsOp, /*PerAxis=*/false,
FetchMinMaxAttrs<TF::FakeQuantWithMinMaxArgsOp>>;
// Transient state for preserving data from match to rewrite
struct ConvertTFConvOpMatchState {
IntegerAttr dilation_height_factor;
IntegerAttr dilation_width_factor;
StringAttr padding;
IntegerAttr stride_height;
IntegerAttr stride_width;
};
// Templated class for declaring a converter from some TensorFlow convolution
// op into its counterpart in TensorFlow Lite.
//
@ -273,19 +288,12 @@ using PreparePerTensorFakeQuantWithMinMaxArgs =
//
// int64_t getBiasDim(ArrayRef<int64_t> filterShape) const;
template <typename ConcreteType, typename TFConvOpType>
struct ConvertTFConvOp : public RewritePattern {
// Transient state for preserving data from match to rewrite
struct ConvertTFConvOpMatchState {
IntegerAttr dilation_height_factor;
IntegerAttr dilation_width_factor;
StringAttr padding;
IntegerAttr stride_height;
IntegerAttr stride_width;
};
ConvertTFConvOp(MLIRContext *context)
class ConvertTFConvOp : public RewritePattern {
public:
ConvertTFConvOp(MLIRContext *context, bool allow_bf16_type_legalization)
: RewritePattern(TFConvOpType::getOperationName(), 1, context),
intAttrOne(Builder(context).getI32IntegerAttr(1)) {}
intAttrOne(Builder(context).getI32IntegerAttr(1)),
allow_bf16_type_legalization_(allow_bf16_type_legalization) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
@ -301,9 +309,13 @@ struct ConvertTFConvOp : public RewritePattern {
TFConvOpType tf_op = cast<TFConvOpType>(op);
if (!TFTypeIsFloatTensor(tf_op.input()) || !TFDataFormatIsNHWC(op))
if (!TFTypeIsFloat32Tensor(tf_op.input()) &&
!(allow_bf16_type_legalization_ &&
TFTypeIsBFloat16Tensor(tf_op.input())))
return failure();
if (!TFDataFormatIsNHWC(op)) return failure();
IntegerAttr height, width;
if (!TFIntListIs1XY1(op, "strides", &height, &width)) return failure();
@ -359,13 +371,17 @@ struct ConvertTFConvOp : public RewritePattern {
}
const IntegerAttr intAttrOne;
private:
bool allow_bf16_type_legalization_;
};
class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
public:
using BaseType = ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp>;
ConvertTFConv2D(MLIRContext *context) : BaseType(context) {}
ConvertTFConv2D(MLIRContext *context, bool allow_bf16_type_legalization)
: BaseType(context, allow_bf16_type_legalization) {}
int64_t getBiasDim(ArrayRef<int64_t> filterShape) const {
return filterShape.back();
@ -421,7 +437,9 @@ class ConvertTFDepthwiseConv2dNative
using BaseType = ConvertTFConvOp<ConvertTFDepthwiseConv2dNative,
TF::DepthwiseConv2dNativeOp>;
ConvertTFDepthwiseConv2dNative(MLIRContext *context) : BaseType(context) {}
ConvertTFDepthwiseConv2dNative(MLIRContext *context,
bool allow_bf16_type_legalization)
: BaseType(context, allow_bf16_type_legalization) {}
int64_t getBiasDim(ArrayRef<int64_t> filterShape) const {
return filterShape[2] * filterShape[3];
@ -716,9 +734,9 @@ struct ConvertTFBroadcastTo : public RewritePattern {
// Allow lowering when low dimension inputs are given and its type is F32 or
// I32.
if (!((output_type.hasRank() && output_type.getRank() <= 5) ||
if (!((output_type.hasRank() && output_type.getRank() <= 4) ||
(shape_type.hasStaticShape() && shape_type.getRank() == 1 &&
shape_type.getDimSize(0) <= 5)))
shape_type.getDimSize(0) <= 4)))
return failure();
if (!(element_type.isa<BFloat16Type, Float32Type>() ||
@ -815,6 +833,9 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
offset = fused_batch_norm_op.getODSOperands(2);
mean = fused_batch_norm_op.getODSOperands(3);
variance = fused_batch_norm_op.getODSOperands(4);
if (!TFTypeIsFloat32Tensor(fused_batch_norm_op.x())) return failure();
{
epsilon = fused_batch_norm_op.getAttrOfType<::mlir::FloatAttr>("epsilon");
if (!epsilon)
@ -1196,8 +1217,10 @@ void PrepareTFPass::runOnFunction() {
ctx);
}
phase_2_patterns.insert<TF::ConvertTFEinsumOp, ConvertTFBroadcastTo,
ConvertTFConv2D, ConvertTFDepthwiseConv2dNative,
ConvertTFStridedSlice, ConvertRfftToRfft2d>(ctx);
phase_2_patterns.insert<ConvertTFConv2D, ConvertTFDepthwiseConv2dNative>(
ctx, allow_bf16_type_legalization_);
applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
}
@ -1205,8 +1228,9 @@ void PrepareTFPass::runOnFunction() {
// Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
bool unfold_batch_matmul) {
return std::make_unique<PrepareTFPass>(unfold_batch_matmul);
bool unfold_batch_matmul, bool allow_bf16_type_legalization) {
return std::make_unique<PrepareTFPass>(unfold_batch_matmul,
allow_bf16_type_legalization);
}
static PassRegistration<PrepareTFPass> pass(

View File

@ -132,7 +132,7 @@ bool NotFromQuantOpOrSameQuantType(mlir::Value val, mlir::TypeAttr qtype_attr) {
llvm::dyn_cast_or_null<mlir::TFL::QuantizeOp>(val_defn_op);
if (!q_op) return true;
// Ignore shape details - weŕe really only trying to
// Ignore shape details - we're really only trying to
// check if quantization is the same.
auto stripped_src_qtype = GetShapeStrippedType(q_op.qtypeAttr());
auto stripped_qtype = GetShapeStrippedType(qtype_attr);

View File

@ -49,12 +49,19 @@ bool TFIntListIs1XY1(const ArrayAttr &attr);
// must be `IntegerAttr`.
bool TFIntListIsAllOnes(const ArrayAttr &attr);
// Returns true iff the given value is a float tensor.
// Returns true iff the given value is a float32 tensor.
// is "DT_FLOAT".
inline bool TFTypeIsFloatTensor(Value value) {
inline bool TFTypeIsFloat32Tensor(Value value) {
auto tensorType = value.getType().dyn_cast<TensorType>();
if (!tensorType) return false;
return tensorType.getElementType().isa<FloatType>();
return tensorType.getElementType().isF32();
}
// Returns true iff the given value is a bf16 tensor.
inline bool TFTypeIsBFloat16Tensor(Value value) {
auto tensorType = value.getType().dyn_cast<TensorType>();
if (!tensorType) return false;
return tensorType.getElementType().isBF16();
}
// Returns true iff the given TensorFlow op has a `padding` attribute whose

View File

@ -882,9 +882,11 @@ cc_library(
"transforms/test_visitor_util.cc",
"transforms/tf_data_optimization_pass.cc",
"transforms/tf_device_assignment.cc",
"transforms/tf_device_replication_pass.cc",
"transforms/tpu_cluster_cleanup_attributes.cc",
"transforms/tpu_cluster_formation.cc",
"transforms/tpu_colocate_composite_resource_ops.cc",
"transforms/tpu_compile_op_replication_pass.cc",
"transforms/tpu_device_propagation.cc",
"transforms/tpu_dynamic_layout_pass.cc",
"transforms/tpu_dynamic_padding_mapper.cc",

View File

@ -1002,6 +1002,10 @@ reverse of SpaceToBatch. See below for a precise description.
TF_Tensor:$output
);
let verifier = [{
return Verify(*this);
}];
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tcrops = TF_DerivedOperandTypeAttr<2>;
TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>;

View File

@ -398,6 +398,27 @@ void BatchToSpaceOp::getCanonicalizationPatterns(
results.insert<BatchToSpaceToBatchToSpaceND>(context);
}
//===----------------------------------------------------------------------===//
// BatchToSpaceNDOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(BatchToSpaceNDOp op) {
auto block_shape_ty = op.block_shape().getType().cast<ShapedType>();
auto crops_ty = op.crops().getType().cast<ShapedType>();
if (block_shape_ty.hasStaticShape() && crops_ty.hasStaticShape()) {
const int block_rank = block_shape_ty.getShape().front();
if (crops_ty.getRank() != 2 || crops_ty.getShape().front() != block_rank ||
crops_ty.getShape()[1] != 2) {
op.emitOpError() << "crops should have shape [" << block_rank
<< ", 2] instead of " << crops_ty.getShape();
return failure();
}
}
return success();
}
//===----------------------------------------------------------------------===//
// BiasAddOp
//===----------------------------------------------------------------------===//

View File

@ -290,6 +290,25 @@ func @sixdim_space_to_batch_nd(%input: tensor<3x5x7x9x10x11xf32>, %block_shape:
return %0 : tensor<?x?x?x?x10x11xf32>
}
// CHECK-LABEL: func @batchToSpace
func @batchToSpace(%arg0: tensor<3x5x2xf32>) -> (tensor<1x8x2xf32>) {
// CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() {value = dense<[3, 1, 5, 2]> : tensor<4xi64>}
// CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() {value = dense<[1, 2, 0, 3]> : tensor<4xi64>}
// CHECK-DAG: [[VAL2:%.+]] = "tf.Const"() {value = dense<[1, 15, 2]> : tensor<3xi64>}
// CHECK-DAG: [[VAL3:%.+]] = "tf.Const"() {value = dense<[0, 3, 0]> : tensor<3xi64>}
// CHECK-DAG: [[VAL4:%.+]] = "tf.Const"() {value = dense<[1, 8, 2]> : tensor<3xi64>}
// CHECK-DAG: [[VAL5:%.+]] = "tf.Reshape"(%arg0, [[VAL0]])
// CHECK-DAG: [[VAL6:%.+]] = "tf.Transpose"([[VAL5]], [[VAL1]])
// CHECK-DAG: [[VAL7:%.+]] = "tf.Reshape"([[VAL6]], [[VAL2]])
// CHECK-DAG: [[VAL8:%.+]] = "tf.Slice"([[VAL7]], [[VAL3]], [[VAL4]])
%0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Const"() {value = dense<[[3, 4]]> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
%2 = "tf.BatchToSpaceND"(%arg0, %0, %1) {device = ""} : (tensor<3x5x2xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x8x2xf32>
// CHECK: return [[VAL8]] : tensor<1x8x2xf32>
return %2 : tensor<1x8x2xf32>
}
func @fake_quant_with_min_max_args(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() {value = dense<1.275000e+02> : tensor<f32>}
// CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() {value = dense<1.00392163> : tensor<f32>}

View File

@ -0,0 +1,25 @@
// RUN: tf-opt --tf-device-replication %s | FileCheck %s
// CHECK: func @test_1(%[[ARG_0:.*]]: tensor<i32> {tf.device = "/job:worker/replica:0/task:0/device:CPU:0"}, %[[ARG_1:.*]]: tensor<i32> {tf.device = "/job:worker/replica:0/task:1/device:CPU:0"})
func @test_1(%arg0: tensor<i32> {tf.device = "/job:worker/replica:0/task:0/device:CPU:0"}, %arg1: tensor<i32> {tf.device = "/job:worker/replica:0/task:1/device:CPU:0"}) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
// CHECK-NEXT: %[[RESULT_0:.*]] = "tf.AddV2"(%[[ARG_0]], %[[ARG_0]]) {device = "/job:worker/replica:0/task:0/device:CPU:0"}
// CHECK-NEXT: %[[RESULT_1:.*]] = "tf.AddV2"(%[[ARG_0]], %[[ARG_0]]) {device = "/job:worker/replica:0/task:0/device:CPU:1"}
// CHECK-NEXT: %[[RESULT_2:.*]] = "tf.AddV2"(%[[ARG_1]], %[[ARG_1]]) {device = "/job:worker/replica:0/task:1/device:CPU:0"}
// CHECK-NEXT: %[[RESULT_3:.*]] = "tf.AddV2"(%[[ARG_1]], %[[ARG_1]]) {device = "/job:worker/replica:0/task:1/device:CPU:1"}
%0:4 = tf_device.replicate([%arg0, %arg0, %arg1, %arg1] as %arg2: tensor<i32>) {devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:1", "/job:worker/replica:0/task:1/device:CPU:0", "/job:worker/replica:0/task:1/device:CPU:1"]}, n = 4 : i32} {
%1 = "tf.AddV2"(%arg2, %arg2) {device = "TPU_REPLICATED_CORE_0"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_device.return %1 : tensor<i32>
}
// CHECK-NEXT: return %[[RESULT_0]], %[[RESULT_1]], %[[RESULT_2]], %[[RESULT_3]]
return %0#0, %0#1, %0#2, %0#3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
// CHECK: func @test_2(%[[ARG_0:.*]]: tensor<i32>
func @test_2(%arg0: tensor<i32> {tf.device = "/job:worker/replica:0/task:0/device:CPU:0"}) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
%0:4 = tf_device.replicate() {n = 4 : i32} {
tf_device.return %arg0 : tensor<i32>
}
// CHECK-NEXT: return %[[ARG_0]], %[[ARG_0]], %[[ARG_0]], %[[ARG_0]]
return %0#0, %0#1, %0#2, %0#3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}

View File

@ -0,0 +1,22 @@
// RUN: tf-opt --tf-tpu-compile-replication %s | FileCheck %s
// CHECK: func @test(%[[ARG_0:.*]]: tensor<i32> {tf.device = "/job:worker/replica:0/task:0/device:CPU:0"}, %[[ARG_1:.*]]: tensor<i32> {tf.device = "/job:worker/replica:0/task:1/device:CPU:0"})
func @test(%arg0: tensor<i32> {tf.device = "/job:worker/replica:0/task:0/device:CPU:0"}, %arg1: tensor<i32> {tf.device = "/job:worker/replica:0/task:1/device:CPU:0"}) -> (tensor<i32>, tensor<i32>) {
// CHECK-NEXT: %[[STATUS_0:.*]], %[[PROGRAM_0:.*]] = "tf._TPUCompileMlir"() {device = "/job:worker/replica:0/task:1/device:CPU:0", metadata = "metadata", mlir_module = "mlir_module"}
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[STATUS_0]]) {device = "/job:worker/replica:0/task:1/device:CPU:0"}
// CHECK-NEXT: %[[STATUS_1:.*]], %[[PROGRAM_1:.*]] = "tf._TPUCompileMlir"() {device = "/job:worker/replica:0/task:0/device:CPU:0", metadata = "metadata", mlir_module = "mlir_module"}
%compilation_status, %program = "tf._TPUCompileMlir"() {device = "/job:worker/replica:0/task:0/device:CPU:0", metadata = "metadata", mlir_module = "mlir_module"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[STATUS_1]]) {device = "/job:worker/replica:0/task:0/device:CPU:0"}
"tf.TPUCompileSucceededAssert"(%compilation_status) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : (tensor<!tf.string>) -> ()
// CHECK-NEXT: %[[ADD_0:.*]] = "tf.AddV2"(%[[ARG_0]], %[[ARG_0]]) {device = "/job:worker/replica:0/task:0/device:TPU:0"}
%0 = "tf.AddV2"(%arg0, %arg0) {device = "/job:worker/replica:0/task:0/device:TPU:0"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK-NEXT: %[[EXECUTE_0:.*]] = "tf.TPUExecute"(%[[ADD_0]], %[[PROGRAM_1]]) {device = "/job:worker/replica:0/task:0/device:TPU:0"}
%1 = "tf.TPUExecute"(%0, %program) {device = "/job:worker/replica:0/task:0/device:TPU:0"} : (tensor<i32>, tensor<2x!tf.string>) -> tensor<i32>
// CHECK-NEXT: %[[ADD_1:.*]] = "tf.AddV2"(%[[ARG_1]], %[[ARG_1]]) {device = "/job:worker/replica:0/task:1/device:TPU:0"}
%2 = "tf.AddV2"(%arg1, %arg1) {device = "/job:worker/replica:0/task:1/device:TPU:0"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK-NEXT: %[[EXECUTE_1:.*]] = "tf.TPUExecute"(%[[ADD_1]], %[[PROGRAM_0]]) {device = "/job:worker/replica:0/task:1/device:TPU:0"}
%3 = "tf.TPUExecute"(%2, %program) {device = "/job:worker/replica:0/task:1/device:TPU:0"} : (tensor<i32>, tensor<2x!tf.string>) -> tensor<i32>
// CHECK-NEXT: return %[[EXECUTE_0]], %[[EXECUTE_1]]
return %1, %3 : tensor<i32>, tensor<i32>
}

View File

@ -121,7 +121,7 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
pm.addPass(CreateTPURewritePass());
pm.addPass(createSymbolDCEPass());
pm.addNestedPass<FuncOp>(TFDevice::CreateReplicateInvariantOpHoistingPass());
pm.addNestedPass<FuncOp>(CreateTPUDynamicLayoutPass());
pm.addPass(CreateTPUDynamicLayoutPass());
pm.addNestedPass<FuncOp>(CreateTPUMergeVariablesWithExecutePass());
pm.addNestedPass<FuncOp>(CreateTPUColocateCompositeResourceOps());
pm.addPass(CreateTPUVariableReformattingPass());

View File

@ -50,7 +50,8 @@ Status MlirGraphOptimizationPass::Run(const ConfigProto& config_proto,
// Assign optimal data layout to layout sensitive operations and delete
// redundant transposes from the IR.
LayoutOptimizationPipelineOptions layout_optimization_options;
CreateLayoutOptimizationPipeline(pm, layout_optimization_options);
CreateLayoutOptimizationPipeline(pm.nest<FuncOp>(),
layout_optimization_options);
// Prepare IR for exporting.
pm.addPass(CreateBreakUpIslandsPass());

View File

@ -85,7 +85,7 @@ void InitTextFileToImportTestPass::runOnOperation() {
// Run the lowering pass.
PassManager pm(context);
pm.addPass(CreateInitTextFileToImportPass());
pm.addNestedPass<FuncOp>(CreateInitTextFileToImportPass());
if (failed(pm.run(module))) return signalPassFailure();
}

View File

@ -988,6 +988,177 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
}
};
class LowerBatchToSpaceND : public RewritePattern {
public:
explicit LowerBatchToSpaceND(MLIRContext *context)
: RewritePattern(BatchToSpaceNDOp::getOperationName(),
{
ConstOp::getOperationName(),
ReshapeOp::getOperationName(),
SliceOp::getOperationName(),
TransposeOp::getOperationName(),
},
1, context) {}
LogicalResult matchAndRewrite(Operation *src_op,
PatternRewriter &rewriter) const override {
auto op = cast<BatchToSpaceNDOp>(src_op);
auto input = op.input();
auto input_ty = input.getType().cast<ShapedType>();
auto element_ty = input_ty.getElementType();
if (!input_ty.hasStaticShape()) {
return failure();
}
const int input_rank = input_ty.getRank();
auto input_shape = input_ty.getShape();
DenseIntElementsAttr block_shape;
DenseIntElementsAttr crops;
if (!matchPattern(op.block_shape(), m_Constant(&block_shape)) ||
!matchPattern(op.crops(), m_Constant(&crops))) {
return failure();
}
auto block_shape_ty = block_shape.getType();
if (!block_shape_ty.hasRank() || block_shape_ty.getRank() != 1) {
return failure();
}
const int block_rank = block_shape_ty.getShape().front();
auto remainder_shape = input_shape.drop_front(1 + block_rank);
const int64_t batch_size = input_shape[0];
// Compute the product of the block_shape values.
int64_t block_num_elems = 1;
for (auto val : block_shape.getIntValues()) {
block_num_elems *= val.getSExtValue();
}
if (block_num_elems <= 0) {
op.emitOpError()
<< "The product of the block dimensions must be positive";
return failure();
}
// 1. Reshape `input` to `reshaped` of shape:
// [block_shape[0], ..., block_shape[M-1],
// batch / prod(block_shape),
// input_shape[1], ..., input_shape[N-1]]
std::vector<int64_t> reshaped_shape;
for (auto val : block_shape) {
reshaped_shape.push_back(val.getSExtValue());
}
reshaped_shape.resize(input_rank + block_rank);
reshaped_shape[block_rank] = batch_size / block_num_elems;
std::copy(input_shape.begin() + 1, input_shape.end(),
reshaped_shape.begin() + block_rank + 1);
auto reshaped = rewriter.create<TF::ReshapeOp>(
op.getLoc(), RankedTensorType::get(reshaped_shape, element_ty), input,
rewriter.create<ConstOp>(op.getLoc(),
rewriter.getI64TensorAttr(reshaped_shape)));
// 2. Permute dimensions of `reshaped` to produce `permuted` of shape
// [batch / prod(block_shape),
//
// input_shape[1], block_shape[0],
// ...,
// input_shape[M], block_shape[M-1],
//
// input_shape[M+1], ..., input_shape[N-1]]
std::vector<int64_t> permutation(reshaped_shape.size());
permutation[0] = block_rank;
for (int i = 0; i < block_rank; ++i) {
permutation[1 + 2 * i] = block_rank + 1 + i;
permutation[1 + 2 * i + 1] = i;
}
std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
1 + block_rank * 2);
std::vector<int64_t> transpose_shape(permutation.size());
for (auto it : llvm::enumerate(permutation)) {
transpose_shape[it.index()] = reshaped_shape[it.value()];
}
auto permuted = rewriter.create<TF::TransposeOp>(
op.getLoc(), RankedTensorType::get(transpose_shape, element_ty),
reshaped,
rewriter.create<ConstOp>(op.getLoc(),
rewriter.getI64TensorAttr(permutation)));
// 3. Reshape `permuted` to produce `reshaped_permuted` of shape
// [batch / prod(block_shape),
//
// input_shape[1] * block_shape[0],
// ...,
// input_shape[M] * block_shape[M-1],
//
// input_shape[M+1],
// ...,
// input_shape[N-1]]
std::vector<int64_t> reshaped_permuted_shape(input_rank);
auto block_shape_values = llvm::to_vector<4>(block_shape.getIntValues());
reshaped_permuted_shape[0] = batch_size / block_num_elems;
for (int i = 0; i < block_rank; ++i) {
reshaped_permuted_shape[1 + i] =
block_shape_values[i].getSExtValue() * input_shape[1 + i];
}
std::copy(remainder_shape.begin(), remainder_shape.end(),
reshaped_permuted_shape.begin() + 1 + block_rank);
auto reshaped_permuted = rewriter.create<TF::ReshapeOp>(
op.getLoc(), RankedTensorType::get(reshaped_permuted_shape, element_ty),
permuted,
rewriter.create<ConstOp>(
op.getLoc(), rewriter.getI64TensorAttr(reshaped_permuted_shape)));
// 4. Crop the start and end of dimensions `[1, ..., M]` of
// `reshaped_permuted` according to `crops` to produce the output of
// shape:
// [batch / prod(block_shape),
//
// input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
// ...,
// input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
//
// input_shape[M+1], ..., input_shape[N-1]]
std::vector<int64_t> start_indices(input_rank, 0);
std::vector<int64_t> slice_sizes = reshaped_permuted_shape;
std::vector<int64_t> strides(input_rank, 1);
auto crop_values = llvm::to_vector<4>(crops.getIntValues());
for (int i = 0; i < block_rank; ++i) {
int64_t crop_start = crop_values[i * 2].getSExtValue();
int64_t crop_end = crop_values[i * 2 + 1].getSExtValue();
if (crop_start < 0 || crop_end < 0) {
op.emitOpError() << "Crops must be non-negative";
return failure();
}
start_indices[i + 1] = crop_start;
slice_sizes[i + 1] -= crop_start + crop_end;
if (slice_sizes[i + 1] < 0) {
op.emitOpError() << "Cropped size must be non-negative: start: "
<< crop_start << " end: " << crop_end << " size "
<< reshaped_permuted_shape[1 + i];
}
}
rewriter.replaceOpWithNewOp<TF::SliceOp>(
op, RankedTensorType::get(slice_sizes, element_ty), reshaped_permuted,
rewriter.create<ConstOp>(op.getLoc(),
rewriter.getI64TensorAttr(start_indices)),
rewriter.create<ConstOp>(op.getLoc(),
rewriter.getI64TensorAttr(slice_sizes)));
return success();
}
};
// Lowers `SparseMatMulOp` to `MatMulOp`, ignoring the sparseness hints,
// since we currently don't have an implementation that can use this
// information. Adds appropriate casts where necessary to align element types
@ -1065,10 +1236,11 @@ class Lower_UnaryOpsComposition
void PopulateLoweringTFPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
patterns->insert<LowerAddNOp, ConvertFakeQuantWithMinMaxVarsOp,
LowerDynamicStitchOp, LowerInvertPermutationOp,
LowerLgammaOp, LowerPackOp, LowerSpaceToBatchNDOp,
LowerSparseMatMulOp, Lower_UnaryOpsComposition>(context);
patterns
->insert<LowerAddNOp, ConvertFakeQuantWithMinMaxVarsOp,
LowerDynamicStitchOp, LowerInvertPermutationOp, LowerLgammaOp,
LowerPackOp, LowerBatchToSpaceND, LowerSpaceToBatchNDOp,
LowerSparseMatMulOp, Lower_UnaryOpsComposition>(context);
populateWithGenerated(context, *patterns);
}

View File

@ -279,6 +279,11 @@ CreateMarkOpsForOutsideCompilationPass();
// attribute to each TensorFlow dialect op in the body based on the `device`
// attribute on the `tf_device.launch`.
std::unique_ptr<OperationPass<ModuleOp>> CreateLaunchToDeviceAttributePass();
// Creates a pass that hoists a `tf_device.replicate` body and replicates each
// TensorFlow dialect op in the body based on its `device` attribute and the
// `devices` attribute on the `tf_device.replicate`.
std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateTFDeviceReplicationPass();
} // namespace TFDevice
namespace TFTPU {
@ -369,6 +374,12 @@ void CreateTPUBridgePipeline(OpPassManager& pm);
// bridge in V1 mode.
void CreateTPUBridgePipelineV1(OpPassManager& pm);
// Creates a pass that replicates the tf._TPUCompileMlir op on each host that
// needs the compiled program. It helps avoid transferring the compiled binary
// between hosts.
std::unique_ptr<OperationPass<mlir::ModuleOp>>
CreateTPUCompileOpReplicationPass();
} // namespace TFTPU
} // namespace mlir

View File

@ -0,0 +1,127 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This pass hoists a `tf_device.replicate` body and replicates each TensorFlow
// dialect op in the body based on its `device` attribute and the `devices`
// attribute on the `tf_device.replicate`.
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TFDevice {
namespace {
constexpr char kDeviceAttr[] = "device";
class TFDeviceReplicationPass
: public PassWrapper<TFDeviceReplicationPass, OperationPass<ModuleOp>> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<TF::TensorFlowDialect>();
}
void runOnOperation() override {
ModuleOp module = getOperation();
const Dialect *tf_dialect = getContext().getLoadedDialect("tf");
module.walk([&](tf_device::ReplicateOp replicate_op) {
OpBuilder builder(replicate_op);
// Map from the existing operation in ReplicateOp's region to a list of
// its replicated operations.
llvm::DenseMap<Operation *, llvm::SmallVector<Operation *, 4>>
operation_map;
llvm::Optional<DictionaryAttr> devices = replicate_op.devices();
const int replicate_num = replicate_op.n();
// Replicates every operation in the region of the ReplicateOp to match
// the number of devices.
for (int i : llvm::seq<int>(0, replicate_num)) {
// Gets the mapping from the packed and replicated block arguments to
// the actual value. This mapping is used to replace the arguments used
// by the cloned operations.
BlockAndValueMapping mapping;
for (BlockArgument &arg : replicate_op.GetBody().getArguments()) {
Value new_arg =
replicate_op.GetReplicaOperandForBlockArgument(arg, i);
mapping.map(arg, new_arg);
}
for (Operation &op : replicate_op.GetBody().without_terminator()) {
// Clones the operation and places it outside the replicate_op's body.
llvm::SmallVector<Operation *, 4> &new_ops = operation_map[&op];
Operation *new_op = builder.clone(op, mapping);
new_ops.push_back(new_op);
// If the op is a TF op, it has a string-valued device attribute and
// the replicate_op has a list of devices corresponding to this device
// attribute's value, updates the device attribute for this op.
if (!devices) continue;
if (op.getDialect() != tf_dialect) continue;
StringAttr device_alias =
new_op->getAttrOfType<StringAttr>(kDeviceAttr);
if (!device_alias) continue;
Attribute new_devices = devices->get(device_alias.getValue());
if (!new_devices) continue;
ArrayAttr new_devices_array = new_devices.cast<ArrayAttr>();
new_op->setAttr(kDeviceAttr, new_devices_array[i].cast<StringAttr>());
}
}
// Replaces usages of the existing results of the tf_device.replicate
// op with the results of the newly replicated operations.
llvm::SmallVector<Value, 4> new_results;
for (Value v : replicate_op.GetBody().getTerminator()->getOperands()) {
OpResult result = v.dyn_cast<OpResult>();
// Uses the original value if the value is not an OpResult.
if (!result) {
for (int i = 0; i < replicate_num; ++i) new_results.push_back(v);
continue;
}
// Uses the original value if the value is defined by an op outside the
// tf_device.replicate's body.
Operation *op = result.getDefiningOp();
if (operation_map.find(op) == operation_map.end()) {
for (int i = 0; i < replicate_num; ++i) new_results.push_back(v);
continue;
}
// Uses the values defined by the newly replicated operations.
int result_num = result.getResultNumber();
for (Operation *new_op : operation_map[op]) {
new_results.push_back(new_op->getResult(result_num));
}
}
replicate_op.replaceAllUsesWith(new_results);
replicate_op.erase();
});
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateTFDeviceReplicationPass() {
return std::make_unique<TFDeviceReplicationPass>();
}
static PassRegistration<TFDeviceReplicationPass> pass(
"tf-device-replication",
"Hoists and replicates the tf_device.replicate "
"inner ops once for each associated device.");
} // namespace TFDevice
} // namespace mlir

View File

@ -0,0 +1,103 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This pass replicates the tf._TPUCompileMlir op on each host that needs the
// compiled program. It helps avoid transferring the compiled binary between
// hosts.
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TFTPU {
namespace {
using DeviceNameUtils = ::tensorflow::DeviceNameUtils;
using ParsedName = ::tensorflow::DeviceNameUtils::ParsedName;
constexpr char kDeviceAttr[] = "device";
constexpr int kStatusResultIndex = 0;
constexpr int kProgramResultIndex = 1;
static std::string GetHost(Operation *op) {
if (StringAttr device = op->getAttrOfType<StringAttr>(kDeviceAttr)) {
ParsedName parsed_name;
DeviceNameUtils::ParseFullName(device.getValue().str(), &parsed_name);
return DeviceNameUtils::ParsedNameToString(
DeviceNameUtils::AddressSpace(parsed_name));
}
return "";
}
class TPUCompileOpReplicationPass
: public PassWrapper<TPUCompileOpReplicationPass, OperationPass<ModuleOp>> {
void runOnOperation() override {
getOperation().walk([&](TF::_TPUCompileMlirOp tpu_compile_op) {
Value compiled_program = tpu_compile_op.getResult(kProgramResultIndex);
std::string tpu_compile_op_host = GetHost(tpu_compile_op.getOperation());
llvm::StringMap<Operation *> compile_op_by_host;
llvm::SmallVector<OpOperand *, 4> usages;
for (OpOperand &usage : compiled_program.getUses()) {
usages.push_back(&usage);
}
// For any op which uses the program compiled on a different host than the
// original tf._TPUCompileMlir op, replicate the tf._TPUCompileMlir op on
// that host and update the op to use the program compiled on the same
// host.
for (OpOperand *usage : usages) {
std::string usage_op_host = GetHost(usage->getOwner());
if (usage_op_host == tpu_compile_op_host) continue;
Operation *&new_compile_op = compile_op_by_host[usage_op_host];
// If it is not already created, create a tf._TPUCompileMlir op and a
// tf.TPUCompileSucceededAssert op on the first CPU of the target host.
if (!new_compile_op) {
std::string device_name = usage_op_host + "/device:CPU:0";
OpBuilder builder(tpu_compile_op);
new_compile_op = builder.clone(*tpu_compile_op.getOperation());
new_compile_op->setAttr(kDeviceAttr,
StringAttr::get(device_name, &getContext()));
TF::TPUCompileSucceededAssertOp new_assert_op =
builder.create<TF::TPUCompileSucceededAssertOp>(
new_compile_op->getLoc(),
new_compile_op->getResult(kStatusResultIndex));
new_assert_op.setAttr(kDeviceAttr,
new_compile_op->getAttr(kDeviceAttr));
}
// Updates the operand to use the result of the newly created
// tf._TPUCompileMlir op.
usage->set(new_compile_op->getResult(kProgramResultIndex));
}
return WalkResult::advance();
});
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUCompileOpReplicationPass() {
return std::make_unique<TPUCompileOpReplicationPass>();
}
static PassRegistration<TPUCompileOpReplicationPass> pass(
"tf-tpu-compile-replication",
"Replicate the TPU compile op to avoid sending the compiled binary between "
"hosts.");
} // namespace TFTPU
} // namespace mlir

View File

@ -3569,17 +3569,20 @@ Status SavedModelSignatureDefImporter::LiftVariables() {
mlir::PassManager pm(module_->getContext());
SetCrashReproducer(pm);
pm.addPass(mlir::tf_executor::CreateTFExecutorGraphPruningPass());
pm.addPass(mlir::CreateExecutorDialectToFunctionalConversionPass());
pm.addNestedPass<mlir::FuncOp>(
mlir::tf_executor::CreateTFExecutorGraphPruningPass());
pm.addNestedPass<mlir::FuncOp>(
mlir::CreateExecutorDialectToFunctionalConversionPass());
pm.addPass(
mlir::tf_saved_model::CreateRemoveVariablesInSessionInitializerPass());
pm.addPass(
pm.addNestedPass<mlir::FuncOp>(
mlir::TF::
CreateConvertReadonlyReferenceVariablesToResourceVariablesPass());
pm.addPass(mlir::TF::CreatePromoteVarHandlesToArgsPass());
pm.addPass(
mlir::tf_saved_model::CreateLiftVariablesPass(bundle_.GetSession()));
pm.addPass(mlir::tf_saved_model::CreateDedupBoundInputBindingPass());
pm.addNestedPass<mlir::FuncOp>(
mlir::tf_saved_model::CreateDedupBoundInputBindingPass());
if (mlir::failed(pm.run(*module_)))
return diag_handler.Combine(errors::Internal("Failed to lift variables."));

View File

@ -279,7 +279,8 @@ void CreateConvertMlirToXlaHloPipeline(
pm.addPass(mlir::TF::CreateTensorListOpsDecompositionPass());
pm.addPass(mlir::TF::CreateStackOpsDecompositionPass());
pm.addPass(mlir::TF::CreateTensorArrayOpsDecompositionPass());
pm.addPass(mlir::TFDevice::CreateDecomposeResourceOpsPass());
pm.addNestedPass<mlir::FuncOp>(
mlir::TFDevice::CreateDecomposeResourceOpsPass());
pm.addPass(mlir::TF::CreatePromoteResourcesToArgsPass());
pm.addPass(mlir::createSymbolDCEPass());
// Guarantee all functions have one use, which enables shape inference.

View File

@ -36,7 +36,7 @@ void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm) {
pm->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass());
// TFJS dialect passes.
pm->addPass(mlir::tfjs::CreateOptimizePass());
pm->addNestedPass<mlir::FuncOp>(mlir::tfjs::CreateOptimizePass());
// Canonicalize, CSE etc.
pm->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());

View File

@ -16,9 +16,11 @@ limitations under the License.
#include <string>
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
@ -34,7 +36,14 @@ namespace tfr {
Status CompositeOpExpansion::Run(EagerOperation* orig_op,
std::unique_ptr<EagerOperation>* out_op) {
if (!IsEnabled()) return Status::OK();
// This can be the default cpu device.
if (orig_op->Device() != kVariantDeviceNull) return Status::OK();
// TODO(fengliuai): We need a better condition to skip the rewrite. Currently,
// The rewrite is enabled for all the tf ops and it is a no-op if the tf op
// isn't a composite op. "VarHandleOp" is explicitly skipped here because its
// roundtrip fails due to some unknown reasons.
if (orig_op->is_function()) return Status::OK();
if (absl::StartsWith(orig_op->op_name(), "VarHandleOp")) return Status::OK();
tf_core_op_expansion_node_counter->GetCell()->IncrementBy(1);

View File

@ -103,6 +103,7 @@ std::unique_ptr<TFRDecomposeContext> TFRDecomposeContext::GetFromText(
mlir::tf_executor::TensorFlowExecutorDialect,
mlir::TFR::TFRDialect>();
// clang-format on
registry.loadAll(mlir_ctx);
// Load the TFR functions in a mlir::ModuleOp
auto memory_buffer = llvm::MemoryBuffer::getMemBuffer(

View File

@ -75,32 +75,32 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
applyTensorflowAndCLOptions(pm);
if (gpu_binary_only) {
pm.addPass(mlir::mhlo::createLegalizeTFPass(
pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
/*allow_partial_conversion=*/false, /*legalize_chlo=*/true));
pm.addNestedPass<mlir::FuncOp>(
mlir::kernel_gen::transforms::CreateMaterializeBroadcastsPass());
pm.addNestedPass<mlir::FuncOp>(
mlir::kernel_gen::transforms::CreateUnfuseBatchNormPass());
pm.addPass(mlir::mhlo::createLegalizeToLhloPass(
/*results_escape_functions=*/true));
pm.addPass(mlir::mhlo::createLegalizeToLhloPass());
// Moving `AllocOp`s and inserting missing `DeallocOp`s
pm.addPass(::mlir::createBufferHoistingPass());
pm.addPass(::mlir::createBufferDeallocationPass());
pm.addNestedPass<mlir::FuncOp>(::mlir::createBufferHoistingPass());
pm.addNestedPass<mlir::FuncOp>(::mlir::createBufferDeallocationPass());
pm.addNestedPass<mlir::FuncOp>(mlir::createCopyRemovalPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::kernel_gen::transforms::CreateShapeToDescriptorsPass());
} else {
pm.addPass(mlir::mhlo::createLegalizeTFPass(
pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
/*allow_partial_conversion=*/false, /*legalize_chlo=*/false));
pm.addPass(mlir::createTransformUnrankedHloPass());
pm.addPass(mlir::mhlo::createChloLegalizeToHloPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addNestedPass<mlir::FuncOp>(mlir::createTransformUnrankedHloPass());
pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createChloLegalizeToHloPass());
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pm.addPass(mlir::kernel_gen::transforms::CreateShapeToDescriptorsPass());
// Clean up the IR created above. In particular, operations on descriptors
// are simplified here.
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::kernel_gen::transforms::CreateBufferizePass());
pm.addPass(mlir::kernel_gen::transforms::CreateParallelLoopsToSequential());
pm.addNestedPass<mlir::FuncOp>(
mlir::kernel_gen::transforms::CreateParallelLoopsToSequential());
}
// Clean up the IR for further processing.
@ -120,36 +120,41 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
tiling_for_unrolling.append(tile_sizes.begin(), tile_sizes.end());
}
// Transform LHLO operations to LinAlg.
pm.addPass(::mlir::lmhlo::createLegalizeLhloToLinalgPass());
pm.addNestedPass<mlir::FuncOp>(
::mlir::lmhlo::createLegalizeLhloToLinalgPass());
// Fuse linalg operations.
pm.addPass(::mlir::lmhlo::createLhloFuseLinalgPass(
pm.addNestedPass<mlir::FuncOp>(::mlir::lmhlo::createLhloFuseLinalgPass(
/*use_parallel_loops=*/true, tiling_for_unrolling));
// Transform the Linalg operations inside of the loop nest into parallel
// loops.
pm.addPass(::mlir::createConvertLinalgToParallelLoopsPass());
pm.addNestedPass<mlir::FuncOp>(
::mlir::createConvertLinalgToParallelLoopsPass());
// Canonicalize the code to simplify index computations. This is needed so
// that loop bounds have the same value.
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
// Fuse the inner-most loops.
pm.addPass(xla::mlir_gpu::createFuseInnerParallelLoopsPass());
pm.addNestedPass<mlir::FuncOp>(
xla::mlir_gpu::createFuseInnerParallelLoopsPass());
// Run CSE to ensure that loads and stores to the same subview get
// recognized as such.
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
// Forward stores to buffers to loads.
pm.addPass(xla::mlir_gpu::createStoreForwardingPass());
pm.addNestedPass<mlir::FuncOp>(xla::mlir_gpu::createStoreForwardingPass());
// Remove now unused temporary buffers.
pm.addPass(xla::mlir_gpu::createDeadTempBufferRemovalPass());
pm.addNestedPass<mlir::FuncOp>(
xla::mlir_gpu::createDeadTempBufferRemovalPass());
if (!unroll_factors.empty()) {
pm.addPass(::mlir::createParallelLoopTilingPass(as_int64));
pm.addNestedPass<mlir::FuncOp>(
::mlir::createParallelLoopTilingPass(as_int64));
}
// Some basic cleanup.
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
// Greedily map the remaining loop to GPU hardware dimensions.
pm.addPass(xla::mlir_gpu::createMapParallelLoopsPass());
pm.addNestedPass<::mlir::FuncOp>(xla::mlir_gpu::createMapParallelLoopsPass());
// Apply the mapping.
pm.addPass(mlir::createParallelLoopToGpuPass());
pm.addNestedPass<::mlir::FuncOp>(mlir::createParallelLoopToGpuPass());
// Embed TF Framework ops.
if (!gpu_binary_only) {
@ -172,12 +177,13 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
if (gpu_binary_only) {
// Make kernel signature deterministic so that we can call it externally.
pm.addPass(xla::mlir_gpu::createRewriteKernelSignaturePass());
pm.addNestedPass<::mlir::FuncOp>(
xla::mlir_gpu::createRewriteKernelSignaturePass());
}
pm.addPass(::mlir::createLowerAffinePass());
// Constraints are removed as late as possible and before lowering to CFG.
pm.addPass(::mlir::createConvertShapeConstraintsPass());
pm.addNestedPass<::mlir::FuncOp>(::mlir::createConvertShapeConstraintsPass());
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
pm.addPass(::mlir::createLowerToCFGPass());

View File

@ -308,7 +308,7 @@ func @abs_unranked_i64(%arg : memref<*xi64>,
%flat_shape : memref<1xindex>,
%arg_size : index) -> memref<*xi64>
attributes {tf_entry} {
%flat_arg = lmhlo.reshape_memref_cast %arg(%flat_shape)
%flat_arg = memref_reshape %arg(%flat_shape)
: (memref<*xi64>, memref<1xindex>) -> memref<?xi64>
// CHECK: alloc
// CHECK-SAME: reuse_input_candidates = [0 : index], reuse_output = 0 : index
@ -324,7 +324,7 @@ func @abs_unranked_i64(%arg : memref<*xi64>,
%a_abs = select %a_pos, %a, %a_neg : i64
linalg.yield %a_abs : i64
}
%result = lmhlo.reshape_memref_cast %flat_result(%arg_shape)
%result = memref_reshape %flat_result(%arg_shape)
: (memref<?xi64>, memref<?xindex>) -> memref<*xi64>
return %result : memref<*xi64>
}

View File

@ -76,3 +76,39 @@ func @assuming(%witness: !shape.witness, %arg : memref<?xf32>)
}
return %assuming_result : tensor<?xf32>
}
// CHECK-LABEL: @const
// CHECK-SAME: -> memref<3xf32>
func @const() -> tensor<3xf32> {
// CHECK: %[[MEM:.*]] = alloca() : memref<3xf32>
// CHECK: %[[C4:.*]] = constant 4.000000e+00 : f32
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: store %[[C4]], %[[MEM]][%[[C0]]] : memref<3xf32>
// CHECK: %[[C5:.*]] = constant 5.000000e+00 : f32
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: store %[[C5]], %[[MEM]][%[[C1]]] : memref<3xf32>
// CHECK: %[[C6:.*]] = constant 6.000000e+00 : f32
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: store %[[C6]], %[[MEM]][%[[C2]]] : memref<3xf32>
// CHECK-NEXT: return %[[MEM]] : memref<3xf32>
%result = constant dense<[4.0, 5.0, 6.0]> : tensor<3xf32>
return %result : tensor<3xf32>
}
// CHECK-LABEL: @const_splat
// CHECK-SAME: -> memref<3xf32>
func @const_splat() -> tensor<3xf32> {
// CHECK: %[[MEM:.*]] = alloca() : memref<3xf32>
// CHECK: %[[C4:.*]] = constant 4.000000e+00 : f32
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: store %[[C4]], %[[MEM]][%[[C0]]] : memref<3xf32>
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: store %[[C4]], %[[MEM]][%[[C1]]] : memref<3xf32>
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: store %[[C4]], %[[MEM]][%[[C2]]] : memref<3xf32>
// CHECK-NEXT: return %[[MEM]] : memref<3xf32>
%result = constant dense<4.0> : tensor<3xf32>
return %result : tensor<3xf32>
}

View File

@ -12,9 +12,9 @@ func @tanh(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NOT: tensor_load
// CHECK: scf.for
// CHECK-NOT: tensor_from_elements
// CHECK: mhlo.reshape_memref_cast
// CHECK: memref_reshape
// CHECK: lmhlo.tanh
// CHECK: mhlo.reshape_memref_cast
// CHECK: memref_reshape
%0 = "tf.Tanh"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

View File

@ -39,7 +39,7 @@ module attributes {gpu.container_module} {
^bb6: // pred: ^bb4
%13 = alloca() : memref<1xindex>
store %8, %13[%c0] : memref<1xindex>
%14 = lmhlo.reshape_memref_cast %arg0(%13) : (memref<*xf32>, memref<1xindex>) -> memref<?xf32>
%14 = memref_reshape %arg0(%13) : (memref<*xf32>, memref<1xindex>) -> memref<?xf32>
%15 = dim %14, %c0 : memref<?xf32>
%16 = tf_framework.alloc(%ctx, %15) : memref<?xf32>
%17 = cmpi "sle", %15, %c0 : index
@ -53,7 +53,7 @@ module attributes {gpu.container_module} {
gpu.launch_func @abs_kernel::@abs_kernel
blocks in (%24, %c1, %c1) threads in (%c256, %c1, %c1)
args(%14 : memref<?xf32>, %16 : memref<?xf32>)
%25 = lmhlo.reshape_memref_cast %16(%1) : (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
%25 = memref_reshape %16(%1) : (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
return %25 : memref<*xf32>
}

View File

@ -35,11 +35,9 @@ cc_library(
srcs = ["bufferize.cc"],
hdrs = ["rewriters.h"],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
@ -128,7 +126,6 @@ cc_library(
"//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo",
"//tensorflow/compiler/mlir/hlo:lhlo",
"//tensorflow/compiler/xla/service/gpu:stream_executor_util",
"//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_llvm",
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
] + if_cuda_is_configured([
"//tensorflow/stream_executor/gpu:asm_compiler",

View File

@ -110,8 +110,8 @@ class BufferSizeAnalysis {
});
// Operand and result of `reshape_memref_cast` must be of same size.
f.walk([&](lmhlo::ReshapeMemRefCastOp reshapeOp) {
ecs_.unionSets(reshapeOp.result(), reshapeOp.operand());
f.walk([&](MemRefReshapeOp reshapeOp) {
ecs_.unionSets(reshapeOp.result(), reshapeOp.source());
});
}

View File

@ -17,28 +17,49 @@ limitations under the License.
#include "mlir/Transforms/Bufferize.h" // from @llvm-project
#include <cstddef>
#include <memory>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
namespace mlir {
namespace kernel_gen {
namespace transforms {
namespace {
class ConstantOpConverter : public OpConversionPattern<ConstantOp> {
public:
using OpConversionPattern<ConstantOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
ConstantOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// We only need to bufferize tensor constants.
Location loc = op.getLoc();
auto result_type = op.getType().dyn_cast<RankedTensorType>();
if (!result_type || !result_type.hasStaticShape() ||
result_type.getRank() != 1)
return failure();
auto memref_type = MemRefType::get({result_type.getNumElements()},
result_type.getElementType());
Value buffer = rewriter.create<AllocaOp>(loc, memref_type);
auto elements_attr = op.getValue().dyn_cast<DenseElementsAttr>();
bool all_same_elems = elements_attr.isSplat();
Value value;
if (all_same_elems)
value = rewriter.create<ConstantOp>(loc, elements_attr.getSplatValue());
for (auto en : llvm::enumerate(elements_attr.getAttributeValues())) {
if (!all_same_elems) value = rewriter.create<ConstantOp>(loc, en.value());
Value index = rewriter.create<ConstantIndexOp>(loc, en.index());
rewriter.create<StoreOp>(loc, value, buffer, index);
}
rewriter.replaceOp(op, {buffer});
return success();
}
};
class TensorFromElementsOpConverter
: public OpConversionPattern<TensorFromElementsOp> {
public:
@ -48,14 +69,14 @@ class TensorFromElementsOpConverter
TensorFromElementsOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Location loc = op.getLoc();
ShapedType result_type = op.getType().cast<ShapedType>();
auto result_type = op.getType().cast<ShapedType>();
int number_of_elements = op.elements().size();
MemRefType memref_type =
MemRefType::get({number_of_elements}, result_type.getElementType());
Value result = rewriter.create<AllocaOp>(loc, memref_type);
for (auto operand : llvm::enumerate(operands)) {
Value index = rewriter.create<ConstantIndexOp>(loc, operand.index());
rewriter.create<StoreOp>(loc, operand.value(), result, index);
for (auto en : llvm::enumerate(operands)) {
Value index = rewriter.create<ConstantIndexOp>(loc, en.index());
rewriter.create<StoreOp>(loc, en.value(), result, index);
}
rewriter.replaceOp(op, {result});
return success();
@ -73,7 +94,7 @@ class DynamicTensorFromElementsOpConverter
// Allocate memory on stack.
Location loc = op.getLoc();
DynamicTensorFromElementsOp::Adaptor transformed(operands);
RankedTensorType tensor_ty = op.getType().cast<RankedTensorType>();
auto tensor_ty = op.getType().cast<RankedTensorType>();
MemRefType memref_type =
MemRefType::get(tensor_ty.getShape(), tensor_ty.getElementType());
Value result = rewriter.create<AllocaOp>(loc, memref_type,
@ -87,7 +108,7 @@ class DynamicTensorFromElementsOpConverter
SmallVector<Value, 4> steps(rank, one);
SmallVector<Value, 4> upper_bounds;
int next_dynamic_index = 0;
for (int i = 0; i < rank; i++) {
for (int i = 0; i < rank; ++i) {
Value ub = tensor_ty.isDynamicDim(i)
? transformed.dynamicExtents()[next_dynamic_index++]
: rewriter.create<ConstantIndexOp>(
@ -139,7 +160,6 @@ class ExtractElementOpConversion
if (!adaptor.aggregate().getType().isa<BaseMemRefType>()) {
return failure();
}
rewriter.replaceOpWithNewOp<LoadOp>(op, adaptor.aggregate(),
adaptor.indices());
return success();
@ -182,7 +202,8 @@ class TensorCastOpConverter : public OpConversionPattern<TensorCastOp> {
void populateStandardBufferizePattern(MLIRContext *context,
BufferizeTypeConverter *converter,
OwningRewritePatternList *patterns) {
patterns->insert<ExtractElementOpConversion, TensorFromElementsOpConverter,
patterns->insert<ConstantOpConverter, ExtractElementOpConversion,
TensorFromElementsOpConverter,
DynamicTensorFromElementsOpConverter,
SimpleOpResultConversion<SelectOp>, TensorLoadOpConversion,
TensorCastOpConverter>(*converter, context);

View File

@ -98,7 +98,8 @@ struct BufferizePass : public BufferizePassBase<BufferizePass> {
return converter.isLegal(inputs) && converter.isLegal(results) &&
converter.isLegal(&op.getBody());
});
target.addDynamicallyLegalOp<CallOp, ReturnOp, SelectOp>(typesAreLegal);
target.addDynamicallyLegalOp<CallOp, ReturnOp, SelectOp, ConstantOp>(
typesAreLegal);
OwningRewritePatternList patterns;
mhlo::populateHLOToLHLOConversionPattern(&context, &converter, &patterns);

View File

@ -162,8 +162,8 @@ class ShapeEqualityKnowledge {
/// results.
void build(FuncOp function) {
function.walk([&](Operation *op) {
if (auto reshape = dyn_cast<lmhlo::ReshapeMemRefCastOp>(op)) {
registerAssociation(ShapeValue{reshape.operand()}, reshape.result());
if (auto reshape = dyn_cast<MemRefReshapeOp>(op)) {
registerAssociation(ShapeValue{reshape.source()}, reshape.result());
return;
}
if (auto alloc = dyn_cast<AllocOp>(op)) {

View File

@ -133,7 +133,7 @@ struct PropagateTfAbiKnowledgeToKernelsPass
while (!worklist.empty()) {
Value candidate = worklist.pop_back_val();
for (auto user : candidate.getUsers()) {
if (auto reshape = dyn_cast<lmhlo::ReshapeMemRefCastOp>(user)) {
if (auto reshape = dyn_cast<MemRefReshapeOp>(user)) {
// Reshape propagates alignment and offset.
// TODO(herhut): This should be a trait.
if (allocated_by_runtime.insert(reshape.result()).second) {

View File

@ -19,7 +19,6 @@ limitations under the License.
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
@ -53,13 +52,12 @@ class TFKernelToLLVMPass : public TFKernelToLLVMPassBase<TFKernelToLLVMPass> {
tf_framework::PopulateTFFrameworkToLLVMConversionPatterns(&type_converter,
&patterns);
populateGpuToLLVMConversionPatterns(type_converter, patterns, "gpu.binary");
lmhlo::PopulateLhloToLLVMConversionPatterns(&type_converter, &patterns);
// Set target.
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target
.addIllegalDialect<gpu::GPUDialect, tf_framework::TFFrameworkDialect>();
target.addIllegalDialect<gpu::GPUDialect, StandardOpsDialect,
tf_framework::TFFrameworkDialect>();
target.addIllegalOp<LLVM::DialectCastOp>();
if (failed(applyPartialConversion(m, target, std::move(patterns)))) {

View File

@ -2527,6 +2527,27 @@ func @expand_dims_dynamic(%arg0: tensor<?x?xf32>) -> tensor<?x1x?xf32> {
return %0 : tensor<?x1x?xf32>
}
// CHECK-LABEL: expand_dynamic_dims_rank1_axis
func @expand_dynamic_dims_rank1_axis(%arg0: tensor<?x?x4xf32>) -> tensor<?x1x?x4xf32> {
%axis = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK-DAG: [[SHAPEOF:%.+]] = shape.shape_of %arg0
// CHECK-DAG: [[CST0:%.+]] = constant 0
// CHECK-DAG: [[CST1:%.+]] = constant 1
// CHECK-DAG: [[GETEXTENT0:%.+]] = shape.get_extent [[SHAPEOF]], [[CST0]]
// CHECK-DAG: [[CST1_0:%.+]] = constant 1
// CHECK-DAG: [[GETEXTENT1:%.+]] = shape.get_extent [[SHAPEOF]], [[CST1_0]]
// CHECK-DAG: [[CST2:%.+]] = constant 2
// CHECK-DAG: [[GETEXTENT2:%.+]] = shape.get_extent [[SHAPEOF]], [[CST2]]
// CHECK-DAG: [[FROMEXTENTS:%.+]] = shape.from_extents [[GETEXTENT0]], [[CST1]], [[GETEXTENT1]], [[GETEXTENT2]]
// CHECK-DAG: [[TOEXTENTS:%.+]] = shape.to_extent_tensor [[FROMEXTENTS]]
// CHECK-DAG: [[RESHAPE:%.+]] = "mhlo.dynamic_reshape"(%arg0, [[TOEXTENTS]])
%0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<?x?x4xf32>, tensor<1xi32>) -> tensor<?x1x?x4xf32>
// CHECK: return [[RESHAPE]]
return %0 : tensor<?x1x?x4xf32>
}
// CHECK-LABEL: func @sign
// CHECK-SAME: [[ARG:%arg.*]]: tensor<1x2x3x4xf32>
func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {

View File

@ -5535,10 +5535,7 @@ class ConvertDynamicExpandDimsOp : public OpRewritePattern<TF::ExpandDimsOp> {
llvm::SmallVector<Value, 4> dims;
dims.resize(result_ty.getRank());
auto inserted_dim = expand_dims_attr.getValue({})
.cast<IntegerAttr>()
.getValue()
.getSExtValue();
auto inserted_dim = expand_dims[0].getSExtValue();
// Handle the negative value use case.
if (inserted_dim < 0) {

View File

@ -103,7 +103,6 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::AtanhOp>(),
TypeID::get<TF::AtanOp>(),
TypeID::get<TF::BatchMatMulV2Op>(),
TypeID::get<TF::BatchToSpaceNDOp>(),
TypeID::get<TF::BatchToSpaceOp>(),
TypeID::get<TF::BesselI0eOp>(),
TypeID::get<TF::BesselI1eOp>(),

View File

@ -519,14 +519,21 @@ StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
// ViewOp only takes memrefs without affine maps (layouts). Let ViewOp produce
// the physical shape (where dimensions are ordered in major to minor) first,
// then follow up with a StaticMemRefCastOp to cast the resulting memref to
// then follow up with a MemRefReinterpretCast to cast the resulting memref to
// the original layout.
Value result =
builder_.create<ViewOp>(loc, physical_out_type, alloc, byte_shift,
/*sizes=*/ValueRange{});
if (physical_out_type != out_type)
result = builder_.create<lmhlo::StaticMemRefCastOp>(loc, out_memref_type,
result);
if (physical_out_type != out_type) {
int64_t out_offset;
SmallVector<int64_t, 4> out_strides;
if (failed(getStridesAndOffset(out_memref_type, out_strides, out_offset)))
return tensorflow::errors::Internal(
"Failed to get strides and offset from the output type.");
result = builder_.create<MemRefReinterpretCastOp>(
loc, out_memref_type, result, out_offset, out_memref_type.getShape(),
out_strides, llvm::None, llvm::None, llvm::None);
}
return cached_value = result;
}

View File

@ -1481,7 +1481,7 @@ tf_xla_py_test(
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
use_xla_device = False, # Uses tf.function(experimental_compile=True)
use_xla_device = False, # Uses tf.function(jit_compile=True)
deps = [
":xla_test",
"//tensorflow/compiler/tf2xla/python:xla",

View File

@ -32,7 +32,7 @@ class CaseTest(xla_test.XLATestCase):
def testCaseBasic(self):
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def switch_case_test(branch_index):
def f1():
@ -58,7 +58,7 @@ class CaseTest(xla_test.XLATestCase):
def testBranchIsPruned(self):
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def switch_case_test():
branch_index = array_ops.constant(0)

View File

@ -693,7 +693,7 @@ class EagerFunctionTest(xla_test.XLATestCase):
return x, y
wholly_compiled_f = def_function.function(f)
op_by_op_f = def_function.function(f, experimental_compile=False)
op_by_op_f = def_function.function(f, jit_compile=False)
x = array_ops.identity([0.0, 2.0], name='data')

View File

@ -45,22 +45,22 @@ flags.DEFINE_bool('vary_seed', False,
NUM_SAMPLES = int(1e3)
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def _igamma(a, x):
return math_ops.igamma(a, x)
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def _igammac(a, x):
return math_ops.igammac(a, x)
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def _polygamma(n, x):
return math_ops.polygamma(n, x)
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def _zeta(a, q):
return math_ops.zeta(a, q)
@ -72,7 +72,7 @@ def implicit_reparameterization_grad(a, x):
return -gen_math_ops.igamma_grad_a(a, x) / prob
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def _log1p(x):
return math_ops.log1p(x)

View File

@ -51,7 +51,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
scenarios (e.g. TPU). The new version of stateless_random_* requires the
intermediate tensor `alg` to be compile-time constant, so we need to check
that this requirement is met. We use xla.compile instead of tf.function's
experimental_compile because the latter doesn't throw an error even if the
jit_compile because the latter doesn't throw an error even if the
compile-time-constant constraint is not met.
"""
if config.list_logical_devices('TPU'):

View File

@ -223,12 +223,10 @@ class StridedSliceOp : public XlaOpKernel {
input_elements_sliced *= input_shape.dim_size(d);
}
OP_REQUIRES(
ctx, output_elements == input_elements_sliced,
errors::InvalidArgument(
"The number of output elements ", output_elements,
" has to equal to number of input elements that are sliced ",
input_elements_sliced, " when input indices are not constant."));
OP_REQUIRES(ctx, output_elements == input_elements_sliced,
errors::InvalidArgument(
"Dynamic indices of strided_slice_op have to be leading "
"dimensions in the indices list."));
for (int64 i = 0; i < ctx->InputShape("begin").dims(); ++i) {
OP_REQUIRES(

View File

@ -197,6 +197,9 @@ class XlaCompiler {
// Alias input and output buffers for parameters that are passed-through XLA
// modules without being changed.
bool alias_passthrough_params = false;
// Enable detailed logging of compilation metadata.
bool detailed_logging = true;
};
explicit XlaCompiler(Options options);

Some files were not shown because too many files have changed in this diff Show More