Move TF_FusedBatchNormV2Op and TF_FusedBatchNormV3Op to tf_ops.td.

These ops are not based on the autogenerated form from the TensorFlow op registry.

PiperOrigin-RevId: 321390079
Change-Id: I9cc5baad159fee7e97813cf560839c1b61189a58
This commit is contained in:
Andy Ly 2020-07-15 10:37:46 -07:00 committed by TensorFlower Gardener
parent e7dca4ea0e
commit 3822d5f114
2 changed files with 57 additions and 57 deletions

View File

@ -3540,63 +3540,6 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
}]; }];
} }
class TF_FusedBatchNormOpBase<string Name> : TF_Op<Name, [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
let summary = "Batch normalization.";
let description = [{
Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
The size of 1D Tensors matches the dimension C of the 4D Tensors.
}];
let arguments = (ins
TensorOf<[BF16, F16, F32]>:$x,
F32Tensor:$scale,
F32Tensor:$offset,
F32Tensor:$mean,
F32Tensor:$variance,
DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
DefaultValuedAttr<F32Attr, "1.0f">:$exponential_avg_factor,
DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
DefaultValuedAttr<BoolAttr, "true">:$is_training
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
let extraClassDeclaration = [{
// TF_FoldOperandsTransposeInterface:
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
// TF_LayoutSensitiveInterface:
StringRef GetOptimalLayout(const RuntimeDevices& devices);
LogicalResult UpdateDataFormat(StringRef data_format);
}];
}
def TF_FusedBatchNormV2Op : TF_FusedBatchNormOpBase<"FusedBatchNormV2"> {
let results = (outs
TensorOf<[BF16, F16, F32]>:$y,
F32Tensor:$batch_mean,
F32Tensor:$batch_variance,
F32Tensor:$reserve_space_1,
F32Tensor:$reserve_space_2
);
}
def TF_FusedBatchNormV3Op : TF_FusedBatchNormOpBase<"FusedBatchNormV3"> {
let results = (outs
TensorOf<[BF16, F16, F32]>:$y,
F32Tensor:$batch_mean,
F32Tensor:$batch_variance,
F32Tensor:$reserve_space_1,
F32Tensor:$reserve_space_2,
F32Tensor:$reserve_space_3
);
}
def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> { def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> {
let summary = "Gather slices from `params` according to `indices`."; let summary = "Gather slices from `params` according to `indices`.";

View File

@ -1196,4 +1196,61 @@ def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface]> {
let verifier = [{ return VerifyPartitionedCall(*this); }]; let verifier = [{ return VerifyPartitionedCall(*this); }];
} }
class TF_FusedBatchNormOpBase<string Name> : TF_Op<Name, [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
let summary = "Batch normalization.";
let description = [{
Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
The size of 1D Tensors matches the dimension C of the 4D Tensors.
}];
let arguments = (ins
TensorOf<[BF16, F16, F32]>:$x,
F32Tensor:$scale,
F32Tensor:$offset,
F32Tensor:$mean,
F32Tensor:$variance,
DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
DefaultValuedAttr<F32Attr, "1.0f">:$exponential_avg_factor,
DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
DefaultValuedAttr<BoolAttr, "true">:$is_training
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
let extraClassDeclaration = [{
// TF_FoldOperandsTransposeInterface:
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
// TF_LayoutSensitiveInterface:
StringRef GetOptimalLayout(const RuntimeDevices& devices);
LogicalResult UpdateDataFormat(StringRef data_format);
}];
}
def TF_FusedBatchNormV2Op : TF_FusedBatchNormOpBase<"FusedBatchNormV2"> {
let results = (outs
TensorOf<[BF16, F16, F32]>:$y,
F32Tensor:$batch_mean,
F32Tensor:$batch_variance,
F32Tensor:$reserve_space_1,
F32Tensor:$reserve_space_2
);
}
def TF_FusedBatchNormV3Op : TF_FusedBatchNormOpBase<"FusedBatchNormV3"> {
let results = (outs
TensorOf<[BF16, F16, F32]>:$y,
F32Tensor:$batch_mean,
F32Tensor:$batch_variance,
F32Tensor:$reserve_space_1,
F32Tensor:$reserve_space_2,
F32Tensor:$reserve_space_3
);
}
#endif // TF_OPS #endif // TF_OPS