Move tfl-device-index-selector to TF directory.
There's nothing lite-specific about this pass. PiperOrigin-RevId: 317188038 Change-Id: Iac9799e296e043aabf7aeabec2e8f72d07c77178
This commit is contained in:
parent
9d33f296d1
commit
35b978db57
|
@ -314,7 +314,6 @@ tf_cc_test(
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tensorflow_lite_legalize_tf",
|
name = "tensorflow_lite_legalize_tf",
|
||||||
srcs = [
|
srcs = [
|
||||||
"transforms/device_index_selector.cc",
|
|
||||||
"transforms/dilated_conv.cc",
|
"transforms/dilated_conv.cc",
|
||||||
"transforms/generated_legalize_tf.inc",
|
"transforms/generated_legalize_tf.inc",
|
||||||
"transforms/generated_lower_static_tensor_list.inc",
|
"transforms/generated_lower_static_tensor_list.inc",
|
||||||
|
|
|
@ -63,7 +63,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||||
standard_pipeline_options.enable_inliner = false;
|
standard_pipeline_options.enable_inliner = false;
|
||||||
standard_pipeline_options.form_clusters = pass_config.form_clusters;
|
standard_pipeline_options.form_clusters = pass_config.form_clusters;
|
||||||
mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
|
mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
|
||||||
pass_manager->addPass(mlir::TFL::CreateDeviceIndexSelectorPass());
|
pass_manager->addPass(mlir::TF::CreateDeviceIndexSelectorPass());
|
||||||
|
|
||||||
if (pass_config.shape_inference) {
|
if (pass_config.shape_inference) {
|
||||||
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||||
|
|
|
@ -91,9 +91,6 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass();
|
||||||
// Verifies runtime constraints.
|
// Verifies runtime constraints.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
|
||||||
|
|
||||||
// Creates function pass to select device index/fold tf.DeviceIndex.
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
|
|
||||||
|
|
||||||
} // namespace TFL
|
} // namespace TFL
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -475,6 +475,7 @@ cc_library(
|
||||||
"transforms/cluster_outlining.cc",
|
"transforms/cluster_outlining.cc",
|
||||||
"transforms/collection_ops_util.cc",
|
"transforms/collection_ops_util.cc",
|
||||||
"transforms/decompose_resource_ops_pass.cc",
|
"transforms/decompose_resource_ops_pass.cc",
|
||||||
|
"transforms/device_index_selector.cc",
|
||||||
"transforms/einsum.cc",
|
"transforms/einsum.cc",
|
||||||
"transforms/executor_island_coarsening.cc",
|
"transforms/executor_island_coarsening.cc",
|
||||||
"transforms/executor_tpuv1_inline_tpu_island.cc",
|
"transforms/executor_tpuv1_inline_tpu_island.cc",
|
||||||
|
|
|
@ -21,11 +21,11 @@ limitations under the License.
|
||||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace TFL {
|
namespace TF {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Folds the DeviceIndex op to a constant value. The DeviceIndex return the
|
// Folds the DeviceIndex op to a constant value. The DeviceIndex return the
|
||||||
|
@ -55,8 +55,8 @@ void DeviceIndexSelector::runOnOperation() {
|
||||||
// Convert all the DeviceIndex ops to constant values.
|
// Convert all the DeviceIndex ops to constant values.
|
||||||
func.getBody().walk([](TF::DeviceIndexOp op) {
|
func.getBody().walk([](TF::DeviceIndexOp op) {
|
||||||
// This just selects the default in all cases where DeviceIndex feeds into
|
// This just selects the default in all cases where DeviceIndex feeds into
|
||||||
// tf.Case. This could be enhanced based on explicit TFLite specification or
|
// tf.Case. This could be enhanced to have some sort of policy in the
|
||||||
// TAC in future.
|
// future.
|
||||||
OpBuilder b(op);
|
OpBuilder b(op);
|
||||||
RankedTensorType type = RankedTensorType::get({}, b.getIntegerType(32));
|
RankedTensorType type = RankedTensorType::get({}, b.getIntegerType(32));
|
||||||
int index = op.device_names().size();
|
int index = op.device_names().size();
|
||||||
|
@ -79,7 +79,7 @@ std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass() {
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<DeviceIndexSelector> pass(
|
static PassRegistration<DeviceIndexSelector> pass(
|
||||||
"tfl-device-index-selector", "Fold tf.DeviceIndex to constant");
|
"tf-device-index-selector", "Fold tf.DeviceIndex to constant");
|
||||||
|
|
||||||
} // namespace TFL
|
} // namespace TF
|
||||||
} // namespace mlir
|
} // namespace mlir
|
|
@ -147,6 +147,9 @@ std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass();
|
||||||
// generally used beyond exporting to runtimes that supports these ops. In the
|
// generally used beyond exporting to runtimes that supports these ops. In the
|
||||||
// future these fusions may be codegen'd automatically.
|
// future these fusions may be codegen'd automatically.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass();
|
||||||
|
|
||||||
|
// Creates function pass to select device index/fold tf.DeviceIndex.
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
|
||||||
} // namespace TF
|
} // namespace TF
|
||||||
|
|
||||||
namespace tf_executor {
|
namespace tf_executor {
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
// Test DeviceIndex selector.
|
// Test DeviceIndex selector.
|
||||||
|
|
||||||
// RUN: tf-opt --tfl-device-index-selector %s | FileCheck %s
|
// RUN: tf-opt --tf-device-index-selector %s | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @select
|
// CHECK-LABEL: func @select
|
||||||
func @select(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<i32>, tensor<f32>) {
|
func @select(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<i32>, tensor<f32>) {
|
Loading…
Reference in New Issue