Move tfl-device-index-selector to TF directory.

There's nothing lite-specific about this pass.

PiperOrigin-RevId: 317188038
Change-Id: Iac9799e296e043aabf7aeabec2e8f72d07c77178
This commit is contained in:
Sean Silva 2020-06-18 15:00:53 -07:00 committed by TensorFlower Gardener
parent 9d33f296d1
commit 35b978db57
7 changed files with 12 additions and 12 deletions

View File

@ -314,7 +314,6 @@ tf_cc_test(
cc_library( cc_library(
name = "tensorflow_lite_legalize_tf", name = "tensorflow_lite_legalize_tf",
srcs = [ srcs = [
"transforms/device_index_selector.cc",
"transforms/dilated_conv.cc", "transforms/dilated_conv.cc",
"transforms/generated_legalize_tf.inc", "transforms/generated_legalize_tf.inc",
"transforms/generated_lower_static_tensor_list.inc", "transforms/generated_lower_static_tensor_list.inc",

View File

@ -63,7 +63,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
standard_pipeline_options.enable_inliner = false; standard_pipeline_options.enable_inliner = false;
standard_pipeline_options.form_clusters = pass_config.form_clusters; standard_pipeline_options.form_clusters = pass_config.form_clusters;
mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options); mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
pass_manager->addPass(mlir::TFL::CreateDeviceIndexSelectorPass()); pass_manager->addPass(mlir::TF::CreateDeviceIndexSelectorPass());
if (pass_config.shape_inference) { if (pass_config.shape_inference) {
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());

View File

@ -91,9 +91,6 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass();
// Verifies runtime constraints. // Verifies runtime constraints.
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass(); std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
// Creates function pass to select device index/fold tf.DeviceIndex.
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
} // namespace TFL } // namespace TFL
} // namespace mlir } // namespace mlir

View File

@ -475,6 +475,7 @@ cc_library(
"transforms/cluster_outlining.cc", "transforms/cluster_outlining.cc",
"transforms/collection_ops_util.cc", "transforms/collection_ops_util.cc",
"transforms/decompose_resource_ops_pass.cc", "transforms/decompose_resource_ops_pass.cc",
"transforms/device_index_selector.cc",
"transforms/einsum.cc", "transforms/einsum.cc",
"transforms/executor_island_coarsening.cc", "transforms/executor_island_coarsening.cc",
"transforms/executor_tpuv1_inline_tpu_island.cc", "transforms/executor_tpuv1_inline_tpu_island.cc",

View File

@ -21,11 +21,11 @@ limitations under the License.
#include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
namespace mlir { namespace mlir {
namespace TFL { namespace TF {
namespace { namespace {
// Folds the DeviceIndex op to a constant value. The DeviceIndex return the // Folds the DeviceIndex op to a constant value. The DeviceIndex return the
@ -55,8 +55,8 @@ void DeviceIndexSelector::runOnOperation() {
// Convert all the DeviceIndex ops to constant values. // Convert all the DeviceIndex ops to constant values.
func.getBody().walk([](TF::DeviceIndexOp op) { func.getBody().walk([](TF::DeviceIndexOp op) {
// This just selects the default in all cases where DeviceIndex feeds into // This just selects the default in all cases where DeviceIndex feeds into
// tf.Case. This could be enhanced based on explicit TFLite specification or // tf.Case. This could be enhanced to have some sort of policy in the
// TAC in future. // future.
OpBuilder b(op); OpBuilder b(op);
RankedTensorType type = RankedTensorType::get({}, b.getIntegerType(32)); RankedTensorType type = RankedTensorType::get({}, b.getIntegerType(32));
int index = op.device_names().size(); int index = op.device_names().size();
@ -79,7 +79,7 @@ std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass() {
} }
static PassRegistration<DeviceIndexSelector> pass( static PassRegistration<DeviceIndexSelector> pass(
"tfl-device-index-selector", "Fold tf.DeviceIndex to constant"); "tf-device-index-selector", "Fold tf.DeviceIndex to constant");
} // namespace TFL } // namespace TF
} // namespace mlir } // namespace mlir

View File

@ -147,6 +147,9 @@ std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass();
// generally used beyond exporting to runtimes that supports these ops. In the // generally used beyond exporting to runtimes that supports these ops. In the
// future these fusions may be codegen'd automatically. // future these fusions may be codegen'd automatically.
std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass(); std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass();
// Creates function pass to select device index/fold tf.DeviceIndex.
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
} // namespace TF } // namespace TF
namespace tf_executor { namespace tf_executor {

View File

@ -1,6 +1,6 @@
// Test DeviceIndex selector. // Test DeviceIndex selector.
// RUN: tf-opt --tfl-device-index-selector %s | FileCheck %s // RUN: tf-opt --tf-device-index-selector %s | FileCheck %s
// CHECK-LABEL: func @select // CHECK-LABEL: func @select
func @select(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<i32>, tensor<f32>) { func @select(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<i32>, tensor<f32>) {