Add a new xla-opt binary tool to exercise only the XLA related passes

This new utility is just like tf-opt but does not bring all the dependencies from
TensorFlow and as such is thinner and faster to link. This will be more convenient
for iterative development on the MLIR-XLA work.

PiperOrigin-RevId: 306362771
Change-Id: I80ad3abeacc7f1eb2af85b54de90c32b3505bc38
This commit is contained in:
Mehdi Amini 2020-04-13 20:02:11 -07:00 committed by TensorFlower Gardener
parent 43f67e0881
commit d6c51d18b7
31 changed files with 75 additions and 53 deletions

View File

@ -40,7 +40,6 @@ cc_library(
srcs = ["tf_mlir_opt_main.cc"], srcs = ["tf_mlir_opt_main.cc"],
deps = [ deps = [
":init_mlir", ":init_mlir",
":passes",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/platform:logging", "//tensorflow/core/platform:logging",
"@llvm-project//llvm:support", "@llvm-project//llvm:support",
@ -56,6 +55,7 @@ cc_library(
cc_library( cc_library(
name = "passes", name = "passes",
visibility = [ visibility = [
":__subpackages__",
"//tensorflow/python:__subpackages__", "//tensorflow/python:__subpackages__",
], ],
deps = [ deps = [
@ -77,24 +77,6 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo", "//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
"//tensorflow/compiler/mlir/xla:buffer_assignment",
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo",
"//tensorflow/compiler/mlir/xla:lhlo",
"//tensorflow/compiler/mlir/xla:lhlo_copy_removal",
"//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_parallel_loops",
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
"//tensorflow/compiler/mlir/xla:xla_lower",
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",
"//tensorflow/compiler/mlir/xla:xla_test_passes",
], ],
) )
@ -142,12 +124,14 @@ cc_library(
tf_cc_binary( tf_cc_binary(
name = "tf-opt", name = "tf-opt",
deps = [ deps = [
":passes",
":tf_mlir_opt_main", ":tf_mlir_opt_main",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration", "//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_dialect_registration", "//tensorflow/compiler/mlir/tfjs:tensorflow_js_dialect_registration",
"//tensorflow/compiler/mlir/xla:all_xla_passes_for_testing",
], ],
) )

View File

@ -71,7 +71,7 @@ tool_dirs = config.mlir_tf_tools_dirs + [
tool_names = [ tool_names = [
'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate', 'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate', 'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate',
'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer' 'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-opt'
] ]
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
llvm_config.add_tool_substitutions(tools, tool_dirs) llvm_config.add_tool_substitutions(tools, tool_dirs)

View File

@ -1,5 +1,5 @@
load("//third_party/mlir:tblgen.bzl", "gentbl") load("//third_party/mlir:tblgen.bzl", "gentbl")
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_native_cc_binary") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary")
package( package(
default_visibility = [":friends"], default_visibility = [":friends"],
@ -750,7 +750,7 @@ genrule(
cmd = ("$(location :operator_writer_gen) " + cmd = ("$(location :operator_writer_gen) " +
"-I external/llvm-project/mlir/include " + "-I external/llvm-project/mlir/include " +
"-I external/org_tensorflow " + "-I external/org_tensorflow " +
"$(location //tensorflow/compiler/mlir/xla:ir/hlo_ops.td) " + "$(location :ir/hlo_ops.td) " +
" -o $@"), " -o $@"),
tools = [":operator_writer_gen"], tools = [":operator_writer_gen"],
) )
@ -763,3 +763,38 @@ cc_library(
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
], ],
) )
cc_library(
name = "all_xla_passes_for_testing",
visibility = [
"//tensorflow/compiler/mlir:__subpackages__",
],
deps = [
":buffer_assignment",
":hlo",
":hlo_legalize_to_lhlo",
":lhlo",
":lhlo_copy_removal",
":lhlo_fuse_linalg",
":lhlo_legalize_to_affine",
":lhlo_legalize_to_gpu",
":lhlo_legalize_to_parallel_loops",
":xla_dialect_registration",
":xla_legalize_control_flow",
":xla_legalize_tf",
":xla_legalize_tf_with_tf2xla",
":xla_legalize_to_linalg",
":xla_legalize_to_standard",
":xla_lower",
":xla_materialize_broadcasts",
":xla_test_passes",
],
)
tf_cc_binary(
name = "xla-opt",
deps = [
":all_xla_passes_for_testing",
"//tensorflow/compiler/mlir:tf_mlir_opt_main",
],
)

View File

@ -14,6 +14,7 @@ filegroup(
testonly = True, testonly = True,
data = [ data = [
"//tensorflow/compiler/mlir:tf-opt", "//tensorflow/compiler/mlir:tf-opt",
"//tensorflow/compiler/mlir/xla:xla-opt",
"@llvm-project//llvm:FileCheck", "@llvm-project//llvm:FileCheck",
], ],
) )

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -test-buffer-assignment -split-input-file %s | FileCheck %s -dump-input-on-failure // RUN: xla-opt -test-buffer-assignment -split-input-file %s | FileCheck %s -dump-input-on-failure
// CHECK-LABEL: Testing : condBranch // CHECK-LABEL: Testing : condBranch
func @condBranch(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ func @condBranch(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input-on-failure // RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input-on-failure
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
// CHECK: "xla_hlo.dynamic-slice" // CHECK: "xla_hlo.dynamic-slice"

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s // RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
// CHECK-LABEL: func @single_operand // CHECK-LABEL: func @single_operand
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s // RUN: xla-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
// ----- // -----

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -hlo-legalize-to-lhlo %s -o - | FileCheck %s --dump-input-on-failure // RUN: xla-opt -hlo-legalize-to-lhlo %s -o - | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: func @attrs // CHECK-LABEL: func @attrs
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
@ -146,14 +146,16 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
// ----- // -----
func @external_func() -> tensor<3xi64>
// CHECK-LABEL: func @dyn_broadcast // CHECK-LABEL: func @dyn_broadcast
func @dyn_broadcast(%operand: memref<?x?xf32>) { func @dyn_broadcast(%operand: memref<?x?xf32>) {
%tensor_operand = tensor_load %operand : memref<?x?xf32> %tensor_operand = tensor_load %operand : memref<?x?xf32>
%shape = "tf.compute_shape"() : () -> tensor<3xi64> %shape = call @external_func() : () -> tensor<3xi64>
%tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) %tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape)
{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
: (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32> : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[SHAPE:.*]] = "tf.compute_shape"() // CHECK: %[[SHAPE:.*]] = call @external_func()
// CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64> // CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
// CHECK: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index // CHECK: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -hlo-legalize-to-linalg -split-input-file | FileCheck %s // RUN: xla-opt %s -hlo-legalize-to-linalg -split-input-file | FileCheck %s
// CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)> // CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @float_add // CHECK-LABEL: func @float_add

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -inline | FileCheck %s --dump-input=fail // RUN: xla-opt %s -inline | FileCheck %s --dump-input=fail
// Test case: Basic test of inlining into xla_hlo.while. // Test case: Basic test of inlining into xla_hlo.while.

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -xla-legalize-control-flow %s -o - | FileCheck %s // RUN: xla-opt -xla-legalize-control-flow %s -o - | FileCheck %s
// CHECK-LABEL: func @while(%arg0: tensor<i64>) -> tensor<i64> { // CHECK-LABEL: func @while(%arg0: tensor<i64>) -> tensor<i64> {
func @while(%arg0: tensor<i64>) -> tensor<i64> { func @while(%arg0: tensor<i64>) -> tensor<i64> {

View File

@ -1,6 +1,6 @@
// RUN: tf-opt -xla-legalize-tf-with-tf2xla=device-type=XLA_CPU_JIT %s | FileCheck %s --dump-input-on-failure // RUN: tf-opt -xla-legalize-tf-with-tf2xla=device-type=XLA_CPU_JIT %s | FileCheck %s --dump-input-on-failure
// INVALID_DEVICE: tf-opt -xla-legalize-tf-with-tf2xla=device-type=INVALID_DEVICE %s | FileCheck %s --dump-input-on-failure // INVALID_DEVICE: xla-opt -xla-legalize-tf-with-tf2xla=device-type=INVALID_DEVICE %s | FileCheck %s --dump-input-on-failure
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -xla-legalize-to-std %s -o - | FileCheck %s // RUN: xla-opt -xla-legalize-to-std %s -o - | FileCheck %s
// CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -lhlo-copy-removal %s -o - | FileCheck %s --dump-input-on-failure // RUN: xla-opt -lhlo-copy-removal %s -o - | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: func @remove_simple // CHECK-LABEL: func @remove_simple
func @remove_simple(%arg0: memref<2x2xf32>) { func @remove_simple(%arg0: memref<2x2xf32>) {

View File

@ -1,6 +1,6 @@
// RUN: tf-opt -lhlo-fuse-linalg %s -o - | FileCheck %s --dump-input=always // RUN: xla-opt -lhlo-fuse-linalg %s -o - | FileCheck %s --dump-input=always
// RUN: tf-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -o - | FileCheck %s -check-prefix=TILED --dump-input-on-failure // RUN: xla-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -o - | FileCheck %s -check-prefix=TILED --dump-input-on-failure
// RUN: tf-opt -lhlo-fuse-linalg=use-parallel-loops %s -o - | FileCheck %s -check-prefix=PLOOP --dump-input-on-failure // RUN: xla-opt -lhlo-fuse-linalg=use-parallel-loops %s -o - | FileCheck %s -check-prefix=PLOOP --dump-input-on-failure
#map0 = affine_map<(d0, d1) -> (d0, d1)> #map0 = affine_map<(d0, d1) -> (d0, d1)>

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -lhlo-legalize-to-affine %s -o - | FileCheck %s // RUN: xla-opt -lhlo-legalize-to-affine %s -o - | FileCheck %s
// Smoke test. // Smoke test.
// CHECK-LABEL: func @min_op // CHECK-LABEL: func @min_op

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -lhlo-legalize-to-gpu -split-input-file | FileCheck %s --dump-input=fail // RUN: xla-opt %s -lhlo-legalize-to-gpu -split-input-file | FileCheck %s --dump-input=fail
func @reduce(%arg: memref<100x10xf32>, func @reduce(%arg: memref<100x10xf32>,
%init: memref<f32>, %init: memref<f32>,

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -lhlo-legalize-to-linalg -split-input-file | FileCheck %s --dump-input-on-failure // RUN: xla-opt %s -lhlo-legalize-to-linalg -split-input-file | FileCheck %s --dump-input-on-failure
// CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)> // CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @element_wise // CHECK-LABEL: func @element_wise

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -lhlo-legalize-to-parallel-loops -canonicalize -split-input-file | FileCheck %s --dump-input-on-failure // RUN: xla-opt %s -lhlo-legalize-to-parallel-loops -canonicalize -split-input-file | FileCheck %s --dump-input-on-failure
func @reduce(%arg: memref<100x10x5xf32>, func @reduce(%arg: memref<100x10x5xf32>,
%init: memref<f32>, %init: memref<f32>,

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -verify-diagnostics -split-input-file | tf-opt | FileCheck %s // RUN: xla-opt %s -verify-diagnostics -split-input-file | xla-opt | FileCheck %s
func @enforce_same_shape(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () { func @enforce_same_shape(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () {
// expected-error@+1{{'xla_lhlo.tanh' op requires all operands to have the same type}} // expected-error@+1{{'xla_lhlo.tanh' op requires all operands to have the same type}}

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -test-xla-lower-complex | FileCheck %s // RUN: xla-opt %s -test-xla-lower-complex | FileCheck %s
// CHECK-LABEL: @add // CHECK-LABEL: @add
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -test-xla-lower-general-dot -split-input-file %s -o - | FileCheck %s // RUN: xla-opt -test-xla-lower-general-dot -split-input-file %s -o - | FileCheck %s
// CHECK-LABEL: @testDebatch1 // CHECK-LABEL: @testDebatch1
func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> { func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> {

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -test-xla-materialize-broadcasts -split-input-file %s -o - | FileCheck --dump-input=fail %s // RUN: xla-opt -test-xla-materialize-broadcasts -split-input-file %s -o - | FileCheck --dump-input=fail %s
// CHECK-LABEL: @addBroadcastRhs // CHECK-LABEL: @addBroadcastRhs
func @addBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { func @addBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -verify-diagnostics -split-input-file | tf-opt | FileCheck %s // RUN: xla-opt %s -verify-diagnostics -split-input-file | xla-opt | FileCheck %s
// Tests for types, ops with custom constraints, verifiers, printer or parser // Tests for types, ops with custom constraints, verifiers, printer or parser
// methods. // methods.

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s // RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
// CHECK-LABEL: func @noop // CHECK-LABEL: func @noop
// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x8xf32>) // CHECK-SAME: (%[[ARG0:.*]]: tensor<4x8xf32>)

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s // RUN: xla-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
// CHECK-LABEL: func @const_fold_collapse_to_scalar // CHECK-LABEL: func @const_fold_collapse_to_scalar
func @const_fold_collapse_to_scalar() -> tensor<i32> { func @const_fold_collapse_to_scalar() -> tensor<i32> {

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s // RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
// CHECK-LABEL: func @noop // CHECK-LABEL: func @noop
// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x2xf32>) // CHECK-SAME: (%[[ARG0:.*]]: tensor<1x2xf32>)

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input=fail // RUN: xla-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input=fail
// CHECK-LABEL: func @remove_noop // CHECK-LABEL: func @remove_noop
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s // RUN: xla-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
// CHECK-LABEL: func @fold_access // CHECK-LABEL: func @fold_access
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -split-input-file -test-xla-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope --dump-input=fail %s // RUN: xla-opt -split-input-file -test-xla-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope --dump-input=fail %s
// CHECK-LABEL: @batchNormInference_2D_inner_features // CHECK-LABEL: @batchNormInference_2D_inner_features
// CHECK-SAME: %[[X:[^:[:space:]]+]] // CHECK-SAME: %[[X:[^:[:space:]]+]]