Merge remote-tracking branch 'upstream/master' into offline_memory_planner
This commit is contained in:
commit
dc3c76758e
@ -1 +1 @@
|
||||
3.0.0
|
||||
3.1.0
|
||||
|
@ -49,7 +49,7 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
|
||||
_TF_WORKSPACE_ROOT = ''
|
||||
_TF_BAZELRC = ''
|
||||
_TF_CURRENT_BAZEL_VERSION = None
|
||||
_TF_MIN_BAZEL_VERSION = '2.0.0'
|
||||
_TF_MIN_BAZEL_VERSION = '3.1.0'
|
||||
_TF_MAX_BAZEL_VERSION = '3.99.0'
|
||||
|
||||
NCCL_LIB_PATHS = [
|
||||
|
@ -40,6 +40,9 @@ using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
|
||||
using MaybeParallelTensorOwned =
|
||||
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
|
||||
|
||||
using MaybeParallelTensorUnowned =
|
||||
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
|
||||
|
||||
// A ParallelDevice on its own is not registered with a TFE_Context, and so has
|
||||
// no device name (e.g. for `tf.device`). `NamedParallelDevice` associates a
|
||||
// name with it, which lets us pack its `ParallelTensor`s into TFE_TensorHandles
|
||||
@ -141,9 +144,32 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
|
||||
result.emplace(std::move(result_content));
|
||||
return result;
|
||||
}
|
||||
std::vector<ParallelTensor*> parallel_inputs;
|
||||
std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors;
|
||||
parallel_inputs.reserve(inputs.size());
|
||||
implicitly_broadcast_tensors.reserve(inputs.size()); // not tight
|
||||
for (const auto& input : inputs) {
|
||||
if (absl::holds_alternative<TFE_TensorHandle*>(input)) {
|
||||
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
|
||||
// to each parallel operation.
|
||||
//
|
||||
// TODO(allenl): There may be smarter ways to do this copy in some
|
||||
// cases, i.e. with a collective broadcast. We'll need to be careful
|
||||
// about things that are taken as inputs on the host or on their
|
||||
// existing device (for multi-device functions).
|
||||
std::unique_ptr<ParallelTensor> parallel_tensor(
|
||||
parallel_device.CopyToParallelDevice(
|
||||
context, absl::get<TFE_TensorHandle*>(input), status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
parallel_inputs.push_back(parallel_tensor.get());
|
||||
implicitly_broadcast_tensors.emplace_back(std::move(parallel_tensor));
|
||||
} else {
|
||||
parallel_inputs.push_back(absl::get<ParallelTensor*>(input));
|
||||
}
|
||||
}
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
maybe_parallel_results(
|
||||
parallel_device.Execute(context, std::move(inputs), operation_name,
|
||||
parallel_device.Execute(context, parallel_inputs, operation_name,
|
||||
attributes, expected_max_outputs, status));
|
||||
if (!maybe_parallel_results.has_value()) return result;
|
||||
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
|
||||
|
@ -100,7 +100,7 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
ParallelDevice::Execute(TFE_Context* context,
|
||||
std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const std::vector<ParallelTensor*>& inputs,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int expected_max_outputs,
|
||||
TF_Status* status) const {
|
||||
@ -129,26 +129,10 @@ ParallelDevice::Execute(TFE_Context* context,
|
||||
status);
|
||||
TFE_OpAddAttrs(op.get(), attributes);
|
||||
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
|
||||
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
|
||||
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
|
||||
// to each parallel operation.
|
||||
//
|
||||
// TODO(allenl): There may be smarter ways to do this copy in some
|
||||
// cases, i.e. with a collective broadcast. We'll need to be careful
|
||||
// about things that are taken as inputs on the host or on their
|
||||
// existing device (for multi-device functions).
|
||||
TFE_OpAddInput(op.get(),
|
||||
absl::get<TFE_TensorHandle*>(inputs[input_index]),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
} else {
|
||||
// Parallel tensors are divided between operations by device.
|
||||
TFE_OpAddInput(op.get(),
|
||||
absl::get<ParallelTensor*>(inputs[input_index])
|
||||
->tensor(device_index),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
// Parallel tensors are divided between operations by device.
|
||||
TFE_OpAddInput(op.get(), inputs[input_index]->tensor(device_index),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
|
||||
int real_num_outputs = expected_max_outputs;
|
||||
|
@ -52,9 +52,6 @@ using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
||||
|
||||
class ParallelTensor;
|
||||
|
||||
using MaybeParallelTensorUnowned =
|
||||
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
|
||||
|
||||
// Forwards operations to `devices`, maintaining ParallelTensor with components
|
||||
// placed on each underlying device.
|
||||
class ParallelDevice {
|
||||
@ -79,10 +76,9 @@ class ParallelDevice {
|
||||
|
||||
// Takes a description of a single operation being executed on the
|
||||
// ParallelDevice, and in turn runs one operation per component device with
|
||||
// its corresponding inputs from the input ParallelTensors (or
|
||||
// implicitly-mirrored tensors on other devices). Wraps the resulting
|
||||
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
|
||||
// output of the original operation.
|
||||
// its corresponding inputs from the input ParallelTensors. Wraps the
|
||||
// resulting per-device and per-output TFE_TensorHandles into one
|
||||
// ParallelTensor per output of the original operation.
|
||||
//
|
||||
// Attributes are forwarded to executed operations unmodified.
|
||||
//
|
||||
@ -90,7 +86,7 @@ class ParallelDevice {
|
||||
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
|
||||
// if sanity checks on dtypes/metadata fail.
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const;
|
||||
|
||||
|
@ -468,10 +468,6 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
<< "Invalid input for outputs " << i << ": " << input_index;
|
||||
ctx->set_output(i, ctx->input(input_index));
|
||||
} else {
|
||||
if (MustAliasOutput(input_output_alias, output_num)) {
|
||||
DCHECK(output.buffer({output_num}).is_null())
|
||||
<< "Expected output buffer to be aliased, but it is not nil.";
|
||||
}
|
||||
if (allocate_xla_tensors_) {
|
||||
TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
|
||||
input_output_alias, output_num, ctx, i, shape, &output,
|
||||
|
@ -32,7 +32,6 @@ struct PassConfig {
|
||||
lower_tensor_list_ops(false),
|
||||
trim_functions_whitelist({}),
|
||||
quant_specs(std::move(specs)),
|
||||
skip_control_dialect(false),
|
||||
form_clusters(false),
|
||||
unfold_batch_matmul(true),
|
||||
legalize_tf_while(true),
|
||||
@ -49,13 +48,8 @@ struct PassConfig {
|
||||
llvm::ArrayRef<std::string> trim_functions_whitelist;
|
||||
// All information about quantization.
|
||||
QuantizationSpecs quant_specs;
|
||||
// If `skip_control_dialect` is true, TF executor dialect is not converted to
|
||||
// TF control dialect prior to legalization to TF Lite.
|
||||
// TODO(b/142911013): Remove flag once control dialect is removed.
|
||||
bool skip_control_dialect;
|
||||
// If `form_clusters` is true (and `skip_control_dialect` is true), clusters
|
||||
// are formed by grouping consecutive ops of the same device, under a
|
||||
// `tf_device.launch` op.
|
||||
// If `form_clusters` is true , clusters are formed by grouping consecutive
|
||||
// ops of the same device, under a `tf_device.launch` op.
|
||||
bool form_clusters;
|
||||
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
|
||||
// of tfl.fully_connected ops.
|
||||
|
@ -803,9 +803,17 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
}
|
||||
|
||||
for (auto output : func_outputs) {
|
||||
bool is_constant = !is_op_output[output];
|
||||
const bool is_func_input = input_index_set.contains(output);
|
||||
bool is_constant = !is_op_output[output] && !is_func_input;
|
||||
// There are 2 cases tensor is scalar when it doesn't have a shape in
|
||||
// flatbuffer:
|
||||
// 1. `is_constant` = true, means this tensor is created from a constant op.
|
||||
// 2. `is_func_input` = true and `is_entry_point` = true, which means this
|
||||
// tensor is function input and function input type is a scalar tensor.
|
||||
const bool shapeless_is_scalar =
|
||||
is_constant || (is_func_input && is_entry_point);
|
||||
auto type_or_err = GetTensorType(*subgraph.tensors.at(output), builder,
|
||||
/*shapeless_are_scalars=*/is_constant,
|
||||
shapeless_is_scalar,
|
||||
/*is_constant=*/is_constant);
|
||||
if (!type_or_err.ok()) {
|
||||
emitError(func_loc, "error reading return types")
|
||||
|
@ -0,0 +1,22 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
// This test is to test for unranked function output from input, the output type should be compatible with input type.
|
||||
|
||||
// CHECK: func @main(%arg0: tensor<1xf32>) -> tensor<*xf32>
|
||||
// CHECK: %0 = "tf.While"(%arg0) {body = @body, cond = @cond, is_stateless = false} : (tensor<1xf32>) -> tensor<*xf32>
|
||||
// CHECK: return %0 : tensor<*xf32>
|
||||
// CHECK: func @cond(%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: func @body(%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
|
||||
func @main(%arg0: tensor<1xf32>) -> tensor<*xf32> {
|
||||
%0 = "tf.While"(%arg0) {cond = @cond, body = @body, is_stateless = false} : (tensor<1xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @cond(%arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
return %arg1: tensor<*xf32>
|
||||
}
|
||||
|
||||
func @body(%arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
return %arg1: tensor<*xf32>
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -58,21 +58,10 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
|
||||
|
||||
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
mlir::OpPassManager* pass_manager) {
|
||||
pass_manager->addPass(mlir::tf_executor::CreateSwitchFoldPass());
|
||||
if (pass_config.skip_control_dialect) {
|
||||
// Merge islands.
|
||||
pass_manager->addPass(
|
||||
mlir::tf_executor::CreateTFExecutorIslandCoarseningPass());
|
||||
// Assuming island coarsening above results in a graph with a single island,
|
||||
// a canonicalization can be ran to hoist the ops of the single island out.
|
||||
pass_manager->addPass(mlir::createCanonicalizerPass());
|
||||
|
||||
if (pass_config.form_clusters)
|
||||
pass_manager->addPass(mlir::TFDevice::CreateClusterFormationPass());
|
||||
} else {
|
||||
pass_manager->addPass(mlir::CreateTFExecutorToControlDialectConversion());
|
||||
pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
|
||||
}
|
||||
mlir::TF::StandardPipelineOptions standard_pipeline_options;
|
||||
standard_pipeline_options.enable_inliner = false;
|
||||
standard_pipeline_options.form_clusters = pass_config.form_clusters;
|
||||
mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
|
||||
|
||||
if (pass_config.shape_inference) {
|
||||
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||
@ -213,13 +202,8 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
|
||||
OpPassManager& func_pm = pm.nest<FuncOp>();
|
||||
|
||||
// tf_executor dialect passes - Cleaning up the IR.
|
||||
func_pm.addPass(tf_executor::CreateSwitchFoldPass());
|
||||
func_pm.addPass(tf_executor::CreateTFExecutorGraphPruningPass());
|
||||
func_pm.addPass(tf_executor::CreateTFExecutorIslandCoarseningPass());
|
||||
|
||||
// more cleanup of executor dialect and raise to control flow.
|
||||
pm.addPass(mlir::CreateTFExecutorToControlDialectConversion());
|
||||
pm.addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
|
||||
mlir::TF::StandardPipelineOptions standard_pipeline_options;
|
||||
mlir::TF::CreateTFStandardPipeline(func_pm, standard_pipeline_options);
|
||||
|
||||
// This is needed for control flow support with TF TensorList.
|
||||
pm.addPass(mlir::TFL::CreateLowerStaticTensorListPass());
|
||||
|
@ -38,12 +38,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/tools/optimize/quantize_weights.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace mlir {
|
||||
/// Create a pass to convert from the TFExecutor to the TF control dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateTFExecutorToControlDialectConversion();
|
||||
} // namespace mlir
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using mlir::MLIRContext;
|
||||
|
@ -54,7 +54,6 @@ class WhileOutlinePass
|
||||
|
||||
tensorflow::OpOrArgLocNameMapper mapper_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) {
|
||||
return (mapper_.GetUniqueName(op) + suffix).str();
|
||||
@ -62,7 +61,7 @@ std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) {
|
||||
|
||||
// Returns whether the WhileOp is already outlined (e.g., only consists of calls
|
||||
// to functions).
|
||||
static bool IsAlreadyOutlinedd(WhileOp while_op) {
|
||||
bool IsAlreadyOutlined(WhileOp while_op) {
|
||||
auto just_call = [](Region& region) {
|
||||
auto it = region.front().begin();
|
||||
if (!isa<CallOp>(*it)) return false;
|
||||
@ -120,7 +119,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
||||
}
|
||||
|
||||
// Skip if already just calls.
|
||||
if (extra_operands.empty() && IsAlreadyOutlinedd(while_op)) return;
|
||||
if (extra_operands.empty() && IsAlreadyOutlined(while_op)) return;
|
||||
|
||||
// Collect new types.
|
||||
SmallVector<Type, 4> types;
|
||||
@ -238,6 +237,7 @@ void WhileOutlinePass::runOnOperation() {
|
||||
getOperation().walk(
|
||||
[&](mlir::TFL::WhileOp while_op) { OutlineWhile(while_op); });
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass() {
|
||||
|
@ -67,27 +67,57 @@ inline RankedTensorType getResultType(mlir::FuncOp func, int idx) {
|
||||
}
|
||||
|
||||
LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
|
||||
if (func.getNumResults() != 2) {
|
||||
return failure();
|
||||
}
|
||||
if (func.getNumArguments() != 1) {
|
||||
return failure();
|
||||
}
|
||||
// In the case of input tensor with 0 rank.
|
||||
// Whitespace tokenizer generates 1 output:
|
||||
// * String tensor for tokens.
|
||||
//
|
||||
// In the case of 1-D input tensor,
|
||||
// Whitespace tokenizer generates 2 outputs to make up a ragged tensor:
|
||||
// * 1st output is the value of ragged tensor;
|
||||
// * 2nd output is the offset.
|
||||
//
|
||||
// In the case of batched input tesnor,
|
||||
// Whitespace tokenizer has 3 outputs to make up a nested ragged tensor:
|
||||
// * 1st output is the value of ragged tensor;
|
||||
// * 2nd output is the inner offset;
|
||||
// * 3rd output is the outer offset.
|
||||
auto input_type = getInputType(func, 0);
|
||||
if (!input_type || input_type.getRank() != 1 ||
|
||||
!input_type.getElementType().isa<mlir::TF::StringType>()) {
|
||||
return failure();
|
||||
if (!input_type || !input_type.getElementType().isa<mlir::TF::StringType>() ||
|
||||
!input_type.hasRank()) {
|
||||
return func.emitError() << "Input should be a string tensor";
|
||||
}
|
||||
|
||||
const std::vector<int> kValidNumOfOutput = {1, 2, 3};
|
||||
if (input_type.getRank() >= kValidNumOfOutput.size()) {
|
||||
return func.emitError()
|
||||
<< "Unrecognized input rank: " << input_type.getRank();
|
||||
}
|
||||
if (func.getNumResults() != kValidNumOfOutput[input_type.getRank()]) {
|
||||
return func.emitError()
|
||||
<< "Expect " << kValidNumOfOutput[input_type.getRank()]
|
||||
<< "output(s) when input has rank " << input_type.getRank();
|
||||
}
|
||||
|
||||
auto value_type = getResultType(func, 0);
|
||||
if (!value_type || value_type.getRank() != 1 ||
|
||||
if (!value_type || !value_type.hasRank() || value_type.getRank() != 1 ||
|
||||
!value_type.getElementType().isa<mlir::TF::StringType>()) {
|
||||
return failure();
|
||||
return func.emitError() << "1st output should be string tensor";
|
||||
}
|
||||
auto offset_type = getResultType(func, 1);
|
||||
if (offset_type.getRank() != 1 ||
|
||||
!offset_type.getElementType().isInteger(64)) {
|
||||
return failure();
|
||||
if (func.getNumResults() > 1) {
|
||||
auto offset_type = getResultType(func, 1);
|
||||
if (!offset_type || !offset_type.hasRank() || offset_type.getRank() != 1 ||
|
||||
!offset_type.getElementType().isInteger(64)) {
|
||||
return func.emitError() << "2nd output should be int64 tensor";
|
||||
}
|
||||
}
|
||||
if (func.getNumResults() > 2) {
|
||||
auto offset_type = getResultType(func, 2);
|
||||
if (!offset_type || !offset_type.hasRank() || offset_type.getRank() != 1 ||
|
||||
!offset_type.getElementType().isInteger(64)) {
|
||||
return func.emitError() << "3rd output should be int64 tensor";
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -96,19 +126,12 @@ LogicalResult ConvertWhitespaceTokenizer(mlir::FuncOp func,
|
||||
func.eraseBody();
|
||||
func.addEntryBlock();
|
||||
func.setAttr(kTFAPIImplements, StringAttr::get(api, func.getContext()));
|
||||
|
||||
Value text = func.getArgument(0);
|
||||
auto output_type = func.getType().getResult(0);
|
||||
auto offset_type = func.getType().getResult(1);
|
||||
SmallVector<Type, 2> shape = {output_type, offset_type};
|
||||
ArrayRef<Type> output_types(shape);
|
||||
|
||||
OpBuilder builder(func.getBody());
|
||||
|
||||
auto op = builder.create<mlir::TFL::CustomOp>(func.getLoc(), output_types,
|
||||
ValueRange(text), api,
|
||||
emptyCustomOption(&builder));
|
||||
|
||||
auto op = builder.create<mlir::TFL::CustomOp>(
|
||||
func.getLoc(), func.getType().getResults(), ValueRange(text), api,
|
||||
emptyCustomOption(&builder));
|
||||
builder.create<mlir::ReturnOp>(func.getLoc(), op.getResults());
|
||||
return success();
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import lit.formats
|
||||
from lit.llvm import llvm_config
|
||||
from lit.llvm.subst import ToolSubst
|
||||
|
@ -431,7 +431,6 @@ cc_library(
|
||||
"transforms/optimize_global_tensors.cc",
|
||||
"transforms/parallel_execute_to_islands.cc",
|
||||
"transforms/promote_resources_to_args.cc",
|
||||
"transforms/raise_control_flow.cc",
|
||||
"transforms/readonly_references_to_resources.cc",
|
||||
"transforms/replicate_invariant_op_hoisting.cc",
|
||||
"transforms/replicate_to_island.cc",
|
||||
@ -460,7 +459,6 @@ cc_library(
|
||||
"transforms/tpu_variable_runtime_reformatting.cc",
|
||||
"translate/breakup-islands.cc",
|
||||
"translate/control_to_executor_dialect.cc",
|
||||
"translate/executor_to_control_dialect.cc",
|
||||
"translate/tf_functional_to_executor.cc",
|
||||
],
|
||||
hdrs = [
|
||||
|
@ -811,11 +811,13 @@ ParseResult ParseEnterOp(OpAsmParser &parser, OperationState &result) {
|
||||
// fully qualified) or a short form with a single type (in which case the data
|
||||
// input and the outputs are all using this type).
|
||||
if (FunctionType type = types.front().dyn_cast<FunctionType>()) {
|
||||
if (type.getNumInputs() != 1)
|
||||
return parser.emitError(parser.getNameLoc())
|
||||
<< " expects a single data type";
|
||||
result.types.assign(type.getResults().begin(), type.getResults().end());
|
||||
types.assign(type.getInputs().begin(), type.getInputs().end());
|
||||
// One data input, and any number of control inputs.
|
||||
if (type.getNumInputs() >= 1) {
|
||||
result.types.assign(type.getResults().begin(), type.getResults().end());
|
||||
types.assign(type.getInputs().begin(), type.getInputs().end());
|
||||
} else {
|
||||
return parser.emitError(parser.getNameLoc()) << " expects a data input";
|
||||
}
|
||||
} else {
|
||||
Type control_type = ControlType::get(context);
|
||||
types.append(op_infos.size() - 1, control_type);
|
||||
|
@ -254,9 +254,12 @@ LogicalResult VerifyExportedFunc(FuncOp func) {
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (func.getArgAttr(i, "tf.resource_name")) {
|
||||
continue;
|
||||
}
|
||||
return func.emitError()
|
||||
<< "all arguments should have 'tf_saved_model.index_path' or "
|
||||
"'tf_saved_model.bound_input' attributes";
|
||||
<< "all arguments should have 'tf_saved_model.index_path', "
|
||||
"'tf_saved_model.bound_input' or 'tf.resource_name' attributes";
|
||||
}
|
||||
llvm::SmallDenseSet<StringRef, 8> unique_bound_inputs;
|
||||
for (int i = 0, e = func.getNumArguments(); i < e; i++) {
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt -tf-executor-to-control-conversion %s | FileCheck %s --check-prefix=CONTROL --dump-input=fail
|
||||
// RUN: tf-opt -tf-executor-graph-pruning %s | FileCheck %s --check-prefix=CONTROL --dump-input=fail
|
||||
// RUN: tf-opt -tf-control-to-executor-conversion %s | FileCheck %s --check-prefix=EXECUTOR --dump-input=fail
|
||||
|
||||
// CONTROL-LABEL: func @main
|
||||
|
@ -1,188 +0,0 @@
|
||||
// RUN: tf-opt -tf-executor-to-control-conversion %s | FileCheck %s --dump-input=fail
|
||||
// CHECK-LABEL: func @LoopTest() {
|
||||
func @LoopTest() {
|
||||
tf_executor.graph {
|
||||
%0:2 = tf_executor.island {
|
||||
%cst = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_executor.yield %cst : tensor<i32>
|
||||
}
|
||||
%1:2 = tf_executor.Enter %0#0 frame "while/while_context" : (tensor<i32>) -> (tensor<*xi32>, !tf_executor.control) {T = "tfdtype$DT_INT32", device = "", name = "while/Enter"}
|
||||
%2 = tf_executor.island {
|
||||
"tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> ()
|
||||
tf_executor.yield
|
||||
}
|
||||
%3:3 = tf_executor.NextIteration.Source : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"}
|
||||
%4:3 = tf_executor.Merge %3#0, %1#0 : tensor<*xi32> {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "while/Merge"}
|
||||
%5:2 = tf_executor.island(%4#2) {
|
||||
%cst = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<2> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_executor.yield %cst : tensor<i32>
|
||||
}
|
||||
%6:2 = tf_executor.island {
|
||||
%14 = "tf.Less"(%4#0, %5#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
|
||||
tf_executor.yield %14 : tensor<*xi1>
|
||||
}
|
||||
%7:2 = tf_executor.LoopCond %6#0 : (tensor<*xi1>) -> (tensor<i1>, !tf_executor.control) {device = "", name = "while/LoopCond"}
|
||||
%8:3 = tf_executor.Switch %4#0, %7#0 : tensor<*xi32> {T = "tfdtype$DT_INT32", _class = ["loc = @while/Merge"], device = "", name = "while/Switch"}
|
||||
%9:2 = tf_executor.Exit %8#0 : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", name = "while/Exit"}
|
||||
%10:2 = tf_executor.island {
|
||||
%14 = "tf.Identity"(%8#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> tensor<*xi32>
|
||||
tf_executor.yield %14 : tensor<*xi32>
|
||||
}
|
||||
%11:2 = tf_executor.island(%10#1) {
|
||||
%cst = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_executor.yield %cst : tensor<i32>
|
||||
}
|
||||
%12:2 = tf_executor.island {
|
||||
%14 = "tf.Add"(%10#0, %11#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
tf_executor.yield %14 : tensor<*xi32>
|
||||
}
|
||||
%13 = tf_executor.ControlTrigger %2, %12#1, %9#1 {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"}
|
||||
tf_executor.NextIteration.Sink [%3#1] %12#0, %13 : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-NEXT: %[[CONST:[0-9]*]]:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[ENTER:[0-9]*]]:2 = "_tf.Enter"(%[[CONST]]#0) {T = "tfdtype$DT_INT32", device = "", frame_name = "while/while_context", is_constant = false, name = "while/Enter", parallel_iterations = 10 : i64} : (tensor<i32>) -> (tensor<*xi32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[NOOP:[0-9]*]] = "_tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> !_tf.control
|
||||
// CHECK-NEXT: %[[SOURCE:[0-9]*]]:2 = "_tf.NextIteration.source"() {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} : () -> (tensor<*xi32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[MERGE:[0-9]*]]:3 = "_tf.Merge"(%[[SOURCE]]#0, %[[ENTER]]#0) {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "while/Merge"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<i32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[CONST_LESS:[0-9]*]]:2 = "_tf.Const"(%[[MERGE]]#2) {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<2> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[LESS:[0-9]*]]:2 = "_tf.Less"(%[[MERGE]]#0, %[[CONST_LESS]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi1>, !_tf.control)
|
||||
// CHECK-NEXT: %[[COND:[0-9]*]]:2 = "_tf.LoopCond"(%[[LESS]]#0) {device = "", name = "while/LoopCond"} : (tensor<*xi1>) -> (tensor<i1>, !_tf.control)
|
||||
// CHECK-NEXT: %[[SWITCH:[0-9]*]]:3 = "_tf.Switch"(%[[MERGE]]#0, %[[COND]]#0) {T = "tfdtype$DT_INT32", _class = ["loc = @while/Merge"], device = "", name = "while/Switch"} : (tensor<*xi32>, tensor<i1>) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[EXIT:[0-9]*]]:2 = "_tf.Exit"(%[[SWITCH]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Exit"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[IDENTITY:[0-9]*]]:2 = "_tf.Identity"(%[[SWITCH]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[CONST_ADD:[0-9]*]]:2 = "_tf.Const"(%[[IDENTITY]]#1) {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<3> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = "_tf.Add"(%[[IDENTITY]]#0, %[[CONST_ADD]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[CT:[0-9]*]] = "_tf.ControlTrigger"(%[[NOOP]], %[[ADD]]#1, %[[EXIT]]#1) {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"} : (!_tf.control, !_tf.control, !_tf.control) -> !_tf.control
|
||||
// CHECK-NEXT: %[[SINK:[0-9]*]] = "_tf.NextIteration.sink"(%[[ADD]]#0, %[[CT]]) {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} : (tensor<*xi32>, !_tf.control) -> !_tf.control
|
||||
// CHECK-NEXT: return
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @multiple_ops_region
|
||||
func @multiple_ops_region(%arg0 : tensor<*xi32>, %arg1 : tensor<i32>) {
|
||||
tf_executor.graph {
|
||||
%0:2 = tf_executor.island {
|
||||
// The 4 operations are independent, but the current conversion will add
|
||||
// control dependencies conservatively.
|
||||
%1 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add1"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%2 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add2"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%3 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add3"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%4 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add4"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
tf_executor.yield %4 : tensor<*xi32>
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-NEXT: %[[ADD1:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add1"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[ADD2:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1, %[[ADD1]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add2"} : (tensor<*xi32>, tensor<i32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[ADD3:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1, %[[ADD2]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add3"} : (tensor<*xi32>, tensor<i32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[ADD4:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1, %[[ADD3]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add4"} : (tensor<*xi32>, tensor<i32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @switchN(
|
||||
func @switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%fetches = tf_executor.graph {
|
||||
// CHECK: [[S1:%.*]]:6 = "_tf._SwitchN"(%arg1, %arg0) {num_outs = 5 : i64}
|
||||
%1:6 = tf_executor.SwitchN %arg1, %arg0 of 5 : tensor<*xf32>
|
||||
// CHECK: "_tf._SwitchN"(%arg1, %arg0, [[S1]]#5) {num_outs = 12 : i64}
|
||||
%2:13 = tf_executor.SwitchN %arg1, %arg0 of 12 (%1#5) : tensor<*xf32>
|
||||
tf_executor.fetch %2#0 : tensor<*xf32>
|
||||
}
|
||||
return %fetches : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test if tf_executor dialect ops with Ref types are mapped correctly to the ops in control dialect.
|
||||
// CHECK-LABEL: func @ref_tf_executor_ops
|
||||
func @ref_tf_executor_ops(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4x!tf.f32ref>, %arg3: tensor<i32>, %arg4: tensor<i1> ) -> tensor<4x!tf.f32ref> {
|
||||
%result = tf_executor.graph {
|
||||
// CHECK: _tf.Enter
|
||||
%0:2 = tf_executor.Enter %arg0 frame "while/while_context" : (tensor<4x!tf.f32ref>) -> (tensor<4x!tf.f32ref>, !tf_executor.control)
|
||||
// CHECK: _tf.Exit
|
||||
%1:2 = tf_executor.Exit %arg0 : tensor<4x!tf.f32ref>
|
||||
// CHECK: _tf.Switch
|
||||
%2:3 = tf_executor.Switch %arg0, %arg4 : (tensor<4x!tf.f32ref>, tensor<i1>) -> (tensor<4x!tf.f32ref>, tensor<4x!tf.f32ref>, !tf_executor.control)
|
||||
// CHECK: _tf.Merge
|
||||
%3:3 = tf_executor.Merge %arg0, %arg1 : (tensor<4x!tf.f32ref>, tensor<4x!tf.f32ref>) -> (tensor<4x!tf.f32ref>, tensor<i32>, !tf_executor.control)
|
||||
// CHECK: _tf.NextIteration.source
|
||||
%4:3 = tf_executor.NextIteration.Source : tensor<4x!tf.f32ref>
|
||||
// CHECK: _tf.NextIteration.sink
|
||||
tf_executor.NextIteration.Sink [%4#1] %4#0 : tensor<4x!tf.f32ref>
|
||||
tf_executor.fetch %0#0 : tensor<4x!tf.f32ref>
|
||||
}
|
||||
return %result : tensor<4x!tf.f32ref>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests if empty island with just one control dependency input and output is
|
||||
// handled correctly.
|
||||
// CHECK-LABEL: func @empty_island_control_dep_only
|
||||
func @empty_island_control_dep_only() -> tensor<i32> {
|
||||
%fetch = tf_executor.graph {
|
||||
%0:2 = tf_executor.island {
|
||||
%4 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_executor.yield %4 : tensor<i32>
|
||||
}
|
||||
// CHECK-NEXT: %[[CONST1:[0-9]*]]:2 = "_tf.Const"()
|
||||
// CHECK-SAME: () -> (tensor<i32>, !_tf.control)
|
||||
%1:2 = tf_executor.island {
|
||||
%5 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_executor.yield %5 : tensor<i32>
|
||||
}
|
||||
// CHECK-NEXT: %[[CONST2:[0-9]*]]:2 = "_tf.Const"()
|
||||
// CHECK-SAME: () -> (tensor<i32>, !_tf.control)
|
||||
%2 = tf_executor.island(%0#1) {
|
||||
tf_executor.yield
|
||||
}
|
||||
%3:2 = tf_executor.island(%2, %1#1) {
|
||||
%6 = "tf.Add"(%0#0, %1#0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
tf_executor.yield %6 : tensor<i32>
|
||||
}
|
||||
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = "_tf.Add"(%[[CONST1]]#0, %[[CONST2]]#0, %[[CONST1]]#1, %[[CONST2]]#1)
|
||||
// CHECK-SAME: (tensor<i32>, tensor<i32>, !_tf.control, !_tf.control) -> (tensor<i32>, !_tf.control)
|
||||
tf_executor.fetch %3#0 : tensor<i32>
|
||||
}
|
||||
return %fetch : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests if empty island with multiple control inputs will be replaced with a
|
||||
// no-op.
|
||||
// CHECK-LABEL: func @empty_island_multi_control_inputs
|
||||
func @empty_island_multi_control_inputs() -> tensor<i32> {
|
||||
%fetch = tf_executor.graph {
|
||||
%0:2 = tf_executor.island {
|
||||
%4 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_executor.yield %4 : tensor<i32>
|
||||
}
|
||||
// CHECK-NEXT: %[[CONST1:[0-9]*]]:2 = "_tf.Const"()
|
||||
// CHECK-SAME: () -> (tensor<i32>, !_tf.control)
|
||||
%1:2 = tf_executor.island {
|
||||
%5 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_executor.yield %5 : tensor<i32>
|
||||
}
|
||||
// CHECK-NEXT: %[[CONST2:[0-9]*]]:2 = "_tf.Const"()
|
||||
// CHECK-SAME: () -> (tensor<i32>, !_tf.control)
|
||||
%2 = tf_executor.island(%0#1, %1#1) {
|
||||
tf_executor.yield
|
||||
}
|
||||
// CHECK-NEXT: %[[NOOP:[0-9]*]] = "_tf.NoOp"(%[[CONST1]]#1, %[[CONST2]]#1)
|
||||
// CHECK-SAME: (!_tf.control, !_tf.control) -> !_tf.control
|
||||
%3:2 = tf_executor.island(%2) {
|
||||
%6 = "tf.Add"(%0#0, %1#0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
tf_executor.yield %6 : tensor<i32>
|
||||
}
|
||||
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = "_tf.Add"(%[[CONST1]]#0, %[[CONST2]]#0, %[[NOOP]])
|
||||
// CHECK-SAME: (tensor<i32>, tensor<i32>, !_tf.control) -> (tensor<i32>, !_tf.control)
|
||||
tf_executor.fetch %3#0 : tensor<i32>
|
||||
}
|
||||
return %fetch : tensor<i32>
|
||||
}
|
@ -723,6 +723,11 @@ func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x1
|
||||
return %0 : tensor<3x8x8x16xf32>
|
||||
}
|
||||
|
||||
func @convert_dot_general(%arg0: tensor<3x2x6x5x1xf32>, %arg1: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> {
|
||||
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<[1, 2]> : tensor<2xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<[1, 3]> : tensor<2xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2x6x5x1xf32>, tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32>
|
||||
return %0 : tensor<3x5x1x4xf32>
|
||||
}
|
||||
|
||||
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
||||
|
||||
// CHECK-LABEL: func @biasAdd_NHWC(
|
||||
@ -1596,3 +1601,14 @@ func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x1
|
||||
// CHECK: [[VAL_402:%.*]] = "tf.BroadcastTo"([[VAL_400]], [[VAL_401]]) : (tensor<3x1x1x16xf32>, tensor<4xi64>) -> tensor<3x8x8x16xf32>
|
||||
// CHECK: return [[VAL_402]] : tensor<3x8x8x16xf32>
|
||||
// CHECK: }
|
||||
|
||||
// CHECK-LABEL: func @convert_dot_general(
|
||||
// CHECK-SAME: [[VAL_396:%.*]]: tensor<3x2x6x5x1xf32>, [[VAL_397:%.*]]: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> {
|
||||
// CHECK: [[VAL_398:%.*]] = "tf.Transpose"([[VAL_396]], {{.*}}) : (tensor<3x2x6x5x1xf32>, tensor<5xi64>) -> tensor<3x5x1x2x6xf32>
|
||||
// CHECK: [[VAL_399:%.*]] = "tf.Transpose"([[VAL_397]], {{.*}}) : (tensor<3x2x4x6xf32>, tensor<4xi64>) -> tensor<3x2x6x4xf32>
|
||||
// CHECK: [[VAL_400:%.*]] = "tf.Reshape"([[VAL_398]], {{.*}}) : (tensor<3x5x1x2x6xf32>, tensor<3xi64>) -> tensor<3x5x12xf32>
|
||||
// CHECK: [[VAL_401:%.*]] = "tf.Reshape"([[VAL_399]], {{.*}}) : (tensor<3x2x6x4xf32>, tensor<3xi64>) -> tensor<3x12x4xf32>
|
||||
// CHECK: [[VAL_402:%.*]] = "tf.BatchMatMulV2"([[VAL_400]], [[VAL_401]]) {adj_x = false, adj_y = false} : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32>
|
||||
// CHECK: [[VAL_403:%.*]] = "tf.Reshape"([[VAL_402]], {{.*}}) : (tensor<3x5x4xf32>, tensor<4xi64>) -> tensor<3x5x1x4xf32>
|
||||
// CHECK: return [[VAL_403]] : tensor<3x5x1x4xf32>
|
||||
// CHECK: }
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt %s -op-fusion | FileCheck %s --dump-input-on-failure
|
||||
// RUN: tf-opt %s -tf-op-fusion | FileCheck %s --dump-input-on-failure
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conv2D + BiasAdd + <Activation> fusions.
|
||||
|
@ -1,5 +1,4 @@
|
||||
// Run a pass for promoting tf.VarHandleOps to function arguments in a format of TensorFlowSavedModelDialect.
|
||||
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-saved-model-promote-var-handles-to-args | FileCheck %s -dump-input-on-failure
|
||||
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-promote-var-handles-to-args | FileCheck %s -dump-input-on-failure
|
||||
|
||||
// Tests main function with multiple blocks.
|
||||
|
||||
@ -12,27 +11,24 @@ func @main() {
|
||||
|
||||
// -----
|
||||
|
||||
"tf_saved_model.global_tensor"() {sym_name = "x", type = tensor<f32>, value = dense<1.67482901> : tensor<f32>} : () -> ()
|
||||
"tf_saved_model.global_tensor"() {sym_name = "y", type = tensor<i32>, value = dense<0> : tensor<i32>} : () -> ()
|
||||
|
||||
// CHECK-LABEL: func @no_args
|
||||
// CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x})
|
||||
// CHECK-SAME: (%arg0: tensor<!tf.resource> {tf.resource_name = "x"})
|
||||
// CHECK-NOT: "tf.VarHandleOp"
|
||||
func @no_args() {
|
||||
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
|
||||
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @some_args
|
||||
// CHECK-SAME: (%arg0: tensor<i1>, %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x})
|
||||
// CHECK-SAME: (%arg0: tensor<i1>, %arg1: tensor<!tf.resource> {tf.resource_name = "x"})
|
||||
// CHECK-NOT: "tf.VarHandleOp"
|
||||
func @some_args(%arg0: tensor<i1>) {
|
||||
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
|
||||
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @unique_vars
|
||||
// CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x}, %arg1: tensor<!tf.resource<tensor<i32>>> {tf_saved_model.bound_input = @y})
|
||||
// CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf.resource_name = "x"}, %arg1: tensor<!tf.resource<tensor<i32>>> {tf.resource_name = "y"})
|
||||
// CHECK-NOT: "tf.VarHandleOp"
|
||||
func @unique_vars() {
|
||||
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
|
||||
@ -41,7 +37,7 @@ func @unique_vars() {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @duplicate_vars
|
||||
// CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x})
|
||||
// CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf.resource_name = "x"})
|
||||
// CHECK-NOT: "tf.VarHandleOp"
|
||||
func @duplicate_vars() {
|
||||
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
|
||||
@ -50,7 +46,7 @@ func @duplicate_vars() {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @duplicate_vars_with_users
|
||||
// CHECK-SAME: (%arg0: tensor<f32>, %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x})
|
||||
// CHECK-SAME: (%arg0: tensor<f32>, %arg1: tensor<!tf.resource<tensor<f32>>> {tf.resource_name = "x"})
|
||||
// CHECK: "tf.ReadVariableOp"(%arg1)
|
||||
// CHECK: "tf.AssignAddVariableOp"(%arg1, %arg0)
|
||||
// CHECK-NOT: "tf.VarHandleOp"
|
||||
|
@ -1,57 +0,0 @@
|
||||
// RUN: tf-opt %s -tf-raise-control-flow -split-input-file | FileCheck %s
|
||||
|
||||
// Test that we remove underscores.
|
||||
|
||||
// CHECK-LABEL: func @testSimpleAddsAndIdentity(%arg0: tensor<*xf32>)
|
||||
func @testSimpleAddsAndIdentity(tensor<*xf32>) -> tensor<*xf32> {
|
||||
^bb0(%0: tensor<*xf32>):
|
||||
|
||||
// CHECK: %0 = "tf.Identity"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%1 = "_tf.Identity"(%0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
|
||||
// CHECK: %1 = "tf.Add"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
%2 = "_tf.Add"(%0, %0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
|
||||
// CHECK: %2 = "tf.Add"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
%3 = "_tf.Add"(%1, %2) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
|
||||
// CHECK: return %2 : tensor<*xf32>
|
||||
return %3 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testAddWithControlDependency(%arg0: tensor<*xf32>)
|
||||
func @testAddWithControlDependency(tensor<*xf32>) -> tensor<*xf32> {
|
||||
^bb0(%0: tensor<*xf32>):
|
||||
|
||||
// CHECK: %0 = "tf.Identity"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%1:2 = "_tf.Identity"(%0) : (tensor<*xf32>) -> (tensor<*xf32>, !_tf.control)
|
||||
|
||||
// CHECK: %1 = "tf.Add"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
%2:2 = "_tf.Add"(%0, %0, %1#1) : (tensor<*xf32>, tensor<*xf32>, !_tf.control) -> (tensor<*xf32>, !_tf.control)
|
||||
|
||||
// CHECK: %2 = "tf.Add"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
%3:2 = "_tf.Add"(%1#0, %2, %1#1, %2#1) : (tensor<*xf32>, tensor<*xf32>, !_tf.control, !_tf.control) -> (tensor<*xf32>, !_tf.control)
|
||||
|
||||
// CHECK: return %2 : tensor<*xf32>
|
||||
return %3 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// TODO(clattner): simplify and expand these tests. This is mostly a placeholder.
|
||||
func @LoopTest() {
|
||||
%0:2 = "_tf.Const"() {device = "", name = "Const", dtype = "tfdtype$DT_INT32", value = dense<1> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
|
||||
%1:2 = "_tf.Enter"(%0#0) {device = "", name = "while/Enter", T = "tfdtype$DT_INT32", frame_name = "while/while_context", is_constant = false, parallel_iterations = 10} : (tensor<i32>) -> (tensor<*xi32>, !_tf.control)
|
||||
|
||||
%11:2 = "_tf.NextIteration.source"() {device = "", name = "while/NextIteration", T = "tfdtype$DT_INT32", id = 0} : () -> (tensor<*xi32>, !_tf.control)
|
||||
|
||||
%2:3 = "_tf.Merge"(%11#0, %1#0) {device = "", name = "while/Merge", N = 2, T = "tfdtype$DT_INT32"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<i32>, !_tf.control)
|
||||
%3:2 = "_tf.Const"(%2#2) {device = "", name = "while/Less/y", dtype = "tfdtype$DT_INT32", value = dense<2> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control)
|
||||
%4:2 = "_tf.Less"(%2#0, %3#0) {device = "", name = "while/Less", T = "tfdtype$DT_INT32"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi1>, !_tf.control)
|
||||
%5:2 = "_tf.LoopCond"(%4#0) {device = "", name = "while/LoopCond"} : (tensor<*xi1>) -> (tensor<i1>, !_tf.control)
|
||||
%6:3 = "_tf.Switch"(%2#0, %5#0) {device = "", name = "while/Switch", T = "tfdtype$DT_INT32", _class = ["loc:@while/Merge"]} : (tensor<*xi32>, tensor<i1>) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control)
|
||||
%7:2 = "_tf.Exit"(%6#0) {device = "", name = "while/Exit", T = "tfdtype$DT_INT32"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
||||
%8:2 = "_tf.Identity"(%6#1) {device = "", name = "while/Identity", T = "tfdtype$DT_INT32"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
||||
%9:2 = "_tf.Const"(%8#1) {device = "", name = "while/Add/y", dtype = "tfdtype$DT_INT32", value = dense<3> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control)
|
||||
%10:2 = "_tf.Add"(%8#0, %9#0) {device = "", name = "while/Add", T = "tfdtype$DT_INT32"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control)
|
||||
%ctl = "_tf.NextIteration.sink"(%10#0) {device = "", name = "while/NextIteration", T = "tfdtype$DT_INT32", id = 0} : (tensor<*xi32>) -> (!_tf.control)
|
||||
return
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt -verify-diagnostics -readonly-references-to-resources -split-input-file %s | FileCheck %s --dump-input=fail
|
||||
// RUN: tf-opt -verify-diagnostics -tf-readonly-references-to-resources -split-input-file %s | FileCheck %s --dump-input=fail
|
||||
|
||||
// Test case: Basic converting.
|
||||
|
||||
|
@ -416,6 +416,17 @@ func @enter_control(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @enter_control_longform(%{{.*}}: tensor<*xf32>, %{{.*}}: tensor<i1>) -> tensor<*xf32> {
|
||||
func @enter_control_longform(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
|
||||
%0 = tf_executor.graph {
|
||||
%1:3 = tf_executor.Switch %arg0, %arg1 : tensor<*xf32>
|
||||
// CHECK: tf_executor.Enter %{{.*}}, %{{.*}}, %{{.*}} frame "some/frame" : tensor<*xf32>
|
||||
%res:2 = tf_executor.Enter %arg0, %1#2, %1#2 frame "some/frame" : (tensor<*xf32>, !tf_executor.control, !tf_executor.control) -> (tensor<*xf32>, !tf_executor.control)
|
||||
tf_executor.fetch %res#0 : tensor<*xf32>
|
||||
}
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @nextiteration(%{{.*}}: tensor<*xf32>, %{{.*}}: i1) -> tensor<*xf32> {
|
||||
func @nextiteration(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> {
|
||||
%0 = tf_executor.graph {
|
||||
|
@ -40,3 +40,16 @@ module attributes {tf_saved_model.semantics} {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// CHECK: func @f
|
||||
func @f(
|
||||
%arg0: tensor<f32> {tf.resource_name = "resource"}
|
||||
) attributes { tf_saved_model.exported_names = ["foo.some_func"] } {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -120,7 +120,7 @@ module attributes {tf_saved_model.semantics} {
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// expected-error@+1 {{all arguments should have 'tf_saved_model.index_path' or 'tf_saved_model.bound_input' attributes}}
|
||||
// expected-error@+1 {{all arguments should have 'tf_saved_model.index_path', 'tf_saved_model.bound_input' or 'tf.resource_name' attributes}}
|
||||
func @f(
|
||||
%arg0: tensor<f32>
|
||||
) attributes { tf_saved_model.exported_names = ["f"] } {
|
||||
|
@ -14,7 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass transforms functional control flow operations in the
|
||||
// standard TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
|
||||
// TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
@ -52,7 +52,6 @@ static Value LowerCondition(Location loc, Value value, OpBuilder* builder) {
|
||||
//
|
||||
// Requires the function to provide arguments for each of the `fn` operands
|
||||
// that is compatible for tensor cast.
|
||||
//
|
||||
static Operation* CallFn(Location loc, const std::function<Value(int)>& get_arg,
|
||||
FuncOp fn, OpBuilder* builder) {
|
||||
FunctionType fn_type = fn.getType();
|
||||
@ -113,7 +112,6 @@ static void JumpToBlock(Location loc, const std::function<Value(int)>& get_arg,
|
||||
// Requires that the block has same number of arguments as number of results of
|
||||
// the operation and either they have same types or are more generic types and
|
||||
// it is possible to cast them to results' types.
|
||||
//
|
||||
static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op,
|
||||
Block* block, OpBuilder* builder) {
|
||||
assert(op->getNumResults() == block->getNumArguments());
|
||||
@ -132,9 +130,6 @@ static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op,
|
||||
// Given a functional IfOp, transforms the enclosing code to eliminate it
|
||||
// completely from the IR, breaking it into operations to evaluate the condition
|
||||
// as a bool, plus some branches.
|
||||
//
|
||||
// This returns true on failure.
|
||||
//
|
||||
static LogicalResult LowerIfOp(IfOp op) {
|
||||
Operation* op_inst = op.getOperation();
|
||||
Location loc = op_inst->getLoc();
|
||||
@ -193,9 +188,6 @@ static LogicalResult LowerIfOp(IfOp op) {
|
||||
// Given a functional WhileOp, transforms the enclosing code to eliminate it
|
||||
// completely from the IR, breaking it into operations to execute the loop body
|
||||
// repeatedly while the loop condition is true.
|
||||
//
|
||||
// This returns true on failure.
|
||||
//
|
||||
static LogicalResult LowerWhileOp(WhileOp op) {
|
||||
Operation* op_inst = op.getOperation();
|
||||
Location loc = op_inst->getLoc();
|
||||
|
@ -15,10 +15,15 @@ limitations under the License.
|
||||
|
||||
// This file implements logic for legalizing HLO to TensorFlow.
|
||||
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
@ -40,6 +45,8 @@ namespace mlir {
|
||||
namespace TF {
|
||||
namespace {
|
||||
|
||||
using xla_hlo::DotDimensionNumbers;
|
||||
|
||||
class ConvertSliceOp : public OpConversionPattern<xla_hlo::SliceOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
@ -75,6 +82,205 @@ class ConvertSliceOp : public OpConversionPattern<xla_hlo::SliceOp> {
|
||||
};
|
||||
};
|
||||
|
||||
// Appends all elements in `range` to `values`.
|
||||
template <typename ValueT, typename Range>
|
||||
void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range) {
|
||||
values.insert(values.end(), range.begin(), range.end());
|
||||
}
|
||||
|
||||
// Appends all elements in `range` to `values`.
|
||||
template <typename ValueT, typename Range, typename... RangeTs>
|
||||
void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range,
|
||||
RangeTs &&... ranges) {
|
||||
values.insert(values.end(), range.begin(), range.end());
|
||||
Append(values, ranges...);
|
||||
}
|
||||
|
||||
// Returns the number of elements in `range`.
|
||||
template <typename Range>
|
||||
size_t Size(Range &&range) {
|
||||
return range.size();
|
||||
}
|
||||
|
||||
// Returns the total number of elements in a variadic number of `ranges`.
|
||||
template <typename Range, typename... RangeTs>
|
||||
size_t Size(Range &&range, RangeTs &&... ranges) {
|
||||
return range.size() + Size(std::forward<RangeTs>(ranges)...);
|
||||
}
|
||||
|
||||
// Concats all elements in `ranges` and returns a small vector as a result.
|
||||
template <typename ValueT, typename... RangeTs>
|
||||
llvm::SmallVector<ValueT, 4> Concat(RangeTs &&... ranges) {
|
||||
llvm::SmallVector<int64_t, 4> results;
|
||||
results.reserve(Size(std::forward<RangeTs>(ranges)...));
|
||||
Append(results, std::forward<RangeTs>(ranges)...);
|
||||
return results;
|
||||
}
|
||||
|
||||
// A struct to hold axes and sizes for a set of dimensions.
|
||||
struct DimensionSetVector {
|
||||
llvm::ArrayRef<int64_t> AxesArray() const { return axes.getArrayRef(); }
|
||||
llvm::ArrayRef<int64_t> SizesArray() const { return sizes.getArrayRef(); }
|
||||
|
||||
llvm::SmallSetVector<int64_t, 4> axes;
|
||||
llvm::SmallSetVector<int64_t, 4> sizes;
|
||||
};
|
||||
|
||||
// A struct to hold information about dimensions of dot_general operands.
|
||||
class DotDimensionsInfo {
|
||||
public:
|
||||
DotDimensionsInfo(ShapedType type, DenseIntElementsAttr batch_dimensions,
|
||||
DenseIntElementsAttr contracting_dimensions) {
|
||||
const int rank = type.getRank();
|
||||
for (const int dim : batch_dimensions.getValues<int64_t>()) {
|
||||
batch_dimensions_.axes.insert(dim);
|
||||
batch_dimensions_.sizes.insert(type.getDimSize(dim));
|
||||
}
|
||||
|
||||
for (const int dim : contracting_dimensions.getValues<int64_t>()) {
|
||||
contracting_dimensions_.axes.insert(dim);
|
||||
contracting_dimensions_.sizes.insert(type.getDimSize(dim));
|
||||
}
|
||||
|
||||
for (int dim = 0; dim < rank; ++dim) {
|
||||
if (contracting_dimensions_.axes.count(dim) > 0 ||
|
||||
batch_dimensions_.axes.count(dim) > 0) {
|
||||
continue;
|
||||
}
|
||||
out_dimensions_.axes.insert(dim);
|
||||
out_dimensions_.sizes.insert(type.getDimSize(dim));
|
||||
}
|
||||
}
|
||||
|
||||
const DimensionSetVector &batch_dimensions() const {
|
||||
return batch_dimensions_;
|
||||
}
|
||||
const DimensionSetVector &contracting_dimensions() const {
|
||||
return contracting_dimensions_;
|
||||
}
|
||||
// Out dimensions are any dimensions that are neither batch nor contracting
|
||||
// dimensions, hence will be propagated to output shape.
|
||||
const DimensionSetVector &out_dimensions() const { return out_dimensions_; }
|
||||
|
||||
// Returns the total dimension size after flattening all contracting
|
||||
// dimensions.
|
||||
int FlattenedContractingDimensionSize() const {
|
||||
return std::accumulate(contracting_dimensions_.sizes.begin(),
|
||||
contracting_dimensions_.sizes.end(), 1,
|
||||
std::multiplies<int64_t>());
|
||||
}
|
||||
|
||||
// Returns the total dimension size after flattening all out dimensions.
|
||||
int FlattenedOutDimensionSize() const {
|
||||
return std::accumulate(out_dimensions_.sizes.begin(),
|
||||
out_dimensions_.sizes.end(), 1,
|
||||
std::multiplies<int64_t>());
|
||||
}
|
||||
|
||||
private:
|
||||
DimensionSetVector batch_dimensions_;
|
||||
DimensionSetVector contracting_dimensions_;
|
||||
// Out dimensions are any dimensions that are neither batch nor contracting
|
||||
// dimensions, hence will be propagated to output shape.
|
||||
DimensionSetVector out_dimensions_;
|
||||
};
|
||||
|
||||
// Converts xla_hlo.dot to tf.BatchMatMul. Reshape or Transpose ops will also be
|
||||
// inserted to convert to well-formed matrix multiply.
|
||||
Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
|
||||
auto dot_general_op = cast<xla_hlo::DotGeneralOp>(old_op);
|
||||
auto lhs_type = dot_general_op.lhs().getType().cast<ShapedType>();
|
||||
auto rhs_type = dot_general_op.rhs().getType().cast<ShapedType>();
|
||||
auto result_type = dot_general_op.getResult().getType().cast<ShapedType>();
|
||||
DotDimensionNumbers dot_dimension_numbers =
|
||||
dot_general_op.dot_dimension_numbers();
|
||||
mlir::Location loc = dot_general_op.getLoc();
|
||||
const int lhs_rank = lhs_type.getRank();
|
||||
const int rhs_rank = rhs_type.getRank();
|
||||
|
||||
// Collects lhs and rhs dimensions information.
|
||||
DotDimensionsInfo lhs_dot_dimensions_info(
|
||||
lhs_type, dot_dimension_numbers.lhs_batching_dimensions(),
|
||||
dot_dimension_numbers.lhs_contracting_dimensions());
|
||||
DotDimensionsInfo rhs_dot_dimensions_info(
|
||||
rhs_type, dot_dimension_numbers.rhs_batching_dimensions(),
|
||||
dot_dimension_numbers.rhs_contracting_dimensions());
|
||||
|
||||
// Transposes lhs shape to be in the order of {batch_dimensions,
|
||||
// out_dimensions, contracting dimensions}.
|
||||
llvm::SmallVector<int64_t, 4> lhs_permutation = Concat<int64_t>(
|
||||
lhs_dot_dimensions_info.batch_dimensions().AxesArray(),
|
||||
lhs_dot_dimensions_info.out_dimensions().AxesArray(),
|
||||
lhs_dot_dimensions_info.contracting_dimensions().AxesArray());
|
||||
llvm::SmallVector<int64_t, 4> lhs_transposed_shape = Concat<int64_t>(
|
||||
lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
|
||||
lhs_dot_dimensions_info.out_dimensions().SizesArray(),
|
||||
lhs_dot_dimensions_info.contracting_dimensions().SizesArray());
|
||||
auto lhs_transposed = rewriter.create<xla_hlo::TransposeOp>(
|
||||
loc,
|
||||
RankedTensorType::get(lhs_transposed_shape, lhs_type.getElementType()),
|
||||
dot_general_op.lhs(),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({lhs_rank}, rewriter.getI64Type()),
|
||||
lhs_permutation));
|
||||
|
||||
// Transposes rhs shape to be in the order of {batch_dimensions, contracting
|
||||
// dimensions, out_dimensions}.
|
||||
llvm::SmallVector<int64_t, 4> rhs_permutation = Concat<int64_t>(
|
||||
rhs_dot_dimensions_info.batch_dimensions().AxesArray(),
|
||||
rhs_dot_dimensions_info.contracting_dimensions().AxesArray(),
|
||||
rhs_dot_dimensions_info.out_dimensions().AxesArray());
|
||||
llvm::SmallVector<int64_t, 4> rhs_transposed_shape = Concat<int64_t>(
|
||||
rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
|
||||
rhs_dot_dimensions_info.contracting_dimensions().SizesArray(),
|
||||
rhs_dot_dimensions_info.out_dimensions().SizesArray());
|
||||
auto rhs_transposed = rewriter.create<xla_hlo::TransposeOp>(
|
||||
loc,
|
||||
RankedTensorType::get(rhs_transposed_shape, rhs_type.getElementType()),
|
||||
dot_general_op.rhs(),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({rhs_rank}, rewriter.getI64Type()),
|
||||
rhs_permutation));
|
||||
|
||||
// Reshapes lhs to flatten out_dimensions and contracting_dimensions.
|
||||
llvm::SmallVector<int64_t, 4> lhs_flattened_shape = Concat<int64_t>(
|
||||
lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
|
||||
llvm::ArrayRef<int64_t>{
|
||||
lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
|
||||
llvm::ArrayRef<int64_t>{
|
||||
lhs_dot_dimensions_info.FlattenedContractingDimensionSize()});
|
||||
auto lhs_flattend = rewriter.create<xla_hlo::ReshapeOp>(
|
||||
loc,
|
||||
RankedTensorType::get(lhs_flattened_shape, lhs_type.getElementType()),
|
||||
lhs_transposed.getResult());
|
||||
|
||||
// Reshapes rhs to flatten out_dimensions and contracting_dimensions.
|
||||
llvm::SmallVector<int64_t, 4> rhs_flattened_shape = Concat<int64_t>(
|
||||
rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
|
||||
llvm::ArrayRef<int64_t>{
|
||||
rhs_dot_dimensions_info.FlattenedContractingDimensionSize()},
|
||||
llvm::ArrayRef<int64_t>{
|
||||
rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
|
||||
auto rhs_flattend = rewriter.create<xla_hlo::ReshapeOp>(
|
||||
loc,
|
||||
RankedTensorType::get(rhs_flattened_shape, rhs_type.getElementType()),
|
||||
rhs_transposed.getResult());
|
||||
|
||||
// Creates matmul op of `lhs_flattend` and `rhs_flattend`.
|
||||
llvm::SmallVector<int64_t, 4> matmul_shape =
|
||||
Concat<int64_t>(lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
|
||||
llvm::ArrayRef<int64_t>{
|
||||
lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
|
||||
llvm::ArrayRef<int64_t>{
|
||||
rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
|
||||
auto matmul = rewriter.create<TF::BatchMatMulV2Op>(
|
||||
loc, RankedTensorType::get(matmul_shape, result_type.getElementType()),
|
||||
lhs_flattend.getResult(), rhs_flattend.getResult());
|
||||
auto reshaped =
|
||||
rewriter.create<xla_hlo::ReshapeOp>(loc, result_type, matmul.getResult());
|
||||
return reshaped.getResult();
|
||||
}
|
||||
|
||||
class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
|
||||
public:
|
||||
LegalizeHloToTf() = default;
|
||||
|
@ -184,3 +184,10 @@ def ConvertDotOp : NativeCodeCall<"ConvertDotOp($_builder, "
|
||||
def : Pat<(HLO_DotOp:$old_value AnyStaticShapeTensor:$lhs,
|
||||
AnyStaticShapeTensor:$rhs, $precision_config),
|
||||
(ConvertDotOp $old_value)>;
|
||||
|
||||
def ConvertDotGeneralOp : NativeCodeCall<"ConvertDotGeneralOp($_builder, "
|
||||
"$0.getDefiningOp())">;
|
||||
def : Pat<(HLO_DotGeneralOp:$old_value AnyStaticShapeTensor:$lhs,
|
||||
AnyStaticShapeTensor:$rhs, $dot_dimension_numbers,
|
||||
$precision_config),
|
||||
(ConvertDotGeneralOp $old_value)>;
|
||||
|
@ -166,7 +166,7 @@ std::unique_ptr<OperationPass<FuncOp>> CreateOpFusionPass() {
|
||||
}
|
||||
|
||||
static PassRegistration<OpFusionPass> pass(
|
||||
"op-fusion",
|
||||
"tf-op-fusion",
|
||||
"Replaces commonly occurring subgraphs with optimized fused kernels");
|
||||
|
||||
} // namespace TF
|
||||
|
@ -58,6 +58,8 @@ void CreateTFStandardPipeline(OpPassManager &pm,
|
||||
func_pm.addPass(tf_executor::CreateTFExecutorGraphPruningPass());
|
||||
func_pm.addPass(tf_executor::CreateTFExecutorIslandCoarseningPass());
|
||||
func_pm.addPass(CreateMaterializePassthroughOpPass());
|
||||
if (options.form_clusters)
|
||||
func_pm.addPass(TFDevice::CreateClusterFormationPass());
|
||||
|
||||
// Hopefully there is a single island left, or there wasn't any to begin with.
|
||||
// We now run the optimizer which operates mostly inside islands.
|
||||
|
@ -77,6 +77,9 @@ struct StandardPipelineOptions
|
||||
Option<bool> enable_inliner{*this, "enable-inliner",
|
||||
llvm::cl::desc("Enable inliner."),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> form_clusters{*this, "form-clusters",
|
||||
llvm::cl::desc("Enable Cluster Formation pass."),
|
||||
llvm::cl::init(false)};
|
||||
};
|
||||
|
||||
// Propagates the pass manager with the passes involved in transforming or
|
||||
@ -95,11 +98,9 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateResourceDeviceInferencePass();
|
||||
// of their aliasing output arguments.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass();
|
||||
|
||||
// Creates a pass that promotes tf.VarHandleOp to to resource arguments of where
|
||||
// resource names are `tf_saved_model.bound_input` symbol argument attributes
|
||||
// for all functions.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreatePromoteVarHandlesToSavedModelArgsPass();
|
||||
// Creates a pass that promotes tf.VarHandleOp to resource arguments for all
|
||||
// functions.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass();
|
||||
|
||||
// Creates a pass that converts readonly reference variables to the
|
||||
// corresponding resource variables.
|
||||
@ -151,13 +152,6 @@ std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass();
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateOpFusionPass();
|
||||
} // namespace TF
|
||||
|
||||
namespace TFControlFlow {
|
||||
// Raises from the "TensorFlow Control Flow" dialect to the standard TensorFlow
|
||||
// dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateRaiseTFControlFlowPass();
|
||||
|
||||
} // namespace TFControlFlow
|
||||
|
||||
namespace tf_executor {
|
||||
class GraphOp;
|
||||
|
||||
|
@ -389,18 +389,15 @@ void PromoteResourcesToArgsPass::runOnOperation() {
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
// This pass is for promoting Varhandle ops to tf_saved_model.bound_input
|
||||
// attributes, which are required for TensorFlowSavedModelDialect.
|
||||
class PromoteVarHandlesToSavedModelArgsPass
|
||||
: public PassWrapper<PromoteVarHandlesToSavedModelArgsPass,
|
||||
OperationPass<ModuleOp>> {
|
||||
class PromoteVarHandlesToArgsPass
|
||||
: public PassWrapper<PromoteVarHandlesToArgsPass, OperationPass<ModuleOp>> {
|
||||
public:
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
void PromoteVarHandlesToSavedModelArgsPass::runOnOperation() {
|
||||
void PromoteVarHandlesToArgsPass::runOnOperation() {
|
||||
ModuleOp module = getOperation();
|
||||
|
||||
MLIRContext* context = module.getContext();
|
||||
for (auto function : module.getOps<FuncOp>()) {
|
||||
if (failed(CheckSingleBlockFunction(function))) return signalPassFailure();
|
||||
|
||||
@ -409,15 +406,13 @@ void PromoteVarHandlesToSavedModelArgsPass::runOnOperation() {
|
||||
&var_handle_shared_names);
|
||||
|
||||
// Add resource names for each `tf.VarHandleOp` that were promoted to
|
||||
// saved model arguments.
|
||||
// resource arguments.
|
||||
const int var_handle_args_offset =
|
||||
function.getNumArguments() - var_handle_shared_names.size();
|
||||
for (auto var_name_and_index : llvm::enumerate(var_handle_shared_names)) {
|
||||
auto symbol_ref =
|
||||
SymbolRefAttr::get(var_name_and_index.value(), &getContext());
|
||||
for (auto var_name_and_index : llvm::enumerate(var_handle_shared_names))
|
||||
function.setArgAttr(var_name_and_index.index() + var_handle_args_offset,
|
||||
"tf_saved_model.bound_input", symbol_ref);
|
||||
}
|
||||
kResourceNameArgAttr,
|
||||
StringAttr::get(var_name_and_index.value(), context));
|
||||
}
|
||||
}
|
||||
|
||||
@ -427,19 +422,17 @@ std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass() {
|
||||
return std::make_unique<PromoteResourcesToArgsPass>();
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreatePromoteVarHandlesToSavedModelArgsPass() {
|
||||
return std::make_unique<PromoteVarHandlesToSavedModelArgsPass>();
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass() {
|
||||
return std::make_unique<PromoteVarHandlesToArgsPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<PromoteResourcesToArgsPass> pass(
|
||||
"tf-promote-resources-to-args",
|
||||
"Promote resources reads/writes to function inputs/outputs.");
|
||||
|
||||
static PassRegistration<PromoteVarHandlesToSavedModelArgsPass> saved_model_pass(
|
||||
"tf-saved-model-promote-var-handles-to-args",
|
||||
"Promote tf.VarHandleOps to function arguments in a format of "
|
||||
"TensorFlowSavedModelDialect.");
|
||||
static PassRegistration<PromoteVarHandlesToArgsPass> var_handle_pass(
|
||||
"tf-promote-var-handles-to-args",
|
||||
"Promote tf.VarHandleOps to function arguments.");
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
@ -1,159 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file implements logic for raising from the "TensorFlow control flow"
|
||||
// dialect of MLIR to the standard TensorFlow dialect. The TensorFlow control
|
||||
// flow dialect represents control flow with Switch/Merge and a few related
|
||||
// control flow nodes, along with control dependencies.
|
||||
//
|
||||
// This pass rebuilds them code in terms of MLIR branches and blocks,
|
||||
// eliminating control dependencies, and results in the code being in the
|
||||
// canonical TensorFlow dialect.
|
||||
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFControlFlow {
|
||||
|
||||
namespace {
|
||||
struct RaiseTFControlFlow
|
||||
: public PassWrapper<RaiseTFControlFlow, FunctionPass> {
|
||||
void runOnFunction() {
|
||||
// First start by recognizing loops and reconstructing a loop tree.
|
||||
buildLoopNests();
|
||||
|
||||
// Next, transform Switch/Merge and other control flow ops into proper
|
||||
// conditional control flow.
|
||||
buildConditionals();
|
||||
|
||||
// Now that we have proper conditional control flow ops, the control edges
|
||||
// can be dropped, and the underscores removed from operation names.
|
||||
rewriteOps();
|
||||
}
|
||||
|
||||
void buildLoopNests();
|
||||
void buildConditionals();
|
||||
void rewriteOps();
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Loop nest reconstruction
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void RaiseTFControlFlow::buildLoopNests() {
|
||||
// TODO(clattner)
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conditional Reconstruction
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void RaiseTFControlFlow::buildConditionals() {
|
||||
// TODO.
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Final rewrite from TF Control Flow form to canonical TensorFlow form
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static bool isUnderscoredTFOp(Operation &op) {
|
||||
return op.getName().getStringRef().startswith("_tf.");
|
||||
}
|
||||
|
||||
// Drop control edges, and remove underscores from operation names.
|
||||
void RaiseTFControlFlow::rewriteOps() {
|
||||
auto function = getFunction();
|
||||
OpBuilder builder(function.getBody());
|
||||
|
||||
// On the first pass, create replacement operations for every one we are going
|
||||
// to replace, updating anything that uses the normal results with the newly
|
||||
// created operation.
|
||||
for (auto &bb : function) {
|
||||
for (auto &op : bb) {
|
||||
// Ignore any operations that we aren't looking for.
|
||||
if (!isUnderscoredTFOp(op)) continue;
|
||||
|
||||
// We always insert the replacement operation next to the operation it
|
||||
// is replacing.
|
||||
builder.setInsertionPoint(&op);
|
||||
|
||||
// Drop the leading _ off the name.
|
||||
OperationState result(op.getLoc(),
|
||||
op.getName().getStringRef().drop_front());
|
||||
|
||||
// Add an operand for each non-control input we find. Control values
|
||||
// aren't necessary any more since the order within a block encodes the
|
||||
// same information.
|
||||
for (auto &operand : op.getOpOperands()) {
|
||||
if (!operand.get().getType().isa<TFControlType>())
|
||||
result.operands.push_back(operand.get());
|
||||
|
||||
// Drop all operands from the old operation, eliminating any
|
||||
// inter-dependencies after this pass.
|
||||
operand.drop();
|
||||
}
|
||||
|
||||
// Add a result type for each non-control result we find.
|
||||
bool sawControlResult = false;
|
||||
for (auto opResult : op.getResults()) {
|
||||
if (opResult.getType().isa<TFControlType>()) {
|
||||
sawControlResult = true;
|
||||
} else {
|
||||
// We assume all control inputs are at the end of the result list.
|
||||
assert(!sawControlResult && "all control results must be last");
|
||||
(void)sawControlResult;
|
||||
result.types.push_back(opResult.getType());
|
||||
}
|
||||
}
|
||||
|
||||
result.attributes.append(op.getAttrs().begin(), op.getAttrs().end());
|
||||
|
||||
// Create the replacement operation.
|
||||
auto *replacement = builder.createOperation(result);
|
||||
|
||||
// We know that all the control results are last, so we can just rewrite
|
||||
// the first results.
|
||||
for (unsigned i = 0, e = result.types.size(); i != e; ++i)
|
||||
op.getResult(i).replaceAllUsesWith(replacement->getResult(i));
|
||||
}
|
||||
}
|
||||
|
||||
// In the second pass, we can safely remove all of the old operations, because
|
||||
// we know that all inter-dependencies are dropped.
|
||||
for (auto &bb : function) {
|
||||
// Advance the iterator so we don't invalidate it when we remove an
|
||||
// operation later in the loop.
|
||||
for (auto &op : llvm::make_early_inc_range(bb))
|
||||
if (isUnderscoredTFOp(op)) op.erase();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateRaiseTFControlFlowPass() {
|
||||
return std::make_unique<RaiseTFControlFlow>();
|
||||
}
|
||||
|
||||
static PassRegistration<RaiseTFControlFlow> pass(
|
||||
"tf-raise-control-flow",
|
||||
"Raise from the TensorFlow Control Flow "
|
||||
"dialect to the standard TensorFlow dialect");
|
||||
|
||||
} // namespace TFControlFlow
|
||||
} // namespace mlir
|
@ -171,7 +171,7 @@ CreateConvertReadonlyReferenceVariablesToResourceVariablesPass() {
|
||||
|
||||
static PassRegistration<
|
||||
ConvertReadonlyReferenceVariablesToResourceVariablesPass>
|
||||
pass("readonly-references-to-resources",
|
||||
pass("tf-readonly-references-to-resources",
|
||||
"Convert readonly reference variables to resource variables.");
|
||||
|
||||
} // namespace TF
|
||||
|
@ -1,242 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass transforms from TF executor dialect to MLIR TF
|
||||
// control dialect.
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
|
||||
#define DEBUG_TYPE "tf-executor-to-ctl"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace {
|
||||
struct ExecutorToControlDialectConversion
|
||||
: public PassWrapper<ExecutorToControlDialectConversion, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
static bool HasSingleGraph(FuncOp function) {
|
||||
// We expect the function has only one region with one block,
|
||||
if (function.getBlocks().size() != 1) return false;
|
||||
auto &block = function.front();
|
||||
// and the block contains two ops,
|
||||
if (std::next(block.begin()) == block.end()) return false;
|
||||
// one GraphOp,
|
||||
if (!isa<tf_executor::GraphOp>(block.begin())) return false;
|
||||
// followed by a terminator.
|
||||
if (!std::next(block.begin())->isKnownTerminator()) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
void ExecutorToControlDialectConversion::runOnFunction() {
|
||||
if (!HasSingleGraph(getFunction())) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Expect a Function with a single block and a single graph op,"
|
||||
" skip tf_executor dialect conversion\n");
|
||||
return;
|
||||
}
|
||||
Type control_type = TFControlFlow::TFControlType::get(&getContext());
|
||||
|
||||
Block &body = getFunction().front();
|
||||
auto graph = cast<tf_executor::GraphOp>(body.front());
|
||||
OpBuilder builder = OpBuilder::atBlockEnd(&body);
|
||||
SmallString<64> new_op_name;
|
||||
for (auto &op : llvm::make_early_inc_range(llvm::reverse(graph.GetBody()))) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Process: " << op.getName() << "\n");
|
||||
|
||||
if (auto fetch = dyn_cast<tf_executor::FetchOp>(op)) {
|
||||
// Replace all the operands of the fetch op with the uses of the graph
|
||||
// results, remove the fetch op afterwards.
|
||||
for (auto ops_and_ret_vals :
|
||||
llvm::zip(graph.getResults(), fetch.getOperands()))
|
||||
std::get<0>(ops_and_ret_vals)
|
||||
.replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
|
||||
op.erase();
|
||||
continue;
|
||||
}
|
||||
|
||||
builder.setInsertionPoint(&op);
|
||||
|
||||
if (auto island = dyn_cast<tf_executor::IslandOp>(op)) {
|
||||
Value ctl_sequence = nullptr;
|
||||
if (island.GetBody().without_terminator().empty() &&
|
||||
island.getNumOperands() > 1) {
|
||||
// For an empty island with multiple control inputs, we create a no-op
|
||||
// inside it which will group all the inputs into one control output.
|
||||
// This helps reducing the number of edges when there are multiple
|
||||
// islands depending on this one.
|
||||
builder.setInsertionPointToStart(&island.GetBody());
|
||||
builder.create<TF::NoOp>(op.getLoc(), ArrayRef<Type>{},
|
||||
ArrayRef<Value>{}, ArrayRef<NamedAttribute>{});
|
||||
builder.setInsertionPoint(&op);
|
||||
}
|
||||
for (Operation &wrapped_op : island.GetBody()) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< " In island: " << wrapped_op.getName() << "\n");
|
||||
if (isa<tf_executor::YieldOp>(wrapped_op)) {
|
||||
for (auto ops_and_ret_vals :
|
||||
llvm::zip(island.getResults(), wrapped_op.getOperands()))
|
||||
std::get<0>(ops_and_ret_vals)
|
||||
.replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
|
||||
break;
|
||||
}
|
||||
// Add a leading _ off the name.
|
||||
new_op_name = "_";
|
||||
new_op_name += wrapped_op.getName().getStringRef();
|
||||
OperationState state(wrapped_op.getLoc(), new_op_name);
|
||||
|
||||
// Add an operand for each non-control input we find. Collect control
|
||||
// values separately to add them to the island operands
|
||||
state.operands.append(wrapped_op.getOperands().begin(),
|
||||
wrapped_op.getOperands().end());
|
||||
|
||||
// Chain operations through a control dependency, except for the first
|
||||
// operations in the sequence that carry the control dependencies held
|
||||
// by the island itself.
|
||||
if (ctl_sequence) {
|
||||
state.operands.push_back(ctl_sequence);
|
||||
} else {
|
||||
for (Value ctl_operand : island.getOperands())
|
||||
state.operands.push_back(ctl_operand);
|
||||
}
|
||||
|
||||
// Add a result type for each result
|
||||
state.types.append(wrapped_op.getResultTypes().begin(),
|
||||
wrapped_op.getResultTypes().end());
|
||||
state.types.push_back(control_type);
|
||||
|
||||
// Create the replacement operation.
|
||||
auto *replacement = builder.createOperation(state);
|
||||
replacement->setAttrs(wrapped_op.getMutableAttrDict());
|
||||
|
||||
for (auto ops_and_ret_vals :
|
||||
llvm::zip(wrapped_op.getResults(), replacement->getResults()))
|
||||
std::get<0>(ops_and_ret_vals)
|
||||
.replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
|
||||
|
||||
ctl_sequence = replacement->getResult(replacement->getNumResults() - 1);
|
||||
}
|
||||
|
||||
if (ctl_sequence) {
|
||||
// If ctl_sequence is non-null, this means at least one operation has
|
||||
// been rewritten from ops in island. Last op rewritten must logically
|
||||
// carry // all the island control inputs, we can simply use it to
|
||||
// replace all uses of island's control output.
|
||||
island.control().replaceAllUsesWith(ctl_sequence);
|
||||
} else if (island.getNumOperands() > 0) {
|
||||
// Getting here means island had an effectively empty body and there is
|
||||
// just one control input. In this case, island's control output should
|
||||
// be replaced with the control input.
|
||||
assert(island.getNumOperands() == 1);
|
||||
island.control().replaceAllUsesWith(island.getOperand(0));
|
||||
}
|
||||
|
||||
op.erase();
|
||||
continue;
|
||||
}
|
||||
|
||||
new_op_name.clear();
|
||||
if (isa<tf_executor::SwitchOp>(op)) {
|
||||
new_op_name = "_tf.Switch";
|
||||
} else if (isa<tf_executor::SwitchNOp>(op)) {
|
||||
new_op_name = "_tf._SwitchN";
|
||||
} else if (isa<tf_executor::MergeOp>(op)) {
|
||||
new_op_name = "_tf.Merge";
|
||||
} else if (isa<tf_executor::NextIterationSourceOp>(op)) {
|
||||
new_op_name = "_tf.NextIteration.source";
|
||||
} else if (isa<tf_executor::NextIterationSinkOp>(op)) {
|
||||
new_op_name = "_tf.NextIteration.sink";
|
||||
} else if (isa<tf_executor::LoopCondOp>(op)) {
|
||||
new_op_name = "_tf.LoopCond";
|
||||
} else if (isa<tf_executor::EnterOp>(op)) {
|
||||
new_op_name = "_tf.Enter";
|
||||
} else if (isa<tf_executor::ExitOp>(op)) {
|
||||
new_op_name = "_tf.Exit";
|
||||
} else if (isa<tf_executor::ControlTriggerOp>(op)) {
|
||||
new_op_name = "_tf.ControlTrigger";
|
||||
} else {
|
||||
op.emitOpError() << "unhandled op in tf_executor to _tf conversion";
|
||||
return signalPassFailure();
|
||||
}
|
||||
OperationState state(op.getLoc(), new_op_name);
|
||||
// Drop all TokenType operands since they don't exist in the control
|
||||
// dialect.
|
||||
auto non_null_operands = llvm::make_filter_range(
|
||||
op.getOperands(),
|
||||
[](Value v) { return !v.getType().isa<tf_executor::TokenType>(); });
|
||||
state.operands.append(non_null_operands.begin(), non_null_operands.end());
|
||||
for (Type result_type : op.getResultTypes()) {
|
||||
// Filter out TokenType, they don't exist in the control dialect.
|
||||
if (result_type.isa<tf_executor::TokenType>()) continue;
|
||||
if (!result_type.isa<tf_executor::ControlType>())
|
||||
state.types.push_back(result_type);
|
||||
else
|
||||
state.types.push_back(control_type);
|
||||
}
|
||||
// The control dialect has a control result for the sink operation.
|
||||
if (isa<tf_executor::NextIterationSinkOp>(op))
|
||||
state.types.push_back(control_type);
|
||||
|
||||
// Create the replacement operation.
|
||||
auto *replacement = builder.createOperation(state);
|
||||
replacement->setAttrs(op.getMutableAttrDict());
|
||||
|
||||
if (auto next_iteration =
|
||||
dyn_cast<tf_executor::NextIterationSourceOp>(op)) {
|
||||
next_iteration.output().replaceAllUsesWith(replacement->getResult(0));
|
||||
next_iteration.token().dropAllUses();
|
||||
next_iteration.control().replaceAllUsesWith(replacement->getResult(1));
|
||||
} else {
|
||||
for (auto ops_and_ret_vals :
|
||||
llvm::zip(op.getResults(), replacement->getResults()))
|
||||
std::get<0>(ops_and_ret_vals)
|
||||
.replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
|
||||
}
|
||||
op.erase();
|
||||
}
|
||||
|
||||
// Now we have rewritten all ops inside GraphOp to TF Control dialect. We need
|
||||
// to move all operations outside of GraphOp and remove it.
|
||||
body.getOperations().splice(body.begin(), graph.GetBody().getOperations());
|
||||
graph.erase();
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateTFExecutorToControlDialectConversion() {
|
||||
return std::make_unique<ExecutorToControlDialectConversion>();
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
static mlir::PassRegistration<mlir::ExecutorToControlDialectConversion> pass(
|
||||
"tf-executor-to-control-conversion",
|
||||
"Convert from TF executor dialect to TF control dialect");
|
@ -89,12 +89,11 @@ StatusOr<ElementsAttr> ConvertFlatTensor(const Tensor& input_tensor,
|
||||
|
||||
ElementsAttr ConvertBf16Tensor(const Tensor& input_tensor,
|
||||
RankedTensorType type) {
|
||||
auto flat = input_tensor.flat<bfloat16>();
|
||||
llvm::SmallVector<llvm::APFloat, 4> floats;
|
||||
floats.reserve(flat.size());
|
||||
for (bfloat16 v : llvm::makeArrayRef(flat.data(), flat.size()))
|
||||
floats.push_back(llvm::APFloat(static_cast<double>(v)));
|
||||
return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(floats));
|
||||
auto buffer = llvm::makeArrayRef(static_cast<char*>(input_tensor.data()),
|
||||
input_tensor.TotalBytes());
|
||||
return mlir::DenseElementsAttr::getFromRawBuffer(
|
||||
type, buffer,
|
||||
/*isSplatBuffer=*/type.getNumElements() == 1);
|
||||
}
|
||||
|
||||
ElementsAttr ConvertHalfTensor(const Tensor& tensor, RankedTensorType type) {
|
||||
@ -280,16 +279,11 @@ void ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr,
|
||||
|
||||
void ConvertBfloat16ElementsAttr(const mlir::DenseFPElementsAttr attr,
|
||||
protobuf::RepeatedField<int>* output) {
|
||||
// Bfloat16 is internally represented as `double` in MLIR.
|
||||
if (attr.isSplat()) {
|
||||
double v = attr.getSplatValue<double>();
|
||||
bfloat16 bf16_val = static_cast<bfloat16>(v);
|
||||
output->Add(absl::bit_cast<int16>(bf16_val));
|
||||
output->Add((*attr.begin()).bitcastToAPInt().getSExtValue());
|
||||
} else {
|
||||
for (auto v : attr.getValues<double>()) {
|
||||
bfloat16 bf16_val = static_cast<bfloat16>(v);
|
||||
output->Add(absl::bit_cast<int16>(bf16_val));
|
||||
}
|
||||
for (const llvm::APFloat value : attr.getFloatValues())
|
||||
output->Add(value.bitcastToAPInt().getSExtValue());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -23,12 +23,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
/// Create a pass to convert from the TFExecutor to the TF control dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateTFExecutorToControlDialectConversion();
|
||||
} // namespace mlir
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm) {
|
||||
|
@ -28,6 +28,8 @@ package_group(
|
||||
|
||||
exports_files(["ir/hlo_ops.td"])
|
||||
|
||||
exports_files(["ir/lhlo_ops.td"])
|
||||
|
||||
filegroup(
|
||||
name = "hlo_ops_td_files",
|
||||
srcs = [
|
||||
@ -35,6 +37,7 @@ filegroup(
|
||||
"ir/hlo_ops.td",
|
||||
"ir/hlo_ops_base.td",
|
||||
"ir/hlo_utils.td",
|
||||
"ir/infer_fusibility_op_interface.td",
|
||||
"ir/lhlo_ops.td",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
||||
@ -87,6 +90,8 @@ gentbl(
|
||||
tbl_outs = [
|
||||
("-gen-op-decls", "ir/lhlo_ops.h.inc"),
|
||||
("-gen-op-defs", "ir/lhlo_ops.cc.inc"),
|
||||
("-gen-struct-attr-decls", "ir/lhlo_structs.h.inc"),
|
||||
("-gen-struct-attr-defs", "ir/lhlo_structs.cc.inc"),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "ir/lhlo_ops.td",
|
||||
@ -118,6 +123,42 @@ gentbl(
|
||||
td_srcs = [":hlo_ops_td_files"],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "infer_fusibility_op_interface_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-op-interface-decls",
|
||||
"ir/infer_fusibility_op_interface.h.inc",
|
||||
),
|
||||
(
|
||||
"-gen-op-interface-defs",
|
||||
"ir/infer_fusibility_op_interface.cc.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "ir/infer_fusibility_op_interface.td",
|
||||
td_srcs = [
|
||||
":hlo_ops_td_files",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "infer_fusibility_op_interface",
|
||||
srcs = [
|
||||
"ir/infer_fusibility_op_interface.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"ir/infer_fusibility_op_interface.h",
|
||||
"ir/infer_fusibility_op_interface.h.inc",
|
||||
],
|
||||
deps = [
|
||||
":infer_fusibility_op_interface_gen",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_legalize_tf",
|
||||
srcs = [
|
||||
@ -362,6 +403,7 @@ cc_library(
|
||||
":map_hlo_to_lhlo_op",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LinalgOps",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
@ -369,6 +411,43 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cycle_detector",
|
||||
srcs = ["transforms/cycle_detector.cc"],
|
||||
hdrs = ["transforms/cycle_detector.h"],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "cycle_detector_test",
|
||||
srcs = ["transforms/cycle_detector_test.cc"],
|
||||
deps = [
|
||||
":cycle_detector",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_hlo_fusion",
|
||||
srcs = ["transforms/xla_hlo_fusion.cc"],
|
||||
deps = [
|
||||
":cycle_detector",
|
||||
":hlo",
|
||||
"@llvm-project//llvm:ir",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "xla_legalize_to_standard_inc_gen",
|
||||
tbl_outs = [
|
||||
@ -555,6 +634,7 @@ cc_library(
|
||||
":convert_op_folder",
|
||||
":hlo_ops_base_inc_gen",
|
||||
":hlo_ops_inc_gen",
|
||||
":infer_fusibility_op_interface",
|
||||
":xla_canonicalize_inc_gen",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@llvm-project//llvm:support",
|
||||
@ -824,6 +904,7 @@ genrule(
|
||||
":ir/hlo_ops.td",
|
||||
":ir/hlo_ops_base.td",
|
||||
":ir/hlo_utils.td",
|
||||
":ir/infer_fusibility_op_interface.td",
|
||||
],
|
||||
outs = ["operator_writers.inc"],
|
||||
cmd = ("$(location :operator_writer_gen) " +
|
||||
@ -859,6 +940,7 @@ cc_library(
|
||||
":lhlo_legalize_to_gpu",
|
||||
":lhlo_legalize_to_parallel_loops",
|
||||
":xla_dialect_registration",
|
||||
":xla_hlo_fusion",
|
||||
":xla_hlo_to_lhlo_with_xla",
|
||||
":xla_legalize_control_flow",
|
||||
":xla_legalize_tf",
|
||||
|
@ -44,8 +44,7 @@ template <typename CppType>
|
||||
}
|
||||
|
||||
mlir::APFloat ConvertToAPFloat(bfloat16 val) {
|
||||
// bfloat16 values are stored as double in MLIR.
|
||||
return llvm::APFloat(static_cast<double>(val));
|
||||
return llvm::APFloat(llvm::APFloat::BFloat(), llvm::APInt(16, val.value));
|
||||
}
|
||||
|
||||
mlir::APFloat ConvertToAPFloat(half val) {
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/ir/infer_fusibility_op_interface.h"
|
||||
|
||||
namespace mlir {
|
||||
class OpBuilder;
|
||||
|
@ -26,6 +26,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td"
|
||||
include "tensorflow/compiler/mlir/xla/ir/hlo_utils.td"
|
||||
include "tensorflow/compiler/mlir/xla/ir/infer_fusibility_op_interface.td"
|
||||
|
||||
def HLO_Dialect : Dialect {
|
||||
let name = "xla_hlo";
|
||||
@ -117,7 +118,7 @@ def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> {
|
||||
|
||||
class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
||||
Type TensorType>: HLO_Op<mnemonic,
|
||||
!listconcat(traits, [InferShapedTypeOpInterface])> {
|
||||
!listconcat(traits, [InferShapedTypeOpInterface, InferFusibilityOpInterface])> {
|
||||
let arguments = (ins TensorType:$operand);
|
||||
let results = (outs TensorType);
|
||||
let extraClassDeclaration = [{
|
||||
@ -132,6 +133,12 @@ class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
||||
return deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||
&reifiedReturnShapes);
|
||||
}
|
||||
bool inferInputOutputShapeEquality(int input, int output) {
|
||||
return true;
|
||||
}
|
||||
llvm::Optional<Value> inferEffectiveWorkloadShape() {
|
||||
return getOperation()->getResult(0);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
@ -257,7 +264,7 @@ def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh",
|
||||
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
|
||||
|
||||
class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
|
||||
HLO_Op<mnemonic, !listconcat(traits, [InferShapedTypeOpInterface])> {
|
||||
HLO_Op<mnemonic, !listconcat(traits, [InferShapedTypeOpInterface, InferFusibilityOpInterface])> {
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$lhs,
|
||||
HLO_Tensor:$rhs
|
||||
@ -275,6 +282,15 @@ class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
|
||||
return deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||
&reifiedReturnShapes);
|
||||
}
|
||||
bool inferInputsShapeEquality(int lhs, int rhs) {
|
||||
return true;
|
||||
}
|
||||
bool inferInputOutputShapeEquality(int input, int output) {
|
||||
return true;
|
||||
}
|
||||
llvm::Optional<Value> inferEffectiveWorkloadShape() {
|
||||
return getOperation()->getResult(0);
|
||||
}
|
||||
}];
|
||||
|
||||
let results = (outs HLO_Tensor);
|
||||
@ -598,7 +614,8 @@ def HLO_AllToAllOp : HLO_Op<"all_to_all",
|
||||
def HLO_ReduceOp: HLO_Op<"reduce", [
|
||||
RecursiveSideEffects,
|
||||
SameVariadicOperandSize,
|
||||
SingleBlockImplicitTerminator<"ReturnOp">
|
||||
SingleBlockImplicitTerminator<"ReturnOp">,
|
||||
InferFusibilityOpInterface
|
||||
]>, BASE_HLO_ReduceOp {
|
||||
let arguments = (ins
|
||||
Variadic<HLO_TensorOrTuple>:$operands,
|
||||
@ -613,6 +630,15 @@ def HLO_ReduceOp: HLO_Op<"reduce", [
|
||||
"ValueRange init_values, DenseIntElementsAttr dimensions"
|
||||
>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
bool isFusibleWithConsumer() {
|
||||
return false;
|
||||
}
|
||||
llvm::Optional<Value> inferEffectiveWorkloadShape() {
|
||||
return getOperation()->getOperand(0);
|
||||
}
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
// TODO(hinsu): Verify that the attached body arguments and results are
|
||||
@ -1360,4 +1386,27 @@ def HLO_DequantizeOp : HLO_Op<"dequantize", [NoSideEffect]>,
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
def HLO_FusionOp : HLO_Op<"fusion", []> {
|
||||
let summary = "Fusion operator";
|
||||
let description = [{
|
||||
Models the fusion instruction.
|
||||
|
||||
A fusion op is consists of a group of basic ops (represented as a region
|
||||
attached to it). It serves as a hint to the backend that it is beneficial
|
||||
to emit the contained ops into a single loop nest or kernel.
|
||||
}];
|
||||
let regions = (region SizedRegion<1>:$fused_computation);
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<HLO_TensorOrTuple>:$operands
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<HLO_TensorOrTuple>:$results
|
||||
);
|
||||
|
||||
// FusionOp has special conversion logic to HLO.
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
#endif // HLO_OPS
|
||||
|
@ -0,0 +1,22 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/ir/infer_fusibility_op_interface.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/ir/infer_fusibility_op_interface.cc.inc"
|
||||
|
||||
} // namespace mlir
|
@ -0,0 +1,28 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/ir/infer_fusibility_op_interface.h.inc"
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
|
161
tensorflow/compiler/mlir/xla/ir/infer_fusibility_op_interface.td
Normal file
161
tensorflow/compiler/mlir/xla/ir/infer_fusibility_op_interface.td
Normal file
@ -0,0 +1,161 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file contains inferFusiblityOpInterface, which is used to guide
|
||||
// fusion decision.
|
||||
|
||||
#ifndef MLIR_INFER_FUSIBILITY_OP_INTERFACE
|
||||
#define MLIR_INFER_FUSIBILITY_OP_INTERFACE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
// OpInterface to query if an op is fusible and to query the shape equality
|
||||
// constraint among the inputs and outputs of an op.
|
||||
def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> {
|
||||
let description = [{
|
||||
Interface to query if an op is fusible and to query the shape equality
|
||||
constraint among the inputs and outputs of an op.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{If true, this op can be fused with its operands
|
||||
}],
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"isFusibleWithOperand",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/[{}],
|
||||
/*defaultImplementation=*/[{
|
||||
/// Returns whether this op can be fused with its operands
|
||||
return true;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{If true, this op can be fused with its consumers
|
||||
}],
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"isFusibleWithConsumer",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/[{}],
|
||||
/*defaultImplementation=*/[{
|
||||
/// Return whether this op can be fused withh its consumers
|
||||
return true;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/"Return whether two inputs have the same shape (assuming no"
|
||||
"implicit broadcasting).",
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"inferInputsShapeEquality",
|
||||
/*args=*/(ins "int":$lhs, "int":$rhs),
|
||||
/*methodBody=*/[{}],
|
||||
/*defaultImplementation=*/[{
|
||||
/// Return whether two inputs have the same shape.
|
||||
Operation *op = this->getOperation();
|
||||
assert(lhs < op->getNumOperands() && lhs >= 0 &&
|
||||
rhs < op->getNumOperands() && rhs >= 0);
|
||||
if (lhs == rhs) return true;
|
||||
|
||||
// if both lhs and rhs have static shapes, check them directly
|
||||
Type lhs_ty = op->getOperand(lhs).getType();
|
||||
Type rhs_ty = op->getOperand(rhs).getType();
|
||||
auto lhs_shape_type = lhs_ty.dyn_cast_or_null<RankedTensorType>();
|
||||
auto rhs_shape_type = rhs_ty.dyn_cast_or_null<RankedTensorType>();
|
||||
if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() ||
|
||||
!rhs_shape_type || !rhs_shape_type.hasStaticShape() ||
|
||||
lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
|
||||
return false;
|
||||
}
|
||||
return lhs_shape_type.getShape() == rhs_shape_type.getShape();
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/"Return whether two outputs have the same shape (assuming no"
|
||||
" implicit broadcasting).",
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"inferOutputsShapeEquality",
|
||||
/*args=*/(ins "int":$lhs, "int":$rhs),
|
||||
/*methodBody=*/[{}],
|
||||
/*defaultImplementation=*/[{
|
||||
/// Return whether two outputs have the same shape.
|
||||
Operation *op = this->getOperation();
|
||||
assert(lhs < op->getNumResults() && lhs >= 0 &&
|
||||
rhs < op->getNumResults() && rhs >= 0);
|
||||
if (lhs == rhs) return true;
|
||||
|
||||
// if both lhs and rhs have static shapes, check them directly
|
||||
Type lhs_ty = op->getResult(lhs).getType();
|
||||
Type rhs_ty = op->getResult(rhs).getType();
|
||||
auto lhs_shape_type = lhs_ty.dyn_cast_or_null<RankedTensorType>();
|
||||
auto rhs_shape_type = rhs_ty.dyn_cast_or_null<RankedTensorType>();
|
||||
if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() ||
|
||||
!rhs_shape_type || !rhs_shape_type.hasStaticShape() ||
|
||||
lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
|
||||
return false;
|
||||
}
|
||||
return lhs_shape_type.getShape() == rhs_shape_type.getShape();
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/"Return whether the input and the output have the same"
|
||||
" shape (assuming no implicit broadcasting).",
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"inferInputOutputShapeEquality",
|
||||
/*args=*/(ins "int":$input, "int":$output),
|
||||
/*methodBody=*/[{}],
|
||||
/*defaultImplementation=*/[{
|
||||
/// Return whether the input and the output have the same shape.
|
||||
Operation *op = this->getOperation();
|
||||
assert(input < op->getNumOperands() && input >= 0 &&
|
||||
output < op->getNumResults() && output >= 0);
|
||||
|
||||
// if both input and output have static shapes, check them directly
|
||||
Type input_ty = op->getOperand(input).getType();
|
||||
Type output_ty = op->getResult(output).getType();
|
||||
auto input_shape_type = input_ty.dyn_cast_or_null<RankedTensorType>();
|
||||
auto output_shape_type = output_ty.dyn_cast_or_null<RankedTensorType>();
|
||||
if (!input_shape_type || !input_shape_type.hasStaticShape() ||
|
||||
!output_shape_type || !output_shape_type.hasStaticShape() ||
|
||||
input_shape_type.getRank() != output_shape_type.getRank()) {
|
||||
return false;
|
||||
}
|
||||
return input_shape_type.getShape() == output_shape_type.getShape();
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{Return the effective workload shape for the operation.
|
||||
|
||||
Here the effective workload shape roughly represents the maximum
|
||||
parallelism can be used during the codegen stage. It's used to check
|
||||
the shape-compatibility of the operation. During fusion, we only
|
||||
try to fuse shape-compatible ops for performace.
|
||||
For example, the effective workload shape of an elementwise op is its
|
||||
output shape, while the effective workload shape of a reduction op may
|
||||
be its operand shape.
|
||||
Return None if such an inference is not possible.
|
||||
}],
|
||||
/*retTy=*/"llvm::Optional<Value>",
|
||||
/*methodName=*/"inferEffectiveWorkloadShape",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/[{}],
|
||||
/*defaultImplementation=*/[{
|
||||
/// Return effective workload size if possible, otherwise None.
|
||||
return {};
|
||||
}]
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
#endif
|
@ -45,6 +45,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
#include "tensorflow/compiler/mlir/xla/ir/lhlo_structs.cc.inc"
|
||||
namespace xla_lhlo {
|
||||
|
||||
XlaLhloDialect::XlaLhloDialect(MLIRContext *context)
|
||||
|
@ -33,6 +33,8 @@ limitations under the License.
|
||||
namespace mlir {
|
||||
class OpBuilder;
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/ir/lhlo_structs.h.inc"
|
||||
|
||||
namespace xla_lhlo {
|
||||
|
||||
class XlaLhloDialect : public Dialect {
|
||||
|
@ -407,11 +407,39 @@ def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp {
|
||||
);
|
||||
}
|
||||
|
||||
// TODO(bondhugula): Make this struct dialect independent so that it can be
|
||||
// shared between the HLO and LHLO dialects.
|
||||
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [
|
||||
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
||||
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
||||
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
||||
|
||||
let description = "Structure of dimension information for conv op";
|
||||
}
|
||||
|
||||
def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
// Default value: one for each of the spatial dimension.
|
||||
OptionalAttr<I64ElementsAttr>:$window_strides,
|
||||
// Default value: zero for each of the spatial dimension.
|
||||
OptionalAttr<I64ElementsAttr>:$padding,
|
||||
// Default value: one for each of the spatial dimension.
|
||||
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
||||
// Default value: one for each of the spatial dimension.
|
||||
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
||||
ConvDimensionNumbers:$dimension_numbers,
|
||||
I64Attr:$feature_group_count,
|
||||
I64Attr:$batch_group_count,
|
||||
HLO_PrecisionConfigAttr:$precision_config
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -937,6 +937,11 @@ LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) {
|
||||
// TODO(whoever): currently not supported.
|
||||
return failure();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
@ -979,10 +984,10 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
|
||||
values.reserve(attr.getNumElements());
|
||||
for (APFloat val : attr.getValues<APFloat>()) {
|
||||
bool loses_info = false;
|
||||
CHECK_EQ(val.convert(llvm::APFloat::IEEEsingle(),
|
||||
llvm::APFloat::rmTowardZero, &loses_info),
|
||||
llvm::APFloat::opOK);
|
||||
CHECK(!loses_info);
|
||||
TF_RET_CHECK(val.convert(llvm::APFloat::IEEEsingle(),
|
||||
llvm::APFloat::rmTowardZero,
|
||||
&loses_info) == llvm::APFloat::opOK);
|
||||
TF_RET_CHECK(!loses_info);
|
||||
values.push_back(xla::half(val.convertToFloat()));
|
||||
}
|
||||
xla::Array<xla::half> source_data(shape.dimensions());
|
||||
@ -992,10 +997,15 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
|
||||
case xla::PrimitiveType::BF16: {
|
||||
xla::Array<double> source_data(shape.dimensions());
|
||||
auto attr_values = attr.getValues<APFloat>();
|
||||
std::vector<double> values_double(source_data.num_elements());
|
||||
for (auto index_and_value : llvm::enumerate(attr_values)) {
|
||||
values_double[index_and_value.index()] =
|
||||
index_and_value.value().convertToDouble();
|
||||
std::vector<double> values_double;
|
||||
values_double.reserve(source_data.num_elements());
|
||||
for (APFloat val : attr_values) {
|
||||
bool loses_info = false;
|
||||
TF_RET_CHECK(val.convert(llvm::APFloat::IEEEdouble(),
|
||||
llvm::APFloat::rmTowardZero,
|
||||
&loses_info) == llvm::APFloat::opOK);
|
||||
TF_RET_CHECK(!loses_info);
|
||||
values_double.push_back(val.convertToDouble());
|
||||
}
|
||||
source_data.SetValues(values_double);
|
||||
return xla::LiteralUtil::ConvertF64ToBF16(
|
||||
|
@ -191,7 +191,7 @@ func @const_f32_bf16() -> tensor<bf16> {
|
||||
|
||||
// CHECK-LABEL: func @const_bf16_f64
|
||||
func @const_bf16_f64() -> tensor<f64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+00> : tensor<f64>
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.187500e+00> : tensor<f64>
|
||||
%cst = xla_hlo.constant dense<4.2> : tensor<bf16>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<f64>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
|
@ -432,7 +432,39 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]],
|
||||
// CHECK-SAME: %[[RESULT:.*]]: [[TYPE]])
|
||||
// CHECK: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %{{.*}}) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
|
||||
%dot = "xla_hlo.dot"(%arg0, %arg0)
|
||||
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
|
||||
return %dot : tensor<1024x1024xf32>
|
||||
}
|
||||
%dot = "xla_hlo.dot"(%arg0, %arg0)
|
||||
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
|
||||
return %dot : tensor<1024x1024xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @conv
|
||||
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> {
|
||||
%c0 = constant 0 : index
|
||||
// CHECK: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
|
||||
// CHECK: "xla_lhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
|
||||
// CHECK-SAME: padding = dense<[
|
||||
// CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
|
||||
// CHECK-SAME: rhs_dilation = dense<[1, 2]>
|
||||
// CHECK-SAME: window_strides = dense<[2, 1]>
|
||||
%out = "xla_hlo.convolution"(%filter, %input) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
input_batch_dimension = 0 : i64,
|
||||
input_feature_dimension = 3 : i64,
|
||||
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
|
||||
kernel_input_feature_dimension = 2 : i64,
|
||||
kernel_output_feature_dimension = 3 : i64,
|
||||
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
||||
output_batch_dimension = 0 : i64,
|
||||
output_feature_dimension = 3 : i64,
|
||||
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
},
|
||||
feature_group_count = 1 : i64,
|
||||
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
|
||||
window_strides = dense<[2, 1]> : tensor<2xi64>
|
||||
} : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32>
|
||||
return %out : tensor<3x5x5x4xf32>
|
||||
}
|
||||
|
@ -511,7 +511,7 @@ func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %[[RESULT_OUT:.*]]):
|
||||
// CHECK-NEXT: %[[L0:.*]] = constant 0 : i32
|
||||
// CHECK-NEXT: %[[L0:.*]] = constant 0 : i32
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = subi %[[L0]], %[[OPERAND_IN]] : i32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||
|
||||
@ -691,3 +691,27 @@ func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) {
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: memref<3x5x5x4xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
%0 = alloc() : memref<3x5x5x4xf32>
|
||||
// CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}})
|
||||
// CHECK-SAME: dilations = [1, 2]
|
||||
// CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64>
|
||||
// CHECK-SAME: strides = [2, 1]}
|
||||
// With all atributes explicitly specified.
|
||||
"xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
|
||||
|
||||
// Dilation left unspecified, sets default dilation since linalg expects it.
|
||||
// CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}})
|
||||
// CHECK-SAME: dilations = [1, 1]
|
||||
// Padding is not set if it's zero.
|
||||
// CHECK-NOT: padding
|
||||
"xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
|
||||
|
||||
"xla_lhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
97
tensorflow/compiler/mlir/xla/tests/xla-hlo-fusion.mlir
Normal file
97
tensorflow/compiler/mlir/xla/tests/xla-hlo-fusion.mlir
Normal file
@ -0,0 +1,97 @@
|
||||
// RUN: tf-opt %s -xla-hlo-fusion -split-input-file | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK-LABEL: func @multi_outputs_same
|
||||
func @multi_outputs_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "xla_hlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%2 = "xla_hlo.add"(%1, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RET:.*]]:2 = "xla_hlo.fusion"
|
||||
// CHECK-NEXT: xla_hlo.add
|
||||
// CHECK-NEXT: xla_hlo.subtract
|
||||
// CHECK-NEXT: xla_hlo.add
|
||||
// CHECK-NEXT: xla_hlo.return
|
||||
return %1, %2 : tensor<?x?xf32>, tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @multi_outputs_same_2
|
||||
func @multi_outputs_same_2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
|
||||
%0 = "xla_hlo.abs"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "xla_hlo.abs"(%arg1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%2 = "xla_hlo.add"(%0, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%3 = "xla_hlo.abs"(%0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%4 = "xla_hlo.abs"(%1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RET:.*]]:3 = "xla_hlo.fusion"
|
||||
// CHECK-NEXT: xla_hlo.abs
|
||||
// CHECK-NEXT: xla_hlo.abs
|
||||
// CHECK-NEXT: xla_hlo.add
|
||||
// CHECK-NEXT: xla_hlo.abs
|
||||
// CHECK-NEXT: xla_hlo.abs
|
||||
// CHECK-NEXT: xla_hlo.return
|
||||
return %2, %3, %4 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @multi_outputs_not_sure_same
|
||||
func @multi_outputs_not_sure_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK-NOT: xla_hlo.fusion
|
||||
%1 = "xla_hlo.subtract"(%arg1, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @reduce
|
||||
func @reduce(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?xf32>) {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "xla_hlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RET0:.*]] = "xla_hlo.fusion"
|
||||
// CHECK-NEXT: xla_hlo.add
|
||||
// CHECK-NEXT: xla_hlo.subtract
|
||||
// CHECK-NEXT: xla_hlo.return
|
||||
// Currently we do not support fuse arguments and ops without direct producer-consumer
|
||||
// relationship. Thus Reduce Op should not be fused with above two ops.
|
||||
|
||||
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%3 = "xla_hlo.reduce"(%arg0, %2) ( {
|
||||
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||
%4 = "xla_hlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%4) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||
%4 = "xla_hlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// Above two ops should not be fused since reduce op can not be
|
||||
// fused with its consumer.
|
||||
// CHECK-NOT: xla_hlo.fusion
|
||||
|
||||
return %1, %4 : tensor<?x?xf32>, tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @reduce_2
|
||||
func @reduce_2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?xf32>) {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "xla_hlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
|
||||
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%3 = "xla_hlo.reduce"(%1, %2) ( {
|
||||
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||
%4 = "xla_hlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%4) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||
// CHECK: %[[RET0:.*]]:2 = "xla_hlo.fusion"
|
||||
// CHECK-NEXT: xla_hlo.add
|
||||
// CHECK-NEXT: xla_hlo.subtract
|
||||
// CHECK-NEXT: xla_hlo.constant
|
||||
// CHECK-NEXT: xla_hlo.reduce
|
||||
// CHECK: xla_hlo.return
|
||||
|
||||
// Following op should not be fused with the above ops since reduce op can not be
|
||||
// fused with its consumer.
|
||||
// CHECK-NOT: xla_hlo.fusion
|
||||
%4 = "xla_hlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %1, %4 : tensor<?x?xf32>, tensor<?xf32>
|
||||
}
|
340
tensorflow/compiler/mlir/xla/transforms/cycle_detector.cc
Normal file
340
tensorflow/compiler/mlir/xla/transforms/cycle_detector.cc
Normal file
@ -0,0 +1,340 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/cycle_detector.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
using NodeSet = llvm::DenseSet<int32_t>;
|
||||
using OrderedNodeSet = OrderedSet<int32_t>;
|
||||
|
||||
template <typename T>
|
||||
struct VecStruct {
|
||||
typedef llvm::SmallVector<T, 4> type;
|
||||
};
|
||||
template <typename T>
|
||||
using Vec = typename VecStruct<T>::type;
|
||||
|
||||
struct Node {
|
||||
// rank number assigned by Pearce-Kelly algorithm
|
||||
int32_t rank;
|
||||
// Temporary marker used by depth-first-search
|
||||
bool visited;
|
||||
// User-supplied data
|
||||
void* data;
|
||||
// List of immediate predecessor nodes in graph
|
||||
OrderedNodeSet in;
|
||||
// List of immediate successor nodes in graph
|
||||
OrderedNodeSet out;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
struct GraphCycles::Rep {
|
||||
Vec<Node*> nodes;
|
||||
// Indices for unused entries in nodes
|
||||
Vec<int32_t> free_nodes;
|
||||
|
||||
// Temporary state.
|
||||
// Results of forward DFS
|
||||
Vec<int32_t> deltaf;
|
||||
// Results of backward DFS
|
||||
Vec<int32_t> deltab;
|
||||
// All nodes to reprocess
|
||||
Vec<int32_t> list;
|
||||
// Rank values to assign to list entries
|
||||
Vec<int32_t> merged;
|
||||
// Emulates recursion stack when doing depth first search
|
||||
Vec<int32_t> stack;
|
||||
};
|
||||
|
||||
GraphCycles::GraphCycles(int32_t num_nodes) : rep_(new Rep) {
|
||||
rep_->nodes.reserve(num_nodes);
|
||||
for (int32_t i = 0; i < num_nodes; ++i) {
|
||||
Node* n = new Node;
|
||||
n->visited = false;
|
||||
n->data = nullptr;
|
||||
n->rank = rep_->nodes.size();
|
||||
rep_->nodes.push_back(n);
|
||||
}
|
||||
}
|
||||
|
||||
GraphCycles::~GraphCycles() {
|
||||
for (Vec<Node*>::size_type i = 0, e = rep_->nodes.size(); i < e; ++i) {
|
||||
delete rep_->nodes[i];
|
||||
}
|
||||
delete rep_;
|
||||
}
|
||||
|
||||
bool GraphCycles::HasEdge(int32_t x, int32_t y) const {
|
||||
return rep_->nodes[x]->out.Contains(y);
|
||||
}
|
||||
|
||||
void GraphCycles::RemoveEdge(int32_t x, int32_t y) {
|
||||
rep_->nodes[x]->out.Erase(y);
|
||||
rep_->nodes[y]->in.Erase(x);
|
||||
// No need to update the rank assignment since a previous valid
|
||||
// rank assignment remains valid after an edge deletion.
|
||||
}
|
||||
|
||||
static bool ForwardDFS(GraphCycles::Rep* r, int32_t n, int32_t upper_bound);
|
||||
static void BackwardDFS(GraphCycles::Rep* r, int32_t n, int32_t lower_bound);
|
||||
static void Reorder(GraphCycles::Rep* r);
|
||||
static void Sort(const Vec<Node*>&, Vec<int32_t>* delta);
|
||||
static void MoveToList(GraphCycles::Rep* r, Vec<int32_t>* src,
|
||||
Vec<int32_t>* dst);
|
||||
static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32_t>& nodes);
|
||||
|
||||
bool GraphCycles::InsertEdge(int32_t x, int32_t y) {
|
||||
if (x == y) return false;
|
||||
Rep* r = rep_;
|
||||
Node* nx = r->nodes[x];
|
||||
if (!nx->out.Insert(y)) {
|
||||
// Edge already exists.
|
||||
return true;
|
||||
}
|
||||
|
||||
Node* ny = r->nodes[y];
|
||||
ny->in.Insert(x);
|
||||
|
||||
if (nx->rank <= ny->rank) {
|
||||
// New edge is consistent with existing rank assignment.
|
||||
return true;
|
||||
}
|
||||
|
||||
// Current rank assignments are incompatible with the new edge. Recompute.
|
||||
// We only need to consider nodes that fall in the range [ny->rank,nx->rank].
|
||||
if (ForwardDFS(r, y, nx->rank)) {
|
||||
// Found a cycle. Undo the insertion and tell caller.
|
||||
nx->out.Erase(y);
|
||||
ny->in.Erase(x);
|
||||
// Since we do not call Reorder() on this path, clear any visited
|
||||
// markers left by ForwardDFS.
|
||||
ClearVisitedBits(r, r->deltaf);
|
||||
return false;
|
||||
}
|
||||
BackwardDFS(r, x, ny->rank);
|
||||
Reorder(r);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Follows the edges from producer to consumer and searchs if the node having
|
||||
// rank `n` can reach the node having rank `upper_bound` using a DFS search.
|
||||
// When doing DFS search, We only consider the pathes that satisfy the ranks
|
||||
// of the nodes of the path are all smaller than `upper_bound`.
|
||||
//
|
||||
// Returns true if such path exists.
|
||||
static bool ForwardDFS(GraphCycles::Rep* r, int32_t n, int32_t upper_bound) {
|
||||
// Avoid recursion since stack space might be limited.
|
||||
// We instead keep a stack of nodes to visit.
|
||||
r->deltaf.clear();
|
||||
r->stack.clear();
|
||||
r->stack.push_back(n);
|
||||
while (!r->stack.empty()) {
|
||||
n = r->stack.back();
|
||||
r->stack.pop_back();
|
||||
Node* nn = r->nodes[n];
|
||||
if (nn->visited) continue;
|
||||
|
||||
nn->visited = true;
|
||||
r->deltaf.push_back(n);
|
||||
|
||||
for (auto w : nn->out.GetSequence()) {
|
||||
Node* nw = r->nodes[w];
|
||||
if (nw->rank == upper_bound) {
|
||||
return true;
|
||||
}
|
||||
if (!nw->visited && nw->rank < upper_bound) {
|
||||
r->stack.push_back(w);
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Follows the edges from consumer to producer and visit all the nodes that
|
||||
// is reachable from node `n` and have rank larger than `lower_bound`.
|
||||
static void BackwardDFS(GraphCycles::Rep* r, int32_t n, int32_t lower_bound) {
|
||||
r->deltab.clear();
|
||||
r->stack.clear();
|
||||
r->stack.push_back(n);
|
||||
while (!r->stack.empty()) {
|
||||
n = r->stack.back();
|
||||
r->stack.pop_back();
|
||||
Node* nn = r->nodes[n];
|
||||
if (nn->visited) continue;
|
||||
|
||||
nn->visited = true;
|
||||
r->deltab.push_back(n);
|
||||
|
||||
for (auto w : nn->in.GetSequence()) {
|
||||
Node* nw = r->nodes[w];
|
||||
if (!nw->visited && lower_bound < nw->rank) {
|
||||
r->stack.push_back(w);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Recomputes rank assignments to make them compatible with the edges (producer
|
||||
// has smaller rank than its consumer)
|
||||
static void Reorder(GraphCycles::Rep* r) {
|
||||
Sort(r->nodes, &r->deltab);
|
||||
Sort(r->nodes, &r->deltaf);
|
||||
|
||||
// Adds contents of delta lists to list (backwards deltas first).
|
||||
r->list.clear();
|
||||
MoveToList(r, &r->deltab, &r->list);
|
||||
MoveToList(r, &r->deltaf, &r->list);
|
||||
|
||||
// Produce sorted list of all ranks that will be reassigned.
|
||||
r->merged.resize(r->deltab.size() + r->deltaf.size());
|
||||
std::merge(r->deltab.begin(), r->deltab.end(), r->deltaf.begin(),
|
||||
r->deltaf.end(), r->merged.begin());
|
||||
|
||||
// Assign the ranks in order to the collected list.
|
||||
for (Vec<int32_t>::size_type i = 0, e = r->list.size(); i < e; ++i) {
|
||||
r->nodes[r->list[i]]->rank = r->merged[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Sorts nodes in the vector according to their ranks. Small rank first.
|
||||
static void Sort(const Vec<Node*>& nodes, Vec<int32_t>* delta) {
|
||||
struct ByRank {
|
||||
const Vec<Node*>* nodes;
|
||||
bool operator()(int32_t a, int32_t b) const {
|
||||
return (*nodes)[a]->rank < (*nodes)[b]->rank;
|
||||
}
|
||||
};
|
||||
ByRank cmp;
|
||||
cmp.nodes = &nodes;
|
||||
std::sort(delta->begin(), delta->end(), cmp);
|
||||
}
|
||||
|
||||
// Collects ranks of nodes in vector `src` to vector `dst`
|
||||
static void MoveToList(GraphCycles::Rep* r, Vec<int32_t>* src,
|
||||
Vec<int32_t>* dst) {
|
||||
for (Vec<int32_t>::size_type i = 0, e = src->size(); i < e; i++) {
|
||||
int32_t w = (*src)[i];
|
||||
// Replace src entry with its rank
|
||||
(*src)[i] = r->nodes[w]->rank;
|
||||
// Prepare for future DFS calls
|
||||
r->nodes[w]->visited = false;
|
||||
dst->push_back(w);
|
||||
}
|
||||
}
|
||||
|
||||
// Clears bookkeeping fileds used during the last DFS process.
|
||||
static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32_t>& nodes) {
|
||||
for (Vec<int32_t>::size_type i = 0, e = nodes.size(); i < e; i++) {
|
||||
r->nodes[nodes[i]]->visited = false;
|
||||
}
|
||||
}
|
||||
|
||||
bool GraphCycles::IsReachable(int32_t x, int32_t y) {
|
||||
if (x == y) return true;
|
||||
Rep* r = rep_;
|
||||
Node* nx = r->nodes[x];
|
||||
Node* ny = r->nodes[y];
|
||||
|
||||
if (nx->rank >= ny->rank) {
|
||||
// x cannot reach y since it is after it in the topological ordering
|
||||
return false;
|
||||
}
|
||||
|
||||
// See if x can reach y using a DFS search that is limited to y's rank
|
||||
bool reachable = ForwardDFS(r, x, ny->rank);
|
||||
|
||||
// Clear any visited markers left by ForwardDFS.
|
||||
ClearVisitedBits(r, r->deltaf);
|
||||
return reachable;
|
||||
}
|
||||
|
||||
llvm::Optional<int32_t> GraphCycles::ContractEdge(int32_t a, int32_t b) {
|
||||
assert(HasEdge(a, b));
|
||||
RemoveEdge(a, b);
|
||||
|
||||
if (IsReachable(a, b)) {
|
||||
// Restore the graph to its original state.
|
||||
InsertEdge(a, b);
|
||||
return {};
|
||||
}
|
||||
|
||||
if (rep_->nodes[b]->in.Size() + rep_->nodes[b]->out.Size() >
|
||||
rep_->nodes[a]->in.Size() + rep_->nodes[a]->out.Size()) {
|
||||
// Swap "a" and "b" to minimize copying.
|
||||
std::swap(a, b);
|
||||
}
|
||||
|
||||
Node* nb = rep_->nodes[b];
|
||||
OrderedNodeSet out = std::move(nb->out);
|
||||
OrderedNodeSet in = std::move(nb->in);
|
||||
for (int32_t y : out.GetSequence()) {
|
||||
rep_->nodes[y]->in.Erase(b);
|
||||
}
|
||||
for (int32_t y : in.GetSequence()) {
|
||||
rep_->nodes[y]->out.Erase(b);
|
||||
}
|
||||
rep_->free_nodes.push_back(b);
|
||||
|
||||
rep_->nodes[a]->out.Reserve(rep_->nodes[a]->out.Size() + out.Size());
|
||||
for (int32_t y : out.GetSequence()) {
|
||||
InsertEdge(a, y);
|
||||
}
|
||||
|
||||
rep_->nodes[a]->in.Reserve(rep_->nodes[a]->in.Size() + in.Size());
|
||||
for (int32_t y : in.GetSequence()) {
|
||||
InsertEdge(y, a);
|
||||
}
|
||||
|
||||
// Note, if the swap happened it might be what originally was called "b".
|
||||
return a;
|
||||
}
|
||||
|
||||
std::vector<int32_t> GraphCycles::SuccessorsCopy(int32_t node) const {
|
||||
return rep_->nodes[node]->out.GetSequence();
|
||||
}
|
||||
|
||||
namespace {
|
||||
void SortInPostOrder(const Vec<Node*>& nodes, std::vector<int32_t>* to_sort) {
|
||||
std::sort(to_sort->begin(), to_sort->end(), [&](int32_t a, int32_t b) {
|
||||
return nodes[a]->rank > nodes[b]->rank;
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::vector<int32_t> GraphCycles::AllNodesInPostOrder() const {
|
||||
llvm::DenseSet<int32_t> free_nodes_set;
|
||||
for (int32_t n : rep_->free_nodes) free_nodes_set.insert(n);
|
||||
|
||||
std::vector<int32_t> all_nodes;
|
||||
all_nodes.reserve(rep_->nodes.size() - free_nodes_set.size());
|
||||
for (size_t i = 0, e = rep_->nodes.size(); i < e; i++) {
|
||||
if (!free_nodes_set.count(i)) {
|
||||
all_nodes.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
SortInPostOrder(rep_->nodes, &all_nodes);
|
||||
return all_nodes;
|
||||
}
|
||||
|
||||
} // namespace mlir
|
165
tensorflow/compiler/mlir/xla/transforms/cycle_detector.h
Normal file
165
tensorflow/compiler/mlir/xla/transforms/cycle_detector.h
Normal file
@ -0,0 +1,165 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// This file contains a light version of GraphCycles implemented in
|
||||
// tensorflow/compiler/jit/graphcycles/graphcycles.h
|
||||
//
|
||||
// We re-implement it here because we do not want to rely
|
||||
// on TensorFlow data structures, and hence we can move
|
||||
// corresponding passes to llvm repo. easily in case necessnary.
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
|
||||
// This is a set data structure that provides a deterministic iteration order.
|
||||
// The iteration order of elements only depends on the sequence of
|
||||
// inserts/deletes, so as long as the inserts/deletes happen in the same
|
||||
// sequence, the set will have the same iteration order.
|
||||
//
|
||||
// Assumes that T can be cheaply copied for simplicity.
|
||||
template <typename T>
|
||||
class OrderedSet {
|
||||
public:
|
||||
// Inserts `value` into the ordered set. Returns true if the value was not
|
||||
// present in the set before the insertion.
|
||||
bool Insert(T value) {
|
||||
bool new_insertion =
|
||||
value_to_index_.insert({value, value_sequence_.size()}).second;
|
||||
if (new_insertion) {
|
||||
value_sequence_.push_back(value);
|
||||
}
|
||||
return new_insertion;
|
||||
}
|
||||
|
||||
// Removes `value` from the set. Assumes `value` is already present in the
|
||||
// set.
|
||||
void Erase(T value) {
|
||||
auto it = value_to_index_.find(value);
|
||||
|
||||
// Since we don't want to move values around in `value_sequence_` we swap
|
||||
// the value in the last position and with value to be deleted and then
|
||||
// pop_back.
|
||||
value_to_index_[value_sequence_.back()] = it->second;
|
||||
std::swap(value_sequence_[it->second], value_sequence_.back());
|
||||
value_sequence_.pop_back();
|
||||
value_to_index_.erase(it);
|
||||
}
|
||||
|
||||
void Reserve(size_t new_size) {
|
||||
value_to_index_.reserve(new_size);
|
||||
value_sequence_.reserve(new_size);
|
||||
}
|
||||
|
||||
void Clear() {
|
||||
value_to_index_.clear();
|
||||
value_sequence_.clear();
|
||||
}
|
||||
|
||||
bool Contains(T value) const { return value_to_index_.count(value); }
|
||||
size_t Size() const { return value_sequence_.size(); }
|
||||
|
||||
const std::vector<T>& GetSequence() const { return value_sequence_; }
|
||||
|
||||
private:
|
||||
// The stable order that we maintain through insertions and deletions.
|
||||
std::vector<T> value_sequence_;
|
||||
|
||||
// Maps values to their indices in `value_sequence_`.
|
||||
llvm::DenseMap<T, int> value_to_index_;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------
|
||||
|
||||
// GraphCycles detects the introduction of a cycle into a directed
|
||||
// graph that is being built up incrementally.
|
||||
//
|
||||
// Nodes are identified by small integers. It is not possible to
|
||||
// record multiple edges with the same (source, destination) pair;
|
||||
// requests to add an edge where one already exists are silently
|
||||
// ignored.
|
||||
//
|
||||
// It is also not possible to introduce a cycle; an attempt to insert
|
||||
// an edge that would introduce a cycle fails and returns false.
|
||||
//
|
||||
// GraphCycles uses no internal locking; calls into it should be
|
||||
// serialized externally.
|
||||
|
||||
// Performance considerations:
|
||||
// Works well on sparse graphs, poorly on dense graphs.
|
||||
// Extra information is maintained incrementally to detect cycles quickly.
|
||||
// InsertEdge() is very fast when the edge already exists, and reasonably fast
|
||||
// otherwise.
|
||||
// FindPath() is linear in the size of the graph.
|
||||
// The current implementation uses O(|V|+|E|) space.
|
||||
|
||||
class GraphCycles {
|
||||
public:
|
||||
explicit GraphCycles(int32_t num_nodes);
|
||||
~GraphCycles();
|
||||
|
||||
// Attempt to insert an edge from x to y. If the
|
||||
// edge would introduce a cycle, return false without making any
|
||||
// changes. Otherwise add the edge and return true.
|
||||
bool InsertEdge(int32_t x, int32_t y);
|
||||
|
||||
// Remove any edge that exists from x to y.
|
||||
void RemoveEdge(int32_t x, int32_t y);
|
||||
|
||||
// Return whether there is an edge directly from x to y.
|
||||
bool HasEdge(int32_t x, int32_t y) const;
|
||||
|
||||
// Contracts the edge from 'a' to node 'b', merging nodes 'a' and 'b'. One of
|
||||
// the nodes is removed from the graph, and edges to/from it are added to
|
||||
// the remaining one, which is returned. If contracting the edge would create
|
||||
// a cycle, does nothing and return no value.
|
||||
llvm::Optional<int32_t> ContractEdge(int32_t a, int32_t b);
|
||||
|
||||
// Return whether dest_node `y` is reachable from source_node `x`
|
||||
// by following edges. This is non-thread-safe version.
|
||||
bool IsReachable(int32_t x, int32_t y);
|
||||
|
||||
// Return a copy of the successors set. This is needed for code using the
|
||||
// collection while modifying the GraphCycles.
|
||||
std::vector<int32_t> SuccessorsCopy(int32_t node) const;
|
||||
|
||||
// Returns all nodes in post order.
|
||||
//
|
||||
// If there is a path from X to Y then X appears after Y in the
|
||||
// returned vector.
|
||||
std::vector<int32_t> AllNodesInPostOrder() const;
|
||||
|
||||
// ----------------------------------------------------
|
||||
struct Rep;
|
||||
|
||||
private:
|
||||
GraphCycles(const GraphCycles&) = delete;
|
||||
GraphCycles& operator=(const GraphCycles&) = delete;
|
||||
|
||||
Rep* rep_; // opaque representation
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_
|
@ -0,0 +1,89 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/cycle_detector.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
|
||||
class GraphCyclesTest : public ::testing::Test {
|
||||
public:
|
||||
GraphCyclesTest() : g_(100) {}
|
||||
|
||||
bool AddEdge(int x, int y) { return g_.InsertEdge(x, y); }
|
||||
|
||||
void AddMultiples() {
|
||||
// For every node x > 0: add edge to 2*x, 3*x
|
||||
for (int x = 1; x < 25; x++) {
|
||||
EXPECT_TRUE(AddEdge(x, 2 * x)) << x;
|
||||
EXPECT_TRUE(AddEdge(x, 3 * x)) << x;
|
||||
}
|
||||
}
|
||||
|
||||
mlir::GraphCycles g_;
|
||||
};
|
||||
|
||||
TEST_F(GraphCyclesTest, NoCycle) { AddMultiples(); }
|
||||
|
||||
TEST_F(GraphCyclesTest, SimpleCycle) {
|
||||
AddMultiples();
|
||||
EXPECT_FALSE(AddEdge(8, 4));
|
||||
}
|
||||
|
||||
TEST_F(GraphCyclesTest, IndirectCycle) {
|
||||
AddMultiples();
|
||||
EXPECT_TRUE(AddEdge(16, 9));
|
||||
EXPECT_FALSE(AddEdge(9, 2));
|
||||
}
|
||||
|
||||
TEST_F(GraphCyclesTest, RemoveEdge) {
|
||||
EXPECT_TRUE(AddEdge(1, 2));
|
||||
EXPECT_TRUE(AddEdge(2, 3));
|
||||
EXPECT_TRUE(AddEdge(3, 4));
|
||||
EXPECT_TRUE(AddEdge(4, 5));
|
||||
g_.RemoveEdge(2, 3);
|
||||
EXPECT_FALSE(g_.HasEdge(2, 3));
|
||||
}
|
||||
|
||||
TEST_F(GraphCyclesTest, IsReachable) {
|
||||
EXPECT_TRUE(AddEdge(1, 2));
|
||||
EXPECT_TRUE(AddEdge(2, 3));
|
||||
EXPECT_TRUE(AddEdge(3, 4));
|
||||
EXPECT_TRUE(AddEdge(4, 5));
|
||||
|
||||
EXPECT_TRUE(g_.IsReachable(1, 5));
|
||||
EXPECT_FALSE(g_.IsReachable(5, 1));
|
||||
}
|
||||
|
||||
TEST_F(GraphCyclesTest, ContractEdge) {
|
||||
ASSERT_TRUE(AddEdge(1, 2));
|
||||
ASSERT_TRUE(AddEdge(1, 3));
|
||||
ASSERT_TRUE(AddEdge(2, 3));
|
||||
ASSERT_TRUE(AddEdge(2, 4));
|
||||
ASSERT_TRUE(AddEdge(3, 4));
|
||||
|
||||
// It will introduce a cycle if the edge is contracted
|
||||
EXPECT_FALSE(g_.ContractEdge(1, 3).hasValue());
|
||||
EXPECT_TRUE(g_.HasEdge(1, 3));
|
||||
|
||||
// Node (2) has more edges.
|
||||
EXPECT_EQ(*g_.ContractEdge(1, 2), 2);
|
||||
EXPECT_TRUE(g_.HasEdge(2, 3));
|
||||
EXPECT_TRUE(g_.HasEdge(2, 4));
|
||||
EXPECT_TRUE(g_.HasEdge(3, 4));
|
||||
|
||||
// Node (2) has more edges.
|
||||
EXPECT_EQ(*g_.ContractEdge(2, 3), 2);
|
||||
EXPECT_TRUE(g_.HasEdge(2, 4));
|
||||
}
|
@ -423,6 +423,7 @@ void populateHLOToLHLOConversionPattern(
|
||||
HloToLhloOpConverter<xla_hlo::CompareOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ComplexOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ConstOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ConvOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ConvertOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CopyOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CosOp>,
|
||||
|
@ -45,6 +45,7 @@ MAP_HLO_TO_LHLO(CeilOp);
|
||||
MAP_HLO_TO_LHLO(ConstOp);
|
||||
MAP_HLO_TO_LHLO(CompareOp);
|
||||
MAP_HLO_TO_LHLO(ComplexOp);
|
||||
MAP_HLO_TO_LHLO(ConvOp);
|
||||
MAP_HLO_TO_LHLO(ConvertOp);
|
||||
MAP_HLO_TO_LHLO(CopyOp);
|
||||
MAP_HLO_TO_LHLO(CosOp);
|
||||
|
@ -73,6 +73,9 @@ std::unique_ptr<OperationPass<FuncOp>> createTransformUnrankedHloPass();
|
||||
// necessary to export to XLA.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
|
||||
|
||||
// fuse xla_hlo ops to kLoop/kInput fusion patterns
|
||||
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusionPass();
|
||||
|
||||
} // namespace xla_hlo
|
||||
|
||||
namespace xla_lhlo {
|
||||
|
579
tensorflow/compiler/mlir/xla/transforms/xla_hlo_fusion.cc
Normal file
579
tensorflow/compiler/mlir/xla/transforms/xla_hlo_fusion.cc
Normal file
@ -0,0 +1,579 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
|
||||
#include "llvm/ADT/EquivalenceClasses.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/cycle_detector.h"
|
||||
|
||||
// This pass has similar functionality of the fusion pass in XLA stack.
|
||||
// However, unlike XLA, it targets the fully dynamic shape scenario.
|
||||
// Currently, it implements the kLoop and kInput fusion templates.
|
||||
// During conversion, it tries to greedily find kLoop/kInput fusion
|
||||
// patterns.
|
||||
//
|
||||
// Similar to XLA, this pass supports fusion pattern having multiple outputs
|
||||
// if all the shape of outputs are consistent. Following are some examples.
|
||||
//
|
||||
// kLoop kInput
|
||||
// +----+ +----+ +----+ +----+ +----+ +----+
|
||||
// |elem| |elem| |elem| |elem<----+elem+---->elem+----+
|
||||
// +-+--+ +-+--+ +-+--+ +-+--+ +----+ +-+--+ |
|
||||
// | | | | | |
|
||||
// | | | | |
|
||||
// +-v--+ | +-v--+ +--v---+ +--v---+ |
|
||||
// |elem+<---+----<+elem| |reduce| |reduce| |
|
||||
// +-+--+ +-+--+ +--+---+ +--+---+ |
|
||||
// | | | | |
|
||||
// | | | | |
|
||||
// v v v v v
|
||||
//
|
||||
// To this end, we also add an simple shape constraint analysis phase.
|
||||
// For kLoop fusion template, it requires all the outputs of the fused
|
||||
// pattern have the same shape. However, we don't know the actual value
|
||||
// of the shape at the compile time in the dynamic shape world.
|
||||
// Fortunately, we could still infer the relationship among different ops
|
||||
// according to their shape constrain traits. Currently, We only consider
|
||||
// shape equality propagation for elementwise ops (assuming that implicit
|
||||
// shape broadcast is forbidden). The above process could be built on the
|
||||
// shape dialect once it is ready.
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace {
|
||||
|
||||
using llvm::EquivalenceClasses;
|
||||
using FusionPattern = std::vector<Operation*>;
|
||||
using FusionPlan = std::vector<FusionPattern>;
|
||||
|
||||
// To support using EquivalenceClasses for Value
|
||||
class ValueWrapper {
|
||||
public:
|
||||
explicit ValueWrapper(Value value) : value_(std::move(value)) {}
|
||||
|
||||
Value getValue() const { return value_; }
|
||||
|
||||
bool operator==(const ValueWrapper& rhs) const {
|
||||
return getValue() == rhs.getValue();
|
||||
}
|
||||
|
||||
private:
|
||||
Value value_;
|
||||
};
|
||||
|
||||
bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs) {
|
||||
auto lhs_value = lhs.getValue().getAsOpaquePointer();
|
||||
auto rhs_value = rhs.getValue().getAsOpaquePointer();
|
||||
return lhs_value < rhs_value;
|
||||
}
|
||||
|
||||
bool IsFusible(Operation* op) {
|
||||
if (matchPattern(op, m_Constant())) {
|
||||
return true;
|
||||
}
|
||||
auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
|
||||
return op_fusibility && (op_fusibility.isFusibleWithOperand() ||
|
||||
op_fusibility.isFusibleWithConsumer());
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> GetInputsOfFusionPattern(const FusionPattern& pattern) {
|
||||
SmallVector<Value, 4> inputs;
|
||||
DenseSet<Value> input_set;
|
||||
DenseSet<Operation*> op_set;
|
||||
for (Operation* op : pattern) {
|
||||
bool inserted = op_set.insert(op).second;
|
||||
(void)inserted;
|
||||
assert(inserted && "FusionPattern contains duplicate operations");
|
||||
}
|
||||
|
||||
for (Operation* op : pattern) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
Operation* operand_op = operand.getDefiningOp();
|
||||
if (op_set.find(operand_op) != op_set.end()) {
|
||||
// skip if defining op is in the pattern
|
||||
continue;
|
||||
}
|
||||
if (input_set.insert(operand).second) {
|
||||
inputs.push_back(operand);
|
||||
}
|
||||
}
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> GetOutputsOfFusionPattern(const FusionPattern& pattern) {
|
||||
SmallVector<Value, 4> outputs;
|
||||
DenseSet<Operation*> op_set;
|
||||
for (Operation* op : pattern) {
|
||||
bool inserted = op_set.insert(op).second;
|
||||
(void)inserted;
|
||||
assert(inserted && "FusionPattern contains duplicate operations");
|
||||
}
|
||||
|
||||
for (Operation* op : pattern) {
|
||||
for (Value result : op->getResults()) {
|
||||
bool has_external_user = llvm::any_of(
|
||||
result.getUses(),
|
||||
[&](OpOperand& use) { return !op_set.count(use.getOwner()); });
|
||||
if (has_external_user) {
|
||||
outputs.push_back(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
FusionPattern MergeFusionPattern(const FusionPattern& lhs,
|
||||
const FusionPattern& rhs) {
|
||||
FusionPattern pattern(lhs);
|
||||
pattern.insert(pattern.end(), rhs.begin(), rhs.end());
|
||||
return pattern;
|
||||
}
|
||||
|
||||
inline int EffectiveSize(const FusionPattern& pattern) {
|
||||
return llvm::count_if(
|
||||
pattern, [](Operation* op) { return !matchPattern(op, m_Constant()); });
|
||||
}
|
||||
|
||||
// This is an simple shape constraint analysis, which is used to
|
||||
// guide fusion decision (e.g. we only fuse shape-compatible ops).
|
||||
//
|
||||
// Currently, We only consider shape equality propagation based
|
||||
// on the shape constrain traits of elementwise ops (assuming that
|
||||
// implicit shape broadcast is forbidden).
|
||||
class ShapeConstraintAnalysis {
|
||||
public:
|
||||
explicit ShapeConstraintAnalysis(const SmallVectorImpl<Operation*>& op_list) {
|
||||
PropagateEquality(op_list);
|
||||
}
|
||||
|
||||
// Returns true is `lhs` and `rhs` are supposed to have same shape.
|
||||
bool HasSameShape(Value lhs, Value rhs) {
|
||||
return impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs));
|
||||
}
|
||||
|
||||
private:
|
||||
// shape equality propagation based on the shape constrains of
|
||||
// elementwise ops.
|
||||
void PropagateEquality(const SmallVectorImpl<Operation*>& op_list) {
|
||||
bool converged = true;
|
||||
do {
|
||||
converged = true;
|
||||
auto update = [&](Value lhs, Value rhs) {
|
||||
if (!impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs))) {
|
||||
converged = false;
|
||||
impl_.unionSets(ValueWrapper(lhs), ValueWrapper(rhs));
|
||||
}
|
||||
};
|
||||
for (Operation* op : op_list) {
|
||||
auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
|
||||
if (!op_fusibility) continue;
|
||||
int numInput = op->getNumOperands();
|
||||
int numOutput = op->getNumResults();
|
||||
// shape equality propagation between inputs.
|
||||
for (int input1 = 0; input1 < numInput; ++input1)
|
||||
for (int input2 = input1 + 1; input2 < numInput; ++input2)
|
||||
if (op_fusibility.inferInputsShapeEquality(input1, input2))
|
||||
update(op->getOperand(input1), op->getOperand(input2));
|
||||
|
||||
// shape equality propagation between outputs.
|
||||
for (int output1 = 0; output1 < numOutput; ++output1)
|
||||
for (int output2 = output1 + 1; output2 < numOutput; ++output2)
|
||||
if (op_fusibility.inferOutputsShapeEquality(output1, output2))
|
||||
update(op->getResult(output1), op->getResult(output2));
|
||||
|
||||
// shape equality propagation between input and output.
|
||||
for (int input = 0; input < numInput; ++input)
|
||||
for (int output = 0; output < numOutput; ++output)
|
||||
if (op_fusibility.inferInputOutputShapeEquality(input, output))
|
||||
update(op->getOperand(input), op->getResult(output));
|
||||
}
|
||||
} while (!converged);
|
||||
}
|
||||
|
||||
// a UnionFind set
|
||||
EquivalenceClasses<ValueWrapper> impl_;
|
||||
};
|
||||
|
||||
// A fusion planner that can propose a fusion plan for a block of ops.
|
||||
// The fusion plan is consisted of a group of fusion patterns.
|
||||
//
|
||||
// Currently all proposed patterns followed xla kLoop/kInput like fusion
|
||||
// templates while are adapted to the fully dynamic shape world.
|
||||
//
|
||||
// kLoop fusion template satifies:
|
||||
// - all ops in the fusion pattern are element-wise.
|
||||
// - all the shapes of outputs of fusion pattern are same, and thus can
|
||||
// fit into a same parallel loop.
|
||||
//
|
||||
// kInput fusion template satifies:
|
||||
// - any op in the fusion pattern is either element-wise or a reduction.
|
||||
// - if a op is a reduction, its output cannot be consumered by other
|
||||
// ops in the same fusion pattern.
|
||||
// - all the effective shapes of outputs of fusion pattern are same.
|
||||
// - For element-wise op, its effective shape is its output shape.
|
||||
// - For reduction op, its effective shape is its operand shape.
|
||||
class FusionPlanner {
|
||||
public:
|
||||
explicit FusionPlanner(const SmallVectorImpl<Operation*>& op_list)
|
||||
: op_list_(op_list),
|
||||
shape_analysis_(op_list),
|
||||
cycle_detector_(op_list.size()) {
|
||||
BuildNodeMap();
|
||||
}
|
||||
|
||||
// Returns a fusion plan if success, otherwise none.
|
||||
llvm::Optional<FusionPlan> Run() {
|
||||
// Greedily search connected fusible pattern, and ops belonging to
|
||||
// a same fusion pattern are grouped into a cluster.
|
||||
RunEdgeContractionLoop();
|
||||
|
||||
// After doing edge contraction, each unique cluster having size
|
||||
// more than one represents a potential fusion pattern.
|
||||
// We collect all these clusters and construct a fusion plan.
|
||||
//
|
||||
// Note that the ops in a fusion pattern are in topological ordering.
|
||||
FusionPlan plan;
|
||||
DenseMap<int, int> pattern_ids;
|
||||
for (Operation* op : op_list_) {
|
||||
Cluster* cluster = GetClusterForNode(op);
|
||||
int node_id = cluster->cycles_graph_node_id();
|
||||
if (!IsFusible(op_list_[node_id]) ||
|
||||
EffectiveSize(GetClusterForNode(op)->fused_pattern()) <= 1) {
|
||||
continue;
|
||||
}
|
||||
if (!pattern_ids.count(node_id)) {
|
||||
int pattern_id = pattern_ids.size();
|
||||
pattern_ids[node_id] = pattern_id;
|
||||
plan.emplace_back();
|
||||
}
|
||||
plan[pattern_ids[node_id]].push_back(op);
|
||||
}
|
||||
return plan;
|
||||
}
|
||||
|
||||
// Returns the op_list this planner operates on.
|
||||
const SmallVectorImpl<Operation*>& op_list() const { return op_list_; }
|
||||
|
||||
private:
|
||||
// Represent a (partial) fused pattern
|
||||
class Cluster {
|
||||
public:
|
||||
Cluster(int node_id, FusionPlanner* planner) : node_id_(node_id) {
|
||||
const SmallVectorImpl<Operation*>& op_list = planner->op_list();
|
||||
pattern_.push_back(op_list[node_id]);
|
||||
}
|
||||
|
||||
// Merges `other` into this cluster, and clears `other`.
|
||||
void Merge(Cluster* other) {
|
||||
pattern_.insert(pattern_.end(), other->pattern_.begin(),
|
||||
other->pattern_.end());
|
||||
other->pattern_.clear();
|
||||
}
|
||||
|
||||
// The number of nodes in this cluster.
|
||||
int cluster_size() const { return pattern_.size(); }
|
||||
|
||||
// The ID of the cluster as represented in `cycle_detector_`.
|
||||
int cycles_graph_node_id() const { return node_id_; }
|
||||
|
||||
// Sets the ID of the cluster as represented in `cycle_detector_`.
|
||||
void set_cycles_graph_node_id(int cycles_graph_node_id) {
|
||||
node_id_ = cycles_graph_node_id;
|
||||
}
|
||||
|
||||
// Currently the fused pattern this cluster holds.
|
||||
const FusionPattern& fused_pattern() { return pattern_; }
|
||||
|
||||
private:
|
||||
// ID of the representative node of this cluster.
|
||||
int node_id_;
|
||||
|
||||
// the fused pattern this cluster holds.
|
||||
FusionPattern pattern_;
|
||||
};
|
||||
|
||||
private:
|
||||
Cluster* MakeCluster(int cycles_graph_node_id) {
|
||||
cluster_storage_.emplace_back(new Cluster(cycles_graph_node_id, this));
|
||||
return cluster_storage_.back().get();
|
||||
}
|
||||
|
||||
void BuildNodeMap() {
|
||||
int num_nodes = op_list_.size();
|
||||
for (int node_id = 0; node_id < num_nodes; ++node_id) {
|
||||
Operation* op = op_list_[node_id];
|
||||
MakeCluster(node_id);
|
||||
op_to_node_id_[op] = node_id;
|
||||
leader_for_node_.insert(node_id);
|
||||
for (Value operand : op->getOperands()) {
|
||||
Operation* operand_op = operand.getDefiningOp();
|
||||
if (operand_op == nullptr) {
|
||||
// skip block argument
|
||||
continue;
|
||||
}
|
||||
auto iter = op_to_node_id_.find(operand_op);
|
||||
assert(iter != op_to_node_id_.end());
|
||||
cycle_detector_.InsertEdge(iter->second, node_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the cluster contains this op.
|
||||
Cluster* GetClusterForNode(Operation* n) {
|
||||
int id = op_to_node_id_.at(n);
|
||||
id = leader_for_node_.getLeaderValue(id);
|
||||
return cluster_storage_[id].get();
|
||||
}
|
||||
|
||||
// Returns the cluster contains the op having `node_id`.
|
||||
Cluster* GetClusterForCyclesGraphNode(int node_id) {
|
||||
return cluster_storage_[leader_for_node_.getLeaderValue(node_id)].get();
|
||||
}
|
||||
|
||||
// Merges the clusters `cluster_from` and `cluster_to`.
|
||||
bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
|
||||
int from = cluster_from->cycles_graph_node_id();
|
||||
int to = cluster_to->cycles_graph_node_id();
|
||||
|
||||
auto optional_merged_node = cycle_detector_.ContractEdge(from, to);
|
||||
if (!optional_merged_node.hasValue()) {
|
||||
llvm::dbgs() << "Could not contract " << from << " -> " << to
|
||||
<< " because contracting the edge would create a cycle.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Merge the clusters.
|
||||
cluster_from->Merge(cluster_to);
|
||||
cluster_from->set_cycles_graph_node_id(*optional_merged_node);
|
||||
|
||||
// Merge the UnionFind Set.
|
||||
leader_for_node_.unionSets(from, to);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename FnTy>
|
||||
bool ForEachEdgeInPostOrder(FnTy fn) {
|
||||
bool changed = false;
|
||||
for (int32_t node : cycle_detector_.AllNodesInPostOrder()) {
|
||||
Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
|
||||
// Make a copy of the set of successors because we may modify the graph in
|
||||
// TryToContractEdge.
|
||||
std::vector<int32_t> successors_copy =
|
||||
cycle_detector_.SuccessorsCopy(cluster_from->cycles_graph_node_id());
|
||||
|
||||
for (int to : successors_copy) {
|
||||
Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
|
||||
bool contracted_edge = fn(cluster_from, cluster_to);
|
||||
changed |= contracted_edge;
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
// returns the outputs if two cluster were merged
|
||||
SmallVector<Value, 4> GetResultsOfFusedPattern(Cluster* from, Cluster* to) {
|
||||
FusionPattern fused_pattern =
|
||||
MergeFusionPattern(from->fused_pattern(), to->fused_pattern());
|
||||
return GetOutputsOfFusionPattern(fused_pattern);
|
||||
}
|
||||
|
||||
// This function check if fusing `from` with `to` is valid and if so perform
|
||||
// the merge. The validity is based on the operations in the clusters and
|
||||
// the compatibility of the shapes of the outputs of the would-be fused
|
||||
// clusters.
|
||||
// Returns true is the merge was performed.
|
||||
bool TryToContractEdge(Cluster* from, Cluster* to) {
|
||||
int node_to = to->cycles_graph_node_id();
|
||||
int node_from = from->cycles_graph_node_id();
|
||||
|
||||
// Both node_to and node_from should be fusible
|
||||
if (!IsFusible(op_list_[node_to]) || !IsFusible(op_list_[node_from])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto op_from_fusibility =
|
||||
dyn_cast<InferFusibilityOpInterface>(op_list_[node_from]);
|
||||
if (op_from_fusibility && !op_from_fusibility.isFusibleWithConsumer()) {
|
||||
// This op cannot be fused with its consumers.
|
||||
return false;
|
||||
}
|
||||
|
||||
auto op_to_fusibility =
|
||||
dyn_cast<InferFusibilityOpInterface>(op_list_[node_to]);
|
||||
if (op_to_fusibility && !op_to_fusibility.isFusibleWithOperand()) {
|
||||
// This op cannot be fused with its operands.
|
||||
return false;
|
||||
}
|
||||
|
||||
// Output shapes of a fusion pattern should be compatible as described in
|
||||
// the document of this class.
|
||||
SmallVector<Value, 4> results = GetResultsOfFusedPattern(from, to);
|
||||
auto get_workload_shape = [](Value v) {
|
||||
Operation* op = v.getDefiningOp();
|
||||
// Block argument
|
||||
if (!op) return v;
|
||||
auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
|
||||
// Const value
|
||||
if (!op_fusibility) return v;
|
||||
llvm::Optional<Value> workload =
|
||||
op_fusibility.inferEffectiveWorkloadShape();
|
||||
return workload.hasValue() ? *workload : v;
|
||||
};
|
||||
|
||||
Value ref = get_workload_shape(results[0]);
|
||||
if (!llvm::all_of(results, [&](Value result) {
|
||||
Value val = get_workload_shape(result);
|
||||
return shape_analysis_.HasSameShape(ref, val);
|
||||
})) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return MergeClusters(from, to);
|
||||
}
|
||||
|
||||
// Greedily fuse connected node.
|
||||
bool RunEdgeContractionLoop() {
|
||||
using std::placeholders::_1;
|
||||
using std::placeholders::_2;
|
||||
return ForEachEdgeInPostOrder(
|
||||
std::bind(&FusionPlanner::TryToContractEdge, this, _1, _2));
|
||||
}
|
||||
|
||||
const SmallVectorImpl<Operation*>& op_list_;
|
||||
|
||||
// Shape equality checker
|
||||
ShapeConstraintAnalysis shape_analysis_;
|
||||
|
||||
// op -> node_id
|
||||
std::unordered_map<Operation*, int> op_to_node_id_;
|
||||
|
||||
// make sure not introduce cycle after fusion
|
||||
GraphCycles cycle_detector_;
|
||||
std::vector<std::unique_ptr<Cluster>> cluster_storage_;
|
||||
|
||||
// a UnionFind set. Each set represents a (partial) fused pattern
|
||||
// and has a leader as representation.
|
||||
EquivalenceClasses<int32_t> leader_for_node_;
|
||||
};
|
||||
|
||||
struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
|
||||
void runOnFunction() override {
|
||||
FuncOp func = getFunction();
|
||||
if (!IsTargetFunc(func)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// process each block and do fusion within a block.
|
||||
for (Block& block : func.getBlocks()) {
|
||||
SmallVector<Operation*, 4> op_list;
|
||||
for (Operation& op : block) {
|
||||
op_list.push_back(&op);
|
||||
}
|
||||
|
||||
FusionPlanner planner(op_list);
|
||||
llvm::Optional<FusionPlan> plan = planner.Run();
|
||||
if (!plan) {
|
||||
emitError(func.getLoc(), "can't find a fusion plan");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
if (!ApplyFusionPlan(*plan)) {
|
||||
emitError(func.getLoc(), "apply fusion plan failed");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool IsTargetFunc(FuncOp func) {
|
||||
int num_fusible_ops = 0;
|
||||
bool is_target_func = false;
|
||||
// We only process the function having enough candidates
|
||||
func.walk([&](Operation* op) {
|
||||
num_fusible_ops +=
|
||||
static_cast<int>(dyn_cast<InferFusibilityOpInterface>(op) != nullptr);
|
||||
is_target_func = (num_fusible_ops > 1);
|
||||
// early stop
|
||||
if (is_target_func) return WalkResult::interrupt();
|
||||
return WalkResult::advance();
|
||||
});
|
||||
return is_target_func;
|
||||
}
|
||||
|
||||
bool ApplyFusionPlan(const FusionPlan& plan) {
|
||||
for (const FusionPattern& pattern : plan) {
|
||||
OpBuilder b(pattern.back());
|
||||
|
||||
SmallVector<Location, 4> locations;
|
||||
locations.reserve(pattern.size());
|
||||
for (Operation* op : pattern) {
|
||||
locations.push_back(op->getLoc());
|
||||
}
|
||||
Location fused_loc =
|
||||
FusedLoc::get(locations, pattern.back()->getContext());
|
||||
|
||||
SmallVector<Value, 4> inputs = GetInputsOfFusionPattern(pattern);
|
||||
SmallVector<Value, 4> outputs = GetOutputsOfFusionPattern(pattern);
|
||||
SmallVector<Type, 4> output_types;
|
||||
output_types.reserve(outputs.size());
|
||||
for (Value v : outputs) {
|
||||
output_types.push_back(v.getType());
|
||||
}
|
||||
|
||||
FusionOp fusion =
|
||||
b.create<xla_hlo::FusionOp>(fused_loc, output_types, inputs);
|
||||
Region& region = fusion.fused_computation();
|
||||
region.push_back(new Block);
|
||||
Block& block = region.front();
|
||||
for (Operation* op : pattern) {
|
||||
op->moveBefore(&block, block.end());
|
||||
}
|
||||
b.setInsertionPoint(&block, block.end());
|
||||
b.create<xla_hlo::ReturnOp>(fused_loc, outputs);
|
||||
|
||||
for (auto output_and_result : llvm::zip(outputs, fusion.getResults())) {
|
||||
Value output = std::get<0>(output_and_result);
|
||||
Value fusion_result = std::get<1>(output_and_result);
|
||||
for (OpOperand& use : llvm::make_early_inc_range(output.getUses())) {
|
||||
if (use.getOwner()->getBlock() != &block) use.set(fusion_result);
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusion() {
|
||||
return std::make_unique<XlaHloFusion>();
|
||||
}
|
||||
|
||||
static PassRegistration<XlaHloFusion> xla_hlo_fusion_pass(
|
||||
"xla-hlo-fusion", "fuse xla_hlo ops to kLoop/kInput fusion patterns.");
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
@ -192,6 +192,108 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// xla_lhlo.convolution conversion pattern.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Converts xla_lhlo.convolution operation to a linalg.conv op.
|
||||
struct ConvToLinalgConverter : public OpConversionPattern<xla_lhlo::ConvOp> {
|
||||
public:
|
||||
using OpConversionPattern<xla_lhlo::ConvOp>::OpConversionPattern;
|
||||
|
||||
// This code has been adapted from IREE's
|
||||
// (https://github.com/google/iree/) xla_hlo -> linalg conversion.
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_lhlo::ConvOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
// Check validity of dimension information.
|
||||
if (const xla_lhlo::ConvDimensionNumbers& dimensionNumbers =
|
||||
op.dimension_numbers()) {
|
||||
const int inputSpatialRank =
|
||||
llvm::size(dimensionNumbers.input_spatial_dimensions());
|
||||
// The dimensions for input should follow the order of
|
||||
// batch_count, spatial_dims..., input_feature_count.
|
||||
if (dimensionNumbers.input_batch_dimension().getInt() != 0 ||
|
||||
dimensionNumbers.input_feature_dimension().getInt() !=
|
||||
(inputSpatialRank + 1))
|
||||
return failure();
|
||||
|
||||
const int kernelSpatialRank =
|
||||
llvm::size(dimensionNumbers.kernel_spatial_dimensions());
|
||||
// The dimensions for filter should follow the order of
|
||||
// spatial_dims..., input_feature_count, num_output_feature_count.
|
||||
if (dimensionNumbers.kernel_input_feature_dimension().getInt() !=
|
||||
kernelSpatialRank ||
|
||||
dimensionNumbers.kernel_output_feature_dimension().getInt() !=
|
||||
(kernelSpatialRank + 1))
|
||||
return failure();
|
||||
|
||||
const int outputSpatialRank =
|
||||
llvm::size(dimensionNumbers.output_spatial_dimensions());
|
||||
// The dimensions for output should follow the order of
|
||||
// batch_count, spatial_dims.., output_feature_count.
|
||||
if (dimensionNumbers.output_batch_dimension().getInt() != 0 ||
|
||||
dimensionNumbers.output_feature_dimension().getInt() !=
|
||||
(outputSpatialRank + 1))
|
||||
return failure();
|
||||
|
||||
if (inputSpatialRank != outputSpatialRank ||
|
||||
inputSpatialRank != kernelSpatialRank)
|
||||
return failure();
|
||||
|
||||
auto inputSpatialDim =
|
||||
dimensionNumbers.input_spatial_dimensions().begin();
|
||||
auto kernelSpatialDim =
|
||||
dimensionNumbers.kernel_spatial_dimensions().begin();
|
||||
auto outputSpatialDim =
|
||||
dimensionNumbers.output_spatial_dimensions().begin();
|
||||
// Check if spatial dims are ordered correctly.
|
||||
for (int i = 0; i < inputSpatialRank; ++i) {
|
||||
const int dim = i + 1;
|
||||
if ((*inputSpatialDim++).getZExtValue() != dim ||
|
||||
(*outputSpatialDim++).getZExtValue() != dim ||
|
||||
(*kernelSpatialDim++).getZExtValue() != i)
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: LHS dilation for deconvolution not supported yet.
|
||||
if (op.lhs_dilation()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
llvm::SmallVector<Attribute, 4> strides;
|
||||
if (auto windowStrides = op.window_strides()) {
|
||||
auto range = windowStrides->getAttributeValues();
|
||||
strides.assign(range.begin(), range.end());
|
||||
}
|
||||
auto stridesArg = ArrayAttr::get(strides, op.getContext());
|
||||
|
||||
llvm::SmallVector<Attribute, 2> dilation;
|
||||
if (auto rhsDilation = op.rhs_dilation()) {
|
||||
auto range = rhsDilation->getAttributeValues();
|
||||
dilation.assign(range.begin(), range.end());
|
||||
} else {
|
||||
// Default dilation of 1.
|
||||
dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1));
|
||||
}
|
||||
auto dilationArg = ArrayAttr::get(dilation, op.getContext());
|
||||
|
||||
// Set padding only if it is non-zero.
|
||||
DenseIntElementsAttr padding = op.paddingAttr();
|
||||
if (!padding || !llvm::any_of(padding.getValues<APInt>(), [](APInt intVal) {
|
||||
return !intVal.isNullValue();
|
||||
})) {
|
||||
padding = nullptr;
|
||||
}
|
||||
|
||||
// The order of input and filter are switched with linalg.conv.
|
||||
rewriter.replaceOpWithNewOp<linalg::ConvOp>(
|
||||
op, args[1], args[0], args[2], stridesArg, dilationArg, padding);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Base class for lowering xla operations that have one operand and one result,
|
||||
/// and are semantically equivalent to a copy of the input to the output (like
|
||||
/// transpose, some reshape, etc.). The derived classes need to provide a method
|
||||
@ -814,6 +916,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
// clang-format off
|
||||
patterns->insert<BroadcastConverter<xla_lhlo::BroadcastOp>,
|
||||
ConstConverter,
|
||||
ConvToLinalgConverter,
|
||||
IotaConverter,
|
||||
LhloBroadcastInDimConverter,
|
||||
PointwiseToLinalgConverter<xla_lhlo::AbsOp>,
|
||||
|
@ -85,6 +85,11 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
for i in xrange(len(result)):
|
||||
self.assertAllClose(result[i], expected[i], rtol, atol)
|
||||
|
||||
def AssertCloseAndSorted(self, result, expected, rtol, atol):
|
||||
"""Tests that result and expeted are both close and sorted."""
|
||||
self.assertAllClose(result, expected, rtol, atol)
|
||||
self.assertAllEqual(np.sort(result), result)
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"MlirHloBuilder::Iota missing required for xla::Diag")
|
||||
def testAllTypeOps(self):
|
||||
@ -510,6 +515,16 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
],
|
||||
dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"TODO(b/155501444): Handle _UnaryOpsComposition ops from Grappler")
|
||||
def testFloatOpsDisabledOnMlirBridge(self):
|
||||
for dtype in self.float_types:
|
||||
if dtype != np.float16:
|
||||
self._assertOpOutputMatchesExpected(
|
||||
lambda x: math_ops.sigmoid(x) / math_ops.log1p(math_ops.exp(x)),
|
||||
np.array([-40, 40], dtype=dtype),
|
||||
expected=np.array([1.0, 0.025], dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"TODO(b/153812660): Handle tf.QuantizeAndDequantize compilation")
|
||||
def testQuantizeAndDequantize(self):
|
||||
@ -1112,17 +1127,27 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
[[[12, 13, 14, 15, 28, 29, 30, 31]]]]],
|
||||
dtype=dtype))
|
||||
|
||||
def _assertSoftplusMatchesExpected(self, features, dtype):
|
||||
def _assertSoftplusMatchesExpected(self,
|
||||
features,
|
||||
dtype,
|
||||
equality_test=None,
|
||||
rtol=1e-6,
|
||||
atol=9.1e-6):
|
||||
features = np.array(features, dtype=dtype)
|
||||
zero = np.asarray(0).astype(dtype)
|
||||
expected = np.logaddexp(zero, features).astype(dtype)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
nn_ops.softplus, features, expected=expected, rtol=1e-6, atol=9.1e-6)
|
||||
nn_ops.softplus,
|
||||
features,
|
||||
expected=expected,
|
||||
equality_test=equality_test,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"bf16 type not supported in CreateDenseElementsAttrFromLiteral")
|
||||
def testSoftplus(self):
|
||||
for dtype in self.float_types:
|
||||
for dtype in self.float_types & {dtypes.float32, dtypes.float64}:
|
||||
self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype)
|
||||
self._assertSoftplusMatchesExpected(
|
||||
[[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]], dtype)
|
||||
@ -1138,6 +1163,13 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
-log_eps + ten
|
||||
], dtype)
|
||||
|
||||
self._assertSoftplusMatchesExpected(
|
||||
[0.69302183, 0.69324386],
|
||||
dtype,
|
||||
equality_test=self.AssertCloseAndSorted,
|
||||
rtol=9e-5,
|
||||
atol=9e-5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -486,7 +486,7 @@ class SliceAssignTest(xla_test.XLATestCase):
|
||||
|
||||
def testUninitialized(self):
|
||||
with self.assertRaisesRegexp(errors.FailedPreconditionError,
|
||||
"uninitialized variable"):
|
||||
"uninitialized"):
|
||||
with self.session() as sess, self.test_scope():
|
||||
v = resource_variable_ops.ResourceVariable([1, 2])
|
||||
sess.run(v[:].assign([1, 2]))
|
||||
|
@ -89,16 +89,25 @@ XLAJIT_MAKE_UNARY(Sign,
|
||||
xla::Select(xla::Ne(x, x), xla::ZerosLike(x), xla::Sign(x)));
|
||||
XLAJIT_MAKE_UNARY(Sinh, xla::Sinh(x));
|
||||
|
||||
// softplus(x) = log(1 + exp(x))
|
||||
//
|
||||
// This is not numerically stable when x is large, it can easily overflow.
|
||||
// However, we can compute it as LogSumExp(x, 0):
|
||||
// max(x, 0) + log(exp(x - max(x, 0)) + exp(0 - max(x, 0)))
|
||||
//
|
||||
// This is equivalent to:
|
||||
// max(x, 0) + log1p(exp(-abs(x)))
|
||||
XLAJIT_MAKE_UNARY(Softplus, xla::Max(x, xla::ScalarLike(x, 0.0)) +
|
||||
xla::Log1p(xla::Exp(-xla::Abs(x))));
|
||||
static xla::XlaOp Softplus(xla::XlaBuilder* b, xla::XlaOp features) {
|
||||
return b->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(features));
|
||||
xla::XlaOp threshold =
|
||||
Log(xla::Epsilon(b, shape.element_type())) + ScalarLike(features, 2.0);
|
||||
// Value above which exp(x) may overflow, but softplus(x) == x
|
||||
// is within machine epsilon.
|
||||
xla::XlaOp too_large = Gt(features, -threshold);
|
||||
// Value below which exp(x) may underflow, but softplus(x) == exp(x)
|
||||
// is within machine epsilon.
|
||||
xla::XlaOp too_small = Lt(features, threshold);
|
||||
xla::XlaOp features_exp = Exp(features);
|
||||
xla::XlaOp output =
|
||||
Select(too_large, features,
|
||||
Select(too_small, features_exp, Log1p(features_exp)));
|
||||
return output;
|
||||
});
|
||||
}
|
||||
XLAJIT_MAKE_UNARY(Softplus, Softplus(b, x));
|
||||
|
||||
// softsign(x) = x / (abs(x) + 1)
|
||||
XLAJIT_MAKE_UNARY(Softsign, x / (xla::Abs(x) + xla::ScalarLike(x, 1.0)));
|
||||
|
@ -413,8 +413,10 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
|
||||
TF_RET_CHECK(variable != nullptr);
|
||||
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
|
||||
if (!variable->initialized()) {
|
||||
return errors::FailedPrecondition("Read of uninitialized variable ",
|
||||
variable->name());
|
||||
return errors::FailedPrecondition(
|
||||
"Read variable failure ", variable->name(),
|
||||
". It could mean the variable is uninitialized or the variable is on "
|
||||
"another device ");
|
||||
}
|
||||
if (variable->type() != type) {
|
||||
return errors::InvalidArgument(
|
||||
@ -464,8 +466,10 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
|
||||
TF_RET_CHECK(variable != nullptr);
|
||||
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
|
||||
if (!variable->initialized()) {
|
||||
return errors::InvalidArgument("Read of uninitialized variable ",
|
||||
variable->name());
|
||||
return errors::InvalidArgument(
|
||||
"Read variable failure ", variable->name(),
|
||||
". It could mean the variable is uninitialized or the variable is on "
|
||||
"another device ");
|
||||
}
|
||||
*type = variable->type();
|
||||
*shape = variable->shape();
|
||||
|
@ -1394,8 +1394,8 @@ XlaOp NextAfter(XlaOp from, XlaOp to) {
|
||||
}
|
||||
|
||||
XlaOp Logistic(XlaOp x) {
|
||||
auto half = xla::ScalarLike(x, 0.5);
|
||||
return half + half * xla::Tanh(half * x);
|
||||
auto one = xla::ScalarLike(x, 1);
|
||||
return xla::Div(one, (one + xla::Exp(xla::Neg(x))));
|
||||
}
|
||||
|
||||
// Computes an approximation to the modified Bessel function of the first kind,
|
||||
|
@ -227,9 +227,9 @@ PyNumberMethods PyBfloat16_AsNumber = {
|
||||
nullptr, // nb_and
|
||||
nullptr, // nb_xor
|
||||
nullptr, // nb_or
|
||||
PyBfloat16_Int, // nb_int
|
||||
nullptr, // reserved
|
||||
PyBfloat16_Float, // nb_float
|
||||
PyBfloat16_Int, // nb_int
|
||||
nullptr, // reserved
|
||||
PyBfloat16_Float, // nb_float
|
||||
|
||||
nullptr, // nb_inplace_add
|
||||
nullptr, // nb_inplace_subtract
|
||||
@ -441,6 +441,29 @@ void ByteSwap16(void* value) {
|
||||
std::swap(p[0], p[1]);
|
||||
}
|
||||
|
||||
int NPyBfloat16_Compare(const void* a, const void* b, void* arr) {
|
||||
bfloat16 x;
|
||||
memcpy(&x, a, sizeof(bfloat16));
|
||||
|
||||
bfloat16 y;
|
||||
memcpy(&y, b, sizeof(bfloat16));
|
||||
|
||||
if (x < y) {
|
||||
return -1;
|
||||
}
|
||||
if (y < x) {
|
||||
return 1;
|
||||
}
|
||||
// NaNs sort to the end.
|
||||
if (!std::isnan(x) && std::isnan(y)) {
|
||||
return -1;
|
||||
}
|
||||
if (std::isnan(x) && !std::isnan(y)) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
void NPyBfloat16_CopySwapN(void* dstv, npy_intp dstride, void* srcv,
|
||||
npy_intp sstride, npy_intp n, int swap, void* arr) {
|
||||
char* dst = reinterpret_cast<char*>(dstv);
|
||||
@ -1213,7 +1236,44 @@ struct LogicalXor {
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(phawkins): implement nextafter, spacing
|
||||
struct NextAfter {
|
||||
bfloat16 operator()(bfloat16 from, bfloat16 to) {
|
||||
uint16_t from_as_int, to_as_int;
|
||||
const uint16_t sign_mask = 1 << 15;
|
||||
float from_as_float(from), to_as_float(to);
|
||||
memcpy(&from_as_int, &from, sizeof(bfloat16));
|
||||
memcpy(&to_as_int, &to, sizeof(bfloat16));
|
||||
if (std::isnan(from_as_float) || std::isnan(to_as_float)) {
|
||||
return bfloat16(std::numeric_limits<float>::quiet_NaN());
|
||||
}
|
||||
if (from_as_int == to_as_int) {
|
||||
return to;
|
||||
}
|
||||
if (from_as_float == 0) {
|
||||
if (to_as_float == 0) {
|
||||
return to;
|
||||
} else {
|
||||
// Smallest subnormal signed like `to`.
|
||||
uint16_t out_int = (to_as_int & sign_mask) | 1;
|
||||
bfloat16 out;
|
||||
memcpy(&out, &out_int, sizeof(bfloat16));
|
||||
return out;
|
||||
}
|
||||
}
|
||||
uint16_t from_sign = from_as_int & sign_mask;
|
||||
uint16_t to_sign = to_as_int & sign_mask;
|
||||
uint16_t from_abs = from_as_int & ~sign_mask;
|
||||
uint16_t to_abs = to_as_int & ~sign_mask;
|
||||
uint16_t magnitude_adjustment =
|
||||
(from_abs > to_abs || from_sign != to_sign) ? 0xFFFF : 0x0001;
|
||||
uint16_t out_int = from_as_int + magnitude_adjustment;
|
||||
bfloat16 out;
|
||||
memcpy(&out, &out_int, sizeof(bfloat16));
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(phawkins): implement spacing
|
||||
|
||||
} // namespace ufuncs
|
||||
|
||||
@ -1243,6 +1303,7 @@ bool Initialize() {
|
||||
PyArray_InitArrFuncs(&NPyBfloat16_ArrFuncs);
|
||||
NPyBfloat16_ArrFuncs.getitem = NPyBfloat16_GetItem;
|
||||
NPyBfloat16_ArrFuncs.setitem = NPyBfloat16_SetItem;
|
||||
NPyBfloat16_ArrFuncs.compare = NPyBfloat16_Compare;
|
||||
NPyBfloat16_ArrFuncs.copyswapn = NPyBfloat16_CopySwapN;
|
||||
NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap;
|
||||
NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero;
|
||||
@ -1467,7 +1528,9 @@ bool Initialize() {
|
||||
RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Ceil>>(numpy.get(),
|
||||
"ceil") &&
|
||||
RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Trunc>>(numpy.get(),
|
||||
"trunc");
|
||||
"trunc") &&
|
||||
RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::NextAfter>>(
|
||||
numpy.get(), "nextafter");
|
||||
|
||||
return ok;
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import itertools
|
||||
import math
|
||||
|
||||
from absl.testing import absltest
|
||||
@ -218,6 +219,12 @@ class Bfloat16Test(parameterized.TestCase):
|
||||
numpy_assert_allclose(
|
||||
a, b, rtol=0.1, atol=0.1, equal_nan=True, err_msg="", verbose=True)
|
||||
|
||||
def testSort(self):
|
||||
values_to_sort = np.float32(FLOAT_VALUES)
|
||||
sorted_f32 = np.sort(values_to_sort)
|
||||
sorted_bf16 = np.sort(values_to_sort.astype(bfloat16))
|
||||
np.testing.assert_equal(sorted_f32, np.float32(sorted_bf16))
|
||||
|
||||
|
||||
BinaryOp = collections.namedtuple("BinaryOp", ["op"])
|
||||
|
||||
@ -398,6 +405,26 @@ class Bfloat16NumPyTest(parameterized.TestCase):
|
||||
np.testing.assert_equal(exp1, exp2)
|
||||
numpy_assert_allclose(mant1, mant2, rtol=1e-2)
|
||||
|
||||
def testNextAfter(self):
|
||||
one = np.array(1., dtype=bfloat16)
|
||||
two = np.array(2., dtype=bfloat16)
|
||||
zero = np.array(0., dtype=bfloat16)
|
||||
nan = np.array(np.nan, dtype=bfloat16)
|
||||
np.testing.assert_equal(np.nextafter(one, two) - one, epsilon)
|
||||
np.testing.assert_equal(np.nextafter(one, zero) - one, -epsilon / 2)
|
||||
np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True)
|
||||
np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True)
|
||||
np.testing.assert_equal(np.nextafter(one, one), one)
|
||||
smallest_denormal = float.fromhex("1.0p-133")
|
||||
np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal)
|
||||
np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal)
|
||||
for a, b in itertools.permutations([0., -0., nan], 2):
|
||||
np.testing.assert_equal(
|
||||
np.nextafter(
|
||||
np.array(a, dtype=np.float32), np.array(b, dtype=np.float32)),
|
||||
np.nextafter(
|
||||
np.array(a, dtype=bfloat16), np.array(b, dtype=bfloat16)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
@ -2973,6 +2973,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
|
||||
// slice instruction should all have the same layout.
|
||||
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
|
||||
pad->shape(), nonzero_pad->mutable_shape()));
|
||||
simplifier_->UpdateLayout(nonzero_pad->mutable_shape());
|
||||
|
||||
// Second, construct the slice instruction to perform the negative padding.
|
||||
std::vector<int64> start_indices;
|
||||
@ -2999,9 +3000,14 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
|
||||
MakeSliceHlo(nonzero_pad, start_indices, end_indices, strides));
|
||||
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
|
||||
pad->shape(), slice->mutable_shape()));
|
||||
simplifier_->UpdateLayout(slice->mutable_shape());
|
||||
|
||||
// Verify that the slice shape matches the pad shape.
|
||||
TF_RET_CHECK(ShapeUtil::Equal(slice->shape(), pad->shape()));
|
||||
auto equal = Shape::Equal();
|
||||
if (!options_.is_layout_sensitive()) {
|
||||
equal.IgnoreTilesInLayout();
|
||||
}
|
||||
TF_RET_CHECK(equal(slice->shape(), pad->shape()));
|
||||
|
||||
return ReplaceInstruction(pad, slice);
|
||||
}
|
||||
@ -3058,20 +3064,6 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
|
||||
return false;
|
||||
}
|
||||
HloInstruction* operand = broadcast->mutable_operand(0);
|
||||
auto is_scalar_broadcast = [](const HloInstruction* instruction) {
|
||||
return instruction->opcode() == HloOpcode::kBroadcast &&
|
||||
ShapeUtil::IsScalar(instruction->operand(0)->shape());
|
||||
};
|
||||
auto is_equal_broadcast = [operand,
|
||||
broadcast](const HloInstruction* instruction) {
|
||||
return instruction->opcode() == HloOpcode::kBroadcast &&
|
||||
ShapeUtil::Equal(operand->shape(),
|
||||
instruction->operand(0)->shape()) &&
|
||||
broadcast->dimensions() == instruction->dimensions();
|
||||
};
|
||||
auto is_compatible_broadcast = [&](const HloInstruction* instruction) {
|
||||
return is_scalar_broadcast(instruction) || is_equal_broadcast(instruction);
|
||||
};
|
||||
for (HloInstruction* user : broadcast->users()) {
|
||||
if (user->user_count() == 0 && user != computation_->root_instruction()) {
|
||||
continue;
|
||||
@ -3090,20 +3082,18 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if all the operands of the user are compatible broadcasts for
|
||||
// sinking. (They are either scalar broadcasts or broadcasts casting
|
||||
// from/to the same shape/dimensions)
|
||||
int64 compatible_broadcast_count = 0;
|
||||
// Find the unique non-scalar operand or continue if there isn't one.
|
||||
int64 scalar_broadcast_count = 0;
|
||||
int64 broadcast_use_count = 0;
|
||||
for (HloInstruction* user_operand : user->operands()) {
|
||||
if (is_compatible_broadcast(user_operand)) {
|
||||
++compatible_broadcast_count;
|
||||
if (user_operand->opcode() == HloOpcode::kBroadcast &&
|
||||
ShapeUtil::IsScalar(user_operand->operand(0)->shape())) {
|
||||
++scalar_broadcast_count;
|
||||
} else if (broadcast == user_operand) {
|
||||
++broadcast_use_count;
|
||||
}
|
||||
}
|
||||
if (compatible_broadcast_count + broadcast_use_count !=
|
||||
user->operand_count()) {
|
||||
if (scalar_broadcast_count + broadcast_use_count != user->operand_count()) {
|
||||
continue;
|
||||
}
|
||||
std::vector<HloInstruction*> new_operands;
|
||||
@ -3111,24 +3101,14 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
|
||||
|
||||
Shape changed_shape;
|
||||
for (HloInstruction* user_operand : user->operands()) {
|
||||
// If this is a broadcast operand that is not our original broadcast input
|
||||
// to this function then we might need to change the input.
|
||||
if (is_compatible_broadcast(user_operand)) {
|
||||
// If this is a broadcast from a scalar value rewrite a broadcast from
|
||||
// the scalar to the new shape enforced from the other broadcast
|
||||
// operands.
|
||||
if (is_scalar_broadcast(user_operand)) {
|
||||
changed_shape = ShapeUtil::ChangeElementType(
|
||||
operand->shape(), user_operand->shape().element_type());
|
||||
simplifier_->UpdateLayout(&changed_shape);
|
||||
new_operands.push_back(
|
||||
computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
changed_shape, user_operand->mutable_operand(0), {})));
|
||||
} else {
|
||||
// For the non-scalar broadcasts we guarantee that the shape of the
|
||||
// operand of the broadcast needs to be already a compatible shape.
|
||||
new_operands.push_back(user_operand->mutable_operand(0));
|
||||
}
|
||||
if (user_operand->opcode() == HloOpcode::kBroadcast &&
|
||||
ShapeUtil::IsScalar(user_operand->operand(0)->shape())) {
|
||||
changed_shape = ShapeUtil::ChangeElementType(
|
||||
operand->shape(), user_operand->shape().element_type());
|
||||
simplifier_->UpdateLayout(&changed_shape);
|
||||
new_operands.push_back(
|
||||
computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
changed_shape, user_operand->mutable_operand(0), {})));
|
||||
} else {
|
||||
CHECK_EQ(broadcast, user_operand);
|
||||
new_operands.push_back(operand);
|
||||
@ -3643,12 +3623,17 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
|
||||
new_slice_strides.push_back(slice->slice_strides(dim));
|
||||
new_slice_limits.push_back(slice->slice_limits(dim));
|
||||
}
|
||||
VLOG(3) << "Sink broadcast through slice";
|
||||
VLOG(3) << "Original slice: " << slice->ToString();
|
||||
VLOG(3) << "Original broadcast: " << broadcast->ToString();
|
||||
TF_ASSIGN_OR_RETURN(auto new_slice,
|
||||
MakeSliceHlo(broadcast_operand, new_slice_starts,
|
||||
new_slice_limits, new_slice_strides));
|
||||
return ReplaceInstruction(
|
||||
slice,
|
||||
MakeBroadcastHlo(new_slice, broadcast->dimensions(), slice->shape()));
|
||||
auto new_broadcast = HloInstruction::CreateBroadcast(
|
||||
slice->shape(), new_slice, broadcast->dimensions());
|
||||
VLOG(3) << "New slice: " << slice->ToString();
|
||||
VLOG(3) << "New broadcast: " << new_broadcast->ToString();
|
||||
return ReplaceWithNewInstruction(slice, std::move(new_broadcast));
|
||||
}
|
||||
|
||||
// Try to simplify concat -> slice to an operand of concat.
|
||||
@ -3728,16 +3713,21 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
|
||||
new_indices.push_back(dynamic_slice->mutable_operand(1 + dim));
|
||||
new_slice_sizes.push_back(dynamic_slice->slice_sizes(dim));
|
||||
}
|
||||
|
||||
VLOG(3) << "Sink broadcast through dynamic slice";
|
||||
VLOG(3) << "Original dynamic slice: " << dynamic_slice->ToString();
|
||||
VLOG(3) << "Original broadcast: " << operand->ToString();
|
||||
HloInstruction* new_dynamic_slice = broadcast_operand;
|
||||
if (!new_slice_sizes.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
new_dynamic_slice,
|
||||
MakeDynamicSliceHlo(broadcast_operand, new_indices, new_slice_sizes));
|
||||
}
|
||||
return ReplaceInstruction(
|
||||
dynamic_slice,
|
||||
MakeBroadcastHlo(new_dynamic_slice, operand->dimensions(),
|
||||
dynamic_slice->shape()));
|
||||
auto new_broadcast = HloInstruction::CreateBroadcast(
|
||||
dynamic_slice->shape(), new_dynamic_slice, operand->dimensions());
|
||||
VLOG(3) << "New dynamic slice: " << dynamic_slice->ToString();
|
||||
VLOG(3) << "New broadcast: " << new_broadcast->ToString();
|
||||
return ReplaceWithNewInstruction(dynamic_slice, std::move(new_broadcast));
|
||||
}
|
||||
|
||||
// Convert a dynamic slice into a slice if all offsets are constant and the
|
||||
|
@ -338,79 +338,6 @@ TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeBroadcastedConstants) {
|
||||
m::ConstantScalar(3.0))))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsScalar) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p0 = f32[] parameter(0)
|
||||
p1 = f32[] parameter(1)
|
||||
b0 = f32[4] broadcast(p0), dimensions={}
|
||||
b1 = f32[4] broadcast(p1), dimensions={}
|
||||
ROOT multiply = f32[4] multiply(b1, b0)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(
|
||||
m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Broadcast(m::Multiply(m::Broadcast(m::Parameter(1)),
|
||||
m::Broadcast(m::Parameter(0))))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsConstantMix) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p0 = f32[4] parameter(0)
|
||||
c0 = f32[] constant(2.0)
|
||||
b0 = f32[4,2] broadcast(c0), dimensions={}
|
||||
b1 = f32[4,2] broadcast(p0), dimensions={0}
|
||||
ROOT multiply = f32[4,2] multiply(b1, b0)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Broadcast(m::Multiply(
|
||||
m::Parameter(0), m::Broadcast(m::ConstantScalar(2.0))))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsNonScalar) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p0 = f32[4] parameter(0)
|
||||
p1 = f32[4] parameter(1)
|
||||
b0 = f32[4,2] broadcast(p0), dimensions={0}
|
||||
b1 = f32[4,2] broadcast(p1), dimensions={0}
|
||||
ROOT multiply = f32[4,2] multiply(b1, b0)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(
|
||||
m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Broadcast(m::Multiply(m::Parameter(1), m::Parameter(0)))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, ElementwiseNoSinkBroadcastsDifferentDims) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p0 = f32[4] parameter(0)
|
||||
p1 = f32[8] parameter(1)
|
||||
b0 = f32[4,8] broadcast(p0), dimensions={0}
|
||||
b1 = f32[4,8] broadcast(p1), dimensions={1}
|
||||
ROOT multiply = f32[4,8] multiply(b1, b0)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Multiply(m::Broadcast(m::Parameter(1)),
|
||||
m::Broadcast(m::Parameter(0)))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest,
|
||||
MultiplyReassociateMultiplyOfConstantAndBroadcast) {
|
||||
const char* kModuleStr = R"(
|
||||
@ -2612,6 +2539,28 @@ TEST_F(AlgebraicSimplifierTest, SliceOfBroadcast) {
|
||||
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Slice(m::Parameter(0)))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, SliceOfBroadcastPreserveLayout) {
|
||||
const char* hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY test {
|
||||
p0 = f32[10,20] parameter(0)
|
||||
b = f32[10,30,20]{2,0,1:T(256)} broadcast(p0), dimensions={0,2}
|
||||
ROOT s = f32[5,5,5]{2,0,1:T(256)} slice(b), slice={[0:5:1], [5:25:4], [5:15:2]}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
const Shape original_slice_shape =
|
||||
module->entry_computation()->root_instruction()->shape();
|
||||
HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
|
||||
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Slice(m::Parameter(0)))));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), original_slice_shape));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, DynamicSliceOfBroadcast) {
|
||||
const char* hlo_string = R"(
|
||||
HloModule module
|
||||
@ -2635,6 +2584,32 @@ TEST_F(AlgebraicSimplifierTest, DynamicSliceOfBroadcast) {
|
||||
m::Parameter(0), m::Parameter(1), m::Parameter(3)))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, DynamicSliceOfBroadcastPreserveLayout) {
|
||||
const char* hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY test {
|
||||
p0 = f32[10,20] parameter(0)
|
||||
i0 = s32[] parameter(1)
|
||||
i1 = s32[] parameter(2)
|
||||
i2 = s32[] parameter(3)
|
||||
b = f32[10,30,20]{2,0,1:T(256)} broadcast(p0), dimensions={0,2}
|
||||
ROOT ds = f32[5,5,5]{2,0,1:T(256)} dynamic-slice(b, i0, i1, i2), dynamic_slice_sizes={5,5,5}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
const Shape original_dynslice_shape =
|
||||
module->entry_computation()->root_instruction()->shape();
|
||||
HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
|
||||
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::DynamicSlice(
|
||||
m::Parameter(0), m::Parameter(1), m::Parameter(3)))));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), original_dynslice_shape));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, TransposeIsReshape) {
|
||||
const char* hlo_string = R"(
|
||||
HloModule module
|
||||
|
@ -75,85 +75,66 @@ CpuExecutable::CpuExecutable(
|
||||
<< reinterpret_cast<void*>(compute_function_);
|
||||
}
|
||||
|
||||
StatusOr<std::tuple<std::vector<se::DeviceMemoryBase>,
|
||||
std::vector<se::OwningDeviceMemory>,
|
||||
std::vector<se::OwningDeviceMemory>>>
|
||||
CpuExecutable::CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator,
|
||||
int device_ordinal,
|
||||
std::vector<ExecutionInput> arguments) {
|
||||
std::vector<se::DeviceMemoryBase> unowning_buffers(
|
||||
assignment_->Allocations().size());
|
||||
std::vector<se::OwningDeviceMemory> owning_buffers(
|
||||
static StatusOr<MaybeOwningDeviceMemory> MemoryForAllocation(
|
||||
const BufferAllocation& allocation,
|
||||
absl::Span<ExecutionInput const> arguments,
|
||||
se::DeviceMemoryAllocator* memory_allocator, int device_ordinal) {
|
||||
VLOG(3) << allocation.ToString();
|
||||
if (allocation.is_entry_computation_parameter()) {
|
||||
se::DeviceMemoryBase out = arguments[allocation.parameter_number()]
|
||||
.Buffer(allocation.param_shape_index())
|
||||
.AsDeviceMemoryBase();
|
||||
CHECK_EQ(allocation.size(), out.size())
|
||||
<< "Size mismatch on param " << allocation.parameter_number()
|
||||
<< " at shape index " << allocation.param_shape_index().ToString();
|
||||
VLOG(3) << "allocation is a parameter";
|
||||
return MaybeOwningDeviceMemory{out};
|
||||
} else if (allocation.is_constant()) {
|
||||
VLOG(3) << "allocation is a constant";
|
||||
return MaybeOwningDeviceMemory{se::DeviceMemoryBase{}};
|
||||
} else if (allocation.is_thread_local()) {
|
||||
VLOG(3) << "buffer is thread-local";
|
||||
return MaybeOwningDeviceMemory{se::DeviceMemoryBase{}};
|
||||
}
|
||||
|
||||
int64 buffer_size = allocation.size();
|
||||
TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory out,
|
||||
memory_allocator->Allocate(device_ordinal, buffer_size));
|
||||
VLOG(3) << "buffer allocated " << buffer_size << " bytes [" << out->opaque()
|
||||
<< "]";
|
||||
|
||||
// Since the output buffer and all the temporary buffers were written into
|
||||
// by the JITed code, msan has no way of knowing their memory was
|
||||
// initialized. Mark them initialized so that msan doesn't flag loads from
|
||||
// these buffers.
|
||||
TF_ANNOTATE_MEMORY_IS_INITIALIZED(out->opaque(), buffer_size);
|
||||
return MaybeOwningDeviceMemory{std::move(out)};
|
||||
}
|
||||
|
||||
StatusOr<std::vector<MaybeOwningDeviceMemory>> CpuExecutable::CreateBufferTable(
|
||||
se::DeviceMemoryAllocator* memory_allocator, int device_ordinal,
|
||||
absl::Span<ExecutionInput const> arguments) {
|
||||
std::vector<MaybeOwningDeviceMemory> buffers(
|
||||
assignment_->Allocations().size());
|
||||
VLOG(3) << "Allocating " << assignment_->Allocations().size()
|
||||
<< " allocations for module " << module().name();
|
||||
for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
|
||||
++i) {
|
||||
auto& allocation = assignment_->GetAllocation(i);
|
||||
|
||||
VLOG(3) << allocation.ToString();
|
||||
|
||||
if (allocation.is_entry_computation_parameter()) {
|
||||
unowning_buffers[i] = arguments[allocation.parameter_number()]
|
||||
.Buffer(allocation.param_shape_index())
|
||||
.AsDeviceMemoryBase();
|
||||
CHECK_EQ(allocation.size(), unowning_buffers[i].size())
|
||||
<< "Size mismatch on param " << allocation.parameter_number()
|
||||
<< " at shape index " << allocation.param_shape_index().ToString();
|
||||
VLOG(3) << "allocation #" << i << " is a parameter";
|
||||
continue;
|
||||
}
|
||||
|
||||
if (allocation.is_constant()) {
|
||||
VLOG(3) << "allocation #" << i << " is a constant";
|
||||
continue;
|
||||
}
|
||||
|
||||
if (allocation.is_thread_local()) {
|
||||
VLOG(3) << "buffer #" << i << " is thread-local";
|
||||
continue;
|
||||
}
|
||||
|
||||
int64 buffer_size = allocation.size();
|
||||
if (!owning_buffers[i].is_null()) {
|
||||
VLOG(3) << "buffer #" << i
|
||||
<< " is in the preallocated result ShapedBuffer";
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(owning_buffers[i], memory_allocator->Allocate(
|
||||
device_ordinal, buffer_size));
|
||||
unowning_buffers[i] = *owning_buffers[i];
|
||||
|
||||
VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes ["
|
||||
<< owning_buffers[i]->opaque() << "]";
|
||||
}
|
||||
|
||||
// Since the output buffer and all the temporary buffers were written into
|
||||
// by the JITed code, msan has no way of knowing their memory was
|
||||
// initialized. Mark them initialized so that msan doesn't flag loads from
|
||||
// these buffers.
|
||||
TF_ANNOTATE_MEMORY_IS_INITIALIZED(owning_buffers[i]->opaque(), buffer_size);
|
||||
const BufferAllocation& allocation = assignment_->GetAllocation(i);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
buffers[i], MemoryForAllocation(allocation, arguments, memory_allocator,
|
||||
device_ordinal));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
|
||||
assignment_->GetUniqueTopLevelOutputSlice());
|
||||
VLOG(3) << "result index: " << result_slice.index();
|
||||
|
||||
std::vector<se::OwningDeviceMemory> buffers_to_free;
|
||||
for (auto& argument : arguments) {
|
||||
for (auto& index_buffer : *argument.MutableBuffers()) {
|
||||
auto maybe_owning_buffer = index_buffer.second.Release();
|
||||
if (maybe_owning_buffer) {
|
||||
buffers_to_free.push_back(std::move(*maybe_owning_buffer));
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_tuple(std::move(unowning_buffers), std::move(owning_buffers),
|
||||
std::move(buffers_to_free));
|
||||
return std::move(buffers);
|
||||
}
|
||||
|
||||
Status CpuExecutable::ExecuteComputeFunction(
|
||||
const ExecutableRunOptions* run_options,
|
||||
absl::Span<const se::DeviceMemoryBase> buffers,
|
||||
absl::Span<MaybeOwningDeviceMemory const> buffers,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
// The calling convention for JITed functions is:
|
||||
//
|
||||
@ -181,7 +162,8 @@ Status CpuExecutable::ExecuteComputeFunction(
|
||||
// Call the computation function following the calling convention.
|
||||
std::vector<void*> buffer_pointers;
|
||||
for (auto& buffer : buffers) {
|
||||
buffer_pointers.push_back(const_cast<void*>(buffer.opaque()));
|
||||
buffer_pointers.push_back(
|
||||
const_cast<void*>(buffer.AsDeviceMemoryBase().opaque()));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
|
||||
assignment_->GetUniqueTopLevelOutputSlice());
|
||||
@ -223,63 +205,82 @@ Status CpuExecutable::ExecuteComputeFunction(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
|
||||
StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<se::OwningDeviceMemory> buffers) {
|
||||
absl::Span<MaybeOwningDeviceMemory> buffers,
|
||||
absl::Span<ExecutionInput> arguments) {
|
||||
se::Stream* stream = run_options->stream();
|
||||
ScopedShapedBuffer result_buffer(
|
||||
/*on_host_shape=*/result_shape(),
|
||||
/*on_device_shape=*/result_shape(), run_options->allocator(),
|
||||
stream->parent()->device_ordinal());
|
||||
ExecutionOutput result(/*on_host_shape=*/result_shape(),
|
||||
/*on_device_shape=*/result_shape(),
|
||||
run_options->allocator(),
|
||||
stream->parent()->device_ordinal());
|
||||
const HloInputOutputAliasConfig& input_output_alias =
|
||||
module().input_output_alias_config();
|
||||
|
||||
// Move se::OwningDeviceMemory values which contain the array(s) of the result
|
||||
// into the respective location in ScopedShapedBuffer which is returned to the
|
||||
// caller.
|
||||
TF_RETURN_IF_ERROR(result_buffer.buffers().ForEachMutableElementWithStatus(
|
||||
[&](const ShapeIndex& index, se::DeviceMemoryBase* device_memory) {
|
||||
const auto& sources = this->GetRootValueSet().element(index);
|
||||
// The points to set is unambiguous so the set should be a
|
||||
// singleton.
|
||||
CHECK_EQ(1, sources.values().size());
|
||||
const HloValue* value_source = sources.values()[0];
|
||||
HloInstruction* src = value_source->instruction();
|
||||
for (auto& p : result.MutableResult()->buffers()) {
|
||||
const ShapeIndex& index = p.first;
|
||||
se::DeviceMemoryBase& result_buffer = p.second;
|
||||
const HloValueSet& sources = this->GetRootValueSet().element(index);
|
||||
// The points to set is unambiguous so the set should be a
|
||||
// singleton.
|
||||
CHECK_EQ(1, sources.values().size());
|
||||
const HloValue* value_source = sources.values()[0];
|
||||
HloInstruction* src = value_source->instruction();
|
||||
|
||||
// The source for this result buffer can be a nested buffer such as
|
||||
// a tuple element. The source instruction should have a
|
||||
// non-parameter buffer assigned.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const BufferAllocation::Slice slice,
|
||||
this->assignment_->GetUniqueSlice(src, value_source->index()));
|
||||
const BufferAllocation::Index buffer_index = slice.index();
|
||||
se::OwningDeviceMemory& buffer = buffers[buffer_index];
|
||||
if (!slice.allocation()->is_entry_computation_parameter()) {
|
||||
// If the buffer coming out of the result is from a parameter, the
|
||||
// owning buffer will be null, and that means the caller aliased some
|
||||
// parameter buffer to an output one (via the
|
||||
// HloInputOutputAliasConfig API). If that is the case, the caller
|
||||
// will receive a partially complete scoped shaped buffer, which they
|
||||
// will have to fill up on return. Unfortunately the interface to the
|
||||
// execute APIs are ShapedBuffer pointer based, which assumes caller
|
||||
// ownership, and hence a buffer coming from there cannot be part of
|
||||
// the new ScopedShapedBuffer we create for the result (which assumes
|
||||
// ownership).
|
||||
*device_memory = buffer.Release();
|
||||
} else {
|
||||
auto output_alias = input_output_alias.GetAliasedOutput(
|
||||
slice.allocation()->parameter_number(),
|
||||
slice.allocation()->param_shape_index());
|
||||
CHECK(output_alias)
|
||||
<< "Output buffer is coming from parameter "
|
||||
<< slice.allocation()->parameter_number() << " at index "
|
||||
<< slice.allocation()->param_shape_index()
|
||||
<< ", but no alias exists";
|
||||
CHECK_EQ(*output_alias, index);
|
||||
// TODO(cheshire): duplication with other backends.
|
||||
absl::optional<HloInputOutputAliasConfig::Alias> alias =
|
||||
input_output_alias.GetAliasedParameter(index);
|
||||
if (alias) {
|
||||
CHECK_LT(alias->parameter_number, arguments.size());
|
||||
ExecutionInput& input = arguments[alias->parameter_number];
|
||||
MaybeOwningDeviceMemory* maybe_owning_memory =
|
||||
input.MutableBuffer(alias->parameter_index);
|
||||
if (absl::optional<se::OwningDeviceMemory> owning =
|
||||
maybe_owning_memory->Release()) {
|
||||
// If the caller passes the ownership of the device memory, reuse it
|
||||
// as the output buffer. It is up to the caller whether or not to
|
||||
// donate a buffer; the aliasing information describes which buffers
|
||||
// may alias, not buffers that must alias.
|
||||
se::DeviceMemoryBase argument_buffer = owning->Release();
|
||||
*maybe_owning_memory = argument_buffer;
|
||||
result_buffer = argument_buffer;
|
||||
if (alias->kind == HloInputOutputAliasConfig::kUserAlias) {
|
||||
// This is a user alias, so a must alias. The caller is giving us the
|
||||
// input buffer, but in case of error of the execute call, we should
|
||||
// not be releasing it as it contains valid data (for example, it is a
|
||||
// parameter which the user wants us to alias, in a gradient update
|
||||
// computation). So we store the index into the result in the aliased
|
||||
// vactor, which will be fed to the ExecutionOutput, which will be
|
||||
// using the indices to drop the addresses from its own
|
||||
// ScopedShapedBuffer result, if the ExecutionOutput is not committed.
|
||||
result.AddAliasedIndex(index);
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
return std::move(result_buffer);
|
||||
}
|
||||
}
|
||||
|
||||
if (result_buffer.is_null()) {
|
||||
// The source for this result buffer can be a nested buffer such as
|
||||
// a tuple element. The source instruction should have a
|
||||
// non-parameter buffer assigned.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const BufferAllocation::Slice slice,
|
||||
this->assignment_->GetUniqueSlice(src, value_source->index()));
|
||||
const BufferAllocation::Index buffer_index = slice.index();
|
||||
MaybeOwningDeviceMemory& buffer = buffers[buffer_index];
|
||||
if (absl::optional<se::OwningDeviceMemory> owned_buffer =
|
||||
buffer.Release()) {
|
||||
result_buffer = owned_buffer->Release();
|
||||
buffer = result_buffer;
|
||||
} else {
|
||||
result_buffer = buffer.AsDeviceMemoryBase();
|
||||
result.AddAliasedIndex(index);
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
|
||||
@ -311,22 +312,16 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
|
||||
run_options->stream()->implementation());
|
||||
se::Stream* stream = run_options->stream();
|
||||
se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
|
||||
std::vector<se::OwningDeviceMemory> owning_buffers;
|
||||
std::vector<se::DeviceMemoryBase> unowning_buffers;
|
||||
std::vector<se::OwningDeviceMemory> buffers_to_release;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::tie(unowning_buffers, owning_buffers, buffers_to_release),
|
||||
std::vector<MaybeOwningDeviceMemory> buffers,
|
||||
CreateBufferTable(memory_allocator, stream->parent()->device_ordinal(),
|
||||
std::move(arguments)));
|
||||
arguments));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
ScopedShapedBuffer result,
|
||||
CreateResultShapedBuffer(run_options, absl::MakeSpan(owning_buffers)));
|
||||
ExecutionOutput result,
|
||||
CreateResultShapedBuffer(run_options, absl::MakeSpan(buffers),
|
||||
absl::MakeSpan(arguments)));
|
||||
|
||||
// At this point, `unowning_buffers` contains unowning pointers to all of our
|
||||
// buffers, and `buffers` contains owning pointers to the non-live-out
|
||||
// buffers. Enqueue a task which keeps alive the non-live-out buffers.
|
||||
//
|
||||
// Logically we want this lambda to capture `buffers` by move, ultimately our
|
||||
// functor needs to be wrapped in an std::function, and that requires its
|
||||
// functor to be copyable. Thus we perpetrate the hack of capturing buffers
|
||||
@ -339,23 +334,33 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
|
||||
struct AsyncRunTask {
|
||||
CpuExecutable* executable;
|
||||
ServiceExecutableRunOptions run_options;
|
||||
std::vector<se::DeviceMemoryBase> unowning_buffers;
|
||||
std::shared_ptr<std::vector<se::OwningDeviceMemory>> buffers;
|
||||
std::shared_ptr<std::vector<MaybeOwningDeviceMemory>> task_buffers;
|
||||
HloExecutionProfile* hlo_execution_profile;
|
||||
|
||||
void operator()() {
|
||||
// Failing a CHECK here is not great, but I don't see an obvious way to
|
||||
// return a failed Status asynchronously.
|
||||
TF_CHECK_OK(executable->ExecuteComputeFunction(
|
||||
&run_options.run_options(), unowning_buffers, hlo_execution_profile));
|
||||
&run_options.run_options(), *task_buffers, hlo_execution_profile));
|
||||
}
|
||||
};
|
||||
host_stream->EnqueueTask(
|
||||
AsyncRunTask{this, *run_options, std::move(unowning_buffers),
|
||||
std::make_shared<std::vector<se::OwningDeviceMemory>>(
|
||||
std::move(owning_buffers)),
|
||||
AsyncRunTask{this, *run_options,
|
||||
std::make_shared<std::vector<MaybeOwningDeviceMemory>>(
|
||||
std::move(buffers)),
|
||||
hlo_execution_profile});
|
||||
return ExecutionOutput(std::move(result), std::move(buffers_to_release));
|
||||
|
||||
// TODO(cheshire): Duplication with other executables.
|
||||
for (ExecutionInput& argument : arguments) {
|
||||
for (auto& index_buffer : *argument.MutableBuffers()) {
|
||||
absl::optional<se::OwningDeviceMemory> maybe_owning_buffer =
|
||||
index_buffer.second.Release();
|
||||
if (maybe_owning_buffer) {
|
||||
result.AddToBeReleased(std::move(*maybe_owning_buffer));
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
/*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) {
|
||||
|
@ -101,24 +101,25 @@ class CpuExecutable : public Executable {
|
||||
//
|
||||
// - buffers_to_free: buffers whose ownership was donated by the caller that
|
||||
// are to be freed by the caller.
|
||||
StatusOr<std::tuple<std::vector<se::DeviceMemoryBase>,
|
||||
std::vector<se::OwningDeviceMemory>,
|
||||
std::vector<se::OwningDeviceMemory>>>
|
||||
CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator,
|
||||
int device_ordinal, std::vector<ExecutionInput> arguments);
|
||||
StatusOr<std::vector<MaybeOwningDeviceMemory>> CreateBufferTable(
|
||||
se::DeviceMemoryAllocator* memory_allocator, int device_ordinal,
|
||||
absl::Span<ExecutionInput const> arguments);
|
||||
|
||||
// Calls the generated function performing the computation with the given
|
||||
// arguments using the supplied buffers.
|
||||
Status ExecuteComputeFunction(const ExecutableRunOptions* run_options,
|
||||
absl::Span<const se::DeviceMemoryBase> buffers,
|
||||
HloExecutionProfile* hlo_execution_profile);
|
||||
Status ExecuteComputeFunction(
|
||||
const ExecutableRunOptions* run_options,
|
||||
absl::Span<MaybeOwningDeviceMemory const> buffers,
|
||||
HloExecutionProfile* hlo_execution_profile);
|
||||
|
||||
// Creates a ScopedShapedBuffer for holding the result of the computation,
|
||||
// moving buffers out of allocated_buffers and into the result as appropriate.
|
||||
// The addresses are set according to buffer assignment.
|
||||
StatusOr<ScopedShapedBuffer> CreateResultShapedBuffer(
|
||||
// Creates an Execution output holding ScopedShapedBuffer for holding the
|
||||
// result of the computation, moving buffers out of allocated_buffers and into
|
||||
// the result as appropriate. The addresses are set according to buffer
|
||||
// assignment.
|
||||
StatusOr<ExecutionOutput> CreateResultShapedBuffer(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<se::OwningDeviceMemory> buffers);
|
||||
absl::Span<MaybeOwningDeviceMemory> buffers,
|
||||
absl::Span<ExecutionInput> arguments);
|
||||
|
||||
// Returns the instruction value set of the root instruction of the entry
|
||||
// computation. Uses dataflow analysis from buffer assignment.
|
||||
|
@ -210,24 +210,6 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Skip 'fusion' instruction if any of its fused instructions are expensive.
|
||||
// This is done to avoid the duplication of expensive instructions, which
|
||||
// would occur if 'fusion' were merged into multiple users.
|
||||
//
|
||||
// If 'fusion' has just one user, then an earlier fusion pass chose not to
|
||||
// fuse this producer/consumer pair (likely because of expensive instruction
|
||||
// re-use by the consumer), and so we honor that choice here as well.
|
||||
if (absl::c_any_of(fusion->fused_instructions(),
|
||||
[](const HloInstruction* instruction) {
|
||||
return instruction->opcode() != HloOpcode::kParameter &&
|
||||
GpuInstructionFusion::IsExpensive(*instruction);
|
||||
})) {
|
||||
VLOG(3) << "Not merging " << fusion->name()
|
||||
<< ": Contains one or more expensive instructions.";
|
||||
++num_fail_expensive_fused_instruction_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Skip 'fusion' instruction if merging it into all users would result in a
|
||||
// net increase in bytes transferred (currently allowing the net bytes
|
||||
// transferred to be exceeded up to ~10% in exchange for eliminating the
|
||||
@ -244,6 +226,35 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Skip 'fusion' instruction if any of its fused instructions are expensive.
|
||||
// This is done to avoid the duplication of expensive instructions, which
|
||||
// would occur if 'fusion' were merged into multiple users.
|
||||
//
|
||||
// If 'fusion' has just one user, then an earlier fusion pass chose not to
|
||||
// fuse this producer/consumer pair (likely because of expensive instruction
|
||||
// re-use by the consumer), and so we honor that choice here as well.
|
||||
//
|
||||
// Moreover, if we are going to save a "lot" in memory bandwidth then we
|
||||
// ignore how expensive the fusion instructions are. The heuristic used to
|
||||
// determine "a lot" is the following: merging must reduce memory traffic by a
|
||||
// factor of 0.3, and the amount of memory accessed must not be entirely
|
||||
// trivial (above 1K). This likely has room for improvement in the future.
|
||||
|
||||
bool allow_expensive_ops =
|
||||
merged_to_current_bytes_ratio < 0.3 && current_bytes_transferred > 1024;
|
||||
|
||||
if (!allow_expensive_ops &&
|
||||
absl::c_any_of(fusion->fused_instructions(),
|
||||
[](const HloInstruction* instruction) {
|
||||
return instruction->opcode() != HloOpcode::kParameter &&
|
||||
GpuInstructionFusion::IsExpensive(*instruction);
|
||||
})) {
|
||||
VLOG(3) << "Not merging " << fusion->name()
|
||||
<< ": Contains one or more expensive instructions.";
|
||||
++num_fail_expensive_fused_instruction_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Skip 'fusion' instruction if merging it into at least one of the users
|
||||
// would cause too much code duplication because of inefficiencies in the
|
||||
// fusion emitter.
|
||||
|
@ -367,6 +367,37 @@ TEST_F(FusionMergerTest, WillNotMergeIfFusionEmitterIsInefficient) {
|
||||
EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(FusionMergerTest, WillMergeExpensiveFusionsIfSavesMemory) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule m
|
||||
|
||||
%f_a (p: f32[]) -> f32[1024,1024,1024] {
|
||||
%p = f32[] parameter(0)
|
||||
%b = f32[1024,1024,1024] broadcast(%p), dimensions={}
|
||||
ROOT %t = f32[1024,1024,1024] tanh(%b)
|
||||
}
|
||||
|
||||
%f_b (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
|
||||
%p = f32[1024,1024,1024] parameter(0)
|
||||
ROOT %t = f32[1024,1024,1024] tanh(%p)
|
||||
}
|
||||
|
||||
%f_c (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
|
||||
%p = f32[1024,1024,1024] parameter(0)
|
||||
ROOT %t = f32[1024,1024,1024] tanh(%p)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
p0 = f32[] parameter(0)
|
||||
f1 = f32[1024,1024,1024] fusion(p0), kind=kLoop, calls=%f_a
|
||||
f2 = f32[1024,1024,1024] fusion(f1), kind=kLoop, calls=%f_b
|
||||
f3 = f32[1024,1024,1024] fusion(f1), kind=kLoop, calls=%f_c
|
||||
ROOT f4 = f32[1024,1024,1024] add(f2, f3)
|
||||
})")
|
||||
.ValueOrDie();
|
||||
EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -499,6 +499,8 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
||||
/*allocate_buffers_for_constants=*/true,
|
||||
/*colorer=*/BufferAssigner::DefaultColorer(),
|
||||
/*must_not_live_out=*/{}, GetCanShareBuffer()));
|
||||
VLOG(1) << "Buffer Assignment Stats "
|
||||
<< buffer_assignment->GetStats().ToString();
|
||||
DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations");
|
||||
|
||||
IrEmitterContext ir_emitter_context(
|
||||
|
@ -321,38 +321,11 @@ MatchBackwardInput(HloInstruction* conv) {
|
||||
const auto no_match_result =
|
||||
std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
|
||||
|
||||
// TODO: Theoretically cuDNN supports grouped convolutions also
|
||||
// for the backward input convolution, but based on the cudnn's current state
|
||||
// there is not much performance improvement when using the
|
||||
// cudnn backward input API for grouped conv.
|
||||
// This needs to be re-evaluated for future cuDNN versions.
|
||||
// Note that we already have the necessary code down below, the only thing to
|
||||
// enable it is to remove the following early return.
|
||||
if (conv->feature_group_count() > 1) {
|
||||
return no_match_result;
|
||||
}
|
||||
|
||||
// Match instruction pattern.
|
||||
CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
|
||||
HloInstruction* reverse_filter = conv->mutable_operand(1);
|
||||
ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
|
||||
|
||||
// Match BackwardInput for a depthwise convolution and thunk it to forward
|
||||
// convolution Output feature dimension and input feature dimension has been
|
||||
// swapped in the bridge. Hence to get the actual input features we need to
|
||||
// query the output feature dimension
|
||||
auto kernel_out_feature_dim = dnums.kernel_output_feature_dimension();
|
||||
auto kernel_out_features =
|
||||
reverse_filter->shape().dimensions(kernel_out_feature_dim);
|
||||
|
||||
// For a depthwise convolution, the input features must be equal to the
|
||||
// feature_group_count. We can leverage this property to match a depthwise
|
||||
// convolution and thunk it to forward conv
|
||||
if (conv->feature_group_count() > 1 &&
|
||||
kernel_out_features == conv->feature_group_count()) {
|
||||
return no_match_result;
|
||||
}
|
||||
|
||||
// We pattern-match to a backwards input conv if:
|
||||
//
|
||||
// - all spatial dims of the filter are reversed
|
||||
|
@ -360,6 +360,27 @@ StatusOr<se::DeviceMemoryBase> GpuExecutable::BufferForAllocation(
|
||||
}
|
||||
}
|
||||
|
||||
static Status CheckAlignment(const BufferAllocation& allocation,
|
||||
se::DeviceMemoryBase buffer, int arg_idx) {
|
||||
const int64 expected_alignment = [&] {
|
||||
if (allocation.is_entry_computation_parameter()) {
|
||||
return kEntryParameterAlignBytes;
|
||||
} else if (allocation.is_constant()) {
|
||||
return kConstantBufferAlignBytes;
|
||||
} else {
|
||||
return kXlaAllocatedBufferAlignBytes;
|
||||
}
|
||||
}();
|
||||
if (!buffer.is_null() &&
|
||||
reinterpret_cast<uintptr_t>(buffer.opaque()) % expected_alignment != 0) {
|
||||
return InternalError(
|
||||
"Address of buffer %d must be a multiple of %x, but "
|
||||
"was %p",
|
||||
arg_idx, expected_alignment, buffer.opaque());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<BufferAllocations> GpuExecutable::GenerateBufferAllocations(
|
||||
absl::Span<ExecutionInput const> arguments,
|
||||
const GpuExecutable::BufferAllocToDeviceMemoryMap* globals,
|
||||
@ -378,28 +399,37 @@ StatusOr<BufferAllocations> GpuExecutable::GenerateBufferAllocations(
|
||||
se::DeviceMemoryBase buffer,
|
||||
BufferForAllocation(arguments, globals, allocation, memory_allocator,
|
||||
executor->device_ordinal(), i));
|
||||
const int64 expected_alignment = [&] {
|
||||
if (allocation.is_entry_computation_parameter()) {
|
||||
return kEntryParameterAlignBytes;
|
||||
} else if (allocation.is_constant()) {
|
||||
return kConstantBufferAlignBytes;
|
||||
} else {
|
||||
return kXlaAllocatedBufferAlignBytes;
|
||||
}
|
||||
}();
|
||||
if (!buffer.is_null() &&
|
||||
reinterpret_cast<uintptr_t>(buffer.opaque()) % expected_alignment !=
|
||||
0) {
|
||||
return InternalError(
|
||||
"Address of buffer %d must be a multiple of %x, but "
|
||||
"was %p",
|
||||
i, expected_alignment, buffer.opaque());
|
||||
}
|
||||
buffers.push_back(buffer);
|
||||
TF_RETURN_IF_ERROR(CheckAlignment(allocation, buffer, i));
|
||||
}
|
||||
return {{buffers, executor->device_ordinal(), memory_allocator}};
|
||||
}
|
||||
|
||||
// Returns `true` if the entire tuple contents is aliased.
|
||||
static bool EntireTupleContentsAliased(
|
||||
const Shape& output_shape, const ShapeIndex& index,
|
||||
const HloInputOutputAliasConfig& alias_config) {
|
||||
const Shape& indexed_shape = ShapeUtil::GetSubshape(output_shape, index);
|
||||
if (!indexed_shape.IsTuple()) {
|
||||
return false;
|
||||
}
|
||||
bool all_aliased = true;
|
||||
ShapeUtil::ForEachSubshape(
|
||||
indexed_shape, [&](const Shape& subshape, const ShapeIndex& subindex) {
|
||||
if (subindex.empty()) {
|
||||
return;
|
||||
}
|
||||
std::vector<int64> full_index;
|
||||
absl::c_copy(index, std::back_inserter(full_index));
|
||||
absl::c_copy(subindex, std::back_inserter(full_index));
|
||||
if (!alias_config.OutputHasAlias(
|
||||
ShapeIndex(full_index.begin(), full_index.end()))) {
|
||||
all_aliased = false;
|
||||
}
|
||||
});
|
||||
return all_aliased;
|
||||
}
|
||||
|
||||
StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
std::vector<ExecutionInput> arguments,
|
||||
@ -425,84 +455,102 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
|
||||
}
|
||||
|
||||
se::StreamExecutor* executor = run_options->stream()->parent();
|
||||
TF_ASSIGN_OR_RETURN(BufferAllocations buffer_allocations,
|
||||
GenerateBufferAllocations(arguments, globals,
|
||||
memory_allocator, executor));
|
||||
|
||||
for (Thunk* thunk : thunk_schedule_->TotalOrder()) {
|
||||
TF_RETURN_IF_ERROR(thunk->Initialize(*this, executor));
|
||||
}
|
||||
VLOG(2) << buffer_allocations.ToString();
|
||||
TF_RETURN_IF_ERROR(ExecuteThunks(run_options, buffer_allocations,
|
||||
block_host_until_done,
|
||||
hlo_execution_profile));
|
||||
|
||||
HloInstruction* root = hlo_module_->entry_computation()->root_instruction();
|
||||
auto device_ordinal = executor->device_ordinal();
|
||||
ExecutionOutput result(root->shape(), root->shape(), memory_allocator,
|
||||
ExecutionOutput result(/*on_host_shape=*/root->shape(),
|
||||
/*on_device_shape=*/root->shape(), memory_allocator,
|
||||
device_ordinal);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(BufferAllocations buffer_allocations,
|
||||
GenerateBufferAllocations(arguments, globals,
|
||||
memory_allocator, executor));
|
||||
VLOG(2) << buffer_allocations.ToString();
|
||||
std::set<se::DeviceMemoryBase> buffers_in_result;
|
||||
for (auto& p : result.MutableResult()->buffers()) {
|
||||
const ShapeIndex& index = p.first;
|
||||
se::DeviceMemoryBase& device_memory = p.second;
|
||||
se::DeviceMemoryBase& result_buffer = p.second;
|
||||
const auto& sources = GetRootValueSet().element(index);
|
||||
// The points-to set is unambiguous so the set should be a
|
||||
// singleton. That is, we know exactly which instruction
|
||||
// produced the array at this element.
|
||||
CHECK_EQ(1, sources.values().size());
|
||||
auto src_hlo = sources.values()[0]->instruction();
|
||||
HloInstruction* src_hlo = sources.values()[0]->instruction();
|
||||
|
||||
VLOG(4) << "Looking at: " << sources.values()[0];
|
||||
|
||||
// The source instruction should have a non-parameter buffer
|
||||
// assigned.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const BufferAllocation::Slice slice,
|
||||
assignment_->GetUniqueSlice(src_hlo, sources.values()[0]->index()));
|
||||
|
||||
se::DeviceMemoryBase src_base =
|
||||
buffer_allocations.GetDeviceAddress(slice.index());
|
||||
CHECK(!src_base.is_null() || src_base.size() == 0);
|
||||
if (!slice.allocation()->is_entry_computation_parameter()) {
|
||||
// If the buffer coming out of the result is from a parameter, it
|
||||
// means the caller aliased some parameter buffer to an output one
|
||||
// (via the HloInputOutputAliasConfig API). If that is the case, the
|
||||
// caller will receive a partially complete scoped shaped buffer,
|
||||
// which they will have to fill up on return.
|
||||
// Unfortunately the interface to the execute APIs are ShapedBuffer
|
||||
// pointer based, which assumes caller ownership, and hence a buffer
|
||||
// coming from there cannot be part of the new ScopedShapedBuffer we
|
||||
// create for the result (which assumes ownership).
|
||||
device_memory = src_base;
|
||||
} else {
|
||||
const HloInputOutputAliasConfig& input_output_alias =
|
||||
module().input_output_alias_config();
|
||||
auto output_alias = input_output_alias.GetAliasedOutput(
|
||||
slice.allocation()->parameter_number(),
|
||||
slice.allocation()->param_shape_index());
|
||||
CHECK(output_alias) << "Output buffer is coming from parameter "
|
||||
<< slice.allocation()->parameter_number()
|
||||
<< " at index "
|
||||
<< slice.allocation()->param_shape_index()
|
||||
<< ", but no alias exists";
|
||||
CHECK_EQ(*output_alias, index);
|
||||
const HloInputOutputAliasConfig& input_output_alias =
|
||||
module().input_output_alias_config();
|
||||
absl::optional<HloInputOutputAliasConfig::Alias> alias =
|
||||
input_output_alias.GetAliasedParameter(index);
|
||||
if (alias) {
|
||||
CHECK_LT(alias->parameter_number, arguments.size());
|
||||
ExecutionInput& input = arguments[alias->parameter_number];
|
||||
MaybeOwningDeviceMemory* maybe_owning_memory =
|
||||
input.MutableBuffer(alias->parameter_index);
|
||||
if (absl::optional<se::OwningDeviceMemory> owning =
|
||||
maybe_owning_memory->Release()) {
|
||||
// If the caller passes the ownership of the device memory, reuse it
|
||||
// as the output buffer. It is up to the caller whether or not to
|
||||
// donate a buffer; the aliasing information describes which buffers
|
||||
// may alias, not buffers that must alias.
|
||||
se::DeviceMemoryBase argument_buffer = owning->Release();
|
||||
*maybe_owning_memory = argument_buffer;
|
||||
result_buffer = argument_buffer;
|
||||
if (alias->kind == HloInputOutputAliasConfig::kUserAlias) {
|
||||
// This is a user alias, so a must alias. The caller is giving us the
|
||||
// input buffer, but in case of error from the execute call, we should
|
||||
// not be releasing it as it contains valid data (for example, it is a
|
||||
// parameter which the user wants us to alias, in a gradient update
|
||||
// computation). So we store the index into the result in the aliased
|
||||
// vector, which will be fed to the ExecutionOutput, which will use
|
||||
// the indices to drop the addresses from its own ScopedShapedBuffer
|
||||
// result, if the ExecutionOutput is not committed.
|
||||
result.AddAliasedIndex(index);
|
||||
}
|
||||
}
|
||||
}
|
||||
buffers_in_result.insert(src_base);
|
||||
|
||||
if (result_buffer.is_null()) {
|
||||
// The source instruction should have a non-parameter buffer
|
||||
// assigned.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const BufferAllocation::Slice slice,
|
||||
assignment_->GetUniqueSlice(src_hlo, sources.values()[0]->index()));
|
||||
result_buffer = buffer_allocations.GetDeviceAddress(slice.index());
|
||||
|
||||
// If the entire tuple contents is aliased, the copy insertion will *not*
|
||||
// materialize a new tuple, so we mark it as aliased as well.
|
||||
if (EntireTupleContentsAliased(root->shape(), index,
|
||||
input_output_alias)) {
|
||||
result.AddAliasedIndex(index);
|
||||
}
|
||||
}
|
||||
buffers_in_result.insert(result_buffer);
|
||||
}
|
||||
|
||||
for (Thunk* thunk : thunk_schedule_->TotalOrder()) {
|
||||
TF_RETURN_IF_ERROR(thunk->Initialize(*this, executor));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ExecuteThunks(run_options, buffer_allocations,
|
||||
block_host_until_done,
|
||||
hlo_execution_profile));
|
||||
|
||||
// Free all temporary allocations.
|
||||
TF_RETURN_IF_ERROR(
|
||||
buffer_allocations.TearDown(buffers_in_result, assignment_.get()));
|
||||
|
||||
std::vector<se::OwningDeviceMemory> buffers_to_free;
|
||||
for (auto& argument : arguments) {
|
||||
// Free allocations for arguments.
|
||||
for (ExecutionInput& argument : arguments) {
|
||||
for (auto& index_buffer : *argument.MutableBuffers()) {
|
||||
auto maybe_owning_buffer = index_buffer.second.Release();
|
||||
if (maybe_owning_buffer) {
|
||||
buffers_to_free.push_back(std::move(*maybe_owning_buffer));
|
||||
if (absl::optional<se::OwningDeviceMemory> owning =
|
||||
index_buffer.second.Release()) {
|
||||
result.AddToBeReleased(std::move(*owning));
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
const InstructionValueSet& GpuExecutable::GetRootValueSet() const {
|
||||
|
@ -121,10 +121,10 @@ string HloInputOutputAliasConfig::ToString() const {
|
||||
return absl::StrJoin(pieces, "\n");
|
||||
}
|
||||
|
||||
HloInputOutputAliasConfig::AliasKind
|
||||
absl::optional<HloInputOutputAliasConfig::AliasKind>
|
||||
HloInputOutputAliasConfig::ParameterAliasKind(
|
||||
int64 param_number, const ShapeIndex& param_index) const {
|
||||
AliasKind kind = AliasKind::kNoAlias;
|
||||
absl::optional<AliasKind> kind;
|
||||
alias_.ForEachElement(
|
||||
[&](const xla::ShapeIndex&, absl::optional<Alias> alias) {
|
||||
if (alias && alias->parameter_number == param_number &&
|
||||
|
@ -36,7 +36,6 @@ class HloInputOutputAliasConfig {
|
||||
// compilation time by the user, and has to be respected. A kSystemAlias one
|
||||
// might be setup by the compiler, if it decides it is convenient to do so.
|
||||
enum AliasKind {
|
||||
kNoAlias,
|
||||
kUserAlias,
|
||||
kSystemAlias,
|
||||
};
|
||||
@ -68,15 +67,15 @@ class HloInputOutputAliasConfig {
|
||||
AliasKind kind = AliasKind::kUserAlias);
|
||||
|
||||
// Returns the kind of alias for the given parameter number and parameter
|
||||
// index. If no alias exists, AliasKind::kNoAlias is returned.
|
||||
AliasKind ParameterAliasKind(int64 param_number,
|
||||
const ShapeIndex& param_index) const;
|
||||
// index.
|
||||
absl::optional<AliasKind> ParameterAliasKind(
|
||||
int64 param_number, const ShapeIndex& param_index) const;
|
||||
|
||||
// Returns true if the given parameter is aliased with one of the output
|
||||
// buffers.
|
||||
bool ParameterHasAlias(int64 param_number,
|
||||
const ShapeIndex& param_index) const {
|
||||
return ParameterAliasKind(param_number, param_index) != AliasKind::kNoAlias;
|
||||
return ParameterAliasKind(param_number, param_index).has_value();
|
||||
}
|
||||
|
||||
// Checks whether the provided output index has already been aliased.
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "absl/base/casts.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "llvm/ADT/Triple.h"
|
||||
#include "llvm/IR/DerivedTypes.h"
|
||||
#include "llvm/IR/GlobalValue.h"
|
||||
#include "llvm/IR/GlobalVariable.h"
|
||||
|
@ -432,8 +432,8 @@ std::string MemorySpaceAssignment::AllocationValue::ToString() const {
|
||||
absl::StrAppend(&out, "\n position:\n");
|
||||
absl::StrAppend(&out, " ", defining_position_.ToString(), "\n");
|
||||
absl::StrAppend(&out, " uses:\n");
|
||||
for (const HloUse& use : uses_) {
|
||||
absl::StrAppend(&out, " ", use.ToString(), "\n");
|
||||
for (const Use& use : uses_) {
|
||||
absl::StrAppend(&out, " ", use.hlo_use.ToString(), "\n");
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -515,6 +515,53 @@ void AlternateMemoryBestFitHeap::CreateAllocationValues(
|
||||
}
|
||||
}
|
||||
|
||||
void AlternateMemoryBestFitHeap::FindAliases(
|
||||
std::vector<AllocationValue>* allocation_values) const {
|
||||
absl::flat_hash_map<const HloInstruction*, const AllocationValue*>
|
||||
values_by_defining_inst;
|
||||
for (AllocationValue& value : *allocation_values) {
|
||||
CHECK_EQ(values_by_defining_inst.count(value.defining_instruction()), 0);
|
||||
values_by_defining_inst[value.defining_instruction()] = &value;
|
||||
}
|
||||
auto maybe_add_alias_with_instruction = [&](const HloInstruction* instruction,
|
||||
AllocationValue::Use* use) {
|
||||
auto aliased_value_it = values_by_defining_inst.find(instruction);
|
||||
if (aliased_value_it != values_by_defining_inst.end()) {
|
||||
VLOG(3) << "Adding aliasing for use " << use->hlo_use.ToString() << " to "
|
||||
<< aliased_value_it->second->ToShortString();
|
||||
use->aliases.push_back(aliased_value_it->second->defining_position());
|
||||
}
|
||||
};
|
||||
|
||||
for (AllocationValue& value : *allocation_values) {
|
||||
for (AllocationValue::Use& use : value.uses()) {
|
||||
// Find any aliases with the instruction itself (operand and output must
|
||||
// alias).
|
||||
maybe_add_alias_with_instruction(use.hlo_use.instruction, &use);
|
||||
|
||||
// Find any aliases with the parameters of called computations.
|
||||
for (const HloComputation* called_computation :
|
||||
use.hlo_use.instruction->called_computations()) {
|
||||
for (const HloInstruction* parameter_instruction :
|
||||
called_computation->parameter_instructions()) {
|
||||
maybe_add_alias_with_instruction(parameter_instruction, &use);
|
||||
}
|
||||
}
|
||||
|
||||
// Special case for kWhile: the root of the body computation must alias as
|
||||
// well.
|
||||
if (use.hlo_use.instruction->opcode() == HloOpcode::kWhile) {
|
||||
HloPosition root_alias{
|
||||
use.hlo_use.instruction->while_body()->root_instruction(),
|
||||
use.hlo_use.operand_index};
|
||||
VLOG(3) << "Adding while body root aliasing for use "
|
||||
<< use.hlo_use.ToString() << " to " << root_alias;
|
||||
use.aliases.push_back(root_alias);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const GlobalDecreasingSizeBestFitHeap::BufferInterval*>
|
||||
AlternateMemoryBestFitHeap::GetSortedColocatedIntervals(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const {
|
||||
@ -675,18 +722,18 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
|
||||
// multiple called computations), determine if the parameter->first use
|
||||
// dependency is short.
|
||||
int64 conditional_time = instruction_schedule.at(use.instruction);
|
||||
for (const HloUse& other_use : value.uses()) {
|
||||
if (other_use.instruction != use.instruction) {
|
||||
for (const AllocationValue::Use& other_use : value.uses()) {
|
||||
if (other_use.hlo_use.instruction != use.instruction) {
|
||||
continue;
|
||||
}
|
||||
HloComputation* called_computation =
|
||||
use.instruction->called_computations().at(other_use.operand_number -
|
||||
1);
|
||||
use.instruction->called_computations().at(
|
||||
other_use.hlo_use.operand_number - 1);
|
||||
const HloInstruction* parameter_instruction =
|
||||
called_computation->parameter_instruction(0);
|
||||
HloValue* parameter_value =
|
||||
&alias_analysis_.dataflow_analysis().GetUniqueValueAt(
|
||||
parameter_instruction, other_use.operand_index);
|
||||
parameter_instruction, other_use.hlo_use.operand_index);
|
||||
int64 parameter_time = instruction_schedule.at(parameter_instruction);
|
||||
int64 min_use_time = conditional_time;
|
||||
for (const HloUse& parameter_use : parameter_value->uses()) {
|
||||
@ -947,6 +994,7 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
|
||||
for (const auto& colocated_interval : colocated_intervals) {
|
||||
CreateAllocationValues(colocated_interval->buffer, &allocation_values);
|
||||
}
|
||||
FindAliases(&allocation_values);
|
||||
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
|
||||
|
||||
// Data structure to contain the preferred offset for a given computation.
|
||||
@ -969,25 +1017,26 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
|
||||
|
||||
// Iterate over the uses.
|
||||
for (int use_idx = 0; use_idx < allocation_value.uses().size(); ++use_idx) {
|
||||
const HloUse& use = allocation_value.uses().at(use_idx);
|
||||
int64 use_time = instruction_schedule.at(use.instruction);
|
||||
const AllocationValue::Use& use = allocation_value.uses().at(use_idx);
|
||||
const HloUse hlo_use = use.hlo_use;
|
||||
int64 use_time = instruction_schedule.at(hlo_use.instruction);
|
||||
int64 latest_prefetch_time = use_time;
|
||||
bool allow_no_copy_alternate_mem_allocation = true;
|
||||
absl::optional<int64> earliest_prefetch_time = absl::nullopt;
|
||||
|
||||
// Sequential calls include kWhile, kCall, and kConditional opcodes.
|
||||
bool is_sequential_call =
|
||||
(GetInstructionCallContext(use.instruction->opcode()) ==
|
||||
(GetInstructionCallContext(hlo_use.instruction->opcode()) ==
|
||||
CallContext::kSequential);
|
||||
if (is_sequential_call) {
|
||||
for (const HloComputation* called_computation :
|
||||
use.instruction->called_computations()) {
|
||||
hlo_use.instruction->called_computations()) {
|
||||
const HloLiveRange::TimeBound& computation_span =
|
||||
hlo_live_range_.computation_span_times().at(called_computation);
|
||||
latest_prefetch_time =
|
||||
std::min(computation_span.start, latest_prefetch_time);
|
||||
}
|
||||
if (use.instruction->opcode() == HloOpcode::kWhile) {
|
||||
if (hlo_use.instruction->opcode() == HloOpcode::kWhile) {
|
||||
// Given an example while loop and flattened schedule (logical times
|
||||
// shown on the left):
|
||||
//
|
||||
@ -1008,10 +1057,10 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
|
||||
// the interval to time 0-4. This is so that the remaining interval
|
||||
// (5-6) can be allocated separately and this buffer doesn't waste
|
||||
// alternate memory space within the while loop body.
|
||||
HloComputation* while_body = use.instruction->while_body();
|
||||
HloComputation* while_body = hlo_use.instruction->while_body();
|
||||
// We require while body ROOTs to be the last in the schedule.
|
||||
CHECK_EQ(instruction_schedule.at(while_body->root_instruction()) + 1,
|
||||
instruction_schedule.at(use.instruction))
|
||||
instruction_schedule.at(hlo_use.instruction))
|
||||
<< "While body ROOTs need to be the last in the schedule! "
|
||||
"Please run RootInstructionSinker.";
|
||||
// Replace the use time with the parameter time so that we can decide
|
||||
@ -1019,11 +1068,11 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
|
||||
// look at uses within the while loop body.
|
||||
use_time =
|
||||
instruction_schedule.at(while_body->parameter_instruction(0));
|
||||
} else if (use.instruction->opcode() == HloOpcode::kConditional) {
|
||||
} else if (hlo_use.instruction->opcode() == HloOpcode::kConditional) {
|
||||
// Replace the use time with the earliest parameter of called
|
||||
// computations.
|
||||
for (const HloComputation* called_computation :
|
||||
use.instruction->called_computations()) {
|
||||
hlo_use.instruction->called_computations()) {
|
||||
use_time = std::min(
|
||||
use_time, instruction_schedule.at(
|
||||
called_computation->parameter_instruction(0)));
|
||||
@ -1033,8 +1082,8 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
|
||||
|
||||
// Add a required assignment in default memory if the use not allowed in
|
||||
// alternate memory.
|
||||
if (!IsUseAllowedInAlternateMemory(allocation_value, use)) {
|
||||
AddRequiredAssignment(allocation_value.value(), use.instruction,
|
||||
if (!IsUseAllowedInAlternateMemory(allocation_value, hlo_use)) {
|
||||
AddRequiredAssignment(allocation_value.value(), hlo_use.instruction,
|
||||
MemorySpace::kDefault, use_time);
|
||||
} else if (use_idx > 0) {
|
||||
// We allow buffers in alternate memory that are passed into
|
||||
@ -1043,14 +1092,16 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
|
||||
// alternate memory allocation, subsequent uses cannot use the same
|
||||
// alternate memory allocation in order not to clobber data. So we force
|
||||
// default memory allocation for these subsequent uses.
|
||||
const HloUse& previous_use = allocation_value.uses().at(use_idx - 1);
|
||||
if (previous_use.instruction->opcode() == HloOpcode::kConditional &&
|
||||
previous_use.instruction != use.instruction) {
|
||||
const AllocationValue::Use& previous_use =
|
||||
allocation_value.uses().at(use_idx - 1);
|
||||
if (previous_use.hlo_use.instruction->opcode() ==
|
||||
HloOpcode::kConditional &&
|
||||
previous_use.hlo_use.instruction != hlo_use.instruction) {
|
||||
allow_no_copy_alternate_mem_allocation = false;
|
||||
earliest_prefetch_time =
|
||||
instruction_schedule.at(previous_use.instruction);
|
||||
VLOG(3) << "Previous use (" << previous_use.ToString() << ") of use ("
|
||||
<< use.ToString()
|
||||
instruction_schedule.at(previous_use.hlo_use.instruction);
|
||||
VLOG(3) << "Previous use (" << previous_use.hlo_use.ToString()
|
||||
<< ") of use (" << hlo_use.ToString()
|
||||
<< ") is a conditional, so this use will need to evict. "
|
||||
<< "Earliest prefetch time = " << *earliest_prefetch_time;
|
||||
}
|
||||
@ -1059,7 +1110,7 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
|
||||
// Bitcasts don't define buffers and don't directly consume buffers. Skip
|
||||
// allocating buffers for bitcast uses. The uses that feed from bitcasts
|
||||
// will be handled specially.
|
||||
if (use.instruction->opcode() != HloOpcode::kBitcast) {
|
||||
if (hlo_use.instruction->opcode() != HloOpcode::kBitcast) {
|
||||
AllocationRequest request;
|
||||
// Rarely, (e.g., when conditional true and false parameters are the
|
||||
// same), definition time can be the time of the conditional and use
|
||||
@ -1072,7 +1123,7 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
|
||||
allow_no_copy_alternate_mem_allocation;
|
||||
request.earliest_prefetch_time = earliest_prefetch_time;
|
||||
request.preferred_offset = preferred_offset;
|
||||
request.use = use;
|
||||
request.use = &use;
|
||||
request.allocation_value = &allocation_value;
|
||||
if (!AllocateSegment(request)) {
|
||||
// If the allocation finding failed (e.g., due to running out of
|
||||
@ -1085,23 +1136,25 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
|
||||
|
||||
// If there are multiple uses, they can try using the memory allocation
|
||||
// already at the alternate memory.
|
||||
definition_time = instruction_schedule.at(use.instruction);
|
||||
definition_time = instruction_schedule.at(hlo_use.instruction);
|
||||
}
|
||||
|
||||
// If the use has been a sequential call (e.g. a while loop), the other
|
||||
// colocated intervals must alias with this allocation.
|
||||
if (is_sequential_call) {
|
||||
MemorySpaceAssignment::Allocation* aliased_allocation =
|
||||
GetLiveAllocationAt(*allocation_value.allocation_sequence(),
|
||||
use_time);
|
||||
AddAliasedRequiredAssignmentsForSequentialCall(use, aliased_allocation);
|
||||
// Remember the preferred offset to be used inside while loop body
|
||||
// computations.
|
||||
if (aliased_allocation->memory_space() == MemorySpace::kAlternate &&
|
||||
use.instruction->opcode() == HloOpcode::kWhile) {
|
||||
preferred_offset_for_computation[use.instruction->while_body()] =
|
||||
aliased_allocation->chunk().offset;
|
||||
}
|
||||
// Propagate the allocation to any aliases this use might have had.
|
||||
MemorySpaceAssignment::Allocation* aliased_allocation =
|
||||
GetLiveAllocationAt(*allocation_value.allocation_sequence(),
|
||||
use_time);
|
||||
for (const HloPosition& aliased_position : use.aliases) {
|
||||
AddAliasedRequiredAssignment(aliased_position.instruction,
|
||||
aliased_position.index,
|
||||
aliased_allocation);
|
||||
}
|
||||
|
||||
// Special case for while loops since the root offset must agree with
|
||||
// other offsets: remember the preferred offset for the while loop body.
|
||||
if (hlo_use.instruction->opcode() == HloOpcode::kWhile &&
|
||||
aliased_allocation->memory_space() == MemorySpace::kAlternate) {
|
||||
preferred_offset_for_computation[hlo_use.instruction->while_body()] =
|
||||
aliased_allocation->chunk().offset;
|
||||
}
|
||||
}
|
||||
if (!allocation_success) {
|
||||
@ -1212,34 +1265,45 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer(
|
||||
pending_required_assignments_.clear();
|
||||
}
|
||||
|
||||
void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignmentsForSequentialCall(
|
||||
const HloUse& use,
|
||||
const MemorySpaceAssignment::Allocation* aliased_allocation) {
|
||||
// Add aliased required assignments.
|
||||
if (use.instruction->opcode() == HloOpcode::kWhile) {
|
||||
HloComputation* while_body = use.instruction->while_body();
|
||||
HloComputation* while_condition = use.instruction->while_condition();
|
||||
AddAliasedRequiredAssignment(while_condition->parameter_instruction(0),
|
||||
use.operand_index, aliased_allocation);
|
||||
AddAliasedRequiredAssignment(while_body->parameter_instruction(0),
|
||||
use.operand_index, aliased_allocation);
|
||||
AddAliasedRequiredAssignment(while_body->root_instruction(),
|
||||
use.operand_index, aliased_allocation);
|
||||
AddAliasedRequiredAssignment(use.instruction, use.operand_index,
|
||||
aliased_allocation);
|
||||
} else if (use.instruction->opcode() == HloOpcode::kConditional) {
|
||||
HloComputation* called_computation =
|
||||
use.instruction->called_computations().at(use.operand_number - 1);
|
||||
AddAliasedRequiredAssignment(called_computation->parameter_instruction(0),
|
||||
use.operand_index, aliased_allocation);
|
||||
} else {
|
||||
CHECK(use.instruction->opcode() == HloOpcode::kCall);
|
||||
HloComputation* called_computation =
|
||||
use.instruction->called_computations().at(0);
|
||||
AddAliasedRequiredAssignment(
|
||||
called_computation->parameter_instruction(use.operand_number),
|
||||
use.operand_index, aliased_allocation);
|
||||
absl::optional<RequiredMemoryAssignment>
|
||||
AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer,
|
||||
int64 time) const {
|
||||
auto required_assignment_it = required_assignments_.find(buffer);
|
||||
absl::optional<RequiredMemoryAssignment> required_assignment_at_time;
|
||||
if (required_assignment_it != required_assignments_.end()) {
|
||||
for (const RequiredMemoryAssignment& required_assignment :
|
||||
required_assignment_it->second) {
|
||||
if (required_assignment.time == time) {
|
||||
// Sanity check that there is only one required at time.
|
||||
CHECK(!required_assignment_at_time);
|
||||
required_assignment_at_time = required_assignment;
|
||||
}
|
||||
}
|
||||
}
|
||||
return required_assignment_at_time;
|
||||
}
|
||||
|
||||
absl::optional<RequiredMemoryAssignment>
|
||||
AlternateMemoryBestFitHeap::AliasedRequiredAssignmentForUse(
|
||||
const AllocationValue::Use& use) const {
|
||||
absl::optional<RequiredMemoryAssignment> required_assignment;
|
||||
for (const HloPosition& position : use.aliases) {
|
||||
const HloValue* value =
|
||||
&alias_analysis_.dataflow_analysis().GetUniqueValueAt(
|
||||
position.instruction, position.index);
|
||||
int64 time =
|
||||
hlo_live_range_.instruction_schedule().at(position.instruction);
|
||||
absl::optional<RequiredMemoryAssignment> required_assignment_for_alias =
|
||||
RequiredMemoryAssignmentAt(value, time);
|
||||
if (required_assignment == absl::nullopt) {
|
||||
required_assignment = required_assignment_for_alias;
|
||||
} else {
|
||||
CHECK(required_assignment_for_alias == absl::nullopt ||
|
||||
required_assignment->equals_ignoring_time(
|
||||
*required_assignment_for_alias));
|
||||
}
|
||||
}
|
||||
return required_assignment;
|
||||
}
|
||||
|
||||
void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment(
|
||||
@ -1429,24 +1493,6 @@ void AlternateMemoryBestFitHeap::AddToPendingChunks(
|
||||
CommitChunk(buffer_interval, chunk_candidate);
|
||||
}
|
||||
|
||||
absl::optional<RequiredMemoryAssignment>
|
||||
AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer,
|
||||
int64 time) const {
|
||||
auto required_assignment_it = required_assignments_.find(buffer);
|
||||
absl::optional<RequiredMemoryAssignment> required_assignment_at_time;
|
||||
if (required_assignment_it != required_assignments_.end()) {
|
||||
for (const RequiredMemoryAssignment& required_assignment :
|
||||
required_assignment_it->second) {
|
||||
if (required_assignment.time == time) {
|
||||
// Sanity check that there is only one required at time.
|
||||
CHECK(!required_assignment_at_time);
|
||||
required_assignment_at_time = required_assignment;
|
||||
}
|
||||
}
|
||||
}
|
||||
return required_assignment_at_time;
|
||||
}
|
||||
|
||||
bool AlternateMemoryBestFitHeap::AllocateSegment(
|
||||
const AllocationRequest& request) {
|
||||
auto allocation_sequence = request.allocation_value->allocation_sequence();
|
||||
@ -1457,7 +1503,7 @@ bool AlternateMemoryBestFitHeap::AllocateSegment(
|
||||
MemorySpaceAssignment::Allocation* allocation =
|
||||
GetLiveAllocationAt(*allocation_sequence, request.end_time);
|
||||
CHECK_NE(allocation, nullptr);
|
||||
allocation->AddUse(request.use);
|
||||
allocation->AddUse(request.use->hlo_use);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -1467,8 +1513,9 @@ bool AlternateMemoryBestFitHeap::AllocateSegment(
|
||||
<< request.allocation_value->ToShortString() << " ("
|
||||
<< request.start_time << ", " << request.end_time
|
||||
<< ") latest prefetch = " << request.latest_prefetch_time
|
||||
<< " last use = " << request.allocation_value->use_times().back()
|
||||
<< " use = " << request.use.ToString() << ". Size = " << request.size
|
||||
<< " last use = " << request.allocation_value->uses().back().time
|
||||
<< " use = " << request.use->hlo_use.ToString()
|
||||
<< ". Size = " << request.size
|
||||
<< ", def pos = " << defining_position.ToString();
|
||||
CHECK_LE(request.start_time, request.end_time);
|
||||
|
||||
@ -1483,8 +1530,21 @@ bool AlternateMemoryBestFitHeap::AllocateSegment(
|
||||
if (required_assignment_at_start) {
|
||||
required_memory_space_at_start = required_assignment_at_start->memory_space;
|
||||
}
|
||||
// Find required assignment both for the use and its aliases. If they are both
|
||||
// non-nullopt, then make sure they require the same assignment.
|
||||
auto required_assignment_at_end = RequiredMemoryAssignmentAt(
|
||||
request.allocation_value->value(), request.end_time);
|
||||
auto aliased_required_assignment_at_end =
|
||||
AliasedRequiredAssignmentForUse(*request.use);
|
||||
if (required_assignment_at_end != aliased_required_assignment_at_end) {
|
||||
if (required_assignment_at_end == absl::nullopt) {
|
||||
required_assignment_at_end = aliased_required_assignment_at_end;
|
||||
} else {
|
||||
CHECK(aliased_required_assignment_at_end == absl::nullopt ||
|
||||
aliased_required_assignment_at_end->equals_ignoring_time(
|
||||
*required_assignment_at_end));
|
||||
}
|
||||
}
|
||||
absl::optional<MemorySpace> required_memory_space_at_end;
|
||||
if (required_assignment_at_end) {
|
||||
required_memory_space_at_end = required_assignment_at_end->memory_space;
|
||||
@ -1553,7 +1613,7 @@ bool AlternateMemoryBestFitHeap::AllocateSegment(
|
||||
VLOG(3)
|
||||
<< "Not trying to prefetch because use requires buffer in default mem.";
|
||||
(*prev_allocation_in_default_mem_it)->Extend(request.end_time);
|
||||
(*prev_allocation_in_default_mem_it)->AddUse(request.use);
|
||||
(*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -1577,7 +1637,7 @@ bool AlternateMemoryBestFitHeap::AllocateSegment(
|
||||
// If a copy wasn't inserted, then add this use to the latest allocation in
|
||||
// default memory.
|
||||
(*prev_allocation_in_default_mem_it)->Extend(request.end_time);
|
||||
(*prev_allocation_in_default_mem_it)->AddUse(request.use);
|
||||
(*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -1746,7 +1806,7 @@ bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy(
|
||||
chunk_candidate->chunk, request.start_time, request.end_time));
|
||||
}
|
||||
request.allocation_value->allocation_sequence()->back()->AddUse(
|
||||
request.use);
|
||||
request.use->hlo_use);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
@ -1833,7 +1893,7 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) {
|
||||
if (!eviction_scheduled) {
|
||||
// If the eviction couldn't be scheduled, then fail. This buffer will be
|
||||
// kept in the default memory.
|
||||
VLOG(3) << "Bailing: Could not evict " << request.use.ToString()
|
||||
VLOG(3) << "Bailing: Could not evict " << request.use->hlo_use.ToString()
|
||||
<< " because we hit the limit of maximum asynchronous copies "
|
||||
<< "between "
|
||||
<< hlo_live_range_.flattened_instruction_sequence()
|
||||
@ -1868,7 +1928,8 @@ bool AlternateMemoryBestFitHeap::Prefetch(
|
||||
earliest_prefetch_time =
|
||||
std::max(earliest_prefetch_time, *request.earliest_prefetch_time);
|
||||
}
|
||||
options_.prefetch_interval_picker->Begin(request.use, earliest_prefetch_time,
|
||||
options_.prefetch_interval_picker->Begin(request.use->hlo_use,
|
||||
earliest_prefetch_time,
|
||||
request.latest_prefetch_time);
|
||||
VLOG(3) << "Trying prefetch picker = "
|
||||
<< options_.prefetch_interval_picker->ToDebugString();
|
||||
@ -1922,7 +1983,7 @@ bool AlternateMemoryBestFitHeap::Prefetch(
|
||||
request.allocation_value->allocation_sequence());
|
||||
|
||||
request.allocation_value->allocation_sequence()->back()->AddUse(
|
||||
request.use);
|
||||
request.use->hlo_use);
|
||||
prefetch_failed_due_to_async_copy_ = false;
|
||||
return true;
|
||||
}
|
||||
@ -1938,11 +1999,11 @@ AlternateMemoryBestFitHeap::FindBestChunkCandidate(
|
||||
if (!preferred_offset) {
|
||||
// Find a chunk that's as long living as possible iterating in reverse over
|
||||
// the use times.
|
||||
for (auto use_time = request.allocation_value->use_times().rbegin();
|
||||
use_time != request.allocation_value->use_times().rend() &&
|
||||
*use_time >= end_time;
|
||||
++use_time) {
|
||||
alternate_mem_interval->end = *use_time;
|
||||
for (auto use_it = request.allocation_value->uses().rbegin();
|
||||
use_it != request.allocation_value->uses().rend() &&
|
||||
use_it->time >= end_time;
|
||||
++use_it) {
|
||||
alternate_mem_interval->end = use_it->time;
|
||||
ChunkCandidate chunk_candidate =
|
||||
FindChunkCandidate(*alternate_mem_interval);
|
||||
if (chunk_candidate.heap_size <= available_heap_size()) {
|
||||
|
@ -620,6 +620,18 @@ class MemorySpaceAssignment {
|
||||
// add.5, operand 0
|
||||
class AllocationValue {
|
||||
public:
|
||||
// This data structure wraps an HloUse and adds additional metadata that are
|
||||
// useful for allocation.
|
||||
struct Use {
|
||||
// The wrapped HloUse object.
|
||||
HloUse hlo_use;
|
||||
// The logical time this use is scheduled.
|
||||
int64 time;
|
||||
// All the positions where this use aliases with. The aliased positions
|
||||
// must get the same allocation.
|
||||
std::vector<HloPosition> aliases;
|
||||
};
|
||||
|
||||
AllocationValue(const HloValue* value, const HloPosition& position)
|
||||
: value_(value), defining_position_(position) {}
|
||||
|
||||
@ -627,8 +639,8 @@ class MemorySpaceAssignment {
|
||||
const HloInstruction* defining_instruction() const {
|
||||
return defining_position().instruction;
|
||||
}
|
||||
const std::vector<HloUse>& uses() const { return uses_; }
|
||||
const std::vector<int64>& use_times() const { return use_times_; }
|
||||
const std::vector<Use>& uses() const { return uses_; }
|
||||
std::vector<Use>& uses() { return uses_; }
|
||||
const HloValue* value() const { return value_; }
|
||||
const HloComputation* computation() const {
|
||||
return defining_instruction()->parent();
|
||||
@ -636,8 +648,7 @@ class MemorySpaceAssignment {
|
||||
AllocationSequence* allocation_sequence() { return &allocation_sequence_; }
|
||||
|
||||
void AddUse(const HloUse& use, int64 use_time) {
|
||||
uses_.push_back(use);
|
||||
use_times_.push_back(use_time);
|
||||
uses_.push_back({use, use_time, {}});
|
||||
}
|
||||
|
||||
std::string ToString() const;
|
||||
@ -646,8 +657,7 @@ class MemorySpaceAssignment {
|
||||
private:
|
||||
const HloValue* value_;
|
||||
HloPosition defining_position_;
|
||||
std::vector<HloUse> uses_;
|
||||
std::vector<int64> use_times_;
|
||||
std::vector<Use> uses_;
|
||||
AllocationSequence allocation_sequence_;
|
||||
};
|
||||
|
||||
@ -769,10 +779,18 @@ struct RequiredMemoryAssignment {
|
||||
int64 time;
|
||||
absl::optional<HeapSimulator::Chunk> chunk;
|
||||
|
||||
bool equals_ignoring_time(const RequiredMemoryAssignment& other) const {
|
||||
return memory_space == other.memory_space && chunk == other.chunk;
|
||||
}
|
||||
|
||||
bool operator==(const RequiredMemoryAssignment& other) const {
|
||||
return memory_space == other.memory_space && time == other.time &&
|
||||
chunk == other.chunk;
|
||||
}
|
||||
|
||||
bool operator!=(const RequiredMemoryAssignment& other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
};
|
||||
|
||||
// A struct representing an asynchronous copy with its logical start and end
|
||||
@ -880,7 +898,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||
bool allow_no_copy_alternate_mem_allocation;
|
||||
absl::optional<int64> earliest_prefetch_time;
|
||||
absl::optional<int64> preferred_offset;
|
||||
HloUse use;
|
||||
const MemorySpaceAssignment::AllocationValue::Use* use;
|
||||
MemorySpaceAssignment::AllocationValue* allocation_value;
|
||||
};
|
||||
|
||||
@ -890,10 +908,6 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||
static MemorySpaceAssignment::Allocation* GetLiveAllocationAt(
|
||||
const MemorySpaceAssignment::AllocationSequence& allocations, int64 time);
|
||||
|
||||
// Returns the required assignment at a particular time, if available.
|
||||
absl::optional<RequiredMemoryAssignment> RequiredMemoryAssignmentAt(
|
||||
const HloValue* buffer, int64 time) const;
|
||||
|
||||
// Returns true if this buffer is allowed to be placed in the alternate
|
||||
// memory.
|
||||
bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const;
|
||||
@ -914,6 +928,10 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||
bool AllocateColocatedIntervals(
|
||||
const std::vector<const BufferInterval*>& colocated_intervals);
|
||||
|
||||
// Go through all the uses in the AllocationValues and find the aliasing
|
||||
// positions.
|
||||
void FindAliases(std::vector<AllocationValue>* allocation_values) const;
|
||||
|
||||
// Finds an allocation for an allocation request for a segment (see the
|
||||
// documentation for AllocationRequest above how a segment is defined).
|
||||
//
|
||||
@ -950,12 +968,14 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||
const AllocationRequest& request, absl::optional<int64> preferred_offset,
|
||||
BufferInterval* alternate_mem_interval) const;
|
||||
|
||||
// At the end of an allocation with a sequential call (while, conditional, and
|
||||
// call), this function adds the necessary aliased assignments within the
|
||||
// called computations.
|
||||
void AddAliasedRequiredAssignmentsForSequentialCall(
|
||||
const HloUse& use,
|
||||
const MemorySpaceAssignment::Allocation* aliased_allocation);
|
||||
// Returns the required assignment at a particular time, if available.
|
||||
absl::optional<RequiredMemoryAssignment> RequiredMemoryAssignmentAt(
|
||||
const HloValue* buffer, int64 time) const;
|
||||
|
||||
// Searches for aliases in the use for a required assignment, and returns it
|
||||
// if found.
|
||||
absl::optional<RequiredMemoryAssignment> AliasedRequiredAssignmentForUse(
|
||||
const AllocationValue::Use& use) const;
|
||||
|
||||
// Propagates aliased required assignment for a given position.
|
||||
void AddAliasedRequiredAssignment(
|
||||
|
@ -1635,7 +1635,8 @@ TEST_P(MemorySpaceAssignmentTest, WhileCondAliasBug) {
|
||||
%constant.5 = s32[1]{0:T(128)} constant({1})
|
||||
%prev.4 = s32[6]{0:T(128)} parameter(0)
|
||||
%rng.8 = s32[5]{0:T(128)} rng(s32[]{:T(128)} %constant.6, s32[]{:T(128)} %constant.7), distribution=rng_uniform
|
||||
ROOT %fusion = s32[6]{0:T(128)} fusion(s32[6]{0:T(128)} %prev.4, s32[1]{0:T(128)} %constant.5, s32[5]{0:T(128)} %rng.8), kind=kLoop, calls=%fused_computation
|
||||
%neg = s32[1]{0:T(128)} negate(s32[1]{0:T(128)} %constant.5)
|
||||
ROOT %fusion = s32[6]{0:T(128)} fusion(s32[6]{0:T(128)} %prev.4, s32[1]{0:T(128)} %neg, s32[5]{0:T(128)} %rng.8), kind=kLoop, calls=%fused_computation
|
||||
}
|
||||
|
||||
%WhileWithPrngScalarResult.11 (prev.12: s32[6]) -> pred[] {
|
||||
@ -1665,6 +1666,62 @@ TEST_P(MemorySpaceAssignmentTest, WhileCondAliasBug) {
|
||||
kDefaultMemorySpace);
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, WhileInPlaceBuffer) {
|
||||
// Ensure that a dynamic update slice within a while loop is able to get an
|
||||
// alternate memory allocation.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module, is_scheduled=true
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[2,3] parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast = f32[2,1] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3)
|
||||
}
|
||||
|
||||
%WhileBody (body_param: (f32[2,3], f32[2,3], f32[])) -> (f32[2,3], f32[2,3], f32[]) {
|
||||
%body_param = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) parameter(0)
|
||||
%get-tuple-element.1 = f32[] get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=2
|
||||
%get-tuple-element.2 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=0
|
||||
%get-tuple-element.3 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=1
|
||||
%fusion = f32[2,3]{1,0} fusion(get-tuple-element.3), kind=kLoop, calls=fused_computation
|
||||
%multiply = f32[2,3]{1,0} multiply(f32[2,3]{1,0} %get-tuple-element.2, f32[2,3]{1,0} %fusion)
|
||||
ROOT %tuple = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) tuple(f32[2,3]{1,0} %multiply, f32[2,3]{1,0} %fusion, f32[] %get-tuple-element.1)
|
||||
}
|
||||
|
||||
%WhileCond (cond_param: (f32[2,3], f32[2,3], f32[])) -> pred[] {
|
||||
%cond_param = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) parameter(0)
|
||||
%get-tuple-element = f32[] get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %cond_param), index=2
|
||||
%constant = f32[] constant(50)
|
||||
ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
|
||||
}
|
||||
|
||||
ENTRY %Entry (param_data: f32[2,3], param_iter: f32[], p2: f32[2,3]) -> f32[2,3] {
|
||||
%param_iter = f32[] parameter(1)
|
||||
%param_data = f32[2,3]{1,0} parameter(0)
|
||||
%p2 = f32[2,3]{1,0} parameter(2)
|
||||
%copy1 = f32[2,3]{1,0} copy(param_data)
|
||||
%copy2 = f32[2,3]{1,0} copy(p2)
|
||||
%tuple.1 = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) tuple(f32[2,3]{1,0} copy1, f32[2,3]{1,0} copy2, f32[] %param_iter)
|
||||
%while = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) while((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
|
||||
%get-tuple-element.4 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %while), index=0
|
||||
ROOT %copy3 = f32[2,3]{1,0} copy(get-tuple-element.4)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
AssignMemorySpace(module.get());
|
||||
const HloInstruction* while_op =
|
||||
module->entry_computation()->GetInstructionWithName("while");
|
||||
if (GetParam()) {
|
||||
EXPECT_EQ(
|
||||
ShapeUtil::GetSubshape(while_op->shape(), {1}).layout().memory_space(),
|
||||
kAlternateMemorySpace);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) {
|
||||
// Having control_predecessors on an HLO was preventing us from DCEing an op
|
||||
// that doesn't have any users (tuple.1). The scheduler assumes the graph is
|
||||
|
@ -413,6 +413,15 @@ absl::optional<HloInstruction*> ExchangeHalo(
|
||||
std::vector<HloInstruction*> concat_pieces;
|
||||
|
||||
int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count);
|
||||
int64 max_right_halo_size =
|
||||
right_halo_size_function.MaxInRange(0, shard_count - 1);
|
||||
if (max_left_halo_size + max_right_halo_size + input_shard_size >=
|
||||
input_shard_size * shard_count &&
|
||||
(max_left_halo_size > input_shard_size ||
|
||||
max_right_halo_size > input_shard_size)) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
// Left halo.
|
||||
for (int64 i = CeilOfRatio(max_left_halo_size, input_shard_size) - 1; i >= 0;
|
||||
--i) {
|
||||
std::vector<std::pair<int64, int64>> source_target_pairs;
|
||||
@ -447,8 +456,6 @@ absl::optional<HloInstruction*> ExchangeHalo(
|
||||
concat_pieces.push_back(hlo);
|
||||
|
||||
// Right halo.
|
||||
int64 max_right_halo_size =
|
||||
right_halo_size_function.MaxInRange(0, shard_count - 1);
|
||||
for (int64 i = 0; i < CeilOfRatio(max_right_halo_size, input_shard_size);
|
||||
++i) {
|
||||
std::vector<std::pair<int64, int64>> source_target_pairs;
|
||||
|
@ -33,6 +33,36 @@ namespace xla {
|
||||
TupleSimplifier::TupleSimplifier(bool exclude_entry_computation)
|
||||
: exclude_entry_computation_(exclude_entry_computation) {}
|
||||
|
||||
StatusOr<bool> TupleSimplifier::RemoveWholeTuple(HloInstruction* tuple) {
|
||||
bool changed = false;
|
||||
HloInstruction* top_tuple = nullptr;
|
||||
bool can_simplify = true;
|
||||
for (int64 operand_number = 0; operand_number < tuple->operand_count();
|
||||
++operand_number) {
|
||||
HloInstruction* operand = tuple->mutable_operand(operand_number);
|
||||
if (operand->opcode() != HloOpcode::kGetTupleElement ||
|
||||
operand->tuple_index() != operand_number) {
|
||||
can_simplify = false;
|
||||
break;
|
||||
}
|
||||
if (top_tuple == nullptr) {
|
||||
top_tuple = operand->mutable_operand(0);
|
||||
if (!ShapeUtil::Compatible(top_tuple->shape(), tuple->shape())) {
|
||||
can_simplify = false;
|
||||
break;
|
||||
}
|
||||
} else if (top_tuple != operand->operand(0)) {
|
||||
can_simplify = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (can_simplify && top_tuple != nullptr) {
|
||||
changed = true;
|
||||
TF_RETURN_IF_ERROR(tuple->parent()->ReplaceInstruction(tuple, top_tuple));
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
|
||||
// Initially add all GTE and Tuple instructions to the worklist.
|
||||
bool changed = false;
|
||||
@ -43,46 +73,7 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
|
||||
}
|
||||
for (auto* instruction : computation->MakeInstructionPostOrder()) {
|
||||
if (instruction->opcode() == HloOpcode::kTuple) {
|
||||
// Collapse the following structure into just 'Tuple-shaped Op':
|
||||
//
|
||||
// Tuple-shaped Op
|
||||
// |
|
||||
// +-----+-----+
|
||||
// | | |
|
||||
// GTE GTE GTE
|
||||
// | | |
|
||||
// +-----+-----+
|
||||
// |
|
||||
// Tuple
|
||||
//
|
||||
HloInstruction* top_tuple = nullptr;
|
||||
bool can_simplify = true;
|
||||
for (int64 operand_number = 0;
|
||||
operand_number < instruction->operand_count(); ++operand_number) {
|
||||
HloInstruction* operand =
|
||||
instruction->mutable_operand(operand_number);
|
||||
if (operand->opcode() != HloOpcode::kGetTupleElement ||
|
||||
operand->tuple_index() != operand_number) {
|
||||
can_simplify = false;
|
||||
break;
|
||||
}
|
||||
if (top_tuple == nullptr) {
|
||||
top_tuple = operand->mutable_operand(0);
|
||||
if (!ShapeUtil::Compatible(top_tuple->shape(),
|
||||
instruction->shape())) {
|
||||
can_simplify = false;
|
||||
break;
|
||||
}
|
||||
} else if (top_tuple != operand->operand(0)) {
|
||||
can_simplify = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (can_simplify && top_tuple != nullptr) {
|
||||
changed = true;
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation->ReplaceInstruction(instruction, top_tuple));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(changed, RemoveWholeTuple(instruction));
|
||||
} else {
|
||||
auto ancestor = instruction->LatestNonGteAncestorAndIndex();
|
||||
if (ancestor.first == instruction) {
|
||||
@ -102,6 +93,11 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
|
||||
// GTE
|
||||
// |
|
||||
// GTE
|
||||
//
|
||||
// Note that this deletes the Tuple instruction altogether. In addition,
|
||||
// if only a subset of tuple's elements are used, this transform
|
||||
// optimizes them one at a time, and after the last use is optimized,
|
||||
// the Tuple will also be deleted.
|
||||
if (ShapeUtil::Compatible(ancestor.first->shape(),
|
||||
instruction->shape())) {
|
||||
changed = true;
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
@ -41,6 +42,20 @@ class TupleSimplifier : public HloModulePass {
|
||||
// apart from the module's entry computation. This is used by Graphcore's
|
||||
// backend.
|
||||
bool exclude_entry_computation_;
|
||||
|
||||
// Collapse the following structure into just 'Tuple-shaped Op':
|
||||
//
|
||||
// Tuple-shaped Op
|
||||
// |
|
||||
// +-----+-----+
|
||||
// | | |
|
||||
// GTE GTE GTE
|
||||
// | | |
|
||||
// +-----+-----+
|
||||
// |
|
||||
// Tuple
|
||||
//
|
||||
StatusOr<bool> RemoveWholeTuple(HloInstruction* tuple);
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -216,11 +216,8 @@ TEST_F(BufferDonationTest, SimpleWhileTupleTest) {
|
||||
HloInstruction::CreateGetTupleElement(f32v1_, while0, 1));
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
// Input output aliasing is only supported on TPU.
|
||||
#if defined(XLA_TEST_BACKEND_TPU)
|
||||
TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({0}, 0, {0}));
|
||||
TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({1}, 0, {1}));
|
||||
#endif
|
||||
|
||||
auto arg = LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::CreateR0<int>(0), LiteralUtil::CreateR1<float>({1.1f})});
|
||||
|
@ -1361,6 +1361,7 @@ cc_library(
|
||||
hdrs = ["replicate_per_replica_nodes.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -666,7 +666,9 @@ Status DirectSession::RunInternal(
|
||||
|
||||
std::unique_ptr<ProfilerSession> profiler_session;
|
||||
if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
|
||||
profiler_session = ProfilerSession::Create();
|
||||
ProfileOptions options = ProfilerSession::DefaultOptions();
|
||||
options.set_host_tracer_level(0);
|
||||
profiler_session = ProfilerSession::Create(options);
|
||||
}
|
||||
|
||||
// Register this step with session's cancellation manager, so that
|
||||
|
@ -847,7 +847,7 @@ Status VerifyVirtualDeviceSettings(
|
||||
" #valid GPUs: ", valid_platform_gpu_ids.size(),
|
||||
" virtual_devices.size(): ", virtual_devices.size());
|
||||
}
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
// Check memory_limt_mb and priority sizes match if priority is non-empty.
|
||||
bool priority_exists = !virtual_devices.Get(0).priority().empty();
|
||||
for (int i = 0; i < virtual_devices.size(); ++i) {
|
||||
@ -893,15 +893,6 @@ Status VerifyVirtualDeviceSettings(
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
for (int i = 0; i < virtual_devices.size(); ++i) {
|
||||
if (!virtual_devices.Get(i).priority().empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"Priority is supported only on Nvidia GPUs."
|
||||
" However, priority is set for virtual device ",
|
||||
i, ", which corresponds to a non Nvidia GPU");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return Status::OK();
|
||||
@ -1185,6 +1176,18 @@ Status BaseGPUDeviceFactory::CreateDevices(
|
||||
platform_gpu_id.value(),
|
||||
" failed. Status: ", hipGetErrorString(err));
|
||||
}
|
||||
int priority_low, priority_high;
|
||||
hipDeviceGetStreamPriorityRange(&priority_low, &priority_high);
|
||||
if (err != hipSuccess) {
|
||||
return errors::Internal(
|
||||
"hipDeviceGetStreamPriorityRange() on GPU:", original_device,
|
||||
" failed. Status: ", hipGetErrorString(err));
|
||||
}
|
||||
VLOG(1) << "HIP stream priority range on GPU(" << original_device
|
||||
<< "): " << priority_high << "," << priority_low;
|
||||
supported_priority_ranges.insert(
|
||||
std::make_pair(platform_gpu_id.value(),
|
||||
std::make_pair(priority_low, priority_high)));
|
||||
#endif
|
||||
}
|
||||
// Reset to the original device.
|
||||
|
@ -229,52 +229,89 @@ TEST_F(GPUDeviceTest, SingleVirtualDeviceWithMemoryLimitAndNoPriority) {
|
||||
|
||||
TEST_F(GPUDeviceTest, SingleVirtualDeviceWithInvalidPriority) {
|
||||
{
|
||||
// Priority outside the range (-1, 0).
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
// Priority outside the range (0, 2) for AMD GPUs
|
||||
SessionOptions opts =
|
||||
MakeSessionOptions("0", 0, 1, {{123, 456}}, {{-1, 2}});
|
||||
#else
|
||||
// Priority outside the range (-1, 0) for NVidia GPUs
|
||||
SessionOptions opts =
|
||||
MakeSessionOptions("0", 0, 1, {{123, 456}}, {{-2, 0}});
|
||||
#endif
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
|
||||
opts, kDeviceNamePrefix, &devices);
|
||||
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
ExpectErrorMessageSubstr(
|
||||
status,
|
||||
"Priority -1 is outside the range of supported priorities [0,2] for"
|
||||
" virtual device 0 on GPU# 0");
|
||||
#else
|
||||
ExpectErrorMessageSubstr(
|
||||
status,
|
||||
"Priority -2 is outside the range of supported priorities [-1,0] for"
|
||||
" virtual device 0 on GPU# 0");
|
||||
#endif
|
||||
}
|
||||
{
|
||||
// Priority outside the range (-1, 0).
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
// Priority outside the range (0, 2) for AMD GPUs
|
||||
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}, {{0, 3}});
|
||||
#else
|
||||
// Priority outside the range (-1, 0) for NVidia GPUs
|
||||
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}, {{0, 1}});
|
||||
#endif
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
|
||||
opts, kDeviceNamePrefix, &devices);
|
||||
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
ExpectErrorMessageSubstr(
|
||||
status,
|
||||
"Priority 3 is outside the range of supported priorities [0,2] for"
|
||||
" virtual device 0 on GPU# 0");
|
||||
#else
|
||||
ExpectErrorMessageSubstr(
|
||||
status,
|
||||
"Priority 1 is outside the range of supported priorities [-1,0] for"
|
||||
" virtual device 0 on GPU# 0");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GPUDeviceTest, SingleVirtualDeviceWithMemoryLimitAndPriority) {
|
||||
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123}}, {{-1}});
|
||||
// 0 is a valid priority value for both AMD and NVidia GPUs
|
||||
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123}}, {{0}});
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
|
||||
opts, kDeviceNamePrefix, &devices));
|
||||
EXPECT_EQ(1, devices.size());
|
||||
EXPECT_EQ(123 << 20, devices[0]->attributes().memory_limit());
|
||||
EXPECT_EQ(-1, static_cast<BaseGPUDevice*>(devices[0].get())->priority());
|
||||
EXPECT_EQ(0, static_cast<BaseGPUDevice*>(devices[0].get())->priority());
|
||||
}
|
||||
|
||||
TEST_F(GPUDeviceTest, MultipleVirtualDevices) {
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
// Valid range for priority values on AMD GPUs in (0,2)
|
||||
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}, {{0, 1}});
|
||||
#else
|
||||
// Valid range for priority values on NVidia GPUs in (-1, 0)
|
||||
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}, {{0, -1}});
|
||||
#endif
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
|
||||
opts, kDeviceNamePrefix, &devices));
|
||||
EXPECT_EQ(2, devices.size());
|
||||
EXPECT_EQ(123 << 20, devices[0]->attributes().memory_limit());
|
||||
EXPECT_EQ(456 << 20, devices[1]->attributes().memory_limit());
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
EXPECT_EQ(0, static_cast<BaseGPUDevice*>(devices[0].get())->priority());
|
||||
EXPECT_EQ(1, static_cast<BaseGPUDevice*>(devices[1].get())->priority());
|
||||
#else
|
||||
EXPECT_EQ(0, static_cast<BaseGPUDevice*>(devices[0].get())->priority());
|
||||
EXPECT_EQ(-1, static_cast<BaseGPUDevice*>(devices[1].get())->priority());
|
||||
#endif
|
||||
ASSERT_EQ(1, devices[0]->attributes().locality().links().link_size());
|
||||
ASSERT_EQ(1, devices[1]->attributes().locality().links().link_size());
|
||||
EXPECT_EQ(1, devices[0]->attributes().locality().links().link(0).device_id());
|
||||
@ -292,7 +329,8 @@ TEST_F(GPUDeviceTest, MultipleVirtualDevices) {
|
||||
TEST_F(GPUDeviceTest, MultipleVirtualDevicesWithPriority) {
|
||||
{
|
||||
// Multile virtual devices with fewer priorities.
|
||||
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}, {{-1}});
|
||||
// 0 is a valid priority value for both AMD and NVidia GPUs
|
||||
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}, {{0}});
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
|
||||
opts, kDeviceNamePrefix, &devices);
|
||||
@ -305,16 +343,27 @@ TEST_F(GPUDeviceTest, MultipleVirtualDevicesWithPriority) {
|
||||
}
|
||||
{
|
||||
// Multile virtual devices with matching priority.
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
// Valid range for priority values on AMD GPUs in (0,2)
|
||||
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}, {{2, 1}});
|
||||
#else
|
||||
// Valid range for priority values on NVidia GPUs in (-1, 0)
|
||||
SessionOptions opts =
|
||||
MakeSessionOptions("0", 0, 1, {{123, 456}}, {{-1, 0}});
|
||||
#endif
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
|
||||
opts, kDeviceNamePrefix, &devices));
|
||||
EXPECT_EQ(2, devices.size());
|
||||
EXPECT_EQ(123 << 20, devices[0]->attributes().memory_limit());
|
||||
EXPECT_EQ(456 << 20, devices[1]->attributes().memory_limit());
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
EXPECT_EQ(2, static_cast<BaseGPUDevice*>(devices[0].get())->priority());
|
||||
EXPECT_EQ(1, static_cast<BaseGPUDevice*>(devices[1].get())->priority());
|
||||
#else
|
||||
EXPECT_EQ(-1, static_cast<BaseGPUDevice*>(devices[0].get())->priority());
|
||||
EXPECT_EQ(0, static_cast<BaseGPUDevice*>(devices[1].get())->priority());
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -394,10 +394,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back({csinfo_.batch_matmul,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.batch_matmul),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
|
||||
CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
|
||||
rinfo_.push_back({csinfo_.batch_matmul_v2,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.batch_matmul_v2),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
|
||||
CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.concat, mkl_op_registry::GetMklOpName(csinfo_.concat),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -25,7 +25,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
#ifdef _OPENMP
|
||||
#if defined(_OPENMP) && !defined(ENABLE_MKLDNN_THREADPOOL)
|
||||
TEST(MKLThreadPoolDeviceTest, TestOmpDefaults) {
|
||||
SessionOptions options;
|
||||
unsetenv("OMP_NUM_THREADS");
|
||||
@ -46,7 +46,7 @@ TEST(MKLThreadPoolDeviceTest, TestOmpPreSets) {
|
||||
|
||||
EXPECT_EQ(omp_get_max_threads(), 314);
|
||||
}
|
||||
#endif // _OPENMP
|
||||
#endif // defined(_OPENMP) && !defined(ENABLE_MKLDNN_THREADPOOL)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
|
||||
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -115,12 +116,36 @@ class ReplicateHelper {
|
||||
// This happens when the dst node runs on a host CPU and
|
||||
// captures a function with an arg node assigned to the same
|
||||
// composite device (e.g. ScanDataset).
|
||||
// For this case, we only need to add an edge connecting the arg
|
||||
// node in the outer function and the corresponding arg in the
|
||||
// inner function, since the host CPU only needs one copy of the
|
||||
// ResourceHandle.
|
||||
graph->AddEdge(src_replicated_nodes.at(0), edge->src_output(), dst,
|
||||
edge->dst_input());
|
||||
// For this case, we insert a PackOp between replicated nodes and the
|
||||
// dst node. The dst node is responsible for unpacking the packed
|
||||
// tensor.
|
||||
// Add '/Packed' as a substring to the name of the new node, which
|
||||
// could be helpful when debugging the graph.
|
||||
NodeDefBuilder pack_builder(
|
||||
graph->NewName(absl::StrCat(edge->src()->name(), "/Packed")),
|
||||
"Pack");
|
||||
const int num_replicas = src_replicated_nodes.size();
|
||||
pack_builder.Attr("N", num_replicas);
|
||||
const DataType dtype = edge->src()->output_type(edge->src_output());
|
||||
pack_builder.Attr("T", dtype);
|
||||
std::vector<NodeDefBuilder::NodeOut> inputs;
|
||||
inputs.reserve(src_replicated_nodes.size());
|
||||
for (Node* replicated_node : src_replicated_nodes) {
|
||||
inputs.emplace_back(NodeDefBuilder::NodeOut{
|
||||
replicated_node->name(), edge->src_output(), dtype});
|
||||
}
|
||||
pack_builder.Input(inputs);
|
||||
NodeDef pack_def;
|
||||
TF_RETURN_IF_ERROR(pack_builder.Finalize(&pack_def));
|
||||
Status status;
|
||||
Node* pack_node = graph->AddNode(pack_def, &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
pack_node->set_assigned_device_name(dst->assigned_device_name());
|
||||
for (int i = 0; i < src_replicated_nodes.size(); ++i) {
|
||||
graph->AddEdge(src_replicated_nodes[i], edge->src_output(),
|
||||
pack_node, i);
|
||||
}
|
||||
graph->AddEdge(pack_node, /*x=*/0, dst, edge->dst_input());
|
||||
} else {
|
||||
return errors::InvalidArgument(
|
||||
"Dst node should be assigned to an allowed device. Found an "
|
||||
|
@ -258,16 +258,24 @@ TEST(ReplicatePerReplicaNodesTest, NestedFunctions) {
|
||||
ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
|
||||
|
||||
{
|
||||
// _Arg(TPU:0) -> Func(CPU:0) -> _Retval(CPU:0)
|
||||
EXPECT_EQ(graph.num_op_nodes(), 4);
|
||||
// _Arg(TPU:0), _Arg(TPU:1) -> Pack(CPU:0) -> Func(CPU:0) -> _Retval(CPU:0)
|
||||
EXPECT_EQ(graph.num_op_nodes(), 5);
|
||||
GraphHelper helper(graph);
|
||||
helper.CheckAssignedDevice("arg/R0", "TPU:0");
|
||||
helper.CheckAssignedDevice("arg/R1", "TPU:1");
|
||||
helper.CheckAssignedDevice("arg/Packed", "CPU:0");
|
||||
helper.CheckAssignedDevice("func", "CPU:0");
|
||||
helper.CheckAssignedDevice("ret", "CPU:0");
|
||||
const EdgeSet& in_edges = helper.GetNodeByName("func")->in_edges();
|
||||
EXPECT_EQ(in_edges.size(), 1);
|
||||
EXPECT_EQ(helper.GetNodeByName("arg/R0"), (*in_edges.begin())->src());
|
||||
const EdgeSet& packed_in_edges =
|
||||
helper.GetNodeByName("arg/Packed")->in_edges();
|
||||
EXPECT_EQ(packed_in_edges.size(), 2);
|
||||
auto it = packed_in_edges.begin();
|
||||
EXPECT_EQ(helper.GetNodeByName("arg/R0"), (*it++)->src());
|
||||
EXPECT_EQ(helper.GetNodeByName("arg/R1"), (*it)->src());
|
||||
const EdgeSet& func_in_edges = helper.GetNodeByName("func")->in_edges();
|
||||
EXPECT_EQ(func_in_edges.size(), 1);
|
||||
EXPECT_EQ(helper.GetNodeByName("arg/Packed"),
|
||||
(*func_in_edges.begin())->src());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -344,6 +344,8 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
|
||||
static constexpr const char* const kDeviceRetOp = "_DeviceRetval";
|
||||
static constexpr const char* const kIntsOnDeviceAttr =
|
||||
"experimental_ints_on_device";
|
||||
static constexpr const char* const kSharedRendezvousAttr =
|
||||
"shared_rendezvous";
|
||||
|
||||
static constexpr const char* const kGradientOp = "SymbolicGradient";
|
||||
static constexpr const char* const kFuncAttr = "f";
|
||||
|
@ -59,9 +59,13 @@ class InterleaveMany : public Node {
|
||||
(*input_times)[long_name()] = old_input_time;
|
||||
return;
|
||||
}
|
||||
double new_input_time =
|
||||
old_input_time +
|
||||
SelfProcessingTimeLocked() * static_cast<double>(num_inputs() - 1);
|
||||
// Here `old_input_time + SelfProcessingTimeLocked()` is the average input
|
||||
// time for the interleave node to call one of the `(num_inputs() - 1)`
|
||||
// input nodes(except the first one) to return an element. Regardless of the
|
||||
// `block_length` parameter of interleave node, the average input time for
|
||||
// any of the `(num_inputs() - 1)` input nodes to be called is computed as:
|
||||
double new_input_time = (old_input_time + SelfProcessingTimeLocked()) *
|
||||
static_cast<double>(num_inputs() - 1);
|
||||
(*input_times)[long_name()] = new_input_time;
|
||||
}
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user