Implement a pass that converts readonly reference variables to the corresponding
resource variables. It converts (VariableV2 -> Identity) to (VarHandle -> ReadVariable). For the background, this pass is a part of hoisting VariableV2 ops by re-using the pipeline for hoisting (VarHandle -> ReadVariable) cases, which can be done by the following passes: - Capturing resource values into global tensors (importing saved model). - Promoting VarHandle ops to function input/outputs. - Freezing global tensor pass. This path assumes that all the VariableV2 ops is read-only via verifying the heuristic method that assumes that all the users of them is Identity op, fed directly. PiperOrigin-RevId: 312704760 Change-Id: I89ac4c0543a7954f6b27d418da63f7f1418490cd
This commit is contained in:
parent
c9c8ac3cb9
commit
c030682e9c
|
@ -430,6 +430,7 @@ cc_library(
|
||||||
"transforms/parallel_execute_to_islands.cc",
|
"transforms/parallel_execute_to_islands.cc",
|
||||||
"transforms/promote_resources_to_args.cc",
|
"transforms/promote_resources_to_args.cc",
|
||||||
"transforms/raise_control_flow.cc",
|
"transforms/raise_control_flow.cc",
|
||||||
|
"transforms/readonly_references_to_resources.cc",
|
||||||
"transforms/replicate_invariant_op_hoisting.cc",
|
"transforms/replicate_invariant_op_hoisting.cc",
|
||||||
"transforms/replicate_to_island.cc",
|
"transforms/replicate_to_island.cc",
|
||||||
"transforms/resource_device_inference.cc",
|
"transforms/resource_device_inference.cc",
|
||||||
|
|
|
@ -47,37 +47,6 @@ limitations under the License.
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace tf_executor {
|
namespace tf_executor {
|
||||||
namespace {
|
|
||||||
|
|
||||||
// If the given tensor has elements of type with subtypes, then returns a new
|
|
||||||
// type after dropping subtypes info. Otherwise, returns the original type as
|
|
||||||
// is.
|
|
||||||
ShapedType DropTypeSubTypes(ShapedType ty) {
|
|
||||||
Type element_ty = ty.getElementType();
|
|
||||||
auto subtype_ty = element_ty.dyn_cast<TF::TensorFlowTypeWithSubtype>();
|
|
||||||
if (!subtype_ty) return ty;
|
|
||||||
|
|
||||||
Type default_ty = GetDefaultTypeOf(subtype_ty);
|
|
||||||
if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty);
|
|
||||||
|
|
||||||
return UnrankedTensorType::get(default_ty);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the given tensor has elements of type ref, then returns a new type
|
|
||||||
// of the shape, but corresponding non-ref type as element type. Otherwise,
|
|
||||||
// returns the original type as is.
|
|
||||||
ShapedType DropRefType(ShapedType ty) {
|
|
||||||
Type element_ty = ty.getElementType();
|
|
||||||
auto ref_ty = element_ty.dyn_cast<TF::TensorFlowRefType>();
|
|
||||||
if (!ref_ty) return ty;
|
|
||||||
|
|
||||||
Type default_ty = GetDefaultTypeOf(ref_ty);
|
|
||||||
if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty);
|
|
||||||
|
|
||||||
return UnrankedTensorType::get(default_ty);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TF Executor Dialect
|
// TF Executor Dialect
|
||||||
|
@ -85,6 +54,9 @@ ShapedType DropRefType(ShapedType ty) {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using TF::DropRefType;
|
||||||
|
using TF::DropTypeSubTypes;
|
||||||
|
|
||||||
struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface {
|
struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface {
|
||||||
using DialectInlinerInterface::DialectInlinerInterface;
|
using DialectInlinerInterface::DialectInlinerInterface;
|
||||||
|
|
||||||
|
|
|
@ -366,5 +366,27 @@ bool AreCastCompatible(ArrayRef<Type> types) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ShapedType DropTypeSubTypes(ShapedType ty) {
|
||||||
|
Type element_ty = ty.getElementType();
|
||||||
|
auto subtype_ty = element_ty.dyn_cast<TF::TensorFlowTypeWithSubtype>();
|
||||||
|
if (!subtype_ty) return ty;
|
||||||
|
|
||||||
|
Type default_ty = GetDefaultTypeOf(subtype_ty);
|
||||||
|
if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty);
|
||||||
|
|
||||||
|
return UnrankedTensorType::get(default_ty);
|
||||||
|
}
|
||||||
|
|
||||||
|
ShapedType DropRefType(ShapedType ty) {
|
||||||
|
Type element_ty = ty.getElementType();
|
||||||
|
TF::TensorFlowRefType ref_ty = element_ty.dyn_cast<TF::TensorFlowRefType>();
|
||||||
|
if (!ref_ty) return ty;
|
||||||
|
|
||||||
|
Type default_ty = TF::GetDefaultTypeOf(ref_ty);
|
||||||
|
if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty);
|
||||||
|
|
||||||
|
return UnrankedTensorType::get(default_ty);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace TF
|
} // namespace TF
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -319,6 +319,16 @@ bool HasCompatibleElementTypes(Type lhs, Type rhs,
|
||||||
// compatible.
|
// compatible.
|
||||||
bool AreCastCompatible(ArrayRef<Type> types);
|
bool AreCastCompatible(ArrayRef<Type> types);
|
||||||
|
|
||||||
|
// If the given tensor has elements of type with subtypes, then returns a new
|
||||||
|
// type after dropping subtypes info. Otherwise, returns the original type as
|
||||||
|
// is.
|
||||||
|
ShapedType DropTypeSubTypes(ShapedType ty);
|
||||||
|
|
||||||
|
// If the given tensor has elements of type ref, then returns a new type
|
||||||
|
// of the shape, but corresponding non-ref type as element type. Otherwise,
|
||||||
|
// returns the original type as is.
|
||||||
|
ShapedType DropRefType(ShapedType ty);
|
||||||
|
|
||||||
} // end namespace TF
|
} // end namespace TF
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,85 @@
|
||||||
|
// RUN: tf-opt -verify-diagnostics -readonly-references-to-resources -split-input-file %s | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
// Test case: Basic converting.
|
||||||
|
|
||||||
|
func @f() {
|
||||||
|
// CHECK: "tf.VarHandleOp"
|
||||||
|
// CHECK: "tf.ReadVariableOp"
|
||||||
|
%val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
|
||||||
|
%val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test case: Two ReadVariable ops.
|
||||||
|
|
||||||
|
func @f() {
|
||||||
|
// CHECK: "tf.VarHandleOp"
|
||||||
|
|
||||||
|
// During lowering to resource variables, this pass will preserve the
|
||||||
|
// locations of the ReadVariableOps as Identity ops to keep the original graph
|
||||||
|
// composition and order.
|
||||||
|
|
||||||
|
// CHECK: "tf.ReadVariableOp"
|
||||||
|
// CHECK: "tf.ReadVariableOp"
|
||||||
|
%val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
|
||||||
|
%val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
|
||||||
|
%val2 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test case: No follow-up ReadVariable case.
|
||||||
|
|
||||||
|
func @f() {
|
||||||
|
// CHECK-NOT: "tf.VariableV2"
|
||||||
|
// CHECK-NOT: "tf.VarHandleOp"
|
||||||
|
%val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test case: No converting when there is another use case.
|
||||||
|
|
||||||
|
func @f() {
|
||||||
|
// expected-error @+1 {{'tf.VariableV2' op expects all users to be 'tf.Identity', but got user tf.CustomIdentity}}
|
||||||
|
%val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
|
||||||
|
%val1 = "tf.CustomIdentity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test case: No class attribute on VariableV2 op.
|
||||||
|
|
||||||
|
func @f() {
|
||||||
|
// expected-error @+1 {{'tf.VariableV2' op has no '_class' attribute}}
|
||||||
|
%val0 = "tf.VariableV2"() {container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
|
||||||
|
%val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test case: No named location found on VariableV2 op.
|
||||||
|
|
||||||
|
func @f() {
|
||||||
|
// expected-error @+1 {{'tf.VariableV2' op expects variable name in '_class' attribute, but got ["unrelated_class"]}}
|
||||||
|
%val0 = "tf.VariableV2"() {_class = ["unrelated_class"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
|
||||||
|
%val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test case: Invalid multiple location information in a class attribute on VariableV2 op.
|
||||||
|
|
||||||
|
func @f() {
|
||||||
|
// expected-error @+1 {{'tf.VariableV2' op expects only one named location in '_class' attribute, but got ["loc:@v1", "loc:@v2"]}}
|
||||||
|
%val0 = "tf.VariableV2"() {_class = ["loc:@v1", "loc:@v2"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
|
||||||
|
%val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
|
||||||
|
return
|
||||||
|
}
|
|
@ -95,6 +95,11 @@ std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass();
|
||||||
// functions.
|
// functions.
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass();
|
||||||
|
|
||||||
|
// Creates a pass that converts readonly reference variables to the
|
||||||
|
// corresponding resource variables.
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
|
CreateConvertReadonlyReferenceVariablesToResourceVariablesPass();
|
||||||
|
|
||||||
// Marks function visibility using tf.entry_function specification. That is,
|
// Marks function visibility using tf.entry_function specification. That is,
|
||||||
// functions with tf.entry_function attributes are marked with public
|
// functions with tf.entry_function attributes are marked with public
|
||||||
// visibility while the other functions are marked with private visibility.
|
// visibility while the other functions are marked with private visibility.
|
||||||
|
|
|
@ -0,0 +1,179 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/ADT/StringExtras.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
#include "llvm/Support/Casting.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Function.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Types.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Value.h" // from @llvm-project
|
||||||
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace TF {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Location attribute.
|
||||||
|
constexpr StringRef kClassAttr = "_class";
|
||||||
|
constexpr StringRef kLocationPrefix = "loc:@";
|
||||||
|
|
||||||
|
// A pass that converts readonly reference variables to the corresponding
|
||||||
|
// resource variables.
|
||||||
|
//
|
||||||
|
// It converts (VariableV2 -> Identity) to (VarHandle -> ReadVariable).
|
||||||
|
//
|
||||||
|
// For the background, this pass is a part of hoisting VariableV2 ops by
|
||||||
|
// re-using the pipeline for hoisting (VarHandle -> ReadVariable) cases, which
|
||||||
|
// can be done by the following passes:
|
||||||
|
// - Capturing resource values into global tensors (importing saved model).
|
||||||
|
// - Promoting VarHandle ops to function input/outputs.
|
||||||
|
// - Freezing global tensor pass.
|
||||||
|
//
|
||||||
|
// This path assumes that all the VariableV2 ops is read-only via verifying the
|
||||||
|
// heuristic method that assumes that all the users of them is Identity op,
|
||||||
|
// fed directly.
|
||||||
|
class ConvertReadonlyReferenceVariablesToResourceVariablesPass
|
||||||
|
: public PassWrapper<
|
||||||
|
ConvertReadonlyReferenceVariablesToResourceVariablesPass,
|
||||||
|
FunctionPass> {
|
||||||
|
public:
|
||||||
|
void runOnFunction() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse node name from "_class" attribute.
|
||||||
|
StringRef GetNodeNameFromClassAttr(Operation *op) {
|
||||||
|
ArrayAttr classes_attr = op->getAttrOfType<ArrayAttr>(kClassAttr);
|
||||||
|
if (!classes_attr) {
|
||||||
|
op->emitOpError() << "has no '_class' attribute";
|
||||||
|
return StringRef();
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef result;
|
||||||
|
for (Attribute class_attr : classes_attr) {
|
||||||
|
StringRef node_name = class_attr.cast<StringAttr>().getValue();
|
||||||
|
if (!node_name.startswith(kLocationPrefix)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!result.empty()) {
|
||||||
|
// Invalid case since there are multiple loc:@ attributes.
|
||||||
|
op->emitOpError()
|
||||||
|
<< "expects only one named location in '_class' attribute, but got "
|
||||||
|
<< classes_attr;
|
||||||
|
return StringRef();
|
||||||
|
}
|
||||||
|
result = node_name.drop_front(kLocationPrefix.size());
|
||||||
|
}
|
||||||
|
if (result.empty()) {
|
||||||
|
op->emitOpError() << "expects variable name in '_class' attribute, but got "
|
||||||
|
<< classes_attr;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ConvertReadonlyReferenceVariablesToResourceVariablesPass::runOnFunction() {
|
||||||
|
FuncOp func = getFunction();
|
||||||
|
|
||||||
|
OpBuilder builder(func.getContext());
|
||||||
|
SmallVector<VariableV2Op, 4> variable_v2s_to_replace;
|
||||||
|
|
||||||
|
// Checks all the VariableV2 ops is read-only via verifying the heuristic
|
||||||
|
// method that assumes that all the users of them is Identity op, feeded
|
||||||
|
// directly.
|
||||||
|
auto read_only_vars_fn = [&variable_v2s_to_replace](
|
||||||
|
VariableV2Op variable_v2_op) {
|
||||||
|
if (variable_v2_op.getResult().use_empty()) {
|
||||||
|
// Erase the op when there is no user.
|
||||||
|
variable_v2_op.erase();
|
||||||
|
return mlir::WalkResult::advance();
|
||||||
|
}
|
||||||
|
if (!all_of(variable_v2_op.getResult().getUsers(), [&variable_v2_op](
|
||||||
|
Operation *user) {
|
||||||
|
if (!isa<IdentityOp>(user)) {
|
||||||
|
variable_v2_op.emitOpError()
|
||||||
|
<< "expects all users to be 'tf.Identity', but got user "
|
||||||
|
<< user->getName();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
})) {
|
||||||
|
return mlir::WalkResult::interrupt();
|
||||||
|
}
|
||||||
|
variable_v2s_to_replace.push_back(variable_v2_op);
|
||||||
|
return mlir::WalkResult::advance();
|
||||||
|
};
|
||||||
|
|
||||||
|
WalkResult walk_res = func.walk(read_only_vars_fn);
|
||||||
|
if (walk_res.wasInterrupted()) return signalPassFailure();
|
||||||
|
|
||||||
|
for (VariableV2Op variable_v2_op : variable_v2s_to_replace) {
|
||||||
|
builder.setInsertionPoint(variable_v2_op);
|
||||||
|
ShapedType shaped_type =
|
||||||
|
variable_v2_op.getResult().getType().cast<ShapedType>();
|
||||||
|
TensorType tensor_type = DropRefType(shaped_type).cast<TensorType>();
|
||||||
|
StringAttr device_attr = variable_v2_op.getAttrOfType<StringAttr>("device");
|
||||||
|
if (!device_attr) device_attr = builder.getStringAttr("");
|
||||||
|
StringRef variable_name = GetNodeNameFromClassAttr(variable_v2_op);
|
||||||
|
if (variable_name.empty()) {
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
VarHandleOp var_handle_op = builder.create<VarHandleOp>(
|
||||||
|
variable_v2_op.getLoc(),
|
||||||
|
ArrayRef<Type>{RankedTensorType::get(
|
||||||
|
{}, TF::ResourceType::get(ArrayRef<TensorType>{tensor_type},
|
||||||
|
builder.getContext()))},
|
||||||
|
ArrayRef<Value>{},
|
||||||
|
ArrayRef<NamedAttribute>{
|
||||||
|
builder.getNamedAttr("device", device_attr),
|
||||||
|
builder.getNamedAttr("container", variable_v2_op.containerAttr()),
|
||||||
|
builder.getNamedAttr("shared_name",
|
||||||
|
builder.getStringAttr(variable_name))});
|
||||||
|
for (Operation *user :
|
||||||
|
make_early_inc_range(variable_v2_op.getResult().getUsers())) {
|
||||||
|
builder.setInsertionPoint(user);
|
||||||
|
ReadVariableOp read_variable_op = builder.create<ReadVariableOp>(
|
||||||
|
user->getLoc(), ArrayRef<Type>{tensor_type},
|
||||||
|
ArrayRef<Value>{var_handle_op}, ArrayRef<NamedAttribute>{});
|
||||||
|
user->getResult(0).replaceAllUsesWith(read_variable_op.getResult());
|
||||||
|
user->erase();
|
||||||
|
}
|
||||||
|
variable_v2_op.erase();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
|
CreateConvertReadonlyReferenceVariablesToResourceVariablesPass() {
|
||||||
|
return std::make_unique<
|
||||||
|
ConvertReadonlyReferenceVariablesToResourceVariablesPass>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<
|
||||||
|
ConvertReadonlyReferenceVariablesToResourceVariablesPass>
|
||||||
|
pass("readonly-references-to-resources",
|
||||||
|
"Convert readonly reference variables to resource variables.");
|
||||||
|
|
||||||
|
} // namespace TF
|
||||||
|
|
||||||
|
} // namespace mlir
|
Loading…
Reference in New Issue