Initial checkin for TFJS Dialect and op definition
PiperOrigin-RevId: 306144532 Change-Id: I62391d55f0ab5026805cbd0eb7a5722145433ed5
This commit is contained in:
parent
11b8e457e4
commit
dd519b931a
tensorflow/compiler/mlir
@ -146,6 +146,7 @@ tf_cc_binary(
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
|
||||
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_dialect_registration",
|
||||
],
|
||||
)
|
||||
|
||||
|
79
tensorflow/compiler/mlir/tfjs/BUILD
Normal file
79
tensorflow/compiler/mlir/tfjs/BUILD
Normal file
@ -0,0 +1,79 @@
|
||||
load("//third_party/mlir:tblgen.bzl", "gentbl")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "tfjs_ops_td_files",
|
||||
srcs = [
|
||||
"ir/tfjs_ops.td",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "tfjs_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-op-decls",
|
||||
"ir/tfjs_ops.h.inc",
|
||||
),
|
||||
(
|
||||
"-gen-op-defs",
|
||||
"ir/tfjs_ops.cc.inc",
|
||||
),
|
||||
(
|
||||
"-gen-dialect-decls",
|
||||
"ir/tfjs_dialect.h.inc",
|
||||
),
|
||||
(
|
||||
"-gen-dialect-doc",
|
||||
"g3doc/tfjs_dialect.md",
|
||||
),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "ir/tfjs_ops.td",
|
||||
td_srcs = [
|
||||
"ir/tfjs_ops.td",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_js",
|
||||
srcs = [
|
||||
"ir/tfjs_dialect.h.inc",
|
||||
"ir/tfjs_ops.cc",
|
||||
"ir/tfjs_ops.cc.inc",
|
||||
"ir/tfjs_ops.h.inc",
|
||||
],
|
||||
hdrs = [
|
||||
"ir/tfjs_ops.h",
|
||||
],
|
||||
deps = [
|
||||
":tfjs_inc_gen",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:SideEffects",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_js_dialect_registration",
|
||||
srcs = [
|
||||
"ir/dialect_registration.cc",
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_js",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
19
tensorflow/compiler/mlir/tfjs/ir/dialect_registration.cc
Normal file
19
tensorflow/compiler/mlir/tfjs/ir/dialect_registration.cc
Normal file
@ -0,0 +1,19 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h"
|
||||
|
||||
// Static initialization for TensorFlow.js op registration.
|
||||
static mlir::DialectRegistration<mlir::tfjs::TFJSDialect> tfjs_ops;
|
36
tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.cc
Normal file
36
tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.cc
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tfjs {
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.cc.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFJSDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
TFJSDialect::TFJSDialect(MLIRContext *context)
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.cc.inc"
|
||||
>();
|
||||
}
|
||||
} // namespace tfjs
|
||||
} // namespace mlir
|
43
tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h
Normal file
43
tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h
Normal file
@ -0,0 +1,43 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines the dialect for TensorFlow.js
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_IR_TFJS_OPS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TFJS_IR_TFJS_OPS_H_
|
||||
|
||||
#include "mlir/Dialect/Traits.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
|
||||
#include "mlir/Support/Functional.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
namespace mlir {
|
||||
namespace tfjs {
|
||||
|
||||
#include "tensorflow/compiler/mlir/tfjs/ir/tfjs_dialect.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h.inc"
|
||||
|
||||
} // namespace tfjs
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TFJS_IR_TFJS_OPS_H_
|
66
tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td
Normal file
66
tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td
Normal file
@ -0,0 +1,66 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This is the operation definition file for TensorFlow.js dialect operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TFJS_DIALECT
|
||||
#define TFJS_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffects.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow.js dialect definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TFJSDialect : Dialect {
|
||||
let name = "tfjs";
|
||||
|
||||
let summary = "Types and operations for TensorFlow.js dialect";
|
||||
let description = [{
|
||||
This dialect contains operations for TensorFlow.js. This dialect will be
|
||||
used in conjunction with the TensorFlow dialects for converting & optimizing
|
||||
TF graphs to be deployed on TFJS.
|
||||
}];
|
||||
|
||||
let cppNamespace = "tfjs";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow.js op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Base class for the operation in this dialect
|
||||
class TFJS_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<TFJSDialect, mnemonic, traits>;
|
||||
|
||||
def TFJS_PReluOp : TFJS_Op<"Prelu", [NoSideEffect, ResultsBroadcastableShape,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "Parametric Rectified Linear Unit operator";
|
||||
let description = [{
|
||||
Element-wise PReLU operator
|
||||
x -> x >= 0 ? x : (alpha * x)
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTensor:$input, AnyTensor:$alpha);
|
||||
let results = (outs AnyTensor:$output);
|
||||
let assemblyFormat =
|
||||
" operands attr-dict `:` `(` type(operands) `)` `->` type($output)";
|
||||
}
|
||||
#endif // TFJS_DIALECT
|
19
tensorflow/compiler/mlir/tfjs/tests/BUILD
Normal file
19
tensorflow/compiler/mlir/tfjs/tests/BUILD
Normal file
@ -0,0 +1,19 @@
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
package(licenses = ["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir:tf-opt",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
],
|
||||
)
|
54
tensorflow/compiler/mlir/tfjs/tests/ops.mlir
Normal file
54
tensorflow/compiler/mlir/tfjs/tests/ops.mlir
Normal file
@ -0,0 +1,54 @@
|
||||
// RUN: tf-opt -split-input-file -verify-diagnostics -tfl-runtime-verify %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// -----
|
||||
|
||||
func @testPReluWrongArgumentAndResultTypes(%arg0: tensor<10x10x10x10xf32>, %arg1: tensor<1x1x10xi32>) -> tensor<10x10x10xf32> {
|
||||
// expected-error @+1 {{requires the same element type for all operands and results}}
|
||||
%0 = tfjs.Prelu %arg0, %arg1 : (tensor<10x10x10x10xf32>, tensor<1x1x10xi32>) -> tensor<10x10x10x10xi32>
|
||||
return %0 : tensor<10x10x10x10xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testPReluWrongOutputShape(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<1x2x3x5xf32> {
|
||||
// expected-error @+1 {{op result type '1x2x3x5' not broadcast compatible with broadcasted operands's shapes '1x2x3x4'}}
|
||||
%0 = tfjs.Prelu %arg0, %arg1 : (tensor<1x2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<1x2x3x5xf32>
|
||||
return %0 : tensor<1x2x3x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testPReluWrongAlphaRank(%arg0: tensor<7x3x2x14xf32>, %arg1: tensor<2x7x3x2x14xf32>) -> tensor<7x3x2x14xf32> {
|
||||
// expected-error @+1 {{result type '7x3x2x14' not broadcast compatible with broadcasted operands's shapes '2x7x3x2x14'}}
|
||||
%0 = tfjs.Prelu %arg0, %arg1 : (tensor<7x3x2x14xf32>, tensor<2x7x3x2x14xf32>) -> tensor<7x3x2x14xf32>
|
||||
return %0 : tensor<7x3x2x14xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testPReluInvalidBroadcast(%arg0: tensor<15x14x2x14xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<15x14x2x14xf32> {
|
||||
// expected-error @+1 {{op operands don't have broadcast-compatible shapes}}
|
||||
%0 = tfjs.Prelu %arg0, %arg1 : (tensor<15x14x2x14xf32>, tensor<1x1x3xf32>) -> tensor<15x14x2x14xf32>
|
||||
return %0 : tensor<15x14x2x14xf32>
|
||||
}
|
||||
// -----
|
||||
// CHECK-LABEL: func @testPReluValidSameSize
|
||||
func @testPReluValidSameSize(%arg0: tensor<16x20x20x13xf32>, %arg1: tensor<20x20x13xf32>) -> tensor<16x20x20x13xf32> {
|
||||
%0 = tfjs.Prelu %arg0, %arg1 : (tensor<16x20x20x13xf32>, tensor<20x20x13xf32>) -> tensor<16x20x20x13xf32>
|
||||
return %0 : tensor<16x20x20x13xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @testPReluValidBroadcast
|
||||
func @testPReluValidBroadcast(%arg0: tensor<19x7x12x14xf32>, %arg1: tensor<1x1x14xf32>) -> tensor<19x7x12x14xf32> {
|
||||
%0 = tfjs.Prelu %arg0, %arg1 : (tensor<19x7x12x14xf32>, tensor<1x1x14xf32>) -> tensor<19x7x12x14xf32>
|
||||
return %0 : tensor<19x7x12x14xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @testPReluValidFullBroadcast
|
||||
func @testPReluValidFullBroadcast(%arg0: tensor<7x8x9x10xf32>, %arg1: tensor<1x1x1xf32>) -> tensor<7x8x9x10xf32> {
|
||||
%0 = tfjs.Prelu %arg0, %arg1 : (tensor<7x8x9x10xf32>, tensor<1x1x1xf32>) -> tensor<7x8x9x10xf32>
|
||||
return %0 : tensor<7x8x9x10xf32>
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user