Move TF_FusedBatchNormV2Op and TF_FusedBatchNormV3Op to tf_ops.td.
These ops are not based on the autogenerated form from the TensorFlow op registry. PiperOrigin-RevId: 321390079 Change-Id: I9cc5baad159fee7e97813cf560839c1b61189a58
This commit is contained in:
parent
e7dca4ea0e
commit
3822d5f114
@ -3540,63 +3540,6 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
||||
}];
|
||||
}
|
||||
|
||||
class TF_FusedBatchNormOpBase<string Name> : TF_Op<Name, [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
|
||||
let summary = "Batch normalization.";
|
||||
|
||||
let description = [{
|
||||
Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
|
||||
The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32]>:$x,
|
||||
F32Tensor:$scale,
|
||||
F32Tensor:$offset,
|
||||
F32Tensor:$mean,
|
||||
F32Tensor:$variance,
|
||||
|
||||
DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
|
||||
DefaultValuedAttr<F32Attr, "1.0f">:$exponential_avg_factor,
|
||||
DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
|
||||
DefaultValuedAttr<BoolAttr, "true">:$is_training
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TF_FoldOperandsTransposeInterface:
|
||||
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
||||
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
|
||||
|
||||
// TF_LayoutSensitiveInterface:
|
||||
StringRef GetOptimalLayout(const RuntimeDevices& devices);
|
||||
LogicalResult UpdateDataFormat(StringRef data_format);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_FusedBatchNormV2Op : TF_FusedBatchNormOpBase<"FusedBatchNormV2"> {
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32]>:$y,
|
||||
F32Tensor:$batch_mean,
|
||||
F32Tensor:$batch_variance,
|
||||
F32Tensor:$reserve_space_1,
|
||||
F32Tensor:$reserve_space_2
|
||||
);
|
||||
}
|
||||
|
||||
def TF_FusedBatchNormV3Op : TF_FusedBatchNormOpBase<"FusedBatchNormV3"> {
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32]>:$y,
|
||||
F32Tensor:$batch_mean,
|
||||
F32Tensor:$batch_variance,
|
||||
F32Tensor:$reserve_space_1,
|
||||
F32Tensor:$reserve_space_2,
|
||||
F32Tensor:$reserve_space_3
|
||||
);
|
||||
}
|
||||
|
||||
def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> {
|
||||
let summary = "Gather slices from `params` according to `indices`.";
|
||||
|
||||
|
@ -1196,4 +1196,61 @@ def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface]> {
|
||||
let verifier = [{ return VerifyPartitionedCall(*this); }];
|
||||
}
|
||||
|
||||
class TF_FusedBatchNormOpBase<string Name> : TF_Op<Name, [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
|
||||
let summary = "Batch normalization.";
|
||||
|
||||
let description = [{
|
||||
Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
|
||||
The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32]>:$x,
|
||||
F32Tensor:$scale,
|
||||
F32Tensor:$offset,
|
||||
F32Tensor:$mean,
|
||||
F32Tensor:$variance,
|
||||
|
||||
DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
|
||||
DefaultValuedAttr<F32Attr, "1.0f">:$exponential_avg_factor,
|
||||
DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
|
||||
DefaultValuedAttr<BoolAttr, "true">:$is_training
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TF_FoldOperandsTransposeInterface:
|
||||
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
||||
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
|
||||
|
||||
// TF_LayoutSensitiveInterface:
|
||||
StringRef GetOptimalLayout(const RuntimeDevices& devices);
|
||||
LogicalResult UpdateDataFormat(StringRef data_format);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_FusedBatchNormV2Op : TF_FusedBatchNormOpBase<"FusedBatchNormV2"> {
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32]>:$y,
|
||||
F32Tensor:$batch_mean,
|
||||
F32Tensor:$batch_variance,
|
||||
F32Tensor:$reserve_space_1,
|
||||
F32Tensor:$reserve_space_2
|
||||
);
|
||||
}
|
||||
|
||||
def TF_FusedBatchNormV3Op : TF_FusedBatchNormOpBase<"FusedBatchNormV3"> {
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32]>:$y,
|
||||
F32Tensor:$batch_mean,
|
||||
F32Tensor:$batch_variance,
|
||||
F32Tensor:$reserve_space_1,
|
||||
F32Tensor:$reserve_space_2,
|
||||
F32Tensor:$reserve_space_3
|
||||
);
|
||||
}
|
||||
|
||||
#endif // TF_OPS
|
||||
|
Loading…
Reference in New Issue
Block a user