From d6c51d18b767634c41150d5b81042f06f6b642f5 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Mon, 13 Apr 2020 20:02:11 -0700 Subject: [PATCH] Add a new xla-opt binary tool to exercise only the XLA related passes This new utility is just like tf-opt but does not bring all the dependencies from TensorFlow and as such is thinner and faster to link. This will be more convenient for iterative development on the MLIR-XLA work. PiperOrigin-RevId: 306362771 Change-Id: I80ad3abeacc7f1eb2af85b54de90c32b3505bc38 --- tensorflow/compiler/mlir/BUILD | 22 ++--------- tensorflow/compiler/mlir/runlit.cfg.py | 2 +- tensorflow/compiler/mlir/xla/BUILD | 39 ++++++++++++++++++- tensorflow/compiler/mlir/xla/tests/BUILD | 1 + .../mlir/xla/tests/buffer-assignment.mlir | 2 +- .../compiler/mlir/xla/tests/canonicalize.mlir | 2 +- .../compiler/mlir/xla/tests/concatenate.mlir | 2 +- .../compiler/mlir/xla/tests/convert.mlir | 2 +- .../mlir/xla/tests/hlo-legalize-to-lhlo.mlir | 8 ++-- .../xla/tests/hlo-legalize-to-linalg.mlir | 2 +- .../compiler/mlir/xla/tests/inlining.mlir | 2 +- .../mlir/xla/tests/legalize-control-flow.mlir | 2 +- .../xla/tests/legalize-tf-with-tf2xla.mlir | 2 +- .../mlir/xla/tests/legalize-to-std.mlir | 2 +- .../mlir/xla/tests/lhlo-copy-removal.mlir | 2 +- .../mlir/xla/tests/lhlo-fuse-linalg.mlir | 6 +-- .../xla/tests/lhlo-legalize-to-affine.mlir | 2 +- .../mlir/xla/tests/lhlo-legalize-to-gpu.mlir | 2 +- .../xla/tests/lhlo-legalize-to-linalg.mlir | 2 +- .../lhlo-legalize-to-parallel-loops.mlir | 2 +- .../compiler/mlir/xla/tests/lhlo_ops.mlir | 2 +- .../mlir/xla/tests/lower-complex.mlir | 2 +- .../mlir/xla/tests/lower-general-dot.mlir | 2 +- .../xla/tests/materialize-broadcasts.mlir | 2 +- tensorflow/compiler/mlir/xla/tests/ops.mlir | 2 +- .../compiler/mlir/xla/tests/reduce.mlir | 2 +- .../compiler/mlir/xla/tests/reshape.mlir | 2 +- .../compiler/mlir/xla/tests/reverse.mlir | 2 +- .../compiler/mlir/xla/tests/transpose.mlir | 2 +- tensorflow/compiler/mlir/xla/tests/tuple.mlir | 2 +- .../mlir/xla/tests/unfuse_batch_norm.mlir | 2 +- 31 files changed, 75 insertions(+), 53 deletions(-) diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 7398fb2f305..bc4094bbad1 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -40,7 +40,6 @@ cc_library( srcs = ["tf_mlir_opt_main.cc"], deps = [ ":init_mlir", - ":passes", "//tensorflow/core:lib", "//tensorflow/core/platform:logging", "@llvm-project//llvm:support", @@ -56,6 +55,7 @@ cc_library( cc_library( name = "passes", visibility = [ + ":__subpackages__", "//tensorflow/python:__subpackages__", ], deps = [ @@ -77,24 +77,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo", - "//tensorflow/compiler/mlir/xla:buffer_assignment", - "//tensorflow/compiler/mlir/xla:hlo", - "//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo", - "//tensorflow/compiler/mlir/xla:lhlo", - "//tensorflow/compiler/mlir/xla:lhlo_copy_removal", - "//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg", - "//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine", - "//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu", - "//tensorflow/compiler/mlir/xla:lhlo_legalize_to_parallel_loops", - "//tensorflow/compiler/mlir/xla:xla_dialect_registration", - "//tensorflow/compiler/mlir/xla:xla_legalize_control_flow", - "//tensorflow/compiler/mlir/xla:xla_legalize_tf", - "//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla", - "//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg", - "//tensorflow/compiler/mlir/xla:xla_legalize_to_standard", - "//tensorflow/compiler/mlir/xla:xla_lower", - "//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts", - "//tensorflow/compiler/mlir/xla:xla_test_passes", ], ) @@ -142,12 +124,14 @@ cc_library( tf_cc_binary( name = "tf-opt", deps = [ + ":passes", ":tf_mlir_opt_main", "//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration", "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass", "//tensorflow/compiler/mlir/tfjs:tensorflow_js_dialect_registration", + "//tensorflow/compiler/mlir/xla:all_xla_passes_for_testing", ], ) diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py index ab8c1107fc8..ddb968434c4 100644 --- a/tensorflow/compiler/mlir/runlit.cfg.py +++ b/tensorflow/compiler/mlir/runlit.cfg.py @@ -71,7 +71,7 @@ tool_dirs = config.mlir_tf_tools_dirs + [ tool_names = [ 'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate', 'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate', - 'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer' + 'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-opt' ] tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index de2dfec1cf4..122692059bf 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -1,5 +1,5 @@ load("//third_party/mlir:tblgen.bzl", "gentbl") -load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_native_cc_binary") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary") package( default_visibility = [":friends"], @@ -750,7 +750,7 @@ genrule( cmd = ("$(location :operator_writer_gen) " + "-I external/llvm-project/mlir/include " + "-I external/org_tensorflow " + - "$(location //tensorflow/compiler/mlir/xla:ir/hlo_ops.td) " + + "$(location :ir/hlo_ops.td) " + " -o $@"), tools = [":operator_writer_gen"], ) @@ -763,3 +763,38 @@ cc_library( "@llvm-project//mlir:IR", ], ) + +cc_library( + name = "all_xla_passes_for_testing", + visibility = [ + "//tensorflow/compiler/mlir:__subpackages__", + ], + deps = [ + ":buffer_assignment", + ":hlo", + ":hlo_legalize_to_lhlo", + ":lhlo", + ":lhlo_copy_removal", + ":lhlo_fuse_linalg", + ":lhlo_legalize_to_affine", + ":lhlo_legalize_to_gpu", + ":lhlo_legalize_to_parallel_loops", + ":xla_dialect_registration", + ":xla_legalize_control_flow", + ":xla_legalize_tf", + ":xla_legalize_tf_with_tf2xla", + ":xla_legalize_to_linalg", + ":xla_legalize_to_standard", + ":xla_lower", + ":xla_materialize_broadcasts", + ":xla_test_passes", + ], +) + +tf_cc_binary( + name = "xla-opt", + deps = [ + ":all_xla_passes_for_testing", + "//tensorflow/compiler/mlir:tf_mlir_opt_main", + ], +) diff --git a/tensorflow/compiler/mlir/xla/tests/BUILD b/tensorflow/compiler/mlir/xla/tests/BUILD index 4faa8d2efe8..989b846f561 100644 --- a/tensorflow/compiler/mlir/xla/tests/BUILD +++ b/tensorflow/compiler/mlir/xla/tests/BUILD @@ -14,6 +14,7 @@ filegroup( testonly = True, data = [ "//tensorflow/compiler/mlir:tf-opt", + "//tensorflow/compiler/mlir/xla:xla-opt", "@llvm-project//llvm:FileCheck", ], ) diff --git a/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir b/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir index 7bcf477d45e..866e7218de0 100644 --- a/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir +++ b/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -test-buffer-assignment -split-input-file %s | FileCheck %s -dump-input-on-failure +// RUN: xla-opt -test-buffer-assignment -split-input-file %s | FileCheck %s -dump-input-on-failure // CHECK-LABEL: Testing : condBranch func @condBranch(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir index a045e1f9d07..1b60745b20c 100644 --- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input-on-failure +// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input-on-failure func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { // CHECK: "xla_hlo.dynamic-slice" diff --git a/tensorflow/compiler/mlir/xla/tests/concatenate.mlir b/tensorflow/compiler/mlir/xla/tests/concatenate.mlir index d22079d4942..5b1225e1e87 100644 --- a/tensorflow/compiler/mlir/xla/tests/concatenate.mlir +++ b/tensorflow/compiler/mlir/xla/tests/concatenate.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s +// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s // CHECK-LABEL: func @single_operand // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] diff --git a/tensorflow/compiler/mlir/xla/tests/convert.mlir b/tensorflow/compiler/mlir/xla/tests/convert.mlir index c9c2f384662..63ce724adb7 100644 --- a/tensorflow/compiler/mlir/xla/tests/convert.mlir +++ b/tensorflow/compiler/mlir/xla/tests/convert.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s +// RUN: xla-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s // ----- diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir index 4d3fa3b8276..c457f3d5506 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -hlo-legalize-to-lhlo %s -o - | FileCheck %s --dump-input-on-failure +// RUN: xla-opt -hlo-legalize-to-lhlo %s -o - | FileCheck %s --dump-input-on-failure // CHECK-LABEL: func @attrs func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { @@ -146,14 +146,16 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) { // ----- +func @external_func() -> tensor<3xi64> + // CHECK-LABEL: func @dyn_broadcast func @dyn_broadcast(%operand: memref) { %tensor_operand = tensor_load %operand : memref - %shape = "tf.compute_shape"() : () -> tensor<3xi64> + %shape = call @external_func() : () -> tensor<3xi64> %tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor, tensor<3xi64>) -> tensor - // CHECK: %[[SHAPE:.*]] = "tf.compute_shape"() + // CHECK: %[[SHAPE:.*]] = call @external_func() // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64> // CHECK: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir index 67c59ba10c5..ecee1d681d6 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -hlo-legalize-to-linalg -split-input-file | FileCheck %s +// RUN: xla-opt %s -hlo-legalize-to-linalg -split-input-file | FileCheck %s // CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @float_add diff --git a/tensorflow/compiler/mlir/xla/tests/inlining.mlir b/tensorflow/compiler/mlir/xla/tests/inlining.mlir index 3e447f7ff11..2f20386b83f 100644 --- a/tensorflow/compiler/mlir/xla/tests/inlining.mlir +++ b/tensorflow/compiler/mlir/xla/tests/inlining.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -inline | FileCheck %s --dump-input=fail +// RUN: xla-opt %s -inline | FileCheck %s --dump-input=fail // Test case: Basic test of inlining into xla_hlo.while. diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir index e611b4419c9..83c3f765dc3 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -xla-legalize-control-flow %s -o - | FileCheck %s +// RUN: xla-opt -xla-legalize-control-flow %s -o - | FileCheck %s // CHECK-LABEL: func @while(%arg0: tensor) -> tensor { func @while(%arg0: tensor) -> tensor { diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index aebb58dc299..2fed18cb917 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -1,6 +1,6 @@ // RUN: tf-opt -xla-legalize-tf-with-tf2xla=device-type=XLA_CPU_JIT %s | FileCheck %s --dump-input-on-failure -// INVALID_DEVICE: tf-opt -xla-legalize-tf-with-tf2xla=device-type=INVALID_DEVICE %s | FileCheck %s --dump-input-on-failure +// INVALID_DEVICE: xla-opt -xla-legalize-tf-with-tf2xla=device-type=INVALID_DEVICE %s | FileCheck %s --dump-input-on-failure module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir index 5bb965fa320..d25a84d0e25 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -xla-legalize-to-std %s -o - | FileCheck %s +// RUN: xla-opt -xla-legalize-to-std %s -o - | FileCheck %s // CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-copy-removal.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-copy-removal.mlir index fab1389262d..3fc1dbe2b97 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-copy-removal.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-copy-removal.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -lhlo-copy-removal %s -o - | FileCheck %s --dump-input-on-failure +// RUN: xla-opt -lhlo-copy-removal %s -o - | FileCheck %s --dump-input-on-failure // CHECK-LABEL: func @remove_simple func @remove_simple(%arg0: memref<2x2xf32>) { diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir index 0a48cbd372f..013748fea28 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir @@ -1,6 +1,6 @@ -// RUN: tf-opt -lhlo-fuse-linalg %s -o - | FileCheck %s --dump-input=always -// RUN: tf-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -o - | FileCheck %s -check-prefix=TILED --dump-input-on-failure -// RUN: tf-opt -lhlo-fuse-linalg=use-parallel-loops %s -o - | FileCheck %s -check-prefix=PLOOP --dump-input-on-failure +// RUN: xla-opt -lhlo-fuse-linalg %s -o - | FileCheck %s --dump-input=always +// RUN: xla-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -o - | FileCheck %s -check-prefix=TILED --dump-input-on-failure +// RUN: xla-opt -lhlo-fuse-linalg=use-parallel-loops %s -o - | FileCheck %s -check-prefix=PLOOP --dump-input-on-failure #map0 = affine_map<(d0, d1) -> (d0, d1)> diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir index 0aa7834b4fb..08ba9f02f3e 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -lhlo-legalize-to-affine %s -o - | FileCheck %s +// RUN: xla-opt -lhlo-legalize-to-affine %s -o - | FileCheck %s // Smoke test. // CHECK-LABEL: func @min_op diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-gpu.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-gpu.mlir index 15414767c83..4d878cee6f4 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-gpu.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-gpu.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -lhlo-legalize-to-gpu -split-input-file | FileCheck %s --dump-input=fail +// RUN: xla-opt %s -lhlo-legalize-to-gpu -split-input-file | FileCheck %s --dump-input=fail func @reduce(%arg: memref<100x10xf32>, %init: memref, diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir index d1c3db32b93..a070dac9836 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -lhlo-legalize-to-linalg -split-input-file | FileCheck %s --dump-input-on-failure +// RUN: xla-opt %s -lhlo-legalize-to-linalg -split-input-file | FileCheck %s --dump-input-on-failure // CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @element_wise diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir index ff4f1d940bf..e1f0d5c8682 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -lhlo-legalize-to-parallel-loops -canonicalize -split-input-file | FileCheck %s --dump-input-on-failure +// RUN: xla-opt %s -lhlo-legalize-to-parallel-loops -canonicalize -split-input-file | FileCheck %s --dump-input-on-failure func @reduce(%arg: memref<100x10x5xf32>, %init: memref, diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir index 7e831eadc2f..23e9d9b68e0 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -verify-diagnostics -split-input-file | tf-opt | FileCheck %s +// RUN: xla-opt %s -verify-diagnostics -split-input-file | xla-opt | FileCheck %s func @enforce_same_shape(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () { // expected-error@+1{{'xla_lhlo.tanh' op requires all operands to have the same type}} diff --git a/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir b/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir index 74d175109d3..35a5ae549d5 100644 --- a/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -test-xla-lower-complex | FileCheck %s +// RUN: xla-opt %s -test-xla-lower-complex | FileCheck %s // CHECK-LABEL: @add func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { diff --git a/tensorflow/compiler/mlir/xla/tests/lower-general-dot.mlir b/tensorflow/compiler/mlir/xla/tests/lower-general-dot.mlir index cde55b05c04..7250fd4cc94 100644 --- a/tensorflow/compiler/mlir/xla/tests/lower-general-dot.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lower-general-dot.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -test-xla-lower-general-dot -split-input-file %s -o - | FileCheck %s +// RUN: xla-opt -test-xla-lower-general-dot -split-input-file %s -o - | FileCheck %s // CHECK-LABEL: @testDebatch1 func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> { diff --git a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir index 07ff6d17091..fde5c12c1c6 100644 --- a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir +++ b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -test-xla-materialize-broadcasts -split-input-file %s -o - | FileCheck --dump-input=fail %s +// RUN: xla-opt -test-xla-materialize-broadcasts -split-input-file %s -o - | FileCheck --dump-input=fail %s // CHECK-LABEL: @addBroadcastRhs func @addBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir index a1cddab54c9..aa38ccd3c30 100644 --- a/tensorflow/compiler/mlir/xla/tests/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -verify-diagnostics -split-input-file | tf-opt | FileCheck %s +// RUN: xla-opt %s -verify-diagnostics -split-input-file | xla-opt | FileCheck %s // Tests for types, ops with custom constraints, verifiers, printer or parser // methods. diff --git a/tensorflow/compiler/mlir/xla/tests/reduce.mlir b/tensorflow/compiler/mlir/xla/tests/reduce.mlir index 53dfca5ec08..d49b34d6f74 100644 --- a/tensorflow/compiler/mlir/xla/tests/reduce.mlir +++ b/tensorflow/compiler/mlir/xla/tests/reduce.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s +// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s // CHECK-LABEL: func @noop // CHECK-SAME: (%[[ARG0:.*]]: tensor<4x8xf32>) diff --git a/tensorflow/compiler/mlir/xla/tests/reshape.mlir b/tensorflow/compiler/mlir/xla/tests/reshape.mlir index 7dbec638528..fe16e8c1c99 100644 --- a/tensorflow/compiler/mlir/xla/tests/reshape.mlir +++ b/tensorflow/compiler/mlir/xla/tests/reshape.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s +// RUN: xla-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s // CHECK-LABEL: func @const_fold_collapse_to_scalar func @const_fold_collapse_to_scalar() -> tensor { diff --git a/tensorflow/compiler/mlir/xla/tests/reverse.mlir b/tensorflow/compiler/mlir/xla/tests/reverse.mlir index a4e70e339d0..e0e80400b81 100644 --- a/tensorflow/compiler/mlir/xla/tests/reverse.mlir +++ b/tensorflow/compiler/mlir/xla/tests/reverse.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s +// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s // CHECK-LABEL: func @noop // CHECK-SAME: (%[[ARG0:.*]]: tensor<1x2xf32>) diff --git a/tensorflow/compiler/mlir/xla/tests/transpose.mlir b/tensorflow/compiler/mlir/xla/tests/transpose.mlir index 0ee56724da5..7942fadcd60 100644 --- a/tensorflow/compiler/mlir/xla/tests/transpose.mlir +++ b/tensorflow/compiler/mlir/xla/tests/transpose.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input=fail +// RUN: xla-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input=fail // CHECK-LABEL: func @remove_noop // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] diff --git a/tensorflow/compiler/mlir/xla/tests/tuple.mlir b/tensorflow/compiler/mlir/xla/tests/tuple.mlir index a06aa912806..f22bc210c57 100644 --- a/tensorflow/compiler/mlir/xla/tests/tuple.mlir +++ b/tensorflow/compiler/mlir/xla/tests/tuple.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s +// RUN: xla-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s // CHECK-LABEL: func @fold_access // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] diff --git a/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir b/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir index 28a4d0589b0..9778772e250 100644 --- a/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir +++ b/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -split-input-file -test-xla-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope --dump-input=fail %s +// RUN: xla-opt -split-input-file -test-xla-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope --dump-input=fail %s // CHECK-LABEL: @batchNormInference_2D_inner_features // CHECK-SAME: %[[X:[^:[:space:]]+]]