Move TF_FusedBatchNormV2Op and TF_FusedBatchNormV3Op to tf_ops.td.
These ops are not based on the autogenerated form from the TensorFlow op registry. PiperOrigin-RevId: 321390079 Change-Id: I9cc5baad159fee7e97813cf560839c1b61189a58
This commit is contained in:
parent
e7dca4ea0e
commit
3822d5f114
@ -3540,63 +3540,6 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
class TF_FusedBatchNormOpBase<string Name> : TF_Op<Name, [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
|
|
||||||
let summary = "Batch normalization.";
|
|
||||||
|
|
||||||
let description = [{
|
|
||||||
Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
|
|
||||||
The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
TensorOf<[BF16, F16, F32]>:$x,
|
|
||||||
F32Tensor:$scale,
|
|
||||||
F32Tensor:$offset,
|
|
||||||
F32Tensor:$mean,
|
|
||||||
F32Tensor:$variance,
|
|
||||||
|
|
||||||
DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
|
|
||||||
DefaultValuedAttr<F32Attr, "1.0f">:$exponential_avg_factor,
|
|
||||||
DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
|
|
||||||
DefaultValuedAttr<BoolAttr, "true">:$is_training
|
|
||||||
);
|
|
||||||
|
|
||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
|
||||||
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
// TF_FoldOperandsTransposeInterface:
|
|
||||||
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
|
||||||
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
|
||||||
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
|
|
||||||
|
|
||||||
// TF_LayoutSensitiveInterface:
|
|
||||||
StringRef GetOptimalLayout(const RuntimeDevices& devices);
|
|
||||||
LogicalResult UpdateDataFormat(StringRef data_format);
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def TF_FusedBatchNormV2Op : TF_FusedBatchNormOpBase<"FusedBatchNormV2"> {
|
|
||||||
let results = (outs
|
|
||||||
TensorOf<[BF16, F16, F32]>:$y,
|
|
||||||
F32Tensor:$batch_mean,
|
|
||||||
F32Tensor:$batch_variance,
|
|
||||||
F32Tensor:$reserve_space_1,
|
|
||||||
F32Tensor:$reserve_space_2
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
def TF_FusedBatchNormV3Op : TF_FusedBatchNormOpBase<"FusedBatchNormV3"> {
|
|
||||||
let results = (outs
|
|
||||||
TensorOf<[BF16, F16, F32]>:$y,
|
|
||||||
F32Tensor:$batch_mean,
|
|
||||||
F32Tensor:$batch_variance,
|
|
||||||
F32Tensor:$reserve_space_1,
|
|
||||||
F32Tensor:$reserve_space_2,
|
|
||||||
F32Tensor:$reserve_space_3
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> {
|
def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> {
|
||||||
let summary = "Gather slices from `params` according to `indices`.";
|
let summary = "Gather slices from `params` according to `indices`.";
|
||||||
|
|
||||||
|
@ -1196,4 +1196,61 @@ def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface]> {
|
|||||||
let verifier = [{ return VerifyPartitionedCall(*this); }];
|
let verifier = [{ return VerifyPartitionedCall(*this); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class TF_FusedBatchNormOpBase<string Name> : TF_Op<Name, [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
|
||||||
|
let summary = "Batch normalization.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
|
||||||
|
The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TensorOf<[BF16, F16, F32]>:$x,
|
||||||
|
F32Tensor:$scale,
|
||||||
|
F32Tensor:$offset,
|
||||||
|
F32Tensor:$mean,
|
||||||
|
F32Tensor:$variance,
|
||||||
|
|
||||||
|
DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
|
||||||
|
DefaultValuedAttr<F32Attr, "1.0f">:$exponential_avg_factor,
|
||||||
|
DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
|
||||||
|
DefaultValuedAttr<BoolAttr, "true">:$is_training
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// TF_FoldOperandsTransposeInterface:
|
||||||
|
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||||
|
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
||||||
|
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
|
||||||
|
|
||||||
|
// TF_LayoutSensitiveInterface:
|
||||||
|
StringRef GetOptimalLayout(const RuntimeDevices& devices);
|
||||||
|
LogicalResult UpdateDataFormat(StringRef data_format);
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_FusedBatchNormV2Op : TF_FusedBatchNormOpBase<"FusedBatchNormV2"> {
|
||||||
|
let results = (outs
|
||||||
|
TensorOf<[BF16, F16, F32]>:$y,
|
||||||
|
F32Tensor:$batch_mean,
|
||||||
|
F32Tensor:$batch_variance,
|
||||||
|
F32Tensor:$reserve_space_1,
|
||||||
|
F32Tensor:$reserve_space_2
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_FusedBatchNormV3Op : TF_FusedBatchNormOpBase<"FusedBatchNormV3"> {
|
||||||
|
let results = (outs
|
||||||
|
TensorOf<[BF16, F16, F32]>:$y,
|
||||||
|
F32Tensor:$batch_mean,
|
||||||
|
F32Tensor:$batch_variance,
|
||||||
|
F32Tensor:$reserve_space_1,
|
||||||
|
F32Tensor:$reserve_space_2,
|
||||||
|
F32Tensor:$reserve_space_3
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#endif // TF_OPS
|
#endif // TF_OPS
|
||||||
|
Loading…
Reference in New Issue
Block a user