Implement a pass that converts readonly reference variables to the corresponding

resource variables.

It converts (VariableV2 -> Identity) to (VarHandle -> ReadVariable).

For the background, this pass is a part of hoisting VariableV2 ops by
re-using the pipeline for hoisting (VarHandle -> ReadVariable) cases, which
can be done by the following passes:
  - Capturing resource values into global tensors (importing saved model).
  - Promoting VarHandle ops to function input/outputs.
  - Freezing global tensor pass.

This path assumes that all the VariableV2 ops is read-only via verifying the
heuristic method that assumes that all the users of them is Identity op,
fed directly.

PiperOrigin-RevId: 312704760
Change-Id: I89ac4c0543a7954f6b27d418da63f7f1418490cd
This commit is contained in:
Jaesung Chung 2020-05-21 11:15:27 -07:00 committed by TensorFlower Gardener
parent c9c8ac3cb9
commit c030682e9c
7 changed files with 305 additions and 31 deletions

View File

@ -430,6 +430,7 @@ cc_library(
"transforms/parallel_execute_to_islands.cc", "transforms/parallel_execute_to_islands.cc",
"transforms/promote_resources_to_args.cc", "transforms/promote_resources_to_args.cc",
"transforms/raise_control_flow.cc", "transforms/raise_control_flow.cc",
"transforms/readonly_references_to_resources.cc",
"transforms/replicate_invariant_op_hoisting.cc", "transforms/replicate_invariant_op_hoisting.cc",
"transforms/replicate_to_island.cc", "transforms/replicate_to_island.cc",
"transforms/resource_device_inference.cc", "transforms/resource_device_inference.cc",

View File

@ -47,37 +47,6 @@ limitations under the License.
namespace mlir { namespace mlir {
namespace tf_executor { namespace tf_executor {
namespace {
// If the given tensor has elements of type with subtypes, then returns a new
// type after dropping subtypes info. Otherwise, returns the original type as
// is.
ShapedType DropTypeSubTypes(ShapedType ty) {
Type element_ty = ty.getElementType();
auto subtype_ty = element_ty.dyn_cast<TF::TensorFlowTypeWithSubtype>();
if (!subtype_ty) return ty;
Type default_ty = GetDefaultTypeOf(subtype_ty);
if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty);
return UnrankedTensorType::get(default_ty);
}
// If the given tensor has elements of type ref, then returns a new type
// of the shape, but corresponding non-ref type as element type. Otherwise,
// returns the original type as is.
ShapedType DropRefType(ShapedType ty) {
Type element_ty = ty.getElementType();
auto ref_ty = element_ty.dyn_cast<TF::TensorFlowRefType>();
if (!ref_ty) return ty;
Type default_ty = GetDefaultTypeOf(ref_ty);
if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty);
return UnrankedTensorType::get(default_ty);
}
} // namespace
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TF Executor Dialect // TF Executor Dialect
@ -85,6 +54,9 @@ ShapedType DropRefType(ShapedType ty) {
namespace { namespace {
using TF::DropRefType;
using TF::DropTypeSubTypes;
struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface { struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface; using DialectInlinerInterface::DialectInlinerInterface;

View File

@ -366,5 +366,27 @@ bool AreCastCompatible(ArrayRef<Type> types) {
return true; return true;
} }
ShapedType DropTypeSubTypes(ShapedType ty) {
Type element_ty = ty.getElementType();
auto subtype_ty = element_ty.dyn_cast<TF::TensorFlowTypeWithSubtype>();
if (!subtype_ty) return ty;
Type default_ty = GetDefaultTypeOf(subtype_ty);
if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty);
return UnrankedTensorType::get(default_ty);
}
ShapedType DropRefType(ShapedType ty) {
Type element_ty = ty.getElementType();
TF::TensorFlowRefType ref_ty = element_ty.dyn_cast<TF::TensorFlowRefType>();
if (!ref_ty) return ty;
Type default_ty = TF::GetDefaultTypeOf(ref_ty);
if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty);
return UnrankedTensorType::get(default_ty);
}
} // namespace TF } // namespace TF
} // namespace mlir } // namespace mlir

View File

@ -319,6 +319,16 @@ bool HasCompatibleElementTypes(Type lhs, Type rhs,
// compatible. // compatible.
bool AreCastCompatible(ArrayRef<Type> types); bool AreCastCompatible(ArrayRef<Type> types);
// If the given tensor has elements of type with subtypes, then returns a new
// type after dropping subtypes info. Otherwise, returns the original type as
// is.
ShapedType DropTypeSubTypes(ShapedType ty);
// If the given tensor has elements of type ref, then returns a new type
// of the shape, but corresponding non-ref type as element type. Otherwise,
// returns the original type as is.
ShapedType DropRefType(ShapedType ty);
} // end namespace TF } // end namespace TF
} // end namespace mlir } // end namespace mlir

View File

@ -0,0 +1,85 @@
// RUN: tf-opt -verify-diagnostics -readonly-references-to-resources -split-input-file %s | FileCheck %s --dump-input=fail
// Test case: Basic converting.
func @f() {
// CHECK: "tf.VarHandleOp"
// CHECK: "tf.ReadVariableOp"
%val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
%val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
return
}
// -----
// Test case: Two ReadVariable ops.
func @f() {
// CHECK: "tf.VarHandleOp"
// During lowering to resource variables, this pass will preserve the
// locations of the ReadVariableOps as Identity ops to keep the original graph
// composition and order.
// CHECK: "tf.ReadVariableOp"
// CHECK: "tf.ReadVariableOp"
%val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
%val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
%val2 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
return
}
// -----
// Test case: No follow-up ReadVariable case.
func @f() {
// CHECK-NOT: "tf.VariableV2"
// CHECK-NOT: "tf.VarHandleOp"
%val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
return
}
// -----
// Test case: No converting when there is another use case.
func @f() {
// expected-error @+1 {{'tf.VariableV2' op expects all users to be 'tf.Identity', but got user tf.CustomIdentity}}
%val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
%val1 = "tf.CustomIdentity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
return
}
// -----
// Test case: No class attribute on VariableV2 op.
func @f() {
// expected-error @+1 {{'tf.VariableV2' op has no '_class' attribute}}
%val0 = "tf.VariableV2"() {container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
%val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
return
}
// -----
// Test case: No named location found on VariableV2 op.
func @f() {
// expected-error @+1 {{'tf.VariableV2' op expects variable name in '_class' attribute, but got ["unrelated_class"]}}
%val0 = "tf.VariableV2"() {_class = ["unrelated_class"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
%val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
return
}
// -----
// Test case: Invalid multiple location information in a class attribute on VariableV2 op.
func @f() {
// expected-error @+1 {{'tf.VariableV2' op expects only one named location in '_class' attribute, but got ["loc:@v1", "loc:@v2"]}}
%val0 = "tf.VariableV2"() {_class = ["loc:@v1", "loc:@v2"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
%val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
return
}

View File

@ -95,6 +95,11 @@ std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass();
// functions. // functions.
std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass(); std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass();
// Creates a pass that converts readonly reference variables to the
// corresponding resource variables.
std::unique_ptr<OperationPass<FuncOp>>
CreateConvertReadonlyReferenceVariablesToResourceVariablesPass();
// Marks function visibility using tf.entry_function specification. That is, // Marks function visibility using tf.entry_function specification. That is,
// functions with tf.entry_function attributes are marked with public // functions with tf.entry_function attributes are marked with public
// visibility while the other functions are marked with private visibility. // visibility while the other functions are marked with private visibility.

View File

@ -0,0 +1,179 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
namespace mlir {
namespace TF {
namespace {
// Location attribute.
constexpr StringRef kClassAttr = "_class";
constexpr StringRef kLocationPrefix = "loc:@";
// A pass that converts readonly reference variables to the corresponding
// resource variables.
//
// It converts (VariableV2 -> Identity) to (VarHandle -> ReadVariable).
//
// For the background, this pass is a part of hoisting VariableV2 ops by
// re-using the pipeline for hoisting (VarHandle -> ReadVariable) cases, which
// can be done by the following passes:
// - Capturing resource values into global tensors (importing saved model).
// - Promoting VarHandle ops to function input/outputs.
// - Freezing global tensor pass.
//
// This path assumes that all the VariableV2 ops is read-only via verifying the
// heuristic method that assumes that all the users of them is Identity op,
// fed directly.
class ConvertReadonlyReferenceVariablesToResourceVariablesPass
: public PassWrapper<
ConvertReadonlyReferenceVariablesToResourceVariablesPass,
FunctionPass> {
public:
void runOnFunction() override;
};
// Parse node name from "_class" attribute.
StringRef GetNodeNameFromClassAttr(Operation *op) {
ArrayAttr classes_attr = op->getAttrOfType<ArrayAttr>(kClassAttr);
if (!classes_attr) {
op->emitOpError() << "has no '_class' attribute";
return StringRef();
}
StringRef result;
for (Attribute class_attr : classes_attr) {
StringRef node_name = class_attr.cast<StringAttr>().getValue();
if (!node_name.startswith(kLocationPrefix)) {
continue;
}
if (!result.empty()) {
// Invalid case since there are multiple loc:@ attributes.
op->emitOpError()
<< "expects only one named location in '_class' attribute, but got "
<< classes_attr;
return StringRef();
}
result = node_name.drop_front(kLocationPrefix.size());
}
if (result.empty()) {
op->emitOpError() << "expects variable name in '_class' attribute, but got "
<< classes_attr;
}
return result;
}
void ConvertReadonlyReferenceVariablesToResourceVariablesPass::runOnFunction() {
FuncOp func = getFunction();
OpBuilder builder(func.getContext());
SmallVector<VariableV2Op, 4> variable_v2s_to_replace;
// Checks all the VariableV2 ops is read-only via verifying the heuristic
// method that assumes that all the users of them is Identity op, feeded
// directly.
auto read_only_vars_fn = [&variable_v2s_to_replace](
VariableV2Op variable_v2_op) {
if (variable_v2_op.getResult().use_empty()) {
// Erase the op when there is no user.
variable_v2_op.erase();
return mlir::WalkResult::advance();
}
if (!all_of(variable_v2_op.getResult().getUsers(), [&variable_v2_op](
Operation *user) {
if (!isa<IdentityOp>(user)) {
variable_v2_op.emitOpError()
<< "expects all users to be 'tf.Identity', but got user "
<< user->getName();
return false;
}
return true;
})) {
return mlir::WalkResult::interrupt();
}
variable_v2s_to_replace.push_back(variable_v2_op);
return mlir::WalkResult::advance();
};
WalkResult walk_res = func.walk(read_only_vars_fn);
if (walk_res.wasInterrupted()) return signalPassFailure();
for (VariableV2Op variable_v2_op : variable_v2s_to_replace) {
builder.setInsertionPoint(variable_v2_op);
ShapedType shaped_type =
variable_v2_op.getResult().getType().cast<ShapedType>();
TensorType tensor_type = DropRefType(shaped_type).cast<TensorType>();
StringAttr device_attr = variable_v2_op.getAttrOfType<StringAttr>("device");
if (!device_attr) device_attr = builder.getStringAttr("");
StringRef variable_name = GetNodeNameFromClassAttr(variable_v2_op);
if (variable_name.empty()) {
return signalPassFailure();
}
VarHandleOp var_handle_op = builder.create<VarHandleOp>(
variable_v2_op.getLoc(),
ArrayRef<Type>{RankedTensorType::get(
{}, TF::ResourceType::get(ArrayRef<TensorType>{tensor_type},
builder.getContext()))},
ArrayRef<Value>{},
ArrayRef<NamedAttribute>{
builder.getNamedAttr("device", device_attr),
builder.getNamedAttr("container", variable_v2_op.containerAttr()),
builder.getNamedAttr("shared_name",
builder.getStringAttr(variable_name))});
for (Operation *user :
make_early_inc_range(variable_v2_op.getResult().getUsers())) {
builder.setInsertionPoint(user);
ReadVariableOp read_variable_op = builder.create<ReadVariableOp>(
user->getLoc(), ArrayRef<Type>{tensor_type},
ArrayRef<Value>{var_handle_op}, ArrayRef<NamedAttribute>{});
user->getResult(0).replaceAllUsesWith(read_variable_op.getResult());
user->erase();
}
variable_v2_op.erase();
}
}
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
CreateConvertReadonlyReferenceVariablesToResourceVariablesPass() {
return std::make_unique<
ConvertReadonlyReferenceVariablesToResourceVariablesPass>();
}
static PassRegistration<
ConvertReadonlyReferenceVariablesToResourceVariablesPass>
pass("readonly-references-to-resources",
"Convert readonly reference variables to resource variables.");
} // namespace TF
} // namespace mlir