Merge branch 'master' into update-array-ops-docstrings

This commit is contained in:
Mihai Maruseac 2020-04-06 16:47:37 +00:00 committed by GitHub
commit 16028e280d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
906 changed files with 22378 additions and 42445 deletions

1
.gitignore vendored
View File

@ -38,6 +38,7 @@ gradleBuild
*.pbxproj *.pbxproj
*.xcworkspace *.xcworkspace
/*.podspec /*.podspec
/tensorflow/lite/**/coreml/**/BUILD
/tensorflow/lite/**/ios/BUILD /tensorflow/lite/**/ios/BUILD
/tensorflow/lite/**/objc/BUILD /tensorflow/lite/**/objc/BUILD
/tensorflow/lite/**/swift/BUILD /tensorflow/lite/**/swift/BUILD

View File

@ -58,6 +58,8 @@ NCCL_LIB_PATHS = [
# List of files to configure when building Bazel on Apple platforms. # List of files to configure when building Bazel on Apple platforms.
APPLE_BAZEL_FILES = [ APPLE_BAZEL_FILES = [
'tensorflow/lite/experimental/delegates/coreml/BUILD',
'tensorflow/lite/experimental/delegates/coreml/builders/BUILD',
'tensorflow/lite/experimental/ios/BUILD', 'tensorflow/lite/experimental/ios/BUILD',
'tensorflow/lite/experimental/objc/BUILD', 'tensorflow/lite/experimental/objc/BUILD',
'tensorflow/lite/experimental/swift/BUILD', 'tensorflow/lite/experimental/swift/BUILD',

View File

@ -639,7 +639,7 @@ tf_cc_shared_object(
"//tensorflow/cc/saved_model:loader_lite_impl", "//tensorflow/cc/saved_model:loader_lite_impl",
"//tensorflow/core:core_cpu_impl", "//tensorflow/core:core_cpu_impl",
"//tensorflow/core:framework_internal_impl", "//tensorflow/core:framework_internal_impl",
"//tensorflow/core:gpu_runtime_impl", "//tensorflow/core/common_runtime/gpu:gpu_runtime_impl",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
"//tensorflow/core:lib_internal_impl", "//tensorflow/core:lib_internal_impl",
"//tensorflow/core/profiler:profiler_impl", "//tensorflow/core/profiler:profiler_impl",

View File

@ -995,9 +995,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
return nullptr; return nullptr;
} }
tensorflow::Tensor tensor = tensorflow::TensorFromInterface(t); return new TF_Tensor{t};
t->Release();
return tensorflow::TF_TensorFromTensor(tensor, &status->status);
} }
void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) { void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {

View File

@ -580,12 +580,6 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
}; };
} }
void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h,
TF_Status* status) {
h->handle->EnableImplicitMirroring();
status->status = tensorflow::Status::OK();
}
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
TF_Buffer* buf, TF_Status* status) { TF_Buffer* buf, TF_Status* status) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =

View File

@ -392,12 +392,6 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx, TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
TF_Status* status); TF_Status* status);
// If the TensorHandle is copied to another device as part of an op execution,
// the copy is destroyed after the op has executed. Enabling implicit mirroring
// causes the copy to be held as a mirror for the lifetime of the TensorHandle.
TF_CAPI_EXPORT extern void TFE_TensorHandleEnableImplicitMirroring(
TFE_TensorHandle*, TF_Status*);
// This function will block till the operation that produces `h` has // This function will block till the operation that produces `h` has
// completed. This is only valid on local TFE_TensorHandles. The pointer // completed. This is only valid on local TFE_TensorHandles. The pointer
// returned will be on the device in which the TFE_TensorHandle resides (so e.g. // returned will be on the device in which the TFE_TensorHandle resides (so e.g.

View File

@ -168,8 +168,6 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
auto* h1_task2 = auto* h1_task2 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status); TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandleEnableImplicitMirroring(h1_task2, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Handles are on task0 (local), and task2, but op is on task1. // Handles are on task0 (local), and task2, but op is on task1.
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2); TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);

View File

@ -594,7 +594,6 @@ void ExecuteAdd(bool async, bool forward_input) {
TFE_TensorHandle* n_gpu = TFE_TensorHandle* n_gpu =
TFE_TensorHandleCopyToDevice(n, ctx, gpu_device_name.c_str(), status); TFE_TensorHandleCopyToDevice(n, ctx, gpu_device_name.c_str(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandleEnableImplicitMirroring(n_gpu, status);
TFE_DeleteTensorHandle(n); TFE_DeleteTensorHandle(n);
n = n_gpu; n = n_gpu;
} }

View File

@ -59,14 +59,6 @@ class AbstractTensorHandleInterface {
// Return a copy of the handle. // Return a copy of the handle.
virtual AbstractTensorHandleInterface* Copy() = 0; virtual AbstractTensorHandleInterface* Copy() = 0;
// Maintain mirror tensors for any implicit copies to local devices. This
// setting is offered on a per tensor handle basis to avoid potential memory
// over utilization due to holding on to mirrors as well as the original
// tensor. Note this setting overrides the context mirroring policy whereby if
// the mirroring policy is MIRRORING_NONE, we will still continue to mirror
// this tensor.
virtual void EnableImplicitMirroring() = 0;
protected: protected:
virtual ~AbstractTensorHandleInterface() {} virtual ~AbstractTensorHandleInterface() {}
}; };

View File

@ -118,8 +118,8 @@ cc_library(
"//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep "//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:gpu_init",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/common_runtime/gpu:gpu_init",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],

View File

@ -72,6 +72,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo", "//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
"//tensorflow/compiler/mlir/xla:buffer_assignment",
"//tensorflow/compiler/mlir/xla:hlo", "//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo", "//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo",
"//tensorflow/compiler/mlir/xla:lhlo", "//tensorflow/compiler/mlir/xla:lhlo",

View File

@ -1,6 +1,7 @@
book_path: /mlir/_book.yaml book_path: /mlir/_book.yaml
project_path: /mlir/_project.yaml project_path: /mlir/_project.yaml
description: <!--no description--> description: An intermediate representation and compiler framework, MLIR unifies the
infrastructure for high-performance ML models in TensorFlow.
landing_page: landing_page:
custom_css_path: /site-assets/css/style.css custom_css_path: /site-assets/css/style.css
rows: rows:

View File

@ -771,6 +771,7 @@ cc_library(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/lite/tools/optimize:quantize_weights", "//tensorflow/lite/tools/optimize:quantize_weights",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:support", "@llvm-project//llvm:support",
"@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",

View File

@ -36,7 +36,7 @@ struct PassConfig {
form_clusters(false), form_clusters(false),
unfold_batch_matmul(true), unfold_batch_matmul(true),
legalize_tf_while(true), legalize_tf_while(true),
shape_inference(false) {} shape_inference(true) {}
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be // If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
// added, which produces TF Lite ops. // added, which produces TF Lite ops.

View File

@ -409,10 +409,14 @@ static void GenOperandResultVerifier(raw_ostream &os,
os << " (void)v;\n" os << " (void)v;\n"
<< " if (!(" << " if (!("
<< tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n" << tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n"
<< " if (failure_on_operand_type_mismatch) {\n"
<< formatv( << formatv(
" return op->emitOpError(\"{0} #\") << index " " return op->emitOpError(\"{0} #\") << index "
"<< \" must be {1}, but got \" << v.getType();\n", "<< \" must be {1}, but got \" << v.getType();\n",
valueKind, desc) valueKind, desc)
<< " } else {\n"
<< " return ::mlir::LogicalResult::Failure;\n"
<< " }\n"
<< " }\n" // if << " }\n" // if
<< " ++index;\n" << " ++index;\n"
<< " }\n"; // for << " }\n"; // for
@ -437,7 +441,8 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
mlir::tblgen::FmtContext verify_ctx; mlir::tblgen::FmtContext verify_ctx;
os << "::mlir::LogicalResult " << op.getCppClassName() os << "::mlir::LogicalResult " << op.getCppClassName()
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op) {\n"; << "::VerifyTflRuntimeTypes(::mlir::Operation *op, bool "
"failure_on_operand_type_mismatch) {\n";
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n"; os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
verify_ctx.withOp("top"); verify_ctx.withOp("top");

View File

@ -70,6 +70,19 @@ class TFLiteCostEstimator<Conv2DOp, hardware::GPU> {
static bool IsSupported(mlir::Operation* op) { return true; } static bool IsSupported(mlir::Operation* op) { return true; }
}; };
// tfl.cos
template <>
class TFLiteCostEstimator<CosOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.depthwise_conv_2d // tfl.depthwise_conv_2d
template <> template <>
class TFLiteCostEstimator<DepthwiseConv2DOp, hardware::GPU> { class TFLiteCostEstimator<DepthwiseConv2DOp, hardware::GPU> {
@ -83,6 +96,32 @@ class TFLiteCostEstimator<DepthwiseConv2DOp, hardware::GPU> {
static bool IsSupported(mlir::Operation* op) { return true; } static bool IsSupported(mlir::Operation* op) { return true; }
}; };
// tfl.div
template <>
class TFLiteCostEstimator<DivOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.exp
template <>
class TFLiteCostEstimator<ExpOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.fully_connected // tfl.fully_connected
template <> template <>
class TFLiteCostEstimator<FullyConnectedOp, hardware::GPU> { class TFLiteCostEstimator<FullyConnectedOp, hardware::GPU> {
@ -97,6 +136,19 @@ class TFLiteCostEstimator<FullyConnectedOp, hardware::GPU> {
static bool IsSupported(mlir::Operation* op) { return true; } static bool IsSupported(mlir::Operation* op) { return true; }
}; };
// tfl.hard_swish
template <>
class TFLiteCostEstimator<HardSwishOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.logistic // tfl.logistic
template <> template <>
class TFLiteCostEstimator<LogisticOp, hardware::GPU> { class TFLiteCostEstimator<LogisticOp, hardware::GPU> {

View File

@ -138,6 +138,8 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
return tflite::TensorType_FLOAT32; return tflite::TensorType_FLOAT32;
case mlir::StandardTypes::F16: case mlir::StandardTypes::F16:
return tflite::TensorType_FLOAT16; return tflite::TensorType_FLOAT16;
case mlir::StandardTypes::F64:
return tflite::TensorType_FLOAT64;
case mlir::TF::TensorFlowTypes::STRING: case mlir::TF::TensorFlowTypes::STRING:
return tflite::TensorType_STRING; return tflite::TensorType_STRING;
case mlir::TF::TensorFlowTypes::QUINT8: case mlir::TF::TensorFlowTypes::QUINT8:

View File

@ -353,6 +353,22 @@ StatusOr<mlir::ElementsAttr> ConvertFloatBuffer(
} }
return DenseElementsAttr::get(shaped_type, ArrayRef<float>(values)); return DenseElementsAttr::get(shaped_type, ArrayRef<float>(values));
} }
case 64: {
assert(bytes_len % 8 == 0);
size_t elem_count = bytes_len / 8;
std::vector<double> values;
values.reserve(elem_count);
const char* data = reinterpret_cast<const char*>(buffer.data());
for (int i = 0; i < elem_count; i++) {
uint64_t bit_repr =
llvm::support::endian::readNext<uint64_t, llvm::support::little,
llvm::support::unaligned>(data);
values.push_back(absl::bit_cast<double>(bit_repr));
}
return DenseElementsAttr::get(shaped_type, ArrayRef<double>(values));
}
} }
return errors::InvalidArgument("unsupported bit width", elem_type.getWidth()); return errors::InvalidArgument("unsupported bit width", elem_type.getWidth());
} }

View File

@ -86,7 +86,8 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
let methods = [ let methods = [
StaticInterfaceMethod< StaticInterfaceMethod<
[{Returns whether the op's operands/results are supported by runtime.}], [{Returns whether the op's operands/results are supported by runtime.}],
"LogicalResult", "VerifyTflRuntimeTypes", (ins "Operation*":$op) "LogicalResult", "VerifyTflRuntimeTypes",
(ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch)
>, >,
]; ];
} }

View File

@ -706,7 +706,10 @@ def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> {
} }
def TFL_CosOp: TFL_Op<"cos", [ def TFL_CosOp: TFL_Op<"cos", [
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> { NoSideEffect,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
let summary = "Cosine operator"; let summary = "Cosine operator";
let description = [{ let description = [{
@ -827,12 +830,12 @@ def TFL_GatherNdOp : TFL_Op<"gather_nd", [NoSideEffect]> {
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$params, TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8, TFL_Str]>:$params,
TFL_I32OrI64Tensor:$indices TFL_I32OrI64Tensor:$indices
); );
let results = (outs let results = (outs
TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$output TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8, TFL_Str]>:$output
); );
} }
@ -1108,7 +1111,10 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [
def TFL_DivOp : TFL_Op<"div", [ def TFL_DivOp : TFL_Op<"div", [
// TODO(fengliuai): NoQuantizableResult is only correct for int8 // TODO(fengliuai): NoQuantizableResult is only correct for int8
// quantization. update to handle Uint8 quantization. // quantization. update to handle Uint8 quantization.
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> { ResultsBroadcastableShape,
NoSideEffect,
NoQuantizableResult,
TFL_GpuTargetOp]> {
let summary = "Division operator"; let summary = "Division operator";
let description = [{ let description = [{
@ -1187,7 +1193,9 @@ def TFL_EqualOp: TFL_Op<"equal", [Commutative, ResultsBroadcastableShape,
let builders = [TFL_ComparisonBinaryBuilder]; let builders = [TFL_ComparisonBinaryBuilder];
} }
def TFL_ExpOp: TFL_Op<"exp", [NoSideEffect, SameOperandsAndResultType]> { def TFL_ExpOp: TFL_Op<"exp", [NoSideEffect,
SameOperandsAndResultType,
TFL_GpuTargetOp]> {
let summary = "Natural exponentiation operator"; let summary = "Natural exponentiation operator";
let description = [{ let description = [{
@ -1369,7 +1377,8 @@ def TFL_GreaterOp : TFL_Op<"greater", [
} }
def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect, def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect,
SameOperandsAndResultShape]> { SameOperandsAndResultShape,
TFL_GpuTargetOp]> {
let summary = "Hardswish activation function."; let summary = "Hardswish activation function.";
let description = [{ let description = [{
Computes hard-swish activation function Computes hard-swish activation function

View File

@ -84,8 +84,14 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context)); auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
mlir::TFL::PassConfig pass_config(quant_specs);
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
pass_config.lower_tensor_list_ops = true;
pass_config.shape_inference = false;
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module), return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
quant_specs, result); pass_config, result);
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -43,8 +43,6 @@ namespace tensorflow {
Status ConvertSavedModelToTFLiteFlatBuffer( Status ConvertSavedModelToTFLiteFlatBuffer(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
const string& saved_model_dir, bool saved_model_v1,
const string& saved_model_tags, const string& saved_model_exported_names,
string* result) { string* result) {
mlir::MLIRContext context; mlir::MLIRContext context;
mlir::TFL::QuantizationSpecs quant_specs; mlir::TFL::QuantizationSpecs quant_specs;
@ -66,13 +64,28 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
// Register all custom ops, including user-specified custom ops. // Register all custom ops, including user-specified custom ops.
TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags)); TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags));
const bool import_saved_model = !saved_model_v1; auto& saved_model_tags = model_flags.saved_model_tags();
TF_ASSIGN_OR_RETURN( auto& saved_model_exported_names = model_flags.saved_model_exported_names();
auto module, std::unordered_set<std::string> tags(saved_model_tags.begin(),
ImportSavedModel(import_saved_model, saved_model_v1, saved_model_dir, saved_model_tags.end());
saved_model_tags, saved_model_exported_names, &context)); auto exported_names_in_vector = std::vector<std::string>(
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module), saved_model_exported_names.begin(), saved_model_exported_names.end());
quant_specs, result); absl::Span<std::string> exported_names(exported_names_in_vector);
TF_ASSIGN_OR_RETURN(auto module,
ImportSavedModel(model_flags.saved_model_dir(),
model_flags.saved_model_version(), tags,
exported_names, &context));
mlir::TFL::PassConfig pass_config(quant_specs);
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
pass_config.lower_tensor_list_ops = true;
pass_config.shape_inference = true;
auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
toco_flags, std::move(module), pass_config, result);
return status;
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -28,8 +28,6 @@ namespace tensorflow {
// status if it fails to convert the input. // status if it fails to convert the input.
Status ConvertSavedModelToTFLiteFlatBuffer( Status ConvertSavedModelToTFLiteFlatBuffer(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
const string& saved_model_dir, bool saved_model_v1,
const string& saved_model_tags, const string& saved_model_exported_names,
string* result); string* result);
} // namespace tensorflow } // namespace tensorflow

View File

@ -105,6 +105,10 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
switch (dtype) { switch (dtype) {
case toco::IODataType::FLOAT: case toco::IODataType::FLOAT:
return DT_FLOAT; return DT_FLOAT;
case toco::IODataType::FLOAT16:
return DT_HALF;
case toco::IODataType::FLOAT64:
return DT_DOUBLE;
case toco::IODataType::QUANTIZED_UINT8: case toco::IODataType::QUANTIZED_UINT8:
return DT_QUINT8; return DT_QUINT8;
case toco::IODataType::INT8: case toco::IODataType::INT8:
@ -261,7 +265,7 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags, Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
mlir::OwningModuleRef module, mlir::OwningModuleRef module,
mlir::TFL::QuantizationSpecs quant_specs, const mlir::TFL::PassConfig& pass_config,
string* result) { string* result) {
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops(); bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
@ -275,9 +279,6 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
} }
mlir::PassManager pm(module->getContext()); mlir::PassManager pm(module->getContext());
mlir::TFL::PassConfig pass_config(quant_specs);
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
pass_config.lower_tensor_list_ops = true;
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm); tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
// Convert back to outlined while format for export back to flatbuffer. // Convert back to outlined while format for export back to flatbuffer.
@ -288,7 +289,8 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
auto status = ConvertTFExecutorToTFLOrFlatbuffer( auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops, module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, quant_specs, result, &pm); emit_select_tf_ops, emit_custom_ops, pass_config.quant_specs, result,
&pm);
if (toco_flags.has_dump_graphviz_dir()) { if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile( TF_RETURN_IF_ERROR(DumpOpGraphToFile(
// rename once we enable the new converter feature flag. // rename once we enable the new converter feature flag.

View File

@ -47,7 +47,7 @@ Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
// This will also run relevant passes as well. // This will also run relevant passes as well.
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags, Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
mlir::OwningModuleRef module, mlir::OwningModuleRef module,
mlir::TFL::QuantizationSpecs quant_specs, const mlir::TFL::PassConfig& pass_config,
string* result); string* result);
// Give a warning for any unused flags that have been specified. // Give a warning for any unused flags that have been specified.

View File

@ -57,8 +57,9 @@ QuantizeContext::QuantizeContext(FuncOp func, const DeviceTarget &spec)
}); });
} }
llvm::ArrayRef<quant::QuantizeRegionOp> QuantizeContext::GetAllOps() { std::vector<quant::QuantizeRegionOp> QuantizeContext::GetAllOps() {
llvm::SmallVector<quant::QuantizeRegionOp, 64> all_ops; std::vector<quant::QuantizeRegionOp> all_ops;
all_ops.reserve(128);
func_.walk([&](quant::QuantizeRegionOp op) { all_ops.push_back(op); }); func_.walk([&](quant::QuantizeRegionOp op) { all_ops.push_back(op); });
return all_ops; return all_ops;
} }
@ -75,7 +76,7 @@ LogicalResult QuantizeContext::Handle(
switch (spec->type) { switch (spec->type) {
case ScaleConstraintType::OutputInputFreeScale: { case ScaleConstraintType::OutputInputFreeScale: {
// no propagation. // no propagation.
*changed = false; *changed |= false;
break; break;
} }
case ScaleConstraintType::CustomScale: { case ScaleConstraintType::CustomScale: {
@ -84,7 +85,20 @@ LogicalResult QuantizeContext::Handle(
} }
break; break;
} }
case ScaleConstraintType::OutputInputSameScale: {
auto params = GetQuantParamsForSameScaleConstraint(op);
if (EmptyParams(params)) {
*changed |= false;
break;
}
// propagate this params to all the quantizable ports.
if (failed(PropagateQuantParams(op, params, new_items, changed))) {
return failure();
}
break;
}
default: { default: {
// TODO(fengliuai): implement the other types.
llvm_unreachable("no implementation."); llvm_unreachable("no implementation.");
return failure(); return failure();
} }
@ -154,6 +168,102 @@ void QuantizeContext::DumpStates(QuantizeRegionOp current_op) {
}); });
} }
// A heuristic to get quantization parameters satisfies the same scale
// constraints:
// - If there are immutable states,
// - use the single input, or,
// - use the single output, or,
// - use the first one in the collection,
// - use the single input if it is ready, or,
// - use the single output if it is ready, or,
// - use use the first ready one in the collection.
QuantParams QuantizeContext::GetQuantParamsForSameScaleConstraint(
Operation *op) {
// Two vector to collect Non-empty operands and results states.
std::vector<quant::QuantState *> mutable_states, immutable_states;
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
auto &state = states_manager_.GetOperandQuantState(op, i);
if (state.immutable) {
immutable_states.push_back(&state);
} else if (!state.IsEmpty()) {
mutable_states.push_back(&state);
}
}
int immutable_operands_num = immutable_states.size();
int mutable_operands_num = mutable_states.size();
// Use the operand's state if it is immutable and it is the only one
// operand.
if (op->getNumOperands() == 1 && immutable_operands_num == 1) {
return immutable_states.front()->params;
}
for (int i = 0, e = op->getNumResults(); i != e; ++i) {
auto &state = states_manager_.GetResultQuantState(op, i);
if (state.immutable) {
immutable_states.push_back(&state);
} else if (!state.IsEmpty()) {
mutable_states.push_back(&state);
}
}
int immutable_results_num = immutable_states.size() - immutable_operands_num;
int mutable_results_num = mutable_states.size() - mutable_operands_num;
// Use the result's state if it is immutable and it is the only one result.
if (op->getNumResults() == 1 && immutable_results_num == 1) {
return immutable_states.back()->params;
}
LLVM_DEBUG(llvm::dbgs()
<< "Quantization parameters are not collected in an ideal place. "
"Has to fallback values which might introduce errors.\n");
// Use the first immutable state to quantize the rest operands and results.
if (!immutable_states.empty()) return immutable_states.front()->params;
// If there are no immutable states, use the operand's state if it is the
// only one operand and has parameters propagated.
if (op->getNumOperands() == 1 && mutable_operands_num == 1) {
return mutable_states.front()->params;
}
// If there are no immutable states, use the result's state if it is the
// only one result and has parameters propagated.
if (op->getNumResults() == 1 && mutable_results_num == 1) {
return mutable_states.back()->params;
}
// Use the first propagated state to quantize the rest operands and results.
if (!mutable_states.empty()) return mutable_states.front()->params;
// None operands/results have parameters propagated, skip this node for now.
return {};
}
LogicalResult QuantizeContext::PropagateQuantParams(
Operation *op, const QuantParams params,
quant::AdjacentOperations *new_items, bool *changed) {
// Use the final state to set all the operands' parameters.
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
auto ele = op->getOperand(i).getType().cast<ShapedType>().getElementType();
if (ele.isa<FloatType>() && SetOperandParams(op, i, params)) {
*changed |= true;
new_items->push_back(op->getOperand(i).getDefiningOp());
}
}
// Use the final state to set all the results' parameters.
for (int res = 0, e = op->getNumResults(); res != e; ++res) {
auto ele = op->getResult(res).getType().cast<ShapedType>().getElementType();
if (ele.isa<FloatType>() && SetResultParams(op, res, params)) {
auto users = op->getResult(res).getUsers();
*changed |= !users.empty();
new_items->append(users.begin(), users.end());
}
}
return success();
}
int QuantizeContext::StatesManager::InitializeState(quant::QuantizeRegionOp op, int QuantizeContext::StatesManager::InitializeState(quant::QuantizeRegionOp op,
int index, bool as_result) { int index, bool as_result) {
Attribute params_attr; Attribute params_attr;

View File

@ -67,7 +67,7 @@ class QuantizeContext {
QuantizeContext(FuncOp func, const DeviceTarget &spec); QuantizeContext(FuncOp func, const DeviceTarget &spec);
// Returns all the quant region ops. // Returns all the quant region ops.
ArrayRef<quant::QuantizeRegionOp> GetAllOps(); std::vector<quant::QuantizeRegionOp> GetAllOps();
// For each quant region op, propagates its quantization parameters according // For each quant region op, propagates its quantization parameters according
// to the kernel specification and also returns the adjcent quant region ops // to the kernel specification and also returns the adjcent quant region ops
@ -107,6 +107,25 @@ class QuantizeContext {
return states_manager_.GetOperandParams(op, index); return states_manager_.GetOperandParams(op, index);
} }
// A heuristic to get quantization parameters satisfies the same scale
// constraints:
// - If there are immutable states,
// - use the single input, or,
// - use the single output, or,
// - use the first one in the collection,
// - use the single input if it is ready, or,
// - use the single output if it is ready, or,
// - use use the first ready one in the collection.
QuantParams GetQuantParamsForSameScaleConstraint(Operation *op);
// Propagate `params` to all the quantizable port of the `op`. The adjcent
// ops, which have the parameters propagated to, are collected by `new_items`,
// so they can be added to the working queue. `changed` is set to true if
// there are any new elements being added to `new_items`.
LogicalResult PropagateQuantParams(Operation *op, const QuantParams params,
AdjacentOperations *new_items,
bool *changed);
private: private:
class StatesManager { class StatesManager {
public: public:

View File

@ -28,6 +28,14 @@ namespace ph = std::placeholders;
CpuDeviceTarget::CpuDeviceTarget(MLIRContext* ctx) : DeviceTarget(ctx) { CpuDeviceTarget::CpuDeviceTarget(MLIRContext* ctx) : DeviceTarget(ctx) {
RegisterKernel("generic.concat", {qi8_, qi8_, qi8_}, RegisterKernel("generic.concat", {qi8_, qi8_, qi8_},
quant::ScaleConstraintType::OutputInputSameScale); quant::ScaleConstraintType::OutputInputSameScale);
// TODO(fengliuai): All the combinations are required to list. We need to
// improve this.
RegisterKernel("generic.reshape", {qi8_, any_},
quant::ScaleConstraintType::OutputInputSameScale);
RegisterKernel("generic.reshape", {any_, qi8_},
quant::ScaleConstraintType::OutputInputSameScale);
RegisterKernel("generic.mul", {qi8_, qi8_, qi8_}, RegisterKernel("generic.mul", {qi8_, qi8_, qi8_},
quant::ScaleConstraintType::OutputInputFreeScale); quant::ScaleConstraintType::OutputInputFreeScale);
RegisterKernel("generic.mul_add", {qi8_, qi8n_, any_, qi8_}, RegisterKernel("generic.mul_add", {qi8_, qi8n_, any_, qi8_},

View File

@ -176,7 +176,7 @@ llvm::SmallVector<Value, 0> fuseOps(PatternRewriter* rewriter,
auto* body = new Block(); auto* body = new Block();
region.body().push_back(body); region.body().push_back(body);
OpBuilder builder(body); OpBuilder builder = OpBuilder::atBlockEnd(body);
BlockAndValueMapping mapping; BlockAndValueMapping mapping;
// Make block arguments and add it to the block value mapping. // Make block arguments and add it to the block value mapping.

View File

@ -69,7 +69,7 @@ void PropagateQuantPass::runOnFunction() {
CpuDeviceTarget spec(&getContext()); CpuDeviceTarget spec(&getContext());
quant::QuantizeContext ctx(func, spec); quant::QuantizeContext ctx(func, spec);
std::vector<quant::QuantizeRegionOp> work_list(ctx.GetAllOps()); std::vector<quant::QuantizeRegionOp> work_list = ctx.GetAllOps();
bool changed = false; bool changed = false;
while (!work_list.empty()) { while (!work_list.empty()) {
quant::QuantizeRegionOp op = work_list.back(); quant::QuantizeRegionOp op = work_list.back();

View File

@ -52,3 +52,18 @@ func @mul_add_annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tenso
// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8<-127:127>:f32, 1.000000e+00:-128>, !quant.uniform<i32:f32, 1.000000e+00>] // CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8<-127:127>:f32, 1.000000e+00:-128>, !quant.uniform<i32:f32, 1.000000e+00>]
// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>] // CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]
} }
// -----
// CHECK-LABEL: @same_scale_1_1
func @same_scale_1_1(%arg0: tensor<1x7x7x64xf32>) -> (tensor<1x3136xf32>) {
%region = "quant.region"(%arg0) ( {
^bb0(%arg1: tensor<1x7x7x64xf32>): // no predecessors
%r = "xla_hlo.reshape"(%arg1) : (tensor<1x7x7x64xf32>) -> (tensor<1x3136xf32>)
"quant.return"(%r) : (tensor<1x3136xf32>) -> ()
}) {input_specs = [!quant.uniform<i8:f32, 1.0>], logical_kernel = "generic.reshape", output_specs = [f32]} : (tensor<1x7x7x64xf32>) -> tensor<1x3136xf32>
return %region : tensor<1x3136xf32>
// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00>]
// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00>]
}

View File

@ -1,37 +1,53 @@
// RUN: tf-opt %s -tfl-identify-dilated-conv | FileCheck %s --dump-input-on-failure // RUN: tf-opt %s -tfl-identify-dilated-conv | FileCheck %s --dump-input-on-failure
func @testDilatedConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { func @testDilatedConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<2> : tensor<2x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
%1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
return %2 : tensor<1x128x128x8xf32>
// CHECK-LABEL: testDilatedConv
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
}
func @testDilatedConvWithNonConstantPadAndCrops(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32> %cst = constant dense<[2, 2]> : tensor<2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> %1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> %2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
return %2 : tensor<1x128x128x8xf32> return %2 : tensor<1x128x128x8xf32>
// CHECK-LABEL: testDilatedConv // CHECK-LABEL: testDilatedConvWithNonConstantPadAndCrops
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>) // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
} }
func @testDilatedConvWithNonZeroSTBPadding(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { func @testDilatedConvWithNonZeroBasePadding(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32> %cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<2> : tensor<2x2xi32> %cst_0 = constant dense<2> : tensor<2x2xi32>
%cst_1 = constant dense<1> : tensor<2x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> %1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
return %2 : tensor<1x128x128x8xf32> return %2 : tensor<1x128x128x8xf32>
// CHECK-LABEL: testDilatedConvWithNonZeroSTBPadding // CHECK-LABEL: testDilatedConvWithNonZeroBasePadding
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>) // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
} }
func @testDilatedConvWithNonTrivialDilations(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { func @testDilatedConvWithNonTrivialDilations(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32> %cst = constant dense<[2, 2]> : tensor<2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %cst_0 = constant dense<2> : tensor<2x2xi32>
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", dilations = [1, 2, 2, 1], strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> %1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", dilations = [1, 2, 2, 1], strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
return %2 : tensor<1x128x128x8xf32> return %2 : tensor<1x128x128x8xf32>
// CHECK-LABEL: testDilatedConvWithNonTrivialDilations // CHECK-LABEL: testDilatedConvWithNonTrivialDilations
@ -41,25 +57,27 @@ func @testDilatedConvWithNonTrivialDilations(%arg0: tensor<1x128x128x3xf32>, %ar
// CHECK-NEXT: return [[RESULT]] // CHECK-NEXT: return [[RESULT]]
} }
func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32> %cst = constant dense<[2, 2]> : tensor<2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %cst_0 = constant dense<2> : tensor<2x2xi32>
%1 = "tf.DepthwiseConv2dNative"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> %1 = "tf.DepthwiseConv2dNative"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
return %2 : tensor<1x128x128x8xf32> return %2 : tensor<1x128x128x8xf32>
// CHECK-LABEL: testDilatedDepthWiseConv // CHECK-LABEL: testDilatedDepthWiseConv
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>) // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
// CHECK-NEXT: [[RESULT:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
} }
func @testDilatedConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> { func @testDilatedConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32> %cst = constant dense<[2, 2]> : tensor<2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %cst_0 = constant dense<2> : tensor<2x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> %1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%2 = "tf.Pad"(%1, %arg1) : (tensor<4x64x64x8xf32>, tensor<2x2xi32>) -> tensor<4x64x64x8xf32> %2 = "tf.Pad"(%1, %arg1) : (tensor<4x64x64x8xf32>, tensor<2x2xi32>) -> tensor<4x64x64x8xf32>
%3 = "tf.BatchToSpaceND"(%2, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> %3 = "tf.BatchToSpaceND"(%2, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
%4 = "tf.BiasAdd"(%3, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> %4 = "tf.BiasAdd"(%3, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
return %4 : tensor<1x128x128x8xf32> return %4 : tensor<1x128x128x8xf32>
@ -72,10 +90,11 @@ func @testDilatedConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi
func @testDilatedDepthWiseConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> { func @testDilatedDepthWiseConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32> %cst = constant dense<[2, 2]> : tensor<2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %cst_0 = constant dense<2> : tensor<2x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
%1 = "tf.DepthwiseConv2dNative"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> %1 = "tf.DepthwiseConv2dNative"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%2 = "tf.Pad"(%1, %arg1) : (tensor<4x64x64x8xf32>, tensor<2x2xi32>) -> tensor<4x64x64x8xf32> %2 = "tf.Pad"(%1, %arg1) : (tensor<4x64x64x8xf32>, tensor<2x2xi32>) -> tensor<4x64x64x8xf32>
%3 = "tf.BatchToSpaceND"(%2, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> %3 = "tf.BatchToSpaceND"(%2, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
%4 = "tf.BiasAdd"(%3, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> %4 = "tf.BiasAdd"(%3, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
return %4 : tensor<1x128x128x8xf32> return %4 : tensor<1x128x128x8xf32>
@ -86,49 +105,52 @@ func @testDilatedDepthWiseConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: ten
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
} }
func @testDilatedConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> { func @testDilatedConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32> %cst = constant dense<[2, 2]> : tensor<2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %cst_0 = constant dense<2> : tensor<2x2xi32>
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> %1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%3 = "tf.BiasAdd"(%2, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
%3 = "tf.BiasAdd"(%2, %arg2) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
return %3 : tensor<1x128x128x8xf32> return %3 : tensor<1x128x128x8xf32>
// CHECK-LABEL: testDilatedConvWithBiasAdd // CHECK-LABEL: testDilatedConvWithBiasAdd
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>) // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
} }
func @testDilatedDepthWiseConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> { func @testDilatedDepthWiseConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32> %cst = constant dense<[2, 2]> : tensor<2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %cst_0 = constant dense<2> : tensor<2x2xi32>
%1 = "tf.DepthwiseConv2dNative"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> %1 = "tf.DepthwiseConv2dNative"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%3 = "tf.BiasAdd"(%2, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
%3 = "tf.BiasAdd"(%2, %arg2) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
return %3 : tensor<1x128x128x8xf32> return %3 : tensor<1x128x128x8xf32>
// CHECK-LABEL: testDilatedDepthWiseConvWithBiasAdd // CHECK-LABEL: testDilatedDepthWiseConvWithBiasAdd
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>) // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
} }
func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> { func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<128xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32> %cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> %cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %cst_1 = constant dense<2> : tensor<2x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> %2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
%4 = "tf.BatchToSpaceND"(%3, %cst, %arg1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> %4 = "tf.BatchToSpaceND"(%3, %cst, %cst_1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
%5 = "tf.BiasAdd"(%4, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> %5 = "tf.BiasAdd"(%4, %arg2) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
return %5 : tensor<1x128x128xf32> return %5 : tensor<1x128x128xf32>
// CHECK-LABEL: testDilatedConvWithExpandSqueeze1 // CHECK-LABEL: testDilatedConvWithExpandSqueeze1
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>) // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32> // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
@ -137,19 +159,20 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
} }
func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> { func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<128xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32> %cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> %cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %cst_1 = constant dense<2> : tensor<2x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> %2 = "tf.DepthwiseConv2dNative"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
%4 = "tf.BatchToSpaceND"(%3, %cst, %arg1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> %4 = "tf.BatchToSpaceND"(%3, %cst, %cst_1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
%5 = "tf.BiasAdd"(%4, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> %5 = "tf.BiasAdd"(%4, %arg2) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
return %5 : tensor<1x128x128xf32> return %5 : tensor<1x128x128xf32>
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze1 // CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze1
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>) // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32> // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
@ -158,19 +181,20 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
} }
func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> { func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<?xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32> %cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> %cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32> %cst_1 = constant dense<2> : tensor<2x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32> %2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x?x?x1xf32>) -> tensor<4x?x?xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x?x?x1xf32>) -> tensor<4x?x?xf32>
%4 = "tf.BiasAdd"(%3, %arg3) : (tensor<4x?x?xf32>, tensor<?xf32>) -> tensor<4x?x?xf32> %4 = "tf.BiasAdd"(%3, %arg2) : (tensor<4x?x?xf32>, tensor<?xf32>) -> tensor<4x?x?xf32>
%5 = "tf.BatchToSpaceND"(%4, %cst, %arg1) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_1) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
return %5 : tensor<1x128x128xf32> return %5 : tensor<1x128x128xf32>
// CHECK-LABEL: testDilatedConvWithExpandSqueeze2 // CHECK-LABEL: testDilatedConvWithExpandSqueeze2
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>) // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32> // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
@ -179,19 +203,20 @@ func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: ten
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
} }
func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> { func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<?xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32> %cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> %cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32> %cst_1 = constant dense<2> : tensor<2x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32> %2 = "tf.DepthwiseConv2dNative"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x?x?x1xf32>) -> tensor<4x?x?xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x?x?x1xf32>) -> tensor<4x?x?xf32>
%4 = "tf.BiasAdd"(%3, %arg3) : (tensor<4x?x?xf32>, tensor<?xf32>) -> tensor<4x?x?xf32> %4 = "tf.BiasAdd"(%3, %arg2) : (tensor<4x?x?xf32>, tensor<?xf32>) -> tensor<4x?x?xf32>
%5 = "tf.BatchToSpaceND"(%4, %cst, %arg1) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_1) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
return %5 : tensor<1x128x128xf32> return %5 : tensor<1x128x128xf32>
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze2 // CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze2
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>) // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32> // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
@ -203,12 +228,13 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %
func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> { func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32> %cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> %cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %cst_1 = constant dense<2> : tensor<2x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> %2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
%4 = "tf.Pad"(%3, %arg1) : (tensor<4x64x64xf32>, tensor<2x2xi32>) -> tensor<4x64x64xf32> %4 = "tf.Pad"(%3, %arg1) : (tensor<4x64x64xf32>, tensor<2x2xi32>) -> tensor<4x64x64xf32>
%5 = "tf.BatchToSpaceND"(%4, %cst, %arg1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
%6 = "tf.BiasAdd"(%5, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> %6 = "tf.BiasAdd"(%5, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
return %6 : tensor<1x128x128xf32> return %6 : tensor<1x128x128xf32>
@ -225,12 +251,13 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten
func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> { func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32> %cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> %cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %cst_1 = constant dense<2> : tensor<2x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> %2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
%4 = "tf.Pad"(%3, %arg1) : (tensor<4x64x64xf32>, tensor<2x2xi32>) -> tensor<4x64x64xf32> %4 = "tf.Pad"(%3, %arg1) : (tensor<4x64x64xf32>, tensor<2x2xi32>) -> tensor<4x64x64xf32>
%5 = "tf.BatchToSpaceND"(%4, %cst, %arg1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
%6 = "tf.BiasAdd"(%5, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> %6 = "tf.BiasAdd"(%5, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
return %6 : tensor<1x128x128xf32> return %6 : tensor<1x128x128xf32>
@ -244,14 +271,15 @@ func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
} }
func @testDilatedConvWithDifferentExpandSqueezeAxis(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128x1xf32> { func @testDilatedConvWithDifferentExpandSqueezeAxis(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32> %cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> %cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %cst_1 = constant dense<2> : tensor<2x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> %2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [2]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64x1xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [2]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64x1xf32>
%4 = "tf.BatchToSpaceND"(%3, %cst, %arg1) : (tensor<4x64x64x1xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x1xf32> %4 = "tf.BatchToSpaceND"(%3, %cst, %cst_1) : (tensor<4x64x64x1xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x1xf32>
return %4 : tensor<1x128x128x1xf32> return %4 : tensor<1x128x128x1xf32>
// CHECK-LABEL: testDilatedConvWithDifferentExpandSqueezeAxis // CHECK-LABEL: testDilatedConvWithDifferentExpandSqueezeAxis

View File

@ -142,7 +142,7 @@ versions {
# CHECK-SAME: control_outputs = "" # CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "unranked" # CHECK-SAME: inputs = "unranked"
# CHECK-SAME: outputs = "unranked,static,static_10" # CHECK-SAME: outputs = "unranked,static,static_10"
# CHECK: [[VAL_1:%.*]] = constant dense<0> : tensor<10xi32> # CHECK: [[VAL_1:%.*]] = constant dense<0> : tensor<i32>
# CHECK: [[VAL_2:%.*]] = constant dense<0> : tensor<i32> # CHECK: [[VAL_2:%.*]] = constant dense<0> : tensor<10xi32>
# CHECK: return [[VAL_0]], [[VAL_2]], [[VAL_1]] : tensor<1x8x8x2xi32>, tensor<i32>, tensor<10xi32> # CHECK: return [[VAL_0]], [[VAL_1]], [[VAL_2]] : tensor<1x8x8x2xi32>, tensor<i32>, tensor<10xi32>
# CHECK: } # CHECK: }

View File

@ -7788,35 +7788,35 @@ library {
# CHECK-SAME: control_outputs = "" # CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "INPUT" # CHECK-SAME: inputs = "INPUT"
# CHECK-SAME: outputs = "OUTPUT" # CHECK-SAME: outputs = "OUTPUT"
# CHECK: [[VAL_1:%.*]] = constant dense<{{\[\[}}-0.400154352, 0.739109992, 0.201825857], [0.678572893, 0.32076478, 0.949867963], [-0.807729483, -5.324750e-01, 0.148033619]]> : tensor<3x3xf32> # CHECK: [[VAL_1:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
# CHECK: [[VAL_2:%.*]] = constant dense<{{\[\[}}0.886177539, -0.606141329, -0.451275587], [0.325554609, 0.691527605, -0.676239967], [0.219799042, 0.626042128, -0.597596407]]> : tensor<3x3xf32> # CHECK: [[VAL_2:%.*]] = constant dense<0.000000e+00> : tensor<3xf32>
# CHECK: [[VAL_3:%.*]] = constant dense<{{\[\[}}-0.493826151, -0.391061306, -0.349843264], [-0.0213134289, 0.558384657, -0.51513052], [0.427886248, 0.618100405, -0.187585592]]> : tensor<3x3xf32> # CHECK: [[VAL_3:%.*]] = constant dense<{{\[\[}}-0.856678485, -0.800494194, 0.716800689], [0.536404848, 0.541643381, -0.35657692], [-0.794646739, 0.137629032, 0.690013885]]> : tensor<3x3xf32>
# CHECK: [[VAL_4:%.*]] = constant dense<{{\[\[}}0.444335222, -0.133341789, 0.839591503], [0.445418358, -0.571707964, 0.569707394], [0.465010405, -0.990037918, -0.632481337]]> : tensor<3x3xf32> # CHECK: [[VAL_4:%.*]] = constant dense<{{\[\[}}-0.125753641, 0.32271719, 0.488939524], [0.36119318, 0.982266664, -0.448646784], [0.966353893, -0.767024993, 0.446366787]]> : tensor<3x3xf32>
# CHECK: [[VAL_5:%.*]] = constant dense<{{\[\[}}-0.138204336, -0.10879755, -0.135128736], [0.94797182, -8.713360e-01, -0.792336463], [0.0339827538, -0.539326906, 8.906350e-01]]> : tensor<3x3xf32> # CHECK: [[VAL_5:%.*]] = constant dense<{{\[\[}}0.891112089, -2.786560e-01, 0.966933965], [-0.789963722, 0.057955265, 0.217499971], [-0.698129416, -0.983400583, -0.834380626]]> : tensor<3x3xf32>
# CHECK: [[VAL_6:%.*]] = constant dense<{{\[\[}}0.513064623, -0.692989588, 0.547988653], [0.0653710365, 0.576977491, 0.966733217], [0.0130724907, 0.247342348, 0.317092657]]> : tensor<3x3xf32> # CHECK: [[VAL_6:%.*]] = constant dense<{{\[\[}}0.782244444, -0.0446639061, 0.848498106], [-0.579102755, -0.407756329, 0.442389727], [0.00566458702, 0.5984025, 0.629857302]]> : tensor<3x3xf32>
# CHECK: [[VAL_7:%.*]] = constant dense<{{\[\[}}0.230039358, -0.182297707, -0.352231741], [-0.805100203, -0.220300436, -0.669503212], [0.278807402, -0.201502323, -0.627609729]]> : tensor<3x3xf32> # CHECK: [[VAL_7:%.*]] = constant dense<1.000000e+00> : tensor<3xf32>
# CHECK: [[VAL_8:%.*]] = constant dense<{{\[\[}}-0.207589626, -0.756766081, -0.853258133], [-0.269270182, 0.0468223095, -0.353052378], [-0.0702953338, 0.0725159645, -0.817753077]]> : tensor<3x3xf32> # CHECK: [[VAL_8:%.*]] = constant dense<{{\[\[}}-0.616786718, 0.892614365, 0.671324968], [-0.842380046, -0.358094931, 0.821366549], [0.790347338, 0.71222949, 0.0690443515]]> : tensor<3x3xf32>
# CHECK: [[VAL_9:%.*]] = constant dense<[0.171322107, -0.153412342, 0.591750383]> : tensor<3xf32> # CHECK: [[VAL_9:%.*]] = constant dense<{{\[\[}}-5.087240e-01, -0.588907719, 0.471896172], [-0.508019447, -0.0157074928, -0.804120779], [-0.978842973, 0.00160336494, -0.978532075]]> : tensor<3x3xf32>
# CHECK: [[VAL_10:%.*]] = constant dense<[-0.671292543, 0.411814928, 0.560465336]> : tensor<3xf32> # CHECK: [[VAL_10:%.*]] = constant dense<{{\[\[}}0.18183589, 0.616135359, -0.167827845], [0.734281301, 0.958347797, -0.878054618], [0.369523764, -0.969005823, -0.881014585]]> : tensor<3x3xf32>
# CHECK: [[VAL_11:%.*]] = constant dense<[0.403919935, -0.882057666, -0.894463062]> : tensor<3xf32> # CHECK: [[VAL_11:%.*]] = constant dense<{{\[\[}}-0.936182261, -0.935433864, 0.288229942], [-0.243383884, -0.628288031, -0.477061749], [-0.514976501, -0.903514862, 6.728170e-01]]> : tensor<3x3xf32>
# CHECK: [[VAL_12:%.*]] = constant dense<{{\[\[}}-0.936182261, -0.935433864, 0.288229942], [-0.243383884, -0.628288031, -0.477061749], [-0.514976501, -0.903514862, 6.728170e-01]]> : tensor<3x3xf32> # CHECK: [[VAL_12:%.*]] = constant dense<{{\[}}0.403919935, -0.882057666, -0.894463062]> : tensor<3xf32>
# CHECK: [[VAL_13:%.*]] = constant dense<{{\[\[}}0.18183589, 0.616135359, -0.167827845], [0.734281301, 0.958347797, -0.878054618], [0.369523764, -0.969005823, -0.881014585]]> : tensor<3x3xf32> # CHECK: [[VAL_13:%.*]] = constant dense<{{\[}}-0.671292543, 0.411814928, 0.560465336]> : tensor<3xf32>
# CHECK: [[VAL_14:%.*]] = constant dense<{{\[\[}}-5.087240e-01, -0.588907719, 0.471896172], [-0.508019447, -0.0157074928, -0.804120779], [-0.978842973, 0.00160336494, -0.978532075]]> : tensor<3x3xf32> # CHECK: [[VAL_14:%.*]] = constant dense<{{\[}}0.171322107, -0.153412342, 0.591750383]> : tensor<3xf32>
# CHECK: [[VAL_15:%.*]] = constant dense<{{\[\[}}-0.616786718, 0.892614365, 0.671324968], [-0.842380046, -0.358094931, 0.821366549], [0.790347338, 0.71222949, 0.0690443515]]> : tensor<3x3xf32> # CHECK: [[VAL_15:%.*]] = constant dense<{{\[\[}}-0.207589626, -0.756766081, -0.853258133], [-0.269270182, 0.0468223095, -0.353052378], [-0.0702953338, 0.0725159645, -0.817753077]]> : tensor<3x3xf32>
# CHECK: [[VAL_16:%.*]] = constant dense<1.000000e+00> : tensor<3xf32> # CHECK: [[VAL_16:%.*]] = constant dense<{{\[\[}}0.230039358, -0.182297707, -0.352231741], [-0.805100203, -0.220300436, -0.669503212], [0.278807402, -0.201502323, -0.627609729]]> : tensor<3x3xf32>
# CHECK: [[VAL_17:%.*]] = constant dense<{{\[\[}}0.782244444, -0.0446639061, 0.848498106], [-0.579102755, -0.407756329, 0.442389727], [0.00566458702, 0.5984025, 0.629857302]]> : tensor<3x3xf32> # CHECK: [[VAL_17:%.*]] = constant dense<{{\[\[}}0.513064623, -0.692989588, 0.547988653], [0.0653710365, 0.576977491, 0.966733217], [0.0130724907, 0.247342348, 0.317092657]]> : tensor<3x3xf32>
# CHECK: [[VAL_18:%.*]] = constant dense<{{\[\[}}0.891112089, -2.786560e-01, 0.966933965], [-0.789963722, 0.057955265, 0.217499971], [-0.698129416, -0.983400583, -0.834380626]]> : tensor<3x3xf32> # CHECK: [[VAL_18:%.*]] = constant dense<{{\[\[}}-0.138204336, -0.10879755, -0.135128736], [0.94797182, -8.713360e-01, -0.792336463], [0.0339827538, -0.539326906, 8.906350e-01]]> : tensor<3x3xf32>
# CHECK: [[VAL_19:%.*]] = constant dense<{{\[\[}}-0.125753641, 0.32271719, 0.488939524], [0.36119318, 0.982266664, -0.448646784], [0.966353893, -0.767024993, 0.446366787]]> : tensor<3x3xf32> # CHECK: [[VAL_19:%.*]] = constant dense<{{\[\[}}0.444335222, -0.133341789, 0.839591503], [0.445418358, -0.571707964, 0.569707394], [0.465010405, -0.990037918, -0.632481337]]> : tensor<3x3xf32>
# CHECK: [[VAL_20:%.*]] = constant dense<{{\[\[}}-0.856678485, -0.800494194, 0.716800689], [0.536404848, 0.541643381, -0.35657692], [-0.794646739, 0.137629032, 0.690013885]]> : tensor<3x3xf32> # CHECK: [[VAL_20:%.*]] = constant dense<{{\[\[}}-0.493826151, -0.391061306, -0.349843264], [-0.0213134289, 0.558384657, -0.51513052], [0.427886248, 0.618100405, -0.187585592]]> : tensor<3x3xf32>
# CHECK: [[VAL_21:%.*]] = constant dense<0.000000e+00> : tensor<3xf32> # CHECK: [[VAL_21:%.*]] = constant dense<{{\[\[}}0.886177539, -0.606141329, -0.451275587], [0.325554609, 0.691527605, -0.676239967], [0.219799042, 0.626042128, -0.597596407]]> : tensor<3x3xf32>
# CHECK: [[VAL_22:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32> # CHECK: [[VAL_22:%.*]] = constant dense<{{\[\[}}-0.400154352, 0.739109992, 0.201825857], [0.678572893, 0.32076478, 0.949867963], [-0.807729483, -5.324750e-01, 0.148033619]]> : tensor<3x3xf32>
# CHECK: [[VAL_23:%.*]] = constant unit # CHECK: [[VAL_23:%.*]] = constant unit
# CHECK: [[VAL_24:%.*]]:3 = "tfl.unpack"(%[[ARG_0]]) {axis = 1 : i32, num = 3 : i32} : (tensor<1x3x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) # CHECK: [[UNPACK:%.*]]:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<1x3x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>)
# CHECK: [[VAL_25:%.*]] = "tfl.pack"([[VAL_24]]#0, [[VAL_24]]#1, [[VAL_24]]#2) {axis = 0 : i32, values_count = 3 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<3x1x3xf32> # CHECK: [[PACK:%.*]] = "tfl.pack"([[UNPACK]]#0, [[UNPACK]]#1, [[UNPACK]]#2) {axis = 0 : i32, values_count = 3 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<3x1x3xf32>
# CHECK: [[VAL_24:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
# CHECK: [[UNIDIRECTIONAL_SEQUENCE_LSTM_1:%.*]] = "tfl.unidirectional_sequence_lstm"([[PACK]], [[VAL_16]], [[VAL_17]], [[VAL_18]], [[VAL_15]], [[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_19]], [[VAL_13]], [[VAL_14]], [[VAL_12]], [[VAL_2]], [[VAL_7]], [[VAL_2]], [[VAL_2]], [[VAL_23]], [[VAL_23]], [[VAL_1]], [[VAL_24]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32>
# CHECK: [[VAL_25:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
# CHECK: [[VAL_26:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32> # CHECK: [[VAL_26:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32>
# CHECK: [[VAL_27:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_25]], [[VAL_7]], [[VAL_6]], [[VAL_5]], [[VAL_8]], [[VAL_3]], [[VAL_2]], [[VAL_1]], [[VAL_4]], [[VAL_10]], [[VAL_9]], [[VAL_11]], [[VAL_21]], [[VAL_16]], [[VAL_21]], [[VAL_21]], [[VAL_23]], [[VAL_23]], [[VAL_22]], [[VAL_26]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32> # CHECK: [[UNIDIRECTIONAL_SEQUENCE_LSTM_2:%.*]] = "tfl.unidirectional_sequence_lstm"([[UNIDIRECTIONAL_SEQUENCE_LSTM_1]], [[VAL_4]], [[VAL_5]], [[VAL_6]], [[VAL_3]], [[VAL_9]], [[VAL_10]], [[VAL_11]], [[VAL_8]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_2]], [[VAL_7]], [[VAL_2]], [[VAL_2]], [[VAL_23]], [[VAL_23]], [[VAL_25]], [[VAL_26]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, none, none, none, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32>
# CHECK: [[VAL_28:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32> # CHECK: [[RESULT:%.*]]:3 = "tfl.unpack"([[UNIDIRECTIONAL_SEQUENCE_LSTM_2]]) {axis = 0 : i32, num = 3 : i32} : (tensor<3x1x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>)
# CHECK: [[VAL_29:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32> # CHECK: return [[RESULT]]#2 : tensor<1x3xf32>
# CHECK: [[VAL_30:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_27]], [[VAL_19]], [[VAL_18]], [[VAL_17]], [[VAL_20]], [[VAL_14]], [[VAL_13]], [[VAL_12]], [[VAL_15]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_21]], [[VAL_16]], [[VAL_21]], [[VAL_21]], [[VAL_23]], [[VAL_23]], [[VAL_28]], [[VAL_29]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, none, none, none, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32>
# CHECK: [[VAL_31:%.*]]:3 = "tfl.unpack"([[VAL_30]]) {axis = 0 : i32, num = 3 : i32} : (tensor<3x1x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>)
# CHECK: return [[VAL_31]]#2 : tensor<1x3xf32>

View File

@ -28,6 +28,13 @@ func @f32() -> tensor<4xf32> {
return %0 : tensor<4xf32> return %0 : tensor<4xf32>
} }
func @f64() -> tensor<4xf64> {
// CHECK-LABEL: @f64
// CHECK: value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf64>
%0 = "tfl.pseudo_const"() { value = dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64> } : () -> tensor<4xf64>
return %0 : tensor<4xf64>
}
func @i8() -> tensor<4xi8> { func @i8() -> tensor<4xi8> {
// CHECK-LABEL: @i8 // CHECK-LABEL: @i8
// CHECK: value = dense<[1, 2, 3, 4]> : tensor<4xi8> // CHECK: value = dense<[1, 2, 3, 4]> : tensor<4xi8>

View File

@ -829,6 +829,14 @@ func @pack3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tensor<2x
// CHECK: "tfl.pack"(%arg0, %arg1, %arg2) {axis = 1 : i32, values_count = 3 : i32} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32> // CHECK: "tfl.pack"(%arg0, %arg1, %arg2) {axis = 1 : i32, values_count = 3 : i32} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
} }
func @packStringWithFlex(%arg0: tensor<2x!tf.string>, %arg1: tensor<2x!tf.string>) -> tensor<2x2x!tf.string> {
%0 = "tf.Pack"(%arg0, %arg1) : (tensor<2x!tf.string>, tensor<2x!tf.string>) -> tensor<2x2x!tf.string>
return %0 : tensor<2x2x!tf.string>
// CHECK-LABEL: packStringWithFlex
// CHECK: "tf.Pack"(%arg0, %arg1) : (tensor<2x!tf.string>, tensor<2x!tf.string>) -> tensor<2x2x!tf.string>
}
func @packNegAxis(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tensor<2xi32>) -> tensor<2x3xi32> { func @packNegAxis(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tensor<2xi32>) -> tensor<2x3xi32> {
%0 = "tf.Pack"(%arg0, %arg1, %arg2) {axis = -1 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32> %0 = "tf.Pack"(%arg0, %arg1, %arg2) {axis = -1 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
return %0 : tensor<2x3xi32> return %0 : tensor<2x3xi32>

View File

@ -0,0 +1,66 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
func @main(tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> {
^bb0(%arg0: tensor<4xf64>, %arg1: tensor<4xf64>):
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "FlexAdd"
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: type: FLOAT64,
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: type: FLOAT64,
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "arg1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: type: FLOAT64,
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "add",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: custom_options: [ 3, 65, 100, 100, 0, 20, 18, 3, 65, 100, 100, 26, 0, 26, 0, 42, 7, 10, 1, 84, 18, 2, 48, 2, 50, 0, 0, 2, 27, 23, 20, 20, 4, 40, 1 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 4
// CHECK-NEXT: } ]
// CHECK-NEXT:}
%0 = "tf.Add"(%arg0, %arg1) : (tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> loc("add")
return %0 : tensor<4xf64>
}

View File

@ -138,13 +138,24 @@ int main(int argc, char **argv) {
// TODO(b/147435528): We need to test the e2e behavior once the graph freezing // TODO(b/147435528): We need to test the e2e behavior once the graph freezing
// inside mlir is done. // inside mlir is done.
if (import_saved_model_object_graph || import_saved_model_signature_defs) { if (import_saved_model_object_graph || import_saved_model_signature_defs) {
int saved_model_version;
if (import_saved_model_object_graph) {
saved_model_version = 2;
} else {
saved_model_version = 1;
}
if (input_mlir) if (input_mlir)
module = tensorflow::errors::InvalidArgument( module = tensorflow::errors::InvalidArgument(
"Importing saved model should not have input_mlir set"); "Importing saved model should not have input_mlir set");
module = tensorflow::ImportSavedModel(import_saved_model_object_graph,
import_saved_model_signature_defs, std::unordered_set<std::string> tags =
input_file_name, saved_model_tags, absl::StrSplit(saved_model_tags, ',');
saved_model_exported_names, &context); std::vector<std::string> exported_names_vector =
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
absl::Span<std::string> exported_names(exported_names_vector);
module = tensorflow::ImportSavedModel(input_file_name, saved_model_version,
tags, exported_names, &context);
} else { } else {
module = tensorflow::LoadFromGraphdefOrMlirSource( module = tensorflow::LoadFromGraphdefOrMlirSource(
input_file_name, input_mlir, use_splatted_constant, custom_opdefs, input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
@ -197,11 +208,6 @@ int main(int argc, char **argv) {
pass_config.lower_tensor_list_ops = lower_tensor_list_ops; pass_config.lower_tensor_list_ops = lower_tensor_list_ops;
pass_config.legalize_tf_while = convert_tf_while_to_tfl_while; pass_config.legalize_tf_while = convert_tf_while_to_tfl_while;
// Currently we only do shape inference for saved model import.
if (import_saved_model_object_graph || import_saved_model_signature_defs) {
pass_config.shape_inference = true;
}
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm); tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
// TODO(b/150901738): Move those into tf_tfl_translate.cc. // TODO(b/150901738): Move those into tf_tfl_translate.cc.
// Convert back to outlined while format for export back to flatbuffer. // Convert back to outlined while format for export back to flatbuffer.

View File

@ -160,25 +160,17 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
} }
StatusOr<mlir::OwningModuleRef> ImportSavedModel( StatusOr<mlir::OwningModuleRef> ImportSavedModel(
bool import_saved_model, bool import_saved_model_v1, const std::string& input_filename, const int saved_model_version,
const std::string& input_filename, const std::string& saved_model_tags, const std::unordered_set<std::string>& tags,
const std::string& saved_model_exported_names, mlir::MLIRContext* context) { absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
if (import_saved_model) { if (saved_model_version == 2) {
std::unordered_set<std::string> tags =
absl::StrSplit(saved_model_tags, ',');
std::vector<std::string> exported_names =
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
auto module = tensorflow::SavedModelObjectGraphToMlirImport( auto module = tensorflow::SavedModelObjectGraphToMlirImport(
input_filename, tags, absl::Span<std::string>(exported_names), context); input_filename, tags, exported_names, context);
if (!module) if (!module)
return tensorflow::errors::InvalidArgument("fail to open input file"); return tensorflow::errors::InvalidArgument("fail to open input file");
return module; return module;
} else if (import_saved_model_v1) { } else if (saved_model_version == 1) {
std::unordered_set<std::string> tags =
absl::StrSplit(saved_model_tags, ',');
auto module = tensorflow::SavedModelSignatureDefsToMlirImport( auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
input_filename, tags, context); input_filename, tags, context);

View File

@ -16,6 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_ #ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
#include <unordered_set>
#include "absl/types/span.h"
#include "llvm/Support/SourceMgr.h" #include "llvm/Support/SourceMgr.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project
@ -42,9 +45,9 @@ LoadFromGraphdefOrMlirSource(
// Load Saved model (either v1 or v2) into MLIR. // Load Saved model (either v1 or v2) into MLIR.
stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel( stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
bool import_saved_model, bool import_saved_model_v1, const std::string& input_filename, const int saved_model_version,
const std::string& input_filename, const std::string& saved_model_tags, const std::unordered_set<std::string>& tags,
const std::string& saved_model_exported_names, mlir::MLIRContext* context); absl::Span<std::string> exported_names, mlir::MLIRContext* context);
// Taking a MLIR module in TF executor dialect and a set of parameters, // Taking a MLIR module in TF executor dialect and a set of parameters,
// applies a set of passes to convert the module to TF Lite dialect and // applies a set of passes to convert the module to TF Lite dialect and

View File

@ -152,7 +152,6 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
} }
// BatchToSpaceND + BiasAdd. // BatchToSpaceND + BiasAdd.
// TODO(b/149936532): Check the `crops` input, currently ignored.
TF::BatchToSpaceNDOp bts_op; TF::BatchToSpaceNDOp bts_op;
TF::BiasAddOp biasadd_op; TF::BiasAddOp biasadd_op;
bool final_op_is_bts = true; bool final_op_is_bts = true;
@ -179,16 +178,50 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
if (!dilations_attr.hasValue()) return failure(); if (!dilations_attr.hasValue()) return failure();
op.setAttr("dilations", dilations_attr.getValue()); op.setAttr("dilations", dilations_attr.getValue());
// Padding is set to 'SAME' when `stb_op` has non-zero paddings. // TODO(b/149936532): Check that the input width & height are multiples of
// TODO(b/149936532): This assumption only holds when the input width & height // dilation rate.
// is multiple of dilation width & height. We should fix it in order to // TF python library will rewrite dilated conv to
// support other use cases. // "SpaceToBatch->Conv->BatchToSpace" pattern, and the Conv in the middle
// always has 'VALID' padding. The padding tensor in `SpaceToBatch` has two
// parts of contributions, one is to reduce padding of CONV from 'SAME' to
// 'VALID', and another is to make input shape multiples of dilation rate. The
// first part of padding, which is also called `base_padding` will be used
// here to determine if the original padding format is 'SAME' or 'VALID'.
// According to the following formula we will compute the `base_padding` if
// it's a constant. Basically, `paddings` tensor in `SpaceToBatch` and `crops`
// tensor in `BatchToSpace` must satisfy the following:
// paddings[i, 0] = base_paddings[i, 0].
// 0 <= paddings[i, 1] - base_paddings[i, 1] < block_shape[i]
// (input_shape[i] + paddings[i, 0] + paddings[i, 1]) % block_shape[i] == 0.
// crops[i, 0] = 0.
// crops[i, 1] = paddings[i, 1] - base_paddings[i, 1].
// If `paddings` - `crops` != 0, this means that `base_paddings` != 0, which
// tells us the original padding is 'SAME' (with one caveat presented below).
// Here we need to reset the padding back to `SAME` if `base_padding`
// != 0.
// TODO(b/149936532): We might not simply rely on `paddings - crops != 0` to
// determine the original padding format. For example, users can build
// arbitrary valid examples of `STB->Conv->BTS` which doesn't represent a
// dilated conv, hence we shouldn't pattern match here. Instead, we need to
// check values of `paddings` and `crops` to make sure it really stands for
// a dilated conv.
auto stb_paddings = stb_op.paddings(); auto stb_paddings = stb_op.paddings();
ElementsAttr stb_paddings_attr; auto bts_crops = bts_op.crops();
if (matchPattern(stb_paddings, m_Constant(&stb_paddings_attr))) { ElementsAttr stb_paddings_attr, bts_crops_attr;
if (llvm::any_of(stb_paddings_attr.getValues<IntegerAttr>(), if (matchPattern(stb_paddings, m_Constant(&stb_paddings_attr)) &&
[](IntegerAttr attr) { return attr.getInt() != 0; })) { matchPattern(bts_crops, m_Constant(&bts_crops_attr))) {
op.setAttr("padding", rewriter.getStringAttr("SAME")); if (stb_paddings_attr.getNumElements() != bts_crops_attr.getNumElements())
return failure();
// padding - crop.
auto paddings = stb_paddings_attr.getValues<IntegerAttr>();
auto crops = bts_crops_attr.getValues<IntegerAttr>();
for (auto it1 = paddings.begin(), it2 = crops.begin();
it1 != paddings.end() && it2 != crops.end(); it1++, it2++) {
if ((*it1).getInt() != (*it2).getInt()) {
op.setAttr("padding", rewriter.getStringAttr("SAME"));
break;
}
} }
} }

View File

@ -679,8 +679,8 @@ LogicalResult ConvertOphintToStub(StringRef stub_name,
return success(); return success();
} }
struct ExtractOphintPass : public ModulePass<ExtractOphintPass> { struct ExtractOphintPass : public OperationPass<ExtractOphintPass, ModuleOp> {
void runOnModule() override; void runOnOperation() override;
void Verify(); void Verify();
private: private:
@ -689,8 +689,8 @@ struct ExtractOphintPass : public ModulePass<ExtractOphintPass> {
// TODO(renjieliu): Current ophint extraction does not support inputs/outputs // TODO(renjieliu): Current ophint extraction does not support inputs/outputs
// cross functions, we need to do that. // cross functions, we need to do that.
void ExtractOphintPass::runOnModule() { void ExtractOphintPass::runOnOperation() {
ModuleOp module = getModule(); ModuleOp module = getOperation();
for (auto function : module.getOps<FuncOp>()) { for (auto function : module.getOps<FuncOp>()) {
// Process block by block. // Process block by block.
for (auto& bb : function.getBody()) { for (auto& bb : function.getBody()) {
@ -710,7 +710,7 @@ void ExtractOphintPass::runOnModule() {
ophint_composite_ops_count = ophint_composite_ops.size(); ophint_composite_ops_count = ophint_composite_ops.size();
// Convert. // Convert.
OpBuilder builder(&bb); OpBuilder builder = OpBuilder::atBlockEnd(&bb);
for (const auto& kv : ophint_composite_ops) { for (const auto& kv : ophint_composite_ops) {
if (failed(ConvertOphintToStub(kv.getKey(), kv.getValue(), &builder, if (failed(ConvertOphintToStub(kv.getKey(), kv.getValue(), &builder,
&module))) { &module))) {
@ -724,9 +724,9 @@ void ExtractOphintPass::runOnModule() {
} }
void ExtractOphintPass::Verify() { void ExtractOphintPass::Verify() {
ModuleOp module = getModule(); ModuleOp module = getOperation();
int ophint_func_op_count = 0; int ophint_func_op_count = 0;
for (FuncOp func : getModule().getOps<FuncOp>()) { for (FuncOp func : getOperation().getOps<FuncOp>()) {
for (const NamedAttribute attr : func.getAttrs()) { for (const NamedAttribute attr : func.getAttrs()) {
if (attr.first == kTfLiteFunctionName) { if (attr.first == kTfLiteFunctionName) {
ophint_func_op_count++; ophint_func_op_count++;

View File

@ -68,8 +68,9 @@ constexpr char kUnidirectionalSequenceLstm[] = "UnidirectionalSequenceLstm";
// | // |
// | // |
// OutputOp1 // OutputOp1
struct LegalizeOphintFuncOpPass : public ModulePass<LegalizeOphintFuncOpPass> { struct LegalizeOphintFuncOpPass
void runOnModule() override; : public OperationPass<LegalizeOphintFuncOpPass, ModuleOp> {
void runOnOperation() override;
}; };
llvm::StringMap<FuncOp> FindCompositeFuncOps(ModuleOp module) { llvm::StringMap<FuncOp> FindCompositeFuncOps(ModuleOp module) {
@ -256,8 +257,8 @@ LogicalResult ConvertCallOps(llvm::StringMap<FuncOp>* composite_func_ops,
return success(); return success();
} }
void LegalizeOphintFuncOpPass::runOnModule() { void LegalizeOphintFuncOpPass::runOnOperation() {
ModuleOp module = getModule(); ModuleOp module = getOperation();
// Find all composite funcs, then for every call op inside every func op // Find all composite funcs, then for every call op inside every func op
// within the module, we go ahead and replace the callop with the tflite // within the module, we go ahead and replace the callop with the tflite

View File

@ -745,7 +745,8 @@ void LegalizeTF::runOnFunction() {
Optional<ConversionTarget::DynamicLegalityCallbackFn>([](Operation* op) { Optional<ConversionTarget::DynamicLegalityCallbackFn>([](Operation* op) {
auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op); auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
if (!tfl_op) return false; if (!tfl_op) return false;
return succeeded(tfl_op.VerifyTflRuntimeTypes(tfl_op.getOperation())); return succeeded(tfl_op.VerifyTflRuntimeTypes(
tfl_op.getOperation(), /*failure_on_operand_type_mismatch=*/false));
})); }));
// Keep trying to convert. // Keep trying to convert.
// TODO(karimnosseir): This is similar to what apply greedy patterns does. // TODO(karimnosseir): This is similar to what apply greedy patterns does.

View File

@ -31,11 +31,11 @@ namespace {
// Legalize TF While to TFL While with calls to the original functions from the // Legalize TF While to TFL While with calls to the original functions from the
// cond and body regions. // cond and body regions.
struct LegalizeWhile : public ModulePass<LegalizeWhile> { struct LegalizeWhile : public OperationPass<LegalizeWhile, ModuleOp> {
void RunOnFunction(FuncOp func); void RunOnFunction(FuncOp func);
void runOnModule() override { void runOnOperation() override {
for (auto op : getModule().getOps<FuncOp>()) RunOnFunction(op); for (auto op : getOperation().getOps<FuncOp>()) RunOnFunction(op);
} }
}; };

View File

@ -82,8 +82,8 @@ class TensorListPatternRewriter : public PatternRewriter {
/// Lower TensorList ops in functions for subsequent legalization. /// Lower TensorList ops in functions for subsequent legalization.
struct LowerStaticTensorListPass struct LowerStaticTensorListPass
: public ModulePass<LowerStaticTensorListPass> { : public OperationPass<LowerStaticTensorListPass, ModuleOp> {
void runOnModule() override; void runOnOperation() override;
// Apply type and op changes within a function. // Apply type and op changes within a function.
LogicalResult RewriteFunction(FuncOp func, LogicalResult RewriteFunction(FuncOp func,
@ -878,14 +878,14 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
return applyFullConversion(func, target, patterns); return applyFullConversion(func, target, patterns);
} }
void LowerStaticTensorListPass::runOnModule() { void LowerStaticTensorListPass::runOnOperation() {
// TODO(haoliang): currently we process the `main` function first, and the // TODO(haoliang): currently we process the `main` function first, and the
// remaining functions may be processed in arbitrary order. However, this will // remaining functions may be processed in arbitrary order. However, this will
// have a potential issue when one function taking a `DT_VARIANT` is processed // have a potential issue when one function taking a `DT_VARIANT` is processed
// before the function that produces the `DT_VARIANT`. We need to carefully // before the function that produces the `DT_VARIANT`. We need to carefully
// order the functions to be processed. // order the functions to be processed.
std::vector<FuncOp> funcs_in_module; std::vector<FuncOp> funcs_in_module;
for (auto func : getModule().getOps<FuncOp>()) { for (auto func : getOperation().getOps<FuncOp>()) {
// Always place the main function to be the first in the list. // Always place the main function to be the first in the list.
if (func.getName() == "main") { if (func.getName() == "main") {
funcs_in_module.insert(funcs_in_module.begin(), func); funcs_in_module.insert(funcs_in_module.begin(), func);

View File

@ -36,8 +36,8 @@ using FuncSet = llvm::SmallSet<FuncOp, 4>;
// Module pass to optimize TensorFlow functional ops. // Module pass to optimize TensorFlow functional ops.
struct OptimizeFunctionalOpsPass struct OptimizeFunctionalOpsPass
: public ModulePass<OptimizeFunctionalOpsPass> { : public OperationPass<OptimizeFunctionalOpsPass, ModuleOp> {
void runOnModule() override; void runOnOperation() override;
}; };
// Updates function return type of the given functions to match the terminator // Updates function return type of the given functions to match the terminator
@ -180,13 +180,13 @@ static void EraseDeadFuncs(const FuncSet& candidate_funcs, ModuleOp module) {
} }
} }
void OptimizeFunctionalOpsPass::runOnModule() { void OptimizeFunctionalOpsPass::runOnOperation() {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
FuncSet inlined_funcs; FuncSet inlined_funcs;
patterns.insert<FoldIfOp>(&getContext(), &inlined_funcs); patterns.insert<FoldIfOp>(&getContext(), &inlined_funcs);
ModuleOp module = getModule(); ModuleOp module = getOperation();
applyPatternsGreedily(module, patterns); applyPatternsGreedily(module, patterns);
// Erase inlined functions that don't have any references. // Erase inlined functions that don't have any references.

View File

@ -94,14 +94,14 @@ class ConvertEmbeddedLookupFunc {
// body with the corresponding fused TFLite op. The replacement need not always // body with the corresponding fused TFLite op. The replacement need not always
// be a fused op, though that is the primary use case. // be a fused op, though that is the primary use case.
class PrepareCompositeFunctionsPass class PrepareCompositeFunctionsPass
: public ModulePass<PrepareCompositeFunctionsPass> { : public OperationPass<PrepareCompositeFunctionsPass, ModuleOp> {
public: public:
explicit PrepareCompositeFunctionsPass() {} explicit PrepareCompositeFunctionsPass() {}
private: private:
void ConvertTFImplements(FuncOp func, StringAttr attr); void ConvertTFImplements(FuncOp func, StringAttr attr);
void ConvertTFAPIImplements(FuncOp func, StringAttr attr, ModuleOp module); void ConvertTFAPIImplements(FuncOp func, StringAttr attr, ModuleOp module);
void runOnModule() override; void runOnOperation() override;
}; };
void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func, void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func,
@ -189,8 +189,8 @@ void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(FuncOp func,
} }
} }
void PrepareCompositeFunctionsPass::runOnModule() { void PrepareCompositeFunctionsPass::runOnOperation() {
auto module = getModule(); auto module = getOperation();
for (auto func : module.getOps<FuncOp>()) { for (auto func : module.getOps<FuncOp>()) {
// We have two kinds of implements: // We have two kinds of implements:
// 1) tf._implements. // 1) tf._implements.

View File

@ -34,7 +34,9 @@ class RuntimeTypeVerifyPass : public mlir::FunctionPass<RuntimeTypeVerifyPass> {
void RuntimeTypeVerifyPass::runOnFunction() { void RuntimeTypeVerifyPass::runOnFunction() {
getFunction().walk([&](TflRuntimeVerifyOpInterface op) { getFunction().walk([&](TflRuntimeVerifyOpInterface op) {
if (failed(op.VerifyTflRuntimeTypes(op.getOperation()))) if (failed(op.VerifyTflRuntimeTypes(
op.getOperation(),
/*failure_on_operand_type_mismatch=*/true)))
signalPassFailure(); signalPassFailure();
}); });
} }

View File

@ -44,21 +44,22 @@ namespace {
// The pass to trim functions before we legalize to TFL // The pass to trim functions before we legalize to TFL
// dialect using the specified whitelist. // dialect using the specified whitelist.
class TrimFunctionsPass : public mlir::ModulePass<TrimFunctionsPass> { class TrimFunctionsPass
: public mlir::OperationPass<TrimFunctionsPass, ModuleOp> {
public: public:
explicit TrimFunctionsPass() : trim_funcs_whitelist_(trim_funcs_whitelist) {} explicit TrimFunctionsPass() : trim_funcs_whitelist_(trim_funcs_whitelist) {}
explicit TrimFunctionsPass(llvm::ArrayRef<std::string> trim_funcs_whitelist) explicit TrimFunctionsPass(llvm::ArrayRef<std::string> trim_funcs_whitelist)
: trim_funcs_whitelist_(trim_funcs_whitelist) {} : trim_funcs_whitelist_(trim_funcs_whitelist) {}
private: private:
void runOnModule() override; void runOnOperation() override;
bool TrimModule(); bool TrimModule();
void Verify(); void Verify();
llvm::ArrayRef<std::string> trim_funcs_whitelist_; llvm::ArrayRef<std::string> trim_funcs_whitelist_;
}; };
void TrimFunctionsPass::runOnModule() { void TrimFunctionsPass::runOnOperation() {
// trim the functions in the module using the trim_funcs_whitelist_ // trim the functions in the module using the trim_funcs_whitelist_
// by removing functions not in the whitelist. // by removing functions not in the whitelist.
if (TrimModule()) { if (TrimModule()) {
@ -73,7 +74,7 @@ bool TrimFunctionsPass::TrimModule() {
if (trim_funcs_whitelist_.empty()) return false; if (trim_funcs_whitelist_.empty()) return false;
llvm::SmallVector<FuncOp, 4> funcs_to_trim; llvm::SmallVector<FuncOp, 4> funcs_to_trim;
for (auto func : getModule().getOps<FuncOp>()) { for (auto func : getOperation().getOps<FuncOp>()) {
if (llvm::is_contained(trim_funcs_whitelist_, func.getName())) { if (llvm::is_contained(trim_funcs_whitelist_, func.getName())) {
// If no main is specified in the whitelist, use the 1st func // If no main is specified in the whitelist, use the 1st func
// in trim_funcs_whitelist as the main. // in trim_funcs_whitelist as the main.
@ -102,12 +103,12 @@ bool TrimFunctionsPass::TrimModule() {
void TrimFunctionsPass::Verify() { void TrimFunctionsPass::Verify() {
// TODO(ashwinm): Instead, we should make sure that references to all // TODO(ashwinm): Instead, we should make sure that references to all
// SymbolRefAttrs of all ops are present. // SymbolRefAttrs of all ops are present.
SymbolTable symbol_table = SymbolTable(getModule()); SymbolTable symbol_table = SymbolTable(getOperation());
llvm::SetVector<FuncOp> reachable_funcs; llvm::SetVector<FuncOp> reachable_funcs;
for (auto func : getModule().getOps<FuncOp>()) { for (auto func : getOperation().getOps<FuncOp>()) {
auto walk_result = func.walk([&](CallOp op) -> WalkResult { auto walk_result = func.walk([&](CallOp op) -> WalkResult {
if (!symbol_table.lookup<FuncOp>(op.getCallee())) if (!symbol_table.lookup<FuncOp>(op.getCallee()))
return getModule().emitError() return getOperation().emitError()
<< func.getName() << " is not in the funcs whitelist"; << func.getName() << " is not in the funcs whitelist";
return WalkResult::advance(); return WalkResult::advance();
}); });

View File

@ -37,12 +37,13 @@ namespace {
// This pass outlines the cond/body region of the TFL WhileOp into functions and // This pass outlines the cond/body region of the TFL WhileOp into functions and
// replaces the regions with calls to these outlined functions. // replaces the regions with calls to these outlined functions.
class WhileOutlinePass : public mlir::ModulePass<WhileOutlinePass> { class WhileOutlinePass
: public mlir::OperationPass<WhileOutlinePass, ModuleOp> {
public: public:
explicit WhileOutlinePass() {} explicit WhileOutlinePass() {}
private: private:
void runOnModule() override; void runOnOperation() override;
// Outlines the regions of the WhileOp's cond and body and insert function // Outlines the regions of the WhileOp's cond and body and insert function
// calls instead, // calls instead,
@ -130,7 +131,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
// Create outline function from region. Optional pass extra arguments through // Create outline function from region. Optional pass extra arguments through
// to yield. // to yield.
SymbolTable symbol_table(getModule()); SymbolTable symbol_table(getOperation());
auto create_outline_func = [&](StringRef name, Region& region, auto create_outline_func = [&](StringRef name, Region& region,
bool passthru_extra_args) { bool passthru_extra_args) {
FunctionType type; FunctionType type;
@ -234,8 +235,8 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
op->erase(); op->erase();
} }
void WhileOutlinePass::runOnModule() { void WhileOutlinePass::runOnOperation() {
getModule().walk( getOperation().walk(
[&](mlir::TFL::WhileOp while_op) { OutlineWhile(while_op); }); [&](mlir::TFL::WhileOp while_op) { OutlineWhile(while_op); });
} }

View File

@ -32,10 +32,12 @@ namespace errors = tensorflow::errors;
mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) { mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
switch (type) { switch (type) {
case tflite::TensorType_FLOAT32:
return builder.getF32Type();
case tflite::TensorType_FLOAT16: case tflite::TensorType_FLOAT16:
return builder.getF16Type(); return builder.getF16Type();
case tflite::TensorType_FLOAT32:
return builder.getF32Type();
case tflite::TensorType_FLOAT64:
return builder.getF64Type();
case tflite::TensorType_INT32: case tflite::TensorType_INT32:
return builder.getIntegerType(32); return builder.getIntegerType(32);
case tflite::TensorType_UINT8: case tflite::TensorType_UINT8:
@ -65,6 +67,8 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) {
return tensorflow::DT_HALF; return tensorflow::DT_HALF;
case tflite::TensorType_FLOAT32: case tflite::TensorType_FLOAT32:
return tensorflow::DT_FLOAT; return tensorflow::DT_FLOAT;
case tflite::TensorType_FLOAT64:
return tensorflow::DT_DOUBLE;
case tflite::TensorType_INT8: case tflite::TensorType_INT8:
return tensorflow::DT_INT8; return tensorflow::DT_INT8;
case tflite::TensorType_INT16: case tflite::TensorType_INT16:

View File

@ -545,13 +545,44 @@ LogicalResult Verify(SwitchNOp switchn) {
<< "expect `num_outs` (" << num_outs.getInt() << ") results but got " << "expect `num_outs` (" << num_outs.getInt() << ") results but got "
<< (switchn.getNumResults() - 1); << (switchn.getNumResults() - 1);
// Check that operand can be broadcasted to each output type.
auto operand0_type = switchn.getOperand(0).getType(); auto operand0_type = switchn.getOperand(0).getType();
for (Value result : switchn.outputs()) TensorType operand0_tensor_type = operand0_type.dyn_cast<TensorType>();
if (operand0_type != result.getType()) if (!operand0_tensor_type) {
return switchn.emitOpError() return switchn.emitOpError()
<< "type mismatch between data operand and result: " << "expects data operand to have tensor type but got "
<< operand0_type << " vs " << result.getType(); << operand0_type;
}
for (Type output_type : switchn.getResultTypes()) {
if (output_type.isa<ControlType>()) break;
TensorType output_tensor_type = output_type.dyn_cast<TensorType>();
if (!output_tensor_type) {
return switchn.emitOpError()
<< "expects outputs to have tensor type but got " << output_type;
}
// If the output type is a ref type, then the operand type should also be of
// the same ref type. However, if the output type is a non-ref type T, then
// the operand can be tensor of type T or T_REF.
bool is_output_ref =
output_tensor_type.getElementType().isa<TF::TensorFlowRefType>();
if (is_output_ref &&
!operand0_tensor_type.getElementType().isa<TF::TensorFlowRefType>()) {
return switchn.emitOpError()
<< "expects same operand and output element type but got "
<< operand0_tensor_type << " vs " << output_tensor_type;
}
Type broadcasted_type = OpTrait::util::getBroadcastedType(
DropRefType(DropTypeSubTypes(operand0_tensor_type)),
DropRefType(DropTypeSubTypes(output_tensor_type)));
if (!broadcasted_type) {
return switchn.emitOpError()
<< "expects data operand to be broadcastable with all output types"
<< " but got " << operand0_tensor_type << " vs "
<< output_tensor_type;
}
}
return success(); return success();
} }

View File

@ -5301,6 +5301,8 @@ tf.pow(x, y) ==> [[256, 65536], [9, 27]]
); );
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let hasFolder = 1;
} }
def TF_PreventGradientOp : TF_Op<"PreventGradient", [NoSideEffect, SameOperandsAndResultType]> { def TF_PreventGradientOp : TF_Op<"PreventGradient", [NoSideEffect, SameOperandsAndResultType]> {
@ -6006,6 +6008,30 @@ Resize `images` to `size` using nearest neighbor interpolation.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
} }
def TF_ResourceApplyAdagradV2Op : TF_Op<"ResourceApplyAdagradV2", []> {
let summary = "Update '*var' according to the adagrad scheme.";
let description = [{
accum += grad * grad
var -= lr * grad * (1 / (sqrt(accum) + epsilon))
}];
let arguments = (ins
TF_ResourceTensor:$var,
TF_ResourceTensor:$accum,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad,
DefaultValuedAttr<BoolAttr, "false">:$use_locking,
DefaultValuedAttr<BoolAttr, "true">:$update_slots
);
let results = (outs);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>;
}
def TF_ResourceApplyAdamOp : TF_Op<"ResourceApplyAdam", []> { def TF_ResourceApplyAdamOp : TF_Op<"ResourceApplyAdam", []> {
let summary = "Update '*var' according to the Adam algorithm."; let summary = "Update '*var' according to the Adam algorithm.";
@ -7711,6 +7737,28 @@ shape of `StridedSlice`'s `input`.
}]; }];
} }
def TF_StringFormatOp : TF_Op<"StringFormat", [NoSideEffect]> {
let summary = "Formats a string template using a list of tensors.";
let description = [{
Formats a string template using a list of tensors, pretty-printing tensor summaries.
}];
let arguments = (ins
Variadic<TF_Tensor>:$inputs,
DefaultValuedAttr<StrAttr, "%s">:$strtemplate,
DefaultValuedAttr<StrAttr, "%s">:$placeholder,
DefaultValuedAttr<I64Attr, "3">:$summarize
);
let results = (outs
TF_StrTensor:$output
);
TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
}
def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape]>, def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape]>,
WithBroadcastableBinOpBuilder { WithBroadcastableBinOpBuilder {
let summary = "Returns x - y element-wise."; let summary = "Returns x - y element-wise.";

View File

@ -2153,6 +2153,27 @@ static LogicalResult VerifyPartitionedCall(OpClass op) {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// PowOp
//===----------------------------------------------------------------------===//
OpFoldResult PowOp::fold(ArrayRef<Attribute> operands) {
auto constant_y = operands[1].dyn_cast_or_null<DenseFPElementsAttr>();
if (constant_y && constant_y.isSplat()) {
APFloat y_value = constant_y.getSplatValue<APFloat>();
auto output_type = getType().cast<ShapedType>();
if (y_value.isZero() && output_type.hasStaticShape()) {
return DenseElementsAttr::get(
output_type,
FloatAttr::get(output_type.getElementType(), /*value=*/1.0));
}
if (y_value.isExactlyValue(1.0)) {
return x();
}
}
return {};
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ReciprocalOp // ReciprocalOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -18,6 +18,26 @@ func @testShape(tensor<f32>, tensor<1x32x32x16xf32>, tensor<*xf32>) -> (tensor<0
return %0, %1, %2 : tensor<0xi32>, tensor<?xi32>, tensor<?xi32> return %0, %1, %2 : tensor<0xi32>, tensor<?xi32>, tensor<?xi32>
} }
// CHECK-LABEL: func @testPow
// CHECK-SAME:(%[[ARG_0:.*]]: tensor<4xf32>, %[[ARG_1:.*]]: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>)
func @testPow(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) {
%cst_zero = constant dense<0.0> : tensor<f32>
%cst_one = constant dense<1.0> : tensor<f32>
// CHECK-DAG: %[[RES_NO_FOLD:.*]] = "tf.Pow"(%arg0, %arg1)
%0 = "tf.Pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-DAG: %[[POW_ZERO:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
%1 = "tf.Pow"(%arg0, %cst_zero) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
// CHECK-NOT: "tf.Pow"
%2 = "tf.Pow"(%arg0, %cst_one) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
// CHECK: return %[[RES_NO_FOLD]], %[[POW_ZERO]], %[[ARG_0]]
return %0, %1, %2 : tensor<4xf32>, tensor<4xf32>, tensor<4xf32>
}
// CHECK-LABEL: func @testShapeN // CHECK-LABEL: func @testShapeN
func @testShapeN(%arg0: tensor<f32>, %arg1: tensor<1x32x32x16xf32>, %arg2: tensor<*xf32>) -> (tensor<0xi64>, tensor<4xi64>, tensor<4xi64>, tensor<?xi64>) { func @testShapeN(%arg0: tensor<f32>, %arg1: tensor<1x32x32x16xf32>, %arg2: tensor<*xf32>) -> (tensor<0xi64>, tensor<4xi64>, tensor<4xi64>, tensor<?xi64>) {

View File

@ -147,6 +147,37 @@ func @decompose_resource_apply_keras_momentum_nesterov(%arg0: tensor<f32>, %arg1
// ----- // -----
// Tests that composite tf.ResourceApplyAdagradV2 operation is decomposed.
// CHECK-LABEL: func @decompose_resource_apply_adagradv2
// CHECK-SAME: ([[LR:%.*]]: tensor<f32>, [[EPSILON:%.*]]: tensor<f32>, [[GRAD:%.*]]: tensor<f32>)
func @decompose_resource_apply_adagradv2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> () {
// CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"()
// CHECK: [[ACC_HANDLE:%.*]] = "tf.VarHandleOp"()
// CHECK: [[GRAD_SQUARE:%.*]] = "tf.Mul"([[GRAD]], [[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: [[OLD_ACC:%.*]] = "tf.ReadVariableOp"([[ACC_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
// CHECK: [[NEW_ACC:%.*]] = "tf.AddV2"([[OLD_ACC]], [[GRAD_SQUARE]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
// CHECK: [[LR_MULTIPLY:%.*]] = "tf.Mul"([[LR]], [[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: [[SQRT:%.*]] = "tf.Sqrt"([[NEW_ACC]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[DIVISOR:%.*]] = "tf.AddV2"([[SQRT]], [[EPSILON]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
// CHECK: [[VAR_DELTA:%.*]] = "tf.Div"([[LR_MULTIPLY]], [[DIVISOR]]) : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
// CHECK: [[NEW_VAR:%.*]] = "tf.Sub"(%9, %8) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
// CHECK: "tf.AssignVariableOp"([[ACC_HANDLE]], [[NEW_ACC]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
"tf.ResourceApplyAdagradV2"(%0, %1, %arg0, %arg1, %arg2) {update_slots = true, use_locking = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
// -----
// Tests that composite tf.ResourceApplyAdam (non-Nesterov) operation is // Tests that composite tf.ResourceApplyAdam (non-Nesterov) operation is
// decomposed. // decomposed.

View File

@ -248,6 +248,40 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
return %0 : tensor<?x?x?xf32> return %0 : tensor<?x?x?xf32>
} }
// Check that supported tf_executor ops can receive data from ops on which
// shape inference has inferred the result types, without throwing any errors.
// CHECK-LABEL: func @supported_tf_executor_users
func @supported_tf_executor_users(%arg0: tensor<32x?x256x4xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<i1>, %arg3: tensor<i32>) -> tensor<?x?x?xf32> {
%0 = tf_executor.graph {
%island:3 = tf_executor.island {
%dims = "tf.Const"() {value = dense<[32, -1, 4]> : tensor<3xi32>} : () -> tensor<3xi32>
%reshape = "tf.Reshape"(%arg0, %dims) : (tensor<32x?x256x4xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
%cast = "tf.Cast"(%arg2) : (tensor<i1>) -> tensor<*xi1>
tf_executor.yield %reshape, %cast : tensor<?x?x?xf32>, tensor<*xi1>
}
// CHECK: tf_executor.Merge
// CHECK-SAME: : (tensor<32x?x4xf32>, tensor<?x?x?xf32>) ->
// CHECK: tf_executor.Switch
// CHECK-SAME: : (tensor<32x?x4xf32>, tensor<i1>) ->
// CHECK: tf_executor.SwitchN
// CHECK-SAME: : tensor<?x?x?xf32>
// CHECK: tf_executor.Enter
// CHECK-SAME: : (tensor<32x?x4xf32>) ->
// CHECK: tf_executor.Exit
// CHECK-SAME: : tensor<?x?x?xf32>
// CHECK: tf_executor.LoopCond
// CHECK-SAME: : tensor<*xi1>
%merge:3 = "tf_executor.Merge"(%island#0, %arg1) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<i32>, !tf_executor.control)
%switch:3 = "tf_executor.Switch"(%island#0, %arg2) : (tensor<?x?x?xf32>, tensor<i1>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>, !tf_executor.control)
%switchn:3 = "tf_executor.SwitchN"(%island#0, %arg3) {num_outs = 2} : (tensor<?x?x?xf32>, tensor<i32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>, !tf_executor.control)
%enter:2 = "tf_executor.Enter"(%island#0) { frame_name = "frame"} : (tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, !tf_executor.control)
%exit:2 = "tf_executor.Exit"(%island#0) : (tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, !tf_executor.control)
%loop_cond:2 = "tf_executor.LoopCond" (%island#1) : (tensor<*xi1>) -> (tensor<*xi1>, !tf_executor.control)
tf_executor.fetch %enter#0 : tensor<?x?x?xf32>
}
return %0 : tensor<?x?x?xf32>
}
// CHECK-LABEL: func @fold_cast // CHECK-LABEL: func @fold_cast
func @fold_cast(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @fold_cast(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NOT: Cast // CHECK-NOT: Cast

View File

@ -7,7 +7,7 @@ func @invalid_type() -> !tf_executor.foobar
// Check that tf_executor.graph does not accept any operand. // Check that tf_executor.graph does not accept any operand.
func @graph_with_invalid_op(%arg0: tensor<*xf32>) { func @graph_with_invalid_op(%arg0: tensor<*xf32>) {
"tf_executor.graph" (%arg0) : (tensor<*xf32>) -> () "tf_executor.graph" (%arg0) ({}) : (tensor<*xf32>) -> ()
// expected-error@-1 {{'tf_executor.graph' op requires zero operands}} // expected-error@-1 {{'tf_executor.graph' op requires zero operands}}
return return
} }
@ -405,12 +405,49 @@ func @invalid_switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32>
// ----- // -----
// Check that switchN result type matches the input type. // Check that data operands of SwitchN have tensor type
func @invalid_switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { func @invalid_switchN(%arg0: i32, %arg1: tensor<i32>) -> tensor<*xi32> {
%result = tf_executor.graph {
%1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (i32, tensor<i32>) -> (tensor<*xi32>, tensor<i32>, !tf_executor.control)
// expected-error@-1 {{'tf_executor.SwitchN' op expects data operand to have tensor type but got 'i32'}}
tf_executor.fetch %1#0 : tensor<*xi32>
}
return %result : tensor<*xi32>
}
// -----
// Check that result of SwitchN has tensor type
func @invalid_switchN(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> i32 {
%result = tf_executor.graph {
%1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<*xi32>, tensor<i32>) -> (i32, tensor<i32>, !tf_executor.control)
// expected-error@-1 {{'tf_executor.SwitchN' op expects outputs to have tensor type but got 'i32'}}
tf_executor.fetch %1#0 : i32
}
return %result : i32
}
// -----
// Check that if any result is a ref type, then data operand needs to be ref too.
func @invalid_switchN(%arg0: tensor<4xf32>, %arg1: tensor<i32>) -> tensor<4x!tf.f32ref> {
%fetches = tf_executor.graph { %fetches = tf_executor.graph {
%1:3 = "tf_executor.SwitchN"(%arg1, %arg0) {num_outs = 2} : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, i32, !tf_executor.control) %1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<4xf32>, tensor<i32>) -> (tensor<4x!tf.f32ref>, tensor<4xf32>, !tf_executor.control)
// expected-error@-1 {{'tf_executor.SwitchN' op type mismatch between data operand and result: 'tensor<*xf32>' vs 'i32'}} // expected-error@-1 {{'tf_executor.SwitchN' op expects same operand and output element type but got 'tensor<4xf32>' vs 'tensor<4x!tf.f32ref>'}}
tf_executor.fetch %1#0 : tensor<4x!tf.f32ref>
}
return %fetches : tensor<4x!tf.f32ref>
}
// -----
// Check that switchN data operand is broadcastable with all output types
func @invalid_switchN(%arg0: tensor<*xf32>, %arg1: tensor<i32>) -> tensor<*xf32> {
%fetches = tf_executor.graph {
%1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<i32>, !tf_executor.control)
// expected-error@-1 {{'tf_executor.SwitchN' op expects data operand to be broadcastable with all output types but got 'tensor<*xf32>' vs 'tensor<i32>'}}
tf_executor.fetch %1#0 : tensor<*xf32> tf_executor.fetch %1#0 : tensor<*xf32>
} }
@ -472,6 +509,30 @@ func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
// ----- // -----
// Check that data operands of merge have tensor type
func @invalid_merge(%arg0: tensor<*xi32>, %arg1: i32) -> tensor<*xi32> {
%result = tf_executor.graph {
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<*xi32>, i32) -> (tensor<*xi32>, tensor<i32>, !tf_executor.control)
// expected-error@-1 {{'tf_executor.Merge' op expects data operands to have tensor type but got 'i32'}}
tf_executor.fetch %value : tensor<*xi32>
}
return %result : tensor<*xi32>
}
// -----
// Check that result of merge has tensor type
func @invalid_merge(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> i32 {
%result = tf_executor.graph {
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<*xi32>, tensor<i32>) -> (i32, tensor<i32>, !tf_executor.control)
// expected-error@-1 {{'tf_executor.Merge' op result #0 must be tensor of any type values, but got 'i32'}}
tf_executor.fetch %value : i32
}
return %result : i32
}
// -----
// Check that merge data inputs are all the same type // Check that merge data inputs are all the same type
func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> { func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
%result = tf_executor.graph { %result = tf_executor.graph {

View File

@ -0,0 +1,92 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# RUN: %p/multi_arguments_results_v1 | FileCheck -dump-input-on-failure %s
# pylint: disable=missing-docstring,line-too-long
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
from tensorflow.python.ops import array_ops
# Tests multiple inputs and outputs with index paths.
# CHECK-LABEL: func @key(
# CHECK-SAME: %[[ARG0:.*]]: tensor<3x5xf32> {tf_saved_model.index_path = ["y"]}
# CHECK-SAME: %[[ARG1:.*]]: tensor<5x3xf32> {tf_saved_model.index_path = ["x"]}
# CHECK-SAME: tensor<3x3xf32> {tf_saved_model.index_path = ["t"]}
# CHECK-SAME: tensor<5x5xf32> {tf_saved_model.index_path = ["s"]}
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["key"]
# CHECK-DAG: %[[MUL0:.*]] = "tf.MatMul"(%[[ARG1]], %[[ARG0]])
# CHECK-DAG: %[[MUL1:.*]] = "tf.MatMul"(%[[ARG0]], %[[ARG1]])
# CHECK: %[[IDENTITY:.*]]:2 = "tf.IdentityN"(%[[MUL1]], %[[MUL0]])
# CHECK: return %[[IDENTITY]]#0, %[[IDENTITY]]#1
# CHECK-LABEL: func @key2(
# CHECK-SAME: %[[ARG1:.*]]: tensor<5x3xf32> {tf_saved_model.index_path = ["b"]}
# CHECK-SAME: %[[ARG0:.*]]: tensor<3x5xf32> {tf_saved_model.index_path = ["a"]}
# CHECK-SAME: tensor<5x5xf32> {tf_saved_model.index_path = ["d"]}
# CHECK-SAME: tensor<3x3xf32> {tf_saved_model.index_path = ["c"]}
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["key2"]
# CHECK-DAG: %[[MUL1:.*]] = "tf.MatMul"(%[[ARG0]], %[[ARG1]])
# CHECK-DAG: %[[MUL2:.*]] = "tf.MatMul"(%[[ARG1]], %[[ARG0]])
# CHECK: %[[IDENTITY:.*]]:2 = "tf.IdentityN"(%[[MUL1]], %[[MUL2]])
# CHECK: return %[[IDENTITY]]#1, %[[IDENTITY]]#0
def Test():
x = tf.constant(1.0, shape=(5, 3))
y = tf.constant(1.0, shape=(3, 5))
s = tf.matmul(x, y)
t = tf.matmul(y, x)
[t, s] = array_ops.identity_n([t, s])
tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.compat.v1.saved_model.utils.build_tensor_info(y)
tensor_info_s = tf.compat.v1.saved_model.utils.build_tensor_info(s)
tensor_info_t = tf.compat.v1.saved_model.utils.build_tensor_info(t)
return {
'key': (tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
inputs={
'x': tensor_info_x,
'y': tensor_info_y
},
outputs={
's': tensor_info_s,
't': tensor_info_t
},
method_name='some_function')),
'key2': (tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
inputs={
'a': tensor_info_y,
'b': tensor_info_x,
},
outputs={
'c': tensor_info_t,
'd': tensor_info_s,
},
method_name='reverse_arguments'))
}
if __name__ == '__main__':
common_v1.set_tf_options()
common_v1.do_test(Test())

View File

@ -1,64 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# RUN: %p/multi_arguments_v1 | FileCheck %s
# pylint: disable=missing-docstring,line-too-long
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
# Tests multiple inputs with index paths.
# CHECK: func {{@[a-zA-Z_0-9]+}}(
# CHECK-SAME: [[ARG0:%.*]]: tensor<5x3xf32> {tf_saved_model.index_path = ["x"]},
# CHECK-SAME: [[ARG1:%.*]]: tensor<3x5xf32> {tf_saved_model.index_path = ["y"]})
# CHECK-SAME: -> (tensor<5x5xf32> {tf_saved_model.index_path = ["s"]},
# CHECK-SAME: tensor<3x3xf32> {tf_saved_model.index_path = ["t"]})
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["key"]
def Test():
x = tf.constant(1.0, shape=(5, 3))
y = tf.constant(1.0, shape=(3, 5))
s = tf.matmul(x, y)
t = tf.matmul(y, x)
tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.compat.v1.saved_model.utils.build_tensor_info(y)
tensor_info_s = tf.compat.v1.saved_model.utils.build_tensor_info(s)
tensor_info_t = tf.compat.v1.saved_model.utils.build_tensor_info(t)
return {
'key': (tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
inputs={
'x': tensor_info_x,
'y': tensor_info_y
},
outputs={
's': tensor_info_s,
't': tensor_info_t
},
method_name='some_function'))
}
if __name__ == '__main__':
common_v1.set_tf_options()
common_v1.do_test(Test())

View File

@ -39,8 +39,8 @@ constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
// Analyzes the inputs to LaunchFuncOps in the module, and annotates their // Analyzes the inputs to LaunchFuncOps in the module, and annotates their
// invoked functions whether each input has the same data across replicas. // invoked functions whether each input has the same data across replicas.
struct AnnotateParameterReplication struct AnnotateParameterReplication
: public ModulePass<AnnotateParameterReplication> { : public OperationPass<AnnotateParameterReplication, ModuleOp> {
void runOnModule() override; void runOnOperation() override;
}; };
// Returns the first value in the chain of operands, which is not defined by a // Returns the first value in the chain of operands, which is not defined by a
@ -53,8 +53,8 @@ Value SkipIdentityAndReadVariable(Value v) {
return v; return v;
} }
void AnnotateParameterReplication::runOnModule() { void AnnotateParameterReplication::runOnOperation() {
ModuleOp m = getModule(); ModuleOp m = getOperation();
OpBuilder builder(m.getContext()); OpBuilder builder(m.getContext());
m.walk([&](tf_device::LaunchFuncOp launch_func) { m.walk([&](tf_device::LaunchFuncOp launch_func) {
auto replicate = launch_func.getParentOfType<tf_device::ReplicateOp>(); auto replicate = launch_func.getParentOfType<tf_device::ReplicateOp>();

View File

@ -38,8 +38,9 @@ namespace {
constexpr char kDeviceAttr[] = "device"; constexpr char kDeviceAttr[] = "device";
constexpr char kFuncAttr[] = "func"; constexpr char kFuncAttr[] = "func";
struct ClusterOutliningPass : public ModulePass<ClusterOutliningPass> { struct ClusterOutliningPass
void runOnModule() override; : public OperationPass<ClusterOutliningPass, ModuleOp> {
void runOnOperation() override;
}; };
void ReplaceLaunchReturnWithReturn(tf_device::ReturnOp launch_return_op, void ReplaceLaunchReturnWithReturn(tf_device::ReturnOp launch_return_op,
@ -120,8 +121,8 @@ void OutlineLaunch(tf_device::LaunchOp launch_op, SymbolTable* symbol_table,
launch_op.erase(); launch_op.erase();
} }
void ClusterOutliningPass::runOnModule() { void ClusterOutliningPass::runOnOperation() {
ModuleOp m = getModule(); ModuleOp m = getOperation();
SymbolTable symbol_table(m); SymbolTable symbol_table(m);
OpBuilder builder(m.getContext()); OpBuilder builder(m.getContext());
m.walk([&](tf_device::LaunchOp launch) { m.walk([&](tf_device::LaunchOp launch) {

View File

@ -22,7 +22,7 @@ class GetScalarOfType<int value> : NativeCodeCall<
"GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">; "GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
// Creates a tf.ReadVariable op that reads a resource `$2` that has the same // Creates a tf.ReadVariable op that reads a resource `$2` that has the same
// element type as `$1`. The op created will use location of `$1`. // element type as `$1`. The op created will use location of `$0`.
def CreateTFReadVariableOp: NativeCodeCall< def CreateTFReadVariableOp: NativeCodeCall<
"$_builder.create<TF::ReadVariableOp>(" "$_builder.create<TF::ReadVariableOp>("
" $0.getLoc()," " $0.getLoc(),"
@ -118,6 +118,32 @@ def DecomposeResourceApplyKerasMomentumOpNesterov :
] ]
>; >;
// Pattern to Decompose ResourceApplyAdagrad.
// This decomposition is only correct inside XLA as it ignores use_locking
// attribute.
// accum <- accum + grad * grad
// variable <- variable - lr * grad / (sqrt(accum) + epsilon)
def DecomposeResourceApplyAdagradV2 :
Pattern<
(TF_ResourceApplyAdagradV2Op:$src_op
$var_resource, $accum_resource, $lr, $epsilon, $grad, BoolAttr:$_,
ConstBoolAttrTrue:$update_slots),
[
(TF_AddV2Op:$new_accum
(CreateTFReadVariableOp $src_op, $grad, $accum_resource),
(TF_MulOp $grad, $grad)
),
(TF_AssignSubVariableOp
$var_resource,
(TF_DivOp
(TF_MulOp $lr, $grad),
(TF_AddV2Op (TF_SqrtOp $new_accum), $epsilon)
)
),
(TF_AssignVariableOp $accum_resource, $new_accum),
]
>;
// Pattern to Decompose ResourceApplyAdam without Nesterov momentum. // Pattern to Decompose ResourceApplyAdam without Nesterov momentum.
// This decomposition is only correct inside XLA as it ignores use_locking // This decomposition is only correct inside XLA as it ignores use_locking
// attribute. // attribute.

View File

@ -303,7 +303,8 @@ void InsertDummyIslandForFetch(FetchOp fetch) {
/*control=*/ControlType::get(fetch.getContext()), /*control=*/ControlType::get(fetch.getContext()),
/*controlInputs=*/control_fetches); /*controlInputs=*/control_fetches);
island.body().push_back(new Block); island.body().push_back(new Block);
OpBuilder(&island.GetBody()).create<YieldOp>(fetch.getLoc(), data_fetches); OpBuilder::atBlockEnd(&island.GetBody())
.create<YieldOp>(fetch.getLoc(), data_fetches);
const int fetch_control_idx = data_fetches.size(); const int fetch_control_idx = data_fetches.size();
for (int i = 0, e = fetch.getNumOperands(); i < e; i++) { for (int i = 0, e = fetch.getNumOperands(); i < e; i++) {
// The fetch could have multiple control operands (all at the end of its // The fetch could have multiple control operands (all at the end of its

View File

@ -43,17 +43,17 @@ constexpr llvm::StringRef kNestedModule = "_tpu_v1_compat_outlined";
// Inlining the islands calling into the nested module that was outlined. // Inlining the islands calling into the nested module that was outlined.
// This is the end of the TPU bridge in V1 compatibility mode. // This is the end of the TPU bridge in V1 compatibility mode.
struct TPUBridgeExecutorIslandInlining struct TPUBridgeExecutorIslandInlining
: public ModulePass<TPUBridgeExecutorIslandInlining> { : public OperationPass<TPUBridgeExecutorIslandInlining, ModuleOp> {
void runOnModule() override; void runOnOperation() override;
}; };
void TPUBridgeExecutorIslandInlining::runOnModule() { void TPUBridgeExecutorIslandInlining::runOnOperation() {
SymbolTable symbol_table(getModule()); SymbolTable symbol_table(getOperation());
Operation *nested_module = symbol_table.lookup(kNestedModule); Operation *nested_module = symbol_table.lookup(kNestedModule);
if (!nested_module) return; if (!nested_module) return;
InlinerInterface inliner(&getContext()); InlinerInterface inliner(&getContext());
auto walk_result = getModule().walk([&](TF::PartitionedCallOp call_op) { auto walk_result = getOperation().walk([&](TF::PartitionedCallOp call_op) {
if (!call_op.f().getRootReference().startswith(kNestedModule)) if (!call_op.f().getRootReference().startswith(kNestedModule))
return WalkResult::advance(); return WalkResult::advance();
// This is a call we need to inline! // This is a call we need to inline!
@ -61,7 +61,7 @@ void TPUBridgeExecutorIslandInlining::runOnModule() {
<< "Found call to inline: " << *call_op.getOperation() << "\n"); << "Found call to inline: " << *call_op.getOperation() << "\n");
FuncOp called_func = dyn_cast_or_null<FuncOp>( FuncOp called_func = dyn_cast_or_null<FuncOp>(
symbol_table.lookupSymbolIn(getModule(), call_op.f())); symbol_table.lookupSymbolIn(getOperation(), call_op.f()));
if (failed(inlineCall(inliner, if (failed(inlineCall(inliner,
cast<CallOpInterface>(call_op.getOperation()), cast<CallOpInterface>(call_op.getOperation()),
@ -80,7 +80,7 @@ void TPUBridgeExecutorIslandInlining::runOnModule() {
Block &nested_block = nested_module->getRegion(0).front(); Block &nested_block = nested_module->getRegion(0).front();
for (FuncOp func_op : for (FuncOp func_op :
llvm::make_early_inc_range(nested_block.getOps<FuncOp>())) { llvm::make_early_inc_range(nested_block.getOps<FuncOp>())) {
if (!symbol_table.lookupSymbolIn(getModule(), func_op.getName())) { if (!symbol_table.lookupSymbolIn(getOperation(), func_op.getName())) {
nested_block.getOperations().remove(func_op.getOperation()); nested_block.getOperations().remove(func_op.getOperation());
symbol_table.insert(func_op.getOperation()); symbol_table.insert(func_op.getOperation());
} }

View File

@ -59,8 +59,8 @@ constexpr llvm::StringRef kTpuStatusAttr = "_tpu_compilation_status";
// TPU-annotated operations and intended to preserve backward compatibility with // TPU-annotated operations and intended to preserve backward compatibility with
// TFv1. // TFv1.
struct TpuV1BridgeExecutorIslandCoarsening struct TpuV1BridgeExecutorIslandCoarsening
: public ModulePass<TpuV1BridgeExecutorIslandCoarsening> { : public OperationPass<TpuV1BridgeExecutorIslandCoarsening, ModuleOp> {
void runOnModule() override; void runOnOperation() override;
}; };
// Sort the Operations in the provided range to enforce dominance. // Sort the Operations in the provided range to enforce dominance.
@ -226,7 +226,8 @@ LogicalResult MergeIsland(llvm::function_ref<bool(StringAttr, Operation*)>
yield_operands.push_back(std::get<1>(result)); yield_operands.push_back(std::get<1>(result));
} }
} }
OpBuilder(&island_body).create<YieldOp>(new_island.getLoc(), yield_operands); OpBuilder::atBlockEnd(&island_body)
.create<YieldOp>(new_island.getLoc(), yield_operands);
// remap results of the new islands to the user outside of the island. // remap results of the new islands to the user outside of the island.
int current_result = 0; int current_result = 0;
@ -257,13 +258,13 @@ LogicalResult MergeIsland(llvm::function_ref<bool(StringAttr, Operation*)>
first_op_after); first_op_after);
} }
void TpuV1BridgeExecutorIslandCoarsening::runOnModule() { void TpuV1BridgeExecutorIslandCoarsening::runOnOperation() {
SymbolTable symbol_table(getModule()); SymbolTable symbol_table(getOperation());
// Map tpu cluster names to the functions that contain operations for this // Map tpu cluster names to the functions that contain operations for this
// cluster. // cluster.
DenseMap<StringRef, DenseSet<FuncOp>> tpu_funcs; DenseMap<StringRef, DenseSet<FuncOp>> tpu_funcs;
for (FuncOp func_op : getModule().getOps<FuncOp>()) { for (FuncOp func_op : getOperation().getOps<FuncOp>()) {
func_op.walk([&](Operation* op) { func_op.walk([&](Operation* op) {
StringAttr cluster_name = StringAttr cluster_name =
op->getAttrOfType<StringAttr>(kTpuReplicateAttr); op->getAttrOfType<StringAttr>(kTpuReplicateAttr);
@ -291,7 +292,7 @@ void TpuV1BridgeExecutorIslandCoarsening::runOnModule() {
return false; return false;
}; };
for (FuncOp func_op : getModule().getOps<FuncOp>()) { for (FuncOp func_op : getOperation().getOps<FuncOp>()) {
func_op.walk([&](GraphOp graph) { func_op.walk([&](GraphOp graph) {
Block& graph_body = graph.GetBody(); Block& graph_body = graph.GetBody();

View File

@ -44,20 +44,20 @@ constexpr llvm::StringRef kOutlinedFuncPrefix = "_tpu_v1_compat_outlined_func";
// This is only intended for V1 compatibility mode where the bridge runs without // This is only intended for V1 compatibility mode where the bridge runs without
// feed/fetches on session create/extend. // feed/fetches on session create/extend.
struct TPUBridgeExecutorIslandOutlining struct TPUBridgeExecutorIslandOutlining
: public ModulePass<TPUBridgeExecutorIslandOutlining> { : public OperationPass<TPUBridgeExecutorIslandOutlining, ModuleOp> {
void runOnModule() override; void runOnOperation() override;
}; };
void TPUBridgeExecutorIslandOutlining::runOnModule() { void TPUBridgeExecutorIslandOutlining::runOnOperation() {
MLIRContext *ctx = &getContext(); MLIRContext *ctx = &getContext();
SymbolTable symbol_table(getModule()); SymbolTable symbol_table(getOperation());
if (Operation *nested_module = symbol_table.lookup(kNestedModule)) { if (Operation *nested_module = symbol_table.lookup(kNestedModule)) {
nested_module->emitOpError("unexpected already present outlined module."); nested_module->emitOpError("unexpected already present outlined module.");
return signalPassFailure(); return signalPassFailure();
} }
ModuleOp outlined_module = ModuleOp::create(getModule().getLoc()); ModuleOp outlined_module = ModuleOp::create(getOperation().getLoc());
outlined_module.setAttrs(getModule().getAttrs()); outlined_module.setAttrs(getOperation().getAttrs());
outlined_module.setAttr(SymbolTable::getSymbolAttrName(), outlined_module.setAttr(SymbolTable::getSymbolAttrName(),
StringAttr::get(kNestedModule, ctx)); StringAttr::get(kNestedModule, ctx));
symbol_table.insert(outlined_module); symbol_table.insert(outlined_module);
@ -66,7 +66,7 @@ void TPUBridgeExecutorIslandOutlining::runOnModule() {
// Find every island that contains a TPUReplicateMetadata node and extract it // Find every island that contains a TPUReplicateMetadata node and extract it
// in a new module to run the V1 bridge there. // in a new module to run the V1 bridge there.
SmallVector<IslandOp, 8> islands_to_outline; SmallVector<IslandOp, 8> islands_to_outline;
getModule().walk([&](TF::TPUReplicateMetadataOp replicate_op) { getOperation().walk([&](TF::TPUReplicateMetadataOp replicate_op) {
auto island_op = cast<IslandOp>(replicate_op.getParentOp()); auto island_op = cast<IslandOp>(replicate_op.getParentOp());
if (!island_op || island_op.WrapsSingleOp()) return; if (!island_op || island_op.WrapsSingleOp()) return;
islands_to_outline.push_back(island_op); islands_to_outline.push_back(island_op);
@ -123,7 +123,7 @@ void TPUBridgeExecutorIslandOutlining::runOnModule() {
// The function is in place in the nested module, create a call and yield in // The function is in place in the nested module, create a call and yield in
// the original island. // the original island.
OpBuilder builder(&island_op.GetBody()); OpBuilder builder = OpBuilder::atBlockEnd(&island_op.GetBody());
auto call_op = builder.create<mlir::TF::PartitionedCallOp>( auto call_op = builder.create<mlir::TF::PartitionedCallOp>(
island_op.getLoc(), func_result_types, operands.getArrayRef(), island_op.getLoc(), func_result_types, operands.getArrayRef(),
builder.getSymbolRefAttr( builder.getSymbolRefAttr(

View File

@ -202,7 +202,7 @@ static void MatchSwitchFoldOps(tf_executor::SwitchOp switch_op,
static LogicalResult FoldMergeNodes(FuncOp function, const DeadQueue& queue) { static LogicalResult FoldMergeNodes(FuncOp function, const DeadQueue& queue) {
// Create builder for val_index of MergeOp. // Create builder for val_index of MergeOp.
auto* block = &function.getBlocks().front(); auto* block = &function.getBlocks().front();
OpBuilder builder(block); OpBuilder builder = OpBuilder::atBlockEnd(block);
auto type = builder.getIntegerType(32); auto type = builder.getIntegerType(32);
auto build_index = [&](Location loc, int value) { auto build_index = [&](Location loc, int value) {
return builder.create<ConstantOp>(loc, type, return builder.create<ConstantOp>(loc, type,

View File

@ -41,12 +41,13 @@ namespace {
// the IR is in correct form for inference backends (like lite) that do not // the IR is in correct form for inference backends (like lite) that do not
// support resources/variables . Further, this contract also ensures that this // support resources/variables . Further, this contract also ensures that this
// pass lowers from saved model to pure TF. Hence it fails, if it cannot lower. // pass lowers from saved model to pure TF. Hence it fails, if it cannot lower.
struct FreezeGlobalTensorsPass : public ModulePass<FreezeGlobalTensorsPass> { struct FreezeGlobalTensorsPass
void runOnModule() override; : public OperationPass<FreezeGlobalTensorsPass, ModuleOp> {
void runOnOperation() override;
}; };
void FreezeGlobalTensorsPass::runOnModule() { void FreezeGlobalTensorsPass::runOnOperation() {
auto module = getModule(); auto module = getOperation();
SymbolTable symbol_table(module); SymbolTable symbol_table(module);
DenseSet<Operation*> frozen_global_tensors; DenseSet<Operation*> frozen_global_tensors;

View File

@ -126,7 +126,7 @@ void LayoutAssignmentPass::runOnFunction() {
mlir::Operation* op = layout_sensitive_interface.getOperation(); mlir::Operation* op = layout_sensitive_interface.getOperation();
Location loc = op->getLoc(); Location loc = op->getLoc();
OpBuilder builder(op->getBlock()); OpBuilder builder = OpBuilder::atBlockEnd(op->getBlock());
auto perm_attr = [&](Permutation permutation) -> DenseIntElementsAttr { auto perm_attr = [&](Permutation permutation) -> DenseIntElementsAttr {
auto perm_ty = RankedTensorType::get({4}, builder.getIntegerType(32)); auto perm_ty = RankedTensorType::get({4}, builder.getIntegerType(32));

View File

@ -74,11 +74,11 @@ LogicalResult MarkFunctionVisibilityUsingEntryFunctionSpecification(
namespace { namespace {
struct MarkFunctionVisibilityUsingEntryFunctionSpecificationPass struct MarkFunctionVisibilityUsingEntryFunctionSpecificationPass
: public ModulePass< : public OperationPass<
MarkFunctionVisibilityUsingEntryFunctionSpecificationPass> { MarkFunctionVisibilityUsingEntryFunctionSpecificationPass, ModuleOp> {
void runOnModule() override { void runOnOperation() override {
if (failed(MarkFunctionVisibilityUsingEntryFunctionSpecification( if (failed(MarkFunctionVisibilityUsingEntryFunctionSpecification(
getModule()))) { getOperation()))) {
signalPassFailure(); signalPassFailure();
} }
} }
@ -110,9 +110,10 @@ static LogicalResult MarkFunctionVisibilityUsingSavedModelLinkage(
namespace { namespace {
struct MarkFunctionVisibilityUsingSavedModelLinkagePass struct MarkFunctionVisibilityUsingSavedModelLinkagePass
: public ModulePass<MarkFunctionVisibilityUsingSavedModelLinkagePass> { : public OperationPass<MarkFunctionVisibilityUsingSavedModelLinkagePass,
void runOnModule() override { ModuleOp> {
if (failed(MarkFunctionVisibilityUsingSavedModelLinkage(getModule()))) { void runOnOperation() override {
if (failed(MarkFunctionVisibilityUsingSavedModelLinkage(getOperation()))) {
signalPassFailure(); signalPassFailure();
} }
} }

View File

@ -41,8 +41,8 @@ namespace mlir {
namespace tf_saved_model { namespace tf_saved_model {
namespace { namespace {
struct OptimizeGlobalTensorsPass struct OptimizeGlobalTensorsPass
: public ModulePass<OptimizeGlobalTensorsPass> { : public OperationPass<OptimizeGlobalTensorsPass, ModuleOp> {
void runOnModule() override; void runOnOperation() override;
}; };
// A global tensor is bound to arguments of multiple funcs. // A global tensor is bound to arguments of multiple funcs.
@ -276,8 +276,8 @@ void EraseUnusedBoundInputs(ModuleOp module) {
} }
} }
void OptimizeGlobalTensorsPass::runOnModule() { void OptimizeGlobalTensorsPass::runOnOperation() {
auto module = getModule(); auto module = getOperation();
EraseUnusedBoundInputs(module); EraseUnusedBoundInputs(module);
ResourceAnalyzer resource_analyzer(module); ResourceAnalyzer resource_analyzer(module);

View File

@ -258,13 +258,13 @@ LogicalResult PromoteResourcesToArguments(FuncOp function) {
} }
class PromoteResourcesToArgsPass class PromoteResourcesToArgsPass
: public ModulePass<PromoteResourcesToArgsPass> { : public OperationPass<PromoteResourcesToArgsPass, ModuleOp> {
public: public:
void runOnModule() override; void runOnOperation() override;
}; };
void PromoteResourcesToArgsPass::runOnModule() { void PromoteResourcesToArgsPass::runOnOperation() {
ModuleOp module = getModule(); ModuleOp module = getOperation();
FuncOp main_func = module.lookupSymbol<FuncOp>("main"); FuncOp main_func = module.lookupSymbol<FuncOp>("main");
if (!main_func) return; if (!main_func) return;

View File

@ -53,8 +53,9 @@ constexpr char kFuncDeviceAttr[] = "tf.device";
// //
// This pass changes the module by adding "tf.device" attribute to function // This pass changes the module by adding "tf.device" attribute to function
// arguments and adding "device" attribute to TF ops. // arguments and adding "device" attribute to TF ops.
struct ResourceDeviceInference : public ModulePass<ResourceDeviceInference> { struct ResourceDeviceInference
void runOnModule() override; : public OperationPass<ResourceDeviceInference, ModuleOp> {
void runOnOperation() override;
}; };
// A class that records each resource's device assignment in a function. // A class that records each resource's device assignment in a function.
@ -190,8 +191,8 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
return failure(walk_res.wasInterrupted()); return failure(walk_res.wasInterrupted());
} }
void ResourceDeviceInference::runOnModule() { void ResourceDeviceInference::runOnOperation() {
auto module = getModule(); auto module = getOperation();
llvm::SmallDenseMap<Operation*, PerFunctionResult, 4> per_function_results; llvm::SmallDenseMap<Operation*, PerFunctionResult, 4> per_function_results;
llvm::SetVector<FuncOp> worklist; llvm::SetVector<FuncOp> worklist;
module.walk([&](FuncOp func_op) { module.walk([&](FuncOp func_op) {

View File

@ -131,8 +131,9 @@ namespace {
// return %arg0 // return %arg0
// } // }
// //
struct ResourceOpLiftingPass : public ModulePass<ResourceOpLiftingPass> { struct ResourceOpLiftingPass
void runOnModule() override; : public OperationPass<ResourceOpLiftingPass, ModuleOp> {
void runOnOperation() override;
}; };
// Removes identity nodes in the block. The device computation does not need // Removes identity nodes in the block. The device computation does not need
@ -1050,13 +1051,13 @@ LogicalResult HoistForFunctionalControlFlow(
// Lifts resource operation from tf_device.launch_func ops nested in `op` // Lifts resource operation from tf_device.launch_func ops nested in `op`
// outside. Returns failure if there are remaining resource-type values that can // outside. Returns failure if there are remaining resource-type values that can
// not be lifted. // not be lifted.
void ResourceOpLiftingPass::runOnModule() { void ResourceOpLiftingPass::runOnOperation() {
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo> llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>
lifted_partitioned_call_callees; lifted_partitioned_call_callees;
auto result = getModule().walk([&](FuncOp func_op) { auto result = getOperation().walk([&](FuncOp func_op) {
return func_op.walk([&](tf_device::LaunchOp launch_op) { return func_op.walk([&](tf_device::LaunchOp launch_op) {
if (failed(HoistForFunctionalControlFlow( if (failed(HoistForFunctionalControlFlow(
&launch_op.GetBody(), getModule(), &launch_op.GetBody(), getOperation(),
&lifted_partitioned_call_callees)) || &lifted_partitioned_call_callees)) ||
failed(HoistResourceOpsFromLaunchOp(launch_op))) { failed(HoistResourceOpsFromLaunchOp(launch_op))) {
return WalkResult::interrupt(); return WalkResult::interrupt();
@ -1070,12 +1071,12 @@ void ResourceOpLiftingPass::runOnModule() {
} }
struct ResourceOpLiftingForMainFunctionPass struct ResourceOpLiftingForMainFunctionPass
: public ModulePass<ResourceOpLiftingForMainFunctionPass> { : public OperationPass<ResourceOpLiftingForMainFunctionPass, ModuleOp> {
void runOnModule() override; void runOnOperation() override;
}; };
void ResourceOpLiftingForMainFunctionPass::runOnModule() { void ResourceOpLiftingForMainFunctionPass::runOnOperation() {
ModuleOp module = getModule(); ModuleOp module = getOperation();
FuncOp main_func = module.lookupSymbol<FuncOp>("main"); FuncOp main_func = module.lookupSymbol<FuncOp>("main");
if (!main_func) { if (!main_func) {
return; return;

View File

@ -111,7 +111,9 @@ bool IsSupportedNonTFOp(Operation* op) {
return isa<tf_executor::YieldOp>(op) || isa<tf_executor::IslandOp>(op) || return isa<tf_executor::YieldOp>(op) || isa<tf_executor::IslandOp>(op) ||
isa<tf_executor::FetchOp>(op) || isa<tf_executor::GraphOp>(op) || isa<tf_executor::FetchOp>(op) || isa<tf_executor::GraphOp>(op) ||
isa<tf_executor::NextIterationSinkOp>(op) || isa<ReturnOp>(op) || isa<tf_executor::NextIterationSinkOp>(op) || isa<ReturnOp>(op) ||
isa<tf_device::ReturnOp>(op); isa<tf_device::ReturnOp>(op) || isa<tf_executor::MergeOp>(op) ||
isa<tf_executor::SwitchOp>(op) || isa<tf_executor::SwitchNOp>(op) ||
isa<tf_executor::EnterOp>(op) || isa<tf_executor::ExitOp>(op);
} }
// Inserts tf.Cast operation when changing the type of a result if the user is // Inserts tf.Cast operation when changing the type of a result if the user is
@ -224,7 +226,8 @@ GetSubtypes(Type type) {
return GetSubtypesHelper<TF::VariantType>(type); return GetSubtypesHelper<TF::VariantType>(type);
} }
// Makes result types match the operand types. Returns if anything is changed. // Makes result types match the operand types (the i-th result type will
// match the i-th operand type). Returns true if anything is changed.
bool PassThroughOperandTypes(OperandRange operands, ResultRange results) { bool PassThroughOperandTypes(OperandRange operands, ResultRange results) {
bool changed = false; bool changed = false;
for (auto entry : llvm::zip(operands, results)) { for (auto entry : llvm::zip(operands, results)) {

View File

@ -47,9 +47,9 @@ namespace {
// This transformation pass propagate shapes on the TensorFlow graph. // This transformation pass propagate shapes on the TensorFlow graph.
// It is a ModulePass in order to be able to change function types. // It is a ModulePass in order to be able to change function types.
struct ShapeInference : public ModulePass<ShapeInference> { struct ShapeInference : public OperationPass<ShapeInference, ModuleOp> {
void runOnModule() override { void runOnOperation() override {
auto module = getModule(); auto module = getOperation();
auto producer_or = tensorflow::GetTfGraphProducerVersion(module); auto producer_or = tensorflow::GetTfGraphProducerVersion(module);
if (!producer_or.ok()) { if (!producer_or.ok()) {
LLVM_DEBUG(llvm::dbgs() << producer_or.status().ToString();); LLVM_DEBUG(llvm::dbgs() << producer_or.status().ToString(););

View File

@ -85,8 +85,8 @@ namespace cutil = TF::collection_ops_util;
// //
// The pass also works across control flow and functional calls. // The pass also works across control flow and functional calls.
struct StackOpsDecompositionPass struct StackOpsDecompositionPass
: public ModulePass<StackOpsDecompositionPass> { : public OperationPass<StackOpsDecompositionPass, ModuleOp> {
void runOnModule() override; void runOnOperation() override;
}; };
// Returns the type of the local variable for the stack size. // Returns the type of the local variable for the stack size.
@ -551,8 +551,8 @@ LogicalResult DecomposeStackOps(Block* block, ModuleOp module) {
&decomposed_partitioned_call_callees); &decomposed_partitioned_call_callees);
} }
void StackOpsDecompositionPass::runOnModule() { void StackOpsDecompositionPass::runOnOperation() {
auto module = getModule(); auto module = getOperation();
auto main = module.lookupSymbol<FuncOp>("main"); auto main = module.lookupSymbol<FuncOp>("main");
if (!main) return; if (!main) return;
if (failed(DecomposeStackOps(&main.front(), module))) { if (failed(DecomposeStackOps(&main.front(), module))) {

View File

@ -68,8 +68,8 @@ using std::string;
// shape. // shape.
// //
struct TensorArrayOpsDecompositionPass struct TensorArrayOpsDecompositionPass
: public ModulePass<TensorArrayOpsDecompositionPass> { : public OperationPass<TensorArrayOpsDecompositionPass, ModuleOp> {
void runOnModule() override; void runOnOperation() override;
}; };
// Infers the element type and count for a TensorArraySplitV3Op. Requires // Infers the element type and count for a TensorArraySplitV3Op. Requires
@ -873,8 +873,8 @@ LogicalResult DecomposeTensorArrayOps(
return success(); return success();
} }
void TensorArrayOpsDecompositionPass::runOnModule() { void TensorArrayOpsDecompositionPass::runOnOperation() {
auto module = getModule(); auto module = getOperation();
auto main = module.lookupSymbol<FuncOp>("main"); auto main = module.lookupSymbol<FuncOp>("main");
if (!main) return; if (!main) return;
llvm::SmallDenseMap<Value, TensorArrayStats> stats; llvm::SmallDenseMap<Value, TensorArrayStats> stats;

View File

@ -62,8 +62,8 @@ namespace cutil = TF::collection_ops_util;
// //
// The pass also works across control flow and functional calls. // The pass also works across control flow and functional calls.
struct TensorListOpsDecompositionPass struct TensorListOpsDecompositionPass
: public ModulePass<TensorListOpsDecompositionPass> { : public OperationPass<TensorListOpsDecompositionPass, ModuleOp> {
void runOnModule() override; void runOnOperation() override;
}; };
// Updates func's type according to its current arguments and return values. // Updates func's type according to its current arguments and return values.
@ -671,8 +671,8 @@ LogicalResult DecomposeTensorListOps(Block* block, ModuleOp module) {
&decomposed_partitioned_call_callees); &decomposed_partitioned_call_callees);
} }
void TensorListOpsDecompositionPass::runOnModule() { void TensorListOpsDecompositionPass::runOnOperation() {
auto module = getModule(); auto module = getOperation();
auto main = module.lookupSymbol<FuncOp>("main"); auto main = module.lookupSymbol<FuncOp>("main");
if (!main) return; if (!main) return;
if (failed(DecomposeTensorListOps(&main.front(), module))) { if (failed(DecomposeTensorListOps(&main.front(), module))) {

View File

@ -40,20 +40,20 @@ namespace tensorflow {
// Optimization Passes and convert back to MLIR. // Optimization Passes and convert back to MLIR.
// Constraints: This pass expects that all operations in the MLIR module either // Constraints: This pass expects that all operations in the MLIR module either
// belong to 'tf' or '_tf' dialect. The output is in '_tf' dialect. // belong to 'tf' or '_tf' dialect. The output is in '_tf' dialect.
class GraphOptPass : public mlir::ModulePass<GraphOptPass> { class GraphOptPass : public mlir::OperationPass<GraphOptPass, mlir::ModuleOp> {
public: public:
explicit GraphOptPass(std::vector<tensorflow::GraphOptimizationPass*> passes) explicit GraphOptPass(std::vector<tensorflow::GraphOptimizationPass*> passes)
: passes_(std::move(passes)) {} : passes_(std::move(passes)) {}
protected: protected:
void runOnModule() override; void runOnOperation() override;
// The passes to run on the module. // The passes to run on the module.
std::vector<GraphOptimizationPass*> passes_; std::vector<GraphOptimizationPass*> passes_;
}; };
void GraphOptPass::runOnModule() { void GraphOptPass::runOnOperation() {
mlir::ModuleOp module_in = getModule(); mlir::ModuleOp module_in = getOperation();
mlir::MLIRContext& ctx = getContext(); mlir::MLIRContext& ctx = getContext();
// Convert MLIR to Graph // Convert MLIR to Graph
@ -151,7 +151,7 @@ class GraphOptByNamePass : public GraphOptPass {
: GraphOptPass(FindRegisteredPassesByName(pass_names)) {} : GraphOptPass(FindRegisteredPassesByName(pass_names)) {}
private: private:
void runOnModule() override { void runOnOperation() override {
// Verify all passes requested were registered/found. // Verify all passes requested were registered/found.
for (auto pass_it : llvm::enumerate(passes_)) { for (auto pass_it : llvm::enumerate(passes_)) {
if (pass_it.value() == nullptr) { if (pass_it.value() == nullptr) {
@ -160,7 +160,7 @@ class GraphOptByNamePass : public GraphOptPass {
return signalPassFailure(); return signalPassFailure();
} }
} }
return GraphOptPass::runOnModule(); return GraphOptPass::runOnOperation();
} }
}; };

View File

@ -48,8 +48,9 @@ constexpr char kPaddingMapAttr[] = "padding_map";
// (user). // (user).
namespace { namespace {
struct TPUDynamicPaddingMapper : public ModulePass<TPUDynamicPaddingMapper> { struct TPUDynamicPaddingMapper
void runOnModule() override; : public OperationPass<TPUDynamicPaddingMapper, ModuleOp> {
void runOnOperation() override;
}; };
// Creates a mapping from replicated input index (in `tf_device.replicate` op) // Creates a mapping from replicated input index (in `tf_device.replicate` op)
@ -190,8 +191,8 @@ LogicalResult RemapAndAssignPaddingMaps(tf_device::LaunchFuncOp launch_func,
return success(); return success();
} }
void TPUDynamicPaddingMapper::runOnModule() { void TPUDynamicPaddingMapper::runOnOperation() {
ModuleOp module = getModule(); ModuleOp module = getOperation();
SymbolTable symbol_table(module); SymbolTable symbol_table(module);
module.walk([&](tf_device::LaunchFuncOp launch_func) { module.walk([&](tf_device::LaunchFuncOp launch_func) {
RemapAndAssignPaddingMaps(launch_func, &symbol_table); RemapAndAssignPaddingMaps(launch_func, &symbol_table);

View File

@ -98,8 +98,8 @@ constexpr char kBadArrayAttrLengthMsg[] =
// %4 = "tf.SomeOp"(%3) // %4 = "tf.SomeOp"(%3)
namespace { namespace {
struct TPURewritePass : public ModulePass<TPURewritePass> { struct TPURewritePass : public OperationPass<TPURewritePass, ModuleOp> {
void runOnModule() override; void runOnOperation() override;
}; };
// Creates a missing attribute error message. // Creates a missing attribute error message.
@ -747,13 +747,13 @@ LogicalResult Rewrite(
return success(); return success();
} }
void TPURewritePass::runOnModule() { void TPURewritePass::runOnOperation() {
mlir::TF::RuntimeDevices devices; mlir::TF::RuntimeDevices devices;
if (failed(tensorflow::GetDevicesFromOp(getModule(), &devices))) if (failed(tensorflow::GetDevicesFromOp(getOperation(), &devices)))
return signalPassFailure(); return signalPassFailure();
OpBuilder builder(&getContext()); OpBuilder builder(&getContext());
auto result = getModule().walk([&](tf_device::LaunchFuncOp op) { auto result = getOperation().walk([&](tf_device::LaunchFuncOp op) {
if (failed(Rewrite(op, devices.device_names(), &builder))) if (failed(Rewrite(op, devices.device_names(), &builder)))
return WalkResult::interrupt(); return WalkResult::interrupt();
@ -763,7 +763,7 @@ void TPURewritePass::runOnModule() {
if (result.wasInterrupted()) return signalPassFailure(); if (result.wasInterrupted()) return signalPassFailure();
// Eliminate TPUCompilationResultOp now that the rewrite is complete. // Eliminate TPUCompilationResultOp now that the rewrite is complete.
getModule().walk([&](TF::TPUCompilationResultOp op) { op.erase(); }); getOperation().walk([&](TF::TPUCompilationResultOp op) { op.erase(); });
// TODO(b/139377366): Remove functions that are no longer needed. // TODO(b/139377366): Remove functions that are no longer needed.
} }

View File

@ -40,8 +40,8 @@ namespace {
constexpr char kShardingAttr[] = "xla_hlo.sharding"; constexpr char kShardingAttr[] = "xla_hlo.sharding";
struct TPUShardingIdentificationPass struct TPUShardingIdentificationPass
: public ModulePass<TPUShardingIdentificationPass> { : public OperationPass<TPUShardingIdentificationPass, ModuleOp> {
void runOnModule() override; void runOnOperation() override;
}; };
// XlaSharding op may be direct user of inputs but it may also be followed by // XlaSharding op may be direct user of inputs but it may also be followed by
@ -176,9 +176,9 @@ void IdentifyXlaShardingForTPUComputation(Builder* builder,
builder->getStrArrayAttr(sharding_for_rets)); builder->getStrArrayAttr(sharding_for_rets));
} }
void TPUShardingIdentificationPass::runOnModule() { void TPUShardingIdentificationPass::runOnOperation() {
Builder builder(getModule().getContext()); Builder builder(getOperation().getContext());
getModule().walk([&](tf_device::LaunchFuncOp launch_func) { getOperation().walk([&](tf_device::LaunchFuncOp launch_func) {
IdentifyXlaShardingForTPUComputation(&builder, launch_func); IdentifyXlaShardingForTPUComputation(&builder, launch_func);
}); });
} }

View File

@ -116,8 +116,8 @@ std::string GetRandomStateVariableName() {
// tf.TPUReshardVariablesOp(%rvar, %default_format, %rstate) // tf.TPUReshardVariablesOp(%rvar, %default_format, %rstate)
// } // }
struct TPUVariableRuntimeReformattingPass struct TPUVariableRuntimeReformattingPass
: public ModulePass<TPUVariableRuntimeReformattingPass> { : public OperationPass<TPUVariableRuntimeReformattingPass, ModuleOp> {
void runOnModule() override; void runOnOperation() override;
}; };
// Returns the earlier value of which `v` is an identity. If `skipped` is // Returns the earlier value of which `v` is an identity. If `skipped` is
@ -318,7 +318,7 @@ TF::WhileOp AddStateVarsToWhileOp(TF::WhileOp while_op, FuncOp body,
new_body_return_vals.push_back(inner_arg); new_body_return_vals.push_back(inner_arg);
new_while_operands.push_back(state_var.resource()); new_while_operands.push_back(state_var.resource());
} }
OpBuilder builder(&body.front()); OpBuilder builder = OpBuilder::atBlockEnd(&body.front());
// Update return values. // Update return values.
builder.create<ReturnOp>(body_return.getLoc(), new_body_return_vals); builder.create<ReturnOp>(body_return.getLoc(), new_body_return_vals);
body_return.erase(); body_return.erase();
@ -555,8 +555,8 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
builder.create<tf_device::ReturnOp>(while_op.getLoc(), ArrayRef<Value>{}); builder.create<tf_device::ReturnOp>(while_op.getLoc(), ArrayRef<Value>{});
} }
void TPUVariableRuntimeReformattingPass::runOnModule() { void TPUVariableRuntimeReformattingPass::runOnOperation() {
auto module = getModule(); auto module = getOperation();
module.walk([&](TF::WhileOp while_op) { module.walk([&](TF::WhileOp while_op) {
auto body = llvm::cast<FuncOp>(module.lookupSymbol(while_op.body())); auto body = llvm::cast<FuncOp>(module.lookupSymbol(while_op.body()));
tf_device::ReplicateOp replicate; tf_device::ReplicateOp replicate;

View File

@ -218,7 +218,7 @@ void ControlToExecutorDialectConversion::runOnFunction() {
} }
// Create the operation inside the island // Create the operation inside the island
OpBuilder island_builder(&island.GetBody()); OpBuilder island_builder = OpBuilder::atBlockEnd(&island.GetBody());
Operation *inner_op = island_builder.createOperation(result); Operation *inner_op = island_builder.createOperation(result);
inner_op->setAttrs(op.getAttrList()); inner_op->setAttrs(op.getAttrList());

View File

@ -68,7 +68,7 @@ void ExecutorToControlDialectConversion::runOnFunction() {
Block &body = getFunction().front(); Block &body = getFunction().front();
auto graph = cast<tf_executor::GraphOp>(body.front()); auto graph = cast<tf_executor::GraphOp>(body.front());
OpBuilder builder(&body); OpBuilder builder = OpBuilder::atBlockEnd(&body);
SmallString<64> new_op_name; SmallString<64> new_op_name;
for (auto &op : llvm::make_early_inc_range(llvm::reverse(graph.GetBody()))) { for (auto &op : llvm::make_early_inc_range(llvm::reverse(graph.GetBody()))) {
LLVM_DEBUG(llvm::dbgs() << "Process: " << op.getName() << "\n"); LLVM_DEBUG(llvm::dbgs() << "Process: " << op.getName() << "\n");

View File

@ -1452,7 +1452,8 @@ mlir::Operation* ImporterBase::createOperation(
result.location, types, control_operands, result.location, types, control_operands,
mlir::ArrayRef<mlir::NamedAttribute>{}); mlir::ArrayRef<mlir::NamedAttribute>{});
island.body().push_back(new mlir::Block); island.body().push_back(new mlir::Block);
mlir::OpBuilder island_builder(&island.GetBody()); mlir::OpBuilder island_builder =
mlir::OpBuilder::atBlockEnd(&island.GetBody());
// Create the operation inside the island now. // Create the operation inside the island now.
mlir::Operation* inner_op; mlir::Operation* inner_op;
@ -2928,12 +2929,11 @@ class SavedModelSignatureDefImporter {
// Converts the SavedModel to the SavedModel dialect. Creates an MLIR function // Converts the SavedModel to the SavedModel dialect. Creates an MLIR function
// for each signature. // for each signature.
StatusOr<mlir::OwningModuleRef> ConvertSignatures(); StatusOr<mlir::OwningModuleRef> ConvertSignatures();
Status ConvertSignature( Status ConvertSignature(const GraphDef& graphdef,
const GraphDef& graphdef, const std::string& sig_def_key, const std::string& sig_def_key,
const std::map<std::string, TensorInfo>& inputs_sorted, const SignatureDef& signature_def,
const std::map<std::string, TensorInfo>& outputs_sorted, const GraphDebugInfo& debug_info,
const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def);
const FunctionLibraryDefinition& flib_def);
// Creates GlobalTensorOp for each variable and moves each VarHandle op to // Creates GlobalTensorOp for each variable and moves each VarHandle op to
// the enclosing function's arguments. // the enclosing function's arguments.
@ -2948,10 +2948,7 @@ class SavedModelSignatureDefImporter {
const llvm::SmallVectorImpl<mlir::TF::VarHandleOp>& ops); const llvm::SmallVectorImpl<mlir::TF::VarHandleOp>& ops);
GraphImportConfig::InputArrays ParseInputArrays( GraphImportConfig::InputArrays ParseInputArrays(
const std::map<std::string, TensorInfo>& inputs); const std::vector<std::pair<std::string, TensorInfo>>& inputs);
std::vector<std::string> ParseOutputArrays(
const std::map<std::string, TensorInfo>& outputs);
const SavedModelBundle& bundle_; const SavedModelBundle& bundle_;
mlir::OwningModuleRef module_; mlir::OwningModuleRef module_;
@ -2979,14 +2976,8 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
continue; continue;
} }
// protobuf::Map doesn't provide stable iteration order so use std::map TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, signature_def,
std::map<std::string, TensorInfo> inputs_sorted( debug_info, flib_def));
signature_def.inputs().begin(), signature_def.inputs().end());
std::map<std::string, TensorInfo> outputs_sorted(
signature_def.outputs().begin(), signature_def.outputs().end());
TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, inputs_sorted,
outputs_sorted, debug_info, flib_def));
} }
TF_RETURN_IF_ERROR(LiftVariables()); TF_RETURN_IF_ERROR(LiftVariables());
@ -2999,13 +2990,26 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
Status SavedModelSignatureDefImporter::ConvertSignature( Status SavedModelSignatureDefImporter::ConvertSignature(
const GraphDef& graphdef, const std::string& sig_def_key, const GraphDef& graphdef, const std::string& sig_def_key,
const std::map<std::string, TensorInfo>& inputs_sorted, const SignatureDef& signature_def, const GraphDebugInfo& debug_info,
const std::map<std::string, TensorInfo>& outputs_sorted,
const GraphDebugInfo& debug_info,
const FunctionLibraryDefinition& flib_def) { const FunctionLibraryDefinition& flib_def) {
// Create local vectors for the input and output and sort them to be
// deterministic. We don't want anyone to really depend on the order, client
// should lookup argument/result mapping by attribute name.
// To avoid accidentally depending on the order we use an unintuitive sorting.
std::vector<std::pair<std::string, TensorInfo>> inputs(
signature_def.inputs().begin(), signature_def.inputs().end());
llvm::sort(inputs, [](const auto& lhs, const auto& rhs) {
return lhs.first.size() < rhs.first.size() || lhs.first > rhs.first;
});
std::vector<std::pair<std::string, TensorInfo>> outputs(
signature_def.outputs().begin(), signature_def.outputs().end());
llvm::sort(outputs, [](const auto& lhs, const auto& rhs) {
return lhs.first.size() < rhs.first.size() || lhs.first > rhs.first;
});
GraphImportConfig specs; GraphImportConfig specs;
specs.inputs = ParseInputArrays(inputs_sorted); specs.inputs = ParseInputArrays(inputs);
specs.outputs = ParseOutputArrays(outputs_sorted); for (auto& output : outputs) specs.outputs.push_back(output.second.name());
// Remove unused nodes and create sub-graphdef. // Remove unused nodes and create sub-graphdef.
GraphDef sub_graph_def; GraphDef sub_graph_def;
@ -3041,11 +3045,11 @@ Status SavedModelSignatureDefImporter::ConvertSignature(
builder.getStrArrayAttr({sig_def_key})); builder.getStrArrayAttr({sig_def_key}));
// Transfer input and output parameter names to index_path attributes. // Transfer input and output parameter names to index_path attributes.
for (auto input_and_idx : llvm::enumerate(inputs_sorted)) { for (auto input_and_idx : llvm::enumerate(inputs)) {
func_op.setArgAttr(input_and_idx.index(), "tf_saved_model.index_path", func_op.setArgAttr(input_and_idx.index(), "tf_saved_model.index_path",
builder.getStrArrayAttr({input_and_idx.value().first})); builder.getStrArrayAttr({input_and_idx.value().first}));
} }
for (auto output_and_idx : llvm::enumerate(outputs_sorted)) { for (auto output_and_idx : llvm::enumerate(outputs)) {
func_op.setResultAttr( func_op.setResultAttr(
output_and_idx.index(), "tf_saved_model.index_path", output_and_idx.index(), "tf_saved_model.index_path",
builder.getStrArrayAttr({output_and_idx.value().first})); builder.getStrArrayAttr({output_and_idx.value().first}));
@ -3180,7 +3184,7 @@ Status SavedModelSignatureDefImporter::ReadVariablesFromSession(
} }
GraphImportConfig::InputArrays SavedModelSignatureDefImporter::ParseInputArrays( GraphImportConfig::InputArrays SavedModelSignatureDefImporter::ParseInputArrays(
const std::map<std::string, TensorInfo>& inputs) { const std::vector<std::pair<std::string, TensorInfo>>& inputs) {
GraphImportConfig::InputArrays results; GraphImportConfig::InputArrays results;
for (const auto& iter : inputs) { for (const auto& iter : inputs) {
const auto& tensor_info = iter.second; const auto& tensor_info = iter.second;
@ -3192,28 +3196,12 @@ GraphImportConfig::InputArrays SavedModelSignatureDefImporter::ParseInputArrays(
array_info.imported_dtype = tensor_info.dtype(); array_info.imported_dtype = tensor_info.dtype();
array_info.shape = tensor_info.tensor_shape(); array_info.shape = tensor_info.tensor_shape();
std::vector<std::string> node_names = results.insert(std::pair<std::string, ArrayInfo>(tensor_info.name(),
absl::StrSplit(tensor_info.name(), ':');
results.insert(std::pair<std::string, ArrayInfo>(node_names.at(0),
std::move(array_info))); std::move(array_info)));
} }
return results; return results;
} }
std::vector<std::string> SavedModelSignatureDefImporter::ParseOutputArrays(
const std::map<std::string, TensorInfo>& outputs) {
std::vector<std::string> results;
for (const auto& iter : outputs) {
const auto& tensor_info = iter.second;
std::vector<std::string> node_names =
absl::StrSplit(tensor_info.name(), ':');
results.push_back(node_names.at(0));
}
return results;
}
} // namespace } // namespace
Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def) { Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def) {

View File

@ -328,6 +328,24 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "buffer_assignment",
srcs = ["transforms/buffer_assignment.cc"],
hdrs = ["transforms/buffer_assignment.h"],
deps = [
":hlo",
":lhlo",
"@com_google_absl//absl/memory",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
)
gentbl( gentbl(
name = "xla_legalize_to_standard_inc_gen", name = "xla_legalize_to_standard_inc_gen",
tbl_outs = [ tbl_outs = [

View File

@ -140,7 +140,7 @@ tensorflow::Status HloFunctionImporter::ImportInstructions(
instruction_value_map_[hlo_parameter] = block->getArgument(i); instruction_value_map_[hlo_parameter] = block->getArgument(i);
} }
mlir::OpBuilder builder(block); mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block);
for (auto instruction : computation->MakeInstructionPostOrder()) { for (auto instruction : computation->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(auto new_operation, TF_ASSIGN_OR_RETURN(auto new_operation,
ImportInstruction(instruction, &builder)); ImportInstruction(instruction, &builder));
@ -523,6 +523,32 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
attributes.push_back(builder_->getNamedAttr("transpose_a", transpose_a)); attributes.push_back(builder_->getNamedAttr("transpose_a", transpose_a));
MakeAndReturn(TriangularSolveOp); MakeAndReturn(TriangularSolveOp);
} }
case HloOpcode::kReduceWindow: {
llvm::SmallVector<int64, 4> sizes, strides, base_dilations, win_dilations;
llvm::SmallVector<int64_t, 8> padding;
for (const auto& dim : instruction->window().dimensions()) {
sizes.push_back(dim.size());
strides.push_back(dim.stride());
base_dilations.push_back(dim.base_dilation());
win_dilations.push_back(dim.window_dilation());
padding.push_back(dim.padding_low());
padding.push_back(dim.padding_high());
}
attributes.push_back(builder_->getNamedAttr("window_dimensions",
ConvertDimensions(sizes)));
attributes.push_back(
builder_->getNamedAttr("window_strides", ConvertDimensions(strides)));
attributes.push_back(builder_->getNamedAttr(
"base_dilations", ConvertDimensions(base_dilations)));
attributes.push_back(builder_->getNamedAttr(
"window_dilations", ConvertDimensions(win_dilations)));
attributes.push_back(ConvertPadding(padding));
auto reduce = func_builder->create<mlir::xla_hlo::ReduceWindowOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(
ImportComputation(instruction->to_apply(), &reduce.body()));
return reduce.getOperation();
}
case HloOpcode::kMap: { case HloOpcode::kMap: {
auto op = func_builder->create<mlir::xla_hlo::MapOp>( auto op = func_builder->create<mlir::xla_hlo::MapOp>(
loc, result_type, operands, loc, result_type, operands,

View File

@ -0,0 +1,131 @@
// RUN: tf-opt -test-buffer-assignment -split-input-file %s | FileCheck %s -dump-input-on-failure
// CHECK-LABEL: Testing : condBranch
func @condBranch(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
// CHECK: Alloc: cond_br
cond_br %cond, ^bb1, ^bb2
^bb1:
br ^exit(%arg0 : tensor<2xf32>)
^bb2:
%1 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
br ^exit(%1 : tensor<2xf32>)
^exit(%arg1: tensor<2xf32>):
return %arg1 : tensor<2xf32>
// CHECK-NEXT: Dealloc: return
}
// -----
// CHECK-LABEL: Testing : criticalEdge
func @criticalEdge(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
// CHECK: Alloc: cond_br
cond_br %cond, ^bb1, ^exit(%arg0 : tensor<2xf32>)
^bb1:
%0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
br ^exit(%0 : tensor<2xf32>)
^exit(%arg1: tensor<2xf32>):
return %arg1 : tensor<2xf32>
// CHECK-NEXT: Dealloc: return
}
// -----
// CHECK-LABEL: Testing : invCriticalEdge
func @invCriticalEdge(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
// CHECK: Alloc: %0 = "xla_hlo.exp"
%0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
cond_br %cond, ^bb1, ^exit(%arg0 : tensor<2xf32>)
^bb1:
br ^exit(%0 : tensor<2xf32>)
^exit(%arg1: tensor<2xf32>):
return %arg1 : tensor<2xf32>
// CHECK-NEXT: Dealloc: return
}
// -----
// CHECK-LABEL: Testing : ifElse
func @ifElse(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
// CHECK: Alloc: %0 = "xla_hlo.exp"(%arg1)
%0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>)
^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>):
br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>)
^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>):
br ^exit(%arg3, %arg4 : tensor<2xf32>, tensor<2xf32>)
^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>):
// CHECK-NEXT: Dealloc: %7 = "xla_hlo.exp"(%5)
// CHECK: Alloc: %7 = "xla_hlo.exp"(%5)
// CHECK-NEXT: Dealloc: return
%1 = "xla_hlo.exp"(%arg5) : (tensor<2xf32>) -> tensor<2xf32>
return %1 : tensor<2xf32>
}
// -----
// CHECK-LABEL: Testing : ifElseNoUsers
func @ifElseNoUsers(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
// CHECK: Alloc: %0 = "xla_hlo.exp"(%arg1)
%0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>)
^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>):
br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>)
^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>):
br ^exit(%arg3, %arg4 : tensor<2xf32>, tensor<2xf32>)
^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>):
// CHECK-NEXT: return
return %arg0 : tensor<2xf32>
}
// -----
// CHECK-LABEL: Testing : ifElseNested
func @ifElseNested(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
// CHECK: Alloc: %0 = "xla_hlo.exp"(%arg1)
%0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>)
^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>):
br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>)
^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>):
cond_br %cond, ^bb3(%arg3 : tensor<2xf32>), ^bb4(%arg4 : tensor<2xf32>)
^bb3(%arg7 : tensor<2xf32>):
br ^exit(%arg7, %arg3 : tensor<2xf32>, tensor<2xf32>)
^bb4(%arg8 : tensor<2xf32>):
br ^exit(%arg3, %arg8 : tensor<2xf32>, tensor<2xf32>)
^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>):
// CHECK-NEXT: Dealloc: %9 = "xla_hlo.exp"(%7)
// CHECK: Alloc: %9 = "xla_hlo.exp"(%7)
// CHECK-NEXT: Dealloc: return
%1 = "xla_hlo.exp"(%arg5) : (tensor<2xf32>) -> tensor<2xf32>
return %1 : tensor<2xf32>
}
// -----
// CHECK-LABEL: Testing : redundantOperations
func @redundantOperations(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) {
// CHECK: Alloc: %0 = xla_hlo.maximum
// CHECK-NEXT: Dealloc: %1 = xla_hlo.add
%1 = "xla_hlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK: Alloc: %1 = xla_hlo.add
// CHECK-NEXT: Dealloc: %1 = xla_hlo.add
%2 = "xla_hlo.add"(%arg0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return
}
// -----
// CHECK-LABEL: Testing : reduce
func @reduce(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
// CHECK: Alloc: %0 = xla_hlo.constant
// CHECK-NEXT: Dealloc: %1 = "xla_hlo.reduce"(%arg0, %0)
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: Alloc: %1 = "xla_hlo.reduce"(%arg0, %0)
// CHECK: Dealloc: return
%2 = "xla_hlo.reduce"(%arg0, %0) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%4 = xla_hlo.add %arg1, %arg2 : tensor<f32>
"xla_hlo.return"(%4) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
return %2 : tensor<4x8xf32>
}

View File

@ -31,11 +31,12 @@ func @reduce(%arg: memref<100x10x5xf32>,
// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 {
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32> // CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]]) // CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref<f32> // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: loop.reduce.return [[ACC_RESULT]] : f32 // CHECK: loop.reduce.return [[ACC_RESULT]] : f32
// CHECK: } // CHECK: }
// CHECK: loop.yield // CHECK: loop.yield
@ -71,11 +72,12 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>,
// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 {
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32> // CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]]) // CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref<f32> // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: loop.reduce.return [[ACC_RESULT]] // CHECK: loop.reduce.return [[ACC_RESULT]]
// CHECK: } // CHECK: }
// CHECK: loop.yield // CHECK: loop.yield
@ -114,11 +116,12 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 {
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32> // CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]]) // CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref<f32> // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: loop.reduce.return [[ACC_RESULT]] : f32 // CHECK: loop.reduce.return [[ACC_RESULT]] : f32
// CHECK: } // CHECK: }
// CHECK: loop.yield // CHECK: loop.yield
@ -185,11 +188,12 @@ func @reduce_window(%arg: memref<112x112xf32>,
// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 {
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32> // CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32> // CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_BUF]]) // CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_BUF]][] : memref<f32> // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: loop.reduce.return [[ACC_RESULT]] : f32 // CHECK: loop.reduce.return [[ACC_RESULT]] : f32
// CHECK: } // CHECK: }
// CHECK: loop.yield // CHECK: loop.yield

View File

@ -698,6 +698,24 @@ add {
ROOT %tuple.6 = ((f32[], f32[]), f32[]) tuple(%reduce.1, %sub.5) ROOT %tuple.6 = ((f32[], f32[]), f32[]) tuple(%reduce.1, %sub.5)
} }
// CHECK-LABEL: func @test_reduce_window
// CHECK-SAME: ([[ARG0:%.*]]: tensor<2x17x31x7xf32>, [[ARG1:%.*]]: tensor<f32>)
%test_reduce_window (Arg_0.1: f32[2,17,31,7], Arg_1.2: f32[]) -> f32[2,5,8,7] {
%Arg_0.1 = f32[2,17,31,7] parameter(0)
%Arg_1.2 = f32[] parameter(1)
// CHECK: "xla_hlo.reduce_window"([[ARG0]], [[ARG1]]) ( {
// CHECK: xla_hlo.add {{.*}} : tensor<f32>
// CHECK: }) {
// CHECK-SAME: base_dilations = dense<1> : tensor<4xi64>
// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64>
// CHECK-SAME: window_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64>
// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>
// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>
// CHECK_SAME: }
ROOT %reduce-window.1 = f32[2,5,8,7] reduce-window(f32[2,17,31,7] %Arg_0.1, f32[] %Arg_1.2), window={size=1x2x2x1 stride=1x4x4x1 pad=0_0x2_0x0_2x0_0 rhs_dilate=1x2x2x1}, to_apply=%reduce_helper.3
}
// CHECK-LABEL: func @test_remainder // CHECK-LABEL: func @test_remainder
// CHECK-SAME: ([[VAL_0:%.*]]: tensor<4xf32>, [[VAL_1:%.*]]: tensor<4xf32>) // CHECK-SAME: ([[VAL_0:%.*]]: tensor<4xf32>, [[VAL_1:%.*]]: tensor<4xf32>)
%test_remainder (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { %test_remainder (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {

View File

@ -0,0 +1,501 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for computing proper alloc and dealloc positions.
// The main class is the BufferAssignment class that realizes this analysis.
// In order to put allocations and deallocations at safe positions, it is
// significantly important to put them into the proper blocks. However, the
// liveness analysis does not pay attention to aliases, which can occur due to
// branches (and their associated block arguments) in general. For this purpose,
// BufferAssignment firstly finds all possible aliases for a single value (using
// the BufferAssignmentAliasAnalysis class). Consider the following example:
//
// ^bb0(%arg0):
// cond_br %cond, ^bb1, ^bb2
// ^bb1:
// br ^exit(%arg0)
// ^bb2:
// %new_value = ...
// br ^exit(%new_value)
// ^exit(%arg1):
// return %arg1;
//
// Using liveness information on its own would cause us to place the allocs and
// deallocs in the wrong block. This is due to the fact that %new_value will not
// be liveOut of its block. Instead, we have to place the alloc for %new_value
// in bb0 and its associated dealloc in exit. Using the class
// BufferAssignmentAliasAnalysis, we will find out that %new_value has a
// potential alias %arg1. In order to find the dealloc position we have to find
// all potential aliases, iterate over their uses and find the common
// post-dominator block. In this block we can safely be sure that %new_value
// will die and can use liveness information to determine the exact operation
// after which we have to insert the dealloc. Finding the alloc position is
// highly similar and non- obvious. Again, we have to consider all potential
// aliases and find the common dominator block to place the alloc.
//
// TODO(dfki):
// The current implementation does not support loops. The only thing that
// is currently missing is a high-level loop analysis that allows us to move
// allocs and deallocs outside of the loop blocks.
#include "tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "absl/memory/memory.h"
namespace mlir {
namespace xla {
namespace {
//===----------------------------------------------------------------------===//
// BufferAssignmentAliasAnalysis
//===----------------------------------------------------------------------===//
/// A straight-forward alias analysis which ensures that all aliases of all
/// values will be determined. This is a requirement for the BufferAssignment
/// class since you need to determine safe positions to place alloc and
/// deallocs.
class BufferAssignmentAliasAnalysis {
public:
using ValueSetT = SmallPtrSet<Value, 16>;
public:
/// Constructs a new alias analysis using the op provided.
BufferAssignmentAliasAnalysis(Operation* op) { build(op->getRegions()); }
/// Finds all immediate and indirect aliases this value could potentially
/// have. Note that the resulting set will also contain the value provided as
/// it is an alias of itself.
ValueSetT resolve(Value value) const {
ValueSetT result;
resolveRecursive(value, result);
return result;
}
private:
/// Recursively determines alias information for the given value. It stores
/// all newly found potential aliases in the given result set.
void resolveRecursive(Value value, ValueSetT& result) const {
if (!result.insert(value).second) {
return;
}
auto it = aliases.find(value);
if (it == aliases.end()) return;
for (auto alias : it->second) {
resolveRecursive(alias, result);
}
}
/// This function constructs a mapping from values to its immediate aliases.
/// It iterates over all blocks, gets their predecessors, determines the
/// values that will be passed to the corresponding block arguments and
/// inserts them into map.
void build(MutableArrayRef<Region> regions) {
for (Region& region : regions) {
for (Block& block : region) {
// Iterate over all predecessor and get the mapped values to their
// corresponding block arguments values.
for (auto pred : block.getPredecessors()) {
// Determine the current successor index of the current predecessor.
unsigned successorIndex = std::distance(
pred->getSuccessors().begin(),
llvm::find_if(pred->getSuccessors(), [&](Block* successor) {
return successor == &block;
}));
// Get the terminator and the values that will be passed to our block.
if (auto branchInterface =
dyn_cast<BranchOpInterface>(pred->getTerminator())) {
// Query the branch op interace to get the successor operands.
auto successorOps =
branchInterface.getSuccessorOperands(successorIndex);
if (successorOps.hasValue()) {
// Build the actual mapping of values to their immediate aliases.
for (auto arg : block.getArguments()) {
Value predecessorArgValue =
successorOps.getValue()[arg.getArgNumber()];
aliases[predecessorArgValue].insert(arg);
}
}
}
}
}
}
}
/// Maps values to all immediate aliases this value can have.
llvm::DenseMap<Value, ValueSetT> aliases;
};
//===----------------------------------------------------------------------===//
// BufferAssignmentPositions
//===----------------------------------------------------------------------===//
/// Stores proper alloc and dealloc positions to place dialect-specific alloc
/// and dealloc operations.
struct BufferAssignmentPositions {
public:
BufferAssignmentPositions()
: allocPosition(nullptr), deallocPosition(nullptr) {}
/// Creates a new positions tuple including alloc and dealloc positions.
BufferAssignmentPositions(Operation* allocPosition,
Operation* deallocPosition)
: allocPosition(allocPosition), deallocPosition(deallocPosition) {}
/// Returns the alloc position before which the alloc operation has to be
/// inserted.
Operation* getAllocPosition() const { return allocPosition; }
/// Returns the dealloc position after which the dealloc operation has to be
/// inserted.
Operation* getDeallocPosition() const { return deallocPosition; }
private:
Operation* allocPosition;
Operation* deallocPosition;
};
//===----------------------------------------------------------------------===//
// BufferAssignmentAnalysis
//===----------------------------------------------------------------------===//
// The main buffer assignment analysis used to place allocs and deallocs.
class BufferAssignmentAnalysis {
public:
using DeallocSetT = SmallPtrSet<Operation*, 2>;
public:
BufferAssignmentAnalysis(Operation* op)
: operation(op),
liveness(op),
dominators(op),
postDominators(op),
aliases(op) {}
/// Computes the actual positions to place allocs and deallocs for the given
/// value.
BufferAssignmentPositions computeAllocAndDeallocPositions(Value value) const {
if (value.use_empty()) {
return BufferAssignmentPositions(value.getDefiningOp(),
value.getDefiningOp());
}
// Get all possible aliases
auto possibleValues = aliases.resolve(value);
return BufferAssignmentPositions(getAllocPosition(value, possibleValues),
getDeallocPosition(value, possibleValues));
}
/// Finds all associated dealloc nodes for the alloc nodes using alias
/// information.
DeallocSetT findAssociatedDeallocs(AllocOp alloc) const {
DeallocSetT result;
auto possibleValues = aliases.resolve(alloc);
for (auto alias : possibleValues) {
for (auto user : alias.getUsers()) {
if (isa<DeallocOp>(user)) result.insert(user);
}
}
return result;
}
/// Dumps the buffer assignment information to the given stream.
void print(raw_ostream& os) const {
os << "// ---- Buffer Assignment -----\n";
for (Region& region : operation->getRegions())
for (Block& block : region)
for (Operation& operation : block)
for (Value result : operation.getResults()) {
BufferAssignmentPositions positions =
computeAllocAndDeallocPositions(result);
os << "Positions for ";
result.print(os);
os << "\n Alloc: ";
positions.getAllocPosition()->print(os);
os << "\n Dealloc: ";
positions.getDeallocPosition()->print(os);
os << "\n";
}
}
private:
/// Finds a proper placement block to store alloc/dealloc node according to
/// the algorithm described at the top of the file. It supports dominator and
/// post-dominator analyses via template arguments.
template <typename AliasesT, typename DominatorT>
Block* findPlacementBlock(Value value, const AliasesT& aliases,
const DominatorT& doms) const {
assert(!value.isa<BlockArgument>() && "Cannot place a block argument");
// Start with the current block the value is defined in.
Block* dom = value.getDefiningOp()->getBlock();
// Iterate over all aliases and their uses to find a safe placement block
// according to the given dominator information.
for (auto alias : aliases) {
for (auto user : alias.getUsers()) {
// Move upwards in the dominator tree to find an appropriate
// dominator block that takes the current use into account.
dom = doms.findNearestCommonDominator(dom, user->getBlock());
}
}
return dom;
}
/// Finds a proper alloc positions according to the algorithm described at the
/// top of the file.
template <typename AliasesT>
Operation* getAllocPosition(Value value, const AliasesT& aliases) const {
// Determine the actual block to place the alloc and get liveness
// information.
auto placementBlock = findPlacementBlock(value, aliases, dominators);
auto livenessInfo = liveness.getLiveness(placementBlock);
// We have to ensure that the alloc will be before the first use of all
// aliases of the given value. We first assume that there are no uses in the
// placementBlock and that we can safely place the alloc before the
// terminator at the end of the block.
Operation* startOperation = placementBlock->getTerminator();
// Iterate over all aliases and ensure that the startOperation will point to
// the first operation of all potential aliases in the placementBlock.
for (auto alias : aliases) {
auto aliasStartOperation = livenessInfo->getStartOperation(alias);
// Check whether the aliasStartOperation lies in the desired block and
// whether it is before the current startOperation. If yes, this will be
// the new startOperation.
if (aliasStartOperation->getBlock() == placementBlock &&
aliasStartOperation->isBeforeInBlock(startOperation)) {
startOperation = aliasStartOperation;
}
}
// startOperation is the first operation before which we can safely store
// the alloc taking all potential aliases into account.
return startOperation;
}
/// Finds a proper dealloc positions according to the algorithm described at
/// the top of the file.
template <typename AliasesT>
Operation* getDeallocPosition(Value value, const AliasesT& aliases) const {
// Determine the actual block to place the dealloc and get liveness
// information.
auto placementBlock = findPlacementBlock(value, aliases, postDominators);
auto livenessInfo = liveness.getLiveness(placementBlock);
// We have to ensure that the dealloc will be after the last use of all
// aliases of the given value. We first assume that there are no uses in the
// placementBlock and that we can safely place the dealloc at the beginning.
Operation* endOperation = &placementBlock->front();
// Iterate over all aliases and ensure that the endOperation will point to
// the last operation of all potential aliases in the placementBlock.
for (auto alias : aliases) {
auto aliasEndOperation =
livenessInfo->getEndOperation(alias, endOperation);
// Check whether the aliasEndOperation lies in the desired block and
// whether it is behind the current endOperation. If yes, this will be the
// new endOperation.
if (aliasEndOperation->getBlock() == placementBlock &&
endOperation->isBeforeInBlock(aliasEndOperation)) {
endOperation = aliasEndOperation;
}
}
// endOperation is the last operation behind which we can safely store the
// dealloc taking all potential aliases into account.
return endOperation;
}
/// The operation this transformation was constructed from.
Operation* operation;
/// The underlying liveness analysis to compute fine grained information about
/// alloc and dealloc positions.
Liveness liveness;
/// The dominator analysis to place allocs in the appropriate blocks.
DominanceInfo dominators;
/// The post dominator analysis to place deallocs in the appropriate blocks.
PostDominanceInfo postDominators;
/// The internal alias analysis to ensure that allocs and deallocs take all
/// their potential aliases into account.
BufferAssignmentAliasAnalysis aliases;
};
//===----------------------------------------------------------------------===//
// BufferAssignmentPass
//===----------------------------------------------------------------------===//
/// The actual buffer assignment pass that moves alloc and dealloc nodes into
/// the right positions. It uses the algorithm described at the top of the file.
// TODO(dfki): create a templated version that allows to match dialect-specific
// alloc/dealloc nodes and to insert dialect-specific dealloc node.
struct BufferAssignmentPass : mlir::FunctionPass<BufferAssignmentPass> {
void runOnFunction() override {
// Get required analysis information first.
auto& analysis = getAnalysis<BufferAssignmentAnalysis>();
// Compute an initial placement of all nodes.
llvm::SmallDenseMap<Value, BufferAssignmentPositions, 16> placements;
getFunction().walk([&](AllocOp alloc) {
placements[alloc] = analysis.computeAllocAndDeallocPositions(alloc);
});
// Move alloc (and dealloc - if any) nodes into the right places
// and insert dealloc nodes if necessary.
getFunction().walk([&](AllocOp alloc) {
// Find already associated dealloc nodes.
auto deallocs = analysis.findAssociatedDeallocs(alloc);
assert(deallocs.size() < 2 &&
"Not supported number of associated dealloc operations");
// Move alloc node to the right place.
BufferAssignmentPositions& positions = placements[alloc];
Operation* allocOperation = alloc.getOperation();
allocOperation->moveBefore(positions.getAllocPosition());
// If there is an existing dealloc, move it to the right place.
if (deallocs.size()) {
Operation* nextOp = positions.getDeallocPosition()->getNextNode();
if (!nextOp)
nextOp = &positions.getDeallocPosition()->getBlock()->back();
(*deallocs.begin())->moveBefore(nextOp);
} else {
// If there is no dealloc node, insert one in the right place.
OpBuilder builder(alloc);
builder.setInsertionPointAfter(positions.getDeallocPosition());
builder.create<DeallocOp>(allocOperation->getLoc(), alloc);
}
});
};
};
} // namespace
//===----------------------------------------------------------------------===//
// BufferAssignmentPlacer
//===----------------------------------------------------------------------===//
/// Creates a new assignment placer.
BufferAssignmentPlacer::BufferAssignmentPlacer(Operation* op)
: operation(op), dominators(op) {}
/// Computes the actual position to place allocs for the given value.
OpBuilder::InsertPoint BufferAssignmentPlacer::computeAllocPosition(
Value value) {
Operation* insertOp;
if (auto arg = value.dyn_cast<BlockArgument>()) {
// This is a block argument which has to be allocated in the scope
// of its associated terminator.
auto domNode = dominators.getNode(arg.getOwner());
assert(domNode != nullptr && "Cannot find dominator info");
auto idomNode = domNode->getIDom();
assert(idomNode != nullptr && "There is no parent dominator");
insertOp = idomNode->getBlock()->getTerminator();
} else {
insertOp = value.getDefiningOp();
}
OpBuilder opBuilder(insertOp);
return opBuilder.saveInsertionPoint();
}
//===----------------------------------------------------------------------===//
// FunctionAndBlockSignatureConverter
//===----------------------------------------------------------------------===//
// Performs the actual signature rewriting step.
LogicalResult FunctionAndBlockSignatureConverter::matchAndRewrite(
FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const {
auto toMemrefConverter = [&](Type t) -> Type {
if (auto tensorType = t.dyn_cast<RankedTensorType>()) {
return MemRefType::get(tensorType.getShape(),
tensorType.getElementType());
}
return t;
};
// Converting tensor-type function arguments to memref-type.
auto funcType = funcOp.getType();
TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
for (auto argType : llvm::enumerate(funcType.getInputs())) {
conversion.addInputs(argType.index(), toMemrefConverter(argType.value()));
}
for (auto resType : funcType.getResults()) {
conversion.addInputs(toMemrefConverter(resType));
}
rewriter.updateRootInPlace(funcOp, [&] {
funcOp.setType(
rewriter.getFunctionType(conversion.getConvertedTypes(), llvm::None));
rewriter.applySignatureConversion(&funcOp.getBody(), conversion);
});
// Converting tensor-type block arugments of all blocks inside the
// function region to memref-type except for the entry block.
for (auto& block : funcOp.getBlocks()) {
if (block.isEntryBlock()) continue;
for (int i = 0, e = block.getNumArguments(); i < e; ++i) {
auto oldArg = block.getArgument(i);
auto newArg =
block.insertArgument(i, toMemrefConverter(oldArg.getType()));
oldArg.replaceAllUsesWith(newArg);
block.eraseArgument(i + 1);
}
}
return success();
}
// Adding functions whose arguments are memref type to the set of legal
// operations.
void FunctionAndBlockSignatureConverter::addDynamicallyLegalFuncOp(
ConversionTarget& target) {
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
auto inputs = op.getType().getInputs();
return std::all_of(inputs.begin(), inputs.end(),
[](Type input) { return input.isa<MemRefType>(); });
});
}
//===----------------------------------------------------------------------===//
// Buffer assignment pass registrations
//===----------------------------------------------------------------------===//
std::unique_ptr<OpPassBase<FuncOp>> createBufferAssignmentPass() {
return absl::make_unique<BufferAssignmentPass>();
}
static PassRegistration<BufferAssignmentPass> buffer_assignment_pass(
"buffer-assignment",
"Executes buffer assignment pass to automatically move alloc and dealloc "
"operations into their proper positions");
/// A simple pass to print debug/test information for the buffer assignment
/// analysis.
struct BufferAssignmentTestPass : mlir::FunctionPass<BufferAssignmentTestPass> {
void runOnFunction() override {
llvm::outs() << "Testing : " << getFunction().getName() << "\n";
getAnalysis<BufferAssignmentAnalysis>().print(llvm::outs());
};
};
std::unique_ptr<OpPassBase<FuncOp>> createBufferAssignmentTestPass() {
return absl::make_unique<BufferAssignmentTestPass>();
}
static PassRegistration<BufferAssignmentTestPass> buffer_assignment_test_pass(
"test-buffer-assignment",
"Outputs debug test information for the buffer assignment analysis");
} // namespace xla
} // namespace mlir

View File

@ -0,0 +1,140 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_BUFFER_ASSIGNMENT_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_BUFFER_ASSIGNMENT_H_
#include "mlir/Analysis/Dominance.h"
#include "mlir/Analysis/Liveness.h"
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project
namespace mlir {
namespace xla {
/// Prepares a buffer assignment phase. It can place (user-defined) alloc
/// nodes. This simplifies the integration of the actual buffer-assignment
/// pass. Sample usage:
/// BufferAssignmentPlacer baHelper(regionOp);
/// -> determine alloc positions
/// auto allocPosition = baHelper.computeAllocPosition(value);
/// -> place alloc
/// allocBuilder.setInsertionPoint(positions.getAllocPosition());
/// <create alloc>
/// alternatively:
/// -> place alloc
/// baHelper.insertAlloc<AllocOp>(...);
/// Note: this class is intended to be used during legalization. In order
/// to move alloc and dealloc nodes into the right places you can use the
/// createBufferAssignmentPass() function.
class BufferAssignmentPlacer {
public:
/// Creates a new assignment builder.
explicit BufferAssignmentPlacer(Operation* op);
/// Returns the operation this analysis was constructed from.
Operation* getOperation() const { return operation; }
/// Computes the actual position to place allocs for the given value.
OpBuilder::InsertPoint computeAllocPosition(Value value);
private:
/// The operation this analysis was constructed from.
Operation* operation;
/// The dominator analysis to place allocs in the appropriate blocks.
DominanceInfo dominators;
};
/// Helper conversion pattern that encapsulates a BufferAssignmentPlacer
/// instance.
template <typename SourceOp>
class BufferAssignmentOpConversionPattern
: public OpConversionPattern<SourceOp> {
public:
explicit BufferAssignmentOpConversionPattern(
MLIRContext* context_,
xla::BufferAssignmentPlacer* bufferAssignment_ = nullptr,
PatternBenefit benefit_ = 1)
: OpConversionPattern<SourceOp>(context_, benefit_),
bufferAssignment(bufferAssignment_) {}
protected:
xla::BufferAssignmentPlacer* bufferAssignment;
};
// Converts only the tensor-type function and block arguments to memref-type.
class FunctionAndBlockSignatureConverter
: public BufferAssignmentOpConversionPattern<FuncOp> {
public:
using BufferAssignmentOpConversionPattern<
FuncOp>::BufferAssignmentOpConversionPattern;
// Adding functions whose arguments are memref type to the set of legal
// operations.
static void addDynamicallyLegalFuncOp(ConversionTarget& target);
// Performs the actual signature rewriting step.
LogicalResult matchAndRewrite(
FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final;
};
// This pattern converter transforms a non-void ReturnOpSourceTy into a void
// return of type ReturnOpTargetTy. It uses a copy operation of type CopyOpTy to
// copy the results to the output buffer.
template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
typename CopyOpTy>
class NonVoidToVoidReturnOpConverter
: public BufferAssignmentOpConversionPattern<ReturnOpSourceTy> {
public:
using BufferAssignmentOpConversionPattern<
ReturnOpSourceTy>::BufferAssignmentOpConversionPattern;
// Performs the actual return-op conversion step.
LogicalResult matchAndRewrite(
ReturnOpSourceTy returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto numReturnValues = returnOp.getNumOperands();
auto funcOp = returnOp.template getParentOfType<FuncOp>();
auto numFuncArgs = funcOp.getNumArguments();
auto loc = returnOp.getLoc();
// Find the corresponding output buffer for each operand.
for (auto operand : llvm::enumerate(operands)) {
auto returnArgNumber = numFuncArgs - numReturnValues + operand.index();
auto dstBuffer = funcOp.getArgument(returnArgNumber);
if (dstBuffer == operand.value()) {
continue;
}
// Insert the copy operation to copy before the return.
rewriter.setInsertionPoint(
returnOp.getOperation()->getBlock()->getTerminator());
rewriter.create<CopyOpTy>(loc, operand.value(),
funcOp.getArgument(returnArgNumber));
}
// Insert the new target return operation.
rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp);
return success();
}
};
} // namespace xla
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_BUFFER_ASSIGNMENT_H_

View File

@ -324,8 +324,8 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern {
// "xla_lhlo.terminator"() : () -> () // "xla_lhlo.terminator"() : () -> ()
// } // }
struct HloLegalizeToLhlo : public ModulePass<HloLegalizeToLhlo> { struct HloLegalizeToLhlo : public OperationPass<HloLegalizeToLhlo, ModuleOp> {
void runOnModule() override { void runOnOperation() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
auto& context = getContext(); auto& context = getContext();
ConversionTarget target(context); ConversionTarget target(context);
@ -344,7 +344,7 @@ struct HloLegalizeToLhlo : public ModulePass<HloLegalizeToLhlo> {
[](Type input) { return input.isa<MemRefType>(); }); [](Type input) { return input.isa<MemRefType>(); });
}); });
auto module = getModule(); auto module = getOperation();
populateHLOToLHLOConversionPattern(module.getContext(), &patterns); populateHLOToLHLOConversionPattern(module.getContext(), &patterns);
// Do partial conversion so we can have unknown ops in tests. // Do partial conversion so we can have unknown ops in tests.

View File

@ -51,9 +51,10 @@ using mlir::PassRegistration;
namespace mlir { namespace mlir {
namespace xla_hlo { namespace xla_hlo {
namespace { namespace {
class LegalizeTFControlFlow : public ModulePass<LegalizeTFControlFlow> { class LegalizeTFControlFlow
: public OperationPass<LegalizeTFControlFlow, ModuleOp> {
public: public:
void runOnModule() override; void runOnOperation() override;
}; };
} // namespace } // namespace
@ -164,8 +165,8 @@ void LowerWhile(TF::WhileOp op, ModuleOp module) {
} }
} // namespace } // namespace
void LegalizeTFControlFlow::runOnModule() { void LegalizeTFControlFlow::runOnOperation() {
auto module = getModule(); auto module = getOperation();
module.walk([&](TF::WhileOp op) -> void { LowerWhile(op, module); }); module.walk([&](TF::WhileOp op) -> void { LowerWhile(op, module); });
module.walk([&](TF::IfOp op) -> void { LowerIf(op, module); }); module.walk([&](TF::IfOp op) -> void { LowerIf(op, module); });

View File

@ -29,48 +29,128 @@ namespace mlir {
namespace xla_lhlo { namespace xla_lhlo {
namespace { namespace {
// Clones and adapts the code in `lhlo_block` that works on buffers and has a
// single output buffer to make it compatible with `operands` that have element
// types of the respective buffers. Returns the computed value.
//
// Example. For `operands` with (f32, i32) types and a block with LHLO ops and
// with signature:
// ^bb(%lhs: memref<f32>, %rhs: memref<i32>, %res: memref<i1>):
// <LHLO_ops>
//
// inserts necessary alloc and store ops to compute and return result that has
// `i1` type.
Value ApplySingleResultLhloCode(Location loc, ValueRange operands,
Block* lhlo_block, OpBuilder* b) {
SmallVector<Value, 2> arg_bufs;
for (auto arg_type : lhlo_block->getArgumentTypes()) {
arg_bufs.push_back(b->create<AllocOp>(loc, arg_type.cast<MemRefType>()));
}
for (auto operand : llvm::enumerate(operands)) {
b->create<StoreOp>(loc, operand.value(), arg_bufs[operand.index()]);
}
// Clone the ops from `lhlo_block`.
BlockAndValueMapping mapping;
mapping.map(lhlo_block->getArguments(), arg_bufs);
for (auto& nested : lhlo_block->without_terminator()) {
auto clone = b->clone(nested, mapping);
mapping.map(nested.getResults(), clone->getResults());
}
return b->create<LoadOp>(loc, arg_bufs.back());
}
// Converts a block with LHLO ops and with signature: // Converts a block with LHLO ops and with signature:
// ^bb(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>): // ^bb(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
// into a reduction operator of loop.reduce by doing buffer allocation for // into a reduction operator of loop.reduce by doing buffer allocation for
// scalar arguments and the result of `loop.reduce` to make it compatible with // scalar arguments and the result of `loop.reduce` to make it compatible with
// LHLO ops. // LHLO ops.
void ConvertToReductionOperator(Location loc, loop::ReduceOp reduce_op, void ConvertToReductionOperator(Location loc, loop::ReduceOp reduce_op,
Block* lhlo_block, Block* lhlo_block, OpBuilder* b) {
ConversionPatternRewriter* rewriter) {
Block& loop_reduce_op_body = reduce_op.reductionOperator().front(); Block& loop_reduce_op_body = reduce_op.reductionOperator().front();
rewriter->setInsertionPointToStart(&loop_reduce_op_body); OpBuilder::InsertionGuard guard(*b);
b->setInsertionPointToStart(&loop_reduce_op_body);
// Allocate buffers to hold arguments of reduction operator block to stay b->create<loop::ReduceReturnOp>(
// compatible with the LHLO dialect ops in the reduction body. loc, ApplySingleResultLhloCode(loc, loop_reduce_op_body.getArguments(),
Value elem_arg = lhlo_block->getArgument(0); lhlo_block, b));
Value elem_buf =
rewriter->create<AllocOp>(loc, elem_arg.getType().cast<MemRefType>());
rewriter->create<StoreOp>(loc, loop_reduce_op_body.getArgument(0), elem_buf);
Value acc_arg = lhlo_block->getArgument(1);
Value acc_buf =
rewriter->create<AllocOp>(loc, acc_arg.getType().cast<MemRefType>());
rewriter->create<StoreOp>(loc, loop_reduce_op_body.getArgument(1), acc_buf);
// Clone the ops from `xla_lhlo.reduce` into reduction operator block.
BlockAndValueMapping mapping;
mapping.map(lhlo_block->getArguments(),
ValueRange{elem_buf, acc_buf, acc_buf});
for (auto& nested : lhlo_block->without_terminator()) {
auto clone = rewriter->clone(nested, mapping);
mapping.map(nested.getResults(), clone->getResults());
}
Value acc_result = rewriter->create<LoadOp>(loc, acc_buf);
rewriter->create<loop::ReduceReturnOp>(loc, acc_result);
} }
// Returns result of ConstantOp if `dim` is static, otherwise uses DimOp to // Returns result of ConstantOp if `dim` is static, otherwise uses DimOp to
// extract dimension at runtime. // extract dimension at runtime.
Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value, Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value,
size_t dim_index, int64_t dim, size_t dim_index, int64_t dim, OpBuilder* b) {
ConversionPatternRewriter* rewriter) {
return dim == ShapedType::kDynamicSize return dim == ShapedType::kDynamicSize
? rewriter->create<DimOp>(loc, shaped_value, dim_index).getResult() ? b->create<DimOp>(loc, shaped_value, dim_index).getResult()
: rewriter->create<ConstantIndexOp>(loc, dim); : b->create<ConstantIndexOp>(loc, dim);
}
struct MappedIvs {
// False if the mapped indices are in the padding area, true otherwise.
Value in_bounds;
// Mapped indices.
SmallVector<Value, 2> ivs;
};
MappedIvs MapWindowIvsToInput(ReduceWindowOp op, ValueRange ivs,
ValueRange window_ivs, OpBuilder* b) {
MappedIvs mapped_ivs;
if (!op.window_strides().hasValue()) {
op.emitOpError("No window strides specified.");
}
auto window_strides = op.window_strides().getValue();
if (!op.padding().hasValue()) {
op.emitOpError("No padding specified.");
}
auto padding = op.padding().getValue();
auto loc = op.getLoc();
auto operand = op.operand();
auto operand_shape = operand.getType().cast<MemRefType>().getShape();
// `in_bounds` is false when the mapped indices are in the padding area.
mapped_ivs.in_bounds = b->create<mlir::ConstantOp>(
loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
for (unsigned i = 0, e = ivs.size(); i < e; ++i) {
auto stride = window_strides.getValue<llvm::APInt>(i);
auto pad_low = padding.getValue<llvm::APInt>({i, 0});
Value stride_val = b->create<ConstantIndexOp>(loc, stride.getSExtValue());
Value pad_low_val = b->create<ConstantIndexOp>(loc, pad_low.getSExtValue());
Value center = b->create<MulIOp>(loc, ivs[i], stride_val);
Value offset = b->create<SubIOp>(loc, window_ivs[i], pad_low_val);
Value index = b->create<AddIOp>(loc, center, offset);
Value upper_bound =
GetStaticOrDynamicDim(loc, operand, i, operand_shape[i], b);
// We must check whether 0 <= index_i < shape_i, as otherwise we are in
// the pad and then we have to use the neutral element for reduction.
// Equivalently, it can be computed as the unsigned comparison index_i <
// shape_i, since a negative value wraps to a large positive value.
mapped_ivs.in_bounds = b->create<mlir::AndOp>(
loc, mapped_ivs.in_bounds,
b->create<CmpIOp>(loc, CmpIPredicate::ult, index, upper_bound));
mapped_ivs.ivs.push_back(index);
}
return mapped_ivs;
}
// Returns loop::Parallel over a shaped value with static or dynamic shape.
loop::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
OpBuilder* b) {
Value zero = b->create<ConstantIndexOp>(loc, 0);
Value one = b->create<ConstantIndexOp>(loc, 1);
ArrayRef<int64_t> shape =
shaped_value.getType().cast<ShapedType>().getShape();
SmallVector<Value, 2> lower, upper, step;
for (auto dim : llvm::enumerate(shape)) {
upper.push_back(
GetStaticOrDynamicDim(loc, shaped_value, dim.index(), dim.value(), b));
lower.push_back(zero);
step.push_back(one);
}
return b->create<loop::ParallelOp>(loc, lower, upper, step);
} }
// Converts `xla_lhlo.ReduceOp` into two loop::ParallelOp and a loop::ReduceOp. // Converts `xla_lhlo.ReduceOp` into two loop::ParallelOp and a loop::ReduceOp.
@ -186,7 +266,7 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
SmallVector<Value, 1> out_indices; SmallVector<Value, 1> out_indices;
if (outer != nullptr) { if (outer != nullptr) {
out_indices.reserve(outer.getNumLoops()); out_indices.reserve(outer.getNumLoops());
for (auto& iv : outer.getInductionVars()) { for (Value iv : outer.getInductionVars()) {
out_indices.push_back(iv); out_indices.push_back(iv);
} }
} else { } else {
@ -198,12 +278,16 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
// Load the element to reduce. // Load the element to reduce.
SmallVector<Value, 2> indices; SmallVector<Value, 2> indices;
indices.reserve(operand_shape.size()); indices.reserve(operand_shape.size());
Block::args_iterator outer_ivs_it =
outer ? outer.getInductionVars().begin() : nullptr; if (outer) {
Block::args_iterator inner_ivs_it = inner.getInductionVars().begin(); auto inner_ivs_it = inner.getInductionVars().begin();
for (unsigned i = 0, e = operand_shape.size(); i < e; ++i) { auto outer_ivs_it = outer.getInductionVars().begin();
indices.push_back(reducing_dims.count(i) ? *inner_ivs_it++ for (unsigned i = 0, e = operand_shape.size(); i < e; ++i) {
: *outer_ivs_it++); indices.push_back(reducing_dims.count(i) ? *inner_ivs_it++
: *outer_ivs_it++);
}
} else {
indices = ValueRange(inner.getInductionVars());
} }
rewriter->setInsertionPointToStart(inner.getBody()); rewriter->setInsertionPointToStart(inner.getBody());
@ -309,20 +393,11 @@ class ReduceWindowOpConverter
// Create an outer parallel loop that spans the output of ReduceWindowOp. // Create an outer parallel loop that spans the output of ReduceWindowOp.
Value xla_output = xla_reduce_window_op.out(); Value xla_output = xla_reduce_window_op.out();
auto output_shape = xla_output.getType().cast<MemRefType>().getShape(); auto output_loop = MakeLoopOverShape(loc, xla_output, rewriter);
SmallVector<Value, 2> parallel_lower, parallel_upper, parallel_step;
for (auto dim : llvm::enumerate(output_shape)) {
parallel_upper.push_back(GetStaticOrDynamicDim(
loc, xla_output, dim.index(), dim.value(), rewriter));
parallel_lower.push_back(zero);
parallel_step.push_back(one);
}
auto output_loop = rewriter->create<loop::ParallelOp>(
loc, parallel_lower, parallel_upper, parallel_step);
// Create a nested loop that traverses the window. // Create a nested loop that traverses the window.
rewriter->setInsertionPointToStart(output_loop.getBody());
SmallVector<Value, 2> window_lower, window_upper, window_step; SmallVector<Value, 2> window_lower, window_upper, window_step;
rewriter->setInsertionPointToStart(output_loop.getBody());
for (const auto& window_dim : xla_reduce_window_op.window_dimensions()) { for (const auto& window_dim : xla_reduce_window_op.window_dimensions()) {
window_step.push_back(one); window_step.push_back(one);
window_lower.push_back(zero); window_lower.push_back(zero);
@ -334,9 +409,8 @@ class ReduceWindowOpConverter
Value reduction_result = *window_loop.getResults().begin(); Value reduction_result = *window_loop.getResults().begin();
auto output_ivs = output_loop.getInductionVars(); auto output_ivs = output_loop.getInductionVars();
rewriter->create<StoreOp>( rewriter->create<StoreOp>(loc, reduction_result, xla_output,
loc, reduction_result, xla_output, ValueRange{output_ivs});
llvm::makeArrayRef(output_ivs.begin(), output_ivs.end()));
return std::make_pair(output_loop, window_loop); return std::make_pair(output_loop, window_loop);
} }
@ -347,12 +421,6 @@ class ReduceWindowOpConverter
rewriter->setInsertionPointToStart(window_loop.getBody()); rewriter->setInsertionPointToStart(window_loop.getBody());
auto loc = xla_reduce_window_op.getLoc(); auto loc = xla_reduce_window_op.getLoc();
if (!xla_reduce_window_op.window_strides().hasValue()) {
xla_reduce_window_op.emitOpError("No window strides specified.");
}
if (!xla_reduce_window_op.padding().hasValue()) {
xla_reduce_window_op.emitOpError("No padding specified.");
}
if (xla_reduce_window_op.base_dilations().hasValue() || if (xla_reduce_window_op.base_dilations().hasValue() ||
xla_reduce_window_op.window_dilations().hasValue()) { xla_reduce_window_op.window_dilations().hasValue()) {
xla_reduce_window_op.emitRemark( xla_reduce_window_op.emitRemark(
@ -362,51 +430,18 @@ class ReduceWindowOpConverter
Value xla_operand = xla_reduce_window_op.operand(); Value xla_operand = xla_reduce_window_op.operand();
auto xla_operand_type = xla_operand.getType().cast<MemRefType>(); auto xla_operand_type = xla_operand.getType().cast<MemRefType>();
auto xla_operand_shape = xla_operand_type.getShape();
auto output_ivs = llvm::to_vector<2>(output_loop.getInductionVars()); MappedIvs mapped_ivs = MapWindowIvsToInput(
auto window_ivs = llvm::to_vector<2>(window_loop.getInductionVars()); xla_reduce_window_op, output_loop.getInductionVars(),
auto window_strides = xla_reduce_window_op.window_strides().getValue(); window_loop.getInductionVars(), rewriter);
auto padding = xla_reduce_window_op.padding().getValue();
SmallVector<Value, 2> operand_indices; auto elem_or_init = rewriter->create<loop::IfOp>(
// `in_bounds` is false when the element in the reduce window is in the loc, xla_operand_type.getElementType(), mapped_ivs.in_bounds,
// padding area, true otherwise. /*withElseRegion=*/true);
Value in_bounds = rewriter->create<mlir::ConstantOp>(
loc, rewriter->getI1Type(),
rewriter->getIntegerAttr(rewriter->getI1Type(), 1));
for (unsigned i = 0, e = output_loop.getNumLoops(); i < e; ++i) {
auto stride = window_strides.getValue<llvm::APInt>(i);
auto pad_low = padding.getValue<llvm::APInt>({i, 0});
Value stride_val =
rewriter->create<ConstantIndexOp>(loc, stride.getSExtValue());
Value pad_low_val =
rewriter->create<ConstantIndexOp>(loc, pad_low.getSExtValue());
Value center = rewriter->create<MulIOp>(loc, output_ivs[i], stride_val);
Value offset = rewriter->create<SubIOp>(loc, window_ivs[i], pad_low_val);
Value index = rewriter->create<AddIOp>(loc, center, offset);
operand_indices.push_back(index);
Value upper_bound = GetStaticOrDynamicDim(loc, xla_operand, i,
xla_operand_shape[i], rewriter);
// We must check whether 0 <= index_i < shape_i, as otherwise we are in
// the pad and then we have to use the neutral element for reduction.
// Equivalently, it can be computed as the unsigned comparison index_i <
// shape_i, since a negative value wraps to a large positive value.
in_bounds = rewriter->create<mlir::AndOp>(
loc, in_bounds,
rewriter->create<CmpIOp>(loc, CmpIPredicate::ult, index,
upper_bound));
}
auto elem_or_init =
rewriter->create<loop::IfOp>(loc, xla_operand_type.getElementType(),
in_bounds, /*withElseRegion=*/true);
OpBuilder then_builder = elem_or_init.getThenBodyBuilder(); OpBuilder then_builder = elem_or_init.getThenBodyBuilder();
Value elem = then_builder.create<mlir::LoadOp>( Value elem = then_builder.create<mlir::LoadOp>(
loc, xla_reduce_window_op.operand(), operand_indices); loc, xla_reduce_window_op.operand(), mapped_ivs.ivs);
then_builder.create<loop::YieldOp>(loc, elem); then_builder.create<loop::YieldOp>(loc, elem);
OpBuilder else_builder = elem_or_init.getElseBodyBuilder(); OpBuilder else_builder = elem_or_init.getElseBodyBuilder();
@ -423,8 +458,12 @@ struct LhloLegalizeToParallelLoops
auto func = getFunction(); auto func = getFunction();
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
patterns.insert<ReduceOpConverter, ReduceWindowOpConverter>( // clang-format off
func.getContext()); patterns.insert<
ReduceOpConverter,
ReduceWindowOpConverter
>(func.getContext());
// clang-format on
ConversionTarget target(getContext()); ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect, target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,

View File

@ -95,6 +95,24 @@ std::unique_ptr<Pass> createLhloCopyRemovalPass();
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeLhloToParallelLoopsPass(); std::unique_ptr<OpPassBase<FuncOp>> createLegalizeLhloToParallelLoopsPass();
} // namespace xla_lhlo } // namespace xla_lhlo
namespace xla {
/// Moves alloc nodes (and their associated dealloc nodes - if any) into the
/// right positions. If there is no associated dealloc node for a given alloc
/// node, this pass will automatically insert a proper dealloc node in the right
/// place. The intended use case of this pass is to store SSA values into
/// buffers using load/store operations. For this purpose, you need to know
/// proper positions to place the required allocs and deallocs.
/// 1) Note that the function signatures and all types for which buffers should
/// be allocated need to be converted in advance.
/// 2) All required alloc nodes have the be inserted in advance.
/// 3) Note that the current implementation does not support loops.
/// Refer to the class mlir::xla::BufferAssignmentLegalizer for more
/// information.
std::unique_ptr<OpPassBase<FuncOp>> createBufferAssignmentPass();
} // namespace xla
} // namespace mlir } // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_PASSES_H_ #endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_PASSES_H_

Some files were not shown because too many files have changed in this diff Show More