diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 8e9d615053c..8d4efeb3d60 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -314,7 +314,6 @@ tf_cc_test( cc_library( name = "tensorflow_lite_legalize_tf", srcs = [ - "transforms/device_index_selector.cc", "transforms/dilated_conv.cc", "transforms/generated_legalize_tf.inc", "transforms/generated_lower_static_tensor_list.inc", diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 008098f62ba..fed2896035b 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -63,7 +63,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, standard_pipeline_options.enable_inliner = false; standard_pipeline_options.form_clusters = pass_config.form_clusters; mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options); - pass_manager->addPass(mlir::TFL::CreateDeviceIndexSelectorPass()); + pass_manager->addPass(mlir::TF::CreateDeviceIndexSelectorPass()); if (pass_config.shape_inference) { pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index 01e5eb1cb68..105c9394fb4 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -91,9 +91,6 @@ std::unique_ptr> CreateWhileOutlinePass(); // Verifies runtime constraints. std::unique_ptr> CreateRuntimeVerifyPass(); -// Creates function pass to select device index/fold tf.DeviceIndex. -std::unique_ptr> CreateDeviceIndexSelectorPass(); - } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 54e57512c32..7c0d427e87b 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -475,6 +475,7 @@ cc_library( "transforms/cluster_outlining.cc", "transforms/collection_ops_util.cc", "transforms/decompose_resource_ops_pass.cc", + "transforms/device_index_selector.cc", "transforms/einsum.cc", "transforms/executor_island_coarsening.cc", "transforms/executor_tpuv1_inline_tpu_island.cc", diff --git a/tensorflow/compiler/mlir/lite/transforms/device_index_selector.cc b/tensorflow/compiler/mlir/tensorflow/transforms/device_index_selector.cc similarity index 92% rename from tensorflow/compiler/mlir/lite/transforms/device_index_selector.cc rename to tensorflow/compiler/mlir/tensorflow/transforms/device_index_selector.cc index d4aed750dc8..550647a915a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/device_index_selector.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/device_index_selector.cc @@ -21,11 +21,11 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" namespace mlir { -namespace TFL { +namespace TF { namespace { // Folds the DeviceIndex op to a constant value. The DeviceIndex return the @@ -55,8 +55,8 @@ void DeviceIndexSelector::runOnOperation() { // Convert all the DeviceIndex ops to constant values. func.getBody().walk([](TF::DeviceIndexOp op) { // This just selects the default in all cases where DeviceIndex feeds into - // tf.Case. This could be enhanced based on explicit TFLite specification or - // TAC in future. + // tf.Case. This could be enhanced to have some sort of policy in the + // future. OpBuilder b(op); RankedTensorType type = RankedTensorType::get({}, b.getIntegerType(32)); int index = op.device_names().size(); @@ -79,7 +79,7 @@ std::unique_ptr> CreateDeviceIndexSelectorPass() { } static PassRegistration pass( - "tfl-device-index-selector", "Fold tf.DeviceIndex to constant"); + "tf-device-index-selector", "Fold tf.DeviceIndex to constant"); -} // namespace TFL +} // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index a34be28c809..168b317641d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -147,6 +147,9 @@ std::unique_ptr> CreateLegalizeHloToTfPass(); // generally used beyond exporting to runtimes that supports these ops. In the // future these fusions may be codegen'd automatically. std::unique_ptr> CreateFusedKernelMatcherPass(); + +// Creates function pass to select device index/fold tf.DeviceIndex. +std::unique_ptr> CreateDeviceIndexSelectorPass(); } // namespace TF namespace tf_executor { diff --git a/tensorflow/compiler/mlir/lite/tests/tf_device_index_selector.mlir b/tensorflow/compiler/tensorflow/tests/tf_device_index_selector.mlir similarity index 94% rename from tensorflow/compiler/mlir/lite/tests/tf_device_index_selector.mlir rename to tensorflow/compiler/tensorflow/tests/tf_device_index_selector.mlir index 1ac7f30d644..7fc2b210f91 100644 --- a/tensorflow/compiler/mlir/lite/tests/tf_device_index_selector.mlir +++ b/tensorflow/compiler/tensorflow/tests/tf_device_index_selector.mlir @@ -1,6 +1,6 @@ // Test DeviceIndex selector. -// RUN: tf-opt --tfl-device-index-selector %s | FileCheck %s +// RUN: tf-opt --tf-device-index-selector %s | FileCheck %s // CHECK-LABEL: func @select func @select(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) {