Only generate xla_hlo.constant where result and attribute types match exactly.
xla_hlo.constant would no longer allow the result type to be different from the attribute type. Since some of the legalizations from TF -> XLA used this relaxed property to generate constants with dynamic/unranked types, they have been modified to use std.tensor_cast op explicitly for type conversion. PiperOrigin-RevId: 308648605 Change-Id: I52839f9d6f5248e4447936f402886646efeb6e4b
This commit is contained in:
parent
5ab3af7a7b
commit
7c977c938e
|
@ -33,8 +33,9 @@ func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %a
|
|||
// CHECK: [[Y:%.*]] = "xla_hlo.batch_norm_inference"([[CONVERT_X]], [[SCALE]], [[OFFSET]], [[MEAN]], [[VARIANCE]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
%0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>)
|
||||
// CHECK: [[Y_CONVERT:%.*]] = "xla_hlo.convert"([[Y]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
|
||||
// CHECK: [[DUMMY:%.*]] = xla_hlo.constant {value = dense<0.000000e+00> : tensor<0xf32>} : tensor<*xf32>
|
||||
// CHECK: return [[Y_CONVERT]], [[MEAN]], [[VARIANCE]], [[MEAN]], [[VARIANCE]], [[DUMMY]]
|
||||
// CHECK: [[DUMMY:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<0xf32>
|
||||
// CHECK: [[DUMMY_CAST:%.*]] = tensor_cast [[DUMMY]] : tensor<0xf32> to tensor<*xf32>
|
||||
// CHECK: return [[Y_CONVERT]], [[MEAN]], [[VARIANCE]], [[MEAN]], [[VARIANCE]], [[DUMMY_CAST]]
|
||||
return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>
|
||||
}
|
||||
|
||||
|
@ -821,7 +822,8 @@ func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> {
|
|||
%cst = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32>
|
||||
|
||||
// CHECK: [[CST:%.+]] = xla_hlo.constant
|
||||
// CHECK: "xla_hlo.dynamic_broadcast_in_dim"(%arg0, [[CST]])
|
||||
// CHECK: [[CAST:%.+]] = tensor_cast [[CST]] : tensor<4xi32> to tensor<4xi32>
|
||||
// CHECK: "xla_hlo.dynamic_broadcast_in_dim"(%arg0, [[CAST]])
|
||||
// CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>}
|
||||
%0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<16xf32>, tensor<4xi32>) -> tensor<16x16x16x16xf32>
|
||||
return %0 : tensor<16x16x16x16xf32>
|
||||
|
@ -1202,8 +1204,10 @@ func @const() -> tensor<2xi32> {
|
|||
|
||||
// CHECK-LABEL: @const_dynamic_output
|
||||
func @const_dynamic_output() -> tensor<*xi32> {
|
||||
// CHECK: xla_hlo.constant {value = dense<0> : tensor<2xi32>} : tensor<*xi32>
|
||||
// CHECK: [[CONST:%.*]] = xla_hlo.constant dense<0> : tensor<2xi32>
|
||||
// CHECK: [[CAST:%.*]] = tensor_cast [[CONST]] : tensor<2xi32> to tensor<*xi32>
|
||||
%0 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> (tensor<*xi32>)
|
||||
// CHECK: return [[CAST]]
|
||||
return %0: tensor<*xi32>
|
||||
}
|
||||
|
||||
|
@ -2339,7 +2343,8 @@ func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
|
|||
// CHECK-LABEL: slice_constant_start
|
||||
func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
|
||||
// CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64>
|
||||
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64>
|
||||
// CHECK: %[[CAST:.*]] = tensor_cast %[[START]] : tensor<1xi64> to tensor<1xi64>
|
||||
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[CAST]]) : (tensor<1xi64>) -> tensor<1xi64>
|
||||
// CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]])
|
||||
// CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
|
||||
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
|
||||
|
@ -2360,7 +2365,8 @@ func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
|
|||
// CHECK-LABEL: slice_i32_consts
|
||||
func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> {
|
||||
// CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi32>
|
||||
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi32>) -> tensor<1xi64>
|
||||
// CHECK: %[[START_CAST:.*]] = tensor_cast %[[START]] : tensor<1xi32> to tensor<1xi32>
|
||||
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START_CAST]]) : (tensor<1xi32>) -> tensor<1xi64>
|
||||
// CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]])
|
||||
// CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
|
||||
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
|
||||
|
@ -2376,7 +2382,8 @@ func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> {
|
|||
// CHECK-LABEL: slice_constant_start_negative_one_size
|
||||
func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> {
|
||||
// CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64>
|
||||
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64>
|
||||
// CHECK: %[[START_CAST:.*]] = tensor_cast %[[START]] : tensor<1xi64> to tensor<1xi64>
|
||||
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START_CAST]]) : (tensor<1xi64>) -> tensor<1xi64>
|
||||
// CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]])
|
||||
// CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
|
||||
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
|
||||
|
@ -2393,7 +2400,8 @@ func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi
|
|||
// CHECK-LABEL: slice_constant_start_dynamic_shape
|
||||
func @slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
|
||||
// CHECK: %[[START:.*]] = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<2xi64>) -> tensor<2xi64>
|
||||
// CHECK: %[[START_CAST:.*]] = tensor_cast %[[START]] : tensor<2xi64> to tensor<2xi64>
|
||||
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START_CAST]]) : (tensor<2xi64>) -> tensor<2xi64>
|
||||
// CHECK: %[[SLICED_START1:.*]] = "xla_hlo.slice"(%[[START_I64]])
|
||||
// CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
|
||||
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
|
||||
|
@ -3127,7 +3135,8 @@ func @range(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<5xf32> {
|
|||
// CHECK-SAME: [[START:%.*]]: tensor<f32>, [[STOP:%.*]]: tensor<f32>
|
||||
func @linspace_static(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<4xf32> {
|
||||
// CHECK-DAG: [[NUM:%.*]] = xla_hlo.constant dense<4>
|
||||
// CHECK-DAG: [[NUM_F32:%.*]] = "xla_hlo.convert"([[NUM]])
|
||||
// CHECK-DAG: [[NUM_CAST:%.*]] = tensor_cast [[NUM]]
|
||||
// CHECK-DAG: [[NUM_F32:%.*]] = "xla_hlo.convert"([[NUM_CAST]])
|
||||
// CHECK-DAG: [[ONE:%.*]] = xla_hlo.constant dense<1.000000e+00>
|
||||
// CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = xla_hlo.subtract [[NUM_F32]], [[ONE]]
|
||||
// CHECK-DAG: [[STEP_NUMERATOR:%.*]] = xla_hlo.subtract [[STOP]], [[START]]
|
||||
|
@ -3150,10 +3159,10 @@ func @linspace_dynamic(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32
|
|||
|
||||
// CHECK-LABEL: func @linspace_invalid_num
|
||||
func @linspace_invalid_num(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<?xf32> {
|
||||
// CHECK: xla_hlo.constant {value = dense<[]> : tensor<0xi32>} : tensor<i32>
|
||||
// CHECK: xla_hlo.constant dense<[]> : tensor<0xi32>
|
||||
// CHECK: "tf.LinSpace"
|
||||
%0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<[]> : tensor<0xi32>} : () -> tensor<i32>
|
||||
%1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor<f32>, tensor<f32>, tensor<i32>) -> tensor<?xf32>
|
||||
%0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
|
||||
%1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor<f32>, tensor<f32>, tensor<0xi32>) -> tensor<?xf32>
|
||||
return %1 : tensor<?xf32>
|
||||
}
|
||||
|
||||
|
@ -3976,16 +3985,18 @@ func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> {
|
|||
// CHECK-LABLE: @variable_shape32
|
||||
func @variable_shape32(%input: tensor<!tf.resource<tensor<2x4x8xf32>>>) -> tensor<3xi32> {
|
||||
// CHECK: [[CST:%.*]] = xla_hlo.constant dense<[2, 4, 8]> : tensor<3xi32>
|
||||
// CHECK: [[CST_CAST:%.*]] = tensor_cast [[CST]]
|
||||
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource<tensor<2x4x8xf32>>>) -> (tensor<3xi32>)
|
||||
// CHECK: return [[CST]]
|
||||
// CHECK: return [[CST_CAST]]
|
||||
return %0: tensor<3xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABLE: @variable_shape64
|
||||
func @variable_shape64(%input: tensor<!tf.resource<tensor<2x4x8xf32>>>) -> tensor<3xi64> {
|
||||
// CHECK: [[CST:%.*]] = xla_hlo.constant dense<[2, 4, 8]> : tensor<3xi64>
|
||||
// CHECK: [[CST_CAST:%.*]] = tensor_cast [[CST]]
|
||||
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource<tensor<2x4x8xf32>>>) -> (tensor<3xi64>)
|
||||
// CHECK: return [[CST]]
|
||||
// CHECK: return [[CST_CAST]]
|
||||
return %0: tensor<3xi64>
|
||||
}
|
||||
|
||||
|
|
|
@ -1432,15 +1432,18 @@ class ConvertFusedBatchNormV3Op
|
|||
: 0;
|
||||
auto const_attr_type = RankedTensorType::get(
|
||||
{num_elements}, getElementTypeOrSelf(reserve_space_3_type));
|
||||
auto dummy_const = rewriter.create<ConstOp>(
|
||||
op.getLoc(), reserve_space_3_type,
|
||||
DenseElementsAttr::get<float>(const_attr_type, 0.0));
|
||||
|
||||
Value dummy_const = rewriter.create<ConstOp>(
|
||||
op.getLoc(), DenseElementsAttr::get<float>(const_attr_type, 0.0));
|
||||
if (const_attr_type != reserve_space_3_type)
|
||||
dummy_const = rewriter.create<TensorCastOp>(
|
||||
op.getLoc(), reserve_space_3_type, dummy_const);
|
||||
rewriter.replaceOp(op, {/*y=*/y_out,
|
||||
/*batch_mean=*/op.mean(),
|
||||
/*batch_variance=*/op.variance(),
|
||||
/*reserve_space_1=*/op.mean(),
|
||||
/*reserve_space_2=*/op.variance(),
|
||||
/*reserve_space_3=*/dummy_const.getResult()});
|
||||
/*reserve_space_3=*/dummy_const});
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -4751,6 +4754,7 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
|
|||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<XlaHloDialect>();
|
||||
target.addLegalOp<CallOp>();
|
||||
target.addLegalOp<TensorCastOp>();
|
||||
|
||||
if (!allow_partial_conversion) {
|
||||
// Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp.
|
||||
|
|
|
@ -434,7 +434,8 @@ def : Pattern<(TF_MatrixBandPartOp:$op AnyRankedTensor:$input, $num_lower, $num_
|
|||
// Nullary op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def : Pat<(TF_ConstOp:$res ElementsAttr:$value), (HLO_ConstOp $value),
|
||||
def : Pat<(TF_ConstOp:$res ElementsAttr:$value),
|
||||
(TensorCastOp (HLO_ConstOp $value)),
|
||||
[(HLO_Tensor $res)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue