diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
index 60250f0d056..144d3f9b3f8 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
@@ -120,7 +120,9 @@ def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> {
 
 class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
       Type TensorType>: HLO_Op<mnemonic,
-        !listconcat(traits, [InferShapedTypeOpInterface, InferFusibilityOpInterface])> {
+        !listconcat(traits,
+        [InferShapedTypeOpInterface, InferFusibilityOpInterface,
+        SameOperandsAndResultShape])> {
     let arguments = (ins TensorType:$operand);
     let results = (outs TensorType);
     let extraClassDeclaration = [{
@@ -146,7 +148,7 @@ class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
 
 // Abs supports complex to real, so element type is not guaranteed to match.
 def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
-    [NoSideEffect, SameOperandsAndResultShape,
+    [NoSideEffect,
      DeclareOpInterfaceMethods<InferTypeOpInterface>],
      TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp {
 }
@@ -157,10 +159,8 @@ def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt",
 def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil",
     [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CeilOp;
 
-def HLO_ConvertOp : HLO_UnaryElementwiseOp<
-    "convert", [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>,
-    BASE_HLO_ConvertOp {
-
+def HLO_ConvertOp : HLO_UnaryElementwiseOp<"convert",
+    [NoSideEffect], HLO_Tensor>, BASE_HLO_ConvertOp {
   let builders = [
     OpBuilderDAG<(ins "Value":$operand, "Type":$result_element_ty)>];
 
@@ -189,15 +189,14 @@ def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor",
     [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp;
 
 def HLO_ImagOp: HLO_UnaryElementwiseOp<"imag",
-    [NoSideEffect, SameOperandsAndResultShape,
-     DeclareOpInterfaceMethods<InferTypeOpInterface>],
+    [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>],
     HLO_ComplexTensor>, BASE_HLO_ImagOp {
   let results = (outs HLO_FpTensor);
   let hasFolder = 1;
 }
 
 def HLO_IsFiniteOp: HLO_UnaryElementwiseOp<"is_finite",
-    [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>,
+    [NoSideEffect], HLO_Tensor>,
     BASE_HLO_IsFiniteOp {
   let arguments = (ins HLO_FpTensor:$x);
   let results = (outs HLO_PredTensor:$y);
@@ -231,8 +230,7 @@ def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt",
     BASE_HLO_PopulationCountOp;
 
 def HLO_RealOp: HLO_UnaryElementwiseOp<"real",
-    [NoSideEffect, SameOperandsAndResultShape,
-     DeclareOpInterfaceMethods<InferTypeOpInterface>],
+    [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>],
     HLO_ComplexTensor>, BASE_HLO_RealOp {
   let results = (outs HLO_FpTensor);
   let hasFolder = 1;
@@ -272,7 +270,9 @@ def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh",
 // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
 
 class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
-        HLO_Op<mnemonic, !listconcat(traits, [InferShapedTypeOpInterface, InferFusibilityOpInterface])> {
+        HLO_Op<mnemonic, !listconcat(traits,
+        [InferShapedTypeOpInterface, InferFusibilityOpInterface,
+        SameOperandsAndResultShape])> {
   let arguments = (ins
     HLO_Tensor:$lhs,
     HLO_Tensor:$rhs
@@ -315,8 +315,7 @@ def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2",
       [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op;
 
 def HLO_ComplexOp: HLO_BinaryElementwiseOp<"complex",
-    [NoSideEffect, SameOperandsAndResultShape,
-     DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
+    [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
     BASE_HLO_ComplexOp {
   let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs);
   let results = (outs HLO_ComplexTensor);