Move tfl-device-index-selector to TF directory.
There's nothing lite-specific about this pass. PiperOrigin-RevId: 317188038 Change-Id: Iac9799e296e043aabf7aeabec2e8f72d07c77178
This commit is contained in:
parent
9d33f296d1
commit
35b978db57
|
@ -314,7 +314,6 @@ tf_cc_test(
|
|||
cc_library(
|
||||
name = "tensorflow_lite_legalize_tf",
|
||||
srcs = [
|
||||
"transforms/device_index_selector.cc",
|
||||
"transforms/dilated_conv.cc",
|
||||
"transforms/generated_legalize_tf.inc",
|
||||
"transforms/generated_lower_static_tensor_list.inc",
|
||||
|
|
|
@ -63,7 +63,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
|||
standard_pipeline_options.enable_inliner = false;
|
||||
standard_pipeline_options.form_clusters = pass_config.form_clusters;
|
||||
mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
|
||||
pass_manager->addPass(mlir::TFL::CreateDeviceIndexSelectorPass());
|
||||
pass_manager->addPass(mlir::TF::CreateDeviceIndexSelectorPass());
|
||||
|
||||
if (pass_config.shape_inference) {
|
||||
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||
|
|
|
@ -91,9 +91,6 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass();
|
|||
// Verifies runtime constraints.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
|
||||
|
||||
// Creates function pass to select device index/fold tf.DeviceIndex.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
|
||||
|
||||
} // namespace TFL
|
||||
|
||||
} // namespace mlir
|
||||
|
|
|
@ -475,6 +475,7 @@ cc_library(
|
|||
"transforms/cluster_outlining.cc",
|
||||
"transforms/collection_ops_util.cc",
|
||||
"transforms/decompose_resource_ops_pass.cc",
|
||||
"transforms/device_index_selector.cc",
|
||||
"transforms/einsum.cc",
|
||||
"transforms/executor_island_coarsening.cc",
|
||||
"transforms/executor_tpuv1_inline_tpu_island.cc",
|
||||
|
|
|
@ -21,11 +21,11 @@ limitations under the License.
|
|||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
namespace TF {
|
||||
namespace {
|
||||
|
||||
// Folds the DeviceIndex op to a constant value. The DeviceIndex return the
|
||||
|
@ -55,8 +55,8 @@ void DeviceIndexSelector::runOnOperation() {
|
|||
// Convert all the DeviceIndex ops to constant values.
|
||||
func.getBody().walk([](TF::DeviceIndexOp op) {
|
||||
// This just selects the default in all cases where DeviceIndex feeds into
|
||||
// tf.Case. This could be enhanced based on explicit TFLite specification or
|
||||
// TAC in future.
|
||||
// tf.Case. This could be enhanced to have some sort of policy in the
|
||||
// future.
|
||||
OpBuilder b(op);
|
||||
RankedTensorType type = RankedTensorType::get({}, b.getIntegerType(32));
|
||||
int index = op.device_names().size();
|
||||
|
@ -79,7 +79,7 @@ std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass() {
|
|||
}
|
||||
|
||||
static PassRegistration<DeviceIndexSelector> pass(
|
||||
"tfl-device-index-selector", "Fold tf.DeviceIndex to constant");
|
||||
"tf-device-index-selector", "Fold tf.DeviceIndex to constant");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
|
@ -147,6 +147,9 @@ std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass();
|
|||
// generally used beyond exporting to runtimes that supports these ops. In the
|
||||
// future these fusions may be codegen'd automatically.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass();
|
||||
|
||||
// Creates function pass to select device index/fold tf.DeviceIndex.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
|
||||
} // namespace TF
|
||||
|
||||
namespace tf_executor {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Test DeviceIndex selector.
|
||||
|
||||
// RUN: tf-opt --tfl-device-index-selector %s | FileCheck %s
|
||||
// RUN: tf-opt --tf-device-index-selector %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @select
|
||||
func @select(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<i32>, tensor<f32>) {
|
Loading…
Reference in New Issue