Add a simple device assignment pass with default-device option
PiperOrigin-RevId: 293483851 Change-Id: I20f5b65d1e78357de70c28cc278881c7da72281f
This commit is contained in:
parent
fed12f5bdc
commit
bdef2c56b3
@ -255,6 +255,7 @@ cc_library(
|
|||||||
"transforms/shape_inference_pass.cc",
|
"transforms/shape_inference_pass.cc",
|
||||||
"transforms/sink_constant.cc",
|
"transforms/sink_constant.cc",
|
||||||
"transforms/test_side_effect_analysis.cc",
|
"transforms/test_side_effect_analysis.cc",
|
||||||
|
"transforms/tf_device_assignment.cc",
|
||||||
"transforms/tpu_cluster_formation.cc",
|
"transforms/tpu_cluster_formation.cc",
|
||||||
"transforms/tpu_dynamic_layout_pass.cc",
|
"transforms/tpu_dynamic_layout_pass.cc",
|
||||||
"transforms/tpu_dynamic_padding_mapper.cc",
|
"transforms/tpu_dynamic_padding_mapper.cc",
|
||||||
|
@ -0,0 +1,13 @@
|
|||||||
|
// RUN: tf-opt -tf-simple-device-assignment='default-device=gpu' %s | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @device_test
|
||||||
|
func @device_test(%arg0: tensor<3x1xf32>) -> (tensor<3x3xf32>) {
|
||||||
|
|
||||||
|
// CHECK: device = "gpu"
|
||||||
|
%0 = "tf.Const"() {value = dense<[[1.0, 2.0, 3.0]]> : tensor<1x3xf32>} : () -> tensor<1x3xf32>
|
||||||
|
// CHECK: device = "gpu"
|
||||||
|
%1 = "tf.MatMul"(%arg0, %0) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "", transpose_a = false, transpose_b = false} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
|
||||||
|
// CHECK: device = "cpu"
|
||||||
|
%2 = "tf.Relu"(%1) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "cpu"} : (tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||||
|
return %2 : tensor<3x3xf32>
|
||||||
|
}
|
@ -79,6 +79,10 @@ LogicalResult MarkFunctionVisibilityUsingEntryFunctionSpecification(
|
|||||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
std::unique_ptr<OpPassBase<ModuleOp>>
|
||||||
CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass();
|
CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass();
|
||||||
|
|
||||||
|
// Create a simple device assignment pass on TF dialect for CoreRT use case.
|
||||||
|
std::unique_ptr<OpPassBase<FuncOp>> CreateSimpleTFDeviceAssignmentPass(
|
||||||
|
llvm::StringRef default_device);
|
||||||
|
|
||||||
} // namespace TF
|
} // namespace TF
|
||||||
|
|
||||||
namespace TFControlFlow {
|
namespace TFControlFlow {
|
||||||
|
@ -0,0 +1,68 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This file implements device assignment in TF dialect.
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace TF {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class SimpleTFDeviceAssignmentPass
|
||||||
|
: public FunctionPass<SimpleTFDeviceAssignmentPass> {
|
||||||
|
public:
|
||||||
|
SimpleTFDeviceAssignmentPass() = default;
|
||||||
|
SimpleTFDeviceAssignmentPass(const SimpleTFDeviceAssignmentPass&) {}
|
||||||
|
explicit SimpleTFDeviceAssignmentPass(llvm::StringRef default_device) {
|
||||||
|
default_device_ = std::string(default_device);
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnFunction() override {
|
||||||
|
Builder builder(&getContext());
|
||||||
|
getFunction().walk([this, &builder](Operation* op) {
|
||||||
|
if (auto device_attr = op->getAttrOfType<StringAttr>("device")) {
|
||||||
|
// We assign default device to ops with device attribute that is empty.
|
||||||
|
if (device_attr.getValue() == "") {
|
||||||
|
op->setAttr("device", builder.getStringAttr(default_device_));
|
||||||
|
}
|
||||||
|
} else if (llvm::isa<ConstOp>(op)) {
|
||||||
|
// tf.Const may sometimes contain no device attribute. In this case, we
|
||||||
|
// assign it the default device.
|
||||||
|
op->setAttr("device", builder.getStringAttr(default_device_));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Option<std::string> default_device_{
|
||||||
|
*this, "default-device", llvm::cl::desc("The default device to assign."),
|
||||||
|
llvm::cl::init("cpu")};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OpPassBase<FuncOp>> CreateSimpleTFDeviceAssignmentPass(
|
||||||
|
llvm::StringRef default_device) {
|
||||||
|
return std::make_unique<SimpleTFDeviceAssignmentPass>(default_device);
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<SimpleTFDeviceAssignmentPass> pass(
|
||||||
|
"tf-simple-device-assignment", "Simple device assignment in TF dialect.");
|
||||||
|
|
||||||
|
} // namespace TF
|
||||||
|
} // namespace mlir
|
Loading…
x
Reference in New Issue
Block a user