Merge pull request #42102 from WindQAQ:utilize-tensor-format
PiperOrigin-RevId: 325358681 Change-Id: I402ac2cef4d12d2f160f2025520470796dce29dd
This commit is contained in:
commit
c415c8345c
@ -339,15 +339,18 @@ void BatchToSpaceOp::getCanonicalizationPatterns(
|
||||
// are not unknown.
|
||||
//
|
||||
static LogicalResult Verify(BiasAddOp op) {
|
||||
StringRef format = op.data_format();
|
||||
if (format == "NHWC") {
|
||||
std::string data_format = op.data_format().str();
|
||||
tensorflow::TensorFormat format;
|
||||
bool is_valid = FormatFromString(data_format, &format);
|
||||
DCHECK(is_valid) << data_format;
|
||||
if (format == tensorflow::TensorFormat::FORMAT_NHWC) {
|
||||
if (!HasRankAtLeast(op.value(), 2))
|
||||
return op.emitOpError(
|
||||
"requires value operand to have rank at least two with `NHWC` data "
|
||||
"format");
|
||||
} else {
|
||||
// Op definition requires data_format to be either NHWC or NCHW.
|
||||
DCHECK_EQ(format.str(), "NCHW");
|
||||
DCHECK_EQ(format, tensorflow::TensorFormat::FORMAT_NCHW);
|
||||
if (!HasRankAtLeast(op.value(), 3))
|
||||
return op.emitOpError(
|
||||
"requires value operand to have rank at least three with `NCHW` data "
|
||||
@ -361,9 +364,8 @@ static LogicalResult Verify(BiasAddOp op) {
|
||||
RankedTensorType bias_ty = op.bias().getType().dyn_cast<RankedTensorType>();
|
||||
if (!bias_ty || !value_ty) return success();
|
||||
|
||||
// TODO(hinsu): Leverage tensor_format.h utility in TensorFlow to compute
|
||||
// dimension indices based on format.
|
||||
int64_t feature_dim_idx = format == "NHWC" ? value_ty.getRank() - 1 : 1;
|
||||
int64_t feature_dim_idx =
|
||||
tensorflow::GetTensorFeatureDimIndex(value_ty.getRank(), format);
|
||||
int64_t feature_dim = value_ty.getDimSize(feature_dim_idx);
|
||||
int64_t bias_len = bias_ty.getDimSize(0);
|
||||
if (feature_dim != -1 && bias_len != -1 && feature_dim != bias_len) {
|
||||
@ -383,15 +385,18 @@ static LogicalResult Verify(BiasAddOp op) {
|
||||
// * the out_backprop operands have valid ranks or are unranked.
|
||||
//
|
||||
static LogicalResult Verify(BiasAddGradOp op) {
|
||||
StringRef format = op.data_format();
|
||||
if (format == "NHWC") {
|
||||
std::string data_format = op.data_format().str();
|
||||
tensorflow::TensorFormat format;
|
||||
bool is_valid = FormatFromString(data_format, &format);
|
||||
DCHECK(is_valid) << data_format;
|
||||
if (format == tensorflow::TensorFormat::FORMAT_NHWC) {
|
||||
if (!HasRankAtLeast(op.out_backprop(), 2))
|
||||
return op.emitOpError(
|
||||
"requires out_backprop operand to have rank at least two with `NHWC` "
|
||||
"data format");
|
||||
} else {
|
||||
// Op definition requires data_format to be either NHWC or NCHW.
|
||||
DCHECK_EQ(format.str(), "NCHW");
|
||||
DCHECK_EQ(format, tensorflow::TensorFormat::FORMAT_NCHW);
|
||||
if (!HasRankAtLeast(op.out_backprop(), 3))
|
||||
return op.emitOpError(
|
||||
"requires out_backprop operand to have rank at least three with "
|
||||
|
Loading…
x
Reference in New Issue
Block a user