Merge remote-tracking branch 'upstream/master' into offline_memory_planner

This commit is contained in:
Jens Elofsson 2020-05-28 20:24:46 +02:00
commit f409152691
262 changed files with 7122 additions and 2920 deletions

View File

@ -91,7 +91,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
}
string message = absl::StrCat(
"Function invoked by the following node is not compilable: ",
SummarizeNodeDef(node_def), ".\n");
SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n");
absl::StrAppend(&message, "Uncompilable nodes:");
for (const auto& node_info : uncompilable_node_info) {
string node_message =

View File

@ -201,9 +201,7 @@ void XlaComputationLaunchContext::PopulateInputs(
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
// Build ShapedBuffers that point directly to the Tensor buffers.
arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1);
arg_buffers_.resize(kernel->xla_input_shapes.size());
arg_ptrs_ = std::vector<ShapedBuffer*>(arg_buffers_.size());
arg_ptrs_ = std::vector<ShapedBuffer*>(kernel->xla_input_shapes.size());
// Pass remaining parameters.
const Tensor* t;
@ -239,11 +237,11 @@ void XlaComputationLaunchContext::PopulateInputs(
<< " not the same as on-host shape "
<< xla::ShapeUtil::HumanStringWithLayout(shape);
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
arg_buffers_[i] = absl::make_unique<ShapedBuffer>(
arg_buffers_.emplace_back(
/*on_host_shape=*/shape, /*on_device_shape=*/shape,
client_->platform(), client_->default_device_ordinal());
arg_buffers_[i]->set_buffer(dmem, /*index=*/{});
arg_ptrs_[i] = arg_buffers_[i].get();
arg_buffers_.back().set_buffer(dmem, /*index=*/{});
arg_ptrs_[i] = &arg_buffers_.back();
}
}
}

View File

@ -165,7 +165,7 @@ class XlaComputationLaunchContext {
se::DeviceMemoryAllocator* xla_allocator_;
bool allocate_xla_tensors_;
bool use_multiple_streams_;
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers_;
std::deque<xla::ShapedBuffer> arg_buffers_;
std::vector<xla::ShapedBuffer*> arg_ptrs_;
};

View File

@ -279,22 +279,6 @@ cc_library(
],
)
tf_cc_test(
name = "tftext_utils_test",
size = "small",
srcs = ["utils/lstm_utils_test.cc"],
deps = [
":lstm_utils",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
cc_library(
name = "stateful_ops_utils",
srcs = [

View File

@ -2297,26 +2297,17 @@ def TFL_PReluOp : TFL_Op<"prelu", [
NoSideEffect,
ResultsBroadcastableShape,
TFL_GpuTargetOp,
TFL_OperandHasRankAtMost<0, 4>,
TFL_OperandHasRankAtMost<1, 4>,
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
BinaryOpSameElementTypeConstraint,
PredOpTrait<"input and output must have the same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
PredOpTrait<"'alpha' should have one less rank than 'input'.",
Or<[TFL_OperandIsUnrankedPred<0>,
TFL_OperandIsUnrankedPred<1>,
CPred<"$_op.getOperand(0).getType().cast<ShapedType>().getRank() == "
"$_op.getOperand(1).getType().cast<ShapedType>().getRank() "
"+ 1">]>>]> {
TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
let summary = "Parameterized Relu operator";
let description = [{
Parameterized Relu operator
x -> x >= 0 ? x : (alpha * x)
where alpha is a trainable tensor.
alpha should have one less rank than the input as it doesn't have the batch
dimension, and the other dimensions either should be the same size as input
or size 1, where it is broadcasted in the second case.
input and alpha should be the same size as input or be broadcastable.
}];
let arguments = (

View File

@ -55,8 +55,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
std::vector<string> node_names;
std::vector<string> node_dtypes;
std::vector<std::vector<int>> node_shapes;
std::vector<double> node_mins;
std::vector<double> node_maxs;
std::vector<llvm::Optional<double>> node_mins;
std::vector<llvm::Optional<double>> node_maxs;
// Populate quantization specs.
TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(

View File

@ -125,8 +125,8 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
std::vector<string> node_names;
std::vector<string> node_dtypes;
std::vector<std::vector<int>> node_shapes;
std::vector<double> node_mins;
std::vector<double> node_maxs;
std::vector<llvm::Optional<double>> node_mins;
std::vector<llvm::Optional<double>> node_maxs;
// Populate quantization specs.
TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(

View File

@ -177,14 +177,13 @@ Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) {
return RegisterCustomBuiltinOps(extra_tf_opdefs);
}
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs,
std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<double>* node_mins,
std::vector<double>* node_maxs) {
Status PopulateQuantizationSpecs(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<llvm::Optional<double>>* node_mins,
std::vector<llvm::Optional<double>>* node_maxs) {
quant_specs->inference_input_type =
ConvertIODataTypeToDataType(toco_flags.inference_input_type());
tensorflow::DataType inference_type =
@ -211,11 +210,16 @@ Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
flag.shape().dims().end()));
// Currently, only UINT8 and INT8 require inputs stats
if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
TF_ASSIGN_OR_RETURN(
auto min_max, InputStatsToMinMax(flag.mean_value(), flag.std_value(),
inference_type));
node_mins->push_back(min_max.first);
node_maxs->push_back(min_max.second);
if (flag.has_mean_value() && flag.has_std_value()) {
TF_ASSIGN_OR_RETURN(
auto min_max, InputStatsToMinMax(flag.mean_value(),
flag.std_value(), inference_type));
node_mins->push_back(min_max.first);
node_maxs->push_back(min_max.second);
} else {
node_mins->push_back(llvm::None);
node_maxs->push_back(llvm::None);
}
}
}

View File

@ -34,14 +34,13 @@ Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags);
// Populate quantization specs (or not) given user specified ranges for each
// input arrays.
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs,
std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<double>* node_mins,
std::vector<double>* node_maxs);
Status PopulateQuantizationSpecs(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<llvm::Optional<double>>* node_mins,
std::vector<llvm::Optional<double>>* node_maxs);
// Convert imported MLIR file to TfLite flatbuffer.
// This will also run relevant passes as well.

View File

@ -45,7 +45,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
absl::string_view inference_type,
QuantizationSpecs* quant_specs) {
std::vector<std::string> input_nodes = absl::StrSplit(node_names, ',');
std::vector<double> node_mins;
std::vector<llvm::Optional<double>> node_mins;
if (!min_values.empty()) {
std::vector<std::string> node_mins_str = absl::StrSplit(min_values, ',');
for (int i = 0; i < node_mins_str.size(); i++) {
@ -57,7 +57,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
}
}
std::vector<double> node_maxs;
std::vector<llvm::Optional<double>> node_maxs;
if (!max_values.empty()) {
std::vector<std::string> node_maxs_str = absl::StrSplit(max_values, ',');
for (int i = 0; i < node_maxs_str.size(); i++) {
@ -79,11 +79,11 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
quant_specs);
}
bool GetInputNodeQuantSpecs(const std::vector<std::string>& node_names,
const std::vector<double>& node_mins,
const std::vector<double>& node_maxs,
tensorflow::DataType inference_type,
QuantizationSpecs* quant_specs) {
bool GetInputNodeQuantSpecs(
const std::vector<std::string>& node_names,
const std::vector<llvm::Optional<double>>& node_mins,
const std::vector<llvm::Optional<double>>& node_maxs,
tensorflow::DataType inference_type, QuantizationSpecs* quant_specs) {
quant_specs->inference_type = inference_type;
// If min/max are not specified, just return;

View File

@ -19,6 +19,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_
#include <optional>
#include <string>
#include <vector>
@ -69,7 +70,8 @@ struct QuantizationSpecs {
// arguments. They are only used when `weight_quantization` is set to false,
// and the model is required to have quantization parameters, either from
// quantization aware training or calibration, for the remaining tensors.
std::vector<std::pair<double, double>> input_ranges;
std::vector<std::pair<llvm::Optional<double>, llvm::Optional<double>>>
input_ranges;
// The default ranges can be used when a tensor doesn't have quantization
// parameters and couldn't be quantized. Used only for latency tests.
@ -130,11 +132,11 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
// Gets the quantization specification for input arrays. The array names are not
// stored in the spec, and will be matched by position. The min/max will be
// ignored if the inference_type isn't a quantized type. Returns true if failed.
bool GetInputNodeQuantSpecs(const std::vector<std::string>& node_names,
const std::vector<double>& node_mins,
const std::vector<double>& node_maxs,
tensorflow::DataType inference_type,
QuantizationSpecs* quant_specs);
bool GetInputNodeQuantSpecs(
const std::vector<std::string>& node_names,
const std::vector<llvm::Optional<double>>& node_mins,
const std::vector<llvm::Optional<double>>& node_maxs,
tensorflow::DataType inference_type, QuantizationSpecs* quant_specs);
} // namespace TFL
} // namespace mlir

View File

@ -109,8 +109,8 @@ class PrepareQuantizePass
// Get the min and max values from the quantization specification for the
// current function function and argument index. Uses default values if
// the function is specified in the `quantize_whitelist`.
std::pair<double, double> GetMinMaxValuesForArgument(
llvm::StringRef func_name, int index) {
std::pair<llvm::Optional<double>, llvm::Optional<double>>
GetMinMaxValuesForArgument(llvm::StringRef func_name, int index) {
if (func_name == quant_specs_.target_func) {
return quant_specs_.input_ranges[index];
} else {
@ -160,10 +160,14 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
}
auto min_max = GetMinMaxValuesForArgument(func_name, i);
// The input min/max or mean/std are not specified, then skip.
if (!min_max.first.hasValue() || !min_max.second.hasValue()) return;
TypeAttr params = quant::GetQuantizedTypeAttr(
builder, input_type, builder.getF64FloatAttr(min_max.first),
builder.getF64FloatAttr(min_max.second), /*quant_dim=*/-1, num_bits,
narrow_range, is_signed);
builder, input_type,
builder.getF64FloatAttr(min_max.first.getValue()),
builder.getF64FloatAttr(min_max.second.getValue()),
/*quant_dim=*/-1, num_bits, narrow_range, is_signed);
builder.setInsertionPoint(block, insertion_point);
auto q_op =
builder.create<quant::QuantizeCastOp>(loc, params.getValue(), arg);

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/core/platform/test.h"
namespace mlir {
@ -92,7 +93,9 @@ class LstmUtilsTest : public ::testing::Test {
LstmUtilsTest() {}
void SetUp() override {
builder_ = std::unique_ptr<mlir::Builder>(new Builder(&context_));
RegisterDialects();
context_ = std::make_unique<mlir::MLIRContext>();
builder_ = std::unique_ptr<mlir::Builder>(new Builder(context_.get()));
fused_lstm_func_ = createLstmCompositeFunc(builder_.get(), false, false);
fused_lstm_func_cifg_ =
createLstmCompositeFunc(builder_.get(), false, true);
@ -105,10 +108,17 @@ class LstmUtilsTest : public ::testing::Test {
fused_ln_lstm_func_.erase();
builder_.reset();
}
void RegisterDialects() {
mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
mlir::registerDialect<TensorFlowLiteDialect>();
}
FuncOp fused_lstm_func_;
FuncOp fused_lstm_func_cifg_;
FuncOp fused_ln_lstm_func_;
mlir::MLIRContext context_;
std::unique_ptr<mlir::MLIRContext> context_;
std::unique_ptr<mlir::Builder> builder_;
};

View File

@ -1318,6 +1318,126 @@ greater than `clip_value_max` are set to `clip_value_max`.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_CollectiveBcastRecvOp : TF_Op<"CollectiveBcastRecv", []> {
let summary = "Receives a tensor value broadcast from another device.";
let description = [{
}];
let arguments = (ins
I64Attr:$group_size,
I64Attr:$group_key,
I64Attr:$instance_key,
TF_ShapeAttr:$shape,
DefaultValuedAttr<StrAttr, "auto">:$communication_hint
);
let results = (outs
TensorOf<[F16, F32, F64, I1, I32, I64]>:$data
);
TF_DerivedResultTypeAttr T = TF_DerivedResultTypeAttr<0>;
}
def TF_CollectiveBcastSendOp : TF_Op<"CollectiveBcastSend", []> {
let summary = "Broadcasts a tensor value to one or more other devices.";
let description = [{
}];
let arguments = (ins
TensorOf<[F16, F32, F64, I1, I32, I64]>:$input,
I64Attr:$group_size,
I64Attr:$group_key,
I64Attr:$instance_key,
TF_ShapeAttr:$shape,
DefaultValuedAttr<StrAttr, "auto">:$communication_hint
);
let results = (outs
TensorOf<[F16, F32, F64, I1, I32, I64]>:$data
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_CollectiveGatherOp : TF_Op<"CollectiveGather", []> {
let summary = [{
Mutually accumulates multiple tensors of identical type and shape.
}];
let description = [{
}];
let arguments = (ins
TensorOf<[F16, F32, F64, I32, I64]>:$input,
I64Attr:$group_size,
I64Attr:$group_key,
I64Attr:$instance_key,
TF_ShapeAttr:$shape,
DefaultValuedAttr<StrAttr, "auto">:$communication_hint
);
let results = (outs
TensorOf<[F16, F32, F64, I32, I64]>:$data
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_CollectivePermuteOp : TF_Op<"CollectivePermute", [NoSideEffect, SameOperandsAndResultShape]> {
let summary = "An Op to permute tensors across replicated TPU instances.";
let description = [{
Each instance supplies its own input.
For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
source_target_pairs=`[[0,1],[1,2],[2,3],[3,0]]` gets the outputs:
`[D, A, B, C]`.
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
I32Tensor:$source_target_pairs
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_CollectiveReduceOp : TF_Op<"CollectiveReduce", [SameOperandsAndResultType]> {
let summary = [{
Mutually reduces multiple tensors of identical type and shape.
}];
let description = [{
}];
let arguments = (ins
TensorOf<[F16, F32, F64, I32, I64]>:$input,
I64Attr:$group_size,
I64Attr:$group_key,
I64Attr:$instance_key,
TF_AnyStrAttrOf<["Min", "Max", "Mul", "Add"]>:$merge_op,
TF_AnyStrAttrOf<["Id", "Div"]>:$final_op,
I64ArrayAttr:$subdiv_offsets,
DefaultValuedAttr<I64ArrayAttr, "{}">:$wait_for,
DefaultValuedAttr<StrAttr, "auto">:$communication_hint
);
let results = (outs
TensorOf<[F16, F32, F64, I32, I64]>:$data
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_ComplexOp : TF_Op<"Complex", [NoSideEffect, ResultsBroadcastableShape]> {
let summary = "Converts two real numbers to a complex number.";

View File

@ -203,12 +203,12 @@ func @moving_alloc_and_inserting_missing_dealloc(%cond : i1, %arg0 : memref<2xf3
"buffer_assignment_test.unary_lowered"(%arg0, %1) : (memref<2xf32>, memref<2xf32>) -> ()
br ^exit(%1 : memref<2xf32>)
^exit(%arg2: memref<2xf32>):
"bufer_assignment_test.copy"(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
"buffer_assignment_test.copy"(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
return
}
// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc()
// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc()
// CHECK: "bufer_assignment_test.copy"
// CHECK: "buffer_assignment_test.copy"
// CHECK-NEXT: dealloc
// CHECK-NEXT: dealloc
// CHECK-NEXT: return
@ -226,11 +226,11 @@ func @moving_invalid_dealloc_op_complex(%cond : i1, %arg0 : memref<2xf32>, %arg1
dealloc %1 : memref<2xf32>
br ^exit(%1 : memref<2xf32>)
^exit(%arg2: memref<2xf32>):
"bufer_assignment_test.copy"(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
"buffer_assignment_test.copy"(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
return
}
// CHECK-NEXT: %[[ALLOC:.*]] = alloc()
// CHECK: bufer_assignment_test.copy
// CHECK: buffer_assignment_test.copy
// CHECK-NEXT: dealloc
// CHECK-NEXT: return
@ -240,10 +240,10 @@ func @moving_invalid_dealloc_op_complex(%cond : i1, %arg0 : memref<2xf32>, %arg1
func @inserting_missing_dealloc_simple(%arg0 : memref<2xf32>, %arg1: memref<2xf32>){
%0 = alloc() : memref<2xf32>
"buffer_assignment_test.unary_lowered"(%arg0, %0) : (memref<2xf32>, memref<2xf32>) -> ()
"bufer_assignment_test.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
"buffer_assignment_test.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
return
}
// CHECK: bufer_assignment_test.copy
// CHECK: buffer_assignment_test.copy
// CHECK-NEXT: dealloc
// -----
@ -253,8 +253,8 @@ func @moving_invalid_dealloc_op(%arg0 : memref<2xf32>, %arg1: memref<2xf32>){
%0 = alloc() : memref<2xf32>
"buffer_assignment_test.unary_lowered"(%arg0, %0) : (memref<2xf32>, memref<2xf32>) -> ()
dealloc %0 : memref<2xf32>
"bufer_assignment_test.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
"buffer_assignment_test.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
return
}
// CHECK: bufer_assignment_test.copy
// CHECK-NEXT: dealloc
// CHECK: buffer_assignment_test.copy
// CHECK-NEXT: dealloc

View File

@ -29,60 +29,66 @@ limitations under the License.
namespace mlir {
namespace xla {
namespace {
/// This dialect independent unary operation has been defined only for testing
/// buffer assignment.
class BufferAssignmentTestUnaryOp
: public Op<BufferAssignmentTestUnaryOp, OpTrait::OneResult,
OpTrait::OneOperand> {
public:
using Op::Op;
static StringRef getOperationName() { return "buffer_assignment_test.unary"; }
static void build(OpBuilder& b, OperationState& state, Value source) {
state.addOperands(source);
}
};
/// This dialect independent lowered unary operation has been defined only for
/// testing buffer assignment.
class BufferAssignmentTestUnaryLoweredOp
: public Op<BufferAssignmentTestUnaryLoweredOp, OpTrait::ZeroResult,
OpTrait::NOperands<2>::Impl> {
public:
using Op::Op;
static StringRef getOperationName() {
return "buffer_assignment_test.unary_lowered";
}
static void build(OpBuilder& b, OperationState& state, Value source,
Value target) {
state.addOperands(source);
state.addOperands(target);
}
};
/// This dialect independent copy operation has been defined only for testing
/// NonVoidToVoidReturnOpConverter
class BufferAssignmentTestCopyOp
: public Op<BufferAssignmentTestCopyOp, OpTrait::ZeroResult,
OpTrait::NOperands<2>::Impl> {
public:
using Op::Op;
static StringRef getOperationName() { return "buffer_assignment_test.copy"; }
static void build(OpBuilder& b, OperationState& state, Value from, Value to) {
state.addOperands(from);
state.addOperands(to);
}
};
class BufferAssignmentTestDialect : public Dialect {
public:
explicit BufferAssignmentTestDialect(MLIRContext* context)
: Dialect(getDialectNamespace(), context) {
addOperations<BufferAssignmentTestCopyOp, BufferAssignmentTestUnaryOp,
BufferAssignmentTestUnaryLoweredOp>();
}
static StringRef getDialectNamespace() { return "buffer_assignment_test"; }
};
/// This pass tests two provided operation converters,
/// FunctionAndBlockSignatureConverter and NonVoidToVoidReturnOpConverter, for
/// Buffer Assignment.
struct BufferAssignmentPreparationTestPass
: mlir::PassWrapper<BufferAssignmentPreparationTestPass, FunctionPass> {
/// This dialect independent unary operation has been defined only for testing
/// buffer assignment.
class BufferAssignmentTestUnaryOp
: public Op<BufferAssignmentTestUnaryOp, OpTrait::OneResult,
OpTrait::OneOperand> {
public:
using Op::Op;
static StringRef getOperationName() {
return "buffer_assignment_test.unary";
}
static void build(OpBuilder& b, OperationState& state, Value source) {
state.addOperands(source);
}
};
/// This dialect independent lowered unary operation has been defined only for
/// testing buffer assignment.
class BufferAssignmentTestUnaryLoweredOp
: public Op<BufferAssignmentTestUnaryLoweredOp, OpTrait::ZeroResult,
OpTrait::NOperands<2>::Impl> {
public:
using Op::Op;
static StringRef getOperationName() {
return "buffer_assignment_test.unary_lowered";
}
static void build(OpBuilder& b, OperationState& state, Value source,
Value target) {
state.addOperands(source);
state.addOperands(target);
}
};
/// This dialect independent copy operation has been defined only for testing
/// NonVoidToVoidReturnOpConverter
class BufferAssignmentTestCopyOp
: public Op<BufferAssignmentTestCopyOp, OpTrait::ZeroResult,
OpTrait::NOperands<2>::Impl> {
public:
using Op::Op;
static StringRef getOperationName() {
return "buffer_assignment_test.copy";
}
static void build(OpBuilder& b, OperationState& state, Value from,
Value to) {
state.addOperands(from);
state.addOperands(to);
}
};
/// A simple converter that legalizes a BufferAssignmentTestUnaryOp to a
/// BufferAssignmentTestUnaryLoweredOp and creates buffer allocation for
/// the result of the computation.
@ -151,8 +157,12 @@ struct BufferAssignmentPreparationTestPass
}
};
};
} // namespace
static mlir::DialectRegistration<BufferAssignmentTestDialect>
buffer_assignment_test_ops;
/// This pass tests helper methods such as computeAllocPosition,
/// FunctionAndBlockSignatureConverter, NonVoidToVoidReturnOpConverter
/// conversion patterns. Furthermore, it checks buffer-assignment pass that

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import itertools
import os
import numpy as np
@ -1609,8 +1608,4 @@ class BinaryOpsTest(xla_test.XLATestCase):
if __name__ == "__main__":
# TODO(b/130689556): XLA CPU does not honor inf/nan which causes problems
os.environ[
"XLA_FLAGS"] = "--xla_cpu_enable_fast_math=false " + os.environ.get(
"XLA_FLAGS", "")
googletest.main()

View File

@ -347,17 +347,15 @@ class UnaryOpsTest(xla_test.XLATestCase):
expected=np.array(
[1.55740772, -2.18503986, -0.14254654, 1.15782128], dtype=dtype))
# TODO(b/130689556): Turn this on for CPU when we start honoring NaNs.
if self.device != "XLA_CPU":
self._assertOpOutputMatchesExpected(
math_ops.tanh,
np.array([[1, 2, 3, 4], [np.inf, -np.inf, np.nan, 20],
[19, -19, 22, -22]],
dtype=dtype),
expected=np.array(
[[0.76159418, 0.96402758, 0.99505478, 0.99932933],
[1.0, -1.0, np.nan, 1.0], [1.0, -1.0, 1.0, -1.0]],
dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.tanh,
np.array([[1, 2, 3, 4], [np.inf, -np.inf, np.nan, 20],
[19, -19, 22, -22]],
dtype=dtype),
expected=np.array(
[[0.76159418, 0.96402758, 0.99505478, 0.99932933],
[1.0, -1.0, np.nan, 1.0], [1.0, -1.0, 1.0, -1.0]],
dtype=dtype))
self._assertOpOutputMatchesExpected(
nn_ops.log_softmax,

View File

@ -136,8 +136,11 @@ class TensorListReserveOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements));
OP_REQUIRES(
ctx, num_elements >= 0,
errors::InvalidArgument("XLA compilation requires a fixed tensor list "
"size. Set the number of elements."));
errors::InvalidArgument(
"XLA compilation requires a fixed tensor list size. Set the number "
"of elements. This could also happen if you're using a TensorArray "
"in a while loop that does not have its maximum_iteration set, you "
"can fix this by setting maximum_iteration to a suitable value."));
// If element shape is compile time constant and it's not "unknown rank"
// shape (-1), create an initialized TensorList. Otherwise create an
@ -197,10 +200,13 @@ class EmptyTensorListOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
int64 max_num_elements;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements));
OP_REQUIRES(
ctx, max_num_elements >= 0,
errors::InvalidArgument("XLA compilation requires a fixed tensor list "
"size. Set the max number of elements."));
OP_REQUIRES(ctx, max_num_elements >= 0,
errors::InvalidArgument(
"XLA compilation requires a fixed tensor list size. Set "
"the max number of elements. This could also happen if "
"you're using a TensorArray in a while loop that does not "
"have its maximum_iteration set, you can fix this by "
"setting maximum_iteration to a suitable value."));
if (dtype_ != DT_VARIANT) {
// We are creating a non-nested TensorList.

View File

@ -63,10 +63,6 @@ class ExecutionInput {
explicit ExecutionInput(xla::Shape shape) : buffers_(std::move(shape)) {}
explicit ExecutionInput(ShapeTree<MaybeOwningDeviceMemory> buffers)
: buffers_(std::move(buffers)) {}
ExecutionInput(ShapeTree<MaybeOwningDeviceMemory> buffers,
std::vector<ShapeIndex> owner_held_indices)
: buffers_(std::move(buffers)),
unowned_indices_(std::move(owner_held_indices)) {}
ExecutionInput(ExecutionInput&&) = default;
~ExecutionInput() {

View File

@ -40,7 +40,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
public:
// A NestedComputer computes an element of the output of the given computation
// given a Span of its input elements.
using NestedComputer = std::function<StatusOr<llvm::Value*>(
using NestedComputer = std::function<StatusOr<std::vector<llvm::Value*>>(
const HloComputation&, absl::Span<llvm::Value* const>)>;
GpuElementalIrEmitter(const HloModuleConfig& hlo_module_config,
@ -91,12 +91,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
absl::string_view) override {
// TODO(b/118332391): Supported variadic return values.
auto result = compute_nested_(callee, parameters);
if (!result.ok()) {
return result.status();
}
return std::vector<llvm::Value*>{result.ValueOrDie()};
return compute_nested_(callee, parameters);
}
llvm::Value* EmitThreadId() override;

View File

@ -698,115 +698,6 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) {
return Status::OK();
}
Status IrEmitter::HandleReduce(HloInstruction* instr) {
const HloReduceInstruction* reduce = Cast<HloReduceInstruction>(instr);
const Shape& out_shape = reduce->shape();
bool returns_tuple = !out_shape.IsArray();
int accumulators_count = 1;
if (returns_tuple) {
CHECK(out_shape.IsTuple());
accumulators_count = out_shape.tuple_shapes_size();
}
auto arg = reduce->operand(0);
absl::Span<const int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply();
return EmitTargetElementLoop(
*reduce,
[=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
std::vector<llvm::Value*> accumulator_addrs;
std::vector<llvm::Type*> accumulator_types;
// Initialize accumulators with initial values.
for (int i = 0; i < accumulators_count; i++) {
auto init_value = reduce->init_values()[i];
const Shape& element_shape =
returns_tuple ? out_shape.tuple_shapes(i) : out_shape;
PrimitiveType accumulator_type = element_shape.element_type();
llvm::Type* accumulator_llvm_type =
llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_);
llvm::AllocaInst* accumulator_addr = Alloca(accumulator_llvm_type);
Store(Load(GetBasePointer(*init_value)), accumulator_addr);
accumulator_addrs.push_back(accumulator_addr);
accumulator_types.push_back(accumulator_llvm_type);
}
// The enclosing loops go over all the target elements. Now we have to
// compute the actual target element. For this, we build a new loop nest
// to iterate over all the reduction dimensions in the argument.
// AddLoopsForShapeOnDimensions will return an Index where induction
// Value*s are placed for each dimension in dimensions, and all the rest
// are nullptrs.
llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_);
std::vector<llvm::Value*> input_multi_index =
loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
"reduction_dim");
SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
// Build a full index for the input argument, using reduced_dims_index
// as the base. In reduced_dims_index only the reduction dimensions are
// filled in. We fill in the rest of the dimensions with induction
// Value*s taken from 'index' which iterates over the target array.
// See the high-level description in the XLA documentation for details.
llvm_ir::IrArray::Index::const_iterator it = index.begin();
for (auto& i : input_multi_index) {
if (i == nullptr) {
i = *it++;
}
}
CHECK(index.end() == it);
// Apply the reduction function to the loaded value.
llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
b_.getInt64Ty());
std::vector<llvm::Value*> reduction_operands(accumulator_addrs.begin(),
accumulator_addrs.end());
for (int i = 0; i < accumulators_count; i++) {
llvm::Value* input_address =
GetIrArray(*reduce->operand(i), *reduce)
.EmitArrayElementAddress(input_index, &b_);
reduction_operands.push_back(input_address);
}
llvm::Value* ret_argument;
if (!returns_tuple) {
CHECK_EQ(accumulator_addrs.size(), 1);
ret_argument = accumulator_addrs[0];
} else {
const Shape& return_shape = function->root_instruction()->shape();
llvm::Type* return_value_buffer_type =
llvm_ir::ShapeToIrType(return_shape, module_);
ret_argument = Alloca(return_value_buffer_type);
llvm_ir::IrArray tuple_array(ret_argument, return_shape);
EmitTuple(tuple_array, accumulator_addrs, &b_);
}
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*function, reduction_operands, ret_argument));
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
if (!returns_tuple) {
CHECK_EQ(accumulator_addrs.size(), 1);
return Load(accumulator_addrs[0]);
} else {
// Emit a struct for the LoopEmitter dealing with multi-output
// fusion.
llvm::Value* returned_structure = llvm::UndefValue::get(
llvm::StructType::get(b_.getContext(), accumulator_types));
for (int i = 0; i < accumulators_count; i++) {
llvm::Value* accumulator_value = Load(accumulator_addrs[i]);
returned_structure =
b_.CreateInsertValue(returned_structure, accumulator_value, i);
}
return returned_structure;
}
});
}
Status IrEmitter::HandleFusion(HloInstruction* fusion) {
// kFusion for library calls should be handled by
// IrEmitterUnnested::HandleFusion.
@ -866,22 +757,39 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
"to a cudnn CustomCall using CudnnBatchNormRewriter.");
}
StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
StatusOr<std::vector<llvm::Value*>> IrEmitter::ComputeNestedElement(
const HloComputation& computation,
absl::Span<llvm::Value* const> parameter_elements) {
const Shape& return_shape = computation.root_instruction()->shape();
llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(
computation.root_instruction()->shape().element_type(), module_),
"return_buffer", &b_);
llvm_ir::ShapeToIrType(return_shape, module_), "return_buffer", &b_);
std::vector<llvm::Value*> parameter_buffers;
for (llvm::Value* parameter_element : parameter_elements) {
parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
parameter_element->getType(), "parameter_buffer", &b_));
Store(parameter_element, parameter_buffers.back());
}
std::vector<llvm::Value*> allocas_for_returned_scalars;
if (!return_shape.IsTuple()) {
allocas_for_returned_scalars.push_back(return_buffer);
} else {
allocas_for_returned_scalars =
llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_);
llvm_ir::IrArray tuple_array(return_buffer, return_shape);
EmitTuple(tuple_array, allocas_for_returned_scalars, &b_);
}
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers,
return_buffer));
return Load(return_buffer);
std::vector<llvm::Value*> returned_scalars;
returned_scalars.reserve(allocas_for_returned_scalars.size());
for (llvm::Value* addr : allocas_for_returned_scalars) {
returned_scalars.push_back(Load(addr));
}
return returned_scalars;
}
std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs(

View File

@ -89,7 +89,6 @@ class IrEmitter : public DfsHloVisitorWithDefault,
Status HandleRecv(HloInstruction* recv) override;
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandleParameter(HloInstruction* parameter) override;
Status HandleReduce(HloInstruction* reduce) override;
Status HandleTuple(HloInstruction* tuple) override;
Status HandleScatter(HloInstruction* scatter) override;
Status HandleSelect(HloInstruction* select) override;
@ -213,7 +212,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
const llvm_ir::IrArray::Index& compare_keys_index,
const llvm_ir::IrArray& keys_array);
StatusOr<llvm::Value*> ComputeNestedElement(
StatusOr<std::vector<llvm::Value*>> ComputeNestedElement(
const HloComputation& computation,
absl::Span<llvm::Value* const> parameter_elements);

View File

@ -312,12 +312,13 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
class HloDotDumper {
public:
HloDotDumper(const HloComputation* computation, absl::string_view label,
const DebugOptions& debug_options, bool show_backend_config,
const DebugOptions& debug_options,
HloRenderOptions hlo_render_options,
const HloExecutionProfile* profile, NodeFilter filter)
: computation_(computation),
label_(label),
debug_options_(debug_options),
show_backend_config_(show_backend_config),
hlo_render_options_(hlo_render_options),
profile_(profile),
filter_(std::move(filter)) {}
@ -384,7 +385,7 @@ class HloDotDumper {
const HloComputation* computation_; // never null
const string label_; // overall name for the graph
const DebugOptions& debug_options_;
const bool show_backend_config_;
const HloRenderOptions hlo_render_options_;
const HloExecutionProfile* profile_; // may be null
const NodeFilter filter_;
@ -565,7 +566,8 @@ bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
if (subcomp->IsFusionComputation()) {
const HloInstruction* fusion = subcomp->FusionInstruction();
if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion)) {
if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion) ||
!hlo_render_options_.show_fusion_subcomputations) {
return false;
}
}
@ -1133,7 +1135,8 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
string HloDotDumper::GetInstructionNodeBackendConfig(
const HloInstruction* instr) {
if (!show_backend_config_ || instr->raw_backend_config_string().empty()) {
if (!hlo_render_options_.show_backend_config ||
instr->raw_backend_config_string().empty()) {
return "";
}
@ -1604,14 +1607,14 @@ StatusOr<string> RenderGraph(const HloComputation& computation,
const DebugOptions& debug_options,
RenderedGraphFormat format,
const HloExecutionProfile* hlo_execution_profile,
bool show_backend_config) {
HloRenderOptions hlo_render_options) {
tensorflow::mutex_lock lock(url_renderer_mu);
if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
return Unavailable("Can't render as URL; no URL renderer was registered.");
}
string rendered_dot =
HloDotDumper(&computation, label, debug_options, show_backend_config,
HloDotDumper(&computation, label, debug_options, hlo_render_options,
hlo_execution_profile, NodeFilter())
.Dump();
return WrapDotInFormat(rendered_dot, format);
@ -1619,7 +1622,7 @@ StatusOr<string> RenderGraph(const HloComputation& computation,
StatusOr<string> RenderNeighborhoodAround(
const HloInstruction& node, int radius, RenderedGraphFormat format,
bool show_backend_config,
HloRenderOptions hlo_render_options,
const absl::flat_hash_set<const HloInstruction*>& boundary) {
tensorflow::mutex_lock lock(url_renderer_mu);
if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
@ -1632,7 +1635,7 @@ StatusOr<string> RenderNeighborhoodAround(
string rendered_dot =
HloDotDumper(node.parent(), label,
node.GetModule()->config().debug_options(),
show_backend_config, /*profile=*/nullptr,
hlo_render_options, /*profile=*/nullptr,
MakeNodeRadiusAroundFilter(&node, radius, boundary))
.Dump();
return WrapDotInFormat(rendered_dot, format);
@ -1641,7 +1644,7 @@ StatusOr<string> RenderNeighborhoodAround(
StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
const HloInstruction& to, int64 max_nodes,
RenderedGraphFormat format,
bool show_backend_config) {
HloRenderOptions hlo_render_options) {
tensorflow::mutex_lock lock(url_renderer_mu);
if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
return FailedPrecondition(
@ -1663,7 +1666,7 @@ StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
"NODES***<br/><br/>");
}
string rendered_dot =
HloDotDumper(from.parent(), label, debug_options, show_backend_config,
HloDotDumper(from.parent(), label, debug_options, hlo_render_options,
/*profile=*/nullptr, filter)
.Dump();
return WrapDotInFormat(rendered_dot, format);

View File

@ -50,6 +50,14 @@ enum class RenderedGraphFormat {
kUrl,
};
struct HloRenderOptions {
// Include the backend config string in the rendered graph.
bool show_backend_config = false;
// Include the fusion subcomputations in the rendered graph.
bool show_fusion_subcomputations = true;
};
// Renders an HLO module as a human-readable visual graph.
//
// Note that this only works well for relatively small graphs (no more than a
@ -61,7 +69,7 @@ StatusOr<string> RenderGraph(
const HloComputation& computation, absl::string_view label,
const DebugOptions& debug_options, RenderedGraphFormat format,
const HloExecutionProfile* hlo_execution_profile = nullptr,
bool show_backend_config = false);
HloRenderOptions hlo_render_options = {});
// Like RenderGraph, but renders only nodes "near" the given node in the graph.
//
@ -73,7 +81,7 @@ StatusOr<string> RenderGraph(
// will be omitted even if they are within the radius.
StatusOr<string> RenderNeighborhoodAround(
const HloInstruction& node, int radius, RenderedGraphFormat format,
bool show_backend_config = false,
HloRenderOptions hlo_render_options = {},
const absl::flat_hash_set<const HloInstruction*>& boundary = {});
// Renders nodes on any of the paths from `from` to `to`. If there are more
@ -82,7 +90,7 @@ StatusOr<string> RenderNeighborhoodAround(
StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
const HloInstruction& to, int64 max_nodes,
RenderedGraphFormat format,
bool show_backend_config = false);
HloRenderOptions hlo_render_options = {});
// Registers a function which implements RenderedGraphFormat::kUrl.
//

View File

@ -137,9 +137,8 @@ class ShapedBuffer {
std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer);
// ShapedBuffer derived class which allocates all internal buffers on
// construction and deallocates the memory when the object is
// destructed.
// ScopedShapedBuffer takes allocated buffers as inputs, and deallocates on
// destruction. This class represents an owning wrapper around `ShapedBuffer`.
//
// TODO(timshen): Remove inheritance between ScopedShapedBuffer and
// ShapedBuffer. There should never be a need to consider a ScopedShapedBuffer

View File

@ -81,11 +81,10 @@ void SpmdLogger::RegisterLogEntry(HloInstruction* hlo,
string report = hlo->ToString();
int64 max_value = -1;
for (HloInstruction* inst : group) {
if (inst->shape().IsTuple()) {
if (!inst->shape().IsArray()) {
continue;
}
max_value =
std::max<int64>(max_value, ShapeUtil::ByteSizeOf(inst->shape(), 4));
max_value = std::max<int64>(max_value, ShapeSizeInBytes(inst->shape()));
absl::StrAppend(&report, " * ", inst->ToString(), "\n");
}
entries_.push_back(std::make_pair(max_value, report));
@ -149,14 +148,14 @@ template <typename F>
const auto add_report = [&](std::vector<HloInstruction*>* insts) {
std::sort(insts->begin(), insts->end(),
[](const HloInstruction* inst0, const HloInstruction* inst1) {
return ShapeUtil::ByteSizeOf(inst0->shape()) >
ShapeUtil::ByteSizeOf(inst1->shape());
return ShapeSizeInBytes(inst0->shape()) >
ShapeSizeInBytes(inst1->shape());
});
for (int64 i = 0;
i < std::min<int64>(report_instruction_count, insts->size()); ++i) {
absl::StrAppend(&report, " ",
tensorflow::strings::HumanReadableNumBytes(
ShapeUtil::ByteSizeOf((*insts)[i]->shape())),
ShapeSizeInBytes((*insts)[i]->shape())),
" : ", (*insts)[i]->ToString(), "\n");
}
};
@ -1180,8 +1179,8 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims(
operand, scatter_dims_to_operand_dims, slice_size,
num_partitions_) &&
ShapeUtil::ByteSizeOf(updates.base_shape()) <
ShapeUtil::ByteSizeOf(scatter->shape())) {
ShapeSizeInBytes(updates.base_shape()) <
ShapeSizeInBytes(scatter->shape())) {
// Operand is sharded on trivial slice dims (update slice size 1). We can
// adjust the indices on each partition by subtracting the offsets. Then
// we execute a scatter on full updated indices, and out-of-bound accesses
@ -1968,8 +1967,8 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) {
if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims(
operand, start_index_map, gather->gather_slice_sizes(),
num_partitions_) &&
ShapeUtil::ByteSizeOf(gather->shape()) <
ShapeUtil::ByteSizeOf(gather->operand(0)->shape())) {
ShapeSizeInBytes(gather->shape()) <
ShapeSizeInBytes(gather->operand(0)->shape())) {
indices = indices.Reshard(HloSharding::Replicate());
// Now the operand is partitioned in trivial slice dimensions, and the
// indices are replicated. We execute a gather on partitioned operand,
@ -2762,8 +2761,7 @@ Status SpmdPartitioningVisitor::HandleConvolutionTiledLhsAndRhs(
auto zero = b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(hlo->shape().element_type())));
if (ShapeUtil::ByteSizeOf(lhs.base_shape()) <
ShapeUtil::ByteSizeOf(rhs.base_shape())) {
if (ShapeSizeInBytes(lhs.base_shape()) < ShapeSizeInBytes(rhs.base_shape())) {
if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) {
return DefaultAction(hlo);
}
@ -3005,8 +3003,8 @@ Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) {
};
auto zero = b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(hlo->shape().element_type())));
if (ShapeUtil::ByteSizeOf(lhs.base_shape()) <
ShapeUtil::ByteSizeOf(rhs.base_shape())) {
if (ShapeSizeInBytes(lhs.base_shape()) <
ShapeSizeInBytes(rhs.base_shape())) {
if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) {
return DefaultAction(hlo);
}
@ -3731,7 +3729,7 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
};
if (output_lhs_non_contracting_partitions == num_partitions_ &&
output_sharding_transposed_to_match_lhs == lhs_sharding &&
ShapeUtil::ByteSizeOf(hlo->operand(1)->shape()) >=
ShapeSizeInBytes(hlo->operand(1)->shape()) >=
options_.threshold_for_windowed_einsum_mib * 1024 * 1024) {
if (rhs_contracting_partitions == num_partitions_) {
return emit_windowed_dot_general(0, 1, true, false);
@ -3745,7 +3743,7 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
}
if (output_rhs_non_contracting_partitions == num_partitions_ &&
output_sharding_transposed_to_match_rhs == rhs_sharding &&
ShapeUtil::ByteSizeOf(hlo->operand(0)->shape()) >=
ShapeSizeInBytes(hlo->operand(0)->shape()) >=
options_.threshold_for_windowed_einsum_mib * 1024 * 1024) {
if (lhs_contracting_partitions == num_partitions_) {
return emit_windowed_dot_general(1, 0, true, false);
@ -3775,8 +3773,8 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
LiteralUtil::Zero(hlo->shape().element_type())));
// Pad both sides with zero, since NaN at one side cannot be masked by zero
// on the other side.
if (ShapeUtil::ByteSizeOf(lhs.base_shape()) <
ShapeUtil::ByteSizeOf(rhs.base_shape())) {
if (ShapeSizeInBytes(lhs.base_shape()) <
ShapeSizeInBytes(rhs.base_shape())) {
lhs =
lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero);
rhs = rhs.PadWithValue(zero);
@ -4607,8 +4605,8 @@ HloInstruction* SpmdPartitioner::AllGatherShards(SpmdBuilder* b,
xpose_permutation[i] = i + tiled_dims.size() - split_dims_added;
} else {
xpose_permutation[i] = split_dims_added;
xpose_permutation[i + 1] = i + tiled_dims.size() - split_dims_added;
split_dims_added++;
xpose_permutation[i + 1] = i + tiled_dims.size();
i++;
}
}

View File

@ -649,6 +649,43 @@ ENTRY entry {
op::ReduceWindow(masked, op::Constant())));
}
TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideHaloBeyondNeighbor) {
const char* const hlo_string = R"(
HloModule module
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add = f32[] add(a, b)
}
ENTRY entry {
param = f32[9,2] parameter(0), sharding={devices=[5,1]0,1,2,3,4}
constant.1 = f32[] constant(0), sharding={replicated}
ROOT reduce-window = f32[5,2]{1,0} reduce-window(param, constant.1),
window={size=4x1 stride=2x1 pad=3_0x0_0}, to_apply=sum,
sharding={devices=[5,1]0,1,2,3,4}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/5));
VLOG(1) << module->ToString();
auto halo0 = AllOf(op::Shape("f32[1,2]"),
op::CollectivePermute(op::Slice(op::Parameter(0))));
auto halo1 =
AllOf(op::Shape("f32[2,2]"), op::CollectivePermute(op::Parameter(0)));
auto pre_mask =
AllOf(op::Shape("f32[4,2]"),
op::Slice(AllOf(op::Shape("f32[5,2]"),
op::Concatenate(halo0, halo1, op::Parameter(0)))));
auto masked =
op::Select(op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply())),
op::Broadcast(op::Constant())),
pre_mask, op::Broadcast(op::Constant()));
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Shape("f32[1,2]{1,0}"),
op::ReduceWindow(masked, op::Constant())));
}
TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideUnequalHalo) {
const char* const hlo_string = R"(
HloModule module

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
#include <algorithm>
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@ -23,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@ -104,6 +107,11 @@ Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) {
return sharding.TileShape(shape);
}
int64 ShapeSizeInBytes(const Shape& shape) {
return ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()) *
ShapeUtil::ElementsIn(shape);
}
Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,
const HloSharding& sharding,
int64 partition_id) {
@ -402,33 +410,30 @@ absl::optional<HloInstruction*> ExchangeHalo(
std::vector<HloInstruction*> concat_pieces;
int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count);
if (max_left_halo_size > input_shard_size) {
VLOG(1) << "ExchangeHalo failed: halo is beyond the left neighbor.";
return absl::nullopt;
}
if (max_left_halo_size > 0) {
for (int64 i = CeilOfRatio(max_left_halo_size, input_shard_size) - 1; i >= 0;
--i) {
std::vector<std::pair<int64, int64>> source_target_pairs;
target.tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 device) {
if (indices[dim] > 0) {
if (indices[dim] > i) {
std::vector<int64> source_indices(indices.begin(), indices.end());
source_indices[dim] -= 1;
source_indices[dim] -= i + 1;
source_target_pairs.emplace_back(
target.tile_assignment()(source_indices), device);
}
});
int64 halo_size =
std::min(max_left_halo_size - input_shard_size * i, input_shard_size);
auto halo_shape = hlo->shape();
auto source_halo_slice = hlo;
if (max_left_halo_size != hlo->shape().dimensions(dim)) {
halo_shape.set_dimensions(dim, max_left_halo_size);
if (halo_size != hlo->shape().dimensions(dim)) {
halo_shape.set_dimensions(dim, halo_size);
std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
halo_start_indices[dim] =
hlo->shape().dimensions(dim) - max_left_halo_size;
halo_start_indices[dim] = hlo->shape().dimensions(dim) - halo_size;
std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
source_halo_slice = b->AddInstruction(
hlo->CreateSlice(halo_shape, hlo, halo_start_indices,
hlo->shape().dimensions(), halo_slice_strides));
source_halo_slice = b->AddInstruction(HloInstruction::CreateSlice(
halo_shape, hlo, halo_start_indices, hlo->shape().dimensions(),
halo_slice_strides));
}
auto left_halo =
collective_ops_creator.create_cross_partition_collective_permute(
@ -441,29 +446,30 @@ absl::optional<HloInstruction*> ExchangeHalo(
// Right halo.
int64 max_right_halo_size =
right_halo_size_function.MaxInRange(0, shard_count - 1);
if (max_right_halo_size > input_shard_size) {
VLOG(1) << "ExchangeHalo failed: halo is beyond the right neighbor.";
return absl::nullopt;
}
if (max_right_halo_size > 0) {
for (int64 i = 0; i < CeilOfRatio(max_right_halo_size, input_shard_size);
++i) {
std::vector<std::pair<int64, int64>> source_target_pairs;
target.tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 device) {
if (indices[dim] > 0) {
if (indices[dim] > i) {
std::vector<int64> target_indices(indices.begin(), indices.end());
target_indices[dim] -= 1;
target_indices[dim] -= i + 1;
source_target_pairs.emplace_back(
device, target.tile_assignment()(target_indices));
}
});
int64 halo_size =
std::min(max_right_halo_size - input_shard_size * i, input_shard_size);
auto halo_shape = hlo->shape();
halo_shape.set_dimensions(dim, max_right_halo_size);
std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
auto source_halo_slice = b->AddInstruction(
hlo->CreateSlice(halo_shape, hlo, halo_start_indices,
halo_shape.dimensions(), halo_slice_strides));
HloInstruction* source_halo_slice = hlo;
if (halo_size != halo_shape.dimensions(dim)) {
halo_shape.set_dimensions(dim, halo_size);
std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
source_halo_slice = b->AddInstruction(HloInstruction::CreateSlice(
halo_shape, hlo, halo_start_indices, halo_shape.dimensions(),
halo_slice_strides));
}
auto right_halo =
collective_ops_creator.create_cross_partition_collective_permute(
b, source_halo_slice, source_target_pairs, (*next_channel_id)++);

View File

@ -57,6 +57,10 @@ bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding);
// target sharding.
Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding);
// Similar to ShapeUtil::ByteSizeOf(), but does not check it has dense layout
// since this can be before layout assignment.
int64 ShapeSizeInBytes(const Shape& shape);
// Returns the shard shape for a partition without padding due to uneven
// sharding.
Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,

View File

@ -112,8 +112,7 @@ constexpr int64 kDefaultMaxNumNodesInAllPaths = 100;
using absl::EqualsIgnoreCase;
// A global control for whether backend configuration display is enabled.
bool show_backend_config = true;
HloRenderOptions hlo_render_options;
HloInstruction* FindInstruction(const HloModule& module, string node_name) {
if (absl::StartsWith(node_name, "%")) {
@ -160,6 +159,8 @@ void DoHelpCommand() {
Renders all nodes in <computation>.
backend_config [on|off]
Controls whether backend operation configuration information is printed.
show_fusion_subcomputations [on|off]
Controls whether fusion subcomputations are shown.
list [name|op_name|op_type] <pattern>
Lists all instructions whose name, metadata op_name, or metadata op_type
contains <pattern> as a substring.
@ -182,15 +183,32 @@ void DoHelpCommand() {
// Turn metadata-printing on or off.
void DoBackendConfigCommand(const std::vector<string>& tokens) {
if (tokens.size() == 2 && tokens[1] == "on") {
show_backend_config = true;
hlo_render_options.show_backend_config = true;
} else if (tokens.size() == 2 && tokens[1] == "off") {
show_backend_config = false;
hlo_render_options.show_backend_config = false;
} else if (tokens.size() != 1) {
std::cerr << "(Illegal backend_config value. Use either 'on' or 'off'.)"
<< std::endl;
}
std::cout << "Backend configuration display "
<< (show_backend_config ? "ON" : "OFF") << std::endl;
<< (hlo_render_options.show_backend_config ? "ON" : "OFF")
<< std::endl;
}
// Turn fusion computation display on or off.
void DoShowFusionSubcomputationsCommand(const std::vector<string>& tokens) {
if (tokens.size() == 2 && tokens[1] == "on") {
hlo_render_options.show_fusion_subcomputations = true;
} else if (tokens.size() == 2 && tokens[1] == "off") {
hlo_render_options.show_fusion_subcomputations = false;
} else if (tokens.size() != 1) {
std::cerr << "(Illegal show_fusion_subcomputations value. Use either "
"'on' or 'off'.)"
<< std::endl;
}
std::cout << "Fusion subcomputations display "
<< (hlo_render_options.show_fusion_subcomputations ? "ON" : "OFF")
<< std::endl;
}
// List all computations in the module.
@ -373,7 +391,7 @@ void DoExtractCommand(const HloModule& module,
auto extracted_module = ExtractModule(instr, height);
std::cout << extracted_module->ToString(
HloPrintOptions::ShortParsable().set_print_backend_config(
show_backend_config))
hlo_render_options.show_backend_config))
<< std::endl;
}
@ -517,7 +535,7 @@ void DoAllPathsCommand(const Options& opts, const HloModule& module,
}
RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
return RenderAllPathsFromTo(*from, *to, max_nodes, format,
/*show_backend_config=*/show_backend_config);
hlo_render_options);
});
}
@ -582,15 +600,13 @@ void DoPlotCommand(const Options& opts, const HloModule& module,
RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
return RenderGraph(*comp, /*label=*/"",
comp->parent()->config().debug_options(), format,
/*hlo_execution_profile=*/nullptr,
/*show_backend_config=*/show_backend_config);
/*hlo_execution_profile=*/nullptr, hlo_render_options);
});
} else {
RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
return RenderNeighborhoodAround(
*instr, graph_width, format,
/*show_backend_config=*/show_backend_config,
/*boundary=*/boundary);
return RenderNeighborhoodAround(*instr, graph_width, format,
hlo_render_options,
/*boundary=*/boundary);
});
}
}
@ -617,6 +633,8 @@ void InteractiveDumpGraphs(const Options& opts, const HloModule& module) {
DoHelpCommand();
} else if (tokens[0] == "backend_config") {
DoBackendConfigCommand(tokens);
} else if (tokens[0] == "show_fusion_subcomputations") {
DoShowFusionSubcomputationsCommand(tokens);
} else if (tokens[0] == "list") {
if (tokens.size() > 1 && tokens[1] == "computations") {
DoListComputationsCommand(module, tokens);

View File

@ -84,6 +84,13 @@ END
name: "Tout"
description: <<END
the types of the output tensors.
END
}
attr {
name: "enable_large_batch_splitting"
description: <<END
input with a large size (i.e., larger than the largest value of
`allowed_batch_sizes`) will be splitted into multiple batches with batch size.
END
}
summary: "Batches all the inputs tensors to the computation done by the function."

View File

@ -0,0 +1,5 @@
op {
graph_op_name: "DeviceIndex"
visibility: HIDDEN
summary: "Return the index of device the op runs."
}

View File

@ -100,7 +100,7 @@ string AttrSlice::DebugString() const {
return absl::StrJoin(attr_key_vals, ", ");
}
string SummarizeNodeDef(const NodeDef& node_def) {
string SummarizeNodeDef(const NodeDef& node_def, int max_inputs_in_summary) {
string ret = strings::StrCat(errors::FormatNodeNameForError(node_def.name()),
" = ", node_def.op(), "[");
strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device()));
@ -111,6 +111,10 @@ string SummarizeNodeDef(const NodeDef& node_def) {
for (const string& input : node_def.input()) {
if (!first) strings::StrAppend(&ret, ", ");
first = false;
if (max_inputs_in_summary-- == 0) {
strings::StrAppend(&ret, "...");
break;
}
strings::StrAppend(&ret, input);
}
strings::StrAppend(&ret, ")");

View File

@ -58,7 +58,12 @@ extern const char* const kColocationGroupPrefix;
// Produce a human-readable version of a Node or NodeDef that is more concise
// than a text-format proto.
string SummarizeNodeDef(const NodeDef& node_def);
//
// The parameter `max_inputs_in_summary` specifies how many inputs at most to
// serialize in the output (in order not to get a string which is overly large).
// The value `-1` specifies that all inputs will be shown.
string SummarizeNodeDef(const NodeDef& node_def,
int max_inputs_in_summary = -1);
string SummarizeAttrs(const NodeDef& node_def);
string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device);

View File

@ -1062,6 +1062,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",

View File

@ -603,16 +603,8 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/optimizers:arithmetic_optimizer",
"//tensorflow/core/grappler/optimizers:common_subgraph_elimination",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"//tensorflow/core/grappler/optimizers:dependency_optimizer",
"//tensorflow/core/grappler/optimizers:function_optimizer",
"//tensorflow/core/grappler/optimizers:loop_optimizer",
"//tensorflow/core/grappler/optimizers:model_pruner",
"//tensorflow/core/grappler/optimizers:remapper",
"//tensorflow/core/grappler/optimizers:shape_optimizer",
"//tensorflow/core/grappler/utils:functions",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core:framework",

View File

@ -21,15 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
#include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/grappler/optimizers/remapper.h"
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/util/ptr_util.h"
@ -60,14 +52,6 @@ constexpr std::array<const char*, 15> kTFDataOptimizations = {
"slack",
"inject_prefetch"};
// Standard grappler optimizations, in the order we want to perform them.
// The order matches the order in the generic meta optimizer.
constexpr std::array<const char*, 9> kGrapplerOptimizations = {
"pruning", "function", "common_subgraph_elimination",
"shape", "arithmetic", "layout_optimizer",
"remapper", "loop", "dependency",
};
// Parses a list of string optimizer configurations into a map from
// optimizer name -> rewriter config for that optimizer.
Status ToConfigMap(
@ -118,11 +102,6 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
ApplyOptimization(optimization, cluster, &optimized_item));
}
for (const auto& optimization : kGrapplerOptimizations) {
TF_RETURN_IF_ERROR(
ApplyOptimization(optimization, cluster, &optimized_item));
}
// Store the final result of all the optimizations in `output`.
output->Swap(&optimized_item.graph);
@ -132,16 +111,17 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
.ReachableDefinitions(*output);
const auto producer = output->versions().producer();
bool optimized_functions = false;
for (const FunctionDef& func : output->library().function()) {
for (const auto& name : flib.ListFunctionNames()) {
auto* func = flib.Find(name);
// Skip non tf.data functions.
if (!func.attr().contains(data::kTFDataFunction)) continue;
VLOG(3) << "Optimize function: function=" << func.signature().name();
if (!func->attr().contains(data::kTFDataFunction)) continue;
VLOG(3) << "Optimize function: function=" << func->signature().name();
optimized_functions = true;
// Make a GrapplerItem from a FunctionDef.
GrapplerFunctionItem func_item;
TF_RETURN_IF_ERROR(
MakeGrapplerFunctionItem(func, flib, producer, &func_item));
MakeGrapplerFunctionItem(*func, flib, producer, &func_item));
GraphDef optimized_func_graph;
TF_RETURN_IF_ERROR(Optimize(cluster, func_item, &optimized_func_graph));
@ -162,7 +142,7 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// Replace optimized function with a new FunctionDef.
TF_RETURN_IF_ERROR(
flib.ReplaceFunction(func.signature().name(), optimized_func));
flib.ReplaceFunction(func->signature().name(), optimized_func));
}
if (optimized_functions) {
*output->mutable_library() = flib.ToProto();
@ -221,27 +201,6 @@ Status TFDataMetaOptimizer::Init(
}
}
// Enable a subset of grappler optimization that are enabled by default.
//
// Layout optimizations are excluded because they assume that ops without
// explicit device assignment will be placed on GPU (if available) but that's
// not the case for operations within tf.data functions.
//
// TODO(b/120437209): Re-enable constant folding.
//
// TODO(jsimsa): Make the set of generic Grappler optimization applied to
// tf.data functions configurable.
enabled_optimizers_["pruning"] = MakeUnique<ModelPruner>();
enabled_optimizers_["shape"] = MakeUnique<ShapeOptimizer>();
enabled_optimizers_["remapping"] = MakeUnique<Remapper>(RewriterConfig::ON);
enabled_optimizers_["common_subgraph_elimination"] =
MakeUnique<CommonSubgraphElimination>();
enabled_optimizers_["arithmetic"] = MakeUnique<ArithmeticOptimizer>();
enabled_optimizers_["dependency"] = MakeUnique<DependencyOptimizer>();
enabled_optimizers_["loop"] = MakeUnique<LoopOptimizer>();
enabled_optimizers_["function"] = MakeUnique<FunctionOptimizer>(
RewriterConfig::ON, /*lower_control_flow=*/true);
return Status::OK();
}

View File

@ -17,9 +17,12 @@ limitations under the License.
#include <string>
#include "absl/strings/match.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_split.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
@ -36,6 +39,11 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
constexpr char kConstOp[] = "Const";
constexpr char kCaseOp[] = "Case";
constexpr char kDeviceIndexOp[] = "DeviceIndex";
// TODO(b/157615690): clean up function implementation swap code.
// The overall idea for the function swap is like below:
// ----------- -----------
// inp_1 ->| P_C | -> out_1 g_inp_1 ->| P_C | -> g_out_1
@ -292,6 +300,74 @@ Status ImplementationSelector::MaybeOptimizeFunctionCall(
return Status::OK();
}
// Finds the index of the device from the device name list.
Status FindDeviceIndex(const utils::MutableNodeView* device_index_node,
const string& device, int* index) {
DeviceNameUtils::ParsedName parsed_name;
if (!DeviceNameUtils::ParseFullName(device, &parsed_name) ||
!parsed_name.has_type) {
return errors::Internal("Could not parse device name:", device);
}
const auto& device_list =
device_index_node->GetAttr("device_names")->list().s();
auto it = absl::c_find(device_list, parsed_name.type);
if (it != device_list.end()) {
*index = it - device_list.begin();
} else {
// Sets *index to device_list.size() because the default_fn is guaranteed to
// be the final item in the case op branching list.
*index = device_list.size();
}
return Status::OK();
}
// Rewrites the device_index op to a const op with value of the index.
void RewriteDeviceIndexOp(utils::MutableNodeView* device_index_node,
int index) {
// Modifies the DeviceIndex node to be an Const op with correct device index.
auto node = device_index_node->node();
node->set_op(kConstOp);
node->clear_attr();
(*node->mutable_attr())["dtype"].set_type(DT_INT32);
auto* tensor = (*node->mutable_attr())["value"].mutable_tensor();
tensor->set_dtype(DT_INT32);
tensor->add_int_val(index);
VLOG(2) << "Node after rewriting:" << node->DebugString();
}
Status ImplementationSelector::SelectDeviceIndex(GraphDef* graph) const {
Status status;
VLOG(2) << "graph before rewriting device index:" << graph->DebugString();
utils::MutableGraphView graph_view(graph, &status);
TF_RETURN_IF_ERROR(status);
const int num_nodes = graph_view.NumNodes();
for (int k = 0; k < num_nodes; ++k) {
auto* node_view = graph_view.GetNode(k);
if (node_view->GetOp() != kDeviceIndexOp) {
continue;
}
VLOG(2) << "Found a node to rewrite the device index";
// Find the case node with device index node as input, rewrite the
// DeviceIndex node to have the value of the index of device type of the
// case node.
for (const auto& fanouts : node_view->GetRegularFanouts()) {
for (const auto& fanout : fanouts) {
if (fanout.node_view()->GetOp() != kCaseOp) continue;
int index;
// If any error is thrown out during device parsing, we simply skip
// and do not modify the DeviceIndexNode.
Status status =
FindDeviceIndex(node_view, fanout.node_view()->GetDevice(), &index);
if (status.ok()) {
RewriteDeviceIndexOp(node_view, index);
}
}
}
}
return Status::OK();
}
Status ImplementationSelector::SelectImplementation(GraphDef* graph) const {
if (!graph->has_library()) {
VLOG(2) << "Skipping graph since it does not have function def";
@ -307,8 +383,9 @@ Status ImplementationSelector::SelectImplementation(GraphDef* graph) const {
TF_RETURN_IF_ERROR(status);
const int num_nodes = graph_view.NumNodes();
for (int k = 0; k < num_nodes; ++k)
for (int k = 0; k < num_nodes; ++k) {
TF_RETURN_IF_ERROR(MaybeOptimizeFunctionCall(graph_view.GetNode(k)));
}
return Status::OK();
}
@ -326,7 +403,13 @@ Status ImplementationSelector::Optimize(Cluster* cluster,
<< "libraries: " << status;
return errors::Aborted("Skipped Optimization");
}
*optimized_graph = item.graph;
status = SelectDeviceIndex(optimized_graph);
if (!status.ok()) {
*optimized_graph = item.graph;
VLOG(2) << "Could not rewrite device index due to error:" << status;
}
return SelectImplementation(optimized_graph);
}

View File

@ -34,6 +34,28 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
// Motivation: To achieve the same high level functionality, the underlying
// implementations sometimes are different for various devices where the
// function runs. In order to achieve the correct result and best performance,
// the proper implementation needs to be picked dynamically.
//
// Currently there are two approaches to do this.
// (1) Utilize case op and dynamacically change the branch index.
// (2) Swap function implementation, it will be deprecated.
//
// Idea for approach 1.
// This transformation rewrites the DeviceIndex op with a Const op with value
// of the index of the device the associcated Case op runs.
// Example:
// def plus_one_gpu(x): return x + 1.0
// def plus_one_reference_implementation(x): return x + 1.0
// input = tf.constant(2.0, dtype=tf.float32)
// cpu_fn = lambda:plus_one_reference_implementation(input)
// gpu_fn = lambda:plus_one_gpu(input)
// control_flow_ops.execute_fn_for_device(
// {"CPU": cpu_fn, "GPU":gpu_fn)}, default_fn=cpu_fn)
//
// Idea for approach 2.
// This transformation replaces function calls by the appropriate function
// definition based on properties of the runtime system. For instance,
// we may choose one implementation over another if we have a GPU with
@ -58,7 +80,8 @@ namespace grappler {
// z = plus_one_gpu(input)
// print(sess.run(z))
//
// At runtime, we will trim either `plus_one_gpu` or
// At runtime, we will select either `plus_one_gpu` or
// `plus_one_reference_implementation` based on the availability of the GPU.
//
// Available annotations:
@ -106,6 +129,68 @@ class ImplementationSelector : public CustomGraphOptimizer {
// gradients.
Status SelectImplementation(GraphDef* graph) const;
// Rewrites the DeviceIndex op with a Const op with value of the index of the
// device the associcated Case op runs.
// This function first looks up all the DeviceIndex ops.
// Then for each of these ops, it finds the device of the
// associated Case op that takes the DeviceIndex op as the input, and
// caculates the index of the device in the device list of DeviceIndex op.
// Lastly, it rewrites the DeviceIndex op with a Const op and sets the value
// to be the index.
//
// Example input nodes:
// node {
// name: "x"
// op: "DeviceIndex"
// device: "/device:CPU:0"
// attr {
// key: "device_names"
// value {
// list {
// s: "CPU"
// s: "TPU_REPLICATED_CORE"
// s: "GPU"
// }
// }
// }
// }
// node {
// name: "case"
// op: "Case"
// input: "x"
// device: "/device:GPU:0"
// ...
// }
// Example output nodes:
//
// name: "x"
// op: "Const"
// device: "/device:CPU:0"
// attr {
// key: "dtype"
// value {
// type: DT_INT32
// }
// }
// attr {
// key: "value"
// value {
// tensor {
// dtype: DT_INT32
// int_val: 2
// }
// }
// }
// node {
// name: "case"
// op: "Case"
// input: "x"
// device: "/device:GPU:0"
// ...
// }
Status SelectDeviceIndex(GraphDef* graph) const;
std::unique_ptr<FunctionLibraryApiInfo> lib_info_;
TF_DISALLOW_COPY_AND_ASSIGN(ImplementationSelector);

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <string>
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
@ -58,6 +59,167 @@ TEST_F(ImplementationSelectorTest, NoUpdate) {
EXPECT_EQ(item.graph.node_size(), output.node_size());
}
TEST_F(ImplementationSelectorTest, SelectDeviceIndex) {
using test::function::NDef;
ImplementationSelector optimizer;
GraphDef output;
GrapplerItem item;
AttrValue device_names;
device_names.mutable_list()->add_s("CPU");
device_names.mutable_list()->add_s("GPU");
item.graph = test::function::GDef(
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
CpuDevice),
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, GpuDevice)});
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
for (const NodeDef& node : output.node()) {
if (node.name() == "x") {
// Rewrite DeviceIndex op to a Const op with value of GPU index 1.
EXPECT_EQ("Const", node.op());
EXPECT_EQ(1, node.attr().at("value").tensor().int_val(0));
}
}
}
TEST_F(ImplementationSelectorTest, SelectDeviceIndexMultiOps) {
using test::function::NDef;
ImplementationSelector optimizer;
GraphDef output;
GrapplerItem item;
AttrValue device_names;
device_names.mutable_list()->add_s("CPU");
device_names.mutable_list()->add_s("TPU_REPLICATED_CORE");
device_names.mutable_list()->add_s("GPU");
item.graph = test::function::GDef(
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
CpuDevice),
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
NDef("y", "DeviceIndex", {}, {{"device_names", device_names}},
GpuDevice),
NDef("case_y", "Case", {"y"}, {{"T", DT_FLOAT}}, TpuDevice)});
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
for (const NodeDef& node : output.node()) {
if (node.name() == "x") {
// Rewrite DeviceIndex op to a Const op with value of GPU index 1.
EXPECT_EQ("Const", node.op());
EXPECT_EQ(2, node.attr().at("value").tensor().int_val(0));
}
if (node.name() == "y") {
// Rewrite DeviceIndex op to a Const op with value of CPU index 0.
EXPECT_EQ("Const", node.op());
EXPECT_EQ(1, node.attr().at("value").tensor().int_val(0));
}
}
}
TEST_F(ImplementationSelectorTest, SelectDeviceIndexNotFound) {
using test::function::NDef;
ImplementationSelector optimizer;
GraphDef output;
GrapplerItem item;
AttrValue device_names;
device_names.mutable_list()->add_s("CPU");
device_names.mutable_list()->add_s("GPU");
item.graph = test::function::GDef(
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
CpuDevice),
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, TpuDevice)});
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
for (const NodeDef& node : output.node()) {
if (node.name() == "x") {
// Rewrite DeviceIndex op to a Const op with value of device names length.
EXPECT_EQ("Const", node.op());
EXPECT_EQ(2, node.attr().at("value").tensor().int_val(0));
}
}
}
TEST_F(ImplementationSelectorTest, SelectDeviceIndexError) {
using test::function::NDef;
ImplementationSelector optimizer;
GraphDef output;
GrapplerItem item;
AttrValue device_names;
device_names.mutable_list()->add_s("CPU");
device_names.mutable_list()->add_s("GPU");
item.graph = test::function::GDef(
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
CpuDevice),
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, "")});
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
for (const NodeDef& node : output.node()) {
if (node.name() == "x") {
// Device parse has error, do not rewrite the DeviceIndexNode.
EXPECT_EQ("DeviceIndex", node.op());
}
}
}
TEST_F(ImplementationSelectorTest, TwoTypesOfSwapImplementation) {
using test::function::NDef;
ImplementationSelector optimizer;
GraphDef output;
GrapplerItem item;
// DeviceIndex op based implementation selector.
AttrValue device_names;
device_names.mutable_list()->add_s("CPU");
device_names.mutable_list()->add_s("TPU_REPLICATED_CORE");
device_names.mutable_list()->add_s("GPU");
// Function swap based implementation selector.
auto cpu_def = test::function::XTimesTwo();
auto* func_attr = cpu_def.mutable_attr();
(*func_attr)["api_implements"].set_s("times_two");
(*func_attr)["api_preferred_device"].set_s("CPU");
auto gpu_def = test::function::XAddX();
auto* func2_attr = gpu_def.mutable_attr();
(*func2_attr)["api_implements"].set_s("times_two");
(*func2_attr)["api_preferred_device"].set_s("GPU");
item.graph = test::function::GDef(
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
CpuDevice),
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
NDef("y", "DeviceIndex", {}, {{"device_names", device_names}},
GpuDevice),
NDef("case_y", "Case", {"y"}, {{"T", DT_FLOAT}}, TpuDevice),
NDef("y1", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
NDef("z1", "Identity", {"y1"}, {{"T", DT_FLOAT}}, GpuDevice),
NDef("y2", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, CpuDevice),
NDef("z2", "Identity", {"y2"}, {{"T", DT_FLOAT}}, CpuDevice)},
// FunctionLib
{cpu_def, gpu_def});
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
for (const NodeDef& node : output.node()) {
if (node.name() == "x") {
// Rewrite DeviceIndex op to a Const op with value of GPU index 1.
EXPECT_EQ("Const", node.op());
EXPECT_EQ(2, node.attr().at("value").tensor().int_val(0));
}
if (node.name() == "y") {
// Rewrite DeviceIndex op to a Const op with value of CPU index 0.
EXPECT_EQ("Const", node.op());
EXPECT_EQ(1, node.attr().at("value").tensor().int_val(0));
}
if (node.name() == "y1") {
// Make sure the implementation has been swapped to use the GPU version.
EXPECT_EQ("XAddX", node.op());
} else if (node.name() == "y2") {
// Make sure the implementation is not changed.
EXPECT_EQ("XTimesTwo", node.op());
}
}
}
TEST_F(ImplementationSelectorTest, SwapImplementation) {
using test::function::NDef;
auto cpu_def = test::function::XTimesTwo();

View File

@ -272,6 +272,7 @@ class BatchResource : public ResourceBase {
int32 batch_timeout_micros, int32 max_enqueued_batches,
const std::vector<int32>& allowed_batch_sizes,
FunctionLibraryRuntime::Handle fhandle,
bool enable_large_batch_splitting,
std::unique_ptr<BatchResource>* resource) {
std::unique_ptr<BatchResource> new_resource(new BatchResource);
@ -286,6 +287,10 @@ class BatchResource : public ResourceBase {
new_resource->batcher_queue_options_.batch_timeout_micros =
batch_timeout_micros;
// Support for splitting large batch is still in progress.
new_resource->batcher_queue_options_.enable_large_batch_splitting =
enable_large_batch_splitting;
new_resource->allowed_batch_sizes_ = allowed_batch_sizes;
new_resource->fhandle_ = fhandle;
@ -786,6 +791,13 @@ class BatchFunctionKernel : public AsyncOpKernel {
OP_REQUIRES_OK(c, c->GetAttr("f", &func));
OP_REQUIRES_OK(
c, lib->Instantiate(func.name(), AttrSlice(&func.attr()), &fhandle_));
if (c->HasAttr("enable_large_batch_splitting")) {
OP_REQUIRES_OK(c, c->GetAttr("enable_large_batch_splitting",
&enable_large_batch_splitting_));
} else {
enable_large_batch_splitting_ = false;
}
}
bool IsExpensive() override { return false; }
@ -794,10 +806,10 @@ class BatchFunctionKernel : public AsyncOpKernel {
BatchResource* br;
std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
std::unique_ptr<BatchResource> new_resource;
TF_RETURN_IF_ERROR(
BatchResource::Create(num_batch_threads_, max_batch_size_,
batch_timeout_micros_, max_enqueued_batches_,
allowed_batch_sizes_, fhandle_, &new_resource));
TF_RETURN_IF_ERROR(BatchResource::Create(
num_batch_threads_, max_batch_size_, batch_timeout_micros_,
max_enqueued_batches_, allowed_batch_sizes_, fhandle_,
enable_large_batch_splitting_, &new_resource));
*r = new_resource.release();
return Status::OK();
};
@ -844,6 +856,7 @@ class BatchFunctionKernel : public AsyncOpKernel {
int32 max_enqueued_batches_;
std::vector<int32> allowed_batch_sizes_;
FunctionLibraryRuntime::Handle fhandle_;
bool enable_large_batch_splitting_;
};
REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
@ -876,7 +889,7 @@ class BatchKernel : public AsyncOpKernel {
std::unique_ptr<BatchResource> new_resource;
TF_RETURN_IF_ERROR(BatchResource::Create(
num_batch_threads_, max_batch_size_, batch_timeout_micros_,
max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle,
max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle, false,
&new_resource));
*r = new_resource.release();
return Status::OK();

View File

@ -160,6 +160,10 @@ class SharedBatchScheduler
// See the class documentation above for guidelines on how to tune this
// parameter.
size_t max_enqueued_batches = 10;
// If true, queue implementation would split one input batch task into
// subtasks and fit them into different batches.
bool enable_large_batch_splitting = false;
};
Status AddQueue(const QueueOptions& options,
std::function<void(std::unique_ptr<Batch<TaskType>>)>

View File

@ -3,6 +3,7 @@
load(
"//tensorflow:tensorflow.bzl",
"if_not_mobile",
"tf_cc_test",
"tf_kernel_library",
)
@ -150,6 +151,7 @@ cc_library(
":dataset_utils",
":single_threaded_executor",
":stats_utils",
"@com_google_absl//absl/time",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
@ -158,8 +160,10 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:variable_ops",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/time",
],
] + if_not_mobile([
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:meta_optimizer",
]),
)
cc_library(

View File

@ -35,6 +35,11 @@ limitations under the License.
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#endif // !IS_MOBILE_PLATFORM
namespace tensorflow {
namespace data {
namespace {
@ -612,6 +617,28 @@ Status CapturedFunction::Instantiate(
for (size_t i = 0; i < fdef->signature().output_arg_size(); ++i) {
inst_opts.output_devices.push_back(inst_opts.target);
}
#if !defined(IS_MOBILE_PLATFORM)
grappler::GrapplerItem::OptimizationOptions optimization_options;
optimization_options.allow_pruning_stateful_and_dataset_ops = false;
ConfigProto config_proto = inst_opts.config_proto;
// Layout optimizations are excluded because they assume that ops without
// explicit device assignment will be placed on GPU (if available) but
// that's not the case for operations within tf.data functions.
config_proto.mutable_graph_options()
->mutable_rewrite_options()
->set_layout_optimizer(RewriterConfig::OFF);
// TODO(b/120437209): Re-enable constant folding.
config_proto.mutable_graph_options()
->mutable_rewrite_options()
->set_constant_folding(RewriterConfig::OFF);
inst_opts.optimize_graph_fn =
std::bind(tensorflow::grappler::OptimizeGraph, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3,
std::placeholders::_4, std::placeholders::_5,
std::move(config_proto), fdef->signature().name(),
std::move(optimization_options), std::placeholders::_6);
#endif // !IS_MOBILE_PLATFORM
}
FunctionLibraryRuntime::Handle f_handle;

View File

@ -100,9 +100,13 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
AttrValue slack_period_attr;
b->BuildAttrValue(slack_period_, &slack_period_attr);
TF_RETURN_IF_ERROR(b->AddDataset(
this, {input_graph_node, buffer_size},
{std::make_pair(kSlackPeriod, slack_period_attr)}, output));
AttrValue legacy_autotune_attr;
b->BuildAttrValue(legacy_autotune_, &legacy_autotune_attr);
TF_RETURN_IF_ERROR(
b->AddDataset(this, {input_graph_node, buffer_size},
{std::make_pair(kSlackPeriod, slack_period_attr),
std::make_pair(kLegacyAutotune, legacy_autotune_attr)},
output));
return Status::OK();
}

View File

@ -39,7 +39,9 @@ namespace data {
constexpr char kCurrentFileIndex[] = "current_file_index";
constexpr char kOffset[] = "offset";
constexpr char kGcsFsPrefix[] = "gs://";
constexpr char kS3FsPrefix[] = "s3://";
constexpr int64 kCloudTpuBlockSize = 127LL << 20; // 127MB.
constexpr int64 kS3BlockSize = kCloudTpuBlockSize;
bool is_cloud_tpu_gcs_fs() {
#if defined(PLATFORM_CLOUD_TPU) && defined(TPU_GCS_FS)
@ -237,12 +239,14 @@ void TFRecordDatasetOp::MakeDataset(OpKernelContext* ctx,
errors::InvalidArgument("`filenames` must be a scalar or a vector."));
bool is_gcs_fs = true;
bool is_s3_fs = true;
std::vector<string> filenames;
filenames.reserve(filenames_tensor->NumElements());
for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
VLOG(2) << "Reading file: " << filenames_tensor->flat<tstring>()(i);
filenames.push_back(filenames_tensor->flat<tstring>()(i));
is_gcs_fs &= absl::StartsWith(filenames[i], kGcsFsPrefix);
is_s3_fs &= absl::StartsWith(filenames[i], kS3FsPrefix);
}
tstring compression_type;
@ -264,6 +268,13 @@ void TFRecordDatasetOp::MakeDataset(OpKernelContext* ctx,
buffer_size = kCloudTpuBlockSize;
}
if (is_s3_fs && buffer_size < kS3BlockSize) {
VLOG(2) << "User buffer size is too small for reading "
<< "TFRecords stored in S3. Overriding " << buffer_size
<< " to the minimum recommended buffer_size = " << kS3BlockSize;
buffer_size = kS3BlockSize;
}
*output =
new Dataset(ctx, std::move(filenames), compression_type, buffer_size);
}

View File

@ -924,5 +924,37 @@ class FakeParamOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_CPU), FakeParamOp);
REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_GPU), FakeParamOp);
// DeviceIndexOP returns the current device index.
class DeviceIndexOp : public OpKernel {
public:
explicit DeviceIndexOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("device_names", &device_names_));
}
void Compute(OpKernelContext* ctx) override {
Tensor* device_name_t;
OP_REQUIRES_OK(ctx,
ctx->allocate_output(0, TensorShape({}), &device_name_t));
DeviceNameUtils::ParsedName parsed_name;
int index = device_names_.size();
if (DeviceNameUtils::ParseFullName(ctx->device()->name(), &parsed_name) &&
parsed_name.has_type) {
auto it = absl::c_find(device_names_, parsed_name.type);
if (it != device_names_.end()) {
index = it - device_names_.begin();
}
}
device_name_t->scalar<int32>()() = index;
}
private:
PersistentTensor value_handle_;
std::vector<string> device_names_;
};
REGISTER_KERNEL_BUILDER(Name("DeviceIndex").Device(DEVICE_CPU), DeviceIndexOp);
REGISTER_KERNEL_BUILDER(
Name("DeviceIndex").Device(DEVICE_GPU).HostMemory("index"), DeviceIndexOp);
} // namespace
} // namespace tensorflow

View File

@ -244,7 +244,7 @@ class MklAddNOp : public OpKernel {
// Create Sum op, and submit net for execution.
std::vector<primitive> net;
auto sum_stream = CPU_STREAM(cpu_engine);
stream* fwd_cpu_stream = CreateStream(ctx, cpu_engine);
#ifdef ENABLE_MKLDNN_V1
mkldnn::sum sum_op(sum_pd);
std::unordered_map<int, memory> net_args = {
@ -253,10 +253,10 @@ class MklAddNOp : public OpKernel {
for (int i = 0; i < num_inputs; ++i) {
net_args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, inputs[i]});
}
sum_op.execute(sum_stream, net_args);
sum_op.execute(*fwd_cpu_stream, net_args);
#else
net.push_back(sum(sum_pd, inputs, dst.GetOpMem()));
sum_stream.submit(net).wait();
fwd_cpu_stream->submit(net).wait();
#endif
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +

View File

@ -136,9 +136,10 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
const T* src_data = input_tensor.flat<T>().data();
T* dst_data = output_tensor->flat<T>().data();
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, pooling_fwd->GetEngine()));
// Execute pooling op.
pooling_fwd->Execute(src_data, dst_data);
pooling_fwd->Execute(src_data, dst_data, nullptr, fwd_cpu_stream);
// Pass min, max from input to output.
if (int8_forward_inference) {
@ -240,8 +241,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
: memory::desc(diff_dst_dims, MklDnnType<T>(),
this->data_format_mkldnn_);
// Pass prop_kind::forward_training to create a forward primitive
// that is used in the backward pass.
// Pass prop_kind::forward_training to create a forward primitive
// that is used in the backward pass.
#ifdef ENABLE_MKLDNN_V1
// TODO(DNNL): Find out what should we use src_md.data.format.
MklPoolingParams bwdParams(
@ -260,6 +261,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
MklPoolingBwdPrimitive<T>* pooling_bwd =
MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);
std::shared_ptr<stream> bwd_cpu_stream;
bwd_cpu_stream.reset(CreateStream(context, pooling_bwd->GetEngine()));
Tensor* output_tensor = nullptr;
this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()),
orig_input_dims_mkl_order,
@ -286,7 +289,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
T* diff_src_data = output_tensor->flat<T>().data();
// Execute pooling op.
pooling_bwd->Execute(diff_dst_data, diff_src_data);
pooling_bwd->Execute(diff_dst_data, diff_src_data, nullptr,
bwd_cpu_stream);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +

View File

@ -265,8 +265,7 @@ class MklConcatFwdPrimitive : public MklPrimitive {
public:
explicit MklConcatFwdPrimitive(const MklConcatFwdParams& concat_fwd_dims,
const std::vector<memory::desc>& srcs_md)
: cpu_engine_(ENGINE_CPU, 0) {
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
: MklPrimitive(engine(ENGINE_CPU, 0)) {
// Create concat primitive
Setup(concat_fwd_dims, srcs_md);
}
@ -278,7 +277,8 @@ class MklConcatFwdPrimitive : public MklPrimitive {
// dst_data: output data buffer of dst
void Execute(const std::vector<mkldnn::memory>& in_data,
const mkldnn::memory& dst_data,
const MklConcatFwdParams& concat_fwd_dims) {
const MklConcatFwdParams& concat_fwd_dims,
std::shared_ptr<stream> fwd_stream) {
DCHECK_EQ(in_data.size(), context_.data_mem.size());
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
context_.data_mem_shdptr[i]->set_data_handle(
@ -292,10 +292,10 @@ class MklConcatFwdPrimitive : public MklPrimitive {
}
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.fwd_primitives, context_.fwd_stream,
execute_primitives(context_.fwd_primitives, fwd_stream,
context_.fwd_primitives_args);
#else
context_.fwd_stream->submit(context_.fwd_primitives);
fwd_stream->submit(context_.fwd_primitives);
#endif // ENABLE_MKLDNN_V1
// After exec, set data handle back
@ -335,7 +335,6 @@ class MklConcatFwdPrimitive : public MklPrimitive {
std::shared_ptr<mkldnn::concat::primitive_desc> fwd_pd;
std::shared_ptr<mkldnn::primitive> concat_fwd;
std::shared_ptr<mkldnn::stream> fwd_stream;
std::vector<mkldnn::primitive> fwd_primitives;
#ifdef ENABLE_MKLDNN_V1
@ -343,10 +342,7 @@ class MklConcatFwdPrimitive : public MklPrimitive {
#endif // ENABLE_MKLDNN_V1
ConcatFwdContext()
: dst_mem(nullptr),
fwd_pd(nullptr),
concat_fwd(nullptr),
fwd_stream(nullptr) {}
: dst_mem(nullptr), fwd_pd(nullptr), concat_fwd(nullptr) {}
};
// Creates the src and dst memory descriptor for mkl concat
@ -417,7 +413,6 @@ class MklConcatFwdPrimitive : public MklPrimitive {
}
struct ConcatFwdContext context_;
engine cpu_engine_;
};
// Class to create/cache the mkl concat primitives based on the
@ -758,7 +753,7 @@ class MklConcatOp : public OpKernel {
for (int k = 0; k < input_tensors.size(); k++) {
if (input_tensors[k].NumElements() > 0) {
srcs[k].CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(srcs_pd[k], cpu_engine));
MEMORY_PD_WITHOUT_DATA(srcs_pd[k], cpu_engine), context);
inputs.push_back(srcs[k].GetOpMem());
}
}
@ -796,7 +791,8 @@ class MklConcatOp : public OpKernel {
if (dnn_shape_dst.IsMklTensor())
dst_md = dnn_shape_dst.GetMklLayout();
dst.SetUsrMem(dst_md, dst_tensor);
stream concat_stream = CPU_STREAM(cpu_engine);
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, cpu_engine));
#ifdef ENABLE_MKLDNN_V1
auto concat_op = concat(concat_pd);
std::unordered_map<int, memory> net_args = {
@ -805,12 +801,12 @@ class MklConcatOp : public OpKernel {
for (int i = 0; i < inputs.size(); ++i) {
net_args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, inputs[i]});
}
concat_op.execute(concat_stream, net_args);
concat_op.execute(*fwd_cpu_stream, net_args);
#else
auto concat_op = concat(concat_pd, inputs, dst.GetOpMem());
std::vector<primitive> net;
net.push_back(concat_op);
concat_stream.submit(net).wait();
fwd_cpu_stream->submit(net).wait();
#endif // ENABLE_MKLDNN_V1
} else {
MklConcatFwdPrimitive<T>* concat_fwd = nullptr;
@ -835,9 +831,11 @@ class MklConcatOp : public OpKernel {
dst_md = dnn_shape_dst.IsMklTensor() ? dnn_shape_dst.GetMklLayout()
: dst_md;
dst.SetUsrMem(dst_md, dst_tensor);
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, concat_fwd->GetEngine()));
// Execute concat
concat_fwd->Execute(srcs_mem, dst.GetOpMem(), concat_fwd_dims);
concat_fwd->Execute(srcs_mem, dst.GetOpMem(), concat_fwd_dims,
fwd_cpu_stream);
}
// For quantized concat, min and max outputs are also computed.

View File

@ -97,9 +97,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
public:
explicit MklConvBwdFilterPrimitive(
const MklConvBwdFilterParams& convBwdFilterDims)
: cpu_engine_(ENGINE_CPU, 0) {
context_.bwd_filter_stream.reset(new CPU_STREAM(cpu_engine_));
: MklPrimitive(engine(ENGINE_CPU, 0)) {
// Create convolution backward filter primitive.
if (context_.conv_bwd_filter == nullptr) {
Setup(convBwdFilterDims);
@ -114,7 +112,8 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
// diff_bias_data: output data buffer for diff_bias
// diff_dst_data: input data buffer for diff_dst
void Execute(const T* src_data, const T* diff_filter_data,
const T* diff_bias_data, const T* diff_dst_data) {
const T* diff_bias_data, const T* diff_dst_data,
std::shared_ptr<stream> bwd_filter_stream) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.diff_filter_mem->set_data_handle(
@ -127,11 +126,10 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
static_cast<void*>(const_cast<T*>(diff_dst_data)));
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.bwd_filter_primitives,
context_.bwd_filter_stream,
execute_primitives(context_.bwd_filter_primitives, bwd_filter_stream,
context_.bwd_filter_primitives_args);
#else
context_.bwd_filter_stream->submit(context_.bwd_filter_primitives);
bwd_filter_stream->submit(context_.bwd_filter_primitives);
#endif
context_.src_mem->set_data_handle(DummyData);
@ -147,8 +145,10 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
// diff_filter_data: output data buffer of diff_filter
// diff_dst_data: input data buffer of diff_dst
void Execute(const T* src_data, const T* diff_filter_data,
const T* diff_dst_data) {
Execute(src_data, diff_filter_data, nullptr, diff_dst_data);
const T* diff_dst_data,
std::shared_ptr<stream> bwd_filter_stream) {
Execute(src_data, diff_filter_data, nullptr, diff_dst_data,
bwd_filter_stream);
}
#ifndef ENABLE_MKLDNN_V1
@ -223,8 +223,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
src_md(nullptr),
diff_filter_md(nullptr),
diff_bias_md(nullptr),
diff_dst_md(nullptr),
bwd_filter_stream(nullptr) {
diff_dst_md(nullptr) {
}
};
@ -345,7 +344,6 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
}
struct ConvBwdFilterContext context_;
engine cpu_engine_;
};
template <typename T>
@ -600,8 +598,10 @@ class MklConvCustomBackpropFilterOp
auto bwd_filter_pd = conv_bwd_filter->GetPrimitiveDesc();
if (IS_SRC_REORDER_NEEDED(fwd_src_md, bwd_filter_pd, conv_bwd_filter)) {
src.SetUsrMem(fwd_src_md, &src_tensor);
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
bwd_filter_pd->PRIMITIVE_DESC_SRC, cpu_engine_));
src.CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(bwd_filter_pd->PRIMITIVE_DESC_SRC,
cpu_engine_),
context);
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
} else {
src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
@ -612,8 +612,10 @@ class MklConvCustomBackpropFilterOp
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bwd_filter_pd,
conv_bwd_filter)) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
bwd_filter_pd->PRIMITIVE_DESC_DIFF_DST, cpu_engine_));
diff_dst.CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(bwd_filter_pd->PRIMITIVE_DESC_DIFF_DST,
cpu_engine_),
context);
diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle());
} else {
diff_dst_data =
@ -646,18 +648,21 @@ class MklConvCustomBackpropFilterOp
}
// Execute convolution backward filter.
std::shared_ptr<stream> bwd_cpu_stream;
bwd_cpu_stream.reset(CreateStream(context, conv_bwd_filter->GetEngine()));
if (bias_enabled) {
T* diff_bias_data =
static_cast<T*>(const_cast<T*>(diff_bias_tensor->flat<T>().data()));
conv_bwd_filter->Execute(src_data, diff_filter_data, diff_bias_data,
diff_dst_data);
diff_dst_data, bwd_cpu_stream);
} else {
conv_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data);
conv_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data,
bwd_cpu_stream);
}
// Reorder diff_filter back to Tensorflow layout if necessary.
if (diff_filter_reorder_required) {
diff_filter.InsertReorderToUserMem();
diff_filter.InsertReorderToUserMem(context);
}
// Delete primitive since it is not cached.

View File

@ -99,9 +99,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
public:
explicit MklConvBwdInputPrimitive(
const MklConvBwdInputParams& convBwdInputDims)
: cpu_engine_(ENGINE_CPU, 0) {
context_.bwd_input_stream.reset(new CPU_STREAM(cpu_engine_));
: MklPrimitive(engine(ENGINE_CPU, 0)) {
// Create conv bwd input primitive
if (context_.conv_bwd_input == nullptr) {
Setup(convBwdInputDims);
@ -116,7 +114,8 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
// diff_dst_data: input data buffer for dst
// Bias does not matter here
void Execute(const T* diff_src_data, const T* filter_data,
const T* diff_dst_data) {
const T* diff_dst_data,
std::shared_ptr<stream> bwd_input_stream) {
context_.diff_src_mem->set_data_handle(
static_cast<T*>(const_cast<T*>(diff_src_data)));
context_.filter_mem->set_data_handle(
@ -125,10 +124,10 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
static_cast<T*>(const_cast<T*>(diff_dst_data)));
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.bwd_input_primitives, context_.bwd_input_stream,
execute_primitives(context_.bwd_input_primitives, bwd_input_stream,
context_.bwd_input_primitives_args);
#else
context_.bwd_input_stream->submit(context_.bwd_input_primitives);
bwd_input_stream->submit(context_.bwd_input_primitives);
#endif // ENABLE_MKLDNN_V1
// Set data handle back to DummyData.
@ -180,7 +179,6 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
std::shared_ptr<memory::desc> diff_dst_md;
// MKL-DNN pipeline for executing primitives.
std::shared_ptr<mkldnn::stream> bwd_input_stream;
std::vector<mkldnn::primitive> bwd_input_primitives;
#ifdef ENABLE_MKLDNN_V1
@ -203,8 +201,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
fwd_pd(nullptr),
diff_src_md(nullptr),
filter_md(nullptr),
diff_dst_md(nullptr),
bwd_input_stream(nullptr) {
diff_dst_md(nullptr) {
}
};
@ -290,7 +287,6 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
}
struct ConvBwdInputContext context_;
engine cpu_engine_;
};
template <typename T>
@ -522,8 +518,10 @@ class MklConvCustomBackpropInputOp
if (IS_FILTER_REORDER_NEEDED(fwd_filter_md, bwd_input_pd,
conv_bwd_input)) {
filter.SetUsrMem(fwd_filter_md, &filter_tensor);
filter.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
bwd_input_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_));
filter.CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(bwd_input_pd.get()->PRIMITIVE_DESC_WEIGHTS,
cpu_engine_),
context);
filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle());
} else {
filter_data =
@ -535,23 +533,29 @@ class MklConvCustomBackpropInputOp
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bwd_input_pd,
conv_bwd_input)) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
bwd_input_pd.get()->PRIMITIVE_DESC_DIFF_DST, cpu_engine_));
diff_dst.CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(bwd_input_pd.get()->PRIMITIVE_DESC_DIFF_DST,
cpu_engine_),
context);
diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle());
} else {
diff_dst_data =
static_cast<T*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
}
std::shared_ptr<stream> bwd_cpu_stream;
bwd_cpu_stream.reset(CreateStream(context, conv_bwd_input->GetEngine()));
// Execute conv bwd input primitive.
if (!eager_mode) {
conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data,
bwd_cpu_stream);
} else {
// In eager mode we first write the output to temporary
// buffer in MKL format. Then we convert the data to TF format.
T* tmp_data =
static_cast<T*>(const_cast<T*>(tmp_tensor.flat<T>().data()));
conv_bwd_input->Execute(tmp_data, filter_data, diff_dst_data);
conv_bwd_input->Execute(tmp_data, filter_data, diff_dst_data,
bwd_cpu_stream);
auto output_tf_md = diff_src_mkl_shape.GetTfLayout();
#ifndef ENABLE_MKLDNN_V1
auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine_);
@ -563,7 +567,7 @@ class MklConvCustomBackpropInputOp
memory* dst_data_mem =
new MEMORY_CONSTRUCTOR(OUTPUT_TF_MD, cpu_engine_, diff_src_data);
CreateAndExecuteReorder(reorder_pd, *tmp_data_mem, *dst_data_mem,
cpu_engine_);
cpu_engine_, context);
}
// Delete primitive since it is not cached.

View File

@ -155,7 +155,8 @@ class MklDequantizeOp : public OpKernel {
// Also it does not define round_nearest (enum).
attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
#endif // !ENABLE_MKLDNN_V1
stream reorder_stream = CPU_STREAM(cpu_engine);
std::shared_ptr<stream> reorder_stream;
reorder_stream.reset(CreateStream(ctx, cpu_engine));
std::vector<primitive> net;
// Create reorder primitive and then execute.
@ -169,11 +170,10 @@ class MklDequantizeOp : public OpKernel {
reorder_net_args.push_back({{MKLDNN_ARG_FROM, *src.GetUsrMem()},
{ MKLDNN_ARG_TO,
*dst.GetUsrMem() }});
execute_primitives(net, std::make_shared<stream>(reorder_stream),
reorder_net_args);
execute_primitives(net, reorder_stream, reorder_net_args);
#else
net.push_back(reorder(reorder_pd, *src.GetUsrMem(), *dst.GetUsrMem()));
reorder_stream.submit(net);
reorder_stream->submit(net);
#endif // ENABLE_MKLDNN_V1
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +

View File

@ -79,13 +79,7 @@ template <typename T, typename U>
class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
public:
explicit MklFusedBatchNormFwdPrimitive(const MklBatchNormFwdParams& fwdParams)
: cpu_engine_(ENGINE_CPU, 0) {
#ifdef ENABLE_MKLDNN_V1
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
#else
context_.fwd_stream.reset(
new mkldnn::stream(mkldnn::stream::kind::eager_nostore));
#endif
: MklPrimitive(engine(ENGINE_CPU, 0)) {
if (context_.bn_fwd == nullptr) Setup(fwdParams);
}
@ -98,7 +92,8 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
// mean_data: output data buffer of means
// variance_data: output data buffer of variances
void Execute(const T* src_data, const U* weights_data, T* dst_data,
U* mean_data, U* variance_data, U* workspace_data) {
U* mean_data, U* variance_data,
std::shared_ptr<stream> fwd_stream, U* workspace_data) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
@ -117,10 +112,10 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
}
#ifdef ENABLE_MKLDNN_V1
// Execute batch-normalization forward primitives.
execute_primitives(context_.fwd_primitives, context_.fwd_stream,
context_.net_args);
execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
#else
context_.fwd_stream->submit(context_.fwd_primitives);
fwd_stream.reset(new stream(stream::kind::eager_nostore));
fwd_stream->submit(context_.fwd_primitives);
#endif // ENABLE_MKLDNN_V1
context_.src_mem->set_data_handle(DummyData);
@ -180,7 +175,6 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
// BatchNorm forward primitive.
std::shared_ptr<mkldnn::primitive> bn_fwd;
std::shared_ptr<mkldnn::stream> fwd_stream;
std::vector<mkldnn::primitive> fwd_primitives;
#ifdef ENABLE_MKLDNN_V1
@ -195,9 +189,8 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
dst_mem(nullptr),
mean_mem(nullptr),
variance_mem(nullptr),
ws_mem(nullptr),
bn_fwd(nullptr),
fwd_stream(nullptr) {}
ws_mem(nullptr) {}
};
void Setup(const MklBatchNormFwdParams& fwdParams) {
@ -392,7 +385,6 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
}
struct BatchNormFwdContext context_;
engine cpu_engine_;
};
template <typename T, typename U>
@ -489,13 +481,7 @@ template <typename T, typename U>
class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
public:
explicit MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams& bwdParams)
: cpu_engine_(ENGINE_CPU, 0) {
#ifdef ENABLE_MKLDNN_V1
context_.bwd_stream.reset(new CPU_STREAM(cpu_engine_));
#else
context_.bwd_stream.reset(
new mkldnn::stream(mkldnn::stream::kind::eager_nostore));
#endif
: MklPrimitive(engine(ENGINE_CPU, 0)) {
if (context_.bn_bwd == nullptr) Setup(bwdParams);
}
@ -515,7 +501,8 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
// on CPU as of now.
void Execute(const T* src_data, const U* mean_data, const U* variance_data,
const T* diff_dst_data, const U* weights_data, T* diff_src_data,
U* diff_weights_data, U* res_space_data) {
U* diff_weights_data, U* res_space_data,
std::shared_ptr<stream> bwd_stream) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.mean_mem->set_data_handle(
@ -537,10 +524,10 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
#ifdef ENABLE_MKLDNN_V1
// Execute backward batch-normalization primitives.
DCHECK_EQ(context_.bwd_primitives.size(), context_.net_args.size());
execute_primitives(context_.bwd_primitives, context_.bwd_stream,
context_.net_args);
execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args);
#else
context_.bwd_stream->submit(context_.bwd_primitives);
bwd_stream.reset(new stream(stream::kind::eager_nostore));
bwd_stream->submit(context_.bwd_primitives);
#endif // ENABLE_MKLDNN_V1
// After execution, set data handle back to DummyData.
@ -593,7 +580,6 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
// Backward batch-normalization primitive.
std::shared_ptr<mkldnn::primitive> bn_bwd;
std::vector<mkldnn::primitive> bwd_primitives;
std::shared_ptr<mkldnn::stream> bwd_stream;
#ifdef ENABLE_MKLDNN_V1
std::vector<std::unordered_map<int, memory>> net_args;
@ -606,8 +592,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
diff_dst_mem(nullptr),
weights_mem(nullptr),
diff_weights_mem(nullptr),
diff_src_mem(nullptr),
bwd_stream(nullptr) {}
diff_src_mem(nullptr) {}
};
void Setup(const MklBatchNormBwdParams& bwdParams) {
@ -616,7 +601,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
? GET_FLAG(use_scale_shift)
: (GET_FLAG(use_scale_shift) | GET_FLAG(use_global_stats));
// Memory descriptors.
// Memory descriptors.
#ifndef ENABLE_MKLDNN_V1
auto src_md = memory::desc({bwdParams.src_dims}, MklDnnType<T>(),
bwdParams.src_format);
@ -689,7 +674,6 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
}
struct BatchNormBwdContext context_;
engine cpu_engine_;
};
template <typename T, typename U>
@ -960,8 +944,10 @@ class MklFusedBatchNormOp : public OpKernel {
std::shared_ptr<BatchNormFwdPd> bn_fwd_pd = bn_fwd->GetBatchNormFwdPd();
if (IS_SRC_REORDER_NEEDED(src_md, bn_fwd_pd, bn_fwd)) {
src.SetUsrMem(src_md, &src_tensor);
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
GET_SRC_DESC_FROM_OP_PD(bn_fwd_pd), cpu_engine_));
src.CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(GET_SRC_DESC_FROM_OP_PD(bn_fwd_pd),
cpu_engine_),
context);
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
} else {
src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
@ -987,9 +973,10 @@ class MklFusedBatchNormOp : public OpKernel {
T* dst_data = dst_tensor->flat<T>().data();
// Execute
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, bn_fwd->GetEngine()));
bn_fwd->Execute(src_data, weights_op_data, dst_data, mean_op_data,
variance_op_data, ws_data);
variance_op_data, fwd_cpu_stream, ws_data);
float adjust_factor = 1.0;
if (is_training_) {
size_t orig_size = src_dims[0] * src_dims[2] * src_dims[3];
@ -1003,9 +990,6 @@ class MklFusedBatchNormOp : public OpKernel {
auto batch_variance_data = batch_variance_tensor->flat<U>().data();
auto est_mean_data = est_mean_tensor.flat<U>().data();
auto est_variance_data = est_variance_tensor.flat<U>().data();
// TODO(intel-tf): Merge the `is_training && exponential_avg_factor == 1`
// case with the `else` (`!is_training`) case if possible.
if (is_training_) {
if (exponential_avg_factor_ == U(1.0)) {
for (int k = 0; k < depth_; k++) {
@ -1328,8 +1312,10 @@ class MklFusedBatchNormGradOp : public OpKernel {
std::shared_ptr<BatchNormBwdPd> bn_bwd_pd = bn_bwd->GetBatchNormBwdPd();
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bn_bwd_pd, bn_bwd)) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
GET_DIFF_DST_DESC_FROM_OP_PD(bn_bwd_pd), cpu_engine_));
diff_dst.CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(GET_DIFF_DST_DESC_FROM_OP_PD(bn_bwd_pd),
cpu_engine_),
context);
diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle());
} else {
diff_dst_data =
@ -1366,10 +1352,11 @@ class MklFusedBatchNormGradOp : public OpKernel {
: nullptr);
// Execute
std::shared_ptr<stream> bwd_cpu_stream;
bwd_cpu_stream.reset(CreateStream(context, bn_bwd->GetEngine()));
bn_bwd->Execute(src_data, mean_data, variance_data, diff_dst_data,
weights_data, diff_src_data, diff_weights_data,
res_space_data);
res_space_data, bwd_cpu_stream);
// Allocate output TF tensors diff_scale and diff_shift.
Tensor* diff_scale_tensor = nullptr;
Tensor* diff_shift_tensor = nullptr;

View File

@ -88,7 +88,6 @@ class MklLRNOp : public OpKernel {
workspace_enabled_ = false;
OP_REQUIRES_OK(context,
context->GetAttr("workspace_enabled", &workspace_enabled_));
fwd_stream_.reset(new CPU_STREAM(cpu_engine_));
}
void Compute(OpKernelContext* context) override {
@ -169,6 +168,7 @@ class MklLRNOp : public OpKernel {
lrn_prim_desc.PRIMITIVE_DESC_SRC, cpu_engine_));
std::vector<primitive> net;
fwd_stream_.reset(CreateStream(context, cpu_engine_));
#ifdef ENABLE_MKLDNN_V1
net.push_back(lrn_forward(lrn_prim_desc));
std::vector<std::unordered_map<int, memory>> net_args;

View File

@ -168,14 +168,13 @@ class MklMatMulOp : public OpKernel {
const int index_transa = transa ? 1 : 0;
const int index_transb = transb ? 1 : 0;
Tensor c_float;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {m, n}, &c_float));
#ifdef ENABLE_MKLDNN_V1
const char ftrans[] = {'N', 'T', 'C'};
dnnl_gemm<bfloat16>(ftrans[index_transa], ftrans[index_transb], m, n, k,
alpha, a, lda, b, ldb, beta,
c_float.flat<float>().data(), ldc);
alpha, a, lda, b, ldb, beta, c, ldc);
#else
Tensor c_float;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {m, n}, &c_float));
const char* const ftrans[] = {"N", "T", "C"};
// MKL-DNN only supports the Fortran API and requires column major while
@ -185,8 +184,8 @@ class MklMatMulOp : public OpKernel {
reinterpret_cast<const mkldnn_bfloat16_t*>(b), &ldb,
reinterpret_cast<const mkldnn_bfloat16_t*>(a), &lda,
&beta, c_float.flat<float>().data(), &ldc);
#endif // ENABLE_MKLDNN_V1
FloatToBFloat16(c_float.flat<float>().data(), c, c_float.NumElements());
#endif // ENABLE_MKLDNN_V1
}
#endif // ENABLE_INTEL_MKL_BFLOAT16
};

View File

@ -516,33 +516,190 @@ class MklDnnMatMulOpBase : public OpKernel {
// MatMul support for bfloat16 and int8 types is introduced in DNNLv1.2.
#ifdef ENABLE_MKLDNN_V1
using mkldnn::matmul;
namespace {
void dnnl_gemm_exec(const memory::desc& a_md, const memory::desc& b_md,
const memory::desc& c_md, const void* a, const void* b,
void* c, const primitive_attr& attr) {
// Create a MatMul primitive
mkldnn::engine cpu_engine = mkldnn::engine(ENGINE_CPU, 0);
mkldnn::matmul::desc matmul_desc(a_md, b_md, c_md);
mkldnn::matmul::primitive_desc matmul_pd(matmul_desc, attr, cpu_engine);
mkldnn::matmul matmul_prim(matmul_pd);
// Wrap raw pointers into DNNL memory objects
mkldnn::memory a_memory(a_md, cpu_engine, const_cast<void*>(a));
mkldnn::memory b_memory(b_md, cpu_engine, const_cast<void*>(b));
mkldnn::memory c_memory(c_md, cpu_engine, c);
// Execute the MatMul primitive.
// Since here all shapes and parameters are static, please note that we
// don't need to pass alpha (scales) again, as they are already hard-coded
// in the primitive descriptor. Also, we are not allowed to change the
// shapes of matrices A, B, and C -- they should exactly match
// the memory descriptors passed to MatMul operation descriptor.
mkldnn::stream s(cpu_engine);
matmul_prim.execute(s, {{DNNL_ARG_SRC, a_memory},
{DNNL_ARG_WEIGHTS, b_memory},
{ DNNL_ARG_DST,
c_memory }});
s.wait();
}
struct MklMatMulParams {
memory::dims a_dims;
memory::dims b_dims;
memory::dims c_dims;
memory::dims a_strides;
memory::dims b_strides;
memory::dims c_strides;
MklMatMulParams(memory::dims a_dims, memory::dims b_dims, memory::dims c_dims,
memory::dims a_strides, memory::dims b_strides,
memory::dims c_strides)
: a_dims(a_dims),
b_dims(b_dims),
c_dims(c_dims),
a_strides(a_strides),
b_strides(b_strides),
c_strides(c_strides) {}
};
template <typename T>
class MklMatMulPrimitive : public MklPrimitive {
public:
explicit MklMatMulPrimitive(const MklMatMulParams& params)
: cpu_engine_(ENGINE_CPU, 0) {
context_.stream.reset(new CPU_STREAM(cpu_engine_));
// Create matmul primitive
Setup(params);
}
~MklMatMulPrimitive() {}
void Execute(const T* a_data, const T* b_data, T* c_data) {
context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data)));
context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data)));
context_.c_mem->set_data_handle(static_cast<void*>(const_cast<T*>(c_data)));
execute_primitives(context_.matmul_primitives, context_.stream,
context_.net_args);
// After execution, set data handle back
context_.a_mem->set_data_handle(DummyData);
context_.b_mem->set_data_handle(DummyData);
context_.c_mem->set_data_handle(DummyData);
}
private:
// Primitive reuse context for MatMul op
struct MklMatMulContext {
// MKL-DNN memory.
std::shared_ptr<mkldnn::memory> a_mem;
std::shared_ptr<mkldnn::memory> b_mem;
std::shared_ptr<mkldnn::memory> c_mem;
// Descriptor and primitive-descriptor for MatMul.
std::shared_ptr<matmul::desc> desc;
std::shared_ptr<matmul::primitive_desc> prim_desc;
// Memory descriptors.
std::shared_ptr<mkldnn::memory::desc> a_md;
std::shared_ptr<mkldnn::memory::desc> b_md;
std::shared_ptr<mkldnn::memory::desc> c_md;
// MatMul primitive.
std::shared_ptr<mkldnn::stream> stream;
std::vector<mkldnn::primitive> matmul_primitives;
std::vector<std::unordered_map<int, memory>> net_args;
MklMatMulContext()
: a_mem(nullptr),
b_mem(nullptr),
c_mem(nullptr),
desc(nullptr),
prim_desc(nullptr),
a_md(nullptr),
b_md(nullptr),
c_md(nullptr),
stream(nullptr) {}
};
void Setup(const MklMatMulParams& params) {
std::shared_ptr<mkldnn::primitive> matmul_primitive = nullptr;
// Create MatMul descriptor and primitive descriptor.
context_.a_md.reset(
new memory::desc({params.a_dims}, MklDnnType<T>(), params.a_strides));
context_.b_md.reset(
new memory::desc({params.b_dims}, MklDnnType<T>(), params.b_strides));
context_.c_md.reset(
new memory::desc({params.c_dims}, MklDnnType<T>(), params.c_strides));
// Create matmul.
context_.desc.reset(
new matmul::desc(*context_.a_md, *context_.b_md, *context_.c_md));
context_.prim_desc.reset(
new matmul::primitive_desc(*context_.desc, cpu_engine_));
// Create memory primitive based on dummy data.
context_.a_mem.reset(
new mkldnn::memory(*context_.a_md, cpu_engine_, DummyData));
context_.b_mem.reset(
new mkldnn::memory(*context_.b_md, cpu_engine_, DummyData));
context_.c_mem.reset(
new mkldnn::memory(*context_.b_md, cpu_engine_, DummyData));
// Create matmul primitive.
matmul_primitive.reset(new mkldnn::matmul(*context_.prim_desc));
context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.a_mem},
{MKLDNN_ARG_WEIGHTS, *context_.b_mem},
{ MKLDNN_ARG_DST,
*context_.c_mem }});
context_.matmul_primitives.push_back(*matmul_primitive);
return;
}
struct MklMatMulContext context_;
engine cpu_engine_;
};
template <typename T>
class MklMatMulPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
static MklMatMulPrimitive<T>* Get(const MklMatMulParams& params,
bool do_not_cache) {
MklMatMulPrimitive<T>* matmul_prim = nullptr;
if (do_not_cache) {
// Always create new primitive
matmul_prim = new MklMatMulPrimitive<T>(params);
} else {
// Try to find a suitable one in pool
matmul_prim = dynamic_cast<MklMatMulPrimitive<T>*>(
MklMatMulPrimitiveFactory<T>::GetInstance().GetMklMatMul(params));
if (matmul_prim == nullptr) {
matmul_prim = new MklMatMulPrimitive<T>(params);
MklMatMulPrimitiveFactory<T>::GetInstance().SetMklMatMul(params,
matmul_prim);
}
}
return matmul_prim;
}
private:
MklMatMulPrimitiveFactory() {}
~MklMatMulPrimitiveFactory() {}
static MklMatMulPrimitiveFactory& GetInstance() {
static MklMatMulPrimitiveFactory instance_;
return instance_;
}
static string CreateKey(const MklMatMulParams& params) {
string prefix = "matmul_";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(params.a_dims);
key_creator.AddAsKey(params.b_dims);
key_creator.AddAsKey(params.c_dims);
key_creator.AddAsKey(params.a_strides);
key_creator.AddAsKey(params.b_strides);
key_creator.AddAsKey(params.c_strides);
key_creator.AddAsKey(typeid(T).name());
return key_creator.GetKey();
}
MklPrimitive* GetMklMatMul(const MklMatMulParams& params) {
string key = CreateKey(params);
return this->GetOp(key);
}
void SetMklMatMul(const MklMatMulParams& params, MklPrimitive* op) {
string key = CreateKey(params);
this->SetOp(key, op);
}
};
template <typename T>
void dnnl_gemm_batch(const std::vector<bool>& transa,
@ -589,45 +746,47 @@ void dnnl_gemm_batch(const std::vector<bool>& transa,
!transb[0] ? dims{k[0] * n[0], n[0], 1} : dims{n[0] * k[0], 1, k[0]};
dims c_strides = dims{m[0] * n[0], n[0], 1};
// Prepare memory descriptors
memory::desc a_md(a_sizes, MklDnnType<T>(), a_strides);
memory::desc b_md(b_sizes, MklDnnType<T>(), b_strides);
memory::desc c_md(c_sizes, MklDnnType<T>(), c_strides);
// Create attributes (to handle alpha and beta if necessary)
mkldnn::primitive_attr attr;
if (alpha[0] != 1.f) attr.set_output_scales(/* mask */ 0, {alpha[0]});
if (beta[0] != 0.f) {
mkldnn::post_ops po;
po.append_sum(beta[0]);
attr.set_post_ops(po);
}
dnnl_gemm_exec(a_md, b_md, c_md, static_cast<const void*>(a),
static_cast<const void*>(b), static_cast<void*>(c), attr);
// MklMatMul uses const alpha and beta, make guarantee here to ensure
// they are never changed.
DCHECK_EQ(alpha, 1.0f);
DCHECK_EQ(beta, 0.f);
MklMatMulParams params(a_sizes, b_sizes, c_sizes, a_strides, b_strides,
c_strides);
MklMatMulPrimitive<T>* matmul_prim =
MklMatMulPrimitiveFactory<T>::Get(params, 0);
// Execute matmul primitive.
matmul_prim->Execute(a, b, c);
}
template <typename T>
void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k,
float alpha, const T* a, int64_t lda, const T* b, int64_t ldb,
float beta, float* c, int64_t ldc) {
float beta, T* c, int64_t ldc) {
using dims = mkldnn::memory::dims;
// Prepare strides based on the transa and transb flags: transposed
// matrices have strides swapped
dims a_dims = dims{m, k};
dims b_dims = dims{k, n};
dims c_dims = dims{m, n};
dims a_strides = tolower(transa) == 'n' ? dims{lda, 1} : dims{1, lda};
dims b_strides = tolower(transb) == 'n' ? dims{ldb, 1} : dims{1, ldb};
// Prepare memory descriptors
memory::desc a_md({m, k}, MklDnnType<T>(), a_strides);
memory::desc b_md({k, n}, MklDnnType<T>(), b_strides);
memory::desc c_md({m, n}, MklDnnType<float>(), {ldc, 1});
// Create attributes (to handle alpha and beta if necessary)
mkldnn::primitive_attr attr;
if (alpha != 1.f) attr.set_output_scales(/* mask */ 0, {alpha});
if (beta != 0.f) {
mkldnn::post_ops po;
po.append_sum(beta);
attr.set_post_ops(po);
}
dnnl_gemm_exec(a_md, b_md, c_md, static_cast<const void*>(a),
static_cast<const void*>(b), static_cast<void*>(c), attr);
dims c_strides = dims{ldc, 1};
// MklMatMul uses const alpha and beta, make guarantee here to ensure
// they are never changed.
DCHECK_EQ(alpha, 1.0f);
DCHECK_EQ(beta, 0.f);
MklMatMulParams params(a_dims, b_dims, c_dims, a_strides, b_strides,
c_strides);
MklMatMulPrimitive<T>* matmul_prim =
MklMatMulPrimitiveFactory<T>::Get(params, 0);
// Execute matmul primitive.
matmul_prim->Execute(a, b, c);
}
} // anonymous namespace

View File

@ -167,10 +167,12 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
const T* src_data = input_tensor.flat<T>().data();
T* dst_data = output_tensor->flat<T>().data();
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, pooling_fwd->GetEngine()));
if (int8_forward_inference) {
// Execute pooling op
pooling_fwd->Execute(src_data, dst_data);
pooling_fwd->Execute(src_data, dst_data, nullptr, fwd_cpu_stream);
// Pass min, max from input to output.
const Tensor& min_input_t = MklGetInput(context, 1);
@ -197,7 +199,7 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
T* ws_data =
static_cast<T*>(dnn_data_wksp.GetOpMem().get_data_handle());
// Execute pooling op.
pooling_fwd->Execute(src_data, dst_data, ws_data);
pooling_fwd->Execute(src_data, dst_data, ws_data, fwd_cpu_stream);
}
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
@ -322,6 +324,8 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
MklPoolingBwdPrimitive<T>* pooling_bwd =
MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);
std::shared_ptr<stream> bwd_cpu_stream;
bwd_cpu_stream.reset(CreateStream(context, pooling_bwd->GetEngine()));
// Allocate output tensor and memory primitive.
Tensor* output_tensor = nullptr;
this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()),
@ -335,8 +339,10 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, pooling_bwd_pd,
pooling_bwd)) {
grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor);
grad_dnn_data.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
GET_DIFF_DST_DESC_FROM_OP_PD(pooling_bwd_pd), cpu_engine_));
grad_dnn_data.CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(GET_DIFF_DST_DESC_FROM_OP_PD(pooling_bwd_pd),
cpu_engine_),
context);
diff_dst_data =
static_cast<T*>(grad_dnn_data.GetOpMem().get_data_handle());
} else {
@ -361,7 +367,8 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
T* diff_src_data = output_tensor->flat<T>().data();
// Execute pooling op.
pooling_bwd->Execute(diff_dst_data, diff_src_data, ws_data);
pooling_bwd->Execute(diff_dst_data, diff_src_data, ws_data,
bwd_cpu_stream);
} catch (mkldnn::error& e) {
string error_msg = "Status:" + std::to_string(e.status) +
", message: " + string(e.message) + ". in file " +

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/kernel_shape_util.h"
namespace tensorflow {
using mkldnn::prop_kind;
@ -38,11 +37,11 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
context_.alg_kind = fwdParams.alg_kind;
context_.prop_kind = fwdParams.prop_kind;
// Create memory descriptor
// FIXME: Pooling doesn't expose to get the src_primitive_desc,
// so src format is currently hard-coded.
// A utility function is used to do this,
// which may be broken with future CPU architectures
// Create memory descriptor
// FIXME: Pooling doesn't expose to get the src_primitive_desc,
// so src format is currently hard-coded.
// A utility function is used to do this,
// which may be broken with future CPU architectures
#ifndef ENABLE_MKLDNN_V1
bool is_2d = (fwdParams.src_dims.size() == 4);
if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value)
@ -126,7 +125,8 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
template <typename T>
void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
void* ws_data) {
void* ws_data,
std::shared_ptr<stream> fwd_stream) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
@ -138,10 +138,9 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
}
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.fwd_primitives, context_.fwd_stream,
context_.net_args);
execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
#else
context_.fwd_stream->submit(context_.fwd_primitives);
fwd_stream->submit(context_.fwd_primitives);
#endif // ENABLE_MKLDNN_V1
// Set back data handle.
@ -268,7 +267,8 @@ void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
template <typename T>
void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
T* diff_src_data, const void* ws_data) {
T* diff_src_data, const void* ws_data,
std::shared_ptr<stream> bwd_stream) {
context_.diff_dst_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_dst_data)));
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
@ -278,10 +278,9 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
}
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.bwd_primitives, context_.bwd_stream,
context_.net_args);
execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args);
#else
context_.bwd_stream->submit(context_.bwd_primitives);
bwd_stream->submit(context_.bwd_primitives);
#endif // ENABLE_MKLDNN_V1
// Set back data handle.

View File

@ -86,8 +86,7 @@ template <typename T>
class MklPoolingFwdPrimitive : public MklPrimitive {
public:
explicit MklPoolingFwdPrimitive(const MklPoolingParams& fwdParams)
: cpu_engine_(ENGINE_CPU, 0) {
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
: MklPrimitive(engine(ENGINE_CPU, 0)) {
if (context_.fwd == nullptr) Setup(fwdParams);
}
@ -97,7 +96,8 @@ class MklPoolingFwdPrimitive : public MklPrimitive {
// src_data: input data buffer of src
// ws_data: output data buffer of workspace
// dst_data: output data buffer of dst
void Execute(const T* src_data, T* dst_data, void* ws_data = nullptr);
void Execute(const T* src_data, T* dst_data, void* ws_data,
std::shared_ptr<stream> fwd_stream);
std::shared_ptr<PoolingFwdPd> GetPoolingFwdPd() const {
return context_.fwd_pd;
@ -159,12 +159,10 @@ class MklPoolingFwdPrimitive : public MklPrimitive {
fwd_pd(nullptr),
src_md(nullptr),
dst_md(nullptr),
fwd(nullptr),
fwd_stream(nullptr) {}
fwd(nullptr) {}
};
struct PoolingFwdContext context_;
engine cpu_engine_;
};
template <typename T>
@ -229,8 +227,7 @@ template <typename T>
class MklPoolingBwdPrimitive : public MklPrimitive {
public:
explicit MklPoolingBwdPrimitive(const MklPoolingParams& bwdParams)
: cpu_engine_(ENGINE_CPU, 0) {
context_.bwd_stream.reset(new CPU_STREAM(cpu_engine_));
: MklPrimitive(engine(ENGINE_CPU, 0)) {
if (context_.bwd == nullptr) Setup(bwdParams);
}
@ -240,8 +237,8 @@ class MklPoolingBwdPrimitive : public MklPrimitive {
// diff_dst_data: input data buffer of diff_dst
// diff_src_data: output data buffer of diff_src
// ws_data: input data buffer of workspace
void Execute(const T* diff_dst_data, T* diff_src_data,
const void* ws_data = nullptr);
void Execute(const T* diff_dst_data, T* diff_src_data, const void* ws_data,
std::shared_ptr<stream> bwd_stream);
public:
std::shared_ptr<PoolingFwdPd> GetPoolingFwdPd() const {
@ -315,12 +312,10 @@ class MklPoolingBwdPrimitive : public MklPrimitive {
bwd_desc(nullptr),
fwd_pd(nullptr),
bwd_pd(nullptr),
bwd(nullptr),
bwd_stream(nullptr) {}
bwd(nullptr) {}
};
struct PoolingBwdContext context_;
engine cpu_engine_;
};
template <typename T>

View File

@ -77,7 +77,7 @@ class MklReorderWithScalePrimitive : public MklPrimitive {
public:
explicit MklReorderWithScalePrimitive(
const MklReorderWithScaleFwdParams& fwdParams)
: cpu_engine_(ENGINE_CPU, 0) {
: MklPrimitive(engine(ENGINE_CPU, 0)) {
// Create reorder primitive
Setup(fwdParams);
}
@ -86,14 +86,14 @@ class MklReorderWithScalePrimitive : public MklPrimitive {
std::shared_ptr<primitive> GetPrimitive() { return context_.reorder_prim; }
void Execute(void* src_data, void* dst_data) {
void Execute(void* src_data, void* dst_data,
std::shared_ptr<stream> reorder_stream) {
context_.src_mem->set_data_handle(src_data);
context_.dst_mem->set_data_handle(dst_data);
#ifndef ENABLE_MKLDNN_V1
context_.reorder_stream->submit(context_.net);
reorder_stream->submit(context_.net);
#else
context_.reorder_prim->execute(*context_.reorder_stream,
context_.prim_args);
context_.reorder_prim->execute(*reorder_stream, context_.prim_args);
#endif // !ENABLE_MKLDNN_V1
// After execution, set data handle back.
context_.src_mem->set_data_handle(DummyData);
@ -124,12 +124,9 @@ class MklReorderWithScalePrimitive : public MklPrimitive {
: src_mem(nullptr),
dst_mem(nullptr),
reorder_pd(nullptr),
reorder_prim(nullptr),
reorder_stream(nullptr) {}
reorder_prim(nullptr) {}
} context_;
engine cpu_engine_;
// Reorder primitive setup
void Setup(const MklReorderWithScaleFwdParams& fwdParams) {
// Create memory descriptors for reorder data with specified format
@ -163,7 +160,6 @@ class MklReorderWithScalePrimitive : public MklPrimitive {
context_.prim_args.insert({MKLDNN_ARG_FROM, *context_.src_mem});
context_.prim_args.insert({MKLDNN_ARG_TO, *context_.dst_mem});
#endif // !ENABLE_MKLDNN_V1
context_.reorder_stream.reset(new CPU_STREAM(cpu_engine_));
}
};
@ -491,7 +487,10 @@ class MklQuantizeV2Op : public OpKernel {
MklReorderWithScalePrimitive* reorder_prim =
MklReorderWithScalePrimitiveFactory<T>::Get(src.GetUsrMem(),
dst.GetUsrMem(), fwdParams);
reorder_prim->Execute(src.GetUsrMemDataHandle(), dst.GetUsrMemDataHandle());
std::shared_ptr<stream> cpu_stream;
cpu_stream.reset(CreateStream(ctx, reorder_prim->GetEngine()));
reorder_prim->Execute(src.GetUsrMemDataHandle(), dst.GetUsrMemDataHandle(),
cpu_stream);
output_min_tensor->flat<float>()(0) = min_range;
output_max_tensor->flat<float>()(0) = max_range;

View File

@ -61,13 +61,11 @@ template <typename T>
class MklEltwiseFwdPrimitive : public MklPrimitive {
public:
explicit MklEltwiseFwdPrimitive(const MklEltwiseFwdParams<T>& fwdParams)
: cpu_engine_(ENGINE_CPU, 0) {
: MklPrimitive(engine(ENGINE_CPU, 0)) {
#ifndef ENABLE_MKLDNN_V1
context_.src_fmt =
static_cast<mkldnn::memory::format>(fwdParams.src_md.data.format);
#endif
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
// create eltwise primitive
if (context_.eltwise_fwd == nullptr) {
Setup(fwdParams);
@ -79,7 +77,8 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
// Eltwise forward execute
// src_data: input data buffer of src
// dst_data: output data buffer of dst
void Execute(const T* src_data, T* dst_data) {
void Execute(const T* src_data, T* dst_data,
std::shared_ptr<stream> fwd_stream) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
@ -87,12 +86,10 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
#ifdef ENABLE_MKLDNN_V1
DCHECK_EQ(context_.fwd_primitives.size(),
context_.fwd_primitives_args.size());
for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
context_.fwd_primitives.at(i).execute(*context_.fwd_stream,
context_.fwd_primitives_args.at(i));
}
execute_primitives(context_.fwd_primitives, fwd_stream,
context_.fwd_primitives_args);
#else
context_.fwd_stream->submit(context_.fwd_primitives);
fwd_stream->submit(context_.fwd_primitives);
#endif
// After execution, set data handle back.
@ -134,7 +131,6 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
// Eltwise primitive
std::shared_ptr<mkldnn::primitive> eltwise_fwd;
std::shared_ptr<stream> fwd_stream;
std::vector<mkldnn::primitive> fwd_primitives;
#ifdef ENABLE_MKLDNN_V1
@ -153,8 +149,7 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
src_md(nullptr),
dst_md(nullptr),
src_mpd(nullptr),
eltwise_fwd(nullptr),
fwd_stream(nullptr) {
eltwise_fwd(nullptr) {
}
};
@ -169,14 +164,12 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
#else
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
#endif
// Create an eltwise forward descriptor and primitive descriptor
context_.fwd_desc.reset(new eltwise_forward::desc(
prop_kind::forward, fwdParams.alg_kind, *context_.src_md,
fwdParams.alpha, fwdParams.beta));
context_.fwd_pd.reset(new EltwiseFwdPd(*context_.fwd_desc, cpu_engine_));
auto fwd_pd = context_.fwd_pd.get();
#ifdef ENABLE_MKLDNN_V1
// Create memory primitive based on dummy data
context_.src_mem.reset(new MEMORY_CONSTRUCTOR(fwd_pd->PRIMITIVE_DESC_SRC,
@ -195,12 +188,10 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
context_.eltwise_fwd.reset(new eltwise_forward(
*context_.fwd_pd, *context_.src_mem, *context_.dst_mem));
#endif
context_.fwd_primitives.push_back(*context_.eltwise_fwd);
}
struct EltwiseFwdContext context_;
engine cpu_engine_;
};
template <typename T>
@ -281,14 +272,13 @@ template <typename T>
class MklEltwiseBwdPrimitive : public MklPrimitive {
public:
explicit MklEltwiseBwdPrimitive(const MklEltwiseBwdParams<T>& bwdParams)
: cpu_engine_(ENGINE_CPU, 0) {
: MklPrimitive(engine(ENGINE_CPU, 0)) {
#ifndef ENABLE_MKLDNN_V1
context_.src_fmt =
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
context_.diff_dst_fmt =
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
#endif
context_.bwd_stream.reset(new stream(CPU_STREAM(cpu_engine_)));
// create eltwise primitive
if (context_.eltwise_bwd == nullptr) {
Setup(bwdParams);
@ -301,7 +291,8 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
// src_data: input data buffer of src
// diff_dst_data: input data buffer of diff_dst
// diff_src_data: output data buffer of diff_src
void Execute(const T* src_data, const T* diff_dst_data, T* diff_src_data) {
void Execute(const T* src_data, const T* diff_dst_data, T* diff_src_data,
std::shared_ptr<stream> bwd_stream) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.diff_dst_mem->set_data_handle(
@ -311,12 +302,10 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
#ifdef ENABLE_MKLDNN_V1
DCHECK_EQ(context_.bwd_primitives.size(),
context_.bwd_primitives_args.size());
for (size_t i = 0; i < context_.bwd_primitives.size(); ++i) {
context_.bwd_primitives.at(i).execute(*context_.bwd_stream,
context_.bwd_primitives_args.at(i));
}
execute_primitives(context_.bwd_primitives, bwd_stream,
context_.bwd_primitives_args);
#else
context_.bwd_stream->submit(context_.bwd_primitives);
bwd_stream->submit(context_.bwd_primitives);
#endif // ENABLE_MKLDNN_V1
// after execution, set data handle back
@ -367,7 +356,6 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
// Eltwise primitive.
std::shared_ptr<mkldnn::primitive> eltwise_bwd;
std::shared_ptr<stream> bwd_stream;
std::vector<mkldnn::primitive> bwd_primitives;
#ifdef ENABLE_MKLDNN_V1
@ -391,8 +379,7 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
fwd_desc(nullptr),
fwd_pd(nullptr),
bwd_pd(nullptr),
eltwise_bwd(nullptr),
bwd_stream(nullptr) {
eltwise_bwd(nullptr) {
}
};
@ -448,7 +435,6 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
}
struct EltwiseBwdContext context_;
engine cpu_engine_;
};
template <typename T>
@ -525,12 +511,10 @@ class MklReluOpBase : public OpKernel {
const Tensor& src_tensor = MklGetInput(context, src_index);
MklDnnShape dnn_shape_src;
GetMklShape(context, src_index, &dnn_shape_src);
if (src_tensor.dims() == 0) {
Compute_Scalar(context);
return;
}
MklDnnShape dnn_shape_dst;
TensorShape tf_shape_dst;
Tensor* dst_tensor = nullptr;
@ -542,7 +526,6 @@ class MklReluOpBase : public OpKernel {
dnn_shape_dst);
return;
}
// Set DNN primitive - src
MklDnnData<T> src(&cpu_engine);
memory::dims src_dims;
@ -556,26 +539,25 @@ class MklReluOpBase : public OpKernel {
// Create blocked memory descriptor
src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides);
}
// Try to get an eltwise forward primitive from caching pool
MklEltwiseFwdParams<T> fwdParams(src_dims, src_md, alg_kind, alpha_,
beta_);
MklEltwiseFwdPrimitive<T>* eltwise_fwd =
MklEltwiseFwdPrimitiveFactory<T>::Get(fwdParams);
auto eltwise_fwd_pd = eltwise_fwd->GetEltwiseFwdPd();
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, eltwise_fwd->GetEngine()));
// Check if src needs to be reordered
const T* src_data = src_tensor.flat<T>().data();
if (IS_SRC_REORDER_NEEDED(src_md, eltwise_fwd_pd, eltwise_fwd)) {
src.SetUsrMem(src_md, &src_tensor);
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
eltwise_fwd_pd->PRIMITIVE_DESC_SRC, cpu_engine));
src.CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(eltwise_fwd_pd->PRIMITIVE_DESC_SRC,
cpu_engine),
context);
src_data = const_cast<T*>(
reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
}
// Allocate dst tensor, always set it as MKL-DNN layout
if (dnn_shape_src.IsMklTensor()) {
dnn_shape_dst.SetMklTensor(true);
@ -590,7 +572,6 @@ class MklReluOpBase : public OpKernel {
dnn_shape_dst.SetMklTensor(false);
tf_shape_dst = src_tensor.shape();
}
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{static_cast<const int>(src_index)},
static_cast<const int>(dst_index),
@ -600,7 +581,7 @@ class MklReluOpBase : public OpKernel {
T* dst_data = dst_tensor->flat<T>().data();
// execute eltwise
eltwise_fwd->Execute(src_data, dst_data);
eltwise_fwd->Execute(src_data, dst_data, fwd_cpu_stream);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
@ -727,13 +708,16 @@ class MklReluGradOpBase : public OpKernel {
MklEltwiseBwdPrimitiveFactory<T>::Get(bwdParams);
auto eltwise_bwd_pd = eltwise_bwd->GetEltwiseBwdPd();
std::shared_ptr<stream> bwd_cpu_stream;
bwd_cpu_stream.reset(CreateStream(context, eltwise_bwd->GetEngine()));
// check whether need reorder for src / diff_dst
const T* src_data = src_tensor.flat<T>().data();
if (IS_SRC_REORDER_NEEDED(src_md, eltwise_bwd_pd, eltwise_bwd)) {
src.SetUsrMem(src_md, &src_tensor);
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
eltwise_bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine));
src.CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(
eltwise_bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine),
context);
src_data = const_cast<T*>(
reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
}
@ -742,8 +726,10 @@ class MklReluGradOpBase : public OpKernel {
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, eltwise_bwd_pd,
eltwise_bwd)) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
eltwise_bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine));
diff_dst.CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(
eltwise_bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine),
context);
diff_dst_data = const_cast<T*>(
reinterpret_cast<T*>(diff_dst.GetOpMem().get_data_handle()));
}
@ -779,7 +765,8 @@ class MklReluGradOpBase : public OpKernel {
T* diff_src_data = diff_src_tensor->flat<T>().data();
// execute eltwise bwd
eltwise_bwd->Execute(src_data, diff_dst_data, diff_src_data);
eltwise_bwd->Execute(src_data, diff_dst_data, diff_src_data,
bwd_cpu_stream);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +

View File

@ -130,9 +130,10 @@ class MklRequantizePerChannelOp : public OpKernel {
GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(input_mem_prim),
GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(output_mem_prim),
cpu_engine_, reorder_attr);
mkldnn::stream reorder_stream = CPU_STREAM(cpu_engine_);
std::shared_ptr<stream> reorder_stream;
reorder_stream.reset(CreateStream(ctx, cpu_engine_));
#ifndef ENABLE_MKLDNN_V1
reorder_stream.submit(
reorder_stream->submit(
{mkldnn::reorder(reorder_pd, *input_mem_prim, *output_mem_prim)});
#else
std::unordered_map<int, mkldnn::memory> reorder_args = {
@ -140,7 +141,7 @@ class MklRequantizePerChannelOp : public OpKernel {
{MKLDNN_ARG_TO, *output_mem_prim}};
std::unique_ptr<mkldnn::primitive> reorder_prim(
new mkldnn::reorder(reorder_pd));
reorder_prim->execute(reorder_stream, reorder_args);
reorder_prim->execute(*reorder_stream, reorder_args);
#endif // !ENABLE_MKLDNN_V1
Tensor* output_min = nullptr;

View File

@ -181,22 +181,22 @@ template <typename T>
class MklSlicePrimitive : public MklPrimitive {
public:
explicit MklSlicePrimitive(const MklSliceParams& sliceParams)
: cpu_engine_(ENGINE_CPU, 0) {
context_.slice_stream.reset(new CPU_STREAM(cpu_engine_));
: MklPrimitive(engine(ENGINE_CPU, 0)) {
Setup(sliceParams);
}
~MklSlicePrimitive() {}
void Execute(const MklSliceParams& sliceParams) {
void Execute(const MklSliceParams& sliceParams,
std::shared_ptr<stream> slice_stream) {
context_.src_mem->set_data_handle(sliceParams.from->get_data_handle());
context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle());
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.slice_primitives, context_.slice_stream,
execute_primitives(context_.slice_primitives, slice_stream,
context_.slice_primitives_args);
#else
context_.slice_stream->submit(context_.slice_primitives);
slice_stream->submit(context_.slice_primitives);
#endif
// We should set it back to DummyData so as to make the primitive
@ -228,8 +228,6 @@ class MklSlicePrimitive : public MklPrimitive {
: src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
} context_;
engine cpu_engine_;
void Setup(const MklSliceParams& sliceParams) {
// Actually, DummyData will not be used in computation,
// because the real data will be filled before execution.
@ -465,7 +463,7 @@ class MklSliceOp : public OpKernel {
auto op_md =
MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
#ifdef ENABLE_MKLDNN_V1
src.CheckReorderToOpMem(op_md, cpu_engine);
src.CheckReorderToOpMem(op_md, cpu_engine, context);
#else
auto op_pd = memory::primitive_desc(op_md, cpu_engine);
src.CheckReorderToOpMem(op_pd);
@ -492,7 +490,9 @@ class MklSliceOp : public OpKernel {
MklSlicePrimitive<T>* reorder_prim =
MklSlicePrimitiveFactory<T>::Get(sliceParams);
// Execute slice reorder.
reorder_prim->Execute(sliceParams);
std::shared_ptr<stream> slice_stream;
slice_stream.reset(CreateStream(context, reorder_prim->GetEngine()));
reorder_prim->Execute(sliceParams, slice_stream);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +

View File

@ -48,8 +48,7 @@ template <typename T>
class MklSoftmaxPrimitive : public MklPrimitive {
public:
explicit MklSoftmaxPrimitive(const MklSoftmaxParams& fwdParams)
: cpu_engine_(ENGINE_CPU, 0) {
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
: MklPrimitive(engine(ENGINE_CPU, 0)) {
Setup(fwdParams);
}
@ -58,16 +57,18 @@ class MklSoftmaxPrimitive : public MklPrimitive {
// Softmax forward execute
// src_data: input data buffer of src
// dst_data: output data buffer of dst
void Execute(const T* src_data, T* dst_data) {
void Execute(const T* src_data, T* dst_data,
std::shared_ptr<stream> fwd_cpu_stream) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.fwd_primitives, context_.fwd_stream,
DCHECK_EQ(context_.fwd_primitives.size(), context_.fwd_net_args.size());
execute_primitives(context_.fwd_primitives, fwd_cpu_stream,
context_.fwd_net_args);
#else
context_.fwd_stream->submit(context_.fwd_primitives);
fwd_cpu_stream->submit(context_.fwd_primitives);
#endif
// After execution, set data handle back.
@ -95,7 +96,6 @@ class MklSoftmaxPrimitive : public MklPrimitive {
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> fwd_pd;
std::shared_ptr<mkldnn::primitive> softmax_fwd;
std::shared_ptr<stream> fwd_stream;
std::vector<mkldnn::primitive> fwd_primitives;
std::vector<MemoryArgsMap> fwd_net_args;
@ -105,8 +105,7 @@ class MklSoftmaxPrimitive : public MklPrimitive {
fwd_desc(nullptr),
src_md(nullptr),
fwd_pd(nullptr),
softmax_fwd(nullptr),
fwd_stream(nullptr) {}
softmax_fwd(nullptr) {}
};
// Softmax forward primitive setup
@ -143,7 +142,6 @@ class MklSoftmaxPrimitive : public MklPrimitive {
}
struct SoftmaxFwdContext context_;
engine cpu_engine_;
};
template <typename T>
@ -303,9 +301,9 @@ class MklSoftmaxOp : public OpKernel {
const T* src_data = src_tensor.flat<T>().data();
T* dst_data = reinterpret_cast<T*>(output_tensor->flat<T>().data());
// Execute softmax primitive.
softmax_fwd->Execute(src_data, dst_data);
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, softmax_fwd->GetEngine()));
softmax_fwd->Execute(src_data, dst_data, fwd_cpu_stream);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +

View File

@ -144,12 +144,14 @@ Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor,
std::vector<primitive> net;
#ifdef ENABLE_MKLDNN_V1
std::shared_ptr<stream> transpose_stream;
auto* prim = FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem());
transpose_stream.reset(CreateStream(context, prim->GetEngine()));
net.push_back(*(prim->GetPrimitive()));
std::vector<MemoryArgsMap> net_args;
net_args.push_back({{MKLDNN_ARG_FROM, *in.GetUsrMem()},
{MKLDNN_ARG_TO, *out.GetUsrMem()}});
execute_primitives(net, prim->GetStream(), net_args);
execute_primitives(net, transpose_stream, net_args);
#else
std::shared_ptr<stream> transpose_stream;
transpose_stream.reset(new CPU_STREAM(cpu_engine));

View File

@ -887,6 +887,17 @@ class ResourceScatterUpdateOp : public OpKernel {
const Tensor& indices = c->input(1);
const Tensor& updates = c->input(2);
// Check that rank(updates.shape) = rank(indices.shape + params.shape[1:])
OP_REQUIRES(c,
updates.dims() == 0 ||
updates.dims() == indices.dims() + params->dims() - 1,
errors::InvalidArgument(
"Must have updates.shape = indices.shape + "
"params.shape[1:] or updates.shape = [], got ",
"updates.shape ", updates.shape().DebugString(),
", indices.shape ", indices.shape().DebugString(),
", params.shape ", params->shape().DebugString()));
// Check that we have enough index space
const int64 N_big = indices.NumElements();
OP_REQUIRES(

View File

@ -508,6 +508,12 @@ class SparseSegmentReductionOpBase : public OpKernel {
errors::InvalidArgument("segment ids must be >= 0"));
auto output_flat = output->flat_outer_dims<T>();
Tensor temp;
if (input.dtype() == DT_BFLOAT16) {
temp = tensorflow::Tensor(DT_FLOAT, output_shape);
}
auto temp_flat = temp.flat_outer_dims<float>();
int64 start = 0, end = 1;
// Index from which the output is not initialized.
SegmentId uninitialized_index = 0;
@ -546,8 +552,9 @@ class SparseSegmentReductionOpBase : public OpKernel {
}
auto out = output_flat.template chip<0>(out_index);
const int bad_offset =
Reduce(input_flat, indices_vec, start, end - start, out);
auto temp = temp_flat.template chip<0>(out_index);
const int bad_offset = Reduce<T, Index>(input_flat, indices_vec, start,
end - start, out, temp);
OP_REQUIRES(context, bad_offset < 0,
errors::InvalidArgument(
"Bad: indices[", start + bad_offset,
@ -572,40 +579,89 @@ class SparseSegmentReductionOpBase : public OpKernel {
}
private:
int64 Reduce(const typename TTypes<T>::ConstMatrix& input_flat,
const typename TTypes<Index>::ConstVec& indices_vec, int64 start,
int64 num,
Eigen::TensorChippingOp<0, typename TTypes<T>::Matrix> out) {
template <typename Tin>
using EnableIfBfloat16 =
typename std::enable_if<std::is_same<Tin, bfloat16>::value, int>::type;
template <typename Tin>
using EnableIfNotBfloat16 =
typename std::enable_if<!std::is_same<Tin, bfloat16>::value, int>::type;
template <typename Tin, typename Tindex, EnableIfNotBfloat16<Tin> = 0>
EIGEN_ALWAYS_INLINE auto fetch_val(
const typename TTypes<Tin>::ConstMatrix& input_flat, Tindex index) {
return input_flat.template chip<0>(index);
}
template <typename Tin, typename Tindex, EnableIfBfloat16<Tin> = 0>
EIGEN_ALWAYS_INLINE auto fetch_val(
const typename TTypes<Tin>::ConstMatrix& input_flat, Tindex index) {
return input_flat.template chip<0>(index).template cast<float>();
}
template <typename Tout>
EIGEN_ALWAYS_INLINE Tout get_scaling_factor(int64 num) {
Tout m(1);
if (is_mean_ && (num < 10)) {
m = Tout(num);
}
if (is_sqrtn_ && (num < 10)) {
m = Tout(sqrt(num));
}
return Tout(1) / m;
}
template <typename Tin, typename Tindex, EnableIfNotBfloat16<Tin> = 0>
int64 Reduce(
const typename TTypes<Tin>::ConstMatrix& input_flat,
const typename TTypes<Tindex>::ConstVec& indices_vec, int64 start,
int64 num, Eigen::TensorChippingOp<0, typename TTypes<Tin>::Matrix> out,
Eigen::TensorChippingOp<0, typename TTypes<float>::Matrix> temp) {
return ReduceImpl<Tin, Tindex, Tin>(input_flat, indices_vec, start, num,
out, get_scaling_factor<Tin>(num));
}
template <typename Tin, typename Tindex, EnableIfBfloat16<Tin> = 0>
int64 Reduce(
const typename TTypes<Tin>::ConstMatrix& input_flat,
const typename TTypes<Tindex>::ConstVec& indices_vec, int64 start,
int64 num, Eigen::TensorChippingOp<0, typename TTypes<Tin>::Matrix> out,
Eigen::TensorChippingOp<0, typename TTypes<float>::Matrix> temp) {
int64 res =
ReduceImpl<Tin, Tindex, float>(input_flat, indices_vec, start, num,
temp, get_scaling_factor<float>(num));
out = temp.template cast<bfloat16>();
return res;
}
template <typename Tin, typename Tindex, typename Tout>
int64 ReduceImpl(
const typename TTypes<Tin>::ConstMatrix& input_flat,
const typename TTypes<Tindex>::ConstVec& indices_vec, int64 start,
int64 num, Eigen::TensorChippingOp<0, typename TTypes<Tout>::Matrix> out,
const Tout scaling_factor) {
#define INDEX(n, i) \
const auto index##n = indices_vec(start + (i)); \
if (!FastBoundsCheck(index##n, input_flat.dimension(0))) return (i);
#define L(n) input_flat.template chip<0>(index##n)
#define L(n) fetch_val<Tin, Tindex>(input_flat, index##n)
if (num == 1) {
INDEX(0, 0);
out = L(0);
} else {
int64 r = num % 8;
T m(1);
if (is_mean_ && (num < 10)) {
m = T(num);
}
if (is_sqrtn_ && (num < 10)) {
m = T(sqrt(num));
}
int64 r = num & 7;
switch (r) {
case 2: {
INDEX(0, 0);
INDEX(1, 1);
out = (L(0) + L(1)) / m;
out = (L(0) + L(1)) * scaling_factor;
break;
}
case 3: {
INDEX(0, 0);
INDEX(1, 1);
INDEX(2, 2);
out = (L(0) + L(1) + L(2)) / m;
out = (L(0) + L(1) + L(2)) * scaling_factor;
break;
}
case 4: {
@ -613,7 +669,7 @@ class SparseSegmentReductionOpBase : public OpKernel {
INDEX(1, 1);
INDEX(2, 2);
INDEX(3, 3);
out = (L(0) + L(1) + L(2) + L(3)) / m;
out = (L(0) + L(1) + L(2) + L(3)) * scaling_factor;
break;
}
case 5: {
@ -622,7 +678,7 @@ class SparseSegmentReductionOpBase : public OpKernel {
INDEX(2, 2);
INDEX(3, 3);
INDEX(4, 4);
out = (L(0) + L(1) + L(2) + L(3) + L(4)) / m;
out = (L(0) + L(1) + L(2) + L(3) + L(4)) * scaling_factor;
break;
}
case 6: {
@ -632,7 +688,7 @@ class SparseSegmentReductionOpBase : public OpKernel {
INDEX(3, 3);
INDEX(4, 4);
INDEX(5, 5);
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) / m;
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) * scaling_factor;
break;
}
case 7: {
@ -643,7 +699,8 @@ class SparseSegmentReductionOpBase : public OpKernel {
INDEX(4, 4);
INDEX(5, 5);
INDEX(6, 6);
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) / m;
out =
(L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) * scaling_factor;
break;
}
case 0: {
@ -655,7 +712,8 @@ class SparseSegmentReductionOpBase : public OpKernel {
INDEX(5, 5);
INDEX(6, 6);
INDEX(7, 7);
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) / m;
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) *
scaling_factor;
r = 8;
break;
}
@ -669,8 +727,8 @@ class SparseSegmentReductionOpBase : public OpKernel {
INDEX(6, 6);
INDEX(7, 7);
INDEX(8, 8);
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) /
m;
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) *
scaling_factor;
r = 9;
break;
}
@ -687,10 +745,10 @@ class SparseSegmentReductionOpBase : public OpKernel {
out += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7);
}
if (is_mean_ && num >= 10) {
out = out / static_cast<T>(num);
out = out / static_cast<Tout>(num);
}
if (is_sqrtn_ && num >= 10) {
out = out / static_cast<T>(sqrt(num));
out = out / static_cast<Tout>(sqrt(num));
}
}

View File

@ -64,6 +64,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
segment_ids_type>);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(bfloat16);
#undef REGISTER_CPU_SPARSE_KERNELS
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
@ -85,6 +86,7 @@ REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
CPUDevice, type, index_type, segment_ids_type>);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(bfloat16);
#undef REGISTER_CPU_SPARSE_KERNELS
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \

View File

@ -21,9 +21,7 @@ limitations under the License.
#include "tensorflow/core/kernels/strided_slice_op_gpu_impl.h"
namespace tensorflow {
TF_CALL_int8(DEFINE_GPU_KERNELS);
TF_CALL_int32(DEFINE_GPU_KERNELS);
TF_CALL_int64(DEFINE_GPU_KERNELS);
TF_CALL_INTEGRAL_TYPES(DEFINE_GPU_KERNELS);
} // end namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -25,6 +25,19 @@ REGISTER_OP("BatchFunction")
.Output("out_tensors: Tout")
.Attr("f: func")
.Attr("num_batch_threads: int")
// 'max_batch_size' denotes the maximum batch size acceptable, i.e., inputs
// with larger batch size are simply invalidated.
// By default, 'max_batch_size' must be equal to max value of
// 'allowed_batch_sizes'.
// By setting 'enable_large_batch_splitting' (attribute below) to true,
// 'max_batch_size' can be greater than or equal to max value of
// 'allowed_batch_sizes', in other words,
// 1) input with size > 'max_batch_size' is still invalidated.
// 2) input with
// a) size <= 'max_batch_size'
// b) size > max value of 'allowed_batch_sizes'
// will automatically be split into multiple batches (with batch size in
// 'allowed_batch_sizes'), executed, and re-composed (as final output).
.Attr("max_batch_size: int")
.Attr("batch_timeout_micros: int")
.Attr("max_enqueued_batches: int = 10")
@ -35,6 +48,12 @@ REGISTER_OP("BatchFunction")
.Attr("Tin: list(type)")
.Attr("Tcaptured: list(type) >= 0")
.Attr("Tout: list(type)")
// If 'enable_large_batch_splitting' is true, for input batches exceeding
// the largest value in "allowed_batch_sizes", allow the batch to be split
// into multiple batches with batch size within "allowed_batch_sizes".
// NOTE: Support for `enable_large_batch_splitting == true` is still
// developed in progress.
.Attr("enable_large_batch_splitting: bool = false")
// TODO(apassos): Fix this shape inference function. It requires shape
// inference of function calls.
.SetShapeFn(shape_inference::UnknownShape);

View File

@ -82,3 +82,94 @@ op {
minimum: 1
}
}
op {
name: "BatchFunction"
input_arg {
name: "in_tensors"
type_list_attr: "Tin"
}
input_arg {
name: "captured_tensors"
type_list_attr: "Tcaptured"
}
output_arg {
name: "out_tensors"
type_list_attr: "Tout"
}
attr {
name: "f"
type: "func"
}
attr {
name: "num_batch_threads"
type: "int"
}
attr {
name: "max_batch_size"
type: "int"
}
attr {
name: "batch_timeout_micros"
type: "int"
}
attr {
name: "max_enqueued_batches"
type: "int"
default_value {
i: 10
}
}
attr {
name: "allowed_batch_sizes"
type: "list(int)"
default_value {
list {
}
}
}
attr {
name: "container"
type: "string"
default_value {
s: ""
}
}
attr {
name: "shared_name"
type: "string"
default_value {
s: ""
}
}
attr {
name: "batching_queue"
type: "string"
default_value {
s: ""
}
}
attr {
name: "Tin"
type: "list(type)"
has_minimum: true
minimum: 1
}
attr {
name: "Tcaptured"
type: "list(type)"
has_minimum: true
}
attr {
name: "Tout"
type: "list(type)"
has_minimum: true
minimum: 1
}
attr {
name: "enable_large_batch_splitting"
type: "bool"
default_value {
b: false
}
}
}

View File

@ -53,3 +53,59 @@ op {
}
}
}
op {
name: "SparseSegmentMean"
input_arg {
name: "data"
type_attr: "T"
}
input_arg {
name: "indices"
type_attr: "Tidx"
}
input_arg {
name: "segment_ids"
type_attr: "Tsegmentids"
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
attr {
name: "Tidx"
type: "type"
default_value {
type: DT_INT32
}
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "Tsegmentids"
type: "type"
default_value {
type: DT_INT32
}
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
}

View File

@ -70,3 +70,76 @@ op {
}
}
}
op {
name: "SparseSegmentMeanWithNumSegments"
input_arg {
name: "data"
type_attr: "T"
}
input_arg {
name: "indices"
type_attr: "Tidx"
}
input_arg {
name: "segment_ids"
type_attr: "Tsegmentids"
}
input_arg {
name: "num_segments"
type_attr: "Tnumsegments"
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
attr {
name: "Tidx"
type: "type"
default_value {
type: DT_INT32
}
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "Tnumsegments"
type: "type"
default_value {
type: DT_INT32
}
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "Tsegmentids"
type: "type"
default_value {
type: DT_INT32
}
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
}

View File

@ -53,3 +53,59 @@ op {
}
}
}
op {
name: "SparseSegmentSqrtN"
input_arg {
name: "data"
type_attr: "T"
}
input_arg {
name: "indices"
type_attr: "Tidx"
}
input_arg {
name: "segment_ids"
type_attr: "Tsegmentids"
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
attr {
name: "Tidx"
type: "type"
default_value {
type: DT_INT32
}
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "Tsegmentids"
type: "type"
default_value {
type: DT_INT32
}
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
}

View File

@ -70,3 +70,76 @@ op {
}
}
}
op {
name: "SparseSegmentSqrtNWithNumSegments"
input_arg {
name: "data"
type_attr: "T"
}
input_arg {
name: "indices"
type_attr: "Tidx"
}
input_arg {
name: "segment_ids"
type_attr: "Tsegmentids"
}
input_arg {
name: "num_segments"
type_attr: "Tnumsegments"
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
attr {
name: "Tidx"
type: "type"
default_value {
type: DT_INT32
}
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "Tnumsegments"
type: "type"
default_value {
type: DT_INT32
}
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "Tsegmentids"
type: "type"
default_value {
type: DT_INT32
}
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
}

View File

@ -299,4 +299,10 @@ REGISTER_OP("FakeParam")
return Status::OK();
});
// Returns the device index.
REGISTER_OP("DeviceIndex")
.Output("index: int32")
.Attr("device_names: list(string)")
.SetShapeFn(shape_inference::ScalarShape);
} // end namespace tensorflow

View File

@ -1337,7 +1337,7 @@ REGISTER_OP("SparseSegmentMean")
.Input("indices: Tidx")
.Input("segment_ids: Tsegmentids")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("T: {bfloat16, float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionShapeFn);
@ -1348,7 +1348,7 @@ REGISTER_OP("SparseSegmentMeanWithNumSegments")
.Input("segment_ids: Tsegmentids")
.Input("num_segments: Tnumsegments")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("T: {bfloat16, float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
@ -1370,7 +1370,7 @@ REGISTER_OP("SparseSegmentSqrtN")
.Input("indices: Tidx")
.Input("segment_ids: Tsegmentids")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("T: {bfloat16, float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionShapeFn);
@ -1381,7 +1381,7 @@ REGISTER_OP("SparseSegmentSqrtNWithNumSegments")
.Input("segment_ids: Tsegmentids")
.Input("num_segments: Tnumsegments")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("T: {bfloat16, float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")

View File

@ -3553,6 +3553,13 @@ op {
has_minimum: true
minimum: 1
}
attr {
name: "enable_large_batch_splitting"
type: "bool"
default_value {
b: false
}
}
}
op {
name: "BatchIFFT"
@ -46097,6 +46104,7 @@ op {
type: "type"
allowed_values {
list {
type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
@ -46215,6 +46223,7 @@ op {
type: "type"
allowed_values {
list {
type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
@ -46283,6 +46292,7 @@ op {
type: "type"
allowed_values {
list {
type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
@ -46401,6 +46411,7 @@ op {
type: "type"
allowed_values {
list {
type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}

View File

@ -58,6 +58,15 @@ TEST(StateOpsTest, ScatterUpdate_ShapeFn) {
// Resolve shape on first updates dimension.
INFER_OK(op, "[1,2];[3];[?,2]", "in0");
// Allow the update to be a scalar.
INFER_OK(op, "[1,2];[3];?", "in0");
// Allow a scalar index.
INFER_OK(op, "[1,2];[];[2]", "in0");
// Check the requirement updates.shape = indices.shape + ref.shape[1:].
INFER_ERROR("Shapes must be equal rank, but are 1 and 0", op, "[2];[];[2]");
}
TEST(StateOpsTest, TemporaryVariable_ShapeFn) {

View File

@ -23,7 +23,7 @@ const char* kProtobufUint64Typename = "::tensorflow::protobuf_uint64";
TStringOutputStream::TStringOutputStream(tstring* target) : target_(target) {}
bool TStringOutputStream::Next(void** data, int* size) {
int old_size = target_->size();
size_t old_size = target_->size();
// Grow the string.
if (old_size < target_->capacity()) {
@ -32,16 +32,16 @@ bool TStringOutputStream::Next(void** data, int* size) {
target_->resize_uninitialized(target_->capacity());
} else {
// Size has reached capacity, try to double the size.
if (old_size > std::numeric_limits<int>::max() / 2) {
if (old_size > std::numeric_limits<size_t>::max() / 2) {
// Can not double the size otherwise it is going to cause integer
// overflow in the expression below: old_size * 2 ";
return false;
}
// Double the size, also make sure that the new size is at least
// kMinimumSize.
target_->resize_uninitialized(
std::max(old_size * 2,
kMinimumSize + 0)); // "+ 0" works around GCC4 weirdness.
target_->resize_uninitialized(std::max(
old_size * 2,
(size_t)kMinimumSize + 0)); // "+ 0" works around GCC4 weirdness.
}
*data = target_->data() + old_size;

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <aws/core/utils/StringUtils.h>
#include <aws/core/utils/logging/AWSLogging.h>
#include <aws/core/utils/logging/LogSystemInterface.h>
#include <aws/core/utils/stream/PreallocatedStreamBuf.h>
#include <aws/s3/S3Client.h>
#include <aws/s3/S3Errors.h>
#include <aws/s3/model/AbortMultipartUploadRequest.h>
@ -58,10 +59,16 @@ static const char* kS3TempFileTemplate = "/tmp/s3_filesystem_XXXXXX";
static const char* kS3FileSystemAllocationTag = "S3FileSystemAllocation";
static const size_t kS3ReadAppendableFileBufferSize = 1024 * 1024;
static const int64 kS3TimeoutMsec = 300000; // 5 min
static const uint64 kS3MultiPartCopyPartSize = 50 * 1024 * 1024; // 50MB
static const uint64 kS3MultiPartUploadChunkSize = 50 * 1024 * 1024; // 50 MB
static const uint64 kS3MultiPartDownloadChunkSize = 2 * 1024 * 1024; // 50 MB
static const int kS3GetChildrenMaxKeys = 100;
static const int kExecutorPoolSize = 5;
static const int kUploadRetries = 5;
// With this change multiple threads are used in one single download.
// Increasing the thread pool size since multiple downloads
// and uploads can occur in parallel.
static const int kExecutorPoolSize = 25;
static const int kUploadRetries = 3;
static const int kDownloadRetries = 3;
static const char* kExecutorTag = "TransferManagerExecutor";
Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
@ -223,9 +230,16 @@ static Status CreateStatusFromAwsError(
class S3RandomAccessFile : public RandomAccessFile {
public:
S3RandomAccessFile(const string& bucket, const string& object,
std::shared_ptr<Aws::S3::S3Client> s3_client)
: bucket_(bucket), object_(object), s3_client_(s3_client) {}
S3RandomAccessFile(
const string& bucket, const string& object,
const bool use_multi_part_download,
std::shared_ptr<Aws::Transfer::TransferManager> transfer_manager,
std::shared_ptr<Aws::S3::S3Client> s3_client)
: bucket_(bucket),
object_(object),
use_multi_part_download_(use_multi_part_download),
transfer_manager_(transfer_manager),
s3_client_(s3_client) {}
Status Name(StringPiece* result) const override {
return errors::Unimplemented("S3RandomAccessFile does not support Name()");
@ -235,6 +249,66 @@ class S3RandomAccessFile : public RandomAccessFile {
char* scratch) const override {
VLOG(1) << "ReadFilefromS3 s3://" << bucket_ << "/" << object_ << " from "
<< offset << " for n:" << n;
if (use_multi_part_download_) {
return ReadS3TransferManager(offset, n, result, scratch);
} else {
return ReadS3Client(offset, n, result, scratch);
}
}
Status ReadS3TransferManager(uint64 offset, size_t n, StringPiece* result,
char* scratch) const {
VLOG(3) << "Using TransferManager";
auto create_stream_fn = [&]() { // create stream lambda fn
return Aws::New<TFS3UnderlyingStream>(
"S3ReadStream",
Aws::New<Aws::Utils::Stream::PreallocatedStreamBuf>(
"S3ReadStream", reinterpret_cast<unsigned char*>(scratch), n));
};
VLOG(3) << "Created stream to read with transferManager";
std::shared_ptr<Aws::Transfer::TransferHandle> handle =
transfer_manager_.get()->DownloadFile(bucket_.c_str(), object_.c_str(),
offset, n, create_stream_fn);
handle->WaitUntilFinished();
// todo change this
int retries = 0;
while (handle->GetStatus() == Aws::Transfer::TransferStatus::FAILED &&
handle->GetLastError().GetResponseCode() !=
Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE &&
retries++ < kDownloadRetries) {
// only failed parts will be downloaded again
VLOG(1) << "Retrying read of s3://" << bucket_ << "/" << object_
<< " after failure. Current retry count:" << retries;
transfer_manager_.get()->RetryDownload(handle);
handle->WaitUntilFinished();
}
if (handle->GetStatus() != Aws::Transfer::TransferStatus::COMPLETED) {
auto error = handle->GetLastError();
if (error.GetResponseCode() ==
Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE) {
// expected when end of file is reached
n = 0;
*result = StringPiece(scratch, n);
return Status(error::OUT_OF_RANGE, "Read less bytes than requested");
}
return CreateStatusFromAwsError(error);
} else {
n = handle->GetBytesTotalSize();
*result = StringPiece(scratch, handle->GetBytesTransferred());
return Status::OK();
}
}
Status ReadS3Client(uint64 offset, size_t n, StringPiece* result,
char* scratch) const {
VLOG(3) << "ReadFile using S3Client s3://" << bucket_ << "/" << object_;
Aws::S3::Model::GetObjectRequest getObjectRequest;
getObjectRequest.WithBucket(bucket_.c_str()).WithKey(object_.c_str());
string bytes = strings::StrCat("bytes=", offset, "-", offset + n - 1);
@ -242,6 +316,7 @@ class S3RandomAccessFile : public RandomAccessFile {
getObjectRequest.SetResponseStreamFactory([]() {
return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag);
});
auto getObjectOutcome = this->s3_client_->GetObject(getObjectRequest);
if (!getObjectOutcome.IsSuccess()) {
auto error = getObjectOutcome.GetError();
@ -252,18 +327,21 @@ class S3RandomAccessFile : public RandomAccessFile {
return Status(error::OUT_OF_RANGE, "Read less bytes than requested");
}
return CreateStatusFromAwsError(error);
}
n = getObjectOutcome.GetResult().GetContentLength();
getObjectOutcome.GetResult().GetBody().read(scratch, n);
} else {
n = getObjectOutcome.GetResult().GetContentLength();
getObjectOutcome.GetResult().GetBody().read(scratch, n);
*result = StringPiece(scratch, n);
return Status::OK();
*result = StringPiece(scratch, n);
return Status::OK();
}
}
private:
string bucket_;
string object_;
std::shared_ptr<Aws::S3::S3Client> s3_client_;
std::shared_ptr<Aws::Transfer::TransferManager> transfer_manager_;
bool use_multi_part_download_;
};
class S3WritableFile : public WritableFile {
@ -375,16 +453,53 @@ class S3ReadOnlyMemoryRegion : public ReadOnlyMemoryRegion {
S3FileSystem::S3FileSystem()
: s3_client_(nullptr, ShutdownClient),
initialization_lock_(),
transfer_manager_(nullptr, ShutdownTransferManager),
executor_(nullptr, ShutdownExecutor) {
const char* part_size_str = getenv("S3_MULTI_PART_COPY_PART_SIZE");
multi_part_copy_part_size_ = kS3MultiPartCopyPartSize;
const char* part_size_str = getenv("S3_MULTI_PART_UPLOAD_CHUNK_SIZE");
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD] =
kS3MultiPartUploadChunkSize;
if (part_size_str) {
uint64 part_size_num;
if (strings::safe_strtou64(part_size_str, &part_size_num)) {
multi_part_copy_part_size_ = part_size_num;
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD] =
part_size_num;
}
}
// Different TensorFlow APIs call the download API with different
// buffer size. Download performance depends on that size and this chunk size.
part_size_str = getenv("S3_MULTI_PART_DOWNLOAD_CHUNK_SIZE");
multi_part_chunk_size_[Aws::Transfer::TransferDirection::DOWNLOAD] =
kS3MultiPartDownloadChunkSize;
if (part_size_str) {
uint64 part_size_num;
if (strings::safe_strtou64(part_size_str, &part_size_num)) {
multi_part_chunk_size_[Aws::Transfer::TransferDirection::DOWNLOAD] =
part_size_num;
}
}
use_multi_part_download_ = true;
const char* disable_transfer_mgr = getenv("S3_DISABLE_MULTI_PART_DOWNLOAD");
if (disable_transfer_mgr) {
if (disable_transfer_mgr[0] == '1') {
use_multi_part_download_ = false;
}
}
auto upload_pair = std::pair<Aws::Transfer::TransferDirection,
std::shared_ptr<Aws::Transfer::TransferManager>>(
Aws::Transfer::TransferDirection::UPLOAD,
std::shared_ptr<Aws::Transfer::TransferManager>(nullptr,
ShutdownTransferManager));
auto download_pair =
std::pair<Aws::Transfer::TransferDirection,
std::shared_ptr<Aws::Transfer::TransferManager>>(
Aws::Transfer::TransferDirection::DOWNLOAD,
std::shared_ptr<Aws::Transfer::TransferManager>(
nullptr, ShutdownTransferManager));
this->transfer_managers_.insert(upload_pair);
this->transfer_managers_.insert(download_pair);
}
S3FileSystem::~S3FileSystem() {}
@ -424,20 +539,22 @@ std::shared_ptr<Aws::S3::S3Client> S3FileSystem::GetS3Client() {
}
std::shared_ptr<Aws::Transfer::TransferManager>
S3FileSystem::GetTransferManager() {
S3FileSystem::GetTransferManager(
const Aws::Transfer::TransferDirection& direction) {
std::shared_ptr<Aws::S3::S3Client> s3_client = this->GetS3Client();
std::lock_guard<mutex> lock(this->initialization_lock_);
if (this->transfer_manager_.get() == nullptr) {
if (this->transfer_managers_[direction].get() == nullptr) {
Aws::Transfer::TransferManagerConfiguration config(
this->GetExecutor().get());
config.s3Client = s3_client;
config.bufferSize = this->multi_part_copy_part_size_;
// must be larger than pool size * multi_part_copy_part_size
config.bufferSize = this->multi_part_chunk_size_[direction];
// must be larger than pool size * multi part chunk size
config.transferBufferMaxHeapSize =
(kExecutorPoolSize + 1) * this->multi_part_copy_part_size_;
this->transfer_manager_ = Aws::Transfer::TransferManager::Create(config);
(kExecutorPoolSize + 1) * this->multi_part_chunk_size_[direction];
this->transfer_managers_[direction] =
Aws::Transfer::TransferManager::Create(config);
}
return this->transfer_manager_;
return this->transfer_managers_[direction];
}
std::shared_ptr<Aws::Utils::Threading::PooledThreadExecutor>
@ -452,9 +569,21 @@ S3FileSystem::GetExecutor() {
Status S3FileSystem::NewRandomAccessFile(
const string& fname, std::unique_ptr<RandomAccessFile>* result) {
return NewRandomAccessFile(fname, result, true);
}
Status S3FileSystem::NewRandomAccessFile(
const string& fname, std::unique_ptr<RandomAccessFile>* result,
bool use_multi_part_download) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
result->reset(new S3RandomAccessFile(bucket, object, this->GetS3Client()));
// check if an override was defined for this file. used for testing
bool use_mpd = this->use_multi_part_download_ && use_multi_part_download;
result->reset(new S3RandomAccessFile(
bucket, object, use_mpd,
this->GetTransferManager(Aws::Transfer::TransferDirection::DOWNLOAD),
this->GetS3Client()));
return Status::OK();
}
@ -462,8 +591,11 @@ Status S3FileSystem::NewWritableFile(const string& fname,
std::unique_ptr<WritableFile>* result) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
result->reset(new S3WritableFile(bucket, object, this->GetTransferManager(),
this->GetS3Client()));
result->reset(new S3WritableFile(
bucket, object,
this->GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD),
this->GetS3Client()));
return Status::OK();
}
@ -478,8 +610,10 @@ Status S3FileSystem::NewAppendableFile(const string& fname,
string bucket, object;
TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
result->reset(new S3WritableFile(bucket, object, this->GetTransferManager(),
this->GetS3Client()));
result->reset(new S3WritableFile(
bucket, object,
this->GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD),
this->GetS3Client()));
while (true) {
status = reader->Read(offset, kS3ReadAppendableFileBufferSize, &read_chunk,
@ -773,10 +907,13 @@ Status S3FileSystem::CopyFile(const Aws::String& source_bucket,
TF_RETURN_IF_ERROR(
this->GetFileSize(string(source_full_path.c_str()), &file_length));
int num_parts;
if (file_length <= multi_part_copy_part_size_) {
if (file_length <=
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD]) {
num_parts = 1;
} else {
num_parts = ceil((float)file_length / multi_part_copy_part_size_);
num_parts =
ceil((float)file_length /
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD]);
}
if (num_parts == 1) {
@ -786,7 +923,8 @@ Status S3FileSystem::CopyFile(const Aws::String& source_bucket,
"MultiPartCopy with number of parts more than 10000 is not supported. "
"Your object ",
source, " required ", num_parts,
" as multi_part_copy_part_size is set to ", multi_part_copy_part_size_,
" as multi_part_copy_part_size is set to ",
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD],
". You can control this part size using the environment variable ",
"S3_MULTI_PART_COPY_PART_SIZE to increase it.");
return tensorflow::errors::Unimplemented(message);
@ -831,7 +969,9 @@ Status S3FileSystem::MultiPartCopy(const Aws::String& source,
Aws::String uploadID = multipartUploadOutcome.GetResult().GetUploadId();
VLOG(1) << "Copying from " << source << " in " << num_parts
<< " parts of size " << multi_part_copy_part_size_ << " each";
<< " parts of size "
<< multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD]
<< " each";
Aws::S3::Model::CompletedMultipartUpload completedMPURequest;
// passed to each callback keyed by partNumber
@ -859,8 +999,12 @@ Status S3FileSystem::MultiPartCopy(const Aws::String& source,
for (std::map<int, PartState>::iterator it = incompletePartStates.begin();
it != incompletePartStates.end(); it++) {
int partNumber = it->first;
uint64 startPos = (partNumber - 1) * multi_part_copy_part_size_;
uint64 endPos = startPos + kS3MultiPartCopyPartSize - 1;
uint64 startPos =
(partNumber - 1) *
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD];
uint64 endPos =
startPos +
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD] - 1;
if (endPos >= file_length) {
endPos = file_length - 1;
}

View File

@ -52,6 +52,10 @@ class S3FileSystem : public FileSystem {
Status NewRandomAccessFile(
const string& fname, std::unique_ptr<RandomAccessFile>* result) override;
Status NewRandomAccessFile(const string& fname,
std::unique_ptr<RandomAccessFile>* result,
bool use_multi_part_download);
Status NewWritableFile(const string& fname,
std::unique_ptr<WritableFile>* result) override;
@ -101,8 +105,12 @@ class S3FileSystem : public FileSystem {
std::shared_ptr<Aws::S3::S3Client> s3_client_;
// Returns the member transfer manager, initializing as-needed.
std::shared_ptr<Aws::Transfer::TransferManager> GetTransferManager();
std::shared_ptr<Aws::Transfer::TransferManager> transfer_manager_;
std::shared_ptr<Aws::Transfer::TransferManager> GetTransferManager(
const Aws::Transfer::TransferDirection& direction);
void InitializeTransferManagers();
std::map<Aws::Transfer::TransferDirection,
std::shared_ptr<Aws::Transfer::TransferManager> >
transfer_managers_;
// Returns the member executor for transfer manager, initializing as-needed.
std::shared_ptr<Aws::Utils::Threading::PooledThreadExecutor> GetExecutor();
@ -132,8 +140,10 @@ class S3FileSystem : public FileSystem {
// Lock held when checking for s3_client_ and transfer_manager_ initialization
mutex initialization_lock_;
// size to split objects during multipart copy
uint64 multi_part_copy_part_size_;
// size to split objects during multipart upload/download/copy
std::map<Aws::Transfer::TransferDirection, uint64> multi_part_chunk_size_;
bool use_multi_part_download_;
};
/// S3 implementation of a file system with retry on failures.
@ -147,6 +157,16 @@ class RetryingS3FileSystem : public RetryingFileSystem<S3FileSystem> {
)) {}
};
// AWS Streams destroy the buffer (buf) passed, so creating a new
// IOStream that retains the buffer so the calling function
// can control it's lifecycle
class TFS3UnderlyingStream : public Aws::IOStream {
public:
using Base = Aws::IOStream;
TFS3UnderlyingStream(std::streambuf* buf) : Base(buf) {}
virtual ~TFS3UnderlyingStream() = default;
};
} // namespace tensorflow
#endif // TENSORFLOW_CONTRIB_S3_S3_FILE_SYSTEM_H_

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/core/platform/s3/s3_file_system.h"
#include <time.h>
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/file_system.h"
#include "tensorflow/core/platform/path.h"
@ -62,6 +64,96 @@ class S3FileSystemTest : public ::testing::Test {
return Status::OK();
}
Status ReadAllInChunks(const string& fname, string* content,
bool use_multi_part_download = true) {
std::unique_ptr<RandomAccessFile> reader;
TF_RETURN_IF_ERROR(
s3fs.NewRandomAccessFile(fname, &reader, use_multi_part_download));
uint64 file_size = 0;
TF_RETURN_IF_ERROR(s3fs.GetFileSize(fname, &file_size));
content->resize(file_size);
uint64 buffer_size = 16 * 1024 * 1024;
std::size_t part_count = (std::max)(
static_cast<size_t>((file_size + buffer_size - 1) / buffer_size),
static_cast<std::size_t>(1));
VLOG(1) << "buffersize:" << buffer_size << " file_size:" << file_size
<< " part_count=" << part_count;
std::unique_ptr<char[]> buffer{new char[buffer_size]};
std::stringstream ss;
int offset = 0;
int result_size = 0;
using namespace std::chrono;
auto start = high_resolution_clock::now();
for (int i = 0; i < part_count; i++) {
StringPiece result;
offset = i * buffer_size;
TF_RETURN_IF_ERROR(
reader->Read(offset, buffer_size, &result, buffer.get()));
if (result.size() != 0) {
ss.write(result.data(), result.size());
result_size += result.size();
}
if (result_size == file_size) {
break;
}
if (result.size() != buffer_size) {
VLOG(1) << "Result size and buffer size did not match";
if (result.empty()) {
return errors::OutOfRange("eof");
} else {
return errors::DataLoss("truncated record at ", offset);
}
}
}
if (file_size != result_size) {
return errors::DataLoss("expected ", file_size, " got ", result_size,
" bytes");
}
auto stop = high_resolution_clock::now();
duration<double> time_taken = duration_cast<duration<double>>(stop - start);
VLOG(1) << "Time Taken"
<< " : " << time_taken.count() << "seconds";
memcpy((char*)(content->data()), ss.str().data(),
static_cast<size_t>(file_size));
return Status::OK();
}
Status ReadLargeFile() {
// const string fname = TmpDir("train-00001-of-01024");
auto large_file_name = getenv("LARGE_DOWNLOAD_FILE_NAME");
const string fname = TmpDir(large_file_name);
string content_xfer;
string content_s3client;
// Read using Chunked Transfer Manager
VLOG(1) << "Using transfer manager";
TF_RETURN_IF_ERROR(ReadAllInChunks(fname, &content_xfer));
VLOG(1) << "Without transfer manager";
// Read using old S3 API and see if the contents match with TransferManager
TF_RETURN_IF_ERROR(ReadAllInChunks(fname, &content_s3client, false));
if (content_xfer == content_s3client) {
return Status::OK();
} else {
VLOG(1) << "ReadLargeFile contents DO NOT match";
return Status(error::OUT_OF_RANGE, "ReadLargeFile contents DO NOT match");
}
}
S3FileSystem s3fs;
};
@ -236,5 +328,9 @@ TEST_F(S3FileSystemTest, HasAtomicMove) {
EXPECT_EQ(has_atomic_move, false);
}
TEST_F(S3FileSystemTest, NewRandomAccessBigFile) {
TF_EXPECT_OK(ReadLargeFile());
}
} // namespace
} // namespace tensorflow

View File

@ -63,13 +63,11 @@ void HostOpMetricsDbBuilder::UpdateHostInfeedEnqInfo(
start_timestamp_ps_diff);
}
void DeviceOpMetricsDbBuilder::EnterOp(uint64 program_id,
absl::string_view name,
absl::string_view category,
absl::string_view provenance,
bool is_eager, uint64 occurrences,
uint64 time_ps, uint64 children_time_ps,
int64 flops, int64 bytes_accessed) {
void DeviceOpMetricsDbBuilder::EnterOp(
uint64 program_id, absl::string_view name, absl::string_view category,
absl::string_view provenance, bool is_eager, uint64 occurrences,
uint64 time_ps, uint64 children_time_ps, int64 flops, int64 bytes_accessed,
const std::vector<OpMetrics::MemoryAccessed>& memory_accessed_breakdown) {
uint64 self_time_ps = time_ps - children_time_ps;
DCHECK_GE(time_ps, self_time_ps);
OpMetrics* op_metrics = LookupOrInsertNewOpMetrics(program_id, name);
@ -89,6 +87,9 @@ void DeviceOpMetricsDbBuilder::EnterOp(uint64 program_id,
op_metrics->bytes_accessed() +
GetCappedPerf(bytes_accessed * occurrences, self_time_ps,
peak_hbm_bw_giga_bytes_per_second_ / 1000));
for (const auto& memory_accessed : memory_accessed_breakdown) {
*op_metrics->add_memory_accessed_breakdown() = memory_accessed;
}
db()->set_total_op_time_ps(db()->total_op_time_ps() + self_time_ps);
}

View File

@ -69,10 +69,14 @@ class DeviceOpMetricsDbBuilder : public OpMetricsDbBuilder {
// picoseconds.
// flops = the number of floating-point operations computed.
// bytes_accessed = the sum of bytes read and bytes written by this OP.
// memory_accessed_breakdown = the breakdown of memory accessed by operation
// type and memory space.
void EnterOp(uint64 program_id, absl::string_view name,
absl::string_view category, absl::string_view provenance,
bool is_eager, uint64 occurrences, uint64 time_ps,
uint64 children_time_ps, int64 flops, int64 bytes_accessed);
uint64 children_time_ps, int64 flops, int64 bytes_accessed,
const std::vector<OpMetrics::MemoryAccessed>&
memory_accessed_breakdown = {});
protected:
// Peak performance of a TensorCore or a GPU in TFLOP/s.

View File

@ -108,7 +108,7 @@ limitations under the License.
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
#define TF_GRAPH_DEF_VERSION 414 // Updated: 2020/5/27
#define TF_GRAPH_DEF_VERSION 415 // Updated: 2020/5/28
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
//

View File

@ -474,7 +474,9 @@ inline SparseTensor SparseTensor::Concat(
const int st_num_entries = st.num_entries();
// Fill in indices & values.
std::copy_n(&st.vals_.vec<T>()(0), st_num_entries, &vals_t(offset));
if (st_num_entries > 0) {
std::copy_n(&st.vals_.vec<T>()(0), st_num_entries, &vals_t(offset));
}
const auto* st_ix = &st.ix_.matrix<int64>()(0, 0);
auto* ix_out = &ix_t(offset, 0);

View File

@ -252,6 +252,7 @@ cc_library(
":version",
"//tensorflow/lite/c:common",
"//tensorflow/lite/core/api",
"//tensorflow/lite/delegates:status",
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
"//tensorflow/lite/experimental/resource",
"//tensorflow/lite/kernels/internal:compatibility",

View File

@ -680,6 +680,9 @@ def gen_model_coverage_test(src, model_name, data, failure_type, tags, size = "m
if failure_type[i] != "none":
args.append("--failure_type=%s" % failure_type[i])
i = i + 1
# Avoid coverage timeouts for large/enormous tests.
coverage_tags = ["nozapfhahn"] if size in ["large", "enormous"] else []
native.py_test(
name = "model_coverage_test_%s_%s" % (model_name, target_op_sets.lower().replace(",", "_")),
srcs = [src],
@ -696,7 +699,7 @@ def gen_model_coverage_test(src, model_name, data, failure_type, tags, size = "m
"no_gpu", # Executing with TF GPU configurations is redundant.
"no_oss",
"no_windows",
] + tags,
] + tags + coverage_tags,
deps = [
"//tensorflow/lite/testing/model_coverage:model_coverage_lib",
"//tensorflow/lite/python:lite",

View File

@ -22,34 +22,56 @@ namespace tflite {
// A simple utility for enabling profiled event tracing in TensorFlow Lite.
class Profiler {
public:
// As certain Profiler instance might be only interested in certain event
// types, we define each event type value to allow a Profiler to use
// bitmasking bitwise operations to determine whether an event should be
// recorded or not.
enum class EventType {
// Default event type, the metadata field has no special significance.
DEFAULT = 0,
DEFAULT = 1,
// The event is an operator invocation and the event_metadata field is the
// index of operator node.
OPERATOR_INVOKE_EVENT = 1,
OPERATOR_INVOKE_EVENT = 2,
// The event is an invocation for an internal operator of a TFLite delegate.
// The event_metadata field is the index of operator node that's specific to
// the delegate.
DELEGATE_OPERATOR_INVOKE_EVENT = 2
DELEGATE_OPERATOR_INVOKE_EVENT = 4,
// The event is a recording of runtime instrumentation such as the overall
// TFLite runtime status, the TFLite delegate status (if a delegate
// is applied), and the overall model inference latency etc.
// Note, the delegate status and overall status are stored as separate
// event_metadata fields. In particular, the delegate status is encoded
// as DelegateStatus::full_status().
GENERAL_RUNTIME_INSTRUMENTATION_EVENT = 8,
};
virtual ~Profiler() {}
// Signals the beginning of an event from a subgraph indexed at
// 'event_subgraph_index', returning a handle to the profile event.
// Signals the beginning of an event and returns a handle to the profile
// event. The `event_metadata1` and `event_metadata2` have different
// interpretations based on the actual Profiler instance and the `event_type`.
// For example, as for the 'SubgraphAwareProfiler' defined in
// lite/core/subgraph.h, when the event_type is OPERATOR_INVOKE_EVENT,
// `event_metadata1` represents the index of a TFLite node, and
// `event_metadata2` represents the index of the subgraph that this event
// comes from.
virtual uint32_t BeginEvent(const char* tag, EventType event_type,
uint32_t event_metadata,
uint32_t event_subgraph_index) = 0;
// Similar w/ the above, but the event comes from the primary subgraph that's
// indexed at 0.
virtual uint32_t BeginEvent(const char* tag, EventType event_type,
uint32_t event_metadata) {
return BeginEvent(tag, event_type, event_metadata, /*primary subgraph*/ 0);
int64_t event_metadata1,
int64_t event_metadata2) = 0;
// Similar w/ the above, but `event_metadata2` defaults to 0.
uint32_t BeginEvent(const char* tag, EventType event_type,
int64_t event_metadata) {
return BeginEvent(tag, event_type, event_metadata, /*event_metadata2*/ 0);
}
// Signals an end to the specified profile event with 'event_metadata's, This
// is useful when 'event_metadata's are not available when the event begins
// or when one wants to overwrite the 'event_metadata's set at the beginning.
virtual void EndEvent(uint32_t event_handle, int64_t event_metadata1,
int64_t event_metadata2) {}
// Signals an end to the specified profile event.
virtual void EndEvent(uint32_t event_handle) = 0;
@ -60,15 +82,18 @@ class Profiler {
// they assume the value is in "usec", if in any case subclasses
// didn't put usec, then the values are not meaningful.
// TODO karimnosseir: Revisit and make the function more clear.
virtual void AddEvent(const char* tag, EventType event_type,
uint32_t event_metadata, uint64_t start, uint64_t end) {
AddEvent(tag, event_type, event_metadata, start, end,
/*event_subgraph_index*/ 0);
void AddEvent(const char* tag, EventType event_type, uint64_t start,
uint64_t end, int64_t event_metadata) {
AddEvent(tag, event_type, start, end, event_metadata,
/*event_metadata2*/ 0);
}
virtual void AddEvent(const char* tag, EventType event_type,
uint32_t event_metadata, uint64_t start, uint64_t end,
uint32_t event_subgraph_index) {}
virtual void AddEvent(const char* tag, EventType event_type, uint64_t start,
uint64_t end, int64_t event_metadata1,
int64_t event_metadata2) {}
protected:
friend class ScopedProfile;
};
// Adds a profile event to `profiler` that begins with the construction
@ -79,7 +104,7 @@ class ScopedProfile {
public:
ScopedProfile(Profiler* profiler, const char* tag,
Profiler::EventType event_type = Profiler::EventType::DEFAULT,
uint32_t event_metadata = 0)
int64_t event_metadata = 0)
: profiler_(profiler), event_handle_(0) {
if (profiler) {
event_handle_ = profiler_->BeginEvent(tag, event_type, event_metadata);
@ -92,8 +117,8 @@ class ScopedProfile {
}
}
private:
Profiler* const profiler_;
protected:
Profiler* profiler_;
uint32_t event_handle_;
};
@ -113,6 +138,31 @@ class ScopedDelegateOperatorProfile : public ScopedProfile {
static_cast<uint32_t>(node_index)) {}
};
class ScopedRuntimeInstrumentationProfile : public ScopedProfile {
public:
ScopedRuntimeInstrumentationProfile(Profiler* profiler, const char* tag)
: ScopedProfile(
profiler, tag,
Profiler::EventType::GENERAL_RUNTIME_INSTRUMENTATION_EVENT, -1) {}
void set_runtime_status(int64_t delegate_status, int64_t interpreter_status) {
if (profiler_) {
delegate_status_ = delegate_status;
interpreter_status_ = interpreter_status;
}
}
~ScopedRuntimeInstrumentationProfile() {
if (profiler_) {
profiler_->EndEvent(event_handle_, delegate_status_, interpreter_status_);
}
}
private:
int64_t delegate_status_;
int64_t interpreter_status_;
};
} // namespace tflite
#define TFLITE_VARNAME_UNIQ_IMPL(name, ctr) name##ctr
@ -130,4 +180,15 @@ class ScopedDelegateOperatorProfile : public ScopedProfile {
tflite::ScopedDelegateOperatorProfile TFLITE_VARNAME_UNIQ( \
_profile_, __COUNTER__)((profiler), (tag), (node_index))
#define TFLITE_ADD_RUNTIME_INSTRUMENTATION_EVENT( \
profiler, tag, delegate_status, interpreter_status) \
do { \
if (!profiler) { \
const auto handle = profiler->BeginEvent( \
tag, Profiler::EventType::GENERAL_RUNTIME_INSTRUMENTATION_EVENT, \
delegate_status, interpreter_status); \
profiler->EndEvent(handle); \
} \
} while (false);
#endif // TENSORFLOW_LITE_CORE_API_PROFILER_H_

View File

@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_CORE_SUBGRAPH_H_
#define TENSORFLOW_LITE_CORE_SUBGRAPH_H_
#include <cstdint>
#include <cstdlib>
#include <map>
#include <utility>
@ -338,21 +339,16 @@ class Subgraph {
class SubgraphAwareProfiler : public Profiler {
public:
// Constructor should be called with the non-nullptr profiler argument.
SubgraphAwareProfiler(Profiler* profiler, uint32_t subgraph_index)
SubgraphAwareProfiler(Profiler* profiler, int64_t subgraph_index)
: profiler_(profiler), subgraph_index_(subgraph_index) {}
~SubgraphAwareProfiler() override {}
uint32_t BeginEvent(const char* tag, EventType event_type,
uint32_t event_metadata,
uint32_t subgraph_index) override {
int64_t event_metadata1,
int64_t event_metadata2) override {
if (!profiler_) return 0;
return profiler_->BeginEvent(tag, event_type, event_metadata,
subgraph_index);
}
uint32_t BeginEvent(const char* tag, EventType event_type,
uint32_t event_metadata) override {
return BeginEvent(tag, event_type, event_metadata, subgraph_index_);
return profiler_->BeginEvent(tag, event_type, event_metadata1,
subgraph_index_);
}
void EndEvent(uint32_t event_handle) override {
@ -360,17 +356,24 @@ class Subgraph {
profiler_->EndEvent(event_handle);
}
void AddEvent(const char* tag, EventType event_type,
uint32_t event_metadata, uint64_t start,
uint64_t end) override {
void EndEvent(uint32_t event_handle, int64_t event_metadata1,
int64_t event_metadata2) override {
if (!profiler_) return;
profiler_->AddEvent(tag, event_type, event_metadata, start, end);
profiler_->EndEvent(event_handle, event_metadata1, event_metadata2);
}
void AddEvent(const char* tag, EventType event_type, uint64_t start,
uint64_t end, int64_t event_metadata1,
int64_t event_metadata2) override {
if (!profiler_) return;
profiler_->AddEvent(tag, event_type, start, end, event_metadata1,
subgraph_index_);
}
private:
// Not own the memory.
Profiler* const profiler_;
const uint32_t subgraph_index_;
const int64_t subgraph_index_;
};
// Prevent 'context_' from accessing functions that are only available to

View File

@ -20,6 +20,15 @@ package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "status",
hdrs = ["status.h"],
copts = tflite_copts(),
deps = [
"//tensorflow/lite/c:common",
],
)
cc_library(
name = "utils",
srcs = ["utils.cc"],

View File

@ -43,8 +43,10 @@ cc_library(
srcs = ["arguments.cc"],
hdrs = ["arguments.h"],
deps = [
":gpu_object",
":opencl_wrapper",
":util",
"//tensorflow/lite/delegates/gpu/common:access_type",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/common:util",
@ -305,6 +307,16 @@ cc_library(
],
)
cc_library(
name = "gpu_object",
hdrs = ["gpu_object.h"],
deps = [
":opencl_wrapper",
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:status",
],
)
cc_library(
name = "inference_context",
srcs = ["inference_context.cc"],

View File

@ -23,10 +23,14 @@ namespace tflite {
namespace gpu {
namespace cl {
namespace {
bool IsWordSymbol(char symbol) {
return absl::ascii_isalnum(symbol) || symbol == '_';
}
std::string GetNextWord(const std::string& code, size_t first_position) {
size_t pos = first_position;
char t = code[pos];
while (absl::ascii_isalnum(t) || t == '_') {
while (IsWordSymbol(t)) {
pos++;
t = code[pos];
}
@ -38,13 +42,19 @@ Arguments::Arguments(Arguments&& args)
: int_values_(std::move(args.int_values_)),
shared_int4s_data_(std::move(args.shared_int4s_data_)),
float_values_(std::move(args.float_values_)),
shared_float4s_data_(std::move(args.shared_float4s_data_)) {}
shared_float4s_data_(std::move(args.shared_float4s_data_)),
buffers_(std::move(args.buffers_)),
images2d_(std::move(args.images2d_)),
objects_(std::move(args.objects_)) {}
Arguments& Arguments::operator=(Arguments&& args) {
if (this != &args) {
int_values_ = std::move(args.int_values_);
shared_int4s_data_ = std::move(args.shared_int4s_data_);
float_values_ = std::move(args.float_values_);
shared_float4s_data_ = std::move(args.shared_float4s_data_);
buffers_ = std::move(args.buffers_);
images2d_ = std::move(args.images2d_);
objects_ = std::move(args.objects_);
}
return *this;
}
@ -55,11 +65,40 @@ void Arguments::AddFloat(const std::string& name, float value) {
void Arguments::AddInt(const std::string& name, int value) {
int_values_[name].value = value;
}
void Arguments::AddBuffer(const std::string& name,
const GPUBufferDescriptor& desc) {
buffers_[name] = desc;
}
void Arguments::AddImage2D(const std::string& name,
const GPUImage2DDescriptor& desc) {
images2d_[name] = desc;
}
void Arguments::AddObject(const std::string& name, GPUObjectPtr&& object) {
objects_[name] = {AccessType::READ, std::move(object)};
}
void Arguments::AddGPUResources(const std::string& name,
const GPUResources& resources) {
for (const auto& r : resources.ints) {
AddInt(absl::StrCat(name, "_", r));
}
for (const auto& r : resources.floats) {
AddFloat(absl::StrCat(name, "_", r));
}
for (const auto& r : resources.buffers) {
AddBuffer(absl::StrCat(name, "_", r.first), r.second);
}
for (const auto& r : resources.images2d) {
AddImage2D(absl::StrCat(name, "_", r.first), r.second);
}
}
absl::Status Arguments::SetInt(const std::string& name, int value) {
auto ii = int_values_.find(name);
if (ii == int_values_.end()) {
return absl::NotFoundError(absl::StrCat("No argument with name - ", name));
return absl::NotFoundError(
absl::StrCat("No int argument with name - ", name));
}
ii->second.value = value;
if (ii->second.active) {
@ -71,7 +110,8 @@ absl::Status Arguments::SetInt(const std::string& name, int value) {
absl::Status Arguments::SetFloat(const std::string& name, float value) {
auto fi = float_values_.find(name);
if (fi == float_values_.end()) {
return absl::NotFoundError(absl::StrCat("No argument with name - ", name));
return absl::NotFoundError(
absl::StrCat("No float argument with name - ", name));
}
fi->second.value = value;
if (fi->second.active) {
@ -80,8 +120,60 @@ absl::Status Arguments::SetFloat(const std::string& name, float value) {
return absl::OkStatus();
}
absl::Status Arguments::SetImage2D(const std::string& name, cl_mem memory) {
auto ti = images2d_.find(name);
if (ti == images2d_.end()) {
return absl::NotFoundError(
absl::StrCat("No image2D argument with name - ", name));
}
ti->second.memory = memory;
return absl::OkStatus();
}
absl::Status Arguments::SetBuffer(const std::string& name, cl_mem memory) {
auto it = buffers_.find(name);
if (it == buffers_.end()) {
return absl::NotFoundError(
absl::StrCat("No buffer argument with name - ", name));
}
it->second.memory = memory;
return absl::OkStatus();
}
absl::Status Arguments::SetGPUResources(
const std::string& name, const GPUResourcesWithValue& resources) {
for (const auto& r : resources.ints) {
RETURN_IF_ERROR(SetInt(absl::StrCat(name, "_", r.first), r.second));
}
for (const auto& r : resources.floats) {
RETURN_IF_ERROR(SetFloat(absl::StrCat(name, "_", r.first), r.second));
}
for (const auto& r : resources.buffers) {
RETURN_IF_ERROR(SetBuffer(absl::StrCat(name, "_", r.first), r.second));
}
for (const auto& r : resources.images2d) {
RETURN_IF_ERROR(SetImage2D(absl::StrCat(name, "_", r.first), r.second));
}
return absl::OkStatus();
}
absl::Status Arguments::TransformToCLCode(std::string* code) {
RETURN_IF_ERROR(AddObjectArgs());
ResolveArgsPass(code);
return absl::OkStatus();
}
std::string Arguments::GetListOfArgs() {
std::string result;
for (auto& t : buffers_) {
const std::string type_name =
t.second.data_type == DataType::FLOAT32 ? "float" : "half";
absl::StrAppend(&result, ",\n __global ", type_name, t.second.element_size,
"* ", t.first);
}
for (auto& t : images2d_) {
absl::StrAppend(&result, ",\n __read_only image2d_t ", t.first);
}
for (int i = 0; i < shared_int4s_data_.size() / 4; ++i) {
absl::StrAppend(&result, ",\n int4 shared_int4_", i);
}
@ -92,6 +184,26 @@ std::string Arguments::GetListOfArgs() {
}
absl::Status Arguments::Bind(cl_kernel kernel, int offset) {
for (auto& t : buffers_) {
const int error_code =
clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
if (error_code != CL_SUCCESS) {
return absl::UnknownError(absl::StrCat(
"Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
"(at index - ", offset, ")"));
}
offset++;
}
for (auto& t : images2d_) {
const int error_code =
clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
if (error_code != CL_SUCCESS) {
return absl::UnknownError(absl::StrCat(
"Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
"(at index - ", offset, ")"));
}
offset++;
}
for (int i = 0; i < shared_int4s_data_.size() / 4; ++i) {
const int error_code = clSetKernelArg(kernel, offset, sizeof(int32_t) * 4,
&shared_int4s_data_[i * 4]);
@ -148,8 +260,8 @@ std::string Arguments::AddActiveArgument(const std::string& arg_name) {
}
void Arguments::ResolveArgsPass(std::string* code) {
std::string result;
constexpr char kPrefix[] = "args.";
std::string result;
size_t position = 0;
size_t next_position = code->find(kPrefix);
while (next_position != std::string::npos) {
@ -168,6 +280,16 @@ void Arguments::ResolveArgsPass(std::string* code) {
shared_float4s_data_.resize(shared_float4s_aligned_size);
}
absl::Status Arguments::AddObjectArgs() {
for (auto& t : objects_) {
AddGPUResources(t.first,
t.second.obj_ptr->GetGPUDescriptor()->GetGPUResources());
RETURN_IF_ERROR(
SetGPUResources(t.first, t.second.obj_ptr->GetGPUResources()));
}
return absl::OkStatus();
}
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -20,8 +20,10 @@ limitations under the License.
#include <string>
#include <vector>
#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
#include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h"
#include "tensorflow/lite/delegates/gpu/cl/util.h"
#include "tensorflow/lite/delegates/gpu/common/access_type.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
@ -35,15 +37,21 @@ class Arguments {
Arguments() = default;
void AddFloat(const std::string& name, float value = 0.0f);
void AddInt(const std::string& name, int value = 0);
void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc);
void AddImage2D(const std::string& name, const GPUImage2DDescriptor& desc);
void AddObject(const std::string& name, GPUObjectPtr&& object);
absl::Status SetInt(const std::string& name, int value);
absl::Status SetFloat(const std::string& name, float value);
absl::Status SetImage2D(const std::string& name, cl_mem memory);
absl::Status SetBuffer(const std::string& name, cl_mem memory);
std::string GetListOfArgs();
absl::Status Bind(cl_kernel kernel, int offset);
void ResolveArgsPass(std::string* code);
absl::Status TransformToCLCode(std::string* code);
// Move only
Arguments(Arguments&& args);
@ -53,6 +61,14 @@ class Arguments {
private:
std::string AddActiveArgument(const std::string& arg_name);
void AddGPUResources(const std::string& name, const GPUResources& resources);
absl::Status SetGPUResources(const std::string& name,
const GPUResourcesWithValue& resources);
absl::Status AddObjectArgs();
void ResolveArgsPass(std::string* code);
struct IntValue {
int value;
@ -79,6 +95,15 @@ class Arguments {
};
std::map<std::string, FloatValue> float_values_;
std::vector<float> shared_float4s_data_;
std::map<std::string, GPUBufferDescriptor> buffers_;
std::map<std::string, GPUImage2DDescriptor> images2d_;
struct ObjectArg {
AccessType access_type;
GPUObjectPtr obj_ptr;
};
std::map<std::string, ObjectArg> objects_;
};
} // namespace cl

View File

@ -0,0 +1,121 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_GPU_OBJECT_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_GPU_OBJECT_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
namespace tflite {
namespace gpu {
namespace cl {
struct GPUImage2DDescriptor {
DataType data_type;
cl_mem memory;
};
struct GPUBufferDescriptor {
DataType data_type;
int element_size;
cl_mem memory;
};
struct GPUResources {
std::vector<std::string> ints;
std::vector<std::string> floats;
std::vector<std::pair<std::string, GPUBufferDescriptor>> buffers;
std::vector<std::pair<std::string, GPUImage2DDescriptor>> images2d;
std::vector<std::string> GetNames() const {
std::vector<std::string> names = ints;
names.insert(names.end(), floats.begin(), floats.end());
for (const auto& obj : buffers) {
names.push_back(obj.first);
}
for (const auto& obj : images2d) {
names.push_back(obj.first);
}
return names;
}
};
struct GPUResourcesWithValue {
std::vector<std::pair<std::string, int>> ints;
std::vector<std::pair<std::string, float>> floats;
std::vector<std::pair<std::string, cl_mem>> buffers;
std::vector<std::pair<std::string, cl_mem>> images2d;
};
class GPUObjectDescriptor {
public:
GPUObjectDescriptor() = default;
GPUObjectDescriptor(const GPUObjectDescriptor& obj_desc)
: state_vars_(obj_desc.state_vars_) {}
GPUObjectDescriptor& operator=(const GPUObjectDescriptor& obj_desc) {
if (this != &obj_desc) {
state_vars_ = obj_desc.state_vars_;
}
return *this;
}
virtual ~GPUObjectDescriptor() = default;
void SetStateVar(const std::string& key, const std::string& value) const {
state_vars_[key] = value;
}
virtual std::string PerformConstExpr(const std::string& const_expr) const {
return "";
}
virtual absl::Status PerformSelector(const std::string& selector,
const std::vector<std::string>& args,
std::string* result) const {
*result = "";
return absl::OkStatus();
}
virtual GPUResources GetGPUResources() const { return GPUResources(); }
protected:
mutable std::map<std::string, std::string> state_vars_;
};
class GPUObject {
public:
GPUObject() = default;
// Move only
GPUObject(GPUObject&& obj_desc) = default;
GPUObject& operator=(GPUObject&& obj_desc) = default;
GPUObject(const GPUObject&) = delete;
GPUObject& operator=(const GPUObject&) = delete;
virtual ~GPUObject() = default;
virtual const GPUObjectDescriptor* GetGPUDescriptor() const = 0;
virtual GPUResourcesWithValue GetGPUResources() const = 0;
};
using GPUObjectPtr = std::unique_ptr<GPUObject>;
} // namespace cl
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_GPU_OBJECT_H_

View File

@ -115,8 +115,7 @@ std::string GetTransposeCode(
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHSB("result", "X", "Y", "Z", batch_id);
c += "}\n";
args->ResolveArgsPass(&c);
return absl::Substitute(c, args->GetListOfArgs());
return c;
}
} // namespace
@ -139,8 +138,10 @@ Transpose& Transpose::operator=(Transpose&& operation) {
}
absl::Status Transpose::Compile(const CreationContext& creation_context) {
const auto code =
std::string code =
GetTransposeCode(definition_, attr_, linked_operations_, &args_);
RETURN_IF_ERROR(args_.TransformToCLCode(&code));
code = absl::Substitute(code, args_.GetListOfArgs());
return creation_context.cache->GetOrCreateCLKernel(
code, "main_function", *creation_context.context,
*creation_context.device, &kernel_);

View File

@ -402,6 +402,13 @@ class AddOperationParser : public TFLiteOperationParser {
return absl::UnimplementedError("ADD requires two input tensors.");
}
// TODO(eignasheva): Add shapes check.
for (int i = 0; i < 2; i++) {
auto input = tflite::GetInput(context, tflite_node, i);
if (IsConstantTensor(input) && input->dims->size > 0) {
RETURN_IF_ERROR(CheckIfLinearConvertible(input->dims));
}
}
TfLiteAddParams* tf_options = nullptr;
return RetrieveBuiltinData(tflite_node, &tf_options);
}
@ -2453,15 +2460,15 @@ class TransformLandmarksV2OperationParser : public TFLiteOperationParser {
RETURN_IF_ERROR(reader->AddOutputs(node));
std::string op_name = "transform_landmarks_v2";
node->operation.type = op_name;
BHWC output_shape;
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape;
BHWC output_shape = output_value->tensor.shape;
RETURN_IF_ERROR(
ParseCustomAttributes(op_name, tflite_node->custom_initial_data,
tflite_node->custom_initial_data_size,
&(node->operation.attributes), &output_shape));
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape;
return absl::OkStatus();
}

Some files were not shown because too many files have changed in this diff Show More