Merge branch 'master' into mazhar/auto_mixed_preci_bdw

This commit is contained in:
mazharul 2020-12-22 11:58:10 -08:00
commit 6c41436dac
54 changed files with 830 additions and 440 deletions

View File

@ -365,13 +365,9 @@ cc_library(
":flags",
":xla_activity_listener",
":xla_activity_proto_cc",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
"//tensorflow/compiler/mlir:array_container_utils",
"//tensorflow/compiler/mlir:mlir_bridge_rollout_policy",
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla:xla_context",
@ -386,13 +382,13 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:logging",
] + if_libtpu(
if_false = [
"//tensorflow/compiler/mlir:array_container_utils",
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
],
if_true = [],
),
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
tf_cc_test(

View File

@ -152,10 +152,11 @@ RecursiveCompilabilityChecker::FindUncompilableNodes(
if (node_stack_trace != nullptr) {
for (const auto& frame : *node_stack_trace) {
stack_trace.emplace_back(
StackFrameView{frame.name, frame.function_name, frame.n});
StackFrameView{frame.name, frame.function_name, frame.stack_trace});
}
}
stack_trace.emplace_back(StackFrameView{node.name(), "", &node});
stack_trace.emplace_back(
StackFrameView{node.name(), "", node.GetStackTrace()});
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes;
IsCompilableNode(node, lib_runtime, &stack_trace,
@ -175,7 +176,7 @@ RecursiveCompilabilityChecker::FindUncompilableNodes(
if (node_stack_trace != nullptr) {
for (const auto& frame : *node_stack_trace) {
stack_trace.emplace_back(
StackFrameView{frame.name, frame.function_name, frame.n});
StackFrameView{frame.name, frame.function_name, frame.stack_trace});
}
}
stack_trace.emplace_back(StackFrameView{call_def.name(), "", nullptr});
@ -361,7 +362,7 @@ bool RecursiveCompilabilityChecker::IsCompilableCall(
bool is_compilable = true;
for (const Node* node : fbody->graph->op_nodes()) {
stack_trace->emplace_back(
StackFrameView{node->name(), function.name(), node});
StackFrameView{node->name(), function.name(), node->GetStackTrace()});
is_compilable &= IsCompilableNode(*node, lib_runtime, stack_trace,
&function, uncompilable_nodes);
stack_trace->pop_back();
@ -586,7 +587,7 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
return StackFrame{
std::string(stack_element.name),
std::string(stack_element.function_name),
stack_element.n};
stack_element.stack_trace};
});
node_info.name = std::string(stack_trace.back().name);

View File

@ -62,7 +62,7 @@ class RecursiveCompilabilityChecker {
struct StackFrame {
std::string name;
std::string function_name;
const Node* n = nullptr;
std::shared_ptr<AbstractStackTrace> stack_trace;
};
// Contains information about uncompilable node inside a function body.
@ -197,7 +197,7 @@ class RecursiveCompilabilityChecker {
struct StackFrameView {
absl::string_view name;
absl::string_view function_name;
const Node* n = nullptr;
std::shared_ptr<AbstractStackTrace> stack_trace;
};
bool IsCompilableNode(

View File

@ -24,6 +24,8 @@ limitations under the License.
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_activity.pb.h"
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@ -48,11 +50,6 @@ limitations under the License.
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/dump_graph.h"
#if !defined(LIBTPU_ON_GCE)
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
#endif
namespace tensorflow {
constexpr int64 XlaCompilationCache::kDefaultCompilationThreshold;
@ -292,13 +289,6 @@ Status XlaCompilationCache::CompileSingleOp(
GetMlirBridgeRolloutPolicy(*graph, *config) ==
MlirBridgeRolloutPolicy::kEnabledByUser &&
node_def.op() != "VarIsInitializedOp";
#ifdef LIBTPU_ON_GCE
if (use_mlir) {
LOG(WARNING) << "MLIR is not supported in this environment.";
}
return compiler->CompileGraph(compile_options, node_def.name(),
std::move(graph), args, result);
#else
if (!use_mlir) {
return compiler->CompileGraph(compile_options, node_def.name(),
std::move(graph), args, result);
@ -314,7 +304,6 @@ Status XlaCompilationCache::CompileSingleOp(
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args), control_rets,
options.device_type.type_string(), compile_options.use_tuple_arg,
*options.flib_def, debug_info, options.shape_representation_fn, result);
#endif
};
return CompileImpl(options, name, args, compile_op,
/*compile_threshold=*/absl::nullopt,

View File

@ -123,13 +123,14 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
std::string node_message = absl::StrCat(
"\n", node_info.name, ": ", node_info.uncompilable_reason, "\n",
"The op is created at:\n");
const Node* n = node_info.stack_trace.back().n;
if (n && n->GetStackTrace()) {
if (node_info.stack_trace.back().stack_trace) {
AbstractStackTrace::TracePrintingOptions opts;
opts.show_line_contents = true;
opts.filter_common_prefix = true;
opts.drop_internal_frames = true;
absl::StrAppend(&node_message, n->GetStackTrace()->ToString(opts));
absl::StrAppend(
&node_message,
node_info.stack_trace.back().stack_trace->ToString(opts));
} else {
absl::StrAppend(&node_message, "<Unavailable>\n");
}

View File

@ -583,6 +583,7 @@ cc_library(
":map_hlo_to_lhlo_op",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
],
)
@ -654,6 +655,7 @@ cc_library(
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
],

View File

@ -66,6 +66,7 @@ MAP_HLO_TO_LHLO(MulOp);
MAP_HLO_TO_LHLO(NegOp);
MAP_HLO_TO_LHLO(NotOp);
MAP_HLO_TO_LHLO(OrOp);
MAP_HLO_TO_LHLO(PowOp);
MAP_HLO_TO_LHLO(RealOp);
MAP_HLO_TO_LHLO(ReduceOp);
MAP_HLO_TO_LHLO(ReshapeOp);

View File

@ -16,12 +16,16 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/iterator_range.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
namespace mlir {
@ -508,6 +512,40 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
lmhlo::PowOp::Adaptor adaptor(args);
// Floating point can use std::powf
auto result_type = result_types.front();
if (result_type.isa<::mlir::FloatType>())
return MapLhloOpToStdScalarOpImpl<::mlir::PowFOp>{}(loc, result_types, args,
b);
assert(result_type.isa<::mlir::IntegerType>() &&
"only float and integer `pow` is supported right now");
// There is no powi, so lower to a simple product. Note that HLO does not
// define semantics of negative exponents.
Value init = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 1));
Value lowerBound = b->create<ConstantIndexOp>(loc, 0);
Value upperBound =
b->create<IndexCastOp>(loc, adaptor.rhs(), b->getIndexType());
Value step = b->create<ConstantIndexOp>(loc, 1);
return b
->create<scf::ForOp>(
loc, lowerBound, upperBound, step, llvm::makeArrayRef(init),
[&](OpBuilder& b, Location l, Value v, ValueRange iters) {
Value prod =
b.create<::mlir::MulIOp>(l, adaptor.lhs(), iters.front());
b.create<scf::YieldOp>(l, prod);
})
.getResult(0);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,

View File

@ -650,6 +650,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
HloToLhloOpConverter<mhlo::NegOp>,
HloToLhloOpConverter<mhlo::NotOp>,
HloToLhloOpConverter<mhlo::OrOp>,
HloToLhloOpConverter<mhlo::PowOp>,
HloToLhloOpConverter<mhlo::RealOp>,
HloToLhloOpConverter<mhlo::RemOp>,
HloToLhloOpConverter<mhlo::RsqrtOp>,

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
@ -957,6 +958,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<lmhlo::NegOp>,
PointwiseToLinalgConverter<lmhlo::NotOp>,
PointwiseToLinalgConverter<lmhlo::OrOp>,
PointwiseToLinalgConverter<lmhlo::PowOp>,
PointwiseToLinalgConverter<lmhlo::RealOp>,
PointwiseToLinalgConverter<lmhlo::RemOp>,
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
@ -1021,13 +1023,14 @@ struct LhloLegalizeToLinalgPass
struct HloLegalizeToLinalgPass
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<linalg::LinalgDialect>();
registry.insert<linalg::LinalgDialect, scf::SCFDialect>();
}
void runOnFunction() override {
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
scf::SCFDialect>();
auto func = getFunction();
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
@ -1075,6 +1078,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<mhlo::NegOp, false>,
PointwiseToLinalgConverter<mhlo::NotOp, false>,
PointwiseToLinalgConverter<mhlo::OrOp, false>,
PointwiseToLinalgConverter<mhlo::PowOp, false>,
PointwiseToLinalgConverter<mhlo::RealOp, false>,
PointwiseToLinalgConverter<mhlo::RemOp, false>,
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,

View File

@ -745,3 +745,42 @@ func @constant() {
return
}
// CHECK: %[[CONSTANT:.*]] = constant dense<10> : tensor<i32>
// -----
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @float_pow
func @float_pow(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: ^{{[a-z0-9_]*}}
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: f32
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32
// CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = powf %[[ARG0]], %[[ARG1]]
// CHECK: linalg.yield %[[RESULT]]
%0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xf32>,
tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// -----
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @integer_pow
func @integer_pow(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: ^{{[a-z0-9_]*}}
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32
// CHECK: %[[UPPER:.*]] = index_cast %[[ARG1]]
// CHECK: %[[RESULT:.*]] = scf.for {{.*}} to %[[UPPER]]
// CHECK-SAME: step %c1{{[a-zA-Z0-9_]*}}
// CHECK-SAME: iter_args(%[[ITER:.*]] = %c1{{.*}}) -> (i32) {
// CHECK: %[[ACCUM:[a-zA-Z0-9_]*]] = muli %[[ARG0]], %[[ITER]]
// CHECK: scf.yield %[[ACCUM]]
// CHECK: linalg.yield %[[RESULT]]
%0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}

View File

@ -1,5 +1,20 @@
// RUN: mlir-hlo-opt %s -lhlo-legalize-to-linalg -split-input-file | FILECHECK_OPTS="" FileCheck %s
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @element_wise
func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xf32>) {
"lmhlo.power"(%lhs, %rhs, %result)
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32):
// CHECK-NEXT: %[[RESULT:.*]] = powf %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @element_wise
func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,

View File

@ -2062,4 +2062,28 @@ An op that groups a list of partitioned inputs together. This op
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
}
def TF_TPUPartitionedOutputOp : TF_Op<"TPUPartitionedOutput", [NoSideEffect]> {
let summary = [{
An op that demultiplexes a tensor to be sharded by XLA to a list of partitioned
}];
let description = [{
outputs outside the XLA computation.
}];
let arguments = (ins
TF_Tensor:$inputs,
DefaultValuedAttr<I64Attr, "0">:$partition_dim,
OptionalAttr<StrAttr>:$_XlaSharding
);
let results = (outs
Variadic<TF_Tensor>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedResultSizeAttr num_splits = TF_DerivedResultSizeAttr<0>;
}
#endif // TF_OPS

View File

@ -1276,20 +1276,6 @@ Status XlaCompiler::CompileGraph(
CompilationResult* result) {
VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name;
absl::optional<ConfigProto> config_proto;
MlirBridgeRolloutPolicy policy =
GetMlirBridgeRolloutPolicy(*graph, config_proto);
if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) {
VLOG(1) << "Using MLIR bridge";
GraphDebugInfo debug_info;
TF_RETURN_IF_ERROR(CompileGraphToXlaHlo(
std::move(*graph), mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
{}, options_.device_type.type_string(), options.use_tuple_arg,
*options_.flib_def, debug_info, options_.shape_representation_fn,
result));
return Status::OK();
}
TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes(
graph.get(), options_.flib_def, local_flib_def_.get()));
TF_RETURN_IF_ERROR(RearrangeFunctionArguments(

View File

@ -358,6 +358,14 @@ class BufferAssignment {
return allocations_;
}
// This is similar to copying Allocations(), but since it's moved out, it
// preserves the addresses. Since BufferAllocation::Slice keeps a
// BufferAllocation*, and some backends keep BufferAllocation::Slice in
// xla::Executables, migrating off the use of addresses can be hard.
std::vector<BufferAllocation> ReleaseAllocations() {
return std::move(allocations_);
}
// Returns the total size allocation holding all temporary buffers.
int64 temp_allocation_total_size() const {
return temp_allocation_total_size_;

View File

@ -672,8 +672,8 @@ cc_library(
"gpu_debug_info_manager.h",
],
deps = [
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_proto_cc",
"//tensorflow/compiler/xla/service:hlo_proto_util",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
@ -685,15 +685,10 @@ tf_cc_test(
srcs = ["gpu_debug_info_manager_test.cc"],
tags = tf_cuda_tests_tags(),
deps = [
":gpu_constants",
":gpu_debug_info_manager",
":gpu_hlo_schedule",
":stream_assignment",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:hlo_proto_cc",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)

View File

@ -34,13 +34,13 @@ namespace gpu {
Status BufferAllocations::TearDown(
const std::set<se::DeviceMemoryBase>& live_addresses,
const BufferAssignment* buffer_assignment) {
absl::Span<const BufferAllocation> allocations) {
// Deallocate temporary buffers, taking care to try to deallocate all of them
// even if one of the deallocations fails.
Status status;
const int64 num_buffers = buffer_assignment->Allocations().size();
const int64 num_buffers = allocations.size();
for (BufferAllocation::Index i = 0; i < num_buffers; ++i) {
const BufferAllocation& allocation = buffer_assignment->GetAllocation(i);
const BufferAllocation& allocation = allocations[i];
se::DeviceMemoryBase buffer_address = GetDeviceAddress(allocation.index());
// Deallocate buffers marked "maybe_live_out" but aren't actually live out,
// and temp buffers.

View File

@ -70,7 +70,7 @@ class BufferAllocations {
// Tears down all buffers allocated by this object that are not in
// `live_addresses`.
Status TearDown(const std::set<se::DeviceMemoryBase>& live_addresses,
const BufferAssignment* buffer_assignment);
absl::Span<const BufferAllocation> allocations);
std::string ToString() {
std::string out;

View File

@ -865,12 +865,22 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
thunk_schedule->ToString());
}
using OutputInfoMap =
absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>;
TF_ASSIGN_OR_RETURN(OutputInfoMap output_info,
GetOutputInfo(*module, *buffer_assignment));
auto buffer_assignment_proto =
std::make_unique<BufferAssignmentProto>(buffer_assignment->ToProto());
std::vector<BufferAllocation> allocations =
buffer_assignment->ReleaseAllocations();
GpuVersion gpu_version = GetGpuVersion(stream_exec);
auto* gpu_executable = new GpuExecutable(
backend_result.first, backend_result.second, gpu_version,
std::move(thunk_schedule), std::move(module),
std::move(buffer_assignment), std::move(profile_printer),
std::move(profile_index_map), std::move(constants));
{std::move(backend_result.first), std::move(backend_result.second),
gpu_version, std::move(thunk_schedule), std::move(constants),
std::move(output_info), std::move(module), std::move(allocations),
std::move(buffer_assignment_proto), std::move(profile_printer),
std::move(profile_index_map)});
if (embed_ir_in_executable) {
DCHECK_NE("", ir_module_string_before_opt);
gpu_executable->set_ir_module_string(ir_module_string_before_opt);

View File

@ -22,7 +22,7 @@ namespace gpu {
void GpuDebugInfoManager::RegisterModule(
const ModuleIdentifier& module_id, std::shared_ptr<HloModule> hlo_module,
std::shared_ptr<const BufferAssignment> buffer_assignment) {
std::shared_ptr<const BufferAssignmentProto> buffer_assignment) {
tensorflow::mutex_lock lock(mutex_);
if (active_modules_.find(module_id) != active_modules_.end()) {
active_modules_[module_id].instances.emplace_back(hlo_module,
@ -40,7 +40,7 @@ void GpuDebugInfoManager::RegisterModule(
// However during tracing, we will defer the cleanup after serialization.
void GpuDebugInfoManager::UnregisterModule(
const ModuleIdentifier& module_id, std::shared_ptr<HloModule> hlo_module,
std::shared_ptr<const BufferAssignment> buffer_assignment) {
std::shared_ptr<const BufferAssignmentProto> buffer_assignment) {
tensorflow::mutex_lock lock(mutex_);
CHECK(active_modules_.find(module_id) != active_modules_.end());
GpuModuleEntry& active_module = active_modules_[module_id];
@ -146,8 +146,10 @@ void GpuDebugInfoManager::StopTracing(
// non-nullptr. Due to the inconvenience of creation of buffer_assignment
// object in test, we set it to nullptr and guard this for it.
if (m.instances[0].hlo_module && m.instances[0].buffer_assignment) {
info.hlo_proto = absl::make_unique<HloProto>(MakeHloProto(
*m.instances[0].hlo_module, *m.instances[0].buffer_assignment));
info.hlo_proto = absl::make_unique<HloProto>(
MakeHloProto(*m.instances[0].hlo_module));
*info.hlo_proto->mutable_buffer_assignment() =
*m.instances[0].buffer_assignment;
}
module_debug_info->emplace_back(std::move(info));
}

View File

@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_DEBUG_INFO_MANAGER_H_
#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/core/lib/core/status.h"
@ -56,14 +56,14 @@ class GpuDebugInfoManager {
// Modules with same module id can be registered and tracked separately.
void RegisterModule(
const ModuleIdentifier& module_id, std::shared_ptr<HloModule> hlo_module,
std::shared_ptr<const BufferAssignment> buffer_assignment);
std::shared_ptr<const BufferAssignmentProto> buffer_assignment);
// Unregister an active module. When the last active module of the same
// module id is out of scope, we remove it from our database.
// However during tracing, we will defer the cleanup after serialization.
void UnregisterModule(
const ModuleIdentifier& module_id, std::shared_ptr<HloModule> hlo_module,
std::shared_ptr<const BufferAssignment> buffer_assignment);
std::shared_ptr<const BufferAssignmentProto> buffer_assignment);
// Register when the module start execution on certain device.
// TODO(jiesun): Do we need to track which device this is?
@ -110,10 +110,10 @@ class GpuDebugInfoManager {
// tracking, they need to be tracked separately.
struct GpuModuleInstance {
GpuModuleInstance(std::shared_ptr<HloModule> m,
std::shared_ptr<const BufferAssignment> b)
std::shared_ptr<const BufferAssignmentProto> b)
: hlo_module(std::move(m)), buffer_assignment(std::move(b)) {}
std::shared_ptr<HloModule> hlo_module;
std::shared_ptr<const BufferAssignment> buffer_assignment;
std::shared_ptr<const BufferAssignmentProto> buffer_assignment;
bool active = true;
};

View File

@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace xla {
@ -30,7 +30,7 @@ class GpuDebugInfoManagerTest : public HloTestBase {
int unique_id;
string id;
std::shared_ptr<HloModule> module;
std::shared_ptr<BufferAssignment> buffer_assignment;
std::shared_ptr<BufferAssignmentProto> buffer_assignment;
};
// Return unique id of this module.

View File

@ -54,31 +54,27 @@ using ::tensorflow::profiler::ScopedAnnotation;
// Implementation note: HLO profiling is always enabled for GPU executables,
// since we can use timers around thunks.
GpuExecutable::GpuExecutable(
const string& text, const std::vector<uint8>& binary,
GpuVersion gpu_version, std::unique_ptr<const ThunkSchedule> thunk_schedule,
std::shared_ptr<HloModule> hlo_module,
std::shared_ptr<const BufferAssignment> assignment,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map,
std::vector<ConstantInfo> globals)
: Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
std::move(hlo_profile_index_map)),
text_(text),
binary_(binary),
gpu_version_(gpu_version),
thunk_schedule_(std::move(thunk_schedule)),
assignment_(std::move(assignment)),
constants_(std::move(globals)) {
CHECK(has_module() && assignment_);
GpuExecutable::GpuExecutable(GpuExecutable::Params params)
: Executable(std::move(params.hlo_module),
std::move(params.hlo_profile_printer_data),
std::move(params.hlo_profile_index_map)),
text_(std::move(params.asm_text)),
binary_(std::move(params.binary)),
gpu_version_(params.gpu_version),
thunk_schedule_(std::move(params.thunk_schedule)),
allocations_(std::move(params.allocations)),
debug_buffer_assignment_(std::move(params.debug_buffer_assignment)),
constants_(std::move(params.constants)),
output_info_(std::move(params.output_info)) {
CHECK(has_module());
GpuDebugInfoManager::Get()->RegisterModule(module().name(), shared_module(),
assignment_);
debug_buffer_assignment_);
}
GpuExecutable::~GpuExecutable() {
CHECK(has_module() && assignment_);
CHECK(has_module());
GpuDebugInfoManager::Get()->UnregisterModule(module().name(), shared_module(),
assignment_);
debug_buffer_assignment_);
{
// We could have issued host->device mem copies in ResolveConstantGlobals.
@ -381,11 +377,11 @@ StatusOr<BufferAllocations> GpuExecutable::GenerateBufferAllocations(
[&] { return std::string("Build buffer allocations"); },
tensorflow::profiler::TraceMeLevel::kInfo);
const int64 num_buffers = assignment_->Allocations().size();
const int64 num_buffers = allocations_.size();
std::vector<se::DeviceMemoryBase> buffers;
buffers.reserve(num_buffers);
for (int64 i = 0; i < num_buffers; ++i) {
const BufferAllocation& allocation = assignment_->GetAllocation(i);
const BufferAllocation& allocation = allocations_[i];
TF_ASSIGN_OR_RETURN(
se::DeviceMemoryBase buffer,
BufferForAllocation(arguments, globals, allocation, memory_allocator,
@ -396,31 +392,6 @@ StatusOr<BufferAllocations> GpuExecutable::GenerateBufferAllocations(
return {{buffers, executor->device_ordinal(), memory_allocator}};
}
// Returns `true` if the entire tuple contents is aliased.
static bool EntireTupleContentsAliased(
const Shape& output_shape, const ShapeIndex& index,
const HloInputOutputAliasConfig& alias_config) {
const Shape& indexed_shape = ShapeUtil::GetSubshape(output_shape, index);
if (!indexed_shape.IsTuple()) {
return false;
}
bool all_aliased = true;
ShapeUtil::ForEachSubshape(
indexed_shape, [&](const Shape& subshape, const ShapeIndex& subindex) {
if (subindex.empty()) {
return;
}
std::vector<int64> full_index;
absl::c_copy(index, std::back_inserter(full_index));
absl::c_copy(subindex, std::back_inserter(full_index));
if (!alias_config.OutputHasAlias(
ShapeIndex(full_index.begin(), full_index.end()))) {
all_aliased = false;
}
});
return all_aliased;
}
StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ExecutionInput> arguments,
@ -432,10 +403,6 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
const bool block_host_until_done =
!memory_allocator->AllowsAsynchronousDeallocation();
if (GetRootValueSet().IsAmbiguous()) {
return Unimplemented("Points-to set of root instruction is ambiguous");
}
const GpuExecutable::BufferAllocToDeviceMemoryMap* globals;
{
tensorflow::profiler::TraceMe hlo_module_activity(
@ -458,33 +425,37 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
memory_allocator, executor));
VLOG(2) << buffer_allocations.ToString();
std::set<se::DeviceMemoryBase> buffers_in_result;
const bool is_entire_tuple_contents_aliased = [&] {
for (auto& p : result.MutableResult()->buffers().leaves()) {
const OutputInfo& output_info = output_info_.at(p.first);
if (!output_info.alias_config.has_value()) {
return false;
}
}
return true;
}();
for (auto& p : result.MutableResult()->buffers()) {
const ShapeIndex& index = p.first;
const OutputInfo& output_info = output_info_.at(index);
const BufferAllocation* allocation =
&allocations_[output_info.allocation_index];
se::DeviceMemoryBase& result_buffer = p.second;
const auto& sources = GetRootValueSet().element(index);
// The points-to set is unambiguous so the set should be a
// singleton. That is, we know exactly which instruction
// produced the array at this element.
CHECK_EQ(1, sources.values().size());
HloInstruction* src_hlo = sources.values()[0]->instruction();
VLOG(4) << "Looking at: " << src_hlo->ToString()
<< "@ index: " << index.ToString();
VLOG(4) << "Looking at: allocation " << output_info.allocation_index
<< " @ index: " << index.ToString();
const HloInputOutputAliasConfig& input_output_alias =
module().input_output_alias_config();
absl::optional<HloInputOutputAliasConfig::Alias> alias =
input_output_alias.GetAliasedParameter(index);
if (alias) {
CHECK_LT(alias->parameter_number, arguments.size());
ExecutionInput& input = arguments[alias->parameter_number];
if (output_info.alias_config) {
ExecutionInput& input = arguments[allocation->parameter_number()];
MaybeOwningDeviceMemory* maybe_owning_memory =
input.MutableBuffer(alias->parameter_index);
if (alias->must_alias() && !maybe_owning_memory->HasOwnership()) {
input.MutableBuffer(allocation->param_shape_index());
if (output_info.alias_config->must_alias() &&
!maybe_owning_memory->HasOwnership()) {
return InvalidArgument(
"An input was configured to be must-alias at "
"compile time but not donated at runtime: %s",
alias->ToString());
"compile time but not donated at runtime: allocation %d",
output_info.allocation_index);
}
if (absl::optional<se::OwningDeviceMemory> owning =
maybe_owning_memory->Release()) {
@ -504,7 +475,7 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
// the indices to drop the addresses from its own ScopedShapedBuffer
// result, if the ExecutionOutput is not committed.
result.AddAliasedIndex(index);
} else if (src_hlo->opcode() != HloOpcode::kParameter) {
} else if (!output_info.passthrough) {
// The guard is above is not to insert copy-protection when aliasing
// pass-through params, as we do not need to write into the output
// buffer.
@ -516,12 +487,9 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
se::OwningDeviceMemory allocated_buffer,
memory_allocator->Allocate(device_ordinal, allocation_size));
result_buffer = allocated_buffer.Release();
TF_ASSIGN_OR_RETURN(
const BufferAllocation::Slice slice,
assignment_->GetUniqueSlice(src_hlo, sources.values()[0]->index()));
CHECK_EQ(slice.offset(), 0) << "Parameter should get its own slice";
se::DeviceMemoryBase& aliased_buffer =
buffer_allocations.GetMutableDeviceAddress(slice.index());
buffer_allocations.GetMutableDeviceAddress(
output_info.allocation_index);
CHECK_EQ(aliased_buffer.size(), result_buffer.size());
run_options->stream()->ThenMemcpyD2D(&result_buffer, aliased_buffer,
aliased_buffer.size());
@ -532,15 +500,12 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
if (result_buffer.is_null()) {
// The source instruction should have a non-parameter buffer
// assigned.
TF_ASSIGN_OR_RETURN(
const BufferAllocation::Slice slice,
assignment_->GetUniqueSlice(src_hlo, sources.values()[0]->index()));
result_buffer = buffer_allocations.GetDeviceAddress(slice.index());
result_buffer =
buffer_allocations.GetDeviceAddress(output_info.allocation_index);
// If the entire tuple contents is aliased, the copy insertion will *not*
// materialize a new tuple, so we mark it as aliased as well.
if (EntireTupleContentsAliased(root->shape(), index,
input_output_alias)) {
if (is_entire_tuple_contents_aliased) {
result.AddAliasedIndex(index);
}
}
@ -556,18 +521,13 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
// Free all temporary allocations.
TF_RETURN_IF_ERROR(
buffer_allocations.TearDown(buffers_in_result, assignment_.get()));
buffer_allocations.TearDown(buffers_in_result, allocations_));
// Free allocations for arguments.
MarkToBeReleasedArguments(absl::MakeSpan(arguments), result);
return std::move(result);
}
const InstructionValueSet& GpuExecutable::GetRootValueSet() const {
return assignment_->dataflow_analysis().GetInstructionValueSet(
module().entry_computation()->root_instruction());
}
int64 GpuExecutable::SizeOfGeneratedCodeInBytes() const {
// Non-empty PTX but empty cubin: compilation must have failed, return
// "unknown".
@ -575,9 +535,8 @@ int64 GpuExecutable::SizeOfGeneratedCodeInBytes() const {
return -1;
}
int64 size = binary().size();
for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
++i) {
const BufferAllocation& allocation = assignment_->GetAllocation(i);
for (BufferAllocation::Index i = 0; i < allocations_.size(); ++i) {
const BufferAllocation& allocation = allocations_[i];
if (allocation.is_constant()) {
size += allocation.size();
}
@ -585,5 +544,46 @@ int64 GpuExecutable::SizeOfGeneratedCodeInBytes() const {
return size;
}
StatusOr<absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>>
GetOutputInfo(const HloModule& hlo_module, const BufferAssignment& assignment) {
const HloInstruction* root =
hlo_module.entry_computation()->root_instruction();
InstructionValueSet root_value_set =
assignment.dataflow_analysis().GetInstructionValueSet(root);
if (root_value_set.IsAmbiguous()) {
return Unimplemented("Points-to set of root instruction is ambiguous");
}
using OutputInfoMap =
absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>;
OutputInfoMap output;
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
root->shape(),
[&](const Shape& /*sub_shape*/, const ShapeIndex& index) -> Status {
const auto& sources = root_value_set.element(index);
// The points-to set is unambiguous so the set should be a
// singleton. That is, we know exactly which instruction
// produced the array at this element.
CHECK_EQ(1, sources.values().size());
HloInstruction* src_hlo = sources.values()[0]->instruction();
GpuExecutable::OutputInfo& info = output[index];
info.passthrough = src_hlo->opcode() == HloOpcode::kParameter;
TF_ASSIGN_OR_RETURN(
const BufferAllocation::Slice slice,
assignment.GetUniqueSlice(src_hlo, sources.values()[0]->index()));
CHECK_EQ(slice.offset(), 0) << "Parameter should get its own slice";
info.allocation_index = slice.index();
output[index].alias_config =
hlo_module.input_output_alias_config().GetAliasedParameter(index);
return Status::OK();
}));
return output;
}
} // namespace gpu
} // namespace xla

View File

@ -55,17 +55,36 @@ class GpuExecutable : public Executable {
int allocation_index = -1;
};
struct OutputInfo {
// Output is passed-through from a parameter.
bool passthrough;
// Corresponding allocation index.
int allocation_index;
// Whether this output is hinted to alias a parameter (BufferAllocation*
// would indicate the aliased parameter), and what kind of alias it is.
absl::optional<HloInputOutputAliasConfig::Alias> alias_config;
};
struct Params {
std::string asm_text;
std::vector<uint8> binary;
GpuVersion gpu_version;
std::unique_ptr<const ThunkSchedule> thunk_schedule;
std::vector<ConstantInfo> constants;
absl::flat_hash_map<ShapeIndex, OutputInfo> output_info;
std::unique_ptr<HloModule> hlo_module;
std::vector<BufferAllocation> allocations;
std::unique_ptr<BufferAssignmentProto> debug_buffer_assignment;
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data = nullptr;
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map = nullptr;
};
// We need to share ownership of hlo_module and assignment with profiler to
// safely keep a reference to these objects during tracing period, thus they
// are passed as shared pointers.
GpuExecutable(const string& text, const std::vector<uint8>& binary,
GpuVersion gpu_version,
std::unique_ptr<const ThunkSchedule> thunk_schedule,
std::shared_ptr<HloModule> hlo_module,
std::shared_ptr<const BufferAssignment> assignment,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map,
std::vector<ConstantInfo> constants);
explicit GpuExecutable(Params params);
~GpuExecutable() override;
int64 SizeOfGeneratedCodeInBytes() const override;
@ -94,8 +113,8 @@ class GpuExecutable : public Executable {
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile) override;
std::shared_ptr<const BufferAssignment> GetBufferAssignment() const {
return assignment_;
absl::Span<const BufferAllocation> GetAllocations() const {
return allocations_;
}
private:
@ -109,10 +128,6 @@ class GpuExecutable : public Executable {
bool block_host_until_done,
HloExecutionProfile* hlo_execution_profile);
// Returns the value set of the root instruction of the entry
// computation. Uses dataflow analysis from buffer assignment.
const InstructionValueSet& GetRootValueSet() const;
using BufferAllocToDeviceMemoryMap =
absl::flat_hash_map<BufferAllocation::Index, se::DeviceMemoryBase>;
@ -166,7 +181,9 @@ class GpuExecutable : public Executable {
// Owns the buffer data at runtime. It provides information to allocate
// memory for every output/temp buffers.
const std::shared_ptr<const BufferAssignment> assignment_;
const std::vector<BufferAllocation> allocations_;
std::shared_ptr<BufferAssignmentProto> debug_buffer_assignment_;
// Cache of module handles and constant buffer allocation maps used by
// `ResolveConstantGlobals`.
@ -177,10 +194,14 @@ class GpuExecutable : public Executable {
module_globals_ TF_GUARDED_BY(module_handle_mutex_);
std::vector<ConstantInfo> constants_;
const absl::flat_hash_map<ShapeIndex, OutputInfo> output_info_;
TF_DISALLOW_COPY_AND_ASSIGN(GpuExecutable);
};
StatusOr<absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>>
GetOutputInfo(const HloModule& hlo_module, const BufferAssignment& assignment);
} // namespace gpu
} // namespace xla

View File

@ -46,12 +46,9 @@ class GemmRewriteTest : public GpuCodegenTest {
backend().default_stream_executor()->GetAllocator()));
GpuExecutable* gpu_executable =
static_cast<GpuExecutable*>(executable.get());
std::shared_ptr<const BufferAssignment> buffer_assignment =
gpu_executable->GetBufferAssignment();
CHECK_EQ(buffer_assignment->Allocations().size(),
expected_number_of_allocations)
<< "Unexpected buffer assignment. Was:\n"
<< buffer_assignment->ToString();
absl::Span<const BufferAllocation> allocations =
gpu_executable->GetAllocations();
CHECK_EQ(allocations.size(), expected_number_of_allocations);
}
};

View File

@ -302,83 +302,6 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
return Status::OK();
}
if (IsCustomCallToDnnConvolution(*custom_call)) {
std::vector<BufferAllocation::Slice> operand_slices;
operand_slices.reserve(custom_call->operand_count());
for (const auto* operand : custom_call->operands()) {
operand_slices.push_back(GetAllocationSlice(*operand));
}
auto conv_result_slice = GetAllocationSlice(*custom_call, {0});
auto scratch_slice = GetAllocationSlice(*custom_call, {1});
// Assert that the tuple slice is not used by anyone directly. That is, all
// users of the tuple output are get-tuple-element. Also assert that the
// second element of the tuple (the scratch buffer) is not used by anyone.
for (const HloInstruction* user : custom_call->users()) {
TF_RET_CHECK(user->opcode() == HloOpcode::kGetTupleElement &&
user->tuple_index() == 0);
}
TF_ASSIGN_OR_RETURN(
GpuConvConfig config,
GetGpuConvConfig(Cast<HloCustomCallInstruction>(custom_call)));
AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>(
context_->GetThunkInfo(custom_call), std::move(config),
std::move(operand_slices), conv_result_slice, scratch_slice));
return Status::OK();
}
if (IsCublasGemm(*custom_call)) {
AddThunkToThunkSequence(BuildGemmThunk(custom_call));
return Status::OK();
}
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
if (custom_call->custom_call_target() == kCusolverCholeskyCallTarget) {
TF_ASSIGN_OR_RETURN(CholeskyOptions options,
custom_call->backend_config<CholeskyOptions>());
const Shape& shape = custom_call->operand(0)->shape();
int ndim = shape.dimensions_size();
CHECK_GE(ndim, 2);
int64 n = shape.dimensions(ndim - 1);
const auto& dims = shape.dimensions();
int64 batch_size = std::accumulate(dims.begin(), dims.end() - 2, int64{1},
[](int64 a, int64 b) { return a * b; });
auto operand_buffer = GetAllocationSlice(*custom_call->operand(0));
auto a_buffer = GetAllocationSlice(*custom_call, {0});
auto workspace_buffer = GetAllocationSlice(*custom_call, {1});
auto info_buffer = GetAllocationSlice(*custom_call, {2});
std::vector<std::unique_ptr<Thunk>> thunks;
if (operand_buffer != a_buffer) {
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
context_->GetThunkInfo(custom_call),
/*source_address=*/operand_buffer,
/*destination_buffer=*/a_buffer,
/*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
}
thunks.push_back(absl::make_unique<CholeskyThunk>(
context_->GetThunkInfo(custom_call), options, a_buffer,
workspace_buffer, info_buffer,
custom_call->operand(0)->shape().element_type(), batch_size, n));
// Elide the sequential thunk if there's no copy.
if (thunks.size() == 1) {
AddThunkToThunkSequence(std::move(thunks[0]));
} else {
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
context_->GetThunkInfo(custom_call), std::move(thunks)));
}
return Status::OK();
}
#endif
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)

View File

@ -585,11 +585,18 @@ StatusOr<std::unique_ptr<Executable>> MlirCompilerImpl::RunBackend(
"thunk_schedule", thunk_schedule->ToString());
}
module = emission_context.releaseHloModule();
TF_ASSIGN_OR_RETURN(auto output_info,
xla::gpu::GetOutputInfo(*module, *buffer_assignment));
std::vector<BufferAllocation> allocations =
buffer_assignment->ReleaseAllocations();
// TODO(b/137624192): Add profiling support.
return {absl::make_unique<GpuExecutable>(
ptx, cubin, GetGpuVersion(stream_exec), std::move(thunk_schedule),
emission_context.releaseHloModule(), std::move(buffer_assignment),
nullptr, nullptr, std::vector<GpuExecutable::ConstantInfo>())};
return {absl::make_unique<GpuExecutable>(GpuExecutable::Params{
std::move(ptx), std::move(cubin), GetGpuVersion(stream_exec),
std::move(thunk_schedule), std::vector<GpuExecutable::ConstantInfo>(),
std::move(output_info), std::move(module), std::move(allocations)})};
}
StatusOr<std::vector<std::unique_ptr<Executable>>> MlirCompilerImpl::Compile(

View File

@ -325,7 +325,8 @@ static Graph* FusedConv2DWithBatchNorm(
.Run(state); \
BM_SET_INFO(N, H, W, C, type, LABEL, Conv2D); \
} \
BENCHMARK(BM_NAME(BM_Conv2D, type, N, H, W, C, FW, FH, FC));
BENCHMARK(BM_NAME(BM_Conv2D, type, N, H, W, C, FW, FH, FC)) \
->Arg(/*unused arg*/ 1);
#define BM_Conv2DWithBias(N, H, W, C, FW, FH, FC, type, LABEL) \
static void BM_NAME(BM_Conv2DWithBias, type, N, H, W, C, FW, FH, \
@ -336,32 +337,35 @@ static Graph* FusedConv2DWithBatchNorm(
.Run(state); \
BM_SET_INFO(N, H, W, C, type, LABEL, Conv2D); \
} \
BENCHMARK(BM_NAME(BM_Conv2DWithBias, type, N, H, W, C, FW, FH, FC));
BENCHMARK(BM_NAME(BM_Conv2DWithBias, type, N, H, W, C, FW, FH, FC)) \
->Arg(/*unused arg*/ 1);
#define BM_Conv2DWithBiasAndRelu(N, H, W, C, FW, FH, FC, type, LABEL) \
static void BM_NAME(BM_Conv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, \
#define BM_Conv2DWithBiasAndRelu(N, H, W, C, FW, FH, FC, type, LABEL) \
static void BM_NAME(BM_Conv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, \
FC)(::testing::benchmark::State & state) { \
test::Benchmark( \
#type, \
Conv2DWithBiasAndActivation<float>(N, H, W, C, FW, FH, FC, "Relu") \
.graph, \
/*old_benchmark_api=*/false) \
.Run(state); \
BM_SET_INFO(N, H, W, C, type, LABEL, Conv2D); \
} \
BENCHMARK(BM_NAME(BM_Conv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, FC)) \
->Arg(/*unused arg*/ 1);
#define BM_FusedConv2DWithBias(N, H, W, C, FW, FH, FC, type, LABEL) \
static void BM_NAME(BM_FusedConv2DWithBias, type, N, H, W, C, FW, FH, \
FC)(::testing::benchmark::State & state) { \
test::Benchmark( \
#type, \
Conv2DWithBiasAndActivation<float>(N, H, W, C, FW, FH, FC, "Relu") \
.graph, \
FusedConv2DWithBias<float>(N, H, W, C, FW, FH, FC, {"BiasAdd"}), \
/*old_benchmark_api=*/false) \
.Run(state); \
BM_SET_INFO(N, H, W, C, type, LABEL, Conv2D); \
} \
BENCHMARK(BM_NAME(BM_Conv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, FC));
#define BM_FusedConv2DWithBias(N, H, W, C, FW, FH, FC, type, LABEL) \
static void BM_NAME(BM_FusedConv2DWithBias, type, N, H, W, C, FW, FH, \
FC)(::testing::benchmark::State & state) { \
test::Benchmark( \
#type, \
FusedConv2DWithBias<float>(N, H, W, C, FW, FH, FC, {"BiasAdd"}), \
/*old_benchmark_api=*/false) \
.Run(state); \
BM_SET_INFO(N, H, W, C, type, LABEL, Conv2D); \
} \
BENCHMARK(BM_NAME(BM_FusedConv2DWithBias, type, N, H, W, C, FW, FH, FC));
BENCHMARK(BM_NAME(BM_FusedConv2DWithBias, type, N, H, W, C, FW, FH, FC)) \
->Arg(/*unused arg*/ 1);
#define BM_FusedConv2DWithBiasAndRelu(N, H, W, C, FW, FH, FC, type, LABEL) \
static void BM_NAME(BM_FusedConv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, \
@ -374,7 +378,8 @@ static Graph* FusedConv2DWithBatchNorm(
BM_SET_INFO(N, H, W, C, type, LABEL, Conv2D); \
} \
BENCHMARK( \
BM_NAME(BM_FusedConv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, FC));
BM_NAME(BM_FusedConv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, FC)) \
->Arg(/*unused arg*/ 1);
#define BM_Conv2DWithBatchNorm(N, H, W, C, FW, FH, FC, type, LABEL) \
static void BM_NAME(BM_Conv2DWithBatchNorm, type, N, H, W, C, FW, FH, \
@ -385,7 +390,8 @@ static Graph* FusedConv2DWithBatchNorm(
.Run(state); \
BM_SET_INFO(N, H, W, C, type, LABEL, Conv2D); \
} \
BENCHMARK(BM_NAME(BM_Conv2DWithBatchNorm, type, N, H, W, C, FW, FH, FC));
BENCHMARK(BM_NAME(BM_Conv2DWithBatchNorm, type, N, H, W, C, FW, FH, FC)) \
->Arg(/*unused arg*/ 1);
#define BM_Conv2DWithBatchNormAndRelu(N, H, W, C, FW, FH, FC, type, LABEL) \
static void BM_NAME(BM_Conv2DWithBatchNormAndRelu, type, N, H, W, C, FW, FH, \
@ -399,7 +405,8 @@ static Graph* FusedConv2DWithBatchNorm(
BM_SET_INFO(N, H, W, C, type, LABEL, Conv2D); \
} \
BENCHMARK( \
BM_NAME(BM_Conv2DWithBatchNormAndRelu, type, N, H, W, C, FW, FH, FC));
BM_NAME(BM_Conv2DWithBatchNormAndRelu, type, N, H, W, C, FW, FH, FC)) \
->Arg(/*unused arg*/ 1);
#define BM_FusedConv2DWithBatchNorm(N, H, W, C, FW, FH, FC, type, LABEL) \
static void BM_NAME(BM_FusedConv2DWithBatchNorm, type, N, H, W, C, FW, FH, \
@ -411,7 +418,9 @@ static Graph* FusedConv2DWithBatchNorm(
.Run(state); \
BM_SET_INFO(N, H, W, C, type, LABEL, Conv2D); \
} \
BENCHMARK(BM_NAME(BM_FusedConv2DWithBatchNorm, type, N, H, W, C, FW, FH, FC));
BENCHMARK( \
BM_NAME(BM_FusedConv2DWithBatchNorm, type, N, H, W, C, FW, FH, FC)) \
->Arg(/*unused arg*/ 1);
#define BM_FusedConv2DWithBatchNormAndRelu(N, H, W, C, FW, FH, FC, type, \
LABEL) \
@ -425,7 +434,8 @@ static Graph* FusedConv2DWithBatchNorm(
BM_SET_INFO(N, H, W, C, type, LABEL, Conv2D); \
} \
BENCHMARK(BM_NAME(BM_FusedConv2DWithBatchNormAndRelu, type, N, H, W, C, FW, \
FH, FC));
FH, FC)) \
->Arg(/*unused arg*/ 1);
// -------------------------------------------------------------------------- //
// Pixel CNN convolutions.
@ -584,7 +594,8 @@ BM_FusedConv2DWithBiasAndRelu(32, 32, 32, 128, 3, 3, 1024, gpu, "3x3 /b 32");
.Run(state); \
BM_SET_INFO(N, H, W, C, type, "", Conv2D); \
} \
BENCHMARK(BM_LONG_NAME(BM_Conv2D, type, T, FORMAT, N, H, W, C, FW, FH, FC));
BENCHMARK(BM_LONG_NAME(BM_Conv2D, type, T, FORMAT, N, H, W, C, FW, FH, FC)) \
->Arg(/*unused arg*/ 1);
#if GOOGLE_CUDA
using fp32 = float;

View File

@ -24,9 +24,12 @@ REGISTER4(BinaryOp, CPU, "FloorDiv", functor::floor_div_real, float,
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER4(BinaryOp, GPU, "FloorDiv", functor::floor_div, uint8, uint16, int16,
int64);
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
REGISTER3(BinaryOp, GPU, "FloorDiv", functor::floor_div_real, float,
Eigen::half, double);
#endif
#endif
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// A special GPU kernel for int32.

View File

@ -364,8 +364,8 @@ struct gemm_pack_colmajor_block<
// Original input column and row after applying all non-standard strides and
// dilations. Computed by padOrSkip{Row,Col}.
Index orig_c;
Index orig_r;
Index orig_c = 0;
Index orig_r = 0;
for (StorageIndex col = 0; col < cols; ++col) {
SubMapper lm = rhs.getLinearMapper(0, col);

View File

@ -135,6 +135,7 @@ tf_kernel_library(
"gpu_op_bitwise_or.cc",
"gpu_op_bitwise_xor.cc",
"gpu_op_equal.cc",
"gpu_op_floor_div.cc",
"gpu_op_greater.cc",
"gpu_op_greater_equal.cc",
"gpu_op_left_shift.cc",
@ -155,6 +156,7 @@ tf_kernel_library(
":bitwise_or_kernels",
":bitwise_xor_kernels",
":equal_kernels",
":floor_div_kernels",
":gpu_ops_base",
":greater_equal_kernels",
":greater_kernels",
@ -537,6 +539,20 @@ gen_kernel_library(
]
]
gen_kernel_library(
name = "floor_div",
tile_size = "256",
# TODO(172804967): Enable for integer types also once unsigned integers are
# supported.
types = [
"f16",
"f32",
"f64",
],
# TODO(b/174543802): Enable once fusion heursitics is better.
# unroll_factors = "4",
)
# Kernels that support all floating-point types.
[
gen_kernel_library(

View File

@ -597,5 +597,25 @@ GENERATE_DEFAULT_TESTS_2(LogicalOr, /*test_name=*/Bool, /*T=*/bool,
/*BaselineOutT=*/bool, baseline_logical_or,
/*use_constraint=*/false)
/// Test `tf.FloorDiv`.
template <typename T>
T baseline_floor_div(T lhs, T rhs) {
return std::floor(lhs / rhs);
}
template <>
Eigen::half baseline_floor_div(Eigen::half lhs, Eigen::half rhs) {
return static_cast<Eigen::half>(std::floor(static_cast<float>(lhs / rhs)));
}
GENERATE_DEFAULT_TESTS(FloorDiv,
/*test_name=*/Half, Eigen::half, Eigen::half,
baseline_floor_div);
GENERATE_DEFAULT_TESTS(FloorDiv,
/*test_name=*/Float, float, float, baseline_floor_div);
GENERATE_DEFAULT_TESTS(FloorDiv,
/*test_name=*/Double, double, double,
baseline_floor_div);
} // namespace
} // end namespace tensorflow

View File

@ -0,0 +1,24 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL(FloorDiv, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_BINARY_KERNEL(FloorDiv, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_KERNEL(FloorDiv, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -92,6 +92,9 @@ template <typename T, std::enable_if_t<
llvm::is_one_of<T, Eigen::half, float, double>::value,
bool> = true>
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name = "") {
if (op_name == "FloorDiv")
return InputAsVector<T, double>({-18.0, -9.0, -1e-6, -0.1, 0.1, 1e-6, 0.1,
0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0});
return InputAsVector<T, double>({-18.0, -9.0, -1e-6, -0.0, 0.0, 1e-6, 0.1,
0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0});
}

View File

@ -0,0 +1,6 @@
func @FloorDiv_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.FloorDiv"(%arg0, %arg1) {T = elem_type, device = ""}
: (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>
}

View File

@ -226,6 +226,7 @@ class XLineBuilder {
int64 NumEvents() const { return line_->events_size(); }
absl::string_view Name() const { return line_->name(); }
void SetName(absl::string_view name) { line_->set_name(std::string(name)); }
void SetNameIfEmpty(absl::string_view name) {
@ -271,6 +272,7 @@ class XPlaneBuilder : public XStatsBuilder<XPlane> {
int64 Id() const { return plane_->id(); }
void SetId(int64 id) { plane_->set_id(id); }
absl::string_view Name() const { return plane_->name(); }
void SetName(absl::string_view name) { plane_->set_name(std::string(name)); }
void ReserveLines(size_t num_lines) {

View File

@ -124,10 +124,10 @@ std::unique_ptr<ComputeTaskDescriptor> SelectSoftmax(const OperationDef& op_def,
const BHWC& src_shape,
const GpuInfo& gpu_info) {
if (src_shape.w == 1 && src_shape.h == 1) {
auto gpu_op = Softmax1x1(op_def, gpu_info, src_shape.c);
auto gpu_op = Softmax1x1(op_def, gpu_info);
return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
} else {
auto gpu_op = Softmax(op_def, src_shape.c);
auto gpu_op = Softmax(op_def);
return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
}
}

View File

@ -39,10 +39,10 @@ using ::tflite::gpu::TensorRef;
using ::tflite::gpu::metal::CompareVectors;
using ::tflite::gpu::metal::SingleOpModel;
@interface SoftmaxTest : XCTestCase
@interface PReLUTest : XCTestCase
@end
@implementation SoftmaxTest
@implementation PReLUTest
- (void)setUp {
[super setUp];
}

View File

@ -36,10 +36,10 @@ using ::tflite::gpu::TensorRef;
using ::tflite::gpu::metal::CompareVectors;
using ::tflite::gpu::metal::SingleOpModel;
@interface SliceTest : XCTestCase
@interface ReLUTest : XCTestCase
@end
@implementation SliceTest
@implementation ReLUTest
- (void)setUp {
[super setUp];
}

View File

@ -50,21 +50,55 @@ kernel void ComputeFunction($1
uint tid[[thread_index_in_threadgroup]],
uint3 ugid[[thread_position_in_grid]])
{
int offset = 0;
float sum = 0.0f;
int s = 0;
do {
if (offset + tid < params.size.x) {
float4 mask_temp = offset + tid == params.size.x - 1 ? params.mask : float4(1.0h);
float4 src = float4(src_tensor[offset + tid]);
sum += dot(mask_temp, exp(src));
offset += 32;
}
s++;
} while (s < params.size.y);
float4 maxx4 = float4(src_tensor[0].x);
for (int s = int(tid); s < params.size.x; s += 32) {
float4 mask_a = s == params.size.x - 1 ? params.mask : float4(1.0f);
float4 mask_b = float4(1.0f) - mask_a;
float4 src = float4(src_tensor[s]);
src = src * mask_a + mask_b * src.x;
maxx4 = max(maxx4, src);
}
float maximum = max(maxx4.x, maxx4.y);
maximum = max(maximum, maxx4.z);
maximum = max(maximum, maxx4.w);
threadgroup float4 tmp[8];
threadgroup float* tmpx1 = (threadgroup float*)tmp;
tmpx1[tid] = maximum;
)";
code += " " + barrier + "(mem_flags::mem_threadgroup);\n";
code += R"(
if (tid == 0) {
maxx4 = max(tmp[0], tmp[1]);
maxx4 = max(maxx4, tmp[2]);
maxx4 = max(maxx4, tmp[3]);
maxx4 = max(maxx4, tmp[4]);
maxx4 = max(maxx4, tmp[5]);
maxx4 = max(maxx4, tmp[6]);
maxx4 = max(maxx4, tmp[7]);
maximum = max(maxx4.x, maxx4.y);
maximum = max(maximum, maxx4.z);
maximum = max(maximum, maxx4.w);
tmpx1[0] = maximum;
}
)";
code += " " + barrier + "(mem_flags::mem_threadgroup);\n";
code += R"(
maximum = tmpx1[0];
float sum = 0.0f;
for (int s = int(tid); s < params.size.x; s += 32) {
float4 mask_temp = s == params.size.x - 1 ? params.mask : float4(1.0f);
float4 src = float4(src_tensor[s]) - float4(maximum);
sum += dot(mask_temp, exp(src));
}
)";
code += " " + barrier + "(mem_flags::mem_threadgroup);\n";
code += R"(
tmpx1[tid] = sum;
)";
code += " " + barrier + "(mem_flags::mem_threadgroup);\n";
@ -85,74 +119,90 @@ kernel void ComputeFunction($1
code += R"(
sum = tmpx1[0];
offset = 0;
s = 0;
do {
if (offset + tid < params.size.x) {
int linear_index = offset + tid;
FLT4 value = FLT4(exp(float4(src_tensor[linear_index])) * sum);
uint3 gid = uint3(0, 0, linear_index);
$2
dst_tensor[linear_index] = value;
offset += 32;
}
s++;
} while (s < params.size.y);
int dst_s = int(ugid.x);
if (dst_s < params.size.x) {
int linear_index = dst_s;
float4 src = float4(src_tensor[linear_index]) - float4(maximum);
FLT4 value = FLT4(exp(src) * sum);
uint3 gid = uint3(0, 0, linear_index);
$2
dst_tensor[linear_index] = value;
}
})";
return code;
}
} // namespace
ComputeTaskDescriptor Softmax(const OperationDef& definition,
int channels_count) {
ComputeTaskDescriptor Softmax(const OperationDef& definition) {
ComputeTaskDescriptor desc(definition);
desc.shader_source = R"(
#include <metal_stdlib>
using namespace metal;
constant int src_channels = )";
desc.shader_source += std::to_string(channels_count);
desc.shader_source += R"(;
$0
kernel void ComputeFunction(
$1
uint3 gid[[thread_position_in_grid]]) {
if (int(gid.x) >= size.x || int(gid.y) >= size.y) {
return;
}
float shift = 0.0f;
int remaining_channels = src_channels % 4;
#include <metal_stdlib>
using namespace metal;
float sum = 0.0f;
for (int d = 0; d < src_channels / 4; ++d) {
int buffer_index = (d * size.y + gid.y) * size.x + gid.x;
sum += dot(float4(1.0f), exp(float4(src_tensor[buffer_index]) - shift));
}
if (remaining_channels > 0) {
int buffer_index = ((src_channels / 4) * size.y + gid.y) * size.x + gid.x;
float4 last_element = float4(src_tensor[buffer_index]);
sum += exp(last_element.x - shift);
if (remaining_channels > 1) sum += exp(last_element.y - shift);
if (remaining_channels == 3) sum += exp(last_element.z - shift);
}
struct uniforms {
int4 size;
float4 mask;
};
$0
kernel void ComputeFunction(
$1
uint3 gid[[thread_position_in_grid]]) {
if (int(gid.x) >= params.size.x || int(gid.y) >= params.size.y) {
return;
}
for (int d = 0; d < (src_channels + 3) / 4; ++d) {
const int linear_index = (d * size.y + gid.y) * size.x + gid.x;
FLT4 value = FLT4(exp(float4(src_tensor[linear_index]) - shift) / sum);
$2
dst_tensor[linear_index] = value;
}
}
float maximum = src_tensor[gid.y * params.size.x + gid.x].x;
for (int d = 0; d < params.size.z; ++d) {
int buffer_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
float4 mask_a = d == params.size.z - 1 ? params.mask : float4(1.0f);
float4 mask_b = float4(1.0f) - mask_a;
float4 src = float4(src_tensor[buffer_index]);
src = src * mask_a + mask_b * src.x;
maximum = max(maximum, src.x);
maximum = max(maximum, src.y);
maximum = max(maximum, src.z);
maximum = max(maximum, src.w);
}
float sum = 0.0f;
for (int d = 0; d < params.size.z; ++d) {
int buffer_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
float4 mask_temp = d == params.size.z - 1 ? params.mask : float4(1.0f);
float4 src = float4(src_tensor[buffer_index]) - float4(maximum);
sum += dot(mask_temp, exp(src));
}
for (int d = 0; d < params.size.z; ++d) {
const int linear_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
float4 src = float4(src_tensor[linear_index]) - float4(maximum);
FLT4 value = FLT4(exp(src) / sum);
$2
dst_tensor[linear_index] = value;
}
}
)";
desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
desc.uniform_buffers = {
{"constant int2& size",
{"constant uniforms& params",
[](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
std::vector<int> sizes{dst_shapes[0].w, dst_shapes[0].h};
return GetByteBuffer(sizes);
const int dst_depth = DivideRoundUp(dst_shapes[0].c, 4);
struct uniforms {
int4 size;
float4 mask;
};
uniforms params;
params.size = {dst_shapes[0].w, dst_shapes[0].h, dst_depth, 1};
params.mask = {0.0f, 0.0f, 0.0f, 0.0f};
int reminder = dst_shapes[0].c % 4 == 0 ? 4 : dst_shapes[0].c % 4;
for (int i = 0; i < reminder; ++i) {
params.mask[i] = 1.0f;
}
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&params);
return std::vector<uint8_t>(ptr, ptr + sizeof(uniforms));
}},
};
@ -168,7 +218,7 @@ ComputeTaskDescriptor Softmax(const OperationDef& definition,
}
ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
const GpuInfo& gpu_info, int channels_count) {
const GpuInfo& gpu_info) {
ComputeTaskDescriptor desc(definition);
desc.shader_source = GetSoftmax1x1Code(gpu_info);
@ -177,9 +227,9 @@ ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
desc.uniform_buffers = {
{"constant uniforms& params",
[channels_count](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
const int src_depth = DivideRoundUp(channels_count, 4);
[](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
const int src_depth = DivideRoundUp(dst_shapes[0].c, 4);
struct uniforms {
int4 size;
float4 mask;
@ -187,7 +237,7 @@ ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
uniforms params;
params.size = {src_depth, DivideRoundUp(src_depth, 32), 1, 1};
params.mask = {0.0f, 0.0f, 0.0f, 0.0f};
const int reminder = channels_count % 4 == 0 ? 4 : channels_count % 4;
int reminder = dst_shapes[0].c % 4 == 0 ? 4 : dst_shapes[0].c % 4;
for (int i = 0; i < reminder; ++i) {
params.mask[i] = 1.0f;
}
@ -198,7 +248,10 @@ ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
desc.resize_function = [](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
return std::make_pair(uint3{32u, 1u, 1u}, uint3{1u, 1u, 1u});
uint3 groups_size{32, 1, 1};
uint3 groups_count{
DivideRoundUp(DivideRoundUp(dst_shapes[0].c, 4), groups_size.x), 1, 1};
return std::make_pair(groups_size, groups_count);
};
return desc;

View File

@ -27,13 +27,12 @@ namespace tflite {
namespace gpu {
namespace metal {
ComputeTaskDescriptor Softmax(const OperationDef& definition,
int channels_count);
ComputeTaskDescriptor Softmax(const OperationDef& definition);
// Softmax for case when width = height = 1 and AXIS = CHANNELS
// We have this case in MobilenetV1/V2.
ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
const GpuInfo& gpu_info, int channels_count);
const GpuInfo& gpu_info);
} // namespace metal
} // namespace gpu

View File

@ -133,4 +133,76 @@ using ::tflite::gpu::metal::SingleOpModel;
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testSoftmaxBigNumber {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 1, 2);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 2, 1, 2);
SoftmaxAttributes attr;
attr.axis = Axis::CHANNELS;
double doubles[4] = {1.0, 2.0, 3.0, 100.0};
// exp(100) is inf in float (32 bit) but representable in double (64 bit)
XCTAssertTrue(std::isinf(std::exp(static_cast<float>(doubles[3]))));
XCTAssertFalse(std::isinf(std::exp(doubles[3])));
double s0 = std::exp(doubles[0]) + std::exp(doubles[1]);
double s1 = std::exp(doubles[2]) + std::exp(doubles[3]);
SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {static_cast<float>(doubles[0]),
static_cast<float>(doubles[1]),
static_cast<float>(doubles[2]),
static_cast<float>(doubles[3])}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({static_cast<float>(std::exp(doubles[0]) / s0),
static_cast<float>(std::exp(doubles[1]) / s0),
static_cast<float>(std::exp(doubles[2]) / s1),
static_cast<float>(std::exp(doubles[3]) / s1)},
model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testSoftmax1x1BigNumber {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 1, 4);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 1, 1, 4);
SoftmaxAttributes attr;
attr.axis = Axis::CHANNELS;
double doubles[4] = {1.0, 2.0, 3.0, 100.0};
// exp(100) is inf in float (32 bit) but representable in double (64 bit)
XCTAssertTrue(std::isinf(std::exp(static_cast<float>(doubles[3]))));
XCTAssertFalse(std::isinf(std::exp(doubles[3])));
double s0 = std::exp(doubles[0]) + std::exp(doubles[1]) +
std::exp(doubles[2]) + std::exp(doubles[3]);
SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {static_cast<float>(doubles[0]),
static_cast<float>(doubles[1]),
static_cast<float>(doubles[2]),
static_cast<float>(doubles[3])}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({static_cast<float>(std::exp(doubles[0]) / s0),
static_cast<float>(std::exp(doubles[1]) / s0),
static_cast<float>(std::exp(doubles[2]) / s0),
static_cast<float>(std::exp(doubles[3]) / s0)},
model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
@end

View File

@ -47,6 +47,7 @@ cc_library(
cc_test(
name = "parse_example_test",
srcs = ["parse_example_test.cc"],
tags = ["no_mac"], # TODO(b/176113117): Fails to load shared object
deps = [
":parse_example",
"@flatbuffers",

View File

@ -272,6 +272,7 @@ py_library(
":device_util",
":distribute_lib",
":reduce_util",
":sharded_variable",
":shared_variable_creator",
":tpu_values",
":values",
@ -1118,6 +1119,7 @@ tf_py_test(
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/module",
"//tensorflow/python/saved_model:load",
"//tensorflow/python/saved_model:loader",
"//tensorflow/python/saved_model:save",
"//tensorflow/python/saved_model:signature_constants",

View File

@ -611,6 +611,26 @@ class PSStrategySaveAndLoadTest(test.TestCase):
# ShardedVariable loading only works in v1.
self.assertAllEqual(self.load_and_run_v1(model_dir, {"x": 1}), [6, 6, 6, 6])
with self.assertRaisesWithLiteralMatch(
ValueError, "Loading `ShardedVariable` is not supported"):
with strategy.scope():
tf.saved_model.load(model_dir)
with self.assertRaisesWithLiteralMatch(
ValueError, "Loading `ShardedVariable` is not supported"):
tf.saved_model.load(model_dir)
def test_load_with_partitioner_raises_error(self):
model = self.Model()
model_dir = self.get_temp_dir()
tf.saved_model.save(model, model_dir)
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
self.cluster_resolver, tf1.fixed_size_partitioner(2))
with self.assertRaisesRegex(ValueError, "`variable_partitioner`"):
with strategy.scope():
tf.saved_model.load(model_dir)
if __name__ == "__main__":
# TODO(b/172304955): enable logical devices.

View File

@ -560,7 +560,13 @@ class ParameterServerStrategyV2Extended(
name = kwargs.get("name", None)
initial_value = kwargs.get("initial_value", None)
if initial_value is None:
raise ValueError("initial_value must be specified.")
raise ValueError(
"It looks like you are using `ParameterServerStrategy` with a "
"`variable_partitioner`, and trying to create a variable without "
"specifying `initial_value`. This is not allowed. Please specify the "
"`initial_value`. This can also happen if you are trying to load a "
"saved_model within a `ParameterServerStrategy` scope. Loading a "
"saved_model with `variable_partitioner` is not supported.")
# Two cases where initial_value can be a callable:
# 1. initial_value is passed as a callable, e.g, an `initializer` class.

View File

@ -28,6 +28,7 @@ from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.saved_model import revived_types
from tensorflow.python.saved_model import save_context
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.training.tracking import base as trackable
@ -500,3 +501,21 @@ def embedding_lookup(params,
return embedding_ops.embedding_lookup(params.variables, ids,
partition_strategy, name,
validate_indices, max_norm)
def _raise_when_load(_):
# We don't have serialization and deserialization mechanisms for
# `ShardedVariable` in 2.x style save/load yet.
raise ValueError('Loading `ShardedVariable` is not supported')
revived_types.register_revived_type(
'_tf_distribute_sharded_variable',
lambda obj: isinstance(obj, ShardedVariable),
versions=[
revived_types.VersionedTypeRegistration(
object_factory=_raise_when_load,
version=0,
min_producer_version=0,
min_consumer_version=0)
])

View File

@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import save
from tensorflow.python.saved_model import signature_constants
@ -300,6 +301,19 @@ class ShardedVariableTest(test.TestCase):
# Continue using root.train for training
self.assertAllEqual([3., 2.], root.train([0, 1]).numpy())
def test_load_raises_error(self):
root = tracking.AutoTrackable()
v1 = variables_lib.Variable([3.])
v2 = variables_lib.Variable([2.])
root.v = sharded_variable.ShardedVariable([v1, v2])
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
save.save(root, save_dir)
with self.assertRaisesWithLiteralMatch(
ValueError, 'Loading `ShardedVariable` is not supported'):
load.load(save_dir)
def test_validation_errors(self):
with self.assertRaisesRegex(ValueError, 'Expected a list of '):
sharded_variable.ShardedVariable(

View File

@ -809,6 +809,41 @@ def constant_value(tensor, partial=False): # pylint: disable=invalid-name
This function attempts to partially evaluate the given tensor, and
returns its value as a numpy ndarray if this succeeds.
Example usage:
>>> a = tf.constant(10)
>>> tf.get_static_value(a)
10
>>> b = tf.constant(20)
>>> tf.get_static_value(tf.add(a, b))
30
>>> # `tf.Variable` is not supported.
>>> c = tf.Variable(30)
>>> print(tf.get_static_value(c))
None
Using `partial` option is most relevant when calling `get_static_value` inside
a `tf.function`. Setting it to `True` will return the results but for the
values that cannot be evaluated will be `None`. For example:
```python
class Foo(object):
def __init__(self):
self.a = tf.Variable(1)
self.b = tf.constant(2)
@tf.function
def bar(self, partial):
packed = tf.raw_ops.Pack(values=[self.a, self.b])
static_val = tf.get_static_value(packed, partial=partial)
tf.print(static_val)
f = Foo()
f.bar(partial=True) # `array([None, array(2, dtype=int32)], dtype=object)`
f.bar(partial=False) # `None`
```
Compatibility(V1): If `constant_value(tensor)` returns a non-`None` result, it
will no longer be possible to feed a different value for `tensor`. This allows
the result of this function to influence the graph that is constructed, and

View File

@ -72,6 +72,7 @@ except ImportError:
requests = None
# Note: `configure_callbacks` is only used in TF1.
def configure_callbacks(callbacks,
model,
do_validation=False,

View File

@ -218,7 +218,7 @@ def _should_expand_composite(value):
# pylint: disable=protected-access
def _composite_to_tensors(value):
def _composite_to_tensors(value, is_batched=False):
"""Converts a CompositeTensor into a list of stackable tensors."""
if _should_expand_composite(value):
spec = value._type_spec
@ -227,6 +227,8 @@ def _composite_to_tensors(value):
"parallel_for or vectorized_map loop body must provide "
"a `BatchableTypeSpec` (saw: {}).".format(
value, spec))
if is_batched:
return spec._to_batched_tensor_list(value)
return spec._to_tensor_list(value)
return value
# pylint: enable=protected-access
@ -421,14 +423,26 @@ def _broadcasting_gather(x, i):
return result
# pylint: disable=protected-access
def _gather_from_tensor_or_composite(x, i):
"""Wrapper for gather that handles CompositeTensors."""
if _should_expand_composite(x):
spec = x._type_spec
gathered_tensors = [_broadcasting_gather(t, i)
for t in spec._to_batched_tensor_list(x)]
return spec._unbatch()._from_compatible_tensor_list(gathered_tensors)
return _broadcasting_gather(x, i)
# pylint: enable=protected-access
@tf_export("vectorized_map")
def vectorized_map(fn, elems, fallback_to_while_loop=True):
"""Parallel map on the list of tensors unpacked from `elems` on dimension 0.
This method works similar to `tf.map_fn` but is optimized to run much faster,
possibly with a much larger memory footprint. The speedups are obtained by
vectorization (see [Auto-Vectorizing TensorFlow Graphs: Jacobians,
Auto-Batching and Beyond](https://arxiv.org/pdf/1903.04243.pdf)). The idea
vectorization (see [Auto-Vectorizing TensorFlow Graphs: Jacobians,
Auto-Batching and Beyond](https://arxiv.org/pdf/1903.04243.pdf)). The idea
behind vectorization is to semantically launch all the invocations of `fn` in
parallel and fuse corresponding operations across all these invocations. This
fusion is done statically at graph generation time and the generated code is
@ -518,19 +532,21 @@ def vectorized_map(fn, elems, fallback_to_while_loop=True):
Raises:
ValueError: If vectorization fails and fallback_to_while_loop is False.
"""
def _convert_to_tensor_or_ndarray(x):
if isinstance(x, np_arrays.ndarray):
return x
return ops.convert_to_tensor(x)
elems = nest.map_structure(_convert_to_tensor_or_ndarray, elems)
elems = nest.map_structure(ops.convert_to_tensor,
elems,
expand_composites=True)
def loop_fn(i):
gathered_elems = nest.map_structure(lambda x: _broadcasting_gather(x, i),
elems)
gathered_elems = nest.map_structure(
lambda x: _gather_from_tensor_or_composite(x, i), elems)
return fn(gathered_elems)
# Extract batch size from the maximum first dimension of any element.
flat_elems = nest.flatten(elems)
flat_elems = nest.flatten(
nest.map_structure(
functools.partial(_composite_to_tensors,
is_batched=True),
elems))
def _get_shape(x):
if isinstance(x, np_arrays.ndarray):
x = x.data

View File

@ -70,6 +70,7 @@ from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
from tensorflow.python.ops.parallel_for.test_util import PForTestCase
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.signal import fft_ops
from tensorflow.python.platform import test
from tensorflow.python.util import nest
@ -2157,6 +2158,27 @@ class CompositeTensorTest(PForTestCase, parameterized.TestCase):
self.assertTrue(particles.mass.shape, [4, 1, 3])
self.assertAllEqual(particles.velocity.shape, [4, 5, 3])
def test_vectorized_map_gathers_composite_tensors(self):
particles = Particle(mass=[1., 2., 3., 4., 5.],
velocity=[1., 2., 3., 4., 5.])
self.assertAllEqual(
pfor_control_flow_ops.vectorized_map(
lambda x: x.mass * x.velocity, particles),
particles.mass * particles.velocity)
def test_vectorized_map_of_ragged_tensors(self):
# Vmap should be able to handle ragged Tensors as long as they're not
# *actually* ragged.
ragged = ragged_tensor.RaggedTensor.from_uniform_row_length(
ragged_tensor.RaggedTensor.from_row_lengths(
values=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
row_lengths=[3, 3, 3, 3]),
uniform_row_length=2) # Overall shape [2, 2, 3].
self.assertAllEqual(
pfor_control_flow_ops.vectorized_map(
lambda x: x.to_tensor(shape=[2, 3]), ragged),
ragged.to_tensor(shape=[2, 2, 3]))
class ParsingTest(PForTestCase):

View File

@ -320,7 +320,6 @@ tf_py_test(
name = "tf_stack_test",
srcs = ["tf_stack_test.py"],
python_version = "PY3",
tags = ["no_windows"], # TODO(b/175726972)
deps = [
":tf_export",
":tf_stack",

View File

@ -140,19 +140,50 @@ class StackTraceWrapper : public AbstractStackTrace {
}
absl::Span<StackFrame const> ToFrames() const override {
GenerateCache();
if (stack_frames_cache_) {
return *stack_frames_cache_;
}
// Grabbing the GIL solves two purposes: 1) makes the class thread-safe,
// and 2) ToStackFrames and LineContents actually need it.
PyGILState_STATE state = PyGILState_Ensure();
stack_frames_cache_ = captured_.ToStackFrames(
[&](std::pair<const char*, int> p) { return StackTraceMapping(p); },
[&](const char* f) { return StackTraceFiltering(f); });
stack_frames_cache_->pop_back(); // Drop last stack frame.
PyGILState_Release(state);
return *stack_frames_cache_;
}
StackFrame LastUserFrame() const override {
GenerateLastFrameCache();
if (last_stack_frame_cache_) {
return *last_stack_frame_cache_;
}
PyGILState_STATE state = PyGILState_Ensure();
std::vector<StackFrame> last_frame = captured_.ToStackFrames(
[&](std::pair<const char*, int> p) { return StackTraceMapping(p); },
[&](const char* file_name) {
return StackTraceFiltering(file_name) ||
IsInternalFrameForFilename(file_name);
},
/*reverse_traversal=*/true,
/*limit=*/1);
if (last_frame.empty()) {
last_stack_frame_cache_ = StackFrame{"", -1, ""};
} else {
DCHECK_EQ(last_frame.size(), 1);
last_stack_frame_cache_ = last_frame[0];
}
PyGILState_Release(state);
return *last_stack_frame_cache_;
}
std::string ToString(const TracePrintingOptions& opts) const override {
GenerateCache();
std::vector<std::string> files_to_find_prefix;
for (const StackFrame& frame : *stack_frames_cache_) {
for (const StackFrame& frame : ToFrames()) {
if (!absl::StrContains(frame.file_name, kFilenameToIgnorePrefix)) {
files_to_find_prefix.push_back(frame.file_name);
}
@ -175,50 +206,6 @@ class StackTraceWrapper : public AbstractStackTrace {
return ToStringHelper(filtered_frames, opts, shared_prefix_size);
}
bool IsCacheGenerated() const { return stack_frames_cache_.has_value(); }
void GenerateCache() const {
// TODO(mdan): We don't really need random access; this can be removed.
if (stack_frames_cache_) {
return;
}
// Grabbing the GIL solves two purposes: 1) makes the class thread-safe, and
// 2) ToStackFrames and LineContents actually need it.
PyGILState_STATE state = PyGILState_Ensure();
stack_frames_cache_ = captured_.ToStackFrames(
[&](std::pair<const char*, int> p) { return StackTraceMapping(p); },
[&](const char* f) { return StackTraceFiltering(f); });
stack_frames_cache_->pop_back(); // Drop last stack frame.
PyGILState_Release(state);
}
void GenerateLastFrameCache() const {
if (last_stack_frame_cache_) {
return;
}
PyGILState_STATE state = PyGILState_Ensure();
auto f = [&](const char* file_name) -> bool {
return StackTraceFiltering(file_name) ||
IsInternalFrameForFilename(file_name);
};
std::vector<StackFrame> last_frame = captured_.ToStackFrames(
[&](std::pair<const char*, int> p) { return StackTraceMapping(p); }, f,
/*reverse_traversal=*/true,
/*limit=*/1);
if (last_frame.empty()) {
last_stack_frame_cache_ = StackFrame{};
} else {
DCHECK(last_frame.size() == 1);
last_stack_frame_cache_ = last_frame[0];
}
PyGILState_Release(state);
}
StackTraceWrapper(StackTraceWrapper&&) = default;
~StackTraceWrapper() override {
PyGILState_STATE state = PyGILState_Ensure();
@ -242,7 +229,8 @@ class StackTraceWrapper : public AbstractStackTrace {
static bool IsInternalFrameForFilename(absl::string_view file_name) {
// Use a simple heuristic for now.
// TODO(cheshire): Build a more sophisticated mechanism, rely on @tf.export.
return absl::StrContains(file_name, "tensorflow/python") &&
return (absl::StrContains(file_name, "tensorflow/python") ||
absl::StrContains(file_name, "tensorflow\\python")) &&
!absl::StrContains(file_name, "keras") &&
!absl::StrContains(file_name, "test.py");
}
@ -392,12 +380,10 @@ PYBIND11_MODULE(_tf_stack, m) {
})
.def("__hash__",
[](const StackTraceWrapper& self) {
self.GenerateCache();
return py::hash(py::str(self.ToString({})));
})
.def("__repr__",
[](const StackTraceWrapper& self) {
self.GenerateCache();
return py::str(self.ToString({}));
})
.def("last_user_frame",