Add a simple device assignment pass with default-device option
PiperOrigin-RevId: 293483851 Change-Id: I20f5b65d1e78357de70c28cc278881c7da72281f
This commit is contained in:
parent
fed12f5bdc
commit
bdef2c56b3
@ -255,6 +255,7 @@ cc_library(
|
||||
"transforms/shape_inference_pass.cc",
|
||||
"transforms/sink_constant.cc",
|
||||
"transforms/test_side_effect_analysis.cc",
|
||||
"transforms/tf_device_assignment.cc",
|
||||
"transforms/tpu_cluster_formation.cc",
|
||||
"transforms/tpu_dynamic_layout_pass.cc",
|
||||
"transforms/tpu_dynamic_padding_mapper.cc",
|
||||
|
@ -0,0 +1,13 @@
|
||||
// RUN: tf-opt -tf-simple-device-assignment='default-device=gpu' %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @device_test
|
||||
func @device_test(%arg0: tensor<3x1xf32>) -> (tensor<3x3xf32>) {
|
||||
|
||||
// CHECK: device = "gpu"
|
||||
%0 = "tf.Const"() {value = dense<[[1.0, 2.0, 3.0]]> : tensor<1x3xf32>} : () -> tensor<1x3xf32>
|
||||
// CHECK: device = "gpu"
|
||||
%1 = "tf.MatMul"(%arg0, %0) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "", transpose_a = false, transpose_b = false} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
|
||||
// CHECK: device = "cpu"
|
||||
%2 = "tf.Relu"(%1) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "cpu"} : (tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||
return %2 : tensor<3x3xf32>
|
||||
}
|
@ -79,6 +79,10 @@ LogicalResult MarkFunctionVisibilityUsingEntryFunctionSpecification(
|
||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
||||
CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass();
|
||||
|
||||
// Create a simple device assignment pass on TF dialect for CoreRT use case.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateSimpleTFDeviceAssignmentPass(
|
||||
llvm::StringRef default_device);
|
||||
|
||||
} // namespace TF
|
||||
|
||||
namespace TFControlFlow {
|
||||
|
@ -0,0 +1,68 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file implements device assignment in TF dialect.
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
namespace {
|
||||
|
||||
class SimpleTFDeviceAssignmentPass
|
||||
: public FunctionPass<SimpleTFDeviceAssignmentPass> {
|
||||
public:
|
||||
SimpleTFDeviceAssignmentPass() = default;
|
||||
SimpleTFDeviceAssignmentPass(const SimpleTFDeviceAssignmentPass&) {}
|
||||
explicit SimpleTFDeviceAssignmentPass(llvm::StringRef default_device) {
|
||||
default_device_ = std::string(default_device);
|
||||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
Builder builder(&getContext());
|
||||
getFunction().walk([this, &builder](Operation* op) {
|
||||
if (auto device_attr = op->getAttrOfType<StringAttr>("device")) {
|
||||
// We assign default device to ops with device attribute that is empty.
|
||||
if (device_attr.getValue() == "") {
|
||||
op->setAttr("device", builder.getStringAttr(default_device_));
|
||||
}
|
||||
} else if (llvm::isa<ConstOp>(op)) {
|
||||
// tf.Const may sometimes contain no device attribute. In this case, we
|
||||
// assign it the default device.
|
||||
op->setAttr("device", builder.getStringAttr(default_device_));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
Option<std::string> default_device_{
|
||||
*this, "default-device", llvm::cl::desc("The default device to assign."),
|
||||
llvm::cl::init("cpu")};
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateSimpleTFDeviceAssignmentPass(
|
||||
llvm::StringRef default_device) {
|
||||
return std::make_unique<SimpleTFDeviceAssignmentPass>(default_device);
|
||||
}
|
||||
|
||||
static PassRegistration<SimpleTFDeviceAssignmentPass> pass(
|
||||
"tf-simple-device-assignment", "Simple device assignment in TF dialect.");
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
Loading…
x
Reference in New Issue
Block a user