diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 9b2e6f0292b..b2b4c09df3b 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -430,6 +430,7 @@ cc_library( "transforms/parallel_execute_to_islands.cc", "transforms/promote_resources_to_args.cc", "transforms/raise_control_flow.cc", + "transforms/readonly_references_to_resources.cc", "transforms/replicate_invariant_op_hoisting.cc", "transforms/replicate_to_island.cc", "transforms/resource_device_inference.cc", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index d5ecbf3e292..9daebc22ba1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -47,37 +47,6 @@ limitations under the License. namespace mlir { namespace tf_executor { -namespace { - -// If the given tensor has elements of type with subtypes, then returns a new -// type after dropping subtypes info. Otherwise, returns the original type as -// is. -ShapedType DropTypeSubTypes(ShapedType ty) { - Type element_ty = ty.getElementType(); - auto subtype_ty = element_ty.dyn_cast(); - if (!subtype_ty) return ty; - - Type default_ty = GetDefaultTypeOf(subtype_ty); - if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty); - - return UnrankedTensorType::get(default_ty); -} - -// If the given tensor has elements of type ref, then returns a new type -// of the shape, but corresponding non-ref type as element type. Otherwise, -// returns the original type as is. -ShapedType DropRefType(ShapedType ty) { - Type element_ty = ty.getElementType(); - auto ref_ty = element_ty.dyn_cast(); - if (!ref_ty) return ty; - - Type default_ty = GetDefaultTypeOf(ref_ty); - if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty); - - return UnrankedTensorType::get(default_ty); -} - -} // namespace //===----------------------------------------------------------------------===// // TF Executor Dialect @@ -85,6 +54,9 @@ ShapedType DropRefType(ShapedType ty) { namespace { +using TF::DropRefType; +using TF::DropTypeSubTypes; + struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc index d312e5e409b..994378ea1cf 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc @@ -366,5 +366,27 @@ bool AreCastCompatible(ArrayRef types) { return true; } +ShapedType DropTypeSubTypes(ShapedType ty) { + Type element_ty = ty.getElementType(); + auto subtype_ty = element_ty.dyn_cast(); + if (!subtype_ty) return ty; + + Type default_ty = GetDefaultTypeOf(subtype_ty); + if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty); + + return UnrankedTensorType::get(default_ty); +} + +ShapedType DropRefType(ShapedType ty) { + Type element_ty = ty.getElementType(); + TF::TensorFlowRefType ref_ty = element_ty.dyn_cast(); + if (!ref_ty) return ty; + + Type default_ty = TF::GetDefaultTypeOf(ref_ty); + if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty); + + return UnrankedTensorType::get(default_ty); +} + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h index 4c99aae4706..5f108e834a9 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h @@ -319,6 +319,16 @@ bool HasCompatibleElementTypes(Type lhs, Type rhs, // compatible. bool AreCastCompatible(ArrayRef types); +// If the given tensor has elements of type with subtypes, then returns a new +// type after dropping subtypes info. Otherwise, returns the original type as +// is. +ShapedType DropTypeSubTypes(ShapedType ty); + +// If the given tensor has elements of type ref, then returns a new type +// of the shape, but corresponding non-ref type as element type. Otherwise, +// returns the original type as is. +ShapedType DropRefType(ShapedType ty); + } // end namespace TF } // end namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/tests/readonly_references_to_resources.mlir b/tensorflow/compiler/mlir/tensorflow/tests/readonly_references_to_resources.mlir new file mode 100644 index 00000000000..2970e31c3c9 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/readonly_references_to_resources.mlir @@ -0,0 +1,85 @@ +// RUN: tf-opt -verify-diagnostics -readonly-references-to-resources -split-input-file %s | FileCheck %s --dump-input=fail + +// Test case: Basic converting. + +func @f() { + // CHECK: "tf.VarHandleOp" + // CHECK: "tf.ReadVariableOp" + %val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref> + %val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32> + return +} + +// ----- + +// Test case: Two ReadVariable ops. + +func @f() { + // CHECK: "tf.VarHandleOp" + + // During lowering to resource variables, this pass will preserve the + // locations of the ReadVariableOps as Identity ops to keep the original graph + // composition and order. + + // CHECK: "tf.ReadVariableOp" + // CHECK: "tf.ReadVariableOp" + %val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref> + %val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32> + %val2 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32> + return +} + +// ----- + +// Test case: No follow-up ReadVariable case. + +func @f() { + // CHECK-NOT: "tf.VariableV2" + // CHECK-NOT: "tf.VarHandleOp" + %val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref> + return +} + +// ----- + +// Test case: No converting when there is another use case. + +func @f() { + // expected-error @+1 {{'tf.VariableV2' op expects all users to be 'tf.Identity', but got user tf.CustomIdentity}} + %val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref> + %val1 = "tf.CustomIdentity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32> + return +} + +// ----- + +// Test case: No class attribute on VariableV2 op. + +func @f() { + // expected-error @+1 {{'tf.VariableV2' op has no '_class' attribute}} + %val0 = "tf.VariableV2"() {container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref> + %val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32> + return +} + +// ----- + +// Test case: No named location found on VariableV2 op. + +func @f() { + // expected-error @+1 {{'tf.VariableV2' op expects variable name in '_class' attribute, but got ["unrelated_class"]}} + %val0 = "tf.VariableV2"() {_class = ["unrelated_class"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref> + %val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32> + return +} + +// ----- + +// Test case: Invalid multiple location information in a class attribute on VariableV2 op. + +func @f() { + // expected-error @+1 {{'tf.VariableV2' op expects only one named location in '_class' attribute, but got ["loc:@v1", "loc:@v2"]}} + %val0 = "tf.VariableV2"() {_class = ["loc:@v1", "loc:@v2"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref> + %val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32> + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 81d0259d2d6..5c140ddd6aa 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -95,6 +95,11 @@ std::unique_ptr> CreatePromoteResourcesToArgsPass(); // functions. std::unique_ptr> CreatePromoteVarHandlesToArgsPass(); +// Creates a pass that converts readonly reference variables to the +// corresponding resource variables. +std::unique_ptr> +CreateConvertReadonlyReferenceVariablesToResourceVariablesPass(); + // Marks function visibility using tf.entry_function specification. That is, // functions with tf.entry_function attributes are marked with public // visibility while the other functions are marked with private visibility. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc b/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc new file mode 100644 index 00000000000..a80b84ddeda --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc @@ -0,0 +1,179 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" + +namespace mlir { +namespace TF { +namespace { + +// Location attribute. +constexpr StringRef kClassAttr = "_class"; +constexpr StringRef kLocationPrefix = "loc:@"; + +// A pass that converts readonly reference variables to the corresponding +// resource variables. +// +// It converts (VariableV2 -> Identity) to (VarHandle -> ReadVariable). +// +// For the background, this pass is a part of hoisting VariableV2 ops by +// re-using the pipeline for hoisting (VarHandle -> ReadVariable) cases, which +// can be done by the following passes: +// - Capturing resource values into global tensors (importing saved model). +// - Promoting VarHandle ops to function input/outputs. +// - Freezing global tensor pass. +// +// This path assumes that all the VariableV2 ops is read-only via verifying the +// heuristic method that assumes that all the users of them is Identity op, +// fed directly. +class ConvertReadonlyReferenceVariablesToResourceVariablesPass + : public PassWrapper< + ConvertReadonlyReferenceVariablesToResourceVariablesPass, + FunctionPass> { + public: + void runOnFunction() override; +}; + +// Parse node name from "_class" attribute. +StringRef GetNodeNameFromClassAttr(Operation *op) { + ArrayAttr classes_attr = op->getAttrOfType(kClassAttr); + if (!classes_attr) { + op->emitOpError() << "has no '_class' attribute"; + return StringRef(); + } + + StringRef result; + for (Attribute class_attr : classes_attr) { + StringRef node_name = class_attr.cast().getValue(); + if (!node_name.startswith(kLocationPrefix)) { + continue; + } + if (!result.empty()) { + // Invalid case since there are multiple loc:@ attributes. + op->emitOpError() + << "expects only one named location in '_class' attribute, but got " + << classes_attr; + return StringRef(); + } + result = node_name.drop_front(kLocationPrefix.size()); + } + if (result.empty()) { + op->emitOpError() << "expects variable name in '_class' attribute, but got " + << classes_attr; + } + return result; +} + +void ConvertReadonlyReferenceVariablesToResourceVariablesPass::runOnFunction() { + FuncOp func = getFunction(); + + OpBuilder builder(func.getContext()); + SmallVector variable_v2s_to_replace; + + // Checks all the VariableV2 ops is read-only via verifying the heuristic + // method that assumes that all the users of them is Identity op, feeded + // directly. + auto read_only_vars_fn = [&variable_v2s_to_replace]( + VariableV2Op variable_v2_op) { + if (variable_v2_op.getResult().use_empty()) { + // Erase the op when there is no user. + variable_v2_op.erase(); + return mlir::WalkResult::advance(); + } + if (!all_of(variable_v2_op.getResult().getUsers(), [&variable_v2_op]( + Operation *user) { + if (!isa(user)) { + variable_v2_op.emitOpError() + << "expects all users to be 'tf.Identity', but got user " + << user->getName(); + return false; + } + return true; + })) { + return mlir::WalkResult::interrupt(); + } + variable_v2s_to_replace.push_back(variable_v2_op); + return mlir::WalkResult::advance(); + }; + + WalkResult walk_res = func.walk(read_only_vars_fn); + if (walk_res.wasInterrupted()) return signalPassFailure(); + + for (VariableV2Op variable_v2_op : variable_v2s_to_replace) { + builder.setInsertionPoint(variable_v2_op); + ShapedType shaped_type = + variable_v2_op.getResult().getType().cast(); + TensorType tensor_type = DropRefType(shaped_type).cast(); + StringAttr device_attr = variable_v2_op.getAttrOfType("device"); + if (!device_attr) device_attr = builder.getStringAttr(""); + StringRef variable_name = GetNodeNameFromClassAttr(variable_v2_op); + if (variable_name.empty()) { + return signalPassFailure(); + } + VarHandleOp var_handle_op = builder.create( + variable_v2_op.getLoc(), + ArrayRef{RankedTensorType::get( + {}, TF::ResourceType::get(ArrayRef{tensor_type}, + builder.getContext()))}, + ArrayRef{}, + ArrayRef{ + builder.getNamedAttr("device", device_attr), + builder.getNamedAttr("container", variable_v2_op.containerAttr()), + builder.getNamedAttr("shared_name", + builder.getStringAttr(variable_name))}); + for (Operation *user : + make_early_inc_range(variable_v2_op.getResult().getUsers())) { + builder.setInsertionPoint(user); + ReadVariableOp read_variable_op = builder.create( + user->getLoc(), ArrayRef{tensor_type}, + ArrayRef{var_handle_op}, ArrayRef{}); + user->getResult(0).replaceAllUsesWith(read_variable_op.getResult()); + user->erase(); + } + variable_v2_op.erase(); + } +} + +} // namespace + +std::unique_ptr> +CreateConvertReadonlyReferenceVariablesToResourceVariablesPass() { + return std::make_unique< + ConvertReadonlyReferenceVariablesToResourceVariablesPass>(); +} + +static PassRegistration< + ConvertReadonlyReferenceVariablesToResourceVariablesPass> + pass("readonly-references-to-resources", + "Convert readonly reference variables to resource variables."); + +} // namespace TF + +} // namespace mlir