Merge pull request #42102 from WindQAQ:utilize-tensor-format

PiperOrigin-RevId: 325358681
Change-Id: I402ac2cef4d12d2f160f2025520470796dce29dd
This commit is contained in:
TensorFlower Gardener 2020-08-06 19:22:05 -07:00
commit c415c8345c

View File

@ -339,15 +339,18 @@ void BatchToSpaceOp::getCanonicalizationPatterns(
// are not unknown.
//
static LogicalResult Verify(BiasAddOp op) {
StringRef format = op.data_format();
if (format == "NHWC") {
std::string data_format = op.data_format().str();
tensorflow::TensorFormat format;
bool is_valid = FormatFromString(data_format, &format);
DCHECK(is_valid) << data_format;
if (format == tensorflow::TensorFormat::FORMAT_NHWC) {
if (!HasRankAtLeast(op.value(), 2))
return op.emitOpError(
"requires value operand to have rank at least two with `NHWC` data "
"format");
} else {
// Op definition requires data_format to be either NHWC or NCHW.
DCHECK_EQ(format.str(), "NCHW");
DCHECK_EQ(format, tensorflow::TensorFormat::FORMAT_NCHW);
if (!HasRankAtLeast(op.value(), 3))
return op.emitOpError(
"requires value operand to have rank at least three with `NCHW` data "
@ -361,9 +364,8 @@ static LogicalResult Verify(BiasAddOp op) {
RankedTensorType bias_ty = op.bias().getType().dyn_cast<RankedTensorType>();
if (!bias_ty || !value_ty) return success();
// TODO(hinsu): Leverage tensor_format.h utility in TensorFlow to compute
// dimension indices based on format.
int64_t feature_dim_idx = format == "NHWC" ? value_ty.getRank() - 1 : 1;
int64_t feature_dim_idx =
tensorflow::GetTensorFeatureDimIndex(value_ty.getRank(), format);
int64_t feature_dim = value_ty.getDimSize(feature_dim_idx);
int64_t bias_len = bias_ty.getDimSize(0);
if (feature_dim != -1 && bias_len != -1 && feature_dim != bias_len) {
@ -383,15 +385,18 @@ static LogicalResult Verify(BiasAddOp op) {
// * the out_backprop operands have valid ranks or are unranked.
//
static LogicalResult Verify(BiasAddGradOp op) {
StringRef format = op.data_format();
if (format == "NHWC") {
std::string data_format = op.data_format().str();
tensorflow::TensorFormat format;
bool is_valid = FormatFromString(data_format, &format);
DCHECK(is_valid) << data_format;
if (format == tensorflow::TensorFormat::FORMAT_NHWC) {
if (!HasRankAtLeast(op.out_backprop(), 2))
return op.emitOpError(
"requires out_backprop operand to have rank at least two with `NHWC` "
"data format");
} else {
// Op definition requires data_format to be either NHWC or NCHW.
DCHECK_EQ(format.str(), "NCHW");
DCHECK_EQ(format, tensorflow::TensorFormat::FORMAT_NCHW);
if (!HasRankAtLeast(op.out_backprop(), 3))
return op.emitOpError(
"requires out_backprop operand to have rank at least three with "