Use mlir::TypeAttr for the type attribute instead of mangled string repr
All TF types are modeled in either standard or TF dialect and can be directly represented without need for string mangling. PiperOrigin-RevId: 284248488 Change-Id: Id57c670808da30eaaaa85b0d4a96fd4e813df8f3
This commit is contained in:
parent
6d4c47b632
commit
00c6bb2b7c
@ -38,6 +38,6 @@ versions {
|
||||
|
||||
# CHECK: func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "output"}} {
|
||||
# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "output"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = i32, device = "", name = "output"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK-NEXT: return %0 : tensor<*xi32>
|
||||
# CHECK-NEXT: }
|
||||
|
@ -8,7 +8,7 @@
|
||||
# Verify that we can also pull some attributes that are needed to be able to
|
||||
# create a Graph in memory, like `T`.
|
||||
# CHECK: tf.MaxPool
|
||||
# CHECK-SAME: T = "tfdtype$DT_FLOAT"
|
||||
# CHECK-SAME: T = f32
|
||||
|
||||
node {
|
||||
name: "input"
|
||||
|
@ -2,11 +2,11 @@
|
||||
|
||||
# CHECK: tf_executor.SwitchN
|
||||
# CHECK-SAME: of 3 : tensor<i32>
|
||||
# CHECK-SAME: T = "tfdtype$DT_INT32"
|
||||
# CHECK-SAME: T = i32
|
||||
# CHECK-SAME: name = "Case/branch_index/_3"
|
||||
# CHECK: tf_executor.SwitchN
|
||||
# CHECK-SAME: of 2 : tensor<f32>
|
||||
# CHECK-SAME: T = "tfdtype$DT_FLOAT"
|
||||
# CHECK-SAME: T = f32
|
||||
# CHECK-SAME: name = "Case/Case/input_0/_7"
|
||||
|
||||
node {
|
||||
|
@ -8,7 +8,7 @@ func @only_resource_load() -> tensor<*xi32> {
|
||||
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_INT32"}
|
||||
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = i32}
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]])
|
||||
// CHECK: tf_device.return %[[COMPUTE_RES]]
|
||||
@ -16,7 +16,7 @@ func @only_resource_load() -> tensor<*xi32> {
|
||||
// CHECK-SAME: () -> tensor<*xi32>
|
||||
|
||||
%1 = "tf_device.launch"() ( {
|
||||
%2 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
|
||||
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
|
||||
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
tf_device.return %3 : tensor<*xi32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
|
||||
@ -39,11 +39,11 @@ func @only_resource_store() -> tensor<*xi32> {
|
||||
// CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]]
|
||||
// CHECK: {device = "tpu0", launch_attr = "launch_attr"}
|
||||
// CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>)
|
||||
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_INT32"}
|
||||
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = i32}
|
||||
|
||||
%1 = "tf_device.launch"() ( {
|
||||
%2 = "tf.SomeComputation"() : () -> (tensor<*xi32>)
|
||||
"tf.AssignVariableOp"(%0, %2) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
|
||||
"tf.AssignVariableOp"(%0, %2) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
|
||||
tf_device.return %2 : tensor<*xi32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
|
||||
|
||||
@ -61,18 +61,18 @@ func @same_resource_load_and_store() -> tensor<*xi32> {
|
||||
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_INT32"}
|
||||
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = i32}
|
||||
// CHECK: %[[LAUNCH_RES:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]])
|
||||
// CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]]
|
||||
// CHECK: {device = "tpu0", launch_attr = "launch_attr"}
|
||||
// CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>)
|
||||
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_INT32"}
|
||||
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = i32}
|
||||
|
||||
%1 = "tf_device.launch"() ( {
|
||||
%2 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
|
||||
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
|
||||
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
"tf.AssignVariableOp"(%0, %3) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
|
||||
"tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
|
||||
tf_device.return %3 : tensor<*xi32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
|
||||
|
||||
@ -91,19 +91,19 @@ func @decompose_assign_add_variable_op() -> tensor<*xi32> {
|
||||
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_INT32"}
|
||||
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = i32}
|
||||
// CHECK: %[[LAUNCH_RES:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
|
||||
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.AddV2"(%[[RES_READ_VAL]], %[[ONE]])
|
||||
// CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]]
|
||||
// CHECK: {device = "tpu0", launch_attr = "launch_attr"}
|
||||
// CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>)
|
||||
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_INT32"}
|
||||
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = i32}
|
||||
|
||||
%1 = "tf_device.launch"() ( {
|
||||
%2 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
"tf.AssignAddVariableOp"(%0, %2) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<i32>) -> ()
|
||||
%3 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
|
||||
"tf.AssignAddVariableOp"(%0, %2) {dtype = i32} : (tensor<*x!tf.resource>, tensor<i32>) -> ()
|
||||
%3 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
|
||||
tf_device.return %3 : tensor<*xi32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
|
||||
|
||||
@ -128,8 +128,8 @@ func @decompose_assign_sub_variable_op() -> tensor<*xi32> {
|
||||
|
||||
%1 = "tf_device.launch"() ( {
|
||||
%2 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
"tf.AssignSubVariableOp"(%0, %2) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<i32>) -> ()
|
||||
%3 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
|
||||
"tf.AssignSubVariableOp"(%0, %2) {dtype = i32} : (tensor<*x!tf.resource>, tensor<i32>) -> ()
|
||||
%3 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
|
||||
tf_device.return %3 : tensor<*xi32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
|
||||
|
||||
@ -147,7 +147,7 @@ func @decompose_resource_apply_gradient_descent() -> tensor<*xf32> {
|
||||
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_FLOAT"}
|
||||
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = f32}
|
||||
// CHECK: %[[LAUNCH_RES:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK: %[[ALPHA:[0-9]*]] = "tf.Const"
|
||||
// CHECK: %[[DELTA:[0-9]*]] = "tf.Const"
|
||||
@ -156,13 +156,13 @@ func @decompose_resource_apply_gradient_descent() -> tensor<*xf32> {
|
||||
// CHECK: tf_device.return %[[SUB]], %[[SUB]]
|
||||
// CHECK: {device = "tpu0", launch_attr = "launch_attr"}
|
||||
// CHECK-SAME: () -> (tensor<*xf32>, tensor<*xf32>)
|
||||
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_FLOAT"}
|
||||
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = f32}
|
||||
|
||||
%1 = "tf_device.launch"() ( {
|
||||
%2 = "tf.Const"() {T = "tfdtype$DT_FLOAT", value = dense<[1.0]> : tensor<1xf32>} : () -> tensor<f32>
|
||||
%3 = "tf.Const"() {T = "tfdtype$DT_FLOAT", value = dense<[0.5]> : tensor<1xf32>} : () -> tensor<f32>
|
||||
%2 = "tf.Const"() {T = f32, value = dense<[1.0]> : tensor<1xf32>} : () -> tensor<f32>
|
||||
%3 = "tf.Const"() {T = f32, value = dense<[0.5]> : tensor<1xf32>} : () -> tensor<f32>
|
||||
"tf.ResourceApplyGradientDescent"(%0, %2, %3) : (tensor<*x!tf.resource>, tensor<f32>, tensor<f32>) -> ()
|
||||
%4 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_FLOAT"} : (tensor<*x!tf.resource>) -> tensor<*xf32>
|
||||
%4 = "tf.ReadVariableOp"(%0) {dtype = f32} : (tensor<*x!tf.resource>) -> tensor<*xf32>
|
||||
tf_device.return %4 : tensor<*xf32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xf32>
|
||||
|
||||
@ -184,13 +184,13 @@ func @internal_resource() -> tensor<*xi32> {
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]])
|
||||
%2 = "tf.ReadVariableOp"(%1) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
|
||||
%2 = "tf.ReadVariableOp"(%1) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
|
||||
|
||||
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]])
|
||||
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
|
||||
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[COMPUTE_RES]])
|
||||
"tf.AssignVariableOp"(%1, %3) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
|
||||
"tf.AssignVariableOp"(%1, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
|
||||
|
||||
// CHECK: tf_device.return %[[COMPUTE_RES]]
|
||||
tf_device.return %3 : tensor<*xi32>
|
||||
|
@ -91,28 +91,17 @@ struct ResourceOpLiftingPass : public FunctionPass<ResourceOpLiftingPass> {
|
||||
//
|
||||
template <typename T>
|
||||
LogicalResult RewriteCompositeAssignVariableOp(T src_op, OpBuilder* builder) {
|
||||
// Read mangled dtype, which indicates type of data stored in resource
|
||||
// Read dtype attribute, which indicates type of data stored in resource
|
||||
// variable. It can then be used to construct type needed for both
|
||||
// ReadVariableOp and AssignVariableOp.
|
||||
StringAttr mangled_dtype_attr =
|
||||
src_op.template getAttrOfType<StringAttr>(kDTypeAttr);
|
||||
std::string type_string = mangled_dtype_attr.getValue();
|
||||
tensorflow::DataType dtype_proto;
|
||||
auto s =
|
||||
tensorflow::mangling_util::DemangleDataType(type_string, &dtype_proto);
|
||||
if (!s.ok()) return src_op.emitError() << s.error_message();
|
||||
|
||||
Type type;
|
||||
s = tensorflow::ConvertDataType(dtype_proto, *builder, &type);
|
||||
if (!s.ok()) return src_op.emitError() << s.error_message();
|
||||
type = UnrankedTensorType::get(type);
|
||||
TypeAttr dtype_attr = src_op.template getAttrOfType<TypeAttr>(kDTypeAttr);
|
||||
Type type = UnrankedTensorType::get(dtype_attr.getValue());
|
||||
|
||||
builder->setInsertionPoint(src_op);
|
||||
|
||||
auto read_variable_op = builder->create<TF::ReadVariableOp>(
|
||||
src_op.getLoc(), type, src_op.resource());
|
||||
read_variable_op.setAttr(builder->getIdentifier(kDTypeAttr),
|
||||
mangled_dtype_attr);
|
||||
read_variable_op.setAttr(builder->getIdentifier(kDTypeAttr), dtype_attr);
|
||||
|
||||
Value* result;
|
||||
if (std::is_same<T, TF::AssignAddVariableOp>()) {
|
||||
@ -125,8 +114,7 @@ LogicalResult RewriteCompositeAssignVariableOp(T src_op, OpBuilder* builder) {
|
||||
|
||||
auto assign_variable_op = builder->create<TF::AssignVariableOp>(
|
||||
src_op.getLoc(), src_op.resource(), result);
|
||||
assign_variable_op.setAttr(builder->getIdentifier(kDTypeAttr),
|
||||
mangled_dtype_attr);
|
||||
assign_variable_op.setAttr(builder->getIdentifier(kDTypeAttr), dtype_attr);
|
||||
|
||||
src_op.erase();
|
||||
return success();
|
||||
@ -147,22 +135,15 @@ LogicalResult RewriteCompositeAssignVariableOp(T src_op, OpBuilder* builder) {
|
||||
// tf.AssignVariableOp(%var, %new_var_val)
|
||||
LogicalResult RewriteResourceApplyGradientDescentOp(
|
||||
TF::ResourceApplyGradientDescentOp op, OpBuilder* builder) {
|
||||
Type type = op.alpha()->getType();
|
||||
auto t = UnrankedTensorType::get(type.cast<TensorType>().getElementType());
|
||||
Type type = getElementTypeOrSelf(op.alpha());
|
||||
auto t = UnrankedTensorType::get(type);
|
||||
|
||||
tensorflow::DataType data_type;
|
||||
auto s = tensorflow::ConvertToDataType(type, &data_type);
|
||||
if (!s.ok()) return op.emitError() << s.error_message();
|
||||
|
||||
std::string mangled_data_type =
|
||||
tensorflow::mangling_util::MangleDataType(data_type);
|
||||
auto mangled_dtype_attr = builder->getStringAttr(mangled_data_type);
|
||||
TypeAttr dtype_attr = TypeAttr::get(type);
|
||||
|
||||
builder->setInsertionPoint(op);
|
||||
auto read_variable_op =
|
||||
builder->create<TF::ReadVariableOp>(op.getLoc(), t, op.var());
|
||||
read_variable_op.setAttr(builder->getIdentifier(kDTypeAttr),
|
||||
mangled_dtype_attr);
|
||||
read_variable_op.setAttr(builder->getIdentifier(kDTypeAttr), dtype_attr);
|
||||
|
||||
auto mul_op =
|
||||
builder->create<TF::MulOp>(op.getLoc(), t, op.alpha(), op.delta());
|
||||
@ -170,8 +151,7 @@ LogicalResult RewriteResourceApplyGradientDescentOp(
|
||||
op.getLoc(), t, read_variable_op.value(), mul_op.z());
|
||||
auto assign_variable_op =
|
||||
builder->create<TF::AssignVariableOp>(op.getLoc(), op.var(), sub_op.z());
|
||||
assign_variable_op.setAttr(builder->getIdentifier(kDTypeAttr),
|
||||
mangled_dtype_attr);
|
||||
assign_variable_op.setAttr(builder->getIdentifier(kDTypeAttr), dtype_attr);
|
||||
|
||||
op.erase();
|
||||
|
||||
|
@ -945,9 +945,11 @@ StatusOr<mlir::Attribute> ImporterBase::ConvertAttributeValue(
|
||||
return builder_.getFloatAttr(builder_.getF32Type(), value.f());
|
||||
case AttrValue::kB:
|
||||
return builder_.getBoolAttr(value.b());
|
||||
case AttrValue::kType:
|
||||
return builder_.getStringAttr(
|
||||
mangling_util::MangleDataType(value.type()));
|
||||
case AttrValue::kType: {
|
||||
mlir::Type type;
|
||||
TF_RETURN_IF_ERROR(ConvertDataType(value.type(), builder_, &type));
|
||||
return mlir::TypeAttr::get(type);
|
||||
}
|
||||
case AttrValue::kShape:
|
||||
return builder_.getStringAttr(mangling_util::MangleShape(value.shape()));
|
||||
case AttrValue::kTensor:
|
||||
|
Loading…
Reference in New Issue
Block a user