Only insert TPUCopyWithLayout when resource generator is on CPU.

The placement of iterator is not determined by the device attribute of tf.IteratorGetNext but by the device attribute of the alias resource generator op. The generator ops are either function arguments or tf.VarHandle ops. So, we modify the TPU dynamic layout pass to check device attribute of generator ops.

This CL modifies the Resource Alias Analysis to return the set of aliases of a given value. For efficiency, it now keeps a map of both resource value --> unique resource ID and vice-versa.

This improves the performance of few models running with MLIR TF/XLA bridge. For example, Resnet50 step time is reduced from 180 ms to 125 ms (on par with the old bridge).

PiperOrigin-RevId: 319255334
Change-Id: I4df9f26f480b580b8277caae981f06c3189e7bf4
This commit is contained in:
Prakalp Srivastava 2020-07-01 10:49:32 -07:00 committed by TensorFlower Gardener
parent 278cbe34ea
commit 153947b5c5
4 changed files with 214 additions and 38 deletions

View File

@ -118,7 +118,7 @@ ResourceAliasAnalysis::ResourceAliasAnalysis(Operation* op) {
}
void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
// This function populates resource_value_to_ids_.
// This function populates resource_value_to_ids_ and id_to_resource_values_.
// If the "tf.resource_arg_unique_id" argument attributes are present for
// resource-type arguments, respect them when choosing IDs; otherwise, they
@ -142,9 +142,9 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
"or all arguments.");
auto emplace_res = attr_id_to_internal_id.try_emplace(id_attr.getInt(),
next_unique_id++);
resource_value_to_ids_[arg].insert(emplace_res.first->getSecond());
AddValueUniqueIDMapping(arg, emplace_res.first->getSecond());
} else {
resource_value_to_ids_[arg].insert(next_unique_id++);
AddValueUniqueIDMapping(arg, next_unique_id++);
}
}
llvm::StringMap<int64_t> var_handle_name_id_map;
@ -164,7 +164,8 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
func_op.walk([&](Operation* op) {
if (auto var_handle = llvm::dyn_cast<TF::VarHandleOp>(op)) {
resource_value_to_ids_[var_handle.resource()].insert(
AddValueUniqueIDMapping(
var_handle.resource(),
GetOrCreateIdForVarHandle(var_handle, &next_unique_id,
&var_handle_name_id_map));
} else if (llvm::isa<TF::IdentityNOp>(op) ||
@ -180,7 +181,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
// different resources.
for (auto arg : replicate.GetBody().getArguments()) {
if (mlir::getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>()) {
resource_value_to_ids_[arg].insert(next_unique_id++);
AddValueUniqueIDMapping(arg, next_unique_id++);
}
}
} else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(op)) {
@ -198,7 +199,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
forward_input_to_output(while_op.getOperand(passthrough_operand),
result.value());
} else {
resource_value_to_ids_[result.value()].insert(kUnknownResourceId);
AddValueUniqueIDMapping(result.value(), kUnknownResourceId);
}
}
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(op)) {
@ -223,7 +224,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
forward_input_to_output(if_op.getOperand(passthrough_else_arg + 1),
result.value());
} else {
resource_value_to_ids_[result.value()].insert(kUnknownResourceId);
AddValueUniqueIDMapping(result.value(), kUnknownResourceId);
}
}
} else {
@ -231,7 +232,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
if (!mlir::getElementTypeOrSelf(result.getType())
.isa<TF::ResourceType>())
continue;
resource_value_to_ids_[result].insert(kUnknownResourceId);
AddValueUniqueIDMapping(result, kUnknownResourceId);
}
}
});
@ -254,6 +255,24 @@ const llvm::SmallSet<int64_t, 8>& ResourceAliasAnalysis::GetResourceUniqueIds(
return it->getSecond();
}
const llvm::SmallSetVector<Value, 8>&
ResourceAliasAnalysis::GetUniqueIdResources(const int64_t id) const {
auto it = id_to_resource_values_.find(id);
assert(it != id_to_resource_values_.end() && "Unseen id was queried");
return it->getSecond();
}
llvm::SmallSetVector<Value, 8> ResourceAliasAnalysis::GetResourceAliases(
const Value resource) const {
assert(!IsUnknownResource(resource) && "Unseen resource was queried");
llvm::SmallSetVector<Value, 8> aliases;
for (int64_t id : GetResourceUniqueIds(resource)) {
const llvm::SmallSetVector<Value, 8>& resources_aliasing_id =
GetUniqueIdResources(id);
aliases.insert(resources_aliasing_id.begin(), resources_aliasing_id.end());
}
return aliases;
}
namespace {
// Returns a set that contains only kUnknownResourceId.

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <cstdint>
#include <memory>
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
@ -49,15 +50,34 @@ class ResourceAliasAnalysis {
const llvm::SmallSet<int64_t, 8>& GetResourceUniqueIds(
const Value resource) const;
// Returns the set of values that are potentially aliases of `value`. Requires
// that IsUnknownResource(resource) == true.
llvm::SmallSetVector<Value, 8> GetResourceAliases(const Value resource) const;
private:
ResourceAliasAnalysis() = default;
// Runs the analysis on `func_op` and populates resource_value_to_ids_.
// Runs the analysis on `func_op` and populates two way resource values to
// unique ID mapping.
void AnalyzeFunction(FuncOp func_op);
// Maps resource value to unique ID and vice-versa.
void AddValueUniqueIDMapping(Value value, int64_t id) {
resource_value_to_ids_[value].insert(id);
id_to_resource_values_[id].insert(value);
}
// Returns the set unique Values which map to `id`.
const llvm::SmallSetVector<Value, 8>& GetUniqueIdResources(int64_t id) const;
// Maps each resource-type value to a set of unique IDs that it could alias.
llvm::SmallDenseMap<Value, llvm::SmallSet<int64_t, 8>, 8>
resource_value_to_ids_;
// Maps each unique ID to a set of resource-type values that could alias to
// it. This is inverse of `resource_value_to_ids_` map.
llvm::SmallDenseMap<int64_t, llvm::SmallSetVector<Value, 8>, 8>
id_to_resource_values_;
};
// An analysis that runs on a function and infers the control predecessors and

View File

@ -106,10 +106,107 @@ func @on_tpu_iter(%arg0: tensor<*x!tf.resource> {tf.device = "/device:TPU:0"}) -
// -----
// Tests that the pass does not transform when tf.IteratorGetNext is on CPU
// but generator is on TPU.
// CHECK-LABEL: func @arg_on_tpu_iter_on_cpu
func @arg_on_tpu_iter_on_cpu(%arg0: tensor<*x!tf.resource> {tf.device = "/device:TPU:0"}) -> tensor<i32> {
%compile:2 = "tf_device.launch"() ( {
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
// CHECK-NOT: "tf.TPUGetLayoutOp"
// CHECK-NOT: "tf.TPUCopyWithLayout"
%2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"}
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/device:CPU:0"} : () -> ()
%execute = "tf_device.launch"() ( {
%3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1)
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
tf_device.return %3 : tensor<i32>
}) {device = "/device:TPU:0"} : () -> tensor<i32>
return %execute : tensor<i32>
}
// -----
// Tests that the pass does not transform when tf.IteratorGetNext is on CPU but
// generator is on TPU. All intermediate nodes like tf.Identity between
// generator and IteratorGetNext are on CPU too.
// CHECK-LABEL: func @arg_on_tpu_intermediate_ops_on_cpu
func @arg_on_tpu_intermediate_ops_on_cpu(%arg0: tensor<*x!tf.resource> {tf.device = "/device:TPU:0"}) -> tensor<i32> {
%compile:2 = "tf_device.launch"() ( {
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
%id1 = "tf.Identity"(%arg0) {device = "/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<*x!tf.resource>)
%id2 = "tf.Identity"(%id1) {device = "/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<*x!tf.resource>)
// CHECK-NOT: "tf.TPUGetLayoutOp"
// CHECK-NOT: "tf.TPUCopyWithLayout"
%2:2 = "tf.IteratorGetNext"(%id2) {device = "/device:CPU:0"}
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/device:CPU:0"} : () -> ()
%execute = "tf_device.launch"() ( {
%3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1)
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
tf_device.return %3 : tensor<i32>
}) {device = "/device:TPU:0"} : () -> tensor<i32>
return %execute : tensor<i32>
}
// -----
// Tests that the pass does not transform when tf.IteratorGetNext is on CPU but
// generator is on TPU.
// CHECK-LABEL: func @var_handle_on_tpu_iter_on_cpu
func @var_handle_on_tpu_iter_on_cpu() -> tensor<i32> {
%compile:2 = "tf_device.launch"() ( {
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
%var = "tf.VarHandleOp"() {container = "c", shared_name = "v", device = "/device:TPU:0"} : () -> tensor<*x!tf.resource>
// CHECK-NOT: "tf.TPUGetLayoutOp"
// CHECK-NOT: "tf.TPUCopyWithLayout"
%2:2 = "tf.IteratorGetNext"(%var) {device = "/device:CPU:0"}
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/device:CPU:0"} : () -> ()
%execute = "tf_device.launch"() ( {
%3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1)
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
tf_device.return %3 : tensor<i32>
}) {device = "/device:TPU:0"} : () -> tensor<i32>
return %execute : tensor<i32>
}
// -----
// Tests that the pass does not change unsupported input ops.
// CHECK-LABEL: func @unsupported_ops
func @unsupported_ops(%arg0: tensor<3x3x1x32xf32>) -> tensor<i32> {
func @unsupported_ops(%arg0: tensor<3x3x1x32xf32> {tf.device = "/device:CPU:0"}) -> tensor<i32> {
%compile:2 = "tf_device.launch"() ( {
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
@ -183,7 +280,7 @@ func @replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) ->
// Tests that the pass does not change inputs inside replicate.
// CHECK-LABEL: func @inside_replicated
func @inside_replicated(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resource>) -> tensor<i32> {
func @inside_replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}, %arg1: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> tensor<i32> {
%compile:2 = "tf_device.launch"() ( {
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
@ -229,7 +326,7 @@ func @inside_replicated(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resou
// num_cores_per_replica: 2
// CHECK-LABEL: func @parallel_execute
func @parallel_execute(%arg0: tensor<*x!tf.resource>) {
func @parallel_execute(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) {
// CHECK: %[[COMPILE:.*]]:3 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"()
%compile:3 = "tf_device.launch"() ( {
@ -293,8 +390,9 @@ func @parallel_execute(%arg0: tensor<*x!tf.resource>) {
// num_cores_per_replica: 2
// CHECK-LABEL: func @replicated_parallel_execute
// CHECK-SAME: (%[[ARG0:[a-z0-9]+]]: tensor<*x!tf.resource>, %[[ARG1:[a-z0-9]+]]: tensor<*x!tf.resource>)
func @replicated_parallel_execute(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resource>) {
// CHECK-SAME: %[[ARG0:[a-z0-9]+]]: tensor<*x!tf.resource>
// CHECK-SAME: %[[ARG1:[a-z0-9]+]]: tensor<*x!tf.resource>
func @replicated_parallel_execute(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}, %arg1: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) {
// CHECK: %[[COMPILE:.*]]:3 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"()
%compile:3 = "tf_device.launch"() ( {

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
@ -46,6 +47,8 @@ namespace TFTPU {
namespace {
constexpr char kDeviceAttr[] = "device";
constexpr char kDeviceCPU[] = "CPU";
constexpr char kFuncDeviceAttr[] = "tf.device";
// A pass that allows TPU input layout to be determined after JIT compilation.
// This is done by adding run-time ops that interpret compilation result and
@ -79,17 +82,49 @@ struct TPUDynamicLayoutPass
};
// Checks if the input producer op is supported in this transform. Right now, we
// only check if it is a host tf.IteratorGetNext.
bool IsSupportedInputOp(Operation* op) {
if (!llvm::isa<TF::IteratorGetNextOp>(op)) return false;
auto device = op->getAttrOfType<StringAttr>(kDeviceAttr);
if (!device) return false;
tensorflow::DeviceNameUtils::ParsedName parsed_device;
if (!tensorflow::DeviceNameUtils::ParseFullName(device.getValue().str(),
&parsed_device)) {
// only check if it is a tf.IteratorGetNext where resource input is coming from
// a VarHandle on CPU or a function argument assigned to CPU.
bool IsSupportedInputOp(Operation* op,
TF::ResourceAliasAnalysis* resource_alias_analysis) {
TF::IteratorGetNextOp iterator_op = llvm::dyn_cast<TF::IteratorGetNextOp>(op);
if (!iterator_op) return false;
Value resource_iterator = iterator_op.iterator();
if (resource_alias_analysis->IsUnknownResource(resource_iterator))
return false;
}
return parsed_device.type == "CPU";
llvm::SmallSetVector<Value, 8> aliases =
resource_alias_analysis->GetResourceAliases(resource_iterator);
auto is_generator = [](Value val) {
if (val.isa<BlockArgument>()) return true;
Operation* definition = val.getDefiningOp();
return definition->getNumOperands() == 0 &&
definition->getNumResults() == 1;
};
// Check all generator aliases (ops or function argument) are on CPU.
FuncOp func = iterator_op.getParentOfType<FuncOp>();
return llvm::all_of(aliases, [&](Value alias) {
// Ignore non-generator aliases.
if (!is_generator(alias)) return true;
StringAttr device;
if (auto arg = alias.dyn_cast<BlockArgument>()) {
device = func.getArgAttrOfType<mlir::StringAttr>(arg.getArgNumber(),
kFuncDeviceAttr);
} else {
device = alias.getDefiningOp()->getAttrOfType<StringAttr>(kDeviceAttr);
}
if (!device) return false;
tensorflow::DeviceNameUtils::ParsedName parsed_device;
if (!tensorflow::DeviceNameUtils::ParseFullName(device.getValue().str(),
&parsed_device)) {
return false;
}
return parsed_device.has_type && parsed_device.type == kDeviceCPU;
});
}
OpBuilder CreateBuilderAfterOp(Operation* op) {
@ -139,12 +174,11 @@ void HandleInput(Value input, const int64_t execute_arg_index,
// Performs transformation for replicated inputs. Returns true if this is a
// supported case (thus transform happened).
bool HandleReplicatedInputs(const int64_t execute_arg_index,
Value compilation_key,
tf_device::LaunchOp execute_launch,
tf_device::LaunchOp compile_launch,
const int64_t replicate_arg_index,
tf_device::ReplicateOp replicate) {
bool HandleReplicatedInputs(
const int64_t execute_arg_index, Value compilation_key,
tf_device::LaunchOp execute_launch, tf_device::LaunchOp compile_launch,
const int64_t replicate_arg_index, tf_device::ReplicateOp replicate,
TF::ResourceAliasAnalysis* resource_alias_analysis) {
// We need to know the devices to copy to.
if (!replicate.devices()) return false;
int64_t num_replicas = replicate.n().getZExtValue();
@ -153,7 +187,8 @@ bool HandleReplicatedInputs(const int64_t execute_arg_index,
.take_front(num_replicas);
for (auto entry : llvm::enumerate(inputs)) {
auto input_op = entry.value().getDefiningOp();
if (!input_op || !IsSupportedInputOp(input_op)) return false;
if (!input_op || !IsSupportedInputOp(input_op, resource_alias_analysis))
return false;
}
OpBuilder builder = CreateBuilderAfterOp(compile_launch);
auto get_layout = BuildGetLayout(execute_arg_index, compilation_key,
@ -180,7 +215,8 @@ bool HandleReplicatedInputs(const int64_t execute_arg_index,
// compile should not have other uses.
void HandleCompileAndExecutes(
tf_device::LaunchOp compile_launch,
llvm::MutableArrayRef<tf_device::LaunchOp> execute_launches) {
llvm::MutableArrayRef<tf_device::LaunchOp> execute_launches,
TF::ResourceAliasAnalysis* resource_alias_analysis) {
auto compile =
llvm::cast<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front());
tensorflow::tpu::TPUCompileMetadataProto metadata;
@ -206,9 +242,10 @@ void HandleCompileAndExecutes(
// For a block argument, consider transforms only when it is a
// replicated input (defining ops will be outside the replicate node).
if (maybe_replicate != block_arg.getParentRegion()->getParentOp() ||
!HandleReplicatedInputs(
execute_arg_index, execute.key(), execute_launch,
compile_launch, block_arg.getArgNumber(), maybe_replicate)) {
!HandleReplicatedInputs(execute_arg_index, execute.key(),
execute_launch, compile_launch,
block_arg.getArgNumber(), maybe_replicate,
resource_alias_analysis)) {
continue;
}
} else {
@ -221,7 +258,7 @@ void HandleCompileAndExecutes(
maybe_replicate.body().isAncestor(input_op->getParentRegion())) {
continue;
}
if (!IsSupportedInputOp(input_op)) continue;
if (!IsSupportedInputOp(input_op, resource_alias_analysis)) continue;
HandleInput(input, execute_arg_index, execute, execute_launch,
compile_launch);
}
@ -238,7 +275,8 @@ void HandleCompileAndExecutes(
}
void TPUDynamicLayoutPass::runOnFunction() {
getFunction().walk([](TF::_TPUCompileMlirOp compile) {
TF::ResourceAliasAnalysis resource_alias_analysis(getFunction());
getFunction().walk([&](TF::_TPUCompileMlirOp compile) {
// Detect tf._TPUCompileMlir -> tf.TPUExecute(s).
auto compile_launch =
llvm::dyn_cast<tf_device::LaunchOp>(compile.getParentOp());
@ -257,7 +295,8 @@ void TPUDynamicLayoutPass::runOnFunction() {
execute_launches.push_back(execute_launch);
}
HandleCompileAndExecutes(compile_launch, execute_launches);
HandleCompileAndExecutes(compile_launch, execute_launches,
&resource_alias_analysis);
});
}