Add add/mul/relu to gpu target

PiperOrigin-RevId: 303112045
Change-Id: Ieb3880f299453d87c3c3ebc3ab27390ca9d532b3
This commit is contained in:
Renjie Liu 2020-03-26 08:02:19 -07:00 committed by TensorFlower Gardener
parent 410852dbd2
commit 6fb1f04baf
2 changed files with 49 additions and 3 deletions

View File

@ -16,6 +16,19 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
// tfl.add
template <>
class TFLiteCostEstimator<AddOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.average_pool_2d
template <>
class TFLiteCostEstimator<AveragePool2DOp, hardware::GPU> {
@ -69,5 +82,31 @@ class TFLiteCostEstimator<MaxPool2DOp, hardware::GPU> {
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.mul
template <>
class TFLiteCostEstimator<MulOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.relu
template <>
class TFLiteCostEstimator<ReluOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_

View File

@ -360,7 +360,10 @@ an output element, this operation computes \\(y = |x|\\).
let hasFolder = 1;
}
def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape, NoSideEffect, Commutative]> {
def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape,
NoSideEffect,
Commutative,
TFL_GpuTargetOp]> {
let summary = "Addition operator";
let description = [{
@ -1869,7 +1872,10 @@ def TFL_MinimumOp : TFL_Op<"minimum", [
let hasOptions = 0;
}
def TFL_MulOp : TFL_Op<"mul", [ResultsBroadcastableShape, NoSideEffect, Commutative]> {
def TFL_MulOp : TFL_Op<"mul", [ResultsBroadcastableShape,
NoSideEffect,
Commutative,
TFL_GpuTargetOp]> {
let summary = "Multiplication operator";
let description = [{
@ -2102,7 +2108,8 @@ def TFL_RankOp: TFL_Op<"rank", [NoSideEffect]> {
def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultsScale]> {
SameOperandsAndResultsScale,
TFL_GpuTargetOp]> {
let summary = "Relu operator";
let description = [{