Add pad/slice/strided_slice/transpose to gpu target.
PiperOrigin-RevId: 305169485 Change-Id: Ib71c972cf179841e87a20f23f450c024c2d21af3
This commit is contained in:
parent
014dd3f5dc
commit
ef5bd97017
@ -227,6 +227,19 @@ class TFLiteCostEstimator<MulOp, hardware::GPU> {
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.pad
|
||||
template <>
|
||||
class TFLiteCostEstimator<PadOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.relu
|
||||
template <>
|
||||
class TFLiteCostEstimator<ReluOp, hardware::GPU> {
|
||||
@ -266,6 +279,19 @@ class TFLiteCostEstimator<ReshapeOp, hardware::GPU> {
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.slice
|
||||
template <>
|
||||
class TFLiteCostEstimator<SliceOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.softmax
|
||||
template <>
|
||||
class TFLiteCostEstimator<SoftmaxOp, hardware::GPU> {
|
||||
@ -279,5 +305,31 @@ class TFLiteCostEstimator<SoftmaxOp, hardware::GPU> {
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.strided_slice
|
||||
template <>
|
||||
class TFLiteCostEstimator<StridedSliceOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.transpose
|
||||
template <>
|
||||
class TFLiteCostEstimator<TransposeOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
|
||||
|
||||
|
@ -1759,7 +1759,7 @@ Rounds the values of a tensor to the nearest integer, element-wise.
|
||||
}
|
||||
|
||||
def TFL_SliceOp : TFL_Op<"slice", [
|
||||
NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
NoSideEffect, SameOperandsAndResultsScale, TFL_GpuTargetOp]> {
|
||||
let summary = "Return a slice from 'input'.";
|
||||
|
||||
let description = [{
|
||||
@ -1988,7 +1988,8 @@ def TFL_PadOp : TFL_Op<"pad", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_OperandHasRank<1, 2>,
|
||||
TFL_OperandRankEquals1DimOfOperand<0, 1>]> {
|
||||
TFL_OperandRankEquals1DimOfOperand<0, 1>,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Padding operator";
|
||||
|
||||
let description = [{
|
||||
@ -2571,7 +2572,8 @@ def TFL_TransposeOp : TFL_Op<"transpose",
|
||||
// TFL_OperandRankEquals1DimOfOperand<0, 1>,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TCresVTEtIsSameAsOp<0, 0>>,
|
||||
SameOperandsAndResultsScale]> {
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Transpose operator";
|
||||
|
||||
let description = [{
|
||||
@ -2883,7 +2885,8 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice",
|
||||
NoSideEffect,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
SameOperandsAndResultsScale
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp
|
||||
]> {
|
||||
let summary = "StridedSlice Op";
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user