Merge remote-tracking branch 'upstream/master' into offline_memory_planner

This commit is contained in:
Jens Elofsson 2020-06-10 09:17:13 +02:00
commit dc3c76758e
462 changed files with 24588 additions and 55920 deletions

View File

@ -1 +1 @@
3.0.0
3.1.0

View File

@ -49,7 +49,7 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
_TF_WORKSPACE_ROOT = ''
_TF_BAZELRC = ''
_TF_CURRENT_BAZEL_VERSION = None
_TF_MIN_BAZEL_VERSION = '2.0.0'
_TF_MIN_BAZEL_VERSION = '3.1.0'
_TF_MAX_BAZEL_VERSION = '3.99.0'
NCCL_LIB_PATHS = [

View File

@ -40,6 +40,9 @@ using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
using MaybeParallelTensorOwned =
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
using MaybeParallelTensorUnowned =
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
// A ParallelDevice on its own is not registered with a TFE_Context, and so has
// no device name (e.g. for `tf.device`). `NamedParallelDevice` associates a
// name with it, which lets us pack its `ParallelTensor`s into TFE_TensorHandles
@ -141,9 +144,32 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
result.emplace(std::move(result_content));
return result;
}
std::vector<ParallelTensor*> parallel_inputs;
std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors;
parallel_inputs.reserve(inputs.size());
implicitly_broadcast_tensors.reserve(inputs.size()); // not tight
for (const auto& input : inputs) {
if (absl::holds_alternative<TFE_TensorHandle*>(input)) {
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
// to each parallel operation.
//
// TODO(allenl): There may be smarter ways to do this copy in some
// cases, i.e. with a collective broadcast. We'll need to be careful
// about things that are taken as inputs on the host or on their
// existing device (for multi-device functions).
std::unique_ptr<ParallelTensor> parallel_tensor(
parallel_device.CopyToParallelDevice(
context, absl::get<TFE_TensorHandle*>(input), status));
if (TF_GetCode(status) != TF_OK) return result;
parallel_inputs.push_back(parallel_tensor.get());
implicitly_broadcast_tensors.emplace_back(std::move(parallel_tensor));
} else {
parallel_inputs.push_back(absl::get<ParallelTensor*>(input));
}
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
maybe_parallel_results(
parallel_device.Execute(context, std::move(inputs), operation_name,
parallel_device.Execute(context, parallel_inputs, operation_name,
attributes, expected_max_outputs, status));
if (!maybe_parallel_results.has_value()) return result;
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(

View File

@ -100,7 +100,7 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ParallelDevice::Execute(TFE_Context* context,
std::vector<MaybeParallelTensorUnowned> inputs,
const std::vector<ParallelTensor*>& inputs,
const char* operation_name,
const TFE_OpAttrs* attributes, int expected_max_outputs,
TF_Status* status) const {
@ -129,26 +129,10 @@ ParallelDevice::Execute(TFE_Context* context,
status);
TFE_OpAddAttrs(op.get(), attributes);
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
// to each parallel operation.
//
// TODO(allenl): There may be smarter ways to do this copy in some
// cases, i.e. with a collective broadcast. We'll need to be careful
// about things that are taken as inputs on the host or on their
// existing device (for multi-device functions).
TFE_OpAddInput(op.get(),
absl::get<TFE_TensorHandle*>(inputs[input_index]),
status);
if (TF_GetCode(status) != TF_OK) return result;
} else {
// Parallel tensors are divided between operations by device.
TFE_OpAddInput(op.get(),
absl::get<ParallelTensor*>(inputs[input_index])
->tensor(device_index),
status);
if (TF_GetCode(status) != TF_OK) return result;
}
// Parallel tensors are divided between operations by device.
TFE_OpAddInput(op.get(), inputs[input_index]->tensor(device_index),
status);
if (TF_GetCode(status) != TF_OK) return result;
}
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
int real_num_outputs = expected_max_outputs;

View File

@ -52,9 +52,6 @@ using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
class ParallelTensor;
using MaybeParallelTensorUnowned =
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
// Forwards operations to `devices`, maintaining ParallelTensor with components
// placed on each underlying device.
class ParallelDevice {
@ -79,10 +76,9 @@ class ParallelDevice {
// Takes a description of a single operation being executed on the
// ParallelDevice, and in turn runs one operation per component device with
// its corresponding inputs from the input ParallelTensors (or
// implicitly-mirrored tensors on other devices). Wraps the resulting
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
// output of the original operation.
// its corresponding inputs from the input ParallelTensors. Wraps the
// resulting per-device and per-output TFE_TensorHandles into one
// ParallelTensor per output of the original operation.
//
// Attributes are forwarded to executed operations unmodified.
//
@ -90,7 +86,7 @@ class ParallelDevice {
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
// if sanity checks on dtypes/metadata fail.
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;

View File

@ -468,10 +468,6 @@ Status XlaComputationLaunchContext::PopulateOutputs(
<< "Invalid input for outputs " << i << ": " << input_index;
ctx->set_output(i, ctx->input(input_index));
} else {
if (MustAliasOutput(input_output_alias, output_num)) {
DCHECK(output.buffer({output_num}).is_null())
<< "Expected output buffer to be aliased, but it is not nil.";
}
if (allocate_xla_tensors_) {
TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
input_output_alias, output_num, ctx, i, shape, &output,

View File

@ -32,7 +32,6 @@ struct PassConfig {
lower_tensor_list_ops(false),
trim_functions_whitelist({}),
quant_specs(std::move(specs)),
skip_control_dialect(false),
form_clusters(false),
unfold_batch_matmul(true),
legalize_tf_while(true),
@ -49,13 +48,8 @@ struct PassConfig {
llvm::ArrayRef<std::string> trim_functions_whitelist;
// All information about quantization.
QuantizationSpecs quant_specs;
// If `skip_control_dialect` is true, TF executor dialect is not converted to
// TF control dialect prior to legalization to TF Lite.
// TODO(b/142911013): Remove flag once control dialect is removed.
bool skip_control_dialect;
// If `form_clusters` is true (and `skip_control_dialect` is true), clusters
// are formed by grouping consecutive ops of the same device, under a
// `tf_device.launch` op.
// If `form_clusters` is true , clusters are formed by grouping consecutive
// ops of the same device, under a `tf_device.launch` op.
bool form_clusters;
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
// of tfl.fully_connected ops.

View File

@ -803,9 +803,17 @@ StatusOr<FuncOp> ConvertSubgraph(
}
for (auto output : func_outputs) {
bool is_constant = !is_op_output[output];
const bool is_func_input = input_index_set.contains(output);
bool is_constant = !is_op_output[output] && !is_func_input;
// There are 2 cases tensor is scalar when it doesn't have a shape in
// flatbuffer:
// 1. `is_constant` = true, means this tensor is created from a constant op.
// 2. `is_func_input` = true and `is_entry_point` = true, which means this
// tensor is function input and function input type is a scalar tensor.
const bool shapeless_is_scalar =
is_constant || (is_func_input && is_entry_point);
auto type_or_err = GetTensorType(*subgraph.tensors.at(output), builder,
/*shapeless_are_scalars=*/is_constant,
shapeless_is_scalar,
/*is_constant=*/is_constant);
if (!type_or_err.ok()) {
emitError(func_loc, "error reading return types")

View File

@ -0,0 +1,22 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// This test is to test for unranked function output from input, the output type should be compatible with input type.
// CHECK: func @main(%arg0: tensor<1xf32>) -> tensor<*xf32>
// CHECK: %0 = "tf.While"(%arg0) {body = @body, cond = @cond, is_stateless = false} : (tensor<1xf32>) -> tensor<*xf32>
// CHECK: return %0 : tensor<*xf32>
// CHECK: func @cond(%arg0: tensor<*xf32>) -> tensor<*xf32>
// CHECK: func @body(%arg0: tensor<*xf32>) -> tensor<*xf32>
func @main(%arg0: tensor<1xf32>) -> tensor<*xf32> {
%0 = "tf.While"(%arg0) {cond = @cond, body = @body, is_stateless = false} : (tensor<1xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @cond(%arg1: tensor<*xf32>) -> tensor<*xf32> {
return %arg1: tensor<*xf32>
}
func @body(%arg1: tensor<*xf32>) -> tensor<*xf32> {
return %arg1: tensor<*xf32>
}

File diff suppressed because it is too large Load Diff

View File

@ -58,21 +58,10 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
mlir::OpPassManager* pass_manager) {
pass_manager->addPass(mlir::tf_executor::CreateSwitchFoldPass());
if (pass_config.skip_control_dialect) {
// Merge islands.
pass_manager->addPass(
mlir::tf_executor::CreateTFExecutorIslandCoarseningPass());
// Assuming island coarsening above results in a graph with a single island,
// a canonicalization can be ran to hoist the ops of the single island out.
pass_manager->addPass(mlir::createCanonicalizerPass());
if (pass_config.form_clusters)
pass_manager->addPass(mlir::TFDevice::CreateClusterFormationPass());
} else {
pass_manager->addPass(mlir::CreateTFExecutorToControlDialectConversion());
pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
}
mlir::TF::StandardPipelineOptions standard_pipeline_options;
standard_pipeline_options.enable_inliner = false;
standard_pipeline_options.form_clusters = pass_config.form_clusters;
mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
if (pass_config.shape_inference) {
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
@ -213,13 +202,8 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
OpPassManager& func_pm = pm.nest<FuncOp>();
// tf_executor dialect passes - Cleaning up the IR.
func_pm.addPass(tf_executor::CreateSwitchFoldPass());
func_pm.addPass(tf_executor::CreateTFExecutorGraphPruningPass());
func_pm.addPass(tf_executor::CreateTFExecutorIslandCoarseningPass());
// more cleanup of executor dialect and raise to control flow.
pm.addPass(mlir::CreateTFExecutorToControlDialectConversion());
pm.addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
mlir::TF::StandardPipelineOptions standard_pipeline_options;
mlir::TF::CreateTFStandardPipeline(func_pm, standard_pipeline_options);
// This is needed for control flow support with TF TensorList.
pm.addPass(mlir::TFL::CreateLowerStaticTensorListPass());

View File

@ -38,12 +38,6 @@ limitations under the License.
#include "tensorflow/lite/tools/optimize/quantize_weights.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace mlir {
/// Create a pass to convert from the TFExecutor to the TF control dialect.
std::unique_ptr<OperationPass<FuncOp>>
CreateTFExecutorToControlDialectConversion();
} // namespace mlir
namespace tensorflow {
using mlir::MLIRContext;

View File

@ -54,7 +54,6 @@ class WhileOutlinePass
tensorflow::OpOrArgLocNameMapper mapper_;
};
} // namespace
std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) {
return (mapper_.GetUniqueName(op) + suffix).str();
@ -62,7 +61,7 @@ std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) {
// Returns whether the WhileOp is already outlined (e.g., only consists of calls
// to functions).
static bool IsAlreadyOutlinedd(WhileOp while_op) {
bool IsAlreadyOutlined(WhileOp while_op) {
auto just_call = [](Region& region) {
auto it = region.front().begin();
if (!isa<CallOp>(*it)) return false;
@ -120,7 +119,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
}
// Skip if already just calls.
if (extra_operands.empty() && IsAlreadyOutlinedd(while_op)) return;
if (extra_operands.empty() && IsAlreadyOutlined(while_op)) return;
// Collect new types.
SmallVector<Type, 4> types;
@ -238,6 +237,7 @@ void WhileOutlinePass::runOnOperation() {
getOperation().walk(
[&](mlir::TFL::WhileOp while_op) { OutlineWhile(while_op); });
}
} // namespace
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass() {

View File

@ -67,27 +67,57 @@ inline RankedTensorType getResultType(mlir::FuncOp func, int idx) {
}
LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
if (func.getNumResults() != 2) {
return failure();
}
if (func.getNumArguments() != 1) {
return failure();
}
// In the case of input tensor with 0 rank.
// Whitespace tokenizer generates 1 output:
// * String tensor for tokens.
//
// In the case of 1-D input tensor,
// Whitespace tokenizer generates 2 outputs to make up a ragged tensor:
// * 1st output is the value of ragged tensor;
// * 2nd output is the offset.
//
// In the case of batched input tesnor,
// Whitespace tokenizer has 3 outputs to make up a nested ragged tensor:
// * 1st output is the value of ragged tensor;
// * 2nd output is the inner offset;
// * 3rd output is the outer offset.
auto input_type = getInputType(func, 0);
if (!input_type || input_type.getRank() != 1 ||
!input_type.getElementType().isa<mlir::TF::StringType>()) {
return failure();
if (!input_type || !input_type.getElementType().isa<mlir::TF::StringType>() ||
!input_type.hasRank()) {
return func.emitError() << "Input should be a string tensor";
}
const std::vector<int> kValidNumOfOutput = {1, 2, 3};
if (input_type.getRank() >= kValidNumOfOutput.size()) {
return func.emitError()
<< "Unrecognized input rank: " << input_type.getRank();
}
if (func.getNumResults() != kValidNumOfOutput[input_type.getRank()]) {
return func.emitError()
<< "Expect " << kValidNumOfOutput[input_type.getRank()]
<< "output(s) when input has rank " << input_type.getRank();
}
auto value_type = getResultType(func, 0);
if (!value_type || value_type.getRank() != 1 ||
if (!value_type || !value_type.hasRank() || value_type.getRank() != 1 ||
!value_type.getElementType().isa<mlir::TF::StringType>()) {
return failure();
return func.emitError() << "1st output should be string tensor";
}
auto offset_type = getResultType(func, 1);
if (offset_type.getRank() != 1 ||
!offset_type.getElementType().isInteger(64)) {
return failure();
if (func.getNumResults() > 1) {
auto offset_type = getResultType(func, 1);
if (!offset_type || !offset_type.hasRank() || offset_type.getRank() != 1 ||
!offset_type.getElementType().isInteger(64)) {
return func.emitError() << "2nd output should be int64 tensor";
}
}
if (func.getNumResults() > 2) {
auto offset_type = getResultType(func, 2);
if (!offset_type || !offset_type.hasRank() || offset_type.getRank() != 1 ||
!offset_type.getElementType().isInteger(64)) {
return func.emitError() << "3rd output should be int64 tensor";
}
}
return success();
}
@ -96,19 +126,12 @@ LogicalResult ConvertWhitespaceTokenizer(mlir::FuncOp func,
func.eraseBody();
func.addEntryBlock();
func.setAttr(kTFAPIImplements, StringAttr::get(api, func.getContext()));
Value text = func.getArgument(0);
auto output_type = func.getType().getResult(0);
auto offset_type = func.getType().getResult(1);
SmallVector<Type, 2> shape = {output_type, offset_type};
ArrayRef<Type> output_types(shape);
OpBuilder builder(func.getBody());
auto op = builder.create<mlir::TFL::CustomOp>(func.getLoc(), output_types,
ValueRange(text), api,
emptyCustomOption(&builder));
auto op = builder.create<mlir::TFL::CustomOp>(
func.getLoc(), func.getType().getResults(), ValueRange(text), api,
emptyCustomOption(&builder));
builder.create<mlir::ReturnOp>(func.getLoc(), op.getResults());
return success();
}

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import os
import platform
import sys
import lit.formats
from lit.llvm import llvm_config
from lit.llvm.subst import ToolSubst

View File

@ -431,7 +431,6 @@ cc_library(
"transforms/optimize_global_tensors.cc",
"transforms/parallel_execute_to_islands.cc",
"transforms/promote_resources_to_args.cc",
"transforms/raise_control_flow.cc",
"transforms/readonly_references_to_resources.cc",
"transforms/replicate_invariant_op_hoisting.cc",
"transforms/replicate_to_island.cc",
@ -460,7 +459,6 @@ cc_library(
"transforms/tpu_variable_runtime_reformatting.cc",
"translate/breakup-islands.cc",
"translate/control_to_executor_dialect.cc",
"translate/executor_to_control_dialect.cc",
"translate/tf_functional_to_executor.cc",
],
hdrs = [

View File

@ -811,11 +811,13 @@ ParseResult ParseEnterOp(OpAsmParser &parser, OperationState &result) {
// fully qualified) or a short form with a single type (in which case the data
// input and the outputs are all using this type).
if (FunctionType type = types.front().dyn_cast<FunctionType>()) {
if (type.getNumInputs() != 1)
return parser.emitError(parser.getNameLoc())
<< " expects a single data type";
result.types.assign(type.getResults().begin(), type.getResults().end());
types.assign(type.getInputs().begin(), type.getInputs().end());
// One data input, and any number of control inputs.
if (type.getNumInputs() >= 1) {
result.types.assign(type.getResults().begin(), type.getResults().end());
types.assign(type.getInputs().begin(), type.getInputs().end());
} else {
return parser.emitError(parser.getNameLoc()) << " expects a data input";
}
} else {
Type control_type = ControlType::get(context);
types.append(op_infos.size() - 1, control_type);

View File

@ -254,9 +254,12 @@ LogicalResult VerifyExportedFunc(FuncOp func) {
}
continue;
}
if (func.getArgAttr(i, "tf.resource_name")) {
continue;
}
return func.emitError()
<< "all arguments should have 'tf_saved_model.index_path' or "
"'tf_saved_model.bound_input' attributes";
<< "all arguments should have 'tf_saved_model.index_path', "
"'tf_saved_model.bound_input' or 'tf.resource_name' attributes";
}
llvm::SmallDenseSet<StringRef, 8> unique_bound_inputs;
for (int i = 0, e = func.getNumArguments(); i < e; i++) {

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -tf-executor-to-control-conversion %s | FileCheck %s --check-prefix=CONTROL --dump-input=fail
// RUN: tf-opt -tf-executor-graph-pruning %s | FileCheck %s --check-prefix=CONTROL --dump-input=fail
// RUN: tf-opt -tf-control-to-executor-conversion %s | FileCheck %s --check-prefix=EXECUTOR --dump-input=fail
// CONTROL-LABEL: func @main

View File

@ -1,188 +0,0 @@
// RUN: tf-opt -tf-executor-to-control-conversion %s | FileCheck %s --dump-input=fail
// CHECK-LABEL: func @LoopTest() {
func @LoopTest() {
tf_executor.graph {
%0:2 = tf_executor.island {
%cst = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
tf_executor.yield %cst : tensor<i32>
}
%1:2 = tf_executor.Enter %0#0 frame "while/while_context" : (tensor<i32>) -> (tensor<*xi32>, !tf_executor.control) {T = "tfdtype$DT_INT32", device = "", name = "while/Enter"}
%2 = tf_executor.island {
"tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> ()
tf_executor.yield
}
%3:3 = tf_executor.NextIteration.Source : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"}
%4:3 = tf_executor.Merge %3#0, %1#0 : tensor<*xi32> {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "while/Merge"}
%5:2 = tf_executor.island(%4#2) {
%cst = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<2> : tensor<i32>} : () -> tensor<i32>
tf_executor.yield %cst : tensor<i32>
}
%6:2 = tf_executor.island {
%14 = "tf.Less"(%4#0, %5#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
tf_executor.yield %14 : tensor<*xi1>
}
%7:2 = tf_executor.LoopCond %6#0 : (tensor<*xi1>) -> (tensor<i1>, !tf_executor.control) {device = "", name = "while/LoopCond"}
%8:3 = tf_executor.Switch %4#0, %7#0 : tensor<*xi32> {T = "tfdtype$DT_INT32", _class = ["loc = @while/Merge"], device = "", name = "while/Switch"}
%9:2 = tf_executor.Exit %8#0 : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", name = "while/Exit"}
%10:2 = tf_executor.island {
%14 = "tf.Identity"(%8#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> tensor<*xi32>
tf_executor.yield %14 : tensor<*xi32>
}
%11:2 = tf_executor.island(%10#1) {
%cst = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<3> : tensor<i32>} : () -> tensor<i32>
tf_executor.yield %cst : tensor<i32>
}
%12:2 = tf_executor.island {
%14 = "tf.Add"(%10#0, %11#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
tf_executor.yield %14 : tensor<*xi32>
}
%13 = tf_executor.ControlTrigger %2, %12#1, %9#1 {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"}
tf_executor.NextIteration.Sink [%3#1] %12#0, %13 : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"}
tf_executor.fetch
}
return
}
// CHECK-NEXT: %[[CONST:[0-9]*]]:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
// CHECK-NEXT: %[[ENTER:[0-9]*]]:2 = "_tf.Enter"(%[[CONST]]#0) {T = "tfdtype$DT_INT32", device = "", frame_name = "while/while_context", is_constant = false, name = "while/Enter", parallel_iterations = 10 : i64} : (tensor<i32>) -> (tensor<*xi32>, !_tf.control)
// CHECK-NEXT: %[[NOOP:[0-9]*]] = "_tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> !_tf.control
// CHECK-NEXT: %[[SOURCE:[0-9]*]]:2 = "_tf.NextIteration.source"() {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} : () -> (tensor<*xi32>, !_tf.control)
// CHECK-NEXT: %[[MERGE:[0-9]*]]:3 = "_tf.Merge"(%[[SOURCE]]#0, %[[ENTER]]#0) {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "while/Merge"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<i32>, !_tf.control)
// CHECK-NEXT: %[[CONST_LESS:[0-9]*]]:2 = "_tf.Const"(%[[MERGE]]#2) {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<2> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control)
// CHECK-NEXT: %[[LESS:[0-9]*]]:2 = "_tf.Less"(%[[MERGE]]#0, %[[CONST_LESS]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi1>, !_tf.control)
// CHECK-NEXT: %[[COND:[0-9]*]]:2 = "_tf.LoopCond"(%[[LESS]]#0) {device = "", name = "while/LoopCond"} : (tensor<*xi1>) -> (tensor<i1>, !_tf.control)
// CHECK-NEXT: %[[SWITCH:[0-9]*]]:3 = "_tf.Switch"(%[[MERGE]]#0, %[[COND]]#0) {T = "tfdtype$DT_INT32", _class = ["loc = @while/Merge"], device = "", name = "while/Switch"} : (tensor<*xi32>, tensor<i1>) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control)
// CHECK-NEXT: %[[EXIT:[0-9]*]]:2 = "_tf.Exit"(%[[SWITCH]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Exit"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
// CHECK-NEXT: %[[IDENTITY:[0-9]*]]:2 = "_tf.Identity"(%[[SWITCH]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
// CHECK-NEXT: %[[CONST_ADD:[0-9]*]]:2 = "_tf.Const"(%[[IDENTITY]]#1) {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<3> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control)
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = "_tf.Add"(%[[IDENTITY]]#0, %[[CONST_ADD]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control)
// CHECK-NEXT: %[[CT:[0-9]*]] = "_tf.ControlTrigger"(%[[NOOP]], %[[ADD]]#1, %[[EXIT]]#1) {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"} : (!_tf.control, !_tf.control, !_tf.control) -> !_tf.control
// CHECK-NEXT: %[[SINK:[0-9]*]] = "_tf.NextIteration.sink"(%[[ADD]]#0, %[[CT]]) {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} : (tensor<*xi32>, !_tf.control) -> !_tf.control
// CHECK-NEXT: return
// -----
// CHECK-LABEL: func @multiple_ops_region
func @multiple_ops_region(%arg0 : tensor<*xi32>, %arg1 : tensor<i32>) {
tf_executor.graph {
%0:2 = tf_executor.island {
// The 4 operations are independent, but the current conversion will add
// control dependencies conservatively.
%1 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add1"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%2 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add2"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%3 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add3"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%4 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add4"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
tf_executor.yield %4 : tensor<*xi32>
}
tf_executor.fetch
}
return
}
// CHECK-NEXT: %[[ADD1:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add1"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control)
// CHECK-NEXT: %[[ADD2:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1, %[[ADD1]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add2"} : (tensor<*xi32>, tensor<i32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
// CHECK-NEXT: %[[ADD3:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1, %[[ADD2]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add3"} : (tensor<*xi32>, tensor<i32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
// CHECK-NEXT: %[[ADD4:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1, %[[ADD3]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add4"} : (tensor<*xi32>, tensor<i32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
// -----
// CHECK-LABEL: func @switchN(
func @switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%fetches = tf_executor.graph {
// CHECK: [[S1:%.*]]:6 = "_tf._SwitchN"(%arg1, %arg0) {num_outs = 5 : i64}
%1:6 = tf_executor.SwitchN %arg1, %arg0 of 5 : tensor<*xf32>
// CHECK: "_tf._SwitchN"(%arg1, %arg0, [[S1]]#5) {num_outs = 12 : i64}
%2:13 = tf_executor.SwitchN %arg1, %arg0 of 12 (%1#5) : tensor<*xf32>
tf_executor.fetch %2#0 : tensor<*xf32>
}
return %fetches : tensor<*xf32>
}
// -----
// Test if tf_executor dialect ops with Ref types are mapped correctly to the ops in control dialect.
// CHECK-LABEL: func @ref_tf_executor_ops
func @ref_tf_executor_ops(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4x!tf.f32ref>, %arg3: tensor<i32>, %arg4: tensor<i1> ) -> tensor<4x!tf.f32ref> {
%result = tf_executor.graph {
// CHECK: _tf.Enter
%0:2 = tf_executor.Enter %arg0 frame "while/while_context" : (tensor<4x!tf.f32ref>) -> (tensor<4x!tf.f32ref>, !tf_executor.control)
// CHECK: _tf.Exit
%1:2 = tf_executor.Exit %arg0 : tensor<4x!tf.f32ref>
// CHECK: _tf.Switch
%2:3 = tf_executor.Switch %arg0, %arg4 : (tensor<4x!tf.f32ref>, tensor<i1>) -> (tensor<4x!tf.f32ref>, tensor<4x!tf.f32ref>, !tf_executor.control)
// CHECK: _tf.Merge
%3:3 = tf_executor.Merge %arg0, %arg1 : (tensor<4x!tf.f32ref>, tensor<4x!tf.f32ref>) -> (tensor<4x!tf.f32ref>, tensor<i32>, !tf_executor.control)
// CHECK: _tf.NextIteration.source
%4:3 = tf_executor.NextIteration.Source : tensor<4x!tf.f32ref>
// CHECK: _tf.NextIteration.sink
tf_executor.NextIteration.Sink [%4#1] %4#0 : tensor<4x!tf.f32ref>
tf_executor.fetch %0#0 : tensor<4x!tf.f32ref>
}
return %result : tensor<4x!tf.f32ref>
}
// -----
// Tests if empty island with just one control dependency input and output is
// handled correctly.
// CHECK-LABEL: func @empty_island_control_dep_only
func @empty_island_control_dep_only() -> tensor<i32> {
%fetch = tf_executor.graph {
%0:2 = tf_executor.island {
%4 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
tf_executor.yield %4 : tensor<i32>
}
// CHECK-NEXT: %[[CONST1:[0-9]*]]:2 = "_tf.Const"()
// CHECK-SAME: () -> (tensor<i32>, !_tf.control)
%1:2 = tf_executor.island {
%5 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
tf_executor.yield %5 : tensor<i32>
}
// CHECK-NEXT: %[[CONST2:[0-9]*]]:2 = "_tf.Const"()
// CHECK-SAME: () -> (tensor<i32>, !_tf.control)
%2 = tf_executor.island(%0#1) {
tf_executor.yield
}
%3:2 = tf_executor.island(%2, %1#1) {
%6 = "tf.Add"(%0#0, %1#0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %6 : tensor<i32>
}
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = "_tf.Add"(%[[CONST1]]#0, %[[CONST2]]#0, %[[CONST1]]#1, %[[CONST2]]#1)
// CHECK-SAME: (tensor<i32>, tensor<i32>, !_tf.control, !_tf.control) -> (tensor<i32>, !_tf.control)
tf_executor.fetch %3#0 : tensor<i32>
}
return %fetch : tensor<i32>
}
// -----
// Tests if empty island with multiple control inputs will be replaced with a
// no-op.
// CHECK-LABEL: func @empty_island_multi_control_inputs
func @empty_island_multi_control_inputs() -> tensor<i32> {
%fetch = tf_executor.graph {
%0:2 = tf_executor.island {
%4 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
tf_executor.yield %4 : tensor<i32>
}
// CHECK-NEXT: %[[CONST1:[0-9]*]]:2 = "_tf.Const"()
// CHECK-SAME: () -> (tensor<i32>, !_tf.control)
%1:2 = tf_executor.island {
%5 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
tf_executor.yield %5 : tensor<i32>
}
// CHECK-NEXT: %[[CONST2:[0-9]*]]:2 = "_tf.Const"()
// CHECK-SAME: () -> (tensor<i32>, !_tf.control)
%2 = tf_executor.island(%0#1, %1#1) {
tf_executor.yield
}
// CHECK-NEXT: %[[NOOP:[0-9]*]] = "_tf.NoOp"(%[[CONST1]]#1, %[[CONST2]]#1)
// CHECK-SAME: (!_tf.control, !_tf.control) -> !_tf.control
%3:2 = tf_executor.island(%2) {
%6 = "tf.Add"(%0#0, %1#0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %6 : tensor<i32>
}
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = "_tf.Add"(%[[CONST1]]#0, %[[CONST2]]#0, %[[NOOP]])
// CHECK-SAME: (tensor<i32>, tensor<i32>, !_tf.control) -> (tensor<i32>, !_tf.control)
tf_executor.fetch %3#0 : tensor<i32>
}
return %fetch : tensor<i32>
}

View File

@ -723,6 +723,11 @@ func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x1
return %0 : tensor<3x8x8x16xf32>
}
func @convert_dot_general(%arg0: tensor<3x2x6x5x1xf32>, %arg1: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> {
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<[1, 2]> : tensor<2xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<[1, 3]> : tensor<2xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2x6x5x1xf32>, tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32>
return %0 : tensor<3x5x1x4xf32>
}
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// CHECK-LABEL: func @biasAdd_NHWC(
@ -1596,3 +1601,14 @@ func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x1
// CHECK: [[VAL_402:%.*]] = "tf.BroadcastTo"([[VAL_400]], [[VAL_401]]) : (tensor<3x1x1x16xf32>, tensor<4xi64>) -> tensor<3x8x8x16xf32>
// CHECK: return [[VAL_402]] : tensor<3x8x8x16xf32>
// CHECK: }
// CHECK-LABEL: func @convert_dot_general(
// CHECK-SAME: [[VAL_396:%.*]]: tensor<3x2x6x5x1xf32>, [[VAL_397:%.*]]: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> {
// CHECK: [[VAL_398:%.*]] = "tf.Transpose"([[VAL_396]], {{.*}}) : (tensor<3x2x6x5x1xf32>, tensor<5xi64>) -> tensor<3x5x1x2x6xf32>
// CHECK: [[VAL_399:%.*]] = "tf.Transpose"([[VAL_397]], {{.*}}) : (tensor<3x2x4x6xf32>, tensor<4xi64>) -> tensor<3x2x6x4xf32>
// CHECK: [[VAL_400:%.*]] = "tf.Reshape"([[VAL_398]], {{.*}}) : (tensor<3x5x1x2x6xf32>, tensor<3xi64>) -> tensor<3x5x12xf32>
// CHECK: [[VAL_401:%.*]] = "tf.Reshape"([[VAL_399]], {{.*}}) : (tensor<3x2x6x4xf32>, tensor<3xi64>) -> tensor<3x12x4xf32>
// CHECK: [[VAL_402:%.*]] = "tf.BatchMatMulV2"([[VAL_400]], [[VAL_401]]) {adj_x = false, adj_y = false} : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32>
// CHECK: [[VAL_403:%.*]] = "tf.Reshape"([[VAL_402]], {{.*}}) : (tensor<3x5x4xf32>, tensor<4xi64>) -> tensor<3x5x1x4xf32>
// CHECK: return [[VAL_403]] : tensor<3x5x1x4xf32>
// CHECK: }

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -op-fusion | FileCheck %s --dump-input-on-failure
// RUN: tf-opt %s -tf-op-fusion | FileCheck %s --dump-input-on-failure
//===----------------------------------------------------------------------===//
// Conv2D + BiasAdd + <Activation> fusions.

View File

@ -1,5 +1,4 @@
// Run a pass for promoting tf.VarHandleOps to function arguments in a format of TensorFlowSavedModelDialect.
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-saved-model-promote-var-handles-to-args | FileCheck %s -dump-input-on-failure
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-promote-var-handles-to-args | FileCheck %s -dump-input-on-failure
// Tests main function with multiple blocks.
@ -12,27 +11,24 @@ func @main() {
// -----
"tf_saved_model.global_tensor"() {sym_name = "x", type = tensor<f32>, value = dense<1.67482901> : tensor<f32>} : () -> ()
"tf_saved_model.global_tensor"() {sym_name = "y", type = tensor<i32>, value = dense<0> : tensor<i32>} : () -> ()
// CHECK-LABEL: func @no_args
// CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x})
// CHECK-SAME: (%arg0: tensor<!tf.resource> {tf.resource_name = "x"})
// CHECK-NOT: "tf.VarHandleOp"
func @no_args() {
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource>
return
}
// CHECK-LABEL: func @some_args
// CHECK-SAME: (%arg0: tensor<i1>, %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x})
// CHECK-SAME: (%arg0: tensor<i1>, %arg1: tensor<!tf.resource> {tf.resource_name = "x"})
// CHECK-NOT: "tf.VarHandleOp"
func @some_args(%arg0: tensor<i1>) {
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource>
return
}
// CHECK-LABEL: func @unique_vars
// CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x}, %arg1: tensor<!tf.resource<tensor<i32>>> {tf_saved_model.bound_input = @y})
// CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf.resource_name = "x"}, %arg1: tensor<!tf.resource<tensor<i32>>> {tf.resource_name = "y"})
// CHECK-NOT: "tf.VarHandleOp"
func @unique_vars() {
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
@ -41,7 +37,7 @@ func @unique_vars() {
}
// CHECK-LABEL: func @duplicate_vars
// CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x})
// CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf.resource_name = "x"})
// CHECK-NOT: "tf.VarHandleOp"
func @duplicate_vars() {
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
@ -50,7 +46,7 @@ func @duplicate_vars() {
}
// CHECK-LABEL: func @duplicate_vars_with_users
// CHECK-SAME: (%arg0: tensor<f32>, %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x})
// CHECK-SAME: (%arg0: tensor<f32>, %arg1: tensor<!tf.resource<tensor<f32>>> {tf.resource_name = "x"})
// CHECK: "tf.ReadVariableOp"(%arg1)
// CHECK: "tf.AssignAddVariableOp"(%arg1, %arg0)
// CHECK-NOT: "tf.VarHandleOp"

View File

@ -1,57 +0,0 @@
// RUN: tf-opt %s -tf-raise-control-flow -split-input-file | FileCheck %s
// Test that we remove underscores.
// CHECK-LABEL: func @testSimpleAddsAndIdentity(%arg0: tensor<*xf32>)
func @testSimpleAddsAndIdentity(tensor<*xf32>) -> tensor<*xf32> {
^bb0(%0: tensor<*xf32>):
// CHECK: %0 = "tf.Identity"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%1 = "_tf.Identity"(%0) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: %1 = "tf.Add"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%2 = "_tf.Add"(%0, %0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: %2 = "tf.Add"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%3 = "_tf.Add"(%1, %2) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %2 : tensor<*xf32>
return %3 : tensor<*xf32>
}
// CHECK-LABEL: func @testAddWithControlDependency(%arg0: tensor<*xf32>)
func @testAddWithControlDependency(tensor<*xf32>) -> tensor<*xf32> {
^bb0(%0: tensor<*xf32>):
// CHECK: %0 = "tf.Identity"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%1:2 = "_tf.Identity"(%0) : (tensor<*xf32>) -> (tensor<*xf32>, !_tf.control)
// CHECK: %1 = "tf.Add"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%2:2 = "_tf.Add"(%0, %0, %1#1) : (tensor<*xf32>, tensor<*xf32>, !_tf.control) -> (tensor<*xf32>, !_tf.control)
// CHECK: %2 = "tf.Add"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%3:2 = "_tf.Add"(%1#0, %2, %1#1, %2#1) : (tensor<*xf32>, tensor<*xf32>, !_tf.control, !_tf.control) -> (tensor<*xf32>, !_tf.control)
// CHECK: return %2 : tensor<*xf32>
return %3 : tensor<*xf32>
}
// TODO(clattner): simplify and expand these tests. This is mostly a placeholder.
func @LoopTest() {
%0:2 = "_tf.Const"() {device = "", name = "Const", dtype = "tfdtype$DT_INT32", value = dense<1> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
%1:2 = "_tf.Enter"(%0#0) {device = "", name = "while/Enter", T = "tfdtype$DT_INT32", frame_name = "while/while_context", is_constant = false, parallel_iterations = 10} : (tensor<i32>) -> (tensor<*xi32>, !_tf.control)
%11:2 = "_tf.NextIteration.source"() {device = "", name = "while/NextIteration", T = "tfdtype$DT_INT32", id = 0} : () -> (tensor<*xi32>, !_tf.control)
%2:3 = "_tf.Merge"(%11#0, %1#0) {device = "", name = "while/Merge", N = 2, T = "tfdtype$DT_INT32"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<i32>, !_tf.control)
%3:2 = "_tf.Const"(%2#2) {device = "", name = "while/Less/y", dtype = "tfdtype$DT_INT32", value = dense<2> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control)
%4:2 = "_tf.Less"(%2#0, %3#0) {device = "", name = "while/Less", T = "tfdtype$DT_INT32"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi1>, !_tf.control)
%5:2 = "_tf.LoopCond"(%4#0) {device = "", name = "while/LoopCond"} : (tensor<*xi1>) -> (tensor<i1>, !_tf.control)
%6:3 = "_tf.Switch"(%2#0, %5#0) {device = "", name = "while/Switch", T = "tfdtype$DT_INT32", _class = ["loc:@while/Merge"]} : (tensor<*xi32>, tensor<i1>) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control)
%7:2 = "_tf.Exit"(%6#0) {device = "", name = "while/Exit", T = "tfdtype$DT_INT32"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
%8:2 = "_tf.Identity"(%6#1) {device = "", name = "while/Identity", T = "tfdtype$DT_INT32"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
%9:2 = "_tf.Const"(%8#1) {device = "", name = "while/Add/y", dtype = "tfdtype$DT_INT32", value = dense<3> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control)
%10:2 = "_tf.Add"(%8#0, %9#0) {device = "", name = "while/Add", T = "tfdtype$DT_INT32"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control)
%ctl = "_tf.NextIteration.sink"(%10#0) {device = "", name = "while/NextIteration", T = "tfdtype$DT_INT32", id = 0} : (tensor<*xi32>) -> (!_tf.control)
return
}

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -verify-diagnostics -readonly-references-to-resources -split-input-file %s | FileCheck %s --dump-input=fail
// RUN: tf-opt -verify-diagnostics -tf-readonly-references-to-resources -split-input-file %s | FileCheck %s --dump-input=fail
// Test case: Basic converting.

View File

@ -416,6 +416,17 @@ func @enter_control(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @enter_control_longform(%{{.*}}: tensor<*xf32>, %{{.*}}: tensor<i1>) -> tensor<*xf32> {
func @enter_control_longform(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
%0 = tf_executor.graph {
%1:3 = tf_executor.Switch %arg0, %arg1 : tensor<*xf32>
// CHECK: tf_executor.Enter %{{.*}}, %{{.*}}, %{{.*}} frame "some/frame" : tensor<*xf32>
%res:2 = tf_executor.Enter %arg0, %1#2, %1#2 frame "some/frame" : (tensor<*xf32>, !tf_executor.control, !tf_executor.control) -> (tensor<*xf32>, !tf_executor.control)
tf_executor.fetch %res#0 : tensor<*xf32>
}
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @nextiteration(%{{.*}}: tensor<*xf32>, %{{.*}}: i1) -> tensor<*xf32> {
func @nextiteration(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> {
%0 = tf_executor.graph {

View File

@ -40,3 +40,16 @@ module attributes {tf_saved_model.semantics} {
}
}
// -----
module attributes {tf_saved_model.semantics} {
// CHECK: func @f
func @f(
%arg0: tensor<f32> {tf.resource_name = "resource"}
) attributes { tf_saved_model.exported_names = ["foo.some_func"] } {
return
}
}

View File

@ -120,7 +120,7 @@ module attributes {tf_saved_model.semantics} {
module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{all arguments should have 'tf_saved_model.index_path' or 'tf_saved_model.bound_input' attributes}}
// expected-error@+1 {{all arguments should have 'tf_saved_model.index_path', 'tf_saved_model.bound_input' or 'tf.resource_name' attributes}}
func @f(
%arg0: tensor<f32>
) attributes { tf_saved_model.exported_names = ["f"] } {

View File

@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
// This transformation pass transforms functional control flow operations in the
// standard TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
// TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
@ -52,7 +52,6 @@ static Value LowerCondition(Location loc, Value value, OpBuilder* builder) {
//
// Requires the function to provide arguments for each of the `fn` operands
// that is compatible for tensor cast.
//
static Operation* CallFn(Location loc, const std::function<Value(int)>& get_arg,
FuncOp fn, OpBuilder* builder) {
FunctionType fn_type = fn.getType();
@ -113,7 +112,6 @@ static void JumpToBlock(Location loc, const std::function<Value(int)>& get_arg,
// Requires that the block has same number of arguments as number of results of
// the operation and either they have same types or are more generic types and
// it is possible to cast them to results' types.
//
static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op,
Block* block, OpBuilder* builder) {
assert(op->getNumResults() == block->getNumArguments());
@ -132,9 +130,6 @@ static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op,
// Given a functional IfOp, transforms the enclosing code to eliminate it
// completely from the IR, breaking it into operations to evaluate the condition
// as a bool, plus some branches.
//
// This returns true on failure.
//
static LogicalResult LowerIfOp(IfOp op) {
Operation* op_inst = op.getOperation();
Location loc = op_inst->getLoc();
@ -193,9 +188,6 @@ static LogicalResult LowerIfOp(IfOp op) {
// Given a functional WhileOp, transforms the enclosing code to eliminate it
// completely from the IR, breaking it into operations to execute the loop body
// repeatedly while the loop condition is true.
//
// This returns true on failure.
//
static LogicalResult LowerWhileOp(WhileOp op) {
Operation* op_inst = op.getOperation();
Location loc = op_inst->getLoc();

View File

@ -15,10 +15,15 @@ limitations under the License.
// This file implements logic for legalizing HLO to TensorFlow.
#include <cstdint>
#include <functional>
#include <memory>
#include <numeric>
#include <vector>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
@ -40,6 +45,8 @@ namespace mlir {
namespace TF {
namespace {
using xla_hlo::DotDimensionNumbers;
class ConvertSliceOp : public OpConversionPattern<xla_hlo::SliceOp> {
public:
using OpConversionPattern::OpConversionPattern;
@ -75,6 +82,205 @@ class ConvertSliceOp : public OpConversionPattern<xla_hlo::SliceOp> {
};
};
// Appends all elements in `range` to `values`.
template <typename ValueT, typename Range>
void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range) {
values.insert(values.end(), range.begin(), range.end());
}
// Appends all elements in `range` to `values`.
template <typename ValueT, typename Range, typename... RangeTs>
void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range,
RangeTs &&... ranges) {
values.insert(values.end(), range.begin(), range.end());
Append(values, ranges...);
}
// Returns the number of elements in `range`.
template <typename Range>
size_t Size(Range &&range) {
return range.size();
}
// Returns the total number of elements in a variadic number of `ranges`.
template <typename Range, typename... RangeTs>
size_t Size(Range &&range, RangeTs &&... ranges) {
return range.size() + Size(std::forward<RangeTs>(ranges)...);
}
// Concats all elements in `ranges` and returns a small vector as a result.
template <typename ValueT, typename... RangeTs>
llvm::SmallVector<ValueT, 4> Concat(RangeTs &&... ranges) {
llvm::SmallVector<int64_t, 4> results;
results.reserve(Size(std::forward<RangeTs>(ranges)...));
Append(results, std::forward<RangeTs>(ranges)...);
return results;
}
// A struct to hold axes and sizes for a set of dimensions.
struct DimensionSetVector {
llvm::ArrayRef<int64_t> AxesArray() const { return axes.getArrayRef(); }
llvm::ArrayRef<int64_t> SizesArray() const { return sizes.getArrayRef(); }
llvm::SmallSetVector<int64_t, 4> axes;
llvm::SmallSetVector<int64_t, 4> sizes;
};
// A struct to hold information about dimensions of dot_general operands.
class DotDimensionsInfo {
public:
DotDimensionsInfo(ShapedType type, DenseIntElementsAttr batch_dimensions,
DenseIntElementsAttr contracting_dimensions) {
const int rank = type.getRank();
for (const int dim : batch_dimensions.getValues<int64_t>()) {
batch_dimensions_.axes.insert(dim);
batch_dimensions_.sizes.insert(type.getDimSize(dim));
}
for (const int dim : contracting_dimensions.getValues<int64_t>()) {
contracting_dimensions_.axes.insert(dim);
contracting_dimensions_.sizes.insert(type.getDimSize(dim));
}
for (int dim = 0; dim < rank; ++dim) {
if (contracting_dimensions_.axes.count(dim) > 0 ||
batch_dimensions_.axes.count(dim) > 0) {
continue;
}
out_dimensions_.axes.insert(dim);
out_dimensions_.sizes.insert(type.getDimSize(dim));
}
}
const DimensionSetVector &batch_dimensions() const {
return batch_dimensions_;
}
const DimensionSetVector &contracting_dimensions() const {
return contracting_dimensions_;
}
// Out dimensions are any dimensions that are neither batch nor contracting
// dimensions, hence will be propagated to output shape.
const DimensionSetVector &out_dimensions() const { return out_dimensions_; }
// Returns the total dimension size after flattening all contracting
// dimensions.
int FlattenedContractingDimensionSize() const {
return std::accumulate(contracting_dimensions_.sizes.begin(),
contracting_dimensions_.sizes.end(), 1,
std::multiplies<int64_t>());
}
// Returns the total dimension size after flattening all out dimensions.
int FlattenedOutDimensionSize() const {
return std::accumulate(out_dimensions_.sizes.begin(),
out_dimensions_.sizes.end(), 1,
std::multiplies<int64_t>());
}
private:
DimensionSetVector batch_dimensions_;
DimensionSetVector contracting_dimensions_;
// Out dimensions are any dimensions that are neither batch nor contracting
// dimensions, hence will be propagated to output shape.
DimensionSetVector out_dimensions_;
};
// Converts xla_hlo.dot to tf.BatchMatMul. Reshape or Transpose ops will also be
// inserted to convert to well-formed matrix multiply.
Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
auto dot_general_op = cast<xla_hlo::DotGeneralOp>(old_op);
auto lhs_type = dot_general_op.lhs().getType().cast<ShapedType>();
auto rhs_type = dot_general_op.rhs().getType().cast<ShapedType>();
auto result_type = dot_general_op.getResult().getType().cast<ShapedType>();
DotDimensionNumbers dot_dimension_numbers =
dot_general_op.dot_dimension_numbers();
mlir::Location loc = dot_general_op.getLoc();
const int lhs_rank = lhs_type.getRank();
const int rhs_rank = rhs_type.getRank();
// Collects lhs and rhs dimensions information.
DotDimensionsInfo lhs_dot_dimensions_info(
lhs_type, dot_dimension_numbers.lhs_batching_dimensions(),
dot_dimension_numbers.lhs_contracting_dimensions());
DotDimensionsInfo rhs_dot_dimensions_info(
rhs_type, dot_dimension_numbers.rhs_batching_dimensions(),
dot_dimension_numbers.rhs_contracting_dimensions());
// Transposes lhs shape to be in the order of {batch_dimensions,
// out_dimensions, contracting dimensions}.
llvm::SmallVector<int64_t, 4> lhs_permutation = Concat<int64_t>(
lhs_dot_dimensions_info.batch_dimensions().AxesArray(),
lhs_dot_dimensions_info.out_dimensions().AxesArray(),
lhs_dot_dimensions_info.contracting_dimensions().AxesArray());
llvm::SmallVector<int64_t, 4> lhs_transposed_shape = Concat<int64_t>(
lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
lhs_dot_dimensions_info.out_dimensions().SizesArray(),
lhs_dot_dimensions_info.contracting_dimensions().SizesArray());
auto lhs_transposed = rewriter.create<xla_hlo::TransposeOp>(
loc,
RankedTensorType::get(lhs_transposed_shape, lhs_type.getElementType()),
dot_general_op.lhs(),
DenseIntElementsAttr::get(
RankedTensorType::get({lhs_rank}, rewriter.getI64Type()),
lhs_permutation));
// Transposes rhs shape to be in the order of {batch_dimensions, contracting
// dimensions, out_dimensions}.
llvm::SmallVector<int64_t, 4> rhs_permutation = Concat<int64_t>(
rhs_dot_dimensions_info.batch_dimensions().AxesArray(),
rhs_dot_dimensions_info.contracting_dimensions().AxesArray(),
rhs_dot_dimensions_info.out_dimensions().AxesArray());
llvm::SmallVector<int64_t, 4> rhs_transposed_shape = Concat<int64_t>(
rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
rhs_dot_dimensions_info.contracting_dimensions().SizesArray(),
rhs_dot_dimensions_info.out_dimensions().SizesArray());
auto rhs_transposed = rewriter.create<xla_hlo::TransposeOp>(
loc,
RankedTensorType::get(rhs_transposed_shape, rhs_type.getElementType()),
dot_general_op.rhs(),
DenseIntElementsAttr::get(
RankedTensorType::get({rhs_rank}, rewriter.getI64Type()),
rhs_permutation));
// Reshapes lhs to flatten out_dimensions and contracting_dimensions.
llvm::SmallVector<int64_t, 4> lhs_flattened_shape = Concat<int64_t>(
lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
llvm::ArrayRef<int64_t>{
lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
llvm::ArrayRef<int64_t>{
lhs_dot_dimensions_info.FlattenedContractingDimensionSize()});
auto lhs_flattend = rewriter.create<xla_hlo::ReshapeOp>(
loc,
RankedTensorType::get(lhs_flattened_shape, lhs_type.getElementType()),
lhs_transposed.getResult());
// Reshapes rhs to flatten out_dimensions and contracting_dimensions.
llvm::SmallVector<int64_t, 4> rhs_flattened_shape = Concat<int64_t>(
rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
llvm::ArrayRef<int64_t>{
rhs_dot_dimensions_info.FlattenedContractingDimensionSize()},
llvm::ArrayRef<int64_t>{
rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
auto rhs_flattend = rewriter.create<xla_hlo::ReshapeOp>(
loc,
RankedTensorType::get(rhs_flattened_shape, rhs_type.getElementType()),
rhs_transposed.getResult());
// Creates matmul op of `lhs_flattend` and `rhs_flattend`.
llvm::SmallVector<int64_t, 4> matmul_shape =
Concat<int64_t>(lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
llvm::ArrayRef<int64_t>{
lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
llvm::ArrayRef<int64_t>{
rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
auto matmul = rewriter.create<TF::BatchMatMulV2Op>(
loc, RankedTensorType::get(matmul_shape, result_type.getElementType()),
lhs_flattend.getResult(), rhs_flattend.getResult());
auto reshaped =
rewriter.create<xla_hlo::ReshapeOp>(loc, result_type, matmul.getResult());
return reshaped.getResult();
}
class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
public:
LegalizeHloToTf() = default;

View File

@ -184,3 +184,10 @@ def ConvertDotOp : NativeCodeCall<"ConvertDotOp($_builder, "
def : Pat<(HLO_DotOp:$old_value AnyStaticShapeTensor:$lhs,
AnyStaticShapeTensor:$rhs, $precision_config),
(ConvertDotOp $old_value)>;
def ConvertDotGeneralOp : NativeCodeCall<"ConvertDotGeneralOp($_builder, "
"$0.getDefiningOp())">;
def : Pat<(HLO_DotGeneralOp:$old_value AnyStaticShapeTensor:$lhs,
AnyStaticShapeTensor:$rhs, $dot_dimension_numbers,
$precision_config),
(ConvertDotGeneralOp $old_value)>;

View File

@ -166,7 +166,7 @@ std::unique_ptr<OperationPass<FuncOp>> CreateOpFusionPass() {
}
static PassRegistration<OpFusionPass> pass(
"op-fusion",
"tf-op-fusion",
"Replaces commonly occurring subgraphs with optimized fused kernels");
} // namespace TF

View File

@ -58,6 +58,8 @@ void CreateTFStandardPipeline(OpPassManager &pm,
func_pm.addPass(tf_executor::CreateTFExecutorGraphPruningPass());
func_pm.addPass(tf_executor::CreateTFExecutorIslandCoarseningPass());
func_pm.addPass(CreateMaterializePassthroughOpPass());
if (options.form_clusters)
func_pm.addPass(TFDevice::CreateClusterFormationPass());
// Hopefully there is a single island left, or there wasn't any to begin with.
// We now run the optimizer which operates mostly inside islands.

View File

@ -77,6 +77,9 @@ struct StandardPipelineOptions
Option<bool> enable_inliner{*this, "enable-inliner",
llvm::cl::desc("Enable inliner."),
llvm::cl::init(false)};
Option<bool> form_clusters{*this, "form-clusters",
llvm::cl::desc("Enable Cluster Formation pass."),
llvm::cl::init(false)};
};
// Propagates the pass manager with the passes involved in transforming or
@ -95,11 +98,9 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateResourceDeviceInferencePass();
// of their aliasing output arguments.
std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass();
// Creates a pass that promotes tf.VarHandleOp to to resource arguments of where
// resource names are `tf_saved_model.bound_input` symbol argument attributes
// for all functions.
std::unique_ptr<OperationPass<ModuleOp>>
CreatePromoteVarHandlesToSavedModelArgsPass();
// Creates a pass that promotes tf.VarHandleOp to resource arguments for all
// functions.
std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass();
// Creates a pass that converts readonly reference variables to the
// corresponding resource variables.
@ -151,13 +152,6 @@ std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass();
std::unique_ptr<OperationPass<FuncOp>> CreateOpFusionPass();
} // namespace TF
namespace TFControlFlow {
// Raises from the "TensorFlow Control Flow" dialect to the standard TensorFlow
// dialect.
std::unique_ptr<OperationPass<FuncOp>> CreateRaiseTFControlFlowPass();
} // namespace TFControlFlow
namespace tf_executor {
class GraphOp;

View File

@ -389,18 +389,15 @@ void PromoteResourcesToArgsPass::runOnOperation() {
return signalPassFailure();
}
// This pass is for promoting Varhandle ops to tf_saved_model.bound_input
// attributes, which are required for TensorFlowSavedModelDialect.
class PromoteVarHandlesToSavedModelArgsPass
: public PassWrapper<PromoteVarHandlesToSavedModelArgsPass,
OperationPass<ModuleOp>> {
class PromoteVarHandlesToArgsPass
: public PassWrapper<PromoteVarHandlesToArgsPass, OperationPass<ModuleOp>> {
public:
void runOnOperation() override;
};
void PromoteVarHandlesToSavedModelArgsPass::runOnOperation() {
void PromoteVarHandlesToArgsPass::runOnOperation() {
ModuleOp module = getOperation();
MLIRContext* context = module.getContext();
for (auto function : module.getOps<FuncOp>()) {
if (failed(CheckSingleBlockFunction(function))) return signalPassFailure();
@ -409,15 +406,13 @@ void PromoteVarHandlesToSavedModelArgsPass::runOnOperation() {
&var_handle_shared_names);
// Add resource names for each `tf.VarHandleOp` that were promoted to
// saved model arguments.
// resource arguments.
const int var_handle_args_offset =
function.getNumArguments() - var_handle_shared_names.size();
for (auto var_name_and_index : llvm::enumerate(var_handle_shared_names)) {
auto symbol_ref =
SymbolRefAttr::get(var_name_and_index.value(), &getContext());
for (auto var_name_and_index : llvm::enumerate(var_handle_shared_names))
function.setArgAttr(var_name_and_index.index() + var_handle_args_offset,
"tf_saved_model.bound_input", symbol_ref);
}
kResourceNameArgAttr,
StringAttr::get(var_name_and_index.value(), context));
}
}
@ -427,19 +422,17 @@ std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass() {
return std::make_unique<PromoteResourcesToArgsPass>();
}
std::unique_ptr<OperationPass<ModuleOp>>
CreatePromoteVarHandlesToSavedModelArgsPass() {
return std::make_unique<PromoteVarHandlesToSavedModelArgsPass>();
std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass() {
return std::make_unique<PromoteVarHandlesToArgsPass>();
}
static PassRegistration<PromoteResourcesToArgsPass> pass(
"tf-promote-resources-to-args",
"Promote resources reads/writes to function inputs/outputs.");
static PassRegistration<PromoteVarHandlesToSavedModelArgsPass> saved_model_pass(
"tf-saved-model-promote-var-handles-to-args",
"Promote tf.VarHandleOps to function arguments in a format of "
"TensorFlowSavedModelDialect.");
static PassRegistration<PromoteVarHandlesToArgsPass> var_handle_pass(
"tf-promote-var-handles-to-args",
"Promote tf.VarHandleOps to function arguments.");
} // namespace TF
} // namespace mlir

View File

@ -1,159 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for raising from the "TensorFlow control flow"
// dialect of MLIR to the standard TensorFlow dialect. The TensorFlow control
// flow dialect represents control flow with Switch/Merge and a few related
// control flow nodes, along with control dependencies.
//
// This pass rebuilds them code in terms of MLIR branches and blocks,
// eliminating control dependencies, and results in the code being in the
// canonical TensorFlow dialect.
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
namespace mlir {
namespace TFControlFlow {
namespace {
struct RaiseTFControlFlow
: public PassWrapper<RaiseTFControlFlow, FunctionPass> {
void runOnFunction() {
// First start by recognizing loops and reconstructing a loop tree.
buildLoopNests();
// Next, transform Switch/Merge and other control flow ops into proper
// conditional control flow.
buildConditionals();
// Now that we have proper conditional control flow ops, the control edges
// can be dropped, and the underscores removed from operation names.
rewriteOps();
}
void buildLoopNests();
void buildConditionals();
void rewriteOps();
};
//===----------------------------------------------------------------------===//
// Loop nest reconstruction
//===----------------------------------------------------------------------===//
void RaiseTFControlFlow::buildLoopNests() {
// TODO(clattner)
}
//===----------------------------------------------------------------------===//
// Conditional Reconstruction
//===----------------------------------------------------------------------===//
void RaiseTFControlFlow::buildConditionals() {
// TODO.
}
//===----------------------------------------------------------------------===//
// Final rewrite from TF Control Flow form to canonical TensorFlow form
//===----------------------------------------------------------------------===//
static bool isUnderscoredTFOp(Operation &op) {
return op.getName().getStringRef().startswith("_tf.");
}
// Drop control edges, and remove underscores from operation names.
void RaiseTFControlFlow::rewriteOps() {
auto function = getFunction();
OpBuilder builder(function.getBody());
// On the first pass, create replacement operations for every one we are going
// to replace, updating anything that uses the normal results with the newly
// created operation.
for (auto &bb : function) {
for (auto &op : bb) {
// Ignore any operations that we aren't looking for.
if (!isUnderscoredTFOp(op)) continue;
// We always insert the replacement operation next to the operation it
// is replacing.
builder.setInsertionPoint(&op);
// Drop the leading _ off the name.
OperationState result(op.getLoc(),
op.getName().getStringRef().drop_front());
// Add an operand for each non-control input we find. Control values
// aren't necessary any more since the order within a block encodes the
// same information.
for (auto &operand : op.getOpOperands()) {
if (!operand.get().getType().isa<TFControlType>())
result.operands.push_back(operand.get());
// Drop all operands from the old operation, eliminating any
// inter-dependencies after this pass.
operand.drop();
}
// Add a result type for each non-control result we find.
bool sawControlResult = false;
for (auto opResult : op.getResults()) {
if (opResult.getType().isa<TFControlType>()) {
sawControlResult = true;
} else {
// We assume all control inputs are at the end of the result list.
assert(!sawControlResult && "all control results must be last");
(void)sawControlResult;
result.types.push_back(opResult.getType());
}
}
result.attributes.append(op.getAttrs().begin(), op.getAttrs().end());
// Create the replacement operation.
auto *replacement = builder.createOperation(result);
// We know that all the control results are last, so we can just rewrite
// the first results.
for (unsigned i = 0, e = result.types.size(); i != e; ++i)
op.getResult(i).replaceAllUsesWith(replacement->getResult(i));
}
}
// In the second pass, we can safely remove all of the old operations, because
// we know that all inter-dependencies are dropped.
for (auto &bb : function) {
// Advance the iterator so we don't invalidate it when we remove an
// operation later in the loop.
for (auto &op : llvm::make_early_inc_range(bb))
if (isUnderscoredTFOp(op)) op.erase();
}
}
} // namespace
std::unique_ptr<OperationPass<FuncOp>> CreateRaiseTFControlFlowPass() {
return std::make_unique<RaiseTFControlFlow>();
}
static PassRegistration<RaiseTFControlFlow> pass(
"tf-raise-control-flow",
"Raise from the TensorFlow Control Flow "
"dialect to the standard TensorFlow dialect");
} // namespace TFControlFlow
} // namespace mlir

View File

@ -171,7 +171,7 @@ CreateConvertReadonlyReferenceVariablesToResourceVariablesPass() {
static PassRegistration<
ConvertReadonlyReferenceVariablesToResourceVariablesPass>
pass("readonly-references-to-resources",
pass("tf-readonly-references-to-resources",
"Convert readonly reference variables to resource variables.");
} // namespace TF

View File

@ -1,242 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass transforms from TF executor dialect to MLIR TF
// control dialect.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#define DEBUG_TYPE "tf-executor-to-ctl"
namespace mlir {
namespace {
struct ExecutorToControlDialectConversion
: public PassWrapper<ExecutorToControlDialectConversion, FunctionPass> {
void runOnFunction() override;
};
} // end anonymous namespace
static bool HasSingleGraph(FuncOp function) {
// We expect the function has only one region with one block,
if (function.getBlocks().size() != 1) return false;
auto &block = function.front();
// and the block contains two ops,
if (std::next(block.begin()) == block.end()) return false;
// one GraphOp,
if (!isa<tf_executor::GraphOp>(block.begin())) return false;
// followed by a terminator.
if (!std::next(block.begin())->isKnownTerminator()) return false;
return true;
}
void ExecutorToControlDialectConversion::runOnFunction() {
if (!HasSingleGraph(getFunction())) {
LLVM_DEBUG(llvm::dbgs()
<< "Expect a Function with a single block and a single graph op,"
" skip tf_executor dialect conversion\n");
return;
}
Type control_type = TFControlFlow::TFControlType::get(&getContext());
Block &body = getFunction().front();
auto graph = cast<tf_executor::GraphOp>(body.front());
OpBuilder builder = OpBuilder::atBlockEnd(&body);
SmallString<64> new_op_name;
for (auto &op : llvm::make_early_inc_range(llvm::reverse(graph.GetBody()))) {
LLVM_DEBUG(llvm::dbgs() << "Process: " << op.getName() << "\n");
if (auto fetch = dyn_cast<tf_executor::FetchOp>(op)) {
// Replace all the operands of the fetch op with the uses of the graph
// results, remove the fetch op afterwards.
for (auto ops_and_ret_vals :
llvm::zip(graph.getResults(), fetch.getOperands()))
std::get<0>(ops_and_ret_vals)
.replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
op.erase();
continue;
}
builder.setInsertionPoint(&op);
if (auto island = dyn_cast<tf_executor::IslandOp>(op)) {
Value ctl_sequence = nullptr;
if (island.GetBody().without_terminator().empty() &&
island.getNumOperands() > 1) {
// For an empty island with multiple control inputs, we create a no-op
// inside it which will group all the inputs into one control output.
// This helps reducing the number of edges when there are multiple
// islands depending on this one.
builder.setInsertionPointToStart(&island.GetBody());
builder.create<TF::NoOp>(op.getLoc(), ArrayRef<Type>{},
ArrayRef<Value>{}, ArrayRef<NamedAttribute>{});
builder.setInsertionPoint(&op);
}
for (Operation &wrapped_op : island.GetBody()) {
LLVM_DEBUG(llvm::dbgs()
<< " In island: " << wrapped_op.getName() << "\n");
if (isa<tf_executor::YieldOp>(wrapped_op)) {
for (auto ops_and_ret_vals :
llvm::zip(island.getResults(), wrapped_op.getOperands()))
std::get<0>(ops_and_ret_vals)
.replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
break;
}
// Add a leading _ off the name.
new_op_name = "_";
new_op_name += wrapped_op.getName().getStringRef();
OperationState state(wrapped_op.getLoc(), new_op_name);
// Add an operand for each non-control input we find. Collect control
// values separately to add them to the island operands
state.operands.append(wrapped_op.getOperands().begin(),
wrapped_op.getOperands().end());
// Chain operations through a control dependency, except for the first
// operations in the sequence that carry the control dependencies held
// by the island itself.
if (ctl_sequence) {
state.operands.push_back(ctl_sequence);
} else {
for (Value ctl_operand : island.getOperands())
state.operands.push_back(ctl_operand);
}
// Add a result type for each result
state.types.append(wrapped_op.getResultTypes().begin(),
wrapped_op.getResultTypes().end());
state.types.push_back(control_type);
// Create the replacement operation.
auto *replacement = builder.createOperation(state);
replacement->setAttrs(wrapped_op.getMutableAttrDict());
for (auto ops_and_ret_vals :
llvm::zip(wrapped_op.getResults(), replacement->getResults()))
std::get<0>(ops_and_ret_vals)
.replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
ctl_sequence = replacement->getResult(replacement->getNumResults() - 1);
}
if (ctl_sequence) {
// If ctl_sequence is non-null, this means at least one operation has
// been rewritten from ops in island. Last op rewritten must logically
// carry // all the island control inputs, we can simply use it to
// replace all uses of island's control output.
island.control().replaceAllUsesWith(ctl_sequence);
} else if (island.getNumOperands() > 0) {
// Getting here means island had an effectively empty body and there is
// just one control input. In this case, island's control output should
// be replaced with the control input.
assert(island.getNumOperands() == 1);
island.control().replaceAllUsesWith(island.getOperand(0));
}
op.erase();
continue;
}
new_op_name.clear();
if (isa<tf_executor::SwitchOp>(op)) {
new_op_name = "_tf.Switch";
} else if (isa<tf_executor::SwitchNOp>(op)) {
new_op_name = "_tf._SwitchN";
} else if (isa<tf_executor::MergeOp>(op)) {
new_op_name = "_tf.Merge";
} else if (isa<tf_executor::NextIterationSourceOp>(op)) {
new_op_name = "_tf.NextIteration.source";
} else if (isa<tf_executor::NextIterationSinkOp>(op)) {
new_op_name = "_tf.NextIteration.sink";
} else if (isa<tf_executor::LoopCondOp>(op)) {
new_op_name = "_tf.LoopCond";
} else if (isa<tf_executor::EnterOp>(op)) {
new_op_name = "_tf.Enter";
} else if (isa<tf_executor::ExitOp>(op)) {
new_op_name = "_tf.Exit";
} else if (isa<tf_executor::ControlTriggerOp>(op)) {
new_op_name = "_tf.ControlTrigger";
} else {
op.emitOpError() << "unhandled op in tf_executor to _tf conversion";
return signalPassFailure();
}
OperationState state(op.getLoc(), new_op_name);
// Drop all TokenType operands since they don't exist in the control
// dialect.
auto non_null_operands = llvm::make_filter_range(
op.getOperands(),
[](Value v) { return !v.getType().isa<tf_executor::TokenType>(); });
state.operands.append(non_null_operands.begin(), non_null_operands.end());
for (Type result_type : op.getResultTypes()) {
// Filter out TokenType, they don't exist in the control dialect.
if (result_type.isa<tf_executor::TokenType>()) continue;
if (!result_type.isa<tf_executor::ControlType>())
state.types.push_back(result_type);
else
state.types.push_back(control_type);
}
// The control dialect has a control result for the sink operation.
if (isa<tf_executor::NextIterationSinkOp>(op))
state.types.push_back(control_type);
// Create the replacement operation.
auto *replacement = builder.createOperation(state);
replacement->setAttrs(op.getMutableAttrDict());
if (auto next_iteration =
dyn_cast<tf_executor::NextIterationSourceOp>(op)) {
next_iteration.output().replaceAllUsesWith(replacement->getResult(0));
next_iteration.token().dropAllUses();
next_iteration.control().replaceAllUsesWith(replacement->getResult(1));
} else {
for (auto ops_and_ret_vals :
llvm::zip(op.getResults(), replacement->getResults()))
std::get<0>(ops_and_ret_vals)
.replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
}
op.erase();
}
// Now we have rewritten all ops inside GraphOp to TF Control dialect. We need
// to move all operations outside of GraphOp and remove it.
body.getOperations().splice(body.begin(), graph.GetBody().getOperations());
graph.erase();
}
std::unique_ptr<OperationPass<FuncOp>>
CreateTFExecutorToControlDialectConversion() {
return std::make_unique<ExecutorToControlDialectConversion>();
}
} // namespace mlir
static mlir::PassRegistration<mlir::ExecutorToControlDialectConversion> pass(
"tf-executor-to-control-conversion",
"Convert from TF executor dialect to TF control dialect");

View File

@ -89,12 +89,11 @@ StatusOr<ElementsAttr> ConvertFlatTensor(const Tensor& input_tensor,
ElementsAttr ConvertBf16Tensor(const Tensor& input_tensor,
RankedTensorType type) {
auto flat = input_tensor.flat<bfloat16>();
llvm::SmallVector<llvm::APFloat, 4> floats;
floats.reserve(flat.size());
for (bfloat16 v : llvm::makeArrayRef(flat.data(), flat.size()))
floats.push_back(llvm::APFloat(static_cast<double>(v)));
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(floats));
auto buffer = llvm::makeArrayRef(static_cast<char*>(input_tensor.data()),
input_tensor.TotalBytes());
return mlir::DenseElementsAttr::getFromRawBuffer(
type, buffer,
/*isSplatBuffer=*/type.getNumElements() == 1);
}
ElementsAttr ConvertHalfTensor(const Tensor& tensor, RankedTensorType type) {
@ -280,16 +279,11 @@ void ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr,
void ConvertBfloat16ElementsAttr(const mlir::DenseFPElementsAttr attr,
protobuf::RepeatedField<int>* output) {
// Bfloat16 is internally represented as `double` in MLIR.
if (attr.isSplat()) {
double v = attr.getSplatValue<double>();
bfloat16 bf16_val = static_cast<bfloat16>(v);
output->Add(absl::bit_cast<int16>(bf16_val));
output->Add((*attr.begin()).bitcastToAPInt().getSExtValue());
} else {
for (auto v : attr.getValues<double>()) {
bfloat16 bf16_val = static_cast<bfloat16>(v);
output->Add(absl::bit_cast<int16>(bf16_val));
}
for (const llvm::APFloat value : attr.getFloatValues())
output->Add(value.bitcastToAPInt().getSExtValue());
}
}

View File

@ -23,12 +23,6 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h"
namespace mlir {
/// Create a pass to convert from the TFExecutor to the TF control dialect.
std::unique_ptr<OperationPass<FuncOp>>
CreateTFExecutorToControlDialectConversion();
} // namespace mlir
namespace tensorflow {
void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm) {

View File

@ -28,6 +28,8 @@ package_group(
exports_files(["ir/hlo_ops.td"])
exports_files(["ir/lhlo_ops.td"])
filegroup(
name = "hlo_ops_td_files",
srcs = [
@ -35,6 +37,7 @@ filegroup(
"ir/hlo_ops.td",
"ir/hlo_ops_base.td",
"ir/hlo_utils.td",
"ir/infer_fusibility_op_interface.td",
"ir/lhlo_ops.td",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
@ -87,6 +90,8 @@ gentbl(
tbl_outs = [
("-gen-op-decls", "ir/lhlo_ops.h.inc"),
("-gen-op-defs", "ir/lhlo_ops.cc.inc"),
("-gen-struct-attr-decls", "ir/lhlo_structs.h.inc"),
("-gen-struct-attr-defs", "ir/lhlo_structs.cc.inc"),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ir/lhlo_ops.td",
@ -118,6 +123,42 @@ gentbl(
td_srcs = [":hlo_ops_td_files"],
)
gentbl(
name = "infer_fusibility_op_interface_gen",
tbl_outs = [
(
"-gen-op-interface-decls",
"ir/infer_fusibility_op_interface.h.inc",
),
(
"-gen-op-interface-defs",
"ir/infer_fusibility_op_interface.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ir/infer_fusibility_op_interface.td",
td_srcs = [
":hlo_ops_td_files",
],
)
cc_library(
name = "infer_fusibility_op_interface",
srcs = [
"ir/infer_fusibility_op_interface.cc",
],
hdrs = [
"ir/infer_fusibility_op_interface.h",
"ir/infer_fusibility_op_interface.h.inc",
],
deps = [
":infer_fusibility_op_interface_gen",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
alwayslink = 1,
)
cc_library(
name = "xla_legalize_tf",
srcs = [
@ -362,6 +403,7 @@ cc_library(
":map_hlo_to_lhlo_op",
"@com_google_absl//absl/memory",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
@ -369,6 +411,43 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "cycle_detector",
srcs = ["transforms/cycle_detector.cc"],
hdrs = ["transforms/cycle_detector.h"],
deps = [
"@llvm-project//llvm:support",
],
alwayslink = 1,
)
tf_cc_test(
name = "cycle_detector_test",
srcs = ["transforms/cycle_detector_test.cc"],
deps = [
":cycle_detector",
"//tensorflow/compiler/xla:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "xla_hlo_fusion",
srcs = ["transforms/xla_hlo_fusion.cc"],
deps = [
":cycle_detector",
":hlo",
"@llvm-project//llvm:ir",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,
)
gentbl(
name = "xla_legalize_to_standard_inc_gen",
tbl_outs = [
@ -555,6 +634,7 @@ cc_library(
":convert_op_folder",
":hlo_ops_base_inc_gen",
":hlo_ops_inc_gen",
":infer_fusibility_op_interface",
":xla_canonicalize_inc_gen",
"@com_google_absl//absl/container:flat_hash_set",
"@llvm-project//llvm:support",
@ -824,6 +904,7 @@ genrule(
":ir/hlo_ops.td",
":ir/hlo_ops_base.td",
":ir/hlo_utils.td",
":ir/infer_fusibility_op_interface.td",
],
outs = ["operator_writers.inc"],
cmd = ("$(location :operator_writer_gen) " +
@ -859,6 +940,7 @@ cc_library(
":lhlo_legalize_to_gpu",
":lhlo_legalize_to_parallel_loops",
":xla_dialect_registration",
":xla_hlo_fusion",
":xla_hlo_to_lhlo_with_xla",
":xla_legalize_control_flow",
":xla_legalize_tf",

View File

@ -44,8 +44,7 @@ template <typename CppType>
}
mlir::APFloat ConvertToAPFloat(bfloat16 val) {
// bfloat16 values are stored as double in MLIR.
return llvm::APFloat(static_cast<double>(val));
return llvm::APFloat(llvm::APFloat::BFloat(), llvm::APInt(16, val.value));
}
mlir::APFloat ConvertToAPFloat(half val) {

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/ir/infer_fusibility_op_interface.h"
namespace mlir {
class OpBuilder;

View File

@ -26,6 +26,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td"
include "tensorflow/compiler/mlir/xla/ir/hlo_utils.td"
include "tensorflow/compiler/mlir/xla/ir/infer_fusibility_op_interface.td"
def HLO_Dialect : Dialect {
let name = "xla_hlo";
@ -117,7 +118,7 @@ def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> {
class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
Type TensorType>: HLO_Op<mnemonic,
!listconcat(traits, [InferShapedTypeOpInterface])> {
!listconcat(traits, [InferShapedTypeOpInterface, InferFusibilityOpInterface])> {
let arguments = (ins TensorType:$operand);
let results = (outs TensorType);
let extraClassDeclaration = [{
@ -132,6 +133,12 @@ class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
return deriveShapeFromFirstOperand(&builder, getOperation(),
&reifiedReturnShapes);
}
bool inferInputOutputShapeEquality(int input, int output) {
return true;
}
llvm::Optional<Value> inferEffectiveWorkloadShape() {
return getOperation()->getResult(0);
}
}];
}
@ -257,7 +264,7 @@ def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh",
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
HLO_Op<mnemonic, !listconcat(traits, [InferShapedTypeOpInterface])> {
HLO_Op<mnemonic, !listconcat(traits, [InferShapedTypeOpInterface, InferFusibilityOpInterface])> {
let arguments = (ins
HLO_Tensor:$lhs,
HLO_Tensor:$rhs
@ -275,6 +282,15 @@ class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
return deriveShapeFromFirstOperand(&builder, getOperation(),
&reifiedReturnShapes);
}
bool inferInputsShapeEquality(int lhs, int rhs) {
return true;
}
bool inferInputOutputShapeEquality(int input, int output) {
return true;
}
llvm::Optional<Value> inferEffectiveWorkloadShape() {
return getOperation()->getResult(0);
}
}];
let results = (outs HLO_Tensor);
@ -598,7 +614,8 @@ def HLO_AllToAllOp : HLO_Op<"all_to_all",
def HLO_ReduceOp: HLO_Op<"reduce", [
RecursiveSideEffects,
SameVariadicOperandSize,
SingleBlockImplicitTerminator<"ReturnOp">
SingleBlockImplicitTerminator<"ReturnOp">,
InferFusibilityOpInterface
]>, BASE_HLO_ReduceOp {
let arguments = (ins
Variadic<HLO_TensorOrTuple>:$operands,
@ -613,6 +630,15 @@ def HLO_ReduceOp: HLO_Op<"reduce", [
"ValueRange init_values, DenseIntElementsAttr dimensions"
>];
let extraClassDeclaration = [{
bool isFusibleWithConsumer() {
return false;
}
llvm::Optional<Value> inferEffectiveWorkloadShape() {
return getOperation()->getOperand(0);
}
}];
let hasFolder = 1;
// TODO(hinsu): Verify that the attached body arguments and results are
@ -1360,4 +1386,27 @@ def HLO_DequantizeOp : HLO_Op<"dequantize", [NoSideEffect]>,
let hasCustomHLOConverter = 1;
}
def HLO_FusionOp : HLO_Op<"fusion", []> {
let summary = "Fusion operator";
let description = [{
Models the fusion instruction.
A fusion op is consists of a group of basic ops (represented as a region
attached to it). It serves as a hint to the backend that it is beneficial
to emit the contained ops into a single loop nest or kernel.
}];
let regions = (region SizedRegion<1>:$fused_computation);
let arguments = (ins
Variadic<HLO_TensorOrTuple>:$operands
);
let results = (outs
Variadic<HLO_TensorOrTuple>:$results
);
// FusionOp has special conversion logic to HLO.
let hasCustomHLOConverter = 1;
}
#endif // HLO_OPS

View File

@ -0,0 +1,22 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/ir/infer_fusibility_op_interface.h"
namespace mlir {
#include "tensorflow/compiler/mlir/xla/ir/infer_fusibility_op_interface.cc.inc"
} // namespace mlir

View File

@ -0,0 +1,28 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
namespace mlir {
#include "tensorflow/compiler/mlir/xla/ir/infer_fusibility_op_interface.h.inc"
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_

View File

@ -0,0 +1,161 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file contains inferFusiblityOpInterface, which is used to guide
// fusion decision.
#ifndef MLIR_INFER_FUSIBILITY_OP_INTERFACE
#define MLIR_INFER_FUSIBILITY_OP_INTERFACE
include "mlir/IR/OpBase.td"
// OpInterface to query if an op is fusible and to query the shape equality
// constraint among the inputs and outputs of an op.
def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> {
let description = [{
Interface to query if an op is fusible and to query the shape equality
constraint among the inputs and outputs of an op.
}];
let methods = [
InterfaceMethod<
/*desc=*/[{If true, this op can be fused with its operands
}],
/*retTy=*/"bool",
/*methodName=*/"isFusibleWithOperand",
/*args=*/(ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
/// Returns whether this op can be fused with its operands
return true;
}]
>,
InterfaceMethod<
/*desc=*/[{If true, this op can be fused with its consumers
}],
/*retTy=*/"bool",
/*methodName=*/"isFusibleWithConsumer",
/*args=*/(ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
/// Return whether this op can be fused withh its consumers
return true;
}]
>,
InterfaceMethod<
/*desc=*/"Return whether two inputs have the same shape (assuming no"
"implicit broadcasting).",
/*retTy=*/"bool",
/*methodName=*/"inferInputsShapeEquality",
/*args=*/(ins "int":$lhs, "int":$rhs),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
/// Return whether two inputs have the same shape.
Operation *op = this->getOperation();
assert(lhs < op->getNumOperands() && lhs >= 0 &&
rhs < op->getNumOperands() && rhs >= 0);
if (lhs == rhs) return true;
// if both lhs and rhs have static shapes, check them directly
Type lhs_ty = op->getOperand(lhs).getType();
Type rhs_ty = op->getOperand(rhs).getType();
auto lhs_shape_type = lhs_ty.dyn_cast_or_null<RankedTensorType>();
auto rhs_shape_type = rhs_ty.dyn_cast_or_null<RankedTensorType>();
if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() ||
!rhs_shape_type || !rhs_shape_type.hasStaticShape() ||
lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
return false;
}
return lhs_shape_type.getShape() == rhs_shape_type.getShape();
}]
>,
InterfaceMethod<
/*desc=*/"Return whether two outputs have the same shape (assuming no"
" implicit broadcasting).",
/*retTy=*/"bool",
/*methodName=*/"inferOutputsShapeEquality",
/*args=*/(ins "int":$lhs, "int":$rhs),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
/// Return whether two outputs have the same shape.
Operation *op = this->getOperation();
assert(lhs < op->getNumResults() && lhs >= 0 &&
rhs < op->getNumResults() && rhs >= 0);
if (lhs == rhs) return true;
// if both lhs and rhs have static shapes, check them directly
Type lhs_ty = op->getResult(lhs).getType();
Type rhs_ty = op->getResult(rhs).getType();
auto lhs_shape_type = lhs_ty.dyn_cast_or_null<RankedTensorType>();
auto rhs_shape_type = rhs_ty.dyn_cast_or_null<RankedTensorType>();
if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() ||
!rhs_shape_type || !rhs_shape_type.hasStaticShape() ||
lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
return false;
}
return lhs_shape_type.getShape() == rhs_shape_type.getShape();
}]
>,
InterfaceMethod<
/*desc=*/"Return whether the input and the output have the same"
" shape (assuming no implicit broadcasting).",
/*retTy=*/"bool",
/*methodName=*/"inferInputOutputShapeEquality",
/*args=*/(ins "int":$input, "int":$output),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
/// Return whether the input and the output have the same shape.
Operation *op = this->getOperation();
assert(input < op->getNumOperands() && input >= 0 &&
output < op->getNumResults() && output >= 0);
// if both input and output have static shapes, check them directly
Type input_ty = op->getOperand(input).getType();
Type output_ty = op->getResult(output).getType();
auto input_shape_type = input_ty.dyn_cast_or_null<RankedTensorType>();
auto output_shape_type = output_ty.dyn_cast_or_null<RankedTensorType>();
if (!input_shape_type || !input_shape_type.hasStaticShape() ||
!output_shape_type || !output_shape_type.hasStaticShape() ||
input_shape_type.getRank() != output_shape_type.getRank()) {
return false;
}
return input_shape_type.getShape() == output_shape_type.getShape();
}]
>,
InterfaceMethod<
/*desc=*/[{Return the effective workload shape for the operation.
Here the effective workload shape roughly represents the maximum
parallelism can be used during the codegen stage. It's used to check
the shape-compatibility of the operation. During fusion, we only
try to fuse shape-compatible ops for performace.
For example, the effective workload shape of an elementwise op is its
output shape, while the effective workload shape of a reduction op may
be its operand shape.
Return None if such an inference is not possible.
}],
/*retTy=*/"llvm::Optional<Value>",
/*methodName=*/"inferEffectiveWorkloadShape",
/*args=*/(ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
/// Return effective workload size if possible, otherwise None.
return {};
}]
>,
];
}
#endif

View File

@ -45,6 +45,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h.inc"
namespace mlir {
#include "tensorflow/compiler/mlir/xla/ir/lhlo_structs.cc.inc"
namespace xla_lhlo {
XlaLhloDialect::XlaLhloDialect(MLIRContext *context)

View File

@ -33,6 +33,8 @@ limitations under the License.
namespace mlir {
class OpBuilder;
#include "tensorflow/compiler/mlir/xla/ir/lhlo_structs.h.inc"
namespace xla_lhlo {
class XlaLhloDialect : public Dialect {

View File

@ -407,11 +407,39 @@ def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp {
);
}
// TODO(bondhugula): Make this struct dialect independent so that it can be
// shared between the HLO and LHLO dialects.
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [
StructFieldAttr<"input_batch_dimension",I64Attr>,
StructFieldAttr<"input_feature_dimension", I64Attr>,
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"output_batch_dimension", I64Attr>,
StructFieldAttr<"output_feature_dimension", I64Attr>,
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
let description = "Structure of dimension information for conv op";
}
def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$window_strides,
// Default value: zero for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$padding,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
ConvDimensionNumbers:$dimension_numbers,
I64Attr:$feature_group_count,
I64Attr:$batch_group_count,
HLO_PrecisionConfigAttr:$precision_config
);
}

View File

@ -937,6 +937,11 @@ LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) {
return success();
}
LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) {
// TODO(whoever): currently not supported.
return failure();
}
} // namespace
} // namespace xla_hlo
} // namespace mlir
@ -979,10 +984,10 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
values.reserve(attr.getNumElements());
for (APFloat val : attr.getValues<APFloat>()) {
bool loses_info = false;
CHECK_EQ(val.convert(llvm::APFloat::IEEEsingle(),
llvm::APFloat::rmTowardZero, &loses_info),
llvm::APFloat::opOK);
CHECK(!loses_info);
TF_RET_CHECK(val.convert(llvm::APFloat::IEEEsingle(),
llvm::APFloat::rmTowardZero,
&loses_info) == llvm::APFloat::opOK);
TF_RET_CHECK(!loses_info);
values.push_back(xla::half(val.convertToFloat()));
}
xla::Array<xla::half> source_data(shape.dimensions());
@ -992,10 +997,15 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
case xla::PrimitiveType::BF16: {
xla::Array<double> source_data(shape.dimensions());
auto attr_values = attr.getValues<APFloat>();
std::vector<double> values_double(source_data.num_elements());
for (auto index_and_value : llvm::enumerate(attr_values)) {
values_double[index_and_value.index()] =
index_and_value.value().convertToDouble();
std::vector<double> values_double;
values_double.reserve(source_data.num_elements());
for (APFloat val : attr_values) {
bool loses_info = false;
TF_RET_CHECK(val.convert(llvm::APFloat::IEEEdouble(),
llvm::APFloat::rmTowardZero,
&loses_info) == llvm::APFloat::opOK);
TF_RET_CHECK(!loses_info);
values_double.push_back(val.convertToDouble());
}
source_data.SetValues(values_double);
return xla::LiteralUtil::ConvertF64ToBF16(

View File

@ -191,7 +191,7 @@ func @const_f32_bf16() -> tensor<bf16> {
// CHECK-LABEL: func @const_bf16_f64
func @const_bf16_f64() -> tensor<f64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+00> : tensor<f64>
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.187500e+00> : tensor<f64>
%cst = xla_hlo.constant dense<4.2> : tensor<bf16>
%0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<f64>
// CHECK-NEXT: return [[CST]]

View File

@ -432,7 +432,39 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
// CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]],
// CHECK-SAME: %[[RESULT:.*]]: [[TYPE]])
// CHECK: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %{{.*}}) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
%dot = "xla_hlo.dot"(%arg0, %arg0)
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
return %dot : tensor<1024x1024xf32>
}
%dot = "xla_hlo.dot"(%arg0, %arg0)
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
return %dot : tensor<1024x1024xf32>
}
// -----
// CHECK-LABEL: func @conv
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> {
%c0 = constant 0 : index
// CHECK: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
// CHECK: "xla_lhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
// CHECK-SAME: padding = dense<[
// CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
// CHECK-SAME: rhs_dilation = dense<[1, 2]>
// CHECK-SAME: window_strides = dense<[2, 1]>
%out = "xla_hlo.convolution"(%filter, %input) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
kernel_input_feature_dimension = 2 : i64,
kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
},
feature_group_count = 1 : i64,
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>
} : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32>
return %out : tensor<3x5x5x4xf32>
}

View File

@ -511,7 +511,7 @@ func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[L0:.*]] = constant 0 : i32
// CHECK-NEXT: %[[L0:.*]] = constant 0 : i32
// CHECK-NEXT: %[[RESULT:.*]] = subi %[[L0]], %[[OPERAND_IN]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
@ -691,3 +691,27 @@ func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) {
return
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// -----
func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: memref<3x5x5x4xf32>) {
%c0 = constant 0 : index
%0 = alloc() : memref<3x5x5x4xf32>
// CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}})
// CHECK-SAME: dilations = [1, 2]
// CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64>
// CHECK-SAME: strides = [2, 1]}
// With all atributes explicitly specified.
"xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
// Dilation left unspecified, sets default dilation since linalg expects it.
// CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}})
// CHECK-SAME: dilations = [1, 1]
// Padding is not set if it's zero.
// CHECK-NOT: padding
"xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
"xla_lhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> ()
"xla_lhlo.terminator"() : () -> ()
}

View File

@ -0,0 +1,97 @@
// RUN: tf-opt %s -xla-hlo-fusion -split-input-file | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: func @multi_outputs_same
func @multi_outputs_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "xla_hlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = "xla_hlo.add"(%1, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RET:.*]]:2 = "xla_hlo.fusion"
// CHECK-NEXT: xla_hlo.add
// CHECK-NEXT: xla_hlo.subtract
// CHECK-NEXT: xla_hlo.add
// CHECK-NEXT: xla_hlo.return
return %1, %2 : tensor<?x?xf32>, tensor<?x?xf32>
}
// -----
// CHECK-LABEL: func @multi_outputs_same_2
func @multi_outputs_same_2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
%0 = "xla_hlo.abs"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "xla_hlo.abs"(%arg1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = "xla_hlo.add"(%0, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%3 = "xla_hlo.abs"(%0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%4 = "xla_hlo.abs"(%1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RET:.*]]:3 = "xla_hlo.fusion"
// CHECK-NEXT: xla_hlo.abs
// CHECK-NEXT: xla_hlo.abs
// CHECK-NEXT: xla_hlo.add
// CHECK-NEXT: xla_hlo.abs
// CHECK-NEXT: xla_hlo.abs
// CHECK-NEXT: xla_hlo.return
return %2, %3, %4 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
}
// -----
// CHECK-LABEL: func @multi_outputs_not_sure_same
func @multi_outputs_not_sure_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
%0 = "xla_hlo.add"(%arg0, %arg0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: xla_hlo.fusion
%1 = "xla_hlo.subtract"(%arg1, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
}
// -----
// CHECK-LABEL: func @reduce
func @reduce(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?xf32>) {
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "xla_hlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RET0:.*]] = "xla_hlo.fusion"
// CHECK-NEXT: xla_hlo.add
// CHECK-NEXT: xla_hlo.subtract
// CHECK-NEXT: xla_hlo.return
// Currently we do not support fuse arguments and ops without direct producer-consumer
// relationship. Thus Reduce Op should not be fused with above two ops.
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%3 = "xla_hlo.reduce"(%arg0, %2) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%4 = "xla_hlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%4) : (tensor<f32>) -> ()
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
%4 = "xla_hlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// Above two ops should not be fused since reduce op can not be
// fused with its consumer.
// CHECK-NOT: xla_hlo.fusion
return %1, %4 : tensor<?x?xf32>, tensor<?xf32>
}
// -----
// CHECK-LABEL: func @reduce_2
func @reduce_2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?xf32>) {
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "xla_hlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%3 = "xla_hlo.reduce"(%1, %2) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%4 = "xla_hlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%4) : (tensor<f32>) -> ()
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
// CHECK: %[[RET0:.*]]:2 = "xla_hlo.fusion"
// CHECK-NEXT: xla_hlo.add
// CHECK-NEXT: xla_hlo.subtract
// CHECK-NEXT: xla_hlo.constant
// CHECK-NEXT: xla_hlo.reduce
// CHECK: xla_hlo.return
// Following op should not be fused with the above ops since reduce op can not be
// fused with its consumer.
// CHECK-NOT: xla_hlo.fusion
%4 = "xla_hlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %1, %4 : tensor<?x?xf32>, tensor<?xf32>
}

View File

@ -0,0 +1,340 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/transforms/cycle_detector.h"
#include <algorithm>
#include "llvm/ADT/DenseSet.h"
namespace mlir {
namespace {
using NodeSet = llvm::DenseSet<int32_t>;
using OrderedNodeSet = OrderedSet<int32_t>;
template <typename T>
struct VecStruct {
typedef llvm::SmallVector<T, 4> type;
};
template <typename T>
using Vec = typename VecStruct<T>::type;
struct Node {
// rank number assigned by Pearce-Kelly algorithm
int32_t rank;
// Temporary marker used by depth-first-search
bool visited;
// User-supplied data
void* data;
// List of immediate predecessor nodes in graph
OrderedNodeSet in;
// List of immediate successor nodes in graph
OrderedNodeSet out;
};
} // namespace
struct GraphCycles::Rep {
Vec<Node*> nodes;
// Indices for unused entries in nodes
Vec<int32_t> free_nodes;
// Temporary state.
// Results of forward DFS
Vec<int32_t> deltaf;
// Results of backward DFS
Vec<int32_t> deltab;
// All nodes to reprocess
Vec<int32_t> list;
// Rank values to assign to list entries
Vec<int32_t> merged;
// Emulates recursion stack when doing depth first search
Vec<int32_t> stack;
};
GraphCycles::GraphCycles(int32_t num_nodes) : rep_(new Rep) {
rep_->nodes.reserve(num_nodes);
for (int32_t i = 0; i < num_nodes; ++i) {
Node* n = new Node;
n->visited = false;
n->data = nullptr;
n->rank = rep_->nodes.size();
rep_->nodes.push_back(n);
}
}
GraphCycles::~GraphCycles() {
for (Vec<Node*>::size_type i = 0, e = rep_->nodes.size(); i < e; ++i) {
delete rep_->nodes[i];
}
delete rep_;
}
bool GraphCycles::HasEdge(int32_t x, int32_t y) const {
return rep_->nodes[x]->out.Contains(y);
}
void GraphCycles::RemoveEdge(int32_t x, int32_t y) {
rep_->nodes[x]->out.Erase(y);
rep_->nodes[y]->in.Erase(x);
// No need to update the rank assignment since a previous valid
// rank assignment remains valid after an edge deletion.
}
static bool ForwardDFS(GraphCycles::Rep* r, int32_t n, int32_t upper_bound);
static void BackwardDFS(GraphCycles::Rep* r, int32_t n, int32_t lower_bound);
static void Reorder(GraphCycles::Rep* r);
static void Sort(const Vec<Node*>&, Vec<int32_t>* delta);
static void MoveToList(GraphCycles::Rep* r, Vec<int32_t>* src,
Vec<int32_t>* dst);
static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32_t>& nodes);
bool GraphCycles::InsertEdge(int32_t x, int32_t y) {
if (x == y) return false;
Rep* r = rep_;
Node* nx = r->nodes[x];
if (!nx->out.Insert(y)) {
// Edge already exists.
return true;
}
Node* ny = r->nodes[y];
ny->in.Insert(x);
if (nx->rank <= ny->rank) {
// New edge is consistent with existing rank assignment.
return true;
}
// Current rank assignments are incompatible with the new edge. Recompute.
// We only need to consider nodes that fall in the range [ny->rank,nx->rank].
if (ForwardDFS(r, y, nx->rank)) {
// Found a cycle. Undo the insertion and tell caller.
nx->out.Erase(y);
ny->in.Erase(x);
// Since we do not call Reorder() on this path, clear any visited
// markers left by ForwardDFS.
ClearVisitedBits(r, r->deltaf);
return false;
}
BackwardDFS(r, x, ny->rank);
Reorder(r);
return true;
}
// Follows the edges from producer to consumer and searchs if the node having
// rank `n` can reach the node having rank `upper_bound` using a DFS search.
// When doing DFS search, We only consider the pathes that satisfy the ranks
// of the nodes of the path are all smaller than `upper_bound`.
//
// Returns true if such path exists.
static bool ForwardDFS(GraphCycles::Rep* r, int32_t n, int32_t upper_bound) {
// Avoid recursion since stack space might be limited.
// We instead keep a stack of nodes to visit.
r->deltaf.clear();
r->stack.clear();
r->stack.push_back(n);
while (!r->stack.empty()) {
n = r->stack.back();
r->stack.pop_back();
Node* nn = r->nodes[n];
if (nn->visited) continue;
nn->visited = true;
r->deltaf.push_back(n);
for (auto w : nn->out.GetSequence()) {
Node* nw = r->nodes[w];
if (nw->rank == upper_bound) {
return true;
}
if (!nw->visited && nw->rank < upper_bound) {
r->stack.push_back(w);
}
}
}
return false;
}
// Follows the edges from consumer to producer and visit all the nodes that
// is reachable from node `n` and have rank larger than `lower_bound`.
static void BackwardDFS(GraphCycles::Rep* r, int32_t n, int32_t lower_bound) {
r->deltab.clear();
r->stack.clear();
r->stack.push_back(n);
while (!r->stack.empty()) {
n = r->stack.back();
r->stack.pop_back();
Node* nn = r->nodes[n];
if (nn->visited) continue;
nn->visited = true;
r->deltab.push_back(n);
for (auto w : nn->in.GetSequence()) {
Node* nw = r->nodes[w];
if (!nw->visited && lower_bound < nw->rank) {
r->stack.push_back(w);
}
}
}
}
// Recomputes rank assignments to make them compatible with the edges (producer
// has smaller rank than its consumer)
static void Reorder(GraphCycles::Rep* r) {
Sort(r->nodes, &r->deltab);
Sort(r->nodes, &r->deltaf);
// Adds contents of delta lists to list (backwards deltas first).
r->list.clear();
MoveToList(r, &r->deltab, &r->list);
MoveToList(r, &r->deltaf, &r->list);
// Produce sorted list of all ranks that will be reassigned.
r->merged.resize(r->deltab.size() + r->deltaf.size());
std::merge(r->deltab.begin(), r->deltab.end(), r->deltaf.begin(),
r->deltaf.end(), r->merged.begin());
// Assign the ranks in order to the collected list.
for (Vec<int32_t>::size_type i = 0, e = r->list.size(); i < e; ++i) {
r->nodes[r->list[i]]->rank = r->merged[i];
}
}
// Sorts nodes in the vector according to their ranks. Small rank first.
static void Sort(const Vec<Node*>& nodes, Vec<int32_t>* delta) {
struct ByRank {
const Vec<Node*>* nodes;
bool operator()(int32_t a, int32_t b) const {
return (*nodes)[a]->rank < (*nodes)[b]->rank;
}
};
ByRank cmp;
cmp.nodes = &nodes;
std::sort(delta->begin(), delta->end(), cmp);
}
// Collects ranks of nodes in vector `src` to vector `dst`
static void MoveToList(GraphCycles::Rep* r, Vec<int32_t>* src,
Vec<int32_t>* dst) {
for (Vec<int32_t>::size_type i = 0, e = src->size(); i < e; i++) {
int32_t w = (*src)[i];
// Replace src entry with its rank
(*src)[i] = r->nodes[w]->rank;
// Prepare for future DFS calls
r->nodes[w]->visited = false;
dst->push_back(w);
}
}
// Clears bookkeeping fileds used during the last DFS process.
static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32_t>& nodes) {
for (Vec<int32_t>::size_type i = 0, e = nodes.size(); i < e; i++) {
r->nodes[nodes[i]]->visited = false;
}
}
bool GraphCycles::IsReachable(int32_t x, int32_t y) {
if (x == y) return true;
Rep* r = rep_;
Node* nx = r->nodes[x];
Node* ny = r->nodes[y];
if (nx->rank >= ny->rank) {
// x cannot reach y since it is after it in the topological ordering
return false;
}
// See if x can reach y using a DFS search that is limited to y's rank
bool reachable = ForwardDFS(r, x, ny->rank);
// Clear any visited markers left by ForwardDFS.
ClearVisitedBits(r, r->deltaf);
return reachable;
}
llvm::Optional<int32_t> GraphCycles::ContractEdge(int32_t a, int32_t b) {
assert(HasEdge(a, b));
RemoveEdge(a, b);
if (IsReachable(a, b)) {
// Restore the graph to its original state.
InsertEdge(a, b);
return {};
}
if (rep_->nodes[b]->in.Size() + rep_->nodes[b]->out.Size() >
rep_->nodes[a]->in.Size() + rep_->nodes[a]->out.Size()) {
// Swap "a" and "b" to minimize copying.
std::swap(a, b);
}
Node* nb = rep_->nodes[b];
OrderedNodeSet out = std::move(nb->out);
OrderedNodeSet in = std::move(nb->in);
for (int32_t y : out.GetSequence()) {
rep_->nodes[y]->in.Erase(b);
}
for (int32_t y : in.GetSequence()) {
rep_->nodes[y]->out.Erase(b);
}
rep_->free_nodes.push_back(b);
rep_->nodes[a]->out.Reserve(rep_->nodes[a]->out.Size() + out.Size());
for (int32_t y : out.GetSequence()) {
InsertEdge(a, y);
}
rep_->nodes[a]->in.Reserve(rep_->nodes[a]->in.Size() + in.Size());
for (int32_t y : in.GetSequence()) {
InsertEdge(y, a);
}
// Note, if the swap happened it might be what originally was called "b".
return a;
}
std::vector<int32_t> GraphCycles::SuccessorsCopy(int32_t node) const {
return rep_->nodes[node]->out.GetSequence();
}
namespace {
void SortInPostOrder(const Vec<Node*>& nodes, std::vector<int32_t>* to_sort) {
std::sort(to_sort->begin(), to_sort->end(), [&](int32_t a, int32_t b) {
return nodes[a]->rank > nodes[b]->rank;
});
}
} // namespace
std::vector<int32_t> GraphCycles::AllNodesInPostOrder() const {
llvm::DenseSet<int32_t> free_nodes_set;
for (int32_t n : rep_->free_nodes) free_nodes_set.insert(n);
std::vector<int32_t> all_nodes;
all_nodes.reserve(rep_->nodes.size() - free_nodes_set.size());
for (size_t i = 0, e = rep_->nodes.size(); i < e; i++) {
if (!free_nodes_set.count(i)) {
all_nodes.push_back(i);
}
}
SortInPostOrder(rep_->nodes, &all_nodes);
return all_nodes;
}
} // namespace mlir

View File

@ -0,0 +1,165 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_
#include <vector>
#include "llvm/ADT/DenseMap.h"
namespace mlir {
// -------------------------------------------------------------------
// This file contains a light version of GraphCycles implemented in
// tensorflow/compiler/jit/graphcycles/graphcycles.h
//
// We re-implement it here because we do not want to rely
// on TensorFlow data structures, and hence we can move
// corresponding passes to llvm repo. easily in case necessnary.
// --------------------------------------------------------------------
// This is a set data structure that provides a deterministic iteration order.
// The iteration order of elements only depends on the sequence of
// inserts/deletes, so as long as the inserts/deletes happen in the same
// sequence, the set will have the same iteration order.
//
// Assumes that T can be cheaply copied for simplicity.
template <typename T>
class OrderedSet {
public:
// Inserts `value` into the ordered set. Returns true if the value was not
// present in the set before the insertion.
bool Insert(T value) {
bool new_insertion =
value_to_index_.insert({value, value_sequence_.size()}).second;
if (new_insertion) {
value_sequence_.push_back(value);
}
return new_insertion;
}
// Removes `value` from the set. Assumes `value` is already present in the
// set.
void Erase(T value) {
auto it = value_to_index_.find(value);
// Since we don't want to move values around in `value_sequence_` we swap
// the value in the last position and with value to be deleted and then
// pop_back.
value_to_index_[value_sequence_.back()] = it->second;
std::swap(value_sequence_[it->second], value_sequence_.back());
value_sequence_.pop_back();
value_to_index_.erase(it);
}
void Reserve(size_t new_size) {
value_to_index_.reserve(new_size);
value_sequence_.reserve(new_size);
}
void Clear() {
value_to_index_.clear();
value_sequence_.clear();
}
bool Contains(T value) const { return value_to_index_.count(value); }
size_t Size() const { return value_sequence_.size(); }
const std::vector<T>& GetSequence() const { return value_sequence_; }
private:
// The stable order that we maintain through insertions and deletions.
std::vector<T> value_sequence_;
// Maps values to their indices in `value_sequence_`.
llvm::DenseMap<T, int> value_to_index_;
};
// ---------------------------------------------------------------------
// GraphCycles detects the introduction of a cycle into a directed
// graph that is being built up incrementally.
//
// Nodes are identified by small integers. It is not possible to
// record multiple edges with the same (source, destination) pair;
// requests to add an edge where one already exists are silently
// ignored.
//
// It is also not possible to introduce a cycle; an attempt to insert
// an edge that would introduce a cycle fails and returns false.
//
// GraphCycles uses no internal locking; calls into it should be
// serialized externally.
// Performance considerations:
// Works well on sparse graphs, poorly on dense graphs.
// Extra information is maintained incrementally to detect cycles quickly.
// InsertEdge() is very fast when the edge already exists, and reasonably fast
// otherwise.
// FindPath() is linear in the size of the graph.
// The current implementation uses O(|V|+|E|) space.
class GraphCycles {
public:
explicit GraphCycles(int32_t num_nodes);
~GraphCycles();
// Attempt to insert an edge from x to y. If the
// edge would introduce a cycle, return false without making any
// changes. Otherwise add the edge and return true.
bool InsertEdge(int32_t x, int32_t y);
// Remove any edge that exists from x to y.
void RemoveEdge(int32_t x, int32_t y);
// Return whether there is an edge directly from x to y.
bool HasEdge(int32_t x, int32_t y) const;
// Contracts the edge from 'a' to node 'b', merging nodes 'a' and 'b'. One of
// the nodes is removed from the graph, and edges to/from it are added to
// the remaining one, which is returned. If contracting the edge would create
// a cycle, does nothing and return no value.
llvm::Optional<int32_t> ContractEdge(int32_t a, int32_t b);
// Return whether dest_node `y` is reachable from source_node `x`
// by following edges. This is non-thread-safe version.
bool IsReachable(int32_t x, int32_t y);
// Return a copy of the successors set. This is needed for code using the
// collection while modifying the GraphCycles.
std::vector<int32_t> SuccessorsCopy(int32_t node) const;
// Returns all nodes in post order.
//
// If there is a path from X to Y then X appears after Y in the
// returned vector.
std::vector<int32_t> AllNodesInPostOrder() const;
// ----------------------------------------------------
struct Rep;
private:
GraphCycles(const GraphCycles&) = delete;
GraphCycles& operator=(const GraphCycles&) = delete;
Rep* rep_; // opaque representation
};
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_

View File

@ -0,0 +1,89 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/transforms/cycle_detector.h"
#include "tensorflow/compiler/xla/test.h"
class GraphCyclesTest : public ::testing::Test {
public:
GraphCyclesTest() : g_(100) {}
bool AddEdge(int x, int y) { return g_.InsertEdge(x, y); }
void AddMultiples() {
// For every node x > 0: add edge to 2*x, 3*x
for (int x = 1; x < 25; x++) {
EXPECT_TRUE(AddEdge(x, 2 * x)) << x;
EXPECT_TRUE(AddEdge(x, 3 * x)) << x;
}
}
mlir::GraphCycles g_;
};
TEST_F(GraphCyclesTest, NoCycle) { AddMultiples(); }
TEST_F(GraphCyclesTest, SimpleCycle) {
AddMultiples();
EXPECT_FALSE(AddEdge(8, 4));
}
TEST_F(GraphCyclesTest, IndirectCycle) {
AddMultiples();
EXPECT_TRUE(AddEdge(16, 9));
EXPECT_FALSE(AddEdge(9, 2));
}
TEST_F(GraphCyclesTest, RemoveEdge) {
EXPECT_TRUE(AddEdge(1, 2));
EXPECT_TRUE(AddEdge(2, 3));
EXPECT_TRUE(AddEdge(3, 4));
EXPECT_TRUE(AddEdge(4, 5));
g_.RemoveEdge(2, 3);
EXPECT_FALSE(g_.HasEdge(2, 3));
}
TEST_F(GraphCyclesTest, IsReachable) {
EXPECT_TRUE(AddEdge(1, 2));
EXPECT_TRUE(AddEdge(2, 3));
EXPECT_TRUE(AddEdge(3, 4));
EXPECT_TRUE(AddEdge(4, 5));
EXPECT_TRUE(g_.IsReachable(1, 5));
EXPECT_FALSE(g_.IsReachable(5, 1));
}
TEST_F(GraphCyclesTest, ContractEdge) {
ASSERT_TRUE(AddEdge(1, 2));
ASSERT_TRUE(AddEdge(1, 3));
ASSERT_TRUE(AddEdge(2, 3));
ASSERT_TRUE(AddEdge(2, 4));
ASSERT_TRUE(AddEdge(3, 4));
// It will introduce a cycle if the edge is contracted
EXPECT_FALSE(g_.ContractEdge(1, 3).hasValue());
EXPECT_TRUE(g_.HasEdge(1, 3));
// Node (2) has more edges.
EXPECT_EQ(*g_.ContractEdge(1, 2), 2);
EXPECT_TRUE(g_.HasEdge(2, 3));
EXPECT_TRUE(g_.HasEdge(2, 4));
EXPECT_TRUE(g_.HasEdge(3, 4));
// Node (2) has more edges.
EXPECT_EQ(*g_.ContractEdge(2, 3), 2);
EXPECT_TRUE(g_.HasEdge(2, 4));
}

View File

@ -423,6 +423,7 @@ void populateHLOToLHLOConversionPattern(
HloToLhloOpConverter<xla_hlo::CompareOp>,
HloToLhloOpConverter<xla_hlo::ComplexOp>,
HloToLhloOpConverter<xla_hlo::ConstOp>,
HloToLhloOpConverter<xla_hlo::ConvOp>,
HloToLhloOpConverter<xla_hlo::ConvertOp>,
HloToLhloOpConverter<xla_hlo::CopyOp>,
HloToLhloOpConverter<xla_hlo::CosOp>,

View File

@ -45,6 +45,7 @@ MAP_HLO_TO_LHLO(CeilOp);
MAP_HLO_TO_LHLO(ConstOp);
MAP_HLO_TO_LHLO(CompareOp);
MAP_HLO_TO_LHLO(ComplexOp);
MAP_HLO_TO_LHLO(ConvOp);
MAP_HLO_TO_LHLO(ConvertOp);
MAP_HLO_TO_LHLO(CopyOp);
MAP_HLO_TO_LHLO(CosOp);

View File

@ -73,6 +73,9 @@ std::unique_ptr<OperationPass<FuncOp>> createTransformUnrankedHloPass();
// necessary to export to XLA.
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
// fuse xla_hlo ops to kLoop/kInput fusion patterns
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusionPass();
} // namespace xla_hlo
namespace xla_lhlo {

View File

@ -0,0 +1,579 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
#include "llvm/ADT/EquivalenceClasses.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/cycle_detector.h"
// This pass has similar functionality of the fusion pass in XLA stack.
// However, unlike XLA, it targets the fully dynamic shape scenario.
// Currently, it implements the kLoop and kInput fusion templates.
// During conversion, it tries to greedily find kLoop/kInput fusion
// patterns.
//
// Similar to XLA, this pass supports fusion pattern having multiple outputs
// if all the shape of outputs are consistent. Following are some examples.
//
// kLoop kInput
// +----+ +----+ +----+ +----+ +----+ +----+
// |elem| |elem| |elem| |elem<----+elem+---->elem+----+
// +-+--+ +-+--+ +-+--+ +-+--+ +----+ +-+--+ |
// | | | | | |
// | | | | |
// +-v--+ | +-v--+ +--v---+ +--v---+ |
// |elem+<---+----<+elem| |reduce| |reduce| |
// +-+--+ +-+--+ +--+---+ +--+---+ |
// | | | | |
// | | | | |
// v v v v v
//
// To this end, we also add an simple shape constraint analysis phase.
// For kLoop fusion template, it requires all the outputs of the fused
// pattern have the same shape. However, we don't know the actual value
// of the shape at the compile time in the dynamic shape world.
// Fortunately, we could still infer the relationship among different ops
// according to their shape constrain traits. Currently, We only consider
// shape equality propagation for elementwise ops (assuming that implicit
// shape broadcast is forbidden). The above process could be built on the
// shape dialect once it is ready.
namespace mlir {
namespace xla_hlo {
namespace {
using llvm::EquivalenceClasses;
using FusionPattern = std::vector<Operation*>;
using FusionPlan = std::vector<FusionPattern>;
// To support using EquivalenceClasses for Value
class ValueWrapper {
public:
explicit ValueWrapper(Value value) : value_(std::move(value)) {}
Value getValue() const { return value_; }
bool operator==(const ValueWrapper& rhs) const {
return getValue() == rhs.getValue();
}
private:
Value value_;
};
bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs) {
auto lhs_value = lhs.getValue().getAsOpaquePointer();
auto rhs_value = rhs.getValue().getAsOpaquePointer();
return lhs_value < rhs_value;
}
bool IsFusible(Operation* op) {
if (matchPattern(op, m_Constant())) {
return true;
}
auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
return op_fusibility && (op_fusibility.isFusibleWithOperand() ||
op_fusibility.isFusibleWithConsumer());
}
SmallVector<Value, 4> GetInputsOfFusionPattern(const FusionPattern& pattern) {
SmallVector<Value, 4> inputs;
DenseSet<Value> input_set;
DenseSet<Operation*> op_set;
for (Operation* op : pattern) {
bool inserted = op_set.insert(op).second;
(void)inserted;
assert(inserted && "FusionPattern contains duplicate operations");
}
for (Operation* op : pattern) {
for (Value operand : op->getOperands()) {
Operation* operand_op = operand.getDefiningOp();
if (op_set.find(operand_op) != op_set.end()) {
// skip if defining op is in the pattern
continue;
}
if (input_set.insert(operand).second) {
inputs.push_back(operand);
}
}
}
return inputs;
}
SmallVector<Value, 4> GetOutputsOfFusionPattern(const FusionPattern& pattern) {
SmallVector<Value, 4> outputs;
DenseSet<Operation*> op_set;
for (Operation* op : pattern) {
bool inserted = op_set.insert(op).second;
(void)inserted;
assert(inserted && "FusionPattern contains duplicate operations");
}
for (Operation* op : pattern) {
for (Value result : op->getResults()) {
bool has_external_user = llvm::any_of(
result.getUses(),
[&](OpOperand& use) { return !op_set.count(use.getOwner()); });
if (has_external_user) {
outputs.push_back(result);
}
}
}
return outputs;
}
FusionPattern MergeFusionPattern(const FusionPattern& lhs,
const FusionPattern& rhs) {
FusionPattern pattern(lhs);
pattern.insert(pattern.end(), rhs.begin(), rhs.end());
return pattern;
}
inline int EffectiveSize(const FusionPattern& pattern) {
return llvm::count_if(
pattern, [](Operation* op) { return !matchPattern(op, m_Constant()); });
}
// This is an simple shape constraint analysis, which is used to
// guide fusion decision (e.g. we only fuse shape-compatible ops).
//
// Currently, We only consider shape equality propagation based
// on the shape constrain traits of elementwise ops (assuming that
// implicit shape broadcast is forbidden).
class ShapeConstraintAnalysis {
public:
explicit ShapeConstraintAnalysis(const SmallVectorImpl<Operation*>& op_list) {
PropagateEquality(op_list);
}
// Returns true is `lhs` and `rhs` are supposed to have same shape.
bool HasSameShape(Value lhs, Value rhs) {
return impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs));
}
private:
// shape equality propagation based on the shape constrains of
// elementwise ops.
void PropagateEquality(const SmallVectorImpl<Operation*>& op_list) {
bool converged = true;
do {
converged = true;
auto update = [&](Value lhs, Value rhs) {
if (!impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs))) {
converged = false;
impl_.unionSets(ValueWrapper(lhs), ValueWrapper(rhs));
}
};
for (Operation* op : op_list) {
auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
if (!op_fusibility) continue;
int numInput = op->getNumOperands();
int numOutput = op->getNumResults();
// shape equality propagation between inputs.
for (int input1 = 0; input1 < numInput; ++input1)
for (int input2 = input1 + 1; input2 < numInput; ++input2)
if (op_fusibility.inferInputsShapeEquality(input1, input2))
update(op->getOperand(input1), op->getOperand(input2));
// shape equality propagation between outputs.
for (int output1 = 0; output1 < numOutput; ++output1)
for (int output2 = output1 + 1; output2 < numOutput; ++output2)
if (op_fusibility.inferOutputsShapeEquality(output1, output2))
update(op->getResult(output1), op->getResult(output2));
// shape equality propagation between input and output.
for (int input = 0; input < numInput; ++input)
for (int output = 0; output < numOutput; ++output)
if (op_fusibility.inferInputOutputShapeEquality(input, output))
update(op->getOperand(input), op->getResult(output));
}
} while (!converged);
}
// a UnionFind set
EquivalenceClasses<ValueWrapper> impl_;
};
// A fusion planner that can propose a fusion plan for a block of ops.
// The fusion plan is consisted of a group of fusion patterns.
//
// Currently all proposed patterns followed xla kLoop/kInput like fusion
// templates while are adapted to the fully dynamic shape world.
//
// kLoop fusion template satifies:
// - all ops in the fusion pattern are element-wise.
// - all the shapes of outputs of fusion pattern are same, and thus can
// fit into a same parallel loop.
//
// kInput fusion template satifies:
// - any op in the fusion pattern is either element-wise or a reduction.
// - if a op is a reduction, its output cannot be consumered by other
// ops in the same fusion pattern.
// - all the effective shapes of outputs of fusion pattern are same.
// - For element-wise op, its effective shape is its output shape.
// - For reduction op, its effective shape is its operand shape.
class FusionPlanner {
public:
explicit FusionPlanner(const SmallVectorImpl<Operation*>& op_list)
: op_list_(op_list),
shape_analysis_(op_list),
cycle_detector_(op_list.size()) {
BuildNodeMap();
}
// Returns a fusion plan if success, otherwise none.
llvm::Optional<FusionPlan> Run() {
// Greedily search connected fusible pattern, and ops belonging to
// a same fusion pattern are grouped into a cluster.
RunEdgeContractionLoop();
// After doing edge contraction, each unique cluster having size
// more than one represents a potential fusion pattern.
// We collect all these clusters and construct a fusion plan.
//
// Note that the ops in a fusion pattern are in topological ordering.
FusionPlan plan;
DenseMap<int, int> pattern_ids;
for (Operation* op : op_list_) {
Cluster* cluster = GetClusterForNode(op);
int node_id = cluster->cycles_graph_node_id();
if (!IsFusible(op_list_[node_id]) ||
EffectiveSize(GetClusterForNode(op)->fused_pattern()) <= 1) {
continue;
}
if (!pattern_ids.count(node_id)) {
int pattern_id = pattern_ids.size();
pattern_ids[node_id] = pattern_id;
plan.emplace_back();
}
plan[pattern_ids[node_id]].push_back(op);
}
return plan;
}
// Returns the op_list this planner operates on.
const SmallVectorImpl<Operation*>& op_list() const { return op_list_; }
private:
// Represent a (partial) fused pattern
class Cluster {
public:
Cluster(int node_id, FusionPlanner* planner) : node_id_(node_id) {
const SmallVectorImpl<Operation*>& op_list = planner->op_list();
pattern_.push_back(op_list[node_id]);
}
// Merges `other` into this cluster, and clears `other`.
void Merge(Cluster* other) {
pattern_.insert(pattern_.end(), other->pattern_.begin(),
other->pattern_.end());
other->pattern_.clear();
}
// The number of nodes in this cluster.
int cluster_size() const { return pattern_.size(); }
// The ID of the cluster as represented in `cycle_detector_`.
int cycles_graph_node_id() const { return node_id_; }
// Sets the ID of the cluster as represented in `cycle_detector_`.
void set_cycles_graph_node_id(int cycles_graph_node_id) {
node_id_ = cycles_graph_node_id;
}
// Currently the fused pattern this cluster holds.
const FusionPattern& fused_pattern() { return pattern_; }
private:
// ID of the representative node of this cluster.
int node_id_;
// the fused pattern this cluster holds.
FusionPattern pattern_;
};
private:
Cluster* MakeCluster(int cycles_graph_node_id) {
cluster_storage_.emplace_back(new Cluster(cycles_graph_node_id, this));
return cluster_storage_.back().get();
}
void BuildNodeMap() {
int num_nodes = op_list_.size();
for (int node_id = 0; node_id < num_nodes; ++node_id) {
Operation* op = op_list_[node_id];
MakeCluster(node_id);
op_to_node_id_[op] = node_id;
leader_for_node_.insert(node_id);
for (Value operand : op->getOperands()) {
Operation* operand_op = operand.getDefiningOp();
if (operand_op == nullptr) {
// skip block argument
continue;
}
auto iter = op_to_node_id_.find(operand_op);
assert(iter != op_to_node_id_.end());
cycle_detector_.InsertEdge(iter->second, node_id);
}
}
}
// Returns the cluster contains this op.
Cluster* GetClusterForNode(Operation* n) {
int id = op_to_node_id_.at(n);
id = leader_for_node_.getLeaderValue(id);
return cluster_storage_[id].get();
}
// Returns the cluster contains the op having `node_id`.
Cluster* GetClusterForCyclesGraphNode(int node_id) {
return cluster_storage_[leader_for_node_.getLeaderValue(node_id)].get();
}
// Merges the clusters `cluster_from` and `cluster_to`.
bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
int from = cluster_from->cycles_graph_node_id();
int to = cluster_to->cycles_graph_node_id();
auto optional_merged_node = cycle_detector_.ContractEdge(from, to);
if (!optional_merged_node.hasValue()) {
llvm::dbgs() << "Could not contract " << from << " -> " << to
<< " because contracting the edge would create a cycle.";
return false;
}
// Merge the clusters.
cluster_from->Merge(cluster_to);
cluster_from->set_cycles_graph_node_id(*optional_merged_node);
// Merge the UnionFind Set.
leader_for_node_.unionSets(from, to);
return true;
}
template <typename FnTy>
bool ForEachEdgeInPostOrder(FnTy fn) {
bool changed = false;
for (int32_t node : cycle_detector_.AllNodesInPostOrder()) {
Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
// Make a copy of the set of successors because we may modify the graph in
// TryToContractEdge.
std::vector<int32_t> successors_copy =
cycle_detector_.SuccessorsCopy(cluster_from->cycles_graph_node_id());
for (int to : successors_copy) {
Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
bool contracted_edge = fn(cluster_from, cluster_to);
changed |= contracted_edge;
}
}
return changed;
}
// returns the outputs if two cluster were merged
SmallVector<Value, 4> GetResultsOfFusedPattern(Cluster* from, Cluster* to) {
FusionPattern fused_pattern =
MergeFusionPattern(from->fused_pattern(), to->fused_pattern());
return GetOutputsOfFusionPattern(fused_pattern);
}
// This function check if fusing `from` with `to` is valid and if so perform
// the merge. The validity is based on the operations in the clusters and
// the compatibility of the shapes of the outputs of the would-be fused
// clusters.
// Returns true is the merge was performed.
bool TryToContractEdge(Cluster* from, Cluster* to) {
int node_to = to->cycles_graph_node_id();
int node_from = from->cycles_graph_node_id();
// Both node_to and node_from should be fusible
if (!IsFusible(op_list_[node_to]) || !IsFusible(op_list_[node_from])) {
return false;
}
auto op_from_fusibility =
dyn_cast<InferFusibilityOpInterface>(op_list_[node_from]);
if (op_from_fusibility && !op_from_fusibility.isFusibleWithConsumer()) {
// This op cannot be fused with its consumers.
return false;
}
auto op_to_fusibility =
dyn_cast<InferFusibilityOpInterface>(op_list_[node_to]);
if (op_to_fusibility && !op_to_fusibility.isFusibleWithOperand()) {
// This op cannot be fused with its operands.
return false;
}
// Output shapes of a fusion pattern should be compatible as described in
// the document of this class.
SmallVector<Value, 4> results = GetResultsOfFusedPattern(from, to);
auto get_workload_shape = [](Value v) {
Operation* op = v.getDefiningOp();
// Block argument
if (!op) return v;
auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
// Const value
if (!op_fusibility) return v;
llvm::Optional<Value> workload =
op_fusibility.inferEffectiveWorkloadShape();
return workload.hasValue() ? *workload : v;
};
Value ref = get_workload_shape(results[0]);
if (!llvm::all_of(results, [&](Value result) {
Value val = get_workload_shape(result);
return shape_analysis_.HasSameShape(ref, val);
})) {
return false;
}
return MergeClusters(from, to);
}
// Greedily fuse connected node.
bool RunEdgeContractionLoop() {
using std::placeholders::_1;
using std::placeholders::_2;
return ForEachEdgeInPostOrder(
std::bind(&FusionPlanner::TryToContractEdge, this, _1, _2));
}
const SmallVectorImpl<Operation*>& op_list_;
// Shape equality checker
ShapeConstraintAnalysis shape_analysis_;
// op -> node_id
std::unordered_map<Operation*, int> op_to_node_id_;
// make sure not introduce cycle after fusion
GraphCycles cycle_detector_;
std::vector<std::unique_ptr<Cluster>> cluster_storage_;
// a UnionFind set. Each set represents a (partial) fused pattern
// and has a leader as representation.
EquivalenceClasses<int32_t> leader_for_node_;
};
struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
void runOnFunction() override {
FuncOp func = getFunction();
if (!IsTargetFunc(func)) {
return;
}
// process each block and do fusion within a block.
for (Block& block : func.getBlocks()) {
SmallVector<Operation*, 4> op_list;
for (Operation& op : block) {
op_list.push_back(&op);
}
FusionPlanner planner(op_list);
llvm::Optional<FusionPlan> plan = planner.Run();
if (!plan) {
emitError(func.getLoc(), "can't find a fusion plan");
signalPassFailure();
return;
}
if (!ApplyFusionPlan(*plan)) {
emitError(func.getLoc(), "apply fusion plan failed");
signalPassFailure();
return;
}
}
}
bool IsTargetFunc(FuncOp func) {
int num_fusible_ops = 0;
bool is_target_func = false;
// We only process the function having enough candidates
func.walk([&](Operation* op) {
num_fusible_ops +=
static_cast<int>(dyn_cast<InferFusibilityOpInterface>(op) != nullptr);
is_target_func = (num_fusible_ops > 1);
// early stop
if (is_target_func) return WalkResult::interrupt();
return WalkResult::advance();
});
return is_target_func;
}
bool ApplyFusionPlan(const FusionPlan& plan) {
for (const FusionPattern& pattern : plan) {
OpBuilder b(pattern.back());
SmallVector<Location, 4> locations;
locations.reserve(pattern.size());
for (Operation* op : pattern) {
locations.push_back(op->getLoc());
}
Location fused_loc =
FusedLoc::get(locations, pattern.back()->getContext());
SmallVector<Value, 4> inputs = GetInputsOfFusionPattern(pattern);
SmallVector<Value, 4> outputs = GetOutputsOfFusionPattern(pattern);
SmallVector<Type, 4> output_types;
output_types.reserve(outputs.size());
for (Value v : outputs) {
output_types.push_back(v.getType());
}
FusionOp fusion =
b.create<xla_hlo::FusionOp>(fused_loc, output_types, inputs);
Region& region = fusion.fused_computation();
region.push_back(new Block);
Block& block = region.front();
for (Operation* op : pattern) {
op->moveBefore(&block, block.end());
}
b.setInsertionPoint(&block, block.end());
b.create<xla_hlo::ReturnOp>(fused_loc, outputs);
for (auto output_and_result : llvm::zip(outputs, fusion.getResults())) {
Value output = std::get<0>(output_and_result);
Value fusion_result = std::get<1>(output_and_result);
for (OpOperand& use : llvm::make_early_inc_range(output.getUses())) {
if (use.getOwner()->getBlock() != &block) use.set(fusion_result);
}
}
}
return true;
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusion() {
return std::make_unique<XlaHloFusion>();
}
static PassRegistration<XlaHloFusion> xla_hlo_fusion_pass(
"xla-hlo-fusion", "fuse xla_hlo ops to kLoop/kInput fusion patterns.");
} // namespace xla_hlo
} // namespace mlir

View File

@ -192,6 +192,108 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
}
};
//===----------------------------------------------------------------------===//
// xla_lhlo.convolution conversion pattern.
//===----------------------------------------------------------------------===//
/// Converts xla_lhlo.convolution operation to a linalg.conv op.
struct ConvToLinalgConverter : public OpConversionPattern<xla_lhlo::ConvOp> {
public:
using OpConversionPattern<xla_lhlo::ConvOp>::OpConversionPattern;
// This code has been adapted from IREE's
// (https://github.com/google/iree/) xla_hlo -> linalg conversion.
LogicalResult matchAndRewrite(
xla_lhlo::ConvOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
// Check validity of dimension information.
if (const xla_lhlo::ConvDimensionNumbers& dimensionNumbers =
op.dimension_numbers()) {
const int inputSpatialRank =
llvm::size(dimensionNumbers.input_spatial_dimensions());
// The dimensions for input should follow the order of
// batch_count, spatial_dims..., input_feature_count.
if (dimensionNumbers.input_batch_dimension().getInt() != 0 ||
dimensionNumbers.input_feature_dimension().getInt() !=
(inputSpatialRank + 1))
return failure();
const int kernelSpatialRank =
llvm::size(dimensionNumbers.kernel_spatial_dimensions());
// The dimensions for filter should follow the order of
// spatial_dims..., input_feature_count, num_output_feature_count.
if (dimensionNumbers.kernel_input_feature_dimension().getInt() !=
kernelSpatialRank ||
dimensionNumbers.kernel_output_feature_dimension().getInt() !=
(kernelSpatialRank + 1))
return failure();
const int outputSpatialRank =
llvm::size(dimensionNumbers.output_spatial_dimensions());
// The dimensions for output should follow the order of
// batch_count, spatial_dims.., output_feature_count.
if (dimensionNumbers.output_batch_dimension().getInt() != 0 ||
dimensionNumbers.output_feature_dimension().getInt() !=
(outputSpatialRank + 1))
return failure();
if (inputSpatialRank != outputSpatialRank ||
inputSpatialRank != kernelSpatialRank)
return failure();
auto inputSpatialDim =
dimensionNumbers.input_spatial_dimensions().begin();
auto kernelSpatialDim =
dimensionNumbers.kernel_spatial_dimensions().begin();
auto outputSpatialDim =
dimensionNumbers.output_spatial_dimensions().begin();
// Check if spatial dims are ordered correctly.
for (int i = 0; i < inputSpatialRank; ++i) {
const int dim = i + 1;
if ((*inputSpatialDim++).getZExtValue() != dim ||
(*outputSpatialDim++).getZExtValue() != dim ||
(*kernelSpatialDim++).getZExtValue() != i)
return failure();
}
}
// TODO: LHS dilation for deconvolution not supported yet.
if (op.lhs_dilation()) {
return failure();
}
llvm::SmallVector<Attribute, 4> strides;
if (auto windowStrides = op.window_strides()) {
auto range = windowStrides->getAttributeValues();
strides.assign(range.begin(), range.end());
}
auto stridesArg = ArrayAttr::get(strides, op.getContext());
llvm::SmallVector<Attribute, 2> dilation;
if (auto rhsDilation = op.rhs_dilation()) {
auto range = rhsDilation->getAttributeValues();
dilation.assign(range.begin(), range.end());
} else {
// Default dilation of 1.
dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1));
}
auto dilationArg = ArrayAttr::get(dilation, op.getContext());
// Set padding only if it is non-zero.
DenseIntElementsAttr padding = op.paddingAttr();
if (!padding || !llvm::any_of(padding.getValues<APInt>(), [](APInt intVal) {
return !intVal.isNullValue();
})) {
padding = nullptr;
}
// The order of input and filter are switched with linalg.conv.
rewriter.replaceOpWithNewOp<linalg::ConvOp>(
op, args[1], args[0], args[2], stridesArg, dilationArg, padding);
return success();
}
};
/// Base class for lowering xla operations that have one operand and one result,
/// and are semantically equivalent to a copy of the input to the output (like
/// transpose, some reshape, etc.). The derived classes need to provide a method
@ -814,6 +916,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
// clang-format off
patterns->insert<BroadcastConverter<xla_lhlo::BroadcastOp>,
ConstConverter,
ConvToLinalgConverter,
IotaConverter,
LhloBroadcastInDimConverter,
PointwiseToLinalgConverter<xla_lhlo::AbsOp>,

View File

@ -85,6 +85,11 @@ class UnaryOpsTest(xla_test.XLATestCase):
for i in xrange(len(result)):
self.assertAllClose(result[i], expected[i], rtol, atol)
def AssertCloseAndSorted(self, result, expected, rtol, atol):
"""Tests that result and expeted are both close and sorted."""
self.assertAllClose(result, expected, rtol, atol)
self.assertAllEqual(np.sort(result), result)
@test_util.disable_mlir_bridge(
"MlirHloBuilder::Iota missing required for xla::Diag")
def testAllTypeOps(self):
@ -510,6 +515,16 @@ class UnaryOpsTest(xla_test.XLATestCase):
],
dtype=dtype))
@test_util.disable_mlir_bridge(
"TODO(b/155501444): Handle _UnaryOpsComposition ops from Grappler")
def testFloatOpsDisabledOnMlirBridge(self):
for dtype in self.float_types:
if dtype != np.float16:
self._assertOpOutputMatchesExpected(
lambda x: math_ops.sigmoid(x) / math_ops.log1p(math_ops.exp(x)),
np.array([-40, 40], dtype=dtype),
expected=np.array([1.0, 0.025], dtype=dtype))
@test_util.disable_mlir_bridge(
"TODO(b/153812660): Handle tf.QuantizeAndDequantize compilation")
def testQuantizeAndDequantize(self):
@ -1112,17 +1127,27 @@ class UnaryOpsTest(xla_test.XLATestCase):
[[[12, 13, 14, 15, 28, 29, 30, 31]]]]],
dtype=dtype))
def _assertSoftplusMatchesExpected(self, features, dtype):
def _assertSoftplusMatchesExpected(self,
features,
dtype,
equality_test=None,
rtol=1e-6,
atol=9.1e-6):
features = np.array(features, dtype=dtype)
zero = np.asarray(0).astype(dtype)
expected = np.logaddexp(zero, features).astype(dtype)
self._assertOpOutputMatchesExpected(
nn_ops.softplus, features, expected=expected, rtol=1e-6, atol=9.1e-6)
nn_ops.softplus,
features,
expected=expected,
equality_test=equality_test,
rtol=rtol,
atol=atol)
@test_util.disable_mlir_bridge(
"bf16 type not supported in CreateDenseElementsAttrFromLiteral")
def testSoftplus(self):
for dtype in self.float_types:
for dtype in self.float_types & {dtypes.float32, dtypes.float64}:
self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype)
self._assertSoftplusMatchesExpected(
[[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]], dtype)
@ -1138,6 +1163,13 @@ class UnaryOpsTest(xla_test.XLATestCase):
-log_eps + ten
], dtype)
self._assertSoftplusMatchesExpected(
[0.69302183, 0.69324386],
dtype,
equality_test=self.AssertCloseAndSorted,
rtol=9e-5,
atol=9e-5)
if __name__ == "__main__":
googletest.main()

View File

@ -486,7 +486,7 @@ class SliceAssignTest(xla_test.XLATestCase):
def testUninitialized(self):
with self.assertRaisesRegexp(errors.FailedPreconditionError,
"uninitialized variable"):
"uninitialized"):
with self.session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable([1, 2])
sess.run(v[:].assign([1, 2]))

View File

@ -89,16 +89,25 @@ XLAJIT_MAKE_UNARY(Sign,
xla::Select(xla::Ne(x, x), xla::ZerosLike(x), xla::Sign(x)));
XLAJIT_MAKE_UNARY(Sinh, xla::Sinh(x));
// softplus(x) = log(1 + exp(x))
//
// This is not numerically stable when x is large, it can easily overflow.
// However, we can compute it as LogSumExp(x, 0):
// max(x, 0) + log(exp(x - max(x, 0)) + exp(0 - max(x, 0)))
//
// This is equivalent to:
// max(x, 0) + log1p(exp(-abs(x)))
XLAJIT_MAKE_UNARY(Softplus, xla::Max(x, xla::ScalarLike(x, 0.0)) +
xla::Log1p(xla::Exp(-xla::Abs(x))));
static xla::XlaOp Softplus(xla::XlaBuilder* b, xla::XlaOp features) {
return b->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(features));
xla::XlaOp threshold =
Log(xla::Epsilon(b, shape.element_type())) + ScalarLike(features, 2.0);
// Value above which exp(x) may overflow, but softplus(x) == x
// is within machine epsilon.
xla::XlaOp too_large = Gt(features, -threshold);
// Value below which exp(x) may underflow, but softplus(x) == exp(x)
// is within machine epsilon.
xla::XlaOp too_small = Lt(features, threshold);
xla::XlaOp features_exp = Exp(features);
xla::XlaOp output =
Select(too_large, features,
Select(too_small, features_exp, Log1p(features_exp)));
return output;
});
}
XLAJIT_MAKE_UNARY(Softplus, Softplus(b, x));
// softsign(x) = x / (abs(x) + 1)
XLAJIT_MAKE_UNARY(Softsign, x / (xla::Abs(x) + xla::ScalarLike(x, 1.0)));

View File

@ -413,8 +413,10 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
TF_RET_CHECK(variable != nullptr);
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
if (!variable->initialized()) {
return errors::FailedPrecondition("Read of uninitialized variable ",
variable->name());
return errors::FailedPrecondition(
"Read variable failure ", variable->name(),
". It could mean the variable is uninitialized or the variable is on "
"another device ");
}
if (variable->type() != type) {
return errors::InvalidArgument(
@ -464,8 +466,10 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
TF_RET_CHECK(variable != nullptr);
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
if (!variable->initialized()) {
return errors::InvalidArgument("Read of uninitialized variable ",
variable->name());
return errors::InvalidArgument(
"Read variable failure ", variable->name(),
". It could mean the variable is uninitialized or the variable is on "
"another device ");
}
*type = variable->type();
*shape = variable->shape();

View File

@ -1394,8 +1394,8 @@ XlaOp NextAfter(XlaOp from, XlaOp to) {
}
XlaOp Logistic(XlaOp x) {
auto half = xla::ScalarLike(x, 0.5);
return half + half * xla::Tanh(half * x);
auto one = xla::ScalarLike(x, 1);
return xla::Div(one, (one + xla::Exp(xla::Neg(x))));
}
// Computes an approximation to the modified Bessel function of the first kind,

View File

@ -227,9 +227,9 @@ PyNumberMethods PyBfloat16_AsNumber = {
nullptr, // nb_and
nullptr, // nb_xor
nullptr, // nb_or
PyBfloat16_Int, // nb_int
nullptr, // reserved
PyBfloat16_Float, // nb_float
PyBfloat16_Int, // nb_int
nullptr, // reserved
PyBfloat16_Float, // nb_float
nullptr, // nb_inplace_add
nullptr, // nb_inplace_subtract
@ -441,6 +441,29 @@ void ByteSwap16(void* value) {
std::swap(p[0], p[1]);
}
int NPyBfloat16_Compare(const void* a, const void* b, void* arr) {
bfloat16 x;
memcpy(&x, a, sizeof(bfloat16));
bfloat16 y;
memcpy(&y, b, sizeof(bfloat16));
if (x < y) {
return -1;
}
if (y < x) {
return 1;
}
// NaNs sort to the end.
if (!std::isnan(x) && std::isnan(y)) {
return -1;
}
if (std::isnan(x) && !std::isnan(y)) {
return 1;
}
return 0;
}
void NPyBfloat16_CopySwapN(void* dstv, npy_intp dstride, void* srcv,
npy_intp sstride, npy_intp n, int swap, void* arr) {
char* dst = reinterpret_cast<char*>(dstv);
@ -1213,7 +1236,44 @@ struct LogicalXor {
}
};
// TODO(phawkins): implement nextafter, spacing
struct NextAfter {
bfloat16 operator()(bfloat16 from, bfloat16 to) {
uint16_t from_as_int, to_as_int;
const uint16_t sign_mask = 1 << 15;
float from_as_float(from), to_as_float(to);
memcpy(&from_as_int, &from, sizeof(bfloat16));
memcpy(&to_as_int, &to, sizeof(bfloat16));
if (std::isnan(from_as_float) || std::isnan(to_as_float)) {
return bfloat16(std::numeric_limits<float>::quiet_NaN());
}
if (from_as_int == to_as_int) {
return to;
}
if (from_as_float == 0) {
if (to_as_float == 0) {
return to;
} else {
// Smallest subnormal signed like `to`.
uint16_t out_int = (to_as_int & sign_mask) | 1;
bfloat16 out;
memcpy(&out, &out_int, sizeof(bfloat16));
return out;
}
}
uint16_t from_sign = from_as_int & sign_mask;
uint16_t to_sign = to_as_int & sign_mask;
uint16_t from_abs = from_as_int & ~sign_mask;
uint16_t to_abs = to_as_int & ~sign_mask;
uint16_t magnitude_adjustment =
(from_abs > to_abs || from_sign != to_sign) ? 0xFFFF : 0x0001;
uint16_t out_int = from_as_int + magnitude_adjustment;
bfloat16 out;
memcpy(&out, &out_int, sizeof(bfloat16));
return out;
}
};
// TODO(phawkins): implement spacing
} // namespace ufuncs
@ -1243,6 +1303,7 @@ bool Initialize() {
PyArray_InitArrFuncs(&NPyBfloat16_ArrFuncs);
NPyBfloat16_ArrFuncs.getitem = NPyBfloat16_GetItem;
NPyBfloat16_ArrFuncs.setitem = NPyBfloat16_SetItem;
NPyBfloat16_ArrFuncs.compare = NPyBfloat16_Compare;
NPyBfloat16_ArrFuncs.copyswapn = NPyBfloat16_CopySwapN;
NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap;
NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero;
@ -1467,7 +1528,9 @@ bool Initialize() {
RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Ceil>>(numpy.get(),
"ceil") &&
RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Trunc>>(numpy.get(),
"trunc");
"trunc") &&
RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::NextAfter>>(
numpy.get(), "nextafter");
return ok;
}

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import collections
import itertools
import math
from absl.testing import absltest
@ -218,6 +219,12 @@ class Bfloat16Test(parameterized.TestCase):
numpy_assert_allclose(
a, b, rtol=0.1, atol=0.1, equal_nan=True, err_msg="", verbose=True)
def testSort(self):
values_to_sort = np.float32(FLOAT_VALUES)
sorted_f32 = np.sort(values_to_sort)
sorted_bf16 = np.sort(values_to_sort.astype(bfloat16))
np.testing.assert_equal(sorted_f32, np.float32(sorted_bf16))
BinaryOp = collections.namedtuple("BinaryOp", ["op"])
@ -398,6 +405,26 @@ class Bfloat16NumPyTest(parameterized.TestCase):
np.testing.assert_equal(exp1, exp2)
numpy_assert_allclose(mant1, mant2, rtol=1e-2)
def testNextAfter(self):
one = np.array(1., dtype=bfloat16)
two = np.array(2., dtype=bfloat16)
zero = np.array(0., dtype=bfloat16)
nan = np.array(np.nan, dtype=bfloat16)
np.testing.assert_equal(np.nextafter(one, two) - one, epsilon)
np.testing.assert_equal(np.nextafter(one, zero) - one, -epsilon / 2)
np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True)
np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True)
np.testing.assert_equal(np.nextafter(one, one), one)
smallest_denormal = float.fromhex("1.0p-133")
np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal)
np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal)
for a, b in itertools.permutations([0., -0., nan], 2):
np.testing.assert_equal(
np.nextafter(
np.array(a, dtype=np.float32), np.array(b, dtype=np.float32)),
np.nextafter(
np.array(a, dtype=bfloat16), np.array(b, dtype=bfloat16)))
if __name__ == "__main__":
absltest.main()

View File

@ -2973,6 +2973,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
// slice instruction should all have the same layout.
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
pad->shape(), nonzero_pad->mutable_shape()));
simplifier_->UpdateLayout(nonzero_pad->mutable_shape());
// Second, construct the slice instruction to perform the negative padding.
std::vector<int64> start_indices;
@ -2999,9 +3000,14 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
MakeSliceHlo(nonzero_pad, start_indices, end_indices, strides));
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
pad->shape(), slice->mutable_shape()));
simplifier_->UpdateLayout(slice->mutable_shape());
// Verify that the slice shape matches the pad shape.
TF_RET_CHECK(ShapeUtil::Equal(slice->shape(), pad->shape()));
auto equal = Shape::Equal();
if (!options_.is_layout_sensitive()) {
equal.IgnoreTilesInLayout();
}
TF_RET_CHECK(equal(slice->shape(), pad->shape()));
return ReplaceInstruction(pad, slice);
}
@ -3058,20 +3064,6 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
return false;
}
HloInstruction* operand = broadcast->mutable_operand(0);
auto is_scalar_broadcast = [](const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kBroadcast &&
ShapeUtil::IsScalar(instruction->operand(0)->shape());
};
auto is_equal_broadcast = [operand,
broadcast](const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kBroadcast &&
ShapeUtil::Equal(operand->shape(),
instruction->operand(0)->shape()) &&
broadcast->dimensions() == instruction->dimensions();
};
auto is_compatible_broadcast = [&](const HloInstruction* instruction) {
return is_scalar_broadcast(instruction) || is_equal_broadcast(instruction);
};
for (HloInstruction* user : broadcast->users()) {
if (user->user_count() == 0 && user != computation_->root_instruction()) {
continue;
@ -3090,20 +3082,18 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
continue;
}
// Check if all the operands of the user are compatible broadcasts for
// sinking. (They are either scalar broadcasts or broadcasts casting
// from/to the same shape/dimensions)
int64 compatible_broadcast_count = 0;
// Find the unique non-scalar operand or continue if there isn't one.
int64 scalar_broadcast_count = 0;
int64 broadcast_use_count = 0;
for (HloInstruction* user_operand : user->operands()) {
if (is_compatible_broadcast(user_operand)) {
++compatible_broadcast_count;
if (user_operand->opcode() == HloOpcode::kBroadcast &&
ShapeUtil::IsScalar(user_operand->operand(0)->shape())) {
++scalar_broadcast_count;
} else if (broadcast == user_operand) {
++broadcast_use_count;
}
}
if (compatible_broadcast_count + broadcast_use_count !=
user->operand_count()) {
if (scalar_broadcast_count + broadcast_use_count != user->operand_count()) {
continue;
}
std::vector<HloInstruction*> new_operands;
@ -3111,24 +3101,14 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
Shape changed_shape;
for (HloInstruction* user_operand : user->operands()) {
// If this is a broadcast operand that is not our original broadcast input
// to this function then we might need to change the input.
if (is_compatible_broadcast(user_operand)) {
// If this is a broadcast from a scalar value rewrite a broadcast from
// the scalar to the new shape enforced from the other broadcast
// operands.
if (is_scalar_broadcast(user_operand)) {
changed_shape = ShapeUtil::ChangeElementType(
operand->shape(), user_operand->shape().element_type());
simplifier_->UpdateLayout(&changed_shape);
new_operands.push_back(
computation_->AddInstruction(HloInstruction::CreateBroadcast(
changed_shape, user_operand->mutable_operand(0), {})));
} else {
// For the non-scalar broadcasts we guarantee that the shape of the
// operand of the broadcast needs to be already a compatible shape.
new_operands.push_back(user_operand->mutable_operand(0));
}
if (user_operand->opcode() == HloOpcode::kBroadcast &&
ShapeUtil::IsScalar(user_operand->operand(0)->shape())) {
changed_shape = ShapeUtil::ChangeElementType(
operand->shape(), user_operand->shape().element_type());
simplifier_->UpdateLayout(&changed_shape);
new_operands.push_back(
computation_->AddInstruction(HloInstruction::CreateBroadcast(
changed_shape, user_operand->mutable_operand(0), {})));
} else {
CHECK_EQ(broadcast, user_operand);
new_operands.push_back(operand);
@ -3643,12 +3623,17 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
new_slice_strides.push_back(slice->slice_strides(dim));
new_slice_limits.push_back(slice->slice_limits(dim));
}
VLOG(3) << "Sink broadcast through slice";
VLOG(3) << "Original slice: " << slice->ToString();
VLOG(3) << "Original broadcast: " << broadcast->ToString();
TF_ASSIGN_OR_RETURN(auto new_slice,
MakeSliceHlo(broadcast_operand, new_slice_starts,
new_slice_limits, new_slice_strides));
return ReplaceInstruction(
slice,
MakeBroadcastHlo(new_slice, broadcast->dimensions(), slice->shape()));
auto new_broadcast = HloInstruction::CreateBroadcast(
slice->shape(), new_slice, broadcast->dimensions());
VLOG(3) << "New slice: " << slice->ToString();
VLOG(3) << "New broadcast: " << new_broadcast->ToString();
return ReplaceWithNewInstruction(slice, std::move(new_broadcast));
}
// Try to simplify concat -> slice to an operand of concat.
@ -3728,16 +3713,21 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
new_indices.push_back(dynamic_slice->mutable_operand(1 + dim));
new_slice_sizes.push_back(dynamic_slice->slice_sizes(dim));
}
VLOG(3) << "Sink broadcast through dynamic slice";
VLOG(3) << "Original dynamic slice: " << dynamic_slice->ToString();
VLOG(3) << "Original broadcast: " << operand->ToString();
HloInstruction* new_dynamic_slice = broadcast_operand;
if (!new_slice_sizes.empty()) {
TF_ASSIGN_OR_RETURN(
new_dynamic_slice,
MakeDynamicSliceHlo(broadcast_operand, new_indices, new_slice_sizes));
}
return ReplaceInstruction(
dynamic_slice,
MakeBroadcastHlo(new_dynamic_slice, operand->dimensions(),
dynamic_slice->shape()));
auto new_broadcast = HloInstruction::CreateBroadcast(
dynamic_slice->shape(), new_dynamic_slice, operand->dimensions());
VLOG(3) << "New dynamic slice: " << dynamic_slice->ToString();
VLOG(3) << "New broadcast: " << new_broadcast->ToString();
return ReplaceWithNewInstruction(dynamic_slice, std::move(new_broadcast));
}
// Convert a dynamic slice into a slice if all offsets are constant and the

View File

@ -338,79 +338,6 @@ TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeBroadcastedConstants) {
m::ConstantScalar(3.0))))));
}
TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsScalar) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = f32[] parameter(0)
p1 = f32[] parameter(1)
b0 = f32[4] broadcast(p0), dimensions={}
b1 = f32[4] broadcast(p1), dimensions={}
ROOT multiply = f32[4] multiply(b1, b0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(
m->entry_computation()->root_instruction(),
GmockMatch(m::Broadcast(m::Multiply(m::Broadcast(m::Parameter(1)),
m::Broadcast(m::Parameter(0))))));
}
TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsConstantMix) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = f32[4] parameter(0)
c0 = f32[] constant(2.0)
b0 = f32[4,2] broadcast(c0), dimensions={}
b1 = f32[4,2] broadcast(p0), dimensions={0}
ROOT multiply = f32[4,2] multiply(b1, b0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Broadcast(m::Multiply(
m::Parameter(0), m::Broadcast(m::ConstantScalar(2.0))))));
}
TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsNonScalar) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = f32[4] parameter(0)
p1 = f32[4] parameter(1)
b0 = f32[4,2] broadcast(p0), dimensions={0}
b1 = f32[4,2] broadcast(p1), dimensions={0}
ROOT multiply = f32[4,2] multiply(b1, b0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(
m->entry_computation()->root_instruction(),
GmockMatch(m::Broadcast(m::Multiply(m::Parameter(1), m::Parameter(0)))));
}
TEST_F(AlgebraicSimplifierTest, ElementwiseNoSinkBroadcastsDifferentDims) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = f32[4] parameter(0)
p1 = f32[8] parameter(1)
b0 = f32[4,8] broadcast(p0), dimensions={0}
b1 = f32[4,8] broadcast(p1), dimensions={1}
ROOT multiply = f32[4,8] multiply(b1, b0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Multiply(m::Broadcast(m::Parameter(1)),
m::Broadcast(m::Parameter(0)))));
}
TEST_F(AlgebraicSimplifierTest,
MultiplyReassociateMultiplyOfConstantAndBroadcast) {
const char* kModuleStr = R"(
@ -2612,6 +2539,28 @@ TEST_F(AlgebraicSimplifierTest, SliceOfBroadcast) {
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Slice(m::Parameter(0)))));
}
TEST_F(AlgebraicSimplifierTest, SliceOfBroadcastPreserveLayout) {
const char* hlo_string = R"(
HloModule module
ENTRY test {
p0 = f32[10,20] parameter(0)
b = f32[10,30,20]{2,0,1:T(256)} broadcast(p0), dimensions={0,2}
ROOT s = f32[5,5,5]{2,0,1:T(256)} slice(b), slice={[0:5:1], [5:25:4], [5:15:2]}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
const Shape original_slice_shape =
module->entry_computation()->root_instruction()->shape();
HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Slice(m::Parameter(0)))));
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), original_slice_shape));
}
TEST_F(AlgebraicSimplifierTest, DynamicSliceOfBroadcast) {
const char* hlo_string = R"(
HloModule module
@ -2635,6 +2584,32 @@ TEST_F(AlgebraicSimplifierTest, DynamicSliceOfBroadcast) {
m::Parameter(0), m::Parameter(1), m::Parameter(3)))));
}
TEST_F(AlgebraicSimplifierTest, DynamicSliceOfBroadcastPreserveLayout) {
const char* hlo_string = R"(
HloModule module
ENTRY test {
p0 = f32[10,20] parameter(0)
i0 = s32[] parameter(1)
i1 = s32[] parameter(2)
i2 = s32[] parameter(3)
b = f32[10,30,20]{2,0,1:T(256)} broadcast(p0), dimensions={0,2}
ROOT ds = f32[5,5,5]{2,0,1:T(256)} dynamic-slice(b, i0, i1, i2), dynamic_slice_sizes={5,5,5}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
const Shape original_dynslice_shape =
module->entry_computation()->root_instruction()->shape();
HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::DynamicSlice(
m::Parameter(0), m::Parameter(1), m::Parameter(3)))));
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), original_dynslice_shape));
}
TEST_F(AlgebraicSimplifierTest, TransposeIsReshape) {
const char* hlo_string = R"(
HloModule module

View File

@ -75,85 +75,66 @@ CpuExecutable::CpuExecutable(
<< reinterpret_cast<void*>(compute_function_);
}
StatusOr<std::tuple<std::vector<se::DeviceMemoryBase>,
std::vector<se::OwningDeviceMemory>,
std::vector<se::OwningDeviceMemory>>>
CpuExecutable::CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator,
int device_ordinal,
std::vector<ExecutionInput> arguments) {
std::vector<se::DeviceMemoryBase> unowning_buffers(
assignment_->Allocations().size());
std::vector<se::OwningDeviceMemory> owning_buffers(
static StatusOr<MaybeOwningDeviceMemory> MemoryForAllocation(
const BufferAllocation& allocation,
absl::Span<ExecutionInput const> arguments,
se::DeviceMemoryAllocator* memory_allocator, int device_ordinal) {
VLOG(3) << allocation.ToString();
if (allocation.is_entry_computation_parameter()) {
se::DeviceMemoryBase out = arguments[allocation.parameter_number()]
.Buffer(allocation.param_shape_index())
.AsDeviceMemoryBase();
CHECK_EQ(allocation.size(), out.size())
<< "Size mismatch on param " << allocation.parameter_number()
<< " at shape index " << allocation.param_shape_index().ToString();
VLOG(3) << "allocation is a parameter";
return MaybeOwningDeviceMemory{out};
} else if (allocation.is_constant()) {
VLOG(3) << "allocation is a constant";
return MaybeOwningDeviceMemory{se::DeviceMemoryBase{}};
} else if (allocation.is_thread_local()) {
VLOG(3) << "buffer is thread-local";
return MaybeOwningDeviceMemory{se::DeviceMemoryBase{}};
}
int64 buffer_size = allocation.size();
TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory out,
memory_allocator->Allocate(device_ordinal, buffer_size));
VLOG(3) << "buffer allocated " << buffer_size << " bytes [" << out->opaque()
<< "]";
// Since the output buffer and all the temporary buffers were written into
// by the JITed code, msan has no way of knowing their memory was
// initialized. Mark them initialized so that msan doesn't flag loads from
// these buffers.
TF_ANNOTATE_MEMORY_IS_INITIALIZED(out->opaque(), buffer_size);
return MaybeOwningDeviceMemory{std::move(out)};
}
StatusOr<std::vector<MaybeOwningDeviceMemory>> CpuExecutable::CreateBufferTable(
se::DeviceMemoryAllocator* memory_allocator, int device_ordinal,
absl::Span<ExecutionInput const> arguments) {
std::vector<MaybeOwningDeviceMemory> buffers(
assignment_->Allocations().size());
VLOG(3) << "Allocating " << assignment_->Allocations().size()
<< " allocations for module " << module().name();
for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
++i) {
auto& allocation = assignment_->GetAllocation(i);
VLOG(3) << allocation.ToString();
if (allocation.is_entry_computation_parameter()) {
unowning_buffers[i] = arguments[allocation.parameter_number()]
.Buffer(allocation.param_shape_index())
.AsDeviceMemoryBase();
CHECK_EQ(allocation.size(), unowning_buffers[i].size())
<< "Size mismatch on param " << allocation.parameter_number()
<< " at shape index " << allocation.param_shape_index().ToString();
VLOG(3) << "allocation #" << i << " is a parameter";
continue;
}
if (allocation.is_constant()) {
VLOG(3) << "allocation #" << i << " is a constant";
continue;
}
if (allocation.is_thread_local()) {
VLOG(3) << "buffer #" << i << " is thread-local";
continue;
}
int64 buffer_size = allocation.size();
if (!owning_buffers[i].is_null()) {
VLOG(3) << "buffer #" << i
<< " is in the preallocated result ShapedBuffer";
} else {
TF_ASSIGN_OR_RETURN(owning_buffers[i], memory_allocator->Allocate(
device_ordinal, buffer_size));
unowning_buffers[i] = *owning_buffers[i];
VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes ["
<< owning_buffers[i]->opaque() << "]";
}
// Since the output buffer and all the temporary buffers were written into
// by the JITed code, msan has no way of knowing their memory was
// initialized. Mark them initialized so that msan doesn't flag loads from
// these buffers.
TF_ANNOTATE_MEMORY_IS_INITIALIZED(owning_buffers[i]->opaque(), buffer_size);
const BufferAllocation& allocation = assignment_->GetAllocation(i);
TF_ASSIGN_OR_RETURN(
buffers[i], MemoryForAllocation(allocation, arguments, memory_allocator,
device_ordinal));
}
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
assignment_->GetUniqueTopLevelOutputSlice());
VLOG(3) << "result index: " << result_slice.index();
std::vector<se::OwningDeviceMemory> buffers_to_free;
for (auto& argument : arguments) {
for (auto& index_buffer : *argument.MutableBuffers()) {
auto maybe_owning_buffer = index_buffer.second.Release();
if (maybe_owning_buffer) {
buffers_to_free.push_back(std::move(*maybe_owning_buffer));
}
}
}
return std::make_tuple(std::move(unowning_buffers), std::move(owning_buffers),
std::move(buffers_to_free));
return std::move(buffers);
}
Status CpuExecutable::ExecuteComputeFunction(
const ExecutableRunOptions* run_options,
absl::Span<const se::DeviceMemoryBase> buffers,
absl::Span<MaybeOwningDeviceMemory const> buffers,
HloExecutionProfile* hlo_execution_profile) {
// The calling convention for JITed functions is:
//
@ -181,7 +162,8 @@ Status CpuExecutable::ExecuteComputeFunction(
// Call the computation function following the calling convention.
std::vector<void*> buffer_pointers;
for (auto& buffer : buffers) {
buffer_pointers.push_back(const_cast<void*>(buffer.opaque()));
buffer_pointers.push_back(
const_cast<void*>(buffer.AsDeviceMemoryBase().opaque()));
}
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
assignment_->GetUniqueTopLevelOutputSlice());
@ -223,63 +205,82 @@ Status CpuExecutable::ExecuteComputeFunction(
return Status::OK();
}
StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer(
const ServiceExecutableRunOptions* run_options,
absl::Span<se::OwningDeviceMemory> buffers) {
absl::Span<MaybeOwningDeviceMemory> buffers,
absl::Span<ExecutionInput> arguments) {
se::Stream* stream = run_options->stream();
ScopedShapedBuffer result_buffer(
/*on_host_shape=*/result_shape(),
/*on_device_shape=*/result_shape(), run_options->allocator(),
stream->parent()->device_ordinal());
ExecutionOutput result(/*on_host_shape=*/result_shape(),
/*on_device_shape=*/result_shape(),
run_options->allocator(),
stream->parent()->device_ordinal());
const HloInputOutputAliasConfig& input_output_alias =
module().input_output_alias_config();
// Move se::OwningDeviceMemory values which contain the array(s) of the result
// into the respective location in ScopedShapedBuffer which is returned to the
// caller.
TF_RETURN_IF_ERROR(result_buffer.buffers().ForEachMutableElementWithStatus(
[&](const ShapeIndex& index, se::DeviceMemoryBase* device_memory) {
const auto& sources = this->GetRootValueSet().element(index);
// The points to set is unambiguous so the set should be a
// singleton.
CHECK_EQ(1, sources.values().size());
const HloValue* value_source = sources.values()[0];
HloInstruction* src = value_source->instruction();
for (auto& p : result.MutableResult()->buffers()) {
const ShapeIndex& index = p.first;
se::DeviceMemoryBase& result_buffer = p.second;
const HloValueSet& sources = this->GetRootValueSet().element(index);
// The points to set is unambiguous so the set should be a
// singleton.
CHECK_EQ(1, sources.values().size());
const HloValue* value_source = sources.values()[0];
HloInstruction* src = value_source->instruction();
// The source for this result buffer can be a nested buffer such as
// a tuple element. The source instruction should have a
// non-parameter buffer assigned.
TF_ASSIGN_OR_RETURN(
const BufferAllocation::Slice slice,
this->assignment_->GetUniqueSlice(src, value_source->index()));
const BufferAllocation::Index buffer_index = slice.index();
se::OwningDeviceMemory& buffer = buffers[buffer_index];
if (!slice.allocation()->is_entry_computation_parameter()) {
// If the buffer coming out of the result is from a parameter, the
// owning buffer will be null, and that means the caller aliased some
// parameter buffer to an output one (via the
// HloInputOutputAliasConfig API). If that is the case, the caller
// will receive a partially complete scoped shaped buffer, which they
// will have to fill up on return. Unfortunately the interface to the
// execute APIs are ShapedBuffer pointer based, which assumes caller
// ownership, and hence a buffer coming from there cannot be part of
// the new ScopedShapedBuffer we create for the result (which assumes
// ownership).
*device_memory = buffer.Release();
} else {
auto output_alias = input_output_alias.GetAliasedOutput(
slice.allocation()->parameter_number(),
slice.allocation()->param_shape_index());
CHECK(output_alias)
<< "Output buffer is coming from parameter "
<< slice.allocation()->parameter_number() << " at index "
<< slice.allocation()->param_shape_index()
<< ", but no alias exists";
CHECK_EQ(*output_alias, index);
// TODO(cheshire): duplication with other backends.
absl::optional<HloInputOutputAliasConfig::Alias> alias =
input_output_alias.GetAliasedParameter(index);
if (alias) {
CHECK_LT(alias->parameter_number, arguments.size());
ExecutionInput& input = arguments[alias->parameter_number];
MaybeOwningDeviceMemory* maybe_owning_memory =
input.MutableBuffer(alias->parameter_index);
if (absl::optional<se::OwningDeviceMemory> owning =
maybe_owning_memory->Release()) {
// If the caller passes the ownership of the device memory, reuse it
// as the output buffer. It is up to the caller whether or not to
// donate a buffer; the aliasing information describes which buffers
// may alias, not buffers that must alias.
se::DeviceMemoryBase argument_buffer = owning->Release();
*maybe_owning_memory = argument_buffer;
result_buffer = argument_buffer;
if (alias->kind == HloInputOutputAliasConfig::kUserAlias) {
// This is a user alias, so a must alias. The caller is giving us the
// input buffer, but in case of error of the execute call, we should
// not be releasing it as it contains valid data (for example, it is a
// parameter which the user wants us to alias, in a gradient update
// computation). So we store the index into the result in the aliased
// vactor, which will be fed to the ExecutionOutput, which will be
// using the indices to drop the addresses from its own
// ScopedShapedBuffer result, if the ExecutionOutput is not committed.
result.AddAliasedIndex(index);
}
return Status::OK();
}));
return std::move(result_buffer);
}
}
if (result_buffer.is_null()) {
// The source for this result buffer can be a nested buffer such as
// a tuple element. The source instruction should have a
// non-parameter buffer assigned.
TF_ASSIGN_OR_RETURN(
const BufferAllocation::Slice slice,
this->assignment_->GetUniqueSlice(src, value_source->index()));
const BufferAllocation::Index buffer_index = slice.index();
MaybeOwningDeviceMemory& buffer = buffers[buffer_index];
if (absl::optional<se::OwningDeviceMemory> owned_buffer =
buffer.Release()) {
result_buffer = owned_buffer->Release();
buffer = result_buffer;
} else {
result_buffer = buffer.AsDeviceMemoryBase();
result.AddAliasedIndex(index);
}
}
}
return std::move(result);
}
StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
@ -311,22 +312,16 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
run_options->stream()->implementation());
se::Stream* stream = run_options->stream();
se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
std::vector<se::OwningDeviceMemory> owning_buffers;
std::vector<se::DeviceMemoryBase> unowning_buffers;
std::vector<se::OwningDeviceMemory> buffers_to_release;
TF_ASSIGN_OR_RETURN(
std::tie(unowning_buffers, owning_buffers, buffers_to_release),
std::vector<MaybeOwningDeviceMemory> buffers,
CreateBufferTable(memory_allocator, stream->parent()->device_ordinal(),
std::move(arguments)));
arguments));
TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer result,
CreateResultShapedBuffer(run_options, absl::MakeSpan(owning_buffers)));
ExecutionOutput result,
CreateResultShapedBuffer(run_options, absl::MakeSpan(buffers),
absl::MakeSpan(arguments)));
// At this point, `unowning_buffers` contains unowning pointers to all of our
// buffers, and `buffers` contains owning pointers to the non-live-out
// buffers. Enqueue a task which keeps alive the non-live-out buffers.
//
// Logically we want this lambda to capture `buffers` by move, ultimately our
// functor needs to be wrapped in an std::function, and that requires its
// functor to be copyable. Thus we perpetrate the hack of capturing buffers
@ -339,23 +334,33 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
struct AsyncRunTask {
CpuExecutable* executable;
ServiceExecutableRunOptions run_options;
std::vector<se::DeviceMemoryBase> unowning_buffers;
std::shared_ptr<std::vector<se::OwningDeviceMemory>> buffers;
std::shared_ptr<std::vector<MaybeOwningDeviceMemory>> task_buffers;
HloExecutionProfile* hlo_execution_profile;
void operator()() {
// Failing a CHECK here is not great, but I don't see an obvious way to
// return a failed Status asynchronously.
TF_CHECK_OK(executable->ExecuteComputeFunction(
&run_options.run_options(), unowning_buffers, hlo_execution_profile));
&run_options.run_options(), *task_buffers, hlo_execution_profile));
}
};
host_stream->EnqueueTask(
AsyncRunTask{this, *run_options, std::move(unowning_buffers),
std::make_shared<std::vector<se::OwningDeviceMemory>>(
std::move(owning_buffers)),
AsyncRunTask{this, *run_options,
std::make_shared<std::vector<MaybeOwningDeviceMemory>>(
std::move(buffers)),
hlo_execution_profile});
return ExecutionOutput(std::move(result), std::move(buffers_to_release));
// TODO(cheshire): Duplication with other executables.
for (ExecutionInput& argument : arguments) {
for (auto& index_buffer : *argument.MutableBuffers()) {
absl::optional<se::OwningDeviceMemory> maybe_owning_buffer =
index_buffer.second.Release();
if (maybe_owning_buffer) {
result.AddToBeReleased(std::move(*maybe_owning_buffer));
}
}
}
return std::move(result);
}
/*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) {

View File

@ -101,24 +101,25 @@ class CpuExecutable : public Executable {
//
// - buffers_to_free: buffers whose ownership was donated by the caller that
// are to be freed by the caller.
StatusOr<std::tuple<std::vector<se::DeviceMemoryBase>,
std::vector<se::OwningDeviceMemory>,
std::vector<se::OwningDeviceMemory>>>
CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator,
int device_ordinal, std::vector<ExecutionInput> arguments);
StatusOr<std::vector<MaybeOwningDeviceMemory>> CreateBufferTable(
se::DeviceMemoryAllocator* memory_allocator, int device_ordinal,
absl::Span<ExecutionInput const> arguments);
// Calls the generated function performing the computation with the given
// arguments using the supplied buffers.
Status ExecuteComputeFunction(const ExecutableRunOptions* run_options,
absl::Span<const se::DeviceMemoryBase> buffers,
HloExecutionProfile* hlo_execution_profile);
Status ExecuteComputeFunction(
const ExecutableRunOptions* run_options,
absl::Span<MaybeOwningDeviceMemory const> buffers,
HloExecutionProfile* hlo_execution_profile);
// Creates a ScopedShapedBuffer for holding the result of the computation,
// moving buffers out of allocated_buffers and into the result as appropriate.
// The addresses are set according to buffer assignment.
StatusOr<ScopedShapedBuffer> CreateResultShapedBuffer(
// Creates an Execution output holding ScopedShapedBuffer for holding the
// result of the computation, moving buffers out of allocated_buffers and into
// the result as appropriate. The addresses are set according to buffer
// assignment.
StatusOr<ExecutionOutput> CreateResultShapedBuffer(
const ServiceExecutableRunOptions* run_options,
absl::Span<se::OwningDeviceMemory> buffers);
absl::Span<MaybeOwningDeviceMemory> buffers,
absl::Span<ExecutionInput> arguments);
// Returns the instruction value set of the root instruction of the entry
// computation. Uses dataflow analysis from buffer assignment.

View File

@ -210,24 +210,6 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
return Status::OK();
}
// Skip 'fusion' instruction if any of its fused instructions are expensive.
// This is done to avoid the duplication of expensive instructions, which
// would occur if 'fusion' were merged into multiple users.
//
// If 'fusion' has just one user, then an earlier fusion pass chose not to
// fuse this producer/consumer pair (likely because of expensive instruction
// re-use by the consumer), and so we honor that choice here as well.
if (absl::c_any_of(fusion->fused_instructions(),
[](const HloInstruction* instruction) {
return instruction->opcode() != HloOpcode::kParameter &&
GpuInstructionFusion::IsExpensive(*instruction);
})) {
VLOG(3) << "Not merging " << fusion->name()
<< ": Contains one or more expensive instructions.";
++num_fail_expensive_fused_instruction_;
return Status::OK();
}
// Skip 'fusion' instruction if merging it into all users would result in a
// net increase in bytes transferred (currently allowing the net bytes
// transferred to be exceeded up to ~10% in exchange for eliminating the
@ -244,6 +226,35 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
return Status::OK();
}
// Skip 'fusion' instruction if any of its fused instructions are expensive.
// This is done to avoid the duplication of expensive instructions, which
// would occur if 'fusion' were merged into multiple users.
//
// If 'fusion' has just one user, then an earlier fusion pass chose not to
// fuse this producer/consumer pair (likely because of expensive instruction
// re-use by the consumer), and so we honor that choice here as well.
//
// Moreover, if we are going to save a "lot" in memory bandwidth then we
// ignore how expensive the fusion instructions are. The heuristic used to
// determine "a lot" is the following: merging must reduce memory traffic by a
// factor of 0.3, and the amount of memory accessed must not be entirely
// trivial (above 1K). This likely has room for improvement in the future.
bool allow_expensive_ops =
merged_to_current_bytes_ratio < 0.3 && current_bytes_transferred > 1024;
if (!allow_expensive_ops &&
absl::c_any_of(fusion->fused_instructions(),
[](const HloInstruction* instruction) {
return instruction->opcode() != HloOpcode::kParameter &&
GpuInstructionFusion::IsExpensive(*instruction);
})) {
VLOG(3) << "Not merging " << fusion->name()
<< ": Contains one or more expensive instructions.";
++num_fail_expensive_fused_instruction_;
return Status::OK();
}
// Skip 'fusion' instruction if merging it into at least one of the users
// would cause too much code duplication because of inefficiencies in the
// fusion emitter.

View File

@ -367,6 +367,37 @@ TEST_F(FusionMergerTest, WillNotMergeIfFusionEmitterIsInefficient) {
EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie());
}
TEST_F(FusionMergerTest, WillMergeExpensiveFusionsIfSavesMemory) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule m
%f_a (p: f32[]) -> f32[1024,1024,1024] {
%p = f32[] parameter(0)
%b = f32[1024,1024,1024] broadcast(%p), dimensions={}
ROOT %t = f32[1024,1024,1024] tanh(%b)
}
%f_b (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
%p = f32[1024,1024,1024] parameter(0)
ROOT %t = f32[1024,1024,1024] tanh(%p)
}
%f_c (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
%p = f32[1024,1024,1024] parameter(0)
ROOT %t = f32[1024,1024,1024] tanh(%p)
}
ENTRY entry {
p0 = f32[] parameter(0)
f1 = f32[1024,1024,1024] fusion(p0), kind=kLoop, calls=%f_a
f2 = f32[1024,1024,1024] fusion(f1), kind=kLoop, calls=%f_b
f3 = f32[1024,1024,1024] fusion(f1), kind=kLoop, calls=%f_c
ROOT f4 = f32[1024,1024,1024] add(f2, f3)
})")
.ValueOrDie();
EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie());
}
} // namespace
} // namespace gpu
} // namespace xla

View File

@ -499,6 +499,8 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
/*allocate_buffers_for_constants=*/true,
/*colorer=*/BufferAssigner::DefaultColorer(),
/*must_not_live_out=*/{}, GetCanShareBuffer()));
VLOG(1) << "Buffer Assignment Stats "
<< buffer_assignment->GetStats().ToString();
DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations");
IrEmitterContext ir_emitter_context(

View File

@ -321,38 +321,11 @@ MatchBackwardInput(HloInstruction* conv) {
const auto no_match_result =
std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
// TODO: Theoretically cuDNN supports grouped convolutions also
// for the backward input convolution, but based on the cudnn's current state
// there is not much performance improvement when using the
// cudnn backward input API for grouped conv.
// This needs to be re-evaluated for future cuDNN versions.
// Note that we already have the necessary code down below, the only thing to
// enable it is to remove the following early return.
if (conv->feature_group_count() > 1) {
return no_match_result;
}
// Match instruction pattern.
CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
HloInstruction* reverse_filter = conv->mutable_operand(1);
ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
// Match BackwardInput for a depthwise convolution and thunk it to forward
// convolution Output feature dimension and input feature dimension has been
// swapped in the bridge. Hence to get the actual input features we need to
// query the output feature dimension
auto kernel_out_feature_dim = dnums.kernel_output_feature_dimension();
auto kernel_out_features =
reverse_filter->shape().dimensions(kernel_out_feature_dim);
// For a depthwise convolution, the input features must be equal to the
// feature_group_count. We can leverage this property to match a depthwise
// convolution and thunk it to forward conv
if (conv->feature_group_count() > 1 &&
kernel_out_features == conv->feature_group_count()) {
return no_match_result;
}
// We pattern-match to a backwards input conv if:
//
// - all spatial dims of the filter are reversed

View File

@ -360,6 +360,27 @@ StatusOr<se::DeviceMemoryBase> GpuExecutable::BufferForAllocation(
}
}
static Status CheckAlignment(const BufferAllocation& allocation,
se::DeviceMemoryBase buffer, int arg_idx) {
const int64 expected_alignment = [&] {
if (allocation.is_entry_computation_parameter()) {
return kEntryParameterAlignBytes;
} else if (allocation.is_constant()) {
return kConstantBufferAlignBytes;
} else {
return kXlaAllocatedBufferAlignBytes;
}
}();
if (!buffer.is_null() &&
reinterpret_cast<uintptr_t>(buffer.opaque()) % expected_alignment != 0) {
return InternalError(
"Address of buffer %d must be a multiple of %x, but "
"was %p",
arg_idx, expected_alignment, buffer.opaque());
}
return Status::OK();
}
StatusOr<BufferAllocations> GpuExecutable::GenerateBufferAllocations(
absl::Span<ExecutionInput const> arguments,
const GpuExecutable::BufferAllocToDeviceMemoryMap* globals,
@ -378,28 +399,37 @@ StatusOr<BufferAllocations> GpuExecutable::GenerateBufferAllocations(
se::DeviceMemoryBase buffer,
BufferForAllocation(arguments, globals, allocation, memory_allocator,
executor->device_ordinal(), i));
const int64 expected_alignment = [&] {
if (allocation.is_entry_computation_parameter()) {
return kEntryParameterAlignBytes;
} else if (allocation.is_constant()) {
return kConstantBufferAlignBytes;
} else {
return kXlaAllocatedBufferAlignBytes;
}
}();
if (!buffer.is_null() &&
reinterpret_cast<uintptr_t>(buffer.opaque()) % expected_alignment !=
0) {
return InternalError(
"Address of buffer %d must be a multiple of %x, but "
"was %p",
i, expected_alignment, buffer.opaque());
}
buffers.push_back(buffer);
TF_RETURN_IF_ERROR(CheckAlignment(allocation, buffer, i));
}
return {{buffers, executor->device_ordinal(), memory_allocator}};
}
// Returns `true` if the entire tuple contents is aliased.
static bool EntireTupleContentsAliased(
const Shape& output_shape, const ShapeIndex& index,
const HloInputOutputAliasConfig& alias_config) {
const Shape& indexed_shape = ShapeUtil::GetSubshape(output_shape, index);
if (!indexed_shape.IsTuple()) {
return false;
}
bool all_aliased = true;
ShapeUtil::ForEachSubshape(
indexed_shape, [&](const Shape& subshape, const ShapeIndex& subindex) {
if (subindex.empty()) {
return;
}
std::vector<int64> full_index;
absl::c_copy(index, std::back_inserter(full_index));
absl::c_copy(subindex, std::back_inserter(full_index));
if (!alias_config.OutputHasAlias(
ShapeIndex(full_index.begin(), full_index.end()))) {
all_aliased = false;
}
});
return all_aliased;
}
StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ExecutionInput> arguments,
@ -425,84 +455,102 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
}
se::StreamExecutor* executor = run_options->stream()->parent();
TF_ASSIGN_OR_RETURN(BufferAllocations buffer_allocations,
GenerateBufferAllocations(arguments, globals,
memory_allocator, executor));
for (Thunk* thunk : thunk_schedule_->TotalOrder()) {
TF_RETURN_IF_ERROR(thunk->Initialize(*this, executor));
}
VLOG(2) << buffer_allocations.ToString();
TF_RETURN_IF_ERROR(ExecuteThunks(run_options, buffer_allocations,
block_host_until_done,
hlo_execution_profile));
HloInstruction* root = hlo_module_->entry_computation()->root_instruction();
auto device_ordinal = executor->device_ordinal();
ExecutionOutput result(root->shape(), root->shape(), memory_allocator,
ExecutionOutput result(/*on_host_shape=*/root->shape(),
/*on_device_shape=*/root->shape(), memory_allocator,
device_ordinal);
TF_ASSIGN_OR_RETURN(BufferAllocations buffer_allocations,
GenerateBufferAllocations(arguments, globals,
memory_allocator, executor));
VLOG(2) << buffer_allocations.ToString();
std::set<se::DeviceMemoryBase> buffers_in_result;
for (auto& p : result.MutableResult()->buffers()) {
const ShapeIndex& index = p.first;
se::DeviceMemoryBase& device_memory = p.second;
se::DeviceMemoryBase& result_buffer = p.second;
const auto& sources = GetRootValueSet().element(index);
// The points-to set is unambiguous so the set should be a
// singleton. That is, we know exactly which instruction
// produced the array at this element.
CHECK_EQ(1, sources.values().size());
auto src_hlo = sources.values()[0]->instruction();
HloInstruction* src_hlo = sources.values()[0]->instruction();
VLOG(4) << "Looking at: " << sources.values()[0];
// The source instruction should have a non-parameter buffer
// assigned.
TF_ASSIGN_OR_RETURN(
const BufferAllocation::Slice slice,
assignment_->GetUniqueSlice(src_hlo, sources.values()[0]->index()));
se::DeviceMemoryBase src_base =
buffer_allocations.GetDeviceAddress(slice.index());
CHECK(!src_base.is_null() || src_base.size() == 0);
if (!slice.allocation()->is_entry_computation_parameter()) {
// If the buffer coming out of the result is from a parameter, it
// means the caller aliased some parameter buffer to an output one
// (via the HloInputOutputAliasConfig API). If that is the case, the
// caller will receive a partially complete scoped shaped buffer,
// which they will have to fill up on return.
// Unfortunately the interface to the execute APIs are ShapedBuffer
// pointer based, which assumes caller ownership, and hence a buffer
// coming from there cannot be part of the new ScopedShapedBuffer we
// create for the result (which assumes ownership).
device_memory = src_base;
} else {
const HloInputOutputAliasConfig& input_output_alias =
module().input_output_alias_config();
auto output_alias = input_output_alias.GetAliasedOutput(
slice.allocation()->parameter_number(),
slice.allocation()->param_shape_index());
CHECK(output_alias) << "Output buffer is coming from parameter "
<< slice.allocation()->parameter_number()
<< " at index "
<< slice.allocation()->param_shape_index()
<< ", but no alias exists";
CHECK_EQ(*output_alias, index);
const HloInputOutputAliasConfig& input_output_alias =
module().input_output_alias_config();
absl::optional<HloInputOutputAliasConfig::Alias> alias =
input_output_alias.GetAliasedParameter(index);
if (alias) {
CHECK_LT(alias->parameter_number, arguments.size());
ExecutionInput& input = arguments[alias->parameter_number];
MaybeOwningDeviceMemory* maybe_owning_memory =
input.MutableBuffer(alias->parameter_index);
if (absl::optional<se::OwningDeviceMemory> owning =
maybe_owning_memory->Release()) {
// If the caller passes the ownership of the device memory, reuse it
// as the output buffer. It is up to the caller whether or not to
// donate a buffer; the aliasing information describes which buffers
// may alias, not buffers that must alias.
se::DeviceMemoryBase argument_buffer = owning->Release();
*maybe_owning_memory = argument_buffer;
result_buffer = argument_buffer;
if (alias->kind == HloInputOutputAliasConfig::kUserAlias) {
// This is a user alias, so a must alias. The caller is giving us the
// input buffer, but in case of error from the execute call, we should
// not be releasing it as it contains valid data (for example, it is a
// parameter which the user wants us to alias, in a gradient update
// computation). So we store the index into the result in the aliased
// vector, which will be fed to the ExecutionOutput, which will use
// the indices to drop the addresses from its own ScopedShapedBuffer
// result, if the ExecutionOutput is not committed.
result.AddAliasedIndex(index);
}
}
}
buffers_in_result.insert(src_base);
if (result_buffer.is_null()) {
// The source instruction should have a non-parameter buffer
// assigned.
TF_ASSIGN_OR_RETURN(
const BufferAllocation::Slice slice,
assignment_->GetUniqueSlice(src_hlo, sources.values()[0]->index()));
result_buffer = buffer_allocations.GetDeviceAddress(slice.index());
// If the entire tuple contents is aliased, the copy insertion will *not*
// materialize a new tuple, so we mark it as aliased as well.
if (EntireTupleContentsAliased(root->shape(), index,
input_output_alias)) {
result.AddAliasedIndex(index);
}
}
buffers_in_result.insert(result_buffer);
}
for (Thunk* thunk : thunk_schedule_->TotalOrder()) {
TF_RETURN_IF_ERROR(thunk->Initialize(*this, executor));
}
TF_RETURN_IF_ERROR(ExecuteThunks(run_options, buffer_allocations,
block_host_until_done,
hlo_execution_profile));
// Free all temporary allocations.
TF_RETURN_IF_ERROR(
buffer_allocations.TearDown(buffers_in_result, assignment_.get()));
std::vector<se::OwningDeviceMemory> buffers_to_free;
for (auto& argument : arguments) {
// Free allocations for arguments.
for (ExecutionInput& argument : arguments) {
for (auto& index_buffer : *argument.MutableBuffers()) {
auto maybe_owning_buffer = index_buffer.second.Release();
if (maybe_owning_buffer) {
buffers_to_free.push_back(std::move(*maybe_owning_buffer));
if (absl::optional<se::OwningDeviceMemory> owning =
index_buffer.second.Release()) {
result.AddToBeReleased(std::move(*owning));
}
}
}
return result;
return std::move(result);
}
const InstructionValueSet& GpuExecutable::GetRootValueSet() const {

View File

@ -121,10 +121,10 @@ string HloInputOutputAliasConfig::ToString() const {
return absl::StrJoin(pieces, "\n");
}
HloInputOutputAliasConfig::AliasKind
absl::optional<HloInputOutputAliasConfig::AliasKind>
HloInputOutputAliasConfig::ParameterAliasKind(
int64 param_number, const ShapeIndex& param_index) const {
AliasKind kind = AliasKind::kNoAlias;
absl::optional<AliasKind> kind;
alias_.ForEachElement(
[&](const xla::ShapeIndex&, absl::optional<Alias> alias) {
if (alias && alias->parameter_number == param_number &&

View File

@ -36,7 +36,6 @@ class HloInputOutputAliasConfig {
// compilation time by the user, and has to be respected. A kSystemAlias one
// might be setup by the compiler, if it decides it is convenient to do so.
enum AliasKind {
kNoAlias,
kUserAlias,
kSystemAlias,
};
@ -68,15 +67,15 @@ class HloInputOutputAliasConfig {
AliasKind kind = AliasKind::kUserAlias);
// Returns the kind of alias for the given parameter number and parameter
// index. If no alias exists, AliasKind::kNoAlias is returned.
AliasKind ParameterAliasKind(int64 param_number,
const ShapeIndex& param_index) const;
// index.
absl::optional<AliasKind> ParameterAliasKind(
int64 param_number, const ShapeIndex& param_index) const;
// Returns true if the given parameter is aliased with one of the output
// buffers.
bool ParameterHasAlias(int64 param_number,
const ShapeIndex& param_index) const {
return ParameterAliasKind(param_number, param_index) != AliasKind::kNoAlias;
return ParameterAliasKind(param_number, param_index).has_value();
}
// Checks whether the provided output index has already been aliased.

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "absl/base/casts.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "llvm/ADT/Triple.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/GlobalVariable.h"

View File

@ -432,8 +432,8 @@ std::string MemorySpaceAssignment::AllocationValue::ToString() const {
absl::StrAppend(&out, "\n position:\n");
absl::StrAppend(&out, " ", defining_position_.ToString(), "\n");
absl::StrAppend(&out, " uses:\n");
for (const HloUse& use : uses_) {
absl::StrAppend(&out, " ", use.ToString(), "\n");
for (const Use& use : uses_) {
absl::StrAppend(&out, " ", use.hlo_use.ToString(), "\n");
}
return out;
}
@ -515,6 +515,53 @@ void AlternateMemoryBestFitHeap::CreateAllocationValues(
}
}
void AlternateMemoryBestFitHeap::FindAliases(
std::vector<AllocationValue>* allocation_values) const {
absl::flat_hash_map<const HloInstruction*, const AllocationValue*>
values_by_defining_inst;
for (AllocationValue& value : *allocation_values) {
CHECK_EQ(values_by_defining_inst.count(value.defining_instruction()), 0);
values_by_defining_inst[value.defining_instruction()] = &value;
}
auto maybe_add_alias_with_instruction = [&](const HloInstruction* instruction,
AllocationValue::Use* use) {
auto aliased_value_it = values_by_defining_inst.find(instruction);
if (aliased_value_it != values_by_defining_inst.end()) {
VLOG(3) << "Adding aliasing for use " << use->hlo_use.ToString() << " to "
<< aliased_value_it->second->ToShortString();
use->aliases.push_back(aliased_value_it->second->defining_position());
}
};
for (AllocationValue& value : *allocation_values) {
for (AllocationValue::Use& use : value.uses()) {
// Find any aliases with the instruction itself (operand and output must
// alias).
maybe_add_alias_with_instruction(use.hlo_use.instruction, &use);
// Find any aliases with the parameters of called computations.
for (const HloComputation* called_computation :
use.hlo_use.instruction->called_computations()) {
for (const HloInstruction* parameter_instruction :
called_computation->parameter_instructions()) {
maybe_add_alias_with_instruction(parameter_instruction, &use);
}
}
// Special case for kWhile: the root of the body computation must alias as
// well.
if (use.hlo_use.instruction->opcode() == HloOpcode::kWhile) {
HloPosition root_alias{
use.hlo_use.instruction->while_body()->root_instruction(),
use.hlo_use.operand_index};
VLOG(3) << "Adding while body root aliasing for use "
<< use.hlo_use.ToString() << " to " << root_alias;
use.aliases.push_back(root_alias);
}
}
}
}
std::vector<const GlobalDecreasingSizeBestFitHeap::BufferInterval*>
AlternateMemoryBestFitHeap::GetSortedColocatedIntervals(
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const {
@ -675,18 +722,18 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
// multiple called computations), determine if the parameter->first use
// dependency is short.
int64 conditional_time = instruction_schedule.at(use.instruction);
for (const HloUse& other_use : value.uses()) {
if (other_use.instruction != use.instruction) {
for (const AllocationValue::Use& other_use : value.uses()) {
if (other_use.hlo_use.instruction != use.instruction) {
continue;
}
HloComputation* called_computation =
use.instruction->called_computations().at(other_use.operand_number -
1);
use.instruction->called_computations().at(
other_use.hlo_use.operand_number - 1);
const HloInstruction* parameter_instruction =
called_computation->parameter_instruction(0);
HloValue* parameter_value =
&alias_analysis_.dataflow_analysis().GetUniqueValueAt(
parameter_instruction, other_use.operand_index);
parameter_instruction, other_use.hlo_use.operand_index);
int64 parameter_time = instruction_schedule.at(parameter_instruction);
int64 min_use_time = conditional_time;
for (const HloUse& parameter_use : parameter_value->uses()) {
@ -947,6 +994,7 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
for (const auto& colocated_interval : colocated_intervals) {
CreateAllocationValues(colocated_interval->buffer, &allocation_values);
}
FindAliases(&allocation_values);
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
// Data structure to contain the preferred offset for a given computation.
@ -969,25 +1017,26 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
// Iterate over the uses.
for (int use_idx = 0; use_idx < allocation_value.uses().size(); ++use_idx) {
const HloUse& use = allocation_value.uses().at(use_idx);
int64 use_time = instruction_schedule.at(use.instruction);
const AllocationValue::Use& use = allocation_value.uses().at(use_idx);
const HloUse hlo_use = use.hlo_use;
int64 use_time = instruction_schedule.at(hlo_use.instruction);
int64 latest_prefetch_time = use_time;
bool allow_no_copy_alternate_mem_allocation = true;
absl::optional<int64> earliest_prefetch_time = absl::nullopt;
// Sequential calls include kWhile, kCall, and kConditional opcodes.
bool is_sequential_call =
(GetInstructionCallContext(use.instruction->opcode()) ==
(GetInstructionCallContext(hlo_use.instruction->opcode()) ==
CallContext::kSequential);
if (is_sequential_call) {
for (const HloComputation* called_computation :
use.instruction->called_computations()) {
hlo_use.instruction->called_computations()) {
const HloLiveRange::TimeBound& computation_span =
hlo_live_range_.computation_span_times().at(called_computation);
latest_prefetch_time =
std::min(computation_span.start, latest_prefetch_time);
}
if (use.instruction->opcode() == HloOpcode::kWhile) {
if (hlo_use.instruction->opcode() == HloOpcode::kWhile) {
// Given an example while loop and flattened schedule (logical times
// shown on the left):
//
@ -1008,10 +1057,10 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
// the interval to time 0-4. This is so that the remaining interval
// (5-6) can be allocated separately and this buffer doesn't waste
// alternate memory space within the while loop body.
HloComputation* while_body = use.instruction->while_body();
HloComputation* while_body = hlo_use.instruction->while_body();
// We require while body ROOTs to be the last in the schedule.
CHECK_EQ(instruction_schedule.at(while_body->root_instruction()) + 1,
instruction_schedule.at(use.instruction))
instruction_schedule.at(hlo_use.instruction))
<< "While body ROOTs need to be the last in the schedule! "
"Please run RootInstructionSinker.";
// Replace the use time with the parameter time so that we can decide
@ -1019,11 +1068,11 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
// look at uses within the while loop body.
use_time =
instruction_schedule.at(while_body->parameter_instruction(0));
} else if (use.instruction->opcode() == HloOpcode::kConditional) {
} else if (hlo_use.instruction->opcode() == HloOpcode::kConditional) {
// Replace the use time with the earliest parameter of called
// computations.
for (const HloComputation* called_computation :
use.instruction->called_computations()) {
hlo_use.instruction->called_computations()) {
use_time = std::min(
use_time, instruction_schedule.at(
called_computation->parameter_instruction(0)));
@ -1033,8 +1082,8 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
// Add a required assignment in default memory if the use not allowed in
// alternate memory.
if (!IsUseAllowedInAlternateMemory(allocation_value, use)) {
AddRequiredAssignment(allocation_value.value(), use.instruction,
if (!IsUseAllowedInAlternateMemory(allocation_value, hlo_use)) {
AddRequiredAssignment(allocation_value.value(), hlo_use.instruction,
MemorySpace::kDefault, use_time);
} else if (use_idx > 0) {
// We allow buffers in alternate memory that are passed into
@ -1043,14 +1092,16 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
// alternate memory allocation, subsequent uses cannot use the same
// alternate memory allocation in order not to clobber data. So we force
// default memory allocation for these subsequent uses.
const HloUse& previous_use = allocation_value.uses().at(use_idx - 1);
if (previous_use.instruction->opcode() == HloOpcode::kConditional &&
previous_use.instruction != use.instruction) {
const AllocationValue::Use& previous_use =
allocation_value.uses().at(use_idx - 1);
if (previous_use.hlo_use.instruction->opcode() ==
HloOpcode::kConditional &&
previous_use.hlo_use.instruction != hlo_use.instruction) {
allow_no_copy_alternate_mem_allocation = false;
earliest_prefetch_time =
instruction_schedule.at(previous_use.instruction);
VLOG(3) << "Previous use (" << previous_use.ToString() << ") of use ("
<< use.ToString()
instruction_schedule.at(previous_use.hlo_use.instruction);
VLOG(3) << "Previous use (" << previous_use.hlo_use.ToString()
<< ") of use (" << hlo_use.ToString()
<< ") is a conditional, so this use will need to evict. "
<< "Earliest prefetch time = " << *earliest_prefetch_time;
}
@ -1059,7 +1110,7 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
// Bitcasts don't define buffers and don't directly consume buffers. Skip
// allocating buffers for bitcast uses. The uses that feed from bitcasts
// will be handled specially.
if (use.instruction->opcode() != HloOpcode::kBitcast) {
if (hlo_use.instruction->opcode() != HloOpcode::kBitcast) {
AllocationRequest request;
// Rarely, (e.g., when conditional true and false parameters are the
// same), definition time can be the time of the conditional and use
@ -1072,7 +1123,7 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
allow_no_copy_alternate_mem_allocation;
request.earliest_prefetch_time = earliest_prefetch_time;
request.preferred_offset = preferred_offset;
request.use = use;
request.use = &use;
request.allocation_value = &allocation_value;
if (!AllocateSegment(request)) {
// If the allocation finding failed (e.g., due to running out of
@ -1085,23 +1136,25 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
// If there are multiple uses, they can try using the memory allocation
// already at the alternate memory.
definition_time = instruction_schedule.at(use.instruction);
definition_time = instruction_schedule.at(hlo_use.instruction);
}
// If the use has been a sequential call (e.g. a while loop), the other
// colocated intervals must alias with this allocation.
if (is_sequential_call) {
MemorySpaceAssignment::Allocation* aliased_allocation =
GetLiveAllocationAt(*allocation_value.allocation_sequence(),
use_time);
AddAliasedRequiredAssignmentsForSequentialCall(use, aliased_allocation);
// Remember the preferred offset to be used inside while loop body
// computations.
if (aliased_allocation->memory_space() == MemorySpace::kAlternate &&
use.instruction->opcode() == HloOpcode::kWhile) {
preferred_offset_for_computation[use.instruction->while_body()] =
aliased_allocation->chunk().offset;
}
// Propagate the allocation to any aliases this use might have had.
MemorySpaceAssignment::Allocation* aliased_allocation =
GetLiveAllocationAt(*allocation_value.allocation_sequence(),
use_time);
for (const HloPosition& aliased_position : use.aliases) {
AddAliasedRequiredAssignment(aliased_position.instruction,
aliased_position.index,
aliased_allocation);
}
// Special case for while loops since the root offset must agree with
// other offsets: remember the preferred offset for the while loop body.
if (hlo_use.instruction->opcode() == HloOpcode::kWhile &&
aliased_allocation->memory_space() == MemorySpace::kAlternate) {
preferred_offset_for_computation[hlo_use.instruction->while_body()] =
aliased_allocation->chunk().offset;
}
}
if (!allocation_success) {
@ -1212,34 +1265,45 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer(
pending_required_assignments_.clear();
}
void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignmentsForSequentialCall(
const HloUse& use,
const MemorySpaceAssignment::Allocation* aliased_allocation) {
// Add aliased required assignments.
if (use.instruction->opcode() == HloOpcode::kWhile) {
HloComputation* while_body = use.instruction->while_body();
HloComputation* while_condition = use.instruction->while_condition();
AddAliasedRequiredAssignment(while_condition->parameter_instruction(0),
use.operand_index, aliased_allocation);
AddAliasedRequiredAssignment(while_body->parameter_instruction(0),
use.operand_index, aliased_allocation);
AddAliasedRequiredAssignment(while_body->root_instruction(),
use.operand_index, aliased_allocation);
AddAliasedRequiredAssignment(use.instruction, use.operand_index,
aliased_allocation);
} else if (use.instruction->opcode() == HloOpcode::kConditional) {
HloComputation* called_computation =
use.instruction->called_computations().at(use.operand_number - 1);
AddAliasedRequiredAssignment(called_computation->parameter_instruction(0),
use.operand_index, aliased_allocation);
} else {
CHECK(use.instruction->opcode() == HloOpcode::kCall);
HloComputation* called_computation =
use.instruction->called_computations().at(0);
AddAliasedRequiredAssignment(
called_computation->parameter_instruction(use.operand_number),
use.operand_index, aliased_allocation);
absl::optional<RequiredMemoryAssignment>
AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer,
int64 time) const {
auto required_assignment_it = required_assignments_.find(buffer);
absl::optional<RequiredMemoryAssignment> required_assignment_at_time;
if (required_assignment_it != required_assignments_.end()) {
for (const RequiredMemoryAssignment& required_assignment :
required_assignment_it->second) {
if (required_assignment.time == time) {
// Sanity check that there is only one required at time.
CHECK(!required_assignment_at_time);
required_assignment_at_time = required_assignment;
}
}
}
return required_assignment_at_time;
}
absl::optional<RequiredMemoryAssignment>
AlternateMemoryBestFitHeap::AliasedRequiredAssignmentForUse(
const AllocationValue::Use& use) const {
absl::optional<RequiredMemoryAssignment> required_assignment;
for (const HloPosition& position : use.aliases) {
const HloValue* value =
&alias_analysis_.dataflow_analysis().GetUniqueValueAt(
position.instruction, position.index);
int64 time =
hlo_live_range_.instruction_schedule().at(position.instruction);
absl::optional<RequiredMemoryAssignment> required_assignment_for_alias =
RequiredMemoryAssignmentAt(value, time);
if (required_assignment == absl::nullopt) {
required_assignment = required_assignment_for_alias;
} else {
CHECK(required_assignment_for_alias == absl::nullopt ||
required_assignment->equals_ignoring_time(
*required_assignment_for_alias));
}
}
return required_assignment;
}
void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment(
@ -1429,24 +1493,6 @@ void AlternateMemoryBestFitHeap::AddToPendingChunks(
CommitChunk(buffer_interval, chunk_candidate);
}
absl::optional<RequiredMemoryAssignment>
AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer,
int64 time) const {
auto required_assignment_it = required_assignments_.find(buffer);
absl::optional<RequiredMemoryAssignment> required_assignment_at_time;
if (required_assignment_it != required_assignments_.end()) {
for (const RequiredMemoryAssignment& required_assignment :
required_assignment_it->second) {
if (required_assignment.time == time) {
// Sanity check that there is only one required at time.
CHECK(!required_assignment_at_time);
required_assignment_at_time = required_assignment;
}
}
}
return required_assignment_at_time;
}
bool AlternateMemoryBestFitHeap::AllocateSegment(
const AllocationRequest& request) {
auto allocation_sequence = request.allocation_value->allocation_sequence();
@ -1457,7 +1503,7 @@ bool AlternateMemoryBestFitHeap::AllocateSegment(
MemorySpaceAssignment::Allocation* allocation =
GetLiveAllocationAt(*allocation_sequence, request.end_time);
CHECK_NE(allocation, nullptr);
allocation->AddUse(request.use);
allocation->AddUse(request.use->hlo_use);
return true;
}
@ -1467,8 +1513,9 @@ bool AlternateMemoryBestFitHeap::AllocateSegment(
<< request.allocation_value->ToShortString() << " ("
<< request.start_time << ", " << request.end_time
<< ") latest prefetch = " << request.latest_prefetch_time
<< " last use = " << request.allocation_value->use_times().back()
<< " use = " << request.use.ToString() << ". Size = " << request.size
<< " last use = " << request.allocation_value->uses().back().time
<< " use = " << request.use->hlo_use.ToString()
<< ". Size = " << request.size
<< ", def pos = " << defining_position.ToString();
CHECK_LE(request.start_time, request.end_time);
@ -1483,8 +1530,21 @@ bool AlternateMemoryBestFitHeap::AllocateSegment(
if (required_assignment_at_start) {
required_memory_space_at_start = required_assignment_at_start->memory_space;
}
// Find required assignment both for the use and its aliases. If they are both
// non-nullopt, then make sure they require the same assignment.
auto required_assignment_at_end = RequiredMemoryAssignmentAt(
request.allocation_value->value(), request.end_time);
auto aliased_required_assignment_at_end =
AliasedRequiredAssignmentForUse(*request.use);
if (required_assignment_at_end != aliased_required_assignment_at_end) {
if (required_assignment_at_end == absl::nullopt) {
required_assignment_at_end = aliased_required_assignment_at_end;
} else {
CHECK(aliased_required_assignment_at_end == absl::nullopt ||
aliased_required_assignment_at_end->equals_ignoring_time(
*required_assignment_at_end));
}
}
absl::optional<MemorySpace> required_memory_space_at_end;
if (required_assignment_at_end) {
required_memory_space_at_end = required_assignment_at_end->memory_space;
@ -1553,7 +1613,7 @@ bool AlternateMemoryBestFitHeap::AllocateSegment(
VLOG(3)
<< "Not trying to prefetch because use requires buffer in default mem.";
(*prev_allocation_in_default_mem_it)->Extend(request.end_time);
(*prev_allocation_in_default_mem_it)->AddUse(request.use);
(*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
return true;
}
@ -1577,7 +1637,7 @@ bool AlternateMemoryBestFitHeap::AllocateSegment(
// If a copy wasn't inserted, then add this use to the latest allocation in
// default memory.
(*prev_allocation_in_default_mem_it)->Extend(request.end_time);
(*prev_allocation_in_default_mem_it)->AddUse(request.use);
(*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
return true;
}
@ -1746,7 +1806,7 @@ bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy(
chunk_candidate->chunk, request.start_time, request.end_time));
}
request.allocation_value->allocation_sequence()->back()->AddUse(
request.use);
request.use->hlo_use);
return true;
}
return false;
@ -1833,7 +1893,7 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) {
if (!eviction_scheduled) {
// If the eviction couldn't be scheduled, then fail. This buffer will be
// kept in the default memory.
VLOG(3) << "Bailing: Could not evict " << request.use.ToString()
VLOG(3) << "Bailing: Could not evict " << request.use->hlo_use.ToString()
<< " because we hit the limit of maximum asynchronous copies "
<< "between "
<< hlo_live_range_.flattened_instruction_sequence()
@ -1868,7 +1928,8 @@ bool AlternateMemoryBestFitHeap::Prefetch(
earliest_prefetch_time =
std::max(earliest_prefetch_time, *request.earliest_prefetch_time);
}
options_.prefetch_interval_picker->Begin(request.use, earliest_prefetch_time,
options_.prefetch_interval_picker->Begin(request.use->hlo_use,
earliest_prefetch_time,
request.latest_prefetch_time);
VLOG(3) << "Trying prefetch picker = "
<< options_.prefetch_interval_picker->ToDebugString();
@ -1922,7 +1983,7 @@ bool AlternateMemoryBestFitHeap::Prefetch(
request.allocation_value->allocation_sequence());
request.allocation_value->allocation_sequence()->back()->AddUse(
request.use);
request.use->hlo_use);
prefetch_failed_due_to_async_copy_ = false;
return true;
}
@ -1938,11 +1999,11 @@ AlternateMemoryBestFitHeap::FindBestChunkCandidate(
if (!preferred_offset) {
// Find a chunk that's as long living as possible iterating in reverse over
// the use times.
for (auto use_time = request.allocation_value->use_times().rbegin();
use_time != request.allocation_value->use_times().rend() &&
*use_time >= end_time;
++use_time) {
alternate_mem_interval->end = *use_time;
for (auto use_it = request.allocation_value->uses().rbegin();
use_it != request.allocation_value->uses().rend() &&
use_it->time >= end_time;
++use_it) {
alternate_mem_interval->end = use_it->time;
ChunkCandidate chunk_candidate =
FindChunkCandidate(*alternate_mem_interval);
if (chunk_candidate.heap_size <= available_heap_size()) {

View File

@ -620,6 +620,18 @@ class MemorySpaceAssignment {
// add.5, operand 0
class AllocationValue {
public:
// This data structure wraps an HloUse and adds additional metadata that are
// useful for allocation.
struct Use {
// The wrapped HloUse object.
HloUse hlo_use;
// The logical time this use is scheduled.
int64 time;
// All the positions where this use aliases with. The aliased positions
// must get the same allocation.
std::vector<HloPosition> aliases;
};
AllocationValue(const HloValue* value, const HloPosition& position)
: value_(value), defining_position_(position) {}
@ -627,8 +639,8 @@ class MemorySpaceAssignment {
const HloInstruction* defining_instruction() const {
return defining_position().instruction;
}
const std::vector<HloUse>& uses() const { return uses_; }
const std::vector<int64>& use_times() const { return use_times_; }
const std::vector<Use>& uses() const { return uses_; }
std::vector<Use>& uses() { return uses_; }
const HloValue* value() const { return value_; }
const HloComputation* computation() const {
return defining_instruction()->parent();
@ -636,8 +648,7 @@ class MemorySpaceAssignment {
AllocationSequence* allocation_sequence() { return &allocation_sequence_; }
void AddUse(const HloUse& use, int64 use_time) {
uses_.push_back(use);
use_times_.push_back(use_time);
uses_.push_back({use, use_time, {}});
}
std::string ToString() const;
@ -646,8 +657,7 @@ class MemorySpaceAssignment {
private:
const HloValue* value_;
HloPosition defining_position_;
std::vector<HloUse> uses_;
std::vector<int64> use_times_;
std::vector<Use> uses_;
AllocationSequence allocation_sequence_;
};
@ -769,10 +779,18 @@ struct RequiredMemoryAssignment {
int64 time;
absl::optional<HeapSimulator::Chunk> chunk;
bool equals_ignoring_time(const RequiredMemoryAssignment& other) const {
return memory_space == other.memory_space && chunk == other.chunk;
}
bool operator==(const RequiredMemoryAssignment& other) const {
return memory_space == other.memory_space && time == other.time &&
chunk == other.chunk;
}
bool operator!=(const RequiredMemoryAssignment& other) const {
return !(*this == other);
}
};
// A struct representing an asynchronous copy with its logical start and end
@ -880,7 +898,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
bool allow_no_copy_alternate_mem_allocation;
absl::optional<int64> earliest_prefetch_time;
absl::optional<int64> preferred_offset;
HloUse use;
const MemorySpaceAssignment::AllocationValue::Use* use;
MemorySpaceAssignment::AllocationValue* allocation_value;
};
@ -890,10 +908,6 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
static MemorySpaceAssignment::Allocation* GetLiveAllocationAt(
const MemorySpaceAssignment::AllocationSequence& allocations, int64 time);
// Returns the required assignment at a particular time, if available.
absl::optional<RequiredMemoryAssignment> RequiredMemoryAssignmentAt(
const HloValue* buffer, int64 time) const;
// Returns true if this buffer is allowed to be placed in the alternate
// memory.
bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const;
@ -914,6 +928,10 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
bool AllocateColocatedIntervals(
const std::vector<const BufferInterval*>& colocated_intervals);
// Go through all the uses in the AllocationValues and find the aliasing
// positions.
void FindAliases(std::vector<AllocationValue>* allocation_values) const;
// Finds an allocation for an allocation request for a segment (see the
// documentation for AllocationRequest above how a segment is defined).
//
@ -950,12 +968,14 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
const AllocationRequest& request, absl::optional<int64> preferred_offset,
BufferInterval* alternate_mem_interval) const;
// At the end of an allocation with a sequential call (while, conditional, and
// call), this function adds the necessary aliased assignments within the
// called computations.
void AddAliasedRequiredAssignmentsForSequentialCall(
const HloUse& use,
const MemorySpaceAssignment::Allocation* aliased_allocation);
// Returns the required assignment at a particular time, if available.
absl::optional<RequiredMemoryAssignment> RequiredMemoryAssignmentAt(
const HloValue* buffer, int64 time) const;
// Searches for aliases in the use for a required assignment, and returns it
// if found.
absl::optional<RequiredMemoryAssignment> AliasedRequiredAssignmentForUse(
const AllocationValue::Use& use) const;
// Propagates aliased required assignment for a given position.
void AddAliasedRequiredAssignment(

View File

@ -1635,7 +1635,8 @@ TEST_P(MemorySpaceAssignmentTest, WhileCondAliasBug) {
%constant.5 = s32[1]{0:T(128)} constant({1})
%prev.4 = s32[6]{0:T(128)} parameter(0)
%rng.8 = s32[5]{0:T(128)} rng(s32[]{:T(128)} %constant.6, s32[]{:T(128)} %constant.7), distribution=rng_uniform
ROOT %fusion = s32[6]{0:T(128)} fusion(s32[6]{0:T(128)} %prev.4, s32[1]{0:T(128)} %constant.5, s32[5]{0:T(128)} %rng.8), kind=kLoop, calls=%fused_computation
%neg = s32[1]{0:T(128)} negate(s32[1]{0:T(128)} %constant.5)
ROOT %fusion = s32[6]{0:T(128)} fusion(s32[6]{0:T(128)} %prev.4, s32[1]{0:T(128)} %neg, s32[5]{0:T(128)} %rng.8), kind=kLoop, calls=%fused_computation
}
%WhileWithPrngScalarResult.11 (prev.12: s32[6]) -> pred[] {
@ -1665,6 +1666,62 @@ TEST_P(MemorySpaceAssignmentTest, WhileCondAliasBug) {
kDefaultMemorySpace);
}
TEST_P(MemorySpaceAssignmentTest, WhileInPlaceBuffer) {
// Ensure that a dynamic update slice within a while loop is able to get an
// alternate memory allocation.
absl::string_view hlo_string = R"(
HloModule Module, is_scheduled=true
fused_computation {
param0 = f32[2,3] parameter(0)
constant.1 = f32[] constant(0)
broadcast = f32[2,1] broadcast(constant.1), dimensions={}
constant.3 = s32[] constant(0)
ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3)
}
%WhileBody (body_param: (f32[2,3], f32[2,3], f32[])) -> (f32[2,3], f32[2,3], f32[]) {
%body_param = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) parameter(0)
%get-tuple-element.1 = f32[] get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=2
%get-tuple-element.2 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=0
%get-tuple-element.3 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=1
%fusion = f32[2,3]{1,0} fusion(get-tuple-element.3), kind=kLoop, calls=fused_computation
%multiply = f32[2,3]{1,0} multiply(f32[2,3]{1,0} %get-tuple-element.2, f32[2,3]{1,0} %fusion)
ROOT %tuple = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) tuple(f32[2,3]{1,0} %multiply, f32[2,3]{1,0} %fusion, f32[] %get-tuple-element.1)
}
%WhileCond (cond_param: (f32[2,3], f32[2,3], f32[])) -> pred[] {
%cond_param = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) parameter(0)
%get-tuple-element = f32[] get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %cond_param), index=2
%constant = f32[] constant(50)
ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
}
ENTRY %Entry (param_data: f32[2,3], param_iter: f32[], p2: f32[2,3]) -> f32[2,3] {
%param_iter = f32[] parameter(1)
%param_data = f32[2,3]{1,0} parameter(0)
%p2 = f32[2,3]{1,0} parameter(2)
%copy1 = f32[2,3]{1,0} copy(param_data)
%copy2 = f32[2,3]{1,0} copy(p2)
%tuple.1 = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) tuple(f32[2,3]{1,0} copy1, f32[2,3]{1,0} copy2, f32[] %param_iter)
%while = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) while((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
%get-tuple-element.4 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %while), index=0
ROOT %copy3 = f32[2,3]{1,0} copy(get-tuple-element.4)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AssignMemorySpace(module.get());
const HloInstruction* while_op =
module->entry_computation()->GetInstructionWithName("while");
if (GetParam()) {
EXPECT_EQ(
ShapeUtil::GetSubshape(while_op->shape(), {1}).layout().memory_space(),
kAlternateMemorySpace);
}
}
TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) {
// Having control_predecessors on an HLO was preventing us from DCEing an op
// that doesn't have any users (tuple.1). The scheduler assumes the graph is

View File

@ -413,6 +413,15 @@ absl::optional<HloInstruction*> ExchangeHalo(
std::vector<HloInstruction*> concat_pieces;
int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count);
int64 max_right_halo_size =
right_halo_size_function.MaxInRange(0, shard_count - 1);
if (max_left_halo_size + max_right_halo_size + input_shard_size >=
input_shard_size * shard_count &&
(max_left_halo_size > input_shard_size ||
max_right_halo_size > input_shard_size)) {
return absl::nullopt;
}
// Left halo.
for (int64 i = CeilOfRatio(max_left_halo_size, input_shard_size) - 1; i >= 0;
--i) {
std::vector<std::pair<int64, int64>> source_target_pairs;
@ -447,8 +456,6 @@ absl::optional<HloInstruction*> ExchangeHalo(
concat_pieces.push_back(hlo);
// Right halo.
int64 max_right_halo_size =
right_halo_size_function.MaxInRange(0, shard_count - 1);
for (int64 i = 0; i < CeilOfRatio(max_right_halo_size, input_shard_size);
++i) {
std::vector<std::pair<int64, int64>> source_target_pairs;

View File

@ -33,6 +33,36 @@ namespace xla {
TupleSimplifier::TupleSimplifier(bool exclude_entry_computation)
: exclude_entry_computation_(exclude_entry_computation) {}
StatusOr<bool> TupleSimplifier::RemoveWholeTuple(HloInstruction* tuple) {
bool changed = false;
HloInstruction* top_tuple = nullptr;
bool can_simplify = true;
for (int64 operand_number = 0; operand_number < tuple->operand_count();
++operand_number) {
HloInstruction* operand = tuple->mutable_operand(operand_number);
if (operand->opcode() != HloOpcode::kGetTupleElement ||
operand->tuple_index() != operand_number) {
can_simplify = false;
break;
}
if (top_tuple == nullptr) {
top_tuple = operand->mutable_operand(0);
if (!ShapeUtil::Compatible(top_tuple->shape(), tuple->shape())) {
can_simplify = false;
break;
}
} else if (top_tuple != operand->operand(0)) {
can_simplify = false;
break;
}
}
if (can_simplify && top_tuple != nullptr) {
changed = true;
TF_RETURN_IF_ERROR(tuple->parent()->ReplaceInstruction(tuple, top_tuple));
}
return changed;
}
StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
// Initially add all GTE and Tuple instructions to the worklist.
bool changed = false;
@ -43,46 +73,7 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
}
for (auto* instruction : computation->MakeInstructionPostOrder()) {
if (instruction->opcode() == HloOpcode::kTuple) {
// Collapse the following structure into just 'Tuple-shaped Op':
//
// Tuple-shaped Op
// |
// +-----+-----+
// | | |
// GTE GTE GTE
// | | |
// +-----+-----+
// |
// Tuple
//
HloInstruction* top_tuple = nullptr;
bool can_simplify = true;
for (int64 operand_number = 0;
operand_number < instruction->operand_count(); ++operand_number) {
HloInstruction* operand =
instruction->mutable_operand(operand_number);
if (operand->opcode() != HloOpcode::kGetTupleElement ||
operand->tuple_index() != operand_number) {
can_simplify = false;
break;
}
if (top_tuple == nullptr) {
top_tuple = operand->mutable_operand(0);
if (!ShapeUtil::Compatible(top_tuple->shape(),
instruction->shape())) {
can_simplify = false;
break;
}
} else if (top_tuple != operand->operand(0)) {
can_simplify = false;
break;
}
}
if (can_simplify && top_tuple != nullptr) {
changed = true;
TF_RETURN_IF_ERROR(
computation->ReplaceInstruction(instruction, top_tuple));
}
TF_ASSIGN_OR_RETURN(changed, RemoveWholeTuple(instruction));
} else {
auto ancestor = instruction->LatestNonGteAncestorAndIndex();
if (ancestor.first == instruction) {
@ -102,6 +93,11 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
// GTE
// |
// GTE
//
// Note that this deletes the Tuple instruction altogether. In addition,
// if only a subset of tuple's elements are used, this transform
// optimizes them one at a time, and after the last use is optimized,
// the Tuple will also be deleted.
if (ShapeUtil::Compatible(ancestor.first->shape(),
instruction->shape())) {
changed = true;

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
@ -41,6 +42,20 @@ class TupleSimplifier : public HloModulePass {
// apart from the module's entry computation. This is used by Graphcore's
// backend.
bool exclude_entry_computation_;
// Collapse the following structure into just 'Tuple-shaped Op':
//
// Tuple-shaped Op
// |
// +-----+-----+
// | | |
// GTE GTE GTE
// | | |
// +-----+-----+
// |
// Tuple
//
StatusOr<bool> RemoveWholeTuple(HloInstruction* tuple);
};
} // namespace xla

View File

@ -216,11 +216,8 @@ TEST_F(BufferDonationTest, SimpleWhileTupleTest) {
HloInstruction::CreateGetTupleElement(f32v1_, while0, 1));
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
module->AddEntryComputation(builder.Build());
// Input output aliasing is only supported on TPU.
#if defined(XLA_TEST_BACKEND_TPU)
TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({0}, 0, {0}));
TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({1}, 0, {1}));
#endif
auto arg = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR0<int>(0), LiteralUtil::CreateR1<float>({1.1f})});

View File

@ -1361,6 +1361,7 @@ cc_library(
hdrs = ["replicate_per_replica_nodes.h"],
copts = tf_copts(),
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",

View File

@ -666,7 +666,9 @@ Status DirectSession::RunInternal(
std::unique_ptr<ProfilerSession> profiler_session;
if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
profiler_session = ProfilerSession::Create();
ProfileOptions options = ProfilerSession::DefaultOptions();
options.set_host_tracer_level(0);
profiler_session = ProfilerSession::Create(options);
}
// Register this step with session's cancellation manager, so that

View File

@ -847,7 +847,7 @@ Status VerifyVirtualDeviceSettings(
" #valid GPUs: ", valid_platform_gpu_ids.size(),
" virtual_devices.size(): ", virtual_devices.size());
}
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Check memory_limt_mb and priority sizes match if priority is non-empty.
bool priority_exists = !virtual_devices.Get(0).priority().empty();
for (int i = 0; i < virtual_devices.size(); ++i) {
@ -893,15 +893,6 @@ Status VerifyVirtualDeviceSettings(
}
}
}
#elif TENSORFLOW_USE_ROCM
for (int i = 0; i < virtual_devices.size(); ++i) {
if (!virtual_devices.Get(i).priority().empty()) {
return errors::InvalidArgument(
"Priority is supported only on Nvidia GPUs."
" However, priority is set for virtual device ",
i, ", which corresponds to a non Nvidia GPU");
}
}
#endif
return Status::OK();
@ -1185,6 +1176,18 @@ Status BaseGPUDeviceFactory::CreateDevices(
platform_gpu_id.value(),
" failed. Status: ", hipGetErrorString(err));
}
int priority_low, priority_high;
hipDeviceGetStreamPriorityRange(&priority_low, &priority_high);
if (err != hipSuccess) {
return errors::Internal(
"hipDeviceGetStreamPriorityRange() on GPU:", original_device,
" failed. Status: ", hipGetErrorString(err));
}
VLOG(1) << "HIP stream priority range on GPU(" << original_device
<< "): " << priority_high << "," << priority_low;
supported_priority_ranges.insert(
std::make_pair(platform_gpu_id.value(),
std::make_pair(priority_low, priority_high)));
#endif
}
// Reset to the original device.

View File

@ -229,52 +229,89 @@ TEST_F(GPUDeviceTest, SingleVirtualDeviceWithMemoryLimitAndNoPriority) {
TEST_F(GPUDeviceTest, SingleVirtualDeviceWithInvalidPriority) {
{
// Priority outside the range (-1, 0).
#if TENSORFLOW_USE_ROCM
// Priority outside the range (0, 2) for AMD GPUs
SessionOptions opts =
MakeSessionOptions("0", 0, 1, {{123, 456}}, {{-1, 2}});
#else
// Priority outside the range (-1, 0) for NVidia GPUs
SessionOptions opts =
MakeSessionOptions("0", 0, 1, {{123, 456}}, {{-2, 0}});
#endif
std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
#if TENSORFLOW_USE_ROCM
ExpectErrorMessageSubstr(
status,
"Priority -1 is outside the range of supported priorities [0,2] for"
" virtual device 0 on GPU# 0");
#else
ExpectErrorMessageSubstr(
status,
"Priority -2 is outside the range of supported priorities [-1,0] for"
" virtual device 0 on GPU# 0");
#endif
}
{
// Priority outside the range (-1, 0).
#if TENSORFLOW_USE_ROCM
// Priority outside the range (0, 2) for AMD GPUs
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}, {{0, 3}});
#else
// Priority outside the range (-1, 0) for NVidia GPUs
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}, {{0, 1}});
#endif
std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
#if TENSORFLOW_USE_ROCM
ExpectErrorMessageSubstr(
status,
"Priority 3 is outside the range of supported priorities [0,2] for"
" virtual device 0 on GPU# 0");
#else
ExpectErrorMessageSubstr(
status,
"Priority 1 is outside the range of supported priorities [-1,0] for"
" virtual device 0 on GPU# 0");
#endif
}
}
TEST_F(GPUDeviceTest, SingleVirtualDeviceWithMemoryLimitAndPriority) {
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123}}, {{-1}});
// 0 is a valid priority value for both AMD and NVidia GPUs
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123}}, {{0}});
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices));
EXPECT_EQ(1, devices.size());
EXPECT_EQ(123 << 20, devices[0]->attributes().memory_limit());
EXPECT_EQ(-1, static_cast<BaseGPUDevice*>(devices[0].get())->priority());
EXPECT_EQ(0, static_cast<BaseGPUDevice*>(devices[0].get())->priority());
}
TEST_F(GPUDeviceTest, MultipleVirtualDevices) {
#if TENSORFLOW_USE_ROCM
// Valid range for priority values on AMD GPUs in (0,2)
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}, {{0, 1}});
#else
// Valid range for priority values on NVidia GPUs in (-1, 0)
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}, {{0, -1}});
#endif
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices));
EXPECT_EQ(2, devices.size());
EXPECT_EQ(123 << 20, devices[0]->attributes().memory_limit());
EXPECT_EQ(456 << 20, devices[1]->attributes().memory_limit());
#if TENSORFLOW_USE_ROCM
EXPECT_EQ(0, static_cast<BaseGPUDevice*>(devices[0].get())->priority());
EXPECT_EQ(1, static_cast<BaseGPUDevice*>(devices[1].get())->priority());
#else
EXPECT_EQ(0, static_cast<BaseGPUDevice*>(devices[0].get())->priority());
EXPECT_EQ(-1, static_cast<BaseGPUDevice*>(devices[1].get())->priority());
#endif
ASSERT_EQ(1, devices[0]->attributes().locality().links().link_size());
ASSERT_EQ(1, devices[1]->attributes().locality().links().link_size());
EXPECT_EQ(1, devices[0]->attributes().locality().links().link(0).device_id());
@ -292,7 +329,8 @@ TEST_F(GPUDeviceTest, MultipleVirtualDevices) {
TEST_F(GPUDeviceTest, MultipleVirtualDevicesWithPriority) {
{
// Multile virtual devices with fewer priorities.
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}, {{-1}});
// 0 is a valid priority value for both AMD and NVidia GPUs
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}, {{0}});
std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
@ -305,16 +343,27 @@ TEST_F(GPUDeviceTest, MultipleVirtualDevicesWithPriority) {
}
{
// Multile virtual devices with matching priority.
#if TENSORFLOW_USE_ROCM
// Valid range for priority values on AMD GPUs in (0,2)
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}, {{2, 1}});
#else
// Valid range for priority values on NVidia GPUs in (-1, 0)
SessionOptions opts =
MakeSessionOptions("0", 0, 1, {{123, 456}}, {{-1, 0}});
#endif
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices));
EXPECT_EQ(2, devices.size());
EXPECT_EQ(123 << 20, devices[0]->attributes().memory_limit());
EXPECT_EQ(456 << 20, devices[1]->attributes().memory_limit());
#if TENSORFLOW_USE_ROCM
EXPECT_EQ(2, static_cast<BaseGPUDevice*>(devices[0].get())->priority());
EXPECT_EQ(1, static_cast<BaseGPUDevice*>(devices[1].get())->priority());
#else
EXPECT_EQ(-1, static_cast<BaseGPUDevice*>(devices[0].get())->priority());
EXPECT_EQ(0, static_cast<BaseGPUDevice*>(devices[1].get())->priority());
#endif
}
}

View File

@ -394,10 +394,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.batch_matmul,
mkl_op_registry::GetMklOpName(csinfo_.batch_matmul),
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
rinfo_.push_back({csinfo_.batch_matmul_v2,
mkl_op_registry::GetMklOpName(csinfo_.batch_matmul_v2),
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
rinfo_.push_back(
{csinfo_.concat, mkl_op_registry::GetMklOpName(csinfo_.concat),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});

File diff suppressed because it is too large Load Diff

View File

@ -25,7 +25,7 @@ limitations under the License.
namespace tensorflow {
#ifdef _OPENMP
#if defined(_OPENMP) && !defined(ENABLE_MKLDNN_THREADPOOL)
TEST(MKLThreadPoolDeviceTest, TestOmpDefaults) {
SessionOptions options;
unsetenv("OMP_NUM_THREADS");
@ -46,7 +46,7 @@ TEST(MKLThreadPoolDeviceTest, TestOmpPreSets) {
EXPECT_EQ(omp_get_max_threads(), 314);
}
#endif // _OPENMP
#endif // defined(_OPENMP) && !defined(ENABLE_MKLDNN_THREADPOOL)
} // namespace tensorflow

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
namespace tensorflow {
namespace {
@ -115,12 +116,36 @@ class ReplicateHelper {
// This happens when the dst node runs on a host CPU and
// captures a function with an arg node assigned to the same
// composite device (e.g. ScanDataset).
// For this case, we only need to add an edge connecting the arg
// node in the outer function and the corresponding arg in the
// inner function, since the host CPU only needs one copy of the
// ResourceHandle.
graph->AddEdge(src_replicated_nodes.at(0), edge->src_output(), dst,
edge->dst_input());
// For this case, we insert a PackOp between replicated nodes and the
// dst node. The dst node is responsible for unpacking the packed
// tensor.
// Add '/Packed' as a substring to the name of the new node, which
// could be helpful when debugging the graph.
NodeDefBuilder pack_builder(
graph->NewName(absl::StrCat(edge->src()->name(), "/Packed")),
"Pack");
const int num_replicas = src_replicated_nodes.size();
pack_builder.Attr("N", num_replicas);
const DataType dtype = edge->src()->output_type(edge->src_output());
pack_builder.Attr("T", dtype);
std::vector<NodeDefBuilder::NodeOut> inputs;
inputs.reserve(src_replicated_nodes.size());
for (Node* replicated_node : src_replicated_nodes) {
inputs.emplace_back(NodeDefBuilder::NodeOut{
replicated_node->name(), edge->src_output(), dtype});
}
pack_builder.Input(inputs);
NodeDef pack_def;
TF_RETURN_IF_ERROR(pack_builder.Finalize(&pack_def));
Status status;
Node* pack_node = graph->AddNode(pack_def, &status);
TF_RETURN_IF_ERROR(status);
pack_node->set_assigned_device_name(dst->assigned_device_name());
for (int i = 0; i < src_replicated_nodes.size(); ++i) {
graph->AddEdge(src_replicated_nodes[i], edge->src_output(),
pack_node, i);
}
graph->AddEdge(pack_node, /*x=*/0, dst, edge->dst_input());
} else {
return errors::InvalidArgument(
"Dst node should be assigned to an allowed device. Found an "

View File

@ -258,16 +258,24 @@ TEST(ReplicatePerReplicaNodesTest, NestedFunctions) {
ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
{
// _Arg(TPU:0) -> Func(CPU:0) -> _Retval(CPU:0)
EXPECT_EQ(graph.num_op_nodes(), 4);
// _Arg(TPU:0), _Arg(TPU:1) -> Pack(CPU:0) -> Func(CPU:0) -> _Retval(CPU:0)
EXPECT_EQ(graph.num_op_nodes(), 5);
GraphHelper helper(graph);
helper.CheckAssignedDevice("arg/R0", "TPU:0");
helper.CheckAssignedDevice("arg/R1", "TPU:1");
helper.CheckAssignedDevice("arg/Packed", "CPU:0");
helper.CheckAssignedDevice("func", "CPU:0");
helper.CheckAssignedDevice("ret", "CPU:0");
const EdgeSet& in_edges = helper.GetNodeByName("func")->in_edges();
EXPECT_EQ(in_edges.size(), 1);
EXPECT_EQ(helper.GetNodeByName("arg/R0"), (*in_edges.begin())->src());
const EdgeSet& packed_in_edges =
helper.GetNodeByName("arg/Packed")->in_edges();
EXPECT_EQ(packed_in_edges.size(), 2);
auto it = packed_in_edges.begin();
EXPECT_EQ(helper.GetNodeByName("arg/R0"), (*it++)->src());
EXPECT_EQ(helper.GetNodeByName("arg/R1"), (*it)->src());
const EdgeSet& func_in_edges = helper.GetNodeByName("func")->in_edges();
EXPECT_EQ(func_in_edges.size(), 1);
EXPECT_EQ(helper.GetNodeByName("arg/Packed"),
(*func_in_edges.begin())->src());
}
}

View File

@ -344,6 +344,8 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
static constexpr const char* const kDeviceRetOp = "_DeviceRetval";
static constexpr const char* const kIntsOnDeviceAttr =
"experimental_ints_on_device";
static constexpr const char* const kSharedRendezvousAttr =
"shared_rendezvous";
static constexpr const char* const kGradientOp = "SymbolicGradient";
static constexpr const char* const kFuncAttr = "f";

View File

@ -59,9 +59,13 @@ class InterleaveMany : public Node {
(*input_times)[long_name()] = old_input_time;
return;
}
double new_input_time =
old_input_time +
SelfProcessingTimeLocked() * static_cast<double>(num_inputs() - 1);
// Here `old_input_time + SelfProcessingTimeLocked()` is the average input
// time for the interleave node to call one of the `(num_inputs() - 1)`
// input nodes(except the first one) to return an element. Regardless of the
// `block_length` parameter of interleave node, the average input time for
// any of the `(num_inputs() - 1)` input nodes to be called is computed as:
double new_input_time = (old_input_time + SelfProcessingTimeLocked()) *
static_cast<double>(num_inputs() - 1);
(*input_times)[long_name()] = new_input_time;
}

Some files were not shown because too many files have changed in this diff Show More