Use mlir::TypeAttr for the type attribute instead of mangled string repr

All TF types are modeled in either standard or TF dialect and can be
directly represented without need for string mangling.

PiperOrigin-RevId: 284248488
Change-Id: Id57c670808da30eaaaa85b0d4a96fd4e813df8f3
This commit is contained in:
Jaesung Chung 2019-12-06 13:19:53 -08:00 committed by TensorFlower Gardener
parent 6d4c47b632
commit 00c6bb2b7c
6 changed files with 40 additions and 58 deletions

View File

@ -38,6 +38,6 @@ versions {
# CHECK: func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<*xi32>
# CHECK: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "output"}} {
# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "output"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = i32, device = "", name = "output"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
# CHECK-NEXT: return %0 : tensor<*xi32>
# CHECK-NEXT: }

View File

@ -8,7 +8,7 @@
# Verify that we can also pull some attributes that are needed to be able to
# create a Graph in memory, like `T`.
# CHECK: tf.MaxPool
# CHECK-SAME: T = "tfdtype$DT_FLOAT"
# CHECK-SAME: T = f32
node {
name: "input"

View File

@ -2,11 +2,11 @@
# CHECK: tf_executor.SwitchN
# CHECK-SAME: of 3 : tensor<i32>
# CHECK-SAME: T = "tfdtype$DT_INT32"
# CHECK-SAME: T = i32
# CHECK-SAME: name = "Case/branch_index/_3"
# CHECK: tf_executor.SwitchN
# CHECK-SAME: of 2 : tensor<f32>
# CHECK-SAME: T = "tfdtype$DT_FLOAT"
# CHECK-SAME: T = f32
# CHECK-SAME: name = "Case/Case/input_0/_7"
node {

View File

@ -8,7 +8,7 @@ func @only_resource_load() -> tensor<*xi32> {
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_INT32"}
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = i32}
// CHECK: "tf_device.launch"
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]])
// CHECK: tf_device.return %[[COMPUTE_RES]]
@ -16,7 +16,7 @@ func @only_resource_load() -> tensor<*xi32> {
// CHECK-SAME: () -> tensor<*xi32>
%1 = "tf_device.launch"() ( {
%2 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
tf_device.return %3 : tensor<*xi32>
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
@ -39,11 +39,11 @@ func @only_resource_store() -> tensor<*xi32> {
// CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]]
// CHECK: {device = "tpu0", launch_attr = "launch_attr"}
// CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>)
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_INT32"}
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = i32}
%1 = "tf_device.launch"() ( {
%2 = "tf.SomeComputation"() : () -> (tensor<*xi32>)
"tf.AssignVariableOp"(%0, %2) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
"tf.AssignVariableOp"(%0, %2) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
tf_device.return %2 : tensor<*xi32>
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
@ -61,18 +61,18 @@ func @same_resource_load_and_store() -> tensor<*xi32> {
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_INT32"}
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = i32}
// CHECK: %[[LAUNCH_RES:[0-9]*]]:2 = "tf_device.launch"
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]])
// CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]]
// CHECK: {device = "tpu0", launch_attr = "launch_attr"}
// CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>)
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_INT32"}
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = i32}
%1 = "tf_device.launch"() ( {
%2 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
"tf.AssignVariableOp"(%0, %3) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
"tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
tf_device.return %3 : tensor<*xi32>
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
@ -91,19 +91,19 @@ func @decompose_assign_add_variable_op() -> tensor<*xi32> {
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_INT32"}
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = i32}
// CHECK: %[[LAUNCH_RES:[0-9]*]]:2 = "tf_device.launch"
// CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.AddV2"(%[[RES_READ_VAL]], %[[ONE]])
// CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]]
// CHECK: {device = "tpu0", launch_attr = "launch_attr"}
// CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>)
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_INT32"}
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = i32}
%1 = "tf_device.launch"() ( {
%2 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
"tf.AssignAddVariableOp"(%0, %2) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<i32>) -> ()
%3 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
"tf.AssignAddVariableOp"(%0, %2) {dtype = i32} : (tensor<*x!tf.resource>, tensor<i32>) -> ()
%3 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
tf_device.return %3 : tensor<*xi32>
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
@ -128,8 +128,8 @@ func @decompose_assign_sub_variable_op() -> tensor<*xi32> {
%1 = "tf_device.launch"() ( {
%2 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
"tf.AssignSubVariableOp"(%0, %2) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<i32>) -> ()
%3 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
"tf.AssignSubVariableOp"(%0, %2) {dtype = i32} : (tensor<*x!tf.resource>, tensor<i32>) -> ()
%3 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
tf_device.return %3 : tensor<*xi32>
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
@ -147,7 +147,7 @@ func @decompose_resource_apply_gradient_descent() -> tensor<*xf32> {
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_FLOAT"}
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = f32}
// CHECK: %[[LAUNCH_RES:[0-9]*]]:2 = "tf_device.launch"
// CHECK: %[[ALPHA:[0-9]*]] = "tf.Const"
// CHECK: %[[DELTA:[0-9]*]] = "tf.Const"
@ -156,13 +156,13 @@ func @decompose_resource_apply_gradient_descent() -> tensor<*xf32> {
// CHECK: tf_device.return %[[SUB]], %[[SUB]]
// CHECK: {device = "tpu0", launch_attr = "launch_attr"}
// CHECK-SAME: () -> (tensor<*xf32>, tensor<*xf32>)
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_FLOAT"}
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = f32}
%1 = "tf_device.launch"() ( {
%2 = "tf.Const"() {T = "tfdtype$DT_FLOAT", value = dense<[1.0]> : tensor<1xf32>} : () -> tensor<f32>
%3 = "tf.Const"() {T = "tfdtype$DT_FLOAT", value = dense<[0.5]> : tensor<1xf32>} : () -> tensor<f32>
%2 = "tf.Const"() {T = f32, value = dense<[1.0]> : tensor<1xf32>} : () -> tensor<f32>
%3 = "tf.Const"() {T = f32, value = dense<[0.5]> : tensor<1xf32>} : () -> tensor<f32>
"tf.ResourceApplyGradientDescent"(%0, %2, %3) : (tensor<*x!tf.resource>, tensor<f32>, tensor<f32>) -> ()
%4 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_FLOAT"} : (tensor<*x!tf.resource>) -> tensor<*xf32>
%4 = "tf.ReadVariableOp"(%0) {dtype = f32} : (tensor<*x!tf.resource>) -> tensor<*xf32>
tf_device.return %4 : tensor<*xf32>
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xf32>
@ -184,13 +184,13 @@ func @internal_resource() -> tensor<*xi32> {
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]])
%2 = "tf.ReadVariableOp"(%1) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
%2 = "tf.ReadVariableOp"(%1) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]])
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[COMPUTE_RES]])
"tf.AssignVariableOp"(%1, %3) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
"tf.AssignVariableOp"(%1, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
// CHECK: tf_device.return %[[COMPUTE_RES]]
tf_device.return %3 : tensor<*xi32>

View File

@ -91,28 +91,17 @@ struct ResourceOpLiftingPass : public FunctionPass<ResourceOpLiftingPass> {
//
template <typename T>
LogicalResult RewriteCompositeAssignVariableOp(T src_op, OpBuilder* builder) {
// Read mangled dtype, which indicates type of data stored in resource
// Read dtype attribute, which indicates type of data stored in resource
// variable. It can then be used to construct type needed for both
// ReadVariableOp and AssignVariableOp.
StringAttr mangled_dtype_attr =
src_op.template getAttrOfType<StringAttr>(kDTypeAttr);
std::string type_string = mangled_dtype_attr.getValue();
tensorflow::DataType dtype_proto;
auto s =
tensorflow::mangling_util::DemangleDataType(type_string, &dtype_proto);
if (!s.ok()) return src_op.emitError() << s.error_message();
Type type;
s = tensorflow::ConvertDataType(dtype_proto, *builder, &type);
if (!s.ok()) return src_op.emitError() << s.error_message();
type = UnrankedTensorType::get(type);
TypeAttr dtype_attr = src_op.template getAttrOfType<TypeAttr>(kDTypeAttr);
Type type = UnrankedTensorType::get(dtype_attr.getValue());
builder->setInsertionPoint(src_op);
auto read_variable_op = builder->create<TF::ReadVariableOp>(
src_op.getLoc(), type, src_op.resource());
read_variable_op.setAttr(builder->getIdentifier(kDTypeAttr),
mangled_dtype_attr);
read_variable_op.setAttr(builder->getIdentifier(kDTypeAttr), dtype_attr);
Value* result;
if (std::is_same<T, TF::AssignAddVariableOp>()) {
@ -125,8 +114,7 @@ LogicalResult RewriteCompositeAssignVariableOp(T src_op, OpBuilder* builder) {
auto assign_variable_op = builder->create<TF::AssignVariableOp>(
src_op.getLoc(), src_op.resource(), result);
assign_variable_op.setAttr(builder->getIdentifier(kDTypeAttr),
mangled_dtype_attr);
assign_variable_op.setAttr(builder->getIdentifier(kDTypeAttr), dtype_attr);
src_op.erase();
return success();
@ -147,22 +135,15 @@ LogicalResult RewriteCompositeAssignVariableOp(T src_op, OpBuilder* builder) {
// tf.AssignVariableOp(%var, %new_var_val)
LogicalResult RewriteResourceApplyGradientDescentOp(
TF::ResourceApplyGradientDescentOp op, OpBuilder* builder) {
Type type = op.alpha()->getType();
auto t = UnrankedTensorType::get(type.cast<TensorType>().getElementType());
Type type = getElementTypeOrSelf(op.alpha());
auto t = UnrankedTensorType::get(type);
tensorflow::DataType data_type;
auto s = tensorflow::ConvertToDataType(type, &data_type);
if (!s.ok()) return op.emitError() << s.error_message();
std::string mangled_data_type =
tensorflow::mangling_util::MangleDataType(data_type);
auto mangled_dtype_attr = builder->getStringAttr(mangled_data_type);
TypeAttr dtype_attr = TypeAttr::get(type);
builder->setInsertionPoint(op);
auto read_variable_op =
builder->create<TF::ReadVariableOp>(op.getLoc(), t, op.var());
read_variable_op.setAttr(builder->getIdentifier(kDTypeAttr),
mangled_dtype_attr);
read_variable_op.setAttr(builder->getIdentifier(kDTypeAttr), dtype_attr);
auto mul_op =
builder->create<TF::MulOp>(op.getLoc(), t, op.alpha(), op.delta());
@ -170,8 +151,7 @@ LogicalResult RewriteResourceApplyGradientDescentOp(
op.getLoc(), t, read_variable_op.value(), mul_op.z());
auto assign_variable_op =
builder->create<TF::AssignVariableOp>(op.getLoc(), op.var(), sub_op.z());
assign_variable_op.setAttr(builder->getIdentifier(kDTypeAttr),
mangled_dtype_attr);
assign_variable_op.setAttr(builder->getIdentifier(kDTypeAttr), dtype_attr);
op.erase();

View File

@ -945,9 +945,11 @@ StatusOr<mlir::Attribute> ImporterBase::ConvertAttributeValue(
return builder_.getFloatAttr(builder_.getF32Type(), value.f());
case AttrValue::kB:
return builder_.getBoolAttr(value.b());
case AttrValue::kType:
return builder_.getStringAttr(
mangling_util::MangleDataType(value.type()));
case AttrValue::kType: {
mlir::Type type;
TF_RETURN_IF_ERROR(ConvertDataType(value.type(), builder_, &type));
return mlir::TypeAttr::get(type);
}
case AttrValue::kShape:
return builder_.getStringAttr(mangling_util::MangleShape(value.shape()));
case AttrValue::kTensor: