Initial checkin for TFJS Dialect and op definition

PiperOrigin-RevId: 306144532
Change-Id: I62391d55f0ab5026805cbd0eb7a5722145433ed5
This commit is contained in:
A. Unique TensorFlower 2020-04-12 14:01:44 -07:00 committed by TensorFlower Gardener
parent 11b8e457e4
commit dd519b931a
8 changed files with 317 additions and 0 deletions

View File

@ -146,6 +146,7 @@ tf_cc_binary(
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_dialect_registration",
],
)

View File

@ -0,0 +1,79 @@
load("//third_party/mlir:tblgen.bzl", "gentbl")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
filegroup(
name = "tfjs_ops_td_files",
srcs = [
"ir/tfjs_ops.td",
"@llvm-project//mlir:OpBaseTdFiles",
],
)
gentbl(
name = "tfjs_inc_gen",
tbl_outs = [
(
"-gen-op-decls",
"ir/tfjs_ops.h.inc",
),
(
"-gen-op-defs",
"ir/tfjs_ops.cc.inc",
),
(
"-gen-dialect-decls",
"ir/tfjs_dialect.h.inc",
),
(
"-gen-dialect-doc",
"g3doc/tfjs_dialect.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ir/tfjs_ops.td",
td_srcs = [
"ir/tfjs_ops.td",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
],
)
cc_library(
name = "tensorflow_js",
srcs = [
"ir/tfjs_dialect.h.inc",
"ir/tfjs_ops.cc",
"ir/tfjs_ops.cc.inc",
"ir/tfjs_ops.h.inc",
],
hdrs = [
"ir/tfjs_ops.h",
],
deps = [
":tfjs_inc_gen",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SideEffects",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,
)
cc_library(
name = "tensorflow_js_dialect_registration",
srcs = [
"ir/dialect_registration.cc",
],
deps = [
":tensorflow_js",
],
alwayslink = 1,
)

View File

@ -0,0 +1,19 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h"
// Static initialization for TensorFlow.js op registration.
static mlir::DialectRegistration<mlir::tfjs::TFJSDialect> tfjs_ops;

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h"
namespace mlir {
namespace tfjs {
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.cc.inc"
//===----------------------------------------------------------------------===//
// TFJSDialect
//===----------------------------------------------------------------------===//
TFJSDialect::TFJSDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addOperations<
#define GET_OP_LIST
#include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.cc.inc"
>();
}
} // namespace tfjs
} // namespace mlir

View File

@ -0,0 +1,43 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
//===----------------------------------------------------------------------===//
//
// This file defines the dialect for TensorFlow.js
//
//===----------------------------------------------------------------------===//
#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_IR_TFJS_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_TFJS_IR_TFJS_OPS_H_
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
#include "mlir/Support/Functional.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
namespace mlir {
namespace tfjs {
#include "tensorflow/compiler/mlir/tfjs/ir/tfjs_dialect.h.inc"
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h.inc"
} // namespace tfjs
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TFJS_IR_TFJS_OPS_H_

View File

@ -0,0 +1,66 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
//===----------------------------------------------------------------------===//
//
// This is the operation definition file for TensorFlow.js dialect operations.
//
//===----------------------------------------------------------------------===//
#ifndef TFJS_DIALECT
#define TFJS_DIALECT
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffects.td"
//===----------------------------------------------------------------------===//
// TensorFlow.js dialect definitions
//===----------------------------------------------------------------------===//
def TFJSDialect : Dialect {
let name = "tfjs";
let summary = "Types and operations for TensorFlow.js dialect";
let description = [{
This dialect contains operations for TensorFlow.js. This dialect will be
used in conjunction with the TensorFlow dialects for converting & optimizing
TF graphs to be deployed on TFJS.
}];
let cppNamespace = "tfjs";
}
//===----------------------------------------------------------------------===//
// TensorFlow.js op definitions
//===----------------------------------------------------------------------===//
// Base class for the operation in this dialect
class TFJS_Op<string mnemonic, list<OpTrait> traits = []> :
Op<TFJSDialect, mnemonic, traits>;
def TFJS_PReluOp : TFJS_Op<"Prelu", [NoSideEffect, ResultsBroadcastableShape,
SameOperandsAndResultElementType]> {
let summary = "Parametric Rectified Linear Unit operator";
let description = [{
Element-wise PReLU operator
x -> x >= 0 ? x : (alpha * x)
}];
let arguments = (ins AnyTensor:$input, AnyTensor:$alpha);
let results = (outs AnyTensor:$output);
let assemblyFormat =
" operands attr-dict `:` `(` type(operands) `)` `->` type($output)";
}
#endif // TFJS_DIALECT

View File

@ -0,0 +1,19 @@
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
package(licenses = ["notice"])
glob_lit_tests(
data = [":test_utilities"],
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = ["mlir"],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/mlir:tf-opt",
"@llvm-project//llvm:FileCheck",
],
)

View File

@ -0,0 +1,54 @@
// RUN: tf-opt -split-input-file -verify-diagnostics -tfl-runtime-verify %s | FileCheck %s --dump-input-on-failure
// -----
func @testPReluWrongArgumentAndResultTypes(%arg0: tensor<10x10x10x10xf32>, %arg1: tensor<1x1x10xi32>) -> tensor<10x10x10xf32> {
// expected-error @+1 {{requires the same element type for all operands and results}}
%0 = tfjs.Prelu %arg0, %arg1 : (tensor<10x10x10x10xf32>, tensor<1x1x10xi32>) -> tensor<10x10x10x10xi32>
return %0 : tensor<10x10x10x10xi32>
}
// -----
func @testPReluWrongOutputShape(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<1x2x3x5xf32> {
// expected-error @+1 {{op result type '1x2x3x5' not broadcast compatible with broadcasted operands's shapes '1x2x3x4'}}
%0 = tfjs.Prelu %arg0, %arg1 : (tensor<1x2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<1x2x3x5xf32>
return %0 : tensor<1x2x3x5xf32>
}
// -----
func @testPReluWrongAlphaRank(%arg0: tensor<7x3x2x14xf32>, %arg1: tensor<2x7x3x2x14xf32>) -> tensor<7x3x2x14xf32> {
// expected-error @+1 {{result type '7x3x2x14' not broadcast compatible with broadcasted operands's shapes '2x7x3x2x14'}}
%0 = tfjs.Prelu %arg0, %arg1 : (tensor<7x3x2x14xf32>, tensor<2x7x3x2x14xf32>) -> tensor<7x3x2x14xf32>
return %0 : tensor<7x3x2x14xf32>
}
// -----
func @testPReluInvalidBroadcast(%arg0: tensor<15x14x2x14xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<15x14x2x14xf32> {
// expected-error @+1 {{op operands don't have broadcast-compatible shapes}}
%0 = tfjs.Prelu %arg0, %arg1 : (tensor<15x14x2x14xf32>, tensor<1x1x3xf32>) -> tensor<15x14x2x14xf32>
return %0 : tensor<15x14x2x14xf32>
}
// -----
// CHECK-LABEL: func @testPReluValidSameSize
func @testPReluValidSameSize(%arg0: tensor<16x20x20x13xf32>, %arg1: tensor<20x20x13xf32>) -> tensor<16x20x20x13xf32> {
%0 = tfjs.Prelu %arg0, %arg1 : (tensor<16x20x20x13xf32>, tensor<20x20x13xf32>) -> tensor<16x20x20x13xf32>
return %0 : tensor<16x20x20x13xf32>
}
// -----
// CHECK-LABEL: func @testPReluValidBroadcast
func @testPReluValidBroadcast(%arg0: tensor<19x7x12x14xf32>, %arg1: tensor<1x1x14xf32>) -> tensor<19x7x12x14xf32> {
%0 = tfjs.Prelu %arg0, %arg1 : (tensor<19x7x12x14xf32>, tensor<1x1x14xf32>) -> tensor<19x7x12x14xf32>
return %0 : tensor<19x7x12x14xf32>
}
// -----
// CHECK-LABEL: func @testPReluValidFullBroadcast
func @testPReluValidFullBroadcast(%arg0: tensor<7x8x9x10xf32>, %arg1: tensor<1x1x1xf32>) -> tensor<7x8x9x10xf32> {
%0 = tfjs.Prelu %arg0, %arg1 : (tensor<7x8x9x10xf32>, tensor<1x1x1xf32>) -> tensor<7x8x9x10xf32>
return %0 : tensor<7x8x9x10xf32>
}