Merge branch 'master' into update-array-ops-docstrings
This commit is contained in:
commit
16028e280d
1
.gitignore
vendored
1
.gitignore
vendored
@ -38,6 +38,7 @@ gradleBuild
|
||||
*.pbxproj
|
||||
*.xcworkspace
|
||||
/*.podspec
|
||||
/tensorflow/lite/**/coreml/**/BUILD
|
||||
/tensorflow/lite/**/ios/BUILD
|
||||
/tensorflow/lite/**/objc/BUILD
|
||||
/tensorflow/lite/**/swift/BUILD
|
||||
|
@ -58,6 +58,8 @@ NCCL_LIB_PATHS = [
|
||||
|
||||
# List of files to configure when building Bazel on Apple platforms.
|
||||
APPLE_BAZEL_FILES = [
|
||||
'tensorflow/lite/experimental/delegates/coreml/BUILD',
|
||||
'tensorflow/lite/experimental/delegates/coreml/builders/BUILD',
|
||||
'tensorflow/lite/experimental/ios/BUILD',
|
||||
'tensorflow/lite/experimental/objc/BUILD',
|
||||
'tensorflow/lite/experimental/swift/BUILD',
|
||||
|
@ -639,7 +639,7 @@ tf_cc_shared_object(
|
||||
"//tensorflow/cc/saved_model:loader_lite_impl",
|
||||
"//tensorflow/core:core_cpu_impl",
|
||||
"//tensorflow/core:framework_internal_impl",
|
||||
"//tensorflow/core:gpu_runtime_impl",
|
||||
"//tensorflow/core/common_runtime/gpu:gpu_runtime_impl",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
|
||||
"//tensorflow/core:lib_internal_impl",
|
||||
"//tensorflow/core/profiler:profiler_impl",
|
||||
|
@ -995,9 +995,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensorflow::Tensor tensor = tensorflow::TensorFromInterface(t);
|
||||
t->Release();
|
||||
return tensorflow::TF_TensorFromTensor(tensor, &status->status);
|
||||
return new TF_Tensor{t};
|
||||
}
|
||||
|
||||
void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
||||
|
@ -580,12 +580,6 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
||||
};
|
||||
}
|
||||
|
||||
void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h,
|
||||
TF_Status* status) {
|
||||
h->handle->EnableImplicitMirroring();
|
||||
status->status = tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
||||
TF_Buffer* buf, TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
|
@ -392,12 +392,6 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
// If the TensorHandle is copied to another device as part of an op execution,
|
||||
// the copy is destroyed after the op has executed. Enabling implicit mirroring
|
||||
// causes the copy to be held as a mirror for the lifetime of the TensorHandle.
|
||||
TF_CAPI_EXPORT extern void TFE_TensorHandleEnableImplicitMirroring(
|
||||
TFE_TensorHandle*, TF_Status*);
|
||||
|
||||
// This function will block till the operation that produces `h` has
|
||||
// completed. This is only valid on local TFE_TensorHandles. The pointer
|
||||
// returned will be on the device in which the TFE_TensorHandle resides (so e.g.
|
||||
|
@ -168,8 +168,6 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
auto* h1_task2 =
|
||||
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandleEnableImplicitMirroring(h1_task2, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Handles are on task0 (local), and task2, but op is on task1.
|
||||
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
||||
|
@ -594,7 +594,6 @@ void ExecuteAdd(bool async, bool forward_input) {
|
||||
TFE_TensorHandle* n_gpu =
|
||||
TFE_TensorHandleCopyToDevice(n, ctx, gpu_device_name.c_str(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandleEnableImplicitMirroring(n_gpu, status);
|
||||
TFE_DeleteTensorHandle(n);
|
||||
n = n_gpu;
|
||||
}
|
||||
|
@ -59,14 +59,6 @@ class AbstractTensorHandleInterface {
|
||||
// Return a copy of the handle.
|
||||
virtual AbstractTensorHandleInterface* Copy() = 0;
|
||||
|
||||
// Maintain mirror tensors for any implicit copies to local devices. This
|
||||
// setting is offered on a per tensor handle basis to avoid potential memory
|
||||
// over utilization due to holding on to mirrors as well as the original
|
||||
// tensor. Note this setting overrides the context mirroring policy whereby if
|
||||
// the mirroring policy is MIRRORING_NONE, we will still continue to mirror
|
||||
// this tensor.
|
||||
virtual void EnableImplicitMirroring() = 0;
|
||||
|
||||
protected:
|
||||
virtual ~AbstractTensorHandleInterface() {}
|
||||
};
|
||||
|
@ -118,8 +118,8 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:gpu_init",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/common_runtime/gpu:gpu_init",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
|
@ -72,6 +72,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
|
||||
"//tensorflow/compiler/mlir/xla:buffer_assignment",
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo",
|
||||
|
@ -1,6 +1,7 @@
|
||||
book_path: /mlir/_book.yaml
|
||||
project_path: /mlir/_project.yaml
|
||||
description: <!--no description-->
|
||||
description: An intermediate representation and compiler framework, MLIR unifies the
|
||||
infrastructure for high-performance ML models in TensorFlow.
|
||||
landing_page:
|
||||
custom_css_path: /site-assets/css/style.css
|
||||
rows:
|
||||
|
@ -771,6 +771,7 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/lite/tools/optimize:quantize_weights",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
|
@ -36,7 +36,7 @@ struct PassConfig {
|
||||
form_clusters(false),
|
||||
unfold_batch_matmul(true),
|
||||
legalize_tf_while(true),
|
||||
shape_inference(false) {}
|
||||
shape_inference(true) {}
|
||||
|
||||
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
|
||||
// added, which produces TF Lite ops.
|
||||
|
@ -409,10 +409,14 @@ static void GenOperandResultVerifier(raw_ostream &os,
|
||||
os << " (void)v;\n"
|
||||
<< " if (!("
|
||||
<< tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n"
|
||||
<< " if (failure_on_operand_type_mismatch) {\n"
|
||||
<< formatv(
|
||||
" return op->emitOpError(\"{0} #\") << index "
|
||||
"<< \" must be {1}, but got \" << v.getType();\n",
|
||||
valueKind, desc)
|
||||
<< " } else {\n"
|
||||
<< " return ::mlir::LogicalResult::Failure;\n"
|
||||
<< " }\n"
|
||||
<< " }\n" // if
|
||||
<< " ++index;\n"
|
||||
<< " }\n"; // for
|
||||
@ -437,7 +441,8 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||
|
||||
mlir::tblgen::FmtContext verify_ctx;
|
||||
os << "::mlir::LogicalResult " << op.getCppClassName()
|
||||
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op) {\n";
|
||||
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op, bool "
|
||||
"failure_on_operand_type_mismatch) {\n";
|
||||
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
|
||||
verify_ctx.withOp("top");
|
||||
|
||||
|
@ -70,6 +70,19 @@ class TFLiteCostEstimator<Conv2DOp, hardware::GPU> {
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.cos
|
||||
template <>
|
||||
class TFLiteCostEstimator<CosOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.depthwise_conv_2d
|
||||
template <>
|
||||
class TFLiteCostEstimator<DepthwiseConv2DOp, hardware::GPU> {
|
||||
@ -83,6 +96,32 @@ class TFLiteCostEstimator<DepthwiseConv2DOp, hardware::GPU> {
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.div
|
||||
template <>
|
||||
class TFLiteCostEstimator<DivOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.exp
|
||||
template <>
|
||||
class TFLiteCostEstimator<ExpOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.fully_connected
|
||||
template <>
|
||||
class TFLiteCostEstimator<FullyConnectedOp, hardware::GPU> {
|
||||
@ -97,6 +136,19 @@ class TFLiteCostEstimator<FullyConnectedOp, hardware::GPU> {
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.hard_swish
|
||||
template <>
|
||||
class TFLiteCostEstimator<HardSwishOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.logistic
|
||||
template <>
|
||||
class TFLiteCostEstimator<LogisticOp, hardware::GPU> {
|
||||
|
@ -138,6 +138,8 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
||||
return tflite::TensorType_FLOAT32;
|
||||
case mlir::StandardTypes::F16:
|
||||
return tflite::TensorType_FLOAT16;
|
||||
case mlir::StandardTypes::F64:
|
||||
return tflite::TensorType_FLOAT64;
|
||||
case mlir::TF::TensorFlowTypes::STRING:
|
||||
return tflite::TensorType_STRING;
|
||||
case mlir::TF::TensorFlowTypes::QUINT8:
|
||||
|
@ -353,6 +353,22 @@ StatusOr<mlir::ElementsAttr> ConvertFloatBuffer(
|
||||
}
|
||||
return DenseElementsAttr::get(shaped_type, ArrayRef<float>(values));
|
||||
}
|
||||
case 64: {
|
||||
assert(bytes_len % 8 == 0);
|
||||
size_t elem_count = bytes_len / 8;
|
||||
std::vector<double> values;
|
||||
values.reserve(elem_count);
|
||||
|
||||
const char* data = reinterpret_cast<const char*>(buffer.data());
|
||||
|
||||
for (int i = 0; i < elem_count; i++) {
|
||||
uint64_t bit_repr =
|
||||
llvm::support::endian::readNext<uint64_t, llvm::support::little,
|
||||
llvm::support::unaligned>(data);
|
||||
values.push_back(absl::bit_cast<double>(bit_repr));
|
||||
}
|
||||
return DenseElementsAttr::get(shaped_type, ArrayRef<double>(values));
|
||||
}
|
||||
}
|
||||
return errors::InvalidArgument("unsupported bit width", elem_type.getWidth());
|
||||
}
|
||||
|
@ -86,7 +86,8 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
|
||||
let methods = [
|
||||
StaticInterfaceMethod<
|
||||
[{Returns whether the op's operands/results are supported by runtime.}],
|
||||
"LogicalResult", "VerifyTflRuntimeTypes", (ins "Operation*":$op)
|
||||
"LogicalResult", "VerifyTflRuntimeTypes",
|
||||
(ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
@ -706,7 +706,10 @@ def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> {
|
||||
}
|
||||
|
||||
def TFL_CosOp: TFL_Op<"cos", [
|
||||
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Cosine operator";
|
||||
|
||||
let description = [{
|
||||
@ -827,12 +830,12 @@ def TFL_GatherNdOp : TFL_Op<"gather_nd", [NoSideEffect]> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$params,
|
||||
TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8, TFL_Str]>:$params,
|
||||
TFL_I32OrI64Tensor:$indices
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$output
|
||||
TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8, TFL_Str]>:$output
|
||||
);
|
||||
}
|
||||
|
||||
@ -1108,7 +1111,10 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [
|
||||
def TFL_DivOp : TFL_Op<"div", [
|
||||
// TODO(fengliuai): NoQuantizableResult is only correct for int8
|
||||
// quantization. update to handle Uint8 quantization.
|
||||
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Division operator";
|
||||
|
||||
let description = [{
|
||||
@ -1187,7 +1193,9 @@ def TFL_EqualOp: TFL_Op<"equal", [Commutative, ResultsBroadcastableShape,
|
||||
let builders = [TFL_ComparisonBinaryBuilder];
|
||||
}
|
||||
|
||||
def TFL_ExpOp: TFL_Op<"exp", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TFL_ExpOp: TFL_Op<"exp", [NoSideEffect,
|
||||
SameOperandsAndResultType,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Natural exponentiation operator";
|
||||
|
||||
let description = [{
|
||||
@ -1369,7 +1377,8 @@ def TFL_GreaterOp : TFL_Op<"greater", [
|
||||
}
|
||||
|
||||
def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect,
|
||||
SameOperandsAndResultShape]> {
|
||||
SameOperandsAndResultShape,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Hardswish activation function.";
|
||||
let description = [{
|
||||
Computes hard-swish activation function
|
||||
|
@ -84,8 +84,14 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
|
||||
|
||||
mlir::TFL::PassConfig pass_config(quant_specs);
|
||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||
pass_config.lower_tensor_list_ops = true;
|
||||
pass_config.shape_inference = false;
|
||||
|
||||
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
|
||||
quant_specs, result);
|
||||
pass_config, result);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -43,8 +43,6 @@ namespace tensorflow {
|
||||
|
||||
Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
const string& saved_model_dir, bool saved_model_v1,
|
||||
const string& saved_model_tags, const string& saved_model_exported_names,
|
||||
string* result) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::TFL::QuantizationSpecs quant_specs;
|
||||
@ -66,13 +64,28 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
// Register all custom ops, including user-specified custom ops.
|
||||
TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags));
|
||||
|
||||
const bool import_saved_model = !saved_model_v1;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto module,
|
||||
ImportSavedModel(import_saved_model, saved_model_v1, saved_model_dir,
|
||||
saved_model_tags, saved_model_exported_names, &context));
|
||||
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
|
||||
quant_specs, result);
|
||||
auto& saved_model_tags = model_flags.saved_model_tags();
|
||||
auto& saved_model_exported_names = model_flags.saved_model_exported_names();
|
||||
std::unordered_set<std::string> tags(saved_model_tags.begin(),
|
||||
saved_model_tags.end());
|
||||
auto exported_names_in_vector = std::vector<std::string>(
|
||||
saved_model_exported_names.begin(), saved_model_exported_names.end());
|
||||
absl::Span<std::string> exported_names(exported_names_in_vector);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto module,
|
||||
ImportSavedModel(model_flags.saved_model_dir(),
|
||||
model_flags.saved_model_version(), tags,
|
||||
exported_names, &context));
|
||||
|
||||
mlir::TFL::PassConfig pass_config(quant_specs);
|
||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||
pass_config.lower_tensor_list_ops = true;
|
||||
pass_config.shape_inference = true;
|
||||
|
||||
auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
|
||||
toco_flags, std::move(module), pass_config, result);
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -28,8 +28,6 @@ namespace tensorflow {
|
||||
// status if it fails to convert the input.
|
||||
Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
const string& saved_model_dir, bool saved_model_v1,
|
||||
const string& saved_model_tags, const string& saved_model_exported_names,
|
||||
string* result);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -105,6 +105,10 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
|
||||
switch (dtype) {
|
||||
case toco::IODataType::FLOAT:
|
||||
return DT_FLOAT;
|
||||
case toco::IODataType::FLOAT16:
|
||||
return DT_HALF;
|
||||
case toco::IODataType::FLOAT64:
|
||||
return DT_DOUBLE;
|
||||
case toco::IODataType::QUANTIZED_UINT8:
|
||||
return DT_QUINT8;
|
||||
case toco::IODataType::INT8:
|
||||
@ -261,7 +265,7 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
|
||||
|
||||
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
|
||||
mlir::OwningModuleRef module,
|
||||
mlir::TFL::QuantizationSpecs quant_specs,
|
||||
const mlir::TFL::PassConfig& pass_config,
|
||||
string* result) {
|
||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
|
||||
@ -275,9 +279,6 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
|
||||
}
|
||||
|
||||
mlir::PassManager pm(module->getContext());
|
||||
mlir::TFL::PassConfig pass_config(quant_specs);
|
||||
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||
pass_config.lower_tensor_list_ops = true;
|
||||
|
||||
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
|
||||
// Convert back to outlined while format for export back to flatbuffer.
|
||||
@ -288,7 +289,8 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
|
||||
|
||||
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, quant_specs, result, &pm);
|
||||
emit_select_tf_ops, emit_custom_ops, pass_config.quant_specs, result,
|
||||
&pm);
|
||||
if (toco_flags.has_dump_graphviz_dir()) {
|
||||
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
|
||||
// rename once we enable the new converter feature flag.
|
||||
|
@ -47,7 +47,7 @@ Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
|
||||
// This will also run relevant passes as well.
|
||||
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
|
||||
mlir::OwningModuleRef module,
|
||||
mlir::TFL::QuantizationSpecs quant_specs,
|
||||
const mlir::TFL::PassConfig& pass_config,
|
||||
string* result);
|
||||
|
||||
// Give a warning for any unused flags that have been specified.
|
||||
|
@ -57,8 +57,9 @@ QuantizeContext::QuantizeContext(FuncOp func, const DeviceTarget &spec)
|
||||
});
|
||||
}
|
||||
|
||||
llvm::ArrayRef<quant::QuantizeRegionOp> QuantizeContext::GetAllOps() {
|
||||
llvm::SmallVector<quant::QuantizeRegionOp, 64> all_ops;
|
||||
std::vector<quant::QuantizeRegionOp> QuantizeContext::GetAllOps() {
|
||||
std::vector<quant::QuantizeRegionOp> all_ops;
|
||||
all_ops.reserve(128);
|
||||
func_.walk([&](quant::QuantizeRegionOp op) { all_ops.push_back(op); });
|
||||
return all_ops;
|
||||
}
|
||||
@ -75,7 +76,7 @@ LogicalResult QuantizeContext::Handle(
|
||||
switch (spec->type) {
|
||||
case ScaleConstraintType::OutputInputFreeScale: {
|
||||
// no propagation.
|
||||
*changed = false;
|
||||
*changed |= false;
|
||||
break;
|
||||
}
|
||||
case ScaleConstraintType::CustomScale: {
|
||||
@ -84,7 +85,20 @@ LogicalResult QuantizeContext::Handle(
|
||||
}
|
||||
break;
|
||||
}
|
||||
case ScaleConstraintType::OutputInputSameScale: {
|
||||
auto params = GetQuantParamsForSameScaleConstraint(op);
|
||||
if (EmptyParams(params)) {
|
||||
*changed |= false;
|
||||
break;
|
||||
}
|
||||
// propagate this params to all the quantizable ports.
|
||||
if (failed(PropagateQuantParams(op, params, new_items, changed))) {
|
||||
return failure();
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
// TODO(fengliuai): implement the other types.
|
||||
llvm_unreachable("no implementation.");
|
||||
return failure();
|
||||
}
|
||||
@ -154,6 +168,102 @@ void QuantizeContext::DumpStates(QuantizeRegionOp current_op) {
|
||||
});
|
||||
}
|
||||
|
||||
// A heuristic to get quantization parameters satisfies the same scale
|
||||
// constraints:
|
||||
// - If there are immutable states,
|
||||
// - use the single input, or,
|
||||
// - use the single output, or,
|
||||
// - use the first one in the collection,
|
||||
// - use the single input if it is ready, or,
|
||||
// - use the single output if it is ready, or,
|
||||
// - use use the first ready one in the collection.
|
||||
QuantParams QuantizeContext::GetQuantParamsForSameScaleConstraint(
|
||||
Operation *op) {
|
||||
// Two vector to collect Non-empty operands and results states.
|
||||
std::vector<quant::QuantState *> mutable_states, immutable_states;
|
||||
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
|
||||
auto &state = states_manager_.GetOperandQuantState(op, i);
|
||||
if (state.immutable) {
|
||||
immutable_states.push_back(&state);
|
||||
} else if (!state.IsEmpty()) {
|
||||
mutable_states.push_back(&state);
|
||||
}
|
||||
}
|
||||
|
||||
int immutable_operands_num = immutable_states.size();
|
||||
int mutable_operands_num = mutable_states.size();
|
||||
// Use the operand's state if it is immutable and it is the only one
|
||||
// operand.
|
||||
if (op->getNumOperands() == 1 && immutable_operands_num == 1) {
|
||||
return immutable_states.front()->params;
|
||||
}
|
||||
|
||||
for (int i = 0, e = op->getNumResults(); i != e; ++i) {
|
||||
auto &state = states_manager_.GetResultQuantState(op, i);
|
||||
if (state.immutable) {
|
||||
immutable_states.push_back(&state);
|
||||
} else if (!state.IsEmpty()) {
|
||||
mutable_states.push_back(&state);
|
||||
}
|
||||
}
|
||||
|
||||
int immutable_results_num = immutable_states.size() - immutable_operands_num;
|
||||
int mutable_results_num = mutable_states.size() - mutable_operands_num;
|
||||
// Use the result's state if it is immutable and it is the only one result.
|
||||
if (op->getNumResults() == 1 && immutable_results_num == 1) {
|
||||
return immutable_states.back()->params;
|
||||
}
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Quantization parameters are not collected in an ideal place. "
|
||||
"Has to fallback values which might introduce errors.\n");
|
||||
|
||||
// Use the first immutable state to quantize the rest operands and results.
|
||||
if (!immutable_states.empty()) return immutable_states.front()->params;
|
||||
|
||||
// If there are no immutable states, use the operand's state if it is the
|
||||
// only one operand and has parameters propagated.
|
||||
if (op->getNumOperands() == 1 && mutable_operands_num == 1) {
|
||||
return mutable_states.front()->params;
|
||||
}
|
||||
|
||||
// If there are no immutable states, use the result's state if it is the
|
||||
// only one result and has parameters propagated.
|
||||
if (op->getNumResults() == 1 && mutable_results_num == 1) {
|
||||
return mutable_states.back()->params;
|
||||
}
|
||||
|
||||
// Use the first propagated state to quantize the rest operands and results.
|
||||
if (!mutable_states.empty()) return mutable_states.front()->params;
|
||||
|
||||
// None operands/results have parameters propagated, skip this node for now.
|
||||
return {};
|
||||
}
|
||||
|
||||
LogicalResult QuantizeContext::PropagateQuantParams(
|
||||
Operation *op, const QuantParams params,
|
||||
quant::AdjacentOperations *new_items, bool *changed) {
|
||||
// Use the final state to set all the operands' parameters.
|
||||
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
|
||||
auto ele = op->getOperand(i).getType().cast<ShapedType>().getElementType();
|
||||
if (ele.isa<FloatType>() && SetOperandParams(op, i, params)) {
|
||||
*changed |= true;
|
||||
new_items->push_back(op->getOperand(i).getDefiningOp());
|
||||
}
|
||||
}
|
||||
|
||||
// Use the final state to set all the results' parameters.
|
||||
for (int res = 0, e = op->getNumResults(); res != e; ++res) {
|
||||
auto ele = op->getResult(res).getType().cast<ShapedType>().getElementType();
|
||||
if (ele.isa<FloatType>() && SetResultParams(op, res, params)) {
|
||||
auto users = op->getResult(res).getUsers();
|
||||
*changed |= !users.empty();
|
||||
new_items->append(users.begin(), users.end());
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
int QuantizeContext::StatesManager::InitializeState(quant::QuantizeRegionOp op,
|
||||
int index, bool as_result) {
|
||||
Attribute params_attr;
|
||||
|
@ -67,7 +67,7 @@ class QuantizeContext {
|
||||
QuantizeContext(FuncOp func, const DeviceTarget &spec);
|
||||
|
||||
// Returns all the quant region ops.
|
||||
ArrayRef<quant::QuantizeRegionOp> GetAllOps();
|
||||
std::vector<quant::QuantizeRegionOp> GetAllOps();
|
||||
|
||||
// For each quant region op, propagates its quantization parameters according
|
||||
// to the kernel specification and also returns the adjcent quant region ops
|
||||
@ -107,6 +107,25 @@ class QuantizeContext {
|
||||
return states_manager_.GetOperandParams(op, index);
|
||||
}
|
||||
|
||||
// A heuristic to get quantization parameters satisfies the same scale
|
||||
// constraints:
|
||||
// - If there are immutable states,
|
||||
// - use the single input, or,
|
||||
// - use the single output, or,
|
||||
// - use the first one in the collection,
|
||||
// - use the single input if it is ready, or,
|
||||
// - use the single output if it is ready, or,
|
||||
// - use use the first ready one in the collection.
|
||||
QuantParams GetQuantParamsForSameScaleConstraint(Operation *op);
|
||||
|
||||
// Propagate `params` to all the quantizable port of the `op`. The adjcent
|
||||
// ops, which have the parameters propagated to, are collected by `new_items`,
|
||||
// so they can be added to the working queue. `changed` is set to true if
|
||||
// there are any new elements being added to `new_items`.
|
||||
LogicalResult PropagateQuantParams(Operation *op, const QuantParams params,
|
||||
AdjacentOperations *new_items,
|
||||
bool *changed);
|
||||
|
||||
private:
|
||||
class StatesManager {
|
||||
public:
|
||||
|
@ -28,6 +28,14 @@ namespace ph = std::placeholders;
|
||||
CpuDeviceTarget::CpuDeviceTarget(MLIRContext* ctx) : DeviceTarget(ctx) {
|
||||
RegisterKernel("generic.concat", {qi8_, qi8_, qi8_},
|
||||
quant::ScaleConstraintType::OutputInputSameScale);
|
||||
|
||||
// TODO(fengliuai): All the combinations are required to list. We need to
|
||||
// improve this.
|
||||
RegisterKernel("generic.reshape", {qi8_, any_},
|
||||
quant::ScaleConstraintType::OutputInputSameScale);
|
||||
RegisterKernel("generic.reshape", {any_, qi8_},
|
||||
quant::ScaleConstraintType::OutputInputSameScale);
|
||||
|
||||
RegisterKernel("generic.mul", {qi8_, qi8_, qi8_},
|
||||
quant::ScaleConstraintType::OutputInputFreeScale);
|
||||
RegisterKernel("generic.mul_add", {qi8_, qi8n_, any_, qi8_},
|
||||
|
@ -176,7 +176,7 @@ llvm::SmallVector<Value, 0> fuseOps(PatternRewriter* rewriter,
|
||||
auto* body = new Block();
|
||||
region.body().push_back(body);
|
||||
|
||||
OpBuilder builder(body);
|
||||
OpBuilder builder = OpBuilder::atBlockEnd(body);
|
||||
BlockAndValueMapping mapping;
|
||||
|
||||
// Make block arguments and add it to the block value mapping.
|
||||
|
@ -69,7 +69,7 @@ void PropagateQuantPass::runOnFunction() {
|
||||
CpuDeviceTarget spec(&getContext());
|
||||
quant::QuantizeContext ctx(func, spec);
|
||||
|
||||
std::vector<quant::QuantizeRegionOp> work_list(ctx.GetAllOps());
|
||||
std::vector<quant::QuantizeRegionOp> work_list = ctx.GetAllOps();
|
||||
bool changed = false;
|
||||
while (!work_list.empty()) {
|
||||
quant::QuantizeRegionOp op = work_list.back();
|
||||
|
@ -52,3 +52,18 @@ func @mul_add_annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tenso
|
||||
// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8<-127:127>:f32, 1.000000e+00:-128>, !quant.uniform<i32:f32, 1.000000e+00>]
|
||||
// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @same_scale_1_1
|
||||
func @same_scale_1_1(%arg0: tensor<1x7x7x64xf32>) -> (tensor<1x3136xf32>) {
|
||||
%region = "quant.region"(%arg0) ( {
|
||||
^bb0(%arg1: tensor<1x7x7x64xf32>): // no predecessors
|
||||
%r = "xla_hlo.reshape"(%arg1) : (tensor<1x7x7x64xf32>) -> (tensor<1x3136xf32>)
|
||||
"quant.return"(%r) : (tensor<1x3136xf32>) -> ()
|
||||
}) {input_specs = [!quant.uniform<i8:f32, 1.0>], logical_kernel = "generic.reshape", output_specs = [f32]} : (tensor<1x7x7x64xf32>) -> tensor<1x3136xf32>
|
||||
return %region : tensor<1x3136xf32>
|
||||
|
||||
// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00>]
|
||||
// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00>]
|
||||
}
|
||||
|
@ -1,37 +1,53 @@
|
||||
// RUN: tf-opt %s -tfl-identify-dilated-conv | FileCheck %s --dump-input-on-failure
|
||||
|
||||
func @testDilatedConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
func @testDilatedConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<2> : tensor<2x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
return %2 : tensor<1x128x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConv
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConvWithNonConstantPadAndCrops(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
return %2 : tensor<1x128x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConv
|
||||
// CHECK-LABEL: testDilatedConvWithNonConstantPadAndCrops
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConvWithNonZeroSTBPadding(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
func @testDilatedConvWithNonZeroBasePadding(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<2> : tensor<2x2xi32>
|
||||
%cst_1 = constant dense<1> : tensor<2x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
%1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.BatchToSpaceND"(%1, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
return %2 : tensor<1x128x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithNonZeroSTBPadding
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
|
||||
// CHECK-LABEL: testDilatedConvWithNonZeroBasePadding
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConvWithNonTrivialDilations(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
func @testDilatedConvWithNonTrivialDilations(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", dilations = [1, 2, 2, 1], strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
%cst_0 = constant dense<2> : tensor<2x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", dilations = [1, 2, 2, 1], strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
return %2 : tensor<1x128x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithNonTrivialDilations
|
||||
@ -41,25 +57,27 @@ func @testDilatedConvWithNonTrivialDilations(%arg0: tensor<1x128x128x3xf32>, %ar
|
||||
// CHECK-NEXT: return [[RESULT]]
|
||||
}
|
||||
|
||||
func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.DepthwiseConv2dNative"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
%cst_0 = constant dense<2> : tensor<2x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.DepthwiseConv2dNative"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
return %2 : tensor<1x128x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedDepthWiseConv
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%cst_0 = constant dense<2> : tensor<2x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.Pad"(%1, %arg1) : (tensor<4x64x64x8xf32>, tensor<2x2xi32>) -> tensor<4x64x64x8xf32>
|
||||
%3 = "tf.BatchToSpaceND"(%2, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
%3 = "tf.BatchToSpaceND"(%2, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
%4 = "tf.BiasAdd"(%3, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
|
||||
return %4 : tensor<1x128x128x8xf32>
|
||||
|
||||
@ -72,10 +90,11 @@ func @testDilatedConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi
|
||||
|
||||
func @testDilatedDepthWiseConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%cst_0 = constant dense<2> : tensor<2x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.DepthwiseConv2dNative"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.Pad"(%1, %arg1) : (tensor<4x64x64x8xf32>, tensor<2x2xi32>) -> tensor<4x64x64x8xf32>
|
||||
%3 = "tf.BatchToSpaceND"(%2, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
%3 = "tf.BatchToSpaceND"(%2, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
%4 = "tf.BiasAdd"(%3, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
|
||||
return %4 : tensor<1x128x128x8xf32>
|
||||
|
||||
@ -86,49 +105,52 @@ func @testDilatedDepthWiseConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: ten
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
func @testDilatedConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
%3 = "tf.BiasAdd"(%2, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
|
||||
%cst_0 = constant dense<2> : tensor<2x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
%3 = "tf.BiasAdd"(%2, %arg2) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
|
||||
return %3 : tensor<1x128x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithBiasAdd
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedDepthWiseConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
func @testDilatedDepthWiseConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.DepthwiseConv2dNative"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
%3 = "tf.BiasAdd"(%2, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
|
||||
%cst_0 = constant dense<2> : tensor<2x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.DepthwiseConv2dNative"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
%3 = "tf.BiasAdd"(%2, %arg2) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
|
||||
return %3 : tensor<1x128x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedDepthWiseConvWithBiasAdd
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%cst_1 = constant dense<2> : tensor<2x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
|
||||
%4 = "tf.BatchToSpaceND"(%3, %cst, %arg1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
|
||||
%5 = "tf.BiasAdd"(%4, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||
%4 = "tf.BatchToSpaceND"(%3, %cst, %cst_1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
|
||||
%5 = "tf.BiasAdd"(%4, %arg2) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||
return %5 : tensor<1x128x128xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithExpandSqueeze1
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
@ -137,19 +159,20 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
|
||||
}
|
||||
|
||||
func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%cst_1 = constant dense<2> : tensor<2x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
%2 = "tf.DepthwiseConv2dNative"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
|
||||
%4 = "tf.BatchToSpaceND"(%3, %cst, %arg1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
|
||||
%5 = "tf.BiasAdd"(%4, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||
%4 = "tf.BatchToSpaceND"(%3, %cst, %cst_1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
|
||||
%5 = "tf.BiasAdd"(%4, %arg2) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||
return %5 : tensor<1x128x128xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze1
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
@ -158,19 +181,20 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
|
||||
func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<?xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
|
||||
%cst_1 = constant dense<2> : tensor<2x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x?x?x1xf32>) -> tensor<4x?x?xf32>
|
||||
%4 = "tf.BiasAdd"(%3, %arg3) : (tensor<4x?x?xf32>, tensor<?xf32>) -> tensor<4x?x?xf32>
|
||||
%5 = "tf.BatchToSpaceND"(%4, %cst, %arg1) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
|
||||
%4 = "tf.BiasAdd"(%3, %arg2) : (tensor<4x?x?xf32>, tensor<?xf32>) -> tensor<4x?x?xf32>
|
||||
%5 = "tf.BatchToSpaceND"(%4, %cst, %cst_1) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
|
||||
return %5 : tensor<1x128x128xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithExpandSqueeze2
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
@ -179,19 +203,20 @@ func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
|
||||
}
|
||||
|
||||
func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
|
||||
func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<?xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
|
||||
%cst_1 = constant dense<2> : tensor<2x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
|
||||
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
|
||||
%2 = "tf.DepthwiseConv2dNative"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x?x?x1xf32>) -> tensor<4x?x?xf32>
|
||||
%4 = "tf.BiasAdd"(%3, %arg3) : (tensor<4x?x?xf32>, tensor<?xf32>) -> tensor<4x?x?xf32>
|
||||
%5 = "tf.BatchToSpaceND"(%4, %cst, %arg1) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
|
||||
%4 = "tf.BiasAdd"(%3, %arg2) : (tensor<4x?x?xf32>, tensor<?xf32>) -> tensor<4x?x?xf32>
|
||||
%5 = "tf.BatchToSpaceND"(%4, %cst, %cst_1) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
|
||||
return %5 : tensor<1x128x128xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze2
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
@ -203,12 +228,13 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %
|
||||
func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%cst_1 = constant dense<2> : tensor<2x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
|
||||
%4 = "tf.Pad"(%3, %arg1) : (tensor<4x64x64xf32>, tensor<2x2xi32>) -> tensor<4x64x64xf32>
|
||||
%5 = "tf.BatchToSpaceND"(%4, %cst, %arg1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
|
||||
%5 = "tf.BatchToSpaceND"(%4, %cst, %cst_1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
|
||||
%6 = "tf.BiasAdd"(%5, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||
return %6 : tensor<1x128x128xf32>
|
||||
|
||||
@ -225,12 +251,13 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
||||
func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%cst_1 = constant dense<2> : tensor<2x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
|
||||
%4 = "tf.Pad"(%3, %arg1) : (tensor<4x64x64xf32>, tensor<2x2xi32>) -> tensor<4x64x64xf32>
|
||||
%5 = "tf.BatchToSpaceND"(%4, %cst, %arg1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
|
||||
%5 = "tf.BatchToSpaceND"(%4, %cst, %cst_1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
|
||||
%6 = "tf.BiasAdd"(%5, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||
return %6 : tensor<1x128x128xf32>
|
||||
|
||||
@ -244,14 +271,15 @@ func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConvWithDifferentExpandSqueezeAxis(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128x1xf32> {
|
||||
func @testDilatedConvWithDifferentExpandSqueezeAxis(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%cst_1 = constant dense<2> : tensor<2x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [2]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
%4 = "tf.BatchToSpaceND"(%3, %cst, %arg1) : (tensor<4x64x64x1xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x1xf32>
|
||||
%4 = "tf.BatchToSpaceND"(%3, %cst, %cst_1) : (tensor<4x64x64x1xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x1xf32>
|
||||
return %4 : tensor<1x128x128x1xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithDifferentExpandSqueezeAxis
|
||||
|
@ -142,7 +142,7 @@ versions {
|
||||
# CHECK-SAME: control_outputs = ""
|
||||
# CHECK-SAME: inputs = "unranked"
|
||||
# CHECK-SAME: outputs = "unranked,static,static_10"
|
||||
# CHECK: [[VAL_1:%.*]] = constant dense<0> : tensor<10xi32>
|
||||
# CHECK: [[VAL_2:%.*]] = constant dense<0> : tensor<i32>
|
||||
# CHECK: return [[VAL_0]], [[VAL_2]], [[VAL_1]] : tensor<1x8x8x2xi32>, tensor<i32>, tensor<10xi32>
|
||||
# CHECK: [[VAL_1:%.*]] = constant dense<0> : tensor<i32>
|
||||
# CHECK: [[VAL_2:%.*]] = constant dense<0> : tensor<10xi32>
|
||||
# CHECK: return [[VAL_0]], [[VAL_1]], [[VAL_2]] : tensor<1x8x8x2xi32>, tensor<i32>, tensor<10xi32>
|
||||
# CHECK: }
|
||||
|
@ -7788,35 +7788,35 @@ library {
|
||||
# CHECK-SAME: control_outputs = ""
|
||||
# CHECK-SAME: inputs = "INPUT"
|
||||
# CHECK-SAME: outputs = "OUTPUT"
|
||||
# CHECK: [[VAL_1:%.*]] = constant dense<{{\[\[}}-0.400154352, 0.739109992, 0.201825857], [0.678572893, 0.32076478, 0.949867963], [-0.807729483, -5.324750e-01, 0.148033619]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_2:%.*]] = constant dense<{{\[\[}}0.886177539, -0.606141329, -0.451275587], [0.325554609, 0.691527605, -0.676239967], [0.219799042, 0.626042128, -0.597596407]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_3:%.*]] = constant dense<{{\[\[}}-0.493826151, -0.391061306, -0.349843264], [-0.0213134289, 0.558384657, -0.51513052], [0.427886248, 0.618100405, -0.187585592]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_4:%.*]] = constant dense<{{\[\[}}0.444335222, -0.133341789, 0.839591503], [0.445418358, -0.571707964, 0.569707394], [0.465010405, -0.990037918, -0.632481337]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_5:%.*]] = constant dense<{{\[\[}}-0.138204336, -0.10879755, -0.135128736], [0.94797182, -8.713360e-01, -0.792336463], [0.0339827538, -0.539326906, 8.906350e-01]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_6:%.*]] = constant dense<{{\[\[}}0.513064623, -0.692989588, 0.547988653], [0.0653710365, 0.576977491, 0.966733217], [0.0130724907, 0.247342348, 0.317092657]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_7:%.*]] = constant dense<{{\[\[}}0.230039358, -0.182297707, -0.352231741], [-0.805100203, -0.220300436, -0.669503212], [0.278807402, -0.201502323, -0.627609729]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_8:%.*]] = constant dense<{{\[\[}}-0.207589626, -0.756766081, -0.853258133], [-0.269270182, 0.0468223095, -0.353052378], [-0.0702953338, 0.0725159645, -0.817753077]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_9:%.*]] = constant dense<[0.171322107, -0.153412342, 0.591750383]> : tensor<3xf32>
|
||||
# CHECK: [[VAL_10:%.*]] = constant dense<[-0.671292543, 0.411814928, 0.560465336]> : tensor<3xf32>
|
||||
# CHECK: [[VAL_11:%.*]] = constant dense<[0.403919935, -0.882057666, -0.894463062]> : tensor<3xf32>
|
||||
# CHECK: [[VAL_12:%.*]] = constant dense<{{\[\[}}-0.936182261, -0.935433864, 0.288229942], [-0.243383884, -0.628288031, -0.477061749], [-0.514976501, -0.903514862, 6.728170e-01]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_13:%.*]] = constant dense<{{\[\[}}0.18183589, 0.616135359, -0.167827845], [0.734281301, 0.958347797, -0.878054618], [0.369523764, -0.969005823, -0.881014585]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_14:%.*]] = constant dense<{{\[\[}}-5.087240e-01, -0.588907719, 0.471896172], [-0.508019447, -0.0157074928, -0.804120779], [-0.978842973, 0.00160336494, -0.978532075]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_15:%.*]] = constant dense<{{\[\[}}-0.616786718, 0.892614365, 0.671324968], [-0.842380046, -0.358094931, 0.821366549], [0.790347338, 0.71222949, 0.0690443515]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_16:%.*]] = constant dense<1.000000e+00> : tensor<3xf32>
|
||||
# CHECK: [[VAL_17:%.*]] = constant dense<{{\[\[}}0.782244444, -0.0446639061, 0.848498106], [-0.579102755, -0.407756329, 0.442389727], [0.00566458702, 0.5984025, 0.629857302]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_18:%.*]] = constant dense<{{\[\[}}0.891112089, -2.786560e-01, 0.966933965], [-0.789963722, 0.057955265, 0.217499971], [-0.698129416, -0.983400583, -0.834380626]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_19:%.*]] = constant dense<{{\[\[}}-0.125753641, 0.32271719, 0.488939524], [0.36119318, 0.982266664, -0.448646784], [0.966353893, -0.767024993, 0.446366787]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_20:%.*]] = constant dense<{{\[\[}}-0.856678485, -0.800494194, 0.716800689], [0.536404848, 0.541643381, -0.35657692], [-0.794646739, 0.137629032, 0.690013885]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_21:%.*]] = constant dense<0.000000e+00> : tensor<3xf32>
|
||||
# CHECK: [[VAL_22:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
|
||||
# CHECK: [[VAL_1:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
|
||||
# CHECK: [[VAL_2:%.*]] = constant dense<0.000000e+00> : tensor<3xf32>
|
||||
# CHECK: [[VAL_3:%.*]] = constant dense<{{\[\[}}-0.856678485, -0.800494194, 0.716800689], [0.536404848, 0.541643381, -0.35657692], [-0.794646739, 0.137629032, 0.690013885]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_4:%.*]] = constant dense<{{\[\[}}-0.125753641, 0.32271719, 0.488939524], [0.36119318, 0.982266664, -0.448646784], [0.966353893, -0.767024993, 0.446366787]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_5:%.*]] = constant dense<{{\[\[}}0.891112089, -2.786560e-01, 0.966933965], [-0.789963722, 0.057955265, 0.217499971], [-0.698129416, -0.983400583, -0.834380626]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_6:%.*]] = constant dense<{{\[\[}}0.782244444, -0.0446639061, 0.848498106], [-0.579102755, -0.407756329, 0.442389727], [0.00566458702, 0.5984025, 0.629857302]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_7:%.*]] = constant dense<1.000000e+00> : tensor<3xf32>
|
||||
# CHECK: [[VAL_8:%.*]] = constant dense<{{\[\[}}-0.616786718, 0.892614365, 0.671324968], [-0.842380046, -0.358094931, 0.821366549], [0.790347338, 0.71222949, 0.0690443515]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_9:%.*]] = constant dense<{{\[\[}}-5.087240e-01, -0.588907719, 0.471896172], [-0.508019447, -0.0157074928, -0.804120779], [-0.978842973, 0.00160336494, -0.978532075]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_10:%.*]] = constant dense<{{\[\[}}0.18183589, 0.616135359, -0.167827845], [0.734281301, 0.958347797, -0.878054618], [0.369523764, -0.969005823, -0.881014585]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_11:%.*]] = constant dense<{{\[\[}}-0.936182261, -0.935433864, 0.288229942], [-0.243383884, -0.628288031, -0.477061749], [-0.514976501, -0.903514862, 6.728170e-01]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_12:%.*]] = constant dense<{{\[}}0.403919935, -0.882057666, -0.894463062]> : tensor<3xf32>
|
||||
# CHECK: [[VAL_13:%.*]] = constant dense<{{\[}}-0.671292543, 0.411814928, 0.560465336]> : tensor<3xf32>
|
||||
# CHECK: [[VAL_14:%.*]] = constant dense<{{\[}}0.171322107, -0.153412342, 0.591750383]> : tensor<3xf32>
|
||||
# CHECK: [[VAL_15:%.*]] = constant dense<{{\[\[}}-0.207589626, -0.756766081, -0.853258133], [-0.269270182, 0.0468223095, -0.353052378], [-0.0702953338, 0.0725159645, -0.817753077]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_16:%.*]] = constant dense<{{\[\[}}0.230039358, -0.182297707, -0.352231741], [-0.805100203, -0.220300436, -0.669503212], [0.278807402, -0.201502323, -0.627609729]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_17:%.*]] = constant dense<{{\[\[}}0.513064623, -0.692989588, 0.547988653], [0.0653710365, 0.576977491, 0.966733217], [0.0130724907, 0.247342348, 0.317092657]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_18:%.*]] = constant dense<{{\[\[}}-0.138204336, -0.10879755, -0.135128736], [0.94797182, -8.713360e-01, -0.792336463], [0.0339827538, -0.539326906, 8.906350e-01]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_19:%.*]] = constant dense<{{\[\[}}0.444335222, -0.133341789, 0.839591503], [0.445418358, -0.571707964, 0.569707394], [0.465010405, -0.990037918, -0.632481337]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_20:%.*]] = constant dense<{{\[\[}}-0.493826151, -0.391061306, -0.349843264], [-0.0213134289, 0.558384657, -0.51513052], [0.427886248, 0.618100405, -0.187585592]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_21:%.*]] = constant dense<{{\[\[}}0.886177539, -0.606141329, -0.451275587], [0.325554609, 0.691527605, -0.676239967], [0.219799042, 0.626042128, -0.597596407]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_22:%.*]] = constant dense<{{\[\[}}-0.400154352, 0.739109992, 0.201825857], [0.678572893, 0.32076478, 0.949867963], [-0.807729483, -5.324750e-01, 0.148033619]]> : tensor<3x3xf32>
|
||||
# CHECK: [[VAL_23:%.*]] = constant unit
|
||||
# CHECK: [[VAL_24:%.*]]:3 = "tfl.unpack"(%[[ARG_0]]) {axis = 1 : i32, num = 3 : i32} : (tensor<1x3x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>)
|
||||
# CHECK: [[VAL_25:%.*]] = "tfl.pack"([[VAL_24]]#0, [[VAL_24]]#1, [[VAL_24]]#2) {axis = 0 : i32, values_count = 3 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<3x1x3xf32>
|
||||
# CHECK: [[UNPACK:%.*]]:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<1x3x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>)
|
||||
# CHECK: [[PACK:%.*]] = "tfl.pack"([[UNPACK]]#0, [[UNPACK]]#1, [[UNPACK]]#2) {axis = 0 : i32, values_count = 3 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<3x1x3xf32>
|
||||
# CHECK: [[VAL_24:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
|
||||
# CHECK: [[UNIDIRECTIONAL_SEQUENCE_LSTM_1:%.*]] = "tfl.unidirectional_sequence_lstm"([[PACK]], [[VAL_16]], [[VAL_17]], [[VAL_18]], [[VAL_15]], [[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_19]], [[VAL_13]], [[VAL_14]], [[VAL_12]], [[VAL_2]], [[VAL_7]], [[VAL_2]], [[VAL_2]], [[VAL_23]], [[VAL_23]], [[VAL_1]], [[VAL_24]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32>
|
||||
# CHECK: [[VAL_25:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
|
||||
# CHECK: [[VAL_26:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
|
||||
# CHECK: [[VAL_27:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_25]], [[VAL_7]], [[VAL_6]], [[VAL_5]], [[VAL_8]], [[VAL_3]], [[VAL_2]], [[VAL_1]], [[VAL_4]], [[VAL_10]], [[VAL_9]], [[VAL_11]], [[VAL_21]], [[VAL_16]], [[VAL_21]], [[VAL_21]], [[VAL_23]], [[VAL_23]], [[VAL_22]], [[VAL_26]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32>
|
||||
# CHECK: [[VAL_28:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
|
||||
# CHECK: [[VAL_29:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
|
||||
# CHECK: [[VAL_30:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_27]], [[VAL_19]], [[VAL_18]], [[VAL_17]], [[VAL_20]], [[VAL_14]], [[VAL_13]], [[VAL_12]], [[VAL_15]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_21]], [[VAL_16]], [[VAL_21]], [[VAL_21]], [[VAL_23]], [[VAL_23]], [[VAL_28]], [[VAL_29]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, none, none, none, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32>
|
||||
# CHECK: [[VAL_31:%.*]]:3 = "tfl.unpack"([[VAL_30]]) {axis = 0 : i32, num = 3 : i32} : (tensor<3x1x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>)
|
||||
# CHECK: return [[VAL_31]]#2 : tensor<1x3xf32>
|
||||
# CHECK: [[UNIDIRECTIONAL_SEQUENCE_LSTM_2:%.*]] = "tfl.unidirectional_sequence_lstm"([[UNIDIRECTIONAL_SEQUENCE_LSTM_1]], [[VAL_4]], [[VAL_5]], [[VAL_6]], [[VAL_3]], [[VAL_9]], [[VAL_10]], [[VAL_11]], [[VAL_8]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_2]], [[VAL_7]], [[VAL_2]], [[VAL_2]], [[VAL_23]], [[VAL_23]], [[VAL_25]], [[VAL_26]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, none, none, none, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32>
|
||||
# CHECK: [[RESULT:%.*]]:3 = "tfl.unpack"([[UNIDIRECTIONAL_SEQUENCE_LSTM_2]]) {axis = 0 : i32, num = 3 : i32} : (tensor<3x1x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>)
|
||||
# CHECK: return [[RESULT]]#2 : tensor<1x3xf32>
|
||||
|
@ -28,6 +28,13 @@ func @f32() -> tensor<4xf32> {
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
func @f64() -> tensor<4xf64> {
|
||||
// CHECK-LABEL: @f64
|
||||
// CHECK: value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf64>
|
||||
%0 = "tfl.pseudo_const"() { value = dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64> } : () -> tensor<4xf64>
|
||||
return %0 : tensor<4xf64>
|
||||
}
|
||||
|
||||
func @i8() -> tensor<4xi8> {
|
||||
// CHECK-LABEL: @i8
|
||||
// CHECK: value = dense<[1, 2, 3, 4]> : tensor<4xi8>
|
||||
|
@ -829,6 +829,14 @@ func @pack3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tensor<2x
|
||||
// CHECK: "tfl.pack"(%arg0, %arg1, %arg2) {axis = 1 : i32, values_count = 3 : i32} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
|
||||
}
|
||||
|
||||
func @packStringWithFlex(%arg0: tensor<2x!tf.string>, %arg1: tensor<2x!tf.string>) -> tensor<2x2x!tf.string> {
|
||||
%0 = "tf.Pack"(%arg0, %arg1) : (tensor<2x!tf.string>, tensor<2x!tf.string>) -> tensor<2x2x!tf.string>
|
||||
return %0 : tensor<2x2x!tf.string>
|
||||
|
||||
// CHECK-LABEL: packStringWithFlex
|
||||
// CHECK: "tf.Pack"(%arg0, %arg1) : (tensor<2x!tf.string>, tensor<2x!tf.string>) -> tensor<2x2x!tf.string>
|
||||
}
|
||||
|
||||
func @packNegAxis(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tensor<2xi32>) -> tensor<2x3xi32> {
|
||||
%0 = "tf.Pack"(%arg0, %arg1, %arg2) {axis = -1 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
|
||||
return %0 : tensor<2x3xi32>
|
||||
|
@ -0,0 +1,66 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
func @main(tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> {
|
||||
^bb0(%arg0: tensor<4xf64>, %arg1: tensor<4xf64>):
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||
// CHECK-NEXT: custom_code: "FlexAdd"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: type: FLOAT64,
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: type: FLOAT64,
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "arg1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: type: FLOAT64,
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "add",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ],
|
||||
// CHECK-NEXT: custom_options: [ 3, 65, 100, 100, 0, 20, 18, 3, 65, 100, 100, 26, 0, 26, 0, 42, 7, 10, 1, 84, 18, 2, 48, 2, 50, 0, 0, 2, 27, 23, 20, 20, 4, 40, 1 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 4
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> loc("add")
|
||||
return %0 : tensor<4xf64>
|
||||
}
|
@ -138,13 +138,24 @@ int main(int argc, char **argv) {
|
||||
// TODO(b/147435528): We need to test the e2e behavior once the graph freezing
|
||||
// inside mlir is done.
|
||||
if (import_saved_model_object_graph || import_saved_model_signature_defs) {
|
||||
int saved_model_version;
|
||||
if (import_saved_model_object_graph) {
|
||||
saved_model_version = 2;
|
||||
} else {
|
||||
saved_model_version = 1;
|
||||
}
|
||||
if (input_mlir)
|
||||
module = tensorflow::errors::InvalidArgument(
|
||||
"Importing saved model should not have input_mlir set");
|
||||
module = tensorflow::ImportSavedModel(import_saved_model_object_graph,
|
||||
import_saved_model_signature_defs,
|
||||
input_file_name, saved_model_tags,
|
||||
saved_model_exported_names, &context);
|
||||
|
||||
std::unordered_set<std::string> tags =
|
||||
absl::StrSplit(saved_model_tags, ',');
|
||||
std::vector<std::string> exported_names_vector =
|
||||
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
||||
absl::Span<std::string> exported_names(exported_names_vector);
|
||||
|
||||
module = tensorflow::ImportSavedModel(input_file_name, saved_model_version,
|
||||
tags, exported_names, &context);
|
||||
} else {
|
||||
module = tensorflow::LoadFromGraphdefOrMlirSource(
|
||||
input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
|
||||
@ -197,11 +208,6 @@ int main(int argc, char **argv) {
|
||||
pass_config.lower_tensor_list_ops = lower_tensor_list_ops;
|
||||
pass_config.legalize_tf_while = convert_tf_while_to_tfl_while;
|
||||
|
||||
// Currently we only do shape inference for saved model import.
|
||||
if (import_saved_model_object_graph || import_saved_model_signature_defs) {
|
||||
pass_config.shape_inference = true;
|
||||
}
|
||||
|
||||
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
|
||||
// TODO(b/150901738): Move those into tf_tfl_translate.cc.
|
||||
// Convert back to outlined while format for export back to flatbuffer.
|
||||
|
@ -160,25 +160,17 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
}
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
bool import_saved_model, bool import_saved_model_v1,
|
||||
const std::string& input_filename, const std::string& saved_model_tags,
|
||||
const std::string& saved_model_exported_names, mlir::MLIRContext* context) {
|
||||
if (import_saved_model) {
|
||||
std::unordered_set<std::string> tags =
|
||||
absl::StrSplit(saved_model_tags, ',');
|
||||
std::vector<std::string> exported_names =
|
||||
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
||||
|
||||
const std::string& input_filename, const int saved_model_version,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
|
||||
if (saved_model_version == 2) {
|
||||
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
|
||||
input_filename, tags, absl::Span<std::string>(exported_names), context);
|
||||
input_filename, tags, exported_names, context);
|
||||
if (!module)
|
||||
return tensorflow::errors::InvalidArgument("fail to open input file");
|
||||
|
||||
return module;
|
||||
} else if (import_saved_model_v1) {
|
||||
std::unordered_set<std::string> tags =
|
||||
absl::StrSplit(saved_model_tags, ',');
|
||||
|
||||
} else if (saved_model_version == 1) {
|
||||
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||
input_filename, tags, context);
|
||||
|
||||
|
@ -16,6 +16,9 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
@ -42,9 +45,9 @@ LoadFromGraphdefOrMlirSource(
|
||||
|
||||
// Load Saved model (either v1 or v2) into MLIR.
|
||||
stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
bool import_saved_model, bool import_saved_model_v1,
|
||||
const std::string& input_filename, const std::string& saved_model_tags,
|
||||
const std::string& saved_model_exported_names, mlir::MLIRContext* context);
|
||||
const std::string& input_filename, const int saved_model_version,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context);
|
||||
|
||||
// Taking a MLIR module in TF executor dialect and a set of parameters,
|
||||
// applies a set of passes to convert the module to TF Lite dialect and
|
||||
|
@ -152,7 +152,6 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
}
|
||||
|
||||
// BatchToSpaceND + BiasAdd.
|
||||
// TODO(b/149936532): Check the `crops` input, currently ignored.
|
||||
TF::BatchToSpaceNDOp bts_op;
|
||||
TF::BiasAddOp biasadd_op;
|
||||
bool final_op_is_bts = true;
|
||||
@ -179,16 +178,50 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
if (!dilations_attr.hasValue()) return failure();
|
||||
op.setAttr("dilations", dilations_attr.getValue());
|
||||
|
||||
// Padding is set to 'SAME' when `stb_op` has non-zero paddings.
|
||||
// TODO(b/149936532): This assumption only holds when the input width & height
|
||||
// is multiple of dilation width & height. We should fix it in order to
|
||||
// support other use cases.
|
||||
// TODO(b/149936532): Check that the input width & height are multiples of
|
||||
// dilation rate.
|
||||
// TF python library will rewrite dilated conv to
|
||||
// "SpaceToBatch->Conv->BatchToSpace" pattern, and the Conv in the middle
|
||||
// always has 'VALID' padding. The padding tensor in `SpaceToBatch` has two
|
||||
// parts of contributions, one is to reduce padding of CONV from 'SAME' to
|
||||
// 'VALID', and another is to make input shape multiples of dilation rate. The
|
||||
// first part of padding, which is also called `base_padding` will be used
|
||||
// here to determine if the original padding format is 'SAME' or 'VALID'.
|
||||
// According to the following formula we will compute the `base_padding` if
|
||||
// it's a constant. Basically, `paddings` tensor in `SpaceToBatch` and `crops`
|
||||
// tensor in `BatchToSpace` must satisfy the following:
|
||||
// paddings[i, 0] = base_paddings[i, 0].
|
||||
// 0 <= paddings[i, 1] - base_paddings[i, 1] < block_shape[i]
|
||||
// (input_shape[i] + paddings[i, 0] + paddings[i, 1]) % block_shape[i] == 0.
|
||||
// crops[i, 0] = 0.
|
||||
// crops[i, 1] = paddings[i, 1] - base_paddings[i, 1].
|
||||
|
||||
// If `paddings` - `crops` != 0, this means that `base_paddings` != 0, which
|
||||
// tells us the original padding is 'SAME' (with one caveat presented below).
|
||||
// Here we need to reset the padding back to `SAME` if `base_padding`
|
||||
// != 0.
|
||||
// TODO(b/149936532): We might not simply rely on `paddings - crops != 0` to
|
||||
// determine the original padding format. For example, users can build
|
||||
// arbitrary valid examples of `STB->Conv->BTS` which doesn't represent a
|
||||
// dilated conv, hence we shouldn't pattern match here. Instead, we need to
|
||||
// check values of `paddings` and `crops` to make sure it really stands for
|
||||
// a dilated conv.
|
||||
auto stb_paddings = stb_op.paddings();
|
||||
ElementsAttr stb_paddings_attr;
|
||||
if (matchPattern(stb_paddings, m_Constant(&stb_paddings_attr))) {
|
||||
if (llvm::any_of(stb_paddings_attr.getValues<IntegerAttr>(),
|
||||
[](IntegerAttr attr) { return attr.getInt() != 0; })) {
|
||||
op.setAttr("padding", rewriter.getStringAttr("SAME"));
|
||||
auto bts_crops = bts_op.crops();
|
||||
ElementsAttr stb_paddings_attr, bts_crops_attr;
|
||||
if (matchPattern(stb_paddings, m_Constant(&stb_paddings_attr)) &&
|
||||
matchPattern(bts_crops, m_Constant(&bts_crops_attr))) {
|
||||
if (stb_paddings_attr.getNumElements() != bts_crops_attr.getNumElements())
|
||||
return failure();
|
||||
// padding - crop.
|
||||
auto paddings = stb_paddings_attr.getValues<IntegerAttr>();
|
||||
auto crops = bts_crops_attr.getValues<IntegerAttr>();
|
||||
for (auto it1 = paddings.begin(), it2 = crops.begin();
|
||||
it1 != paddings.end() && it2 != crops.end(); it1++, it2++) {
|
||||
if ((*it1).getInt() != (*it2).getInt()) {
|
||||
op.setAttr("padding", rewriter.getStringAttr("SAME"));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -679,8 +679,8 @@ LogicalResult ConvertOphintToStub(StringRef stub_name,
|
||||
return success();
|
||||
}
|
||||
|
||||
struct ExtractOphintPass : public ModulePass<ExtractOphintPass> {
|
||||
void runOnModule() override;
|
||||
struct ExtractOphintPass : public OperationPass<ExtractOphintPass, ModuleOp> {
|
||||
void runOnOperation() override;
|
||||
void Verify();
|
||||
|
||||
private:
|
||||
@ -689,8 +689,8 @@ struct ExtractOphintPass : public ModulePass<ExtractOphintPass> {
|
||||
|
||||
// TODO(renjieliu): Current ophint extraction does not support inputs/outputs
|
||||
// cross functions, we need to do that.
|
||||
void ExtractOphintPass::runOnModule() {
|
||||
ModuleOp module = getModule();
|
||||
void ExtractOphintPass::runOnOperation() {
|
||||
ModuleOp module = getOperation();
|
||||
for (auto function : module.getOps<FuncOp>()) {
|
||||
// Process block by block.
|
||||
for (auto& bb : function.getBody()) {
|
||||
@ -710,7 +710,7 @@ void ExtractOphintPass::runOnModule() {
|
||||
ophint_composite_ops_count = ophint_composite_ops.size();
|
||||
|
||||
// Convert.
|
||||
OpBuilder builder(&bb);
|
||||
OpBuilder builder = OpBuilder::atBlockEnd(&bb);
|
||||
for (const auto& kv : ophint_composite_ops) {
|
||||
if (failed(ConvertOphintToStub(kv.getKey(), kv.getValue(), &builder,
|
||||
&module))) {
|
||||
@ -724,9 +724,9 @@ void ExtractOphintPass::runOnModule() {
|
||||
}
|
||||
|
||||
void ExtractOphintPass::Verify() {
|
||||
ModuleOp module = getModule();
|
||||
ModuleOp module = getOperation();
|
||||
int ophint_func_op_count = 0;
|
||||
for (FuncOp func : getModule().getOps<FuncOp>()) {
|
||||
for (FuncOp func : getOperation().getOps<FuncOp>()) {
|
||||
for (const NamedAttribute attr : func.getAttrs()) {
|
||||
if (attr.first == kTfLiteFunctionName) {
|
||||
ophint_func_op_count++;
|
||||
|
@ -68,8 +68,9 @@ constexpr char kUnidirectionalSequenceLstm[] = "UnidirectionalSequenceLstm";
|
||||
// |
|
||||
// |
|
||||
// OutputOp1
|
||||
struct LegalizeOphintFuncOpPass : public ModulePass<LegalizeOphintFuncOpPass> {
|
||||
void runOnModule() override;
|
||||
struct LegalizeOphintFuncOpPass
|
||||
: public OperationPass<LegalizeOphintFuncOpPass, ModuleOp> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
llvm::StringMap<FuncOp> FindCompositeFuncOps(ModuleOp module) {
|
||||
@ -256,8 +257,8 @@ LogicalResult ConvertCallOps(llvm::StringMap<FuncOp>* composite_func_ops,
|
||||
return success();
|
||||
}
|
||||
|
||||
void LegalizeOphintFuncOpPass::runOnModule() {
|
||||
ModuleOp module = getModule();
|
||||
void LegalizeOphintFuncOpPass::runOnOperation() {
|
||||
ModuleOp module = getOperation();
|
||||
|
||||
// Find all composite funcs, then for every call op inside every func op
|
||||
// within the module, we go ahead and replace the callop with the tflite
|
||||
|
@ -745,7 +745,8 @@ void LegalizeTF::runOnFunction() {
|
||||
Optional<ConversionTarget::DynamicLegalityCallbackFn>([](Operation* op) {
|
||||
auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
|
||||
if (!tfl_op) return false;
|
||||
return succeeded(tfl_op.VerifyTflRuntimeTypes(tfl_op.getOperation()));
|
||||
return succeeded(tfl_op.VerifyTflRuntimeTypes(
|
||||
tfl_op.getOperation(), /*failure_on_operand_type_mismatch=*/false));
|
||||
}));
|
||||
// Keep trying to convert.
|
||||
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
|
||||
|
@ -31,11 +31,11 @@ namespace {
|
||||
|
||||
// Legalize TF While to TFL While with calls to the original functions from the
|
||||
// cond and body regions.
|
||||
struct LegalizeWhile : public ModulePass<LegalizeWhile> {
|
||||
struct LegalizeWhile : public OperationPass<LegalizeWhile, ModuleOp> {
|
||||
void RunOnFunction(FuncOp func);
|
||||
|
||||
void runOnModule() override {
|
||||
for (auto op : getModule().getOps<FuncOp>()) RunOnFunction(op);
|
||||
void runOnOperation() override {
|
||||
for (auto op : getOperation().getOps<FuncOp>()) RunOnFunction(op);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -82,8 +82,8 @@ class TensorListPatternRewriter : public PatternRewriter {
|
||||
|
||||
/// Lower TensorList ops in functions for subsequent legalization.
|
||||
struct LowerStaticTensorListPass
|
||||
: public ModulePass<LowerStaticTensorListPass> {
|
||||
void runOnModule() override;
|
||||
: public OperationPass<LowerStaticTensorListPass, ModuleOp> {
|
||||
void runOnOperation() override;
|
||||
|
||||
// Apply type and op changes within a function.
|
||||
LogicalResult RewriteFunction(FuncOp func,
|
||||
@ -878,14 +878,14 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
||||
return applyFullConversion(func, target, patterns);
|
||||
}
|
||||
|
||||
void LowerStaticTensorListPass::runOnModule() {
|
||||
void LowerStaticTensorListPass::runOnOperation() {
|
||||
// TODO(haoliang): currently we process the `main` function first, and the
|
||||
// remaining functions may be processed in arbitrary order. However, this will
|
||||
// have a potential issue when one function taking a `DT_VARIANT` is processed
|
||||
// before the function that produces the `DT_VARIANT`. We need to carefully
|
||||
// order the functions to be processed.
|
||||
std::vector<FuncOp> funcs_in_module;
|
||||
for (auto func : getModule().getOps<FuncOp>()) {
|
||||
for (auto func : getOperation().getOps<FuncOp>()) {
|
||||
// Always place the main function to be the first in the list.
|
||||
if (func.getName() == "main") {
|
||||
funcs_in_module.insert(funcs_in_module.begin(), func);
|
||||
|
@ -36,8 +36,8 @@ using FuncSet = llvm::SmallSet<FuncOp, 4>;
|
||||
|
||||
// Module pass to optimize TensorFlow functional ops.
|
||||
struct OptimizeFunctionalOpsPass
|
||||
: public ModulePass<OptimizeFunctionalOpsPass> {
|
||||
void runOnModule() override;
|
||||
: public OperationPass<OptimizeFunctionalOpsPass, ModuleOp> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
// Updates function return type of the given functions to match the terminator
|
||||
@ -180,13 +180,13 @@ static void EraseDeadFuncs(const FuncSet& candidate_funcs, ModuleOp module) {
|
||||
}
|
||||
}
|
||||
|
||||
void OptimizeFunctionalOpsPass::runOnModule() {
|
||||
void OptimizeFunctionalOpsPass::runOnOperation() {
|
||||
OwningRewritePatternList patterns;
|
||||
|
||||
FuncSet inlined_funcs;
|
||||
patterns.insert<FoldIfOp>(&getContext(), &inlined_funcs);
|
||||
|
||||
ModuleOp module = getModule();
|
||||
ModuleOp module = getOperation();
|
||||
applyPatternsGreedily(module, patterns);
|
||||
|
||||
// Erase inlined functions that don't have any references.
|
||||
|
@ -94,14 +94,14 @@ class ConvertEmbeddedLookupFunc {
|
||||
// body with the corresponding fused TFLite op. The replacement need not always
|
||||
// be a fused op, though that is the primary use case.
|
||||
class PrepareCompositeFunctionsPass
|
||||
: public ModulePass<PrepareCompositeFunctionsPass> {
|
||||
: public OperationPass<PrepareCompositeFunctionsPass, ModuleOp> {
|
||||
public:
|
||||
explicit PrepareCompositeFunctionsPass() {}
|
||||
|
||||
private:
|
||||
void ConvertTFImplements(FuncOp func, StringAttr attr);
|
||||
void ConvertTFAPIImplements(FuncOp func, StringAttr attr, ModuleOp module);
|
||||
void runOnModule() override;
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func,
|
||||
@ -189,8 +189,8 @@ void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(FuncOp func,
|
||||
}
|
||||
}
|
||||
|
||||
void PrepareCompositeFunctionsPass::runOnModule() {
|
||||
auto module = getModule();
|
||||
void PrepareCompositeFunctionsPass::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
// We have two kinds of implements:
|
||||
// 1) tf._implements.
|
||||
|
@ -34,7 +34,9 @@ class RuntimeTypeVerifyPass : public mlir::FunctionPass<RuntimeTypeVerifyPass> {
|
||||
|
||||
void RuntimeTypeVerifyPass::runOnFunction() {
|
||||
getFunction().walk([&](TflRuntimeVerifyOpInterface op) {
|
||||
if (failed(op.VerifyTflRuntimeTypes(op.getOperation())))
|
||||
if (failed(op.VerifyTflRuntimeTypes(
|
||||
op.getOperation(),
|
||||
/*failure_on_operand_type_mismatch=*/true)))
|
||||
signalPassFailure();
|
||||
});
|
||||
}
|
||||
|
@ -44,21 +44,22 @@ namespace {
|
||||
|
||||
// The pass to trim functions before we legalize to TFL
|
||||
// dialect using the specified whitelist.
|
||||
class TrimFunctionsPass : public mlir::ModulePass<TrimFunctionsPass> {
|
||||
class TrimFunctionsPass
|
||||
: public mlir::OperationPass<TrimFunctionsPass, ModuleOp> {
|
||||
public:
|
||||
explicit TrimFunctionsPass() : trim_funcs_whitelist_(trim_funcs_whitelist) {}
|
||||
explicit TrimFunctionsPass(llvm::ArrayRef<std::string> trim_funcs_whitelist)
|
||||
: trim_funcs_whitelist_(trim_funcs_whitelist) {}
|
||||
|
||||
private:
|
||||
void runOnModule() override;
|
||||
void runOnOperation() override;
|
||||
bool TrimModule();
|
||||
void Verify();
|
||||
|
||||
llvm::ArrayRef<std::string> trim_funcs_whitelist_;
|
||||
};
|
||||
|
||||
void TrimFunctionsPass::runOnModule() {
|
||||
void TrimFunctionsPass::runOnOperation() {
|
||||
// trim the functions in the module using the trim_funcs_whitelist_
|
||||
// by removing functions not in the whitelist.
|
||||
if (TrimModule()) {
|
||||
@ -73,7 +74,7 @@ bool TrimFunctionsPass::TrimModule() {
|
||||
if (trim_funcs_whitelist_.empty()) return false;
|
||||
|
||||
llvm::SmallVector<FuncOp, 4> funcs_to_trim;
|
||||
for (auto func : getModule().getOps<FuncOp>()) {
|
||||
for (auto func : getOperation().getOps<FuncOp>()) {
|
||||
if (llvm::is_contained(trim_funcs_whitelist_, func.getName())) {
|
||||
// If no main is specified in the whitelist, use the 1st func
|
||||
// in trim_funcs_whitelist as the main.
|
||||
@ -102,12 +103,12 @@ bool TrimFunctionsPass::TrimModule() {
|
||||
void TrimFunctionsPass::Verify() {
|
||||
// TODO(ashwinm): Instead, we should make sure that references to all
|
||||
// SymbolRefAttrs of all ops are present.
|
||||
SymbolTable symbol_table = SymbolTable(getModule());
|
||||
SymbolTable symbol_table = SymbolTable(getOperation());
|
||||
llvm::SetVector<FuncOp> reachable_funcs;
|
||||
for (auto func : getModule().getOps<FuncOp>()) {
|
||||
for (auto func : getOperation().getOps<FuncOp>()) {
|
||||
auto walk_result = func.walk([&](CallOp op) -> WalkResult {
|
||||
if (!symbol_table.lookup<FuncOp>(op.getCallee()))
|
||||
return getModule().emitError()
|
||||
return getOperation().emitError()
|
||||
<< func.getName() << " is not in the funcs whitelist";
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
@ -37,12 +37,13 @@ namespace {
|
||||
|
||||
// This pass outlines the cond/body region of the TFL WhileOp into functions and
|
||||
// replaces the regions with calls to these outlined functions.
|
||||
class WhileOutlinePass : public mlir::ModulePass<WhileOutlinePass> {
|
||||
class WhileOutlinePass
|
||||
: public mlir::OperationPass<WhileOutlinePass, ModuleOp> {
|
||||
public:
|
||||
explicit WhileOutlinePass() {}
|
||||
|
||||
private:
|
||||
void runOnModule() override;
|
||||
void runOnOperation() override;
|
||||
|
||||
// Outlines the regions of the WhileOp's cond and body and insert function
|
||||
// calls instead,
|
||||
@ -130,7 +131,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
||||
|
||||
// Create outline function from region. Optional pass extra arguments through
|
||||
// to yield.
|
||||
SymbolTable symbol_table(getModule());
|
||||
SymbolTable symbol_table(getOperation());
|
||||
auto create_outline_func = [&](StringRef name, Region& region,
|
||||
bool passthru_extra_args) {
|
||||
FunctionType type;
|
||||
@ -234,8 +235,8 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
||||
op->erase();
|
||||
}
|
||||
|
||||
void WhileOutlinePass::runOnModule() {
|
||||
getModule().walk(
|
||||
void WhileOutlinePass::runOnOperation() {
|
||||
getOperation().walk(
|
||||
[&](mlir::TFL::WhileOp while_op) { OutlineWhile(while_op); });
|
||||
}
|
||||
|
||||
|
@ -32,10 +32,12 @@ namespace errors = tensorflow::errors;
|
||||
|
||||
mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
|
||||
switch (type) {
|
||||
case tflite::TensorType_FLOAT32:
|
||||
return builder.getF32Type();
|
||||
case tflite::TensorType_FLOAT16:
|
||||
return builder.getF16Type();
|
||||
case tflite::TensorType_FLOAT32:
|
||||
return builder.getF32Type();
|
||||
case tflite::TensorType_FLOAT64:
|
||||
return builder.getF64Type();
|
||||
case tflite::TensorType_INT32:
|
||||
return builder.getIntegerType(32);
|
||||
case tflite::TensorType_UINT8:
|
||||
@ -65,6 +67,8 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) {
|
||||
return tensorflow::DT_HALF;
|
||||
case tflite::TensorType_FLOAT32:
|
||||
return tensorflow::DT_FLOAT;
|
||||
case tflite::TensorType_FLOAT64:
|
||||
return tensorflow::DT_DOUBLE;
|
||||
case tflite::TensorType_INT8:
|
||||
return tensorflow::DT_INT8;
|
||||
case tflite::TensorType_INT16:
|
||||
|
@ -545,13 +545,44 @@ LogicalResult Verify(SwitchNOp switchn) {
|
||||
<< "expect `num_outs` (" << num_outs.getInt() << ") results but got "
|
||||
<< (switchn.getNumResults() - 1);
|
||||
|
||||
// Check that operand can be broadcasted to each output type.
|
||||
auto operand0_type = switchn.getOperand(0).getType();
|
||||
for (Value result : switchn.outputs())
|
||||
if (operand0_type != result.getType())
|
||||
return switchn.emitOpError()
|
||||
<< "type mismatch between data operand and result: "
|
||||
<< operand0_type << " vs " << result.getType();
|
||||
TensorType operand0_tensor_type = operand0_type.dyn_cast<TensorType>();
|
||||
if (!operand0_tensor_type) {
|
||||
return switchn.emitOpError()
|
||||
<< "expects data operand to have tensor type but got "
|
||||
<< operand0_type;
|
||||
}
|
||||
for (Type output_type : switchn.getResultTypes()) {
|
||||
if (output_type.isa<ControlType>()) break;
|
||||
|
||||
TensorType output_tensor_type = output_type.dyn_cast<TensorType>();
|
||||
if (!output_tensor_type) {
|
||||
return switchn.emitOpError()
|
||||
<< "expects outputs to have tensor type but got " << output_type;
|
||||
}
|
||||
|
||||
// If the output type is a ref type, then the operand type should also be of
|
||||
// the same ref type. However, if the output type is a non-ref type T, then
|
||||
// the operand can be tensor of type T or T_REF.
|
||||
bool is_output_ref =
|
||||
output_tensor_type.getElementType().isa<TF::TensorFlowRefType>();
|
||||
if (is_output_ref &&
|
||||
!operand0_tensor_type.getElementType().isa<TF::TensorFlowRefType>()) {
|
||||
return switchn.emitOpError()
|
||||
<< "expects same operand and output element type but got "
|
||||
<< operand0_tensor_type << " vs " << output_tensor_type;
|
||||
}
|
||||
Type broadcasted_type = OpTrait::util::getBroadcastedType(
|
||||
DropRefType(DropTypeSubTypes(operand0_tensor_type)),
|
||||
DropRefType(DropTypeSubTypes(output_tensor_type)));
|
||||
if (!broadcasted_type) {
|
||||
return switchn.emitOpError()
|
||||
<< "expects data operand to be broadcastable with all output types"
|
||||
<< " but got " << operand0_tensor_type << " vs "
|
||||
<< output_tensor_type;
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -5301,6 +5301,8 @@ tf.pow(x, y) ==> [[256, 65536], [9, 27]]
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_PreventGradientOp : TF_Op<"PreventGradient", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
@ -6006,6 +6008,30 @@ Resize `images` to `size` using nearest neighbor interpolation.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ResourceApplyAdagradV2Op : TF_Op<"ResourceApplyAdagradV2", []> {
|
||||
let summary = "Update '*var' according to the adagrad scheme.";
|
||||
|
||||
let description = [{
|
||||
accum += grad * grad
|
||||
var -= lr * grad * (1 / (sqrt(accum) + epsilon))
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_ResourceTensor:$var,
|
||||
TF_ResourceTensor:$accum,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$use_locking,
|
||||
DefaultValuedAttr<BoolAttr, "true">:$update_slots
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>;
|
||||
}
|
||||
|
||||
def TF_ResourceApplyAdamOp : TF_Op<"ResourceApplyAdam", []> {
|
||||
let summary = "Update '*var' according to the Adam algorithm.";
|
||||
|
||||
@ -7711,6 +7737,28 @@ shape of `StridedSlice`'s `input`.
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_StringFormatOp : TF_Op<"StringFormat", [NoSideEffect]> {
|
||||
let summary = "Formats a string template using a list of tensors.";
|
||||
|
||||
let description = [{
|
||||
Formats a string template using a list of tensors, pretty-printing tensor summaries.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TF_Tensor>:$inputs,
|
||||
|
||||
DefaultValuedAttr<StrAttr, "%s">:$strtemplate,
|
||||
DefaultValuedAttr<StrAttr, "%s">:$placeholder,
|
||||
DefaultValuedAttr<I64Attr, "3">:$summarize
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_StrTensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
|
||||
}
|
||||
|
||||
def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns x - y element-wise.";
|
||||
|
@ -2153,6 +2153,27 @@ static LogicalResult VerifyPartitionedCall(OpClass op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PowOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult PowOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto constant_y = operands[1].dyn_cast_or_null<DenseFPElementsAttr>();
|
||||
if (constant_y && constant_y.isSplat()) {
|
||||
APFloat y_value = constant_y.getSplatValue<APFloat>();
|
||||
auto output_type = getType().cast<ShapedType>();
|
||||
if (y_value.isZero() && output_type.hasStaticShape()) {
|
||||
return DenseElementsAttr::get(
|
||||
output_type,
|
||||
FloatAttr::get(output_type.getElementType(), /*value=*/1.0));
|
||||
}
|
||||
if (y_value.isExactlyValue(1.0)) {
|
||||
return x();
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReciprocalOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -18,6 +18,26 @@ func @testShape(tensor<f32>, tensor<1x32x32x16xf32>, tensor<*xf32>) -> (tensor<0
|
||||
return %0, %1, %2 : tensor<0xi32>, tensor<?xi32>, tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testPow
|
||||
// CHECK-SAME:(%[[ARG_0:.*]]: tensor<4xf32>, %[[ARG_1:.*]]: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>)
|
||||
func @testPow(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) {
|
||||
|
||||
%cst_zero = constant dense<0.0> : tensor<f32>
|
||||
%cst_one = constant dense<1.0> : tensor<f32>
|
||||
|
||||
// CHECK-DAG: %[[RES_NO_FOLD:.*]] = "tf.Pow"(%arg0, %arg1)
|
||||
%0 = "tf.Pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK-DAG: %[[POW_ZERO:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
%1 = "tf.Pow"(%arg0, %cst_zero) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK-NOT: "tf.Pow"
|
||||
%2 = "tf.Pow"(%arg0, %cst_one) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK: return %[[RES_NO_FOLD]], %[[POW_ZERO]], %[[ARG_0]]
|
||||
return %0, %1, %2 : tensor<4xf32>, tensor<4xf32>, tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testShapeN
|
||||
func @testShapeN(%arg0: tensor<f32>, %arg1: tensor<1x32x32x16xf32>, %arg2: tensor<*xf32>) -> (tensor<0xi64>, tensor<4xi64>, tensor<4xi64>, tensor<?xi64>) {
|
||||
|
||||
|
@ -147,6 +147,37 @@ func @decompose_resource_apply_keras_momentum_nesterov(%arg0: tensor<f32>, %arg1
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
// Tests that composite tf.ResourceApplyAdagradV2 operation is decomposed.
|
||||
|
||||
// CHECK-LABEL: func @decompose_resource_apply_adagradv2
|
||||
// CHECK-SAME: ([[LR:%.*]]: tensor<f32>, [[EPSILON:%.*]]: tensor<f32>, [[GRAD:%.*]]: tensor<f32>)
|
||||
func @decompose_resource_apply_adagradv2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> () {
|
||||
|
||||
// CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"()
|
||||
// CHECK: [[ACC_HANDLE:%.*]] = "tf.VarHandleOp"()
|
||||
// CHECK: [[GRAD_SQUARE:%.*]] = "tf.Mul"([[GRAD]], [[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: [[OLD_ACC:%.*]] = "tf.ReadVariableOp"([[ACC_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
|
||||
// CHECK: [[NEW_ACC:%.*]] = "tf.AddV2"([[OLD_ACC]], [[GRAD_SQUARE]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
|
||||
// CHECK: [[LR_MULTIPLY:%.*]] = "tf.Mul"([[LR]], [[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: [[SQRT:%.*]] = "tf.Sqrt"([[NEW_ACC]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: [[DIVISOR:%.*]] = "tf.AddV2"([[SQRT]], [[EPSILON]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
|
||||
// CHECK: [[VAR_DELTA:%.*]] = "tf.Div"([[LR_MULTIPLY]], [[DIVISOR]]) : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
|
||||
// CHECK: [[NEW_VAR:%.*]] = "tf.Sub"(%9, %8) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
|
||||
// CHECK: "tf.AssignVariableOp"([[ACC_HANDLE]], [[NEW_ACC]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
|
||||
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
"tf.ResourceApplyAdagradV2"(%0, %1, %arg0, %arg1, %arg2) {update_slots = true, use_locking = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that composite tf.ResourceApplyAdam (non-Nesterov) operation is
|
||||
// decomposed.
|
||||
|
||||
|
@ -248,6 +248,40 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
||||
return %0 : tensor<?x?x?xf32>
|
||||
}
|
||||
|
||||
// Check that supported tf_executor ops can receive data from ops on which
|
||||
// shape inference has inferred the result types, without throwing any errors.
|
||||
// CHECK-LABEL: func @supported_tf_executor_users
|
||||
func @supported_tf_executor_users(%arg0: tensor<32x?x256x4xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<i1>, %arg3: tensor<i32>) -> tensor<?x?x?xf32> {
|
||||
%0 = tf_executor.graph {
|
||||
%island:3 = tf_executor.island {
|
||||
%dims = "tf.Const"() {value = dense<[32, -1, 4]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
%reshape = "tf.Reshape"(%arg0, %dims) : (tensor<32x?x256x4xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
|
||||
%cast = "tf.Cast"(%arg2) : (tensor<i1>) -> tensor<*xi1>
|
||||
tf_executor.yield %reshape, %cast : tensor<?x?x?xf32>, tensor<*xi1>
|
||||
}
|
||||
// CHECK: tf_executor.Merge
|
||||
// CHECK-SAME: : (tensor<32x?x4xf32>, tensor<?x?x?xf32>) ->
|
||||
// CHECK: tf_executor.Switch
|
||||
// CHECK-SAME: : (tensor<32x?x4xf32>, tensor<i1>) ->
|
||||
// CHECK: tf_executor.SwitchN
|
||||
// CHECK-SAME: : tensor<?x?x?xf32>
|
||||
// CHECK: tf_executor.Enter
|
||||
// CHECK-SAME: : (tensor<32x?x4xf32>) ->
|
||||
// CHECK: tf_executor.Exit
|
||||
// CHECK-SAME: : tensor<?x?x?xf32>
|
||||
// CHECK: tf_executor.LoopCond
|
||||
// CHECK-SAME: : tensor<*xi1>
|
||||
%merge:3 = "tf_executor.Merge"(%island#0, %arg1) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<i32>, !tf_executor.control)
|
||||
%switch:3 = "tf_executor.Switch"(%island#0, %arg2) : (tensor<?x?x?xf32>, tensor<i1>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>, !tf_executor.control)
|
||||
%switchn:3 = "tf_executor.SwitchN"(%island#0, %arg3) {num_outs = 2} : (tensor<?x?x?xf32>, tensor<i32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>, !tf_executor.control)
|
||||
%enter:2 = "tf_executor.Enter"(%island#0) { frame_name = "frame"} : (tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, !tf_executor.control)
|
||||
%exit:2 = "tf_executor.Exit"(%island#0) : (tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, !tf_executor.control)
|
||||
%loop_cond:2 = "tf_executor.LoopCond" (%island#1) : (tensor<*xi1>) -> (tensor<*xi1>, !tf_executor.control)
|
||||
tf_executor.fetch %enter#0 : tensor<?x?x?xf32>
|
||||
}
|
||||
return %0 : tensor<?x?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_cast
|
||||
func @fold_cast(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NOT: Cast
|
||||
|
@ -7,7 +7,7 @@ func @invalid_type() -> !tf_executor.foobar
|
||||
|
||||
// Check that tf_executor.graph does not accept any operand.
|
||||
func @graph_with_invalid_op(%arg0: tensor<*xf32>) {
|
||||
"tf_executor.graph" (%arg0) : (tensor<*xf32>) -> ()
|
||||
"tf_executor.graph" (%arg0) ({}) : (tensor<*xf32>) -> ()
|
||||
// expected-error@-1 {{'tf_executor.graph' op requires zero operands}}
|
||||
return
|
||||
}
|
||||
@ -405,12 +405,49 @@ func @invalid_switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// Check that switchN result type matches the input type.
|
||||
func @invalid_switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// Check that data operands of SwitchN have tensor type
|
||||
func @invalid_switchN(%arg0: i32, %arg1: tensor<i32>) -> tensor<*xi32> {
|
||||
%result = tf_executor.graph {
|
||||
%1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (i32, tensor<i32>) -> (tensor<*xi32>, tensor<i32>, !tf_executor.control)
|
||||
// expected-error@-1 {{'tf_executor.SwitchN' op expects data operand to have tensor type but got 'i32'}}
|
||||
tf_executor.fetch %1#0 : tensor<*xi32>
|
||||
}
|
||||
return %result : tensor<*xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that result of SwitchN has tensor type
|
||||
func @invalid_switchN(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> i32 {
|
||||
%result = tf_executor.graph {
|
||||
%1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<*xi32>, tensor<i32>) -> (i32, tensor<i32>, !tf_executor.control)
|
||||
// expected-error@-1 {{'tf_executor.SwitchN' op expects outputs to have tensor type but got 'i32'}}
|
||||
tf_executor.fetch %1#0 : i32
|
||||
}
|
||||
return %result : i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that if any result is a ref type, then data operand needs to be ref too.
|
||||
func @invalid_switchN(%arg0: tensor<4xf32>, %arg1: tensor<i32>) -> tensor<4x!tf.f32ref> {
|
||||
%fetches = tf_executor.graph {
|
||||
|
||||
%1:3 = "tf_executor.SwitchN"(%arg1, %arg0) {num_outs = 2} : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, i32, !tf_executor.control)
|
||||
// expected-error@-1 {{'tf_executor.SwitchN' op type mismatch between data operand and result: 'tensor<*xf32>' vs 'i32'}}
|
||||
%1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<4xf32>, tensor<i32>) -> (tensor<4x!tf.f32ref>, tensor<4xf32>, !tf_executor.control)
|
||||
// expected-error@-1 {{'tf_executor.SwitchN' op expects same operand and output element type but got 'tensor<4xf32>' vs 'tensor<4x!tf.f32ref>'}}
|
||||
tf_executor.fetch %1#0 : tensor<4x!tf.f32ref>
|
||||
}
|
||||
return %fetches : tensor<4x!tf.f32ref>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that switchN data operand is broadcastable with all output types
|
||||
func @invalid_switchN(%arg0: tensor<*xf32>, %arg1: tensor<i32>) -> tensor<*xf32> {
|
||||
%fetches = tf_executor.graph {
|
||||
|
||||
%1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<i32>, !tf_executor.control)
|
||||
// expected-error@-1 {{'tf_executor.SwitchN' op expects data operand to be broadcastable with all output types but got 'tensor<*xf32>' vs 'tensor<i32>'}}
|
||||
|
||||
tf_executor.fetch %1#0 : tensor<*xf32>
|
||||
}
|
||||
@ -472,6 +509,30 @@ func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// Check that data operands of merge have tensor type
|
||||
func @invalid_merge(%arg0: tensor<*xi32>, %arg1: i32) -> tensor<*xi32> {
|
||||
%result = tf_executor.graph {
|
||||
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<*xi32>, i32) -> (tensor<*xi32>, tensor<i32>, !tf_executor.control)
|
||||
// expected-error@-1 {{'tf_executor.Merge' op expects data operands to have tensor type but got 'i32'}}
|
||||
tf_executor.fetch %value : tensor<*xi32>
|
||||
}
|
||||
return %result : tensor<*xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that result of merge has tensor type
|
||||
func @invalid_merge(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> i32 {
|
||||
%result = tf_executor.graph {
|
||||
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<*xi32>, tensor<i32>) -> (i32, tensor<i32>, !tf_executor.control)
|
||||
// expected-error@-1 {{'tf_executor.Merge' op result #0 must be tensor of any type values, but got 'i32'}}
|
||||
tf_executor.fetch %value : i32
|
||||
}
|
||||
return %result : i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that merge data inputs are all the same type
|
||||
func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
|
||||
%result = tf_executor.graph {
|
||||
|
@ -0,0 +1,92 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# RUN: %p/multi_arguments_results_v1 | FileCheck -dump-input-on-failure %s
|
||||
|
||||
# pylint: disable=missing-docstring,line-too-long
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
||||
# Tests multiple inputs and outputs with index paths.
|
||||
|
||||
# CHECK-LABEL: func @key(
|
||||
# CHECK-SAME: %[[ARG0:.*]]: tensor<3x5xf32> {tf_saved_model.index_path = ["y"]}
|
||||
# CHECK-SAME: %[[ARG1:.*]]: tensor<5x3xf32> {tf_saved_model.index_path = ["x"]}
|
||||
# CHECK-SAME: tensor<3x3xf32> {tf_saved_model.index_path = ["t"]}
|
||||
# CHECK-SAME: tensor<5x5xf32> {tf_saved_model.index_path = ["s"]}
|
||||
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["key"]
|
||||
# CHECK-DAG: %[[MUL0:.*]] = "tf.MatMul"(%[[ARG1]], %[[ARG0]])
|
||||
# CHECK-DAG: %[[MUL1:.*]] = "tf.MatMul"(%[[ARG0]], %[[ARG1]])
|
||||
# CHECK: %[[IDENTITY:.*]]:2 = "tf.IdentityN"(%[[MUL1]], %[[MUL0]])
|
||||
# CHECK: return %[[IDENTITY]]#0, %[[IDENTITY]]#1
|
||||
|
||||
# CHECK-LABEL: func @key2(
|
||||
# CHECK-SAME: %[[ARG1:.*]]: tensor<5x3xf32> {tf_saved_model.index_path = ["b"]}
|
||||
# CHECK-SAME: %[[ARG0:.*]]: tensor<3x5xf32> {tf_saved_model.index_path = ["a"]}
|
||||
# CHECK-SAME: tensor<5x5xf32> {tf_saved_model.index_path = ["d"]}
|
||||
# CHECK-SAME: tensor<3x3xf32> {tf_saved_model.index_path = ["c"]}
|
||||
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["key2"]
|
||||
# CHECK-DAG: %[[MUL1:.*]] = "tf.MatMul"(%[[ARG0]], %[[ARG1]])
|
||||
# CHECK-DAG: %[[MUL2:.*]] = "tf.MatMul"(%[[ARG1]], %[[ARG0]])
|
||||
# CHECK: %[[IDENTITY:.*]]:2 = "tf.IdentityN"(%[[MUL1]], %[[MUL2]])
|
||||
# CHECK: return %[[IDENTITY]]#1, %[[IDENTITY]]#0
|
||||
|
||||
|
||||
def Test():
|
||||
|
||||
x = tf.constant(1.0, shape=(5, 3))
|
||||
y = tf.constant(1.0, shape=(3, 5))
|
||||
|
||||
s = tf.matmul(x, y)
|
||||
t = tf.matmul(y, x)
|
||||
[t, s] = array_ops.identity_n([t, s])
|
||||
|
||||
tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x)
|
||||
tensor_info_y = tf.compat.v1.saved_model.utils.build_tensor_info(y)
|
||||
tensor_info_s = tf.compat.v1.saved_model.utils.build_tensor_info(s)
|
||||
tensor_info_t = tf.compat.v1.saved_model.utils.build_tensor_info(t)
|
||||
|
||||
return {
|
||||
'key': (tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
|
||||
inputs={
|
||||
'x': tensor_info_x,
|
||||
'y': tensor_info_y
|
||||
},
|
||||
outputs={
|
||||
's': tensor_info_s,
|
||||
't': tensor_info_t
|
||||
},
|
||||
method_name='some_function')),
|
||||
'key2': (tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
|
||||
inputs={
|
||||
'a': tensor_info_y,
|
||||
'b': tensor_info_x,
|
||||
},
|
||||
outputs={
|
||||
'c': tensor_info_t,
|
||||
'd': tensor_info_s,
|
||||
},
|
||||
method_name='reverse_arguments'))
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
common_v1.set_tf_options()
|
||||
common_v1.do_test(Test())
|
@ -1,64 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# RUN: %p/multi_arguments_v1 | FileCheck %s
|
||||
|
||||
# pylint: disable=missing-docstring,line-too-long
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
|
||||
|
||||
# Tests multiple inputs with index paths.
|
||||
# CHECK: func {{@[a-zA-Z_0-9]+}}(
|
||||
# CHECK-SAME: [[ARG0:%.*]]: tensor<5x3xf32> {tf_saved_model.index_path = ["x"]},
|
||||
# CHECK-SAME: [[ARG1:%.*]]: tensor<3x5xf32> {tf_saved_model.index_path = ["y"]})
|
||||
# CHECK-SAME: -> (tensor<5x5xf32> {tf_saved_model.index_path = ["s"]},
|
||||
# CHECK-SAME: tensor<3x3xf32> {tf_saved_model.index_path = ["t"]})
|
||||
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["key"]
|
||||
|
||||
|
||||
def Test():
|
||||
|
||||
x = tf.constant(1.0, shape=(5, 3))
|
||||
y = tf.constant(1.0, shape=(3, 5))
|
||||
|
||||
s = tf.matmul(x, y)
|
||||
t = tf.matmul(y, x)
|
||||
|
||||
tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x)
|
||||
tensor_info_y = tf.compat.v1.saved_model.utils.build_tensor_info(y)
|
||||
tensor_info_s = tf.compat.v1.saved_model.utils.build_tensor_info(s)
|
||||
tensor_info_t = tf.compat.v1.saved_model.utils.build_tensor_info(t)
|
||||
|
||||
return {
|
||||
'key': (tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
|
||||
inputs={
|
||||
'x': tensor_info_x,
|
||||
'y': tensor_info_y
|
||||
},
|
||||
outputs={
|
||||
's': tensor_info_s,
|
||||
't': tensor_info_t
|
||||
},
|
||||
method_name='some_function'))
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
common_v1.set_tf_options()
|
||||
common_v1.do_test(Test())
|
@ -39,8 +39,8 @@ constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
|
||||
// Analyzes the inputs to LaunchFuncOps in the module, and annotates their
|
||||
// invoked functions whether each input has the same data across replicas.
|
||||
struct AnnotateParameterReplication
|
||||
: public ModulePass<AnnotateParameterReplication> {
|
||||
void runOnModule() override;
|
||||
: public OperationPass<AnnotateParameterReplication, ModuleOp> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
// Returns the first value in the chain of operands, which is not defined by a
|
||||
@ -53,8 +53,8 @@ Value SkipIdentityAndReadVariable(Value v) {
|
||||
return v;
|
||||
}
|
||||
|
||||
void AnnotateParameterReplication::runOnModule() {
|
||||
ModuleOp m = getModule();
|
||||
void AnnotateParameterReplication::runOnOperation() {
|
||||
ModuleOp m = getOperation();
|
||||
OpBuilder builder(m.getContext());
|
||||
m.walk([&](tf_device::LaunchFuncOp launch_func) {
|
||||
auto replicate = launch_func.getParentOfType<tf_device::ReplicateOp>();
|
||||
|
@ -38,8 +38,9 @@ namespace {
|
||||
constexpr char kDeviceAttr[] = "device";
|
||||
constexpr char kFuncAttr[] = "func";
|
||||
|
||||
struct ClusterOutliningPass : public ModulePass<ClusterOutliningPass> {
|
||||
void runOnModule() override;
|
||||
struct ClusterOutliningPass
|
||||
: public OperationPass<ClusterOutliningPass, ModuleOp> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
void ReplaceLaunchReturnWithReturn(tf_device::ReturnOp launch_return_op,
|
||||
@ -120,8 +121,8 @@ void OutlineLaunch(tf_device::LaunchOp launch_op, SymbolTable* symbol_table,
|
||||
launch_op.erase();
|
||||
}
|
||||
|
||||
void ClusterOutliningPass::runOnModule() {
|
||||
ModuleOp m = getModule();
|
||||
void ClusterOutliningPass::runOnOperation() {
|
||||
ModuleOp m = getOperation();
|
||||
SymbolTable symbol_table(m);
|
||||
OpBuilder builder(m.getContext());
|
||||
m.walk([&](tf_device::LaunchOp launch) {
|
||||
|
@ -22,7 +22,7 @@ class GetScalarOfType<int value> : NativeCodeCall<
|
||||
"GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
|
||||
|
||||
// Creates a tf.ReadVariable op that reads a resource `$2` that has the same
|
||||
// element type as `$1`. The op created will use location of `$1`.
|
||||
// element type as `$1`. The op created will use location of `$0`.
|
||||
def CreateTFReadVariableOp: NativeCodeCall<
|
||||
"$_builder.create<TF::ReadVariableOp>("
|
||||
" $0.getLoc(),"
|
||||
@ -118,6 +118,32 @@ def DecomposeResourceApplyKerasMomentumOpNesterov :
|
||||
]
|
||||
>;
|
||||
|
||||
// Pattern to Decompose ResourceApplyAdagrad.
|
||||
// This decomposition is only correct inside XLA as it ignores use_locking
|
||||
// attribute.
|
||||
// accum <- accum + grad * grad
|
||||
// variable <- variable - lr * grad / (sqrt(accum) + epsilon)
|
||||
def DecomposeResourceApplyAdagradV2 :
|
||||
Pattern<
|
||||
(TF_ResourceApplyAdagradV2Op:$src_op
|
||||
$var_resource, $accum_resource, $lr, $epsilon, $grad, BoolAttr:$_,
|
||||
ConstBoolAttrTrue:$update_slots),
|
||||
[
|
||||
(TF_AddV2Op:$new_accum
|
||||
(CreateTFReadVariableOp $src_op, $grad, $accum_resource),
|
||||
(TF_MulOp $grad, $grad)
|
||||
),
|
||||
(TF_AssignSubVariableOp
|
||||
$var_resource,
|
||||
(TF_DivOp
|
||||
(TF_MulOp $lr, $grad),
|
||||
(TF_AddV2Op (TF_SqrtOp $new_accum), $epsilon)
|
||||
)
|
||||
),
|
||||
(TF_AssignVariableOp $accum_resource, $new_accum),
|
||||
]
|
||||
>;
|
||||
|
||||
// Pattern to Decompose ResourceApplyAdam without Nesterov momentum.
|
||||
// This decomposition is only correct inside XLA as it ignores use_locking
|
||||
// attribute.
|
||||
|
@ -303,7 +303,8 @@ void InsertDummyIslandForFetch(FetchOp fetch) {
|
||||
/*control=*/ControlType::get(fetch.getContext()),
|
||||
/*controlInputs=*/control_fetches);
|
||||
island.body().push_back(new Block);
|
||||
OpBuilder(&island.GetBody()).create<YieldOp>(fetch.getLoc(), data_fetches);
|
||||
OpBuilder::atBlockEnd(&island.GetBody())
|
||||
.create<YieldOp>(fetch.getLoc(), data_fetches);
|
||||
const int fetch_control_idx = data_fetches.size();
|
||||
for (int i = 0, e = fetch.getNumOperands(); i < e; i++) {
|
||||
// The fetch could have multiple control operands (all at the end of its
|
||||
|
@ -43,17 +43,17 @@ constexpr llvm::StringRef kNestedModule = "_tpu_v1_compat_outlined";
|
||||
// Inlining the islands calling into the nested module that was outlined.
|
||||
// This is the end of the TPU bridge in V1 compatibility mode.
|
||||
struct TPUBridgeExecutorIslandInlining
|
||||
: public ModulePass<TPUBridgeExecutorIslandInlining> {
|
||||
void runOnModule() override;
|
||||
: public OperationPass<TPUBridgeExecutorIslandInlining, ModuleOp> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
void TPUBridgeExecutorIslandInlining::runOnModule() {
|
||||
SymbolTable symbol_table(getModule());
|
||||
void TPUBridgeExecutorIslandInlining::runOnOperation() {
|
||||
SymbolTable symbol_table(getOperation());
|
||||
Operation *nested_module = symbol_table.lookup(kNestedModule);
|
||||
if (!nested_module) return;
|
||||
|
||||
InlinerInterface inliner(&getContext());
|
||||
auto walk_result = getModule().walk([&](TF::PartitionedCallOp call_op) {
|
||||
auto walk_result = getOperation().walk([&](TF::PartitionedCallOp call_op) {
|
||||
if (!call_op.f().getRootReference().startswith(kNestedModule))
|
||||
return WalkResult::advance();
|
||||
// This is a call we need to inline!
|
||||
@ -61,7 +61,7 @@ void TPUBridgeExecutorIslandInlining::runOnModule() {
|
||||
<< "Found call to inline: " << *call_op.getOperation() << "\n");
|
||||
|
||||
FuncOp called_func = dyn_cast_or_null<FuncOp>(
|
||||
symbol_table.lookupSymbolIn(getModule(), call_op.f()));
|
||||
symbol_table.lookupSymbolIn(getOperation(), call_op.f()));
|
||||
|
||||
if (failed(inlineCall(inliner,
|
||||
cast<CallOpInterface>(call_op.getOperation()),
|
||||
@ -80,7 +80,7 @@ void TPUBridgeExecutorIslandInlining::runOnModule() {
|
||||
Block &nested_block = nested_module->getRegion(0).front();
|
||||
for (FuncOp func_op :
|
||||
llvm::make_early_inc_range(nested_block.getOps<FuncOp>())) {
|
||||
if (!symbol_table.lookupSymbolIn(getModule(), func_op.getName())) {
|
||||
if (!symbol_table.lookupSymbolIn(getOperation(), func_op.getName())) {
|
||||
nested_block.getOperations().remove(func_op.getOperation());
|
||||
symbol_table.insert(func_op.getOperation());
|
||||
}
|
||||
|
@ -59,8 +59,8 @@ constexpr llvm::StringRef kTpuStatusAttr = "_tpu_compilation_status";
|
||||
// TPU-annotated operations and intended to preserve backward compatibility with
|
||||
// TFv1.
|
||||
struct TpuV1BridgeExecutorIslandCoarsening
|
||||
: public ModulePass<TpuV1BridgeExecutorIslandCoarsening> {
|
||||
void runOnModule() override;
|
||||
: public OperationPass<TpuV1BridgeExecutorIslandCoarsening, ModuleOp> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
// Sort the Operations in the provided range to enforce dominance.
|
||||
@ -226,7 +226,8 @@ LogicalResult MergeIsland(llvm::function_ref<bool(StringAttr, Operation*)>
|
||||
yield_operands.push_back(std::get<1>(result));
|
||||
}
|
||||
}
|
||||
OpBuilder(&island_body).create<YieldOp>(new_island.getLoc(), yield_operands);
|
||||
OpBuilder::atBlockEnd(&island_body)
|
||||
.create<YieldOp>(new_island.getLoc(), yield_operands);
|
||||
|
||||
// remap results of the new islands to the user outside of the island.
|
||||
int current_result = 0;
|
||||
@ -257,13 +258,13 @@ LogicalResult MergeIsland(llvm::function_ref<bool(StringAttr, Operation*)>
|
||||
first_op_after);
|
||||
}
|
||||
|
||||
void TpuV1BridgeExecutorIslandCoarsening::runOnModule() {
|
||||
SymbolTable symbol_table(getModule());
|
||||
void TpuV1BridgeExecutorIslandCoarsening::runOnOperation() {
|
||||
SymbolTable symbol_table(getOperation());
|
||||
|
||||
// Map tpu cluster names to the functions that contain operations for this
|
||||
// cluster.
|
||||
DenseMap<StringRef, DenseSet<FuncOp>> tpu_funcs;
|
||||
for (FuncOp func_op : getModule().getOps<FuncOp>()) {
|
||||
for (FuncOp func_op : getOperation().getOps<FuncOp>()) {
|
||||
func_op.walk([&](Operation* op) {
|
||||
StringAttr cluster_name =
|
||||
op->getAttrOfType<StringAttr>(kTpuReplicateAttr);
|
||||
@ -291,7 +292,7 @@ void TpuV1BridgeExecutorIslandCoarsening::runOnModule() {
|
||||
return false;
|
||||
};
|
||||
|
||||
for (FuncOp func_op : getModule().getOps<FuncOp>()) {
|
||||
for (FuncOp func_op : getOperation().getOps<FuncOp>()) {
|
||||
func_op.walk([&](GraphOp graph) {
|
||||
Block& graph_body = graph.GetBody();
|
||||
|
||||
|
@ -44,20 +44,20 @@ constexpr llvm::StringRef kOutlinedFuncPrefix = "_tpu_v1_compat_outlined_func";
|
||||
// This is only intended for V1 compatibility mode where the bridge runs without
|
||||
// feed/fetches on session create/extend.
|
||||
struct TPUBridgeExecutorIslandOutlining
|
||||
: public ModulePass<TPUBridgeExecutorIslandOutlining> {
|
||||
void runOnModule() override;
|
||||
: public OperationPass<TPUBridgeExecutorIslandOutlining, ModuleOp> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
void TPUBridgeExecutorIslandOutlining::runOnModule() {
|
||||
void TPUBridgeExecutorIslandOutlining::runOnOperation() {
|
||||
MLIRContext *ctx = &getContext();
|
||||
|
||||
SymbolTable symbol_table(getModule());
|
||||
SymbolTable symbol_table(getOperation());
|
||||
if (Operation *nested_module = symbol_table.lookup(kNestedModule)) {
|
||||
nested_module->emitOpError("unexpected already present outlined module.");
|
||||
return signalPassFailure();
|
||||
}
|
||||
ModuleOp outlined_module = ModuleOp::create(getModule().getLoc());
|
||||
outlined_module.setAttrs(getModule().getAttrs());
|
||||
ModuleOp outlined_module = ModuleOp::create(getOperation().getLoc());
|
||||
outlined_module.setAttrs(getOperation().getAttrs());
|
||||
outlined_module.setAttr(SymbolTable::getSymbolAttrName(),
|
||||
StringAttr::get(kNestedModule, ctx));
|
||||
symbol_table.insert(outlined_module);
|
||||
@ -66,7 +66,7 @@ void TPUBridgeExecutorIslandOutlining::runOnModule() {
|
||||
// Find every island that contains a TPUReplicateMetadata node and extract it
|
||||
// in a new module to run the V1 bridge there.
|
||||
SmallVector<IslandOp, 8> islands_to_outline;
|
||||
getModule().walk([&](TF::TPUReplicateMetadataOp replicate_op) {
|
||||
getOperation().walk([&](TF::TPUReplicateMetadataOp replicate_op) {
|
||||
auto island_op = cast<IslandOp>(replicate_op.getParentOp());
|
||||
if (!island_op || island_op.WrapsSingleOp()) return;
|
||||
islands_to_outline.push_back(island_op);
|
||||
@ -123,7 +123,7 @@ void TPUBridgeExecutorIslandOutlining::runOnModule() {
|
||||
|
||||
// The function is in place in the nested module, create a call and yield in
|
||||
// the original island.
|
||||
OpBuilder builder(&island_op.GetBody());
|
||||
OpBuilder builder = OpBuilder::atBlockEnd(&island_op.GetBody());
|
||||
auto call_op = builder.create<mlir::TF::PartitionedCallOp>(
|
||||
island_op.getLoc(), func_result_types, operands.getArrayRef(),
|
||||
builder.getSymbolRefAttr(
|
||||
|
@ -202,7 +202,7 @@ static void MatchSwitchFoldOps(tf_executor::SwitchOp switch_op,
|
||||
static LogicalResult FoldMergeNodes(FuncOp function, const DeadQueue& queue) {
|
||||
// Create builder for val_index of MergeOp.
|
||||
auto* block = &function.getBlocks().front();
|
||||
OpBuilder builder(block);
|
||||
OpBuilder builder = OpBuilder::atBlockEnd(block);
|
||||
auto type = builder.getIntegerType(32);
|
||||
auto build_index = [&](Location loc, int value) {
|
||||
return builder.create<ConstantOp>(loc, type,
|
||||
|
@ -41,12 +41,13 @@ namespace {
|
||||
// the IR is in correct form for inference backends (like lite) that do not
|
||||
// support resources/variables . Further, this contract also ensures that this
|
||||
// pass lowers from saved model to pure TF. Hence it fails, if it cannot lower.
|
||||
struct FreezeGlobalTensorsPass : public ModulePass<FreezeGlobalTensorsPass> {
|
||||
void runOnModule() override;
|
||||
struct FreezeGlobalTensorsPass
|
||||
: public OperationPass<FreezeGlobalTensorsPass, ModuleOp> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
void FreezeGlobalTensorsPass::runOnModule() {
|
||||
auto module = getModule();
|
||||
void FreezeGlobalTensorsPass::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
SymbolTable symbol_table(module);
|
||||
DenseSet<Operation*> frozen_global_tensors;
|
||||
|
||||
|
@ -126,7 +126,7 @@ void LayoutAssignmentPass::runOnFunction() {
|
||||
|
||||
mlir::Operation* op = layout_sensitive_interface.getOperation();
|
||||
Location loc = op->getLoc();
|
||||
OpBuilder builder(op->getBlock());
|
||||
OpBuilder builder = OpBuilder::atBlockEnd(op->getBlock());
|
||||
|
||||
auto perm_attr = [&](Permutation permutation) -> DenseIntElementsAttr {
|
||||
auto perm_ty = RankedTensorType::get({4}, builder.getIntegerType(32));
|
||||
|
@ -74,11 +74,11 @@ LogicalResult MarkFunctionVisibilityUsingEntryFunctionSpecification(
|
||||
|
||||
namespace {
|
||||
struct MarkFunctionVisibilityUsingEntryFunctionSpecificationPass
|
||||
: public ModulePass<
|
||||
MarkFunctionVisibilityUsingEntryFunctionSpecificationPass> {
|
||||
void runOnModule() override {
|
||||
: public OperationPass<
|
||||
MarkFunctionVisibilityUsingEntryFunctionSpecificationPass, ModuleOp> {
|
||||
void runOnOperation() override {
|
||||
if (failed(MarkFunctionVisibilityUsingEntryFunctionSpecification(
|
||||
getModule()))) {
|
||||
getOperation()))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
@ -110,9 +110,10 @@ static LogicalResult MarkFunctionVisibilityUsingSavedModelLinkage(
|
||||
|
||||
namespace {
|
||||
struct MarkFunctionVisibilityUsingSavedModelLinkagePass
|
||||
: public ModulePass<MarkFunctionVisibilityUsingSavedModelLinkagePass> {
|
||||
void runOnModule() override {
|
||||
if (failed(MarkFunctionVisibilityUsingSavedModelLinkage(getModule()))) {
|
||||
: public OperationPass<MarkFunctionVisibilityUsingSavedModelLinkagePass,
|
||||
ModuleOp> {
|
||||
void runOnOperation() override {
|
||||
if (failed(MarkFunctionVisibilityUsingSavedModelLinkage(getOperation()))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
@ -41,8 +41,8 @@ namespace mlir {
|
||||
namespace tf_saved_model {
|
||||
namespace {
|
||||
struct OptimizeGlobalTensorsPass
|
||||
: public ModulePass<OptimizeGlobalTensorsPass> {
|
||||
void runOnModule() override;
|
||||
: public OperationPass<OptimizeGlobalTensorsPass, ModuleOp> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
// A global tensor is bound to arguments of multiple funcs.
|
||||
@ -276,8 +276,8 @@ void EraseUnusedBoundInputs(ModuleOp module) {
|
||||
}
|
||||
}
|
||||
|
||||
void OptimizeGlobalTensorsPass::runOnModule() {
|
||||
auto module = getModule();
|
||||
void OptimizeGlobalTensorsPass::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
EraseUnusedBoundInputs(module);
|
||||
|
||||
ResourceAnalyzer resource_analyzer(module);
|
||||
|
@ -258,13 +258,13 @@ LogicalResult PromoteResourcesToArguments(FuncOp function) {
|
||||
}
|
||||
|
||||
class PromoteResourcesToArgsPass
|
||||
: public ModulePass<PromoteResourcesToArgsPass> {
|
||||
: public OperationPass<PromoteResourcesToArgsPass, ModuleOp> {
|
||||
public:
|
||||
void runOnModule() override;
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
void PromoteResourcesToArgsPass::runOnModule() {
|
||||
ModuleOp module = getModule();
|
||||
void PromoteResourcesToArgsPass::runOnOperation() {
|
||||
ModuleOp module = getOperation();
|
||||
FuncOp main_func = module.lookupSymbol<FuncOp>("main");
|
||||
if (!main_func) return;
|
||||
|
||||
|
@ -53,8 +53,9 @@ constexpr char kFuncDeviceAttr[] = "tf.device";
|
||||
//
|
||||
// This pass changes the module by adding "tf.device" attribute to function
|
||||
// arguments and adding "device" attribute to TF ops.
|
||||
struct ResourceDeviceInference : public ModulePass<ResourceDeviceInference> {
|
||||
void runOnModule() override;
|
||||
struct ResourceDeviceInference
|
||||
: public OperationPass<ResourceDeviceInference, ModuleOp> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
// A class that records each resource's device assignment in a function.
|
||||
@ -190,8 +191,8 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
|
||||
return failure(walk_res.wasInterrupted());
|
||||
}
|
||||
|
||||
void ResourceDeviceInference::runOnModule() {
|
||||
auto module = getModule();
|
||||
void ResourceDeviceInference::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
llvm::SmallDenseMap<Operation*, PerFunctionResult, 4> per_function_results;
|
||||
llvm::SetVector<FuncOp> worklist;
|
||||
module.walk([&](FuncOp func_op) {
|
||||
|
@ -131,8 +131,9 @@ namespace {
|
||||
// return %arg0
|
||||
// }
|
||||
//
|
||||
struct ResourceOpLiftingPass : public ModulePass<ResourceOpLiftingPass> {
|
||||
void runOnModule() override;
|
||||
struct ResourceOpLiftingPass
|
||||
: public OperationPass<ResourceOpLiftingPass, ModuleOp> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
// Removes identity nodes in the block. The device computation does not need
|
||||
@ -1050,13 +1051,13 @@ LogicalResult HoistForFunctionalControlFlow(
|
||||
// Lifts resource operation from tf_device.launch_func ops nested in `op`
|
||||
// outside. Returns failure if there are remaining resource-type values that can
|
||||
// not be lifted.
|
||||
void ResourceOpLiftingPass::runOnModule() {
|
||||
void ResourceOpLiftingPass::runOnOperation() {
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>
|
||||
lifted_partitioned_call_callees;
|
||||
auto result = getModule().walk([&](FuncOp func_op) {
|
||||
auto result = getOperation().walk([&](FuncOp func_op) {
|
||||
return func_op.walk([&](tf_device::LaunchOp launch_op) {
|
||||
if (failed(HoistForFunctionalControlFlow(
|
||||
&launch_op.GetBody(), getModule(),
|
||||
&launch_op.GetBody(), getOperation(),
|
||||
&lifted_partitioned_call_callees)) ||
|
||||
failed(HoistResourceOpsFromLaunchOp(launch_op))) {
|
||||
return WalkResult::interrupt();
|
||||
@ -1070,12 +1071,12 @@ void ResourceOpLiftingPass::runOnModule() {
|
||||
}
|
||||
|
||||
struct ResourceOpLiftingForMainFunctionPass
|
||||
: public ModulePass<ResourceOpLiftingForMainFunctionPass> {
|
||||
void runOnModule() override;
|
||||
: public OperationPass<ResourceOpLiftingForMainFunctionPass, ModuleOp> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
void ResourceOpLiftingForMainFunctionPass::runOnModule() {
|
||||
ModuleOp module = getModule();
|
||||
void ResourceOpLiftingForMainFunctionPass::runOnOperation() {
|
||||
ModuleOp module = getOperation();
|
||||
FuncOp main_func = module.lookupSymbol<FuncOp>("main");
|
||||
if (!main_func) {
|
||||
return;
|
||||
|
@ -111,7 +111,9 @@ bool IsSupportedNonTFOp(Operation* op) {
|
||||
return isa<tf_executor::YieldOp>(op) || isa<tf_executor::IslandOp>(op) ||
|
||||
isa<tf_executor::FetchOp>(op) || isa<tf_executor::GraphOp>(op) ||
|
||||
isa<tf_executor::NextIterationSinkOp>(op) || isa<ReturnOp>(op) ||
|
||||
isa<tf_device::ReturnOp>(op);
|
||||
isa<tf_device::ReturnOp>(op) || isa<tf_executor::MergeOp>(op) ||
|
||||
isa<tf_executor::SwitchOp>(op) || isa<tf_executor::SwitchNOp>(op) ||
|
||||
isa<tf_executor::EnterOp>(op) || isa<tf_executor::ExitOp>(op);
|
||||
}
|
||||
|
||||
// Inserts tf.Cast operation when changing the type of a result if the user is
|
||||
@ -224,7 +226,8 @@ GetSubtypes(Type type) {
|
||||
return GetSubtypesHelper<TF::VariantType>(type);
|
||||
}
|
||||
|
||||
// Makes result types match the operand types. Returns if anything is changed.
|
||||
// Makes result types match the operand types (the i-th result type will
|
||||
// match the i-th operand type). Returns true if anything is changed.
|
||||
bool PassThroughOperandTypes(OperandRange operands, ResultRange results) {
|
||||
bool changed = false;
|
||||
for (auto entry : llvm::zip(operands, results)) {
|
||||
|
@ -47,9 +47,9 @@ namespace {
|
||||
|
||||
// This transformation pass propagate shapes on the TensorFlow graph.
|
||||
// It is a ModulePass in order to be able to change function types.
|
||||
struct ShapeInference : public ModulePass<ShapeInference> {
|
||||
void runOnModule() override {
|
||||
auto module = getModule();
|
||||
struct ShapeInference : public OperationPass<ShapeInference, ModuleOp> {
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
auto producer_or = tensorflow::GetTfGraphProducerVersion(module);
|
||||
if (!producer_or.ok()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << producer_or.status().ToString(););
|
||||
|
@ -85,8 +85,8 @@ namespace cutil = TF::collection_ops_util;
|
||||
//
|
||||
// The pass also works across control flow and functional calls.
|
||||
struct StackOpsDecompositionPass
|
||||
: public ModulePass<StackOpsDecompositionPass> {
|
||||
void runOnModule() override;
|
||||
: public OperationPass<StackOpsDecompositionPass, ModuleOp> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
// Returns the type of the local variable for the stack size.
|
||||
@ -551,8 +551,8 @@ LogicalResult DecomposeStackOps(Block* block, ModuleOp module) {
|
||||
&decomposed_partitioned_call_callees);
|
||||
}
|
||||
|
||||
void StackOpsDecompositionPass::runOnModule() {
|
||||
auto module = getModule();
|
||||
void StackOpsDecompositionPass::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
auto main = module.lookupSymbol<FuncOp>("main");
|
||||
if (!main) return;
|
||||
if (failed(DecomposeStackOps(&main.front(), module))) {
|
||||
|
@ -68,8 +68,8 @@ using std::string;
|
||||
// shape.
|
||||
//
|
||||
struct TensorArrayOpsDecompositionPass
|
||||
: public ModulePass<TensorArrayOpsDecompositionPass> {
|
||||
void runOnModule() override;
|
||||
: public OperationPass<TensorArrayOpsDecompositionPass, ModuleOp> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
// Infers the element type and count for a TensorArraySplitV3Op. Requires
|
||||
@ -873,8 +873,8 @@ LogicalResult DecomposeTensorArrayOps(
|
||||
return success();
|
||||
}
|
||||
|
||||
void TensorArrayOpsDecompositionPass::runOnModule() {
|
||||
auto module = getModule();
|
||||
void TensorArrayOpsDecompositionPass::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
auto main = module.lookupSymbol<FuncOp>("main");
|
||||
if (!main) return;
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats> stats;
|
||||
|
@ -62,8 +62,8 @@ namespace cutil = TF::collection_ops_util;
|
||||
//
|
||||
// The pass also works across control flow and functional calls.
|
||||
struct TensorListOpsDecompositionPass
|
||||
: public ModulePass<TensorListOpsDecompositionPass> {
|
||||
void runOnModule() override;
|
||||
: public OperationPass<TensorListOpsDecompositionPass, ModuleOp> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
// Updates func's type according to its current arguments and return values.
|
||||
@ -671,8 +671,8 @@ LogicalResult DecomposeTensorListOps(Block* block, ModuleOp module) {
|
||||
&decomposed_partitioned_call_callees);
|
||||
}
|
||||
|
||||
void TensorListOpsDecompositionPass::runOnModule() {
|
||||
auto module = getModule();
|
||||
void TensorListOpsDecompositionPass::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
auto main = module.lookupSymbol<FuncOp>("main");
|
||||
if (!main) return;
|
||||
if (failed(DecomposeTensorListOps(&main.front(), module))) {
|
||||
|
@ -40,20 +40,20 @@ namespace tensorflow {
|
||||
// Optimization Passes and convert back to MLIR.
|
||||
// Constraints: This pass expects that all operations in the MLIR module either
|
||||
// belong to 'tf' or '_tf' dialect. The output is in '_tf' dialect.
|
||||
class GraphOptPass : public mlir::ModulePass<GraphOptPass> {
|
||||
class GraphOptPass : public mlir::OperationPass<GraphOptPass, mlir::ModuleOp> {
|
||||
public:
|
||||
explicit GraphOptPass(std::vector<tensorflow::GraphOptimizationPass*> passes)
|
||||
: passes_(std::move(passes)) {}
|
||||
|
||||
protected:
|
||||
void runOnModule() override;
|
||||
void runOnOperation() override;
|
||||
|
||||
// The passes to run on the module.
|
||||
std::vector<GraphOptimizationPass*> passes_;
|
||||
};
|
||||
|
||||
void GraphOptPass::runOnModule() {
|
||||
mlir::ModuleOp module_in = getModule();
|
||||
void GraphOptPass::runOnOperation() {
|
||||
mlir::ModuleOp module_in = getOperation();
|
||||
mlir::MLIRContext& ctx = getContext();
|
||||
|
||||
// Convert MLIR to Graph
|
||||
@ -151,7 +151,7 @@ class GraphOptByNamePass : public GraphOptPass {
|
||||
: GraphOptPass(FindRegisteredPassesByName(pass_names)) {}
|
||||
|
||||
private:
|
||||
void runOnModule() override {
|
||||
void runOnOperation() override {
|
||||
// Verify all passes requested were registered/found.
|
||||
for (auto pass_it : llvm::enumerate(passes_)) {
|
||||
if (pass_it.value() == nullptr) {
|
||||
@ -160,7 +160,7 @@ class GraphOptByNamePass : public GraphOptPass {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
return GraphOptPass::runOnModule();
|
||||
return GraphOptPass::runOnOperation();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -48,8 +48,9 @@ constexpr char kPaddingMapAttr[] = "padding_map";
|
||||
// (user).
|
||||
|
||||
namespace {
|
||||
struct TPUDynamicPaddingMapper : public ModulePass<TPUDynamicPaddingMapper> {
|
||||
void runOnModule() override;
|
||||
struct TPUDynamicPaddingMapper
|
||||
: public OperationPass<TPUDynamicPaddingMapper, ModuleOp> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
// Creates a mapping from replicated input index (in `tf_device.replicate` op)
|
||||
@ -190,8 +191,8 @@ LogicalResult RemapAndAssignPaddingMaps(tf_device::LaunchFuncOp launch_func,
|
||||
return success();
|
||||
}
|
||||
|
||||
void TPUDynamicPaddingMapper::runOnModule() {
|
||||
ModuleOp module = getModule();
|
||||
void TPUDynamicPaddingMapper::runOnOperation() {
|
||||
ModuleOp module = getOperation();
|
||||
SymbolTable symbol_table(module);
|
||||
module.walk([&](tf_device::LaunchFuncOp launch_func) {
|
||||
RemapAndAssignPaddingMaps(launch_func, &symbol_table);
|
||||
|
@ -98,8 +98,8 @@ constexpr char kBadArrayAttrLengthMsg[] =
|
||||
// %4 = "tf.SomeOp"(%3)
|
||||
|
||||
namespace {
|
||||
struct TPURewritePass : public ModulePass<TPURewritePass> {
|
||||
void runOnModule() override;
|
||||
struct TPURewritePass : public OperationPass<TPURewritePass, ModuleOp> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
// Creates a missing attribute error message.
|
||||
@ -747,13 +747,13 @@ LogicalResult Rewrite(
|
||||
return success();
|
||||
}
|
||||
|
||||
void TPURewritePass::runOnModule() {
|
||||
void TPURewritePass::runOnOperation() {
|
||||
mlir::TF::RuntimeDevices devices;
|
||||
if (failed(tensorflow::GetDevicesFromOp(getModule(), &devices)))
|
||||
if (failed(tensorflow::GetDevicesFromOp(getOperation(), &devices)))
|
||||
return signalPassFailure();
|
||||
|
||||
OpBuilder builder(&getContext());
|
||||
auto result = getModule().walk([&](tf_device::LaunchFuncOp op) {
|
||||
auto result = getOperation().walk([&](tf_device::LaunchFuncOp op) {
|
||||
if (failed(Rewrite(op, devices.device_names(), &builder)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
@ -763,7 +763,7 @@ void TPURewritePass::runOnModule() {
|
||||
if (result.wasInterrupted()) return signalPassFailure();
|
||||
|
||||
// Eliminate TPUCompilationResultOp now that the rewrite is complete.
|
||||
getModule().walk([&](TF::TPUCompilationResultOp op) { op.erase(); });
|
||||
getOperation().walk([&](TF::TPUCompilationResultOp op) { op.erase(); });
|
||||
|
||||
// TODO(b/139377366): Remove functions that are no longer needed.
|
||||
}
|
||||
|
@ -40,8 +40,8 @@ namespace {
|
||||
constexpr char kShardingAttr[] = "xla_hlo.sharding";
|
||||
|
||||
struct TPUShardingIdentificationPass
|
||||
: public ModulePass<TPUShardingIdentificationPass> {
|
||||
void runOnModule() override;
|
||||
: public OperationPass<TPUShardingIdentificationPass, ModuleOp> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
// XlaSharding op may be direct user of inputs but it may also be followed by
|
||||
@ -176,9 +176,9 @@ void IdentifyXlaShardingForTPUComputation(Builder* builder,
|
||||
builder->getStrArrayAttr(sharding_for_rets));
|
||||
}
|
||||
|
||||
void TPUShardingIdentificationPass::runOnModule() {
|
||||
Builder builder(getModule().getContext());
|
||||
getModule().walk([&](tf_device::LaunchFuncOp launch_func) {
|
||||
void TPUShardingIdentificationPass::runOnOperation() {
|
||||
Builder builder(getOperation().getContext());
|
||||
getOperation().walk([&](tf_device::LaunchFuncOp launch_func) {
|
||||
IdentifyXlaShardingForTPUComputation(&builder, launch_func);
|
||||
});
|
||||
}
|
||||
|
@ -116,8 +116,8 @@ std::string GetRandomStateVariableName() {
|
||||
// tf.TPUReshardVariablesOp(%rvar, %default_format, %rstate)
|
||||
// }
|
||||
struct TPUVariableRuntimeReformattingPass
|
||||
: public ModulePass<TPUVariableRuntimeReformattingPass> {
|
||||
void runOnModule() override;
|
||||
: public OperationPass<TPUVariableRuntimeReformattingPass, ModuleOp> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
// Returns the earlier value of which `v` is an identity. If `skipped` is
|
||||
@ -318,7 +318,7 @@ TF::WhileOp AddStateVarsToWhileOp(TF::WhileOp while_op, FuncOp body,
|
||||
new_body_return_vals.push_back(inner_arg);
|
||||
new_while_operands.push_back(state_var.resource());
|
||||
}
|
||||
OpBuilder builder(&body.front());
|
||||
OpBuilder builder = OpBuilder::atBlockEnd(&body.front());
|
||||
// Update return values.
|
||||
builder.create<ReturnOp>(body_return.getLoc(), new_body_return_vals);
|
||||
body_return.erase();
|
||||
@ -555,8 +555,8 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
|
||||
builder.create<tf_device::ReturnOp>(while_op.getLoc(), ArrayRef<Value>{});
|
||||
}
|
||||
|
||||
void TPUVariableRuntimeReformattingPass::runOnModule() {
|
||||
auto module = getModule();
|
||||
void TPUVariableRuntimeReformattingPass::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
module.walk([&](TF::WhileOp while_op) {
|
||||
auto body = llvm::cast<FuncOp>(module.lookupSymbol(while_op.body()));
|
||||
tf_device::ReplicateOp replicate;
|
||||
|
@ -218,7 +218,7 @@ void ControlToExecutorDialectConversion::runOnFunction() {
|
||||
}
|
||||
|
||||
// Create the operation inside the island
|
||||
OpBuilder island_builder(&island.GetBody());
|
||||
OpBuilder island_builder = OpBuilder::atBlockEnd(&island.GetBody());
|
||||
Operation *inner_op = island_builder.createOperation(result);
|
||||
inner_op->setAttrs(op.getAttrList());
|
||||
|
||||
|
@ -68,7 +68,7 @@ void ExecutorToControlDialectConversion::runOnFunction() {
|
||||
|
||||
Block &body = getFunction().front();
|
||||
auto graph = cast<tf_executor::GraphOp>(body.front());
|
||||
OpBuilder builder(&body);
|
||||
OpBuilder builder = OpBuilder::atBlockEnd(&body);
|
||||
SmallString<64> new_op_name;
|
||||
for (auto &op : llvm::make_early_inc_range(llvm::reverse(graph.GetBody()))) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Process: " << op.getName() << "\n");
|
||||
|
@ -1452,7 +1452,8 @@ mlir::Operation* ImporterBase::createOperation(
|
||||
result.location, types, control_operands,
|
||||
mlir::ArrayRef<mlir::NamedAttribute>{});
|
||||
island.body().push_back(new mlir::Block);
|
||||
mlir::OpBuilder island_builder(&island.GetBody());
|
||||
mlir::OpBuilder island_builder =
|
||||
mlir::OpBuilder::atBlockEnd(&island.GetBody());
|
||||
|
||||
// Create the operation inside the island now.
|
||||
mlir::Operation* inner_op;
|
||||
@ -2928,12 +2929,11 @@ class SavedModelSignatureDefImporter {
|
||||
// Converts the SavedModel to the SavedModel dialect. Creates an MLIR function
|
||||
// for each signature.
|
||||
StatusOr<mlir::OwningModuleRef> ConvertSignatures();
|
||||
Status ConvertSignature(
|
||||
const GraphDef& graphdef, const std::string& sig_def_key,
|
||||
const std::map<std::string, TensorInfo>& inputs_sorted,
|
||||
const std::map<std::string, TensorInfo>& outputs_sorted,
|
||||
const GraphDebugInfo& debug_info,
|
||||
const FunctionLibraryDefinition& flib_def);
|
||||
Status ConvertSignature(const GraphDef& graphdef,
|
||||
const std::string& sig_def_key,
|
||||
const SignatureDef& signature_def,
|
||||
const GraphDebugInfo& debug_info,
|
||||
const FunctionLibraryDefinition& flib_def);
|
||||
|
||||
// Creates GlobalTensorOp for each variable and moves each VarHandle op to
|
||||
// the enclosing function's arguments.
|
||||
@ -2948,10 +2948,7 @@ class SavedModelSignatureDefImporter {
|
||||
const llvm::SmallVectorImpl<mlir::TF::VarHandleOp>& ops);
|
||||
|
||||
GraphImportConfig::InputArrays ParseInputArrays(
|
||||
const std::map<std::string, TensorInfo>& inputs);
|
||||
|
||||
std::vector<std::string> ParseOutputArrays(
|
||||
const std::map<std::string, TensorInfo>& outputs);
|
||||
const std::vector<std::pair<std::string, TensorInfo>>& inputs);
|
||||
|
||||
const SavedModelBundle& bundle_;
|
||||
mlir::OwningModuleRef module_;
|
||||
@ -2979,14 +2976,8 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// protobuf::Map doesn't provide stable iteration order so use std::map
|
||||
std::map<std::string, TensorInfo> inputs_sorted(
|
||||
signature_def.inputs().begin(), signature_def.inputs().end());
|
||||
std::map<std::string, TensorInfo> outputs_sorted(
|
||||
signature_def.outputs().begin(), signature_def.outputs().end());
|
||||
|
||||
TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, inputs_sorted,
|
||||
outputs_sorted, debug_info, flib_def));
|
||||
TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, signature_def,
|
||||
debug_info, flib_def));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(LiftVariables());
|
||||
|
||||
@ -2999,13 +2990,26 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
|
||||
|
||||
Status SavedModelSignatureDefImporter::ConvertSignature(
|
||||
const GraphDef& graphdef, const std::string& sig_def_key,
|
||||
const std::map<std::string, TensorInfo>& inputs_sorted,
|
||||
const std::map<std::string, TensorInfo>& outputs_sorted,
|
||||
const GraphDebugInfo& debug_info,
|
||||
const SignatureDef& signature_def, const GraphDebugInfo& debug_info,
|
||||
const FunctionLibraryDefinition& flib_def) {
|
||||
// Create local vectors for the input and output and sort them to be
|
||||
// deterministic. We don't want anyone to really depend on the order, client
|
||||
// should lookup argument/result mapping by attribute name.
|
||||
// To avoid accidentally depending on the order we use an unintuitive sorting.
|
||||
std::vector<std::pair<std::string, TensorInfo>> inputs(
|
||||
signature_def.inputs().begin(), signature_def.inputs().end());
|
||||
llvm::sort(inputs, [](const auto& lhs, const auto& rhs) {
|
||||
return lhs.first.size() < rhs.first.size() || lhs.first > rhs.first;
|
||||
});
|
||||
std::vector<std::pair<std::string, TensorInfo>> outputs(
|
||||
signature_def.outputs().begin(), signature_def.outputs().end());
|
||||
llvm::sort(outputs, [](const auto& lhs, const auto& rhs) {
|
||||
return lhs.first.size() < rhs.first.size() || lhs.first > rhs.first;
|
||||
});
|
||||
|
||||
GraphImportConfig specs;
|
||||
specs.inputs = ParseInputArrays(inputs_sorted);
|
||||
specs.outputs = ParseOutputArrays(outputs_sorted);
|
||||
specs.inputs = ParseInputArrays(inputs);
|
||||
for (auto& output : outputs) specs.outputs.push_back(output.second.name());
|
||||
|
||||
// Remove unused nodes and create sub-graphdef.
|
||||
GraphDef sub_graph_def;
|
||||
@ -3041,11 +3045,11 @@ Status SavedModelSignatureDefImporter::ConvertSignature(
|
||||
builder.getStrArrayAttr({sig_def_key}));
|
||||
|
||||
// Transfer input and output parameter names to index_path attributes.
|
||||
for (auto input_and_idx : llvm::enumerate(inputs_sorted)) {
|
||||
for (auto input_and_idx : llvm::enumerate(inputs)) {
|
||||
func_op.setArgAttr(input_and_idx.index(), "tf_saved_model.index_path",
|
||||
builder.getStrArrayAttr({input_and_idx.value().first}));
|
||||
}
|
||||
for (auto output_and_idx : llvm::enumerate(outputs_sorted)) {
|
||||
for (auto output_and_idx : llvm::enumerate(outputs)) {
|
||||
func_op.setResultAttr(
|
||||
output_and_idx.index(), "tf_saved_model.index_path",
|
||||
builder.getStrArrayAttr({output_and_idx.value().first}));
|
||||
@ -3180,7 +3184,7 @@ Status SavedModelSignatureDefImporter::ReadVariablesFromSession(
|
||||
}
|
||||
|
||||
GraphImportConfig::InputArrays SavedModelSignatureDefImporter::ParseInputArrays(
|
||||
const std::map<std::string, TensorInfo>& inputs) {
|
||||
const std::vector<std::pair<std::string, TensorInfo>>& inputs) {
|
||||
GraphImportConfig::InputArrays results;
|
||||
for (const auto& iter : inputs) {
|
||||
const auto& tensor_info = iter.second;
|
||||
@ -3192,28 +3196,12 @@ GraphImportConfig::InputArrays SavedModelSignatureDefImporter::ParseInputArrays(
|
||||
array_info.imported_dtype = tensor_info.dtype();
|
||||
array_info.shape = tensor_info.tensor_shape();
|
||||
|
||||
std::vector<std::string> node_names =
|
||||
absl::StrSplit(tensor_info.name(), ':');
|
||||
|
||||
results.insert(std::pair<std::string, ArrayInfo>(node_names.at(0),
|
||||
results.insert(std::pair<std::string, ArrayInfo>(tensor_info.name(),
|
||||
std::move(array_info)));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
std::vector<std::string> SavedModelSignatureDefImporter::ParseOutputArrays(
|
||||
const std::map<std::string, TensorInfo>& outputs) {
|
||||
std::vector<std::string> results;
|
||||
for (const auto& iter : outputs) {
|
||||
const auto& tensor_info = iter.second;
|
||||
|
||||
std::vector<std::string> node_names =
|
||||
absl::StrSplit(tensor_info.name(), ':');
|
||||
results.push_back(node_names.at(0));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def) {
|
||||
|
@ -328,6 +328,24 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "buffer_assignment",
|
||||
srcs = ["transforms/buffer_assignment.cc"],
|
||||
hdrs = ["transforms/buffer_assignment.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":lhlo",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "xla_legalize_to_standard_inc_gen",
|
||||
tbl_outs = [
|
||||
|
@ -140,7 +140,7 @@ tensorflow::Status HloFunctionImporter::ImportInstructions(
|
||||
instruction_value_map_[hlo_parameter] = block->getArgument(i);
|
||||
}
|
||||
|
||||
mlir::OpBuilder builder(block);
|
||||
mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block);
|
||||
for (auto instruction : computation->MakeInstructionPostOrder()) {
|
||||
TF_ASSIGN_OR_RETURN(auto new_operation,
|
||||
ImportInstruction(instruction, &builder));
|
||||
@ -523,6 +523,32 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
attributes.push_back(builder_->getNamedAttr("transpose_a", transpose_a));
|
||||
MakeAndReturn(TriangularSolveOp);
|
||||
}
|
||||
case HloOpcode::kReduceWindow: {
|
||||
llvm::SmallVector<int64, 4> sizes, strides, base_dilations, win_dilations;
|
||||
llvm::SmallVector<int64_t, 8> padding;
|
||||
for (const auto& dim : instruction->window().dimensions()) {
|
||||
sizes.push_back(dim.size());
|
||||
strides.push_back(dim.stride());
|
||||
base_dilations.push_back(dim.base_dilation());
|
||||
win_dilations.push_back(dim.window_dilation());
|
||||
padding.push_back(dim.padding_low());
|
||||
padding.push_back(dim.padding_high());
|
||||
}
|
||||
attributes.push_back(builder_->getNamedAttr("window_dimensions",
|
||||
ConvertDimensions(sizes)));
|
||||
attributes.push_back(
|
||||
builder_->getNamedAttr("window_strides", ConvertDimensions(strides)));
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"base_dilations", ConvertDimensions(base_dilations)));
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"window_dilations", ConvertDimensions(win_dilations)));
|
||||
attributes.push_back(ConvertPadding(padding));
|
||||
auto reduce = func_builder->create<mlir::xla_hlo::ReduceWindowOp>(
|
||||
loc, result_type, operands, attributes);
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportComputation(instruction->to_apply(), &reduce.body()));
|
||||
return reduce.getOperation();
|
||||
}
|
||||
case HloOpcode::kMap: {
|
||||
auto op = func_builder->create<mlir::xla_hlo::MapOp>(
|
||||
loc, result_type, operands,
|
||||
|
131
tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir
Normal file
131
tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir
Normal file
@ -0,0 +1,131 @@
|
||||
// RUN: tf-opt -test-buffer-assignment -split-input-file %s | FileCheck %s -dump-input-on-failure
|
||||
|
||||
// CHECK-LABEL: Testing : condBranch
|
||||
func @condBranch(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
|
||||
// CHECK: Alloc: cond_br
|
||||
cond_br %cond, ^bb1, ^bb2
|
||||
^bb1:
|
||||
br ^exit(%arg0 : tensor<2xf32>)
|
||||
^bb2:
|
||||
%1 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
br ^exit(%1 : tensor<2xf32>)
|
||||
^exit(%arg1: tensor<2xf32>):
|
||||
return %arg1 : tensor<2xf32>
|
||||
// CHECK-NEXT: Dealloc: return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: Testing : criticalEdge
|
||||
func @criticalEdge(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
|
||||
// CHECK: Alloc: cond_br
|
||||
cond_br %cond, ^bb1, ^exit(%arg0 : tensor<2xf32>)
|
||||
^bb1:
|
||||
%0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
br ^exit(%0 : tensor<2xf32>)
|
||||
^exit(%arg1: tensor<2xf32>):
|
||||
return %arg1 : tensor<2xf32>
|
||||
// CHECK-NEXT: Dealloc: return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: Testing : invCriticalEdge
|
||||
func @invCriticalEdge(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
|
||||
// CHECK: Alloc: %0 = "xla_hlo.exp"
|
||||
%0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
cond_br %cond, ^bb1, ^exit(%arg0 : tensor<2xf32>)
|
||||
^bb1:
|
||||
br ^exit(%0 : tensor<2xf32>)
|
||||
^exit(%arg1: tensor<2xf32>):
|
||||
return %arg1 : tensor<2xf32>
|
||||
// CHECK-NEXT: Dealloc: return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: Testing : ifElse
|
||||
func @ifElse(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
|
||||
// CHECK: Alloc: %0 = "xla_hlo.exp"(%arg1)
|
||||
%0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>)
|
||||
^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>):
|
||||
br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>)
|
||||
^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>):
|
||||
br ^exit(%arg3, %arg4 : tensor<2xf32>, tensor<2xf32>)
|
||||
^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>):
|
||||
// CHECK-NEXT: Dealloc: %7 = "xla_hlo.exp"(%5)
|
||||
// CHECK: Alloc: %7 = "xla_hlo.exp"(%5)
|
||||
// CHECK-NEXT: Dealloc: return
|
||||
%1 = "xla_hlo.exp"(%arg5) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %1 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: Testing : ifElseNoUsers
|
||||
func @ifElseNoUsers(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
|
||||
// CHECK: Alloc: %0 = "xla_hlo.exp"(%arg1)
|
||||
%0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>)
|
||||
^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>):
|
||||
br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>)
|
||||
^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>):
|
||||
br ^exit(%arg3, %arg4 : tensor<2xf32>, tensor<2xf32>)
|
||||
^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>):
|
||||
// CHECK-NEXT: return
|
||||
return %arg0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: Testing : ifElseNested
|
||||
func @ifElseNested(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
|
||||
// CHECK: Alloc: %0 = "xla_hlo.exp"(%arg1)
|
||||
%0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>)
|
||||
^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>):
|
||||
br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>)
|
||||
^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>):
|
||||
cond_br %cond, ^bb3(%arg3 : tensor<2xf32>), ^bb4(%arg4 : tensor<2xf32>)
|
||||
^bb3(%arg7 : tensor<2xf32>):
|
||||
br ^exit(%arg7, %arg3 : tensor<2xf32>, tensor<2xf32>)
|
||||
^bb4(%arg8 : tensor<2xf32>):
|
||||
br ^exit(%arg3, %arg8 : tensor<2xf32>, tensor<2xf32>)
|
||||
^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>):
|
||||
// CHECK-NEXT: Dealloc: %9 = "xla_hlo.exp"(%7)
|
||||
// CHECK: Alloc: %9 = "xla_hlo.exp"(%7)
|
||||
// CHECK-NEXT: Dealloc: return
|
||||
%1 = "xla_hlo.exp"(%arg5) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %1 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: Testing : redundantOperations
|
||||
func @redundantOperations(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) {
|
||||
// CHECK: Alloc: %0 = xla_hlo.maximum
|
||||
// CHECK-NEXT: Dealloc: %1 = xla_hlo.add
|
||||
%1 = "xla_hlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK: Alloc: %1 = xla_hlo.add
|
||||
// CHECK-NEXT: Dealloc: %1 = xla_hlo.add
|
||||
%2 = "xla_hlo.add"(%arg0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: Testing : reduce
|
||||
func @reduce(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
|
||||
// CHECK: Alloc: %0 = xla_hlo.constant
|
||||
// CHECK-NEXT: Dealloc: %1 = "xla_hlo.reduce"(%arg0, %0)
|
||||
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: Alloc: %1 = "xla_hlo.reduce"(%arg0, %0)
|
||||
// CHECK: Dealloc: return
|
||||
%2 = "xla_hlo.reduce"(%arg0, %0) ( {
|
||||
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||
%4 = xla_hlo.add %arg1, %arg2 : tensor<f32>
|
||||
"xla_hlo.return"(%4) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
|
||||
return %2 : tensor<4x8xf32>
|
||||
}
|
@ -31,11 +31,12 @@ func @reduce(%arg: memref<100x10x5xf32>,
|
||||
// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 {
|
||||
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
|
||||
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: loop.reduce.return [[ACC_RESULT]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: loop.yield
|
||||
@ -71,11 +72,12 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>,
|
||||
// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 {
|
||||
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
|
||||
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: loop.reduce.return [[ACC_RESULT]]
|
||||
// CHECK: }
|
||||
// CHECK: loop.yield
|
||||
@ -114,11 +116,12 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
|
||||
// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 {
|
||||
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
|
||||
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: loop.reduce.return [[ACC_RESULT]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: loop.yield
|
||||
@ -185,11 +188,12 @@ func @reduce_window(%arg: memref<112x112xf32>,
|
||||
// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 {
|
||||
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
|
||||
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: loop.reduce.return [[ACC_RESULT]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: loop.yield
|
||||
|
@ -698,6 +698,24 @@ add {
|
||||
ROOT %tuple.6 = ((f32[], f32[]), f32[]) tuple(%reduce.1, %sub.5)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_reduce_window
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tensor<2x17x31x7xf32>, [[ARG1:%.*]]: tensor<f32>)
|
||||
%test_reduce_window (Arg_0.1: f32[2,17,31,7], Arg_1.2: f32[]) -> f32[2,5,8,7] {
|
||||
%Arg_0.1 = f32[2,17,31,7] parameter(0)
|
||||
%Arg_1.2 = f32[] parameter(1)
|
||||
|
||||
// CHECK: "xla_hlo.reduce_window"([[ARG0]], [[ARG1]]) ( {
|
||||
// CHECK: xla_hlo.add {{.*}} : tensor<f32>
|
||||
// CHECK: }) {
|
||||
// CHECK-SAME: base_dilations = dense<1> : tensor<4xi64>
|
||||
// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64>
|
||||
// CHECK-SAME: window_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64>
|
||||
// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>
|
||||
// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>
|
||||
// CHECK_SAME: }
|
||||
ROOT %reduce-window.1 = f32[2,5,8,7] reduce-window(f32[2,17,31,7] %Arg_0.1, f32[] %Arg_1.2), window={size=1x2x2x1 stride=1x4x4x1 pad=0_0x2_0x0_2x0_0 rhs_dilate=1x2x2x1}, to_apply=%reduce_helper.3
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_remainder
|
||||
// CHECK-SAME: ([[VAL_0:%.*]]: tensor<4xf32>, [[VAL_1:%.*]]: tensor<4xf32>)
|
||||
%test_remainder (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
|
||||
|
501
tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc
Normal file
501
tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc
Normal file
@ -0,0 +1,501 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file implements logic for computing proper alloc and dealloc positions.
|
||||
// The main class is the BufferAssignment class that realizes this analysis.
|
||||
// In order to put allocations and deallocations at safe positions, it is
|
||||
// significantly important to put them into the proper blocks. However, the
|
||||
// liveness analysis does not pay attention to aliases, which can occur due to
|
||||
// branches (and their associated block arguments) in general. For this purpose,
|
||||
// BufferAssignment firstly finds all possible aliases for a single value (using
|
||||
// the BufferAssignmentAliasAnalysis class). Consider the following example:
|
||||
//
|
||||
// ^bb0(%arg0):
|
||||
// cond_br %cond, ^bb1, ^bb2
|
||||
// ^bb1:
|
||||
// br ^exit(%arg0)
|
||||
// ^bb2:
|
||||
// %new_value = ...
|
||||
// br ^exit(%new_value)
|
||||
// ^exit(%arg1):
|
||||
// return %arg1;
|
||||
//
|
||||
// Using liveness information on its own would cause us to place the allocs and
|
||||
// deallocs in the wrong block. This is due to the fact that %new_value will not
|
||||
// be liveOut of its block. Instead, we have to place the alloc for %new_value
|
||||
// in bb0 and its associated dealloc in exit. Using the class
|
||||
// BufferAssignmentAliasAnalysis, we will find out that %new_value has a
|
||||
// potential alias %arg1. In order to find the dealloc position we have to find
|
||||
// all potential aliases, iterate over their uses and find the common
|
||||
// post-dominator block. In this block we can safely be sure that %new_value
|
||||
// will die and can use liveness information to determine the exact operation
|
||||
// after which we have to insert the dealloc. Finding the alloc position is
|
||||
// highly similar and non- obvious. Again, we have to consider all potential
|
||||
// aliases and find the common dominator block to place the alloc.
|
||||
//
|
||||
// TODO(dfki):
|
||||
// The current implementation does not support loops. The only thing that
|
||||
// is currently missing is a high-level loop analysis that allows us to move
|
||||
// allocs and deallocs outside of the loop blocks.
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "absl/memory/memory.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferAssignmentAliasAnalysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// A straight-forward alias analysis which ensures that all aliases of all
|
||||
/// values will be determined. This is a requirement for the BufferAssignment
|
||||
/// class since you need to determine safe positions to place alloc and
|
||||
/// deallocs.
|
||||
class BufferAssignmentAliasAnalysis {
|
||||
public:
|
||||
using ValueSetT = SmallPtrSet<Value, 16>;
|
||||
|
||||
public:
|
||||
/// Constructs a new alias analysis using the op provided.
|
||||
BufferAssignmentAliasAnalysis(Operation* op) { build(op->getRegions()); }
|
||||
|
||||
/// Finds all immediate and indirect aliases this value could potentially
|
||||
/// have. Note that the resulting set will also contain the value provided as
|
||||
/// it is an alias of itself.
|
||||
ValueSetT resolve(Value value) const {
|
||||
ValueSetT result;
|
||||
resolveRecursive(value, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
/// Recursively determines alias information for the given value. It stores
|
||||
/// all newly found potential aliases in the given result set.
|
||||
void resolveRecursive(Value value, ValueSetT& result) const {
|
||||
if (!result.insert(value).second) {
|
||||
return;
|
||||
}
|
||||
auto it = aliases.find(value);
|
||||
if (it == aliases.end()) return;
|
||||
for (auto alias : it->second) {
|
||||
resolveRecursive(alias, result);
|
||||
}
|
||||
}
|
||||
|
||||
/// This function constructs a mapping from values to its immediate aliases.
|
||||
/// It iterates over all blocks, gets their predecessors, determines the
|
||||
/// values that will be passed to the corresponding block arguments and
|
||||
/// inserts them into map.
|
||||
void build(MutableArrayRef<Region> regions) {
|
||||
for (Region& region : regions) {
|
||||
for (Block& block : region) {
|
||||
// Iterate over all predecessor and get the mapped values to their
|
||||
// corresponding block arguments values.
|
||||
for (auto pred : block.getPredecessors()) {
|
||||
// Determine the current successor index of the current predecessor.
|
||||
unsigned successorIndex = std::distance(
|
||||
pred->getSuccessors().begin(),
|
||||
llvm::find_if(pred->getSuccessors(), [&](Block* successor) {
|
||||
return successor == █
|
||||
}));
|
||||
// Get the terminator and the values that will be passed to our block.
|
||||
if (auto branchInterface =
|
||||
dyn_cast<BranchOpInterface>(pred->getTerminator())) {
|
||||
// Query the branch op interace to get the successor operands.
|
||||
auto successorOps =
|
||||
branchInterface.getSuccessorOperands(successorIndex);
|
||||
if (successorOps.hasValue()) {
|
||||
// Build the actual mapping of values to their immediate aliases.
|
||||
for (auto arg : block.getArguments()) {
|
||||
Value predecessorArgValue =
|
||||
successorOps.getValue()[arg.getArgNumber()];
|
||||
aliases[predecessorArgValue].insert(arg);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Maps values to all immediate aliases this value can have.
|
||||
llvm::DenseMap<Value, ValueSetT> aliases;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferAssignmentPositions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Stores proper alloc and dealloc positions to place dialect-specific alloc
|
||||
/// and dealloc operations.
|
||||
struct BufferAssignmentPositions {
|
||||
public:
|
||||
BufferAssignmentPositions()
|
||||
: allocPosition(nullptr), deallocPosition(nullptr) {}
|
||||
|
||||
/// Creates a new positions tuple including alloc and dealloc positions.
|
||||
BufferAssignmentPositions(Operation* allocPosition,
|
||||
Operation* deallocPosition)
|
||||
: allocPosition(allocPosition), deallocPosition(deallocPosition) {}
|
||||
|
||||
/// Returns the alloc position before which the alloc operation has to be
|
||||
/// inserted.
|
||||
Operation* getAllocPosition() const { return allocPosition; }
|
||||
|
||||
/// Returns the dealloc position after which the dealloc operation has to be
|
||||
/// inserted.
|
||||
Operation* getDeallocPosition() const { return deallocPosition; }
|
||||
|
||||
private:
|
||||
Operation* allocPosition;
|
||||
Operation* deallocPosition;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferAssignmentAnalysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// The main buffer assignment analysis used to place allocs and deallocs.
|
||||
class BufferAssignmentAnalysis {
|
||||
public:
|
||||
using DeallocSetT = SmallPtrSet<Operation*, 2>;
|
||||
|
||||
public:
|
||||
BufferAssignmentAnalysis(Operation* op)
|
||||
: operation(op),
|
||||
liveness(op),
|
||||
dominators(op),
|
||||
postDominators(op),
|
||||
aliases(op) {}
|
||||
|
||||
/// Computes the actual positions to place allocs and deallocs for the given
|
||||
/// value.
|
||||
BufferAssignmentPositions computeAllocAndDeallocPositions(Value value) const {
|
||||
if (value.use_empty()) {
|
||||
return BufferAssignmentPositions(value.getDefiningOp(),
|
||||
value.getDefiningOp());
|
||||
}
|
||||
// Get all possible aliases
|
||||
auto possibleValues = aliases.resolve(value);
|
||||
return BufferAssignmentPositions(getAllocPosition(value, possibleValues),
|
||||
getDeallocPosition(value, possibleValues));
|
||||
}
|
||||
|
||||
/// Finds all associated dealloc nodes for the alloc nodes using alias
|
||||
/// information.
|
||||
DeallocSetT findAssociatedDeallocs(AllocOp alloc) const {
|
||||
DeallocSetT result;
|
||||
auto possibleValues = aliases.resolve(alloc);
|
||||
for (auto alias : possibleValues) {
|
||||
for (auto user : alias.getUsers()) {
|
||||
if (isa<DeallocOp>(user)) result.insert(user);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Dumps the buffer assignment information to the given stream.
|
||||
void print(raw_ostream& os) const {
|
||||
os << "// ---- Buffer Assignment -----\n";
|
||||
|
||||
for (Region& region : operation->getRegions())
|
||||
for (Block& block : region)
|
||||
for (Operation& operation : block)
|
||||
for (Value result : operation.getResults()) {
|
||||
BufferAssignmentPositions positions =
|
||||
computeAllocAndDeallocPositions(result);
|
||||
os << "Positions for ";
|
||||
result.print(os);
|
||||
os << "\n Alloc: ";
|
||||
positions.getAllocPosition()->print(os);
|
||||
os << "\n Dealloc: ";
|
||||
positions.getDeallocPosition()->print(os);
|
||||
os << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
/// Finds a proper placement block to store alloc/dealloc node according to
|
||||
/// the algorithm described at the top of the file. It supports dominator and
|
||||
/// post-dominator analyses via template arguments.
|
||||
template <typename AliasesT, typename DominatorT>
|
||||
Block* findPlacementBlock(Value value, const AliasesT& aliases,
|
||||
const DominatorT& doms) const {
|
||||
assert(!value.isa<BlockArgument>() && "Cannot place a block argument");
|
||||
// Start with the current block the value is defined in.
|
||||
Block* dom = value.getDefiningOp()->getBlock();
|
||||
// Iterate over all aliases and their uses to find a safe placement block
|
||||
// according to the given dominator information.
|
||||
for (auto alias : aliases) {
|
||||
for (auto user : alias.getUsers()) {
|
||||
// Move upwards in the dominator tree to find an appropriate
|
||||
// dominator block that takes the current use into account.
|
||||
dom = doms.findNearestCommonDominator(dom, user->getBlock());
|
||||
}
|
||||
}
|
||||
return dom;
|
||||
}
|
||||
|
||||
/// Finds a proper alloc positions according to the algorithm described at the
|
||||
/// top of the file.
|
||||
template <typename AliasesT>
|
||||
Operation* getAllocPosition(Value value, const AliasesT& aliases) const {
|
||||
// Determine the actual block to place the alloc and get liveness
|
||||
// information.
|
||||
auto placementBlock = findPlacementBlock(value, aliases, dominators);
|
||||
auto livenessInfo = liveness.getLiveness(placementBlock);
|
||||
|
||||
// We have to ensure that the alloc will be before the first use of all
|
||||
// aliases of the given value. We first assume that there are no uses in the
|
||||
// placementBlock and that we can safely place the alloc before the
|
||||
// terminator at the end of the block.
|
||||
Operation* startOperation = placementBlock->getTerminator();
|
||||
// Iterate over all aliases and ensure that the startOperation will point to
|
||||
// the first operation of all potential aliases in the placementBlock.
|
||||
for (auto alias : aliases) {
|
||||
auto aliasStartOperation = livenessInfo->getStartOperation(alias);
|
||||
// Check whether the aliasStartOperation lies in the desired block and
|
||||
// whether it is before the current startOperation. If yes, this will be
|
||||
// the new startOperation.
|
||||
if (aliasStartOperation->getBlock() == placementBlock &&
|
||||
aliasStartOperation->isBeforeInBlock(startOperation)) {
|
||||
startOperation = aliasStartOperation;
|
||||
}
|
||||
}
|
||||
// startOperation is the first operation before which we can safely store
|
||||
// the alloc taking all potential aliases into account.
|
||||
return startOperation;
|
||||
}
|
||||
|
||||
/// Finds a proper dealloc positions according to the algorithm described at
|
||||
/// the top of the file.
|
||||
template <typename AliasesT>
|
||||
Operation* getDeallocPosition(Value value, const AliasesT& aliases) const {
|
||||
// Determine the actual block to place the dealloc and get liveness
|
||||
// information.
|
||||
auto placementBlock = findPlacementBlock(value, aliases, postDominators);
|
||||
auto livenessInfo = liveness.getLiveness(placementBlock);
|
||||
|
||||
// We have to ensure that the dealloc will be after the last use of all
|
||||
// aliases of the given value. We first assume that there are no uses in the
|
||||
// placementBlock and that we can safely place the dealloc at the beginning.
|
||||
Operation* endOperation = &placementBlock->front();
|
||||
// Iterate over all aliases and ensure that the endOperation will point to
|
||||
// the last operation of all potential aliases in the placementBlock.
|
||||
for (auto alias : aliases) {
|
||||
auto aliasEndOperation =
|
||||
livenessInfo->getEndOperation(alias, endOperation);
|
||||
// Check whether the aliasEndOperation lies in the desired block and
|
||||
// whether it is behind the current endOperation. If yes, this will be the
|
||||
// new endOperation.
|
||||
if (aliasEndOperation->getBlock() == placementBlock &&
|
||||
endOperation->isBeforeInBlock(aliasEndOperation)) {
|
||||
endOperation = aliasEndOperation;
|
||||
}
|
||||
}
|
||||
// endOperation is the last operation behind which we can safely store the
|
||||
// dealloc taking all potential aliases into account.
|
||||
return endOperation;
|
||||
}
|
||||
|
||||
/// The operation this transformation was constructed from.
|
||||
Operation* operation;
|
||||
|
||||
/// The underlying liveness analysis to compute fine grained information about
|
||||
/// alloc and dealloc positions.
|
||||
Liveness liveness;
|
||||
|
||||
/// The dominator analysis to place allocs in the appropriate blocks.
|
||||
DominanceInfo dominators;
|
||||
|
||||
/// The post dominator analysis to place deallocs in the appropriate blocks.
|
||||
PostDominanceInfo postDominators;
|
||||
|
||||
/// The internal alias analysis to ensure that allocs and deallocs take all
|
||||
/// their potential aliases into account.
|
||||
BufferAssignmentAliasAnalysis aliases;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferAssignmentPass
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// The actual buffer assignment pass that moves alloc and dealloc nodes into
|
||||
/// the right positions. It uses the algorithm described at the top of the file.
|
||||
// TODO(dfki): create a templated version that allows to match dialect-specific
|
||||
// alloc/dealloc nodes and to insert dialect-specific dealloc node.
|
||||
struct BufferAssignmentPass : mlir::FunctionPass<BufferAssignmentPass> {
|
||||
void runOnFunction() override {
|
||||
// Get required analysis information first.
|
||||
auto& analysis = getAnalysis<BufferAssignmentAnalysis>();
|
||||
|
||||
// Compute an initial placement of all nodes.
|
||||
llvm::SmallDenseMap<Value, BufferAssignmentPositions, 16> placements;
|
||||
getFunction().walk([&](AllocOp alloc) {
|
||||
placements[alloc] = analysis.computeAllocAndDeallocPositions(alloc);
|
||||
});
|
||||
|
||||
// Move alloc (and dealloc - if any) nodes into the right places
|
||||
// and insert dealloc nodes if necessary.
|
||||
getFunction().walk([&](AllocOp alloc) {
|
||||
// Find already associated dealloc nodes.
|
||||
auto deallocs = analysis.findAssociatedDeallocs(alloc);
|
||||
assert(deallocs.size() < 2 &&
|
||||
"Not supported number of associated dealloc operations");
|
||||
|
||||
// Move alloc node to the right place.
|
||||
BufferAssignmentPositions& positions = placements[alloc];
|
||||
Operation* allocOperation = alloc.getOperation();
|
||||
allocOperation->moveBefore(positions.getAllocPosition());
|
||||
|
||||
// If there is an existing dealloc, move it to the right place.
|
||||
if (deallocs.size()) {
|
||||
Operation* nextOp = positions.getDeallocPosition()->getNextNode();
|
||||
if (!nextOp)
|
||||
nextOp = &positions.getDeallocPosition()->getBlock()->back();
|
||||
(*deallocs.begin())->moveBefore(nextOp);
|
||||
} else {
|
||||
// If there is no dealloc node, insert one in the right place.
|
||||
OpBuilder builder(alloc);
|
||||
builder.setInsertionPointAfter(positions.getDeallocPosition());
|
||||
builder.create<DeallocOp>(allocOperation->getLoc(), alloc);
|
||||
}
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferAssignmentPlacer
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Creates a new assignment placer.
|
||||
BufferAssignmentPlacer::BufferAssignmentPlacer(Operation* op)
|
||||
: operation(op), dominators(op) {}
|
||||
|
||||
/// Computes the actual position to place allocs for the given value.
|
||||
OpBuilder::InsertPoint BufferAssignmentPlacer::computeAllocPosition(
|
||||
Value value) {
|
||||
Operation* insertOp;
|
||||
if (auto arg = value.dyn_cast<BlockArgument>()) {
|
||||
// This is a block argument which has to be allocated in the scope
|
||||
// of its associated terminator.
|
||||
auto domNode = dominators.getNode(arg.getOwner());
|
||||
assert(domNode != nullptr && "Cannot find dominator info");
|
||||
auto idomNode = domNode->getIDom();
|
||||
assert(idomNode != nullptr && "There is no parent dominator");
|
||||
insertOp = idomNode->getBlock()->getTerminator();
|
||||
} else {
|
||||
insertOp = value.getDefiningOp();
|
||||
}
|
||||
OpBuilder opBuilder(insertOp);
|
||||
return opBuilder.saveInsertionPoint();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FunctionAndBlockSignatureConverter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Performs the actual signature rewriting step.
|
||||
LogicalResult FunctionAndBlockSignatureConverter::matchAndRewrite(
|
||||
FuncOp funcOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
auto toMemrefConverter = [&](Type t) -> Type {
|
||||
if (auto tensorType = t.dyn_cast<RankedTensorType>()) {
|
||||
return MemRefType::get(tensorType.getShape(),
|
||||
tensorType.getElementType());
|
||||
}
|
||||
return t;
|
||||
};
|
||||
// Converting tensor-type function arguments to memref-type.
|
||||
auto funcType = funcOp.getType();
|
||||
TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
|
||||
for (auto argType : llvm::enumerate(funcType.getInputs())) {
|
||||
conversion.addInputs(argType.index(), toMemrefConverter(argType.value()));
|
||||
}
|
||||
for (auto resType : funcType.getResults()) {
|
||||
conversion.addInputs(toMemrefConverter(resType));
|
||||
}
|
||||
rewriter.updateRootInPlace(funcOp, [&] {
|
||||
funcOp.setType(
|
||||
rewriter.getFunctionType(conversion.getConvertedTypes(), llvm::None));
|
||||
rewriter.applySignatureConversion(&funcOp.getBody(), conversion);
|
||||
});
|
||||
// Converting tensor-type block arugments of all blocks inside the
|
||||
// function region to memref-type except for the entry block.
|
||||
for (auto& block : funcOp.getBlocks()) {
|
||||
if (block.isEntryBlock()) continue;
|
||||
for (int i = 0, e = block.getNumArguments(); i < e; ++i) {
|
||||
auto oldArg = block.getArgument(i);
|
||||
auto newArg =
|
||||
block.insertArgument(i, toMemrefConverter(oldArg.getType()));
|
||||
oldArg.replaceAllUsesWith(newArg);
|
||||
block.eraseArgument(i + 1);
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// Adding functions whose arguments are memref type to the set of legal
|
||||
// operations.
|
||||
void FunctionAndBlockSignatureConverter::addDynamicallyLegalFuncOp(
|
||||
ConversionTarget& target) {
|
||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||
auto inputs = op.getType().getInputs();
|
||||
return std::all_of(inputs.begin(), inputs.end(),
|
||||
[](Type input) { return input.isa<MemRefType>(); });
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Buffer assignment pass registrations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>> createBufferAssignmentPass() {
|
||||
return absl::make_unique<BufferAssignmentPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<BufferAssignmentPass> buffer_assignment_pass(
|
||||
"buffer-assignment",
|
||||
"Executes buffer assignment pass to automatically move alloc and dealloc "
|
||||
"operations into their proper positions");
|
||||
|
||||
/// A simple pass to print debug/test information for the buffer assignment
|
||||
/// analysis.
|
||||
struct BufferAssignmentTestPass : mlir::FunctionPass<BufferAssignmentTestPass> {
|
||||
void runOnFunction() override {
|
||||
llvm::outs() << "Testing : " << getFunction().getName() << "\n";
|
||||
getAnalysis<BufferAssignmentAnalysis>().print(llvm::outs());
|
||||
};
|
||||
};
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>> createBufferAssignmentTestPass() {
|
||||
return absl::make_unique<BufferAssignmentTestPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<BufferAssignmentTestPass> buffer_assignment_test_pass(
|
||||
"test-buffer-assignment",
|
||||
"Outputs debug test information for the buffer assignment analysis");
|
||||
|
||||
} // namespace xla
|
||||
} // namespace mlir
|
140
tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h
Normal file
140
tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h
Normal file
@ -0,0 +1,140 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_BUFFER_ASSIGNMENT_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_BUFFER_ASSIGNMENT_H_
|
||||
|
||||
#include "mlir/Analysis/Dominance.h"
|
||||
#include "mlir/Analysis/Liveness.h"
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace xla {
|
||||
|
||||
/// Prepares a buffer assignment phase. It can place (user-defined) alloc
|
||||
/// nodes. This simplifies the integration of the actual buffer-assignment
|
||||
/// pass. Sample usage:
|
||||
/// BufferAssignmentPlacer baHelper(regionOp);
|
||||
/// -> determine alloc positions
|
||||
/// auto allocPosition = baHelper.computeAllocPosition(value);
|
||||
/// -> place alloc
|
||||
/// allocBuilder.setInsertionPoint(positions.getAllocPosition());
|
||||
/// <create alloc>
|
||||
/// alternatively:
|
||||
/// -> place alloc
|
||||
/// baHelper.insertAlloc<AllocOp>(...);
|
||||
/// Note: this class is intended to be used during legalization. In order
|
||||
/// to move alloc and dealloc nodes into the right places you can use the
|
||||
/// createBufferAssignmentPass() function.
|
||||
class BufferAssignmentPlacer {
|
||||
public:
|
||||
/// Creates a new assignment builder.
|
||||
explicit BufferAssignmentPlacer(Operation* op);
|
||||
|
||||
/// Returns the operation this analysis was constructed from.
|
||||
Operation* getOperation() const { return operation; }
|
||||
|
||||
/// Computes the actual position to place allocs for the given value.
|
||||
OpBuilder::InsertPoint computeAllocPosition(Value value);
|
||||
|
||||
private:
|
||||
/// The operation this analysis was constructed from.
|
||||
Operation* operation;
|
||||
|
||||
/// The dominator analysis to place allocs in the appropriate blocks.
|
||||
DominanceInfo dominators;
|
||||
};
|
||||
|
||||
/// Helper conversion pattern that encapsulates a BufferAssignmentPlacer
|
||||
/// instance.
|
||||
template <typename SourceOp>
|
||||
class BufferAssignmentOpConversionPattern
|
||||
: public OpConversionPattern<SourceOp> {
|
||||
public:
|
||||
explicit BufferAssignmentOpConversionPattern(
|
||||
MLIRContext* context_,
|
||||
xla::BufferAssignmentPlacer* bufferAssignment_ = nullptr,
|
||||
PatternBenefit benefit_ = 1)
|
||||
: OpConversionPattern<SourceOp>(context_, benefit_),
|
||||
bufferAssignment(bufferAssignment_) {}
|
||||
|
||||
protected:
|
||||
xla::BufferAssignmentPlacer* bufferAssignment;
|
||||
};
|
||||
|
||||
// Converts only the tensor-type function and block arguments to memref-type.
|
||||
class FunctionAndBlockSignatureConverter
|
||||
: public BufferAssignmentOpConversionPattern<FuncOp> {
|
||||
public:
|
||||
using BufferAssignmentOpConversionPattern<
|
||||
FuncOp>::BufferAssignmentOpConversionPattern;
|
||||
|
||||
// Adding functions whose arguments are memref type to the set of legal
|
||||
// operations.
|
||||
static void addDynamicallyLegalFuncOp(ConversionTarget& target);
|
||||
|
||||
// Performs the actual signature rewriting step.
|
||||
LogicalResult matchAndRewrite(
|
||||
FuncOp funcOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final;
|
||||
};
|
||||
|
||||
// This pattern converter transforms a non-void ReturnOpSourceTy into a void
|
||||
// return of type ReturnOpTargetTy. It uses a copy operation of type CopyOpTy to
|
||||
// copy the results to the output buffer.
|
||||
template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
|
||||
typename CopyOpTy>
|
||||
class NonVoidToVoidReturnOpConverter
|
||||
: public BufferAssignmentOpConversionPattern<ReturnOpSourceTy> {
|
||||
public:
|
||||
using BufferAssignmentOpConversionPattern<
|
||||
ReturnOpSourceTy>::BufferAssignmentOpConversionPattern;
|
||||
|
||||
// Performs the actual return-op conversion step.
|
||||
LogicalResult matchAndRewrite(
|
||||
ReturnOpSourceTy returnOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto numReturnValues = returnOp.getNumOperands();
|
||||
auto funcOp = returnOp.template getParentOfType<FuncOp>();
|
||||
auto numFuncArgs = funcOp.getNumArguments();
|
||||
auto loc = returnOp.getLoc();
|
||||
|
||||
// Find the corresponding output buffer for each operand.
|
||||
for (auto operand : llvm::enumerate(operands)) {
|
||||
auto returnArgNumber = numFuncArgs - numReturnValues + operand.index();
|
||||
auto dstBuffer = funcOp.getArgument(returnArgNumber);
|
||||
if (dstBuffer == operand.value()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Insert the copy operation to copy before the return.
|
||||
rewriter.setInsertionPoint(
|
||||
returnOp.getOperation()->getBlock()->getTerminator());
|
||||
rewriter.create<CopyOpTy>(loc, operand.value(),
|
||||
funcOp.getArgument(returnArgNumber));
|
||||
}
|
||||
// Insert the new target return operation.
|
||||
rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_BUFFER_ASSIGNMENT_H_
|
@ -324,8 +324,8 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern {
|
||||
// "xla_lhlo.terminator"() : () -> ()
|
||||
// }
|
||||
|
||||
struct HloLegalizeToLhlo : public ModulePass<HloLegalizeToLhlo> {
|
||||
void runOnModule() override {
|
||||
struct HloLegalizeToLhlo : public OperationPass<HloLegalizeToLhlo, ModuleOp> {
|
||||
void runOnOperation() override {
|
||||
OwningRewritePatternList patterns;
|
||||
auto& context = getContext();
|
||||
ConversionTarget target(context);
|
||||
@ -344,7 +344,7 @@ struct HloLegalizeToLhlo : public ModulePass<HloLegalizeToLhlo> {
|
||||
[](Type input) { return input.isa<MemRefType>(); });
|
||||
});
|
||||
|
||||
auto module = getModule();
|
||||
auto module = getOperation();
|
||||
populateHLOToLHLOConversionPattern(module.getContext(), &patterns);
|
||||
|
||||
// Do partial conversion so we can have unknown ops in tests.
|
||||
|
@ -51,9 +51,10 @@ using mlir::PassRegistration;
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace {
|
||||
class LegalizeTFControlFlow : public ModulePass<LegalizeTFControlFlow> {
|
||||
class LegalizeTFControlFlow
|
||||
: public OperationPass<LegalizeTFControlFlow, ModuleOp> {
|
||||
public:
|
||||
void runOnModule() override;
|
||||
void runOnOperation() override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@ -164,8 +165,8 @@ void LowerWhile(TF::WhileOp op, ModuleOp module) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void LegalizeTFControlFlow::runOnModule() {
|
||||
auto module = getModule();
|
||||
void LegalizeTFControlFlow::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
|
||||
module.walk([&](TF::WhileOp op) -> void { LowerWhile(op, module); });
|
||||
module.walk([&](TF::IfOp op) -> void { LowerIf(op, module); });
|
||||
|
@ -29,48 +29,128 @@ namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace {
|
||||
|
||||
// Clones and adapts the code in `lhlo_block` that works on buffers and has a
|
||||
// single output buffer to make it compatible with `operands` that have element
|
||||
// types of the respective buffers. Returns the computed value.
|
||||
//
|
||||
// Example. For `operands` with (f32, i32) types and a block with LHLO ops and
|
||||
// with signature:
|
||||
// ^bb(%lhs: memref<f32>, %rhs: memref<i32>, %res: memref<i1>):
|
||||
// <LHLO_ops>
|
||||
//
|
||||
// inserts necessary alloc and store ops to compute and return result that has
|
||||
// `i1` type.
|
||||
Value ApplySingleResultLhloCode(Location loc, ValueRange operands,
|
||||
Block* lhlo_block, OpBuilder* b) {
|
||||
SmallVector<Value, 2> arg_bufs;
|
||||
for (auto arg_type : lhlo_block->getArgumentTypes()) {
|
||||
arg_bufs.push_back(b->create<AllocOp>(loc, arg_type.cast<MemRefType>()));
|
||||
}
|
||||
for (auto operand : llvm::enumerate(operands)) {
|
||||
b->create<StoreOp>(loc, operand.value(), arg_bufs[operand.index()]);
|
||||
}
|
||||
// Clone the ops from `lhlo_block`.
|
||||
BlockAndValueMapping mapping;
|
||||
mapping.map(lhlo_block->getArguments(), arg_bufs);
|
||||
for (auto& nested : lhlo_block->without_terminator()) {
|
||||
auto clone = b->clone(nested, mapping);
|
||||
mapping.map(nested.getResults(), clone->getResults());
|
||||
}
|
||||
return b->create<LoadOp>(loc, arg_bufs.back());
|
||||
}
|
||||
|
||||
// Converts a block with LHLO ops and with signature:
|
||||
// ^bb(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||
// into a reduction operator of loop.reduce by doing buffer allocation for
|
||||
// scalar arguments and the result of `loop.reduce` to make it compatible with
|
||||
// LHLO ops.
|
||||
void ConvertToReductionOperator(Location loc, loop::ReduceOp reduce_op,
|
||||
Block* lhlo_block,
|
||||
ConversionPatternRewriter* rewriter) {
|
||||
Block* lhlo_block, OpBuilder* b) {
|
||||
Block& loop_reduce_op_body = reduce_op.reductionOperator().front();
|
||||
rewriter->setInsertionPointToStart(&loop_reduce_op_body);
|
||||
|
||||
// Allocate buffers to hold arguments of reduction operator block to stay
|
||||
// compatible with the LHLO dialect ops in the reduction body.
|
||||
Value elem_arg = lhlo_block->getArgument(0);
|
||||
Value elem_buf =
|
||||
rewriter->create<AllocOp>(loc, elem_arg.getType().cast<MemRefType>());
|
||||
rewriter->create<StoreOp>(loc, loop_reduce_op_body.getArgument(0), elem_buf);
|
||||
Value acc_arg = lhlo_block->getArgument(1);
|
||||
Value acc_buf =
|
||||
rewriter->create<AllocOp>(loc, acc_arg.getType().cast<MemRefType>());
|
||||
rewriter->create<StoreOp>(loc, loop_reduce_op_body.getArgument(1), acc_buf);
|
||||
|
||||
// Clone the ops from `xla_lhlo.reduce` into reduction operator block.
|
||||
BlockAndValueMapping mapping;
|
||||
mapping.map(lhlo_block->getArguments(),
|
||||
ValueRange{elem_buf, acc_buf, acc_buf});
|
||||
for (auto& nested : lhlo_block->without_terminator()) {
|
||||
auto clone = rewriter->clone(nested, mapping);
|
||||
mapping.map(nested.getResults(), clone->getResults());
|
||||
}
|
||||
Value acc_result = rewriter->create<LoadOp>(loc, acc_buf);
|
||||
rewriter->create<loop::ReduceReturnOp>(loc, acc_result);
|
||||
OpBuilder::InsertionGuard guard(*b);
|
||||
b->setInsertionPointToStart(&loop_reduce_op_body);
|
||||
b->create<loop::ReduceReturnOp>(
|
||||
loc, ApplySingleResultLhloCode(loc, loop_reduce_op_body.getArguments(),
|
||||
lhlo_block, b));
|
||||
}
|
||||
|
||||
// Returns result of ConstantOp if `dim` is static, otherwise uses DimOp to
|
||||
// extract dimension at runtime.
|
||||
Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value,
|
||||
size_t dim_index, int64_t dim,
|
||||
ConversionPatternRewriter* rewriter) {
|
||||
size_t dim_index, int64_t dim, OpBuilder* b) {
|
||||
return dim == ShapedType::kDynamicSize
|
||||
? rewriter->create<DimOp>(loc, shaped_value, dim_index).getResult()
|
||||
: rewriter->create<ConstantIndexOp>(loc, dim);
|
||||
? b->create<DimOp>(loc, shaped_value, dim_index).getResult()
|
||||
: b->create<ConstantIndexOp>(loc, dim);
|
||||
}
|
||||
|
||||
struct MappedIvs {
|
||||
// False if the mapped indices are in the padding area, true otherwise.
|
||||
Value in_bounds;
|
||||
// Mapped indices.
|
||||
SmallVector<Value, 2> ivs;
|
||||
};
|
||||
|
||||
MappedIvs MapWindowIvsToInput(ReduceWindowOp op, ValueRange ivs,
|
||||
ValueRange window_ivs, OpBuilder* b) {
|
||||
MappedIvs mapped_ivs;
|
||||
|
||||
if (!op.window_strides().hasValue()) {
|
||||
op.emitOpError("No window strides specified.");
|
||||
}
|
||||
auto window_strides = op.window_strides().getValue();
|
||||
|
||||
if (!op.padding().hasValue()) {
|
||||
op.emitOpError("No padding specified.");
|
||||
}
|
||||
auto padding = op.padding().getValue();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto operand = op.operand();
|
||||
auto operand_shape = operand.getType().cast<MemRefType>().getShape();
|
||||
|
||||
// `in_bounds` is false when the mapped indices are in the padding area.
|
||||
mapped_ivs.in_bounds = b->create<mlir::ConstantOp>(
|
||||
loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
|
||||
for (unsigned i = 0, e = ivs.size(); i < e; ++i) {
|
||||
auto stride = window_strides.getValue<llvm::APInt>(i);
|
||||
auto pad_low = padding.getValue<llvm::APInt>({i, 0});
|
||||
|
||||
Value stride_val = b->create<ConstantIndexOp>(loc, stride.getSExtValue());
|
||||
Value pad_low_val = b->create<ConstantIndexOp>(loc, pad_low.getSExtValue());
|
||||
|
||||
Value center = b->create<MulIOp>(loc, ivs[i], stride_val);
|
||||
Value offset = b->create<SubIOp>(loc, window_ivs[i], pad_low_val);
|
||||
Value index = b->create<AddIOp>(loc, center, offset);
|
||||
Value upper_bound =
|
||||
GetStaticOrDynamicDim(loc, operand, i, operand_shape[i], b);
|
||||
// We must check whether 0 <= index_i < shape_i, as otherwise we are in
|
||||
// the pad and then we have to use the neutral element for reduction.
|
||||
// Equivalently, it can be computed as the unsigned comparison index_i <
|
||||
// shape_i, since a negative value wraps to a large positive value.
|
||||
mapped_ivs.in_bounds = b->create<mlir::AndOp>(
|
||||
loc, mapped_ivs.in_bounds,
|
||||
b->create<CmpIOp>(loc, CmpIPredicate::ult, index, upper_bound));
|
||||
mapped_ivs.ivs.push_back(index);
|
||||
}
|
||||
return mapped_ivs;
|
||||
}
|
||||
|
||||
// Returns loop::Parallel over a shaped value with static or dynamic shape.
|
||||
loop::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
|
||||
OpBuilder* b) {
|
||||
Value zero = b->create<ConstantIndexOp>(loc, 0);
|
||||
Value one = b->create<ConstantIndexOp>(loc, 1);
|
||||
|
||||
ArrayRef<int64_t> shape =
|
||||
shaped_value.getType().cast<ShapedType>().getShape();
|
||||
SmallVector<Value, 2> lower, upper, step;
|
||||
for (auto dim : llvm::enumerate(shape)) {
|
||||
upper.push_back(
|
||||
GetStaticOrDynamicDim(loc, shaped_value, dim.index(), dim.value(), b));
|
||||
lower.push_back(zero);
|
||||
step.push_back(one);
|
||||
}
|
||||
return b->create<loop::ParallelOp>(loc, lower, upper, step);
|
||||
}
|
||||
|
||||
// Converts `xla_lhlo.ReduceOp` into two loop::ParallelOp and a loop::ReduceOp.
|
||||
@ -186,7 +266,7 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
|
||||
SmallVector<Value, 1> out_indices;
|
||||
if (outer != nullptr) {
|
||||
out_indices.reserve(outer.getNumLoops());
|
||||
for (auto& iv : outer.getInductionVars()) {
|
||||
for (Value iv : outer.getInductionVars()) {
|
||||
out_indices.push_back(iv);
|
||||
}
|
||||
} else {
|
||||
@ -198,12 +278,16 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
|
||||
// Load the element to reduce.
|
||||
SmallVector<Value, 2> indices;
|
||||
indices.reserve(operand_shape.size());
|
||||
Block::args_iterator outer_ivs_it =
|
||||
outer ? outer.getInductionVars().begin() : nullptr;
|
||||
Block::args_iterator inner_ivs_it = inner.getInductionVars().begin();
|
||||
for (unsigned i = 0, e = operand_shape.size(); i < e; ++i) {
|
||||
indices.push_back(reducing_dims.count(i) ? *inner_ivs_it++
|
||||
: *outer_ivs_it++);
|
||||
|
||||
if (outer) {
|
||||
auto inner_ivs_it = inner.getInductionVars().begin();
|
||||
auto outer_ivs_it = outer.getInductionVars().begin();
|
||||
for (unsigned i = 0, e = operand_shape.size(); i < e; ++i) {
|
||||
indices.push_back(reducing_dims.count(i) ? *inner_ivs_it++
|
||||
: *outer_ivs_it++);
|
||||
}
|
||||
} else {
|
||||
indices = ValueRange(inner.getInductionVars());
|
||||
}
|
||||
|
||||
rewriter->setInsertionPointToStart(inner.getBody());
|
||||
@ -309,20 +393,11 @@ class ReduceWindowOpConverter
|
||||
|
||||
// Create an outer parallel loop that spans the output of ReduceWindowOp.
|
||||
Value xla_output = xla_reduce_window_op.out();
|
||||
auto output_shape = xla_output.getType().cast<MemRefType>().getShape();
|
||||
SmallVector<Value, 2> parallel_lower, parallel_upper, parallel_step;
|
||||
for (auto dim : llvm::enumerate(output_shape)) {
|
||||
parallel_upper.push_back(GetStaticOrDynamicDim(
|
||||
loc, xla_output, dim.index(), dim.value(), rewriter));
|
||||
parallel_lower.push_back(zero);
|
||||
parallel_step.push_back(one);
|
||||
}
|
||||
auto output_loop = rewriter->create<loop::ParallelOp>(
|
||||
loc, parallel_lower, parallel_upper, parallel_step);
|
||||
auto output_loop = MakeLoopOverShape(loc, xla_output, rewriter);
|
||||
|
||||
// Create a nested loop that traverses the window.
|
||||
rewriter->setInsertionPointToStart(output_loop.getBody());
|
||||
SmallVector<Value, 2> window_lower, window_upper, window_step;
|
||||
rewriter->setInsertionPointToStart(output_loop.getBody());
|
||||
for (const auto& window_dim : xla_reduce_window_op.window_dimensions()) {
|
||||
window_step.push_back(one);
|
||||
window_lower.push_back(zero);
|
||||
@ -334,9 +409,8 @@ class ReduceWindowOpConverter
|
||||
|
||||
Value reduction_result = *window_loop.getResults().begin();
|
||||
auto output_ivs = output_loop.getInductionVars();
|
||||
rewriter->create<StoreOp>(
|
||||
loc, reduction_result, xla_output,
|
||||
llvm::makeArrayRef(output_ivs.begin(), output_ivs.end()));
|
||||
rewriter->create<StoreOp>(loc, reduction_result, xla_output,
|
||||
ValueRange{output_ivs});
|
||||
return std::make_pair(output_loop, window_loop);
|
||||
}
|
||||
|
||||
@ -347,12 +421,6 @@ class ReduceWindowOpConverter
|
||||
rewriter->setInsertionPointToStart(window_loop.getBody());
|
||||
auto loc = xla_reduce_window_op.getLoc();
|
||||
|
||||
if (!xla_reduce_window_op.window_strides().hasValue()) {
|
||||
xla_reduce_window_op.emitOpError("No window strides specified.");
|
||||
}
|
||||
if (!xla_reduce_window_op.padding().hasValue()) {
|
||||
xla_reduce_window_op.emitOpError("No padding specified.");
|
||||
}
|
||||
if (xla_reduce_window_op.base_dilations().hasValue() ||
|
||||
xla_reduce_window_op.window_dilations().hasValue()) {
|
||||
xla_reduce_window_op.emitRemark(
|
||||
@ -362,51 +430,18 @@ class ReduceWindowOpConverter
|
||||
|
||||
Value xla_operand = xla_reduce_window_op.operand();
|
||||
auto xla_operand_type = xla_operand.getType().cast<MemRefType>();
|
||||
auto xla_operand_shape = xla_operand_type.getShape();
|
||||
|
||||
auto output_ivs = llvm::to_vector<2>(output_loop.getInductionVars());
|
||||
auto window_ivs = llvm::to_vector<2>(window_loop.getInductionVars());
|
||||
auto window_strides = xla_reduce_window_op.window_strides().getValue();
|
||||
auto padding = xla_reduce_window_op.padding().getValue();
|
||||
MappedIvs mapped_ivs = MapWindowIvsToInput(
|
||||
xla_reduce_window_op, output_loop.getInductionVars(),
|
||||
window_loop.getInductionVars(), rewriter);
|
||||
|
||||
SmallVector<Value, 2> operand_indices;
|
||||
// `in_bounds` is false when the element in the reduce window is in the
|
||||
// padding area, true otherwise.
|
||||
Value in_bounds = rewriter->create<mlir::ConstantOp>(
|
||||
loc, rewriter->getI1Type(),
|
||||
rewriter->getIntegerAttr(rewriter->getI1Type(), 1));
|
||||
for (unsigned i = 0, e = output_loop.getNumLoops(); i < e; ++i) {
|
||||
auto stride = window_strides.getValue<llvm::APInt>(i);
|
||||
auto pad_low = padding.getValue<llvm::APInt>({i, 0});
|
||||
|
||||
Value stride_val =
|
||||
rewriter->create<ConstantIndexOp>(loc, stride.getSExtValue());
|
||||
Value pad_low_val =
|
||||
rewriter->create<ConstantIndexOp>(loc, pad_low.getSExtValue());
|
||||
|
||||
Value center = rewriter->create<MulIOp>(loc, output_ivs[i], stride_val);
|
||||
Value offset = rewriter->create<SubIOp>(loc, window_ivs[i], pad_low_val);
|
||||
Value index = rewriter->create<AddIOp>(loc, center, offset);
|
||||
operand_indices.push_back(index);
|
||||
Value upper_bound = GetStaticOrDynamicDim(loc, xla_operand, i,
|
||||
xla_operand_shape[i], rewriter);
|
||||
// We must check whether 0 <= index_i < shape_i, as otherwise we are in
|
||||
// the pad and then we have to use the neutral element for reduction.
|
||||
// Equivalently, it can be computed as the unsigned comparison index_i <
|
||||
// shape_i, since a negative value wraps to a large positive value.
|
||||
in_bounds = rewriter->create<mlir::AndOp>(
|
||||
loc, in_bounds,
|
||||
rewriter->create<CmpIOp>(loc, CmpIPredicate::ult, index,
|
||||
upper_bound));
|
||||
}
|
||||
|
||||
auto elem_or_init =
|
||||
rewriter->create<loop::IfOp>(loc, xla_operand_type.getElementType(),
|
||||
in_bounds, /*withElseRegion=*/true);
|
||||
auto elem_or_init = rewriter->create<loop::IfOp>(
|
||||
loc, xla_operand_type.getElementType(), mapped_ivs.in_bounds,
|
||||
/*withElseRegion=*/true);
|
||||
|
||||
OpBuilder then_builder = elem_or_init.getThenBodyBuilder();
|
||||
Value elem = then_builder.create<mlir::LoadOp>(
|
||||
loc, xla_reduce_window_op.operand(), operand_indices);
|
||||
loc, xla_reduce_window_op.operand(), mapped_ivs.ivs);
|
||||
then_builder.create<loop::YieldOp>(loc, elem);
|
||||
|
||||
OpBuilder else_builder = elem_or_init.getElseBodyBuilder();
|
||||
@ -423,8 +458,12 @@ struct LhloLegalizeToParallelLoops
|
||||
auto func = getFunction();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<ReduceOpConverter, ReduceWindowOpConverter>(
|
||||
func.getContext());
|
||||
// clang-format off
|
||||
patterns.insert<
|
||||
ReduceOpConverter,
|
||||
ReduceWindowOpConverter
|
||||
>(func.getContext());
|
||||
// clang-format on
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||
|
@ -95,6 +95,24 @@ std::unique_ptr<Pass> createLhloCopyRemovalPass();
|
||||
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeLhloToParallelLoopsPass();
|
||||
|
||||
} // namespace xla_lhlo
|
||||
|
||||
namespace xla {
|
||||
|
||||
/// Moves alloc nodes (and their associated dealloc nodes - if any) into the
|
||||
/// right positions. If there is no associated dealloc node for a given alloc
|
||||
/// node, this pass will automatically insert a proper dealloc node in the right
|
||||
/// place. The intended use case of this pass is to store SSA values into
|
||||
/// buffers using load/store operations. For this purpose, you need to know
|
||||
/// proper positions to place the required allocs and deallocs.
|
||||
/// 1) Note that the function signatures and all types for which buffers should
|
||||
/// be allocated need to be converted in advance.
|
||||
/// 2) All required alloc nodes have the be inserted in advance.
|
||||
/// 3) Note that the current implementation does not support loops.
|
||||
/// Refer to the class mlir::xla::BufferAssignmentLegalizer for more
|
||||
/// information.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> createBufferAssignmentPass();
|
||||
|
||||
} // namespace xla
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_PASSES_H_
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user