diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index dbe68a10099..4b01c8d0655 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -255,6 +255,7 @@ cc_library( "transforms/shape_inference_pass.cc", "transforms/sink_constant.cc", "transforms/test_side_effect_analysis.cc", + "transforms/tf_device_assignment.cc", "transforms/tpu_cluster_formation.cc", "transforms/tpu_dynamic_layout_pass.cc", "transforms/tpu_dynamic_padding_mapper.cc", diff --git a/tensorflow/compiler/mlir/tensorflow/tests/device_assignment.mlir b/tensorflow/compiler/mlir/tensorflow/tests/device_assignment.mlir new file mode 100644 index 00000000000..1f1e6c63f30 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/device_assignment.mlir @@ -0,0 +1,13 @@ +// RUN: tf-opt -tf-simple-device-assignment='default-device=gpu' %s | FileCheck %s + +// CHECK-LABEL: func @device_test +func @device_test(%arg0: tensor<3x1xf32>) -> (tensor<3x3xf32>) { + + // CHECK: device = "gpu" + %0 = "tf.Const"() {value = dense<[[1.0, 2.0, 3.0]]> : tensor<1x3xf32>} : () -> tensor<1x3xf32> + // CHECK: device = "gpu" + %1 = "tf.MatMul"(%arg0, %0) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "", transpose_a = false, transpose_b = false} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32> + // CHECK: device = "cpu" + %2 = "tf.Relu"(%1) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "cpu"} : (tensor<3x3xf32>) -> tensor<3x3xf32> + return %2 : tensor<3x3xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 740dca71710..ff8f571f7d6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -79,6 +79,10 @@ LogicalResult MarkFunctionVisibilityUsingEntryFunctionSpecification( std::unique_ptr> CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass(); +// Create a simple device assignment pass on TF dialect for CoreRT use case. +std::unique_ptr> CreateSimpleTFDeviceAssignmentPass( + llvm::StringRef default_device); + } // namespace TF namespace TFControlFlow { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc new file mode 100644 index 00000000000..a4a8c1ab95f --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc @@ -0,0 +1,68 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements device assignment in TF dialect. +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" + +namespace mlir { +namespace TF { +namespace { + +class SimpleTFDeviceAssignmentPass + : public FunctionPass { + public: + SimpleTFDeviceAssignmentPass() = default; + SimpleTFDeviceAssignmentPass(const SimpleTFDeviceAssignmentPass&) {} + explicit SimpleTFDeviceAssignmentPass(llvm::StringRef default_device) { + default_device_ = std::string(default_device); + } + + void runOnFunction() override { + Builder builder(&getContext()); + getFunction().walk([this, &builder](Operation* op) { + if (auto device_attr = op->getAttrOfType("device")) { + // We assign default device to ops with device attribute that is empty. + if (device_attr.getValue() == "") { + op->setAttr("device", builder.getStringAttr(default_device_)); + } + } else if (llvm::isa(op)) { + // tf.Const may sometimes contain no device attribute. In this case, we + // assign it the default device. + op->setAttr("device", builder.getStringAttr(default_device_)); + } + }); + } + + private: + Option default_device_{ + *this, "default-device", llvm::cl::desc("The default device to assign."), + llvm::cl::init("cpu")}; +}; + +} // namespace + +std::unique_ptr> CreateSimpleTFDeviceAssignmentPass( + llvm::StringRef default_device) { + return std::make_unique(default_device); +} + +static PassRegistration pass( + "tf-simple-device-assignment", "Simple device assignment in TF dialect."); + +} // namespace TF +} // namespace mlir