Merge branch 'master' into allow_build_at_runtime

This commit is contained in:
Tamas Bela Feher 2020-02-19 17:39:55 +01:00
commit 68e9393d9f
373 changed files with 13923 additions and 3950 deletions

View File

@ -320,10 +320,8 @@ build:xla --define=with_xla_support=true
# Options when using remote execution
# WARNING: THESE OPTIONS WONT WORK IF YOU DO NOT HAVE PROPER AUTHENTICATION AND PERMISSIONS
build:rbe --action_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1
build:rbe --auth_enabled=true
build:rbe --auth_scope=https://www.googleapis.com/auth/cloud-source-tools
build:rbe --google_default_credentials
build:rbe --bes_backend=buildeventservice.googleapis.com
build:rbe --bes_best_effort=false
build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations"
build:rbe --bes_timeout=600s
build:rbe --define=EXECUTOR=remote
@ -336,7 +334,7 @@ build:rbe --spawn_strategy=remote,worker,standalone,local
test:rbe --test_env=USER=anon
# Attempt to minimize the amount of data transfer between bazel and the remote
# workers:
build:rbe --experimental_inmemory_jdeps_files --experimental_inmemory_dotd_files --experimental_remote_download_outputs=toplevel
build:rbe --remote_download_toplevel
build:rbe_linux --config=rbe
build:rbe_linux --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin"

View File

@ -113,3 +113,28 @@ http_archive(
"https://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
],
)
# Required for dependency @com_github_grpc_grpc
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
grpc_deps()
load(
"@build_bazel_rules_apple//apple:repositories.bzl",
"apple_rules_dependencies",
)
apple_rules_dependencies()
load(
"@build_bazel_apple_support//lib:repositories.bzl",
"apple_support_dependencies",
)
apple_support_dependencies()
load("@upb//bazel:repository_defs.bzl", "bazel_version_repository")
bazel_version_repository(name = "bazel_version")

View File

@ -505,13 +505,15 @@ selects.config_setting_group(
package_group(
name = "internal",
packages = [
# To pass open source testing in the pip Kokoros.
"//bazel_pip/tensorflow/...",
"//learning/brain/swift/x10/...",
"//perftools/accelerators/xprof/api/...",
"//third_party/py/autograph/...",
"//third_party/swift/tensorflow/x10/...",
"//tensorflow/...",
"//tensorflow_estimator/python/estimator/...",
"//tensorflow_models/official/...",
"//third_party/py/autograph/...",
"//third_party/swift/tensorflow/x10/...",
],
)
@ -545,8 +547,8 @@ cc_library(
name = "grpc",
visibility = ["//visibility:public"],
deps = select({
":linux_s390x": ["@grpc//:grpc_unsecure"],
"//conditions:default": ["@grpc"],
":linux_s390x": ["@com_github_grpc_grpc//:grpc_unsecure"],
"//conditions:default": ["@com_github_grpc_grpc//:grpc"],
}),
)
@ -554,8 +556,8 @@ cc_library(
name = "grpc++",
visibility = ["//visibility:public"],
deps = select({
":linux_s390x": ["@grpc//:grpc++_unsecure"],
"//conditions:default": ["@grpc//:grpc++"],
":linux_s390x": ["@com_github_grpc_grpc//:grpc++_unsecure"],
"//conditions:default": ["@com_github_grpc_grpc//:grpc++"],
}),
)

View File

@ -1883,6 +1883,8 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"EmptyTensorList",
"ExtractImagePatches",
"Igamma",
"IgammaGradA",
"RandomGammaGrad",
"Igammac",
"FFT",
"FFT2D",
@ -1909,7 +1911,6 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"LinSpace",
"ListDiff",
"LogMatrixDeterminant",
"LowerBound",
"MatMul",
"MatrixBandPart",
"MatrixDiag",
@ -2036,7 +2037,6 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"TensorScatterUpdate",
"TridiagonalSolve",
"TruncatedNormal",
"UpperBound",
"UnsortedSegmentMax",
"UnsortedSegmentMin",
"UnsortedSegmentProd",

View File

@ -20,15 +20,17 @@ limitations under the License.
namespace tensorflow {
bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr,
const NodeDef& node_def) const {
return CanCreateXlaKernel(node_def);
bool XlaKernelCreator::CanCreateKernel(
const FunctionLibraryRuntime& flr,
const std::shared_ptr<const NodeProperties>& props) const {
return CanCreateXlaKernel(props->node_def);
}
Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr,
const NodeDef& node_def,
std::unique_ptr<OpKernel>* kernel) const {
return CreateXlaKernel(flr, node_def, kernel);
Status XlaKernelCreator::CreateKernel(
FunctionLibraryRuntime* flr,
const std::shared_ptr<const NodeProperties>& props,
std::unique_ptr<OpKernel>* kernel) const {
return CreateXlaKernel(flr, props->node_def, kernel);
}
namespace {

View File

@ -29,11 +29,13 @@ class XlaKernelCreator : public CustomKernelCreator {
// Given a NodeDef 'node_def' and the function library runtime 'flr', returns
// true if 'node_def' is a call to a compilable function defined in 'flr',
// with the kXlaCompileAttr set.
bool CanCreateKernel(const FunctionLibraryRuntime& flr,
const NodeDef& node_def) const override;
bool CanCreateKernel(
const FunctionLibraryRuntime& flr,
const std::shared_ptr<const NodeProperties>& props) const override;
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
Status CreateKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
Status CreateKernel(FunctionLibraryRuntime* flr,
const std::shared_ptr<const NodeProperties>& props,
std::unique_ptr<OpKernel>* kernel) const override;
};

View File

@ -30,10 +30,12 @@ limitations under the License.
namespace tensorflow {
NodeDef ToNodeDef(const string& text) {
std::shared_ptr<NodeProperties> ToNodeProperties(const string& text) {
NodeDef node_def;
DataTypeVector dummy;
EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
return node_def;
return std::make_shared<NodeProperties>(nullptr, std::move(node_def), dummy,
dummy);
}
// Create a FunctionDef that takes one resource and one regular param
@ -98,11 +100,11 @@ TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) {
(*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
Init({fdef});
XlaKernelCreator xla_kernel_creator;
NodeDef callsite =
ToNodeDef(R"pb(
auto callsite =
ToNodeProperties(R"pb(
name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
)pb");
(*callsite.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
(*(callsite->node_def.mutable_attr()))["_XlaMustCompile"] = BoolAttr(true);
// Note: need to set attribute on the created node.
Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_);
@ -127,13 +129,14 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) {
Init({fdef});
XlaKernelCreator xla_kernel_creator;
Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto(
name: 'XTimesY'
op: 'XTimesY'
input: 'a'
input: 'b'
)proto"),
&kernel_);
Status status =
xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
name: 'XTimesY'
op: 'XTimesY'
input: 'a'
input: 'b'
)proto"),
&kernel_);
EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
}
@ -143,13 +146,14 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
Init({fdef});
XlaKernelCreator xla_kernel_creator;
Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto(
name: 'XTimesY'
op: 'XTimesY'
input: 'a'
input: 'b'
)proto"),
&kernel_);
Status status =
xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
name: 'XTimesY'
op: 'XTimesY'
input: 'a'
input: 'b'
)proto"),
&kernel_);
EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
}

View File

@ -218,12 +218,13 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
Device* dev = flr->device();
Status s;
OpKernelConstruction construction(
DeviceType(dev->device_type()), dev,
dev->GetAllocator(AllocatorAttributes()), &node_def,
&fbody->fdef.signature(), flr, dev->resource_manager(), fbody->arg_types,
input_memory_types, fbody->ret_types, output_memory_types,
flr->graph_def_version(), &s);
auto props = std::make_shared<NodeProperties>(
&fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types);
OpKernelConstruction construction(DeviceType(dev->device_type()), dev,
dev->GetAllocator(AllocatorAttributes()),
flr, dev->resource_manager(), props,
input_memory_types, output_memory_types,
flr->graph_def_version(), &s);
*kernel = absl::make_unique<XlaLocalLaunchBase>(
&construction, constant_arg_indices, resource_arg_indices, function,

View File

@ -208,6 +208,7 @@ cc_library(
"ir/tfl_ops.h.inc",
"ir/tfl_ops_interface.cc.inc",
"ir/tfl_ops_interface.h.inc",
"runtime_verifiers.inc",
"utils/attribute_utils.cc",
],
hdrs = [
@ -303,12 +304,14 @@ cc_library(
"transforms/optimize_functional_ops.cc",
"transforms/prepare_composite_functions_tf.cc",
"transforms/prepare_tf.cc",
"transforms/runtime_type_verify.cc",
"transforms/split_merged_operands.cc",
"transforms/trim_functions_tf.cc",
"transforms/unroll_batch_matmul.cc",
"transforms/while_loop_outline.cc",
],
hdrs = [
"ir/tfl_ops_interface.h.inc",
"transforms/dilated_conv.h",
"transforms/passes.h",
"transforms/unroll_batch_matmul.h",
@ -461,9 +464,9 @@ cc_library(
)
tf_native_cc_binary(
name = "operator-converter-gen",
name = "converter-gen",
srcs = [
"operator_converter_gen.cc",
"converter_gen.cc",
],
deps = [
"@llvm-project//llvm:support",
@ -473,14 +476,18 @@ tf_native_cc_binary(
)
gentbl(
name = "operator_converter_inc",
name = "converter_inc",
tbl_outs = [
(
"", # This driver has no options.
"--gen-operator-converters",
"operator_converters.inc",
),
(
"--gen-runtime-verifiers",
"runtime_verifiers.inc",
),
],
tblgen = ":operator-converter-gen",
tblgen = ":converter-gen",
td_file = "ir/tfl_ops.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
@ -650,6 +657,7 @@ tf_cc_binary(
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],
)

View File

@ -28,6 +28,9 @@ limitations under the License.
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
#include "mlir/TableGen/Attribute.h" // TF:llvm-project
#include "mlir/TableGen/Format.h" // TF:llvm-project
#include "mlir/TableGen/Operator.h" // TF:llvm-project
#include "mlir/TableGen/Predicate.h" // TF:llvm-project
using llvm::DefInit;
using llvm::dyn_cast;
@ -41,6 +44,19 @@ using llvm::SmallVector;
using llvm::StringInit;
using llvm::StringRef;
enum ActionType {
OpConv,
RuntimeVerify,
};
// NOLINTNEXTLINE
llvm::cl::opt<ActionType> action(
llvm::cl::desc("Action to perform:"),
llvm::cl::values(clEnumValN(OpConv, "gen-operator-converters",
"Generate operator converters"),
clEnumValN(RuntimeVerify, "gen-runtime-verifiers",
"Generate TFLite runtime verifiers")));
// Returns the associated option name for the given op definition.
static inline std::string GetOperatorOptionName(const Record &def) {
assert(def.getName().startswith("TFL_") && "unexpected op prefix");
@ -342,8 +358,101 @@ static bool OperatorWritersMain(raw_ostream &os, RecordKeeper &records) {
return false;
}
static void GenOperandResultVerifier(raw_ostream &os,
llvm::ArrayRef<llvm::Init *> values,
StringRef valueKind) {
mlir::tblgen::FmtContext fctx;
bool first = true;
for (auto static_value : llvm::enumerate(values)) {
auto *definit = llvm::cast<llvm::DefInit>(static_value.value());
auto *val = definit->getDef()->getValue("tflRuntimeTypePredicate");
if (!val) continue;
// Create code block on first type to verify.
if (first) {
os << " {\n";
os << " unsigned index = " << static_value.index() << ";\n";
first = false;
}
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
auto desc =
definit->getDef()->getValueAsString("tflRuntimeTypeDescription");
// Emit a loop to check all the dynamic values in the pack.
os << formatv(" for (Value v : top.getODS{0}{1}s({2})) {{\n",
// Capitalize the first letter to match the function name
valueKind.substr(0, 1).upper(), valueKind.substr(1),
static_value.index());
os << " (void)v;\n"
<< " if (!("
<< tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n"
<< formatv(
" return op->emitOpError(\"{0} #\") << index "
"<< \" must be {1}, but got \" << v.getType();\n",
valueKind, desc)
<< " }\n" // if
<< " ++index;\n"
<< " }\n"; // for
}
// Emit closing brace if needed.
if (!first) os << " }\n";
}
// NOLINTNEXTLINE
static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
emitSourceFileHeader("MLIR TFLite Runtime Verifiers", os);
// Retrieve all the definitions derived from TFL_Op and sort by record name.
std::vector<Record *> defs = records.getAllDerivedDefinitions("Op");
llvm::sort(defs, LessRecord());
// Iterate through all the ops defined.
for (const auto *def : defs) {
mlir::tblgen::Operator op(*def);
if (!op.getTrait("TflRuntimeVerifyOpInterface::Trait")) continue;
mlir::tblgen::FmtContext verify_ctx;
os << "::mlir::LogicalResult " << op.getCppClassName()
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op) {\n";
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
verify_ctx.withOp("top");
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
auto &value = op.getOperand(i);
// Skip from from first variadic operands for now. Else getOperand index
// used below doesn't match.
if (value.isVariadic()) break;
if (!value.name.empty())
verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i));
}
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
auto &value = op.getResult(i);
// Skip from from first variadic results for now. Else getResult index
// used below doesn't match.
if (value.isVariadic()) break;
if (!value.name.empty())
verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i));
}
}
GenOperandResultVerifier(os, def->getValueAsDag("arguments")->getArgs(),
"operand");
GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(),
"result");
os << " return mlir::success();\n}\n";
}
return false;
}
int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv);
return TableGenMain(argv[0], &OperatorWritersMain);
if (action == ActionType::OpConv)
return TableGenMain(argv[0], &OperatorWritersMain);
return TableGenMain(argv[0], &RuntimeVerifierWriterMain);
}

View File

@ -71,4 +71,23 @@ def TFL_SparseOp : OpInterface<"SparseOpInterface"> {
];
}
//===----------------------------------------------------------------------===//
// TFL runtime type verification of operand/result types.
def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
let description = [{
Interface to verify TFLite runtime op verification.
This verifies that the converted TFLite ops has operand/result type
supported by the TFLite runtime.
}];
let methods = [
StaticInterfaceMethod<
[{Returns whether the op's operands/results are supported by runtime.}],
"LogicalResult", "VerifyTflRuntimeTypes", (ins "Operation*":$op)
>,
];
}
#endif // TFL_OP_INTERFACES

View File

@ -1872,6 +1872,7 @@ LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) {
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc"
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
#include "tensorflow/compiler/mlir/lite/runtime_verifiers.inc"
Operation *TensorFlowLiteDialect::materializeConstant(OpBuilder &builder,
Attribute value,

File diff suppressed because it is too large Load Diff

View File

@ -282,6 +282,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
if (pass_config.legalize_tf_while) {
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
}
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,

View File

@ -150,7 +150,8 @@ struct QuantizationPattern : public RewritePattern {
explicit QuantizationPattern(MLIRContext* context, bool enable_verify,
float error_tolerance, bool single_layer_verify)
: RewritePattern(DQ::getOperationName(), 1, context),
// Set the score to a large number so it is always preferred.
: RewritePattern(DQ::getOperationName(), 300, context),
enable_verify(enable_verify),
error_tolerance(error_tolerance),
single_layer_verify(single_layer_verify) {}

View File

@ -1373,3 +1373,37 @@ func @reciprocal_i64(%arg0: tensor<8xi64>) -> tensor<8xi64> {
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xi64>, tensor<8xi64>) -> tensor<8xi64>
// CHECK: return
}
func @random_uniform() -> tensor<2x5xf32> {
%0 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32>
%1 = "tf.RandomUniform"(%0) { seed = 1, seed2 = 0} : (tensor<2xi32>) -> tensor<2x5xf32>
return %1 : tensor<2x5xf32>
// CHECK-LABEL: random_uniform
// CHECK: %[[CST:.*]] = constant dense
// CHECK: return %[[CST:.*]] : tensor<2x5xf32>
}
func @random_uniform_no_fold(%arg0: tensor<2xi32>) -> tensor<2x5xf32> {
%1 = "tf.RandomUniform"(%arg0) { seed = 0, seed2 = 0} : (tensor<2xi32>) -> tensor<2x5xf32>
return %1 : tensor<2x5xf32>
// CHECK-LABEL: random_uniform_no_fold
// CHECK: %[[RANDOM:.*]] = "tf.RandomUniform"
}
func @random_uniform_no_fold2(%arg0: tensor<2xi32>) -> tensor<*xf32> {
%1 = "tf.RandomUniform"(%arg0) { seed = 1, seed2 = 2} : (tensor<2xi32>) -> tensor<*xf32>
return %1 : tensor<*xf32>
// CHECK-LABEL: random_uniform_no_fold2
// CHECK: %[[RANDOM:.*]] = "tf.RandomUniform"
}
func @random_uniform_no_fold3(%arg0: tensor<2xi32>) -> tensor<*xf64> {
%1 = "tf.RandomUniform"(%arg0) { seed = 1, seed2 = 2} : (tensor<2xi32>) -> tensor<*xf64>
return %1 : tensor<*xf64>
// CHECK-LABEL: random_uniform_no_fold3
// CHECK: %[[RANDOM:.*]] = "tf.RandomUniform"
}

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -split-input-file -verify-diagnostics %s | FileCheck %s --dump-input-on-failure
// RUN: tf-opt -split-input-file -verify-diagnostics -tfl-runtime-verify %s | FileCheck %s --dump-input-on-failure
// Unary math ops
// -----

View File

@ -2,39 +2,44 @@
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize -tfl-numeric-verify | FileCheck --check-prefix=DEBUG %s
// CHECK-LABEL: QuantizeFloatConst
func @QuantizeFloatConst() -> tensor<f32> {
func @QuantizeFloatConst() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
%0 = constant dense<-0.1> : tensor<2x2xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
%2 = "tfl.dequantize"(%1) : (tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<f32>
return %2 : tensor<f32>
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
return %1 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<0> : tensor<2x2xi8>}
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[cst]])
// CHECK: return %[[dq]] : tensor<f32>
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<0> : tensor<2x2xi8>}
// CHECK: return %[[cst]]
}
// CHECK-LABEL: QuantizeDenseFloatConst
func @QuantizeDenseFloatConst() -> tensor<2x2xf32> {
func @QuantizeDenseFloatConst() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
%0 = constant dense<[[-0.1, 1.0], [1.0, 3.0]]> : tensor<2x2xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<2x2xf32>
return %2 : tensor<2x2xf32>
return %1 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<{{\[\[}}0, -1], {{\[}}-1, -1]]> : tensor<2x2xi8>}
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[cst]])
// CHECK: return %[[dq]] : tensor<2x2xf32>
// CHECK: return %[[cst]]
}
// CHECK-LABEL: QuantizeSplatFloatConst
func @QuantizeSplatFloatConst() -> tensor<2x2xf32> {
func @QuantizeSplatFloatConst() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
%0 = constant dense<3.0> : tensor<2x2xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
return %1 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<-1> : tensor<2x2xi8>}
// CHECK: return %[[cst]]
}
// CHECK-LABEL: NotQuantizeFloatConst
func @NotQuantizeFloatConst() -> tensor<2x2xf32> {
%0 = constant dense<-0.1> : tensor<2x2xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<2x2xf32>
return %2 : tensor<2x2xf32>
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<-1> : tensor<2x2xi8>}
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[cst]])
// CHECK: return %[[dq]] : tensor<2x2xf32>
// CHECK: %[[cst:.*]] = constant dense<-1.000000e-01> : tensor<2x2xf32>
// CHECK: return %[[cst]] : tensor<2x2xf32>
}
// CHECK-LABEL: DequantizeAndQuantize

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/init_mlir.h"
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
@ -32,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
#include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h"
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/lite/model.h"
@ -182,6 +184,7 @@ int main(int argc, char **argv) {
pass_config.inline_functions = inline_functions;
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
std::string result;
auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(

View File

@ -49,6 +49,8 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
namespace mlir {
@ -114,9 +116,54 @@ DECL_CONVERT_OP(SplitV);
DECL_CONVERT_OP(StridedSlice);
DECL_CONVERT_OP(Unpack);
DECL_CONVERT_OP(Reciprocal);
DECL_CONVERT_OP(RandomUniform);
#undef DECL_CONVERT_OP
PatternMatchResult ConvertTFRandomUniformOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto random_uniform_op = cast<TF::RandomUniformOp>(op);
if (random_uniform_op.seed() == 0 && random_uniform_op.seed2() == 0) {
return matchFailure();
}
if (!random_uniform_op.dtype().isF32()) {
return matchFailure();
}
typedef tensorflow::random::UniformDistribution<
tensorflow::random::PhiloxRandom, float>
Distribution;
tensorflow::random::PhiloxRandom generator(
random_uniform_op.seed().getSExtValue(),
random_uniform_op.seed2().getSExtValue());
Distribution dist;
int num_elements = 0;
if (auto output_type =
random_uniform_op.output().getType().dyn_cast_or_null<ShapedType>()) {
if (auto ranked_output = output_type.dyn_cast_or_null<RankedTensorType>()) {
if (!ranked_output.hasRank() || ranked_output.getNumDynamicDims() != 0) {
return matchFailure();
}
num_elements = output_type.getNumElements();
size_t offset = 0;
size_t num_samples = Distribution::kResultElementCount;
llvm::SmallVector<float, 32> data;
data.resize(num_elements);
while (offset < num_elements) {
const typename Distribution::ResultType samples = dist(&generator);
std::copy(&samples[0],
&samples[0] + std::min(num_samples, data.size() - offset),
&data[0] + offset);
offset += num_samples;
}
auto output_data = DenseFPElementsAttr::get(output_type, data);
rewriter.replaceOpWithNewOp<ConstantOp>(op, output_type, output_data);
return matchSuccess();
}
}
return matchFailure();
}
PatternMatchResult ConvertTFConcatOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_concat_op = cast<TF::ConcatOp>(op);
@ -521,11 +568,12 @@ void LegalizeTF::runOnFunction() {
// Add the generated patterns to the list.
populateWithGenerated(ctx, &patterns);
patterns.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
ConvertTFMatrixDiagV2Op, ConvertTFMatrixDiagV3Op,
ConvertTFPackOp, ConvertTFReshapeOp, ConvertTFSplitOp,
ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp,
ConvertTFAssertOp, ConvertTFReciprocalOp>(ctx);
patterns
.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
ConvertTFMatrixDiagV2Op, ConvertTFMatrixDiagV3Op, ConvertTFPackOp,
ConvertTFReshapeOp, ConvertTFSplitOp, ConvertTFSplitVOp,
ConvertTFStridedSliceOp, ConvertTFUnpackOp, ConvertTFAssertOp,
ConvertTFReciprocalOp, ConvertTFRandomUniformOp>(ctx);
applyPatternsGreedily(func, patterns);
}

View File

@ -91,6 +91,9 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFWhilePass();
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateWhileOutlinePass();
// Verifies runtime supports types used.
std::unique_ptr<OpPassBase<FuncOp>> CreateRuntimeTypeVerifyPass();
} // namespace TFL
} // namespace mlir

View File

@ -21,12 +21,20 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
// Quantize attribute $0 by using quantization parameter from %1.
def QuantizeByQuantizedType : NativeCodeCall<"quant::Quantize($0, $1.getValue())">;
def F32ElementsAttr : ElementsAttrBase<
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
// Squash tfl.dequantize and tfl.quantize pairs.
// TODO(fengliuai): Compare the scale of input and output. This can also be
// squashed to a requantize op if the scales are different.
def : Pat<(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), (replaceWithValue $in)>;
// If the tfl.dequantize op wasn't fused, we shouldn't quantize the floating
// point constant.
def : Pat<(TFL_DequantizeOp
(TFL_QuantizeOp (ConstantOp F32ElementsAttr:$cst), $qt)),
(ConstantOp $cst)>;
// Quantize the value of a constant op if the quantization parameters have been
// propagated to the output.
def : Pat<(TFL_QuantizeOp

View File

@ -0,0 +1,52 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
namespace mlir {
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc"
namespace TFL {
namespace {
// This pass verifies that the operands and results types are supported by
// TFLite runtime.
class RuntimeTypeVerifyPass : public mlir::FunctionPass<RuntimeTypeVerifyPass> {
public:
explicit RuntimeTypeVerifyPass() {}
private:
void runOnFunction() override;
};
void RuntimeTypeVerifyPass::runOnFunction() {
getFunction().walk([&](TflRuntimeVerifyOpInterface op) {
if (failed(op.VerifyTflRuntimeTypes(op.getOperation())))
signalPassFailure();
});
}
} // namespace
// Verifies runtime supports types used.
std::unique_ptr<OpPassBase<FuncOp>> CreateRuntimeTypeVerifyPass() {
return std::make_unique<RuntimeTypeVerifyPass>();
}
static PassRegistration<RuntimeTypeVerifyPass> pass(
"tfl-runtime-verify", "TFLite runtime verification");
} // namespace TFL
} // namespace mlir

View File

@ -168,6 +168,10 @@ std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) {
result.getResultNumber());
return std::string(result.getOwner()->getName().getStringRef());
}
// Use the ASM syntax for BloackArgument
if (auto arg = val.dyn_cast<mlir::BlockArgument>()) {
return "arg" + std::to_string(arg.getArgNumber());
}
return "";
}

View File

@ -41,11 +41,52 @@ limitations under the License.
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/Support/STLExtras.h" // TF:llvm-project
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/core/platform/logging.h"
namespace mlir {
namespace tf_device {
//===----------------------------------------------------------------------===//
// TF Device Dialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
struct TFInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
//===--------------------------------------------------------------------===//
// Analysis Hooks
//===--------------------------------------------------------------------===//
// Defines the legality of inlining TF Device operations.
bool isLegalToInline(Operation*, Region*, BlockAndValueMapping&) const final {
// For now, enable inlining all operations.
return true;
}
//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
// Attempts to materialize a conversion for a type mismatch between a call
// from this dialect, and a callable region. This method should generate an
// operation that takes 'input' as the only operand, and produces a single
// result of 'resultType'. If a conversion can not be generated, nullptr
// should be returned.
// This is just re-using the same logic as the TensorFlow dialect right now.
Operation* materializeCallConversion(OpBuilder& builder, Value input,
Type result_type,
Location conversion_loc) const final {
if (!result_type.isa<TensorType>() || !input.getType().isa<TensorType>())
return nullptr;
return builder.create<TF::CastOp>(conversion_loc, result_type, input,
/*truncate=*/builder.getBoolAttr(false));
}
};
} // end anonymous namespace
TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
: Dialect(/*name=*/"tf_device", context) {
addOperations<
@ -54,6 +95,8 @@ TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
>();
addOperations<ParallelExecuteOp>();
addInterfaces<TFInlinerInterface>();
}
//===----------------------------------------------------------------------===//

View File

@ -49,7 +49,7 @@ an output element, this operation computes \\(y = |x|\\).
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape]>,
def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic]>,
WithBroadcastableBinOpBuilder {
let summary = "Returns x + y element-wise.";
@ -98,7 +98,7 @@ Inputs must be of same size and shape.
let hasFolder = 1;
}
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape]>,
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic]>,
WithBroadcastableBinOpBuilder {
let summary = "Returns x + y element-wise.";
@ -6781,7 +6781,7 @@ variables.
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
}
def TF_TanhOp : TF_Op<"Tanh", [NoSideEffect, SameOperandsAndResultType]> {
def TF_TanhOp : TF_Op<"Tanh", [NoSideEffect, SameOperandsAndResultType, TF_LayoutAgnostic]> {
let summary = "Computes hyperbolic tangent of `x` element-wise.";
let description = [{

View File

@ -58,6 +58,10 @@ TODO: Make invariants more structured so that we can reference them in ops.
def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait<
"TF::OperandsSameAsResultsTypeOrRef">;
// Layout agnostic operations do not depend on the operands data layout (data
// format), as an example all element wise operations are layout agnostic.
def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">;
//===----------------------------------------------------------------------===//
// TensorFlow op definitions
//===----------------------------------------------------------------------===//

View File

@ -68,6 +68,11 @@ class OperandsSameAsResultsTypeOrRef
}
};
// Layout agnostic operations do not depend on the operands data layout (data
// format), as and example all element wise operations are layout agnostic.
template <typename ConcreteType>
class LayoutAgnostic : public TraitBase<ConcreteType, LayoutAgnostic> {};
} // namespace TF
} // namespace OpTrait
} // namespace mlir

View File

@ -3,6 +3,10 @@
// All tests also test for idempotence.
// Test that external functions aren't processed (used to crash).
// CHECK-LABEL: func @unused_external_func
func @unused_external_func()
func @multiple_return(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> (tensor<*xi32>, tensor<*xi32>) {
%graph:2 = tf_executor.graph {
%island:3 = tf_executor.island {

View File

@ -0,0 +1,57 @@
// RUN: tf-opt %s -tf-executor-tpu-v1-island-coarsening | FileCheck %s --dump-input=fail
// Test that islands with a function call are merged if the call is to a function
// that contains ops with the same attribute.
// CHECK-LABEL: func @control_input
func @control_input(%arg0 : tensor<i1>) -> tensor<i32> {
%0:6 = tf_executor.graph {
%1:2 = tf_executor.island wraps "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
%2:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "A", body = @while_body_with_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
%3:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "B", body = @while_body_with_wrong_cluster_attr, cond = @while_cond_with_wrong_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
%4:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "C", body = @while_body_without_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
%6:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "D", body = @while_body_without_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
%5:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "E", body = @while_body_with_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
// CHECK: "tf.opA"
// CHECK-NOT: island
// CHECK: name = "A"
// CHECK-NOT: island
// CHECK: name = "C"
// CHECK-NOT: island
// CHECK: name = "E"
// CHECK: island {{.*}}name = "B"
// CHECK: island {{.*}}name = "D"
tf_executor.fetch %1#0, %2#0, %3#0, %4#0, %5#0, %6#0 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
return %0#0 : tensor<i32>
}
func @while_body_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
func @while_cond_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func @while_body_with_wrong_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "some.op"(%arg0) {_tpu_replicate = "wrong_cluster"} : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
func @while_cond_with_wrong_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "some.op"(%arg0) {_tpu_replicate = "wrong_cluster"} : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func @while_body_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
func @while_cond_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}

View File

@ -0,0 +1,44 @@
// RUN: tf-opt %s -tf-executor-tpu-v1-island-inlining | FileCheck %s --dump-input=fail
// CHECK-NOT: tf.PartitionedCall
// CHECK-NOT: module @_tpu_v1_compat_outlined
module {
func @control_input(%arg0: tensor<i1>) -> tensor<i32> {
%0:4 = tf_executor.graph {
%outputs:4, %control = tf_executor.island wraps "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @_tpu_v1_compat_outlined::@_tpu_v1_compat_outlined_func0} : (tensor<i1>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>)
tf_executor.fetch %outputs#0, %outputs#1, %outputs#2, %outputs#3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
return %0#0 : tensor<i32>
}
module @_tpu_v1_compat_outlined {
func @_tpu_v1_compat_outlined_func0(%arg0: tensor<i1>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
"tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1 : i64, topology = "topology"} : () -> ()
%0 = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
%1 = "tf.While"(%0) {body = @while_body_with_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "A", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
%2 = "tf.While"(%0) {body = @while_body_without_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "C", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
%3 = "tf.While"(%0) {body = @while_body_with_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, name = "E", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
return %0, %1, %2, %3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
func @while_body_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
func @while_cond_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func @while_body_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
func @while_cond_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "tf.PartionedCalledOp"(%arg0) {f = @callee_func} : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func @callee_func(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
}
}

View File

@ -0,0 +1,48 @@
// RUN: tf-opt %s -tf-executor-tpu-v1-island-outlining | FileCheck %s --dump-input=fail
// CHECK: func @control_input
// CHECK-NOT: func @
// CHECK-LABEL: module @_tpu_v1_compat_outlined
// CHECK: @_tpu_v1_compat_outlined_func0
// CHECK: func @while_body_with_cluster_attr
// CHECK: func @while_cond_with_cluster_attr
// CHECK: func @while_body_without_cluster_attr
// CHECK: func @while_cond_without_cluster_attr
// CHECK: func @callee_func
module {
func @control_input(%arg0: tensor<i1>) -> tensor<i32> {
%0:4 = tf_executor.graph {
%outputs:4, %control = tf_executor.island {
"tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
%1 = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
%2 = "tf.While"(%1) {body = @while_body_with_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "A", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
%3 = "tf.While"(%1) {body = @while_body_without_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "C", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
%4 = "tf.While"(%1) {body = @while_body_with_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, name = "E", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
tf_executor.yield %1, %2, %3, %4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
tf_executor.fetch %outputs#0, %outputs#1, %outputs#2, %outputs#3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
return %0#0 : tensor<i32>
}
func @while_body_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
func @while_cond_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func @while_body_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
func @while_cond_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "tf.PartionedCalledOp"(%arg0) { f = @callee_func} : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func @callee_func(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
}

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -tf-layout-assignment=force-data-format=NCHW -verify-diagnostics | FileCheck %s
// RUN: tf-opt %s -tf-layout-assignment=force-data-format=NCHW -verify-diagnostics | FileCheck %s --dump-input=always
// CHECK-LABEL: func @transposeBiasAdd
func @transposeBiasAdd(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<1x4x4x8xf32> {

View File

@ -0,0 +1,67 @@
// RUN: tf-opt %s -tf-move-transposes -verify-diagnostics | FileCheck %s --dump-input=always
// CHECK-LABEL: func @move_across_single_op
func @move_across_single_op(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
// CHECK: return %[[TANH]]
%0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
%1 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
%2 = "tf.Transpose"(%0, %1) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
return %2 : tensor<1x8x4x4xf32>
}
// CHECK-LABEL: func @move_across_multiple_ops
func @move_across_multiple_ops(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
// CHECK: %[[TANH0:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
// CHECK: %[[TANH1:[0-9]*]] = "tf.Tanh"(%[[TANH0]]) {{.*}} tensor<1x8x4x4xf32>
// CHECK: return %[[TANH1]]
%0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
%1 = "tf.Tanh"(%0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
%2 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
%3 = "tf.Transpose"(%1, %2) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
return %3 : tensor<1x8x4x4xf32>
}
// CHECK-LABEL: func @move_across_multi_operand_op
func @move_across_multi_operand_op(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
// CHECK: %[[ARG0_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
// CHECK: %[[ARG1_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg1, %[[ARG_PERM]])
// CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%[[ARG0_TRANSPOSE]], %[[ARG1_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
// CHECK: return %[[ADD]]
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
%1 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
%2 = "tf.Transpose"(%0, %1) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
return %2 : tensor<1x8x4x4xf32>
}
// CHECK-LABEL: func @move_with_multiple_uses
func @move_with_multiple_uses(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
// CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%[[TANH]], %[[TANH]]) {{.*}} tensor<1x8x4x4xf32>
// CHECK: return %[[ADD]]
%0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
%1 = "tf.AddV2"(%0, %0) : (tensor<1x4x4x8xf32>, tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
%2 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
%3 = "tf.Transpose"(%1, %2) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
return %3 : tensor<1x8x4x4xf32>
}

View File

@ -542,3 +542,116 @@ func @if_else(%arg0: tensor<*x!tf.resource<tensor<4xf32>>>, %arg1: tensor<*x!tf.
-> (tensor<*x!tf.resource<tensor<4xf32>>>) {
return %arg1 : tensor<*x!tf.resource<tensor<4xf32>>>
}
// -----
// Tests that the pass lifts resources on two partitioned call ops sharing the
// same callee. The lifting should clone the callee then modify the clone.
// CHECK-LABEL: @launch_with_partitioned_call
func @launch_with_partitioned_call() -> tensor<f32> {
// CHECK: %[[VH:.*]] = "tf.VarHandleOp"()
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
// CHECK: %[[CONST:.*]] = "tf.Const"()
%1 = "tf.Const"() {value = dense<10.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]])
// CHECK: %[[LAUNCH:.*]] = "tf_device.launch"()
%2 = "tf_device.launch"() ( {
// CHECK: %[[PC0:.*]] = "tf.PartitionedCall"(%[[CONST]], %[[READ]], %[[CONST]])
// CHECK-SAME: f = @callee_resource_lifted
%3 = "tf.PartitionedCall"(%1, %0, %1) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<f32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<f32>
// CHECK: %[[PC1:.*]] = "tf.PartitionedCall"(%[[CONST]], %[[READ]], %[[CONST]])
// CHECK-SAME: f = @callee_resource_lifted
%4 = "tf.PartitionedCall"(%1, %0, %1) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<f32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<f32>
// CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[PC0]], %[[PC1]])
%5 = "tf.AddV2"(%3, %4) : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: tf_device.return %[[ADD]] : tensor<f32>
tf_device.return %5 : tensor<f32>
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<f32>
return %2 : tensor<f32>
}
// CHECK: @callee(%[[OA0:.*]]: tensor<f32>, %[[OA1:.*]]: tensor<*x!tf.resource<tensor<f32>>>, %[[OA2:.*]]: tensor<f32>) -> tensor<f32>
func @callee(%arg0: tensor<f32>, %arg1: tensor<*x!tf.resource<tensor<f32>>>, %arg2: tensor<f32>) -> tensor<f32> {
// CHECK: "tf.ReadVariableOp"(%[[OA1]])
%0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
%1 = "tf.AddV2"(%0, %arg0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%2 = "tf.AddV2"(%1, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %2 : tensor<f32>
}
// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor<f32>, %[[A1:.*]]: tensor<f32>, %[[A2:.*]]: tensor<f32>) -> tensor<f32>
// CHECK-NEXT: %[[ADD0:.*]] = "tf.AddV2"(%[[A1]], %[[A0]])
// CHECK-NEXT: %[[ADD1:.*]] = "tf.AddV2"(%[[ADD0]], %[[A2]])
// CHECK-NEXT: return %[[ADD1]]
// -----
// Tests that the pass lifts resources on two stateful partitioned call ops
// sharing the same callee. The lifting should clone the callee then modify the
// clone.
// CHECK-LABEL: @launch_with_stateful_partitioned_call
func @launch_with_stateful_partitioned_call() -> () {
// CHECK: %[[VH0:.*]] = "tf.VarHandleOp"()
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
// CHECK: %[[VH1:.*]] = "tf.VarHandleOp"()
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<f32>>>
// CHECK: %[[CONST:.*]] = "tf.Const"()
%2 = "tf.Const"() {value = dense<10.0> : tensor<f32>} : () -> tensor<f32>
// CHECK-DAG: %[[READ0:.*]] = "tf.ReadVariableOp"(%[[VH0]])
// CHECK-DAG: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]])
// CHECK: %[[LAUNCH:.*]] = "tf_device.launch"()
"tf_device.launch"() ( {
// CHECK: %[[PC0:.*]] = "tf.StatefulPartitionedCall"(%[[READ0]], %[[READ1]], %[[CONST]])
// CHECK-SAME: f = @callee_resource_lifted
%3 = "tf.StatefulPartitionedCall"(%0, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
// CHECK: %[[PC1:.*]] = "tf.StatefulPartitionedCall"(%[[PC0]], %[[READ1]], %[[CONST]])
// CHECK-SAME: f = @callee_resource_lifted
%4 = "tf.StatefulPartitionedCall"(%3, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
// CHECK: tf_device.return %[[PC1]] : tensor<f32>
tf_device.return
// CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<f32>
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
// CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[LAUNCH]])
return
}
// CHECK: @callee(%[[OA0:.*]]: tensor<*x!tf.resource<tensor<f32>>>, %[[OA1:.*]]: tensor<*x!tf.resource<tensor<f32>>>, %[[OA2:.*]]: tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
func @callee(%arg0: tensor<*x!tf.resource<tensor<f32>>>, %arg1: tensor<*x!tf.resource<tensor<f32>>>, %arg2: tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>> {
// CHECK: "tf.ReadVariableOp"(%[[OA1]])
%0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
%1 = "tf.AddV2"(%0, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"tf.AssignVariableOp"(%arg0, %1) {dtype = i32} : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
return %arg0 : tensor<*x!tf.resource<tensor<f32>>>
}
// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor<f32>, %[[A1:.*]]: tensor<f32>, %[[A2:.*]]: tensor<f32>) -> tensor<f32>
// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[A1]], %[[A2]])
// CHECK-NEXT: return %[[ADD]]
// -----
// Tests that the pass reports error on called function that has resource output
// which doesn't alias an input.
func @launch_with_stateful_partitioned_call() -> () {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<f32>>>
%2 = "tf.Const"() {value = dense<10.0> : tensor<f32>} : () -> tensor<f32>
"tf_device.launch"() ( {
%3 = "tf.StatefulPartitionedCall"(%0, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
%4 = "tf.StatefulPartitionedCall"(%3, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
tf_device.return
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
return
}
// expected-error @+1 {{Unsupported function call: resource return value does not alias an input.}}
func @callee(%arg0: tensor<*x!tf.resource<tensor<f32>>>, %arg1: tensor<*x!tf.resource<tensor<f32>>>, %arg2: tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>> {
%0 = "tf._Unknown_"() : () -> tensor<*x!tf.resource<tensor<f32>>>
return %0 : tensor<*x!tf.resource<tensor<f32>>>
}

View File

@ -45,6 +45,17 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
return %1 : tensor<*xf32>
}
// CHECK-LABEL: func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<?xf32>
func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
br ^bb1
^bb1:
// CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK: return %[[IDENTITY]] : tensor<?xf32>
%ret = "tf.Identity"(%arg0) : (tensor<?xf32>) -> tensor<*xf32>
return %ret : tensor<*xf32>
}
// Tests the case where an inference opportunity relies on folding.
// CHECK-LABEL: func @simple_folding

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
@ -70,10 +71,20 @@ void TPUBridgeExecutorIslandInlining::runOnModule() {
call_op.emitOpError() << "Failed to inline\n";
return WalkResult::interrupt();
}
called_func.erase();
call_op.erase();
return WalkResult::advance();
});
if (walk_result.wasInterrupted()) return signalPassFailure();
// Move all remaining nested functions back into the parent module.
Block &nested_block = nested_module->getRegion(0).front();
for (FuncOp func_op :
llvm::make_early_inc_range(nested_block.getOps<FuncOp>())) {
if (!symbol_table.lookupSymbolIn(getModule(), func_op.getName())) {
nested_block.getOperations().remove(func_op.getOperation());
symbol_table.insert(func_op.getOperation());
}
}
nested_module->erase();
}

View File

@ -29,10 +29,12 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/IR/UseDefLists.h" // TF:llvm-project
#include "mlir/IR/Visitors.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
@ -57,8 +59,8 @@ constexpr llvm::StringRef kTpuStatusAttr = "_tpu_compilation_status";
// TPU-annotated operations and intended to preserve backward compatibility with
// TFv1.
struct TpuV1BridgeExecutorIslandCoarsening
: public FunctionPass<TpuV1BridgeExecutorIslandCoarsening> {
void runOnFunction() override;
: public ModulePass<TpuV1BridgeExecutorIslandCoarsening> {
void runOnModule() override;
};
// Sort the Operations in the provided range to enforce dominance.
@ -88,9 +90,10 @@ LogicalResult SortTopologically(Block::iterator first_op,
Operation* producer_in_block =
block->findAncestorOpInBlock(*defining_op);
if (producer_in_block && producer_in_block != &op &&
unscheduled_ops.count(producer_in_block))
unscheduled_ops.count(producer_in_block)) {
// Found an operand that isn't scheduled yet, interrupt the walk.
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
@ -113,7 +116,9 @@ LogicalResult SortTopologically(Block::iterator first_op,
// A failure is returned if a cycle preventing the merge from happening
// correctly without breaking dominance. The IR is left in invalid state in case
// of failure.
LogicalResult MergeIsland(Operation* op, bool* changed) {
LogicalResult MergeIsland(llvm::function_ref<bool(StringAttr, Operation*)>
is_op_calling_func_for_cluster,
Operation* op, bool* changed) {
// Find the first island wrapping a single operation with the `_tpu_replicate`
// attribute, it'll be used as the root of the algorithm to find the other
// operations that are part of the same cluster.
@ -146,7 +151,9 @@ LogicalResult MergeIsland(Operation* op, bool* changed) {
if (!candidate_cluster_name)
candidate_cluster_name =
candidate_wrapped_op.getAttrOfType<StringAttr>(kTpuStatusAttr);
if (candidate_cluster_name != cluster_name) continue;
if (candidate_cluster_name != cluster_name &&
!is_op_calling_func_for_cluster(cluster_name, &candidate_wrapped_op))
continue;
// Look at captured operands to bring-in ReplicatedInputOp in the
// island as well. TODO: also pull in tf.Const, some optimizations can
@ -250,34 +257,71 @@ LogicalResult MergeIsland(Operation* op, bool* changed) {
first_op_after);
}
void TpuV1BridgeExecutorIslandCoarsening::runOnFunction() {
getFunction().walk([&](GraphOp graph) {
Block& graph_body = graph.GetBody();
void TpuV1BridgeExecutorIslandCoarsening::runOnModule() {
SymbolTable symbol_table(getModule());
// Iterate until fixed point on the block, as it may contain multiple
// clusters.
bool changed = true;
while (changed) {
changed = false;
for (Operation& op : graph_body) {
if (failed(MergeIsland(&op, &changed))) {
graph.emitError() << "Merging island failed: the TPU cluster likely "
<< "contains a cycle with non-TPU operations\n";
signalPassFailure();
return WalkResult::interrupt();
}
// If islands were merged, restart scanning the block from the beginning
// as we lost track of where to continue.
if (changed) break;
}
// Map tpu cluster names to the functions that contain operations for this
// cluster.
DenseMap<StringRef, DenseSet<FuncOp>> tpu_funcs;
for (FuncOp func_op : getModule().getOps<FuncOp>()) {
func_op.walk([&](Operation* op) {
StringAttr cluster_name =
op->getAttrOfType<StringAttr>(kTpuReplicateAttr);
if (!cluster_name)
cluster_name = op->getAttrOfType<StringAttr>(kTpuStatusAttr);
if (!cluster_name) return;
tpu_funcs[cluster_name.getValue()].insert(func_op);
});
}
// Return true if the operation is containing a reference to a function
// containing operations for this cluster.
auto is_op_calling_func_for_cluster = [&](StringAttr cluster, Operation* op) {
auto funcs_for_cluster = tpu_funcs.find(cluster.getValue());
assert(funcs_for_cluster != tpu_funcs.end());
assert(!funcs_for_cluster->second.empty());
if (funcs_for_cluster->second.size() == 1) return false;
for (NamedAttribute attr : op->getAttrs()) {
auto symbol_ref = attr.second.dyn_cast<FlatSymbolRefAttr>();
if (!symbol_ref) continue;
FuncOp callee = symbol_table.lookup<FuncOp>(symbol_ref.getValue());
if (!callee) continue;
if (funcs_for_cluster->second.count(callee)) return true;
}
return WalkResult::advance();
});
return false;
};
for (FuncOp func_op : getModule().getOps<FuncOp>()) {
func_op.walk([&](GraphOp graph) {
Block& graph_body = graph.GetBody();
// Iterate until fixed point on the block, as it may contain multiple
// clusters.
bool changed = true;
while (changed) {
changed = false;
for (Operation& op : graph_body) {
if (failed(
MergeIsland(is_op_calling_func_for_cluster, &op, &changed))) {
graph.emitError()
<< "Merging island failed: the TPU cluster likely "
<< "contains a cycle with non-TPU operations\n";
signalPassFailure();
return WalkResult::interrupt();
}
// If islands were merged, restart scanning the block from the
// beginning as we lost track of where to continue.
if (changed) break;
}
}
return WalkResult::advance();
});
}
}
} // namespace
std::unique_ptr<OpPassBase<FuncOp>>
std::unique_ptr<OpPassBase<ModuleOp>>
CreateTFExecutorTPUV1IslandCoarseningPass() {
return std::make_unique<TpuV1BridgeExecutorIslandCoarsening>();
}

View File

@ -133,9 +133,23 @@ void TPUBridgeExecutorIslandOutlining::runOnModule() {
/*executor_type=*/builder.getStringAttr(""));
SmallVector<Value, 16> yield_operands(call_op.getResults());
builder.create<YieldOp>(island_op.getLoc(), yield_operands);
}
// TODO(aminim): handle transitively referenced function and clone them in
// the new module.
// Outlined all the transitively called functions by moving them in the
// outlined module.
for (FuncOp func : outlined_module.getOps<FuncOp>()) {
func.walk([&](Operation *op) {
for (NamedAttribute attr : op->getAttrs()) {
auto symbol_ref = attr.second.dyn_cast<FlatSymbolRefAttr>();
if (!symbol_ref) continue;
if (outlined_symbol_table.lookup<FuncOp>(symbol_ref.getValue()))
continue;
FuncOp callee = symbol_table.lookup<FuncOp>(symbol_ref.getValue());
callee.getOperation()->getBlock()->getOperations().remove(
callee.getOperation());
outlined_symbol_table.insert(callee);
}
});
}
}

View File

@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/STLExtras.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
@ -25,6 +28,8 @@ namespace TF {
namespace {
// LayoutAssignmentPass assigns optimal data layout (data format) for all
// layout sensitive operations.
class LayoutAssignmentPass : public FunctionPass<LayoutAssignmentPass> {
public:
LayoutAssignmentPass() = default;
@ -39,6 +44,14 @@ class LayoutAssignmentPass : public FunctionPass<LayoutAssignmentPass> {
llvm::cl::desc("Force data format for all layout sensitive ops")};
};
// MoveTransposesPass moves all Transpose ops to the beginning or to the end of
// the basic block where they are defined. This will allow canonicalzer to
// delete redundant transposes.
class MoveTransposesPass : public FunctionPass<MoveTransposesPass> {
public:
void runOnFunction() final;
};
using Permutation = SmallVector<int64_t, 4>;
Permutation GetDataFormatPermutation(StringRef from_data_format,
@ -128,10 +141,116 @@ void LayoutAssignmentPass::runOnFunction() {
});
}
// Move Transpose operations that permute `op` results before the `op`.
void MoveTransposeBefore(Operation* op, SmallVector<Operation*, 8>* work_list) {
// TODO(ezhulenev): Move transpose across layout sensitive operations.
if (!op->hasTrait<OpTrait::TF::LayoutAgnostic>()) return;
// Transpose operations that use operation results.
SmallVector<TransposeOp, 2> transpose_ops;
// Constant operation that defines permutation indices for result transposes.
ConstOp permutation_op;
// All operation results must be used by transpose operations with the same
// permutation indices.
for (OpResult result : op->getResults()) {
for (Operation* user : result.getUsers()) {
// Result user must be a transpose operation.
TransposeOp transpose = dyn_cast<TransposeOp>(user);
if (!transpose) return;
// With permutation defined by constant operation.
ConstOp perm =
dyn_cast_or_null<ConstOp>(transpose.getOperand(1).getDefiningOp());
if (!perm) return;
// With the same permutation indices.
auto dense_elem_attr = perm.value().dyn_cast<DenseElementsAttr>();
if (!dense_elem_attr) return;
if (!permutation_op) permutation_op = perm;
// Check that permutation matches for all result transposes.
if (perm.value() != permutation_op.value()) return;
// Add a transpose operation for later reuse.
transpose_ops.push_back(transpose);
}
}
// Nothing to do here.
if (!permutation_op || transpose_ops.empty()) return;
// At this point we checked that we can safely move Transpose node before
// `op`, and bypass all result transposes.
Location loc = op->getLoc();
// Move constant op defining result permutation to the beginning of the block.
permutation_op.getOperation()->moveBefore(&op->getBlock()->front());
// Bypass Transpose nodes for all results.
for (OpResult result : op->getResults()) {
result.setType(cast<TransposeOp>(*result.getUsers().begin()).y().getType());
for (Operation* transpose : result.getUsers()) {
transpose->getResult(0).replaceAllUsesWith(result);
}
}
// Maybe add a Transpose node for all operands (or reuse existing transposes).
OpBuilder builder(op);
builder.setInsertionPoint(op);
for (OpOperand& operand : op->getOpOperands()) {
// Try to push transpose further up.
if (Operation* operand_op = operand.get().getDefiningOp())
work_list->push_back(operand_op);
// Try to reuse result transposes.
TransposeOp transpose;
if (!transpose_ops.empty()) {
transpose = transpose_ops.pop_back_val();
transpose.getOperation()->moveBefore(op);
transpose.setOperand(0, operand.get());
transpose.setOperand(1, permutation_op);
} else {
transpose =
builder.create<TransposeOp>(loc, operand.get(), permutation_op);
}
operand.set(transpose);
}
// Remove unused transpose operations.
while (!transpose_ops.empty()) {
TransposeOp transpose = transpose_ops.pop_back_val();
transpose.erase();
}
}
void MoveTransposesPass::runOnFunction() {
FuncOp func = getFunction();
SmallVector<Operation*, 8> work_list;
func.walk([&](TransposeOp transpose) {
for (auto operand : transpose.getOperands()) {
if (auto op = operand.getDefiningOp()) work_list.push_back(op);
}
});
while (!work_list.empty()) {
Operation* op = work_list.pop_back_val();
MoveTransposeBefore(op, &work_list);
}
}
} // namespace
static PassRegistration<LayoutAssignmentPass> pass("tf-layout-assignment",
"Layout assignment pass");
static PassRegistration<LayoutAssignmentPass> layout_assignment(
"tf-layout-assignment", "Layout assignment pass");
static PassRegistration<MoveTransposesPass> move_transposes(
"tf-move-transposes", "Move transposes pass");
} // namespace TF
} // namespace mlir

View File

@ -106,7 +106,8 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorIslandCoarseningPass();
// Creates a pass to merge IslandOps for operation marked for execution on TPU.
// This is a V1 backward compatibility.
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorTPUV1IslandCoarseningPass();
std::unique_ptr<OpPassBase<ModuleOp>>
CreateTFExecutorTPUV1IslandCoarseningPass();
// Creates a pass to outlining TPU clusters from single IslandOp into a nested
// module suitable for being processed as-if it was a V2 module.

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
@ -811,16 +812,185 @@ LogicalResult HanldeIfOP(TF::IfOp if_op, FuncOp then_branch,
return success();
}
// A resource-lifted function for (potentially multiple) PartitionedCallOps and
// information about the lifting changes.
struct PartitionedCallLiftingInfo {
// Function with resources lifted. Can be nullptr if nothing needs to change.
FuncOp lifted_callee;
// Mapping from old resource outputs to their aliasing output inputs.
llvm::SmallDenseMap<int64_t, int64_t> old_outputs_aliasing_old_inputs;
// Mapping from old to new output indices in case any output is removed.
llvm::SmallVector<int64_t, 4> old_to_new_output_indices;
// ResourceArgUseInfo for each old resource argument.
llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> use_info;
// Input for AddLoadsStoresOutsideControlFlowOp(), see its comment.
llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
arg_data_type_and_updated_output_index;
};
// Lifts loads/stores from a PartitionedCallOp's callee function. If anything
// needs to be changed, the original function will be preserved, and the lifting
// happens on a clone, which will be stored in `result`.
LogicalResult HandlePartitionedCallOpCallee(
FuncOp callee, PartitionedCallLiftingInfo* result) {
// Remove identity nodes to avoid aliasing.
RemoveIdentity(&callee.front());
// Sanity check: return of resources should be aliases of inputs. Such outputs
// will be removed later.
int64_t non_resource_results = 0;
for (auto entry :
llvm::enumerate(callee.front().getTerminator()->getOperands())) {
auto retval = entry.value();
if (!getElementTypeOrSelf(retval.getType()).isa<TF::ResourceType>()) {
result->old_to_new_output_indices.push_back(non_resource_results++);
continue;
}
auto aliasing_arg = retval.dyn_cast<BlockArgument>();
if (!aliasing_arg) {
return callee.emitOpError(
"Unsupported function call: resource return value does not alias an "
"input.");
}
result->old_outputs_aliasing_old_inputs[entry.index()] =
aliasing_arg.getArgNumber();
result->old_to_new_output_indices.push_back(-1);
}
if (failed(FindResourceArgUseInfo(callee, &result->use_info))) {
return failure();
}
if (result->use_info.empty()) {
result->lifted_callee = nullptr;
return success();
}
// Clone the callee before making changes.
SmallString<64> name_base = callee.getName();
auto module = callee.getParentOfType<ModuleOp>();
name_base += "_resource_lifted";
auto name = name_base;
{
int64_t counter = 0;
while (module.lookupSymbol(name)) {
auto name = name_base;
name += "_" + std::to_string(counter++);
}
}
callee = callee.clone();
callee.setName(name);
SymbolTable(module).insert(callee);
result->lifted_callee = callee;
// Remove unused resources in functions.
llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
RemoveUnusedResourceArgumentsAndForwardedRetvals(
result->use_info, callee, /*old_to_new_arg_indices=*/nullptr,
&remaining_resource_data_types);
for (const auto& entry : remaining_resource_data_types) {
result->arg_data_type_and_updated_output_index[entry.getFirst()] = {
entry.getSecond(), -1};
}
llvm::SmallVector<Value, 4> new_retvals;
for (auto val : callee.front().getTerminator()->getOperands()) {
// Remove resource type outputs.
if (getElementTypeOrSelf(val.getType()).isa<TF::ResourceType>()) continue;
new_retvals.push_back(val);
}
// Lift resources.
LiftArgRetResourcesForFunction(
callee, remaining_resource_data_types, [&](int64_t index, Value value) {
result->arg_data_type_and_updated_output_index[index].second =
new_retvals.size();
new_retvals.push_back(value);
});
auto old_return = callee.front().getTerminator();
// Replace old return with the new ones with update values.
OpBuilder builder(old_return);
auto new_return = builder.create<ReturnOp>(old_return->getLoc(), new_retvals);
old_return->erase();
callee.setType(FunctionType::get(
callee.getType().getInputs(),
llvm::to_vector<4>(new_return.getOperandTypes()), callee.getContext()));
return success();
}
// Updates a PartitionedCallOp/StatefulPartitionedCallOp according to the
// resource-lifted new callee function in lifting_info.
template <typename CallOpType>
void UpdatePartitionedCallOpWithNewCallee(
CallOpType call_op, const PartitionedCallLiftingInfo& lifting_info) {
if (lifting_info.lifted_callee == nullptr) return;
// Replace output resource uses with the aliasing input, so that we can remove
// this output.
for (const auto& entry : lifting_info.old_outputs_aliasing_old_inputs) {
call_op.getResult(entry.getFirst())
.replaceAllUsesWith(call_op.getOperand(entry.getSecond()));
}
// Recreate the call op.
OpBuilder builder(call_op);
// Now use the filtered original operands, which will be replaced by
// AddLoadsStoresOutsideControlFlowOp().
auto new_operands =
FilterRange<Value, OperandRange>(call_op.args(), lifting_info.use_info);
auto new_call = builder.create<CallOpType>(
call_op.getLoc(),
const_cast<FuncOp&>(lifting_info.lifted_callee).getType().getResults(),
new_operands, call_op.getAttrs());
new_call.setAttr(
"f", builder.getSymbolRefAttr(
const_cast<FuncOp&>(lifting_info.lifted_callee).getName()));
AddLoadsStoresOutsideControlFlowOp(
new_call, lifting_info.arg_data_type_and_updated_output_index);
// Replace uses.
for (int64_t i = 0; i < lifting_info.old_to_new_output_indices.size(); ++i) {
if (lifting_info.old_to_new_output_indices[i] >= 0) {
call_op.getResult(i).replaceAllUsesWith(
new_call.getResult(lifting_info.old_to_new_output_indices[i]));
}
}
call_op.erase();
}
LogicalResult HoistForFunctionalControlFlow(
Block*, ModuleOp, llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>*);
// A templated routine for handling both PartitionedCallOp and
// StatefulPartitionedCallOp. If the callee is already lifted, it just updates
// the caller op itself; otherwise, it first recursively handles nested control
// flow, then performs lifting on the callee.
template <typename CallOpType>
LogicalResult HandlePartitionedCallOp(
CallOpType call_op, FuncOp callee, ModuleOp module,
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>* lifted_callees) {
auto emplace_res =
lifted_callees->try_emplace(callee, PartitionedCallLiftingInfo());
if (emplace_res.second) {
// Unseen callee. Perform resource lifting on it.
HoistForFunctionalControlFlow(&callee.front(), module, lifted_callees);
if (failed(HandlePartitionedCallOpCallee(
callee, &emplace_res.first->getSecond()))) {
return failure();
}
}
UpdatePartitionedCallOpWithNewCallee(call_op, emplace_res.first->getSecond());
return success();
}
// Hoists resource loads/stores from control flow ops in `block` outside the
// body/cond/branch functions.
LogicalResult HoistForFunctionalControlFlow(Block* block, ModuleOp module) {
// body/cond/branch/callee functions.
LogicalResult HoistForFunctionalControlFlow(
Block* block, ModuleOp module,
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>*
lifted_partitioned_call_callees) {
for (Operation& op : llvm::make_early_inc_range(*block)) {
if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
auto body = llvm::cast<FuncOp>(module.lookupSymbol(while_op.body()));
auto cond = llvm::cast<FuncOp>(module.lookupSymbol(while_op.cond()));
// Recursively handle the nested control flow.
HoistForFunctionalControlFlow(&body.front(), module);
HoistForFunctionalControlFlow(&cond.front(), module);
HoistForFunctionalControlFlow(&body.front(), module,
lifted_partitioned_call_callees);
HoistForFunctionalControlFlow(&cond.front(), module,
lifted_partitioned_call_callees);
if (failed(HanldeWhileLoop(while_op, body, cond))) return failure();
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
auto then_branch =
@ -828,9 +998,30 @@ LogicalResult HoistForFunctionalControlFlow(Block* block, ModuleOp module) {
auto else_branch =
llvm::cast<FuncOp>(module.lookupSymbol(if_op.else_branch()));
// Recursively handle the nested control flow.
HoistForFunctionalControlFlow(&then_branch.front(), module);
HoistForFunctionalControlFlow(&else_branch.front(), module);
HoistForFunctionalControlFlow(&then_branch.front(), module,
lifted_partitioned_call_callees);
HoistForFunctionalControlFlow(&else_branch.front(), module,
lifted_partitioned_call_callees);
if (failed(HanldeIfOP(if_op, then_branch, else_branch))) return failure();
} else if (auto call_op = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
if (!call_op.f().isa<FlatSymbolRefAttr>()) {
return call_op.emitError(
"Resource lifting does not support call with nested references.");
}
auto callee = llvm::cast<FuncOp>(
module.lookupSymbol(call_op.f().getRootReference()));
if (failed(HandlePartitionedCallOp(call_op, callee, module,
lifted_partitioned_call_callees))) {
// Nested control flow handling is done in HandlePartitionedCallOp().
return failure();
}
} else if (auto call_op =
llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
auto callee = llvm::cast<FuncOp>(module.lookupSymbol(call_op.f()));
if (failed(HandlePartitionedCallOp(call_op, callee, module,
lifted_partitioned_call_callees))) {
return failure();
}
}
}
return success();
@ -840,10 +1031,13 @@ LogicalResult HoistForFunctionalControlFlow(Block* block, ModuleOp module) {
// outside. Returns failure if there are remaining resource-type values that can
// not be lifted.
void ResourceOpLiftingPass::runOnModule() {
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>
lifted_partitioned_call_callees;
auto result = getModule().walk([&](FuncOp func_op) {
return func_op.walk([&](tf_device::LaunchOp launch_op) {
if (failed(HoistForFunctionalControlFlow(&launch_op.GetBody(),
getModule())) ||
if (failed(HoistForFunctionalControlFlow(
&launch_op.GetBody(), getModule(),
&lifted_partitioned_call_callees)) ||
failed(HoistResourceOpsFromLaunchOp(launch_op))) {
return WalkResult::interrupt();
}
@ -901,8 +1095,11 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) {
<< function.getBlocks().size();
}
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>
lifted_partitioned_call_callees;
return HoistForFunctionalControlFlow(&function.front(),
cast<ModuleOp>(function.getParentOp()));
cast<ModuleOp>(function.getParentOp()),
&lifted_partitioned_call_callees);
}
} // namespace TF

View File

@ -60,16 +60,23 @@ namespace TF {
namespace {
Optional<llvm::SmallVector<mlir::Type, 4>> InferShapeForFunctionReturnType(
FuncOp func) {
// Only infer shape when there is one return op for now.
if (!has_single_element(func.getBody()) || func.front().empty()) {
// Find any return ops.
SmallVector<ReturnOp, 4> return_ops;
for (Block& block : func) {
if (auto return_op = dyn_cast<ReturnOp>(block.getTerminator())) {
return_ops.push_back(return_op);
}
}
// Right now we only handle the case of a single return op.
// To handle multiple return ops, we would need to look at all their shapes
// and come up with a common shape and insert appropriate casts.
if (return_ops.size() != 1) {
return None;
}
// Find the return type.
auto return_op = dyn_cast<mlir::ReturnOp>(func.front().back());
if (!return_op) {
return None;
}
auto return_op = return_ops.front();
// Manually fold tf.Cast that precedes the return instruction and only differs
// in shape refinement level.

View File

@ -263,6 +263,7 @@ tf_device::ReplicateOp AddInputsToReplicateOp(
llvm::SmallVector<std::pair<llvm::ArrayRef<Value>, Type>, 8>
new_replicated_inputs;
llvm::SmallVector<llvm::SmallVector<Value, 8>, 8> replicated_inputs;
replicated_inputs.reserve(replicate.GetBody().getNumArguments());
for (auto arg : llvm::enumerate(replicate.GetBody().getArguments())) {
int64_t i = arg.index();
replicated_inputs.emplace_back();

View File

@ -42,8 +42,8 @@ namespace mlir {
namespace {
struct BreakUpIslands : OperationPass<BreakUpIslands, FuncOp> {
void runOnOperation() final;
struct BreakUpIslands : FunctionPass<BreakUpIslands> {
void runOnFunction() final;
void BreakUpIsland(tf_executor::IslandOp island_op,
const TF::SideEffectAnalysis& side_effect_analysis,
@ -51,8 +51,8 @@ struct BreakUpIslands : OperationPass<BreakUpIslands, FuncOp> {
new_control_inputs);
};
void BreakUpIslands::runOnOperation() {
auto graph_op_range = getOperation().getBody().front().without_terminator();
void BreakUpIslands::runOnFunction() {
auto graph_op_range = getFunction().getBody().front().without_terminator();
tf_executor::GraphOp graph_op;
if (graph_op_range.begin() != graph_op_range.end() &&
std::next(graph_op_range.begin()) == graph_op_range.end()) {

View File

@ -63,21 +63,21 @@ Status StatusScopedDiagnosticHandler::Combine(Status status) {
}
LogicalResult StatusScopedDiagnosticHandler::handler(Diagnostic* diag) {
#ifndef NDEBUG
// Non-error diagnostic are ignored when VLOG isn't enabled.
if (diag->getSeverity() != DiagnosticSeverity::Error && VLOG_IS_ON(1))
return success();
size_t current_diag_str_size_ = diag_str_.size();
#endif
// Emit the diagnostic and flush the stream.
emitDiagnostic(*diag);
diag_stream_.flush();
#ifndef NDEBUG
// Emit non-errors to VLOG instead of the internal status.
if (diag->getSeverity() != DiagnosticSeverity::Error) {
VLOG(1) << diag_str_.substr(current_diag_str_size_);
diag_str_.resize(current_diag_str_size_);
}
#endif
// Return failure to signal propagation if necessary.
return failure(propagate_);

View File

@ -370,6 +370,22 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
Convert(interior_padding))
.getOperation();
}
case HloOpcode::kScatter: {
auto scatter = static_cast<HloScatterInstruction*>(instruction);
attributes.push_back(
ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers()));
attributes.push_back(builder_->getNamedAttr(
"indices_are_sorted",
builder_->getBoolAttr(scatter->indices_are_sorted())));
attributes.push_back(builder_->getNamedAttr(
"unique_indices", builder_->getBoolAttr(scatter->unique_indices())));
auto scatter_op = func_builder->create<mlir::xla_hlo::ScatterOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportComputation(scatter->to_apply(),
&scatter_op.update_computation()));
return scatter_op.getOperation();
}
case HloOpcode::kSetDimensionSize: {
attributes.push_back(builder_->getNamedAttr(
"dimension", builder_->getIntegerAttr(builder_->getIntegerType(32),
@ -385,6 +401,16 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
ConvertDimensions(instruction->slice_strides()))
.getOperation();
}
case HloOpcode::kSort: {
auto sort_instruction = static_cast<HloSortInstruction*>(instruction);
auto sort_op = func_builder->create<mlir::xla_hlo::SortOp>(
loc, result_type, operands,
builder_->getI64IntegerAttr(sort_instruction->sort_dimension()),
builder_->getBoolAttr(sort_instruction->is_stable()));
TF_RETURN_IF_ERROR(ImportComputation(sort_instruction->to_apply(),
&sort_op.comparator()));
return sort_op.getOperation();
}
case HloOpcode::kConditional: {
llvm::SmallVector<Type, 4> rets;
TF_RETURN_IF_ERROR(GetMlirTypes(
@ -834,6 +860,22 @@ mlir::NamedAttribute HloFunctionImporter::ConvertGatherDimensionNumbers(
return builder_->getNamedAttr("dimension_numbers", attr);
}
mlir::NamedAttribute HloFunctionImporter::ConvertScatterDimensionNumbers(
const xla::ScatterDimensionNumbers& dnums) {
std::vector<int64_t> update_window_dims(dnums.update_window_dims().begin(),
dnums.update_window_dims().end());
std::vector<int64_t> inserted_window_dims(
dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end());
std::vector<int64_t> scatter_dims_to_operand_dims(
dnums.scatter_dims_to_operand_dims().begin(),
dnums.scatter_dims_to_operand_dims().end());
auto attr = mlir::xla_hlo::ScatterDimensionNumbers::get(
Convert(update_window_dims), Convert(inserted_window_dims),
Convert(scatter_dims_to_operand_dims),
builder_->getI64IntegerAttr(dnums.index_vector_dim()), context_);
return builder_->getNamedAttr("scatter_dimension_numbers", attr);
}
mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs(
const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
source_target_pairs) {

View File

@ -121,6 +121,10 @@ class HloFunctionImporter {
mlir::NamedAttribute ConvertGatherDimensionNumbers(
const xla::GatherDimensionNumbers& dnums);
// Converts the scatter dimensions to attributes.
mlir::NamedAttribute ConvertScatterDimensionNumbers(
const xla::ScatterDimensionNumbers& dnums);
// Converts XLA instruction source target pairs to MLIR attribute.
mlir::NamedAttribute ConvertSourceTargetPairs(
const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&

View File

@ -60,6 +60,13 @@ def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>;
def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>;
// Dynamic representation of a shape vector as a tensor. Ideally this would be
// an index type (as it stores indices) but that is currently disallowed in
// MLIR.
def HLO_DimensionTensor : ShapedContainerType<
[AnyInteger], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
"a 1D tensor of dimensions">;
// In general, static shaped tensor constraints should be avoided unless
// it is for a legacy op which is only correct with static shapes.
def HLO_StaticShapeTensor : StaticShapeTensorOf<[
@ -771,10 +778,22 @@ def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim",
}
def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim",
[NoSideEffect]>, BASE_HLO_DynamicBroadcastInDimOp {
[NoSideEffect]> {
string summary = "Broadcast a tensor into the given dynamic shape by adding dimensions.";
string description = [{
This is a generalization of the BroadcastInDimOp which accepts its output
dimensions as an argument. It should eventually supercede the statically
shaped original, but is being phased as a separate op in order to support
compatibility with lowerings and translations that precede dynamic
shapes.
Note that the `broadcast_dimensions` attribute is optional and if omitted,
it is assumed to be an ordered, right-aligned mapping from input to
output dimensions.
}];
let arguments = (ins
HLO_Tensor:$operand,
HLO_BASE_DimensionTensor:$output_dimensions,
HLO_DimensionTensor:$output_dimensions,
BroadcastDimAttr:$broadcast_dimensions
);

View File

@ -27,13 +27,6 @@ def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
// matching the matrix to dimensions 1 and 2 of the cuboid.
def BroadcastDimAttr : OptionalAttr<I64ElementsAttr>;
// Dynamic representation of a shape vector as a tensor. Ideally this would be
// an index type (as it stores indices) but that is currently disallowed in
// MLIR.
def HLO_BASE_DimensionTensor : ShapedContainerType<
[AnyInteger], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
"a 1D tensor of dimensions">;
//===----------------------------------------------------------------------===//
// XLA nullary op definitions.
//===----------------------------------------------------------------------===//
@ -817,22 +810,6 @@ class BASE_HLO_BroadcastInDimOp {
}];
}
class BASE_HLO_DynamicBroadcastInDimOp {
string summary = "Broadcast a tensor into the given dynamic shape by adding dimensions.";
string description = [{
This is a generalization of the BroadcastInDimOp which accepts its output
dimensions as an argument. It should eventually supercede the statically
shaped original, but is being phased as a separate op in order to support
compatibility with lowerings and translations that precede dynamic
shapes.
Note that the `broadcast_dimensions` attribute is optional and if omitted,
it is assumed to be an ordered, right-aligned mapping from input to
output dimensions.
}];
}
class BASE_HLO_CholeskyOp {
string summary = "Cholesky operator";

View File

@ -242,16 +242,6 @@ def LHLO_BroadcastInDimOp : LHLO_Op<"broadcast_in_dim",
);
}
def HLO_DynamicBroadcastInDimOp : LHLO_Op<"dynamic_broadcast_in_dim",
[NoSideEffect]>, BASE_HLO_DynamicBroadcastInDimOp {
let arguments = (ins
LHLO_Buffer:$operand,
HLO_BASE_DimensionTensor:$output_dimensions,
LHLO_Buffer:$output,
BroadcastDimAttr:$broadcast_dimensions
);
}
def LHLO_ClampOp : LHLO_Op<"clamp", []>, BASE_HLO_ClampOp {
let arguments = (ins
LHLO_Buffer:$min,

View File

@ -1,32 +1,57 @@
// RUN: tf-opt -lhlo-fuse-linalg %s -o - | FileCheck %s
// RUN: tf-opt -lhlo-fuse-linalg %s -o - | FileCheck %s --dump-input=always
// RUN: tf-opt -lhlo-fuse-linalg -tile-sizes-for-linalg-fusion=2,3 %s -o - | FileCheck %s -check-prefix=TILED --dump-input-on-failure
// RUN: tf-opt -lhlo-fuse-linalg -tile-to-parallel-loops-for-linalg-fusion %s -o - | FileCheck %s -check-prefix=PLOOP --dump-input-on-failure
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
%temp_result = alloc() {temp = true} : memref<2x2xf32>
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
%temp_result = alloc() {temp = true} : memref<6x6xf32>
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
%out = addf %summand_1_in, %summand_2_in : f32
linalg.yield %out : f32
} : memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result {
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
%out = mulf %temp_result_in, %multiplier_in : f32
linalg.yield %out : f32
} : memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>
dealloc %temp_result : memref<2x2xf32>
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
dealloc %temp_result : memref<6x6xf32>
"xla_lhlo.terminator"() : () -> ()
}
// CHECK-LABEL: func @fusion
// CHECK-NOT: linalg.generic
// CHECK: loop.for
// CHECK: loop.for
// CHECK-NOT: loop.for
// CHECK: linalg.generic
// CHECK: addf
// CHECK: linalg.generic
// CHECK: mulf
// CHECK: %[[C1:.*]] = constant 1
// CHECK-NOT: linalg.generic
// CHECK: loop.for {{.*}} step %[[C1]]
// CHECK: loop.for {{.*}} step %[[C1]]
// CHECK-NOT: loop.for
// CHECK: linalg.generic
// CHECK: addf
// CHECK: linalg.generic
// CHECK: mulf
// TILED-LABEL: func @fusion
// TILED-DAG: %[[C2:.*]] = constant 2
// TILED-DAG: %[[C3:.*]] = constant 3
// TILED-NOT: linalg.generic
// TILED: loop.for {{.*}} step %[[C2]]
// TILED: loop.for {{.*}} step %[[C3]]
// TILED-NOT: loop.for
// TILED: linalg.generic
// TILED: addf
// TILED: linalg.generic
// TILED: mulf
// PLOOP-LABEL: func @fusion
// PLOOP-NOT: linalg.generic
// PLOOP: loop.parallel
// PLOOP-NOT: loop.parallel
// PLOOP: linalg.generic
// PLOOP: addf
// PLOOP: linalg.generic
// PLOOP: mulf
func @fusion_of_three(%arg0: memref<100x10xf32>,
%arg1: memref<100xf32>,
@ -67,12 +92,36 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
return
}
// CHECK-LABEL: func @fusion
// CHECK-NOT: linalg.generic
// CHECK: loop.for
// CHECK: loop.for
// CHECK-NOT: loop.for
// CHECK: linalg.generic
// CHECK: linalg.generic
// CHECK: subf
// CHECK: linalg.generic
// CHECK: exp
// CHECK: %[[C1:.*]] = constant 1
// CHECK-NOT: linalg.generic
// CHECK: loop.for {{.*}} step %[[C1]]
// CHECK: loop.for {{.*}} step %[[C1]]
// CHECK-NOT: loop.for
// CHECK: linalg.generic
// CHECK: linalg.generic
// CHECK: subf
// CHECK: linalg.generic
// CHECK: exp
// TILED-LABEL: func @fusion_of_three
// TILED-DAG: %[[C2:.*]] = constant 2
// TILED-DAG: %[[C3:.*]] = constant 3
// TILED-NOT: linalg.generic
// TILED: loop.for {{.*}} step %[[C2]]
// TILED: loop.for {{.*}} step %[[C3]]
// TILED-NOT: loop.for
// TILED: linalg.generic
// TILED: linalg.generic
// TILED: subf
// TILED: linalg.generic
// TILED: exp
// PLOOP-LABEL: func @fusion_of_three
// PLOOP-NOT: linalg.generic
// PLOOP: loop.parallel
// PLOOP-NOT: loop.parallel
// PLOOP: linalg.generic
// PLOOP: linalg.generic
// PLOOP: subf
// PLOOP: linalg.generic
// PLOOP: exp

View File

@ -179,6 +179,22 @@ func @iota(%out: memref<7x10xi64>) {
// -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
// CHECK-LABEL: func @dynamic_broadcast
func @dynamic_broadcast(%operand: memref<?x?x?xf32>,
%result: memref<?x?x?x?x?xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result)
{broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>}
: (memref<?x?x?xf32>, memref<?x?x?x?x?xf32>) -> ()
return
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, 0)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
// CHECK-LABEL: func @broadcast

View File

@ -152,13 +152,6 @@ func @broadcast_in_dim_zero_rank_memref(%arg0: memref<i32>, %out: memref<1x2x3xi
// -----
// CHECK-LABEL: func @dynamic_broadcast_in_dim_memref
func @dynamic_broadcast_in_dim_memref(%arg0: memref<?x?xi32>, %out: memref<?x?x?xi32>, %shape: tensor<3xi64>) -> () {
"xla_lhlo.dynamic_broadcast_in_dim"(%arg0, %shape, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<?x?xi32>, tensor<3xi64>, memref<?x?x?xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @reduce_memref
func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf32>) -> () {

View File

@ -716,6 +716,37 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
ROOT %Arg_0.1 = f32[] parameter(0)
}
// Test scatter
%update_computation {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
ROOT %sum = f32[] add(f32[] %lhs, f32[] %rhs)
}
%test_scatter {
%input_tensor = f32[200,100,300] parameter(0)
%scatter_indices = s64[10,2] parameter(1)
%updates = f32[10,300] parameter(2)
ROOT %scatter = f32[200,100,300] scatter(f32[200,100,300] %input_tensor, s64[10,2] %scatter_indices, f32[10,300] %updates), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, to_apply=%update_computation
}
// CHECK-LABEL: func @test_scatter
// CHECK-SAME: [[ARG_0:%.*]]: tensor<200x100x300xf32>, [[ARG_1:%.*]]: tensor<10x2xi64>, [[ARG_2:%.*]]: tensor<10x300xf32>) -> tensor<200x100x300xf32>
// CHECK: "xla_hlo.scatter"([[ARG_0]], [[ARG_1]], [[ARG_2]]) ( {
// CHECK: ^bb0([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
// CHECK: [[ADD:%.*]] = xla_hlo.add [[LHS]], [[RHS]]
// CHECK: "xla_hlo.return"([[ADD]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK-SAME: indices_are_sorted = false
// CHECK-SAME: scatter_dimension_numbers = {
// CHECK-SAME: index_vector_dim = 1 : i64
// CHECK-SAME: inserted_window_dims = dense<[0, 1]> : tensor<2xi64>
// CHECK-SAME: scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>
// CHECK-SAME: update_window_dims = dense<1> : tensor<1xi64>
// CHECK-SAME: }
// CHECK-SAME: unique_indices = false
// CHECK-LABEL: func @test_select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
%test_select {
%Arg_0.1 = pred[2,3] parameter(0)
@ -743,6 +774,25 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
ROOT %sine.3 = f32[1,16,16,3]{3,2,1,0} sine(f32[1,16,16,3]{3,2,1,0} %arg0.1)
}
// Test sort
%compare {
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
%test_sort {
x = f32[1024]{0} parameter(0)
ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, is_stable=true, to_apply=compare
}
// CHECK-LABEL: func @test_sort
// CHECK-SAME: [[ARG:%.*]]: tensor<1024xf32>) -> tensor<1024xf32>
// CHECK: "xla_hlo.sort"([[ARG]]) ( {
// CHECK: ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
// CHECK: [[CMP:%.*]] = "xla_hlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "LT", name = "lt"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: "xla_hlo.return"([[CMP]]) : (tensor<i1>) -> ()
// CHECK: }) {dimension = 0 : i64, is_stable = true} : (tensor<1024xf32>) -> tensor<1024xf32>
// CHECK-LABEL: func @test_subtract
%test_subtract (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)

View File

@ -22,6 +22,20 @@ limitations under the License.
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Transforms/FoldUtils.h" // TF:llvm-project
// NOLINTNEXTLINE
static llvm::cl::opt<bool> tile_to_parallel_loops_for_linalg_fusion(
"tile-to-parallel-loops-for-linalg-fusion",
llvm::cl::desc(
"Tiles GenericOp consumer to parallel loops before linalg fusion"),
llvm::cl::init(false));
// NOLINTNEXTLINE
static llvm::cl::list<unsigned> tile_sizes_for_linalg_fusion(
"tile-sizes-for-linalg-fusion",
llvm::cl::desc(
"Tile sizes by which to tile linalg generic before linalg fusion"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
namespace mlir {
namespace xla_lhlo {
namespace {
@ -50,13 +64,16 @@ struct LhloFuseLinalg : public FunctionPass<LhloFuseLinalg> {
OpBuilder b(func);
OperationFolder folder(func.getContext());
func.walk([&](linalg::GenericOp generic_op) {
const SmallVector<int64_t, 2> tile_sizes(
generic_op.getNumInputsAndOutputs(), 1);
SmallVector<int64_t, 2> tile_sizes(tile_sizes_for_linalg_fusion.begin(),
tile_sizes_for_linalg_fusion.end());
if (tile_sizes.empty()) {
tile_sizes =
SmallVector<int64_t, 2>(generic_op.getNumInputsAndOutputs(), 1);
}
auto op = cast<LinalgOp>(generic_op.getOperation());
for (const Value result : op.getOutputBuffers()) {
if (!func_args.count(result)) continue;
if (linalg::tileLinalgOp(b, op, tile_sizes, /*permutation=*/{},
&folder)) {
if (tileGenericOp(op, tile_sizes, &b, &folder)) {
generic_op.erase();
return;
}
@ -83,6 +100,18 @@ struct LhloFuseLinalg : public FunctionPass<LhloFuseLinalg> {
}
for (auto* e : erase_set) e->erase();
}
private:
bool tileGenericOp(LinalgOp op, ArrayRef<int64_t> tile_sizes, OpBuilder* b,
OperationFolder* folder) {
auto tiled_generic_op =
tile_to_parallel_loops_for_linalg_fusion
? linalg::tileLinalgOpToParallelLoops(*b, op, tile_sizes,
/*permutation=*/{}, folder)
: linalg::tileLinalgOp(*b, op, tile_sizes,
/*permutation=*/{}, folder);
return tiled_generic_op.hasValue();
}
};
} // namespace

View File

@ -227,19 +227,21 @@ class BroadcastInDimConverter
unsigned nloops = resultMemrefType.getRank();
auto operandShape = operandMemrefType.getShape();
SmallVector<AffineExpr, 4> dimExprs;
{
dimExprs.reserve(nloops);
for (const auto& broadcastDim : llvm::enumerate(
broadcastOp.broadcast_dimensions().getValue().getIntValues())) {
int dim = broadcastDim.value().getSExtValue();
auto operandShape = operandMemrefType.getShape();
int index = 0;
for (const auto& broadcastSize :
broadcastOp.broadcast_dimensions().getValue().getIntValues()) {
int size = broadcastSize.getSExtValue();
dimExprs.push_back(
operandShape[index++] == 1
// TODO(pifon): Add support for args with dynamic shapes for the case
// when a dimension of size 1 is broadcasted into dim of size N.
AffineExpr affineExpr =
operandShape[broadcastDim.index()] == 1
? mlir::getAffineConstantExpr(0, broadcastOp.getContext())
: mlir::getAffineDimExpr(size, broadcastOp.getContext()));
: mlir::getAffineDimExpr(dim, broadcastOp.getContext());
dimExprs.push_back(affineExpr);
}
}

View File

@ -18,6 +18,10 @@ package_group(
includes = [
"//tensorflow/compiler/tf2xla:internal",
],
packages = [
# To pass open source testing in the pip Kokoros.
"//bazel_pip/tensorflow/compiler/tests/...",
],
)
package_group(
@ -25,6 +29,10 @@ package_group(
includes = [
"//tensorflow/compiler/tf2xla:friends",
],
packages = [
# To pass open source testing in the pip Kokoros.
"//bazel_pip/tensorflow/compiler/tests/...",
],
)
generate_backend_suites()
@ -66,6 +74,9 @@ py_test(
size = "small",
srcs = ["xla_test_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
],
@ -76,6 +87,9 @@ tf_xla_py_test(
size = "medium",
srcs = ["adadelta_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -90,6 +104,9 @@ tf_xla_py_test(
size = "small",
srcs = ["adagrad_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -105,6 +122,9 @@ tf_xla_py_test(
size = "small",
srcs = ["adagrad_da_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -119,6 +139,9 @@ tf_xla_py_test(
size = "small",
srcs = ["adam_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -136,6 +159,9 @@ tf_xla_py_test(
# TensorList ops are not implemented in the on-demand compilation model yet.
disabled_backends = ["cpu_ondemand"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -151,6 +177,9 @@ tf_xla_py_test(
size = "small",
srcs = ["argminmax_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -168,6 +197,7 @@ tf_xla_py_test(
shard_count = 5,
tags = [
"no_oss", # TODO(b/148108508): Re-enable this test in OSS.
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"optonly", # Times out frequently in fastbuild mode.
],
deps = [
@ -194,6 +224,7 @@ tf_xla_py_test(
python_version = "PY3",
shard_count = 2,
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"optonly", # Times out frequently in fastbuild mode.
],
deps = [
@ -212,6 +243,9 @@ tf_xla_py_test(
size = "small",
srcs = ["bucketize_op_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -226,7 +260,10 @@ tf_xla_py_test(
size = "small",
srcs = ["categorical_op_test.py"],
python_version = "PY3",
tags = ["optonly"],
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"optonly",
],
deps = [
":xla_test",
"//tensorflow/python:framework",
@ -242,6 +279,7 @@ tf_xla_py_test(
srcs = ["cholesky_op_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"no_rocm",
"optonly",
],
@ -261,6 +299,9 @@ tf_xla_py_test(
size = "small",
srcs = ["cond_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/compiler/tf2xla/python:xla",
@ -278,7 +319,10 @@ tf_xla_py_test(
size = "medium",
srcs = ["self_adjoint_eig_op_test.py"],
python_version = "PY3",
tags = ["optonly"],
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"optonly",
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -291,18 +335,6 @@ tf_xla_py_test(
],
)
tf_xla_py_test(
name = "searchsorted_op_test",
size = "small",
timeout = "moderate",
srcs = ["searchsorted_op_test.py"],
python_version = "PY3",
deps = [
":xla_test",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "svd_op_test",
size = "medium",
@ -314,6 +346,7 @@ tf_xla_py_test(
],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"no_rocm",
"optonly",
],
@ -336,6 +369,7 @@ tf_xla_py_test(
srcs = ["matrix_inverse_op_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"noasan",
"nomsan",
"notsan",
@ -356,6 +390,9 @@ tf_xla_py_test(
timeout = "moderate",
srcs = ["matrix_solve_op_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:linalg_ops",
@ -371,7 +408,10 @@ tf_xla_py_test(
timeout = "moderate",
srcs = ["matrix_triangular_solve_op_test.py"],
python_version = "PY3",
tags = ["optonly"],
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"optonly",
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -387,6 +427,9 @@ tf_xla_py_test(
size = "small",
srcs = ["clustering_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -403,6 +446,7 @@ tf_xla_py_test(
python_version = "PY3",
tags = [
"many_xla_args",
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"no_rocm",
],
deps = [
@ -423,6 +467,9 @@ tf_xla_py_test(
srcs = ["conv2d_test.py"],
python_version = "PY3",
shard_count = 10,
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":test_utils",
":xla_test",
@ -442,6 +489,9 @@ tf_xla_py_test(
srcs = ["conv3d_test.py"],
python_version = "PY3",
shard_count = 5,
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -460,6 +510,7 @@ tf_xla_py_test(
python_version = "PY3",
shard_count = 5,
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"no_rocm",
"noasan",
"nomsan",
@ -482,6 +533,9 @@ tf_xla_py_test(
size = "small",
srcs = ["dynamic_slice_ops_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
"//tensorflow/compiler/tests:xla_test",
"//tensorflow/compiler/tf2xla/python:xla",
@ -499,6 +553,9 @@ tf_xla_py_test(
"gpu",
],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -513,6 +570,9 @@ tf_xla_py_test(
size = "small",
srcs = ["reshape_op_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
"//tensorflow/compiler/tests:xla_test",
"//tensorflow/compiler/tf2xla/python:xla",
@ -527,6 +587,9 @@ tf_xla_py_test(
size = "small",
srcs = ["dynamic_stitch_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -541,6 +604,9 @@ tf_xla_py_test(
size = "small",
srcs = ["extract_image_patches_op_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -556,6 +622,7 @@ tf_xla_py_test(
python_version = "PY3",
tags = [
"multi_and_single_gpu",
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
@ -574,6 +641,9 @@ tf_xla_py_test(
size = "medium",
srcs = ["fifo_queue_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -591,6 +661,7 @@ tf_xla_py_test(
python_version = "PY3",
shard_count = 6,
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"no_rocm",
"optonly",
],
@ -609,6 +680,9 @@ tf_xla_py_test(
size = "small",
srcs = ["slice_ops_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -623,6 +697,9 @@ tf_xla_py_test(
size = "medium",
srcs = ["ftrl_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -638,6 +715,9 @@ tf_xla_py_test(
size = "small",
srcs = ["function_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -653,6 +733,7 @@ tf_xla_py_test(
python_version = "PY3",
shard_count = 10,
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"optonly", # Times out frequently in fastbuild mode.
],
deps = [
@ -669,6 +750,9 @@ tf_xla_py_test(
size = "small",
srcs = ["listdiff_op_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -685,6 +769,9 @@ tf_xla_py_test(
size = "medium",
srcs = ["lrn_ops_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -700,6 +787,9 @@ tf_xla_py_test(
size = "small",
srcs = ["manip_ops_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -715,7 +805,10 @@ tf_xla_py_test(
timeout = "long",
srcs = ["matrix_band_part_test.py"],
python_version = "PY3",
tags = ["optonly"],
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"optonly",
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -731,6 +824,9 @@ tf_xla_py_test(
timeout = "long",
srcs = ["matrix_diag_ops_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -744,6 +840,9 @@ tf_xla_py_test(
size = "small",
srcs = ["momentum_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -759,6 +858,9 @@ tf_xla_py_test(
size = "small",
srcs = ["nary_ops_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -773,6 +875,9 @@ tf_xla_py_test(
size = "small",
srcs = ["nullary_ops_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:control_flow_ops",
@ -787,6 +892,9 @@ tf_xla_py_test(
srcs = ["pooling_ops_test.py"],
python_version = "PY3",
shard_count = 10,
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -803,6 +911,9 @@ tf_xla_py_test(
srcs = ["pooling_ops_3d_test.py"],
python_version = "PY3",
shard_count = 10,
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -818,6 +929,9 @@ tf_xla_py_test(
size = "medium",
srcs = ["proximal_adagrad_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -832,6 +946,9 @@ tf_xla_py_test(
size = "medium",
srcs = ["proximal_gradient_descent_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -852,7 +969,10 @@ tf_xla_py_test(
],
python_version = "PY3",
shard_count = 5,
tags = ["optonly"],
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"optonly",
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -871,6 +991,7 @@ tf_xla_py_test(
python_version = "PY3",
shard_count = 5,
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"no_rocm",
"optonly",
],
@ -892,6 +1013,7 @@ tf_xla_py_test(
python_version = "PY3",
shard_count = 10,
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"notap", # TODO(b/141057424): flaky on TPU
],
deps = [
@ -911,6 +1033,9 @@ tf_xla_py_test(
srcs = ["reduce_ops_test.py"],
python_version = "PY3",
shard_count = 5,
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -927,6 +1052,9 @@ tf_xla_py_test(
size = "small",
srcs = ["reduce_window_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/compiler/tf2xla/python:xla",
@ -943,6 +1071,9 @@ tf_xla_py_test(
size = "medium",
srcs = ["reverse_ops_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -955,7 +1086,10 @@ tf_xla_py_test(
size = "medium",
srcs = ["reverse_sequence_op_test.py"],
python_version = "PY3",
tags = ["optonly"],
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"optonly",
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -969,6 +1103,9 @@ tf_xla_py_test(
size = "small",
srcs = ["rmsprop_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -984,7 +1121,10 @@ tf_xla_py_test(
size = "small",
srcs = ["scan_ops_test.py"],
python_version = "PY3",
tags = ["optonly"],
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"optonly",
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -999,6 +1139,9 @@ tf_xla_py_test(
size = "medium",
srcs = ["segment_reduction_ops_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -1015,6 +1158,9 @@ tf_xla_py_test(
srcs = ["spacetobatch_op_test.py"],
python_version = "PY3",
shard_count = 3,
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -1029,6 +1175,9 @@ tf_xla_py_test(
size = "small",
srcs = ["sparse_to_dense_op_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -1043,7 +1192,10 @@ tf_xla_py_test(
size = "small",
srcs = ["stack_ops_test.py"],
python_version = "PY3",
tags = ["config-cuda-only"],
tags = [
"config-cuda-only",
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
use_xla_device = False,
deps = [
":xla_test",
@ -1060,7 +1212,10 @@ tf_xla_py_test(
srcs = ["stateful_random_ops_test.py"],
python_version = "PY3",
shard_count = 10,
tags = ["optonly"],
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"optonly",
],
deps = [
":xla_test",
"//tensorflow/python:framework",
@ -1076,7 +1231,10 @@ tf_xla_py_test(
size = "medium",
srcs = ["stateless_random_ops_test.py"],
python_version = "PY3",
tags = ["optonly"],
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"optonly",
],
deps = [
":xla_test",
"//tensorflow/python:framework",
@ -1096,6 +1254,7 @@ tf_xla_py_test(
python_version = "PY3",
tags = [
"config-cuda-only",
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"v1only",
],
use_xla_device = False,
@ -1121,6 +1280,9 @@ tf_xla_py_test(
# TensorList ops are not implemented in the on-demand compilation model yet.
disabled_backends = ["cpu_ondemand"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -1136,6 +1298,9 @@ tf_xla_py_test(
size = "medium",
srcs = ["ternary_ops_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -1152,6 +1317,9 @@ tf_xla_py_test(
size = "medium",
srcs = ["unary_ops_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -1168,6 +1336,9 @@ tf_xla_py_test(
size = "medium",
srcs = ["fused_batchnorm_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":test_utils",
":xla_test",
@ -1188,7 +1359,10 @@ tf_xla_py_test(
size = "small",
srcs = ["variable_ops_test.py"],
python_version = "PY3",
tags = ["optonly"],
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"optonly",
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -1207,6 +1381,9 @@ tf_xla_py_test(
size = "small",
srcs = ["while_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/compiler/tf2xla/python:xla",
@ -1222,7 +1399,10 @@ tf_xla_py_test(
size = "medium",
srcs = ["gather_test.py"],
python_version = "PY3",
tags = ["optonly"],
tags = [
"no_pip",
"optonly",
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -1237,6 +1417,9 @@ tf_xla_py_test(
size = "medium",
srcs = ["gather_nd_op_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -1250,7 +1433,10 @@ tf_xla_py_test(
size = "medium",
srcs = ["scatter_nd_op_test.py"],
python_version = "PY3",
tags = ["optonly"],
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"optonly",
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -1266,7 +1452,10 @@ tf_xla_py_test(
python_version = "PY3",
shard_count = 1,
# Times out in fastbuild mode.
tags = ["optonly"],
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"optonly",
],
deps = [
"//tensorflow/compiler/tests:xla_test",
"//tensorflow/compiler/tf2xla/python:xla",
@ -1280,6 +1469,9 @@ tf_xla_py_test(
size = "small",
srcs = ["data_format_ops_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
"//tensorflow/compiler/tests:xla_test",
"//tensorflow/python:array_ops",
@ -1294,7 +1486,10 @@ tf_xla_py_test(
size = "small",
srcs = ["xla_device_test.py"],
python_version = "PY3",
tags = ["optonly"],
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"optonly",
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -1307,6 +1502,9 @@ cuda_py_test(
name = "xla_device_gpu_test",
size = "small",
srcs = ["xla_device_gpu_test.py"],
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
xla_enable_strict_auto_jit = False,
deps = [
"//tensorflow/python:array_ops",
@ -1323,7 +1521,10 @@ cuda_py_test(
size = "medium",
srcs = ["jit_test.py"],
shard_count = 5,
tags = ["no_rocm"],
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"no_rocm",
],
xla_enable_strict_auto_jit = False,
deps = [
":test_utils",
@ -1344,7 +1545,10 @@ cuda_py_test(
name = "dense_layer_test",
size = "medium",
srcs = ["dense_layer_test.py"],
tags = ["no_rocm"],
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"no_rocm",
],
xla_enable_strict_auto_jit = False,
deps = [
":test_utils",
@ -1385,6 +1589,7 @@ tf_cuda_cc_test(
size = "large",
# This test is randomized, so only run it if explicitly requested.
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"manual",
"notap",
] + tf_cuda_tests_tags(),
@ -1394,7 +1599,9 @@ tf_cuda_cc_test(
tf_cuda_cc_test(
name = "unary_ops_composition_test",
srcs = ["unary_ops_composition_test.cc"],
tags = tf_cuda_tests_tags(),
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
] + tf_cuda_tests_tags(),
deps = [
"//tensorflow/cc:cc_ops",
"//tensorflow/compiler/jit",
@ -1430,7 +1637,10 @@ py_library(
cuda_py_test(
name = "lstm_test",
srcs = ["lstm_test.py"],
tags = ["no_rocm"],
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"no_rocm",
],
xla_enable_strict_auto_jit = False,
deps = [
":lstm",
@ -1474,6 +1684,9 @@ tf_xla_py_test(
size = "medium",
srcs = ["fake_quant_ops_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:framework",
@ -1486,6 +1699,9 @@ tf_xla_py_test(
size = "small",
srcs = ["placeholder_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -1499,6 +1715,9 @@ tf_xla_py_test(
size = "medium",
srcs = ["quantized_ops_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/compiler/tf2xla/python:xla",
@ -1516,6 +1735,9 @@ tf_xla_py_test(
size = "medium",
srcs = ["xla_ops_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/compiler/tf2xla/python:xla",
@ -1535,6 +1757,7 @@ tf_xla_py_test(
shard_count = 5,
tags = [
"no_oss", # TODO(b/148108508): Re-enable this test in OSS.
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"no_rocm",
],
deps = [
@ -1560,6 +1783,7 @@ tf_xla_py_test(
],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"optonly",
],
deps = [
@ -1576,7 +1800,10 @@ tf_xla_py_test(
size = "medium",
srcs = ["special_math_test.py"],
shard_count = 5,
tags = ["optonly"],
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"optonly",
],
deps = [
":xla_test",
"//tensorflow/python:extra_py_tests_deps",

View File

@ -1,75 +0,0 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for XLA implementation of tf.searchsorted."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class SearchSorteddOpTest(xla_test.XLATestCase):
def test1D(self):
# Test against NumPy implementation (which is 1D only).
np.random.seed(1)
for side in ['left', 'right']:
for dtype in [np.float32, np.int32]:
values = np.random.uniform(
low=-1000, high=1000, size=(10,)).astype(dtype)
unsorted = np.random.uniform(
low=-1000, high=1000, size=(20,)).astype(dtype)
sorted_sequence = np.sort(unsorted)
np_ans = np.searchsorted(sorted_sequence, values, side=side)
with self.session() as session:
with self.test_scope():
tf_ans = array_ops.searchsorted(sorted_sequence, values, side=side)
tf_out = session.run(tf_ans)
self.assertAllEqual(np_ans, tf_out)
def _test2DExample(self, dtype, side, sorted_sequence, values, correct_ans):
with self.session() as session:
with self.test_scope():
tf_ans = array_ops.searchsorted(sorted_sequence, values, side=side)
tf_out = session.run(tf_ans)
self.assertAllEqual(correct_ans, tf_out)
def testLowerBound2DExample(self):
# 2D TensorFlow documentation example.
for dtype in self.float_types | self.int_types:
sorted_sequence = np.array([[0, 3, 9, 9, 10], [1, 2, 3, 4, 5]], dtype)
values = np.array([[2, 4, 9], [0, 2, 6]], dtype)
correct_ans = np.array([[1, 2, 2], [0, 1, 5]], dtype)
self._test2DExample(dtype, 'left', sorted_sequence, values, correct_ans)
def testUpperBound2DExample(self):
# 2D TensorFlow documentation example.
for dtype in self.float_types | self.int_types:
sorted_sequence = np.array([[0, 3, 9, 9, 10], [1, 2, 3, 4, 5]], dtype)
values = np.array([[2, 4, 9], [0, 2, 6]], dtype)
correct_ans = np.array([[1, 2, 4], [0, 2, 5]], dtype)
self._test2DExample(dtype, 'right', sorted_sequence, values, correct_ans)
if __name__ == '__main__':
test.main()

View File

@ -29,6 +29,10 @@ import scipy.special as sps
import six
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_random_ops
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@ -39,6 +43,13 @@ flags.DEFINE_bool('vary_seed', False,
NUM_SAMPLES = int(1e3)
# This is df/da / df/dx, where f = igamma.
def implicit_reparameterization_grad(a, x):
log_prob = math_ops.xlogy(a - 1., x) - math_ops.lgamma(a) - x
prob = math_ops.exp(log_prob)
return -gen_math_ops.igamma_grad_a(a, x) / prob
class IgammaTest(xla_test.XLATestCase, parameterized.TestCase):
def setUp(self):
@ -48,9 +59,15 @@ class IgammaTest(xla_test.XLATestCase, parameterized.TestCase):
answer = int(entropy.encode('hex'), 16)
else:
answer = int.from_bytes(entropy, 'big')
np.random.seed(answer)
np.random.seed(answer % (2**32 - 1))
super(IgammaTest, self).setUp()
# Skip Float64 test on TPU due to missing ops.
def maybe_skip_test(self, dtype):
if self.device not in ['XLA_GPU', 'XLA_CPU', 'CPU'] and dtype == np.float64:
self.skipTest(
'Skipping test because some F64 operations not supported on TPU.')
@parameterized.parameters((np.float32, 1e-2, 1e-11),
(np.float64, 1e-4, 1e-30))
def testIgammaSmallValues(self, dtype, rtol, atol):
@ -93,6 +110,97 @@ class IgammaTest(xla_test.XLATestCase, parameterized.TestCase):
actual = sess.run(math_ops.igamma(a, x))
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
# We don't check small values because the numerical gradients become quite
# large.
@parameterized.parameters((np.float32, 0.09), (np.float64, 1e-7))
def testIgammaGradMediumValues(self, dtype, tolerance):
self.maybe_skip_test(dtype)
with self.session():
with self.test_scope():
x = constant_op.constant(
np.random.uniform(low=1., high=100.,
size=[NUM_SAMPLES]).astype(dtype))
a = constant_op.constant(
np.random.uniform(low=1., high=100.,
size=[NUM_SAMPLES]).astype(dtype))
f = lambda b: math_ops.igamma(b, x)
max_error = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(f, x=[a], delta=1e-3))
self.assertLessEqual(max_error, tolerance)
@parameterized.parameters((np.float32, 0.5), (np.float64, 1e-7))
def testIgammaGradLargeValues(self, dtype, tolerance):
self.maybe_skip_test(dtype)
with self.session():
with self.test_scope():
x = constant_op.constant(
np.random.uniform(low=100., high=int(1e4),
size=[NUM_SAMPLES]).astype(dtype))
a = constant_op.constant(
np.random.uniform(low=100., high=int(1e4),
size=[NUM_SAMPLES]).astype(dtype))
f = lambda b: math_ops.igamma(b, x)
max_error = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(f, x=[a], delta=1e-2))
self.assertLessEqual(max_error, tolerance)
@parameterized.parameters((np.float32, 1e-2, 1e-11),
(np.float64, 1e-4, 1e-30))
def testRandomGammaGradSmallValues(self, dtype, rtol, atol):
self.maybe_skip_test(dtype)
# Test values near zero.
with self.session() as sess:
with self.test_scope():
x = constant_op.constant(
np.random.uniform(
low=np.finfo(dtype).tiny, high=1.,
size=[NUM_SAMPLES]).astype(dtype))
a = constant_op.constant(
np.random.uniform(
low=np.finfo(dtype).tiny, high=1.,
size=[NUM_SAMPLES]).astype(dtype))
gamma_sample_grad = gen_random_ops.random_gamma_grad(a, x)
actual_grad = implicit_reparameterization_grad(a, x)
gamma_sample_grad, actual_grad = sess.run(
[gamma_sample_grad, actual_grad])
# We do this because the ratio computed in
# implicit_reparameterization_grad can very easily result in a NaN due
# to the computed numerator and denominator zeroing out.
gamma_sample_grad = gamma_sample_grad[
~np.logical_or(np.isnan(actual_grad), np.isinf(actual_grad))]
actual_grad = actual_grad[
~np.logical_or(np.isnan(actual_grad), np.isinf(actual_grad))]
self.assertAllClose(actual_grad, gamma_sample_grad, atol=atol, rtol=rtol)
@parameterized.parameters((np.float32, 1e-2, 1e-11),
(np.float64, 1e-4, 1e-30))
def testRandomGammaGradMediumValues(self, dtype, rtol, atol):
self.maybe_skip_test(dtype)
with self.session() as sess:
with self.test_scope():
x = constant_op.constant(
np.random.uniform(low=1., high=10.,
size=[NUM_SAMPLES]).astype(dtype))
a = constant_op.constant(
np.random.uniform(low=1., high=10.,
size=[NUM_SAMPLES]).astype(dtype))
gamma_sample_grad = gen_random_ops.random_gamma_grad(a, x)
actual_grad = implicit_reparameterization_grad(a, x)
gamma_sample_grad, actual_grad = sess.run(
[gamma_sample_grad, actual_grad])
# We do this because the ratio computed in
# implicit_reparameterization_grad can very easily result in a NaN due
# to the computed numerator and denominator zeroing out.
gamma_sample_grad = gamma_sample_grad[
~np.logical_or(np.isnan(actual_grad), np.isinf(actual_grad))]
actual_grad = actual_grad[
~np.logical_or(np.isnan(actual_grad), np.isinf(actual_grad))]
self.assertAllClose(actual_grad, gamma_sample_grad, atol=atol, rtol=rtol)
if __name__ == '__main__':
os.environ['XLA_FLAGS'] = '--xla_cpu_enable_fast_math=false'

View File

@ -249,10 +249,12 @@ tf_cuda_library(
srcs = [
"utils/trt_int8_calibrator.cc",
"utils/trt_lru_cache.cc",
"utils/trt_shape_optimization_profiles.cc",
],
hdrs = [
"utils/trt_int8_calibrator.h",
"utils/trt_lru_cache.h",
"utils/trt_shape_optimization_profiles.h",
],
deps = [
":trt_allocator",
@ -308,6 +310,22 @@ tf_cc_test(
],
)
tf_cuda_cc_test(
name = "trt_shape_optimization_profiles_test",
size = "small",
srcs = ["utils/trt_shape_optimization_profiles_test.cc"],
tags = [
"no_cuda_on_cpu_tap",
"no_windows",
"nomac",
],
deps = [
":trt_resources",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cuda_library(
name = "logger_registry",
srcs = ["convert/logger_registry.cc"],

View File

@ -431,7 +431,8 @@ Status CreateTRTNode(const ConversionParams& params,
calibrate_int8 ? TrtPrecisionMode::FP32 : info.precision_mode,
max_batch_size, info.max_workspace_size_bytes, input_shapes, trt_logger,
alloc, /*calibrator=*/nullptr, &engine, info.use_calibration,
params.use_implicit_batch, /*convert_successfully=*/nullptr));
params.use_implicit_batch, /*convert_successfully=*/nullptr,
/*profile=*/nullptr));
TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
segment_string = string(static_cast<const char*>(engine_data->data()),
engine_data->size());

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
#include "tensorflow/core/framework/node_def.pb.h" // NOLINT
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
@ -249,6 +250,16 @@ void GetInputProperties(const grappler::GraphProperties& graph_properties,
}
}
// This function checks if a tensor is compatible with TRT.
//
// We check that the shape and datatype are compatible with TensorRT. We also
// return the corresponding trt_dtype, the trt_dims and the batch_size (latter
// is only needed in implicit batch mode).
//
// The return status indicates wether the tensor is compatible.
//
// For implicit batch mode, when validation_only == false, we also check that
// all input dimensions (besides the batch dimension) are known dimensions.
Status ValidateTensorProperties(const string& producer_node_type,
const DataType dtype,
const PartialTensorShape& shape,
@ -293,11 +304,7 @@ Status ValidateTensorProperties(const string& producer_node_type,
if (validation_only) return Status::OK();
// Following checks are only used during TRT engine creation time. In implicit
// batch mode we check that all inputs for the network has static shape (as
// required by the TensorRT). The only exception is the batch size, which
// could be unknown. In contrast, using explicit batch mode this test is not
// necessary, since any dimension could be unknown in explicit batch mode.
// Following checks are only used during TRT engine creation time.
if (use_implicit_batch) {
for (int d = first_trt_dim; d < shape.dims(); ++d) {
if (shape.dim_size(d) < 0) {
@ -1336,7 +1343,7 @@ Status Converter::RenameAndMarkOutputTensors(
Status Converter::BuildCudaEngine(
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, int max_batch_size,
size_t max_workspace_size_bytes, nvinfer1::IGpuAllocator* allocator,
TRTInt8Calibrator* calibrator) {
TRTInt8Calibrator* calibrator, TrtShapeOptimizationProfile* profiles) {
VLOG(1) << "Configuring TensorRT builder";
trt_builder_->setMaxBatchSize(max_batch_size);
trt_builder_->setGpuAllocator(allocator);
@ -1356,7 +1363,10 @@ Status Converter::BuildCudaEngine(
builder_config->setInt8Calibrator(nullptr);
}
}
if (!use_implicit_batch_ && profiles) {
TF_RETURN_IF_ERROR(profiles->ConfigureBuilder(
trt_builder_.get(), builder_config.get(), network()));
}
VLOG(1) << "Building TensorRT engine";
engine->reset(
trt_builder_->buildEngineWithConfig(*network(), *builder_config));
@ -5743,7 +5753,8 @@ Status ConvertGraphDefToEngine(
nvinfer1::ILogger* trt_logger, nvinfer1::IGpuAllocator* allocator,
TRTInt8Calibrator* calibrator,
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, bool use_calibration,
const bool use_implicit_batch, bool* convert_successfully) {
const bool use_implicit_batch, bool* convert_successfully,
TrtShapeOptimizationProfile* profiles) {
engine->reset();
if (convert_successfully) *convert_successfully = false;
@ -5842,7 +5853,8 @@ Status ConvertGraphDefToEngine(
// Build the engine.
TF_RETURN_IF_ERROR(converter->BuildCudaEngine(
engine, max_batch_size, max_workspace_size_bytes, allocator, calibrator));
engine, max_batch_size, max_workspace_size_bytes, allocator, calibrator,
profiles));
VLOG(1) << "Finished conversion";
return Status::OK();

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
@ -147,7 +148,8 @@ Status ConvertGraphDefToEngine(
nvinfer1::ILogger* logger, nvinfer1::IGpuAllocator* allocator,
TRTInt8Calibrator* calibrator,
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, bool use_calibration,
const bool use_implicit_batch, bool* convert_successfully);
const bool use_implicit_batch, bool* convert_successfully,
TrtShapeOptimizationProfile* profiles);
// Helper class for the segmenter to determine whether an output edge from the
// TRT segment is valid.
@ -467,7 +469,8 @@ class Converter {
Status BuildCudaEngine(TrtUniquePtrType<nvinfer1::ICudaEngine>* engine,
int max_batch_size, size_t max_workspace_size_bytes,
nvinfer1::IGpuAllocator* allocator,
TRTInt8Calibrator* calibrator);
TRTInt8Calibrator* calibrator,
TrtShapeOptimizationProfile* profiles);
//////////////////////////////////////////////////////////////////////////////
// Methods used by op converters to convert individual TF node and add layers

View File

@ -1187,7 +1187,7 @@ class ConvertGraphDefToEngineTest : public ::testing::Test {
/*max_workspace_size_bytes=*/64 << 20, input_shapes, &logger_,
/*allocator=*/nullptr, /*calibrator=*/nullptr, &engine_,
/*use_calibration=*/false, /*use_implicit_batch=*/true,
/*convert_successfully=*/nullptr);
/*convert_successfully=*/nullptr, /*profiles=*/nullptr);
}
protected:
@ -1302,7 +1302,8 @@ class OpConverterTest : public ::testing::Test {
/*max_batch_size=*/batch_size,
/*max_workspace_size_bytes=*/1 << 26,
/*allocator=*/nullptr,
/*calibrator=*/nullptr));
/*calibrator=*/nullptr,
/*profiles=*/nullptr));
CHECK_NOTNULL(engine_.get());
CheckDataTypeMatches(input_data);
CheckDataTypeMatches(*output_data);

View File

@ -133,6 +133,25 @@ string DebugString(const std::vector<TensorShape>& shapes) {
string DebugString(const std::vector<PartialTensorShape>& shapes) {
return PartialTensorShapeUtils::PartialShapeListString(shapes);
}
int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) {
int n_bindings = engine->getNbBindings();
int n_input = 0;
for (int i = 0; i < n_bindings; i++) {
if (engine->bindingIsInput(i)) n_input++;
}
// According to TensorRT 7 doc: "If the engine has been built for K profiles,
// the first getNbBindings() / K bindings are used by profile number 0, the
// following getNbBindings() / K bindings are used by profile number 1 etc."
// Therefore, to get the number of input tensors, we need to divide by the
// the number of profiles.
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
int n_profiles = engine->getNbOptimizationProfiles();
#else
int n_profiles = 1;
#endif
return n_input / n_profiles;
}
#endif
string GetLinkedTensorRTVersion() {

View File

@ -106,6 +106,11 @@ string GetLinkedTensorRTVersion();
// TensorRT library version information {Maj, Min, Patch}.
string GetLoadedTensorRTVersion();
// Returns the number of inputs for the engine, which also correspends to the
// number of input tensors for the network. This can differ from the number of
// input bindings, because the number of total input bindings equals the number
// of profiles times the number of engine inputs.
int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine);
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
} // namespace tensorrt

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/framework/function.h"
@ -92,7 +93,7 @@ class TRTEngineOp : public AsyncOpKernel {
LRUCache<std::vector<TensorShape>, std::unique_ptr<EngineContext>,
VectorTensorShapeHasher>;
// Execute calibration
// Executes calibration.
void ExecuteCalibration(OpKernelContext* ctx,
TRTEngineCacheResource* cache_res,
AsyncHelper* helper);
@ -103,14 +104,15 @@ class TRTEngineOp : public AsyncOpKernel {
Status ConstructFunctionHandle(FunctionLibraryRuntime* lib,
const string& device_name);
// Execute replaced native segment as function Op.
// Executes replaced native segment as function Op.
void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
// Execute the tensorrt engine. Returns whether we need to retry by running
// Executes the tensorrt engine. Returns whether we need to retry by running
// the native segment.
bool ExecuteTrtEngine(OpKernelContext* ctx, EngineContext* engine_context);
bool ExecuteTrtEngine(OpKernelContext* ctx, EngineContext* engine_context,
int trt_context_idx);
// Allocate necessary resources for calibration
// Allocates necessary resources for calibration.
Status AllocateCalibrationResources(OpKernelContext* ctx,
TRTEngineCacheResource* cache_res);
@ -605,11 +607,24 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
OP_REQUIRES_OK_ASYNC(ctx, VerifyInputShapes(input_concrete_shapes), *helper);
if (!use_implicit_batch_) {
if (cache_res->profiles_.GetNumProfiles() == 0) {
// Create a single profile from the current input shape. In the future we
// will collect a set of input shapes during build mode and create
// profiles for each of them.
cache_res->profiles_.AddShape(input_concrete_shapes);
cache_res->profiles_.InitProfiles();
}
}
StatusOr<EngineContext*> status =
GetEngine(input_concrete_shapes, ctx, cache_res);
OP_REQUIRES_OK_ASYNC(ctx, status.status(), *helper);
EngineContext* engine_context = status.ValueOrDie();
// Context idx equals with the profile idx because for each profile we create
// one context. Currently we do not have profile_generation mode, therefore we
// have just a single profile.
int trt_context_idx = 0;
if (!engine_context->cuda_engine) {
VLOG(1) << "Engine retrieval for input shapes: "
<< TensorShapeUtils::ShapeListString(input_concrete_shapes)
@ -617,7 +632,8 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
ExecuteNativeSegment(ctx, helper);
return;
}
const bool retry = ExecuteTrtEngine(ctx, engine_context);
const bool retry = ExecuteTrtEngine(ctx, engine_context, trt_context_idx);
if (retry) {
LOG(WARNING) << "Failed to execute engine, "
<< "retrying with native segment for " << name();
@ -665,7 +681,8 @@ Status GetTrtBindingIndex(const char* tensor_name, int profile_index,
}
bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
EngineContext* engine_context) {
EngineContext* engine_context,
int trt_context_idx) {
VLOG(1) << "Executing TRT engine: " << name();
auto& cuda_engine = engine_context->cuda_engine;
@ -688,6 +705,11 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
}
const bool kRetry = true;
if (trt_context_idx >= 1) {
LOG(ERROR) << "Requested engine context with index " << trt_context_idx
<< ", but only 1 context is present.";
return kRetry;
}
const int num_binding = cuda_engine->getNbBindings();
std::vector<void*> buffers(num_binding);
@ -698,8 +720,8 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
for (int i = 0; i < ctx->num_inputs(); i++) {
const string input_name = StrCat(IONamePrefixes::kInputPHName, i);
int binding_index;
auto status = GetTrtBindingIndex(input_name.c_str(), 0, cuda_engine.get(),
&binding_index);
auto status = GetTrtBindingIndex(input_name.c_str(), trt_context_idx,
cuda_engine.get(), &binding_index);
if (!status.ok()) {
ctx->SetStatus(status);
return !kRetry;
@ -770,8 +792,8 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
for (int i = 0; i < ctx->num_outputs(); i++) {
const string output_name = StrCat(IONamePrefixes::kOutputPHName, i);
int binding_index;
auto status = GetTrtBindingIndex(output_name.c_str(), 0, cuda_engine.get(),
&binding_index);
auto status = GetTrtBindingIndex(output_name.c_str(), trt_context_idx,
cuda_engine.get(), &binding_index);
if (!status.ok()) {
ctx->SetStatus(status);
return !kRetry;
@ -801,7 +823,7 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
trt_shape.push_back(dims.d[j]);
}
}
// Allocate output tensor of TRTEngineOp
// Allocate output tensor of TRTEngineOp.
Tensor* output_tensor = nullptr;
TensorShape output_shape;
status = TensorShapeUtils::MakeShape(trt_shape.data(), trt_shape.size(),
@ -997,7 +1019,8 @@ StatusOr<EngineContext*> TRTEngineOp::GetEngine(
auto status = convert::ConvertGraphDefToEngine(
segment_graph_def_, precision_mode_, batch_size, workspace_size_,
conversion_input_shapes, &logger, allocator, calibrator_.get(), &engine,
use_calibration_, use_implicit_batch_, &convert_successfully);
use_calibration_, use_implicit_batch_, &convert_successfully,
&cache_res->profiles_);
if (!status.ok()) {
LOG(WARNING) << "Engine creation for " << name() << " failed. "
<< "The native segment will be used instead. "
@ -1007,11 +1030,12 @@ StatusOr<EngineContext*> TRTEngineOp::GetEngine(
cache.emplace(input_concrete_shapes, absl::make_unique<EngineContext>());
return &empty_context;
}
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
engine->createExecutionContext());
std::vector<TrtUniquePtrType<nvinfer1::IExecutionContext>> exec_context;
TF_RETURN_IF_ERROR(cache_res->profiles_.CreateExecutionContexts(
engine.get(), exec_context));
cache.emplace(input_concrete_shapes,
absl::make_unique<EngineContext>(std::move(engine),
std::move(exec_context)));
std::move(exec_context[0])));
VLOG(1) << "Added new engine to cache of " << name()
<< ". Cache size: " << cache.size();
}
@ -1085,9 +1109,9 @@ Status TRTEngineOp::AllocateCalibrationResources(
this->segment_graph_def_, TrtPrecisionMode::INT8,
cres->calibrator_->getBatchSize(), this->workspace_size_,
partial_shapes, &cache_res->GetLogger(), cache_res->allocator_.get(),
cres->calibrator_.get(), &cres->engine_,
/*use_calibration=*/true, this->use_implicit_batch_,
/*convert_successfully=*/nullptr);
cres->calibrator_.get(), &cres->engine_, /*use_calibration=*/true,
this->use_implicit_batch_, /*convert_successfully=*/nullptr,
/*profiles=*/nullptr);
if (!s.ok()) {
LOG(ERROR) << "Calibration failed: " << s;
cres->calibrator_->setDone(); // Ignore further pushes

View File

@ -129,9 +129,14 @@ class TRTEngineOpTestBase : public OpsTestBase {
private:
Status InitOpWithFunctionLibrary() {
OpKernel* kernel = nullptr;
Status status = CreateOpKernel(device_type_, device_, allocator(),
pflr_->GetFLR(device_->name()), node_def_,
TF_GRAPH_DEF_VERSION, &kernel);
auto flr = pflr_->GetFLR(device_->name());
std::shared_ptr<const NodeProperties> props;
Status status = NodeProperties::CreateFromNodeDef(
node_def_, flr->GetFunctionLibraryDefinition(), &props);
if (status.ok()) {
status.Update(CreateOpKernel(device_type_, device_, allocator(), flr,
props, TF_GRAPH_DEF_VERSION, &kernel));
}
kernel_ = std::unique_ptr<OpKernel>(kernel);
if (kernel_ != nullptr) input_types_ = kernel_->input_types();
return status;
@ -214,6 +219,7 @@ TEST_F(TRTEngineOpTestBase, AllowBuildAtRuntime) {
EXPECT_EQ(ectx->cuda_engine, nullptr);
}
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
TEST_F(TRTEngineOpTestBase, ExplicitBatch) {
// Test inference in explicit batch mode with static input shapes. Static
// shapes in this context means that the TensorRT knows all the input shapes
@ -253,15 +259,6 @@ TEST_F(TRTEngineOpTestBase, DynamicShapes) {
TensorShape input_shape({1, 2});
TRTEngineOpTestBase::AddSimpleInput<float>(input_shape);
// We expect that TensorRT engine creation fails: we would need to configure
// the engine with optimization profiles to use dynamic input shapes, but that
// feature is not yet implemented.
//
// Since TRT engine creation has failed, we fall back to native segment.
// Calling the native segment fails for the same reason that is investigated
// in https://github.com/tensorflow/tensorflow/pull/34919. This is irrelevant
// for the current test, here we want to just check wether TRT engine creation
// has failed.
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
// Get the engine cache.
@ -274,11 +271,8 @@ TEST_F(TRTEngineOpTestBase, DynamicShapes) {
auto cache = &cache_resource->cache_;
EXPECT_EQ(1, cache->size());
ASSERT_EQ(1, cache->count({input_shape}));
// TODO(bixia): re-enable the check below when the problem is fixed.
// EngineContext* ectx = cache->at({input_shape}).get();
// Since engine creation failed, we expect to find nullptr. Finding a nullptr
// indicates that unknown shapes were used to define the TensorRT network.
// EXPECT_EQ(ectx->cuda_engine, nullptr);
EngineContext* ectx = cache->at({input_shape}).get();
EXPECT_NE(ectx->cuda_engine, nullptr);
}
template <typename T>
@ -302,6 +296,7 @@ TYPED_TEST(TRTEngineOpTest, Basic) {
output->NumElements()),
ElementsAre(TypeParam(0.0f), TypeParam(2.0f)));
}
#endif
} // namespace tensorrt
} // namespace tensorflow

View File

@ -140,11 +140,24 @@ class InitializeTRTResource : public OpKernel {
engine_instance.serialized_engine().c_str(),
engine_instance.serialized_engine().size(), nullptr));
auto raw_engine = engine.get();
resource->cache_.emplace(
engine_input_shapes,
absl::make_unique<EngineContext>(
std::move(engine), TrtUniquePtrType<nvinfer1::IExecutionContext>(
raw_engine->createExecutionContext())));
std::vector<TrtUniquePtrType<nvinfer1::IExecutionContext>> ctx_vec;
if (num_loaded_engine == 0) {
// Restore profiles if there are any. Currently only 1 engine is allowed
// in dynamic mode therefore we call this only for the 0th engine.
// it is a no-op in implicit batch mode.
OP_REQUIRES_OK(ctx, resource->profiles_.RestoreProfiles(raw_engine));
OP_REQUIRES_OK(ctx, resource->profiles_.CreateExecutionContexts(
raw_engine, ctx_vec));
} else {
// Multiple engines are only available in static mode. For each engine
// we have only a single execution context.
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_ctx(
raw_engine->createExecutionContext());
ctx_vec.push_back(std::move(exec_ctx));
}
resource->cache_.emplace(engine_input_shapes,
absl::make_unique<EngineContext>(
std::move(engine), std::move(ctx_vec[0])));
++num_loaded_engine;
} while (1);
VLOG(1) << "Loaded " << num_loaded_engine << " TRT engines for op "

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/errors.h"
@ -182,6 +183,11 @@ class TRTEngineCacheResource : public ResourceBase {
// TODO(hinsu): Use different calibration context for the available shapes and
// attach it to each item of the cache.
std::unique_ptr<CalibrationContext> calib_ctx_;
// This object maintains all the optimization profiles during profile
// generation and engine build. During runtime the list of profiles is used to
// look up a matching profile for the input data.
TrtShapeOptimizationProfile profiles_;
};
#endif // GOOGLE_TENSORRT

View File

@ -0,0 +1,185 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
#include <algorithm>
#include <functional>
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#if GOOGLE_CUDA && GOOGLE_TENSORRT
namespace tensorflow {
namespace tensorrt {
// Creates optimization profiles for a list of input shapes. The list of input
// shapes are stored in shapes_.
void TrtShapeOptimizationProfile::InitProfiles() {
if (input_shapes_.size() == 0) {
VLOG(1) << "Not creating profiles without input_shapes. "
"You have to enable profile generation mode first (build).";
} else {
VLOG(1) << "Creating profiles with startegy of one profile "
<< "for each input (min=opt=max).";
}
for (auto& shape_vec : input_shapes_) {
std::vector<nvinfer1::Dims> dimvec;
for (auto& shape : shape_vec) {
dimvec.push_back(TensorShapeToTrtDims(shape, false));
}
// We set min=opt=max.
OptimizationProfileConfig profConfig{dimvec, dimvec, dimvec};
profiles_.push_back(std::move(profConfig));
VLOG(1) << "Created profile " << profiles_.back().DebugString();
}
}
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
Status TrtShapeOptimizationProfile::AddProfiles(
nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config,
const nvinfer1::INetworkDefinition* network) {
// Create a vector of optimization profiles
for (int i = 0; i < profiles_.size(); i++) {
auto* optProfile = builder->createOptimizationProfile();
Status status = profiles_[i].SetDimensions(network, optProfile);
if (!status.ok()) {
return status;
}
int idx = -1;
if (optProfile->isValid()) {
idx = config->addOptimizationProfile(optProfile);
}
if (idx >= 0) {
if (i != idx) {
return errors::Internal(
"Profile index of engine config is different from resource profile "
"index: ",
i, " != ", idx);
}
VLOG(1) << "Added optimization profile " << profiles_[i].DebugString()
<< " to builder config.";
} else {
LOG(ERROR) << "Failed to add optimization profile "
<< profiles_[i].DebugString()
<< ". This usually happens when profile is invalid.";
}
}
if (config->getNbOptimizationProfiles() == 0) {
return errors::Internal("Failure in adding an optimization profile.");
}
// if TRT_VERSION < 6, then we do not need to add
return Status::OK();
}
#endif
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
Status TrtShapeOptimizationProfile::ConfigureBuilder(
nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config,
const nvinfer1::INetworkDefinition* network) {
TF_RETURN_IF_ERROR(AddProfiles(builder, config, network));
return Status::OK();
}
#endif
int TrtShapeOptimizationProfile::GetProfileNumber(
std::vector<TensorShape> shapes) {
for (int i = 0; i < profiles_.size(); i++) {
if (profiles_[i].IncludesShapes(shapes)) {
return i;
}
}
VLOG(1) << "Profile not found for input shapes " << DebugString(shapes)
<< ".";
return -1;
}
Status TrtShapeOptimizationProfile::CreateExecutionContexts(
nvinfer1::ICudaEngine* engine,
std::vector<TrtUniquePtrType<nvinfer1::IExecutionContext>>& exec_context) {
int i = 0;
// The following loop runs once if we have static shapes, to create a single
// execution context without profiles. In dynamic mode we create one context
// for each profile and set the corresponding optimization profile.
do {
VLOG(1) << "Creating execution context " << i;
nvinfer1::IExecutionContext* ctx = engine->createExecutionContext();
if (ctx == nullptr) {
return errors::Internal("Failed to create execution context");
}
if (i > 0) {
// This condition is needed for two reasons:
// - using static shapes we do not have any profiles so we cannot call
// set optimizationprofiles.
// - The 0th profile is set implicitly for the first execution context
// therefore we do not need to set.
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
bool stat = ctx->setOptimizationProfile(i);
if (!stat) {
ctx->destroy();
return errors::Internal("Could not set TRT optimization profile.");
}
#endif
}
exec_context.push_back(TrtUniquePtrType<nvinfer1::IExecutionContext>(ctx));
i++;
} while (i < profiles_.size());
return Status::OK();
}
Status TrtShapeOptimizationProfile::RestoreProfiles(
const nvinfer1::ICudaEngine* engine) {
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
if (!engine) {
// We do not need to restore profiles for an empty engine
return Status::OK();
}
#if IS_TRT_VERSION_GE(7, 0, 0, 0)
if (engine->hasImplicitBatchDimension()) {
// Nothing to do, we cannot have profiles in implicit batch mode
return Status::OK();
}
#endif
int n_profiles = engine->getNbOptimizationProfiles();
int n_inputs = GetNumberOfEngineInputs(engine);
VLOG(2) << "Attempting to restore " << n_profiles << " profiles, each with "
<< n_inputs << " inputs";
for (int prof_idx = 0; prof_idx < n_profiles; prof_idx++) {
OptimizationProfileConfig cfg;
for (int j = 0; j < n_inputs; j++) {
nvinfer1::Dims min = engine->getProfileDimensions(
j, prof_idx, nvinfer1::OptProfileSelector::kMIN);
nvinfer1::Dims max = engine->getProfileDimensions(
j, prof_idx, nvinfer1::OptProfileSelector::kMAX);
nvinfer1::Dims opt = engine->getProfileDimensions(
j, prof_idx, nvinfer1::OptProfileSelector::kOPT);
cfg.min.push_back(min);
cfg.max.push_back(max);
cfg.opt.push_back(opt);
}
VLOG(2) << "Restored profile " << cfg.DebugString();
profiles_.push_back(std::move(cfg));
}
#endif
return Status::OK();
}
int TrtShapeOptimizationProfile::GetNumProfiles() const {
return profiles_.size();
}
} // namespace tensorrt
} // namespace tensorflow
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT

View File

@ -0,0 +1,178 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_SHAPE_OPTIMIZATION_PROFILES_H_
#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_SHAPE_OPTIMIZATION_PROFILES_H_
#include <list>
#include <string>
#include <unordered_set>
#include <vector>
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "third_party/tensorrt/NvInfer.h"
namespace tensorflow {
namespace tensorrt {
// Stores optimization profile parameters (min/opt/max of each input shape).
//
// A TensorRT optimization profile describes the possible min/max values of
// each dynamic input shape along with an optimum value. These values are used
// by the TensorRT builder to select the best kernel for the optimum value among
// those kernels that are valid for all input tensors in the [min, max] range.
struct OptimizationProfileConfig {
// Length of vector == num_inputs to engine
std::vector<nvinfer1::Dims> min;
std::vector<nvinfer1::Dims> opt;
std::vector<nvinfer1::Dims> max;
string DebugString() const {
using absl::StrCat;
return StrCat("[min: ", tensorflow::tensorrt::DebugString(min),
", opt: : ", tensorflow::tensorrt::DebugString(opt),
", max: ", tensorflow::tensorrt::DebugString(max), "]");
}
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
// Sets the stored min/opt/max dimensions for profile.
//
// Parameters:
// network - TensorRT network, used to enumerate all the input tensors
// profile - on exit the profile information will be set for each input tensor
Status SetDimensions(const nvinfer1::INetworkDefinition* network,
nvinfer1::IOptimizationProfile* profile) const {
int n_inputs = network->getNbInputs();
if (min.size() != n_inputs || opt.size() != n_inputs ||
max.size() != n_inputs) {
return errors::Internal("Incorrect number of profile config parameters");
}
for (int i = 0; i < n_inputs; i++) {
const char* name = network->getInput(i)->getName();
profile->setDimensions(name, nvinfer1::OptProfileSelector::kMIN, min[i]);
profile->setDimensions(name, nvinfer1::OptProfileSelector::kOPT, opt[i]);
profile->setDimensions(name, nvinfer1::OptProfileSelector::kMAX, max[i]);
}
return Status::OK();
}
#endif
// Returns true if profile range completely includes the given shapes.
bool IncludesShapes(const std::vector<TensorShape>& shapes) const {
// min, max, and opt must have the same size which is already verified in
// SetDimensions.
if (min.size() != shapes.size()) {
return false;
}
for (int i = 0; i < shapes.size(); i++) {
auto current_shape = shapes[i];
// min, max, and opt must have the same nbDims, which is already verified
// in SetDimensions.
if (min[i].nbDims != current_shape.dims()) {
return false;
}
// Check if range [min, max] includes current_shape.
for (int dim = 0; dim < current_shape.dims(); dim++) {
if ((min[i].d[dim] > current_shape.dim_size(dim)) ||
(max[i].d[dim] < current_shape.dim_size(dim))) {
return false;
}
}
}
return true;
}
};
// Manages Optimization profiles during TRT Engine construction.
//
// An optimization profile describes a range of dimensions for each TRT network
// input, and the optimal dimensions that the auto-tuner should use for
// optimization.
//
// This class stores the list of input shapes that were seen during the
// build/profile_generation_mode phase, and using them it creates a set of
// OptimizationProfileConfigs. These configs will be added to IBuilderConfig
// before the engine is created.
class TrtShapeOptimizationProfile {
public:
TrtShapeOptimizationProfile() {}
// Stores input shape information during profile_generation_mode
void AddShape(std::vector<TensorShape> shapes) {
input_shapes_.insert(shapes);
VLOG(1) << "Collected shape(s) " << DebugString(shapes) << " for profiles.";
}
void clear() { profiles_.clear(); }
// Returns the profile number that should be used to execute the network with
// the given input shapes. Returns -1 if none of cached profiles are
// compatible with the given input shapes.
int GetProfileNumber(std::vector<TensorShape> shapes);
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
// Creates optimization profiles and add them to the builder config.
Status ConfigureBuilder(nvinfer1::IBuilder* builder,
nvinfer1::IBuilderConfig* config,
const nvinfer1::INetworkDefinition* network);
#endif
// Creates execution contexts for each optimization profile.
Status CreateExecutionContexts(
nvinfer1::ICudaEngine* engine,
std::vector<TrtUniquePtrType<nvinfer1::IExecutionContext>>& exec_context);
// Maps input vector shapes to TRT Optimization profiles (min, max, opt) i.e.
// maps input_shapes_ to profiles_
void InitProfiles();
// Returns number of created profiles.
int GetNumProfiles() const;
// Restores profiles from the engine (used after deserialization)
Status RestoreProfiles(const nvinfer1::ICudaEngine* engine);
private:
// Set of input shape vetors that we collect during profile_generation_mode
std::unordered_set<std::vector<TensorShape>, VectorTensorShapeHasher>
input_shapes_;
// The optimization profiles generated from input_shapes_
std::vector<OptimizationProfileConfig> profiles_;
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
/// Adds optimization profiles to the builder config
Status AddProfiles(nvinfer1::IBuilder* builder,
nvinfer1::IBuilderConfig* config,
const nvinfer1::INetworkDefinition* network);
#endif
};
} // namespace tensorrt
} // namespace tensorflow
#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA
#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_SHAPE_OPTIMIZATION_PROFILES_H_

View File

@ -0,0 +1,218 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include <string.h>
#include <vector>
#include "absl/memory/memory.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/test.h"
#include "third_party/tensorrt/NvInfer.h"
namespace tensorflow {
namespace tensorrt {
std::vector<TensorShape> DimVecToShapeVec(std::vector<nvinfer1::Dims3> dimvec) {
std::vector<TensorShape> shapevec(dimvec.size());
for (int i = 0; i < dimvec.size(); i++) {
TensorShape shape;
TF_CHECK_OK(
TensorShapeUtils::MakeShape(dimvec[i].d, dimvec[i].nbDims, &shape));
shapevec[i] = shape;
}
return shapevec;
}
bool DimsContained(const nvinfer1::Dims& dim, const nvinfer1::Dims& min,
const nvinfer1::Dims& max) {
if (dim.nbDims != min.nbDims || dim.nbDims != max.nbDims) {
return false;
}
for (int i = 0; i < dim.nbDims; i++) {
if (dim.d[i] < min.d[i] || dim.d[i] > max.d[i]) {
return false;
}
}
return true;
}
bool DimsEqual(const nvinfer1::Dims& a, const nvinfer1::Dims& b) {
if (a.nbDims != b.nbDims) {
return false;
}
for (int i = 0; i < a.nbDims; i++) {
if (a.d[i] != b.d[i]) {
return false;
}
}
return true;
}
class TrtShapeOptimizationProfileTest : public ::testing::Test {
protected:
void SetUp() override {
builder_ = TrtUniquePtrType<nvinfer1::IBuilder>(
nvinfer1::createInferBuilder(logger_));
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
network_ = TrtUniquePtrType<nvinfer1::INetworkDefinition>(
builder_->createNetworkV2(flags_));
builder_config_ = TrtUniquePtrType<nvinfer1::IBuilderConfig>(
builder_->createBuilderConfig());
builder_config_->setMaxWorkspaceSize(1 << 10);
#else
network_ = TrtUniquePtrType<nvinfer1::INetworkDefinition>(
builder_->createNetwork());
builder_->setMaxWorkspaceSize(1 << 10);
#endif
}
// Defines a simple network: output = input1 + input2.
void DefineNetwork(nvinfer1::INetworkDefinition* network,
nvinfer1::Dims3& dims) {
nvinfer1::ITensor* input1 =
network->addInput("input1", nvinfer1::DataType::kFLOAT, dims);
EXPECT_NE(nullptr, input1);
nvinfer1::ITensor* input2 =
network->addInput("input2", nvinfer1::DataType::kFLOAT, dims);
EXPECT_NE(nullptr, input1);
auto layer = network->addElementWise(*input1, *input2,
nvinfer1::ElementWiseOperation::kSUM);
EXPECT_NE(nullptr, layer);
// Mark the output.
nvinfer1::ITensor* output = layer->getOutput(0);
output->setName("output");
network->markOutput(*output);
}
Logger logger_;
TrtUniquePtrType<nvinfer1::IBuilder> builder_;
TrtUniquePtrType<nvinfer1::INetworkDefinition> network_;
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
TrtUniquePtrType<nvinfer1::IBuilderConfig> builder_config_;
#endif
TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
std::vector<TrtUniquePtrType<nvinfer1::IExecutionContext>> exec_context_;
// The order is important: exec_context_ must be destroyed first, and logger
// at last.
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
const uint32_t flags_ =
1U << static_cast<int>(
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
#endif
};
TEST_F(TrtShapeOptimizationProfileTest, Static) {
// Network with static input shape
nvinfer1::Dims3 dims(8, 8, 10);
DefineNetwork(network_.get(), dims);
TrtShapeOptimizationProfile profile;
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
// Configure and build engine - should be a no-op
TF_CHECK_OK(profile.ConfigureBuilder(builder_.get(), builder_config_.get(),
network_.get()));
engine = TrtUniquePtrType<nvinfer1::ICudaEngine>(
builder_->buildEngineWithConfig(*network_, *builder_config_));
#else
engine = TrtUniquePtrType<nvinfer1::ICudaEngine>(
builder_->buildCudaEngine(*network_));
#endif
EXPECT_NE(nullptr, engine);
TF_CHECK_OK(profile.CreateExecutionContexts(engine.get(), exec_context_));
// A single execution context should be created for a graph with static input
ASSERT_EQ(exec_context_.size(), 1);
EXPECT_NE(nullptr, exec_context_[0]);
std::vector<nvinfer1::Dims3> dim_vec(2, dims);
std::vector<TensorShape> shape_vec = DimVecToShapeVec(dim_vec);
EXPECT_EQ(-1, profile.GetProfileNumber(shape_vec));
}
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
TEST_F(TrtShapeOptimizationProfileTest, Dynamic) {
// Network with dynamic input shapes
nvinfer1::Dims3 dims(-1, -1, 10);
DefineNetwork(network_.get(), dims);
TrtShapeOptimizationProfile profile;
std::vector<std::vector<nvinfer1::Dims3>> input_profiles{
{nvinfer1::Dims3(2, 2, 10), nvinfer1::Dims3(2, 2, 10)},
{nvinfer1::Dims3(3, 3, 10), nvinfer1::Dims3(3, 3, 10)},
{nvinfer1::Dims3(16, 16, 10), nvinfer1::Dims3(16, 16, 10)},
};
// Simulate a profile collection phase
for (auto dim_vec : input_profiles) {
std::vector<TensorShape> shape_vec = DimVecToShapeVec(dim_vec);
profile.AddShape(shape_vec);
}
profile.InitProfiles();
// Configure and build engine
TF_CHECK_OK(profile.ConfigureBuilder(builder_.get(), builder_config_.get(),
network_.get()));
engine = TrtUniquePtrType<nvinfer1::ICudaEngine>(
builder_->buildEngineWithConfig(*network_.get(), *builder_config_.get()));
ASSERT_NE(nullptr, engine);
TF_CHECK_OK(profile.CreateExecutionContexts(engine.get(), exec_context_));
// Each profile has an associated execution context.
EXPECT_EQ(exec_context_.size(), input_profiles.size());
// Check if the profiles are assigned correctly.
for (auto dimvec : input_profiles) {
std::vector<TensorShape> shape_vec = DimVecToShapeVec(dimvec);
int idx = profile.GetProfileNumber(shape_vec);
int prof_idx = exec_context_[idx]->getOptimizationProfile();
ASSERT_GE(prof_idx, 0);
for (int j = 0; j < dimvec.size(); j++) {
nvinfer1::Dims min = engine->getProfileDimensions(
j, prof_idx, nvinfer1::OptProfileSelector::kMIN);
nvinfer1::Dims max = engine->getProfileDimensions(
j, prof_idx, nvinfer1::OptProfileSelector::kMAX);
nvinfer1::Dims opt = engine->getProfileDimensions(
j, prof_idx, nvinfer1::OptProfileSelector::kOPT);
// This should always hold.
EXPECT_TRUE(DimsContained(dimvec[j], min, max));
// The following test depends on the profile creation strategy, and needs
// to be updated (disabled) if the default trategy (defined by
// InitProfiles) changes.
EXPECT_TRUE(DimsEqual(dimvec[j], opt));
}
}
}
#endif
} // namespace tensorrt
} // namespace tensorflow
#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA

View File

@ -133,7 +133,7 @@ Status GraphCompiler::Compile() {
OpKernel* op_kernel_raw = nullptr;
// The kernel is not actually run for functional ops, we just need it
// for metadata.
Status s = flib_->CreateKernel(n->def(), &op_kernel_raw);
Status s = flib_->CreateKernel(n->properties(), &op_kernel_raw);
// Transfer ownership of the kernel to a local smart pointer.
std::unique_ptr<OpKernel> op_kernel(op_kernel_raw);

View File

@ -55,7 +55,6 @@ tf_kernel_library(
"index_ops.cc",
"l2loss_op.cc",
"listdiff_op.cc",
"lower_upper_bound_ops.cc",
"lrn_ops.cc",
"matmul_op.cc",
"matrix_band_part_op.cc",
@ -150,7 +149,6 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:comparison_util",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",

View File

@ -264,6 +264,23 @@ xla::XlaOp IgammaImpl(xla::XlaOp x, xla::XlaOp y,
XLA_MAKE_BINARY(Igamma, IgammaImpl(lhs, rhs, broadcast_helper));
xla::XlaOp IgammaGradAImpl(xla::XlaOp x, xla::XlaOp y,
const BCast& broadcast_helper) {
std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
return xla::IgammaGradA(x, y);
}
XLA_MAKE_BINARY(IgammaGradA, IgammaGradAImpl(lhs, rhs, broadcast_helper));
xla::XlaOp RandomGammaGradImpl(xla::XlaOp x, xla::XlaOp y,
const BCast& broadcast_helper) {
std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
return xla::RandomGammaGrad(x, y);
}
XLA_MAKE_BINARY(RandomGammaGrad,
RandomGammaGradImpl(lhs, rhs, broadcast_helper));
xla::XlaOp IgammacImpl(xla::XlaOp x, xla::XlaOp y,
const BCast& broadcast_helper) {
std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);

View File

@ -1,116 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
namespace {
// Builds a LowerBound or UpperBound op, the distinction lying in
// comparison_direction: GT => LowerBoundOp, GE => UpperBoundOp.
// Note that this is an O(MN) algorithm: all entries in each sorted_inputs row
// are considered, and their sorted nature is not fully exploited.
void BuildLowerUpperBoundOp(XlaOpKernelContext* ctx, DataType out_dtype,
xla::ComparisonDirection comparison_direction) {
const TensorShape sorted_inputs_shape = ctx->InputShape("sorted_inputs");
const TensorShape values_shape = ctx->InputShape("values");
const xla::XlaOp sorted_inputs = ctx->Input("sorted_inputs");
const xla::XlaOp values = ctx->Input("values");
// We are assuming both inputs are 2D, which they will be given the current
// implementation of tf.searchsorted.
OP_REQUIRES(ctx, sorted_inputs_shape.dims() == 2,
errors::FailedPrecondition("sorted_inputs must be 2D"));
OP_REQUIRES(ctx, values_shape.dims() == 2,
errors::FailedPrecondition("values must be 2D"));
// Add a new inner dimension to values, to allow broadcasting along the inner
// dimension of sorted_sequence.
auto new_values_shape = values_shape;
new_values_shape.InsertDim(/* d */ 2, /* size */ 1);
auto values_reshaped = xla::Reshape(values, new_values_shape.dim_sizes());
// Add a new penultimate dimension to sorted_inputs, to allow broadcasting of
// sorted_sequence entries for each value.
auto new_sorted_inputs_shape = sorted_inputs_shape;
new_sorted_inputs_shape.InsertDim(/* d */ 1, /* size */ 1);
auto sorted_inputs_reshaped =
xla::Reshape(sorted_inputs, new_sorted_inputs_shape.dim_sizes());
// We are relying on broadcasting to compare each value against each entry in
// the associated sorted_inputs row.
// The reshapes above leave the tensors with equal rank of 3, so broadcast
// dimensions are not explicitly specified.
auto comparison = xla::Compare(values_reshaped, sorted_inputs_reshaped, {},
comparison_direction);
const DataType accumulation_type = XlaHelpers::SumAccumulationType(out_dtype);
// Convert boolean comparison results to integers so we can sum them.
auto comparison_int =
XlaHelpers::ConvertElementType(comparison, accumulation_type);
// Sum the comparison results over the inner dimension to find the index for
// each value.
xla::XlaBuilder* builder = ctx->builder();
auto reduced =
xla::Reduce(comparison_int, XlaHelpers::Zero(builder, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type), {2});
ctx->SetOutput(0, reduced);
}
class LowerBoundOp : public XlaOpKernel {
public:
explicit LowerBoundOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
BuildLowerUpperBoundOp(ctx, out_dtype_, xla::ComparisonDirection::kGt);
}
private:
DataType out_dtype_;
};
REGISTER_XLA_OP(Name("LowerBound"), LowerBoundOp);
class UpperBoundOp : public XlaOpKernel {
public:
explicit UpperBoundOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
BuildLowerUpperBoundOp(ctx, out_dtype_, xla::ComparisonDirection::kGe);
}
private:
DataType out_dtype_;
};
REGISTER_XLA_OP(Name("UpperBound"), UpperBoundOp);
} // namespace
} // namespace tensorflow

View File

@ -32,6 +32,8 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/pooling_ops_common.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
namespace {
@ -157,6 +159,13 @@ class MaxPoolOp : public PoolingOp {
OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
errors::InvalidArgument("Invalid data format"));
OP_REQUIRES(
ctx,
data_format_ != FORMAT_NCHW_VECT_C &&
data_format_ != FORMAT_NHWC_VECT_W,
errors::Unimplemented("XLA does not support the VECT_* data formats. "
"Returning unimplemented from MaxPool to keep "
"Tensorflow's intended optimized MaxPool here."));
}
void Compile(XlaOpKernelContext* ctx) override {

View File

@ -34,6 +34,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import bitwise_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_random_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
@ -200,6 +201,8 @@ shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper)
shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper)
igamma = _broadcasting_binary_op(math_ops.igamma)
igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a)
random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad)
igammac = _broadcasting_binary_op(math_ops.igammac)

View File

@ -26,22 +26,6 @@ const char kShardingAttribute[] = "_XlaSharding";
} // namespace
namespace {
xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
const NodeDef& node_def) {
if (!HasNodeAttr(node_def, kShardingAttribute)) {
return absl::optional<xla::OpSharding>();
}
string value;
xla::OpSharding sharding;
TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value));
if (!sharding.ParseFromString(value)) {
return xla::InvalidArgument(
"Experimental _XlaSharding attribute was not a valid encoded "
"xla::OpSharding proto.");
}
return absl::optional<xla::OpSharding>(sharding);
}
Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
return errors::InvalidArgument(
"Invalid replicated core id: ", core,
@ -107,4 +91,19 @@ void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) {
}
}
xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
const NodeDef& node_def) {
if (!HasNodeAttr(node_def, kShardingAttribute)) {
return absl::optional<xla::OpSharding>();
}
string value;
xla::OpSharding sharding;
TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value));
if (!sharding.ParseFromString(value)) {
return xla::InvalidArgument(
"Experimental _XlaSharding attribute was not a valid encoded "
"xla::OpSharding proto.");
}
return absl::optional<xla::OpSharding>(sharding);
}
} // namespace tensorflow

View File

@ -45,6 +45,10 @@ xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst);
// Get sharding inforamtion from node.
xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
const NodeDef& node_def);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_SHARDING_UTIL_H_

View File

@ -97,6 +97,7 @@ xla::StatusOr<DataType> EncodePrimitiveTypeAsDataType(xla::PrimitiveType type) {
{xla::U16, DT_UINT16},
{xla::U32, DT_UINT32},
{xla::U64, DT_UINT64},
{xla::C128, DT_COMPLEX128},
});
auto it = data_type_map.find(type);

View File

@ -139,6 +139,86 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
return Status::OK();
}
// Rewrites the layout of xla_shape if there is tiled sharding.
Status RewriteLayoutWithShardedShape(
const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
xla::Shape* xla_shape) {
if (sharding && !sharding->IsTileMaximal()) {
// After sharding, per core shape might have different layout. For example,
// before sharding, a shape [128, 128] will be assigned default
// minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2,
// the sharded shapes will have minor-to-major {0, 1}.
//
// As a result, for sharded shapes, we set their layout to per core shape's
// layout.
//
// TODO(endlessroad): for variable input & update, we might have
// different layouts which will prevent input output aliasing and
// increase memory usage. Investigate such cases.
int64 device = *sharding->tile_assignment().begin();
std::vector<int64> offset =
sharding->TileOffsetForDevice(*xla_shape, device);
std::vector<int64> limit = sharding->TileLimitForDevice(*xla_shape, device);
std::vector<int64> dimensions(xla_shape->rank());
for (int64 i = 0; i < xla_shape->rank(); ++i) {
dimensions[i] = limit[i] - offset[i];
}
xla::Shape per_device_xla_shape =
xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
TensorShape per_device_tensor_shape;
TF_RETURN_IF_ERROR(
XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape));
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
xla_shape->element_type()));
TF_ASSIGN_OR_RETURN(per_device_xla_shape,
shape_representation_fn(per_device_tensor_shape, dtype,
use_fast_memory));
*xla_shape->mutable_layout() = per_device_xla_shape.layout();
}
return Status::OK();
}
// There is a shape_representation_fn or sharding for an output, this function
// uses a reshape to fix the layout.
xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
absl::optional<xla::OpSharding> sharding, bool fast_mem) {
if (original_shape.IsTuple()) {
std::vector<xla::XlaOp> elements;
for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) {
auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
TF_ASSIGN_OR_RETURN(auto element,
ReshapeWithCorrectRepresentationAndSharding(
builder, xla::GetTupleElement(original, i),
original_shape.tuple_shapes(i),
shape_representation_fn, subsharding, fast_mem));
elements.push_back(element);
}
return xla::Tuple(builder, elements);
}
if (!original_shape.IsArray()) return original;
TensorShape shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape));
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
original_shape.element_type()));
TF_ASSIGN_OR_RETURN(auto to_shape,
shape_representation_fn(shape, dtype, fast_mem));
if (sharding) {
TF_ASSIGN_OR_RETURN(auto hlo_sharding,
xla::HloSharding::FromProto(*sharding));
TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
hlo_sharding, fast_mem, shape_representation_fn, &to_shape));
}
if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
for (int64 i = 0; i < original_shape.rank(); ++i) {
to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
}
}
return xla::Reshape(to_shape, original);
}
// Builds the XLA computation.
// - `args` is the list of input arguments
// - `retvals` is the list of retvals produced by _Retval operators, in index
@ -188,10 +268,6 @@ Status BuildComputation(
std::vector<xla::XlaOp> elems;
elems.reserve(retvals.size());
// Keeps track of the layout of each retval. If a retval is not in this list,
// a descending layout is used. The first element is the output index, second
// element is the new layout.
std::vector<std::pair<int64, xla::Layout>> retval_index_and_layout;
// Keeps track of sharding of each retval. If a retval is not in this list,
// replicate sharding is used. The first element is the output index, second
// element is the sharding.
@ -219,22 +295,22 @@ Status BuildComputation(
TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape());
xla::XlaOp value = retval.handle();
auto it = retval_shardings.find(i);
xla::XlaScopedShardingAssignment assign_sharding(
builder, it == retval_shardings.end()
? absl::optional<xla::OpSharding>()
: it->second);
absl::optional<xla::OpSharding> sharding =
it == retval_shardings.end() ? absl::optional<xla::OpSharding>()
: it->second;
if (it != retval_shardings.end()) {
retval_index_and_sharding[elems.size()] = it->second;
}
if (shape_representation_fn) {
// If there is a shape representation function, reshape the output
// tensor to the shape given by the representation shape function.
TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn(
output.shape, output.type,
/*use_fast_memory=*/false));
value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions()));
retval_index_and_layout.emplace_back(elems.size(), shape.layout());
} else if (it != retval_shardings.end()) {
TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(value));
TF_ASSIGN_OR_RETURN(value,
ReshapeWithCorrectRepresentationAndSharding(
builder, value, original_shape,
shape_representation_fn, sharding,
/*fast_mem=*/false));
}
if (it != retval_shardings.end()) {
xla::XlaScopedShardingAssignment assign_sharding(builder, sharding);
// Apply the sharding to the output, if there is a core assignment.
value = identity_op(value);
}
@ -312,43 +388,27 @@ Status BuildComputation(
update.tensor_array_gradients_accessed.insert(grad.first);
}
xla::XlaOp handle;
TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
auto sharding = it == arg_shardings.end()
? absl::optional<xla::OpSharding>()
: it->second;
// Set layout of the retval to device representation layout.
if (shape_representation_fn) {
TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(handle));
TF_ASSIGN_OR_RETURN(
handle, ReshapeWithCorrectRepresentationAndSharding(
builder, handle, original_shape,
shape_representation_fn, sharding, arg.fast_mem));
}
// Request that the value be returned on a specific core.
xla::XlaScopedShardingAssignment assign_sharding(
builder, it == arg_shardings.end() ? absl::optional<xla::OpSharding>()
: it->second);
xla::XlaScopedShardingAssignment assign_sharding(builder, sharding);
if (it != arg_shardings.end()) {
retval_index_and_sharding[elems.size()] = it->second;
}
xla::XlaOp handle;
TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
// Ensures the correct sharding is applied to the output.
handle = identity_op(handle);
// Set layout of the retval to device representation layout.
absl::optional<xla::Shape> representation_shape;
if (shape_representation_fn) {
TF_ASSIGN_OR_RETURN(
xla::Shape xla_shape,
shape_representation_fn(resource->shape(), resource->type(),
/*use_fast_memory=*/false));
representation_shape = xla_shape;
}
if (resource->representation_shape().has_value()) {
const xla::Shape& xla_shape = resource->representation_shape().value();
if (representation_shape) {
TF_RET_CHECK(
xla::ShapeUtil::Compatible(*representation_shape, xla_shape));
} else {
representation_shape = xla_shape;
}
}
if (representation_shape) {
retval_index_and_layout.emplace_back(elems.size(),
representation_shape->layout());
}
elems.push_back(handle);
}
}
@ -411,20 +471,8 @@ Status BuildComputation(
}
*computation = computation_status.ConsumeValueOrDie();
TF_ASSIGN_OR_RETURN(const auto& program_shape,
computation->GetProgramShape());
TF_ASSIGN_OR_RETURN(auto program_shape, computation->GetProgramShape());
*output_shape = program_shape.result();
// Update the output layout to the layout of retval.
for (auto& index_and_layout : retval_index_and_layout) {
if (!always_return_tuple && elems.size() == 1) {
*output_shape->mutable_layout() = index_and_layout.second;
continue;
}
xla::Shape* output_sub_shape = xla::ShapeUtil::GetMutableSubshape(
output_shape, {index_and_layout.first});
*output_sub_shape->mutable_layout() = index_and_layout.second;
}
return Status::OK();
}
@ -779,47 +827,6 @@ Status XlaCompiler::XLAShapeForArgument(
const XlaCompiler::Argument& arg, bool is_entry_computation,
const absl::optional<xla::HloSharding>& arg_sharding,
xla::Shape* xla_shape) const {
auto rewrite_layout_with_sharded_shape =
[](const absl::optional<xla::HloSharding>& arg_sharding,
bool use_fast_memory,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
xla::Shape* xla_shape) {
if (arg_sharding && !arg_sharding->IsTileMaximal()) {
// After parameter sharding, per core parameter might have different
// layout. For example, before sharding, a parameter of shape [128,
// 128] will be assigned default minor-to-major {1, 0}. But after we
// shard this parameter to [128, 64] * 2, the sharded parameters
// will have minor-to-major {0, 1}.
//
// As a result, for sharded parameters, we set their layout to per
// core parameter's layout.
//
// TODO(endlessroad): for variable input & update, we might have
// different layouts which will prevent input output aliasing and
// increase memory usage. Investigate such cases.
int64 device = *arg_sharding->tile_assignment().begin();
std::vector<int64> offset =
arg_sharding->TileOffsetForDevice(*xla_shape, device);
std::vector<int64> limit =
arg_sharding->TileLimitForDevice(*xla_shape, device);
std::vector<int64> dimensions(xla_shape->rank());
for (int64 i = 0; i < xla_shape->rank(); ++i) {
dimensions[i] = limit[i] - offset[i];
}
xla::Shape per_device_xla_shape =
xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
TensorShape per_device_tensor_shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(per_device_xla_shape,
&per_device_tensor_shape));
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
xla_shape->element_type()));
TF_ASSIGN_OR_RETURN(per_device_xla_shape,
shape_representation_fn(per_device_tensor_shape,
dtype, use_fast_memory));
*xla_shape->mutable_layout() = per_device_xla_shape.layout();
}
return Status::OK();
};
switch (arg.kind) {
case XlaCompiler::Argument::kConstant:
LOG(FATAL) << "Unreachable case";
@ -835,7 +842,7 @@ Status XlaCompiler::XLAShapeForArgument(
TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn(
shape, arg.type,
/*use_fast_memory=*/false));
TF_RETURN_IF_ERROR(rewrite_layout_with_sharded_shape(
TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
arg_sharding, /*use_fast_memory=*/false,
options_.shape_representation_fn, xla_shape));
} else {
@ -863,7 +870,7 @@ Status XlaCompiler::XLAShapeForArgument(
options_.shape_representation_fn(
absl::get<TensorShape>(arg.shape), arg.type,
/*use_fast_memory=*/arg.fast_mem));
TF_RETURN_IF_ERROR(rewrite_layout_with_sharded_shape(
TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
arg_sharding, arg.fast_mem, options_.shape_representation_fn,
xla_shape));
return Status::OK();

View File

@ -365,7 +365,8 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForFastMemVar) {
compile_options.return_updated_values_for_all_resources = true;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
args, &result));
EXPECT_EQ(fast_mem_arg_count, 1);
// Count 2: one for argument, one for the return value.
EXPECT_EQ(fast_mem_arg_count, 2);
}
// Tests that the compiler can correctly propagate the layout assigned by
@ -417,6 +418,8 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) {
// Check that the return shapes are correctly tranposed.
EXPECT_EQ(result.xla_output_shape,
xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
EXPECT_EQ(result.computation->GetProgramShape().ConsumeValueOrDie().result(),
xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
}
// The layout of resource variable shouldn't change after transpose
@ -1091,6 +1094,8 @@ TEST_F(XlaCompilerTest, ResultLayoutSingle) {
EXPECT_TRUE(xla::ShapeUtil::Equal(
result.xla_output_shape,
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1})));
EXPECT_EQ(result.computation->GetProgramShape().ConsumeValueOrDie().result(),
result.xla_output_shape);
}
TEST_F(XlaCompilerTest, ResultLayoutMultiple) {
@ -1131,6 +1136,8 @@ TEST_F(XlaCompilerTest, ResultLayoutMultiple) {
EXPECT_TRUE(xla::ShapeUtil::Equal(
result.xla_output_shape,
xla::ShapeUtil::MakeTupleShape({result_shape, result_shape})));
EXPECT_EQ(result.computation->GetProgramShape().ConsumeValueOrDie().result(),
result.xla_output_shape);
}
// Tests a simple graph that reads and writes a variable.

View File

@ -693,7 +693,10 @@ XlaOp Digamma(XlaOp input) {
namespace {
enum kIgammaMode { VALUE, DERIVATIVE, SAMPLE_DERIVATIVE };
// Helper function for computing Igamma using a power series.
template <kIgammaMode mode>
XlaOp IgammaSeries(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled,
xla::PrimitiveType type) {
// vals: (enabled, r, c, ans, x)
@ -715,24 +718,60 @@ XlaOp IgammaSeries(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled,
XlaOp c = vals[2];
XlaOp ans = vals[3];
XlaOp x = vals[4];
XlaOp dc_da = vals[5];
XlaOp dans_da = vals[6];
r = r + ScalarLike(r, 1);
dc_da = dc_da * (x / r) + (ScalarLike(r, -1) * c * x) / (r * r);
dans_da = dans_da + dc_da;
c = c * (x / r);
ans = ans + c;
XlaOp conditional;
if (mode == VALUE) {
conditional = And(enabled, Gt(c / ans, Epsilon(builder, type)));
} else {
conditional =
And(enabled, Gt(Abs(dc_da / dans_da), Epsilon(builder, type)));
}
return std::vector<XlaOp>{
And(enabled, Gt(c / ans, Epsilon(builder, type))),
Select(enabled, r, vals[1]), Select(enabled, c, vals[2]),
Select(enabled, ans, vals[3]), Select(enabled, x, vals[4])};
conditional,
Select(enabled, r, vals[1]),
Select(enabled, c, vals[2]),
Select(enabled, ans, vals[3]),
Select(enabled, x, vals[4]),
Select(enabled, dc_da, vals[5]),
Select(enabled, dans_da, vals[6]),
};
};
auto& b = *ax.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
std::vector<XlaOp> vals = {enabled, a, FullLike(a, 1), FullLike(a, 1), x};
std::vector<XlaOp> vals = {
enabled, a, FullLike(a, 1), FullLike(a, 1), x, FullLike(a, 0),
FullLike(a, 0),
};
TF_ASSIGN_OR_RETURN(vals, WhileLoopHelper(cond, body, vals, "igamma", &b));
XlaOp ans = vals[3];
return (ans * ax) / a;
XlaOp dans_da = vals[6];
if (mode == VALUE) {
return (ans * ax) / a;
}
XlaOp dlogax_da = Log(x) - Digamma(a + ScalarLike(a, 1));
switch (mode) {
case DERIVATIVE:
return ax * (ans * dlogax_da + dans_da) / a;
case SAMPLE_DERIVATIVE:
default:
return -(dans_da + ans * dlogax_da) * x / a;
}
});
}
// Helper function for computing Igammac using a continued fraction.
template <kIgammaMode mode>
XlaOp IgammacContinuedFraction(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled,
xla::PrimitiveType type) {
// vals: enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2
@ -754,6 +793,13 @@ XlaOp IgammacContinuedFraction(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled,
XlaOp qkm1 = vals[7];
XlaOp pkm2 = vals[8];
XlaOp qkm2 = vals[9];
XlaOp dpkm2_da = vals[10];
XlaOp dqkm2_da = vals[11];
XlaOp dpkm1_da = vals[12];
XlaOp dqkm1_da = vals[13];
XlaOp dans_da = vals[14];
c = c + ScalarLike(c, 1);
y = y + ScalarLike(y, 1);
z = z + ScalarLike(z, 2);
@ -762,18 +808,46 @@ XlaOp IgammacContinuedFraction(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled,
XlaOp qk = qkm1 * z - qkm2 * yc;
XlaOp qk_is_nonzero = Ne(qk, ScalarLike(qk, 0));
XlaOp r = pk / qk;
t = Select(qk_is_nonzero, Abs((ans - r) / r), FullLike(t, 1));
ans = Select(qk_is_nonzero, r, ans);
XlaOp dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c;
XlaOp dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c;
XlaOp dans_da_new =
Select(qk_is_nonzero, (dpk_da - ans * dqk_da) / qk, dans_da);
XlaOp grad_conditional =
Select(qk_is_nonzero, Abs(dans_da_new - dans_da), FullLike(dans_da, 1));
pkm2 = pkm1;
pkm1 = pk;
qkm2 = qkm1;
qkm1 = qk;
dpkm2_da = dpkm1_da;
dqkm2_da = dqkm1_da;
dpkm1_da = dpk_da;
dqkm1_da = dqk_da;
XlaOp rescale = Gt(Abs(pk), Reciprocal(Epsilon(builder, type)));
pkm2 = Select(rescale, pkm2 * Epsilon(builder, type), pkm2);
pkm1 = Select(rescale, pkm1 * Epsilon(builder, type), pkm1);
qkm2 = Select(rescale, qkm2 * Epsilon(builder, type), qkm2);
qkm1 = Select(rescale, qkm1 * Epsilon(builder, type), qkm1);
return std::vector<XlaOp>{And(enabled, Gt(t, Epsilon(builder, type))),
dpkm2_da = Select(rescale, dpkm2_da * Epsilon(builder, type), dpkm2_da);
dqkm2_da = Select(rescale, dqkm2_da * Epsilon(builder, type), dqkm2_da);
dpkm1_da = Select(rescale, dpkm1_da * Epsilon(builder, type), dpkm1_da);
dqkm1_da = Select(rescale, dqkm1_da * Epsilon(builder, type), dqkm1_da);
XlaOp conditional;
if (mode == VALUE) {
conditional = And(enabled, Gt(t, Epsilon(builder, type)));
} else {
conditional = And(enabled, Gt(grad_conditional, Epsilon(builder, type)));
}
return std::vector<XlaOp>{conditional,
Select(enabled, ans, vals[1]),
Select(enabled, t, vals[2]),
Select(enabled, y, vals[3]),
@ -782,7 +856,12 @@ XlaOp IgammacContinuedFraction(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled,
Select(enabled, pkm1, vals[6]),
Select(enabled, qkm1, vals[7]),
Select(enabled, pkm2, vals[8]),
Select(enabled, qkm2, vals[9])};
Select(enabled, qkm2, vals[9]),
Select(enabled, dpkm2_da, vals[10]),
Select(enabled, dqkm2_da, vals[11]),
Select(enabled, dpkm1_da, vals[12]),
Select(enabled, dqkm1_da, vals[13]),
Select(enabled, dans_da_new, vals[14])};
};
auto& b = *ax.builder();
@ -796,11 +875,31 @@ XlaOp IgammacContinuedFraction(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled,
XlaOp qkm1 = z * x;
XlaOp ans = pkm1 / qkm1;
XlaOp t = FullLike(x, 1);
std::vector<XlaOp> vals = {enabled, ans, t, y, z,
c, pkm1, qkm1, pkm2, qkm2};
XlaOp dpkm2_da = FullLike(x, 0);
XlaOp dqkm2_da = FullLike(x, 0);
XlaOp dpkm1_da = FullLike(x, 0);
XlaOp dqkm1_da = -x;
XlaOp dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1;
std::vector<XlaOp> vals = {enabled, ans, t, y, z,
c, pkm1, qkm1, pkm2, qkm2,
dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da};
TF_ASSIGN_OR_RETURN(vals, WhileLoopHelper(cond, body, vals, "igammac", &b));
ans = vals[1];
return ans * ax;
if (mode == VALUE) {
return ans * ax;
}
dans_da = vals[14];
XlaOp dlogax_da = Log(x) - Digamma(a);
switch (mode) {
case DERIVATIVE:
return ax * (ans * dlogax_da + dans_da);
case SAMPLE_DERIVATIVE:
default:
return -(dans_da + ans * dlogax_da) * x;
}
});
}
@ -820,9 +919,9 @@ XlaOp Igamma(XlaOp a, XlaOp x) {
const double nan = std::numeric_limits<double>::quiet_NaN();
XlaOp output = Select(
use_igammac,
ScalarLike(a, 1) -
IgammacContinuedFraction(ax, x, a, And(enabled, use_igammac), type),
IgammaSeries(ax, x, a, And(enabled, Not(use_igammac)), type));
ScalarLike(a, 1) - IgammacContinuedFraction<VALUE>(
ax, x, a, And(enabled, use_igammac), type),
IgammaSeries<VALUE>(ax, x, a, And(enabled, Not(use_igammac)), type));
output = Select(underflow, ZerosLike(output), output);
output = Select(x_is_zero, ZerosLike(output), output);
output = Select(Or(domain_error, is_nan), FullLike(a, nan), output);
@ -852,6 +951,101 @@ XlaOp Igamma(XlaOp a, XlaOp x) {
});
}
XlaOp IgammaGradA(XlaOp a, XlaOp x) {
auto& b = *a.builder();
auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp {
XlaOp is_nan = Or(IsNan(a), IsNan(x));
XlaOp x_is_zero = Eq(x, ScalarLike(x, 0));
XlaOp domain_error = Or(Lt(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0)));
XlaOp use_igammac = And(Gt(x, ScalarLike(x, 1)), Gt(x, a));
XlaOp ax = a * Log(x) - x - Lgamma(a);
XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type)));
ax = Exp(ax);
XlaOp enabled = Not(Or(Or(Or(x_is_zero, domain_error), underflow), is_nan));
const double nan = std::numeric_limits<double>::quiet_NaN();
XlaOp output = Select(use_igammac,
-IgammacContinuedFraction<DERIVATIVE>(
ax, x, a, And(enabled, use_igammac), type),
IgammaSeries<DERIVATIVE>(
ax, x, a, And(enabled, Not(use_igammac)), type));
output = Select(underflow, ZerosLike(output), output);
output = Select(x_is_zero, ZerosLike(output), output);
output = Select(Or(domain_error, is_nan), FullLike(a, nan), output);
return output;
};
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a));
TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x));
if (a_shape != x_shape) {
return InvalidArgument(
"Arguments to IgammaGradA must have equal shapes and types; got %s "
"and %s",
a_shape.ToString(), x_shape.ToString());
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a));
bool needs_upcast =
a_shape.element_type() == F16 || a_shape.element_type() == BF16;
if (needs_upcast) {
a = ConvertElementType(a, F32);
x = ConvertElementType(x, F32);
}
XlaOp result = doit(a, x, a_shape.element_type());
if (needs_upcast) {
result = ConvertElementType(result, a_shape.element_type());
}
return result;
});
}
// Gradient of Gamma sample from Gamma(a, 1) with respect to `a`.
XlaOp RandomGammaGrad(XlaOp a, XlaOp x) {
auto& b = *a.builder();
auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp {
XlaOp is_nan = Or(IsNan(a), IsNan(x));
XlaOp x_is_zero = Eq(x, ScalarLike(x, 0));
XlaOp domain_error = Or(Lt(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0)));
XlaOp use_igammac = And(Gt(x, ScalarLike(x, 1)), Gt(x, a));
XlaOp ax = a * Log(x) - x - Lgamma(a);
XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type)));
ax = Exp(ax);
XlaOp enabled = Not(Or(Or(Or(x_is_zero, domain_error), underflow), is_nan));
const double nan = std::numeric_limits<double>::quiet_NaN();
XlaOp output = Select(use_igammac,
-IgammacContinuedFraction<SAMPLE_DERIVATIVE>(
ax, x, a, And(enabled, use_igammac), type),
IgammaSeries<SAMPLE_DERIVATIVE>(
ax, x, a, And(enabled, Not(use_igammac)), type));
output = Select(underflow, ZerosLike(output), output);
output = Select(x_is_zero, ZerosLike(output), output);
output = Select(Or(domain_error, is_nan), FullLike(a, nan), output);
return output;
};
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a));
TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x));
if (a_shape != x_shape) {
return InvalidArgument(
"Arguments to RandomGammaGrad must have equal shapes and types; got "
"%s and %s",
a_shape.ToString(), x_shape.ToString());
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("RandomGammaGrad", a));
bool needs_upcast =
a_shape.element_type() == F16 || a_shape.element_type() == BF16;
if (needs_upcast) {
a = ConvertElementType(a, F32);
x = ConvertElementType(x, F32);
}
XlaOp result = doit(a, x, a_shape.element_type());
if (needs_upcast) {
result = ConvertElementType(result, a_shape.element_type());
}
return result;
});
}
XlaOp Igammac(XlaOp a, XlaOp x) {
auto& b = *a.builder();
auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp {
@ -863,10 +1057,10 @@ XlaOp Igammac(XlaOp a, XlaOp x) {
ax = Exp(ax);
XlaOp result =
Select(use_igamma,
ScalarLike(a, 1) -
IgammaSeries(ax, x, a, And(enabled, use_igamma), type),
IgammacContinuedFraction(ax, x, a, And(enabled, Not(use_igamma)),
type));
ScalarLike(a, 1) - IgammaSeries<VALUE>(
ax, x, a, And(enabled, use_igamma), type),
IgammacContinuedFraction<VALUE>(
ax, x, a, And(enabled, Not(use_igamma)), type));
return Select(underflow, ZerosLike(a),
Select(out_of_range, FullLike(a, 1), result));
};
@ -1008,12 +1202,23 @@ XlaOp Asinh(XlaOp x) {
if (primitive_util::IsComplexType(shape.element_type())) {
return Log(x + Sqrt(x * x + one));
}
// For small x, sqrt(x**2 + 1) will evaluate to 1 due to floating point
// arithmetic. However, we would like to retain the low order term of this,
// which is around 0.5 * x**2 using a binomial expansion.
// Let z = sqrt(a**2 + 1)
// log(a + sqrt(a**2 + 1)) =
// log((a + sqrt(a**2 + 1)) * (1 + sqrt(a**2 + 1)) / (1 + sqrt(a**2 + 1))) =
// log((a + a**2 + 1 + a * z + z) / (1 + z)) =
// log(1 + a + a**2 / (1 + z)) =
// log(1 + a + a ** 2 / (1 + sqrt(a**2 + 1)))
// This rewrite retains the lower order term.
auto a = Abs(x);
auto small_result = Log1p(a + a * a / (one + Sqrt(a * a + one)));
auto naive_result = Log(a + Sqrt(a * a + one));
auto overflow_result = Log(Abs(a)) + Log(ScalarLike(a, 2));
auto sqrt_max_value = Sqrt(MaxFiniteValue(b, shape.element_type()));
return Sign(x) *
Select(Ge(a, sqrt_max_value), overflow_result, naive_result);
return Sign(x) * Select(Ge(a, sqrt_max_value), overflow_result,
Select(Le(a, one), small_result, naive_result));
};
// These upcasts are not strictly necessary on all platforms to get within our
// error tolerances, so we could relax this if it ever mattered.
@ -1028,9 +1233,7 @@ XlaOp Atanh(XlaOp x) {
XlaBuilder* b = x.builder();
auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
auto naive_result =
Log((ScalarLike(x, 1.0) + x) / (ScalarLike(x, 1.0) - x)) *
ScalarLike(x, 0.5);
auto naive_result = (Log1p(x) - Log1p(-x)) * ScalarLike(x, 0.5);
// TODO(jlebar): For now, we ignore the nan edge case for complex inputs,
// because we don't yet have exhaustive tests for complex trig functions.
@ -1074,9 +1277,35 @@ XlaOp Cosh(XlaOp x) {
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
// we deem this acceptable.
XlaOp Sinh(XlaOp x) {
return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
XlaBuilder* b = x.builder();
auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
auto one_half = ScalarLike(x, 0.5);
auto log_one_half = Log(ScalarLike(x, 0.5));
return Exp(x + log_one_half) - Exp(-x + log_one_half);
auto large_sinh_result = Exp(x + log_one_half) - Exp(-x + log_one_half);
if (primitive_util::IsComplexType(shape.element_type())) {
return large_sinh_result;
}
// Here we use e^x = e^(x / 2) * e^(x / 2). This avoids overflow for large
// values of x.
// For smaller x, we get unwanted cancellations of e^x - e^-x, resulting in
// 0.
// Rewrite this to avoid that. We use expm1(x) because that preserves the
// first order term of the taylor series of e^x.
// (e^(x) - e^(-x)) / 2. =
// (e^(x) - 1 + 1 - e^(-x)) / 2.
// (expm1(x) + (e^(x) - 1) / e^x) / 2.
// (expm1(x) + expm1(x) / (expm1(x) + 1)) / 2.
auto expm1 = Expm1(x);
auto one = ScalarLike(x, 1.);
auto small_sinh_result = one_half * (expm1 + expm1 / (expm1 + one));
return Select(Lt(Abs(x), one), small_sinh_result, large_sinh_result);
};
return DoWithUpcastToF32(x, {BF16, F16}, [&](XlaOp x) {
return b->ReportErrorOrReturn(do_it(x));
});
}

View File

@ -61,6 +61,14 @@ XlaOp Digamma(XlaOp input);
// Computes an approximation of the incomplete gamma function.
XlaOp Igamma(XlaOp a, XlaOp x);
// Computes an approximation of the derivative of the incomplete gamma function
// with respect to a.
XlaOp IgammaGradA(XlaOp a, XlaOp x);
// Computes an approximation of the derivative of a sample `x` from a `Gamma(a,
// 1)` distribution with respect to a.
XlaOp RandomGammaGrad(XlaOp a, XlaOp x);
// Computes an approximation of the complementary incomplete gamma function.
XlaOp Igammac(XlaOp a, XlaOp x);

View File

@ -298,6 +298,30 @@ XLA_TEST_F(MathTest, SqrtSixValues) {
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(MathTest, SinhSmallValues) {
XlaBuilder builder(TestName());
auto x = ConstantR1<float>(&builder, {1e-3, 1e-5, 1e-7, 1e-9, 1e-11});
Sinh(x);
std::vector<float> expected = {1e-3, 1e-5, 1e-7, 1e-9, 1e-11};
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(MathTest, AsinhSmallValues) {
XlaBuilder builder(TestName());
auto x = ConstantR1<float>(&builder, {1e-3, 1e-5, 1e-7, 1e-9, 1e-11});
Asinh(x);
std::vector<float> expected = {1e-3, 1e-5, 1e-7, 1e-9, 1e-11};
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(MathTest, AtanhSmallValues) {
XlaBuilder builder(TestName());
auto x = ConstantR1<float>(&builder, {1e-8, 1e-9, 1e-10, 1e-11});
Atanh(x);
std::vector<float> expected = {1e-8, 1e-9, 1e-10, 1e-11};
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(MathTest, Lgamma) {
XlaBuilder builder(TestName());
auto x = ConstantR1<float>(&builder, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.5, 1.5,

View File

@ -528,7 +528,8 @@ StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
}
// Eliminate the size one dimensions.
TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand, Reshape(reshaped_shape, operand));
TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand,
ReshapeInternal(reshaped_shape, operand));
// Broadcast 'reshape' up to the larger size.
return InDimBroadcast(broadcast_shape, reshaped_operand,
broadcast_dimensions);
@ -828,8 +829,8 @@ XlaOp XlaBuilder::BroadcastInDim(
});
}
StatusOr<XlaOp> XlaBuilder::Reshape(const Shape& shape, XlaOp operand,
int64 inferred_dimension) {
StatusOr<XlaOp> XlaBuilder::ReshapeInternal(const Shape& shape, XlaOp operand,
int64 inferred_dimension) {
TF_RETURN_IF_ERROR(first_error_);
HloInstructionProto instr;
@ -1020,7 +1021,7 @@ XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> dimensions,
XlaOp transposed = IsIdentityPermutation(dimensions)
? operand
: Transpose(operand, dimensions);
return Reshape(shape, transposed, inferred_dimension);
return ReshapeInternal(shape, transposed, inferred_dimension);
});
}
@ -1034,6 +1035,13 @@ XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> new_sizes,
});
}
XlaOp XlaBuilder::Reshape(const Shape& shape, XlaOp operand,
int64 inferred_dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
return ReshapeInternal(shape, operand, inferred_dimension);
});
}
XlaOp XlaBuilder::Collapse(XlaOp operand, absl::Span<const int64> dimensions) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (dimensions.size() <= 1) {
@ -2951,6 +2959,10 @@ XlaOp Reshape(const XlaOp operand, absl::Span<const int64> new_sizes) {
return operand.builder()->Reshape(operand, new_sizes);
}
XlaOp Reshape(const Shape& shape, XlaOp operand) {
return operand.builder()->Reshape(shape, operand);
}
XlaOp ReshapeWithInferredDimension(XlaOp operand,
absl::Span<const int64> new_sizes,
int64 inferred_dimension) {

View File

@ -397,6 +397,9 @@ class XlaBuilder {
XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes,
int64 inferred_dimension = -1);
XlaOp Reshape(const Shape& shape, XlaOp operand,
int64 inferred_dimension = -1);
XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
@ -668,8 +671,8 @@ class XlaBuilder {
// Internal helper method for creating a Reshape op with the already inferred
// shape.
StatusOr<XlaOp> Reshape(const Shape& shape, XlaOp operand,
int64 inferred_dimension = -1);
StatusOr<XlaOp> ReshapeInternal(const Shape& shape, XlaOp operand,
int64 inferred_dimension = -1);
// Returns the (inferred) result for the program shape using the given root.
StatusOr<ProgramShape> GetProgramShape(int64 root_id) const;
@ -777,6 +780,8 @@ class XlaBuilder {
friend XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
friend XlaOp Reshape(const Shape& shape, XlaOp operand);
friend XlaOp ReshapeWithInferredDimension(XlaOp operand,
absl::Span<const int64> new_sizes,
int64 inferred_dimension);
@ -1252,6 +1257,9 @@ XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
// sizes. Conceptually, this is a limited form of "shape casting".
XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
// Enqueues a Reshape op that uses an explicit target shape.
XlaOp Reshape(const Shape& shape, XlaOp operand);
// `inferred_dimension` represents the output dimension that's inferred by
// upper-level framework by dividing the input element count by the known
// output element count. While an inferred_dimension can be static, if there

View File

@ -372,6 +372,7 @@ pybind_extension(
# not require Tensorflow.
"//tensorflow/core:lib_internal_impl", # buildcleaner: keep
"//tensorflow/core/profiler/lib:profiler_backends",
"//tensorflow/core/profiler/lib:profiler_session",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/profiler/rpc:profiler_server",
"//tensorflow/stream_executor:device_memory_allocator",

View File

@ -22,6 +22,7 @@ cc_library(
"//tensorflow/compiler/xla/python:local_client",
"//tensorflow/compiler/xla/python:semaphore",
"//tensorflow/compiler/xla/python/tpu_driver",
"//tensorflow/compiler/xla/python/tpu_driver:direct_tpu_driver",
"//tensorflow/compiler/xla/python/tpu_driver:grpc_tpu_driver",
"//tensorflow/compiler/xla/python/tpu_driver:recording_tpu_driver",
"//tensorflow/compiler/xla/python/tpu_driver:tpu_driver_proto_cc",

View File

@ -27,7 +27,8 @@
namespace tpu_driver {
namespace {
// Enable the macro by default in the env where the libtpu.so is available.
// Enable the macro by default in the Google internal environment where the
// libtpu.so is linked in statically.
#ifdef PLATFORM_GOOGLE
#define TPU_SHARED_LIBRARY_COMPILE_LINK 1
#endif

View File

@ -458,6 +458,8 @@ void BuildOpsSubmodule(py::module* m) {
ops.def("Igamma", &Igamma);
ops.def("Igammac", &Igammac);
ops.def("IgammaGradA", &IgammaGradA);
ops.def("RandomGammaGrad", &RandomGammaGrad);
ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta);
#define BINARY_OP(op) \

View File

@ -1698,6 +1698,7 @@ _BINARY_OPS = [
'ShiftRightLogical',
'Atan2',
'Igamma',
'IgammaGradA',
'Igammac',
'Complex',
'NextAfter',

View File

@ -3727,7 +3727,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
// Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced were
// batch dimensions of the dot. The transformation supports reducing other
// dimensions as well.
if (Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) &&
if (options_.enable_dot_strength_reduction() &&
Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) &&
Match(reduce->to_apply()->root_instruction(),
m::Add(m::Parameter(), m::Parameter())) &&
absl::c_any_of(reduce->dimensions(), [&](int64 dim) {

View File

@ -1043,15 +1043,31 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
HloInstruction* root = computation->root_instruction();
// Mark nondistinct/ambiguous indices.
absl::flat_hash_set<const HloBuffer*> seen;
absl::flat_hash_map<const HloBuffer*, ShapeIndex> seen;
ShapeUtil::ForEachSubshape(
root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
std::vector<const HloBuffer*> buffers_at_index =
alias_analysis->ComputeBuffersAt(root, index);
bool buffer_seen_before = false;
for (const HloBuffer* buffer : buffers_at_index) {
buffer_seen_before |= !seen.insert(buffer).second;
buffer_seen_before |= !seen.emplace(buffer, index).second;
}
if (buffer_seen_before && policy.copy_root_replicated_buffers &&
computation == module->entry_computation() &&
module->input_output_alias_config().OutputHasAlias(index) &&
buffers_at_index.size() == 1) {
absl::optional<HloInputOutputAliasConfig::Alias> alias =
module->input_output_alias_config().GetAliasedParameter(index);
CHECK(alias) << "Alias does not exist";
const ShapeIndex& other_index = seen[buffers_at_index[0]];
VLOG(2) << "Output indices " << index.ToString() << " and "
<< other_index.ToString() << " are both aliased to "
<< alias->parameter_number << " copying " << other_index;
add_index_to_copy(root, other_index);
return;
}
if (buffers_at_index.size() > 1 ||
(buffer_seen_before && policy.copy_root_replicated_buffers)) {
VLOG(2) << "Index " << index << " of computation "
@ -1097,6 +1113,18 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
return Status::OK();
}
static int64 GetNumExistingCopies(const HloModule* module) {
int64 num_existing_copies = 0;
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kCopy) {
++num_existing_copies;
}
}
}
return num_existing_copies;
}
Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
HloModule* module) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
@ -1112,13 +1140,24 @@ Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
}
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kCopy &&
copy_remover.TryElideCopy(instruction)) {
TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction));
TF_RETURN_IF_ERROR(
instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
int64 num_existing_copies = GetNumExistingCopies(module);
bool changed = true;
int64 num_iterations = -1;
while (changed) {
CHECK_LE(++num_iterations, num_existing_copies);
changed = false;
VLOG(2) << "Running fixpoint iteration " << num_iterations
<< " of copy elision";
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kCopy &&
copy_remover.TryElideCopy(instruction)) {
changed = true;
TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction));
TF_RETURN_IF_ERROR(
instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
}
}
}
}
@ -1156,17 +1195,6 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
"Call graph must be flattened before copy insertion.");
}
int64 num_existing_copies = 0;
if (VLOG_IS_ON(1)) {
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kCopy) {
++num_existing_copies;
}
}
}
}
TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module));
// Simplify the tuple structures introduced by the deep copies. This should be
@ -1185,7 +1213,6 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
RemoveUnnecessaryCopies(DependencyHloOrdering(module), module));
DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies",
*module);
TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
DumpHloModuleDuringPassIfEnabled(name(), "after adding special-case copies",
*module);
@ -1202,7 +1229,8 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
}
}
}
VLOG(1) << "Num copies before copy-insertion: " << num_existing_copies;
VLOG(1) << "Num copies before copy-insertion: "
<< GetNumExistingCopies(module);
VLOG(1) << "Num copies after copy-insertion: " << num_total_copies;
}

Some files were not shown because too many files have changed in this diff Show More