Add a new xla-opt binary tool to exercise only the XLA related passes
This new utility is just like tf-opt but does not bring all the dependencies from TensorFlow and as such is thinner and faster to link. This will be more convenient for iterative development on the MLIR-XLA work. PiperOrigin-RevId: 306362771 Change-Id: I80ad3abeacc7f1eb2af85b54de90c32b3505bc38
This commit is contained in:
parent
43f67e0881
commit
d6c51d18b7
@ -40,7 +40,6 @@ cc_library(
|
|||||||
srcs = ["tf_mlir_opt_main.cc"],
|
srcs = ["tf_mlir_opt_main.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":init_mlir",
|
":init_mlir",
|
||||||
":passes",
|
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/platform:logging",
|
"//tensorflow/core/platform:logging",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
@ -56,6 +55,7 @@ cc_library(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "passes",
|
name = "passes",
|
||||||
visibility = [
|
visibility = [
|
||||||
|
":__subpackages__",
|
||||||
"//tensorflow/python:__subpackages__",
|
"//tensorflow/python:__subpackages__",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
@ -77,24 +77,6 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes",
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
|
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
|
||||||
"//tensorflow/compiler/mlir/xla:buffer_assignment",
|
|
||||||
"//tensorflow/compiler/mlir/xla:hlo",
|
|
||||||
"//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo",
|
|
||||||
"//tensorflow/compiler/mlir/xla:lhlo",
|
|
||||||
"//tensorflow/compiler/mlir/xla:lhlo_copy_removal",
|
|
||||||
"//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg",
|
|
||||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
|
|
||||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",
|
|
||||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_parallel_loops",
|
|
||||||
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
|
|
||||||
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
|
|
||||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
|
||||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
|
|
||||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg",
|
|
||||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
|
|
||||||
"//tensorflow/compiler/mlir/xla:xla_lower",
|
|
||||||
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",
|
|
||||||
"//tensorflow/compiler/mlir/xla:xla_test_passes",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -142,12 +124,14 @@ cc_library(
|
|||||||
tf_cc_binary(
|
tf_cc_binary(
|
||||||
name = "tf-opt",
|
name = "tf-opt",
|
||||||
deps = [
|
deps = [
|
||||||
|
":passes",
|
||||||
":tf_mlir_opt_main",
|
":tf_mlir_opt_main",
|
||||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
|
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration",
|
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
|
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
|
||||||
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_dialect_registration",
|
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_dialect_registration",
|
||||||
|
"//tensorflow/compiler/mlir/xla:all_xla_passes_for_testing",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ tool_dirs = config.mlir_tf_tools_dirs + [
|
|||||||
tool_names = [
|
tool_names = [
|
||||||
'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
|
'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
|
||||||
'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate',
|
'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate',
|
||||||
'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer'
|
'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-opt'
|
||||||
]
|
]
|
||||||
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
||||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
load("//third_party/mlir:tblgen.bzl", "gentbl")
|
load("//third_party/mlir:tblgen.bzl", "gentbl")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_native_cc_binary")
|
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = [":friends"],
|
default_visibility = [":friends"],
|
||||||
@ -750,7 +750,7 @@ genrule(
|
|||||||
cmd = ("$(location :operator_writer_gen) " +
|
cmd = ("$(location :operator_writer_gen) " +
|
||||||
"-I external/llvm-project/mlir/include " +
|
"-I external/llvm-project/mlir/include " +
|
||||||
"-I external/org_tensorflow " +
|
"-I external/org_tensorflow " +
|
||||||
"$(location //tensorflow/compiler/mlir/xla:ir/hlo_ops.td) " +
|
"$(location :ir/hlo_ops.td) " +
|
||||||
" -o $@"),
|
" -o $@"),
|
||||||
tools = [":operator_writer_gen"],
|
tools = [":operator_writer_gen"],
|
||||||
)
|
)
|
||||||
@ -763,3 +763,38 @@ cc_library(
|
|||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "all_xla_passes_for_testing",
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/compiler/mlir:__subpackages__",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":buffer_assignment",
|
||||||
|
":hlo",
|
||||||
|
":hlo_legalize_to_lhlo",
|
||||||
|
":lhlo",
|
||||||
|
":lhlo_copy_removal",
|
||||||
|
":lhlo_fuse_linalg",
|
||||||
|
":lhlo_legalize_to_affine",
|
||||||
|
":lhlo_legalize_to_gpu",
|
||||||
|
":lhlo_legalize_to_parallel_loops",
|
||||||
|
":xla_dialect_registration",
|
||||||
|
":xla_legalize_control_flow",
|
||||||
|
":xla_legalize_tf",
|
||||||
|
":xla_legalize_tf_with_tf2xla",
|
||||||
|
":xla_legalize_to_linalg",
|
||||||
|
":xla_legalize_to_standard",
|
||||||
|
":xla_lower",
|
||||||
|
":xla_materialize_broadcasts",
|
||||||
|
":xla_test_passes",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_binary(
|
||||||
|
name = "xla-opt",
|
||||||
|
deps = [
|
||||||
|
":all_xla_passes_for_testing",
|
||||||
|
"//tensorflow/compiler/mlir:tf_mlir_opt_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -14,6 +14,7 @@ filegroup(
|
|||||||
testonly = True,
|
testonly = True,
|
||||||
data = [
|
data = [
|
||||||
"//tensorflow/compiler/mlir:tf-opt",
|
"//tensorflow/compiler/mlir:tf-opt",
|
||||||
|
"//tensorflow/compiler/mlir/xla:xla-opt",
|
||||||
"@llvm-project//llvm:FileCheck",
|
"@llvm-project//llvm:FileCheck",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt -test-buffer-assignment -split-input-file %s | FileCheck %s -dump-input-on-failure
|
// RUN: xla-opt -test-buffer-assignment -split-input-file %s | FileCheck %s -dump-input-on-failure
|
||||||
|
|
||||||
// CHECK-LABEL: Testing : condBranch
|
// CHECK-LABEL: Testing : condBranch
|
||||||
func @condBranch(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
|
func @condBranch(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input-on-failure
|
// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
|
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
|
||||||
// CHECK: "xla_hlo.dynamic-slice"
|
// CHECK: "xla_hlo.dynamic-slice"
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
|
// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @single_operand
|
// CHECK-LABEL: func @single_operand
|
||||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
|
// RUN: xla-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt -hlo-legalize-to-lhlo %s -o - | FileCheck %s --dump-input-on-failure
|
// RUN: xla-opt -hlo-legalize-to-lhlo %s -o - | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
// CHECK-LABEL: func @attrs
|
// CHECK-LABEL: func @attrs
|
||||||
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
@ -146,14 +146,16 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @external_func() -> tensor<3xi64>
|
||||||
|
|
||||||
// CHECK-LABEL: func @dyn_broadcast
|
// CHECK-LABEL: func @dyn_broadcast
|
||||||
func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
||||||
%tensor_operand = tensor_load %operand : memref<?x?xf32>
|
%tensor_operand = tensor_load %operand : memref<?x?xf32>
|
||||||
%shape = "tf.compute_shape"() : () -> tensor<3xi64>
|
%shape = call @external_func() : () -> tensor<3xi64>
|
||||||
%tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape)
|
%tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape)
|
||||||
{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
|
{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
|
||||||
: (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
: (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||||
// CHECK: %[[SHAPE:.*]] = "tf.compute_shape"()
|
// CHECK: %[[SHAPE:.*]] = call @external_func()
|
||||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
|
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
|
||||||
// CHECK: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
|
// CHECK: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt %s -hlo-legalize-to-linalg -split-input-file | FileCheck %s
|
// RUN: xla-opt %s -hlo-legalize-to-linalg -split-input-file | FileCheck %s
|
||||||
|
|
||||||
// CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)>
|
// CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
// CHECK-LABEL: func @float_add
|
// CHECK-LABEL: func @float_add
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt %s -inline | FileCheck %s --dump-input=fail
|
// RUN: xla-opt %s -inline | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
// Test case: Basic test of inlining into xla_hlo.while.
|
// Test case: Basic test of inlining into xla_hlo.while.
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt -xla-legalize-control-flow %s -o - | FileCheck %s
|
// RUN: xla-opt -xla-legalize-control-flow %s -o - | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @while(%arg0: tensor<i64>) -> tensor<i64> {
|
// CHECK-LABEL: func @while(%arg0: tensor<i64>) -> tensor<i64> {
|
||||||
func @while(%arg0: tensor<i64>) -> tensor<i64> {
|
func @while(%arg0: tensor<i64>) -> tensor<i64> {
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
// RUN: tf-opt -xla-legalize-tf-with-tf2xla=device-type=XLA_CPU_JIT %s | FileCheck %s --dump-input-on-failure
|
// RUN: tf-opt -xla-legalize-tf-with-tf2xla=device-type=XLA_CPU_JIT %s | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
// INVALID_DEVICE: tf-opt -xla-legalize-tf-with-tf2xla=device-type=INVALID_DEVICE %s | FileCheck %s --dump-input-on-failure
|
// INVALID_DEVICE: xla-opt -xla-legalize-tf-with-tf2xla=device-type=INVALID_DEVICE %s | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
|
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt -xla-legalize-to-std %s -o - | FileCheck %s
|
// RUN: xla-opt -xla-legalize-to-std %s -o - | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
// CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt -lhlo-copy-removal %s -o - | FileCheck %s --dump-input-on-failure
|
// RUN: xla-opt -lhlo-copy-removal %s -o - | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
// CHECK-LABEL: func @remove_simple
|
// CHECK-LABEL: func @remove_simple
|
||||||
func @remove_simple(%arg0: memref<2x2xf32>) {
|
func @remove_simple(%arg0: memref<2x2xf32>) {
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
// RUN: tf-opt -lhlo-fuse-linalg %s -o - | FileCheck %s --dump-input=always
|
// RUN: xla-opt -lhlo-fuse-linalg %s -o - | FileCheck %s --dump-input=always
|
||||||
// RUN: tf-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -o - | FileCheck %s -check-prefix=TILED --dump-input-on-failure
|
// RUN: xla-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -o - | FileCheck %s -check-prefix=TILED --dump-input-on-failure
|
||||||
// RUN: tf-opt -lhlo-fuse-linalg=use-parallel-loops %s -o - | FileCheck %s -check-prefix=PLOOP --dump-input-on-failure
|
// RUN: xla-opt -lhlo-fuse-linalg=use-parallel-loops %s -o - | FileCheck %s -check-prefix=PLOOP --dump-input-on-failure
|
||||||
|
|
||||||
|
|
||||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt -lhlo-legalize-to-affine %s -o - | FileCheck %s
|
// RUN: xla-opt -lhlo-legalize-to-affine %s -o - | FileCheck %s
|
||||||
|
|
||||||
// Smoke test.
|
// Smoke test.
|
||||||
// CHECK-LABEL: func @min_op
|
// CHECK-LABEL: func @min_op
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt %s -lhlo-legalize-to-gpu -split-input-file | FileCheck %s --dump-input=fail
|
// RUN: xla-opt %s -lhlo-legalize-to-gpu -split-input-file | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
func @reduce(%arg: memref<100x10xf32>,
|
func @reduce(%arg: memref<100x10xf32>,
|
||||||
%init: memref<f32>,
|
%init: memref<f32>,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt %s -lhlo-legalize-to-linalg -split-input-file | FileCheck %s --dump-input-on-failure
|
// RUN: xla-opt %s -lhlo-legalize-to-linalg -split-input-file | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
// CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)>
|
// CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
// CHECK-LABEL: func @element_wise
|
// CHECK-LABEL: func @element_wise
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt %s -lhlo-legalize-to-parallel-loops -canonicalize -split-input-file | FileCheck %s --dump-input-on-failure
|
// RUN: xla-opt %s -lhlo-legalize-to-parallel-loops -canonicalize -split-input-file | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
func @reduce(%arg: memref<100x10x5xf32>,
|
func @reduce(%arg: memref<100x10x5xf32>,
|
||||||
%init: memref<f32>,
|
%init: memref<f32>,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt %s -verify-diagnostics -split-input-file | tf-opt | FileCheck %s
|
// RUN: xla-opt %s -verify-diagnostics -split-input-file | xla-opt | FileCheck %s
|
||||||
|
|
||||||
func @enforce_same_shape(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () {
|
func @enforce_same_shape(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () {
|
||||||
// expected-error@+1{{'xla_lhlo.tanh' op requires all operands to have the same type}}
|
// expected-error@+1{{'xla_lhlo.tanh' op requires all operands to have the same type}}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt %s -test-xla-lower-complex | FileCheck %s
|
// RUN: xla-opt %s -test-xla-lower-complex | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: @add
|
// CHECK-LABEL: @add
|
||||||
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt -test-xla-lower-general-dot -split-input-file %s -o - | FileCheck %s
|
// RUN: xla-opt -test-xla-lower-general-dot -split-input-file %s -o - | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: @testDebatch1
|
// CHECK-LABEL: @testDebatch1
|
||||||
func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> {
|
func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt -test-xla-materialize-broadcasts -split-input-file %s -o - | FileCheck --dump-input=fail %s
|
// RUN: xla-opt -test-xla-materialize-broadcasts -split-input-file %s -o - | FileCheck --dump-input=fail %s
|
||||||
|
|
||||||
// CHECK-LABEL: @addBroadcastRhs
|
// CHECK-LABEL: @addBroadcastRhs
|
||||||
func @addBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
func @addBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt %s -verify-diagnostics -split-input-file | tf-opt | FileCheck %s
|
// RUN: xla-opt %s -verify-diagnostics -split-input-file | xla-opt | FileCheck %s
|
||||||
|
|
||||||
// Tests for types, ops with custom constraints, verifiers, printer or parser
|
// Tests for types, ops with custom constraints, verifiers, printer or parser
|
||||||
// methods.
|
// methods.
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
|
// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @noop
|
// CHECK-LABEL: func @noop
|
||||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x8xf32>)
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x8xf32>)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
|
// RUN: xla-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @const_fold_collapse_to_scalar
|
// CHECK-LABEL: func @const_fold_collapse_to_scalar
|
||||||
func @const_fold_collapse_to_scalar() -> tensor<i32> {
|
func @const_fold_collapse_to_scalar() -> tensor<i32> {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
|
// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @noop
|
// CHECK-LABEL: func @noop
|
||||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x2xf32>)
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x2xf32>)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input=fail
|
// RUN: xla-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
// CHECK-LABEL: func @remove_noop
|
// CHECK-LABEL: func @remove_noop
|
||||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
|
// RUN: xla-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @fold_access
|
// CHECK-LABEL: func @fold_access
|
||||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt -split-input-file -test-xla-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope --dump-input=fail %s
|
// RUN: xla-opt -split-input-file -test-xla-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope --dump-input=fail %s
|
||||||
|
|
||||||
// CHECK-LABEL: @batchNormInference_2D_inner_features
|
// CHECK-LABEL: @batchNormInference_2D_inner_features
|
||||||
// CHECK-SAME: %[[X:[^:[:space:]]+]]
|
// CHECK-SAME: %[[X:[^:[:space:]]+]]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user