Add a simple device assignment pass with default-device option

PiperOrigin-RevId: 293483851
Change-Id: I20f5b65d1e78357de70c28cc278881c7da72281f
This commit is contained in:
A. Unique TensorFlower 2020-02-05 16:52:42 -08:00 committed by TensorFlower Gardener
parent fed12f5bdc
commit bdef2c56b3
4 changed files with 86 additions and 0 deletions

View File

@ -255,6 +255,7 @@ cc_library(
"transforms/shape_inference_pass.cc",
"transforms/sink_constant.cc",
"transforms/test_side_effect_analysis.cc",
"transforms/tf_device_assignment.cc",
"transforms/tpu_cluster_formation.cc",
"transforms/tpu_dynamic_layout_pass.cc",
"transforms/tpu_dynamic_padding_mapper.cc",

View File

@ -0,0 +1,13 @@
// RUN: tf-opt -tf-simple-device-assignment='default-device=gpu' %s | FileCheck %s
// CHECK-LABEL: func @device_test
func @device_test(%arg0: tensor<3x1xf32>) -> (tensor<3x3xf32>) {
// CHECK: device = "gpu"
%0 = "tf.Const"() {value = dense<[[1.0, 2.0, 3.0]]> : tensor<1x3xf32>} : () -> tensor<1x3xf32>
// CHECK: device = "gpu"
%1 = "tf.MatMul"(%arg0, %0) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "", transpose_a = false, transpose_b = false} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
// CHECK: device = "cpu"
%2 = "tf.Relu"(%1) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "cpu"} : (tensor<3x3xf32>) -> tensor<3x3xf32>
return %2 : tensor<3x3xf32>
}

View File

@ -79,6 +79,10 @@ LogicalResult MarkFunctionVisibilityUsingEntryFunctionSpecification(
std::unique_ptr<OpPassBase<ModuleOp>>
CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass();
// Create a simple device assignment pass on TF dialect for CoreRT use case.
std::unique_ptr<OpPassBase<FuncOp>> CreateSimpleTFDeviceAssignmentPass(
llvm::StringRef default_device);
} // namespace TF
namespace TFControlFlow {

View File

@ -0,0 +1,68 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements device assignment in TF dialect.
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
namespace mlir {
namespace TF {
namespace {
class SimpleTFDeviceAssignmentPass
: public FunctionPass<SimpleTFDeviceAssignmentPass> {
public:
SimpleTFDeviceAssignmentPass() = default;
SimpleTFDeviceAssignmentPass(const SimpleTFDeviceAssignmentPass&) {}
explicit SimpleTFDeviceAssignmentPass(llvm::StringRef default_device) {
default_device_ = std::string(default_device);
}
void runOnFunction() override {
Builder builder(&getContext());
getFunction().walk([this, &builder](Operation* op) {
if (auto device_attr = op->getAttrOfType<StringAttr>("device")) {
// We assign default device to ops with device attribute that is empty.
if (device_attr.getValue() == "") {
op->setAttr("device", builder.getStringAttr(default_device_));
}
} else if (llvm::isa<ConstOp>(op)) {
// tf.Const may sometimes contain no device attribute. In this case, we
// assign it the default device.
op->setAttr("device", builder.getStringAttr(default_device_));
}
});
}
private:
Option<std::string> default_device_{
*this, "default-device", llvm::cl::desc("The default device to assign."),
llvm::cl::init("cpu")};
};
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateSimpleTFDeviceAssignmentPass(
llvm::StringRef default_device) {
return std::make_unique<SimpleTFDeviceAssignmentPass>(default_device);
}
static PassRegistration<SimpleTFDeviceAssignmentPass> pass(
"tf-simple-device-assignment", "Simple device assignment in TF dialect.");
} // namespace TF
} // namespace mlir