Merge branch 'master' into toupstream/16x8_resize_bilinear
This commit is contained in:
commit
5bc12f3435
@ -36,6 +36,9 @@
|
||||
* Exposing `tf.data.experimental.ExternalStatePolicy`, which can be used
|
||||
to control how external state should be handled during dataset
|
||||
serialization or iterator checkpointing.
|
||||
* XLA compilation:
|
||||
* `tf.function(experimental_compile=True)` has become a stable API,
|
||||
renamed `tf.function(jit_compile=True)`.
|
||||
|
||||
* `tf.lite`:
|
||||
* NNAPI
|
||||
|
@ -1205,7 +1205,7 @@ typedef struct TF_Session TF_Session;
|
||||
// Return a new execution session with the associated graph, or NULL on
|
||||
// error. Does not take ownership of any input parameters.
|
||||
//
|
||||
// *`graph` must be a valid graph (not deleted or nullptr). `graph` will be be
|
||||
// *`graph` must be a valid graph (not deleted or nullptr). `graph` will be
|
||||
// kept alive for the lifetime of the returned TF_Session. New nodes can still
|
||||
// be added to `graph` after this call.
|
||||
TF_CAPI_EXPORT extern TF_Session* TF_NewSession(TF_Graph* graph,
|
||||
|
@ -86,7 +86,7 @@ TF_CAPI_EXPORT void TF_SetXlaConstantFoldingDisabled(
|
||||
|
||||
// Create a serialized tensorflow.ConfigProto proto, where:
|
||||
//
|
||||
// a) ConfigProto.optimizer_options.global_jit_level is set to to ON_1 if
|
||||
// a) ConfigProto.optimizer_options.global_jit_level is set to ON_1 if
|
||||
// `enable_xla_compilation` is non-zero, and OFF otherwise.
|
||||
// b) ConfigProto.gpu_options.allow_growth is set to `gpu_memory_allow_growth`.
|
||||
// c) ConfigProto.device_count is set to `num_cpu_devices`.
|
||||
|
@ -481,7 +481,7 @@ typedef struct TFE_CustomDevice {
|
||||
// "/job:localhost/replica:0/task:0/device:CUSTOM:0".
|
||||
//
|
||||
// The custom device defines copy operations for moving TensorHandles on and
|
||||
// off, and an an execution operation for named operations. Often execution will
|
||||
// off, and an execution operation for named operations. Often execution will
|
||||
// simply wrap op execution on one or more physical devices.
|
||||
//
|
||||
// device_info is an opaque caller-defined type stored with the custom device
|
||||
|
@ -16,9 +16,9 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_
|
||||
|
||||
// Run a function containing a MatMul op and check its output.
|
||||
// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one
|
||||
// which creates a remote remote input, to simulate a scenario that the remote
|
||||
// input is not ready when we start running an op or a function.
|
||||
// If heavy_load_on_streaming_rpc is true, send some rpc requests before the one
|
||||
// which creates a remote input, to simulate a scenario that the remote input
|
||||
// is not ready when we start running an op or a function.
|
||||
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
|
||||
bool heavy_load_on_streaming_rpc,
|
||||
bool remote_func_outputs = false);
|
||||
|
@ -696,7 +696,7 @@ TEST(CAPI, ExecuteAddForwardAsync) {
|
||||
/*tfrt*/ false);
|
||||
}
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
// TODO(b/153349425): Add add forwarding tests for TFRT
|
||||
// TODO(b/153349425): Add forwarding tests for TFRT
|
||||
TEST(CAPI, ExecuteAddTfrt) {
|
||||
ExecuteAdd(
|
||||
/*async=*/false,
|
||||
|
@ -46,8 +46,6 @@ class SavedModelAPI {
|
||||
virtual Status GetSignatureDefFunction(const std::string& signature_def_key,
|
||||
SignatureDefFunction** function) = 0;
|
||||
|
||||
virtual std::vector<ConcreteFunction*> ListFunctions() = 0;
|
||||
|
||||
virtual ~SavedModelAPI() = default;
|
||||
};
|
||||
|
||||
|
@ -211,15 +211,6 @@ Status TFSavedModelAPI::GetSignatureDefFunction(
|
||||
return Status();
|
||||
}
|
||||
|
||||
std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {
|
||||
std::vector<ConcreteFunction*> result;
|
||||
result.reserve(revived_objects_.concrete_functions.size());
|
||||
for (auto& index_and_function : revived_objects_.concrete_functions) {
|
||||
result.push_back(index_and_function.second.get());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Status TFSavedModelAPI::GetVariable(const std::string& variable_path,
|
||||
Variable** variable) {
|
||||
absl::optional<int> node =
|
||||
|
@ -66,8 +66,6 @@ class TFSavedModelAPI : public SavedModelAPI {
|
||||
ImmediateExecutionContext* context,
|
||||
std::unique_ptr<TFSavedModelAPI>* out);
|
||||
|
||||
std::vector<ConcreteFunction*> ListFunctions() override;
|
||||
|
||||
~TFSavedModelAPI() override = default;
|
||||
|
||||
Status GetVariable(const std::string& variable_path, Variable** variable);
|
||||
|
@ -122,9 +122,4 @@ TF_GetSavedModelSignatureDefFunction(TF_SavedModel* model,
|
||||
return tensorflow::wrap(result);
|
||||
}
|
||||
|
||||
TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) {
|
||||
return new TF_ConcreteFunctionList{
|
||||
tensorflow::unwrap(model)->ListFunctions()};
|
||||
}
|
||||
|
||||
} // end extern "C"
|
||||
|
@ -100,11 +100,6 @@ TF_GetSavedModelSignatureDefFunction(TF_SavedModel* model,
|
||||
const char* signature_def_key,
|
||||
TF_Status* status);
|
||||
|
||||
// Returns a list of all ConcreteFunctions stored in this SavedModel.
|
||||
// The lifetime of the returned list is bound to `model`.
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunctionList* TF_ListSavedModelFunctions(
|
||||
TF_SavedModel* model);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
@ -44,7 +44,7 @@ class DummyDevice : public DeviceBase {
|
||||
}
|
||||
};
|
||||
|
||||
// Helper for comparing ouput and expected output
|
||||
// Helper for comparing output and expected output
|
||||
void ExpectSummaryMatches(const Summary& actual, const string& expected_str) {
|
||||
Summary expected;
|
||||
ASSERT_TRUE(protobuf::TextFormat::ParseFromString(expected_str, &expected));
|
||||
|
@ -76,7 +76,7 @@ class Tensor {
|
||||
// unknown rank.
|
||||
int dims() const;
|
||||
|
||||
// Returns the number of elements in in demension `d`.
|
||||
// Returns the number of elements in dimension `d`.
|
||||
// REQUIRES: `0 <= d < dims()`
|
||||
int64_t dim_size(int d) const;
|
||||
|
||||
@ -154,7 +154,7 @@ inline Tensor Tensor::FromBuffer(TF_DataType dtype,
|
||||
// 1. Only a function pointer is sent across the C API (&DeleterFunction)
|
||||
// 2. DeleterFunction is defined in the same build artifact that constructed
|
||||
// the std::function (so there isn't confusion about std::function ABI).
|
||||
// Note that 2. is satisifed by the fact that this is a header-only API, where
|
||||
// Note that 2. is satisfied by the fact that this is a header-only API, where
|
||||
// the function implementations are inline.
|
||||
|
||||
DeleterStruct* deleter_struct = new DeleterStruct{deleter};
|
||||
|
@ -67,7 +67,7 @@ bool IsZero(const Scope& scope, const Output& grad) {
|
||||
// mat: A 2-D tensor of dimension [D0, D1]
|
||||
//
|
||||
// Returns:
|
||||
// A tensor of dimension [D0, D1], the result fo vec * mat.
|
||||
// A tensor of dimension [D0, D1], the result for vec * mat.
|
||||
Output BroadcastMul(const Scope& scope, const Output& vec, const Output& mat) {
|
||||
auto reshaped = ExpandDims(scope, vec, -1);
|
||||
return Multiply(scope, reshaped, mat);
|
||||
|
@ -84,9 +84,6 @@ class SavedModelAPI {
|
||||
SignatureDefFunction* GetSignatureDefFunction(
|
||||
const std::string& function_path, Status* status);
|
||||
|
||||
// Lists all Conrete Functions available from the SavedModel.
|
||||
std::vector<ConcreteFunction*> ListFunctions();
|
||||
|
||||
// SavedModelAPI is movable, but not copyable.
|
||||
SavedModelAPI(SavedModelAPI&&) = default;
|
||||
SavedModelAPI& operator=(SavedModelAPI&&) = default;
|
||||
@ -151,11 +148,6 @@ inline SignatureDefFunction* SavedModelAPI::GetSignatureDefFunction(
|
||||
return SignatureDefFunction::wrap(function);
|
||||
}
|
||||
|
||||
inline std::vector<ConcreteFunction*> SavedModelAPI::ListFunctions() {
|
||||
ConcreteFunctionList list(TF_ListSavedModelFunctions(saved_model_.get()));
|
||||
return list.ToVector();
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
@ -138,7 +138,7 @@ class FreezeTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
|
||||
// "c" isnt dependent on the variable, so nothing should be frozen.
|
||||
// "c" isn't dependent on the variable, so nothing should be frozen.
|
||||
TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
|
||||
graph_def, {"c:0"}, "assign", &saved_model_bundle));
|
||||
|
||||
@ -183,7 +183,7 @@ class FreezeTest : public ::testing::Test {
|
||||
}
|
||||
Output c = ops::Mul(scope.WithOpName("c"), a, read_var);
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
|
||||
// "c" isnt dependent on the variable, so nothing should be frozen.
|
||||
// "c" isn't dependent on the variable, so nothing should be frozen.
|
||||
TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
|
||||
graph_def, {"c:0"}, "assign", &saved_model_bundle));
|
||||
|
||||
@ -244,7 +244,7 @@ class FreezeTest : public ::testing::Test {
|
||||
|
||||
Output c = ops::Mul(scope.WithOpName("c"), a, read_var);
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
|
||||
// "c" isnt dependent on the variable, so nothing should be frozen.
|
||||
// "c" isn't dependent on the variable, so nothing should be frozen.
|
||||
TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
|
||||
graph_def, {"c:0"}, "assign", &saved_model_bundle));
|
||||
|
||||
|
@ -173,7 +173,8 @@ Status XlaCompilationCache::BuildExecutable(
|
||||
build_options.set_result_layout(result.xla_output_shape);
|
||||
build_options.set_device_allocator(options.device_allocator);
|
||||
build_options.set_alias_passthrough_params(options.alias_passthrough_params);
|
||||
|
||||
build_options.mutable_debug_options()->set_xla_detailed_logging(
|
||||
options.detailed_logging);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto executables,
|
||||
client_->Compile(*result.computation, argument_layouts, build_options));
|
||||
|
@ -132,7 +132,8 @@ Status XlaCompileOnDemandOp::Compile(
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
|
||||
platform_info_,
|
||||
/*has_ref_vars=*/true, &tf_allocator_adapter);
|
||||
|
||||
// No detailed logging from on demand op.
|
||||
options.detailed_logging = false;
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
compile_options.is_entry_computation = true;
|
||||
// Optimization: where possible, have the computation return a naked array
|
||||
|
@ -398,7 +398,7 @@ static void ShowXlaDeviceDeprecationWarning(
|
||||
absl::call_once(once, [] {
|
||||
LOG(INFO) << "XLA_GPU and XLA_CPU devices are deprecated and will be "
|
||||
"removed in subsequent releases. Instead, use either "
|
||||
"@tf.function(experimental_compile=True) for must-compile "
|
||||
"@tf.function(jit_compile=True) for must-compile "
|
||||
"semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
|
||||
"for auto-clustering best-effort compilation.";
|
||||
});
|
||||
|
@ -568,22 +568,6 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lhlo_legalize_to_llvm",
|
||||
srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc"],
|
||||
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"],
|
||||
deps = [
|
||||
":lhlo",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LLVMDialect",
|
||||
"@llvm-project//mlir:LLVMTransforms",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "legalize_to_linalg",
|
||||
srcs = ["lib/Dialect/mhlo/transforms/legalize_to_linalg.cc"],
|
||||
@ -952,7 +936,6 @@ cc_library(
|
||||
srcs = [
|
||||
"include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h",
|
||||
"lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc",
|
||||
"lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc",
|
||||
"lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc",
|
||||
"lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc",
|
||||
"lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc",
|
||||
@ -962,7 +945,6 @@ cc_library(
|
||||
":chlo_legalize_to_hlo", # build-cleaner: keep
|
||||
":hlo",
|
||||
":lhlo",
|
||||
":lhlo_legalize_to_llvm", # build-cleaner: keep
|
||||
":materialize_broadcasts", # build-cleaner: keep
|
||||
":pass_details",
|
||||
":unfuse_batch_norm", # build-cleaner: keep
|
||||
|
@ -314,169 +314,6 @@ def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
|
||||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StaticMemRefCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
|
||||
let summary = [{
|
||||
modifies the offset, sizes and strides of a statically shaped memref
|
||||
}];
|
||||
let description = [{
|
||||
Casts the statically shaped memref operand to a memref with optionally
|
||||
modified offsets, sizes and strides.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
%buf_transformed =
|
||||
lmhlo.static_memref_cast %buf
|
||||
: memref<1x5xf32> -> memref<5xf32, offset: 2, strides: [1]>
|
||||
|
||||
// The result of the op is a rank-1 memref with `[5]` shape, stride 1 and
|
||||
// offset 2.
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Arg<LHLO_Buffer, "", []>:$operand);
|
||||
let results = (outs Res<LHLO_Buffer, "", []>:$result);
|
||||
|
||||
let builders = [
|
||||
OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$operand),
|
||||
[{
|
||||
$_state.addOperands(operand);
|
||||
$_state.types.push_back(resultType);
|
||||
}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
|
||||
}];
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($operand) `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DynamicMemRefCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
|
||||
[SameVariadicOperandSize, NoSideEffect,
|
||||
DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
|
||||
let summary = "dynamic memref cast operation";
|
||||
let description = [{
|
||||
Change sizes and strides of a memref using the values computed in runtime.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
%buf_transformed =
|
||||
lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y]
|
||||
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
|
||||
// The result of the op is a type-erased memref with `[%size_X, %size_Y]`
|
||||
// shape and `[%step_X, %step_Y]` strides. The offset will be inherited
|
||||
// from the input.
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", []>:$operand,
|
||||
Variadic<Index>:$sizes,
|
||||
Variadic<Index>:$strides
|
||||
);
|
||||
let results = (outs Res<LHLO_Buffer, "", []>:$result);
|
||||
|
||||
let builders = [
|
||||
OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$operand,
|
||||
"ValueRange":$sizes, "ValueRange":$strides),
|
||||
[{
|
||||
$_state.addOperands(operand);
|
||||
$_state.addOperands(sizes);
|
||||
$_state.addOperands(strides);
|
||||
$_state.types.push_back(resultType);
|
||||
}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
|
||||
}];
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
let assemblyFormat = [{
|
||||
$operand `(` $sizes `)` `[` $strides `]` attr-dict `:` type($operand) `->`
|
||||
type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReshapeMemRefCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ReshapeMemRefCastOp: Op<LHLO_Dialect, "reshape_memref_cast", [
|
||||
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
|
||||
NoSideEffect]> {
|
||||
let summary = "reshape memref cast operation";
|
||||
let description = [{
|
||||
The `reshape_memref_cast` operation converts a memref from one type to an
|
||||
equivalent type with a provided shape. The data is never copied or moved.
|
||||
The source and destination types are compatible if both have the same
|
||||
element type, address space and identity layout map. The following
|
||||
combinations are possible:
|
||||
|
||||
a. Both are ranked memref types.
|
||||
|
||||
```mlir
|
||||
// Reshape statically-shaped memref.
|
||||
%dst = reshape_memref_cast %src(%shape)
|
||||
: (memref<4x1xf32>, memref<1xi32>) to memref<4xf32>
|
||||
%dst0 = reshape_memref_cast %src(%shape0)
|
||||
: (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32>
|
||||
```
|
||||
|
||||
b. Source type is ranked, destination type is unranked.
|
||||
|
||||
```mlir
|
||||
// Reshape dynamically-shaped 1D memref.
|
||||
%dst = reshape_memref_cast %src(%shape)
|
||||
: (memref<?xf32>, memref<?xi32>) to memref<*xf32>
|
||||
```
|
||||
|
||||
c. Source type is unranked, destination type is ranked.
|
||||
|
||||
```mlir
|
||||
// Flatten unranked memref.
|
||||
%dst = reshape_memref_cast %src(%shape)
|
||||
: (memref<*xf32>, memref<1xi32>) to memref<?xf32>
|
||||
```
|
||||
|
||||
d. Both are unranked memref types.
|
||||
|
||||
```mlir
|
||||
// Reshape unranked memref.
|
||||
%dst = reshape_memref_cast %src(%shape)
|
||||
: (memref<*xf32>, memref<?xi32>) to memref<*xf32>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyRankedOrUnrankedMemRef:$operand,
|
||||
LHLO_ExtentBuffer:$shape
|
||||
);
|
||||
let results = (outs AnyRankedOrUnrankedMemRef:$result);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
BaseMemRefType getType() {
|
||||
return getResult().getType().cast<BaseMemRefType>(); }
|
||||
}];
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
let assemblyFormat = [{
|
||||
$operand `(` $shape `)` attr-dict `:` `(` type($operand) `,` type($shape)
|
||||
`)` `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LMHLO Other op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -46,12 +46,6 @@ def LhloLegalizeToGpuPass : Pass<"lhlo-legalize-to-gpu", "FuncOp"> {
|
||||
}
|
||||
|
||||
|
||||
def TestLhloToLLVMPass : Pass<"test-lhlo-legalize-to-llvm", "FuncOp"> {
|
||||
let summary = "Legalize from LHLO dialect to LLVM.";
|
||||
let constructor = "createTestLhloToLLVMPass()";
|
||||
}
|
||||
|
||||
|
||||
def LhloLegalizeToParallelLoopsPass : Pass<"lhlo-legalize-to-parallel-loops", "FuncOp"> {
|
||||
let summary = "Legalize from LHLO dialect to parallel loops.";
|
||||
let constructor = "createLegalizeLhloToParallelLoopsPass()";
|
||||
|
@ -48,12 +48,8 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
|
||||
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass();
|
||||
|
||||
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
||||
/// buffers if necessary. If `results_escape_functions` is set to true,
|
||||
/// allocated buffers for function results will be returned and escape the
|
||||
/// function. Otherwise, the signature is rewritten with extra arguments for the
|
||||
/// buffers that are to be used for results.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
|
||||
bool results_escape_functions = false);
|
||||
/// buffers if necessary.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass();
|
||||
|
||||
// Lowers from HLO dialect to Linalg dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass();
|
||||
|
@ -35,8 +35,6 @@ inline void registerAllMhloPasses() { registerMHLOPasses(); }
|
||||
|
||||
namespace lmhlo {
|
||||
|
||||
std::unique_ptr<Pass> createTestLhloToLLVMPass();
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc"
|
||||
|
||||
|
@ -24,8 +24,6 @@ limitations under the License.
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
class LLVMTypeConverter;
|
||||
class LowerToLLVMOptions;
|
||||
class OwningRewritePatternList;
|
||||
|
||||
// Populates a collection of rewrite patterns to realize element-wise operations
|
||||
@ -94,14 +92,6 @@ void PopulateTrigonometricToApproximationPatterns(
|
||||
|
||||
} // namespace mhlo
|
||||
|
||||
namespace lmhlo {
|
||||
|
||||
/// Collect a set of patterns to convert from the LHLO dialect to LLVM.
|
||||
void PopulateLhloToLLVMConversionPatterns(LLVMTypeConverter *converter,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
} // namespace lmhlo
|
||||
|
||||
namespace chlo {
|
||||
|
||||
// Populates a collection of conversion patterns for legalizing client-HLO to
|
||||
|
@ -2173,10 +2173,21 @@ LogicalResult SliceOp::inferReturnTypes(
|
||||
return success();
|
||||
}
|
||||
|
||||
int64_t rank = ranked_ty.getRank();
|
||||
ShapedType attr_ty = slice.start_indices().getType();
|
||||
if (attr_ty.getRank() != 1 || attr_ty.getNumElements() != rank ||
|
||||
!attr_ty.getElementType().isSignlessInteger(64) ||
|
||||
if (attr_ty.getRank() != 1) {
|
||||
return emitOptionalError(location, "start_indices has rank ",
|
||||
attr_ty.getRank(), " instead of required rank 1");
|
||||
}
|
||||
|
||||
int64_t rank = ranked_ty.getRank();
|
||||
if (attr_ty.getNumElements() != rank) {
|
||||
return emitOptionalError(
|
||||
location, "the number of elements in start_indices (",
|
||||
attr_ty.getNumElements(), ") does not match the rank of the operand (",
|
||||
rank, ")");
|
||||
}
|
||||
|
||||
if (!attr_ty.getElementType().isSignlessInteger(64) ||
|
||||
slice.limit_indices().getType() != attr_ty ||
|
||||
slice.strides().getType() != attr_ty) {
|
||||
// Unfortunately we can't rely on the AllTypesMatch trait for the SliceOp
|
||||
|
@ -88,76 +88,6 @@ void ConstOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||
results.insert<EraseConstOp>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StaticMemRefCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Value StaticMemRefCastOp::getViewSource() { return *getODSOperands(0).begin(); }
|
||||
|
||||
static LogicalResult Verify(StaticMemRefCastOp op) {
|
||||
if (!op.operand().getType().cast<ShapedType>().hasStaticShape())
|
||||
return op.emitOpError("operand must have static shape");
|
||||
if (!op.getType().hasStaticShape())
|
||||
return op.emitOpError("result must have static shape");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DynamicMemRefCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Value DynamicMemRefCastOp::getViewSource() {
|
||||
return *getODSOperands(0).begin();
|
||||
}
|
||||
|
||||
static LogicalResult Verify(DynamicMemRefCastOp op) {
|
||||
// Check if `sizes` and `strides` args are compatible with the result type.
|
||||
if (op.sizes().size() != op.getType().getRank())
|
||||
return op.emitOpError(
|
||||
"`sizes` args count must be equal to the rank of the output memref");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReshapeMemrefCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Value ReshapeMemRefCastOp::getViewSource() { return operand(); }
|
||||
|
||||
static LogicalResult Verify(ReshapeMemRefCastOp op) {
|
||||
Type operandType = op.operand().getType();
|
||||
Type resultType = op.result().getType();
|
||||
|
||||
Type operandElementType = operandType.cast<ShapedType>().getElementType();
|
||||
Type resultElementType = resultType.cast<ShapedType>().getElementType();
|
||||
if (operandElementType != resultElementType)
|
||||
return op.emitOpError(
|
||||
"element types of source and destination memref "
|
||||
"types should be the same");
|
||||
|
||||
if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
|
||||
if (!operandMemRefType.getAffineMaps().empty())
|
||||
return op.emitOpError(
|
||||
"operand memref type should have identity affine map");
|
||||
|
||||
int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
|
||||
auto resultMemRefType = resultType.dyn_cast<MemRefType>();
|
||||
if (resultMemRefType) {
|
||||
if (shapeSize == ShapedType::kDynamicSize)
|
||||
return op.emitOpError(
|
||||
"cannot use shape operand with dynamic length to "
|
||||
"cast statically-ranked memref type");
|
||||
if (shapeSize != resultMemRefType.getRank())
|
||||
return op.emitOpError(
|
||||
"length of shape operand differs from the result's memref rank");
|
||||
|
||||
if (!resultMemRefType.getAffineMaps().empty())
|
||||
return op.emitOpError(
|
||||
"result memref type should have identity affine map");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -67,6 +67,7 @@ add_mlir_library(MhloPasses
|
||||
DEPENDS
|
||||
MLIRhlo_opsIncGen
|
||||
MLIRMhloLowerComplexIncGen
|
||||
MLIRMhloPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
@ -133,8 +134,6 @@ add_mlir_library(LmhloPasses
|
||||
lhlo_fuse_linalg.cc
|
||||
lhlo_legalize_to_affine.cc
|
||||
lhlo_legalize_to_gpu.cc
|
||||
lhlo_legalize_to_llvm.cc
|
||||
lhlo_legalize_to_llvm_pass.cc
|
||||
lhlo_legalize_to_parallel_loops.cc
|
||||
|
||||
DEPENDS
|
||||
|
@ -206,7 +206,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||
// Inserts dynamic memref to change the layout of the memref to put 0-stride
|
||||
// and size of the target dimension if size-1 dimension expansion is
|
||||
// necessary.
|
||||
lmhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
|
||||
MemRefReinterpretCastOp InsertDynamicMemrefCastOp(
|
||||
mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
|
||||
auto loc = op.getLoc();
|
||||
auto operand_type = operand.getType().cast<MemRefType>();
|
||||
@ -259,8 +259,13 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||
makeStridedLinearLayoutMap(dynamic_layout,
|
||||
/*offset=*/0, b->getContext()));
|
||||
|
||||
auto transformed_operand = b->create<lmhlo::DynamicMemRefCastOp>(
|
||||
loc, type_erased_memref_type, operand, sizes, strides);
|
||||
SmallVector<int64_t, 2> static_sizes(sizes.size(),
|
||||
ShapedType::kDynamicSize);
|
||||
SmallVector<int64_t, 2> static_strides(strides.size(),
|
||||
ShapedType::kDynamicStrideOrOffset);
|
||||
auto transformed_operand = b->create<MemRefReinterpretCastOp>(
|
||||
loc, type_erased_memref_type, operand, /*offset=*/0, static_sizes,
|
||||
static_strides, llvm::None, sizes, strides);
|
||||
return transformed_operand;
|
||||
}
|
||||
};
|
||||
@ -284,7 +289,7 @@ struct HloToLhloDynamicReshapeConverter
|
||||
return failure();
|
||||
}
|
||||
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<lmhlo::ReshapeMemRefCastOp>(
|
||||
rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
|
||||
op, result_type, adaptor.operand(), adaptor.output_shape());
|
||||
return success();
|
||||
}
|
||||
@ -504,12 +509,7 @@ struct HloLegalizeToLhlo
|
||||
|
||||
public:
|
||||
HloLegalizeToLhlo() = default;
|
||||
HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {
|
||||
this->results_escape_function = o.results_escape_function.getValue();
|
||||
}
|
||||
explicit HloLegalizeToLhlo(bool results_escape_function) {
|
||||
this->results_escape_function.setValue(results_escape_function);
|
||||
}
|
||||
HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
OwningRewritePatternList patterns;
|
||||
@ -542,13 +542,6 @@ struct HloLegalizeToLhlo
|
||||
isMemRefType);
|
||||
});
|
||||
|
||||
auto kind = results_escape_function
|
||||
? BufferizeTypeConverter::KeepAsFunctionResult
|
||||
: BufferizeTypeConverter::AppendToArgumentsList;
|
||||
converter.setResultConversionKind<UnrankedTensorType, UnrankedMemRefType>(
|
||||
kind);
|
||||
converter.setResultConversionKind<RankedTensorType, MemRefType>(kind);
|
||||
|
||||
populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
|
||||
populateWithBufferizeOpConversionPatterns<mlir::ReturnOp, mlir::ReturnOp,
|
||||
lmhlo::CopyOp>(
|
||||
@ -559,13 +552,6 @@ struct HloLegalizeToLhlo
|
||||
std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
private:
|
||||
Option<bool> results_escape_function{
|
||||
*this, "results-escape-function",
|
||||
llvm::cl::desc(
|
||||
"Allocate the results of functions within the functions body"),
|
||||
llvm::cl::init(false)};
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@ -625,9 +611,8 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
|
||||
bool results_escape_function) {
|
||||
return std::make_unique<HloLegalizeToLhlo>(results_escape_function);
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass() {
|
||||
return std::make_unique<HloLegalizeToLhlo>();
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
|
@ -1,370 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lmhlo {
|
||||
namespace {
|
||||
|
||||
struct StaticMemRefCastOpConverter
|
||||
: public ConvertOpToLLVMPattern<StaticMemRefCastOp> {
|
||||
using ConvertOpToLLVMPattern<StaticMemRefCastOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto cast_op = cast<StaticMemRefCastOp>(op);
|
||||
|
||||
StaticMemRefCastOp::Adaptor operands_adaptor(operands);
|
||||
MemRefDescriptor sourceMemRef(operands_adaptor.operand());
|
||||
|
||||
MemRefType targetMemRefType =
|
||||
cast_op.getResult().getType().cast<MemRefType>();
|
||||
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
|
||||
.dyn_cast_or_null<LLVM::LLVMType>();
|
||||
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
|
||||
return failure();
|
||||
// Create descriptor.
|
||||
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
|
||||
Type llvmTargetElementTy = desc.getElementPtrType();
|
||||
// Set allocated ptr.
|
||||
Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
|
||||
allocated =
|
||||
rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
|
||||
desc.setAllocatedPtr(rewriter, loc, allocated);
|
||||
// Set aligned ptr.
|
||||
Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
|
||||
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
|
||||
desc.setAlignedPtr(rewriter, loc, ptr);
|
||||
|
||||
// Fill size and stride descriptors in memref.
|
||||
auto target_sizes = targetMemRefType.getShape();
|
||||
int64_t target_offset;
|
||||
llvm::SmallVector<int64_t, 4> target_strides;
|
||||
if (failed((getStridesAndOffset(targetMemRefType, target_strides,
|
||||
target_offset))))
|
||||
return failure();
|
||||
|
||||
// Copy offset of `targetMemRef`.
|
||||
desc.setConstantOffset(rewriter, loc, target_offset);
|
||||
for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
|
||||
desc.setConstantSize(rewriter, loc, i, target_sizes[i]);
|
||||
desc.setConstantStride(rewriter, loc, i, target_strides[i]);
|
||||
}
|
||||
rewriter.replaceOp(op, {desc});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct DynamicMemRefCastOpConverter
|
||||
: public ConvertOpToLLVMPattern<DynamicMemRefCastOp> {
|
||||
using ConvertOpToLLVMPattern<DynamicMemRefCastOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto cast_op = cast<DynamicMemRefCastOp>(op);
|
||||
|
||||
DynamicMemRefCastOp::Adaptor operands_adaptor(operands);
|
||||
MemRefDescriptor sourceMemRef(operands_adaptor.operand());
|
||||
|
||||
MemRefType targetMemRefType =
|
||||
cast_op.getResult().getType().cast<MemRefType>();
|
||||
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
|
||||
.dyn_cast_or_null<LLVM::LLVMType>();
|
||||
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
|
||||
return failure();
|
||||
// Create descriptor.
|
||||
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
|
||||
Type llvmTargetElementTy = desc.getElementPtrType();
|
||||
// Set allocated ptr.
|
||||
Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
|
||||
allocated =
|
||||
rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
|
||||
desc.setAllocatedPtr(rewriter, loc, allocated);
|
||||
// Set aligned ptr.
|
||||
Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
|
||||
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
|
||||
desc.setAlignedPtr(rewriter, loc, ptr);
|
||||
// Copy offset of `sourceMemRef`.
|
||||
desc.setOffset(rewriter, loc, sourceMemRef.offset(rewriter, loc));
|
||||
|
||||
// Fill size and stride descriptors in memref.
|
||||
if (!cast_op.sizes().empty()) {
|
||||
auto sizes = operands_adaptor.sizes();
|
||||
auto strides = operands_adaptor.strides();
|
||||
for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
|
||||
desc.setSize(rewriter, loc, i, sizes[i]);
|
||||
desc.setStride(rewriter, loc, i, strides[i]);
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(op, {desc});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ReshapeMemRefCastOpConverter
|
||||
: public ConvertOpToLLVMPattern<ReshapeMemRefCastOp> {
|
||||
using ConvertOpToLLVMPattern<ReshapeMemRefCastOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
|
||||
auto reshape_op = cast<ReshapeMemRefCastOp>(op);
|
||||
auto dst_type = reshape_op.getResult().getType().cast<BaseMemRefType>();
|
||||
auto element_type = dst_type.getElementType();
|
||||
|
||||
auto shape = reshape_op.shape();
|
||||
|
||||
ReshapeMemRefCastOp::Adaptor operands_adaptor(operands);
|
||||
PtrsAndOffset ptrs_n_offset = ExtractMemRefPtrsAndOffset(
|
||||
loc, reshape_op.operand(), operands_adaptor.operand(), &rewriter);
|
||||
|
||||
MemRefDescriptor shape_desc(operands_adaptor.shape());
|
||||
|
||||
auto shape_memref_type = shape.getType().cast<MemRefType>();
|
||||
|
||||
if (shape_memref_type.hasStaticShape()) {
|
||||
auto shape_length = shape_memref_type.getDimSize(0);
|
||||
|
||||
MemRefType targetMemRefType = MemRefType::get(
|
||||
SmallVector<int64_t, 1>(shape_length, 1), element_type);
|
||||
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
|
||||
.dyn_cast_or_null<LLVM::LLVMType>();
|
||||
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
|
||||
return failure();
|
||||
// Create descriptor.
|
||||
auto desc =
|
||||
MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
|
||||
desc.setAllocatedPtr(rewriter, loc, ptrs_n_offset.allocated_ptr);
|
||||
desc.setAlignedPtr(rewriter, loc, ptrs_n_offset.aligned_ptr);
|
||||
desc.setOffset(rewriter, loc, ptrs_n_offset.offset);
|
||||
|
||||
auto llvm_index_type = typeConverter.getIndexType();
|
||||
auto llvm_index_ptr_type = llvm_index_type.getPointerTo();
|
||||
Value stride_carried = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvm_index_type,
|
||||
rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
|
||||
for (int i = shape_length - 1; i >= 0; --i) {
|
||||
Value pos = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvm_index_type,
|
||||
rewriter.getIntegerAttr(rewriter.getIndexType(), i));
|
||||
Value ptr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, llvm_index_ptr_type, shape_desc.alignedPtr(rewriter, loc),
|
||||
ValueRange{pos});
|
||||
Value extracted_size = rewriter.create<LLVM::LoadOp>(loc, ptr);
|
||||
desc.setSize(rewriter, loc, i, extracted_size);
|
||||
desc.setStride(rewriter, loc, i, stride_carried);
|
||||
// Update stride
|
||||
if (i > 0) {
|
||||
stride_carried =
|
||||
rewriter.create<LLVM::MulOp>(loc, stride_carried, extracted_size);
|
||||
}
|
||||
}
|
||||
if (dst_type.isa<MemRefType>()) {
|
||||
rewriter.replaceOp(op, {desc});
|
||||
} else {
|
||||
Value rank = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvm_index_type,
|
||||
rewriter.getIntegerAttr(rewriter.getIndexType(), shape_length));
|
||||
Value alloca =
|
||||
typeConverter.promoteOneMemRefDescriptor(loc, desc, rewriter);
|
||||
Value void_ptr =
|
||||
rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), alloca);
|
||||
auto unranked_desc = UnrankedMemRefDescriptor::pack(
|
||||
rewriter, loc, typeConverter, dst_type.cast<UnrankedMemRefType>(),
|
||||
{rank, void_ptr});
|
||||
rewriter.replaceOp(op, {unranked_desc});
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// The shape is a rank-1 tensor with unknown length.
|
||||
Value result_rank = shape_desc.size(rewriter, loc, 0);
|
||||
// TODO(herhut): Propely handle address spaces.
|
||||
unsigned address_space = 0;
|
||||
auto target_type =
|
||||
typeConverter
|
||||
.convertType(UnrankedMemRefType::get(element_type, address_space))
|
||||
.cast<LLVM::LLVMType>();
|
||||
// Create the unranked memref descriptor that holds the ranked one. The
|
||||
// inner descriptor is allocated on stack.
|
||||
UnrankedMemRefDescriptor target_desc =
|
||||
UnrankedMemRefDescriptor::undef(rewriter, loc, target_type);
|
||||
target_desc.setRank(rewriter, loc, result_rank);
|
||||
SmallVector<Value, 1> sizes;
|
||||
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter,
|
||||
{target_desc}, sizes);
|
||||
auto void_ptr_type = LLVM::LLVMType::getInt8PtrTy(rewriter.getContext());
|
||||
Value ranked_desc_mem = rewriter.create<LLVM::AllocaOp>(
|
||||
loc, void_ptr_type, sizes.front(), llvm::None);
|
||||
target_desc.setMemRefDescPtr(rewriter, loc, ranked_desc_mem);
|
||||
|
||||
// Fill the fixed parts. For this, we cast to a 0-D memref.
|
||||
auto zero_d_memref_type = MemRefType::get({}, element_type);
|
||||
Value as_zero_d = rewriter.create<LLVM::BitcastOp>(
|
||||
loc,
|
||||
typeConverter.convertType(zero_d_memref_type)
|
||||
.cast<LLVM::LLVMType>()
|
||||
.getPointerTo(address_space),
|
||||
ranked_desc_mem);
|
||||
// Some common constants. Use 32 bit where required by gep struct indexes.
|
||||
auto int32_type = typeConverter.convertType(rewriter.getI32Type());
|
||||
Value zero_index = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, typeConverter.getIndexType(), rewriter.getIndexAttr(0));
|
||||
Value zero = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32_type, rewriter.getI32IntegerAttr(0));
|
||||
Value one = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32_type, rewriter.getI32IntegerAttr(1));
|
||||
Value two = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32_type, rewriter.getI32IntegerAttr(2));
|
||||
// Set base_pointer and aligned pointer.
|
||||
auto element_ptr_ptr_type = typeConverter.convertType(element_type)
|
||||
.cast<LLVM::LLVMType>()
|
||||
.getPointerTo(address_space)
|
||||
.getPointerTo(address_space);
|
||||
auto base_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, zero}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.allocated_ptr, base_gep);
|
||||
auto aligned_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, one}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.aligned_ptr, aligned_gep);
|
||||
// Set offset.
|
||||
auto index_ptr_type =
|
||||
typeConverter.getIndexType().getPointerTo(address_space);
|
||||
auto offset_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, as_zero_d, ValueRange({zero_index, two}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.offset, offset_gep);
|
||||
|
||||
// Use the offset pointer as base for further addressing. Copy over the
|
||||
// new shape and compute strides. For this, we need to create a loop from
|
||||
// rank - 1 to 0.
|
||||
Value one_index = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, typeConverter.getIndexType(), rewriter.getIndexAttr(1));
|
||||
auto target_shape_base = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, offset_gep, ValueRange({one}));
|
||||
auto target_strides_base = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, target_shape_base, ValueRange({result_rank}));
|
||||
auto shape_ptr = shape_desc.alignedPtr(rewriter, loc);
|
||||
auto result_rank_minus_one =
|
||||
rewriter.create<LLVM::SubOp>(loc, result_rank, one_index);
|
||||
|
||||
Block *init_block = rewriter.getInsertionBlock();
|
||||
Block *cond_block =
|
||||
rewriter.splitBlock(init_block, rewriter.getInsertionPoint());
|
||||
rewriter.setInsertionPointToEnd(init_block);
|
||||
rewriter.create<LLVM::BrOp>(
|
||||
loc, ValueRange({result_rank_minus_one, one_index}), cond_block);
|
||||
rewriter.setInsertionPointToStart(cond_block);
|
||||
auto index_arg = cond_block->addArgument(typeConverter.getIndexType());
|
||||
auto stride_arg = cond_block->addArgument(typeConverter.getIndexType());
|
||||
auto pred = rewriter.create<LLVM::ICmpOp>(
|
||||
loc, LLVM::LLVMType::getInt1Ty(rewriter.getContext()),
|
||||
LLVM::ICmpPredicate::sge, index_arg, zero_index);
|
||||
|
||||
Block *body_block =
|
||||
rewriter.splitBlock(cond_block, rewriter.getInsertionPoint());
|
||||
rewriter.setInsertionPointToStart(body_block);
|
||||
|
||||
// Copy size from shape to descriptor.
|
||||
auto size_load_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, shape_ptr, ValueRange{index_arg});
|
||||
auto extracted_size = rewriter.create<LLVM::LoadOp>(loc, size_load_gep);
|
||||
auto size_store_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, target_shape_base, ValueRange({index_arg}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, extracted_size, size_store_gep);
|
||||
// Write stride value and compute next one.
|
||||
auto stride_store_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, target_strides_base, ValueRange({index_arg}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, stride_arg, stride_store_gep);
|
||||
auto next_stride =
|
||||
rewriter.create<LLVM::MulOp>(loc, stride_arg, extracted_size);
|
||||
|
||||
// Decrement loop counter and branch back.
|
||||
auto decrement = rewriter.create<LLVM::SubOp>(loc, index_arg, one_index);
|
||||
rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, next_stride}),
|
||||
cond_block);
|
||||
|
||||
Block *remainder =
|
||||
rewriter.splitBlock(body_block, rewriter.getInsertionPoint());
|
||||
|
||||
// Hook up the cond exit to the remainder.
|
||||
rewriter.setInsertionPointToEnd(cond_block);
|
||||
rewriter.create<LLVM::CondBrOp>(loc, pred, body_block, ValueRange(),
|
||||
remainder, ValueRange());
|
||||
|
||||
// Reset position to beginning of new remainder block.
|
||||
rewriter.setInsertionPointToStart(remainder);
|
||||
rewriter.replaceOp(op, {target_desc});
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
struct PtrsAndOffset {
|
||||
Value allocated_ptr;
|
||||
Value aligned_ptr;
|
||||
Value offset;
|
||||
};
|
||||
|
||||
PtrsAndOffset ExtractMemRefPtrsAndOffset(
|
||||
Location loc, Value originalOperand, Value convertedOperand,
|
||||
ConversionPatternRewriter *rewriter) const {
|
||||
Type operandType = originalOperand.getType();
|
||||
Value descriptor_ptr;
|
||||
if (operandType.isa<MemRefType>()) {
|
||||
descriptor_ptr = convertedOperand;
|
||||
} else {
|
||||
UnrankedMemRefDescriptor unranked_descriptor(convertedOperand);
|
||||
Value underlying_desc_ptr =
|
||||
unranked_descriptor.memRefDescPtr(*rewriter, loc);
|
||||
|
||||
Type element_type =
|
||||
operandType.cast<UnrankedMemRefType>().getElementType();
|
||||
LLVM::LLVMType memref_type_0d =
|
||||
typeConverter.convertType(MemRefType::get(/*shape=*/{}, element_type))
|
||||
.cast<LLVM::LLVMType>();
|
||||
descriptor_ptr = rewriter->create<LLVM::BitcastOp>(
|
||||
loc, memref_type_0d.getPointerTo(), underlying_desc_ptr);
|
||||
descriptor_ptr = rewriter->create<LLVM::LoadOp>(loc, descriptor_ptr);
|
||||
}
|
||||
MemRefDescriptor descriptor(descriptor_ptr);
|
||||
PtrsAndOffset result;
|
||||
result.allocated_ptr = descriptor.allocatedPtr(*rewriter, loc);
|
||||
result.aligned_ptr = descriptor.alignedPtr(*rewriter, loc);
|
||||
result.offset = descriptor.offset(*rewriter, loc);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void PopulateLhloToLLVMConversionPatterns(LLVMTypeConverter *converter,
|
||||
OwningRewritePatternList *patterns) {
|
||||
patterns->insert<DynamicMemRefCastOpConverter, ReshapeMemRefCastOpConverter,
|
||||
StaticMemRefCastOpConverter>(*converter);
|
||||
}
|
||||
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
@ -1,63 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lmhlo {
|
||||
namespace {
|
||||
|
||||
class TestLhloToLLVMPass
|
||||
: public ::mlir::PassWrapper<TestLhloToLLVMPass,
|
||||
::mlir::OperationPass<::mlir::ModuleOp>> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<LLVM::LLVMDialect>();
|
||||
}
|
||||
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
ModuleOp m = getOperation();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
LLVMTypeConverter converter(&getContext());
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
PopulateLhloToLLVMConversionPatterns(&converter, &patterns);
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
target.addIllegalDialect<LmhloDialect>();
|
||||
|
||||
if (failed(applyFullConversion(m, target, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createTestLhloToLLVMPass() {
|
||||
return std::make_unique<TestLhloToLLVMPass>();
|
||||
}
|
||||
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-hoisting -buffer-deallocation %s -o - | FileCheck %s
|
||||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @func_op_unranked_arg_result
|
||||
func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
@ -17,7 +17,7 @@ func @dynamic_reshape_from_unranked(
|
||||
return %reshaped : tensor<?xf32>
|
||||
}
|
||||
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>)
|
||||
// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]])
|
||||
// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]])
|
||||
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
|
||||
|
||||
// -----
|
||||
@ -30,5 +30,5 @@ func @dynamic_reshape_to_unranked(
|
||||
return %reshaped : tensor<*xf32>
|
||||
}
|
||||
// CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>)
|
||||
// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]])
|
||||
// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]])
|
||||
// CHECK-SAME: : (memref<?xf32>, memref<?xi32>) -> memref<*xf32>
|
||||
|
@ -1,13 +1,12 @@
|
||||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=PRE,BOTH %s
|
||||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=ESC,BOTH %s
|
||||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck %s
|
||||
|
||||
// BOTH-LABEL: func @attrs
|
||||
// CHECK-LABEL: func @attrs
|
||||
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.exponential"(%tensor_operand)
|
||||
{some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
||||
// CHECK: "lmhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -17,16 +16,13 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
||||
return %arg0 : tensor<4xf32>
|
||||
}
|
||||
// PRE: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
|
||||
// PRE-NEXT: "lmhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> ()
|
||||
// PRE-NEXT: return
|
||||
// ESC: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
|
||||
// ESC-NOT: "lmhlo.copy"
|
||||
// ESC-NEXT: return %[[ARG0]]
|
||||
// CHECK: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
|
||||
// CHECK-NOT: "lmhlo.copy"
|
||||
// CHECK-NEXT: return %[[ARG0]]
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @func_op_long
|
||||
// CHECK-LABEL: func @func_op_long
|
||||
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
%1 = mhlo.maximum %arg0, %arg1 : tensor<4xf32>
|
||||
%2 = mhlo.add %arg0, %1 : tensor<4xf32>
|
||||
@ -35,91 +31,87 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
|
||||
%5 = mhlo.multiply %2, %4 : tensor<4xf32>
|
||||
return %5 : tensor<4xf32>
|
||||
}
|
||||
// PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
|
||||
// ESC: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
|
||||
// BOTH-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// BOTH-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
|
||||
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// BOTH-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
|
||||
// BOTH-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
|
||||
// BOTH-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// BOTH-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
|
||||
// BOTH-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// BOTH-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
|
||||
// BOTH-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
|
||||
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
|
||||
// BOTH-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
|
||||
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
|
||||
// PRE-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
|
||||
// PRE-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
|
||||
// PRE-NEXT: return
|
||||
// ESC-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
|
||||
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
|
||||
// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// CHECK-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
|
||||
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// CHECK-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
|
||||
// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// CHECK-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
|
||||
// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// CHECK-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
|
||||
// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
|
||||
// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @fusion
|
||||
// CHECK-LABEL: func @fusion
|
||||
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
|
||||
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// BOTH: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}})
|
||||
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
|
||||
// CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}})
|
||||
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
|
||||
%tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32>
|
||||
%tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
|
||||
%sum = "mhlo.add"(%tensor_summand_1, %tensor_summand_2)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
|
||||
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
|
||||
// CHECK-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
|
||||
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
|
||||
%tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.multiply"(%sum, %tensor_multiplier)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
|
||||
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
|
||||
// BOTH-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
|
||||
// CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
|
||||
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
|
||||
// CHECK-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
// BOTH-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32>
|
||||
// BOTH-NEXT: return
|
||||
// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32>
|
||||
// CHECK-NEXT: return
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @copy
|
||||
// CHECK-LABEL: func @copy
|
||||
func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.copy"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @exp
|
||||
// CHECK-LABEL: func @exp
|
||||
func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.exponential"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.exponential"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @log
|
||||
// CHECK-LABEL: func @log
|
||||
func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.log"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "lmhlo.log"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.log"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @select
|
||||
// CHECK-LABEL: func @select
|
||||
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
||||
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_pred = tensor_load %pred : memref<2x2xi1>
|
||||
@ -127,34 +119,34 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
||||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
|
||||
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "lmhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @compare
|
||||
// CHECK-LABEL: func @compare
|
||||
func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) {
|
||||
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
|
||||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.compare"(%tensor_lhs, %tensor_rhs)
|
||||
{comparison_direction = "EQ"}
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
|
||||
// BOTH: "lmhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
|
||||
// CHECK: "lmhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
|
||||
tensor_store %tensor_result, %result : memref<2x2xi1>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @broadcast
|
||||
// CHECK-LABEL: func @broadcast
|
||||
func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<5xf32>
|
||||
%tensor_result = "mhlo.broadcast_in_dim"(%tensor_operand)
|
||||
{broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
: (tensor<5xf32>) -> tensor<10x5xf32>
|
||||
// BOTH: "lmhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK: "lmhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
tensor_store %tensor_result, %result : memref<10x5xf32>
|
||||
return
|
||||
}
|
||||
@ -163,56 +155,57 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
|
||||
|
||||
func @external_func() -> tensor<3xi64>
|
||||
|
||||
// BOTH: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
|
||||
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
|
||||
|
||||
// BOTH-LABEL: func @dyn_broadcast
|
||||
// CHECK-LABEL: func @dyn_broadcast
|
||||
func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
||||
// BOTH-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
|
||||
// CHECK-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
|
||||
%tensor_operand = tensor_load %operand : memref<?x?xf32>
|
||||
%c1 = constant 1 : i64
|
||||
%shape = tensor_from_elements %c1, %c1, %c1 : tensor<3xi64>
|
||||
%tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
|
||||
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||
// BOTH: %[[SHAPE:.*]] = tensor_from_elements
|
||||
// BOTH: %[[C0:.*]] = constant 0 : index
|
||||
// BOTH: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
|
||||
// BOTH: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
|
||||
// BOTH: %[[C1:.*]] = constant 1 : index
|
||||
// BOTH: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64>
|
||||
// BOTH: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index
|
||||
// BOTH: %[[C2:.*]] = constant 2 : index
|
||||
// BOTH: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64>
|
||||
// BOTH: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index
|
||||
// BOTH: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]])
|
||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
|
||||
// CHECK: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64>
|
||||
// CHECK: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64>
|
||||
// CHECK: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]])
|
||||
|
||||
// BOTH: %[[C0_:.*]] = constant 0 : index
|
||||
// BOTH: %[[C1_:.*]] = constant 1 : index
|
||||
// CHECK: %[[C0_:.*]] = constant 0 : index
|
||||
// CHECK: %[[C1_:.*]] = constant 1 : index
|
||||
|
||||
// BOTH: %[[C1__:.*]] = constant 1 : index
|
||||
// BOTH: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
|
||||
// BOTH: %[[C0___:.*]] = constant 0 : index
|
||||
// BOTH: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref<?x?xf32>
|
||||
// BOTH: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
|
||||
// BOTH: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
|
||||
// BOTH: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
|
||||
// CHECK: %[[C1__:.*]] = constant 1 : index
|
||||
// CHECK: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
|
||||
// CHECK: %[[C0___:.*]] = constant 0 : index
|
||||
// CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref<?x?xf32>
|
||||
// CHECK: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
|
||||
// CHECK: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
|
||||
// CHECK: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
|
||||
|
||||
// BOTH: %[[C2_:.*]] = constant 2 : index
|
||||
// BOTH: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
|
||||
// BOTH: %[[C1___:.*]] = constant 1 : index
|
||||
// BOTH: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref<?x?xf32>
|
||||
// BOTH: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
|
||||
// BOTH: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
|
||||
// BOTH: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
|
||||
// CHECK: %[[C2_:.*]] = constant 2 : index
|
||||
// CHECK: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
|
||||
// CHECK: %[[C1___:.*]] = constant 1 : index
|
||||
// CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref<?x?xf32>
|
||||
// CHECK: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
|
||||
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
|
||||
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
|
||||
|
||||
// BOTH: %[[TRANSFORMED_MEMREF:.*]] = lmhlo.dynamic_memref_cast
|
||||
// BOTH-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]])
|
||||
// BOTH-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
|
||||
// BOTH-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map>
|
||||
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to
|
||||
// CHECK-SAME: offset: [0],
|
||||
// CHECK-SAME: sizes: {{\[}}%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]]
|
||||
// CHECK-SAME: strides: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
|
||||
// CHECK-SAME: : memref<?x?xf32> to memref<?x?xf32, #map>
|
||||
|
||||
// BOTH: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
|
||||
// BOTH-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
// BOTH-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
|
||||
// CHECK: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
|
||||
// CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
// CHECK-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
|
||||
|
||||
// Do not store the value back to avoid the tensor-store being rewritten to
|
||||
// a copy into the pre-allocated argument.
|
||||
@ -221,7 +214,7 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @complex
|
||||
// CHECK-LABEL: func @complex
|
||||
func @complex(%real: memref<2x2xf32>,
|
||||
%imag: memref<2x2xf32>,
|
||||
%result: memref<2x2xcomplex<f32>>) {
|
||||
@ -229,14 +222,14 @@ func @complex(%real: memref<2x2xf32>,
|
||||
%tensor_imag = tensor_load %imag : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>>
|
||||
// BOTH: "lmhlo.complex"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.complex"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xcomplex<f32>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @complex_dyn
|
||||
// CHECK-LABEL: func @complex_dyn
|
||||
func @complex_dyn(%real: memref<?xf32>,
|
||||
%imag: memref<?xf32>,
|
||||
%result: memref<?xcomplex<f32>>) {
|
||||
@ -244,50 +237,50 @@ func @complex_dyn(%real: memref<?xf32>,
|
||||
%tensor_imag = tensor_load %imag : memref<?xf32>
|
||||
%tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag)
|
||||
: (tensor<?xf32>, tensor<?xf32>) -> tensor<?xcomplex<f32>>
|
||||
// BOTH: "lmhlo.complex"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.complex"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<?xcomplex<f32>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @real
|
||||
// CHECK-LABEL: func @real
|
||||
func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
|
||||
%tensor_result = "mhlo.real"(%tensor_operand)
|
||||
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
|
||||
// BOTH: "lmhlo.real"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.real"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @real_dyn
|
||||
// CHECK-LABEL: func @real_dyn
|
||||
func @real_dyn(%operand: memref<?xcomplex<f32>>, %result: memref<?xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<?xcomplex<f32>>
|
||||
%tensor_result = "mhlo.real"(%tensor_operand)
|
||||
: (tensor<?xcomplex<f32>>) -> tensor<?xf32>
|
||||
// BOTH: "lmhlo.real"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.real"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @imag
|
||||
// CHECK-LABEL: func @imag
|
||||
func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
|
||||
%tensor_result = "mhlo.imag"(%tensor_operand)
|
||||
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
|
||||
// BOTH: "lmhlo.imag"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.imag"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @gather
|
||||
// CHECK-LABEL: func @gather
|
||||
func @gather(%operand: memref<13x7xf32>, %idxs: memref<5xi32>, %result: memref<5x7xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<13x7xf32>
|
||||
%tensor_idxs = tensor_load %idxs : memref<5xi32>
|
||||
@ -302,176 +295,176 @@ func @gather(%operand: memref<13x7xf32>, %idxs: memref<5xi32>, %result: memref<5
|
||||
, name = "gather.71"
|
||||
, slice_sizes = dense<[1, 7]> : tensor<2xi64> }
|
||||
: (tensor<13x7xf32>, tensor<5xi32>) -> tensor<5x7xf32>
|
||||
// BOTH: "lmhlo.gather"(%{{.*}}, %{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.gather"(%{{.*}}, %{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<5x7xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @imag_dyn
|
||||
// CHECK-LABEL: func @imag_dyn
|
||||
func @imag_dyn(%operand: memref<?xcomplex<f32>>, %result: memref<?xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<?xcomplex<f32>>
|
||||
%tensor_result = "mhlo.imag"(%tensor_operand)
|
||||
: (tensor<?xcomplex<f32>>) -> tensor<?xf32>
|
||||
// BOTH: "lmhlo.imag"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.imag"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @iota
|
||||
// CHECK-LABEL: func @iota
|
||||
func @iota(%result: memref<10xi32>) {
|
||||
%tensor_result = "mhlo.iota"()
|
||||
{iota_dimension = 0 : i64} : () -> tensor<10xi32>
|
||||
// BOTH: "lmhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
|
||||
// CHECK: "lmhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
|
||||
tensor_store %tensor_result, %result : memref<10xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @abs
|
||||
// CHECK-LABEL: func @abs
|
||||
func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.abs"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "lmhlo.abs"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.abs"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @ceil
|
||||
// CHECK-LABEL: func @ceil
|
||||
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.ceil"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "lmhlo.ceil"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.ceil"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @convert
|
||||
// CHECK-LABEL: func @convert
|
||||
func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.convert"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}})
|
||||
// BOTH-NOT: tensor_store
|
||||
// CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}})
|
||||
// CHECK-NOT: tensor_store
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @cos
|
||||
// CHECK-LABEL: func @cos
|
||||
func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.cosine"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "lmhlo.cosine"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.cosine"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @floor
|
||||
// CHECK-LABEL: func @floor
|
||||
func @floor(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.floor"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "lmhlo.floor"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.floor"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @neg
|
||||
// CHECK-LABEL: func @neg
|
||||
func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.negate"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "lmhlo.negate"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.negate"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @not
|
||||
// CHECK-LABEL: func @not
|
||||
func @not(%operand: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xi32>
|
||||
%tensor_result = "mhlo.not"(%tensor_operand)
|
||||
: (tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
// BOTH: "lmhlo.not"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.not"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @rsqrt
|
||||
// CHECK-LABEL: func @rsqrt
|
||||
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.rsqrt"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "lmhlo.rsqrt"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.rsqrt"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @sign
|
||||
// CHECK-LABEL: func @sign
|
||||
func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.sign"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "lmhlo.sign"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.sign"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @sqrt
|
||||
// CHECK-LABEL: func @sqrt
|
||||
func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.sqrt"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "lmhlo.sqrt"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.sqrt"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @tanh
|
||||
// CHECK-LABEL: func @tanh
|
||||
func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.tanh"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "lmhlo.tanh"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.tanh"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @remainder
|
||||
// CHECK-LABEL: func @remainder
|
||||
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
|
||||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "lmhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -479,61 +472,60 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
|
||||
// -----
|
||||
|
||||
// Dynamic shape binary element-wise operation.
|
||||
// BOTH-LABEL: func @add_dyn
|
||||
// CHECK-LABEL: func @add_dyn
|
||||
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
|
||||
%result = "mhlo.add"(%lhs, %rhs)
|
||||
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// BOTH: %[[C0:.*]] = constant 0 : index
|
||||
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
|
||||
// BOTH: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
|
||||
// BOTH: %[[C1:.*]] = constant 1 : index
|
||||
// BOTH: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
||||
// BOTH: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
|
||||
// BOTH: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
|
||||
// BOTH: %[[C0_:.*]] = constant 0 : index
|
||||
// BOTH: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
|
||||
// BOTH: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
|
||||
// BOTH: %[[C1_:.*]] = constant 1 : index
|
||||
// BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
|
||||
// BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
||||
// BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
||||
// BOTH: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
|
||||
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
||||
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
|
||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
|
||||
// CHECK: %[[C0_:.*]] = constant 0 : index
|
||||
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
|
||||
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
|
||||
// CHECK: %[[C1_:.*]] = constant 1 : index
|
||||
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
|
||||
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
||||
// CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Dynamic shape unary element-wise operation.
|
||||
// BOTH-LABEL: func @tanh_dyn
|
||||
// CHECK-LABEL: func @tanh_dyn
|
||||
func @tanh_dyn(%arg0: tensor<?x?xf32>) {
|
||||
%result = "mhlo.tanh"(%arg0)
|
||||
: (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// BOTH: %[[C0:.*]] = constant 0 : index
|
||||
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
|
||||
// BOTH: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
|
||||
// BOTH: %[[C1:.*]] = constant 1 : index
|
||||
// BOTH: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
||||
// BOTH: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
|
||||
// BOTH: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
|
||||
// BOTH: %[[C0_:.*]] = constant 0 : index
|
||||
// BOTH: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
|
||||
// BOTH: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
|
||||
// BOTH: %[[C1_:.*]] = constant 1 : index
|
||||
// BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
|
||||
// BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
||||
// BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
||||
// BOTH: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
|
||||
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
||||
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
|
||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
|
||||
// CHECK: %[[C0_:.*]] = constant 0 : index
|
||||
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
|
||||
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
|
||||
// CHECK: %[[C1_:.*]] = constant 1 : index
|
||||
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
|
||||
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
||||
// CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @dot
|
||||
// CHECK-LABEL: func @dot
|
||||
func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
|
||||
// PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
|
||||
// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
|
||||
// BOTH-NEXT: %[[ALLOC:.*]] = alloc
|
||||
// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) {
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
|
||||
// CHECK-NEXT: %[[ALLOC:.*]] = alloc
|
||||
// CHECK: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) {
|
||||
// dot_dimension_numbers = {
|
||||
// lhs_batching_dimensions = dense<> : tensor<0xi64>,
|
||||
// lhs_contracting_dimensions = dense<1> : tensor<1xi64>,
|
||||
@ -542,22 +534,21 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
|
||||
// : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
|
||||
%dot = "mhlo.dot"(%arg0, %arg0)
|
||||
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
|
||||
// PRE: "lmhlo.copy"(%[[ALLOC]], %[[RESULT]])
|
||||
// ESC: return %[[ALLOC]]
|
||||
// CHECK: return %[[ALLOC]]
|
||||
return %dot : tensor<1024x1024xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @conv
|
||||
// CHECK-LABEL: func @conv
|
||||
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> {
|
||||
%c0 = constant 0 : index
|
||||
// BOTH: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
|
||||
// BOTH: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
|
||||
// BOTH-SAME: padding = dense<[
|
||||
// BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
|
||||
// BOTH-SAME: rhs_dilation = dense<[1, 2]>
|
||||
// BOTH-SAME: window_strides = dense<[2, 1]>
|
||||
// CHECK: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
|
||||
// CHECK: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
|
||||
// CHECK-SAME: padding = dense<[
|
||||
// CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
|
||||
// CHECK-SAME: rhs_dilation = dense<[1, 2]>
|
||||
// CHECK-SAME: window_strides = dense<[2, 1]>
|
||||
%out = "mhlo.convolution"(%filter, %input) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
@ -581,18 +572,18 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @reduce
|
||||
// CHECK-LABEL: func @reduce
|
||||
func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
|
||||
// BOTH: %[[OUT:.*]] = alloc() : memref<1xf32>
|
||||
// BOTH: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( {
|
||||
// BOTH: ^bb0(%[[ARG1:.*]]: memref<f32>, %[[ARG2:.*]]: memref<f32>,
|
||||
// BOTH-SAME: %[[ARG3:.*]]: memref<f32>):
|
||||
// BOTH: %[[TMP:.*]] = alloc() : memref<f32>
|
||||
// BOTH: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]])
|
||||
// BOTH: "lmhlo.copy"(%[[TMP]], %[[ARG3]])
|
||||
// BOTH: "lmhlo.terminator"() : () -> ()
|
||||
// BOTH: }) {dimensions = dense<1> : tensor<1xi64>}
|
||||
// BOTH-SAME: : (memref<1x8xf32>, memref<f32>, memref<1xf32>) -> ()
|
||||
// CHECK: %[[OUT:.*]] = alloc() : memref<1xf32>
|
||||
// CHECK: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( {
|
||||
// CHECK: ^bb0(%[[ARG1:.*]]: memref<f32>, %[[ARG2:.*]]: memref<f32>,
|
||||
// CHECK-SAME: %[[ARG3:.*]]: memref<f32>):
|
||||
// CHECK: %[[TMP:.*]] = alloc() : memref<f32>
|
||||
// CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]])
|
||||
// CHECK: "lmhlo.copy"(%[[TMP]], %[[ARG3]])
|
||||
// CHECK: "lmhlo.terminator"() : () -> ()
|
||||
// CHECK: }) {dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-SAME: : (memref<1x8xf32>, memref<f32>, memref<1xf32>) -> ()
|
||||
%0 = "mhlo.reduce"(%arg0, %arg1) ( {
|
||||
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
|
||||
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
|
||||
@ -604,25 +595,25 @@ func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @transpose
|
||||
// CHECK-LABEL: func @transpose
|
||||
func @transpose(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.transpose"(%tensor_operand) {permutation = dense<[1, 0]> : tensor<2xi64>}
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "lmhlo.transpose"(%{{.*}}, %{{.*}}) {permutation = dense<[1, 0]> : tensor<2xi64>}
|
||||
// BOTH-NOT: tensor_store
|
||||
// CHECK: "lmhlo.transpose"(%{{.*}}, %{{.*}}) {permutation = dense<[1, 0]> : tensor<2xi64>}
|
||||
// CHECK-NOT: tensor_store
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @custom_call
|
||||
// BOTH-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>, [[RESULT:%.*]]: memref<4x4xf16>)
|
||||
// CHECK-LABEL: func @custom_call
|
||||
// CHECK-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>, [[RESULT:%.*]]: memref<4x4xf16>)
|
||||
func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memref<4x4xf16>) {
|
||||
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
|
||||
%arg1_tensor = tensor_load %arg1 : memref<2x3xf32>
|
||||
// BOTH: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false}
|
||||
// CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false}
|
||||
%result_tensor = "mhlo.custom_call"(%arg0_tensor, %arg1_tensor)
|
||||
{backend_config = "", call_target_name = "foo", has_side_effect = false}
|
||||
: (tensor<2x2xf32>, tensor<2x3xf32>) -> tensor<4x4xf16>
|
||||
@ -632,10 +623,10 @@ func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memre
|
||||
|
||||
// ----
|
||||
|
||||
// BOTH-LABEL: func @isfinite
|
||||
// CHECK-LABEL: func @isfinite
|
||||
func @isfinite(%arg0: memref<2x2xf32>, %result: memref<2x2xi1>) {
|
||||
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
|
||||
// BOTH: "lmhlo.is_finite"(%{{.*}}, %{{.*}})
|
||||
// CHECK: "lmhlo.is_finite"(%{{.*}}, %{{.*}})
|
||||
%result_tensor = "mhlo.is_finite"(%arg0_tensor) : (tensor<2x2xf32>) -> tensor<2x2xi1>
|
||||
tensor_store %result_tensor, %result: memref<2x2xi1>
|
||||
return
|
||||
@ -644,19 +635,19 @@ func @isfinite(%arg0: memref<2x2xf32>, %result: memref<2x2xi1>) {
|
||||
// -----
|
||||
|
||||
// Test that assuming ops propagate memref types.
|
||||
// BOTH-LABEL: func @shape_assuming_memref
|
||||
// CHECK-LABEL: func @shape_assuming_memref
|
||||
func @shape_assuming_memref(%arg0: tensor<?xf16>) -> tensor<?xf16> {
|
||||
%0 = mhlo.constant dense<0.000000e+00> : tensor<f16>
|
||||
%1 = shape.const_witness true
|
||||
// BOTH: shape.assuming %{{.*}} -> (memref<?xf16>)
|
||||
// CHECK: shape.assuming %{{.*}} -> (memref<?xf16>)
|
||||
%2 = shape.assuming %1 -> (tensor<?xf16>) {
|
||||
%3 = shape.shape_of %arg0 : tensor<?xf16> -> tensor<?xindex>
|
||||
%4 = tensor_cast %3 : tensor<?xindex> to tensor<1xindex>
|
||||
%5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f16>, tensor<1xindex>) -> tensor<?xf16>
|
||||
%6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16>
|
||||
// BOTH: "lmhlo.maximum"(%6, %9, %20) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
|
||||
// CHECK: "lmhlo.maximum"(%6, %9, %20) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
|
||||
%7 = mhlo.maximum %5, %6 : tensor<?xf16>
|
||||
// BOTH: shape.assuming_yield %{{.*}} : memref<?xf16>
|
||||
// CHECK: shape.assuming_yield %{{.*}} : memref<?xf16>
|
||||
shape.assuming_yield %7 : tensor<?xf16>
|
||||
}
|
||||
return %2 : tensor<?xf16>
|
||||
|
@ -267,7 +267,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
||||
%13 = absf %arg3 : f32
|
||||
linalg.yield %13 : f32
|
||||
}
|
||||
%2 = lmhlo.reshape_memref_cast %1(%arg1)
|
||||
%2 = memref_reshape %1(%arg1)
|
||||
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
|
||||
return %2 : memref<*xf32>
|
||||
}
|
||||
@ -279,7 +279,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
||||
// CHECK-NOT: scf.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: absf
|
||||
// CHECK: reshape_memref_cast
|
||||
// CHECK: memref_reshape
|
||||
|
||||
// TILED-LABEL: func @view_result
|
||||
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||
@ -288,7 +288,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
||||
// TILED-NOT: scf.for
|
||||
// TILED: linalg.generic
|
||||
// TILED: absf
|
||||
// TILED: reshape_memref_cast
|
||||
// TILED: memref_reshape
|
||||
|
||||
|
||||
// PLOOP-LABEL: func @view_result
|
||||
@ -297,5 +297,5 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
||||
// PLOOP-NOT: scf.parallel
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: absf
|
||||
// PLOOP: reshape_memref_cast
|
||||
// PLOOP: memref_reshape
|
||||
|
||||
|
@ -1,65 +0,0 @@
|
||||
// RUN: mlir-hlo-opt %s -lower-affine -convert-scf-to-std -test-lhlo-legalize-to-llvm -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @static_memref_cast
|
||||
func @static_memref_cast(%buf : memref<10x1x5xf32>) {
|
||||
%0 = lmhlo.static_memref_cast %buf
|
||||
: memref<10x1x5xf32> -> memref<10x5xf32, offset: 2, strides: [5, 1]>
|
||||
return
|
||||
}
|
||||
// CHECK: %[[INPUT_MEMREF_BLDR:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE_3D:!.*]]
|
||||
// CHECK: llvm.insertvalue
|
||||
// CHECK: %[[MEMREF_BLDR_0:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE_2D:!.*]]
|
||||
|
||||
// CHECK: %[[IN_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF:.*]][0] : [[DESCRIPTOR_TYPE_3D]]
|
||||
// CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm.ptr<float> to !llvm.ptr<float>
|
||||
// CHECK: %[[MEMREF_BLDR_1:.*]] = llvm.insertvalue %[[PTR]], %[[MEMREF_BLDR_0]][0] : [[DESCRIPTOR_TYPE_2D]]
|
||||
|
||||
// CHECK: %[[IN_ALIGNED_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][1] : [[DESCRIPTOR_TYPE_3D]]
|
||||
// CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm.ptr<float> to !llvm.ptr<float>
|
||||
// CHECK: %[[MEMREF_BLDR_2:.*]] = llvm.insertvalue %[[ALIGNED_PTR]], %[[MEMREF_BLDR_1]][1] : [[DESCRIPTOR_TYPE_2D]]
|
||||
|
||||
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
|
||||
// CHECK: %[[MEMREF_BLDR_3:.*]] = llvm.insertvalue %[[C2]], %[[MEMREF_BLDR_2]][2] : [[DESCRIPTOR_TYPE_2D]]
|
||||
|
||||
// CHECK: %[[C10:.*]] = llvm.mlir.constant(10 : index) : !llvm.i64
|
||||
// CHECK: %[[MEMREF_BLDR_4:.*]] = llvm.insertvalue %[[C10]], %[[MEMREF_BLDR_3]][3, 0] : [[DESCRIPTOR_TYPE_2D]]
|
||||
// CHECK: %[[C5:.*]] = llvm.mlir.constant(5 : index) : !llvm.i64
|
||||
// CHECK: %[[MEMREF_BLDR_5:.*]] = llvm.insertvalue %[[C5]], %[[MEMREF_BLDR_4]][4, 0] : [[DESCRIPTOR_TYPE_2D]]
|
||||
// CHECK: %[[C5_:.*]] = llvm.mlir.constant(5 : index) : !llvm.i64
|
||||
// CHECK: %[[MEMREF_BLDR_6:.*]] = llvm.insertvalue %[[C5_]], %[[MEMREF_BLDR_5]][3, 1] : [[DESCRIPTOR_TYPE_2D]]
|
||||
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: %[[MEMREF_BLDR_7:.*]] = llvm.insertvalue %[[C1]], %[[MEMREF_BLDR_6]][4, 1] : [[DESCRIPTOR_TYPE_2D]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @dynamic_memref_cast
|
||||
func @dynamic_memref_cast(%buf : memref<?x?xf32>) {
|
||||
%size_X = constant 10 : index
|
||||
%size_Y = constant 50 : index
|
||||
%stride_X = constant 1 : index
|
||||
%stride_Y = constant 0 : index
|
||||
%0 = lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y]
|
||||
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
|
||||
return
|
||||
}
|
||||
// CHECK: %[[C10:.*]] = llvm.mlir.constant(10 : index) : !llvm.i64
|
||||
// CHECK: %[[C50:.*]] = llvm.mlir.constant(50 : index) : !llvm.i64
|
||||
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
|
||||
// CHECK: %[[MEMREF_BLDR_0:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE:!.*]]
|
||||
|
||||
// CHECK: %[[IN_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF:.*]][0] : [[DESCRIPTOR_TYPE]]
|
||||
// CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm.ptr<float> to !llvm.ptr<float>
|
||||
// CHECK: %[[MEMREF_BLDR_1:.*]] = llvm.insertvalue %[[PTR]], %[[MEMREF_BLDR_0]][0] : [[DESCRIPTOR_TYPE]]
|
||||
|
||||
// CHECK: %[[IN_ALIGNED_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][1] : [[DESCRIPTOR_TYPE]]
|
||||
// CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm.ptr<float> to !llvm.ptr<float>
|
||||
// CHECK: %[[MEMREF_BLDR_2:.*]] = llvm.insertvalue %[[ALIGNED_PTR]], %[[MEMREF_BLDR_1]][1] : [[DESCRIPTOR_TYPE]]
|
||||
|
||||
// CHECK: %[[SRC_OFFSET:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][2] : [[DESCRIPTOR_TYPE]]
|
||||
// CHECK: %[[MEMREF_BLDR_3:.*]] = llvm.insertvalue %[[SRC_OFFSET]], %[[MEMREF_BLDR_2]][2] : [[DESCRIPTOR_TYPE]]
|
||||
// CHECK: %[[MEMREF_BLDR_4:.*]] = llvm.insertvalue %[[C10]], %[[MEMREF_BLDR_3]][3, 0] : [[DESCRIPTOR_TYPE]]
|
||||
// CHECK: %[[MEMREF_BLDR_5:.*]] = llvm.insertvalue %[[C1]], %[[MEMREF_BLDR_4]][4, 0] : [[DESCRIPTOR_TYPE]]
|
||||
// CHECK: %[[MEMREF_BLDR_6:.*]] = llvm.insertvalue %[[C50]], %[[MEMREF_BLDR_5]][3, 1] : [[DESCRIPTOR_TYPE]]
|
||||
// CHECK: %[[MEMREF_BLDR_7:.*]] = llvm.insertvalue %[[C0]], %[[MEMREF_BLDR_6]][4, 1] : [[DESCRIPTOR_TYPE]]
|
@ -429,120 +429,6 @@ func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memr
|
||||
|
||||
// -----
|
||||
|
||||
func @static_memref_cast(%in: memref<10x1xf32>) {
|
||||
%out = lmhlo.static_memref_cast %in
|
||||
: memref<10x1xf32> -> memref<10xf32, offset: 0, strides: [1]>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @static_memref_cast
|
||||
|
||||
// -----
|
||||
|
||||
func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) {
|
||||
// expected-error @+1 {{operand must have static shape}}
|
||||
%out = lmhlo.static_memref_cast %in
|
||||
: memref<10x?xf32> -> memref<10x1xf32, offset: 0, strides: [10, 1]>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) {
|
||||
// expected-error @+1 {{result must have static shape}}
|
||||
%out = lmhlo.static_memref_cast %in
|
||||
: memref<10x1xf32> -> memref<10x?xf32, offset: 0, strides: [?, ?]>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dynamic_memref_cast(%in: memref<?xf32>) {
|
||||
%size = constant 10 : index
|
||||
%step = constant 1 : index
|
||||
%out = lmhlo.dynamic_memref_cast %in(%size)[%step]
|
||||
: memref<?xf32> -> memref<?xf32, offset: 0, strides: [?]>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @dynamic_memref_cast
|
||||
|
||||
// -----
|
||||
|
||||
func @dynamic_memref_cast_incompatible_result_type(%in: memref<?xf32>) {
|
||||
// expected-error @+3 {{`sizes` args count must be equal to the rank of the output memref}}
|
||||
%size = constant 10 : index
|
||||
%step = constant 1 : index
|
||||
%out = lmhlo.dynamic_memref_cast %in(%size)[%step]
|
||||
: memref<?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
|
||||
return
|
||||
}
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @reshape_memref_cast(
|
||||
func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
|
||||
%shape2: memref<2xi32>, %shape3: memref<?xi32>) {
|
||||
// CHECK-SAME: [[UNRANKED:%.*]]: memref<*xf32>, [[SHAPE_1:%.*]]: memref<1xi32>,
|
||||
// CHECK-SAME: [[SHAPE_2:%.*]]: memref<2xi32>, [[SHAPE_3:%.*]]: memref<?xi32>
|
||||
|
||||
// CHECK-NEXT: [[DYN_VEC:%.*]] = lmhlo.reshape_memref_cast [[UNRANKED]]
|
||||
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
|
||||
%dyn_vec = lmhlo.reshape_memref_cast %unranked(%shape1)
|
||||
: (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
|
||||
|
||||
// CHECK-NEXT: [[DYN_MAT:%.*]] = lmhlo.reshape_memref_cast [[DYN_VEC]]
|
||||
// CHECK-SAME: : (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32>
|
||||
%dyn_mat = lmhlo.reshape_memref_cast %dyn_vec(%shape2)
|
||||
: (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32>
|
||||
|
||||
// CHECK-NEXT: {{%.*}} = lmhlo.reshape_memref_cast [[DYN_MAT]]
|
||||
// CHECK-SAME: : (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
|
||||
%new_unranked = lmhlo.reshape_memref_cast %dyn_mat(%shape3)
|
||||
: (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @reshape_memref_cast_element_type_mismatch(
|
||||
%buf: memref<*xf32>, %shape: memref<1xi32>) {
|
||||
// expected-error @+1 {{element types of source and destination memref types should be the same}}
|
||||
lmhlo.reshape_memref_cast %buf(%shape)
|
||||
: (memref<*xf32>, memref<1xi32>) -> memref<?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @reshape_memref_cast_dst_ranked_shape_unranked(
|
||||
%buf: memref<*xf32>, %shape: memref<?xi32>) {
|
||||
// expected-error @+1 {{cannot use shape operand with dynamic length to cast statically-ranked memref type}}
|
||||
lmhlo.reshape_memref_cast %buf(%shape)
|
||||
: (memref<*xf32>, memref<?xi32>) -> memref<?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @reshape_memref_cast_dst_shape_rank_mismatch(
|
||||
%buf: memref<*xf32>, %shape: memref<1xi32>) {
|
||||
// expected-error @+1 {{length of shape operand differs from the result's memref rank}}
|
||||
lmhlo.reshape_memref_cast %buf(%shape)
|
||||
: (memref<*xf32>, memref<1xi32>) -> memref<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @reshape_memref_cast_affine_map_is_not_identity(
|
||||
%buf: memref<4x4xf32, offset: 0, strides: [3, 2]>,
|
||||
%shape: memref<1xi32>) {
|
||||
// expected-error @+1 {{operand memref type should have identity affine map}}
|
||||
lmhlo.reshape_memref_cast %buf(%shape)
|
||||
: (memref<4x4xf32, offset: 0, strides: [3, 2]>, memref<1xi32>)
|
||||
-> memref<8xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @atan2_memrefs
|
||||
func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
"lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
|
@ -700,7 +700,7 @@ func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> {
|
||||
|
||||
func @slice_indices_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> {
|
||||
// expected-error@+1 {{failed to verify that all of {start_indices, limit_indices, strides} have same type}}
|
||||
%0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 2, 3]> : tensor<3xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32>
|
||||
%0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 2]> : tensor<2xi64>, limit_indices = dense<[2, 4, 1]> : tensor<3xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32>
|
||||
return %0 : tensor<1x4xi32>
|
||||
}
|
||||
|
||||
@ -714,6 +714,30 @@ func @slice_operand_result_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
func @slice_indices_not_rank_1(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> {
|
||||
// expected-error@+1 {{start_indices has rank 2 instead of required rank 1}}
|
||||
%0 = "mhlo.slice"(%arg0) {
|
||||
start_indices = dense<[[1, 0]]> : tensor<1x2xi64>,
|
||||
limit_indices = dense<[[2, 4]]> : tensor<1x2xi64>,
|
||||
strides = dense<[[1, 2]]> : tensor<1x2xi64>
|
||||
} : (tensor<3x4xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @slice_indices_wrong_size(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> {
|
||||
// expected-error@+1 {{the number of elements in start_indices (3) does not match the rank of the operand (2)}}
|
||||
%0 = "mhlo.slice"(%arg0) {
|
||||
start_indices = dense<[1, 0, 0]> : tensor<3xi64>,
|
||||
limit_indices = dense<[2, 4, 0]> : tensor<3xi64>,
|
||||
strides = dense<[1, 2, 0]> : tensor<3xi64>
|
||||
} : (tensor<3x4xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @dynamic_slice
|
||||
func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
|
||||
%0 = "mhlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
|
||||
|
@ -265,6 +265,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
|
@ -148,18 +148,10 @@ bool IsI64Type(Type element_type) {
|
||||
bool VerifyAddOpShapeConstraints(AddOp op) {
|
||||
auto element_type = getElementTypeOrSelf(op.output().getType());
|
||||
|
||||
// Allows F32, QI8, and QUI8 outputs when the operands have valid shapes,
|
||||
// Allows F32, QI8, QUI8 and I32 outputs when the operands have valid shapes,
|
||||
// which are broadcastable shapes up to five dimension or have same shapes.
|
||||
if (element_type.isF32() || IsQI8Type(element_type) ||
|
||||
IsQUI8Type(element_type)) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/5);
|
||||
}
|
||||
|
||||
// Allows I32 output when the operands have valid shapes, which are
|
||||
// broadcastable shapes up to four dimension or have same shapes.
|
||||
if (IsI32Type(element_type)) {
|
||||
IsQUI8Type(element_type) || IsI32Type(element_type)) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/4);
|
||||
@ -211,20 +203,13 @@ bool VerifyMulOpShapeConstraints(MulOp op) {
|
||||
}
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/5);
|
||||
/*max_bcast_rank=*/4);
|
||||
}
|
||||
|
||||
// Allows F32 output when the operands have valid shapes, which are
|
||||
// broadcastable shapes up to five dimension or have same shapes.
|
||||
if (element_type.isF32()) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/5);
|
||||
}
|
||||
|
||||
// Allows I32 and QI16 outputs when the operands have valid shapes, which are
|
||||
// broadcastable shapes up to four dimension or have same shapes.
|
||||
if (IsI32Type(element_type) || IsQI16Type(element_type)) {
|
||||
// Allows I32, QI16 and F32 outputs when the operands have valid shapes, which
|
||||
// are broadcastable shapes up to four dimension or have same shapes.
|
||||
if (IsI32Type(element_type) || IsQI16Type(element_type) ||
|
||||
element_type.isF32()) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/4);
|
||||
|
@ -4407,4 +4407,55 @@ def TFL_CustomTfOp : Op<TFL_Dialect, "custom_tf", [
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
}
|
||||
|
||||
def TFL_BroadcastToOp : TFL_Op<"broadcast_to", [
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_OperandHasRankAtMost<0, 8>,
|
||||
TFL_OperandHasRank<1, 1>,
|
||||
PredOpTrait<"output dimension count must be at most 8",
|
||||
Or<[TFL_OperandIsUnrankedPred<1>,
|
||||
TFL_OperandDimIsAtMost<1, 0, 8>]>>,
|
||||
NoSideEffect]> {
|
||||
let summary = "Broadcast an array for a compatible shape.";
|
||||
|
||||
let description = [{
|
||||
Broadcasting is the process of making arrays to have compatible shapes
|
||||
for arithmetic operations. Two shapes are compatible if for each
|
||||
dimension pair they are either equal or one of them is one. When trying
|
||||
to broadcast a Tensor to a shape, it starts with the trailing dimensions,
|
||||
and works its way forward.
|
||||
|
||||
For example,
|
||||
|
||||
>>> x = tf.constant([1, 2, 3])
|
||||
>>> y = tf.broadcast_to(x, [3, 3])
|
||||
>>> print(y)
|
||||
tf.Tensor(
|
||||
[[1 2 3]
|
||||
[1 2 3]
|
||||
[1 2 3]], shape=(3, 3), dtype=int32)
|
||||
|
||||
In the above example, the input Tensor with the shape of `[1, 3]`
|
||||
is broadcasted to output Tensor with shape of `[3, 3]`.
|
||||
|
||||
When doing broadcasted operations such as multiplying a tensor
|
||||
by a scalar, broadcasting (usually) confers some time or space
|
||||
benefit, as the broadcasted tensor is never materialized.
|
||||
|
||||
However, `broadcast_to` does not carry with it any such benefits.
|
||||
The newly-created tensor takes the full memory of the broadcasted
|
||||
shape. (In a graph context, `broadcast_to` might be fused to
|
||||
subsequent operation and then be optimized away, however.)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I32, I1, I8, QI8, UI8, QUI8, I16, QI16, I64, Complex<F<32>>]>:$input,
|
||||
TFL_I32OrI64Tensor:$shape
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I32, I1, I8, QI8, UI8, QUI8, I16, QI16, I64, Complex<F<32>>]>:$output
|
||||
);
|
||||
}
|
||||
|
||||
#endif // TFL_OPS
|
||||
|
@ -289,7 +289,8 @@ Status ConvertMLIRToTFLiteFlatBuffer(
|
||||
absl::StrCat(toco_flags.dump_graphviz_dir(), "/toco_AT_IMPORT.dot")));
|
||||
}
|
||||
|
||||
mlir::PassManager pm(module->getContext());
|
||||
mlir::PassManager pm(module->getContext(),
|
||||
mlir::OpPassManager::Nesting::Implicit);
|
||||
|
||||
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm, session);
|
||||
// Convert back to outlined while format for export back to flatbuffer.
|
||||
|
@ -75,7 +75,7 @@ TfLiteStatus QuantizeModel(
|
||||
}
|
||||
|
||||
// Apply quantization passes
|
||||
PassManager pm(module->getContext());
|
||||
PassManager pm(module->getContext(), OpPassManager::Nesting::Implicit);
|
||||
TFL::QuantizationSpecs quant_specs;
|
||||
quant_specs.inference_type = tflite::TflTypeToTfType(inference_type);
|
||||
quant_specs.post_training_quantization = true;
|
||||
|
@ -57,7 +57,7 @@ TfLiteStatus SparsifyModel(const tflite::ModelT& input_model,
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
PassManager pm(module->getContext());
|
||||
PassManager pm(module->getContext(), OpPassManager::Nesting::Implicit);
|
||||
pm.addPass(TFL::CreateDenseToSparsePass());
|
||||
|
||||
if (failed(pm.run(module.get()))) {
|
||||
|
@ -25,13 +25,6 @@ func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: te
|
||||
return %0 : tensor<1x2x3x4x5x6x7x8xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testAddTooHighBroadcastableDims
|
||||
func @testAddTooHighBroadcastableDims(%arg0: tensor<1x2x3x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// expected-error @+1 {{'tfl.add' op failed to verify that operand #0 and operand #1 have the same shape or broadcastable shapes within the rank 4}}
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
|
||||
%2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32>
|
||||
return %2: tensor<1xf32>
|
||||
@ -1520,6 +1513,24 @@ func @UnidirectionalRnn(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x28xf32>) {
|
||||
// CHECK: return [[VAL_4]] : tensor<28x1x28xf32>
|
||||
// CHECK: }
|
||||
|
||||
func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
|
||||
return %0: tensor<3x3xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_f32
|
||||
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
|
||||
// CHECK: return [[BCT]] : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
|
||||
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
|
||||
return %0: tensor<3x3xi32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_i32
|
||||
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return [[BCT]] : tensor<3x3xi32>
|
||||
}
|
||||
|
||||
func @matmul_batch(%arg0: tensor<10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<10x17xf32> {
|
||||
%0 = "tf.BatchMatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} :
|
||||
(tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32>
|
||||
@ -1550,7 +1561,11 @@ func @select_v2_with_6d_broadcasting(%arg0: tensor<1x1x1x1x3x1xi1>, %arg1 : tens
|
||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2): (tensor<1x1x1x1x3x1xi1>, tensor<1x1x1x1x1x4xf32>, tensor<1x1x1x2x1x1xf32>) -> tensor<1x1x1x2x3x4xf32>
|
||||
return %0 : tensor<1x1x1x2x3x4xf32>
|
||||
// CHECK-LABEL: select_v2_with_6d_broadcasting
|
||||
// CHECK: "tf.SelectV2"(%arg0, %arg1, %arg2)
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 1, 1, 2, 3, 4]> : tensor<6xi64>
|
||||
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCT_0:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: [[BCT_1:%.*]] = "tfl.broadcast_to"(%arg2, [[CST]])
|
||||
// CHECK: "tfl.select"([[BCT]], [[BCT_0]], [[BCT_1]])
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -1560,7 +1575,9 @@ func @maximum_with_6d_broadcasting(%arg0: tensor<1x1x1x1x8x16xf32>, %arg1: tenso
|
||||
return %0 : tensor<1x1x1x1x8x16xf32>
|
||||
|
||||
// CHECK-LABEL: maximum_with_6d_broadcasting
|
||||
// CHECK: "tf.Maximum"(%arg0, %arg1)
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 1, 1, 1, 8, 16]> : tensor<6xi64>
|
||||
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: "tfl.maximum"(%arg0, [[BCT]])
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -1569,7 +1586,171 @@ func @add_with_int32_5d_inputs(%arg0: tensor<1x1x1x3x1xi32>, %arg1 : tensor<1x1x
|
||||
%0 = "tf.Add"(%arg0, %arg1): (tensor<1x1x1x3x1xi32>, tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32>
|
||||
return %0 : tensor<1x1x1x3x4xi32>
|
||||
// CHECK-LABEL: add_with_int32_5d_inputs
|
||||
// CHECK: "tf.Add"(%arg0, %arg1)
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 1, 1, 3, 4]> : tensor<5xi64>
|
||||
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCT_0:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: tfl.add [[BCT]], [[BCT_0]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testAddWithBroadcastToOps
|
||||
func @testAddWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: tfl.add [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSubWithBroadcastToOps
|
||||
func @testSubWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: tfl.sub [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
|
||||
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testMulWithBroadcastToOps
|
||||
func @testMulWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: tfl.mul [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
|
||||
%0 = "tf.Mul"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testDivWithBroadcastToOps
|
||||
func @testDivWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: tfl.div [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
|
||||
%0 = "tf.Div"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFloorDivWithBroadcastToOps
|
||||
func @testFloorDivWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: tfl.floor_div [[BCAST]], [[BCAST_1]] : tensor<1x2x3x4x5x6xi32>
|
||||
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFloorModWithBroadcastToOps
|
||||
func @testFloorModWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: "tfl.floor_mod"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testPowWithBroadcastToOps
|
||||
func @testPowWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: tfl.pow [[BCAST]], [[BCAST_1]] : tensor<1x2x3x4x5x6xi32>
|
||||
%0 = "tf.Pow"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testMaximumWithBroadcastToOps
|
||||
func @testMaximumWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: "tfl.maximum"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
%0 = "tf.Maximum"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testMinimumWithBroadcastToOps
|
||||
func @testMinimumWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: "tfl.minimum"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
%0 = "tf.Minimum"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSelectV2WithBroadcastToOps
|
||||
func @testSelectV2WithBroadcastToOps(%arg0: tensor<1x2x1x4x1x6xi1>, %arg1: tensor<1x2x3x4x1x1xi32>, %arg2: tensor<1x2x1x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: [[BCAST_2:%.*]] = "tfl.broadcast_to"(%arg2, [[CST]])
|
||||
// CHECK: "tfl.select"([[BCAST]], [[BCAST_1]], [[BCAST_2]])
|
||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1x2x1x4x1x6xi1>, tensor<1x2x3x4x1x1xi32>, tensor<1x2x1x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testLessEqualWithBroadcastToOps
|
||||
func @testLessEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: "tfl.less_equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
%0 = "tf.LessEqual"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
return %0 : tensor<1x2x3x4x5x6xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testGreaterEqualWithBroadcastToOps
|
||||
func @testGreaterEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: "tfl.greater_equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
%0 = "tf.GreaterEqual"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
return %0 : tensor<1x2x3x4x5x6xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testEqualWithBroadcastToOps
|
||||
func @testEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: "tfl.equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
return %0 : tensor<1x2x3x4x5x6xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testNotEqualWithBroadcastToOps
|
||||
func @testNotEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: "tfl.not_equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
%0 = "tf.NotEqual"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
return %0 : tensor<1x2x3x4x5x6xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testLessWithBroadcastToOps
|
||||
func @testLessWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: "tfl.less"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
%0 = "tf.Less"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
return %0 : tensor<1x2x3x4x5x6xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testGreaterWithBroadcastToOps
|
||||
func @testGreaterWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: "tfl.greater"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
return %0 : tensor<1x2x3x4x5x6xi1>
|
||||
}
|
||||
|
||||
func @tranpose_int32_perm(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
|
||||
|
@ -2419,3 +2419,21 @@ func @invalid_two_dynamic_dims_on_reshape(%arg0: tensor<3x4xi32>, %arg1: tensor<
|
||||
%0 = "tfl.reshape"(%arg0, %arg1) : (tensor<3x4xi32>, tensor<?x?x4xi32>) -> tensor<1x3x4xi32>
|
||||
return %0 : tensor<1x3x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testBroadcastToWithI32ShapeTensor
|
||||
func @testBroadcastToWithI32ShapeTensor(tensor<?x?x?x?x?x?xf32>, tensor<8xi32>) -> tensor<?x?x?x?x?x?x?x?xf32> {
|
||||
^bb0(%arg0: tensor<?x?x?x?x?x?xf32>, %arg1: tensor<8xi32>):
|
||||
// CHECK: "tfl.broadcast_to"(%arg0, %arg1)
|
||||
%0 = "tfl.broadcast_to"(%arg0, %arg1): (tensor<?x?x?x?x?x?xf32>, tensor<8xi32>) -> tensor<?x?x?x?x?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?x?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testBroadcastToWithI64ShapeTensor
|
||||
func @testBroadcastToWithI64ShapeTensor(tensor<?x?x?x?x?x?xf32>, tensor<8xi64>) -> tensor<?x?x?x?x?x?x?x?xf32> {
|
||||
^bb0(%arg0: tensor<?x?x?x?x?x?xf32>, %arg1: tensor<8xi64>):
|
||||
// CHECK: "tfl.broadcast_to"(%arg0, %arg1)
|
||||
%0 = "tfl.broadcast_to"(%arg0, %arg1): (tensor<?x?x?x?x?x?xf32>, tensor<8xi64>) -> tensor<?x?x?x?x?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?x?x?x?x?xf32>
|
||||
}
|
||||
|
@ -1382,3 +1382,18 @@ func @fuseScalarAddIntoConv2dHalf(%arg0: tensor<256x32x32x3xf16>, %arg1: tensor<
|
||||
// CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xf16>
|
||||
// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fuseExpanded1DMulIntoConv2d
|
||||
func @fuseExpanded1DMulIntoConv2d(%arg0: tensor<1x8x8x207xf32>) -> tensor<1x8x8x256xf32> {
|
||||
%cst_0 = constant dense<1.4> : tensor<256x3x3x207xf32>
|
||||
%cst_1 = constant dense<1.5> : tensor<256xf32>
|
||||
%cst_2 = constant dense<2.0> : tensor<1x1x1x256xf32>
|
||||
%0 = "tfl.conv_2d"(%arg0, %cst_0, %cst_1) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x8x8x207xf32>, tensor<256x3x3x207xf32>, tensor<256xf32>) -> tensor<1x8x8x256xf32>
|
||||
%1 = "tfl.mul"(%0, %cst_2) {fused_activation_function = "NONE"} : (tensor<1x8x8x256xf32>, tensor<1x1x1x256xf32>) -> tensor<1x8x8x256xf32>
|
||||
return %1 : tensor<1x8x8x256xf32>
|
||||
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<2.800000e+00> : tensor<256x3x3x207xf32>
|
||||
// CHECK: %[[CST_1:.*]] = constant dense<3.000000e+00> : tensor<1x1x1x256xf32>
|
||||
// CHECK: "tfl.conv_2d"(%arg0, %[[CST_0]], %[[CST_1]])
|
||||
|
||||
}
|
||||
|
@ -0,0 +1,26 @@
|
||||
// RUN: tf-opt -tfl-prepare-tf=tfl-allow-bf16-type-legalization=true %s | FileCheck %s
|
||||
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
|
||||
|
||||
// CHECK-LABEL: conv_2d_bf16
|
||||
func @conv_2d_bf16(%arg0 : tensor<256x32x32x3xbf16>, %arg1 : tensor<3x3x3x16xbf16>) -> tensor<256x30x30x16xbf16> {
|
||||
%0 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xbf16>, tensor<3x3x3x16xbf16>) -> tensor<256x30x30x16xbf16>
|
||||
return %0 : tensor<256x30x30x16xbf16>
|
||||
// CHECK: "tfl.conv_2d"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fused_batch_norm_v3_bf16
|
||||
func @fused_batch_norm_v3_bf16(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> tensor<8x8x8x8xbf16> {
|
||||
%0, %1, %2 ,%3, %4, %5 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_BFLOAT16", U = "tfdtype$DT_BFLOAT16", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
|
||||
return %0 : tensor<8x8x8x8xbf16>
|
||||
// CHECK: "tf.FusedBatchNormV3"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: depthwise_conv_2d_bf16
|
||||
func @depthwise_conv_2d_bf16(%arg0 : tensor<256x32x32x3xbf16>, %arg1 : tensor<3x3x3x4xf32>, %arg2 : tensor<256x3x32x32xf32>) -> tensor<256x30x30x12xbf16> {
|
||||
%0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xbf16>, tensor<3x3x3x4xf32>) -> tensor<256x30x30x12xbf16>
|
||||
return %0 : tensor<256x30x30x12xbf16>
|
||||
// CHECK: "tfl.depthwise_conv_2d"
|
||||
}
|
||||
|
||||
}
|
@ -571,6 +571,73 @@ func @MatrixSetDiagV3Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) ->
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
|
||||
return %0: tensor<3x3xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_f32_low_dim
|
||||
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32>
|
||||
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_i32_low_dim(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
|
||||
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
|
||||
return %0: tensor<3x3xi32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_i32_low_dim
|
||||
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<3x3xi32>
|
||||
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xi32>
|
||||
}
|
||||
|
||||
func @broadcast_to_low_dim_with_unknown_shape(%arg0: tensor<3xf32>, %arg1: tensor<*xi32>) -> tensor<3x3xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<*xi32>) -> tensor<3x3xf32>
|
||||
return %0: tensor<3x3xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_low_dim_with_unknown_shape
|
||||
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32>
|
||||
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_i32_low_dim_with_unknown_output(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<*xi32> {
|
||||
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<*xi32>
|
||||
return %0: tensor<*xi32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_i32_low_dim_with_unknown_output
|
||||
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: [[FILL:%.*]] = "tf.Fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[FILL]]) : (tensor<3xi32>, tensor<*xi32>) -> tensor<*xi32>
|
||||
// CHECK: return [[MUL]] : tensor<*xi32>
|
||||
}
|
||||
|
||||
func @broadcast_to_high_dim_with_unknown_shape(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32>
|
||||
return %0: tensor<7x8x1x2x3x4x5x6xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_high_dim_with_unknown_shape
|
||||
// CHECK: [[BCT:%.*]] = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32>
|
||||
// CHECK: return [[BCT]] : tensor<7x8x1x2x3x4x5x6xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_high_dim_with_unknown_output(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<8xi32>) -> tensor<*xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<8xi32>) -> tensor<*xf32>
|
||||
return %0: tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_high_dim_with_unknown_output
|
||||
// CHECK: [[BCT:%.*]] = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<8xi32>) -> tensor<*xf32>
|
||||
// CHECK: return [[BCT]] : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_with_unknown_shape_and_output(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<*xf32>
|
||||
return %0: tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_with_unknown_shape_and_output
|
||||
// CHECK: "tf.BroadcastTo"(%arg0, %arg1)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: xla_conv
|
||||
func @xla_conv(%arg0: tensor<4x8x8x16xf32>) -> tensor<4x8x8x16xf32> {
|
||||
%0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<3x3x16x16xf32>} : () -> tensor<3x3x16x16xf32> loc("Const_1")
|
||||
@ -646,4 +713,46 @@ func @DoNotConvertConv2DWhenFilterTypeDimIsNotDecided(%arg0 : tensor<?x?x?x96xf3
|
||||
// CHECK: tf.Conv2D
|
||||
}
|
||||
|
||||
// CHECK-LABEL: conv2d_f16
|
||||
func @conv2d_f16(%arg0 : tensor<?x224x224x3xf16>, %arg1 : tensor<3x3x3x16xf16>) -> tensor<?x112x112x16xf16> {
|
||||
%0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<?x224x224x3xf16>, tensor<3x3x3x16xf16>) -> tensor<?x112x112x16xf16>
|
||||
return %0 : tensor<?x112x112x16xf16>
|
||||
// CHECK: "tf.Conv2D"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fused_batch_norm_v3_f16
|
||||
func @fused_batch_norm_v3_f16(%arg0 : tensor<?x112x112x16xf16>, %arg1 : tensor<16xf32>, %arg2 : tensor<16xf32>, %arg3 : tensor<16xf32>, %arg4 : tensor<16xf32>) -> tensor<?x112x112x16xf16> {
|
||||
%0, %1, %2, %3, %4, %5 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {data_format = "NHWC", device = "", epsilon = 1.000000e-03 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = false} : (tensor<?x112x112x16xf16>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) -> (tensor<?x112x112x16xf16>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<*xf32>)
|
||||
return %0 : tensor<?x112x112x16xf16>
|
||||
// CHECK: "tf.FusedBatchNormV3"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: depthwise_conv2d_native_f16
|
||||
func @depthwise_conv2d_native_f16(%arg0 : tensor<?x112x112x16xf16>, %arg1 : tensor<3x3x16x1xf16>) -> tensor<?x112x112x16xf16> {
|
||||
%0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<?x112x112x16xf16>, tensor<3x3x16x1xf16>) -> tensor<?x112x112x16xf16>
|
||||
return %0 : tensor<?x112x112x16xf16>
|
||||
// CHECK: "tf.DepthwiseConv2dNative"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: conv_2d_bf16
|
||||
func @conv_2d_bf16(%arg0 : tensor<256x32x32x3xbf16>, %arg1 : tensor<3x3x3x16xbf16>) -> tensor<256x30x30x16xbf16> {
|
||||
%0 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xbf16>, tensor<3x3x3x16xbf16>) -> tensor<256x30x30x16xbf16>
|
||||
return %0 : tensor<256x30x30x16xbf16>
|
||||
// CHECK: "tf.Conv2D"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fused_batch_norm_v3_bf16
|
||||
func @fused_batch_norm_v3_bf16(%arg0 : tensor<?x112x112x16xbf16>, %arg1 : tensor<16xf32>, %arg2 : tensor<16xf32>, %arg3 : tensor<16xf32>, %arg4 : tensor<16xf32>) -> tensor<?x112x112x16xbf16> {
|
||||
%0, %1, %2, %3, %4, %5 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {data_format = "NHWC", device = "", epsilon = 1.000000e-03 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = false} : (tensor<?x112x112x16xbf16>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) -> (tensor<?x112x112x16xbf16>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<*xf32>)
|
||||
return %0 : tensor<?x112x112x16xbf16>
|
||||
// CHECK: "tf.FusedBatchNormV3"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: depthwise_conv_2d_bf16
|
||||
func @depthwise_conv_2d_bf16(%arg0 : tensor<256x32x32x3xbf16>, %arg1 : tensor<3x3x3x4xf32>, %arg2 : tensor<256x3x32x32xf32>) -> tensor<256x30x30x12xbf16> {
|
||||
%0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xbf16>, tensor<3x3x3x4xf32>) -> tensor<256x30x30x12xbf16>
|
||||
return %0 : tensor<256x30x30x12xbf16>
|
||||
// CHECK: "tf.DepthwiseConv2dNative"
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -45,18 +45,20 @@ const char kTFLiteDataLayout[] = "NHWC";
|
||||
|
||||
void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
|
||||
mlir::OpPassManager* pass_manager) {
|
||||
pass_manager->addPass(mlir::TFL::CreatePrepareQuantizePass(quant_specs));
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(
|
||||
mlir::TFL::CreatePrepareQuantizePass(quant_specs));
|
||||
if (quant_specs.default_ranges.first.hasValue() ||
|
||||
quant_specs.default_ranges.second.hasValue()) {
|
||||
pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass(
|
||||
quant_specs.default_ranges.first.getValueOr(0.0),
|
||||
quant_specs.default_ranges.second.getValueOr(0.0),
|
||||
quant_specs.IsSignedInferenceType()));
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(
|
||||
mlir::TFL::CreateDefaultQuantParamsPass(
|
||||
quant_specs.default_ranges.first.getValueOr(0.0),
|
||||
quant_specs.default_ranges.second.getValueOr(0.0),
|
||||
quant_specs.IsSignedInferenceType()));
|
||||
}
|
||||
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateQuantizePass());
|
||||
bool emit_quant_adaptor_ops =
|
||||
quant_specs.inference_type != quant_specs.inference_input_type;
|
||||
pass_manager->addPass(
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(
|
||||
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
|
||||
}
|
||||
|
||||
@ -67,7 +69,8 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
standard_pipeline_options.enable_inliner = false;
|
||||
standard_pipeline_options.form_clusters = pass_config.form_clusters;
|
||||
mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
|
||||
pass_manager->addPass(mlir::TF::CreateDeviceIndexSelectorPass());
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(
|
||||
mlir::TF::CreateDeviceIndexSelectorPass());
|
||||
|
||||
// Add canonicalize pass to remove no-op session initializer pass.
|
||||
pass_manager->addPass(mlir::createCanonicalizerPass());
|
||||
@ -155,7 +158,8 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
pass_manager->addPass(mlir::createInlinerPass());
|
||||
|
||||
// TODO(jpienaar): Revise post dialect constants.
|
||||
pass_manager->addPass(mlir::TF::CreateDecodeConstantPass());
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(
|
||||
mlir::TF::CreateDecodeConstantPass());
|
||||
// Canonicalization includes const folding, which is utilized here to optimize
|
||||
// away ops that can't get constant folded after PrepareTF pass. For example,
|
||||
// tf.Conv2D is split into tf.Transpose and tfl.Conv2D.
|
||||
@ -178,12 +182,13 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
// to match 'kTFLiteDataLayout'
|
||||
mlir::TF::LayoutOptimizationPipelineOptions layout_optimization_options;
|
||||
layout_optimization_options.force_data_format = kTFLiteDataLayout;
|
||||
mlir::TF::CreateLayoutOptimizationPipeline(*pass_manager,
|
||||
layout_optimization_options);
|
||||
mlir::TF::CreateLayoutOptimizationPipeline(
|
||||
pass_manager->nest<mlir::FuncOp>(), layout_optimization_options);
|
||||
// Prepare for TFLite dialect, rerun canonicalization, and then legalize to
|
||||
// the TFLite dialect.
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul));
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreatePrepareTFPass(
|
||||
pass_config.unfold_batch_matmul,
|
||||
/*allow_bf16_type_legalization=*/!pass_config.runtime_verification));
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
if (pass_config.shape_inference) {
|
||||
// Add a shape inference pass to optimize away the unnecessary casts.
|
||||
@ -198,16 +203,18 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
pass_manager->addPass(mlir::createInlinerPass());
|
||||
|
||||
// This pass removes the asset file dependencies in hash table use cases.
|
||||
pass_manager->addPass(mlir::TF::CreateInitTextFileToImportPass());
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(
|
||||
mlir::TF::CreateInitTextFileToImportPass());
|
||||
|
||||
pass_manager->addPass(
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(
|
||||
mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
|
||||
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateOptimizePass());
|
||||
// This pass operates on TensorFlow ops but is triggered after legalization
|
||||
// so that it can target constants introduced once TensorFlow Identity ops
|
||||
// are removed during legalization.
|
||||
pass_manager->addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
|
||||
pass_manager->addPass(mlir::TFL::CreateRaiseCustomOpsPass());
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(
|
||||
mlir::TFL::CreateRaiseCustomOpsPass());
|
||||
pass_manager->addPass(mlir::createSymbolDCEPass());
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
|
||||
@ -225,7 +232,8 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
// it's not desired by TFL. This pass serves as a "fix" pass to split the
|
||||
// merged inputs until we have 1st class variable support or reuse
|
||||
// tf.variable to model this.
|
||||
pass_manager->addPass(mlir::TFL::CreateSplitMergedOperandsPass());
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(
|
||||
mlir::TFL::CreateSplitMergedOperandsPass());
|
||||
}
|
||||
}
|
||||
|
||||
@ -276,7 +284,8 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
|
||||
pm.addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass());
|
||||
|
||||
// TFLite dialect passes.
|
||||
pm.addPass(mlir::TFL::CreatePrepareTFPass(true));
|
||||
pm.addPass(mlir::TFL::CreatePrepareTFPass(
|
||||
/*unfold_batch_matmul=*/true, /*allow_bf16_type_legalization=*/false));
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
pm.addPass(
|
||||
mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true));
|
||||
@ -295,7 +304,7 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
|
||||
|
||||
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
||||
|
||||
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::TFL::CreateRuntimeVerifyPass());
|
||||
}
|
||||
|
||||
// Registers a pass pipeline for the standard TFL passes.
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
#include "tensorflow/cc/saved_model/loader.h"
|
||||
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||
@ -187,7 +188,7 @@ int main(int argc, char **argv) {
|
||||
// message. So we can just return here.
|
||||
if (!module.ok()) return kTrFailure;
|
||||
|
||||
mlir::PassManager pm(&context);
|
||||
mlir::PassManager pm(&context, mlir::OpPassManager::Nesting::Implicit);
|
||||
mlir::applyPassManagerCLOptions(pm);
|
||||
|
||||
// Set the quantization specifications from the command line flags.
|
||||
|
@ -92,7 +92,7 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (!TFTypeIsFloatTensor(op.input()) || !TFDataFormatIsNHWC(op))
|
||||
if (!TFTypeIsFloat32Tensor(op.input()) || !TFDataFormatIsNHWC(op))
|
||||
return failure();
|
||||
|
||||
// Allow dynamic width and height dimensions only.
|
||||
|
@ -116,6 +116,9 @@ def LegalizeArgMax : Pat<(TF_ArgMaxOp $input, $dim),
|
||||
def LegalizeArgMin : Pat<(TF_ArgMinOp $input, $dim),
|
||||
(TFL_ArgMinOp $input, $dim)>;
|
||||
|
||||
def LegalizeBroadcastTo : Pat<(TF_BroadcastToOp $input, $dim),
|
||||
(TFL_BroadcastToOp $input, $dim)>;
|
||||
|
||||
def LegalizeCeil : Pat<(TF_CeilOp $arg), (TFL_CeilOp $arg)>;
|
||||
|
||||
def LegalizeCos : Pat<(TF_CosOp $arg), (TFL_CosOp $arg)>;
|
||||
@ -264,7 +267,7 @@ def LegalizeAddv2 : Pat<(TF_AddV2Op $lhs, $rhs),
|
||||
(TFL_AddOp $lhs, $rhs, TFL_AF_None)>;
|
||||
def LegalizeBiasAdd : Pat<
|
||||
(TF_BiasAddOp F32Tensor:$l, F32Tensor:$r, IsDataFormatNHWC:$data_format),
|
||||
(TFL_AddOp $l, $r, TFL_AF_None)>;
|
||||
(TF_AddV2Op $l, $r)>;
|
||||
def LegalizeSub : Pat<(TF_SubOp $lhs, $rhs),
|
||||
(TFL_SubOp $lhs, $rhs, TFL_AF_None)>;
|
||||
def LegalizeMul : Pat<(TF_MulOp $lhs, $rhs),
|
||||
|
@ -636,11 +636,155 @@ struct LegalizeUnidirectionalSequenceRnn : public RewritePattern {
|
||||
}
|
||||
};
|
||||
|
||||
void LegalizeTF::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto* context = &getContext();
|
||||
auto func = getFunction();
|
||||
// Put two TFL BroadcastTo ops in front of the given TF binary broadcast op to
|
||||
// to make binary broadcast-able op conversion always successful and does not
|
||||
// require flex delegate.
|
||||
template <typename SourceOp>
|
||||
class ApplyExplicitBroadcasting : public OpRewritePattern<SourceOp> {
|
||||
public:
|
||||
using OpRewritePattern<SourceOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(SourceOp src_op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Operation* op = static_cast<Operation*>(src_op);
|
||||
auto lhs = op->getOperand(0);
|
||||
auto rhs = op->getOperand(1);
|
||||
|
||||
// Should have static shapes to calculate the broadcasted shape.
|
||||
if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
|
||||
!rhs.getType().cast<ShapedType>().hasStaticShape()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto lhs_shape = lhs.getType().cast<ShapedType>().getShape();
|
||||
auto rhs_shape = rhs.getType().cast<ShapedType>().getShape();
|
||||
|
||||
if (lhs_shape == rhs_shape) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Calculate the broadcasted shape.
|
||||
SmallVector<int64_t, 4> result_shape;
|
||||
if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape,
|
||||
result_shape)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
RankedTensorType result_type = RankedTensorType::get(
|
||||
result_shape, getElementTypeOrSelf(op->getResult(0).getType()));
|
||||
|
||||
// Create a const op, that stores the above broadcasted shape.
|
||||
auto new_shape_attr = mlir::DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64)),
|
||||
result_shape);
|
||||
auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
|
||||
|
||||
// Apply BroadcastTo ops to each input.
|
||||
auto broadcast_type = RankedTensorType::get(
|
||||
result_shape, getElementTypeOrSelf(lhs.getType()));
|
||||
|
||||
if (result_type.getShape() != lhs_shape) {
|
||||
lhs = rewriter
|
||||
.create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, lhs,
|
||||
new_shape)
|
||||
.output();
|
||||
}
|
||||
if (result_type.getShape() != rhs_shape) {
|
||||
rhs = rewriter
|
||||
.create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, rhs,
|
||||
new_shape)
|
||||
.output();
|
||||
}
|
||||
|
||||
// Recreate an op with the above Broadcast op results.
|
||||
rewriter.replaceOpWithNewOp<SourceOp>(op, result_type, lhs, rhs);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// This specialization is for TF SelectV2 op. SelectV2 op have three inputs and
|
||||
// they should have broadcastable shapes.
|
||||
template <>
|
||||
class ApplyExplicitBroadcasting<TF::SelectV2Op>
|
||||
: public OpRewritePattern<TF::SelectV2Op> {
|
||||
public:
|
||||
using OpRewritePattern<TF::SelectV2Op>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TF::SelectV2Op src_op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Operation* op = static_cast<Operation*>(src_op);
|
||||
auto cond = op->getOperand(0);
|
||||
auto lhs = op->getOperand(1);
|
||||
auto rhs = op->getOperand(2);
|
||||
|
||||
// Should have static shapes to calculate the broadcasted shape.
|
||||
if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
|
||||
!rhs.getType().cast<ShapedType>().hasStaticShape() ||
|
||||
!cond.getType().cast<ShapedType>().hasStaticShape()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto lhs_shape = lhs.getType().cast<ShapedType>().getShape();
|
||||
auto rhs_shape = rhs.getType().cast<ShapedType>().getShape();
|
||||
auto cond_shape = cond.getType().cast<ShapedType>().getShape();
|
||||
|
||||
if (lhs_shape == rhs_shape && cond_shape == lhs_shape) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Calculate the broadcasted shape.
|
||||
SmallVector<int64_t, 4> broadcasted_shape;
|
||||
if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape,
|
||||
broadcasted_shape)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> result_shape;
|
||||
if (!OpTrait::util::getBroadcastedShape(broadcasted_shape, cond_shape,
|
||||
result_shape)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Create a const op, that stores the above broadcasted shape.
|
||||
auto shape_type =
|
||||
RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64));
|
||||
auto new_shape_attr =
|
||||
mlir::DenseIntElementsAttr::get(shape_type, result_shape);
|
||||
auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
|
||||
|
||||
// Apply BroadcastTo ops to each input.
|
||||
auto cond_result_type =
|
||||
RankedTensorType::get(result_shape, rewriter.getIntegerType(1));
|
||||
auto result_type = RankedTensorType::get(
|
||||
result_shape, getElementTypeOrSelf(lhs.getType()));
|
||||
|
||||
if (result_shape != cond_shape) {
|
||||
cond = rewriter
|
||||
.create<TF::BroadcastToOp>(op->getLoc(), cond_result_type,
|
||||
cond, new_shape)
|
||||
.output();
|
||||
}
|
||||
if (result_shape != lhs_shape) {
|
||||
lhs = rewriter
|
||||
.create<TF::BroadcastToOp>(op->getLoc(), result_type, lhs,
|
||||
new_shape)
|
||||
.output();
|
||||
}
|
||||
if (result_shape != rhs_shape) {
|
||||
rhs = rewriter
|
||||
.create<TF::BroadcastToOp>(op->getLoc(), result_type, rhs,
|
||||
new_shape)
|
||||
.output();
|
||||
}
|
||||
|
||||
// Recreate an op with the above Broadcast op results.
|
||||
rewriter.replaceOpWithNewOp<TF::SelectV2Op>(op, result_type, cond, lhs,
|
||||
rhs);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void addPatterns(MLIRContext* context, OwningRewritePatternList& patterns) {
|
||||
// Add TF->TF lowering patterns.
|
||||
TF::PopulateLoweringTFPatterns(context, &patterns);
|
||||
|
||||
@ -656,7 +800,25 @@ void LegalizeTF::runOnFunction() {
|
||||
// Ophint python converter converted tf node pattern.
|
||||
patterns.insert<LegalizeUnidirectionalSequenceLstm,
|
||||
LegalizeUnidirectionalSequenceRnn>(context);
|
||||
FrozenRewritePatternList frozenPatterns(std::move(patterns));
|
||||
}
|
||||
|
||||
void applyPatterns(FuncOp func, ConversionTarget& target,
|
||||
FrozenRewritePatternList& frozenPatterns) {
|
||||
// Keep trying to convert.
|
||||
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
|
||||
// Look if there is a function that tries until it converge.
|
||||
// Currently unit-test doesn't do multiple tries, so we need this.
|
||||
const int max_iterations = 15;
|
||||
for (int i = 0; i < max_iterations; ++i) {
|
||||
if (failed(applyPartialConversion(func, target, frozenPatterns))) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void LegalizeTF::runOnFunction() {
|
||||
auto* context = &getContext();
|
||||
auto func = getFunction();
|
||||
|
||||
ConversionTarget target(*context);
|
||||
// It is legal to have TF ops in the graph still which can be
|
||||
@ -690,16 +852,42 @@ void LegalizeTF::runOnFunction() {
|
||||
return success(current_thread_id == llvm::get_threadid());
|
||||
});
|
||||
|
||||
// Keep trying to convert.
|
||||
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
|
||||
// Look if there is a function that tries until it converge.
|
||||
// Currently unit-test doesn't do multiple tries, so we need this.
|
||||
const int max_iterations = 15;
|
||||
for (int i = 0; i < max_iterations; ++i) {
|
||||
if (failed(applyPartialConversion(func, target, frozenPatterns))) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
OwningRewritePatternList stage1Patterns;
|
||||
|
||||
addPatterns(context, stage1Patterns);
|
||||
|
||||
FrozenRewritePatternList stage1FrozenPatterns(std::move(stage1Patterns));
|
||||
applyPatterns(func, target, stage1FrozenPatterns);
|
||||
|
||||
// Explict BroadcastTo addition for left-over broadcast-able ops.
|
||||
// The following pattern matchings should be done after the other legalization
|
||||
// rules in order not to add unnecessary BroadcastTo ops.
|
||||
OwningRewritePatternList stage2Patterns;
|
||||
|
||||
addPatterns(context, stage2Patterns);
|
||||
|
||||
stage2Patterns.insert<ApplyExplicitBroadcasting<TF::LessEqualOp>,
|
||||
ApplyExplicitBroadcasting<TF::GreaterEqualOp>,
|
||||
ApplyExplicitBroadcasting<TF::NotEqualOp>,
|
||||
ApplyExplicitBroadcasting<TF::GreaterOp>,
|
||||
ApplyExplicitBroadcasting<TF::LessOp>,
|
||||
ApplyExplicitBroadcasting<TF::EqualOp>,
|
||||
ApplyExplicitBroadcasting<TF::AddOp>,
|
||||
ApplyExplicitBroadcasting<TF::AddV2Op>,
|
||||
ApplyExplicitBroadcasting<TF::MulOp>,
|
||||
ApplyExplicitBroadcasting<TF::DivOp>,
|
||||
ApplyExplicitBroadcasting<TF::RealDivOp>,
|
||||
ApplyExplicitBroadcasting<TF::SubOp>,
|
||||
ApplyExplicitBroadcasting<TF::FloorDivOp>,
|
||||
ApplyExplicitBroadcasting<TF::FloorModOp>,
|
||||
ApplyExplicitBroadcasting<TF::PowOp>,
|
||||
ApplyExplicitBroadcasting<TF::MaximumOp>,
|
||||
ApplyExplicitBroadcasting<TF::MinimumOp>,
|
||||
ApplyExplicitBroadcasting<TF::SquaredDifferenceOp>,
|
||||
ApplyExplicitBroadcasting<TF::SelectV2Op>>(context);
|
||||
|
||||
FrozenRewritePatternList stage2FrozenPatterns(std::move(stage2Patterns));
|
||||
applyPatterns(func, target, stage2FrozenPatterns);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -200,16 +200,16 @@ bool CanOptimizeIdentityGatherNdOrScatterNdOp(Value params,
|
||||
ElementsAttr ExpandTo4DForConvImpl(Attribute a, bool is_depthwise) {
|
||||
auto elements = a.dyn_cast<DenseElementsAttr>();
|
||||
auto shape = elements.getType().getShape();
|
||||
if (shape.size() == 4) {
|
||||
return elements;
|
||||
if (!shape.empty()) {
|
||||
// Checks that elements are essentially 1d.
|
||||
assert(elements.getNumElements() == shape.back());
|
||||
}
|
||||
std::vector<int64_t> shape_data = {1, 1, 1, 1};
|
||||
if (shape.size() == 1 || shape.empty()) {
|
||||
if (is_depthwise)
|
||||
shape_data[3] = shape.empty() ? 1 : shape[0];
|
||||
else
|
||||
shape_data[0] = shape.empty() ? 1 : shape[0];
|
||||
}
|
||||
const int vector_length = elements.getNumElements();
|
||||
if (is_depthwise)
|
||||
shape_data[3] = vector_length;
|
||||
else
|
||||
shape_data[0] = vector_length;
|
||||
auto new_shape =
|
||||
RankedTensorType::get(shape_data, elements.getType().getElementType());
|
||||
return elements.reshape(new_shape);
|
||||
|
@ -41,7 +41,7 @@ std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass();
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
|
||||
bool unfold_batch_matmul);
|
||||
bool unfold_batch_matmul, bool allow_bf16_type_legalization);
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
|
||||
// pass.
|
||||
|
@ -257,9 +257,9 @@ void PrepareQuantizePass::SanityCheckAndAdjustment(FuncOp func) {
|
||||
// (Quant $in, $qA) would introduce less quantization noise) the likley cause
|
||||
// is an minor error in constructing the original network model that
|
||||
// introduced back-to-back Fake Quantization operations. Hence: emit a
|
||||
// warning. N.b. at this point weŕe (teporarility) in the quantization dialect
|
||||
// (presuambly enalbe re-use in xla etc) quant::*QuantizeCastOp weŕe matching
|
||||
// here.
|
||||
// warning. N.b. at this point we're (teporarility) in the quantization
|
||||
// dialect (presuambly enable re-use in xla etc) quant::*QuantizeCastOp
|
||||
// we're matching here.
|
||||
//
|
||||
func.walk([&](quant::QuantizeCastOp q_op) {
|
||||
// If up with end up with
|
||||
|
@ -82,8 +82,10 @@ class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
|
||||
public:
|
||||
PrepareTFPass() = default;
|
||||
PrepareTFPass(const PrepareTFPass &) {}
|
||||
explicit PrepareTFPass(bool unfold_batch_matmul) {
|
||||
explicit PrepareTFPass(bool unfold_batch_matmul,
|
||||
bool allow_bf16_type_legalization) {
|
||||
unfold_batch_matmul_ = unfold_batch_matmul;
|
||||
allow_bf16_type_legalization_ = allow_bf16_type_legalization;
|
||||
}
|
||||
void runOnFunction() override;
|
||||
|
||||
@ -97,6 +99,10 @@ class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
|
||||
*this, "tfl-unfold-batch-matmul",
|
||||
llvm::cl::desc("Unfold BatchMatMul into individual MatMul ops."),
|
||||
llvm::cl::init(true)};
|
||||
|
||||
Option<bool> allow_bf16_type_legalization_{
|
||||
*this, "tfl-allow-bf16-type-legalization",
|
||||
llvm::cl::desc("Allow bf16 type legalization."), llvm::cl::init(false)};
|
||||
};
|
||||
|
||||
template <class TFFakeQuantOp>
|
||||
@ -258,6 +264,15 @@ using PreparePerTensorFakeQuantWithMinMaxArgs =
|
||||
TF::FakeQuantWithMinMaxArgsOp, /*PerAxis=*/false,
|
||||
FetchMinMaxAttrs<TF::FakeQuantWithMinMaxArgsOp>>;
|
||||
|
||||
// Transient state for preserving data from match to rewrite
|
||||
struct ConvertTFConvOpMatchState {
|
||||
IntegerAttr dilation_height_factor;
|
||||
IntegerAttr dilation_width_factor;
|
||||
StringAttr padding;
|
||||
IntegerAttr stride_height;
|
||||
IntegerAttr stride_width;
|
||||
};
|
||||
|
||||
// Templated class for declaring a converter from some TensorFlow convolution
|
||||
// op into its counterpart in TensorFlow Lite.
|
||||
//
|
||||
@ -273,19 +288,12 @@ using PreparePerTensorFakeQuantWithMinMaxArgs =
|
||||
//
|
||||
// int64_t getBiasDim(ArrayRef<int64_t> filterShape) const;
|
||||
template <typename ConcreteType, typename TFConvOpType>
|
||||
struct ConvertTFConvOp : public RewritePattern {
|
||||
// Transient state for preserving data from match to rewrite
|
||||
struct ConvertTFConvOpMatchState {
|
||||
IntegerAttr dilation_height_factor;
|
||||
IntegerAttr dilation_width_factor;
|
||||
StringAttr padding;
|
||||
IntegerAttr stride_height;
|
||||
IntegerAttr stride_width;
|
||||
};
|
||||
|
||||
ConvertTFConvOp(MLIRContext *context)
|
||||
class ConvertTFConvOp : public RewritePattern {
|
||||
public:
|
||||
ConvertTFConvOp(MLIRContext *context, bool allow_bf16_type_legalization)
|
||||
: RewritePattern(TFConvOpType::getOperationName(), 1, context),
|
||||
intAttrOne(Builder(context).getI32IntegerAttr(1)) {}
|
||||
intAttrOne(Builder(context).getI32IntegerAttr(1)),
|
||||
allow_bf16_type_legalization_(allow_bf16_type_legalization) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
@ -301,9 +309,13 @@ struct ConvertTFConvOp : public RewritePattern {
|
||||
|
||||
TFConvOpType tf_op = cast<TFConvOpType>(op);
|
||||
|
||||
if (!TFTypeIsFloatTensor(tf_op.input()) || !TFDataFormatIsNHWC(op))
|
||||
if (!TFTypeIsFloat32Tensor(tf_op.input()) &&
|
||||
!(allow_bf16_type_legalization_ &&
|
||||
TFTypeIsBFloat16Tensor(tf_op.input())))
|
||||
return failure();
|
||||
|
||||
if (!TFDataFormatIsNHWC(op)) return failure();
|
||||
|
||||
IntegerAttr height, width;
|
||||
if (!TFIntListIs1XY1(op, "strides", &height, &width)) return failure();
|
||||
|
||||
@ -359,13 +371,17 @@ struct ConvertTFConvOp : public RewritePattern {
|
||||
}
|
||||
|
||||
const IntegerAttr intAttrOne;
|
||||
|
||||
private:
|
||||
bool allow_bf16_type_legalization_;
|
||||
};
|
||||
|
||||
class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
|
||||
public:
|
||||
using BaseType = ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp>;
|
||||
|
||||
ConvertTFConv2D(MLIRContext *context) : BaseType(context) {}
|
||||
ConvertTFConv2D(MLIRContext *context, bool allow_bf16_type_legalization)
|
||||
: BaseType(context, allow_bf16_type_legalization) {}
|
||||
|
||||
int64_t getBiasDim(ArrayRef<int64_t> filterShape) const {
|
||||
return filterShape.back();
|
||||
@ -421,7 +437,9 @@ class ConvertTFDepthwiseConv2dNative
|
||||
using BaseType = ConvertTFConvOp<ConvertTFDepthwiseConv2dNative,
|
||||
TF::DepthwiseConv2dNativeOp>;
|
||||
|
||||
ConvertTFDepthwiseConv2dNative(MLIRContext *context) : BaseType(context) {}
|
||||
ConvertTFDepthwiseConv2dNative(MLIRContext *context,
|
||||
bool allow_bf16_type_legalization)
|
||||
: BaseType(context, allow_bf16_type_legalization) {}
|
||||
|
||||
int64_t getBiasDim(ArrayRef<int64_t> filterShape) const {
|
||||
return filterShape[2] * filterShape[3];
|
||||
@ -716,9 +734,9 @@ struct ConvertTFBroadcastTo : public RewritePattern {
|
||||
|
||||
// Allow lowering when low dimension inputs are given and its type is F32 or
|
||||
// I32.
|
||||
if (!((output_type.hasRank() && output_type.getRank() <= 5) ||
|
||||
if (!((output_type.hasRank() && output_type.getRank() <= 4) ||
|
||||
(shape_type.hasStaticShape() && shape_type.getRank() == 1 &&
|
||||
shape_type.getDimSize(0) <= 5)))
|
||||
shape_type.getDimSize(0) <= 4)))
|
||||
return failure();
|
||||
|
||||
if (!(element_type.isa<BFloat16Type, Float32Type>() ||
|
||||
@ -815,6 +833,9 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
|
||||
offset = fused_batch_norm_op.getODSOperands(2);
|
||||
mean = fused_batch_norm_op.getODSOperands(3);
|
||||
variance = fused_batch_norm_op.getODSOperands(4);
|
||||
|
||||
if (!TFTypeIsFloat32Tensor(fused_batch_norm_op.x())) return failure();
|
||||
|
||||
{
|
||||
epsilon = fused_batch_norm_op.getAttrOfType<::mlir::FloatAttr>("epsilon");
|
||||
if (!epsilon)
|
||||
@ -1196,8 +1217,10 @@ void PrepareTFPass::runOnFunction() {
|
||||
ctx);
|
||||
}
|
||||
phase_2_patterns.insert<TF::ConvertTFEinsumOp, ConvertTFBroadcastTo,
|
||||
ConvertTFConv2D, ConvertTFDepthwiseConv2dNative,
|
||||
ConvertTFStridedSlice, ConvertRfftToRfft2d>(ctx);
|
||||
phase_2_patterns.insert<ConvertTFConv2D, ConvertTFDepthwiseConv2dNative>(
|
||||
ctx, allow_bf16_type_legalization_);
|
||||
|
||||
applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
|
||||
}
|
||||
|
||||
@ -1205,8 +1228,9 @@ void PrepareTFPass::runOnFunction() {
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
|
||||
bool unfold_batch_matmul) {
|
||||
return std::make_unique<PrepareTFPass>(unfold_batch_matmul);
|
||||
bool unfold_batch_matmul, bool allow_bf16_type_legalization) {
|
||||
return std::make_unique<PrepareTFPass>(unfold_batch_matmul,
|
||||
allow_bf16_type_legalization);
|
||||
}
|
||||
|
||||
static PassRegistration<PrepareTFPass> pass(
|
||||
|
@ -132,7 +132,7 @@ bool NotFromQuantOpOrSameQuantType(mlir::Value val, mlir::TypeAttr qtype_attr) {
|
||||
llvm::dyn_cast_or_null<mlir::TFL::QuantizeOp>(val_defn_op);
|
||||
if (!q_op) return true;
|
||||
|
||||
// Ignore shape details - weŕe really only trying to
|
||||
// Ignore shape details - we're really only trying to
|
||||
// check if quantization is the same.
|
||||
auto stripped_src_qtype = GetShapeStrippedType(q_op.qtypeAttr());
|
||||
auto stripped_qtype = GetShapeStrippedType(qtype_attr);
|
||||
|
@ -49,12 +49,19 @@ bool TFIntListIs1XY1(const ArrayAttr &attr);
|
||||
// must be `IntegerAttr`.
|
||||
bool TFIntListIsAllOnes(const ArrayAttr &attr);
|
||||
|
||||
// Returns true iff the given value is a float tensor.
|
||||
// Returns true iff the given value is a float32 tensor.
|
||||
// is "DT_FLOAT".
|
||||
inline bool TFTypeIsFloatTensor(Value value) {
|
||||
inline bool TFTypeIsFloat32Tensor(Value value) {
|
||||
auto tensorType = value.getType().dyn_cast<TensorType>();
|
||||
if (!tensorType) return false;
|
||||
return tensorType.getElementType().isa<FloatType>();
|
||||
return tensorType.getElementType().isF32();
|
||||
}
|
||||
|
||||
// Returns true iff the given value is a bf16 tensor.
|
||||
inline bool TFTypeIsBFloat16Tensor(Value value) {
|
||||
auto tensorType = value.getType().dyn_cast<TensorType>();
|
||||
if (!tensorType) return false;
|
||||
return tensorType.getElementType().isBF16();
|
||||
}
|
||||
|
||||
// Returns true iff the given TensorFlow op has a `padding` attribute whose
|
||||
|
@ -882,9 +882,11 @@ cc_library(
|
||||
"transforms/test_visitor_util.cc",
|
||||
"transforms/tf_data_optimization_pass.cc",
|
||||
"transforms/tf_device_assignment.cc",
|
||||
"transforms/tf_device_replication_pass.cc",
|
||||
"transforms/tpu_cluster_cleanup_attributes.cc",
|
||||
"transforms/tpu_cluster_formation.cc",
|
||||
"transforms/tpu_colocate_composite_resource_ops.cc",
|
||||
"transforms/tpu_compile_op_replication_pass.cc",
|
||||
"transforms/tpu_device_propagation.cc",
|
||||
"transforms/tpu_dynamic_layout_pass.cc",
|
||||
"transforms/tpu_dynamic_padding_mapper.cc",
|
||||
|
@ -1002,6 +1002,10 @@ reverse of SpaceToBatch. See below for a precise description.
|
||||
TF_Tensor:$output
|
||||
);
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tcrops = TF_DerivedOperandTypeAttr<2>;
|
||||
TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>;
|
||||
|
@ -398,6 +398,27 @@ void BatchToSpaceOp::getCanonicalizationPatterns(
|
||||
results.insert<BatchToSpaceToBatchToSpaceND>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BatchToSpaceNDOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(BatchToSpaceNDOp op) {
|
||||
auto block_shape_ty = op.block_shape().getType().cast<ShapedType>();
|
||||
auto crops_ty = op.crops().getType().cast<ShapedType>();
|
||||
|
||||
if (block_shape_ty.hasStaticShape() && crops_ty.hasStaticShape()) {
|
||||
const int block_rank = block_shape_ty.getShape().front();
|
||||
if (crops_ty.getRank() != 2 || crops_ty.getShape().front() != block_rank ||
|
||||
crops_ty.getShape()[1] != 2) {
|
||||
op.emitOpError() << "crops should have shape [" << block_rank
|
||||
<< ", 2] instead of " << crops_ty.getShape();
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BiasAddOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -290,6 +290,25 @@ func @sixdim_space_to_batch_nd(%input: tensor<3x5x7x9x10x11xf32>, %block_shape:
|
||||
return %0 : tensor<?x?x?x?x10x11xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @batchToSpace
|
||||
func @batchToSpace(%arg0: tensor<3x5x2xf32>) -> (tensor<1x8x2xf32>) {
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() {value = dense<[3, 1, 5, 2]> : tensor<4xi64>}
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() {value = dense<[1, 2, 0, 3]> : tensor<4xi64>}
|
||||
// CHECK-DAG: [[VAL2:%.+]] = "tf.Const"() {value = dense<[1, 15, 2]> : tensor<3xi64>}
|
||||
// CHECK-DAG: [[VAL3:%.+]] = "tf.Const"() {value = dense<[0, 3, 0]> : tensor<3xi64>}
|
||||
// CHECK-DAG: [[VAL4:%.+]] = "tf.Const"() {value = dense<[1, 8, 2]> : tensor<3xi64>}
|
||||
// CHECK-DAG: [[VAL5:%.+]] = "tf.Reshape"(%arg0, [[VAL0]])
|
||||
// CHECK-DAG: [[VAL6:%.+]] = "tf.Transpose"([[VAL5]], [[VAL1]])
|
||||
// CHECK-DAG: [[VAL7:%.+]] = "tf.Reshape"([[VAL6]], [[VAL2]])
|
||||
// CHECK-DAG: [[VAL8:%.+]] = "tf.Slice"([[VAL7]], [[VAL3]], [[VAL4]])
|
||||
%0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = "tf.Const"() {value = dense<[[3, 4]]> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||
%2 = "tf.BatchToSpaceND"(%arg0, %0, %1) {device = ""} : (tensor<3x5x2xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x8x2xf32>
|
||||
|
||||
// CHECK: return [[VAL8]] : tensor<1x8x2xf32>
|
||||
return %2 : tensor<1x8x2xf32>
|
||||
}
|
||||
|
||||
func @fake_quant_with_min_max_args(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() {value = dense<1.275000e+02> : tensor<f32>}
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() {value = dense<1.00392163> : tensor<f32>}
|
||||
|
@ -0,0 +1,25 @@
|
||||
// RUN: tf-opt --tf-device-replication %s | FileCheck %s
|
||||
|
||||
// CHECK: func @test_1(%[[ARG_0:.*]]: tensor<i32> {tf.device = "/job:worker/replica:0/task:0/device:CPU:0"}, %[[ARG_1:.*]]: tensor<i32> {tf.device = "/job:worker/replica:0/task:1/device:CPU:0"})
|
||||
func @test_1(%arg0: tensor<i32> {tf.device = "/job:worker/replica:0/task:0/device:CPU:0"}, %arg1: tensor<i32> {tf.device = "/job:worker/replica:0/task:1/device:CPU:0"}) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
|
||||
// CHECK-NEXT: %[[RESULT_0:.*]] = "tf.AddV2"(%[[ARG_0]], %[[ARG_0]]) {device = "/job:worker/replica:0/task:0/device:CPU:0"}
|
||||
// CHECK-NEXT: %[[RESULT_1:.*]] = "tf.AddV2"(%[[ARG_0]], %[[ARG_0]]) {device = "/job:worker/replica:0/task:0/device:CPU:1"}
|
||||
// CHECK-NEXT: %[[RESULT_2:.*]] = "tf.AddV2"(%[[ARG_1]], %[[ARG_1]]) {device = "/job:worker/replica:0/task:1/device:CPU:0"}
|
||||
// CHECK-NEXT: %[[RESULT_3:.*]] = "tf.AddV2"(%[[ARG_1]], %[[ARG_1]]) {device = "/job:worker/replica:0/task:1/device:CPU:1"}
|
||||
%0:4 = tf_device.replicate([%arg0, %arg0, %arg1, %arg1] as %arg2: tensor<i32>) {devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:1", "/job:worker/replica:0/task:1/device:CPU:0", "/job:worker/replica:0/task:1/device:CPU:1"]}, n = 4 : i32} {
|
||||
%1 = "tf.AddV2"(%arg2, %arg2) {device = "TPU_REPLICATED_CORE_0"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
tf_device.return %1 : tensor<i32>
|
||||
}
|
||||
// CHECK-NEXT: return %[[RESULT_0]], %[[RESULT_1]], %[[RESULT_2]], %[[RESULT_3]]
|
||||
return %0#0, %0#1, %0#2, %0#3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}
|
||||
|
||||
|
||||
// CHECK: func @test_2(%[[ARG_0:.*]]: tensor<i32>
|
||||
func @test_2(%arg0: tensor<i32> {tf.device = "/job:worker/replica:0/task:0/device:CPU:0"}) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
|
||||
%0:4 = tf_device.replicate() {n = 4 : i32} {
|
||||
tf_device.return %arg0 : tensor<i32>
|
||||
}
|
||||
// CHECK-NEXT: return %[[ARG_0]], %[[ARG_0]], %[[ARG_0]], %[[ARG_0]]
|
||||
return %0#0, %0#1, %0#2, %0#3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}
|
@ -0,0 +1,22 @@
|
||||
// RUN: tf-opt --tf-tpu-compile-replication %s | FileCheck %s
|
||||
|
||||
// CHECK: func @test(%[[ARG_0:.*]]: tensor<i32> {tf.device = "/job:worker/replica:0/task:0/device:CPU:0"}, %[[ARG_1:.*]]: tensor<i32> {tf.device = "/job:worker/replica:0/task:1/device:CPU:0"})
|
||||
func @test(%arg0: tensor<i32> {tf.device = "/job:worker/replica:0/task:0/device:CPU:0"}, %arg1: tensor<i32> {tf.device = "/job:worker/replica:0/task:1/device:CPU:0"}) -> (tensor<i32>, tensor<i32>) {
|
||||
// CHECK-NEXT: %[[STATUS_0:.*]], %[[PROGRAM_0:.*]] = "tf._TPUCompileMlir"() {device = "/job:worker/replica:0/task:1/device:CPU:0", metadata = "metadata", mlir_module = "mlir_module"}
|
||||
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[STATUS_0]]) {device = "/job:worker/replica:0/task:1/device:CPU:0"}
|
||||
// CHECK-NEXT: %[[STATUS_1:.*]], %[[PROGRAM_1:.*]] = "tf._TPUCompileMlir"() {device = "/job:worker/replica:0/task:0/device:CPU:0", metadata = "metadata", mlir_module = "mlir_module"}
|
||||
%compilation_status, %program = "tf._TPUCompileMlir"() {device = "/job:worker/replica:0/task:0/device:CPU:0", metadata = "metadata", mlir_module = "mlir_module"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[STATUS_1]]) {device = "/job:worker/replica:0/task:0/device:CPU:0"}
|
||||
"tf.TPUCompileSucceededAssert"(%compilation_status) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : (tensor<!tf.string>) -> ()
|
||||
// CHECK-NEXT: %[[ADD_0:.*]] = "tf.AddV2"(%[[ARG_0]], %[[ARG_0]]) {device = "/job:worker/replica:0/task:0/device:TPU:0"}
|
||||
%0 = "tf.AddV2"(%arg0, %arg0) {device = "/job:worker/replica:0/task:0/device:TPU:0"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
// CHECK-NEXT: %[[EXECUTE_0:.*]] = "tf.TPUExecute"(%[[ADD_0]], %[[PROGRAM_1]]) {device = "/job:worker/replica:0/task:0/device:TPU:0"}
|
||||
%1 = "tf.TPUExecute"(%0, %program) {device = "/job:worker/replica:0/task:0/device:TPU:0"} : (tensor<i32>, tensor<2x!tf.string>) -> tensor<i32>
|
||||
// CHECK-NEXT: %[[ADD_1:.*]] = "tf.AddV2"(%[[ARG_1]], %[[ARG_1]]) {device = "/job:worker/replica:0/task:1/device:TPU:0"}
|
||||
%2 = "tf.AddV2"(%arg1, %arg1) {device = "/job:worker/replica:0/task:1/device:TPU:0"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
// CHECK-NEXT: %[[EXECUTE_1:.*]] = "tf.TPUExecute"(%[[ADD_1]], %[[PROGRAM_0]]) {device = "/job:worker/replica:0/task:1/device:TPU:0"}
|
||||
%3 = "tf.TPUExecute"(%2, %program) {device = "/job:worker/replica:0/task:1/device:TPU:0"} : (tensor<i32>, tensor<2x!tf.string>) -> tensor<i32>
|
||||
// CHECK-NEXT: return %[[EXECUTE_0]], %[[EXECUTE_1]]
|
||||
return %1, %3 : tensor<i32>, tensor<i32>
|
||||
}
|
||||
|
@ -121,7 +121,7 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
|
||||
pm.addPass(CreateTPURewritePass());
|
||||
pm.addPass(createSymbolDCEPass());
|
||||
pm.addNestedPass<FuncOp>(TFDevice::CreateReplicateInvariantOpHoistingPass());
|
||||
pm.addNestedPass<FuncOp>(CreateTPUDynamicLayoutPass());
|
||||
pm.addPass(CreateTPUDynamicLayoutPass());
|
||||
pm.addNestedPass<FuncOp>(CreateTPUMergeVariablesWithExecutePass());
|
||||
pm.addNestedPass<FuncOp>(CreateTPUColocateCompositeResourceOps());
|
||||
pm.addPass(CreateTPUVariableReformattingPass());
|
||||
|
@ -50,7 +50,8 @@ Status MlirGraphOptimizationPass::Run(const ConfigProto& config_proto,
|
||||
// Assign optimal data layout to layout sensitive operations and delete
|
||||
// redundant transposes from the IR.
|
||||
LayoutOptimizationPipelineOptions layout_optimization_options;
|
||||
CreateLayoutOptimizationPipeline(pm, layout_optimization_options);
|
||||
CreateLayoutOptimizationPipeline(pm.nest<FuncOp>(),
|
||||
layout_optimization_options);
|
||||
|
||||
// Prepare IR for exporting.
|
||||
pm.addPass(CreateBreakUpIslandsPass());
|
||||
|
@ -85,7 +85,7 @@ void InitTextFileToImportTestPass::runOnOperation() {
|
||||
|
||||
// Run the lowering pass.
|
||||
PassManager pm(context);
|
||||
pm.addPass(CreateInitTextFileToImportPass());
|
||||
pm.addNestedPass<FuncOp>(CreateInitTextFileToImportPass());
|
||||
if (failed(pm.run(module))) return signalPassFailure();
|
||||
}
|
||||
|
||||
|
@ -988,6 +988,177 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
|
||||
}
|
||||
};
|
||||
|
||||
class LowerBatchToSpaceND : public RewritePattern {
|
||||
public:
|
||||
explicit LowerBatchToSpaceND(MLIRContext *context)
|
||||
: RewritePattern(BatchToSpaceNDOp::getOperationName(),
|
||||
{
|
||||
ConstOp::getOperationName(),
|
||||
ReshapeOp::getOperationName(),
|
||||
SliceOp::getOperationName(),
|
||||
TransposeOp::getOperationName(),
|
||||
},
|
||||
1, context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *src_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto op = cast<BatchToSpaceNDOp>(src_op);
|
||||
auto input = op.input();
|
||||
auto input_ty = input.getType().cast<ShapedType>();
|
||||
auto element_ty = input_ty.getElementType();
|
||||
if (!input_ty.hasStaticShape()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int input_rank = input_ty.getRank();
|
||||
auto input_shape = input_ty.getShape();
|
||||
|
||||
DenseIntElementsAttr block_shape;
|
||||
DenseIntElementsAttr crops;
|
||||
if (!matchPattern(op.block_shape(), m_Constant(&block_shape)) ||
|
||||
!matchPattern(op.crops(), m_Constant(&crops))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto block_shape_ty = block_shape.getType();
|
||||
if (!block_shape_ty.hasRank() || block_shape_ty.getRank() != 1) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int block_rank = block_shape_ty.getShape().front();
|
||||
auto remainder_shape = input_shape.drop_front(1 + block_rank);
|
||||
|
||||
const int64_t batch_size = input_shape[0];
|
||||
|
||||
// Compute the product of the block_shape values.
|
||||
int64_t block_num_elems = 1;
|
||||
|
||||
for (auto val : block_shape.getIntValues()) {
|
||||
block_num_elems *= val.getSExtValue();
|
||||
}
|
||||
|
||||
if (block_num_elems <= 0) {
|
||||
op.emitOpError()
|
||||
<< "The product of the block dimensions must be positive";
|
||||
return failure();
|
||||
}
|
||||
|
||||
// 1. Reshape `input` to `reshaped` of shape:
|
||||
// [block_shape[0], ..., block_shape[M-1],
|
||||
// batch / prod(block_shape),
|
||||
// input_shape[1], ..., input_shape[N-1]]
|
||||
std::vector<int64_t> reshaped_shape;
|
||||
for (auto val : block_shape) {
|
||||
reshaped_shape.push_back(val.getSExtValue());
|
||||
}
|
||||
reshaped_shape.resize(input_rank + block_rank);
|
||||
|
||||
reshaped_shape[block_rank] = batch_size / block_num_elems;
|
||||
std::copy(input_shape.begin() + 1, input_shape.end(),
|
||||
reshaped_shape.begin() + block_rank + 1);
|
||||
|
||||
auto reshaped = rewriter.create<TF::ReshapeOp>(
|
||||
op.getLoc(), RankedTensorType::get(reshaped_shape, element_ty), input,
|
||||
rewriter.create<ConstOp>(op.getLoc(),
|
||||
rewriter.getI64TensorAttr(reshaped_shape)));
|
||||
|
||||
// 2. Permute dimensions of `reshaped` to produce `permuted` of shape
|
||||
// [batch / prod(block_shape),
|
||||
//
|
||||
// input_shape[1], block_shape[0],
|
||||
// ...,
|
||||
// input_shape[M], block_shape[M-1],
|
||||
//
|
||||
// input_shape[M+1], ..., input_shape[N-1]]
|
||||
std::vector<int64_t> permutation(reshaped_shape.size());
|
||||
permutation[0] = block_rank;
|
||||
for (int i = 0; i < block_rank; ++i) {
|
||||
permutation[1 + 2 * i] = block_rank + 1 + i;
|
||||
permutation[1 + 2 * i + 1] = i;
|
||||
}
|
||||
std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
|
||||
1 + block_rank * 2);
|
||||
|
||||
std::vector<int64_t> transpose_shape(permutation.size());
|
||||
for (auto it : llvm::enumerate(permutation)) {
|
||||
transpose_shape[it.index()] = reshaped_shape[it.value()];
|
||||
}
|
||||
|
||||
auto permuted = rewriter.create<TF::TransposeOp>(
|
||||
op.getLoc(), RankedTensorType::get(transpose_shape, element_ty),
|
||||
reshaped,
|
||||
rewriter.create<ConstOp>(op.getLoc(),
|
||||
rewriter.getI64TensorAttr(permutation)));
|
||||
|
||||
// 3. Reshape `permuted` to produce `reshaped_permuted` of shape
|
||||
// [batch / prod(block_shape),
|
||||
//
|
||||
// input_shape[1] * block_shape[0],
|
||||
// ...,
|
||||
// input_shape[M] * block_shape[M-1],
|
||||
//
|
||||
// input_shape[M+1],
|
||||
// ...,
|
||||
// input_shape[N-1]]
|
||||
std::vector<int64_t> reshaped_permuted_shape(input_rank);
|
||||
auto block_shape_values = llvm::to_vector<4>(block_shape.getIntValues());
|
||||
reshaped_permuted_shape[0] = batch_size / block_num_elems;
|
||||
for (int i = 0; i < block_rank; ++i) {
|
||||
reshaped_permuted_shape[1 + i] =
|
||||
block_shape_values[i].getSExtValue() * input_shape[1 + i];
|
||||
}
|
||||
std::copy(remainder_shape.begin(), remainder_shape.end(),
|
||||
reshaped_permuted_shape.begin() + 1 + block_rank);
|
||||
|
||||
auto reshaped_permuted = rewriter.create<TF::ReshapeOp>(
|
||||
op.getLoc(), RankedTensorType::get(reshaped_permuted_shape, element_ty),
|
||||
permuted,
|
||||
rewriter.create<ConstOp>(
|
||||
op.getLoc(), rewriter.getI64TensorAttr(reshaped_permuted_shape)));
|
||||
|
||||
// 4. Crop the start and end of dimensions `[1, ..., M]` of
|
||||
// `reshaped_permuted` according to `crops` to produce the output of
|
||||
// shape:
|
||||
// [batch / prod(block_shape),
|
||||
//
|
||||
// input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
|
||||
// ...,
|
||||
// input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
|
||||
//
|
||||
// input_shape[M+1], ..., input_shape[N-1]]
|
||||
std::vector<int64_t> start_indices(input_rank, 0);
|
||||
std::vector<int64_t> slice_sizes = reshaped_permuted_shape;
|
||||
std::vector<int64_t> strides(input_rank, 1);
|
||||
auto crop_values = llvm::to_vector<4>(crops.getIntValues());
|
||||
for (int i = 0; i < block_rank; ++i) {
|
||||
int64_t crop_start = crop_values[i * 2].getSExtValue();
|
||||
int64_t crop_end = crop_values[i * 2 + 1].getSExtValue();
|
||||
|
||||
if (crop_start < 0 || crop_end < 0) {
|
||||
op.emitOpError() << "Crops must be non-negative";
|
||||
return failure();
|
||||
}
|
||||
|
||||
start_indices[i + 1] = crop_start;
|
||||
slice_sizes[i + 1] -= crop_start + crop_end;
|
||||
|
||||
if (slice_sizes[i + 1] < 0) {
|
||||
op.emitOpError() << "Cropped size must be non-negative: start: "
|
||||
<< crop_start << " end: " << crop_end << " size "
|
||||
<< reshaped_permuted_shape[1 + i];
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<TF::SliceOp>(
|
||||
op, RankedTensorType::get(slice_sizes, element_ty), reshaped_permuted,
|
||||
rewriter.create<ConstOp>(op.getLoc(),
|
||||
rewriter.getI64TensorAttr(start_indices)),
|
||||
rewriter.create<ConstOp>(op.getLoc(),
|
||||
rewriter.getI64TensorAttr(slice_sizes)));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Lowers `SparseMatMulOp` to `MatMulOp`, ignoring the sparseness hints,
|
||||
// since we currently don't have an implementation that can use this
|
||||
// information. Adds appropriate casts where necessary to align element types
|
||||
@ -1065,10 +1236,11 @@ class Lower_UnaryOpsComposition
|
||||
|
||||
void PopulateLoweringTFPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
patterns->insert<LowerAddNOp, ConvertFakeQuantWithMinMaxVarsOp,
|
||||
LowerDynamicStitchOp, LowerInvertPermutationOp,
|
||||
LowerLgammaOp, LowerPackOp, LowerSpaceToBatchNDOp,
|
||||
LowerSparseMatMulOp, Lower_UnaryOpsComposition>(context);
|
||||
patterns
|
||||
->insert<LowerAddNOp, ConvertFakeQuantWithMinMaxVarsOp,
|
||||
LowerDynamicStitchOp, LowerInvertPermutationOp, LowerLgammaOp,
|
||||
LowerPackOp, LowerBatchToSpaceND, LowerSpaceToBatchNDOp,
|
||||
LowerSparseMatMulOp, Lower_UnaryOpsComposition>(context);
|
||||
populateWithGenerated(context, *patterns);
|
||||
}
|
||||
|
||||
|
@ -279,6 +279,11 @@ CreateMarkOpsForOutsideCompilationPass();
|
||||
// attribute to each TensorFlow dialect op in the body based on the `device`
|
||||
// attribute on the `tf_device.launch`.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateLaunchToDeviceAttributePass();
|
||||
|
||||
// Creates a pass that hoists a `tf_device.replicate` body and replicates each
|
||||
// TensorFlow dialect op in the body based on its `device` attribute and the
|
||||
// `devices` attribute on the `tf_device.replicate`.
|
||||
std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateTFDeviceReplicationPass();
|
||||
} // namespace TFDevice
|
||||
|
||||
namespace TFTPU {
|
||||
@ -369,6 +374,12 @@ void CreateTPUBridgePipeline(OpPassManager& pm);
|
||||
// bridge in V1 mode.
|
||||
void CreateTPUBridgePipelineV1(OpPassManager& pm);
|
||||
|
||||
// Creates a pass that replicates the tf._TPUCompileMlir op on each host that
|
||||
// needs the compiled program. It helps avoid transferring the compiled binary
|
||||
// between hosts.
|
||||
std::unique_ptr<OperationPass<mlir::ModuleOp>>
|
||||
CreateTPUCompileOpReplicationPass();
|
||||
|
||||
} // namespace TFTPU
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -0,0 +1,127 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This pass hoists a `tf_device.replicate` body and replicates each TensorFlow
|
||||
// dialect op in the body based on its `device` attribute and the `devices`
|
||||
// attribute on the `tf_device.replicate`.
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFDevice {
|
||||
namespace {
|
||||
|
||||
constexpr char kDeviceAttr[] = "device";
|
||||
|
||||
class TFDeviceReplicationPass
|
||||
: public PassWrapper<TFDeviceReplicationPass, OperationPass<ModuleOp>> {
|
||||
public:
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<TF::TensorFlowDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
const Dialect *tf_dialect = getContext().getLoadedDialect("tf");
|
||||
module.walk([&](tf_device::ReplicateOp replicate_op) {
|
||||
OpBuilder builder(replicate_op);
|
||||
// Map from the existing operation in ReplicateOp's region to a list of
|
||||
// its replicated operations.
|
||||
llvm::DenseMap<Operation *, llvm::SmallVector<Operation *, 4>>
|
||||
operation_map;
|
||||
llvm::Optional<DictionaryAttr> devices = replicate_op.devices();
|
||||
const int replicate_num = replicate_op.n();
|
||||
|
||||
// Replicates every operation in the region of the ReplicateOp to match
|
||||
// the number of devices.
|
||||
for (int i : llvm::seq<int>(0, replicate_num)) {
|
||||
// Gets the mapping from the packed and replicated block arguments to
|
||||
// the actual value. This mapping is used to replace the arguments used
|
||||
// by the cloned operations.
|
||||
BlockAndValueMapping mapping;
|
||||
for (BlockArgument &arg : replicate_op.GetBody().getArguments()) {
|
||||
Value new_arg =
|
||||
replicate_op.GetReplicaOperandForBlockArgument(arg, i);
|
||||
mapping.map(arg, new_arg);
|
||||
}
|
||||
for (Operation &op : replicate_op.GetBody().without_terminator()) {
|
||||
// Clones the operation and places it outside the replicate_op's body.
|
||||
llvm::SmallVector<Operation *, 4> &new_ops = operation_map[&op];
|
||||
Operation *new_op = builder.clone(op, mapping);
|
||||
new_ops.push_back(new_op);
|
||||
// If the op is a TF op, it has a string-valued device attribute and
|
||||
// the replicate_op has a list of devices corresponding to this device
|
||||
// attribute's value, updates the device attribute for this op.
|
||||
if (!devices) continue;
|
||||
|
||||
if (op.getDialect() != tf_dialect) continue;
|
||||
|
||||
StringAttr device_alias =
|
||||
new_op->getAttrOfType<StringAttr>(kDeviceAttr);
|
||||
if (!device_alias) continue;
|
||||
|
||||
Attribute new_devices = devices->get(device_alias.getValue());
|
||||
if (!new_devices) continue;
|
||||
|
||||
ArrayAttr new_devices_array = new_devices.cast<ArrayAttr>();
|
||||
new_op->setAttr(kDeviceAttr, new_devices_array[i].cast<StringAttr>());
|
||||
}
|
||||
}
|
||||
// Replaces usages of the existing results of the tf_device.replicate
|
||||
// op with the results of the newly replicated operations.
|
||||
llvm::SmallVector<Value, 4> new_results;
|
||||
for (Value v : replicate_op.GetBody().getTerminator()->getOperands()) {
|
||||
OpResult result = v.dyn_cast<OpResult>();
|
||||
// Uses the original value if the value is not an OpResult.
|
||||
if (!result) {
|
||||
for (int i = 0; i < replicate_num; ++i) new_results.push_back(v);
|
||||
continue;
|
||||
}
|
||||
// Uses the original value if the value is defined by an op outside the
|
||||
// tf_device.replicate's body.
|
||||
Operation *op = result.getDefiningOp();
|
||||
if (operation_map.find(op) == operation_map.end()) {
|
||||
for (int i = 0; i < replicate_num; ++i) new_results.push_back(v);
|
||||
continue;
|
||||
}
|
||||
// Uses the values defined by the newly replicated operations.
|
||||
int result_num = result.getResultNumber();
|
||||
for (Operation *new_op : operation_map[op]) {
|
||||
new_results.push_back(new_op->getResult(result_num));
|
||||
}
|
||||
}
|
||||
replicate_op.replaceAllUsesWith(new_results);
|
||||
replicate_op.erase();
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTFDeviceReplicationPass() {
|
||||
return std::make_unique<TFDeviceReplicationPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<TFDeviceReplicationPass> pass(
|
||||
"tf-device-replication",
|
||||
"Hoists and replicates the tf_device.replicate "
|
||||
"inner ops once for each associated device.");
|
||||
|
||||
} // namespace TFDevice
|
||||
} // namespace mlir
|
@ -0,0 +1,103 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This pass replicates the tf._TPUCompileMlir op on each host that needs the
|
||||
// compiled program. It helps avoid transferring the compiled binary between
|
||||
// hosts.
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFTPU {
|
||||
namespace {
|
||||
|
||||
using DeviceNameUtils = ::tensorflow::DeviceNameUtils;
|
||||
using ParsedName = ::tensorflow::DeviceNameUtils::ParsedName;
|
||||
|
||||
constexpr char kDeviceAttr[] = "device";
|
||||
constexpr int kStatusResultIndex = 0;
|
||||
constexpr int kProgramResultIndex = 1;
|
||||
|
||||
static std::string GetHost(Operation *op) {
|
||||
if (StringAttr device = op->getAttrOfType<StringAttr>(kDeviceAttr)) {
|
||||
ParsedName parsed_name;
|
||||
DeviceNameUtils::ParseFullName(device.getValue().str(), &parsed_name);
|
||||
return DeviceNameUtils::ParsedNameToString(
|
||||
DeviceNameUtils::AddressSpace(parsed_name));
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
class TPUCompileOpReplicationPass
|
||||
: public PassWrapper<TPUCompileOpReplicationPass, OperationPass<ModuleOp>> {
|
||||
void runOnOperation() override {
|
||||
getOperation().walk([&](TF::_TPUCompileMlirOp tpu_compile_op) {
|
||||
Value compiled_program = tpu_compile_op.getResult(kProgramResultIndex);
|
||||
std::string tpu_compile_op_host = GetHost(tpu_compile_op.getOperation());
|
||||
llvm::StringMap<Operation *> compile_op_by_host;
|
||||
llvm::SmallVector<OpOperand *, 4> usages;
|
||||
|
||||
for (OpOperand &usage : compiled_program.getUses()) {
|
||||
usages.push_back(&usage);
|
||||
}
|
||||
|
||||
// For any op which uses the program compiled on a different host than the
|
||||
// original tf._TPUCompileMlir op, replicate the tf._TPUCompileMlir op on
|
||||
// that host and update the op to use the program compiled on the same
|
||||
// host.
|
||||
for (OpOperand *usage : usages) {
|
||||
std::string usage_op_host = GetHost(usage->getOwner());
|
||||
if (usage_op_host == tpu_compile_op_host) continue;
|
||||
|
||||
Operation *&new_compile_op = compile_op_by_host[usage_op_host];
|
||||
// If it is not already created, create a tf._TPUCompileMlir op and a
|
||||
// tf.TPUCompileSucceededAssert op on the first CPU of the target host.
|
||||
if (!new_compile_op) {
|
||||
std::string device_name = usage_op_host + "/device:CPU:0";
|
||||
OpBuilder builder(tpu_compile_op);
|
||||
new_compile_op = builder.clone(*tpu_compile_op.getOperation());
|
||||
new_compile_op->setAttr(kDeviceAttr,
|
||||
StringAttr::get(device_name, &getContext()));
|
||||
TF::TPUCompileSucceededAssertOp new_assert_op =
|
||||
builder.create<TF::TPUCompileSucceededAssertOp>(
|
||||
new_compile_op->getLoc(),
|
||||
new_compile_op->getResult(kStatusResultIndex));
|
||||
new_assert_op.setAttr(kDeviceAttr,
|
||||
new_compile_op->getAttr(kDeviceAttr));
|
||||
}
|
||||
// Updates the operand to use the result of the newly created
|
||||
// tf._TPUCompileMlir op.
|
||||
usage->set(new_compile_op->getResult(kProgramResultIndex));
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUCompileOpReplicationPass() {
|
||||
return std::make_unique<TPUCompileOpReplicationPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<TPUCompileOpReplicationPass> pass(
|
||||
"tf-tpu-compile-replication",
|
||||
"Replicate the TPU compile op to avoid sending the compiled binary between "
|
||||
"hosts.");
|
||||
|
||||
} // namespace TFTPU
|
||||
} // namespace mlir
|
@ -3569,17 +3569,20 @@ Status SavedModelSignatureDefImporter::LiftVariables() {
|
||||
|
||||
mlir::PassManager pm(module_->getContext());
|
||||
SetCrashReproducer(pm);
|
||||
pm.addPass(mlir::tf_executor::CreateTFExecutorGraphPruningPass());
|
||||
pm.addPass(mlir::CreateExecutorDialectToFunctionalConversionPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
mlir::tf_executor::CreateTFExecutorGraphPruningPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
mlir::CreateExecutorDialectToFunctionalConversionPass());
|
||||
pm.addPass(
|
||||
mlir::tf_saved_model::CreateRemoveVariablesInSessionInitializerPass());
|
||||
pm.addPass(
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
mlir::TF::
|
||||
CreateConvertReadonlyReferenceVariablesToResourceVariablesPass());
|
||||
pm.addPass(mlir::TF::CreatePromoteVarHandlesToArgsPass());
|
||||
pm.addPass(
|
||||
mlir::tf_saved_model::CreateLiftVariablesPass(bundle_.GetSession()));
|
||||
pm.addPass(mlir::tf_saved_model::CreateDedupBoundInputBindingPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
mlir::tf_saved_model::CreateDedupBoundInputBindingPass());
|
||||
if (mlir::failed(pm.run(*module_)))
|
||||
return diag_handler.Combine(errors::Internal("Failed to lift variables."));
|
||||
|
||||
|
@ -279,7 +279,8 @@ void CreateConvertMlirToXlaHloPipeline(
|
||||
pm.addPass(mlir::TF::CreateTensorListOpsDecompositionPass());
|
||||
pm.addPass(mlir::TF::CreateStackOpsDecompositionPass());
|
||||
pm.addPass(mlir::TF::CreateTensorArrayOpsDecompositionPass());
|
||||
pm.addPass(mlir::TFDevice::CreateDecomposeResourceOpsPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
mlir::TFDevice::CreateDecomposeResourceOpsPass());
|
||||
pm.addPass(mlir::TF::CreatePromoteResourcesToArgsPass());
|
||||
pm.addPass(mlir::createSymbolDCEPass());
|
||||
// Guarantee all functions have one use, which enables shape inference.
|
||||
|
@ -36,7 +36,7 @@ void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm) {
|
||||
pm->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass());
|
||||
|
||||
// TFJS dialect passes.
|
||||
pm->addPass(mlir::tfjs::CreateOptimizePass());
|
||||
pm->addNestedPass<mlir::FuncOp>(mlir::tfjs::CreateOptimizePass());
|
||||
|
||||
// Canonicalize, CSE etc.
|
||||
pm->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
|
@ -16,9 +16,11 @@ limitations under the License.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -34,7 +36,14 @@ namespace tfr {
|
||||
Status CompositeOpExpansion::Run(EagerOperation* orig_op,
|
||||
std::unique_ptr<EagerOperation>* out_op) {
|
||||
if (!IsEnabled()) return Status::OK();
|
||||
// This can be the default cpu device.
|
||||
if (orig_op->Device() != kVariantDeviceNull) return Status::OK();
|
||||
// TODO(fengliuai): We need a better condition to skip the rewrite. Currently,
|
||||
// The rewrite is enabled for all the tf ops and it is a no-op if the tf op
|
||||
// isn't a composite op. "VarHandleOp" is explicitly skipped here because its
|
||||
// roundtrip fails due to some unknown reasons.
|
||||
if (orig_op->is_function()) return Status::OK();
|
||||
if (absl::StartsWith(orig_op->op_name(), "VarHandleOp")) return Status::OK();
|
||||
|
||||
tf_core_op_expansion_node_counter->GetCell()->IncrementBy(1);
|
||||
|
||||
|
@ -103,6 +103,7 @@ std::unique_ptr<TFRDecomposeContext> TFRDecomposeContext::GetFromText(
|
||||
mlir::tf_executor::TensorFlowExecutorDialect,
|
||||
mlir::TFR::TFRDialect>();
|
||||
// clang-format on
|
||||
registry.loadAll(mlir_ctx);
|
||||
|
||||
// Load the TFR functions in a mlir::ModuleOp
|
||||
auto memory_buffer = llvm::MemoryBuffer::getMemBuffer(
|
||||
|
@ -75,32 +75,32 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
|
||||
applyTensorflowAndCLOptions(pm);
|
||||
|
||||
if (gpu_binary_only) {
|
||||
pm.addPass(mlir::mhlo::createLegalizeTFPass(
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
|
||||
/*allow_partial_conversion=*/false, /*legalize_chlo=*/true));
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
mlir::kernel_gen::transforms::CreateMaterializeBroadcastsPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
mlir::kernel_gen::transforms::CreateUnfuseBatchNormPass());
|
||||
pm.addPass(mlir::mhlo::createLegalizeToLhloPass(
|
||||
/*results_escape_functions=*/true));
|
||||
pm.addPass(mlir::mhlo::createLegalizeToLhloPass());
|
||||
// Moving `AllocOp`s and inserting missing `DeallocOp`s
|
||||
pm.addPass(::mlir::createBufferHoistingPass());
|
||||
pm.addPass(::mlir::createBufferDeallocationPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(::mlir::createBufferHoistingPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(::mlir::createBufferDeallocationPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::createCopyRemovalPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::kernel_gen::transforms::CreateShapeToDescriptorsPass());
|
||||
} else {
|
||||
pm.addPass(mlir::mhlo::createLegalizeTFPass(
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
|
||||
/*allow_partial_conversion=*/false, /*legalize_chlo=*/false));
|
||||
pm.addPass(mlir::createTransformUnrankedHloPass());
|
||||
pm.addPass(mlir::mhlo::createChloLegalizeToHloPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::createTransformUnrankedHloPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createChloLegalizeToHloPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::kernel_gen::transforms::CreateShapeToDescriptorsPass());
|
||||
// Clean up the IR created above. In particular, operations on descriptors
|
||||
// are simplified here.
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::kernel_gen::transforms::CreateBufferizePass());
|
||||
pm.addPass(mlir::kernel_gen::transforms::CreateParallelLoopsToSequential());
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
mlir::kernel_gen::transforms::CreateParallelLoopsToSequential());
|
||||
}
|
||||
|
||||
// Clean up the IR for further processing.
|
||||
@ -120,36 +120,41 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
|
||||
tiling_for_unrolling.append(tile_sizes.begin(), tile_sizes.end());
|
||||
}
|
||||
// Transform LHLO operations to LinAlg.
|
||||
pm.addPass(::mlir::lmhlo::createLegalizeLhloToLinalgPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
::mlir::lmhlo::createLegalizeLhloToLinalgPass());
|
||||
// Fuse linalg operations.
|
||||
pm.addPass(::mlir::lmhlo::createLhloFuseLinalgPass(
|
||||
pm.addNestedPass<mlir::FuncOp>(::mlir::lmhlo::createLhloFuseLinalgPass(
|
||||
/*use_parallel_loops=*/true, tiling_for_unrolling));
|
||||
// Transform the Linalg operations inside of the loop nest into parallel
|
||||
// loops.
|
||||
pm.addPass(::mlir::createConvertLinalgToParallelLoopsPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
::mlir::createConvertLinalgToParallelLoopsPass());
|
||||
// Canonicalize the code to simplify index computations. This is needed so
|
||||
// that loop bounds have the same value.
|
||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
|
||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
|
||||
// Fuse the inner-most loops.
|
||||
pm.addPass(xla::mlir_gpu::createFuseInnerParallelLoopsPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
xla::mlir_gpu::createFuseInnerParallelLoopsPass());
|
||||
// Run CSE to ensure that loads and stores to the same subview get
|
||||
// recognized as such.
|
||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
|
||||
// Forward stores to buffers to loads.
|
||||
pm.addPass(xla::mlir_gpu::createStoreForwardingPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(xla::mlir_gpu::createStoreForwardingPass());
|
||||
// Remove now unused temporary buffers.
|
||||
pm.addPass(xla::mlir_gpu::createDeadTempBufferRemovalPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
xla::mlir_gpu::createDeadTempBufferRemovalPass());
|
||||
if (!unroll_factors.empty()) {
|
||||
pm.addPass(::mlir::createParallelLoopTilingPass(as_int64));
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
::mlir::createParallelLoopTilingPass(as_int64));
|
||||
}
|
||||
// Some basic cleanup.
|
||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
|
||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
|
||||
// Greedily map the remaining loop to GPU hardware dimensions.
|
||||
pm.addPass(xla::mlir_gpu::createMapParallelLoopsPass());
|
||||
pm.addNestedPass<::mlir::FuncOp>(xla::mlir_gpu::createMapParallelLoopsPass());
|
||||
// Apply the mapping.
|
||||
pm.addPass(mlir::createParallelLoopToGpuPass());
|
||||
pm.addNestedPass<::mlir::FuncOp>(mlir::createParallelLoopToGpuPass());
|
||||
|
||||
// Embed TF Framework ops.
|
||||
if (!gpu_binary_only) {
|
||||
@ -172,12 +177,13 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
|
||||
|
||||
if (gpu_binary_only) {
|
||||
// Make kernel signature deterministic so that we can call it externally.
|
||||
pm.addPass(xla::mlir_gpu::createRewriteKernelSignaturePass());
|
||||
pm.addNestedPass<::mlir::FuncOp>(
|
||||
xla::mlir_gpu::createRewriteKernelSignaturePass());
|
||||
}
|
||||
pm.addPass(::mlir::createLowerAffinePass());
|
||||
|
||||
// Constraints are removed as late as possible and before lowering to CFG.
|
||||
pm.addPass(::mlir::createConvertShapeConstraintsPass());
|
||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createConvertShapeConstraintsPass());
|
||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
|
||||
|
||||
pm.addPass(::mlir::createLowerToCFGPass());
|
||||
|
@ -308,7 +308,7 @@ func @abs_unranked_i64(%arg : memref<*xi64>,
|
||||
%flat_shape : memref<1xindex>,
|
||||
%arg_size : index) -> memref<*xi64>
|
||||
attributes {tf_entry} {
|
||||
%flat_arg = lmhlo.reshape_memref_cast %arg(%flat_shape)
|
||||
%flat_arg = memref_reshape %arg(%flat_shape)
|
||||
: (memref<*xi64>, memref<1xindex>) -> memref<?xi64>
|
||||
// CHECK: alloc
|
||||
// CHECK-SAME: reuse_input_candidates = [0 : index], reuse_output = 0 : index
|
||||
@ -324,7 +324,7 @@ func @abs_unranked_i64(%arg : memref<*xi64>,
|
||||
%a_abs = select %a_pos, %a, %a_neg : i64
|
||||
linalg.yield %a_abs : i64
|
||||
}
|
||||
%result = lmhlo.reshape_memref_cast %flat_result(%arg_shape)
|
||||
%result = memref_reshape %flat_result(%arg_shape)
|
||||
: (memref<?xi64>, memref<?xindex>) -> memref<*xi64>
|
||||
return %result : memref<*xi64>
|
||||
}
|
||||
|
@ -76,3 +76,39 @@ func @assuming(%witness: !shape.witness, %arg : memref<?xf32>)
|
||||
}
|
||||
return %assuming_result : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @const
|
||||
// CHECK-SAME: -> memref<3xf32>
|
||||
func @const() -> tensor<3xf32> {
|
||||
// CHECK: %[[MEM:.*]] = alloca() : memref<3xf32>
|
||||
// CHECK: %[[C4:.*]] = constant 4.000000e+00 : f32
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: store %[[C4]], %[[MEM]][%[[C0]]] : memref<3xf32>
|
||||
// CHECK: %[[C5:.*]] = constant 5.000000e+00 : f32
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: store %[[C5]], %[[MEM]][%[[C1]]] : memref<3xf32>
|
||||
// CHECK: %[[C6:.*]] = constant 6.000000e+00 : f32
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: store %[[C6]], %[[MEM]][%[[C2]]] : memref<3xf32>
|
||||
// CHECK-NEXT: return %[[MEM]] : memref<3xf32>
|
||||
%result = constant dense<[4.0, 5.0, 6.0]> : tensor<3xf32>
|
||||
return %result : tensor<3xf32>
|
||||
}
|
||||
|
||||
|
||||
|
||||
// CHECK-LABEL: @const_splat
|
||||
// CHECK-SAME: -> memref<3xf32>
|
||||
func @const_splat() -> tensor<3xf32> {
|
||||
// CHECK: %[[MEM:.*]] = alloca() : memref<3xf32>
|
||||
// CHECK: %[[C4:.*]] = constant 4.000000e+00 : f32
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: store %[[C4]], %[[MEM]][%[[C0]]] : memref<3xf32>
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: store %[[C4]], %[[MEM]][%[[C1]]] : memref<3xf32>
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: store %[[C4]], %[[MEM]][%[[C2]]] : memref<3xf32>
|
||||
// CHECK-NEXT: return %[[MEM]] : memref<3xf32>
|
||||
%result = constant dense<4.0> : tensor<3xf32>
|
||||
return %result : tensor<3xf32>
|
||||
}
|
||||
|
@ -12,9 +12,9 @@ func @tanh(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NOT: tensor_load
|
||||
// CHECK: scf.for
|
||||
// CHECK-NOT: tensor_from_elements
|
||||
// CHECK: mhlo.reshape_memref_cast
|
||||
// CHECK: memref_reshape
|
||||
// CHECK: lmhlo.tanh
|
||||
// CHECK: mhlo.reshape_memref_cast
|
||||
// CHECK: memref_reshape
|
||||
%0 = "tf.Tanh"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ module attributes {gpu.container_module} {
|
||||
^bb6: // pred: ^bb4
|
||||
%13 = alloca() : memref<1xindex>
|
||||
store %8, %13[%c0] : memref<1xindex>
|
||||
%14 = lmhlo.reshape_memref_cast %arg0(%13) : (memref<*xf32>, memref<1xindex>) -> memref<?xf32>
|
||||
%14 = memref_reshape %arg0(%13) : (memref<*xf32>, memref<1xindex>) -> memref<?xf32>
|
||||
%15 = dim %14, %c0 : memref<?xf32>
|
||||
%16 = tf_framework.alloc(%ctx, %15) : memref<?xf32>
|
||||
%17 = cmpi "sle", %15, %c0 : index
|
||||
@ -53,7 +53,7 @@ module attributes {gpu.container_module} {
|
||||
gpu.launch_func @abs_kernel::@abs_kernel
|
||||
blocks in (%24, %c1, %c1) threads in (%c256, %c1, %c1)
|
||||
args(%14 : memref<?xf32>, %16 : memref<?xf32>)
|
||||
%25 = lmhlo.reshape_memref_cast %16(%1) : (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
|
||||
%25 = memref_reshape %16(%1) : (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
|
||||
return %25 : memref<*xf32>
|
||||
}
|
||||
|
||||
|
@ -35,11 +35,9 @@ cc_library(
|
||||
srcs = ["bufferize.cc"],
|
||||
hdrs = ["rewriters.h"],
|
||||
deps = [
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
@ -128,7 +126,6 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo",
|
||||
"//tensorflow/compiler/xla/service/gpu:stream_executor_util",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_llvm",
|
||||
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
|
||||
] + if_cuda_is_configured([
|
||||
"//tensorflow/stream_executor/gpu:asm_compiler",
|
||||
|
@ -110,8 +110,8 @@ class BufferSizeAnalysis {
|
||||
});
|
||||
|
||||
// Operand and result of `reshape_memref_cast` must be of same size.
|
||||
f.walk([&](lmhlo::ReshapeMemRefCastOp reshapeOp) {
|
||||
ecs_.unionSets(reshapeOp.result(), reshapeOp.operand());
|
||||
f.walk([&](MemRefReshapeOp reshapeOp) {
|
||||
ecs_.unionSets(reshapeOp.result(), reshapeOp.source());
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -17,28 +17,49 @@ limitations under the License.
|
||||
|
||||
#include "mlir/Transforms/Bufferize.h" // from @llvm-project
|
||||
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace kernel_gen {
|
||||
namespace transforms {
|
||||
|
||||
namespace {
|
||||
|
||||
class ConstantOpConverter : public OpConversionPattern<ConstantOp> {
|
||||
public:
|
||||
using OpConversionPattern<ConstantOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
ConstantOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// We only need to bufferize tensor constants.
|
||||
Location loc = op.getLoc();
|
||||
auto result_type = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!result_type || !result_type.hasStaticShape() ||
|
||||
result_type.getRank() != 1)
|
||||
return failure();
|
||||
|
||||
auto memref_type = MemRefType::get({result_type.getNumElements()},
|
||||
result_type.getElementType());
|
||||
Value buffer = rewriter.create<AllocaOp>(loc, memref_type);
|
||||
|
||||
auto elements_attr = op.getValue().dyn_cast<DenseElementsAttr>();
|
||||
bool all_same_elems = elements_attr.isSplat();
|
||||
Value value;
|
||||
if (all_same_elems)
|
||||
value = rewriter.create<ConstantOp>(loc, elements_attr.getSplatValue());
|
||||
for (auto en : llvm::enumerate(elements_attr.getAttributeValues())) {
|
||||
if (!all_same_elems) value = rewriter.create<ConstantOp>(loc, en.value());
|
||||
Value index = rewriter.create<ConstantIndexOp>(loc, en.index());
|
||||
rewriter.create<StoreOp>(loc, value, buffer, index);
|
||||
}
|
||||
rewriter.replaceOp(op, {buffer});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class TensorFromElementsOpConverter
|
||||
: public OpConversionPattern<TensorFromElementsOp> {
|
||||
public:
|
||||
@ -48,14 +69,14 @@ class TensorFromElementsOpConverter
|
||||
TensorFromElementsOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
Location loc = op.getLoc();
|
||||
ShapedType result_type = op.getType().cast<ShapedType>();
|
||||
auto result_type = op.getType().cast<ShapedType>();
|
||||
int number_of_elements = op.elements().size();
|
||||
MemRefType memref_type =
|
||||
MemRefType::get({number_of_elements}, result_type.getElementType());
|
||||
Value result = rewriter.create<AllocaOp>(loc, memref_type);
|
||||
for (auto operand : llvm::enumerate(operands)) {
|
||||
Value index = rewriter.create<ConstantIndexOp>(loc, operand.index());
|
||||
rewriter.create<StoreOp>(loc, operand.value(), result, index);
|
||||
for (auto en : llvm::enumerate(operands)) {
|
||||
Value index = rewriter.create<ConstantIndexOp>(loc, en.index());
|
||||
rewriter.create<StoreOp>(loc, en.value(), result, index);
|
||||
}
|
||||
rewriter.replaceOp(op, {result});
|
||||
return success();
|
||||
@ -73,7 +94,7 @@ class DynamicTensorFromElementsOpConverter
|
||||
// Allocate memory on stack.
|
||||
Location loc = op.getLoc();
|
||||
DynamicTensorFromElementsOp::Adaptor transformed(operands);
|
||||
RankedTensorType tensor_ty = op.getType().cast<RankedTensorType>();
|
||||
auto tensor_ty = op.getType().cast<RankedTensorType>();
|
||||
MemRefType memref_type =
|
||||
MemRefType::get(tensor_ty.getShape(), tensor_ty.getElementType());
|
||||
Value result = rewriter.create<AllocaOp>(loc, memref_type,
|
||||
@ -87,7 +108,7 @@ class DynamicTensorFromElementsOpConverter
|
||||
SmallVector<Value, 4> steps(rank, one);
|
||||
SmallVector<Value, 4> upper_bounds;
|
||||
int next_dynamic_index = 0;
|
||||
for (int i = 0; i < rank; i++) {
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
Value ub = tensor_ty.isDynamicDim(i)
|
||||
? transformed.dynamicExtents()[next_dynamic_index++]
|
||||
: rewriter.create<ConstantIndexOp>(
|
||||
@ -139,7 +160,6 @@ class ExtractElementOpConversion
|
||||
if (!adaptor.aggregate().getType().isa<BaseMemRefType>()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<LoadOp>(op, adaptor.aggregate(),
|
||||
adaptor.indices());
|
||||
return success();
|
||||
@ -182,7 +202,8 @@ class TensorCastOpConverter : public OpConversionPattern<TensorCastOp> {
|
||||
void populateStandardBufferizePattern(MLIRContext *context,
|
||||
BufferizeTypeConverter *converter,
|
||||
OwningRewritePatternList *patterns) {
|
||||
patterns->insert<ExtractElementOpConversion, TensorFromElementsOpConverter,
|
||||
patterns->insert<ConstantOpConverter, ExtractElementOpConversion,
|
||||
TensorFromElementsOpConverter,
|
||||
DynamicTensorFromElementsOpConverter,
|
||||
SimpleOpResultConversion<SelectOp>, TensorLoadOpConversion,
|
||||
TensorCastOpConverter>(*converter, context);
|
||||
|
@ -98,7 +98,8 @@ struct BufferizePass : public BufferizePassBase<BufferizePass> {
|
||||
return converter.isLegal(inputs) && converter.isLegal(results) &&
|
||||
converter.isLegal(&op.getBody());
|
||||
});
|
||||
target.addDynamicallyLegalOp<CallOp, ReturnOp, SelectOp>(typesAreLegal);
|
||||
target.addDynamicallyLegalOp<CallOp, ReturnOp, SelectOp, ConstantOp>(
|
||||
typesAreLegal);
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
mhlo::populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
|
||||
|
@ -162,8 +162,8 @@ class ShapeEqualityKnowledge {
|
||||
/// results.
|
||||
void build(FuncOp function) {
|
||||
function.walk([&](Operation *op) {
|
||||
if (auto reshape = dyn_cast<lmhlo::ReshapeMemRefCastOp>(op)) {
|
||||
registerAssociation(ShapeValue{reshape.operand()}, reshape.result());
|
||||
if (auto reshape = dyn_cast<MemRefReshapeOp>(op)) {
|
||||
registerAssociation(ShapeValue{reshape.source()}, reshape.result());
|
||||
return;
|
||||
}
|
||||
if (auto alloc = dyn_cast<AllocOp>(op)) {
|
||||
|
@ -133,7 +133,7 @@ struct PropagateTfAbiKnowledgeToKernelsPass
|
||||
while (!worklist.empty()) {
|
||||
Value candidate = worklist.pop_back_val();
|
||||
for (auto user : candidate.getUsers()) {
|
||||
if (auto reshape = dyn_cast<lmhlo::ReshapeMemRefCastOp>(user)) {
|
||||
if (auto reshape = dyn_cast<MemRefReshapeOp>(user)) {
|
||||
// Reshape propagates alignment and offset.
|
||||
// TODO(herhut): This should be a trait.
|
||||
if (allocated_by_runtime.insert(reshape.result()).second) {
|
||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
|
||||
@ -53,13 +52,12 @@ class TFKernelToLLVMPass : public TFKernelToLLVMPassBase<TFKernelToLLVMPass> {
|
||||
tf_framework::PopulateTFFrameworkToLLVMConversionPatterns(&type_converter,
|
||||
&patterns);
|
||||
populateGpuToLLVMConversionPatterns(type_converter, patterns, "gpu.binary");
|
||||
lmhlo::PopulateLhloToLLVMConversionPatterns(&type_converter, &patterns);
|
||||
|
||||
// Set target.
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||
target
|
||||
.addIllegalDialect<gpu::GPUDialect, tf_framework::TFFrameworkDialect>();
|
||||
target.addIllegalDialect<gpu::GPUDialect, StandardOpsDialect,
|
||||
tf_framework::TFFrameworkDialect>();
|
||||
target.addIllegalOp<LLVM::DialectCastOp>();
|
||||
|
||||
if (failed(applyPartialConversion(m, target, std::move(patterns)))) {
|
||||
|
@ -2527,6 +2527,27 @@ func @expand_dims_dynamic(%arg0: tensor<?x?xf32>) -> tensor<?x1x?xf32> {
|
||||
return %0 : tensor<?x1x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: expand_dynamic_dims_rank1_axis
|
||||
func @expand_dynamic_dims_rank1_axis(%arg0: tensor<?x?x4xf32>) -> tensor<?x1x?x4xf32> {
|
||||
%axis = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
|
||||
// CHECK-DAG: [[SHAPEOF:%.+]] = shape.shape_of %arg0
|
||||
// CHECK-DAG: [[CST0:%.+]] = constant 0
|
||||
// CHECK-DAG: [[CST1:%.+]] = constant 1
|
||||
// CHECK-DAG: [[GETEXTENT0:%.+]] = shape.get_extent [[SHAPEOF]], [[CST0]]
|
||||
// CHECK-DAG: [[CST1_0:%.+]] = constant 1
|
||||
// CHECK-DAG: [[GETEXTENT1:%.+]] = shape.get_extent [[SHAPEOF]], [[CST1_0]]
|
||||
// CHECK-DAG: [[CST2:%.+]] = constant 2
|
||||
// CHECK-DAG: [[GETEXTENT2:%.+]] = shape.get_extent [[SHAPEOF]], [[CST2]]
|
||||
// CHECK-DAG: [[FROMEXTENTS:%.+]] = shape.from_extents [[GETEXTENT0]], [[CST1]], [[GETEXTENT1]], [[GETEXTENT2]]
|
||||
// CHECK-DAG: [[TOEXTENTS:%.+]] = shape.to_extent_tensor [[FROMEXTENTS]]
|
||||
// CHECK-DAG: [[RESHAPE:%.+]] = "mhlo.dynamic_reshape"(%arg0, [[TOEXTENTS]])
|
||||
%0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<?x?x4xf32>, tensor<1xi32>) -> tensor<?x1x?x4xf32>
|
||||
|
||||
// CHECK: return [[RESHAPE]]
|
||||
return %0 : tensor<?x1x?x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sign
|
||||
// CHECK-SAME: [[ARG:%arg.*]]: tensor<1x2x3x4xf32>
|
||||
func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
|
||||
|
@ -5535,10 +5535,7 @@ class ConvertDynamicExpandDimsOp : public OpRewritePattern<TF::ExpandDimsOp> {
|
||||
llvm::SmallVector<Value, 4> dims;
|
||||
dims.resize(result_ty.getRank());
|
||||
|
||||
auto inserted_dim = expand_dims_attr.getValue({})
|
||||
.cast<IntegerAttr>()
|
||||
.getValue()
|
||||
.getSExtValue();
|
||||
auto inserted_dim = expand_dims[0].getSExtValue();
|
||||
|
||||
// Handle the negative value use case.
|
||||
if (inserted_dim < 0) {
|
||||
|
@ -103,7 +103,6 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
||||
TypeID::get<TF::AtanhOp>(),
|
||||
TypeID::get<TF::AtanOp>(),
|
||||
TypeID::get<TF::BatchMatMulV2Op>(),
|
||||
TypeID::get<TF::BatchToSpaceNDOp>(),
|
||||
TypeID::get<TF::BatchToSpaceOp>(),
|
||||
TypeID::get<TF::BesselI0eOp>(),
|
||||
TypeID::get<TF::BesselI1eOp>(),
|
||||
|
@ -519,14 +519,21 @@ StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
|
||||
|
||||
// ViewOp only takes memrefs without affine maps (layouts). Let ViewOp produce
|
||||
// the physical shape (where dimensions are ordered in major to minor) first,
|
||||
// then follow up with a StaticMemRefCastOp to cast the resulting memref to
|
||||
// then follow up with a MemRefReinterpretCast to cast the resulting memref to
|
||||
// the original layout.
|
||||
Value result =
|
||||
builder_.create<ViewOp>(loc, physical_out_type, alloc, byte_shift,
|
||||
/*sizes=*/ValueRange{});
|
||||
if (physical_out_type != out_type)
|
||||
result = builder_.create<lmhlo::StaticMemRefCastOp>(loc, out_memref_type,
|
||||
result);
|
||||
if (physical_out_type != out_type) {
|
||||
int64_t out_offset;
|
||||
SmallVector<int64_t, 4> out_strides;
|
||||
if (failed(getStridesAndOffset(out_memref_type, out_strides, out_offset)))
|
||||
return tensorflow::errors::Internal(
|
||||
"Failed to get strides and offset from the output type.");
|
||||
result = builder_.create<MemRefReinterpretCastOp>(
|
||||
loc, out_memref_type, result, out_offset, out_memref_type.getShape(),
|
||||
out_strides, llvm::None, llvm::None, llvm::None);
|
||||
}
|
||||
return cached_value = result;
|
||||
}
|
||||
|
||||
|
@ -1481,7 +1481,7 @@ tf_xla_py_test(
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
use_xla_device = False, # Uses tf.function(experimental_compile=True)
|
||||
use_xla_device = False, # Uses tf.function(jit_compile=True)
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/compiler/tf2xla/python:xla",
|
||||
|
@ -32,7 +32,7 @@ class CaseTest(xla_test.XLATestCase):
|
||||
|
||||
def testCaseBasic(self):
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
@def_function.function(jit_compile=True)
|
||||
def switch_case_test(branch_index):
|
||||
|
||||
def f1():
|
||||
@ -58,7 +58,7 @@ class CaseTest(xla_test.XLATestCase):
|
||||
|
||||
def testBranchIsPruned(self):
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
@def_function.function(jit_compile=True)
|
||||
def switch_case_test():
|
||||
branch_index = array_ops.constant(0)
|
||||
|
||||
|
@ -693,7 +693,7 @@ class EagerFunctionTest(xla_test.XLATestCase):
|
||||
return x, y
|
||||
|
||||
wholly_compiled_f = def_function.function(f)
|
||||
op_by_op_f = def_function.function(f, experimental_compile=False)
|
||||
op_by_op_f = def_function.function(f, jit_compile=False)
|
||||
|
||||
x = array_ops.identity([0.0, 2.0], name='data')
|
||||
|
||||
|
@ -45,22 +45,22 @@ flags.DEFINE_bool('vary_seed', False,
|
||||
NUM_SAMPLES = int(1e3)
|
||||
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
@def_function.function(jit_compile=True)
|
||||
def _igamma(a, x):
|
||||
return math_ops.igamma(a, x)
|
||||
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
@def_function.function(jit_compile=True)
|
||||
def _igammac(a, x):
|
||||
return math_ops.igammac(a, x)
|
||||
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
@def_function.function(jit_compile=True)
|
||||
def _polygamma(n, x):
|
||||
return math_ops.polygamma(n, x)
|
||||
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
@def_function.function(jit_compile=True)
|
||||
def _zeta(a, q):
|
||||
return math_ops.zeta(a, q)
|
||||
|
||||
@ -72,7 +72,7 @@ def implicit_reparameterization_grad(a, x):
|
||||
return -gen_math_ops.igamma_grad_a(a, x) / prob
|
||||
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
@def_function.function(jit_compile=True)
|
||||
def _log1p(x):
|
||||
return math_ops.log1p(x)
|
||||
|
||||
|
@ -51,7 +51,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
|
||||
scenarios (e.g. TPU). The new version of stateless_random_* requires the
|
||||
intermediate tensor `alg` to be compile-time constant, so we need to check
|
||||
that this requirement is met. We use xla.compile instead of tf.function's
|
||||
experimental_compile because the latter doesn't throw an error even if the
|
||||
jit_compile because the latter doesn't throw an error even if the
|
||||
compile-time-constant constraint is not met.
|
||||
"""
|
||||
if config.list_logical_devices('TPU'):
|
||||
|
@ -223,12 +223,10 @@ class StridedSliceOp : public XlaOpKernel {
|
||||
input_elements_sliced *= input_shape.dim_size(d);
|
||||
}
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, output_elements == input_elements_sliced,
|
||||
errors::InvalidArgument(
|
||||
"The number of output elements ", output_elements,
|
||||
" has to equal to number of input elements that are sliced ",
|
||||
input_elements_sliced, " when input indices are not constant."));
|
||||
OP_REQUIRES(ctx, output_elements == input_elements_sliced,
|
||||
errors::InvalidArgument(
|
||||
"Dynamic indices of strided_slice_op have to be leading "
|
||||
"dimensions in the indices list."));
|
||||
|
||||
for (int64 i = 0; i < ctx->InputShape("begin").dims(); ++i) {
|
||||
OP_REQUIRES(
|
||||
|
@ -197,6 +197,9 @@ class XlaCompiler {
|
||||
// Alias input and output buffers for parameters that are passed-through XLA
|
||||
// modules without being changed.
|
||||
bool alias_passthrough_params = false;
|
||||
|
||||
// Enable detailed logging of compilation metadata.
|
||||
bool detailed_logging = true;
|
||||
};
|
||||
|
||||
explicit XlaCompiler(Options options);
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user