Merge branch 'master' of github.com:tensorflow/tensorflow

* 'master' of github.com:tensorflow/tensorflow:
  Delete duplicate installation of patchelf.
  Emit error messages for all missing legalizations in TF to XLA full legalization pass.
  Fix control dependency issue causing shape_assert test to fail.
  Go: Update generated wrapper functions for TensorFlow ops.
  Enable XRT cache to be shared among multiple GPU devices. Allow XRT GPU work with multi-threaded based replication, where a single process see all the available devices.
  Export the only one function of a saved model only when it matches with exported_names argument.
  Add named size and count methods for arg, result and var methods to AOT models.
  [XLA] Add a memory space propagation pass.
  Enable Squeeze Op conversion without squeeze_dim attribute in explicit batch mode.
  Check return status of reading environment variable
  Fix misspelling
  fix build break
  fix typo
  Support options(environment variable) to enable grpc reuse port.
This commit is contained in:
Andrew Cavanaugh 2020-05-07 14:14:42 -04:00
commit 46ab1e9f26
41 changed files with 754 additions and 151 deletions

View File

@ -131,6 +131,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type));
std::vector<string> dim_vars;
string dim_sizes, indices;
int count = 1;
if (shape.rank() == 0 ||
(shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) {
dim_sizes = "[1]";
@ -140,6 +141,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
dim_vars.push_back(absl::StrCat("size_t dim", dim));
dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]");
indices += absl::StrCat("[dim", dim, "]");
count *= shape.dimensions(dim);
}
}
rewrites->push_back({"{{I}}", absl::StrCat(i)});
@ -147,6 +149,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")});
rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
rewrites->push_back({"{{INDICES}}", indices});
rewrites->push_back({"{{COUNT}}", absl::StrCat(count)});
return Status::OK();
}
@ -199,6 +202,12 @@ Status GenArgMethods(const tf2xla::Config& config,
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
arg_data({{I}}))){{INDICES}};
}
int arg{{NAME}}_size() const {
return {{COUNT}} * sizeof({{TYPE}});
}
int arg{{NAME}}_count() const {
return {{COUNT}};
}
)";
*methods += RewriteWithName(absl::StrCat(i), code, rewrites);
if (!config.feed(i).name().empty()) {
@ -246,6 +255,12 @@ Status GenResultMethods(const tf2xla::Config& config,
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
result_data({{I}}))){{INDICES}};
}
int result{{NAME}}_size() const {
return {{COUNT}} * sizeof({{TYPE}});
}
int result{{NAME}}_count() const {
return {{COUNT}};
}
)";
*methods += RewriteWithName(absl::StrCat(i), code, rewrites);
if (!config.fetch(i).name().empty()) {
@ -281,6 +296,12 @@ Status GenVariableMethods(const tf2xla::Config& config,
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
arg_data({{I}}))){{INDICES}};
}
int var_{{NAME}}_size() const {
return {{COUNT}} * sizeof({{TYPE}});
}
int var_{{NAME}}_count() const {
return {{COUNT}};
}
)";
const tf2xla::Variable& var = config.variable(i - config.feed_size());
rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : "");

View File

@ -138,6 +138,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const float(*)[1][2]>(
arg_data(0)))[dim0][dim1];
}
int arg0_size() const {
return 2 * sizeof(float);
}
int arg0_count() const {
return 2;
}
void set_arg_myfeed_data(const void* data) {
set_arg_data(0, data);
@ -156,6 +162,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const float(*)[1][2]>(
arg_data(0)))[dim0][dim1];
}
int arg_myfeed_size() const {
return 2 * sizeof(float);
}
int arg_myfeed_count() const {
return 2;
}
void set_arg1_data(const void* data) {
set_arg_data(1, data);
@ -174,6 +186,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const tensorflow::int64(*)[3][4]>(
arg_data(1)))[dim0][dim1];
}
int arg1_size() const {
return 12 * sizeof(tensorflow::int64);
}
int arg1_count() const {
return 12;
}
// Result methods for managing output buffers. Buffers are in row-major order.
// Must only be called after a successful Run call. There is a set of methods
@ -204,6 +222,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
result_data(0)))[dim0][dim1];
}
int result0_size() const {
return 30 * sizeof(tensorflow::uint32);
}
int result0_count() const {
return 30;
}
tensorflow::uint32* result_myfetch_data() {
return static_cast<tensorflow::uint32*>(result_data(0));
@ -219,6 +243,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
result_data(0)))[dim0][dim1];
}
int result_myfetch_size() const {
return 30 * sizeof(tensorflow::uint32);
}
int result_myfetch_count() const {
return 30;
}
// Methods for managing variable buffers. Buffers are in row-major order.
//
@ -261,6 +291,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const float(*)[1]>(
arg_data(2)))[0];
}
int var_myvar_readonly_size() const {
return 1 * sizeof(float);
}
int var_myvar_readonly_count() const {
return 1;
}
void set_var_myvar_data(float* data) {
set_arg_data(3, data);
@ -279,6 +315,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const float(*)[1]>(
arg_data(3)))[0];
}
int var_myvar_size() const {
return 1 * sizeof(float);
}
int var_myvar_count() const {
return 1;
}
void set_var_myvar2_data(tensorflow::int32* data) {
set_arg_data(4, data);
@ -297,6 +339,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const tensorflow::int32(*)[5]>(
arg_data(4)))[dim0];
}
int var_myvar2_size() const {
return 5 * sizeof(tensorflow::int32);
}
int var_myvar2_count() const {
return 5;
}
private:
// Number of buffers for the compiled computation.

View File

@ -146,6 +146,10 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
saved_model_exported_names.begin(), saved_model_exported_names.end());
absl::Span<std::string> exported_names(exported_names_in_vector);
if (exported_names.size() != 1) {
return errors::Unimplemented("Only support a single exported name.");
}
TF_ASSIGN_OR_RETURN(auto module,
ImportSavedModel(model_flags.saved_model_dir(),
model_flags.saved_model_version(), tags,

View File

@ -160,6 +160,11 @@ int main(int argc, char **argv) {
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
absl::Span<std::string> exported_names(exported_names_vector);
if (exported_names.size() != 1) {
llvm::errs() << "There should be only one exported name";
return kTrFailure;
}
module = tensorflow::ImportSavedModel(input_file_name, saved_model_version,
tags, exported_names, &context);
} else {

View File

@ -174,7 +174,7 @@ StatusOr<mlir::OwningModuleRef> ImportSavedModel(
return module;
} else if (saved_model_version == 1) {
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
input_filename, tags, context);
input_filename, tags, exported_names, context);
if (!module)
return tensorflow::errors::InvalidArgument("fail to open input file");

View File

@ -112,7 +112,7 @@ std::string ExperimentalConvertSavedModelV1ToMlir(
// Convert the SavedModelBundle to an MLIR module.
mlir::MLIRContext context;
auto module_or = ConvertSavedModelV1ToMlir(bundle, &context);
auto module_or = ConvertSavedModelV1ToMlir(bundle, {}, &context);
if (!module_or.status().ok()) {
Set_TF_Status_from_Status(status, module_or.status());
return "// error";

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
@ -57,6 +58,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Verifier.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
@ -65,6 +67,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
@ -2428,8 +2431,8 @@ class SavedModelObjectGraphImporter : public ImporterBase {
// Main entry point: converts all functions in the given meta graph to an MLIR
// Module.
static StatusOr<mlir::OwningModuleRef> Convert(
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
absl::Span<std::string> exported_names, bool add_default_attributes);
SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
mlir::MLIRContext* context, bool add_default_attributes);
private:
explicit SavedModelObjectGraphImporter(
@ -3129,8 +3132,8 @@ Status CreateSavedModelIR(
}
StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
absl::Span<std::string> exported_names, bool add_default_attributes) {
SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
mlir::MLIRContext* context, bool add_default_attributes) {
GraphDebugInfo dummy_debug_info;
const GraphDebugInfo& debug_info =
saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info;
@ -3207,17 +3210,20 @@ class SavedModelSignatureDefImporter {
public:
// Main entry point: converts all functions (specified by SignatureDefs) in
// the given meta graph to an MLIR Module.
static StatusOr<mlir::OwningModuleRef> Convert(const SavedModelBundle& bundle,
mlir::MLIRContext* context) {
SavedModelSignatureDefImporter importer(bundle, context);
static StatusOr<mlir::OwningModuleRef> Convert(
const SavedModelBundle& bundle, absl::Span<std::string> exported_names,
mlir::MLIRContext* context) {
SavedModelSignatureDefImporter importer(bundle, exported_names, context);
return importer.ConvertSignatures();
}
private:
SavedModelSignatureDefImporter(const SavedModelBundle& bundle,
absl::Span<std::string> exported_names,
mlir::MLIRContext* context)
: bundle_(bundle),
exported_names_(exported_names),
module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {}
// Converts the SavedModel to the SavedModel dialect. Creates an MLIR function
@ -3250,6 +3256,7 @@ class SavedModelSignatureDefImporter {
const std::vector<std::pair<std::string, TensorInfo>>& inputs);
const SavedModelBundle& bundle_;
absl::Span<std::string> exported_names_;
mlir::OwningModuleRef module_;
};
@ -3265,6 +3272,9 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
GraphDebugInfo debug_info;
if (bundle_.debug_info != nullptr) debug_info = *bundle_.debug_info;
llvm::StringSet<> exported_name_set;
exported_name_set.insert(exported_names_.begin(), exported_names_.end());
for (const auto& key_and_signature_def : signatures) {
const std::string& sig_def_key = key_and_signature_def.first;
const SignatureDef& signature_def = key_and_signature_def.second;
@ -3274,6 +3284,10 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
if (sig_def_key == "__saved_model_init_op") {
continue;
}
if (!exported_name_set.empty() &&
exported_name_set.count(sig_def_key) == 0) {
continue;
}
TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, signature_def,
debug_info, flib_def));
@ -3556,12 +3570,14 @@ StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
absl::Span<std::string> exported_names, bool add_default_attributes) {
return SavedModelObjectGraphImporter::Convert(
saved_model, context, exported_names, add_default_attributes);
saved_model, exported_names, context, add_default_attributes);
}
StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlir(
const SavedModelBundle& saved_model, mlir::MLIRContext* context) {
return SavedModelSignatureDefImporter::Convert(saved_model, context);
const SavedModelBundle& saved_model, absl::Span<std::string> exported_names,
mlir::MLIRContext* context) {
return SavedModelSignatureDefImporter::Convert(saved_model, exported_names,
context);
}
std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) {

View File

@ -55,6 +55,7 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
// expressed with tf_executor dialect.
stream_executor::port::StatusOr<mlir::OwningModuleRef>
ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model,
absl::Span<std::string> exported_names,
mlir::MLIRContext* context);
// Serialize a MLIR module to a string.

View File

@ -141,7 +141,8 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
absl::string_view saved_model_dir,
const std::unordered_set<std::string>& tags, mlir::MLIRContext* context) {
const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
tensorflow::SavedModelBundle bundle;
tensorflow::SessionOptions session_options;
// Force saved model states to be restored to CPU.
@ -155,7 +156,7 @@ mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
return nullptr;
}
auto module_or = ConvertSavedModelV1ToMlir(bundle, context);
auto module_or = ConvertSavedModelV1ToMlir(bundle, exported_names, context);
if (!module_or.status().ok()) {
LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
return nullptr;

View File

@ -64,7 +64,8 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
// given MLIR `context`.
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
absl::string_view saved_model_dir,
const std::unordered_set<std::string>& tags, mlir::MLIRContext* context);
const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context);
} // namespace tensorflow

View File

@ -104,26 +104,24 @@ int main(int argc, char** argv) {
return 1;
}
std::unordered_set<std::string> tags = absl::StrSplit(saved_model_tags, ',');
std::vector<std::string> exported_names_vector =
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
absl::Span<std::string> exported_names(exported_names_vector);
if (import_saved_model_object_graph) {
std::unordered_set<std::string> tags =
absl::StrSplit(saved_model_tags, ',');
std::vector<std::string> exported_names =
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
mlir::MLIRContext context;
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
input_filename, tags, absl::Span<std::string>(exported_names),
&context);
input_filename, tags, exported_names, &context);
if (!module) return 1;
module->print(output->os());
} else if (import_saved_model_signature_defs) {
std::unordered_set<std::string> tags =
absl::StrSplit(saved_model_tags, ',');
mlir::MLIRContext context;
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
input_filename, tags, &context);
input_filename, tags, exported_names, &context);
if (!module) return 1;
module->print(output->os());

View File

@ -1,22 +1,24 @@
// RUN: tf-opt %s -xla-legalize-tf -split-input-file -verify-diagnostics
// expected-error@below{{The following operations cannot be legalized: tf.NoOp (count: 1); tf_executor.fetch (count: 1); tf_executor.graph (count: 1); tf_executor.island (count: 1); tf_executor.yield (count: 1). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc.}}
// expected-error@below{{Emitting more detail about one op that failed to legalize...}}
func @tf_executor_graph_op() {
// expected-error@+1 {{failed to legalize operation 'tf_executor.graph'}}
tf_executor.graph {
%0 = tf_executor.island {
// expected-error@+1 {{'tf.NoOp' op is not legalizable}}
"tf.NoOp"() {} : () -> ()
tf_executor.yield
}
tf_executor.fetch
}
return
}
// -----
// expected-error@below{{The following operations cannot be legalized: tf.OpA (count: 1). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc.}}
func @tf_unknown_op(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// expected-error@+1 {{failed to legalize operation 'tf.OpA'}}
// expected-error@+1 {{'tf.OpA' op is not legalizable}}
%0 = "tf.OpA"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0: tensor<2xi32>
}
@ -27,3 +29,16 @@ func @tf_known_op(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0: tensor<2xi32>
}
// -----
// expected-error@below{{The following operations cannot be legalized: tf.OpA (count: 1); tf.OpB (count: 2). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc.}}
// expected-error@below{{Emitting more detail about one op that failed to legalize...}}
func @tf_unknown_known_mix(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// expected-error@+1 {{'tf.OpA' op is not legalizable}}
%0 = "tf.OpA"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%1 = "tf.OpB"(%0, %0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%2 = "tf.Add"(%1, %1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%3 = "tf.OpB"(%2, %2) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %2: tensor<2xi32>
}

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/Dialect/Traits.h" // from @llvm-project
@ -4785,6 +4786,51 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
}
};
// Emits debug information which includes the number of ops of each type which
// failed to legalize.
void EmitLegalizationErrors(Operation *op,
const DenseSet<Operation *> &nonlegalized_ops) {
// Track the legalization failures by mapping op name to information about
// that failure: the number of unlegalized occurances of the op, and one
// example operation that failed.
std::map<StringRef, std::pair<int, Operation *>> op_name_to_error_info;
DenseSet<Operation *> error_ops;
for (Operation *nonlegalized_op : nonlegalized_ops) {
// Increment count of this legalization failure.
StringRef op_name = nonlegalized_op->getName().getStringRef();
// If this emplace is successful, it's the first time we've encountered
// this op type. Initialize count to 0 so that after increment, it is 1.
auto insertion_result = op_name_to_error_info.emplace(
op_name, std::make_pair(0, nonlegalized_op));
++insertion_result.first->second.first;
}
std::vector<std::string> error_messages;
error_messages.reserve(op_name_to_error_info.size());
for (const auto &op_info : op_name_to_error_info) {
error_messages.push_back(
llvm::formatv("{0} (count: {1})", op_info.first, op_info.second.first));
}
Location loc = op->getLoc();
emitError(loc) << "The following operations cannot be legalized: "
<< llvm::join(error_messages, "; ")
<< ". These legalization failure(s) may be due to missing TF "
"to HLO lowerings and/or unsupported attributes, etc.";
// Emit more information about the missing ops. This error message
// contains useful details beyond the op name (input and output shapes,
// attributes, etc.).
if (!VLOG_IS_ON(1) && nonlegalized_ops.size() != 1) {
emitError(loc)
<< "Emitting more detail about one op that failed to legalize...";
} else if (VLOG_IS_ON(1)) {
emitError(loc) << "Emitting more detail about one of each type of op "
"that failed to legalize...";
}
for (const auto &op_info : op_name_to_error_info) {
op_info.second.second->emitOpError() << "is not legalizable";
if (!VLOG_IS_ON(1)) break;
}
}
// Performs the lowering to XLA dialect.
void LegalizeTF::runOnFunction() {
if (failed(legalizeTF(getFunction(), allow_partial_conversion_)))
@ -4841,7 +4887,16 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
if (!allow_partial_conversion) {
// Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp.
target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ::mlir::ReturnOp>();
return applyFullConversion(op, target, patterns);
DenseSet<Operation *> nonlegalized_ops;
LogicalResult result = applyPartialConversion(
op, target, patterns, /*converter=*/nullptr, &nonlegalized_ops);
// In order to enforce that the conversion result is fully converted,
// fail if there are any nonlegalized ops in the set.
if (failed(result) || !nonlegalized_ops.empty()) {
EmitLegalizationErrors(op, nonlegalized_ops);
return failure();
}
return result;
}
return applyPartialConversion(op, target, patterns);

View File

@ -2413,26 +2413,19 @@ Status ConvertExpandDims(OpConverterParams* params) {
}
Status Converter::SqueezeTensor(nvinfer1::ITensor* input,
const std::vector<int>& trt_axes,
std::vector<int>* input_dims,
nvinfer1::ITensor** output) {
const nvinfer1::Dims dims = input->getDimensions();
std::vector<int> input_dims(dims.d, dims.d + dims.nbDims);
// Mark axes to remove by setting them to 0.
for (int axis : trt_axes) {
input_dims[axis] = 0;
}
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
// If the remaining dimensions of a squeeze operation have dynamic sizes, we
// need to use TRT ops to build the result shape for the squeeze operation.
// This is because IShuffleLayer::setReshapeDimensions treats -1 as a special
// value.
if (absl::c_any_of(input_dims, [](int i) { return i == -1; })) {
if (absl::c_any_of(*input_dims, [](int i) { return i == -1; })) {
nvinfer1::ITensor* shape = network()->addShape(*input)->getOutput(0);
std::vector<nvinfer1::ITensor const*> concat_inputs;
for (int i = 0; i < input_dims.size(); i++) {
for (int i = 0; i < input_dims->size(); i++) {
// If input dim wasn't set to 0 earlier, we include it in new shape.
if (input_dims[i] != 0) {
if (input_dims->at(i) != 0) {
concat_inputs.push_back(
network()
->addSlice(*shape, {1, {i}}, {1, {1}}, {1, {1}})
@ -2452,11 +2445,12 @@ Status Converter::SqueezeTensor(nvinfer1::ITensor* input,
}
#endif
// Remove all dims which are equal to 0.
input_dims.erase(std::remove(input_dims.begin(), input_dims.end(), 0),
input_dims.end());
input_dims->erase(std::remove(input_dims->begin(), input_dims->end(), 0),
input_dims->end());
// Reshape tensor.
nvinfer1::Dims new_dims;
TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims));
VLOG(2) << "input_dims" << input_dims;
TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(*input_dims, &new_dims));
TF_RETURN_IF_ERROR(PrepareTensorForShape(TRT_TensorOrWeights(input), new_dims,
/*validation_only=*/false, output));
return Status::OK();
@ -2475,31 +2469,48 @@ Status ConvertSqueeze(OpConverterParams* params) {
TFAttrs attrs(node_def);
auto squeeze_dims = attrs.get<std::vector<int64>>("squeeze_dims");
if (squeeze_dims.empty()) {
return errors::Unimplemented(
"Squeeze is only implemented for explicit dims, at ", node_def.name());
}
std::vector<int> trt_axes;
trt_axes.reserve(squeeze_dims.size());
for (int tf_axis : squeeze_dims) {
// If the axis is valid, then convert it to TRT axis, otherwise abort
// conversion.
int trt_axis;
TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(),
params->use_implicit_batch, &trt_axis));
// Make sure target dimension is size 1 or unknown size (-1)
if (input_dims[trt_axis] != -1 && input_dims[trt_axis] != 1) {
return errors::InvalidArgument(
"Dimension ", tf_axis, " with size ", input_dims[trt_axis],
" cannot be squeezed because it must be size 1, at ",
if (params->use_implicit_batch || !HasStaticShape(dims)) {
return errors::Unimplemented(
"Squeeze is not implemented for empty squeeze_dims, at ",
node_def.name());
} else {
// explicit batch mode with static input shape we squeeze all singleton
// dimensions
for (int& dim : input_dims) {
if (dim == 1) {
// Mark it for removal by setting it to 0
dim = 0;
}
}
}
} else {
std::vector<int> trt_axes;
trt_axes.reserve(squeeze_dims.size());
for (int tf_axis : squeeze_dims) {
// If the axis is valid, then convert it to TRT axis, otherwise abort
// conversion.
int trt_axis;
TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(),
params->use_implicit_batch, &trt_axis));
// Make sure target dimension is size 1 or unknown size (-1)
if (input_dims[trt_axis] != -1 && input_dims[trt_axis] != 1) {
return errors::InvalidArgument(
"Dimension ", tf_axis, " with size ", input_dims[trt_axis],
" cannot be squeezed because it must be size 1, at ",
node_def.name());
}
trt_axes.push_back(trt_axis);
}
// Mark axes to remove by setting them to 0.
for (int axis : trt_axes) {
input_dims[axis] = 0;
}
trt_axes.push_back(trt_axis);
}
if (params->validation_only) return Status::OK();
nvinfer1::ITensor* output_tensor = nullptr;
TF_RETURN_IF_ERROR(params->converter->SqueezeTensor(
input_tensor.tensor(), trt_axes, &output_tensor));
input_tensor.tensor(), &input_dims, &output_tensor));
params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
return Status::OK();
}

View File

@ -529,11 +529,9 @@ class Converter {
// Helper function to add a squeeze op to the network.
//
// The trt_axes argument lists those axes that need to be squeezed. Each axis
// in the list is numbered according to TRT convention (see ConvertAxis for
// details).
Status SqueezeTensor(nvinfer1::ITensor* input,
const std::vector<int>& trt_axes,
// The input_dims argument stores the TRT dimensions of the input tensor,
// where the dimensions to be squeezed are replaced by 0.
Status SqueezeTensor(nvinfer1::ITensor* input, std::vector<int>* input_dims,
nvinfer1::ITensor** output);
// Creates an IConstantLayer using 'weights' whose dimensions are specified by

View File

@ -3129,11 +3129,13 @@ TEST_P(ParameterizedOpConverterTest, ConvertSqueeze) {
TestParamBase{
{1, 2, 1, 3}, // input dims
{}, // input partial dims
{2, 1, 3}, // expected output dims
{2, 3}, // expected output dims
{}, // axis
Status{
error::UNIMPLEMENTED,
"Squeeze is only implemented for explicit dims, at my_squeeze"}},
trt_mode == TrtTestMode::kExplicitBatch
? Status::OK()
: Status{error::UNIMPLEMENTED,
"Squeeze is not implemented for empty squeeze_dims, at "
"my_squeeze"}},
TestParamBase{{1, 2, 1, 3},
{},
{2, 1, 3},

View File

@ -50,6 +50,7 @@ class RunId {
public:
// Creates a new, unique RunId.
RunId();
explicit RunId(int64 value) : data_(value) {}
RunId(const RunId&) = default;
RunId& operator=(const RunId&) = default;

View File

@ -3234,6 +3234,29 @@ tf_cc_test(
],
)
cc_library(
name = "memory_space_propagation",
srcs = ["memory_space_propagation.cc"],
hdrs = ["memory_space_propagation.h"],
deps = [
":hlo",
":hlo_dataflow_analysis",
":hlo_pass",
],
)
tf_cc_test(
name = "memory_space_propagation_test",
srcs = ["memory_space_propagation_test.cc"],
deps = [
":hlo_parser",
":memory_space_propagation",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
cc_library(
name = "hlo_dce",
srcs = ["hlo_dce.cc"],

View File

@ -0,0 +1,67 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/memory_space_propagation.h"
namespace xla {
StatusOr<bool> MemorySpacePropagation::Run(HloModule* module) {
bool modified = false;
TF_ASSIGN_OR_RETURN(auto dataflow_analysis,
HloDataflowAnalysis::Run(*module));
dataflow_analysis_ = std::move(dataflow_analysis);
for (HloComputation* computation : module->MakeNonfusionComputations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kFusion) {
// Propagate the operand subshapes.
for (int operand_idx = 0; operand_idx < instruction->operand_count();
++operand_idx) {
modified |=
PropagateSubshapes(instruction->operand(operand_idx)->shape(),
instruction->fused_parameter(operand_idx));
}
// Propagate output subshapes.
modified |= PropagateSubshapes(instruction->shape(),
instruction->fused_expression_root());
}
}
}
return modified;
}
bool MemorySpacePropagation::PropagateSubshapes(
const Shape& caller_shape, const HloInstruction* callee_instruction) const {
bool modified = false;
for (const ShapeUtil::IndexedShape& indexed_shape :
ShapeUtil::GetLeafShapes(caller_shape)) {
int64 memory_space = indexed_shape.shape.layout().memory_space();
const HloValue& value = dataflow_analysis_->GetUniqueValueAt(
callee_instruction, indexed_shape.index);
for (const HloPosition& position : value.positions()) {
Shape* shape = ShapeUtil::GetMutableSubshape(
position.instruction->mutable_shape(), position.index);
if (shape->layout().memory_space() != memory_space) {
shape->mutable_layout()->set_memory_space(memory_space);
modified = true;
}
}
}
return modified;
}
} // namespace xla

View File

@ -0,0 +1,46 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_PROPAGATION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_PROPAGATION_H_
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
// This is a legalization pass that propagates the memory space in the layout to
// the fusion computations.
class MemorySpacePropagation : public HloModulePass {
public:
~MemorySpacePropagation() override = default;
absl::string_view name() const override { return "memory-space-propagation"; }
StatusOr<bool> Run(HloModule* module) override;
private:
// Given the caller shape (operand or output) and its corresponding
// insturction in the fused computation (parameter or root), propagates the
// memory space to all the subshapes in the callee side. Returns true if the
// module is modified.
bool PropagateSubshapes(const Shape& caller_shape,
const HloInstruction* callee_instruction) const;
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_PROPAGATION_H_

View File

@ -0,0 +1,203 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/memory_space_propagation.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
class MemorySpacePropagationTest : public HloTestBase {
public:
MemorySpacePropagationTest()
: HloTestBase(),
verifier_(/*layout_sensitive=*/false, /*allow_mixed_precision*/ false) {
}
Status Verify(HloModule* module) { return verifier_.Run(module).status(); }
private:
HloVerifier verifier_;
};
TEST_F(MemorySpacePropagationTest, NoMemorySpace) {
absl::string_view hlo_string = R"(
HloModule NoMemorySpace
%fused_computation {
%param_1.3 = s32[1]{0:T(128)} parameter(1)
%constant.2 = s32[]{:T(128)} constant(-2147483648)
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
%param_2.3 = s32[5]{0:T(128)} parameter(2)
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
%param_0.1 = s32[6]{0:T(128)} parameter(0)
ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
}
ENTRY %entry {
%param0 = s32[6]{0:T(128)} parameter(0)
%param1 = s32[1]{0:T(128)} parameter(1)
%param2 = s32[5]{0:T(128)} parameter(2)
%arg0 = s32[6]{0:T(128)} copy(%param0)
%arg1 = s32[1]{0:T(128)} copy(%param1)
%arg2 = s32[5]{0:T(128)} copy(%param2)
%fusion = s32[6]{0:T(128)} fusion(s32[6]{0:T(128)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)} %arg2), kind=kLoop, calls=%fused_computation
ROOT %root = s32[6]{0:T(128)} copy(%fusion)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
MemorySpacePropagation memory_space_propagation;
EXPECT_FALSE(memory_space_propagation.Run(module.get()).ValueOrDie());
TF_ASSERT_OK_AND_ASSIGN(auto ref, ParseAndReturnVerifiedModule(hlo_string));
EXPECT_EQ(module->Hash(), ref->Hash());
}
TEST_F(MemorySpacePropagationTest, NonTupleOutput) {
absl::string_view hlo_string = R"(
HloModule NonTupleOutput
%fused_computation {
%param_1.3 = s32[1]{0:T(128)} parameter(1)
%constant.2 = s32[]{:T(128)} constant(-2147483648)
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
%param_2.3 = s32[5]{0:T(128)} parameter(2)
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
%param_0.1 = s32[6]{0:T(128)} parameter(0)
ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
}
ENTRY %entry {
%param0 = s32[6]{0:T(128)} parameter(0)
%param1 = s32[1]{0:T(128)} parameter(1)
%param2 = s32[5]{0:T(128)} parameter(2)
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
%arg1 = s32[1]{0:T(128)} copy(%param1)
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
%fusion = s32[6]{0:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
ROOT %root = s32[6]{0:T(128)} copy(%fusion)
}
)";
absl::string_view expected_hlo_string = R"(
HloModule NonTupleOutput
%fused_computation {
%param_1.3 = s32[1]{0:T(128)} parameter(1)
%constant.2 = s32[]{:T(128)} constant(-2147483648)
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
%param_2.3 = s32[5]{0:T(128)S(1)} parameter(2)
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
%param_0.1 = s32[6]{0:T(128)S(1)} parameter(0)
ROOT %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
}
ENTRY %entry {
%param0 = s32[6]{0:T(128)} parameter(0)
%param1 = s32[1]{0:T(128)} parameter(1)
%param2 = s32[5]{0:T(128)} parameter(2)
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
%arg1 = s32[1]{0:T(128)} copy(%param1)
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
%fusion = s32[6]{0:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
ROOT %root = s32[6]{0:T(128)} copy(%fusion)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
MemorySpacePropagation memory_space_propagation;
EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie());
TF_EXPECT_OK(Verify(module.get()));
TF_ASSERT_OK_AND_ASSIGN(auto ref,
ParseAndReturnVerifiedModule(expected_hlo_string));
EXPECT_EQ(module->Hash(), ref->Hash());
}
TEST_F(MemorySpacePropagationTest, TupleOutput) {
absl::string_view hlo_string = R"(
HloModule TupleOutput
%fused_computation {
%param_1.3 = s32[1]{0:T(128)} parameter(1)
%constant.2 = s32[]{:T(128)} constant(-2147483648)
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
%param_2.3 = s32[5]{0:T(128)} parameter(2)
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
%param_0.1 = s32[6]{0:T(128)} parameter(0)
%add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
%multiply.0 = s32[6]{0:T(128)} multiply(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
ROOT %tuple = (s32[6]{0:T(128)}, s32[6]{0:T(128)}) tuple(%add.0, %multiply.0)
}
ENTRY %entry {
%param0 = s32[6]{0:T(128)} parameter(0)
%param1 = s32[1]{0:T(128)} parameter(1)
%param2 = s32[5]{0:T(128)} parameter(2)
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
%arg1 = s32[1]{0:T(128)} copy(%param1)
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
%fusion = (s32[6]{0:T(128)S(1)}, s32[6]{0:T(128)}) fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
%gte0 = s32[6]{0:T(128)S(1)} get-tuple-element(%fusion), index=0
%gte1 = s32[6]{0:T(128)} get-tuple-element(%fusion), index=1
ROOT %root = s32[6]{0:T(128)} add(%gte0, %gte1)
}
)";
absl::string_view expected_hlo_string = R"(
HloModule TupleOutput
%fused_computation {
%param_1.3 = s32[1]{0:T(128)} parameter(1)
%constant.2 = s32[]{:T(128)} constant(-2147483648)
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
%param_2.3 = s32[5]{0:T(128)S(1)} parameter(2)
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
%param_0.1 = s32[6]{0:T(128)S(1)} parameter(0)
%add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
%multiply.0 = s32[6]{0:T(128)} multiply(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
ROOT %tuple = (s32[6]{0:T(128)S(1)}, s32[6]{0:T(128)}) tuple(%add.0, %multiply.0)
}
ENTRY %entry {
%param0 = s32[6]{0:T(128)} parameter(0)
%param1 = s32[1]{0:T(128)} parameter(1)
%param2 = s32[5]{0:T(128)} parameter(2)
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
%arg1 = s32[1]{0:T(128)} copy(%param1)
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
%fusion = (s32[6]{0:T(128)S(1)}, s32[6]{0:T(128)}) fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
%gte0 = s32[6]{0:T(128)S(1)} get-tuple-element(%fusion), index=0
%gte1 = s32[6]{0:T(128)} get-tuple-element(%fusion), index=1
ROOT %root = s32[6]{0:T(128)} add(%gte0, %gte1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
MemorySpacePropagation memory_space_propagation;
EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie());
TF_EXPECT_OK(Verify(module.get()));
TF_ASSERT_OK_AND_ASSIGN(auto ref,
ParseAndReturnVerifiedModule(expected_hlo_string));
EXPECT_EQ(module->Hash(), ref->Hash());
}
} // namespace
} // namespace xla

View File

@ -158,7 +158,7 @@ Status XRTCompileOp::Compile(OpKernelContext* ctx,
argument_layout_ptrs[i] = &argument_layouts[i];
}
xla::ExecutableBuildOptions build_options;
build_options.set_device_ordinal(client->default_device_ordinal());
build_options.set_device_ordinal(device_ref.device_ordinal());
build_options.set_num_replicas(num_replicas);
build_options.set_result_layout(xla::Shape(config.program_shape().result()));
build_options.set_device_allocator(device_ref.backend()->memory_allocator());
@ -206,7 +206,8 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, CompilationCacheKey(computation_proto, &key));
// Process-wide cache of XLA executables.
auto cache_or = GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0);
auto cache_or = XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
ctx, /*max_number_of_entries=*/0);
OP_REQUIRES_OK(ctx, cache_or.status());
auto cache = cache_or.ConsumeValueOrDie();
@ -259,15 +260,11 @@ void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "XRTReleaseCompilationRefOp::Compute";
auto timed = monitoring::MakeTimed(xrt_metrics::GetReleaseCompilationCell());
ResourceMgr* rm;
OP_REQUIRES_OK(ctx, XRTGenericDeviceAccessor::GetResourceManager(ctx, &rm));
// Process-wide cache of XLA executables.
XRTCompilationCache* cache;
OP_REQUIRES_OK(ctx, rm->Lookup<XRTCompilationCache>(
rm->default_container(),
kXRTCompilationCacheResourceName, &cache));
core::ScopedUnref cache_unref(cache);
auto cache_or = XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
ctx, /*max_number_of_entries=*/0);
OP_REQUIRES_OK(ctx, cache_or.status());
auto cache = cache_or.ConsumeValueOrDie();
const Tensor& keys_tensor = ctx->input(0);
auto flat_keys = keys_tensor.flat<int64>();

View File

@ -149,13 +149,17 @@ xla::StatusOr<InputBuffers> GetChainedOpInputs(
xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
OpKernelContext* context, XRTGenericDeviceAccessor::ScopedRef* device_ref,
xla::LocalExecutable* executable, const InputBuffers& input_buffers,
se::Stream* stream, int rng_seed, int replica_id) {
se::Stream* stream, int rng_seed,
const xrt::CommonExecutionConfig& config) {
VLOG(2) << "Executing computation.";
xla::ExecutableRunOptions run_options;
run_options.set_stream(stream);
run_options.set_allocator(device_ref->backend()->memory_allocator());
run_options.set_intra_op_thread_pool(&context->eigen_cpu_device());
run_options.set_rng_seed(rng_seed);
if (config.run_id() != 0) {
run_options.set_run_id(xla::RunId(config.run_id()));
}
if (executable->executable()
->module_config()
.has_static_device_assignment()) {
@ -164,8 +168,11 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
}
xla::GpuExecutableRunOptions gpu_options;
std::vector<xla::GlobalDeviceId> gpu_global_ids;
if (replica_id >= 0) {
gpu_global_ids.emplace_back(replica_id);
if (config.local_replica_mapping_size() > 0) {
gpu_global_ids.reserve(config.local_replica_mapping_size());
for (auto& gid : config.local_replica_mapping()) {
gpu_global_ids.emplace_back(xla::GlobalDeviceId(gid));
}
gpu_options.set_gpu_global_device_ids(gpu_global_ids);
}
std::shared_ptr<NcclUniqueIdFactory> nccl_factory = GetNcclUniqueIdFactory();
@ -222,10 +229,11 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> ExecuteComputation(
OpKernelContext* context, XRTMemoryManager* memory_manager,
XRTGenericDeviceAccessor::ScopedRef* device_ref,
xla::LocalExecutable* executable, const InputBuffers& input_buffers,
se::Stream* stream, int rng_seed, int replica_id) {
se::Stream* stream, int rng_seed,
const xrt::CommonExecutionConfig& config) {
auto runfn = [&]() {
return RunExecutable(context, device_ref, executable, input_buffers, stream,
rng_seed, replica_id);
rng_seed, config);
};
// We pass zero as requested_free_size as there is no simple way to get the
@ -241,14 +249,15 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> ExecuteComputation(
XRTGenericDeviceAccessor::ScopedRef* device_ref,
xla::LocalExecutable* executable,
const std::vector<InputCoords>& input_coords, bool release_inputs,
se::Stream* stream, int rng_seed, int replica_id) {
se::Stream* stream, int rng_seed,
const xrt::CommonExecutionConfig& config) {
XRTMemoryManager::WorkingSet working_set(memory_manager);
TF_ASSIGN_OR_RETURN(InputBuffers input_buffers,
GetInputBuffers(&working_set, device_ref->backend(),
input_coords, release_inputs));
return ExecuteComputation(context, memory_manager.get(), device_ref,
executable, input_buffers, stream, rng_seed,
replica_id);
config);
}
// XRTExecuteOp
@ -297,8 +306,9 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) {
bool release_inputs = config_proto.release_input_handles();
bool release_compilation = config_proto.release_compilation_handle();
TF_ASSIGN_OR_RETURN(
auto cache, GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0));
TF_ASSIGN_OR_RETURN(auto cache,
XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
context, /*max_number_of_entries=*/0));
// We are guaranteed that the underlying device object won't be deleted out
// from under us, while the ScopedRef is live.
class XRTGenericDeviceAccessor::ScopedRef device_ref;
@ -330,7 +340,7 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) {
RefPtr<XRTTupleAllocation> output_tuple,
ExecuteComputation(context, memory_manager, &device_ref, executable,
input_coords, release_inputs, stream, rng_seed,
config_proto.replica_id()));
config_proto.common_config()));
return CreateExecuteOutput(context, memory_manager.get(),
std::move(output_tuple),
@ -379,8 +389,9 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) {
xrt::XRTChainedExecuteConfig config;
TF_RET_CHECK(ParseFromTString(execution_config.scalar<tstring>()(), &config));
TF_ASSIGN_OR_RETURN(
auto cache, GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0));
TF_ASSIGN_OR_RETURN(auto cache,
XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
context, /*max_number_of_entries=*/0));
// We are guaranteed that the underlying device object won't be deleted out
// from under us, while the ScopedRef is live.
class XRTGenericDeviceAccessor::ScopedRef device_ref;
@ -408,7 +419,7 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) {
return ExecuteComputation(context, memory_manager.get(), &device_ref,
executable, input_buffers, stream, rng_seed,
config.replica_id());
config.common_config());
};
return ExecuteChained(context, memory_manager, device_ref.backend(),

View File

@ -111,6 +111,17 @@ message XLATupleNode {
repeated XLATupleNode tuples = 3;
}
message CommonExecutionConfig {
// The replica index this execute is driving.
int32 replica_id = 1;
// Mapping local device ordinals to global replica IDs.
// local_replica_mapping[LOCAL_DEVICE_ORDINAL] = GLOBAL_REPLICA_ID
repeated int32 local_replica_mapping = 2;
// The execution run ID used to correlate different XRT execute operations
// happeining in parallel from different threads.
int64 run_id = 3;
}
// Options for an XLA execution.
message XRTExecutionConfig {
// Local device to run on. This is present because the execute Op
@ -133,8 +144,9 @@ message XRTExecutionConfig {
// a single tuple allocation the execution will return a vector of
// allocations, one for each of the first-level elements of the result tuple.
bool return_exploded_tuple = 7;
// The replica index this execute is driving.
int32 replica_id = 8;
reserved 8;
// The common configuration for XRT execute operations.
CommonExecutionConfig common_config = 9;
}
message XRTChainedExecuteConfig {
@ -145,8 +157,9 @@ message XRTChainedExecuteConfig {
// Optional key to disambiguate between executions. This is only needed if
// multiple host send/recvs may be outstanding concurrently with executions.
string execution_instance_key = 3;
// The replica index this execute is driving.
int32 replica_id = 4;
reserved 4;
// The common configuration for XRT execute operations.
CommonExecutionConfig common_config = 5;
}
// A single chained execute operation. An operation can either be a device data

View File

@ -17,19 +17,56 @@ limitations under the License.
#include "tensorflow/compiler/xrt/xrt_device.h"
#include <map>
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace {
class ResourceMgrArena {
public:
static ResourceMgrArena* Get() {
static ResourceMgrArena* arena = new ResourceMgrArena();
return arena;
}
ResourceMgr* GetResourceMgr(const std::string& platform_name) {
mutex_lock lock(mutex_);
auto it = resource_managers_.find(platform_name);
if (it == resource_managers_.end()) {
it = resource_managers_.emplace(platform_name, new ResourceMgr()).first;
}
return it->second;
}
private:
mutex mutex_;
std::map<std::string, ResourceMgr*> resource_managers_;
};
} // namespace
/*static*/ Status XRTGenericDeviceAccessor::GetResourceManager(
OpKernelContext* ctx, ResourceMgr** rm) {
*rm = ctx->resource_manager();
const XlaDevice::Metadata* metadata;
TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata));
*rm = ResourceMgrArena::Get()->GetResourceMgr(metadata->platform()->Name());
return Status::OK();
}
/* static */ xla::StatusOr<RefPtr<XRTCompilationCache>>
XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
OpKernelContext* ctx, int64 max_number_of_entries) {
ResourceMgr* rm;
TF_RETURN_IF_ERROR(GetResourceManager(ctx, &rm));
return tensorflow::GetOrCreateCompilationCache(rm, max_number_of_entries);
}
/*static*/ Status XRTGenericDeviceAccessor::InitScopedRef(
OpKernelContext* ctx, int device_ordinal, ScopedRef* scoped_ref) {
const XlaDevice::Metadata* metadata;

View File

@ -19,6 +19,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XRT_XRT_DEVICE_H_
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
@ -31,6 +32,9 @@ class XRTGenericDeviceAccessor {
public:
static Status GetResourceManager(OpKernelContext* ctx, ResourceMgr** rm);
static xla::StatusOr<RefPtr<XRTCompilationCache>> GetOrCreateCompilationCache(
OpKernelContext* ctx, int64 max_number_of_entries);
// We use a ScopedRef pattern here even though it's not strictly necessary,
// just so that templated uses of this and the TPU accessor class will be as
// similar as possible.

View File

@ -70,6 +70,18 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption {
plugins) override {}
};
// Define an option subclass in order to enable SO_REUSEPORT for the
// server socket.
class ReusePortOption : public ::grpc::ServerBuilderOption {
public:
void UpdateArguments(::grpc::ChannelArguments* args) override {
args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 1);
}
void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
plugins) override {}
};
// static utility function
RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) {
return new RpcRendezvousMgr(env);
@ -220,8 +232,18 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) {
GetServerCredentials(server_def_), &bound_port_);
builder.SetMaxMessageSize(std::numeric_limits<int32>::max());
builder.SetOption(
std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption));
bool reuse_port = false;
const Status status =
ReadBoolFromEnvVar("TF_GRPC_REUSE_PORT", false, &reuse_port);
if (!status.ok()) {
LOG(ERROR) << status.error_message();
}
auto server_build_option =
reuse_port
? std::unique_ptr<::grpc::ServerBuilderOption>(new ReusePortOption)
: std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption);
builder.SetOption(std::move(server_build_option));
// Allow subclasses to specify more args to pass to the gRPC server.
MaybeMutateBuilder(&builder);
master_impl_ = CreateMaster(&master_env_);

View File

@ -12059,7 +12059,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo
//
// value: The cropped area of the image must have an aspect ratio =
// width / height within this range.
// If not specified, defaults to {f:0.75 f:1.33}
// If not specified, defaults to {f:0.75 f:1.33}
func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr {
return func(m optionalAttr) {
m["aspect_ratio_range"] = value
@ -12070,7 +12070,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted
//
// value: The cropped area of the image must contain a fraction of the
// supplied image within this range.
// If not specified, defaults to {f:0.05 f:1}
// If not specified, defaults to {f:0.05 f:1}
func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr {
return func(m optionalAttr) {
m["area_range"] = value
@ -18975,7 +18975,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2
//
// value: The cropped area of the image must have an aspect ratio =
// width / height within this range.
// If not specified, defaults to {f:0.75 f:1.33}
// If not specified, defaults to {f:0.75 f:1.33}
func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr {
return func(m optionalAttr) {
m["aspect_ratio_range"] = value
@ -18986,7 +18986,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort
//
// value: The cropped area of the image must contain a fraction of the
// supplied image within this range.
// If not specified, defaults to {f:0.05 f:1}
// If not specified, defaults to {f:0.05 f:1}
func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr {
return func(m optionalAttr) {
m["area_range"] = value
@ -19390,7 +19390,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr {
// ImageSummaryBadColor sets the optional bad_color attribute to value.
//
// value: Color to use for pixels with non-finite values.
// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255}
// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255}
func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr {
return func(m optionalAttr) {
m["bad_color"] = value
@ -20461,7 +20461,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr {
// filter element on that dimension. The dimension order is determined by the
// value of `data_format`, see above for details. Dilations in the batch and
// depth dimensions must be 1.
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr {
return func(m optionalAttr) {
m["dilations"] = value
@ -21633,7 +21633,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr {
// element on that dimension. The dimension order is determined by the value of
// `data_format`, see above for details. Dilations in the batch and depth
// dimensions must be 1.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -22341,7 +22341,7 @@ func Conv2DDataFormat(value string) Conv2DAttr {
// filter element on that dimension. The dimension order is determined by the
// value of `data_format`, see above for details. Dilations in the batch and
// depth dimensions must be 1.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func Conv2DDilations(value []int64) Conv2DAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -22537,7 +22537,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy
// QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value.
//
// value: List of dilation values.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -22606,7 +22606,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized
// QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value.
//
// value: List of dilation values.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -22721,7 +22721,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi
// QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value.
//
// value: List of dilation values.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -22780,7 +22780,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D
// QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value.
//
// value: List of dilation values.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -22954,7 +22954,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann
// QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value.
//
// value: list of dilation values.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -23331,7 +23331,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr {
// filter element on that dimension. The dimension order is determined by the
// value of `data_format`, see above for details. Dilations in the batch and
// depth dimensions must be 1.
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr {
return func(m optionalAttr) {
m["dilations"] = value
@ -25651,7 +25651,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi
type Conv3DBackpropFilterAttr func(optionalAttr)
// Conv3DBackpropFilterDilations sets the optional dilations attribute to value.
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -25714,7 +25714,7 @@ func Conv3DDataFormat(value string) Conv3DAttr {
// filter element on that dimension. The dimension order is determined by the
// value of `data_format`, see above for details. Dilations in the batch and
// depth dimensions must be 1.
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
func Conv3DDilations(value []int64) Conv3DAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -25965,7 +25965,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN
// element on that dimension. The dimension order is determined by the value of
// `data_format`, see above for details. Dilations in the batch and depth
// dimensions must be 1.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -26449,7 +26449,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr {
// filter element on that dimension. The dimension order is determined by the
// value of `data_format`, see above for details. Dilations in the batch and
// depth dimensions must be 1.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -45537,7 +45537,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d
// element on that dimension. The dimension order is determined by the value of
// `data_format`, see above for details. Dilations in the batch and depth
// dimensions must be 1.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -47477,7 +47477,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula
type Conv3DBackpropInputAttr func(optionalAttr)
// Conv3DBackpropInputDilations sets the optional dilations attribute to value.
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -47548,7 +47548,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr {
// element on that dimension. The dimension order is determined by the value of
// `data_format`, see above for details. Dilations in the batch and depth
// dimensions must be 1.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -48537,7 +48537,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr {
// element on that dimension. The dimension order is determined by the value of
// `data_format`, see above for details. Dilations in the batch and depth
// dimensions must be 1.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr {
return func(m optionalAttr) {
m["dilations"] = value

View File

@ -401,7 +401,8 @@ class TFLiteConverterBase(object):
if not self._contains_function_with_implements_attr(saved_model_proto):
self.saved_model_dir = None
else:
self._saved_model_exported_names = []
if not self._saved_model_exported_names:
self._saved_model_exported_names = []
self._saved_model_version = saved_model_proto.saved_model_schema_version
if self._saved_model_version not in [1, 2]:
raise ValueError(
@ -761,6 +762,9 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
if not signature_keys:
signature_keys = saved_model.signatures
if len(signature_keys) != 1:
raise ValueError("Only support a single signature key.")
funcs = []
for key in signature_keys:
if key not in saved_model.signatures:

View File

@ -469,15 +469,10 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
save(root, save_dir, {'add': add_func, 'sub': sub_func})
# Ensure the converter generates.
converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
self.assertLen(converter._funcs, 2)
# Try converting multiple functions.
with self.assertRaises(ValueError) as error:
_ = converter.convert()
self.assertIn('This converter can only convert a single ConcreteFunction',
str(error.exception))
_ = lite.TFLiteConverterV2.from_saved_model(save_dir)
self.assertIn('Only support a single signature key.', str(error.exception))
@test_util.run_v2_only
def testNoConcreteFunctionModel(self):
@ -487,12 +482,9 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
save(root, save_dir)
converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
self.assertLen(converter._funcs, 0)
with self.assertRaises(ValueError) as error:
_ = converter.convert()
self.assertIn('No ConcreteFunction is specified.', str(error.exception))
_ = lite.TFLiteConverterV2.from_saved_model(save_dir)
self.assertIn('Only support a single signature key.', str(error.exception))
@test_util.run_v2_only
def testKerasSequentialModel(self):

View File

@ -334,7 +334,7 @@ void BenchmarkPerformanceOptions::Run() {
// profiling listener etc. in each Run() invoke because such listeners may be
// reset and become invalid in the next Run(). As a result, we record the
// number of externally-added listeners here to prevent they're cleared later.
const int num_external_listners = single_option_run_->NumListeners();
const int num_external_listeners = single_option_run_->NumListeners();
// Now perform all runs, each with different performance-affecting parameters.
for (const auto& run_params : all_run_params_) {
@ -349,7 +349,7 @@ void BenchmarkPerformanceOptions::Run() {
// Clear internally created listeners before each run but keep externally
// created ones.
single_option_run_->RemoveListeners(num_external_listners);
single_option_run_->RemoveListeners(num_external_listeners);
all_run_stats_->MarkBenchmarkStart(*single_option_run_params_);
single_option_run_->Run();

View File

@ -119,7 +119,7 @@ std::vector<Flag> ExternalDelegateProvider::CreateFlags(
"The library path for the underlying external."),
CreateFlag<std::string>(
"external_delegate_options", params,
"Comma-seperated options to be passed to the external delegate")};
"Comma-separated options to be passed to the external delegate")};
return flags;
}

View File

@ -33,7 +33,7 @@ class DelegateProviders {
DelegateProviders();
// Initialize delegate-related parameters from commandline arguments and
// returns true if sucessful.
// returns true if successful.
bool InitFromCmdlineArgs(int* argc, const char** argv);
// Get all parameters from all registered delegate providers.

View File

@ -66,7 +66,7 @@ TEST(EvaluationDelegateProviderTest, GetAllParamsWithTfliteInferenceParams) {
TfliteInferenceParams params;
params.set_delegate(TfliteInferenceParams::NONE);
params.set_num_threads(4);
// The same-meaning parameter in TfliteInferenceParams takes precendence.
// The same-meaning parameter in TfliteInferenceParams takes precedence.
tools::ToolParams tool_params = providers.GetAllParams(params);
EXPECT_EQ(4, tool_params.Get<int>("num_threads"));
EXPECT_EQ(1, argc);

View File

@ -246,7 +246,7 @@ BENCHMARK_LIB_SRCS := $(filter-out \
$(BENCHMARK_ALL_SRCS))
# These target-specific makefiles should modify or replace options like
# CXXFLAGS or LIBS to work for a specific targetted architecture. All logic
# CXXFLAGS or LIBS to work for a specific targeted architecture. All logic
# based on platforms or architectures should happen within these files, to
# keep this main makefile focused on the sources and dependencies.
include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc)

View File

@ -86,7 +86,7 @@ struct OperatorProperty {
bool restrict_same_input_output_scale = false;
// Use same min of min and max of max for each group.
// Incompatable with restrict_same_input_output_scale and restricted_value.
// Incompatible with restrict_same_input_output_scale and restricted_value.
// TODO(jianlijianli): make it compatible with other restrictions when there
// is a use case.
std::vector<std::vector<int>> restrict_scale = {};

View File

@ -74,6 +74,6 @@ def modify_model_interface(input_file, output_file, input_type, output_type):
# Throw an exception if the return status is an error.
if status != 0:
raise RuntimeError(
'Error occured when trying to modify the model input type from float '
'Error occurred when trying to modify the model input type from float '
'to {input_type} and output type from float to {output_type}.'.format(
input_type=input_type, output_type=output_type))

View File

@ -133,7 +133,7 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
if (op_sig.input_types.size() == 2) {
return 6;
}
// `keep_num_dims` is supported at verison 5.
// `keep_num_dims` is supported at version 5.
if (op_sig.options.fully_connected.keep_num_dims) {
return 5;
}

View File

@ -1688,8 +1688,6 @@ class AssertShapesTest(test.TestCase):
rank_three_shapes, array_ops.constant(1), correct_rank=3, actual_rank=0)
def test_raises_dynamic_incorrect_rank(self):
self.skipTest("b/134600611")
x_value = 5
rank_two_shapes = [(1, 1), (1, 3), ("a", "b"), (None, None)]
with ops.Graph().as_default():

View File

@ -1845,7 +1845,12 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
'Specified by tensor %s dimension %d' %
(tensor_name(specified_by_y), specified_at_dim))
actual_size = sizes.actual_sizes[tensor_dim]
# This is extremely subtle. If actual_sizes is dynamic, we must
# make sure a control dependency is inserted here so that this slice
# can not execute until the rank is asserted to be enough for the
# slice to not fail.
with ops.control_dependencies(rank_assertions):
actual_size = sizes.actual_sizes[tensor_dim]
if _has_known_value(actual_size) and _has_known_value(specified_size):
if int(actual_size) != int(specified_size):
raise ValueError(
@ -1871,12 +1876,17 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
size_assertions.append(
control_flow_ops.Assert(condition, data_, summarize=summarize))
else:
size = sizes.actual_sizes[tensor_dim]
# Not sure if actual_sizes is a constant, but for safety, guard
# on rank. See explanation above about actual_sizes need for safety.
with ops.control_dependencies(rank_assertions):
size = sizes.actual_sizes[tensor_dim]
size_specifications[size_symbol] = (size, sizes.x, tensor_dim)
with ops.control_dependencies(rank_assertions):
shapes_assertion = control_flow_ops.group(size_assertions)
return shapes_assertion
# Ensure both assertions actually occur.
with ops.control_dependencies(rank_assertions):
shapes_assertion = control_flow_ops.group(size_assertions)
return shapes_assertion
# pylint: disable=line-too-long

View File

@ -17,7 +17,6 @@ RUN apt-get update && apt-get install -y \
flex \
g++ \
make \
patchelf \
rpm2cpio \
unar \
wget \