Merge branch 'master' into allow_build_at_runtime
This commit is contained in:
commit
68e9393d9f
6
.bazelrc
6
.bazelrc
@ -320,10 +320,8 @@ build:xla --define=with_xla_support=true
|
||||
# Options when using remote execution
|
||||
# WARNING: THESE OPTIONS WONT WORK IF YOU DO NOT HAVE PROPER AUTHENTICATION AND PERMISSIONS
|
||||
build:rbe --action_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1
|
||||
build:rbe --auth_enabled=true
|
||||
build:rbe --auth_scope=https://www.googleapis.com/auth/cloud-source-tools
|
||||
build:rbe --google_default_credentials
|
||||
build:rbe --bes_backend=buildeventservice.googleapis.com
|
||||
build:rbe --bes_best_effort=false
|
||||
build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations"
|
||||
build:rbe --bes_timeout=600s
|
||||
build:rbe --define=EXECUTOR=remote
|
||||
@ -336,7 +334,7 @@ build:rbe --spawn_strategy=remote,worker,standalone,local
|
||||
test:rbe --test_env=USER=anon
|
||||
# Attempt to minimize the amount of data transfer between bazel and the remote
|
||||
# workers:
|
||||
build:rbe --experimental_inmemory_jdeps_files --experimental_inmemory_dotd_files --experimental_remote_download_outputs=toplevel
|
||||
build:rbe --remote_download_toplevel
|
||||
|
||||
build:rbe_linux --config=rbe
|
||||
build:rbe_linux --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin"
|
||||
|
25
WORKSPACE
25
WORKSPACE
@ -113,3 +113,28 @@ http_archive(
|
||||
"https://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
|
||||
],
|
||||
)
|
||||
|
||||
# Required for dependency @com_github_grpc_grpc
|
||||
|
||||
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
|
||||
|
||||
grpc_deps()
|
||||
|
||||
load(
|
||||
"@build_bazel_rules_apple//apple:repositories.bzl",
|
||||
"apple_rules_dependencies",
|
||||
)
|
||||
|
||||
apple_rules_dependencies()
|
||||
|
||||
load(
|
||||
"@build_bazel_apple_support//lib:repositories.bzl",
|
||||
"apple_support_dependencies",
|
||||
)
|
||||
|
||||
apple_support_dependencies()
|
||||
|
||||
load("@upb//bazel:repository_defs.bzl", "bazel_version_repository")
|
||||
|
||||
bazel_version_repository(name = "bazel_version")
|
||||
|
||||
|
@ -505,13 +505,15 @@ selects.config_setting_group(
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
# To pass open source testing in the pip Kokoros.
|
||||
"//bazel_pip/tensorflow/...",
|
||||
"//learning/brain/swift/x10/...",
|
||||
"//perftools/accelerators/xprof/api/...",
|
||||
"//third_party/py/autograph/...",
|
||||
"//third_party/swift/tensorflow/x10/...",
|
||||
"//tensorflow/...",
|
||||
"//tensorflow_estimator/python/estimator/...",
|
||||
"//tensorflow_models/official/...",
|
||||
"//third_party/py/autograph/...",
|
||||
"//third_party/swift/tensorflow/x10/...",
|
||||
],
|
||||
)
|
||||
|
||||
@ -545,8 +547,8 @@ cc_library(
|
||||
name = "grpc",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
":linux_s390x": ["@grpc//:grpc_unsecure"],
|
||||
"//conditions:default": ["@grpc"],
|
||||
":linux_s390x": ["@com_github_grpc_grpc//:grpc_unsecure"],
|
||||
"//conditions:default": ["@com_github_grpc_grpc//:grpc"],
|
||||
}),
|
||||
)
|
||||
|
||||
@ -554,8 +556,8 @@ cc_library(
|
||||
name = "grpc++",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
":linux_s390x": ["@grpc//:grpc++_unsecure"],
|
||||
"//conditions:default": ["@grpc//:grpc++"],
|
||||
":linux_s390x": ["@com_github_grpc_grpc//:grpc++_unsecure"],
|
||||
"//conditions:default": ["@com_github_grpc_grpc//:grpc++"],
|
||||
}),
|
||||
)
|
||||
|
||||
|
@ -1883,6 +1883,8 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"EmptyTensorList",
|
||||
"ExtractImagePatches",
|
||||
"Igamma",
|
||||
"IgammaGradA",
|
||||
"RandomGammaGrad",
|
||||
"Igammac",
|
||||
"FFT",
|
||||
"FFT2D",
|
||||
@ -1909,7 +1911,6 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"LinSpace",
|
||||
"ListDiff",
|
||||
"LogMatrixDeterminant",
|
||||
"LowerBound",
|
||||
"MatMul",
|
||||
"MatrixBandPart",
|
||||
"MatrixDiag",
|
||||
@ -2036,7 +2037,6 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"TensorScatterUpdate",
|
||||
"TridiagonalSolve",
|
||||
"TruncatedNormal",
|
||||
"UpperBound",
|
||||
"UnsortedSegmentMax",
|
||||
"UnsortedSegmentMin",
|
||||
"UnsortedSegmentProd",
|
||||
|
@ -20,15 +20,17 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr,
|
||||
const NodeDef& node_def) const {
|
||||
return CanCreateXlaKernel(node_def);
|
||||
bool XlaKernelCreator::CanCreateKernel(
|
||||
const FunctionLibraryRuntime& flr,
|
||||
const std::shared_ptr<const NodeProperties>& props) const {
|
||||
return CanCreateXlaKernel(props->node_def);
|
||||
}
|
||||
|
||||
Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr,
|
||||
const NodeDef& node_def,
|
||||
std::unique_ptr<OpKernel>* kernel) const {
|
||||
return CreateXlaKernel(flr, node_def, kernel);
|
||||
Status XlaKernelCreator::CreateKernel(
|
||||
FunctionLibraryRuntime* flr,
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
std::unique_ptr<OpKernel>* kernel) const {
|
||||
return CreateXlaKernel(flr, props->node_def, kernel);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -29,11 +29,13 @@ class XlaKernelCreator : public CustomKernelCreator {
|
||||
// Given a NodeDef 'node_def' and the function library runtime 'flr', returns
|
||||
// true if 'node_def' is a call to a compilable function defined in 'flr',
|
||||
// with the kXlaCompileAttr set.
|
||||
bool CanCreateKernel(const FunctionLibraryRuntime& flr,
|
||||
const NodeDef& node_def) const override;
|
||||
bool CanCreateKernel(
|
||||
const FunctionLibraryRuntime& flr,
|
||||
const std::shared_ptr<const NodeProperties>& props) const override;
|
||||
|
||||
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
|
||||
Status CreateKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
Status CreateKernel(FunctionLibraryRuntime* flr,
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
std::unique_ptr<OpKernel>* kernel) const override;
|
||||
};
|
||||
|
||||
|
@ -30,10 +30,12 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
NodeDef ToNodeDef(const string& text) {
|
||||
std::shared_ptr<NodeProperties> ToNodeProperties(const string& text) {
|
||||
NodeDef node_def;
|
||||
DataTypeVector dummy;
|
||||
EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
|
||||
return node_def;
|
||||
return std::make_shared<NodeProperties>(nullptr, std::move(node_def), dummy,
|
||||
dummy);
|
||||
}
|
||||
|
||||
// Create a FunctionDef that takes one resource and one regular param
|
||||
@ -98,11 +100,11 @@ TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) {
|
||||
(*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
|
||||
Init({fdef});
|
||||
XlaKernelCreator xla_kernel_creator;
|
||||
NodeDef callsite =
|
||||
ToNodeDef(R"pb(
|
||||
auto callsite =
|
||||
ToNodeProperties(R"pb(
|
||||
name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
|
||||
)pb");
|
||||
(*callsite.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
|
||||
(*(callsite->node_def.mutable_attr()))["_XlaMustCompile"] = BoolAttr(true);
|
||||
|
||||
// Note: need to set attribute on the created node.
|
||||
Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_);
|
||||
@ -127,13 +129,14 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) {
|
||||
Init({fdef});
|
||||
XlaKernelCreator xla_kernel_creator;
|
||||
|
||||
Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto(
|
||||
name: 'XTimesY'
|
||||
op: 'XTimesY'
|
||||
input: 'a'
|
||||
input: 'b'
|
||||
)proto"),
|
||||
&kernel_);
|
||||
Status status =
|
||||
xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
|
||||
name: 'XTimesY'
|
||||
op: 'XTimesY'
|
||||
input: 'a'
|
||||
input: 'b'
|
||||
)proto"),
|
||||
&kernel_);
|
||||
EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
|
||||
}
|
||||
|
||||
@ -143,13 +146,14 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
|
||||
Init({fdef});
|
||||
XlaKernelCreator xla_kernel_creator;
|
||||
|
||||
Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto(
|
||||
name: 'XTimesY'
|
||||
op: 'XTimesY'
|
||||
input: 'a'
|
||||
input: 'b'
|
||||
)proto"),
|
||||
&kernel_);
|
||||
Status status =
|
||||
xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
|
||||
name: 'XTimesY'
|
||||
op: 'XTimesY'
|
||||
input: 'a'
|
||||
input: 'b'
|
||||
)proto"),
|
||||
&kernel_);
|
||||
EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
|
||||
}
|
||||
|
||||
|
@ -218,12 +218,13 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
||||
Device* dev = flr->device();
|
||||
Status s;
|
||||
OpKernelConstruction construction(
|
||||
DeviceType(dev->device_type()), dev,
|
||||
dev->GetAllocator(AllocatorAttributes()), &node_def,
|
||||
&fbody->fdef.signature(), flr, dev->resource_manager(), fbody->arg_types,
|
||||
input_memory_types, fbody->ret_types, output_memory_types,
|
||||
flr->graph_def_version(), &s);
|
||||
auto props = std::make_shared<NodeProperties>(
|
||||
&fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types);
|
||||
OpKernelConstruction construction(DeviceType(dev->device_type()), dev,
|
||||
dev->GetAllocator(AllocatorAttributes()),
|
||||
flr, dev->resource_manager(), props,
|
||||
input_memory_types, output_memory_types,
|
||||
flr->graph_def_version(), &s);
|
||||
|
||||
*kernel = absl::make_unique<XlaLocalLaunchBase>(
|
||||
&construction, constant_arg_indices, resource_arg_indices, function,
|
||||
|
@ -208,6 +208,7 @@ cc_library(
|
||||
"ir/tfl_ops.h.inc",
|
||||
"ir/tfl_ops_interface.cc.inc",
|
||||
"ir/tfl_ops_interface.h.inc",
|
||||
"runtime_verifiers.inc",
|
||||
"utils/attribute_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
@ -303,12 +304,14 @@ cc_library(
|
||||
"transforms/optimize_functional_ops.cc",
|
||||
"transforms/prepare_composite_functions_tf.cc",
|
||||
"transforms/prepare_tf.cc",
|
||||
"transforms/runtime_type_verify.cc",
|
||||
"transforms/split_merged_operands.cc",
|
||||
"transforms/trim_functions_tf.cc",
|
||||
"transforms/unroll_batch_matmul.cc",
|
||||
"transforms/while_loop_outline.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"ir/tfl_ops_interface.h.inc",
|
||||
"transforms/dilated_conv.h",
|
||||
"transforms/passes.h",
|
||||
"transforms/unroll_batch_matmul.h",
|
||||
@ -461,9 +464,9 @@ cc_library(
|
||||
)
|
||||
|
||||
tf_native_cc_binary(
|
||||
name = "operator-converter-gen",
|
||||
name = "converter-gen",
|
||||
srcs = [
|
||||
"operator_converter_gen.cc",
|
||||
"converter_gen.cc",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
@ -473,14 +476,18 @@ tf_native_cc_binary(
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "operator_converter_inc",
|
||||
name = "converter_inc",
|
||||
tbl_outs = [
|
||||
(
|
||||
"", # This driver has no options.
|
||||
"--gen-operator-converters",
|
||||
"operator_converters.inc",
|
||||
),
|
||||
(
|
||||
"--gen-runtime-verifiers",
|
||||
"runtime_verifiers.inc",
|
||||
),
|
||||
],
|
||||
tblgen = ":operator-converter-gen",
|
||||
tblgen = ":converter-gen",
|
||||
td_file = "ir/tfl_ops.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
@ -650,6 +657,7 @@ tf_cc_binary(
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
@ -28,6 +28,9 @@ limitations under the License.
|
||||
#include "llvm/TableGen/Record.h"
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
#include "mlir/TableGen/Attribute.h" // TF:llvm-project
|
||||
#include "mlir/TableGen/Format.h" // TF:llvm-project
|
||||
#include "mlir/TableGen/Operator.h" // TF:llvm-project
|
||||
#include "mlir/TableGen/Predicate.h" // TF:llvm-project
|
||||
|
||||
using llvm::DefInit;
|
||||
using llvm::dyn_cast;
|
||||
@ -41,6 +44,19 @@ using llvm::SmallVector;
|
||||
using llvm::StringInit;
|
||||
using llvm::StringRef;
|
||||
|
||||
enum ActionType {
|
||||
OpConv,
|
||||
RuntimeVerify,
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
llvm::cl::opt<ActionType> action(
|
||||
llvm::cl::desc("Action to perform:"),
|
||||
llvm::cl::values(clEnumValN(OpConv, "gen-operator-converters",
|
||||
"Generate operator converters"),
|
||||
clEnumValN(RuntimeVerify, "gen-runtime-verifiers",
|
||||
"Generate TFLite runtime verifiers")));
|
||||
|
||||
// Returns the associated option name for the given op definition.
|
||||
static inline std::string GetOperatorOptionName(const Record &def) {
|
||||
assert(def.getName().startswith("TFL_") && "unexpected op prefix");
|
||||
@ -342,8 +358,101 @@ static bool OperatorWritersMain(raw_ostream &os, RecordKeeper &records) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static void GenOperandResultVerifier(raw_ostream &os,
|
||||
llvm::ArrayRef<llvm::Init *> values,
|
||||
StringRef valueKind) {
|
||||
mlir::tblgen::FmtContext fctx;
|
||||
|
||||
bool first = true;
|
||||
for (auto static_value : llvm::enumerate(values)) {
|
||||
auto *definit = llvm::cast<llvm::DefInit>(static_value.value());
|
||||
auto *val = definit->getDef()->getValue("tflRuntimeTypePredicate");
|
||||
if (!val) continue;
|
||||
|
||||
// Create code block on first type to verify.
|
||||
if (first) {
|
||||
os << " {\n";
|
||||
os << " unsigned index = " << static_value.index() << ";\n";
|
||||
first = false;
|
||||
}
|
||||
|
||||
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
|
||||
auto desc =
|
||||
definit->getDef()->getValueAsString("tflRuntimeTypeDescription");
|
||||
|
||||
// Emit a loop to check all the dynamic values in the pack.
|
||||
os << formatv(" for (Value v : top.getODS{0}{1}s({2})) {{\n",
|
||||
// Capitalize the first letter to match the function name
|
||||
valueKind.substr(0, 1).upper(), valueKind.substr(1),
|
||||
static_value.index());
|
||||
|
||||
os << " (void)v;\n"
|
||||
<< " if (!("
|
||||
<< tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n"
|
||||
<< formatv(
|
||||
" return op->emitOpError(\"{0} #\") << index "
|
||||
"<< \" must be {1}, but got \" << v.getType();\n",
|
||||
valueKind, desc)
|
||||
<< " }\n" // if
|
||||
<< " ++index;\n"
|
||||
<< " }\n"; // for
|
||||
}
|
||||
|
||||
// Emit closing brace if needed.
|
||||
if (!first) os << " }\n";
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||
emitSourceFileHeader("MLIR TFLite Runtime Verifiers", os);
|
||||
|
||||
// Retrieve all the definitions derived from TFL_Op and sort by record name.
|
||||
std::vector<Record *> defs = records.getAllDerivedDefinitions("Op");
|
||||
llvm::sort(defs, LessRecord());
|
||||
|
||||
// Iterate through all the ops defined.
|
||||
for (const auto *def : defs) {
|
||||
mlir::tblgen::Operator op(*def);
|
||||
if (!op.getTrait("TflRuntimeVerifyOpInterface::Trait")) continue;
|
||||
|
||||
mlir::tblgen::FmtContext verify_ctx;
|
||||
os << "::mlir::LogicalResult " << op.getCppClassName()
|
||||
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op) {\n";
|
||||
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
|
||||
verify_ctx.withOp("top");
|
||||
|
||||
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
||||
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
||||
auto &value = op.getOperand(i);
|
||||
// Skip from from first variadic operands for now. Else getOperand index
|
||||
// used below doesn't match.
|
||||
if (value.isVariadic()) break;
|
||||
if (!value.name.empty())
|
||||
verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i));
|
||||
}
|
||||
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
|
||||
auto &value = op.getResult(i);
|
||||
// Skip from from first variadic results for now. Else getResult index
|
||||
// used below doesn't match.
|
||||
if (value.isVariadic()) break;
|
||||
if (!value.name.empty())
|
||||
verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i));
|
||||
}
|
||||
}
|
||||
GenOperandResultVerifier(os, def->getValueAsDag("arguments")->getArgs(),
|
||||
"operand");
|
||||
GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(),
|
||||
"result");
|
||||
os << " return mlir::success();\n}\n";
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv);
|
||||
return TableGenMain(argv[0], &OperatorWritersMain);
|
||||
if (action == ActionType::OpConv)
|
||||
return TableGenMain(argv[0], &OperatorWritersMain);
|
||||
return TableGenMain(argv[0], &RuntimeVerifierWriterMain);
|
||||
}
|
@ -71,4 +71,23 @@ def TFL_SparseOp : OpInterface<"SparseOpInterface"> {
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL runtime type verification of operand/result types.
|
||||
|
||||
def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
|
||||
let description = [{
|
||||
Interface to verify TFLite runtime op verification.
|
||||
|
||||
This verifies that the converted TFLite ops has operand/result type
|
||||
supported by the TFLite runtime.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
StaticInterfaceMethod<
|
||||
[{Returns whether the op's operands/results are supported by runtime.}],
|
||||
"LogicalResult", "VerifyTflRuntimeTypes", (ins "Operation*":$op)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
#endif // TFL_OP_INTERFACES
|
||||
|
@ -1872,6 +1872,7 @@ LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) {
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc"
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
|
||||
#include "tensorflow/compiler/mlir/lite/runtime_verifiers.inc"
|
||||
|
||||
Operation *TensorFlowLiteDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value,
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -282,6 +282,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
if (pass_config.legalize_tf_while) {
|
||||
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
||||
}
|
||||
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
|
||||
|
||||
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
|
||||
|
@ -150,7 +150,8 @@ struct QuantizationPattern : public RewritePattern {
|
||||
|
||||
explicit QuantizationPattern(MLIRContext* context, bool enable_verify,
|
||||
float error_tolerance, bool single_layer_verify)
|
||||
: RewritePattern(DQ::getOperationName(), 1, context),
|
||||
// Set the score to a large number so it is always preferred.
|
||||
: RewritePattern(DQ::getOperationName(), 300, context),
|
||||
enable_verify(enable_verify),
|
||||
error_tolerance(error_tolerance),
|
||||
single_layer_verify(single_layer_verify) {}
|
||||
|
@ -1373,3 +1373,37 @@ func @reciprocal_i64(%arg0: tensor<8xi64>) -> tensor<8xi64> {
|
||||
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xi64>, tensor<8xi64>) -> tensor<8xi64>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @random_uniform() -> tensor<2x5xf32> {
|
||||
%0 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32>
|
||||
%1 = "tf.RandomUniform"(%0) { seed = 1, seed2 = 0} : (tensor<2xi32>) -> tensor<2x5xf32>
|
||||
return %1 : tensor<2x5xf32>
|
||||
|
||||
// CHECK-LABEL: random_uniform
|
||||
// CHECK: %[[CST:.*]] = constant dense
|
||||
// CHECK: return %[[CST:.*]] : tensor<2x5xf32>
|
||||
}
|
||||
|
||||
func @random_uniform_no_fold(%arg0: tensor<2xi32>) -> tensor<2x5xf32> {
|
||||
%1 = "tf.RandomUniform"(%arg0) { seed = 0, seed2 = 0} : (tensor<2xi32>) -> tensor<2x5xf32>
|
||||
return %1 : tensor<2x5xf32>
|
||||
|
||||
// CHECK-LABEL: random_uniform_no_fold
|
||||
// CHECK: %[[RANDOM:.*]] = "tf.RandomUniform"
|
||||
}
|
||||
|
||||
func @random_uniform_no_fold2(%arg0: tensor<2xi32>) -> tensor<*xf32> {
|
||||
%1 = "tf.RandomUniform"(%arg0) { seed = 1, seed2 = 2} : (tensor<2xi32>) -> tensor<*xf32>
|
||||
return %1 : tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: random_uniform_no_fold2
|
||||
// CHECK: %[[RANDOM:.*]] = "tf.RandomUniform"
|
||||
}
|
||||
|
||||
func @random_uniform_no_fold3(%arg0: tensor<2xi32>) -> tensor<*xf64> {
|
||||
%1 = "tf.RandomUniform"(%arg0) { seed = 1, seed2 = 2} : (tensor<2xi32>) -> tensor<*xf64>
|
||||
return %1 : tensor<*xf64>
|
||||
|
||||
// CHECK-LABEL: random_uniform_no_fold3
|
||||
// CHECK: %[[RANDOM:.*]] = "tf.RandomUniform"
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt -split-input-file -verify-diagnostics %s | FileCheck %s --dump-input-on-failure
|
||||
// RUN: tf-opt -split-input-file -verify-diagnostics -tfl-runtime-verify %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// Unary math ops
|
||||
// -----
|
||||
|
@ -2,39 +2,44 @@
|
||||
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize -tfl-numeric-verify | FileCheck --check-prefix=DEBUG %s
|
||||
|
||||
// CHECK-LABEL: QuantizeFloatConst
|
||||
func @QuantizeFloatConst() -> tensor<f32> {
|
||||
func @QuantizeFloatConst() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
|
||||
%0 = constant dense<-0.1> : tensor<2x2xf32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
%2 = "tfl.dequantize"(%1) : (tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<f32>
|
||||
return %2 : tensor<f32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
return %1 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
|
||||
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<0> : tensor<2x2xi8>}
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[cst]])
|
||||
// CHECK: return %[[dq]] : tensor<f32>
|
||||
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<0> : tensor<2x2xi8>}
|
||||
// CHECK: return %[[cst]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeDenseFloatConst
|
||||
func @QuantizeDenseFloatConst() -> tensor<2x2xf32> {
|
||||
func @QuantizeDenseFloatConst() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
|
||||
%0 = constant dense<[[-0.1, 1.0], [1.0, 3.0]]> : tensor<2x2xf32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<2x2xf32>
|
||||
return %2 : tensor<2x2xf32>
|
||||
return %1 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
|
||||
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<{{\[\[}}0, -1], {{\[}}-1, -1]]> : tensor<2x2xi8>}
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[cst]])
|
||||
// CHECK: return %[[dq]] : tensor<2x2xf32>
|
||||
// CHECK: return %[[cst]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeSplatFloatConst
|
||||
func @QuantizeSplatFloatConst() -> tensor<2x2xf32> {
|
||||
func @QuantizeSplatFloatConst() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
|
||||
%0 = constant dense<3.0> : tensor<2x2xf32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
return %1 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
|
||||
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<-1> : tensor<2x2xi8>}
|
||||
// CHECK: return %[[cst]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: NotQuantizeFloatConst
|
||||
func @NotQuantizeFloatConst() -> tensor<2x2xf32> {
|
||||
%0 = constant dense<-0.1> : tensor<2x2xf32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<2x2xf32>
|
||||
return %2 : tensor<2x2xf32>
|
||||
|
||||
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<-1> : tensor<2x2xi8>}
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[cst]])
|
||||
// CHECK: return %[[dq]] : tensor<2x2xf32>
|
||||
// CHECK: %[[cst:.*]] = constant dense<-1.000000e-01> : tensor<2x2xf32>
|
||||
// CHECK: return %[[cst]] : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: DequantizeAndQuantize
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
@ -32,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
@ -182,6 +184,7 @@ int main(int argc, char **argv) {
|
||||
pass_config.inline_functions = inline_functions;
|
||||
|
||||
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
|
||||
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
|
||||
|
||||
std::string result;
|
||||
auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
|
@ -49,6 +49,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -114,9 +116,54 @@ DECL_CONVERT_OP(SplitV);
|
||||
DECL_CONVERT_OP(StridedSlice);
|
||||
DECL_CONVERT_OP(Unpack);
|
||||
DECL_CONVERT_OP(Reciprocal);
|
||||
DECL_CONVERT_OP(RandomUniform);
|
||||
|
||||
#undef DECL_CONVERT_OP
|
||||
|
||||
PatternMatchResult ConvertTFRandomUniformOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto random_uniform_op = cast<TF::RandomUniformOp>(op);
|
||||
if (random_uniform_op.seed() == 0 && random_uniform_op.seed2() == 0) {
|
||||
return matchFailure();
|
||||
}
|
||||
if (!random_uniform_op.dtype().isF32()) {
|
||||
return matchFailure();
|
||||
}
|
||||
typedef tensorflow::random::UniformDistribution<
|
||||
tensorflow::random::PhiloxRandom, float>
|
||||
Distribution;
|
||||
|
||||
tensorflow::random::PhiloxRandom generator(
|
||||
random_uniform_op.seed().getSExtValue(),
|
||||
random_uniform_op.seed2().getSExtValue());
|
||||
Distribution dist;
|
||||
int num_elements = 0;
|
||||
if (auto output_type =
|
||||
random_uniform_op.output().getType().dyn_cast_or_null<ShapedType>()) {
|
||||
if (auto ranked_output = output_type.dyn_cast_or_null<RankedTensorType>()) {
|
||||
if (!ranked_output.hasRank() || ranked_output.getNumDynamicDims() != 0) {
|
||||
return matchFailure();
|
||||
}
|
||||
num_elements = output_type.getNumElements();
|
||||
size_t offset = 0;
|
||||
size_t num_samples = Distribution::kResultElementCount;
|
||||
llvm::SmallVector<float, 32> data;
|
||||
data.resize(num_elements);
|
||||
while (offset < num_elements) {
|
||||
const typename Distribution::ResultType samples = dist(&generator);
|
||||
std::copy(&samples[0],
|
||||
&samples[0] + std::min(num_samples, data.size() - offset),
|
||||
&data[0] + offset);
|
||||
offset += num_samples;
|
||||
}
|
||||
auto output_data = DenseFPElementsAttr::get(output_type, data);
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(op, output_type, output_data);
|
||||
return matchSuccess();
|
||||
}
|
||||
}
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
PatternMatchResult ConvertTFConcatOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_concat_op = cast<TF::ConcatOp>(op);
|
||||
@ -521,11 +568,12 @@ void LegalizeTF::runOnFunction() {
|
||||
|
||||
// Add the generated patterns to the list.
|
||||
populateWithGenerated(ctx, &patterns);
|
||||
patterns.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
|
||||
ConvertTFMatrixDiagV2Op, ConvertTFMatrixDiagV3Op,
|
||||
ConvertTFPackOp, ConvertTFReshapeOp, ConvertTFSplitOp,
|
||||
ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp,
|
||||
ConvertTFAssertOp, ConvertTFReciprocalOp>(ctx);
|
||||
patterns
|
||||
.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
|
||||
ConvertTFMatrixDiagV2Op, ConvertTFMatrixDiagV3Op, ConvertTFPackOp,
|
||||
ConvertTFReshapeOp, ConvertTFSplitOp, ConvertTFSplitVOp,
|
||||
ConvertTFStridedSliceOp, ConvertTFUnpackOp, ConvertTFAssertOp,
|
||||
ConvertTFReciprocalOp, ConvertTFRandomUniformOp>(ctx);
|
||||
applyPatternsGreedily(func, patterns);
|
||||
}
|
||||
|
||||
|
@ -91,6 +91,9 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFWhilePass();
|
||||
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateWhileOutlinePass();
|
||||
|
||||
// Verifies runtime supports types used.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateRuntimeTypeVerifyPass();
|
||||
|
||||
} // namespace TFL
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -21,12 +21,20 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
|
||||
// Quantize attribute $0 by using quantization parameter from %1.
|
||||
def QuantizeByQuantizedType : NativeCodeCall<"quant::Quantize($0, $1.getValue())">;
|
||||
def F32ElementsAttr : ElementsAttrBase<
|
||||
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
|
||||
|
||||
// Squash tfl.dequantize and tfl.quantize pairs.
|
||||
// TODO(fengliuai): Compare the scale of input and output. This can also be
|
||||
// squashed to a requantize op if the scales are different.
|
||||
def : Pat<(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), (replaceWithValue $in)>;
|
||||
|
||||
// If the tfl.dequantize op wasn't fused, we shouldn't quantize the floating
|
||||
// point constant.
|
||||
def : Pat<(TFL_DequantizeOp
|
||||
(TFL_QuantizeOp (ConstantOp F32ElementsAttr:$cst), $qt)),
|
||||
(ConstantOp $cst)>;
|
||||
|
||||
// Quantize the value of a constant op if the quantization parameters have been
|
||||
// propagated to the output.
|
||||
def : Pat<(TFL_QuantizeOp
|
||||
|
@ -0,0 +1,52 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc"
|
||||
namespace TFL {
|
||||
namespace {
|
||||
|
||||
// This pass verifies that the operands and results types are supported by
|
||||
// TFLite runtime.
|
||||
class RuntimeTypeVerifyPass : public mlir::FunctionPass<RuntimeTypeVerifyPass> {
|
||||
public:
|
||||
explicit RuntimeTypeVerifyPass() {}
|
||||
|
||||
private:
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
void RuntimeTypeVerifyPass::runOnFunction() {
|
||||
getFunction().walk([&](TflRuntimeVerifyOpInterface op) {
|
||||
if (failed(op.VerifyTflRuntimeTypes(op.getOperation())))
|
||||
signalPassFailure();
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Verifies runtime supports types used.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateRuntimeTypeVerifyPass() {
|
||||
return std::make_unique<RuntimeTypeVerifyPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<RuntimeTypeVerifyPass> pass(
|
||||
"tfl-runtime-verify", "TFLite runtime verification");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
@ -168,6 +168,10 @@ std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) {
|
||||
result.getResultNumber());
|
||||
return std::string(result.getOwner()->getName().getStringRef());
|
||||
}
|
||||
// Use the ASM syntax for BloackArgument
|
||||
if (auto arg = val.dyn_cast<mlir::BlockArgument>()) {
|
||||
return "arg" + std::to_string(arg.getArgNumber());
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
|
@ -41,11 +41,52 @@ limitations under the License.
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "mlir/Support/STLExtras.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_device {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TF Device Dialect Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct TFInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Analysis Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// Defines the legality of inlining TF Device operations.
|
||||
bool isLegalToInline(Operation*, Region*, BlockAndValueMapping&) const final {
|
||||
// For now, enable inlining all operations.
|
||||
return true;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Transformation Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// Attempts to materialize a conversion for a type mismatch between a call
|
||||
// from this dialect, and a callable region. This method should generate an
|
||||
// operation that takes 'input' as the only operand, and produces a single
|
||||
// result of 'resultType'. If a conversion can not be generated, nullptr
|
||||
// should be returned.
|
||||
// This is just re-using the same logic as the TensorFlow dialect right now.
|
||||
Operation* materializeCallConversion(OpBuilder& builder, Value input,
|
||||
Type result_type,
|
||||
Location conversion_loc) const final {
|
||||
if (!result_type.isa<TensorType>() || !input.getType().isa<TensorType>())
|
||||
return nullptr;
|
||||
return builder.create<TF::CastOp>(conversion_loc, result_type, input,
|
||||
/*truncate=*/builder.getBoolAttr(false));
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
|
||||
: Dialect(/*name=*/"tf_device", context) {
|
||||
addOperations<
|
||||
@ -54,6 +95,8 @@ TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
|
||||
>();
|
||||
|
||||
addOperations<ParallelExecuteOp>();
|
||||
|
||||
addInterfaces<TFInlinerInterface>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -49,7 +49,7 @@ an output element, this operation computes \\(y = |x|\\).
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns x + y element-wise.";
|
||||
|
||||
@ -98,7 +98,7 @@ Inputs must be of same size and shape.
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape]>,
|
||||
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns x + y element-wise.";
|
||||
|
||||
@ -6781,7 +6781,7 @@ variables.
|
||||
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_TanhOp : TF_Op<"Tanh", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_TanhOp : TF_Op<"Tanh", [NoSideEffect, SameOperandsAndResultType, TF_LayoutAgnostic]> {
|
||||
let summary = "Computes hyperbolic tangent of `x` element-wise.";
|
||||
|
||||
let description = [{
|
||||
|
@ -58,6 +58,10 @@ TODO: Make invariants more structured so that we can reference them in ops.
|
||||
def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait<
|
||||
"TF::OperandsSameAsResultsTypeOrRef">;
|
||||
|
||||
// Layout agnostic operations do not depend on the operands data layout (data
|
||||
// format), as an example all element wise operations are layout agnostic.
|
||||
def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -68,6 +68,11 @@ class OperandsSameAsResultsTypeOrRef
|
||||
}
|
||||
};
|
||||
|
||||
// Layout agnostic operations do not depend on the operands data layout (data
|
||||
// format), as and example all element wise operations are layout agnostic.
|
||||
template <typename ConcreteType>
|
||||
class LayoutAgnostic : public TraitBase<ConcreteType, LayoutAgnostic> {};
|
||||
|
||||
} // namespace TF
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
@ -3,6 +3,10 @@
|
||||
|
||||
// All tests also test for idempotence.
|
||||
|
||||
// Test that external functions aren't processed (used to crash).
|
||||
// CHECK-LABEL: func @unused_external_func
|
||||
func @unused_external_func()
|
||||
|
||||
func @multiple_return(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> (tensor<*xi32>, tensor<*xi32>) {
|
||||
%graph:2 = tf_executor.graph {
|
||||
%island:3 = tf_executor.island {
|
||||
|
@ -0,0 +1,57 @@
|
||||
// RUN: tf-opt %s -tf-executor-tpu-v1-island-coarsening | FileCheck %s --dump-input=fail
|
||||
|
||||
|
||||
// Test that islands with a function call are merged if the call is to a function
|
||||
// that contains ops with the same attribute.
|
||||
// CHECK-LABEL: func @control_input
|
||||
func @control_input(%arg0 : tensor<i1>) -> tensor<i32> {
|
||||
%0:6 = tf_executor.graph {
|
||||
%1:2 = tf_executor.island wraps "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
|
||||
%2:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "A", body = @while_body_with_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
%3:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "B", body = @while_body_with_wrong_cluster_attr, cond = @while_cond_with_wrong_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
%4:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "C", body = @while_body_without_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
%6:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "D", body = @while_body_without_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
%5:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "E", body = @while_body_with_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
|
||||
// CHECK: "tf.opA"
|
||||
// CHECK-NOT: island
|
||||
// CHECK: name = "A"
|
||||
// CHECK-NOT: island
|
||||
// CHECK: name = "C"
|
||||
// CHECK-NOT: island
|
||||
// CHECK: name = "E"
|
||||
// CHECK: island {{.*}}name = "B"
|
||||
// CHECK: island {{.*}}name = "D"
|
||||
|
||||
tf_executor.fetch %1#0, %2#0, %3#0, %4#0, %5#0, %6#0 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}
|
||||
return %0#0 : tensor<i32>
|
||||
}
|
||||
|
||||
func @while_body_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
func @while_cond_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
|
||||
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
func @while_body_with_wrong_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "some.op"(%arg0) {_tpu_replicate = "wrong_cluster"} : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
func @while_cond_with_wrong_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
|
||||
%0 = "some.op"(%arg0) {_tpu_replicate = "wrong_cluster"} : (tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
func @while_body_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
func @while_cond_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
|
||||
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
@ -0,0 +1,44 @@
|
||||
// RUN: tf-opt %s -tf-executor-tpu-v1-island-inlining | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-NOT: tf.PartitionedCall
|
||||
// CHECK-NOT: module @_tpu_v1_compat_outlined
|
||||
|
||||
module {
|
||||
func @control_input(%arg0: tensor<i1>) -> tensor<i32> {
|
||||
%0:4 = tf_executor.graph {
|
||||
%outputs:4, %control = tf_executor.island wraps "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @_tpu_v1_compat_outlined::@_tpu_v1_compat_outlined_func0} : (tensor<i1>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>)
|
||||
tf_executor.fetch %outputs#0, %outputs#1, %outputs#2, %outputs#3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}
|
||||
return %0#0 : tensor<i32>
|
||||
}
|
||||
module @_tpu_v1_compat_outlined {
|
||||
func @_tpu_v1_compat_outlined_func0(%arg0: tensor<i1>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1 : i64, topology = "topology"} : () -> ()
|
||||
%0 = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
|
||||
%1 = "tf.While"(%0) {body = @while_body_with_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "A", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
%2 = "tf.While"(%0) {body = @while_body_without_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "C", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
%3 = "tf.While"(%0) {body = @while_body_with_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, name = "E", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
return %0, %1, %2, %3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}
|
||||
func @while_body_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
func @while_cond_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
|
||||
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
func @while_body_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
func @while_cond_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
|
||||
%0 = "tf.PartionedCalledOp"(%arg0) {f = @callee_func} : (tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
func @callee_func(%arg0: tensor<i32>) -> tensor<i1> {
|
||||
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,48 @@
|
||||
// RUN: tf-opt %s -tf-executor-tpu-v1-island-outlining | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK: func @control_input
|
||||
// CHECK-NOT: func @
|
||||
// CHECK-LABEL: module @_tpu_v1_compat_outlined
|
||||
// CHECK: @_tpu_v1_compat_outlined_func0
|
||||
// CHECK: func @while_body_with_cluster_attr
|
||||
// CHECK: func @while_cond_with_cluster_attr
|
||||
// CHECK: func @while_body_without_cluster_attr
|
||||
// CHECK: func @while_cond_without_cluster_attr
|
||||
// CHECK: func @callee_func
|
||||
module {
|
||||
func @control_input(%arg0: tensor<i1>) -> tensor<i32> {
|
||||
%0:4 = tf_executor.graph {
|
||||
%outputs:4, %control = tf_executor.island {
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
|
||||
%1 = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
|
||||
%2 = "tf.While"(%1) {body = @while_body_with_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "A", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
%3 = "tf.While"(%1) {body = @while_body_without_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "C", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
%4 = "tf.While"(%1) {body = @while_body_with_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, name = "E", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
tf_executor.yield %1, %2, %3, %4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}
|
||||
tf_executor.fetch %outputs#0, %outputs#1, %outputs#2, %outputs#3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
|
||||
}
|
||||
return %0#0 : tensor<i32>
|
||||
}
|
||||
func @while_body_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
func @while_cond_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
|
||||
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
func @while_body_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
func @while_cond_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
|
||||
%0 = "tf.PartionedCalledOp"(%arg0) { f = @callee_func} : (tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
func @callee_func(%arg0: tensor<i32>) -> tensor<i1> {
|
||||
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt %s -tf-layout-assignment=force-data-format=NCHW -verify-diagnostics | FileCheck %s
|
||||
// RUN: tf-opt %s -tf-layout-assignment=force-data-format=NCHW -verify-diagnostics | FileCheck %s --dump-input=always
|
||||
|
||||
// CHECK-LABEL: func @transposeBiasAdd
|
||||
func @transposeBiasAdd(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
|
@ -0,0 +1,67 @@
|
||||
// RUN: tf-opt %s -tf-move-transposes -verify-diagnostics | FileCheck %s --dump-input=always
|
||||
|
||||
// CHECK-LABEL: func @move_across_single_op
|
||||
func @move_across_single_op(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: return %[[TANH]]
|
||||
|
||||
%0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||
%1 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%2 = "tf.Transpose"(%0, %1) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||
|
||||
return %2 : tensor<1x8x4x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @move_across_multiple_ops
|
||||
func @move_across_multiple_ops(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
// CHECK: %[[TANH0:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: %[[TANH1:[0-9]*]] = "tf.Tanh"(%[[TANH0]]) {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: return %[[TANH1]]
|
||||
|
||||
%0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||
%1 = "tf.Tanh"(%0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||
|
||||
%2 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%3 = "tf.Transpose"(%1, %2) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||
|
||||
return %3 : tensor<1x8x4x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @move_across_multi_operand_op
|
||||
func @move_across_multi_operand_op(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[ARG0_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
// CHECK: %[[ARG1_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg1, %[[ARG_PERM]])
|
||||
// CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%[[ARG0_TRANSPOSE]], %[[ARG1_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: return %[[ADD]]
|
||||
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||
%1 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%2 = "tf.Transpose"(%0, %1) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||
|
||||
return %2 : tensor<1x8x4x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @move_with_multiple_uses
|
||||
func @move_with_multiple_uses(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%[[TANH]], %[[TANH]]) {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: return %[[ADD]]
|
||||
|
||||
%0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||
%1 = "tf.AddV2"(%0, %0) : (tensor<1x4x4x8xf32>, tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||
%2 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%3 = "tf.Transpose"(%1, %2) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||
|
||||
return %3 : tensor<1x8x4x4xf32>
|
||||
}
|
@ -542,3 +542,116 @@ func @if_else(%arg0: tensor<*x!tf.resource<tensor<4xf32>>>, %arg1: tensor<*x!tf.
|
||||
-> (tensor<*x!tf.resource<tensor<4xf32>>>) {
|
||||
return %arg1 : tensor<*x!tf.resource<tensor<4xf32>>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass lifts resources on two partitioned call ops sharing the
|
||||
// same callee. The lifting should clone the callee then modify the clone.
|
||||
|
||||
// CHECK-LABEL: @launch_with_partitioned_call
|
||||
func @launch_with_partitioned_call() -> tensor<f32> {
|
||||
// CHECK: %[[VH:.*]] = "tf.VarHandleOp"()
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
// CHECK: %[[CONST:.*]] = "tf.Const"()
|
||||
%1 = "tf.Const"() {value = dense<10.0> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]])
|
||||
// CHECK: %[[LAUNCH:.*]] = "tf_device.launch"()
|
||||
%2 = "tf_device.launch"() ( {
|
||||
// CHECK: %[[PC0:.*]] = "tf.PartitionedCall"(%[[CONST]], %[[READ]], %[[CONST]])
|
||||
// CHECK-SAME: f = @callee_resource_lifted
|
||||
%3 = "tf.PartitionedCall"(%1, %0, %1) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<f32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[PC1:.*]] = "tf.PartitionedCall"(%[[CONST]], %[[READ]], %[[CONST]])
|
||||
// CHECK-SAME: f = @callee_resource_lifted
|
||||
%4 = "tf.PartitionedCall"(%1, %0, %1) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<f32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[PC0]], %[[PC1]])
|
||||
%5 = "tf.AddV2"(%3, %4) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: tf_device.return %[[ADD]] : tensor<f32>
|
||||
tf_device.return %5 : tensor<f32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<f32>
|
||||
return %2 : tensor<f32>
|
||||
}
|
||||
// CHECK: @callee(%[[OA0:.*]]: tensor<f32>, %[[OA1:.*]]: tensor<*x!tf.resource<tensor<f32>>>, %[[OA2:.*]]: tensor<f32>) -> tensor<f32>
|
||||
func @callee(%arg0: tensor<f32>, %arg1: tensor<*x!tf.resource<tensor<f32>>>, %arg2: tensor<f32>) -> tensor<f32> {
|
||||
// CHECK: "tf.ReadVariableOp"(%[[OA1]])
|
||||
%0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
%1 = "tf.AddV2"(%0, %arg0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%2 = "tf.AddV2"(%1, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %2 : tensor<f32>
|
||||
}
|
||||
// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor<f32>, %[[A1:.*]]: tensor<f32>, %[[A2:.*]]: tensor<f32>) -> tensor<f32>
|
||||
// CHECK-NEXT: %[[ADD0:.*]] = "tf.AddV2"(%[[A1]], %[[A0]])
|
||||
// CHECK-NEXT: %[[ADD1:.*]] = "tf.AddV2"(%[[ADD0]], %[[A2]])
|
||||
// CHECK-NEXT: return %[[ADD1]]
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass lifts resources on two stateful partitioned call ops
|
||||
// sharing the same callee. The lifting should clone the callee then modify the
|
||||
// clone.
|
||||
|
||||
// CHECK-LABEL: @launch_with_stateful_partitioned_call
|
||||
func @launch_with_stateful_partitioned_call() -> () {
|
||||
// CHECK: %[[VH0:.*]] = "tf.VarHandleOp"()
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
// CHECK: %[[VH1:.*]] = "tf.VarHandleOp"()
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
// CHECK: %[[CONST:.*]] = "tf.Const"()
|
||||
%2 = "tf.Const"() {value = dense<10.0> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK-DAG: %[[READ0:.*]] = "tf.ReadVariableOp"(%[[VH0]])
|
||||
// CHECK-DAG: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]])
|
||||
// CHECK: %[[LAUNCH:.*]] = "tf_device.launch"()
|
||||
"tf_device.launch"() ( {
|
||||
// CHECK: %[[PC0:.*]] = "tf.StatefulPartitionedCall"(%[[READ0]], %[[READ1]], %[[CONST]])
|
||||
// CHECK-SAME: f = @callee_resource_lifted
|
||||
%3 = "tf.StatefulPartitionedCall"(%0, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
// CHECK: %[[PC1:.*]] = "tf.StatefulPartitionedCall"(%[[PC0]], %[[READ1]], %[[CONST]])
|
||||
// CHECK-SAME: f = @callee_resource_lifted
|
||||
%4 = "tf.StatefulPartitionedCall"(%3, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
// CHECK: tf_device.return %[[PC1]] : tensor<f32>
|
||||
tf_device.return
|
||||
// CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<f32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
||||
// CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[LAUNCH]])
|
||||
return
|
||||
}
|
||||
// CHECK: @callee(%[[OA0:.*]]: tensor<*x!tf.resource<tensor<f32>>>, %[[OA1:.*]]: tensor<*x!tf.resource<tensor<f32>>>, %[[OA2:.*]]: tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
func @callee(%arg0: tensor<*x!tf.resource<tensor<f32>>>, %arg1: tensor<*x!tf.resource<tensor<f32>>>, %arg2: tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>> {
|
||||
// CHECK: "tf.ReadVariableOp"(%[[OA1]])
|
||||
%0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
%1 = "tf.AddV2"(%0, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
"tf.AssignVariableOp"(%arg0, %1) {dtype = i32} : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
|
||||
return %arg0 : tensor<*x!tf.resource<tensor<f32>>>
|
||||
}
|
||||
// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor<f32>, %[[A1:.*]]: tensor<f32>, %[[A2:.*]]: tensor<f32>) -> tensor<f32>
|
||||
// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[A1]], %[[A2]])
|
||||
// CHECK-NEXT: return %[[ADD]]
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass reports error on called function that has resource output
|
||||
// which doesn't alias an input.
|
||||
|
||||
func @launch_with_stateful_partitioned_call() -> () {
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
%2 = "tf.Const"() {value = dense<10.0> : tensor<f32>} : () -> tensor<f32>
|
||||
"tf_device.launch"() ( {
|
||||
%3 = "tf.StatefulPartitionedCall"(%0, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
%4 = "tf.StatefulPartitionedCall"(%3, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
tf_device.return
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
// expected-error @+1 {{Unsupported function call: resource return value does not alias an input.}}
|
||||
func @callee(%arg0: tensor<*x!tf.resource<tensor<f32>>>, %arg1: tensor<*x!tf.resource<tensor<f32>>>, %arg2: tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>> {
|
||||
%0 = "tf._Unknown_"() : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
return %0 : tensor<*x!tf.resource<tensor<f32>>>
|
||||
}
|
||||
|
@ -45,6 +45,17 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
return %1 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<?xf32>
|
||||
func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
||||
br ^bb1
|
||||
^bb1:
|
||||
// CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: return %[[IDENTITY]] : tensor<?xf32>
|
||||
%ret = "tf.Identity"(%arg0) : (tensor<?xf32>) -> tensor<*xf32>
|
||||
return %ret : tensor<*xf32>
|
||||
}
|
||||
|
||||
|
||||
// Tests the case where an inference opportunity relies on folding.
|
||||
|
||||
// CHECK-LABEL: func @simple_folding
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
@ -70,10 +71,20 @@ void TPUBridgeExecutorIslandInlining::runOnModule() {
|
||||
call_op.emitOpError() << "Failed to inline\n";
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
called_func.erase();
|
||||
call_op.erase();
|
||||
return WalkResult::advance();
|
||||
});
|
||||
if (walk_result.wasInterrupted()) return signalPassFailure();
|
||||
// Move all remaining nested functions back into the parent module.
|
||||
Block &nested_block = nested_module->getRegion(0).front();
|
||||
for (FuncOp func_op :
|
||||
llvm::make_early_inc_range(nested_block.getOps<FuncOp>())) {
|
||||
if (!symbol_table.lookupSymbolIn(getModule(), func_op.getName())) {
|
||||
nested_block.getOperations().remove(func_op.getOperation());
|
||||
symbol_table.insert(func_op.getOperation());
|
||||
}
|
||||
}
|
||||
nested_module->erase();
|
||||
}
|
||||
|
||||
|
@ -29,10 +29,12 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
#include "mlir/IR/UseDefLists.h" // TF:llvm-project
|
||||
#include "mlir/IR/Visitors.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
@ -57,8 +59,8 @@ constexpr llvm::StringRef kTpuStatusAttr = "_tpu_compilation_status";
|
||||
// TPU-annotated operations and intended to preserve backward compatibility with
|
||||
// TFv1.
|
||||
struct TpuV1BridgeExecutorIslandCoarsening
|
||||
: public FunctionPass<TpuV1BridgeExecutorIslandCoarsening> {
|
||||
void runOnFunction() override;
|
||||
: public ModulePass<TpuV1BridgeExecutorIslandCoarsening> {
|
||||
void runOnModule() override;
|
||||
};
|
||||
|
||||
// Sort the Operations in the provided range to enforce dominance.
|
||||
@ -88,9 +90,10 @@ LogicalResult SortTopologically(Block::iterator first_op,
|
||||
Operation* producer_in_block =
|
||||
block->findAncestorOpInBlock(*defining_op);
|
||||
if (producer_in_block && producer_in_block != &op &&
|
||||
unscheduled_ops.count(producer_in_block))
|
||||
unscheduled_ops.count(producer_in_block)) {
|
||||
// Found an operand that isn't scheduled yet, interrupt the walk.
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
@ -113,7 +116,9 @@ LogicalResult SortTopologically(Block::iterator first_op,
|
||||
// A failure is returned if a cycle preventing the merge from happening
|
||||
// correctly without breaking dominance. The IR is left in invalid state in case
|
||||
// of failure.
|
||||
LogicalResult MergeIsland(Operation* op, bool* changed) {
|
||||
LogicalResult MergeIsland(llvm::function_ref<bool(StringAttr, Operation*)>
|
||||
is_op_calling_func_for_cluster,
|
||||
Operation* op, bool* changed) {
|
||||
// Find the first island wrapping a single operation with the `_tpu_replicate`
|
||||
// attribute, it'll be used as the root of the algorithm to find the other
|
||||
// operations that are part of the same cluster.
|
||||
@ -146,7 +151,9 @@ LogicalResult MergeIsland(Operation* op, bool* changed) {
|
||||
if (!candidate_cluster_name)
|
||||
candidate_cluster_name =
|
||||
candidate_wrapped_op.getAttrOfType<StringAttr>(kTpuStatusAttr);
|
||||
if (candidate_cluster_name != cluster_name) continue;
|
||||
if (candidate_cluster_name != cluster_name &&
|
||||
!is_op_calling_func_for_cluster(cluster_name, &candidate_wrapped_op))
|
||||
continue;
|
||||
|
||||
// Look at captured operands to bring-in ReplicatedInputOp in the
|
||||
// island as well. TODO: also pull in tf.Const, some optimizations can
|
||||
@ -250,34 +257,71 @@ LogicalResult MergeIsland(Operation* op, bool* changed) {
|
||||
first_op_after);
|
||||
}
|
||||
|
||||
void TpuV1BridgeExecutorIslandCoarsening::runOnFunction() {
|
||||
getFunction().walk([&](GraphOp graph) {
|
||||
Block& graph_body = graph.GetBody();
|
||||
void TpuV1BridgeExecutorIslandCoarsening::runOnModule() {
|
||||
SymbolTable symbol_table(getModule());
|
||||
|
||||
// Iterate until fixed point on the block, as it may contain multiple
|
||||
// clusters.
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
for (Operation& op : graph_body) {
|
||||
if (failed(MergeIsland(&op, &changed))) {
|
||||
graph.emitError() << "Merging island failed: the TPU cluster likely "
|
||||
<< "contains a cycle with non-TPU operations\n";
|
||||
signalPassFailure();
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
// If islands were merged, restart scanning the block from the beginning
|
||||
// as we lost track of where to continue.
|
||||
if (changed) break;
|
||||
}
|
||||
// Map tpu cluster names to the functions that contain operations for this
|
||||
// cluster.
|
||||
DenseMap<StringRef, DenseSet<FuncOp>> tpu_funcs;
|
||||
for (FuncOp func_op : getModule().getOps<FuncOp>()) {
|
||||
func_op.walk([&](Operation* op) {
|
||||
StringAttr cluster_name =
|
||||
op->getAttrOfType<StringAttr>(kTpuReplicateAttr);
|
||||
if (!cluster_name)
|
||||
cluster_name = op->getAttrOfType<StringAttr>(kTpuStatusAttr);
|
||||
if (!cluster_name) return;
|
||||
tpu_funcs[cluster_name.getValue()].insert(func_op);
|
||||
});
|
||||
}
|
||||
|
||||
// Return true if the operation is containing a reference to a function
|
||||
// containing operations for this cluster.
|
||||
auto is_op_calling_func_for_cluster = [&](StringAttr cluster, Operation* op) {
|
||||
auto funcs_for_cluster = tpu_funcs.find(cluster.getValue());
|
||||
assert(funcs_for_cluster != tpu_funcs.end());
|
||||
assert(!funcs_for_cluster->second.empty());
|
||||
if (funcs_for_cluster->second.size() == 1) return false;
|
||||
for (NamedAttribute attr : op->getAttrs()) {
|
||||
auto symbol_ref = attr.second.dyn_cast<FlatSymbolRefAttr>();
|
||||
if (!symbol_ref) continue;
|
||||
FuncOp callee = symbol_table.lookup<FuncOp>(symbol_ref.getValue());
|
||||
if (!callee) continue;
|
||||
if (funcs_for_cluster->second.count(callee)) return true;
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
return false;
|
||||
};
|
||||
|
||||
for (FuncOp func_op : getModule().getOps<FuncOp>()) {
|
||||
func_op.walk([&](GraphOp graph) {
|
||||
Block& graph_body = graph.GetBody();
|
||||
|
||||
// Iterate until fixed point on the block, as it may contain multiple
|
||||
// clusters.
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
for (Operation& op : graph_body) {
|
||||
if (failed(
|
||||
MergeIsland(is_op_calling_func_for_cluster, &op, &changed))) {
|
||||
graph.emitError()
|
||||
<< "Merging island failed: the TPU cluster likely "
|
||||
<< "contains a cycle with non-TPU operations\n";
|
||||
signalPassFailure();
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
// If islands were merged, restart scanning the block from the
|
||||
// beginning as we lost track of where to continue.
|
||||
if (changed) break;
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>>
|
||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
||||
CreateTFExecutorTPUV1IslandCoarseningPass() {
|
||||
return std::make_unique<TpuV1BridgeExecutorIslandCoarsening>();
|
||||
}
|
||||
|
@ -133,9 +133,23 @@ void TPUBridgeExecutorIslandOutlining::runOnModule() {
|
||||
/*executor_type=*/builder.getStringAttr(""));
|
||||
SmallVector<Value, 16> yield_operands(call_op.getResults());
|
||||
builder.create<YieldOp>(island_op.getLoc(), yield_operands);
|
||||
}
|
||||
|
||||
// TODO(aminim): handle transitively referenced function and clone them in
|
||||
// the new module.
|
||||
// Outlined all the transitively called functions by moving them in the
|
||||
// outlined module.
|
||||
for (FuncOp func : outlined_module.getOps<FuncOp>()) {
|
||||
func.walk([&](Operation *op) {
|
||||
for (NamedAttribute attr : op->getAttrs()) {
|
||||
auto symbol_ref = attr.second.dyn_cast<FlatSymbolRefAttr>();
|
||||
if (!symbol_ref) continue;
|
||||
if (outlined_symbol_table.lookup<FuncOp>(symbol_ref.getValue()))
|
||||
continue;
|
||||
FuncOp callee = symbol_table.lookup<FuncOp>(symbol_ref.getValue());
|
||||
callee.getOperation()->getBlock()->getOperations().remove(
|
||||
callee.getOperation());
|
||||
outlined_symbol_table.insert(callee);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
|
||||
@ -25,6 +28,8 @@ namespace TF {
|
||||
|
||||
namespace {
|
||||
|
||||
// LayoutAssignmentPass assigns optimal data layout (data format) for all
|
||||
// layout sensitive operations.
|
||||
class LayoutAssignmentPass : public FunctionPass<LayoutAssignmentPass> {
|
||||
public:
|
||||
LayoutAssignmentPass() = default;
|
||||
@ -39,6 +44,14 @@ class LayoutAssignmentPass : public FunctionPass<LayoutAssignmentPass> {
|
||||
llvm::cl::desc("Force data format for all layout sensitive ops")};
|
||||
};
|
||||
|
||||
// MoveTransposesPass moves all Transpose ops to the beginning or to the end of
|
||||
// the basic block where they are defined. This will allow canonicalzer to
|
||||
// delete redundant transposes.
|
||||
class MoveTransposesPass : public FunctionPass<MoveTransposesPass> {
|
||||
public:
|
||||
void runOnFunction() final;
|
||||
};
|
||||
|
||||
using Permutation = SmallVector<int64_t, 4>;
|
||||
|
||||
Permutation GetDataFormatPermutation(StringRef from_data_format,
|
||||
@ -128,10 +141,116 @@ void LayoutAssignmentPass::runOnFunction() {
|
||||
});
|
||||
}
|
||||
|
||||
// Move Transpose operations that permute `op` results before the `op`.
|
||||
void MoveTransposeBefore(Operation* op, SmallVector<Operation*, 8>* work_list) {
|
||||
// TODO(ezhulenev): Move transpose across layout sensitive operations.
|
||||
if (!op->hasTrait<OpTrait::TF::LayoutAgnostic>()) return;
|
||||
|
||||
// Transpose operations that use operation results.
|
||||
SmallVector<TransposeOp, 2> transpose_ops;
|
||||
|
||||
// Constant operation that defines permutation indices for result transposes.
|
||||
ConstOp permutation_op;
|
||||
|
||||
// All operation results must be used by transpose operations with the same
|
||||
// permutation indices.
|
||||
for (OpResult result : op->getResults()) {
|
||||
for (Operation* user : result.getUsers()) {
|
||||
// Result user must be a transpose operation.
|
||||
TransposeOp transpose = dyn_cast<TransposeOp>(user);
|
||||
if (!transpose) return;
|
||||
|
||||
// With permutation defined by constant operation.
|
||||
ConstOp perm =
|
||||
dyn_cast_or_null<ConstOp>(transpose.getOperand(1).getDefiningOp());
|
||||
if (!perm) return;
|
||||
|
||||
// With the same permutation indices.
|
||||
auto dense_elem_attr = perm.value().dyn_cast<DenseElementsAttr>();
|
||||
if (!dense_elem_attr) return;
|
||||
|
||||
if (!permutation_op) permutation_op = perm;
|
||||
|
||||
// Check that permutation matches for all result transposes.
|
||||
if (perm.value() != permutation_op.value()) return;
|
||||
|
||||
// Add a transpose operation for later reuse.
|
||||
transpose_ops.push_back(transpose);
|
||||
}
|
||||
}
|
||||
|
||||
// Nothing to do here.
|
||||
if (!permutation_op || transpose_ops.empty()) return;
|
||||
|
||||
// At this point we checked that we can safely move Transpose node before
|
||||
// `op`, and bypass all result transposes.
|
||||
Location loc = op->getLoc();
|
||||
|
||||
// Move constant op defining result permutation to the beginning of the block.
|
||||
permutation_op.getOperation()->moveBefore(&op->getBlock()->front());
|
||||
|
||||
// Bypass Transpose nodes for all results.
|
||||
for (OpResult result : op->getResults()) {
|
||||
result.setType(cast<TransposeOp>(*result.getUsers().begin()).y().getType());
|
||||
for (Operation* transpose : result.getUsers()) {
|
||||
transpose->getResult(0).replaceAllUsesWith(result);
|
||||
}
|
||||
}
|
||||
|
||||
// Maybe add a Transpose node for all operands (or reuse existing transposes).
|
||||
OpBuilder builder(op);
|
||||
builder.setInsertionPoint(op);
|
||||
|
||||
for (OpOperand& operand : op->getOpOperands()) {
|
||||
// Try to push transpose further up.
|
||||
if (Operation* operand_op = operand.get().getDefiningOp())
|
||||
work_list->push_back(operand_op);
|
||||
|
||||
// Try to reuse result transposes.
|
||||
TransposeOp transpose;
|
||||
if (!transpose_ops.empty()) {
|
||||
transpose = transpose_ops.pop_back_val();
|
||||
transpose.getOperation()->moveBefore(op);
|
||||
transpose.setOperand(0, operand.get());
|
||||
transpose.setOperand(1, permutation_op);
|
||||
} else {
|
||||
transpose =
|
||||
builder.create<TransposeOp>(loc, operand.get(), permutation_op);
|
||||
}
|
||||
|
||||
operand.set(transpose);
|
||||
}
|
||||
|
||||
// Remove unused transpose operations.
|
||||
while (!transpose_ops.empty()) {
|
||||
TransposeOp transpose = transpose_ops.pop_back_val();
|
||||
transpose.erase();
|
||||
}
|
||||
}
|
||||
|
||||
void MoveTransposesPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
|
||||
SmallVector<Operation*, 8> work_list;
|
||||
|
||||
func.walk([&](TransposeOp transpose) {
|
||||
for (auto operand : transpose.getOperands()) {
|
||||
if (auto op = operand.getDefiningOp()) work_list.push_back(op);
|
||||
}
|
||||
});
|
||||
|
||||
while (!work_list.empty()) {
|
||||
Operation* op = work_list.pop_back_val();
|
||||
MoveTransposeBefore(op, &work_list);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
static PassRegistration<LayoutAssignmentPass> pass("tf-layout-assignment",
|
||||
"Layout assignment pass");
|
||||
static PassRegistration<LayoutAssignmentPass> layout_assignment(
|
||||
"tf-layout-assignment", "Layout assignment pass");
|
||||
static PassRegistration<MoveTransposesPass> move_transposes(
|
||||
"tf-move-transposes", "Move transposes pass");
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
@ -106,7 +106,8 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorIslandCoarseningPass();
|
||||
|
||||
// Creates a pass to merge IslandOps for operation marked for execution on TPU.
|
||||
// This is a V1 backward compatibility.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorTPUV1IslandCoarseningPass();
|
||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
||||
CreateTFExecutorTPUV1IslandCoarseningPass();
|
||||
|
||||
// Creates a pass to outlining TPU clusters from single IslandOp into a nested
|
||||
// module suitable for being processed as-if it was a V2 module.
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
@ -811,16 +812,185 @@ LogicalResult HanldeIfOP(TF::IfOp if_op, FuncOp then_branch,
|
||||
return success();
|
||||
}
|
||||
|
||||
// A resource-lifted function for (potentially multiple) PartitionedCallOps and
|
||||
// information about the lifting changes.
|
||||
struct PartitionedCallLiftingInfo {
|
||||
// Function with resources lifted. Can be nullptr if nothing needs to change.
|
||||
FuncOp lifted_callee;
|
||||
// Mapping from old resource outputs to their aliasing output inputs.
|
||||
llvm::SmallDenseMap<int64_t, int64_t> old_outputs_aliasing_old_inputs;
|
||||
// Mapping from old to new output indices in case any output is removed.
|
||||
llvm::SmallVector<int64_t, 4> old_to_new_output_indices;
|
||||
// ResourceArgUseInfo for each old resource argument.
|
||||
llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> use_info;
|
||||
// Input for AddLoadsStoresOutsideControlFlowOp(), see its comment.
|
||||
llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
|
||||
arg_data_type_and_updated_output_index;
|
||||
};
|
||||
|
||||
// Lifts loads/stores from a PartitionedCallOp's callee function. If anything
|
||||
// needs to be changed, the original function will be preserved, and the lifting
|
||||
// happens on a clone, which will be stored in `result`.
|
||||
LogicalResult HandlePartitionedCallOpCallee(
|
||||
FuncOp callee, PartitionedCallLiftingInfo* result) {
|
||||
// Remove identity nodes to avoid aliasing.
|
||||
RemoveIdentity(&callee.front());
|
||||
// Sanity check: return of resources should be aliases of inputs. Such outputs
|
||||
// will be removed later.
|
||||
int64_t non_resource_results = 0;
|
||||
for (auto entry :
|
||||
llvm::enumerate(callee.front().getTerminator()->getOperands())) {
|
||||
auto retval = entry.value();
|
||||
if (!getElementTypeOrSelf(retval.getType()).isa<TF::ResourceType>()) {
|
||||
result->old_to_new_output_indices.push_back(non_resource_results++);
|
||||
continue;
|
||||
}
|
||||
auto aliasing_arg = retval.dyn_cast<BlockArgument>();
|
||||
if (!aliasing_arg) {
|
||||
return callee.emitOpError(
|
||||
"Unsupported function call: resource return value does not alias an "
|
||||
"input.");
|
||||
}
|
||||
result->old_outputs_aliasing_old_inputs[entry.index()] =
|
||||
aliasing_arg.getArgNumber();
|
||||
result->old_to_new_output_indices.push_back(-1);
|
||||
}
|
||||
|
||||
if (failed(FindResourceArgUseInfo(callee, &result->use_info))) {
|
||||
return failure();
|
||||
}
|
||||
if (result->use_info.empty()) {
|
||||
result->lifted_callee = nullptr;
|
||||
return success();
|
||||
}
|
||||
|
||||
// Clone the callee before making changes.
|
||||
SmallString<64> name_base = callee.getName();
|
||||
auto module = callee.getParentOfType<ModuleOp>();
|
||||
name_base += "_resource_lifted";
|
||||
auto name = name_base;
|
||||
{
|
||||
int64_t counter = 0;
|
||||
while (module.lookupSymbol(name)) {
|
||||
auto name = name_base;
|
||||
name += "_" + std::to_string(counter++);
|
||||
}
|
||||
}
|
||||
callee = callee.clone();
|
||||
callee.setName(name);
|
||||
SymbolTable(module).insert(callee);
|
||||
result->lifted_callee = callee;
|
||||
|
||||
// Remove unused resources in functions.
|
||||
llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
|
||||
RemoveUnusedResourceArgumentsAndForwardedRetvals(
|
||||
result->use_info, callee, /*old_to_new_arg_indices=*/nullptr,
|
||||
&remaining_resource_data_types);
|
||||
for (const auto& entry : remaining_resource_data_types) {
|
||||
result->arg_data_type_and_updated_output_index[entry.getFirst()] = {
|
||||
entry.getSecond(), -1};
|
||||
}
|
||||
llvm::SmallVector<Value, 4> new_retvals;
|
||||
for (auto val : callee.front().getTerminator()->getOperands()) {
|
||||
// Remove resource type outputs.
|
||||
if (getElementTypeOrSelf(val.getType()).isa<TF::ResourceType>()) continue;
|
||||
new_retvals.push_back(val);
|
||||
}
|
||||
// Lift resources.
|
||||
LiftArgRetResourcesForFunction(
|
||||
callee, remaining_resource_data_types, [&](int64_t index, Value value) {
|
||||
result->arg_data_type_and_updated_output_index[index].second =
|
||||
new_retvals.size();
|
||||
new_retvals.push_back(value);
|
||||
});
|
||||
auto old_return = callee.front().getTerminator();
|
||||
// Replace old return with the new ones with update values.
|
||||
OpBuilder builder(old_return);
|
||||
auto new_return = builder.create<ReturnOp>(old_return->getLoc(), new_retvals);
|
||||
old_return->erase();
|
||||
callee.setType(FunctionType::get(
|
||||
callee.getType().getInputs(),
|
||||
llvm::to_vector<4>(new_return.getOperandTypes()), callee.getContext()));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Updates a PartitionedCallOp/StatefulPartitionedCallOp according to the
|
||||
// resource-lifted new callee function in lifting_info.
|
||||
template <typename CallOpType>
|
||||
void UpdatePartitionedCallOpWithNewCallee(
|
||||
CallOpType call_op, const PartitionedCallLiftingInfo& lifting_info) {
|
||||
if (lifting_info.lifted_callee == nullptr) return;
|
||||
// Replace output resource uses with the aliasing input, so that we can remove
|
||||
// this output.
|
||||
for (const auto& entry : lifting_info.old_outputs_aliasing_old_inputs) {
|
||||
call_op.getResult(entry.getFirst())
|
||||
.replaceAllUsesWith(call_op.getOperand(entry.getSecond()));
|
||||
}
|
||||
// Recreate the call op.
|
||||
OpBuilder builder(call_op);
|
||||
// Now use the filtered original operands, which will be replaced by
|
||||
// AddLoadsStoresOutsideControlFlowOp().
|
||||
auto new_operands =
|
||||
FilterRange<Value, OperandRange>(call_op.args(), lifting_info.use_info);
|
||||
auto new_call = builder.create<CallOpType>(
|
||||
call_op.getLoc(),
|
||||
const_cast<FuncOp&>(lifting_info.lifted_callee).getType().getResults(),
|
||||
new_operands, call_op.getAttrs());
|
||||
new_call.setAttr(
|
||||
"f", builder.getSymbolRefAttr(
|
||||
const_cast<FuncOp&>(lifting_info.lifted_callee).getName()));
|
||||
AddLoadsStoresOutsideControlFlowOp(
|
||||
new_call, lifting_info.arg_data_type_and_updated_output_index);
|
||||
// Replace uses.
|
||||
for (int64_t i = 0; i < lifting_info.old_to_new_output_indices.size(); ++i) {
|
||||
if (lifting_info.old_to_new_output_indices[i] >= 0) {
|
||||
call_op.getResult(i).replaceAllUsesWith(
|
||||
new_call.getResult(lifting_info.old_to_new_output_indices[i]));
|
||||
}
|
||||
}
|
||||
call_op.erase();
|
||||
}
|
||||
|
||||
LogicalResult HoistForFunctionalControlFlow(
|
||||
Block*, ModuleOp, llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>*);
|
||||
|
||||
// A templated routine for handling both PartitionedCallOp and
|
||||
// StatefulPartitionedCallOp. If the callee is already lifted, it just updates
|
||||
// the caller op itself; otherwise, it first recursively handles nested control
|
||||
// flow, then performs lifting on the callee.
|
||||
template <typename CallOpType>
|
||||
LogicalResult HandlePartitionedCallOp(
|
||||
CallOpType call_op, FuncOp callee, ModuleOp module,
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>* lifted_callees) {
|
||||
auto emplace_res =
|
||||
lifted_callees->try_emplace(callee, PartitionedCallLiftingInfo());
|
||||
if (emplace_res.second) {
|
||||
// Unseen callee. Perform resource lifting on it.
|
||||
HoistForFunctionalControlFlow(&callee.front(), module, lifted_callees);
|
||||
if (failed(HandlePartitionedCallOpCallee(
|
||||
callee, &emplace_res.first->getSecond()))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
UpdatePartitionedCallOpWithNewCallee(call_op, emplace_res.first->getSecond());
|
||||
return success();
|
||||
}
|
||||
|
||||
// Hoists resource loads/stores from control flow ops in `block` outside the
|
||||
// body/cond/branch functions.
|
||||
LogicalResult HoistForFunctionalControlFlow(Block* block, ModuleOp module) {
|
||||
// body/cond/branch/callee functions.
|
||||
LogicalResult HoistForFunctionalControlFlow(
|
||||
Block* block, ModuleOp module,
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>*
|
||||
lifted_partitioned_call_callees) {
|
||||
for (Operation& op : llvm::make_early_inc_range(*block)) {
|
||||
if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
|
||||
auto body = llvm::cast<FuncOp>(module.lookupSymbol(while_op.body()));
|
||||
auto cond = llvm::cast<FuncOp>(module.lookupSymbol(while_op.cond()));
|
||||
// Recursively handle the nested control flow.
|
||||
HoistForFunctionalControlFlow(&body.front(), module);
|
||||
HoistForFunctionalControlFlow(&cond.front(), module);
|
||||
HoistForFunctionalControlFlow(&body.front(), module,
|
||||
lifted_partitioned_call_callees);
|
||||
HoistForFunctionalControlFlow(&cond.front(), module,
|
||||
lifted_partitioned_call_callees);
|
||||
if (failed(HanldeWhileLoop(while_op, body, cond))) return failure();
|
||||
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
|
||||
auto then_branch =
|
||||
@ -828,9 +998,30 @@ LogicalResult HoistForFunctionalControlFlow(Block* block, ModuleOp module) {
|
||||
auto else_branch =
|
||||
llvm::cast<FuncOp>(module.lookupSymbol(if_op.else_branch()));
|
||||
// Recursively handle the nested control flow.
|
||||
HoistForFunctionalControlFlow(&then_branch.front(), module);
|
||||
HoistForFunctionalControlFlow(&else_branch.front(), module);
|
||||
HoistForFunctionalControlFlow(&then_branch.front(), module,
|
||||
lifted_partitioned_call_callees);
|
||||
HoistForFunctionalControlFlow(&else_branch.front(), module,
|
||||
lifted_partitioned_call_callees);
|
||||
if (failed(HanldeIfOP(if_op, then_branch, else_branch))) return failure();
|
||||
} else if (auto call_op = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
|
||||
if (!call_op.f().isa<FlatSymbolRefAttr>()) {
|
||||
return call_op.emitError(
|
||||
"Resource lifting does not support call with nested references.");
|
||||
}
|
||||
auto callee = llvm::cast<FuncOp>(
|
||||
module.lookupSymbol(call_op.f().getRootReference()));
|
||||
if (failed(HandlePartitionedCallOp(call_op, callee, module,
|
||||
lifted_partitioned_call_callees))) {
|
||||
// Nested control flow handling is done in HandlePartitionedCallOp().
|
||||
return failure();
|
||||
}
|
||||
} else if (auto call_op =
|
||||
llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
|
||||
auto callee = llvm::cast<FuncOp>(module.lookupSymbol(call_op.f()));
|
||||
if (failed(HandlePartitionedCallOp(call_op, callee, module,
|
||||
lifted_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
return success();
|
||||
@ -840,10 +1031,13 @@ LogicalResult HoistForFunctionalControlFlow(Block* block, ModuleOp module) {
|
||||
// outside. Returns failure if there are remaining resource-type values that can
|
||||
// not be lifted.
|
||||
void ResourceOpLiftingPass::runOnModule() {
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>
|
||||
lifted_partitioned_call_callees;
|
||||
auto result = getModule().walk([&](FuncOp func_op) {
|
||||
return func_op.walk([&](tf_device::LaunchOp launch_op) {
|
||||
if (failed(HoistForFunctionalControlFlow(&launch_op.GetBody(),
|
||||
getModule())) ||
|
||||
if (failed(HoistForFunctionalControlFlow(
|
||||
&launch_op.GetBody(), getModule(),
|
||||
&lifted_partitioned_call_callees)) ||
|
||||
failed(HoistResourceOpsFromLaunchOp(launch_op))) {
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
@ -901,8 +1095,11 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) {
|
||||
<< function.getBlocks().size();
|
||||
}
|
||||
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>
|
||||
lifted_partitioned_call_callees;
|
||||
return HoistForFunctionalControlFlow(&function.front(),
|
||||
cast<ModuleOp>(function.getParentOp()));
|
||||
cast<ModuleOp>(function.getParentOp()),
|
||||
&lifted_partitioned_call_callees);
|
||||
}
|
||||
} // namespace TF
|
||||
|
||||
|
@ -60,16 +60,23 @@ namespace TF {
|
||||
namespace {
|
||||
Optional<llvm::SmallVector<mlir::Type, 4>> InferShapeForFunctionReturnType(
|
||||
FuncOp func) {
|
||||
// Only infer shape when there is one return op for now.
|
||||
if (!has_single_element(func.getBody()) || func.front().empty()) {
|
||||
// Find any return ops.
|
||||
SmallVector<ReturnOp, 4> return_ops;
|
||||
for (Block& block : func) {
|
||||
if (auto return_op = dyn_cast<ReturnOp>(block.getTerminator())) {
|
||||
return_ops.push_back(return_op);
|
||||
}
|
||||
}
|
||||
|
||||
// Right now we only handle the case of a single return op.
|
||||
// To handle multiple return ops, we would need to look at all their shapes
|
||||
// and come up with a common shape and insert appropriate casts.
|
||||
if (return_ops.size() != 1) {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Find the return type.
|
||||
auto return_op = dyn_cast<mlir::ReturnOp>(func.front().back());
|
||||
if (!return_op) {
|
||||
return None;
|
||||
}
|
||||
auto return_op = return_ops.front();
|
||||
|
||||
// Manually fold tf.Cast that precedes the return instruction and only differs
|
||||
// in shape refinement level.
|
||||
|
@ -263,6 +263,7 @@ tf_device::ReplicateOp AddInputsToReplicateOp(
|
||||
llvm::SmallVector<std::pair<llvm::ArrayRef<Value>, Type>, 8>
|
||||
new_replicated_inputs;
|
||||
llvm::SmallVector<llvm::SmallVector<Value, 8>, 8> replicated_inputs;
|
||||
replicated_inputs.reserve(replicate.GetBody().getNumArguments());
|
||||
for (auto arg : llvm::enumerate(replicate.GetBody().getArguments())) {
|
||||
int64_t i = arg.index();
|
||||
replicated_inputs.emplace_back();
|
||||
|
@ -42,8 +42,8 @@ namespace mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
struct BreakUpIslands : OperationPass<BreakUpIslands, FuncOp> {
|
||||
void runOnOperation() final;
|
||||
struct BreakUpIslands : FunctionPass<BreakUpIslands> {
|
||||
void runOnFunction() final;
|
||||
|
||||
void BreakUpIsland(tf_executor::IslandOp island_op,
|
||||
const TF::SideEffectAnalysis& side_effect_analysis,
|
||||
@ -51,8 +51,8 @@ struct BreakUpIslands : OperationPass<BreakUpIslands, FuncOp> {
|
||||
new_control_inputs);
|
||||
};
|
||||
|
||||
void BreakUpIslands::runOnOperation() {
|
||||
auto graph_op_range = getOperation().getBody().front().without_terminator();
|
||||
void BreakUpIslands::runOnFunction() {
|
||||
auto graph_op_range = getFunction().getBody().front().without_terminator();
|
||||
tf_executor::GraphOp graph_op;
|
||||
if (graph_op_range.begin() != graph_op_range.end() &&
|
||||
std::next(graph_op_range.begin()) == graph_op_range.end()) {
|
||||
|
@ -63,21 +63,21 @@ Status StatusScopedDiagnosticHandler::Combine(Status status) {
|
||||
}
|
||||
|
||||
LogicalResult StatusScopedDiagnosticHandler::handler(Diagnostic* diag) {
|
||||
#ifndef NDEBUG
|
||||
// Non-error diagnostic are ignored when VLOG isn't enabled.
|
||||
if (diag->getSeverity() != DiagnosticSeverity::Error && VLOG_IS_ON(1))
|
||||
return success();
|
||||
|
||||
size_t current_diag_str_size_ = diag_str_.size();
|
||||
#endif
|
||||
|
||||
// Emit the diagnostic and flush the stream.
|
||||
emitDiagnostic(*diag);
|
||||
diag_stream_.flush();
|
||||
|
||||
#ifndef NDEBUG
|
||||
// Emit non-errors to VLOG instead of the internal status.
|
||||
if (diag->getSeverity() != DiagnosticSeverity::Error) {
|
||||
VLOG(1) << diag_str_.substr(current_diag_str_size_);
|
||||
diag_str_.resize(current_diag_str_size_);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Return failure to signal propagation if necessary.
|
||||
return failure(propagate_);
|
||||
|
@ -370,6 +370,22 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
Convert(interior_padding))
|
||||
.getOperation();
|
||||
}
|
||||
case HloOpcode::kScatter: {
|
||||
auto scatter = static_cast<HloScatterInstruction*>(instruction);
|
||||
attributes.push_back(
|
||||
ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers()));
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"indices_are_sorted",
|
||||
builder_->getBoolAttr(scatter->indices_are_sorted())));
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"unique_indices", builder_->getBoolAttr(scatter->unique_indices())));
|
||||
|
||||
auto scatter_op = func_builder->create<mlir::xla_hlo::ScatterOp>(
|
||||
loc, result_type, operands, attributes);
|
||||
TF_RETURN_IF_ERROR(ImportComputation(scatter->to_apply(),
|
||||
&scatter_op.update_computation()));
|
||||
return scatter_op.getOperation();
|
||||
}
|
||||
case HloOpcode::kSetDimensionSize: {
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"dimension", builder_->getIntegerAttr(builder_->getIntegerType(32),
|
||||
@ -385,6 +401,16 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
ConvertDimensions(instruction->slice_strides()))
|
||||
.getOperation();
|
||||
}
|
||||
case HloOpcode::kSort: {
|
||||
auto sort_instruction = static_cast<HloSortInstruction*>(instruction);
|
||||
auto sort_op = func_builder->create<mlir::xla_hlo::SortOp>(
|
||||
loc, result_type, operands,
|
||||
builder_->getI64IntegerAttr(sort_instruction->sort_dimension()),
|
||||
builder_->getBoolAttr(sort_instruction->is_stable()));
|
||||
TF_RETURN_IF_ERROR(ImportComputation(sort_instruction->to_apply(),
|
||||
&sort_op.comparator()));
|
||||
return sort_op.getOperation();
|
||||
}
|
||||
case HloOpcode::kConditional: {
|
||||
llvm::SmallVector<Type, 4> rets;
|
||||
TF_RETURN_IF_ERROR(GetMlirTypes(
|
||||
@ -834,6 +860,22 @@ mlir::NamedAttribute HloFunctionImporter::ConvertGatherDimensionNumbers(
|
||||
return builder_->getNamedAttr("dimension_numbers", attr);
|
||||
}
|
||||
|
||||
mlir::NamedAttribute HloFunctionImporter::ConvertScatterDimensionNumbers(
|
||||
const xla::ScatterDimensionNumbers& dnums) {
|
||||
std::vector<int64_t> update_window_dims(dnums.update_window_dims().begin(),
|
||||
dnums.update_window_dims().end());
|
||||
std::vector<int64_t> inserted_window_dims(
|
||||
dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end());
|
||||
std::vector<int64_t> scatter_dims_to_operand_dims(
|
||||
dnums.scatter_dims_to_operand_dims().begin(),
|
||||
dnums.scatter_dims_to_operand_dims().end());
|
||||
auto attr = mlir::xla_hlo::ScatterDimensionNumbers::get(
|
||||
Convert(update_window_dims), Convert(inserted_window_dims),
|
||||
Convert(scatter_dims_to_operand_dims),
|
||||
builder_->getI64IntegerAttr(dnums.index_vector_dim()), context_);
|
||||
return builder_->getNamedAttr("scatter_dimension_numbers", attr);
|
||||
}
|
||||
|
||||
mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs(
|
||||
const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
|
||||
source_target_pairs) {
|
||||
|
@ -121,6 +121,10 @@ class HloFunctionImporter {
|
||||
mlir::NamedAttribute ConvertGatherDimensionNumbers(
|
||||
const xla::GatherDimensionNumbers& dnums);
|
||||
|
||||
// Converts the scatter dimensions to attributes.
|
||||
mlir::NamedAttribute ConvertScatterDimensionNumbers(
|
||||
const xla::ScatterDimensionNumbers& dnums);
|
||||
|
||||
// Converts XLA instruction source target pairs to MLIR attribute.
|
||||
mlir::NamedAttribute ConvertSourceTargetPairs(
|
||||
const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
|
||||
|
@ -60,6 +60,13 @@ def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>;
|
||||
|
||||
def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>;
|
||||
|
||||
// Dynamic representation of a shape vector as a tensor. Ideally this would be
|
||||
// an index type (as it stores indices) but that is currently disallowed in
|
||||
// MLIR.
|
||||
def HLO_DimensionTensor : ShapedContainerType<
|
||||
[AnyInteger], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
|
||||
"a 1D tensor of dimensions">;
|
||||
|
||||
// In general, static shaped tensor constraints should be avoided unless
|
||||
// it is for a legacy op which is only correct with static shapes.
|
||||
def HLO_StaticShapeTensor : StaticShapeTensorOf<[
|
||||
@ -771,10 +778,22 @@ def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim",
|
||||
}
|
||||
|
||||
def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim",
|
||||
[NoSideEffect]>, BASE_HLO_DynamicBroadcastInDimOp {
|
||||
[NoSideEffect]> {
|
||||
string summary = "Broadcast a tensor into the given dynamic shape by adding dimensions.";
|
||||
string description = [{
|
||||
This is a generalization of the BroadcastInDimOp which accepts its output
|
||||
dimensions as an argument. It should eventually supercede the statically
|
||||
shaped original, but is being phased as a separate op in order to support
|
||||
compatibility with lowerings and translations that precede dynamic
|
||||
shapes.
|
||||
|
||||
Note that the `broadcast_dimensions` attribute is optional and if omitted,
|
||||
it is assumed to be an ordered, right-aligned mapping from input to
|
||||
output dimensions.
|
||||
}];
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$operand,
|
||||
HLO_BASE_DimensionTensor:$output_dimensions,
|
||||
HLO_DimensionTensor:$output_dimensions,
|
||||
BroadcastDimAttr:$broadcast_dimensions
|
||||
);
|
||||
|
||||
|
@ -27,13 +27,6 @@ def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
|
||||
// matching the matrix to dimensions 1 and 2 of the cuboid.
|
||||
def BroadcastDimAttr : OptionalAttr<I64ElementsAttr>;
|
||||
|
||||
// Dynamic representation of a shape vector as a tensor. Ideally this would be
|
||||
// an index type (as it stores indices) but that is currently disallowed in
|
||||
// MLIR.
|
||||
def HLO_BASE_DimensionTensor : ShapedContainerType<
|
||||
[AnyInteger], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
|
||||
"a 1D tensor of dimensions">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA nullary op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -817,22 +810,6 @@ class BASE_HLO_BroadcastInDimOp {
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_DynamicBroadcastInDimOp {
|
||||
string summary = "Broadcast a tensor into the given dynamic shape by adding dimensions.";
|
||||
|
||||
string description = [{
|
||||
This is a generalization of the BroadcastInDimOp which accepts its output
|
||||
dimensions as an argument. It should eventually supercede the statically
|
||||
shaped original, but is being phased as a separate op in order to support
|
||||
compatibility with lowerings and translations that precede dynamic
|
||||
shapes.
|
||||
|
||||
Note that the `broadcast_dimensions` attribute is optional and if omitted,
|
||||
it is assumed to be an ordered, right-aligned mapping from input to
|
||||
output dimensions.
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_CholeskyOp {
|
||||
string summary = "Cholesky operator";
|
||||
|
||||
|
@ -242,16 +242,6 @@ def LHLO_BroadcastInDimOp : LHLO_Op<"broadcast_in_dim",
|
||||
);
|
||||
}
|
||||
|
||||
def HLO_DynamicBroadcastInDimOp : LHLO_Op<"dynamic_broadcast_in_dim",
|
||||
[NoSideEffect]>, BASE_HLO_DynamicBroadcastInDimOp {
|
||||
let arguments = (ins
|
||||
LHLO_Buffer:$operand,
|
||||
HLO_BASE_DimensionTensor:$output_dimensions,
|
||||
LHLO_Buffer:$output,
|
||||
BroadcastDimAttr:$broadcast_dimensions
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_ClampOp : LHLO_Op<"clamp", []>, BASE_HLO_ClampOp {
|
||||
let arguments = (ins
|
||||
LHLO_Buffer:$min,
|
||||
|
@ -1,32 +1,57 @@
|
||||
// RUN: tf-opt -lhlo-fuse-linalg %s -o - | FileCheck %s
|
||||
// RUN: tf-opt -lhlo-fuse-linalg %s -o - | FileCheck %s --dump-input=always
|
||||
// RUN: tf-opt -lhlo-fuse-linalg -tile-sizes-for-linalg-fusion=2,3 %s -o - | FileCheck %s -check-prefix=TILED --dump-input-on-failure
|
||||
// RUN: tf-opt -lhlo-fuse-linalg -tile-to-parallel-loops-for-linalg-fusion %s -o - | FileCheck %s -check-prefix=PLOOP --dump-input-on-failure
|
||||
|
||||
|
||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
||||
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
|
||||
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%temp_result = alloc() {temp = true} : memref<2x2xf32>
|
||||
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
|
||||
%temp_result = alloc() {temp = true} : memref<6x6xf32>
|
||||
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
|
||||
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
||||
%out = addf %summand_1_in, %summand_2_in : f32
|
||||
linalg.yield %out : f32
|
||||
} : memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>
|
||||
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
||||
linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result {
|
||||
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
|
||||
%out = mulf %temp_result_in, %multiplier_in : f32
|
||||
linalg.yield %out : f32
|
||||
} : memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>
|
||||
dealloc %temp_result : memref<2x2xf32>
|
||||
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
||||
dealloc %temp_result : memref<6x6xf32>
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
// CHECK-LABEL: func @fusion
|
||||
// CHECK-NOT: linalg.generic
|
||||
// CHECK: loop.for
|
||||
// CHECK: loop.for
|
||||
// CHECK-NOT: loop.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: addf
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: mulf
|
||||
// CHECK: %[[C1:.*]] = constant 1
|
||||
// CHECK-NOT: linalg.generic
|
||||
// CHECK: loop.for {{.*}} step %[[C1]]
|
||||
// CHECK: loop.for {{.*}} step %[[C1]]
|
||||
// CHECK-NOT: loop.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: addf
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: mulf
|
||||
|
||||
// TILED-LABEL: func @fusion
|
||||
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||
// TILED-DAG: %[[C3:.*]] = constant 3
|
||||
// TILED-NOT: linalg.generic
|
||||
// TILED: loop.for {{.*}} step %[[C2]]
|
||||
// TILED: loop.for {{.*}} step %[[C3]]
|
||||
// TILED-NOT: loop.for
|
||||
// TILED: linalg.generic
|
||||
// TILED: addf
|
||||
// TILED: linalg.generic
|
||||
// TILED: mulf
|
||||
|
||||
// PLOOP-LABEL: func @fusion
|
||||
// PLOOP-NOT: linalg.generic
|
||||
// PLOOP: loop.parallel
|
||||
// PLOOP-NOT: loop.parallel
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: addf
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: mulf
|
||||
|
||||
func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||
%arg1: memref<100xf32>,
|
||||
@ -67,12 +92,36 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @fusion
|
||||
// CHECK-NOT: linalg.generic
|
||||
// CHECK: loop.for
|
||||
// CHECK: loop.for
|
||||
// CHECK-NOT: loop.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: subf
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: exp
|
||||
// CHECK: %[[C1:.*]] = constant 1
|
||||
// CHECK-NOT: linalg.generic
|
||||
// CHECK: loop.for {{.*}} step %[[C1]]
|
||||
// CHECK: loop.for {{.*}} step %[[C1]]
|
||||
// CHECK-NOT: loop.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: subf
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: exp
|
||||
|
||||
// TILED-LABEL: func @fusion_of_three
|
||||
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||
// TILED-DAG: %[[C3:.*]] = constant 3
|
||||
// TILED-NOT: linalg.generic
|
||||
// TILED: loop.for {{.*}} step %[[C2]]
|
||||
// TILED: loop.for {{.*}} step %[[C3]]
|
||||
// TILED-NOT: loop.for
|
||||
// TILED: linalg.generic
|
||||
// TILED: linalg.generic
|
||||
// TILED: subf
|
||||
// TILED: linalg.generic
|
||||
// TILED: exp
|
||||
|
||||
// PLOOP-LABEL: func @fusion_of_three
|
||||
// PLOOP-NOT: linalg.generic
|
||||
// PLOOP: loop.parallel
|
||||
// PLOOP-NOT: loop.parallel
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: subf
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: exp
|
||||
|
@ -179,6 +179,22 @@ func @iota(%out: memref<7x10xi64>) {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
|
||||
// CHECK-LABEL: func @dynamic_broadcast
|
||||
func @dynamic_broadcast(%operand: memref<?x?x?xf32>,
|
||||
%result: memref<?x?x?x?x?xf32>) {
|
||||
"xla_lhlo.broadcast_in_dim"(%operand, %result)
|
||||
{broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>}
|
||||
: (memref<?x?x?xf32>, memref<?x?x?x?x?xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
|
||||
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, 0)>
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
|
||||
// CHECK-LABEL: func @broadcast
|
||||
|
@ -152,13 +152,6 @@ func @broadcast_in_dim_zero_rank_memref(%arg0: memref<i32>, %out: memref<1x2x3xi
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @dynamic_broadcast_in_dim_memref
|
||||
func @dynamic_broadcast_in_dim_memref(%arg0: memref<?x?xi32>, %out: memref<?x?x?xi32>, %shape: tensor<3xi64>) -> () {
|
||||
"xla_lhlo.dynamic_broadcast_in_dim"(%arg0, %shape, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<?x?xi32>, tensor<3xi64>, memref<?x?x?xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @reduce_memref
|
||||
func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf32>) -> () {
|
||||
|
@ -716,6 +716,37 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
|
||||
ROOT %Arg_0.1 = f32[] parameter(0)
|
||||
}
|
||||
|
||||
// Test scatter
|
||||
%update_computation {
|
||||
%lhs = f32[] parameter(0)
|
||||
%rhs = f32[] parameter(1)
|
||||
ROOT %sum = f32[] add(f32[] %lhs, f32[] %rhs)
|
||||
}
|
||||
|
||||
%test_scatter {
|
||||
%input_tensor = f32[200,100,300] parameter(0)
|
||||
%scatter_indices = s64[10,2] parameter(1)
|
||||
%updates = f32[10,300] parameter(2)
|
||||
ROOT %scatter = f32[200,100,300] scatter(f32[200,100,300] %input_tensor, s64[10,2] %scatter_indices, f32[10,300] %updates), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, to_apply=%update_computation
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_scatter
|
||||
// CHECK-SAME: [[ARG_0:%.*]]: tensor<200x100x300xf32>, [[ARG_1:%.*]]: tensor<10x2xi64>, [[ARG_2:%.*]]: tensor<10x300xf32>) -> tensor<200x100x300xf32>
|
||||
// CHECK: "xla_hlo.scatter"([[ARG_0]], [[ARG_1]], [[ARG_2]]) ( {
|
||||
// CHECK: ^bb0([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
|
||||
// CHECK: [[ADD:%.*]] = xla_hlo.add [[LHS]], [[RHS]]
|
||||
// CHECK: "xla_hlo.return"([[ADD]]) : (tensor<f32>) -> ()
|
||||
// CHECK: })
|
||||
// CHECK-SAME: indices_are_sorted = false
|
||||
// CHECK-SAME: scatter_dimension_numbers = {
|
||||
// CHECK-SAME: index_vector_dim = 1 : i64
|
||||
// CHECK-SAME: inserted_window_dims = dense<[0, 1]> : tensor<2xi64>
|
||||
// CHECK-SAME: scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>
|
||||
// CHECK-SAME: update_window_dims = dense<1> : tensor<1xi64>
|
||||
// CHECK-SAME: }
|
||||
// CHECK-SAME: unique_indices = false
|
||||
|
||||
|
||||
// CHECK-LABEL: func @test_select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
|
||||
%test_select {
|
||||
%Arg_0.1 = pred[2,3] parameter(0)
|
||||
@ -743,6 +774,25 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
|
||||
ROOT %sine.3 = f32[1,16,16,3]{3,2,1,0} sine(f32[1,16,16,3]{3,2,1,0} %arg0.1)
|
||||
}
|
||||
|
||||
// Test sort
|
||||
%compare {
|
||||
p.0.lhs = f32[] parameter(0)
|
||||
p.0.rhs = f32[] parameter(1)
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
%test_sort {
|
||||
x = f32[1024]{0} parameter(0)
|
||||
ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, is_stable=true, to_apply=compare
|
||||
}
|
||||
// CHECK-LABEL: func @test_sort
|
||||
// CHECK-SAME: [[ARG:%.*]]: tensor<1024xf32>) -> tensor<1024xf32>
|
||||
// CHECK: "xla_hlo.sort"([[ARG]]) ( {
|
||||
// CHECK: ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
|
||||
// CHECK: [[CMP:%.*]] = "xla_hlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "LT", name = "lt"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: "xla_hlo.return"([[CMP]]) : (tensor<i1>) -> ()
|
||||
// CHECK: }) {dimension = 0 : i64, is_stable = true} : (tensor<1024xf32>) -> tensor<1024xf32>
|
||||
|
||||
// CHECK-LABEL: func @test_subtract
|
||||
%test_subtract (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
|
@ -22,6 +22,20 @@ limitations under the License.
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/FoldUtils.h" // TF:llvm-project
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> tile_to_parallel_loops_for_linalg_fusion(
|
||||
"tile-to-parallel-loops-for-linalg-fusion",
|
||||
llvm::cl::desc(
|
||||
"Tiles GenericOp consumer to parallel loops before linalg fusion"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::list<unsigned> tile_sizes_for_linalg_fusion(
|
||||
"tile-sizes-for-linalg-fusion",
|
||||
llvm::cl::desc(
|
||||
"Tile sizes by which to tile linalg generic before linalg fusion"),
|
||||
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace {
|
||||
@ -50,13 +64,16 @@ struct LhloFuseLinalg : public FunctionPass<LhloFuseLinalg> {
|
||||
OpBuilder b(func);
|
||||
OperationFolder folder(func.getContext());
|
||||
func.walk([&](linalg::GenericOp generic_op) {
|
||||
const SmallVector<int64_t, 2> tile_sizes(
|
||||
generic_op.getNumInputsAndOutputs(), 1);
|
||||
SmallVector<int64_t, 2> tile_sizes(tile_sizes_for_linalg_fusion.begin(),
|
||||
tile_sizes_for_linalg_fusion.end());
|
||||
if (tile_sizes.empty()) {
|
||||
tile_sizes =
|
||||
SmallVector<int64_t, 2>(generic_op.getNumInputsAndOutputs(), 1);
|
||||
}
|
||||
auto op = cast<LinalgOp>(generic_op.getOperation());
|
||||
for (const Value result : op.getOutputBuffers()) {
|
||||
if (!func_args.count(result)) continue;
|
||||
if (linalg::tileLinalgOp(b, op, tile_sizes, /*permutation=*/{},
|
||||
&folder)) {
|
||||
if (tileGenericOp(op, tile_sizes, &b, &folder)) {
|
||||
generic_op.erase();
|
||||
return;
|
||||
}
|
||||
@ -83,6 +100,18 @@ struct LhloFuseLinalg : public FunctionPass<LhloFuseLinalg> {
|
||||
}
|
||||
for (auto* e : erase_set) e->erase();
|
||||
}
|
||||
|
||||
private:
|
||||
bool tileGenericOp(LinalgOp op, ArrayRef<int64_t> tile_sizes, OpBuilder* b,
|
||||
OperationFolder* folder) {
|
||||
auto tiled_generic_op =
|
||||
tile_to_parallel_loops_for_linalg_fusion
|
||||
? linalg::tileLinalgOpToParallelLoops(*b, op, tile_sizes,
|
||||
/*permutation=*/{}, folder)
|
||||
: linalg::tileLinalgOp(*b, op, tile_sizes,
|
||||
/*permutation=*/{}, folder);
|
||||
return tiled_generic_op.hasValue();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
@ -227,19 +227,21 @@ class BroadcastInDimConverter
|
||||
|
||||
unsigned nloops = resultMemrefType.getRank();
|
||||
|
||||
auto operandShape = operandMemrefType.getShape();
|
||||
SmallVector<AffineExpr, 4> dimExprs;
|
||||
{
|
||||
dimExprs.reserve(nloops);
|
||||
for (const auto& broadcastDim : llvm::enumerate(
|
||||
broadcastOp.broadcast_dimensions().getValue().getIntValues())) {
|
||||
int dim = broadcastDim.value().getSExtValue();
|
||||
|
||||
auto operandShape = operandMemrefType.getShape();
|
||||
int index = 0;
|
||||
for (const auto& broadcastSize :
|
||||
broadcastOp.broadcast_dimensions().getValue().getIntValues()) {
|
||||
int size = broadcastSize.getSExtValue();
|
||||
dimExprs.push_back(
|
||||
operandShape[index++] == 1
|
||||
// TODO(pifon): Add support for args with dynamic shapes for the case
|
||||
// when a dimension of size 1 is broadcasted into dim of size N.
|
||||
AffineExpr affineExpr =
|
||||
operandShape[broadcastDim.index()] == 1
|
||||
? mlir::getAffineConstantExpr(0, broadcastOp.getContext())
|
||||
: mlir::getAffineDimExpr(size, broadcastOp.getContext()));
|
||||
: mlir::getAffineDimExpr(dim, broadcastOp.getContext());
|
||||
dimExprs.push_back(affineExpr);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -18,6 +18,10 @@ package_group(
|
||||
includes = [
|
||||
"//tensorflow/compiler/tf2xla:internal",
|
||||
],
|
||||
packages = [
|
||||
# To pass open source testing in the pip Kokoros.
|
||||
"//bazel_pip/tensorflow/compiler/tests/...",
|
||||
],
|
||||
)
|
||||
|
||||
package_group(
|
||||
@ -25,6 +29,10 @@ package_group(
|
||||
includes = [
|
||||
"//tensorflow/compiler/tf2xla:friends",
|
||||
],
|
||||
packages = [
|
||||
# To pass open source testing in the pip Kokoros.
|
||||
"//bazel_pip/tensorflow/compiler/tests/...",
|
||||
],
|
||||
)
|
||||
|
||||
generate_backend_suites()
|
||||
@ -66,6 +74,9 @@ py_test(
|
||||
size = "small",
|
||||
srcs = ["xla_test_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
],
|
||||
@ -76,6 +87,9 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
srcs = ["adadelta_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -90,6 +104,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["adagrad_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -105,6 +122,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["adagrad_da_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -119,6 +139,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["adam_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -136,6 +159,9 @@ tf_xla_py_test(
|
||||
# TensorList ops are not implemented in the on-demand compilation model yet.
|
||||
disabled_backends = ["cpu_ondemand"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -151,6 +177,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["argminmax_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -168,6 +197,7 @@ tf_xla_py_test(
|
||||
shard_count = 5,
|
||||
tags = [
|
||||
"no_oss", # TODO(b/148108508): Re-enable this test in OSS.
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"optonly", # Times out frequently in fastbuild mode.
|
||||
],
|
||||
deps = [
|
||||
@ -194,6 +224,7 @@ tf_xla_py_test(
|
||||
python_version = "PY3",
|
||||
shard_count = 2,
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"optonly", # Times out frequently in fastbuild mode.
|
||||
],
|
||||
deps = [
|
||||
@ -212,6 +243,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["bucketize_op_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -226,7 +260,10 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["categorical_op_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = ["optonly"],
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:framework",
|
||||
@ -242,6 +279,7 @@ tf_xla_py_test(
|
||||
srcs = ["cholesky_op_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"no_rocm",
|
||||
"optonly",
|
||||
],
|
||||
@ -261,6 +299,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["cond_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/compiler/tf2xla/python:xla",
|
||||
@ -278,7 +319,10 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
srcs = ["self_adjoint_eig_op_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = ["optonly"],
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -291,18 +335,6 @@ tf_xla_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "searchsorted_op_test",
|
||||
size = "small",
|
||||
timeout = "moderate",
|
||||
srcs = ["searchsorted_op_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "svd_op_test",
|
||||
size = "medium",
|
||||
@ -314,6 +346,7 @@ tf_xla_py_test(
|
||||
],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"no_rocm",
|
||||
"optonly",
|
||||
],
|
||||
@ -336,6 +369,7 @@ tf_xla_py_test(
|
||||
srcs = ["matrix_inverse_op_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"noasan",
|
||||
"nomsan",
|
||||
"notsan",
|
||||
@ -356,6 +390,9 @@ tf_xla_py_test(
|
||||
timeout = "moderate",
|
||||
srcs = ["matrix_solve_op_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:linalg_ops",
|
||||
@ -371,7 +408,10 @@ tf_xla_py_test(
|
||||
timeout = "moderate",
|
||||
srcs = ["matrix_triangular_solve_op_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = ["optonly"],
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -387,6 +427,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["clustering_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -403,6 +446,7 @@ tf_xla_py_test(
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"many_xla_args",
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"no_rocm",
|
||||
],
|
||||
deps = [
|
||||
@ -423,6 +467,9 @@ tf_xla_py_test(
|
||||
srcs = ["conv2d_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 10,
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":test_utils",
|
||||
":xla_test",
|
||||
@ -442,6 +489,9 @@ tf_xla_py_test(
|
||||
srcs = ["conv3d_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 5,
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -460,6 +510,7 @@ tf_xla_py_test(
|
||||
python_version = "PY3",
|
||||
shard_count = 5,
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"no_rocm",
|
||||
"noasan",
|
||||
"nomsan",
|
||||
@ -482,6 +533,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["dynamic_slice_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tests:xla_test",
|
||||
"//tensorflow/compiler/tf2xla/python:xla",
|
||||
@ -499,6 +553,9 @@ tf_xla_py_test(
|
||||
"gpu",
|
||||
],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -513,6 +570,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["reshape_op_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tests:xla_test",
|
||||
"//tensorflow/compiler/tf2xla/python:xla",
|
||||
@ -527,6 +587,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["dynamic_stitch_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -541,6 +604,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["extract_image_patches_op_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -556,6 +622,7 @@ tf_xla_py_test(
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
@ -574,6 +641,9 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
srcs = ["fifo_queue_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -591,6 +661,7 @@ tf_xla_py_test(
|
||||
python_version = "PY3",
|
||||
shard_count = 6,
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"no_rocm",
|
||||
"optonly",
|
||||
],
|
||||
@ -609,6 +680,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["slice_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -623,6 +697,9 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
srcs = ["ftrl_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -638,6 +715,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["function_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -653,6 +733,7 @@ tf_xla_py_test(
|
||||
python_version = "PY3",
|
||||
shard_count = 10,
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"optonly", # Times out frequently in fastbuild mode.
|
||||
],
|
||||
deps = [
|
||||
@ -669,6 +750,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["listdiff_op_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -685,6 +769,9 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
srcs = ["lrn_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -700,6 +787,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["manip_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -715,7 +805,10 @@ tf_xla_py_test(
|
||||
timeout = "long",
|
||||
srcs = ["matrix_band_part_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = ["optonly"],
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -731,6 +824,9 @@ tf_xla_py_test(
|
||||
timeout = "long",
|
||||
srcs = ["matrix_diag_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -744,6 +840,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["momentum_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -759,6 +858,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["nary_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -773,6 +875,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["nullary_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
@ -787,6 +892,9 @@ tf_xla_py_test(
|
||||
srcs = ["pooling_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 10,
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -803,6 +911,9 @@ tf_xla_py_test(
|
||||
srcs = ["pooling_ops_3d_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 10,
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -818,6 +929,9 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
srcs = ["proximal_adagrad_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -832,6 +946,9 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
srcs = ["proximal_gradient_descent_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -852,7 +969,10 @@ tf_xla_py_test(
|
||||
],
|
||||
python_version = "PY3",
|
||||
shard_count = 5,
|
||||
tags = ["optonly"],
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -871,6 +991,7 @@ tf_xla_py_test(
|
||||
python_version = "PY3",
|
||||
shard_count = 5,
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"no_rocm",
|
||||
"optonly",
|
||||
],
|
||||
@ -892,6 +1013,7 @@ tf_xla_py_test(
|
||||
python_version = "PY3",
|
||||
shard_count = 10,
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"notap", # TODO(b/141057424): flaky on TPU
|
||||
],
|
||||
deps = [
|
||||
@ -911,6 +1033,9 @@ tf_xla_py_test(
|
||||
srcs = ["reduce_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 5,
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -927,6 +1052,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["reduce_window_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/compiler/tf2xla/python:xla",
|
||||
@ -943,6 +1071,9 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
srcs = ["reverse_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -955,7 +1086,10 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
srcs = ["reverse_sequence_op_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = ["optonly"],
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -969,6 +1103,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["rmsprop_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -984,7 +1121,10 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["scan_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = ["optonly"],
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -999,6 +1139,9 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
srcs = ["segment_reduction_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -1015,6 +1158,9 @@ tf_xla_py_test(
|
||||
srcs = ["spacetobatch_op_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 3,
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -1029,6 +1175,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["sparse_to_dense_op_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -1043,7 +1192,10 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["stack_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = ["config-cuda-only"],
|
||||
tags = [
|
||||
"config-cuda-only",
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
use_xla_device = False,
|
||||
deps = [
|
||||
":xla_test",
|
||||
@ -1060,7 +1212,10 @@ tf_xla_py_test(
|
||||
srcs = ["stateful_random_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 10,
|
||||
tags = ["optonly"],
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:framework",
|
||||
@ -1076,7 +1231,10 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
srcs = ["stateless_random_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = ["optonly"],
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:framework",
|
||||
@ -1096,6 +1254,7 @@ tf_xla_py_test(
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"config-cuda-only",
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"v1only",
|
||||
],
|
||||
use_xla_device = False,
|
||||
@ -1121,6 +1280,9 @@ tf_xla_py_test(
|
||||
# TensorList ops are not implemented in the on-demand compilation model yet.
|
||||
disabled_backends = ["cpu_ondemand"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -1136,6 +1298,9 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
srcs = ["ternary_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -1152,6 +1317,9 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
srcs = ["unary_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -1168,6 +1336,9 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
srcs = ["fused_batchnorm_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":test_utils",
|
||||
":xla_test",
|
||||
@ -1188,7 +1359,10 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["variable_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = ["optonly"],
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -1207,6 +1381,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["while_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/compiler/tf2xla/python:xla",
|
||||
@ -1222,7 +1399,10 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
srcs = ["gather_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = ["optonly"],
|
||||
tags = [
|
||||
"no_pip",
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -1237,6 +1417,9 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
srcs = ["gather_nd_op_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -1250,7 +1433,10 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
srcs = ["scatter_nd_op_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = ["optonly"],
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -1266,7 +1452,10 @@ tf_xla_py_test(
|
||||
python_version = "PY3",
|
||||
shard_count = 1,
|
||||
# Times out in fastbuild mode.
|
||||
tags = ["optonly"],
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tests:xla_test",
|
||||
"//tensorflow/compiler/tf2xla/python:xla",
|
||||
@ -1280,6 +1469,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["data_format_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tests:xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -1294,7 +1486,10 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["xla_device_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = ["optonly"],
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -1307,6 +1502,9 @@ cuda_py_test(
|
||||
name = "xla_device_gpu_test",
|
||||
size = "small",
|
||||
srcs = ["xla_device_gpu_test.py"],
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
xla_enable_strict_auto_jit = False,
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -1323,7 +1521,10 @@ cuda_py_test(
|
||||
size = "medium",
|
||||
srcs = ["jit_test.py"],
|
||||
shard_count = 5,
|
||||
tags = ["no_rocm"],
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"no_rocm",
|
||||
],
|
||||
xla_enable_strict_auto_jit = False,
|
||||
deps = [
|
||||
":test_utils",
|
||||
@ -1344,7 +1545,10 @@ cuda_py_test(
|
||||
name = "dense_layer_test",
|
||||
size = "medium",
|
||||
srcs = ["dense_layer_test.py"],
|
||||
tags = ["no_rocm"],
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"no_rocm",
|
||||
],
|
||||
xla_enable_strict_auto_jit = False,
|
||||
deps = [
|
||||
":test_utils",
|
||||
@ -1385,6 +1589,7 @@ tf_cuda_cc_test(
|
||||
size = "large",
|
||||
# This test is randomized, so only run it if explicitly requested.
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"manual",
|
||||
"notap",
|
||||
] + tf_cuda_tests_tags(),
|
||||
@ -1394,7 +1599,9 @@ tf_cuda_cc_test(
|
||||
tf_cuda_cc_test(
|
||||
name = "unary_ops_composition_test",
|
||||
srcs = ["unary_ops_composition_test.cc"],
|
||||
tags = tf_cuda_tests_tags(),
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
] + tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/compiler/jit",
|
||||
@ -1430,7 +1637,10 @@ py_library(
|
||||
cuda_py_test(
|
||||
name = "lstm_test",
|
||||
srcs = ["lstm_test.py"],
|
||||
tags = ["no_rocm"],
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"no_rocm",
|
||||
],
|
||||
xla_enable_strict_auto_jit = False,
|
||||
deps = [
|
||||
":lstm",
|
||||
@ -1474,6 +1684,9 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
srcs = ["fake_quant_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:framework",
|
||||
@ -1486,6 +1699,9 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["placeholder_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -1499,6 +1715,9 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
srcs = ["quantized_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/compiler/tf2xla/python:xla",
|
||||
@ -1516,6 +1735,9 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
srcs = ["xla_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/compiler/tf2xla/python:xla",
|
||||
@ -1535,6 +1757,7 @@ tf_xla_py_test(
|
||||
shard_count = 5,
|
||||
tags = [
|
||||
"no_oss", # TODO(b/148108508): Re-enable this test in OSS.
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"no_rocm",
|
||||
],
|
||||
deps = [
|
||||
@ -1560,6 +1783,7 @@ tf_xla_py_test(
|
||||
],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
@ -1576,7 +1800,10 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
srcs = ["special_math_test.py"],
|
||||
shard_count = 5,
|
||||
tags = ["optonly"],
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
|
@ -1,75 +0,0 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test for XLA implementation of tf.searchsorted."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class SearchSorteddOpTest(xla_test.XLATestCase):
|
||||
|
||||
def test1D(self):
|
||||
# Test against NumPy implementation (which is 1D only).
|
||||
np.random.seed(1)
|
||||
for side in ['left', 'right']:
|
||||
for dtype in [np.float32, np.int32]:
|
||||
values = np.random.uniform(
|
||||
low=-1000, high=1000, size=(10,)).astype(dtype)
|
||||
unsorted = np.random.uniform(
|
||||
low=-1000, high=1000, size=(20,)).astype(dtype)
|
||||
|
||||
sorted_sequence = np.sort(unsorted)
|
||||
np_ans = np.searchsorted(sorted_sequence, values, side=side)
|
||||
|
||||
with self.session() as session:
|
||||
with self.test_scope():
|
||||
tf_ans = array_ops.searchsorted(sorted_sequence, values, side=side)
|
||||
tf_out = session.run(tf_ans)
|
||||
self.assertAllEqual(np_ans, tf_out)
|
||||
|
||||
def _test2DExample(self, dtype, side, sorted_sequence, values, correct_ans):
|
||||
|
||||
with self.session() as session:
|
||||
with self.test_scope():
|
||||
tf_ans = array_ops.searchsorted(sorted_sequence, values, side=side)
|
||||
tf_out = session.run(tf_ans)
|
||||
self.assertAllEqual(correct_ans, tf_out)
|
||||
|
||||
def testLowerBound2DExample(self):
|
||||
# 2D TensorFlow documentation example.
|
||||
for dtype in self.float_types | self.int_types:
|
||||
sorted_sequence = np.array([[0, 3, 9, 9, 10], [1, 2, 3, 4, 5]], dtype)
|
||||
values = np.array([[2, 4, 9], [0, 2, 6]], dtype)
|
||||
correct_ans = np.array([[1, 2, 2], [0, 1, 5]], dtype)
|
||||
self._test2DExample(dtype, 'left', sorted_sequence, values, correct_ans)
|
||||
|
||||
def testUpperBound2DExample(self):
|
||||
# 2D TensorFlow documentation example.
|
||||
for dtype in self.float_types | self.int_types:
|
||||
sorted_sequence = np.array([[0, 3, 9, 9, 10], [1, 2, 3, 4, 5]], dtype)
|
||||
values = np.array([[2, 4, 9], [0, 2, 6]], dtype)
|
||||
correct_ans = np.array([[1, 2, 4], [0, 2, 5]], dtype)
|
||||
self._test2DExample(dtype, 'right', sorted_sequence, values, correct_ans)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -29,6 +29,10 @@ import scipy.special as sps
|
||||
import six
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import gen_random_ops
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -39,6 +43,13 @@ flags.DEFINE_bool('vary_seed', False,
|
||||
NUM_SAMPLES = int(1e3)
|
||||
|
||||
|
||||
# This is df/da / df/dx, where f = igamma.
|
||||
def implicit_reparameterization_grad(a, x):
|
||||
log_prob = math_ops.xlogy(a - 1., x) - math_ops.lgamma(a) - x
|
||||
prob = math_ops.exp(log_prob)
|
||||
return -gen_math_ops.igamma_grad_a(a, x) / prob
|
||||
|
||||
|
||||
class IgammaTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -48,9 +59,15 @@ class IgammaTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
answer = int(entropy.encode('hex'), 16)
|
||||
else:
|
||||
answer = int.from_bytes(entropy, 'big')
|
||||
np.random.seed(answer)
|
||||
np.random.seed(answer % (2**32 - 1))
|
||||
super(IgammaTest, self).setUp()
|
||||
|
||||
# Skip Float64 test on TPU due to missing ops.
|
||||
def maybe_skip_test(self, dtype):
|
||||
if self.device not in ['XLA_GPU', 'XLA_CPU', 'CPU'] and dtype == np.float64:
|
||||
self.skipTest(
|
||||
'Skipping test because some F64 operations not supported on TPU.')
|
||||
|
||||
@parameterized.parameters((np.float32, 1e-2, 1e-11),
|
||||
(np.float64, 1e-4, 1e-30))
|
||||
def testIgammaSmallValues(self, dtype, rtol, atol):
|
||||
@ -93,6 +110,97 @@ class IgammaTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
actual = sess.run(math_ops.igamma(a, x))
|
||||
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
|
||||
|
||||
# We don't check small values because the numerical gradients become quite
|
||||
# large.
|
||||
@parameterized.parameters((np.float32, 0.09), (np.float64, 1e-7))
|
||||
def testIgammaGradMediumValues(self, dtype, tolerance):
|
||||
self.maybe_skip_test(dtype)
|
||||
with self.session():
|
||||
with self.test_scope():
|
||||
x = constant_op.constant(
|
||||
np.random.uniform(low=1., high=100.,
|
||||
size=[NUM_SAMPLES]).astype(dtype))
|
||||
a = constant_op.constant(
|
||||
np.random.uniform(low=1., high=100.,
|
||||
size=[NUM_SAMPLES]).astype(dtype))
|
||||
|
||||
f = lambda b: math_ops.igamma(b, x)
|
||||
max_error = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(f, x=[a], delta=1e-3))
|
||||
self.assertLessEqual(max_error, tolerance)
|
||||
|
||||
@parameterized.parameters((np.float32, 0.5), (np.float64, 1e-7))
|
||||
def testIgammaGradLargeValues(self, dtype, tolerance):
|
||||
self.maybe_skip_test(dtype)
|
||||
with self.session():
|
||||
with self.test_scope():
|
||||
x = constant_op.constant(
|
||||
np.random.uniform(low=100., high=int(1e4),
|
||||
size=[NUM_SAMPLES]).astype(dtype))
|
||||
a = constant_op.constant(
|
||||
np.random.uniform(low=100., high=int(1e4),
|
||||
size=[NUM_SAMPLES]).astype(dtype))
|
||||
|
||||
f = lambda b: math_ops.igamma(b, x)
|
||||
max_error = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(f, x=[a], delta=1e-2))
|
||||
self.assertLessEqual(max_error, tolerance)
|
||||
|
||||
@parameterized.parameters((np.float32, 1e-2, 1e-11),
|
||||
(np.float64, 1e-4, 1e-30))
|
||||
def testRandomGammaGradSmallValues(self, dtype, rtol, atol):
|
||||
self.maybe_skip_test(dtype)
|
||||
# Test values near zero.
|
||||
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
x = constant_op.constant(
|
||||
np.random.uniform(
|
||||
low=np.finfo(dtype).tiny, high=1.,
|
||||
size=[NUM_SAMPLES]).astype(dtype))
|
||||
a = constant_op.constant(
|
||||
np.random.uniform(
|
||||
low=np.finfo(dtype).tiny, high=1.,
|
||||
size=[NUM_SAMPLES]).astype(dtype))
|
||||
gamma_sample_grad = gen_random_ops.random_gamma_grad(a, x)
|
||||
actual_grad = implicit_reparameterization_grad(a, x)
|
||||
gamma_sample_grad, actual_grad = sess.run(
|
||||
[gamma_sample_grad, actual_grad])
|
||||
# We do this because the ratio computed in
|
||||
# implicit_reparameterization_grad can very easily result in a NaN due
|
||||
# to the computed numerator and denominator zeroing out.
|
||||
gamma_sample_grad = gamma_sample_grad[
|
||||
~np.logical_or(np.isnan(actual_grad), np.isinf(actual_grad))]
|
||||
actual_grad = actual_grad[
|
||||
~np.logical_or(np.isnan(actual_grad), np.isinf(actual_grad))]
|
||||
self.assertAllClose(actual_grad, gamma_sample_grad, atol=atol, rtol=rtol)
|
||||
|
||||
@parameterized.parameters((np.float32, 1e-2, 1e-11),
|
||||
(np.float64, 1e-4, 1e-30))
|
||||
def testRandomGammaGradMediumValues(self, dtype, rtol, atol):
|
||||
self.maybe_skip_test(dtype)
|
||||
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
x = constant_op.constant(
|
||||
np.random.uniform(low=1., high=10.,
|
||||
size=[NUM_SAMPLES]).astype(dtype))
|
||||
a = constant_op.constant(
|
||||
np.random.uniform(low=1., high=10.,
|
||||
size=[NUM_SAMPLES]).astype(dtype))
|
||||
gamma_sample_grad = gen_random_ops.random_gamma_grad(a, x)
|
||||
actual_grad = implicit_reparameterization_grad(a, x)
|
||||
gamma_sample_grad, actual_grad = sess.run(
|
||||
[gamma_sample_grad, actual_grad])
|
||||
# We do this because the ratio computed in
|
||||
# implicit_reparameterization_grad can very easily result in a NaN due
|
||||
# to the computed numerator and denominator zeroing out.
|
||||
gamma_sample_grad = gamma_sample_grad[
|
||||
~np.logical_or(np.isnan(actual_grad), np.isinf(actual_grad))]
|
||||
actual_grad = actual_grad[
|
||||
~np.logical_or(np.isnan(actual_grad), np.isinf(actual_grad))]
|
||||
self.assertAllClose(actual_grad, gamma_sample_grad, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os.environ['XLA_FLAGS'] = '--xla_cpu_enable_fast_math=false'
|
||||
|
@ -249,10 +249,12 @@ tf_cuda_library(
|
||||
srcs = [
|
||||
"utils/trt_int8_calibrator.cc",
|
||||
"utils/trt_lru_cache.cc",
|
||||
"utils/trt_shape_optimization_profiles.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"utils/trt_int8_calibrator.h",
|
||||
"utils/trt_lru_cache.h",
|
||||
"utils/trt_shape_optimization_profiles.h",
|
||||
],
|
||||
deps = [
|
||||
":trt_allocator",
|
||||
@ -308,6 +310,22 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "trt_shape_optimization_profiles_test",
|
||||
size = "small",
|
||||
srcs = ["utils/trt_shape_optimization_profiles_test.cc"],
|
||||
tags = [
|
||||
"no_cuda_on_cpu_tap",
|
||||
"no_windows",
|
||||
"nomac",
|
||||
],
|
||||
deps = [
|
||||
":trt_resources",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "logger_registry",
|
||||
srcs = ["convert/logger_registry.cc"],
|
||||
|
@ -431,7 +431,8 @@ Status CreateTRTNode(const ConversionParams& params,
|
||||
calibrate_int8 ? TrtPrecisionMode::FP32 : info.precision_mode,
|
||||
max_batch_size, info.max_workspace_size_bytes, input_shapes, trt_logger,
|
||||
alloc, /*calibrator=*/nullptr, &engine, info.use_calibration,
|
||||
params.use_implicit_batch, /*convert_successfully=*/nullptr));
|
||||
params.use_implicit_batch, /*convert_successfully=*/nullptr,
|
||||
/*profile=*/nullptr));
|
||||
TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
|
||||
segment_string = string(static_cast<const char*>(engine_data->data()),
|
||||
engine_data->size());
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h" // NOLINT
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
|
||||
@ -249,6 +250,16 @@ void GetInputProperties(const grappler::GraphProperties& graph_properties,
|
||||
}
|
||||
}
|
||||
|
||||
// This function checks if a tensor is compatible with TRT.
|
||||
//
|
||||
// We check that the shape and datatype are compatible with TensorRT. We also
|
||||
// return the corresponding trt_dtype, the trt_dims and the batch_size (latter
|
||||
// is only needed in implicit batch mode).
|
||||
//
|
||||
// The return status indicates wether the tensor is compatible.
|
||||
//
|
||||
// For implicit batch mode, when validation_only == false, we also check that
|
||||
// all input dimensions (besides the batch dimension) are known dimensions.
|
||||
Status ValidateTensorProperties(const string& producer_node_type,
|
||||
const DataType dtype,
|
||||
const PartialTensorShape& shape,
|
||||
@ -293,11 +304,7 @@ Status ValidateTensorProperties(const string& producer_node_type,
|
||||
|
||||
if (validation_only) return Status::OK();
|
||||
|
||||
// Following checks are only used during TRT engine creation time. In implicit
|
||||
// batch mode we check that all inputs for the network has static shape (as
|
||||
// required by the TensorRT). The only exception is the batch size, which
|
||||
// could be unknown. In contrast, using explicit batch mode this test is not
|
||||
// necessary, since any dimension could be unknown in explicit batch mode.
|
||||
// Following checks are only used during TRT engine creation time.
|
||||
if (use_implicit_batch) {
|
||||
for (int d = first_trt_dim; d < shape.dims(); ++d) {
|
||||
if (shape.dim_size(d) < 0) {
|
||||
@ -1336,7 +1343,7 @@ Status Converter::RenameAndMarkOutputTensors(
|
||||
Status Converter::BuildCudaEngine(
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, int max_batch_size,
|
||||
size_t max_workspace_size_bytes, nvinfer1::IGpuAllocator* allocator,
|
||||
TRTInt8Calibrator* calibrator) {
|
||||
TRTInt8Calibrator* calibrator, TrtShapeOptimizationProfile* profiles) {
|
||||
VLOG(1) << "Configuring TensorRT builder";
|
||||
trt_builder_->setMaxBatchSize(max_batch_size);
|
||||
trt_builder_->setGpuAllocator(allocator);
|
||||
@ -1356,7 +1363,10 @@ Status Converter::BuildCudaEngine(
|
||||
builder_config->setInt8Calibrator(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
if (!use_implicit_batch_ && profiles) {
|
||||
TF_RETURN_IF_ERROR(profiles->ConfigureBuilder(
|
||||
trt_builder_.get(), builder_config.get(), network()));
|
||||
}
|
||||
VLOG(1) << "Building TensorRT engine";
|
||||
engine->reset(
|
||||
trt_builder_->buildEngineWithConfig(*network(), *builder_config));
|
||||
@ -5743,7 +5753,8 @@ Status ConvertGraphDefToEngine(
|
||||
nvinfer1::ILogger* trt_logger, nvinfer1::IGpuAllocator* allocator,
|
||||
TRTInt8Calibrator* calibrator,
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, bool use_calibration,
|
||||
const bool use_implicit_batch, bool* convert_successfully) {
|
||||
const bool use_implicit_batch, bool* convert_successfully,
|
||||
TrtShapeOptimizationProfile* profiles) {
|
||||
engine->reset();
|
||||
if (convert_successfully) *convert_successfully = false;
|
||||
|
||||
@ -5842,7 +5853,8 @@ Status ConvertGraphDefToEngine(
|
||||
|
||||
// Build the engine.
|
||||
TF_RETURN_IF_ERROR(converter->BuildCudaEngine(
|
||||
engine, max_batch_size, max_workspace_size_bytes, allocator, calibrator));
|
||||
engine, max_batch_size, max_workspace_size_bytes, allocator, calibrator,
|
||||
profiles));
|
||||
|
||||
VLOG(1) << "Finished conversion";
|
||||
return Status::OK();
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
@ -147,7 +148,8 @@ Status ConvertGraphDefToEngine(
|
||||
nvinfer1::ILogger* logger, nvinfer1::IGpuAllocator* allocator,
|
||||
TRTInt8Calibrator* calibrator,
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, bool use_calibration,
|
||||
const bool use_implicit_batch, bool* convert_successfully);
|
||||
const bool use_implicit_batch, bool* convert_successfully,
|
||||
TrtShapeOptimizationProfile* profiles);
|
||||
|
||||
// Helper class for the segmenter to determine whether an output edge from the
|
||||
// TRT segment is valid.
|
||||
@ -467,7 +469,8 @@ class Converter {
|
||||
Status BuildCudaEngine(TrtUniquePtrType<nvinfer1::ICudaEngine>* engine,
|
||||
int max_batch_size, size_t max_workspace_size_bytes,
|
||||
nvinfer1::IGpuAllocator* allocator,
|
||||
TRTInt8Calibrator* calibrator);
|
||||
TRTInt8Calibrator* calibrator,
|
||||
TrtShapeOptimizationProfile* profiles);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Methods used by op converters to convert individual TF node and add layers
|
||||
|
@ -1187,7 +1187,7 @@ class ConvertGraphDefToEngineTest : public ::testing::Test {
|
||||
/*max_workspace_size_bytes=*/64 << 20, input_shapes, &logger_,
|
||||
/*allocator=*/nullptr, /*calibrator=*/nullptr, &engine_,
|
||||
/*use_calibration=*/false, /*use_implicit_batch=*/true,
|
||||
/*convert_successfully=*/nullptr);
|
||||
/*convert_successfully=*/nullptr, /*profiles=*/nullptr);
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -1302,7 +1302,8 @@ class OpConverterTest : public ::testing::Test {
|
||||
/*max_batch_size=*/batch_size,
|
||||
/*max_workspace_size_bytes=*/1 << 26,
|
||||
/*allocator=*/nullptr,
|
||||
/*calibrator=*/nullptr));
|
||||
/*calibrator=*/nullptr,
|
||||
/*profiles=*/nullptr));
|
||||
CHECK_NOTNULL(engine_.get());
|
||||
CheckDataTypeMatches(input_data);
|
||||
CheckDataTypeMatches(*output_data);
|
||||
|
@ -133,6 +133,25 @@ string DebugString(const std::vector<TensorShape>& shapes) {
|
||||
string DebugString(const std::vector<PartialTensorShape>& shapes) {
|
||||
return PartialTensorShapeUtils::PartialShapeListString(shapes);
|
||||
}
|
||||
|
||||
int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) {
|
||||
int n_bindings = engine->getNbBindings();
|
||||
int n_input = 0;
|
||||
for (int i = 0; i < n_bindings; i++) {
|
||||
if (engine->bindingIsInput(i)) n_input++;
|
||||
}
|
||||
// According to TensorRT 7 doc: "If the engine has been built for K profiles,
|
||||
// the first getNbBindings() / K bindings are used by profile number 0, the
|
||||
// following getNbBindings() / K bindings are used by profile number 1 etc."
|
||||
// Therefore, to get the number of input tensors, we need to divide by the
|
||||
// the number of profiles.
|
||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||
int n_profiles = engine->getNbOptimizationProfiles();
|
||||
#else
|
||||
int n_profiles = 1;
|
||||
#endif
|
||||
return n_input / n_profiles;
|
||||
}
|
||||
#endif
|
||||
|
||||
string GetLinkedTensorRTVersion() {
|
||||
|
@ -106,6 +106,11 @@ string GetLinkedTensorRTVersion();
|
||||
// TensorRT library version information {Maj, Min, Patch}.
|
||||
string GetLoadedTensorRTVersion();
|
||||
|
||||
// Returns the number of inputs for the engine, which also correspends to the
|
||||
// number of input tensors for the network. This can differ from the number of
|
||||
// input bindings, because the number of total input bindings equals the number
|
||||
// of profiles times the number of engine inputs.
|
||||
int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine);
|
||||
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
|
||||
} // namespace tensorrt
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
@ -92,7 +93,7 @@ class TRTEngineOp : public AsyncOpKernel {
|
||||
LRUCache<std::vector<TensorShape>, std::unique_ptr<EngineContext>,
|
||||
VectorTensorShapeHasher>;
|
||||
|
||||
// Execute calibration
|
||||
// Executes calibration.
|
||||
void ExecuteCalibration(OpKernelContext* ctx,
|
||||
TRTEngineCacheResource* cache_res,
|
||||
AsyncHelper* helper);
|
||||
@ -103,14 +104,15 @@ class TRTEngineOp : public AsyncOpKernel {
|
||||
Status ConstructFunctionHandle(FunctionLibraryRuntime* lib,
|
||||
const string& device_name);
|
||||
|
||||
// Execute replaced native segment as function Op.
|
||||
// Executes replaced native segment as function Op.
|
||||
void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
|
||||
|
||||
// Execute the tensorrt engine. Returns whether we need to retry by running
|
||||
// Executes the tensorrt engine. Returns whether we need to retry by running
|
||||
// the native segment.
|
||||
bool ExecuteTrtEngine(OpKernelContext* ctx, EngineContext* engine_context);
|
||||
bool ExecuteTrtEngine(OpKernelContext* ctx, EngineContext* engine_context,
|
||||
int trt_context_idx);
|
||||
|
||||
// Allocate necessary resources for calibration
|
||||
// Allocates necessary resources for calibration.
|
||||
Status AllocateCalibrationResources(OpKernelContext* ctx,
|
||||
TRTEngineCacheResource* cache_res);
|
||||
|
||||
@ -605,11 +607,24 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
|
||||
|
||||
OP_REQUIRES_OK_ASYNC(ctx, VerifyInputShapes(input_concrete_shapes), *helper);
|
||||
|
||||
if (!use_implicit_batch_) {
|
||||
if (cache_res->profiles_.GetNumProfiles() == 0) {
|
||||
// Create a single profile from the current input shape. In the future we
|
||||
// will collect a set of input shapes during build mode and create
|
||||
// profiles for each of them.
|
||||
cache_res->profiles_.AddShape(input_concrete_shapes);
|
||||
cache_res->profiles_.InitProfiles();
|
||||
}
|
||||
}
|
||||
StatusOr<EngineContext*> status =
|
||||
GetEngine(input_concrete_shapes, ctx, cache_res);
|
||||
OP_REQUIRES_OK_ASYNC(ctx, status.status(), *helper);
|
||||
|
||||
EngineContext* engine_context = status.ValueOrDie();
|
||||
// Context idx equals with the profile idx because for each profile we create
|
||||
// one context. Currently we do not have profile_generation mode, therefore we
|
||||
// have just a single profile.
|
||||
int trt_context_idx = 0;
|
||||
if (!engine_context->cuda_engine) {
|
||||
VLOG(1) << "Engine retrieval for input shapes: "
|
||||
<< TensorShapeUtils::ShapeListString(input_concrete_shapes)
|
||||
@ -617,7 +632,8 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
|
||||
ExecuteNativeSegment(ctx, helper);
|
||||
return;
|
||||
}
|
||||
const bool retry = ExecuteTrtEngine(ctx, engine_context);
|
||||
|
||||
const bool retry = ExecuteTrtEngine(ctx, engine_context, trt_context_idx);
|
||||
if (retry) {
|
||||
LOG(WARNING) << "Failed to execute engine, "
|
||||
<< "retrying with native segment for " << name();
|
||||
@ -665,7 +681,8 @@ Status GetTrtBindingIndex(const char* tensor_name, int profile_index,
|
||||
}
|
||||
|
||||
bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
|
||||
EngineContext* engine_context) {
|
||||
EngineContext* engine_context,
|
||||
int trt_context_idx) {
|
||||
VLOG(1) << "Executing TRT engine: " << name();
|
||||
auto& cuda_engine = engine_context->cuda_engine;
|
||||
|
||||
@ -688,6 +705,11 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
|
||||
}
|
||||
|
||||
const bool kRetry = true;
|
||||
if (trt_context_idx >= 1) {
|
||||
LOG(ERROR) << "Requested engine context with index " << trt_context_idx
|
||||
<< ", but only 1 context is present.";
|
||||
return kRetry;
|
||||
}
|
||||
const int num_binding = cuda_engine->getNbBindings();
|
||||
std::vector<void*> buffers(num_binding);
|
||||
|
||||
@ -698,8 +720,8 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
|
||||
for (int i = 0; i < ctx->num_inputs(); i++) {
|
||||
const string input_name = StrCat(IONamePrefixes::kInputPHName, i);
|
||||
int binding_index;
|
||||
auto status = GetTrtBindingIndex(input_name.c_str(), 0, cuda_engine.get(),
|
||||
&binding_index);
|
||||
auto status = GetTrtBindingIndex(input_name.c_str(), trt_context_idx,
|
||||
cuda_engine.get(), &binding_index);
|
||||
if (!status.ok()) {
|
||||
ctx->SetStatus(status);
|
||||
return !kRetry;
|
||||
@ -770,8 +792,8 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
|
||||
for (int i = 0; i < ctx->num_outputs(); i++) {
|
||||
const string output_name = StrCat(IONamePrefixes::kOutputPHName, i);
|
||||
int binding_index;
|
||||
auto status = GetTrtBindingIndex(output_name.c_str(), 0, cuda_engine.get(),
|
||||
&binding_index);
|
||||
auto status = GetTrtBindingIndex(output_name.c_str(), trt_context_idx,
|
||||
cuda_engine.get(), &binding_index);
|
||||
if (!status.ok()) {
|
||||
ctx->SetStatus(status);
|
||||
return !kRetry;
|
||||
@ -801,7 +823,7 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
|
||||
trt_shape.push_back(dims.d[j]);
|
||||
}
|
||||
}
|
||||
// Allocate output tensor of TRTEngineOp
|
||||
// Allocate output tensor of TRTEngineOp.
|
||||
Tensor* output_tensor = nullptr;
|
||||
TensorShape output_shape;
|
||||
status = TensorShapeUtils::MakeShape(trt_shape.data(), trt_shape.size(),
|
||||
@ -997,7 +1019,8 @@ StatusOr<EngineContext*> TRTEngineOp::GetEngine(
|
||||
auto status = convert::ConvertGraphDefToEngine(
|
||||
segment_graph_def_, precision_mode_, batch_size, workspace_size_,
|
||||
conversion_input_shapes, &logger, allocator, calibrator_.get(), &engine,
|
||||
use_calibration_, use_implicit_batch_, &convert_successfully);
|
||||
use_calibration_, use_implicit_batch_, &convert_successfully,
|
||||
&cache_res->profiles_);
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING) << "Engine creation for " << name() << " failed. "
|
||||
<< "The native segment will be used instead. "
|
||||
@ -1007,11 +1030,12 @@ StatusOr<EngineContext*> TRTEngineOp::GetEngine(
|
||||
cache.emplace(input_concrete_shapes, absl::make_unique<EngineContext>());
|
||||
return &empty_context;
|
||||
}
|
||||
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
|
||||
engine->createExecutionContext());
|
||||
std::vector<TrtUniquePtrType<nvinfer1::IExecutionContext>> exec_context;
|
||||
TF_RETURN_IF_ERROR(cache_res->profiles_.CreateExecutionContexts(
|
||||
engine.get(), exec_context));
|
||||
cache.emplace(input_concrete_shapes,
|
||||
absl::make_unique<EngineContext>(std::move(engine),
|
||||
std::move(exec_context)));
|
||||
std::move(exec_context[0])));
|
||||
VLOG(1) << "Added new engine to cache of " << name()
|
||||
<< ". Cache size: " << cache.size();
|
||||
}
|
||||
@ -1085,9 +1109,9 @@ Status TRTEngineOp::AllocateCalibrationResources(
|
||||
this->segment_graph_def_, TrtPrecisionMode::INT8,
|
||||
cres->calibrator_->getBatchSize(), this->workspace_size_,
|
||||
partial_shapes, &cache_res->GetLogger(), cache_res->allocator_.get(),
|
||||
cres->calibrator_.get(), &cres->engine_,
|
||||
/*use_calibration=*/true, this->use_implicit_batch_,
|
||||
/*convert_successfully=*/nullptr);
|
||||
cres->calibrator_.get(), &cres->engine_, /*use_calibration=*/true,
|
||||
this->use_implicit_batch_, /*convert_successfully=*/nullptr,
|
||||
/*profiles=*/nullptr);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Calibration failed: " << s;
|
||||
cres->calibrator_->setDone(); // Ignore further pushes
|
||||
|
@ -129,9 +129,14 @@ class TRTEngineOpTestBase : public OpsTestBase {
|
||||
private:
|
||||
Status InitOpWithFunctionLibrary() {
|
||||
OpKernel* kernel = nullptr;
|
||||
Status status = CreateOpKernel(device_type_, device_, allocator(),
|
||||
pflr_->GetFLR(device_->name()), node_def_,
|
||||
TF_GRAPH_DEF_VERSION, &kernel);
|
||||
auto flr = pflr_->GetFLR(device_->name());
|
||||
std::shared_ptr<const NodeProperties> props;
|
||||
Status status = NodeProperties::CreateFromNodeDef(
|
||||
node_def_, flr->GetFunctionLibraryDefinition(), &props);
|
||||
if (status.ok()) {
|
||||
status.Update(CreateOpKernel(device_type_, device_, allocator(), flr,
|
||||
props, TF_GRAPH_DEF_VERSION, &kernel));
|
||||
}
|
||||
kernel_ = std::unique_ptr<OpKernel>(kernel);
|
||||
if (kernel_ != nullptr) input_types_ = kernel_->input_types();
|
||||
return status;
|
||||
@ -214,6 +219,7 @@ TEST_F(TRTEngineOpTestBase, AllowBuildAtRuntime) {
|
||||
EXPECT_EQ(ectx->cuda_engine, nullptr);
|
||||
}
|
||||
|
||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||
TEST_F(TRTEngineOpTestBase, ExplicitBatch) {
|
||||
// Test inference in explicit batch mode with static input shapes. Static
|
||||
// shapes in this context means that the TensorRT knows all the input shapes
|
||||
@ -253,15 +259,6 @@ TEST_F(TRTEngineOpTestBase, DynamicShapes) {
|
||||
TensorShape input_shape({1, 2});
|
||||
TRTEngineOpTestBase::AddSimpleInput<float>(input_shape);
|
||||
|
||||
// We expect that TensorRT engine creation fails: we would need to configure
|
||||
// the engine with optimization profiles to use dynamic input shapes, but that
|
||||
// feature is not yet implemented.
|
||||
//
|
||||
// Since TRT engine creation has failed, we fall back to native segment.
|
||||
// Calling the native segment fails for the same reason that is investigated
|
||||
// in https://github.com/tensorflow/tensorflow/pull/34919. This is irrelevant
|
||||
// for the current test, here we want to just check wether TRT engine creation
|
||||
// has failed.
|
||||
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
|
||||
|
||||
// Get the engine cache.
|
||||
@ -274,11 +271,8 @@ TEST_F(TRTEngineOpTestBase, DynamicShapes) {
|
||||
auto cache = &cache_resource->cache_;
|
||||
EXPECT_EQ(1, cache->size());
|
||||
ASSERT_EQ(1, cache->count({input_shape}));
|
||||
// TODO(bixia): re-enable the check below when the problem is fixed.
|
||||
// EngineContext* ectx = cache->at({input_shape}).get();
|
||||
// Since engine creation failed, we expect to find nullptr. Finding a nullptr
|
||||
// indicates that unknown shapes were used to define the TensorRT network.
|
||||
// EXPECT_EQ(ectx->cuda_engine, nullptr);
|
||||
EngineContext* ectx = cache->at({input_shape}).get();
|
||||
EXPECT_NE(ectx->cuda_engine, nullptr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -302,6 +296,7 @@ TYPED_TEST(TRTEngineOpTest, Basic) {
|
||||
output->NumElements()),
|
||||
ElementsAre(TypeParam(0.0f), TypeParam(2.0f)));
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
@ -140,11 +140,24 @@ class InitializeTRTResource : public OpKernel {
|
||||
engine_instance.serialized_engine().c_str(),
|
||||
engine_instance.serialized_engine().size(), nullptr));
|
||||
auto raw_engine = engine.get();
|
||||
resource->cache_.emplace(
|
||||
engine_input_shapes,
|
||||
absl::make_unique<EngineContext>(
|
||||
std::move(engine), TrtUniquePtrType<nvinfer1::IExecutionContext>(
|
||||
raw_engine->createExecutionContext())));
|
||||
std::vector<TrtUniquePtrType<nvinfer1::IExecutionContext>> ctx_vec;
|
||||
if (num_loaded_engine == 0) {
|
||||
// Restore profiles if there are any. Currently only 1 engine is allowed
|
||||
// in dynamic mode therefore we call this only for the 0th engine.
|
||||
// it is a no-op in implicit batch mode.
|
||||
OP_REQUIRES_OK(ctx, resource->profiles_.RestoreProfiles(raw_engine));
|
||||
OP_REQUIRES_OK(ctx, resource->profiles_.CreateExecutionContexts(
|
||||
raw_engine, ctx_vec));
|
||||
} else {
|
||||
// Multiple engines are only available in static mode. For each engine
|
||||
// we have only a single execution context.
|
||||
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_ctx(
|
||||
raw_engine->createExecutionContext());
|
||||
ctx_vec.push_back(std::move(exec_ctx));
|
||||
}
|
||||
resource->cache_.emplace(engine_input_shapes,
|
||||
absl::make_unique<EngineContext>(
|
||||
std::move(engine), std::move(ctx_vec[0])));
|
||||
++num_loaded_engine;
|
||||
} while (1);
|
||||
VLOG(1) << "Loaded " << num_loaded_engine << " TRT engines for op "
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
@ -182,6 +183,11 @@ class TRTEngineCacheResource : public ResourceBase {
|
||||
// TODO(hinsu): Use different calibration context for the available shapes and
|
||||
// attach it to each item of the cache.
|
||||
std::unique_ptr<CalibrationContext> calib_ctx_;
|
||||
|
||||
// This object maintains all the optimization profiles during profile
|
||||
// generation and engine build. During runtime the list of profiles is used to
|
||||
// look up a matching profile for the input data.
|
||||
TrtShapeOptimizationProfile profiles_;
|
||||
};
|
||||
|
||||
#endif // GOOGLE_TENSORRT
|
||||
|
@ -0,0 +1,185 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||
|
||||
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
// Creates optimization profiles for a list of input shapes. The list of input
|
||||
// shapes are stored in shapes_.
|
||||
void TrtShapeOptimizationProfile::InitProfiles() {
|
||||
if (input_shapes_.size() == 0) {
|
||||
VLOG(1) << "Not creating profiles without input_shapes. "
|
||||
"You have to enable profile generation mode first (build).";
|
||||
} else {
|
||||
VLOG(1) << "Creating profiles with startegy of one profile "
|
||||
<< "for each input (min=opt=max).";
|
||||
}
|
||||
for (auto& shape_vec : input_shapes_) {
|
||||
std::vector<nvinfer1::Dims> dimvec;
|
||||
for (auto& shape : shape_vec) {
|
||||
dimvec.push_back(TensorShapeToTrtDims(shape, false));
|
||||
}
|
||||
// We set min=opt=max.
|
||||
OptimizationProfileConfig profConfig{dimvec, dimvec, dimvec};
|
||||
profiles_.push_back(std::move(profConfig));
|
||||
VLOG(1) << "Created profile " << profiles_.back().DebugString();
|
||||
}
|
||||
}
|
||||
|
||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||
Status TrtShapeOptimizationProfile::AddProfiles(
|
||||
nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config,
|
||||
const nvinfer1::INetworkDefinition* network) {
|
||||
// Create a vector of optimization profiles
|
||||
for (int i = 0; i < profiles_.size(); i++) {
|
||||
auto* optProfile = builder->createOptimizationProfile();
|
||||
Status status = profiles_[i].SetDimensions(network, optProfile);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
int idx = -1;
|
||||
if (optProfile->isValid()) {
|
||||
idx = config->addOptimizationProfile(optProfile);
|
||||
}
|
||||
if (idx >= 0) {
|
||||
if (i != idx) {
|
||||
return errors::Internal(
|
||||
"Profile index of engine config is different from resource profile "
|
||||
"index: ",
|
||||
i, " != ", idx);
|
||||
}
|
||||
VLOG(1) << "Added optimization profile " << profiles_[i].DebugString()
|
||||
<< " to builder config.";
|
||||
} else {
|
||||
LOG(ERROR) << "Failed to add optimization profile "
|
||||
<< profiles_[i].DebugString()
|
||||
<< ". This usually happens when profile is invalid.";
|
||||
}
|
||||
}
|
||||
if (config->getNbOptimizationProfiles() == 0) {
|
||||
return errors::Internal("Failure in adding an optimization profile.");
|
||||
}
|
||||
// if TRT_VERSION < 6, then we do not need to add
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||
Status TrtShapeOptimizationProfile::ConfigureBuilder(
|
||||
nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config,
|
||||
const nvinfer1::INetworkDefinition* network) {
|
||||
TF_RETURN_IF_ERROR(AddProfiles(builder, config, network));
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
int TrtShapeOptimizationProfile::GetProfileNumber(
|
||||
std::vector<TensorShape> shapes) {
|
||||
for (int i = 0; i < profiles_.size(); i++) {
|
||||
if (profiles_[i].IncludesShapes(shapes)) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
VLOG(1) << "Profile not found for input shapes " << DebugString(shapes)
|
||||
<< ".";
|
||||
return -1;
|
||||
}
|
||||
|
||||
Status TrtShapeOptimizationProfile::CreateExecutionContexts(
|
||||
nvinfer1::ICudaEngine* engine,
|
||||
std::vector<TrtUniquePtrType<nvinfer1::IExecutionContext>>& exec_context) {
|
||||
int i = 0;
|
||||
// The following loop runs once if we have static shapes, to create a single
|
||||
// execution context without profiles. In dynamic mode we create one context
|
||||
// for each profile and set the corresponding optimization profile.
|
||||
do {
|
||||
VLOG(1) << "Creating execution context " << i;
|
||||
nvinfer1::IExecutionContext* ctx = engine->createExecutionContext();
|
||||
if (ctx == nullptr) {
|
||||
return errors::Internal("Failed to create execution context");
|
||||
}
|
||||
if (i > 0) {
|
||||
// This condition is needed for two reasons:
|
||||
// - using static shapes we do not have any profiles so we cannot call
|
||||
// set optimizationprofiles.
|
||||
// - The 0th profile is set implicitly for the first execution context
|
||||
// therefore we do not need to set.
|
||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||
bool stat = ctx->setOptimizationProfile(i);
|
||||
if (!stat) {
|
||||
ctx->destroy();
|
||||
return errors::Internal("Could not set TRT optimization profile.");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
exec_context.push_back(TrtUniquePtrType<nvinfer1::IExecutionContext>(ctx));
|
||||
i++;
|
||||
} while (i < profiles_.size());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TrtShapeOptimizationProfile::RestoreProfiles(
|
||||
const nvinfer1::ICudaEngine* engine) {
|
||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||
if (!engine) {
|
||||
// We do not need to restore profiles for an empty engine
|
||||
return Status::OK();
|
||||
}
|
||||
#if IS_TRT_VERSION_GE(7, 0, 0, 0)
|
||||
if (engine->hasImplicitBatchDimension()) {
|
||||
// Nothing to do, we cannot have profiles in implicit batch mode
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
int n_profiles = engine->getNbOptimizationProfiles();
|
||||
int n_inputs = GetNumberOfEngineInputs(engine);
|
||||
VLOG(2) << "Attempting to restore " << n_profiles << " profiles, each with "
|
||||
<< n_inputs << " inputs";
|
||||
for (int prof_idx = 0; prof_idx < n_profiles; prof_idx++) {
|
||||
OptimizationProfileConfig cfg;
|
||||
for (int j = 0; j < n_inputs; j++) {
|
||||
nvinfer1::Dims min = engine->getProfileDimensions(
|
||||
j, prof_idx, nvinfer1::OptProfileSelector::kMIN);
|
||||
nvinfer1::Dims max = engine->getProfileDimensions(
|
||||
j, prof_idx, nvinfer1::OptProfileSelector::kMAX);
|
||||
nvinfer1::Dims opt = engine->getProfileDimensions(
|
||||
j, prof_idx, nvinfer1::OptProfileSelector::kOPT);
|
||||
cfg.min.push_back(min);
|
||||
cfg.max.push_back(max);
|
||||
cfg.opt.push_back(opt);
|
||||
}
|
||||
VLOG(2) << "Restored profile " << cfg.DebugString();
|
||||
profiles_.push_back(std::move(cfg));
|
||||
}
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int TrtShapeOptimizationProfile::GetNumProfiles() const {
|
||||
return profiles_.size();
|
||||
}
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
|
@ -0,0 +1,178 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_SHAPE_OPTIMIZATION_PROFILES_H_
|
||||
#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_SHAPE_OPTIMIZATION_PROFILES_H_
|
||||
|
||||
#include <list>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
|
||||
#include "third_party/tensorrt/NvInfer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
// Stores optimization profile parameters (min/opt/max of each input shape).
|
||||
//
|
||||
// A TensorRT optimization profile describes the possible min/max values of
|
||||
// each dynamic input shape along with an optimum value. These values are used
|
||||
// by the TensorRT builder to select the best kernel for the optimum value among
|
||||
// those kernels that are valid for all input tensors in the [min, max] range.
|
||||
struct OptimizationProfileConfig {
|
||||
// Length of vector == num_inputs to engine
|
||||
std::vector<nvinfer1::Dims> min;
|
||||
std::vector<nvinfer1::Dims> opt;
|
||||
std::vector<nvinfer1::Dims> max;
|
||||
|
||||
string DebugString() const {
|
||||
using absl::StrCat;
|
||||
return StrCat("[min: ", tensorflow::tensorrt::DebugString(min),
|
||||
", opt: : ", tensorflow::tensorrt::DebugString(opt),
|
||||
", max: ", tensorflow::tensorrt::DebugString(max), "]");
|
||||
}
|
||||
|
||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||
// Sets the stored min/opt/max dimensions for profile.
|
||||
//
|
||||
// Parameters:
|
||||
// network - TensorRT network, used to enumerate all the input tensors
|
||||
// profile - on exit the profile information will be set for each input tensor
|
||||
Status SetDimensions(const nvinfer1::INetworkDefinition* network,
|
||||
nvinfer1::IOptimizationProfile* profile) const {
|
||||
int n_inputs = network->getNbInputs();
|
||||
if (min.size() != n_inputs || opt.size() != n_inputs ||
|
||||
max.size() != n_inputs) {
|
||||
return errors::Internal("Incorrect number of profile config parameters");
|
||||
}
|
||||
for (int i = 0; i < n_inputs; i++) {
|
||||
const char* name = network->getInput(i)->getName();
|
||||
profile->setDimensions(name, nvinfer1::OptProfileSelector::kMIN, min[i]);
|
||||
profile->setDimensions(name, nvinfer1::OptProfileSelector::kOPT, opt[i]);
|
||||
profile->setDimensions(name, nvinfer1::OptProfileSelector::kMAX, max[i]);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
// Returns true if profile range completely includes the given shapes.
|
||||
bool IncludesShapes(const std::vector<TensorShape>& shapes) const {
|
||||
// min, max, and opt must have the same size which is already verified in
|
||||
// SetDimensions.
|
||||
if (min.size() != shapes.size()) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < shapes.size(); i++) {
|
||||
auto current_shape = shapes[i];
|
||||
// min, max, and opt must have the same nbDims, which is already verified
|
||||
// in SetDimensions.
|
||||
if (min[i].nbDims != current_shape.dims()) {
|
||||
return false;
|
||||
}
|
||||
// Check if range [min, max] includes current_shape.
|
||||
for (int dim = 0; dim < current_shape.dims(); dim++) {
|
||||
if ((min[i].d[dim] > current_shape.dim_size(dim)) ||
|
||||
(max[i].d[dim] < current_shape.dim_size(dim))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// Manages Optimization profiles during TRT Engine construction.
|
||||
//
|
||||
// An optimization profile describes a range of dimensions for each TRT network
|
||||
// input, and the optimal dimensions that the auto-tuner should use for
|
||||
// optimization.
|
||||
//
|
||||
// This class stores the list of input shapes that were seen during the
|
||||
// build/profile_generation_mode phase, and using them it creates a set of
|
||||
// OptimizationProfileConfigs. These configs will be added to IBuilderConfig
|
||||
// before the engine is created.
|
||||
class TrtShapeOptimizationProfile {
|
||||
public:
|
||||
TrtShapeOptimizationProfile() {}
|
||||
|
||||
// Stores input shape information during profile_generation_mode
|
||||
void AddShape(std::vector<TensorShape> shapes) {
|
||||
input_shapes_.insert(shapes);
|
||||
VLOG(1) << "Collected shape(s) " << DebugString(shapes) << " for profiles.";
|
||||
}
|
||||
|
||||
void clear() { profiles_.clear(); }
|
||||
|
||||
// Returns the profile number that should be used to execute the network with
|
||||
// the given input shapes. Returns -1 if none of cached profiles are
|
||||
// compatible with the given input shapes.
|
||||
int GetProfileNumber(std::vector<TensorShape> shapes);
|
||||
|
||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||
// Creates optimization profiles and add them to the builder config.
|
||||
Status ConfigureBuilder(nvinfer1::IBuilder* builder,
|
||||
nvinfer1::IBuilderConfig* config,
|
||||
const nvinfer1::INetworkDefinition* network);
|
||||
#endif
|
||||
|
||||
// Creates execution contexts for each optimization profile.
|
||||
Status CreateExecutionContexts(
|
||||
nvinfer1::ICudaEngine* engine,
|
||||
std::vector<TrtUniquePtrType<nvinfer1::IExecutionContext>>& exec_context);
|
||||
|
||||
// Maps input vector shapes to TRT Optimization profiles (min, max, opt) i.e.
|
||||
// maps input_shapes_ to profiles_
|
||||
void InitProfiles();
|
||||
|
||||
// Returns number of created profiles.
|
||||
int GetNumProfiles() const;
|
||||
|
||||
// Restores profiles from the engine (used after deserialization)
|
||||
Status RestoreProfiles(const nvinfer1::ICudaEngine* engine);
|
||||
|
||||
private:
|
||||
// Set of input shape vetors that we collect during profile_generation_mode
|
||||
std::unordered_set<std::vector<TensorShape>, VectorTensorShapeHasher>
|
||||
input_shapes_;
|
||||
|
||||
// The optimization profiles generated from input_shapes_
|
||||
std::vector<OptimizationProfileConfig> profiles_;
|
||||
|
||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||
/// Adds optimization profiles to the builder config
|
||||
Status AddProfiles(nvinfer1::IBuilder* builder,
|
||||
nvinfer1::IBuilderConfig* config,
|
||||
const nvinfer1::INetworkDefinition* network);
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_TENSORRT
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_SHAPE_OPTIMIZATION_PROFILES_H_
|
@ -0,0 +1,218 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "third_party/tensorrt/NvInfer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
std::vector<TensorShape> DimVecToShapeVec(std::vector<nvinfer1::Dims3> dimvec) {
|
||||
std::vector<TensorShape> shapevec(dimvec.size());
|
||||
for (int i = 0; i < dimvec.size(); i++) {
|
||||
TensorShape shape;
|
||||
TF_CHECK_OK(
|
||||
TensorShapeUtils::MakeShape(dimvec[i].d, dimvec[i].nbDims, &shape));
|
||||
shapevec[i] = shape;
|
||||
}
|
||||
return shapevec;
|
||||
}
|
||||
|
||||
bool DimsContained(const nvinfer1::Dims& dim, const nvinfer1::Dims& min,
|
||||
const nvinfer1::Dims& max) {
|
||||
if (dim.nbDims != min.nbDims || dim.nbDims != max.nbDims) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < dim.nbDims; i++) {
|
||||
if (dim.d[i] < min.d[i] || dim.d[i] > max.d[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DimsEqual(const nvinfer1::Dims& a, const nvinfer1::Dims& b) {
|
||||
if (a.nbDims != b.nbDims) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < a.nbDims; i++) {
|
||||
if (a.d[i] != b.d[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
class TrtShapeOptimizationProfileTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
builder_ = TrtUniquePtrType<nvinfer1::IBuilder>(
|
||||
nvinfer1::createInferBuilder(logger_));
|
||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||
network_ = TrtUniquePtrType<nvinfer1::INetworkDefinition>(
|
||||
builder_->createNetworkV2(flags_));
|
||||
builder_config_ = TrtUniquePtrType<nvinfer1::IBuilderConfig>(
|
||||
builder_->createBuilderConfig());
|
||||
builder_config_->setMaxWorkspaceSize(1 << 10);
|
||||
#else
|
||||
network_ = TrtUniquePtrType<nvinfer1::INetworkDefinition>(
|
||||
builder_->createNetwork());
|
||||
builder_->setMaxWorkspaceSize(1 << 10);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Defines a simple network: output = input1 + input2.
|
||||
void DefineNetwork(nvinfer1::INetworkDefinition* network,
|
||||
nvinfer1::Dims3& dims) {
|
||||
nvinfer1::ITensor* input1 =
|
||||
network->addInput("input1", nvinfer1::DataType::kFLOAT, dims);
|
||||
EXPECT_NE(nullptr, input1);
|
||||
|
||||
nvinfer1::ITensor* input2 =
|
||||
network->addInput("input2", nvinfer1::DataType::kFLOAT, dims);
|
||||
EXPECT_NE(nullptr, input1);
|
||||
|
||||
auto layer = network->addElementWise(*input1, *input2,
|
||||
nvinfer1::ElementWiseOperation::kSUM);
|
||||
EXPECT_NE(nullptr, layer);
|
||||
// Mark the output.
|
||||
nvinfer1::ITensor* output = layer->getOutput(0);
|
||||
output->setName("output");
|
||||
network->markOutput(*output);
|
||||
}
|
||||
|
||||
Logger logger_;
|
||||
TrtUniquePtrType<nvinfer1::IBuilder> builder_;
|
||||
TrtUniquePtrType<nvinfer1::INetworkDefinition> network_;
|
||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||
TrtUniquePtrType<nvinfer1::IBuilderConfig> builder_config_;
|
||||
#endif
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
|
||||
std::vector<TrtUniquePtrType<nvinfer1::IExecutionContext>> exec_context_;
|
||||
// The order is important: exec_context_ must be destroyed first, and logger
|
||||
// at last.
|
||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||
const uint32_t flags_ =
|
||||
1U << static_cast<int>(
|
||||
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
|
||||
#endif
|
||||
};
|
||||
|
||||
TEST_F(TrtShapeOptimizationProfileTest, Static) {
|
||||
// Network with static input shape
|
||||
nvinfer1::Dims3 dims(8, 8, 10);
|
||||
DefineNetwork(network_.get(), dims);
|
||||
|
||||
TrtShapeOptimizationProfile profile;
|
||||
|
||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||
// Configure and build engine - should be a no-op
|
||||
TF_CHECK_OK(profile.ConfigureBuilder(builder_.get(), builder_config_.get(),
|
||||
network_.get()));
|
||||
|
||||
engine = TrtUniquePtrType<nvinfer1::ICudaEngine>(
|
||||
builder_->buildEngineWithConfig(*network_, *builder_config_));
|
||||
#else
|
||||
engine = TrtUniquePtrType<nvinfer1::ICudaEngine>(
|
||||
builder_->buildCudaEngine(*network_));
|
||||
#endif
|
||||
EXPECT_NE(nullptr, engine);
|
||||
TF_CHECK_OK(profile.CreateExecutionContexts(engine.get(), exec_context_));
|
||||
// A single execution context should be created for a graph with static input
|
||||
ASSERT_EQ(exec_context_.size(), 1);
|
||||
EXPECT_NE(nullptr, exec_context_[0]);
|
||||
|
||||
std::vector<nvinfer1::Dims3> dim_vec(2, dims);
|
||||
std::vector<TensorShape> shape_vec = DimVecToShapeVec(dim_vec);
|
||||
EXPECT_EQ(-1, profile.GetProfileNumber(shape_vec));
|
||||
}
|
||||
|
||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||
TEST_F(TrtShapeOptimizationProfileTest, Dynamic) {
|
||||
// Network with dynamic input shapes
|
||||
nvinfer1::Dims3 dims(-1, -1, 10);
|
||||
DefineNetwork(network_.get(), dims);
|
||||
|
||||
TrtShapeOptimizationProfile profile;
|
||||
std::vector<std::vector<nvinfer1::Dims3>> input_profiles{
|
||||
{nvinfer1::Dims3(2, 2, 10), nvinfer1::Dims3(2, 2, 10)},
|
||||
{nvinfer1::Dims3(3, 3, 10), nvinfer1::Dims3(3, 3, 10)},
|
||||
{nvinfer1::Dims3(16, 16, 10), nvinfer1::Dims3(16, 16, 10)},
|
||||
};
|
||||
|
||||
// Simulate a profile collection phase
|
||||
for (auto dim_vec : input_profiles) {
|
||||
std::vector<TensorShape> shape_vec = DimVecToShapeVec(dim_vec);
|
||||
profile.AddShape(shape_vec);
|
||||
}
|
||||
profile.InitProfiles();
|
||||
|
||||
// Configure and build engine
|
||||
TF_CHECK_OK(profile.ConfigureBuilder(builder_.get(), builder_config_.get(),
|
||||
network_.get()));
|
||||
engine = TrtUniquePtrType<nvinfer1::ICudaEngine>(
|
||||
builder_->buildEngineWithConfig(*network_.get(), *builder_config_.get()));
|
||||
ASSERT_NE(nullptr, engine);
|
||||
|
||||
TF_CHECK_OK(profile.CreateExecutionContexts(engine.get(), exec_context_));
|
||||
|
||||
// Each profile has an associated execution context.
|
||||
EXPECT_EQ(exec_context_.size(), input_profiles.size());
|
||||
|
||||
// Check if the profiles are assigned correctly.
|
||||
for (auto dimvec : input_profiles) {
|
||||
std::vector<TensorShape> shape_vec = DimVecToShapeVec(dimvec);
|
||||
int idx = profile.GetProfileNumber(shape_vec);
|
||||
int prof_idx = exec_context_[idx]->getOptimizationProfile();
|
||||
ASSERT_GE(prof_idx, 0);
|
||||
|
||||
for (int j = 0; j < dimvec.size(); j++) {
|
||||
nvinfer1::Dims min = engine->getProfileDimensions(
|
||||
j, prof_idx, nvinfer1::OptProfileSelector::kMIN);
|
||||
nvinfer1::Dims max = engine->getProfileDimensions(
|
||||
j, prof_idx, nvinfer1::OptProfileSelector::kMAX);
|
||||
nvinfer1::Dims opt = engine->getProfileDimensions(
|
||||
j, prof_idx, nvinfer1::OptProfileSelector::kOPT);
|
||||
|
||||
// This should always hold.
|
||||
EXPECT_TRUE(DimsContained(dimvec[j], min, max));
|
||||
|
||||
// The following test depends on the profile creation strategy, and needs
|
||||
// to be updated (disabled) if the default trategy (defined by
|
||||
// InitProfiles) changes.
|
||||
EXPECT_TRUE(DimsEqual(dimvec[j], opt));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_TENSORRT
|
||||
#endif // GOOGLE_CUDA
|
@ -133,7 +133,7 @@ Status GraphCompiler::Compile() {
|
||||
OpKernel* op_kernel_raw = nullptr;
|
||||
// The kernel is not actually run for functional ops, we just need it
|
||||
// for metadata.
|
||||
Status s = flib_->CreateKernel(n->def(), &op_kernel_raw);
|
||||
Status s = flib_->CreateKernel(n->properties(), &op_kernel_raw);
|
||||
// Transfer ownership of the kernel to a local smart pointer.
|
||||
std::unique_ptr<OpKernel> op_kernel(op_kernel_raw);
|
||||
|
||||
|
@ -55,7 +55,6 @@ tf_kernel_library(
|
||||
"index_ops.cc",
|
||||
"l2loss_op.cc",
|
||||
"listdiff_op.cc",
|
||||
"lower_upper_bound_ops.cc",
|
||||
"lrn_ops.cc",
|
||||
"matmul_op.cc",
|
||||
"matrix_band_part_op.cc",
|
||||
@ -150,7 +149,6 @@ tf_kernel_library(
|
||||
"//tensorflow/compiler/tf2xla/lib:util",
|
||||
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
||||
"//tensorflow/compiler/xla:array4d",
|
||||
"//tensorflow/compiler/xla:comparison_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
|
@ -264,6 +264,23 @@ xla::XlaOp IgammaImpl(xla::XlaOp x, xla::XlaOp y,
|
||||
|
||||
XLA_MAKE_BINARY(Igamma, IgammaImpl(lhs, rhs, broadcast_helper));
|
||||
|
||||
xla::XlaOp IgammaGradAImpl(xla::XlaOp x, xla::XlaOp y,
|
||||
const BCast& broadcast_helper) {
|
||||
std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
|
||||
return xla::IgammaGradA(x, y);
|
||||
}
|
||||
|
||||
XLA_MAKE_BINARY(IgammaGradA, IgammaGradAImpl(lhs, rhs, broadcast_helper));
|
||||
|
||||
xla::XlaOp RandomGammaGradImpl(xla::XlaOp x, xla::XlaOp y,
|
||||
const BCast& broadcast_helper) {
|
||||
std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
|
||||
return xla::RandomGammaGrad(x, y);
|
||||
}
|
||||
|
||||
XLA_MAKE_BINARY(RandomGammaGrad,
|
||||
RandomGammaGradImpl(lhs, rhs, broadcast_helper));
|
||||
|
||||
xla::XlaOp IgammacImpl(xla::XlaOp x, xla::XlaOp y,
|
||||
const BCast& broadcast_helper) {
|
||||
std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
|
||||
|
@ -1,116 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Builds a LowerBound or UpperBound op, the distinction lying in
|
||||
// comparison_direction: GT => LowerBoundOp, GE => UpperBoundOp.
|
||||
// Note that this is an O(MN) algorithm: all entries in each sorted_inputs row
|
||||
// are considered, and their sorted nature is not fully exploited.
|
||||
void BuildLowerUpperBoundOp(XlaOpKernelContext* ctx, DataType out_dtype,
|
||||
xla::ComparisonDirection comparison_direction) {
|
||||
const TensorShape sorted_inputs_shape = ctx->InputShape("sorted_inputs");
|
||||
const TensorShape values_shape = ctx->InputShape("values");
|
||||
const xla::XlaOp sorted_inputs = ctx->Input("sorted_inputs");
|
||||
const xla::XlaOp values = ctx->Input("values");
|
||||
|
||||
// We are assuming both inputs are 2D, which they will be given the current
|
||||
// implementation of tf.searchsorted.
|
||||
OP_REQUIRES(ctx, sorted_inputs_shape.dims() == 2,
|
||||
errors::FailedPrecondition("sorted_inputs must be 2D"));
|
||||
OP_REQUIRES(ctx, values_shape.dims() == 2,
|
||||
errors::FailedPrecondition("values must be 2D"));
|
||||
|
||||
// Add a new inner dimension to values, to allow broadcasting along the inner
|
||||
// dimension of sorted_sequence.
|
||||
auto new_values_shape = values_shape;
|
||||
new_values_shape.InsertDim(/* d */ 2, /* size */ 1);
|
||||
auto values_reshaped = xla::Reshape(values, new_values_shape.dim_sizes());
|
||||
|
||||
// Add a new penultimate dimension to sorted_inputs, to allow broadcasting of
|
||||
// sorted_sequence entries for each value.
|
||||
auto new_sorted_inputs_shape = sorted_inputs_shape;
|
||||
new_sorted_inputs_shape.InsertDim(/* d */ 1, /* size */ 1);
|
||||
auto sorted_inputs_reshaped =
|
||||
xla::Reshape(sorted_inputs, new_sorted_inputs_shape.dim_sizes());
|
||||
|
||||
// We are relying on broadcasting to compare each value against each entry in
|
||||
// the associated sorted_inputs row.
|
||||
// The reshapes above leave the tensors with equal rank of 3, so broadcast
|
||||
// dimensions are not explicitly specified.
|
||||
auto comparison = xla::Compare(values_reshaped, sorted_inputs_reshaped, {},
|
||||
comparison_direction);
|
||||
|
||||
const DataType accumulation_type = XlaHelpers::SumAccumulationType(out_dtype);
|
||||
|
||||
// Convert boolean comparison results to integers so we can sum them.
|
||||
auto comparison_int =
|
||||
XlaHelpers::ConvertElementType(comparison, accumulation_type);
|
||||
|
||||
// Sum the comparison results over the inner dimension to find the index for
|
||||
// each value.
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
auto reduced =
|
||||
xla::Reduce(comparison_int, XlaHelpers::Zero(builder, accumulation_type),
|
||||
*ctx->GetOrCreateAdd(accumulation_type), {2});
|
||||
|
||||
ctx->SetOutput(0, reduced);
|
||||
}
|
||||
|
||||
class LowerBoundOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit LowerBoundOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
BuildLowerUpperBoundOp(ctx, out_dtype_, xla::ComparisonDirection::kGt);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType out_dtype_;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("LowerBound"), LowerBoundOp);
|
||||
|
||||
class UpperBoundOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit UpperBoundOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
BuildLowerUpperBoundOp(ctx, out_dtype_, xla::ComparisonDirection::kGe);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType out_dtype_;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("UpperBound"), UpperBoundOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -32,6 +32,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/pooling_ops_common.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -157,6 +159,13 @@ class MaxPoolOp : public PoolingOp {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
|
||||
OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
data_format_ != FORMAT_NCHW_VECT_C &&
|
||||
data_format_ != FORMAT_NHWC_VECT_W,
|
||||
errors::Unimplemented("XLA does not support the VECT_* data formats. "
|
||||
"Returning unimplemented from MaxPool to keep "
|
||||
"Tensorflow's intended optimized MaxPool here."));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
|
@ -34,6 +34,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import bitwise_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import gen_random_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
|
||||
@ -200,6 +201,8 @@ shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper)
|
||||
shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper)
|
||||
|
||||
igamma = _broadcasting_binary_op(math_ops.igamma)
|
||||
igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a)
|
||||
random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad)
|
||||
igammac = _broadcasting_binary_op(math_ops.igammac)
|
||||
|
||||
|
||||
|
@ -26,22 +26,6 @@ const char kShardingAttribute[] = "_XlaSharding";
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
|
||||
const NodeDef& node_def) {
|
||||
if (!HasNodeAttr(node_def, kShardingAttribute)) {
|
||||
return absl::optional<xla::OpSharding>();
|
||||
}
|
||||
string value;
|
||||
xla::OpSharding sharding;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value));
|
||||
if (!sharding.ParseFromString(value)) {
|
||||
return xla::InvalidArgument(
|
||||
"Experimental _XlaSharding attribute was not a valid encoded "
|
||||
"xla::OpSharding proto.");
|
||||
}
|
||||
return absl::optional<xla::OpSharding>(sharding);
|
||||
}
|
||||
|
||||
Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid replicated core id: ", core,
|
||||
@ -107,4 +91,19 @@ void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) {
|
||||
}
|
||||
}
|
||||
|
||||
xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
|
||||
const NodeDef& node_def) {
|
||||
if (!HasNodeAttr(node_def, kShardingAttribute)) {
|
||||
return absl::optional<xla::OpSharding>();
|
||||
}
|
||||
string value;
|
||||
xla::OpSharding sharding;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value));
|
||||
if (!sharding.ParseFromString(value)) {
|
||||
return xla::InvalidArgument(
|
||||
"Experimental _XlaSharding attribute was not a valid encoded "
|
||||
"xla::OpSharding proto.");
|
||||
}
|
||||
return absl::optional<xla::OpSharding>(sharding);
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -45,6 +45,10 @@ xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
||||
|
||||
void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst);
|
||||
|
||||
// Get sharding inforamtion from node.
|
||||
xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
|
||||
const NodeDef& node_def);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_SHARDING_UTIL_H_
|
||||
|
@ -97,6 +97,7 @@ xla::StatusOr<DataType> EncodePrimitiveTypeAsDataType(xla::PrimitiveType type) {
|
||||
{xla::U16, DT_UINT16},
|
||||
{xla::U32, DT_UINT32},
|
||||
{xla::U64, DT_UINT64},
|
||||
{xla::C128, DT_COMPLEX128},
|
||||
});
|
||||
|
||||
auto it = data_type_map.find(type);
|
||||
|
@ -139,6 +139,86 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Rewrites the layout of xla_shape if there is tiled sharding.
|
||||
Status RewriteLayoutWithShardedShape(
|
||||
const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
xla::Shape* xla_shape) {
|
||||
if (sharding && !sharding->IsTileMaximal()) {
|
||||
// After sharding, per core shape might have different layout. For example,
|
||||
// before sharding, a shape [128, 128] will be assigned default
|
||||
// minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2,
|
||||
// the sharded shapes will have minor-to-major {0, 1}.
|
||||
//
|
||||
// As a result, for sharded shapes, we set their layout to per core shape's
|
||||
// layout.
|
||||
//
|
||||
// TODO(endlessroad): for variable input & update, we might have
|
||||
// different layouts which will prevent input output aliasing and
|
||||
// increase memory usage. Investigate such cases.
|
||||
int64 device = *sharding->tile_assignment().begin();
|
||||
std::vector<int64> offset =
|
||||
sharding->TileOffsetForDevice(*xla_shape, device);
|
||||
std::vector<int64> limit = sharding->TileLimitForDevice(*xla_shape, device);
|
||||
std::vector<int64> dimensions(xla_shape->rank());
|
||||
for (int64 i = 0; i < xla_shape->rank(); ++i) {
|
||||
dimensions[i] = limit[i] - offset[i];
|
||||
}
|
||||
xla::Shape per_device_xla_shape =
|
||||
xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
|
||||
TensorShape per_device_tensor_shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape));
|
||||
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
|
||||
xla_shape->element_type()));
|
||||
TF_ASSIGN_OR_RETURN(per_device_xla_shape,
|
||||
shape_representation_fn(per_device_tensor_shape, dtype,
|
||||
use_fast_memory));
|
||||
*xla_shape->mutable_layout() = per_device_xla_shape.layout();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// There is a shape_representation_fn or sharding for an output, this function
|
||||
// uses a reshape to fix the layout.
|
||||
xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
|
||||
xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
absl::optional<xla::OpSharding> sharding, bool fast_mem) {
|
||||
if (original_shape.IsTuple()) {
|
||||
std::vector<xla::XlaOp> elements;
|
||||
for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) {
|
||||
auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
|
||||
TF_ASSIGN_OR_RETURN(auto element,
|
||||
ReshapeWithCorrectRepresentationAndSharding(
|
||||
builder, xla::GetTupleElement(original, i),
|
||||
original_shape.tuple_shapes(i),
|
||||
shape_representation_fn, subsharding, fast_mem));
|
||||
elements.push_back(element);
|
||||
}
|
||||
return xla::Tuple(builder, elements);
|
||||
}
|
||||
if (!original_shape.IsArray()) return original;
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape));
|
||||
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
|
||||
original_shape.element_type()));
|
||||
TF_ASSIGN_OR_RETURN(auto to_shape,
|
||||
shape_representation_fn(shape, dtype, fast_mem));
|
||||
if (sharding) {
|
||||
TF_ASSIGN_OR_RETURN(auto hlo_sharding,
|
||||
xla::HloSharding::FromProto(*sharding));
|
||||
TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
|
||||
hlo_sharding, fast_mem, shape_representation_fn, &to_shape));
|
||||
}
|
||||
if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
|
||||
for (int64 i = 0; i < original_shape.rank(); ++i) {
|
||||
to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
|
||||
}
|
||||
}
|
||||
return xla::Reshape(to_shape, original);
|
||||
}
|
||||
|
||||
// Builds the XLA computation.
|
||||
// - `args` is the list of input arguments
|
||||
// - `retvals` is the list of retvals produced by _Retval operators, in index
|
||||
@ -188,10 +268,6 @@ Status BuildComputation(
|
||||
std::vector<xla::XlaOp> elems;
|
||||
elems.reserve(retvals.size());
|
||||
|
||||
// Keeps track of the layout of each retval. If a retval is not in this list,
|
||||
// a descending layout is used. The first element is the output index, second
|
||||
// element is the new layout.
|
||||
std::vector<std::pair<int64, xla::Layout>> retval_index_and_layout;
|
||||
// Keeps track of sharding of each retval. If a retval is not in this list,
|
||||
// replicate sharding is used. The first element is the output index, second
|
||||
// element is the sharding.
|
||||
@ -219,22 +295,22 @@ Status BuildComputation(
|
||||
TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape());
|
||||
xla::XlaOp value = retval.handle();
|
||||
auto it = retval_shardings.find(i);
|
||||
xla::XlaScopedShardingAssignment assign_sharding(
|
||||
builder, it == retval_shardings.end()
|
||||
? absl::optional<xla::OpSharding>()
|
||||
: it->second);
|
||||
absl::optional<xla::OpSharding> sharding =
|
||||
it == retval_shardings.end() ? absl::optional<xla::OpSharding>()
|
||||
: it->second;
|
||||
if (it != retval_shardings.end()) {
|
||||
retval_index_and_sharding[elems.size()] = it->second;
|
||||
}
|
||||
if (shape_representation_fn) {
|
||||
// If there is a shape representation function, reshape the output
|
||||
// tensor to the shape given by the representation shape function.
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn(
|
||||
output.shape, output.type,
|
||||
/*use_fast_memory=*/false));
|
||||
value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions()));
|
||||
retval_index_and_layout.emplace_back(elems.size(), shape.layout());
|
||||
} else if (it != retval_shardings.end()) {
|
||||
TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(value));
|
||||
TF_ASSIGN_OR_RETURN(value,
|
||||
ReshapeWithCorrectRepresentationAndSharding(
|
||||
builder, value, original_shape,
|
||||
shape_representation_fn, sharding,
|
||||
/*fast_mem=*/false));
|
||||
}
|
||||
if (it != retval_shardings.end()) {
|
||||
xla::XlaScopedShardingAssignment assign_sharding(builder, sharding);
|
||||
// Apply the sharding to the output, if there is a core assignment.
|
||||
value = identity_op(value);
|
||||
}
|
||||
@ -312,43 +388,27 @@ Status BuildComputation(
|
||||
update.tensor_array_gradients_accessed.insert(grad.first);
|
||||
}
|
||||
|
||||
xla::XlaOp handle;
|
||||
TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
|
||||
auto sharding = it == arg_shardings.end()
|
||||
? absl::optional<xla::OpSharding>()
|
||||
: it->second;
|
||||
// Set layout of the retval to device representation layout.
|
||||
if (shape_representation_fn) {
|
||||
TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(handle));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
handle, ReshapeWithCorrectRepresentationAndSharding(
|
||||
builder, handle, original_shape,
|
||||
shape_representation_fn, sharding, arg.fast_mem));
|
||||
}
|
||||
|
||||
// Request that the value be returned on a specific core.
|
||||
xla::XlaScopedShardingAssignment assign_sharding(
|
||||
builder, it == arg_shardings.end() ? absl::optional<xla::OpSharding>()
|
||||
: it->second);
|
||||
xla::XlaScopedShardingAssignment assign_sharding(builder, sharding);
|
||||
if (it != arg_shardings.end()) {
|
||||
retval_index_and_sharding[elems.size()] = it->second;
|
||||
}
|
||||
|
||||
xla::XlaOp handle;
|
||||
TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
|
||||
|
||||
// Ensures the correct sharding is applied to the output.
|
||||
handle = identity_op(handle);
|
||||
|
||||
// Set layout of the retval to device representation layout.
|
||||
absl::optional<xla::Shape> representation_shape;
|
||||
if (shape_representation_fn) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
xla::Shape xla_shape,
|
||||
shape_representation_fn(resource->shape(), resource->type(),
|
||||
/*use_fast_memory=*/false));
|
||||
representation_shape = xla_shape;
|
||||
}
|
||||
if (resource->representation_shape().has_value()) {
|
||||
const xla::Shape& xla_shape = resource->representation_shape().value();
|
||||
if (representation_shape) {
|
||||
TF_RET_CHECK(
|
||||
xla::ShapeUtil::Compatible(*representation_shape, xla_shape));
|
||||
} else {
|
||||
representation_shape = xla_shape;
|
||||
}
|
||||
}
|
||||
if (representation_shape) {
|
||||
retval_index_and_layout.emplace_back(elems.size(),
|
||||
representation_shape->layout());
|
||||
}
|
||||
|
||||
elems.push_back(handle);
|
||||
}
|
||||
}
|
||||
@ -411,20 +471,8 @@ Status BuildComputation(
|
||||
}
|
||||
*computation = computation_status.ConsumeValueOrDie();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const auto& program_shape,
|
||||
computation->GetProgramShape());
|
||||
TF_ASSIGN_OR_RETURN(auto program_shape, computation->GetProgramShape());
|
||||
*output_shape = program_shape.result();
|
||||
// Update the output layout to the layout of retval.
|
||||
for (auto& index_and_layout : retval_index_and_layout) {
|
||||
if (!always_return_tuple && elems.size() == 1) {
|
||||
*output_shape->mutable_layout() = index_and_layout.second;
|
||||
continue;
|
||||
}
|
||||
|
||||
xla::Shape* output_sub_shape = xla::ShapeUtil::GetMutableSubshape(
|
||||
output_shape, {index_and_layout.first});
|
||||
*output_sub_shape->mutable_layout() = index_and_layout.second;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -779,47 +827,6 @@ Status XlaCompiler::XLAShapeForArgument(
|
||||
const XlaCompiler::Argument& arg, bool is_entry_computation,
|
||||
const absl::optional<xla::HloSharding>& arg_sharding,
|
||||
xla::Shape* xla_shape) const {
|
||||
auto rewrite_layout_with_sharded_shape =
|
||||
[](const absl::optional<xla::HloSharding>& arg_sharding,
|
||||
bool use_fast_memory,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
xla::Shape* xla_shape) {
|
||||
if (arg_sharding && !arg_sharding->IsTileMaximal()) {
|
||||
// After parameter sharding, per core parameter might have different
|
||||
// layout. For example, before sharding, a parameter of shape [128,
|
||||
// 128] will be assigned default minor-to-major {1, 0}. But after we
|
||||
// shard this parameter to [128, 64] * 2, the sharded parameters
|
||||
// will have minor-to-major {0, 1}.
|
||||
//
|
||||
// As a result, for sharded parameters, we set their layout to per
|
||||
// core parameter's layout.
|
||||
//
|
||||
// TODO(endlessroad): for variable input & update, we might have
|
||||
// different layouts which will prevent input output aliasing and
|
||||
// increase memory usage. Investigate such cases.
|
||||
int64 device = *arg_sharding->tile_assignment().begin();
|
||||
std::vector<int64> offset =
|
||||
arg_sharding->TileOffsetForDevice(*xla_shape, device);
|
||||
std::vector<int64> limit =
|
||||
arg_sharding->TileLimitForDevice(*xla_shape, device);
|
||||
std::vector<int64> dimensions(xla_shape->rank());
|
||||
for (int64 i = 0; i < xla_shape->rank(); ++i) {
|
||||
dimensions[i] = limit[i] - offset[i];
|
||||
}
|
||||
xla::Shape per_device_xla_shape =
|
||||
xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
|
||||
TensorShape per_device_tensor_shape;
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(per_device_xla_shape,
|
||||
&per_device_tensor_shape));
|
||||
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
|
||||
xla_shape->element_type()));
|
||||
TF_ASSIGN_OR_RETURN(per_device_xla_shape,
|
||||
shape_representation_fn(per_device_tensor_shape,
|
||||
dtype, use_fast_memory));
|
||||
*xla_shape->mutable_layout() = per_device_xla_shape.layout();
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
switch (arg.kind) {
|
||||
case XlaCompiler::Argument::kConstant:
|
||||
LOG(FATAL) << "Unreachable case";
|
||||
@ -835,7 +842,7 @@ Status XlaCompiler::XLAShapeForArgument(
|
||||
TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn(
|
||||
shape, arg.type,
|
||||
/*use_fast_memory=*/false));
|
||||
TF_RETURN_IF_ERROR(rewrite_layout_with_sharded_shape(
|
||||
TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
|
||||
arg_sharding, /*use_fast_memory=*/false,
|
||||
options_.shape_representation_fn, xla_shape));
|
||||
} else {
|
||||
@ -863,7 +870,7 @@ Status XlaCompiler::XLAShapeForArgument(
|
||||
options_.shape_representation_fn(
|
||||
absl::get<TensorShape>(arg.shape), arg.type,
|
||||
/*use_fast_memory=*/arg.fast_mem));
|
||||
TF_RETURN_IF_ERROR(rewrite_layout_with_sharded_shape(
|
||||
TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
|
||||
arg_sharding, arg.fast_mem, options_.shape_representation_fn,
|
||||
xla_shape));
|
||||
return Status::OK();
|
||||
|
@ -365,7 +365,8 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForFastMemVar) {
|
||||
compile_options.return_updated_values_for_all_resources = true;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
||||
args, &result));
|
||||
EXPECT_EQ(fast_mem_arg_count, 1);
|
||||
// Count 2: one for argument, one for the return value.
|
||||
EXPECT_EQ(fast_mem_arg_count, 2);
|
||||
}
|
||||
|
||||
// Tests that the compiler can correctly propagate the layout assigned by
|
||||
@ -417,6 +418,8 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) {
|
||||
// Check that the return shapes are correctly tranposed.
|
||||
EXPECT_EQ(result.xla_output_shape,
|
||||
xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
|
||||
EXPECT_EQ(result.computation->GetProgramShape().ConsumeValueOrDie().result(),
|
||||
xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
|
||||
}
|
||||
|
||||
// The layout of resource variable shouldn't change after transpose
|
||||
@ -1091,6 +1094,8 @@ TEST_F(XlaCompilerTest, ResultLayoutSingle) {
|
||||
EXPECT_TRUE(xla::ShapeUtil::Equal(
|
||||
result.xla_output_shape,
|
||||
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1})));
|
||||
EXPECT_EQ(result.computation->GetProgramShape().ConsumeValueOrDie().result(),
|
||||
result.xla_output_shape);
|
||||
}
|
||||
|
||||
TEST_F(XlaCompilerTest, ResultLayoutMultiple) {
|
||||
@ -1131,6 +1136,8 @@ TEST_F(XlaCompilerTest, ResultLayoutMultiple) {
|
||||
EXPECT_TRUE(xla::ShapeUtil::Equal(
|
||||
result.xla_output_shape,
|
||||
xla::ShapeUtil::MakeTupleShape({result_shape, result_shape})));
|
||||
EXPECT_EQ(result.computation->GetProgramShape().ConsumeValueOrDie().result(),
|
||||
result.xla_output_shape);
|
||||
}
|
||||
|
||||
// Tests a simple graph that reads and writes a variable.
|
||||
|
@ -693,7 +693,10 @@ XlaOp Digamma(XlaOp input) {
|
||||
|
||||
namespace {
|
||||
|
||||
enum kIgammaMode { VALUE, DERIVATIVE, SAMPLE_DERIVATIVE };
|
||||
|
||||
// Helper function for computing Igamma using a power series.
|
||||
template <kIgammaMode mode>
|
||||
XlaOp IgammaSeries(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled,
|
||||
xla::PrimitiveType type) {
|
||||
// vals: (enabled, r, c, ans, x)
|
||||
@ -715,24 +718,60 @@ XlaOp IgammaSeries(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled,
|
||||
XlaOp c = vals[2];
|
||||
XlaOp ans = vals[3];
|
||||
XlaOp x = vals[4];
|
||||
XlaOp dc_da = vals[5];
|
||||
XlaOp dans_da = vals[6];
|
||||
|
||||
r = r + ScalarLike(r, 1);
|
||||
dc_da = dc_da * (x / r) + (ScalarLike(r, -1) * c * x) / (r * r);
|
||||
dans_da = dans_da + dc_da;
|
||||
c = c * (x / r);
|
||||
ans = ans + c;
|
||||
XlaOp conditional;
|
||||
if (mode == VALUE) {
|
||||
conditional = And(enabled, Gt(c / ans, Epsilon(builder, type)));
|
||||
} else {
|
||||
conditional =
|
||||
And(enabled, Gt(Abs(dc_da / dans_da), Epsilon(builder, type)));
|
||||
}
|
||||
|
||||
return std::vector<XlaOp>{
|
||||
And(enabled, Gt(c / ans, Epsilon(builder, type))),
|
||||
Select(enabled, r, vals[1]), Select(enabled, c, vals[2]),
|
||||
Select(enabled, ans, vals[3]), Select(enabled, x, vals[4])};
|
||||
conditional,
|
||||
Select(enabled, r, vals[1]),
|
||||
Select(enabled, c, vals[2]),
|
||||
Select(enabled, ans, vals[3]),
|
||||
Select(enabled, x, vals[4]),
|
||||
Select(enabled, dc_da, vals[5]),
|
||||
Select(enabled, dans_da, vals[6]),
|
||||
};
|
||||
};
|
||||
auto& b = *ax.builder();
|
||||
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
std::vector<XlaOp> vals = {enabled, a, FullLike(a, 1), FullLike(a, 1), x};
|
||||
std::vector<XlaOp> vals = {
|
||||
enabled, a, FullLike(a, 1), FullLike(a, 1), x, FullLike(a, 0),
|
||||
FullLike(a, 0),
|
||||
};
|
||||
|
||||
TF_ASSIGN_OR_RETURN(vals, WhileLoopHelper(cond, body, vals, "igamma", &b));
|
||||
XlaOp ans = vals[3];
|
||||
return (ans * ax) / a;
|
||||
XlaOp dans_da = vals[6];
|
||||
if (mode == VALUE) {
|
||||
return (ans * ax) / a;
|
||||
}
|
||||
|
||||
XlaOp dlogax_da = Log(x) - Digamma(a + ScalarLike(a, 1));
|
||||
|
||||
switch (mode) {
|
||||
case DERIVATIVE:
|
||||
return ax * (ans * dlogax_da + dans_da) / a;
|
||||
case SAMPLE_DERIVATIVE:
|
||||
default:
|
||||
return -(dans_da + ans * dlogax_da) * x / a;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Helper function for computing Igammac using a continued fraction.
|
||||
template <kIgammaMode mode>
|
||||
XlaOp IgammacContinuedFraction(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled,
|
||||
xla::PrimitiveType type) {
|
||||
// vals: enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2
|
||||
@ -754,6 +793,13 @@ XlaOp IgammacContinuedFraction(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled,
|
||||
XlaOp qkm1 = vals[7];
|
||||
XlaOp pkm2 = vals[8];
|
||||
XlaOp qkm2 = vals[9];
|
||||
|
||||
XlaOp dpkm2_da = vals[10];
|
||||
XlaOp dqkm2_da = vals[11];
|
||||
XlaOp dpkm1_da = vals[12];
|
||||
XlaOp dqkm1_da = vals[13];
|
||||
XlaOp dans_da = vals[14];
|
||||
|
||||
c = c + ScalarLike(c, 1);
|
||||
y = y + ScalarLike(y, 1);
|
||||
z = z + ScalarLike(z, 2);
|
||||
@ -762,18 +808,46 @@ XlaOp IgammacContinuedFraction(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled,
|
||||
XlaOp qk = qkm1 * z - qkm2 * yc;
|
||||
XlaOp qk_is_nonzero = Ne(qk, ScalarLike(qk, 0));
|
||||
XlaOp r = pk / qk;
|
||||
|
||||
t = Select(qk_is_nonzero, Abs((ans - r) / r), FullLike(t, 1));
|
||||
ans = Select(qk_is_nonzero, r, ans);
|
||||
|
||||
XlaOp dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c;
|
||||
XlaOp dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c;
|
||||
XlaOp dans_da_new =
|
||||
Select(qk_is_nonzero, (dpk_da - ans * dqk_da) / qk, dans_da);
|
||||
XlaOp grad_conditional =
|
||||
Select(qk_is_nonzero, Abs(dans_da_new - dans_da), FullLike(dans_da, 1));
|
||||
|
||||
pkm2 = pkm1;
|
||||
pkm1 = pk;
|
||||
qkm2 = qkm1;
|
||||
qkm1 = qk;
|
||||
|
||||
dpkm2_da = dpkm1_da;
|
||||
dqkm2_da = dqkm1_da;
|
||||
dpkm1_da = dpk_da;
|
||||
dqkm1_da = dqk_da;
|
||||
|
||||
XlaOp rescale = Gt(Abs(pk), Reciprocal(Epsilon(builder, type)));
|
||||
pkm2 = Select(rescale, pkm2 * Epsilon(builder, type), pkm2);
|
||||
pkm1 = Select(rescale, pkm1 * Epsilon(builder, type), pkm1);
|
||||
qkm2 = Select(rescale, qkm2 * Epsilon(builder, type), qkm2);
|
||||
qkm1 = Select(rescale, qkm1 * Epsilon(builder, type), qkm1);
|
||||
return std::vector<XlaOp>{And(enabled, Gt(t, Epsilon(builder, type))),
|
||||
|
||||
dpkm2_da = Select(rescale, dpkm2_da * Epsilon(builder, type), dpkm2_da);
|
||||
dqkm2_da = Select(rescale, dqkm2_da * Epsilon(builder, type), dqkm2_da);
|
||||
dpkm1_da = Select(rescale, dpkm1_da * Epsilon(builder, type), dpkm1_da);
|
||||
dqkm1_da = Select(rescale, dqkm1_da * Epsilon(builder, type), dqkm1_da);
|
||||
|
||||
XlaOp conditional;
|
||||
if (mode == VALUE) {
|
||||
conditional = And(enabled, Gt(t, Epsilon(builder, type)));
|
||||
} else {
|
||||
conditional = And(enabled, Gt(grad_conditional, Epsilon(builder, type)));
|
||||
}
|
||||
|
||||
return std::vector<XlaOp>{conditional,
|
||||
Select(enabled, ans, vals[1]),
|
||||
Select(enabled, t, vals[2]),
|
||||
Select(enabled, y, vals[3]),
|
||||
@ -782,7 +856,12 @@ XlaOp IgammacContinuedFraction(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled,
|
||||
Select(enabled, pkm1, vals[6]),
|
||||
Select(enabled, qkm1, vals[7]),
|
||||
Select(enabled, pkm2, vals[8]),
|
||||
Select(enabled, qkm2, vals[9])};
|
||||
Select(enabled, qkm2, vals[9]),
|
||||
Select(enabled, dpkm2_da, vals[10]),
|
||||
Select(enabled, dqkm2_da, vals[11]),
|
||||
Select(enabled, dpkm1_da, vals[12]),
|
||||
Select(enabled, dqkm1_da, vals[13]),
|
||||
Select(enabled, dans_da_new, vals[14])};
|
||||
};
|
||||
|
||||
auto& b = *ax.builder();
|
||||
@ -796,11 +875,31 @@ XlaOp IgammacContinuedFraction(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled,
|
||||
XlaOp qkm1 = z * x;
|
||||
XlaOp ans = pkm1 / qkm1;
|
||||
XlaOp t = FullLike(x, 1);
|
||||
std::vector<XlaOp> vals = {enabled, ans, t, y, z,
|
||||
c, pkm1, qkm1, pkm2, qkm2};
|
||||
XlaOp dpkm2_da = FullLike(x, 0);
|
||||
XlaOp dqkm2_da = FullLike(x, 0);
|
||||
XlaOp dpkm1_da = FullLike(x, 0);
|
||||
XlaOp dqkm1_da = -x;
|
||||
XlaOp dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1;
|
||||
std::vector<XlaOp> vals = {enabled, ans, t, y, z,
|
||||
c, pkm1, qkm1, pkm2, qkm2,
|
||||
dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da};
|
||||
|
||||
TF_ASSIGN_OR_RETURN(vals, WhileLoopHelper(cond, body, vals, "igammac", &b));
|
||||
ans = vals[1];
|
||||
return ans * ax;
|
||||
if (mode == VALUE) {
|
||||
return ans * ax;
|
||||
}
|
||||
|
||||
dans_da = vals[14];
|
||||
XlaOp dlogax_da = Log(x) - Digamma(a);
|
||||
|
||||
switch (mode) {
|
||||
case DERIVATIVE:
|
||||
return ax * (ans * dlogax_da + dans_da);
|
||||
case SAMPLE_DERIVATIVE:
|
||||
default:
|
||||
return -(dans_da + ans * dlogax_da) * x;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@ -820,9 +919,9 @@ XlaOp Igamma(XlaOp a, XlaOp x) {
|
||||
const double nan = std::numeric_limits<double>::quiet_NaN();
|
||||
XlaOp output = Select(
|
||||
use_igammac,
|
||||
ScalarLike(a, 1) -
|
||||
IgammacContinuedFraction(ax, x, a, And(enabled, use_igammac), type),
|
||||
IgammaSeries(ax, x, a, And(enabled, Not(use_igammac)), type));
|
||||
ScalarLike(a, 1) - IgammacContinuedFraction<VALUE>(
|
||||
ax, x, a, And(enabled, use_igammac), type),
|
||||
IgammaSeries<VALUE>(ax, x, a, And(enabled, Not(use_igammac)), type));
|
||||
output = Select(underflow, ZerosLike(output), output);
|
||||
output = Select(x_is_zero, ZerosLike(output), output);
|
||||
output = Select(Or(domain_error, is_nan), FullLike(a, nan), output);
|
||||
@ -852,6 +951,101 @@ XlaOp Igamma(XlaOp a, XlaOp x) {
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp IgammaGradA(XlaOp a, XlaOp x) {
|
||||
auto& b = *a.builder();
|
||||
auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp {
|
||||
XlaOp is_nan = Or(IsNan(a), IsNan(x));
|
||||
XlaOp x_is_zero = Eq(x, ScalarLike(x, 0));
|
||||
XlaOp domain_error = Or(Lt(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0)));
|
||||
XlaOp use_igammac = And(Gt(x, ScalarLike(x, 1)), Gt(x, a));
|
||||
XlaOp ax = a * Log(x) - x - Lgamma(a);
|
||||
XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type)));
|
||||
ax = Exp(ax);
|
||||
XlaOp enabled = Not(Or(Or(Or(x_is_zero, domain_error), underflow), is_nan));
|
||||
const double nan = std::numeric_limits<double>::quiet_NaN();
|
||||
XlaOp output = Select(use_igammac,
|
||||
-IgammacContinuedFraction<DERIVATIVE>(
|
||||
ax, x, a, And(enabled, use_igammac), type),
|
||||
IgammaSeries<DERIVATIVE>(
|
||||
ax, x, a, And(enabled, Not(use_igammac)), type));
|
||||
output = Select(underflow, ZerosLike(output), output);
|
||||
output = Select(x_is_zero, ZerosLike(output), output);
|
||||
output = Select(Or(domain_error, is_nan), FullLike(a, nan), output);
|
||||
return output;
|
||||
};
|
||||
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a));
|
||||
TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x));
|
||||
if (a_shape != x_shape) {
|
||||
return InvalidArgument(
|
||||
"Arguments to IgammaGradA must have equal shapes and types; got %s "
|
||||
"and %s",
|
||||
a_shape.ToString(), x_shape.ToString());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a));
|
||||
bool needs_upcast =
|
||||
a_shape.element_type() == F16 || a_shape.element_type() == BF16;
|
||||
|
||||
if (needs_upcast) {
|
||||
a = ConvertElementType(a, F32);
|
||||
x = ConvertElementType(x, F32);
|
||||
}
|
||||
XlaOp result = doit(a, x, a_shape.element_type());
|
||||
if (needs_upcast) {
|
||||
result = ConvertElementType(result, a_shape.element_type());
|
||||
}
|
||||
return result;
|
||||
});
|
||||
}
|
||||
|
||||
// Gradient of Gamma sample from Gamma(a, 1) with respect to `a`.
|
||||
XlaOp RandomGammaGrad(XlaOp a, XlaOp x) {
|
||||
auto& b = *a.builder();
|
||||
auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp {
|
||||
XlaOp is_nan = Or(IsNan(a), IsNan(x));
|
||||
XlaOp x_is_zero = Eq(x, ScalarLike(x, 0));
|
||||
XlaOp domain_error = Or(Lt(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0)));
|
||||
XlaOp use_igammac = And(Gt(x, ScalarLike(x, 1)), Gt(x, a));
|
||||
XlaOp ax = a * Log(x) - x - Lgamma(a);
|
||||
XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type)));
|
||||
ax = Exp(ax);
|
||||
XlaOp enabled = Not(Or(Or(Or(x_is_zero, domain_error), underflow), is_nan));
|
||||
const double nan = std::numeric_limits<double>::quiet_NaN();
|
||||
XlaOp output = Select(use_igammac,
|
||||
-IgammacContinuedFraction<SAMPLE_DERIVATIVE>(
|
||||
ax, x, a, And(enabled, use_igammac), type),
|
||||
IgammaSeries<SAMPLE_DERIVATIVE>(
|
||||
ax, x, a, And(enabled, Not(use_igammac)), type));
|
||||
output = Select(underflow, ZerosLike(output), output);
|
||||
output = Select(x_is_zero, ZerosLike(output), output);
|
||||
output = Select(Or(domain_error, is_nan), FullLike(a, nan), output);
|
||||
return output;
|
||||
};
|
||||
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a));
|
||||
TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x));
|
||||
if (a_shape != x_shape) {
|
||||
return InvalidArgument(
|
||||
"Arguments to RandomGammaGrad must have equal shapes and types; got "
|
||||
"%s and %s",
|
||||
a_shape.ToString(), x_shape.ToString());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("RandomGammaGrad", a));
|
||||
bool needs_upcast =
|
||||
a_shape.element_type() == F16 || a_shape.element_type() == BF16;
|
||||
|
||||
if (needs_upcast) {
|
||||
a = ConvertElementType(a, F32);
|
||||
x = ConvertElementType(x, F32);
|
||||
}
|
||||
XlaOp result = doit(a, x, a_shape.element_type());
|
||||
if (needs_upcast) {
|
||||
result = ConvertElementType(result, a_shape.element_type());
|
||||
}
|
||||
return result;
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp Igammac(XlaOp a, XlaOp x) {
|
||||
auto& b = *a.builder();
|
||||
auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp {
|
||||
@ -863,10 +1057,10 @@ XlaOp Igammac(XlaOp a, XlaOp x) {
|
||||
ax = Exp(ax);
|
||||
XlaOp result =
|
||||
Select(use_igamma,
|
||||
ScalarLike(a, 1) -
|
||||
IgammaSeries(ax, x, a, And(enabled, use_igamma), type),
|
||||
IgammacContinuedFraction(ax, x, a, And(enabled, Not(use_igamma)),
|
||||
type));
|
||||
ScalarLike(a, 1) - IgammaSeries<VALUE>(
|
||||
ax, x, a, And(enabled, use_igamma), type),
|
||||
IgammacContinuedFraction<VALUE>(
|
||||
ax, x, a, And(enabled, Not(use_igamma)), type));
|
||||
return Select(underflow, ZerosLike(a),
|
||||
Select(out_of_range, FullLike(a, 1), result));
|
||||
};
|
||||
@ -1008,12 +1202,23 @@ XlaOp Asinh(XlaOp x) {
|
||||
if (primitive_util::IsComplexType(shape.element_type())) {
|
||||
return Log(x + Sqrt(x * x + one));
|
||||
}
|
||||
// For small x, sqrt(x**2 + 1) will evaluate to 1 due to floating point
|
||||
// arithmetic. However, we would like to retain the low order term of this,
|
||||
// which is around 0.5 * x**2 using a binomial expansion.
|
||||
// Let z = sqrt(a**2 + 1)
|
||||
// log(a + sqrt(a**2 + 1)) =
|
||||
// log((a + sqrt(a**2 + 1)) * (1 + sqrt(a**2 + 1)) / (1 + sqrt(a**2 + 1))) =
|
||||
// log((a + a**2 + 1 + a * z + z) / (1 + z)) =
|
||||
// log(1 + a + a**2 / (1 + z)) =
|
||||
// log(1 + a + a ** 2 / (1 + sqrt(a**2 + 1)))
|
||||
// This rewrite retains the lower order term.
|
||||
auto a = Abs(x);
|
||||
auto small_result = Log1p(a + a * a / (one + Sqrt(a * a + one)));
|
||||
auto naive_result = Log(a + Sqrt(a * a + one));
|
||||
auto overflow_result = Log(Abs(a)) + Log(ScalarLike(a, 2));
|
||||
auto sqrt_max_value = Sqrt(MaxFiniteValue(b, shape.element_type()));
|
||||
return Sign(x) *
|
||||
Select(Ge(a, sqrt_max_value), overflow_result, naive_result);
|
||||
return Sign(x) * Select(Ge(a, sqrt_max_value), overflow_result,
|
||||
Select(Le(a, one), small_result, naive_result));
|
||||
};
|
||||
// These upcasts are not strictly necessary on all platforms to get within our
|
||||
// error tolerances, so we could relax this if it ever mattered.
|
||||
@ -1028,9 +1233,7 @@ XlaOp Atanh(XlaOp x) {
|
||||
XlaBuilder* b = x.builder();
|
||||
auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
|
||||
auto naive_result =
|
||||
Log((ScalarLike(x, 1.0) + x) / (ScalarLike(x, 1.0) - x)) *
|
||||
ScalarLike(x, 0.5);
|
||||
auto naive_result = (Log1p(x) - Log1p(-x)) * ScalarLike(x, 0.5);
|
||||
|
||||
// TODO(jlebar): For now, we ignore the nan edge case for complex inputs,
|
||||
// because we don't yet have exhaustive tests for complex trig functions.
|
||||
@ -1074,9 +1277,35 @@ XlaOp Cosh(XlaOp x) {
|
||||
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
|
||||
// we deem this acceptable.
|
||||
XlaOp Sinh(XlaOp x) {
|
||||
return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
|
||||
XlaBuilder* b = x.builder();
|
||||
auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
|
||||
auto one_half = ScalarLike(x, 0.5);
|
||||
auto log_one_half = Log(ScalarLike(x, 0.5));
|
||||
return Exp(x + log_one_half) - Exp(-x + log_one_half);
|
||||
auto large_sinh_result = Exp(x + log_one_half) - Exp(-x + log_one_half);
|
||||
|
||||
if (primitive_util::IsComplexType(shape.element_type())) {
|
||||
return large_sinh_result;
|
||||
}
|
||||
|
||||
// Here we use e^x = e^(x / 2) * e^(x / 2). This avoids overflow for large
|
||||
// values of x.
|
||||
|
||||
// For smaller x, we get unwanted cancellations of e^x - e^-x, resulting in
|
||||
// 0.
|
||||
// Rewrite this to avoid that. We use expm1(x) because that preserves the
|
||||
// first order term of the taylor series of e^x.
|
||||
// (e^(x) - e^(-x)) / 2. =
|
||||
// (e^(x) - 1 + 1 - e^(-x)) / 2.
|
||||
// (expm1(x) + (e^(x) - 1) / e^x) / 2.
|
||||
// (expm1(x) + expm1(x) / (expm1(x) + 1)) / 2.
|
||||
auto expm1 = Expm1(x);
|
||||
auto one = ScalarLike(x, 1.);
|
||||
auto small_sinh_result = one_half * (expm1 + expm1 / (expm1 + one));
|
||||
return Select(Lt(Abs(x), one), small_sinh_result, large_sinh_result);
|
||||
};
|
||||
return DoWithUpcastToF32(x, {BF16, F16}, [&](XlaOp x) {
|
||||
return b->ReportErrorOrReturn(do_it(x));
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -61,6 +61,14 @@ XlaOp Digamma(XlaOp input);
|
||||
// Computes an approximation of the incomplete gamma function.
|
||||
XlaOp Igamma(XlaOp a, XlaOp x);
|
||||
|
||||
// Computes an approximation of the derivative of the incomplete gamma function
|
||||
// with respect to a.
|
||||
XlaOp IgammaGradA(XlaOp a, XlaOp x);
|
||||
|
||||
// Computes an approximation of the derivative of a sample `x` from a `Gamma(a,
|
||||
// 1)` distribution with respect to a.
|
||||
XlaOp RandomGammaGrad(XlaOp a, XlaOp x);
|
||||
|
||||
// Computes an approximation of the complementary incomplete gamma function.
|
||||
XlaOp Igammac(XlaOp a, XlaOp x);
|
||||
|
||||
|
@ -298,6 +298,30 @@ XLA_TEST_F(MathTest, SqrtSixValues) {
|
||||
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(MathTest, SinhSmallValues) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = ConstantR1<float>(&builder, {1e-3, 1e-5, 1e-7, 1e-9, 1e-11});
|
||||
Sinh(x);
|
||||
std::vector<float> expected = {1e-3, 1e-5, 1e-7, 1e-9, 1e-11};
|
||||
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(MathTest, AsinhSmallValues) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = ConstantR1<float>(&builder, {1e-3, 1e-5, 1e-7, 1e-9, 1e-11});
|
||||
Asinh(x);
|
||||
std::vector<float> expected = {1e-3, 1e-5, 1e-7, 1e-9, 1e-11};
|
||||
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(MathTest, AtanhSmallValues) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = ConstantR1<float>(&builder, {1e-8, 1e-9, 1e-10, 1e-11});
|
||||
Atanh(x);
|
||||
std::vector<float> expected = {1e-8, 1e-9, 1e-10, 1e-11};
|
||||
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(MathTest, Lgamma) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = ConstantR1<float>(&builder, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.5, 1.5,
|
||||
|
@ -528,7 +528,8 @@ StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
|
||||
}
|
||||
|
||||
// Eliminate the size one dimensions.
|
||||
TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand, Reshape(reshaped_shape, operand));
|
||||
TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand,
|
||||
ReshapeInternal(reshaped_shape, operand));
|
||||
// Broadcast 'reshape' up to the larger size.
|
||||
return InDimBroadcast(broadcast_shape, reshaped_operand,
|
||||
broadcast_dimensions);
|
||||
@ -828,8 +829,8 @@ XlaOp XlaBuilder::BroadcastInDim(
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::Reshape(const Shape& shape, XlaOp operand,
|
||||
int64 inferred_dimension) {
|
||||
StatusOr<XlaOp> XlaBuilder::ReshapeInternal(const Shape& shape, XlaOp operand,
|
||||
int64 inferred_dimension) {
|
||||
TF_RETURN_IF_ERROR(first_error_);
|
||||
|
||||
HloInstructionProto instr;
|
||||
@ -1020,7 +1021,7 @@ XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> dimensions,
|
||||
XlaOp transposed = IsIdentityPermutation(dimensions)
|
||||
? operand
|
||||
: Transpose(operand, dimensions);
|
||||
return Reshape(shape, transposed, inferred_dimension);
|
||||
return ReshapeInternal(shape, transposed, inferred_dimension);
|
||||
});
|
||||
}
|
||||
|
||||
@ -1034,6 +1035,13 @@ XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> new_sizes,
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Reshape(const Shape& shape, XlaOp operand,
|
||||
int64 inferred_dimension) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
return ReshapeInternal(shape, operand, inferred_dimension);
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Collapse(XlaOp operand, absl::Span<const int64> dimensions) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
if (dimensions.size() <= 1) {
|
||||
@ -2951,6 +2959,10 @@ XlaOp Reshape(const XlaOp operand, absl::Span<const int64> new_sizes) {
|
||||
return operand.builder()->Reshape(operand, new_sizes);
|
||||
}
|
||||
|
||||
XlaOp Reshape(const Shape& shape, XlaOp operand) {
|
||||
return operand.builder()->Reshape(shape, operand);
|
||||
}
|
||||
|
||||
XlaOp ReshapeWithInferredDimension(XlaOp operand,
|
||||
absl::Span<const int64> new_sizes,
|
||||
int64 inferred_dimension) {
|
||||
|
@ -397,6 +397,9 @@ class XlaBuilder {
|
||||
XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes,
|
||||
int64 inferred_dimension = -1);
|
||||
|
||||
XlaOp Reshape(const Shape& shape, XlaOp operand,
|
||||
int64 inferred_dimension = -1);
|
||||
|
||||
XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
|
||||
|
||||
XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
|
||||
@ -668,8 +671,8 @@ class XlaBuilder {
|
||||
|
||||
// Internal helper method for creating a Reshape op with the already inferred
|
||||
// shape.
|
||||
StatusOr<XlaOp> Reshape(const Shape& shape, XlaOp operand,
|
||||
int64 inferred_dimension = -1);
|
||||
StatusOr<XlaOp> ReshapeInternal(const Shape& shape, XlaOp operand,
|
||||
int64 inferred_dimension = -1);
|
||||
|
||||
// Returns the (inferred) result for the program shape using the given root.
|
||||
StatusOr<ProgramShape> GetProgramShape(int64 root_id) const;
|
||||
@ -777,6 +780,8 @@ class XlaBuilder {
|
||||
|
||||
friend XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
|
||||
|
||||
friend XlaOp Reshape(const Shape& shape, XlaOp operand);
|
||||
|
||||
friend XlaOp ReshapeWithInferredDimension(XlaOp operand,
|
||||
absl::Span<const int64> new_sizes,
|
||||
int64 inferred_dimension);
|
||||
@ -1252,6 +1257,9 @@ XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
|
||||
// sizes. Conceptually, this is a limited form of "shape casting".
|
||||
XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
|
||||
|
||||
// Enqueues a Reshape op that uses an explicit target shape.
|
||||
XlaOp Reshape(const Shape& shape, XlaOp operand);
|
||||
|
||||
// `inferred_dimension` represents the output dimension that's inferred by
|
||||
// upper-level framework by dividing the input element count by the known
|
||||
// output element count. While an inferred_dimension can be static, if there
|
||||
|
@ -372,6 +372,7 @@ pybind_extension(
|
||||
# not require Tensorflow.
|
||||
"//tensorflow/core:lib_internal_impl", # buildcleaner: keep
|
||||
"//tensorflow/core/profiler/lib:profiler_backends",
|
||||
"//tensorflow/core/profiler/lib:profiler_session",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/core/profiler/rpc:profiler_server",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
|
@ -22,6 +22,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/python:local_client",
|
||||
"//tensorflow/compiler/xla/python:semaphore",
|
||||
"//tensorflow/compiler/xla/python/tpu_driver",
|
||||
"//tensorflow/compiler/xla/python/tpu_driver:direct_tpu_driver",
|
||||
"//tensorflow/compiler/xla/python/tpu_driver:grpc_tpu_driver",
|
||||
"//tensorflow/compiler/xla/python/tpu_driver:recording_tpu_driver",
|
||||
"//tensorflow/compiler/xla/python/tpu_driver:tpu_driver_proto_cc",
|
||||
|
@ -27,7 +27,8 @@
|
||||
namespace tpu_driver {
|
||||
namespace {
|
||||
|
||||
// Enable the macro by default in the env where the libtpu.so is available.
|
||||
// Enable the macro by default in the Google internal environment where the
|
||||
// libtpu.so is linked in statically.
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
#define TPU_SHARED_LIBRARY_COMPILE_LINK 1
|
||||
#endif
|
||||
|
@ -458,6 +458,8 @@ void BuildOpsSubmodule(py::module* m) {
|
||||
|
||||
ops.def("Igamma", &Igamma);
|
||||
ops.def("Igammac", &Igammac);
|
||||
ops.def("IgammaGradA", &IgammaGradA);
|
||||
ops.def("RandomGammaGrad", &RandomGammaGrad);
|
||||
ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta);
|
||||
|
||||
#define BINARY_OP(op) \
|
||||
|
@ -1698,6 +1698,7 @@ _BINARY_OPS = [
|
||||
'ShiftRightLogical',
|
||||
'Atan2',
|
||||
'Igamma',
|
||||
'IgammaGradA',
|
||||
'Igammac',
|
||||
'Complex',
|
||||
'NextAfter',
|
||||
|
@ -3727,7 +3727,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
|
||||
// Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced were
|
||||
// batch dimensions of the dot. The transformation supports reducing other
|
||||
// dimensions as well.
|
||||
if (Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) &&
|
||||
if (options_.enable_dot_strength_reduction() &&
|
||||
Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) &&
|
||||
Match(reduce->to_apply()->root_instruction(),
|
||||
m::Add(m::Parameter(), m::Parameter())) &&
|
||||
absl::c_any_of(reduce->dimensions(), [&](int64 dim) {
|
||||
|
@ -1043,15 +1043,31 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
|
||||
// Mark nondistinct/ambiguous indices.
|
||||
absl::flat_hash_set<const HloBuffer*> seen;
|
||||
absl::flat_hash_map<const HloBuffer*, ShapeIndex> seen;
|
||||
ShapeUtil::ForEachSubshape(
|
||||
root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
|
||||
std::vector<const HloBuffer*> buffers_at_index =
|
||||
alias_analysis->ComputeBuffersAt(root, index);
|
||||
bool buffer_seen_before = false;
|
||||
for (const HloBuffer* buffer : buffers_at_index) {
|
||||
buffer_seen_before |= !seen.insert(buffer).second;
|
||||
buffer_seen_before |= !seen.emplace(buffer, index).second;
|
||||
}
|
||||
|
||||
if (buffer_seen_before && policy.copy_root_replicated_buffers &&
|
||||
computation == module->entry_computation() &&
|
||||
module->input_output_alias_config().OutputHasAlias(index) &&
|
||||
buffers_at_index.size() == 1) {
|
||||
absl::optional<HloInputOutputAliasConfig::Alias> alias =
|
||||
module->input_output_alias_config().GetAliasedParameter(index);
|
||||
CHECK(alias) << "Alias does not exist";
|
||||
const ShapeIndex& other_index = seen[buffers_at_index[0]];
|
||||
VLOG(2) << "Output indices " << index.ToString() << " and "
|
||||
<< other_index.ToString() << " are both aliased to "
|
||||
<< alias->parameter_number << " copying " << other_index;
|
||||
add_index_to_copy(root, other_index);
|
||||
return;
|
||||
}
|
||||
|
||||
if (buffers_at_index.size() > 1 ||
|
||||
(buffer_seen_before && policy.copy_root_replicated_buffers)) {
|
||||
VLOG(2) << "Index " << index << " of computation "
|
||||
@ -1097,6 +1113,18 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static int64 GetNumExistingCopies(const HloModule* module) {
|
||||
int64 num_existing_copies = 0;
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (instruction->opcode() == HloOpcode::kCopy) {
|
||||
++num_existing_copies;
|
||||
}
|
||||
}
|
||||
}
|
||||
return num_existing_copies;
|
||||
}
|
||||
|
||||
Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
|
||||
HloModule* module) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||
@ -1112,13 +1140,24 @@ Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
|
||||
}
|
||||
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (instruction->opcode() == HloOpcode::kCopy &&
|
||||
copy_remover.TryElideCopy(instruction)) {
|
||||
TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction));
|
||||
TF_RETURN_IF_ERROR(
|
||||
instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
|
||||
|
||||
int64 num_existing_copies = GetNumExistingCopies(module);
|
||||
bool changed = true;
|
||||
int64 num_iterations = -1;
|
||||
while (changed) {
|
||||
CHECK_LE(++num_iterations, num_existing_copies);
|
||||
changed = false;
|
||||
VLOG(2) << "Running fixpoint iteration " << num_iterations
|
||||
<< " of copy elision";
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (instruction->opcode() == HloOpcode::kCopy &&
|
||||
copy_remover.TryElideCopy(instruction)) {
|
||||
changed = true;
|
||||
TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction));
|
||||
TF_RETURN_IF_ERROR(
|
||||
instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1156,17 +1195,6 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
|
||||
"Call graph must be flattened before copy insertion.");
|
||||
}
|
||||
|
||||
int64 num_existing_copies = 0;
|
||||
if (VLOG_IS_ON(1)) {
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (instruction->opcode() == HloOpcode::kCopy) {
|
||||
++num_existing_copies;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module));
|
||||
|
||||
// Simplify the tuple structures introduced by the deep copies. This should be
|
||||
@ -1185,7 +1213,6 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
|
||||
RemoveUnnecessaryCopies(DependencyHloOrdering(module), module));
|
||||
DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies",
|
||||
*module);
|
||||
|
||||
TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
|
||||
DumpHloModuleDuringPassIfEnabled(name(), "after adding special-case copies",
|
||||
*module);
|
||||
@ -1202,7 +1229,8 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
|
||||
}
|
||||
}
|
||||
}
|
||||
VLOG(1) << "Num copies before copy-insertion: " << num_existing_copies;
|
||||
VLOG(1) << "Num copies before copy-insertion: "
|
||||
<< GetNumExistingCopies(module);
|
||||
VLOG(1) << "Num copies after copy-insertion: " << num_total_copies;
|
||||
}
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user