Merge remote-tracking branch 'upstream/master' into offline_memory_planner
This commit is contained in:
commit
f409152691
@ -91,7 +91,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
}
|
||||
string message = absl::StrCat(
|
||||
"Function invoked by the following node is not compilable: ",
|
||||
SummarizeNodeDef(node_def), ".\n");
|
||||
SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n");
|
||||
absl::StrAppend(&message, "Uncompilable nodes:");
|
||||
for (const auto& node_info : uncompilable_node_info) {
|
||||
string node_message =
|
||||
|
||||
@ -201,9 +201,7 @@ void XlaComputationLaunchContext::PopulateInputs(
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
// Build ShapedBuffers that point directly to the Tensor buffers.
|
||||
arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1);
|
||||
arg_buffers_.resize(kernel->xla_input_shapes.size());
|
||||
arg_ptrs_ = std::vector<ShapedBuffer*>(arg_buffers_.size());
|
||||
arg_ptrs_ = std::vector<ShapedBuffer*>(kernel->xla_input_shapes.size());
|
||||
|
||||
// Pass remaining parameters.
|
||||
const Tensor* t;
|
||||
@ -239,11 +237,11 @@ void XlaComputationLaunchContext::PopulateInputs(
|
||||
<< " not the same as on-host shape "
|
||||
<< xla::ShapeUtil::HumanStringWithLayout(shape);
|
||||
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
|
||||
arg_buffers_[i] = absl::make_unique<ShapedBuffer>(
|
||||
arg_buffers_.emplace_back(
|
||||
/*on_host_shape=*/shape, /*on_device_shape=*/shape,
|
||||
client_->platform(), client_->default_device_ordinal());
|
||||
arg_buffers_[i]->set_buffer(dmem, /*index=*/{});
|
||||
arg_ptrs_[i] = arg_buffers_[i].get();
|
||||
arg_buffers_.back().set_buffer(dmem, /*index=*/{});
|
||||
arg_ptrs_[i] = &arg_buffers_.back();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -165,7 +165,7 @@ class XlaComputationLaunchContext {
|
||||
se::DeviceMemoryAllocator* xla_allocator_;
|
||||
bool allocate_xla_tensors_;
|
||||
bool use_multiple_streams_;
|
||||
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers_;
|
||||
std::deque<xla::ShapedBuffer> arg_buffers_;
|
||||
std::vector<xla::ShapedBuffer*> arg_ptrs_;
|
||||
};
|
||||
|
||||
|
||||
@ -279,22 +279,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tftext_utils_test",
|
||||
size = "small",
|
||||
srcs = ["utils/lstm_utils_test.cc"],
|
||||
deps = [
|
||||
":lstm_utils",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stateful_ops_utils",
|
||||
srcs = [
|
||||
|
||||
@ -2297,26 +2297,17 @@ def TFL_PReluOp : TFL_Op<"prelu", [
|
||||
NoSideEffect,
|
||||
ResultsBroadcastableShape,
|
||||
TFL_GpuTargetOp,
|
||||
TFL_OperandHasRankAtMost<0, 4>,
|
||||
TFL_OperandHasRankAtMost<1, 4>,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
PredOpTrait<"input and output must have the same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
PredOpTrait<"'alpha' should have one less rank than 'input'.",
|
||||
Or<[TFL_OperandIsUnrankedPred<0>,
|
||||
TFL_OperandIsUnrankedPred<1>,
|
||||
CPred<"$_op.getOperand(0).getType().cast<ShapedType>().getRank() == "
|
||||
"$_op.getOperand(1).getType().cast<ShapedType>().getRank() "
|
||||
"+ 1">]>>]> {
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
|
||||
let summary = "Parameterized Relu operator";
|
||||
|
||||
let description = [{
|
||||
Parameterized Relu operator
|
||||
x -> x >= 0 ? x : (alpha * x)
|
||||
where alpha is a trainable tensor.
|
||||
alpha should have one less rank than the input as it doesn't have the batch
|
||||
dimension, and the other dimensions either should be the same size as input
|
||||
or size 1, where it is broadcasted in the second case.
|
||||
input and alpha should be the same size as input or be broadcastable.
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
|
||||
@ -55,8 +55,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
std::vector<string> node_names;
|
||||
std::vector<string> node_dtypes;
|
||||
std::vector<std::vector<int>> node_shapes;
|
||||
std::vector<double> node_mins;
|
||||
std::vector<double> node_maxs;
|
||||
std::vector<llvm::Optional<double>> node_mins;
|
||||
std::vector<llvm::Optional<double>> node_maxs;
|
||||
|
||||
// Populate quantization specs.
|
||||
TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(
|
||||
|
||||
@ -125,8 +125,8 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
std::vector<string> node_names;
|
||||
std::vector<string> node_dtypes;
|
||||
std::vector<std::vector<int>> node_shapes;
|
||||
std::vector<double> node_mins;
|
||||
std::vector<double> node_maxs;
|
||||
std::vector<llvm::Optional<double>> node_mins;
|
||||
std::vector<llvm::Optional<double>> node_maxs;
|
||||
|
||||
// Populate quantization specs.
|
||||
TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(
|
||||
|
||||
@ -177,14 +177,13 @@ Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) {
|
||||
return RegisterCustomBuiltinOps(extra_tf_opdefs);
|
||||
}
|
||||
|
||||
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
|
||||
const toco::TocoFlags& toco_flags,
|
||||
mlir::TFL::QuantizationSpecs* quant_specs,
|
||||
std::vector<string>* node_names,
|
||||
std::vector<string>* node_dtypes,
|
||||
std::vector<std::vector<int>>* node_shapes,
|
||||
std::vector<double>* node_mins,
|
||||
std::vector<double>* node_maxs) {
|
||||
Status PopulateQuantizationSpecs(
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
|
||||
std::vector<string>* node_dtypes,
|
||||
std::vector<std::vector<int>>* node_shapes,
|
||||
std::vector<llvm::Optional<double>>* node_mins,
|
||||
std::vector<llvm::Optional<double>>* node_maxs) {
|
||||
quant_specs->inference_input_type =
|
||||
ConvertIODataTypeToDataType(toco_flags.inference_input_type());
|
||||
tensorflow::DataType inference_type =
|
||||
@ -211,11 +210,16 @@ Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
|
||||
flag.shape().dims().end()));
|
||||
// Currently, only UINT8 and INT8 require inputs stats
|
||||
if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto min_max, InputStatsToMinMax(flag.mean_value(), flag.std_value(),
|
||||
inference_type));
|
||||
node_mins->push_back(min_max.first);
|
||||
node_maxs->push_back(min_max.second);
|
||||
if (flag.has_mean_value() && flag.has_std_value()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto min_max, InputStatsToMinMax(flag.mean_value(),
|
||||
flag.std_value(), inference_type));
|
||||
node_mins->push_back(min_max.first);
|
||||
node_maxs->push_back(min_max.second);
|
||||
} else {
|
||||
node_mins->push_back(llvm::None);
|
||||
node_maxs->push_back(llvm::None);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -34,14 +34,13 @@ Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags);
|
||||
|
||||
// Populate quantization specs (or not) given user specified ranges for each
|
||||
// input arrays.
|
||||
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
|
||||
const toco::TocoFlags& toco_flags,
|
||||
mlir::TFL::QuantizationSpecs* quant_specs,
|
||||
std::vector<string>* node_names,
|
||||
std::vector<string>* node_dtypes,
|
||||
std::vector<std::vector<int>>* node_shapes,
|
||||
std::vector<double>* node_mins,
|
||||
std::vector<double>* node_maxs);
|
||||
Status PopulateQuantizationSpecs(
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
|
||||
std::vector<string>* node_dtypes,
|
||||
std::vector<std::vector<int>>* node_shapes,
|
||||
std::vector<llvm::Optional<double>>* node_mins,
|
||||
std::vector<llvm::Optional<double>>* node_maxs);
|
||||
|
||||
// Convert imported MLIR file to TfLite flatbuffer.
|
||||
// This will also run relevant passes as well.
|
||||
|
||||
@ -45,7 +45,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
|
||||
absl::string_view inference_type,
|
||||
QuantizationSpecs* quant_specs) {
|
||||
std::vector<std::string> input_nodes = absl::StrSplit(node_names, ',');
|
||||
std::vector<double> node_mins;
|
||||
std::vector<llvm::Optional<double>> node_mins;
|
||||
if (!min_values.empty()) {
|
||||
std::vector<std::string> node_mins_str = absl::StrSplit(min_values, ',');
|
||||
for (int i = 0; i < node_mins_str.size(); i++) {
|
||||
@ -57,7 +57,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<double> node_maxs;
|
||||
std::vector<llvm::Optional<double>> node_maxs;
|
||||
if (!max_values.empty()) {
|
||||
std::vector<std::string> node_maxs_str = absl::StrSplit(max_values, ',');
|
||||
for (int i = 0; i < node_maxs_str.size(); i++) {
|
||||
@ -79,11 +79,11 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
|
||||
quant_specs);
|
||||
}
|
||||
|
||||
bool GetInputNodeQuantSpecs(const std::vector<std::string>& node_names,
|
||||
const std::vector<double>& node_mins,
|
||||
const std::vector<double>& node_maxs,
|
||||
tensorflow::DataType inference_type,
|
||||
QuantizationSpecs* quant_specs) {
|
||||
bool GetInputNodeQuantSpecs(
|
||||
const std::vector<std::string>& node_names,
|
||||
const std::vector<llvm::Optional<double>>& node_mins,
|
||||
const std::vector<llvm::Optional<double>>& node_maxs,
|
||||
tensorflow::DataType inference_type, QuantizationSpecs* quant_specs) {
|
||||
quant_specs->inference_type = inference_type;
|
||||
|
||||
// If min/max are not specified, just return;
|
||||
|
||||
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONFIG_H_
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -69,7 +70,8 @@ struct QuantizationSpecs {
|
||||
// arguments. They are only used when `weight_quantization` is set to false,
|
||||
// and the model is required to have quantization parameters, either from
|
||||
// quantization aware training or calibration, for the remaining tensors.
|
||||
std::vector<std::pair<double, double>> input_ranges;
|
||||
std::vector<std::pair<llvm::Optional<double>, llvm::Optional<double>>>
|
||||
input_ranges;
|
||||
|
||||
// The default ranges can be used when a tensor doesn't have quantization
|
||||
// parameters and couldn't be quantized. Used only for latency tests.
|
||||
@ -130,11 +132,11 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
|
||||
// Gets the quantization specification for input arrays. The array names are not
|
||||
// stored in the spec, and will be matched by position. The min/max will be
|
||||
// ignored if the inference_type isn't a quantized type. Returns true if failed.
|
||||
bool GetInputNodeQuantSpecs(const std::vector<std::string>& node_names,
|
||||
const std::vector<double>& node_mins,
|
||||
const std::vector<double>& node_maxs,
|
||||
tensorflow::DataType inference_type,
|
||||
QuantizationSpecs* quant_specs);
|
||||
bool GetInputNodeQuantSpecs(
|
||||
const std::vector<std::string>& node_names,
|
||||
const std::vector<llvm::Optional<double>>& node_mins,
|
||||
const std::vector<llvm::Optional<double>>& node_maxs,
|
||||
tensorflow::DataType inference_type, QuantizationSpecs* quant_specs);
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
||||
|
||||
@ -109,8 +109,8 @@ class PrepareQuantizePass
|
||||
// Get the min and max values from the quantization specification for the
|
||||
// current function function and argument index. Uses default values if
|
||||
// the function is specified in the `quantize_whitelist`.
|
||||
std::pair<double, double> GetMinMaxValuesForArgument(
|
||||
llvm::StringRef func_name, int index) {
|
||||
std::pair<llvm::Optional<double>, llvm::Optional<double>>
|
||||
GetMinMaxValuesForArgument(llvm::StringRef func_name, int index) {
|
||||
if (func_name == quant_specs_.target_func) {
|
||||
return quant_specs_.input_ranges[index];
|
||||
} else {
|
||||
@ -160,10 +160,14 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
|
||||
}
|
||||
|
||||
auto min_max = GetMinMaxValuesForArgument(func_name, i);
|
||||
// The input min/max or mean/std are not specified, then skip.
|
||||
if (!min_max.first.hasValue() || !min_max.second.hasValue()) return;
|
||||
|
||||
TypeAttr params = quant::GetQuantizedTypeAttr(
|
||||
builder, input_type, builder.getF64FloatAttr(min_max.first),
|
||||
builder.getF64FloatAttr(min_max.second), /*quant_dim=*/-1, num_bits,
|
||||
narrow_range, is_signed);
|
||||
builder, input_type,
|
||||
builder.getF64FloatAttr(min_max.first.getValue()),
|
||||
builder.getF64FloatAttr(min_max.second.getValue()),
|
||||
/*quant_dim=*/-1, num_bits, narrow_range, is_signed);
|
||||
builder.setInsertionPoint(block, insertion_point);
|
||||
auto q_op =
|
||||
builder.create<quant::QuantizeCastOp>(loc, params.getValue(), arg);
|
||||
|
||||
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -92,7 +93,9 @@ class LstmUtilsTest : public ::testing::Test {
|
||||
LstmUtilsTest() {}
|
||||
|
||||
void SetUp() override {
|
||||
builder_ = std::unique_ptr<mlir::Builder>(new Builder(&context_));
|
||||
RegisterDialects();
|
||||
context_ = std::make_unique<mlir::MLIRContext>();
|
||||
builder_ = std::unique_ptr<mlir::Builder>(new Builder(context_.get()));
|
||||
fused_lstm_func_ = createLstmCompositeFunc(builder_.get(), false, false);
|
||||
fused_lstm_func_cifg_ =
|
||||
createLstmCompositeFunc(builder_.get(), false, true);
|
||||
@ -105,10 +108,17 @@ class LstmUtilsTest : public ::testing::Test {
|
||||
fused_ln_lstm_func_.erase();
|
||||
builder_.reset();
|
||||
}
|
||||
|
||||
void RegisterDialects() {
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||
mlir::registerDialect<TensorFlowLiteDialect>();
|
||||
}
|
||||
|
||||
FuncOp fused_lstm_func_;
|
||||
FuncOp fused_lstm_func_cifg_;
|
||||
FuncOp fused_ln_lstm_func_;
|
||||
mlir::MLIRContext context_;
|
||||
std::unique_ptr<mlir::MLIRContext> context_;
|
||||
std::unique_ptr<mlir::Builder> builder_;
|
||||
};
|
||||
|
||||
|
||||
@ -1318,6 +1318,126 @@ greater than `clip_value_max` are set to `clip_value_max`.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CollectiveBcastRecvOp : TF_Op<"CollectiveBcastRecv", []> {
|
||||
let summary = "Receives a tensor value broadcast from another device.";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I64Attr:$group_size,
|
||||
I64Attr:$group_key,
|
||||
I64Attr:$instance_key,
|
||||
TF_ShapeAttr:$shape,
|
||||
DefaultValuedAttr<StrAttr, "auto">:$communication_hint
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F16, F32, F64, I1, I32, I64]>:$data
|
||||
);
|
||||
|
||||
TF_DerivedResultTypeAttr T = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CollectiveBcastSendOp : TF_Op<"CollectiveBcastSend", []> {
|
||||
let summary = "Broadcasts a tensor value to one or more other devices.";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F16, F32, F64, I1, I32, I64]>:$input,
|
||||
|
||||
I64Attr:$group_size,
|
||||
I64Attr:$group_key,
|
||||
I64Attr:$instance_key,
|
||||
TF_ShapeAttr:$shape,
|
||||
DefaultValuedAttr<StrAttr, "auto">:$communication_hint
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F16, F32, F64, I1, I32, I64]>:$data
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CollectiveGatherOp : TF_Op<"CollectiveGather", []> {
|
||||
let summary = [{
|
||||
Mutually accumulates multiple tensors of identical type and shape.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F16, F32, F64, I32, I64]>:$input,
|
||||
|
||||
I64Attr:$group_size,
|
||||
I64Attr:$group_key,
|
||||
I64Attr:$instance_key,
|
||||
TF_ShapeAttr:$shape,
|
||||
DefaultValuedAttr<StrAttr, "auto">:$communication_hint
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F16, F32, F64, I32, I64]>:$data
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CollectivePermuteOp : TF_Op<"CollectivePermute", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||
let summary = "An Op to permute tensors across replicated TPU instances.";
|
||||
|
||||
let description = [{
|
||||
Each instance supplies its own input.
|
||||
|
||||
For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
|
||||
source_target_pairs=`[[0,1],[1,2],[2,3],[3,0]]` gets the outputs:
|
||||
`[D, A, B, C]`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
|
||||
I32Tensor:$source_target_pairs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CollectiveReduceOp : TF_Op<"CollectiveReduce", [SameOperandsAndResultType]> {
|
||||
let summary = [{
|
||||
Mutually reduces multiple tensors of identical type and shape.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F16, F32, F64, I32, I64]>:$input,
|
||||
|
||||
I64Attr:$group_size,
|
||||
I64Attr:$group_key,
|
||||
I64Attr:$instance_key,
|
||||
TF_AnyStrAttrOf<["Min", "Max", "Mul", "Add"]>:$merge_op,
|
||||
TF_AnyStrAttrOf<["Id", "Div"]>:$final_op,
|
||||
I64ArrayAttr:$subdiv_offsets,
|
||||
DefaultValuedAttr<I64ArrayAttr, "{}">:$wait_for,
|
||||
DefaultValuedAttr<StrAttr, "auto">:$communication_hint
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F16, F32, F64, I32, I64]>:$data
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ComplexOp : TF_Op<"Complex", [NoSideEffect, ResultsBroadcastableShape]> {
|
||||
let summary = "Converts two real numbers to a complex number.";
|
||||
|
||||
|
||||
@ -203,12 +203,12 @@ func @moving_alloc_and_inserting_missing_dealloc(%cond : i1, %arg0 : memref<2xf3
|
||||
"buffer_assignment_test.unary_lowered"(%arg0, %1) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||
br ^exit(%1 : memref<2xf32>)
|
||||
^exit(%arg2: memref<2xf32>):
|
||||
"bufer_assignment_test.copy"(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||
"buffer_assignment_test.copy"(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc()
|
||||
// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc()
|
||||
// CHECK: "bufer_assignment_test.copy"
|
||||
// CHECK: "buffer_assignment_test.copy"
|
||||
// CHECK-NEXT: dealloc
|
||||
// CHECK-NEXT: dealloc
|
||||
// CHECK-NEXT: return
|
||||
@ -226,11 +226,11 @@ func @moving_invalid_dealloc_op_complex(%cond : i1, %arg0 : memref<2xf32>, %arg1
|
||||
dealloc %1 : memref<2xf32>
|
||||
br ^exit(%1 : memref<2xf32>)
|
||||
^exit(%arg2: memref<2xf32>):
|
||||
"bufer_assignment_test.copy"(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||
"buffer_assignment_test.copy"(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK-NEXT: %[[ALLOC:.*]] = alloc()
|
||||
// CHECK: bufer_assignment_test.copy
|
||||
// CHECK: buffer_assignment_test.copy
|
||||
// CHECK-NEXT: dealloc
|
||||
// CHECK-NEXT: return
|
||||
|
||||
@ -240,10 +240,10 @@ func @moving_invalid_dealloc_op_complex(%cond : i1, %arg0 : memref<2xf32>, %arg1
|
||||
func @inserting_missing_dealloc_simple(%arg0 : memref<2xf32>, %arg1: memref<2xf32>){
|
||||
%0 = alloc() : memref<2xf32>
|
||||
"buffer_assignment_test.unary_lowered"(%arg0, %0) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||
"bufer_assignment_test.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||
"buffer_assignment_test.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: bufer_assignment_test.copy
|
||||
// CHECK: buffer_assignment_test.copy
|
||||
// CHECK-NEXT: dealloc
|
||||
|
||||
// -----
|
||||
@ -253,8 +253,8 @@ func @moving_invalid_dealloc_op(%arg0 : memref<2xf32>, %arg1: memref<2xf32>){
|
||||
%0 = alloc() : memref<2xf32>
|
||||
"buffer_assignment_test.unary_lowered"(%arg0, %0) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||
dealloc %0 : memref<2xf32>
|
||||
"bufer_assignment_test.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||
"buffer_assignment_test.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: bufer_assignment_test.copy
|
||||
// CHECK-NEXT: dealloc
|
||||
// CHECK: buffer_assignment_test.copy
|
||||
// CHECK-NEXT: dealloc
|
||||
|
||||
@ -29,60 +29,66 @@ limitations under the License.
|
||||
namespace mlir {
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
/// This dialect independent unary operation has been defined only for testing
|
||||
/// buffer assignment.
|
||||
class BufferAssignmentTestUnaryOp
|
||||
: public Op<BufferAssignmentTestUnaryOp, OpTrait::OneResult,
|
||||
OpTrait::OneOperand> {
|
||||
public:
|
||||
using Op::Op;
|
||||
static StringRef getOperationName() { return "buffer_assignment_test.unary"; }
|
||||
static void build(OpBuilder& b, OperationState& state, Value source) {
|
||||
state.addOperands(source);
|
||||
}
|
||||
};
|
||||
|
||||
/// This dialect independent lowered unary operation has been defined only for
|
||||
/// testing buffer assignment.
|
||||
class BufferAssignmentTestUnaryLoweredOp
|
||||
: public Op<BufferAssignmentTestUnaryLoweredOp, OpTrait::ZeroResult,
|
||||
OpTrait::NOperands<2>::Impl> {
|
||||
public:
|
||||
using Op::Op;
|
||||
static StringRef getOperationName() {
|
||||
return "buffer_assignment_test.unary_lowered";
|
||||
}
|
||||
static void build(OpBuilder& b, OperationState& state, Value source,
|
||||
Value target) {
|
||||
state.addOperands(source);
|
||||
state.addOperands(target);
|
||||
}
|
||||
};
|
||||
|
||||
/// This dialect independent copy operation has been defined only for testing
|
||||
/// NonVoidToVoidReturnOpConverter
|
||||
class BufferAssignmentTestCopyOp
|
||||
: public Op<BufferAssignmentTestCopyOp, OpTrait::ZeroResult,
|
||||
OpTrait::NOperands<2>::Impl> {
|
||||
public:
|
||||
using Op::Op;
|
||||
static StringRef getOperationName() { return "buffer_assignment_test.copy"; }
|
||||
static void build(OpBuilder& b, OperationState& state, Value from, Value to) {
|
||||
state.addOperands(from);
|
||||
state.addOperands(to);
|
||||
}
|
||||
};
|
||||
|
||||
class BufferAssignmentTestDialect : public Dialect {
|
||||
public:
|
||||
explicit BufferAssignmentTestDialect(MLIRContext* context)
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
addOperations<BufferAssignmentTestCopyOp, BufferAssignmentTestUnaryOp,
|
||||
BufferAssignmentTestUnaryLoweredOp>();
|
||||
}
|
||||
static StringRef getDialectNamespace() { return "buffer_assignment_test"; }
|
||||
};
|
||||
|
||||
/// This pass tests two provided operation converters,
|
||||
/// FunctionAndBlockSignatureConverter and NonVoidToVoidReturnOpConverter, for
|
||||
/// Buffer Assignment.
|
||||
struct BufferAssignmentPreparationTestPass
|
||||
: mlir::PassWrapper<BufferAssignmentPreparationTestPass, FunctionPass> {
|
||||
/// This dialect independent unary operation has been defined only for testing
|
||||
/// buffer assignment.
|
||||
class BufferAssignmentTestUnaryOp
|
||||
: public Op<BufferAssignmentTestUnaryOp, OpTrait::OneResult,
|
||||
OpTrait::OneOperand> {
|
||||
public:
|
||||
using Op::Op;
|
||||
static StringRef getOperationName() {
|
||||
return "buffer_assignment_test.unary";
|
||||
}
|
||||
static void build(OpBuilder& b, OperationState& state, Value source) {
|
||||
state.addOperands(source);
|
||||
}
|
||||
};
|
||||
|
||||
/// This dialect independent lowered unary operation has been defined only for
|
||||
/// testing buffer assignment.
|
||||
class BufferAssignmentTestUnaryLoweredOp
|
||||
: public Op<BufferAssignmentTestUnaryLoweredOp, OpTrait::ZeroResult,
|
||||
OpTrait::NOperands<2>::Impl> {
|
||||
public:
|
||||
using Op::Op;
|
||||
static StringRef getOperationName() {
|
||||
return "buffer_assignment_test.unary_lowered";
|
||||
}
|
||||
static void build(OpBuilder& b, OperationState& state, Value source,
|
||||
Value target) {
|
||||
state.addOperands(source);
|
||||
state.addOperands(target);
|
||||
}
|
||||
};
|
||||
|
||||
/// This dialect independent copy operation has been defined only for testing
|
||||
/// NonVoidToVoidReturnOpConverter
|
||||
class BufferAssignmentTestCopyOp
|
||||
: public Op<BufferAssignmentTestCopyOp, OpTrait::ZeroResult,
|
||||
OpTrait::NOperands<2>::Impl> {
|
||||
public:
|
||||
using Op::Op;
|
||||
static StringRef getOperationName() {
|
||||
return "buffer_assignment_test.copy";
|
||||
}
|
||||
static void build(OpBuilder& b, OperationState& state, Value from,
|
||||
Value to) {
|
||||
state.addOperands(from);
|
||||
state.addOperands(to);
|
||||
}
|
||||
};
|
||||
|
||||
/// A simple converter that legalizes a BufferAssignmentTestUnaryOp to a
|
||||
/// BufferAssignmentTestUnaryLoweredOp and creates buffer allocation for
|
||||
/// the result of the computation.
|
||||
@ -151,8 +157,12 @@ struct BufferAssignmentPreparationTestPass
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
static mlir::DialectRegistration<BufferAssignmentTestDialect>
|
||||
buffer_assignment_test_ops;
|
||||
|
||||
/// This pass tests helper methods such as computeAllocPosition,
|
||||
/// FunctionAndBlockSignatureConverter, NonVoidToVoidReturnOpConverter
|
||||
/// conversion patterns. Furthermore, it checks buffer-assignment pass that
|
||||
|
||||
@ -19,7 +19,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -1609,8 +1608,4 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# TODO(b/130689556): XLA CPU does not honor inf/nan which causes problems
|
||||
os.environ[
|
||||
"XLA_FLAGS"] = "--xla_cpu_enable_fast_math=false " + os.environ.get(
|
||||
"XLA_FLAGS", "")
|
||||
googletest.main()
|
||||
|
||||
@ -347,17 +347,15 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
expected=np.array(
|
||||
[1.55740772, -2.18503986, -0.14254654, 1.15782128], dtype=dtype))
|
||||
|
||||
# TODO(b/130689556): Turn this on for CPU when we start honoring NaNs.
|
||||
if self.device != "XLA_CPU":
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.tanh,
|
||||
np.array([[1, 2, 3, 4], [np.inf, -np.inf, np.nan, 20],
|
||||
[19, -19, 22, -22]],
|
||||
dtype=dtype),
|
||||
expected=np.array(
|
||||
[[0.76159418, 0.96402758, 0.99505478, 0.99932933],
|
||||
[1.0, -1.0, np.nan, 1.0], [1.0, -1.0, 1.0, -1.0]],
|
||||
dtype=dtype))
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.tanh,
|
||||
np.array([[1, 2, 3, 4], [np.inf, -np.inf, np.nan, 20],
|
||||
[19, -19, 22, -22]],
|
||||
dtype=dtype),
|
||||
expected=np.array(
|
||||
[[0.76159418, 0.96402758, 0.99505478, 0.99932933],
|
||||
[1.0, -1.0, np.nan, 1.0], [1.0, -1.0, 1.0, -1.0]],
|
||||
dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
nn_ops.log_softmax,
|
||||
|
||||
@ -136,8 +136,11 @@ class TensorListReserveOp : public XlaOpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements));
|
||||
OP_REQUIRES(
|
||||
ctx, num_elements >= 0,
|
||||
errors::InvalidArgument("XLA compilation requires a fixed tensor list "
|
||||
"size. Set the number of elements."));
|
||||
errors::InvalidArgument(
|
||||
"XLA compilation requires a fixed tensor list size. Set the number "
|
||||
"of elements. This could also happen if you're using a TensorArray "
|
||||
"in a while loop that does not have its maximum_iteration set, you "
|
||||
"can fix this by setting maximum_iteration to a suitable value."));
|
||||
|
||||
// If element shape is compile time constant and it's not "unknown rank"
|
||||
// shape (-1), create an initialized TensorList. Otherwise create an
|
||||
@ -197,10 +200,13 @@ class EmptyTensorListOp : public XlaOpKernel {
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
int64 max_num_elements;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements));
|
||||
OP_REQUIRES(
|
||||
ctx, max_num_elements >= 0,
|
||||
errors::InvalidArgument("XLA compilation requires a fixed tensor list "
|
||||
"size. Set the max number of elements."));
|
||||
OP_REQUIRES(ctx, max_num_elements >= 0,
|
||||
errors::InvalidArgument(
|
||||
"XLA compilation requires a fixed tensor list size. Set "
|
||||
"the max number of elements. This could also happen if "
|
||||
"you're using a TensorArray in a while loop that does not "
|
||||
"have its maximum_iteration set, you can fix this by "
|
||||
"setting maximum_iteration to a suitable value."));
|
||||
|
||||
if (dtype_ != DT_VARIANT) {
|
||||
// We are creating a non-nested TensorList.
|
||||
|
||||
@ -63,10 +63,6 @@ class ExecutionInput {
|
||||
explicit ExecutionInput(xla::Shape shape) : buffers_(std::move(shape)) {}
|
||||
explicit ExecutionInput(ShapeTree<MaybeOwningDeviceMemory> buffers)
|
||||
: buffers_(std::move(buffers)) {}
|
||||
ExecutionInput(ShapeTree<MaybeOwningDeviceMemory> buffers,
|
||||
std::vector<ShapeIndex> owner_held_indices)
|
||||
: buffers_(std::move(buffers)),
|
||||
unowned_indices_(std::move(owner_held_indices)) {}
|
||||
ExecutionInput(ExecutionInput&&) = default;
|
||||
|
||||
~ExecutionInput() {
|
||||
|
||||
@ -40,7 +40,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
|
||||
public:
|
||||
// A NestedComputer computes an element of the output of the given computation
|
||||
// given a Span of its input elements.
|
||||
using NestedComputer = std::function<StatusOr<llvm::Value*>(
|
||||
using NestedComputer = std::function<StatusOr<std::vector<llvm::Value*>>(
|
||||
const HloComputation&, absl::Span<llvm::Value* const>)>;
|
||||
|
||||
GpuElementalIrEmitter(const HloModuleConfig& hlo_module_config,
|
||||
@ -91,12 +91,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
|
||||
StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
|
||||
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
|
||||
absl::string_view) override {
|
||||
// TODO(b/118332391): Supported variadic return values.
|
||||
auto result = compute_nested_(callee, parameters);
|
||||
if (!result.ok()) {
|
||||
return result.status();
|
||||
}
|
||||
return std::vector<llvm::Value*>{result.ValueOrDie()};
|
||||
return compute_nested_(callee, parameters);
|
||||
}
|
||||
|
||||
llvm::Value* EmitThreadId() override;
|
||||
|
||||
@ -698,115 +698,6 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleReduce(HloInstruction* instr) {
|
||||
const HloReduceInstruction* reduce = Cast<HloReduceInstruction>(instr);
|
||||
const Shape& out_shape = reduce->shape();
|
||||
bool returns_tuple = !out_shape.IsArray();
|
||||
int accumulators_count = 1;
|
||||
if (returns_tuple) {
|
||||
CHECK(out_shape.IsTuple());
|
||||
accumulators_count = out_shape.tuple_shapes_size();
|
||||
}
|
||||
|
||||
auto arg = reduce->operand(0);
|
||||
absl::Span<const int64> dimensions(reduce->dimensions());
|
||||
HloComputation* function = reduce->to_apply();
|
||||
return EmitTargetElementLoop(
|
||||
*reduce,
|
||||
[=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
||||
std::vector<llvm::Value*> accumulator_addrs;
|
||||
std::vector<llvm::Type*> accumulator_types;
|
||||
|
||||
// Initialize accumulators with initial values.
|
||||
for (int i = 0; i < accumulators_count; i++) {
|
||||
auto init_value = reduce->init_values()[i];
|
||||
const Shape& element_shape =
|
||||
returns_tuple ? out_shape.tuple_shapes(i) : out_shape;
|
||||
PrimitiveType accumulator_type = element_shape.element_type();
|
||||
llvm::Type* accumulator_llvm_type =
|
||||
llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_);
|
||||
llvm::AllocaInst* accumulator_addr = Alloca(accumulator_llvm_type);
|
||||
Store(Load(GetBasePointer(*init_value)), accumulator_addr);
|
||||
accumulator_addrs.push_back(accumulator_addr);
|
||||
accumulator_types.push_back(accumulator_llvm_type);
|
||||
}
|
||||
|
||||
// The enclosing loops go over all the target elements. Now we have to
|
||||
// compute the actual target element. For this, we build a new loop nest
|
||||
// to iterate over all the reduction dimensions in the argument.
|
||||
// AddLoopsForShapeOnDimensions will return an Index where induction
|
||||
// Value*s are placed for each dimension in dimensions, and all the rest
|
||||
// are nullptrs.
|
||||
llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_);
|
||||
std::vector<llvm::Value*> input_multi_index =
|
||||
loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
|
||||
"reduction_dim");
|
||||
|
||||
SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
|
||||
|
||||
// Build a full index for the input argument, using reduced_dims_index
|
||||
// as the base. In reduced_dims_index only the reduction dimensions are
|
||||
// filled in. We fill in the rest of the dimensions with induction
|
||||
// Value*s taken from 'index' which iterates over the target array.
|
||||
// See the high-level description in the XLA documentation for details.
|
||||
llvm_ir::IrArray::Index::const_iterator it = index.begin();
|
||||
|
||||
for (auto& i : input_multi_index) {
|
||||
if (i == nullptr) {
|
||||
i = *it++;
|
||||
}
|
||||
}
|
||||
CHECK(index.end() == it);
|
||||
|
||||
// Apply the reduction function to the loaded value.
|
||||
llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
|
||||
b_.getInt64Ty());
|
||||
std::vector<llvm::Value*> reduction_operands(accumulator_addrs.begin(),
|
||||
accumulator_addrs.end());
|
||||
for (int i = 0; i < accumulators_count; i++) {
|
||||
llvm::Value* input_address =
|
||||
GetIrArray(*reduce->operand(i), *reduce)
|
||||
.EmitArrayElementAddress(input_index, &b_);
|
||||
reduction_operands.push_back(input_address);
|
||||
}
|
||||
|
||||
llvm::Value* ret_argument;
|
||||
if (!returns_tuple) {
|
||||
CHECK_EQ(accumulator_addrs.size(), 1);
|
||||
ret_argument = accumulator_addrs[0];
|
||||
} else {
|
||||
const Shape& return_shape = function->root_instruction()->shape();
|
||||
|
||||
llvm::Type* return_value_buffer_type =
|
||||
llvm_ir::ShapeToIrType(return_shape, module_);
|
||||
ret_argument = Alloca(return_value_buffer_type);
|
||||
llvm_ir::IrArray tuple_array(ret_argument, return_shape);
|
||||
EmitTuple(tuple_array, accumulator_addrs, &b_);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
|
||||
*function, reduction_operands, ret_argument));
|
||||
|
||||
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
|
||||
|
||||
if (!returns_tuple) {
|
||||
CHECK_EQ(accumulator_addrs.size(), 1);
|
||||
return Load(accumulator_addrs[0]);
|
||||
} else {
|
||||
// Emit a struct for the LoopEmitter dealing with multi-output
|
||||
// fusion.
|
||||
llvm::Value* returned_structure = llvm::UndefValue::get(
|
||||
llvm::StructType::get(b_.getContext(), accumulator_types));
|
||||
for (int i = 0; i < accumulators_count; i++) {
|
||||
llvm::Value* accumulator_value = Load(accumulator_addrs[i]);
|
||||
returned_structure =
|
||||
b_.CreateInsertValue(returned_structure, accumulator_value, i);
|
||||
}
|
||||
return returned_structure;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleFusion(HloInstruction* fusion) {
|
||||
// kFusion for library calls should be handled by
|
||||
// IrEmitterUnnested::HandleFusion.
|
||||
@ -866,22 +757,39 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
|
||||
"to a cudnn CustomCall using CudnnBatchNormRewriter.");
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
|
||||
StatusOr<std::vector<llvm::Value*>> IrEmitter::ComputeNestedElement(
|
||||
const HloComputation& computation,
|
||||
absl::Span<llvm::Value* const> parameter_elements) {
|
||||
const Shape& return_shape = computation.root_instruction()->shape();
|
||||
llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
llvm_ir::PrimitiveTypeToIrType(
|
||||
computation.root_instruction()->shape().element_type(), module_),
|
||||
"return_buffer", &b_);
|
||||
llvm_ir::ShapeToIrType(return_shape, module_), "return_buffer", &b_);
|
||||
std::vector<llvm::Value*> parameter_buffers;
|
||||
for (llvm::Value* parameter_element : parameter_elements) {
|
||||
parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
parameter_element->getType(), "parameter_buffer", &b_));
|
||||
Store(parameter_element, parameter_buffers.back());
|
||||
}
|
||||
|
||||
std::vector<llvm::Value*> allocas_for_returned_scalars;
|
||||
if (!return_shape.IsTuple()) {
|
||||
allocas_for_returned_scalars.push_back(return_buffer);
|
||||
} else {
|
||||
allocas_for_returned_scalars =
|
||||
llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_);
|
||||
llvm_ir::IrArray tuple_array(return_buffer, return_shape);
|
||||
|
||||
EmitTuple(tuple_array, allocas_for_returned_scalars, &b_);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers,
|
||||
return_buffer));
|
||||
return Load(return_buffer);
|
||||
|
||||
std::vector<llvm::Value*> returned_scalars;
|
||||
returned_scalars.reserve(allocas_for_returned_scalars.size());
|
||||
for (llvm::Value* addr : allocas_for_returned_scalars) {
|
||||
returned_scalars.push_back(Load(addr));
|
||||
}
|
||||
return returned_scalars;
|
||||
}
|
||||
|
||||
std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs(
|
||||
|
||||
@ -89,7 +89,6 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
Status HandleRecv(HloInstruction* recv) override;
|
||||
Status HandleRecvDone(HloInstruction* recv_done) override;
|
||||
Status HandleParameter(HloInstruction* parameter) override;
|
||||
Status HandleReduce(HloInstruction* reduce) override;
|
||||
Status HandleTuple(HloInstruction* tuple) override;
|
||||
Status HandleScatter(HloInstruction* scatter) override;
|
||||
Status HandleSelect(HloInstruction* select) override;
|
||||
@ -213,7 +212,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
const llvm_ir::IrArray::Index& compare_keys_index,
|
||||
const llvm_ir::IrArray& keys_array);
|
||||
|
||||
StatusOr<llvm::Value*> ComputeNestedElement(
|
||||
StatusOr<std::vector<llvm::Value*>> ComputeNestedElement(
|
||||
const HloComputation& computation,
|
||||
absl::Span<llvm::Value* const> parameter_elements);
|
||||
|
||||
|
||||
@ -312,12 +312,13 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
|
||||
class HloDotDumper {
|
||||
public:
|
||||
HloDotDumper(const HloComputation* computation, absl::string_view label,
|
||||
const DebugOptions& debug_options, bool show_backend_config,
|
||||
const DebugOptions& debug_options,
|
||||
HloRenderOptions hlo_render_options,
|
||||
const HloExecutionProfile* profile, NodeFilter filter)
|
||||
: computation_(computation),
|
||||
label_(label),
|
||||
debug_options_(debug_options),
|
||||
show_backend_config_(show_backend_config),
|
||||
hlo_render_options_(hlo_render_options),
|
||||
profile_(profile),
|
||||
filter_(std::move(filter)) {}
|
||||
|
||||
@ -384,7 +385,7 @@ class HloDotDumper {
|
||||
const HloComputation* computation_; // never null
|
||||
const string label_; // overall name for the graph
|
||||
const DebugOptions& debug_options_;
|
||||
const bool show_backend_config_;
|
||||
const HloRenderOptions hlo_render_options_;
|
||||
const HloExecutionProfile* profile_; // may be null
|
||||
const NodeFilter filter_;
|
||||
|
||||
@ -565,7 +566,8 @@ bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
|
||||
bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
|
||||
if (subcomp->IsFusionComputation()) {
|
||||
const HloInstruction* fusion = subcomp->FusionInstruction();
|
||||
if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion)) {
|
||||
if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion) ||
|
||||
!hlo_render_options_.show_fusion_subcomputations) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -1133,7 +1135,8 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
|
||||
|
||||
string HloDotDumper::GetInstructionNodeBackendConfig(
|
||||
const HloInstruction* instr) {
|
||||
if (!show_backend_config_ || instr->raw_backend_config_string().empty()) {
|
||||
if (!hlo_render_options_.show_backend_config ||
|
||||
instr->raw_backend_config_string().empty()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
@ -1604,14 +1607,14 @@ StatusOr<string> RenderGraph(const HloComputation& computation,
|
||||
const DebugOptions& debug_options,
|
||||
RenderedGraphFormat format,
|
||||
const HloExecutionProfile* hlo_execution_profile,
|
||||
bool show_backend_config) {
|
||||
HloRenderOptions hlo_render_options) {
|
||||
tensorflow::mutex_lock lock(url_renderer_mu);
|
||||
if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
|
||||
return Unavailable("Can't render as URL; no URL renderer was registered.");
|
||||
}
|
||||
|
||||
string rendered_dot =
|
||||
HloDotDumper(&computation, label, debug_options, show_backend_config,
|
||||
HloDotDumper(&computation, label, debug_options, hlo_render_options,
|
||||
hlo_execution_profile, NodeFilter())
|
||||
.Dump();
|
||||
return WrapDotInFormat(rendered_dot, format);
|
||||
@ -1619,7 +1622,7 @@ StatusOr<string> RenderGraph(const HloComputation& computation,
|
||||
|
||||
StatusOr<string> RenderNeighborhoodAround(
|
||||
const HloInstruction& node, int radius, RenderedGraphFormat format,
|
||||
bool show_backend_config,
|
||||
HloRenderOptions hlo_render_options,
|
||||
const absl::flat_hash_set<const HloInstruction*>& boundary) {
|
||||
tensorflow::mutex_lock lock(url_renderer_mu);
|
||||
if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
|
||||
@ -1632,7 +1635,7 @@ StatusOr<string> RenderNeighborhoodAround(
|
||||
string rendered_dot =
|
||||
HloDotDumper(node.parent(), label,
|
||||
node.GetModule()->config().debug_options(),
|
||||
show_backend_config, /*profile=*/nullptr,
|
||||
hlo_render_options, /*profile=*/nullptr,
|
||||
MakeNodeRadiusAroundFilter(&node, radius, boundary))
|
||||
.Dump();
|
||||
return WrapDotInFormat(rendered_dot, format);
|
||||
@ -1641,7 +1644,7 @@ StatusOr<string> RenderNeighborhoodAround(
|
||||
StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
|
||||
const HloInstruction& to, int64 max_nodes,
|
||||
RenderedGraphFormat format,
|
||||
bool show_backend_config) {
|
||||
HloRenderOptions hlo_render_options) {
|
||||
tensorflow::mutex_lock lock(url_renderer_mu);
|
||||
if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
|
||||
return FailedPrecondition(
|
||||
@ -1663,7 +1666,7 @@ StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
|
||||
"NODES***<br/><br/>");
|
||||
}
|
||||
string rendered_dot =
|
||||
HloDotDumper(from.parent(), label, debug_options, show_backend_config,
|
||||
HloDotDumper(from.parent(), label, debug_options, hlo_render_options,
|
||||
/*profile=*/nullptr, filter)
|
||||
.Dump();
|
||||
return WrapDotInFormat(rendered_dot, format);
|
||||
|
||||
@ -50,6 +50,14 @@ enum class RenderedGraphFormat {
|
||||
kUrl,
|
||||
};
|
||||
|
||||
struct HloRenderOptions {
|
||||
// Include the backend config string in the rendered graph.
|
||||
bool show_backend_config = false;
|
||||
|
||||
// Include the fusion subcomputations in the rendered graph.
|
||||
bool show_fusion_subcomputations = true;
|
||||
};
|
||||
|
||||
// Renders an HLO module as a human-readable visual graph.
|
||||
//
|
||||
// Note that this only works well for relatively small graphs (no more than a
|
||||
@ -61,7 +69,7 @@ StatusOr<string> RenderGraph(
|
||||
const HloComputation& computation, absl::string_view label,
|
||||
const DebugOptions& debug_options, RenderedGraphFormat format,
|
||||
const HloExecutionProfile* hlo_execution_profile = nullptr,
|
||||
bool show_backend_config = false);
|
||||
HloRenderOptions hlo_render_options = {});
|
||||
|
||||
// Like RenderGraph, but renders only nodes "near" the given node in the graph.
|
||||
//
|
||||
@ -73,7 +81,7 @@ StatusOr<string> RenderGraph(
|
||||
// will be omitted even if they are within the radius.
|
||||
StatusOr<string> RenderNeighborhoodAround(
|
||||
const HloInstruction& node, int radius, RenderedGraphFormat format,
|
||||
bool show_backend_config = false,
|
||||
HloRenderOptions hlo_render_options = {},
|
||||
const absl::flat_hash_set<const HloInstruction*>& boundary = {});
|
||||
|
||||
// Renders nodes on any of the paths from `from` to `to`. If there are more
|
||||
@ -82,7 +90,7 @@ StatusOr<string> RenderNeighborhoodAround(
|
||||
StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
|
||||
const HloInstruction& to, int64 max_nodes,
|
||||
RenderedGraphFormat format,
|
||||
bool show_backend_config = false);
|
||||
HloRenderOptions hlo_render_options = {});
|
||||
|
||||
// Registers a function which implements RenderedGraphFormat::kUrl.
|
||||
//
|
||||
|
||||
@ -137,9 +137,8 @@ class ShapedBuffer {
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer);
|
||||
|
||||
// ShapedBuffer derived class which allocates all internal buffers on
|
||||
// construction and deallocates the memory when the object is
|
||||
// destructed.
|
||||
// ScopedShapedBuffer takes allocated buffers as inputs, and deallocates on
|
||||
// destruction. This class represents an owning wrapper around `ShapedBuffer`.
|
||||
//
|
||||
// TODO(timshen): Remove inheritance between ScopedShapedBuffer and
|
||||
// ShapedBuffer. There should never be a need to consider a ScopedShapedBuffer
|
||||
|
||||
@ -81,11 +81,10 @@ void SpmdLogger::RegisterLogEntry(HloInstruction* hlo,
|
||||
string report = hlo->ToString();
|
||||
int64 max_value = -1;
|
||||
for (HloInstruction* inst : group) {
|
||||
if (inst->shape().IsTuple()) {
|
||||
if (!inst->shape().IsArray()) {
|
||||
continue;
|
||||
}
|
||||
max_value =
|
||||
std::max<int64>(max_value, ShapeUtil::ByteSizeOf(inst->shape(), 4));
|
||||
max_value = std::max<int64>(max_value, ShapeSizeInBytes(inst->shape()));
|
||||
absl::StrAppend(&report, " * ", inst->ToString(), "\n");
|
||||
}
|
||||
entries_.push_back(std::make_pair(max_value, report));
|
||||
@ -149,14 +148,14 @@ template <typename F>
|
||||
const auto add_report = [&](std::vector<HloInstruction*>* insts) {
|
||||
std::sort(insts->begin(), insts->end(),
|
||||
[](const HloInstruction* inst0, const HloInstruction* inst1) {
|
||||
return ShapeUtil::ByteSizeOf(inst0->shape()) >
|
||||
ShapeUtil::ByteSizeOf(inst1->shape());
|
||||
return ShapeSizeInBytes(inst0->shape()) >
|
||||
ShapeSizeInBytes(inst1->shape());
|
||||
});
|
||||
for (int64 i = 0;
|
||||
i < std::min<int64>(report_instruction_count, insts->size()); ++i) {
|
||||
absl::StrAppend(&report, " ",
|
||||
tensorflow::strings::HumanReadableNumBytes(
|
||||
ShapeUtil::ByteSizeOf((*insts)[i]->shape())),
|
||||
ShapeSizeInBytes((*insts)[i]->shape())),
|
||||
" : ", (*insts)[i]->ToString(), "\n");
|
||||
}
|
||||
};
|
||||
@ -1180,8 +1179,8 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
|
||||
if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims(
|
||||
operand, scatter_dims_to_operand_dims, slice_size,
|
||||
num_partitions_) &&
|
||||
ShapeUtil::ByteSizeOf(updates.base_shape()) <
|
||||
ShapeUtil::ByteSizeOf(scatter->shape())) {
|
||||
ShapeSizeInBytes(updates.base_shape()) <
|
||||
ShapeSizeInBytes(scatter->shape())) {
|
||||
// Operand is sharded on trivial slice dims (update slice size 1). We can
|
||||
// adjust the indices on each partition by subtracting the offsets. Then
|
||||
// we execute a scatter on full updated indices, and out-of-bound accesses
|
||||
@ -1968,8 +1967,8 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) {
|
||||
if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims(
|
||||
operand, start_index_map, gather->gather_slice_sizes(),
|
||||
num_partitions_) &&
|
||||
ShapeUtil::ByteSizeOf(gather->shape()) <
|
||||
ShapeUtil::ByteSizeOf(gather->operand(0)->shape())) {
|
||||
ShapeSizeInBytes(gather->shape()) <
|
||||
ShapeSizeInBytes(gather->operand(0)->shape())) {
|
||||
indices = indices.Reshard(HloSharding::Replicate());
|
||||
// Now the operand is partitioned in trivial slice dimensions, and the
|
||||
// indices are replicated. We execute a gather on partitioned operand,
|
||||
@ -2762,8 +2761,7 @@ Status SpmdPartitioningVisitor::HandleConvolutionTiledLhsAndRhs(
|
||||
|
||||
auto zero = b_.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(hlo->shape().element_type())));
|
||||
if (ShapeUtil::ByteSizeOf(lhs.base_shape()) <
|
||||
ShapeUtil::ByteSizeOf(rhs.base_shape())) {
|
||||
if (ShapeSizeInBytes(lhs.base_shape()) < ShapeSizeInBytes(rhs.base_shape())) {
|
||||
if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
@ -3005,8 +3003,8 @@ Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) {
|
||||
};
|
||||
auto zero = b_.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(hlo->shape().element_type())));
|
||||
if (ShapeUtil::ByteSizeOf(lhs.base_shape()) <
|
||||
ShapeUtil::ByteSizeOf(rhs.base_shape())) {
|
||||
if (ShapeSizeInBytes(lhs.base_shape()) <
|
||||
ShapeSizeInBytes(rhs.base_shape())) {
|
||||
if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
@ -3731,7 +3729,7 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
|
||||
};
|
||||
if (output_lhs_non_contracting_partitions == num_partitions_ &&
|
||||
output_sharding_transposed_to_match_lhs == lhs_sharding &&
|
||||
ShapeUtil::ByteSizeOf(hlo->operand(1)->shape()) >=
|
||||
ShapeSizeInBytes(hlo->operand(1)->shape()) >=
|
||||
options_.threshold_for_windowed_einsum_mib * 1024 * 1024) {
|
||||
if (rhs_contracting_partitions == num_partitions_) {
|
||||
return emit_windowed_dot_general(0, 1, true, false);
|
||||
@ -3745,7 +3743,7 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
|
||||
}
|
||||
if (output_rhs_non_contracting_partitions == num_partitions_ &&
|
||||
output_sharding_transposed_to_match_rhs == rhs_sharding &&
|
||||
ShapeUtil::ByteSizeOf(hlo->operand(0)->shape()) >=
|
||||
ShapeSizeInBytes(hlo->operand(0)->shape()) >=
|
||||
options_.threshold_for_windowed_einsum_mib * 1024 * 1024) {
|
||||
if (lhs_contracting_partitions == num_partitions_) {
|
||||
return emit_windowed_dot_general(1, 0, true, false);
|
||||
@ -3775,8 +3773,8 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
|
||||
LiteralUtil::Zero(hlo->shape().element_type())));
|
||||
// Pad both sides with zero, since NaN at one side cannot be masked by zero
|
||||
// on the other side.
|
||||
if (ShapeUtil::ByteSizeOf(lhs.base_shape()) <
|
||||
ShapeUtil::ByteSizeOf(rhs.base_shape())) {
|
||||
if (ShapeSizeInBytes(lhs.base_shape()) <
|
||||
ShapeSizeInBytes(rhs.base_shape())) {
|
||||
lhs =
|
||||
lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero);
|
||||
rhs = rhs.PadWithValue(zero);
|
||||
@ -4607,8 +4605,8 @@ HloInstruction* SpmdPartitioner::AllGatherShards(SpmdBuilder* b,
|
||||
xpose_permutation[i] = i + tiled_dims.size() - split_dims_added;
|
||||
} else {
|
||||
xpose_permutation[i] = split_dims_added;
|
||||
xpose_permutation[i + 1] = i + tiled_dims.size() - split_dims_added;
|
||||
split_dims_added++;
|
||||
xpose_permutation[i + 1] = i + tiled_dims.size();
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
@ -649,6 +649,43 @@ ENTRY entry {
|
||||
op::ReduceWindow(masked, op::Constant())));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideHaloBeyondNeighbor) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
sum {
|
||||
a = f32[] parameter(0)
|
||||
b = f32[] parameter(1)
|
||||
ROOT add = f32[] add(a, b)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
param = f32[9,2] parameter(0), sharding={devices=[5,1]0,1,2,3,4}
|
||||
constant.1 = f32[] constant(0), sharding={replicated}
|
||||
ROOT reduce-window = f32[5,2]{1,0} reduce-window(param, constant.1),
|
||||
window={size=4x1 stride=2x1 pad=3_0x0_0}, to_apply=sum,
|
||||
sharding={devices=[5,1]0,1,2,3,4}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/5));
|
||||
VLOG(1) << module->ToString();
|
||||
auto halo0 = AllOf(op::Shape("f32[1,2]"),
|
||||
op::CollectivePermute(op::Slice(op::Parameter(0))));
|
||||
auto halo1 =
|
||||
AllOf(op::Shape("f32[2,2]"), op::CollectivePermute(op::Parameter(0)));
|
||||
auto pre_mask =
|
||||
AllOf(op::Shape("f32[4,2]"),
|
||||
op::Slice(AllOf(op::Shape("f32[5,2]"),
|
||||
op::Concatenate(halo0, halo1, op::Parameter(0)))));
|
||||
auto masked =
|
||||
op::Select(op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply())),
|
||||
op::Broadcast(op::Constant())),
|
||||
pre_mask, op::Broadcast(op::Constant()));
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, AllOf(op::Shape("f32[1,2]{1,0}"),
|
||||
op::ReduceWindow(masked, op::Constant())));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideUnequalHalo) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
@ -23,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
@ -104,6 +107,11 @@ Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) {
|
||||
return sharding.TileShape(shape);
|
||||
}
|
||||
|
||||
int64 ShapeSizeInBytes(const Shape& shape) {
|
||||
return ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()) *
|
||||
ShapeUtil::ElementsIn(shape);
|
||||
}
|
||||
|
||||
Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,
|
||||
const HloSharding& sharding,
|
||||
int64 partition_id) {
|
||||
@ -402,33 +410,30 @@ absl::optional<HloInstruction*> ExchangeHalo(
|
||||
std::vector<HloInstruction*> concat_pieces;
|
||||
|
||||
int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count);
|
||||
if (max_left_halo_size > input_shard_size) {
|
||||
VLOG(1) << "ExchangeHalo failed: halo is beyond the left neighbor.";
|
||||
return absl::nullopt;
|
||||
}
|
||||
if (max_left_halo_size > 0) {
|
||||
for (int64 i = CeilOfRatio(max_left_halo_size, input_shard_size) - 1; i >= 0;
|
||||
--i) {
|
||||
std::vector<std::pair<int64, int64>> source_target_pairs;
|
||||
target.tile_assignment().Each(
|
||||
[&](absl::Span<const int64> indices, int64 device) {
|
||||
if (indices[dim] > 0) {
|
||||
if (indices[dim] > i) {
|
||||
std::vector<int64> source_indices(indices.begin(), indices.end());
|
||||
source_indices[dim] -= 1;
|
||||
source_indices[dim] -= i + 1;
|
||||
source_target_pairs.emplace_back(
|
||||
target.tile_assignment()(source_indices), device);
|
||||
}
|
||||
});
|
||||
int64 halo_size =
|
||||
std::min(max_left_halo_size - input_shard_size * i, input_shard_size);
|
||||
auto halo_shape = hlo->shape();
|
||||
auto source_halo_slice = hlo;
|
||||
if (max_left_halo_size != hlo->shape().dimensions(dim)) {
|
||||
halo_shape.set_dimensions(dim, max_left_halo_size);
|
||||
if (halo_size != hlo->shape().dimensions(dim)) {
|
||||
halo_shape.set_dimensions(dim, halo_size);
|
||||
std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
|
||||
halo_start_indices[dim] =
|
||||
hlo->shape().dimensions(dim) - max_left_halo_size;
|
||||
halo_start_indices[dim] = hlo->shape().dimensions(dim) - halo_size;
|
||||
std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
|
||||
|
||||
source_halo_slice = b->AddInstruction(
|
||||
hlo->CreateSlice(halo_shape, hlo, halo_start_indices,
|
||||
hlo->shape().dimensions(), halo_slice_strides));
|
||||
source_halo_slice = b->AddInstruction(HloInstruction::CreateSlice(
|
||||
halo_shape, hlo, halo_start_indices, hlo->shape().dimensions(),
|
||||
halo_slice_strides));
|
||||
}
|
||||
auto left_halo =
|
||||
collective_ops_creator.create_cross_partition_collective_permute(
|
||||
@ -441,29 +446,30 @@ absl::optional<HloInstruction*> ExchangeHalo(
|
||||
// Right halo.
|
||||
int64 max_right_halo_size =
|
||||
right_halo_size_function.MaxInRange(0, shard_count - 1);
|
||||
if (max_right_halo_size > input_shard_size) {
|
||||
VLOG(1) << "ExchangeHalo failed: halo is beyond the right neighbor.";
|
||||
return absl::nullopt;
|
||||
}
|
||||
if (max_right_halo_size > 0) {
|
||||
for (int64 i = 0; i < CeilOfRatio(max_right_halo_size, input_shard_size);
|
||||
++i) {
|
||||
std::vector<std::pair<int64, int64>> source_target_pairs;
|
||||
target.tile_assignment().Each(
|
||||
[&](absl::Span<const int64> indices, int64 device) {
|
||||
if (indices[dim] > 0) {
|
||||
if (indices[dim] > i) {
|
||||
std::vector<int64> target_indices(indices.begin(), indices.end());
|
||||
target_indices[dim] -= 1;
|
||||
target_indices[dim] -= i + 1;
|
||||
source_target_pairs.emplace_back(
|
||||
device, target.tile_assignment()(target_indices));
|
||||
}
|
||||
});
|
||||
int64 halo_size =
|
||||
std::min(max_right_halo_size - input_shard_size * i, input_shard_size);
|
||||
auto halo_shape = hlo->shape();
|
||||
halo_shape.set_dimensions(dim, max_right_halo_size);
|
||||
std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
|
||||
std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
|
||||
|
||||
auto source_halo_slice = b->AddInstruction(
|
||||
hlo->CreateSlice(halo_shape, hlo, halo_start_indices,
|
||||
halo_shape.dimensions(), halo_slice_strides));
|
||||
HloInstruction* source_halo_slice = hlo;
|
||||
if (halo_size != halo_shape.dimensions(dim)) {
|
||||
halo_shape.set_dimensions(dim, halo_size);
|
||||
std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
|
||||
std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
|
||||
source_halo_slice = b->AddInstruction(HloInstruction::CreateSlice(
|
||||
halo_shape, hlo, halo_start_indices, halo_shape.dimensions(),
|
||||
halo_slice_strides));
|
||||
}
|
||||
auto right_halo =
|
||||
collective_ops_creator.create_cross_partition_collective_permute(
|
||||
b, source_halo_slice, source_target_pairs, (*next_channel_id)++);
|
||||
|
||||
@ -57,6 +57,10 @@ bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding);
|
||||
// target sharding.
|
||||
Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding);
|
||||
|
||||
// Similar to ShapeUtil::ByteSizeOf(), but does not check it has dense layout
|
||||
// since this can be before layout assignment.
|
||||
int64 ShapeSizeInBytes(const Shape& shape);
|
||||
|
||||
// Returns the shard shape for a partition without padding due to uneven
|
||||
// sharding.
|
||||
Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,
|
||||
|
||||
@ -112,8 +112,7 @@ constexpr int64 kDefaultMaxNumNodesInAllPaths = 100;
|
||||
|
||||
using absl::EqualsIgnoreCase;
|
||||
|
||||
// A global control for whether backend configuration display is enabled.
|
||||
bool show_backend_config = true;
|
||||
HloRenderOptions hlo_render_options;
|
||||
|
||||
HloInstruction* FindInstruction(const HloModule& module, string node_name) {
|
||||
if (absl::StartsWith(node_name, "%")) {
|
||||
@ -160,6 +159,8 @@ void DoHelpCommand() {
|
||||
Renders all nodes in <computation>.
|
||||
backend_config [on|off]
|
||||
Controls whether backend operation configuration information is printed.
|
||||
show_fusion_subcomputations [on|off]
|
||||
Controls whether fusion subcomputations are shown.
|
||||
list [name|op_name|op_type] <pattern>
|
||||
Lists all instructions whose name, metadata op_name, or metadata op_type
|
||||
contains <pattern> as a substring.
|
||||
@ -182,15 +183,32 @@ void DoHelpCommand() {
|
||||
// Turn metadata-printing on or off.
|
||||
void DoBackendConfigCommand(const std::vector<string>& tokens) {
|
||||
if (tokens.size() == 2 && tokens[1] == "on") {
|
||||
show_backend_config = true;
|
||||
hlo_render_options.show_backend_config = true;
|
||||
} else if (tokens.size() == 2 && tokens[1] == "off") {
|
||||
show_backend_config = false;
|
||||
hlo_render_options.show_backend_config = false;
|
||||
} else if (tokens.size() != 1) {
|
||||
std::cerr << "(Illegal backend_config value. Use either 'on' or 'off'.)"
|
||||
<< std::endl;
|
||||
}
|
||||
std::cout << "Backend configuration display "
|
||||
<< (show_backend_config ? "ON" : "OFF") << std::endl;
|
||||
<< (hlo_render_options.show_backend_config ? "ON" : "OFF")
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// Turn fusion computation display on or off.
|
||||
void DoShowFusionSubcomputationsCommand(const std::vector<string>& tokens) {
|
||||
if (tokens.size() == 2 && tokens[1] == "on") {
|
||||
hlo_render_options.show_fusion_subcomputations = true;
|
||||
} else if (tokens.size() == 2 && tokens[1] == "off") {
|
||||
hlo_render_options.show_fusion_subcomputations = false;
|
||||
} else if (tokens.size() != 1) {
|
||||
std::cerr << "(Illegal show_fusion_subcomputations value. Use either "
|
||||
"'on' or 'off'.)"
|
||||
<< std::endl;
|
||||
}
|
||||
std::cout << "Fusion subcomputations display "
|
||||
<< (hlo_render_options.show_fusion_subcomputations ? "ON" : "OFF")
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// List all computations in the module.
|
||||
@ -373,7 +391,7 @@ void DoExtractCommand(const HloModule& module,
|
||||
auto extracted_module = ExtractModule(instr, height);
|
||||
std::cout << extracted_module->ToString(
|
||||
HloPrintOptions::ShortParsable().set_print_backend_config(
|
||||
show_backend_config))
|
||||
hlo_render_options.show_backend_config))
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
@ -517,7 +535,7 @@ void DoAllPathsCommand(const Options& opts, const HloModule& module,
|
||||
}
|
||||
RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
|
||||
return RenderAllPathsFromTo(*from, *to, max_nodes, format,
|
||||
/*show_backend_config=*/show_backend_config);
|
||||
hlo_render_options);
|
||||
});
|
||||
}
|
||||
|
||||
@ -582,15 +600,13 @@ void DoPlotCommand(const Options& opts, const HloModule& module,
|
||||
RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
|
||||
return RenderGraph(*comp, /*label=*/"",
|
||||
comp->parent()->config().debug_options(), format,
|
||||
/*hlo_execution_profile=*/nullptr,
|
||||
/*show_backend_config=*/show_backend_config);
|
||||
/*hlo_execution_profile=*/nullptr, hlo_render_options);
|
||||
});
|
||||
} else {
|
||||
RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
|
||||
return RenderNeighborhoodAround(
|
||||
*instr, graph_width, format,
|
||||
/*show_backend_config=*/show_backend_config,
|
||||
/*boundary=*/boundary);
|
||||
return RenderNeighborhoodAround(*instr, graph_width, format,
|
||||
hlo_render_options,
|
||||
/*boundary=*/boundary);
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -617,6 +633,8 @@ void InteractiveDumpGraphs(const Options& opts, const HloModule& module) {
|
||||
DoHelpCommand();
|
||||
} else if (tokens[0] == "backend_config") {
|
||||
DoBackendConfigCommand(tokens);
|
||||
} else if (tokens[0] == "show_fusion_subcomputations") {
|
||||
DoShowFusionSubcomputationsCommand(tokens);
|
||||
} else if (tokens[0] == "list") {
|
||||
if (tokens.size() > 1 && tokens[1] == "computations") {
|
||||
DoListComputationsCommand(module, tokens);
|
||||
|
||||
@ -84,6 +84,13 @@ END
|
||||
name: "Tout"
|
||||
description: <<END
|
||||
the types of the output tensors.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "enable_large_batch_splitting"
|
||||
description: <<END
|
||||
input with a large size (i.e., larger than the largest value of
|
||||
`allowed_batch_sizes`) will be splitted into multiple batches with batch size.
|
||||
END
|
||||
}
|
||||
summary: "Batches all the inputs tensors to the computation done by the function."
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
op {
|
||||
graph_op_name: "DeviceIndex"
|
||||
visibility: HIDDEN
|
||||
summary: "Return the index of device the op runs."
|
||||
}
|
||||
@ -100,7 +100,7 @@ string AttrSlice::DebugString() const {
|
||||
return absl::StrJoin(attr_key_vals, ", ");
|
||||
}
|
||||
|
||||
string SummarizeNodeDef(const NodeDef& node_def) {
|
||||
string SummarizeNodeDef(const NodeDef& node_def, int max_inputs_in_summary) {
|
||||
string ret = strings::StrCat(errors::FormatNodeNameForError(node_def.name()),
|
||||
" = ", node_def.op(), "[");
|
||||
strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device()));
|
||||
@ -111,6 +111,10 @@ string SummarizeNodeDef(const NodeDef& node_def) {
|
||||
for (const string& input : node_def.input()) {
|
||||
if (!first) strings::StrAppend(&ret, ", ");
|
||||
first = false;
|
||||
if (max_inputs_in_summary-- == 0) {
|
||||
strings::StrAppend(&ret, "...");
|
||||
break;
|
||||
}
|
||||
strings::StrAppend(&ret, input);
|
||||
}
|
||||
strings::StrAppend(&ret, ")");
|
||||
|
||||
@ -58,7 +58,12 @@ extern const char* const kColocationGroupPrefix;
|
||||
|
||||
// Produce a human-readable version of a Node or NodeDef that is more concise
|
||||
// than a text-format proto.
|
||||
string SummarizeNodeDef(const NodeDef& node_def);
|
||||
//
|
||||
// The parameter `max_inputs_in_summary` specifies how many inputs at most to
|
||||
// serialize in the output (in order not to get a string which is overly large).
|
||||
// The value `-1` specifies that all inputs will be shown.
|
||||
string SummarizeNodeDef(const NodeDef& node_def,
|
||||
int max_inputs_in_summary = -1);
|
||||
string SummarizeAttrs(const NodeDef& node_def);
|
||||
string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device);
|
||||
|
||||
|
||||
@ -1062,6 +1062,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
|
||||
@ -603,16 +603,8 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
"//tensorflow/core/grappler/optimizers:arithmetic_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:common_subgraph_elimination",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||
"//tensorflow/core/grappler/optimizers:dependency_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:function_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:loop_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:model_pruner",
|
||||
"//tensorflow/core/grappler/optimizers:remapper",
|
||||
"//tensorflow/core/grappler/optimizers:shape_optimizer",
|
||||
"//tensorflow/core/grappler/utils:functions",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core:framework",
|
||||
|
||||
@ -21,15 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
||||
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
||||
#include "tensorflow/core/grappler/optimizers/remapper.h"
|
||||
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
|
||||
#include "tensorflow/core/grappler/utils/functions.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
@ -60,14 +52,6 @@ constexpr std::array<const char*, 15> kTFDataOptimizations = {
|
||||
"slack",
|
||||
"inject_prefetch"};
|
||||
|
||||
// Standard grappler optimizations, in the order we want to perform them.
|
||||
// The order matches the order in the generic meta optimizer.
|
||||
constexpr std::array<const char*, 9> kGrapplerOptimizations = {
|
||||
"pruning", "function", "common_subgraph_elimination",
|
||||
"shape", "arithmetic", "layout_optimizer",
|
||||
"remapper", "loop", "dependency",
|
||||
};
|
||||
|
||||
// Parses a list of string optimizer configurations into a map from
|
||||
// optimizer name -> rewriter config for that optimizer.
|
||||
Status ToConfigMap(
|
||||
@ -118,11 +102,6 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
ApplyOptimization(optimization, cluster, &optimized_item));
|
||||
}
|
||||
|
||||
for (const auto& optimization : kGrapplerOptimizations) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ApplyOptimization(optimization, cluster, &optimized_item));
|
||||
}
|
||||
|
||||
// Store the final result of all the optimizations in `output`.
|
||||
output->Swap(&optimized_item.graph);
|
||||
|
||||
@ -132,16 +111,17 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
.ReachableDefinitions(*output);
|
||||
const auto producer = output->versions().producer();
|
||||
bool optimized_functions = false;
|
||||
for (const FunctionDef& func : output->library().function()) {
|
||||
for (const auto& name : flib.ListFunctionNames()) {
|
||||
auto* func = flib.Find(name);
|
||||
// Skip non tf.data functions.
|
||||
if (!func.attr().contains(data::kTFDataFunction)) continue;
|
||||
VLOG(3) << "Optimize function: function=" << func.signature().name();
|
||||
if (!func->attr().contains(data::kTFDataFunction)) continue;
|
||||
VLOG(3) << "Optimize function: function=" << func->signature().name();
|
||||
optimized_functions = true;
|
||||
|
||||
// Make a GrapplerItem from a FunctionDef.
|
||||
GrapplerFunctionItem func_item;
|
||||
TF_RETURN_IF_ERROR(
|
||||
MakeGrapplerFunctionItem(func, flib, producer, &func_item));
|
||||
MakeGrapplerFunctionItem(*func, flib, producer, &func_item));
|
||||
|
||||
GraphDef optimized_func_graph;
|
||||
TF_RETURN_IF_ERROR(Optimize(cluster, func_item, &optimized_func_graph));
|
||||
@ -162,7 +142,7 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
|
||||
// Replace optimized function with a new FunctionDef.
|
||||
TF_RETURN_IF_ERROR(
|
||||
flib.ReplaceFunction(func.signature().name(), optimized_func));
|
||||
flib.ReplaceFunction(func->signature().name(), optimized_func));
|
||||
}
|
||||
if (optimized_functions) {
|
||||
*output->mutable_library() = flib.ToProto();
|
||||
@ -221,27 +201,6 @@ Status TFDataMetaOptimizer::Init(
|
||||
}
|
||||
}
|
||||
|
||||
// Enable a subset of grappler optimization that are enabled by default.
|
||||
//
|
||||
// Layout optimizations are excluded because they assume that ops without
|
||||
// explicit device assignment will be placed on GPU (if available) but that's
|
||||
// not the case for operations within tf.data functions.
|
||||
//
|
||||
// TODO(b/120437209): Re-enable constant folding.
|
||||
//
|
||||
// TODO(jsimsa): Make the set of generic Grappler optimization applied to
|
||||
// tf.data functions configurable.
|
||||
enabled_optimizers_["pruning"] = MakeUnique<ModelPruner>();
|
||||
enabled_optimizers_["shape"] = MakeUnique<ShapeOptimizer>();
|
||||
enabled_optimizers_["remapping"] = MakeUnique<Remapper>(RewriterConfig::ON);
|
||||
enabled_optimizers_["common_subgraph_elimination"] =
|
||||
MakeUnique<CommonSubgraphElimination>();
|
||||
enabled_optimizers_["arithmetic"] = MakeUnique<ArithmeticOptimizer>();
|
||||
enabled_optimizers_["dependency"] = MakeUnique<DependencyOptimizer>();
|
||||
enabled_optimizers_["loop"] = MakeUnique<LoopOptimizer>();
|
||||
enabled_optimizers_["function"] = MakeUnique<FunctionOptimizer>(
|
||||
RewriterConfig::ON, /*lower_control_flow=*/true);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
@ -17,9 +17,12 @@ limitations under the License.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
@ -36,6 +39,11 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
constexpr char kConstOp[] = "Const";
|
||||
constexpr char kCaseOp[] = "Case";
|
||||
constexpr char kDeviceIndexOp[] = "DeviceIndex";
|
||||
|
||||
// TODO(b/157615690): clean up function implementation swap code.
|
||||
// The overall idea for the function swap is like below:
|
||||
// ----------- -----------
|
||||
// inp_1 ->| P_C | -> out_1 g_inp_1 ->| P_C | -> g_out_1
|
||||
@ -292,6 +300,74 @@ Status ImplementationSelector::MaybeOptimizeFunctionCall(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Finds the index of the device from the device name list.
|
||||
Status FindDeviceIndex(const utils::MutableNodeView* device_index_node,
|
||||
const string& device, int* index) {
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
if (!DeviceNameUtils::ParseFullName(device, &parsed_name) ||
|
||||
!parsed_name.has_type) {
|
||||
return errors::Internal("Could not parse device name:", device);
|
||||
}
|
||||
const auto& device_list =
|
||||
device_index_node->GetAttr("device_names")->list().s();
|
||||
auto it = absl::c_find(device_list, parsed_name.type);
|
||||
if (it != device_list.end()) {
|
||||
*index = it - device_list.begin();
|
||||
} else {
|
||||
// Sets *index to device_list.size() because the default_fn is guaranteed to
|
||||
// be the final item in the case op branching list.
|
||||
*index = device_list.size();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Rewrites the device_index op to a const op with value of the index.
|
||||
void RewriteDeviceIndexOp(utils::MutableNodeView* device_index_node,
|
||||
int index) {
|
||||
// Modifies the DeviceIndex node to be an Const op with correct device index.
|
||||
auto node = device_index_node->node();
|
||||
node->set_op(kConstOp);
|
||||
node->clear_attr();
|
||||
(*node->mutable_attr())["dtype"].set_type(DT_INT32);
|
||||
auto* tensor = (*node->mutable_attr())["value"].mutable_tensor();
|
||||
tensor->set_dtype(DT_INT32);
|
||||
tensor->add_int_val(index);
|
||||
VLOG(2) << "Node after rewriting:" << node->DebugString();
|
||||
}
|
||||
|
||||
Status ImplementationSelector::SelectDeviceIndex(GraphDef* graph) const {
|
||||
Status status;
|
||||
VLOG(2) << "graph before rewriting device index:" << graph->DebugString();
|
||||
utils::MutableGraphView graph_view(graph, &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
const int num_nodes = graph_view.NumNodes();
|
||||
for (int k = 0; k < num_nodes; ++k) {
|
||||
auto* node_view = graph_view.GetNode(k);
|
||||
if (node_view->GetOp() != kDeviceIndexOp) {
|
||||
continue;
|
||||
}
|
||||
VLOG(2) << "Found a node to rewrite the device index";
|
||||
|
||||
// Find the case node with device index node as input, rewrite the
|
||||
// DeviceIndex node to have the value of the index of device type of the
|
||||
// case node.
|
||||
for (const auto& fanouts : node_view->GetRegularFanouts()) {
|
||||
for (const auto& fanout : fanouts) {
|
||||
if (fanout.node_view()->GetOp() != kCaseOp) continue;
|
||||
int index;
|
||||
// If any error is thrown out during device parsing, we simply skip
|
||||
// and do not modify the DeviceIndexNode.
|
||||
Status status =
|
||||
FindDeviceIndex(node_view, fanout.node_view()->GetDevice(), &index);
|
||||
if (status.ok()) {
|
||||
RewriteDeviceIndexOp(node_view, index);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ImplementationSelector::SelectImplementation(GraphDef* graph) const {
|
||||
if (!graph->has_library()) {
|
||||
VLOG(2) << "Skipping graph since it does not have function def";
|
||||
@ -307,8 +383,9 @@ Status ImplementationSelector::SelectImplementation(GraphDef* graph) const {
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
const int num_nodes = graph_view.NumNodes();
|
||||
for (int k = 0; k < num_nodes; ++k)
|
||||
for (int k = 0; k < num_nodes; ++k) {
|
||||
TF_RETURN_IF_ERROR(MaybeOptimizeFunctionCall(graph_view.GetNode(k)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@ -326,7 +403,13 @@ Status ImplementationSelector::Optimize(Cluster* cluster,
|
||||
<< "libraries: " << status;
|
||||
return errors::Aborted("Skipped Optimization");
|
||||
}
|
||||
|
||||
*optimized_graph = item.graph;
|
||||
status = SelectDeviceIndex(optimized_graph);
|
||||
if (!status.ok()) {
|
||||
*optimized_graph = item.graph;
|
||||
VLOG(2) << "Could not rewrite device index due to error:" << status;
|
||||
}
|
||||
return SelectImplementation(optimized_graph);
|
||||
}
|
||||
|
||||
|
||||
@ -34,6 +34,28 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Motivation: To achieve the same high level functionality, the underlying
|
||||
// implementations sometimes are different for various devices where the
|
||||
// function runs. In order to achieve the correct result and best performance,
|
||||
// the proper implementation needs to be picked dynamically.
|
||||
//
|
||||
// Currently there are two approaches to do this.
|
||||
// (1) Utilize case op and dynamacically change the branch index.
|
||||
// (2) Swap function implementation, it will be deprecated.
|
||||
//
|
||||
// Idea for approach 1.
|
||||
// This transformation rewrites the DeviceIndex op with a Const op with value
|
||||
// of the index of the device the associcated Case op runs.
|
||||
// Example:
|
||||
// def plus_one_gpu(x): return x + 1.0
|
||||
// def plus_one_reference_implementation(x): return x + 1.0
|
||||
// input = tf.constant(2.0, dtype=tf.float32)
|
||||
// cpu_fn = lambda:plus_one_reference_implementation(input)
|
||||
// gpu_fn = lambda:plus_one_gpu(input)
|
||||
// control_flow_ops.execute_fn_for_device(
|
||||
// {"CPU": cpu_fn, "GPU":gpu_fn)}, default_fn=cpu_fn)
|
||||
//
|
||||
// Idea for approach 2.
|
||||
// This transformation replaces function calls by the appropriate function
|
||||
// definition based on properties of the runtime system. For instance,
|
||||
// we may choose one implementation over another if we have a GPU with
|
||||
@ -58,7 +80,8 @@ namespace grappler {
|
||||
// z = plus_one_gpu(input)
|
||||
// print(sess.run(z))
|
||||
//
|
||||
// At runtime, we will trim either `plus_one_gpu` or
|
||||
|
||||
// At runtime, we will select either `plus_one_gpu` or
|
||||
// `plus_one_reference_implementation` based on the availability of the GPU.
|
||||
//
|
||||
// Available annotations:
|
||||
@ -106,6 +129,68 @@ class ImplementationSelector : public CustomGraphOptimizer {
|
||||
// gradients.
|
||||
Status SelectImplementation(GraphDef* graph) const;
|
||||
|
||||
// Rewrites the DeviceIndex op with a Const op with value of the index of the
|
||||
// device the associcated Case op runs.
|
||||
|
||||
// This function first looks up all the DeviceIndex ops.
|
||||
// Then for each of these ops, it finds the device of the
|
||||
// associated Case op that takes the DeviceIndex op as the input, and
|
||||
// caculates the index of the device in the device list of DeviceIndex op.
|
||||
// Lastly, it rewrites the DeviceIndex op with a Const op and sets the value
|
||||
// to be the index.
|
||||
//
|
||||
// Example input nodes:
|
||||
// node {
|
||||
// name: "x"
|
||||
// op: "DeviceIndex"
|
||||
// device: "/device:CPU:0"
|
||||
// attr {
|
||||
// key: "device_names"
|
||||
// value {
|
||||
// list {
|
||||
// s: "CPU"
|
||||
// s: "TPU_REPLICATED_CORE"
|
||||
// s: "GPU"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// node {
|
||||
// name: "case"
|
||||
// op: "Case"
|
||||
// input: "x"
|
||||
// device: "/device:GPU:0"
|
||||
// ...
|
||||
// }
|
||||
// Example output nodes:
|
||||
//
|
||||
// name: "x"
|
||||
// op: "Const"
|
||||
// device: "/device:CPU:0"
|
||||
// attr {
|
||||
// key: "dtype"
|
||||
// value {
|
||||
// type: DT_INT32
|
||||
// }
|
||||
// }
|
||||
// attr {
|
||||
// key: "value"
|
||||
// value {
|
||||
// tensor {
|
||||
// dtype: DT_INT32
|
||||
// int_val: 2
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// node {
|
||||
// name: "case"
|
||||
// op: "Case"
|
||||
// input: "x"
|
||||
// device: "/device:GPU:0"
|
||||
// ...
|
||||
// }
|
||||
Status SelectDeviceIndex(GraphDef* graph) const;
|
||||
|
||||
std::unique_ptr<FunctionLibraryApiInfo> lib_info_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ImplementationSelector);
|
||||
|
||||
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
@ -58,6 +59,167 @@ TEST_F(ImplementationSelectorTest, NoUpdate) {
|
||||
EXPECT_EQ(item.graph.node_size(), output.node_size());
|
||||
}
|
||||
|
||||
TEST_F(ImplementationSelectorTest, SelectDeviceIndex) {
|
||||
using test::function::NDef;
|
||||
ImplementationSelector optimizer;
|
||||
GraphDef output;
|
||||
GrapplerItem item;
|
||||
AttrValue device_names;
|
||||
device_names.mutable_list()->add_s("CPU");
|
||||
device_names.mutable_list()->add_s("GPU");
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||
CpuDevice),
|
||||
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, GpuDevice)});
|
||||
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "x") {
|
||||
// Rewrite DeviceIndex op to a Const op with value of GPU index 1.
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ(1, node.attr().at("value").tensor().int_val(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ImplementationSelectorTest, SelectDeviceIndexMultiOps) {
|
||||
using test::function::NDef;
|
||||
ImplementationSelector optimizer;
|
||||
GraphDef output;
|
||||
GrapplerItem item;
|
||||
AttrValue device_names;
|
||||
device_names.mutable_list()->add_s("CPU");
|
||||
device_names.mutable_list()->add_s("TPU_REPLICATED_CORE");
|
||||
device_names.mutable_list()->add_s("GPU");
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||
CpuDevice),
|
||||
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
|
||||
NDef("y", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||
GpuDevice),
|
||||
NDef("case_y", "Case", {"y"}, {{"T", DT_FLOAT}}, TpuDevice)});
|
||||
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "x") {
|
||||
// Rewrite DeviceIndex op to a Const op with value of GPU index 1.
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ(2, node.attr().at("value").tensor().int_val(0));
|
||||
}
|
||||
if (node.name() == "y") {
|
||||
// Rewrite DeviceIndex op to a Const op with value of CPU index 0.
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ(1, node.attr().at("value").tensor().int_val(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ImplementationSelectorTest, SelectDeviceIndexNotFound) {
|
||||
using test::function::NDef;
|
||||
ImplementationSelector optimizer;
|
||||
GraphDef output;
|
||||
GrapplerItem item;
|
||||
AttrValue device_names;
|
||||
device_names.mutable_list()->add_s("CPU");
|
||||
device_names.mutable_list()->add_s("GPU");
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||
CpuDevice),
|
||||
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, TpuDevice)});
|
||||
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "x") {
|
||||
// Rewrite DeviceIndex op to a Const op with value of device names length.
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ(2, node.attr().at("value").tensor().int_val(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ImplementationSelectorTest, SelectDeviceIndexError) {
|
||||
using test::function::NDef;
|
||||
ImplementationSelector optimizer;
|
||||
GraphDef output;
|
||||
GrapplerItem item;
|
||||
AttrValue device_names;
|
||||
device_names.mutable_list()->add_s("CPU");
|
||||
device_names.mutable_list()->add_s("GPU");
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||
CpuDevice),
|
||||
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, "")});
|
||||
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "x") {
|
||||
// Device parse has error, do not rewrite the DeviceIndexNode.
|
||||
EXPECT_EQ("DeviceIndex", node.op());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ImplementationSelectorTest, TwoTypesOfSwapImplementation) {
|
||||
using test::function::NDef;
|
||||
ImplementationSelector optimizer;
|
||||
GraphDef output;
|
||||
GrapplerItem item;
|
||||
// DeviceIndex op based implementation selector.
|
||||
AttrValue device_names;
|
||||
device_names.mutable_list()->add_s("CPU");
|
||||
device_names.mutable_list()->add_s("TPU_REPLICATED_CORE");
|
||||
device_names.mutable_list()->add_s("GPU");
|
||||
|
||||
// Function swap based implementation selector.
|
||||
auto cpu_def = test::function::XTimesTwo();
|
||||
auto* func_attr = cpu_def.mutable_attr();
|
||||
(*func_attr)["api_implements"].set_s("times_two");
|
||||
(*func_attr)["api_preferred_device"].set_s("CPU");
|
||||
|
||||
auto gpu_def = test::function::XAddX();
|
||||
auto* func2_attr = gpu_def.mutable_attr();
|
||||
(*func2_attr)["api_implements"].set_s("times_two");
|
||||
(*func2_attr)["api_preferred_device"].set_s("GPU");
|
||||
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||
CpuDevice),
|
||||
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
|
||||
NDef("y", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||
GpuDevice),
|
||||
NDef("case_y", "Case", {"y"}, {{"T", DT_FLOAT}}, TpuDevice),
|
||||
NDef("y1", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
|
||||
NDef("z1", "Identity", {"y1"}, {{"T", DT_FLOAT}}, GpuDevice),
|
||||
NDef("y2", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, CpuDevice),
|
||||
NDef("z2", "Identity", {"y2"}, {{"T", DT_FLOAT}}, CpuDevice)},
|
||||
// FunctionLib
|
||||
{cpu_def, gpu_def});
|
||||
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "x") {
|
||||
// Rewrite DeviceIndex op to a Const op with value of GPU index 1.
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ(2, node.attr().at("value").tensor().int_val(0));
|
||||
}
|
||||
if (node.name() == "y") {
|
||||
// Rewrite DeviceIndex op to a Const op with value of CPU index 0.
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ(1, node.attr().at("value").tensor().int_val(0));
|
||||
}
|
||||
if (node.name() == "y1") {
|
||||
// Make sure the implementation has been swapped to use the GPU version.
|
||||
EXPECT_EQ("XAddX", node.op());
|
||||
} else if (node.name() == "y2") {
|
||||
// Make sure the implementation is not changed.
|
||||
EXPECT_EQ("XTimesTwo", node.op());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ImplementationSelectorTest, SwapImplementation) {
|
||||
using test::function::NDef;
|
||||
auto cpu_def = test::function::XTimesTwo();
|
||||
|
||||
@ -272,6 +272,7 @@ class BatchResource : public ResourceBase {
|
||||
int32 batch_timeout_micros, int32 max_enqueued_batches,
|
||||
const std::vector<int32>& allowed_batch_sizes,
|
||||
FunctionLibraryRuntime::Handle fhandle,
|
||||
bool enable_large_batch_splitting,
|
||||
std::unique_ptr<BatchResource>* resource) {
|
||||
std::unique_ptr<BatchResource> new_resource(new BatchResource);
|
||||
|
||||
@ -286,6 +287,10 @@ class BatchResource : public ResourceBase {
|
||||
new_resource->batcher_queue_options_.batch_timeout_micros =
|
||||
batch_timeout_micros;
|
||||
|
||||
// Support for splitting large batch is still in progress.
|
||||
new_resource->batcher_queue_options_.enable_large_batch_splitting =
|
||||
enable_large_batch_splitting;
|
||||
|
||||
new_resource->allowed_batch_sizes_ = allowed_batch_sizes;
|
||||
|
||||
new_resource->fhandle_ = fhandle;
|
||||
@ -786,6 +791,13 @@ class BatchFunctionKernel : public AsyncOpKernel {
|
||||
OP_REQUIRES_OK(c, c->GetAttr("f", &func));
|
||||
OP_REQUIRES_OK(
|
||||
c, lib->Instantiate(func.name(), AttrSlice(&func.attr()), &fhandle_));
|
||||
|
||||
if (c->HasAttr("enable_large_batch_splitting")) {
|
||||
OP_REQUIRES_OK(c, c->GetAttr("enable_large_batch_splitting",
|
||||
&enable_large_batch_splitting_));
|
||||
} else {
|
||||
enable_large_batch_splitting_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
bool IsExpensive() override { return false; }
|
||||
@ -794,10 +806,10 @@ class BatchFunctionKernel : public AsyncOpKernel {
|
||||
BatchResource* br;
|
||||
std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
|
||||
std::unique_ptr<BatchResource> new_resource;
|
||||
TF_RETURN_IF_ERROR(
|
||||
BatchResource::Create(num_batch_threads_, max_batch_size_,
|
||||
batch_timeout_micros_, max_enqueued_batches_,
|
||||
allowed_batch_sizes_, fhandle_, &new_resource));
|
||||
TF_RETURN_IF_ERROR(BatchResource::Create(
|
||||
num_batch_threads_, max_batch_size_, batch_timeout_micros_,
|
||||
max_enqueued_batches_, allowed_batch_sizes_, fhandle_,
|
||||
enable_large_batch_splitting_, &new_resource));
|
||||
*r = new_resource.release();
|
||||
return Status::OK();
|
||||
};
|
||||
@ -844,6 +856,7 @@ class BatchFunctionKernel : public AsyncOpKernel {
|
||||
int32 max_enqueued_batches_;
|
||||
std::vector<int32> allowed_batch_sizes_;
|
||||
FunctionLibraryRuntime::Handle fhandle_;
|
||||
bool enable_large_batch_splitting_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
|
||||
@ -876,7 +889,7 @@ class BatchKernel : public AsyncOpKernel {
|
||||
std::unique_ptr<BatchResource> new_resource;
|
||||
TF_RETURN_IF_ERROR(BatchResource::Create(
|
||||
num_batch_threads_, max_batch_size_, batch_timeout_micros_,
|
||||
max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle,
|
||||
max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle, false,
|
||||
&new_resource));
|
||||
*r = new_resource.release();
|
||||
return Status::OK();
|
||||
|
||||
@ -160,6 +160,10 @@ class SharedBatchScheduler
|
||||
// See the class documentation above for guidelines on how to tune this
|
||||
// parameter.
|
||||
size_t max_enqueued_batches = 10;
|
||||
|
||||
// If true, queue implementation would split one input batch task into
|
||||
// subtasks and fit them into different batches.
|
||||
bool enable_large_batch_splitting = false;
|
||||
};
|
||||
Status AddQueue(const QueueOptions& options,
|
||||
std::function<void(std::unique_ptr<Batch<TaskType>>)>
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"if_not_mobile",
|
||||
"tf_cc_test",
|
||||
"tf_kernel_library",
|
||||
)
|
||||
@ -150,6 +151,7 @@ cc_library(
|
||||
":dataset_utils",
|
||||
":single_threaded_executor",
|
||||
":stats_utils",
|
||||
"@com_google_absl//absl/time",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
@ -158,8 +160,10 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/kernels:variable_ops",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"@com_google_absl//absl/time",
|
||||
],
|
||||
] + if_not_mobile([
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler/optimizers:meta_optimizer",
|
||||
]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
||||
@ -35,6 +35,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/notification.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
@ -612,6 +617,28 @@ Status CapturedFunction::Instantiate(
|
||||
for (size_t i = 0; i < fdef->signature().output_arg_size(); ++i) {
|
||||
inst_opts.output_devices.push_back(inst_opts.target);
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
grappler::GrapplerItem::OptimizationOptions optimization_options;
|
||||
optimization_options.allow_pruning_stateful_and_dataset_ops = false;
|
||||
ConfigProto config_proto = inst_opts.config_proto;
|
||||
// Layout optimizations are excluded because they assume that ops without
|
||||
// explicit device assignment will be placed on GPU (if available) but
|
||||
// that's not the case for operations within tf.data functions.
|
||||
config_proto.mutable_graph_options()
|
||||
->mutable_rewrite_options()
|
||||
->set_layout_optimizer(RewriterConfig::OFF);
|
||||
// TODO(b/120437209): Re-enable constant folding.
|
||||
config_proto.mutable_graph_options()
|
||||
->mutable_rewrite_options()
|
||||
->set_constant_folding(RewriterConfig::OFF);
|
||||
inst_opts.optimize_graph_fn =
|
||||
std::bind(tensorflow::grappler::OptimizeGraph, std::placeholders::_1,
|
||||
std::placeholders::_2, std::placeholders::_3,
|
||||
std::placeholders::_4, std::placeholders::_5,
|
||||
std::move(config_proto), fdef->signature().name(),
|
||||
std::move(optimization_options), std::placeholders::_6);
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime::Handle f_handle;
|
||||
|
||||
@ -100,9 +100,13 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
|
||||
AttrValue slack_period_attr;
|
||||
b->BuildAttrValue(slack_period_, &slack_period_attr);
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||
this, {input_graph_node, buffer_size},
|
||||
{std::make_pair(kSlackPeriod, slack_period_attr)}, output));
|
||||
AttrValue legacy_autotune_attr;
|
||||
b->BuildAttrValue(legacy_autotune_, &legacy_autotune_attr);
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddDataset(this, {input_graph_node, buffer_size},
|
||||
{std::make_pair(kSlackPeriod, slack_period_attr),
|
||||
std::make_pair(kLegacyAutotune, legacy_autotune_attr)},
|
||||
output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
@ -39,7 +39,9 @@ namespace data {
|
||||
constexpr char kCurrentFileIndex[] = "current_file_index";
|
||||
constexpr char kOffset[] = "offset";
|
||||
constexpr char kGcsFsPrefix[] = "gs://";
|
||||
constexpr char kS3FsPrefix[] = "s3://";
|
||||
constexpr int64 kCloudTpuBlockSize = 127LL << 20; // 127MB.
|
||||
constexpr int64 kS3BlockSize = kCloudTpuBlockSize;
|
||||
|
||||
bool is_cloud_tpu_gcs_fs() {
|
||||
#if defined(PLATFORM_CLOUD_TPU) && defined(TPU_GCS_FS)
|
||||
@ -237,12 +239,14 @@ void TFRecordDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
errors::InvalidArgument("`filenames` must be a scalar or a vector."));
|
||||
|
||||
bool is_gcs_fs = true;
|
||||
bool is_s3_fs = true;
|
||||
std::vector<string> filenames;
|
||||
filenames.reserve(filenames_tensor->NumElements());
|
||||
for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
|
||||
VLOG(2) << "Reading file: " << filenames_tensor->flat<tstring>()(i);
|
||||
filenames.push_back(filenames_tensor->flat<tstring>()(i));
|
||||
is_gcs_fs &= absl::StartsWith(filenames[i], kGcsFsPrefix);
|
||||
is_s3_fs &= absl::StartsWith(filenames[i], kS3FsPrefix);
|
||||
}
|
||||
|
||||
tstring compression_type;
|
||||
@ -264,6 +268,13 @@ void TFRecordDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
buffer_size = kCloudTpuBlockSize;
|
||||
}
|
||||
|
||||
if (is_s3_fs && buffer_size < kS3BlockSize) {
|
||||
VLOG(2) << "User buffer size is too small for reading "
|
||||
<< "TFRecords stored in S3. Overriding " << buffer_size
|
||||
<< " to the minimum recommended buffer_size = " << kS3BlockSize;
|
||||
buffer_size = kS3BlockSize;
|
||||
}
|
||||
|
||||
*output =
|
||||
new Dataset(ctx, std::move(filenames), compression_type, buffer_size);
|
||||
}
|
||||
|
||||
@ -924,5 +924,37 @@ class FakeParamOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_CPU), FakeParamOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_GPU), FakeParamOp);
|
||||
|
||||
// DeviceIndexOP returns the current device index.
|
||||
class DeviceIndexOp : public OpKernel {
|
||||
public:
|
||||
explicit DeviceIndexOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("device_names", &device_names_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Tensor* device_name_t;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output(0, TensorShape({}), &device_name_t));
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
int index = device_names_.size();
|
||||
if (DeviceNameUtils::ParseFullName(ctx->device()->name(), &parsed_name) &&
|
||||
parsed_name.has_type) {
|
||||
auto it = absl::c_find(device_names_, parsed_name.type);
|
||||
if (it != device_names_.end()) {
|
||||
index = it - device_names_.begin();
|
||||
}
|
||||
}
|
||||
device_name_t->scalar<int32>()() = index;
|
||||
}
|
||||
|
||||
private:
|
||||
PersistentTensor value_handle_;
|
||||
std::vector<string> device_names_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("DeviceIndex").Device(DEVICE_CPU), DeviceIndexOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("DeviceIndex").Device(DEVICE_GPU).HostMemory("index"), DeviceIndexOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
@ -244,7 +244,7 @@ class MklAddNOp : public OpKernel {
|
||||
|
||||
// Create Sum op, and submit net for execution.
|
||||
std::vector<primitive> net;
|
||||
auto sum_stream = CPU_STREAM(cpu_engine);
|
||||
stream* fwd_cpu_stream = CreateStream(ctx, cpu_engine);
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
mkldnn::sum sum_op(sum_pd);
|
||||
std::unordered_map<int, memory> net_args = {
|
||||
@ -253,10 +253,10 @@ class MklAddNOp : public OpKernel {
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
net_args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, inputs[i]});
|
||||
}
|
||||
sum_op.execute(sum_stream, net_args);
|
||||
sum_op.execute(*fwd_cpu_stream, net_args);
|
||||
#else
|
||||
net.push_back(sum(sum_pd, inputs, dst.GetOpMem()));
|
||||
sum_stream.submit(net).wait();
|
||||
fwd_cpu_stream->submit(net).wait();
|
||||
#endif
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
|
||||
@ -136,9 +136,10 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
|
||||
const T* src_data = input_tensor.flat<T>().data();
|
||||
|
||||
T* dst_data = output_tensor->flat<T>().data();
|
||||
|
||||
std::shared_ptr<stream> fwd_cpu_stream;
|
||||
fwd_cpu_stream.reset(CreateStream(context, pooling_fwd->GetEngine()));
|
||||
// Execute pooling op.
|
||||
pooling_fwd->Execute(src_data, dst_data);
|
||||
pooling_fwd->Execute(src_data, dst_data, nullptr, fwd_cpu_stream);
|
||||
|
||||
// Pass min, max from input to output.
|
||||
if (int8_forward_inference) {
|
||||
@ -240,8 +241,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
|
||||
: memory::desc(diff_dst_dims, MklDnnType<T>(),
|
||||
this->data_format_mkldnn_);
|
||||
|
||||
// Pass prop_kind::forward_training to create a forward primitive
|
||||
// that is used in the backward pass.
|
||||
// Pass prop_kind::forward_training to create a forward primitive
|
||||
// that is used in the backward pass.
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// TODO(DNNL): Find out what should we use src_md.data.format.
|
||||
MklPoolingParams bwdParams(
|
||||
@ -260,6 +261,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
|
||||
MklPoolingBwdPrimitive<T>* pooling_bwd =
|
||||
MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);
|
||||
|
||||
std::shared_ptr<stream> bwd_cpu_stream;
|
||||
bwd_cpu_stream.reset(CreateStream(context, pooling_bwd->GetEngine()));
|
||||
Tensor* output_tensor = nullptr;
|
||||
this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()),
|
||||
orig_input_dims_mkl_order,
|
||||
@ -286,7 +289,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
|
||||
T* diff_src_data = output_tensor->flat<T>().data();
|
||||
|
||||
// Execute pooling op.
|
||||
pooling_bwd->Execute(diff_dst_data, diff_src_data);
|
||||
pooling_bwd->Execute(diff_dst_data, diff_src_data, nullptr,
|
||||
bwd_cpu_stream);
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
", message: " + string(e.message) + ", in file " +
|
||||
|
||||
@ -265,8 +265,7 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklConcatFwdPrimitive(const MklConcatFwdParams& concat_fwd_dims,
|
||||
const std::vector<memory::desc>& srcs_md)
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
// Create concat primitive
|
||||
Setup(concat_fwd_dims, srcs_md);
|
||||
}
|
||||
@ -278,7 +277,8 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
||||
// dst_data: output data buffer of dst
|
||||
void Execute(const std::vector<mkldnn::memory>& in_data,
|
||||
const mkldnn::memory& dst_data,
|
||||
const MklConcatFwdParams& concat_fwd_dims) {
|
||||
const MklConcatFwdParams& concat_fwd_dims,
|
||||
std::shared_ptr<stream> fwd_stream) {
|
||||
DCHECK_EQ(in_data.size(), context_.data_mem.size());
|
||||
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
|
||||
context_.data_mem_shdptr[i]->set_data_handle(
|
||||
@ -292,10 +292,10 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
||||
}
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.fwd_primitives, context_.fwd_stream,
|
||||
execute_primitives(context_.fwd_primitives, fwd_stream,
|
||||
context_.fwd_primitives_args);
|
||||
#else
|
||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
||||
fwd_stream->submit(context_.fwd_primitives);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// After exec, set data handle back
|
||||
@ -335,7 +335,6 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
||||
std::shared_ptr<mkldnn::concat::primitive_desc> fwd_pd;
|
||||
std::shared_ptr<mkldnn::primitive> concat_fwd;
|
||||
|
||||
std::shared_ptr<mkldnn::stream> fwd_stream;
|
||||
std::vector<mkldnn::primitive> fwd_primitives;
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
@ -343,10 +342,7 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
ConcatFwdContext()
|
||||
: dst_mem(nullptr),
|
||||
fwd_pd(nullptr),
|
||||
concat_fwd(nullptr),
|
||||
fwd_stream(nullptr) {}
|
||||
: dst_mem(nullptr), fwd_pd(nullptr), concat_fwd(nullptr) {}
|
||||
};
|
||||
|
||||
// Creates the src and dst memory descriptor for mkl concat
|
||||
@ -417,7 +413,6 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
||||
}
|
||||
|
||||
struct ConcatFwdContext context_;
|
||||
engine cpu_engine_;
|
||||
};
|
||||
|
||||
// Class to create/cache the mkl concat primitives based on the
|
||||
@ -758,7 +753,7 @@ class MklConcatOp : public OpKernel {
|
||||
for (int k = 0; k < input_tensors.size(); k++) {
|
||||
if (input_tensors[k].NumElements() > 0) {
|
||||
srcs[k].CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(srcs_pd[k], cpu_engine));
|
||||
MEMORY_PD_WITHOUT_DATA(srcs_pd[k], cpu_engine), context);
|
||||
inputs.push_back(srcs[k].GetOpMem());
|
||||
}
|
||||
}
|
||||
@ -796,7 +791,8 @@ class MklConcatOp : public OpKernel {
|
||||
if (dnn_shape_dst.IsMklTensor())
|
||||
dst_md = dnn_shape_dst.GetMklLayout();
|
||||
dst.SetUsrMem(dst_md, dst_tensor);
|
||||
stream concat_stream = CPU_STREAM(cpu_engine);
|
||||
std::shared_ptr<stream> fwd_cpu_stream;
|
||||
fwd_cpu_stream.reset(CreateStream(context, cpu_engine));
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
auto concat_op = concat(concat_pd);
|
||||
std::unordered_map<int, memory> net_args = {
|
||||
@ -805,12 +801,12 @@ class MklConcatOp : public OpKernel {
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
net_args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, inputs[i]});
|
||||
}
|
||||
concat_op.execute(concat_stream, net_args);
|
||||
concat_op.execute(*fwd_cpu_stream, net_args);
|
||||
#else
|
||||
auto concat_op = concat(concat_pd, inputs, dst.GetOpMem());
|
||||
std::vector<primitive> net;
|
||||
net.push_back(concat_op);
|
||||
concat_stream.submit(net).wait();
|
||||
fwd_cpu_stream->submit(net).wait();
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
} else {
|
||||
MklConcatFwdPrimitive<T>* concat_fwd = nullptr;
|
||||
@ -835,9 +831,11 @@ class MklConcatOp : public OpKernel {
|
||||
dst_md = dnn_shape_dst.IsMklTensor() ? dnn_shape_dst.GetMklLayout()
|
||||
: dst_md;
|
||||
dst.SetUsrMem(dst_md, dst_tensor);
|
||||
|
||||
std::shared_ptr<stream> fwd_cpu_stream;
|
||||
fwd_cpu_stream.reset(CreateStream(context, concat_fwd->GetEngine()));
|
||||
// Execute concat
|
||||
concat_fwd->Execute(srcs_mem, dst.GetOpMem(), concat_fwd_dims);
|
||||
concat_fwd->Execute(srcs_mem, dst.GetOpMem(), concat_fwd_dims,
|
||||
fwd_cpu_stream);
|
||||
}
|
||||
|
||||
// For quantized concat, min and max outputs are also computed.
|
||||
|
||||
@ -97,9 +97,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklConvBwdFilterPrimitive(
|
||||
const MklConvBwdFilterParams& convBwdFilterDims)
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
context_.bwd_filter_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
// Create convolution backward filter primitive.
|
||||
if (context_.conv_bwd_filter == nullptr) {
|
||||
Setup(convBwdFilterDims);
|
||||
@ -114,7 +112,8 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
||||
// diff_bias_data: output data buffer for diff_bias
|
||||
// diff_dst_data: input data buffer for diff_dst
|
||||
void Execute(const T* src_data, const T* diff_filter_data,
|
||||
const T* diff_bias_data, const T* diff_dst_data) {
|
||||
const T* diff_bias_data, const T* diff_dst_data,
|
||||
std::shared_ptr<stream> bwd_filter_stream) {
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.diff_filter_mem->set_data_handle(
|
||||
@ -127,11 +126,10 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
||||
static_cast<void*>(const_cast<T*>(diff_dst_data)));
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.bwd_filter_primitives,
|
||||
context_.bwd_filter_stream,
|
||||
execute_primitives(context_.bwd_filter_primitives, bwd_filter_stream,
|
||||
context_.bwd_filter_primitives_args);
|
||||
#else
|
||||
context_.bwd_filter_stream->submit(context_.bwd_filter_primitives);
|
||||
bwd_filter_stream->submit(context_.bwd_filter_primitives);
|
||||
#endif
|
||||
|
||||
context_.src_mem->set_data_handle(DummyData);
|
||||
@ -147,8 +145,10 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
||||
// diff_filter_data: output data buffer of diff_filter
|
||||
// diff_dst_data: input data buffer of diff_dst
|
||||
void Execute(const T* src_data, const T* diff_filter_data,
|
||||
const T* diff_dst_data) {
|
||||
Execute(src_data, diff_filter_data, nullptr, diff_dst_data);
|
||||
const T* diff_dst_data,
|
||||
std::shared_ptr<stream> bwd_filter_stream) {
|
||||
Execute(src_data, diff_filter_data, nullptr, diff_dst_data,
|
||||
bwd_filter_stream);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
@ -223,8 +223,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
||||
src_md(nullptr),
|
||||
diff_filter_md(nullptr),
|
||||
diff_bias_md(nullptr),
|
||||
diff_dst_md(nullptr),
|
||||
bwd_filter_stream(nullptr) {
|
||||
diff_dst_md(nullptr) {
|
||||
}
|
||||
};
|
||||
|
||||
@ -345,7 +344,6 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
||||
}
|
||||
|
||||
struct ConvBwdFilterContext context_;
|
||||
engine cpu_engine_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -600,8 +598,10 @@ class MklConvCustomBackpropFilterOp
|
||||
auto bwd_filter_pd = conv_bwd_filter->GetPrimitiveDesc();
|
||||
if (IS_SRC_REORDER_NEEDED(fwd_src_md, bwd_filter_pd, conv_bwd_filter)) {
|
||||
src.SetUsrMem(fwd_src_md, &src_tensor);
|
||||
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
bwd_filter_pd->PRIMITIVE_DESC_SRC, cpu_engine_));
|
||||
src.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(bwd_filter_pd->PRIMITIVE_DESC_SRC,
|
||||
cpu_engine_),
|
||||
context);
|
||||
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
|
||||
@ -612,8 +612,10 @@ class MklConvCustomBackpropFilterOp
|
||||
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bwd_filter_pd,
|
||||
conv_bwd_filter)) {
|
||||
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
|
||||
diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
bwd_filter_pd->PRIMITIVE_DESC_DIFF_DST, cpu_engine_));
|
||||
diff_dst.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(bwd_filter_pd->PRIMITIVE_DESC_DIFF_DST,
|
||||
cpu_engine_),
|
||||
context);
|
||||
diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
diff_dst_data =
|
||||
@ -646,18 +648,21 @@ class MklConvCustomBackpropFilterOp
|
||||
}
|
||||
|
||||
// Execute convolution backward filter.
|
||||
std::shared_ptr<stream> bwd_cpu_stream;
|
||||
bwd_cpu_stream.reset(CreateStream(context, conv_bwd_filter->GetEngine()));
|
||||
if (bias_enabled) {
|
||||
T* diff_bias_data =
|
||||
static_cast<T*>(const_cast<T*>(diff_bias_tensor->flat<T>().data()));
|
||||
conv_bwd_filter->Execute(src_data, diff_filter_data, diff_bias_data,
|
||||
diff_dst_data);
|
||||
diff_dst_data, bwd_cpu_stream);
|
||||
} else {
|
||||
conv_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data);
|
||||
conv_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data,
|
||||
bwd_cpu_stream);
|
||||
}
|
||||
|
||||
// Reorder diff_filter back to Tensorflow layout if necessary.
|
||||
if (diff_filter_reorder_required) {
|
||||
diff_filter.InsertReorderToUserMem();
|
||||
diff_filter.InsertReorderToUserMem(context);
|
||||
}
|
||||
|
||||
// Delete primitive since it is not cached.
|
||||
|
||||
@ -99,9 +99,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklConvBwdInputPrimitive(
|
||||
const MklConvBwdInputParams& convBwdInputDims)
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
context_.bwd_input_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
// Create conv bwd input primitive
|
||||
if (context_.conv_bwd_input == nullptr) {
|
||||
Setup(convBwdInputDims);
|
||||
@ -116,7 +114,8 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
|
||||
// diff_dst_data: input data buffer for dst
|
||||
// Bias does not matter here
|
||||
void Execute(const T* diff_src_data, const T* filter_data,
|
||||
const T* diff_dst_data) {
|
||||
const T* diff_dst_data,
|
||||
std::shared_ptr<stream> bwd_input_stream) {
|
||||
context_.diff_src_mem->set_data_handle(
|
||||
static_cast<T*>(const_cast<T*>(diff_src_data)));
|
||||
context_.filter_mem->set_data_handle(
|
||||
@ -125,10 +124,10 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
|
||||
static_cast<T*>(const_cast<T*>(diff_dst_data)));
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.bwd_input_primitives, context_.bwd_input_stream,
|
||||
execute_primitives(context_.bwd_input_primitives, bwd_input_stream,
|
||||
context_.bwd_input_primitives_args);
|
||||
#else
|
||||
context_.bwd_input_stream->submit(context_.bwd_input_primitives);
|
||||
bwd_input_stream->submit(context_.bwd_input_primitives);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// Set data handle back to DummyData.
|
||||
@ -180,7 +179,6 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
|
||||
std::shared_ptr<memory::desc> diff_dst_md;
|
||||
|
||||
// MKL-DNN pipeline for executing primitives.
|
||||
std::shared_ptr<mkldnn::stream> bwd_input_stream;
|
||||
std::vector<mkldnn::primitive> bwd_input_primitives;
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
@ -203,8 +201,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
|
||||
fwd_pd(nullptr),
|
||||
diff_src_md(nullptr),
|
||||
filter_md(nullptr),
|
||||
diff_dst_md(nullptr),
|
||||
bwd_input_stream(nullptr) {
|
||||
diff_dst_md(nullptr) {
|
||||
}
|
||||
};
|
||||
|
||||
@ -290,7 +287,6 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
|
||||
}
|
||||
|
||||
struct ConvBwdInputContext context_;
|
||||
engine cpu_engine_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -522,8 +518,10 @@ class MklConvCustomBackpropInputOp
|
||||
if (IS_FILTER_REORDER_NEEDED(fwd_filter_md, bwd_input_pd,
|
||||
conv_bwd_input)) {
|
||||
filter.SetUsrMem(fwd_filter_md, &filter_tensor);
|
||||
filter.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
bwd_input_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_));
|
||||
filter.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(bwd_input_pd.get()->PRIMITIVE_DESC_WEIGHTS,
|
||||
cpu_engine_),
|
||||
context);
|
||||
filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
filter_data =
|
||||
@ -535,23 +533,29 @@ class MklConvCustomBackpropInputOp
|
||||
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bwd_input_pd,
|
||||
conv_bwd_input)) {
|
||||
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
|
||||
diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
bwd_input_pd.get()->PRIMITIVE_DESC_DIFF_DST, cpu_engine_));
|
||||
diff_dst.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(bwd_input_pd.get()->PRIMITIVE_DESC_DIFF_DST,
|
||||
cpu_engine_),
|
||||
context);
|
||||
diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
diff_dst_data =
|
||||
static_cast<T*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
|
||||
}
|
||||
|
||||
std::shared_ptr<stream> bwd_cpu_stream;
|
||||
bwd_cpu_stream.reset(CreateStream(context, conv_bwd_input->GetEngine()));
|
||||
// Execute conv bwd input primitive.
|
||||
if (!eager_mode) {
|
||||
conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
|
||||
conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data,
|
||||
bwd_cpu_stream);
|
||||
} else {
|
||||
// In eager mode we first write the output to temporary
|
||||
// buffer in MKL format. Then we convert the data to TF format.
|
||||
T* tmp_data =
|
||||
static_cast<T*>(const_cast<T*>(tmp_tensor.flat<T>().data()));
|
||||
conv_bwd_input->Execute(tmp_data, filter_data, diff_dst_data);
|
||||
conv_bwd_input->Execute(tmp_data, filter_data, diff_dst_data,
|
||||
bwd_cpu_stream);
|
||||
auto output_tf_md = diff_src_mkl_shape.GetTfLayout();
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine_);
|
||||
@ -563,7 +567,7 @@ class MklConvCustomBackpropInputOp
|
||||
memory* dst_data_mem =
|
||||
new MEMORY_CONSTRUCTOR(OUTPUT_TF_MD, cpu_engine_, diff_src_data);
|
||||
CreateAndExecuteReorder(reorder_pd, *tmp_data_mem, *dst_data_mem,
|
||||
cpu_engine_);
|
||||
cpu_engine_, context);
|
||||
}
|
||||
|
||||
// Delete primitive since it is not cached.
|
||||
|
||||
@ -155,7 +155,8 @@ class MklDequantizeOp : public OpKernel {
|
||||
// Also it does not define round_nearest (enum).
|
||||
attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
stream reorder_stream = CPU_STREAM(cpu_engine);
|
||||
std::shared_ptr<stream> reorder_stream;
|
||||
reorder_stream.reset(CreateStream(ctx, cpu_engine));
|
||||
std::vector<primitive> net;
|
||||
|
||||
// Create reorder primitive and then execute.
|
||||
@ -169,11 +170,10 @@ class MklDequantizeOp : public OpKernel {
|
||||
reorder_net_args.push_back({{MKLDNN_ARG_FROM, *src.GetUsrMem()},
|
||||
{ MKLDNN_ARG_TO,
|
||||
*dst.GetUsrMem() }});
|
||||
execute_primitives(net, std::make_shared<stream>(reorder_stream),
|
||||
reorder_net_args);
|
||||
execute_primitives(net, reorder_stream, reorder_net_args);
|
||||
#else
|
||||
net.push_back(reorder(reorder_pd, *src.GetUsrMem(), *dst.GetUsrMem()));
|
||||
reorder_stream.submit(net);
|
||||
reorder_stream->submit(net);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
|
||||
@ -79,13 +79,7 @@ template <typename T, typename U>
|
||||
class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklFusedBatchNormFwdPrimitive(const MklBatchNormFwdParams& fwdParams)
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
#else
|
||||
context_.fwd_stream.reset(
|
||||
new mkldnn::stream(mkldnn::stream::kind::eager_nostore));
|
||||
#endif
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
if (context_.bn_fwd == nullptr) Setup(fwdParams);
|
||||
}
|
||||
|
||||
@ -98,7 +92,8 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
||||
// mean_data: output data buffer of means
|
||||
// variance_data: output data buffer of variances
|
||||
void Execute(const T* src_data, const U* weights_data, T* dst_data,
|
||||
U* mean_data, U* variance_data, U* workspace_data) {
|
||||
U* mean_data, U* variance_data,
|
||||
std::shared_ptr<stream> fwd_stream, U* workspace_data) {
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
||||
@ -117,10 +112,10 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
||||
}
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// Execute batch-normalization forward primitives.
|
||||
execute_primitives(context_.fwd_primitives, context_.fwd_stream,
|
||||
context_.net_args);
|
||||
execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
|
||||
#else
|
||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
||||
fwd_stream.reset(new stream(stream::kind::eager_nostore));
|
||||
fwd_stream->submit(context_.fwd_primitives);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
context_.src_mem->set_data_handle(DummyData);
|
||||
@ -180,7 +175,6 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
||||
|
||||
// BatchNorm forward primitive.
|
||||
std::shared_ptr<mkldnn::primitive> bn_fwd;
|
||||
std::shared_ptr<mkldnn::stream> fwd_stream;
|
||||
std::vector<mkldnn::primitive> fwd_primitives;
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
@ -195,9 +189,8 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
||||
dst_mem(nullptr),
|
||||
mean_mem(nullptr),
|
||||
variance_mem(nullptr),
|
||||
ws_mem(nullptr),
|
||||
bn_fwd(nullptr),
|
||||
fwd_stream(nullptr) {}
|
||||
ws_mem(nullptr) {}
|
||||
};
|
||||
|
||||
void Setup(const MklBatchNormFwdParams& fwdParams) {
|
||||
@ -392,7 +385,6 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
||||
}
|
||||
|
||||
struct BatchNormFwdContext context_;
|
||||
engine cpu_engine_;
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
@ -489,13 +481,7 @@ template <typename T, typename U>
|
||||
class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams& bwdParams)
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
context_.bwd_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
#else
|
||||
context_.bwd_stream.reset(
|
||||
new mkldnn::stream(mkldnn::stream::kind::eager_nostore));
|
||||
#endif
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
if (context_.bn_bwd == nullptr) Setup(bwdParams);
|
||||
}
|
||||
|
||||
@ -515,7 +501,8 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
|
||||
// on CPU as of now.
|
||||
void Execute(const T* src_data, const U* mean_data, const U* variance_data,
|
||||
const T* diff_dst_data, const U* weights_data, T* diff_src_data,
|
||||
U* diff_weights_data, U* res_space_data) {
|
||||
U* diff_weights_data, U* res_space_data,
|
||||
std::shared_ptr<stream> bwd_stream) {
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.mean_mem->set_data_handle(
|
||||
@ -537,10 +524,10 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// Execute backward batch-normalization primitives.
|
||||
DCHECK_EQ(context_.bwd_primitives.size(), context_.net_args.size());
|
||||
execute_primitives(context_.bwd_primitives, context_.bwd_stream,
|
||||
context_.net_args);
|
||||
execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args);
|
||||
#else
|
||||
context_.bwd_stream->submit(context_.bwd_primitives);
|
||||
bwd_stream.reset(new stream(stream::kind::eager_nostore));
|
||||
bwd_stream->submit(context_.bwd_primitives);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// After execution, set data handle back to DummyData.
|
||||
@ -593,7 +580,6 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
|
||||
// Backward batch-normalization primitive.
|
||||
std::shared_ptr<mkldnn::primitive> bn_bwd;
|
||||
std::vector<mkldnn::primitive> bwd_primitives;
|
||||
std::shared_ptr<mkldnn::stream> bwd_stream;
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
std::vector<std::unordered_map<int, memory>> net_args;
|
||||
@ -606,8 +592,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
|
||||
diff_dst_mem(nullptr),
|
||||
weights_mem(nullptr),
|
||||
diff_weights_mem(nullptr),
|
||||
diff_src_mem(nullptr),
|
||||
bwd_stream(nullptr) {}
|
||||
diff_src_mem(nullptr) {}
|
||||
};
|
||||
|
||||
void Setup(const MklBatchNormBwdParams& bwdParams) {
|
||||
@ -616,7 +601,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
|
||||
? GET_FLAG(use_scale_shift)
|
||||
: (GET_FLAG(use_scale_shift) | GET_FLAG(use_global_stats));
|
||||
|
||||
// Memory descriptors.
|
||||
// Memory descriptors.
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
auto src_md = memory::desc({bwdParams.src_dims}, MklDnnType<T>(),
|
||||
bwdParams.src_format);
|
||||
@ -689,7 +674,6 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
|
||||
}
|
||||
|
||||
struct BatchNormBwdContext context_;
|
||||
engine cpu_engine_;
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
@ -960,8 +944,10 @@ class MklFusedBatchNormOp : public OpKernel {
|
||||
std::shared_ptr<BatchNormFwdPd> bn_fwd_pd = bn_fwd->GetBatchNormFwdPd();
|
||||
if (IS_SRC_REORDER_NEEDED(src_md, bn_fwd_pd, bn_fwd)) {
|
||||
src.SetUsrMem(src_md, &src_tensor);
|
||||
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
GET_SRC_DESC_FROM_OP_PD(bn_fwd_pd), cpu_engine_));
|
||||
src.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(GET_SRC_DESC_FROM_OP_PD(bn_fwd_pd),
|
||||
cpu_engine_),
|
||||
context);
|
||||
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
|
||||
@ -987,9 +973,10 @@ class MklFusedBatchNormOp : public OpKernel {
|
||||
T* dst_data = dst_tensor->flat<T>().data();
|
||||
|
||||
// Execute
|
||||
std::shared_ptr<stream> fwd_cpu_stream;
|
||||
fwd_cpu_stream.reset(CreateStream(context, bn_fwd->GetEngine()));
|
||||
bn_fwd->Execute(src_data, weights_op_data, dst_data, mean_op_data,
|
||||
variance_op_data, ws_data);
|
||||
|
||||
variance_op_data, fwd_cpu_stream, ws_data);
|
||||
float adjust_factor = 1.0;
|
||||
if (is_training_) {
|
||||
size_t orig_size = src_dims[0] * src_dims[2] * src_dims[3];
|
||||
@ -1003,9 +990,6 @@ class MklFusedBatchNormOp : public OpKernel {
|
||||
auto batch_variance_data = batch_variance_tensor->flat<U>().data();
|
||||
auto est_mean_data = est_mean_tensor.flat<U>().data();
|
||||
auto est_variance_data = est_variance_tensor.flat<U>().data();
|
||||
|
||||
// TODO(intel-tf): Merge the `is_training && exponential_avg_factor == 1`
|
||||
// case with the `else` (`!is_training`) case if possible.
|
||||
if (is_training_) {
|
||||
if (exponential_avg_factor_ == U(1.0)) {
|
||||
for (int k = 0; k < depth_; k++) {
|
||||
@ -1328,8 +1312,10 @@ class MklFusedBatchNormGradOp : public OpKernel {
|
||||
std::shared_ptr<BatchNormBwdPd> bn_bwd_pd = bn_bwd->GetBatchNormBwdPd();
|
||||
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bn_bwd_pd, bn_bwd)) {
|
||||
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
|
||||
diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
GET_DIFF_DST_DESC_FROM_OP_PD(bn_bwd_pd), cpu_engine_));
|
||||
diff_dst.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(GET_DIFF_DST_DESC_FROM_OP_PD(bn_bwd_pd),
|
||||
cpu_engine_),
|
||||
context);
|
||||
diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
diff_dst_data =
|
||||
@ -1366,10 +1352,11 @@ class MklFusedBatchNormGradOp : public OpKernel {
|
||||
: nullptr);
|
||||
|
||||
// Execute
|
||||
std::shared_ptr<stream> bwd_cpu_stream;
|
||||
bwd_cpu_stream.reset(CreateStream(context, bn_bwd->GetEngine()));
|
||||
bn_bwd->Execute(src_data, mean_data, variance_data, diff_dst_data,
|
||||
weights_data, diff_src_data, diff_weights_data,
|
||||
res_space_data);
|
||||
|
||||
res_space_data, bwd_cpu_stream);
|
||||
// Allocate output TF tensors diff_scale and diff_shift.
|
||||
Tensor* diff_scale_tensor = nullptr;
|
||||
Tensor* diff_shift_tensor = nullptr;
|
||||
|
||||
@ -88,7 +88,6 @@ class MklLRNOp : public OpKernel {
|
||||
workspace_enabled_ = false;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("workspace_enabled", &workspace_enabled_));
|
||||
fwd_stream_.reset(new CPU_STREAM(cpu_engine_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
@ -169,6 +168,7 @@ class MklLRNOp : public OpKernel {
|
||||
lrn_prim_desc.PRIMITIVE_DESC_SRC, cpu_engine_));
|
||||
|
||||
std::vector<primitive> net;
|
||||
fwd_stream_.reset(CreateStream(context, cpu_engine_));
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
net.push_back(lrn_forward(lrn_prim_desc));
|
||||
std::vector<std::unordered_map<int, memory>> net_args;
|
||||
|
||||
@ -168,14 +168,13 @@ class MklMatMulOp : public OpKernel {
|
||||
const int index_transa = transa ? 1 : 0;
|
||||
const int index_transb = transb ? 1 : 0;
|
||||
|
||||
Tensor c_float;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {m, n}, &c_float));
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
const char ftrans[] = {'N', 'T', 'C'};
|
||||
dnnl_gemm<bfloat16>(ftrans[index_transa], ftrans[index_transb], m, n, k,
|
||||
alpha, a, lda, b, ldb, beta,
|
||||
c_float.flat<float>().data(), ldc);
|
||||
alpha, a, lda, b, ldb, beta, c, ldc);
|
||||
#else
|
||||
Tensor c_float;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {m, n}, &c_float));
|
||||
const char* const ftrans[] = {"N", "T", "C"};
|
||||
|
||||
// MKL-DNN only supports the Fortran API and requires column major while
|
||||
@ -185,8 +184,8 @@ class MklMatMulOp : public OpKernel {
|
||||
reinterpret_cast<const mkldnn_bfloat16_t*>(b), &ldb,
|
||||
reinterpret_cast<const mkldnn_bfloat16_t*>(a), &lda,
|
||||
&beta, c_float.flat<float>().data(), &ldc);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
FloatToBFloat16(c_float.flat<float>().data(), c, c_float.NumElements());
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
}
|
||||
#endif // ENABLE_INTEL_MKL_BFLOAT16
|
||||
};
|
||||
|
||||
@ -516,33 +516,190 @@ class MklDnnMatMulOpBase : public OpKernel {
|
||||
|
||||
// MatMul support for bfloat16 and int8 types is introduced in DNNLv1.2.
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
|
||||
using mkldnn::matmul;
|
||||
|
||||
namespace {
|
||||
|
||||
void dnnl_gemm_exec(const memory::desc& a_md, const memory::desc& b_md,
|
||||
const memory::desc& c_md, const void* a, const void* b,
|
||||
void* c, const primitive_attr& attr) {
|
||||
// Create a MatMul primitive
|
||||
mkldnn::engine cpu_engine = mkldnn::engine(ENGINE_CPU, 0);
|
||||
mkldnn::matmul::desc matmul_desc(a_md, b_md, c_md);
|
||||
mkldnn::matmul::primitive_desc matmul_pd(matmul_desc, attr, cpu_engine);
|
||||
mkldnn::matmul matmul_prim(matmul_pd);
|
||||
// Wrap raw pointers into DNNL memory objects
|
||||
mkldnn::memory a_memory(a_md, cpu_engine, const_cast<void*>(a));
|
||||
mkldnn::memory b_memory(b_md, cpu_engine, const_cast<void*>(b));
|
||||
mkldnn::memory c_memory(c_md, cpu_engine, c);
|
||||
// Execute the MatMul primitive.
|
||||
// Since here all shapes and parameters are static, please note that we
|
||||
// don't need to pass alpha (scales) again, as they are already hard-coded
|
||||
// in the primitive descriptor. Also, we are not allowed to change the
|
||||
// shapes of matrices A, B, and C -- they should exactly match
|
||||
// the memory descriptors passed to MatMul operation descriptor.
|
||||
mkldnn::stream s(cpu_engine);
|
||||
matmul_prim.execute(s, {{DNNL_ARG_SRC, a_memory},
|
||||
{DNNL_ARG_WEIGHTS, b_memory},
|
||||
{ DNNL_ARG_DST,
|
||||
c_memory }});
|
||||
s.wait();
|
||||
}
|
||||
struct MklMatMulParams {
|
||||
memory::dims a_dims;
|
||||
memory::dims b_dims;
|
||||
memory::dims c_dims;
|
||||
memory::dims a_strides;
|
||||
memory::dims b_strides;
|
||||
memory::dims c_strides;
|
||||
|
||||
MklMatMulParams(memory::dims a_dims, memory::dims b_dims, memory::dims c_dims,
|
||||
memory::dims a_strides, memory::dims b_strides,
|
||||
memory::dims c_strides)
|
||||
: a_dims(a_dims),
|
||||
b_dims(b_dims),
|
||||
c_dims(c_dims),
|
||||
a_strides(a_strides),
|
||||
b_strides(b_strides),
|
||||
c_strides(c_strides) {}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class MklMatMulPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklMatMulPrimitive(const MklMatMulParams& params)
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
context_.stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
// Create matmul primitive
|
||||
Setup(params);
|
||||
}
|
||||
|
||||
~MklMatMulPrimitive() {}
|
||||
|
||||
void Execute(const T* a_data, const T* b_data, T* c_data) {
|
||||
context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data)));
|
||||
context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data)));
|
||||
context_.c_mem->set_data_handle(static_cast<void*>(const_cast<T*>(c_data)));
|
||||
|
||||
execute_primitives(context_.matmul_primitives, context_.stream,
|
||||
context_.net_args);
|
||||
|
||||
// After execution, set data handle back
|
||||
context_.a_mem->set_data_handle(DummyData);
|
||||
context_.b_mem->set_data_handle(DummyData);
|
||||
context_.c_mem->set_data_handle(DummyData);
|
||||
}
|
||||
|
||||
private:
|
||||
// Primitive reuse context for MatMul op
|
||||
struct MklMatMulContext {
|
||||
// MKL-DNN memory.
|
||||
std::shared_ptr<mkldnn::memory> a_mem;
|
||||
std::shared_ptr<mkldnn::memory> b_mem;
|
||||
std::shared_ptr<mkldnn::memory> c_mem;
|
||||
|
||||
// Descriptor and primitive-descriptor for MatMul.
|
||||
std::shared_ptr<matmul::desc> desc;
|
||||
std::shared_ptr<matmul::primitive_desc> prim_desc;
|
||||
|
||||
// Memory descriptors.
|
||||
std::shared_ptr<mkldnn::memory::desc> a_md;
|
||||
std::shared_ptr<mkldnn::memory::desc> b_md;
|
||||
std::shared_ptr<mkldnn::memory::desc> c_md;
|
||||
|
||||
// MatMul primitive.
|
||||
std::shared_ptr<mkldnn::stream> stream;
|
||||
std::vector<mkldnn::primitive> matmul_primitives;
|
||||
std::vector<std::unordered_map<int, memory>> net_args;
|
||||
|
||||
MklMatMulContext()
|
||||
: a_mem(nullptr),
|
||||
b_mem(nullptr),
|
||||
c_mem(nullptr),
|
||||
desc(nullptr),
|
||||
prim_desc(nullptr),
|
||||
a_md(nullptr),
|
||||
b_md(nullptr),
|
||||
c_md(nullptr),
|
||||
stream(nullptr) {}
|
||||
};
|
||||
|
||||
void Setup(const MklMatMulParams& params) {
|
||||
std::shared_ptr<mkldnn::primitive> matmul_primitive = nullptr;
|
||||
|
||||
// Create MatMul descriptor and primitive descriptor.
|
||||
context_.a_md.reset(
|
||||
new memory::desc({params.a_dims}, MklDnnType<T>(), params.a_strides));
|
||||
|
||||
context_.b_md.reset(
|
||||
new memory::desc({params.b_dims}, MklDnnType<T>(), params.b_strides));
|
||||
|
||||
context_.c_md.reset(
|
||||
new memory::desc({params.c_dims}, MklDnnType<T>(), params.c_strides));
|
||||
|
||||
// Create matmul.
|
||||
context_.desc.reset(
|
||||
new matmul::desc(*context_.a_md, *context_.b_md, *context_.c_md));
|
||||
context_.prim_desc.reset(
|
||||
new matmul::primitive_desc(*context_.desc, cpu_engine_));
|
||||
|
||||
// Create memory primitive based on dummy data.
|
||||
context_.a_mem.reset(
|
||||
new mkldnn::memory(*context_.a_md, cpu_engine_, DummyData));
|
||||
context_.b_mem.reset(
|
||||
new mkldnn::memory(*context_.b_md, cpu_engine_, DummyData));
|
||||
context_.c_mem.reset(
|
||||
new mkldnn::memory(*context_.b_md, cpu_engine_, DummyData));
|
||||
|
||||
// Create matmul primitive.
|
||||
matmul_primitive.reset(new mkldnn::matmul(*context_.prim_desc));
|
||||
context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.a_mem},
|
||||
{MKLDNN_ARG_WEIGHTS, *context_.b_mem},
|
||||
{ MKLDNN_ARG_DST,
|
||||
*context_.c_mem }});
|
||||
|
||||
context_.matmul_primitives.push_back(*matmul_primitive);
|
||||
return;
|
||||
}
|
||||
|
||||
struct MklMatMulContext context_;
|
||||
engine cpu_engine_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class MklMatMulPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
public:
|
||||
static MklMatMulPrimitive<T>* Get(const MklMatMulParams& params,
|
||||
bool do_not_cache) {
|
||||
MklMatMulPrimitive<T>* matmul_prim = nullptr;
|
||||
|
||||
if (do_not_cache) {
|
||||
// Always create new primitive
|
||||
matmul_prim = new MklMatMulPrimitive<T>(params);
|
||||
} else {
|
||||
// Try to find a suitable one in pool
|
||||
matmul_prim = dynamic_cast<MklMatMulPrimitive<T>*>(
|
||||
MklMatMulPrimitiveFactory<T>::GetInstance().GetMklMatMul(params));
|
||||
if (matmul_prim == nullptr) {
|
||||
matmul_prim = new MklMatMulPrimitive<T>(params);
|
||||
MklMatMulPrimitiveFactory<T>::GetInstance().SetMklMatMul(params,
|
||||
matmul_prim);
|
||||
}
|
||||
}
|
||||
|
||||
return matmul_prim;
|
||||
}
|
||||
|
||||
private:
|
||||
MklMatMulPrimitiveFactory() {}
|
||||
~MklMatMulPrimitiveFactory() {}
|
||||
|
||||
static MklMatMulPrimitiveFactory& GetInstance() {
|
||||
static MklMatMulPrimitiveFactory instance_;
|
||||
return instance_;
|
||||
}
|
||||
|
||||
static string CreateKey(const MklMatMulParams& params) {
|
||||
string prefix = "matmul_";
|
||||
FactoryKeyCreator key_creator;
|
||||
key_creator.AddAsKey(prefix);
|
||||
key_creator.AddAsKey(params.a_dims);
|
||||
key_creator.AddAsKey(params.b_dims);
|
||||
key_creator.AddAsKey(params.c_dims);
|
||||
key_creator.AddAsKey(params.a_strides);
|
||||
key_creator.AddAsKey(params.b_strides);
|
||||
key_creator.AddAsKey(params.c_strides);
|
||||
key_creator.AddAsKey(typeid(T).name());
|
||||
|
||||
return key_creator.GetKey();
|
||||
}
|
||||
|
||||
MklPrimitive* GetMklMatMul(const MklMatMulParams& params) {
|
||||
string key = CreateKey(params);
|
||||
return this->GetOp(key);
|
||||
}
|
||||
|
||||
void SetMklMatMul(const MklMatMulParams& params, MklPrimitive* op) {
|
||||
string key = CreateKey(params);
|
||||
this->SetOp(key, op);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void dnnl_gemm_batch(const std::vector<bool>& transa,
|
||||
@ -589,45 +746,47 @@ void dnnl_gemm_batch(const std::vector<bool>& transa,
|
||||
!transb[0] ? dims{k[0] * n[0], n[0], 1} : dims{n[0] * k[0], 1, k[0]};
|
||||
dims c_strides = dims{m[0] * n[0], n[0], 1};
|
||||
|
||||
// Prepare memory descriptors
|
||||
memory::desc a_md(a_sizes, MklDnnType<T>(), a_strides);
|
||||
memory::desc b_md(b_sizes, MklDnnType<T>(), b_strides);
|
||||
memory::desc c_md(c_sizes, MklDnnType<T>(), c_strides);
|
||||
// Create attributes (to handle alpha and beta if necessary)
|
||||
mkldnn::primitive_attr attr;
|
||||
if (alpha[0] != 1.f) attr.set_output_scales(/* mask */ 0, {alpha[0]});
|
||||
if (beta[0] != 0.f) {
|
||||
mkldnn::post_ops po;
|
||||
po.append_sum(beta[0]);
|
||||
attr.set_post_ops(po);
|
||||
}
|
||||
dnnl_gemm_exec(a_md, b_md, c_md, static_cast<const void*>(a),
|
||||
static_cast<const void*>(b), static_cast<void*>(c), attr);
|
||||
// MklMatMul uses const alpha and beta, make guarantee here to ensure
|
||||
// they are never changed.
|
||||
DCHECK_EQ(alpha, 1.0f);
|
||||
DCHECK_EQ(beta, 0.f);
|
||||
|
||||
MklMatMulParams params(a_sizes, b_sizes, c_sizes, a_strides, b_strides,
|
||||
c_strides);
|
||||
MklMatMulPrimitive<T>* matmul_prim =
|
||||
MklMatMulPrimitiveFactory<T>::Get(params, 0);
|
||||
|
||||
// Execute matmul primitive.
|
||||
matmul_prim->Execute(a, b, c);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k,
|
||||
float alpha, const T* a, int64_t lda, const T* b, int64_t ldb,
|
||||
float beta, float* c, int64_t ldc) {
|
||||
float beta, T* c, int64_t ldc) {
|
||||
using dims = mkldnn::memory::dims;
|
||||
|
||||
// Prepare strides based on the transa and transb flags: transposed
|
||||
// matrices have strides swapped
|
||||
dims a_dims = dims{m, k};
|
||||
dims b_dims = dims{k, n};
|
||||
dims c_dims = dims{m, n};
|
||||
dims a_strides = tolower(transa) == 'n' ? dims{lda, 1} : dims{1, lda};
|
||||
dims b_strides = tolower(transb) == 'n' ? dims{ldb, 1} : dims{1, ldb};
|
||||
// Prepare memory descriptors
|
||||
memory::desc a_md({m, k}, MklDnnType<T>(), a_strides);
|
||||
memory::desc b_md({k, n}, MklDnnType<T>(), b_strides);
|
||||
memory::desc c_md({m, n}, MklDnnType<float>(), {ldc, 1});
|
||||
// Create attributes (to handle alpha and beta if necessary)
|
||||
mkldnn::primitive_attr attr;
|
||||
if (alpha != 1.f) attr.set_output_scales(/* mask */ 0, {alpha});
|
||||
if (beta != 0.f) {
|
||||
mkldnn::post_ops po;
|
||||
po.append_sum(beta);
|
||||
attr.set_post_ops(po);
|
||||
}
|
||||
dnnl_gemm_exec(a_md, b_md, c_md, static_cast<const void*>(a),
|
||||
static_cast<const void*>(b), static_cast<void*>(c), attr);
|
||||
dims c_strides = dims{ldc, 1};
|
||||
|
||||
// MklMatMul uses const alpha and beta, make guarantee here to ensure
|
||||
// they are never changed.
|
||||
DCHECK_EQ(alpha, 1.0f);
|
||||
DCHECK_EQ(beta, 0.f);
|
||||
|
||||
MklMatMulParams params(a_dims, b_dims, c_dims, a_strides, b_strides,
|
||||
c_strides);
|
||||
MklMatMulPrimitive<T>* matmul_prim =
|
||||
MklMatMulPrimitiveFactory<T>::Get(params, 0);
|
||||
|
||||
// Execute matmul primitive.
|
||||
matmul_prim->Execute(a, b, c);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
@ -167,10 +167,12 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
|
||||
const T* src_data = input_tensor.flat<T>().data();
|
||||
|
||||
T* dst_data = output_tensor->flat<T>().data();
|
||||
std::shared_ptr<stream> fwd_cpu_stream;
|
||||
fwd_cpu_stream.reset(CreateStream(context, pooling_fwd->GetEngine()));
|
||||
|
||||
if (int8_forward_inference) {
|
||||
// Execute pooling op
|
||||
pooling_fwd->Execute(src_data, dst_data);
|
||||
pooling_fwd->Execute(src_data, dst_data, nullptr, fwd_cpu_stream);
|
||||
|
||||
// Pass min, max from input to output.
|
||||
const Tensor& min_input_t = MklGetInput(context, 1);
|
||||
@ -197,7 +199,7 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
|
||||
T* ws_data =
|
||||
static_cast<T*>(dnn_data_wksp.GetOpMem().get_data_handle());
|
||||
// Execute pooling op.
|
||||
pooling_fwd->Execute(src_data, dst_data, ws_data);
|
||||
pooling_fwd->Execute(src_data, dst_data, ws_data, fwd_cpu_stream);
|
||||
}
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
@ -322,6 +324,8 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
|
||||
MklPoolingBwdPrimitive<T>* pooling_bwd =
|
||||
MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);
|
||||
|
||||
std::shared_ptr<stream> bwd_cpu_stream;
|
||||
bwd_cpu_stream.reset(CreateStream(context, pooling_bwd->GetEngine()));
|
||||
// Allocate output tensor and memory primitive.
|
||||
Tensor* output_tensor = nullptr;
|
||||
this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()),
|
||||
@ -335,8 +339,10 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
|
||||
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, pooling_bwd_pd,
|
||||
pooling_bwd)) {
|
||||
grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor);
|
||||
grad_dnn_data.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
GET_DIFF_DST_DESC_FROM_OP_PD(pooling_bwd_pd), cpu_engine_));
|
||||
grad_dnn_data.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(GET_DIFF_DST_DESC_FROM_OP_PD(pooling_bwd_pd),
|
||||
cpu_engine_),
|
||||
context);
|
||||
diff_dst_data =
|
||||
static_cast<T*>(grad_dnn_data.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
@ -361,7 +367,8 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
|
||||
T* diff_src_data = output_tensor->flat<T>().data();
|
||||
|
||||
// Execute pooling op.
|
||||
pooling_bwd->Execute(diff_dst_data, diff_src_data, ws_data);
|
||||
pooling_bwd->Execute(diff_dst_data, diff_src_data, ws_data,
|
||||
bwd_cpu_stream);
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status:" + std::to_string(e.status) +
|
||||
", message: " + string(e.message) + ". in file " +
|
||||
|
||||
@ -23,7 +23,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/kernel_shape_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
using mkldnn::prop_kind;
|
||||
|
||||
@ -38,11 +37,11 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
|
||||
context_.alg_kind = fwdParams.alg_kind;
|
||||
context_.prop_kind = fwdParams.prop_kind;
|
||||
|
||||
// Create memory descriptor
|
||||
// FIXME: Pooling doesn't expose to get the src_primitive_desc,
|
||||
// so src format is currently hard-coded.
|
||||
// A utility function is used to do this,
|
||||
// which may be broken with future CPU architectures
|
||||
// Create memory descriptor
|
||||
// FIXME: Pooling doesn't expose to get the src_primitive_desc,
|
||||
// so src format is currently hard-coded.
|
||||
// A utility function is used to do this,
|
||||
// which may be broken with future CPU architectures
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
bool is_2d = (fwdParams.src_dims.size() == 4);
|
||||
if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value)
|
||||
@ -126,7 +125,8 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
|
||||
|
||||
template <typename T>
|
||||
void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
|
||||
void* ws_data) {
|
||||
void* ws_data,
|
||||
std::shared_ptr<stream> fwd_stream) {
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
||||
@ -138,10 +138,9 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
|
||||
}
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.fwd_primitives, context_.fwd_stream,
|
||||
context_.net_args);
|
||||
execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
|
||||
#else
|
||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
||||
fwd_stream->submit(context_.fwd_primitives);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// Set back data handle.
|
||||
@ -268,7 +267,8 @@ void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
|
||||
|
||||
template <typename T>
|
||||
void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
|
||||
T* diff_src_data, const void* ws_data) {
|
||||
T* diff_src_data, const void* ws_data,
|
||||
std::shared_ptr<stream> bwd_stream) {
|
||||
context_.diff_dst_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(diff_dst_data)));
|
||||
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
|
||||
@ -278,10 +278,9 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
|
||||
}
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.bwd_primitives, context_.bwd_stream,
|
||||
context_.net_args);
|
||||
execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args);
|
||||
#else
|
||||
context_.bwd_stream->submit(context_.bwd_primitives);
|
||||
bwd_stream->submit(context_.bwd_primitives);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// Set back data handle.
|
||||
|
||||
@ -86,8 +86,7 @@ template <typename T>
|
||||
class MklPoolingFwdPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklPoolingFwdPrimitive(const MklPoolingParams& fwdParams)
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
if (context_.fwd == nullptr) Setup(fwdParams);
|
||||
}
|
||||
|
||||
@ -97,7 +96,8 @@ class MklPoolingFwdPrimitive : public MklPrimitive {
|
||||
// src_data: input data buffer of src
|
||||
// ws_data: output data buffer of workspace
|
||||
// dst_data: output data buffer of dst
|
||||
void Execute(const T* src_data, T* dst_data, void* ws_data = nullptr);
|
||||
void Execute(const T* src_data, T* dst_data, void* ws_data,
|
||||
std::shared_ptr<stream> fwd_stream);
|
||||
|
||||
std::shared_ptr<PoolingFwdPd> GetPoolingFwdPd() const {
|
||||
return context_.fwd_pd;
|
||||
@ -159,12 +159,10 @@ class MklPoolingFwdPrimitive : public MklPrimitive {
|
||||
fwd_pd(nullptr),
|
||||
src_md(nullptr),
|
||||
dst_md(nullptr),
|
||||
fwd(nullptr),
|
||||
fwd_stream(nullptr) {}
|
||||
fwd(nullptr) {}
|
||||
};
|
||||
|
||||
struct PoolingFwdContext context_;
|
||||
engine cpu_engine_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -229,8 +227,7 @@ template <typename T>
|
||||
class MklPoolingBwdPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklPoolingBwdPrimitive(const MklPoolingParams& bwdParams)
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
context_.bwd_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
if (context_.bwd == nullptr) Setup(bwdParams);
|
||||
}
|
||||
|
||||
@ -240,8 +237,8 @@ class MklPoolingBwdPrimitive : public MklPrimitive {
|
||||
// diff_dst_data: input data buffer of diff_dst
|
||||
// diff_src_data: output data buffer of diff_src
|
||||
// ws_data: input data buffer of workspace
|
||||
void Execute(const T* diff_dst_data, T* diff_src_data,
|
||||
const void* ws_data = nullptr);
|
||||
void Execute(const T* diff_dst_data, T* diff_src_data, const void* ws_data,
|
||||
std::shared_ptr<stream> bwd_stream);
|
||||
|
||||
public:
|
||||
std::shared_ptr<PoolingFwdPd> GetPoolingFwdPd() const {
|
||||
@ -315,12 +312,10 @@ class MklPoolingBwdPrimitive : public MklPrimitive {
|
||||
bwd_desc(nullptr),
|
||||
fwd_pd(nullptr),
|
||||
bwd_pd(nullptr),
|
||||
bwd(nullptr),
|
||||
bwd_stream(nullptr) {}
|
||||
bwd(nullptr) {}
|
||||
};
|
||||
|
||||
struct PoolingBwdContext context_;
|
||||
engine cpu_engine_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
||||
@ -77,7 +77,7 @@ class MklReorderWithScalePrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklReorderWithScalePrimitive(
|
||||
const MklReorderWithScaleFwdParams& fwdParams)
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
// Create reorder primitive
|
||||
Setup(fwdParams);
|
||||
}
|
||||
@ -86,14 +86,14 @@ class MklReorderWithScalePrimitive : public MklPrimitive {
|
||||
|
||||
std::shared_ptr<primitive> GetPrimitive() { return context_.reorder_prim; }
|
||||
|
||||
void Execute(void* src_data, void* dst_data) {
|
||||
void Execute(void* src_data, void* dst_data,
|
||||
std::shared_ptr<stream> reorder_stream) {
|
||||
context_.src_mem->set_data_handle(src_data);
|
||||
context_.dst_mem->set_data_handle(dst_data);
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
context_.reorder_stream->submit(context_.net);
|
||||
reorder_stream->submit(context_.net);
|
||||
#else
|
||||
context_.reorder_prim->execute(*context_.reorder_stream,
|
||||
context_.prim_args);
|
||||
context_.reorder_prim->execute(*reorder_stream, context_.prim_args);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
// After execution, set data handle back.
|
||||
context_.src_mem->set_data_handle(DummyData);
|
||||
@ -124,12 +124,9 @@ class MklReorderWithScalePrimitive : public MklPrimitive {
|
||||
: src_mem(nullptr),
|
||||
dst_mem(nullptr),
|
||||
reorder_pd(nullptr),
|
||||
reorder_prim(nullptr),
|
||||
reorder_stream(nullptr) {}
|
||||
reorder_prim(nullptr) {}
|
||||
} context_;
|
||||
|
||||
engine cpu_engine_;
|
||||
|
||||
// Reorder primitive setup
|
||||
void Setup(const MklReorderWithScaleFwdParams& fwdParams) {
|
||||
// Create memory descriptors for reorder data with specified format
|
||||
@ -163,7 +160,6 @@ class MklReorderWithScalePrimitive : public MklPrimitive {
|
||||
context_.prim_args.insert({MKLDNN_ARG_FROM, *context_.src_mem});
|
||||
context_.prim_args.insert({MKLDNN_ARG_TO, *context_.dst_mem});
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
context_.reorder_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
}
|
||||
};
|
||||
|
||||
@ -491,7 +487,10 @@ class MklQuantizeV2Op : public OpKernel {
|
||||
MklReorderWithScalePrimitive* reorder_prim =
|
||||
MklReorderWithScalePrimitiveFactory<T>::Get(src.GetUsrMem(),
|
||||
dst.GetUsrMem(), fwdParams);
|
||||
reorder_prim->Execute(src.GetUsrMemDataHandle(), dst.GetUsrMemDataHandle());
|
||||
std::shared_ptr<stream> cpu_stream;
|
||||
cpu_stream.reset(CreateStream(ctx, reorder_prim->GetEngine()));
|
||||
reorder_prim->Execute(src.GetUsrMemDataHandle(), dst.GetUsrMemDataHandle(),
|
||||
cpu_stream);
|
||||
|
||||
output_min_tensor->flat<float>()(0) = min_range;
|
||||
output_max_tensor->flat<float>()(0) = max_range;
|
||||
|
||||
@ -61,13 +61,11 @@ template <typename T>
|
||||
class MklEltwiseFwdPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklEltwiseFwdPrimitive(const MklEltwiseFwdParams<T>& fwdParams)
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
context_.src_fmt =
|
||||
static_cast<mkldnn::memory::format>(fwdParams.src_md.data.format);
|
||||
#endif
|
||||
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
|
||||
// create eltwise primitive
|
||||
if (context_.eltwise_fwd == nullptr) {
|
||||
Setup(fwdParams);
|
||||
@ -79,7 +77,8 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
|
||||
// Eltwise forward execute
|
||||
// src_data: input data buffer of src
|
||||
// dst_data: output data buffer of dst
|
||||
void Execute(const T* src_data, T* dst_data) {
|
||||
void Execute(const T* src_data, T* dst_data,
|
||||
std::shared_ptr<stream> fwd_stream) {
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
||||
@ -87,12 +86,10 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
DCHECK_EQ(context_.fwd_primitives.size(),
|
||||
context_.fwd_primitives_args.size());
|
||||
for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
|
||||
context_.fwd_primitives.at(i).execute(*context_.fwd_stream,
|
||||
context_.fwd_primitives_args.at(i));
|
||||
}
|
||||
execute_primitives(context_.fwd_primitives, fwd_stream,
|
||||
context_.fwd_primitives_args);
|
||||
#else
|
||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
||||
fwd_stream->submit(context_.fwd_primitives);
|
||||
#endif
|
||||
|
||||
// After execution, set data handle back.
|
||||
@ -134,7 +131,6 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
|
||||
// Eltwise primitive
|
||||
std::shared_ptr<mkldnn::primitive> eltwise_fwd;
|
||||
|
||||
std::shared_ptr<stream> fwd_stream;
|
||||
std::vector<mkldnn::primitive> fwd_primitives;
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
@ -153,8 +149,7 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
|
||||
src_md(nullptr),
|
||||
dst_md(nullptr),
|
||||
src_mpd(nullptr),
|
||||
eltwise_fwd(nullptr),
|
||||
fwd_stream(nullptr) {
|
||||
eltwise_fwd(nullptr) {
|
||||
}
|
||||
};
|
||||
|
||||
@ -169,14 +164,12 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
|
||||
#else
|
||||
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
|
||||
#endif
|
||||
|
||||
// Create an eltwise forward descriptor and primitive descriptor
|
||||
context_.fwd_desc.reset(new eltwise_forward::desc(
|
||||
prop_kind::forward, fwdParams.alg_kind, *context_.src_md,
|
||||
fwdParams.alpha, fwdParams.beta));
|
||||
context_.fwd_pd.reset(new EltwiseFwdPd(*context_.fwd_desc, cpu_engine_));
|
||||
auto fwd_pd = context_.fwd_pd.get();
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// Create memory primitive based on dummy data
|
||||
context_.src_mem.reset(new MEMORY_CONSTRUCTOR(fwd_pd->PRIMITIVE_DESC_SRC,
|
||||
@ -195,12 +188,10 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
|
||||
context_.eltwise_fwd.reset(new eltwise_forward(
|
||||
*context_.fwd_pd, *context_.src_mem, *context_.dst_mem));
|
||||
#endif
|
||||
|
||||
context_.fwd_primitives.push_back(*context_.eltwise_fwd);
|
||||
}
|
||||
|
||||
struct EltwiseFwdContext context_;
|
||||
engine cpu_engine_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -281,14 +272,13 @@ template <typename T>
|
||||
class MklEltwiseBwdPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklEltwiseBwdPrimitive(const MklEltwiseBwdParams<T>& bwdParams)
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
context_.src_fmt =
|
||||
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
|
||||
context_.diff_dst_fmt =
|
||||
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
|
||||
#endif
|
||||
context_.bwd_stream.reset(new stream(CPU_STREAM(cpu_engine_)));
|
||||
// create eltwise primitive
|
||||
if (context_.eltwise_bwd == nullptr) {
|
||||
Setup(bwdParams);
|
||||
@ -301,7 +291,8 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
|
||||
// src_data: input data buffer of src
|
||||
// diff_dst_data: input data buffer of diff_dst
|
||||
// diff_src_data: output data buffer of diff_src
|
||||
void Execute(const T* src_data, const T* diff_dst_data, T* diff_src_data) {
|
||||
void Execute(const T* src_data, const T* diff_dst_data, T* diff_src_data,
|
||||
std::shared_ptr<stream> bwd_stream) {
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.diff_dst_mem->set_data_handle(
|
||||
@ -311,12 +302,10 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
DCHECK_EQ(context_.bwd_primitives.size(),
|
||||
context_.bwd_primitives_args.size());
|
||||
for (size_t i = 0; i < context_.bwd_primitives.size(); ++i) {
|
||||
context_.bwd_primitives.at(i).execute(*context_.bwd_stream,
|
||||
context_.bwd_primitives_args.at(i));
|
||||
}
|
||||
execute_primitives(context_.bwd_primitives, bwd_stream,
|
||||
context_.bwd_primitives_args);
|
||||
#else
|
||||
context_.bwd_stream->submit(context_.bwd_primitives);
|
||||
bwd_stream->submit(context_.bwd_primitives);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// after execution, set data handle back
|
||||
@ -367,7 +356,6 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
|
||||
// Eltwise primitive.
|
||||
std::shared_ptr<mkldnn::primitive> eltwise_bwd;
|
||||
|
||||
std::shared_ptr<stream> bwd_stream;
|
||||
std::vector<mkldnn::primitive> bwd_primitives;
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
@ -391,8 +379,7 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
|
||||
fwd_desc(nullptr),
|
||||
fwd_pd(nullptr),
|
||||
bwd_pd(nullptr),
|
||||
eltwise_bwd(nullptr),
|
||||
bwd_stream(nullptr) {
|
||||
eltwise_bwd(nullptr) {
|
||||
}
|
||||
};
|
||||
|
||||
@ -448,7 +435,6 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
|
||||
}
|
||||
|
||||
struct EltwiseBwdContext context_;
|
||||
engine cpu_engine_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -525,12 +511,10 @@ class MklReluOpBase : public OpKernel {
|
||||
const Tensor& src_tensor = MklGetInput(context, src_index);
|
||||
MklDnnShape dnn_shape_src;
|
||||
GetMklShape(context, src_index, &dnn_shape_src);
|
||||
|
||||
if (src_tensor.dims() == 0) {
|
||||
Compute_Scalar(context);
|
||||
return;
|
||||
}
|
||||
|
||||
MklDnnShape dnn_shape_dst;
|
||||
TensorShape tf_shape_dst;
|
||||
Tensor* dst_tensor = nullptr;
|
||||
@ -542,7 +526,6 @@ class MklReluOpBase : public OpKernel {
|
||||
dnn_shape_dst);
|
||||
return;
|
||||
}
|
||||
|
||||
// Set DNN primitive - src
|
||||
MklDnnData<T> src(&cpu_engine);
|
||||
memory::dims src_dims;
|
||||
@ -556,26 +539,25 @@ class MklReluOpBase : public OpKernel {
|
||||
// Create blocked memory descriptor
|
||||
src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides);
|
||||
}
|
||||
|
||||
// Try to get an eltwise forward primitive from caching pool
|
||||
MklEltwiseFwdParams<T> fwdParams(src_dims, src_md, alg_kind, alpha_,
|
||||
beta_);
|
||||
|
||||
MklEltwiseFwdPrimitive<T>* eltwise_fwd =
|
||||
MklEltwiseFwdPrimitiveFactory<T>::Get(fwdParams);
|
||||
|
||||
auto eltwise_fwd_pd = eltwise_fwd->GetEltwiseFwdPd();
|
||||
|
||||
std::shared_ptr<stream> fwd_cpu_stream;
|
||||
fwd_cpu_stream.reset(CreateStream(context, eltwise_fwd->GetEngine()));
|
||||
// Check if src needs to be reordered
|
||||
const T* src_data = src_tensor.flat<T>().data();
|
||||
if (IS_SRC_REORDER_NEEDED(src_md, eltwise_fwd_pd, eltwise_fwd)) {
|
||||
src.SetUsrMem(src_md, &src_tensor);
|
||||
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
eltwise_fwd_pd->PRIMITIVE_DESC_SRC, cpu_engine));
|
||||
src.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(eltwise_fwd_pd->PRIMITIVE_DESC_SRC,
|
||||
cpu_engine),
|
||||
context);
|
||||
src_data = const_cast<T*>(
|
||||
reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
|
||||
}
|
||||
|
||||
// Allocate dst tensor, always set it as MKL-DNN layout
|
||||
if (dnn_shape_src.IsMklTensor()) {
|
||||
dnn_shape_dst.SetMklTensor(true);
|
||||
@ -590,7 +572,6 @@ class MklReluOpBase : public OpKernel {
|
||||
dnn_shape_dst.SetMklTensor(false);
|
||||
tf_shape_dst = src_tensor.shape();
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||
{static_cast<const int>(src_index)},
|
||||
static_cast<const int>(dst_index),
|
||||
@ -600,7 +581,7 @@ class MklReluOpBase : public OpKernel {
|
||||
T* dst_data = dst_tensor->flat<T>().data();
|
||||
|
||||
// execute eltwise
|
||||
eltwise_fwd->Execute(src_data, dst_data);
|
||||
eltwise_fwd->Execute(src_data, dst_data, fwd_cpu_stream);
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
", message: " + string(e.message) + ", in file " +
|
||||
@ -727,13 +708,16 @@ class MklReluGradOpBase : public OpKernel {
|
||||
MklEltwiseBwdPrimitiveFactory<T>::Get(bwdParams);
|
||||
|
||||
auto eltwise_bwd_pd = eltwise_bwd->GetEltwiseBwdPd();
|
||||
|
||||
std::shared_ptr<stream> bwd_cpu_stream;
|
||||
bwd_cpu_stream.reset(CreateStream(context, eltwise_bwd->GetEngine()));
|
||||
// check whether need reorder for src / diff_dst
|
||||
const T* src_data = src_tensor.flat<T>().data();
|
||||
if (IS_SRC_REORDER_NEEDED(src_md, eltwise_bwd_pd, eltwise_bwd)) {
|
||||
src.SetUsrMem(src_md, &src_tensor);
|
||||
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
eltwise_bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine));
|
||||
src.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(
|
||||
eltwise_bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine),
|
||||
context);
|
||||
src_data = const_cast<T*>(
|
||||
reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
|
||||
}
|
||||
@ -742,8 +726,10 @@ class MklReluGradOpBase : public OpKernel {
|
||||
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, eltwise_bwd_pd,
|
||||
eltwise_bwd)) {
|
||||
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
|
||||
diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
eltwise_bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine));
|
||||
diff_dst.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(
|
||||
eltwise_bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine),
|
||||
context);
|
||||
diff_dst_data = const_cast<T*>(
|
||||
reinterpret_cast<T*>(diff_dst.GetOpMem().get_data_handle()));
|
||||
}
|
||||
@ -779,7 +765,8 @@ class MklReluGradOpBase : public OpKernel {
|
||||
T* diff_src_data = diff_src_tensor->flat<T>().data();
|
||||
|
||||
// execute eltwise bwd
|
||||
eltwise_bwd->Execute(src_data, diff_dst_data, diff_src_data);
|
||||
eltwise_bwd->Execute(src_data, diff_dst_data, diff_src_data,
|
||||
bwd_cpu_stream);
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
", message: " + string(e.message) + ", in file " +
|
||||
|
||||
@ -130,9 +130,10 @@ class MklRequantizePerChannelOp : public OpKernel {
|
||||
GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(input_mem_prim),
|
||||
GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(output_mem_prim),
|
||||
cpu_engine_, reorder_attr);
|
||||
mkldnn::stream reorder_stream = CPU_STREAM(cpu_engine_);
|
||||
std::shared_ptr<stream> reorder_stream;
|
||||
reorder_stream.reset(CreateStream(ctx, cpu_engine_));
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
reorder_stream.submit(
|
||||
reorder_stream->submit(
|
||||
{mkldnn::reorder(reorder_pd, *input_mem_prim, *output_mem_prim)});
|
||||
#else
|
||||
std::unordered_map<int, mkldnn::memory> reorder_args = {
|
||||
@ -140,7 +141,7 @@ class MklRequantizePerChannelOp : public OpKernel {
|
||||
{MKLDNN_ARG_TO, *output_mem_prim}};
|
||||
std::unique_ptr<mkldnn::primitive> reorder_prim(
|
||||
new mkldnn::reorder(reorder_pd));
|
||||
reorder_prim->execute(reorder_stream, reorder_args);
|
||||
reorder_prim->execute(*reorder_stream, reorder_args);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
Tensor* output_min = nullptr;
|
||||
|
||||
@ -181,22 +181,22 @@ template <typename T>
|
||||
class MklSlicePrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklSlicePrimitive(const MklSliceParams& sliceParams)
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
context_.slice_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
Setup(sliceParams);
|
||||
}
|
||||
|
||||
~MklSlicePrimitive() {}
|
||||
|
||||
void Execute(const MklSliceParams& sliceParams) {
|
||||
void Execute(const MklSliceParams& sliceParams,
|
||||
std::shared_ptr<stream> slice_stream) {
|
||||
context_.src_mem->set_data_handle(sliceParams.from->get_data_handle());
|
||||
context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle());
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.slice_primitives, context_.slice_stream,
|
||||
execute_primitives(context_.slice_primitives, slice_stream,
|
||||
context_.slice_primitives_args);
|
||||
#else
|
||||
context_.slice_stream->submit(context_.slice_primitives);
|
||||
slice_stream->submit(context_.slice_primitives);
|
||||
#endif
|
||||
|
||||
// We should set it back to DummyData so as to make the primitive
|
||||
@ -228,8 +228,6 @@ class MklSlicePrimitive : public MklPrimitive {
|
||||
: src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
|
||||
} context_;
|
||||
|
||||
engine cpu_engine_;
|
||||
|
||||
void Setup(const MklSliceParams& sliceParams) {
|
||||
// Actually, DummyData will not be used in computation,
|
||||
// because the real data will be filled before execution.
|
||||
@ -465,7 +463,7 @@ class MklSliceOp : public OpKernel {
|
||||
auto op_md =
|
||||
MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
src.CheckReorderToOpMem(op_md, cpu_engine);
|
||||
src.CheckReorderToOpMem(op_md, cpu_engine, context);
|
||||
#else
|
||||
auto op_pd = memory::primitive_desc(op_md, cpu_engine);
|
||||
src.CheckReorderToOpMem(op_pd);
|
||||
@ -492,7 +490,9 @@ class MklSliceOp : public OpKernel {
|
||||
MklSlicePrimitive<T>* reorder_prim =
|
||||
MklSlicePrimitiveFactory<T>::Get(sliceParams);
|
||||
// Execute slice reorder.
|
||||
reorder_prim->Execute(sliceParams);
|
||||
std::shared_ptr<stream> slice_stream;
|
||||
slice_stream.reset(CreateStream(context, reorder_prim->GetEngine()));
|
||||
reorder_prim->Execute(sliceParams, slice_stream);
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
", message: " + string(e.message) + ", in file " +
|
||||
|
||||
@ -48,8 +48,7 @@ template <typename T>
|
||||
class MklSoftmaxPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklSoftmaxPrimitive(const MklSoftmaxParams& fwdParams)
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
Setup(fwdParams);
|
||||
}
|
||||
|
||||
@ -58,16 +57,18 @@ class MklSoftmaxPrimitive : public MklPrimitive {
|
||||
// Softmax forward execute
|
||||
// src_data: input data buffer of src
|
||||
// dst_data: output data buffer of dst
|
||||
void Execute(const T* src_data, T* dst_data) {
|
||||
void Execute(const T* src_data, T* dst_data,
|
||||
std::shared_ptr<stream> fwd_cpu_stream) {
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.fwd_primitives, context_.fwd_stream,
|
||||
DCHECK_EQ(context_.fwd_primitives.size(), context_.fwd_net_args.size());
|
||||
execute_primitives(context_.fwd_primitives, fwd_cpu_stream,
|
||||
context_.fwd_net_args);
|
||||
#else
|
||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
||||
fwd_cpu_stream->submit(context_.fwd_primitives);
|
||||
#endif
|
||||
|
||||
// After execution, set data handle back.
|
||||
@ -95,7 +96,6 @@ class MklSoftmaxPrimitive : public MklPrimitive {
|
||||
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> fwd_pd;
|
||||
std::shared_ptr<mkldnn::primitive> softmax_fwd;
|
||||
|
||||
std::shared_ptr<stream> fwd_stream;
|
||||
std::vector<mkldnn::primitive> fwd_primitives;
|
||||
std::vector<MemoryArgsMap> fwd_net_args;
|
||||
|
||||
@ -105,8 +105,7 @@ class MklSoftmaxPrimitive : public MklPrimitive {
|
||||
fwd_desc(nullptr),
|
||||
src_md(nullptr),
|
||||
fwd_pd(nullptr),
|
||||
softmax_fwd(nullptr),
|
||||
fwd_stream(nullptr) {}
|
||||
softmax_fwd(nullptr) {}
|
||||
};
|
||||
|
||||
// Softmax forward primitive setup
|
||||
@ -143,7 +142,6 @@ class MklSoftmaxPrimitive : public MklPrimitive {
|
||||
}
|
||||
|
||||
struct SoftmaxFwdContext context_;
|
||||
engine cpu_engine_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -303,9 +301,9 @@ class MklSoftmaxOp : public OpKernel {
|
||||
|
||||
const T* src_data = src_tensor.flat<T>().data();
|
||||
T* dst_data = reinterpret_cast<T*>(output_tensor->flat<T>().data());
|
||||
|
||||
// Execute softmax primitive.
|
||||
softmax_fwd->Execute(src_data, dst_data);
|
||||
std::shared_ptr<stream> fwd_cpu_stream;
|
||||
fwd_cpu_stream.reset(CreateStream(context, softmax_fwd->GetEngine()));
|
||||
softmax_fwd->Execute(src_data, dst_data, fwd_cpu_stream);
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
", message: " + string(e.message) + ", in file " +
|
||||
|
||||
@ -144,12 +144,14 @@ Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor,
|
||||
|
||||
std::vector<primitive> net;
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
std::shared_ptr<stream> transpose_stream;
|
||||
auto* prim = FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem());
|
||||
transpose_stream.reset(CreateStream(context, prim->GetEngine()));
|
||||
net.push_back(*(prim->GetPrimitive()));
|
||||
std::vector<MemoryArgsMap> net_args;
|
||||
net_args.push_back({{MKLDNN_ARG_FROM, *in.GetUsrMem()},
|
||||
{MKLDNN_ARG_TO, *out.GetUsrMem()}});
|
||||
execute_primitives(net, prim->GetStream(), net_args);
|
||||
execute_primitives(net, transpose_stream, net_args);
|
||||
#else
|
||||
std::shared_ptr<stream> transpose_stream;
|
||||
transpose_stream.reset(new CPU_STREAM(cpu_engine));
|
||||
|
||||
@ -887,6 +887,17 @@ class ResourceScatterUpdateOp : public OpKernel {
|
||||
const Tensor& indices = c->input(1);
|
||||
const Tensor& updates = c->input(2);
|
||||
|
||||
// Check that rank(updates.shape) = rank(indices.shape + params.shape[1:])
|
||||
OP_REQUIRES(c,
|
||||
updates.dims() == 0 ||
|
||||
updates.dims() == indices.dims() + params->dims() - 1,
|
||||
errors::InvalidArgument(
|
||||
"Must have updates.shape = indices.shape + "
|
||||
"params.shape[1:] or updates.shape = [], got ",
|
||||
"updates.shape ", updates.shape().DebugString(),
|
||||
", indices.shape ", indices.shape().DebugString(),
|
||||
", params.shape ", params->shape().DebugString()));
|
||||
|
||||
// Check that we have enough index space
|
||||
const int64 N_big = indices.NumElements();
|
||||
OP_REQUIRES(
|
||||
|
||||
@ -508,6 +508,12 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
||||
errors::InvalidArgument("segment ids must be >= 0"));
|
||||
auto output_flat = output->flat_outer_dims<T>();
|
||||
|
||||
Tensor temp;
|
||||
if (input.dtype() == DT_BFLOAT16) {
|
||||
temp = tensorflow::Tensor(DT_FLOAT, output_shape);
|
||||
}
|
||||
auto temp_flat = temp.flat_outer_dims<float>();
|
||||
|
||||
int64 start = 0, end = 1;
|
||||
// Index from which the output is not initialized.
|
||||
SegmentId uninitialized_index = 0;
|
||||
@ -546,8 +552,9 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
||||
}
|
||||
|
||||
auto out = output_flat.template chip<0>(out_index);
|
||||
const int bad_offset =
|
||||
Reduce(input_flat, indices_vec, start, end - start, out);
|
||||
auto temp = temp_flat.template chip<0>(out_index);
|
||||
const int bad_offset = Reduce<T, Index>(input_flat, indices_vec, start,
|
||||
end - start, out, temp);
|
||||
OP_REQUIRES(context, bad_offset < 0,
|
||||
errors::InvalidArgument(
|
||||
"Bad: indices[", start + bad_offset,
|
||||
@ -572,40 +579,89 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
int64 Reduce(const typename TTypes<T>::ConstMatrix& input_flat,
|
||||
const typename TTypes<Index>::ConstVec& indices_vec, int64 start,
|
||||
int64 num,
|
||||
Eigen::TensorChippingOp<0, typename TTypes<T>::Matrix> out) {
|
||||
template <typename Tin>
|
||||
using EnableIfBfloat16 =
|
||||
typename std::enable_if<std::is_same<Tin, bfloat16>::value, int>::type;
|
||||
template <typename Tin>
|
||||
using EnableIfNotBfloat16 =
|
||||
typename std::enable_if<!std::is_same<Tin, bfloat16>::value, int>::type;
|
||||
|
||||
template <typename Tin, typename Tindex, EnableIfNotBfloat16<Tin> = 0>
|
||||
EIGEN_ALWAYS_INLINE auto fetch_val(
|
||||
const typename TTypes<Tin>::ConstMatrix& input_flat, Tindex index) {
|
||||
return input_flat.template chip<0>(index);
|
||||
}
|
||||
|
||||
template <typename Tin, typename Tindex, EnableIfBfloat16<Tin> = 0>
|
||||
EIGEN_ALWAYS_INLINE auto fetch_val(
|
||||
const typename TTypes<Tin>::ConstMatrix& input_flat, Tindex index) {
|
||||
return input_flat.template chip<0>(index).template cast<float>();
|
||||
}
|
||||
|
||||
template <typename Tout>
|
||||
EIGEN_ALWAYS_INLINE Tout get_scaling_factor(int64 num) {
|
||||
Tout m(1);
|
||||
if (is_mean_ && (num < 10)) {
|
||||
m = Tout(num);
|
||||
}
|
||||
if (is_sqrtn_ && (num < 10)) {
|
||||
m = Tout(sqrt(num));
|
||||
}
|
||||
return Tout(1) / m;
|
||||
}
|
||||
|
||||
template <typename Tin, typename Tindex, EnableIfNotBfloat16<Tin> = 0>
|
||||
int64 Reduce(
|
||||
const typename TTypes<Tin>::ConstMatrix& input_flat,
|
||||
const typename TTypes<Tindex>::ConstVec& indices_vec, int64 start,
|
||||
int64 num, Eigen::TensorChippingOp<0, typename TTypes<Tin>::Matrix> out,
|
||||
Eigen::TensorChippingOp<0, typename TTypes<float>::Matrix> temp) {
|
||||
return ReduceImpl<Tin, Tindex, Tin>(input_flat, indices_vec, start, num,
|
||||
out, get_scaling_factor<Tin>(num));
|
||||
}
|
||||
|
||||
template <typename Tin, typename Tindex, EnableIfBfloat16<Tin> = 0>
|
||||
int64 Reduce(
|
||||
const typename TTypes<Tin>::ConstMatrix& input_flat,
|
||||
const typename TTypes<Tindex>::ConstVec& indices_vec, int64 start,
|
||||
int64 num, Eigen::TensorChippingOp<0, typename TTypes<Tin>::Matrix> out,
|
||||
Eigen::TensorChippingOp<0, typename TTypes<float>::Matrix> temp) {
|
||||
int64 res =
|
||||
ReduceImpl<Tin, Tindex, float>(input_flat, indices_vec, start, num,
|
||||
temp, get_scaling_factor<float>(num));
|
||||
out = temp.template cast<bfloat16>();
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename Tin, typename Tindex, typename Tout>
|
||||
int64 ReduceImpl(
|
||||
const typename TTypes<Tin>::ConstMatrix& input_flat,
|
||||
const typename TTypes<Tindex>::ConstVec& indices_vec, int64 start,
|
||||
int64 num, Eigen::TensorChippingOp<0, typename TTypes<Tout>::Matrix> out,
|
||||
const Tout scaling_factor) {
|
||||
#define INDEX(n, i) \
|
||||
const auto index##n = indices_vec(start + (i)); \
|
||||
if (!FastBoundsCheck(index##n, input_flat.dimension(0))) return (i);
|
||||
|
||||
#define L(n) input_flat.template chip<0>(index##n)
|
||||
#define L(n) fetch_val<Tin, Tindex>(input_flat, index##n)
|
||||
|
||||
if (num == 1) {
|
||||
INDEX(0, 0);
|
||||
out = L(0);
|
||||
} else {
|
||||
int64 r = num % 8;
|
||||
T m(1);
|
||||
if (is_mean_ && (num < 10)) {
|
||||
m = T(num);
|
||||
}
|
||||
if (is_sqrtn_ && (num < 10)) {
|
||||
m = T(sqrt(num));
|
||||
}
|
||||
int64 r = num & 7;
|
||||
switch (r) {
|
||||
case 2: {
|
||||
INDEX(0, 0);
|
||||
INDEX(1, 1);
|
||||
out = (L(0) + L(1)) / m;
|
||||
out = (L(0) + L(1)) * scaling_factor;
|
||||
break;
|
||||
}
|
||||
case 3: {
|
||||
INDEX(0, 0);
|
||||
INDEX(1, 1);
|
||||
INDEX(2, 2);
|
||||
out = (L(0) + L(1) + L(2)) / m;
|
||||
out = (L(0) + L(1) + L(2)) * scaling_factor;
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
@ -613,7 +669,7 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
||||
INDEX(1, 1);
|
||||
INDEX(2, 2);
|
||||
INDEX(3, 3);
|
||||
out = (L(0) + L(1) + L(2) + L(3)) / m;
|
||||
out = (L(0) + L(1) + L(2) + L(3)) * scaling_factor;
|
||||
break;
|
||||
}
|
||||
case 5: {
|
||||
@ -622,7 +678,7 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
||||
INDEX(2, 2);
|
||||
INDEX(3, 3);
|
||||
INDEX(4, 4);
|
||||
out = (L(0) + L(1) + L(2) + L(3) + L(4)) / m;
|
||||
out = (L(0) + L(1) + L(2) + L(3) + L(4)) * scaling_factor;
|
||||
break;
|
||||
}
|
||||
case 6: {
|
||||
@ -632,7 +688,7 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
||||
INDEX(3, 3);
|
||||
INDEX(4, 4);
|
||||
INDEX(5, 5);
|
||||
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) / m;
|
||||
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) * scaling_factor;
|
||||
break;
|
||||
}
|
||||
case 7: {
|
||||
@ -643,7 +699,8 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
||||
INDEX(4, 4);
|
||||
INDEX(5, 5);
|
||||
INDEX(6, 6);
|
||||
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) / m;
|
||||
out =
|
||||
(L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) * scaling_factor;
|
||||
break;
|
||||
}
|
||||
case 0: {
|
||||
@ -655,7 +712,8 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
||||
INDEX(5, 5);
|
||||
INDEX(6, 6);
|
||||
INDEX(7, 7);
|
||||
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) / m;
|
||||
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) *
|
||||
scaling_factor;
|
||||
r = 8;
|
||||
break;
|
||||
}
|
||||
@ -669,8 +727,8 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
||||
INDEX(6, 6);
|
||||
INDEX(7, 7);
|
||||
INDEX(8, 8);
|
||||
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) /
|
||||
m;
|
||||
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) *
|
||||
scaling_factor;
|
||||
r = 9;
|
||||
break;
|
||||
}
|
||||
@ -687,10 +745,10 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
||||
out += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7);
|
||||
}
|
||||
if (is_mean_ && num >= 10) {
|
||||
out = out / static_cast<T>(num);
|
||||
out = out / static_cast<Tout>(num);
|
||||
}
|
||||
if (is_sqrtn_ && num >= 10) {
|
||||
out = out / static_cast<T>(sqrt(num));
|
||||
out = out / static_cast<Tout>(sqrt(num));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -64,6 +64,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
|
||||
segment_ids_type>);
|
||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
|
||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
|
||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(bfloat16);
|
||||
#undef REGISTER_CPU_SPARSE_KERNELS
|
||||
|
||||
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
|
||||
@ -85,6 +86,7 @@ REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
|
||||
CPUDevice, type, index_type, segment_ids_type>);
|
||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
|
||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
|
||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(bfloat16);
|
||||
#undef REGISTER_CPU_SPARSE_KERNELS
|
||||
|
||||
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
|
||||
|
||||
@ -21,9 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/strided_slice_op_gpu_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
TF_CALL_int8(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_int32(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_int64(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_INTEGRAL_TYPES(DEFINE_GPU_KERNELS);
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
@ -25,6 +25,19 @@ REGISTER_OP("BatchFunction")
|
||||
.Output("out_tensors: Tout")
|
||||
.Attr("f: func")
|
||||
.Attr("num_batch_threads: int")
|
||||
// 'max_batch_size' denotes the maximum batch size acceptable, i.e., inputs
|
||||
// with larger batch size are simply invalidated.
|
||||
// By default, 'max_batch_size' must be equal to max value of
|
||||
// 'allowed_batch_sizes'.
|
||||
// By setting 'enable_large_batch_splitting' (attribute below) to true,
|
||||
// 'max_batch_size' can be greater than or equal to max value of
|
||||
// 'allowed_batch_sizes', in other words,
|
||||
// 1) input with size > 'max_batch_size' is still invalidated.
|
||||
// 2) input with
|
||||
// a) size <= 'max_batch_size'
|
||||
// b) size > max value of 'allowed_batch_sizes'
|
||||
// will automatically be split into multiple batches (with batch size in
|
||||
// 'allowed_batch_sizes'), executed, and re-composed (as final output).
|
||||
.Attr("max_batch_size: int")
|
||||
.Attr("batch_timeout_micros: int")
|
||||
.Attr("max_enqueued_batches: int = 10")
|
||||
@ -35,6 +48,12 @@ REGISTER_OP("BatchFunction")
|
||||
.Attr("Tin: list(type)")
|
||||
.Attr("Tcaptured: list(type) >= 0")
|
||||
.Attr("Tout: list(type)")
|
||||
// If 'enable_large_batch_splitting' is true, for input batches exceeding
|
||||
// the largest value in "allowed_batch_sizes", allow the batch to be split
|
||||
// into multiple batches with batch size within "allowed_batch_sizes".
|
||||
// NOTE: Support for `enable_large_batch_splitting == true` is still
|
||||
// developed in progress.
|
||||
.Attr("enable_large_batch_splitting: bool = false")
|
||||
// TODO(apassos): Fix this shape inference function. It requires shape
|
||||
// inference of function calls.
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
@ -82,3 +82,94 @@ op {
|
||||
minimum: 1
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "BatchFunction"
|
||||
input_arg {
|
||||
name: "in_tensors"
|
||||
type_list_attr: "Tin"
|
||||
}
|
||||
input_arg {
|
||||
name: "captured_tensors"
|
||||
type_list_attr: "Tcaptured"
|
||||
}
|
||||
output_arg {
|
||||
name: "out_tensors"
|
||||
type_list_attr: "Tout"
|
||||
}
|
||||
attr {
|
||||
name: "f"
|
||||
type: "func"
|
||||
}
|
||||
attr {
|
||||
name: "num_batch_threads"
|
||||
type: "int"
|
||||
}
|
||||
attr {
|
||||
name: "max_batch_size"
|
||||
type: "int"
|
||||
}
|
||||
attr {
|
||||
name: "batch_timeout_micros"
|
||||
type: "int"
|
||||
}
|
||||
attr {
|
||||
name: "max_enqueued_batches"
|
||||
type: "int"
|
||||
default_value {
|
||||
i: 10
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "allowed_batch_sizes"
|
||||
type: "list(int)"
|
||||
default_value {
|
||||
list {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "container"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "shared_name"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "batching_queue"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tin"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "Tcaptured"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
}
|
||||
attr {
|
||||
name: "Tout"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "enable_large_batch_splitting"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -53,3 +53,59 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "SparseSegmentMean"
|
||||
input_arg {
|
||||
name: "data"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "indices"
|
||||
type_attr: "Tidx"
|
||||
}
|
||||
input_arg {
|
||||
name: "segment_ids"
|
||||
type_attr: "Tsegmentids"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_BFLOAT16
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tidx"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tsegmentids"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -70,3 +70,76 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "SparseSegmentMeanWithNumSegments"
|
||||
input_arg {
|
||||
name: "data"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "indices"
|
||||
type_attr: "Tidx"
|
||||
}
|
||||
input_arg {
|
||||
name: "segment_ids"
|
||||
type_attr: "Tsegmentids"
|
||||
}
|
||||
input_arg {
|
||||
name: "num_segments"
|
||||
type_attr: "Tnumsegments"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_BFLOAT16
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tidx"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tnumsegments"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tsegmentids"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -53,3 +53,59 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "SparseSegmentSqrtN"
|
||||
input_arg {
|
||||
name: "data"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "indices"
|
||||
type_attr: "Tidx"
|
||||
}
|
||||
input_arg {
|
||||
name: "segment_ids"
|
||||
type_attr: "Tsegmentids"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_BFLOAT16
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tidx"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tsegmentids"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -70,3 +70,76 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "SparseSegmentSqrtNWithNumSegments"
|
||||
input_arg {
|
||||
name: "data"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "indices"
|
||||
type_attr: "Tidx"
|
||||
}
|
||||
input_arg {
|
||||
name: "segment_ids"
|
||||
type_attr: "Tsegmentids"
|
||||
}
|
||||
input_arg {
|
||||
name: "num_segments"
|
||||
type_attr: "Tnumsegments"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_BFLOAT16
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tidx"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tnumsegments"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tsegmentids"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -299,4 +299,10 @@ REGISTER_OP("FakeParam")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
// Returns the device index.
|
||||
REGISTER_OP("DeviceIndex")
|
||||
.Output("index: int32")
|
||||
.Attr("device_names: list(string)")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
@ -1337,7 +1337,7 @@ REGISTER_OP("SparseSegmentMean")
|
||||
.Input("indices: Tidx")
|
||||
.Input("segment_ids: Tsegmentids")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Attr("T: {bfloat16, float, double}")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(SparseSegmentReductionShapeFn);
|
||||
@ -1348,7 +1348,7 @@ REGISTER_OP("SparseSegmentMeanWithNumSegments")
|
||||
.Input("segment_ids: Tsegmentids")
|
||||
.Input("num_segments: Tnumsegments")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Attr("T: {bfloat16, float, double}")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
|
||||
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
||||
@ -1370,7 +1370,7 @@ REGISTER_OP("SparseSegmentSqrtN")
|
||||
.Input("indices: Tidx")
|
||||
.Input("segment_ids: Tsegmentids")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Attr("T: {bfloat16, float, double}")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(SparseSegmentReductionShapeFn);
|
||||
@ -1381,7 +1381,7 @@ REGISTER_OP("SparseSegmentSqrtNWithNumSegments")
|
||||
.Input("segment_ids: Tsegmentids")
|
||||
.Input("num_segments: Tnumsegments")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Attr("T: {bfloat16, float, double}")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
|
||||
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
||||
|
||||
@ -3553,6 +3553,13 @@ op {
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "enable_large_batch_splitting"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "BatchIFFT"
|
||||
@ -46097,6 +46104,7 @@ op {
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_BFLOAT16
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
@ -46215,6 +46223,7 @@ op {
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_BFLOAT16
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
@ -46283,6 +46292,7 @@ op {
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_BFLOAT16
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
@ -46401,6 +46411,7 @@ op {
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_BFLOAT16
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
|
||||
@ -58,6 +58,15 @@ TEST(StateOpsTest, ScatterUpdate_ShapeFn) {
|
||||
|
||||
// Resolve shape on first updates dimension.
|
||||
INFER_OK(op, "[1,2];[3];[?,2]", "in0");
|
||||
|
||||
// Allow the update to be a scalar.
|
||||
INFER_OK(op, "[1,2];[3];?", "in0");
|
||||
|
||||
// Allow a scalar index.
|
||||
INFER_OK(op, "[1,2];[];[2]", "in0");
|
||||
|
||||
// Check the requirement updates.shape = indices.shape + ref.shape[1:].
|
||||
INFER_ERROR("Shapes must be equal rank, but are 1 and 0", op, "[2];[];[2]");
|
||||
}
|
||||
|
||||
TEST(StateOpsTest, TemporaryVariable_ShapeFn) {
|
||||
|
||||
@ -23,7 +23,7 @@ const char* kProtobufUint64Typename = "::tensorflow::protobuf_uint64";
|
||||
TStringOutputStream::TStringOutputStream(tstring* target) : target_(target) {}
|
||||
|
||||
bool TStringOutputStream::Next(void** data, int* size) {
|
||||
int old_size = target_->size();
|
||||
size_t old_size = target_->size();
|
||||
|
||||
// Grow the string.
|
||||
if (old_size < target_->capacity()) {
|
||||
@ -32,16 +32,16 @@ bool TStringOutputStream::Next(void** data, int* size) {
|
||||
target_->resize_uninitialized(target_->capacity());
|
||||
} else {
|
||||
// Size has reached capacity, try to double the size.
|
||||
if (old_size > std::numeric_limits<int>::max() / 2) {
|
||||
if (old_size > std::numeric_limits<size_t>::max() / 2) {
|
||||
// Can not double the size otherwise it is going to cause integer
|
||||
// overflow in the expression below: old_size * 2 ";
|
||||
return false;
|
||||
}
|
||||
// Double the size, also make sure that the new size is at least
|
||||
// kMinimumSize.
|
||||
target_->resize_uninitialized(
|
||||
std::max(old_size * 2,
|
||||
kMinimumSize + 0)); // "+ 0" works around GCC4 weirdness.
|
||||
target_->resize_uninitialized(std::max(
|
||||
old_size * 2,
|
||||
(size_t)kMinimumSize + 0)); // "+ 0" works around GCC4 weirdness.
|
||||
}
|
||||
|
||||
*data = target_->data() + old_size;
|
||||
|
||||
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <aws/core/utils/StringUtils.h>
|
||||
#include <aws/core/utils/logging/AWSLogging.h>
|
||||
#include <aws/core/utils/logging/LogSystemInterface.h>
|
||||
#include <aws/core/utils/stream/PreallocatedStreamBuf.h>
|
||||
#include <aws/s3/S3Client.h>
|
||||
#include <aws/s3/S3Errors.h>
|
||||
#include <aws/s3/model/AbortMultipartUploadRequest.h>
|
||||
@ -58,10 +59,16 @@ static const char* kS3TempFileTemplate = "/tmp/s3_filesystem_XXXXXX";
|
||||
static const char* kS3FileSystemAllocationTag = "S3FileSystemAllocation";
|
||||
static const size_t kS3ReadAppendableFileBufferSize = 1024 * 1024;
|
||||
static const int64 kS3TimeoutMsec = 300000; // 5 min
|
||||
static const uint64 kS3MultiPartCopyPartSize = 50 * 1024 * 1024; // 50MB
|
||||
static const uint64 kS3MultiPartUploadChunkSize = 50 * 1024 * 1024; // 50 MB
|
||||
static const uint64 kS3MultiPartDownloadChunkSize = 2 * 1024 * 1024; // 50 MB
|
||||
static const int kS3GetChildrenMaxKeys = 100;
|
||||
static const int kExecutorPoolSize = 5;
|
||||
static const int kUploadRetries = 5;
|
||||
|
||||
// With this change multiple threads are used in one single download.
|
||||
// Increasing the thread pool size since multiple downloads
|
||||
// and uploads can occur in parallel.
|
||||
static const int kExecutorPoolSize = 25;
|
||||
static const int kUploadRetries = 3;
|
||||
static const int kDownloadRetries = 3;
|
||||
static const char* kExecutorTag = "TransferManagerExecutor";
|
||||
|
||||
Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
|
||||
@ -223,9 +230,16 @@ static Status CreateStatusFromAwsError(
|
||||
|
||||
class S3RandomAccessFile : public RandomAccessFile {
|
||||
public:
|
||||
S3RandomAccessFile(const string& bucket, const string& object,
|
||||
std::shared_ptr<Aws::S3::S3Client> s3_client)
|
||||
: bucket_(bucket), object_(object), s3_client_(s3_client) {}
|
||||
S3RandomAccessFile(
|
||||
const string& bucket, const string& object,
|
||||
const bool use_multi_part_download,
|
||||
std::shared_ptr<Aws::Transfer::TransferManager> transfer_manager,
|
||||
std::shared_ptr<Aws::S3::S3Client> s3_client)
|
||||
: bucket_(bucket),
|
||||
object_(object),
|
||||
use_multi_part_download_(use_multi_part_download),
|
||||
transfer_manager_(transfer_manager),
|
||||
s3_client_(s3_client) {}
|
||||
|
||||
Status Name(StringPiece* result) const override {
|
||||
return errors::Unimplemented("S3RandomAccessFile does not support Name()");
|
||||
@ -235,6 +249,66 @@ class S3RandomAccessFile : public RandomAccessFile {
|
||||
char* scratch) const override {
|
||||
VLOG(1) << "ReadFilefromS3 s3://" << bucket_ << "/" << object_ << " from "
|
||||
<< offset << " for n:" << n;
|
||||
if (use_multi_part_download_) {
|
||||
return ReadS3TransferManager(offset, n, result, scratch);
|
||||
} else {
|
||||
return ReadS3Client(offset, n, result, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
Status ReadS3TransferManager(uint64 offset, size_t n, StringPiece* result,
|
||||
char* scratch) const {
|
||||
VLOG(3) << "Using TransferManager";
|
||||
|
||||
auto create_stream_fn = [&]() { // create stream lambda fn
|
||||
return Aws::New<TFS3UnderlyingStream>(
|
||||
"S3ReadStream",
|
||||
Aws::New<Aws::Utils::Stream::PreallocatedStreamBuf>(
|
||||
"S3ReadStream", reinterpret_cast<unsigned char*>(scratch), n));
|
||||
};
|
||||
|
||||
VLOG(3) << "Created stream to read with transferManager";
|
||||
|
||||
std::shared_ptr<Aws::Transfer::TransferHandle> handle =
|
||||
transfer_manager_.get()->DownloadFile(bucket_.c_str(), object_.c_str(),
|
||||
offset, n, create_stream_fn);
|
||||
handle->WaitUntilFinished();
|
||||
|
||||
// todo change this
|
||||
int retries = 0;
|
||||
|
||||
while (handle->GetStatus() == Aws::Transfer::TransferStatus::FAILED &&
|
||||
handle->GetLastError().GetResponseCode() !=
|
||||
Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE &&
|
||||
retries++ < kDownloadRetries) {
|
||||
// only failed parts will be downloaded again
|
||||
VLOG(1) << "Retrying read of s3://" << bucket_ << "/" << object_
|
||||
<< " after failure. Current retry count:" << retries;
|
||||
transfer_manager_.get()->RetryDownload(handle);
|
||||
handle->WaitUntilFinished();
|
||||
}
|
||||
|
||||
if (handle->GetStatus() != Aws::Transfer::TransferStatus::COMPLETED) {
|
||||
auto error = handle->GetLastError();
|
||||
if (error.GetResponseCode() ==
|
||||
Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE) {
|
||||
// expected when end of file is reached
|
||||
n = 0;
|
||||
*result = StringPiece(scratch, n);
|
||||
return Status(error::OUT_OF_RANGE, "Read less bytes than requested");
|
||||
}
|
||||
return CreateStatusFromAwsError(error);
|
||||
} else {
|
||||
n = handle->GetBytesTotalSize();
|
||||
*result = StringPiece(scratch, handle->GetBytesTransferred());
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
Status ReadS3Client(uint64 offset, size_t n, StringPiece* result,
|
||||
char* scratch) const {
|
||||
VLOG(3) << "ReadFile using S3Client s3://" << bucket_ << "/" << object_;
|
||||
|
||||
Aws::S3::Model::GetObjectRequest getObjectRequest;
|
||||
getObjectRequest.WithBucket(bucket_.c_str()).WithKey(object_.c_str());
|
||||
string bytes = strings::StrCat("bytes=", offset, "-", offset + n - 1);
|
||||
@ -242,6 +316,7 @@ class S3RandomAccessFile : public RandomAccessFile {
|
||||
getObjectRequest.SetResponseStreamFactory([]() {
|
||||
return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag);
|
||||
});
|
||||
|
||||
auto getObjectOutcome = this->s3_client_->GetObject(getObjectRequest);
|
||||
if (!getObjectOutcome.IsSuccess()) {
|
||||
auto error = getObjectOutcome.GetError();
|
||||
@ -252,18 +327,21 @@ class S3RandomAccessFile : public RandomAccessFile {
|
||||
return Status(error::OUT_OF_RANGE, "Read less bytes than requested");
|
||||
}
|
||||
return CreateStatusFromAwsError(error);
|
||||
}
|
||||
n = getObjectOutcome.GetResult().GetContentLength();
|
||||
getObjectOutcome.GetResult().GetBody().read(scratch, n);
|
||||
} else {
|
||||
n = getObjectOutcome.GetResult().GetContentLength();
|
||||
getObjectOutcome.GetResult().GetBody().read(scratch, n);
|
||||
|
||||
*result = StringPiece(scratch, n);
|
||||
return Status::OK();
|
||||
*result = StringPiece(scratch, n);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
string bucket_;
|
||||
string object_;
|
||||
std::shared_ptr<Aws::S3::S3Client> s3_client_;
|
||||
std::shared_ptr<Aws::Transfer::TransferManager> transfer_manager_;
|
||||
bool use_multi_part_download_;
|
||||
};
|
||||
|
||||
class S3WritableFile : public WritableFile {
|
||||
@ -375,16 +453,53 @@ class S3ReadOnlyMemoryRegion : public ReadOnlyMemoryRegion {
|
||||
S3FileSystem::S3FileSystem()
|
||||
: s3_client_(nullptr, ShutdownClient),
|
||||
initialization_lock_(),
|
||||
transfer_manager_(nullptr, ShutdownTransferManager),
|
||||
executor_(nullptr, ShutdownExecutor) {
|
||||
const char* part_size_str = getenv("S3_MULTI_PART_COPY_PART_SIZE");
|
||||
multi_part_copy_part_size_ = kS3MultiPartCopyPartSize;
|
||||
const char* part_size_str = getenv("S3_MULTI_PART_UPLOAD_CHUNK_SIZE");
|
||||
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD] =
|
||||
kS3MultiPartUploadChunkSize;
|
||||
if (part_size_str) {
|
||||
uint64 part_size_num;
|
||||
if (strings::safe_strtou64(part_size_str, &part_size_num)) {
|
||||
multi_part_copy_part_size_ = part_size_num;
|
||||
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD] =
|
||||
part_size_num;
|
||||
}
|
||||
}
|
||||
|
||||
// Different TensorFlow APIs call the download API with different
|
||||
// buffer size. Download performance depends on that size and this chunk size.
|
||||
part_size_str = getenv("S3_MULTI_PART_DOWNLOAD_CHUNK_SIZE");
|
||||
multi_part_chunk_size_[Aws::Transfer::TransferDirection::DOWNLOAD] =
|
||||
kS3MultiPartDownloadChunkSize;
|
||||
if (part_size_str) {
|
||||
uint64 part_size_num;
|
||||
if (strings::safe_strtou64(part_size_str, &part_size_num)) {
|
||||
multi_part_chunk_size_[Aws::Transfer::TransferDirection::DOWNLOAD] =
|
||||
part_size_num;
|
||||
}
|
||||
}
|
||||
|
||||
use_multi_part_download_ = true;
|
||||
const char* disable_transfer_mgr = getenv("S3_DISABLE_MULTI_PART_DOWNLOAD");
|
||||
if (disable_transfer_mgr) {
|
||||
if (disable_transfer_mgr[0] == '1') {
|
||||
use_multi_part_download_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
auto upload_pair = std::pair<Aws::Transfer::TransferDirection,
|
||||
std::shared_ptr<Aws::Transfer::TransferManager>>(
|
||||
Aws::Transfer::TransferDirection::UPLOAD,
|
||||
std::shared_ptr<Aws::Transfer::TransferManager>(nullptr,
|
||||
ShutdownTransferManager));
|
||||
auto download_pair =
|
||||
std::pair<Aws::Transfer::TransferDirection,
|
||||
std::shared_ptr<Aws::Transfer::TransferManager>>(
|
||||
Aws::Transfer::TransferDirection::DOWNLOAD,
|
||||
std::shared_ptr<Aws::Transfer::TransferManager>(
|
||||
nullptr, ShutdownTransferManager));
|
||||
|
||||
this->transfer_managers_.insert(upload_pair);
|
||||
this->transfer_managers_.insert(download_pair);
|
||||
}
|
||||
|
||||
S3FileSystem::~S3FileSystem() {}
|
||||
@ -424,20 +539,22 @@ std::shared_ptr<Aws::S3::S3Client> S3FileSystem::GetS3Client() {
|
||||
}
|
||||
|
||||
std::shared_ptr<Aws::Transfer::TransferManager>
|
||||
S3FileSystem::GetTransferManager() {
|
||||
S3FileSystem::GetTransferManager(
|
||||
const Aws::Transfer::TransferDirection& direction) {
|
||||
std::shared_ptr<Aws::S3::S3Client> s3_client = this->GetS3Client();
|
||||
std::lock_guard<mutex> lock(this->initialization_lock_);
|
||||
if (this->transfer_manager_.get() == nullptr) {
|
||||
if (this->transfer_managers_[direction].get() == nullptr) {
|
||||
Aws::Transfer::TransferManagerConfiguration config(
|
||||
this->GetExecutor().get());
|
||||
config.s3Client = s3_client;
|
||||
config.bufferSize = this->multi_part_copy_part_size_;
|
||||
// must be larger than pool size * multi_part_copy_part_size
|
||||
config.bufferSize = this->multi_part_chunk_size_[direction];
|
||||
// must be larger than pool size * multi part chunk size
|
||||
config.transferBufferMaxHeapSize =
|
||||
(kExecutorPoolSize + 1) * this->multi_part_copy_part_size_;
|
||||
this->transfer_manager_ = Aws::Transfer::TransferManager::Create(config);
|
||||
(kExecutorPoolSize + 1) * this->multi_part_chunk_size_[direction];
|
||||
this->transfer_managers_[direction] =
|
||||
Aws::Transfer::TransferManager::Create(config);
|
||||
}
|
||||
return this->transfer_manager_;
|
||||
return this->transfer_managers_[direction];
|
||||
}
|
||||
|
||||
std::shared_ptr<Aws::Utils::Threading::PooledThreadExecutor>
|
||||
@ -452,9 +569,21 @@ S3FileSystem::GetExecutor() {
|
||||
|
||||
Status S3FileSystem::NewRandomAccessFile(
|
||||
const string& fname, std::unique_ptr<RandomAccessFile>* result) {
|
||||
return NewRandomAccessFile(fname, result, true);
|
||||
}
|
||||
|
||||
Status S3FileSystem::NewRandomAccessFile(
|
||||
const string& fname, std::unique_ptr<RandomAccessFile>* result,
|
||||
bool use_multi_part_download) {
|
||||
string bucket, object;
|
||||
TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
|
||||
result->reset(new S3RandomAccessFile(bucket, object, this->GetS3Client()));
|
||||
|
||||
// check if an override was defined for this file. used for testing
|
||||
bool use_mpd = this->use_multi_part_download_ && use_multi_part_download;
|
||||
result->reset(new S3RandomAccessFile(
|
||||
bucket, object, use_mpd,
|
||||
this->GetTransferManager(Aws::Transfer::TransferDirection::DOWNLOAD),
|
||||
this->GetS3Client()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -462,8 +591,11 @@ Status S3FileSystem::NewWritableFile(const string& fname,
|
||||
std::unique_ptr<WritableFile>* result) {
|
||||
string bucket, object;
|
||||
TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
|
||||
result->reset(new S3WritableFile(bucket, object, this->GetTransferManager(),
|
||||
this->GetS3Client()));
|
||||
result->reset(new S3WritableFile(
|
||||
bucket, object,
|
||||
this->GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD),
|
||||
this->GetS3Client()));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -478,8 +610,10 @@ Status S3FileSystem::NewAppendableFile(const string& fname,
|
||||
|
||||
string bucket, object;
|
||||
TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
|
||||
result->reset(new S3WritableFile(bucket, object, this->GetTransferManager(),
|
||||
this->GetS3Client()));
|
||||
result->reset(new S3WritableFile(
|
||||
bucket, object,
|
||||
this->GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD),
|
||||
this->GetS3Client()));
|
||||
|
||||
while (true) {
|
||||
status = reader->Read(offset, kS3ReadAppendableFileBufferSize, &read_chunk,
|
||||
@ -773,10 +907,13 @@ Status S3FileSystem::CopyFile(const Aws::String& source_bucket,
|
||||
TF_RETURN_IF_ERROR(
|
||||
this->GetFileSize(string(source_full_path.c_str()), &file_length));
|
||||
int num_parts;
|
||||
if (file_length <= multi_part_copy_part_size_) {
|
||||
if (file_length <=
|
||||
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD]) {
|
||||
num_parts = 1;
|
||||
} else {
|
||||
num_parts = ceil((float)file_length / multi_part_copy_part_size_);
|
||||
num_parts =
|
||||
ceil((float)file_length /
|
||||
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD]);
|
||||
}
|
||||
|
||||
if (num_parts == 1) {
|
||||
@ -786,7 +923,8 @@ Status S3FileSystem::CopyFile(const Aws::String& source_bucket,
|
||||
"MultiPartCopy with number of parts more than 10000 is not supported. "
|
||||
"Your object ",
|
||||
source, " required ", num_parts,
|
||||
" as multi_part_copy_part_size is set to ", multi_part_copy_part_size_,
|
||||
" as multi_part_copy_part_size is set to ",
|
||||
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD],
|
||||
". You can control this part size using the environment variable ",
|
||||
"S3_MULTI_PART_COPY_PART_SIZE to increase it.");
|
||||
return tensorflow::errors::Unimplemented(message);
|
||||
@ -831,7 +969,9 @@ Status S3FileSystem::MultiPartCopy(const Aws::String& source,
|
||||
|
||||
Aws::String uploadID = multipartUploadOutcome.GetResult().GetUploadId();
|
||||
VLOG(1) << "Copying from " << source << " in " << num_parts
|
||||
<< " parts of size " << multi_part_copy_part_size_ << " each";
|
||||
<< " parts of size "
|
||||
<< multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD]
|
||||
<< " each";
|
||||
Aws::S3::Model::CompletedMultipartUpload completedMPURequest;
|
||||
|
||||
// passed to each callback keyed by partNumber
|
||||
@ -859,8 +999,12 @@ Status S3FileSystem::MultiPartCopy(const Aws::String& source,
|
||||
for (std::map<int, PartState>::iterator it = incompletePartStates.begin();
|
||||
it != incompletePartStates.end(); it++) {
|
||||
int partNumber = it->first;
|
||||
uint64 startPos = (partNumber - 1) * multi_part_copy_part_size_;
|
||||
uint64 endPos = startPos + kS3MultiPartCopyPartSize - 1;
|
||||
uint64 startPos =
|
||||
(partNumber - 1) *
|
||||
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD];
|
||||
uint64 endPos =
|
||||
startPos +
|
||||
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD] - 1;
|
||||
if (endPos >= file_length) {
|
||||
endPos = file_length - 1;
|
||||
}
|
||||
|
||||
@ -52,6 +52,10 @@ class S3FileSystem : public FileSystem {
|
||||
Status NewRandomAccessFile(
|
||||
const string& fname, std::unique_ptr<RandomAccessFile>* result) override;
|
||||
|
||||
Status NewRandomAccessFile(const string& fname,
|
||||
std::unique_ptr<RandomAccessFile>* result,
|
||||
bool use_multi_part_download);
|
||||
|
||||
Status NewWritableFile(const string& fname,
|
||||
std::unique_ptr<WritableFile>* result) override;
|
||||
|
||||
@ -101,8 +105,12 @@ class S3FileSystem : public FileSystem {
|
||||
std::shared_ptr<Aws::S3::S3Client> s3_client_;
|
||||
|
||||
// Returns the member transfer manager, initializing as-needed.
|
||||
std::shared_ptr<Aws::Transfer::TransferManager> GetTransferManager();
|
||||
std::shared_ptr<Aws::Transfer::TransferManager> transfer_manager_;
|
||||
std::shared_ptr<Aws::Transfer::TransferManager> GetTransferManager(
|
||||
const Aws::Transfer::TransferDirection& direction);
|
||||
void InitializeTransferManagers();
|
||||
std::map<Aws::Transfer::TransferDirection,
|
||||
std::shared_ptr<Aws::Transfer::TransferManager> >
|
||||
transfer_managers_;
|
||||
|
||||
// Returns the member executor for transfer manager, initializing as-needed.
|
||||
std::shared_ptr<Aws::Utils::Threading::PooledThreadExecutor> GetExecutor();
|
||||
@ -132,8 +140,10 @@ class S3FileSystem : public FileSystem {
|
||||
// Lock held when checking for s3_client_ and transfer_manager_ initialization
|
||||
mutex initialization_lock_;
|
||||
|
||||
// size to split objects during multipart copy
|
||||
uint64 multi_part_copy_part_size_;
|
||||
// size to split objects during multipart upload/download/copy
|
||||
std::map<Aws::Transfer::TransferDirection, uint64> multi_part_chunk_size_;
|
||||
|
||||
bool use_multi_part_download_;
|
||||
};
|
||||
|
||||
/// S3 implementation of a file system with retry on failures.
|
||||
@ -147,6 +157,16 @@ class RetryingS3FileSystem : public RetryingFileSystem<S3FileSystem> {
|
||||
)) {}
|
||||
};
|
||||
|
||||
// AWS Streams destroy the buffer (buf) passed, so creating a new
|
||||
// IOStream that retains the buffer so the calling function
|
||||
// can control it's lifecycle
|
||||
class TFS3UnderlyingStream : public Aws::IOStream {
|
||||
public:
|
||||
using Base = Aws::IOStream;
|
||||
TFS3UnderlyingStream(std::streambuf* buf) : Base(buf) {}
|
||||
virtual ~TFS3UnderlyingStream() = default;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_S3_S3_FILE_SYSTEM_H_
|
||||
|
||||
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/platform/s3/s3_file_system.h"
|
||||
|
||||
#include <time.h>
|
||||
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/file_system.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
@ -62,6 +64,96 @@ class S3FileSystemTest : public ::testing::Test {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadAllInChunks(const string& fname, string* content,
|
||||
bool use_multi_part_download = true) {
|
||||
std::unique_ptr<RandomAccessFile> reader;
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
s3fs.NewRandomAccessFile(fname, &reader, use_multi_part_download));
|
||||
|
||||
uint64 file_size = 0;
|
||||
TF_RETURN_IF_ERROR(s3fs.GetFileSize(fname, &file_size));
|
||||
|
||||
content->resize(file_size);
|
||||
|
||||
uint64 buffer_size = 16 * 1024 * 1024;
|
||||
|
||||
std::size_t part_count = (std::max)(
|
||||
static_cast<size_t>((file_size + buffer_size - 1) / buffer_size),
|
||||
static_cast<std::size_t>(1));
|
||||
VLOG(1) << "buffersize:" << buffer_size << " file_size:" << file_size
|
||||
<< " part_count=" << part_count;
|
||||
std::unique_ptr<char[]> buffer{new char[buffer_size]};
|
||||
std::stringstream ss;
|
||||
|
||||
int offset = 0;
|
||||
int result_size = 0;
|
||||
|
||||
using namespace std::chrono;
|
||||
auto start = high_resolution_clock::now();
|
||||
|
||||
for (int i = 0; i < part_count; i++) {
|
||||
StringPiece result;
|
||||
offset = i * buffer_size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->Read(offset, buffer_size, &result, buffer.get()));
|
||||
|
||||
if (result.size() != 0) {
|
||||
ss.write(result.data(), result.size());
|
||||
result_size += result.size();
|
||||
}
|
||||
if (result_size == file_size) {
|
||||
break;
|
||||
}
|
||||
if (result.size() != buffer_size) {
|
||||
VLOG(1) << "Result size and buffer size did not match";
|
||||
if (result.empty()) {
|
||||
return errors::OutOfRange("eof");
|
||||
} else {
|
||||
return errors::DataLoss("truncated record at ", offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (file_size != result_size) {
|
||||
return errors::DataLoss("expected ", file_size, " got ", result_size,
|
||||
" bytes");
|
||||
}
|
||||
|
||||
auto stop = high_resolution_clock::now();
|
||||
duration<double> time_taken = duration_cast<duration<double>>(stop - start);
|
||||
VLOG(1) << "Time Taken"
|
||||
<< " : " << time_taken.count() << "seconds";
|
||||
|
||||
memcpy((char*)(content->data()), ss.str().data(),
|
||||
static_cast<size_t>(file_size));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadLargeFile() {
|
||||
// const string fname = TmpDir("train-00001-of-01024");
|
||||
auto large_file_name = getenv("LARGE_DOWNLOAD_FILE_NAME");
|
||||
const string fname = TmpDir(large_file_name);
|
||||
string content_xfer;
|
||||
string content_s3client;
|
||||
|
||||
// Read using Chunked Transfer Manager
|
||||
VLOG(1) << "Using transfer manager";
|
||||
TF_RETURN_IF_ERROR(ReadAllInChunks(fname, &content_xfer));
|
||||
|
||||
VLOG(1) << "Without transfer manager";
|
||||
// Read using old S3 API and see if the contents match with TransferManager
|
||||
TF_RETURN_IF_ERROR(ReadAllInChunks(fname, &content_s3client, false));
|
||||
|
||||
if (content_xfer == content_s3client) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
VLOG(1) << "ReadLargeFile contents DO NOT match";
|
||||
return Status(error::OUT_OF_RANGE, "ReadLargeFile contents DO NOT match");
|
||||
}
|
||||
}
|
||||
|
||||
S3FileSystem s3fs;
|
||||
};
|
||||
|
||||
@ -236,5 +328,9 @@ TEST_F(S3FileSystemTest, HasAtomicMove) {
|
||||
EXPECT_EQ(has_atomic_move, false);
|
||||
}
|
||||
|
||||
TEST_F(S3FileSystemTest, NewRandomAccessBigFile) {
|
||||
TF_EXPECT_OK(ReadLargeFile());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
@ -63,13 +63,11 @@ void HostOpMetricsDbBuilder::UpdateHostInfeedEnqInfo(
|
||||
start_timestamp_ps_diff);
|
||||
}
|
||||
|
||||
void DeviceOpMetricsDbBuilder::EnterOp(uint64 program_id,
|
||||
absl::string_view name,
|
||||
absl::string_view category,
|
||||
absl::string_view provenance,
|
||||
bool is_eager, uint64 occurrences,
|
||||
uint64 time_ps, uint64 children_time_ps,
|
||||
int64 flops, int64 bytes_accessed) {
|
||||
void DeviceOpMetricsDbBuilder::EnterOp(
|
||||
uint64 program_id, absl::string_view name, absl::string_view category,
|
||||
absl::string_view provenance, bool is_eager, uint64 occurrences,
|
||||
uint64 time_ps, uint64 children_time_ps, int64 flops, int64 bytes_accessed,
|
||||
const std::vector<OpMetrics::MemoryAccessed>& memory_accessed_breakdown) {
|
||||
uint64 self_time_ps = time_ps - children_time_ps;
|
||||
DCHECK_GE(time_ps, self_time_ps);
|
||||
OpMetrics* op_metrics = LookupOrInsertNewOpMetrics(program_id, name);
|
||||
@ -89,6 +87,9 @@ void DeviceOpMetricsDbBuilder::EnterOp(uint64 program_id,
|
||||
op_metrics->bytes_accessed() +
|
||||
GetCappedPerf(bytes_accessed * occurrences, self_time_ps,
|
||||
peak_hbm_bw_giga_bytes_per_second_ / 1000));
|
||||
for (const auto& memory_accessed : memory_accessed_breakdown) {
|
||||
*op_metrics->add_memory_accessed_breakdown() = memory_accessed;
|
||||
}
|
||||
db()->set_total_op_time_ps(db()->total_op_time_ps() + self_time_ps);
|
||||
}
|
||||
|
||||
|
||||
@ -69,10 +69,14 @@ class DeviceOpMetricsDbBuilder : public OpMetricsDbBuilder {
|
||||
// picoseconds.
|
||||
// flops = the number of floating-point operations computed.
|
||||
// bytes_accessed = the sum of bytes read and bytes written by this OP.
|
||||
// memory_accessed_breakdown = the breakdown of memory accessed by operation
|
||||
// type and memory space.
|
||||
void EnterOp(uint64 program_id, absl::string_view name,
|
||||
absl::string_view category, absl::string_view provenance,
|
||||
bool is_eager, uint64 occurrences, uint64 time_ps,
|
||||
uint64 children_time_ps, int64 flops, int64 bytes_accessed);
|
||||
uint64 children_time_ps, int64 flops, int64 bytes_accessed,
|
||||
const std::vector<OpMetrics::MemoryAccessed>&
|
||||
memory_accessed_breakdown = {});
|
||||
|
||||
protected:
|
||||
// Peak performance of a TensorCore or a GPU in TFLOP/s.
|
||||
|
||||
@ -108,7 +108,7 @@ limitations under the License.
|
||||
|
||||
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
||||
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
||||
#define TF_GRAPH_DEF_VERSION 414 // Updated: 2020/5/27
|
||||
#define TF_GRAPH_DEF_VERSION 415 // Updated: 2020/5/28
|
||||
|
||||
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
|
||||
//
|
||||
|
||||
@ -474,7 +474,9 @@ inline SparseTensor SparseTensor::Concat(
|
||||
const int st_num_entries = st.num_entries();
|
||||
|
||||
// Fill in indices & values.
|
||||
std::copy_n(&st.vals_.vec<T>()(0), st_num_entries, &vals_t(offset));
|
||||
if (st_num_entries > 0) {
|
||||
std::copy_n(&st.vals_.vec<T>()(0), st_num_entries, &vals_t(offset));
|
||||
}
|
||||
|
||||
const auto* st_ix = &st.ix_.matrix<int64>()(0, 0);
|
||||
auto* ix_out = &ix_t(offset, 0);
|
||||
|
||||
@ -252,6 +252,7 @@ cc_library(
|
||||
":version",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/delegates:status",
|
||||
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
|
||||
"//tensorflow/lite/experimental/resource",
|
||||
"//tensorflow/lite/kernels/internal:compatibility",
|
||||
|
||||
@ -680,6 +680,9 @@ def gen_model_coverage_test(src, model_name, data, failure_type, tags, size = "m
|
||||
if failure_type[i] != "none":
|
||||
args.append("--failure_type=%s" % failure_type[i])
|
||||
i = i + 1
|
||||
|
||||
# Avoid coverage timeouts for large/enormous tests.
|
||||
coverage_tags = ["nozapfhahn"] if size in ["large", "enormous"] else []
|
||||
native.py_test(
|
||||
name = "model_coverage_test_%s_%s" % (model_name, target_op_sets.lower().replace(",", "_")),
|
||||
srcs = [src],
|
||||
@ -696,7 +699,7 @@ def gen_model_coverage_test(src, model_name, data, failure_type, tags, size = "m
|
||||
"no_gpu", # Executing with TF GPU configurations is redundant.
|
||||
"no_oss",
|
||||
"no_windows",
|
||||
] + tags,
|
||||
] + tags + coverage_tags,
|
||||
deps = [
|
||||
"//tensorflow/lite/testing/model_coverage:model_coverage_lib",
|
||||
"//tensorflow/lite/python:lite",
|
||||
|
||||
@ -22,34 +22,56 @@ namespace tflite {
|
||||
// A simple utility for enabling profiled event tracing in TensorFlow Lite.
|
||||
class Profiler {
|
||||
public:
|
||||
// As certain Profiler instance might be only interested in certain event
|
||||
// types, we define each event type value to allow a Profiler to use
|
||||
// bitmasking bitwise operations to determine whether an event should be
|
||||
// recorded or not.
|
||||
enum class EventType {
|
||||
// Default event type, the metadata field has no special significance.
|
||||
DEFAULT = 0,
|
||||
DEFAULT = 1,
|
||||
|
||||
// The event is an operator invocation and the event_metadata field is the
|
||||
// index of operator node.
|
||||
OPERATOR_INVOKE_EVENT = 1,
|
||||
OPERATOR_INVOKE_EVENT = 2,
|
||||
|
||||
// The event is an invocation for an internal operator of a TFLite delegate.
|
||||
// The event_metadata field is the index of operator node that's specific to
|
||||
// the delegate.
|
||||
DELEGATE_OPERATOR_INVOKE_EVENT = 2
|
||||
DELEGATE_OPERATOR_INVOKE_EVENT = 4,
|
||||
|
||||
// The event is a recording of runtime instrumentation such as the overall
|
||||
// TFLite runtime status, the TFLite delegate status (if a delegate
|
||||
// is applied), and the overall model inference latency etc.
|
||||
// Note, the delegate status and overall status are stored as separate
|
||||
// event_metadata fields. In particular, the delegate status is encoded
|
||||
// as DelegateStatus::full_status().
|
||||
GENERAL_RUNTIME_INSTRUMENTATION_EVENT = 8,
|
||||
};
|
||||
|
||||
virtual ~Profiler() {}
|
||||
|
||||
// Signals the beginning of an event from a subgraph indexed at
|
||||
// 'event_subgraph_index', returning a handle to the profile event.
|
||||
// Signals the beginning of an event and returns a handle to the profile
|
||||
// event. The `event_metadata1` and `event_metadata2` have different
|
||||
// interpretations based on the actual Profiler instance and the `event_type`.
|
||||
// For example, as for the 'SubgraphAwareProfiler' defined in
|
||||
// lite/core/subgraph.h, when the event_type is OPERATOR_INVOKE_EVENT,
|
||||
// `event_metadata1` represents the index of a TFLite node, and
|
||||
// `event_metadata2` represents the index of the subgraph that this event
|
||||
// comes from.
|
||||
virtual uint32_t BeginEvent(const char* tag, EventType event_type,
|
||||
uint32_t event_metadata,
|
||||
uint32_t event_subgraph_index) = 0;
|
||||
// Similar w/ the above, but the event comes from the primary subgraph that's
|
||||
// indexed at 0.
|
||||
virtual uint32_t BeginEvent(const char* tag, EventType event_type,
|
||||
uint32_t event_metadata) {
|
||||
return BeginEvent(tag, event_type, event_metadata, /*primary subgraph*/ 0);
|
||||
int64_t event_metadata1,
|
||||
int64_t event_metadata2) = 0;
|
||||
// Similar w/ the above, but `event_metadata2` defaults to 0.
|
||||
uint32_t BeginEvent(const char* tag, EventType event_type,
|
||||
int64_t event_metadata) {
|
||||
return BeginEvent(tag, event_type, event_metadata, /*event_metadata2*/ 0);
|
||||
}
|
||||
|
||||
// Signals an end to the specified profile event with 'event_metadata's, This
|
||||
// is useful when 'event_metadata's are not available when the event begins
|
||||
// or when one wants to overwrite the 'event_metadata's set at the beginning.
|
||||
virtual void EndEvent(uint32_t event_handle, int64_t event_metadata1,
|
||||
int64_t event_metadata2) {}
|
||||
// Signals an end to the specified profile event.
|
||||
virtual void EndEvent(uint32_t event_handle) = 0;
|
||||
|
||||
@ -60,15 +82,18 @@ class Profiler {
|
||||
// they assume the value is in "usec", if in any case subclasses
|
||||
// didn't put usec, then the values are not meaningful.
|
||||
// TODO karimnosseir: Revisit and make the function more clear.
|
||||
virtual void AddEvent(const char* tag, EventType event_type,
|
||||
uint32_t event_metadata, uint64_t start, uint64_t end) {
|
||||
AddEvent(tag, event_type, event_metadata, start, end,
|
||||
/*event_subgraph_index*/ 0);
|
||||
void AddEvent(const char* tag, EventType event_type, uint64_t start,
|
||||
uint64_t end, int64_t event_metadata) {
|
||||
AddEvent(tag, event_type, start, end, event_metadata,
|
||||
/*event_metadata2*/ 0);
|
||||
}
|
||||
|
||||
virtual void AddEvent(const char* tag, EventType event_type,
|
||||
uint32_t event_metadata, uint64_t start, uint64_t end,
|
||||
uint32_t event_subgraph_index) {}
|
||||
virtual void AddEvent(const char* tag, EventType event_type, uint64_t start,
|
||||
uint64_t end, int64_t event_metadata1,
|
||||
int64_t event_metadata2) {}
|
||||
|
||||
protected:
|
||||
friend class ScopedProfile;
|
||||
};
|
||||
|
||||
// Adds a profile event to `profiler` that begins with the construction
|
||||
@ -79,7 +104,7 @@ class ScopedProfile {
|
||||
public:
|
||||
ScopedProfile(Profiler* profiler, const char* tag,
|
||||
Profiler::EventType event_type = Profiler::EventType::DEFAULT,
|
||||
uint32_t event_metadata = 0)
|
||||
int64_t event_metadata = 0)
|
||||
: profiler_(profiler), event_handle_(0) {
|
||||
if (profiler) {
|
||||
event_handle_ = profiler_->BeginEvent(tag, event_type, event_metadata);
|
||||
@ -92,8 +117,8 @@ class ScopedProfile {
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Profiler* const profiler_;
|
||||
protected:
|
||||
Profiler* profiler_;
|
||||
uint32_t event_handle_;
|
||||
};
|
||||
|
||||
@ -113,6 +138,31 @@ class ScopedDelegateOperatorProfile : public ScopedProfile {
|
||||
static_cast<uint32_t>(node_index)) {}
|
||||
};
|
||||
|
||||
class ScopedRuntimeInstrumentationProfile : public ScopedProfile {
|
||||
public:
|
||||
ScopedRuntimeInstrumentationProfile(Profiler* profiler, const char* tag)
|
||||
: ScopedProfile(
|
||||
profiler, tag,
|
||||
Profiler::EventType::GENERAL_RUNTIME_INSTRUMENTATION_EVENT, -1) {}
|
||||
|
||||
void set_runtime_status(int64_t delegate_status, int64_t interpreter_status) {
|
||||
if (profiler_) {
|
||||
delegate_status_ = delegate_status;
|
||||
interpreter_status_ = interpreter_status;
|
||||
}
|
||||
}
|
||||
|
||||
~ScopedRuntimeInstrumentationProfile() {
|
||||
if (profiler_) {
|
||||
profiler_->EndEvent(event_handle_, delegate_status_, interpreter_status_);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t delegate_status_;
|
||||
int64_t interpreter_status_;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#define TFLITE_VARNAME_UNIQ_IMPL(name, ctr) name##ctr
|
||||
@ -130,4 +180,15 @@ class ScopedDelegateOperatorProfile : public ScopedProfile {
|
||||
tflite::ScopedDelegateOperatorProfile TFLITE_VARNAME_UNIQ( \
|
||||
_profile_, __COUNTER__)((profiler), (tag), (node_index))
|
||||
|
||||
#define TFLITE_ADD_RUNTIME_INSTRUMENTATION_EVENT( \
|
||||
profiler, tag, delegate_status, interpreter_status) \
|
||||
do { \
|
||||
if (!profiler) { \
|
||||
const auto handle = profiler->BeginEvent( \
|
||||
tag, Profiler::EventType::GENERAL_RUNTIME_INSTRUMENTATION_EVENT, \
|
||||
delegate_status, interpreter_status); \
|
||||
profiler->EndEvent(handle); \
|
||||
} \
|
||||
} while (false);
|
||||
|
||||
#endif // TENSORFLOW_LITE_CORE_API_PROFILER_H_
|
||||
|
||||
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_CORE_SUBGRAPH_H_
|
||||
#define TENSORFLOW_LITE_CORE_SUBGRAPH_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
@ -338,21 +339,16 @@ class Subgraph {
|
||||
class SubgraphAwareProfiler : public Profiler {
|
||||
public:
|
||||
// Constructor should be called with the non-nullptr profiler argument.
|
||||
SubgraphAwareProfiler(Profiler* profiler, uint32_t subgraph_index)
|
||||
SubgraphAwareProfiler(Profiler* profiler, int64_t subgraph_index)
|
||||
: profiler_(profiler), subgraph_index_(subgraph_index) {}
|
||||
~SubgraphAwareProfiler() override {}
|
||||
|
||||
uint32_t BeginEvent(const char* tag, EventType event_type,
|
||||
uint32_t event_metadata,
|
||||
uint32_t subgraph_index) override {
|
||||
int64_t event_metadata1,
|
||||
int64_t event_metadata2) override {
|
||||
if (!profiler_) return 0;
|
||||
return profiler_->BeginEvent(tag, event_type, event_metadata,
|
||||
subgraph_index);
|
||||
}
|
||||
|
||||
uint32_t BeginEvent(const char* tag, EventType event_type,
|
||||
uint32_t event_metadata) override {
|
||||
return BeginEvent(tag, event_type, event_metadata, subgraph_index_);
|
||||
return profiler_->BeginEvent(tag, event_type, event_metadata1,
|
||||
subgraph_index_);
|
||||
}
|
||||
|
||||
void EndEvent(uint32_t event_handle) override {
|
||||
@ -360,17 +356,24 @@ class Subgraph {
|
||||
profiler_->EndEvent(event_handle);
|
||||
}
|
||||
|
||||
void AddEvent(const char* tag, EventType event_type,
|
||||
uint32_t event_metadata, uint64_t start,
|
||||
uint64_t end) override {
|
||||
void EndEvent(uint32_t event_handle, int64_t event_metadata1,
|
||||
int64_t event_metadata2) override {
|
||||
if (!profiler_) return;
|
||||
profiler_->AddEvent(tag, event_type, event_metadata, start, end);
|
||||
profiler_->EndEvent(event_handle, event_metadata1, event_metadata2);
|
||||
}
|
||||
|
||||
void AddEvent(const char* tag, EventType event_type, uint64_t start,
|
||||
uint64_t end, int64_t event_metadata1,
|
||||
int64_t event_metadata2) override {
|
||||
if (!profiler_) return;
|
||||
profiler_->AddEvent(tag, event_type, start, end, event_metadata1,
|
||||
subgraph_index_);
|
||||
}
|
||||
|
||||
private:
|
||||
// Not own the memory.
|
||||
Profiler* const profiler_;
|
||||
const uint32_t subgraph_index_;
|
||||
const int64_t subgraph_index_;
|
||||
};
|
||||
|
||||
// Prevent 'context_' from accessing functions that are only available to
|
||||
|
||||
@ -20,6 +20,15 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "status",
|
||||
hdrs = ["status.h"],
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
"//tensorflow/lite/c:common",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "utils",
|
||||
srcs = ["utils.cc"],
|
||||
|
||||
@ -43,8 +43,10 @@ cc_library(
|
||||
srcs = ["arguments.cc"],
|
||||
hdrs = ["arguments.h"],
|
||||
deps = [
|
||||
":gpu_object",
|
||||
":opencl_wrapper",
|
||||
":util",
|
||||
"//tensorflow/lite/delegates/gpu/common:access_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
@ -305,6 +307,16 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_object",
|
||||
hdrs = ["gpu_object.h"],
|
||||
deps = [
|
||||
":opencl_wrapper",
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "inference_context",
|
||||
srcs = ["inference_context.cc"],
|
||||
|
||||
@ -23,10 +23,14 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
namespace {
|
||||
bool IsWordSymbol(char symbol) {
|
||||
return absl::ascii_isalnum(symbol) || symbol == '_';
|
||||
}
|
||||
|
||||
std::string GetNextWord(const std::string& code, size_t first_position) {
|
||||
size_t pos = first_position;
|
||||
char t = code[pos];
|
||||
while (absl::ascii_isalnum(t) || t == '_') {
|
||||
while (IsWordSymbol(t)) {
|
||||
pos++;
|
||||
t = code[pos];
|
||||
}
|
||||
@ -38,13 +42,19 @@ Arguments::Arguments(Arguments&& args)
|
||||
: int_values_(std::move(args.int_values_)),
|
||||
shared_int4s_data_(std::move(args.shared_int4s_data_)),
|
||||
float_values_(std::move(args.float_values_)),
|
||||
shared_float4s_data_(std::move(args.shared_float4s_data_)) {}
|
||||
shared_float4s_data_(std::move(args.shared_float4s_data_)),
|
||||
buffers_(std::move(args.buffers_)),
|
||||
images2d_(std::move(args.images2d_)),
|
||||
objects_(std::move(args.objects_)) {}
|
||||
Arguments& Arguments::operator=(Arguments&& args) {
|
||||
if (this != &args) {
|
||||
int_values_ = std::move(args.int_values_);
|
||||
shared_int4s_data_ = std::move(args.shared_int4s_data_);
|
||||
float_values_ = std::move(args.float_values_);
|
||||
shared_float4s_data_ = std::move(args.shared_float4s_data_);
|
||||
buffers_ = std::move(args.buffers_);
|
||||
images2d_ = std::move(args.images2d_);
|
||||
objects_ = std::move(args.objects_);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
@ -55,11 +65,40 @@ void Arguments::AddFloat(const std::string& name, float value) {
|
||||
void Arguments::AddInt(const std::string& name, int value) {
|
||||
int_values_[name].value = value;
|
||||
}
|
||||
void Arguments::AddBuffer(const std::string& name,
|
||||
const GPUBufferDescriptor& desc) {
|
||||
buffers_[name] = desc;
|
||||
}
|
||||
void Arguments::AddImage2D(const std::string& name,
|
||||
const GPUImage2DDescriptor& desc) {
|
||||
images2d_[name] = desc;
|
||||
}
|
||||
|
||||
void Arguments::AddObject(const std::string& name, GPUObjectPtr&& object) {
|
||||
objects_[name] = {AccessType::READ, std::move(object)};
|
||||
}
|
||||
|
||||
void Arguments::AddGPUResources(const std::string& name,
|
||||
const GPUResources& resources) {
|
||||
for (const auto& r : resources.ints) {
|
||||
AddInt(absl::StrCat(name, "_", r));
|
||||
}
|
||||
for (const auto& r : resources.floats) {
|
||||
AddFloat(absl::StrCat(name, "_", r));
|
||||
}
|
||||
for (const auto& r : resources.buffers) {
|
||||
AddBuffer(absl::StrCat(name, "_", r.first), r.second);
|
||||
}
|
||||
for (const auto& r : resources.images2d) {
|
||||
AddImage2D(absl::StrCat(name, "_", r.first), r.second);
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status Arguments::SetInt(const std::string& name, int value) {
|
||||
auto ii = int_values_.find(name);
|
||||
if (ii == int_values_.end()) {
|
||||
return absl::NotFoundError(absl::StrCat("No argument with name - ", name));
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("No int argument with name - ", name));
|
||||
}
|
||||
ii->second.value = value;
|
||||
if (ii->second.active) {
|
||||
@ -71,7 +110,8 @@ absl::Status Arguments::SetInt(const std::string& name, int value) {
|
||||
absl::Status Arguments::SetFloat(const std::string& name, float value) {
|
||||
auto fi = float_values_.find(name);
|
||||
if (fi == float_values_.end()) {
|
||||
return absl::NotFoundError(absl::StrCat("No argument with name - ", name));
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("No float argument with name - ", name));
|
||||
}
|
||||
fi->second.value = value;
|
||||
if (fi->second.active) {
|
||||
@ -80,8 +120,60 @@ absl::Status Arguments::SetFloat(const std::string& name, float value) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Arguments::SetImage2D(const std::string& name, cl_mem memory) {
|
||||
auto ti = images2d_.find(name);
|
||||
if (ti == images2d_.end()) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("No image2D argument with name - ", name));
|
||||
}
|
||||
ti->second.memory = memory;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Arguments::SetBuffer(const std::string& name, cl_mem memory) {
|
||||
auto it = buffers_.find(name);
|
||||
if (it == buffers_.end()) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("No buffer argument with name - ", name));
|
||||
}
|
||||
it->second.memory = memory;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Arguments::SetGPUResources(
|
||||
const std::string& name, const GPUResourcesWithValue& resources) {
|
||||
for (const auto& r : resources.ints) {
|
||||
RETURN_IF_ERROR(SetInt(absl::StrCat(name, "_", r.first), r.second));
|
||||
}
|
||||
for (const auto& r : resources.floats) {
|
||||
RETURN_IF_ERROR(SetFloat(absl::StrCat(name, "_", r.first), r.second));
|
||||
}
|
||||
for (const auto& r : resources.buffers) {
|
||||
RETURN_IF_ERROR(SetBuffer(absl::StrCat(name, "_", r.first), r.second));
|
||||
}
|
||||
for (const auto& r : resources.images2d) {
|
||||
RETURN_IF_ERROR(SetImage2D(absl::StrCat(name, "_", r.first), r.second));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Arguments::TransformToCLCode(std::string* code) {
|
||||
RETURN_IF_ERROR(AddObjectArgs());
|
||||
ResolveArgsPass(code);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
std::string Arguments::GetListOfArgs() {
|
||||
std::string result;
|
||||
for (auto& t : buffers_) {
|
||||
const std::string type_name =
|
||||
t.second.data_type == DataType::FLOAT32 ? "float" : "half";
|
||||
absl::StrAppend(&result, ",\n __global ", type_name, t.second.element_size,
|
||||
"* ", t.first);
|
||||
}
|
||||
for (auto& t : images2d_) {
|
||||
absl::StrAppend(&result, ",\n __read_only image2d_t ", t.first);
|
||||
}
|
||||
for (int i = 0; i < shared_int4s_data_.size() / 4; ++i) {
|
||||
absl::StrAppend(&result, ",\n int4 shared_int4_", i);
|
||||
}
|
||||
@ -92,6 +184,26 @@ std::string Arguments::GetListOfArgs() {
|
||||
}
|
||||
|
||||
absl::Status Arguments::Bind(cl_kernel kernel, int offset) {
|
||||
for (auto& t : buffers_) {
|
||||
const int error_code =
|
||||
clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
|
||||
if (error_code != CL_SUCCESS) {
|
||||
return absl::UnknownError(absl::StrCat(
|
||||
"Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
|
||||
"(at index - ", offset, ")"));
|
||||
}
|
||||
offset++;
|
||||
}
|
||||
for (auto& t : images2d_) {
|
||||
const int error_code =
|
||||
clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
|
||||
if (error_code != CL_SUCCESS) {
|
||||
return absl::UnknownError(absl::StrCat(
|
||||
"Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
|
||||
"(at index - ", offset, ")"));
|
||||
}
|
||||
offset++;
|
||||
}
|
||||
for (int i = 0; i < shared_int4s_data_.size() / 4; ++i) {
|
||||
const int error_code = clSetKernelArg(kernel, offset, sizeof(int32_t) * 4,
|
||||
&shared_int4s_data_[i * 4]);
|
||||
@ -148,8 +260,8 @@ std::string Arguments::AddActiveArgument(const std::string& arg_name) {
|
||||
}
|
||||
|
||||
void Arguments::ResolveArgsPass(std::string* code) {
|
||||
std::string result;
|
||||
constexpr char kPrefix[] = "args.";
|
||||
std::string result;
|
||||
size_t position = 0;
|
||||
size_t next_position = code->find(kPrefix);
|
||||
while (next_position != std::string::npos) {
|
||||
@ -168,6 +280,16 @@ void Arguments::ResolveArgsPass(std::string* code) {
|
||||
shared_float4s_data_.resize(shared_float4s_aligned_size);
|
||||
}
|
||||
|
||||
absl::Status Arguments::AddObjectArgs() {
|
||||
for (auto& t : objects_) {
|
||||
AddGPUResources(t.first,
|
||||
t.second.obj_ptr->GetGPUDescriptor()->GetGPUResources());
|
||||
RETURN_IF_ERROR(
|
||||
SetGPUResources(t.first, t.second.obj_ptr->GetGPUResources()));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
@ -20,8 +20,10 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/access_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
@ -35,15 +37,21 @@ class Arguments {
|
||||
Arguments() = default;
|
||||
void AddFloat(const std::string& name, float value = 0.0f);
|
||||
void AddInt(const std::string& name, int value = 0);
|
||||
void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc);
|
||||
void AddImage2D(const std::string& name, const GPUImage2DDescriptor& desc);
|
||||
|
||||
void AddObject(const std::string& name, GPUObjectPtr&& object);
|
||||
|
||||
absl::Status SetInt(const std::string& name, int value);
|
||||
absl::Status SetFloat(const std::string& name, float value);
|
||||
absl::Status SetImage2D(const std::string& name, cl_mem memory);
|
||||
absl::Status SetBuffer(const std::string& name, cl_mem memory);
|
||||
|
||||
std::string GetListOfArgs();
|
||||
|
||||
absl::Status Bind(cl_kernel kernel, int offset);
|
||||
|
||||
void ResolveArgsPass(std::string* code);
|
||||
absl::Status TransformToCLCode(std::string* code);
|
||||
|
||||
// Move only
|
||||
Arguments(Arguments&& args);
|
||||
@ -53,6 +61,14 @@ class Arguments {
|
||||
|
||||
private:
|
||||
std::string AddActiveArgument(const std::string& arg_name);
|
||||
void AddGPUResources(const std::string& name, const GPUResources& resources);
|
||||
|
||||
absl::Status SetGPUResources(const std::string& name,
|
||||
const GPUResourcesWithValue& resources);
|
||||
|
||||
absl::Status AddObjectArgs();
|
||||
|
||||
void ResolveArgsPass(std::string* code);
|
||||
|
||||
struct IntValue {
|
||||
int value;
|
||||
@ -79,6 +95,15 @@ class Arguments {
|
||||
};
|
||||
std::map<std::string, FloatValue> float_values_;
|
||||
std::vector<float> shared_float4s_data_;
|
||||
|
||||
std::map<std::string, GPUBufferDescriptor> buffers_;
|
||||
std::map<std::string, GPUImage2DDescriptor> images2d_;
|
||||
|
||||
struct ObjectArg {
|
||||
AccessType access_type;
|
||||
GPUObjectPtr obj_ptr;
|
||||
};
|
||||
std::map<std::string, ObjectArg> objects_;
|
||||
};
|
||||
|
||||
} // namespace cl
|
||||
|
||||
121
tensorflow/lite/delegates/gpu/cl/gpu_object.h
Normal file
121
tensorflow/lite/delegates/gpu/cl/gpu_object.h
Normal file
@ -0,0 +1,121 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_GPU_OBJECT_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_GPU_OBJECT_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
struct GPUImage2DDescriptor {
|
||||
DataType data_type;
|
||||
cl_mem memory;
|
||||
};
|
||||
|
||||
struct GPUBufferDescriptor {
|
||||
DataType data_type;
|
||||
int element_size;
|
||||
cl_mem memory;
|
||||
};
|
||||
|
||||
struct GPUResources {
|
||||
std::vector<std::string> ints;
|
||||
std::vector<std::string> floats;
|
||||
std::vector<std::pair<std::string, GPUBufferDescriptor>> buffers;
|
||||
std::vector<std::pair<std::string, GPUImage2DDescriptor>> images2d;
|
||||
|
||||
std::vector<std::string> GetNames() const {
|
||||
std::vector<std::string> names = ints;
|
||||
names.insert(names.end(), floats.begin(), floats.end());
|
||||
for (const auto& obj : buffers) {
|
||||
names.push_back(obj.first);
|
||||
}
|
||||
for (const auto& obj : images2d) {
|
||||
names.push_back(obj.first);
|
||||
}
|
||||
return names;
|
||||
}
|
||||
};
|
||||
|
||||
struct GPUResourcesWithValue {
|
||||
std::vector<std::pair<std::string, int>> ints;
|
||||
std::vector<std::pair<std::string, float>> floats;
|
||||
std::vector<std::pair<std::string, cl_mem>> buffers;
|
||||
std::vector<std::pair<std::string, cl_mem>> images2d;
|
||||
};
|
||||
|
||||
class GPUObjectDescriptor {
|
||||
public:
|
||||
GPUObjectDescriptor() = default;
|
||||
GPUObjectDescriptor(const GPUObjectDescriptor& obj_desc)
|
||||
: state_vars_(obj_desc.state_vars_) {}
|
||||
GPUObjectDescriptor& operator=(const GPUObjectDescriptor& obj_desc) {
|
||||
if (this != &obj_desc) {
|
||||
state_vars_ = obj_desc.state_vars_;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
virtual ~GPUObjectDescriptor() = default;
|
||||
|
||||
void SetStateVar(const std::string& key, const std::string& value) const {
|
||||
state_vars_[key] = value;
|
||||
}
|
||||
|
||||
virtual std::string PerformConstExpr(const std::string& const_expr) const {
|
||||
return "";
|
||||
}
|
||||
|
||||
virtual absl::Status PerformSelector(const std::string& selector,
|
||||
const std::vector<std::string>& args,
|
||||
std::string* result) const {
|
||||
*result = "";
|
||||
return absl::OkStatus();
|
||||
}
|
||||
virtual GPUResources GetGPUResources() const { return GPUResources(); }
|
||||
|
||||
protected:
|
||||
mutable std::map<std::string, std::string> state_vars_;
|
||||
};
|
||||
|
||||
class GPUObject {
|
||||
public:
|
||||
GPUObject() = default;
|
||||
// Move only
|
||||
GPUObject(GPUObject&& obj_desc) = default;
|
||||
GPUObject& operator=(GPUObject&& obj_desc) = default;
|
||||
GPUObject(const GPUObject&) = delete;
|
||||
GPUObject& operator=(const GPUObject&) = delete;
|
||||
virtual ~GPUObject() = default;
|
||||
virtual const GPUObjectDescriptor* GetGPUDescriptor() const = 0;
|
||||
virtual GPUResourcesWithValue GetGPUResources() const = 0;
|
||||
};
|
||||
|
||||
using GPUObjectPtr = std::unique_ptr<GPUObject>;
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_GPU_OBJECT_H_
|
||||
@ -115,8 +115,7 @@ std::string GetTransposeCode(
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst_tensor.WriteWHSB("result", "X", "Y", "Z", batch_id);
|
||||
c += "}\n";
|
||||
args->ResolveArgsPass(&c);
|
||||
return absl::Substitute(c, args->GetListOfArgs());
|
||||
return c;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
@ -139,8 +138,10 @@ Transpose& Transpose::operator=(Transpose&& operation) {
|
||||
}
|
||||
|
||||
absl::Status Transpose::Compile(const CreationContext& creation_context) {
|
||||
const auto code =
|
||||
std::string code =
|
||||
GetTransposeCode(definition_, attr_, linked_operations_, &args_);
|
||||
RETURN_IF_ERROR(args_.TransformToCLCode(&code));
|
||||
code = absl::Substitute(code, args_.GetListOfArgs());
|
||||
return creation_context.cache->GetOrCreateCLKernel(
|
||||
code, "main_function", *creation_context.context,
|
||||
*creation_context.device, &kernel_);
|
||||
|
||||
@ -402,6 +402,13 @@ class AddOperationParser : public TFLiteOperationParser {
|
||||
return absl::UnimplementedError("ADD requires two input tensors.");
|
||||
}
|
||||
// TODO(eignasheva): Add shapes check.
|
||||
for (int i = 0; i < 2; i++) {
|
||||
auto input = tflite::GetInput(context, tflite_node, i);
|
||||
if (IsConstantTensor(input) && input->dims->size > 0) {
|
||||
RETURN_IF_ERROR(CheckIfLinearConvertible(input->dims));
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteAddParams* tf_options = nullptr;
|
||||
return RetrieveBuiltinData(tflite_node, &tf_options);
|
||||
}
|
||||
@ -2453,15 +2460,15 @@ class TransformLandmarksV2OperationParser : public TFLiteOperationParser {
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
std::string op_name = "transform_landmarks_v2";
|
||||
node->operation.type = op_name;
|
||||
BHWC output_shape;
|
||||
|
||||
auto output_value = graph->FindOutputs(node->id)[0];
|
||||
output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape;
|
||||
BHWC output_shape = output_value->tensor.shape;
|
||||
RETURN_IF_ERROR(
|
||||
ParseCustomAttributes(op_name, tflite_node->custom_initial_data,
|
||||
tflite_node->custom_initial_data_size,
|
||||
&(node->operation.attributes), &output_shape));
|
||||
|
||||
auto output_value = graph->FindOutputs(node->id)[0];
|
||||
|
||||
output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user